Skip to content

Commit

Permalink
[ROCm] Fixed linker issues related to fp8 buffer_comparator functions
Browse files Browse the repository at this point in the history
  • Loading branch information
zoranjovanovic-ns committed Nov 13, 2024
1 parent bf81e49 commit 4951842
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions xla/service/gpu/buffer_comparator.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ __global__ void xla_fp8_e4m3fnuz_comparison(__hip_fp8_storage_t* buffer_a,
float rel_error_threshold,
uint64_t buffer_length,
int* mismatch_count) {
#if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx >= buffer_length) return;
__hip_fp8_e4m3_fnuz elem_a_fp8, elem_b_fp8;
Expand All @@ -123,13 +124,17 @@ __global__ void xla_fp8_e4m3fnuz_comparison(__hip_fp8_storage_t* buffer_a,

if (rel_error > rel_error_threshold || isnan(rel_error))
atomicAdd(mismatch_count, 1);
#else
abort();
#endif // defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
}

__global__ void xla_fp8_e5m2fnuz_comparison(__hip_fp8_storage_t* buffer_a,
__hip_fp8_storage_t* buffer_b,
float rel_error_threshold,
uint64_t buffer_length,
int* mismatch_count) {
#if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx >= buffer_length) return;
__hip_fp8_e5m2_fnuz elem_a_fp8, elem_b_fp8;
Expand All @@ -145,6 +150,9 @@ __global__ void xla_fp8_e5m2fnuz_comparison(__hip_fp8_storage_t* buffer_a,

if (rel_error > rel_error_threshold || isnan(rel_error))
atomicAdd(mismatch_count, 1);
#else
abort();
#endif // defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
}
#endif // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200

Expand Down

0 comments on commit 4951842

Please sign in to comment.