diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/vppinfra/vector/count_equal.h | 90 | ||||
-rw-r--r-- | src/vppinfra/vector_avx512.h | 21 |
2 files changed, 98 insertions, 13 deletions
diff --git a/src/vppinfra/vector/count_equal.h b/src/vppinfra/vector/count_equal.h index 98770cff7c0..a2aeecd9ba0 100644 --- a/src/vppinfra/vector/count_equal.h +++ b/src/vppinfra/vector/count_equal.h @@ -67,28 +67,62 @@ clib_count_equal_u32 (u32 *data, uword max_count) count = 0; first = data[0]; -#if defined(CLIB_HAVE_VEC256) +#if defined(CLIB_HAVE_VEC512) + u32x16 splat = u32x16_splat (first); + while (count + 15 < max_count) + { + u32 bmp; + bmp = u32x16_is_equal_mask (u32x16_load_unaligned (data), splat); + if (bmp != pow2_mask (16)) + return count + count_trailing_zeros (~bmp); + + data += 16; + count += 16; + } + if (count == max_count) + return count; + else + { + u32 mask = pow2_mask (max_count - count); + u32 bmp = + u32x16_is_equal_mask (u32x16_mask_load_zero (data, mask), splat); + return count + count_trailing_zeros (~bmp); + } +#elif defined(CLIB_HAVE_VEC256) u32x8 splat = u32x8_splat (first); while (count + 7 < max_count) { - u64 bmp; + u32 bmp; +#ifdef __AVX512F__ + bmp = u32x8_is_equal_mask (u32x8_load_unaligned (data), splat); + if (bmp != pow2_mask (8)) + return count + count_trailing_zeros (~bmp); +#else bmp = u8x32_msb_mask ((u8x32) (u32x8_load_unaligned (data) == splat)); if (bmp != 0xffffffff) - { - count += count_trailing_zeros (~bmp) / 4; - return count; - } + return count + count_trailing_zeros (~bmp) / 4; +#endif data += 8; count += 8; } + if (count == max_count) + return count; +#if defined(CxLIB_HAVE_VEC256_MASK_LOAD_STORE) + else + { + u32 mask = pow2_mask (max_count - count); + u32 bmp = u32x8_is_equal_mask (u32x8_mask_load_zero (data, mask), splat); + return count + count_trailing_zeros (~bmp); + } +#endif #elif defined(CLIB_HAVE_VEC128) && defined(CLIB_HAVE_VEC128_MSB_MASK) u32x4 splat = u32x4_splat (first); while (count + 3 < max_count) { u64 bmp; bmp = u8x16_msb_mask ((u8x16) (u32x4_load_unaligned (data) == splat)); - if (bmp != 0xffff) + if (bmp != pow2_mask (4 * 4)) { count += count_trailing_zeros (~bmp) / 4; return count; @@ -191,18 +225,50 @@ clib_count_equal_u8 (u8 *data, uword max_count) count = 0; first = data[0]; -#if defined(CLIB_HAVE_VEC256) +#if defined(CLIB_HAVE_VEC512) + u8x64 splat = u8x64_splat (first); + while (count + 63 < max_count) + { + u64 bmp; + bmp = u8x64_is_equal_mask (u8x64_load_unaligned (data), splat); + if (bmp != -1) + return count + count_trailing_zeros (~bmp); + + data += 64; + count += 64; + } + if (count == max_count) + return count; +#if defined(CLIB_HAVE_VEC512_MASK_LOAD_STORE) + else + { + u64 mask = pow2_mask (max_count - count); + u64 bmp = u8x64_is_equal_mask (u8x64_mask_load_zero (data, mask), splat); + return count + count_trailing_zeros (~bmp); + } +#endif +#elif defined(CLIB_HAVE_VEC256) u8x32 splat = u8x32_splat (first); while (count + 31 < max_count) { u64 bmp; bmp = u8x32_msb_mask ((u8x32) (u8x32_load_unaligned (data) == splat)); if (bmp != 0xffffffff) - return max_count; + return count + count_trailing_zeros (~bmp); data += 32; count += 32; } + if (count == max_count) + return count; +#if defined(CLIB_HAVE_VEC256_MASK_LOAD_STORE) + else + { + u32 mask = pow2_mask (max_count - count); + u64 bmp = u8x32_msb_mask (u8x32_mask_load_zero (data, mask) == splat); + return count + count_trailing_zeros (~bmp); + } +#endif #elif defined(CLIB_HAVE_VEC128) && defined(CLIB_HAVE_VEC128_MSB_MASK) u8x16 splat = u8x16_splat (first); while (count + 15 < max_count) @@ -210,10 +276,7 @@ clib_count_equal_u8 (u8 *data, uword max_count) u64 bmp; bmp = u8x16_msb_mask ((u8x16) (u8x16_load_unaligned (data) == splat)); if (bmp != 0xffff) - { - count += count_trailing_zeros (~bmp); - return count; - } + return count + count_trailing_zeros (~bmp); data += 16; count += 16; @@ -235,4 +298,5 @@ clib_count_equal_u8 (u8 *data, uword max_count) } return count; } + #endif diff --git a/src/vppinfra/vector_avx512.h b/src/vppinfra/vector_avx512.h index a82231ac025..1a5c2528bf7 100644 --- a/src/vppinfra/vector_avx512.h +++ b/src/vppinfra/vector_avx512.h @@ -301,6 +301,27 @@ _ (u32x16, u16, epu32, _mm512, __m512i) _ (u64x8, u8, epu64, _mm512, __m512i) #undef _ +#define _(t, m, e, p, it) \ + static_always_inline m t##_is_not_equal_mask (t a, t b) \ + { \ + return p##_cmpneq_##e##_mask ((it) a, (it) b); \ + } +_ (u8x16, u16, epu8, _mm, __m128i) +_ (u16x8, u8, epu16, _mm, __m128i) +_ (u32x4, u8, epu32, _mm, __m128i) +_ (u64x2, u8, epu64, _mm, __m128i) + +_ (u8x32, u32, epu8, _mm256, __m256i) +_ (u16x16, u16, epu16, _mm256, __m256i) +_ (u32x8, u8, epu32, _mm256, __m256i) +_ (u64x4, u8, epu64, _mm256, __m256i) + +_ (u8x64, u64, epu8, _mm512, __m512i) +_ (u16x32, u32, epu16, _mm512, __m512i) +_ (u32x16, u16, epu32, _mm512, __m512i) +_ (u64x8, u8, epu64, _mm512, __m512i) +#undef _ + #define _(f, t, fn, it) \ static_always_inline t t##_from_##f (f x) { return (t) fn ((it) x); } _ (u16x16, u32x16, _mm512_cvtepi16_epi32, __m256i) |