diff --git a/sparse_operation_kit/kit_src/variable/impl/det_variable.cu b/sparse_operation_kit/kit_src/variable/impl/det_variable.cu index 152969f8b0..a9c1a84def 100644 --- a/sparse_operation_kit/kit_src/variable/impl/det_variable.cu +++ b/sparse_operation_kit/kit_src/variable/impl/det_variable.cu @@ -218,6 +218,24 @@ void DETVariable::lookup(const KeyType* keys, ValueType** va map_->lookup(keys, values, num_keys, stream); } +template +void DETVariable::lookup_with_evict(const KeyType* keys, KeyType* tmp_keys, + ValueType* tmp_values, ValueType* values, + uint64_t* evict_num_keys, uint64_t num_keys, + cudaStream_t stream) { + throw std::runtime_error( + "SOK dynamic variable with DET backend don't support lookup_with_evict!"); +} + +template +void DETVariable::copy_evict_keys(const KeyType* keys, const ValueType* values, + size_t num_keys, size_t dim, + KeyType* ret_keys, ValueType* ret_values, + cudaStream_t stream) { + throw std::runtime_error( + "SOK dynamic variable with DET backend don't support lookup_with_evict!"); +} + template void DETVariable::scatter_add(const KeyType* keys, const ValueType* values, size_t num_keys, cudaStream_t stream) { diff --git a/sparse_operation_kit/kit_src/variable/impl/det_variable.h b/sparse_operation_kit/kit_src/variable/impl/det_variable.h index c628b1e271..243308bd7b 100644 --- a/sparse_operation_kit/kit_src/variable/impl/det_variable.h +++ b/sparse_operation_kit/kit_src/variable/impl/det_variable.h @@ -45,6 +45,13 @@ class DETVariable : public VariableBase { cudaStream_t stream = 0) override; void lookup(const KeyType *keys, ValueType **values, size_t num_keys, cudaStream_t stream = 0) override; + + void lookup_with_evict(const KeyType *keys, KeyType *tmp_keys, ValueType *tmp_values, + ValueType *values, uint64_t *evict_num_keys, uint64_t num_keys, + cudaStream_t stream = 0) override; + + void copy_evict_keys(const KeyType *keys, const ValueType *values, size_t num_keys, size_t dim, + KeyType *ret_keys, ValueType *ret_values, cudaStream_t stream = 0) override; void scatter_add(const KeyType *keys, const ValueType *values, size_t num_keys, cudaStream_t stream = 0) override; void scatter_update(const KeyType *keys, const ValueType *values, size_t num_keys, diff --git a/sparse_operation_kit/kit_src/variable/impl/hkv_variable.cu b/sparse_operation_kit/kit_src/variable/impl/hkv_variable.cu index 38c519b872..9a9fa2ab22 100644 --- a/sparse_operation_kit/kit_src/variable/impl/hkv_variable.cu +++ b/sparse_operation_kit/kit_src/variable/impl/hkv_variable.cu @@ -67,7 +67,34 @@ __global__ void generate_uniform_kernel(curandState* state, T** result, bool* d_ load_state = true; } for (size_t i = emb_vec_id; i < dim; i += blockDim.x) { - result[emb_id][i] = curand_normal_double(&localState); + result[emb_id][i] = curand_uniform_double(&localState); + } + } + } + /* Copy state back to global memory */ + if (load_state) { + state[GlobalThreadId()] = localState; + } +} + +template +__global__ void generate_uniform_kernel(curandState* state, T* result, bool* d_found, size_t dim, + size_t num_embedding) { + auto id = static_cast(blockDim.x) * blockIdx.x + threadIdx.x; + size_t block_id = blockIdx.x; + size_t emb_vec_id = threadIdx.x; + /* Copy state to local memory for efficiency */ + curandState localState; + bool load_state = false; + for (size_t emb_id = block_id; emb_id < num_embedding; emb_id += gridDim.x) { + if (!d_found[emb_id]) { + if (!load_state) { + localState = state[GlobalThreadId()]; + load_state = true; + } + for (size_t i = emb_vec_id; i < dim; i += blockDim.x) { + T* tmp_reslut = result + emb_id * dim; + tmp_reslut[i] = curand_uniform_double(&localState); } } } @@ -97,6 +124,19 @@ __global__ void const_initializer_kernel(float val, T** result, bool* d_found, s } } +template +__global__ void const_initializer_kernel(float val, T* result, bool* d_found, size_t dim) { + size_t id = threadIdx.x + blockIdx.x * blockDim.x; + size_t emb_id = blockIdx.x; + size_t emb_vec_id = threadIdx.x; + if (!d_found[emb_id]) { + for (size_t i = emb_vec_id; i < dim; i += blockDim.x) { + T* tmp_reslut = result + emb_id * dim; + tmp_reslut[i] = static_cast(val); + } + } +} + template __global__ void generate_normal_kernel(curandState* state, T* result, size_t n) { auto id = static_cast(blockDim.x) * blockIdx.x + threadIdx.x; @@ -137,6 +177,62 @@ __global__ void generate_normal_kernel(curandState* state, T** result, bool* d_f } } +template +__global__ void generate_normal_kernel(curandState* state, T* result, bool* d_found, size_t dim, + size_t num_embedding) { + auto id = static_cast(blockDim.x) * blockIdx.x + threadIdx.x; + size_t block_id = blockIdx.x; + size_t emb_vec_id = threadIdx.x; + /* Copy state to local memory for efficiency */ + curandState localState; + bool load_state = false; + for (size_t emb_id = block_id; emb_id < num_embedding; emb_id += gridDim.x) { + if (!d_found[emb_id]) { + if (!load_state) { + localState = state[GlobalThreadId()]; + load_state = true; + } + for (size_t i = emb_vec_id; i < dim; i += blockDim.x) { + T* tmp_reslut = result + emb_id * dim; + tmp_reslut[i] = curand_normal_double(&localState); + } + } + } + + /* Copy state back to global memory */ + if (load_state) { + state[GlobalThreadId()] = localState; + } +} + +template +__global__ void select_no_found_kernel(const KeyType* keys, const ValueType* values, bool* d_found, + uint64_t num_keys, uint64_t dim, KeyType* ret_keys, + ValueType* ret_values, uint64_t* num_no_found) { + auto id = static_cast(blockDim.x) * blockIdx.x + threadIdx.x; + __shared__ uint64_t smem[1]; + uint64_t block_id = blockIdx.x; + uint64_t emb_vec_id = threadIdx.x; + /* Copy state to local memory for efficiency */ + for (uint64_t emb_id = block_id; emb_id < num_keys; emb_id += gridDim.x) { + if (!d_found[emb_id]) { + if (emb_vec_id == 0) { + uint64_t index = atomicAdd(num_no_found, 1); + smem[0] = index; + ret_keys[index] = keys[emb_id]; + } + __syncthreads(); + uint64_t output_index = smem[0]; + + for (uint64_t i = emb_vec_id; i < dim; i += blockDim.x) { + const ValueType* tmp_values = values + emb_id * dim; + ValueType* tmp_ret_values = ret_values + output_index * dim; + tmp_ret_values[i] = tmp_values[i]; + } + } + } +} + template struct ExportIfPredFunctor { __forceinline__ __device__ bool operator()(const K& key, S& score, const K& pattern, @@ -286,24 +382,12 @@ template void HKVVariable::assign(const KeyType* keys, const ValueType* values, size_t num_keys, cudaStream_t stream) { int64_t dim = cols(); - // `keys` and `values` are pointers of host memory - // KeyType* d_keys; - // CUDACHECK(cudaMalloc(&d_keys, sizeof(KeyType) * num_keys)); - // ValueType* d_values; - // CUDACHECK(cudaMalloc(&d_values, sizeof(ValueType) * num_keys * dim)); KeyType* d_keys; CUDACHECK(cudaMallocManaged(&d_keys, sizeof(KeyType) * num_keys)); ValueType* d_values; CUDACHECK(cudaMallocManaged(&d_values, sizeof(ValueType) * num_keys * dim)); // clang-format off - //CUDACHECK(cudaMemcpyAsync(d_keys, keys, sizeof(KeyType) * num_keys, - // cudaMemcpyHostToDevice, stream)); - - //CUDACHECK(cudaMemcpyAsync(d_values, values, sizeof(ValueType) * num_keys * dim, - // cudaMemcpyHostToDevice, stream)); - - //CUDACHECK(cudaStreamSynchronize(stream)); std::memcpy(d_keys, keys, sizeof(KeyType) * num_keys); std::memcpy(d_values, values, sizeof(ValueType) * num_keys * dim); hkv_table_->insert_or_assign(num_keys, d_keys, d_values, nullptr, stream); @@ -344,6 +428,77 @@ void HKVVariable::lookup(const KeyType* keys, ValueType* val CUDACHECK(cudaFree(d_found)); } +template +void HKVVariable::lookup_with_evict(const KeyType *keys,KeyType *tmp_keys, ValueType* tmp_values, ValueType *values,uint64_t* evict_num_keys,uint64_t num_keys,cudaStream_t stream){ + + int64_t dim = cols(); + + bool* d_found; + KeyType* tmp_key_buffer; + ValueType* tmp_value_buffer; + uint64_t* tmp_counters; + uint64_t* d_evict_num_keys; + uint64_t h_tmp_counters[1]; + + uint64_t tmp_buffer_size = 0; + tmp_buffer_size += align_length(num_keys*sizeof(bool)); + tmp_buffer_size += align_length(num_keys*sizeof(KeyType)); + tmp_buffer_size += align_length(num_keys*dim*sizeof(ValueType)); + tmp_buffer_size += align_length(sizeof(uint64_t)); + tmp_buffer_size += align_length(sizeof(size_t)); + + CUDACHECK(cudaMallocAsync(&d_found, tmp_buffer_size,stream)); + CUDACHECK(cudaMemsetAsync(d_found, 0, tmp_buffer_size,stream)); + + CUDACHECK(cudaStreamSynchronize(stream)); + tmp_key_buffer = (KeyType*)(((char*)d_found)+align_length(num_keys*sizeof(bool))); + tmp_value_buffer = (ValueType*)(((char*)tmp_key_buffer)+align_length(num_keys*sizeof(KeyType))); + tmp_counters = (uint64_t*)(((char*)tmp_value_buffer)+align_length(num_keys*dim*sizeof(ValueType))); + d_evict_num_keys = (size_t*)(((char*)tmp_counters)+align_length(sizeof(uint64_t))); + + //found first + hkv_table_->find(num_keys, keys, values, d_found, nullptr, stream); + + //fill not found + uint32_t block_dim = max(dim, static_cast(32)); + uint32_t grid_dim = SM_NUM*(NTHREAD_PER_SM/block_dim); + if (num_keys>>( + curand_states_, values, d_found, dim,num_keys); + } else if (initializer_ == "uniform") { + generate_uniform_kernel<<>>( + curand_states_, values, d_found, dim,num_keys); + } else { + try { + float val = std::stof(initializer_); + const_initializer_kernel<<>>( + val, values, d_found, dim); + } catch (std::invalid_argument& err) { + throw std::runtime_error("Unrecognized initializer {" + initializer_ + "}"); + } + } + + select_no_found_kernel<<>>(keys,values,d_found,num_keys,dim,tmp_key_buffer,tmp_value_buffer,tmp_counters); + CUDACHECK(cudaMemcpyAsync(h_tmp_counters,tmp_counters,sizeof(uint64_t),cudaMemcpyDeviceToHost,stream)); + CUDACHECK(cudaStreamSynchronize(stream)); + + hkv_table_->insert_and_evict(h_tmp_counters[0], tmp_key_buffer, tmp_value_buffer, nullptr, tmp_keys,tmp_values,nullptr,d_evict_num_keys, stream); + + CUDACHECK(cudaMemcpyAsync(evict_num_keys,d_evict_num_keys,sizeof(size_t),cudaMemcpyDeviceToHost,stream)); + CUDACHECK(cudaFreeAsync(d_found,stream)); + CUDACHECK(cudaStreamSynchronize(stream)); +} + +template +void HKVVariable::copy_evict_keys(const KeyType* keys, const ValueType* values,size_t num_keys,size_t dim, KeyType* ret_keys, ValueType* ret_values, cudaStream_t stream) { + + CUDACHECK(cudaMemcpyAsync(ret_keys,keys,sizeof(KeyType)*num_keys,cudaMemcpyDeviceToDevice,stream)); + CUDACHECK(cudaMemcpyAsync(ret_values,values,sizeof(ValueType)*num_keys*dim,cudaMemcpyDeviceToDevice,stream)); + CUDACHECK(cudaStreamSynchronize(stream)); + +} + template void HKVVariable::lookup(const KeyType* keys, ValueType** values, size_t num_keys, cudaStream_t stream) { diff --git a/sparse_operation_kit/kit_src/variable/impl/hkv_variable.h b/sparse_operation_kit/kit_src/variable/impl/hkv_variable.h index c87013a728..631d57b3be 100644 --- a/sparse_operation_kit/kit_src/variable/impl/hkv_variable.h +++ b/sparse_operation_kit/kit_src/variable/impl/hkv_variable.h @@ -47,6 +47,14 @@ class HKVVariable : public VariableBase { void lookup(const KeyType *keys, ValueType *values, size_t num_keys, cudaStream_t stream = 0) override; + + void lookup_with_evict(const KeyType *keys, KeyType *tmp_keys, ValueType *tmp_values, + ValueType *values, uint64_t *evict_num_keys, uint64_t num_keys, + cudaStream_t stream = 0) override; + + void copy_evict_keys(const KeyType *keys, const ValueType *values, size_t num_keys, size_t dim, + KeyType *ret_keys, ValueType *ret_values, cudaStream_t stream = 0) override; + void lookup(const KeyType *keys, ValueType **values, size_t num_keys, cudaStream_t stream = 0) override; void scatter_add(const KeyType *keys, const ValueType *values, size_t num_keys, diff --git a/sparse_operation_kit/kit_src/variable/impl/variable_base.cu b/sparse_operation_kit/kit_src/variable/impl/variable_base.cu index 286d511c04..7c89ea13e3 100644 --- a/sparse_operation_kit/kit_src/variable/impl/variable_base.cu +++ b/sparse_operation_kit/kit_src/variable/impl/variable_base.cu @@ -22,6 +22,16 @@ namespace sok { +uint64_t align_length(uint64_t num) { + // Check if num is already a multiple of 16 + if (num % 16 == 0) { + return num; + } + // Find the next multiple of 16 + uint64_t alignedNum = ((num / 16) + 1) * 16; + return alignedNum; +} + template std::shared_ptr> VariableFactory::create( int64_t rows, int64_t cols, const std::string &type, const std::string &initializer, diff --git a/sparse_operation_kit/kit_src/variable/impl/variable_base.h b/sparse_operation_kit/kit_src/variable/impl/variable_base.h index 4bc7195614..544ecd0571 100644 --- a/sparse_operation_kit/kit_src/variable/impl/variable_base.h +++ b/sparse_operation_kit/kit_src/variable/impl/variable_base.h @@ -24,6 +24,8 @@ namespace sok { +uint64_t align_length(uint64_t num); + template class VariableBase { public: @@ -43,6 +45,14 @@ class VariableBase { cudaStream_t stream = 0) = 0; virtual void lookup(const KeyType *keys, ValueType **values, size_t num_keys, cudaStream_t stream = 0) = 0; + + virtual void lookup_with_evict(const KeyType *keys, KeyType *tmp_keys, ValueType *tmp_values, + ValueType *values, uint64_t *evict_num_keys, uint64_t num_keys, + cudaStream_t stream = 0) = 0; + + virtual void copy_evict_keys(const KeyType *keys, const ValueType *values, size_t num_keys, + size_t dim, KeyType *ret_keys, ValueType *ret_values, + cudaStream_t stream = 0) = 0; virtual void scatter_add(const KeyType *keys, const ValueType *values, size_t num_keys, cudaStream_t stream = 0) = 0; virtual void scatter_update(const KeyType *keys, const ValueType *values, size_t num_keys, diff --git a/sparse_operation_kit/kit_src/variable/kernels/dummy_var.cc b/sparse_operation_kit/kit_src/variable/kernels/dummy_var.cc index 7e282fb32f..d922fe4616 100644 --- a/sparse_operation_kit/kit_src/variable/kernels/dummy_var.cc +++ b/sparse_operation_kit/kit_src/variable/kernels/dummy_var.cc @@ -85,6 +85,20 @@ void DummyVar::SparseRead(const void* keys, void* values, si stream); } +template +void DummyVar::SparseReadEvict(const void* keys, void* tmp_keys, void* tmp_values, void *values , uint64_t* evict_num_keys, size_t num_keys, + cudaStream_t stream) { + check_var(); + var_->lookup_with_evict(static_cast(keys), static_cast(tmp_keys),static_cast(tmp_values),static_cast(values),evict_num_keys, num_keys, + stream); +} + +template +void DummyVar::CopyEvictKeys(const void* keys, const void* values,size_t num_keys,size_t dim, void* ret_keys, void* ret_values, cudaStream_t stream) { + check_var(); + var_->copy_evict_keys(static_cast(keys), static_cast(values),num_keys,dim,static_cast(ret_keys),static_cast(ret_values),stream); +} + template void DummyVar::ScatterAdd(const void* keys, const void* values, size_t num_keys, cudaStream_t stream) { diff --git a/sparse_operation_kit/kit_src/variable/kernels/dummy_var.h b/sparse_operation_kit/kit_src/variable/kernels/dummy_var.h index 49d005512f..0ba4a3841d 100644 --- a/sparse_operation_kit/kit_src/variable/kernels/dummy_var.h +++ b/sparse_operation_kit/kit_src/variable/kernels/dummy_var.h @@ -57,6 +57,10 @@ class DummyVar : public ResourceBase { void Assign(const void *keys, const void *values, size_t num_keys, cudaStream_t stream); void SparseRead(const void *keys, void *values, size_t num_keys, cudaStream_t stream); + void SparseReadEvict(const void *keys, void *tmp_keys, void *tmp_values, void *values, + uint64_t *evict_num_keys, size_t num_keys, cudaStream_t stream); + void CopyEvictKeys(const void *keys, const void *values, size_t num_keys, size_t dim, + void *ret_keys, void *ret_values, cudaStream_t stream); void ScatterAdd(const void *keys, const void *values, size_t num_keys, cudaStream_t stream); void ScatterUpdate(const void *keys, const void *values, size_t num_keys, cudaStream_t stream); diff --git a/sparse_operation_kit/kit_src/variable/kernels/dummy_var_ops.cc b/sparse_operation_kit/kit_src/variable/kernels/dummy_var_ops.cc index ec77e02613..3f6418b78f 100644 --- a/sparse_operation_kit/kit_src/variable/kernels/dummy_var_ops.cc +++ b/sparse_operation_kit/kit_src/variable/kernels/dummy_var_ops.cc @@ -252,6 +252,82 @@ REGISTER_GPU_KERNELS(int32_t, int32_t, float, float); #endif #undef REGISTER_GPU_KERNELS +// ----------------------------------------------------------------------------------------------- +// DummyVarSparseReadEvict +// ----------------------------------------------------------------------------------------------- +template +class DummyVarSparseReadEvictOp : public OpKernel { + public: + explicit DummyVarSparseReadEvictOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + core::RefCountPtr> var; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &var)); + + tf_shared_lock ml(*var->mu()); + + const Tensor* indices = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("indices", &indices)); + + int64_t cols = var->cols(); + int64_t rows = indices->NumElements(); + + AllocatorAttributes alloc_attr; + alloc_attr.set_on_host(true); + // temp buffer + Tensor tmp_indices; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_INT64, {rows}, &tmp_indices)); + + + Tensor tmp_values; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum::v(), {rows*cols}, &tmp_values)); + + Tensor tmp_evict_num; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_UINT64, {1}, &tmp_evict_num, alloc_attr)); + + Tensor* output = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {rows, cols}, &output)); + + // Get cuda stream of tensorflow + auto device_ctx = ctx->op_device_context(); + OP_REQUIRES(ctx, device_ctx != nullptr, errors::Aborted("No valid device context.")); + cudaStream_t stream = stream_executor::gpu::AsGpuStreamValue(device_ctx->stream()); + + var->SparseReadEvict(indices->data(),tmp_indices.data(),tmp_values.data(), output->data(),(uint64_t*)tmp_evict_num.data(), rows, stream); + size_t evict_num = ((size_t*)tmp_evict_num.data())[0]; + Tensor* output_evict_keys = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(1, {evict_num}, &output_evict_keys)); + + Tensor* output_evict_value = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(2, {evict_num, cols}, &output_evict_value)); + if (evict_num > 0){ + var->CopyEvictKeys(tmp_indices.data(),tmp_values.data(),evict_num,cols,output_evict_keys->data(),output_evict_value->data(),stream); + } + + } +}; + +#define REGISTER_GPU_KERNELS(key_type_tf, key_type, dtype_tf, dtype) \ + REGISTER_KERNEL_BUILDER(Name("DummyVarSparseReadEvict") \ + .Device(DEVICE_GPU) \ + .HostMemory("resource") \ + .TypeConstraint("key_type") \ + .TypeConstraint("dtype"), \ + DummyVarSparseReadEvictOp) +#if TF_VERSION_MAJOR == 1 +REGISTER_GPU_KERNELS(int64, int64_t, float, float); +REGISTER_GPU_KERNELS(int32, int32_t, float, float); +// REGISTER_GPU_KERNELS(int64, int64_t, Eigen::half, __half); +// REGISTER_GPU_KERNELS(int32, int32_t, Eigen::half, __half); +#else +REGISTER_GPU_KERNELS(int64_t, int64_t, float, float); +REGISTER_GPU_KERNELS(int32_t, int32_t, float, float); +// REGISTER_GPU_KERNELS(int64_t, int64_t, Eigen::half, __half); +// REGISTER_GPU_KERNELS(int32_t, int32_t, Eigen::half, __half); +#endif +#undef REGISTER_GPU_KERNELS + + // ----------------------------------------------------------------------------------------------- // DummyVarScatterAdd // ----------------------------------------------------------------------------------------------- diff --git a/sparse_operation_kit/kit_src/variable/ops/dummy_var_ops.cc b/sparse_operation_kit/kit_src/variable/ops/dummy_var_ops.cc index 899d11462a..70e8199596 100644 --- a/sparse_operation_kit/kit_src/variable/ops/dummy_var_ops.cc +++ b/sparse_operation_kit/kit_src/variable/ops/dummy_var_ops.cc @@ -50,7 +50,6 @@ REGISTER_OP("DummyVarExportIf") .Attr("dtype: {float32} = DT_FLOAT") .SetShapeFn([](InferenceContext* c) { return sok_tsl_status(); }); - REGISTER_OP("DummyVarSparseRead") .Input("resource: resource") .Input("indices: key_type") @@ -79,6 +78,36 @@ REGISTER_OP("DummyVarSparseRead") return sok_tsl_status(); }); +REGISTER_OP("DummyVarSparseReadEvict") + .Input("resource: resource") + .Input("indices: key_type") + .Output("output: dtype") + .Output("evict_keys: key_type") + .Output("evict_values: dtype") + .Attr("key_type: {int32, int64}") + .Attr("dtype: {float32, float16} = DT_FLOAT") + .SetShapeFn([](InferenceContext* c) { + //// Get handle.shape[1] + //auto handle_shapes_and_types = c->input_handle_shapes_and_types(0); + //if (handle_shapes_and_types == nullptr) { + // return sok_tsl_status(); + //} + //auto handle_shape = (*handle_shapes_and_types)[0].shape; + //ShapeHandle handle_shape_1; + //TF_RETURN_IF_ERROR(c->Subshape(handle_shape, 1, 2, &handle_shape_1)); + + //// rank(indices) should == 1 + //ShapeHandle indices_shape; + //TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape)); + + //// Set output shape = [indices.shape[0], handle.shape[1]] + //ShapeHandle output_shape; + //TF_RETURN_IF_ERROR(c->Concatenate(c->input(1), handle_shape_1, &output_shape)); + //c->set_output(0, output_shape); + + return sok_tsl_status(); + }); + namespace { Status DummyVarScatterShapeFn(InferenceContext* c) { // Get handle.shape[1] diff --git a/sparse_operation_kit/sparse_operation_kit/__init__.py b/sparse_operation_kit/sparse_operation_kit/__init__.py index 6769b78bc0..4b27f6cd23 100644 --- a/sparse_operation_kit/sparse_operation_kit/__init__.py +++ b/sparse_operation_kit/sparse_operation_kit/__init__.py @@ -65,7 +65,7 @@ from sparse_operation_kit.optimizer import SGD -from sparse_operation_kit.lookup import lookup_sparse +from sparse_operation_kit.lookup import lookup_sparse, sparse_read_and_evict from sparse_operation_kit.lookup import all2all_dense_embedding from sparse_operation_kit.dump_load import dump, load, incremental_model_dump diff --git a/sparse_operation_kit/sparse_operation_kit/dynamic_variable.py b/sparse_operation_kit/sparse_operation_kit/dynamic_variable.py index f28b62e8fb..b2f26b4cef 100644 --- a/sparse_operation_kit/sparse_operation_kit/dynamic_variable.py +++ b/sparse_operation_kit/sparse_operation_kit/dynamic_variable.py @@ -17,6 +17,7 @@ import tensorflow as tf from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops.resource_variable_ops import variable_shape from tensorflow.python.ops.resource_variable_ops import ResourceVariable, variable_accessed from tensorflow.python.eager import context from tensorflow.python.ops import resource_variable_ops diff --git a/sparse_operation_kit/sparse_operation_kit/lookup.py b/sparse_operation_kit/sparse_operation_kit/lookup.py index 7026500d7a..6831fd919e 100644 --- a/sparse_operation_kit/sparse_operation_kit/lookup.py +++ b/sparse_operation_kit/sparse_operation_kit/lookup.py @@ -37,6 +37,7 @@ from sparse_operation_kit.distributed_variable import LocalizedVariable from sparse_operation_kit.dynamic_variable import DynamicVariable +from sparse_operation_kit.utils import SOK_IndexedSlices import importlib try: @@ -50,6 +51,30 @@ pass +@tf.RegisterGradient("DummyVarSparseReadEvict") +def _DummyVarSparseReadEvictGrad(op, *top_grads): + handle = op.inputs[0] + indices = op.inputs[1] + key_type = op.get_attr("key_type") + dtype = op.get_attr("dtype") + variable_shape = raw_ops.dummy_var_shape(handle, key_type=key_type, dtype=dtype) + size = array_ops.expand_dims(array_ops.size(indices), 0) + values_shape = array_ops.concat([size, variable_shape[1:]], 0) + grad = array_ops.reshape(top_grads[0], values_shape) + indices = array_ops.reshape(indices, size) + + grads = [SOK_IndexedSlices()(grad, indices, values_shape)] + return grads + [None] + + +def sparse_read_and_evict(var, indices, name=None): + # only used on hybrid backend + if var.backend_type != "hybrid": + raise TypeError("sparse_read_and_evict only use on hybrid backend") + variable_accessed(var) + return raw_ops.dummy_var_sparse_read_evict(var._dummy_handle, indices, dtype=var.handle_dtype) + + def group_lookup(params, indices, dtype=None, name=None): # Fused-version of tf.nn.embedding_lookup on single GPU if not (isinstance(params, list) or isinstance(params, tuple)): diff --git a/sparse_operation_kit/sparse_operation_kit/test/function_test/run_function_test_multi_process.sh b/sparse_operation_kit/sparse_operation_kit/test/function_test/run_function_test_multi_process.sh index da9a9e40fd..72b7c64e53 100755 --- a/sparse_operation_kit/sparse_operation_kit/test/function_test/run_function_test_multi_process.sh +++ b/sparse_operation_kit/sparse_operation_kit/test/function_test/run_function_test_multi_process.sh @@ -19,6 +19,7 @@ horovodrun -np ${task_num} python lookup_sparse_distributed_test.py horovodrun -np ${task_num} python lookup_sparse_distributed_dynamic_test.py horovodrun -np ${task_num} python lookup_sparse_distributed_hkv_test.py horovodrun -np ${task_num} python lookup_sparse_hkv_incremental_dump_test.py +horovodrun -np 1 python sparse_read_evict.py #horovodrun -np ${task_num} python lookup_sparse_localized_test.py #horovodrun -np ${task_num} python lookup_sparse_localized_dynamic_test.py #horovodrun -np ${task_num} python lookup_sparse_localized_hkv_test.py diff --git a/sparse_operation_kit/sparse_operation_kit/test/function_test/tf1/lookup/sparse_read_evict.py b/sparse_operation_kit/sparse_operation_kit/test/function_test/tf1/lookup/sparse_read_evict.py new file mode 100644 index 0000000000..6d2ba75be6 --- /dev/null +++ b/sparse_operation_kit/sparse_operation_kit/test/function_test/tf1/lookup/sparse_read_evict.py @@ -0,0 +1,86 @@ +""" + Copyright (c) 2022, NVIDIA CORPORATION. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import time +import numpy as np +import tensorflow as tf +import horovod.tensorflow as hvd + +import sparse_operation_kit as sok + + +def check_overlap(arr_list, arr2): + for i in range(len(arr_list)): + # Use np.in1d to check if any element of arr2 is in arr_list[i] + overlap = np.in1d(arr2, arr_list[i]) + tmp_overlap = np.any(overlap) + if tmp_overlap: + return tmp_overlap + return False + + +if __name__ == "__main__": + hvd.init() + config = tf.compat.v1.ConfigProto() + config.gpu_options.visible_device_list = str(hvd.local_rank()) + config.gpu_options.allow_growth = True + sess = tf.compat.v1.Session(config=config) + sok.init() + iter_num = 5 + input_length = 8192 + evict_keys_list = [] + evict_values_list = [] + + # sok variables + sok_var = sok.DynamicVariable( + dimension=16, + var_type="hybrid", + initializer=str(11), + init_capacity=input_length, + max_capacity=input_length * 2, + ) + first_values = np.arange(16, dtype=np.int64) + second_values = np.arange(16, 24, dtype=np.int64) + + indices = tf.placeholder(shape=[None], dtype=tf.int64) + + # initialize optimizer + optimizer = tf.keras.optimizers.SGD(learning_rate=1.0) + sok_optimizer = sok.OptimizerWrapper(optimizer) + + embedding, evict_key, evict_value = sok.sparse_read_and_evict(sok_var, indices) + loss = tf.reduce_sum(embedding) + grads = tf.gradients(loss, [sok_var]) + apply_gradients_op = sok_optimizer.apply_gradients(zip(grads, [sok_var])) + + init_op = tf.compat.v1.global_variables_initializer() + sess.run(init_op) + for i in range(iter_num): + indices_values = np.arange(i * input_length, (i + 1) * input_length, dtype=np.int64) + embedding_np, evict_key_np, evict_value_np, _ = sess.run( + [embedding, evict_key, evict_value, apply_gradients_op], + feed_dict=dict(zip([indices], [indices_values])), + ) + + if i > 0: + assert not check_overlap( + evict_keys_list, evict_key_np + ), "Not all indices are within the specified range." + assert np.all(evict_value_np == 10), "Not all values are updated correctly." + evict_keys_list.append(evict_key_np) + evict_values_list.append(evict_value_np) + + print("[SOK INFO] : sparse_read_evict run success!") diff --git a/sparse_operation_kit/sparse_operation_kit/test/function_test/tf2/lookup/sparse_read_evict.py b/sparse_operation_kit/sparse_operation_kit/test/function_test/tf2/lookup/sparse_read_evict.py new file mode 100644 index 0000000000..5ebf64813e --- /dev/null +++ b/sparse_operation_kit/sparse_operation_kit/test/function_test/tf2/lookup/sparse_read_evict.py @@ -0,0 +1,84 @@ +""" + Copyright (c) 2022, NVIDIA CORPORATION. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import time +import numpy as np +import tensorflow as tf +import horovod.tensorflow as hvd + +import sparse_operation_kit as sok + + +def check_overlap(arr_list, arr2): + for i in range(len(arr_list)): + # Use np.in1d to check if any element of arr2 is in arr_list[i] + overlap = np.in1d(arr2, arr_list[i]) + tmp_overlap = np.any(overlap) + if tmp_overlap: + return tmp_overlap + return False + + +if __name__ == "__main__": + hvd.init() + gpus = tf.config.experimental.list_physical_devices("GPU") + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + if gpus: + tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], "GPU") + sok.init() + iter_num = 5 + input_length = 8192 + + evict_keys_list = [] + evict_values_list = [] + # sok variables + sok_var = sok.DynamicVariable( + dimension=16, + var_type="hybrid", + initializer=str(11), + init_capacity=input_length, + max_capacity=input_length * 2, + ) + + optimizer = tf.optimizers.SGD(learning_rate=1.0, momentum=0.9) + sok_optimizer = sok.OptimizerWrapper(optimizer) + + for i in range(iter_num): + indices_values = tf.constant( + range(i * input_length, (i + 1) * input_length), dtype=tf.int64 + ) + indices = tf.ragged.constant(indices_values, dtype=tf.int64) + + with tf.GradientTape() as tape: + embedding_first, evict_key, evict_value = sok.sparse_read_and_evict(sok_var, indices) + loss = tf.reduce_sum(embedding_first) + grads = tape.gradient(loss, [sok_var]) + grad_pair = zip(grads, [sok_var]) + sok_optimizer.apply_gradients(grad_pair) + + evict_key_np = evict_key.numpy() + evict_value_np = evict_value.numpy() + + if i > 0: + assert not check_overlap( + evict_keys_list, evict_key_np + ), "Not all indices are within the specified range." + assert np.all(evict_value_np == 10), "Not all values are updated correctly." + evict_keys_list.append(evict_key_np) + evict_values_list.append(evict_value_np) + + print("[SOK INFO] : sparse_read_evict run success!")