Skip to content

Commit

Permalink
Merge branch 'release/0.10' into yutji/hipblas-unit
Browse files Browse the repository at this point in the history
  • Loading branch information
yukirora authored Dec 22, 2023
2 parents 0d707c8 + bfb625b commit f77b8f9
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2 -DROCM_USE_FLOAT16=1")
if(DEFINED ENV{USE_HIPBLASLT_DATATYPE})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_HIPBLASLT_DATATYPE=1")
elseif(DEFINED ENV{USE_HIP_DATATYPE})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_HIP_DATATYPE=1")
endif()
if(DEFINED ENV{USE_HIPBLAS_COMPUTETYPE})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_HIPBLAS_COMPUTETYPE=1")
endif()
target_link_libraries(dist_inference MPI::MPI_CXX rccl hipblaslt hip::device)
else()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,18 @@ using cublasLtHalf = hipblasLtHalf;
#if defined(USE_HIPBLASLT_DATATYPE)
#define DIST_INF_HIP_DATATYPE_R_16F HIPBLASLT_R_16F
#define DIST_INF_HIP_DATATYPE_R_32F HIPBLASLT_R_32F
#elif defined(USE_HIP_DATATYPE)
#define DIST_INF_HIP_DATATYPE_R_16F HIP_R_16F
#define DIST_INF_HIP_DATATYPE_R_32F HIP_R_32F
#else
#define DIST_INF_HIP_DATATYPE_R_16F HIPBLAS_R_16F
#define DIST_INF_HIP_DATATYPE_R_32F HIPBLAS_R_32F
#endif
#if defined(USE_HIPBLAS_COMPUTETYPE)
#define DIST_INF_HIP_COMPUTETYPE_F32 HIPBLAS_COMPUTE_32F
#else
#define DIST_INF_HIP_COMPUTETYPE_F32 HIPBLASLT_COMPUTE_F32
#endif
#else
#include <cublasLt.h>
#include <cuda_fp16.h>
Expand Down Expand Up @@ -244,8 +252,10 @@ void TestModel(int64_t m, int64_t n, int64_t k, float alpha, float beta, int32_t
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matF, DIST_INF_HIP_DATATYPE_R_16F, k, n, k));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matG, DIST_INF_HIP_DATATYPE_R_16F, k, n, k));

CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul1, HIPBLASLT_COMPUTE_F32, DIST_INF_HIP_DATATYPE_R_32F));
CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul2, HIPBLASLT_COMPUTE_F32, DIST_INF_HIP_DATATYPE_R_32F));
CHECK_CUBLASLT_ERROR(
hipblasLtMatmulDescCreate(&matmul1, DIST_INF_HIP_COMPUTETYPE_F32, DIST_INF_HIP_DATATYPE_R_32F));
CHECK_CUBLASLT_ERROR(
hipblasLtMatmulDescCreate(&matmul2, DIST_INF_HIP_COMPUTETYPE_F32, DIST_INF_HIP_DATATYPE_R_32F));

hipblasOperation_t trans = HIPBLAS_OP_N;
CHECK_CUBLASLT_ERROR(
Expand Down

0 comments on commit f77b8f9

Please sign in to comment.