diff options
-rw-r--r-- | src/plugins/wireguard/wireguard_peer.c | 12 | ||||
-rw-r--r-- | src/plugins/wireguard/wireguard_timer.c | 10 | ||||
-rw-r--r-- | src/plugins/wireguard/wireguard_timer.h | 1 | ||||
-rw-r--r-- | test/test_wireguard.py | 33 |
4 files changed, 46 insertions, 10 deletions
diff --git a/src/plugins/wireguard/wireguard_peer.c b/src/plugins/wireguard/wireguard_peer.c index a8f1ab91644..f7bf2352db4 100644 --- a/src/plugins/wireguard/wireguard_peer.c +++ b/src/plugins/wireguard/wireguard_peer.c @@ -244,11 +244,7 @@ wg_peer_enable (vlib_main_t *vm, wg_peer_t *peer) noise_remote_init (&peer->remote, peeri, public_key, wg_if->local_idx); - wg_send_handshake (vm, peer, false); - if (peer->persistent_keepalive_interval != 0) - { - wg_send_keepalive (vm, peer); - } + wg_timers_send_first_handshake (peer); } walk_rc_t @@ -494,11 +490,7 @@ wg_peer_add (u32 tun_sw_if_index, const u8 public_key[NOISE_PUBLIC_KEY_LEN], if (vnet_sw_interface_is_admin_up (vnet_get_main (), tun_sw_if_index)) { - wg_send_handshake (vm, peer, false); - if (peer->persistent_keepalive_interval != 0) - { - wg_send_keepalive (vm, peer); - } + wg_timers_send_first_handshake (peer); } *peer_index = peer - wg_peer_pool; diff --git a/src/plugins/wireguard/wireguard_timer.c b/src/plugins/wireguard/wireguard_timer.c index b95801122fc..4319d534ffc 100644 --- a/src/plugins/wireguard/wireguard_timer.c +++ b/src/plugins/wireguard/wireguard_timer.c @@ -239,6 +239,16 @@ wg_timers_handshake_initiated (wg_peer_t * peer) } void +wg_timers_send_first_handshake (wg_peer_t *peer) +{ + // zero value is not allowed + peer->new_handshake_interval_tick = + get_random_u32_max (REKEY_TIMEOUT_JITTER) + 1; + start_timer_from_mt (peer - wg_peer_pool, WG_TIMER_NEW_HANDSHAKE, + peer->new_handshake_interval_tick); +} + +void wg_timers_session_derived (wg_peer_t * peer) { peer->session_derived = vlib_time_now (vlib_get_main ()); diff --git a/src/plugins/wireguard/wireguard_timer.h b/src/plugins/wireguard/wireguard_timer.h index ebde47e9067..47638bfd74d 100644 --- a/src/plugins/wireguard/wireguard_timer.h +++ b/src/plugins/wireguard/wireguard_timer.h @@ -50,6 +50,7 @@ void wg_timers_any_authenticated_packet_received_opt (wg_peer_t *peer, f64 time); void wg_timers_handshake_initiated (wg_peer_t * peer); void wg_timers_handshake_complete (wg_peer_t * peer); +void wg_timers_send_first_handshake (wg_peer_t *peer); void wg_timers_session_derived (wg_peer_t * peer); void wg_timers_any_authenticated_packet_traversal (wg_peer_t * peer); diff --git a/test/test_wireguard.py b/test/test_wireguard.py index 80ebdd89aa6..72a317ca8c2 100644 --- a/test/test_wireguard.py +++ b/test/test_wireguard.py @@ -147,6 +147,8 @@ UNDER_LOAD_INTERVAL = 1.0 HANDSHAKE_NUM_PER_PEER_UNTIL_UNDER_LOAD = 40 HANDSHAKE_NUM_BEFORE_RATELIMITING = 5 +HANDSHAKE_JITTER = 0.5 + class VppWgPeer(VppObject): def __init__(self, test, itf, endpoint, port, allowed_ips, persistent_keepalive=15): @@ -672,6 +674,9 @@ class TestWg(VppTestCase): self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1) if is_resp: + # skip the first automatic handshake + self.pg1.get_capture(1, timeout=HANDSHAKE_JITTER) + # prepare and send a handshake initiation # expect the peer to send a handshake response init = peer_1.mk_handshake(self.pg1, is_ip6=is_ip6) @@ -764,6 +769,9 @@ class TestWg(VppTestCase): # reset noise to be able to turn into initiator later peer_1.noise_reset() else: + # skip the first automatic handshake + self.pg1.get_capture(1, timeout=HANDSHAKE_JITTER) + # prepare and send a bunch of handshake initiations # expect to switch to under load state init = peer_1.mk_handshake(self.pg1, is_ip6=is_ip6) @@ -829,6 +837,9 @@ class TestWg(VppTestCase): ).add_vpp_config() self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1) + # skip the first automatic handshake + self.pg1.get_capture(1, timeout=HANDSHAKE_JITTER) + # prepare and send a bunch of handshake initiations # expect to switch to under load state init = peer_1.mk_handshake(self.pg1) @@ -893,6 +904,9 @@ class TestWg(VppTestCase): ).add_vpp_config() self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1) + # skip the first automatic handshake + self.pg1.get_capture(1, timeout=HANDSHAKE_JITTER) + # prepare and send a bunch of handshake initiations # expect to switch to under load state init = peer_1.mk_handshake(self.pg1, is_ip6=is_ip6) @@ -963,6 +977,9 @@ class TestWg(VppTestCase): ).add_vpp_config() self.assertEqual(len(self.vapi.wireguard_peers_dump()), 2) + # skip the first automatic handshake + self.pg1.get_capture(NUM_PEERS, timeout=HANDSHAKE_JITTER) + # (peer_1) prepare and send a bunch of handshake initiations # expect not to switch to under load state init_1 = peer_1.mk_handshake(self.pg1) @@ -2098,6 +2115,9 @@ class TestWg(VppTestCase): self.pg0.generate_remote_hosts(NUM_IFS) self.pg0.configure_ipv4_neighbors() + self.pg_enable_capture(self.pg_interfaces) + self.pg_start() + # Create interfaces with a peer on each peers = [] routes = [] @@ -2129,6 +2149,9 @@ class TestWg(VppTestCase): self.assertEqual(len(self.vapi.wireguard_peers_dump()), NUM_IFS) + # skip the first automatic handshake + self.pg1.get_capture(NUM_IFS, timeout=HANDSHAKE_JITTER) + for i in range(NUM_IFS): # send a valid handsake init for which we expect a response p = peers[i].mk_handshake(self.pg1) @@ -2282,6 +2305,10 @@ class TestWg(VppTestCase): self.assertEqual(len(self.vapi.wireguard_peers_dump()), NUM_PEERS * 2) + # skip the first automatic handshake + self.pg1.get_capture(NUM_PEERS, timeout=HANDSHAKE_JITTER) + self.pg2.get_capture(NUM_PEERS, timeout=HANDSHAKE_JITTER) + # Want events from the first perr of wg0 # and from all wg1 peers peers_0[0].want_events() @@ -2472,6 +2499,9 @@ class WireguardHandoffTests(TestWg): wg0.admin_up() wg0.config_ip4() + self.pg_enable_capture(self.pg_interfaces) + self.pg_start() + peer_1 = VppWgPeer( self, wg0, self.pg1.remote_ip4, port + 1, ["10.11.2.0/24", "10.11.3.0/24"] ).add_vpp_config() @@ -2481,6 +2511,9 @@ class WireguardHandoffTests(TestWg): self, "10.11.3.0", 24, [VppRoutePath("10.11.3.1", wg0.sw_if_index)] ).add_vpp_config() + # skip the first automatic handshake + self.pg1.get_capture(1, timeout=HANDSHAKE_JITTER) + # send a valid handsake init for which we expect a response p = peer_1.mk_handshake(self.pg1) |