Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Fused Div Add Softmax Operator #259

Draft
wants to merge 9 commits into
base: yuqxia/prepare
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
365 changes: 365 additions & 0 deletions src/nnfusion/core/kernels/cuda_gpu/cuda_langunit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2673,6 +2673,188 @@ inline void DispatchSoftmax(cudaStream_t stream, const int64_t rows, const int64
DispatchSoftmaxBlockUncachedImpl(stream, rows, cols, x, y);
}
}

template<typename T, int pack_size, int cols_per_thread, bool padding>
__global__ void FusedSoftmaxWarpImpl(const int64_t rows, const int64_t cols, const T* x, const T* x1, const T* x2, T* y) {
static_assert(cols_per_thread % pack_size == 0, "");
constexpr int num_packs = cols_per_thread / pack_size;
assert(cols <= cols_per_thread * kWarpSize);
using ComputeType = typename GetComputeType<T>::type;
ComputeType buf[cols_per_thread];
ComputeType buf_bias[cols_per_thread];
const int global_warp_id = blockIdx.x * blockDim.y + threadIdx.y;
const int num_global_warp = gridDim.x * blockDim.y;
const int lane_id = threadIdx.x;
for (int64_t row = global_warp_id; row < rows; row += num_global_warp) {
const int64_t row_offset = row * cols;
const T* row_x = x + row_offset;
const T scale = x1[0];
const T* row_x_bias = x2 + row_offset;

T* row_y = y + row_offset;
ComputeType thread_max = -Inf<ComputeType>();
#pragma unroll
for (int pack_id = 0; pack_id < num_packs; ++pack_id) {
const int col = (pack_id * kWarpSize + lane_id) * pack_size;
if (!padding || col < cols) {
MultiFetch<T, ComputeType, pack_size>()(buf + pack_id * pack_size, row_x + col);
MultiFetch<T, ComputeType, pack_size>()(buf_bias + pack_id * pack_size, row_x_bias + col);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
buf[pack_id * pack_size + i] = add(fdividef(buf[pack_id * pack_size + i], scale), buf_bias[pack_id * pack_size + i]);
thread_max = max(thread_max, buf[pack_id * pack_size + i]);
//thread_max = max(thread_max, buf[pack_id * pack_size + i]);
}
} else {
#pragma unroll
for (int i = 0; i < pack_size; ++i) { buf[pack_id * pack_size + i] = -Inf<ComputeType>(); }
#pragma unroll
for (int i = 0; i < pack_size; ++i) { buf_bias[pack_id * pack_size + i] = -Inf<ComputeType>(); }
}
}
const ComputeType warp_max = WarpAllReduce<MaxOp, ComputeType>(thread_max);
ComputeType thread_sum = 0;
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
buf[i] = exp(buf[i] - warp_max);
thread_sum += buf[i];
}
const ComputeType warp_sum = WarpAllReduce<SumOp, ComputeType>(thread_sum);
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) { buf[i] = buf[i] / warp_sum; }
#pragma unroll
for (int i = 0; i < num_packs; ++i) {
const int col = (i * kWarpSize + lane_id) * pack_size;
if (!padding || col < cols) {
MultiStore<ComputeType, T, pack_size>()(row_y + col, buf + i * pack_size);
}
}
}
}

template<typename T, int pack_size, int cols_per_thread, bool padding>
inline void LaunchFusedSoftmaxWarpImpl(cudaStream_t stream, const int64_t rows, const int64_t cols,
const T* x, const T* x1, const T* x2, T* y) {
constexpr int block_size = 128;
constexpr int waves = 32;
static_assert(block_size % kWarpSize == 0, "");
constexpr int rows_per_block = block_size / kWarpSize;
dim3 block_dim(kWarpSize, rows_per_block);
const int64_t num_blocks = (rows + rows_per_block - 1) / rows_per_block;
const int grid_dim_x = GetNumBlocks(block_size, num_blocks, waves);
FusedSoftmaxWarpImpl<T, pack_size, cols_per_thread, padding>
<<<grid_dim_x, block_dim, 0, stream>>>(rows, cols, x, x1, x2, y);
}

