diff --git a/xla/service/gpu/fusions/triton/BUILD b/xla/service/gpu/fusions/triton/BUILD index a53ca3cc90932..3f85bdf093c5a 100644 --- a/xla/service/gpu/fusions/triton/BUILD +++ b/xla/service/gpu/fusions/triton/BUILD @@ -71,8 +71,6 @@ cc_library( ]), hdrs = ["compilation_pipeline.h"], deps = [ - "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", - "//xla/stream_executor:device_description", "@com_google_absl//absl/status", "@llvm-project//mlir:Pass", ] + if_gpu_is_configured([ @@ -85,6 +83,7 @@ cc_library( "@llvm-project//mlir:Transforms", "//xla/service:hlo_module_config", "//xla/service/gpu:matmul_utils", + "//xla/stream_executor:device_description", "@triton//:TritonDialects", "@triton//:TritonGPUToLLVM", "@triton//:TritonGPUTransforms", diff --git a/xla/service/gpu/fusions/triton/compilation_pipeline.h b/xla/service/gpu/fusions/triton/compilation_pipeline.h index 8e40565a05626..9db6fc01e9e9f 100644 --- a/xla/service/gpu/fusions/triton/compilation_pipeline.h +++ b/xla/service/gpu/fusions/triton/compilation_pipeline.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_FUSIONS_TRITON_COMPILATION_PIPELINE_H_ #define XLA_SERVICE_GPU_FUSIONS_TRITON_COMPILATION_PIPELINE_H_ +#include + #include "absl/status/status.h" #include "mlir/Pass/PassManager.h" -#include "xla/service/gpu/model/tiled_hlo_computation.h" -#include "xla/stream_executor/device_description.h" namespace mlir::triton::nvidia_gpu { @@ -41,9 +41,8 @@ namespace gpu { // parameter which would give a hint to Triton which cluster dims we prefer to // use, but that's not the case currently. absl::Status CreateTritonPipeline( - mlir::OpPassManager& pm, const se::GpuComputeCapability& cc, - const BlockLevelParameters& block_level_parameters, - mlir::triton::nvidia_gpu::ClusterInfo& out_cluster_info); + mlir::OpPassManager* pm, std::string arch_name, int num_warps, int num_ctas, + int num_stages, mlir::triton::nvidia_gpu::ClusterInfo& out_cluster_info); } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc b/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc index 8ad50e305721d..6bd49df697a7d 100644 --- a/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc +++ b/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include "nvidia/include/NVGPUToLLVM/NVGPUToLLVMPass.h" #include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h" @@ -26,7 +27,6 @@ limitations under the License. #include "mlir/Transforms/Passes.h" #include "xla/service/gpu/fusions/triton/xla_triton_passes.h" #include "xla/service/gpu/llvm_gpu_backend/nvptx_libdevice_path.h" -#include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/service/hlo_module_config.h" #include "xla/stream_executor/device_description.h" #include "triton/Conversion/TritonGPUToLLVM/Passes.h" @@ -42,94 +42,90 @@ namespace mt = ::mlir::triton; namespace mt_xla = ::mlir::triton::xla; absl::Status CreateTritonPipeline( - mlir::OpPassManager& pm, const se::GpuComputeCapability& cc, - const BlockLevelParameters& block_level_parameters, - mt::nvidia_gpu::ClusterInfo& out_cluster_info) { - auto ccCuda = std::get(cc); - const int ccAsInt = ccCuda.major * 10 + ccCuda.minor; + mlir::OpPassManager* pm, std::string arch_name, int num_warps, int num_ctas, + int num_stages, mt::nvidia_gpu::ClusterInfo& out_cluster_info) { + auto cc = se::CudaComputeCapability(std::move(arch_name)); + const int ccAsInt = cc.major * 10 + cc.minor; const int threadsPerWarp = 32; // Based on make_ttir() in // @triton//:third_party/nvidia/backend/compiler.py - pm.addPass(mlir::createInlinerPass()); - pm.addPass(mt::createRewriteTensorPointerPass()); - pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mt::createCombineOpsPass()); - pm.addPass(mt::createReorderBroadcastPass()); - pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::createLoopInvariantCodeMotionPass()); - pm.addPass(mlir::createSymbolDCEPass()); - pm.addPass(mt::createLoopUnrollPass()); + pm->addPass(mlir::createInlinerPass()); + pm->addPass(mt::createRewriteTensorPointerPass()); + pm->addPass(mlir::createCanonicalizerPass()); + pm->addPass(mt::createCombineOpsPass()); + pm->addPass(mt::createReorderBroadcastPass()); + pm->addPass(mlir::createCSEPass()); + pm->addPass(mlir::createLoopInvariantCodeMotionPass()); + pm->addPass(mlir::createSymbolDCEPass()); + pm->addPass(mt::createLoopUnrollPass()); // Based on make_ttgir() in // @triton//:third_party/nvidia/backend/compiler.py - pm.addPass(mt::createConvertTritonToTritonGPUPass( - absl::StrFormat("cuda:%u", ccAsInt), block_level_parameters.num_warps, - threadsPerWarp, block_level_parameters.num_ctas)); - pm.addPass(mt_xla::CreateSparseAddEncodingPass( - block_level_parameters.num_warps, threadsPerWarp, - block_level_parameters.num_ctas)); - pm.addPass(mt::gpu::createTritonGPUCoalesce()); - if (ccCuda.IsAtLeastAmpere()) { - pm.addPass(mt::gpu::createTritonGPUF32DotTC()); + pm->addPass(mt::createConvertTritonToTritonGPUPass( + absl::StrFormat("cuda:%u", ccAsInt), num_warps, threadsPerWarp, + num_ctas)); + pm->addPass( + mt_xla::CreateSparseAddEncodingPass(num_warps, threadsPerWarp, num_ctas)); + pm->addPass(mt::gpu::createTritonGPUCoalesce()); + if (cc.IsAtLeastAmpere()) { + pm->addPass(mt::gpu::createTritonGPUF32DotTC()); } - pm.addPass(mlir::createTritonNvidiaGPUPlanCTAPass(&out_cluster_info)); - pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); - pm.addPass(mt::gpu::createTritonGPUOptimizeThreadLocality()); - pm.addPass(mt_xla::CreateSparseBlockedToMMAPass()); - pm.addPass(mt::gpu::createTritonGPUAccelerateMatmul()); - pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); - pm.addPass( - mt::gpu::createTritonGPUOptimizeDotOperands({ccCuda.IsAtLeastAmpere()})); - pm.addPass(mlir::createCSEPass()); + pm->addPass(mlir::createTritonNvidiaGPUPlanCTAPass(&out_cluster_info)); + pm->addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); + pm->addPass(mt::gpu::createTritonGPUOptimizeThreadLocality()); + pm->addPass(mt_xla::CreateSparseBlockedToMMAPass()); + pm->addPass(mt::gpu::createTritonGPUAccelerateMatmul()); + pm->addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); + pm->addPass( + mt::gpu::createTritonGPUOptimizeDotOperands({cc.IsAtLeastAmpere()})); + pm->addPass(mlir::createCSEPass()); // Even though we don't run on pre-Ampere architectures anymore, we keep this // check for consistency with the upstream pipeline - if (ccCuda.IsAtLeastAmpere()) { - pm.addPass(mt::gpu::createTritonGPUCombineTensorSelectAndIf()); - pm.addPass(mt::gpu::createTritonGPULoopScheduling( - {block_level_parameters.num_stages})); - pm.addPass( - mt::gpu::createTritonGPUPipeline({block_level_parameters.num_stages})); + if (cc.IsAtLeastAmpere()) { + pm->addPass(mt::gpu::createTritonGPUCombineTensorSelectAndIf()); + pm->addPass(mt::gpu::createTritonGPULoopScheduling({num_stages})); + pm->addPass(mt::gpu::createTritonGPUPipeline({num_stages})); } - pm.addPass(mt::gpu::createTritonGPUPrefetch()); - pm.addPass( - mt::gpu::createTritonGPUOptimizeDotOperands({ccCuda.IsAtLeastAmpere()})); - pm.addPass(mt::gpu::createTritonGPUCoalesceAsyncCopy()); - pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); - pm.addPass(mt_xla::CreateSparseRemoveLayoutConversionPass()); - pm.addPass(mt::gpu::createTritonGPUReduceDataDuplication()); - pm.addPass(mt::gpu::createTritonGPUReorderInstructions()); - pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::createSymbolDCEPass()); - if (ccCuda.IsAtLeastHopper()) { - pm.addPass(mlir::createTritonNvidiaGPUFenceInsertionPass(ccAsInt)); - pm.addPass(mlir::createTritonNvidiaGPUTMALoweringPass()); + pm->addPass(mt::gpu::createTritonGPUPrefetch()); + pm->addPass( + mt::gpu::createTritonGPUOptimizeDotOperands({cc.IsAtLeastAmpere()})); + pm->addPass(mt::gpu::createTritonGPUCoalesceAsyncCopy()); + pm->addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); + pm->addPass(mt_xla::CreateSparseRemoveLayoutConversionPass()); + pm->addPass(mt::gpu::createTritonGPUReduceDataDuplication()); + pm->addPass(mt::gpu::createTritonGPUReorderInstructions()); + pm->addPass(mlir::createCSEPass()); + pm->addPass(mlir::createSymbolDCEPass()); + if (cc.IsAtLeastHopper()) { + pm->addPass(mlir::createTritonNvidiaGPUFenceInsertionPass(ccAsInt)); + pm->addPass(mlir::createTritonNvidiaGPUTMALoweringPass()); } - pm.addPass(mlir::createCanonicalizerPass()); + pm->addPass(mlir::createCanonicalizerPass()); // Based on make_llir() in // @triton//:third_party/nvidia/backend/compiler.py - pm.addPass(mt::NVIDIA::createDecomposeUnsupportedConversionsPass()); + pm->addPass(mt::NVIDIA::createDecomposeUnsupportedConversionsPass()); // This pass reduces Hopper compile time extensively: b/344841434. - if (ccCuda.IsAtLeastHopper()) { - pm.addPass(mt_xla::CreatePreventMmaV3LoopUnrollingPass()); + if (cc.IsAtLeastHopper()) { + pm->addPass(mt_xla::CreatePreventMmaV3LoopUnrollingPass()); } - pm.addPass(mlir::createConvertSCFToCFPass()); - pm.addPass(mlir::createConvertIndexToLLVMPass()); - pm.addPass(mt::gpu::createAllocateSharedMemoryPass()); - pm.addPass(mt::gpu::createTritonGPUGlobalScratchAllocationPass()); - pm.addPass(mt_xla::CreateSparseLocalLoadToLLVMPass()); - pm.addPass(mt::createConvertTritonGPUToLLVMPass(ccAsInt)); + pm->addPass(mlir::createConvertSCFToCFPass()); + pm->addPass(mlir::createConvertIndexToLLVMPass()); + pm->addPass(mt::gpu::createAllocateSharedMemoryPass()); + pm->addPass(mt::gpu::createTritonGPUGlobalScratchAllocationPass()); + pm->addPass(mt_xla::CreateSparseLocalLoadToLLVMPass()); + pm->addPass(mt::createConvertTritonGPUToLLVMPass(ccAsInt)); // The triton_xla.sparse_dot ops need to be rewritten after // ModuleAxisInfoAnalysis inside convert-triton-gpu-to-llvm. - pm.addPass(mt_xla::CreateSparseDotOpToLLVMPass()); - pm.addPass(mt::createConvertNVGPUToLLVMPass()); - pm.addPass(mt_xla::CreateSparseWGMMAOpToLLVMPass()); - pm.addPass(mlir::createArithToLLVMConversionPass()); - pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::createSymbolDCEPass()); + pm->addPass(mt_xla::CreateSparseDotOpToLLVMPass()); + pm->addPass(mt::createConvertNVGPUToLLVMPass()); + pm->addPass(mt_xla::CreateSparseWGMMAOpToLLVMPass()); + pm->addPass(mlir::createArithToLLVMConversionPass()); + pm->addPass(mlir::createCanonicalizerPass()); + pm->addPass(mlir::createCSEPass()); + pm->addPass(mlir::createSymbolDCEPass()); // Note: translateTritonGPUToLLVMIR adds line info with LLVMDIScopePass. return absl::OkStatus(); diff --git a/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc b/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc index 187d96657e34a..3d41babfd8ff6 100644 --- a/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc +++ b/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ // TODO(ROCm): Enable and include ROCm Triton passes when ROCm Triton is // included in build. +#include +#include + #include "third_party/amd/include/TritonAMDGPUToLLVM/Passes.h" #include "third_party/amd/include/TritonAMDGPUTransforms/Passes.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" @@ -24,8 +27,8 @@ limitations under the License. #include "mlir/Transforms/Passes.h" #include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" #include "xla/service/gpu/matmul_utils.h" -#include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/service/hlo_module_config.h" +#include "xla/stream_executor/device_description.h" #include "tsl/platform/rocm_rocdl_path.h" #include "triton/Conversion/TritonGPUToLLVM/Passes.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" @@ -53,80 +56,76 @@ using ::mlir::Value; using mlir::ValueRange; absl::Status CreateTritonPipeline( - mlir::OpPassManager& pm, const se::GpuComputeCapability& cc, - const BlockLevelParameters& block_level_parameters, - mt::nvidia_gpu::ClusterInfo& out_cluster_info) { + mlir::OpPassManager* pm, std::string arch_name, int num_warps, int num_ctas, + int num_stages, mt::nvidia_gpu::ClusterInfo& out_cluster_info) { // TODO(ROCm): Check why some test fail when threadsPerWarp is set to 64. const int threadsPerWarp = 32; - auto ccRocm = std::get(cc); + auto cc = se::RocmComputeCapability(std::move(arch_name)); // Based on make_ttir() in // @triton//:third_party/amd/backend/compiler.py - pm.addPass(mlir::createInlinerPass()); - pm.addPass(mt::createRewriteTensorPointerPass()); - pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mt::createCombineOpsPass()); - pm.addPass(mt::createReorderBroadcastPass()); - pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::createLoopInvariantCodeMotionPass()); - pm.addPass(mlir::createSymbolDCEPass()); - pm.addPass(mt::createLoopUnrollPass()); + pm->addPass(mlir::createInlinerPass()); + pm->addPass(mt::createRewriteTensorPointerPass()); + pm->addPass(mlir::createCanonicalizerPass()); + pm->addPass(mt::createCombineOpsPass()); + pm->addPass(mt::createReorderBroadcastPass()); + pm->addPass(mlir::createCSEPass()); + pm->addPass(mlir::createLoopInvariantCodeMotionPass()); + pm->addPass(mlir::createSymbolDCEPass()); + pm->addPass(mt::createLoopUnrollPass()); // Based on make_ttgir() in // @triton//:third_party/amd/backend/compiler.py - pm.addPass(mt::createConvertTritonToTritonGPUPass( - absl::StrCat("hip:", ccRocm.gfx_version()), - block_level_parameters.num_warps, threadsPerWarp, - block_level_parameters.num_ctas)); - pm.addPass(mt::gpu::createTritonGPUCoalesce()); - pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); - pm.addPass(mt::gpu::createTritonGPUOptimizeThreadLocality()); - pm.addPass(mt::gpu::createTritonGPUAccelerateMatmul()); - pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); + pm->addPass(mt::createConvertTritonToTritonGPUPass( + absl::StrCat("hip:", cc.gfx_version()), num_warps, threadsPerWarp, + num_ctas)); + pm->addPass(mt::gpu::createTritonGPUCoalesce()); + pm->addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); + pm->addPass(mt::gpu::createTritonGPUOptimizeThreadLocality()); + pm->addPass(mt::gpu::createTritonGPUAccelerateMatmul()); + pm->addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); // TODO ROCm Check if we want to compare MI100 and greater - pm.addPass(mlir::createTritonAMDGPUOptimizeEpiloguePass()); - pm.addPass(mt::gpu::createTritonGPUOptimizeDotOperands({true})); - if (block_level_parameters.num_stages == kAmdDoubleBuffering && - ccRocm.has_amd_matrix_core()) { - pm.addPass(mlir::createTritonAMDGPUStreamPipelinePass( - block_level_parameters.num_stages, /*stream_prefetch=*/true)); - pm.addPass(mlir::createCanonicalizerPass()); + pm->addPass(mlir::createTritonAMDGPUOptimizeEpiloguePass()); + pm->addPass(mt::gpu::createTritonGPUOptimizeDotOperands({true})); + if (num_stages == kAmdDoubleBuffering && cc.has_amd_matrix_core()) { + pm->addPass(mlir::createTritonAMDGPUStreamPipelinePass( + num_stages, /*stream_prefetch=*/true)); + pm->addPass(mlir::createCanonicalizerPass()); } - pm.addPass(mt::createTritonAMDGPUInsertInstructionSchedHintsPass()); - pm.addPass(mt::gpu::createTritonGPUOptimizeDotOperands({true})); - pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); - pm.addPass(mt::gpu::createTritonGPUReduceDataDuplication()); - if (block_level_parameters.num_stages != kAmdDoubleBuffering) { - pm.addPass(mt::gpu::createTritonGPUReorderInstructions()); + pm->addPass(mt::createTritonAMDGPUInsertInstructionSchedHintsPass()); + pm->addPass(mt::gpu::createTritonGPUOptimizeDotOperands({true})); + pm->addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); + pm->addPass(mt::gpu::createTritonGPUReduceDataDuplication()); + if (num_stages != kAmdDoubleBuffering) { + pm->addPass(mt::gpu::createTritonGPUReorderInstructions()); } - pm.addPass(mlir::createTritonAMDGPUCanonicalizePointersPass()); - pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::createSymbolDCEPass()); + pm->addPass(mlir::createTritonAMDGPUCanonicalizePointersPass()); + pm->addPass(mlir::createCanonicalizerPass()); + pm->addPass(mlir::createCSEPass()); + pm->addPass(mlir::createSymbolDCEPass()); // Based on make_llir() in // @triton//:third_party/amd/backend/compiler.py - pm.addPass(mlir::triton::AMD::createDecomposeUnsupportedConversionsPass( - ccRocm.gfx_version())); + pm->addPass(mlir::triton::AMD::createDecomposeUnsupportedConversionsPass( + cc.gfx_version())); const int custom_lds_size = 0; - pm.addPass(mlir::triton::AMD::createOptimizeLDSUsagePass(ccRocm.gfx_version(), - custom_lds_size)); - pm.addPass(mlir::createConvertSCFToCFPass()); - pm.addPass(mlir::createConvertIndexToLLVMPass()); - pm.addPass(mt::gpu::createAllocateSharedMemoryPass()); - pm.addPass( - mt::createConvertTritonAMDGPUToLLVMPass(ccRocm.gfx_version(), true)); - pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mlir::createCSEPass()); + pm->addPass(mlir::triton::AMD::createOptimizeLDSUsagePass(cc.gfx_version(), + custom_lds_size)); + pm->addPass(mlir::createConvertSCFToCFPass()); + pm->addPass(mlir::createConvertIndexToLLVMPass()); + pm->addPass(mt::gpu::createAllocateSharedMemoryPass()); + pm->addPass(mt::createConvertTritonAMDGPUToLLVMPass(cc.gfx_version(), true)); + pm->addPass(mlir::createCanonicalizerPass()); + pm->addPass(mlir::createCSEPass()); // Note: translateTritonGPUToLLVMIR adds line info with LLVMDIScopePass. - pm.addPass(mlir::createConvertControlFlowToLLVMPass()); - pm.addPass(mlir::createArithToLLVMConversionPass()); - pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::createSymbolDCEPass()); - pm.addPass(mt::createTritonAMDGPULowerInstructionSchedHintsPass( - ccRocm.gfx_version(), block_level_parameters.num_stages, "default")); - pm.addPass(mt::createConvertBuiltinFuncToLLVMPass(/*ftz=*/true)); + pm->addPass(mlir::createConvertControlFlowToLLVMPass()); + pm->addPass(mlir::createArithToLLVMConversionPass()); + pm->addPass(mlir::createCanonicalizerPass()); + pm->addPass(mlir::createCSEPass()); + pm->addPass(mlir::createSymbolDCEPass()); + pm->addPass(mt::createTritonAMDGPULowerInstructionSchedHintsPass( + cc.gfx_version(), num_stages, "default")); + pm->addPass(mt::createConvertBuiltinFuncToLLVMPass(/*ftz=*/true)); // There is no clusters in ROCm for now. out_cluster_info.clusterDimX = 1; out_cluster_info.clusterDimY = 1; diff --git a/xla/service/gpu/fusions/triton/compilation_pipeline_stub.cc b/xla/service/gpu/fusions/triton/compilation_pipeline_stub.cc index 220d5a3147d14..338a1fe5cd604 100644 --- a/xla/service/gpu/fusions/triton/compilation_pipeline_stub.cc +++ b/xla/service/gpu/fusions/triton/compilation_pipeline_stub.cc @@ -13,19 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "absl/status/status.h" #include "mlir/Pass/PassManager.h" #include "xla/service/gpu/fusions/triton/compilation_pipeline.h" -#include "xla/service/gpu/model/tiled_hlo_computation.h" -#include "xla/stream_executor/device_description.h" namespace xla { namespace gpu { absl::Status CreateTritonPipeline( - mlir::OpPassManager& pm, const se::GpuComputeCapability& cc, - const BlockLevelParameters& block_level_parameters, - mlir::triton::nvidia_gpu::ClusterInfo& out_cluster_info) { + mlir::OpPassManager* pm, std::string arch_name, int num_warps, int num_ctas, + int num_stages, mlir::triton::nvidia_gpu::ClusterInfo& out_cluster_info) { return absl::UnimplementedError("not supported for this build configuration"); } diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc index 31a8307e45360..97da071c5d362 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc @@ -1222,6 +1222,8 @@ absl::StatusOr CompileTritonToLLVM( mlir::ModuleOp triton_module, llvm::Module* llvm_module, mlir::MLIRContext& mlir_context, bool emit_kernel) { const auto& cc = device_info.gpu_compute_capability(); + const std::string arch_name = + std::visit([](auto& cc) { return cc.ToString(); }, cc); if (std::holds_alternative(cc)) { auto ccCuda = std::get(cc); if (!ccCuda.IsAtLeastAmpere()) { @@ -1281,7 +1283,9 @@ absl::StatusOr CompileTritonToLLVM( pm.addPass(CreateSimplifyAffinePass()); mlir::triton::nvidia_gpu::ClusterInfo cluster_info; - if (!CreateTritonPipeline(pm, cc, block_level_parameters, cluster_info) + if (!CreateTritonPipeline(&pm, arch_name, block_level_parameters.num_warps, + block_level_parameters.num_ctas, + block_level_parameters.num_stages, cluster_info) .ok()) { return Internal("Failed to create Triton pipeline."); } diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter_stub_test.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter_stub_test.cc index f063bc6460fc9..4e23149ba2431 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter_stub_test.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter_stub_test.cc @@ -49,7 +49,7 @@ TEST(TritonStub, CallStubApi) { mlir::OpPassManager pm; ::mlir::triton::nvidia_gpu::ClusterInfo cluster_info; - EXPECT_FALSE(CreateTritonPipeline(pm, {}, {}, cluster_info).ok()); + EXPECT_FALSE(CreateTritonPipeline(&pm, "", 1, 1, 1, cluster_info).ok()); EXPECT_EQ(GetLibdevicePath({}, {}), ""); EmitterLocOpBuilder builder(&context); diff --git a/xla/stream_executor/device_description.h b/xla/stream_executor/device_description.h index f6bb2e4a41ad4..396ce94a876db 100644 --- a/xla/stream_executor/device_description.h +++ b/xla/stream_executor/device_description.h @@ -23,6 +23,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -59,7 +60,7 @@ struct CudaComputeCapability { this->minor = minor; } // cuda arch format "major.minor", example: "8.6". - explicit CudaComputeCapability(const std::string &cuda_arch_name) { + explicit CudaComputeCapability(std::string cuda_arch_name) { std::vector split = absl::StrSplit(cuda_arch_name, '.'); assert(split.size() == 2); this->major = std::stoi(split[0]); @@ -236,6 +237,8 @@ class RocmComputeCapability { bool has_fp8_support() const { return gfx9_mi300(); } + std::string ToString() const { return gcn_arch_name(); } + RocmComputeCapabilityProto ToProto() const { RocmComputeCapabilityProto proto; proto.set_gcn_arch_name(gcn_arch_name_);