From 9e24f7eb911fc5ab7558109286fe8e1d2774ea93 Mon Sep 17 00:00:00 2001 From: Artem Glazychev Date: Tue, 25 May 2021 12:06:42 +0700 Subject: wireguard: use the same udp-port for multi-tunnel now we can reuse udp-port for many wireguard interfaces Type: improvement Change-Id: I14b5a9dbe917d83300ccb4d6907743d88355e5c5 Signed-off-by: Artem Glazychev --- src/plugins/wireguard/wireguard_if.c | 37 +++++++----- src/plugins/wireguard/wireguard_if.h | 12 ++-- src/plugins/wireguard/wireguard_input.c | 34 ++++++++--- src/plugins/wireguard/wireguard_peer.c | 7 +++ src/plugins/wireguard/wireguard_peer.h | 1 + test/test_wireguard.py | 104 ++++++++++++++++++++++++++++++++ 6 files changed, 166 insertions(+), 29 deletions(-) diff --git a/src/plugins/wireguard/wireguard_if.c b/src/plugins/wireguard/wireguard_if.c index b1c5cb20638..0866d24e775 100644 --- a/src/plugins/wireguard/wireguard_if.c +++ b/src/plugins/wireguard/wireguard_if.c @@ -32,7 +32,7 @@ static uword *wg_if_instances; static index_t *wg_if_index_by_sw_if_index; /* vector of interfaces key'd on their UDP port (in network order) */ -index_t *wg_if_index_by_port; +index_t **wg_if_indexes_by_port; static u8 * format_wg_if_name (u8 * s, va_list * args) @@ -253,13 +253,6 @@ wg_if_create (u32 user_instance, *sw_if_indexp = (u32) ~ 0; - /* - * Check if the required port is already in use - */ - udp_dst_port_info_t *pi = udp_get_dst_port_info (&udp_main, port, UDP_IP4); - if (pi) - return VNET_API_ERROR_UDP_PORT_TAKEN; - /* * Allocate a wg_if instance. Either select on dynamically * or try to use the desired user_instance number. @@ -295,10 +288,11 @@ wg_if_create (u32 user_instance, if (~0 == wg_if->user_instance) wg_if->user_instance = t_idx; - udp_register_dst_port (vlib_get_main (), port, wg_input_node.index, 1); + vec_validate_init_empty (wg_if_indexes_by_port, port, NULL); + if (vec_len (wg_if_indexes_by_port[port]) == 0) + udp_register_dst_port (vlib_get_main (), port, wg_input_node.index, 1); - vec_validate_init_empty (wg_if_index_by_port, port, INDEX_INVALID); - wg_if_index_by_port[port] = wg_if - wg_if_pool; + vec_add1 (wg_if_indexes_by_port[port], t_idx); wg_if->port = port; wg_if->local_idx = local - noise_local_pool; @@ -334,15 +328,30 @@ wg_if_delete (u32 sw_if_index) return VNET_API_ERROR_INVALID_VALUE; wg_if_t *wg_if; - wg_if = wg_if_get (wg_if_find_by_sw_if_index (sw_if_index)); + index_t wgii = wg_if_find_by_sw_if_index (sw_if_index); + wg_if = wg_if_get (wgii); if (NULL == wg_if) return VNET_API_ERROR_INVALID_SW_IF_INDEX_2; if (wg_if_instance_free (wg_if->user_instance) < 0) return VNET_API_ERROR_INVALID_VALUE_2; - udp_unregister_dst_port (vlib_get_main (), wg_if->port, 1); - wg_if_index_by_port[wg_if->port] = INDEX_INVALID; + // Remove peers before interface deletion + wg_if_peer_walk (wg_if, wg_peer_if_delete, NULL); + + index_t *ii; + index_t *ifs = wg_if_indexes_get_by_port (wg_if->port); + vec_foreach (ii, ifs) + { + if (*ii == wgii) + { + vec_del1 (ifs, ifs - ii); + break; + } + } + if (vec_len (ifs) == 0) + udp_unregister_dst_port (vlib_get_main (), wg_if->port, 1); + vnet_delete_hw_interface (vnm, hw->hw_if_index); pool_put_index (noise_local_pool, wg_if->local_idx); pool_put (wg_if_pool, wg_if); diff --git a/src/plugins/wireguard/wireguard_if.h b/src/plugins/wireguard/wireguard_if.h index e43557b56fe..3153a380066 100644 --- a/src/plugins/wireguard/wireguard_if.h +++ b/src/plugins/wireguard/wireguard_if.h @@ -71,16 +71,16 @@ wg_if_get (index_t wgii) return (pool_elt_at_index (wg_if_pool, wgii)); } -extern index_t *wg_if_index_by_port; +extern index_t **wg_if_indexes_by_port; -static_always_inline wg_if_t * -wg_if_get_by_port (u16 port) +static_always_inline index_t * +wg_if_indexes_get_by_port (u16 port) { - if (vec_len (wg_if_index_by_port) < port) + if (vec_len (wg_if_indexes_by_port) == 0) return (NULL); - if (INDEX_INVALID == wg_if_index_by_port[port]) + if (vec_len (wg_if_indexes_by_port[port]) == 0) return (NULL); - return (wg_if_get (wg_if_index_by_port[port])); + return (wg_if_indexes_by_port[port]); } diff --git a/src/plugins/wireguard/wireguard_input.c b/src/plugins/wireguard/wireguard_input.c index 8dfba7615ef..ad002dcb3c2 100644 --- a/src/plugins/wireguard/wireguard_input.c +++ b/src/plugins/wireguard/wireguard_input.c @@ -116,6 +116,7 @@ wg_handshake_process (vlib_main_t * vm, wg_main_t * wmp, vlib_buffer_t * b) enum cookie_mac_state mac_state; bool packet_needs_cookie; bool under_load; + index_t *wg_ifs; wg_if_t *wg_if; wg_peer_t *peer = NULL; @@ -131,11 +132,6 @@ wg_handshake_process (vlib_main_t * vm, wg_main_t * wmp, vlib_buffer_t * b) message_header_t *header = current_b_data; under_load = false; - wg_if = wg_if_get_by_port (udp_dst_port); - - if (NULL == wg_if) - return WG_INPUT_ERROR_INTERFACE; - if (PREDICT_FALSE (header->type == MESSAGE_HANDSHAKE_COOKIE)) { message_handshake_cookie_t *packet = @@ -159,10 +155,30 @@ wg_handshake_process (vlib_main_t * vm, wg_main_t * wmp, vlib_buffer_t * b) message_macs_t *macs = (message_macs_t *) ((u8 *) current_b_data + len - sizeof (*macs)); - mac_state = - cookie_checker_validate_macs (vm, &wg_if->cookie_checker, macs, - current_b_data, len, under_load, ip4_src, - udp_src_port); + index_t *ii; + wg_ifs = wg_if_indexes_get_by_port (udp_dst_port); + if (NULL == wg_ifs) + return WG_INPUT_ERROR_INTERFACE; + + vec_foreach (ii, wg_ifs) + { + wg_if = wg_if_get (*ii); + if (NULL == wg_if) + continue; + + mac_state = cookie_checker_validate_macs ( + vm, &wg_if->cookie_checker, macs, current_b_data, len, under_load, + ip4_src, udp_src_port); + if (mac_state == INVALID_MAC) + { + wg_if = NULL; + continue; + } + break; + } + + if (NULL == wg_if) + return WG_INPUT_ERROR_HANDSHAKE_MAC; if ((under_load && mac_state == VALID_MAC_WITH_COOKIE) || (!under_load && mac_state == VALID_MAC_BUT_NO_COOKIE)) diff --git a/src/plugins/wireguard/wireguard_peer.c b/src/plugins/wireguard/wireguard_peer.c index 5e2011b401e..d4a85a20cf6 100644 --- a/src/plugins/wireguard/wireguard_peer.c +++ b/src/plugins/wireguard/wireguard_peer.c @@ -207,6 +207,13 @@ wg_peer_adj_walk (adj_index_t ai, void *data) ADJ_WALK_RC_STOP; } +walk_rc_t +wg_peer_if_delete (index_t peeri, void *data) +{ + wg_peer_remove (peeri); + return (WALK_CONTINUE); +} + static int wg_peer_fill (vlib_main_t *vm, wg_peer_t *peer, u32 table_id, const ip46_address_t *dst, u16 port, diff --git a/src/plugins/wireguard/wireguard_peer.h b/src/plugins/wireguard/wireguard_peer.h index cf859f32b1f..e23feb7866f 100644 --- a/src/plugins/wireguard/wireguard_peer.h +++ b/src/plugins/wireguard/wireguard_peer.h @@ -106,6 +106,7 @@ index_t wg_peer_walk (wg_peer_walk_cb_t fn, void *data); u8 *format_wg_peer (u8 * s, va_list * va); walk_rc_t wg_peer_if_admin_state_change (index_t peeri, void *data); +walk_rc_t wg_peer_if_delete (index_t peeri, void *data); walk_rc_t wg_peer_if_adj_change (index_t peeri, void *data); adj_walk_rc_t wg_peer_adj_walk (adj_index_t ai, void *data); diff --git a/test/test_wireguard.py b/test/test_wireguard.py index 206425e14fd..96c1bc0f5fd 100755 --- a/test/test_wireguard.py +++ b/test/test_wireguard.py @@ -677,6 +677,106 @@ class TestWg(VppTestCase): wg0.remove_vpp_config() wg1.remove_vpp_config() + def test_wg_multi_interface(self): + """ Multi-tunnel on the same port """ + port = 12500 + + # Create many wireguard interfaces + NUM_IFS = 4 + self.pg1.generate_remote_hosts(NUM_IFS) + self.pg1.configure_ipv4_neighbors() + self.pg0.generate_remote_hosts(NUM_IFS) + self.pg0.configure_ipv4_neighbors() + + # Create interfaces with a peer on each + peers = [] + routes = [] + wg_ifs = [] + for i in range(NUM_IFS): + # Use the same port for each interface + wg0 = VppWgInterface(self, + self.pg1.local_ip4, + port).add_vpp_config() + wg0.admin_up() + wg0.config_ip4() + wg_ifs.append(wg0) + peers.append(VppWgPeer(self, + wg0, + self.pg1.remote_hosts[i].ip4, + port+1+i, + ["10.0.%d.0/24" % i]).add_vpp_config()) + + routes.append(VppIpRoute(self, "10.0.%d.0" % i, 24, + [VppRoutePath("10.0.%d.4" % i, + wg0.sw_if_index)]).add_vpp_config()) + + self.assertEqual(len(self.vapi.wireguard_peers_dump()), NUM_IFS) + + for i in range(NUM_IFS): + # send a valid handsake init for which we expect a response + p = peers[i].mk_handshake(self.pg1) + rx = self.send_and_expect(self.pg1, [p], self.pg1) + peers[i].consume_response(rx[0]) + + # send a data packet from the peer through the tunnel + # this completes the handshake + p = (IP(src="10.0.%d.4" % i, + dst=self.pg0.remote_hosts[i].ip4, ttl=20) / + UDP(sport=222, dport=223) / + Raw()) + d = peers[i].encrypt_transport(p) + p = (peers[i].mk_tunnel_header(self.pg1) / + (Wireguard(message_type=4, reserved_zero=0) / + WireguardTransport(receiver_index=peers[i].sender, + counter=0, + encrypted_encapsulated_packet=d))) + rxs = self.send_and_expect(self.pg1, [p], self.pg0) + for rx in rxs: + self.assertEqual(rx[IP].dst, self.pg0.remote_hosts[i].ip4) + self.assertEqual(rx[IP].ttl, 19) + + # send a packets that are routed into the tunnel + for i in range(NUM_IFS): + p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) / + IP(src=self.pg0.remote_hosts[i].ip4, dst="10.0.%d.4" % i) / + UDP(sport=555, dport=556) / + Raw(b'\x00' * 80)) + + rxs = self.send_and_expect(self.pg0, p * 64, self.pg1) + + for rx in rxs: + rx = IP(peers[i].decrypt_transport(rx)) + + # check the oringial packet is present + self.assertEqual(rx[IP].dst, p[IP].dst) + self.assertEqual(rx[IP].ttl, p[IP].ttl-1) + + # send packets into the tunnel + for i in range(NUM_IFS): + p = [(peers[i].mk_tunnel_header(self.pg1) / + Wireguard(message_type=4, reserved_zero=0) / + WireguardTransport( + receiver_index=peers[i].sender, + counter=ii+1, + encrypted_encapsulated_packet=peers[i].encrypt_transport( + (IP(src="10.0.%d.4" % i, + dst=self.pg0.remote_hosts[i].ip4, ttl=20) / + UDP(sport=222, dport=223) / + Raw())))) for ii in range(64)] + + rxs = self.send_and_expect(self.pg1, p, self.pg0) + + for rx in rxs: + self.assertEqual(rx[IP].dst, self.pg0.remote_hosts[i].ip4) + self.assertEqual(rx[IP].ttl, 19) + + for r in routes: + r.remove_vpp_config() + for p in peers: + p.remove_vpp_config() + for i in wg_ifs: + i.remove_vpp_config() + class WireguardHandoffTests(TestWg): """ Wireguard Tests in multi worker setup """ @@ -768,3 +868,7 @@ class WireguardHandoffTests(TestWg): r1.remove_vpp_config() peer_1.remove_vpp_config() wg0.remove_vpp_config() + + @unittest.skip("test disabled") + def test_wg_multi_interface(self): + """ Multi-tunnel on the same port """ -- cgit 1.2.3-korg