template<typename T, int pack_size, int cols_per_thread>
inline void DispatchFusedSoftmaxWarpImplPadding(cudaStream_t stream, const int64_t rows,
const int64_t cols, const T* x, const T* x1, const T* x2, T* y) {
if (cols == cols_per_thread * kWarpSize) {
LaunchFusedSoftmaxWarpImpl<T, pack_size, cols_per_thread, false>(stream, rows, cols, x, x1, x2, y);
} else {
LaunchFusedSoftmaxWarpImpl<T, pack_size, cols_per_thread, true>(stream, rows, cols, x, x1, x2, y);
}
}

template<typename T, int pack_size>
typename std::enable_if<pack_size == 1, void>::type DispatchFusedSoftmaxWarpImplCols(cudaStream_t stream,
const int64_t rows,
const int64_t cols,
const T* x, const T* x1, const T* x2, T* y) {
if (cols <= 0) { return; }
#define DEFINE_ONE_ELIF(col) \
else if (cols <= (col)*kWarpSize) { \
DispatchFusedSoftmaxWarpImplPadding<T, pack_size, col>(stream, rows, cols, x, x1, x2, y); \
}
DEFINE_ONE_ELIF(1)
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(3)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(5)
DEFINE_ONE_ELIF(6)
DEFINE_ONE_ELIF(7)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(9)
DEFINE_ONE_ELIF(10)
DEFINE_ONE_ELIF(11)
DEFINE_ONE_ELIF(12)
DEFINE_ONE_ELIF(13)
DEFINE_ONE_ELIF(14)
DEFINE_ONE_ELIF(15)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(17)
DEFINE_ONE_ELIF(18)
DEFINE_ONE_ELIF(19)
DEFINE_ONE_ELIF(20)
DEFINE_ONE_ELIF(21)
DEFINE_ONE_ELIF(22)
DEFINE_ONE_ELIF(23)
DEFINE_ONE_ELIF(24)
DEFINE_ONE_ELIF(25)
DEFINE_ONE_ELIF(26)
DEFINE_ONE_ELIF(27)
DEFINE_ONE_ELIF(28)
DEFINE_ONE_ELIF(29)
DEFINE_ONE_ELIF(30)
DEFINE_ONE_ELIF(31)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
else {
return;
}
}

template<typename T, int pack_size>
typename std::enable_if<pack_size == 2, void>::type DispatchFusedSoftmaxWarpImplCols(cudaStream_t stream,
const int64_t rows,
const int64_t cols,
const T* x, const T* x1, const T* x2, T* y) {
if (cols <= 0) { return; }
#define DEFINE_ONE_ELIF(col) \
else if (cols <= (col)*kWarpSize) { \
DispatchFusedSoftmaxWarpImplPadding<T, pack_size, col>(stream, rows, cols, x, x1, x2, y); \
}
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(6)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(10)
DEFINE_ONE_ELIF(12)
DEFINE_ONE_ELIF(14)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(18)
DEFINE_ONE_ELIF(20)
DEFINE_ONE_ELIF(22)
DEFINE_ONE_ELIF(24)
DEFINE_ONE_ELIF(26)
DEFINE_ONE_ELIF(28)
DEFINE_ONE_ELIF(30)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
else {
return;
}
}

template<typename T>
inline void DispatchFusedSoftmaxWarpImplPackSize(cudaStream_t stream, const int64_t rows,
const int64_t cols, const T* x, const T* x1, const T* x2, T* y) {
DispatchFusedSoftmaxWarpImplCols<T, 1>(stream, rows, cols, x, x1, x2, y);
}

template<typename T>
inline void DispatchFusedSoftmaxWarpImpl(cudaStream_t stream, const int64_t rows, const int64_t cols,
const T* x, const T* x1, const T* x2, T* y) {
DispatchFusedSoftmaxWarpImplPackSize<T>(stream, rows, cols, x, x1, x2, y);
}


template<typename T>
inline void DispatchFusedSoftmaxWarp(cudaStream_t stream, const int64_t rows, const int64_t cols, const T* x, const T* x1, const T* x2, T* y) {
if (cols <= 1024) {
DispatchFusedSoftmaxWarpImpl<T>(stream, rows, cols, x, x1, x2, y);
}
}
)",
R"(
/*
Expand Down Expand Up @@ -3159,6 +3341,189 @@ inline void DispatchSoftmax(cudaStream_t stream, const int64_t rows, const int64
DispatchSoftmaxBlockUncachedImpl(stream, rows, cols, x, y);
}
}

template<typename T, int pack_size, int cols_per_thread, bool padding>
__global__ void FusedSoftmaxWarpImpl(const int64_t rows, const int64_t cols, const T* x, const T* x1, const T* x2, T* y) {
static_assert(cols_per_thread % pack_size == 0, "");
constexpr int num_packs = cols_per_thread / pack_size;
assert(cols <= cols_per_thread * kWarpSize);
using ComputeType = typename GetComputeType<T>::type;
ComputeType buf[cols_per_thread];
ComputeType buf_bias[cols_per_thread];
const int global_warp_id = blockIdx.x * blockDim.y + threadIdx.y;
const int num_global_warp = gridDim.x * blockDim.y;
const int lane_id = threadIdx.x;
for (int64_t row = global_warp_id; row < rows; row += num_global_warp) {
const int64_t row_offset = row * cols;
const T* row_x = x + row_offset;
const T scale = x1[0];
const T* row_x_bias = x2 + row_offset;

T* row_y = y + row_offset;
ComputeType thread_max = -Inf<ComputeType>();
#pragma unroll
for (int pack_id = 0; pack_id < num_packs; ++pack_id) {
const int col = (pack_id * kWarpSize + lane_id) * pack_size;
if (!padding || col < cols) {
MultiFetch<T, ComputeType, pack_size>()(buf + pack_id * pack_size, row_x + col);
MultiFetch<T, ComputeType, pack_size>()(buf_bias + pack_id * pack_size, row_x_bias + col);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
buf[pack_id * pack_size + i] = add(fdividef(buf[pack_id * pack_size + i], scale), buf_bias[pack_id * pack_size + i]);
thread_max = max(thread_max, buf[pack_id * pack_size + i]);
//thread_max = max(thread_max, buf[pack_id * pack_size + i]);
}
} else {
#pragma unroll
for (int i = 0; i < pack_size; ++i) { buf[pack_id * pack_size + i] = -Inf<ComputeType>(); }
#pragma unroll
for (int i = 0; i < pack_size; ++i) { buf_bias[pack_id * pack_size + i] = -Inf<ComputeType>(); }
}
}
const ComputeType warp_max = WarpAllReduce<MaxOp, ComputeType>(thread_max);
ComputeType thread_sum = 0;
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
buf[i] = exp(buf[i] - warp_max);
thread_sum += buf[i];
}
const ComputeType warp_sum = WarpAllReduce<SumOp, ComputeType>(thread_sum);
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) { buf[i] = buf[i] / warp_sum; }
#pragma unroll
for (int i = 0; i < num_packs; ++i) {
const int col = (i * kWarpSize + lane_id) * pack_size;
if (!padding || col < cols) {
MultiStore<ComputeType, T, pack_size>()(row_y + col, buf + i * pack_size);
}
}
}
}

template<typename T, int pack_size, int cols_per_thread, bool padding>
inline void LaunchFusedSoftmaxWarpImpl(cudaStream_t stream, const int64_t rows, const int64_t cols,
const T* x, const T* x1, const T* x2, T* y) {
constexpr int block_size = 128;
constexpr int waves = 32;
static_assert(block_size % kWarpSize == 0, "");
constexpr int rows_per_block = block_size / kWarpSize;
dim3 block_dim(kWarpSize, rows_per_block);
const int64_t num_blocks = (rows + rows_per_block - 1) / rows_per_block;
const int grid_dim_x = GetNumBlocks(block_size, num_blocks, waves);
FusedSoftmaxWarpImpl<T, pack_size, cols_per_thread, padding>
<<<grid_dim_x, block_dim, 0, stream>>>(rows, cols, x, x1, x2, y);
}

template<typename T, int pack_size, int cols_per_thread>
inline void DispatchFusedSoftmaxWarpImplPadding(cudaStream_t stream, const int64_t rows,
const int64_t cols, const T* x, const T* x1, const T* x2, T* y) {
if (cols == cols_per_thread * kWarpSize) {
LaunchFusedSoftmaxWarpImpl<T, pack_size, cols_per_thread, false>(stream, rows, cols, x, x1, x2, y);
} else {
LaunchFusedSoftmaxWarpImpl<T, pack_size, cols_per_thread, true>(stream, rows, cols, x, x1, x2, y);
}
}

template<typename T, int pack_size>
typename std::enable_if<pack_size == 1, void>::type DispatchFusedSoftmaxWarpImplCols(cudaStream_t stream,
const int64_t rows,
const int64_t cols,
const T* x, const T* x1, const T* x2, T* y) {
if (cols <= 0) { return; }
#define DEFINE_ONE_ELIF(col) \
else if (cols <= (col)*kWarpSize) { \
DispatchFusedSoftmaxWarpImplPadding<T, pack_size, col>(stream, rows, cols, x, x1, x2, y); \
}
DEFINE_ONE_ELIF(1)
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(3)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(5)
DEFINE_ONE_ELIF(6)
DEFINE_ONE_ELIF(7)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(9)
DEFINE_ONE_ELIF(10)
DEFINE_ONE_ELIF(11)
DEFINE_ONE_ELIF(12)
DEFINE_ONE_ELIF(13)
DEFINE_ONE_ELIF(14)
DEFINE_ONE_ELIF(15)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(17)
DEFINE_ONE_ELIF(18)
DEFINE_ONE_ELIF(19)
DEFINE_ONE_ELIF(20)
DEFINE_ONE_ELIF(21)
DEFINE_ONE_ELIF(22)
DEFINE_ONE_ELIF(23)
DEFINE_ONE_ELIF(24)
DEFINE_ONE_ELIF(25)
DEFINE_ONE_ELIF(26)
DEFINE_ONE_ELIF(27)
DEFINE_ONE_ELIF(28)
DEFINE_ONE_ELIF(29)
DEFINE_ONE_ELIF(30)
DEFINE_ONE_ELIF(31)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
else {
return;
}
}

template<typename T, int pack_size>
typename std::enable_if<pack_size == 2, void>::type DispatchFusedSoftmaxWarpImplCols(cudaStream_t stream,
const int64_t rows,
const int64_t cols,
const T* x, const T* x1, const T* x2, T* y) {
if (cols <= 0) { return; }
#define DEFINE_ONE_ELIF(col) \
else if (cols <= (col)*kWarpSize) { \
DispatchFusedSoftmaxWarpImplPadding<T, pack_size, col>(stream, rows, cols, x, x1, x2, y); \
}
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(6)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(10)
DEFINE_ONE_ELIF(12)
DEFINE_ONE_ELIF(14)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(18)
DEFINE_ONE_ELIF(20)
DEFINE_ONE_ELIF(22)
DEFINE_ONE_ELIF(24)
DEFINE_ONE_ELIF(26)
DEFINE_ONE_ELIF(28)
DEFINE_ONE_ELIF(30)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
else {
return;
}
}

template<typename T>
inline void DispatchFusedSoftmaxWarpImplPackSize(cudaStream_t stream, const int64_t rows,
const int64_t cols, const T* x, const T* x1, const T* x2, T* y) {
DispatchFusedSoftmaxWarpImplCols<T, 1>(stream, rows, cols, x, x1, x2, y);
}

template<typename T>
inline void DispatchFusedSoftmaxWarpImpl(cudaStream_t stream, const int64_t rows, const int64_t cols,
const T* x, const T* x1, const T* x2, T* y) {
DispatchFusedSoftmaxWarpImplPackSize<T>(stream, rows, cols, x, x1, x2, y);
}


template<typename T>
inline void DispatchFusedSoftmaxWarp(cudaStream_t stream, const int64_t rows, const int64_t cols, const T* x, const T* x1, const T* x2, T* y) {
if (cols <= 1024) {
DispatchFusedSoftmaxWarpImpl<T>(stream, rows, cols, x, x1, x2, y);
}
}

)"

,
Expand Down
Loading