diff options
author | Alexander Chernavin <achernavin@netgate.com> | 2022-08-04 08:11:57 +0000 |
---|---|---|
committer | Alexander Chernavin <achernavin@netgate.com> | 2022-08-09 15:55:45 +0000 |
commit | fee9853a4f5d9a180ef6309cc37bd4060d27a51e (patch) | |
tree | 09ed324ca250603af84f2994683765a78a2c4191 | |
parent | a6328e51e0c831ba3f0f4977f776491ac44eaec5 (diff) |
wireguard: add peers roaming support
Type: feature
With this change, peers are able to roam between different external
endpoints. Successfully authenticated handshake or data packet that is
received from a new endpoint will cause the peer's endpoint to be
updated accordingly.
Signed-off-by: Alexander Chernavin <achernavin@netgate.com>
Change-Id: Ib4eb7dfa3403f3fb9e8bbe19ba6237c4960c764c
-rw-r--r-- | src/plugins/wireguard/FEATURE.yaml | 2 | ||||
-rw-r--r-- | src/plugins/wireguard/README.rst | 1 | ||||
-rw-r--r-- | src/plugins/wireguard/wireguard_cli.c | 8 | ||||
-rw-r--r-- | src/plugins/wireguard/wireguard_input.c | 61 | ||||
-rw-r--r-- | src/plugins/wireguard/wireguard_peer.c | 125 | ||||
-rw-r--r-- | src/plugins/wireguard/wireguard_peer.h | 19 | ||||
-rw-r--r-- | src/plugins/wireguard/wireguard_send.c | 20 | ||||
-rw-r--r-- | test/test_wireguard.py | 266 |
8 files changed, 434 insertions, 68 deletions
diff --git a/src/plugins/wireguard/FEATURE.yaml b/src/plugins/wireguard/FEATURE.yaml index 4c6946d2a33..5c0a588a484 100644 --- a/src/plugins/wireguard/FEATURE.yaml +++ b/src/plugins/wireguard/FEATURE.yaml @@ -7,5 +7,3 @@ features: description: "Wireguard protocol implementation" state: development properties: [API, CLI] -missing: - - Peers roaming between different external IPs diff --git a/src/plugins/wireguard/README.rst b/src/plugins/wireguard/README.rst index ead412519b7..35dd2c41382 100644 --- a/src/plugins/wireguard/README.rst +++ b/src/plugins/wireguard/README.rst @@ -77,4 +77,3 @@ Main next steps for improving this implementation ------------------------------------------------- 1. Use all benefits of VPP-engine. -2. Add peers roaming support diff --git a/src/plugins/wireguard/wireguard_cli.c b/src/plugins/wireguard/wireguard_cli.c index 214e6a5e2b4..5fa620507d6 100644 --- a/src/plugins/wireguard/wireguard_cli.c +++ b/src/plugins/wireguard/wireguard_cli.c @@ -165,7 +165,7 @@ wg_peer_add_command_fn (vlib_main_t * vm, u8 public_key[NOISE_PUBLIC_KEY_LEN + 1]; fib_prefix_t allowed_ip, *allowed_ips = NULL; ip_prefix_t pfx; - ip_address_t ip; + ip_address_t ip = ip_address_initializer; u32 portDst = 0, table_id = 0; u32 persistent_keepalive = 0; u32 tun_sw_if_index = ~0; @@ -213,6 +213,12 @@ wg_peer_add_command_fn (vlib_main_t * vm, } } + if (0 == vec_len (allowed_ips)) + { + error = clib_error_return (0, "Allowed IPs are not specified"); + goto done; + } + rv = wg_peer_add (tun_sw_if_index, public_key, table_id, &ip_addr_46 (&ip), allowed_ips, portDst, persistent_keepalive, &peer_index); diff --git a/src/plugins/wireguard/wireguard_input.c b/src/plugins/wireguard/wireguard_input.c index b85cdc610e4..22850b832b4 100644 --- a/src/plugins/wireguard/wireguard_input.c +++ b/src/plugins/wireguard/wireguard_input.c @@ -125,16 +125,6 @@ typedef enum WG_INPUT_N_NEXT, } wg_input_next_t; -/* static void */ -/* set_peer_address (wg_peer_t * peer, ip4_address_t ip4, u16 udp_port) */ -/* { */ -/* if (peer) */ -/* { */ -/* ip46_address_set_ip4 (&peer->dst.addr, &ip4); */ -/* peer->dst.port = udp_port; */ -/* } */ -/* } */ - static u8 is_ip4_header (u8 *data) { @@ -171,8 +161,8 @@ wg_handshake_process (vlib_main_t *vm, wg_main_t *wmp, vlib_buffer_t *b, } udp_header_t *uhd = current_b_data - sizeof (udp_header_t); - u16 udp_src_port = clib_host_to_net_u16 (uhd->src_port);; - u16 udp_dst_port = clib_host_to_net_u16 (uhd->dst_port);; + u16 udp_src_port = clib_host_to_net_u16 (uhd->src_port); + u16 udp_dst_port = clib_host_to_net_u16 (uhd->dst_port); message_header_t *header = current_b_data; @@ -269,7 +259,8 @@ wg_handshake_process (vlib_main_t *vm, wg_main_t *wmp, vlib_buffer_t *b, return WG_INPUT_ERROR_PEER; } - // set_peer_address (peer, ip4_src, udp_src_port); + wg_peer_update_endpoint (rp->r_peer_idx, &src_ip, udp_src_port); + if (PREDICT_FALSE (!wg_send_handshake_response (vm, peer))) { vlib_node_increment_counter (vm, node_idx, @@ -318,7 +309,8 @@ wg_handshake_process (vlib_main_t *vm, wg_main_t *wmp, vlib_buffer_t *b, return WG_INPUT_ERROR_PEER; } - // set_peer_address (peer, ip4_src, udp_src_port); + wg_peer_update_endpoint (peeri, &src_ip, udp_src_port); + if (noise_remote_begin_session (vm, &peer->remote)) { @@ -582,6 +574,26 @@ error: return ret; } +static_always_inline void +wg_find_outer_addr_port (vlib_buffer_t *b, ip46_address_t *addr, u16 *port, + u8 is_ip4) +{ + if (is_ip4) + { + ip4_udp_header_t *ip4_udp_hdr = + vlib_buffer_get_current (b) - sizeof (ip4_udp_header_t); + ip46_address_set_ip4 (addr, &ip4_udp_hdr->ip4.src_address); + *port = clib_net_to_host_u16 (ip4_udp_hdr->udp.src_port); + } + else + { + ip6_udp_header_t *ip6_udp_hdr = + vlib_buffer_get_current (b) - sizeof (ip6_udp_header_t); + ip46_address_set_ip6 (addr, &ip6_udp_hdr->ip6.src_address); + *port = clib_net_to_host_u16 (ip6_udp_hdr->udp.src_port); + } +} + always_inline uword wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame, u8 is_ip4, u16 async_next_node) @@ -735,8 +747,6 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node, } else { - peer_idx = NULL; - /* Handshake packets should be processed in main thread */ if (thread_index != 0) { @@ -808,6 +818,10 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node, } message_data_t *data = vlib_buffer_get_current (b[0]); + ip46_address_t out_src_ip; + u16 out_udp_src_port; + + wg_find_outer_addr_port (b[0], &out_src_ip, &out_udp_src_port, is_ip4); if (data->receiver_index != last_rec_idx) { @@ -823,6 +837,8 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node, if (PREDICT_FALSE (peer_idx && (last_peer_time_idx != peer_idx))) { + wg_peer_update_endpoint_from_mt (*peer_idx, &out_src_ip, + out_udp_src_port); wg_timers_any_authenticated_packet_received_opt (peer, time); wg_timers_any_authenticated_packet_traversal (peer); last_peer_time_idx = peer_idx; @@ -890,7 +906,8 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node, } always_inline uword -wg_input_post (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame) +wg_input_post (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame, + u8 is_ip4) { vnet_main_t *vnm = vnet_get_main (); vnet_interface_main_t *im = &vnm->interface_main; @@ -925,6 +942,10 @@ wg_input_post (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame) bool is_keepalive = false; message_data_t *data = vlib_buffer_get_current (b[0]); + ip46_address_t out_src_ip; + u16 out_udp_src_port; + + wg_find_outer_addr_port (b[0], &out_src_ip, &out_udp_src_port, is_ip4); if (data->receiver_index != last_rec_idx) { @@ -949,6 +970,8 @@ wg_input_post (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame) if (PREDICT_FALSE (peer_idx && (last_peer_time_idx != peer_idx))) { + wg_peer_update_endpoint_from_mt (*peer_idx, &out_src_ip, + out_udp_src_port); wg_timers_any_authenticated_packet_received_opt (peer, time); wg_timers_any_authenticated_packet_traversal (peer); last_peer_time_idx = peer_idx; @@ -995,13 +1018,13 @@ VLIB_NODE_FN (wg6_input_node) VLIB_NODE_FN (wg4_input_post_node) (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *from_frame) { - return wg_input_post (vm, node, from_frame); + return wg_input_post (vm, node, from_frame, /* is_ip4 */ 1); } VLIB_NODE_FN (wg6_input_post_node) (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *from_frame) { - return wg_input_post (vm, node, from_frame); + return wg_input_post (vm, node, from_frame, /* is_ip4 */ 0); } /* *INDENT-OFF* */ diff --git a/src/plugins/wireguard/wireguard_peer.c b/src/plugins/wireguard/wireguard_peer.c index 589f71272f6..922ca8cdae5 100644 --- a/src/plugins/wireguard/wireguard_peer.c +++ b/src/plugins/wireguard/wireguard_peer.c @@ -16,6 +16,7 @@ #include <vnet/adj/adj_midchain.h> #include <vnet/fib/fib_table.h> +#include <vnet/fib/fib_entry_track.h> #include <wireguard/wireguard_peer.h> #include <wireguard/wireguard_if.h> #include <wireguard/wireguard_messages.h> @@ -63,13 +64,14 @@ wg_peer_clear (vlib_main_t * vm, wg_peer_t * peer) wg_peer_endpoint_reset (&peer->src); wg_peer_endpoint_reset (&peer->dst); - adj_index_t *adj_index; - vec_foreach (adj_index, peer->adj_indices) + wg_peer_adj_t *peer_adj; + vec_foreach (peer_adj, peer->adjs) { - if (INDEX_INVALID != *adj_index) - { - wg_peer_by_adj_index[*adj_index] = INDEX_INVALID; - } + wg_peer_by_adj_index[peer_adj->adj_index] = INDEX_INVALID; + if (FIB_NODE_INDEX_INVALID != peer_adj->fib_entry_index) + fib_entry_untrack (peer_adj->fib_entry_index, peer_adj->sibling_index); + if (adj_is_valid (peer_adj->adj_index)) + adj_nbr_midchain_unstack (peer_adj->adj_index); } peer->input_thread_index = ~0; peer->output_thread_index = ~0; @@ -83,8 +85,9 @@ wg_peer_clear (vlib_main_t * vm, wg_peer_t * peer) peer->new_handshake_interval_tick = 0; peer->rehandshake_interval_tick = 0; peer->timer_need_another_keepalive = false; + vec_free (peer->rewrite); vec_free (peer->allowed_ips); - vec_free (peer->adj_indices); + vec_free (peer->adjs); } static void @@ -96,17 +99,17 @@ wg_peer_init (vlib_main_t * vm, wg_peer_t * peer) } static void -wg_peer_adj_stack (wg_peer_t *peer, adj_index_t ai) +wg_peer_adj_stack (wg_peer_t *peer, wg_peer_adj_t *peer_adj) { ip_adjacency_t *adj; u32 sw_if_index; wg_if_t *wgi; fib_protocol_t fib_proto; - if (!adj_is_valid (ai)) + if (!adj_is_valid (peer_adj->adj_index)) return; - adj = adj_get (ai); + adj = adj_get (peer_adj->adj_index); sw_if_index = adj->rewrite_header.sw_if_index; u8 is_ip4 = ip46_address_is_ip4 (&peer->src.addr); fib_proto = is_ip4 ? FIB_PROTOCOL_IP4 : FIB_PROTOCOL_IP6; @@ -116,9 +119,10 @@ wg_peer_adj_stack (wg_peer_t *peer, adj_index_t ai) if (!wgi) return; - if (!vnet_sw_interface_is_admin_up (vnet_get_main (), wgi->sw_if_index)) + if (!vnet_sw_interface_is_admin_up (vnet_get_main (), wgi->sw_if_index) || + !wg_peer_can_send (peer)) { - adj_midchain_delegate_unstack (ai); + adj_nbr_midchain_unstack (peer_adj->adj_index); } else { @@ -132,8 +136,13 @@ wg_peer_adj_stack (wg_peer_t *peer, adj_index_t ai) u32 fib_index; fib_index = fib_table_find (fib_proto, peer->table_id); + peer_adj->fib_entry_index = + fib_entry_track (fib_index, &dst, FIB_NODE_TYPE_ADJ, + peer_adj->adj_index, &peer_adj->sibling_index); - adj_midchain_delegate_stack (ai, fib_index, &dst); + adj_nbr_midchain_stack_on_fib_entry ( + peer_adj->adj_index, peer_adj->fib_entry_index, + fib_forw_chain_type_from_fib_proto (dst.fp_proto)); } } @@ -198,11 +207,11 @@ walk_rc_t wg_peer_if_admin_state_change (index_t peeri, void *data) { wg_peer_t *peer; - adj_index_t *adj_index; + wg_peer_adj_t *peer_adj; peer = wg_peer_get (peeri); - vec_foreach (adj_index, peer->adj_indices) + vec_foreach (peer_adj, peer->adjs) { - wg_peer_adj_stack (peer, *adj_index); + wg_peer_adj_stack (peer, peer_adj); } return (WALK_CONTINUE); } @@ -215,6 +224,7 @@ wg_peer_if_adj_change (index_t peeri, void *data) ip_adjacency_t *adj; wg_peer_t *peer; fib_prefix_t *allowed_ip; + wg_peer_adj_t *peer_adj; adj = adj_get (*adj_index); @@ -224,17 +234,21 @@ wg_peer_if_adj_change (index_t peeri, void *data) if (fib_prefix_is_cover_addr_46 (allowed_ip, &adj->sub_type.nbr.next_hop)) { - vec_add1 (peer->adj_indices, *adj_index); + vec_add2 (peer->adjs, peer_adj, 1); + peer_adj->adj_index = *adj_index; + peer_adj->fib_entry_index = FIB_NODE_INDEX_INVALID; + peer_adj->sibling_index = ~0; + vec_validate_init_empty (wg_peer_by_adj_index, *adj_index, INDEX_INVALID); - wg_peer_by_adj_index[*adj_index] = peer - wg_peer_pool; + wg_peer_by_adj_index[*adj_index] = peeri; fixup = wg_peer_get_fixup (peer, adj_get_link_type (*adj_index)); adj_nbr_midchain_update_rewrite (*adj_index, fixup, NULL, ADJ_FLAG_MIDCHAIN_IP_STACK, vec_dup (peer->rewrite)); - wg_peer_adj_stack (peer, *adj_index); + wg_peer_adj_stack (peer, peer_adj); return (WALK_STOP); } } @@ -313,6 +327,71 @@ wg_peer_update_flags (index_t peeri, wg_peer_flags flag, bool add_del) wg_api_peer_event (peeri, peer->flags); } +void +wg_peer_update_endpoint (index_t peeri, const ip46_address_t *addr, u16 port) +{ + wg_peer_t *peer = wg_peer_get (peeri); + + if (ip46_address_is_equal (&peer->dst.addr, addr) && peer->dst.port == port) + return; + + wg_peer_endpoint_init (&peer->dst, addr, port); + + u8 is_ip4 = ip46_address_is_ip4 (&peer->dst.addr); + vec_free (peer->rewrite); + peer->rewrite = wg_build_rewrite (&peer->src.addr, peer->src.port, + &peer->dst.addr, peer->dst.port, is_ip4); + + wg_peer_adj_t *peer_adj; + vec_foreach (peer_adj, peer->adjs) + { + if (FIB_NODE_INDEX_INVALID != peer_adj->fib_entry_index) + { + fib_entry_untrack (peer_adj->fib_entry_index, + peer_adj->sibling_index); + peer_adj->fib_entry_index = FIB_NODE_INDEX_INVALID; + peer_adj->sibling_index = ~0; + } + + if (adj_is_valid (peer_adj->adj_index)) + { + adj_midchain_fixup_t fixup = + wg_peer_get_fixup (peer, adj_get_link_type (peer_adj->adj_index)); + adj_nbr_midchain_update_rewrite (peer_adj->adj_index, fixup, NULL, + ADJ_FLAG_MIDCHAIN_IP_STACK, + vec_dup (peer->rewrite)); + wg_peer_adj_stack (peer, peer_adj); + } + } +} + +typedef struct wg_peer_upd_ep_args_t_ +{ + index_t peeri; + ip46_address_t addr; + u16 port; +} wg_peer_upd_ep_args_t; + +static void +wg_peer_update_endpoint_thread_fn (wg_peer_upd_ep_args_t *args) +{ + wg_peer_update_endpoint (args->peeri, &args->addr, args->port); +} + +void +wg_peer_update_endpoint_from_mt (index_t peeri, const ip46_address_t *addr, + u16 port) +{ + wg_peer_upd_ep_args_t args = { + .peeri = peeri, + .port = port, + }; + + ip46_address_copy (&args.addr, addr); + vlib_rpc_call_main_thread (wg_peer_update_endpoint_thread_fn, (u8 *) &args, + sizeof (args)); +} + int wg_peer_add (u32 tun_sw_if_index, const u8 public_key[NOISE_PUBLIC_KEY_LEN], u32 table_id, const ip46_address_t *endpoint, @@ -345,7 +424,7 @@ wg_peer_add (u32 tun_sw_if_index, const u8 public_key[NOISE_PUBLIC_KEY_LEN], if (pool_elts (wg_peer_pool) > MAX_PEERS) return (VNET_API_ERROR_LIMIT_EXCEEDED); - pool_get (wg_peer_pool, peer); + pool_get_zero (wg_peer_pool, peer); wg_peer_init (vm, peer); @@ -428,9 +507,9 @@ format_wg_peer (u8 * s, va_list * va) { index_t peeri = va_arg (*va, index_t); fib_prefix_t *allowed_ip; - adj_index_t *adj_index; u8 key[NOISE_KEY_LEN_BASE64]; wg_peer_t *peer; + wg_peer_adj_t *peer_adj; peer = wg_peer_get (peeri); key_to_base64 (peer->remote.r_public, NOISE_PUBLIC_KEY_LEN, key); @@ -443,9 +522,9 @@ format_wg_peer (u8 * s, va_list * va) peer->wg_sw_if_index, peer->persistent_keepalive_interval, peer->flags, pool_elts (peer->api_clients)); s = format (s, "\n adj:"); - vec_foreach (adj_index, peer->adj_indices) + vec_foreach (peer_adj, peer->adjs) { - s = format (s, " %d", *adj_index); + s = format (s, " %d", peer_adj->adj_index); } s = format (s, "\n key:%=s %U", key, format_hex_bytes, peer->remote.r_public, NOISE_PUBLIC_KEY_LEN); diff --git a/src/plugins/wireguard/wireguard_peer.h b/src/plugins/wireguard/wireguard_peer.h index a14f2692b1c..c07ea894b36 100644 --- a/src/plugins/wireguard/wireguard_peer.h +++ b/src/plugins/wireguard/wireguard_peer.h @@ -68,6 +68,13 @@ typedef enum WG_PEER_ESTABLISHED = 0x2, } wg_peer_flags; +typedef struct wg_peer_adj_t_ +{ + adj_index_t adj_index; + fib_node_index_t fib_entry_index; + u32 sibling_index; +} wg_peer_adj_t; + typedef struct wg_peer { noise_remote_t remote; @@ -80,7 +87,7 @@ typedef struct wg_peer wg_peer_endpoint_t dst; wg_peer_endpoint_t src; u32 table_id; - adj_index_t *adj_indices; + wg_peer_adj_t *adjs; /* rewrite built from address information */ u8 *rewrite; @@ -144,6 +151,10 @@ adj_walk_rc_t wg_peer_adj_walk (adj_index_t ai, void *data); void wg_api_peer_event (index_t peeri, wg_peer_flags flags); void wg_peer_update_flags (index_t peeri, wg_peer_flags flag, bool add_del); +void wg_peer_update_endpoint (index_t peeri, const ip46_address_t *addr, + u16 port); +void wg_peer_update_endpoint_from_mt (index_t peeri, + const ip46_address_t *addr, u16 port); static inline bool wg_peer_is_dead (wg_peer_t *peer) @@ -200,6 +211,12 @@ fib_prefix_is_cover_addr_46 (const fib_prefix_t *p1, const ip46_address_t *ip) return (false); } +static inline bool +wg_peer_can_send (wg_peer_t *peer) +{ + return peer && peer->rewrite; +} + #endif // __included_wg_peer_h__ /* diff --git a/src/plugins/wireguard/wireguard_send.c b/src/plugins/wireguard/wireguard_send.c index 509fe70c777..93e808ad050 100644 --- a/src/plugins/wireguard/wireguard_send.c +++ b/src/plugins/wireguard/wireguard_send.c @@ -104,6 +104,9 @@ u8 * wg_build_rewrite (ip46_address_t *src_addr, u16 src_port, ip46_address_t *dst_addr, u16 dst_port, u8 is_ip4) { + if (ip46_address_is_zero (dst_addr) || 0 == dst_port) + return NULL; + u8 *rewrite = NULL; if (is_ip4) { @@ -151,6 +154,9 @@ wg_send_handshake (vlib_main_t * vm, wg_peer_t * peer, bool is_retry) { ASSERT (vm->thread_index == 0); + if (!wg_peer_can_send (peer)) + return false; + message_handshake_initiation_t packet; if (!is_retry) @@ -224,6 +230,9 @@ wg_send_keepalive (vlib_main_t * vm, wg_peer_t * peer) { ASSERT (vm->thread_index == 0); + if (!wg_peer_can_send (peer)) + return false; + u32 size_of_packet = message_data_len (0); message_data_t *packet = (message_data_t *) wg_main.per_thread_data[vm->thread_index].data; @@ -278,6 +287,9 @@ wg_send_handshake_response (vlib_main_t * vm, wg_peer_t * peer) { message_handshake_response_t packet; + if (!wg_peer_can_send (peer)) + return false; + if (noise_create_response (vm, &peer->remote, &packet.sender_index, @@ -329,10 +341,14 @@ wg_send_handshake_cookie (vlib_main_t *vm, u32 sender_index, u32 bi0 = 0; u8 is_ip4 = ip46_address_is_ip4 (remote_addr); + bool ret; rewrite = wg_build_rewrite (wg_if_addr, wg_if_port, remote_addr, remote_port, is_ip4); - if (!wg_create_buffer (vm, rewrite, (u8 *) &packet, sizeof (packet), &bi0, - is_ip4)) + + ret = wg_create_buffer (vm, rewrite, (u8 *) &packet, sizeof (packet), &bi0, + is_ip4); + vec_free (rewrite); + if (!ret) return false; ip46_enqueue_packet (vm, bi0, is_ip4); diff --git a/test/test_wireguard.py b/test/test_wireguard.py index b8c5d2afd93..95cfe68d2a9 100644 --- a/test/test_wireguard.py +++ b/test/test_wireguard.py @@ -137,15 +137,6 @@ class VppWgInterface(VppInterface): return "wireguard-%d" % self._sw_if_index -def find_route(test, prefix, is_ip6, table_id=0): - routes = test.vapi.ip_route_dump(table_id, is_ip6) - - for e in routes: - if table_id == e.route.table_id and str(e.route.prefix) == str(prefix): - return True - return False - - NOISE_HANDSHAKE_NAME = b"Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" NOISE_IDENTIFIER_NAME = b"WireGuard v1 zx2c4 Jason@zx2c4.com" @@ -176,6 +167,10 @@ class VppWgPeer(VppObject): self.noise = NoiseConnection.from_name(NOISE_HANDSHAKE_NAME) + def change_endpoint(self, endpoint, port): + self.endpoint = endpoint + self.port = port + def add_vpp_config(self, is_ip6=False): rv = self._test.vapi.wireguard_peer_add( peer={ @@ -206,10 +201,12 @@ class VppWgPeer(VppObject): peers = self._test.vapi.wireguard_peers_dump() for p in peers: + # "::" endpoint will be returned as "0.0.0.0" in peer's details + endpoint = "0.0.0.0" if self.endpoint == "::" else self.endpoint if ( p.peer.public_key == self.public_key_bytes() and p.peer.port == self.port - and str(p.peer.endpoint) == self.endpoint + and str(p.peer.endpoint) == endpoint and p.peer.sw_if_index == self.itf.sw_if_index and len(self.allowed_ips) == p.peer.n_allowed_ips ): @@ -470,17 +467,17 @@ class VppWgPeer(VppObject): def validate_encapped(self, rxs, tx, is_ip6=False): for rx in rxs: if is_ip6 is False: - rx = IP(self.decrypt_transport(rx)) + rx = IP(self.decrypt_transport(rx, is_ip6=is_ip6)) - # chech the oringial packet is present + # check the original packet is present self._test.assertEqual(rx[IP].dst, tx[IP].dst) self._test.assertEqual(rx[IP].ttl, tx[IP].ttl - 1) else: - rx = IPv6(self.decrypt_transport(rx)) + rx = IPv6(self.decrypt_transport(rx, is_ip6=is_ip6)) - # chech the oringial packet is present + # check the original packet is present self._test.assertEqual(rx[IPv6].dst, tx[IPv6].dst) - self._test.assertEqual(rx[IPv6].ttl, tx[IPv6].ttl - 1) + self._test.assertEqual(rx[IPv6].hlim, tx[IPv6].hlim - 1) def want_events(self): self._test.vapi.want_wireguard_peer_events( @@ -997,6 +994,237 @@ class TestWg(VppTestCase): peer_2.remove_vpp_config() wg0.remove_vpp_config() + def _test_wg_peer_roaming_on_handshake_tmpl(self, is_endpoint_set, is_resp, is_ip6): + port = 12323 + + # create wg interface + if is_ip6: + wg0 = VppWgInterface(self, self.pg1.local_ip6, port).add_vpp_config() + wg0.admin_up() + wg0.config_ip6() + else: + wg0 = VppWgInterface(self, self.pg1.local_ip4, port).add_vpp_config() + wg0.admin_up() + wg0.config_ip4() + + self.pg_enable_capture(self.pg_interfaces) + self.pg_start() + + # create more remote hosts + NUM_REMOTE_HOSTS = 2 + self.pg1.generate_remote_hosts(NUM_REMOTE_HOSTS) + if is_ip6: + self.pg1.configure_ipv6_neighbors() + else: + self.pg1.configure_ipv4_neighbors() + + # create a peer + if is_ip6: + peer_1 = VppWgPeer( + test=self, + itf=wg0, + endpoint=self.pg1.remote_hosts[0].ip6 if is_endpoint_set else "::", + port=port + 1 if is_endpoint_set else 0, + allowed_ips=["1::3:0/112"], + ).add_vpp_config() + else: + peer_1 = VppWgPeer( + test=self, + itf=wg0, + endpoint=self.pg1.remote_hosts[0].ip4 if is_endpoint_set else "0.0.0.0", + port=port + 1 if is_endpoint_set else 0, + allowed_ips=["10.11.3.0/24"], + ).add_vpp_config() + self.assertTrue(peer_1.query_vpp_config()) + + if is_resp: + # wait for the peer to send a handshake initiation + rxs = self.pg1.get_capture(1, timeout=2) + # prepare a handshake response + resp = peer_1.consume_init(rxs[0], self.pg1, is_ip6=is_ip6) + # change endpoint + if is_ip6: + peer_1.change_endpoint(self.pg1.remote_hosts[1].ip6, port + 100) + resp[IPv6].src, resp[UDP].sport = peer_1.endpoint, peer_1.port + else: + peer_1.change_endpoint(self.pg1.remote_hosts[1].ip4, port + 100) + resp[IP].src, resp[UDP].sport = peer_1.endpoint, peer_1.port + # send the handshake response + # expect a keepalive message sent to the new endpoint + rxs = self.send_and_expect(self.pg1, [resp], self.pg1) + # verify the keepalive message + b = peer_1.decrypt_transport(rxs[0], is_ip6=is_ip6) + self.assertEqual(0, len(b)) + else: + # change endpoint + if is_ip6: + peer_1.change_endpoint(self.pg1.remote_hosts[1].ip6, port + 100) + else: + peer_1.change_endpoint(self.pg1.remote_hosts[1].ip4, port + 100) + # prepare and send a handshake initiation + # expect a handshake response sent to the new endpoint + init = peer_1.mk_handshake(self.pg1, is_ip6=is_ip6) + rxs = self.send_and_expect(self.pg1, [init], self.pg1) + # verify the response + peer_1.consume_response(rxs[0], is_ip6=is_ip6) + self.assertTrue(peer_1.query_vpp_config()) + + # remove configs + peer_1.remove_vpp_config() + wg0.remove_vpp_config() + + def test_wg_peer_roaming_on_init_v4(self): + """Peer roaming on handshake initiation (v4)""" + self._test_wg_peer_roaming_on_handshake_tmpl( + is_endpoint_set=False, is_resp=False, is_ip6=False + ) + + def test_wg_peer_roaming_on_init_v6(self): + """Peer roaming on handshake initiation (v6)""" + self._test_wg_peer_roaming_on_handshake_tmpl( + is_endpoint_set=False, is_resp=False, is_ip6=True + ) + + def test_wg_peer_roaming_on_resp_v4(self): + """Peer roaming on handshake response (v4)""" + self._test_wg_peer_roaming_on_handshake_tmpl( + is_endpoint_set=True, is_resp=True, is_ip6=False + ) + + def test_wg_peer_roaming_on_resp_v6(self): + """Peer roaming on handshake response (v6)""" + self._test_wg_peer_roaming_on_handshake_tmpl( + is_endpoint_set=True, is_resp=True, is_ip6=True + ) + + def _test_wg_peer_roaming_on_data_tmpl(self, is_async, is_ip6): + self.vapi.wg_set_async_mode(is_async) + port = 12323 + + # create wg interface + if is_ip6: + wg0 = VppWgInterface(self, self.pg1.local_ip6, port).add_vpp_config() + wg0.admin_up() + wg0.config_ip6() + else: + wg0 = VppWgInterface(self, self.pg1.local_ip4, port).add_vpp_config() + wg0.admin_up() + wg0.config_ip4() + + self.pg_enable_capture(self.pg_interfaces) + self.pg_start() + + # create more remote hosts + NUM_REMOTE_HOSTS = 2 + self.pg1.generate_remote_hosts(NUM_REMOTE_HOSTS) + if is_ip6: + self.pg1.configure_ipv6_neighbors() + else: + self.pg1.configure_ipv4_neighbors() + + # create a peer + if is_ip6: + peer_1 = VppWgPeer( + self, wg0, self.pg1.remote_hosts[0].ip6, port + 1, ["1::3:0/112"] + ).add_vpp_config() + else: + peer_1 = VppWgPeer( + self, wg0, self.pg1.remote_hosts[0].ip4, port + 1, ["10.11.3.0/24"] + ).add_vpp_config() + self.assertTrue(peer_1.query_vpp_config()) + + # create a route to rewrite traffic into the wg interface + if is_ip6: + r1 = VppIpRoute( + self, "1::3:0", 112, [VppRoutePath("1::3:1", wg0.sw_if_index)] + ).add_vpp_config() + else: + r1 = VppIpRoute( + self, "10.11.3.0", 24, [VppRoutePath("10.11.3.1", wg0.sw_if_index)] + ).add_vpp_config() + + # wait for the peer to send a handshake initiation + rxs = self.pg1.get_capture(1, timeout=2) + + # prepare and send a handshake response + # expect a keepalive message + resp = peer_1.consume_init(rxs[0], self.pg1, is_ip6=is_ip6) + rxs = self.send_and_expect(self.pg1, [resp], self.pg1) + + # verify the keepalive message + b = peer_1.decrypt_transport(rxs[0], is_ip6=is_ip6) + self.assertEqual(0, len(b)) + + # change endpoint + if is_ip6: + peer_1.change_endpoint(self.pg1.remote_hosts[1].ip6, port + 100) + else: + peer_1.change_endpoint(self.pg1.remote_hosts[1].ip4, port + 100) + + # prepare and send a data packet + # expect endpoint change + if is_ip6: + ip_header = IPv6(src="1::3:1", dst=self.pg0.remote_ip6, hlim=20) + else: + ip_header = IP(src="10.11.3.1", dst=self.pg0.remote_ip4, ttl=20) + data = ( + peer_1.mk_tunnel_header(self.pg1, is_ip6=is_ip6) + / Wireguard(message_type=4, reserved_zero=0) + / WireguardTransport( + receiver_index=peer_1.sender, + counter=0, + encrypted_encapsulated_packet=peer_1.encrypt_transport( + ip_header / UDP(sport=222, dport=223) / Raw() + ), + ) + ) + rxs = self.send_and_expect(self.pg1, [data], self.pg0) + if is_ip6: + self.assertEqual(rxs[0][IPv6].dst, self.pg0.remote_ip6) + self.assertEqual(rxs[0][IPv6].hlim, 19) + else: + self.assertEqual(rxs[0][IP].dst, self.pg0.remote_ip4) + self.assertEqual(rxs[0][IP].ttl, 19) + self.assertTrue(peer_1.query_vpp_config()) + + # prepare and send a packet that will be rewritten into the wg interface + # expect a data packet sent to the new endpoint + if is_ip6: + ip_header = IPv6(src=self.pg0.remote_ip6, dst="1::3:2") + else: + ip_header = IP(src=self.pg0.remote_ip4, dst="10.11.3.2") + p = ( + Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) + / ip_header + / UDP(sport=555, dport=556) + / Raw() + ) + rxs = self.send_and_expect(self.pg0, [p], self.pg1) + + # verify the data packet + peer_1.validate_encapped(rxs, p, is_ip6=is_ip6) + + # remove configs + r1.remove_vpp_config() + peer_1.remove_vpp_config() + wg0.remove_vpp_config() + + def test_wg_peer_roaming_on_data_v4_sync(self): + """Peer roaming on data packet (v4, sync)""" + self._test_wg_peer_roaming_on_data_tmpl(is_async=False, is_ip6=False) + + def test_wg_peer_roaming_on_data_v6_sync(self): + """Peer roaming on data packet (v6, sync)""" + self._test_wg_peer_roaming_on_data_tmpl(is_async=False, is_ip6=True) + + def test_wg_peer_roaming_on_data_v4_async(self): + """Peer roaming on data packet (v4, async)""" + self._test_wg_peer_roaming_on_data_tmpl(is_async=True, is_ip6=False) + + def test_wg_peer_roaming_on_data_v6_async(self): + """Peer roaming on data packet (v6, async)""" + self._test_wg_peer_roaming_on_data_tmpl(is_async=True, is_ip6=True) + def test_wg_peer_resp(self): """Send handshake response""" port = 12323 @@ -1197,7 +1425,7 @@ class TestWg(VppTestCase): for rx in rxs: rx = IP(peer_1.decrypt_transport(rx)) - # chech the oringial packet is present + # check the original packet is present self.assertEqual(rx[IP].dst, p[IP].dst) self.assertEqual(rx[IP].ttl, p[IP].ttl - 1) @@ -1358,7 +1586,7 @@ class TestWg(VppTestCase): for rx in rxs: rx = IPv6(peer_1.decrypt_transport(rx, True)) - # chech the oringial packet is present + # check the original packet is present self.assertEqual(rx[IPv6].dst, p[IPv6].dst) self.assertEqual(rx[IPv6].hlim, p[IPv6].hlim - 1) @@ -1499,7 +1727,7 @@ class TestWg(VppTestCase): for rx in rxs: rx = IPv6(peer_1.decrypt_transport(rx)) - # chech the oringial packet is present + # check the original packet is present self.assertEqual(rx[IPv6].dst, p[IPv6].dst) self.assertEqual(rx[IPv6].hlim, p[IPv6].hlim - 1) @@ -1638,7 +1866,7 @@ class TestWg(VppTestCase): for rx in rxs: rx = IP(peer_1.decrypt_transport(rx, True)) - # chech the oringial packet is present + # check the original packet is present self.assertEqual(rx[IP].dst, p[IP].dst) self.assertEqual(rx[IP].ttl, p[IP].ttl - 1) |