Skip to content

Commit

Permalink
[AMD][gfx12] Support Shared to dot WMMAv2 operand conversion (triton-…
Browse files Browse the repository at this point in the history
…lang#4467)

- Added separate method for tensor element mapping specific for WMMAv2
- Added a lit test to verify number of llvm instructions

Signed-off-by: Ilya Veselov <[email protected]>
  • Loading branch information
joviliast authored Aug 8, 2024
1 parent 46788a5 commit 1402578
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 41 deletions.
47 changes: 29 additions & 18 deletions test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir
Original file line number Diff line number Diff line change
@@ -1,34 +1,45 @@
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx1100 | FileCheck %s

#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
#mma = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}>
#mma1 = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}>
#mma2 = #triton_gpu.amd_wmma<{version = 2, warpsPerCTA = [2, 2]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: wmma_dot_operand
tt.func @wmma_dot_operand(%arg0: !tt.memdesc<64x64xf16, #shared>) {
// CHECK-LABEL: wmma1_dot_operand
tt.func @wmma1_dot_operand(%arg0: !tt.memdesc<64x64xf16, #shared>) {
// 2 CTA * 4 rep * load_per_thread_per_instr
// CHECK-COUNT-8: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xf16>
%0 = triton_gpu.local_load %arg0 : !tt.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
%0 = triton_gpu.local_load %arg0 : !tt.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>
// CHECK-COUNT-128: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<1xf16>
%1 = triton_gpu.local_load %arg0 : !tt.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>
%1 = triton_gpu.local_load %arg0 : !tt.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>
tt.return
}

// CHECK-LABEL: wmma2_dot_operand
tt.func @wmma2_dot_operand(%arg0: !tt.memdesc<64x64xf16, #shared>) {
// 2 CTA * 4 rep * load_per_thread_per_instr
// CHECK-COUNT-8: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<8xf16>
%0 = triton_gpu.local_load %arg0 : !tt.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma2, kWidth = 8}>>
// CHECK-COUNT-64: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<1xf16>
%1 = triton_gpu.local_load %arg0 : !tt.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma2, kWidth = 8}>>
tt.return
}

