Skip to content

Commit

Permalink
Implements compare_absolute : apply predicate on absolute values - SV…
Browse files Browse the repository at this point in the history
…E optimisation
  • Loading branch information
jtlap authored Oct 25, 2024
1 parent dddd90f commit c2dacd7
Show file tree
Hide file tree
Showing 6 changed files with 334 additions and 0 deletions.
118 changes: 118 additions & 0 deletions include/eve/module/core/regular/compare_absolute.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
//==================================================================================================
/*
EVE - Expressive Vector Engine
Copyright : EVE Project Contributors
SPDX-License-Identifier: BSL-1.0
*/
//==================================================================================================
#pragma once

#include <eve/arch.hpp>
#include <eve/concept/value.hpp>
#include <eve/detail/function/friends.hpp>
#include <eve/detail/implementation.hpp>

namespace eve
{
template<typename Options>
struct compare_absolute_t : strict_elementwise_callable<compare_absolute_t, Options, saturated_option>
{
template< value T, typename F>
constexpr EVE_FORCEINLINE logical<T> operator()(T a, F f) const
{
return EVE_DISPATCH_CALL(a, f);
}

template<value T, value U, typename F>
requires(eve::same_lanes_or_scalar<T, U>)
constexpr EVE_FORCEINLINE common_logical_t<T, U> operator()(T a, U b, F f) const
{
return EVE_DISPATCH_CALL(a, b, f);
}

EVE_CALLABLE_OBJECT(compare_absolute_t, compare_absolute_);
};

//================================================================================================
//! @addtogroup core_predicates
//! @{
//! @var compare_absolute
//! @brief `elementwise callable` returning a logical true if and only if the absolute value of the first parameters
//! satisfy the predicate parameters.
//!
//! @groupheader{Header file}
//!
//! @code
//! #include <eve/module/core.hpp>
//! @endcode
//!
//! @groupheader{Callable Signatures}
//!
//! @code
//! namespace eve
//! {
//! // Regular overload
//! constexpr auto compare_absolute(value auto x, auto f) noexcept; // 1
//! constexpr auto compare_absolute(value auto x, value auto y, auto f) noexcept; // 1
//!
//! // Lanes masking
//! constexpr auto compare_absolute[conditional_expr auto c](value auto x, value auto y, auto f) noexcept; // 2
//! constexpr auto compare_absolute[logical_value auto m](value auto x, value auto y, auto f) noexcept; // 2
//! constexpr auto compare_absolute[conditional_expr auto c](value auto x, auto f) noexcept; // 2
//! constexpr auto compare_absolute[logical_value auto m](value auto x, auto f) noexcept; // 2
//!
//! // Semantic options
//! constexpr auto compare_absolute[saturated](value auto x, auto f) noexcept; // 3
//! constexpr auto compare_absolute[saturated](value auto x, value auto y, auto f) noexcept; // 3
//! }
//! @endcode
//!
//! **Parameters**
//!
//! * `x`, `y`: [arguments](@ref eve::value).
//! * `f`: one or two parameter predicate according to `y` presence.
//! * `c`: [Conditional expression](@ref conditional_expr) masking the operation.
//! * `m`: [Logical value](@ref logical) masking the operation.
//!
//! **Return value**
//!
//! 1. The call `eve::compare_absolute(x,y,f)` is semantically equivalent to `f(abs(x), abs(y))` and
//! `eve::compare_absolute(x,f)` is semantically equivalent to `f(abs(x))` and
//! 2. [The operation is performed conditionnaly](@ref conditional).
//! 3. the option is transfered to 'abs' in the call.
//!
//! @groupheader{Example}
//! @godbolt{doc/core/compare_absolute.cpp}
//================================================================================================
inline constexpr auto compare_absolute = functor<compare_absolute_t>;
//================================================================================================
//! @}
//================================================================================================
}

#include <eve/module/core/regular/abs.hpp>

namespace eve::detail
{
template<value T, typename F, callable_options O>
EVE_FORCEINLINE constexpr logical<T>
compare_absolute_(EVE_REQUIRES(cpu_), O const& o, T a, F f) noexcept
{
return f(eve::abs[o](a));
}

template<value T, value U, typename F, callable_options O>
EVE_FORCEINLINE constexpr common_logical_t<T,U>
compare_absolute_(EVE_REQUIRES(cpu_), O const& o, T a, U b, F f) noexcept
{
return f(eve::abs[o](a), abs[o](b));
}
}

