Skip to content

Commit

Permalink
feat: Add key/value support to radix sort in breeze (#11733)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #11733

Reviewed By: mbasmanova

Differential Revision: D66792931

Pulled By: bikramSingh91

fbshipit-source-id: a61ea1aa73c5213d971b02c220247f1e32e73f16
  • Loading branch information
David Reveman authored and facebook-github-bot committed Dec 5, 2024
1 parent eb49cea commit 8b4663d
Show file tree
Hide file tree
Showing 29 changed files with 1,188 additions and 351 deletions.
96 changes: 69 additions & 27 deletions velox/experimental/breeze/breeze/algorithms/sort.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,20 @@ struct SortBlockType<unsigned> {
}
};

template <typename PlatformT, int ITEMS_PER_THREAD, int RADIX_BITS, typename T>
template <typename KeyT, typename ValueT, int BLOCK_ITEMS>
struct KeyValueScatterType {
KeyT keys[BLOCK_ITEMS];
ValueT values[BLOCK_ITEMS];
};

// partial specialization where ValueT is NullType
template <typename KeyT, int BLOCK_ITEMS>
struct KeyValueScatterType<KeyT, utils::NullType, BLOCK_ITEMS> {
KeyT keys[BLOCK_ITEMS];
};

template <typename PlatformT, int ITEMS_PER_THREAD, int RADIX_BITS,
typename KeyT, typename ValueT>
struct DeviceRadixSort {
enum {
BLOCK_THREADS = PlatformT::BLOCK_THREADS,
Expand All @@ -164,25 +177,26 @@ struct DeviceRadixSort {
unsigned global_offsets[NUM_BINS];
int block_idx;
};
struct {
T items[BLOCK_ITEMS];
} scatter;
KeyValueScatterType<KeyT, ValueT, BLOCK_ITEMS> scatter;
};
};

template <typename BlockT, typename InputSlice, typename OffsetSlice,
typename OutputSlice, typename BlockIdxSlice, typename BlockSlice,
typename ScratchSlice>
static ATTR void Sort(PlatformT p, const InputSlice in,
template <typename BlockT, typename KeyInputSlice, typename ValueInputSlice,
typename OffsetSlice, typename KeyOutputSlice,
typename ValueOutputSlice, typename BlockIdxSlice,
typename BlockSlice, typename ScratchSlice>
static ATTR void Sort(PlatformT p, const KeyInputSlice in_keys,
const ValueInputSlice in_values,
const OffsetSlice in_offsets, int start_bit,
int num_pass_bits, OutputSlice out,
int num_pass_bits, KeyOutputSlice out_keys,
ValueOutputSlice out_values,
BlockIdxSlice next_block_idx, BlockSlice blocks,
ScratchSlice scratch, int num_items) {
using namespace functions;
using namespace utils;

enum {
END_BIT = sizeof(T) * /*BITS_PER_BYTE=*/8,
END_BIT = sizeof(KeyT) * /*BITS_PER_BYTE=*/8,
WARP_THREADS = PlatformT::WARP_THREADS,
NUM_WARPS = BLOCK_THREADS / WARP_THREADS,
WARP_ITEMS = WARP_THREADS * ITEMS_PER_THREAD,
Expand Down Expand Up @@ -211,19 +225,19 @@ struct DeviceRadixSort {
// load items into warp-striped arrangement after initializing all values
// to all bits set as that allows us to always use the fast-path version
// radix rank function
T items[ITEMS_PER_THREAD];
KeyT keys[ITEMS_PER_THREAD];
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
items[i] = NumericLimits<T>::max();
keys[i] = NumericLimits<KeyT>::max();
}
const InputSlice it = in.subslice(block.offset);
const KeyInputSlice it = in_keys.subslice(block.offset);
BlockLoad<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, it, make_slice<THREAD, WARP_STRIPED>(items), block.num_items);
p, it, make_slice<THREAD, WARP_STRIPED>(keys), block.num_items);

// convert items to bit ordered representation
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
items[i] = RadixSortTraits<T>::to_bit_ordered(items[i]);
keys[i] = RadixSortTraits<KeyT>::to_bit_ordered(keys[i]);
}

// determine stable rank for each item
Expand All @@ -232,18 +246,31 @@ struct DeviceRadixSort {
int exclusive_scan[BINS_PER_THREAD];
BlockRadixRankT::Rank(
p,
make_bitfield_extractor(make_slice<THREAD, WARP_STRIPED>(items),
make_bitfield_extractor(make_slice<THREAD, WARP_STRIPED>(keys),
start_bit, num_pass_bits),
make_slice<THREAD, WARP_STRIPED>(ranks), make_slice(histogram),
blocks.subslice(block_idx * NUM_BINS), make_slice(exclusive_scan),
make_slice<SHARED>(&scratch->rank));
p.syncthreads();

// scatter items by storing them in shared memory using ranks
// scatter keys by storing them in scratch using ranks
BlockStoreAt<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, make_slice<THREAD, WARP_STRIPED>(items),
p, make_slice<THREAD, WARP_STRIPED>(keys),
make_slice<THREAD, WARP_STRIPED>(ranks),
make_slice<SHARED>(scratch->scatter.items));
make_slice<SHARED>(scratch->scatter.keys));

// load and scatter optional values
ValueT values[ITEMS_PER_THREAD];
if constexpr (IsDifferent<ValueT, NullType>::VALUE) {
const ValueInputSlice it = in_values.subslice(block.offset);
BlockLoad<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, it, make_slice<THREAD, WARP_STRIPED>(values), block.num_items);
// scatter values by storing them in scratch using ranks
BlockStoreAt<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, make_slice<THREAD, WARP_STRIPED>(values),
make_slice<THREAD, WARP_STRIPED>(ranks),
make_slice<SHARED>(scratch->scatter.values));
}
p.syncthreads();

// first block loads initial global offsets from input and other blocks
Expand Down Expand Up @@ -334,9 +361,16 @@ struct DeviceRadixSort {
global_offsets[i] -= exclusive_scan[i];
}

// gather scattered items from scratch
// gather scattered keys from scratch
BlockLoad<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, make_slice<SHARED>(scratch->scatter.items), make_slice(items));
p, make_slice<SHARED>(scratch->scatter.keys), make_slice(keys));

// gather optional scattered values from scratch
if constexpr (IsDifferent<ValueT, NullType>::VALUE) {
// gather scattered values from scratch
BlockLoad<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, make_slice<SHARED>(scratch->scatter.values), make_slice(values));
}
p.syncthreads();

// store global offsets in scratch
Expand All @@ -349,7 +383,7 @@ struct DeviceRadixSort {
unsigned out_offsets[ITEMS_PER_THREAD];
BlockLoadFrom<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, make_slice<SHARED>(scratch->global_offsets),
make_bitfield_extractor(make_slice(items), start_bit, num_pass_bits),
make_bitfield_extractor(make_slice(keys), start_bit, num_pass_bits),
make_slice(out_offsets));

// add item index (same as rank after scatter/gather) to output offsets
Expand All @@ -358,15 +392,23 @@ struct DeviceRadixSort {
out_offsets[i] += p.thread_idx() + i * BLOCK_THREADS;
}

// convert items back to original representation
// convert keys back to original representation
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
items[i] = RadixSortTraits<T>::from_bit_ordered(items[i]);
keys[i] = RadixSortTraits<KeyT>::from_bit_ordered(keys[i]);
}

// store gathered items in global memory using output offsets
BlockStoreAt<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, make_slice(items), make_slice(out_offsets), out, block.num_items);
// store gathered keys in global memory using output offsets
BlockStoreAt<BLOCK_THREADS, ITEMS_PER_THREAD>(p, make_slice(keys),
make_slice(out_offsets),
out_keys, block.num_items);

