Skip to content

Commit

Permalink
Merge branch 'huikang/sparse_read_evict' into 'main'
Browse files Browse the repository at this point in the history
add sparse read evict

See merge request dl/hugectr/hugectr!1526
  • Loading branch information
minseokl committed Feb 19, 2024
2 parents 4dc3775 + 02af789 commit 5a8ee1b
Show file tree
Hide file tree
Showing 16 changed files with 543 additions and 15 deletions.
18 changes: 18 additions & 0 deletions sparse_operation_kit/kit_src/variable/impl/det_variable.cu
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,24 @@ void DETVariable<KeyType, ValueType>::lookup(const KeyType* keys, ValueType** va
map_->lookup(keys, values, num_keys, stream);
}

template <typename KeyType, typename ValueType>
void DETVariable<KeyType, ValueType>::lookup_with_evict(const KeyType* keys, KeyType* tmp_keys,
ValueType* tmp_values, ValueType* values,
uint64_t* evict_num_keys, uint64_t num_keys,
cudaStream_t stream) {
throw std::runtime_error(
"SOK dynamic variable with DET backend don't support lookup_with_evict!");
}

template <typename KeyType, typename ValueType>
void DETVariable<KeyType, ValueType>::copy_evict_keys(const KeyType* keys, const ValueType* values,
size_t num_keys, size_t dim,
KeyType* ret_keys, ValueType* ret_values,
cudaStream_t stream) {
throw std::runtime_error(
"SOK dynamic variable with DET backend don't support lookup_with_evict!");
}

template <typename KeyType, typename ValueType>
void DETVariable<KeyType, ValueType>::scatter_add(const KeyType* keys, const ValueType* values,
size_t num_keys, cudaStream_t stream) {
Expand Down
7 changes: 7 additions & 0 deletions sparse_operation_kit/kit_src/variable/impl/det_variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ class DETVariable : public VariableBase<KeyType, ValueType> {
cudaStream_t stream = 0) override;
void lookup(const KeyType *keys, ValueType **values, size_t num_keys,
cudaStream_t stream = 0) override;

void lookup_with_evict(const KeyType *keys, KeyType *tmp_keys, ValueType *tmp_values,
ValueType *values, uint64_t *evict_num_keys, uint64_t num_keys,
cudaStream_t stream = 0) override;

void copy_evict_keys(const KeyType *keys, const ValueType *values, size_t num_keys, size_t dim,
KeyType *ret_keys, ValueType *ret_values, cudaStream_t stream = 0) override;
void scatter_add(const KeyType *keys, const ValueType *values, size_t num_keys,
cudaStream_t stream = 0) override;
void scatter_update(const KeyType *keys, const ValueType *values, size_t num_keys,
Expand Down
181 changes: 168 additions & 13 deletions sparse_operation_kit/kit_src/variable/impl/hkv_variable.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,34 @@ __global__ void generate_uniform_kernel(curandState* state, T** result, bool* d_
load_state = true;
}
for (size_t i = emb_vec_id; i < dim; i += blockDim.x) {
result[emb_id][i] = curand_normal_double(&localState);
result[emb_id][i] = curand_uniform_double(&localState);
}
}
}
/* Copy state back to global memory */
if (load_state) {
state[GlobalThreadId()] = localState;
}
}

template <typename T>
__global__ void generate_uniform_kernel(curandState* state, T* result, bool* d_found, size_t dim,
size_t num_embedding) {
auto id = static_cast<size_t>(blockDim.x) * blockIdx.x + threadIdx.x;
size_t block_id = blockIdx.x;
size_t emb_vec_id = threadIdx.x;
/* Copy state to local memory for efficiency */
curandState localState;
bool load_state = false;
for (size_t emb_id = block_id; emb_id < num_embedding; emb_id += gridDim.x) {
if (!d_found[emb_id]) {
if (!load_state) {
localState = state[GlobalThreadId()];
load_state = true;
}
for (size_t i = emb_vec_id; i < dim; i += blockDim.x) {
T* tmp_reslut = result + emb_id * dim;
tmp_reslut[i] = curand_uniform_double(&localState);
}
}
}
Expand Down Expand Up @@ -97,6 +124,19 @@ __global__ void const_initializer_kernel(float val, T** result, bool* d_found, s
}
}

template <typename T>
__global__ void const_initializer_kernel(float val, T* result, bool* d_found, size_t dim) {
size_t id = threadIdx.x + blockIdx.x * blockDim.x;
size_t emb_id = blockIdx.x;
size_t emb_vec_id = threadIdx.x;
if (!d_found[emb_id]) {
for (size_t i = emb_vec_id; i < dim; i += blockDim.x) {
T* tmp_reslut = result + emb_id * dim;
tmp_reslut[i] = static_cast<T>(val);
}
}
}

