From 14025786d108596cfd99700caa4f438938c2ceba Mon Sep 17 00:00:00 2001 From: Ilya V <152324710+joviliast@users.noreply.github.com> Date: Thu, 8 Aug 2024 17:17:37 +0200 Subject: [PATCH] [AMD][gfx12] Support Shared to dot WMMAv2 operand conversion (#4467) - Added separate method for tensor element mapping specific for WMMAv2 - Added a lit test to verify number of llvm instructions Signed-off-by: Ilya Veselov --- .../amd/tritongpu_wmma_dot_to_llvm.mlir | 47 +++++++----- .../SharedToDotOperandWMMA.cpp | 74 +++++++++++++------ 2 files changed, 80 insertions(+), 41 deletions(-) diff --git a/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir b/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir index b030c095a0e5..07364eabc345 100644 --- a/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir @@ -1,26 +1,37 @@ // RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx1100 | FileCheck %s #shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> -#mma = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}> +#mma1 = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}> +#mma2 = #triton_gpu.amd_wmma<{version = 2, warpsPerCTA = [2, 2]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - // CHECK-LABEL: wmma_dot_operand - tt.func @wmma_dot_operand(%arg0: !tt.memdesc<64x64xf16, #shared>) { + // CHECK-LABEL: wmma1_dot_operand + tt.func @wmma1_dot_operand(%arg0: !tt.memdesc<64x64xf16, #shared>) { // 2 CTA * 4 rep * load_per_thread_per_instr // CHECK-COUNT-8: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xf16> - %0 = triton_gpu.local_load %arg0 : !tt.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> + %0 = triton_gpu.local_load %arg0 : !tt.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> // CHECK-COUNT-128: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<1xf16> - %1 = triton_gpu.local_load %arg0 : !tt.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> + %1 = triton_gpu.local_load %arg0 : !tt.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> + tt.return + } + + // CHECK-LABEL: wmma2_dot_operand + tt.func @wmma2_dot_operand(%arg0: !tt.memdesc<64x64xf16, #shared>) { + // 2 CTA * 4 rep * load_per_thread_per_instr + // CHECK-COUNT-8: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<8xf16> + %0 = triton_gpu.local_load %arg0 : !tt.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma2, kWidth = 8}>> + // CHECK-COUNT-64: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<1xf16> + %1 = triton_gpu.local_load %arg0 : !tt.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma2, kWidth = 8}>> tt.return } // CHECK-LABEL: wmma_dot - tt.func @wmma_dot(%arg0: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg2: tensor<16x16xf16, #mma>) { + tt.func @wmma_dot(%arg0: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xf16, #mma1>) { // CHECK-COUNT-32: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16)> // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> // CHECK: llvm.mlir.undef : vector<16xf16> // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<16xf16> // CHECK: rocdl.wmma.f16.16x16x16.f16 {{.*}} : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16> - %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> -> tensor<16x16xf16, #mma> + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xf16, #mma1> // CHECK-COUNT-8: llvm.extractelement {{.*}} : vector<16xf16> // CHECK: llvm.mlir.undef : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> // CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> @@ -28,7 +39,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } // CHECK-LABEL: wmma_dot_bf16 - tt.func @wmma_dot_bf16(%arg0: tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg2: tensor<16x16xbf16, #mma>) { + tt.func @wmma_dot_bf16(%arg0: tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xbf16, #mma1>) { // CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)> // CHECK: llvm.bitcast %{{.*}} : vector<16xbf16> to vector<16xi16> // CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)> @@ -37,12 +48,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.mlir.undef : vector<16xbf16> // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<16xbf16> // CHECK: rocdl.wmma.bf16.16x16x16.bf16 {{.*}} : (vector<16xi16>, vector<16xi16>, vector<16xbf16>, i1) -> vector<16xbf16> - %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> * tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> -> tensor<16x16xbf16, #mma> + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xbf16, #mma1> tt.return } // CHECK-LABEL: wmma_dot_int8_32 - tt.func @wmma_dot_int8_32(%arg0: tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma>) { + tt.func @wmma_dot_int8_32(%arg0: tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma1>) { // CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)> // CHECK-COUNT-16: llvm.insertelement {{.*}} : vector<16xi8> // CHECK: llvm.bitcast %{{.*}} : vector<16xi8> to vector<4xi32> @@ -51,13 +62,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.bitcast %{{.*}} : vector<16xi8> to vector<4xi32> // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> // CHECK: rocdl.wmma.i32.16x16x16.iu8 {{.*}} : (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32> - %0 = tt.dot %arg0, %arg1, %arg2 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> * tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> -> tensor<16x16xi32, #mma> + %0 = tt.dot %arg0, %arg1, %arg2 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xi32, #mma1> // CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> tt.return } // CHECK-LABEL: wmma_dot_int4_32 - tt.func @wmma_dot_int4_32(%arg0: tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma>) { + tt.func @wmma_dot_int4_32(%arg0: tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma1>) { // CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4)> // CHECK-COUNT-16: llvm.insertelement {{.*}} : vector<16xi4> // CHECK: llvm.bitcast %{{.*}} : vector<16xi4> to vector<2xi32> @@ -66,7 +77,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.bitcast %{{.*}} : vector<16xi4> to vector<2xi32> // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> // CHECK: rocdl.wmma.i32.16x16x16.iu4 {{.*}} : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> - %0 = tt.dot %arg0, %arg1, %arg2 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> * tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> -> tensor<16x16xi32, #mma> + %0 = tt.dot %arg0, %arg1, %arg2 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xi32, #mma1> // CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> tt.return } @@ -75,19 +86,19 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- #shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [2, 1, 0], hasLeadingOffset = false}> -#mma = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 1, 4]}> +#mma1 = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 1, 4]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK-LABEL: wmma_dot_operand3d tt.func @wmma_dot_operand3d(%arg0: !tt.memdesc<4x16x32xf16, #shared>) { // CHECK-COUNT-4: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xf16> - %0 = triton_gpu.local_load %arg0 : !tt.memdesc<4x16x32xf16, #shared> -> tensor<4x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> + %0 = triton_gpu.local_load %arg0 : !tt.memdesc<4x16x32xf16, #shared> -> tensor<4x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> // CHECK-COUNT-32: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<1xf16> - %1 = triton_gpu.local_load %arg0 : !tt.memdesc<4x16x32xf16, #shared> -> tensor<4x16x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> + %1 = triton_gpu.local_load %arg0 : !tt.memdesc<4x16x32xf16, #shared> -> tensor<4x16x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> tt.return } // CHECK-LABEL: wmma_dot3d - tt.func @wmma_dot3d(%arg0: tensor<2x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<2x32x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg2: tensor<2x16x16xf16, #mma>) { + tt.func @wmma_dot3d(%arg0: tensor<2x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<2x32x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<2x16x16xf16, #mma1>) { // CHECK-COUNT-32: llvm.extractvalue %arg0 // CHECK-COUNT-32: llvm.insertelement // CHECK-COUNT-32: llvm.extractvalue %arg1 @@ -95,7 +106,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK-COUNT-8: llvm.extractvalue %arg2 // CHECK-COUNT-8: llvm.insertelement // CHECK-COUNT-2: rocdl.wmma.f16.16x16x16.f16 {{.*}} : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16> - %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<2x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> * tensor<2x32x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> -> tensor<2x16x16xf16, #mma> + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<2x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<2x32x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<2x16x16xf16, #mma1> // CHECK-COUNT-8: llvm.extractelement // CHECK-COUNT-8: llvm.insertvalue tt.return diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp index 51252b6372ae..6cdcddad396e 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp @@ -33,8 +33,9 @@ using ::mlir::triton::gpu::SharedEncodingAttr; namespace SharedToDotOperandWMMA { /** - * @brief This function maps particular load of wmma dot operand to element - * indexes(row, col) + * @brief Following functions maps particular load of wmma dot operand to + * element indexes(row, col). For each WMMA generation separate function is + * used. * * Whole tensor is broken into "blocks" of warps along "non-K" axis. * One block could be processed by multiple warps. @@ -64,7 +65,8 @@ namespace SharedToDotOperandWMMA { * @return vector (i-th element corresponds to i-th load instruction) of * 2-element vectors(tensor row and col). */ -llvm::SmallVector> computeTensorElemMappingInBlock( +llvm::SmallVector> +computeTensorElemMappingInBlockWmma1( ConversionPatternRewriter &rewriter, Location loc, const ArrayRef &elemsPerInstr, Value warpId, Value laneId, int numOfElems, ArrayRef reps, ArrayRef smemOffsets, @@ -75,28 +77,55 @@ llvm::SmallVector> computeTensorElemMappingInBlock( const int loadsPerThread = numOfElems / loadVecSize; llvm::SmallVector> mapping(numK * loadsPerThread); - Value _0 = i32_val(0); - Value nonKDim = i32_val(iNonKDim); - Value warpVOffset = mul(warpId, i32_val(elemsPerInstr[0])); - + Value elemsPerInstrV = i32_val(elemsPerInstr[0]); + Value warpVOffset = mul(warpId, elemsPerInstrV); + Value sliceVOffset = add(urem(laneId, elemsPerInstrV), warpVOffset); auto rank = smemOffsets.size(); + Value row = add(sliceVOffset, smemOffsets[rank - 2]); for (int tile = 0; tile < numK; ++tile) { - Value tileVOffset = _0; Value tileHOffset = i32_val(tile * elemsPerInstr[1]); - Value laneVOffset = laneId; - Value laneHOffset = _0; - for (int loadId = 0; loadId < loadsPerThread; ++loadId) { - Value elemVOffset = _0; Value elemHOffset = i32_val(loadId * loadVecSize); + Value sliceHOffset = add(tileHOffset, elemHOffset); + + Value col = add(sliceHOffset, smemOffsets[rank - 1]); + mapping[loadsPerThread * tile + loadId] = {row, col}; + } + } + + return mapping; +} + +llvm::SmallVector> +computeTensorElemMappingInBlockWmma2( + ConversionPatternRewriter &rewriter, Location loc, + const ArrayRef &elemsPerInstr, Value warpId, Value laneId, + int numOfElems, ArrayRef reps, ArrayRef smemOffsets, + int loadVecSize, unsigned iNonKDim, [[maybe_unused]] unsigned iKDim) { + assert(reps.size() == 3); + assert(elemsPerInstr.size() == 2); + auto numK = reps[2]; + const int loadsPerThread = numOfElems / loadVecSize; + llvm::SmallVector> mapping(numK * loadsPerThread); - Value sliceVOffset = - add(add(add(tileVOffset, laneVOffset), elemVOffset), warpVOffset); - Value sliceHOffset = add(add(tileHOffset, laneHOffset), elemHOffset); + Value rowsPerInstr = i32_val(elemsPerInstr[0]); + Value colsPerInstr = i32_val(elemsPerInstr[1]); + Value elemsPerThread = i32_val(elemsPerInstr[1] / 2); + Value warpVOffset = mul(warpId, rowsPerInstr); + Value sliceVOffset = add(urem(laneId, rowsPerInstr), warpVOffset); + + auto rank = smemOffsets.size(); + Value row = add(sliceVOffset, smemOffsets[rank - 2]); + Value laneHOffset = mul(udiv(laneId, colsPerInstr), elemsPerThread); + + for (int tile = 0; tile < numK; ++tile) { + Value tileHOffset = add(laneHOffset, i32_val(tile * elemsPerInstr[1])); + for (int loadId = 0; loadId < loadsPerThread; ++loadId) { + Value elemHOffset = i32_val(loadId * loadVecSize); + Value sliceHOffset = add(tileHOffset, elemHOffset); - Value row = add(sliceVOffset, smemOffsets[rank - 2]); Value col = add(sliceHOffset, smemOffsets[rank - 1]); mapping[loadsPerThread * tile + loadId] = {row, col}; @@ -116,8 +145,9 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, int nonKDimIdx = opIdx == 0 ? rank - 2 : rank - 1; auto wmmaLayout = cast(encoding.getParent()); - // TODO: support 2nd gen of WMMA - assert(wmmaLayout.getVersion() == 1); + auto computeTensorElemMappingInBlock = + wmmaLayout.getVersion() == 1 ? computeTensorElemMappingInBlockWmma1 + : computeTensorElemMappingInBlockWmma2; assert(wmmaLayout.getMNKDimPerInstr()[nonKDimIdx] == 16); auto warpsPerCTA = wmmaLayout.getWarpsPerCTA(); @@ -141,16 +171,14 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, auto repB = numReps[0]; unsigned iWaveSize = triton::gpu::getWarpSize(wmmaLayout); - unsigned iNumLanes = iWaveSize / 2; assert(iWaveSize == 32); Value waveSize = i32_val(iWaveSize); - Value numLanes = i32_val(iNumLanes); Value linearWaveId = udiv(thread, waveSize); - Value lane = urem(thread, numLanes); // share elem between two threads - unsigned numElemsPerThreadPerRep = wmmaInstrK; + unsigned numElemsPerThreadPerRep = + wmmaLayout.getSizePerThreadForOperands(opIdx)[kDimIdx]; - Value warp = udiv(thread, waveSize); + Value lane = urem(thread, waveSize); unsigned int maxNumWarps = shape[nonKDimIdx] / wmmaInstrNonK; int warpsPerBlockNonK = std::min(warpsPerCTA[nonKDimIdx], maxNumWarps); int warpsPerBatch =