// store gathered values in global memory using output offsets
if constexpr (IsDifferent<ValueT, NullType>::VALUE) {
BlockStoreAt<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, make_slice(values), make_slice(out_offsets), out_values,
block.num_items);
}
}
};

Expand Down
104 changes: 76 additions & 28 deletions velox/experimental/breeze/breeze/functions/sort.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,12 +269,13 @@ struct BlockRadixRank {
}
};

template <typename PlatformT, int ITEMS_PER_THREAD, int RADIX_BITS, typename T>
template <typename PlatformT, int ITEMS_PER_THREAD, int RADIX_BITS,
typename KeyT, typename ValueT>
struct BlockRadixSort {
enum {
BLOCK_THREADS = PlatformT::BLOCK_THREADS,
WARP_THREADS = PlatformT::WARP_THREADS,
END_BIT = sizeof(T) * /*BITS_PER_BYTE=*/8,
END_BIT = sizeof(KeyT) * /*BITS_PER_BYTE=*/8,
NUM_PASSES = utils::DivideAndRoundUp<END_BIT, RADIX_BITS>::VALUE,
NUM_BINS = 1 << RADIX_BITS,
BINS_PER_THREAD = utils::DivideAndRoundUp<NUM_BINS, BLOCK_THREADS>::VALUE,
Expand All @@ -284,21 +285,27 @@ struct BlockRadixSort {
union {
typename BlockRadixRank<PlatformT, ITEMS_PER_THREAD, RADIX_BITS>::Scratch
rank;
T scatter[BLOCK_THREADS * ITEMS_PER_THREAD];
struct {
union {
KeyT keys[BLOCK_THREADS * ITEMS_PER_THREAD];
ValueT values[BLOCK_THREADS * ITEMS_PER_THREAD];
};
} scatter;
};
};

template <typename ItemSlice, typename ScratchSlice>
static ATTR void Sort(PlatformT p, ItemSlice items, ScratchSlice scratch) {
template <typename KeySlice, typename ValueSlice, typename ScratchSlice>
static ATTR void Sort(PlatformT p, KeySlice keys, ValueSlice values,
ScratchSlice scratch) {
using namespace utils;

static_assert(IsSame<typename ScratchSlice::data_type, Scratch>::VALUE,
"incorrect scratch type");

// convert items to bit ordered representation if needed
// convert keys to bit ordered representation if needed
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
items[i] = RadixSortTraits<T>::to_bit_ordered(items[i]);
keys[i] = RadixSortTraits<KeyT>::to_bit_ordered(keys[i]);
}

// start from LSB and loop until no bits are left
Expand All @@ -307,36 +314,49 @@ struct BlockRadixSort {
int start_bit = i * RADIX_BITS;
int num_pass_bits = p.min(RADIX_BITS, END_BIT - start_bit);

// determine stable rank for each item
// determine stable rank for each key
int ranks[ITEMS_PER_THREAD];
BlockRadixRank<PlatformT, ITEMS_PER_THREAD, RADIX_BITS>::Rank(
p, make_bitfield_extractor(items, start_bit, num_pass_bits),
make_slice<THREAD, ItemSlice::ARRANGEMENT>(ranks),
p, make_bitfield_extractor(keys, start_bit, num_pass_bits),
make_slice<THREAD, KeySlice::ARRANGEMENT>(ranks),
make_slice<SHARED>(&scratch->rank));
p.syncthreads();

// scatter items by storing them in shared memory using ranks
// scatter keys by storing them in scratch using ranks
BlockStoreAt<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, items, make_slice<THREAD, ItemSlice::ARRANGEMENT>(ranks),
make_slice<SHARED>(scratch->scatter));
p, keys, make_slice<THREAD, KeySlice::ARRANGEMENT>(ranks),
make_slice<SHARED>(scratch->scatter.keys));
p.syncthreads();

// load scattered items
// load scattered keys
BlockLoad<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, make_slice<SHARED>(scratch->scatter), items);
p, make_slice<SHARED>(scratch->scatter.keys), keys);
p.syncthreads();

if constexpr (IsDifferent<ValueT, NullType>::VALUE) {
// scatter values by storing them in scratch using ranks
BlockStoreAt<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, values, make_slice<THREAD, KeySlice::ARRANGEMENT>(ranks),
make_slice<SHARED>(scratch->scatter.values));
p.syncthreads();

// load scattered values
BlockLoad<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, make_slice<SHARED>(scratch->scatter.values), values);
p.syncthreads();
}
}

// convert items back to original representation
// convert keys back to original representation
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
items[i] = RadixSortTraits<T>::from_bit_ordered(items[i]);
keys[i] = RadixSortTraits<KeyT>::from_bit_ordered(keys[i]);
}
}

template <typename ItemSlice, typename ScratchSlice>
static ATTR void Sort(PlatformT p, ItemSlice items, ScratchSlice scratch,
int num_items) {
template <typename KeySlice, typename ValueSlice, typename ScratchSlice>
static ATTR void Sort(PlatformT p, KeySlice keys, ValueSlice values,
ScratchSlice scratch, int num_items) {
using namespace utils;

enum {
Expand All @@ -345,31 +365,59 @@ struct BlockRadixSort {

static_assert((BLOCK_THREADS % WARP_THREADS) == 0,
"BLOCK_THREADS must be a multiple of WARP_THREADS");
static_assert(ItemSlice::ARRANGEMENT == WARP_STRIPED,
static_assert(KeySlice::ARRANGEMENT == WARP_STRIPED,
"input must have warp-striped arrangement");

int thread_offset = p.warp_idx() * WARP_ITEMS + p.lane_idx();

// pad items with values that have all bits set
T padded_items[ITEMS_PER_THREAD];
// pad keys with values that have all bits set
KeyT padded_keys[ITEMS_PER_THREAD];
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
padded_items[i] = NumericLimits<T>::max();
padded_keys[i] = NumericLimits<KeyT>::max();
}
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
if (thread_offset + (i * WARP_THREADS) < num_items) {
padded_items[i] = items[i];
padded_keys[i] = keys[i];
}
}

Sort(p, make_slice<THREAD, WARP_STRIPED>(padded_items), scratch);
if constexpr (IsDifferent<ValueT, NullType>::VALUE) {
static_assert(ValueSlice::ARRANGEMENT == WARP_STRIPED,
"input must have warp-striped arrangement");

ValueT padded_values[ITEMS_PER_THREAD];
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
padded_values[i] = static_cast<ValueT>(0);
}
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
if (thread_offset + (i * WARP_THREADS) < num_items) {
padded_values[i] = values[i];
}
}

Sort(p, make_slice<THREAD, WARP_STRIPED>(padded_keys),
make_slice<THREAD, WARP_STRIPED>(padded_values), scratch);

// copy valid values back
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
if (thread_offset + (i * WARP_THREADS) < num_items) {
values[i] = padded_values[i];
}
}
} else {
Sort(p, make_slice<THREAD, WARP_STRIPED>(padded_keys), values, scratch);
}

// copy valid items back
// copy valid keys back
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
if (thread_offset + (i * WARP_THREADS) < num_items) {
items[i] = padded_items[i];
keys[i] = padded_keys[i];
}
}
}
Expand Down
6 changes: 5 additions & 1 deletion velox/experimental/breeze/breeze/utils/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,11 @@ enum DataArrangement {
WARP_STRIPED,
};

class EmptySlice {};
class NullType {};

class EmptySlice {
using data_type = NullType;
};

ATTR EmptySlice constexpr make_empty_slice() { return EmptySlice{}; }

Expand Down
Loading

0 comments on commit 8b4663d

Please sign in to comment.