Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
yzygitzh committed Mar 29, 2024
1 parent 60bb72e commit 9cd7a54
Showing 1 changed file with 7 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ using cublasLtHalf = hipblasLtHalf;
#else
#define DIST_INF_HIP_COMPUTETYPE_F32 HIPBLASLT_COMPUTE_F32
#endif
#if ROCM_VERSION >= 50700
#if HIP_VERSION >= 50700000
#include <hipblaslt/hipblaslt-ext.hpp>
#if ROCM_VERSION >= 60000
#if HIP_VERSION >= 60000000
#define HIPBLASLT_GETINDEXFROMALGO(algo) hipblaslt_ext::getIndexFromAlgo(algo)
#else
static int getIndexFromAlgo(hipblasLtMatmulAlgo_t& algo) {
int* algo_ptr = (int*)algo.data;
if(*algo_ptr < 0) {
static int getIndexFromAlgo(hipblasLtMatmulAlgo_t &algo) {
int *algo_ptr = (int *)algo.data;
if (*algo_ptr < 0) {
return -1;
}
return *algo_ptr;
Expand Down Expand Up @@ -203,7 +203,7 @@ void InitializeABCDEF(std::vector<cublasLtHalf> &ha, int64_t size_a, std::vector
}
}

#if defined(__HIP_PLATFORM_AMD__) && ROCM_VERSION >= 50700
#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION >= 50700000
// Tune GEMM algorithm in local rank.
// Write <0 to ret_algo_time_in_ms if nothing found.
// Write >=0 to ret_algo_time_in_ms and write ret_algo if something is found.
Expand Down Expand Up @@ -389,7 +389,7 @@ void TestModel(int64_t m, int64_t n, int64_t k, float alpha, float beta, int32_t
CHECK_CUBLASLT_ERROR(hipblasLtMatmulAlgoGetHeuristic(handle, matmul2, matE, matD, matF, matG, pref, 1,
heuristicResult2, &returnedAlgoCount));
hipblasLtMatmulAlgo_t algo2 = heuristicResult2[0].algo;
#if ROCM_VERSION >= 50700
#if HIP_VERSION >= 50700000
if (tune_gemm) {
hipblasLtMatmulAlgo_t ret_algo;
float ret_algo_time_in_ms;
Expand Down

0 comments on commit 9cd7a54

Please sign in to comment.