Skip to content

Commit

Permalink
Fix missing X86 FMA related optimisations
Browse files Browse the repository at this point in the history
  • Loading branch information
jtlap authored Apr 25, 2024
1 parent 99f533d commit 8af225a
Show file tree
Hide file tree
Showing 9 changed files with 230 additions and 210 deletions.
2 changes: 1 addition & 1 deletion include/eve/module/core/regular/impl/fam.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,6 @@ namespace eve::detail
}
// REGULAR ---------------------
else
return a + b * c;
return fma(b, c, a);
}
}
50 changes: 22 additions & 28 deletions include/eve/module/core/regular/impl/simd/x86/fam.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,49 +8,43 @@
#pragma once

#include <eve/concept/value.hpp>
#include <eve/detail/abi.hpp>
#include <eve/detail/category.hpp>
#include <eve/detail/implementation.hpp>

#include <type_traits>
#include <eve/forward.hpp>

namespace eve::detail
{
// -----------------------------------------------------------------------------------------------
// Masked case
template<conditional_expr C, arithmetic_scalar_value T, typename N, callable_options O>
EVE_FORCEINLINE wide<T, N> fam_(EVE_SUPPORTS(sse2_),
C const &cx,
O const &opts,
EVE_FORCEINLINE wide<T, N> fam_(EVE_REQUIRES(avx512_),
C const &mask,
O const &,
wide<T, N> const &v,
wide<T, N> const &w,
wide<T, N> const &x) noexcept
requires x86_abi<abi_t<T, N>>
{
constexpr auto c = categorize<wide<T, N>>();
// NOTE: As those masked version are at the AVX512 level, they will always uses a variant of
// hardware VMADD, thus ensuring the pedantic behavior by default, hence why we don't care about
// PEDANTIC. As usual, we don't care about PROMOTE as we only accept similar types.

if constexpr( C::is_complete || abi_t<T, N>::is_wide_logical )
{
return fam.behavior(cpu_{}, opts, v, w, x);
}
else
if constexpr( C::is_complete )
return alternative(mask, v, as(v));
else if constexpr( !C::has_alternative )
{
auto m = expand_mask(cx, as<wide<T, N>> {}).storage().value;
constexpr auto c = categorize<wide<T, N>>();
[[maybe_unused]] auto const m = expand_mask(mask, as(v)).storage().value;

if constexpr( !C::has_alternative )
{
if constexpr( c == category::float32x16 ) return _mm512_mask3_fmadd_ps(w, x, v, m);
else if constexpr( c == category::float64x8 ) return _mm512_mask3_fmadd_pd(w, x, v, m);
else if constexpr( c == category::float32x8 ) return _mm256_mask3_fmadd_ps(w, x, v, m);
else if constexpr( c == category::float64x4 ) return _mm256_mask3_fmadd_pd(w, x, v, m);
else if constexpr( c == category::float32x8 ) return _mm128_mask3_fmadd_ps(w, x, v, m);
else if constexpr( c == category::float64x4 ) return _mm128_mask3_fmadd_pd(w, x, v, m);
else return fam.behavior(cpu_{}, opts, v, w, x);
}
else
{
auto src = alternative(cx, v, as<wide<T, N>> {});
return fam.behavior(cpu_{}, opts, v, w, x);
}
if constexpr( c == category::float32x16) return _mm512_mask3_fmadd_ps(w, x, v, m);
else if constexpr( c == category::float64x8 ) return _mm512_mask3_fmadd_pd(w, x, v, m);
else if constexpr( c == category::float32x8 ) return _mm256_mask3_fmadd_ps(w, x, v, m);
else if constexpr( c == category::float64x4 ) return _mm256_mask3_fmadd_pd(w, x, v, m);
else if constexpr( c == category::float32x8 ) return _mm128_mask3_fmadd_ps(w, x, v, m);
else if constexpr( c == category::float64x4 ) return _mm128_mask3_fmadd_pd(w, x, v, m);
// No rounding issue with integers, so we just mask over regular FMA
else return if_else(mask, eve::fam(v, w, x), v);
}
else return if_else(mask, eve::fam(v, w, x), alternative(mask, v, as(v)));
}
}
65 changes: 32 additions & 33 deletions include/eve/module/core/regular/impl/simd/x86/fanm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,48 +8,47 @@
#pragma once

#include <eve/concept/value.hpp>
#include <eve/detail/abi.hpp>
#include <eve/detail/category.hpp>
#include <eve/detail/implementation.hpp>

#include <type_traits>
#include <eve/forward.hpp>

namespace eve::detail
{
// -----------------------------------------------------------------------------------------------
// Masked case
template<conditional_expr C, arithmetic_scalar_value T, typename N>
EVE_FORCEINLINE wide<T, N>
fanm_(EVE_SUPPORTS(sse2_),
C const &cx,
wide<T, N> const &v,
wide<T, N> const &w,
wide<T, N> const &x) noexcept requires x86_abi<abi_t<T, N>>
{
constexpr auto c = categorize<wide<T, N>>();

if constexpr( C::is_complete || abi_t<T, N>::is_wide_logical )
{
return fanm_(EVE_RETARGET(cpu_), cx, v, w, x);
}
else
template<conditional_expr C, arithmetic_scalar_value T, typename N>
EVE_FORCEINLINE wide<T, N> fanm_(EVE_SUPPORTS(avx512_),
C const &mask,
wide<T, N> const &v,
wide<T, N> const &w,
wide<T, N> const &x) noexcept
requires x86_abi<abi_t<T, N>>
{
auto m = expand_mask(cx, as<wide<T, N>> {}).storage().value;
// NOTE: As those masked version are at the AVX512 level, they will always uses a variant of
// hardware VMADD, thus ensuring the pedantic behavior by default, hence why we don't care about
// PEDANTIC. As usual, we don't care about PROMOTE as we only accept similar types.

if constexpr( !C::has_alternative )
if constexpr( C::is_complete )
return alternative(mask, v, as(v));
else if constexpr( !C::has_alternative )
{
if constexpr( c == category::float32x16 ) return _mm512_mask3_fnmadd_ps(w, x, v, m);
else if constexpr( c == category::float64x8 ) return _mm512_mask3_fnmadd_pd(w, x, v, m);
else if constexpr( c == category::float32x8 ) return _mm256_mask3_fnmadd_ps(w, x, v, m);
else if constexpr( c == category::float64x4 ) return _mm256_mask3_fnmadd_pd(w, x, v, m);
else if constexpr( c == category::float32x8 ) return _mm128_mask3_fnmadd_ps(w, x, v, m);
else if constexpr( c == category::float64x4 ) return _mm128_mask3_fnmadd_pd(w, x, v, m);
else return fanm_(EVE_RETARGET(cpu_), cx, v, w, x);
}
else
{
auto src = alternative(cx, v, as<wide<T, N>> {});
return fanm_(EVE_RETARGET(cpu_), cx, v, w, x);
constexpr auto c = categorize<wide<T, N>>();
[[maybe_unused]] auto const m = expand_mask(mask, as(v)).storage().value;

if constexpr( !C::has_alternative )
{
if constexpr( c == category::float32x16) return _mm512_mask3_fnmadd_ps(w, x, v, m);
else if constexpr( c == category::float64x8 ) return _mm512_mask3_fnmadd_pd(w, x, v, m);
else if constexpr( c == category::float32x8 ) return _mm256_mask3_fnmadd_ps(w, x, v, m);
else if constexpr( c == category::float64x4 ) return _mm256_mask3_fnmadd_pd(w, x, v, m);
else if constexpr( c == category::float32x8 ) return _mm128_mask3_fnmadd_ps(w, x, v, m);
else if constexpr( c == category::float64x4 ) return _mm128_mask3_fnmadd_pd(w, x, v, m);
// No rounding issue with integers, so we just mask over regular FMA
else
return if_else(mask, eve::fanm(v, w, x), v);
}
else
return if_else(mask, eve::fanm(v, w, x), alternative(mask, v, as(v)));
}
}
}
}
16 changes: 11 additions & 5 deletions include/eve/module/core/regular/impl/simd/x86/fma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
namespace eve::detail
{
template<typename T, typename N, callable_options O>
EVE_FORCEINLINE wide<T, N>
fma_(EVE_REQUIRES(sse2_), O const& opts, wide<T, N> const& a, wide<T, N> const& b, wide<T, N> const& c) noexcept
EVE_FORCEINLINE wide<T, N> fma_(EVE_REQUIRES(sse2_),
O const& opts,
wide<T, N> const& a,
wide<T, N> const& b,
wide<T, N> const& c) noexcept
requires x86_abi<abi_t<T, N>>
{
// Integral don't do anything special ----
Expand Down Expand Up @@ -47,9 +50,12 @@ namespace eve::detail
}

template<typename T, typename N, conditional_expr C, callable_options O>
EVE_FORCEINLINE wide<T, N>
fma_( EVE_REQUIRES(avx512_), C const& mask, O const&
, wide<T, N> const& a, wide<T, N> const& b, wide<T, N> const& c
EVE_FORCEINLINE wide<T, N> fma_( EVE_REQUIRES(avx512_),
C const& mask,
O const&,
wide<T, N> const& a,
wide<T, N> const& b,
wide<T, N> const& c
)
noexcept requires x86_abi<abi_t<T, N>>
{
Expand Down
99 changes: 53 additions & 46 deletions include/eve/module/core/regular/impl/simd/x86/fms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,75 +7,82 @@
//==================================================================================================
#pragma once


#include <eve/concept/value.hpp>
#include <eve/detail/abi.hpp>
#include <eve/detail/category.hpp>
#include <eve/detail/implementation.hpp>
#include <eve/forward.hpp>
#include <eve/module/core/regular/fma.hpp>

#include <type_traits>

namespace eve::detail
{
template<arithmetic_scalar_value T, typename N>
EVE_FORCEINLINE wide<T, N>
fms_(EVE_SUPPORTS(avx2_),
wide<T, N> const &a,
wide<T, N> const &b,
wide<T, N> const &c) noexcept requires x86_abi<abi_t<T, N>>
{
if constexpr( std::is_integral_v<T> ) { return fms_(EVE_RETARGET(cpu_), a, b, c); }
else
template<arithmetic_scalar_value T, typename N, callable_options O>
EVE_FORCEINLINE wide<T, N> fms_(EVE_REQUIRES(avx2_),
O const &opts,
wide<T, N> const &a,
wide<T, N> const &b,
wide<T, N> const &c) noexcept
requires x86_abi<abi_t<T, N>>
{
constexpr auto cat = categorize<wide<T, N>>();

if constexpr( cat == category::float64x8 ) return _mm512_fmsub_pd(a, b, c);
else if constexpr( cat == category::float32x16 ) return _mm512_fmsub_ps(a, b, c);
else if constexpr( supports_fma3 )
// Integral don't do anything special ----
if constexpr( std::integral<T> ) return fms.behavior(cpu_{}, opts, a, b, c);
// PEDANTIC ---
else if constexpr(O::contains(pedantic2) )
{
if constexpr( cat == category::float64x4 ) return _mm256_fmsub_pd(a, b, c);
else if constexpr( cat == category::float64x2 ) return _mm_fmsub_pd(a, b, c);
else if constexpr( cat == category::float32x8 ) return _mm256_fmsub_ps(a, b, c);
else if constexpr( cat == category::float32x4 ) return _mm_fmsub_ps(a, b, c);
if constexpr( supports_fma3 ) return fms(a, b, c);
else return fms.behavior(cpu_{}, opts, a, b, c);
}
// REGULAR ---
// we don't care about PROMOTE as we only accept similar types.
else
{
constexpr auto cat = categorize<wide<T, N>>();

if constexpr( cat == category::float64x8 ) return _mm512_fmsub_pd(a, b, c);
else if constexpr( cat == category::float32x16 ) return _mm512_fmsub_ps(a, b, c);
else if constexpr( supports_fma3 )
{
if constexpr( cat == category::float64x4 ) return _mm256_fmsub_pd(a, b, c);
else if constexpr( cat == category::float64x2 ) return _mm_fmsub_pd(a, b, c);
else if constexpr( cat == category::float32x8 ) return _mm256_fmsub_ps(a, b, c);
else if constexpr( cat == category::float32x4 ) return _mm_fmsub_ps(a, b, c);
}
else return fma(a, b, -c);
}
else return fma(a, b, -c);
}
}

// -----------------------------------------------------------------------------------------------
// Masked case
template<conditional_expr C, arithmetic_scalar_value T, typename N>
EVE_FORCEINLINE wide<T, N>
fms_(EVE_SUPPORTS(avx512_),
C const &cx,
wide<T, N> const &v,
wide<T, N> const &w,
wide<T, N> const &x) noexcept requires x86_abi<abi_t<T, N>>
{
constexpr auto c = categorize<wide<T, N>>();

if constexpr( C::is_complete || abi_t<T, N>::is_wide_logical )
{
return fms_(EVE_RETARGET(cpu_), cx, v, w, x);
}
else
template<conditional_expr C, arithmetic_scalar_value T, typename N, callable_options O>
EVE_FORCEINLINE wide<T, N> fms_(EVE_SUPPORTS(avx512_),
C const &mask,
O const &,
wide<T, N> const &v,
wide<T, N> const &w,
wide<T, N> const &x) noexcept
requires x86_abi<abi_t<T, N>>
{
auto m = expand_mask(cx, as<wide<T, N>> {}).storage().value;
// NOTE: As those masked version are at the AVX512 level, they will always uses a variant of
// hardware VMADD, thus ensuring the pedantic behavior by default, hence why we don't care about
// PEDANTIC. As usual, we don't care about PROMOTE as we only accept similar types.

if constexpr( !C::has_alternative )
if constexpr( C::is_complete ) return alternative(mask, v, as(v));
else if constexpr( !C::has_alternative )
{
if constexpr( c == category::float32x16 ) return _mm512_mask_fmsub_ps(v, m, w, x);
constexpr auto c = categorize<wide<T, N>>();
[[maybe_unused]] auto const m = expand_mask(mask, as(v)).storage().value;

if constexpr( c == category::float32x16) return _mm512_mask_fmsub_ps(v, m, w, x);
else if constexpr( c == category::float64x8 ) return _mm512_mask_fmsub_pd(v, m, w, x);
else if constexpr( c == category::float32x8 ) return _mm256_mask_fmsub_ps(v, m, w, x);
else if constexpr( c == category::float64x4 ) return _mm256_mask_fmsub_pd(v, m, w, x);
else if constexpr( c == category::float32x4 ) return _mm_mask_fmsub_ps(v, m, w, x);
else if constexpr( c == category::float64x2 ) return _mm_mask_fmsub_pd(v, m, w, x);
else return fms_(EVE_RETARGET(cpu_), cx, v, w, x);
}
else
{
auto src = alternative(cx, v, as<wide<T, N>> {});
return fms_(EVE_RETARGET(cpu_), cx, v, w, x);
// No rounding issue with integers, so we just mask over regular FMA
else return if_else(mask, eve::fms(v, w, x), v);
}
else return if_else(mask, eve::fms(v, w, x), alternative(mask, v, as(v)));
}
}
}
16 changes: 11 additions & 5 deletions include/eve/module/core/regular/impl/simd/x86/fnma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@
namespace eve::detail
{
template<typename T, typename N, callable_options O>
EVE_FORCEINLINE wide<T, N>
fnma_(EVE_REQUIRES(sse2_), O const& opts, wide<T, N> const& a, wide<T, N> const& b, wide<T, N> const& c) noexcept
EVE_FORCEINLINE wide<T, N> fnma_(EVE_REQUIRES(sse2_),
O const& opts,
wide<T, N> const& a,
wide<T, N> const& b,
wide<T, N> const& c) noexcept
requires x86_abi<abi_t<T, N>>
{
// Integral don't do anything special ----
Expand Down Expand Up @@ -51,9 +54,12 @@ namespace eve::detail
}

template<typename T, typename N, conditional_expr C, callable_options O>
EVE_FORCEINLINE wide<T, N>
fnma_( EVE_REQUIRES(avx512_), C const& mask, O const&
, wide<T, N> const& a, wide<T, N> const& b, wide<T, N> const& c
EVE_FORCEINLINE wide<T, N> fnma_( EVE_REQUIRES(avx512_),
C const& mask,
O const&
, wide<T, N> const& a,
wide<T, N> const& b,
wide<T, N> const& c
)
noexcept requires x86_abi<abi_t<T, N>>
{
Expand Down
Loading

0 comments on commit 8af225a

Please sign in to comment.