Skip to content

Commit

Permalink
low frequency filter
Browse files Browse the repository at this point in the history
  • Loading branch information
ccccjunkang committed Jul 18, 2024
1 parent 2e8e855 commit 0182e7c
Show file tree
Hide file tree
Showing 15 changed files with 435 additions and 17 deletions.
126 changes: 124 additions & 2 deletions HugeCTR/embedding/all2all_embedding_collection.cu
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ void weighted_sparse_forward_per_gpu(
const core23::Tensor &sp_weights_all_gather_recv_buffer, ILookup *emb_storage,
std::vector<core23::Tensor> &emb_vec_model_buffer, int64_t *num_model_key,
int64_t *num_model_offsets, core23::Tensor &ret_model_key, core23::Tensor &ret_model_offset,
core23::Tensor &ret_sp_weight) {
core23::Tensor &ret_sp_weight, bool use_filter) {
HugeCTR::CudaDeviceContext context(core->get_device_id());

int tensor_device_id = core->get_device_id();
Expand Down Expand Up @@ -369,14 +369,88 @@ void weighted_sparse_forward_per_gpu(
*num_model_offsets = model_offsets.num_elements();
}

template <typename offset_t>
__global__ void cal_lookup_idx(size_t lookup_num, offset_t *bucket_after_filter, size_t batch_size,
offset_t *lookup_offset, size_t bucket_num) {
int32_t i = blockIdx.x * blockDim.x + threadIdx.x;
int32_t step = blockDim.x * gridDim.x;
for (; i < (lookup_num); i += step) {
lookup_offset[i] = bucket_after_filter[i * batch_size];
}
}

template <typename offset_t>
__global__ void count_ratio_filter(size_t bucket_num, char *filterd, const offset_t *bucket_range,
offset_t *bucket_after_filter) {
int32_t i = blockIdx.x * blockDim.x + threadIdx.x;
int32_t step = blockDim.x * gridDim.x;
for (; i < (bucket_num); i += step) {
offset_t start = bucket_range[i];
offset_t end = bucket_range[i + 1];
bucket_after_filter[i + 1] = 0;
for (offset_t idx = start; idx < end; idx++) {
if (filterd[idx] == 1) {
bucket_after_filter[i + 1]++;
}
}
if (i == 0) {
bucket_after_filter[i] = 0;
}
}
}

void filter(std::shared_ptr<CoreResourceManager> core,
const UniformModelParallelEmbeddingMeta &meta, const core23::Tensor &filterd,
core23::Tensor &bucket_range, core23::Tensor &bucket_after_filter,
core23::TensorParams &params, EmbeddingInput &emb_input, core23::Tensor &lookup_offset,
core23::Tensor &temp_scan_storage, core23::Tensor &temp_select_storage,
size_t temp_scan_bytes, size_t temp_select_bytes, core23::Tensor &keys_after_filter) {
auto stream = core->get_local_gpu()->get_stream();
// bucket_range length = bucket_num+1 , so here we minus 1.
int bucket_num = bucket_range.num_elements() - 1;
const int block_size = 256;
const int grid_size =
core->get_kernel_param().num_sms * core->get_kernel_param().max_thread_per_block / block_size;

DISPATCH_INTEGRAL_FUNCTION_CORE23(bucket_range.data_type().type(), offset_t, [&] {
DISPATCH_INTEGRAL_FUNCTION_CORE23(keys_after_filter.data_type().type(), key_t, [&] {
offset_t *bucket_after_filter_ptr = bucket_after_filter.data<offset_t>();
const offset_t *bucket_range_ptr = bucket_range.data<offset_t>();
char *filterd_ptr = filterd.data<char>();
count_ratio_filter<<<grid_size, block_size, 0, stream>>>(
bucket_num, filterd_ptr, bucket_range_ptr, bucket_after_filter_ptr);
cub::DeviceScan::InclusiveSum(
temp_scan_storage.data(), temp_scan_bytes, bucket_after_filter.data<offset_t>(),
bucket_after_filter.data<offset_t>(), bucket_after_filter.num_elements(), stream);

key_t *keys_ptr = emb_input.keys.data<key_t>();

cub::DeviceSelect::Flagged(temp_select_storage.data(), temp_select_bytes, keys_ptr,
filterd_ptr, keys_after_filter.data<key_t>(),
emb_input.num_keys.data<uint64_t>(), emb_input.h_num_keys, stream);

size_t batch_size = (bucket_num) / meta.num_lookup_;

cal_lookup_idx<<<1, block_size, 0, stream>>>(meta.num_lookup_ + 1,
bucket_after_filter.data<offset_t>(), batch_size,
lookup_offset.data<offset_t>(), bucket_num);
HCTR_LIB_THROW(cudaStreamSynchronize(stream));
emb_input.h_num_keys = static_cast<size_t>(emb_input.num_keys.data<uint64_t>()[0]);
emb_input.keys = keys_after_filter;
emb_input.bucket_range = bucket_after_filter;
});
});
}

void sparse_forward_per_gpu(std::shared_ptr<CoreResourceManager> core,
const EmbeddingCollectionParam &ebc_param,
const UniformModelParallelEmbeddingMeta &meta,
const core23::Tensor &key_all_gather_recv_buffer,
const core23::Tensor &row_lengths_all_gather_recv_buffer,
ILookup *emb_storage, std::vector<core23::Tensor> &emb_vec_model_buffer,
int64_t *num_model_key, int64_t *num_model_offsets,
core23::Tensor *ret_model_key, core23::Tensor *ret_model_offset) {
core23::Tensor *ret_model_key, core23::Tensor *ret_model_offset,
bool use_filter) {
/*
There are some steps in this function:
1.reorder key to feature major
Expand Down Expand Up @@ -500,8 +574,56 @@ void sparse_forward_per_gpu(std::shared_ptr<CoreResourceManager> core,
compress_offset_.compute(embedding_input.bucket_range, batch_size, &num_key_per_lookup_offset);
HCTR_LIB_THROW(cudaStreamSynchronize(stream));

if (use_filter) {
core23::Tensor bucket_range_after_filter;
core23::Tensor keys_after_filter;
core23::Tensor filtered;

filtered = core23::Tensor(
params.shape({(int64_t)embedding_input.h_num_keys}).data_type(core23::ScalarType::Char));
bucket_range_after_filter =
core23::Tensor(params.shape({embedding_input.bucket_range.num_elements()})
.data_type(embedding_input.bucket_range.data_type().type()));
keys_after_filter = core23::Tensor(params.shape({(int64_t)embedding_input.h_num_keys + 1})
.data_type(embedding_input.keys.data_type().type()));

core23::Tensor temp_scan_storage;
core23::Tensor temp_select_storage;

size_t temp_scan_bytes = 0;
size_t temp_select_bytes = 0;

DISPATCH_INTEGRAL_FUNCTION_CORE23(
embedding_input.bucket_range.data_type().type(), offset_t, [&] {
DISPATCH_INTEGRAL_FUNCTION_CORE23(embedding_input.keys.data_type().type(), key_t, [&] {
cub::DeviceScan::InclusiveSum(nullptr, temp_scan_bytes, (offset_t *)nullptr,
(offset_t *)nullptr,
bucket_range_after_filter.num_elements());

temp_scan_storage = core23::Tensor(params.shape({static_cast<int64_t>(temp_scan_bytes)})
.data_type(core23::ScalarType::Char));

cub::DeviceSelect::Flagged(nullptr, temp_select_bytes, (key_t *)nullptr,
(char *)nullptr, (key_t *)nullptr, (uint64_t *)nullptr,
embedding_input.h_num_keys);

temp_select_storage =
core23::Tensor(params.shape({static_cast<int64_t>(temp_select_bytes)})
.data_type(core23::ScalarType::Char));
});
});

emb_storage->ratio_filter(embedding_input.keys, embedding_input.h_num_keys,
num_key_per_lookup_offset, meta.num_local_lookup_ + 1,
meta.d_local_table_id_list_, filtered);

filter(core, meta, filtered, embedding_input.bucket_range, bucket_range_after_filter, params,
embedding_input, num_key_per_lookup_offset, temp_scan_storage, temp_select_storage,
temp_scan_bytes, temp_select_bytes, keys_after_filter);
}
core23::Tensor embedding_vec = core23::init_tensor_list<float>(
key_all_gather_recv_buffer.num_elements(), params.device().index());

emb_storage->lookup(embedding_input.keys, embedding_input.h_num_keys, num_key_per_lookup_offset,
meta.num_local_lookup_ + 1, meta.d_local_table_id_list_, embedding_vec);

Expand Down
5 changes: 3 additions & 2 deletions HugeCTR/embedding/all2all_embedding_collection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ void weighted_sparse_forward_per_gpu(
const core23::Tensor &sp_weights_all_gather_recv_buffer, ILookup *emb_storage,
std::vector<core23::Tensor> &emb_vec_model_buffer, int64_t *num_model_key,
int64_t *num_model_offsets, core23::Tensor &ret_model_key, core23::Tensor &ret_model_offset,
core23::Tensor &ret_sp_weight);
core23::Tensor &ret_sp_weight, bool use_filter);

void weighted_copy_model_keys_and_offsets(
std::shared_ptr<CoreResourceManager> core, const core23::Tensor &model_key,
Expand All @@ -71,7 +71,8 @@ void sparse_forward_per_gpu(std::shared_ptr<CoreResourceManager> core,
const core23::Tensor &row_lengths_all_gather_recv_buffer,
ILookup *emb_storage, std::vector<core23::Tensor> &emb_vec_model_buffer,
int64_t *num_model_key, int64_t *num_model_offsets,
core23::Tensor *ret_model_key, core23::Tensor *ret_model_offset);
core23::Tensor *ret_model_key, core23::Tensor *ret_model_offset,
bool use_filter);

void copy_model_keys_and_offsets(std::shared_ptr<CoreResourceManager> core,
const core23::Tensor &model_key,
Expand Down
10 changes: 7 additions & 3 deletions HugeCTR/embedding/embedding_table.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,13 @@ class ILookup {
public:
virtual ~ILookup() = default;

virtual void lookup(const core23::Tensor &keys, size_t num_keys,
const core23::Tensor &num_keys_per_table_offset, size_t num_table_offset,
const core23::Tensor &table_id_list, core23::Tensor &embedding_vec) = 0;
virtual void lookup(const core23::Tensor& keys, size_t num_keys,
const core23::Tensor& num_keys_per_table_offset, size_t num_table_offset,
const core23::Tensor& table_id_list, core23::Tensor& embedding_vec) = 0;

virtual void ratio_filter(const core23::Tensor& keys, size_t num_keys,
const core23::Tensor& id_space_offset, size_t num_id_space_offset,
const core23::Tensor& id_space, core23::Tensor& filtered){};
};

} // namespace embedding
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,56 @@ void DummyVarAdapter<KeyType, OffsetType, DType>::lookup(
}
}

template <typename KeyType, typename OffsetType, typename DType>
void DummyVarAdapter<KeyType, OffsetType, DType>::ratio_filter(
const core23::Tensor& keys, size_t num_keys, const core23::Tensor& id_space_offset,
size_t num_id_space_offset, const core23::Tensor& id_space, core23::Tensor& filtered) {
// clang-format off
id_space_offset_.clear();
id_space_.clear();
id_space_offset_.resize(num_id_space_offset);
CUDACHECK(cudaMemcpyAsync(id_space_offset_.data(),
id_space_offset.data<OffsetType>(),
sizeof(OffsetType) * (num_id_space_offset),
cudaMemcpyDeviceToHost, stream_));
id_space_.resize(num_id_space_offset - 1);
CUDACHECK(cudaMemcpyAsync(id_space_.data(),
id_space.data<int>(),
sizeof(int) * (num_id_space_offset - 1),
cudaMemcpyDeviceToHost, stream_));
// clang-format on
CUDACHECK(cudaStreamSynchronize(stream_));
const KeyType* input = keys.data<KeyType>();
bool* output_filtered = filtered.data<bool>();
int start_index = 0;
size_t num = 0;
bool is_lookup = false;

for (int i = 0; i < num_id_space_offset - 1; ++i) {
if (i == num_id_space_offset - 2) {
num += id_space_offset_[i + 1] - id_space_offset_[i];
is_lookup = true;
} else {
if (same_table_[i + 1] != same_table_[i]) {
num += id_space_offset_[i + 1] - id_space_offset_[i];
is_lookup = true;
} else {
num += id_space_offset_[i + 1] - id_space_offset_[i];
}
}
if (num != 0 && is_lookup) {
auto var = vars_[id_space_[start_index]];
var->ratio_filter(input, output_filtered, num, stream_);
CUDACHECK(cudaStreamSynchronize(stream_));
input += num;
output_filtered += num;
num = 0;
is_lookup = false;
start_index = i + 1;
}
}
}

template class DummyVarAdapter<int32_t, int32_t, float>;
template class DummyVarAdapter<int32_t, int64_t, float>;
// template class DummyVarAdapter<int32_t, __half>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ class DummyVarAdapter : public ::embedding::ILookup {
size_t num_id_space_offset, const core23::Tensor& id_space,
core23::Tensor& embedding_vec) override;

void ratio_filter(const core23::Tensor& keys, size_t num_keys,
const core23::Tensor& id_space_offset, size_t num_id_space_offset,
const core23::Tensor& id_space, core23::Tensor& filtered) override;

private:
std::shared_ptr<sok::CoreResourceManager> tf_backend_;
int sm_count_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class EmbeddingCollectionBase : public OpKernel {
int global_gpu_id_;
int num_local_lookups_;
bool use_sp_weight_;
bool use_filter_;
HugeCTR::core23::KernelParams kernel_params_;

std::unique_ptr<sok::EmbeddingCollectionParam> ebc_param_;
Expand Down Expand Up @@ -143,6 +144,7 @@ class EmbeddingCollectionBase : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->GetAttr("id_in_local_rank", &id_in_local_rank_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_gpus", &num_gpus_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_sp_weight", &use_sp_weight_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_filter", &use_filter_));

// check rank/num_ranks/id_in_local_rank/num_gpus
OP_REQUIRES(ctx, rank_ >= 0 && rank_ < num_ranks_, errors::InvalidArgument("Invalid rank."));
Expand Down Expand Up @@ -477,13 +479,13 @@ class LookupFowardBase : public EmbeddingCollectionBase<KeyType, OffsetType, DTy
tf_backend, *this->meta_, this->global_gpu_id_, key_recv_buffer_tensor,
row_length_recv_buffer_tensor, sp_weight_recv_buffer_tensor, &adapter_,
emb_vec_model_buffer, &num_model_key, &num_model_offsets, ret_model_key, ret_model_offset,
ret_sp_weight);
ret_sp_weight,this->use_filter_);

} else {
::embedding::tf::model_forward::sparse_forward_per_gpu(
tf_backend, *this->ebc_param_, *this->meta_, key_recv_buffer_tensor,
row_length_recv_buffer_tensor, &adapter_, emb_vec_model_buffer, &num_model_key,
&num_model_offsets, &ret_model_key, &ret_model_offset);
&num_model_offsets, &ret_model_key, &ret_model_offset,this->use_filter_);
}

// Prepare model_key & model_offsets
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ REGISTER_OP("PreprocessingForward")
.Attr("id_in_local_rank: int")
.Attr("num_gpus: int")
.Attr("use_sp_weight: bool")
.Attr("use_filter: bool")
.Attr("Tindices: {int32, int64} = DT_INT64")
.Attr("Toffsets: {int32, int64} = DT_INT64")
.Attr("dtype: {float32, float16} = DT_FLOAT")
Expand Down Expand Up @@ -80,6 +81,7 @@ REGISTER_OP("PreprocessingForwardWithWeight")
.Attr("id_in_local_rank: int")
.Attr("num_gpus: int")
.Attr("use_sp_weight: bool")
.Attr("use_filter: bool")
.Attr("Tindices: {int32, int64} = DT_INT64")
.Attr("Toffsets: {int32, int64} = DT_INT64")
.Attr("dtype: {float32, float16} = DT_FLOAT")
Expand Down Expand Up @@ -112,6 +114,7 @@ REGISTER_OP("LookupForward")
.Attr("id_in_local_rank: int")
.Attr("num_gpus: int")
.Attr("use_sp_weight: bool")
.Attr("use_filter: bool")
.Attr("Tindices: {int32, int64} = DT_INT64")
.Attr("Toffsets: {int32, int64} = DT_INT64")
.Attr("dtype: {float32, float16} = DT_FLOAT")
Expand Down Expand Up @@ -165,6 +168,7 @@ REGISTER_OP("LookupForwardVariable")
.Attr("id_in_local_rank: int")
.Attr("num_gpus: int")
.Attr("use_sp_weight: bool")
.Attr("use_filter: bool")
.Attr("Tindices: {int32, int64} = DT_INT64")
.Attr("Toffsets: {int32, int64} = DT_INT64")
.Attr("dtype: {float32, float16} = DT_FLOAT")
Expand Down Expand Up @@ -218,6 +222,7 @@ REGISTER_OP("LookupForwardDynamic")
.Attr("id_in_local_rank: int")
.Attr("num_gpus: int")
.Attr("use_sp_weight: bool")
.Attr("use_filter: bool")
.Attr("Tindices: {int32, int64} = DT_INT64")
.Attr("Toffsets: {int32, int64} = DT_INT64")
.Attr("dtype: {float32, float16} = DT_FLOAT")
Expand Down Expand Up @@ -273,6 +278,7 @@ REGISTER_OP("LookupForwardEmbeddingVarGPU")
.Attr("id_in_local_rank: int")
.Attr("num_gpus: int")
.Attr("use_sp_weight: bool")
.Attr("use_filter: bool")
.Attr("Tindices: {int32, int64} = DT_INT64")
.Attr("Toffsets: {int32, int64} = DT_INT64")
.Attr("dtype: {float32, float16} = DT_FLOAT")
Expand Down Expand Up @@ -324,6 +330,7 @@ REGISTER_OP("LookupBackward")
.Attr("id_in_local_rank: int")
.Attr("num_gpus: int")
.Attr("use_sp_weight: bool")
.Attr("use_filter: bool")
.Attr("Tindices: {int32, int64} = DT_INT64")
.Attr("Toffsets: {int32, int64} = DT_INT64")
.Attr("dtype: {float32, float16} = DT_FLOAT")
Expand Down Expand Up @@ -362,6 +369,7 @@ REGISTER_OP("PostprocessingForward")
.Attr("id_in_local_rank: int")
.Attr("num_gpus: int")
.Attr("use_sp_weight: bool")
.Attr("use_filter: bool")
.Attr("Tindices: {int32, int64} = DT_INT64")
.Attr("Toffsets: {int32, int64} = DT_INT64")
.Attr("dtype: {float32, float16} = DT_FLOAT")
Expand Down Expand Up @@ -403,6 +411,7 @@ REGISTER_OP("PostprocessingBackward")
.Attr("id_in_local_rank: int")
.Attr("num_gpus: int")
.Attr("use_sp_weight: bool")
.Attr("use_filter: bool")
.Attr("Tindices: {int32, int64} = DT_INT64")
.Attr("Toffsets: {int32, int64} = DT_INT64")
.Attr("dtype: {float32, float16} = DT_FLOAT")
Expand Down
6 changes: 6 additions & 0 deletions sparse_operation_kit/kit_src/variable/impl/det_variable.cu
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,12 @@ void DETVariable<KeyType, ValueType>::scatter_update(const KeyType* keys, const
map_->scatter_update(keys, values, num_keys, stream);
}

template <typename KeyType, typename ValueType>
void DETVariable<KeyType, ValueType>::ratio_filter(const KeyType* keys, bool* filtered,
size_t num_keys, cudaStream_t stream) {
throw std::runtime_error("SOK dynamic variable with DET backend don't support ratio_filter!");
}

template class DETVariable<int32_t, float>;
template class DETVariable<int64_t, float>;

Expand Down
3 changes: 3 additions & 0 deletions sparse_operation_kit/kit_src/variable/impl/det_variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,13 @@ class DETVariable : public VariableBase<KeyType, ValueType> {
cudaStream_t stream = 0) override;
void scatter_update(const KeyType *keys, const ValueType *values, size_t num_keys,
cudaStream_t stream = 0) override;
void ratio_filter(const KeyType *keys, bool *filtered, size_t num_keys,
cudaStream_t stream = 0) override;

private:
std::unique_ptr<cuco::dynamic_map<KeyType, ValueType, cuco::initializer>> map_;

float filter_ratio_;
size_t dimension_;
size_t initial_capacity_;
std::string initializer_;
Expand Down
Loading

0 comments on commit 0182e7c

Please sign in to comment.