diff --git a/HugeCTR/src/layers/softmax_layer.cu b/HugeCTR/src/layers/softmax_layer.cu index 662c1637ce..9e518fe732 100644 --- a/HugeCTR/src/layers/softmax_layer.cu +++ b/HugeCTR/src/layers/softmax_layer.cu @@ -22,6 +22,7 @@ #include #include #include +#include #include namespace HugeCTR { @@ -36,14 +37,15 @@ SoftmaxLayer::SoftmaxLayer(const core23::Tensor& input_tensor, dims_ = input_tensor.shape().dims(); hidden_size_ = input_tensor.shape().size(dims_ - 1); n_rows_ = len_ / hidden_size_; - workspace23_ = - core23::Tensor({(int64_t)n_rows_}, core23::DataType(core23::ToScalarType::value)); - identity23_ = - core23::Tensor({(int64_t)hidden_size_}, core23::DataType(core23::ToScalarType::value)); - softmax_out23_ = - core23::Tensor(input_tensor.shape(), core23::DataType(core23::ToScalarType::value)); + core23::BufferParams buf_p{.channel = GetBlobsBufferChannel()}; + auto param = (input_tensor.my_params().buffer_params(buf_p)); + workspace23_ = core23::Tensor( + param.shape({(int64_t)n_rows_}).data_type(core23::DataType(core23::ToScalarType::value))); + identity23_ = core23::Tensor(param.shape({(int64_t)hidden_size_}) + .data_type(core23::DataType(core23::ToScalarType::value))); + softmax_out23_ = core23::Tensor(param.shape(input_tensor.shape()) + .data_type(core23::DataType(core23::ToScalarType::value))); } - template void SoftmaxLayer::initialize() { CudaDeviceContext context(get_device_id());