Skip to content

Commit

Permalink
Add amax
Browse files Browse the repository at this point in the history
Signed-off-by: Jiang, Zhiwei <[email protected]>
  • Loading branch information
zhiweij1 committed Jul 11, 2024
1 parent fa9f8da commit 7fe2025
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions features/feature_case/cublasLt/matmul.cu
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,8 @@ void fgemmlt(cublasLtHandle_t ltHandle, int m, int n, int k,
cublasLtMatrixLayout_t Adesc,
cublasLtMatrixLayout_t Bdesc,
cublasLtMatrixLayout_t Cdesc,
cublasLtMatrixLayout_t Ddesc) {
cublasLtMatrixLayout_t Ddesc,
float *amax_d) {
cublasLtMatmulDesc_t matmulDesc = NULL;
cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);

Expand All @@ -746,6 +747,7 @@ void fgemmlt(cublasLtHandle_t ltHandle, int m, int n, int k,
cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scale_a, sizeof(scale_a));
cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scale_b, sizeof(scale_b));
cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &scale_d, sizeof(scale_d));
cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &amax_d, sizeof(amax_d));

cublasLtEpilogue_t ep = CUBLASLT_EPILOGUE_RELU;
cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &ep, sizeof(ep));
Expand Down Expand Up @@ -816,8 +818,11 @@ bool test7() {

// Matmul

float *amax_d;
cudaMallocManaged(&amax_d, sizeof(float));

fgemmlt(ltHandle, m, n, k, (const float *)Adev, (const float *)Bdev, (const float *)Cdev, (float *)Ddev,
&alpha, &beta, lda, ldb, ldc, ldd, Adesc_col_major, Bdesc_col_major, Cdesc_col_major, Ddesc_col_major);
&alpha, &beta, lda, ldb, ldc, ldd, Adesc_col_major, Bdesc_col_major, Cdesc_col_major, Ddesc_col_major, amax_d);
cudaStreamSynchronize(0);

// Check result
Expand All @@ -832,11 +837,14 @@ bool test7() {
break;
}
}
if (*amax_d != 8300)
error = true;

printf("d:\n");
for (int i = 0; i < ldd * n; i++)
printf("%f, ", Dhost[i]);
printf("\n");
printf("amax_d:%f\n", *amax_d);

if (error) {
printf("error\n");
Expand All @@ -852,6 +860,7 @@ bool test7() {
cudaFree(Adev);
cudaFree(Bdev);
cudaFree(Ddev);
cudaFree(amax_d);

return !error;
}
Expand Down

0 comments on commit 7fe2025

Please sign in to comment.