From 98b4f9e06824d0c140b8c6d1ad2f9372406113ff Mon Sep 17 00:00:00 2001 From: ravil-mobile Date: Tue, 26 Nov 2024 17:14:02 +0000 Subject: [PATCH] [AMD] fixed the ReorderInstructions pass --- .../amd/amd-reorder-instructions.mlir | 36 +++++++++---------- .../ReorderInstructions.cpp | 6 ++-- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/test/TritonGPU/amd/amd-reorder-instructions.mlir b/test/TritonGPU/amd/amd-reorder-instructions.mlir index 708d75a232c6..385ddab306d4 100644 --- a/test/TritonGPU/amd/amd-reorder-instructions.mlir +++ b/test/TritonGPU/amd/amd-reorder-instructions.mlir @@ -267,30 +267,30 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-LABEL: tt.func @matmul_loop_mb // CHECK: %{{.*}}:8 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}) -// Stage 0 -// CHECK: %[[ADDPTR_28:.*]] = tt.addptr %[[ARG6]], %{{.*}} -// CHECK: %[[MULI_29:.*]] = arith.muli %{{.*}}, %{{.*}} -// CHECK: %[[SUBI_30:.*]] = arith.subi %{{.*}}, %[[MULI_29]] -// CHECK: %[[CMPI_31:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_30]] -// CHECK: %[[SPLAT_32:.*]] = tt.splat %[[CMPI_31]] -// CHECK: %[[LOAD_33:.*]] = tt.load %[[ADDPTR_28]], %[[SPLAT_32]] -// CHECK: %[[ADDPTR_34:.*]] = tt.addptr %[[ARG7]], %{{.*}} -// CHECK: %[[SPLAT_35:.*]] = tt.splat %[[CMPI_31]] -// CHECK: %[[LOAD_36:.*]] = tt.load %[[ADDPTR_34]], %[[SPLAT_35]] // Stage 1 -// CHECK: %[[ADDI_37:.*]] = arith.addi %[[ARG9]], %{{.*}} -// CHECK: %[[CMPI_38:.*]] = arith.cmpi slt, %[[ADDI_37]], %{{.*}} -// CHECK: %[[SELECT_39:.*]] = arith.select %[[CMPI_38]], %[[ADDI_37]], %{{.*}} -// CHECK: %[[MEMDESC_SUBVIEW_40:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}] -// CHECK: triton_gpu.local_store %[[ARG12]], %[[MEMDESC_SUBVIEW_40]] -// CHECK: %[[MEMDESC_SUBVIEW_41:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}] -// CHECK: triton_gpu.local_store %[[ARG13]], %[[MEMDESC_SUBVIEW_41]] +// CHECK: %[[ADDI_28:.*]] = arith.addi %[[ARG9]], %{{.*}} +// CHECK: %[[CMPI_29:.*]] = arith.cmpi slt, %[[ADDI_28]], %{{.*}} +// CHECK: %[[SELECT_30:.*]] = arith.select %[[CMPI_29]], %[[ADDI_28]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_31:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_30]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[ARG12]], %[[MEMDESC_SUBVIEW_31]] +// CHECK: %[[MEMDESC_SUBVIEW_32:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_30]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[ARG13]], %[[MEMDESC_SUBVIEW_32]] +// Stage 1 +// CHECK: %[[ADDPTR_33:.*]] = tt.addptr %[[ARG6]], %{{.*}} +// CHECK: %[[MULI_34:.*]] = arith.muli %{{.*}}, %{{.*}} +// CHECK: %[[SUBI_35:.*]] = arith.subi %{{.*}}, %[[MULI_34]] +// CHECK: %[[CMPI_36:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_35]] +// CHECK: %[[SPLAT_37:.*]] = tt.splat %[[CMPI_36]] +// CHECK: %[[LOAD_38:.*]] = tt.load %[[ADDPTR_33]], %[[SPLAT_37]] +// CHECK: %[[ADDPTR_39:.*]] = tt.addptr %[[ARG7]], %{{.*}} +// CHECK: %[[SPLAT_40:.*]] = tt.splat %[[CMPI_36]] +// CHECK: %[[LOAD_41:.*]] = tt.load %[[ADDPTR_39]], %[[SPLAT_40]] // Stage 2 // CHECK: %[[LOCAL_LOAD_42:.*]] = triton_gpu.local_load %[[ARG10]] // CHECK: %[[LOCAL_LOAD_43:.*]] = triton_gpu.local_load %[[ARG11]] // CHECK: %[[MULF_44:.*]] = arith.mulf %[[LOCAL_LOAD_43]], %{{.*}} // CHECK: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_42]], %[[MULF_44]], %[[ARG8]] -// CHECK: scf.yield %[[ADDPTR_28]], %[[ADDPTR_34]], %[[DOT_45]], %[[SELECT_39]], %[[MEMDESC_SUBVIEW_40]], %[[MEMDESC_SUBVIEW_41]], %[[LOAD_33]], %[[LOAD_36]] +// CHECK: scf.yield %[[ADDPTR_33]], %[[ADDPTR_39]], %[[DOT_45]], %[[SELECT_30]], %[[MEMDESC_SUBVIEW_31]], %[[MEMDESC_SUBVIEW_32]], %[[LOAD_38]], %[[LOAD_41]] // CHECK: } tt.func @matmul_loop_mb(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #mma> { diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp index 0837f16dcf7c..0f62d32ca284 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp @@ -216,12 +216,12 @@ static void moveUpTranspose(triton::FuncOp funcOp) { // Schedule global load and local store ops for better GEMM performance. static void scheduleGlobalLoadLocalStore(triton::FuncOp funcOp) { SmallVector moveOps; - // Move global loads early to prefetch. This may increase register pressure - // but it enables issuing global loads early. - funcOp.walk([&](triton::LoadOp op) { moveOps.push_back(op); }); // Move local_stores early if dependence distance greater than one iteration. // Best perf on GEMM when these precede global loads. funcOp.walk([&](ttg::LocalStoreOp op) { moveOps.push_back(op); }); + // Move global loads early to prefetch. This may increase register pressure + // but it enables issuing global loads early. + funcOp.walk([&](triton::LoadOp op) { moveOps.push_back(op); }); for (auto op : llvm::reverse(moveOps)) { // Gather use-def chain in block.