diff options
Diffstat (limited to 'src/vppinfra/crypto/aes_gcm.h')
-rw-r--r-- | src/vppinfra/crypto/aes_gcm.h | 975 |
1 files changed, 975 insertions, 0 deletions
diff --git a/src/vppinfra/crypto/aes_gcm.h b/src/vppinfra/crypto/aes_gcm.h new file mode 100644 index 00000000000..8a5f76c3b33 --- /dev/null +++ b/src/vppinfra/crypto/aes_gcm.h @@ -0,0 +1,975 @@ +/* SPDX-License-Identifier: Apache-2.0 + * Copyright(c) 2023 Cisco Systems, Inc. + */ + +#ifndef __crypto_aes_gcm_h__ +#define __crypto_aes_gcm_h__ + +#include <vppinfra/clib.h> +#include <vppinfra/vector.h> +#include <vppinfra/cache.h> +#include <vppinfra/string.h> +#include <vppinfra/crypto/aes.h> +#include <vppinfra/crypto/ghash.h> + +#define NUM_HI 36 +#if defined(__VAES__) && defined(__AVX512F__) +typedef u8x64 aes_data_t; +typedef u8x64u aes_ghash_t; +typedef u8x64u aes_mem_t; +typedef u32x16 aes_gcm_counter_t; +#define N 64 +#define aes_gcm_load_partial(p, n) u8x64_load_partial ((u8 *) (p), n) +#define aes_gcm_store_partial(v, p, n) u8x64_store_partial (v, (u8 *) (p), n) +#define aes_gcm_splat(v) u8x64_splat (v) +#define aes_gcm_reflect(r) u8x64_reflect_u8x16 (r) +#define aes_gcm_ghash_reduce(c) ghash4_reduce (&(c)->gd) +#define aes_gcm_ghash_reduce2(c) ghash4_reduce2 (&(c)->gd) +#define aes_gcm_ghash_final(c) (c)->T = ghash4_final (&(c)->gd) +#elif defined(__VAES__) +typedef u8x32 aes_data_t; +typedef u8x32u aes_ghash_t; +typedef u8x32u aes_mem_t; +typedef u32x8 aes_gcm_counter_t; +#define N 32 +#define aes_gcm_load_partial(p, n) u8x32_load_partial ((u8 *) (p), n) +#define aes_gcm_store_partial(v, p, n) u8x32_store_partial (v, (u8 *) (p), n) +#define aes_gcm_splat(v) u8x32_splat (v) +#define aes_gcm_reflect(r) u8x32_reflect_u8x16 (r) +#define aes_gcm_ghash_reduce(c) ghash2_reduce (&(c)->gd) +#define aes_gcm_ghash_reduce2(c) ghash2_reduce2 (&(c)->gd) +#define aes_gcm_ghash_final(c) (c)->T = ghash2_final (&(c)->gd) +#else +typedef u8x16 aes_data_t; +typedef u8x16 aes_ghash_t; +typedef u8x16u aes_mem_t; +typedef u32x4 aes_gcm_counter_t; +#define N 16 +#define aes_gcm_load_partial(p, n) u8x16_load_partial ((u8 *) (p), n) +#define aes_gcm_store_partial(v, p, n) u8x16_store_partial (v, (u8 *) (p), n) +#define aes_gcm_splat(v) u8x16_splat (v) +#define aes_gcm_reflect(r) u8x16_reflect (r) +#define aes_gcm_ghash_reduce(c) ghash_reduce (&(c)->gd) +#define aes_gcm_ghash_reduce2(c) ghash_reduce2 (&(c)->gd) +#define aes_gcm_ghash_final(c) (c)->T = ghash_final (&(c)->gd) +#endif +#define N_LANES (N / 16) + +typedef enum +{ + AES_GCM_OP_UNKNONW = 0, + AES_GCM_OP_ENCRYPT, + AES_GCM_OP_DECRYPT, + AES_GCM_OP_GMAC +} aes_gcm_op_t; + +typedef union +{ + u8x16 x1; + u8x32 x2; + u8x64 x4; + u8x16 lanes[4]; +} __clib_aligned (64) +aes_gcm_expaned_key_t; + +typedef struct +{ + /* pre-calculated hash key values */ + const u8x16 Hi[NUM_HI]; + /* extracted AES key */ + const aes_gcm_expaned_key_t Ke[AES_KEY_ROUNDS (AES_KEY_256) + 1]; +} aes_gcm_key_data_t; + +typedef struct +{ + aes_gcm_op_t operation; + int last; + u8 rounds; + uword data_bytes; + uword aad_bytes; + + u8x16 T; + + /* hash */ + const u8x16 *Hi; + const aes_ghash_t *next_Hi; + + /* expaded keys */ + const aes_gcm_expaned_key_t *Ke; + + /* counter */ + u32 counter; + u8x16 EY0; + aes_gcm_counter_t Y; + + /* ghash */ + ghash_data_t gd; +} aes_gcm_ctx_t; + +static_always_inline void +aes_gcm_ghash_mul_first (aes_gcm_ctx_t *ctx, aes_data_t data, u32 n_lanes) +{ + uword hash_offset = NUM_HI - n_lanes; + ctx->next_Hi = (aes_ghash_t *) (ctx->Hi + hash_offset); +#if N_LANES == 4 + u8x64 tag4 = {}; + tag4 = u8x64_insert_u8x16 (tag4, ctx->T, 0); + ghash4_mul_first (&ctx->gd, aes_gcm_reflect (data) ^ tag4, *ctx->next_Hi++); +#elif N_LANES == 2 + u8x32 tag2 = {}; + tag2 = u8x32_insert_lo (tag2, ctx->T); + ghash2_mul_first (&ctx->gd, aes_gcm_reflect (data) ^ tag2, *ctx->next_Hi++); +#else + ghash_mul_first (&ctx->gd, aes_gcm_reflect (data) ^ ctx->T, *ctx->next_Hi++); +#endif +} + +static_always_inline void +aes_gcm_ghash_mul_next (aes_gcm_ctx_t *ctx, aes_data_t data) +{ +#if N_LANES == 4 + ghash4_mul_next (&ctx->gd, aes_gcm_reflect (data), *ctx->next_Hi++); +#elif N_LANES == 2 + ghash2_mul_next (&ctx->gd, aes_gcm_reflect (data), *ctx->next_Hi++); +#else + ghash_mul_next (&ctx->gd, aes_gcm_reflect (data), *ctx->next_Hi++); +#endif +} + +static_always_inline void +aes_gcm_ghash_mul_bit_len (aes_gcm_ctx_t *ctx) +{ + u8x16 r = (u8x16) ((u64x2){ ctx->data_bytes, ctx->aad_bytes } << 3); +#if N_LANES == 4 + u8x64 h = u8x64_insert_u8x16 (u8x64_zero (), ctx->Hi[NUM_HI - 1], 0); + u8x64 r4 = u8x64_insert_u8x16 (u8x64_zero (), r, 0); + ghash4_mul_next (&ctx->gd, r4, h); +#elif N_LANES == 2 + u8x32 h = u8x32_insert_lo (u8x32_zero (), ctx->Hi[NUM_HI - 1]); + u8x32 r2 = u8x32_insert_lo (u8x32_zero (), r); + ghash2_mul_next (&ctx->gd, r2, h); +#else + ghash_mul_next (&ctx->gd, r, ctx->Hi[NUM_HI - 1]); +#endif +} + +static_always_inline void +aes_gcm_enc_ctr0_round (aes_gcm_ctx_t *ctx, int aes_round) +{ + if (aes_round == 0) + ctx->EY0 ^= ctx->Ke[0].x1; + else if (aes_round == ctx->rounds) + ctx->EY0 = aes_enc_last_round (ctx->EY0, ctx->Ke[aes_round].x1); + else + ctx->EY0 = aes_enc_round (ctx->EY0, ctx->Ke[aes_round].x1); +} + +static_always_inline void +aes_gcm_ghash (aes_gcm_ctx_t *ctx, u8 *data, u32 n_left) +{ + uword i; + aes_data_t r = {}; + const aes_mem_t *d = (aes_mem_t *) data; + + for (; n_left >= 8 * N; n_left -= 8 * N, d += 8) + { + if (ctx->operation == AES_GCM_OP_GMAC && n_left == N * 8) + { + aes_gcm_ghash_mul_first (ctx, d[0], 8 * N_LANES + 1); + for (i = 1; i < 8; i++) + aes_gcm_ghash_mul_next (ctx, d[i]); + aes_gcm_ghash_mul_bit_len (ctx); + aes_gcm_ghash_reduce (ctx); + aes_gcm_ghash_reduce2 (ctx); + aes_gcm_ghash_final (ctx); + goto done; + } + + aes_gcm_ghash_mul_first (ctx, d[0], 8 * N_LANES); + for (i = 1; i < 8; i++) + aes_gcm_ghash_mul_next (ctx, d[i]); + aes_gcm_ghash_reduce (ctx); + aes_gcm_ghash_reduce2 (ctx); + aes_gcm_ghash_final (ctx); + } + + if (n_left > 0) + { + int n_lanes = (n_left + 15) / 16; + + if (ctx->operation == AES_GCM_OP_GMAC) + n_lanes++; + + if (n_left < N) + { + clib_memcpy_fast (&r, d, n_left); + aes_gcm_ghash_mul_first (ctx, r, n_lanes); + } + else + { + aes_gcm_ghash_mul_first (ctx, d[0], n_lanes); + n_left -= N; + i = 1; + + if (n_left >= 4 * N) + { + aes_gcm_ghash_mul_next (ctx, d[i]); + aes_gcm_ghash_mul_next (ctx, d[i + 1]); + aes_gcm_ghash_mul_next (ctx, d[i + 2]); + aes_gcm_ghash_mul_next (ctx, d[i + 3]); + n_left -= 4 * N; + i += 4; + } + if (n_left >= 2 * N) + { + aes_gcm_ghash_mul_next (ctx, d[i]); + aes_gcm_ghash_mul_next (ctx, d[i + 1]); + n_left -= 2 * N; + i += 2; + } + + if (n_left >= N) + { + aes_gcm_ghash_mul_next (ctx, d[i]); + n_left -= N; + i += 1; + } + + if (n_left) + { + clib_memcpy_fast (&r, d + i, n_left); + aes_gcm_ghash_mul_next (ctx, r); + } + } + + if (ctx->operation == AES_GCM_OP_GMAC) + aes_gcm_ghash_mul_bit_len (ctx); + aes_gcm_ghash_reduce (ctx); + aes_gcm_ghash_reduce2 (ctx); + aes_gcm_ghash_final (ctx); + } + else if (ctx->operation == AES_GCM_OP_GMAC) + { + u8x16 r = (u8x16) ((u64x2){ ctx->data_bytes, ctx->aad_bytes } << 3); + ctx->T = ghash_mul (r ^ ctx->T, ctx->Hi[NUM_HI - 1]); + } + +done: + /* encrypt counter 0 E(Y0, k) */ + if (ctx->operation == AES_GCM_OP_GMAC) + for (int i = 0; i < ctx->rounds + 1; i += 1) + aes_gcm_enc_ctr0_round (ctx, i); +} + +static_always_inline void +aes_gcm_enc_first_round (aes_gcm_ctx_t *ctx, aes_data_t *r, uword n_blocks) +{ + const aes_gcm_expaned_key_t Ke0 = ctx->Ke[0]; + uword i = 0; + +#if N_LANES == 4 + const u32x16 ctr_inv_4444 = { 0, 0, 0, 4 << 24, 0, 0, 0, 4 << 24, + 0, 0, 0, 4 << 24, 0, 0, 0, 4 << 24 }; + + const u32x16 ctr_4444 = { + 4, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, + }; + + /* As counter is stored in network byte order for performance reasons we + are incrementing least significant byte only except in case where we + overlow. As we are processing four 512-blocks in parallel except the + last round, overflow can happen only when n == 4 */ + + if (n_blocks == 4) + for (; i < 2; i++) + { + r[i] = Ke0.x4 ^ (u8x64) ctx->Y; + ctx->Y += ctr_inv_4444; + } + + if (n_blocks == 4 && PREDICT_FALSE ((u8) ctx->counter == 242)) + { + u32x16 Yr = (u32x16) aes_gcm_reflect ((u8x64) ctx->Y); + + for (; i < n_blocks; i++) + { + r[i] = Ke0.x4 ^ (u8x64) ctx->Y; + Yr += ctr_4444; + ctx->Y = (u32x16) aes_gcm_reflect ((u8x64) Yr); + } + } + else + { + for (; i < n_blocks; i++) + { + r[i] = Ke0.x4 ^ (u8x64) ctx->Y; + ctx->Y += ctr_inv_4444; + } + } + ctx->counter += n_blocks * 4; +#elif N_LANES == 2 + const u32x8 ctr_inv_22 = { 0, 0, 0, 2 << 24, 0, 0, 0, 2 << 24 }; + const u32x8 ctr_22 = { 2, 0, 0, 0, 2, 0, 0, 0 }; + + /* As counter is stored in network byte order for performance reasons we + are incrementing least significant byte only except in case where we + overlow. As we are processing four 512-blocks in parallel except the + last round, overflow can happen only when n == 4 */ + + if (n_blocks == 4) + for (; i < 2; i++) + { + r[i] = Ke0.x2 ^ (u8x32) ctx->Y; + ctx->Y += ctr_inv_22; + } + + if (n_blocks == 4 && PREDICT_FALSE ((u8) ctx->counter == 250)) + { + u32x8 Yr = (u32x8) aes_gcm_reflect ((u8x32) ctx->Y); + + for (; i < n_blocks; i++) + { + r[i] = Ke0.x2 ^ (u8x32) ctx->Y; + Yr += ctr_22; + ctx->Y = (u32x8) aes_gcm_reflect ((u8x32) Yr); + } + } + else + { + for (; i < n_blocks; i++) + { + r[i] = Ke0.x2 ^ (u8x32) ctx->Y; + ctx->Y += ctr_inv_22; + } + } + ctx->counter += n_blocks * 2; +#else + const u32x4 ctr_inv_1 = { 0, 0, 0, 1 << 24 }; + + if (PREDICT_TRUE ((u8) ctx->counter < 0xfe) || n_blocks < 3) + { + for (; i < n_blocks; i++) + { + r[i] = Ke0.x1 ^ (u8x16) ctx->Y; + ctx->Y += ctr_inv_1; + } + ctx->counter += n_blocks; + } + else + { + r[i++] = Ke0.x1 ^ (u8x16) ctx->Y; + ctx->Y += ctr_inv_1; + ctx->counter += 1; + + for (; i < n_blocks; i++) + { + r[i] = Ke0.x1 ^ (u8x16) ctx->Y; + ctx->counter++; + ctx->Y[3] = clib_host_to_net_u32 (ctx->counter); + } + } +#endif +} + +static_always_inline void +aes_gcm_enc_round (aes_data_t *r, const aes_gcm_expaned_key_t *Ke, + uword n_blocks) +{ + for (int i = 0; i < n_blocks; i++) +#if N_LANES == 4 + r[i] = aes_enc_round_x4 (r[i], Ke->x4); +#elif N_LANES == 2 + r[i] = aes_enc_round_x2 (r[i], Ke->x2); +#else + r[i] = aes_enc_round (r[i], Ke->x1); +#endif +} + +static_always_inline void +aes_gcm_enc_last_round (aes_gcm_ctx_t *ctx, aes_data_t *r, aes_data_t *d, + const aes_gcm_expaned_key_t *Ke, uword n_blocks) +{ + /* additional ronuds for AES-192 and AES-256 */ + for (int i = 10; i < ctx->rounds; i++) + aes_gcm_enc_round (r, Ke + i, n_blocks); + + for (int i = 0; i < n_blocks; i++) +#if N_LANES == 4 + d[i] ^= aes_enc_last_round_x4 (r[i], Ke[ctx->rounds].x4); +#elif N_LANES == 2 + d[i] ^= aes_enc_last_round_x2 (r[i], Ke[ctx->rounds].x2); +#else + d[i] ^= aes_enc_last_round (r[i], Ke[ctx->rounds].x1); +#endif +} + +static_always_inline void +aes_gcm_calc (aes_gcm_ctx_t *ctx, aes_data_t *d, const u8 *src, u8 *dst, u32 n, + u32 n_bytes, int with_ghash) +{ + const aes_gcm_expaned_key_t *k = ctx->Ke; + const aes_mem_t *sv = (aes_mem_t *) src; + aes_mem_t *dv = (aes_mem_t *) dst; + uword ghash_blocks, gc = 1; + aes_data_t r[4]; + u32 i, n_lanes; + + if (ctx->operation == AES_GCM_OP_ENCRYPT) + { + ghash_blocks = 4; + n_lanes = N_LANES * 4; + } + else + { + ghash_blocks = n; + n_lanes = n * N_LANES; +#if N_LANES != 1 + if (ctx->last) + n_lanes = (n_bytes + 15) / 16; +#endif + } + + n_bytes -= (n - 1) * N; + + /* AES rounds 0 and 1 */ + aes_gcm_enc_first_round (ctx, r, n); + aes_gcm_enc_round (r, k + 1, n); + + /* load data - decrypt round */ + if (ctx->operation == AES_GCM_OP_DECRYPT) + { + for (i = 0; i < n - ctx->last; i++) + d[i] = sv[i]; + + if (ctx->last) + d[n - 1] = aes_gcm_load_partial ((u8 *) (sv + n - 1), n_bytes); + } + + /* GHASH multiply block 0 */ + if (with_ghash) + aes_gcm_ghash_mul_first (ctx, d[0], n_lanes); + + /* AES rounds 2 and 3 */ + aes_gcm_enc_round (r, k + 2, n); + aes_gcm_enc_round (r, k + 3, n); + + /* GHASH multiply block 1 */ + if (with_ghash && gc++ < ghash_blocks) + aes_gcm_ghash_mul_next (ctx, (d[1])); + + /* AES rounds 4 and 5 */ + aes_gcm_enc_round (r, k + 4, n); + aes_gcm_enc_round (r, k + 5, n); + + /* GHASH multiply block 2 */ + if (with_ghash && gc++ < ghash_blocks) + aes_gcm_ghash_mul_next (ctx, (d[2])); + + /* AES rounds 6 and 7 */ + aes_gcm_enc_round (r, k + 6, n); + aes_gcm_enc_round (r, k + 7, n); + + /* GHASH multiply block 3 */ + if (with_ghash && gc++ < ghash_blocks) + aes_gcm_ghash_mul_next (ctx, (d[3])); + + /* load 4 blocks of data - decrypt round */ + if (ctx->operation == AES_GCM_OP_ENCRYPT) + { + for (i = 0; i < n - ctx->last; i++) + d[i] = sv[i]; + + if (ctx->last) + d[n - 1] = aes_gcm_load_partial (sv + n - 1, n_bytes); + } + + /* AES rounds 8 and 9 */ + aes_gcm_enc_round (r, k + 8, n); + aes_gcm_enc_round (r, k + 9, n); + + /* AES last round(s) */ + aes_gcm_enc_last_round (ctx, r, d, k, n); + + /* store data */ + for (i = 0; i < n - ctx->last; i++) + dv[i] = d[i]; + + if (ctx->last) + aes_gcm_store_partial (d[n - 1], dv + n - 1, n_bytes); + + /* GHASH reduce 1st step */ + aes_gcm_ghash_reduce (ctx); + + /* GHASH reduce 2nd step */ + if (with_ghash) + aes_gcm_ghash_reduce2 (ctx); + + /* GHASH final step */ + if (with_ghash) + aes_gcm_ghash_final (ctx); +} + +static_always_inline void +aes_gcm_calc_double (aes_gcm_ctx_t *ctx, aes_data_t *d, const u8 *src, u8 *dst, + int with_ghash) +{ + const aes_gcm_expaned_key_t *k = ctx->Ke; + const aes_mem_t *sv = (aes_mem_t *) src; + aes_mem_t *dv = (aes_mem_t *) dst; + aes_data_t r[4]; + + /* AES rounds 0 and 1 */ + aes_gcm_enc_first_round (ctx, r, 4); + aes_gcm_enc_round (r, k + 1, 4); + + /* load 4 blocks of data - decrypt round */ + if (ctx->operation == AES_GCM_OP_DECRYPT) + for (int i = 0; i < 4; i++) + d[i] = sv[i]; + + /* GHASH multiply block 0 */ + aes_gcm_ghash_mul_first (ctx, d[0], N_LANES * 8); + + /* AES rounds 2 and 3 */ + aes_gcm_enc_round (r, k + 2, 4); + aes_gcm_enc_round (r, k + 3, 4); + + /* GHASH multiply block 1 */ + aes_gcm_ghash_mul_next (ctx, (d[1])); + + /* AES rounds 4 and 5 */ + aes_gcm_enc_round (r, k + 4, 4); + aes_gcm_enc_round (r, k + 5, 4); + + /* GHASH multiply block 2 */ + aes_gcm_ghash_mul_next (ctx, (d[2])); + + /* AES rounds 6 and 7 */ + aes_gcm_enc_round (r, k + 6, 4); + aes_gcm_enc_round (r, k + 7, 4); + + /* GHASH multiply block 3 */ + aes_gcm_ghash_mul_next (ctx, (d[3])); + + /* AES rounds 8 and 9 */ + aes_gcm_enc_round (r, k + 8, 4); + aes_gcm_enc_round (r, k + 9, 4); + + /* load 4 blocks of data - encrypt round */ + if (ctx->operation == AES_GCM_OP_ENCRYPT) + for (int i = 0; i < 4; i++) + d[i] = sv[i]; + + /* AES last round(s) */ + aes_gcm_enc_last_round (ctx, r, d, k, 4); + + /* store 4 blocks of data */ + for (int i = 0; i < 4; i++) + dv[i] = d[i]; + + /* load next 4 blocks of data data - decrypt round */ + if (ctx->operation == AES_GCM_OP_DECRYPT) + for (int i = 0; i < 4; i++) + d[i] = sv[i + 4]; + + /* GHASH multiply block 4 */ + aes_gcm_ghash_mul_next (ctx, (d[0])); + + /* AES rounds 0 and 1 */ + aes_gcm_enc_first_round (ctx, r, 4); + aes_gcm_enc_round (r, k + 1, 4); + + /* GHASH multiply block 5 */ + aes_gcm_ghash_mul_next (ctx, (d[1])); + + /* AES rounds 2 and 3 */ + aes_gcm_enc_round (r, k + 2, 4); + aes_gcm_enc_round (r, k + 3, 4); + + /* GHASH multiply block 6 */ + aes_gcm_ghash_mul_next (ctx, (d[2])); + + /* AES rounds 4 and 5 */ + aes_gcm_enc_round (r, k + 4, 4); + aes_gcm_enc_round (r, k + 5, 4); + + /* GHASH multiply block 7 */ + aes_gcm_ghash_mul_next (ctx, (d[3])); + + /* AES rounds 6 and 7 */ + aes_gcm_enc_round (r, k + 6, 4); + aes_gcm_enc_round (r, k + 7, 4); + + /* GHASH reduce 1st step */ + aes_gcm_ghash_reduce (ctx); + + /* AES rounds 8 and 9 */ + aes_gcm_enc_round (r, k + 8, 4); + aes_gcm_enc_round (r, k + 9, 4); + + /* GHASH reduce 2nd step */ + aes_gcm_ghash_reduce2 (ctx); + + /* load 4 blocks of data - encrypt round */ + if (ctx->operation == AES_GCM_OP_ENCRYPT) + for (int i = 0; i < 4; i++) + d[i] = sv[i + 4]; + + /* AES last round(s) */ + aes_gcm_enc_last_round (ctx, r, d, k, 4); + + /* store data */ + for (int i = 0; i < 4; i++) + dv[i + 4] = d[i]; + + /* GHASH final step */ + aes_gcm_ghash_final (ctx); +} + +static_always_inline void +aes_gcm_mask_bytes (aes_data_t *d, uword n_bytes) +{ + const union + { + u8 b[64]; + aes_data_t r; + } scale = { + .b = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63 }, + }; + + d[0] &= (aes_gcm_splat (n_bytes) > scale.r); +} + +static_always_inline void +aes_gcm_calc_last (aes_gcm_ctx_t *ctx, aes_data_t *d, int n_blocks, + u32 n_bytes) +{ + int n_lanes = (N_LANES == 1 ? n_blocks : (n_bytes + 15) / 16) + 1; + n_bytes -= (n_blocks - 1) * N; + int i; + + aes_gcm_enc_ctr0_round (ctx, 0); + aes_gcm_enc_ctr0_round (ctx, 1); + + if (n_bytes != N) + aes_gcm_mask_bytes (d + n_blocks - 1, n_bytes); + + aes_gcm_ghash_mul_first (ctx, d[0], n_lanes); + + aes_gcm_enc_ctr0_round (ctx, 2); + aes_gcm_enc_ctr0_round (ctx, 3); + + if (n_blocks > 1) + aes_gcm_ghash_mul_next (ctx, d[1]); + + aes_gcm_enc_ctr0_round (ctx, 4); + aes_gcm_enc_ctr0_round (ctx, 5); + + if (n_blocks > 2) + aes_gcm_ghash_mul_next (ctx, d[2]); + + aes_gcm_enc_ctr0_round (ctx, 6); + aes_gcm_enc_ctr0_round (ctx, 7); + + if (n_blocks > 3) + aes_gcm_ghash_mul_next (ctx, d[3]); + + aes_gcm_enc_ctr0_round (ctx, 8); + aes_gcm_enc_ctr0_round (ctx, 9); + + aes_gcm_ghash_mul_bit_len (ctx); + aes_gcm_ghash_reduce (ctx); + + for (i = 10; i < ctx->rounds; i++) + aes_gcm_enc_ctr0_round (ctx, i); + + aes_gcm_ghash_reduce2 (ctx); + + aes_gcm_ghash_final (ctx); + + aes_gcm_enc_ctr0_round (ctx, i); +} + +static_always_inline void +aes_gcm_enc (aes_gcm_ctx_t *ctx, const u8 *src, u8 *dst, u32 n_left) +{ + aes_data_t d[4]; + + if (PREDICT_FALSE (n_left == 0)) + { + int i; + for (i = 0; i < ctx->rounds + 1; i++) + aes_gcm_enc_ctr0_round (ctx, i); + return; + } + + if (n_left < 4 * N) + { + ctx->last = 1; + if (n_left > 3 * N) + { + aes_gcm_calc (ctx, d, src, dst, 4, n_left, /* with_ghash */ 0); + aes_gcm_calc_last (ctx, d, 4, n_left); + } + else if (n_left > 2 * N) + { + aes_gcm_calc (ctx, d, src, dst, 3, n_left, /* with_ghash */ 0); + aes_gcm_calc_last (ctx, d, 3, n_left); + } + else if (n_left > N) + { + aes_gcm_calc (ctx, d, src, dst, 2, n_left, /* with_ghash */ 0); + aes_gcm_calc_last (ctx, d, 2, n_left); + } + else + { + aes_gcm_calc (ctx, d, src, dst, 1, n_left, /* with_ghash */ 0); + aes_gcm_calc_last (ctx, d, 1, n_left); + } + return; + } + aes_gcm_calc (ctx, d, src, dst, 4, 4 * N, /* with_ghash */ 0); + + /* next */ + n_left -= 4 * N; + dst += 4 * N; + src += 4 * N; + + for (; n_left >= 8 * N; n_left -= 8 * N, src += 8 * N, dst += 8 * N) + aes_gcm_calc_double (ctx, d, src, dst, /* with_ghash */ 1); + + if (n_left >= 4 * N) + { + aes_gcm_calc (ctx, d, src, dst, 4, 4 * N, /* with_ghash */ 1); + + /* next */ + n_left -= 4 * N; + dst += 4 * N; + src += 4 * N; + } + + if (n_left == 0) + { + aes_gcm_calc_last (ctx, d, 4, 4 * N); + return; + } + + ctx->last = 1; + + if (n_left > 3 * N) + { + aes_gcm_calc (ctx, d, src, dst, 4, n_left, /* with_ghash */ 1); + aes_gcm_calc_last (ctx, d, 4, n_left); + } + else if (n_left > 2 * N) + { + aes_gcm_calc (ctx, d, src, dst, 3, n_left, /* with_ghash */ 1); + aes_gcm_calc_last (ctx, d, 3, n_left); + } + else if (n_left > N) + { + aes_gcm_calc (ctx, d, src, dst, 2, n_left, /* with_ghash */ 1); + aes_gcm_calc_last (ctx, d, 2, n_left); + } + else + { + aes_gcm_calc (ctx, d, src, dst, 1, n_left, /* with_ghash */ 1); + aes_gcm_calc_last (ctx, d, 1, n_left); + } +} + +static_always_inline void +aes_gcm_dec (aes_gcm_ctx_t *ctx, const u8 *src, u8 *dst, uword n_left) +{ + aes_data_t d[4] = {}; + for (; n_left >= 8 * N; n_left -= 8 * N, dst += 8 * N, src += 8 * N) + aes_gcm_calc_double (ctx, d, src, dst, /* with_ghash */ 1); + + if (n_left >= 4 * N) + { + aes_gcm_calc (ctx, d, src, dst, 4, 4 * N, /* with_ghash */ 1); + + /* next */ + n_left -= 4 * N; + dst += N * 4; + src += N * 4; + } + + if (n_left == 0) + goto done; + + ctx->last = 1; + + if (n_left > 3 * N) + aes_gcm_calc (ctx, d, src, dst, 4, n_left, /* with_ghash */ 1); + else if (n_left > 2 * N) + aes_gcm_calc (ctx, d, src, dst, 3, n_left, /* with_ghash */ 1); + else if (n_left > N) + aes_gcm_calc (ctx, d, src, dst, 2, n_left, /* with_ghash */ 1); + else + aes_gcm_calc (ctx, d, src, dst, 1, n_left, /* with_ghash */ 1); + + u8x16 r; +done: + r = (u8x16) ((u64x2){ ctx->data_bytes, ctx->aad_bytes } << 3); + ctx->T = ghash_mul (r ^ ctx->T, ctx->Hi[NUM_HI - 1]); + + /* encrypt counter 0 E(Y0, k) */ + for (int i = 0; i < ctx->rounds + 1; i += 1) + aes_gcm_enc_ctr0_round (ctx, i); +} + +static_always_inline int +aes_gcm (const u8 *src, u8 *dst, const u8 *aad, u8 *ivp, u8 *tag, + u32 data_bytes, u32 aad_bytes, u8 tag_len, + const aes_gcm_key_data_t *kd, int aes_rounds, aes_gcm_op_t op) +{ + u8 *addt = (u8 *) aad; + u32x4 Y0; + + aes_gcm_ctx_t _ctx = { .counter = 2, + .rounds = aes_rounds, + .operation = op, + .data_bytes = data_bytes, + .aad_bytes = aad_bytes, + .Hi = kd->Hi }, + *ctx = &_ctx; + + /* initalize counter */ + Y0 = (u32x4) (u64x2){ *(u64u *) ivp, 0 }; + Y0[2] = *(u32u *) (ivp + 8); + Y0[3] = 1 << 24; + ctx->EY0 = (u8x16) Y0; + ctx->Ke = kd->Ke; +#if N_LANES == 4 + ctx->Y = u32x16_splat_u32x4 (Y0) + (u32x16){ + 0, 0, 0, 1 << 24, 0, 0, 0, 2 << 24, 0, 0, 0, 3 << 24, 0, 0, 0, 4 << 24, + }; +#elif N_LANES == 2 + ctx->Y = + u32x8_splat_u32x4 (Y0) + (u32x8){ 0, 0, 0, 1 << 24, 0, 0, 0, 2 << 24 }; +#else + ctx->Y = Y0 + (u32x4){ 0, 0, 0, 1 << 24 }; +#endif + + /* calculate ghash for AAD */ + aes_gcm_ghash (ctx, addt, aad_bytes); + + clib_prefetch_load (tag); + + /* ghash and encrypt/edcrypt */ + if (op == AES_GCM_OP_ENCRYPT) + aes_gcm_enc (ctx, src, dst, data_bytes); + else if (op == AES_GCM_OP_DECRYPT) + aes_gcm_dec (ctx, src, dst, data_bytes); + + /* final tag is */ + ctx->T = u8x16_reflect (ctx->T) ^ ctx->EY0; + + /* tag_len 16 -> 0 */ + tag_len &= 0xf; + + if (op == AES_GCM_OP_ENCRYPT || op == AES_GCM_OP_GMAC) + { + /* store tag */ + if (tag_len) + u8x16_store_partial (ctx->T, tag, tag_len); + else + ((u8x16u *) tag)[0] = ctx->T; + } + else + { + /* check tag */ + if (tag_len) + { + u16 mask = pow2_mask (tag_len); + u8x16 expected = u8x16_load_partial (tag, tag_len); + if ((u8x16_msb_mask (expected == ctx->T) & mask) == mask) + return 1; + } + else + { + if (u8x16_is_equal (ctx->T, *(u8x16u *) tag)) + return 1; + } + } + return 0; +} + +static_always_inline void +clib_aes_gcm_key_expand (aes_gcm_key_data_t *kd, const u8 *key, + aes_key_size_t ks) +{ + u8x16 H; + u8x16 ek[AES_KEY_ROUNDS (AES_KEY_256) + 1]; + aes_gcm_expaned_key_t *Ke = (aes_gcm_expaned_key_t *) kd->Ke; + + /* expand AES key */ + aes_key_expand (ek, key, ks); + for (int i = 0; i < AES_KEY_ROUNDS (ks) + 1; i++) + Ke[i].lanes[0] = Ke[i].lanes[1] = Ke[i].lanes[2] = Ke[i].lanes[3] = ek[i]; + + /* pre-calculate H */ + H = aes_encrypt_block (u8x16_zero (), ek, ks); + H = u8x16_reflect (H); + ghash_precompute (H, (u8x16 *) kd->Hi, ARRAY_LEN (kd->Hi)); +} + +static_always_inline void +clib_aes128_gcm_enc (const aes_gcm_key_data_t *kd, const u8 *plaintext, + u32 data_bytes, const u8 *aad, u32 aad_bytes, + const u8 *iv, u32 tag_bytes, u8 *cyphertext, u8 *tag) +{ + aes_gcm (plaintext, cyphertext, aad, (u8 *) iv, tag, data_bytes, aad_bytes, + tag_bytes, kd, AES_KEY_ROUNDS (AES_KEY_128), AES_GCM_OP_ENCRYPT); +} + +static_always_inline void +clib_aes256_gcm_enc (const aes_gcm_key_data_t *kd, const u8 *plaintext, + u32 data_bytes, const u8 *aad, u32 aad_bytes, + const u8 *iv, u32 tag_bytes, u8 *cyphertext, u8 *tag) +{ + aes_gcm (plaintext, cyphertext, aad, (u8 *) iv, tag, data_bytes, aad_bytes, + tag_bytes, kd, AES_KEY_ROUNDS (AES_KEY_256), AES_GCM_OP_ENCRYPT); +} + +static_always_inline int +clib_aes128_gcm_dec (const aes_gcm_key_data_t *kd, const u8 *cyphertext, + u32 data_bytes, const u8 *aad, u32 aad_bytes, + const u8 *iv, const u8 *tag, u32 tag_bytes, u8 *plaintext) +{ + return aes_gcm (cyphertext, plaintext, aad, (u8 *) iv, (u8 *) tag, + data_bytes, aad_bytes, tag_bytes, kd, + AES_KEY_ROUNDS (AES_KEY_128), AES_GCM_OP_DECRYPT); +} + +static_always_inline int +clib_aes256_gcm_dec (const aes_gcm_key_data_t *kd, const u8 *cyphertext, + u32 data_bytes, const u8 *aad, u32 aad_bytes, + const u8 *iv, const u8 *tag, u32 tag_bytes, u8 *plaintext) +{ + return aes_gcm (cyphertext, plaintext, aad, (u8 *) iv, (u8 *) tag, + data_bytes, aad_bytes, tag_bytes, kd, + AES_KEY_ROUNDS (AES_KEY_256), AES_GCM_OP_DECRYPT); +} + +static_always_inline void +clib_aes128_gmac (const aes_gcm_key_data_t *kd, const u8 *data, u32 data_bytes, + const u8 *iv, u32 tag_bytes, u8 *tag) +{ + aes_gcm (0, 0, data, (u8 *) iv, tag, 0, data_bytes, tag_bytes, kd, + AES_KEY_ROUNDS (AES_KEY_128), AES_GCM_OP_GMAC); +} + +static_always_inline void +clib_aes256_gmac (const aes_gcm_key_data_t *kd, const u8 *data, u32 data_bytes, + const u8 *iv, u32 tag_bytes, u8 *tag) +{ + aes_gcm (0, 0, data, (u8 *) iv, tag, 0, data_bytes, tag_bytes, kd, + AES_KEY_ROUNDS (AES_KEY_256), AES_GCM_OP_GMAC); +} + +#endif /* __crypto_aes_gcm_h__ */ |