/* SPDX-License-Identifier: Apache-2.0
 * Copyright(c) 2021 Cisco Systems, Inc.
 */

#include <vppinfra/clib.h>
#include <vppinfra/mem.h>
#include <vppinfra/vector/toeplitz.h>

static u8 default_key[40] = {
  0x6d, 0x5a, 0x56, 0xda, 0x25, 0x5b, 0x0e, 0xc2, 0x41, 0x67,
  0x25, 0x3d, 0x43, 0xa3, 0x8f, 0xb0, 0xd0, 0xca, 0x2b, 0xcb,
  0xae, 0x7b, 0x30, 0xb4, 0x77, 0xcb, 0x2d, 0xa3, 0x80, 0x30,
  0xf2, 0x0c, 0x6a, 0x42, 0xb7, 0x3b, 0xbe, 0xac, 0x01, 0xfa,
};

#ifdef __x86_64__
static_always_inline void
clib_toeplitz_hash_key_expand_8 (u64x2 kv, u64x8u *m)
{
  u64x8 kv4, a, b, shift = { 0, 1, 2, 3, 4, 5, 6, 7 };

  kv4 = (u64x8){ kv[0], kv[1], kv[0], kv[1], kv[0], kv[1], kv[0], kv[1] };

  /* clang-format off */
  /* create 8 byte-swapped copies of the bytes 0 - 7 */
  a = (u64x8) u8x64_shuffle (kv4,
    0x7, 0x6, 0x5, 0x4, 0x3, 0x2, 0x1, 0x0,
    0x7, 0x6, 0x5, 0x4, 0x3, 0x2, 0x1, 0x0,
    0x7, 0x6, 0x5, 0x4, 0x3, 0x2, 0x1, 0x0,
    0x7, 0x6, 0x5, 0x4, 0x3, 0x2, 0x1, 0x0,
    0x7, 0x6, 0x5, 0x4, 0x3, 0x2, 0x1, 0x0,
    0x7, 0x6, 0x5, 0x4, 0x3, 0x2, 0x1, 0x0,
    0x7, 0x6, 0x5, 0x4, 0x3, 0x2, 0x1, 0x0,
    0x7, 0x6, 0x5, 0x4, 0x3, 0x2, 0x1, 0x0);
  /* create 8 byte-swapped copies of the bytes 4 - 11 */
  b = (u64x8) u8x64_shuffle (kv4,
    0xb, 0xa, 0x9, 0x8, 0x7, 0x6, 0x5, 0x4,
    0xb, 0xa, 0x9, 0x8, 0x7, 0x6, 0x5, 0x4,
    0xb, 0xa, 0x9, 0x8, 0x7, 0x6, 0x5, 0x4,
    0xb, 0xa, 0x9, 0x8, 0x7, 0x6, 0x5, 0x4,
    0xb, 0xa, 0x9, 0x8, 0x7, 0x6, 0x5, 0x4,
    0xb, 0xa, 0x9, 0x8, 0x7, 0x6, 0x5, 0x4,
    0xb, 0xa, 0x9, 0x8, 0x7, 0x6, 0x5, 0x4,
    0xb, 0xa, 0x9, 0x8, 0x7, 0x6, 0x5, 0x4);
  /* clang-format on */

  /* shift each 64-bit element for 0 - 7 bits */
  a <<= shift;
  b <<= shift;

  /* clang-format off */
  /* construct eight 8x8 bit matrix used by gf2p8affine */
  * m = (u64x8) u8x64_shuffle2 (a, b,
    0x07, 0x0f, 0x17, 0x1f, 0x27, 0x2f, 0x37, 0x3f,
    0x06, 0x0e, 0x16, 0x1e, 0x26, 0x2e, 0x36, 0x3e,
    0x05, 0x0d, 0x15, 0x1d, 0x25, 0x2d, 0x35, 0x3d,
    0x04, 0x0c, 0x14, 0x1c, 0x24, 0x2c, 0x34, 0x3c,
    0x47, 0x4f, 0x57, 0x5f, 0x67, 0x6f, 0x77, 0x7f,
    0x46, 0x4e, 0x56, 0x5e, 0x66, 0x6e, 0x76, 0x7e,
    0x45, 0x4d, 0x55, 0x5d, 0x65, 0x6d, 0x75, 0x7d,
    0x44, 0x4c, 0x54, 0x5c, 0x64, 0x6c, 0x74, 0x7c);
  /* clang-format on */
}

void
clib_toeplitz_hash_key_expand (u64 *matrixes, u8 *key, int size)
{
  u64x8u *m = (u64x8u *) matrixes;
  u64x2 kv = {}, zero = {};

  while (size >= 8)
    {
      kv = *(u64x2u *) key;
      clib_toeplitz_hash_key_expand_8 (kv, m);
      key += 8;
      m++;
      size -= 8;
    }

  kv = u64x2_shuffle2 (kv, zero, 1, 2);
  clib_toeplitz_hash_key_expand_8 (kv, m);
}
#endif

__clib_export clib_toeplitz_hash_key_t *
clib_toeplitz_hash_key_init (u8 *key, u32 keylen)
{
  clib_toeplitz_hash_key_t *k;
  u32 size, gfni_size = 0;

  if (key == 0)
    {
      key = default_key;
      keylen = sizeof (default_key);
    }

  size =
    round_pow2 (sizeof (clib_toeplitz_hash_key_t) + round_pow2 (keylen, 16),
		CLIB_CACHE_LINE_BYTES);
#ifdef __x86_64__
  gfni_size = round_pow2 ((keylen + 1) * 8, CLIB_CACHE_LINE_BYTES);
#endif

  k = clib_mem_alloc_aligned (size + gfni_size, CLIB_CACHE_LINE_BYTES);
  clib_memset_u8 (k, 0, size + gfni_size);
  k->key_length = keylen;
  k->gfni_offset = size;
  clib_memcpy_fast (k->data, key, keylen);

#ifdef __x86_64__
  clib_toeplitz_hash_key_expand ((u64 *) ((u8 *) k + k->gfni_offset), k->data,
				 k->key_length);
#endif

  return k;
}

__clib_export void
clib_toeplitz_hash_key_free (clib_toeplitz_hash_key_t *k)
{
  clib_mem_free (k);
}