Skip to content

Commit

Permalink
change index's dtype for int to int64 (#55949)
Browse files Browse the repository at this point in the history
  • Loading branch information
AnnaTrainingG authored Aug 9, 2023
1 parent 4eba647 commit 8d181e3
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 27 deletions.
9 changes: 3 additions & 6 deletions paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,9 @@ void FusedRopeGradKernel(const Context& dev_ctx,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv) {
int numel = dout_q.numel();
int64_t numel = dout_q.numel();
if (numel <= 0) return;
dev_ctx.template Alloc<T>(dq);
dq->Resize(dout_q.dims());
// small size for broadcast
auto batch_size = dout_q.dims()[0];
auto num_heads = dout_q.dims()[2];
Expand All @@ -51,8 +50,8 @@ void FusedRopeGradKernel(const Context& dev_ctx,
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size);

int grid = config.block_per_grid.x;
int block = config.thread_per_block.x;
int64_t grid = config.block_per_grid.x;
int64_t block = config.thread_per_block.x;
auto stream = dev_ctx.stream();

phi::Array<T*, 3> outs_data;
Expand All @@ -65,15 +64,13 @@ void FusedRopeGradKernel(const Context& dev_ctx,

if (dout_k.get_ptr()) {
dev_ctx.template Alloc<T>(dk);
dk->Resize(dout_q.dims());
outs_data[1] = dk->data<T>();
ins_data[1] = dout_k->data<T>();
num_inputs++;
}

if (dout_v.get_ptr()) {
dev_ctx.template Alloc<T>(dv);
dv->Resize(dout_q.dims());
outs_data[2] = dv->data<T>();
ins_data[2] = dout_v->data<T>();
num_inputs++;
Expand Down
9 changes: 3 additions & 6 deletions paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,9 @@ void FusedRopeKernel(const Context& dev_ctx,
DenseTensor* out_q,
DenseTensor* out_k,
DenseTensor* out_v) {
int numel = q.numel();
int64_t numel = q.numel();
if (numel <= 0) return;
dev_ctx.template Alloc<T>(out_q);
out_q->Resize(q.dims());
// small size for broadcast
auto batch_size = q.dims()[0];
auto num_heads = q.dims()[2];
Expand All @@ -51,8 +50,8 @@ void FusedRopeKernel(const Context& dev_ctx,
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size);

int grid = config.block_per_grid.x;
int block = config.thread_per_block.x;
int64_t grid = config.block_per_grid.x;
int64_t block = config.thread_per_block.x;
auto stream = dev_ctx.stream();

phi::Array<T*, 3> outs_data;
Expand All @@ -65,15 +64,13 @@ void FusedRopeKernel(const Context& dev_ctx,

if (k.get_ptr()) {
dev_ctx.template Alloc<T>(out_k);
out_k->Resize(q.dims());
ins_data[1] = k->data<T>();
outs_data[1] = out_k->data<T>();
num_inputs++;
}

if (v.get_ptr()) {
dev_ctx.template Alloc<T>(out_v);
out_v->Resize(q.dims());
ins_data[2] = v->data<T>();
outs_data[2] = out_v->data<T>();
num_inputs++;
Expand Down
32 changes: 18 additions & 14 deletions paddle/phi/kernels/fusion/gpu/fused_rope_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,20 @@ __global__ void VectorizedFusedRopeKernel(phi::Array<const T*, 3> ins_data,
phi::Array<const T*, 2> sin_cos_data,
bool flag_sin_cos,
int sign,
int batch_size,
int seq_len,
int num_heads,
int head_dim,
int64_t batch_size,
int64_t seq_len,
int64_t num_heads,
int64_t head_dim,
phi::Array<T*, 3> outs_data,
int num_inputs,
MPType div_c) {
int index = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
int stride = gridDim.x * blockDim.x * VecSize;
int size = batch_size * seq_len * num_heads * head_dim;
int64_t index =
(static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
threadIdx.x) *
VecSize;
int64_t stride = static_cast<int64_t>(gridDim.x) *
static_cast<int64_t>(blockDim.x) * VecSize;
int64_t size = batch_size * seq_len * num_heads * head_dim;
MPType sin_value[VecSize];
MPType cos_value[VecSize];
MPType result[VecSize];
Expand All @@ -44,11 +48,11 @@ __global__ void VectorizedFusedRopeKernel(phi::Array<const T*, 3> ins_data,
for (; index < size; index += stride) {
if (flag_sin_cos) {
#pragma unroll
for (int nx = 0; nx < VecSize; ++nx) {
int index_wc = (index + nx) % (seq_len * num_heads * head_dim);
int pos_seq = index_wc / (num_heads * head_dim);
int pos_head = index_wc % head_dim;
int index_sc = pos_seq * head_dim + pos_head;
for (int64_t nx = 0; nx < VecSize; ++nx) {
int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim);
int64_t pos_seq = index_wc / (num_heads * head_dim);
int64_t pos_head = index_wc % head_dim;
int64_t index_sc = pos_seq * head_dim + pos_head;
const T* sin_input = sin_cos_data[0] + index_sc;
const T* cos_input = sin_cos_data[1] + index_sc;

Expand All @@ -59,8 +63,8 @@ __global__ void VectorizedFusedRopeKernel(phi::Array<const T*, 3> ins_data,
#pragma unroll
for (int nx = 0; nx < VecSize; ++nx) {
// get sin_index and cos_index
int index_wc = (index + nx) % (seq_len * num_heads * head_dim);
int pos_seq = index_wc / (num_heads * head_dim);
int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim);
int64_t pos_seq = index_wc / (num_heads * head_dim);
MPType idx = static_cast<MPType>((index_wc % head_dim) / 2 * 2.0);
MPType indicses =
static_cast<MPType>(1) /
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from paddle.framework import in_dynamic_mode


def fused_rotary_position_embedding(q, k, v, sin=None, cos=None):
def fused_rotary_position_embedding(q, k=None, v=None, sin=None, cos=None):
r"""
Fused rotary position embedding.
Expand Down Expand Up @@ -53,3 +53,7 @@ def fused_rotary_position_embedding(q, k, v, sin=None, cos=None):
"""
if in_dynamic_mode():
return _C_ops.fused_rotary_position_embedding(q, k, v, sin, cos)

raise RuntimeError(
"This feature is currently supported only in dynamic mode and with CUDAPlace."
)
9 changes: 9 additions & 0 deletions test/legacy_test/test_fused_rotary_position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,15 @@ def test_fused_dropout_add_sin_cos(self):
p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05
)

def test_error(self):
paddle.enable_static()
with self.assertRaises(RuntimeError):
static_q = paddle.static.data(
name="q", shape=self.shape, dtype=self.dtype
)
fused_rotary_position_embedding(static_q, static_q, static_q)
paddle.disable_static()


if __name__ == '__main__':
unittest.main()

0 comments on commit 8d181e3

Please sign in to comment.