diff options
Diffstat (limited to 'src/plugins/crypto_native/aes_cbc.c')
-rw-r--r-- | src/plugins/crypto_native/aes_cbc.c | 357 |
1 files changed, 297 insertions, 60 deletions
diff --git a/src/plugins/crypto_native/aes_cbc.c b/src/plugins/crypto_native/aes_cbc.c index 7896c8814b1..02d96b31c79 100644 --- a/src/plugins/crypto_native/aes_cbc.c +++ b/src/plugins/crypto_native/aes_cbc.c @@ -25,17 +25,35 @@ #pragma GCC optimize ("O3") #endif +#if defined(__VAES__) && defined(__AVX512F__) +#define N 16 +#define u8xN u8x64 +#define u32xN u32x16 +#define u32xN_min_scalar u32x16_min_scalar +#define u32xN_is_all_zero u32x16_is_all_zero +#define u32xN_splat u32x16_splat +#elif defined(__VAES__) +#define N 8 +#define u8xN u8x32 +#define u32xN u32x8 +#define u32xN_min_scalar u32x8_min_scalar +#define u32xN_is_all_zero u32x8_is_all_zero +#define u32xN_splat u32x8_splat +#else +#define N 4 +#define u8xN u8x16 +#define u32xN u32x4 +#define u32xN_min_scalar u32x4_min_scalar +#define u32xN_is_all_zero u32x4_is_all_zero +#define u32xN_splat u32x4_splat +#endif + typedef struct { u8x16 encrypt_key[15]; -#if __VAES__ - u8x64 decrypt_key[15]; -#else - u8x16 decrypt_key[15]; -#endif + u8xN decrypt_key[15]; } aes_cbc_key_data_t; - static_always_inline void __clib_unused aes_cbc_dec (u8x16 * k, u8x16u * src, u8x16u * dst, u8x16u * iv, int count, int rounds) @@ -119,7 +137,7 @@ aes_cbc_dec (u8x16 * k, u8x16u * src, u8x16u * dst, u8x16u * iv, int count, } #if __x86_64__ -#ifdef __VAES__ +#if defined(__VAES__) && defined(__AVX512F__) static_always_inline u8x64 aes_block_load_x4 (u8 * src[], int i) @@ -142,14 +160,13 @@ aes_block_store_x4 (u8 * dst[], int i, u8x64 r) } static_always_inline u8x64 -aes_cbc_dec_permute (u8x64 a, u8x64 b) +aes4_cbc_dec_permute (u8x64 a, u8x64 b) { - __m512i perm = { 6, 7, 8, 9, 10, 11, 12, 13 }; - return (u8x64) _mm512_permutex2var_epi64 ((__m512i) a, perm, (__m512i) b); + return (u8x64) u64x8_shuffle2 (a, b, 6, 7, 8, 9, 10, 11, 12, 13); } static_always_inline void -vaes_cbc_dec (u8x64 *k, u8x64u *src, u8x64u *dst, u8x16u *iv, int count, +aes4_cbc_dec (u8x64 *k, u8x64u *src, u8x64u *dst, u8x16u *iv, int count, aes_key_size_t rounds) { u8x64 f, r[4], c[4] = { }; @@ -184,10 +201,10 @@ vaes_cbc_dec (u8x64 *k, u8x64u *src, u8x64u *dst, u8x16u *iv, int count, r[2] = aes_dec_last_round_x4 (r[2], k[i]); r[3] = aes_dec_last_round_x4 (r[3], k[i]); - dst[0] = r[0] ^= aes_cbc_dec_permute (f, c[0]); - dst[1] = r[1] ^= aes_cbc_dec_permute (c[0], c[1]); - dst[2] = r[2] ^= aes_cbc_dec_permute (c[1], c[2]); - dst[3] = r[3] ^= aes_cbc_dec_permute (c[2], c[3]); + dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]); + dst[1] = r[1] ^= aes4_cbc_dec_permute (c[0], c[1]); + dst[2] = r[2] ^= aes4_cbc_dec_permute (c[1], c[2]); + dst[3] = r[3] ^= aes4_cbc_dec_permute (c[2], c[3]); f = c[3]; n_blocks -= 16; @@ -195,40 +212,248 @@ vaes_cbc_dec (u8x64 *k, u8x64u *src, u8x64u *dst, u8x16u *iv, int count, dst += 4; } - while (n_blocks > 0) + if (n_blocks >= 12) + { + c[0] = src[0]; + c[1] = src[1]; + c[2] = src[2]; + + r[0] = c[0] ^ k[0]; + r[1] = c[1] ^ k[0]; + r[2] = c[2] ^ k[0]; + + for (i = 1; i < rounds; i++) + { + r[0] = aes_dec_round_x4 (r[0], k[i]); + r[1] = aes_dec_round_x4 (r[1], k[i]); + r[2] = aes_dec_round_x4 (r[2], k[i]); + } + + r[0] = aes_dec_last_round_x4 (r[0], k[i]); + r[1] = aes_dec_last_round_x4 (r[1], k[i]); + r[2] = aes_dec_last_round_x4 (r[2], k[i]); + + dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]); + dst[1] = r[1] ^= aes4_cbc_dec_permute (c[0], c[1]); + dst[2] = r[2] ^= aes4_cbc_dec_permute (c[1], c[2]); + f = c[2]; + + n_blocks -= 12; + src += 3; + dst += 3; + } + else if (n_blocks >= 8) + { + c[0] = src[0]; + c[1] = src[1]; + + r[0] = c[0] ^ k[0]; + r[1] = c[1] ^ k[0]; + + for (i = 1; i < rounds; i++) + { + r[0] = aes_dec_round_x4 (r[0], k[i]); + r[1] = aes_dec_round_x4 (r[1], k[i]); + } + + r[0] = aes_dec_last_round_x4 (r[0], k[i]); + r[1] = aes_dec_last_round_x4 (r[1], k[i]); + + dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]); + dst[1] = r[1] ^= aes4_cbc_dec_permute (c[0], c[1]); + f = c[1]; + + n_blocks -= 8; + src += 2; + dst += 2; + } + else if (n_blocks >= 4) + { + c[0] = src[0]; + + r[0] = c[0] ^ k[0]; + + for (i = 1; i < rounds; i++) + { + r[0] = aes_dec_round_x4 (r[0], k[i]); + } + + r[0] = aes_dec_last_round_x4 (r[0], k[i]); + + dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]); + f = c[0]; + + n_blocks -= 4; + src += 1; + dst += 1; + } + + if (n_blocks > 0) { m = (1 << (n_blocks * 2)) - 1; c[0] = (u8x64) _mm512_mask_loadu_epi64 ((__m512i) c[0], m, (__m512i *) src); - f = aes_cbc_dec_permute (f, c[0]); + f = aes4_cbc_dec_permute (f, c[0]); r[0] = c[0] ^ k[0]; for (i = 1; i < rounds; i++) r[0] = aes_dec_round_x4 (r[0], k[i]); r[0] = aes_dec_last_round_x4 (r[0], k[i]); _mm512_mask_storeu_epi64 ((__m512i *) dst, m, (__m512i) (r[0] ^ f)); - f = c[0]; + } +} +#elif defined(__VAES__) + +static_always_inline u8x32 +aes_block_load_x2 (u8 *src[], int i) +{ + u8x32 r = {}; + r = u8x32_insert_lo (r, aes_block_load (src[0] + i)); + r = u8x32_insert_hi (r, aes_block_load (src[1] + i)); + return r; +} + +static_always_inline void +aes_block_store_x2 (u8 *dst[], int i, u8x32 r) +{ + aes_block_store (dst[0] + i, u8x32_extract_lo (r)); + aes_block_store (dst[1] + i, u8x32_extract_hi (r)); +} + +static_always_inline u8x32 +aes2_cbc_dec_permute (u8x32 a, u8x32 b) +{ + return (u8x32) u64x4_shuffle2 ((u64x4) a, (u64x4) b, 2, 3, 4, 5); +} + +static_always_inline void +aes2_cbc_dec (u8x32 *k, u8x32u *src, u8x32u *dst, u8x16u *iv, int count, + aes_key_size_t rounds) +{ + u8x32 f = {}, r[4], c[4] = {}; + int i, n_blocks = count >> 4; + + f = u8x32_insert_hi (f, *iv); + + while (n_blocks >= 8) + { + c[0] = src[0]; + c[1] = src[1]; + c[2] = src[2]; + c[3] = src[3]; + + r[0] = c[0] ^ k[0]; + r[1] = c[1] ^ k[0]; + r[2] = c[2] ^ k[0]; + r[3] = c[3] ^ k[0]; + + for (i = 1; i < rounds; i++) + { + r[0] = aes_dec_round_x2 (r[0], k[i]); + r[1] = aes_dec_round_x2 (r[1], k[i]); + r[2] = aes_dec_round_x2 (r[2], k[i]); + r[3] = aes_dec_round_x2 (r[3], k[i]); + } + + r[0] = aes_dec_last_round_x2 (r[0], k[i]); + r[1] = aes_dec_last_round_x2 (r[1], k[i]); + r[2] = aes_dec_last_round_x2 (r[2], k[i]); + r[3] = aes_dec_last_round_x2 (r[3], k[i]); + + dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]); + dst[1] = r[1] ^= aes2_cbc_dec_permute (c[0], c[1]); + dst[2] = r[2] ^= aes2_cbc_dec_permute (c[1], c[2]); + dst[3] = r[3] ^= aes2_cbc_dec_permute (c[2], c[3]); + f = c[3]; + + n_blocks -= 8; + src += 4; + dst += 4; + } + + if (n_blocks >= 6) + { + c[0] = src[0]; + c[1] = src[1]; + c[2] = src[2]; + + r[0] = c[0] ^ k[0]; + r[1] = c[1] ^ k[0]; + r[2] = c[2] ^ k[0]; + + for (i = 1; i < rounds; i++) + { + r[0] = aes_dec_round_x2 (r[0], k[i]); + r[1] = aes_dec_round_x2 (r[1], k[i]); + r[2] = aes_dec_round_x2 (r[2], k[i]); + } + + r[0] = aes_dec_last_round_x2 (r[0], k[i]); + r[1] = aes_dec_last_round_x2 (r[1], k[i]); + r[2] = aes_dec_last_round_x2 (r[2], k[i]); + + dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]); + dst[1] = r[1] ^= aes2_cbc_dec_permute (c[0], c[1]); + dst[2] = r[2] ^= aes2_cbc_dec_permute (c[1], c[2]); + f = c[2]; + + n_blocks -= 6; + src += 3; + dst += 3; + } + else if (n_blocks >= 4) + { + c[0] = src[0]; + c[1] = src[1]; + + r[0] = c[0] ^ k[0]; + r[1] = c[1] ^ k[0]; + + for (i = 1; i < rounds; i++) + { + r[0] = aes_dec_round_x2 (r[0], k[i]); + r[1] = aes_dec_round_x2 (r[1], k[i]); + } + + r[0] = aes_dec_last_round_x2 (r[0], k[i]); + r[1] = aes_dec_last_round_x2 (r[1], k[i]); + + dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]); + dst[1] = r[1] ^= aes2_cbc_dec_permute (c[0], c[1]); + f = c[1]; + n_blocks -= 4; + src += 2; + dst += 2; + } + else if (n_blocks >= 2) + { + c[0] = src[0]; + r[0] = c[0] ^ k[0]; + + for (i = 1; i < rounds; i++) + r[0] = aes_dec_round_x2 (r[0], k[i]); + + r[0] = aes_dec_last_round_x2 (r[0], k[i]); + dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]); + f = c[0]; + + n_blocks -= 2; src += 1; dst += 1; } + + if (n_blocks > 0) + { + u8x16 rl = *(u8x16u *) src ^ u8x32_extract_lo (k[0]); + for (i = 1; i < rounds; i++) + rl = aes_dec_round (rl, u8x32_extract_lo (k[i])); + rl = aes_dec_last_round (rl, u8x32_extract_lo (k[i])); + *(u8x16 *) dst = rl ^ u8x32_extract_hi (f); + } } #endif #endif -#ifdef __VAES__ -#define N 16 -#define u32xN u32x16 -#define u32xN_min_scalar u32x16_min_scalar -#define u32xN_is_all_zero u32x16_is_all_zero -#define u32xN_splat u32x16_splat -#else -#define N 4 -#define u32xN u32x4 -#define u32xN_min_scalar u32x4_min_scalar -#define u32xN_is_all_zero u32x4_is_all_zero -#define u32xN_splat u32x4_splat -#endif - static_always_inline u32 aes_ops_enc_aes_cbc (vlib_main_t * vm, vnet_crypto_op_t * ops[], u32 n_ops, aes_key_size_t ks) @@ -242,14 +467,8 @@ aes_ops_enc_aes_cbc (vlib_main_t * vm, vnet_crypto_op_t * ops[], vnet_crypto_key_index_t key_index[N]; u8 *src[N] = { }; u8 *dst[N] = { }; -#if __VAES__ - u8x64 r[N / 4] = { }; - u8x64 k[15][N / 4] = { }; - u8x16 *kq, *rq = (u8x16 *) r; -#else - u8x16 r[N] = { }; - u8x16 k[15][N] = { }; -#endif + u8xN r[4] = {}; + u8xN k[15][4] = {}; for (i = 0; i < N; i++) key_index[i] = ~0; @@ -268,11 +487,7 @@ more: else { u8x16 t = aes_block_load (ops[0]->iv); -#if __VAES__ - rq[i] = t; -#else - r[i] = t; -#endif + ((u8x16 *) r)[i] = t; src[i] = ops[0]->src; dst[i] = ops[0]->dst; @@ -284,14 +499,7 @@ more: key_index[i] = ops[0]->key_index; kd = (aes_cbc_key_data_t *) cm->key_data[key_index[i]]; for (j = 0; j < rounds + 1; j++) - { -#if __VAES__ - kq = (u8x16 *) k[j]; - kq[i] = kd->encrypt_key[j]; -#else - k[j][i] = kd->encrypt_key[j]; -#endif - } + ((u8x16 *) k[j])[i] = kd->encrypt_key[j]; } ops[0]->status = VNET_CRYPTO_OP_STATUS_COMPLETED; n_left--; @@ -305,7 +513,7 @@ more: for (i = 0; i < count; i += 16) { -#ifdef __VAES__ +#if defined(__VAES__) && defined(__AVX512F__) r[0] = u8x64_xor3 (r[0], aes_block_load_x4 (src, i), k[0][0]); r[1] = u8x64_xor3 (r[1], aes_block_load_x4 (src + 4, i), k[0][1]); r[2] = u8x64_xor3 (r[2], aes_block_load_x4 (src + 8, i), k[0][2]); @@ -327,6 +535,28 @@ more: aes_block_store_x4 (dst + 4, i, r[1]); aes_block_store_x4 (dst + 8, i, r[2]); aes_block_store_x4 (dst + 12, i, r[3]); +#elif defined(__VAES__) + r[0] = u8x32_xor3 (r[0], aes_block_load_x2 (src, i), k[0][0]); + r[1] = u8x32_xor3 (r[1], aes_block_load_x2 (src + 2, i), k[0][1]); + r[2] = u8x32_xor3 (r[2], aes_block_load_x2 (src + 4, i), k[0][2]); + r[3] = u8x32_xor3 (r[3], aes_block_load_x2 (src + 6, i), k[0][3]); + + for (j = 1; j < rounds; j++) + { + r[0] = aes_enc_round_x2 (r[0], k[j][0]); + r[1] = aes_enc_round_x2 (r[1], k[j][1]); + r[2] = aes_enc_round_x2 (r[2], k[j][2]); + r[3] = aes_enc_round_x2 (r[3], k[j][3]); + } + r[0] = aes_enc_last_round_x2 (r[0], k[j][0]); + r[1] = aes_enc_last_round_x2 (r[1], k[j][1]); + r[2] = aes_enc_last_round_x2 (r[2], k[j][2]); + r[3] = aes_enc_last_round_x2 (r[3], k[j][3]); + + aes_block_store_x2 (dst, i, r[0]); + aes_block_store_x2 (dst + 2, i, r[1]); + aes_block_store_x2 (dst + 4, i, r[2]); + aes_block_store_x2 (dst + 6, i, r[3]); #else #if __x86_64__ r[0] = u8x16_xor3 (r[0], aes_block_load (src[0] + i), k[0][0]); @@ -406,8 +636,11 @@ aes_ops_dec_aes_cbc (vlib_main_t * vm, vnet_crypto_op_t * ops[], ASSERT (n_ops >= 1); decrypt: -#ifdef __VAES__ - vaes_cbc_dec (kd->decrypt_key, (u8x64u *) op->src, (u8x64u *) op->dst, +#if defined(__VAES__) && defined(__AVX512F__) + aes4_cbc_dec (kd->decrypt_key, (u8x64u *) op->src, (u8x64u *) op->dst, + (u8x16u *) op->iv, op->len, rounds); +#elif defined(__VAES__) + aes2_cbc_dec (kd->decrypt_key, (u8x32u *) op->src, (u8x32u *) op->dst, (u8x16u *) op->iv, op->len, rounds); #else aes_cbc_dec (kd->decrypt_key, (u8x16u *) op->src, (u8x16u *) op->dst, @@ -435,8 +668,10 @@ aes_cbc_key_exp (vnet_crypto_key_t * key, aes_key_size_t ks) aes_key_enc_to_dec (e, d, ks); for (int i = 0; i < AES_KEY_ROUNDS (ks) + 1; i++) { -#if __VAES__ - kd->decrypt_key[i] = (u8x64) _mm512_broadcast_i64x2 ((__m128i) d[i]); +#if defined(__VAES__) && defined(__AVX512F__) + kd->decrypt_key[i] = u8x64_splat_u8x16 (d[i]); +#elif defined(__VAES__) + kd->decrypt_key[i] = u8x32_splat_u8x16 (d[i]); #else kd->decrypt_key[i] = d[i]; #endif @@ -463,8 +698,10 @@ foreach_aes_cbc_handler_type; #include <fcntl.h> clib_error_t * -#ifdef __VAES__ -crypto_native_aes_cbc_init_icl (vlib_main_t * vm) +#if defined(__VAES__) && defined(__AVX512F__) +crypto_native_aes_cbc_init_icl (vlib_main_t *vm) +#elif defined(__VAES__) +crypto_native_aes_cbc_init_adl (vlib_main_t *vm) #elif __AVX512F__ crypto_native_aes_cbc_init_skx (vlib_main_t * vm) #elif __aarch64__ |