Skip to content

Commit

Permalink
cuda : use CUBLAS_COMPUTE_16F for non-attention ops
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Oct 27, 2023
1 parent 3b9ea65 commit 0f2498f
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6385,19 +6385,27 @@ inline void ggml_cuda_op_mul_mat_cublas(
}
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;

const float alpha = 1.0f;
const float beta = 0.0f;
size_t dst_as = 0;
half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);

const half alpha = 1.0f;
const half beta = 0.0f;

CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream));
CUBLAS_CHECK(
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
row_diff, src1_ncols, ne10,
&alpha, src0_ptr, CUDA_R_16F, ne00,
src1_ptr, CUDA_R_16F, ne10,
&beta, dst_dd_i, CUDA_R_32F, ldc,
CUBLAS_COMPUTE_32F,
&beta, dst_f16, CUDA_R_16F, ldc,
CUBLAS_COMPUTE_16F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));

const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);

ggml_cuda_pool_free(dst_f16, dst_as);

if (src0_as != 0) {
ggml_cuda_pool_free(src0_as_f16, src0_as);
}
Expand Down

0 comments on commit 0f2498f

Please sign in to comment.