diff --git a/include/eve/module/core/regular/bit_shr.hpp b/include/eve/module/core/regular/bit_shr.hpp index 7e40af119f..7d56ea4981 100644 --- a/include/eve/module/core/regular/bit_shr.hpp +++ b/include/eve/module/core/regular/bit_shr.hpp @@ -104,3 +104,7 @@ namespace eve #if defined(EVE_INCLUDE_POWERPC_HEADER) # include #endif + +#if defined(EVE_INCLUDE_X86_HEADER) +# include +#endif diff --git a/include/eve/module/core/regular/impl/simd/x86/bit_shr.hpp b/include/eve/module/core/regular/impl/simd/x86/bit_shr.hpp new file mode 100644 index 0000000000..6627dd9dbd --- /dev/null +++ b/include/eve/module/core/regular/impl/simd/x86/bit_shr.hpp @@ -0,0 +1,86 @@ +//================================================================================================== +/* + EVE - Expressive Vector Engine + Copyright : EVE Project Contributors + SPDX-License-Identifier: BSL-1.0 +*/ +//================================================================================================== +#pragma once + +#include +#include +#include +#include + +namespace eve::detail +{ + // bit_shr[mask](wide_val, wide_mask) + template + EVE_FORCEINLINE wide bit_shr_(EVE_REQUIRES(avx512_), C const& cx, O const&, wide v, wide s) noexcept + requires((sizeof(T) >= 2) && x86_abi>) + { + constexpr auto c = categorize>(); + auto src = alternative(cx, v, as> {}); + auto m = expand_mask(cx, as> {}).storage().value; + + // perform a logical shift right for the ints, arithmetic shift is defined in eve::shr[mask] + if constexpr( c == category::int16x32 ) return _mm512_mask_srlv_epi16 (src, m, v, s); + else if constexpr( c == category::int16x16 ) return _mm256_mask_srlv_epi16 (src, m, v, s); + else if constexpr( c == category::int16x8 ) return _mm_mask_srlv_epi16 (src, m, v, s); + + else if constexpr( c == category::int32x16 ) return _mm512_mask_srlv_epi32 (src, m, v, s); + else if constexpr( c == category::int32x8 ) return _mm256_mask_srlv_epi32 (src, m, v, s); + else if constexpr( c == category::int32x4 ) return _mm_mask_srlv_epi32 (src, m, v, s); + + else if constexpr( c == category::int64x8 ) return _mm512_mask_srlv_epi64 (src, m, v, s); + else if constexpr( c == category::int64x4 ) return _mm256_mask_srlv_epi64 (src, m, v, s); + else if constexpr( c == category::int64x2 ) return _mm_mask_srlv_epi64 (src, m, v, s); + + else if constexpr( c == category::uint16x32) return _mm512_mask_srlv_epi16 (src, m, v, s); + else if constexpr( c == category::uint16x16) return _mm256_mask_srlv_epi16 (src, m, v, s); + else if constexpr( c == category::uint16x8 ) return _mm_mask_srlv_epi16 (src, m, v, s); + + else if constexpr( c == category::uint32x16) return _mm512_mask_srlv_epi32 (src, m, v, s); + else if constexpr( c == category::uint32x8 ) return _mm256_mask_srlv_epi32 (src, m, v, s); + else if constexpr( c == category::uint32x4 ) return _mm_mask_srlv_epi32 (src, m, v, s); + + else if constexpr( c == category::uint64x8 ) return _mm512_mask_srlv_epi64 (src, m, v, s); + else if constexpr( c == category::uint64x4 ) return _mm256_mask_srlv_epi64 (src, m, v, s); + else if constexpr( c == category::uint64x2 ) return _mm_mask_srlv_epi64 (src, m, v, s); + } + + // bit_shr[mask](wide_val, imm_mask) + template + EVE_FORCEINLINE wide bit_shr_(EVE_REQUIRES(avx512_), C const& cx, O const&, wide v, unsigned int s) noexcept + requires((sizeof(T) >= 2) && x86_abi>) + { + constexpr auto c = categorize>(); + auto src = alternative(cx, v, as> {}); + auto m = expand_mask(cx, as> {}).storage().value; + + if constexpr( c == category::int16x32 ) return _mm512_mask_srli_epi16 (src, m, v, s); + else if constexpr( c == category::int16x16 ) return _mm256_mask_srli_epi16 (src, m, v, s); + else if constexpr( c == category::int16x8 ) return _mm_mask_srli_epi16 (src, m, v, s); + + else if constexpr( c == category::int32x16 ) return _mm512_mask_srli_epi32 (src, m, v, s); + else if constexpr( c == category::int32x8 ) return _mm256_mask_srli_epi32 (src, m, v, s); + else if constexpr( c == category::int32x4 ) return _mm_mask_srli_epi32 (src, m, v, s); + + else if constexpr( c == category::int64x8 ) return _mm512_mask_srli_epi64 (src, m, v, s); + else if constexpr( c == category::int64x4 ) return _mm256_mask_srli_epi64 (src, m, v, s); + else if constexpr( c == category::int64x2 ) return _mm_mask_srli_epi64 (src, m, v, s); + + // perform a logical shift right for the uints too, arithmetic shift is defined in eve::shr[mask] + else if constexpr( c == category::uint16x32) return _mm512_mask_srli_epi16 (src, m, v, s); + else if constexpr( c == category::uint16x16) return _mm256_mask_srli_epi16 (src, m, v, s); + else if constexpr( c == category::uint16x8 ) return _mm_mask_srli_epi16 (src, m, v, s); + + else if constexpr( c == category::uint32x16) return _mm512_mask_srli_epi32 (src, m, v, s); + else if constexpr( c == category::uint32x8 ) return _mm256_mask_srli_epi32 (src, m, v, s); + else if constexpr( c == category::uint32x4 ) return _mm_mask_srli_epi32 (src, m, v, s); + + else if constexpr( c == category::uint64x8 ) return _mm512_mask_srli_epi64 (src, m, v, s); + else if constexpr( c == category::uint64x4 ) return _mm256_mask_srli_epi64 (src, m, v, s); + else if constexpr( c == category::uint64x2 ) return _mm_mask_srli_epi64 (src, m, v, s); + } +} diff --git a/include/eve/module/core/regular/impl/simd/x86/shl.hpp b/include/eve/module/core/regular/impl/simd/x86/shl.hpp new file mode 100644 index 0000000000..b7abee6d96 --- /dev/null +++ b/include/eve/module/core/regular/impl/simd/x86/shl.hpp @@ -0,0 +1,86 @@ +//================================================================================================== +/* + EVE - Expressive Vector Engine + Copyright : EVE Project Contributors + SPDX-License-Identifier: BSL-1.0 +*/ +//================================================================================================== +#pragma once + +#include +#include +#include +#include + +namespace eve::detail +{ + // slh[mask](wide_val, wide_mask) + template + EVE_FORCEINLINE wide shl_(EVE_REQUIRES(avx512_), C const& cx, O const&, wide v, wide s) noexcept + requires((sizeof(T) >= 2) && x86_abi>) + { + constexpr auto c = categorize>(); + auto src = alternative(cx, v, as> {}); + auto m = expand_mask(cx, as> {}).storage().value; + + // perform a logical shift left for ints as it is equivalent to an arithmetic shift left + if constexpr( c == category::int16x32 ) return _mm512_mask_sllv_epi16 (src, m, v, s); + else if constexpr( c == category::int16x16 ) return _mm256_mask_sllv_epi16 (src, m, v, s); + else if constexpr( c == category::int16x8 ) return _mm_mask_sllv_epi16 (src, m, v, s); + + else if constexpr( c == category::int32x16 ) return _mm512_mask_sllv_epi32 (src, m, v, s); + else if constexpr( c == category::int32x8 ) return _mm256_mask_sllv_epi32 (src, m, v, s); + else if constexpr( c == category::int32x4 ) return _mm_mask_sllv_epi32 (src, m, v, s); + + else if constexpr( c == category::int64x8 ) return _mm512_mask_sllv_epi64 (src, m, v, s); + else if constexpr( c == category::int64x4 ) return _mm256_mask_sllv_epi64 (src, m, v, s); + else if constexpr( c == category::int64x2 ) return _mm_mask_sllv_epi64 (src, m, v, s); + + else if constexpr( c == category::uint16x32) return _mm512_mask_sllv_epi16 (src, m, v, s); + else if constexpr( c == category::uint16x16) return _mm256_mask_sllv_epi16 (src, m, v, s); + else if constexpr( c == category::uint16x8 ) return _mm_mask_sllv_epi16 (src, m, v, s); + + else if constexpr( c == category::uint32x16) return _mm512_mask_sllv_epi32 (src, m, v, s); + else if constexpr( c == category::uint32x8 ) return _mm256_mask_sllv_epi32 (src, m, v, s); + else if constexpr( c == category::uint32x4 ) return _mm_mask_sllv_epi32 (src, m, v, s); + + else if constexpr( c == category::uint64x8 ) return _mm512_mask_sllv_epi64 (src, m, v, s); + else if constexpr( c == category::uint64x4 ) return _mm256_mask_sllv_epi64 (src, m, v, s); + else if constexpr( c == category::uint64x2 ) return _mm_mask_sllv_epi64 (src, m, v, s); + } + + // shr[mask](wide_val, imm_mask) + template + EVE_FORCEINLINE wide shl_(EVE_REQUIRES(avx512_), C const& cx, O const&, wide v, unsigned int s) noexcept + requires((sizeof(T) >= 2) && x86_abi>) + { + constexpr auto c = categorize>(); + auto src = alternative(cx, v, as> {}); + auto m = expand_mask(cx, as> {}).storage().value; + + if constexpr( c == category::int16x32 ) return _mm512_mask_slli_epi16 (src, m, v, s); + else if constexpr( c == category::int16x16 ) return _mm256_mask_slli_epi16 (src, m, v, s); + else if constexpr( c == category::int16x8 ) return _mm_mask_slli_epi16 (src, m, v, s); + + else if constexpr( c == category::int32x16 ) return _mm512_mask_slli_epi32 (src, m, v, s); + else if constexpr( c == category::int32x8 ) return _mm256_mask_slli_epi32 (src, m, v, s); + else if constexpr( c == category::int32x4 ) return _mm_mask_slli_epi32 (src, m, v, s); + + else if constexpr( c == category::int64x8 ) return _mm512_mask_slli_epi64 (src, m, v, s); + else if constexpr( c == category::int64x4 ) return _mm256_mask_slli_epi64 (src, m, v, s); + else if constexpr( c == category::int64x2 ) return _mm_mask_slli_epi64 (src, m, v, s); + + // perform a logical shift left for uints as it is equivalent to an arithmetic shift left + else if constexpr( c == category::uint16x32) return _mm512_mask_slli_epi16 (src, m, v, s); + else if constexpr( c == category::uint16x16) return _mm256_mask_slli_epi16 (src, m, v, s); + else if constexpr( c == category::uint16x8 ) return _mm_mask_slli_epi16 (src, m, v, s); + + else if constexpr( c == category::uint32x16) return _mm512_mask_slli_epi32 (src, m, v, s); + else if constexpr( c == category::uint32x8 ) return _mm256_mask_slli_epi32 (src, m, v, s); + else if constexpr( c == category::uint32x4 ) return _mm_mask_slli_epi32 (src, m, v, s); + + else if constexpr( c == category::uint64x8 ) return _mm512_mask_slli_epi64 (src, m, v, s); + else if constexpr( c == category::uint64x4 ) return _mm256_mask_slli_epi64 (src, m, v, s); + else if constexpr( c == category::uint64x2 ) return _mm_mask_slli_epi64 (src, m, v, s); + } +} diff --git a/include/eve/module/core/regular/impl/simd/x86/shr.hpp b/include/eve/module/core/regular/impl/simd/x86/shr.hpp new file mode 100644 index 0000000000..c830dfb22c --- /dev/null +++ b/include/eve/module/core/regular/impl/simd/x86/shr.hpp @@ -0,0 +1,88 @@ +//================================================================================================== +/* + EVE - Expressive Vector Engine + Copyright : EVE Project Contributors + SPDX-License-Identifier: BSL-1.0 +*/ +//================================================================================================== +#pragma once + +#include +#include +#include +#include + +namespace eve::detail +{ + // shr[mask](wide_val, wide_mask) + template + EVE_FORCEINLINE wide shr_(EVE_REQUIRES(avx512_), C const& cx, O const&, wide v, wide s) noexcept + requires((sizeof(T) >= 2) && x86_abi>) + { + constexpr auto c = categorize>(); + auto src = alternative(cx, v, as> {}); + auto m = expand_mask(cx, as> {}).storage().value; + + // perform an arithmetic shift right for the ints + if constexpr( c == category::int16x32 ) return _mm512_mask_srav_epi16 (src, m, v, s); + else if constexpr( c == category::int16x16 ) return _mm256_mask_srav_epi16 (src, m, v, s); + else if constexpr( c == category::int16x8 ) return _mm_mask_srav_epi16 (src, m, v, s); + + else if constexpr( c == category::int32x16 ) return _mm512_mask_srav_epi32 (src, m, v, s); + else if constexpr( c == category::int32x8 ) return _mm256_mask_srav_epi32 (src, m, v, s); + else if constexpr( c == category::int32x4 ) return _mm_mask_srav_epi32 (src, m, v, s); + + else if constexpr( c == category::int64x8 ) return _mm512_mask_srav_epi64 (src, m, v, s); + else if constexpr( c == category::int64x4 ) return _mm256_mask_srav_epi64 (src, m, v, s); + else if constexpr( c == category::int64x2 ) return _mm_mask_srav_epi64 (src, m, v, s); + + // it does not matter for the uints, so just perform a logical shift + else if constexpr( c == category::uint16x32) return _mm512_mask_srlv_epi16 (src, m, v, s); + else if constexpr( c == category::uint16x16) return _mm256_mask_srlv_epi16 (src, m, v, s); + else if constexpr( c == category::uint16x8 ) return _mm_mask_srlv_epi16 (src, m, v, s); + + else if constexpr( c == category::uint32x16) return _mm512_mask_srlv_epi32 (src, m, v, s); + else if constexpr( c == category::uint32x8 ) return _mm256_mask_srlv_epi32 (src, m, v, s); + else if constexpr( c == category::uint32x4 ) return _mm_mask_srlv_epi32 (src, m, v, s); + + else if constexpr( c == category::uint64x8 ) return _mm512_mask_srlv_epi64 (src, m, v, s); + else if constexpr( c == category::uint64x4 ) return _mm256_mask_srlv_epi64 (src, m, v, s); + else if constexpr( c == category::uint64x2 ) return _mm_mask_srlv_epi64 (src, m, v, s); + } + + // shr[mask](wide_val, imm_mask) + template + EVE_FORCEINLINE wide shr_(EVE_REQUIRES(avx512_), C const& cx, O const&, wide v, unsigned int s) noexcept + requires((sizeof(T) >= 2) && x86_abi>) + { + constexpr auto c = categorize>(); + auto src = alternative(cx, v, as> {}); + auto m = expand_mask(cx, as> {}).storage().value; + + // perform an arithmetic shift right for the uints to preserve the sign bit + if constexpr( c == category::int16x32 ) return _mm512_mask_srai_epi16 (src, m, v, s); + else if constexpr( c == category::int16x16 ) return _mm256_mask_srai_epi16 (src, m, v, s); + else if constexpr( c == category::int16x8 ) return _mm_mask_srai_epi16 (src, m, v, s); + + else if constexpr( c == category::int32x16 ) return _mm512_mask_srai_epi32 (src, m, v, s); + else if constexpr( c == category::int32x8 ) return _mm256_mask_srai_epi32 (src, m, v, s); + else if constexpr( c == category::int32x4 ) return _mm_mask_srai_epi32 (src, m, v, s); + + else if constexpr( c == category::int64x8 ) return _mm512_mask_srai_epi64 (src, m, v, s); + else if constexpr( c == category::int64x4 ) return _mm256_mask_srai_epi64 (src, m, v, s); + else if constexpr( c == category::int64x2 ) return _mm_mask_srai_epi64 (src, m, v, s); + + // perform an logical shift right for the ints, as it does not matter + else if constexpr( c == category::uint16x32) return _mm512_mask_srli_epi16 (src, m, v, s); + else if constexpr( c == category::uint16x16) return _mm256_mask_srli_epi16 (src, m, v, s); + else if constexpr( c == category::uint16x8 ) return _mm_mask_srli_epi16 (src, m, v, s); + + else if constexpr( c == category::uint32x16) return _mm512_mask_srli_epi32 (src, m, v, s); + else if constexpr( c == category::uint32x8 ) return _mm256_mask_srli_epi32 (src, m, v, s); + else if constexpr( c == category::uint32x4 ) return _mm_mask_srli_epi32 (src, m, v, s); + + else if constexpr( c == category::uint64x8 ) return _mm512_mask_srli_epi64 (src, m, v, s); + else if constexpr( c == category::uint64x4 ) return _mm256_mask_srli_epi64 (src, m, v, s); + else if constexpr( c == category::uint64x2 ) return _mm_mask_srli_epi64 (src, m, v, s); + } +} diff --git a/include/eve/module/core/regular/shl.hpp b/include/eve/module/core/regular/shl.hpp index ff601f95ed..d5086ff078 100644 --- a/include/eve/module/core/regular/shl.hpp +++ b/include/eve/module/core/regular/shl.hpp @@ -121,3 +121,7 @@ namespace eve } } } + +#if defined(EVE_INCLUDE_X86_HEADER) +# include +#endif diff --git a/include/eve/module/core/regular/shr.hpp b/include/eve/module/core/regular/shr.hpp index 22599c9f34..ff020492ee 100644 --- a/include/eve/module/core/regular/shr.hpp +++ b/include/eve/module/core/regular/shr.hpp @@ -122,3 +122,7 @@ namespace eve } } } + +#if defined(EVE_INCLUDE_X86_HEADER) +# include +#endif