template <typename T>
__global__ void generate_normal_kernel(curandState* state, T* result, size_t n) {
auto id = static_cast<size_t>(blockDim.x) * blockIdx.x + threadIdx.x;
Expand Down Expand Up @@ -137,6 +177,62 @@ __global__ void generate_normal_kernel(curandState* state, T** result, bool* d_f
}
}

template <typename T>
__global__ void generate_normal_kernel(curandState* state, T* result, bool* d_found, size_t dim,
size_t num_embedding) {
auto id = static_cast<size_t>(blockDim.x) * blockIdx.x + threadIdx.x;
size_t block_id = blockIdx.x;
size_t emb_vec_id = threadIdx.x;
/* Copy state to local memory for efficiency */
curandState localState;
bool load_state = false;
for (size_t emb_id = block_id; emb_id < num_embedding; emb_id += gridDim.x) {
if (!d_found[emb_id]) {
if (!load_state) {
localState = state[GlobalThreadId()];
load_state = true;
}
for (size_t i = emb_vec_id; i < dim; i += blockDim.x) {
T* tmp_reslut = result + emb_id * dim;
tmp_reslut[i] = curand_normal_double(&localState);
}
}
}

/* Copy state back to global memory */
if (load_state) {
state[GlobalThreadId()] = localState;
}
}

template <typename KeyType, typename ValueType>
__global__ void select_no_found_kernel(const KeyType* keys, const ValueType* values, bool* d_found,
uint64_t num_keys, uint64_t dim, KeyType* ret_keys,
ValueType* ret_values, uint64_t* num_no_found) {
auto id = static_cast<uint64_t>(blockDim.x) * blockIdx.x + threadIdx.x;
__shared__ uint64_t smem[1];
uint64_t block_id = blockIdx.x;
uint64_t emb_vec_id = threadIdx.x;
/* Copy state to local memory for efficiency */
for (uint64_t emb_id = block_id; emb_id < num_keys; emb_id += gridDim.x) {
if (!d_found[emb_id]) {
if (emb_vec_id == 0) {
uint64_t index = atomicAdd(num_no_found, 1);
smem[0] = index;
ret_keys[index] = keys[emb_id];
}
__syncthreads();
uint64_t output_index = smem[0];

for (uint64_t i = emb_vec_id; i < dim; i += blockDim.x) {
const ValueType* tmp_values = values + emb_id * dim;
ValueType* tmp_ret_values = ret_values + output_index * dim;
tmp_ret_values[i] = tmp_values[i];
}
}
}
}

