From 729d638c800d263c1947ee27d8d55db5e913b750 Mon Sep 17 00:00:00 2001 From: Zoran Jovanovic Date: Thu, 16 May 2024 20:43:10 +0000 Subject: [PATCH 1/6] Switch on Triton feature for ROCm. --- third_party/triton/temporary/amd_pr7.patch | 181 ++++++++++++++++++ third_party/triton/temporary/series.bzl | 1 + .../tsl/third_party/gpus/rocm_configure.bzl | 1 + xla/service/gpu/BUILD | 9 +- xla/service/gpu/gemm_fusion.cc | 6 +- xla/service/gpu/gpu_compiler.cc | 4 + xla/service/gpu/ir_emitter_triton_rocm.cc | 7 +- xla/service/gpu/ir_emitter_triton_test.cc | 20 +- 8 files changed, 217 insertions(+), 12 deletions(-) create mode 100644 third_party/triton/temporary/amd_pr7.patch diff --git a/third_party/triton/temporary/amd_pr7.patch b/third_party/triton/temporary/amd_pr7.patch new file mode 100644 index 0000000000000..59482237bfdbc --- /dev/null +++ b/third_party/triton/temporary/amd_pr7.patch @@ -0,0 +1,181 @@ +==== triton/BUILD#46 - /google/src/cloud/csigg/triton_amd/triton/BUILD ==== +# action=edit type=text +--- triton/BUILD 2024-04-11 02:00:21.000000000 -0700 ++++ triton/BUILD 2024-04-21 23:52:01.000000000 -0700 +@@ -725,12 +725,12 @@ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:GPUDialect", ++ "@llvm-project//mlir:GPUToROCDLTransforms", + "@llvm-project//mlir:IR", +- "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", +- "@llvm-project//mlir:TransformUtils", ++ "@llvm-project//mlir:ROCDLDialect", + "@llvm-project//mlir:Transforms", + ], + ) +==== triton/third_party/amd/BUILD#None - /google/src/cloud/csigg/triton_amd/triton/third_party/amd/BUILD ==== +# action=add type=text +diff --git a/third_party/amd/BUILD b/third_party/amd/BUILD +new file mode 100644 +index 0000000..ee4bc37 +--- /dev/null ++++ b/third_party/amd/BUILD +@@ -0,0 +1,128 @@ ++load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") ++ ++ ++package( ++ # copybara:uncomment_begin ++ # default_applicable_licenses = ["//:license"], ++ # default_visibility = [ ++ # "//third_party/tensorflow/compiler/xla/service/gpu:__subpackages__", ++ # "//:__subpackages__", ++ # ], ++ # copybara:uncomment_end_and_comment_begin ++ default_visibility = ["//visibility:public"], ++ # copybara:comment_end ++) ++ ++# TODO(csigg): fix, enable error upstream, remove. ++_no_unused_variable = select({ ++ "//:compiler_is_msvc": [], ++ "//conditions:default": ["-Wno-unused-variable"], ++}) ++ ++cc_library( ++ name = "TritonAMDGPUTransforms", ++ srcs = glob([ ++ "lib/TritonAMDGPUTransforms/*.h", ++ "lib/TritonAMDGPUTransforms/*.cpp", ++ ]), ++ hdrs = glob([ ++ "include/TritonAMDGPU/*.h", ++ "include/TritonAMDGPUTransforms/*.h", ++ "lib/TritonAMDGPUTransforms/*.h", ++ ]), ++ copts = _no_unused_variable, ++ includes = [ ++ "..", ++ "include", ++ "lib/TritonAMDGPUTransforms", ++ ], ++ deps = [ ++ ":triton_conversion_amdgpu_to_llvm_passes_inc_gen", ++ "@llvm-project//mlir:ConvertToLLVM", ++ "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:LLVMCommonConversion", ++ "@llvm-project//mlir:LLVMDialect", ++ "@llvm-project//mlir:Pass", ++ "@llvm-project//mlir:Support", ++ "@llvm-project//mlir:TransformUtils", ++ "@llvm-project//mlir:Transforms", ++ "//:TritonAnalysis", ++ "//:TritonDialects", ++ "//:TritonGPUToLLVM", ++ "//:TritonGPUTransforms", ++ ], ++) ++ ++cc_library( ++ name = "TritonAMDGPUToLLVM", ++ srcs = glob([ ++ "lib/TritonAMDGPUToLLVM/**/*.h", ++ "lib/TritonAMDGPUToLLVM/**/*.cpp", ++ ]), ++ hdrs = glob([ ++ "include/TritonAMDGPUToLLVM/**/*.h", ++ ]) + [ ++ "lib/TritonAMDGPUToLLVM/Utility.h", ++ ], ++ copts = _no_unused_variable, ++ includes = [ ++ "..", ++ "include", ++ "lib/TritonAMDGPUToLLVM", ++ ], ++ deps = [ ++ ":triton_transforms_amdgpu_to_llvm_passes_inc_gen", ++ "@llvm-project//mlir:ConvertToLLVM", ++ "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:LLVMCommonConversion", ++ "@llvm-project//mlir:LLVMDialect", ++ "@llvm-project//mlir:Pass", ++ "@llvm-project//mlir:Support", ++ "@llvm-project//mlir:TransformUtils", ++ "@llvm-project//mlir:Transforms", ++ "//:TritonAnalysis", ++ "//:TritonDialects", ++ "//:TritonGPUToLLVM", ++ ":TritonAMDGPUTransforms", ++ ], ++) ++ ++td_library( ++ name = "td_files", ++ srcs = glob(["include/**/*.td"]), ++ includes = ["include"], ++ deps = ["//:td_files"], ++) ++ ++gentbl_cc_library( ++ name = "triton_transforms_amdgpu_to_llvm_passes_inc_gen", ++ tbl_outs = [ ++ ( ++ [ ++ "--gen-pass-decls", ++ "--name=TritonAMDGPUToLLVM", ++ ], ++ "include/TritonAMDGPUToLLVM/Passes.h.inc", ++ ), ++ ], ++ tblgen = "@llvm-project//mlir:mlir-tblgen", ++ td_file = "include/TritonAMDGPUToLLVM/Passes.td", ++ deps = [":td_files"], ++) ++ ++ ++gentbl_cc_library( ++ name = "triton_conversion_amdgpu_to_llvm_passes_inc_gen", ++ tbl_outs = [ ++ ( ++ [ ++ "--gen-pass-decls", ++ "--name=TritonAMDGPU", ++ ], ++ "include/TritonAMDGPUTransforms/Passes.h.inc", ++ ), ++ ], ++ tblgen = "@llvm-project//mlir:mlir-tblgen", ++ td_file = "include/TritonAMDGPUTransforms/Passes.td", ++ deps = [":td_files"], ++) +diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp +index f59efd6..cf601f0 100644 +--- a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp ++++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp +@@ -1132,6 +1132,21 @@ struct FpToFpOpConversion + for (unsigned i = 0; i < std::min(numElements, operands.size()); i++) { + inVals.push_back(operands[i][0]); + } ++ ++ bool isSrcFP16 = srcElementType.isF16(); ++ bool isSrcBF16 = srcElementType.isBF16(); ++ ++ if ((isSrcFP16 || isSrcBF16) ++ && isDstFP32) { ++ SmallVector outVals; ++ for (Value &v : inVals) { ++ if(isSrcFP16) ++ outVals.push_back(convertFp16ToFp32(loc, rewriter, v)); ++ else ++ outVals.push_back(convertBf16ToFp32(loc, rewriter, v)); ++ } ++ return outVals; ++ } + if (useFP16IntermediateSrc) + for (Value &v : inVals) + v = convertFp32ToFp16NZ(loc, rewriter, v); diff --git a/third_party/triton/temporary/series.bzl b/third_party/triton/temporary/series.bzl index 2dfe0dd1bb695..e4181c8371ae5 100644 --- a/third_party/triton/temporary/series.bzl +++ b/third_party/triton/temporary/series.bzl @@ -7,4 +7,5 @@ internal patch during the next triton integration process. temporary_patch_list = [ "//third_party/triton/temporary:pipelining.patch", + "//third_party/triton/temporary:amd_pr7.patch", ] diff --git a/third_party/tsl/third_party/gpus/rocm_configure.bzl b/third_party/tsl/third_party/gpus/rocm_configure.bzl index 89a5c21ebf928..6b56f9cfc168b 100644 --- a/third_party/tsl/third_party/gpus/rocm_configure.bzl +++ b/third_party/tsl/third_party/gpus/rocm_configure.bzl @@ -707,6 +707,7 @@ def _create_local_rocm_repository(repository_ctx): "-DTENSORFLOW_USE_ROCM=1", "-D__HIP_PLATFORM_AMD__", "-DEIGEN_USE_HIP", + "-DUSE_ROCM", ]) rocm_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc" diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index beae58698cd25..de59058805d73 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -597,15 +597,18 @@ cc_library( "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", "@triton//:TritonGPUTransforms", "@triton//:TritonLLVMIR", + ]) + if_rocm_is_configured([ + "@triton//:TritonGPUTransforms", + "@triton//:TritonLLVMIR", + "@triton//third_party/amd:TritonAMDGPUToLLVM", ]), ) xla_test( name = "ir_emitter_triton_test", - srcs = if_cuda_is_configured(["ir_emitter_triton_test.cc"]), + srcs = if_gpu_is_configured(["ir_emitter_triton_test.cc"]), backends = [ - "gpu_a100", - "gpu_h100", + "gpu", ], shard_count = 20, tags = ["nomac"], diff --git a/xla/service/gpu/gemm_fusion.cc b/xla/service/gpu/gemm_fusion.cc index 05e758a73f3d4..cbf0951636282 100644 --- a/xla/service/gpu/gemm_fusion.cc +++ b/xla/service/gpu/gemm_fusion.cc @@ -803,7 +803,11 @@ absl::StatusOr GemmFusion::Run( const absl::flat_hash_set& execution_threads) { auto cuda_compute_capability = std::get_if(&gpu_version_); - if (!cuda_compute_capability || !cuda_compute_capability->IsAtLeastAmpere()) { + auto rocm_compute_capability = + std::get_if(&gpu_version_); + + if ((!cuda_compute_capability || !cuda_compute_capability->IsAtLeastAmpere()) + && !rocm_compute_capability) { return absl::FailedPreconditionError( "Triton support is only enabled for Ampere GPUs and up."); } diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 0da6accd17c09..503b7590fd33b 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -1420,6 +1420,7 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( gpu_target_config.device_description.gpu_compute_capability(); pipeline.AddPass(gpu_version); const auto* cuda_cc = std::get_if(&gpu_version); + const auto* rocm_cc = std::get_if(&gpu_version); // Rewrite FP8 GEMMs ahead of Triton which currently lacks support for FP8 // and may rewrite quantized FP8 GEMMs as higher-precision GEMMs. @@ -1428,6 +1429,9 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) { pipeline.AddPass(gpu_version); } + if (debug_options.xla_gpu_enable_triton_gemm() && rocm_cc != nullptr) { + pipeline.AddPass(gpu_version); + } // Rewrite non-FP8 GEMMs. pipeline.AddPass(gpu_version, /*f8_rewrite=*/false); diff --git a/xla/service/gpu/ir_emitter_triton_rocm.cc b/xla/service/gpu/ir_emitter_triton_rocm.cc index 1a7aa92c62a9e..7c791dcace677 100644 --- a/xla/service/gpu/ir_emitter_triton_rocm.cc +++ b/xla/service/gpu/ir_emitter_triton_rocm.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ // TODO(ROCm): Enable and include ROCm Triton passes when ROCm Triton is // included in build. -// #include "third_party/amd/include/TritonAMDGPUToLLVM/Passes.h" +#include "third_party/amd/include/TritonAMDGPUToLLVM/Passes.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" // from @llvm-project #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" // from @llvm-project @@ -30,6 +30,7 @@ limitations under the License. #include "triton/Dialect/Triton/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" namespace xla { namespace gpu { @@ -77,8 +78,6 @@ absl::Status CreateTritonPipeline( pm.addPass(mt::gpu::createRemoveLayoutConversionsPass()); pm.addPass(mt::gpu::createOptimizeDotOperandsPass()); pm.addPass(mlir::createCSEPass()); - pm.addPass(mt::gpu::createPipelinePass(config.num_stages, config.num_warps, - config.num_ctas, ccAsInt)); pm.addPass(mt::gpu::createPrefetchPass()); pm.addPass(mt::gpu::createOptimizeDotOperandsPass()); @@ -95,7 +94,7 @@ absl::Status CreateTritonPipeline( pm.addPass(mlir::createConvertSCFToCFPass()); pm.addPass(mlir::createConvertIndexToLLVMPass()); pm.addPass(mt::gpu::createAllocateSharedMemoryPass()); - // pm.addPass(mt::createConvertTritonAMDGPUToLLVMPass()); + pm.addPass(mt::createConvertTritonAMDGPUToLLVMPass()); pm.addPass(mlir::createArithToLLVMConversionPass()); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass()); diff --git a/xla/service/gpu/ir_emitter_triton_test.cc b/xla/service/gpu/ir_emitter_triton_test.cc index 97b887f40ddf6..cae5454550ec8 100644 --- a/xla/service/gpu/ir_emitter_triton_test.cc +++ b/xla/service/gpu/ir_emitter_triton_test.cc @@ -284,7 +284,7 @@ ENTRY e { })"; TritonGemmConfig config(16, 16, 32, 1, 1, 1); - EXPECT_OK( + TF_EXPECT_OK( CreateTritonIrAndFileCheck(kHloText, config, EmitMatMul, "triton_dot", R"( CHECK: tt.func @triton_fn(%[[LHS:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[RHS:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[OUT:.*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK-DAG: %[[ZERO_KN:.*]] = arith.constant dense<0.000000e+00> : tensor<32x16xf32> @@ -1083,7 +1083,7 @@ ENTRY main { ParseAndReturnVerifiedModule(kHloText)); TritonGemmConfig config(16, 64, 32, 1, 1, 1); - ASSERT_OK(CreateTritonIrAndFileCheck(kHloText, config, EmitSoftMax, + TF_ASSERT_OK(CreateTritonIrAndFileCheck(kHloText, config, EmitSoftMax, "triton_softmax_computation", R"( // CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (s0 * 16)> // CHECK-LABEL: tt.func @triton_fn( @@ -2152,6 +2152,9 @@ ENTRY e { TEST_F(TritonGemmTestAny, DoNotFuseConcatenationOfSplitNonContractingDimension) { + if (std::holds_alternative(GpuComputeComp())) { + GTEST_SKIP() << "Not using autotuner on ROCM yet."; + } if (SkipBF16Tests()) { GTEST_SKIP() << "BF16 not supported."; } @@ -3278,6 +3281,9 @@ ENTRY e { } TEST_F(CompareTest, UsingOptinSharedMemoryOnAmpereProducesSameResult) { + if (std::holds_alternative(GpuComputeComp())) { + GTEST_SKIP() << "No Optin Shared Memory on AMD."; + } const se::DeviceDescription dev_info = backend().default_stream_executor()->GetDeviceDescription(); constexpr int kBytesOfSharedMemoryTested = 64 * 1024; @@ -4504,7 +4510,7 @@ ENTRY e { } )"; TritonGemmConfig config(32, 32, 32, 1, 1, 1); - ASSERT_OK( + TF_ASSERT_OK( CreateTritonIrAndFileCheck(kHloText, config, EmitMatMul, "triton_dot", R"( CHECK: %[[INFINITY:.*]] = arith.constant dense<0x7F800000> : tensor<32x32xf32> CHECK: %[[C_MASK:.*]] = arith.constant dense<-65536> : tensor<32x32xi32> @@ -4742,6 +4748,9 @@ CHECK-COUNT-6: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> } TEST_F(Triton6xBF16GemmTest, Emit6xBF16GemmEndToEnd) { + if (std::holds_alternative(GpuComputeComp())) { + GTEST_SKIP() << "ALG_DOT_BF16_BF16_F32_X6 not supported on ROCM."; + } const char* kHloText = R"( HloModule t @@ -4835,7 +4844,7 @@ ENTRY e { } )"; TritonGemmConfig config(32, 32, 32, 1, 1, 1); - ASSERT_OK( + TF_ASSERT_OK( CreateTritonIrAndFileCheck(kHloText, config, EmitMatMul, "triton_dot", R"( CHECK: %[[INFINITY:.*]] = arith.constant dense<0x7F800000> : tensor<32x32xf32> CHECK: %[[C_MASK:.*]] = arith.constant dense<-65536> : tensor<32x32xi32> @@ -5092,6 +5101,9 @@ CHECK-COUNT-3: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> } TEST_F(Triton3xBF16GemmTest, Emit3xBF16GemmEndToEnd) { + if (std::holds_alternative(GpuComputeComp())) { + GTEST_SKIP() << "ALG_DOT_BF16_BF16_F32_X3 not supported on ROCM."; + } const char* kHloText = R"( HloModule t From 72e7d772b7345d77b8177e576b5771f877fdcbb2 Mon Sep 17 00:00:00 2001 From: Zoran Jovanovic Date: Fri, 17 May 2024 15:30:26 +0000 Subject: [PATCH 2/6] Triton passes for ROCm organized same as in rocm compipler.py --- xla/service/gpu/BUILD | 1 + xla/service/gpu/ir_emitter_triton_rocm.cc | 30 ++++++++++++++--------- xla/stream_executor/device_description.h | 7 ++++++ 3 files changed, 27 insertions(+), 11 deletions(-) diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index de59058805d73..234319fec692f 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -601,6 +601,7 @@ cc_library( "@triton//:TritonGPUTransforms", "@triton//:TritonLLVMIR", "@triton//third_party/amd:TritonAMDGPUToLLVM", + "@triton//third_party/amd:TritonAMDGPUTransforms", ]), ) diff --git a/xla/service/gpu/ir_emitter_triton_rocm.cc b/xla/service/gpu/ir_emitter_triton_rocm.cc index 7c791dcace677..5646bb9bc6991 100644 --- a/xla/service/gpu/ir_emitter_triton_rocm.cc +++ b/xla/service/gpu/ir_emitter_triton_rocm.cc @@ -31,6 +31,7 @@ limitations under the License. #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" #include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "third_party/amd/include/TritonAMDGPUTransforms/Passes.h" namespace xla { namespace gpu { @@ -55,9 +56,10 @@ absl::Status CreateTritonPipeline( const int ccAsInt = 0; // TODO(ROCm): Check why some test fail when threadsPerWarp is set to 64. const int threadsPerWarp = 32; + auto ccRocm = std::get(cc); // Based on make_ttir() in - // @triton//:third_party/nvidia/backend/compiler.py + // @triton//:third_party/amd/backend/compiler.py pm.addPass(mlir::createInlinerPass()); pm.addPass(mt::createRewriteTensorPointerPass()); pm.addPass(mt::createCombineOpsPass()); @@ -68,7 +70,7 @@ absl::Status CreateTritonPipeline( pm.addPass(mlir::createSymbolDCEPass()); // Based on make_ttgir() in - // @triton//:third_party/nvidia/backend/compiler.py + // @triton//:third_party/amd/backend/compiler.py pm.addPass(mt::createConvertTritonToTritonGPUPass( config.num_warps, threadsPerWarp, config.num_ctas, ccAsInt)); pm.addPass(mt::gpu::createCoalescePass()); @@ -76,21 +78,24 @@ absl::Status CreateTritonPipeline( pm.addPass(mt::gpu::createOptimizeThreadLocalityPass()); pm.addPass(mt::gpu::createAccelerateMatmulPass(ccAsInt)); pm.addPass(mt::gpu::createRemoveLayoutConversionsPass()); + pm.addPass(mlir::createTritonAMDGPUOptimizeEpiloguePass()); pm.addPass(mt::gpu::createOptimizeDotOperandsPass()); - pm.addPass(mlir::createCSEPass()); - pm.addPass(mt::gpu::createPrefetchPass()); - + if(config.num_stages == 0 and ccRocm.has_mma_instr_support()) { + pm.addPass(mlir::createTritonAMDGPUStreamPipelinePass()); + pm.addPass(mlir::createCanonicalizerPass()); + } pm.addPass(mt::gpu::createOptimizeDotOperandsPass()); pm.addPass(mt::gpu::createRemoveLayoutConversionsPass()); - pm.addPass(mt::gpu::createReduceDataDuplicationPass()); - pm.addPass(mt::gpu::createReorderInstructionsPass()); + pm.addPass(mlir::createTritonAMDGPUDecomposeConversionsPass()); + if(config.num_stages == 0) { + pm.addPass(mt::gpu::createReorderInstructionsPass()); + } pm.addPass(mlir::createCSEPass()); pm.addPass(mlir::createSymbolDCEPass()); - pm.addPass(mlir::createCanonicalizerPass()); // Based on make_llir() in - // @triton//:third_party/nvidia/backend/compiler.py - // pm.addPass(mt::gpu::createDecomposeUnsupportedConversionsPass()); + // @triton//:third_party/amd/backend/compiler.py + pm.addPass(mlir::triton::AMD::createDecomposeUnsupportedConversionsPass()); pm.addPass(mlir::createConvertSCFToCFPass()); pm.addPass(mlir::createConvertIndexToLLVMPass()); pm.addPass(mt::gpu::createAllocateSharedMemoryPass()); @@ -102,7 +107,10 @@ absl::Status CreateTritonPipeline( // Note: translateTritonGPUToLLVMIR adds line info with LLVMDIScopePass. pm.addPass(mlir::createConvertSCFToCFPass()); pm.addPass(mlir::createConvertControlFlowToLLVMPass()); - + pm.addPass(mlir::createArithToLLVMConversionPass()); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createCSEPass()); + pm.addPass(mlir::createSymbolDCEPass()); // There is no clusters in ROCm for now. out_cluster_info.clusterDimX = 1; out_cluster_info.clusterDimY = 1; diff --git a/xla/stream_executor/device_description.h b/xla/stream_executor/device_description.h index 7a7c35bc55659..fadd74a523838 100644 --- a/xla/stream_executor/device_description.h +++ b/xla/stream_executor/device_description.h @@ -214,6 +214,13 @@ class RocmComputeCapability { bool has_mfma_instr_support() const { return gfx9_mi100_or_later(); } + bool has_mma_instr_support() const { + static constexpr absl::string_view kList[] = {"gfx908", "gfx90a", "gfx940", + "gfx942"}; + return (gfx_version().find("gfx11") || + (absl::c_count(kList, gfx_version()) != 0)); + } + bool has_fp16_atomics_support() const { // TODO(rocm): Check. This should be the same as has_fast_fp16_support(). return gfx9_mi200_or_later(); From 8a1508290f240c516b87f480df7cc1d0139985e4 Mon Sep 17 00:00:00 2001 From: Zoran Jovanovic Date: Thu, 30 May 2024 08:22:52 +0000 Subject: [PATCH 3/6] [ROCm] Add custom call handling by Triton. --- xla/service/gpu/ir_emitter_unnested.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xla/service/gpu/ir_emitter_unnested.cc b/xla/service/gpu/ir_emitter_unnested.cc index 9ba1be41febde..76fc9e64d7ae2 100644 --- a/xla/service/gpu/ir_emitter_unnested.cc +++ b/xla/service/gpu/ir_emitter_unnested.cc @@ -1590,7 +1590,7 @@ absl::Status IrEmitterUnnested::EmitTopKCustomCall( absl::Status IrEmitterUnnested::EmitTritonCustomCall( const HloCustomCallInstruction* instr) { -#if !GOOGLE_CUDA +#if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM return absl::UnimplementedError("Triton support requires CUDA"); #else auto generate = [this, &instr]() -> absl::StatusOr { @@ -1615,7 +1615,7 @@ absl::Status IrEmitterUnnested::EmitTritonCustomCall( TF_ASSIGN_OR_RETURN( auto result, CompileTritonToLLVM(hlo_module->config(), hlo_module->name(), - ir_emitter_context_->cuda_compute_capability(), + ir_emitter_context_->gpu_compute_capability(), ir_emitter_context_->gpu_device_info(), gemm_config, triton_module.get(), ir_emitter_context_->llvm_module(), mlir_context)); From 9aadc4d1eafa2d89bfc7c74acf5d344fe9bc90c6 Mon Sep 17 00:00:00 2001 From: Zoran Jovanovic Date: Thu, 30 May 2024 15:36:33 +0000 Subject: [PATCH 4/6] [ROCm] Fix an issue with Softmax. --- xla/service/gpu/gpu_compiler.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 503b7590fd33b..88bc5b777380a 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -1453,8 +1453,9 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( // ReductionDimensionGrouper, as that makes matching the softmax pattern // harder. if (debug_options.xla_gpu_enable_triton_softmax_fusion() && - cuda_cc != nullptr && - cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) { + ((cuda_cc != nullptr && + cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) || + rocm_cc != nullptr)) { pipeline.AddPass>(simplifier_options); pipeline.AddPass(gpu_version); } From ae0f0798aea025519f597f0c56827658d60f82fa Mon Sep 17 00:00:00 2001 From: Zoran Jovanovic Date: Thu, 30 May 2024 16:10:23 +0000 Subject: [PATCH 5/6] [ROCm] Fix an issue with Softmax 2. --- xla/service/gpu/softmax_rewriter_triton.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/xla/service/gpu/softmax_rewriter_triton.cc b/xla/service/gpu/softmax_rewriter_triton.cc index 2db88bc59379a..27b08692e7c4b 100644 --- a/xla/service/gpu/softmax_rewriter_triton.cc +++ b/xla/service/gpu/softmax_rewriter_triton.cc @@ -627,9 +627,12 @@ absl::StatusOr SoftmaxRewriterTriton::Run( const absl::flat_hash_set& execution_threads) { auto cuda_compute_capability = std::get_if(&gpu_version_); - if (!cuda_compute_capability || !cuda_compute_capability->IsAtLeastAmpere()) { + auto rocm_compute_capability = + std::get_if(&gpu_version_); + if ((!cuda_compute_capability || !cuda_compute_capability->IsAtLeastAmpere()) + && !rocm_compute_capability) { return absl::FailedPreconditionError( - "Triton support is only enabled for Ampere GPUs and up."); + "Triton support is only enabled for ROCm and Ampere GPUs and up."); } std::vector diamond_chains = From 2ecd466b45323ab7026cc16fc65bb9c132764de8 Mon Sep 17 00:00:00 2001 From: Zoran Jovanovic Date: Tue, 4 Jun 2024 11:28:21 +0000 Subject: [PATCH 6/6] [ROCm] Fix an issue with undefined __oclc_ABI_version symbol. --- xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 6f60f4b3a5ac5..92f3524ddfbde 100644 --- a/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -712,7 +712,8 @@ std::vector GetROCDLPaths(std::string gcn_arch_name, new std::vector( {"opencl.bc", "ocml.bc", "ockl.bc", "oclc_finite_only_off.bc", "oclc_daz_opt_off.bc", "oclc_correctly_rounded_sqrt_on.bc", - "oclc_unsafe_math_off.bc", "oclc_wavefrontsize64_on.bc"}); + "oclc_unsafe_math_off.bc", "oclc_wavefrontsize64_on.bc", + "oclc_abi_version_500.bc"}); // Construct full path to ROCDL bitcode libraries. std::vector result;