Skip to content

Commit

Permalink
[AMD] Fixed instruction reorder
Browse files Browse the repository at this point in the history
  • Loading branch information
ravil-mobile committed Nov 25, 2024
1 parent 3fc21bb commit 9255bc6
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 21 deletions.
36 changes: 18 additions & 18 deletions test/TritonGPU/amd/amd-reorder-instructions.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -267,30 +267,30 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :

// CHECK-LABEL: tt.func @matmul_loop_mb
// CHECK: %{{.*}}:8 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}})
// Stage 0
// CHECK: %[[ADDPTR_28:.*]] = tt.addptr %[[ARG6]], %{{.*}}
// CHECK: %[[MULI_29:.*]] = arith.muli %{{.*}}, %{{.*}}
// CHECK: %[[SUBI_30:.*]] = arith.subi %{{.*}}, %[[MULI_29]]
// CHECK: %[[CMPI_31:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_30]]
// CHECK: %[[SPLAT_32:.*]] = tt.splat %[[CMPI_31]]
// CHECK: %[[LOAD_33:.*]] = tt.load %[[ADDPTR_28]], %[[SPLAT_32]]
// CHECK: %[[ADDPTR_34:.*]] = tt.addptr %[[ARG7]], %{{.*}}
// CHECK: %[[SPLAT_35:.*]] = tt.splat %[[CMPI_31]]
// CHECK: %[[LOAD_36:.*]] = tt.load %[[ADDPTR_34]], %[[SPLAT_35]]
// Stage 1
// CHECK: %[[ADDI_37:.*]] = arith.addi %[[ARG9]], %{{.*}}
// CHECK: %[[CMPI_38:.*]] = arith.cmpi slt, %[[ADDI_37]], %{{.*}}
// CHECK: %[[SELECT_39:.*]] = arith.select %[[CMPI_38]], %[[ADDI_37]], %{{.*}}
// CHECK: %[[MEMDESC_SUBVIEW_40:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}]
// CHECK: triton_gpu.local_store %[[ARG12]], %[[MEMDESC_SUBVIEW_40]]
// CHECK: %[[MEMDESC_SUBVIEW_41:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}]
// CHECK: triton_gpu.local_store %[[ARG13]], %[[MEMDESC_SUBVIEW_41]]
// CHECK: %[[ADDI_28:.*]] = arith.addi %[[ARG9]], %{{.*}}
// CHECK: %[[CMPI_29:.*]] = arith.cmpi slt, %[[ADDI_28]], %{{.*}}
// CHECK: %[[SELECT_30:.*]] = arith.select %[[CMPI_29]], %[[ADDI_28]], %{{.*}}
// CHECK: %[[MEMDESC_SUBVIEW_31:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_30]], %{{.*}}, %{{.*}}]
// CHECK: triton_gpu.local_store %[[ARG12]], %[[MEMDESC_SUBVIEW_31]]
// CHECK: %[[MEMDESC_SUBVIEW_32:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_30]], %{{.*}}, %{{.*}}]
// CHECK: triton_gpu.local_store %[[ARG13]], %[[MEMDESC_SUBVIEW_32]]
// Stage 1
// CHECK: %[[ADDPTR_33:.*]] = tt.addptr %[[ARG6]], %{{.*}}
// CHECK: %[[MULI_34:.*]] = arith.muli %{{.*}}, %{{.*}}
// CHECK: %[[SUBI_35:.*]] = arith.subi %{{.*}}, %[[MULI_34]]
// CHECK: %[[CMPI_36:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_35]]
// CHECK: %[[SPLAT_37:.*]] = tt.splat %[[CMPI_36]]
// CHECK: %[[LOAD_38:.*]] = tt.load %[[ADDPTR_33]], %[[SPLAT_37]]
// CHECK: %[[ADDPTR_39:.*]] = tt.addptr %[[ARG7]], %{{.*}}
// CHECK: %[[SPLAT_40:.*]] = tt.splat %[[CMPI_36]]
// CHECK: %[[LOAD_41:.*]] = tt.load %[[ADDPTR_39]], %[[SPLAT_40]]
// Stage 2
// CHECK: %[[LOCAL_LOAD_42:.*]] = triton_gpu.local_load %[[ARG10]]
// CHECK: %[[LOCAL_LOAD_43:.*]] = triton_gpu.local_load %[[ARG11]]
// CHECK: %[[MULF_44:.*]] = arith.mulf %[[LOCAL_LOAD_43]], %{{.*}}
// CHECK: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_42]], %[[MULF_44]], %[[ARG8]]
// CHECK: scf.yield %[[ADDPTR_28]], %[[ADDPTR_34]], %[[DOT_45]], %[[SELECT_39]], %[[MEMDESC_SUBVIEW_40]], %[[MEMDESC_SUBVIEW_41]], %[[LOAD_33]], %[[LOAD_36]]
// CHECK: scf.yield %[[ADDPTR_33]], %[[ADDPTR_39]], %[[DOT_45]], %[[SELECT_30]], %[[MEMDESC_SUBVIEW_31]], %[[MEMDESC_SUBVIEW_32]], %[[LOAD_38]], %[[LOAD_41]]
// CHECK: }

tt.func @matmul_loop_mb(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #mma> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,14 @@ static void moveUpTranspose(triton::FuncOp funcOp) {
// Schedule global load and local store ops for better GEMM performance.
static void scheduleGlobalLoadLocalStore(scf::ForOp forOp) {
SmallVector<Operation *> moveOps;
// Move global loads early to prefetch. This may increase register pressure
// but it enables issuing global loads early.
forOp.walk([&](triton::LoadOp op) { moveOps.push_back(op); });
// Move local_stores early if dependence distance greater than one iteration.
// Best perf on GEMM when these precede global loads.
forOp.walk([&](ttg::LocalStoreOp op) { moveOps.push_back(op); });

// Move global loads early to prefetch. This may increase register pressure
// but it enables issuing global loads early.
forOp.walk([&](triton::LoadOp op) { moveOps.push_back(op); });

for (auto op : llvm::reverse(moveOps)) {
// Gather use-def chain in block.
Block *block = op->getBlock();
Expand Down

0 comments on commit 9255bc6

Please sign in to comment.