Skip to content

Commit

Permalink
Merge pull request #35 from ROCm/rocm-jaxlib-v0.4.30-qa-cleanup
Browse files Browse the repository at this point in the history
Rocm jaxlib v0.4.30 qa cleanup
  • Loading branch information
hsharsha authored Aug 29, 2024
2 parents 8c73dfe + 8ae1de7 commit 5945307
Show file tree
Hide file tree
Showing 33 changed files with 546 additions and 272 deletions.
35 changes: 35 additions & 0 deletions third_party/llvm/rocdl_shuffle_down.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
From a46b9e979ffa523bfed61487a2404e1f48140288 Mon Sep 17 00:00:00 2001
From: Dragan Mladjenovic <[email protected]>
Date: Fri, 29 Mar 2024 12:27:36 +0000
Subject: [PATCH] Support gpu::ShuffleMode::DOWN lowering

---
mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp | 6 +++++-
1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index e2cb3687d872..9317e30290c6 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -140,7 +140,7 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth);

auto int32Type = IntegerType::get(rewriter.getContext(), 32);
Value width = adaptor.getWidth();
Value zero = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 0);
Value negwidth = rewriter.create<LLVM::SubOp>(loc, int32Type, zero, width);
Value add = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId, width);
@@ -151,6 +151,10 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
// TODO: Use ds_swizzle for XOR when step/offsets are constants for better
// perf.
switch (op.getMode()) {
+ case gpu::ShuffleMode::DOWN:
+ dstLane = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId,
+ adaptor.getOffset());
+ break;
case gpu::ShuffleMode::XOR:
dstLane = rewriter.create<LLVM::XOrOp>(loc, int32Type, srcLaneId,
adaptor.getOffset());
--
2.25.1

1 change: 1 addition & 0 deletions third_party/llvm/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def repo(name):
"//third_party/llvm:mathextras.patch",
"//third_party/llvm:toolchains.patch",
"//third_party/llvm:zstd.patch",
"//third_party/llvm:rocdl_shuffle_down.patch",
],
link_files = {"//third_party/llvm:run_lit.sh": "mlir/run_lit.sh"},
)
35 changes: 35 additions & 0 deletions third_party/tsl/third_party/llvm/rocdl_shuffle_down.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
From a46b9e979ffa523bfed61487a2404e1f48140288 Mon Sep 17 00:00:00 2001
From: Dragan Mladjenovic <[email protected]>
Date: Fri, 29 Mar 2024 12:27:36 +0000
Subject: [PATCH] Support gpu::ShuffleMode::DOWN lowering

---
mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp | 6 +++++-
1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index e2cb3687d872..9317e30290c6 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -140,7 +140,7 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth);

auto int32Type = IntegerType::get(rewriter.getContext(), 32);
Value width = adaptor.getWidth();
Value zero = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 0);
Value negwidth = rewriter.create<LLVM::SubOp>(loc, int32Type, zero, width);
Value add = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId, width);
@@ -151,6 +151,10 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
// TODO: Use ds_swizzle for XOR when step/offsets are constants for better
// perf.
switch (op.getMode()) {
+ case gpu::ShuffleMode::DOWN:
+ dstLane = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId,
+ adaptor.getOffset());
+ break;
case gpu::ShuffleMode::XOR:
dstLane = rewriter.create<LLVM::XOrOp>(loc, int32Type, srcLaneId,
adaptor.getOffset());
--
2.25.1

