From 16bd608ed1ba13cee2592665ce005937375baf82 Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Mon, 21 Aug 2023 12:09:03 -0700 Subject: [PATCH 1/3] Fix an equal check bug in packed_case: shouldn't compare the whole slot for maps --- .../cuco/detail/open_addressing_ref_impl.cuh | 64 +++++++++++++------ .../cuco/detail/static_map/static_map_ref.inl | 20 +++--- .../cuco/detail/static_set/static_set_ref.inl | 20 +++--- 3 files changed, 69 insertions(+), 35 deletions(-) diff --git a/include/cuco/detail/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing_ref_impl.cuh index 99187cc51..8abaecbef 100644 --- a/include/cuco/detail/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing_ref_impl.cuh @@ -132,6 +132,7 @@ class open_addressing_ref_impl { /** * @brief Inserts an element. * + * @tparam IsSet Boolean indicating it's a set or map implementation * @tparam Predicate Predicate type * * @param key Key of the element to insert @@ -140,7 +141,7 @@ class open_addressing_ref_impl { * * @return True if the given element is successfully inserted */ - template + template __device__ bool insert(key_type const& key, value_type const& value, Predicate const& predicate) noexcept @@ -158,7 +159,7 @@ class open_addressing_ref_impl { if (eq_res == detail::equal_result::EQUAL) { return false; } if (eq_res == detail::equal_result::EMPTY) { auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content); - switch (attempt_insert( + switch (attempt_insert( (storage_ref_.data() + *probing_iter)->data() + intra_window_index, value, predicate)) { case insert_result::CONTINUE: continue; case insert_result::SUCCESS: return true; @@ -173,6 +174,7 @@ class open_addressing_ref_impl { /** * @brief Inserts an element. * + * @tparam IsSet Boolean indicating it's a set or map implementation * @tparam Predicate Predicate type * * @param group The Cooperative Group used to perform group insert @@ -182,7 +184,7 @@ class open_addressing_ref_impl { * * @return True if the given element is successfully inserted */ - template + template __device__ bool insert(cooperative_groups::thread_block_tile const& group, key_type const& key, value_type const& value, @@ -214,9 +216,10 @@ class open_addressing_ref_impl { auto const src_lane = __ffs(group_contains_empty) - 1; auto const status = (group.thread_rank() == src_lane) - ? attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index, - value, - predicate) + ? attempt_insert( + (storage_ref_.data() + *probing_iter)->data() + intra_window_index, + value, + predicate) : insert_result::CONTINUE; switch (group.shfl(status, src_lane)) { @@ -237,6 +240,7 @@ class open_addressing_ref_impl { * element that prevented the insertion) and a `bool` denoting whether the insertion took place or * not. * + * @tparam IsSet Boolean indicating it's a set or map implementation * @tparam Predicate Predicate type * * @param key Key of the element to insert @@ -246,7 +250,7 @@ class open_addressing_ref_impl { * @return a pair consisting of an iterator to the element and a bool indicating whether the * insertion is successful or not. */ - template + template __device__ thrust::pair insert_and_find(key_type const& key, value_type const& value, Predicate const& predicate) noexcept @@ -266,7 +270,7 @@ class open_addressing_ref_impl { if (eq_res == detail::equal_result::EMPTY) { switch ([&]() { if constexpr (sizeof(value_type) <= 8) { - return packed_cas(window_ptr + i, value, predicate); + return packed_cas(window_ptr + i, value, predicate); } else { return cas_dependent_write(window_ptr + i, value, predicate); } @@ -292,6 +296,7 @@ class open_addressing_ref_impl { * element that prevented the insertion) and a `bool` denoting whether the insertion took place or * not. * + * @tparam IsSet Boolean indicating it's a set or map implementation * @tparam Predicate Predicate type * * @param group The Cooperative Group used to perform group insert_and_find @@ -302,7 +307,7 @@ class open_addressing_ref_impl { * @return a pair consisting of an iterator to the element and a bool indicating whether the * insertion is successful or not. */ - template + template __device__ thrust::pair insert_and_find( cooperative_groups::thread_block_tile const& group, key_type const& key, @@ -343,7 +348,7 @@ class open_addressing_ref_impl { auto const status = [&]() { if (group.thread_rank() != src_lane) { return insert_result::CONTINUE; } if constexpr (sizeof(value_type) <= 8) { - return packed_cas(slot_ptr, value, predicate); + return packed_cas(slot_ptr, value, predicate); } else { return cas_dependent_write(slot_ptr, value, predicate); } @@ -649,6 +654,7 @@ class open_addressing_ref_impl { /** * @brief Inserts the specified element with one single CAS operation. * + * @tparam IsSet Boolean indicating it's a set or map implementation * @tparam Predicate Predicate type * * @param slot Pointer to the slot in memory @@ -657,20 +663,39 @@ class open_addressing_ref_impl { * * @return Result of this operation, i.e., success/continue/duplicate */ - template + template [[nodiscard]] __device__ constexpr insert_result packed_cas(value_type* slot, value_type const& value, Predicate const& predicate) noexcept { - auto old = compare_and_swap(slot, this->empty_slot_sentinel_, value); - auto* old_ptr = reinterpret_cast(&old); - if (cuco::detail::bitwise_compare(*old_ptr, this->empty_slot_sentinel_)) { + auto old = compare_and_swap(slot, this->empty_slot_sentinel_, value); + auto* old_ptr = reinterpret_cast(&old); + auto const inserted = [&]() { + if constexpr (IsSet) { + // If it's a set implementation, compare the whole slot content + return cuco::detail::bitwise_compare(*old_ptr, this->empty_slot_sentinel_); + } + if constexpr (not IsSet) { + // If it's a map implementation, compare keys only + return cuco::detail::bitwise_compare(old_ptr->first, this->empty_slot_sentinel_.first); + } + }(); + if (inserted) { return insert_result::SUCCESS; } else { // Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare - return predicate.equal_to(*old_ptr, value) == detail::equal_result::EQUAL - ? insert_result::DUPLICATE - : insert_result::CONTINUE; + auto const res = [&]() { + if constexpr (IsSet) { + // If it's a set implementation, compare the whole slot content + return predicate.equal_to(*old_ptr, value); + } + if constexpr (not IsSet) { + // If it's a map implementation, compare keys only + return predicate.equal_to(old_ptr->first, value.first); + } + }(); + return res == detail::equal_result::EQUAL ? insert_result::DUPLICATE + : insert_result::CONTINUE; } } @@ -761,6 +786,7 @@ class open_addressing_ref_impl { * @note Dispatches the correct implementation depending on the container * type and presence of other operator mixins. * + * @tparam IsSet Boolean indicating it's a set or map implementation * @tparam Predicate Predicate type * * @param slot Pointer to the slot in memory @@ -769,13 +795,13 @@ class open_addressing_ref_impl { * * @return Result of this operation, i.e., success/continue/duplicate */ - template + template [[nodiscard]] __device__ insert_result attempt_insert(value_type* slot, value_type const& value, Predicate const& predicate) noexcept { if constexpr (sizeof(value_type) <= 8) { - return packed_cas(slot, value, predicate); + return packed_cas(slot, value, predicate); } else { #if (_CUDA_ARCH__ < 700) return cas_dependent_write(slot, value, predicate); diff --git a/include/cuco/detail/static_map/static_map_ref.inl b/include/cuco/detail/static_map/static_map_ref.inl index f3c412924..5c2fba842 100644 --- a/include/cuco/detail/static_map/static_map_ref.inl +++ b/include/cuco/detail/static_map/static_map_ref.inl @@ -209,8 +209,9 @@ class operator_impl< */ __device__ bool insert(value_type const& value) noexcept { - ref_type& ref_ = static_cast(*this); - return ref_.impl_.insert(value.first, value, ref_.predicate_); + ref_type& ref_ = static_cast(*this); + auto constexpr is_set = false; + return ref_.impl_.insert(value.first, value, ref_.predicate_); } /** @@ -223,8 +224,9 @@ class operator_impl< __device__ bool insert(cooperative_groups::thread_block_tile const& group, value_type const& value) noexcept { - auto& ref_ = static_cast(*this); - return ref_.impl_.insert(group, value.first, value, ref_.predicate_); + auto& ref_ = static_cast(*this); + auto constexpr is_set = false; + return ref_.impl_.insert(group, value.first, value, ref_.predicate_); } }; @@ -289,8 +291,9 @@ class operator_impl< */ __device__ thrust::pair insert_and_find(value_type const& value) noexcept { - ref_type& ref_ = static_cast(*this); - return ref_.impl_.insert_and_find(value.first, value, ref_.predicate_); + ref_type& ref_ = static_cast(*this); + auto constexpr is_set = false; + return ref_.impl_.insert_and_find(value.first, value, ref_.predicate_); } /** @@ -309,8 +312,9 @@ class operator_impl< __device__ thrust::pair insert_and_find( cooperative_groups::thread_block_tile const& group, value_type const& value) noexcept { - ref_type& ref_ = static_cast(*this); - return ref_.impl_.insert_and_find(group, value.first, value, ref_.predicate_); + ref_type& ref_ = static_cast(*this); + auto constexpr is_set = false; + return ref_.impl_.insert_and_find(group, value.first, value, ref_.predicate_); } }; diff --git a/include/cuco/detail/static_set/static_set_ref.inl b/include/cuco/detail/static_set/static_set_ref.inl index 3482738cc..007420500 100644 --- a/include/cuco/detail/static_set/static_set_ref.inl +++ b/include/cuco/detail/static_set/static_set_ref.inl @@ -100,8 +100,9 @@ class operator_impl(*this); - return ref_.impl_.insert(value, value, ref_.predicate_); + ref_type& ref_ = static_cast(*this); + auto constexpr is_set = true; + return ref_.impl_.insert(value, value, ref_.predicate_); } /** @@ -115,8 +116,9 @@ class operator_impl const& group, value_type const& value) noexcept { - auto& ref_ = static_cast(*this); - return ref_.impl_.insert(group, value, value, ref_.predicate_); + auto& ref_ = static_cast(*this); + auto constexpr is_set = true; + return ref_.impl_.insert(group, value, value, ref_.predicate_); } }; @@ -179,8 +181,9 @@ class operator_impl insert_and_find(value_type const& value) noexcept { - ref_type& ref_ = static_cast(*this); - return ref_.impl_.insert_and_find(value, value, ref_.predicate_); + ref_type& ref_ = static_cast(*this); + auto constexpr is_set = true; + return ref_.impl_.insert_and_find(value, value, ref_.predicate_); } /** @@ -199,8 +202,9 @@ class operator_impl insert_and_find( cooperative_groups::thread_block_tile const& group, value_type const& value) noexcept { - ref_type& ref_ = static_cast(*this); - return ref_.impl_.insert_and_find(group, value, value, ref_.predicate_); + ref_type& ref_ = static_cast(*this); + auto constexpr is_set = true; + return ref_.impl_.insert_and_find(group, value, value, ref_.predicate_); } }; From db72af5d1822ac982abb24f9ce602d5f62a76cde Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Tue, 22 Aug 2023 09:27:53 -0700 Subject: [PATCH 2/3] Renaming: hash_payload instead of is_set --- .../cuco/detail/open_addressing_ref_impl.cuh | 42 +++++++++---------- .../cuco/detail/static_map/static_map_ref.inl | 24 +++++------ .../cuco/detail/static_set/static_set_ref.inl | 24 +++++------ 3 files changed, 45 insertions(+), 45 deletions(-) diff --git a/include/cuco/detail/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing_ref_impl.cuh index 8abaecbef..b2565969d 100644 --- a/include/cuco/detail/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing_ref_impl.cuh @@ -132,7 +132,7 @@ class open_addressing_ref_impl { /** * @brief Inserts an element. * - * @tparam IsSet Boolean indicating it's a set or map implementation + * @tparam HasPayload Boolean indicating it's a set or map implementation * @tparam Predicate Predicate type * * @param key Key of the element to insert @@ -141,7 +141,7 @@ class open_addressing_ref_impl { * * @return True if the given element is successfully inserted */ - template + template __device__ bool insert(key_type const& key, value_type const& value, Predicate const& predicate) noexcept @@ -159,7 +159,7 @@ class open_addressing_ref_impl { if (eq_res == detail::equal_result::EQUAL) { return false; } if (eq_res == detail::equal_result::EMPTY) { auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content); - switch (attempt_insert( + switch (attempt_insert( (storage_ref_.data() + *probing_iter)->data() + intra_window_index, value, predicate)) { case insert_result::CONTINUE: continue; case insert_result::SUCCESS: return true; @@ -174,7 +174,7 @@ class open_addressing_ref_impl { /** * @brief Inserts an element. * - * @tparam IsSet Boolean indicating it's a set or map implementation + * @tparam HasPayload Boolean indicating it's a set or map implementation * @tparam Predicate Predicate type * * @param group The Cooperative Group used to perform group insert @@ -184,7 +184,7 @@ class open_addressing_ref_impl { * * @return True if the given element is successfully inserted */ - template + template __device__ bool insert(cooperative_groups::thread_block_tile const& group, key_type const& key, value_type const& value, @@ -216,7 +216,7 @@ class open_addressing_ref_impl { auto const src_lane = __ffs(group_contains_empty) - 1; auto const status = (group.thread_rank() == src_lane) - ? attempt_insert( + ? attempt_insert( (storage_ref_.data() + *probing_iter)->data() + intra_window_index, value, predicate) @@ -240,7 +240,7 @@ class open_addressing_ref_impl { * element that prevented the insertion) and a `bool` denoting whether the insertion took place or * not. * - * @tparam IsSet Boolean indicating it's a set or map implementation + * @tparam HasPayload Boolean indicating it's a set or map implementation * @tparam Predicate Predicate type * * @param key Key of the element to insert @@ -250,7 +250,7 @@ class open_addressing_ref_impl { * @return a pair consisting of an iterator to the element and a bool indicating whether the * insertion is successful or not. */ - template + template __device__ thrust::pair insert_and_find(key_type const& key, value_type const& value, Predicate const& predicate) noexcept @@ -270,7 +270,7 @@ class open_addressing_ref_impl { if (eq_res == detail::equal_result::EMPTY) { switch ([&]() { if constexpr (sizeof(value_type) <= 8) { - return packed_cas(window_ptr + i, value, predicate); + return packed_cas(window_ptr + i, value, predicate); } else { return cas_dependent_write(window_ptr + i, value, predicate); } @@ -296,7 +296,7 @@ class open_addressing_ref_impl { * element that prevented the insertion) and a `bool` denoting whether the insertion took place or * not. * - * @tparam IsSet Boolean indicating it's a set or map implementation + * @tparam HasPayload Boolean indicating it's a set or map implementation * @tparam Predicate Predicate type * * @param group The Cooperative Group used to perform group insert_and_find @@ -307,7 +307,7 @@ class open_addressing_ref_impl { * @return a pair consisting of an iterator to the element and a bool indicating whether the * insertion is successful or not. */ - template + template __device__ thrust::pair insert_and_find( cooperative_groups::thread_block_tile const& group, key_type const& key, @@ -348,7 +348,7 @@ class open_addressing_ref_impl { auto const status = [&]() { if (group.thread_rank() != src_lane) { return insert_result::CONTINUE; } if constexpr (sizeof(value_type) <= 8) { - return packed_cas(slot_ptr, value, predicate); + return packed_cas(slot_ptr, value, predicate); } else { return cas_dependent_write(slot_ptr, value, predicate); } @@ -654,7 +654,7 @@ class open_addressing_ref_impl { /** * @brief Inserts the specified element with one single CAS operation. * - * @tparam IsSet Boolean indicating it's a set or map implementation + * @tparam HasPayload Boolean indicating it's a set or map implementation * @tparam Predicate Predicate type * * @param slot Pointer to the slot in memory @@ -663,7 +663,7 @@ class open_addressing_ref_impl { * * @return Result of this operation, i.e., success/continue/duplicate */ - template + template [[nodiscard]] __device__ constexpr insert_result packed_cas(value_type* slot, value_type const& value, Predicate const& predicate) noexcept @@ -671,11 +671,11 @@ class open_addressing_ref_impl { auto old = compare_and_swap(slot, this->empty_slot_sentinel_, value); auto* old_ptr = reinterpret_cast(&old); auto const inserted = [&]() { - if constexpr (IsSet) { + if constexpr (HasPayload) { // If it's a set implementation, compare the whole slot content return cuco::detail::bitwise_compare(*old_ptr, this->empty_slot_sentinel_); } - if constexpr (not IsSet) { + if constexpr (not HasPayload) { // If it's a map implementation, compare keys only return cuco::detail::bitwise_compare(old_ptr->first, this->empty_slot_sentinel_.first); } @@ -685,11 +685,11 @@ class open_addressing_ref_impl { } else { // Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare auto const res = [&]() { - if constexpr (IsSet) { + if constexpr (HasPayload) { // If it's a set implementation, compare the whole slot content return predicate.equal_to(*old_ptr, value); } - if constexpr (not IsSet) { + if constexpr (not HasPayload) { // If it's a map implementation, compare keys only return predicate.equal_to(old_ptr->first, value.first); } @@ -786,7 +786,7 @@ class open_addressing_ref_impl { * @note Dispatches the correct implementation depending on the container * type and presence of other operator mixins. * - * @tparam IsSet Boolean indicating it's a set or map implementation + * @tparam HasPayload Boolean indicating it's a set or map implementation * @tparam Predicate Predicate type * * @param slot Pointer to the slot in memory @@ -795,13 +795,13 @@ class open_addressing_ref_impl { * * @return Result of this operation, i.e., success/continue/duplicate */ - template + template [[nodiscard]] __device__ insert_result attempt_insert(value_type* slot, value_type const& value, Predicate const& predicate) noexcept { if constexpr (sizeof(value_type) <= 8) { - return packed_cas(slot, value, predicate); + return packed_cas(slot, value, predicate); } else { #if (_CUDA_ARCH__ < 700) return cas_dependent_write(slot, value, predicate); diff --git a/include/cuco/detail/static_map/static_map_ref.inl b/include/cuco/detail/static_map/static_map_ref.inl index 5c2fba842..536973b20 100644 --- a/include/cuco/detail/static_map/static_map_ref.inl +++ b/include/cuco/detail/static_map/static_map_ref.inl @@ -209,9 +209,9 @@ class operator_impl< */ __device__ bool insert(value_type const& value) noexcept { - ref_type& ref_ = static_cast(*this); - auto constexpr is_set = false; - return ref_.impl_.insert(value.first, value, ref_.predicate_); + ref_type& ref_ = static_cast(*this); + auto constexpr has_payload = false; + return ref_.impl_.insert(value.first, value, ref_.predicate_); } /** @@ -224,9 +224,9 @@ class operator_impl< __device__ bool insert(cooperative_groups::thread_block_tile const& group, value_type const& value) noexcept { - auto& ref_ = static_cast(*this); - auto constexpr is_set = false; - return ref_.impl_.insert(group, value.first, value, ref_.predicate_); + auto& ref_ = static_cast(*this); + auto constexpr has_payload = false; + return ref_.impl_.insert(group, value.first, value, ref_.predicate_); } }; @@ -291,9 +291,9 @@ class operator_impl< */ __device__ thrust::pair insert_and_find(value_type const& value) noexcept { - ref_type& ref_ = static_cast(*this); - auto constexpr is_set = false; - return ref_.impl_.insert_and_find(value.first, value, ref_.predicate_); + ref_type& ref_ = static_cast(*this); + auto constexpr has_payload = false; + return ref_.impl_.insert_and_find(value.first, value, ref_.predicate_); } /** @@ -312,9 +312,9 @@ class operator_impl< __device__ thrust::pair insert_and_find( cooperative_groups::thread_block_tile const& group, value_type const& value) noexcept { - ref_type& ref_ = static_cast(*this); - auto constexpr is_set = false; - return ref_.impl_.insert_and_find(group, value.first, value, ref_.predicate_); + ref_type& ref_ = static_cast(*this); + auto constexpr has_payload = false; + return ref_.impl_.insert_and_find(group, value.first, value, ref_.predicate_); } }; diff --git a/include/cuco/detail/static_set/static_set_ref.inl b/include/cuco/detail/static_set/static_set_ref.inl index 007420500..3131f3764 100644 --- a/include/cuco/detail/static_set/static_set_ref.inl +++ b/include/cuco/detail/static_set/static_set_ref.inl @@ -100,9 +100,9 @@ class operator_impl(*this); - auto constexpr is_set = true; - return ref_.impl_.insert(value, value, ref_.predicate_); + ref_type& ref_ = static_cast(*this); + auto constexpr has_payload = true; + return ref_.impl_.insert(value, value, ref_.predicate_); } /** @@ -116,9 +116,9 @@ class operator_impl const& group, value_type const& value) noexcept { - auto& ref_ = static_cast(*this); - auto constexpr is_set = true; - return ref_.impl_.insert(group, value, value, ref_.predicate_); + auto& ref_ = static_cast(*this); + auto constexpr has_payload = true; + return ref_.impl_.insert(group, value, value, ref_.predicate_); } }; @@ -181,9 +181,9 @@ class operator_impl insert_and_find(value_type const& value) noexcept { - ref_type& ref_ = static_cast(*this); - auto constexpr is_set = true; - return ref_.impl_.insert_and_find(value, value, ref_.predicate_); + ref_type& ref_ = static_cast(*this); + auto constexpr has_payload = true; + return ref_.impl_.insert_and_find(value, value, ref_.predicate_); } /** @@ -202,9 +202,9 @@ class operator_impl insert_and_find( cooperative_groups::thread_block_tile const& group, value_type const& value) noexcept { - ref_type& ref_ = static_cast(*this); - auto constexpr is_set = true; - return ref_.impl_.insert_and_find(group, value, value, ref_.predicate_); + ref_type& ref_ = static_cast(*this); + auto constexpr has_payload = true; + return ref_.impl_.insert_and_find(group, value, value, ref_.predicate_); } }; From 4f8953daf5aad6ad3f08ae60d9cc5fa3bfbd8ea9 Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Tue, 22 Aug 2023 09:34:57 -0700 Subject: [PATCH 3/3] Minor cleanups --- include/cuco/detail/open_addressing_ref_impl.cuh | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/include/cuco/detail/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing_ref_impl.cuh index b2565969d..4aa701759 100644 --- a/include/cuco/detail/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing_ref_impl.cuh @@ -674,8 +674,7 @@ class open_addressing_ref_impl { if constexpr (HasPayload) { // If it's a set implementation, compare the whole slot content return cuco::detail::bitwise_compare(*old_ptr, this->empty_slot_sentinel_); - } - if constexpr (not HasPayload) { + } else { // If it's a map implementation, compare keys only return cuco::detail::bitwise_compare(old_ptr->first, this->empty_slot_sentinel_.first); } @@ -688,8 +687,7 @@ class open_addressing_ref_impl { if constexpr (HasPayload) { // If it's a set implementation, compare the whole slot content return predicate.equal_to(*old_ptr, value); - } - if constexpr (not HasPayload) { + } else { // If it's a map implementation, compare keys only return predicate.equal_to(old_ptr->first, value.first); }