// CHECK-LABEL: wmma_dot
tt.func @wmma_dot(%arg0: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg2: tensor<16x16xf16, #mma>) {
tt.func @wmma_dot(%arg0: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xf16, #mma1>) {
// CHECK-COUNT-32: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16)>
// CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)>
// CHECK: llvm.mlir.undef : vector<16xf16>
// CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<16xf16>
// CHECK: rocdl.wmma.f16.16x16x16.f16 {{.*}} : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
%0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> -> tensor<16x16xf16, #mma>
%0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xf16, #mma1>
// CHECK-COUNT-8: llvm.extractelement {{.*}} : vector<16xf16>
// CHECK: llvm.mlir.undef : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)>
// CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)>
tt.return
}

// CHECK-LABEL: wmma_dot_bf16
tt.func @wmma_dot_bf16(%arg0: tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg2: tensor<16x16xbf16, #mma>) {
tt.func @wmma_dot_bf16(%arg0: tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xbf16, #mma1>) {
// CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)>
// CHECK: llvm.bitcast %{{.*}} : vector<16xbf16> to vector<16xi16>
// CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)>
Expand All @@ -37,12 +48,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK: llvm.mlir.undef : vector<16xbf16>
// CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<16xbf16>
// CHECK: rocdl.wmma.bf16.16x16x16.bf16 {{.*}} : (vector<16xi16>, vector<16xi16>, vector<16xbf16>, i1) -> vector<16xbf16>
%0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> * tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> -> tensor<16x16xbf16, #mma>
%0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xbf16, #mma1>
tt.return
}

// CHECK-LABEL: wmma_dot_int8_32
tt.func @wmma_dot_int8_32(%arg0: tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma>) {
tt.func @wmma_dot_int8_32(%arg0: tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma1>) {
// CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)>
// CHECK-COUNT-16: llvm.insertelement {{.*}} : vector<16xi8>
// CHECK: llvm.bitcast %{{.*}} : vector<16xi8> to vector<4xi32>
Expand All @@ -51,13 +62,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK: llvm.bitcast %{{.*}} : vector<16xi8> to vector<4xi32>
// CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK: rocdl.wmma.i32.16x16x16.iu8 {{.*}} : (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32>
%0 = tt.dot %arg0, %arg1, %arg2 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> * tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> -> tensor<16x16xi32, #mma>
%0 = tt.dot %arg0, %arg1, %arg2 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xi32, #mma1>
// CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
tt.return
}

// CHECK-LABEL: wmma_dot_int4_32
tt.func @wmma_dot_int4_32(%arg0: tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma>) {
tt.func @wmma_dot_int4_32(%arg0: tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma1>) {
// CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4)>
// CHECK-COUNT-16: llvm.insertelement {{.*}} : vector<16xi4>
// CHECK: llvm.bitcast %{{.*}} : vector<16xi4> to vector<2xi32>
Expand All @@ -66,7 +77,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK: llvm.bitcast %{{.*}} : vector<16xi4> to vector<2xi32>
// CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK: rocdl.wmma.i32.16x16x16.iu4 {{.*}} : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
%0 = tt.dot %arg0, %arg1, %arg2 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> * tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> -> tensor<16x16xi32, #mma>
%0 = tt.dot %arg0, %arg1, %arg2 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xi32, #mma1>
// CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
tt.return
}
Expand All @@ -75,27 +86,27 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// -----

#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [2, 1, 0], hasLeadingOffset = false}>
#mma = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 1, 4]}>
#mma1 = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 1, 4]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: wmma_dot_operand3d
tt.func @wmma_dot_operand3d(%arg0: !tt.memdesc<4x16x32xf16, #shared>) {
// CHECK-COUNT-4: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xf16>
%0 = triton_gpu.local_load %arg0 : !tt.memdesc<4x16x32xf16, #shared> -> tensor<4x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
%0 = triton_gpu.local_load %arg0 : !tt.memdesc<4x16x32xf16, #shared> -> tensor<4x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>
// CHECK-COUNT-32: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<1xf16>
%1 = triton_gpu.local_load %arg0 : !tt.memdesc<4x16x32xf16, #shared> -> tensor<4x16x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>
%1 = triton_gpu.local_load %arg0 : !tt.memdesc<4x16x32xf16, #shared> -> tensor<4x16x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>
tt.return
}

// CHECK-LABEL: wmma_dot3d
tt.func @wmma_dot3d(%arg0: tensor<2x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<2x32x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg2: tensor<2x16x16xf16, #mma>) {
tt.func @wmma_dot3d(%arg0: tensor<2x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<2x32x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<2x16x16xf16, #mma1>) {
// CHECK-COUNT-32: llvm.extractvalue %arg0
// CHECK-COUNT-32: llvm.insertelement
// CHECK-COUNT-32: llvm.extractvalue %arg1
// CHECK-COUNT-32: llvm.insertelement
// CHECK-COUNT-8: llvm.extractvalue %arg2
// CHECK-COUNT-8: llvm.insertelement
// CHECK-COUNT-2: rocdl.wmma.f16.16x16x16.f16 {{.*}} : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
%0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<2x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> * tensor<2x32x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> -> tensor<2x16x16xf16, #mma>
%0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<2x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<2x32x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<2x16x16xf16, #mma1>
// CHECK-COUNT-8: llvm.extractelement
// CHECK-COUNT-8: llvm.insertvalue
tt.return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ using ::mlir::triton::gpu::SharedEncodingAttr;
namespace SharedToDotOperandWMMA {

/**
* @brief This function maps particular load of wmma dot operand to element
* indexes(row, col)
* @brief Following functions maps particular load of wmma dot operand to
* element indexes(row, col). For each WMMA generation separate function is
* used.
*
* Whole tensor is broken into "blocks" of warps along "non-K" axis.
* One block could be processed by multiple warps.
Expand Down Expand Up @@ -64,7 +65,8 @@ namespace SharedToDotOperandWMMA {
* @return vector (i-th element corresponds to i-th load instruction) of
* 2-element vectors(tensor row and col).
*/
llvm::SmallVector<llvm::SmallVector<Value>> computeTensorElemMappingInBlock(
llvm::SmallVector<llvm::SmallVector<Value>>
computeTensorElemMappingInBlockWmma1(
ConversionPatternRewriter &rewriter, Location loc,
const ArrayRef<int64_t> &elemsPerInstr, Value warpId, Value laneId,
int numOfElems, ArrayRef<int64_t> reps, ArrayRef<Value> smemOffsets,
Expand All @@ -75,28 +77,55 @@ llvm::SmallVector<llvm::SmallVector<Value>> computeTensorElemMappingInBlock(
const int loadsPerThread = numOfElems / loadVecSize;
llvm::SmallVector<llvm::SmallVector<Value>> mapping(numK * loadsPerThread);

Value _0 = i32_val(0);
Value nonKDim = i32_val(iNonKDim);
Value warpVOffset = mul(warpId, i32_val(elemsPerInstr[0]));

Value elemsPerInstrV = i32_val(elemsPerInstr[0]);
Value warpVOffset = mul(warpId, elemsPerInstrV);
Value sliceVOffset = add(urem(laneId, elemsPerInstrV), warpVOffset);
auto rank = smemOffsets.size();
Value row = add(sliceVOffset, smemOffsets[rank - 2]);

for (int tile = 0; tile < numK; ++tile) {
Value tileVOffset = _0;
Value tileHOffset = i32_val(tile * elemsPerInstr[1]);

Value laneVOffset = laneId;
Value laneHOffset = _0;

for (int loadId = 0; loadId < loadsPerThread; ++loadId) {
Value elemVOffset = _0;
Value elemHOffset = i32_val(loadId * loadVecSize);
Value sliceHOffset = add(tileHOffset, elemHOffset);

Value col = add(sliceHOffset, smemOffsets[rank - 1]);
mapping[loadsPerThread * tile + loadId] = {row, col};
}
}

return mapping;
}

llvm::SmallVector<llvm::SmallVector<Value>>
computeTensorElemMappingInBlockWmma2(
ConversionPatternRewriter &rewriter, Location loc,
const ArrayRef<int64_t> &elemsPerInstr, Value warpId, Value laneId,
int numOfElems, ArrayRef<int64_t> reps, ArrayRef<Value> smemOffsets,
int loadVecSize, unsigned iNonKDim, [[maybe_unused]] unsigned iKDim) {
assert(reps.size() == 3);
assert(elemsPerInstr.size() == 2);
auto numK = reps[2];
const int loadsPerThread = numOfElems / loadVecSize;
llvm::SmallVector<llvm::SmallVector<Value>> mapping(numK * loadsPerThread);

Value sliceVOffset =
add(add(add(tileVOffset, laneVOffset), elemVOffset), warpVOffset);
Value sliceHOffset = add(add(tileHOffset, laneHOffset), elemHOffset);
Value rowsPerInstr = i32_val(elemsPerInstr[0]);
Value colsPerInstr = i32_val(elemsPerInstr[1]);
Value elemsPerThread = i32_val(elemsPerInstr[1] / 2);
Value warpVOffset = mul(warpId, rowsPerInstr);
Value sliceVOffset = add(urem(laneId, rowsPerInstr), warpVOffset);

auto rank = smemOffsets.size();
Value row = add(sliceVOffset, smemOffsets[rank - 2]);
Value laneHOffset = mul(udiv(laneId, colsPerInstr), elemsPerThread);

for (int tile = 0; tile < numK; ++tile) {
Value tileHOffset = add(laneHOffset, i32_val(tile * elemsPerInstr[1]));
for (int loadId = 0; loadId < loadsPerThread; ++loadId) {
Value elemHOffset = i32_val(loadId * loadVecSize);
Value sliceHOffset = add(tileHOffset, elemHOffset);

Value row = add(sliceVOffset, smemOffsets[rank - 2]);
Value col = add(sliceHOffset, smemOffsets[rank - 1]);

mapping[loadsPerThread * tile + loadId] = {row, col};
Expand All @@ -116,8 +145,9 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
int nonKDimIdx = opIdx == 0 ? rank - 2 : rank - 1;

auto wmmaLayout = cast<AMDWmmaEncodingAttr>(encoding.getParent());
// TODO: support 2nd gen of WMMA
assert(wmmaLayout.getVersion() == 1);
auto computeTensorElemMappingInBlock =
wmmaLayout.getVersion() == 1 ? computeTensorElemMappingInBlockWmma1
: computeTensorElemMappingInBlockWmma2;
assert(wmmaLayout.getMNKDimPerInstr()[nonKDimIdx] == 16);
auto warpsPerCTA = wmmaLayout.getWarpsPerCTA();

Expand All @@ -141,16 +171,14 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
auto repB = numReps[0];

unsigned iWaveSize = triton::gpu::getWarpSize(wmmaLayout);
unsigned iNumLanes = iWaveSize / 2;
assert(iWaveSize == 32);
Value waveSize = i32_val(iWaveSize);
Value numLanes = i32_val(iNumLanes);
Value linearWaveId = udiv(thread, waveSize);
Value lane = urem(thread, numLanes); // share elem between two threads

unsigned numElemsPerThreadPerRep = wmmaInstrK;
unsigned numElemsPerThreadPerRep =
wmmaLayout.getSizePerThreadForOperands(opIdx)[kDimIdx];

Value warp = udiv(thread, waveSize);
Value lane = urem(thread, waveSize);
unsigned int maxNumWarps = shape[nonKDimIdx] / wmmaInstrNonK;
int warpsPerBlockNonK = std::min(warpsPerCTA[nonKDimIdx], maxNumWarps);
int warpsPerBatch =
Expand Down

0 comments on commit 1402578

Please sign in to comment.