Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distinct code path for key equality checks in packed_cas #356

Merged
merged 4 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 43 additions & 19 deletions include/cuco/detail/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ class open_addressing_ref_impl {
/**
* @brief Inserts an element.
*
* @tparam HasPayload Boolean indicating it's a set or map implementation
* @tparam Predicate Predicate type
*
* @param key Key of the element to insert
Expand All @@ -140,7 +141,7 @@ class open_addressing_ref_impl {
*
* @return True if the given element is successfully inserted
*/
template <typename Predicate>
template <bool HasPayload, typename Predicate>
__device__ bool insert(key_type const& key,
value_type const& value,
Predicate const& predicate) noexcept
Expand All @@ -158,7 +159,7 @@ class open_addressing_ref_impl {
if (eq_res == detail::equal_result::EQUAL) { return false; }
if (eq_res == detail::equal_result::EMPTY) {
auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content);
switch (attempt_insert(
switch (attempt_insert<HasPayload>(
(storage_ref_.data() + *probing_iter)->data() + intra_window_index, value, predicate)) {
case insert_result::CONTINUE: continue;
case insert_result::SUCCESS: return true;
Expand All @@ -173,6 +174,7 @@ class open_addressing_ref_impl {
/**
* @brief Inserts an element.
*
* @tparam HasPayload Boolean indicating it's a set or map implementation
* @tparam Predicate Predicate type
*
* @param group The Cooperative Group used to perform group insert
Expand All @@ -182,7 +184,7 @@ class open_addressing_ref_impl {
*
* @return True if the given element is successfully inserted
*/
template <typename Predicate>
template <bool HasPayload, typename Predicate>
__device__ bool insert(cooperative_groups::thread_block_tile<cg_size> const& group,
key_type const& key,
value_type const& value,
Expand Down Expand Up @@ -214,9 +216,10 @@ class open_addressing_ref_impl {
auto const src_lane = __ffs(group_contains_empty) - 1;
auto const status =
(group.thread_rank() == src_lane)
? attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index,
value,
predicate)
? attempt_insert<HasPayload>(
(storage_ref_.data() + *probing_iter)->data() + intra_window_index,
value,
predicate)
: insert_result::CONTINUE;

switch (group.shfl(status, src_lane)) {
Expand All @@ -237,6 +240,7 @@ class open_addressing_ref_impl {
* element that prevented the insertion) and a `bool` denoting whether the insertion took place or
* not.
*
* @tparam HasPayload Boolean indicating it's a set or map implementation
* @tparam Predicate Predicate type
*
* @param key Key of the element to insert
Expand All @@ -246,7 +250,7 @@ class open_addressing_ref_impl {
* @return a pair consisting of an iterator to the element and a bool indicating whether the
* insertion is successful or not.
*/
template <typename Predicate>
template <bool HasPayload, typename Predicate>
__device__ thrust::pair<iterator, bool> insert_and_find(key_type const& key,
value_type const& value,
Predicate const& predicate) noexcept
Expand All @@ -266,7 +270,7 @@ class open_addressing_ref_impl {
if (eq_res == detail::equal_result::EMPTY) {
switch ([&]() {
if constexpr (sizeof(value_type) <= 8) {
return packed_cas(window_ptr + i, value, predicate);
return packed_cas<HasPayload>(window_ptr + i, value, predicate);
} else {
return cas_dependent_write(window_ptr + i, value, predicate);
}
Expand All @@ -292,6 +296,7 @@ class open_addressing_ref_impl {
* element that prevented the insertion) and a `bool` denoting whether the insertion took place or
* not.
*
* @tparam HasPayload Boolean indicating it's a set or map implementation
* @tparam Predicate Predicate type
*
* @param group The Cooperative Group used to perform group insert_and_find
Expand All @@ -302,7 +307,7 @@ class open_addressing_ref_impl {
* @return a pair consisting of an iterator to the element and a bool indicating whether the
* insertion is successful or not.
*/
template <typename Predicate>
template <bool HasPayload, typename Predicate>
__device__ thrust::pair<iterator, bool> insert_and_find(
cooperative_groups::thread_block_tile<cg_size> const& group,
key_type const& key,
Expand Down Expand Up @@ -343,7 +348,7 @@ class open_addressing_ref_impl {
auto const status = [&]() {
if (group.thread_rank() != src_lane) { return insert_result::CONTINUE; }
if constexpr (sizeof(value_type) <= 8) {
return packed_cas(slot_ptr, value, predicate);
return packed_cas<HasPayload>(slot_ptr, value, predicate);
} else {
return cas_dependent_write(slot_ptr, value, predicate);
}
Expand Down Expand Up @@ -649,6 +654,7 @@ class open_addressing_ref_impl {
/**
* @brief Inserts the specified element with one single CAS operation.
*
* @tparam HasPayload Boolean indicating it's a set or map implementation
* @tparam Predicate Predicate type
*
* @param slot Pointer to the slot in memory
Expand All @@ -657,20 +663,37 @@ class open_addressing_ref_impl {
*
* @return Result of this operation, i.e., success/continue/duplicate
*/
template <typename Predicate>
template <bool HasPayload, typename Predicate>
[[nodiscard]] __device__ constexpr insert_result packed_cas(value_type* slot,
value_type const& value,
Predicate const& predicate) noexcept
{
auto old = compare_and_swap(slot, this->empty_slot_sentinel_, value);
auto* old_ptr = reinterpret_cast<value_type*>(&old);
if (cuco::detail::bitwise_compare(*old_ptr, this->empty_slot_sentinel_)) {
auto old = compare_and_swap(slot, this->empty_slot_sentinel_, value);
auto* old_ptr = reinterpret_cast<value_type*>(&old);
auto const inserted = [&]() {
if constexpr (HasPayload) {
// If it's a set implementation, compare the whole slot content
return cuco::detail::bitwise_compare(*old_ptr, this->empty_slot_sentinel_);
} else {
// If it's a map implementation, compare keys only
return cuco::detail::bitwise_compare(old_ptr->first, this->empty_slot_sentinel_.first);
}
}();
if (inserted) {
return insert_result::SUCCESS;
} else {
// Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare
return predicate.equal_to(*old_ptr, value) == detail::equal_result::EQUAL
? insert_result::DUPLICATE
: insert_result::CONTINUE;
auto const res = [&]() {
if constexpr (HasPayload) {
// If it's a set implementation, compare the whole slot content
return predicate.equal_to(*old_ptr, value);
} else {
// If it's a map implementation, compare keys only
return predicate.equal_to(old_ptr->first, value.first);
}
}();
return res == detail::equal_result::EQUAL ? insert_result::DUPLICATE
: insert_result::CONTINUE;
}
}

Expand Down Expand Up @@ -761,6 +784,7 @@ class open_addressing_ref_impl {
* @note Dispatches the correct implementation depending on the container
* type and presence of other operator mixins.
*
* @tparam HasPayload Boolean indicating it's a set or map implementation
* @tparam Predicate Predicate type
*
* @param slot Pointer to the slot in memory
Expand All @@ -769,13 +793,13 @@ class open_addressing_ref_impl {
*
* @return Result of this operation, i.e., success/continue/duplicate
*/
template <typename Predicate>
template <bool HasPayload, typename Predicate>
[[nodiscard]] __device__ insert_result attempt_insert(value_type* slot,
value_type const& value,
Predicate const& predicate) noexcept
{
if constexpr (sizeof(value_type) <= 8) {
return packed_cas(slot, value, predicate);
return packed_cas<HasPayload>(slot, value, predicate);
} else {
#if (_CUDA_ARCH__ < 700)
return cas_dependent_write(slot, value, predicate);
Expand Down
20 changes: 12 additions & 8 deletions include/cuco/detail/static_map/static_map_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,9 @@ class operator_impl<
*/
__device__ bool insert(value_type const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
return ref_.impl_.insert(value.first, value, ref_.predicate_);
ref_type& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = false;
return ref_.impl_.insert<has_payload>(value.first, value, ref_.predicate_);
}

/**
Expand All @@ -223,8 +224,9 @@ class operator_impl<
__device__ bool insert(cooperative_groups::thread_block_tile<cg_size> const& group,
value_type const& value) noexcept
{
auto& ref_ = static_cast<ref_type&>(*this);
return ref_.impl_.insert(group, value.first, value, ref_.predicate_);
auto& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = false;
return ref_.impl_.insert<has_payload>(group, value.first, value, ref_.predicate_);
}
};

Expand Down Expand Up @@ -289,8 +291,9 @@ class operator_impl<
*/
__device__ thrust::pair<iterator, bool> insert_and_find(value_type const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
return ref_.impl_.insert_and_find(value.first, value, ref_.predicate_);
ref_type& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = false;
return ref_.impl_.insert_and_find<has_payload>(value.first, value, ref_.predicate_);
}

/**
Expand All @@ -309,8 +312,9 @@ class operator_impl<
__device__ thrust::pair<iterator, bool> insert_and_find(
cooperative_groups::thread_block_tile<cg_size> const& group, value_type const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
return ref_.impl_.insert_and_find(group, value.first, value, ref_.predicate_);
ref_type& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = false;
return ref_.impl_.insert_and_find<has_payload>(group, value.first, value, ref_.predicate_);
}
};

Expand Down
20 changes: 12 additions & 8 deletions include/cuco/detail/static_set/static_set_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ class operator_impl<op::insert_tag,
*/
__device__ bool insert(value_type const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
return ref_.impl_.insert(value, value, ref_.predicate_);
ref_type& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = true;
return ref_.impl_.insert<has_payload>(value, value, ref_.predicate_);
}

/**
Expand All @@ -115,8 +116,9 @@ class operator_impl<op::insert_tag,
__device__ bool insert(cooperative_groups::thread_block_tile<cg_size> const& group,
value_type const& value) noexcept
{
auto& ref_ = static_cast<ref_type&>(*this);
return ref_.impl_.insert(group, value, value, ref_.predicate_);
auto& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = true;
return ref_.impl_.insert<has_payload>(group, value, value, ref_.predicate_);
}
};

Expand Down Expand Up @@ -179,8 +181,9 @@ class operator_impl<op::insert_and_find_tag,
*/
__device__ thrust::pair<iterator, bool> insert_and_find(value_type const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
return ref_.impl_.insert_and_find(value, value, ref_.predicate_);
ref_type& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = true;
return ref_.impl_.insert_and_find<has_payload>(value, value, ref_.predicate_);
}

/**
Expand All @@ -199,8 +202,9 @@ class operator_impl<op::insert_and_find_tag,
__device__ thrust::pair<iterator, bool> insert_and_find(
cooperative_groups::thread_block_tile<cg_size> const& group, value_type const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
return ref_.impl_.insert_and_find(group, value, value, ref_.predicate_);
ref_type& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = true;
return ref_.impl_.insert_and_find<has_payload>(group, value, value, ref_.predicate_);
}
};

Expand Down
Loading