Skip to content

Commit

Permalink
[xla:gpu] CreateTritonPipeline no longer depends on internal XLA GP…
Browse files Browse the repository at this point in the history
…U abstractions

Both `BlockLevelParameters` and `se::ComputeCapability` were not strictly
necessary. So, I decided to replace them with simpler types, which do not
require JAX to depend on XLA:GPU internals.

Note also that `mlir::PassManager` is now passed by pointer to make it easier
to call into `CreateTritonPipeline` using MLIR C API abstractions, which
generally store pointers to their C++ counterparts.

See jax-ml/jax#25196.

PiperOrigin-RevId: 705210005
  • Loading branch information
superbobry authored and Google-ML-Automation committed Dec 11, 2024
1 parent 6a22ff1 commit 209cbfa
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 143 deletions.
3 changes: 1 addition & 2 deletions xla/service/gpu/fusions/triton/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ cc_library(
]),
hdrs = ["compilation_pipeline.h"],
deps = [
"//xla/service/gpu/model:tiled_hlo_instruction_or_computation",
"//xla/stream_executor:device_description",
"@com_google_absl//absl/status",
"@llvm-project//mlir:Pass",
] + if_gpu_is_configured([
Expand All @@ -85,6 +83,7 @@ cc_library(
"@llvm-project//mlir:Transforms",
"//xla/service:hlo_module_config",
"//xla/service/gpu:matmul_utils",
"//xla/stream_executor:device_description",
"@triton//:TritonDialects",
"@triton//:TritonGPUToLLVM",
"@triton//:TritonGPUTransforms",
Expand Down
9 changes: 4 additions & 5 deletions xla/service/gpu/fusions/triton/compilation_pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ limitations under the License.
#ifndef XLA_SERVICE_GPU_FUSIONS_TRITON_COMPILATION_PIPELINE_H_
#define XLA_SERVICE_GPU_FUSIONS_TRITON_COMPILATION_PIPELINE_H_

#include <string>

#include "absl/status/status.h"
#include "mlir/Pass/PassManager.h"
#include "xla/service/gpu/model/tiled_hlo_computation.h"
#include "xla/stream_executor/device_description.h"

namespace mlir::triton::nvidia_gpu {

Expand All @@ -41,9 +41,8 @@ namespace gpu {
// parameter which would give a hint to Triton which cluster dims we prefer to
// use, but that's not the case currently.
absl::Status CreateTritonPipeline(
mlir::OpPassManager& pm, const se::GpuComputeCapability& cc,
const BlockLevelParameters& block_level_parameters,
mlir::triton::nvidia_gpu::ClusterInfo& out_cluster_info);
mlir::OpPassManager* pm, std::string arch_name, int num_warps, int num_ctas,
int num_stages, mlir::triton::nvidia_gpu::ClusterInfo& out_cluster_info);

} // namespace gpu
} // namespace xla
Expand Down
134 changes: 65 additions & 69 deletions xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/

#include <string>
#include <utility>

#include "nvidia/include/NVGPUToLLVM/NVGPUToLLVMPass.h"
#include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h"
Expand All @@ -26,7 +27,6 @@ limitations under the License.
#include "mlir/Transforms/Passes.h"
#include "xla/service/gpu/fusions/triton/xla_triton_passes.h"
#include "xla/service/gpu/llvm_gpu_backend/nvptx_libdevice_path.h"
#include "xla/service/gpu/model/tiled_hlo_computation.h"
#include "xla/service/hlo_module_config.h"
#include "xla/stream_executor/device_description.h"
#include "triton/Conversion/TritonGPUToLLVM/Passes.h"
Expand All @@ -42,94 +42,90 @@ namespace mt = ::mlir::triton;
namespace mt_xla = ::mlir::triton::xla;

absl::Status CreateTritonPipeline(
mlir::OpPassManager& pm, const se::GpuComputeCapability& cc,
const BlockLevelParameters& block_level_parameters,
mt::nvidia_gpu::ClusterInfo& out_cluster_info) {
auto ccCuda = std::get<se::CudaComputeCapability>(cc);
const int ccAsInt = ccCuda.major * 10 + ccCuda.minor;
mlir::OpPassManager* pm, std::string arch_name, int num_warps, int num_ctas,
int num_stages, mt::nvidia_gpu::ClusterInfo& out_cluster_info) {
auto cc = se::CudaComputeCapability(std::move(arch_name));
const int ccAsInt = cc.major * 10 + cc.minor;
const int threadsPerWarp = 32;

// Based on make_ttir() in
// @triton//:third_party/nvidia/backend/compiler.py
pm.addPass(mlir::createInlinerPass());
pm.addPass(mt::createRewriteTensorPointerPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mt::createCombineOpsPass());
pm.addPass(mt::createReorderBroadcastPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createLoopInvariantCodeMotionPass());
pm.addPass(mlir::createSymbolDCEPass());
pm.addPass(mt::createLoopUnrollPass());
pm->addPass(mlir::createInlinerPass());
pm->addPass(mt::createRewriteTensorPointerPass());
pm->addPass(mlir::createCanonicalizerPass());
pm->addPass(mt::createCombineOpsPass());
pm->addPass(mt::createReorderBroadcastPass());
pm->addPass(mlir::createCSEPass());
pm->addPass(mlir::createLoopInvariantCodeMotionPass());
pm->addPass(mlir::createSymbolDCEPass());
pm->addPass(mt::createLoopUnrollPass());

// Based on make_ttgir() in
// @triton//:third_party/nvidia/backend/compiler.py
pm.addPass(mt::createConvertTritonToTritonGPUPass(
absl::StrFormat("cuda:%u", ccAsInt), block_level_parameters.num_warps,
threadsPerWarp, block_level_parameters.num_ctas));
pm.addPass(mt_xla::CreateSparseAddEncodingPass(
block_level_parameters.num_warps, threadsPerWarp,
block_level_parameters.num_ctas));
pm.addPass(mt::gpu::createTritonGPUCoalesce());
if (ccCuda.IsAtLeastAmpere()) {
pm.addPass(mt::gpu::createTritonGPUF32DotTC());
pm->addPass(mt::createConvertTritonToTritonGPUPass(
absl::StrFormat("cuda:%u", ccAsInt), num_warps, threadsPerWarp,
num_ctas));
pm->addPass(
mt_xla::CreateSparseAddEncodingPass(num_warps, threadsPerWarp, num_ctas));
pm->addPass(mt::gpu::createTritonGPUCoalesce());
if (cc.IsAtLeastAmpere()) {
pm->addPass(mt::gpu::createTritonGPUF32DotTC());
}
pm.addPass(mlir::createTritonNvidiaGPUPlanCTAPass(&out_cluster_info));
pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions());
pm.addPass(mt::gpu::createTritonGPUOptimizeThreadLocality());
pm.addPass(mt_xla::CreateSparseBlockedToMMAPass());
pm.addPass(mt::gpu::createTritonGPUAccelerateMatmul());
pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions());
pm.addPass(
mt::gpu::createTritonGPUOptimizeDotOperands({ccCuda.IsAtLeastAmpere()}));
pm.addPass(mlir::createCSEPass());
pm->addPass(mlir::createTritonNvidiaGPUPlanCTAPass(&out_cluster_info));
pm->addPass(mt::gpu::createTritonGPURemoveLayoutConversions());
pm->addPass(mt::gpu::createTritonGPUOptimizeThreadLocality());
pm->addPass(mt_xla::CreateSparseBlockedToMMAPass());
pm->addPass(mt::gpu::createTritonGPUAccelerateMatmul());
pm->addPass(mt::gpu::createTritonGPURemoveLayoutConversions());
pm->addPass(
mt::gpu::createTritonGPUOptimizeDotOperands({cc.IsAtLeastAmpere()}));
pm->addPass(mlir::createCSEPass());

// Even though we don't run on pre-Ampere architectures anymore, we keep this
// check for consistency with the upstream pipeline
if (ccCuda.IsAtLeastAmpere()) {
pm.addPass(mt::gpu::createTritonGPUCombineTensorSelectAndIf());
pm.addPass(mt::gpu::createTritonGPULoopScheduling(
{block_level_parameters.num_stages}));
pm.addPass(
mt::gpu::createTritonGPUPipeline({block_level_parameters.num_stages}));
if (cc.IsAtLeastAmpere()) {
pm->addPass(mt::gpu::createTritonGPUCombineTensorSelectAndIf());
pm->addPass(mt::gpu::createTritonGPULoopScheduling({num_stages}));
pm->addPass(mt::gpu::createTritonGPUPipeline({num_stages}));
}
pm.addPass(mt::gpu::createTritonGPUPrefetch());
pm.addPass(
mt::gpu::createTritonGPUOptimizeDotOperands({ccCuda.IsAtLeastAmpere()}));
pm.addPass(mt::gpu::createTritonGPUCoalesceAsyncCopy());
pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions());
pm.addPass(mt_xla::CreateSparseRemoveLayoutConversionPass());
pm.addPass(mt::gpu::createTritonGPUReduceDataDuplication());
pm.addPass(mt::gpu::createTritonGPUReorderInstructions());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createSymbolDCEPass());
if (ccCuda.IsAtLeastHopper()) {
pm.addPass(mlir::createTritonNvidiaGPUFenceInsertionPass(ccAsInt));
pm.addPass(mlir::createTritonNvidiaGPUTMALoweringPass());
pm->addPass(mt::gpu::createTritonGPUPrefetch());
pm->addPass(
mt::gpu::createTritonGPUOptimizeDotOperands({cc.IsAtLeastAmpere()}));
pm->addPass(mt::gpu::createTritonGPUCoalesceAsyncCopy());
pm->addPass(mt::gpu::createTritonGPURemoveLayoutConversions());
pm->addPass(mt_xla::CreateSparseRemoveLayoutConversionPass());
pm->addPass(mt::gpu::createTritonGPUReduceDataDuplication());
pm->addPass(mt::gpu::createTritonGPUReorderInstructions());
pm->addPass(mlir::createCSEPass());
pm->addPass(mlir::createSymbolDCEPass());
if (cc.IsAtLeastHopper()) {
pm->addPass(mlir::createTritonNvidiaGPUFenceInsertionPass(ccAsInt));
pm->addPass(mlir::createTritonNvidiaGPUTMALoweringPass());
}
pm.addPass(mlir::createCanonicalizerPass());
pm->addPass(mlir::createCanonicalizerPass());

// Based on make_llir() in
// @triton//:third_party/nvidia/backend/compiler.py
pm.addPass(mt::NVIDIA::createDecomposeUnsupportedConversionsPass());
pm->addPass(mt::NVIDIA::createDecomposeUnsupportedConversionsPass());
// This pass reduces Hopper compile time extensively: b/344841434.
if (ccCuda.IsAtLeastHopper()) {
pm.addPass(mt_xla::CreatePreventMmaV3LoopUnrollingPass());
if (cc.IsAtLeastHopper()) {
pm->addPass(mt_xla::CreatePreventMmaV3LoopUnrollingPass());
}
pm.addPass(mlir::createConvertSCFToCFPass());
pm.addPass(mlir::createConvertIndexToLLVMPass());
pm.addPass(mt::gpu::createAllocateSharedMemoryPass());
pm.addPass(mt::gpu::createTritonGPUGlobalScratchAllocationPass());
pm.addPass(mt_xla::CreateSparseLocalLoadToLLVMPass());
pm.addPass(mt::createConvertTritonGPUToLLVMPass(ccAsInt));
pm->addPass(mlir::createConvertSCFToCFPass());
pm->addPass(mlir::createConvertIndexToLLVMPass());
pm->addPass(mt::gpu::createAllocateSharedMemoryPass());
pm->addPass(mt::gpu::createTritonGPUGlobalScratchAllocationPass());
pm->addPass(mt_xla::CreateSparseLocalLoadToLLVMPass());
pm->addPass(mt::createConvertTritonGPUToLLVMPass(ccAsInt));
// The triton_xla.sparse_dot ops need to be rewritten after
// ModuleAxisInfoAnalysis inside convert-triton-gpu-to-llvm.
pm.addPass(mt_xla::CreateSparseDotOpToLLVMPass());
pm.addPass(mt::createConvertNVGPUToLLVMPass());
pm.addPass(mt_xla::CreateSparseWGMMAOpToLLVMPass());
pm.addPass(mlir::createArithToLLVMConversionPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createSymbolDCEPass());
pm->addPass(mt_xla::CreateSparseDotOpToLLVMPass());
pm->addPass(mt::createConvertNVGPUToLLVMPass());
pm->addPass(mt_xla::CreateSparseWGMMAOpToLLVMPass());
pm->addPass(mlir::createArithToLLVMConversionPass());
pm->addPass(mlir::createCanonicalizerPass());
pm->addPass(mlir::createCSEPass());
pm->addPass(mlir::createSymbolDCEPass());
// Note: translateTritonGPUToLLVMIR adds line info with LLVMDIScopePass.

return absl::OkStatus();
Expand Down
117 changes: 58 additions & 59 deletions xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ limitations under the License.
==============================================================================*/
// TODO(ROCm): Enable and include ROCm Triton passes when ROCm Triton is
// included in build.
#include <string>
#include <utility>

#include "third_party/amd/include/TritonAMDGPUToLLVM/Passes.h"
#include "third_party/amd/include/TritonAMDGPUTransforms/Passes.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
Expand All @@ -24,8 +27,8 @@ limitations under the License.
#include "mlir/Transforms/Passes.h"
#include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
#include "xla/service/gpu/matmul_utils.h"
#include "xla/service/gpu/model/tiled_hlo_computation.h"
#include "xla/service/hlo_module_config.h"
#include "xla/stream_executor/device_description.h"
#include "tsl/platform/rocm_rocdl_path.h"
#include "triton/Conversion/TritonGPUToLLVM/Passes.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
Expand Down Expand Up @@ -53,80 +56,76 @@ using ::mlir::Value;
using mlir::ValueRange;

absl::Status CreateTritonPipeline(
mlir::OpPassManager& pm, const se::GpuComputeCapability& cc,
const BlockLevelParameters& block_level_parameters,
mt::nvidia_gpu::ClusterInfo& out_cluster_info) {
mlir::OpPassManager* pm, std::string arch_name, int num_warps, int num_ctas,
int num_stages, mt::nvidia_gpu::ClusterInfo& out_cluster_info) {
// TODO(ROCm): Check why some test fail when threadsPerWarp is set to 64.
const int threadsPerWarp = 32;
auto ccRocm = std::get<se::RocmComputeCapability>(cc);
auto cc = se::RocmComputeCapability(std::move(arch_name));

// Based on make_ttir() in
// @triton//:third_party/amd/backend/compiler.py
pm.addPass(mlir::createInlinerPass());
pm.addPass(mt::createRewriteTensorPointerPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mt::createCombineOpsPass());
pm.addPass(mt::createReorderBroadcastPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createLoopInvariantCodeMotionPass());
pm.addPass(mlir::createSymbolDCEPass());
pm.addPass(mt::createLoopUnrollPass());
pm->addPass(mlir::createInlinerPass());
pm->addPass(mt::createRewriteTensorPointerPass());
pm->addPass(mlir::createCanonicalizerPass());
pm->addPass(mt::createCombineOpsPass());
pm->addPass(mt::createReorderBroadcastPass());
pm->addPass(mlir::createCSEPass());
pm->addPass(mlir::createLoopInvariantCodeMotionPass());
pm->addPass(mlir::createSymbolDCEPass());
pm->addPass(mt::createLoopUnrollPass());

// Based on make_ttgir() in
// @triton//:third_party/amd/backend/compiler.py
pm.addPass(mt::createConvertTritonToTritonGPUPass(
absl::StrCat("hip:", ccRocm.gfx_version()),
block_level_parameters.num_warps, threadsPerWarp,
block_level_parameters.num_ctas));
pm.addPass(mt::gpu::createTritonGPUCoalesce());
pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions());
pm.addPass(mt::gpu::createTritonGPUOptimizeThreadLocality());
pm.addPass(mt::gpu::createTritonGPUAccelerateMatmul());
pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions());
pm->addPass(mt::createConvertTritonToTritonGPUPass(
absl::StrCat("hip:", cc.gfx_version()), num_warps, threadsPerWarp,
num_ctas));
pm->addPass(mt::gpu::createTritonGPUCoalesce());
pm->addPass(mt::gpu::createTritonGPURemoveLayoutConversions());
pm->addPass(mt::gpu::createTritonGPUOptimizeThreadLocality());
pm->addPass(mt::gpu::createTritonGPUAccelerateMatmul());
pm->addPass(mt::gpu::createTritonGPURemoveLayoutConversions());
// TODO ROCm Check if we want to compare MI100 and greater
pm.addPass(mlir::createTritonAMDGPUOptimizeEpiloguePass());
pm.addPass(mt::gpu::createTritonGPUOptimizeDotOperands({true}));
if (block_level_parameters.num_stages == kAmdDoubleBuffering &&
ccRocm.has_amd_matrix_core()) {
pm.addPass(mlir::createTritonAMDGPUStreamPipelinePass(
block_level_parameters.num_stages, /*stream_prefetch=*/true));
pm.addPass(mlir::createCanonicalizerPass());
pm->addPass(mlir::createTritonAMDGPUOptimizeEpiloguePass());
pm->addPass(mt::gpu::createTritonGPUOptimizeDotOperands({true}));
if (num_stages == kAmdDoubleBuffering && cc.has_amd_matrix_core()) {
pm->addPass(mlir::createTritonAMDGPUStreamPipelinePass(
num_stages, /*stream_prefetch=*/true));
pm->addPass(mlir::createCanonicalizerPass());
}
pm.addPass(mt::createTritonAMDGPUInsertInstructionSchedHintsPass());
pm.addPass(mt::gpu::createTritonGPUOptimizeDotOperands({true}));
pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions());
pm.addPass(mt::gpu::createTritonGPUReduceDataDuplication());
if (block_level_parameters.num_stages != kAmdDoubleBuffering) {
pm.addPass(mt::gpu::createTritonGPUReorderInstructions());
pm->addPass(mt::createTritonAMDGPUInsertInstructionSchedHintsPass());
pm->addPass(mt::gpu::createTritonGPUOptimizeDotOperands({true}));
pm->addPass(mt::gpu::createTritonGPURemoveLayoutConversions());
pm->addPass(mt::gpu::createTritonGPUReduceDataDuplication());
if (num_stages != kAmdDoubleBuffering) {
pm->addPass(mt::gpu::createTritonGPUReorderInstructions());
}
pm.addPass(mlir::createTritonAMDGPUCanonicalizePointersPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createSymbolDCEPass());
pm->addPass(mlir::createTritonAMDGPUCanonicalizePointersPass());
pm->addPass(mlir::createCanonicalizerPass());
pm->addPass(mlir::createCSEPass());
pm->addPass(mlir::createSymbolDCEPass());

// Based on make_llir() in
// @triton//:third_party/amd/backend/compiler.py
pm.addPass(mlir::triton::AMD::createDecomposeUnsupportedConversionsPass(
ccRocm.gfx_version()));
pm->addPass(mlir::triton::AMD::createDecomposeUnsupportedConversionsPass(
cc.gfx_version()));
const int custom_lds_size = 0;
pm.addPass(mlir::triton::AMD::createOptimizeLDSUsagePass(ccRocm.gfx_version(),
custom_lds_size));
pm.addPass(mlir::createConvertSCFToCFPass());
pm.addPass(mlir::createConvertIndexToLLVMPass());
pm.addPass(mt::gpu::createAllocateSharedMemoryPass());
pm.addPass(
mt::createConvertTritonAMDGPUToLLVMPass(ccRocm.gfx_version(), true));
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm->addPass(mlir::triton::AMD::createOptimizeLDSUsagePass(cc.gfx_version(),
custom_lds_size));
pm->addPass(mlir::createConvertSCFToCFPass());
pm->addPass(mlir::createConvertIndexToLLVMPass());
pm->addPass(mt::gpu::createAllocateSharedMemoryPass());
pm->addPass(mt::createConvertTritonAMDGPUToLLVMPass(cc.gfx_version(), true));
pm->addPass(mlir::createCanonicalizerPass());
pm->addPass(mlir::createCSEPass());
// Note: translateTritonGPUToLLVMIR adds line info with LLVMDIScopePass.
pm.addPass(mlir::createConvertControlFlowToLLVMPass());
pm.addPass(mlir::createArithToLLVMConversionPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createSymbolDCEPass());
pm.addPass(mt::createTritonAMDGPULowerInstructionSchedHintsPass(
ccRocm.gfx_version(), block_level_parameters.num_stages, "default"));
pm.addPass(mt::createConvertBuiltinFuncToLLVMPass(/*ftz=*/true));
pm->addPass(mlir::createConvertControlFlowToLLVMPass());
pm->addPass(mlir::createArithToLLVMConversionPass());
pm->addPass(mlir::createCanonicalizerPass());
pm->addPass(mlir::createCSEPass());
pm->addPass(mlir::createSymbolDCEPass());
pm->addPass(mt::createTritonAMDGPULowerInstructionSchedHintsPass(
cc.gfx_version(), num_stages, "default"));
pm->addPass(mt::createConvertBuiltinFuncToLLVMPass(/*ftz=*/true));
// There is no clusters in ROCm for now.
out_cluster_info.clusterDimX = 1;
out_cluster_info.clusterDimY = 1;
Expand Down
9 changes: 4 additions & 5 deletions xla/service/gpu/fusions/triton/compilation_pipeline_stub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <string>

#include "absl/status/status.h"
#include "mlir/Pass/PassManager.h"
#include "xla/service/gpu/fusions/triton/compilation_pipeline.h"
#include "xla/service/gpu/model/tiled_hlo_computation.h"
#include "xla/stream_executor/device_description.h"

namespace xla {
namespace gpu {

absl::Status CreateTritonPipeline(
mlir::OpPassManager& pm, const se::GpuComputeCapability& cc,
const BlockLevelParameters& block_level_parameters,
mlir::triton::nvidia_gpu::ClusterInfo& out_cluster_info) {
mlir::OpPassManager* pm, std::string arch_name, int num_warps, int num_ctas,
int num_stages, mlir::triton::nvidia_gpu::ClusterInfo& out_cluster_info) {
return absl::UnimplementedError("not supported for this build configuration");
}

Expand Down
Loading

0 comments on commit 209cbfa

Please sign in to comment.