From 4172448954d7787f511c91eecfb48897e946ed0b Mon Sep 17 00:00:00 2001 From: Damjan Marion Date: Thu, 23 Mar 2023 13:44:01 +0000 Subject: vppinfra: small improvement and polishing of AES GCM code Type: improvement Change-Id: Ie9661792ec68d4ea3c62ee9eb31b455d3b2b0a42 Signed-off-by: Damjan Marion --- src/vppinfra/crypto/aes_gcm.h | 133 ++++++++++++++++++++++++------------------ src/vppinfra/crypto/ghash.h | 36 ++++++------ 2 files changed, 95 insertions(+), 74 deletions(-) diff --git a/src/vppinfra/crypto/aes_gcm.h b/src/vppinfra/crypto/aes_gcm.h index 8a5f76c3b33..3d1b220f7b8 100644 --- a/src/vppinfra/crypto/aes_gcm.h +++ b/src/vppinfra/crypto/aes_gcm.h @@ -103,9 +103,15 @@ typedef struct aes_gcm_counter_t Y; /* ghash */ - ghash_data_t gd; + ghash_ctx_t gd; } aes_gcm_ctx_t; +static_always_inline u8x16 +aes_gcm_final_block (aes_gcm_ctx_t *ctx) +{ + return (u8x16) ((u64x2){ ctx->data_bytes, ctx->aad_bytes } << 3); +} + static_always_inline void aes_gcm_ghash_mul_first (aes_gcm_ctx_t *ctx, aes_data_t data, u32 n_lanes) { @@ -137,19 +143,18 @@ aes_gcm_ghash_mul_next (aes_gcm_ctx_t *ctx, aes_data_t data) } static_always_inline void -aes_gcm_ghash_mul_bit_len (aes_gcm_ctx_t *ctx) +aes_gcm_ghash_mul_final_block (aes_gcm_ctx_t *ctx) { - u8x16 r = (u8x16) ((u64x2){ ctx->data_bytes, ctx->aad_bytes } << 3); #if N_LANES == 4 u8x64 h = u8x64_insert_u8x16 (u8x64_zero (), ctx->Hi[NUM_HI - 1], 0); - u8x64 r4 = u8x64_insert_u8x16 (u8x64_zero (), r, 0); + u8x64 r4 = u8x64_insert_u8x16 (u8x64_zero (), aes_gcm_final_block (ctx), 0); ghash4_mul_next (&ctx->gd, r4, h); #elif N_LANES == 2 u8x32 h = u8x32_insert_lo (u8x32_zero (), ctx->Hi[NUM_HI - 1]); - u8x32 r2 = u8x32_insert_lo (u8x32_zero (), r); + u8x32 r2 = u8x32_insert_lo (u8x32_zero (), aes_gcm_final_block (ctx)); ghash2_mul_next (&ctx->gd, r2, h); #else - ghash_mul_next (&ctx->gd, r, ctx->Hi[NUM_HI - 1]); + ghash_mul_next (&ctx->gd, aes_gcm_final_block (ctx), ctx->Hi[NUM_HI - 1]); #endif } @@ -178,7 +183,7 @@ aes_gcm_ghash (aes_gcm_ctx_t *ctx, u8 *data, u32 n_left) aes_gcm_ghash_mul_first (ctx, d[0], 8 * N_LANES + 1); for (i = 1; i < 8; i++) aes_gcm_ghash_mul_next (ctx, d[i]); - aes_gcm_ghash_mul_bit_len (ctx); + aes_gcm_ghash_mul_final_block (ctx); aes_gcm_ghash_reduce (ctx); aes_gcm_ghash_reduce2 (ctx); aes_gcm_ghash_final (ctx); @@ -243,16 +248,14 @@ aes_gcm_ghash (aes_gcm_ctx_t *ctx, u8 *data, u32 n_left) } if (ctx->operation == AES_GCM_OP_GMAC) - aes_gcm_ghash_mul_bit_len (ctx); + aes_gcm_ghash_mul_final_block (ctx); aes_gcm_ghash_reduce (ctx); aes_gcm_ghash_reduce2 (ctx); aes_gcm_ghash_final (ctx); } else if (ctx->operation == AES_GCM_OP_GMAC) - { - u8x16 r = (u8x16) ((u64x2){ ctx->data_bytes, ctx->aad_bytes } << 3); - ctx->T = ghash_mul (r ^ ctx->T, ctx->Hi[NUM_HI - 1]); - } + ctx->T = + ghash_mul (aes_gcm_final_block (ctx) ^ ctx->T, ctx->Hi[NUM_HI - 1]); done: /* encrypt counter 0 E(Y0, k) */ @@ -267,6 +270,11 @@ aes_gcm_enc_first_round (aes_gcm_ctx_t *ctx, aes_data_t *r, uword n_blocks) const aes_gcm_expaned_key_t Ke0 = ctx->Ke[0]; uword i = 0; + /* As counter is stored in network byte order for performance reasons we + are incrementing least significant byte only except in case where we + overlow. As we are processing four 128, 256 or 512-blocks in parallel + except the last round, overflow can happen only when n_blocks == 4 */ + #if N_LANES == 4 const u32x16 ctr_inv_4444 = { 0, 0, 0, 4 << 24, 0, 0, 0, 4 << 24, 0, 0, 0, 4 << 24, 0, 0, 0, 4 << 24 }; @@ -275,15 +283,10 @@ aes_gcm_enc_first_round (aes_gcm_ctx_t *ctx, aes_data_t *r, uword n_blocks) 4, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, }; - /* As counter is stored in network byte order for performance reasons we - are incrementing least significant byte only except in case where we - overlow. As we are processing four 512-blocks in parallel except the - last round, overflow can happen only when n == 4 */ - if (n_blocks == 4) for (; i < 2; i++) { - r[i] = Ke0.x4 ^ (u8x64) ctx->Y; + r[i] = Ke0.x4 ^ (u8x64) ctx->Y; /* Initial AES round */ ctx->Y += ctr_inv_4444; } @@ -293,7 +296,7 @@ aes_gcm_enc_first_round (aes_gcm_ctx_t *ctx, aes_data_t *r, uword n_blocks) for (; i < n_blocks; i++) { - r[i] = Ke0.x4 ^ (u8x64) ctx->Y; + r[i] = Ke0.x4 ^ (u8x64) ctx->Y; /* Initial AES round */ Yr += ctr_4444; ctx->Y = (u32x16) aes_gcm_reflect ((u8x64) Yr); } @@ -302,7 +305,7 @@ aes_gcm_enc_first_round (aes_gcm_ctx_t *ctx, aes_data_t *r, uword n_blocks) { for (; i < n_blocks; i++) { - r[i] = Ke0.x4 ^ (u8x64) ctx->Y; + r[i] = Ke0.x4 ^ (u8x64) ctx->Y; /* Initial AES round */ ctx->Y += ctr_inv_4444; } } @@ -311,15 +314,10 @@ aes_gcm_enc_first_round (aes_gcm_ctx_t *ctx, aes_data_t *r, uword n_blocks) const u32x8 ctr_inv_22 = { 0, 0, 0, 2 << 24, 0, 0, 0, 2 << 24 }; const u32x8 ctr_22 = { 2, 0, 0, 0, 2, 0, 0, 0 }; - /* As counter is stored in network byte order for performance reasons we - are incrementing least significant byte only except in case where we - overlow. As we are processing four 512-blocks in parallel except the - last round, overflow can happen only when n == 4 */ - if (n_blocks == 4) for (; i < 2; i++) { - r[i] = Ke0.x2 ^ (u8x32) ctx->Y; + r[i] = Ke0.x2 ^ (u8x32) ctx->Y; /* Initial AES round */ ctx->Y += ctr_inv_22; } @@ -329,7 +327,7 @@ aes_gcm_enc_first_round (aes_gcm_ctx_t *ctx, aes_data_t *r, uword n_blocks) for (; i < n_blocks; i++) { - r[i] = Ke0.x2 ^ (u8x32) ctx->Y; + r[i] = Ke0.x2 ^ (u8x32) ctx->Y; /* Initial AES round */ Yr += ctr_22; ctx->Y = (u32x8) aes_gcm_reflect ((u8x32) Yr); } @@ -338,7 +336,7 @@ aes_gcm_enc_first_round (aes_gcm_ctx_t *ctx, aes_data_t *r, uword n_blocks) { for (; i < n_blocks; i++) { - r[i] = Ke0.x2 ^ (u8x32) ctx->Y; + r[i] = Ke0.x2 ^ (u8x32) ctx->Y; /* Initial AES round */ ctx->Y += ctr_inv_22; } } @@ -350,20 +348,20 @@ aes_gcm_enc_first_round (aes_gcm_ctx_t *ctx, aes_data_t *r, uword n_blocks) { for (; i < n_blocks; i++) { - r[i] = Ke0.x1 ^ (u8x16) ctx->Y; + r[i] = Ke0.x1 ^ (u8x16) ctx->Y; /* Initial AES round */ ctx->Y += ctr_inv_1; } ctx->counter += n_blocks; } else { - r[i++] = Ke0.x1 ^ (u8x16) ctx->Y; + r[i++] = Ke0.x1 ^ (u8x16) ctx->Y; /* Initial AES round */ ctx->Y += ctr_inv_1; ctx->counter += 1; for (; i < n_blocks; i++) { - r[i] = Ke0.x1 ^ (u8x16) ctx->Y; + r[i] = Ke0.x1 ^ (u8x16) ctx->Y; /* Initial AES round */ ctx->counter++; ctx->Y[3] = clib_host_to_net_u32 (ctx->counter); } @@ -510,8 +508,7 @@ aes_gcm_calc (aes_gcm_ctx_t *ctx, aes_data_t *d, const u8 *src, u8 *dst, u32 n, } static_always_inline void -aes_gcm_calc_double (aes_gcm_ctx_t *ctx, aes_data_t *d, const u8 *src, u8 *dst, - int with_ghash) +aes_gcm_calc_double (aes_gcm_ctx_t *ctx, aes_data_t *d, const u8 *src, u8 *dst) { const aes_gcm_expaned_key_t *k = ctx->Ke; const aes_mem_t *sv = (aes_mem_t *) src; @@ -680,7 +677,7 @@ aes_gcm_calc_last (aes_gcm_ctx_t *ctx, aes_data_t *d, int n_blocks, aes_gcm_enc_ctr0_round (ctx, 8); aes_gcm_enc_ctr0_round (ctx, 9); - aes_gcm_ghash_mul_bit_len (ctx); + aes_gcm_ghash_mul_final_block (ctx); aes_gcm_ghash_reduce (ctx); for (i = 10; i < ctx->rounds; i++) @@ -731,6 +728,7 @@ aes_gcm_enc (aes_gcm_ctx_t *ctx, const u8 *src, u8 *dst, u32 n_left) } return; } + aes_gcm_calc (ctx, d, src, dst, 4, 4 * N, /* with_ghash */ 0); /* next */ @@ -739,7 +737,7 @@ aes_gcm_enc (aes_gcm_ctx_t *ctx, const u8 *src, u8 *dst, u32 n_left) src += 4 * N; for (; n_left >= 8 * N; n_left -= 8 * N, src += 8 * N, dst += 8 * N) - aes_gcm_calc_double (ctx, d, src, dst, /* with_ghash */ 1); + aes_gcm_calc_double (ctx, d, src, dst); if (n_left >= 4 * N) { @@ -785,8 +783,11 @@ static_always_inline void aes_gcm_dec (aes_gcm_ctx_t *ctx, const u8 *src, u8 *dst, uword n_left) { aes_data_t d[4] = {}; + ghash_ctx_t gd; + + /* main encryption loop */ for (; n_left >= 8 * N; n_left -= 8 * N, dst += 8 * N, src += 8 * N) - aes_gcm_calc_double (ctx, d, src, dst, /* with_ghash */ 1); + aes_gcm_calc_double (ctx, d, src, dst); if (n_left >= 4 * N) { @@ -798,27 +799,48 @@ aes_gcm_dec (aes_gcm_ctx_t *ctx, const u8 *src, u8 *dst, uword n_left) src += N * 4; } - if (n_left == 0) - goto done; + if (n_left) + { + ctx->last = 1; - ctx->last = 1; + if (n_left > 3 * N) + aes_gcm_calc (ctx, d, src, dst, 4, n_left, /* with_ghash */ 1); + else if (n_left > 2 * N) + aes_gcm_calc (ctx, d, src, dst, 3, n_left, /* with_ghash */ 1); + else if (n_left > N) + aes_gcm_calc (ctx, d, src, dst, 2, n_left, /* with_ghash */ 1); + else + aes_gcm_calc (ctx, d, src, dst, 1, n_left, /* with_ghash */ 1); + } - if (n_left > 3 * N) - aes_gcm_calc (ctx, d, src, dst, 4, n_left, /* with_ghash */ 1); - else if (n_left > 2 * N) - aes_gcm_calc (ctx, d, src, dst, 3, n_left, /* with_ghash */ 1); - else if (n_left > N) - aes_gcm_calc (ctx, d, src, dst, 2, n_left, /* with_ghash */ 1); - else - aes_gcm_calc (ctx, d, src, dst, 1, n_left, /* with_ghash */ 1); + /* interleaved counter 0 encryption E(Y0, k) and ghash of final GCM + * (bit length) block */ - u8x16 r; -done: - r = (u8x16) ((u64x2){ ctx->data_bytes, ctx->aad_bytes } << 3); - ctx->T = ghash_mul (r ^ ctx->T, ctx->Hi[NUM_HI - 1]); + aes_gcm_enc_ctr0_round (ctx, 0); + aes_gcm_enc_ctr0_round (ctx, 1); - /* encrypt counter 0 E(Y0, k) */ - for (int i = 0; i < ctx->rounds + 1; i += 1) + ghash_mul_first (&gd, aes_gcm_final_block (ctx) ^ ctx->T, + ctx->Hi[NUM_HI - 1]); + + aes_gcm_enc_ctr0_round (ctx, 2); + aes_gcm_enc_ctr0_round (ctx, 3); + + ghash_reduce (&gd); + + aes_gcm_enc_ctr0_round (ctx, 4); + aes_gcm_enc_ctr0_round (ctx, 5); + + ghash_reduce2 (&gd); + + aes_gcm_enc_ctr0_round (ctx, 6); + aes_gcm_enc_ctr0_round (ctx, 7); + + ctx->T = ghash_final (&gd); + + aes_gcm_enc_ctr0_round (ctx, 8); + aes_gcm_enc_ctr0_round (ctx, 9); + + for (int i = 10; i < ctx->rounds + 1; i += 1) aes_gcm_enc_ctr0_round (ctx, i); } @@ -835,6 +857,7 @@ aes_gcm (const u8 *src, u8 *dst, const u8 *aad, u8 *ivp, u8 *tag, .operation = op, .data_bytes = data_bytes, .aad_bytes = aad_bytes, + .Ke = kd->Ke, .Hi = kd->Hi }, *ctx = &_ctx; @@ -843,7 +866,7 @@ aes_gcm (const u8 *src, u8 *dst, const u8 *aad, u8 *ivp, u8 *tag, Y0[2] = *(u32u *) (ivp + 8); Y0[3] = 1 << 24; ctx->EY0 = (u8x16) Y0; - ctx->Ke = kd->Ke; + #if N_LANES == 4 ctx->Y = u32x16_splat_u32x4 (Y0) + (u32x16){ 0, 0, 0, 1 << 24, 0, 0, 0, 2 << 24, 0, 0, 0, 3 << 24, 0, 0, 0, 4 << 24, @@ -858,8 +881,6 @@ aes_gcm (const u8 *src, u8 *dst, const u8 *aad, u8 *ivp, u8 *tag, /* calculate ghash for AAD */ aes_gcm_ghash (ctx, addt, aad_bytes); - clib_prefetch_load (tag); - /* ghash and encrypt/edcrypt */ if (op == AES_GCM_OP_ENCRYPT) aes_gcm_enc (ctx, src, dst, data_bytes); diff --git a/src/vppinfra/crypto/ghash.h b/src/vppinfra/crypto/ghash.h index bae8badb5fc..66e3f6a673a 100644 --- a/src/vppinfra/crypto/ghash.h +++ b/src/vppinfra/crypto/ghash.h @@ -89,7 +89,7 @@ * u8x16 Hi[4]; * ghash_precompute (H, Hi, 4); * - * ghash_data_t _gd, *gd = &_gd; + * ghash_ctx_t _gd, *gd = &_gd; * ghash_mul_first (gd, GH ^ b0, Hi[3]); * ghash_mul_next (gd, b1, Hi[2]); * ghash_mul_next (gd, b2, Hi[1]); @@ -154,7 +154,7 @@ typedef struct u8x32 hi2, lo2, mid2, tmp_lo2, tmp_hi2; u8x64 hi4, lo4, mid4, tmp_lo4, tmp_hi4; int pending; -} ghash_data_t; +} ghash_ctx_t; static const u8x16 ghash_poly = { 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, @@ -167,7 +167,7 @@ static const u8x16 ghash_poly2 = { }; static_always_inline void -ghash_mul_first (ghash_data_t * gd, u8x16 a, u8x16 b) +ghash_mul_first (ghash_ctx_t *gd, u8x16 a, u8x16 b) { /* a1 * b1 */ gd->hi = gmul_hi_hi (a, b); @@ -182,7 +182,7 @@ ghash_mul_first (ghash_data_t * gd, u8x16 a, u8x16 b) } static_always_inline void -ghash_mul_next (ghash_data_t * gd, u8x16 a, u8x16 b) +ghash_mul_next (ghash_ctx_t *gd, u8x16 a, u8x16 b) { /* a1 * b1 */ u8x16 hi = gmul_hi_hi (a, b); @@ -211,7 +211,7 @@ ghash_mul_next (ghash_data_t * gd, u8x16 a, u8x16 b) } static_always_inline void -ghash_reduce (ghash_data_t * gd) +ghash_reduce (ghash_ctx_t *gd) { u8x16 r; @@ -236,14 +236,14 @@ ghash_reduce (ghash_data_t * gd) } static_always_inline void -ghash_reduce2 (ghash_data_t * gd) +ghash_reduce2 (ghash_ctx_t *gd) { gd->tmp_lo = gmul_lo_lo (ghash_poly2, gd->lo); gd->tmp_hi = gmul_lo_hi (ghash_poly2, gd->lo); } static_always_inline u8x16 -ghash_final (ghash_data_t * gd) +ghash_final (ghash_ctx_t *gd) { return u8x16_xor3 (gd->hi, u8x16_word_shift_right (gd->tmp_lo, 4), u8x16_word_shift_left (gd->tmp_hi, 4)); @@ -252,7 +252,7 @@ ghash_final (ghash_data_t * gd) static_always_inline u8x16 ghash_mul (u8x16 a, u8x16 b) { - ghash_data_t _gd, *gd = &_gd; + ghash_ctx_t _gd, *gd = &_gd; ghash_mul_first (gd, a, b); ghash_reduce (gd); ghash_reduce2 (gd); @@ -297,7 +297,7 @@ gmul4_hi_hi (u8x64 a, u8x64 b) } static_always_inline void -ghash4_mul_first (ghash_data_t *gd, u8x64 a, u8x64 b) +ghash4_mul_first (ghash_ctx_t *gd, u8x64 a, u8x64 b) { gd->hi4 = gmul4_hi_hi (a, b); gd->lo4 = gmul4_lo_lo (a, b); @@ -306,7 +306,7 @@ ghash4_mul_first (ghash_data_t *gd, u8x64 a, u8x64 b) } static_always_inline void -ghash4_mul_next (ghash_data_t *gd, u8x64 a, u8x64 b) +ghash4_mul_next (ghash_ctx_t *gd, u8x64 a, u8x64 b) { u8x64 hi = gmul4_hi_hi (a, b); u8x64 lo = gmul4_lo_lo (a, b); @@ -329,7 +329,7 @@ ghash4_mul_next (ghash_data_t *gd, u8x64 a, u8x64 b) } static_always_inline void -ghash4_reduce (ghash_data_t *gd) +ghash4_reduce (ghash_ctx_t *gd) { u8x64 r; @@ -356,14 +356,14 @@ ghash4_reduce (ghash_data_t *gd) } static_always_inline void -ghash4_reduce2 (ghash_data_t *gd) +ghash4_reduce2 (ghash_ctx_t *gd) { gd->tmp_lo4 = gmul4_lo_lo (ghash4_poly2, gd->lo4); gd->tmp_hi4 = gmul4_lo_hi (ghash4_poly2, gd->lo4); } static_always_inline u8x16 -ghash4_final (ghash_data_t *gd) +ghash4_final (ghash_ctx_t *gd) { u8x64 r; u8x32 t; @@ -410,7 +410,7 @@ gmul2_hi_hi (u8x32 a, u8x32 b) } static_always_inline void -ghash2_mul_first (ghash_data_t *gd, u8x32 a, u8x32 b) +ghash2_mul_first (ghash_ctx_t *gd, u8x32 a, u8x32 b) { gd->hi2 = gmul2_hi_hi (a, b); gd->lo2 = gmul2_lo_lo (a, b); @@ -419,7 +419,7 @@ ghash2_mul_first (ghash_data_t *gd, u8x32 a, u8x32 b) } static_always_inline void -ghash2_mul_next (ghash_data_t *gd, u8x32 a, u8x32 b) +ghash2_mul_next (ghash_ctx_t *gd, u8x32 a, u8x32 b) { u8x32 hi = gmul2_hi_hi (a, b); u8x32 lo = gmul2_lo_lo (a, b); @@ -442,7 +442,7 @@ ghash2_mul_next (ghash_data_t *gd, u8x32 a, u8x32 b) } static_always_inline void -ghash2_reduce (ghash_data_t *gd) +ghash2_reduce (ghash_ctx_t *gd) { u8x32 r; @@ -469,14 +469,14 @@ ghash2_reduce (ghash_data_t *gd) } static_always_inline void -ghash2_reduce2 (ghash_data_t *gd) +ghash2_reduce2 (ghash_ctx_t *gd) { gd->tmp_lo2 = gmul2_lo_lo (ghash2_poly2, gd->lo2); gd->tmp_hi2 = gmul2_lo_hi (ghash2_poly2, gd->lo2); } static_always_inline u8x16 -ghash2_final (ghash_data_t *gd) +ghash2_final (ghash_ctx_t *gd) { u8x32 r; -- cgit 1.2.3-korg