Skip to content

Commit

Permalink
change hash map logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Jasmine-ge committed Nov 21, 2024
1 parent c91a82d commit 1050a2d
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 241 deletions.
13 changes: 7 additions & 6 deletions src/AggregateFunctions/Streaming/AggregateFunctionUniq.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <AggregateFunctions/ThetaSketchData.h>
#include <AggregateFunctions/UniqVariadicHash.h>
#include <AggregateFunctions/Streaming/CountedValueMap.h>
#include <AggregateFunctions/Streaming/CountedValueHashMap.h>

#include "config.h"

Expand All @@ -52,7 +53,7 @@ namespace Streaming
/// uniq
struct AggregateFunctionUniqUniquesHashSetData
{
using Set = CountedValueSet<UInt64>;
using Set = CountedValueHashMap<UInt64>;
Set set;

constexpr static bool is_able_to_parallelize_merge = false;
Expand All @@ -66,7 +67,7 @@ struct AggregateFunctionUniqUniquesHashSetData
template <bool is_exact_, bool argument_is_tuple_>
struct AggregateFunctionUniqUniquesHashSetDataForVariadic
{
using Set = CountedValueSet<UInt64>;
using Set = CountedValueHashMap<UInt64>;
Set set;

constexpr static bool is_able_to_parallelize_merge = false;
Expand All @@ -83,7 +84,7 @@ struct AggregateFunctionUniqUniquesHashSetDataForVariadic
template <typename T, bool is_able_to_parallelize_merge_>
struct AggregateFunctionUniqExactData
{
using Set = CountedValueSet<T>;
using Set = CountedValueHashMap<T>;
Set set;

constexpr static bool is_able_to_parallelize_merge = is_able_to_parallelize_merge_;
Expand All @@ -97,7 +98,7 @@ struct AggregateFunctionUniqExactData
template <bool is_able_to_parallelize_merge_>
struct AggregateFunctionUniqExactData<String, is_able_to_parallelize_merge_>
{
using Set = CountedValueSet<UInt128>;
using Set = CountedValueHashMap<UInt128>;
Set set;

constexpr static bool is_able_to_parallelize_merge = is_able_to_parallelize_merge_;
Expand Down Expand Up @@ -126,8 +127,8 @@ struct IsUniqExactSet : std::false_type
{
};

template <typename T1>
struct IsUniqExactSet<CountedValueSet<T1>> : std::true_type
template <typename T>
struct IsUniqExactSet<CountedValueHashMap<T>> : std::true_type
{
};

Expand Down
149 changes: 24 additions & 125 deletions src/AggregateFunctions/Streaming/CountedValueHashMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ namespace DB
namespace Streaming
{
template <typename T>
struct CountedValueArena
struct CountedValueHashmapArena
{
CountedValueArena() = default;
CountedValueHashmapArena() = default;
T emplace(T key) { return std::move(key); }
void free(const T & /*key*/) { }
};
Expand All @@ -28,9 +28,9 @@ struct CountedValueArena
* otherwise the reference would be invalid after the processed block is released.
*/
template <>
struct CountedValueArena<StringRef>
struct CountedValueHashmapArena<StringRef>
{
CountedValueArena() = default;
CountedValueHashmapArena() = default;
StringRef emplace(StringRef key) { return copyStringInArena(arena, key); }

void free(StringRef key)
Expand All @@ -48,21 +48,20 @@ struct CountedValueArena<StringRef>
/// CountedValueHashMap maintain count for each key with maximum capacity
/// When capacity hits the max capacity threshold, it will delete
/// the minimum / maximum key in the map to maintain the capacity constrain
template <typename T, typenmae KeyEq = void>
template <typename T, typename KeyCompare = void>
class CountedValueHashMap
{
public:
using Eq = std::conditional_t<std::is_void_v<KeyEq>, std::equal<T>, KeyCompare>;

/// NOTE: Generally we prefer to use absl::btree_map, but it requires `Compare` is nothrow copy constructible, so if not, we use std::map
using FlatHashMap = absl::flat_hash_map<T, uint32_t, Eq>;
using STDUnorderedMap = std::unordered_map<T, uint32_t, Eq>;
using Map = std::conditional_t<std::is_nothrow_copy_constructible<Eq>::value, FlatHashMap, STDUnorderedMap>;
using Compare = std::conditional_t<std::is_void_v<KeyCompare>, std::equal_to<T>, KeyCompare>;
using FlatHashMap = absl::flat_hash_map<T, uint32_t, DefaultHash<T>, Compare>;
using STDHashMap = std::unordered_map<T, uint32_t, DefaultHash<T>, Compare>;
using Map = std::conditional_t<std::is_nothrow_copy_constructible<Compare>::value, FlatHashMap, STDHashMap>;
using size_type = typename Map::size_type;
using key_type = T;

CountedValueHashMap() = default;
explicit CountedValueHashMap(size_type max_size_, Eq && eq = Eq{})
: max_size(max_size_), arena(std::make_unique<CountedValueArena<T>>()), m(std::move(comp))
explicit CountedValueHashMap(size_type max_size_, Compare && comp = Compare{})
: max_size(max_size_), arena(std::make_unique<CountedValueHashmapArena<T>>())
{
}

Expand All @@ -78,12 +77,7 @@ class CountedValueHashMap
Map::iterator emplace(T v)
{
if (atCapacity())
{
/// At capacity, this is an optimization
/// fast ignore elements we don't want to maintain
if (less(lastValue(), v))
return m.end();
}
return m.end();

if (auto iter = m.find(v); iter != m.end())
{
Expand All @@ -95,8 +89,6 @@ class CountedValueHashMap
/// Didn't find v in the map
auto [new_iter, inserted] = m.emplace(arena->emplace(std::move(v)), 1);
assert(inserted);

eraseExtraElements();
return new_iter;
}
}
Expand Down Expand Up @@ -134,44 +126,10 @@ class CountedValueHashMap
return m.find(v) != m.end();
}

bool firstValue(T & v) const
{
if (unlikely(m.empty()))
return false;

v = m.begin()->first;
return true;
}

const T & firstValue() const
{
if (unlikely(m.empty()))
throw std::logic_error("Call top on empty value map");

return m.begin()->first;
}

bool lastValue(T & v) const
{
if (unlikely(m.empty()))
return false;

v = m.rbegin()->first;
return true;
}

const T & lastValue() const
{
if (unlikely(m.empty()))
throw std::logic_error("Call top on empty value map");

return m.rbegin()->first;
}

void merge(const CountedValueMap & rhs) { merge<true>(rhs); }
void merge(const CountedValueHashMap & rhs) { merge<true>(rhs); }

/// After merge, `rhs` will be empty
void merge(CountedValueMap & rhs)
void merge(CountedValueHashMap & rhs)
{
merge<false>(rhs);
rhs.clear();
Expand All @@ -180,7 +138,7 @@ class CountedValueHashMap
void clear()
{
m.clear();
arena = std::make_unique<CountedValueArena<T>>();
arena = std::make_unique<CountedValueHashmapArena<T>>();
}

inline bool atCapacity() const { return max_size > 0 && m.size() == max_size; }
Expand All @@ -193,7 +151,7 @@ class CountedValueHashMap

bool empty() const { return m.empty(); }

void swap(CountedValueMap & rhs)
void swap(CountedValueHashMap & rhs)
{
std::swap(max_size, rhs.max_size);
m.swap(rhs.m);
Expand All @@ -206,9 +164,9 @@ class CountedValueHashMap
auto end() { return m.end(); }
auto end() const { return m.end(); }

CountedValueArena<T> & getArena() { return *arena; }
CountedValueHashmapArena<T> & getArena() { return *arena; }

static CountedValueMap & merge(CountedValueMap & lhs, CountedValueMap & rhs)
static CountedValueHashMap & merge(CountedValueHashMap & lhs, CountedValueHashMap & rhs)
{
if (rhs.size() > lhs.size())
lhs.swap(rhs);
Expand All @@ -232,58 +190,18 @@ class CountedValueHashMap
return swap(rhs);
}

/// The algorithm (for example minimum sorting):
/// I) If lhs and rhs have no overlap value ranges : all values in lhs is greater or less than those in the rhs and
/// one of them are at capacity. There are 2 fast paths
/// 1) if lhs is at capacity, and values in lhs are all greater than those in the rhs, there is nothing to merge. Just return
/// 2) if rhs is at capacity, and values in rhs are all greater than those in the lhs, clear lhs and then copy everything from rhs.
/// Ideally it can be a lightweight swap, but since it is const rhs, we can't do that
/// II) Slow path
assert(!rhs.empty() && !empty());

/// Optimize path : if lhs is at capacity and rhs has no overlap of lhs
if (atCapacity())
{
/// If all values in lhs are less/greater (i.e. for minimum/maximum) than rhs
/// we don't need any merge
if (less(lastValue(), rhs.firstValue()))
return;

/// If all values in lhs are greater than rhs and rhs are at capacity as well
if (rhs.atCapacity() && rhs.capacity() == capacity() && greater(firstValue(), rhs.lastValue()))
{
if constexpr (copy)
return clearAndClone(rhs);
else
return clearAndSwap(rhs);
}
}

if (rhs.atCapacity() && rhs.capacity() == capacity())
{
/// If all values in lhs are greater than rhs
/// we can clear up lhs elements and copy over elements from rhs
if (greater(firstValue(), rhs.lastValue()))
{
if constexpr (copy)
return clearAndClone(rhs);
else
return clearAndSwap(rhs);
}
}

/// Loop from min to max
/// Directly loop all elements as there's no order nor capacity
for (auto src_iter = rhs.m.begin(); src_iter != rhs.m.end(); ++src_iter)
{
if (atCapacity() && less(lastValue(), src_iter->first))
if (atCapacity())
/// We reached maximum capacity and all other values from rhs will be
/// greater than those already in lhs. Stop merging more
break;

doMerge<copy>(src_iter);
}

eraseExtraElements();
}

template <bool copy, typename Iter>
Expand All @@ -303,20 +221,7 @@ class CountedValueHashMap
}
}

void eraseExtraElements()
{
if (max_size <= 0)
return;

while (m.size() > max_size)
{
auto last_elem = --m.end();
arena->free(last_elem->first);
m.erase(last_elem);
}
}

inline void clearAndClone(const CountedValueMap & rhs)
inline void clearAndClone(const CountedValueHashMap & rhs)
{
clear();

Expand All @@ -325,21 +230,15 @@ class CountedValueHashMap
m.emplace(arena->emplace(src_iter->first), src_iter->second);
}

inline void clearAndSwap(CountedValueMap & rhs)
inline void clearAndSwap(CountedValueHashMap & rhs)
{
clear();
swap(rhs);
}

/// \returns: true means l < r for minimum order
inline bool less(const T & l, const T & r) const { return m.key_comp()(l, r); }

/// \returns: true means l > r for minimum order
inline bool greater(const T & l, const T & r) const { return m.key_comp()(r, l); }

private:
size_type max_size;
std::unique_ptr<CountedValueArena<T>> arena;
std::unique_ptr<CountedValueHashmapArena<T>> arena;
Map m;
};
}
Expand Down
Loading

0 comments on commit 1050a2d

Please sign in to comment.