aboutsummaryrefslogtreecommitdiffstats
path: root/src/plugins/wireguard
diff options
context:
space:
mode:
Diffstat (limited to 'src/plugins/wireguard')
-rwxr-xr-xsrc/plugins/wireguard/CMakeLists.txt1
-rwxr-xr-xsrc/plugins/wireguard/test/test_wireguard.py101
-rwxr-xr-xsrc/plugins/wireguard/wireguard.c12
-rwxr-xr-xsrc/plugins/wireguard/wireguard.h18
-rwxr-xr-xsrc/plugins/wireguard/wireguard_api.c4
-rw-r--r--src/plugins/wireguard/wireguard_handoff.c197
-rw-r--r--src/plugins/wireguard/wireguard_if.c86
-rw-r--r--src/plugins/wireguard/wireguard_if.h5
-rwxr-xr-xsrc/plugins/wireguard/wireguard_input.c255
-rwxr-xr-xsrc/plugins/wireguard/wireguard_noise.c131
-rwxr-xr-xsrc/plugins/wireguard/wireguard_noise.h23
-rwxr-xr-xsrc/plugins/wireguard/wireguard_output_tun.c94
-rwxr-xr-xsrc/plugins/wireguard/wireguard_peer.c58
-rwxr-xr-xsrc/plugins/wireguard/wireguard_peer.h40
-rwxr-xr-xsrc/plugins/wireguard/wireguard_send.c89
-rwxr-xr-xsrc/plugins/wireguard/wireguard_send.h1
-rwxr-xr-xsrc/plugins/wireguard/wireguard_timer.c220
-rwxr-xr-xsrc/plugins/wireguard/wireguard_timer.h2
18 files changed, 889 insertions, 448 deletions
diff --git a/src/plugins/wireguard/CMakeLists.txt b/src/plugins/wireguard/CMakeLists.txt
index db5bb2d8910..db74f9cdce0 100755
--- a/src/plugins/wireguard/CMakeLists.txt
+++ b/src/plugins/wireguard/CMakeLists.txt
@@ -30,6 +30,7 @@ add_vpp_plugin(wireguard
wireguard_if.h
wireguard_input.c
wireguard_output_tun.c
+ wireguard_handoff.c
wireguard_key.c
wireguard_key.h
wireguard_cli.c
diff --git a/src/plugins/wireguard/test/test_wireguard.py b/src/plugins/wireguard/test/test_wireguard.py
index 77349396ec6..cee1e938bb0 100755
--- a/src/plugins/wireguard/test/test_wireguard.py
+++ b/src/plugins/wireguard/test/test_wireguard.py
@@ -327,6 +327,14 @@ class VppWgPeer(VppObject):
def encrypt_transport(self, p):
return self.noise.encrypt(bytes(p))
+ def validate_encapped(self, rxs, tx):
+ for rx in rxs:
+ rx = IP(self.decrypt_transport(rx))
+
+ # chech the oringial packet is present
+ self._test.assertEqual(rx[IP].dst, tx[IP].dst)
+ self._test.assertEqual(rx[IP].ttl, tx[IP].ttl-1)
+
class TestWg(VppTestCase):
""" Wireguard Test Case """
@@ -455,11 +463,7 @@ class TestWg(VppTestCase):
rxs = self.send_and_expect(self.pg0, p * 255, self.pg1)
- for rx in rxs:
- rx = IP(peer_1.decrypt_transport(rx))
- # chech the oringial packet is present
- self.assertEqual(rx[IP].dst, p[IP].dst)
- self.assertEqual(rx[IP].ttl, p[IP].ttl-1)
+ peer_1.validate_encapped(rxs, p)
# send packets into the tunnel, expect to receive them on
# the other side
@@ -655,3 +659,90 @@ class TestWg(VppTestCase):
wg0.remove_vpp_config()
wg1.remove_vpp_config()
+
+
+class WireguardHandoffTests(TestWg):
+ """ Wireguard Tests in multi worker setup """
+ worker_config = "workers 2"
+
+ def test_wg_peer_init(self):
+ """ Handoff """
+ wg_output_node_name = '/err/wg-output-tun/'
+ wg_input_node_name = '/err/wg-input/'
+
+ port = 12323
+
+ # Create interfaces
+ wg0 = VppWgInterface(self,
+ self.pg1.local_ip4,
+ port).add_vpp_config()
+ wg0.admin_up()
+ wg0.config_ip4()
+
+ peer_1 = VppWgPeer(self,
+ wg0,
+ self.pg1.remote_ip4,
+ port+1,
+ ["10.11.2.0/24",
+ "10.11.3.0/24"]).add_vpp_config()
+ self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1)
+
+ # send a valid handsake init for which we expect a response
+ p = peer_1.mk_handshake(self.pg1)
+
+ rx = self.send_and_expect(self.pg1, [p], self.pg1)
+
+ peer_1.consume_response(rx[0])
+
+ # send a data packet from the peer through the tunnel
+ # this completes the handshake and pins the peer to worker 0
+ p = (IP(src="10.11.3.1", dst=self.pg0.remote_ip4, ttl=20) /
+ UDP(sport=222, dport=223) /
+ Raw())
+ d = peer_1.encrypt_transport(p)
+ p = (peer_1.mk_tunnel_header(self.pg1) /
+ (Wireguard(message_type=4, reserved_zero=0) /
+ WireguardTransport(receiver_index=peer_1.sender,
+ counter=0,
+ encrypted_encapsulated_packet=d)))
+ rxs = self.send_and_expect(self.pg1, [p], self.pg0,
+ worker=0)
+
+ for rx in rxs:
+ self.assertEqual(rx[IP].dst, self.pg0.remote_ip4)
+ self.assertEqual(rx[IP].ttl, 19)
+
+ # send a packets that are routed into the tunnel
+ # and pins the peer tp worker 1
+ pe = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) /
+ IP(src=self.pg0.remote_ip4, dst="10.11.3.2") /
+ UDP(sport=555, dport=556) /
+ Raw(b'\x00' * 80))
+ rxs = self.send_and_expect(self.pg0, pe * 255, self.pg1, worker=1)
+ peer_1.validate_encapped(rxs, pe)
+
+ # send packets into the tunnel, from the other worker
+ p = [(peer_1.mk_tunnel_header(self.pg1) /
+ Wireguard(message_type=4, reserved_zero=0) /
+ WireguardTransport(
+ receiver_index=peer_1.sender,
+ counter=ii+1,
+ encrypted_encapsulated_packet=peer_1.encrypt_transport(
+ (IP(src="10.11.3.1", dst=self.pg0.remote_ip4, ttl=20) /
+ UDP(sport=222, dport=223) /
+ Raw())))) for ii in range(255)]
+
+ rxs = self.send_and_expect(self.pg1, p, self.pg0, worker=1)
+
+ for rx in rxs:
+ self.assertEqual(rx[IP].dst, self.pg0.remote_ip4)
+ self.assertEqual(rx[IP].ttl, 19)
+
+ # send a packets that are routed into the tunnel
+ # from owrker 0
+ rxs = self.send_and_expect(self.pg0, pe * 255, self.pg1, worker=0)
+
+ peer_1.validate_encapped(rxs, pe)
+
+ peer_1.remove_vpp_config()
+ wg0.remove_vpp_config()
diff --git a/src/plugins/wireguard/wireguard.c b/src/plugins/wireguard/wireguard.c
index 00921811e4a..9510a0ad385 100755
--- a/src/plugins/wireguard/wireguard.c
+++ b/src/plugins/wireguard/wireguard.c
@@ -32,7 +32,17 @@ wg_init (vlib_main_t * vm)
wg_main_t *wmp = &wg_main;
wmp->vlib_main = vm;
- wmp->peers = 0;
+
+ wmp->in_fq_index = vlib_frame_queue_main_init (wg_input_node.index, 0);
+ wmp->out_fq_index =
+ vlib_frame_queue_main_init (wg_output_tun_node.index, 0);
+
+ vlib_thread_main_t *tm = vlib_get_thread_main ();
+
+ vec_validate_aligned (wmp->per_thread_data, tm->n_vlib_mains,
+ CLIB_CACHE_LINE_BYTES);
+
+ wg_timer_wheel_init ();
return (NULL);
}
diff --git a/src/plugins/wireguard/wireguard.h b/src/plugins/wireguard/wireguard.h
index 70a692e602f..2c892a374b8 100755
--- a/src/plugins/wireguard/wireguard.h
+++ b/src/plugins/wireguard/wireguard.h
@@ -17,13 +17,17 @@
#include <wireguard/wireguard_index_table.h>
#include <wireguard/wireguard_messages.h>
-#include <wireguard/wireguard_peer.h>
+#include <wireguard/wireguard_timer.h>
+
+#define WG_DEFAULT_DATA_SIZE 2048
extern vlib_node_registration_t wg_input_node;
extern vlib_node_registration_t wg_output_tun_node;
-
-
+typedef struct wg_per_thread_data_t_
+{
+ u8 data[WG_DEFAULT_DATA_SIZE];
+} wg_per_thread_data_t;
typedef struct
{
/* convenience */
@@ -31,10 +35,14 @@ typedef struct
u16 msg_id_base;
- // Peers pool
- wg_peer_t *peers;
wg_index_table_t index_table;
+ u32 in_fq_index;
+ u32 out_fq_index;
+
+ wg_per_thread_data_t *per_thread_data;
+
+ tw_timer_wheel_16t_2w_512sl_t timer_wheel;
} wg_main_t;
extern wg_main_t wg_main;
diff --git a/src/plugins/wireguard/wireguard_api.c b/src/plugins/wireguard/wireguard_api.c
index 8bbacddaf45..27ed6ea05da 100755
--- a/src/plugins/wireguard/wireguard_api.c
+++ b/src/plugins/wireguard/wireguard_api.c
@@ -97,15 +97,17 @@ wireguard_if_send_details (index_t wgii, void *data)
vl_api_wireguard_interface_details_t *rmp;
wg_deatils_walk_t *ctx = data;
const wg_if_t *wgi;
+ const noise_local_t *local;
wgi = wg_if_get (wgii);
+ local = noise_local_get (wgi->local_idx);
rmp = vl_msg_api_alloc_zero (sizeof (*rmp));
rmp->_vl_msg_id = htons (VL_API_WIREGUARD_INTERFACE_DETAILS +
wg_main.msg_id_base);
clib_memcpy (rmp->interface.private_key,
- wgi->local.l_private, NOISE_PUBLIC_KEY_LEN);
+ local->l_private, NOISE_PUBLIC_KEY_LEN);
rmp->interface.sw_if_index = htonl (wgi->sw_if_index);
rmp->interface.port = htons (wgi->port);
ip_address_encode2 (&wgi->src_ip, &rmp->interface.src_ip);
diff --git a/src/plugins/wireguard/wireguard_handoff.c b/src/plugins/wireguard/wireguard_handoff.c
new file mode 100644
index 00000000000..b0b74229452
--- /dev/null
+++ b/src/plugins/wireguard/wireguard_handoff.c
@@ -0,0 +1,197 @@
+/*
+ * Copyright (c) 2020 Doc.ai and/or its affiliates.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at:
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <wireguard/wireguard.h>
+#include <wireguard/wireguard_peer.h>
+
+#define foreach_wg_handoff_error \
+_(CONGESTION_DROP, "congestion drop")
+
+typedef enum
+{
+#define _(sym,str) WG_HANDOFF_ERROR_##sym,
+ foreach_wg_handoff_error
+#undef _
+ HANDOFF_N_ERROR,
+} ipsec_handoff_error_t;
+
+static char *wg_handoff_error_strings[] = {
+#define _(sym,string) string,
+ foreach_wg_handoff_error
+#undef _
+};
+
+typedef enum
+{
+ WG_HANDOFF_HANDSHAKE,
+ WG_HANDOFF_INP_DATA,
+ WG_HANDOFF_OUT_TUN,
+} wg_handoff_mode_t;
+
+typedef struct wg_handoff_trace_t_
+{
+ u32 next_worker_index;
+ index_t peer;
+} wg_handoff_trace_t;
+
+static u8 *
+format_wg_handoff_trace (u8 * s, va_list * args)
+{
+ CLIB_UNUSED (vlib_main_t * vm) = va_arg (*args, vlib_main_t *);
+ CLIB_UNUSED (vlib_node_t * node) = va_arg (*args, vlib_node_t *);
+ wg_handoff_trace_t *t = va_arg (*args, wg_handoff_trace_t *);
+
+ s = format (s, "next-worker %d peer %d", t->next_worker_index, t->peer);
+
+ return s;
+}
+
+static_always_inline uword
+wg_handoff (vlib_main_t * vm,
+ vlib_node_runtime_t * node,
+ vlib_frame_t * frame, u32 fq_index, wg_handoff_mode_t mode)
+{
+ vlib_buffer_t *bufs[VLIB_FRAME_SIZE], **b;
+ u16 thread_indices[VLIB_FRAME_SIZE], *ti;
+ u32 n_enq, n_left_from, *from;
+ wg_main_t *wmp;
+
+ wmp = &wg_main;
+ from = vlib_frame_vector_args (frame);
+ n_left_from = frame->n_vectors;
+ vlib_get_buffers (vm, from, bufs, n_left_from);
+
+ b = bufs;
+ ti = thread_indices;
+
+ while (n_left_from > 0)
+ {
+ const wg_peer_t *peer;
+ index_t peeri;
+
+ if (PREDICT_FALSE (mode == WG_HANDOFF_HANDSHAKE))
+ {
+ ti[0] = 0;
+ }
+ else if (mode == WG_HANDOFF_INP_DATA)
+ {
+ message_data_t *data = vlib_buffer_get_current (b[0]);
+ u32 *entry =
+ wg_index_table_lookup (&wmp->index_table, data->receiver_index);
+ peeri = *entry;
+ peer = wg_peer_get (peeri);
+
+ ti[0] = peer->input_thread_index;
+ }
+ else
+ {
+ peeri =
+ wg_peer_get_by_adj_index (vnet_buffer (b[0])->
+ ip.adj_index[VLIB_TX]);
+ peer = wg_peer_get (peeri);
+ ti[0] = peer->output_thread_index;
+ }
+
+ if (PREDICT_FALSE (b[0]->flags & VLIB_BUFFER_IS_TRACED))
+ {
+ wg_handoff_trace_t *t =
+ vlib_add_trace (vm, node, b[0], sizeof (*t));
+ t->next_worker_index = ti[0];
+ t->peer = peeri;
+ }
+
+ n_left_from -= 1;
+ ti += 1;
+ b += 1;
+ }
+
+ n_enq = vlib_buffer_enqueue_to_thread (vm, fq_index, from,
+ thread_indices, frame->n_vectors, 1);
+
+ if (n_enq < frame->n_vectors)
+ vlib_node_increment_counter (vm, node->node_index,
+ WG_HANDOFF_ERROR_CONGESTION_DROP,
+ frame->n_vectors - n_enq);
+
+ return n_enq;
+}
+
+VLIB_NODE_FN (wg_handshake_handoff) (vlib_main_t * vm,
+ vlib_node_runtime_t * node,
+ vlib_frame_t * from_frame)
+{
+ wg_main_t *wmp = &wg_main;
+
+ return wg_handoff (vm, node, from_frame, wmp->in_fq_index,
+ WG_HANDOFF_HANDSHAKE);
+}
+
+VLIB_NODE_FN (wg_input_data_handoff) (vlib_main_t * vm,
+ vlib_node_runtime_t * node,
+ vlib_frame_t * from_frame)
+{
+ wg_main_t *wmp = &wg_main;
+
+ return wg_handoff (vm, node, from_frame, wmp->in_fq_index,
+ WG_HANDOFF_INP_DATA);
+}
+
+VLIB_NODE_FN (wg_output_tun_handoff) (vlib_main_t * vm,
+ vlib_node_runtime_t * node,
+ vlib_frame_t * from_frame)
+{
+ wg_main_t *wmp = &wg_main;
+
+ return wg_handoff (vm, node, from_frame, wmp->out_fq_index,
+ WG_HANDOFF_OUT_TUN);
+}
+
+VLIB_REGISTER_NODE (wg_handshake_handoff) =
+{
+ .name = "wg-handshake-handoff",.vector_size = sizeof (u32),.format_trace =
+ format_wg_handoff_trace,.type = VLIB_NODE_TYPE_INTERNAL,.n_errors =
+ ARRAY_LEN (wg_handoff_error_strings),.error_strings =
+ wg_handoff_error_strings,.n_next_nodes = 1,.next_nodes =
+ {
+ [0] = "error-drop",}
+,};
+
+VLIB_REGISTER_NODE (wg_input_data_handoff) =
+{
+ .name = "wg-input-data-handoff",.vector_size = sizeof (u32),.format_trace =
+ format_wg_handoff_trace,.type = VLIB_NODE_TYPE_INTERNAL,.n_errors =
+ ARRAY_LEN (wg_handoff_error_strings),.error_strings =
+ wg_handoff_error_strings,.n_next_nodes = 1,.next_nodes =
+ {
+ [0] = "error-drop",}
+,};
+
+VLIB_REGISTER_NODE (wg_output_tun_handoff) =
+{
+ .name = "wg-output-tun-handoff",.vector_size = sizeof (u32),.format_trace =
+ format_wg_handoff_trace,.type = VLIB_NODE_TYPE_INTERNAL,.n_errors =
+ ARRAY_LEN (wg_handoff_error_strings),.error_strings =
+ wg_handoff_error_strings,.n_next_nodes = 1,.next_nodes =
+ {
+ [0] = "error-drop",}
+,};
+
+/*
+ * fd.io coding-style-patch-verification: ON
+ *
+ * Local Variables:
+ * eval: (c-set-style "gnu")
+ * End:
+ */
diff --git a/src/plugins/wireguard/wireguard_if.c b/src/plugins/wireguard/wireguard_if.c
index c91667bb234..7509923a1bf 100644
--- a/src/plugins/wireguard/wireguard_if.c
+++ b/src/plugins/wireguard/wireguard_if.c
@@ -5,6 +5,7 @@
#include <wireguard/wireguard_messages.h>
#include <wireguard/wireguard_if.h>
#include <wireguard/wireguard.h>
+#include <wireguard/wireguard_peer.h>
/* pool of interfaces */
wg_if_t *wg_if_pool;
@@ -30,28 +31,28 @@ format_wg_if (u8 * s, va_list * args)
{
index_t wgii = va_arg (*args, u32);
wg_if_t *wgi = wg_if_get (wgii);
+ noise_local_t *local = noise_local_get (wgi->local_idx);
u8 key[NOISE_KEY_LEN_BASE64];
- key_to_base64 (wgi->local.l_private, NOISE_PUBLIC_KEY_LEN, key);
s = format (s, "[%d] %U src:%U port:%d",
wgii,
format_vnet_sw_if_index_name, vnet_get_main (),
wgi->sw_if_index, format_ip_address, &wgi->src_ip, wgi->port);
- key_to_base64 (wgi->local.l_private, NOISE_PUBLIC_KEY_LEN, key);
+ key_to_base64 (local->l_private, NOISE_PUBLIC_KEY_LEN, key);
s = format (s, " private-key:%s", key);
s =
- format (s, " %U", format_hex_bytes, wgi->local.l_private,
+ format (s, " %U", format_hex_bytes, local->l_private,
NOISE_PUBLIC_KEY_LEN);
- key_to_base64 (wgi->local.l_public, NOISE_PUBLIC_KEY_LEN, key);
+ key_to_base64 (local->l_public, NOISE_PUBLIC_KEY_LEN, key);
s = format (s, " public-key:%s", key);
s =
- format (s, " %U", format_hex_bytes, wgi->local.l_public,
+ format (s, " %U", format_hex_bytes, local->l_public,
NOISE_PUBLIC_KEY_LEN);
s = format (s, " mac-key: %U", format_hex_bytes,
@@ -72,23 +73,28 @@ wg_if_find_by_sw_if_index (u32 sw_if_index)
return (ti);
}
+static walk_rc_t
+wg_if_find_peer_by_public_key (index_t peeri, void *data)
+{
+ uint8_t *public = data;
+ wg_peer_t *peer = wg_peer_get (peeri);
+
+ if (!memcmp (peer->remote.r_public, public, NOISE_PUBLIC_KEY_LEN))
+ return (WALK_STOP);
+ return (WALK_CONTINUE);
+}
+
static noise_remote_t *
-wg_remote_get (uint8_t public[NOISE_PUBLIC_KEY_LEN])
+wg_remote_get (const uint8_t public[NOISE_PUBLIC_KEY_LEN])
{
- wg_main_t *wmp = &wg_main;
- wg_peer_t *peer = NULL;
- wg_peer_t *peer_iter;
- /* *INDENT-OFF* */
- pool_foreach (peer_iter, wmp->peers,
- ({
- if (!memcmp (peer_iter->remote.r_public, public, NOISE_PUBLIC_KEY_LEN))
- {
- peer = peer_iter;
- break;
- }
- }));
- /* *INDENT-ON* */
- return peer ? &peer->remote : NULL;
+ index_t peeri;
+
+ peeri = wg_peer_walk (wg_if_find_peer_by_public_key, (void *) public);
+
+ if (INDEX_INVALID != peeri)
+ return &wg_peer_get (peeri)->remote;
+
+ return NULL;
}
static uint32_t
@@ -223,6 +229,7 @@ wg_if_create (u32 user_instance,
u32 instance, hw_if_index;
vnet_hw_interface_t *hi;
wg_if_t *wg_if;
+ noise_local_t *local;
ASSERT (sw_if_indexp);
@@ -236,6 +243,24 @@ wg_if_create (u32 user_instance,
if (instance == ~0)
return VNET_API_ERROR_INVALID_REGISTRATION;
+ /* *INDENT-OFF* */
+ struct noise_upcall upcall = {
+ .u_remote_get = wg_remote_get,
+ .u_index_set = wg_index_set,
+ .u_index_drop = wg_index_drop,
+ };
+ /* *INDENT-ON* */
+
+ pool_get (noise_local_pool, local);
+
+ noise_local_init (local, &upcall);
+ if (!noise_local_set_private (local, private_key))
+ {
+ pool_put (noise_local_pool, local);
+ wg_if_instance_free (instance);
+ return VNET_API_ERROR_INVALID_REGISTRATION;
+ }
+
pool_get (wg_if_pool, wg_if);
/* tunnel index (or instance) */
@@ -251,18 +276,8 @@ wg_if_create (u32 user_instance,
wg_if_index_by_port[port] = wg_if - wg_if_pool;
wg_if->port = port;
-
- /* *INDENT-OFF* */
- struct noise_upcall upcall = {
- .u_remote_get = wg_remote_get,
- .u_index_set = wg_index_set,
- .u_index_drop = wg_index_drop,
- };
- /* *INDENT-ON* */
-
- noise_local_init (&wg_if->local, &upcall);
- noise_local_set_private (&wg_if->local, private_key);
- cookie_checker_update (&wg_if->cookie_checker, wg_if->local.l_public);
+ wg_if->local_idx = local - noise_local_pool;
+ cookie_checker_update (&wg_if->cookie_checker, local->l_public);
hw_if_index = vnet_register_interface (vnm,
wg_if_device_class.index,
@@ -304,6 +319,7 @@ wg_if_delete (u32 sw_if_index)
udp_unregister_dst_port (vlib_get_main (), wg_if->port, 1);
wg_if_index_by_port[wg_if->port] = INDEX_INVALID;
vnet_delete_hw_interface (vnm, hw->hw_if_index);
+ pool_put_index (noise_local_pool, wg_if->local_idx);
pool_put (wg_if_pool, wg_if);
return 0;
@@ -343,7 +359,7 @@ wg_if_walk (wg_if_walk_cb_t fn, void *data)
/* *INDENT-ON* */
}
-void
+index_t
wg_if_peer_walk (wg_if_t * wgi, wg_if_peer_walk_cb_t fn, void *data)
{
index_t peeri, val;
@@ -352,9 +368,11 @@ wg_if_peer_walk (wg_if_t * wgi, wg_if_peer_walk_cb_t fn, void *data)
hash_foreach (peeri, val, wgi->peers,
{
if (WALK_STOP == fn(wgi, peeri, data))
- break;
+ return peeri;
});
/* *INDENT-ON* */
+
+ return INDEX_INVALID;
}
diff --git a/src/plugins/wireguard/wireguard_if.h b/src/plugins/wireguard/wireguard_if.h
index 9e6b6190e0e..d8c2a87dc71 100644
--- a/src/plugins/wireguard/wireguard_if.h
+++ b/src/plugins/wireguard/wireguard_if.h
@@ -25,7 +25,8 @@ typedef struct wg_if_t_
u32 sw_if_index;
// Interface params
- noise_local_t local;
+ /* noise_local_pool elt index */
+ u32 local_idx;
cookie_checker_t cookie_checker;
u16 port;
@@ -52,7 +53,7 @@ void wg_if_walk (wg_if_walk_cb_t fn, void *data);
typedef walk_rc_t (*wg_if_peer_walk_cb_t) (wg_if_t * wgi, index_t peeri,
void *data);
-void wg_if_peer_walk (wg_if_t * wgi, wg_if_peer_walk_cb_t fn, void *data);
+index_t wg_if_peer_walk (wg_if_t * wgi, wg_if_peer_walk_cb_t fn, void *data);
void wg_if_peer_add (wg_if_t * wgi, index_t peeri);
void wg_if_peer_remove (wg_if_t * wgi, index_t peeri);
diff --git a/src/plugins/wireguard/wireguard_input.c b/src/plugins/wireguard/wireguard_input.c
index cdd65f87b51..b15c265cdac 100755
--- a/src/plugins/wireguard/wireguard_input.c
+++ b/src/plugins/wireguard/wireguard_input.c
@@ -30,6 +30,7 @@
_(DECRYPTION, "Failed during decryption") \
_(KEEPALIVE_SEND, "Failed while sending Keepalive") \
_(HANDSHAKE_SEND, "Failed while sending Handshake") \
+ _(TOO_BIG, "Packet too big") \
_(UNDEFINED, "Undefined error")
typedef enum
@@ -51,7 +52,7 @@ typedef struct
message_type_t type;
u16 current_length;
bool is_keepalive;
-
+ index_t peer;
} wg_input_trace_t;
u8 *
@@ -79,6 +80,7 @@ format_wg_input_trace (u8 * s, va_list * args)
s = format (s, "WG input: \n");
s = format (s, " Type: %U\n", format_wg_message_type, t->type);
+ s = format (s, " peer: %d\n", t->peer);
s = format (s, " Length: %d\n", t->current_length);
s = format (s, " Keepalive: %s", t->is_keepalive ? "true" : "false");
@@ -87,6 +89,8 @@ format_wg_input_trace (u8 * s, va_list * args)
typedef enum
{
+ WG_INPUT_NEXT_HANDOFF_HANDSHAKE,
+ WG_INPUT_NEXT_HANDOFF_DATA,
WG_INPUT_NEXT_IP4_INPUT,
WG_INPUT_NEXT_PUNT,
WG_INPUT_NEXT_ERROR,
@@ -106,6 +110,8 @@ typedef enum
static wg_input_error_t
wg_handshake_process (vlib_main_t * vm, wg_main_t * wmp, vlib_buffer_t * b)
{
+ ASSERT (vm->thread_index == 0);
+
enum cookie_mac_state mac_state;
bool packet_needs_cookie;
bool under_load;
@@ -129,17 +135,15 @@ wg_handshake_process (vlib_main_t * vm, wg_main_t * wmp, vlib_buffer_t * b)
if (NULL == wg_if)
return WG_INPUT_ERROR_INTERFACE;
- if (header->type == MESSAGE_HANDSHAKE_COOKIE)
+ if (PREDICT_FALSE (header->type == MESSAGE_HANDSHAKE_COOKIE))
{
message_handshake_cookie_t *packet =
(message_handshake_cookie_t *) current_b_data;
u32 *entry =
wg_index_table_lookup (&wmp->index_table, packet->receiver_index);
if (entry)
- {
- peer = pool_elt_at_index (wmp->peers, *entry);
- }
- if (!peer)
+ peer = wg_peer_get (*entry);
+ else
return WG_INPUT_ERROR_PEER;
// TODO: Implement cookie_maker_consume_payload
@@ -178,17 +182,17 @@ wg_handshake_process (vlib_main_t * vm, wg_main_t * wmp, vlib_buffer_t * b)
// TODO: Add processing
}
noise_remote_t *rp;
-
if (noise_consume_initiation
- (wmp->vlib_main, &wg_if->local, &rp, message->sender_index,
- message->unencrypted_ephemeral, message->encrypted_static,
- message->encrypted_timestamp))
+ (vm, noise_local_get (wg_if->local_idx), &rp,
+ message->sender_index, message->unencrypted_ephemeral,
+ message->encrypted_static, message->encrypted_timestamp))
{
- peer = pool_elt_at_index (wmp->peers, rp->r_peer_idx);
+ peer = wg_peer_get (rp->r_peer_idx);
+ }
+ else
+ {
+ return WG_INPUT_ERROR_PEER;
}
-
- if (!peer)
- return WG_INPUT_ERROR_PEER;
// set_peer_address (peer, ip4_src, udp_src_port);
if (PREDICT_FALSE (!wg_send_handshake_response (vm, peer)))
@@ -203,15 +207,18 @@ wg_handshake_process (vlib_main_t * vm, wg_main_t * wmp, vlib_buffer_t * b)
message_handshake_response_t *resp = current_b_data;
u32 *entry =
wg_index_table_lookup (&wmp->index_table, resp->receiver_index);
- if (entry)
+
+ if (PREDICT_TRUE (entry != NULL))
{
- peer = pool_elt_at_index (wmp->peers, *entry);
- if (!peer || peer->is_dead)
+ peer = wg_peer_get (*entry);
+ if (peer->is_dead)
return WG_INPUT_ERROR_PEER;
}
+ else
+ return WG_INPUT_ERROR_PEER;
if (!noise_consume_response
- (wmp->vlib_main, &peer->remote, resp->sender_index,
+ (vm, &peer->remote, resp->sender_index,
resp->receiver_index, resp->unencrypted_ephemeral,
resp->encrypted_nothing))
{
@@ -223,8 +230,9 @@ wg_handshake_process (vlib_main_t * vm, wg_main_t * wmp, vlib_buffer_t * b)
}
// set_peer_address (peer, ip4_src, udp_src_port);
- if (noise_remote_begin_session (wmp->vlib_main, &peer->remote))
+ if (noise_remote_begin_session (vm, &peer->remote))
{
+
wg_timers_session_derived (peer);
wg_timers_handshake_complete (peer);
if (PREDICT_FALSE (!wg_send_keepalive (vm, peer)))
@@ -272,6 +280,7 @@ VLIB_NODE_FN (wg_input_node) (vlib_main_t * vm,
u32 *from;
vlib_buffer_t *bufs[VLIB_FRAME_SIZE], **b;
u16 nexts[VLIB_FRAME_SIZE], *next;
+ u32 thread_index = vm->thread_index;
from = vlib_frame_vector_args (frame);
n_left_from = frame->n_vectors;
@@ -289,120 +298,132 @@ VLIB_NODE_FN (wg_input_node) (vlib_main_t * vm,
next[0] = WG_INPUT_NEXT_PUNT;
header_type =
((message_header_t *) vlib_buffer_get_current (b[0]))->type;
+ u32 *peer_idx;
- switch (header_type)
+ if (PREDICT_TRUE (header_type == MESSAGE_DATA))
{
- case MESSAGE_HANDSHAKE_INITIATION:
- case MESSAGE_HANDSHAKE_RESPONSE:
- case MESSAGE_HANDSHAKE_COOKIE:
- {
- wg_input_error_t ret = wg_handshake_process (vm, wmp, b[0]);
- if (ret != WG_INPUT_ERROR_NONE)
- {
- next[0] = WG_INPUT_NEXT_ERROR;
- b[0]->error = node->errors[ret];
- }
- break;
- }
- case MESSAGE_DATA:
- {
- message_data_t *data = vlib_buffer_get_current (b[0]);
- u32 *entry =
- wg_index_table_lookup (&wmp->index_table, data->receiver_index);
-
- if (entry)
- {
- peer = pool_elt_at_index (wmp->peers, *entry);
- }
- else
- {
- next[0] = WG_INPUT_NEXT_ERROR;
- b[0]->error = node->errors[WG_INPUT_ERROR_PEER];
- goto out;
- }
+ message_data_t *data = vlib_buffer_get_current (b[0]);
- u16 encr_len = b[0]->current_length - sizeof (message_data_t);
- u16 decr_len = encr_len - NOISE_AUTHTAG_LEN;
- u8 *decr_data = clib_mem_alloc (decr_len);
+ peer_idx = wg_index_table_lookup (&wmp->index_table,
+ data->receiver_index);
- enum noise_state_crypt state_cr =
- noise_remote_decrypt (wmp->vlib_main,
- &peer->remote,
- data->receiver_index,
- data->counter,
- data->encrypted_data,
- encr_len,
- decr_data);
+ if (peer_idx)
+ {
+ peer = wg_peer_get (*peer_idx);
+ }
+ else
+ {
+ next[0] = WG_INPUT_NEXT_ERROR;
+ b[0]->error = node->errors[WG_INPUT_ERROR_PEER];
+ goto out;
+ }
- switch (state_cr)
- {
- case SC_OK:
- break;
- case SC_CONN_RESET:
- wg_timers_handshake_complete (peer);
- break;
- case SC_KEEP_KEY_FRESH:
- if (PREDICT_FALSE (!wg_send_handshake (vm, peer, false)))
- {
- vlib_node_increment_counter (vm, wg_input_node.index,
- WG_INPUT_ERROR_HANDSHAKE_SEND,
- 1);
- }
- break;
- case SC_FAILED:
- next[0] = WG_INPUT_NEXT_ERROR;
- b[0]->error = node->errors[WG_INPUT_ERROR_DECRYPTION];
- goto out;
- default:
- break;
- }
+ if (PREDICT_FALSE (~0 == peer->input_thread_index))
+ {
+ /* this is the first packet to use this peer, claim the peer
+ * for this thread.
+ */
+ clib_atomic_cmp_and_swap (&peer->input_thread_index, ~0,
+ wg_peer_assign_thread (thread_index));
+ }
- clib_memcpy (vlib_buffer_get_current (b[0]), decr_data, decr_len);
- b[0]->current_length = decr_len;
- b[0]->flags &= ~VNET_BUFFER_F_OFFLOAD_UDP_CKSUM;
+ if (PREDICT_TRUE (thread_index != peer->input_thread_index))
+ {
+ next[0] = WG_INPUT_NEXT_HANDOFF_DATA;
+ goto next;
+ }
- clib_mem_free (decr_data);
+ u16 encr_len = b[0]->current_length - sizeof (message_data_t);
+ u16 decr_len = encr_len - NOISE_AUTHTAG_LEN;
+ if (PREDICT_FALSE (decr_len >= WG_DEFAULT_DATA_SIZE))
+ {
+ b[0]->error = node->errors[WG_INPUT_ERROR_TOO_BIG];
+ goto out;
+ }
- wg_timers_any_authenticated_packet_received (peer);
- wg_timers_any_authenticated_packet_traversal (peer);
+ u8 *decr_data = wmp->per_thread_data[thread_index].data;
- if (decr_len == 0)
- {
- is_keepalive = true;
- goto out;
- }
+ enum noise_state_crypt state_cr = noise_remote_decrypt (vm,
+ &peer->remote,
+ data->receiver_index,
+ data->counter,
+ data->encrypted_data,
+ encr_len,
+ decr_data);
- wg_timers_data_received (peer);
+ if (PREDICT_FALSE (state_cr == SC_CONN_RESET))
+ {
+ wg_timers_handshake_complete (peer);
+ }
+ else if (PREDICT_FALSE (state_cr == SC_KEEP_KEY_FRESH))
+ {
+ wg_send_handshake_from_mt (*peer_idx, false);
+ }
+ else if (PREDICT_FALSE (state_cr == SC_FAILED))
+ {
+ next[0] = WG_INPUT_NEXT_ERROR;
+ b[0]->error = node->errors[WG_INPUT_ERROR_DECRYPTION];
+ goto out;
+ }
- ip4_header_t *iph = vlib_buffer_get_current (b[0]);
+ clib_memcpy (vlib_buffer_get_current (b[0]), decr_data, decr_len);
+ b[0]->current_length = decr_len;
+ b[0]->flags &= ~VNET_BUFFER_F_OFFLOAD_UDP_CKSUM;
- const wg_peer_allowed_ip_t *allowed_ip;
- bool allowed = false;
+ wg_timers_any_authenticated_packet_received (peer);
+ wg_timers_any_authenticated_packet_traversal (peer);
- /*
- * we could make this into an ACL, but the expectation
- * is that there aren't many allowed IPs and thus a linear
- * walk is fater than an ACL
- */
- vec_foreach (allowed_ip, peer->allowed_ips)
+ /* Keepalive packet has zero length */
+ if (decr_len == 0)
{
- if (fib_prefix_is_cover_addr_4 (&allowed_ip->prefix,
- &iph->src_address))
- {
- allowed = true;
- break;
- }
+ is_keepalive = true;
+ goto out;
}
- if (allowed)
+
+ wg_timers_data_received (peer);
+
+ ip4_header_t *iph = vlib_buffer_get_current (b[0]);
+
+ const wg_peer_allowed_ip_t *allowed_ip;
+ bool allowed = false;
+
+ /*
+ * we could make this into an ACL, but the expectation
+ * is that there aren't many allowed IPs and thus a linear
+ * walk is fater than an ACL
+ */
+ vec_foreach (allowed_ip, peer->allowed_ips)
+ {
+ if (fib_prefix_is_cover_addr_4 (&allowed_ip->prefix,
+ &iph->src_address))
{
- vnet_buffer (b[0])->sw_if_index[VLIB_RX] =
- peer->wg_sw_if_index;
- next[0] = WG_INPUT_NEXT_IP4_INPUT;
+ allowed = true;
+ break;
}
- break;
}
- default:
- break;
+ if (allowed)
+ {
+ vnet_buffer (b[0])->sw_if_index[VLIB_RX] = peer->wg_sw_if_index;
+ next[0] = WG_INPUT_NEXT_IP4_INPUT;
+ }
+ }
+ else
+ {
+ peer_idx = NULL;
+
+ /* Handshake packets should be processed in main thread */
+ if (thread_index != 0)
+ {
+ next[0] = WG_INPUT_NEXT_HANDOFF_HANDSHAKE;
+ goto next;
+ }
+
+ wg_input_error_t ret = wg_handshake_process (vm, wmp, b[0]);
+ if (ret != WG_INPUT_ERROR_NONE)
+ {
+ next[0] = WG_INPUT_NEXT_ERROR;
+ b[0]->error = node->errors[ret];
+ }
}
out:
@@ -413,7 +434,9 @@ VLIB_NODE_FN (wg_input_node) (vlib_main_t * vm,
t->type = header_type;
t->current_length = b[0]->current_length;
t->is_keepalive = is_keepalive;
+ t->peer = peer_idx ? *peer_idx : INDEX_INVALID;
}
+ next:
n_left_from -= 1;
next += 1;
b += 1;
@@ -435,6 +458,8 @@ VLIB_REGISTER_NODE (wg_input_node) =
.n_next_nodes = WG_INPUT_N_NEXT,
/* edit / add dispositions here */
.next_nodes = {
+ [WG_INPUT_NEXT_HANDOFF_HANDSHAKE] = "wg-handshake-handoff",
+ [WG_INPUT_NEXT_HANDOFF_DATA] = "wg-input-data-handoff",
[WG_INPUT_NEXT_IP4_INPUT] = "ip4-input-no-checksum",
[WG_INPUT_NEXT_PUNT] = "error-punt",
[WG_INPUT_NEXT_ERROR] = "error-drop",
diff --git a/src/plugins/wireguard/wireguard_noise.c b/src/plugins/wireguard/wireguard_noise.c
index b47bb5747b9..00b67109de4 100755
--- a/src/plugins/wireguard/wireguard_noise.c
+++ b/src/plugins/wireguard/wireguard_noise.c
@@ -26,6 +26,8 @@
* <- e, ee, se, psk, {}
*/
+noise_local_t *noise_local_pool;
+
/* Private functions */
static noise_keypair_t *noise_remote_keypair_allocate (noise_remote_t *);
static void noise_remote_keypair_free (vlib_main_t * vm, noise_remote_t *,
@@ -80,81 +82,31 @@ noise_local_set_private (noise_local_t * l,
const uint8_t private[NOISE_PUBLIC_KEY_LEN])
{
clib_memcpy (l->l_private, private, NOISE_PUBLIC_KEY_LEN);
- l->l_has_identity = curve25519_gen_public (l->l_public, private);
-
- return l->l_has_identity;
-}
-bool
-noise_local_keys (noise_local_t * l, uint8_t public[NOISE_PUBLIC_KEY_LEN],
- uint8_t private[NOISE_PUBLIC_KEY_LEN])
-{
- if (l->l_has_identity)
- {
- if (public != NULL)
- clib_memcpy (public, l->l_public, NOISE_PUBLIC_KEY_LEN);
- if (private != NULL)
- clib_memcpy (private, l->l_private, NOISE_PUBLIC_KEY_LEN);
- }
- else
- {
- return false;
- }
- return true;
+ return curve25519_gen_public (l->l_public, private);
}
void
noise_remote_init (noise_remote_t * r, uint32_t peer_pool_idx,
const uint8_t public[NOISE_PUBLIC_KEY_LEN],
- noise_local_t * l)
+ u32 noise_local_idx)
{
clib_memset (r, 0, sizeof (*r));
clib_memcpy (r->r_public, public, NOISE_PUBLIC_KEY_LEN);
+ clib_rwlock_init (&r->r_keypair_lock);
r->r_peer_idx = peer_pool_idx;
-
- ASSERT (l != NULL);
- r->r_local = l;
+ r->r_local_idx = noise_local_idx;
r->r_handshake.hs_state = HS_ZEROED;
- noise_remote_precompute (r);
-}
-
-bool
-noise_remote_set_psk (noise_remote_t * r,
- uint8_t psk[NOISE_SYMMETRIC_KEY_LEN])
-{
- int same;
- same = !clib_memcmp (r->r_psk, psk, NOISE_SYMMETRIC_KEY_LEN);
- if (!same)
- {
- clib_memcpy (r->r_psk, psk, NOISE_SYMMETRIC_KEY_LEN);
- }
- return same == 0;
-}
-
-bool
-noise_remote_keys (noise_remote_t * r, uint8_t public[NOISE_PUBLIC_KEY_LEN],
- uint8_t psk[NOISE_SYMMETRIC_KEY_LEN])
-{
- static uint8_t null_psk[NOISE_SYMMETRIC_KEY_LEN];
- int ret;
- if (public != NULL)
- clib_memcpy (public, r->r_public, NOISE_PUBLIC_KEY_LEN);
-
- if (psk != NULL)
- clib_memcpy (psk, r->r_psk, NOISE_SYMMETRIC_KEY_LEN);
- ret = clib_memcmp (r->r_psk, null_psk, NOISE_SYMMETRIC_KEY_LEN);
-
- return ret;
+ noise_remote_precompute (r);
}
void
noise_remote_precompute (noise_remote_t * r)
{
- noise_local_t *l = r->r_local;
- if (!l->l_has_identity)
- clib_memset (r->r_ss, 0, NOISE_PUBLIC_KEY_LEN);
- else if (!curve25519_gen_shared (r->r_ss, l->l_private, r->r_public))
+ noise_local_t *l = noise_local_get (r->r_local_idx);
+
+ if (!curve25519_gen_shared (r->r_ss, l->l_private, r->r_public))
clib_memset (r->r_ss, 0, NOISE_PUBLIC_KEY_LEN);
noise_remote_handshake_index_drop (r);
@@ -169,7 +121,7 @@ noise_create_initiation (vlib_main_t * vm, noise_remote_t * r,
uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN])
{
noise_handshake_t *hs = &r->r_handshake;
- noise_local_t *l = r->r_local;
+ noise_local_t *l = noise_local_get (r->r_local_idx);
uint8_t _key[NOISE_SYMMETRIC_KEY_LEN];
uint32_t key_idx;
uint8_t *key;
@@ -180,8 +132,6 @@ noise_create_initiation (vlib_main_t * vm, noise_remote_t * r,
NOISE_SYMMETRIC_KEY_LEN);
key = vnet_crypto_get_key (key_idx)->data;
- if (!l->l_has_identity)
- goto error;
noise_param_init (hs->hs_ck, hs->hs_hash, r->r_public);
/* e */
@@ -239,8 +189,6 @@ noise_consume_initiation (vlib_main_t * vm, noise_local_t * l,
NOISE_SYMMETRIC_KEY_LEN);
key = vnet_crypto_get_key (key_idx)->data;
- if (!l->l_has_identity)
- goto error;
noise_param_init (hs.hs_ck, hs.hs_hash, l->l_public);
/* e */
@@ -294,6 +242,7 @@ noise_consume_initiation (vlib_main_t * vm, noise_local_t * l,
r->r_handshake = hs;
*rp = r;
ret = true;
+
error:
vnet_crypto_key_del (vm, key_idx);
secure_zero_memory (key, NOISE_SYMMETRIC_KEY_LEN);
@@ -359,7 +308,7 @@ noise_consume_response (vlib_main_t * vm, noise_remote_t * r, uint32_t s_idx,
uint32_t r_idx, uint8_t ue[NOISE_PUBLIC_KEY_LEN],
uint8_t en[0 + NOISE_AUTHTAG_LEN])
{
- noise_local_t *l = r->r_local;
+ noise_local_t *l = noise_local_get (r->r_local_idx);
noise_handshake_t hs;
uint8_t _key[NOISE_SYMMETRIC_KEY_LEN];
uint8_t preshared_key[NOISE_PUBLIC_KEY_LEN];
@@ -372,9 +321,6 @@ noise_consume_response (vlib_main_t * vm, noise_remote_t * r, uint32_t s_idx,
NOISE_SYMMETRIC_KEY_LEN);
key = vnet_crypto_get_key (key_idx)->data;
- if (!l->l_has_identity)
- goto error;
-
hs = r->r_handshake;
clib_memcpy (preshared_key, r->r_psk, NOISE_SYMMETRIC_KEY_LEN);
@@ -460,6 +406,7 @@ noise_remote_begin_session (vlib_main_t * vm, noise_remote_t * r)
clib_memset (&kp.kp_ctr, 0, sizeof (kp.kp_ctr));
/* Now we need to add_new_keypair */
+ clib_rwlock_writer_lock (&r->r_keypair_lock);
next = r->r_next;
current = r->r_current;
previous = r->r_previous;
@@ -491,7 +438,10 @@ noise_remote_begin_session (vlib_main_t * vm, noise_remote_t * r)
r->r_next = noise_remote_keypair_allocate (r);
*r->r_next = kp;
}
+ clib_rwlock_writer_unlock (&r->r_keypair_lock);
+
secure_zero_memory (&r->r_handshake, sizeof (r->r_handshake));
+
secure_zero_memory (&kp, sizeof (kp));
return true;
}
@@ -502,21 +452,25 @@ noise_remote_clear (vlib_main_t * vm, noise_remote_t * r)
noise_remote_handshake_index_drop (r);
secure_zero_memory (&r->r_handshake, sizeof (r->r_handshake));
+ clib_rwlock_writer_lock (&r->r_keypair_lock);
noise_remote_keypair_free (vm, r, &r->r_next);
noise_remote_keypair_free (vm, r, &r->r_current);
noise_remote_keypair_free (vm, r, &r->r_previous);
r->r_next = NULL;
r->r_current = NULL;
r->r_previous = NULL;
+ clib_rwlock_writer_unlock (&r->r_keypair_lock);
}
void
noise_remote_expire_current (noise_remote_t * r)
{
+ clib_rwlock_writer_lock (&r->r_keypair_lock);
if (r->r_next != NULL)
r->r_next->kp_valid = 0;
if (r->r_current != NULL)
r->r_current->kp_valid = 0;
+ clib_rwlock_writer_unlock (&r->r_keypair_lock);
}
bool
@@ -525,6 +479,7 @@ noise_remote_ready (noise_remote_t * r)
noise_keypair_t *kp;
int ret;
+ clib_rwlock_reader_lock (&r->r_keypair_lock);
if ((kp = r->r_current) == NULL ||
!kp->kp_valid ||
wg_birthdate_has_expired (kp->kp_birthdate, REJECT_AFTER_TIME) ||
@@ -533,6 +488,7 @@ noise_remote_ready (noise_remote_t * r)
ret = false;
else
ret = true;
+ clib_rwlock_reader_unlock (&r->r_keypair_lock);
return ret;
}
@@ -592,6 +548,7 @@ noise_remote_encrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t * r_idx,
noise_keypair_t *kp;
enum noise_state_crypt ret = SC_FAILED;
+ clib_rwlock_reader_lock (&r->r_keypair_lock);
if ((kp = r->r_current) == NULL)
goto error;
@@ -631,6 +588,7 @@ noise_remote_encrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t * r_idx,
ret = SC_OK;
error:
+ clib_rwlock_reader_unlock (&r->r_keypair_lock);
return ret;
}
@@ -641,6 +599,7 @@ noise_remote_decrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t r_idx,
{
noise_keypair_t *kp;
enum noise_state_crypt ret = SC_FAILED;
+ clib_rwlock_reader_lock (&r->r_keypair_lock);
if (r->r_current != NULL && r->r_current->kp_local_index == r_idx)
{
@@ -682,18 +641,26 @@ noise_remote_decrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t r_idx,
* next keypair into current. If we do slide the next keypair in, then
* we skip the REKEY_AFTER_TIME_RECV check. This is safe to do as a
* data packet can't confirm a session that we are an INITIATOR of. */
- if (kp == r->r_next && kp->kp_local_index == r_idx)
+ if (kp == r->r_next)
{
- noise_remote_keypair_free (vm, r, &r->r_previous);
- r->r_previous = r->r_current;
- r->r_current = r->r_next;
- r->r_next = NULL;
+ clib_rwlock_reader_unlock (&r->r_keypair_lock);
+ clib_rwlock_writer_lock (&r->r_keypair_lock);
+ if (kp == r->r_next && kp->kp_local_index == r_idx)
+ {
+ noise_remote_keypair_free (vm, r, &r->r_previous);
+ r->r_previous = r->r_current;
+ r->r_current = r->r_next;
+ r->r_next = NULL;
- ret = SC_CONN_RESET;
- goto error;
+ ret = SC_CONN_RESET;
+ clib_rwlock_writer_unlock (&r->r_keypair_lock);
+ clib_rwlock_reader_lock (&r->r_keypair_lock);
+ goto error;
+ }
+ clib_rwlock_writer_unlock (&r->r_keypair_lock);
+ clib_rwlock_reader_lock (&r->r_keypair_lock);
}
-
/* Similar to when we encrypt, we want to notify the caller when we
* are approaching our tolerances. We notify if:
* - we're the initiator and the current keypair is older than
@@ -708,6 +675,7 @@ noise_remote_decrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t r_idx,
ret = SC_OK;
error:
+ clib_rwlock_reader_unlock (&r->r_keypair_lock);
return ret;
}
@@ -725,7 +693,8 @@ static void
noise_remote_keypair_free (vlib_main_t * vm, noise_remote_t * r,
noise_keypair_t ** kp)
{
- struct noise_upcall *u = &r->r_local->l_upcall;
+ noise_local_t *local = noise_local_get (r->r_local_idx);
+ struct noise_upcall *u = &local->l_upcall;
if (*kp)
{
u->u_index_drop ((*kp)->kp_local_index);
@@ -738,7 +707,8 @@ noise_remote_keypair_free (vlib_main_t * vm, noise_remote_t * r,
static uint32_t
noise_remote_handshake_index_get (noise_remote_t * r)
{
- struct noise_upcall *u = &r->r_local->l_upcall;
+ noise_local_t *local = noise_local_get (r->r_local_idx);
+ struct noise_upcall *u = &local->l_upcall;
return u->u_index_set (r);
}
@@ -746,7 +716,8 @@ static void
noise_remote_handshake_index_drop (noise_remote_t * r)
{
noise_handshake_t *hs = &r->r_handshake;
- struct noise_upcall *u = &r->r_local->l_upcall;
+ noise_local_t *local = noise_local_get (r->r_local_idx);
+ struct noise_upcall *u = &local->l_upcall;
if (hs->hs_state != HS_ZEROED)
u->u_index_drop (hs->hs_local_index);
}
@@ -754,7 +725,8 @@ noise_remote_handshake_index_drop (noise_remote_t * r)
static uint64_t
noise_counter_send (noise_counter_t * ctr)
{
- uint64_t ret = ctr->c_send++;
+ uint64_t ret;
+ ret = ctr->c_send++;
return ret;
}
@@ -765,7 +737,6 @@ noise_counter_recv (noise_counter_t * ctr, uint64_t recv)
unsigned long bit;
bool ret = false;
-
/* Check that the recv counter is valid */
if (ctr->c_recv >= REJECT_AFTER_MESSAGES || recv >= REJECT_AFTER_MESSAGES)
goto error;
diff --git a/src/plugins/wireguard/wireguard_noise.h b/src/plugins/wireguard/wireguard_noise.h
index 1f6804c59ca..5b5a88fa250 100755
--- a/src/plugins/wireguard/wireguard_noise.h
+++ b/src/plugins/wireguard/wireguard_noise.h
@@ -100,7 +100,7 @@ typedef struct noise_remote
{
uint32_t r_peer_idx;
uint8_t r_public[NOISE_PUBLIC_KEY_LEN];
- noise_local_t *r_local;
+ uint32_t r_local_idx;
uint8_t r_ss[NOISE_PUBLIC_KEY_LEN];
noise_handshake_t r_handshake;
@@ -108,37 +108,40 @@ typedef struct noise_remote
uint8_t r_timestamp[NOISE_TIMESTAMP_LEN];
f64 r_last_init;
+ clib_rwlock_t r_keypair_lock;
noise_keypair_t *r_next, *r_current, *r_previous;
} noise_remote_t;
typedef struct noise_local
{
- bool l_has_identity;
uint8_t l_public[NOISE_PUBLIC_KEY_LEN];
uint8_t l_private[NOISE_PUBLIC_KEY_LEN];
struct noise_upcall
{
void *u_arg;
- noise_remote_t *(*u_remote_get) (uint8_t[NOISE_PUBLIC_KEY_LEN]);
+ noise_remote_t *(*u_remote_get) (const uint8_t[NOISE_PUBLIC_KEY_LEN]);
uint32_t (*u_index_set) (noise_remote_t *);
void (*u_index_drop) (uint32_t);
} l_upcall;
} noise_local_t;
+/* pool of noise_local */
+extern noise_local_t *noise_local_pool;
+
/* Set/Get noise parameters */
+static_always_inline noise_local_t *
+noise_local_get (uint32_t locali)
+{
+ return (pool_elt_at_index (noise_local_pool, locali));
+}
+
void noise_local_init (noise_local_t *, struct noise_upcall *);
bool noise_local_set_private (noise_local_t *,
const uint8_t[NOISE_PUBLIC_KEY_LEN]);
-bool noise_local_keys (noise_local_t *, uint8_t[NOISE_PUBLIC_KEY_LEN],
- uint8_t[NOISE_PUBLIC_KEY_LEN]);
void noise_remote_init (noise_remote_t *, uint32_t,
- const uint8_t[NOISE_PUBLIC_KEY_LEN], noise_local_t *);
-bool noise_remote_set_psk (noise_remote_t *,
- uint8_t[NOISE_SYMMETRIC_KEY_LEN]);
-bool noise_remote_keys (noise_remote_t *, uint8_t[NOISE_PUBLIC_KEY_LEN],
- uint8_t[NOISE_SYMMETRIC_KEY_LEN]);
+ const uint8_t[NOISE_PUBLIC_KEY_LEN], uint32_t);
/* Should be called anytime noise_local_set_private is called */
void noise_remote_precompute (noise_remote_t *);
diff --git a/src/plugins/wireguard/wireguard_output_tun.c b/src/plugins/wireguard/wireguard_output_tun.c
index cdfd9d730f6..9a8710b77db 100755
--- a/src/plugins/wireguard/wireguard_output_tun.c
+++ b/src/plugins/wireguard/wireguard_output_tun.c
@@ -15,10 +15,6 @@
#include <vlib/vlib.h>
#include <vnet/vnet.h>
-#include <vnet/pg/pg.h>
-#include <vnet/fib/ip6_fib.h>
-#include <vnet/fib/ip4_fib.h>
-#include <vnet/fib/fib_entry.h>
#include <vppinfra/error.h>
#include <wireguard/wireguard.h>
@@ -28,19 +24,8 @@
_(NONE, "No error") \
_(PEER, "Peer error") \
_(KEYPAIR, "Keypair error") \
- _(HANDSHAKE_SEND, "Handshake sending failed") \
_(TOO_BIG, "packet too big") \
-#define WG_OUTPUT_SCRATCH_SIZE 2048
-
-typedef struct wg_output_scratch_t_
-{
- u8 scratch[WG_OUTPUT_SCRATCH_SIZE];
-} wg_output_scratch_t;
-
-/* Cache line aligned per-thread scratch space */
-static wg_output_scratch_t *wg_output_scratchs;
-
typedef enum
{
#define _(sym,str) WG_OUTPUT_ERROR_##sym,
@@ -58,6 +43,7 @@ static char *wg_output_error_strings[] = {
typedef enum
{
WG_OUTPUT_NEXT_ERROR,
+ WG_OUTPUT_NEXT_HANDOFF,
WG_OUTPUT_NEXT_INTERFACE_OUTPUT,
WG_OUTPUT_N_NEXT,
} wg_output_next_t;
@@ -65,6 +51,7 @@ typedef enum
typedef struct
{
ip4_udp_header_t hdr;
+ index_t peer;
} wg_output_tun_trace_t;
u8 *
@@ -87,7 +74,8 @@ format_wg_output_tun_trace (u8 * s, va_list * args)
wg_output_tun_trace_t *t = va_arg (*args, wg_output_tun_trace_t *);
- s = format (s, "Encrypted packet: %U\n", format_ip4_udp_header, &t->hdr);
+ s = format (s, "peer: %d\n", t->peer);
+ s = format (s, " Encrypted packet: %U", format_ip4_udp_header, &t->hdr);
return s;
}
@@ -109,7 +97,6 @@ VLIB_NODE_FN (wg_output_tun_node) (vlib_main_t * vm,
vlib_get_buffers (vm, from, bufs, n_left_from);
wg_main_t *wmp = &wg_main;
- u32 handsh_fails = 0;
wg_peer_t *peer = NULL;
while (n_left_from > 0)
@@ -119,11 +106,12 @@ VLIB_NODE_FN (wg_output_tun_node) (vlib_main_t * vm,
sizeof (ip4_udp_header_t));
u16 plain_data_len =
clib_net_to_host_u16 (((ip4_header_t *) plain_data)->length);
+ index_t peeri;
next[0] = WG_OUTPUT_NEXT_ERROR;
-
- peer =
+ peeri =
wg_peer_get_by_adj_index (vnet_buffer (b[0])->ip.adj_index[VLIB_TX]);
+ peer = wg_peer_get (peeri);
if (!peer || peer->is_dead)
{
@@ -131,21 +119,34 @@ VLIB_NODE_FN (wg_output_tun_node) (vlib_main_t * vm,
goto out;
}
+ if (PREDICT_FALSE (~0 == peer->output_thread_index))
+ {
+ /* this is the first packet to use this peer, claim the peer
+ * for this thread.
+ */
+ clib_atomic_cmp_and_swap (&peer->output_thread_index, ~0,
+ wg_peer_assign_thread (thread_index));
+ }
+
+ if (PREDICT_TRUE (thread_index != peer->output_thread_index))
+ {
+ next[0] = WG_OUTPUT_NEXT_HANDOFF;
+ goto next;
+ }
+
if (PREDICT_FALSE (!peer->remote.r_current))
{
- if (PREDICT_FALSE (!wg_send_handshake (vm, peer, false)))
- handsh_fails++;
+ wg_send_handshake_from_mt (peeri, false);
b[0]->error = node->errors[WG_OUTPUT_ERROR_KEYPAIR];
goto out;
}
-
size_t encrypted_packet_len = message_data_len (plain_data_len);
/*
* Ensure there is enough space to write the encrypted data
* into the packet
*/
- if (PREDICT_FALSE (encrypted_packet_len >= WG_OUTPUT_SCRATCH_SIZE) ||
+ if (PREDICT_FALSE (encrypted_packet_len >= WG_DEFAULT_DATA_SIZE) ||
PREDICT_FALSE ((b[0]->current_data + encrypted_packet_len) >=
vlib_buffer_get_default_data_size (vm)))
{
@@ -154,35 +155,29 @@ VLIB_NODE_FN (wg_output_tun_node) (vlib_main_t * vm,
}
message_data_t *encrypted_packet =
- (message_data_t *) wg_output_scratchs[thread_index].scratch;
+ (message_data_t *) wmp->per_thread_data[thread_index].data;
enum noise_state_crypt state;
state =
- noise_remote_encrypt (wmp->vlib_main,
+ noise_remote_encrypt (vm,
&peer->remote,
&encrypted_packet->receiver_index,
&encrypted_packet->counter, plain_data,
plain_data_len,
encrypted_packet->encrypted_data);
- switch (state)
+
+ if (PREDICT_FALSE (state == SC_KEEP_KEY_FRESH))
+ {
+ wg_send_handshake_from_mt (peeri, false);
+ }
+ else if (PREDICT_FALSE (state == SC_FAILED))
{
- case SC_OK:
- break;
- case SC_KEEP_KEY_FRESH:
- if (PREDICT_FALSE (!wg_send_handshake (vm, peer, false)))
- handsh_fails++;
- break;
- case SC_FAILED:
//TODO: Maybe wrong
- if (PREDICT_FALSE (!wg_send_handshake (vm, peer, false)))
- handsh_fails++;
- clib_mem_free (encrypted_packet);
+ wg_send_handshake_from_mt (peeri, false);
goto out;
- default:
- break;
}
- // Here we are sure that can send packet to next node.
+ /* Here we are sure that can send packet to next node */
next[0] = WG_OUTPUT_NEXT_INTERFACE_OUTPUT;
encrypted_packet->header.type = MESSAGE_DATA;
@@ -195,9 +190,9 @@ VLIB_NODE_FN (wg_output_tun_node) (vlib_main_t * vm,
ip4_header_set_len_w_chksum
(&hdr->ip4, clib_host_to_net_u16 (b[0]->current_length));
- wg_timers_any_authenticated_packet_traversal (peer);
wg_timers_any_authenticated_packet_sent (peer);
wg_timers_data_sent (peer);
+ wg_timers_any_authenticated_packet_traversal (peer);
out:
if (PREDICT_FALSE ((node->flags & VLIB_NODE_FLAG_TRACE)
@@ -206,17 +201,15 @@ VLIB_NODE_FN (wg_output_tun_node) (vlib_main_t * vm,
wg_output_tun_trace_t *t =
vlib_add_trace (vm, node, b[0], sizeof (*t));
t->hdr = *hdr;
+ t->peer = peeri;
}
+ next:
n_left_from -= 1;
next += 1;
b += 1;
}
vlib_buffer_enqueue_to_next (vm, node, from, nexts, frame->n_vectors);
-
- vlib_node_increment_counter (vm, node->node_index,
- WG_OUTPUT_ERROR_HANDSHAKE_SEND, handsh_fails);
-
return frame->n_vectors;
}
@@ -231,24 +224,13 @@ VLIB_REGISTER_NODE (wg_output_tun_node) =
.error_strings = wg_output_error_strings,
.n_next_nodes = WG_OUTPUT_N_NEXT,
.next_nodes = {
+ [WG_OUTPUT_NEXT_HANDOFF] = "wg-output-tun-handoff",
[WG_OUTPUT_NEXT_INTERFACE_OUTPUT] = "adj-midchain-tx",
[WG_OUTPUT_NEXT_ERROR] = "error-drop",
},
};
/* *INDENT-ON* */
-static clib_error_t *
-wireguard_output_module_init (vlib_main_t * vm)
-{
- vlib_thread_main_t *tm = vlib_get_thread_main ();
-
- vec_validate_aligned (wg_output_scratchs, tm->n_vlib_mains,
- CLIB_CACHE_LINE_BYTES);
- return (NULL);
-}
-
-VLIB_INIT_FUNCTION (wireguard_output_module_init);
-
/*
* fd.io coding-style-patch-verification: ON
*
diff --git a/src/plugins/wireguard/wireguard_peer.c b/src/plugins/wireguard/wireguard_peer.c
index 30adea82647..b41118f83d1 100755
--- a/src/plugins/wireguard/wireguard_peer.c
+++ b/src/plugins/wireguard/wireguard_peer.c
@@ -23,15 +23,10 @@
#include <wireguard/wireguard.h>
static fib_source_t wg_fib_source;
+wg_peer_t *wg_peer_pool;
index_t *wg_peer_by_adj_index;
-wg_peer_t *
-wg_peer_get (index_t peeri)
-{
- return (pool_elt_at_index (wg_main.peers, peeri));
-}
-
static void
wg_peer_endpoint_reset (wg_peer_endpoint_t * ep)
{
@@ -82,7 +77,11 @@ static void
wg_peer_clear (vlib_main_t * vm, wg_peer_t * peer)
{
wg_timers_stop (peer);
- noise_remote_clear (vm, &peer->remote);
+ for (int i = 0; i < WG_N_TIMERS; i++)
+ {
+ peer->timers[i] = ~0;
+ }
+
peer->last_sent_handshake = vlib_time_now (vm) - (REKEY_TIMEOUT + 1);
clib_memset (&peer->cookie_maker, 0, sizeof (peer->cookie_maker));
@@ -97,9 +96,18 @@ wg_peer_clear (vlib_main_t * vm, wg_peer_t * peer)
}
wg_peer_fib_flush (peer);
+ peer->input_thread_index = ~0;
+ peer->output_thread_index = ~0;
peer->adj_index = INDEX_INVALID;
+ peer->timer_wheel = 0;
peer->persistent_keepalive_interval = 0;
peer->timer_handshake_attempts = 0;
+ peer->last_sent_packet = 0;
+ peer->last_received_packet = 0;
+ peer->session_derived = 0;
+ peer->rehandshake_started = 0;
+ peer->new_handshake_interval_tick = 0;
+ peer->rehandshake_interval_tick = 0;
peer->timer_need_another_keepalive = false;
peer->is_dead = true;
vec_free (peer->allowed_ips);
@@ -108,7 +116,7 @@ wg_peer_clear (vlib_main_t * vm, wg_peer_t * peer)
static void
wg_peer_init (vlib_main_t * vm, wg_peer_t * peer)
{
- wg_timers_init (peer, vlib_time_now (vm));
+ peer->adj_index = INDEX_INVALID;
wg_peer_clear (vm, peer);
}
@@ -205,8 +213,9 @@ wg_peer_fill (vlib_main_t * vm, wg_peer_t * peer,
wg_peer_endpoint_init (&peer->dst, dst, port);
peer->table_id = table_id;
- peer->persistent_keepalive_interval = persistent_keepalive_interval;
peer->wg_sw_if_index = wg_sw_if_index;
+ peer->timer_wheel = &wg_main.timer_wheel;
+ peer->persistent_keepalive_interval = persistent_keepalive_interval;
peer->last_sent_handshake = vlib_time_now (vm) - (REKEY_TIMEOUT + 1);
peer->is_dead = false;
@@ -230,7 +239,7 @@ wg_peer_fill (vlib_main_t * vm, wg_peer_t * peer,
vec_validate_init_empty (wg_peer_by_adj_index,
peer->adj_index, INDEX_INVALID);
- wg_peer_by_adj_index[peer->adj_index] = peer - wg_main.peers;
+ wg_peer_by_adj_index[peer->adj_index] = peer - wg_peer_pool;
adj_nbr_midchain_update_rewrite (peer->adj_index,
NULL,
@@ -280,7 +289,7 @@ wg_peer_add (u32 tun_sw_if_index,
return (VNET_API_ERROR_INVALID_SW_IF_INDEX);
/* *INDENT-OFF* */
- pool_foreach (peer, wg_main.peers,
+ pool_foreach (peer, wg_peer_pool,
({
if (!memcmp (peer->remote.r_public, public_key, NOISE_PUBLIC_KEY_LEN))
{
@@ -289,10 +298,10 @@ wg_peer_add (u32 tun_sw_if_index,
}));
/* *INDENT-ON* */
- if (pool_elts (wg_main.peers) > MAX_PEERS)
+ if (pool_elts (wg_peer_pool) > MAX_PEERS)
return (VNET_API_ERROR_LIMIT_EXCEEDED);
- pool_get (wg_main.peers, peer);
+ pool_get (wg_peer_pool, peer);
wg_peer_init (vm, peer);
@@ -302,12 +311,12 @@ wg_peer_add (u32 tun_sw_if_index,
if (rv)
{
wg_peer_clear (vm, peer);
- pool_put (wg_main.peers, peer);
+ pool_put (wg_peer_pool, peer);
return (rv);
}
- noise_remote_init (&peer->remote, peer - wg_main.peers, public_key,
- &wg_if->local);
+ noise_remote_init (&peer->remote, peer - wg_peer_pool, public_key,
+ wg_if->local_idx);
cookie_maker_init (&peer->cookie_maker, public_key);
if (peer->persistent_keepalive_interval != 0)
@@ -315,7 +324,7 @@ wg_peer_add (u32 tun_sw_if_index,
wg_send_keepalive (vm, peer);
}
- *peer_index = peer - wg_main.peers;
+ *peer_index = peer - wg_peer_pool;
wg_if_peer_add (wg_if, *peer_index);
return (0);
@@ -328,34 +337,37 @@ wg_peer_remove (index_t peeri)
wg_peer_t *peer = NULL;
wg_if_t *wgi;
- if (pool_is_free_index (wmp->peers, peeri))
+ if (pool_is_free_index (wg_peer_pool, peeri))
return VNET_API_ERROR_NO_SUCH_ENTRY;
- peer = pool_elt_at_index (wmp->peers, peeri);
+ peer = pool_elt_at_index (wg_peer_pool, peeri);
wgi = wg_if_get (wg_if_find_by_sw_if_index (peer->wg_sw_if_index));
wg_if_peer_remove (wgi, peeri);
vnet_feature_enable_disable ("ip4-output", "wg-output-tun",
peer->wg_sw_if_index, 0, 0, 0);
+
+ noise_remote_clear (wmp->vlib_main, &peer->remote);
wg_peer_clear (wmp->vlib_main, peer);
- pool_put (wmp->peers, peer);
+ pool_put (wg_peer_pool, peer);
return (0);
}
-void
+index_t
wg_peer_walk (wg_peer_walk_cb_t fn, void *data)
{
index_t peeri;
/* *INDENT-OFF* */
- pool_foreach_index(peeri, wg_main.peers,
+ pool_foreach_index(peeri, wg_peer_pool,
{
if (WALK_STOP == fn(peeri, data))
- break;
+ return peeri;
});
/* *INDENT-ON* */
+ return INDEX_INVALID;
}
static u8 *
diff --git a/src/plugins/wireguard/wireguard_peer.h b/src/plugins/wireguard/wireguard_peer.h
index 99c73f3a0ed..009a6f67aeb 100755
--- a/src/plugins/wireguard/wireguard_peer.h
+++ b/src/plugins/wireguard/wireguard_peer.h
@@ -49,6 +49,9 @@ typedef struct wg_peer
noise_remote_t remote;
cookie_maker_t cookie_maker;
+ u32 input_thread_index;
+ u32 output_thread_index;
+
/* Peer addresses */
wg_peer_endpoint_t dst;
wg_peer_endpoint_t src;
@@ -65,11 +68,22 @@ typedef struct wg_peer
u32 wg_sw_if_index;
/* Timers */
- tw_timer_wheel_16t_2w_512sl_t timer_wheel;
+ tw_timer_wheel_16t_2w_512sl_t *timer_wheel;
u32 timers[WG_N_TIMERS];
u32 timer_handshake_attempts;
u16 persistent_keepalive_interval;
+
+ /* Timestamps */
f64 last_sent_handshake;
+ f64 last_sent_packet;
+ f64 last_received_packet;
+ f64 session_derived;
+ f64 rehandshake_started;
+
+ /* Variable intervals */
+ u32 new_handshake_interval_tick;
+ u32 rehandshake_interval_tick;
+
bool timer_need_another_keepalive;
bool is_dead;
@@ -91,10 +105,9 @@ int wg_peer_add (u32 tun_sw_if_index,
int wg_peer_remove (u32 peer_index);
typedef walk_rc_t (*wg_peer_walk_cb_t) (index_t peeri, void *arg);
-void wg_peer_walk (wg_peer_walk_cb_t fn, void *data);
+index_t wg_peer_walk (wg_peer_walk_cb_t fn, void *data);
u8 *format_wg_peer (u8 * s, va_list * va);
-wg_peer_t *wg_peer_get (index_t peeri);
walk_rc_t wg_peer_if_admin_state_change (wg_if_t * wgi, index_t peeri,
void *data);
@@ -104,11 +117,30 @@ walk_rc_t wg_peer_if_table_change (wg_if_t * wgi, index_t peeri, void *data);
* Expoed for the data-plane
*/
extern index_t *wg_peer_by_adj_index;
+extern wg_peer_t *wg_peer_pool;
static inline wg_peer_t *
+wg_peer_get (index_t peeri)
+{
+ return (pool_elt_at_index (wg_peer_pool, peeri));
+}
+
+static inline index_t
wg_peer_get_by_adj_index (index_t ai)
{
- return wg_peer_get (wg_peer_by_adj_index[ai]);
+ return (wg_peer_by_adj_index[ai]);
+}
+
+/*
+ * Makes choice for thread_id should be assigned.
+*/
+static inline u32
+wg_peer_assign_thread (u32 thread_id)
+{
+ return ((thread_id) ? thread_id
+ : (vlib_num_workers ()?
+ ((unix_time_now_nsec () % vlib_num_workers ()) +
+ 1) : thread_id));
}
#endif // __included_wg_peer_h__
diff --git a/src/plugins/wireguard/wireguard_send.c b/src/plugins/wireguard/wireguard_send.c
index a5d8aaf6900..2e29a9b4b76 100755
--- a/src/plugins/wireguard/wireguard_send.c
+++ b/src/plugins/wireguard/wireguard_send.c
@@ -14,13 +14,11 @@
*/
#include <vnet/vnet.h>
-#include <vnet/fib/ip6_fib.h>
-#include <vnet/fib/ip4_fib.h>
-#include <vnet/fib/fib_entry.h>
#include <vnet/ip/ip6_link.h>
#include <vnet/pg/pg.h>
#include <vnet/udp/udp.h>
#include <vppinfra/error.h>
+#include <vlibmemory/api.h>
#include <wireguard/wireguard.h>
#include <wireguard/wireguard_send.h>
@@ -86,7 +84,8 @@ wg_create_buffer (vlib_main_t * vm,
bool
wg_send_handshake (vlib_main_t * vm, wg_peer_t * peer, bool is_retry)
{
- wg_main_t *wmp = &wg_main;
+ ASSERT (vm->thread_index == 0);
+
message_handshake_initiation_t packet;
if (!is_retry)
@@ -94,41 +93,73 @@ wg_send_handshake (vlib_main_t * vm, wg_peer_t * peer, bool is_retry)
if (!wg_birthdate_has_expired (peer->last_sent_handshake,
REKEY_TIMEOUT) || peer->is_dead)
- {
- return true;
- }
- if (noise_create_initiation (wmp->vlib_main,
+ return true;
+
+ if (noise_create_initiation (vm,
&peer->remote,
&packet.sender_index,
packet.unencrypted_ephemeral,
packet.encrypted_static,
packet.encrypted_timestamp))
{
- f64 now = vlib_time_now (vm);
packet.header.type = MESSAGE_HANDSHAKE_INITIATION;
cookie_maker_mac (&peer->cookie_maker, &packet.macs, &packet,
sizeof (packet));
- wg_timers_any_authenticated_packet_traversal (peer);
wg_timers_any_authenticated_packet_sent (peer);
- peer->last_sent_handshake = now;
wg_timers_handshake_initiated (peer);
+ wg_timers_any_authenticated_packet_traversal (peer);
+
+ peer->last_sent_handshake = vlib_time_now (vm);
}
else
return false;
+
u32 bi0 = 0;
if (!wg_create_buffer (vm, peer, (u8 *) & packet, sizeof (packet), &bi0))
return false;
- ip46_enqueue_packet (vm, bi0, false);
+ ip46_enqueue_packet (vm, bi0, false);
return true;
}
+typedef struct
+{
+ u32 peer_idx;
+ bool is_retry;
+} wg_send_args_t;
+
+static void *
+wg_send_handshake_thread_fn (void *arg)
+{
+ wg_send_args_t *a = arg;
+
+ wg_main_t *wmp = &wg_main;
+ wg_peer_t *peer = wg_peer_get (a->peer_idx);
+
+ wg_send_handshake (wmp->vlib_main, peer, a->is_retry);
+ return 0;
+}
+
+void
+wg_send_handshake_from_mt (u32 peer_idx, bool is_retry)
+{
+ wg_send_args_t a = {
+ .peer_idx = peer_idx,
+ .is_retry = is_retry,
+ };
+
+ vl_api_rpc_call_main_thread (wg_send_handshake_thread_fn,
+ (u8 *) & a, sizeof (a));
+}
+
bool
wg_send_keepalive (vlib_main_t * vm, wg_peer_t * peer)
{
- wg_main_t *wmp = &wg_main;
+ ASSERT (vm->thread_index == 0);
+
u32 size_of_packet = message_data_len (0);
- message_data_t *packet = clib_mem_alloc (size_of_packet);
+ message_data_t *packet =
+ (message_data_t *) wg_main.per_thread_data[vm->thread_index].data;
u32 bi0 = 0;
bool ret = true;
enum noise_state_crypt state;
@@ -140,23 +171,21 @@ wg_send_keepalive (vlib_main_t * vm, wg_peer_t * peer)
}
state =
- noise_remote_encrypt (wmp->vlib_main,
+ noise_remote_encrypt (vm,
&peer->remote,
&packet->receiver_index,
&packet->counter, NULL, 0, packet->encrypted_data);
- switch (state)
+
+ if (PREDICT_FALSE (state == SC_KEEP_KEY_FRESH))
{
- case SC_OK:
- break;
- case SC_KEEP_KEY_FRESH:
wg_send_handshake (vm, peer, false);
- break;
- case SC_FAILED:
+ }
+ else if (PREDICT_FALSE (state == SC_FAILED))
+ {
ret = false;
goto out;
- default:
- break;
}
+
packet->header.type = MESSAGE_DATA;
if (!wg_create_buffer (vm, peer, (u8 *) packet, size_of_packet, &bi0))
@@ -166,22 +195,19 @@ wg_send_keepalive (vlib_main_t * vm, wg_peer_t * peer)
}
ip46_enqueue_packet (vm, bi0, false);
- wg_timers_any_authenticated_packet_traversal (peer);
+
wg_timers_any_authenticated_packet_sent (peer);
+ wg_timers_any_authenticated_packet_traversal (peer);
out:
- clib_mem_free (packet);
return ret;
}
bool
wg_send_handshake_response (vlib_main_t * vm, wg_peer_t * peer)
{
- wg_main_t *wmp = &wg_main;
message_handshake_response_t packet;
- peer->last_sent_handshake = vlib_time_now (vm);
-
if (noise_create_response (vm,
&peer->remote,
&packet.sender_index,
@@ -189,17 +215,16 @@ wg_send_handshake_response (vlib_main_t * vm, wg_peer_t * peer)
packet.unencrypted_ephemeral,
packet.encrypted_nothing))
{
- f64 now = vlib_time_now (vm);
packet.header.type = MESSAGE_HANDSHAKE_RESPONSE;
cookie_maker_mac (&peer->cookie_maker, &packet.macs, &packet,
sizeof (packet));
- if (noise_remote_begin_session (wmp->vlib_main, &peer->remote))
+ if (noise_remote_begin_session (vm, &peer->remote))
{
wg_timers_session_derived (peer);
- wg_timers_any_authenticated_packet_traversal (peer);
wg_timers_any_authenticated_packet_sent (peer);
- peer->last_sent_handshake = now;
+ wg_timers_any_authenticated_packet_traversal (peer);
+ peer->last_sent_handshake = vlib_time_now (vm);
u32 bi0 = 0;
if (!wg_create_buffer (vm, peer, (u8 *) & packet,
diff --git a/src/plugins/wireguard/wireguard_send.h b/src/plugins/wireguard/wireguard_send.h
index 4ea1f6effea..efe41949428 100755
--- a/src/plugins/wireguard/wireguard_send.h
+++ b/src/plugins/wireguard/wireguard_send.h
@@ -20,6 +20,7 @@
bool wg_send_keepalive (vlib_main_t * vm, wg_peer_t * peer);
bool wg_send_handshake (vlib_main_t * vm, wg_peer_t * peer, bool is_retry);
+void wg_send_handshake_from_mt (u32 peer_index, bool is_retry);
bool wg_send_handshake_response (vlib_main_t * vm, wg_peer_t * peer);
always_inline void
diff --git a/src/plugins/wireguard/wireguard_timer.c b/src/plugins/wireguard/wireguard_timer.c
index e4d4030bb18..b7fd6891d14 100755
--- a/src/plugins/wireguard/wireguard_timer.c
+++ b/src/plugins/wireguard/wireguard_timer.c
@@ -13,6 +13,7 @@
* limitations under the License.
*/
+#include <vlibmemory/api.h>
#include <wireguard/wireguard.h>
#include <wireguard/wireguard_send.h>
#include <wireguard/wireguard_timer.h>
@@ -30,31 +31,77 @@ stop_timer (wg_peer_t * peer, u32 timer_id)
{
if (peer->timers[timer_id] != ~0)
{
- tw_timer_stop_16t_2w_512sl (&peer->timer_wheel, peer->timers[timer_id]);
+ tw_timer_stop_16t_2w_512sl (peer->timer_wheel, peer->timers[timer_id]);
peer->timers[timer_id] = ~0;
}
}
static void
-start_or_update_timer (wg_peer_t * peer, u32 timer_id, u32 interval)
+start_timer (wg_peer_t * peer, u32 timer_id, u32 interval_ticks)
{
+ ASSERT (vlib_get_thread_index () == 0);
+
if (peer->timers[timer_id] == ~0)
{
- wg_main_t *wmp = &wg_main;
peer->timers[timer_id] =
- tw_timer_start_16t_2w_512sl (&peer->timer_wheel, peer - wmp->peers,
- timer_id, interval);
- }
- else
- {
- tw_timer_update_16t_2w_512sl (&peer->timer_wheel,
- peer->timers[timer_id], interval);
+ tw_timer_start_16t_2w_512sl (peer->timer_wheel, peer - wg_peer_pool,
+ timer_id, interval_ticks);
}
}
+typedef struct
+{
+ u32 peer_idx;
+ u32 timer_id;
+ u32 interval_ticks;
+
+} wg_timers_args;
+
+static void *
+start_timer_thread_fn (void *arg)
+{
+ wg_timers_args *a = arg;
+ wg_peer_t *peer = wg_peer_get (a->peer_idx);
+
+ start_timer (peer, a->timer_id, a->interval_ticks);
+ return 0;
+}
+
+static void
+start_timer_from_mt (u32 peer_idx, u32 timer_id, u32 interval_ticks)
+{
+ wg_timers_args a = {
+ .peer_idx = peer_idx,
+ .timer_id = timer_id,
+ .interval_ticks = interval_ticks,
+ };
+
+ vl_api_rpc_call_main_thread (start_timer_thread_fn, (u8 *) & a, sizeof (a));
+}
+
+static inline u32
+timer_ticks_left (vlib_main_t * vm, f64 init_time_sec, u32 interval_ticks)
+{
+ static const int32_t rounding = (int32_t) (WHZ / 2);
+ int32_t ticks_remain;
+
+ ticks_remain = (init_time_sec - vlib_time_now (vm)) * WHZ + interval_ticks;
+ return (ticks_remain > rounding) ? (u32) ticks_remain : 0;
+}
+
static void
wg_expired_retransmit_handshake (vlib_main_t * vm, wg_peer_t * peer)
{
+ if (peer->rehandshake_started == ~0)
+ return;
+
+ u32 ticks = timer_ticks_left (vm, peer->rehandshake_started,
+ peer->rehandshake_interval_tick);
+ if (ticks)
+ {
+ start_timer (peer, WG_TIMER_RETRANSMIT_HANDSHAKE, ticks);
+ return;
+ }
if (peer->timer_handshake_attempts > MAX_TIMER_HANDSHAKES)
{
@@ -63,17 +110,8 @@ wg_expired_retransmit_handshake (vlib_main_t * vm, wg_peer_t * peer)
/* We set a timer for destroying any residue that might be left
* of a partial exchange.
*/
+ start_timer (peer, WG_TIMER_KEY_ZEROING, REJECT_AFTER_TIME * 3 * WHZ);
- if (peer->timers[WG_TIMER_KEY_ZEROING] == ~0)
- {
- wg_main_t *wmp = &wg_main;
-
- peer->timers[WG_TIMER_KEY_ZEROING] =
- tw_timer_start_16t_2w_512sl (&peer->timer_wheel,
- peer - wmp->peers,
- WG_TIMER_KEY_ZEROING,
- REJECT_AFTER_TIME * 3 * WHZ);
- }
}
else
{
@@ -85,13 +123,23 @@ wg_expired_retransmit_handshake (vlib_main_t * vm, wg_peer_t * peer)
static void
wg_expired_send_keepalive (vlib_main_t * vm, wg_peer_t * peer)
{
- wg_send_keepalive (vm, peer);
-
- if (peer->timer_need_another_keepalive)
+ if (peer->last_sent_packet < peer->last_received_packet)
{
- peer->timer_need_another_keepalive = false;
- start_or_update_timer (peer, WG_TIMER_SEND_KEEPALIVE,
- KEEPALIVE_TIMEOUT * WHZ);
+ u32 ticks = timer_ticks_left (vm, peer->last_received_packet,
+ KEEPALIVE_TIMEOUT * WHZ);
+ if (ticks)
+ {
+ start_timer (peer, WG_TIMER_SEND_KEEPALIVE, ticks);
+ return;
+ }
+
+ wg_send_keepalive (vm, peer);
+ if (peer->timer_need_another_keepalive)
+ {
+ peer->timer_need_another_keepalive = false;
+ start_timer (peer, WG_TIMER_SEND_KEEPALIVE,
+ KEEPALIVE_TIMEOUT * WHZ);
+ }
}
}
@@ -100,6 +148,18 @@ wg_expired_send_persistent_keepalive (vlib_main_t * vm, wg_peer_t * peer)
{
if (peer->persistent_keepalive_interval)
{
+ f64 latest_time = peer->last_sent_packet > peer->last_received_packet
+ ? peer->last_sent_packet : peer->last_received_packet;
+
+ u32 ticks = timer_ticks_left (vm, latest_time,
+ peer->persistent_keepalive_interval *
+ WHZ);
+ if (ticks)
+ {
+ start_timer (peer, WG_TIMER_PERSISTENT_KEEPALIVE, ticks);
+ return;
+ }
+
wg_send_keepalive (vm, peer);
}
}
@@ -107,64 +167,81 @@ wg_expired_send_persistent_keepalive (vlib_main_t * vm, wg_peer_t * peer)
static void
wg_expired_new_handshake (vlib_main_t * vm, wg_peer_t * peer)
{
+ u32 ticks = timer_ticks_left (vm, peer->last_sent_packet,
+ peer->new_handshake_interval_tick);
+ if (ticks)
+ {
+ start_timer (peer, WG_TIMER_NEW_HANDSHAKE, ticks);
+ return;
+ }
+
wg_send_handshake (vm, peer, false);
}
static void
wg_expired_zero_key_material (vlib_main_t * vm, wg_peer_t * peer)
{
+ u32 ticks =
+ timer_ticks_left (vm, peer->session_derived, REJECT_AFTER_TIME * 3 * WHZ);
+ if (ticks)
+ {
+ start_timer (peer, WG_TIMER_KEY_ZEROING, ticks);
+ return;
+ }
+
if (!peer->is_dead)
{
noise_remote_clear (vm, &peer->remote);
}
}
-
void
wg_timers_any_authenticated_packet_traversal (wg_peer_t * peer)
{
if (peer->persistent_keepalive_interval)
{
- start_or_update_timer (peer, WG_TIMER_PERSISTENT_KEEPALIVE,
- peer->persistent_keepalive_interval * WHZ);
+ start_timer_from_mt (peer - wg_peer_pool,
+ WG_TIMER_PERSISTENT_KEEPALIVE,
+ peer->persistent_keepalive_interval * WHZ);
}
}
void
wg_timers_any_authenticated_packet_sent (wg_peer_t * peer)
{
- stop_timer (peer, WG_TIMER_SEND_KEEPALIVE);
+ peer->last_sent_packet = vlib_time_now (vlib_get_main ());
}
void
wg_timers_handshake_initiated (wg_peer_t * peer)
{
- u32 interval =
+ peer->rehandshake_started = vlib_time_now (vlib_get_main ());
+ peer->rehandshake_interval_tick =
REKEY_TIMEOUT * WHZ + get_random_u32_max (REKEY_TIMEOUT_JITTER);
- start_or_update_timer (peer, WG_TIMER_RETRANSMIT_HANDSHAKE, interval);
+
+ start_timer_from_mt (peer - wg_peer_pool, WG_TIMER_RETRANSMIT_HANDSHAKE,
+ peer->rehandshake_interval_tick);
}
void
wg_timers_session_derived (wg_peer_t * peer)
{
- start_or_update_timer (peer, WG_TIMER_KEY_ZEROING,
- REJECT_AFTER_TIME * 3 * WHZ);
+ peer->session_derived = vlib_time_now (vlib_get_main ());
+
+ start_timer_from_mt (peer - wg_peer_pool, WG_TIMER_KEY_ZEROING,
+ REJECT_AFTER_TIME * 3 * WHZ);
}
/* Should be called after an authenticated data packet is sent. */
void
wg_timers_data_sent (wg_peer_t * peer)
{
- u32 interval = (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT) * WHZ +
+ peer->new_handshake_interval_tick =
+ (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT) * WHZ +
get_random_u32_max (REKEY_TIMEOUT_JITTER);
- if (peer->timers[WG_TIMER_NEW_HANDSHAKE] == ~0)
- {
- wg_main_t *wmp = &wg_main;
- peer->timers[WG_TIMER_NEW_HANDSHAKE] =
- tw_timer_start_16t_2w_512sl (&peer->timer_wheel, peer - wmp->peers,
- WG_TIMER_NEW_HANDSHAKE, interval);
- }
+ start_timer_from_mt (peer - wg_peer_pool, WG_TIMER_NEW_HANDSHAKE,
+ peer->new_handshake_interval_tick);
}
/* Should be called after an authenticated data packet is received. */
@@ -173,16 +250,11 @@ wg_timers_data_received (wg_peer_t * peer)
{
if (peer->timers[WG_TIMER_SEND_KEEPALIVE] == ~0)
{
- wg_main_t *wmp = &wg_main;
- peer->timers[WG_TIMER_SEND_KEEPALIVE] =
- tw_timer_start_16t_2w_512sl (&peer->timer_wheel, peer - wmp->peers,
- WG_TIMER_SEND_KEEPALIVE,
- KEEPALIVE_TIMEOUT * WHZ);
+ start_timer_from_mt (peer - wg_peer_pool, WG_TIMER_SEND_KEEPALIVE,
+ KEEPALIVE_TIMEOUT * WHZ);
}
else
- {
- peer->timer_need_another_keepalive = true;
- }
+ peer->timer_need_another_keepalive = true;
}
/* Should be called after a handshake response message is received and processed
@@ -191,15 +263,14 @@ wg_timers_data_received (wg_peer_t * peer)
void
wg_timers_handshake_complete (wg_peer_t * peer)
{
- stop_timer (peer, WG_TIMER_RETRANSMIT_HANDSHAKE);
-
+ peer->rehandshake_started = ~0;
peer->timer_handshake_attempts = 0;
}
void
wg_timers_any_authenticated_packet_received (wg_peer_t * peer)
{
- stop_timer (peer, WG_TIMER_NEW_HANDSHAKE);
+ peer->last_received_packet = vlib_time_now (vlib_get_main ());
}
static vlib_node_registration_t wg_timer_mngr_node;
@@ -222,7 +293,7 @@ expired_timer_callback (u32 * expired_timers)
pool_index = expired_timers[i] & 0x0FFFFFFF;
timer_id = expired_timers[i] >> 28;
- peer = pool_elt_at_index (wmp->peers, pool_index);
+ peer = wg_peer_get (pool_index);
peer->timers[timer_id] = ~0;
}
@@ -231,7 +302,7 @@ expired_timer_callback (u32 * expired_timers)
pool_index = expired_timers[i] & 0x0FFFFFFF;
timer_id = expired_timers[i] >> 28;
- peer = pool_elt_at_index (wmp->peers, pool_index);
+ peer = wg_peer_get (pool_index);
switch (timer_id)
{
case WG_TIMER_RETRANSMIT_HANDSHAKE:
@@ -256,18 +327,14 @@ expired_timer_callback (u32 * expired_timers)
}
void
-wg_timers_init (wg_peer_t * peer, f64 now)
+wg_timer_wheel_init ()
{
- for (int i = 0; i < WG_N_TIMERS; i++)
- {
- peer->timers[i] = ~0;
- }
- tw_timer_wheel_16t_2w_512sl_t *tw = &peer->timer_wheel;
+ wg_main_t *wmp = &wg_main;
+ tw_timer_wheel_16t_2w_512sl_t *tw = &wmp->timer_wheel;
tw_timer_wheel_init_16t_2w_512sl (tw,
expired_timer_callback,
WG_TICK /* timer period in s */ , ~0);
- tw->last_run_time = now;
- peer->adj_index = INDEX_INVALID;
+ tw->last_run_time = vlib_time_now (wmp->vlib_main);
}
static uword
@@ -275,22 +342,13 @@ wg_timer_mngr_fn (vlib_main_t * vm, vlib_node_runtime_t * rt,
vlib_frame_t * f)
{
wg_main_t *wmp = &wg_main;
- wg_peer_t *peers;
- wg_peer_t *peer;
-
while (1)
{
vlib_process_wait_for_event_or_clock (vm, WG_TICK);
vlib_process_get_events (vm, NULL);
- peers = wmp->peers;
- /* *INDENT-OFF* */
- pool_foreach (peer, peers,
- ({
- tw_timer_expire_timers_16t_2w_512sl
- (&peer->timer_wheel, vlib_time_now (vm));
- }));
- /* *INDENT-ON* */
+ tw_timer_expire_timers_16t_2w_512sl (&wmp->timer_wheel,
+ vlib_time_now (vm));
}
return 0;
@@ -299,11 +357,15 @@ wg_timer_mngr_fn (vlib_main_t * vm, vlib_node_runtime_t * rt,
void
wg_timers_stop (wg_peer_t * peer)
{
- stop_timer (peer, WG_TIMER_RETRANSMIT_HANDSHAKE);
- stop_timer (peer, WG_TIMER_PERSISTENT_KEEPALIVE);
- stop_timer (peer, WG_TIMER_SEND_KEEPALIVE);
- stop_timer (peer, WG_TIMER_NEW_HANDSHAKE);
- stop_timer (peer, WG_TIMER_KEY_ZEROING);
+ ASSERT (vlib_get_thread_index () == 0);
+ if (peer->timer_wheel)
+ {
+ stop_timer (peer, WG_TIMER_RETRANSMIT_HANDSHAKE);
+ stop_timer (peer, WG_TIMER_PERSISTENT_KEEPALIVE);
+ stop_timer (peer, WG_TIMER_SEND_KEEPALIVE);
+ stop_timer (peer, WG_TIMER_NEW_HANDSHAKE);
+ stop_timer (peer, WG_TIMER_KEY_ZEROING);
+ }
}
/* *INDENT-OFF* */
diff --git a/src/plugins/wireguard/wireguard_timer.h b/src/plugins/wireguard/wireguard_timer.h
index 457dce28674..2cc5dd01284 100755
--- a/src/plugins/wireguard/wireguard_timer.h
+++ b/src/plugins/wireguard/wireguard_timer.h
@@ -38,7 +38,7 @@ typedef enum _wg_timers
typedef struct wg_peer wg_peer_t;
-void wg_timers_init (wg_peer_t * peer, f64 now);
+void wg_timer_wheel_init ();
void wg_timers_stop (wg_peer_t * peer);
void wg_timers_data_sent (wg_peer_t * peer);
void wg_timers_data_received (wg_peer_t * peer);