diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td index 524503f3c3dc..d8f5b092b813 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td +++ b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td @@ -206,12 +206,13 @@ def Rock_ReduceOp : } def Rock_AttentionOp : - Rock_Op<"attention", [DeclareOpInterfaceMethods, RockFusionRoot]>, + Rock_Op<"attention", [DeclareOpInterfaceMethods, RockFusionRoot, AttrSizedOperandSegments]>, Arguments<(ins TensorOrMemRefOf<[F32, F16, I8]>:$queries, TensorOrMemRefOf<[F32, F16, I8]>:$keys, TensorOrMemRefOf<[F32, F16]>:$values, Variadic>:$preSoftmaxElemWiseInputs, + Optional>:$currentSeqLen, TensorOrMemRefOf<[F32, F16]>:$out, UnitAttr:$qTransposed, UnitAttr:$kTransposed, @@ -244,6 +245,7 @@ def Rock_AttentionOp : let assemblyFormat = [{ `{` `\n` ` ` `qk` `=` (`tr` $qTransposed^)? $queries `*` (`tr` $kTransposed^)? $keys `:` type($queries) `,` type($keys) `\n` + (`currentSeqLen` `=` `(` $currentSeqLen^ `:` type($currentSeqLen) `)` `\n`)? (`qk` `=` `elementwise` (`otherIns` `(` $preSoftmaxElemWiseInputs^ `:` type($preSoftmaxElemWiseInputs) `)`)? $preSoftmaxBody^ `\n`)? (`tr` $oTransposed^)? $out `=` `softmax` `(` `qk` `)` `*` (`tr` $vTransposed^)? $values `:` type($values) `->` type($out) `\n` `}` attr-dict (`->` type($result)^)? @@ -431,11 +433,12 @@ def Rock_GridwiseGemmAccelOp : // gridwise_attention_accel def Rock_GridwiseAttentionAccelOp : - Rock_Op<"gridwise_attention_accel", [DeclareOpInterfaceMethods, RockFusionRoot]>, + Rock_Op<"gridwise_attention_accel", [DeclareOpInterfaceMethods, RockFusionRoot, AttrSizedOperandSegments]>, Arguments<(ins MemRefRankOf<[F32, F16, I8], [3]>:$queries, MemRefRankOf<[F32, F16, I8], [3]>:$keys, MemRefRankOf<[F32, F16], [3]>:$values, Variadic>:$preSoftmaxElemWiseInputs, + Optional>:$currentSeqLen, MemRefRankOf<[F32, F16], [3]>:$out, StrAttr:$arch, Rock_GemmFeaturesAttr:$features, @@ -677,7 +680,7 @@ def Rock_TransformingForOp : Results<(outs Variadic:$results)> { let summary = "for loop with coordinate transforms"; let description = [{ - Loops over several a rectangular regeon of dimensions `bounds` in several + Loops over several a rectangular region of dimensions `bounds` in several iteration domains, which are coordinate spaces that are the upper coordinates for a sequence of coordinate transformations. @@ -779,7 +782,7 @@ def Rock_TransformingForOp : return *(getLowerStarts().getValues().begin() + n); } - // Retreive the block arguments corresponding to the lower coordinates + // Retrieve the block arguments corresponding to the lower coordinates // for a given iteration domain. Block::BlockArgListType getLowerCoords(uint32_t domain) { uint32_t start = getLowerStart(domain); diff --git a/mlir/include/mlir/Dialect/Rock/utility/builderUtils.h b/mlir/include/mlir/Dialect/Rock/utility/builderUtils.h index 5d51cefed470..9516e66c8300 100644 --- a/mlir/include/mlir/Dialect/Rock/utility/builderUtils.h +++ b/mlir/include/mlir/Dialect/Rock/utility/builderUtils.h @@ -11,7 +11,8 @@ namespace mlir { namespace rock { /// Utility op to emit constant float op Value createConstantFloatOp(OpBuilder &b, Location loc, Type type, - Type elemType, float value); + Type elemType, float value, + APFloat::opStatus expectedStatus = APFloat::opOK); /// Utility op to emit constant int op Value createConstantIntOp(OpBuilder &b, Location loc, Type type, Type elemType, diff --git a/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp b/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp index 226c2d45b4ac..20a41fe6a4b1 100644 --- a/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp +++ b/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp @@ -1089,9 +1089,11 @@ struct AttentionRewritePattern : public OpRewritePattern { tosa::MatMulOp firstMatMulOp = maybeFirstMatMul.value(); IntegerAttr numCUAttr = numCu.has_value() ? rewriter.getI32IntegerAttr(numCu.value()) : nullptr; + + // TODO: extract currentSeqLen from tosa rock::AttentionOp attnOp = rewriter.create( loc, outputType, firstMatMulOp.getA(), firstMatMulOp.getB(), op.getB(), - elemwiseOtherArgs, output, + elemwiseOtherArgs, nullptr, output, // TODO(implement transpose fusion support here) /*qTransposed=*/nullptr, /*kTransposed=*/nullptr, diff --git a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp index d1d1f098d267..50ce3e1ecad8 100644 --- a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp +++ b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp @@ -2098,6 +2098,41 @@ LogicalResult AttentionOp::verify() { if (keyN != valueK) { return emitError("reduction dimensions of second gemm do not match"); } + + // check output type + ShapedType oType = getOut().getType(); + int64_t oBatchDim = oType.getShape().size() == 3 ? oType.getShape()[0] : 1; + + ArrayRef oLastDims = oType.getShape().slice(oType.getRank() - 2); + auto [outputSeqLen, outputHeadDim] = + getOTransposed() ? std::tuple{oLastDims[1], oLastDims[0]} + : std::tuple{oLastDims[0], oLastDims[1]}; + + if (qType.getShape().size() != oType.getShape().size()) { + return emitError("Number of dimensions do not match (Q and Output)"); + } + if (qBatchDim != oBatchDim) { + return emitError("Batch dimensions do not match (Q and Output)"); + } + if (queryM != outputSeqLen) { + return emitError("Sequence length does not match (Q and Output)"); + } + if (valueN != outputHeadDim) { + return emitError("Head dimensions do not match (V and Output)"); + } + + // check currentSeqLen (KV Cache) + auto currentSeqLen = getCurrentSeqLen(); + if (currentSeqLen) { + ShapedType seqLenType = currentSeqLen.getType(); + if (seqLenType.getShape().size() != 1) { + return emitError("Number of dimensions is not one (currentSeqLen)"); + } + if (seqLenType.getShape()[0] != oBatchDim) { + return emitError( + "Batch dimensions do not match (currentSeqLen and Output)"); + } + } return success(); } diff --git a/mlir/lib/Dialect/Rock/Transforms/AffixTuningParameters.cpp b/mlir/lib/Dialect/Rock/Transforms/AffixTuningParameters.cpp index d38bec420f6f..4276a488bf37 100644 --- a/mlir/lib/Dialect/Rock/Transforms/AffixTuningParameters.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/AffixTuningParameters.cpp @@ -246,8 +246,7 @@ void AffixTuningParameters::affixTuningParametersImpl(AttentionOp op) { if (!isAccel) { op.emitError("Currently, attention op is only supported on GPUs " "with matrix accelerator extentions"); - signalPassFailure(); - return; + return signalPassFailure(); } Attribute params0 = op.getParams0().value_or(nullptr); // set a default one if params is not provided @@ -262,6 +261,7 @@ void AffixTuningParameters::affixTuningParametersImpl(AttentionOp op) { auto attnPerfConfig = AttnPerfConfigAttr::get(perfConfigStrAttr); if (!attnPerfConfig) { op.emitError("perf config string has an incorrect format."); + return signalPassFailure(); } GemmFeatures features = op.getFeatures(); RockAccelTuningParamAttrInterface accelParams0; @@ -283,8 +283,7 @@ void AffixTuningParameters::affixTuningParametersImpl(AttentionOp op) { if (attnPerfConfig.getMPerBlockG0() > attnPerfConfig.getMPerBlockG1()) { op.emitError( "The MPerBlockG0 should be larger or equal to getMPerBlockG1."); - signalPassFailure(); - return; + return signalPassFailure(); } RockAccelTuningParamAttrInterface accelParams1 = deriveGemm1TuningParams(builder, op, attnPerfConfig); @@ -308,8 +307,7 @@ void AffixTuningParameters::affixTuningParametersImpl(AttentionOp op) { /*enableDPerWaveFiltering=*/false); if (isValidBlockwiseGemm0.failed() || isValidBlockwiseGemm1.failed()) { op.emitError("The provided perf config is not valid"); - signalPassFailure(); - return; + return signalPassFailure(); } IntegerAttr blockSizeAttr = builder.getI32IntegerAttr(blockSize); diff --git a/mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp index bc8f87a3b576..5388befaf007 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp @@ -538,8 +538,9 @@ AttentionRewritePattern::matchAndRewrite(AttentionOp op, prePadG0NAttr = rw.getIndexAttr(gemm0Size.n); } auto newOp = rw.create( - loc, queries, keys, values, adaptor.getPreSoftmaxElemWiseInputs(), out, - op.getArchAttr(), op.getFeaturesAttr(), blockSizeAttr, gridSizeAttr, + loc, queries, keys, values, adaptor.getPreSoftmaxElemWiseInputs(), + op.getCurrentSeqLen(), out, op.getArchAttr(), op.getFeaturesAttr(), + blockSizeAttr, gridSizeAttr, /*disableQBypassLDS=*/nullptr, prePadG0MAttr, prePadG0NAttr, params0, params1, op.getFirstGemmIdxAttr()); bool linalgOpFound = false; diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index d1b9da7160a4..f70840e491c8 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -1043,7 +1043,7 @@ struct GridwiseAttentionAccelRewritePattern gemm0OutExpTrs, gemm0OutTrs}, /*bounds=*/ArrayRef{g0Mpt, g0Npt}, /*strides=*/ArrayRef{1, 1}, - /*useIndexDiffs=*/true, /*forceUnroll=*/true); + /*forceUnroll=*/true, /*useIndexDiffs=*/true); { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(loop.getBody()); @@ -1105,7 +1105,7 @@ struct GridwiseAttentionAccelRewritePattern gemm0OutBufferMaxTrs}, /*bounds=*/ArrayRef{g0Mpt, 1}, /*strides=*/ArrayRef{1, 1}, - /*useIndexDiffs=*/true, /*forceUnroll=*/true); + /*forceUnroll=*/true, /*useIndexDiffs=*/true); { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(loop.getBody()); @@ -1167,7 +1167,7 @@ struct GridwiseAttentionAccelRewritePattern ArrayRef{rewriter.getArrayAttr({}), attentionOutAccTrs}, /*bounds=*/ArrayRef{g1Mpt, g1Npt}, /*strides=*/ArrayRef{1, 1}, - /*useIndexDiffs=*/true, /*forceUnroll=*/true); + /*forceUnroll=*/true, /*useIndexDiffs=*/true); { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(loop.getBody()); @@ -1230,7 +1230,7 @@ struct GridwiseAttentionAccelRewritePattern attentionOutAccBufferTrs}, /*bounds=*/ArrayRef{g1Mpt, g1Npt}, /*strides=*/ArrayRef{1, 1}, - /*useIndexDiffs=*/true, /*forceUnroll=*/true); + /*forceUnroll=*/true, /*useIndexDiffs=*/true); { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(loop.getBody()); @@ -1330,13 +1330,14 @@ struct GridwiseAttentionAccelRewritePattern // attention kernel will perform softmax normalization on rows. // Therefore, having zeros -- zero not being the minimum representable // value in the element type -- going to affect all the values - // post normalization. Therefore, this function creates a trasnforming + // post normalization. Therefore, this function creates a transforming // for loop that overwrites out of bounds values of first gemm output // to be negative infinity. - void createFirstGemmNegInfPadding( - PatternRewriter &rewriter, Location loc, - layout::GridCoordinates gridCoords, Value gemm0OutBuffer, - RegsAsMatrixSubTiles gemm0OutSubTileViews) const { + void createFirstGemmNegInfPadding(PatternRewriter &rewriter, Location loc, + layout::GridCoordinates gridCoords, + Value gemm0OutBuffer, + RegsAsMatrixSubTiles gemm0OutSubTileViews, + bool isGfx11) const { MemRefType gemm0OutBufferType = cast(gemm0OutBuffer.getType()); auto negInfTyped = createConstantFloatOp( rewriter, loc, gemm0OutBufferType.getElementType(), @@ -1346,6 +1347,9 @@ struct GridwiseAttentionAccelRewritePattern auto tid = rewriter.create(loc, rewriter.getIndexType()); int64_t elementsInThreadBuffer = gemm0OutBufferType.getNumElements(); Value zero = rewriter.createOrFold(loc, 0); + + // TODO: fix forceUnroll=false for gfx1100 + // (https://github.com/ROCm/rocMLIR-internal/issues/1661) auto loop = rewriter.create( loc, ArrayRef{{gridCoords.g_block, gridCoords.m_block, @@ -1355,7 +1359,7 @@ struct GridwiseAttentionAccelRewritePattern rewriter.getArrayAttr({})}, /*bounds=*/ArrayRef{1, 1, 1, 1, elementsInThreadBuffer}, /*strides=*/ArrayRef{1, 1, 1, 1, 1}, - /*useIndexDiffs=*/true, /*forceUnroll=*/true); + /*forceUnroll=*/!isGfx11, /*useIndexDiffs=*/true); { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(loop.getBody()); @@ -1376,6 +1380,59 @@ struct GridwiseAttentionAccelRewritePattern } } + void setGemm0OutputOutOfScopeKVCache( + PatternRewriter &rewriter, Location loc, + layout::GridCoordinates gridCoords, Value gemm0OutBuffer, + RegsAsMatrixSubTiles gemm0OutSubTileViews, Value currentSeqLen, + Value mLoopIV, Value gemm0MBlocksLastIter) const { + if (currentSeqLen) { + auto isLastIteration = rewriter.create( + loc, arith::CmpIPredicate::eq, mLoopIV, gemm0MBlocksLastIter); + scf::IfOp ifb = rewriter.create(loc, isLastIteration, + /*withElseRegion=*/false); + { + OpBuilder thenb = ifb.getThenBodyBuilder(); + + MemRefType gemm0OutBufferType = + cast(gemm0OutBuffer.getType()); + auto negInfTyped = createConstantFloatOp( + thenb, loc, gemm0OutBufferType.getElementType(), + gemm0OutBufferType.getElementType(), + -std::numeric_limits::infinity()); + // Get current workitem ID. + auto tid = thenb.create(loc, thenb.getIndexType()); + int64_t elementsInThreadBuffer = gemm0OutBufferType.getNumElements(); + Value zero = thenb.createOrFold(loc, 0); + auto loop = thenb.create( + loc, + ArrayRef{{gridCoords.g_block, gridCoords.m_block, + gridCoords.n_block, tid, zero}, + {zero, zero, zero, zero, zero}}, + ArrayRef{gemm0OutSubTileViews.gridSubTile, + thenb.getArrayAttr({})}, + /*bounds=*/ArrayRef{1, 1, 1, 1, elementsInThreadBuffer}, + /*strides=*/ArrayRef{1, 1, 1, 1, 1}, + /*forceUnroll=*/true, /*useIndexDiffs=*/true); + { + OpBuilder::InsertionGuard guard(thenb); + thenb.setInsertionPointToStart(loop.getBody()); + + Block::BlockArgListType lowerCoords = loop.getLowerCoords(0); + Block::BlockArgListType upperCoords = loop.getLowerCoords(1); + auto isInvalid = thenb.create( + loc, arith::CmpIPredicate::uge, lowerCoords[2], currentSeqLen); + scf::IfOp ifb = thenb.create(loc, isInvalid, + /*withElseRegion=*/false); + { + OpBuilder thenb = ifb.getThenBodyBuilder(); + thenb.create(loc, negInfTyped, gemm0OutBuffer, + ValueRange{upperCoords[4]}); + } + } + } + } + } + template void postProcessFirstGemmSplat(PatternRewriter &rewriter, Location loc, layout::GridCoordinates gridCoords, @@ -1627,6 +1684,8 @@ struct GridwiseAttentionAccelRewritePattern ArrayRef outShape = cast(trOut.getType()).getShape(); Type elemTypeOut = cast(trOut.getType()).getElementType(); + TypedValue currentSeqLenTensor = op.getCurrentSeqLen(); + // Gemm0 out is casted to be elemTypeV Type elemTypeQxK = elemTypeV; @@ -1924,17 +1983,85 @@ struct GridwiseAttentionAccelRewritePattern } bool isReverseGrid = succeeded(rock::getReverseGrid(op)); - affine::AffineForOp mLoopOp = - rewriter.create(loc, 0, gemm0MBlocks, 1); + if (isReverseGrid && currentSeqLenTensor) { + return op.emitError( + "reverse grid is not compatible with currentSeqLen\n"); + } + + LoopLikeOpInterface mLoopOp; + Value gemm0MBlocksLastIter; + Value currentSeqLen; + // This is needed for KV Cache support + if (currentSeqLenTensor) { + Value zero = rewriter.createOrFold(loc, 0); + auto gridCoordsGemm0LoadCurrSeqLen = layout::makeGxNGridLayout( + rewriter, loc, bid, zero, gemm0NBlocks, gridSize, arch); + + // add dim 1 for thread_read_into (registers) + ArrayRef inpShape = + cast(currentSeqLenTensor.getType()).getShape(); + SmallVector startNames = {"gemmG"}; + rock::BottomUpTMBuilder addDim(rewriter, startNames, inpShape); + addDim.addDim("dummy", 1, 1); + addDim.passThrough(ArrayRef{0}, ArrayRef{0}); + auto addDimAttr = addDim.get(); + Value currentSeqLenTensorAddDim = rewriter.create( + loc, currentSeqLenTensor, addDimAttr); + Type currentSeqLenElemType = + getElementTypeOrSelf(currentSeqLenTensorAddDim.getType()); + + // create registers + auto privateMemoryAddressSpace = rewriter.getAttr( + gpu::GPUDialect::getPrivateAddressSpace()); + auto memrefType = MemRefType::get({1}, currentSeqLenElemType, AffineMap{}, + privateMemoryAddressSpace); + auto currentSeqLenLoad = rewriter.create(loc, memrefType); + + // load from memory to registers + rewriter.create( + loc, vectorOfBoolShapedLike(currentSeqLenLoad), + currentSeqLenTensorAddDim, currentSeqLenLoad, + /*dynamicValidities=*/ValueRange{}, + /*extraViews=*/rewriter.getArrayAttr({}), + /*extraIndices=*/ + ValueRange{gridCoordsGemm0LoadCurrSeqLen.g_block}, true, true); + + // load from registers + Value currentSeqLenValue = rewriter.create( + loc, currentSeqLenElemType, currentSeqLenLoad, ValueRange{zero}); + + currentSeqLen = rewriter.createOrFold( + loc, rewriter.getIndexType(), currentSeqLenValue); + + Value constGemm0MPerBlock = + rewriter.createOrFold(loc, gemm0MPerBlock); + Value constGemm0MPerBlockM1 = + rewriter.createOrFold(loc, + gemm0MPerBlock - 1); + Value numerator = rewriter.create(loc, currentSeqLen, + constGemm0MPerBlockM1); + Value gemm0MBlocksEarlyExit = rewriter.createOrFold( + loc, numerator, constGemm0MPerBlock); + Value one = rewriter.createOrFold(loc, 1); + gemm0MBlocksLastIter = + rewriter.createOrFold(loc, gemm0MBlocksEarlyExit, one); + + mLoopOp = + rewriter.create(loc, zero, gemm0MBlocksEarlyExit, one); + } else { + mLoopOp = rewriter.create(loc, 0, gemm0MBlocks, 1); + } { PatternRewriter::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(mLoopOp.getBody()); + // workaround for mLoopOp.getBody() + assert(mLoopOp->getRegions().size() == 1); + rewriter.setInsertionPointToStart(&mLoopOp->getRegion(0).front()); int64_t kIterationsGemm0 = gemm0K / gemm0KPerBlock; Value kIterationsGemm0Val = rewriter.createOrFold(loc, kIterationsGemm0); Value mIterationsGemm0Val = rewriter.createOrFold(loc, gemm0MBlocks); - Value mLoopIV = mLoopOp.getInductionVar(); + Value mLoopIV = mLoopOp.getSingleInductionVar().value(); if (isReverseGrid) { AffineMap reverseMap = rock::getIdxReversalMap(rewriter); mLoopIV = rewriter.createOrFold( @@ -2085,21 +2212,27 @@ struct GridwiseAttentionAccelRewritePattern // Scale gemm0 output by (1/ln2) // So that we can use exp2 instead of exp. #ifndef ROCK_DEBUG_ATTENTION_REMOVE_SOFTMAX - Value ln2Recip = createConstantFloatOp(rewriter, loc, elemTypeQxK, - elemTypeQxK, 1.44269504); + Value ln2Recip = createConstantFloatOp( + rewriter, loc, elemTypeQxK, elemTypeQxK, 1.44269504f, + elemTypeQxK.isF32() ? APFloat::opOK : APFloat::opInexact); postProcessFirstGemmSplat( rewriter, loc, gridCoordsGemm0, gemm0OutBuffer, gemm0OutSubTileViews, ln2Recip.getDefiningOp().getValue()); -#endif // Handle padding bool hasPadding = op.getPrePadG0M().has_value() || op.getPrePadG0N().has_value(); if (hasPadding) { + bool isGfx11 = arch.contains("gfx11"); createFirstGemmNegInfPadding(rewriter, loc, gridCoordsGemm0, gemm0OutBuffer, - gemm0OutSubTileViewsTrUnPadded); + gemm0OutSubTileViewsTrUnPadded, isGfx11); } + // Negative Infinite for extra values (KV cache) + setGemm0OutputOutOfScopeKVCache( + rewriter, loc, gridCoordsGemm0, gemm0OutBuffer, + gemm0OutSubTileViewsTr, currentSeqLen, mLoopIV, gemm0MBlocksLastIter); +#endif APInt reductionAxis = APInt(64, 1); APInt nrDimPerThread = APInt(64, gemm0MPerBlock / gemm0MPerThread); diff --git a/mlir/lib/Dialect/Rock/utility/builderUtils.cpp b/mlir/lib/Dialect/Rock/utility/builderUtils.cpp index 42c8da8e6363..0d4b10deff2b 100644 --- a/mlir/lib/Dialect/Rock/utility/builderUtils.cpp +++ b/mlir/lib/Dialect/Rock/utility/builderUtils.cpp @@ -40,21 +40,22 @@ Value createConstantIntOp(OpBuilder &b, Location loc, Type type, } Value createConstantFloatOp(OpBuilder &b, Location loc, Type type, - Type elementType, float value) { + Type elemType, float value, + APFloat::opStatus expectedStatus) { auto semantics = static_cast(-1); - if (elementType.isF32()) { + if (elemType.isF32()) { semantics = APFloat::S_IEEEsingle; - } else if (elementType.isF16()) { + } else if (elemType.isF16()) { semantics = APFloat::S_IEEEhalf; - } else if (elementType.isBF16()) { + } else if (elemType.isBF16()) { semantics = APFloat::S_BFloat; - } else if (elementType.isFloat8E4M3FNUZ()) { + } else if (elemType.isFloat8E4M3FNUZ()) { semantics = APFloat::S_Float8E4M3FNUZ; - } else if (elementType.isFloat8E5M2FNUZ()) { + } else if (elemType.isFloat8E5M2FNUZ()) { semantics = APFloat::S_Float8E5M2FNUZ; - } else if (elementType.isFloat8E4M3FN()) { + } else if (elemType.isFloat8E4M3FN()) { semantics = APFloat::S_Float8E4M3FN; - } else if (elementType.isFloat8E5M2()) { + } else if (elemType.isFloat8E5M2()) { semantics = APFloat::S_Float8E5M2; } else { llvm_unreachable("Unexpected float semantics"); @@ -62,17 +63,19 @@ Value createConstantFloatOp(OpBuilder &b, Location loc, Type type, APFloat apValue(value); bool lostInfo = false; - apValue.convert(APFloat::EnumToSemantics(semantics), - APFloat::rmNearestTiesToEven, &lostInfo); + auto status = apValue.convert(APFloat::EnumToSemantics(semantics), + APFloat::rmNearestTiesToEven, &lostInfo); + + assert(status == expectedStatus); Value retValue; if (auto shapedType = dyn_cast(type)) { - Attribute constValue = b.getFloatAttr(elementType, apValue); + Attribute constValue = b.getFloatAttr(elemType, apValue); + assert(shapedType.getElementType() == elemType); retValue = b.create( loc, SplatElementsAttr::get(shapedType, constValue)); } else { - retValue = - b.create(loc, type, b.getFloatAttr(elementType, value)); + retValue = b.create(loc, type, b.getFloatAttr(elemType, value)); } return retValue; diff --git a/mlir/test/Conversion/TosaToRock/tosa-to-rock-attention-gqa.mlir b/mlir/test/Conversion/TosaToRock/tosa-to-rock-attention-gqa.mlir new file mode 100644 index 000000000000..f741126443ee --- /dev/null +++ b/mlir/test/Conversion/TosaToRock/tosa-to-rock-attention-gqa.mlir @@ -0,0 +1,174 @@ +// RUN: rocmlir-opt --tosa-to-rock %s -verify-diagnostics -o -| FileCheck %s + +// CHECK: rock.attention +func.func @self_attention_gqa(%arg0: tensor<4096xf32> {mhal.read_access}, %arg1: tensor<8192xf32> {mhal.read_access}, %arg2: tensor<4096xf32> {mhal.read_access}) -> (tensor<8192xf32> {mhal.write_access}) attributes {kernel, arch = ""} { + %expanded = tensor.expand_shape %arg1 [[0, 1, 2, 3]] output_shape [2, 4, 32, 32] : tensor<8192xf32> into tensor<2x4x32x32xf32> + %expanded_0 = tensor.expand_shape %arg2 [[0, 1, 2, 3]] output_shape [2, 2, 32, 32] : tensor<4096xf32> into tensor<2x2x32x32xf32> + %expanded_1 = tensor.expand_shape %arg0 [[0, 1, 2, 3]] output_shape [2, 2, 32, 32] : tensor<4096xf32> into tensor<2x2x32x32xf32> + %expanded_2 = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4]] output_shape [2, 2, 1, 32, 32] : tensor<4096xf32> into tensor<2x2x1x32x32xf32> + %0 = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x2x2x32x32xf32>}> : () -> tensor<2x2x2x32x32xf32> + %1 = tosa.add %expanded_2, %0 : (tensor<2x2x1x32x32xf32>, tensor<2x2x2x32x32xf32>) -> tensor<2x2x2x32x32xf32> + %collapsed = tensor.collapse_shape %1 [[0], [1, 2], [3], [4]] : tensor<2x2x2x32x32xf32> into tensor<2x4x32x32xf32> + %expanded_3 = tensor.expand_shape %arg2 [[0, 1, 2, 3, 4]] output_shape [2, 2, 1, 32, 32] : tensor<4096xf32> into tensor<2x2x1x32x32xf32> + %2 = tosa.add %expanded_3, %0 : (tensor<2x2x1x32x32xf32>, tensor<2x2x2x32x32xf32>) -> tensor<2x2x2x32x32xf32> + %collapsed_4 = tensor.collapse_shape %2 [[0], [1, 2], [3], [4]] : tensor<2x2x2x32x32xf32> into tensor<2x4x32x32xf32> + %3 = "tosa.const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> + %4 = tosa.transpose %collapsed_4, %3 : (tensor<2x4x32x32xf32>, tensor<4xi32>) -> tensor<2x4x32x32xf32> + %expanded_5 = tensor.expand_shape %arg1 [[0, 1, 2]] output_shape [8, 32, 32] : tensor<8192xf32> into tensor<8x32x32xf32> + %collapsed_6 = tensor.collapse_shape %4 [[0, 1], [2], [3]] : tensor<2x4x32x32xf32> into tensor<8x32x32xf32> + %5 = tosa.matmul %expanded_5, %collapsed_6 : (tensor<8x32x32xf32>, tensor<8x32x32xf32>) -> tensor<8x32x32xf32> + %expanded_7 = tensor.expand_shape %5 [[0, 1], [2], [3]] output_shape [2, 4, 32, 32] : tensor<8x32x32xf32> into tensor<2x4x32x32xf32> + %6 = tosa.reduce_max %expanded_7 {axis = 3 : i32} : (tensor<2x4x32x32xf32>) -> tensor<2x4x32x1xf32> + %7 = tosa.sub %expanded_7, %6 : (tensor<2x4x32x32xf32>, tensor<2x4x32x1xf32>) -> tensor<2x4x32x32xf32> + %8 = tosa.exp %7 : (tensor<2x4x32x32xf32>) -> tensor<2x4x32x32xf32> + %9 = tosa.reduce_sum %8 {axis = 3 : i32} : (tensor<2x4x32x32xf32>) -> tensor<2x4x32x1xf32> + %10 = tosa.reciprocal %9 : (tensor<2x4x32x1xf32>) -> tensor<2x4x32x1xf32> + %11 = tosa.mul %8, %10 {shift = 0 : i8} : (tensor<2x4x32x32xf32>, tensor<2x4x32x1xf32>) -> tensor<2x4x32x32xf32> + %collapsed_8 = tensor.collapse_shape %11 [[0, 1], [2], [3]] : tensor<2x4x32x32xf32> into tensor<8x32x32xf32> + %collapsed_9 = tensor.collapse_shape %1 [[0, 1, 2], [3], [4]] : tensor<2x2x2x32x32xf32> into tensor<8x32x32xf32> + %12 = tosa.matmul %collapsed_8, %collapsed_9 : (tensor<8x32x32xf32>, tensor<8x32x32xf32>) -> tensor<8x32x32xf32> + %expanded_10 = tensor.expand_shape %12 [[0, 1], [2], [3]] output_shape [2, 4, 32, 32] : tensor<8x32x32xf32> into tensor<2x4x32x32xf32> + %collapsed_11 = tensor.collapse_shape %12 [[0, 1, 2]] : tensor<8x32x32xf32> into tensor<8192xf32> + return %collapsed_11 : tensor<8192xf32> +} + + +// CHECK: rock.attention +func.func @self_attention_gqa_bias(%arg0: tensor<4096xf32> {mhal.read_access}, %arg1: tensor<8192xf32> {mhal.read_access}, %arg2: tensor<4096xf32> {mhal.read_access}, %arg3: tensor<8192xf32> {mhal.read_access}) -> (tensor<8192xf32> {mhal.write_access}) attributes {kernel, arch = ""} { + %expanded = tensor.expand_shape %arg3 [[0, 1, 2, 3]] output_shape [2, 4, 32, 32] : tensor<8192xf32> into tensor<2x4x32x32xf32> + %expanded_0 = tensor.expand_shape %arg1 [[0, 1, 2, 3]] output_shape [2, 4, 32, 32] : tensor<8192xf32> into tensor<2x4x32x32xf32> + %expanded_1 = tensor.expand_shape %arg2 [[0, 1, 2, 3]] output_shape [2, 2, 32, 32] : tensor<4096xf32> into tensor<2x2x32x32xf32> + %expanded_2 = tensor.expand_shape %arg0 [[0, 1, 2, 3]] output_shape [2, 2, 32, 32] : tensor<4096xf32> into tensor<2x2x32x32xf32> + %expanded_3 = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4]] output_shape [2, 2, 1, 32, 32] : tensor<4096xf32> into tensor<2x2x1x32x32xf32> + %0 = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x2x2x32x32xf32>}> : () -> tensor<2x2x2x32x32xf32> + %1 = tosa.add %expanded_3, %0 : (tensor<2x2x1x32x32xf32>, tensor<2x2x2x32x32xf32>) -> tensor<2x2x2x32x32xf32> + %collapsed = tensor.collapse_shape %1 [[0], [1, 2], [3], [4]] : tensor<2x2x2x32x32xf32> into tensor<2x4x32x32xf32> + %expanded_4 = tensor.expand_shape %arg2 [[0, 1, 2, 3, 4]] output_shape [2, 2, 1, 32, 32] : tensor<4096xf32> into tensor<2x2x1x32x32xf32> + %2 = tosa.add %expanded_4, %0 : (tensor<2x2x1x32x32xf32>, tensor<2x2x2x32x32xf32>) -> tensor<2x2x2x32x32xf32> + %collapsed_5 = tensor.collapse_shape %2 [[0], [1, 2], [3], [4]] : tensor<2x2x2x32x32xf32> into tensor<2x4x32x32xf32> + %3 = "tosa.const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> + %4 = tosa.transpose %collapsed_5, %3 : (tensor<2x4x32x32xf32>, tensor<4xi32>) -> tensor<2x4x32x32xf32> + %expanded_6 = tensor.expand_shape %arg1 [[0, 1, 2]] output_shape [8, 32, 32] : tensor<8192xf32> into tensor<8x32x32xf32> + %collapsed_7 = tensor.collapse_shape %4 [[0, 1], [2], [3]] : tensor<2x4x32x32xf32> into tensor<8x32x32xf32> + %5 = tosa.matmul %expanded_6, %collapsed_7 : (tensor<8x32x32xf32>, tensor<8x32x32xf32>) -> tensor<8x32x32xf32> + %expanded_8 = tensor.expand_shape %5 [[0, 1], [2], [3]] output_shape [2, 4, 32, 32] : tensor<8x32x32xf32> into tensor<2x4x32x32xf32> + %6 = tosa.add %expanded_8, %expanded : (tensor<2x4x32x32xf32>, tensor<2x4x32x32xf32>) -> tensor<2x4x32x32xf32> + %7 = tosa.reduce_max %6 {axis = 3 : i32} : (tensor<2x4x32x32xf32>) -> tensor<2x4x32x1xf32> + %8 = tosa.sub %6, %7 : (tensor<2x4x32x32xf32>, tensor<2x4x32x1xf32>) -> tensor<2x4x32x32xf32> + %9 = tosa.exp %8 : (tensor<2x4x32x32xf32>) -> tensor<2x4x32x32xf32> + %10 = tosa.reduce_sum %9 {axis = 3 : i32} : (tensor<2x4x32x32xf32>) -> tensor<2x4x32x1xf32> + %11 = tosa.reciprocal %10 : (tensor<2x4x32x1xf32>) -> tensor<2x4x32x1xf32> + %12 = tosa.mul %9, %11 {shift = 0 : i8} : (tensor<2x4x32x32xf32>, tensor<2x4x32x1xf32>) -> tensor<2x4x32x32xf32> + %collapsed_9 = tensor.collapse_shape %12 [[0, 1], [2], [3]] : tensor<2x4x32x32xf32> into tensor<8x32x32xf32> + %collapsed_10 = tensor.collapse_shape %1 [[0, 1, 2], [3], [4]] : tensor<2x2x2x32x32xf32> into tensor<8x32x32xf32> + %13 = tosa.matmul %collapsed_9, %collapsed_10 : (tensor<8x32x32xf32>, tensor<8x32x32xf32>) -> tensor<8x32x32xf32> + %expanded_11 = tensor.expand_shape %13 [[0, 1], [2], [3]] output_shape [2, 4, 32, 32] : tensor<8x32x32xf32> into tensor<2x4x32x32xf32> + %collapsed_12 = tensor.collapse_shape %13 [[0, 1, 2]] : tensor<8x32x32xf32> into tensor<8192xf32> + return %collapsed_12 : tensor<8192xf32> +} + +// CHECK: rock.attention +func.func @self_attention_gqa_scale(%arg0: tensor<4096xf32> {mhal.read_access}, %arg1: tensor<8192xf32> {mhal.read_access}, %arg2: tensor<8192xf32> {mhal.read_access}, %arg3: tensor<4096xf32> {mhal.read_access}) -> (tensor<8192xf32> {mhal.write_access}) attributes {kernel, arch = ""} { + %expanded = tensor.expand_shape %arg1 [[0, 1, 2, 3]] output_shape [2, 4, 32, 32] : tensor<8192xf32> into tensor<2x4x32x32xf32> + %expanded_0 = tensor.expand_shape %arg2 [[0, 1, 2, 3]] output_shape [2, 4, 32, 32] : tensor<8192xf32> into tensor<2x4x32x32xf32> + %expanded_1 = tensor.expand_shape %arg3 [[0, 1, 2, 3]] output_shape [2, 2, 32, 32] : tensor<4096xf32> into tensor<2x2x32x32xf32> + %expanded_2 = tensor.expand_shape %arg0 [[0, 1, 2, 3]] output_shape [2, 2, 32, 32] : tensor<4096xf32> into tensor<2x2x32x32xf32> + %expanded_3 = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4]] output_shape [2, 2, 1, 32, 32] : tensor<4096xf32> into tensor<2x2x1x32x32xf32> + %0 = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x2x2x32x32xf32>}> : () -> tensor<2x2x2x32x32xf32> + %1 = tosa.add %expanded_3, %0 : (tensor<2x2x1x32x32xf32>, tensor<2x2x2x32x32xf32>) -> tensor<2x2x2x32x32xf32> + %collapsed = tensor.collapse_shape %1 [[0], [1, 2], [3], [4]] : tensor<2x2x2x32x32xf32> into tensor<2x4x32x32xf32> + %expanded_4 = tensor.expand_shape %arg3 [[0, 1, 2, 3, 4]] output_shape [2, 2, 1, 32, 32] : tensor<4096xf32> into tensor<2x2x1x32x32xf32> + %2 = tosa.add %expanded_4, %0 : (tensor<2x2x1x32x32xf32>, tensor<2x2x2x32x32xf32>) -> tensor<2x2x2x32x32xf32> + %collapsed_5 = tensor.collapse_shape %2 [[0], [1, 2], [3], [4]] : tensor<2x2x2x32x32xf32> into tensor<2x4x32x32xf32> + %3 = "tosa.const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> + %4 = tosa.transpose %collapsed_5, %3 : (tensor<2x4x32x32xf32>, tensor<4xi32>) -> tensor<2x4x32x32xf32> + %expanded_6 = tensor.expand_shape %arg2 [[0, 1, 2]] output_shape [8, 32, 32] : tensor<8192xf32> into tensor<8x32x32xf32> + %collapsed_7 = tensor.collapse_shape %4 [[0, 1], [2], [3]] : tensor<2x4x32x32xf32> into tensor<8x32x32xf32> + %5 = tosa.matmul %expanded_6, %collapsed_7 : (tensor<8x32x32xf32>, tensor<8x32x32xf32>) -> tensor<8x32x32xf32> + %expanded_8 = tensor.expand_shape %5 [[0, 1], [2], [3]] output_shape [2, 4, 32, 32] : tensor<8x32x32xf32> into tensor<2x4x32x32xf32> + %6 = tosa.mul %expanded_8, %expanded {shift = 0 : i8} : (tensor<2x4x32x32xf32>, tensor<2x4x32x32xf32>) -> tensor<2x4x32x32xf32> + %7 = tosa.reduce_max %6 {axis = 3 : i32} : (tensor<2x4x32x32xf32>) -> tensor<2x4x32x1xf32> + %8 = tosa.sub %6, %7 : (tensor<2x4x32x32xf32>, tensor<2x4x32x1xf32>) -> tensor<2x4x32x32xf32> + %9 = tosa.exp %8 : (tensor<2x4x32x32xf32>) -> tensor<2x4x32x32xf32> + %10 = tosa.reduce_sum %9 {axis = 3 : i32} : (tensor<2x4x32x32xf32>) -> tensor<2x4x32x1xf32> + %11 = tosa.reciprocal %10 : (tensor<2x4x32x1xf32>) -> tensor<2x4x32x1xf32> + %12 = tosa.mul %9, %11 {shift = 0 : i8} : (tensor<2x4x32x32xf32>, tensor<2x4x32x1xf32>) -> tensor<2x4x32x32xf32> + %collapsed_9 = tensor.collapse_shape %12 [[0, 1], [2], [3]] : tensor<2x4x32x32xf32> into tensor<8x32x32xf32> + %collapsed_10 = tensor.collapse_shape %1 [[0, 1, 2], [3], [4]] : tensor<2x2x2x32x32xf32> into tensor<8x32x32xf32> + %13 = tosa.matmul %collapsed_9, %collapsed_10 : (tensor<8x32x32xf32>, tensor<8x32x32xf32>) -> tensor<8x32x32xf32> + %expanded_11 = tensor.expand_shape %13 [[0, 1], [2], [3]] output_shape [2, 4, 32, 32] : tensor<8x32x32xf32> into tensor<2x4x32x32xf32> + %collapsed_12 = tensor.collapse_shape %13 [[0, 1, 2]] : tensor<8x32x32xf32> into tensor<8192xf32> + return %collapsed_12 : tensor<8192xf32> +} + +// CHECK: rock.attention +func.func @self_attention_gqa_scale_bias(%arg0: tensor<4096xf32> {mhal.read_access}, %arg1: tensor<8192xf32> {mhal.read_access}, %arg2: tensor<8192xf32> {mhal.read_access}, %arg3: tensor<4096xf32> {mhal.read_access}, %arg4: tensor<8192xf32> {mhal.read_access}) -> (tensor<8192xf32> {mhal.write_access}) attributes {kernel, arch = ""} { + %expanded = tensor.expand_shape %arg4 [[0, 1, 2, 3]] output_shape [2, 4, 32, 32] : tensor<8192xf32> into tensor<2x4x32x32xf32> + %expanded_0 = tensor.expand_shape %arg1 [[0, 1, 2, 3]] output_shape [2, 4, 32, 32] : tensor<8192xf32> into tensor<2x4x32x32xf32> + %expanded_1 = tensor.expand_shape %arg2 [[0, 1, 2, 3]] output_shape [2, 4, 32, 32] : tensor<8192xf32> into tensor<2x4x32x32xf32> + %expanded_2 = tensor.expand_shape %arg3 [[0, 1, 2, 3]] output_shape [2, 2, 32, 32] : tensor<4096xf32> into tensor<2x2x32x32xf32> + %expanded_3 = tensor.expand_shape %arg0 [[0, 1, 2, 3]] output_shape [2, 2, 32, 32] : tensor<4096xf32> into tensor<2x2x32x32xf32> + %expanded_4 = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4]] output_shape [2, 2, 1, 32, 32] : tensor<4096xf32> into tensor<2x2x1x32x32xf32> + %0 = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x2x2x32x32xf32>}> : () -> tensor<2x2x2x32x32xf32> + %1 = tosa.add %expanded_4, %0 : (tensor<2x2x1x32x32xf32>, tensor<2x2x2x32x32xf32>) -> tensor<2x2x2x32x32xf32> + %collapsed = tensor.collapse_shape %1 [[0], [1, 2], [3], [4]] : tensor<2x2x2x32x32xf32> into tensor<2x4x32x32xf32> + %expanded_5 = tensor.expand_shape %arg3 [[0, 1, 2, 3, 4]] output_shape [2, 2, 1, 32, 32] : tensor<4096xf32> into tensor<2x2x1x32x32xf32> + %2 = tosa.add %expanded_5, %0 : (tensor<2x2x1x32x32xf32>, tensor<2x2x2x32x32xf32>) -> tensor<2x2x2x32x32xf32> + %collapsed_6 = tensor.collapse_shape %2 [[0], [1, 2], [3], [4]] : tensor<2x2x2x32x32xf32> into tensor<2x4x32x32xf32> + %3 = "tosa.const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> + %4 = tosa.transpose %collapsed_6, %3 : (tensor<2x4x32x32xf32>, tensor<4xi32>) -> tensor<2x4x32x32xf32> + %expanded_7 = tensor.expand_shape %arg2 [[0, 1, 2]] output_shape [8, 32, 32] : tensor<8192xf32> into tensor<8x32x32xf32> + %collapsed_8 = tensor.collapse_shape %4 [[0, 1], [2], [3]] : tensor<2x4x32x32xf32> into tensor<8x32x32xf32> + %5 = tosa.matmul %expanded_7, %collapsed_8 : (tensor<8x32x32xf32>, tensor<8x32x32xf32>) -> tensor<8x32x32xf32> + %expanded_9 = tensor.expand_shape %5 [[0, 1], [2], [3]] output_shape [2, 4, 32, 32] : tensor<8x32x32xf32> into tensor<2x4x32x32xf32> + %6 = tosa.mul %expanded_9, %expanded_0 {shift = 0 : i8} : (tensor<2x4x32x32xf32>, tensor<2x4x32x32xf32>) -> tensor<2x4x32x32xf32> + %7 = tosa.add %6, %expanded : (tensor<2x4x32x32xf32>, tensor<2x4x32x32xf32>) -> tensor<2x4x32x32xf32> + %8 = tosa.reduce_max %7 {axis = 3 : i32} : (tensor<2x4x32x32xf32>) -> tensor<2x4x32x1xf32> + %9 = tosa.sub %7, %8 : (tensor<2x4x32x32xf32>, tensor<2x4x32x1xf32>) -> tensor<2x4x32x32xf32> + %10 = tosa.exp %9 : (tensor<2x4x32x32xf32>) -> tensor<2x4x32x32xf32> + %11 = tosa.reduce_sum %10 {axis = 3 : i32} : (tensor<2x4x32x32xf32>) -> tensor<2x4x32x1xf32> + %12 = tosa.reciprocal %11 : (tensor<2x4x32x1xf32>) -> tensor<2x4x32x1xf32> + %13 = tosa.mul %10, %12 {shift = 0 : i8} : (tensor<2x4x32x32xf32>, tensor<2x4x32x1xf32>) -> tensor<2x4x32x32xf32> + %collapsed_10 = tensor.collapse_shape %13 [[0, 1], [2], [3]] : tensor<2x4x32x32xf32> into tensor<8x32x32xf32> + %collapsed_11 = tensor.collapse_shape %1 [[0, 1, 2], [3], [4]] : tensor<2x2x2x32x32xf32> into tensor<8x32x32xf32> + %14 = tosa.matmul %collapsed_10, %collapsed_11 : (tensor<8x32x32xf32>, tensor<8x32x32xf32>) -> tensor<8x32x32xf32> + %expanded_12 = tensor.expand_shape %14 [[0, 1], [2], [3]] output_shape [2, 4, 32, 32] : tensor<8x32x32xf32> into tensor<2x4x32x32xf32> + %collapsed_13 = tensor.collapse_shape %14 [[0, 1, 2]] : tensor<8x32x32xf32> into tensor<8192xf32> + return %collapsed_13 : tensor<8192xf32> +} + +// CHECK: rock.attention +func.func @self_attention_gqa_scale_bias_kvcache(%arg0: tensor<4096xf32> {mhal.read_access}, %arg1: tensor<256xf32> {mhal.read_access}, %arg2: tensor<256xf32> {mhal.read_access}, %arg3: tensor<4096xf32> {mhal.read_access}, %arg4: tensor<256xf32> {mhal.read_access}) -> (tensor<256xf32> {mhal.write_access}) attributes {kernel, arch = ""} { + %expanded = tensor.expand_shape %arg4 [[0, 1, 2, 3]] output_shape [2, 4, 1, 32] : tensor<256xf32> into tensor<2x4x1x32xf32> + %expanded_0 = tensor.expand_shape %arg1 [[0, 1, 2, 3]] output_shape [2, 4, 1, 32] : tensor<256xf32> into tensor<2x4x1x32xf32> + %expanded_1 = tensor.expand_shape %arg2 [[0, 1, 2, 3]] output_shape [2, 4, 1, 32] : tensor<256xf32> into tensor<2x4x1x32xf32> + %expanded_2 = tensor.expand_shape %arg3 [[0, 1, 2, 3]] output_shape [2, 2, 32, 32] : tensor<4096xf32> into tensor<2x2x32x32xf32> + %expanded_3 = tensor.expand_shape %arg0 [[0, 1, 2, 3]] output_shape [2, 2, 32, 32] : tensor<4096xf32> into tensor<2x2x32x32xf32> + %expanded_4 = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4]] output_shape [2, 2, 1, 32, 32] : tensor<4096xf32> into tensor<2x2x1x32x32xf32> + %0 = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x2x2x32x32xf32>}> : () -> tensor<2x2x2x32x32xf32> + %1 = tosa.add %expanded_4, %0 : (tensor<2x2x1x32x32xf32>, tensor<2x2x2x32x32xf32>) -> tensor<2x2x2x32x32xf32> + %collapsed = tensor.collapse_shape %1 [[0], [1, 2], [3], [4]] : tensor<2x2x2x32x32xf32> into tensor<2x4x32x32xf32> + %expanded_5 = tensor.expand_shape %arg3 [[0, 1, 2, 3, 4]] output_shape [2, 2, 1, 32, 32] : tensor<4096xf32> into tensor<2x2x1x32x32xf32> + %2 = tosa.add %expanded_5, %0 : (tensor<2x2x1x32x32xf32>, tensor<2x2x2x32x32xf32>) -> tensor<2x2x2x32x32xf32> + %collapsed_6 = tensor.collapse_shape %2 [[0], [1, 2], [3], [4]] : tensor<2x2x2x32x32xf32> into tensor<2x4x32x32xf32> + %3 = "tosa.const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> + %4 = tosa.transpose %collapsed_6, %3 : (tensor<2x4x32x32xf32>, tensor<4xi32>) -> tensor<2x4x32x32xf32> + %expanded_7 = tensor.expand_shape %arg2 [[0, 1, 2]] output_shape [8, 1, 32] : tensor<256xf32> into tensor<8x1x32xf32> + %collapsed_8 = tensor.collapse_shape %4 [[0, 1], [2], [3]] : tensor<2x4x32x32xf32> into tensor<8x32x32xf32> + %5 = tosa.matmul %expanded_7, %collapsed_8 : (tensor<8x1x32xf32>, tensor<8x32x32xf32>) -> tensor<8x1x32xf32> + %expanded_9 = tensor.expand_shape %5 [[0, 1], [2], [3]] output_shape [2, 4, 1, 32] : tensor<8x1x32xf32> into tensor<2x4x1x32xf32> + %6 = tosa.mul %expanded_9, %expanded_0 {shift = 0 : i8} : (tensor<2x4x1x32xf32>, tensor<2x4x1x32xf32>) -> tensor<2x4x1x32xf32> + %7 = tosa.add %6, %expanded : (tensor<2x4x1x32xf32>, tensor<2x4x1x32xf32>) -> tensor<2x4x1x32xf32> + %8 = tosa.reduce_max %7 {axis = 3 : i32} : (tensor<2x4x1x32xf32>) -> tensor<2x4x1x1xf32> + %9 = tosa.sub %7, %8 : (tensor<2x4x1x32xf32>, tensor<2x4x1x1xf32>) -> tensor<2x4x1x32xf32> + %10 = tosa.exp %9 : (tensor<2x4x1x32xf32>) -> tensor<2x4x1x32xf32> + %11 = tosa.reduce_sum %10 {axis = 3 : i32} : (tensor<2x4x1x32xf32>) -> tensor<2x4x1x1xf32> + %12 = tosa.reciprocal %11 : (tensor<2x4x1x1xf32>) -> tensor<2x4x1x1xf32> + %13 = tosa.mul %10, %12 {shift = 0 : i8} : (tensor<2x4x1x32xf32>, tensor<2x4x1x1xf32>) -> tensor<2x4x1x32xf32> + %collapsed_10 = tensor.collapse_shape %13 [[0, 1], [2], [3]] : tensor<2x4x1x32xf32> into tensor<8x1x32xf32> + %collapsed_11 = tensor.collapse_shape %1 [[0, 1, 2], [3], [4]] : tensor<2x2x2x32x32xf32> into tensor<8x32x32xf32> + %14 = tosa.matmul %collapsed_10, %collapsed_11 : (tensor<8x1x32xf32>, tensor<8x32x32xf32>) -> tensor<8x1x32xf32> + %expanded_12 = tensor.expand_shape %14 [[0, 1], [2], [3]] output_shape [2, 4, 1, 32] : tensor<8x1x32xf32> into tensor<2x4x1x32xf32> + %collapsed_13 = tensor.collapse_shape %14 [[0, 1, 2]] : tensor<8x1x32xf32> into tensor<256xf32> + return %collapsed_13 : tensor<256xf32> +} diff --git a/mlir/test/Dialect/Rock/gemm_to_gridwise.mlir b/mlir/test/Dialect/Rock/gemm_to_gridwise.mlir index 43b118ee7943..ce6883807b9e 100644 --- a/mlir/test/Dialect/Rock/gemm_to_gridwise.mlir +++ b/mlir/test/Dialect/Rock/gemm_to_gridwise.mlir @@ -203,3 +203,21 @@ func.func @rock_attention_tr_padded(%arg0: memref<1x49x7xf32>, %arg1: memref<1x7 } return } + +// CHECK-LABEL: func.func @rock_attention_kvcache +// CHECK-SAME: (%[[q:.*]]: memref<1x64x1024xf32>, %[[k:.*]]: memref<1x64x1024xf32>, %[[v:.*]]: memref<1x1024x64xf32>, %[[o:.*]]: memref<1x1024x64xf32>, %[[currentSeqLen:.*]]: memref<1xi32>) +func.func @rock_attention_kvcache(%arg0: memref<1x64x1024xf32>, %arg1: memref<1x64x1024xf32>, %arg2: memref<1x1024x64xf32>, %arg3: memref<1x1024x64xf32>, %arg4: memref<1xi32>) attributes {kernel, mhal.arch = "amdgcn-amd-amdhsa:gfx908", block_size = 64 : i32, grid_size = 1024 : i32} { + // CHECK: rock.gridwise_attention_accel(%[[q]], %[[k]], %[[v]], %[[currentSeqLen]], %[[o]]) + rock.attention{ + qk = tr %arg0 * %arg1 : memref<1x64x1024xf32>, memref<1x64x1024xf32> + currentSeqLen = (%arg4 : memref<1xi32>) + %arg3 = softmax(qk) * %arg2 : memref<1x1024x64xf32> -> memref<1x1024x64xf32> + } { + arch = "amdgcn-amd-amdhsa:gfx908", + features = #rock, + params0 = #xldops_attn_params_g0, + params1 = #xldops_attn_params_g1, + firstGemmIdx = 0 : i32 + } + return +} diff --git a/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir b/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir index 749efaf8d050..bfd2208f875e 100644 --- a/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir +++ b/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir @@ -288,3 +288,39 @@ func.func @gridwise_attn_grid_reversed(%arg0: memref<1x384x64xf32>, %arg1: memre } : memref<1x64x384xf32>, memref<1x64x384xf32>, memref<1x384x64xf32>, memref<1x384x64xf32> return } + +// ----- + +// CHECK: @gridwise_attn_kvcache +func.func @gridwise_attn_kvcache(%arg0: memref<1x384x64xf32>, %arg1: memref<1x64x384xf32>, %arg2: memref<1x384x64xf32>, %arg3: memref<1x384x64xf32>, %arg4: memref<1xi32>) attributes {block_size = 64 : i32, grid_size = 24 : i32, kernel, mhal.arch = "amdgcn-amd-amdhsa:gfx908:sramecc+:xnack-"} { + %0 = rock.transform %arg0 by (d0, d2, d1)> by [ ["gemmG"] at [0]>, ["gemm0K", "gemm0M"] at [2, 1]>] bounds = [1, 64, 384] -> [1, 384, 64]> : memref<1x384x64xf32> to memref<1x64x384xf32> + // CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index + // CHECK-DAG: %[[c31:.+]] = arith.constant 31 : index + // CHECK-DAG: %[[c32:.+]] = arith.constant 32 : index + // CHECK: %[[currSeqLenTensor:.+]] = rock.transform %arg4 by #{{.+}} : memref<1xi32> to memref<1x1xi32> + // CHECK: %[[registers:.+]] = rock.alloc() : memref<1xi32, #gpu.address_space> + // CHECK-NEXT: rock.threadwise_read_into {forceUnroll, useIndexDiffs} [](%[[currSeqLenTensor]]) [%{{.+}}] -> %[[registers]] : memref<1x1xi32> -> memref<1xi32, #gpu.address_space>, vector<1xi1> + // CHECK-NEXT: %[[currSeqLen:.+]] = rock.in_bounds_load %[[registers]][%[[c0]]] : memref<1xi32, #gpu.address_space>, index -> i32 + // CHECK-NEXT: %[[currSeqLenIndex:.+]] = arith.index_cast %[[currSeqLen]] : i32 to index + // CHECK: %[[num:.+]] = arith.addi %[[currSeqLenIndex]], %[[c31]] : index + // CHECK-NEXT: %[[numIter:.+]] = arith.divui %[[num]], %[[c32]] : index + // CHECK-NEXT: %[[lastIter:.+]] = arith.subi %[[numIter]], %[[c1]] : index + // CHECK-NEXT: scf.for %[[iterIndex:.+]] = %[[c0]] to %[[numIter]] step %[[c1]] { + // CHECK: %[[comparison:.+]] = arith.cmpi eq, %[[iterIndex]], %[[lastIter]] : index + // CHECK-NEXT: scf.if %[[comparison]] { + // CHECK: rock.transforming_for {forceUnroll, useIndexDiffs} (%[[dim0:.+]], %[[dim1:.+]], %[[dim2:.+]]) = [{{.*}}]({{.*}}), ({{.*}}) = [] + // CHECK-NEXT: %[[secondComparison:.+]] = arith.cmpi uge, %[[dim2]], %[[currSeqLenIndex]] : index + // CHECK-NEXT: scf.if %[[secondComparison]] { + // CHECK-NEXT: rock.in_bounds_store + rock.gridwise_attention_accel(%0, %arg1, %arg2, %arg4, %arg3) features = mfma|dot|atomic_add preSoftmaxOps = {} { + arch = "amdgcn-amd-amdhsa:gfx908:sramecc+:xnack-", + blockSize = 64 : i32, + gridSize = 24 : i32, + operandSegmentSizes = array, + params0 = #rock.xdlops_gemm_derived_params, + params1 = #rock.xdlops_gemm_derived_params, + firstGemmIdx = 0 : i32 + } : memref<1x64x384xf32>, memref<1x64x384xf32>, memref<1x384x64xf32>, memref<1xi32>, memref<1x384x64xf32> + return +} diff --git a/mlir/test/e2e/PrAttentionF16.toml b/mlir/test/e2e/PrAttentionF16.toml index cb25d59f2bf8..e17e6cdb0a3e 100644 --- a/mlir/test/e2e/PrAttentionF16.toml +++ b/mlir/test/e2e/PrAttentionF16.toml @@ -1,6 +1,6 @@ directory = "PrAttentionF16" prefix = "rocmlir-gen" -suffix = "--operation attention -t f16 --arch %arch -pv %random_data %rocmlir_gen_flags -relDiff_threshold 0.02 -absDiff_threshold 0.02 -RMS_threshold 0.015 | rocmlir-driver -c | mlir-cpu-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=" +suffix = "--operation attention -t f16 --arch %arch -pv %random_data %rocmlir_gen_flags -relDiff_threshold 0.02 -absDiff_threshold 0.1 -RMS_threshold 0.015 | rocmlir-driver -c | mlir-cpu-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=" [[axis]] name = "operation" @@ -65,3 +65,15 @@ config = "-seq_len_q 384 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --with-at # cross attention [[suite.test]] config = "-seq_len_q 128 -seq_len_k 27 -head_dim_qk 64 -head_dim_v 32 --with-attn-scale --with-attn-bias" + +# GQA +[[suite.test]] +config = "-num_heads_q 4 -num_heads_kv 2 -seq_len_q 384 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --with-attn-scale --with-attn-bias" + +# GQA + KV Cache +[[suite.test]] +config = "-rand 1 -current_seq_len=17 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --with-attn-scale --with-attn-bias" + +# GQA + KV Cache batch=3 +[[suite.test]] +config = "-rand 1 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --with-attn-scale --with-attn-bias" diff --git a/mlir/test/e2e/PrAttentionF32.toml b/mlir/test/e2e/PrAttentionF32.toml index 8319f0aa1b56..e8ddbc855502 100644 --- a/mlir/test/e2e/PrAttentionF32.toml +++ b/mlir/test/e2e/PrAttentionF32.toml @@ -1,6 +1,6 @@ directory = "PrAttentionF32" prefix = "rocmlir-gen" -suffix = "--operation attention -t f32 --arch %arch -pv %random_data %rocmlir_gen_flags -relDiff_threshold 0.000005 | rocmlir-driver -c | mlir-cpu-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=" +suffix = "--operation attention -t f32 --arch %arch -pv %random_data %rocmlir_gen_flags -relDiff_threshold 0.00005 | rocmlir-driver -c | mlir-cpu-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=" [[axis]] name = "operation" @@ -45,3 +45,14 @@ config = "-seq_len_q 384 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --with-at [[suite.test]] config = "-seq_len_q 128 -seq_len_k 27 -head_dim_qk 64 -head_dim_v 32 --with-attn-scale --with-attn-bias" +# GQA +[[suite.test]] +config = "-num_heads_q 4 -num_heads_kv 2 -seq_len_q 384 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --with-attn-scale --with-attn-bias" + +# GQA + KV Cache +[[suite.test]] +config = "-rand 1 -current_seq_len=17 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --with-attn-scale --with-attn-bias" + +# GQA + KV Cache batch=3 +[[suite.test]] +config = "-rand 1 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --with-attn-scale --with-attn-bias" diff --git a/mlir/test/e2e/PrAttentionI8.toml b/mlir/test/e2e/PrAttentionI8.toml index c90da602b3bb..bf178f44e29d 100644 --- a/mlir/test/e2e/PrAttentionI8.toml +++ b/mlir/test/e2e/PrAttentionI8.toml @@ -54,3 +54,14 @@ config = "-seq_len_q 384 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --with-at [[suite.test]] config = "-seq_len_q 128 -seq_len_k 27 -head_dim_qk 64 -head_dim_v 32 --with-attn-scale --with-attn-bias" +# GQA +[[suite.test]] +config = "-num_heads_q 4 -num_heads_kv 2 -seq_len_q 384 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --with-attn-scale --with-attn-bias" + +# GQA + KV Cache batch=1 +[[suite.test]] +config = "-rand 1 -current_seq_len=17 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --with-attn-scale --with-attn-bias" + +# GQA + KV Cache batch=3 +[[suite.test]] +config = "-rand 1 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --with-attn-scale --with-attn-bias" diff --git a/mlir/test/fusion/pr-e2e/attention/mixr-attention-gqa-bias.mlir b/mlir/test/fusion/pr-e2e/attention/mixr-attention-gqa-bias.mlir new file mode 100644 index 000000000000..da073a6e3b65 --- /dev/null +++ b/mlir/test/fusion/pr-e2e/attention/mixr-attention-gqa-bias.mlir @@ -0,0 +1,21 @@ +// RUN: rocmlir-gen -fut mlir_attention --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx | rocmlir-driver -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_attention_wrapper -relDiff_threshold 0.000004 --verifier clone - | rocmlir-driver -host-pipeline mhal -kernel-pipeline full | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s +// ALLOW_RETRIES: 2 +// CHECK: [1 1 1] +module { + func.func private @mlir_attention(%v: !migraphx.shaped<2x2x32x32xf32, 2048x1024x32x1> {mhal.read_access}, + %q: !migraphx.shaped<2x4x32x32xf32, 4096x1024x32x1> {mhal.read_access}, + %k: !migraphx.shaped<2x2x32x32xf32, 2048x1024x32x1> {mhal.read_access}, + %bias: !migraphx.shaped<2x4x32x32xf32, 4096x1024x32x1> {mhal.read_access}) + -> (!migraphx.shaped<2x4x32x32xf32, 4096x1024x32x1> {mhal.write_access}) { + %vbroadcast = migraphx.multibroadcast %v {out_dyn_dims = [], out_lens = [2, 2, 2, 32, 32]} : <2x2x32x32xf32, 2048x1024x32x1> -> <2x2x2x32x32xf32, 2048x1024x0x32x1> + %vreshaped = migraphx.reshape %vbroadcast {dims = [2, 4, 32, 32]} : <2x2x2x32x32xf32, 2048x1024x0x32x1> -> <2x4x32x32xf32, 2048x1024x32x1> + %kbroadcast = migraphx.multibroadcast %k {out_dyn_dims = [], out_lens = [2, 2, 2, 32, 32]} : <2x2x32x32xf32, 2048x1024x32x1> -> <2x2x2x32x32xf32, 2048x1024x0x32x1> + %kreshaped = migraphx.reshape %kbroadcast {dims = [2, 4, 32, 32]} : <2x2x2x32x32xf32, 2048x1024x0x32x1> -> <2x4x32x32xf32, 2048x1024x32x1> + %kt = migraphx.transpose %kreshaped {permutation = [0, 1, 3, 2]} : <2x4x32x32xf32, 2048x1024x32x1> -> <2x4x32x32xf32, 2048x1024x32x1> + %qk = migraphx.dot %q, %kt : <2x4x32x32xf32, 4096x1024x32x1>, <2x4x32x32xf32, 2048x1024x32x1> -> <2x4x32x32xf32, 4096x1024x32x1> + %qk_biased = migraphx.add %qk, %bias : <2x4x32x32xf32, 4096x1024x32x1>, <2x4x32x32xf32, 4096x1024x32x1> -> <2x4x32x32xf32, 4096x1024x32x1> + %att = migraphx.softmax %qk_biased {axis = 3 : i64} : <2x4x32x32xf32, 4096x1024x32x1> -> <2x4x32x32xf32, 4096x1024x32x1> + %res = migraphx.dot %att, %vreshaped : <2x4x32x32xf32, 4096x1024x32x1>, <2x4x32x32xf32, 2048x1024x32x1> -> <2x4x32x32xf32, 4096x1024x32x1> + return %res : !migraphx.shaped<2x4x32x32xf32, 4096x1024x32x1> + } +} diff --git a/mlir/test/fusion/pr-e2e/attention/mixr-attention-gqa-scale-bias.mlir b/mlir/test/fusion/pr-e2e/attention/mixr-attention-gqa-scale-bias.mlir new file mode 100644 index 000000000000..3a7619cfb568 --- /dev/null +++ b/mlir/test/fusion/pr-e2e/attention/mixr-attention-gqa-scale-bias.mlir @@ -0,0 +1,23 @@ +// RUN: rocmlir-gen -fut mlir_attention --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx | rocmlir-driver -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_attention_wrapper -relDiff_threshold 0.000004 --verifier clone - | rocmlir-driver -host-pipeline mhal -kernel-pipeline full | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s +// ALLOW_RETRIES: 2 +// CHECK: [1 1 1] +module { + func.func private @mlir_attention(%v: !migraphx.shaped<2x2x32x32xf32, 2048x1024x32x1> {mhal.read_access}, + %scale: !migraphx.shaped<2x4x32x32xf32, 4096x1024x32x1> {mhal.read_access}, + %q: !migraphx.shaped<2x4x32x32xf32, 4096x1024x32x1> {mhal.read_access}, + %k: !migraphx.shaped<2x2x32x32xf32, 2048x1024x32x1> {mhal.read_access}, + %bias: !migraphx.shaped<2x4x32x32xf32, 4096x1024x32x1> {mhal.read_access}) + -> (!migraphx.shaped<2x4x32x32xf32, 4096x1024x32x1> {mhal.write_access}) { + %vbroadcast = migraphx.multibroadcast %v {out_dyn_dims = [], out_lens = [2, 2, 2, 32, 32]} : <2x2x32x32xf32, 2048x1024x32x1> -> <2x2x2x32x32xf32, 2048x1024x0x32x1> + %vreshaped = migraphx.reshape %vbroadcast {dims = [2, 4, 32, 32]} : <2x2x2x32x32xf32, 2048x1024x0x32x1> -> <2x4x32x32xf32, 2048x1024x32x1> + %kbroadcast = migraphx.multibroadcast %k {out_dyn_dims = [], out_lens = [2, 2, 2, 32, 32]} : <2x2x32x32xf32, 2048x1024x32x1> -> <2x2x2x32x32xf32, 2048x1024x0x32x1> + %kreshaped = migraphx.reshape %kbroadcast {dims = [2, 4, 32, 32]} : <2x2x2x32x32xf32, 2048x1024x0x32x1> -> <2x4x32x32xf32, 2048x1024x32x1> + %kt = migraphx.transpose %kreshaped {permutation = [0, 1, 3, 2]} : <2x4x32x32xf32, 2048x1024x32x1> -> <2x4x32x32xf32, 2048x1024x32x1> + %qk = migraphx.dot %q, %kt : <2x4x32x32xf32, 4096x1024x32x1>, <2x4x32x32xf32, 2048x1024x32x1> -> <2x4x32x32xf32, 4096x1024x32x1> + %qk_scaled = migraphx.mul %qk, %scale : <2x4x32x32xf32, 4096x1024x32x1>, <2x4x32x32xf32, 4096x1024x32x1> -> <2x4x32x32xf32, 4096x1024x32x1> + %qk_biased = migraphx.add %qk_scaled, %bias : <2x4x32x32xf32, 4096x1024x32x1>, <2x4x32x32xf32, 4096x1024x32x1> -> <2x4x32x32xf32, 4096x1024x32x1> + %att = migraphx.softmax %qk_biased {axis = 3 : i64} : <2x4x32x32xf32, 4096x1024x32x1> -> <2x4x32x32xf32, 4096x1024x32x1> + %res = migraphx.dot %att, %vreshaped : <2x4x32x32xf32, 4096x1024x32x1>, <2x4x32x32xf32, 2048x1024x32x1> -> <2x4x32x32xf32, 4096x1024x32x1> + return %res : !migraphx.shaped<2x4x32x32xf32, 4096x1024x32x1> + } +} diff --git a/mlir/test/fusion/pr-e2e/attention/mixr-attention-gqa-scale.mlir b/mlir/test/fusion/pr-e2e/attention/mixr-attention-gqa-scale.mlir new file mode 100644 index 000000000000..4ec693d64664 --- /dev/null +++ b/mlir/test/fusion/pr-e2e/attention/mixr-attention-gqa-scale.mlir @@ -0,0 +1,21 @@ +// RUN: rocmlir-gen -fut mlir_attention --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx | rocmlir-driver -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_attention_wrapper -relDiff_threshold 0.000004 --verifier clone - | rocmlir-driver -host-pipeline mhal -kernel-pipeline full | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s +// ALLOW_RETRIES: 2 +// CHECK: [1 1 1] +module { + func.func private @mlir_attention(%v: !migraphx.shaped<2x2x32x32xf32, 2048x1024x32x1> {mhal.read_access}, + %scale: !migraphx.shaped<2x4x32x32xf32, 4096x1024x32x1> {mhal.read_access}, + %q: !migraphx.shaped<2x4x32x32xf32, 4096x1024x32x1> {mhal.read_access}, + %k: !migraphx.shaped<2x2x32x32xf32, 2048x1024x32x1> {mhal.read_access}) + -> (!migraphx.shaped<2x4x32x32xf32, 4096x1024x32x1> {mhal.write_access}) { + %vbroadcast = migraphx.multibroadcast %v {out_dyn_dims = [], out_lens = [2, 2, 2, 32, 32]} : <2x2x32x32xf32, 2048x1024x32x1> -> <2x2x2x32x32xf32, 2048x1024x0x32x1> + %vreshaped = migraphx.reshape %vbroadcast {dims = [2, 4, 32, 32]} : <2x2x2x32x32xf32, 2048x1024x0x32x1> -> <2x4x32x32xf32, 2048x1024x32x1> + %kbroadcast = migraphx.multibroadcast %k {out_dyn_dims = [], out_lens = [2, 2, 2, 32, 32]} : <2x2x32x32xf32, 2048x1024x32x1> -> <2x2x2x32x32xf32, 2048x1024x0x32x1> + %kreshaped = migraphx.reshape %kbroadcast {dims = [2, 4, 32, 32]} : <2x2x2x32x32xf32, 2048x1024x0x32x1> -> <2x4x32x32xf32, 2048x1024x32x1> + %kt = migraphx.transpose %kreshaped {permutation = [0, 1, 3, 2]} : <2x4x32x32xf32, 2048x1024x32x1> -> <2x4x32x32xf32, 2048x1024x32x1> + %qk = migraphx.dot %q, %kt : <2x4x32x32xf32, 4096x1024x32x1>, <2x4x32x32xf32, 2048x1024x32x1> -> <2x4x32x32xf32, 4096x1024x32x1> + %qk_scaled = migraphx.mul %qk, %scale : <2x4x32x32xf32, 4096x1024x32x1>, <2x4x32x32xf32, 4096x1024x32x1> -> <2x4x32x32xf32, 4096x1024x32x1> + %att = migraphx.softmax %qk_scaled {axis = 3 : i64} : <2x4x32x32xf32, 4096x1024x32x1> -> <2x4x32x32xf32, 4096x1024x32x1> + %res = migraphx.dot %att, %vreshaped : <2x4x32x32xf32, 4096x1024x32x1>, <2x4x32x32xf32, 2048x1024x32x1> -> <2x4x32x32xf32, 4096x1024x32x1> + return %res : !migraphx.shaped<2x4x32x32xf32, 4096x1024x32x1> + } +} diff --git a/mlir/test/fusion/pr-e2e/attention/mixr-attention-gqa.mlir b/mlir/test/fusion/pr-e2e/attention/mixr-attention-gqa.mlir new file mode 100644 index 000000000000..34e4ccb2702f --- /dev/null +++ b/mlir/test/fusion/pr-e2e/attention/mixr-attention-gqa.mlir @@ -0,0 +1,19 @@ +// RUN: rocmlir-gen -fut mlir_attention --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx | rocmlir-driver -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_attention_wrapper -relDiff_threshold 0.000004 --verifier clone - | rocmlir-driver -host-pipeline mhal -kernel-pipeline full | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s +// ALLOW_RETRIES: 2 +// CHECK: [1 1 1] +module { + func.func private @mlir_attention(%v: !migraphx.shaped<2x2x32x32xf32, 2048x1024x32x1> {mhal.read_access}, + %q: !migraphx.shaped<2x4x32x32xf32, 4096x1024x32x1> {mhal.read_access}, + %k: !migraphx.shaped<2x2x32x32xf32, 2048x1024x32x1> {mhal.read_access}) + -> (!migraphx.shaped<2x4x32x32xf32, 4096x1024x32x1> {mhal.write_access}) { + %vbroadcast = migraphx.multibroadcast %v {out_dyn_dims = [], out_lens = [2, 2, 2, 32, 32]} : <2x2x32x32xf32, 2048x1024x32x1> -> <2x2x2x32x32xf32, 2048x1024x0x32x1> + %vreshaped = migraphx.reshape %vbroadcast {dims = [2, 4, 32, 32]} : <2x2x2x32x32xf32, 2048x1024x0x32x1> -> <2x4x32x32xf32, 2048x1024x32x1> + %kbroadcast = migraphx.multibroadcast %k {out_dyn_dims = [], out_lens = [2, 2, 2, 32, 32]} : <2x2x32x32xf32, 2048x1024x32x1> -> <2x2x2x32x32xf32, 2048x1024x0x32x1> + %kreshaped = migraphx.reshape %kbroadcast {dims = [2, 4, 32, 32]} : <2x2x2x32x32xf32, 2048x1024x0x32x1> -> <2x4x32x32xf32, 2048x1024x32x1> + %kt = migraphx.transpose %kreshaped {permutation = [0, 1, 3, 2]} : <2x4x32x32xf32, 2048x1024x32x1> -> <2x4x32x32xf32, 2048x1024x32x1> + %qk = migraphx.dot %q, %kt : <2x4x32x32xf32, 4096x1024x32x1>, <2x4x32x32xf32, 2048x1024x32x1> -> <2x4x32x32xf32, 4096x1024x32x1> + %att = migraphx.softmax %qk {axis = 3 : i64} : <2x4x32x32xf32, 4096x1024x32x1> -> <2x4x32x32xf32, 4096x1024x32x1> + %res = migraphx.dot %att, %vreshaped : <2x4x32x32xf32, 4096x1024x32x1>, <2x4x32x32xf32, 2048x1024x32x1> -> <2x4x32x32xf32, 4096x1024x32x1> + return %res : !migraphx.shaped<2x4x32x32xf32, 4096x1024x32x1> + } +} diff --git a/mlir/test/rocmlir-gen/attention-kernel-gqa-kvcache.mlir b/mlir/test/rocmlir-gen/attention-kernel-gqa-kvcache.mlir new file mode 100644 index 000000000000..83ebfe1e237a --- /dev/null +++ b/mlir/test/rocmlir-gen/attention-kernel-gqa-kvcache.mlir @@ -0,0 +1,141 @@ +// RUN: rocmlir-gen --arch gfx90a:sramecc+:xnack- --operation attention -current_seq_len=33 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1024 -seq_len_k 1024 -head_dim_qk 32 -head_dim_v 32 --with-attn-scale -t f32 -pv --apply-bufferization-pipeline=false | rocmlir-opt | FileCheck %s --enable-var-scope --check-prefixes=CHECK_SCALE + +// CHECK_SCALE: module attributes {mhal.arch = "[[$ARCH:.*]]"} + +// CHECK_SCALE-LABEL: func.func @rock_attention +// CHECK_SCALE-SAME: (%[[queriesRaw:.*0]]: memref<131072xf32>, +// CHECK_SCALE-SAME: %[[keysRaw:.*1]]: memref<65536xf32>, +// CHECK_SCALE-SAME: %[[valuesRaw:.*2]]: memref<65536xf32>, +// CHECK_SCALE-SAME: %[[scaleRaw:.*3]]: memref<4194304xf32>, +// CHECK_SCALE-SAME: %[[currentSeqLenRaw:.*4]]: memref<1xi32>, +// CHECK_SCALE-SAME: %[[outputRaw:.*5]]: memref<131072xf32>) +// CHECK_SCALE-SAME: attributes {kernel, mhal.arch = "[[$ARCH]]"} +// CHECK_SCALE-NEXT: %[[queries:.*]] = rock.transform %[[queriesRaw]] {{.*}} : memref<131072xf32> to memref<4x1024x32xf32> +// CHECK_SCALE-NEXT: %[[keysGQA:.*]] = rock.transform %[[keysRaw]] {{.*}} : memref<65536xf32> to memref<2x32x1024xf32> +// CHECK_SCALE-NEXT: %[[valuesGQA:.*]] = rock.transform %[[valuesRaw]] {{.*}} : memref<65536xf32> to memref<2x1024x32xf32> +// CHECK_SCALE-NEXT: %[[scale:.*]] = rock.transform %[[scaleRaw]] {{.*}} : memref<4194304xf32> to memref<4x1024x1024xf32> +// CHECK_SCALE-NEXT: %[[currentSeqLen:.*]] = rock.transform %[[currentSeqLenRaw]] {{.*}} : memref<1xi32> to memref<1xi32> +// CHECK_SCALE-NEXT: %[[output:.*]] = rock.transform %[[outputRaw]] {{.*}} : memref<131072xf32> to memref<4x1024x32xf32> +// CHECK_SCALE-NEXT: %[[currentSeqLenAddDim:.*]] = rock.transform %[[currentSeqLen]] {{.*}} : memref<1xi32> to memref<1x1xi32> +// CHECK_SCALE-NEXT: %[[currentSeqLenBroadcast:.*]] = rock.transform %[[currentSeqLenAddDim]] {{.*}} : memref<1x1xi32> to memref<1x4xi32> +// CHECK_SCALE-NEXT: %[[currentSeqLenMerge:.*]] = rock.transform %[[currentSeqLenBroadcast]] {{.*}} : memref<1x4xi32> to memref<4xi32> +// CHECK_SCALE-NEXT: %[[keysAddDim:.*]] = rock.transform %[[keysGQA]] {{.*}} : memref<2x32x1024xf32> to memref<2x1x32x1024xf32> +// CHECK_SCALE-NEXT: %[[keysBroadcast:.*]] = rock.transform %[[keysAddDim]] {{.*}} : memref<2x1x32x1024xf32> to memref<2x2x32x1024xf32> +// CHECK_SCALE-NEXT: %[[keys:.*]] = rock.transform %[[keysBroadcast]] {{.*}} : memref<2x2x32x1024xf32> to memref<4x32x1024xf32> +// CHECK_SCALE-NEXT: %[[valuesAddDim:.*]] = rock.transform %[[valuesGQA]] {{.*}} : memref<2x1024x32xf32> to memref<2x1x1024x32xf32> +// CHECK_SCALE-NEXT: %[[valuesBroadcast:.*]] = rock.transform %[[valuesAddDim]] {{.*}} : memref<2x1x1024x32xf32> to memref<2x2x1024x32xf32> +// CHECK_SCALE-NEXT: %[[values:.*]] = rock.transform %[[valuesBroadcast]] {{.*}} : memref<2x2x1024x32xf32> to memref<4x1024x32xf32> + +// CHECK_SCALE-NEXT: rock.attention +// CHECK_SCALE-NEXT: qk = %[[queries]] * %[[keys]] +// CHECK_SCALE-NEXT: currentSeqLen = (%[[currentSeqLenMerge]] : memref<4xi32>) +// CHECK_SCALE-NEXT: qk = elementwise otherIns(%[[scale]] +// CHECK_SCALE: %[[output]] = softmax(qk) * %[[values]] +// CHECK_SCALE: return + +// CHECK_SCALE-LABEL: func.func @host_naive_attention +// CHECK_SCALE: %[[keysExpanded:.*]] = tensor.expand_shape {{.*}} output_shape [2, 1, 32, 1024] : tensor<2x32x1024xf32> into tensor<2x1x32x1024xf32> +// CHECK_SCALE: %[[keysAdd:.*]] = tosa.add %{{.*}}, %[[keysExpanded]] : (tensor<2x2x32x1024xf32>, tensor<2x1x32x1024xf32>) -> tensor<2x2x32x1024xf32> +// CHECK_SCALE: %[[keysTensor:.*]] = tensor.collapse_shape %[[keysAdd]] {{.*}} : tensor<2x2x32x1024xf32> into tensor<4x32x1024xf32> +// CHECK_SCALE: %[[valuesExpanded:.*]] = tensor.expand_shape {{.*}} output_shape [2, 1, 1024, 32] : tensor<2x1024x32xf32> into tensor<2x1x1024x32xf32> +// CHECK_SCALE: %[[valuesAdd:.*]] = tosa.add %{{.*}}, %[[valuesExpanded]] : (tensor<2x2x1024x32xf32>, tensor<2x1x1024x32xf32>) -> tensor<2x2x1024x32xf32> +// CHECK_SCALE: %[[valuesTensor:.*]] = tensor.collapse_shape %[[valuesAdd]] {{.*}} : tensor<2x2x1024x32xf32> into [[valuesShape:tensor<.*>]] +// CHECK_SCALE: %[[qkTensorOrig:.*]] = tosa.matmul %[[queriesTensor:.*]], %[[keysTensor]] : ([[queriesShape:tensor<.*>]], [[keysShape:tensor<.*>]]) -> [[squareShape:tensor<.*>]] + +// CHECK_SCALE: %[[currSeqLenTensorDumbReshaped:.*]] = tosa.reshape %[[currSeqLenTensor:.*]] {new_shape = array} : (tensor<1xi32>) -> tensor<1xi32> +// CHECK_SCALE: %[[currSeqLenTensorReshaped:.*]] = tosa.reshape %[[currSeqLenTensorDumbReshaped]] {new_shape = array} : (tensor<1xi32>) -> tensor<1x1x1x1xi32> + +// CHECK_SCALE: %[[scaledFirstReshaped:.*]] = tosa.reshape %[[scaledTensorRaw:.*]] {new_shape = array} : (tensor<4194304xf32>) -> tensor<4x1024x1024xf32> +// CHECK_SCALE: %[[scaledReshaped:.*]] = tosa.reshape %[[scaledFirstReshaped:.*]] {new_shape = array} : (tensor<4x1024x1024xf32>) -> tensor<1x4x1024x1024xf32> +// CHECK_SCALE: %[[range2:.*]] = "tosa.const"() <{value = {{.*}} : tensor<1024xi32>}> : () -> tensor<1024xi32> +// CHECK_SCALE: %[[zero2:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x4x1024x1024xi32>}> : () -> tensor<1x4x1024x1024xi32> +// CHECK_SCALE: %[[rangeBroadcast2:.*]] = tosa.add %[[zero2]], %[[range2]] : (tensor<1x4x1024x1024xi32>, tensor<1024xi32>) -> tensor<1x4x1024x1024xi32> +// CHECK_SCALE: %[[currSeqLenTensorBroadcast2:.*]] = tosa.add %[[zero2]], %[[currSeqLenTensorReshaped]] : (tensor<1x4x1024x1024xi32>, tensor<1x1x1x1xi32>) -> tensor<1x4x1024x1024xi32> +// CHECK_SCALE: %[[mask2:.*]] = tosa.greater_equal %[[rangeBroadcast2]], %[[currSeqLenTensorBroadcast2]] : (tensor<1x4x1024x1024xi32>, tensor<1x4x1024x1024xi32>) -> tensor<1x4x1024x1024xi1> +// CHECK_SCALE: %[[one:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x4x1024x1024xf32>}> : () -> tensor<1x4x1024x1024xf32> +// CHECK_SCALE: %[[scaleTensorBeforeReshape:.*]] = tosa.select %[[mask2]], %[[one]], %[[scaledReshaped]] : (tensor<1x4x1024x1024xi1>, tensor<1x4x1024x1024xf32>, tensor<1x4x1024x1024xf32>) -> tensor<1x4x1024x1024xf32> +// CHECK_SCALE: %[[scaleTensor:.*]] = tosa.reshape %[[scaleTensorBeforeReshape]] {new_shape = array} : (tensor<1x4x1024x1024xf32>) -> tensor<4x1024x1024xf32> +// CHECK_SCALE: %[[sqkTensor:.*]] = tosa.mul %[[qkTensorOrig]], %[[scaleTensor]] {{.*}} : ([[squareShape]], [[squareShape]]) -> [[squareShape]] + +// CHECK_SCALE: %[[qkTensorReshaped:.*]] = tosa.reshape %[[sqkTensor]] {new_shape = array} : (tensor<4x1024x1024xf32>) -> tensor<1x4x1024x1024xf32> +// CHECK_SCALE: %[[range:.*]] = "tosa.const"() <{value = {{.*}} : tensor<1024xi32>}> : () -> tensor<1024xi32> +// CHECK_SCALE: %[[zero:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x4x1024x1024xi32>}> : () -> tensor<1x4x1024x1024xi32> +// CHECK_SCALE: %[[rangeBroadcast:.*]] = tosa.add %[[zero]], %[[range]] : (tensor<1x4x1024x1024xi32>, tensor<1024xi32>) -> tensor<1x4x1024x1024xi32> +// CHECK_SCALE: %[[currSeqLenTensorBroadcast:.*]] = tosa.add %[[zero]], %[[currSeqLenTensorReshaped]] : (tensor<1x4x1024x1024xi32>, tensor<1x1x1x1xi32>) -> tensor<1x4x1024x1024xi32> +// CHECK_SCALE: %[[mask:.*]] = tosa.greater_equal %[[rangeBroadcast]], %[[currSeqLenTensorBroadcast]] : (tensor<1x4x1024x1024xi32>, tensor<1x4x1024x1024xi32>) -> tensor<1x4x1024x1024xi1> +// CHECK_SCALE: %[[negInf:.*]] = "tosa.const"() <{value = dense<0xFF800000> : tensor<1x4x1024x1024xf32>}> : () -> tensor<1x4x1024x1024xf32> +// CHECK_SCALE: %[[qkTensorBeforeReshape:.*]] = tosa.select %[[mask]], %[[negInf]], %[[qkTensorReshaped]] : (tensor<1x4x1024x1024xi1>, tensor<1x4x1024x1024xf32>, tensor<1x4x1024x1024xf32>) -> tensor<1x4x1024x1024xf32> +// CHECK_SCALE: %[[qkTensor:.*]] = tosa.reshape %[[qkTensorBeforeReshape]] {new_shape = array} : (tensor<1x4x1024x1024xf32>) -> tensor<4x1024x1024xf32> + +// CHECK_SCALE-DAG: %[[sqkMaxs:.*]] = tosa.reduce_max %[[qkTensor]] {{.*}} : ([[squareShape]]) -> [[reducedShape:tensor<.*>]] +// CHECK_SCALE-DAG: %[[normilizedSqkTensor:.*]] = tosa.sub %[[qkTensor]], %[[sqkMaxs]] : ([[squareShape]], [[reducedShape]]) -> [[squareShape]] +// CHECK_SCALE-DAG: %[[expsTensor:.*]] = tosa.exp %[[normilizedSqkTensor]] : ([[squareShape]]) -> [[squareShape]] +// CHECK_SCALE-DAG: %[[expsSumsTensor:.*]] = tosa.reduce_sum %[[expsTensor]] {{.*}} : ([[squareShape]]) -> [[reducedShape]] +// CHECK_SCALE-DAG: %[[invExpsSums:.*]] = tosa.reciprocal %[[expsSumsTensor]] : ([[reducedShape]]) -> [[reducedShape]] +// CHECK_SCALE-DAG: %[[softmaxTensor:.*]] = tosa.mul %[[expsTensor]], %[[invExpsSums]] {{.*}} : ([[squareShape]], [[reducedShape]]) -> [[squareShape]] +// CHECK_SCALE-DAG: %[[resultTensor:.*]] = tosa.matmul %[[softmaxTensor]], %[[valuesTensor]] : ([[squareShape]], [[valuesShape]]) -> tensor<4x1024x32xf32> +// CHECK_SCALE: return + +// ---- + +// RUN: rocmlir-gen --arch gfx90a:sramecc+:xnack- --operation attention -current_seq_len=33 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1024 -seq_len_k 1024 -head_dim_qk 32 -head_dim_v 32 -t f32 -pv --apply-bufferization-pipeline=false | rocmlir-opt | FileCheck %s --enable-var-scope --check-prefixes=CHECK_NO_SCALE + +// CHECK_NO_SCALE: module attributes {mhal.arch = "[[$ARCH:.*]]"} + +// CHECK_NO_SCALE-LABEL: func.func @rock_attention +// CHECK_NO_SCALE-SAME: (%[[queriesRaw:.*0]]: memref<131072xf32>, +// CHECK_NO_SCALE-SAME: %[[keysRaw:.*1]]: memref<65536xf32>, +// CHECK_NO_SCALE-SAME: %[[valuesRaw:.*2]]: memref<65536xf32>, +// CHECK_NO_SCALE-SAME: %[[currentSeqLenRaw:.*3]]: memref<1xi32>, +// CHECK_NO_SCALE-SAME: %[[outputRaw:.*4]]: memref<131072xf32>) +// CHECK_NO_SCALE-SAME: attributes {kernel, mhal.arch = "[[$ARCH]]"} +// CHECK_NO_SCALE-NEXT: %[[queries:.*]] = rock.transform %[[queriesRaw]] {{.*}} : memref<131072xf32> to memref<4x1024x32xf32> +// CHECK_NO_SCALE-NEXT: %[[keysGQA:.*]] = rock.transform %[[keysRaw]] {{.*}} : memref<65536xf32> to memref<2x32x1024xf32> +// CHECK_NO_SCALE-NEXT: %[[valuesGQA:.*]] = rock.transform %[[valuesRaw]] {{.*}} : memref<65536xf32> to memref<2x1024x32xf32> +// CHECK_NO_SCALE-NEXT: %[[currentSeqLen:.*]] = rock.transform %[[currentSeqLenRaw]] {{.*}} : memref<1xi32> to memref<1xi32> +// CHECK_NO_SCALE-NEXT: %[[output:.*]] = rock.transform %[[outputRaw]] {{.*}} : memref<131072xf32> to memref<4x1024x32xf32> +// CHECK_NO_SCALE-NEXT: %[[currentSeqLenAddDim:.*]] = rock.transform %[[currentSeqLen]] {{.*}} : memref<1xi32> to memref<1x1xi32> +// CHECK_NO_SCALE-NEXT: %[[currentSeqLenBroadcast:.*]] = rock.transform %[[currentSeqLenAddDim]] {{.*}} : memref<1x1xi32> to memref<1x4xi32> +// CHECK_NO_SCALE-NEXT: %[[currentSeqLenMerge:.*]] = rock.transform %[[currentSeqLenBroadcast]] {{.*}} : memref<1x4xi32> to memref<4xi32> +// CHECK_NO_SCALE-NEXT: %[[keysAddDim:.*]] = rock.transform %[[keysGQA]] {{.*}} : memref<2x32x1024xf32> to memref<2x1x32x1024xf32> +// CHECK_NO_SCALE-NEXT: %[[keysBroadcast:.*]] = rock.transform %[[keysAddDim]] {{.*}} : memref<2x1x32x1024xf32> to memref<2x2x32x1024xf32> +// CHECK_NO_SCALE-NEXT: %[[keys:.*]] = rock.transform %[[keysBroadcast]] {{.*}} : memref<2x2x32x1024xf32> to memref<4x32x1024xf32> +// CHECK_NO_SCALE-NEXT: %[[valuesAddDim:.*]] = rock.transform %[[valuesGQA]] {{.*}} : memref<2x1024x32xf32> to memref<2x1x1024x32xf32> +// CHECK_NO_SCALE-NEXT: %[[valuesBroadcast:.*]] = rock.transform %[[valuesAddDim]] {{.*}} : memref<2x1x1024x32xf32> to memref<2x2x1024x32xf32> +// CHECK_NO_SCALE-NEXT: %[[values:.*]] = rock.transform %[[valuesBroadcast]] {{.*}} : memref<2x2x1024x32xf32> to memref<4x1024x32xf32> + +// CHECK_NO_SCALE-NEXT: rock.attention +// CHECK_NO_SCALE-NEXT: qk = %[[queries]] * %[[keys]] +// CHECK_NO_SCALE-NEXT: currentSeqLen = (%[[currentSeqLenMerge]] : memref<4xi32>) +// CHECK_NO_SCALE: %[[output]] = softmax(qk) * %[[values]] +// CHECK_NO_SCALE: return + +// CHECK_NO_SCALE-LABEL: func.func @host_naive_attention +// CHECK_NO_SCALE: %[[keysExpanded:.*]] = tensor.expand_shape {{.*}} output_shape [2, 1, 32, 1024] : tensor<2x32x1024xf32> into tensor<2x1x32x1024xf32> +// CHECK_NO_SCALE: %[[keysAdd:.*]] = tosa.add %{{.*}}, %[[keysExpanded]] : (tensor<2x2x32x1024xf32>, tensor<2x1x32x1024xf32>) -> tensor<2x2x32x1024xf32> +// CHECK_NO_SCALE: %[[keysTensor:.*]] = tensor.collapse_shape %[[keysAdd]] {{.*}} : tensor<2x2x32x1024xf32> into tensor<4x32x1024xf32> +// CHECK_NO_SCALE: %[[valuesExpanded:.*]] = tensor.expand_shape {{.*}} output_shape [2, 1, 1024, 32] : tensor<2x1024x32xf32> into tensor<2x1x1024x32xf32> +// CHECK_NO_SCALE: %[[valuesAdd:.*]] = tosa.add %{{.*}}, %[[valuesExpanded]] : (tensor<2x2x1024x32xf32>, tensor<2x1x1024x32xf32>) -> tensor<2x2x1024x32xf32> +// CHECK_NO_SCALE: %[[valuesTensor:.*]] = tensor.collapse_shape %[[valuesAdd]] {{.*}} : tensor<2x2x1024x32xf32> into [[valuesShape:tensor<.*>]] +// CHECK_NO_SCALE: %[[qkTensorOrig:.*]] = tosa.matmul %[[queriesTensor:.*]], %[[keysTensor:.*]] : ([[queriesShape:tensor<.*>]], [[keysShape:tensor<.*>]]) -> [[squareShape:tensor<.*>]] + +// CHECK_NO_SCALE: %[[currSeqLenTensorDumbReshaped:.*]] = tosa.reshape %[[currSeqLenTensor:.*]] {new_shape = array} : (tensor<1xi32>) -> tensor<1xi32> +// CHECK_NO_SCALE: %[[currSeqLenTensorReshaped:.*]] = tosa.reshape %[[currSeqLenTensorDumbReshaped]] {new_shape = array} : (tensor<1xi32>) -> tensor<1x1x1x1xi32> +// CHECK_NO_SCALE: %[[qkTensorReshaped:.*]] = tosa.reshape %[[qkTensorOrig]] {new_shape = array} : (tensor<4x1024x1024xf32>) -> tensor<1x4x1024x1024xf32> +// CHECK_NO_SCALE: %[[range:.*]] = "tosa.const"() <{value = {{.*}} : tensor<1024xi32>}> : () -> tensor<1024xi32> +// CHECK_NO_SCALE: %[[zero:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x4x1024x1024xi32>}> : () -> tensor<1x4x1024x1024xi32> +// CHECK_NO_SCALE: %[[rangeBroadcast:.*]] = tosa.add %[[zero]], %[[range]] : (tensor<1x4x1024x1024xi32>, tensor<1024xi32>) -> tensor<1x4x1024x1024xi32> +// CHECK_NO_SCALE: %[[currSeqLenTensorBroadcast:.*]] = tosa.add %[[zero]], %[[currSeqLenTensorReshaped]] : (tensor<1x4x1024x1024xi32>, tensor<1x1x1x1xi32>) -> tensor<1x4x1024x1024xi32> +// CHECK_NO_SCALE: %[[mask:.*]] = tosa.greater_equal %[[rangeBroadcast]], %[[currSeqLenTensorBroadcast]] : (tensor<1x4x1024x1024xi32>, tensor<1x4x1024x1024xi32>) -> tensor<1x4x1024x1024xi1> +// CHECK_NO_SCALE: %[[negInf:.*]] = "tosa.const"() <{value = dense<0xFF800000> : tensor<1x4x1024x1024xf32>}> : () -> tensor<1x4x1024x1024xf32> +// CHECK_NO_SCALE: %[[qkTensorBeforeReshape:.*]] = tosa.select %[[mask]], %[[negInf]], %[[qkTensorReshaped]] : (tensor<1x4x1024x1024xi1>, tensor<1x4x1024x1024xf32>, tensor<1x4x1024x1024xf32>) -> tensor<1x4x1024x1024xf32> +// CHECK_NO_SCALE: %[[qkTensor:.*]] = tosa.reshape %[[qkTensorBeforeReshape]] {new_shape = array} : (tensor<1x4x1024x1024xf32>) -> tensor<4x1024x1024xf32> + +// CHECK_NO_SCALE-DAG: %[[sqkMaxs:.*]] = tosa.reduce_max %[[qkTensor]] {{.*}} : ([[squareShape]]) -> [[reducedShape:tensor<.*>]] +// CHECK_NO_SCALE-DAG: %[[normilizedQkTensor:.*]] = tosa.sub %[[qkTensor]], %[[sqkMaxs]] : ([[squareShape]], [[reducedShape]]) -> [[squareShape]] +// CHECK_NO_SCALE-DAG: %[[expsTensor:.*]] = tosa.exp %[[normilizedQkTensor]] : ([[squareShape]]) -> [[squareShape]] +// CHECK_NO_SCALE-DAG: %[[expsSumsTensor:.*]] = tosa.reduce_sum %[[expsTensor]] {{.*}} : ([[squareShape]]) -> [[reducedShape]] +// CHECK_NO_SCALE-DAG: %[[invExpsSums:.*]] = tosa.reciprocal %[[expsSumsTensor]] : ([[reducedShape]]) -> [[reducedShape]] +// CHECK_NO_SCALE-DAG: %[[softmaxTensor:.*]] = tosa.mul %[[expsTensor]], %[[invExpsSums]] {{.*}} : ([[squareShape]], [[reducedShape]]) -> [[squareShape]] +// CHECK_NO_SCALE-DAG: %[[resultTensor:.*]] = tosa.matmul %[[softmaxTensor]], %[[valuesTensor:.*]] : ([[squareShape]], [[valuesShape:tensor<.*>]]) -> tensor<4x1024x32xf32> +// CHECK_NO_SCALE: return diff --git a/mlir/test/rocmlir-gen/attention-kernel-gqa.mlir b/mlir/test/rocmlir-gen/attention-kernel-gqa.mlir new file mode 100644 index 000000000000..0d58e448693f --- /dev/null +++ b/mlir/test/rocmlir-gen/attention-kernel-gqa.mlir @@ -0,0 +1,91 @@ +// RUN: rocmlir-gen --arch gfx90a:sramecc+:xnack- --operation attention -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1024 -seq_len_k 1024 -head_dim_qk 32 -head_dim_v 32 --with-attn-scale -t f32 -pv --apply-bufferization-pipeline=false | rocmlir-opt | FileCheck %s --enable-var-scope --check-prefixes=CHECK_SCALE + +// CHECK_SCALE: module attributes {mhal.arch = "[[$ARCH:.*]]"} + +// CHECK_SCALE-LABEL: func.func @rock_attention +// CHECK_SCALE-SAME: (%[[queriesRaw:.*0]]: memref<131072xf32>, +// CHECK_SCALE-SAME: %[[keysRaw:.*1]]: memref<65536xf32>, +// CHECK_SCALE-SAME: %[[valuesRaw:.*2]]: memref<65536xf32>, +// CHECK_SCALE-SAME: %[[scaleRaw:.*3]]: memref<4194304xf32>, +// CHECK_SCALE-SAME: %[[outputRaw:.*4]]: memref<131072xf32>) +// CHECK_SCALE-SAME: attributes {kernel, mhal.arch = "[[$ARCH]]"} +// CHECK_SCALE-NEXT: %[[queries:.*]] = rock.transform %[[queriesRaw]] {{.*}} : memref<131072xf32> to memref<4x1024x32xf32> +// CHECK_SCALE-NEXT: %[[keysGQA:.*]] = rock.transform %[[keysRaw]] {{.*}} : memref<65536xf32> to memref<2x32x1024xf32> +// CHECK_SCALE-NEXT: %[[valuesGQA:.*]] = rock.transform %[[valuesRaw]] {{.*}} : memref<65536xf32> to memref<2x1024x32xf32> +// CHECK_SCALE-NEXT: %[[scale:.*]] = rock.transform %[[scaleRaw]] {{.*}} : memref<4194304xf32> to memref<4x1024x1024xf32> +// CHECK_SCALE-NEXT: %[[output:.*]] = rock.transform %[[outputRaw]] {{.*}} : memref<131072xf32> to memref<4x1024x32xf32> +// CHECK_SCALE-NEXT: %[[keysAddDim:.*]] = rock.transform %[[keysGQA]] {{.*}} : memref<2x32x1024xf32> to memref<2x1x32x1024xf32> +// CHECK_SCALE-NEXT: %[[keysBroadcast:.*]] = rock.transform %[[keysAddDim]] {{.*}} : memref<2x1x32x1024xf32> to memref<2x2x32x1024xf32> +// CHECK_SCALE-NEXT: %[[keys:.*]] = rock.transform %[[keysBroadcast]] {{.*}} : memref<2x2x32x1024xf32> to memref<4x32x1024xf32> +// CHECK_SCALE-NEXT: %[[valuesAddDim:.*]] = rock.transform %[[valuesGQA]] {{.*}} : memref<2x1024x32xf32> to memref<2x1x1024x32xf32> +// CHECK_SCALE-NEXT: %[[valuesBroadcast:.*]] = rock.transform %[[valuesAddDim]] {{.*}} : memref<2x1x1024x32xf32> to memref<2x2x1024x32xf32> +// CHECK_SCALE-NEXT: %[[values:.*]] = rock.transform %[[valuesBroadcast]] {{.*}} : memref<2x2x1024x32xf32> to memref<4x1024x32xf32> + +// CHECK_SCALE-NEXT: rock.attention +// CHECK_SCALE-NEXT: qk = %[[queries]] * %[[keys]] +// CHECK_SCALE-NEXT: qk = elementwise otherIns(%[[scale]] +// CHECK_SCALE: %[[output]] = softmax(qk) * %[[values]] +// CHECK_SCALE: return + +// CHECK_SCALE-LABEL: func.func @host_naive_attention +// CHECK_SCALE: %[[keysExpanded:.*]] = tensor.expand_shape {{.*}} output_shape [2, 1, 32, 1024] : tensor<2x32x1024xf32> into tensor<2x1x32x1024xf32> +// CHECK_SCALE: %[[keysAdd:.*]] = tosa.add %{{.*}}, %[[keysExpanded]] : (tensor<2x2x32x1024xf32>, tensor<2x1x32x1024xf32>) -> tensor<2x2x32x1024xf32> +// CHECK_SCALE: %[[keysTensor:.*]] = tensor.collapse_shape %[[keysAdd]] {{.*}} : tensor<2x2x32x1024xf32> into tensor<4x32x1024xf32> +// CHECK_SCALE: %[[valuesExpanded:.*]] = tensor.expand_shape {{.*}} output_shape [2, 1, 1024, 32] : tensor<2x1024x32xf32> into tensor<2x1x1024x32xf32> +// CHECK_SCALE: %[[valuesAdd:.*]] = tosa.add %{{.*}}, %[[valuesExpanded]] : (tensor<2x2x1024x32xf32>, tensor<2x1x1024x32xf32>) -> tensor<2x2x1024x32xf32> +// CHECK_SCALE: %[[valuesTensor:.*]] = tensor.collapse_shape %[[valuesAdd]] {{.*}} : tensor<2x2x1024x32xf32> into tensor<4x1024x32xf32> +// CHECK_SCALE: %[[qkTensor:.*]] = tosa.matmul %[[queriesTensor:.*]], %[[keysTensor]] : ([[queriesShape:tensor<.*>]], [[keysShape:tensor<.*>]]) -> [[squareShape:tensor<.*>]] +// CHECK_SCALE-DAG: %[[sqkTensor:.*]] = tosa.mul %[[qkTensor]], %[[scaleTensor:.*]] {{.*}} : ([[squareShape]], [[squareShape]]) -> [[squareShape]] +// CHECK_SCALE-DAG: %[[sqkMaxs:.*]] = tosa.reduce_max %[[sqkTensor]] {{.*}} : ([[squareShape]]) -> [[reducedShape:tensor<.*>]] +// CHECK_SCALE-DAG: %[[normilizedSqkTensor:.*]] = tosa.sub %[[sqkTensor]], %[[sqkMaxs]] : ([[squareShape]], [[reducedShape]]) -> [[squareShape]] +// CHECK_SCALE-DAG: %[[expsTensor:.*]] = tosa.exp %[[normilizedSqkTensor]] : ([[squareShape]]) -> [[squareShape]] +// CHECK_SCALE-DAG: %[[expsSumsTensor:.*]] = tosa.reduce_sum %[[expsTensor]] {{.*}} : ([[squareShape]]) -> [[reducedShape]] +// CHECK_SCALE-DAG: %[[invExpsSums:.*]] = tosa.reciprocal %[[expsSumsTensor]] : ([[reducedShape]]) -> [[reducedShape]] +// CHECK_SCALE-DAG: %[[softmaxTensor:.*]] = tosa.mul %[[expsTensor]], %[[invExpsSums]] {{.*}} : ([[squareShape]], [[reducedShape]]) -> [[squareShape]] +// CHECK_SCALE-DAG: %[[resultTensor:.*]] = tosa.matmul %[[softmaxTensor]], %[[valuesTensor]] : ([[squareShape]], [[valuesShape:tensor<.*>]]) -> [[valuesShape]] +// CHECK_SCALE: return + +// ---- + +// RUN: rocmlir-gen --arch gfx90a:sramecc+:xnack- --operation attention -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1024 -seq_len_k 1024 -head_dim_qk 32 -head_dim_v 32 -t f32 -pv --apply-bufferization-pipeline=false | rocmlir-opt | FileCheck %s --enable-var-scope --check-prefixes=CHECK_NO_SCALE + +// CHECK_NO_SCALE: module attributes {mhal.arch = "[[$ARCH:.*]]"} + +// CHECK_NO_SCALE-LABEL: func.func @rock_attention +// CHECK_NO_SCALE-SAME: (%[[queriesRaw:.*0]]: memref<131072xf32>, +// CHECK_NO_SCALE-SAME: %[[keysRaw:.*1]]: memref<65536xf32>, +// CHECK_NO_SCALE-SAME: %[[valuesRaw:.*2]]: memref<65536xf32>, +// CHECK_NO_SCALE-SAME: %[[outputRaw:.*3]]: memref<131072xf32>) +// CHECK_NO_SCALE-SAME: attributes {kernel, mhal.arch = "[[$ARCH]]"} +// CHECK_NO_SCALE-NEXT: %[[queries:.*]] = rock.transform %[[queriesRaw]] {{.*}} : memref<131072xf32> to memref<4x1024x32xf32> +// CHECK_NO_SCALE-NEXT: %[[keysGQA:.*]] = rock.transform %[[keysRaw]] {{.*}} : memref<65536xf32> to memref<2x32x1024xf32> +// CHECK_NO_SCALE-NEXT: %[[valuesGQA:.*]] = rock.transform %[[valuesRaw]] {{.*}} : memref<65536xf32> to memref<2x1024x32xf32> +// CHECK_NO_SCALE-NEXT: %[[output:.*]] = rock.transform %[[outputRaw]] {{.*}} : memref<131072xf32> to memref<4x1024x32xf32> +// CHECK_NO_SCALE-NEXT: %[[keysAddDim:.*]] = rock.transform %[[keysGQA]] {{.*}} : memref<2x32x1024xf32> to memref<2x1x32x1024xf32> +// CHECK_NO_SCALE-NEXT: %[[keysBroadcast:.*]] = rock.transform %[[keysAddDim]] {{.*}} : memref<2x1x32x1024xf32> to memref<2x2x32x1024xf32> +// CHECK_NO_SCALE-NEXT: %[[keys:.*]] = rock.transform %[[keysBroadcast]] {{.*}} : memref<2x2x32x1024xf32> to memref<4x32x1024xf32> +// CHECK_NO_SCALE-NEXT: %[[valuesAddDim:.*]] = rock.transform %[[valuesGQA]] {{.*}} : memref<2x1024x32xf32> to memref<2x1x1024x32xf32> +// CHECK_NO_SCALE-NEXT: %[[valuesBroadcast:.*]] = rock.transform %[[valuesAddDim]] {{.*}} : memref<2x1x1024x32xf32> to memref<2x2x1024x32xf32> +// CHECK_NO_SCALE-NEXT: %[[values:.*]] = rock.transform %[[valuesBroadcast]] {{.*}} : memref<2x2x1024x32xf32> to memref<4x1024x32xf32> + +// CHECK_NO_SCALE-NEXT: rock.attention +// CHECK_NO_SCALE-NEXT: qk = %[[queries]] * %[[keys]] +// CHECK_NO_SCALE: %[[output]] = softmax(qk) * %[[values]] +// CHECK_NO_SCALE: return + +// CHECK_NO_SCALE-LABEL: func.func @host_naive_attention +// CHECK_NO_SCALE: %[[keysExpanded:.*]] = tensor.expand_shape {{.*}} output_shape [2, 1, 32, 1024] : tensor<2x32x1024xf32> into tensor<2x1x32x1024xf32> +// CHECK_NO_SCALE: %[[keysAdd:.*]] = tosa.add %{{.*}}, %[[keysExpanded]] : (tensor<2x2x32x1024xf32>, tensor<2x1x32x1024xf32>) -> tensor<2x2x32x1024xf32> +// CHECK_NO_SCALE: %[[keysTensor:.*]] = tensor.collapse_shape %[[keysAdd]] {{.*}} : tensor<2x2x32x1024xf32> into tensor<4x32x1024xf32> +// CHECK_NO_SCALE: %[[valuesExpanded:.*]] = tensor.expand_shape {{.*}} output_shape [2, 1, 1024, 32] : tensor<2x1024x32xf32> into tensor<2x1x1024x32xf32> +// CHECK_NO_SCALE: %[[valuesAdd:.*]] = tosa.add %{{.*}}, %[[valuesExpanded]] : (tensor<2x2x1024x32xf32>, tensor<2x1x1024x32xf32>) -> tensor<2x2x1024x32xf32> +// CHECK_NO_SCALE: %[[valuesTensor:.*]] = tensor.collapse_shape %[[valuesAdd]] {{.*}} : tensor<2x2x1024x32xf32> into tensor<4x1024x32xf32> +// CHECK_NO_SCALE: %[[qkTensor:.*]] = tosa.matmul %[[queriesTensor:.*]], %[[keysTensor:.*]] : ([[queriesShape:tensor<.*>]], [[keysShape:tensor<.*>]]) -> [[squareShape:tensor<.*>]] +// CHECK_NO_SCALE-DAG: %[[sqkMaxs:.*]] = tosa.reduce_max %[[qkTensor]] {{.*}} : ([[squareShape]]) -> [[reducedShape:tensor<.*>]] +// CHECK_NO_SCALE-DAG: %[[normilizedQkTensor:.*]] = tosa.sub %[[qkTensor]], %[[sqkMaxs]] : ([[squareShape]], [[reducedShape]]) -> [[squareShape]] +// CHECK_NO_SCALE-DAG: %[[expsTensor:.*]] = tosa.exp %[[normilizedQkTensor]] : ([[squareShape]]) -> [[squareShape]] +// CHECK_NO_SCALE-DAG: %[[expsSumsTensor:.*]] = tosa.reduce_sum %[[expsTensor]] {{.*}} : ([[squareShape]]) -> [[reducedShape]] +// CHECK_NO_SCALE-DAG: %[[invExpsSums:.*]] = tosa.reciprocal %[[expsSumsTensor]] : ([[reducedShape]]) -> [[reducedShape]] +// CHECK_NO_SCALE-DAG: %[[softmaxTensor:.*]] = tosa.mul %[[expsTensor]], %[[invExpsSums]] {{.*}} : ([[squareShape]], [[reducedShape]]) -> [[squareShape]] +// CHECK_NO_SCALE-DAG: %[[resultTensor:.*]] = tosa.matmul %[[softmaxTensor]], %[[valuesTensor:.*]] : ([[squareShape]], [[valuesShape:tensor<.*>]]) -> [[valuesShape]] +// CHECK_NO_SCALE: return diff --git a/mlir/test/rocmlir-gen/problem-key.mlir b/mlir/test/rocmlir-gen/problem-key.mlir index 255544da7d8b..cc93d00742c2 100644 --- a/mlir/test/rocmlir-gen/problem-key.mlir +++ b/mlir/test/rocmlir-gen/problem-key.mlir @@ -4,8 +4,16 @@ // CHECK_2: -t f16 -transQ false -transK false -transV false -transO false -g 4 -seq_len_q 256 -seq_len_k 512 -head_dim_qk 64 -head_dim_v 32 // RUN: rocmlir-gen --arch gfx942 --operation attention -seq_len_q 256 -seq_len_k 512 -head_dim_qk 64 -head_dim_v 32 -t i8 -g 8 | rocmlir-gen --emit-tuning-key - | FileCheck %s --check-prefixes=CHECK_3 // CHECK_3: -t i8 -transQ false -transK false -transV false -transO false -g 8 -seq_len_q 256 -seq_len_k 512 -head_dim_qk 64 -head_dim_v 32 +// RUN: rocmlir-gen --arch gfx942 --operation attention -num_heads_q 4 -num_heads_kv 4 -seq_len_q 256 -seq_len_k 512 -head_dim_qk 64 -head_dim_v 32 -t i8 -g 8 | rocmlir-gen --emit-tuning-key - | FileCheck %s --check-prefixes=CHECK_4 +// CHECK_4: -t i8 -transQ false -transK false -transV false -transO false -g 32 -seq_len_q 256 -seq_len_k 512 -head_dim_qk 64 -head_dim_v 32 +// RUN: rocmlir-gen --arch gfx942 --operation attention -num_heads_q 4 -num_heads_kv 2 -seq_len_q 256 -seq_len_k 512 -head_dim_qk 64 -head_dim_v 32 -t i8 -g 8 | rocmlir-gen --emit-tuning-key - | FileCheck %s --check-prefixes=CHECK_5 +// CHECK_5: -t i8 -transQ false -transK false -transV false -transO false -g 32 -seq_len_q 256 -seq_len_k 512 -head_dim_qk 64 -head_dim_v 32 +// RUN: rocmlir-gen --arch gfx942 --operation attention -current_seq_len=16 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 256 -seq_len_k 512 -head_dim_qk 64 -head_dim_v 32 -t i8 -g 1 | rocmlir-gen --emit-tuning-key - | FileCheck %s --check-prefixes=CHECK_6 +// CHECK_6: -t i8 -transQ false -transK false -transV false -transO false -g 4 -seq_len_q 256 -seq_len_k 512 -head_dim_qk 64 -head_dim_v 32 +// RUN: rocmlir-gen --arch gfx942 --operation attention -current_seq_len=16,16,17,1,30,40,38,12 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 256 -seq_len_k 512 -head_dim_qk 64 -head_dim_v 32 -t i8 -g 8 | rocmlir-gen --emit-tuning-key - | FileCheck %s --check-prefixes=CHECK_7 +// CHECK_7: -t i8 -transQ false -transK false -transV false -transO false -g 32 -seq_len_q 256 -seq_len_k 512 -head_dim_qk 64 -head_dim_v 32 // Checking numCU -// RUN: rocmlir-gen --arch gfx942 --num_cu 304 --operation attention -seq_len_q 256 -seq_len_k 512 -head_dim_qk 64 -head_dim_v 32 -t f16 -g 4 | rocmlir-gen --emit-tuning-key - | FileCheck %s --check-prefixes=CHECK_4 -// CHECK_4: 304 +// RUN: rocmlir-gen --arch gfx942 --num_cu 304 --operation attention -seq_len_q 256 -seq_len_k 512 -head_dim_qk 64 -head_dim_v 32 -t f16 -g 4 | rocmlir-gen --emit-tuning-key - | FileCheck %s --check-prefixes=CHECK_NUMCU +// CHECK_NUMCU: 304 diff --git a/mlir/tools/rocmlir-gen/rocmlir-gen.cpp b/mlir/tools/rocmlir-gen/rocmlir-gen.cpp index ab993d97b043..9e45182a4289 100644 --- a/mlir/tools/rocmlir-gen/rocmlir-gen.cpp +++ b/mlir/tools/rocmlir-gen/rocmlir-gen.cpp @@ -508,6 +508,23 @@ static llvm::cl::opt emitTuningKey( // Attention related args // ---------------------- +static llvm::cl::opt + numHeadsQ("num_heads_q", + llvm::cl::desc("number of heads of Q in attention()"), + llvm::cl::value_desc("positive integer"), llvm::cl::init(1)); + +static llvm::cl::opt + numHeadsKV("num_heads_kv", + llvm::cl::desc("number of heads of K,V in attention()"), + llvm::cl::value_desc("positive integer"), llvm::cl::init(1)); + +static llvm::cl::list + currentSeqLen("current_seq_len", + llvm::cl::desc("List of sequence lengths of K and V (related " + "to KV-cache) in attention()"), + llvm::cl::value_desc("list of positive integers"), + llvm::cl::CommaSeparated); + static llvm::cl::opt sequenceLengthQ( "seq_len_q", llvm::cl::desc("sequence length of Q in attention()"), llvm::cl::value_desc("positive integer"), llvm::cl::init(-1)); @@ -819,6 +836,7 @@ struct AttentionQuantizedArgIndex { static const size_t quantScale = 4; static const size_t scale = 5; static const size_t bias = 6; + static const size_t currentSeqLen = 7; }; struct AttentionArgIndex { @@ -827,6 +845,7 @@ struct AttentionArgIndex { static const size_t v = 2; static const size_t scale = 3; static const size_t bias = 4; + static const size_t currentSeqLen = 5; }; struct GenParams { @@ -1025,6 +1044,8 @@ static void populateDefaults() { groupSize = 1; sequenceLengthQ = 1024; sequenceLengthK = 1024; + numHeadsQ = 1; + numHeadsKV = 1; headDimQK = 32; headDimV = 32; } @@ -2156,14 +2177,20 @@ static func::FuncOp createGpuGemmKernel(ModuleOp module, static void getAttentionTypes(SmallVectorImpl &result, ArrayRef elemTypes) { - SmallVector qDims{groupSize, sequenceLengthQ, headDimQK}; - SmallVector transposedQDims{groupSize, headDimQK, sequenceLengthQ}; - SmallVector kDims{groupSize, sequenceLengthK, headDimQK}; - SmallVector transposedKDims{groupSize, headDimQK, sequenceLengthK}; - SmallVector vDims{groupSize, sequenceLengthK, headDimV}; - SmallVector transposedVDims{groupSize, headDimV, sequenceLengthK}; - SmallVector oDims{groupSize, sequenceLengthQ, headDimV}; - SmallVector transposedODims{groupSize, headDimV, sequenceLengthQ}; + SmallVector qDims{groupSize * numHeadsQ, sequenceLengthQ, headDimQK}; + SmallVector transposedQDims{groupSize * numHeadsQ, headDimQK, + sequenceLengthQ}; + SmallVector kDims{groupSize * numHeadsKV, sequenceLengthK, + headDimQK}; + SmallVector transposedKDims{groupSize * numHeadsKV, headDimQK, + sequenceLengthK}; + SmallVector vDims{groupSize * numHeadsKV, sequenceLengthK, headDimV}; + SmallVector transposedVDims{groupSize * numHeadsKV, headDimV, + sequenceLengthK}; + SmallVector oDims{groupSize * numHeadsQ, sequenceLengthQ, headDimV}; + SmallVector transposedODims{groupSize * numHeadsQ, headDimV, + sequenceLengthQ}; + bool isQuantized = elemTypes[0] == IntegerType::get(elemTypes[0].getContext(), 8); @@ -2177,6 +2204,9 @@ static void getAttentionTypes(SmallVectorImpl &result, : AttentionArgIndex::scale; const size_t biasIndex = isQuantized ? AttentionQuantizedArgIndex::bias : AttentionArgIndex::bias; + const size_t currentSeqLenIndex = + isQuantized ? AttentionQuantizedArgIndex::currentSeqLen + : AttentionArgIndex::currentSeqLen; const size_t outputIndex = biasIndex; MemRefType qType = MemRefType::get(transposeQ ? transposedQDims : qDims, @@ -2202,15 +2232,23 @@ static void getAttentionTypes(SmallVectorImpl &result, result.push_back(qsType); } if (hasAttnScale) { - SmallVector scaleDims{groupSize, sequenceLengthQ, sequenceLengthK}; + SmallVector scaleDims{groupSize * numHeadsQ, sequenceLengthQ, + sequenceLengthK}; MemRefType sType = MemRefType::get(scaleDims, elemTypes[scaleIndex]); result.push_back(sType); } if (hasAttnBias) { - SmallVector biasDims{groupSize, sequenceLengthQ, sequenceLengthK}; + SmallVector biasDims{groupSize * numHeadsQ, sequenceLengthQ, + sequenceLengthK}; MemRefType bType = MemRefType::get(biasDims, elemTypes[biasIndex]); result.push_back(bType); } + if (!currentSeqLen.empty()) { + SmallVector currentSeqDims{groupSize}; + MemRefType currSeqLenType = + MemRefType::get(currentSeqDims, elemTypes[currentSeqLenIndex]); + result.push_back(currSeqLenType); + } MemRefType outType = MemRefType::get(transposeO ? transposedODims : oDims, elemTypes[outputIndex]); result.push_back(outType); @@ -2243,7 +2281,8 @@ getAttentionDimNames(SmallVectorImpl> &result, result.emplace_back(SmallVector{gName, seqQName, seqKName}); if (hasAttnBias) result.emplace_back(SmallVector{gName, seqQName, seqKName}); - + if (!currentSeqLen.empty()) + result.emplace_back(SmallVector{gName}); if (transposeO) result.emplace_back(SmallVector{gName, headVName, seqQName}); else @@ -2278,6 +2317,182 @@ Value addTensorArgToBlock(OpBuilder &builder, Location loc, return funcArgTensor; } +template +static Value maskKVCacheTosa(OpBuilder builder, Location loc, Value inputTensor, + Value currentSeqLenVal, T initValue) { + // inputTensor is [B*NUM_HEADS, SEQ_LEN_Q, SEQ_LEN_KV], we want to reshape to + // [B, NUM_HEADS, SEQ_LEN_Q, SEQ_LEN_KV] + auto origType = cast(inputTensor.getType()); + ArrayRef origShape = origType.getShape(); + SmallVector newShape = {origShape[0] / numHeadsQ, numHeadsQ, + origShape[1], origShape[2]}; + inputTensor = createOpAndInfer( + builder, loc, origType.getElementType(), inputTensor, newShape); + + auto inpType = cast(inputTensor.getType()); + ArrayRef inpShape = inpType.getShape(); + assert(static_cast(currentSeqLen.size()) == inpShape[0] && + "Number of current sequence lenght must match batch dimension"); + for (auto v : currentSeqLen) + assert(v > 0 && v <= inpShape[3]); + + // create range 0 to inpShape[axis] + llvm::SmallVector range; + range.reserve(inpShape[3]); + for (int i = 0; i < inpShape[3]; i++) + range.push_back(i); + DenseElementsAttr rangeAttr = DenseIntElementsAttr::get( + RankedTensorType::get({inpShape[3]}, builder.getI32Type()), range); + Value rangeVal = + builder.create(loc, rangeAttr.getType(), rangeAttr); + + // broadcast range to inputTensor shape + auto outType = RankedTensorType::get(inpShape, builder.getI32Type()); + auto zeroValue = cast(builder.getZeroAttr(outType)); + auto zeroTensor = builder.create(loc, outType, zeroValue); + auto rangeBroadcast = createOpAndInfer( + builder, loc, builder.getI32Type(), zeroTensor, rangeVal); + + // broadcast currentSeqLen + auto currentSeqLenBroadcast = createOpAndInfer( + builder, loc, builder.getI32Type(), zeroTensor, currentSeqLenVal); + + // create mask + auto mask = createOpAndInfer( + builder, loc, builder.getIntegerType(1), rangeBroadcast, + currentSeqLenBroadcast); + + // create a tensor with a single value and broadcast it + DenseElementsAttr initValueAttr; + if constexpr (std::is_same_v) { + assert(inpType.getElementType() == builder.getI32Type()); + initValueAttr = DenseIntElementsAttr::get( + RankedTensorType::get(inpShape, inpType.getElementType()), initValue); + } else if constexpr (std::is_same_v) { + assert(inpType.getElementType() == builder.getF32Type() || + inpType.getElementType() == builder.getF16Type()); + llvm::APFloat fpVal(initValue); + if (inpType.getElementType() == builder.getF16Type()) { + bool losesInfo = false; + auto status = + fpVal.convert(llvm::APFloat::IEEEhalf(), + llvm::APFloat::rmNearestTiesToEven, &losesInfo); + assert(status == llvm::APFloat::opOK); + } + initValueAttr = DenseFPElementsAttr::get( + RankedTensorType::get(inpShape, inpType.getElementType()), fpVal); + } else { + static_assert(!std::is_same_v, + "Unsupported type for MLIR type mapping"); + } + Value initVal = builder.create(loc, initValueAttr.getType(), + initValueAttr); + + // mask is 1 for values we want to set to -inf, initVal=-inf + auto result = createOpAndInfer( + builder, loc, inpType.getElementType(), mask, initVal, inputTensor); + + // reshape result back to [B*NUM_HEADS, SEQ_LEN_Q, SEQ_LEN_KV] + auto resultReshaped = createOpAndInfer( + builder, loc, inpType.getElementType(), result, origShape); + + return resultReshaped; +} + +static Value broadcastGQATosa(OpBuilder builder, Location loc, + Value inputTensor) { + assert(numHeadsQ % numHeadsKV == 0); + + if (numHeadsQ == numHeadsKV) + return inputTensor; + + int64_t numRepeat = numHeadsQ / numHeadsKV; + + auto inpType = cast(inputTensor.getType()); + ArrayRef inpShape = inpType.getShape(); + + // add one dimension + SmallVector reassocIndices = {{0, 1}, {2}, {3}}; + SmallVector expandedShape = {inpShape[0], 1, inpShape[1], + inpShape[2]}; + auto newType = RankedTensorType::get(expandedShape, inpType.getElementType()); + auto expandedValue = builder.create( + loc, newType, inputTensor, reassocIndices); + + // broadcast + SmallVector outShape = {inpShape[0], numRepeat, inpShape[1], + inpShape[2]}; + auto outType = RankedTensorType::get(outShape, inpType.getElementType()); + + auto zeroValue = cast(builder.getZeroAttr(outType)); + auto zeroTensor = builder.create(loc, outType, zeroValue); + auto addWithZero = createOpAndInfer( + builder, loc, inpType.getElementType(), zeroTensor, expandedValue); + + // collapse + return builder.create(loc, addWithZero, + reassocIndices); +} + +static Value broadcastKVCacheRock(OpBuilder builder, Location loc, + Value inputTensor) { + ArrayRef inpShape = + cast(inputTensor.getType()).getShape(); + assert(static_cast(currentSeqLen.size()) == inpShape[0] && + "Number of current sequence lenght must match batch dimension"); + SmallVector startNames = {"gemmG"}; + rock::BottomUpTMBuilder addDim(builder, startNames, inpShape); + addDim.addDim("seqLen", 1, 1); + addDim.passThrough(ArrayRef{0}, ArrayRef{0}); + auto addDimAttr = addDim.get(); + Value matrixAddDim = + builder.create(loc, inputTensor, addDimAttr); + + auto broadcaster = rock::BottomUpTMBuilder::above(addDim, addDimAttr); + broadcaster.broadcast({1}, {numHeadsQ}); + broadcaster.passThrough(ArrayRef{0}, ArrayRef{0}); + auto broadcasterAttr = broadcaster.get(); + Value tensorBroadcast = + builder.create(loc, matrixAddDim, broadcasterAttr); + + auto merger = rock::BottomUpTMBuilder::above(broadcaster, broadcasterAttr); + merger.merge("gemmG", 0, {"gemmG", "seqLen"}); + auto mergerAttr = merger.get(); + return builder.create(loc, tensorBroadcast, mergerAttr); +} + +static Value broadcastGQARock(OpBuilder builder, Location loc, + Value inputTensor) { + assert(numHeadsQ % numHeadsKV == 0); + + if (numHeadsQ == numHeadsKV) + return inputTensor; + + int64_t numRepeats = numHeadsQ / numHeadsKV; + ArrayRef inpShape = + cast(inputTensor.getType()).getShape(); + SmallVector startNames = {"gemmG", "seqLen", "headDim"}; + rock::BottomUpTMBuilder addDim(builder, startNames, inpShape); + addDim.addDim("broadcastDim", 1, 1); + addDim.passThrough({0, 2, 3}, {0, 1, 2}); + auto addDimAttr = addDim.get(); + Value matrixAddDim = + builder.create(loc, inputTensor, addDimAttr); + + auto broadcaster = rock::BottomUpTMBuilder::above(addDim, addDimAttr); + broadcaster.broadcast({1}, {numRepeats}); + broadcaster.passThrough({0, 2, 3}, {0, 2, 3}); + auto broadcasterAttr = broadcaster.get(); + Value tensorBroadcast = + builder.create(loc, matrixAddDim, broadcasterAttr); + + auto merger = rock::BottomUpTMBuilder::above(broadcaster, broadcasterAttr); + merger.merge("gemmG", 0, {"gemmG", "broadcastDim"}); + merger.passThrough({1, 2}, {2, 3}); + auto mergerAttr = merger.get(); + return builder.create(loc, tensorBroadcast, mergerAttr); +} + static func::FuncOp createGpuAttentionKernel(ModuleOp module, const GenParams ¶ms) { MLIRContext *ctx = module.getContext(); @@ -2325,6 +2540,7 @@ static func::FuncOp createGpuAttentionKernel(ModuleOp module, Value scale; Value bias; Value output; + Value currentSeqLenTensor; SmallVector elemwiseInputs; unsigned optionalArgsCounter = 3; @@ -2342,15 +2558,22 @@ static func::FuncOp createGpuAttentionKernel(ModuleOp module, bias = unflattenedArgs[optionalArgsCounter++]; elemwiseInputs.push_back(bias); } + if (!currentSeqLen.empty()) { + currentSeqLenTensor = broadcastKVCacheRock( + builder, loc, unflattenedArgs[optionalArgsCounter++]); + } output = unflattenedArgs[optionalArgsCounter]; + keys = broadcastGQARock(builder, loc, keys); + values = broadcastGQARock(builder, loc, values); + IntegerAttr numCUAttr = (num_cu.getNumOccurrences() > 0 ? builder.getI32IntegerAttr(num_cu) : nullptr); auto attention = builder.create( - loc, TypeRange{}, queries, keys, values, elemwiseInputs, output, - transposeQ, transposeK, transposeV, transposeO, archAttr, params.features, - numCUAttr, + loc, TypeRange{}, queries, keys, values, elemwiseInputs, + currentSeqLenTensor, output, transposeQ, transposeK, transposeV, + transposeO, archAttr, params.features, numCUAttr, /*params0=*/nullptr, /*params1=*/nullptr, /*firstGemmIdx=*/0); { Block *preSoftmaxElemwiseBlock = @@ -2575,12 +2798,37 @@ static func::FuncOp createCpuAttentionKernelWithMlir(ModuleOp module, if (transposeV) { valuesTensor = transposeMatrix(builder, loc, valuesTensor, {0, 2, 1}); } + // GQA + keysTensor = broadcastGQATosa(builder, loc, keysTensor); + valuesTensor = broadcastGQATosa(builder, loc, valuesTensor); + Type firstGemmOutElemType = params.types[0]; if (isQuantized) { firstGemmOutElemType = IntegerType::get(ctx, 32); } Value qkTensor = createOpAndInfer( builder, loc, firstGemmOutElemType, queriesTensor, keysTensor); + + // get currentSeqLenTensor + Value currentSeqLenTensor; + if (!currentSeqLen.empty()) { + unsigned seqLenCounter = 3; + if (isQuantized) + seqLenCounter += 2; + if (hasAttnScale) + seqLenCounter++; + if (hasAttnBias) + seqLenCounter++; + auto currentSeqLenTensorRaw = getTensorForBlockArg(seqLenCounter); + auto type = cast(currentSeqLenTensorRaw.getType()); + ArrayRef shape = type.getShape(); + assert(shape.size() == 1); + + currentSeqLenTensor = createOpAndInfer( + builder, loc, type.getElementType(), currentSeqLenTensorRaw, + builder.getDenseI64ArrayAttr({shape[0], 1, 1, 1})); + } + unsigned optionalArgsCounter = 3; if (isQuantized) { auto quantBiasI8 = getTensorForBlockArg(optionalArgsCounter++); @@ -2597,6 +2845,10 @@ static func::FuncOp createCpuAttentionKernelWithMlir(ModuleOp module, } if (hasAttnScale) { auto scaleTensor = getTensorForBlockArg(optionalArgsCounter++); + if (!currentSeqLen.empty()) + scaleTensor = + maskKVCacheTosa(builder, loc, scaleTensor, currentSeqLenTensor, 1.0f); + qkTensor = createOpAndInfer( builder, loc, cast(scaleTensor.getType()).getElementType(), qkTensor, scaleTensor, /*shift=*/0); @@ -2604,11 +2856,20 @@ static func::FuncOp createCpuAttentionKernelWithMlir(ModuleOp module, if (hasAttnBias) { auto biasTensor = getTensorForBlockArg(optionalArgsCounter++); + if (!currentSeqLen.empty()) + biasTensor = + maskKVCacheTosa(builder, loc, biasTensor, currentSeqLenTensor, 0.0f); + qkTensor = createOpAndInfer( builder, loc, cast(biasTensor.getType()).getElementType(), qkTensor, biasTensor); } + if (currentSeqLenTensor) { + qkTensor = maskKVCacheTosa(builder, loc, qkTensor, currentSeqLenTensor, + -std::numeric_limits::infinity()); + } + constexpr int64_t reductionAxis = 2; auto qkMaxs = createOpAndInfer( builder, loc, cast(qkTensor.getType()).getElementType(), @@ -2632,9 +2893,11 @@ static func::FuncOp createCpuAttentionKernelWithMlir(ModuleOp module, #ifdef ROCK_DEBUG_ATTENTION_REMOVE_SOFTMAX softmaxTensor = qkTensor; #endif + auto resultOutElementType = + cast(softmaxTensor.getType()).getElementType(); Value resultTensor = createOpAndInfer( - builder, loc, cast(softmaxTensor.getType()).getElementType(), - softmaxTensor, valuesTensor); + builder, loc, resultOutElementType, softmaxTensor, valuesTensor); + if (transposeO) { resultTensor = transposeMatrix(builder, loc, resultTensor, {0, 2, 1}); } @@ -3036,7 +3299,6 @@ static void insertValidationCalls(const GenParams &genParams, OpBuilder &b, } // generate all sub-kernels, and get corresponding gemmId std::string kernelBaseName = genConfig.kernelBaseName; - llvm::errs() << kernelBaseName << "\n"; for (int i = kernelStart; i < kernelCount; ++i) { convGenerator.setKernelName(kernelBaseName + "_" + std::to_string(i)); if (failed(convGenerator.genConvModule(module, i, true, @@ -3145,6 +3407,7 @@ static void insertValidationCalls(const GenParams &genParams, OpBuilder &b, root0.func.getName().str() + "_verify" + std::to_string(outIdx); auto verifierFunc = createVerifierFunc(module, root0, testType, valType, funcName); + b.create(loc, verifierFunc, ValueRange{testResult, valResult}); } @@ -3196,6 +3459,7 @@ static LogicalResult populateHostHarnessLogic( b.create(loc, seedFunc, seedConst); } + bool isAttention = false; SmallVector outIndices; if (genParams.operation.has_value()) { switch (genParams.operation.value()) { @@ -3210,6 +3474,7 @@ static LogicalResult populateHostHarnessLogic( outIndices.push_back(0); break; case rock::KernelType::Attention: + isAttention = true; int32_t optionalArgsCounter{3}; bool isQuantized = genParams.types[0] == b.getI8Type(); if (isQuantized) @@ -3218,6 +3483,8 @@ static LogicalResult populateHostHarnessLogic( ++optionalArgsCounter; if (hasAttnBias) ++optionalArgsCounter; + if (!currentSeqLen.empty()) + ++optionalArgsCounter; outIndices.push_back(optionalArgsCounter); } } else { @@ -3243,7 +3510,18 @@ static LogicalResult populateHostHarnessLogic( } auto lvar = b.create(loc, paramMRType); localVars.push_back(lvar); - if (!isRandom) { + + if (!currentSeqLen.empty() && isAttention && + idx == root0.params.size() - 2) { + // fill with currentSeqLen + // as it's very small, just define constant and store directly + for (auto pair : llvm::enumerate(currentSeqLen)) { + Value index = b.create(loc, pair.index()); + Value value = + b.create(loc, pair.value(), b.getI32Type()); + b.create(loc, value, lvar, ValueRange{index}); + } + } else if (!isRandom) { SmallVector initPattern = getTensorInitPattern(elemType); if (failed(populateTensorFillLogic(b, loc, initPattern, elemType, lvar))) return failure(); @@ -3533,7 +3811,7 @@ static void generateKernel(MLIRContext *context, GenParams &genParams, // We only support first-gemm i8 version of attention // This will be changed when we support both gemms of i8. if (elemType == IntegerType::get(context, 8)) { - constexpr size_t maxNumArgs{7}; + constexpr size_t maxNumArgs{8}; genParams.types.resize(maxNumArgs); genParams.types[AttentionQuantizedArgIndex::q] = IntegerType::get(context, 8); @@ -3549,6 +3827,8 @@ static void generateKernel(MLIRContext *context, GenParams &genParams, Float16Type::get(context); genParams.types[AttentionQuantizedArgIndex::bias] = Float16Type::get(context); + genParams.types[AttentionQuantizedArgIndex::currentSeqLen] = + IntegerType::get(context, 32); } else { constexpr size_t maxNumArgs{5}; // Note: In the current implementation, all operands have the same type. @@ -3556,6 +3836,8 @@ static void generateKernel(MLIRContext *context, GenParams &genParams, for (size_t argIdx{0}; argIdx < maxNumArgs; ++argIdx) { genParams.types.push_back(elemType); } + // extra operand: currentSeqLen + genParams.types.push_back(IntegerType::get(context, 32)); } genParams.convConfig = std::nullopt; (void)createGpuAttentionKernel(module, genParams);