From c2dacd782c29177ffc4b8ef86a900d6bc6389e67 Mon Sep 17 00:00:00 2001 From: jtlap Date: Fri, 25 Oct 2024 15:06:16 +0200 Subject: [PATCH] Implements compare_absolute : apply predicate on absolute values - SVE optimisation --- .../module/core/regular/compare_absolute.hpp | 118 ++++++++++++++++++ include/eve/module/core/regular/core.hpp | 1 + .../impl/simd/arm/neon/compare_absolute.hpp | 74 +++++++++++ .../impl/simd/arm/sve/compare_absolute.hpp | 56 +++++++++ test/doc/core/compare_absolute.cpp | 30 +++++ test/unit/module/core/compare_absolute.cpp | 55 ++++++++ 6 files changed, 334 insertions(+) create mode 100644 include/eve/module/core/regular/compare_absolute.hpp create mode 100644 include/eve/module/core/regular/impl/simd/arm/neon/compare_absolute.hpp create mode 100644 include/eve/module/core/regular/impl/simd/arm/sve/compare_absolute.hpp create mode 100644 test/doc/core/compare_absolute.cpp create mode 100644 test/unit/module/core/compare_absolute.cpp diff --git a/include/eve/module/core/regular/compare_absolute.hpp b/include/eve/module/core/regular/compare_absolute.hpp new file mode 100644 index 0000000000..d9a4be5ad5 --- /dev/null +++ b/include/eve/module/core/regular/compare_absolute.hpp @@ -0,0 +1,118 @@ +//================================================================================================== +/* + EVE - Expressive Vector Engine + Copyright : EVE Project Contributors + SPDX-License-Identifier: BSL-1.0 +*/ +//================================================================================================== +#pragma once + +#include +#include +#include +#include + +namespace eve +{ + template + struct compare_absolute_t : strict_elementwise_callable + { + template< value T, typename F> + constexpr EVE_FORCEINLINE logical operator()(T a, F f) const + { + return EVE_DISPATCH_CALL(a, f); + } + + template + requires(eve::same_lanes_or_scalar) + constexpr EVE_FORCEINLINE common_logical_t operator()(T a, U b, F f) const + { + return EVE_DISPATCH_CALL(a, b, f); + } + + EVE_CALLABLE_OBJECT(compare_absolute_t, compare_absolute_); + }; + +//================================================================================================ +//! @addtogroup core_predicates +//! @{ +//! @var compare_absolute +//! @brief `elementwise callable` returning a logical true if and only if the absolute value of the first parameters +//! satisfy the predicate parameters. +//! +//! @groupheader{Header file} +//! +//! @code +//! #include +//! @endcode +//! +//! @groupheader{Callable Signatures} +//! +//! @code +//! namespace eve +//! { +//! // Regular overload +//! constexpr auto compare_absolute(value auto x, auto f) noexcept; // 1 +//! constexpr auto compare_absolute(value auto x, value auto y, auto f) noexcept; // 1 +//! +//! // Lanes masking +//! constexpr auto compare_absolute[conditional_expr auto c](value auto x, value auto y, auto f) noexcept; // 2 +//! constexpr auto compare_absolute[logical_value auto m](value auto x, value auto y, auto f) noexcept; // 2 +//! constexpr auto compare_absolute[conditional_expr auto c](value auto x, auto f) noexcept; // 2 +//! constexpr auto compare_absolute[logical_value auto m](value auto x, auto f) noexcept; // 2 +//! +//! // Semantic options +//! constexpr auto compare_absolute[saturated](value auto x, auto f) noexcept; // 3 +//! constexpr auto compare_absolute[saturated](value auto x, value auto y, auto f) noexcept; // 3 +//! } +//! @endcode +//! +//! **Parameters** +//! +//! * `x`, `y`: [arguments](@ref eve::value). +//! * `f`: one or two parameter predicate according to `y` presence. +//! * `c`: [Conditional expression](@ref conditional_expr) masking the operation. +//! * `m`: [Logical value](@ref logical) masking the operation. +//! +//! **Return value** +//! +//! 1. The call `eve::compare_absolute(x,y,f)` is semantically equivalent to `f(abs(x), abs(y))` and +//! `eve::compare_absolute(x,f)` is semantically equivalent to `f(abs(x))` and +//! 2. [The operation is performed conditionnaly](@ref conditional). +//! 3. the option is transfered to 'abs' in the call. +//! +//! @groupheader{Example} +//! @godbolt{doc/core/compare_absolute.cpp} +//================================================================================================ + inline constexpr auto compare_absolute = functor; +//================================================================================================ +//! @} +//================================================================================================ +} + +#include + +namespace eve::detail +{ + template + EVE_FORCEINLINE constexpr logical + compare_absolute_(EVE_REQUIRES(cpu_), O const& o, T a, F f) noexcept + { + return f(eve::abs[o](a)); + } + + template + EVE_FORCEINLINE constexpr common_logical_t + compare_absolute_(EVE_REQUIRES(cpu_), O const& o, T a, U b, F f) noexcept + { + return f(eve::abs[o](a), abs[o](b)); + } +} + +#if defined(EVE_INCLUDE_ARM_NEON_HEADER) +# include +#endif + +#if defined(EVE_INCLUDE_ARM_SVE_HEADER) +# include +#endif diff --git a/include/eve/module/core/regular/core.hpp b/include/eve/module/core/regular/core.hpp index 119217d944..56078df620 100644 --- a/include/eve/module/core/regular/core.hpp +++ b/include/eve/module/core/regular/core.hpp @@ -47,6 +47,7 @@ #include #include #include +#include #include #include #include diff --git a/include/eve/module/core/regular/impl/simd/arm/neon/compare_absolute.hpp b/include/eve/module/core/regular/impl/simd/arm/neon/compare_absolute.hpp new file mode 100644 index 0000000000..13e9030e70 --- /dev/null +++ b/include/eve/module/core/regular/impl/simd/arm/neon/compare_absolute.hpp @@ -0,0 +1,74 @@ +//================================================================================================== +/* + EVE - Expressive Vector Engine + Copyright : EVE Project Contributors + SPDX-License-Identifier: BSL-1.0 +*/ +//================================================================================================== +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace eve::detail +{ + template + EVE_FORCEINLINE logical> compare_absolute_(EVE_REQUIRES(neon128_), O const& opts, + wide v, wide w, F f) noexcept + requires (arm_abi>) + { + constexpr auto c = categorize>(); + if constexpr(F{} == is_less_equal) + { + if constexpr( c == category::float32x2 ) return vcale_f32 (v, w); + else if constexpr( c == category::float32x4 ) return vcaleq_f32(v, w); + else if constexpr( current_api >= asimd ) + { + if constexpr( c == category::float64x1 ) return vcale_f64 (v, w); + else if constexpr( c == category::float64x2 ) return vcaleq_f64 (v, w); + } + else return compare_absolute.behavior(cpu_{}, v, w, f); + } + else if constexpr(F{} == is_greater_equal) + { + if constexpr( c == category::float32x2 ) return vcage_f32 (v, w); + else if constexpr( c == category::float32x4 ) return vcageq_f32(v, w); + else if constexpr( current_api >= asimd ) + { + if constexpr( c == category::float64x1 ) return vcage_f64 (v, w); + else if constexpr( c == category::float64x2 ) return vcageq_f64 (v, w); + } + else return compare_absolute.behavior(cpu_{}, v, w, f); + } + else if constexpr(F{} == is_less) + { + if constexpr( c == category::float32x2 ) return vcalt_f32 (v, w); + else if constexpr( c == category::float32x4 ) return vcaltq_f32(v, w); + else if constexpr( current_api >= asimd ) + { + if constexpr( c == category::float64x1 ) return vcalt_f64 (v, w); + else if constexpr( c == category::float64x2 ) return vcaltq_f64 (v, w); + } + else return compare_absolute.behavior(cpu_{}, v, w, f); + } + else if constexpr(F{} == is_greater) + { + if constexpr( c == category::float32x2 ) return vcagt_f32 (v, w); + else if constexpr( c == category::float32x4 ) return vcagtq_f32(v, w); + else if constexpr( current_api >= asimd ) + { + if constexpr( c == category::float64x1 ) return vcagt_f64 (v, w); + else if constexpr( c == category::float64x2 ) return vcagtq_f64 (v, w); + } + else return compare_absolute.behavior(cpu_{}, v, w, f); + } + else return compare_absolute.behavior(cpu_{}, v, w, f); + } +} diff --git a/include/eve/module/core/regular/impl/simd/arm/sve/compare_absolute.hpp b/include/eve/module/core/regular/impl/simd/arm/sve/compare_absolute.hpp new file mode 100644 index 0000000000..8ab9cb3689 --- /dev/null +++ b/include/eve/module/core/regular/impl/simd/arm/sve/compare_absolute.hpp @@ -0,0 +1,56 @@ +//================================================================================================== +/* + EVE - Expressive Vector Engine + Copyright : EVE Project Contributors + SPDX-License-Identifier: BSL-1.0 +*/ +//================================================================================================== +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace eve::detail +{ + + template + EVE_FORCEINLINE logical> + compare_absolute_(EVE_REQUIRES(sve_), O const& opts, + wide v, wide w, F f) noexcept + requires (sve_abi>) + { + auto m = sve_true(); + if constexpr(F{} == is_less_equal) return svacle(m, v, w); + else if constexpr(F{} == is_less) return svaclt(m, v, w); + else if constexpr(F{} == is_greater_equal) return svacge(m, v, w); + else if constexpr(F{} == is_less_equal) return svacgt(m, v, w); + else return compare_absolute.behavior(cpu_{}, opts, v, w, f); + } + + template + EVE_FORCEINLINE logical> compare_absolute_(EVE_REQUIRES(sve_), C const& mask, O const& opts, + wide v, wide w, F f) noexcept + requires (sve_abi>) + { + auto const alt = alternative(mask, v, as(to_logical(v))); + if constexpr( C::is_complete ) return alt; + else if constexpr (!C::has_alternative) + { + auto m = expand_mask(mask, as(v)); + if constexpr(F{} == is_less_equal) return svacle(m, v, w); + else if constexpr(F{} == is_less) return svaclt(m, v, w); + else if constexpr(F{} == is_greater_equal) return svacge(m, v, w); + else if constexpr(F{} == is_less_equal) return svacgt(m, v, w); + else return compare_absolute.behavior(cpu_{}, opts, v, w, f); + } + else return compare_absolute.behavior(cpu_{}, opts, v, w, f); + } +} diff --git a/test/doc/core/compare_absolute.cpp b/test/doc/core/compare_absolute.cpp new file mode 100644 index 0000000000..b46bc83464 --- /dev/null +++ b/test/doc/core/compare_absolute.cpp @@ -0,0 +1,30 @@ +// revision 0 +#include +#include + +int main() +{ + eve::wide wf0{0.0, 1.0, 2.0, 3.0, -1.0, -2.0, -3.0, -4.0}; + eve::wide wf1{0.0, -4.0, 1.0, -1.0, 2.0, -2.0, 3.0, -3.0}; + eve::wide wi0{0, 1, 2, 3, -1, -2, -3, -4}; + eve::wide wi1{0, -4, 1, -1, 2, -2, 3, -3}; + eve::wide wu0{0u, 1u, 2u, 3u, 4u, 5u, 6u, 7u}; + eve::wide wu1{7u, 6u, 5u, 4u, 3u, 2u, 1u, 0u}; + + std::cout << "<- wf0 = " << wf0 << "\n"; + std::cout << "<- wf1 = " << wf1 << "\n"; + std::cout << "<- wi0 = " << wi0 << "\n"; + std::cout << "<- wi1 = " << wi1 << "\n"; + std::cout << "<- wu0 = " << wu0 << "\n"; + std::cout << "<- wu1 = " << wu1 << "\n"; + using eve::is_greater; + std::cout << "-> compare_absolute(wf0, wf1, is_greater) = " << eve::compare_absolute(wf0, wf1, is_greater) << "\n"; + std::cout << "-> compare_absolute[ignore_last(2)](wf0, wf1, is_greater) = " << eve::compare_absolute[eve::ignore_last(2)](wf0, wf1, is_greater) << "\n"; + std::cout << "-> compare_absolute[wf0 != 0](wf0, wf1, is_greater) = " << eve::compare_absolute[wf0 != 0](wf0, wf1, is_greater) << "\n"; + std::cout << "-> compare_absolute(wu0, wu1, is_greater) = " << eve::compare_absolute(wu0, wu1, is_greater) << "\n"; + std::cout << "-> compare_absolute[ignore_last(2)](wu0, wu1, is_greater) = " << eve::compare_absolute[eve::ignore_last(2)](wu0, wu1, is_greater) << "\n"; + std::cout << "-> compare_absolute[wu0 != 0](wu0, wu1, is_greater) = " << eve::compare_absolute[wu0 != 0](wu0, wu1, is_greater) << "\n"; + std::cout << "-> compare_absolute(wi0, wi1, is_greater) = " << eve::compare_absolute(wi0, wi1, is_greater) << "\n"; + std::cout << "-> compare_absolute[ignore_last(2)](wi0, wi1, is_greater) = " << eve::compare_absolute[eve::ignore_last(2)](wi0, wi1, is_greater) << "\n"; + std::cout << "-> compare_absolute[wi0 != 0](wi0, wi1, is_greater) = " << eve::compare_absolute[wi0 != 0](wi0, wi1, is_greater) << "\n"; +} diff --git a/test/unit/module/core/compare_absolute.cpp b/test/unit/module/core/compare_absolute.cpp new file mode 100644 index 0000000000..250d5877e8 --- /dev/null +++ b/test/unit/module/core/compare_absolute.cpp @@ -0,0 +1,55 @@ +//================================================================================================== +/** + EVE - Expressive Vector Engine + Copyright : EVE Project Contributors + SPDX-License-Identifier: BSL-1.0 +**/ +//================================================================================================== +#include "test.hpp" + +#include + +//================================================================================================== +//== Types tests +//================================================================================================== +TTS_CASE_TPL("Check return types of eve::compare_absolute(simd)", eve::test::simd::all_types) +(tts::type) +{ + using eve::logical; + using v_t = eve::element_type_t; + auto f = eve::is_greater; + TTS_EXPR_IS(eve::compare_absolute(T(), T(), f), logical); + TTS_EXPR_IS(eve::compare_absolute(v_t(), v_t(), f), logical); + TTS_EXPR_IS(eve::compare_absolute(T(), v_t(), f), logical); + TTS_EXPR_IS(eve::compare_absolute(v_t(), T(), f), logical); + TTS_EXPR_IS(eve::compare_absolute[eve::saturated](T(), T(), f), logical); + TTS_EXPR_IS(eve::compare_absolute[eve::saturated](v_t(), v_t(), f), logical); + TTS_EXPR_IS(eve::compare_absolute[eve::saturated](T(), v_t(), f), logical); + TTS_EXPR_IS(eve::compare_absolute[eve::saturated](v_t(), T(), f), logical); + +}; + +//================================================================================================== +//== Tests for eve::compare_absolute +//================================================================================================== +TTS_CASE_WITH("Check behavior of eve::compare_absolute(simd)", + eve::test::simd::all_types, + tts::generate(tts::ramp(0), tts::reverse_ramp(4, 2), tts::logicals(0, 3))) +(T const& a0, T const& a1, M const& t) +{ + using v_t = eve::element_type_t; + auto ff = eve::is_greater; + TTS_EQUAL(eve::compare_absolute(a0, a1, ff), + tts::map([](auto e, auto f) -> eve::logical { return eve::abs(e) > eve::abs(f); }, a0, a1)); + TTS_EQUAL(eve::compare_absolute(a0, a0, ff), + tts::map([](auto e, auto f) -> eve::logical { return eve::abs(e) > eve::abs(f); }, a0, a0)); + TTS_EQUAL(eve::compare_absolute(a0, v_t(1), ff), + tts::map([](auto e) -> eve::logical { return eve::abs(e) > v_t(1); }, a0)); + TTS_EQUAL(eve::compare_absolute[t](a0, a1, ff), + eve::if_else(t, eve::compare_absolute(a0, a1, ff), eve::false_(eve::as(a0)))); + auto gg = eve::is_gtz; + TTS_EQUAL(eve::compare_absolute(a0, gg), + tts::map([](auto e) -> eve::logical { return eve::abs(e) > v_t(0); }, a0)); + TTS_EQUAL(eve::compare_absolute[t](a0, gg), + eve::if_else(t, eve::compare_absolute(a0, gg), eve::false_(eve::as(a0)))); +};