-
Notifications
You must be signed in to change notification settings - Fork 58
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement mask shifts (take 2) (#1827)
* take 1 * take 2 * take 3 * take 4
- Loading branch information
1 parent
ec8c1f7
commit 85aa3ce
Showing
6 changed files
with
272 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
//================================================================================================== | ||
/* | ||
EVE - Expressive Vector Engine | ||
Copyright : EVE Project Contributors | ||
SPDX-License-Identifier: BSL-1.0 | ||
*/ | ||
//================================================================================================== | ||
#pragma once | ||
|
||
#include <eve/concept/value.hpp> | ||
#include <eve/detail/abi.hpp> | ||
#include <eve/detail/category.hpp> | ||
#include <eve/forward.hpp> | ||
|
||
namespace eve::detail | ||
{ | ||
// bit_shr[mask](wide_val, wide_mask) | ||
template<conditional_expr C, integral_scalar_value T, typename S, typename N, callable_options O> | ||
EVE_FORCEINLINE wide<T,N> bit_shr_(EVE_REQUIRES(avx512_), C const& cx, O const&, wide<T, N> v, wide<S, N> s) noexcept | ||
requires((sizeof(T) >= 2) && x86_abi<abi_t<T, N>>) | ||
{ | ||
constexpr auto c = categorize<wide<T, N>>(); | ||
auto src = alternative(cx, v, as<wide<T, N>> {}); | ||
auto m = expand_mask(cx, as<wide<T, N>> {}).storage().value; | ||
|
||
// perform a logical shift right for the ints, arithmetic shift is defined in eve::shr[mask] | ||
if constexpr( c == category::int16x32 ) return _mm512_mask_srlv_epi16 (src, m, v, s); | ||
else if constexpr( c == category::int16x16 ) return _mm256_mask_srlv_epi16 (src, m, v, s); | ||
else if constexpr( c == category::int16x8 ) return _mm_mask_srlv_epi16 (src, m, v, s); | ||
|
||
else if constexpr( c == category::int32x16 ) return _mm512_mask_srlv_epi32 (src, m, v, s); | ||
else if constexpr( c == category::int32x8 ) return _mm256_mask_srlv_epi32 (src, m, v, s); | ||
else if constexpr( c == category::int32x4 ) return _mm_mask_srlv_epi32 (src, m, v, s); | ||
|
||
else if constexpr( c == category::int64x8 ) return _mm512_mask_srlv_epi64 (src, m, v, s); | ||
else if constexpr( c == category::int64x4 ) return _mm256_mask_srlv_epi64 (src, m, v, s); | ||
else if constexpr( c == category::int64x2 ) return _mm_mask_srlv_epi64 (src, m, v, s); | ||
|
||
else if constexpr( c == category::uint16x32) return _mm512_mask_srlv_epi16 (src, m, v, s); | ||
else if constexpr( c == category::uint16x16) return _mm256_mask_srlv_epi16 (src, m, v, s); | ||
else if constexpr( c == category::uint16x8 ) return _mm_mask_srlv_epi16 (src, m, v, s); | ||
|
||
else if constexpr( c == category::uint32x16) return _mm512_mask_srlv_epi32 (src, m, v, s); | ||
else if constexpr( c == category::uint32x8 ) return _mm256_mask_srlv_epi32 (src, m, v, s); | ||
else if constexpr( c == category::uint32x4 ) return _mm_mask_srlv_epi32 (src, m, v, s); | ||
|
||
else if constexpr( c == category::uint64x8 ) return _mm512_mask_srlv_epi64 (src, m, v, s); | ||
else if constexpr( c == category::uint64x4 ) return _mm256_mask_srlv_epi64 (src, m, v, s); | ||
else if constexpr( c == category::uint64x2 ) return _mm_mask_srlv_epi64 (src, m, v, s); | ||
} | ||
|
||
// bit_shr[mask](wide_val, imm_mask) | ||
template<conditional_expr C, integral_scalar_value T, typename N, callable_options O> | ||
EVE_FORCEINLINE wide<T,N> bit_shr_(EVE_REQUIRES(avx512_), C const& cx, O const&, wide<T, N> v, unsigned int s) noexcept | ||
requires((sizeof(T) >= 2) && x86_abi<abi_t<T, N>>) | ||
{ | ||
constexpr auto c = categorize<wide<T, N>>(); | ||
auto src = alternative(cx, v, as<wide<T, N>> {}); | ||
auto m = expand_mask(cx, as<wide<T, N>> {}).storage().value; | ||
|
||
if constexpr( c == category::int16x32 ) return _mm512_mask_srli_epi16 (src, m, v, s); | ||
else if constexpr( c == category::int16x16 ) return _mm256_mask_srli_epi16 (src, m, v, s); | ||
else if constexpr( c == category::int16x8 ) return _mm_mask_srli_epi16 (src, m, v, s); | ||
|
||
else if constexpr( c == category::int32x16 ) return _mm512_mask_srli_epi32 (src, m, v, s); | ||
else if constexpr( c == category::int32x8 ) return _mm256_mask_srli_epi32 (src, m, v, s); | ||
else if constexpr( c == category::int32x4 ) return _mm_mask_srli_epi32 (src, m, v, s); | ||
|
||
else if constexpr( c == category::int64x8 ) return _mm512_mask_srli_epi64 (src, m, v, s); | ||
else if constexpr( c == category::int64x4 ) return _mm256_mask_srli_epi64 (src, m, v, s); | ||
else if constexpr( c == category::int64x2 ) return _mm_mask_srli_epi64 (src, m, v, s); | ||
|
||
// perform a logical shift right for the uints too, arithmetic shift is defined in eve::shr[mask] | ||
else if constexpr( c == category::uint16x32) return _mm512_mask_srli_epi16 (src, m, v, s); | ||
else if constexpr( c == category::uint16x16) return _mm256_mask_srli_epi16 (src, m, v, s); | ||
else if constexpr( c == category::uint16x8 ) return _mm_mask_srli_epi16 (src, m, v, s); | ||
|
||
else if constexpr( c == category::uint32x16) return _mm512_mask_srli_epi32 (src, m, v, s); | ||
else if constexpr( c == category::uint32x8 ) return _mm256_mask_srli_epi32 (src, m, v, s); | ||
else if constexpr( c == category::uint32x4 ) return _mm_mask_srli_epi32 (src, m, v, s); | ||
|
||
else if constexpr( c == category::uint64x8 ) return _mm512_mask_srli_epi64 (src, m, v, s); | ||
else if constexpr( c == category::uint64x4 ) return _mm256_mask_srli_epi64 (src, m, v, s); | ||
else if constexpr( c == category::uint64x2 ) return _mm_mask_srli_epi64 (src, m, v, s); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
//================================================================================================== | ||
/* | ||
EVE - Expressive Vector Engine | ||
Copyright : EVE Project Contributors | ||
SPDX-License-Identifier: BSL-1.0 | ||
*/ | ||
//================================================================================================== | ||
#pragma once | ||
|
||
#include <eve/concept/value.hpp> | ||
#include <eve/detail/abi.hpp> | ||
#include <eve/detail/category.hpp> | ||
#include <eve/forward.hpp> | ||
|
||
namespace eve::detail | ||
{ | ||
// slh[mask](wide_val, wide_mask) | ||
template<conditional_expr C, integral_scalar_value T, typename S, typename N, callable_options O> | ||
EVE_FORCEINLINE wide<T,N> shl_(EVE_REQUIRES(avx512_), C const& cx, O const&, wide<T, N> v, wide<S, N> s) noexcept | ||
requires((sizeof(T) >= 2) && x86_abi<abi_t<T, N>>) | ||
{ | ||
constexpr auto c = categorize<wide<T, N>>(); | ||
auto src = alternative(cx, v, as<wide<T, N>> {}); | ||
auto m = expand_mask(cx, as<wide<T, N>> {}).storage().value; | ||
|
||
// perform a logical shift left for ints as it is equivalent to an arithmetic shift left | ||
if constexpr( c == category::int16x32 ) return _mm512_mask_sllv_epi16 (src, m, v, s); | ||
else if constexpr( c == category::int16x16 ) return _mm256_mask_sllv_epi16 (src, m, v, s); | ||
else if constexpr( c == category::int16x8 ) return _mm_mask_sllv_epi16 (src, m, v, s); | ||
|
||
else if constexpr( c == category::int32x16 ) return _mm512_mask_sllv_epi32 (src, m, v, s); | ||
else if constexpr( c == category::int32x8 ) return _mm256_mask_sllv_epi32 (src, m, v, s); | ||
else if constexpr( c == category::int32x4 ) return _mm_mask_sllv_epi32 (src, m, v, s); | ||
|
||
else if constexpr( c == category::int64x8 ) return _mm512_mask_sllv_epi64 (src, m, v, s); | ||
else if constexpr( c == category::int64x4 ) return _mm256_mask_sllv_epi64 (src, m, v, s); | ||
else if constexpr( c == category::int64x2 ) return _mm_mask_sllv_epi64 (src, m, v, s); | ||
|
||
else if constexpr( c == category::uint16x32) return _mm512_mask_sllv_epi16 (src, m, v, s); | ||
else if constexpr( c == category::uint16x16) return _mm256_mask_sllv_epi16 (src, m, v, s); | ||
else if constexpr( c == category::uint16x8 ) return _mm_mask_sllv_epi16 (src, m, v, s); | ||
|
||
else if constexpr( c == category::uint32x16) return _mm512_mask_sllv_epi32 (src, m, v, s); | ||
else if constexpr( c == category::uint32x8 ) return _mm256_mask_sllv_epi32 (src, m, v, s); | ||
else if constexpr( c == category::uint32x4 ) return _mm_mask_sllv_epi32 (src, m, v, s); | ||
|
||
else if constexpr( c == category::uint64x8 ) return _mm512_mask_sllv_epi64 (src, m, v, s); | ||
else if constexpr( c == category::uint64x4 ) return _mm256_mask_sllv_epi64 (src, m, v, s); | ||
else if constexpr( c == category::uint64x2 ) return _mm_mask_sllv_epi64 (src, m, v, s); | ||
} | ||
|
||
// shr[mask](wide_val, imm_mask) | ||
template<conditional_expr C, integral_scalar_value T, typename N, callable_options O> | ||
EVE_FORCEINLINE wide<T,N> shl_(EVE_REQUIRES(avx512_), C const& cx, O const&, wide<T, N> v, unsigned int s) noexcept | ||
requires((sizeof(T) >= 2) && x86_abi<abi_t<T, N>>) | ||
{ | ||
constexpr auto c = categorize<wide<T, N>>(); | ||
auto src = alternative(cx, v, as<wide<T, N>> {}); | ||
auto m = expand_mask(cx, as<wide<T, N>> {}).storage().value; | ||
|
||
if constexpr( c == category::int16x32 ) return _mm512_mask_slli_epi16 (src, m, v, s); | ||
else if constexpr( c == category::int16x16 ) return _mm256_mask_slli_epi16 (src, m, v, s); | ||
else if constexpr( c == category::int16x8 ) return _mm_mask_slli_epi16 (src, m, v, s); | ||
|
||
else if constexpr( c == category::int32x16 ) return _mm512_mask_slli_epi32 (src, m, v, s); | ||
else if constexpr( c == category::int32x8 ) return _mm256_mask_slli_epi32 (src, m, v, s); | ||
else if constexpr( c == category::int32x4 ) return _mm_mask_slli_epi32 (src, m, v, s); | ||
|
||
else if constexpr( c == category::int64x8 ) return _mm512_mask_slli_epi64 (src, m, v, s); | ||
else if constexpr( c == category::int64x4 ) return _mm256_mask_slli_epi64 (src, m, v, s); | ||
else if constexpr( c == category::int64x2 ) return _mm_mask_slli_epi64 (src, m, v, s); | ||
|
||
// perform a logical shift left for uints as it is equivalent to an arithmetic shift left | ||
else if constexpr( c == category::uint16x32) return _mm512_mask_slli_epi16 (src, m, v, s); | ||
else if constexpr( c == category::uint16x16) return _mm256_mask_slli_epi16 (src, m, v, s); | ||
else if constexpr( c == category::uint16x8 ) return _mm_mask_slli_epi16 (src, m, v, s); | ||
|
||
else if constexpr( c == category::uint32x16) return _mm512_mask_slli_epi32 (src, m, v, s); | ||
else if constexpr( c == category::uint32x8 ) return _mm256_mask_slli_epi32 (src, m, v, s); | ||
else if constexpr( c == category::uint32x4 ) return _mm_mask_slli_epi32 (src, m, v, s); | ||
|
||
else if constexpr( c == category::uint64x8 ) return _mm512_mask_slli_epi64 (src, m, v, s); | ||
else if constexpr( c == category::uint64x4 ) return _mm256_mask_slli_epi64 (src, m, v, s); | ||
else if constexpr( c == category::uint64x2 ) return _mm_mask_slli_epi64 (src, m, v, s); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
//================================================================================================== | ||
/* | ||
EVE - Expressive Vector Engine | ||
Copyright : EVE Project Contributors | ||
SPDX-License-Identifier: BSL-1.0 | ||
*/ | ||
//================================================================================================== | ||
#pragma once | ||
|
||
#include <eve/concept/value.hpp> | ||
#include <eve/detail/abi.hpp> | ||
#include <eve/detail/category.hpp> | ||
#include <eve/forward.hpp> | ||
|
||
namespace eve::detail | ||
{ | ||
// shr[mask](wide_val, wide_mask) | ||
template<conditional_expr C, integral_scalar_value T, typename S, typename N, callable_options O> | ||
EVE_FORCEINLINE wide<T,N> shr_(EVE_REQUIRES(avx512_), C const& cx, O const&, wide<T, N> v, wide<S, N> s) noexcept | ||
requires((sizeof(T) >= 2) && x86_abi<abi_t<T, N>>) | ||
{ | ||
constexpr auto c = categorize<wide<T, N>>(); | ||
auto src = alternative(cx, v, as<wide<T, N>> {}); | ||
auto m = expand_mask(cx, as<wide<T, N>> {}).storage().value; | ||
|
||
// perform an arithmetic shift right for the ints | ||
if constexpr( c == category::int16x32 ) return _mm512_mask_srav_epi16 (src, m, v, s); | ||
else if constexpr( c == category::int16x16 ) return _mm256_mask_srav_epi16 (src, m, v, s); | ||
else if constexpr( c == category::int16x8 ) return _mm_mask_srav_epi16 (src, m, v, s); | ||
|
||
else if constexpr( c == category::int32x16 ) return _mm512_mask_srav_epi32 (src, m, v, s); | ||
else if constexpr( c == category::int32x8 ) return _mm256_mask_srav_epi32 (src, m, v, s); | ||
else if constexpr( c == category::int32x4 ) return _mm_mask_srav_epi32 (src, m, v, s); | ||
|
||
else if constexpr( c == category::int64x8 ) return _mm512_mask_srav_epi64 (src, m, v, s); | ||
else if constexpr( c == category::int64x4 ) return _mm256_mask_srav_epi64 (src, m, v, s); | ||
else if constexpr( c == category::int64x2 ) return _mm_mask_srav_epi64 (src, m, v, s); | ||
|
||
// it does not matter for the uints, so just perform a logical shift | ||
else if constexpr( c == category::uint16x32) return _mm512_mask_srlv_epi16 (src, m, v, s); | ||
else if constexpr( c == category::uint16x16) return _mm256_mask_srlv_epi16 (src, m, v, s); | ||
else if constexpr( c == category::uint16x8 ) return _mm_mask_srlv_epi16 (src, m, v, s); | ||
|
||
else if constexpr( c == category::uint32x16) return _mm512_mask_srlv_epi32 (src, m, v, s); | ||
else if constexpr( c == category::uint32x8 ) return _mm256_mask_srlv_epi32 (src, m, v, s); | ||
else if constexpr( c == category::uint32x4 ) return _mm_mask_srlv_epi32 (src, m, v, s); | ||
|
||
else if constexpr( c == category::uint64x8 ) return _mm512_mask_srlv_epi64 (src, m, v, s); | ||
else if constexpr( c == category::uint64x4 ) return _mm256_mask_srlv_epi64 (src, m, v, s); | ||
else if constexpr( c == category::uint64x2 ) return _mm_mask_srlv_epi64 (src, m, v, s); | ||
} | ||
|
||
// shr[mask](wide_val, imm_mask) | ||
template<conditional_expr C, integral_scalar_value T, typename N, callable_options O> | ||
EVE_FORCEINLINE wide<T,N> shr_(EVE_REQUIRES(avx512_), C const& cx, O const&, wide<T, N> v, unsigned int s) noexcept | ||
requires((sizeof(T) >= 2) && x86_abi<abi_t<T, N>>) | ||
{ | ||
constexpr auto c = categorize<wide<T, N>>(); | ||
auto src = alternative(cx, v, as<wide<T, N>> {}); | ||
auto m = expand_mask(cx, as<wide<T, N>> {}).storage().value; | ||
|
||
// perform an arithmetic shift right for the uints to preserve the sign bit | ||
if constexpr( c == category::int16x32 ) return _mm512_mask_srai_epi16 (src, m, v, s); | ||
else if constexpr( c == category::int16x16 ) return _mm256_mask_srai_epi16 (src, m, v, s); | ||
else if constexpr( c == category::int16x8 ) return _mm_mask_srai_epi16 (src, m, v, s); | ||
|
||
else if constexpr( c == category::int32x16 ) return _mm512_mask_srai_epi32 (src, m, v, s); | ||
else if constexpr( c == category::int32x8 ) return _mm256_mask_srai_epi32 (src, m, v, s); | ||
else if constexpr( c == category::int32x4 ) return _mm_mask_srai_epi32 (src, m, v, s); | ||
|
||
else if constexpr( c == category::int64x8 ) return _mm512_mask_srai_epi64 (src, m, v, s); | ||
else if constexpr( c == category::int64x4 ) return _mm256_mask_srai_epi64 (src, m, v, s); | ||
else if constexpr( c == category::int64x2 ) return _mm_mask_srai_epi64 (src, m, v, s); | ||
|
||
// perform an logical shift right for the ints, as it does not matter | ||
else if constexpr( c == category::uint16x32) return _mm512_mask_srli_epi16 (src, m, v, s); | ||
else if constexpr( c == category::uint16x16) return _mm256_mask_srli_epi16 (src, m, v, s); | ||
else if constexpr( c == category::uint16x8 ) return _mm_mask_srli_epi16 (src, m, v, s); | ||
|
||
else if constexpr( c == category::uint32x16) return _mm512_mask_srli_epi32 (src, m, v, s); | ||
else if constexpr( c == category::uint32x8 ) return _mm256_mask_srli_epi32 (src, m, v, s); | ||
else if constexpr( c == category::uint32x4 ) return _mm_mask_srli_epi32 (src, m, v, s); | ||
|
||
else if constexpr( c == category::uint64x8 ) return _mm512_mask_srli_epi64 (src, m, v, s); | ||
else if constexpr( c == category::uint64x4 ) return _mm256_mask_srli_epi64 (src, m, v, s); | ||
else if constexpr( c == category::uint64x2 ) return _mm_mask_srli_epi64 (src, m, v, s); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters