Skip to content

Commit

Permalink
Add set for_each
Browse files Browse the repository at this point in the history
  • Loading branch information
PointKernel committed Oct 31, 2024
1 parent 69817e2 commit 63f26dd
Show file tree
Hide file tree
Showing 7 changed files with 310 additions and 10 deletions.
7 changes: 3 additions & 4 deletions include/cuco/detail/static_map/static_map_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
#include <cuco/operator.hpp>

#include <cuda/atomic>
#include <cuda/std/functional>
#include <cuda/std/type_traits>
#include <thrust/tuple.h>
#include <cuda/std/utility>

#include <cooperative_groups.h>

Expand Down Expand Up @@ -1335,7 +1334,7 @@ class operator_impl<
{
// CRTP: cast `this` to the actual ref type
auto const& ref_ = static_cast<ref_type const&>(*this);
ref_.impl_.for_each(key, std::forward<CallbackOp>(callback_op));
ref_.impl_.for_each(key, cuda::std::forward<CallbackOp>(callback_op));
}

/**
Expand Down Expand Up @@ -1363,7 +1362,7 @@ class operator_impl<
{
// CRTP: cast `this` to the actual ref type
auto const& ref_ = static_cast<ref_type const&>(*this);
ref_.impl_.for_each(group, key, std::forward<CallbackOp>(callback_op));
ref_.impl_.for_each(group, key, cuda::std::forward<CallbackOp>(callback_op));
}
};

Expand Down
60 changes: 60 additions & 0 deletions include/cuco/detail/static_set/static_set.inl
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,66 @@ void static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>
impl_->find_async(first, last, output_begin, ref(op::find), stream);
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename CallbackOp>
void static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::for_each(
CallbackOp&& callback_op, cuda::stream_ref stream) const
{
impl_->for_each_async(std::forward<CallbackOp>(callback_op), stream);
stream.wait();
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename CallbackOp>
void static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::for_each_async(
CallbackOp&& callback_op, cuda::stream_ref stream) const
{
impl_->for_each_async(std::forward<CallbackOp>(callback_op), stream);
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt, typename CallbackOp>
void static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::for_each(
InputIt first, InputIt last, CallbackOp&& callback_op, cuda::stream_ref stream) const
{
impl_->for_each_async(
first, last, std::forward<CallbackOp>(callback_op), ref(op::for_each), stream);
stream.wait();
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt, typename CallbackOp>
void static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::for_each_async(
InputIt first, InputIt last, CallbackOp&& callback_op, cuda::stream_ref stream) const noexcept
{
impl_->for_each_async(
first, last, std::forward<CallbackOp>(callback_op), ref(op::for_each), stream);
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
Expand Down
69 changes: 69 additions & 0 deletions include/cuco/detail/static_set/static_set_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <cuda/atomic>
#include <cuda/std/type_traits>
#include <cuda/std/utility>

#include <cooperative_groups.h>

Expand Down Expand Up @@ -629,6 +630,74 @@ class operator_impl<op::find_tag,
}
};

template <typename Key,
cuda::thread_scope Scope,
typename KeyEqual,
typename ProbingScheme,
typename StorageRef,
typename... Operators>
class operator_impl<op::for_each_tag,
static_set_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>> {
using base_type = static_set_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef>;
using ref_type = static_set_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>;
using key_type = typename base_type::key_type;
using value_type = typename base_type::value_type;
using iterator = typename base_type::iterator;
using const_iterator = typename base_type::const_iterator;

static constexpr auto cg_size = base_type::cg_size;
static constexpr auto window_size = base_type::window_size;

public:
/**
* @brief For a given key, applies the function object `callback_op` to its match found in the
* container.
*
* @note The return value of `callback_op`, if any, is ignored.
*
* @tparam ProbeKey Probe key type
* @tparam CallbackOp Type of unary callback function object
*
* @param key The key to search for
* @param callback_op Function to apply to the copy of the matched slot
*/
template <class ProbeKey, class CallbackOp>
__device__ void for_each(ProbeKey const& key, CallbackOp&& callback_op) const noexcept
{
// CRTP: cast `this` to the actual ref type
auto const& ref_ = static_cast<ref_type const&>(*this);
ref_.impl_.for_each(key, cuda::std::forward<CallbackOp>(callback_op));
}

/**
* @brief For a given key, applies the function object `callback_op` to its match found in the
* container.
*
* @note This function uses cooperative group semantics, meaning that any thread may call the
* callback if it finds a matching slot.
*
* @note The return value of `callback_op`, if any, is ignored.
*
* @note Synchronizing `group` within `callback_op` is undefined behavior.
*
* @tparam ProbeKey Probe key type
* @tparam CallbackOp Type of unary callback function object
*
* @param group The Cooperative Group used to perform this operation
* @param key The key to search for
* @param callback_op Function to apply to the copy of the matched slot
*/
template <class ProbeKey, class CallbackOp>
__device__ void for_each(cooperative_groups::thread_block_tile<cg_size> const& group,
ProbeKey const& key,
CallbackOp&& callback_op) const noexcept
{
// CRTP: cast `this` to the actual ref type
auto const& ref_ = static_cast<ref_type const&>(*this);
ref_.impl_.for_each(group, key, cuda::std::forward<CallbackOp>(callback_op));
}
};

template <typename Key,
cuda::thread_scope Scope,
typename KeyEqual,
Expand Down
8 changes: 4 additions & 4 deletions include/cuco/static_map.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ class static_map {
*
* @tparam CallbackOp Type of unary callback function object
*
* @param callback_op Function to apply to the copy of the matched key-value pair
* @param callback_op Function to apply to the copy of the filled slot
* @param stream CUDA stream used for this operation
*/
template <typename CallbackOp>
Expand All @@ -789,7 +789,7 @@ class static_map {
*
* @tparam CallbackOp Type of unary callback function object
*
* @param callback_op Function to apply to the copy of the matched key-value pair
* @param callback_op Function to apply to the copy of the filled slot
* @param stream CUDA stream used for this operation
*/
template <typename CallbackOp>
Expand All @@ -806,7 +806,7 @@ class static_map {
*
* @param first Beginning of the sequence of keys
* @param last End of the sequence of keys
* @param callback_op Function to apply to the copy of the matched key-value pair
* @param callback_op Function to apply to the copy of the matched slot
* @param stream CUDA stream used for this operation
*/
template <typename InputIt, typename CallbackOp>
Expand All @@ -826,7 +826,7 @@ class static_map {
*
* @param first Beginning of the sequence of keys
* @param last End of the sequence of keys
* @param callback_op Function to apply to the copy of the matched key-value pair
* @param callback_op Function to apply to the copy of the matched slot
* @param stream CUDA stream used for this operation
*/
template <typename InputIt, typename CallbackOp>
Expand Down
72 changes: 70 additions & 2 deletions include/cuco/static_set.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ class static_set {
*
* @tparam InputIt Device accessible random access input iterator
* @tparam FoundIt Device accessible random access output iterator whose `value_type`
* is constructible from `map::iterator` type
* is constructible from `set::iterator` type
* @tparam InsertedIt Device accessible random access output iterator whose `value_type`
* is constructible from `bool`
*
Expand Down Expand Up @@ -379,7 +379,7 @@ class static_set {
*
* @tparam InputIt Device accessible random access input iterator
* @tparam FoundIt Device accessible random access output iterator whose `value_type`
* is constructible from `map::iterator` type
* is constructible from `set::iterator` type
* @tparam InsertedIt Device accessible random access output iterator whose `value_type`
* is constructible from `bool`
*
Expand Down Expand Up @@ -590,6 +590,74 @@ class static_set {
OutputIt output_begin,
cuda::stream_ref stream = {}) const;

/**
* @brief Applies the given function object `callback_op` to the copy of every filled slot in the
* container
*
* @note The return value of `callback_op`, if any, is ignored.
*
* @tparam CallbackOp Type of unary callback function object
*
* @param callback_op Function to apply to the copy of the filled slot
* @param stream CUDA stream used for this operation
*/
template <typename CallbackOp>
void for_each(CallbackOp&& callback_op, cuda::stream_ref stream = {}) const;

/**
* @brief Asynchronously applies the given function object `callback_op` to the copy of every
* filled slot in the container
*
* @note The return value of `callback_op`, if any, is ignored.
*
* @tparam CallbackOp Type of unary callback function object
*
* @param callback_op Function to apply to the copy of the filled slot
* @param stream CUDA stream used for this operation
*/
template <typename CallbackOp>
void for_each_async(CallbackOp&& callback_op, cuda::stream_ref stream = {}) const;

/**
* @brief For each key in the range [first, last), applies the function object `callback_op` to
* the copy of all corresponding matches found in the container.
*
* @note The return value of `callback_op`, if any, is ignored.
*
* @tparam InputIt Device accessible random access input iterator
* @tparam CallbackOp Type of unary callback function object
*
* @param first Beginning of the sequence of keys
* @param last End of the sequence of keys
* @param callback_op Function to apply to the copy of the matched slot
* @param stream CUDA stream used for this operation
*/
template <typename InputIt, typename CallbackOp>
void for_each(InputIt first,
InputIt last,
CallbackOp&& callback_op,
cuda::stream_ref stream = {}) const;

/**
* @brief For each key in the range [first, last), asynchronously applies the function object
* `callback_op` to the copy of all corresponding matches found in the container.
*
* @note The return value of `callback_op`, if any, is ignored.
*
* @tparam InputIt Device accessible random access input iterator
* @tparam CallbackOp Type of unary callback function object
*
* @param first Beginning of the sequence of keys
* @param last End of the sequence of keys
* @param callback_op Function to apply to the copy of the matched slot
* @param stream CUDA stream used for this operation
*/
template <typename InputIt, typename CallbackOp>
void for_each_async(InputIt first,
InputIt last,
CallbackOp&& callback_op,
cuda::stream_ref stream = {}) const noexcept;

/**
* @brief Counts the occurrences of keys in `[first, last)` contained in the set
*
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ ConfigureTest(UTILITY_TEST
# - static_set tests ------------------------------------------------------------------------------
ConfigureTest(STATIC_SET_TEST
static_set/capacity_test.cu
static_set/for_each_test.cu
static_set/heterogeneous_lookup_test.cu
static_set/insert_and_find_test.cu
static_set/large_input_test.cu
Expand Down
Loading

0 comments on commit 63f26dd

Please sign in to comment.