diff options
Diffstat (limited to 'src/plugins/wireguard')
-rw-r--r-- | src/plugins/wireguard/wireguard.h | 2 | ||||
-rw-r--r-- | src/plugins/wireguard/wireguard_input.c | 326 | ||||
-rw-r--r-- | src/plugins/wireguard/wireguard_noise.c | 174 | ||||
-rw-r--r-- | src/plugins/wireguard/wireguard_noise.h | 76 | ||||
-rw-r--r-- | src/plugins/wireguard/wireguard_output_tun.c | 100 | ||||
-rw-r--r-- | src/plugins/wireguard/wireguard_timer.c | 30 | ||||
-rw-r--r-- | src/plugins/wireguard/wireguard_timer.h | 11 |
7 files changed, 538 insertions, 181 deletions
diff --git a/src/plugins/wireguard/wireguard.h b/src/plugins/wireguard/wireguard.h index 829c9e6f22b..4cbee1fcf7a 100644 --- a/src/plugins/wireguard/wireguard.h +++ b/src/plugins/wireguard/wireguard.h @@ -28,6 +28,8 @@ extern vlib_node_registration_t wg6_output_tun_node; typedef struct wg_per_thread_data_t_ { + CLIB_CACHE_LINE_ALIGN_MARK (cacheline0); + vnet_crypto_op_t *crypto_ops; u8 data[WG_DEFAULT_DATA_SIZE]; } wg_per_thread_data_t; typedef struct diff --git a/src/plugins/wireguard/wireguard_input.c b/src/plugins/wireguard/wireguard_input.c index d146d0d8fb5..10827ca0e64 100644 --- a/src/plugins/wireguard/wireguard_input.c +++ b/src/plugins/wireguard/wireguard_input.c @@ -298,50 +298,109 @@ wg_handshake_process (vlib_main_t *vm, wg_main_t *wmp, vlib_buffer_t *b, return WG_INPUT_ERROR_NONE; } +static_always_inline void +wg_input_process_ops (vlib_main_t *vm, vlib_node_runtime_t *node, + vnet_crypto_op_t *ops, vlib_buffer_t *b[], u16 *nexts, + u16 drop_next) +{ + u32 n_fail, n_ops = vec_len (ops); + vnet_crypto_op_t *op = ops; + + if (n_ops == 0) + return; + + n_fail = n_ops - vnet_crypto_process_ops (vm, op, n_ops); + + while (n_fail) + { + ASSERT (op - ops < n_ops); + + if (op->status != VNET_CRYPTO_OP_STATUS_COMPLETED) + { + u32 bi = op->user_data; + b[bi]->error = node->errors[WG_INPUT_ERROR_DECRYPTION]; + nexts[bi] = drop_next; + n_fail--; + } + op++; + } +} + always_inline uword wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame, u8 is_ip4) { - message_type_t header_type; - u32 n_left_from; - u32 *from; - vlib_buffer_t *bufs[VLIB_FRAME_SIZE], **b; - u16 nexts[VLIB_FRAME_SIZE], *next; - u32 thread_index = vm->thread_index; + wg_main_t *wmp = &wg_main; + wg_per_thread_data_t *ptd = + vec_elt_at_index (wmp->per_thread_data, vm->thread_index); + u32 *from = vlib_frame_vector_args (frame); + u32 n_left_from = frame->n_vectors; - from = vlib_frame_vector_args (frame); - n_left_from = frame->n_vectors; - b = bufs; - next = nexts; + vlib_buffer_t *bufs[VLIB_FRAME_SIZE], **b = bufs; + u32 thread_index = vm->thread_index; + vnet_crypto_op_t **crypto_ops = &ptd->crypto_ops; + const u16 drop_next = WG_INPUT_NEXT_PUNT; + message_type_t header_type; + vlib_buffer_t *data_bufs[VLIB_FRAME_SIZE]; + u32 data_bi[VLIB_FRAME_SIZE]; /* buffer index for data */ + u32 other_bi[VLIB_FRAME_SIZE]; /* buffer index for drop or handoff */ + u16 other_nexts[VLIB_FRAME_SIZE], *other_next = other_nexts, n_other = 0; + u16 data_nexts[VLIB_FRAME_SIZE], *data_next = data_nexts, n_data = 0; vlib_get_buffers (vm, from, bufs, n_left_from); + vec_reset_length (ptd->crypto_ops); + + f64 time = clib_time_now (&vm->clib_time) + vm->time_offset; - wg_main_t *wmp = &wg_main; wg_peer_t *peer = NULL; + u32 *last_peer_time_idx = NULL; + u32 last_rec_idx = ~0; + + bool is_keepalive = false; + u32 *peer_idx = NULL; while (n_left_from > 0) { - bool is_keepalive = false; - next[0] = WG_INPUT_NEXT_PUNT; + if (n_left_from > 2) + { + u8 *p; + vlib_prefetch_buffer_header (b[2], LOAD); + p = vlib_buffer_get_current (b[1]); + CLIB_PREFETCH (p, CLIB_CACHE_LINE_BYTES, LOAD); + CLIB_PREFETCH (vlib_buffer_get_tail (b[1]), CLIB_CACHE_LINE_BYTES, + LOAD); + } + + other_next[n_other] = WG_INPUT_NEXT_PUNT; + data_nexts[n_data] = WG_INPUT_N_NEXT; + header_type = ((message_header_t *) vlib_buffer_get_current (b[0]))->type; - u32 *peer_idx; if (PREDICT_TRUE (header_type == MESSAGE_DATA)) { message_data_t *data = vlib_buffer_get_current (b[0]); - + u8 *iv_data = b[0]->pre_data; peer_idx = wg_index_table_lookup (&wmp->index_table, data->receiver_index); - if (peer_idx) + if (data->receiver_index != last_rec_idx) { - peer = wg_peer_get (*peer_idx); + peer_idx = wg_index_table_lookup (&wmp->index_table, + data->receiver_index); + if (PREDICT_TRUE (peer_idx != NULL)) + { + peer = wg_peer_get (*peer_idx); + } + last_rec_idx = data->receiver_index; } - else + + if (PREDICT_FALSE (!peer_idx)) { - next[0] = WG_INPUT_NEXT_ERROR; + other_next[n_other] = WG_INPUT_NEXT_ERROR; b[0]->error = node->errors[WG_INPUT_ERROR_PEER]; + other_bi[n_other] = from[b - bufs]; + n_other += 1; goto out; } @@ -356,7 +415,9 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node, if (PREDICT_TRUE (thread_index != peer->input_thread_index)) { - next[0] = WG_INPUT_NEXT_HANDOFF_DATA; + other_next[n_other] = WG_INPUT_NEXT_HANDOFF_DATA; + other_bi[n_other] = from[b - bufs]; + n_other += 1; goto next; } @@ -365,83 +426,47 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node, if (PREDICT_FALSE (decr_len >= WG_DEFAULT_DATA_SIZE)) { b[0]->error = node->errors[WG_INPUT_ERROR_TOO_BIG]; + other_bi[n_other] = from[b - bufs]; + n_other += 1; goto out; } - enum noise_state_crypt state_cr = noise_remote_decrypt ( - vm, &peer->remote, data->receiver_index, data->counter, - data->encrypted_data, encr_len, data->encrypted_data); + enum noise_state_crypt state_cr = noise_sync_remote_decrypt ( + vm, crypto_ops, &peer->remote, data->receiver_index, data->counter, + data->encrypted_data, decr_len, data->encrypted_data, n_data, + iv_data, time); if (PREDICT_FALSE (state_cr == SC_CONN_RESET)) { wg_timers_handshake_complete (peer); + data_bufs[n_data] = b[0]; + data_bi[n_data] = from[b - bufs]; + n_data += 1; + goto next; } else if (PREDICT_FALSE (state_cr == SC_KEEP_KEY_FRESH)) { wg_send_handshake_from_mt (*peer_idx, false); + data_bufs[n_data] = b[0]; + data_bi[n_data] = from[b - bufs]; + n_data += 1; + goto next; } else if (PREDICT_FALSE (state_cr == SC_FAILED)) { wg_peer_update_flags (*peer_idx, WG_PEER_ESTABLISHED, false); - next[0] = WG_INPUT_NEXT_ERROR; + other_next[n_other] = WG_INPUT_NEXT_ERROR; b[0]->error = node->errors[WG_INPUT_ERROR_DECRYPTION]; + other_bi[n_other] = from[b - bufs]; + n_other += 1; goto out; } - - vlib_buffer_advance (b[0], sizeof (message_data_t)); - b[0]->current_length = decr_len; - vnet_buffer_offload_flags_clear (b[0], - VNET_BUFFER_OFFLOAD_F_UDP_CKSUM); - - wg_timers_any_authenticated_packet_received (peer); - wg_timers_any_authenticated_packet_traversal (peer); - - /* Keepalive packet has zero length */ - if (decr_len == 0) - { - is_keepalive = true; - goto out; - } - - wg_timers_data_received (peer); - - ip46_address_t src_ip; - u8 is_ip4_inner = is_ip4_header (vlib_buffer_get_current (b[0])); - if (is_ip4_inner) - { - ip46_address_set_ip4 ( - &src_ip, &((ip4_header_t *) vlib_buffer_get_current (b[0])) - ->src_address); - } - else - { - ip46_address_set_ip6 ( - &src_ip, &((ip6_header_t *) vlib_buffer_get_current (b[0])) - ->src_address); - } - - const fib_prefix_t *allowed_ip; - bool allowed = false; - - /* - * we could make this into an ACL, but the expectation - * is that there aren't many allowed IPs and thus a linear - * walk is fater than an ACL - */ - - vec_foreach (allowed_ip, peer->allowed_ips) - { - if (fib_prefix_is_cover_addr_46 (allowed_ip, &src_ip)) - { - allowed = true; - break; - } - } - if (allowed) + else if (PREDICT_TRUE (state_cr == SC_OK)) { - vnet_buffer (b[0])->sw_if_index[VLIB_RX] = peer->wg_sw_if_index; - next[0] = is_ip4_inner ? WG_INPUT_NEXT_IP4_INPUT : - WG_INPUT_NEXT_IP6_INPUT; + data_bufs[n_data] = b[0]; + data_bi[n_data] = from[b - bufs]; + n_data += 1; + goto next; } } else @@ -451,7 +476,9 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node, /* Handshake packets should be processed in main thread */ if (thread_index != 0) { - next[0] = WG_INPUT_NEXT_HANDOFF_HANDSHAKE; + other_next[n_other] = WG_INPUT_NEXT_HANDOFF_HANDSHAKE; + other_bi[n_other] = from[b - bufs]; + n_other += 1; goto next; } @@ -459,14 +486,21 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node, wg_handshake_process (vm, wmp, b[0], node->node_index, is_ip4); if (ret != WG_INPUT_ERROR_NONE) { - next[0] = WG_INPUT_NEXT_ERROR; + other_next[n_other] = WG_INPUT_NEXT_ERROR; b[0]->error = node->errors[ret]; + other_bi[n_other] = from[b - bufs]; + n_other += 1; + } + else + { + other_bi[n_other] = from[b - bufs]; + n_other += 1; } } out: - if (PREDICT_FALSE ((node->flags & VLIB_NODE_FLAG_TRACE) - && (b[0]->flags & VLIB_BUFFER_IS_TRACED))) + if (PREDICT_FALSE ((node->flags & VLIB_NODE_FLAG_TRACE) && + (b[0]->flags & VLIB_BUFFER_IS_TRACED))) { wg_input_trace_t *t = vlib_add_trace (vm, node, b[0], sizeof (*t)); t->type = header_type; @@ -474,12 +508,136 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node, t->is_keepalive = is_keepalive; t->peer = peer_idx ? *peer_idx : INDEX_INVALID; } + next: n_left_from -= 1; - next += 1; b += 1; } - vlib_buffer_enqueue_to_next (vm, node, from, nexts, frame->n_vectors); + + /* decrypt packets */ + wg_input_process_ops (vm, node, ptd->crypto_ops, data_bufs, data_nexts, + drop_next); + + /* process after decryption */ + b = data_bufs; + n_left_from = n_data; + n_data = 0; + last_rec_idx = ~0; + last_peer_time_idx = NULL; + while (n_left_from > 0) + { + bool is_keepalive = false; + u32 *peer_idx = NULL; + + if (data_next[n_data] == WG_INPUT_NEXT_PUNT) + { + goto trace; + } + else + { + data_next[n_data] = WG_INPUT_NEXT_PUNT; + } + + message_data_t *data = vlib_buffer_get_current (b[0]); + + if (data->receiver_index != last_rec_idx) + { + peer_idx = + wg_index_table_lookup (&wmp->index_table, data->receiver_index); + /* already checked and excisting */ + peer = wg_peer_get (*peer_idx); + last_rec_idx = data->receiver_index; + } + + noise_keypair_t *kp = + wg_get_active_keypair (&peer->remote, data->receiver_index); + + if (!noise_counter_recv (&kp->kp_ctr, data->counter)) + { + goto trace; + } + + u16 encr_len = b[0]->current_length - sizeof (message_data_t); + u16 decr_len = encr_len - NOISE_AUTHTAG_LEN; + + vlib_buffer_advance (b[0], sizeof (message_data_t)); + b[0]->current_length = decr_len; + vnet_buffer_offload_flags_clear (b[0], VNET_BUFFER_OFFLOAD_F_UDP_CKSUM); + + if (PREDICT_FALSE (peer_idx && (last_peer_time_idx != peer_idx))) + { + wg_timers_any_authenticated_packet_received_opt (peer, time); + wg_timers_any_authenticated_packet_traversal (peer); + last_peer_time_idx = peer_idx; + } + + /* Keepalive packet has zero length */ + if (decr_len == 0) + { + is_keepalive = true; + goto trace; + } + + wg_timers_data_received (peer); + + ip46_address_t src_ip; + u8 is_ip4_inner = is_ip4_header (vlib_buffer_get_current (b[0])); + if (is_ip4_inner) + { + ip46_address_set_ip4 ( + &src_ip, + &((ip4_header_t *) vlib_buffer_get_current (b[0]))->src_address); + } + else + { + ip46_address_set_ip6 ( + &src_ip, + &((ip6_header_t *) vlib_buffer_get_current (b[0]))->src_address); + } + + const fib_prefix_t *allowed_ip; + bool allowed = false; + + /* + * we could make this into an ACL, but the expectation + * is that there aren't many allowed IPs and thus a linear + * walk is fater than an ACL + */ + + vec_foreach (allowed_ip, peer->allowed_ips) + { + if (fib_prefix_is_cover_addr_46 (allowed_ip, &src_ip)) + { + allowed = true; + break; + } + } + if (allowed) + { + vnet_buffer (b[0])->sw_if_index[VLIB_RX] = peer->wg_sw_if_index; + data_next[n_data] = + is_ip4_inner ? WG_INPUT_NEXT_IP4_INPUT : WG_INPUT_NEXT_IP6_INPUT; + } + trace: + if (PREDICT_FALSE ((node->flags & VLIB_NODE_FLAG_TRACE) && + (b[0]->flags & VLIB_BUFFER_IS_TRACED))) + { + wg_input_trace_t *t = vlib_add_trace (vm, node, b[0], sizeof (*t)); + t->type = header_type; + t->current_length = b[0]->current_length; + t->is_keepalive = is_keepalive; + t->peer = peer_idx ? *peer_idx : INDEX_INVALID; + } + + b += 1; + n_left_from -= 1; + n_data += 1; + } + /* enqueue other bufs */ + vlib_buffer_enqueue_to_next (vm, node, other_bi, other_next, n_other); + + /* enqueue data bufs */ + vlib_buffer_enqueue_to_next (vm, node, data_bi, data_nexts, n_data); return frame->n_vectors; } diff --git a/src/plugins/wireguard/wireguard_noise.c b/src/plugins/wireguard/wireguard_noise.c index 36de8ae9cac..c8605f117cd 100644 --- a/src/plugins/wireguard/wireguard_noise.c +++ b/src/plugins/wireguard/wireguard_noise.c @@ -36,7 +36,7 @@ static uint32_t noise_remote_handshake_index_get (noise_remote_t *); static void noise_remote_handshake_index_drop (noise_remote_t *); static uint64_t noise_counter_send (noise_counter_t *); -static bool noise_counter_recv (noise_counter_t *, uint64_t); +bool noise_counter_recv (noise_counter_t *, uint64_t); static void noise_kdf (uint8_t *, uint8_t *, uint8_t *, const uint8_t *, size_t, size_t, size_t, size_t, @@ -407,6 +407,8 @@ noise_remote_begin_session (vlib_main_t * vm, noise_remote_t * r) /* Now we need to add_new_keypair */ clib_rwlock_writer_lock (&r->r_keypair_lock); + /* Activate barrier to synchronization keys between threads */ + vlib_worker_thread_barrier_sync (vm); next = r->r_next; current = r->r_current; previous = r->r_previous; @@ -438,6 +440,7 @@ noise_remote_begin_session (vlib_main_t * vm, noise_remote_t * r) r->r_next = noise_remote_keypair_allocate (r); *r->r_next = kp; } + vlib_worker_thread_barrier_release (vm); clib_rwlock_writer_unlock (&r->r_keypair_lock); secure_zero_memory (&r->r_handshake, sizeof (r->r_handshake)); @@ -541,6 +544,41 @@ chacha20poly1305_calc (vlib_main_t * vm, return (op->status == VNET_CRYPTO_OP_STATUS_COMPLETED); } +always_inline void +wg_prepare_sync_op (vlib_main_t *vm, vnet_crypto_op_t **crypto_ops, u8 *src, + u32 src_len, u8 *dst, u8 *aad, u32 aad_len, u64 nonce, + vnet_crypto_op_id_t op_id, + vnet_crypto_key_index_t key_index, u32 bi, u8 *iv) +{ + vnet_crypto_op_t _op, *op = &_op; + u8 src_[] = {}; + + clib_memset (iv, 0, 4); + clib_memcpy (iv + 4, &nonce, sizeof (nonce)); + + vec_add2_aligned (crypto_ops[0], op, 1, CLIB_CACHE_LINE_BYTES); + vnet_crypto_op_init (op, op_id); + + op->tag_len = NOISE_AUTHTAG_LEN; + if (op_id == VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC) + { + op->tag = src + src_len; + op->flags |= VNET_CRYPTO_OP_FLAG_HMAC_CHECK; + } + else + op->tag = dst + src_len; + + op->src = !src ? src_ : src; + op->len = src_len; + + op->dst = dst; + op->key_index = key_index; + op->aad = aad; + op->aad_len = aad_len; + op->iv = iv; + op->user_data = bi; +} + enum noise_state_crypt noise_remote_encrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t * r_idx, uint64_t * nonce, uint8_t * src, size_t srclen, @@ -592,26 +630,67 @@ error: } enum noise_state_crypt -noise_remote_decrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t r_idx, - uint64_t nonce, uint8_t * src, size_t srclen, - uint8_t * dst) +noise_sync_remote_encrypt (vlib_main_t *vm, vnet_crypto_op_t **crypto_ops, + noise_remote_t *r, uint32_t *r_idx, uint64_t *nonce, + uint8_t *src, size_t srclen, uint8_t *dst, u32 bi, + u8 *iv, f64 time) { noise_keypair_t *kp; enum noise_state_crypt ret = SC_FAILED; - if (r->r_current != NULL && r->r_current->kp_local_index == r_idx) - { - kp = r->r_current; - } - else if (r->r_previous != NULL && r->r_previous->kp_local_index == r_idx) - { - kp = r->r_previous; - } - else if (r->r_next != NULL && r->r_next->kp_local_index == r_idx) - { - kp = r->r_next; - } - else + if ((kp = r->r_current) == NULL) + goto error; + + /* We confirm that our values are within our tolerances. We want: + * - a valid keypair + * - our keypair to be less than REJECT_AFTER_TIME seconds old + * - our receive counter to be less than REJECT_AFTER_MESSAGES + * - our send counter to be less than REJECT_AFTER_MESSAGES + */ + if (!kp->kp_valid || + wg_birthdate_has_expired_opt (kp->kp_birthdate, REJECT_AFTER_TIME, + time) || + kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES || + ((*nonce = noise_counter_send (&kp->kp_ctr)) > REJECT_AFTER_MESSAGES)) + goto error; + + /* We encrypt into the same buffer, so the caller must ensure that buf + * has NOISE_AUTHTAG_LEN bytes to store the MAC. The nonce and index + * are passed back out to the caller through the provided data pointer. */ + *r_idx = kp->kp_remote_index; + + wg_prepare_sync_op (vm, crypto_ops, src, srclen, dst, NULL, 0, *nonce, + VNET_CRYPTO_OP_CHACHA20_POLY1305_ENC, kp->kp_send_index, + bi, iv); + + /* If our values are still within tolerances, but we are approaching + * the tolerances, we notify the caller with ESTALE that they should + * establish a new keypair. The current keypair can continue to be used + * until the tolerances are hit. We notify if: + * - our send counter is valid and not less than REKEY_AFTER_MESSAGES + * - we're the initiator and our keypair is older than + * REKEY_AFTER_TIME seconds */ + ret = SC_KEEP_KEY_FRESH; + if ((kp->kp_valid && *nonce >= REKEY_AFTER_MESSAGES) || + (kp->kp_is_initiator && wg_birthdate_has_expired_opt ( + kp->kp_birthdate, REKEY_AFTER_TIME, time))) + goto error; + + ret = SC_OK; +error: + return ret; +} + +enum noise_state_crypt +noise_sync_remote_decrypt (vlib_main_t *vm, vnet_crypto_op_t **crypto_ops, + noise_remote_t *r, uint32_t r_idx, uint64_t nonce, + uint8_t *src, size_t srclen, uint8_t *dst, u32 bi, + u8 *iv, f64 time) +{ + noise_keypair_t *kp; + enum noise_state_crypt ret = SC_FAILED; + + if ((kp = wg_get_active_keypair (r, r_idx)) == NULL) { goto error; } @@ -620,20 +699,17 @@ noise_remote_decrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t r_idx, * are the same as the encrypt routine. * * kp_ctr isn't locked here, we're happy to accept a racy read. */ - if (wg_birthdate_has_expired (kp->kp_birthdate, REJECT_AFTER_TIME) || + if (wg_birthdate_has_expired_opt (kp->kp_birthdate, REJECT_AFTER_TIME, + time) || kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES) goto error; /* Decrypt, then validate the counter. We don't want to validate the * counter before decrypting as we do not know the message is authentic * prior to decryption. */ - if (!chacha20poly1305_calc (vm, src, srclen, dst, NULL, 0, nonce, - VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC, - kp->kp_recv_index)) - goto error; - - if (!noise_counter_recv (&kp->kp_ctr, nonce)) - goto error; + wg_prepare_sync_op (vm, crypto_ops, src, srclen, dst, NULL, 0, nonce, + VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC, kp->kp_recv_index, + bi, iv); /* If we've received the handshake confirming data packet then move the * next keypair into current. If we do slide the next keypair in, then @@ -662,10 +738,9 @@ noise_remote_decrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t r_idx, * REKEY_AFTER_TIME_RECV seconds. */ ret = SC_KEEP_KEY_FRESH; kp = r->r_current; - if (kp != NULL && - kp->kp_valid && - kp->kp_is_initiator && - wg_birthdate_has_expired (kp->kp_birthdate, REKEY_AFTER_TIME_RECV)) + if (kp != NULL && kp->kp_valid && kp->kp_is_initiator && + wg_birthdate_has_expired_opt (kp->kp_birthdate, REKEY_AFTER_TIME_RECV, + time)) goto error; ret = SC_OK; @@ -724,47 +799,6 @@ noise_counter_send (noise_counter_t * ctr) return ret; } -static bool -noise_counter_recv (noise_counter_t * ctr, uint64_t recv) -{ - uint64_t i, top, index_recv, index_ctr; - unsigned long bit; - bool ret = false; - - /* Check that the recv counter is valid */ - if (ctr->c_recv >= REJECT_AFTER_MESSAGES || recv >= REJECT_AFTER_MESSAGES) - goto error; - - /* If the packet is out of the window, invalid */ - if (recv + COUNTER_WINDOW_SIZE < ctr->c_recv) - goto error; - - /* If the new counter is ahead of the current counter, we'll need to - * zero out the bitmap that has previously been used */ - index_recv = recv / COUNTER_BITS; - index_ctr = ctr->c_recv / COUNTER_BITS; - - if (recv > ctr->c_recv) - { - top = clib_min (index_recv - index_ctr, COUNTER_NUM); - for (i = 1; i <= top; i++) - ctr->c_backtrack[(i + index_ctr) & (COUNTER_NUM - 1)] = 0; - ctr->c_recv = recv; - } - - index_recv %= COUNTER_NUM; - bit = 1ul << (recv % COUNTER_BITS); - - if (ctr->c_backtrack[index_recv] & bit) - goto error; - - ctr->c_backtrack[index_recv] |= bit; - - ret = true; -error: - return ret; -} - static void noise_kdf (uint8_t * a, uint8_t * b, uint8_t * c, const uint8_t * x, size_t a_len, size_t b_len, size_t c_len, size_t x_len, diff --git a/src/plugins/wireguard/wireguard_noise.h b/src/plugins/wireguard/wireguard_noise.h index 5b5a88fa250..ef1e7dcbfca 100644 --- a/src/plugins/wireguard/wireguard_noise.h +++ b/src/plugins/wireguard/wireguard_noise.h @@ -187,12 +187,80 @@ noise_remote_encrypt (vlib_main_t * vm, noise_remote_t *, uint32_t * r_idx, uint64_t * nonce, uint8_t * src, size_t srclen, uint8_t * dst); + enum noise_state_crypt -noise_remote_decrypt (vlib_main_t * vm, noise_remote_t *, - uint32_t r_idx, - uint64_t nonce, - uint8_t * src, size_t srclen, uint8_t * dst); +noise_sync_remote_encrypt (vlib_main_t *vm, vnet_crypto_op_t **crypto_ops, + noise_remote_t *r, uint32_t *r_idx, uint64_t *nonce, + uint8_t *src, size_t srclen, uint8_t *dst, u32 bi, + u8 *iv, f64 time); + +enum noise_state_crypt +noise_sync_remote_decrypt (vlib_main_t *vm, vnet_crypto_op_t **crypto_ops, + noise_remote_t *, uint32_t r_idx, uint64_t nonce, + uint8_t *src, size_t srclen, uint8_t *dst, u32 bi, + u8 *iv, f64 time); + +static_always_inline noise_keypair_t * +wg_get_active_keypair (noise_remote_t *r, uint32_t r_idx) +{ + if (r->r_current != NULL && r->r_current->kp_local_index == r_idx) + { + return r->r_current; + } + else if (r->r_previous != NULL && r->r_previous->kp_local_index == r_idx) + { + return r->r_previous; + } + else if (r->r_next != NULL && r->r_next->kp_local_index == r_idx) + { + return r->r_next; + } + else + { + return NULL; + } +} + +inline bool +noise_counter_recv (noise_counter_t *ctr, uint64_t recv) +{ + uint64_t i, top, index_recv, index_ctr; + unsigned long bit; + bool ret = false; + /* Check that the recv counter is valid */ + if (ctr->c_recv >= REJECT_AFTER_MESSAGES || recv >= REJECT_AFTER_MESSAGES) + goto error; + + /* If the packet is out of the window, invalid */ + if (recv + COUNTER_WINDOW_SIZE < ctr->c_recv) + goto error; + + /* If the new counter is ahead of the current counter, we'll need to + * zero out the bitmap that has previously been used */ + index_recv = recv / COUNTER_BITS; + index_ctr = ctr->c_recv / COUNTER_BITS; + + if (recv > ctr->c_recv) + { + top = clib_min (index_recv - index_ctr, COUNTER_NUM); + for (i = 1; i <= top; i++) + ctr->c_backtrack[(i + index_ctr) & (COUNTER_NUM - 1)] = 0; + ctr->c_recv = recv; + } + + index_recv %= COUNTER_NUM; + bit = 1ul << (recv % COUNTER_BITS); + + if (ctr->c_backtrack[index_recv] & bit) + goto error; + + ctr->c_backtrack[index_recv] |= bit; + + ret = true; +error: + return ret; +} #endif /* __included_wg_noise_h__ */ diff --git a/src/plugins/wireguard/wireguard_output_tun.c b/src/plugins/wireguard/wireguard_output_tun.c index c792d4b713e..2feb0570a31 100644 --- a/src/plugins/wireguard/wireguard_output_tun.c +++ b/src/plugins/wireguard/wireguard_output_tun.c @@ -93,32 +93,67 @@ format_wg_output_tun_trace (u8 * s, va_list * args) return s; } +static_always_inline void +wg_output_process_ops (vlib_main_t *vm, vlib_node_runtime_t *node, + vnet_crypto_op_t *ops, vlib_buffer_t *b[], u16 *nexts, + u16 drop_next) +{ + u32 n_fail, n_ops = vec_len (ops); + vnet_crypto_op_t *op = ops; + + if (n_ops == 0) + return; + + n_fail = n_ops - vnet_crypto_process_ops (vm, op, n_ops); + + while (n_fail) + { + ASSERT (op - ops < n_ops); + + if (op->status != VNET_CRYPTO_OP_STATUS_COMPLETED) + { + u32 bi = op->user_data; + b[bi]->error = node->errors[WG_OUTPUT_ERROR_KEYPAIR]; + nexts[bi] = drop_next; + n_fail--; + } + op++; + } +} + /* is_ip4 - inner header flag */ always_inline uword wg_output_tun_inline (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame, u8 is_ip4) { - u32 n_left_from; - u32 *from; + wg_main_t *wmp = &wg_main; + wg_per_thread_data_t *ptd = + vec_elt_at_index (wmp->per_thread_data, vm->thread_index); + u32 *from = vlib_frame_vector_args (frame); + u32 n_left_from = frame->n_vectors; ip4_udp_wg_header_t *hdr4_out = NULL; ip6_udp_wg_header_t *hdr6_out = NULL; message_data_t *message_data_wg = NULL; - vlib_buffer_t *bufs[VLIB_FRAME_SIZE], **b; - u16 nexts[VLIB_FRAME_SIZE], *next; + vlib_buffer_t *bufs[VLIB_FRAME_SIZE], **b = bufs; + vnet_crypto_op_t **crypto_ops = &ptd->crypto_ops; + u16 nexts[VLIB_FRAME_SIZE], *next = nexts; + vlib_buffer_t *sync_bufs[VLIB_FRAME_SIZE]; u32 thread_index = vm->thread_index; - - from = vlib_frame_vector_args (frame); - n_left_from = frame->n_vectors; - b = bufs; - next = nexts; + u16 n_sync = 0; + u16 drop_next = WG_OUTPUT_NEXT_ERROR; vlib_get_buffers (vm, from, bufs, n_left_from); + vec_reset_length (ptd->crypto_ops); wg_peer_t *peer = NULL; + u32 adj_index = 0; + u32 last_adj_index = ~0; + index_t peeri = INDEX_INVALID; + + f64 time = clib_time_now (&vm->clib_time) + vm->time_offset; while (n_left_from > 0) { - index_t peeri; u8 iph_offset = 0; u8 is_ip4_out = 1; u8 *plain_data; @@ -130,14 +165,21 @@ wg_output_tun_inline (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_prefetch_buffer_header (b[2], LOAD); p = vlib_buffer_get_current (b[1]); CLIB_PREFETCH (p, CLIB_CACHE_LINE_BYTES, LOAD); + CLIB_PREFETCH (vlib_buffer_get_tail (b[1]), CLIB_CACHE_LINE_BYTES, + LOAD); } next[0] = WG_OUTPUT_NEXT_ERROR; - peeri = - wg_peer_get_by_adj_index (vnet_buffer (b[0])->ip.adj_index[VLIB_TX]); - peer = wg_peer_get (peeri); - if (wg_peer_is_dead (peer)) + adj_index = vnet_buffer (b[0])->ip.adj_index[VLIB_TX]; + + if (PREDICT_FALSE (last_adj_index != adj_index)) + { + peeri = wg_peer_get_by_adj_index (adj_index); + peer = wg_peer_get (peeri); + } + + if (!peer || wg_peer_is_dead (peer)) { b[0]->error = node->errors[WG_OUTPUT_ERROR_PEER]; goto out; @@ -179,6 +221,7 @@ wg_output_tun_inline (vlib_main_t *vm, vlib_node_runtime_t *node, iph_offset = vnet_buffer (b[0])->ip.save_rewrite_length; plain_data = vlib_buffer_get_current (b[0]) + iph_offset; plain_data_len = vlib_buffer_length_in_chain (vm, b[0]) - iph_offset; + u8 *iv_data = b[0]->pre_data; size_t encrypted_packet_len = message_data_len (plain_data_len); @@ -194,11 +237,20 @@ wg_output_tun_inline (vlib_main_t *vm, vlib_node_runtime_t *node, goto out; } + if (PREDICT_FALSE (last_adj_index != adj_index)) + { + wg_timers_any_authenticated_packet_sent_opt (peer, time); + wg_timers_data_sent_opt (peer, time); + wg_timers_any_authenticated_packet_traversal (peer); + last_adj_index = adj_index; + } + enum noise_state_crypt state; - state = noise_remote_encrypt ( - vm, &peer->remote, &message_data_wg->receiver_index, - &message_data_wg->counter, plain_data, plain_data_len, plain_data); + state = noise_sync_remote_encrypt ( + vm, crypto_ops, &peer->remote, &message_data_wg->receiver_index, + &message_data_wg->counter, plain_data, plain_data_len, plain_data, + n_sync, iv_data, time); if (PREDICT_FALSE (state == SC_KEEP_KEY_FRESH)) { @@ -206,7 +258,7 @@ wg_output_tun_inline (vlib_main_t *vm, vlib_node_runtime_t *node, } else if (PREDICT_FALSE (state == SC_FAILED)) { - //TODO: Maybe wrong + // TODO: Maybe wrong wg_send_handshake_from_mt (peeri, false); wg_peer_update_flags (peeri, WG_PEER_ESTABLISHED, false); goto out; @@ -236,10 +288,6 @@ wg_output_tun_inline (vlib_main_t *vm, vlib_node_runtime_t *node, clib_host_to_net_u16 (b[0]->current_length); } - wg_timers_any_authenticated_packet_sent (peer); - wg_timers_data_sent (peer); - wg_timers_any_authenticated_packet_traversal (peer); - out: if (PREDICT_FALSE ((node->flags & VLIB_NODE_FLAG_TRACE) && (b[0]->flags & VLIB_BUFFER_IS_TRACED))) @@ -256,12 +304,18 @@ wg_output_tun_inline (vlib_main_t *vm, vlib_node_runtime_t *node, } next: + sync_bufs[n_sync] = b[0]; + n_sync += 1; n_left_from -= 1; next += 1; b += 1; } - vlib_buffer_enqueue_to_next (vm, node, from, nexts, frame->n_vectors); + /* wg-output-process-ops */ + wg_output_process_ops (vm, node, ptd->crypto_ops, sync_bufs, nexts, + drop_next); + + vlib_buffer_enqueue_to_next (vm, node, from, nexts, n_sync); return frame->n_vectors; } diff --git a/src/plugins/wireguard/wireguard_timer.c b/src/plugins/wireguard/wireguard_timer.c index 97c861b459c..d2ae7ebf93e 100644 --- a/src/plugins/wireguard/wireguard_timer.c +++ b/src/plugins/wireguard/wireguard_timer.c @@ -26,6 +26,13 @@ get_random_u32_max (u32 max) return random_u32 (&seed) % max; } +static u32 +get_random_u32_max_opt (u32 max, f64 time) +{ + u32 seed = (u32) (time * 1e6); + return random_u32 (&seed) % max; +} + static void stop_timer (wg_peer_t * peer, u32 timer_id) { @@ -215,6 +222,12 @@ wg_timers_any_authenticated_packet_sent (wg_peer_t * peer) } void +wg_timers_any_authenticated_packet_sent_opt (wg_peer_t *peer, f64 time) +{ + peer->last_sent_packet = time; +} + +void wg_timers_handshake_initiated (wg_peer_t * peer) { peer->rehandshake_started = vlib_time_now (vlib_get_main ()); @@ -246,6 +259,17 @@ wg_timers_data_sent (wg_peer_t * peer) peer->new_handshake_interval_tick); } +void +wg_timers_data_sent_opt (wg_peer_t *peer, f64 time) +{ + peer->new_handshake_interval_tick = + (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT) * WHZ + + get_random_u32_max_opt (REKEY_TIMEOUT_JITTER, time); + + start_timer_from_mt (peer - wg_peer_pool, WG_TIMER_NEW_HANDSHAKE, + peer->new_handshake_interval_tick); +} + /* Should be called after an authenticated data packet is received. */ void wg_timers_data_received (wg_peer_t * peer) @@ -275,6 +299,12 @@ wg_timers_any_authenticated_packet_received (wg_peer_t * peer) peer->last_received_packet = vlib_time_now (vlib_get_main ()); } +void +wg_timers_any_authenticated_packet_received_opt (wg_peer_t *peer, f64 time) +{ + peer->last_received_packet = time; +} + static vlib_node_registration_t wg_timer_mngr_node; static void diff --git a/src/plugins/wireguard/wireguard_timer.h b/src/plugins/wireguard/wireguard_timer.h index 6b59a39f815..cc8e123f3a2 100644 --- a/src/plugins/wireguard/wireguard_timer.h +++ b/src/plugins/wireguard/wireguard_timer.h @@ -41,9 +41,13 @@ typedef struct wg_peer wg_peer_t; void wg_timer_wheel_init (); void wg_timers_stop (wg_peer_t * peer); void wg_timers_data_sent (wg_peer_t * peer); +void wg_timers_data_sent_opt (wg_peer_t *peer, f64 time); void wg_timers_data_received (wg_peer_t * peer); void wg_timers_any_authenticated_packet_sent (wg_peer_t * peer); +void wg_timers_any_authenticated_packet_sent_opt (wg_peer_t *peer, f64 time); void wg_timers_any_authenticated_packet_received (wg_peer_t * peer); +void wg_timers_any_authenticated_packet_received_opt (wg_peer_t *peer, + f64 time); void wg_timers_handshake_initiated (wg_peer_t * peer); void wg_timers_handshake_complete (wg_peer_t * peer); void wg_timers_session_derived (wg_peer_t * peer); @@ -57,6 +61,13 @@ wg_birthdate_has_expired (f64 birthday_seconds, f64 expiration_seconds) return (birthday_seconds + expiration_seconds) < now_seconds; } +static inline bool +wg_birthdate_has_expired_opt (f64 birthday_seconds, f64 expiration_seconds, + f64 time) +{ + return (birthday_seconds + expiration_seconds) < time; +} + #endif /* __included_wg_timer_h__ */ /* |