Skip to content

Commit

Permalink
Implement mask shifts (take 2) (#1827)
Browse files Browse the repository at this point in the history
* take 1

* take 2

* take 3

* take 4
  • Loading branch information
SadiinsoSnowfall authored May 30, 2024
1 parent ec8c1f7 commit 85aa3ce
Show file tree
Hide file tree
Showing 6 changed files with 272 additions and 0 deletions.
4 changes: 4 additions & 0 deletions include/eve/module/core/regular/bit_shr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,7 @@ namespace eve
#if defined(EVE_INCLUDE_POWERPC_HEADER)
# include <eve/module/core/regular/impl/simd/ppc/bit_shr.hpp>
#endif

#if defined(EVE_INCLUDE_X86_HEADER)
# include <eve/module/core/regular/impl/simd/x86/bit_shr.hpp>
#endif
86 changes: 86 additions & 0 deletions include/eve/module/core/regular/impl/simd/x86/bit_shr.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
//==================================================================================================
/*
EVE - Expressive Vector Engine
Copyright : EVE Project Contributors
SPDX-License-Identifier: BSL-1.0
*/
//==================================================================================================
#pragma once

#include <eve/concept/value.hpp>
#include <eve/detail/abi.hpp>
#include <eve/detail/category.hpp>
#include <eve/forward.hpp>

namespace eve::detail
{
// bit_shr[mask](wide_val, wide_mask)
template<conditional_expr C, integral_scalar_value T, typename S, typename N, callable_options O>
EVE_FORCEINLINE wide<T,N> bit_shr_(EVE_REQUIRES(avx512_), C const& cx, O const&, wide<T, N> v, wide<S, N> s) noexcept
requires((sizeof(T) >= 2) && x86_abi<abi_t<T, N>>)
{
constexpr auto c = categorize<wide<T, N>>();
auto src = alternative(cx, v, as<wide<T, N>> {});
auto m = expand_mask(cx, as<wide<T, N>> {}).storage().value;

// perform a logical shift right for the ints, arithmetic shift is defined in eve::shr[mask]
if constexpr( c == category::int16x32 ) return _mm512_mask_srlv_epi16 (src, m, v, s);
else if constexpr( c == category::int16x16 ) return _mm256_mask_srlv_epi16 (src, m, v, s);
else if constexpr( c == category::int16x8 ) return _mm_mask_srlv_epi16 (src, m, v, s);

else if constexpr( c == category::int32x16 ) return _mm512_mask_srlv_epi32 (src, m, v, s);
else if constexpr( c == category::int32x8 ) return _mm256_mask_srlv_epi32 (src, m, v, s);
else if constexpr( c == category::int32x4 ) return _mm_mask_srlv_epi32 (src, m, v, s);

else if constexpr( c == category::int64x8 ) return _mm512_mask_srlv_epi64 (src, m, v, s);
else if constexpr( c == category::int64x4 ) return _mm256_mask_srlv_epi64 (src, m, v, s);
else if constexpr( c == category::int64x2 ) return _mm_mask_srlv_epi64 (src, m, v, s);

else if constexpr( c == category::uint16x32) return _mm512_mask_srlv_epi16 (src, m, v, s);
else if constexpr( c == category::uint16x16) return _mm256_mask_srlv_epi16 (src, m, v, s);
else if constexpr( c == category::uint16x8 ) return _mm_mask_srlv_epi16 (src, m, v, s);

else if constexpr( c == category::uint32x16) return _mm512_mask_srlv_epi32 (src, m, v, s);
else if constexpr( c == category::uint32x8 ) return _mm256_mask_srlv_epi32 (src, m, v, s);
else if constexpr( c == category::uint32x4 ) return _mm_mask_srlv_epi32 (src, m, v, s);

else if constexpr( c == category::uint64x8 ) return _mm512_mask_srlv_epi64 (src, m, v, s);
else if constexpr( c == category::uint64x4 ) return _mm256_mask_srlv_epi64 (src, m, v, s);
else if constexpr( c == category::uint64x2 ) return _mm_mask_srlv_epi64 (src, m, v, s);
}

// bit_shr[mask](wide_val, imm_mask)
template<conditional_expr C, integral_scalar_value T, typename N, callable_options O>
EVE_FORCEINLINE wide<T,N> bit_shr_(EVE_REQUIRES(avx512_), C const& cx, O const&, wide<T, N> v, unsigned int s) noexcept
requires((sizeof(T) >= 2) && x86_abi<abi_t<T, N>>)
{
constexpr auto c = categorize<wide<T, N>>();
auto src = alternative(cx, v, as<wide<T, N>> {});
auto m = expand_mask(cx, as<wide<T, N>> {}).storage().value;

if constexpr( c == category::int16x32 ) return _mm512_mask_srli_epi16 (src, m, v, s);
else if constexpr( c == category::int16x16 ) return _mm256_mask_srli_epi16 (src, m, v, s);
else if constexpr( c == category::int16x8 ) return _mm_mask_srli_epi16 (src, m, v, s);

else if constexpr( c == category::int32x16 ) return _mm512_mask_srli_epi32 (src, m, v, s);
else if constexpr( c == category::int32x8 ) return _mm256_mask_srli_epi32 (src, m, v, s);
else if constexpr( c == category::int32x4 ) return _mm_mask_srli_epi32 (src, m, v, s);

else if constexpr( c == category::int64x8 ) return _mm512_mask_srli_epi64 (src, m, v, s);
else if constexpr( c == category::int64x4 ) return _mm256_mask_srli_epi64 (src, m, v, s);
else if constexpr( c == category::int64x2 ) return _mm_mask_srli_epi64 (src, m, v, s);

// perform a logical shift right for the uints too, arithmetic shift is defined in eve::shr[mask]
else if constexpr( c == category::uint16x32) return _mm512_mask_srli_epi16 (src, m, v, s);
else if constexpr( c == category::uint16x16) return _mm256_mask_srli_epi16 (src, m, v, s);
else if constexpr( c == category::uint16x8 ) return _mm_mask_srli_epi16 (src, m, v, s);

else if constexpr( c == category::uint32x16) return _mm512_mask_srli_epi32 (src, m, v, s);
else if constexpr( c == category::uint32x8 ) return _mm256_mask_srli_epi32 (src, m, v, s);
else if constexpr( c == category::uint32x4 ) return _mm_mask_srli_epi32 (src, m, v, s);

else if constexpr( c == category::uint64x8 ) return _mm512_mask_srli_epi64 (src, m, v, s);
else if constexpr( c == category::uint64x4 ) return _mm256_mask_srli_epi64 (src, m, v, s);
else if constexpr( c == category::uint64x2 ) return _mm_mask_srli_epi64 (src, m, v, s);
}
}
86 changes: 86 additions & 0 deletions include/eve/module/core/regular/impl/simd/x86/shl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
//==================================================================================================
/*
EVE - Expressive Vector Engine
Copyright : EVE Project Contributors
SPDX-License-Identifier: BSL-1.0
*/
//==================================================================================================
#pragma once

#include <eve/concept/value.hpp>
#include <eve/detail/abi.hpp>
#include <eve/detail/category.hpp>
#include <eve/forward.hpp>

namespace eve::detail
{
// slh[mask](wide_val, wide_mask)
template<conditional_expr C, integral_scalar_value T, typename S, typename N, callable_options O>
EVE_FORCEINLINE wide<T,N> shl_(EVE_REQUIRES(avx512_), C const& cx, O const&, wide<T, N> v, wide<S, N> s) noexcept
requires((sizeof(T) >= 2) && x86_abi<abi_t<T, N>>)
{
constexpr auto c = categorize<wide<T, N>>();
auto src = alternative(cx, v, as<wide<T, N>> {});
auto m = expand_mask(cx, as<wide<T, N>> {}).storage().value;

// perform a logical shift left for ints as it is equivalent to an arithmetic shift left
if constexpr( c == category::int16x32 ) return _mm512_mask_sllv_epi16 (src, m, v, s);
else if constexpr( c == category::int16x16 ) return _mm256_mask_sllv_epi16 (src, m, v, s);
else if constexpr( c == category::int16x8 ) return _mm_mask_sllv_epi16 (src, m, v, s);

else if constexpr( c == category::int32x16 ) return _mm512_mask_sllv_epi32 (src, m, v, s);
else if constexpr( c == category::int32x8 ) return _mm256_mask_sllv_epi32 (src, m, v, s);
else if constexpr( c == category::int32x4 ) return _mm_mask_sllv_epi32 (src, m, v, s);

else if constexpr( c == category::int64x8 ) return _mm512_mask_sllv_epi64 (src, m, v, s);
else if constexpr( c == category::int64x4 ) return _mm256_mask_sllv_epi64 (src, m, v, s);
else if constexpr( c == category::int64x2 ) return _mm_mask_sllv_epi64 (src, m, v, s);

else if constexpr( c == category::uint16x32) return _mm512_mask_sllv_epi16 (src, m, v, s);
else if constexpr( c == category::uint16x16) return _mm256_mask_sllv_epi16 (src, m, v, s);
else if constexpr( c == category::uint16x8 ) return _mm_mask_sllv_epi16 (src, m, v, s);

else if constexpr( c == category::uint32x16) return _mm512_mask_sllv_epi32 (src, m, v, s);
else if constexpr( c == category::uint32x8 ) return _mm256_mask_sllv_epi32 (src, m, v, s);
else if constexpr( c == category::uint32x4 ) return _mm_mask_sllv_epi32 (src, m, v, s);

else if constexpr( c == category::uint64x8 ) return _mm512_mask_sllv_epi64 (src, m, v, s);
else if constexpr( c == category::uint64x4 ) return _mm256_mask_sllv_epi64 (src, m, v, s);
else if constexpr( c == category::uint64x2 ) return _mm_mask_sllv_epi64 (src, m, v, s);
}

// shr[mask](wide_val, imm_mask)
template<conditional_expr C, integral_scalar_value T, typename N, callable_options O>
EVE_FORCEINLINE wide<T,N> shl_(EVE_REQUIRES(avx512_), C const& cx, O const&, wide<T, N> v, unsigned int s) noexcept
requires((sizeof(T) >= 2) && x86_abi<abi_t<T, N>>)
{
constexpr auto c = categorize<wide<T, N>>();
auto src = alternative(cx, v, as<wide<T, N>> {});
auto m = expand_mask(cx, as<wide<T, N>> {}).storage().value;

if constexpr( c == category::int16x32 ) return _mm512_mask_slli_epi16 (src, m, v, s);
else if constexpr( c == category::int16x16 ) return _mm256_mask_slli_epi16 (src, m, v, s);
else if constexpr( c == category::int16x8 ) return _mm_mask_slli_epi16 (src, m, v, s);

else if constexpr( c == category::int32x16 ) return _mm512_mask_slli_epi32 (src, m, v, s);
else if constexpr( c == category::int32x8 ) return _mm256_mask_slli_epi32 (src, m, v, s);
else if constexpr( c == category::int32x4 ) return _mm_mask_slli_epi32 (src, m, v, s);

else if constexpr( c == category::int64x8 ) return _mm512_mask_slli_epi64 (src, m, v, s);
else if constexpr( c == category::int64x4 ) return _mm256_mask_slli_epi64 (src, m, v, s);
else if constexpr( c == category::int64x2 ) return _mm_mask_slli_epi64 (src, m, v, s);

// perform a logical shift left for uints as it is equivalent to an arithmetic shift left
else if constexpr( c == category::uint16x32) return _mm512_mask_slli_epi16 (src, m, v, s);
else if constexpr( c == category::uint16x16) return _mm256_mask_slli_epi16 (src, m, v, s);
else if constexpr( c == category::uint16x8 ) return _mm_mask_slli_epi16 (src, m, v, s);

else if constexpr( c == category::uint32x16) return _mm512_mask_slli_epi32 (src, m, v, s);
else if constexpr( c == category::uint32x8 ) return _mm256_mask_slli_epi32 (src, m, v, s);
else if constexpr( c == category::uint32x4 ) return _mm_mask_slli_epi32 (src, m, v, s);

else if constexpr( c == category::uint64x8 ) return _mm512_mask_slli_epi64 (src, m, v, s);
else if constexpr( c == category::uint64x4 ) return _mm256_mask_slli_epi64 (src, m, v, s);
else if constexpr( c == category::uint64x2 ) return _mm_mask_slli_epi64 (src, m, v, s);
}
}
88 changes: 88 additions & 0 deletions include/eve/module/core/regular/impl/simd/x86/shr.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
//==================================================================================================
/*
EVE - Expressive Vector Engine
Copyright : EVE Project Contributors
SPDX-License-Identifier: BSL-1.0
*/
//==================================================================================================
#pragma once

#include <eve/concept/value.hpp>
#include <eve/detail/abi.hpp>
#include <eve/detail/category.hpp>
#include <eve/forward.hpp>

namespace eve::detail
{
// shr[mask](wide_val, wide_mask)
template<conditional_expr C, integral_scalar_value T, typename S, typename N, callable_options O>
EVE_FORCEINLINE wide<T,N> shr_(EVE_REQUIRES(avx512_), C const& cx, O const&, wide<T, N> v, wide<S, N> s) noexcept
requires((sizeof(T) >= 2) && x86_abi<abi_t<T, N>>)
{
constexpr auto c = categorize<wide<T, N>>();
auto src = alternative(cx, v, as<wide<T, N>> {});
auto m = expand_mask(cx, as<wide<T, N>> {}).storage().value;

// perform an arithmetic shift right for the ints
if constexpr( c == category::int16x32 ) return _mm512_mask_srav_epi16 (src, m, v, s);
else if constexpr( c == category::int16x16 ) return _mm256_mask_srav_epi16 (src, m, v, s);
else if constexpr( c == category::int16x8 ) return _mm_mask_srav_epi16 (src, m, v, s);

else if constexpr( c == category::int32x16 ) return _mm512_mask_srav_epi32 (src, m, v, s);
else if constexpr( c == category::int32x8 ) return _mm256_mask_srav_epi32 (src, m, v, s);
else if constexpr( c == category::int32x4 ) return _mm_mask_srav_epi32 (src, m, v, s);

else if constexpr( c == category::int64x8 ) return _mm512_mask_srav_epi64 (src, m, v, s);
else if constexpr( c == category::int64x4 ) return _mm256_mask_srav_epi64 (src, m, v, s);
else if constexpr( c == category::int64x2 ) return _mm_mask_srav_epi64 (src, m, v, s);

// it does not matter for the uints, so just perform a logical shift
else if constexpr( c == category::uint16x32) return _mm512_mask_srlv_epi16 (src, m, v, s);
else if constexpr( c == category::uint16x16) return _mm256_mask_srlv_epi16 (src, m, v, s);
else if constexpr( c == category::uint16x8 ) return _mm_mask_srlv_epi16 (src, m, v, s);

else if constexpr( c == category::uint32x16) return _mm512_mask_srlv_epi32 (src, m, v, s);
else if constexpr( c == category::uint32x8 ) return _mm256_mask_srlv_epi32 (src, m, v, s);
else if constexpr( c == category::uint32x4 ) return _mm_mask_srlv_epi32 (src, m, v, s);

else if constexpr( c == category::uint64x8 ) return _mm512_mask_srlv_epi64 (src, m, v, s);
else if constexpr( c == category::uint64x4 ) return _mm256_mask_srlv_epi64 (src, m, v, s);
else if constexpr( c == category::uint64x2 ) return _mm_mask_srlv_epi64 (src, m, v, s);
}

// shr[mask](wide_val, imm_mask)
template<conditional_expr C, integral_scalar_value T, typename N, callable_options O>
EVE_FORCEINLINE wide<T,N> shr_(EVE_REQUIRES(avx512_), C const& cx, O const&, wide<T, N> v, unsigned int s) noexcept
requires((sizeof(T) >= 2) && x86_abi<abi_t<T, N>>)
{
constexpr auto c = categorize<wide<T, N>>();
auto src = alternative(cx, v, as<wide<T, N>> {});
auto m = expand_mask(cx, as<wide<T, N>> {}).storage().value;

// perform an arithmetic shift right for the uints to preserve the sign bit
if constexpr( c == category::int16x32 ) return _mm512_mask_srai_epi16 (src, m, v, s);
else if constexpr( c == category::int16x16 ) return _mm256_mask_srai_epi16 (src, m, v, s);
else if constexpr( c == category::int16x8 ) return _mm_mask_srai_epi16 (src, m, v, s);

else if constexpr( c == category::int32x16 ) return _mm512_mask_srai_epi32 (src, m, v, s);
else if constexpr( c == category::int32x8 ) return _mm256_mask_srai_epi32 (src, m, v, s);
else if constexpr( c == category::int32x4 ) return _mm_mask_srai_epi32 (src, m, v, s);

else if constexpr( c == category::int64x8 ) return _mm512_mask_srai_epi64 (src, m, v, s);
else if constexpr( c == category::int64x4 ) return _mm256_mask_srai_epi64 (src, m, v, s);
else if constexpr( c == category::int64x2 ) return _mm_mask_srai_epi64 (src, m, v, s);

// perform an logical shift right for the ints, as it does not matter
else if constexpr( c == category::uint16x32) return _mm512_mask_srli_epi16 (src, m, v, s);
else if constexpr( c == category::uint16x16) return _mm256_mask_srli_epi16 (src, m, v, s);
else if constexpr( c == category::uint16x8 ) return _mm_mask_srli_epi16 (src, m, v, s);

else if constexpr( c == category::uint32x16) return _mm512_mask_srli_epi32 (src, m, v, s);
else if constexpr( c == category::uint32x8 ) return _mm256_mask_srli_epi32 (src, m, v, s);
else if constexpr( c == category::uint32x4 ) return _mm_mask_srli_epi32 (src, m, v, s);

else if constexpr( c == category::uint64x8 ) return _mm512_mask_srli_epi64 (src, m, v, s);
else if constexpr( c == category::uint64x4 ) return _mm256_mask_srli_epi64 (src, m, v, s);
else if constexpr( c == category::uint64x2 ) return _mm_mask_srli_epi64 (src, m, v, s);
}
}
4 changes: 4 additions & 0 deletions include/eve/module/core/regular/shl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,7 @@ namespace eve
}
}
}

#if defined(EVE_INCLUDE_X86_HEADER)
# include <eve/module/core/regular/impl/simd/x86/shl.hpp>
#endif
4 changes: 4 additions & 0 deletions include/eve/module/core/regular/shr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,7 @@ namespace eve
}
}
}

#if defined(EVE_INCLUDE_X86_HEADER)
# include <eve/module/core/regular/impl/simd/x86/shr.hpp>
#endif

0 comments on commit 85aa3ce

Please sign in to comment.