Skip to content

Commit

Permalink
fix lint and ut
Browse files Browse the repository at this point in the history
  • Loading branch information
yukirora committed Nov 17, 2023
1 parent 48215b3 commit 3462f7c
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ template <typename T> cudaDataType_t get_datatype() {
if (std::is_same<T, fp8e5m2>::value)
return CUDA_R_8F_E5M2;
if (std::is_same<T, int8>::value)
return CUDA_R_8I;
return CUDA_R_8I;
throw std::invalid_argument("Unknown type");
}

Expand Down Expand Up @@ -166,7 +166,7 @@ int main(int argc, char **argv) {
else if (args.in_type == "fp8e5m2")
run<fp8e5m2, fp8e4m3, fp16>(&args);
else if (args.in_type == "int8")
run<int8>(&args);
run<int8>(&args);
else
throw std::invalid_argument("Unknown type " + args.in_type);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l
if (a_type == CUDA_R_64F || b_type == CUDA_R_64F)
gemm_compute_type = CUBLAS_COMPUTE_64F;
if (a_type == CUDA_R_8I)
gemm_compute_type = CUBLAS_COMPUTE_32I;
gemm_compute_type = CUBLAS_COMPUTE_32I;

cublasLtMatmulDesc_t op_desc = nullptr;
CUBLAS_CHECK(cublasLtMatmulDescCreate(&op_desc, gemm_compute_type, CUDA_R_32F));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_cublaslt_gemm_command_generation(self):
parameters='--batch 2:16:2 --shapes 2:4,4:8,8:32 32:128:4,128,128 --in_types fp16 fp32 fp64 int8',
)
self.assertTrue(benchmark._preprocess())
self.assertEqual(4 * (2 * 2 * 3 + 2) * 3, len(benchmark._commands))
self.assertEqual(4 * (2 * 2 * 3 + 2) * len(benchmark._args.in_types), len(benchmark._commands))

def cmd(t, b, m, n, k):
return f'{benchmark._CublasLtBenchmark__bin_path} -m {m} -n {n} -k {k} -b {b} -w 20 -i 50 -t {t}'
Expand Down

0 comments on commit 3462f7c

Please sign in to comment.