template <class K, class S>
struct ExportIfPredFunctor {
__forceinline__ __device__ bool operator()(const K& key, S& score, const K& pattern,
Expand Down Expand Up @@ -286,24 +382,12 @@ template <typename KeyType, typename ValueType>
void HKVVariable<KeyType, ValueType>::assign(const KeyType* keys, const ValueType* values,
size_t num_keys, cudaStream_t stream) {
int64_t dim = cols();
// `keys` and `values` are pointers of host memory
// KeyType* d_keys;
// CUDACHECK(cudaMalloc(&d_keys, sizeof(KeyType) * num_keys));
// ValueType* d_values;
// CUDACHECK(cudaMalloc(&d_values, sizeof(ValueType) * num_keys * dim));

KeyType* d_keys;
CUDACHECK(cudaMallocManaged(&d_keys, sizeof(KeyType) * num_keys));
ValueType* d_values;
CUDACHECK(cudaMallocManaged(&d_values, sizeof(ValueType) * num_keys * dim));
// clang-format off
//CUDACHECK(cudaMemcpyAsync(d_keys, keys, sizeof(KeyType) * num_keys,
// cudaMemcpyHostToDevice, stream));

//CUDACHECK(cudaMemcpyAsync(d_values, values, sizeof(ValueType) * num_keys * dim,
// cudaMemcpyHostToDevice, stream));

//CUDACHECK(cudaStreamSynchronize(stream));
std::memcpy(d_keys, keys, sizeof(KeyType) * num_keys);
std::memcpy(d_values, values, sizeof(ValueType) * num_keys * dim);
hkv_table_->insert_or_assign(num_keys, d_keys, d_values, nullptr, stream);
Expand Down Expand Up @@ -344,6 +428,77 @@ void HKVVariable<KeyType, ValueType>::lookup(const KeyType* keys, ValueType* val
CUDACHECK(cudaFree(d_found));
}

template <typename KeyType, typename ValueType>
void HKVVariable<KeyType, ValueType>::lookup_with_evict(const KeyType *keys,KeyType *tmp_keys, ValueType* tmp_values, ValueType *values,uint64_t* evict_num_keys,uint64_t num_keys,cudaStream_t stream){

int64_t dim = cols();

bool* d_found;
KeyType* tmp_key_buffer;
ValueType* tmp_value_buffer;
uint64_t* tmp_counters;
uint64_t* d_evict_num_keys;
uint64_t h_tmp_counters[1];

uint64_t tmp_buffer_size = 0;
tmp_buffer_size += align_length(num_keys*sizeof(bool));
tmp_buffer_size += align_length(num_keys*sizeof(KeyType));
tmp_buffer_size += align_length(num_keys*dim*sizeof(ValueType));
tmp_buffer_size += align_length(sizeof(uint64_t));
tmp_buffer_size += align_length(sizeof(size_t));

CUDACHECK(cudaMallocAsync(&d_found, tmp_buffer_size,stream));
CUDACHECK(cudaMemsetAsync(d_found, 0, tmp_buffer_size,stream));

CUDACHECK(cudaStreamSynchronize(stream));
tmp_key_buffer = (KeyType*)(((char*)d_found)+align_length(num_keys*sizeof(bool)));
tmp_value_buffer = (ValueType*)(((char*)tmp_key_buffer)+align_length(num_keys*sizeof(KeyType)));
tmp_counters = (uint64_t*)(((char*)tmp_value_buffer)+align_length(num_keys*dim*sizeof(ValueType)));
d_evict_num_keys = (size_t*)(((char*)tmp_counters)+align_length(sizeof(uint64_t)));

//found first
hkv_table_->find(num_keys, keys, values, d_found, nullptr, stream);

//fill not found
uint32_t block_dim = max(dim, static_cast<int64_t>(32));
uint32_t grid_dim = SM_NUM*(NTHREAD_PER_SM/block_dim);
if (num_keys<grid_dim) grid_dim = num_keys;
if (initializer_ == "normal" || initializer_ == "random") {
generate_normal_kernel<<<grid_dim, block_dim, 0, stream>>>(
curand_states_, values, d_found, dim,num_keys);
} else if (initializer_ == "uniform") {
generate_uniform_kernel<<<grid_dim, block_dim, 0, stream>>>(
curand_states_, values, d_found, dim,num_keys);
} else {
try {
float val = std::stof(initializer_);
const_initializer_kernel<<<num_keys, block_dim, 0, stream>>>(
val, values, d_found, dim);
} catch (std::invalid_argument& err) {
throw std::runtime_error("Unrecognized initializer {" + initializer_ + "}");
}
}

select_no_found_kernel<<<grid_dim, block_dim, 0, stream>>>(keys,values,d_found,num_keys,dim,tmp_key_buffer,tmp_value_buffer,tmp_counters);
CUDACHECK(cudaMemcpyAsync(h_tmp_counters,tmp_counters,sizeof(uint64_t),cudaMemcpyDeviceToHost,stream));
CUDACHECK(cudaStreamSynchronize(stream));

hkv_table_->insert_and_evict(h_tmp_counters[0], tmp_key_buffer, tmp_value_buffer, nullptr, tmp_keys,tmp_values,nullptr,d_evict_num_keys, stream);

CUDACHECK(cudaMemcpyAsync(evict_num_keys,d_evict_num_keys,sizeof(size_t),cudaMemcpyDeviceToHost,stream));
CUDACHECK(cudaFreeAsync(d_found,stream));
CUDACHECK(cudaStreamSynchronize(stream));
}

template <typename KeyType, typename ValueType>
void HKVVariable<KeyType, ValueType>::copy_evict_keys(const KeyType* keys, const ValueType* values,size_t num_keys,size_t dim, KeyType* ret_keys, ValueType* ret_values, cudaStream_t stream) {

CUDACHECK(cudaMemcpyAsync(ret_keys,keys,sizeof(KeyType)*num_keys,cudaMemcpyDeviceToDevice,stream));
CUDACHECK(cudaMemcpyAsync(ret_values,values,sizeof(ValueType)*num_keys*dim,cudaMemcpyDeviceToDevice,stream));
CUDACHECK(cudaStreamSynchronize(stream));

}

template <typename KeyType, typename ValueType>
void HKVVariable<KeyType, ValueType>::lookup(const KeyType* keys, ValueType** values,
size_t num_keys, cudaStream_t stream) {
Expand Down
8 changes: 8 additions & 0 deletions sparse_operation_kit/kit_src/variable/impl/hkv_variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ class HKVVariable : public VariableBase<KeyType, ValueType> {

void lookup(const KeyType *keys, ValueType *values, size_t num_keys,
cudaStream_t stream = 0) override;

void lookup_with_evict(const KeyType *keys, KeyType *tmp_keys, ValueType *tmp_values,
ValueType *values, uint64_t *evict_num_keys, uint64_t num_keys,
cudaStream_t stream = 0) override;

void copy_evict_keys(const KeyType *keys, const ValueType *values, size_t num_keys, size_t dim,
KeyType *ret_keys, ValueType *ret_values, cudaStream_t stream = 0) override;

void lookup(const KeyType *keys, ValueType **values, size_t num_keys,
cudaStream_t stream = 0) override;
void scatter_add(const KeyType *keys, const ValueType *values, size_t num_keys,
Expand Down
10 changes: 10 additions & 0 deletions sparse_operation_kit/kit_src/variable/impl/variable_base.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@

namespace sok {

uint64_t align_length(uint64_t num) {
// Check if num is already a multiple of 16
if (num % 16 == 0) {
return num;
}
// Find the next multiple of 16
uint64_t alignedNum = ((num / 16) + 1) * 16;
return alignedNum;
}

template <typename KeyType, typename ValueType>
std::shared_ptr<VariableBase<KeyType, ValueType>> VariableFactory::create(
int64_t rows, int64_t cols, const std::string &type, const std::string &initializer,
Expand Down
10 changes: 10 additions & 0 deletions sparse_operation_kit/kit_src/variable/impl/variable_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

namespace sok {

uint64_t align_length(uint64_t num);

template <typename KeyType, typename ValueType>
class VariableBase {
public:
Expand All @@ -43,6 +45,14 @@ class VariableBase {
cudaStream_t stream = 0) = 0;
virtual void lookup(const KeyType *keys, ValueType **values, size_t num_keys,
cudaStream_t stream = 0) = 0;

virtual void lookup_with_evict(const KeyType *keys, KeyType *tmp_keys, ValueType *tmp_values,
ValueType *values, uint64_t *evict_num_keys, uint64_t num_keys,
cudaStream_t stream = 0) = 0;

virtual void copy_evict_keys(const KeyType *keys, const ValueType *values, size_t num_keys,
size_t dim, KeyType *ret_keys, ValueType *ret_values,
cudaStream_t stream = 0) = 0;
virtual void scatter_add(const KeyType *keys, const ValueType *values, size_t num_keys,
cudaStream_t stream = 0) = 0;
virtual void scatter_update(const KeyType *keys, const ValueType *values, size_t num_keys,
Expand Down
14 changes: 14 additions & 0 deletions sparse_operation_kit/kit_src/variable/kernels/dummy_var.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,20 @@ void DummyVar<KeyType, ValueType>::SparseRead(const void* keys, void* values, si
stream);
}

template <typename KeyType, typename ValueType>
void DummyVar<KeyType, ValueType>::SparseReadEvict(const void* keys, void* tmp_keys, void* tmp_values, void *values , uint64_t* evict_num_keys, size_t num_keys,
cudaStream_t stream) {
check_var();
var_->lookup_with_evict(static_cast<const KeyType*>(keys), static_cast<KeyType*>(tmp_keys),static_cast<ValueType*>(tmp_values),static_cast<ValueType*>(values),evict_num_keys, num_keys,
stream);
}

template <typename KeyType, typename ValueType>
void DummyVar<KeyType, ValueType>::CopyEvictKeys(const void* keys, const void* values,size_t num_keys,size_t dim, void* ret_keys, void* ret_values, cudaStream_t stream) {
check_var();
var_->copy_evict_keys(static_cast<const KeyType*>(keys), static_cast<const ValueType*>(values),num_keys,dim,static_cast<KeyType*>(ret_keys),static_cast<ValueType*>(ret_values),stream);
}

template <typename KeyType, typename ValueType>
void DummyVar<KeyType, ValueType>::ScatterAdd(const void* keys, const void* values, size_t num_keys,
cudaStream_t stream) {
Expand Down
4 changes: 4 additions & 0 deletions sparse_operation_kit/kit_src/variable/kernels/dummy_var.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ class DummyVar : public ResourceBase {
void Assign(const void *keys, const void *values, size_t num_keys, cudaStream_t stream);

void SparseRead(const void *keys, void *values, size_t num_keys, cudaStream_t stream);
void SparseReadEvict(const void *keys, void *tmp_keys, void *tmp_values, void *values,
uint64_t *evict_num_keys, size_t num_keys, cudaStream_t stream);
void CopyEvictKeys(const void *keys, const void *values, size_t num_keys, size_t dim,
void *ret_keys, void *ret_values, cudaStream_t stream);
void ScatterAdd(const void *keys, const void *values, size_t num_keys, cudaStream_t stream);
void ScatterUpdate(const void *keys, const void *values, size_t num_keys, cudaStream_t stream);

Expand Down
Loading

0 comments on commit 5a8ee1b

Please sign in to comment.