#if defined(EVE_INCLUDE_ARM_NEON_HEADER)
# include <eve/module/core/regular/impl/simd/arm/neon/compare_absolute.hpp>
#endif

#if defined(EVE_INCLUDE_ARM_SVE_HEADER)
# include <eve/module/core/regular/impl/simd/arm/sve/compare_absolute.hpp>
#endif
1 change: 1 addition & 0 deletions include/eve/module/core/regular/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include <eve/module/core/regular/chi.hpp>
#include <eve/module/core/regular/clamp.hpp>
#include <eve/module/core/regular/combine.hpp>
#include <eve/module/core/regular/compare_absolute.hpp>
#include <eve/module/core/regular/convert.hpp>
#include <eve/module/core/regular/copysign.hpp>
#include <eve/module/core/regular/count_true.hpp>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
//==================================================================================================
/*
EVE - Expressive Vector Engine
Copyright : EVE Project Contributors
SPDX-License-Identifier: BSL-1.0
*/
//==================================================================================================
#pragma once

#include <eve/concept/value.hpp>
#include <eve/detail/abi.hpp>
#include <eve/detail/category.hpp>
#include <eve/forward.hpp>
#include <eve/module/core/regular/simd_cast.hpp>
#include <eve/module/core/regular/is_greater_equal.hpp>
#include <eve/module/core/regular/is_greater.hpp>
#include <eve/module/core/regular/is_less_equal.hpp>
#include <eve/module/core/regular/is_less.hpp>

namespace eve::detail
{
template<callable_options O, floating_scalar_value T, typename N, typename F>
EVE_FORCEINLINE logical<wide<T, N>> compare_absolute_(EVE_REQUIRES(neon128_), O const& opts,
wide<T, N> v, wide<T, N> w, F f) noexcept
requires (arm_abi<abi_t<T, N>>)
{
constexpr auto c = categorize<wide<T, N>>();
if constexpr(F{} == is_less_equal)
{
if constexpr( c == category::float32x2 ) return vcale_f32 (v, w);
else if constexpr( c == category::float32x4 ) return vcaleq_f32(v, w);
else if constexpr( current_api >= asimd )
{
if constexpr( c == category::float64x1 ) return vcale_f64 (v, w);
else if constexpr( c == category::float64x2 ) return vcaleq_f64 (v, w);
}
else return compare_absolute.behavior(cpu_{}, v, w, f);
}
else if constexpr(F{} == is_greater_equal)
{
if constexpr( c == category::float32x2 ) return vcage_f32 (v, w);
else if constexpr( c == category::float32x4 ) return vcageq_f32(v, w);
else if constexpr( current_api >= asimd )
{
if constexpr( c == category::float64x1 ) return vcage_f64 (v, w);
else if constexpr( c == category::float64x2 ) return vcageq_f64 (v, w);
}
else return compare_absolute.behavior(cpu_{}, v, w, f);
}
else if constexpr(F{} == is_less)
{
if constexpr( c == category::float32x2 ) return vcalt_f32 (v, w);
else if constexpr( c == category::float32x4 ) return vcaltq_f32(v, w);
else if constexpr( current_api >= asimd )
{
if constexpr( c == category::float64x1 ) return vcalt_f64 (v, w);
else if constexpr( c == category::float64x2 ) return vcaltq_f64 (v, w);
}
else return compare_absolute.behavior(cpu_{}, v, w, f);
}
else if constexpr(F{} == is_greater)
{
if constexpr( c == category::float32x2 ) return vcagt_f32 (v, w);
else if constexpr( c == category::float32x4 ) return vcagtq_f32(v, w);
else if constexpr( current_api >= asimd )
{
if constexpr( c == category::float64x1 ) return vcagt_f64 (v, w);
else if constexpr( c == category::float64x2 ) return vcagtq_f64 (v, w);
}
else return compare_absolute.behavior(cpu_{}, v, w, f);
}
else return compare_absolute.behavior(cpu_{}, v, w, f);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
//==================================================================================================
/*
EVE - Expressive Vector Engine
Copyright : EVE Project Contributors
SPDX-License-Identifier: BSL-1.0
*/
//==================================================================================================
#pragma once

#include <eve/concept/value.hpp>
#include <eve/detail/abi.hpp>
#include <eve/detail/category.hpp>
#include <eve/forward.hpp>
#include <eve/module/core/regular/simd_cast.hpp>
#include <eve/module/core/regular/is_greater_equal.hpp>
#include <eve/module/core/regular/is_greater.hpp>
#include <eve/module/core/regular/is_less_equal.hpp>
#include <eve/module/core/regular/is_less.hpp>


namespace eve::detail
{

template<callable_options O, floating_scalar_value T, typename N, typename F>
EVE_FORCEINLINE logical<wide<T, N>>
compare_absolute_(EVE_REQUIRES(sve_), O const& opts,
wide<T, N> v, wide<T, N> w, F f) noexcept
requires (sve_abi<abi_t<T, N>>)
{
auto m = sve_true<T>();
if constexpr(F{} == is_less_equal) return svacle(m, v, w);
else if constexpr(F{} == is_less) return svaclt(m, v, w);
else if constexpr(F{} == is_greater_equal) return svacge(m, v, w);
else if constexpr(F{} == is_less_equal) return svacgt(m, v, w);
else return compare_absolute.behavior(cpu_{}, opts, v, w, f);
}

template<callable_options O, floating_scalar_value T, typename N, conditional_expr C, typename F>
EVE_FORCEINLINE logical<wide<T, N>> compare_absolute_(EVE_REQUIRES(sve_), C const& mask, O const& opts,
wide<T, N> v, wide<T, N> w, F f) noexcept
requires (sve_abi<abi_t<T, N>>)
{
auto const alt = alternative(mask, v, as(to_logical(v)));
if constexpr( C::is_complete ) return alt;
else if constexpr (!C::has_alternative)
{
auto m = expand_mask(mask, as(v));
if constexpr(F{} == is_less_equal) return svacle(m, v, w);
else if constexpr(F{} == is_less) return svaclt(m, v, w);
else if constexpr(F{} == is_greater_equal) return svacge(m, v, w);
else if constexpr(F{} == is_less_equal) return svacgt(m, v, w);
else return compare_absolute.behavior(cpu_{}, opts, v, w, f);
}
else return compare_absolute.behavior(cpu_{}, opts, v, w, f);
}
}
30 changes: 30 additions & 0 deletions test/doc/core/compare_absolute.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// revision 0
#include <eve/module/core.hpp>
#include <iostream>

int main()
{
eve::wide wf0{0.0, 1.0, 2.0, 3.0, -1.0, -2.0, -3.0, -4.0};
eve::wide wf1{0.0, -4.0, 1.0, -1.0, 2.0, -2.0, 3.0, -3.0};
eve::wide wi0{0, 1, 2, 3, -1, -2, -3, -4};
eve::wide wi1{0, -4, 1, -1, 2, -2, 3, -3};
eve::wide wu0{0u, 1u, 2u, 3u, 4u, 5u, 6u, 7u};
eve::wide wu1{7u, 6u, 5u, 4u, 3u, 2u, 1u, 0u};

std::cout << "<- wf0 = " << wf0 << "\n";
std::cout << "<- wf1 = " << wf1 << "\n";
std::cout << "<- wi0 = " << wi0 << "\n";
std::cout << "<- wi1 = " << wi1 << "\n";
std::cout << "<- wu0 = " << wu0 << "\n";
std::cout << "<- wu1 = " << wu1 << "\n";
using eve::is_greater;
std::cout << "-> compare_absolute(wf0, wf1, is_greater) = " << eve::compare_absolute(wf0, wf1, is_greater) << "\n";
std::cout << "-> compare_absolute[ignore_last(2)](wf0, wf1, is_greater) = " << eve::compare_absolute[eve::ignore_last(2)](wf0, wf1, is_greater) << "\n";
std::cout << "-> compare_absolute[wf0 != 0](wf0, wf1, is_greater) = " << eve::compare_absolute[wf0 != 0](wf0, wf1, is_greater) << "\n";
std::cout << "-> compare_absolute(wu0, wu1, is_greater) = " << eve::compare_absolute(wu0, wu1, is_greater) << "\n";
std::cout << "-> compare_absolute[ignore_last(2)](wu0, wu1, is_greater) = " << eve::compare_absolute[eve::ignore_last(2)](wu0, wu1, is_greater) << "\n";
std::cout << "-> compare_absolute[wu0 != 0](wu0, wu1, is_greater) = " << eve::compare_absolute[wu0 != 0](wu0, wu1, is_greater) << "\n";
std::cout << "-> compare_absolute(wi0, wi1, is_greater) = " << eve::compare_absolute(wi0, wi1, is_greater) << "\n";
std::cout << "-> compare_absolute[ignore_last(2)](wi0, wi1, is_greater) = " << eve::compare_absolute[eve::ignore_last(2)](wi0, wi1, is_greater) << "\n";
std::cout << "-> compare_absolute[wi0 != 0](wi0, wi1, is_greater) = " << eve::compare_absolute[wi0 != 0](wi0, wi1, is_greater) << "\n";
}
55 changes: 55 additions & 0 deletions test/unit/module/core/compare_absolute.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
//==================================================================================================
/**
EVE - Expressive Vector Engine
Copyright : EVE Project Contributors
SPDX-License-Identifier: BSL-1.0
**/
//==================================================================================================
#include "test.hpp"

#include <eve/module/core.hpp>

//==================================================================================================
//== Types tests
//==================================================================================================
TTS_CASE_TPL("Check return types of eve::compare_absolute(simd)", eve::test::simd::all_types)
<typename T>(tts::type<T>)
{
using eve::logical;
using v_t = eve::element_type_t<T>;
auto f = eve::is_greater;
TTS_EXPR_IS(eve::compare_absolute(T(), T(), f), logical<T>);
TTS_EXPR_IS(eve::compare_absolute(v_t(), v_t(), f), logical<v_t>);
TTS_EXPR_IS(eve::compare_absolute(T(), v_t(), f), logical<T>);
TTS_EXPR_IS(eve::compare_absolute(v_t(), T(), f), logical<T>);
TTS_EXPR_IS(eve::compare_absolute[eve::saturated](T(), T(), f), logical<T>);
TTS_EXPR_IS(eve::compare_absolute[eve::saturated](v_t(), v_t(), f), logical<v_t>);
TTS_EXPR_IS(eve::compare_absolute[eve::saturated](T(), v_t(), f), logical<T>);
TTS_EXPR_IS(eve::compare_absolute[eve::saturated](v_t(), T(), f), logical<T>);

};

//==================================================================================================
//== Tests for eve::compare_absolute
//==================================================================================================
TTS_CASE_WITH("Check behavior of eve::compare_absolute(simd)",
eve::test::simd::all_types,
tts::generate(tts::ramp(0), tts::reverse_ramp(4, 2), tts::logicals(0, 3)))
<typename T, typename M>(T const& a0, T const& a1, M const& t)
{
using v_t = eve::element_type_t<T>;
auto ff = eve::is_greater;
TTS_EQUAL(eve::compare_absolute(a0, a1, ff),
tts::map([](auto e, auto f) -> eve::logical<v_t> { return eve::abs(e) > eve::abs(f); }, a0, a1));
TTS_EQUAL(eve::compare_absolute(a0, a0, ff),
tts::map([](auto e, auto f) -> eve::logical<v_t> { return eve::abs(e) > eve::abs(f); }, a0, a0));
TTS_EQUAL(eve::compare_absolute(a0, v_t(1), ff),
tts::map([](auto e) -> eve::logical<v_t> { return eve::abs(e) > v_t(1); }, a0));
TTS_EQUAL(eve::compare_absolute[t](a0, a1, ff),
eve::if_else(t, eve::compare_absolute(a0, a1, ff), eve::false_(eve::as(a0))));
auto gg = eve::is_gtz;
TTS_EQUAL(eve::compare_absolute(a0, gg),
tts::map([](auto e) -> eve::logical<v_t> { return eve::abs(e) > v_t(0); }, a0));
TTS_EQUAL(eve::compare_absolute[t](a0, gg),
eve::if_else(t, eve::compare_absolute(a0, gg), eve::false_(eve::as(a0))));
};

0 comments on commit c2dacd7

Please sign in to comment.