diff --git a/superbench/benchmarks/micro_benchmarks/dist_inference_cpp/CMakeLists.txt b/superbench/benchmarks/micro_benchmarks/dist_inference_cpp/CMakeLists.txt index f29149e50..27a220f90 100644 --- a/superbench/benchmarks/micro_benchmarks/dist_inference_cpp/CMakeLists.txt +++ b/superbench/benchmarks/micro_benchmarks/dist_inference_cpp/CMakeLists.txt @@ -33,6 +33,11 @@ else() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2 -DROCM_USE_FLOAT16=1") if(DEFINED ENV{USE_HIPBLASLT_DATATYPE}) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_HIPBLASLT_DATATYPE=1") + elseif(DEFINED ENV{USE_HIP_DATATYPE}) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_HIP_DATATYPE=1") + endif() + if(DEFINED ENV{USE_HIPBLAS_COMPUTETYPE}) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_HIPBLAS_COMPUTETYPE=1") endif() target_link_libraries(dist_inference MPI::MPI_CXX rccl hipblaslt hip::device) else() diff --git a/superbench/benchmarks/micro_benchmarks/dist_inference_cpp/dist_inference.cu b/superbench/benchmarks/micro_benchmarks/dist_inference_cpp/dist_inference.cu index 1db7e270f..c6d7cd033 100644 --- a/superbench/benchmarks/micro_benchmarks/dist_inference_cpp/dist_inference.cu +++ b/superbench/benchmarks/micro_benchmarks/dist_inference_cpp/dist_inference.cu @@ -48,10 +48,18 @@ using cublasLtHalf = hipblasLtHalf; #if defined(USE_HIPBLASLT_DATATYPE) #define DIST_INF_HIP_DATATYPE_R_16F HIPBLASLT_R_16F #define DIST_INF_HIP_DATATYPE_R_32F HIPBLASLT_R_32F +#elif defined(USE_HIP_DATATYPE) +#define DIST_INF_HIP_DATATYPE_R_16F HIP_R_16F +#define DIST_INF_HIP_DATATYPE_R_32F HIP_R_32F #else #define DIST_INF_HIP_DATATYPE_R_16F HIPBLAS_R_16F #define DIST_INF_HIP_DATATYPE_R_32F HIPBLAS_R_32F #endif +#if defined(USE_HIPBLAS_COMPUTETYPE) +#define DIST_INF_HIP_COMPUTETYPE_F32 HIPBLAS_COMPUTE_32F +#else +#define DIST_INF_HIP_COMPUTETYPE_F32 HIPBLASLT_COMPUTE_F32 +#endif #else #include #include @@ -244,8 +252,10 @@ void TestModel(int64_t m, int64_t n, int64_t k, float alpha, float beta, int32_t CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matF, DIST_INF_HIP_DATATYPE_R_16F, k, n, k)); CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matG, DIST_INF_HIP_DATATYPE_R_16F, k, n, k)); - CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul1, HIPBLASLT_COMPUTE_F32, DIST_INF_HIP_DATATYPE_R_32F)); - CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul2, HIPBLASLT_COMPUTE_F32, DIST_INF_HIP_DATATYPE_R_32F)); + CHECK_CUBLASLT_ERROR( + hipblasLtMatmulDescCreate(&matmul1, DIST_INF_HIP_COMPUTETYPE_F32, DIST_INF_HIP_DATATYPE_R_32F)); + CHECK_CUBLASLT_ERROR( + hipblasLtMatmulDescCreate(&matmul2, DIST_INF_HIP_COMPUTETYPE_F32, DIST_INF_HIP_DATATYPE_R_32F)); hipblasOperation_t trans = HIPBLAS_OP_N; CHECK_CUBLASLT_ERROR(