1 change: 1 addition & 0 deletions third_party/tsl/third_party/llvm/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def repo(name):
"//third_party/llvm:mathextras.patch",
"//third_party/llvm:toolchains.patch",
"//third_party/llvm:zstd.patch",
"//third_party/llvm:rocdl_shuffle_down.patch",
],
link_files = {"//third_party/llvm:run_lit.sh": "mlir/run_lit.sh"},
)
17 changes: 13 additions & 4 deletions xla/service/algorithm_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,17 @@ bool IsSupportedDotAlgorithmOnGpu(
std::get<se::CudaComputeCapability>(gpu_compute_capability)
.IsAtLeast(8, 9);

const bool is_rocm_mi100_and_above =
std::holds_alternative<se::RocmComputeCapability>(
gpu_compute_capability) &&
std::get<se::RocmComputeCapability>(gpu_compute_capability)
.gfx9_mi100_or_later();

switch (algorithm) {
case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32:
case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM:
// Other F8 types are actually not supported by NVIDIA GPUs.
return is_cuda_ge_ada &&
return (is_cuda_ge_ada || is_rocm_mi100_and_above) &&
(input_storage_type == F8E5M2 || input_storage_type == F8E4M3FN) &&
(output_storage_type == F8E5M2 ||
output_storage_type == F8E4M3FN || output_storage_type == F16 ||
Expand All @@ -168,14 +174,17 @@ bool IsSupportedDotAlgorithmOnGpu(
return input_storage_type == F16 &&
(output_storage_type == F16 || output_storage_type == F32);
case PrecisionConfig::ALG_DOT_BF16_BF16_F32:
return is_cuda_ge_ampere && input_storage_type == BF16 &&
return (is_cuda_ge_ampere || is_rocm_mi100_and_above) &&
input_storage_type == BF16 &&
(output_storage_type == BF16 || output_storage_type == F32);
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3:
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6:
return is_cuda_ge_ampere && input_storage_type == F32 &&
return (is_cuda_ge_ampere || is_rocm_mi100_and_above)
&& input_storage_type == F32 &&
output_storage_type == F32;
case PrecisionConfig::ALG_DOT_TF32_TF32_F32:
return is_cuda_ge_ampere && input_storage_type == F32 &&
return (is_cuda_ge_ampere || is_rocm_mi100_and_above) &&
input_storage_type == F32 &&
output_storage_type == F32;
case PrecisionConfig::ALG_DOT_F32_F32_F32:
return input_storage_type == F32 && output_storage_type == F32;
Expand Down
7 changes: 6 additions & 1 deletion xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ cc_library(
testonly = 1,
srcs = ["gpu_device_info_for_tests.cc"],
hdrs = ["gpu_device_info_for_tests.h"],
local_defines = if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]),
compatible_with = get_compatible_with_portable(),
deps = [
"//xla/stream_executor:device_description",
Expand Down Expand Up @@ -5681,6 +5682,8 @@ cc_library(
xla_test(
name = "dot_algorithm_support_test",
srcs = if_gpu_is_configured(["dot_algorithm_support_test.cc"]),
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) +
if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]),
backends = [
"gpu_v100",
"gpu_a100",
Expand All @@ -5699,7 +5702,9 @@ xla_test(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_googletest//:gtest",
],
] + if_rocm_is_configured([
"@local_config_rocm//rocm:rocm_headers",
]),
)

cc_library(
Expand Down
61 changes: 61 additions & 0 deletions xla/service/gpu/buffer_comparator.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ using bfloat16 = __nv_bfloat16;
#include <hip/hip_bfloat16.h>
#include <hip/hip_fp16.h>

#include "rocm/rocm_config.h"
#if TF_ROCM_VERSION >= 60200
#include <hip/hip_fp8.h>
#endif // TF_ROCM_VERSION >= 60200

using bfloat16 = hip_bfloat16;
#define BF16_TO_F32 float

Expand Down Expand Up @@ -97,6 +102,52 @@ __global__ void xla_fp8_e5m2_comparison(__nv_fp8_storage_t* buffer_a,
}
#endif // GOOGLE_CUDA

#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200
__global__ void xla_fp8_e4m3fnuz_comparison(__hip_fp8_storage_t* buffer_a,
__hip_fp8_storage_t* buffer_b,
float rel_error_threshold,
uint64_t buffer_length,
int* mismatch_count) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx >= buffer_length) return;
__hip_fp8_e4m3_fnuz elem_a_fp8, elem_b_fp8;
elem_a_fp8.__x = buffer_a[idx];
elem_b_fp8.__x = buffer_b[idx];
float elem_a = static_cast<float>(elem_a_fp8);
float elem_b = static_cast<float>(elem_b_fp8);
elem_a = Canonicalize(elem_a);
elem_b = Canonicalize(elem_b);
if (isnan(elem_a) && isnan(elem_b)) return;

float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + 1);

if (rel_error > rel_error_threshold || isnan(rel_error))
atomicAdd(mismatch_count, 1);
}

__global__ void xla_fp8_e5m2fnuz_comparison(__hip_fp8_storage_t* buffer_a,
__hip_fp8_storage_t* buffer_b,
float rel_error_threshold,
uint64_t buffer_length,
int* mismatch_count) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx >= buffer_length) return;
__hip_fp8_e5m2_fnuz elem_a_fp8, elem_b_fp8;
elem_a_fp8.__x = buffer_a[idx];
elem_b_fp8.__x = buffer_b[idx];
float elem_a = static_cast<float>(elem_a_fp8);
float elem_b = static_cast<float>(elem_b_fp8);
elem_a = Canonicalize(elem_a);
elem_b = Canonicalize(elem_b);
if (isnan(elem_a) && isnan(elem_b)) return;

float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + 1);

if (rel_error > rel_error_threshold || isnan(rel_error))
atomicAdd(mismatch_count, 1);
}
#endif // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200

__global__ void xla_fp16_comparison(__half* buffer_a, __half* buffer_b,
float rel_error_threshold,
uint64_t buffer_length,
Expand Down Expand Up @@ -206,6 +257,16 @@ void* fp8_e5m2_comparison() {
}
#endif

#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200
void* fp8_e4m3fnuz_comparison() {
return reinterpret_cast<void*>(&xla_fp8_e4m3fnuz_comparison);
}

void* fp8_e5m2fnuz_comparison() {
return reinterpret_cast<void*>(&xla_fp8_e5m2fnuz_comparison);
}
#endif // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200

void* fp16_comparison() {
return reinterpret_cast<void*>(&xla_fp16_comparison);
}
Expand Down
8 changes: 8 additions & 0 deletions xla/service/gpu/buffer_comparator.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ limitations under the License.
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/stream_executor.h"

#if TENSORFLOW_USE_ROCM
#include "rocm/rocm_config.h"
#endif

namespace xla::gpu {

// A device-side comparator that compares buffers.
Expand Down Expand Up @@ -56,6 +60,10 @@ namespace buffer_comparator {
// Returns a pointer to CUDA C++ device function implementing comparison.
void* fp8_e4m3fn_comparison();
void* fp8_e5m2_comparison();
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200
void* fp8_e4m3fnuz_comparison();
void* fp8_e5m2fnuz_comparison();
#endif // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200
void* fp16_comparison();
void* bf16_comparison();
void* fp32_comparison();
Expand Down
Loading

0 comments on commit 5945307

Please sign in to comment.