Skip to content

Commit

Permalink
[AMD] Adjusted ordering local stores and global loads for GEMMs
Browse files Browse the repository at this point in the history
  • Loading branch information
ravil-mobile committed Nov 25, 2024
1 parent 3fc21bb commit de1d3e2
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 15 deletions.
2 changes: 1 addition & 1 deletion third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def make_ttgir(mod, metadata, options):
passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_reduce_data_duplication(pm)
if amd.has_matrix_core_feature(options.arch):
amd.passes.ttgpuir.add_reorder_instructions(pm)
amd.passes.ttgpuir.add_reorder_instructions(pm, options.num_stages, stream_prefetch)

if use_buffer_ops:
amd.passes.ttgpuir.add_canonicalize_pointers(pm)
Expand Down
4 changes: 3 additions & 1 deletion third_party/amd/include/TritonAMDGPUTransforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ createTritonAMDGPUAccelerateMatmulPass(std::string archGenName = std::string(),

std::unique_ptr<Pass> createTritonAMDGPUCanonicalizeLoopsPass();

std::unique_ptr<Pass> createTritonAMDGPUReorderInstructionsPass();
std::unique_ptr<Pass>
createTritonAMDGPUReorderInstructionsPass(int32_t numStages,
bool streamPrefetch);

std::unique_ptr<Pass> createTritonAMDGPUVerifier();

Expand Down
9 changes: 8 additions & 1 deletion third_party/amd/include/TritonAMDGPUTransforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,16 @@ def TritonAMDGPUReorderInstructions: Pass<"tritonamdgpu-reorder-instructions", "
"conversions from shared memory before their first use) and (2) promote LLVM instruction "
"order more friendly to `ptxas`.";

let constructor = "mlir::createTritonAMDGPUReorderInstructionsPass()";
let constructor = "mlir::createTritonAMDGPUReorderInstructionsPass(2, false)";

let dependentDialects = [];

let options = [
Option<"numStages", "num_stages", "int32_t", /*default*/"2",
"number of pipeline stages">,
Option<"streamPrefetch", "local_prefetch", "bool", /*default*/"false",
"indicates whether stream prefetch is enabled">,
];
}

def TritonAMDGPUConvertToBufferOps : Pass<"tritonamdgpu-convert-buffer-ops", "mlir::ModuleOp"> {
Expand Down
44 changes: 34 additions & 10 deletions third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,26 @@ static void moveUpTranspose(triton::FuncOp funcOp) {
}

// Schedule global load and local store ops for better GEMM performance.
static void scheduleGlobalLoadLocalStore(scf::ForOp forOp) {
static void
scheduleGlobalLoadLocalStore(scf::ForOp forOp,
const bool independentGlobalLoadStages) {
SmallVector<Operation *> moveOps;
// Move global loads early to prefetch. This may increase register pressure
// but it enables issuing global loads early.
forOp.walk([&](triton::LoadOp op) { moveOps.push_back(op); });
// Move local_stores early if dependence distance greater than one iteration.
// Best perf on GEMM when these precede global loads.
forOp.walk([&](ttg::LocalStoreOp op) { moveOps.push_back(op); });

if (independentGlobalLoadStages) {
// Move local stores early to prefetch. It results in moving the
// corresponding memory fence to the very top of the current basic block.
// This results in better instruction interleaving
// - i.e., `ds_write`, `global/buffer_loads`.
forOp.walk([&](ttg::LocalStoreOp op) { moveOps.push_back(op); });
forOp.walk([&](triton::LoadOp op) { moveOps.push_back(op); });
} else {
// Move global loads early to prefetch. This may increase register pressure
// but it enables issuing global loads early.
forOp.walk([&](triton::LoadOp op) { moveOps.push_back(op); });
// Move local_stores early if dependence distance greater than one
// iteration. Best perf on GEMM when these precede global loads.
forOp.walk([&](ttg::LocalStoreOp op) { moveOps.push_back(op); });
}

for (auto op : llvm::reverse(moveOps)) {
// Gather use-def chain in block.
Expand Down Expand Up @@ -360,6 +372,13 @@ namespace {
struct TritonAMDGPUReorderInstructionsPass
: public TritonAMDGPUReorderInstructionsBase<
TritonAMDGPUReorderInstructionsPass> {

explicit TritonAMDGPUReorderInstructionsPass(int32_t numStages,
bool streamPrefetch) {
this->numStages = numStages;
this->streamPrefetch = streamPrefetch;
}

void runOnOperation() override {
ModuleOp m = getOperation();
for (auto funcOp : m.getOps<triton::FuncOp>()) {
Expand All @@ -370,10 +389,12 @@ struct TritonAMDGPUReorderInstructionsPass

moveUpTranspose(funcOp);

const bool independentGlobalLoadStages =
this->numStages > 2 || this->streamPrefetch;
SmallVector<scf::ForOp> leafForOps = triton::AMD::getLeafForOps(funcOp);
for (auto forOp : leafForOps) {
if (isPureMatmulProblem(forOp)) {
scheduleGlobalLoadLocalStore(forOp);
scheduleGlobalLoadLocalStore(forOp, independentGlobalLoadStages);
sinkSecondLoad(forOp);
}
}
Expand All @@ -382,6 +403,9 @@ struct TritonAMDGPUReorderInstructionsPass
};
} // namespace

std::unique_ptr<Pass> mlir::createTritonAMDGPUReorderInstructionsPass() {
return std::make_unique<TritonAMDGPUReorderInstructionsPass>();
std::unique_ptr<Pass>
mlir::createTritonAMDGPUReorderInstructionsPass(int32_t numStages,
bool streamPrefetch) {
return std::make_unique<TritonAMDGPUReorderInstructionsPass>(numStages,
streamPrefetch);
}
5 changes: 3 additions & 2 deletions third_party/amd/python/triton_amd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) {
mlir::createTritonAMDGPUCanonicalizePointersPass);
ADD_PASS_WRAPPER_0("add_convert_to_buffer_ops",
mlir::createTritonAMDGPUConvertToBufferOpsPass);
ADD_PASS_WRAPPER_0("add_reorder_instructions",
mlir::createTritonAMDGPUReorderInstructionsPass);
ADD_PASS_WRAPPER_2("add_reorder_instructions",
mlir::createTritonAMDGPUReorderInstructionsPass, int32_t,
bool);
ADD_PASS_WRAPPER_2("add_stream_pipelinev2",
mlir::createTritonAMDGPUStreamPipelineV2Pass, int, int);
}
Expand Down

0 comments on commit de1d3e2

Please sign in to comment.