summaryrefslogtreecommitdiffstats
path: root/src/plugins/wireguard/wireguard_noise.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/plugins/wireguard/wireguard_noise.c')
-rwxr-xr-xsrc/plugins/wireguard/wireguard_noise.c131
1 files changed, 51 insertions, 80 deletions
diff --git a/src/plugins/wireguard/wireguard_noise.c b/src/plugins/wireguard/wireguard_noise.c
index b47bb5747b9..00b67109de4 100755
--- a/src/plugins/wireguard/wireguard_noise.c
+++ b/src/plugins/wireguard/wireguard_noise.c
@@ -26,6 +26,8 @@
* <- e, ee, se, psk, {}
*/
+noise_local_t *noise_local_pool;
+
/* Private functions */
static noise_keypair_t *noise_remote_keypair_allocate (noise_remote_t *);
static void noise_remote_keypair_free (vlib_main_t * vm, noise_remote_t *,
@@ -80,81 +82,31 @@ noise_local_set_private (noise_local_t * l,
const uint8_t private[NOISE_PUBLIC_KEY_LEN])
{
clib_memcpy (l->l_private, private, NOISE_PUBLIC_KEY_LEN);
- l->l_has_identity = curve25519_gen_public (l->l_public, private);
-
- return l->l_has_identity;
-}
-bool
-noise_local_keys (noise_local_t * l, uint8_t public[NOISE_PUBLIC_KEY_LEN],
- uint8_t private[NOISE_PUBLIC_KEY_LEN])
-{
- if (l->l_has_identity)
- {
- if (public != NULL)
- clib_memcpy (public, l->l_public, NOISE_PUBLIC_KEY_LEN);
- if (private != NULL)
- clib_memcpy (private, l->l_private, NOISE_PUBLIC_KEY_LEN);
- }
- else
- {
- return false;
- }
- return true;
+ return curve25519_gen_public (l->l_public, private);
}
void
noise_remote_init (noise_remote_t * r, uint32_t peer_pool_idx,
const uint8_t public[NOISE_PUBLIC_KEY_LEN],
- noise_local_t * l)
+ u32 noise_local_idx)
{
clib_memset (r, 0, sizeof (*r));
clib_memcpy (r->r_public, public, NOISE_PUBLIC_KEY_LEN);
+ clib_rwlock_init (&r->r_keypair_lock);
r->r_peer_idx = peer_pool_idx;
-
- ASSERT (l != NULL);
- r->r_local = l;
+ r->r_local_idx = noise_local_idx;
r->r_handshake.hs_state = HS_ZEROED;
- noise_remote_precompute (r);
-}
-
-bool
-noise_remote_set_psk (noise_remote_t * r,
- uint8_t psk[NOISE_SYMMETRIC_KEY_LEN])
-{
- int same;
- same = !clib_memcmp (r->r_psk, psk, NOISE_SYMMETRIC_KEY_LEN);
- if (!same)
- {
- clib_memcpy (r->r_psk, psk, NOISE_SYMMETRIC_KEY_LEN);
- }
- return same == 0;
-}
-
-bool
-noise_remote_keys (noise_remote_t * r, uint8_t public[NOISE_PUBLIC_KEY_LEN],
- uint8_t psk[NOISE_SYMMETRIC_KEY_LEN])
-{
- static uint8_t null_psk[NOISE_SYMMETRIC_KEY_LEN];
- int ret;
- if (public != NULL)
- clib_memcpy (public, r->r_public, NOISE_PUBLIC_KEY_LEN);
-
- if (psk != NULL)
- clib_memcpy (psk, r->r_psk, NOISE_SYMMETRIC_KEY_LEN);
- ret = clib_memcmp (r->r_psk, null_psk, NOISE_SYMMETRIC_KEY_LEN);
-
- return ret;
+ noise_remote_precompute (r);
}
void
noise_remote_precompute (noise_remote_t * r)
{
- noise_local_t *l = r->r_local;
- if (!l->l_has_identity)
- clib_memset (r->r_ss, 0, NOISE_PUBLIC_KEY_LEN);
- else if (!curve25519_gen_shared (r->r_ss, l->l_private, r->r_public))
+ noise_local_t *l = noise_local_get (r->r_local_idx);
+
+ if (!curve25519_gen_shared (r->r_ss, l->l_private, r->r_public))
clib_memset (r->r_ss, 0, NOISE_PUBLIC_KEY_LEN);
noise_remote_handshake_index_drop (r);
@@ -169,7 +121,7 @@ noise_create_initiation (vlib_main_t * vm, noise_remote_t * r,
uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN])
{
noise_handshake_t *hs = &r->r_handshake;
- noise_local_t *l = r->r_local;
+ noise_local_t *l = noise_local_get (r->r_local_idx);
uint8_t _key[NOISE_SYMMETRIC_KEY_LEN];
uint32_t key_idx;
uint8_t *key;
@@ -180,8 +132,6 @@ noise_create_initiation (vlib_main_t * vm, noise_remote_t * r,
NOISE_SYMMETRIC_KEY_LEN);
key = vnet_crypto_get_key (key_idx)->data;
- if (!l->l_has_identity)
- goto error;
noise_param_init (hs->hs_ck, hs->hs_hash, r->r_public);
/* e */
@@ -239,8 +189,6 @@ noise_consume_initiation (vlib_main_t * vm, noise_local_t * l,
NOISE_SYMMETRIC_KEY_LEN);
key = vnet_crypto_get_key (key_idx)->data;
- if (!l->l_has_identity)
- goto error;
noise_param_init (hs.hs_ck, hs.hs_hash, l->l_public);
/* e */
@@ -294,6 +242,7 @@ noise_consume_initiation (vlib_main_t * vm, noise_local_t * l,
r->r_handshake = hs;
*rp = r;
ret = true;
+
error:
vnet_crypto_key_del (vm, key_idx);
secure_zero_memory (key, NOISE_SYMMETRIC_KEY_LEN);
@@ -359,7 +308,7 @@ noise_consume_response (vlib_main_t * vm, noise_remote_t * r, uint32_t s_idx,
uint32_t r_idx, uint8_t ue[NOISE_PUBLIC_KEY_LEN],
uint8_t en[0 + NOISE_AUTHTAG_LEN])
{
- noise_local_t *l = r->r_local;
+ noise_local_t *l = noise_local_get (r->r_local_idx);
noise_handshake_t hs;
uint8_t _key[NOISE_SYMMETRIC_KEY_LEN];
uint8_t preshared_key[NOISE_PUBLIC_KEY_LEN];
@@ -372,9 +321,6 @@ noise_consume_response (vlib_main_t * vm, noise_remote_t * r, uint32_t s_idx,
NOISE_SYMMETRIC_KEY_LEN);
key = vnet_crypto_get_key (key_idx)->data;
- if (!l->l_has_identity)
- goto error;
-
hs = r->r_handshake;
clib_memcpy (preshared_key, r->r_psk, NOISE_SYMMETRIC_KEY_LEN);
@@ -460,6 +406,7 @@ noise_remote_begin_session (vlib_main_t * vm, noise_remote_t * r)
clib_memset (&kp.kp_ctr, 0, sizeof (kp.kp_ctr));
/* Now we need to add_new_keypair */
+ clib_rwlock_writer_lock (&r->r_keypair_lock);
next = r->r_next;
current = r->r_current;
previous = r->r_previous;
@@ -491,7 +438,10 @@ noise_remote_begin_session (vlib_main_t * vm, noise_remote_t * r)
r->r_next = noise_remote_keypair_allocate (r);
*r->r_next = kp;
}
+ clib_rwlock_writer_unlock (&r->r_keypair_lock);
+
secure_zero_memory (&r->r_handshake, sizeof (r->r_handshake));
+
secure_zero_memory (&kp, sizeof (kp));
return true;
}
@@ -502,21 +452,25 @@ noise_remote_clear (vlib_main_t * vm, noise_remote_t * r)
noise_remote_handshake_index_drop (r);
secure_zero_memory (&r->r_handshake, sizeof (r->r_handshake));
+ clib_rwlock_writer_lock (&r->r_keypair_lock);
noise_remote_keypair_free (vm, r, &r->r_next);
noise_remote_keypair_free (vm, r, &r->r_current);
noise_remote_keypair_free (vm, r, &r->r_previous);
r->r_next = NULL;
r->r_current = NULL;
r->r_previous = NULL;
+ clib_rwlock_writer_unlock (&r->r_keypair_lock);
}
void
noise_remote_expire_current (noise_remote_t * r)
{
+ clib_rwlock_writer_lock (&r->r_keypair_lock);
if (r->r_next != NULL)
r->r_next->kp_valid = 0;
if (r->r_current != NULL)
r->r_current->kp_valid = 0;
+ clib_rwlock_writer_unlock (&r->r_keypair_lock);
}
bool
@@ -525,6 +479,7 @@ noise_remote_ready (noise_remote_t * r)
noise_keypair_t *kp;
int ret;
+ clib_rwlock_reader_lock (&r->r_keypair_lock);
if ((kp = r->r_current) == NULL ||
!kp->kp_valid ||
wg_birthdate_has_expired (kp->kp_birthdate, REJECT_AFTER_TIME) ||
@@ -533,6 +488,7 @@ noise_remote_ready (noise_remote_t * r)
ret = false;
else
ret = true;
+ clib_rwlock_reader_unlock (&r->r_keypair_lock);
return ret;
}
@@ -592,6 +548,7 @@ noise_remote_encrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t * r_idx,
noise_keypair_t *kp;
enum noise_state_crypt ret = SC_FAILED;
+ clib_rwlock_reader_lock (&r->r_keypair_lock);
if ((kp = r->r_current) == NULL)
goto error;
@@ -631,6 +588,7 @@ noise_remote_encrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t * r_idx,
ret = SC_OK;
error:
+ clib_rwlock_reader_unlock (&r->r_keypair_lock);
return ret;
}
@@ -641,6 +599,7 @@ noise_remote_decrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t r_idx,
{
noise_keypair_t *kp;
enum noise_state_crypt ret = SC_FAILED;
+ clib_rwlock_reader_lock (&r->r_keypair_lock);
if (r->r_current != NULL && r->r_current->kp_local_index == r_idx)
{
@@ -682,18 +641,26 @@ noise_remote_decrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t r_idx,
* next keypair into current. If we do slide the next keypair in, then
* we skip the REKEY_AFTER_TIME_RECV check. This is safe to do as a
* data packet can't confirm a session that we are an INITIATOR of. */
- if (kp == r->r_next && kp->kp_local_index == r_idx)
+ if (kp == r->r_next)
{
- noise_remote_keypair_free (vm, r, &r->r_previous);
- r->r_previous = r->r_current;
- r->r_current = r->r_next;
- r->r_next = NULL;
+ clib_rwlock_reader_unlock (&r->r_keypair_lock);
+ clib_rwlock_writer_lock (&r->r_keypair_lock);
+ if (kp == r->r_next && kp->kp_local_index == r_idx)
+ {
+ noise_remote_keypair_free (vm, r, &r->r_previous);
+ r->r_previous = r->r_current;
+ r->r_current = r->r_next;
+ r->r_next = NULL;
- ret = SC_CONN_RESET;
- goto error;
+ ret = SC_CONN_RESET;
+ clib_rwlock_writer_unlock (&r->r_keypair_lock);
+ clib_rwlock_reader_lock (&r->r_keypair_lock);
+ goto error;
+ }
+ clib_rwlock_writer_unlock (&r->r_keypair_lock);
+ clib_rwlock_reader_lock (&r->r_keypair_lock);
}
-
/* Similar to when we encrypt, we want to notify the caller when we
* are approaching our tolerances. We notify if:
* - we're the initiator and the current keypair is older than
@@ -708,6 +675,7 @@ noise_remote_decrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t r_idx,
ret = SC_OK;
error:
+ clib_rwlock_reader_unlock (&r->r_keypair_lock);
return ret;
}
@@ -725,7 +693,8 @@ static void
noise_remote_keypair_free (vlib_main_t * vm, noise_remote_t * r,
noise_keypair_t ** kp)
{
- struct noise_upcall *u = &r->r_local->l_upcall;
+ noise_local_t *local = noise_local_get (r->r_local_idx);
+ struct noise_upcall *u = &local->l_upcall;
if (*kp)
{
u->u_index_drop ((*kp)->kp_local_index);
@@ -738,7 +707,8 @@ noise_remote_keypair_free (vlib_main_t * vm, noise_remote_t * r,
static uint32_t
noise_remote_handshake_index_get (noise_remote_t * r)
{
- struct noise_upcall *u = &r->r_local->l_upcall;
+ noise_local_t *local = noise_local_get (r->r_local_idx);
+ struct noise_upcall *u = &local->l_upcall;
return u->u_index_set (r);
}
@@ -746,7 +716,8 @@ static void
noise_remote_handshake_index_drop (noise_remote_t * r)
{
noise_handshake_t *hs = &r->r_handshake;
- struct noise_upcall *u = &r->r_local->l_upcall;
+ noise_local_t *local = noise_local_get (r->r_local_idx);
+ struct noise_upcall *u = &local->l_upcall;
if (hs->hs_state != HS_ZEROED)
u->u_index_drop (hs->hs_local_index);
}
@@ -754,7 +725,8 @@ noise_remote_handshake_index_drop (noise_remote_t * r)
static uint64_t
noise_counter_send (noise_counter_t * ctr)
{
- uint64_t ret = ctr->c_send++;
+ uint64_t ret;
+ ret = ctr->c_send++;
return ret;
}
@@ -765,7 +737,6 @@ noise_counter_recv (noise_counter_t * ctr, uint64_t recv)
unsigned long bit;
bool ret = false;
-
/* Check that the recv counter is valid */
if (ctr->c_recv >= REJECT_AFTER_MESSAGES || recv >= REJECT_AFTER_MESSAGES)
goto error;