Skip to content

Commit

Permalink
[AMD] Use Linear Layout convertions for AMDWmma
Browse files Browse the repository at this point in the history
Enable LL conwertions for WMMA as well as for MFMA layouts.

See also: #5210

Signed-off-by: Ilya Veselov <[email protected]>
  • Loading branch information
joviliast committed Nov 26, 2024
1 parent 68a08dd commit 8287311
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 2 deletions.
7 changes: 5 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,13 +374,16 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
// TODO (Keren): Currently, we handle general mma/blocked/slice/dot(ampere)
// -> mma/blocked/slice/dot(ampere) conversions. The following tasks must be
// completed before we can remove the layoutIsOK check:
// 1. Support for AMD's WMMA
// 1. Support for AMD's WMMA dot operand
std::function<bool(Attribute)> layoutIsOK = [&](Attribute layout) {
if (auto dotOperand = dyn_cast<DotOperandEncodingAttr>(layout)) {
layout = dotOperand.getParent();
if (isa<AMDWmmaEncodingAttr>(layout)) {
return false;
}
}

if (isa<NvidiaMmaEncodingAttr, AMDMfmaEncodingAttr>(layout)) {
if (isa<MmaEncodingTrait>(layout)) {
return !useLegacyMMAConversion;
}
if (isa<BlockedEncodingAttr, LinearEncodingAttr>(layout)) {
Expand Down
65 changes: 65 additions & 0 deletions test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx1100 | FileCheck %s

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
#mma1 = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}>
#mma2 = #triton_gpu.amd_wmma<{version = 2, warpsPerCTA = [2, 2]}>
Expand Down Expand Up @@ -97,6 +98,70 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)>
tt.return
}

// CHECK-LABEL: blocked_to_wmma1
tt.func @blocked_to_wmma1(%arg0: tensor<128x16xi32, #blocked>) {
// CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK-COUNT-32: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
%0 = triton_gpu.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #blocked> -> tensor<128x16xi32, #mma1>
tt.return
}

// CHECK-LABEL: slice_blocked_to_wmma1
tt.func @slice_blocked_to_wmma1(%arg0: tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) {
// CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK-COUNT-1: llvm.insertvalue {{.*}} : !llvm.struct<(i32)>
%0 = triton_gpu.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>>
tt.return
}

// CHECK-LABEL: wmma1_to_blocked
tt.func @wmma1_to_blocked(%arg0: tensor<128x16xi32, #mma1>) {
// CHECK-COUNT-32: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
%0 = triton_gpu.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #mma1> -> tensor<128x16xi32, #blocked>
tt.return
}

// CHECK-LABEL: slice_wmma1_to_blocked
tt.func @slice_wmma1_to_blocked(%arg0: tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>>) {
// CHECK-COUNT-1: llvm.extractvalue {{.*}} : !llvm.struct<(i32)>
// CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
%0 = triton_gpu.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
tt.return
}

// CHECK-LABEL: blocked_to_wmma2
tt.func @blocked_to_wmma2(%arg0: tensor<128x16xi32, #blocked>) {
// CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK-COUNT-32: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
%0 = triton_gpu.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #blocked> -> tensor<128x16xi32, #mma2>
tt.return
}

// CHECK-LABEL: slice_blocked_to_wmma2
tt.func @slice_blocked_to_wmma2(%arg0: tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) {
// CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK-COUNT-1: llvm.insertvalue {{.*}} : !llvm.struct<(i32)>
%0 = triton_gpu.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma2}>>
tt.return
}

// CHECK-LABEL: wmma2_to_blocked
tt.func @wmma2_to_blocked(%arg0: tensor<128x16xi32, #mma2>) {
// CHECK-COUNT-32: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
%0 = triton_gpu.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #mma2> -> tensor<128x16xi32, #blocked>
tt.return
}

// CHECK-LABEL: slice_wmma2_to_blocked
tt.func @slice_wmma2_to_blocked(%arg0: tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma2}>>) {
// CHECK-COUNT-1: llvm.extractvalue {{.*}} : !llvm.struct<(i32)>
// CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
%0 = triton_gpu.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma2}>> -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
tt.return
}
}

// -----
Expand Down

0 comments on commit 8287311

Please sign in to comment.