diff options
author | Gabriel Oginski <gabrielx.oginski@intel.com> | 2021-11-04 07:23:08 +0000 |
---|---|---|
committer | Fan Zhang <roy.fan.zhang@intel.com> | 2022-01-24 10:01:42 +0000 |
commit | ab2478ceedc1756e56b2c3406b168826ffa17555 (patch) | |
tree | 460ceddb5caeb9995bb99343944e07cd77f820cd | |
parent | 93e5bea2d3b0c324b6b09ee87a922236b2b3eaf9 (diff) |
wireguard: add burst mode
Originally wireguard does packet by packet encryption and decryption.
This patch adds burst mode for encryption and decryption packets. In
addition, it contains some performance improvement such as prefetching
packet header and reducing the number of current time function calls.
Type: improvement
Signed-off-by: Gabriel Oginski <gabrielx.oginski@intel.com>
Change-Id: I04c7daa9b6dc56cd15c789661a64ec642b35aa3f
(cherry picked from commit 8ca08496a43e8d98fe2d4130d760c6fb600d0a93)
-rw-r--r-- | src/plugins/wireguard/wireguard.h | 2 | ||||
-rw-r--r-- | src/plugins/wireguard/wireguard_input.c | 326 | ||||
-rw-r--r-- | src/plugins/wireguard/wireguard_noise.c | 174 | ||||
-rw-r--r-- | src/plugins/wireguard/wireguard_noise.h | 76 | ||||
-rw-r--r-- | src/plugins/wireguard/wireguard_output_tun.c | 100 | ||||
-rw-r--r-- | src/plugins/wireguard/wireguard_timer.c | 30 | ||||
-rw-r--r-- | src/plugins/wireguard/wireguard_timer.h | 11 |
7 files changed, 538 insertions, 181 deletions
diff --git a/src/plugins/wireguard/wireguard.h b/src/plugins/wireguard/wireguard.h index 829c9e6f22b..4cbee1fcf7a 100644 --- a/src/plugins/wireguard/wireguard.h +++ b/src/plugins/wireguard/wireguard.h @@ -28,6 +28,8 @@ extern vlib_node_registration_t wg6_output_tun_node; typedef struct wg_per_thread_data_t_ { + CLIB_CACHE_LINE_ALIGN_MARK (cacheline0); + vnet_crypto_op_t *crypto_ops; u8 data[WG_DEFAULT_DATA_SIZE]; } wg_per_thread_data_t; typedef struct diff --git a/src/plugins/wireguard/wireguard_input.c b/src/plugins/wireguard/wireguard_input.c index d146d0d8fb5..10827ca0e64 100644 --- a/src/plugins/wireguard/wireguard_input.c +++ b/src/plugins/wireguard/wireguard_input.c @@ -298,50 +298,109 @@ wg_handshake_process (vlib_main_t *vm, wg_main_t *wmp, vlib_buffer_t *b, return WG_INPUT_ERROR_NONE; } +static_always_inline void +wg_input_process_ops (vlib_main_t *vm, vlib_node_runtime_t *node, + vnet_crypto_op_t *ops, vlib_buffer_t *b[], u16 *nexts, + u16 drop_next) +{ + u32 n_fail, n_ops = vec_len (ops); + vnet_crypto_op_t *op = ops; + + if (n_ops == 0) + return; + + n_fail = n_ops - vnet_crypto_process_ops (vm, op, n_ops); + + while (n_fail) + { + ASSERT (op - ops < n_ops); + + if (op->status != VNET_CRYPTO_OP_STATUS_COMPLETED) + { + u32 bi = op->user_data; + b[bi]->error = node->errors[WG_INPUT_ERROR_DECRYPTION]; + nexts[bi] = drop_next; + n_fail--; + } + op++; + } +} + always_inline uword wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame, u8 is_ip4) { - message_type_t header_type; - u32 n_left_from; - u32 *from; - vlib_buffer_t *bufs[VLIB_FRAME_SIZE], **b; - u16 nexts[VLIB_FRAME_SIZE], *next; - u32 thread_index = vm->thread_index; + wg_main_t *wmp = &wg_main; + wg_per_thread_data_t *ptd = + vec_elt_at_index (wmp->per_thread_data, vm->thread_index); + u32 *from = vlib_frame_vector_args (frame); + u32 n_left_from = frame->n_vectors; - from = vlib_frame_vector_args (frame); - n_left_from = frame->n_vectors; - b = bufs; - next = nexts; + vlib_buffer_t *bufs[VLIB_FRAME_SIZE], **b = bufs; + u32 thread_index = vm->thread_index; + vnet_crypto_op_t **crypto_ops = &ptd->crypto_ops; + const u16 drop_next = WG_INPUT_NEXT_PUNT; + message_type_t header_type; + vlib_buffer_t *data_bufs[VLIB_FRAME_SIZE]; + u32 data_bi[VLIB_FRAME_SIZE]; /* buffer index for data */ + u32 other_bi[VLIB_FRAME_SIZE]; /* buffer index for drop or handoff */ + u16 other_nexts[VLIB_FRAME_SIZE], *other_next = other_nexts, n_other = 0; + u16 data_nexts[VLIB_FRAME_SIZE], *data_next = data_nexts, n_data = 0; vlib_get_buffers (vm, from, bufs, n_left_from); + vec_reset_length (ptd->crypto_ops); + + f64 time = clib_time_now (&vm->clib_time) + vm->time_offset; - wg_main_t *wmp = &wg_main; wg_peer_t *peer = NULL; + u32 *last_peer_time_idx = NULL; + u32 last_rec_idx = ~0; + + bool is_keepalive = false; + u32 *peer_idx = NULL; while (n_left_from > 0) { - bool is_keepalive = false; - next[0] = WG_INPUT_NEXT_PUNT; + if (n_left_from > 2) + { + u8 *p; + vlib_prefetch_buffer_header (b[2], LOAD); + p = vlib_buffer_get_current (b[1]); + CLIB_PREFETCH (p, CLIB_CACHE_LINE_BYTES, LOAD); + CLIB_PREFETCH (vlib_buffer_get_tail (b[1]), CLIB_CACHE_LINE_BYTES, + LOAD); + } + + other_next[n_other] = WG_INPUT_NEXT_PUNT; + data_nexts[n_data] = WG_INPUT_N_NEXT; + header_type = ((message_header_t *) vlib_buffer_get_current (b[0]))->type; - u32 *peer_idx; if (PREDICT_TRUE (header_type == MESSAGE_DATA)) { message_data_t *data = vlib_buffer_get_current (b[0]); - + u8 *iv_data = b[0]->pre_data; peer_idx = wg_index_table_lookup (&wmp->index_table, data->receiver_index); - if (peer_idx) + if (data->receiver_index != last_rec_idx) { - peer = wg_peer_get (*peer_idx); + peer_idx = wg_index_table_lookup (&wmp->index_table, + data->receiver_index); + if (PREDICT_TRUE (peer_idx != NULL)) + { + peer = wg_peer_get (*peer_idx); + } + last_rec_idx = data->receiver_index; } - else + + if (PREDICT_FALSE (!peer_idx)) { - next[0] = WG_INPUT_NEXT_ERROR; + other_next[n_other] = WG_INPUT_NEXT_ERROR; b[0]->error = node->errors[WG_INPUT_ERROR_PEER]; + other_bi[n_other] = from[b - bufs]; + n_other += 1; goto out; } @@ -356,7 +415,9 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node, if (PREDICT_TRUE (thread_index != peer->input_thread_index)) { - next[0] = WG_INPUT_NEXT_HANDOFF_DATA; + other_next[n_other] = WG_INPUT_NEXT_HANDOFF_DATA; + other_bi[n_other] = from[b - bufs]; + n_other += 1; goto next; } @@ -365,83 +426,47 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node, if (PREDICT_FALSE (decr_len >= WG_DEFAULT_DATA_SIZE)) { b[0]->error = node->errors[WG_INPUT_ERROR_TOO_BIG]; + other_bi[n_other] = from[b - bufs]; + n_other += 1; goto out; } - enum noise_state_crypt state_cr = noise_remote_decrypt ( - vm, &peer->remote, data->receiver_index, data->counter, - data->encrypted_data, encr_len, data->encrypted_data); + enum noise_state_crypt state_cr = noise_sync_remote_decrypt ( + vm, crypto_ops, &peer->remote, data->receiver_index, data->counter, + data->encrypted_data, decr_len, data->encrypted_data, n_data, + iv_data, time); if (PREDICT_FALSE (state_cr == SC_CONN_RESET)) { wg_timers_handshake_complete (peer); + data_bufs[n_data] = b[0]; + data_bi[n_data] = from[b - bufs]; + n_data += 1; + goto next; } else if (PREDICT_FALSE (state_cr == SC_KEEP_KEY_FRESH)) { wg_send_handshake_from_mt (*peer_idx, false); + data_bufs[n_data] = b[0]; + data_bi[n_data] = from[b - bufs]; + n_data += 1; + goto next; } else if (PREDICT_FALSE (state_cr == SC_FAILED)) { wg_peer_update_flags (*peer_idx, WG_PEER_ESTABLISHED, false); - next[0] = WG_INPUT_NEXT_ERROR; + other_next[n_other] = WG_INPUT_NEXT_ERROR; b[0]->error = node->errors[WG_INPUT_ERROR_DECRYPTION]; + other_bi[n_other] = from[b - bufs]; + n_other += 1; goto out; } - - vlib_buffer_advance (b[0], sizeof (message_data_t)); - b[0]->current_length = decr_len; - vnet_buffer_offload_flags_clear (b[0], - VNET_BUFFER_OFFLOAD_F_UDP_CKSUM); - - wg_timers_any_authenticated_packet_received (peer); - wg_timers_any_authenticated_packet_traversal (peer); - - /* Keepalive packet has zero length */ - if (decr_len == 0) - { - is_keepalive = true; - goto out; - } - - wg_timers_data_received (peer); - - ip46_address_t src_ip; - u8 is_ip4_inner = is_ip4_header (vlib_buffer_get_current (b[0])); - if (is_ip4_inner) - { - ip46_address_set_ip4 ( - &src_ip, &((ip4_header_t *) vlib_buffer_get_current (b[0])) - ->src_address); - } - else - { - ip46_address_set_ip6 ( - &src_ip, &((ip6_header_t *) vlib_buffer_get_current (b[0])) - ->src_address); - } - - const fib_prefix_t *allowed_ip; - bool allowed = false; - - /* - * we could make this into an ACL, but the expectation - * is that there aren't many allowed IPs and thus a linear - * walk is fater than an ACL - */ - - vec_foreach (allowed_ip, peer->allowed_ips) - { - if (fib_prefix_is_cover_addr_46 (allowed_ip, &src_ip)) - { - allowed = true; - break; - } - } - if (allowed) + else if (PREDICT_TRUE (state_cr == SC_OK)) { - vnet_buffer (b[0])->sw_if_index[VLIB_RX] = peer->wg_sw_if_index; - next[0] = is_ip4_inner ? WG_INPUT_NEXT_IP4_INPUT : - WG_INPUT_NEXT_IP6_INPUT; + data_bufs[n_data] = b[0]; + data_bi[n_data] = from[b - bufs]; + n_data += 1; + goto next; } } else @@ -451,7 +476,9 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node, /* Handshake packets should be processed in main thread */ if (thread_index != 0) { - next[0] = WG_INPUT_NEXT_HANDOFF_HANDSHAKE; + other_next[n_other] = WG_INPUT_NEXT_HANDOFF_HANDSHAKE; + other_bi[n_other] = from[b - bufs]; + n_other += 1; goto next; } @@ -459,14 +486,21 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node, wg_handshake_process (vm, wmp, b[0], node->node_index, is_ip4); if (ret != WG_INPUT_ERROR_NONE) { - next[0] = WG_INPUT_NEXT_ERROR; + other_next[n_other] = WG_INPUT_NEXT_ERROR; b[0]->error = node->errors[ret]; + other_bi[n_other] = from[b - bufs]; + n_other += 1; + } + else + { + other_bi[n_other] = from[b - bufs]; + n_other += 1; } } out: - if (PREDICT_FALSE ((node->flags & VLIB_NODE_FLAG_TRACE) - && (b[0]->flags & VLIB_BUFFER_IS_TRACED))) + if (PREDICT_FALSE ((node->flags & VLIB_NODE_FLAG_TRACE) && + (b[0]->flags & VLIB_BUFFER_IS_TRACED))) { wg_input_trace_t *t = vlib_add_trace (vm, node, b[0], sizeof (*t)); t->type = header_type; @@ -474,12 +508,136 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node, t->is_keepalive = is_keepalive; t->peer = peer_idx ? *peer_idx : INDEX_INVALID; } + next: n_left_from -= 1; - next += 1; b += 1; } - vlib_buffer_enqueue_to_next (vm, node, from, nexts, frame->n_vectors); + + /* decrypt packets */ + wg_input_process_ops (vm, node, ptd->crypto_ops, data_bufs, data_nexts, + drop_next); + + /* process after decryption */ + b = data_bufs; + n_left_from = n_data; + n_data = 0; + last_rec_idx = ~0; + last_peer_time_idx = NULL; + while (n_left_from > 0) + { + bool is_keepalive = false; + u32 *peer_idx = NULL; + + if (data_next[n_data] == WG_INPUT_NEXT_PUNT) + { + goto trace; + } + else + { + data_next[n_data] = WG_INPUT_NEXT_PUNT; + } + + message_data_t *data = vlib_buffer_get_current (b[0]); + + if (data->receiver_index != last_rec_idx) + { + peer_idx = + wg_index_table_lookup (&wmp->index_table, data->receiver_index); + /* already checked and excisting */ + peer = wg_peer_get (*peer_idx); + last_rec_idx = data->receiver_index; + } + + noise_keypair_t *kp = + wg_get_active_keypair (&peer->remote, data->receiver_index); + + if (!noise_counter_recv (&kp->kp_ctr, data->counter)) + { + goto trace; + } + + u16 encr_len = b[0]->current_length - sizeof (message_data_t); + u16 decr_len = encr_len - NOISE_AUTHTAG_LEN; + + vlib_buffer_advance (b[0], sizeof (message_data_t)); + b[0]->current_length = decr_len; + vnet_buffer_offload_flags_clear (b[0], VNET_BUFFER_OFFLOAD_F_UDP_CKSUM); + + if (PREDICT_FALSE (peer_idx && (last_peer_time_idx != peer_idx))) + { + wg_timers_any_authenticated_packet_received_opt (peer, time); + wg_timers_any_authenticated_packet_traversal (peer); + last_peer_time_idx = peer_idx; + } + + /* Keepalive packet has zero length */ + if (decr_len == 0) + { + is_keepalive = true; + goto trace; + } + + wg_timers_data_received (peer); + + ip46_address_t src_ip; + u8 is_ip4_inner = is_ip4_header (vlib_buffer_get_current (b[0])); + if (is_ip4_inner) + { + ip46_address_set_ip4 ( + &src_ip, + &((ip4_header_t *) vlib_buffer_get_current (b[0]))->src_address); + } + else + { + ip46_address_set_ip6 ( + &src_ip, + &((ip6_header_t *) vlib_buffer_get_current (b[0]))->src_address); + } + + const fib_prefix_t *allowed_ip; + bool allowed = false; + + /* + * we could make this into an ACL, but the expectation + * is that there aren't many allowed IPs and thus a linear + * walk is fater than an ACL + */ + + vec_foreach (allowed_ip, peer->allowed_ips) + { + if (fib_prefix_is_cover_addr_46 (allowed_ip, &src_ip)) + { + allowed = true; + break; + } + } + if (allowed) + { + vnet_buffer (b[0])->sw_if_index[VLIB_RX] = peer->wg_sw_if_index; + data_next[n_data] = + is_ip4_inner ? WG_INPUT_NEXT_IP4_INPUT : WG_INPUT_NEXT_IP6_INPUT; + } + trace: + if (PREDICT_FALSE ((node->flags & VLIB_NODE_FLAG_TRACE) && + (b[0]->flags & VLIB_BUFFER_IS_TRACED))) + { + wg_input_trace_t *t = vlib_add_trace (vm, node, b[0], sizeof (*t)); + t->type = header_type; + t->current_length = b[0]->current_length; + t->is_keepalive = is_keepalive; + t->peer = peer_idx ? *peer_idx : INDEX_INVALID; + } + + b += 1; + n_left_from -= 1; + n_data += 1; + } + /* enqueue other bufs */ + vlib_buffer_enqueue_to_next (vm, node, other_bi, other_next, n_other); + + /* enqueue data bufs */ + vlib_buffer_enqueue_to_next (vm, node, data_bi, data_nexts, n_data); return frame->n_vectors; } diff --git a/src/plugins/wireguard/wireguard_noise.c b/src/plugins/wireguard/wireguard_noise.c index 36de8ae9cac..c8605f117cd 100644 --- a/src/plugins/wireguard/wireguard_noise.c +++ b/src/plugins/wireguard/wireguard_noise.c @@ -36,7 +36,7 @@ static uint32_t noise_remote_handshake_index_get (noise_remote_t *); static void noise_remote_handshake_index_drop (noise_remote_t *); static uint64_t noise_counter_send (noise_counter_t *); -static bool noise_counter_recv (noise_counter_t *, uint64_t); +bool noise_counter_recv (noise_counter_t *, uint64_t); static void noise_kdf (uint8_t *, uint8_t *, uint8_t *, const uint8_t *, size_t, size_t, size_t, size_t, @@ -407,6 +407,8 @@ noise_remote_begin_session (vlib_main_t * vm, noise_remote_t * r) /* Now we need to add_new_keypair */ clib_rwlock_writer_lock (&r->r_keypair_lock); + /* Activate barrier to synchronization keys between threads */ + vlib_worker_thread_barrier_sync (vm); next = r->r_next; current = r->r_current; previous = r->r_previous; @@ -438,6 +440,7 @@ noise_remote_begin_session (vlib_main_t * vm, noise_remote_t * r) r->r_next = noise_remote_keypair_allocate (r); *r->r_next = kp; } + vlib_worker_thread_barrier_release (vm); clib_rwlock_writer_unlock (&r->r_keypair_lock); secure_zero_memory (&r->r_handshake, sizeof (r->r_handshake)); @@ -541,6 +544,41 @@ chacha20poly1305_calc (vlib_main_t * vm, return (op->status == VNET_CRYPTO_OP_STATUS_COMPLETED); } +always_inline void +wg_prepare_sync_op (vlib_main_t *vm, vnet_crypto_op_t **crypto_ops, u8 *src, + u32 src_len, u8 *dst, u8 *aad, u32 aad_len, u64 nonce, + vnet_crypto_op_id_t op_id, + vnet_crypto_key_index_t key_index, u32 bi, u8 *iv) +{ + vnet_crypto_op_t _op, *op = &_op; + u8 src_[] = {}; + + clib_memset (iv, 0, 4); + clib_memcpy (iv + 4, &nonce, sizeof (nonce)); + + vec_add2_aligned (crypto_ops[0], op, 1, CLIB_CACHE_LINE_BYTES); + vnet_crypto_op_init (op, op_id); + + op->tag_len = NOISE_AUTHTAG_LEN; + if (op_id == VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC) + { + op->tag = src + src_len; + op->flags |= VNET_CRYPTO_OP_FLAG_HMAC_CHECK; + } + else + op->tag = dst + src_len; + + op->src = !src ? src_ : src; + op->len = src_len; + + op->dst = dst; + op->key_index = key_index; + op->aad = aad; + op->aad_len = aad_len; + op->iv = iv; + op->user_data = bi; +} + enum noise_state_crypt noise_remote_encrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t * r_idx, uint64_t * nonce, uint8_t * src, size_t srclen, @@ -592,26 +630,67 @@ error: } enum noise_state_crypt -noise_remote_decrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t r_idx, - uint64_t nonce, uint8_t * src, size_t srclen, - uint8_t * dst) +noise_sync_remote_encrypt (vlib_main_t *vm, vnet_crypto_op_t **crypto_ops, + noise_remote_t *r, uint32_t *r_idx, uint64_t *nonce, + uint8_t *src, size_t srclen, uint8_t *dst, u32 bi, + u8 *iv, f64 time) { noise_keypair_t *kp; enum noise_state_crypt ret = SC_FAILED; - if (r->r_current != NULL && r->r_current->kp_local_index == r_idx) - { - kp = r->r_current; - } - else if (r->r_previous != NULL && r->r_previous->kp_local_index == r_idx) - { - kp = r->r_previous; - } - else if (r->r_next != NULL && r->r_next->kp_local_index == r_idx) - { - kp = r->r_next; - } - else + if ((kp = r->r_current) == NULL) + goto error; + + /* We confirm that our values are within our tolerances. We want: + * - a valid keypair + * - our keypair to be less than REJECT_AFTER_TIME seconds old + * - our receive counter to be less than REJECT_AFTER_MESSAGES + * - our send counter to be less than REJECT_AFTER_MESSAGES + */ + if (!kp->kp_valid || + wg_birthdate_has_expired_opt (kp->kp_birthdate, REJECT_AFTER_TIME, + time) || + kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES || + ((*nonce = noise_counter_send (&kp->kp_ctr)) > REJECT_AFTER_MESSAGES)) + goto error; + + /* We encrypt into the same buffer, so the caller must ensure that buf + * has NOISE_AUTHTAG_LEN bytes to store the MAC. The nonce and index + * are passed back out to the caller through the provided data pointer. */ + *r_idx = kp->kp_remote_index; + + wg_prepare_sync_op (vm, crypto_ops, src, srclen, dst, NULL, 0, *nonce, + VNET_CRYPTO_OP_CHACHA20_POLY1305_ENC, kp->kp_send_index, + bi, iv); + + /* If our values are still within tolerances, but we are approaching + * the tolerances, we notify the caller with ESTALE that they should + * establish a new keypair. The current keypair can continue to be used + * until the tolerances are hit. We notify if: + * - our send counter is valid and not less than REKEY_AFTER_MESSAGES + * - we're the initiator and our keypair is older than + * REKEY_AFTER_TIME seconds */ + ret = SC_KEEP_KEY_FRESH; + if ((kp->kp_valid && *nonce >= REKEY_AFTER_MESSAGES) || + (kp->kp_is_initiator && wg_birthdate_has_expired_opt ( + kp->kp_birthdate, REKEY_AFTER_TIME, time))) + goto error; + + ret = SC_OK; +error: + return ret; +} + +enum noise_state_crypt +noise_sync_remote_decrypt (vlib_main_t *vm, vnet_crypto_op_t **crypto_ops, + noise_remote_t *r, uint32_t r_idx, uint64_t nonce, + uint8_t *src, size_t srclen, uint8_t *dst, u32 bi, + u8 *iv, f64 time) +{ + noise_keypair_t *kp; + enum noise_state_crypt ret = SC_FAILED; + + if ((kp = wg_get_active_keypair (r, r_idx)) == NULL) { goto error; } @@ -620,20 +699,17 @@ noise_remote_decrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t r_idx, * are the same as the encrypt routine. * * kp_ctr isn't locked here, we're happy to accept a racy read. */ - if (wg_birthdate_has_expired (kp->kp_birthdate, REJECT_AFTER_TIME) || + if (wg_birthdate_has_expired_opt (kp->kp_birthdate, REJECT_AFTER_TIME, + time) || kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES) goto error; /* Decrypt, then validate the counter. We don't want to validate the * counter before decrypting as we do not know the message is authentic * prior to decryption. */ - if (!chacha20poly1305_calc (vm, src, srclen, dst, NULL, 0, nonce, - VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC, - kp->kp_recv_index)) - goto error; - - if (!noise_counter_recv (&kp->kp_ctr, nonce)) - goto error; + wg_prepare_sync_op (vm, crypto_ops, src, srclen, dst, NULL, 0, nonce, + VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC, kp->kp_recv_index, + bi, iv); /* If we've received the handshake confirming data packet then move the * next keypair into current. If we do slide the next keypair in, then @@ -662,10 +738,9 @@ noise_remote_decrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t r_idx, * REKEY_AFTER_TIME_RECV seconds. */ ret = SC_KEEP_KEY_FRESH; kp = r->r_current; - if (kp != NULL && - kp->kp_valid && - kp->kp_is_initiator && - wg_birthdate_has_expired (kp->kp_birthdate, REKEY_AFTER_TIME_RECV)) + if (kp != NULL && kp->kp_valid && kp->kp_is_initiator && + wg_birthdate_has_expired_opt (kp->kp_birthdate, REKEY_AFTER_TIME_RECV, + time)) goto error; ret = SC_OK; @@ -724,47 +799,6 @@ noise_counter_send (noise_counter_t * ctr) return ret; } -static bool -noise_counter_recv (noise_counter_t * ctr, uint64_t recv) -{ - uint64_t i, top, index_recv, index_ctr; - unsigned long bit; - bool ret = false; - - /* Check that the recv counter is valid */ - if (ctr->c_recv >= REJECT_AFTER_MESSAGES || recv >= REJECT_AFTER_MESSAGES) - goto error; - - /* If the packet is out of the window, invalid */ - if (recv + COUNTER_WINDOW_SIZE < ctr->c_recv) - goto error; - - /* If the new counter is ahead of the current counter, we'll need to - * zero out the bitmap that has previously been used */ - index_recv = recv / COUNTER_BITS; - index_ctr = ctr->c_recv / COUNTER_BITS; - - if (recv > ctr->c_recv) - { - top = clib_min (index_recv - index_ctr, COUNTER_NUM); - for (i = 1; i <= top; i++) - ctr->c_backtrack[(i + index_ctr) & (COUNTER_NUM - 1)] = 0; - ctr->c_recv = recv; - } - - index_recv %= COUNTER_NUM; - bit = 1ul << (recv % COUNTER_BITS); - - if (ctr->c_backtrack[index_recv] & bit) - goto error; - - ctr->c_backtrack[index_recv] |= bit; - - ret = true; -error: - return ret; -} - static void noise_kdf (uint8_t * a, uint8_t * b, uint8_t * c, const uint8_t * x, size_t a_len, size_t b_len, size_t c_len, size_t x_len, diff --git a/src/plugins/wireguard/wireguard_noise.h b/src/plugins/wireguard/wireguard_noise.h index 5b5a88fa250..ef1e7dcbfca 100644 --- a/src/plugins/wireguard/wireguard_noise.h +++ b/src/plugins/wireguard/wireguard_noise.h @@ -187,12 +187,80 @@ noise_remote_encrypt (vlib_main_t * vm, noise_remote_t *, uint32_t * r_idx, uint64_t * nonce, uint8_t * src, size_t srclen, uint8_t * dst); + enum noise_state_crypt -noise_remote_decrypt (vlib_main_t * vm, noise_remote_t *, - uint32_t r_idx, - uint64_t nonce, - uint8_t * src, size_t srclen, uint8_t * dst); +noise_sync_remote_encrypt (vlib_main_t *vm, vnet_crypto_op_t **crypto_ops, + noise_remote_t *r, uint32_t *r_idx, uint64_t *nonce, + uint8_t *src, size_t srclen, uint8_t *dst, u32 bi, + u8 *iv, f64 time); + +enum noise_state_crypt +noise_sync_remote_decrypt (vlib_main_t *vm, vnet_crypto_op_t **crypto_ops, + noise_remote_t *, uint32_t r_idx, uint64_t nonce, + uint8_t *src, size_t srclen, uint8_t *dst, u32 bi, + u8 *iv, f64 time); + +static_always_inline noise_keypair_t * +wg_get_active_keypair (noise_remote_t *r, uint32_t r_idx) +{ + if (r->r_current != NULL && r->r_current->kp_local_index == r_idx) + { + return r->r_current; + } + else if (r->r_previous != NULL && r->r_previous->kp_local_index == r_idx) + { + return r->r_previous; + } + else if (r->r_next != NULL && r->r_next->kp_local_index == r_idx) + { + return r->r_next; + } + else + { + return NULL; + } +} + +inline bool +noise_counter_recv (noise_counter_t *ctr, uint64_t recv) +{ + uint64_t i, top, index_recv, index_ctr; + unsigned long bit; + bool ret = false; + /* Check that the recv counter is valid */ + if (ctr->c_recv >= REJECT_AFTER_MESSAGES || recv >= REJECT_AFTER_MESSAGES) + goto error; + + /* If the packet is out of the window, invalid */ + if (recv + COUNTER_WINDOW_SIZE < ctr->c_recv) + goto error; + + /* If the new counter is ahead of the current counter, we'll need to + * zero out the bitmap that has previously been used */ + index_recv = recv / COUNTER_BITS; + index_ctr = ctr->c_recv / COUNTER_BITS; + + if (recv > ctr->c_recv) + { + top = clib_min (index_recv - index_ctr, COUNTER_NUM); + for (i = 1; i <= top; i++) + ctr->c_backtrack[(i + index_ctr) & (COUNTER_NUM - 1)] = 0; + ctr->c_recv = recv; + } + + index_recv %= COUNTER_NUM; + bit = 1ul << (recv % COUNTER_BITS); + + if (ctr->c_backtrack[index_recv] & bit) + goto error; + + ctr->c_backtrack[index_recv] |= bit; + + ret = true; +error: + return ret; +} #endif /* __included_wg_noise_h__ */ diff --git a/src/plugins/wireguard/wireguard_output_tun.c b/src/plugins/wireguard/wireguard_output_tun.c index c792d4b713e..2feb0570a31 100644 --- a/src/plugins/wireguard/wireguard_output_tun.c +++ b/src/plugins/wireguard/wireguard_output_tun.c @@ -93,32 +93,67 @@ format_wg_output_tun_trace (u8 * s, va_list * args) return s; } +static_always_inline void +wg_output_process_ops (vlib_main_t *vm, vlib_node_runtime_t *node, + vnet_crypto_op_t *ops, vlib_buffer_t *b[], u16 *nexts, + u16 drop_next) +{ + u32 n_fail, n_ops = vec_len (ops); + vnet_crypto_op_t *op = ops; + + if (n_ops == 0) + return; + + n_fail = n_ops - vnet_crypto_process_ops (vm, op, n_ops); + + while (n_fail) + { + ASSERT (op - ops < n_ops); + + if (op->status != VNET_CRYPTO_OP_STATUS_COMPLETED) + { + u32 bi = op->user_data; + b[bi]->error = node->errors[WG_OUTPUT_ERROR_KEYPAIR]; + nexts[bi] = drop_next; + n_fail--; + } + op++; + } +} + /* is_ip4 - inner header flag */ always_inline uword wg_output_tun_inline (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame, u8 is_ip4) { - u32 n_left_from; - u32 *from; + wg_main_t *wmp = &wg_main; + wg_per_thread_data_t *ptd = + vec_elt_at_index (wmp->per_thread_data, vm->thread_index); + u32 *from = vlib_frame_vector_args (frame); + u32 n_left_from = frame->n_vectors; ip4_udp_wg_header_t *hdr4_out = NULL; ip6_udp_wg_header_t *hdr6_out = NULL; message_data_t *message_data_wg = NULL; - vlib_buffer_t *bufs[VLIB_FRAME_SIZE], **b; - u16 nexts[VLIB_FRAME_SIZE], *next; + vlib_buffer_t *bufs[VLIB_FRAME_SIZE], **b = bufs; + vnet_crypto_op_t **crypto_ops = &ptd->crypto_ops; + u16 nexts[VLIB_FRAME_SIZE], *next = nexts; + vlib_buffer_t *sync_bufs[VLIB_FRAME_SIZE]; u32 thread_index = vm->thread_index; - - from = vlib_frame_vector_args (frame); - n_left_from = frame->n_vectors; - b = bufs; - next = nexts; + u16 n_sync = 0; + u16 drop_next = WG_OUTPUT_NEXT_ERROR; vlib_get_buffers (vm, from, bufs, n_left_from); + vec_reset_length (ptd->crypto_ops); wg_peer_t *peer = NULL; + u32 adj_index = 0; + u32 last_adj_index = ~0; + index_t peeri = INDEX_INVALID; + + f64 time = clib_time_now (&vm->clib_time) + vm->time_offset; while (n_left_from > 0) { - index_t peeri; u8 iph_offset = 0; u8 is_ip4_out = 1; u8 *plain_data; @@ -130,14 +165,21 @@ wg_output_tun_inline (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_prefetch_buffer_header (b[2], LOAD); p = vlib_buffer_get_current (b[1]); CLIB_PREFETCH (p, CLIB_CACHE_LINE_BYTES, LOAD); + CLIB_PREFETCH (vlib_buffer_get_tail (b[1]), CLIB_CACHE_LINE_BYTES, + LOAD); } next[0] = WG_OUTPUT_NEXT_ERROR; - peeri = - wg_peer_get_by_adj_index (vnet_buffer (b[0])->ip.adj_index[VLIB_TX]); - peer = wg_peer_get (peeri); - if (wg_peer_is_dead (peer)) + adj_index = vnet_buffer (b[0])->ip.adj_index[VLIB_TX]; + + if (PREDICT_FALSE (last_adj_index != adj_index)) + { + peeri = wg_peer_get_by_adj_index (adj_index); + peer = wg_peer_get (peeri); + } + + if (!peer || wg_peer_is_dead (peer)) { b[0]->error = node->errors[WG_OUTPUT_ERROR_PEER]; goto out; @@ -179,6 +221,7 @@ wg_output_tun_inline (vlib_main_t *vm, vlib_node_runtime_t *node, iph_offset = vnet_buffer (b[0])->ip.save_rewrite_length; plain_data = vlib_buffer_get_current (b[0]) + iph_offset; plain_data_len = vlib_buffer_length_in_chain (vm, b[0]) - iph_offset; + u8 *iv_data = b[0]->pre_data; size_t encrypted_packet_len = message_data_len (plain_data_len); @@ -194,11 +237,20 @@ wg_output_tun_inline (vlib_main_t *vm, vlib_node_runtime_t *node, goto out; } + if (PREDICT_FALSE (last_adj_index != adj_index)) + { + wg_timers_any_authenticated_packet_sent_opt (peer, time); + wg_timers_data_sent_opt (peer, time); + wg_timers_any_authenticated_packet_traversal (peer); + last_adj_index = adj_index; + } + enum noise_state_crypt state; - state = noise_remote_encrypt ( - vm, &peer->remote, &message_data_wg->receiver_index, - &message_data_wg->counter, plain_data, plain_data_len, plain_data); + state = noise_sync_remote_encrypt ( + vm, crypto_ops, &peer->remote, &message_data_wg->receiver_index, + &message_data_wg->counter, plain_data, plain_data_len, plain_data, + n_sync, iv_data, time); if (PREDICT_FALSE (state == SC_KEEP_KEY_FRESH)) { @@ -206,7 +258,7 @@ wg_output_tun_inline (vlib_main_t *vm, vlib_node_runtime_t *node, } else if (PREDICT_FALSE (state == SC_FAILED)) { - //TODO: Maybe wrong + // TODO: Maybe wrong wg_send_handshake_from_mt (peeri, false); wg_peer_update_flags (peeri, WG_PEER_ESTABLISHED, false); goto out; @@ -236,10 +288,6 @@ wg_output_tun_inline (vlib_main_t *vm, vlib_node_runtime_t *node, clib_host_to_net_u16 (b[0]->current_length); } - wg_timers_any_authenticated_packet_sent (peer); - wg_timers_data_sent (peer); - wg_timers_any_authenticated_packet_traversal (peer); - out: if (PREDICT_FALSE ((node->flags & VLIB_NODE_FLAG_TRACE) && (b[0]->flags & VLIB_BUFFER_IS_TRACED))) @@ -256,12 +304,18 @@ wg_output_tun_inline (vlib_main_t *vm, vlib_node_runtime_t *node, } next: + sync_bufs[n_sync] = b[0]; + n_sync += 1; n_left_from -= 1; next += 1; b += 1; } - vlib_buffer_enqueue_to_next (vm, node, from, nexts, frame->n_vectors); + /* wg-output-process-ops */ + wg_output_process_ops (vm, node, ptd->crypto_ops, sync_bufs, nexts, + drop_next); + + vlib_buffer_enqueue_to_next (vm, node, from, nexts, n_sync); return frame->n_vectors; } diff --git a/src/plugins/wireguard/wireguard_timer.c b/src/plugins/wireguard/wireguard_timer.c index 97c861b459c..d2ae7ebf93e 100644 --- a/src/plugins/wireguard/wireguard_timer.c +++ b/src/plugins/wireguard/wireguard_timer.c @@ -26,6 +26,13 @@ get_random_u32_max (u32 max) return random_u32 (&seed) % max; } +static u32 +get_random_u32_max_opt (u32 max, f64 time) +{ + u32 seed = (u32) (time * 1e6); + return random_u32 (&seed) % max; +} + static void stop_timer (wg_peer_t * peer, u32 timer_id) { @@ -215,6 +222,12 @@ wg_timers_any_authenticated_packet_sent (wg_peer_t * peer) } void +wg_timers_any_authenticated_packet_sent_opt (wg_peer_t *peer, f64 time) +{ + peer->last_sent_packet = time; +} + +void wg_timers_handshake_initiated (wg_peer_t * peer) { peer->rehandshake_started = vlib_time_now (vlib_get_main ()); @@ -246,6 +259,17 @@ wg_timers_data_sent (wg_peer_t * peer) peer->new_handshake_interval_tick); } +void +wg_timers_data_sent_opt (wg_peer_t *peer, f64 time) +{ + peer->new_handshake_interval_tick = + (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT) * WHZ + + get_random_u32_max_opt (REKEY_TIMEOUT_JITTER, time); + + start_timer_from_mt (peer - wg_peer_pool, WG_TIMER_NEW_HANDSHAKE, + peer->new_handshake_interval_tick); +} + /* Should be called after an authenticated data packet is received. */ void wg_timers_data_received (wg_peer_t * peer) @@ -275,6 +299,12 @@ wg_timers_any_authenticated_packet_received (wg_peer_t * peer) peer->last_received_packet = vlib_time_now (vlib_get_main ()); } +void +wg_timers_any_authenticated_packet_received_opt (wg_peer_t *peer, f64 time) +{ + peer->last_received_packet = time; +} + static vlib_node_registration_t wg_timer_mngr_node; static void diff --git a/src/plugins/wireguard/wireguard_timer.h b/src/plugins/wireguard/wireguard_timer.h index 6b59a39f815..cc8e123f3a2 100644 --- a/src/plugins/wireguard/wireguard_timer.h +++ b/src/plugins/wireguard/wireguard_timer.h @@ -41,9 +41,13 @@ typedef struct wg_peer wg_peer_t; void wg_timer_wheel_init (); void wg_timers_stop (wg_peer_t * peer); void wg_timers_data_sent (wg_peer_t * peer); +void wg_timers_data_sent_opt (wg_peer_t *peer, f64 time); void wg_timers_data_received (wg_peer_t * peer); void wg_timers_any_authenticated_packet_sent (wg_peer_t * peer); +void wg_timers_any_authenticated_packet_sent_opt (wg_peer_t *peer, f64 time); void wg_timers_any_authenticated_packet_received (wg_peer_t * peer); +void wg_timers_any_authenticated_packet_received_opt (wg_peer_t *peer, + f64 time); void wg_timers_handshake_initiated (wg_peer_t * peer); void wg_timers_handshake_complete (wg_peer_t * peer); void wg_timers_session_derived (wg_peer_t * peer); @@ -57,6 +61,13 @@ wg_birthdate_has_expired (f64 birthday_seconds, f64 expiration_seconds) return (birthday_seconds + expiration_seconds) < now_seconds; } +static inline bool +wg_birthdate_has_expired_opt (f64 birthday_seconds, f64 expiration_seconds, + f64 time) +{ + return (birthday_seconds + expiration_seconds) < time; +} + #endif /* __included_wg_timer_h__ */ /* |