From 9cd7a54d12cd9d0a58d94f8b0838d98e46be7885 Mon Sep 17 00:00:00 2001 From: Ziyue Yang Date: Fri, 29 Mar 2024 02:15:15 +0000 Subject: [PATCH] fix bug --- .../dist_inference_cpp/dist_inference.cu | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/superbench/benchmarks/micro_benchmarks/dist_inference_cpp/dist_inference.cu b/superbench/benchmarks/micro_benchmarks/dist_inference_cpp/dist_inference.cu index be6464d76..f33fae492 100644 --- a/superbench/benchmarks/micro_benchmarks/dist_inference_cpp/dist_inference.cu +++ b/superbench/benchmarks/micro_benchmarks/dist_inference_cpp/dist_inference.cu @@ -61,14 +61,14 @@ using cublasLtHalf = hipblasLtHalf; #else #define DIST_INF_HIP_COMPUTETYPE_F32 HIPBLASLT_COMPUTE_F32 #endif -#if ROCM_VERSION >= 50700 +#if HIP_VERSION >= 50700000 #include -#if ROCM_VERSION >= 60000 +#if HIP_VERSION >= 60000000 #define HIPBLASLT_GETINDEXFROMALGO(algo) hipblaslt_ext::getIndexFromAlgo(algo) #else -static int getIndexFromAlgo(hipblasLtMatmulAlgo_t& algo) { - int* algo_ptr = (int*)algo.data; - if(*algo_ptr < 0) { +static int getIndexFromAlgo(hipblasLtMatmulAlgo_t &algo) { + int *algo_ptr = (int *)algo.data; + if (*algo_ptr < 0) { return -1; } return *algo_ptr; @@ -203,7 +203,7 @@ void InitializeABCDEF(std::vector &ha, int64_t size_a, std::vector } } -#if defined(__HIP_PLATFORM_AMD__) && ROCM_VERSION >= 50700 +#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION >= 50700000 // Tune GEMM algorithm in local rank. // Write <0 to ret_algo_time_in_ms if nothing found. // Write >=0 to ret_algo_time_in_ms and write ret_algo if something is found. @@ -389,7 +389,7 @@ void TestModel(int64_t m, int64_t n, int64_t k, float alpha, float beta, int32_t CHECK_CUBLASLT_ERROR(hipblasLtMatmulAlgoGetHeuristic(handle, matmul2, matE, matD, matF, matG, pref, 1, heuristicResult2, &returnedAlgoCount)); hipblasLtMatmulAlgo_t algo2 = heuristicResult2[0].algo; -#if ROCM_VERSION >= 50700 +#if HIP_VERSION >= 50700000 if (tune_gemm) { hipblasLtMatmulAlgo_t ret_algo; float ret_algo_time_in_ms;