Skip to content

Commit

Permalink
#1520 - scatter emulation
Browse files Browse the repository at this point in the history
  • Loading branch information
jfalcou authored Jan 19, 2024
1 parent 3ef389f commit 6fac6da
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 0 deletions.
1 change: 1 addition & 0 deletions include/eve/module/core/regular/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@
#include <eve/module/core/regular/safe.hpp>
#include <eve/module/core/regular/saturate.hpp>
#include <eve/module/core/regular/scan.hpp>
#include <eve/module/core/regular/scatter.hpp>
#include <eve/module/core/regular/shl.hpp>
#include <eve/module/core/regular/shr.hpp>
#include <eve/module/core/regular/shuffle_v2.hpp>
Expand Down
42 changes: 42 additions & 0 deletions include/eve/module/core/regular/impl/scatter.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
//==================================================================================================
/*
EVE - Expressive Vector Engine
Copyright : EVE Project Contributors
SPDX-License-Identifier: BSL-1.0
*/
//==================================================================================================
#pragma once

#include <eve/detail/abi.hpp>
#include <eve/forward.hpp>
#include <eve/module/core/regular/store_equivalent.hpp>
#include <eve/module/core/regular/unalign.hpp>
#include <eve/module/core/regular/write.hpp>

namespace eve::detail
{
template<typename T, typename Idx, typename Ptr, callable_options O>
EVE_FORCEINLINE void scatter_(EVE_REQUIRES(cpu_), O const& o, T const& v, Ptr p, Idx const& idx)
{
// Retrieve element to scatter
auto se = store_equivalent(o[condition_key],v,p);

// Extract the pointer from a potential aligned_ptr
auto base = unalign(get<2>(se));

// Single-value scatter
auto sc = [&](auto n, auto c, auto v)
{
// We only write if mask is set
if constexpr(match_option<condition_key,O,ignore_none_>) write(v.get(n),base+idx.get(n));
else
{
auto m = c.mask( as<as_logical_t<T>>{} );
if(m.get(n)) write(v.get(n),base+idx.get(n));
}
};

// Scatter all (clang doesn't like capturing structured bindings)
eve::detail::for_<0, 1, T::size()>([&](auto... I) { ( sc(I,get<0>(se),get<1>(se)), ...); });
}
}
74 changes: 74 additions & 0 deletions include/eve/module/core/regular/scatter.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
//======================================================================================================================
/*
EVE - Expressive Vector Engine
Copyright : EVE Project Contributors
SPDX-License-Identifier: BSL-1.0
*/
//======================================================================================================================
#pragma once

#include <eve/arch.hpp>
#include <eve/concept/memory.hpp>
#include <eve/traits/overload.hpp>
#include <eve/memory/aligned_ptr.hpp>
#include <eve/module/core/decorator/core.hpp>

namespace eve
{
template<typename Options>
struct scatter_t : callable<scatter_t, Options, relative_conditional_no_alternative_option>
{
template<simd_value T, integral_simd_value Idx, simd_compatible_ptr<T> Ptr>
EVE_FORCEINLINE void operator()(T const& v, Ptr ptr, Idx const& idx) const noexcept { EVE_DISPATCH_CALL(v,ptr,idx); }

EVE_CALLABLE_OBJECT(scatter_t, scatter_);
};

//======================================================================================================================
//! @addtogroup core_arithmetic
//! @{
//! @var scatter
//! @brief Store a SIMD register to memory using scattered indexes
//!
//! Store each element of a given [SIMD value](@ref eve::simd_value) `v`in different memory address computed form a base
//! SIMD compatible iterator `ptr` and a [SIMD integral value](@ref eve::integral_simd_value) `idx` used as indexes.
//!
//! A call to `eve::scatter(v,ptr,idx)` is semantically equivalent to:
//!
//! ```
//! for(std::size_t i=0;i<v.size();++i)
//! ptr[idx.get(i)] = v.get(i);
//! ```
//!
//! @groupheader{Header file}
//!
//! @code
//! #include <eve/module/core.hpp>
//! @endcode
//!
//! @groupheader{Callable Signatures}
//!
//! @code
//! namespace eve
//! {
//! template<simd_value T, integral_simd_value Idx, simd_compatible_ptr<T> Ptr>
//! void scatter(T const& v, Ptr ptr, Idx const& idx) noexcept;
//! }
//! @endcode
//!
//! **Parameters**
//!
//! * `v` : [SIMD value](@ref eve::simd_value) to scatter
//! * `ptr` : Base pointer to scatter to.
//! * `idx` : [Integral SIMD value](@ref eve::integral_simd_value) containing the index to scatter to.
//!
//! @groupheader{Example}
//!
//! @godbolt{doc/core/regular/scatter.cpp}
//!
//! @}
//======================================================================================================================
inline constexpr auto scatter = functor<scatter_t>;
}

#include <eve/module/core/regular/impl/scatter.hpp>
17 changes: 17 additions & 0 deletions include/eve/traits/overload/supports.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,23 @@ namespace eve
return options<decltype(new_opts)>{new_opts};
}
};

struct relative_conditional_no_alternative_option
{
template<relative_conditional_expr Opt>
EVE_FORCEINLINE constexpr auto process(auto const& base, Opt opt) const
requires( !Opt::has_alternative )
{
auto new_opts = rbr::merge(options{condition_key = opt}, base);
return options<decltype(new_opts)>{new_opts};
}

EVE_FORCEINLINE constexpr auto default_to(auto const& base) const
{
auto new_opts = rbr::merge(base, options{condition_key = ignore_none});
return options<decltype(new_opts)>{new_opts};
}
};
}

namespace eve::detail
Expand Down
17 changes: 17 additions & 0 deletions test/doc/core/regular/scatter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include <eve/module/core.hpp>
#include <iostream>

int main()
{
float data[2*eve::wide<float>::size()] = {};

eve::wide<int> indexes = [](auto i, auto) { return 2*i; };
eve::wide<float> values = [](auto i, auto) { return 1.5f * (1+i); };

eve::scatter(values, data, indexes);

for(auto e : data)
std::cout << e << " ";
std::cout << "\n";
}

55 changes: 55 additions & 0 deletions test/unit/module/core/scatter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
//==================================================================================================
/*
EVE - Expressive Vector Engine
Copyright : EVE Project Contributors
SPDX-License-Identifier: BSL-1.0
*/
//==================================================================================================
#include "test.hpp"

#include <eve/module/core.hpp>
#include <algorithm>

template<typename Index, typename T>
void test_scatter(T const& values, auto ptr)
{
using e_t = eve::element_type_t<T>;
eve::wide<Index, eve::fixed<T::size()>> indexes = [](auto i, auto) { return 2*i; };

e_t ref [2*T::size()];
auto base = eve::unalign(ptr);
std::copy(base,base+2*T::size(),&ref[0]);
for( std::size_t i = 0; i < T::size(); ++i ) ref[indexes.get(i)] = values.get(i);

eve::scatter(values, ptr, indexes);

TTS_EXPECT(std::equal(&ref[0],&ref[0]+2*T::size(),base));
}

TTS_CASE_TPL("Check eve::scatter behavior with pointer", eve::test::simd::all_types)
<typename T>(tts::type<T>)
{
using e_t = eve::element_type_t<T>;
e_t data[2*T::size()];
for( std::size_t i = 0; i < 2*T::size(); ++i ) data[i] = e_t(99);
T values = [](auto i, auto) { return 2*i+1; };

test_scatter<std::int32_t >(values,&data[0]);
test_scatter<std::uint32_t>(values,&data[0]);
test_scatter<std::int64_t >(values,&data[0]);
test_scatter<std::uint64_t>(values,&data[0]);
};

TTS_CASE_TPL("Check eve::scatter behavior with aligned_ptr", eve::test::simd::all_types)
<typename T>(tts::type<T>)
{
using e_t = eve::element_type_t<T>;
alignas(T::alignment()) e_t data[2*T::size()];
for( std::size_t i = 0; i < 2*T::size(); ++i ) data[i] = e_t(99);
T values = [](auto i, auto) { return 2*i+1; };

test_scatter<std::int32_t >(values,eve::as_aligned(&data[0],typename T::cardinal_type{}));
test_scatter<std::uint32_t>(values,eve::as_aligned(&data[0],typename T::cardinal_type{}));
test_scatter<std::int64_t >(values,eve::as_aligned(&data[0],typename T::cardinal_type{}));
test_scatter<std::uint64_t>(values,eve::as_aligned(&data[0],typename T::cardinal_type{}));
};

0 comments on commit 6fac6da

Please sign in to comment.