From f53d941a22fc0746e98ef3560a6799422be8fa47 Mon Sep 17 00:00:00 2001 From: Yuting Jiang Date: Mon, 20 Nov 2023 11:21:20 +0800 Subject: [PATCH] Benchmarks: micro benchmarks - add int8 support for cublaslt function (#574) **Description** add int8 support for cublaslt function. --- superbench/benchmarks/micro_benchmarks/cublaslt_function.py | 2 +- .../micro_benchmarks/cublaslt_gemm/cublaslt_gemm.cu | 5 +++++ .../micro_benchmarks/cublaslt_gemm/cublaslt_utils.cc | 2 ++ tests/benchmarks/micro_benchmarks/test_cublaslt_function.py | 6 +++--- 4 files changed, 11 insertions(+), 4 deletions(-) diff --git a/superbench/benchmarks/micro_benchmarks/cublaslt_function.py b/superbench/benchmarks/micro_benchmarks/cublaslt_function.py index 59733ea10..9bf3d99f3 100644 --- a/superbench/benchmarks/micro_benchmarks/cublaslt_function.py +++ b/superbench/benchmarks/micro_benchmarks/cublaslt_function.py @@ -23,7 +23,7 @@ def __init__(self, name, parameters=''): super().__init__(name, parameters) self._bin_name = 'cublaslt_gemm' - self._in_types = ['fp64', 'fp32', 'fp16', 'bf16', 'fp8e4m3', 'fp8e5m2'] + self._in_types = ['fp64', 'fp32', 'fp16', 'bf16', 'fp8e4m3', 'fp8e5m2', 'int8'] def mrange(self, start, stop=-1, multiplication_factor=2): """Range constructor with multiplication factor. diff --git a/superbench/benchmarks/micro_benchmarks/cublaslt_gemm/cublaslt_gemm.cu b/superbench/benchmarks/micro_benchmarks/cublaslt_gemm/cublaslt_gemm.cu index 788b1989d..002b06447 100644 --- a/superbench/benchmarks/micro_benchmarks/cublaslt_gemm/cublaslt_gemm.cu +++ b/superbench/benchmarks/micro_benchmarks/cublaslt_gemm/cublaslt_gemm.cu @@ -16,6 +16,7 @@ using fp16 = half; using bf16 = nv_bfloat16; using fp8e4m3 = __nv_fp8_e4m3; using fp8e5m2 = __nv_fp8_e5m2; +using int8 = int8_t; struct Args { int m = 16; @@ -84,6 +85,8 @@ template cudaDataType_t get_datatype() { return CUDA_R_8F_E4M3; if (std::is_same::value) return CUDA_R_8F_E5M2; + if (std::is_same::value) + return CUDA_R_8I; throw std::invalid_argument("Unknown type"); } @@ -162,6 +165,8 @@ int main(int argc, char **argv) { run(&args); else if (args.in_type == "fp8e5m2") run(&args); + else if (args.in_type == "int8") + run(&args); else throw std::invalid_argument("Unknown type " + args.in_type); diff --git a/superbench/benchmarks/micro_benchmarks/cublaslt_gemm/cublaslt_utils.cc b/superbench/benchmarks/micro_benchmarks/cublaslt_gemm/cublaslt_utils.cc index 4842c22d1..6ec5a101e 100644 --- a/superbench/benchmarks/micro_benchmarks/cublaslt_gemm/cublaslt_utils.cc +++ b/superbench/benchmarks/micro_benchmarks/cublaslt_gemm/cublaslt_utils.cc @@ -62,6 +62,8 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l gemm_compute_type = CUBLAS_COMPUTE_32F; if (a_type == CUDA_R_64F || b_type == CUDA_R_64F) gemm_compute_type = CUBLAS_COMPUTE_64F; + if (a_type == CUDA_R_8I) + gemm_compute_type = CUBLAS_COMPUTE_32I; cublasLtMatmulDesc_t op_desc = nullptr; CUBLAS_CHECK(cublasLtMatmulDescCreate(&op_desc, gemm_compute_type, CUDA_R_32F)); diff --git a/tests/benchmarks/micro_benchmarks/test_cublaslt_function.py b/tests/benchmarks/micro_benchmarks/test_cublaslt_function.py index b504062a2..a6fae8f0e 100644 --- a/tests/benchmarks/micro_benchmarks/test_cublaslt_function.py +++ b/tests/benchmarks/micro_benchmarks/test_cublaslt_function.py @@ -63,15 +63,15 @@ def test_cublaslt_gemm_command_generation(self): (benchmark_cls, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(self.benchmark_name, Platform.CUDA) benchmark = benchmark_cls( self.benchmark_name, - parameters='--batch 2:16:2 --shapes 2:4,4:8,8:32 32:128:4,128,128 --in_types fp16 fp32 fp64', + parameters='--batch 2:16:2 --shapes 2:4,4:8,8:32 32:128:4,128,128 --in_types fp16 fp32 fp64 int8', ) self.assertTrue(benchmark._preprocess()) - self.assertEqual(4 * (2 * 2 * 3 + 2) * 3, len(benchmark._commands)) + self.assertEqual(4 * (2 * 2 * 3 + 2) * len(benchmark._args.in_types), len(benchmark._commands)) def cmd(t, b, m, n, k): return f'{benchmark._CublasLtBenchmark__bin_path} -m {m} -n {n} -k {k} -b {b} -w 20 -i 50 -t {t}' - for _t in ['fp16', 'fp32', 'fp64']: + for _t in ['fp16', 'fp32', 'fp64', 'int8']: for _b in [2, 4, 8, 16]: for _m in [2, 4]: for _n in [4, 8]: