You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
cublasSgemmStridedBatched should be cublasHgemmStridedBatched.
for Sum_half_half_cuda_Sum_355
extern"C" __launch_bounds__(512) __global__ voidSum_half_half_cuda_Sum_355(half* input0, half* output0)
{
int width = 768;
int block_size = 512;
constint warp_size = 32;
__shared__ float shm[warp_size];
int thread_idx = threadIdx.x;
int block_idx = blockIdx.x;
int data_idx_offset = block_idx * width;
float val = 0.0;
for (int tidx = thread_idx; tidx < width; tidx += block_size) {
int data_idx = tidx + data_idx_offset;
val += input0[data_idx];
}
val = reduceSum(val, thread_idx, block_size, shm);
if (thread_idx == 0) output0[block_idx] = val;
}
datatype of val shoule be half.
fix this two problem, the inference can produce correct output of bert-fp16.
(base) root@bad3554e6e95:/workspace/v-leiwang3/nnfusion_rt/cuda_codegen# ./main_test
Result_1913_0:
1.553955e-01 -1.488037e-01 2.043457e-01 4.392090e-01 -1.478271e-01 -6.176758e-02 -4.776001e-02 -1.222229e-02 2.309570e-01 -5.352783e-02 .. (size = 100, ends with 1.050415e-01);
Result_1913_0:
1.553955e-01 -1.488037e-01 2.043457e-01 4.392090e-01 -1.478271e-01 -6.176758e-02 -4.776001e-02 -1.222229e-02 2.309570e-01 -5.352783e-02 .. (size = 100, ends with 1.050415e-01);
Result_1913_0:
1.553955e-01 -1.488037e-01 2.043457e-01 4.392090e-01 -1.478271e-01 -6.176758e-02 -4.776001e-02 -1.222229e-02 2.309570e-01 -5.352783e-02 .. (size = 100, ends with 1.050415e-01);
Result_1913_0:
1.553955e-01 -1.488037e-01 2.043457e-01 4.392090e-01 -1.478271e-01 -6.176758e-02 -4.776001e-02 -1.222229e-02 2.309570e-01 -5.352783e-02 .. (size = 100, ends with 1.050415e-01);
Result_1913_0:
1.553955e-01 -1.488037e-01 2.043457e-01 4.392090e-01 -1.478271e-01 -6.176758e-02 -4.776001e-02 -1.222229e-02 2.309570e-01 -5.352783e-02 .. (size = 100, ends with 1.050415e-01);
Iteration time 4.993408 ms
Iteration time 4.985344 ms
Iteration time 4.983872 ms
Iteration time 4.976992 ms
Iteration time 4.975488 ms
Iteration time 4.990912 ms
Iteration time 4.998816 ms
Iteration time 4.996576 ms
Iteration time 4.985600 ms
Iteration time 4.986848 ms
Iteration time 5.005440 ms
Iteration time 5.035680 ms
Iteration time 4.999136 ms
Iteration time 4.961440 ms
Iteration time 4.982464 ms
Iteration time 4.978112 ms
Iteration time 4.981376 ms
Iteration time 4.976672 ms
Iteration time 4.970368 ms
Iteration time 4.965472 ms
Iteration time 4.961984 ms
Iteration time 4.962720 ms
Iteration time 4.976032 ms
Iteration time 4.980736 ms
Iteration time 4.964320 ms
Iteration time 4.981536 ms
Iteration time 4.958144 ms
Iteration time 4.954144 ms
Iteration time 4.967424 ms
Iteration time 4.977216 ms
Iteration time 4.976416 ms
Iteration time 4.992608 ms
Iteration time 4.969056 ms
Iteration time 4.973120 ms
Iteration time 4.973536 ms
Iteration time 4.973696 ms
Iteration time 4.979328 ms
Iteration time 4.986464 ms
Iteration time 4.961376 ms
Iteration time 4.949312 ms
Iteration time 4.963584 ms
Iteration time 4.963776 ms
Iteration time 4.956448 ms
Iteration time 4.955840 ms
Iteration time 4.962272 ms
Iteration time 4.967616 ms
Iteration time 4.967904 ms
Iteration time 4.964000 ms
Iteration time 4.974656 ms
Iteration time 4.969856 ms
Iteration time 4.950016 ms
Iteration time 4.953728 ms
Iteration time 4.949120 ms
Iteration time 5.183392 ms
Iteration time 4.967072 ms
Iteration time 5.142752 ms
Iteration time 4.955520 ms
Iteration time 4.955168 ms
Iteration time 4.949344 ms
Iteration time 4.943904 ms
Iteration time 4.933536 ms
Iteration time 4.954464 ms
Iteration time 4.960832 ms
Iteration time 5.271488 ms
Iteration time 4.963872 ms
Iteration time 4.951264 ms
Iteration time 4.952640 ms
Iteration time 4.954688 ms
Iteration time 4.939296 ms
Iteration time 4.944832 ms
Iteration time 4.935328 ms
Iteration time 4.945664 ms
Iteration time 4.944992 ms
Iteration time 4.948544 ms
Iteration time 4.959392 ms
Iteration time 4.950944 ms
Iteration time 4.950848 ms
Iteration time 4.964416 ms
Iteration time 4.951296 ms
Iteration time 4.958496 ms
Iteration time 4.943648 ms
Iteration time 4.951904 ms
Iteration time 4.970528 ms
Iteration time 4.963584 ms
Iteration time 4.956128 ms
Iteration time 4.953024 ms
Iteration time 4.948032 ms
Iteration time 4.947712 ms
Iteration time 5.234112 ms
Iteration time 4.954336 ms
Iteration time 4.960640 ms
Iteration time 4.970816 ms
Iteration time 4.982112 ms
Iteration time 4.968320 ms
Iteration time 6.102816 ms
Iteration time 4.962592 ms
Iteration time 4.957952 ms
Iteration time 4.954656 ms
Iteration time 4.949184 ms
Iteration time 4.951040 ms
Summary: [min, max, mean] = [4.933536, 6.102816, 4.986650] ms
The text was updated successfully, but these errors were encountered:
🐛 Bug
for
BatchMatMul_half_half_half_cuda_lib_BatchMatMul_427
cublasSgemmStridedBatched
should becublasHgemmStridedBatched
.for
Sum_half_half_cuda_Sum_355
datatype of
val
shoule be half.fix this two problem, the inference can produce correct output of bert-fp16.
The text was updated successfully, but these errors were encountered: