diff options
Diffstat (limited to 'src/plugins/wireguard')
-rwxr-xr-x | src/plugins/wireguard/CMakeLists.txt | 1 | ||||
-rwxr-xr-x | src/plugins/wireguard/test/test_wireguard.py | 101 | ||||
-rwxr-xr-x | src/plugins/wireguard/wireguard.c | 12 | ||||
-rwxr-xr-x | src/plugins/wireguard/wireguard.h | 18 | ||||
-rwxr-xr-x | src/plugins/wireguard/wireguard_api.c | 4 | ||||
-rw-r--r-- | src/plugins/wireguard/wireguard_handoff.c | 197 | ||||
-rw-r--r-- | src/plugins/wireguard/wireguard_if.c | 86 | ||||
-rw-r--r-- | src/plugins/wireguard/wireguard_if.h | 5 | ||||
-rwxr-xr-x | src/plugins/wireguard/wireguard_input.c | 255 | ||||
-rwxr-xr-x | src/plugins/wireguard/wireguard_noise.c | 131 | ||||
-rwxr-xr-x | src/plugins/wireguard/wireguard_noise.h | 23 | ||||
-rwxr-xr-x | src/plugins/wireguard/wireguard_output_tun.c | 94 | ||||
-rwxr-xr-x | src/plugins/wireguard/wireguard_peer.c | 58 | ||||
-rwxr-xr-x | src/plugins/wireguard/wireguard_peer.h | 40 | ||||
-rwxr-xr-x | src/plugins/wireguard/wireguard_send.c | 89 | ||||
-rwxr-xr-x | src/plugins/wireguard/wireguard_send.h | 1 | ||||
-rwxr-xr-x | src/plugins/wireguard/wireguard_timer.c | 220 | ||||
-rwxr-xr-x | src/plugins/wireguard/wireguard_timer.h | 2 |
18 files changed, 889 insertions, 448 deletions
diff --git a/src/plugins/wireguard/CMakeLists.txt b/src/plugins/wireguard/CMakeLists.txt index db5bb2d8910..db74f9cdce0 100755 --- a/src/plugins/wireguard/CMakeLists.txt +++ b/src/plugins/wireguard/CMakeLists.txt @@ -30,6 +30,7 @@ add_vpp_plugin(wireguard wireguard_if.h wireguard_input.c wireguard_output_tun.c + wireguard_handoff.c wireguard_key.c wireguard_key.h wireguard_cli.c diff --git a/src/plugins/wireguard/test/test_wireguard.py b/src/plugins/wireguard/test/test_wireguard.py index 77349396ec6..cee1e938bb0 100755 --- a/src/plugins/wireguard/test/test_wireguard.py +++ b/src/plugins/wireguard/test/test_wireguard.py @@ -327,6 +327,14 @@ class VppWgPeer(VppObject): def encrypt_transport(self, p): return self.noise.encrypt(bytes(p)) + def validate_encapped(self, rxs, tx): + for rx in rxs: + rx = IP(self.decrypt_transport(rx)) + + # chech the oringial packet is present + self._test.assertEqual(rx[IP].dst, tx[IP].dst) + self._test.assertEqual(rx[IP].ttl, tx[IP].ttl-1) + class TestWg(VppTestCase): """ Wireguard Test Case """ @@ -455,11 +463,7 @@ class TestWg(VppTestCase): rxs = self.send_and_expect(self.pg0, p * 255, self.pg1) - for rx in rxs: - rx = IP(peer_1.decrypt_transport(rx)) - # chech the oringial packet is present - self.assertEqual(rx[IP].dst, p[IP].dst) - self.assertEqual(rx[IP].ttl, p[IP].ttl-1) + peer_1.validate_encapped(rxs, p) # send packets into the tunnel, expect to receive them on # the other side @@ -655,3 +659,90 @@ class TestWg(VppTestCase): wg0.remove_vpp_config() wg1.remove_vpp_config() + + +class WireguardHandoffTests(TestWg): + """ Wireguard Tests in multi worker setup """ + worker_config = "workers 2" + + def test_wg_peer_init(self): + """ Handoff """ + wg_output_node_name = '/err/wg-output-tun/' + wg_input_node_name = '/err/wg-input/' + + port = 12323 + + # Create interfaces + wg0 = VppWgInterface(self, + self.pg1.local_ip4, + port).add_vpp_config() + wg0.admin_up() + wg0.config_ip4() + + peer_1 = VppWgPeer(self, + wg0, + self.pg1.remote_ip4, + port+1, + ["10.11.2.0/24", + "10.11.3.0/24"]).add_vpp_config() + self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1) + + # send a valid handsake init for which we expect a response + p = peer_1.mk_handshake(self.pg1) + + rx = self.send_and_expect(self.pg1, [p], self.pg1) + + peer_1.consume_response(rx[0]) + + # send a data packet from the peer through the tunnel + # this completes the handshake and pins the peer to worker 0 + p = (IP(src="10.11.3.1", dst=self.pg0.remote_ip4, ttl=20) / + UDP(sport=222, dport=223) / + Raw()) + d = peer_1.encrypt_transport(p) + p = (peer_1.mk_tunnel_header(self.pg1) / + (Wireguard(message_type=4, reserved_zero=0) / + WireguardTransport(receiver_index=peer_1.sender, + counter=0, + encrypted_encapsulated_packet=d))) + rxs = self.send_and_expect(self.pg1, [p], self.pg0, + worker=0) + + for rx in rxs: + self.assertEqual(rx[IP].dst, self.pg0.remote_ip4) + self.assertEqual(rx[IP].ttl, 19) + + # send a packets that are routed into the tunnel + # and pins the peer tp worker 1 + pe = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) / + IP(src=self.pg0.remote_ip4, dst="10.11.3.2") / + UDP(sport=555, dport=556) / + Raw(b'\x00' * 80)) + rxs = self.send_and_expect(self.pg0, pe * 255, self.pg1, worker=1) + peer_1.validate_encapped(rxs, pe) + + # send packets into the tunnel, from the other worker + p = [(peer_1.mk_tunnel_header(self.pg1) / + Wireguard(message_type=4, reserved_zero=0) / + WireguardTransport( + receiver_index=peer_1.sender, + counter=ii+1, + encrypted_encapsulated_packet=peer_1.encrypt_transport( + (IP(src="10.11.3.1", dst=self.pg0.remote_ip4, ttl=20) / + UDP(sport=222, dport=223) / + Raw())))) for ii in range(255)] + + rxs = self.send_and_expect(self.pg1, p, self.pg0, worker=1) + + for rx in rxs: + self.assertEqual(rx[IP].dst, self.pg0.remote_ip4) + self.assertEqual(rx[IP].ttl, 19) + + # send a packets that are routed into the tunnel + # from owrker 0 + rxs = self.send_and_expect(self.pg0, pe * 255, self.pg1, worker=0) + + peer_1.validate_encapped(rxs, pe) + + peer_1.remove_vpp_config() + wg0.remove_vpp_config() diff --git a/src/plugins/wireguard/wireguard.c b/src/plugins/wireguard/wireguard.c index 00921811e4a..9510a0ad385 100755 --- a/src/plugins/wireguard/wireguard.c +++ b/src/plugins/wireguard/wireguard.c @@ -32,7 +32,17 @@ wg_init (vlib_main_t * vm) wg_main_t *wmp = &wg_main; wmp->vlib_main = vm; - wmp->peers = 0; + + wmp->in_fq_index = vlib_frame_queue_main_init (wg_input_node.index, 0); + wmp->out_fq_index = + vlib_frame_queue_main_init (wg_output_tun_node.index, 0); + + vlib_thread_main_t *tm = vlib_get_thread_main (); + + vec_validate_aligned (wmp->per_thread_data, tm->n_vlib_mains, + CLIB_CACHE_LINE_BYTES); + + wg_timer_wheel_init (); return (NULL); } diff --git a/src/plugins/wireguard/wireguard.h b/src/plugins/wireguard/wireguard.h index 70a692e602f..2c892a374b8 100755 --- a/src/plugins/wireguard/wireguard.h +++ b/src/plugins/wireguard/wireguard.h @@ -17,13 +17,17 @@ #include <wireguard/wireguard_index_table.h> #include <wireguard/wireguard_messages.h> -#include <wireguard/wireguard_peer.h> +#include <wireguard/wireguard_timer.h> + +#define WG_DEFAULT_DATA_SIZE 2048 extern vlib_node_registration_t wg_input_node; extern vlib_node_registration_t wg_output_tun_node; - - +typedef struct wg_per_thread_data_t_ +{ + u8 data[WG_DEFAULT_DATA_SIZE]; +} wg_per_thread_data_t; typedef struct { /* convenience */ @@ -31,10 +35,14 @@ typedef struct u16 msg_id_base; - // Peers pool - wg_peer_t *peers; wg_index_table_t index_table; + u32 in_fq_index; + u32 out_fq_index; + + wg_per_thread_data_t *per_thread_data; + + tw_timer_wheel_16t_2w_512sl_t timer_wheel; } wg_main_t; extern wg_main_t wg_main; diff --git a/src/plugins/wireguard/wireguard_api.c b/src/plugins/wireguard/wireguard_api.c index 8bbacddaf45..27ed6ea05da 100755 --- a/src/plugins/wireguard/wireguard_api.c +++ b/src/plugins/wireguard/wireguard_api.c @@ -97,15 +97,17 @@ wireguard_if_send_details (index_t wgii, void *data) vl_api_wireguard_interface_details_t *rmp; wg_deatils_walk_t *ctx = data; const wg_if_t *wgi; + const noise_local_t *local; wgi = wg_if_get (wgii); + local = noise_local_get (wgi->local_idx); rmp = vl_msg_api_alloc_zero (sizeof (*rmp)); rmp->_vl_msg_id = htons (VL_API_WIREGUARD_INTERFACE_DETAILS + wg_main.msg_id_base); clib_memcpy (rmp->interface.private_key, - wgi->local.l_private, NOISE_PUBLIC_KEY_LEN); + local->l_private, NOISE_PUBLIC_KEY_LEN); rmp->interface.sw_if_index = htonl (wgi->sw_if_index); rmp->interface.port = htons (wgi->port); ip_address_encode2 (&wgi->src_ip, &rmp->interface.src_ip); diff --git a/src/plugins/wireguard/wireguard_handoff.c b/src/plugins/wireguard/wireguard_handoff.c new file mode 100644 index 00000000000..b0b74229452 --- /dev/null +++ b/src/plugins/wireguard/wireguard_handoff.c @@ -0,0 +1,197 @@ +/* + * Copyright (c) 2020 Doc.ai and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <wireguard/wireguard.h> +#include <wireguard/wireguard_peer.h> + +#define foreach_wg_handoff_error \ +_(CONGESTION_DROP, "congestion drop") + +typedef enum +{ +#define _(sym,str) WG_HANDOFF_ERROR_##sym, + foreach_wg_handoff_error +#undef _ + HANDOFF_N_ERROR, +} ipsec_handoff_error_t; + +static char *wg_handoff_error_strings[] = { +#define _(sym,string) string, + foreach_wg_handoff_error +#undef _ +}; + +typedef enum +{ + WG_HANDOFF_HANDSHAKE, + WG_HANDOFF_INP_DATA, + WG_HANDOFF_OUT_TUN, +} wg_handoff_mode_t; + +typedef struct wg_handoff_trace_t_ +{ + u32 next_worker_index; + index_t peer; +} wg_handoff_trace_t; + +static u8 * +format_wg_handoff_trace (u8 * s, va_list * args) +{ + CLIB_UNUSED (vlib_main_t * vm) = va_arg (*args, vlib_main_t *); + CLIB_UNUSED (vlib_node_t * node) = va_arg (*args, vlib_node_t *); + wg_handoff_trace_t *t = va_arg (*args, wg_handoff_trace_t *); + + s = format (s, "next-worker %d peer %d", t->next_worker_index, t->peer); + + return s; +} + +static_always_inline uword +wg_handoff (vlib_main_t * vm, + vlib_node_runtime_t * node, + vlib_frame_t * frame, u32 fq_index, wg_handoff_mode_t mode) +{ + vlib_buffer_t *bufs[VLIB_FRAME_SIZE], **b; + u16 thread_indices[VLIB_FRAME_SIZE], *ti; + u32 n_enq, n_left_from, *from; + wg_main_t *wmp; + + wmp = &wg_main; + from = vlib_frame_vector_args (frame); + n_left_from = frame->n_vectors; + vlib_get_buffers (vm, from, bufs, n_left_from); + + b = bufs; + ti = thread_indices; + + while (n_left_from > 0) + { + const wg_peer_t *peer; + index_t peeri; + + if (PREDICT_FALSE (mode == WG_HANDOFF_HANDSHAKE)) + { + ti[0] = 0; + } + else if (mode == WG_HANDOFF_INP_DATA) + { + message_data_t *data = vlib_buffer_get_current (b[0]); + u32 *entry = + wg_index_table_lookup (&wmp->index_table, data->receiver_index); + peeri = *entry; + peer = wg_peer_get (peeri); + + ti[0] = peer->input_thread_index; + } + else + { + peeri = + wg_peer_get_by_adj_index (vnet_buffer (b[0])-> + ip.adj_index[VLIB_TX]); + peer = wg_peer_get (peeri); + ti[0] = peer->output_thread_index; + } + + if (PREDICT_FALSE (b[0]->flags & VLIB_BUFFER_IS_TRACED)) + { + wg_handoff_trace_t *t = + vlib_add_trace (vm, node, b[0], sizeof (*t)); + t->next_worker_index = ti[0]; + t->peer = peeri; + } + + n_left_from -= 1; + ti += 1; + b += 1; + } + + n_enq = vlib_buffer_enqueue_to_thread (vm, fq_index, from, + thread_indices, frame->n_vectors, 1); + + if (n_enq < frame->n_vectors) + vlib_node_increment_counter (vm, node->node_index, + WG_HANDOFF_ERROR_CONGESTION_DROP, + frame->n_vectors - n_enq); + + return n_enq; +} + +VLIB_NODE_FN (wg_handshake_handoff) (vlib_main_t * vm, + vlib_node_runtime_t * node, + vlib_frame_t * from_frame) +{ + wg_main_t *wmp = &wg_main; + + return wg_handoff (vm, node, from_frame, wmp->in_fq_index, + WG_HANDOFF_HANDSHAKE); +} + +VLIB_NODE_FN (wg_input_data_handoff) (vlib_main_t * vm, + vlib_node_runtime_t * node, + vlib_frame_t * from_frame) +{ + wg_main_t *wmp = &wg_main; + + return wg_handoff (vm, node, from_frame, wmp->in_fq_index, + WG_HANDOFF_INP_DATA); +} + +VLIB_NODE_FN (wg_output_tun_handoff) (vlib_main_t * vm, + vlib_node_runtime_t * node, + vlib_frame_t * from_frame) +{ + wg_main_t *wmp = &wg_main; + + return wg_handoff (vm, node, from_frame, wmp->out_fq_index, + WG_HANDOFF_OUT_TUN); +} + +VLIB_REGISTER_NODE (wg_handshake_handoff) = +{ + .name = "wg-handshake-handoff",.vector_size = sizeof (u32),.format_trace = + format_wg_handoff_trace,.type = VLIB_NODE_TYPE_INTERNAL,.n_errors = + ARRAY_LEN (wg_handoff_error_strings),.error_strings = + wg_handoff_error_strings,.n_next_nodes = 1,.next_nodes = + { + [0] = "error-drop",} +,}; + +VLIB_REGISTER_NODE (wg_input_data_handoff) = +{ + .name = "wg-input-data-handoff",.vector_size = sizeof (u32),.format_trace = + format_wg_handoff_trace,.type = VLIB_NODE_TYPE_INTERNAL,.n_errors = + ARRAY_LEN (wg_handoff_error_strings),.error_strings = + wg_handoff_error_strings,.n_next_nodes = 1,.next_nodes = + { + [0] = "error-drop",} +,}; + +VLIB_REGISTER_NODE (wg_output_tun_handoff) = +{ + .name = "wg-output-tun-handoff",.vector_size = sizeof (u32),.format_trace = + format_wg_handoff_trace,.type = VLIB_NODE_TYPE_INTERNAL,.n_errors = + ARRAY_LEN (wg_handoff_error_strings),.error_strings = + wg_handoff_error_strings,.n_next_nodes = 1,.next_nodes = + { + [0] = "error-drop",} +,}; + +/* + * fd.io coding-style-patch-verification: ON + * + * Local Variables: + * eval: (c-set-style "gnu") + * End: + */ diff --git a/src/plugins/wireguard/wireguard_if.c b/src/plugins/wireguard/wireguard_if.c index c91667bb234..7509923a1bf 100644 --- a/src/plugins/wireguard/wireguard_if.c +++ b/src/plugins/wireguard/wireguard_if.c @@ -5,6 +5,7 @@ #include <wireguard/wireguard_messages.h> #include <wireguard/wireguard_if.h> #include <wireguard/wireguard.h> +#include <wireguard/wireguard_peer.h> /* pool of interfaces */ wg_if_t *wg_if_pool; @@ -30,28 +31,28 @@ format_wg_if (u8 * s, va_list * args) { index_t wgii = va_arg (*args, u32); wg_if_t *wgi = wg_if_get (wgii); + noise_local_t *local = noise_local_get (wgi->local_idx); u8 key[NOISE_KEY_LEN_BASE64]; - key_to_base64 (wgi->local.l_private, NOISE_PUBLIC_KEY_LEN, key); s = format (s, "[%d] %U src:%U port:%d", wgii, format_vnet_sw_if_index_name, vnet_get_main (), wgi->sw_if_index, format_ip_address, &wgi->src_ip, wgi->port); - key_to_base64 (wgi->local.l_private, NOISE_PUBLIC_KEY_LEN, key); + key_to_base64 (local->l_private, NOISE_PUBLIC_KEY_LEN, key); s = format (s, " private-key:%s", key); s = - format (s, " %U", format_hex_bytes, wgi->local.l_private, + format (s, " %U", format_hex_bytes, local->l_private, NOISE_PUBLIC_KEY_LEN); - key_to_base64 (wgi->local.l_public, NOISE_PUBLIC_KEY_LEN, key); + key_to_base64 (local->l_public, NOISE_PUBLIC_KEY_LEN, key); s = format (s, " public-key:%s", key); s = - format (s, " %U", format_hex_bytes, wgi->local.l_public, + format (s, " %U", format_hex_bytes, local->l_public, NOISE_PUBLIC_KEY_LEN); s = format (s, " mac-key: %U", format_hex_bytes, @@ -72,23 +73,28 @@ wg_if_find_by_sw_if_index (u32 sw_if_index) return (ti); } +static walk_rc_t +wg_if_find_peer_by_public_key (index_t peeri, void *data) +{ + uint8_t *public = data; + wg_peer_t *peer = wg_peer_get (peeri); + + if (!memcmp (peer->remote.r_public, public, NOISE_PUBLIC_KEY_LEN)) + return (WALK_STOP); + return (WALK_CONTINUE); +} + static noise_remote_t * -wg_remote_get (uint8_t public[NOISE_PUBLIC_KEY_LEN]) +wg_remote_get (const uint8_t public[NOISE_PUBLIC_KEY_LEN]) { - wg_main_t *wmp = &wg_main; - wg_peer_t *peer = NULL; - wg_peer_t *peer_iter; - /* *INDENT-OFF* */ - pool_foreach (peer_iter, wmp->peers, - ({ - if (!memcmp (peer_iter->remote.r_public, public, NOISE_PUBLIC_KEY_LEN)) - { - peer = peer_iter; - break; - } - })); - /* *INDENT-ON* */ - return peer ? &peer->remote : NULL; + index_t peeri; + + peeri = wg_peer_walk (wg_if_find_peer_by_public_key, (void *) public); + + if (INDEX_INVALID != peeri) + return &wg_peer_get (peeri)->remote; + + return NULL; } static uint32_t @@ -223,6 +229,7 @@ wg_if_create (u32 user_instance, u32 instance, hw_if_index; vnet_hw_interface_t *hi; wg_if_t *wg_if; + noise_local_t *local; ASSERT (sw_if_indexp); @@ -236,6 +243,24 @@ wg_if_create (u32 user_instance, if (instance == ~0) return VNET_API_ERROR_INVALID_REGISTRATION; + /* *INDENT-OFF* */ + struct noise_upcall upcall = { + .u_remote_get = wg_remote_get, + .u_index_set = wg_index_set, + .u_index_drop = wg_index_drop, + }; + /* *INDENT-ON* */ + + pool_get (noise_local_pool, local); + + noise_local_init (local, &upcall); + if (!noise_local_set_private (local, private_key)) + { + pool_put (noise_local_pool, local); + wg_if_instance_free (instance); + return VNET_API_ERROR_INVALID_REGISTRATION; + } + pool_get (wg_if_pool, wg_if); /* tunnel index (or instance) */ @@ -251,18 +276,8 @@ wg_if_create (u32 user_instance, wg_if_index_by_port[port] = wg_if - wg_if_pool; wg_if->port = port; - - /* *INDENT-OFF* */ - struct noise_upcall upcall = { - .u_remote_get = wg_remote_get, - .u_index_set = wg_index_set, - .u_index_drop = wg_index_drop, - }; - /* *INDENT-ON* */ - - noise_local_init (&wg_if->local, &upcall); - noise_local_set_private (&wg_if->local, private_key); - cookie_checker_update (&wg_if->cookie_checker, wg_if->local.l_public); + wg_if->local_idx = local - noise_local_pool; + cookie_checker_update (&wg_if->cookie_checker, local->l_public); hw_if_index = vnet_register_interface (vnm, wg_if_device_class.index, @@ -304,6 +319,7 @@ wg_if_delete (u32 sw_if_index) udp_unregister_dst_port (vlib_get_main (), wg_if->port, 1); wg_if_index_by_port[wg_if->port] = INDEX_INVALID; vnet_delete_hw_interface (vnm, hw->hw_if_index); + pool_put_index (noise_local_pool, wg_if->local_idx); pool_put (wg_if_pool, wg_if); return 0; @@ -343,7 +359,7 @@ wg_if_walk (wg_if_walk_cb_t fn, void *data) /* *INDENT-ON* */ } -void +index_t wg_if_peer_walk (wg_if_t * wgi, wg_if_peer_walk_cb_t fn, void *data) { index_t peeri, val; @@ -352,9 +368,11 @@ wg_if_peer_walk (wg_if_t * wgi, wg_if_peer_walk_cb_t fn, void *data) hash_foreach (peeri, val, wgi->peers, { if (WALK_STOP == fn(wgi, peeri, data)) - break; + return peeri; }); /* *INDENT-ON* */ + + return INDEX_INVALID; } diff --git a/src/plugins/wireguard/wireguard_if.h b/src/plugins/wireguard/wireguard_if.h index 9e6b6190e0e..d8c2a87dc71 100644 --- a/src/plugins/wireguard/wireguard_if.h +++ b/src/plugins/wireguard/wireguard_if.h @@ -25,7 +25,8 @@ typedef struct wg_if_t_ u32 sw_if_index; // Interface params - noise_local_t local; + /* noise_local_pool elt index */ + u32 local_idx; cookie_checker_t cookie_checker; u16 port; @@ -52,7 +53,7 @@ void wg_if_walk (wg_if_walk_cb_t fn, void *data); typedef walk_rc_t (*wg_if_peer_walk_cb_t) (wg_if_t * wgi, index_t peeri, void *data); -void wg_if_peer_walk (wg_if_t * wgi, wg_if_peer_walk_cb_t fn, void *data); +index_t wg_if_peer_walk (wg_if_t * wgi, wg_if_peer_walk_cb_t fn, void *data); void wg_if_peer_add (wg_if_t * wgi, index_t peeri); void wg_if_peer_remove (wg_if_t * wgi, index_t peeri); diff --git a/src/plugins/wireguard/wireguard_input.c b/src/plugins/wireguard/wireguard_input.c index cdd65f87b51..b15c265cdac 100755 --- a/src/plugins/wireguard/wireguard_input.c +++ b/src/plugins/wireguard/wireguard_input.c @@ -30,6 +30,7 @@ _(DECRYPTION, "Failed during decryption") \ _(KEEPALIVE_SEND, "Failed while sending Keepalive") \ _(HANDSHAKE_SEND, "Failed while sending Handshake") \ + _(TOO_BIG, "Packet too big") \ _(UNDEFINED, "Undefined error") typedef enum @@ -51,7 +52,7 @@ typedef struct message_type_t type; u16 current_length; bool is_keepalive; - + index_t peer; } wg_input_trace_t; u8 * @@ -79,6 +80,7 @@ format_wg_input_trace (u8 * s, va_list * args) s = format (s, "WG input: \n"); s = format (s, " Type: %U\n", format_wg_message_type, t->type); + s = format (s, " peer: %d\n", t->peer); s = format (s, " Length: %d\n", t->current_length); s = format (s, " Keepalive: %s", t->is_keepalive ? "true" : "false"); @@ -87,6 +89,8 @@ format_wg_input_trace (u8 * s, va_list * args) typedef enum { + WG_INPUT_NEXT_HANDOFF_HANDSHAKE, + WG_INPUT_NEXT_HANDOFF_DATA, WG_INPUT_NEXT_IP4_INPUT, WG_INPUT_NEXT_PUNT, WG_INPUT_NEXT_ERROR, @@ -106,6 +110,8 @@ typedef enum static wg_input_error_t wg_handshake_process (vlib_main_t * vm, wg_main_t * wmp, vlib_buffer_t * b) { + ASSERT (vm->thread_index == 0); + enum cookie_mac_state mac_state; bool packet_needs_cookie; bool under_load; @@ -129,17 +135,15 @@ wg_handshake_process (vlib_main_t * vm, wg_main_t * wmp, vlib_buffer_t * b) if (NULL == wg_if) return WG_INPUT_ERROR_INTERFACE; - if (header->type == MESSAGE_HANDSHAKE_COOKIE) + if (PREDICT_FALSE (header->type == MESSAGE_HANDSHAKE_COOKIE)) { message_handshake_cookie_t *packet = (message_handshake_cookie_t *) current_b_data; u32 *entry = wg_index_table_lookup (&wmp->index_table, packet->receiver_index); if (entry) - { - peer = pool_elt_at_index (wmp->peers, *entry); - } - if (!peer) + peer = wg_peer_get (*entry); + else return WG_INPUT_ERROR_PEER; // TODO: Implement cookie_maker_consume_payload @@ -178,17 +182,17 @@ wg_handshake_process (vlib_main_t * vm, wg_main_t * wmp, vlib_buffer_t * b) // TODO: Add processing } noise_remote_t *rp; - if (noise_consume_initiation - (wmp->vlib_main, &wg_if->local, &rp, message->sender_index, - message->unencrypted_ephemeral, message->encrypted_static, - message->encrypted_timestamp)) + (vm, noise_local_get (wg_if->local_idx), &rp, + message->sender_index, message->unencrypted_ephemeral, + message->encrypted_static, message->encrypted_timestamp)) { - peer = pool_elt_at_index (wmp->peers, rp->r_peer_idx); + peer = wg_peer_get (rp->r_peer_idx); + } + else + { + return WG_INPUT_ERROR_PEER; } - - if (!peer) - return WG_INPUT_ERROR_PEER; // set_peer_address (peer, ip4_src, udp_src_port); if (PREDICT_FALSE (!wg_send_handshake_response (vm, peer))) @@ -203,15 +207,18 @@ wg_handshake_process (vlib_main_t * vm, wg_main_t * wmp, vlib_buffer_t * b) message_handshake_response_t *resp = current_b_data; u32 *entry = wg_index_table_lookup (&wmp->index_table, resp->receiver_index); - if (entry) + + if (PREDICT_TRUE (entry != NULL)) { - peer = pool_elt_at_index (wmp->peers, *entry); - if (!peer || peer->is_dead) + peer = wg_peer_get (*entry); + if (peer->is_dead) return WG_INPUT_ERROR_PEER; } + else + return WG_INPUT_ERROR_PEER; if (!noise_consume_response - (wmp->vlib_main, &peer->remote, resp->sender_index, + (vm, &peer->remote, resp->sender_index, resp->receiver_index, resp->unencrypted_ephemeral, resp->encrypted_nothing)) { @@ -223,8 +230,9 @@ wg_handshake_process (vlib_main_t * vm, wg_main_t * wmp, vlib_buffer_t * b) } // set_peer_address (peer, ip4_src, udp_src_port); - if (noise_remote_begin_session (wmp->vlib_main, &peer->remote)) + if (noise_remote_begin_session (vm, &peer->remote)) { + wg_timers_session_derived (peer); wg_timers_handshake_complete (peer); if (PREDICT_FALSE (!wg_send_keepalive (vm, peer))) @@ -272,6 +280,7 @@ VLIB_NODE_FN (wg_input_node) (vlib_main_t * vm, u32 *from; vlib_buffer_t *bufs[VLIB_FRAME_SIZE], **b; u16 nexts[VLIB_FRAME_SIZE], *next; + u32 thread_index = vm->thread_index; from = vlib_frame_vector_args (frame); n_left_from = frame->n_vectors; @@ -289,120 +298,132 @@ VLIB_NODE_FN (wg_input_node) (vlib_main_t * vm, next[0] = WG_INPUT_NEXT_PUNT; header_type = ((message_header_t *) vlib_buffer_get_current (b[0]))->type; + u32 *peer_idx; - switch (header_type) + if (PREDICT_TRUE (header_type == MESSAGE_DATA)) { - case MESSAGE_HANDSHAKE_INITIATION: - case MESSAGE_HANDSHAKE_RESPONSE: - case MESSAGE_HANDSHAKE_COOKIE: - { - wg_input_error_t ret = wg_handshake_process (vm, wmp, b[0]); - if (ret != WG_INPUT_ERROR_NONE) - { - next[0] = WG_INPUT_NEXT_ERROR; - b[0]->error = node->errors[ret]; - } - break; - } - case MESSAGE_DATA: - { - message_data_t *data = vlib_buffer_get_current (b[0]); - u32 *entry = - wg_index_table_lookup (&wmp->index_table, data->receiver_index); - - if (entry) - { - peer = pool_elt_at_index (wmp->peers, *entry); - } - else - { - next[0] = WG_INPUT_NEXT_ERROR; - b[0]->error = node->errors[WG_INPUT_ERROR_PEER]; - goto out; - } + message_data_t *data = vlib_buffer_get_current (b[0]); - u16 encr_len = b[0]->current_length - sizeof (message_data_t); - u16 decr_len = encr_len - NOISE_AUTHTAG_LEN; - u8 *decr_data = clib_mem_alloc (decr_len); + peer_idx = wg_index_table_lookup (&wmp->index_table, + data->receiver_index); - enum noise_state_crypt state_cr = - noise_remote_decrypt (wmp->vlib_main, - &peer->remote, - data->receiver_index, - data->counter, - data->encrypted_data, - encr_len, - decr_data); + if (peer_idx) + { + peer = wg_peer_get (*peer_idx); + } + else + { + next[0] = WG_INPUT_NEXT_ERROR; + b[0]->error = node->errors[WG_INPUT_ERROR_PEER]; + goto out; + } - switch (state_cr) - { - case SC_OK: - break; - case SC_CONN_RESET: - wg_timers_handshake_complete (peer); - break; - case SC_KEEP_KEY_FRESH: - if (PREDICT_FALSE (!wg_send_handshake (vm, peer, false))) - { - vlib_node_increment_counter (vm, wg_input_node.index, - WG_INPUT_ERROR_HANDSHAKE_SEND, - 1); - } - break; - case SC_FAILED: - next[0] = WG_INPUT_NEXT_ERROR; - b[0]->error = node->errors[WG_INPUT_ERROR_DECRYPTION]; - goto out; - default: - break; - } + if (PREDICT_FALSE (~0 == peer->input_thread_index)) + { + /* this is the first packet to use this peer, claim the peer + * for this thread. + */ + clib_atomic_cmp_and_swap (&peer->input_thread_index, ~0, + wg_peer_assign_thread (thread_index)); + } - clib_memcpy (vlib_buffer_get_current (b[0]), decr_data, decr_len); - b[0]->current_length = decr_len; - b[0]->flags &= ~VNET_BUFFER_F_OFFLOAD_UDP_CKSUM; + if (PREDICT_TRUE (thread_index != peer->input_thread_index)) + { + next[0] = WG_INPUT_NEXT_HANDOFF_DATA; + goto next; + } - clib_mem_free (decr_data); + u16 encr_len = b[0]->current_length - sizeof (message_data_t); + u16 decr_len = encr_len - NOISE_AUTHTAG_LEN; + if (PREDICT_FALSE (decr_len >= WG_DEFAULT_DATA_SIZE)) + { + b[0]->error = node->errors[WG_INPUT_ERROR_TOO_BIG]; + goto out; + } - wg_timers_any_authenticated_packet_received (peer); - wg_timers_any_authenticated_packet_traversal (peer); + u8 *decr_data = wmp->per_thread_data[thread_index].data; - if (decr_len == 0) - { - is_keepalive = true; - goto out; - } + enum noise_state_crypt state_cr = noise_remote_decrypt (vm, + &peer->remote, + data->receiver_index, + data->counter, + data->encrypted_data, + encr_len, + decr_data); - wg_timers_data_received (peer); + if (PREDICT_FALSE (state_cr == SC_CONN_RESET)) + { + wg_timers_handshake_complete (peer); + } + else if (PREDICT_FALSE (state_cr == SC_KEEP_KEY_FRESH)) + { + wg_send_handshake_from_mt (*peer_idx, false); + } + else if (PREDICT_FALSE (state_cr == SC_FAILED)) + { + next[0] = WG_INPUT_NEXT_ERROR; + b[0]->error = node->errors[WG_INPUT_ERROR_DECRYPTION]; + goto out; + } - ip4_header_t *iph = vlib_buffer_get_current (b[0]); + clib_memcpy (vlib_buffer_get_current (b[0]), decr_data, decr_len); + b[0]->current_length = decr_len; + b[0]->flags &= ~VNET_BUFFER_F_OFFLOAD_UDP_CKSUM; - const wg_peer_allowed_ip_t *allowed_ip; - bool allowed = false; + wg_timers_any_authenticated_packet_received (peer); + wg_timers_any_authenticated_packet_traversal (peer); - /* - * we could make this into an ACL, but the expectation - * is that there aren't many allowed IPs and thus a linear - * walk is fater than an ACL - */ - vec_foreach (allowed_ip, peer->allowed_ips) + /* Keepalive packet has zero length */ + if (decr_len == 0) { - if (fib_prefix_is_cover_addr_4 (&allowed_ip->prefix, - &iph->src_address)) - { - allowed = true; - break; - } + is_keepalive = true; + goto out; } - if (allowed) + + wg_timers_data_received (peer); + + ip4_header_t *iph = vlib_buffer_get_current (b[0]); + + const wg_peer_allowed_ip_t *allowed_ip; + bool allowed = false; + + /* + * we could make this into an ACL, but the expectation + * is that there aren't many allowed IPs and thus a linear + * walk is fater than an ACL + */ + vec_foreach (allowed_ip, peer->allowed_ips) + { + if (fib_prefix_is_cover_addr_4 (&allowed_ip->prefix, + &iph->src_address)) { - vnet_buffer (b[0])->sw_if_index[VLIB_RX] = - peer->wg_sw_if_index; - next[0] = WG_INPUT_NEXT_IP4_INPUT; + allowed = true; + break; } - break; } - default: - break; + if (allowed) + { + vnet_buffer (b[0])->sw_if_index[VLIB_RX] = peer->wg_sw_if_index; + next[0] = WG_INPUT_NEXT_IP4_INPUT; + } + } + else + { + peer_idx = NULL; + + /* Handshake packets should be processed in main thread */ + if (thread_index != 0) + { + next[0] = WG_INPUT_NEXT_HANDOFF_HANDSHAKE; + goto next; + } + + wg_input_error_t ret = wg_handshake_process (vm, wmp, b[0]); + if (ret != WG_INPUT_ERROR_NONE) + { + next[0] = WG_INPUT_NEXT_ERROR; + b[0]->error = node->errors[ret]; + } } out: @@ -413,7 +434,9 @@ VLIB_NODE_FN (wg_input_node) (vlib_main_t * vm, t->type = header_type; t->current_length = b[0]->current_length; t->is_keepalive = is_keepalive; + t->peer = peer_idx ? *peer_idx : INDEX_INVALID; } + next: n_left_from -= 1; next += 1; b += 1; @@ -435,6 +458,8 @@ VLIB_REGISTER_NODE (wg_input_node) = .n_next_nodes = WG_INPUT_N_NEXT, /* edit / add dispositions here */ .next_nodes = { + [WG_INPUT_NEXT_HANDOFF_HANDSHAKE] = "wg-handshake-handoff", + [WG_INPUT_NEXT_HANDOFF_DATA] = "wg-input-data-handoff", [WG_INPUT_NEXT_IP4_INPUT] = "ip4-input-no-checksum", [WG_INPUT_NEXT_PUNT] = "error-punt", [WG_INPUT_NEXT_ERROR] = "error-drop", diff --git a/src/plugins/wireguard/wireguard_noise.c b/src/plugins/wireguard/wireguard_noise.c index b47bb5747b9..00b67109de4 100755 --- a/src/plugins/wireguard/wireguard_noise.c +++ b/src/plugins/wireguard/wireguard_noise.c @@ -26,6 +26,8 @@ * <- e, ee, se, psk, {} */ +noise_local_t *noise_local_pool; + /* Private functions */ static noise_keypair_t *noise_remote_keypair_allocate (noise_remote_t *); static void noise_remote_keypair_free (vlib_main_t * vm, noise_remote_t *, @@ -80,81 +82,31 @@ noise_local_set_private (noise_local_t * l, const uint8_t private[NOISE_PUBLIC_KEY_LEN]) { clib_memcpy (l->l_private, private, NOISE_PUBLIC_KEY_LEN); - l->l_has_identity = curve25519_gen_public (l->l_public, private); - - return l->l_has_identity; -} -bool -noise_local_keys (noise_local_t * l, uint8_t public[NOISE_PUBLIC_KEY_LEN], - uint8_t private[NOISE_PUBLIC_KEY_LEN]) -{ - if (l->l_has_identity) - { - if (public != NULL) - clib_memcpy (public, l->l_public, NOISE_PUBLIC_KEY_LEN); - if (private != NULL) - clib_memcpy (private, l->l_private, NOISE_PUBLIC_KEY_LEN); - } - else - { - return false; - } - return true; + return curve25519_gen_public (l->l_public, private); } void noise_remote_init (noise_remote_t * r, uint32_t peer_pool_idx, const uint8_t public[NOISE_PUBLIC_KEY_LEN], - noise_local_t * l) + u32 noise_local_idx) { clib_memset (r, 0, sizeof (*r)); clib_memcpy (r->r_public, public, NOISE_PUBLIC_KEY_LEN); + clib_rwlock_init (&r->r_keypair_lock); r->r_peer_idx = peer_pool_idx; - - ASSERT (l != NULL); - r->r_local = l; + r->r_local_idx = noise_local_idx; r->r_handshake.hs_state = HS_ZEROED; - noise_remote_precompute (r); -} - -bool -noise_remote_set_psk (noise_remote_t * r, - uint8_t psk[NOISE_SYMMETRIC_KEY_LEN]) -{ - int same; - same = !clib_memcmp (r->r_psk, psk, NOISE_SYMMETRIC_KEY_LEN); - if (!same) - { - clib_memcpy (r->r_psk, psk, NOISE_SYMMETRIC_KEY_LEN); - } - return same == 0; -} - -bool -noise_remote_keys (noise_remote_t * r, uint8_t public[NOISE_PUBLIC_KEY_LEN], - uint8_t psk[NOISE_SYMMETRIC_KEY_LEN]) -{ - static uint8_t null_psk[NOISE_SYMMETRIC_KEY_LEN]; - int ret; - if (public != NULL) - clib_memcpy (public, r->r_public, NOISE_PUBLIC_KEY_LEN); - - if (psk != NULL) - clib_memcpy (psk, r->r_psk, NOISE_SYMMETRIC_KEY_LEN); - ret = clib_memcmp (r->r_psk, null_psk, NOISE_SYMMETRIC_KEY_LEN); - - return ret; + noise_remote_precompute (r); } void noise_remote_precompute (noise_remote_t * r) { - noise_local_t *l = r->r_local; - if (!l->l_has_identity) - clib_memset (r->r_ss, 0, NOISE_PUBLIC_KEY_LEN); - else if (!curve25519_gen_shared (r->r_ss, l->l_private, r->r_public)) + noise_local_t *l = noise_local_get (r->r_local_idx); + + if (!curve25519_gen_shared (r->r_ss, l->l_private, r->r_public)) clib_memset (r->r_ss, 0, NOISE_PUBLIC_KEY_LEN); noise_remote_handshake_index_drop (r); @@ -169,7 +121,7 @@ noise_create_initiation (vlib_main_t * vm, noise_remote_t * r, uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN]) { noise_handshake_t *hs = &r->r_handshake; - noise_local_t *l = r->r_local; + noise_local_t *l = noise_local_get (r->r_local_idx); uint8_t _key[NOISE_SYMMETRIC_KEY_LEN]; uint32_t key_idx; uint8_t *key; @@ -180,8 +132,6 @@ noise_create_initiation (vlib_main_t * vm, noise_remote_t * r, NOISE_SYMMETRIC_KEY_LEN); key = vnet_crypto_get_key (key_idx)->data; - if (!l->l_has_identity) - goto error; noise_param_init (hs->hs_ck, hs->hs_hash, r->r_public); /* e */ @@ -239,8 +189,6 @@ noise_consume_initiation (vlib_main_t * vm, noise_local_t * l, NOISE_SYMMETRIC_KEY_LEN); key = vnet_crypto_get_key (key_idx)->data; - if (!l->l_has_identity) - goto error; noise_param_init (hs.hs_ck, hs.hs_hash, l->l_public); /* e */ @@ -294,6 +242,7 @@ noise_consume_initiation (vlib_main_t * vm, noise_local_t * l, r->r_handshake = hs; *rp = r; ret = true; + error: vnet_crypto_key_del (vm, key_idx); secure_zero_memory (key, NOISE_SYMMETRIC_KEY_LEN); @@ -359,7 +308,7 @@ noise_consume_response (vlib_main_t * vm, noise_remote_t * r, uint32_t s_idx, uint32_t r_idx, uint8_t ue[NOISE_PUBLIC_KEY_LEN], uint8_t en[0 + NOISE_AUTHTAG_LEN]) { - noise_local_t *l = r->r_local; + noise_local_t *l = noise_local_get (r->r_local_idx); noise_handshake_t hs; uint8_t _key[NOISE_SYMMETRIC_KEY_LEN]; uint8_t preshared_key[NOISE_PUBLIC_KEY_LEN]; @@ -372,9 +321,6 @@ noise_consume_response (vlib_main_t * vm, noise_remote_t * r, uint32_t s_idx, NOISE_SYMMETRIC_KEY_LEN); key = vnet_crypto_get_key (key_idx)->data; - if (!l->l_has_identity) - goto error; - hs = r->r_handshake; clib_memcpy (preshared_key, r->r_psk, NOISE_SYMMETRIC_KEY_LEN); @@ -460,6 +406,7 @@ noise_remote_begin_session (vlib_main_t * vm, noise_remote_t * r) clib_memset (&kp.kp_ctr, 0, sizeof (kp.kp_ctr)); /* Now we need to add_new_keypair */ + clib_rwlock_writer_lock (&r->r_keypair_lock); next = r->r_next; current = r->r_current; previous = r->r_previous; @@ -491,7 +438,10 @@ noise_remote_begin_session (vlib_main_t * vm, noise_remote_t * r) r->r_next = noise_remote_keypair_allocate (r); *r->r_next = kp; } + clib_rwlock_writer_unlock (&r->r_keypair_lock); + secure_zero_memory (&r->r_handshake, sizeof (r->r_handshake)); + secure_zero_memory (&kp, sizeof (kp)); return true; } @@ -502,21 +452,25 @@ noise_remote_clear (vlib_main_t * vm, noise_remote_t * r) noise_remote_handshake_index_drop (r); secure_zero_memory (&r->r_handshake, sizeof (r->r_handshake)); + clib_rwlock_writer_lock (&r->r_keypair_lock); noise_remote_keypair_free (vm, r, &r->r_next); noise_remote_keypair_free (vm, r, &r->r_current); noise_remote_keypair_free (vm, r, &r->r_previous); r->r_next = NULL; r->r_current = NULL; r->r_previous = NULL; + clib_rwlock_writer_unlock (&r->r_keypair_lock); } void noise_remote_expire_current (noise_remote_t * r) { + clib_rwlock_writer_lock (&r->r_keypair_lock); if (r->r_next != NULL) r->r_next->kp_valid = 0; if (r->r_current != NULL) r->r_current->kp_valid = 0; + clib_rwlock_writer_unlock (&r->r_keypair_lock); } bool @@ -525,6 +479,7 @@ noise_remote_ready (noise_remote_t * r) noise_keypair_t *kp; int ret; + clib_rwlock_reader_lock (&r->r_keypair_lock); if ((kp = r->r_current) == NULL || !kp->kp_valid || wg_birthdate_has_expired (kp->kp_birthdate, REJECT_AFTER_TIME) || @@ -533,6 +488,7 @@ noise_remote_ready (noise_remote_t * r) ret = false; else ret = true; + clib_rwlock_reader_unlock (&r->r_keypair_lock); return ret; } @@ -592,6 +548,7 @@ noise_remote_encrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t * r_idx, noise_keypair_t *kp; enum noise_state_crypt ret = SC_FAILED; + clib_rwlock_reader_lock (&r->r_keypair_lock); if ((kp = r->r_current) == NULL) goto error; @@ -631,6 +588,7 @@ noise_remote_encrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t * r_idx, ret = SC_OK; error: + clib_rwlock_reader_unlock (&r->r_keypair_lock); return ret; } @@ -641,6 +599,7 @@ noise_remote_decrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t r_idx, { noise_keypair_t *kp; enum noise_state_crypt ret = SC_FAILED; + clib_rwlock_reader_lock (&r->r_keypair_lock); if (r->r_current != NULL && r->r_current->kp_local_index == r_idx) { @@ -682,18 +641,26 @@ noise_remote_decrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t r_idx, * next keypair into current. If we do slide the next keypair in, then * we skip the REKEY_AFTER_TIME_RECV check. This is safe to do as a * data packet can't confirm a session that we are an INITIATOR of. */ - if (kp == r->r_next && kp->kp_local_index == r_idx) + if (kp == r->r_next) { - noise_remote_keypair_free (vm, r, &r->r_previous); - r->r_previous = r->r_current; - r->r_current = r->r_next; - r->r_next = NULL; + clib_rwlock_reader_unlock (&r->r_keypair_lock); + clib_rwlock_writer_lock (&r->r_keypair_lock); + if (kp == r->r_next && kp->kp_local_index == r_idx) + { + noise_remote_keypair_free (vm, r, &r->r_previous); + r->r_previous = r->r_current; + r->r_current = r->r_next; + r->r_next = NULL; - ret = SC_CONN_RESET; - goto error; + ret = SC_CONN_RESET; + clib_rwlock_writer_unlock (&r->r_keypair_lock); + clib_rwlock_reader_lock (&r->r_keypair_lock); + goto error; + } + clib_rwlock_writer_unlock (&r->r_keypair_lock); + clib_rwlock_reader_lock (&r->r_keypair_lock); } - /* Similar to when we encrypt, we want to notify the caller when we * are approaching our tolerances. We notify if: * - we're the initiator and the current keypair is older than @@ -708,6 +675,7 @@ noise_remote_decrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t r_idx, ret = SC_OK; error: + clib_rwlock_reader_unlock (&r->r_keypair_lock); return ret; } @@ -725,7 +693,8 @@ static void noise_remote_keypair_free (vlib_main_t * vm, noise_remote_t * r, noise_keypair_t ** kp) { - struct noise_upcall *u = &r->r_local->l_upcall; + noise_local_t *local = noise_local_get (r->r_local_idx); + struct noise_upcall *u = &local->l_upcall; if (*kp) { u->u_index_drop ((*kp)->kp_local_index); @@ -738,7 +707,8 @@ noise_remote_keypair_free (vlib_main_t * vm, noise_remote_t * r, static uint32_t noise_remote_handshake_index_get (noise_remote_t * r) { - struct noise_upcall *u = &r->r_local->l_upcall; + noise_local_t *local = noise_local_get (r->r_local_idx); + struct noise_upcall *u = &local->l_upcall; return u->u_index_set (r); } @@ -746,7 +716,8 @@ static void noise_remote_handshake_index_drop (noise_remote_t * r) { noise_handshake_t *hs = &r->r_handshake; - struct noise_upcall *u = &r->r_local->l_upcall; + noise_local_t *local = noise_local_get (r->r_local_idx); + struct noise_upcall *u = &local->l_upcall; if (hs->hs_state != HS_ZEROED) u->u_index_drop (hs->hs_local_index); } @@ -754,7 +725,8 @@ noise_remote_handshake_index_drop (noise_remote_t * r) static uint64_t noise_counter_send (noise_counter_t * ctr) { - uint64_t ret = ctr->c_send++; + uint64_t ret; + ret = ctr->c_send++; return ret; } @@ -765,7 +737,6 @@ noise_counter_recv (noise_counter_t * ctr, uint64_t recv) unsigned long bit; bool ret = false; - /* Check that the recv counter is valid */ if (ctr->c_recv >= REJECT_AFTER_MESSAGES || recv >= REJECT_AFTER_MESSAGES) goto error; diff --git a/src/plugins/wireguard/wireguard_noise.h b/src/plugins/wireguard/wireguard_noise.h index 1f6804c59ca..5b5a88fa250 100755 --- a/src/plugins/wireguard/wireguard_noise.h +++ b/src/plugins/wireguard/wireguard_noise.h @@ -100,7 +100,7 @@ typedef struct noise_remote { uint32_t r_peer_idx; uint8_t r_public[NOISE_PUBLIC_KEY_LEN]; - noise_local_t *r_local; + uint32_t r_local_idx; uint8_t r_ss[NOISE_PUBLIC_KEY_LEN]; noise_handshake_t r_handshake; @@ -108,37 +108,40 @@ typedef struct noise_remote uint8_t r_timestamp[NOISE_TIMESTAMP_LEN]; f64 r_last_init; + clib_rwlock_t r_keypair_lock; noise_keypair_t *r_next, *r_current, *r_previous; } noise_remote_t; typedef struct noise_local { - bool l_has_identity; uint8_t l_public[NOISE_PUBLIC_KEY_LEN]; uint8_t l_private[NOISE_PUBLIC_KEY_LEN]; struct noise_upcall { void *u_arg; - noise_remote_t *(*u_remote_get) (uint8_t[NOISE_PUBLIC_KEY_LEN]); + noise_remote_t *(*u_remote_get) (const uint8_t[NOISE_PUBLIC_KEY_LEN]); uint32_t (*u_index_set) (noise_remote_t *); void (*u_index_drop) (uint32_t); } l_upcall; } noise_local_t; +/* pool of noise_local */ +extern noise_local_t *noise_local_pool; + /* Set/Get noise parameters */ +static_always_inline noise_local_t * +noise_local_get (uint32_t locali) +{ + return (pool_elt_at_index (noise_local_pool, locali)); +} + void noise_local_init (noise_local_t *, struct noise_upcall *); bool noise_local_set_private (noise_local_t *, const uint8_t[NOISE_PUBLIC_KEY_LEN]); -bool noise_local_keys (noise_local_t *, uint8_t[NOISE_PUBLIC_KEY_LEN], - uint8_t[NOISE_PUBLIC_KEY_LEN]); void noise_remote_init (noise_remote_t *, uint32_t, - const uint8_t[NOISE_PUBLIC_KEY_LEN], noise_local_t *); -bool noise_remote_set_psk (noise_remote_t *, - uint8_t[NOISE_SYMMETRIC_KEY_LEN]); -bool noise_remote_keys (noise_remote_t *, uint8_t[NOISE_PUBLIC_KEY_LEN], - uint8_t[NOISE_SYMMETRIC_KEY_LEN]); + const uint8_t[NOISE_PUBLIC_KEY_LEN], uint32_t); /* Should be called anytime noise_local_set_private is called */ void noise_remote_precompute (noise_remote_t *); diff --git a/src/plugins/wireguard/wireguard_output_tun.c b/src/plugins/wireguard/wireguard_output_tun.c index cdfd9d730f6..9a8710b77db 100755 --- a/src/plugins/wireguard/wireguard_output_tun.c +++ b/src/plugins/wireguard/wireguard_output_tun.c @@ -15,10 +15,6 @@ #include <vlib/vlib.h> #include <vnet/vnet.h> -#include <vnet/pg/pg.h> -#include <vnet/fib/ip6_fib.h> -#include <vnet/fib/ip4_fib.h> -#include <vnet/fib/fib_entry.h> #include <vppinfra/error.h> #include <wireguard/wireguard.h> @@ -28,19 +24,8 @@ _(NONE, "No error") \ _(PEER, "Peer error") \ _(KEYPAIR, "Keypair error") \ - _(HANDSHAKE_SEND, "Handshake sending failed") \ _(TOO_BIG, "packet too big") \ -#define WG_OUTPUT_SCRATCH_SIZE 2048 - -typedef struct wg_output_scratch_t_ -{ - u8 scratch[WG_OUTPUT_SCRATCH_SIZE]; -} wg_output_scratch_t; - -/* Cache line aligned per-thread scratch space */ -static wg_output_scratch_t *wg_output_scratchs; - typedef enum { #define _(sym,str) WG_OUTPUT_ERROR_##sym, @@ -58,6 +43,7 @@ static char *wg_output_error_strings[] = { typedef enum { WG_OUTPUT_NEXT_ERROR, + WG_OUTPUT_NEXT_HANDOFF, WG_OUTPUT_NEXT_INTERFACE_OUTPUT, WG_OUTPUT_N_NEXT, } wg_output_next_t; @@ -65,6 +51,7 @@ typedef enum typedef struct { ip4_udp_header_t hdr; + index_t peer; } wg_output_tun_trace_t; u8 * @@ -87,7 +74,8 @@ format_wg_output_tun_trace (u8 * s, va_list * args) wg_output_tun_trace_t *t = va_arg (*args, wg_output_tun_trace_t *); - s = format (s, "Encrypted packet: %U\n", format_ip4_udp_header, &t->hdr); + s = format (s, "peer: %d\n", t->peer); + s = format (s, " Encrypted packet: %U", format_ip4_udp_header, &t->hdr); return s; } @@ -109,7 +97,6 @@ VLIB_NODE_FN (wg_output_tun_node) (vlib_main_t * vm, vlib_get_buffers (vm, from, bufs, n_left_from); wg_main_t *wmp = &wg_main; - u32 handsh_fails = 0; wg_peer_t *peer = NULL; while (n_left_from > 0) @@ -119,11 +106,12 @@ VLIB_NODE_FN (wg_output_tun_node) (vlib_main_t * vm, sizeof (ip4_udp_header_t)); u16 plain_data_len = clib_net_to_host_u16 (((ip4_header_t *) plain_data)->length); + index_t peeri; next[0] = WG_OUTPUT_NEXT_ERROR; - - peer = + peeri = wg_peer_get_by_adj_index (vnet_buffer (b[0])->ip.adj_index[VLIB_TX]); + peer = wg_peer_get (peeri); if (!peer || peer->is_dead) { @@ -131,21 +119,34 @@ VLIB_NODE_FN (wg_output_tun_node) (vlib_main_t * vm, goto out; } + if (PREDICT_FALSE (~0 == peer->output_thread_index)) + { + /* this is the first packet to use this peer, claim the peer + * for this thread. + */ + clib_atomic_cmp_and_swap (&peer->output_thread_index, ~0, + wg_peer_assign_thread (thread_index)); + } + + if (PREDICT_TRUE (thread_index != peer->output_thread_index)) + { + next[0] = WG_OUTPUT_NEXT_HANDOFF; + goto next; + } + if (PREDICT_FALSE (!peer->remote.r_current)) { - if (PREDICT_FALSE (!wg_send_handshake (vm, peer, false))) - handsh_fails++; + wg_send_handshake_from_mt (peeri, false); b[0]->error = node->errors[WG_OUTPUT_ERROR_KEYPAIR]; goto out; } - size_t encrypted_packet_len = message_data_len (plain_data_len); /* * Ensure there is enough space to write the encrypted data * into the packet */ - if (PREDICT_FALSE (encrypted_packet_len >= WG_OUTPUT_SCRATCH_SIZE) || + if (PREDICT_FALSE (encrypted_packet_len >= WG_DEFAULT_DATA_SIZE) || PREDICT_FALSE ((b[0]->current_data + encrypted_packet_len) >= vlib_buffer_get_default_data_size (vm))) { @@ -154,35 +155,29 @@ VLIB_NODE_FN (wg_output_tun_node) (vlib_main_t * vm, } message_data_t *encrypted_packet = - (message_data_t *) wg_output_scratchs[thread_index].scratch; + (message_data_t *) wmp->per_thread_data[thread_index].data; enum noise_state_crypt state; state = - noise_remote_encrypt (wmp->vlib_main, + noise_remote_encrypt (vm, &peer->remote, &encrypted_packet->receiver_index, &encrypted_packet->counter, plain_data, plain_data_len, encrypted_packet->encrypted_data); - switch (state) + + if (PREDICT_FALSE (state == SC_KEEP_KEY_FRESH)) + { + wg_send_handshake_from_mt (peeri, false); + } + else if (PREDICT_FALSE (state == SC_FAILED)) { - case SC_OK: - break; - case SC_KEEP_KEY_FRESH: - if (PREDICT_FALSE (!wg_send_handshake (vm, peer, false))) - handsh_fails++; - break; - case SC_FAILED: //TODO: Maybe wrong - if (PREDICT_FALSE (!wg_send_handshake (vm, peer, false))) - handsh_fails++; - clib_mem_free (encrypted_packet); + wg_send_handshake_from_mt (peeri, false); goto out; - default: - break; } - // Here we are sure that can send packet to next node. + /* Here we are sure that can send packet to next node */ next[0] = WG_OUTPUT_NEXT_INTERFACE_OUTPUT; encrypted_packet->header.type = MESSAGE_DATA; @@ -195,9 +190,9 @@ VLIB_NODE_FN (wg_output_tun_node) (vlib_main_t * vm, ip4_header_set_len_w_chksum (&hdr->ip4, clib_host_to_net_u16 (b[0]->current_length)); - wg_timers_any_authenticated_packet_traversal (peer); wg_timers_any_authenticated_packet_sent (peer); wg_timers_data_sent (peer); + wg_timers_any_authenticated_packet_traversal (peer); out: if (PREDICT_FALSE ((node->flags & VLIB_NODE_FLAG_TRACE) @@ -206,17 +201,15 @@ VLIB_NODE_FN (wg_output_tun_node) (vlib_main_t * vm, wg_output_tun_trace_t *t = vlib_add_trace (vm, node, b[0], sizeof (*t)); t->hdr = *hdr; + t->peer = peeri; } + next: n_left_from -= 1; next += 1; b += 1; } vlib_buffer_enqueue_to_next (vm, node, from, nexts, frame->n_vectors); - - vlib_node_increment_counter (vm, node->node_index, - WG_OUTPUT_ERROR_HANDSHAKE_SEND, handsh_fails); - return frame->n_vectors; } @@ -231,24 +224,13 @@ VLIB_REGISTER_NODE (wg_output_tun_node) = .error_strings = wg_output_error_strings, .n_next_nodes = WG_OUTPUT_N_NEXT, .next_nodes = { + [WG_OUTPUT_NEXT_HANDOFF] = "wg-output-tun-handoff", [WG_OUTPUT_NEXT_INTERFACE_OUTPUT] = "adj-midchain-tx", [WG_OUTPUT_NEXT_ERROR] = "error-drop", }, }; /* *INDENT-ON* */ -static clib_error_t * -wireguard_output_module_init (vlib_main_t * vm) -{ - vlib_thread_main_t *tm = vlib_get_thread_main (); - - vec_validate_aligned (wg_output_scratchs, tm->n_vlib_mains, - CLIB_CACHE_LINE_BYTES); - return (NULL); -} - -VLIB_INIT_FUNCTION (wireguard_output_module_init); - /* * fd.io coding-style-patch-verification: ON * diff --git a/src/plugins/wireguard/wireguard_peer.c b/src/plugins/wireguard/wireguard_peer.c index 30adea82647..b41118f83d1 100755 --- a/src/plugins/wireguard/wireguard_peer.c +++ b/src/plugins/wireguard/wireguard_peer.c @@ -23,15 +23,10 @@ #include <wireguard/wireguard.h> static fib_source_t wg_fib_source; +wg_peer_t *wg_peer_pool; index_t *wg_peer_by_adj_index; -wg_peer_t * -wg_peer_get (index_t peeri) -{ - return (pool_elt_at_index (wg_main.peers, peeri)); -} - static void wg_peer_endpoint_reset (wg_peer_endpoint_t * ep) { @@ -82,7 +77,11 @@ static void wg_peer_clear (vlib_main_t * vm, wg_peer_t * peer) { wg_timers_stop (peer); - noise_remote_clear (vm, &peer->remote); + for (int i = 0; i < WG_N_TIMERS; i++) + { + peer->timers[i] = ~0; + } + peer->last_sent_handshake = vlib_time_now (vm) - (REKEY_TIMEOUT + 1); clib_memset (&peer->cookie_maker, 0, sizeof (peer->cookie_maker)); @@ -97,9 +96,18 @@ wg_peer_clear (vlib_main_t * vm, wg_peer_t * peer) } wg_peer_fib_flush (peer); + peer->input_thread_index = ~0; + peer->output_thread_index = ~0; peer->adj_index = INDEX_INVALID; + peer->timer_wheel = 0; peer->persistent_keepalive_interval = 0; peer->timer_handshake_attempts = 0; + peer->last_sent_packet = 0; + peer->last_received_packet = 0; + peer->session_derived = 0; + peer->rehandshake_started = 0; + peer->new_handshake_interval_tick = 0; + peer->rehandshake_interval_tick = 0; peer->timer_need_another_keepalive = false; peer->is_dead = true; vec_free (peer->allowed_ips); @@ -108,7 +116,7 @@ wg_peer_clear (vlib_main_t * vm, wg_peer_t * peer) static void wg_peer_init (vlib_main_t * vm, wg_peer_t * peer) { - wg_timers_init (peer, vlib_time_now (vm)); + peer->adj_index = INDEX_INVALID; wg_peer_clear (vm, peer); } @@ -205,8 +213,9 @@ wg_peer_fill (vlib_main_t * vm, wg_peer_t * peer, wg_peer_endpoint_init (&peer->dst, dst, port); peer->table_id = table_id; - peer->persistent_keepalive_interval = persistent_keepalive_interval; peer->wg_sw_if_index = wg_sw_if_index; + peer->timer_wheel = &wg_main.timer_wheel; + peer->persistent_keepalive_interval = persistent_keepalive_interval; peer->last_sent_handshake = vlib_time_now (vm) - (REKEY_TIMEOUT + 1); peer->is_dead = false; @@ -230,7 +239,7 @@ wg_peer_fill (vlib_main_t * vm, wg_peer_t * peer, vec_validate_init_empty (wg_peer_by_adj_index, peer->adj_index, INDEX_INVALID); - wg_peer_by_adj_index[peer->adj_index] = peer - wg_main.peers; + wg_peer_by_adj_index[peer->adj_index] = peer - wg_peer_pool; adj_nbr_midchain_update_rewrite (peer->adj_index, NULL, @@ -280,7 +289,7 @@ wg_peer_add (u32 tun_sw_if_index, return (VNET_API_ERROR_INVALID_SW_IF_INDEX); /* *INDENT-OFF* */ - pool_foreach (peer, wg_main.peers, + pool_foreach (peer, wg_peer_pool, ({ if (!memcmp (peer->remote.r_public, public_key, NOISE_PUBLIC_KEY_LEN)) { @@ -289,10 +298,10 @@ wg_peer_add (u32 tun_sw_if_index, })); /* *INDENT-ON* */ - if (pool_elts (wg_main.peers) > MAX_PEERS) + if (pool_elts (wg_peer_pool) > MAX_PEERS) return (VNET_API_ERROR_LIMIT_EXCEEDED); - pool_get (wg_main.peers, peer); + pool_get (wg_peer_pool, peer); wg_peer_init (vm, peer); @@ -302,12 +311,12 @@ wg_peer_add (u32 tun_sw_if_index, if (rv) { wg_peer_clear (vm, peer); - pool_put (wg_main.peers, peer); + pool_put (wg_peer_pool, peer); return (rv); } - noise_remote_init (&peer->remote, peer - wg_main.peers, public_key, - &wg_if->local); + noise_remote_init (&peer->remote, peer - wg_peer_pool, public_key, + wg_if->local_idx); cookie_maker_init (&peer->cookie_maker, public_key); if (peer->persistent_keepalive_interval != 0) @@ -315,7 +324,7 @@ wg_peer_add (u32 tun_sw_if_index, wg_send_keepalive (vm, peer); } - *peer_index = peer - wg_main.peers; + *peer_index = peer - wg_peer_pool; wg_if_peer_add (wg_if, *peer_index); return (0); @@ -328,34 +337,37 @@ wg_peer_remove (index_t peeri) wg_peer_t *peer = NULL; wg_if_t *wgi; - if (pool_is_free_index (wmp->peers, peeri)) + if (pool_is_free_index (wg_peer_pool, peeri)) return VNET_API_ERROR_NO_SUCH_ENTRY; - peer = pool_elt_at_index (wmp->peers, peeri); + peer = pool_elt_at_index (wg_peer_pool, peeri); wgi = wg_if_get (wg_if_find_by_sw_if_index (peer->wg_sw_if_index)); wg_if_peer_remove (wgi, peeri); vnet_feature_enable_disable ("ip4-output", "wg-output-tun", peer->wg_sw_if_index, 0, 0, 0); + + noise_remote_clear (wmp->vlib_main, &peer->remote); wg_peer_clear (wmp->vlib_main, peer); - pool_put (wmp->peers, peer); + pool_put (wg_peer_pool, peer); return (0); } -void +index_t wg_peer_walk (wg_peer_walk_cb_t fn, void *data) { index_t peeri; /* *INDENT-OFF* */ - pool_foreach_index(peeri, wg_main.peers, + pool_foreach_index(peeri, wg_peer_pool, { if (WALK_STOP == fn(peeri, data)) - break; + return peeri; }); /* *INDENT-ON* */ + return INDEX_INVALID; } static u8 * diff --git a/src/plugins/wireguard/wireguard_peer.h b/src/plugins/wireguard/wireguard_peer.h index 99c73f3a0ed..009a6f67aeb 100755 --- a/src/plugins/wireguard/wireguard_peer.h +++ b/src/plugins/wireguard/wireguard_peer.h @@ -49,6 +49,9 @@ typedef struct wg_peer noise_remote_t remote; cookie_maker_t cookie_maker; + u32 input_thread_index; + u32 output_thread_index; + /* Peer addresses */ wg_peer_endpoint_t dst; wg_peer_endpoint_t src; @@ -65,11 +68,22 @@ typedef struct wg_peer u32 wg_sw_if_index; /* Timers */ - tw_timer_wheel_16t_2w_512sl_t timer_wheel; + tw_timer_wheel_16t_2w_512sl_t *timer_wheel; u32 timers[WG_N_TIMERS]; u32 timer_handshake_attempts; u16 persistent_keepalive_interval; + + /* Timestamps */ f64 last_sent_handshake; + f64 last_sent_packet; + f64 last_received_packet; + f64 session_derived; + f64 rehandshake_started; + + /* Variable intervals */ + u32 new_handshake_interval_tick; + u32 rehandshake_interval_tick; + bool timer_need_another_keepalive; bool is_dead; @@ -91,10 +105,9 @@ int wg_peer_add (u32 tun_sw_if_index, int wg_peer_remove (u32 peer_index); typedef walk_rc_t (*wg_peer_walk_cb_t) (index_t peeri, void *arg); -void wg_peer_walk (wg_peer_walk_cb_t fn, void *data); +index_t wg_peer_walk (wg_peer_walk_cb_t fn, void *data); u8 *format_wg_peer (u8 * s, va_list * va); -wg_peer_t *wg_peer_get (index_t peeri); walk_rc_t wg_peer_if_admin_state_change (wg_if_t * wgi, index_t peeri, void *data); @@ -104,11 +117,30 @@ walk_rc_t wg_peer_if_table_change (wg_if_t * wgi, index_t peeri, void *data); * Expoed for the data-plane */ extern index_t *wg_peer_by_adj_index; +extern wg_peer_t *wg_peer_pool; static inline wg_peer_t * +wg_peer_get (index_t peeri) +{ + return (pool_elt_at_index (wg_peer_pool, peeri)); +} + +static inline index_t wg_peer_get_by_adj_index (index_t ai) { - return wg_peer_get (wg_peer_by_adj_index[ai]); + return (wg_peer_by_adj_index[ai]); +} + +/* + * Makes choice for thread_id should be assigned. +*/ +static inline u32 +wg_peer_assign_thread (u32 thread_id) +{ + return ((thread_id) ? thread_id + : (vlib_num_workers ()? + ((unix_time_now_nsec () % vlib_num_workers ()) + + 1) : thread_id)); } #endif // __included_wg_peer_h__ diff --git a/src/plugins/wireguard/wireguard_send.c b/src/plugins/wireguard/wireguard_send.c index a5d8aaf6900..2e29a9b4b76 100755 --- a/src/plugins/wireguard/wireguard_send.c +++ b/src/plugins/wireguard/wireguard_send.c @@ -14,13 +14,11 @@ */ #include <vnet/vnet.h> -#include <vnet/fib/ip6_fib.h> -#include <vnet/fib/ip4_fib.h> -#include <vnet/fib/fib_entry.h> #include <vnet/ip/ip6_link.h> #include <vnet/pg/pg.h> #include <vnet/udp/udp.h> #include <vppinfra/error.h> +#include <vlibmemory/api.h> #include <wireguard/wireguard.h> #include <wireguard/wireguard_send.h> @@ -86,7 +84,8 @@ wg_create_buffer (vlib_main_t * vm, bool wg_send_handshake (vlib_main_t * vm, wg_peer_t * peer, bool is_retry) { - wg_main_t *wmp = &wg_main; + ASSERT (vm->thread_index == 0); + message_handshake_initiation_t packet; if (!is_retry) @@ -94,41 +93,73 @@ wg_send_handshake (vlib_main_t * vm, wg_peer_t * peer, bool is_retry) if (!wg_birthdate_has_expired (peer->last_sent_handshake, REKEY_TIMEOUT) || peer->is_dead) - { - return true; - } - if (noise_create_initiation (wmp->vlib_main, + return true; + + if (noise_create_initiation (vm, &peer->remote, &packet.sender_index, packet.unencrypted_ephemeral, packet.encrypted_static, packet.encrypted_timestamp)) { - f64 now = vlib_time_now (vm); packet.header.type = MESSAGE_HANDSHAKE_INITIATION; cookie_maker_mac (&peer->cookie_maker, &packet.macs, &packet, sizeof (packet)); - wg_timers_any_authenticated_packet_traversal (peer); wg_timers_any_authenticated_packet_sent (peer); - peer->last_sent_handshake = now; wg_timers_handshake_initiated (peer); + wg_timers_any_authenticated_packet_traversal (peer); + + peer->last_sent_handshake = vlib_time_now (vm); } else return false; + u32 bi0 = 0; if (!wg_create_buffer (vm, peer, (u8 *) & packet, sizeof (packet), &bi0)) return false; - ip46_enqueue_packet (vm, bi0, false); + ip46_enqueue_packet (vm, bi0, false); return true; } +typedef struct +{ + u32 peer_idx; + bool is_retry; +} wg_send_args_t; + +static void * +wg_send_handshake_thread_fn (void *arg) +{ + wg_send_args_t *a = arg; + + wg_main_t *wmp = &wg_main; + wg_peer_t *peer = wg_peer_get (a->peer_idx); + + wg_send_handshake (wmp->vlib_main, peer, a->is_retry); + return 0; +} + +void +wg_send_handshake_from_mt (u32 peer_idx, bool is_retry) +{ + wg_send_args_t a = { + .peer_idx = peer_idx, + .is_retry = is_retry, + }; + + vl_api_rpc_call_main_thread (wg_send_handshake_thread_fn, + (u8 *) & a, sizeof (a)); +} + bool wg_send_keepalive (vlib_main_t * vm, wg_peer_t * peer) { - wg_main_t *wmp = &wg_main; + ASSERT (vm->thread_index == 0); + u32 size_of_packet = message_data_len (0); - message_data_t *packet = clib_mem_alloc (size_of_packet); + message_data_t *packet = + (message_data_t *) wg_main.per_thread_data[vm->thread_index].data; u32 bi0 = 0; bool ret = true; enum noise_state_crypt state; @@ -140,23 +171,21 @@ wg_send_keepalive (vlib_main_t * vm, wg_peer_t * peer) } state = - noise_remote_encrypt (wmp->vlib_main, + noise_remote_encrypt (vm, &peer->remote, &packet->receiver_index, &packet->counter, NULL, 0, packet->encrypted_data); - switch (state) + + if (PREDICT_FALSE (state == SC_KEEP_KEY_FRESH)) { - case SC_OK: - break; - case SC_KEEP_KEY_FRESH: wg_send_handshake (vm, peer, false); - break; - case SC_FAILED: + } + else if (PREDICT_FALSE (state == SC_FAILED)) + { ret = false; goto out; - default: - break; } + packet->header.type = MESSAGE_DATA; if (!wg_create_buffer (vm, peer, (u8 *) packet, size_of_packet, &bi0)) @@ -166,22 +195,19 @@ wg_send_keepalive (vlib_main_t * vm, wg_peer_t * peer) } ip46_enqueue_packet (vm, bi0, false); - wg_timers_any_authenticated_packet_traversal (peer); + wg_timers_any_authenticated_packet_sent (peer); + wg_timers_any_authenticated_packet_traversal (peer); out: - clib_mem_free (packet); return ret; } bool wg_send_handshake_response (vlib_main_t * vm, wg_peer_t * peer) { - wg_main_t *wmp = &wg_main; message_handshake_response_t packet; - peer->last_sent_handshake = vlib_time_now (vm); - if (noise_create_response (vm, &peer->remote, &packet.sender_index, @@ -189,17 +215,16 @@ wg_send_handshake_response (vlib_main_t * vm, wg_peer_t * peer) packet.unencrypted_ephemeral, packet.encrypted_nothing)) { - f64 now = vlib_time_now (vm); packet.header.type = MESSAGE_HANDSHAKE_RESPONSE; cookie_maker_mac (&peer->cookie_maker, &packet.macs, &packet, sizeof (packet)); - if (noise_remote_begin_session (wmp->vlib_main, &peer->remote)) + if (noise_remote_begin_session (vm, &peer->remote)) { wg_timers_session_derived (peer); - wg_timers_any_authenticated_packet_traversal (peer); wg_timers_any_authenticated_packet_sent (peer); - peer->last_sent_handshake = now; + wg_timers_any_authenticated_packet_traversal (peer); + peer->last_sent_handshake = vlib_time_now (vm); u32 bi0 = 0; if (!wg_create_buffer (vm, peer, (u8 *) & packet, diff --git a/src/plugins/wireguard/wireguard_send.h b/src/plugins/wireguard/wireguard_send.h index 4ea1f6effea..efe41949428 100755 --- a/src/plugins/wireguard/wireguard_send.h +++ b/src/plugins/wireguard/wireguard_send.h @@ -20,6 +20,7 @@ bool wg_send_keepalive (vlib_main_t * vm, wg_peer_t * peer); bool wg_send_handshake (vlib_main_t * vm, wg_peer_t * peer, bool is_retry); +void wg_send_handshake_from_mt (u32 peer_index, bool is_retry); bool wg_send_handshake_response (vlib_main_t * vm, wg_peer_t * peer); always_inline void diff --git a/src/plugins/wireguard/wireguard_timer.c b/src/plugins/wireguard/wireguard_timer.c index e4d4030bb18..b7fd6891d14 100755 --- a/src/plugins/wireguard/wireguard_timer.c +++ b/src/plugins/wireguard/wireguard_timer.c @@ -13,6 +13,7 @@ * limitations under the License. */ +#include <vlibmemory/api.h> #include <wireguard/wireguard.h> #include <wireguard/wireguard_send.h> #include <wireguard/wireguard_timer.h> @@ -30,31 +31,77 @@ stop_timer (wg_peer_t * peer, u32 timer_id) { if (peer->timers[timer_id] != ~0) { - tw_timer_stop_16t_2w_512sl (&peer->timer_wheel, peer->timers[timer_id]); + tw_timer_stop_16t_2w_512sl (peer->timer_wheel, peer->timers[timer_id]); peer->timers[timer_id] = ~0; } } static void -start_or_update_timer (wg_peer_t * peer, u32 timer_id, u32 interval) +start_timer (wg_peer_t * peer, u32 timer_id, u32 interval_ticks) { + ASSERT (vlib_get_thread_index () == 0); + if (peer->timers[timer_id] == ~0) { - wg_main_t *wmp = &wg_main; peer->timers[timer_id] = - tw_timer_start_16t_2w_512sl (&peer->timer_wheel, peer - wmp->peers, - timer_id, interval); - } - else - { - tw_timer_update_16t_2w_512sl (&peer->timer_wheel, - peer->timers[timer_id], interval); + tw_timer_start_16t_2w_512sl (peer->timer_wheel, peer - wg_peer_pool, + timer_id, interval_ticks); } } +typedef struct +{ + u32 peer_idx; + u32 timer_id; + u32 interval_ticks; + +} wg_timers_args; + +static void * +start_timer_thread_fn (void *arg) +{ + wg_timers_args *a = arg; + wg_peer_t *peer = wg_peer_get (a->peer_idx); + + start_timer (peer, a->timer_id, a->interval_ticks); + return 0; +} + +static void +start_timer_from_mt (u32 peer_idx, u32 timer_id, u32 interval_ticks) +{ + wg_timers_args a = { + .peer_idx = peer_idx, + .timer_id = timer_id, + .interval_ticks = interval_ticks, + }; + + vl_api_rpc_call_main_thread (start_timer_thread_fn, (u8 *) & a, sizeof (a)); +} + +static inline u32 +timer_ticks_left (vlib_main_t * vm, f64 init_time_sec, u32 interval_ticks) +{ + static const int32_t rounding = (int32_t) (WHZ / 2); + int32_t ticks_remain; + + ticks_remain = (init_time_sec - vlib_time_now (vm)) * WHZ + interval_ticks; + return (ticks_remain > rounding) ? (u32) ticks_remain : 0; +} + static void wg_expired_retransmit_handshake (vlib_main_t * vm, wg_peer_t * peer) { + if (peer->rehandshake_started == ~0) + return; + + u32 ticks = timer_ticks_left (vm, peer->rehandshake_started, + peer->rehandshake_interval_tick); + if (ticks) + { + start_timer (peer, WG_TIMER_RETRANSMIT_HANDSHAKE, ticks); + return; + } if (peer->timer_handshake_attempts > MAX_TIMER_HANDSHAKES) { @@ -63,17 +110,8 @@ wg_expired_retransmit_handshake (vlib_main_t * vm, wg_peer_t * peer) /* We set a timer for destroying any residue that might be left * of a partial exchange. */ + start_timer (peer, WG_TIMER_KEY_ZEROING, REJECT_AFTER_TIME * 3 * WHZ); - if (peer->timers[WG_TIMER_KEY_ZEROING] == ~0) - { - wg_main_t *wmp = &wg_main; - - peer->timers[WG_TIMER_KEY_ZEROING] = - tw_timer_start_16t_2w_512sl (&peer->timer_wheel, - peer - wmp->peers, - WG_TIMER_KEY_ZEROING, - REJECT_AFTER_TIME * 3 * WHZ); - } } else { @@ -85,13 +123,23 @@ wg_expired_retransmit_handshake (vlib_main_t * vm, wg_peer_t * peer) static void wg_expired_send_keepalive (vlib_main_t * vm, wg_peer_t * peer) { - wg_send_keepalive (vm, peer); - - if (peer->timer_need_another_keepalive) + if (peer->last_sent_packet < peer->last_received_packet) { - peer->timer_need_another_keepalive = false; - start_or_update_timer (peer, WG_TIMER_SEND_KEEPALIVE, - KEEPALIVE_TIMEOUT * WHZ); + u32 ticks = timer_ticks_left (vm, peer->last_received_packet, + KEEPALIVE_TIMEOUT * WHZ); + if (ticks) + { + start_timer (peer, WG_TIMER_SEND_KEEPALIVE, ticks); + return; + } + + wg_send_keepalive (vm, peer); + if (peer->timer_need_another_keepalive) + { + peer->timer_need_another_keepalive = false; + start_timer (peer, WG_TIMER_SEND_KEEPALIVE, + KEEPALIVE_TIMEOUT * WHZ); + } } } @@ -100,6 +148,18 @@ wg_expired_send_persistent_keepalive (vlib_main_t * vm, wg_peer_t * peer) { if (peer->persistent_keepalive_interval) { + f64 latest_time = peer->last_sent_packet > peer->last_received_packet + ? peer->last_sent_packet : peer->last_received_packet; + + u32 ticks = timer_ticks_left (vm, latest_time, + peer->persistent_keepalive_interval * + WHZ); + if (ticks) + { + start_timer (peer, WG_TIMER_PERSISTENT_KEEPALIVE, ticks); + return; + } + wg_send_keepalive (vm, peer); } } @@ -107,64 +167,81 @@ wg_expired_send_persistent_keepalive (vlib_main_t * vm, wg_peer_t * peer) static void wg_expired_new_handshake (vlib_main_t * vm, wg_peer_t * peer) { + u32 ticks = timer_ticks_left (vm, peer->last_sent_packet, + peer->new_handshake_interval_tick); + if (ticks) + { + start_timer (peer, WG_TIMER_NEW_HANDSHAKE, ticks); + return; + } + wg_send_handshake (vm, peer, false); } static void wg_expired_zero_key_material (vlib_main_t * vm, wg_peer_t * peer) { + u32 ticks = + timer_ticks_left (vm, peer->session_derived, REJECT_AFTER_TIME * 3 * WHZ); + if (ticks) + { + start_timer (peer, WG_TIMER_KEY_ZEROING, ticks); + return; + } + if (!peer->is_dead) { noise_remote_clear (vm, &peer->remote); } } - void wg_timers_any_authenticated_packet_traversal (wg_peer_t * peer) { if (peer->persistent_keepalive_interval) { - start_or_update_timer (peer, WG_TIMER_PERSISTENT_KEEPALIVE, - peer->persistent_keepalive_interval * WHZ); + start_timer_from_mt (peer - wg_peer_pool, + WG_TIMER_PERSISTENT_KEEPALIVE, + peer->persistent_keepalive_interval * WHZ); } } void wg_timers_any_authenticated_packet_sent (wg_peer_t * peer) { - stop_timer (peer, WG_TIMER_SEND_KEEPALIVE); + peer->last_sent_packet = vlib_time_now (vlib_get_main ()); } void wg_timers_handshake_initiated (wg_peer_t * peer) { - u32 interval = + peer->rehandshake_started = vlib_time_now (vlib_get_main ()); + peer->rehandshake_interval_tick = REKEY_TIMEOUT * WHZ + get_random_u32_max (REKEY_TIMEOUT_JITTER); - start_or_update_timer (peer, WG_TIMER_RETRANSMIT_HANDSHAKE, interval); + + start_timer_from_mt (peer - wg_peer_pool, WG_TIMER_RETRANSMIT_HANDSHAKE, + peer->rehandshake_interval_tick); } void wg_timers_session_derived (wg_peer_t * peer) { - start_or_update_timer (peer, WG_TIMER_KEY_ZEROING, - REJECT_AFTER_TIME * 3 * WHZ); + peer->session_derived = vlib_time_now (vlib_get_main ()); + + start_timer_from_mt (peer - wg_peer_pool, WG_TIMER_KEY_ZEROING, + REJECT_AFTER_TIME * 3 * WHZ); } /* Should be called after an authenticated data packet is sent. */ void wg_timers_data_sent (wg_peer_t * peer) { - u32 interval = (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT) * WHZ + + peer->new_handshake_interval_tick = + (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT) * WHZ + get_random_u32_max (REKEY_TIMEOUT_JITTER); - if (peer->timers[WG_TIMER_NEW_HANDSHAKE] == ~0) - { - wg_main_t *wmp = &wg_main; - peer->timers[WG_TIMER_NEW_HANDSHAKE] = - tw_timer_start_16t_2w_512sl (&peer->timer_wheel, peer - wmp->peers, - WG_TIMER_NEW_HANDSHAKE, interval); - } + start_timer_from_mt (peer - wg_peer_pool, WG_TIMER_NEW_HANDSHAKE, + peer->new_handshake_interval_tick); } /* Should be called after an authenticated data packet is received. */ @@ -173,16 +250,11 @@ wg_timers_data_received (wg_peer_t * peer) { if (peer->timers[WG_TIMER_SEND_KEEPALIVE] == ~0) { - wg_main_t *wmp = &wg_main; - peer->timers[WG_TIMER_SEND_KEEPALIVE] = - tw_timer_start_16t_2w_512sl (&peer->timer_wheel, peer - wmp->peers, - WG_TIMER_SEND_KEEPALIVE, - KEEPALIVE_TIMEOUT * WHZ); + start_timer_from_mt (peer - wg_peer_pool, WG_TIMER_SEND_KEEPALIVE, + KEEPALIVE_TIMEOUT * WHZ); } else - { - peer->timer_need_another_keepalive = true; - } + peer->timer_need_another_keepalive = true; } /* Should be called after a handshake response message is received and processed @@ -191,15 +263,14 @@ wg_timers_data_received (wg_peer_t * peer) void wg_timers_handshake_complete (wg_peer_t * peer) { - stop_timer (peer, WG_TIMER_RETRANSMIT_HANDSHAKE); - + peer->rehandshake_started = ~0; peer->timer_handshake_attempts = 0; } void wg_timers_any_authenticated_packet_received (wg_peer_t * peer) { - stop_timer (peer, WG_TIMER_NEW_HANDSHAKE); + peer->last_received_packet = vlib_time_now (vlib_get_main ()); } static vlib_node_registration_t wg_timer_mngr_node; @@ -222,7 +293,7 @@ expired_timer_callback (u32 * expired_timers) pool_index = expired_timers[i] & 0x0FFFFFFF; timer_id = expired_timers[i] >> 28; - peer = pool_elt_at_index (wmp->peers, pool_index); + peer = wg_peer_get (pool_index); peer->timers[timer_id] = ~0; } @@ -231,7 +302,7 @@ expired_timer_callback (u32 * expired_timers) pool_index = expired_timers[i] & 0x0FFFFFFF; timer_id = expired_timers[i] >> 28; - peer = pool_elt_at_index (wmp->peers, pool_index); + peer = wg_peer_get (pool_index); switch (timer_id) { case WG_TIMER_RETRANSMIT_HANDSHAKE: @@ -256,18 +327,14 @@ expired_timer_callback (u32 * expired_timers) } void -wg_timers_init (wg_peer_t * peer, f64 now) +wg_timer_wheel_init () { - for (int i = 0; i < WG_N_TIMERS; i++) - { - peer->timers[i] = ~0; - } - tw_timer_wheel_16t_2w_512sl_t *tw = &peer->timer_wheel; + wg_main_t *wmp = &wg_main; + tw_timer_wheel_16t_2w_512sl_t *tw = &wmp->timer_wheel; tw_timer_wheel_init_16t_2w_512sl (tw, expired_timer_callback, WG_TICK /* timer period in s */ , ~0); - tw->last_run_time = now; - peer->adj_index = INDEX_INVALID; + tw->last_run_time = vlib_time_now (wmp->vlib_main); } static uword @@ -275,22 +342,13 @@ wg_timer_mngr_fn (vlib_main_t * vm, vlib_node_runtime_t * rt, vlib_frame_t * f) { wg_main_t *wmp = &wg_main; - wg_peer_t *peers; - wg_peer_t *peer; - while (1) { vlib_process_wait_for_event_or_clock (vm, WG_TICK); vlib_process_get_events (vm, NULL); - peers = wmp->peers; - /* *INDENT-OFF* */ - pool_foreach (peer, peers, - ({ - tw_timer_expire_timers_16t_2w_512sl - (&peer->timer_wheel, vlib_time_now (vm)); - })); - /* *INDENT-ON* */ + tw_timer_expire_timers_16t_2w_512sl (&wmp->timer_wheel, + vlib_time_now (vm)); } return 0; @@ -299,11 +357,15 @@ wg_timer_mngr_fn (vlib_main_t * vm, vlib_node_runtime_t * rt, void wg_timers_stop (wg_peer_t * peer) { - stop_timer (peer, WG_TIMER_RETRANSMIT_HANDSHAKE); - stop_timer (peer, WG_TIMER_PERSISTENT_KEEPALIVE); - stop_timer (peer, WG_TIMER_SEND_KEEPALIVE); - stop_timer (peer, WG_TIMER_NEW_HANDSHAKE); - stop_timer (peer, WG_TIMER_KEY_ZEROING); + ASSERT (vlib_get_thread_index () == 0); + if (peer->timer_wheel) + { + stop_timer (peer, WG_TIMER_RETRANSMIT_HANDSHAKE); + stop_timer (peer, WG_TIMER_PERSISTENT_KEEPALIVE); + stop_timer (peer, WG_TIMER_SEND_KEEPALIVE); + stop_timer (peer, WG_TIMER_NEW_HANDSHAKE); + stop_timer (peer, WG_TIMER_KEY_ZEROING); + } } /* *INDENT-OFF* */ diff --git a/src/plugins/wireguard/wireguard_timer.h b/src/plugins/wireguard/wireguard_timer.h index 457dce28674..2cc5dd01284 100755 --- a/src/plugins/wireguard/wireguard_timer.h +++ b/src/plugins/wireguard/wireguard_timer.h @@ -38,7 +38,7 @@ typedef enum _wg_timers typedef struct wg_peer wg_peer_t; -void wg_timers_init (wg_peer_t * peer, f64 now); +void wg_timer_wheel_init (); void wg_timers_stop (wg_peer_t * peer); void wg_timers_data_sent (wg_peer_t * peer); void wg_timers_data_received (wg_peer_t * peer); |