Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GQA and KV Cache #1696

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions mlir/include/mlir/Dialect/Rock/IR/RockOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,13 @@ def Rock_ReduceOp :
}

def Rock_AttentionOp :
Rock_Op<"attention", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, RockFusionRoot]>,
Rock_Op<"attention", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, RockFusionRoot, AttrSizedOperandSegments]>,
Arguments<(ins
TensorOrMemRefOf<[F32, F16, I8]>:$queries,
TensorOrMemRefOf<[F32, F16, I8]>:$keys,
TensorOrMemRefOf<[F32, F16]>:$values,
Variadic<TensorOrMemRefOf<[F32, F16, I8]>>:$preSoftmaxElemWiseInputs,
Optional<TensorOrMemRefOf<[I32]>>:$currentSeqLen,
TensorOrMemRefOf<[F32, F16]>:$out,
UnitAttr:$qTransposed,
UnitAttr:$kTransposed,
Expand Down Expand Up @@ -244,6 +245,7 @@ def Rock_AttentionOp :
let assemblyFormat = [{
`{` `\n`
` ` `qk` `=` (`tr` $qTransposed^)? $queries `*` (`tr` $kTransposed^)? $keys `:` type($queries) `,` type($keys) `\n`
(`currentSeqLen` `=` `(` $currentSeqLen^ `:` type($currentSeqLen) `)` `\n`)?
(`qk` `=` `elementwise` (`otherIns` `(` $preSoftmaxElemWiseInputs^ `:` type($preSoftmaxElemWiseInputs) `)`)? $preSoftmaxBody^ `\n`)?
(`tr` $oTransposed^)? $out `=` `softmax` `(` `qk` `)` `*` (`tr` $vTransposed^)? $values `:` type($values) `->` type($out) `\n`
`}` attr-dict (`->` type($result)^)?
Expand Down Expand Up @@ -431,11 +433,12 @@ def Rock_GridwiseGemmAccelOp :

// gridwise_attention_accel
def Rock_GridwiseAttentionAccelOp :
Rock_Op<"gridwise_attention_accel", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, RockFusionRoot]>,
Rock_Op<"gridwise_attention_accel", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, RockFusionRoot, AttrSizedOperandSegments]>,
Arguments<(ins MemRefRankOf<[F32, F16, I8], [3]>:$queries,
MemRefRankOf<[F32, F16, I8], [3]>:$keys,
MemRefRankOf<[F32, F16], [3]>:$values,
Variadic<TensorOrMemRefOf<[F32, F16, I8]>>:$preSoftmaxElemWiseInputs,
Optional<MemRefRankOf<[I32], [1]>>:$currentSeqLen,
MemRefRankOf<[F32, F16], [3]>:$out,
StrAttr:$arch,
Rock_GemmFeaturesAttr:$features,
Expand Down Expand Up @@ -677,7 +680,7 @@ def Rock_TransformingForOp :
Results<(outs Variadic<AnyType>:$results)> {
let summary = "for loop with coordinate transforms";
let description = [{
Loops over several a rectangular regeon of dimensions `bounds` in several
Loops over several a rectangular region of dimensions `bounds` in several
iteration domains, which are coordinate spaces that are the upper coordinates
for a sequence of coordinate transformations.

Expand Down Expand Up @@ -779,7 +782,7 @@ def Rock_TransformingForOp :
return *(getLowerStarts().getValues<uint32_t>().begin() + n);
}

// Retreive the block arguments corresponding to the lower coordinates
// Retrieve the block arguments corresponding to the lower coordinates
// for a given iteration domain.
Block::BlockArgListType getLowerCoords(uint32_t domain) {
uint32_t start = getLowerStart(domain);
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/Rock/utility/builderUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ namespace mlir {
namespace rock {
/// Utility op to emit constant float op
Value createConstantFloatOp(OpBuilder &b, Location loc, Type type,
Type elemType, float value);
Type elemType, float value,
APFloat::opStatus expectedStatus = APFloat::opOK);

/// Utility op to emit constant int op
Value createConstantIntOp(OpBuilder &b, Location loc, Type type, Type elemType,
Expand Down
4 changes: 3 additions & 1 deletion mlir/lib/Conversion/TosaToRock/TosaToRock.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1089,9 +1089,11 @@ struct AttentionRewritePattern : public OpRewritePattern<tosa::MatMulOp> {
tosa::MatMulOp firstMatMulOp = maybeFirstMatMul.value();
IntegerAttr numCUAttr =
numCu.has_value() ? rewriter.getI32IntegerAttr(numCu.value()) : nullptr;

// TODO: extract currentSeqLen from tosa
rock::AttentionOp attnOp = rewriter.create<rock::AttentionOp>(
loc, outputType, firstMatMulOp.getA(), firstMatMulOp.getB(), op.getB(),
elemwiseOtherArgs, output,
elemwiseOtherArgs, nullptr, output,
// TODO(implement transpose fusion support here)
/*qTransposed=*/nullptr,
/*kTransposed=*/nullptr,
Expand Down
35 changes: 35 additions & 0 deletions mlir/lib/Dialect/Rock/IR/RockDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2098,6 +2098,41 @@
if (keyN != valueK) {
return emitError("reduction dimensions of second gemm do not match");
}

// check output type
ShapedType oType = getOut().getType();
int64_t oBatchDim = oType.getShape().size() == 3 ? oType.getShape()[0] : 1;

ArrayRef<int64_t> oLastDims = oType.getShape().slice(oType.getRank() - 2);
auto [outputSeqLen, outputHeadDim] =
getOTransposed() ? std::tuple{oLastDims[1], oLastDims[0]}
: std::tuple{oLastDims[0], oLastDims[1]};

if (qType.getShape().size() != oType.getShape().size()) {
return emitError("Number of dimensions do not match (Q and Output)");
}

Check warning on line 2113 in mlir/lib/Dialect/Rock/IR/RockDialect.cpp

View check run for this annotation

Codecov / codecov/patch

mlir/lib/Dialect/Rock/IR/RockDialect.cpp#L2112-L2113

Added lines #L2112 - L2113 were not covered by tests
if (qBatchDim != oBatchDim) {
return emitError("Batch dimensions do not match (Q and Output)");
}

Check warning on line 2116 in mlir/lib/Dialect/Rock/IR/RockDialect.cpp

View check run for this annotation

Codecov / codecov/patch

mlir/lib/Dialect/Rock/IR/RockDialect.cpp#L2115-L2116

Added lines #L2115 - L2116 were not covered by tests
if (queryM != outputSeqLen) {
return emitError("Sequence length does not match (Q and Output)");
}

Check warning on line 2119 in mlir/lib/Dialect/Rock/IR/RockDialect.cpp

View check run for this annotation

Codecov / codecov/patch

mlir/lib/Dialect/Rock/IR/RockDialect.cpp#L2118-L2119

Added lines #L2118 - L2119 were not covered by tests
if (valueN != outputHeadDim) {
return emitError("Head dimensions do not match (V and Output)");
}

Check warning on line 2122 in mlir/lib/Dialect/Rock/IR/RockDialect.cpp

View check run for this annotation

Codecov / codecov/patch

mlir/lib/Dialect/Rock/IR/RockDialect.cpp#L2121-L2122

Added lines #L2121 - L2122 were not covered by tests

// check currentSeqLen (KV Cache)
auto currentSeqLen = getCurrentSeqLen();
if (currentSeqLen) {
ShapedType seqLenType = currentSeqLen.getType();
if (seqLenType.getShape().size() != 1) {
return emitError("Number of dimensions is not one (currentSeqLen)");
}

Check warning on line 2130 in mlir/lib/Dialect/Rock/IR/RockDialect.cpp

View check run for this annotation

Codecov / codecov/patch

mlir/lib/Dialect/Rock/IR/RockDialect.cpp#L2129-L2130

Added lines #L2129 - L2130 were not covered by tests
if (seqLenType.getShape()[0] != oBatchDim) {
return emitError(
"Batch dimensions do not match (currentSeqLen and Output)");
}

Check warning on line 2134 in mlir/lib/Dialect/Rock/IR/RockDialect.cpp

View check run for this annotation

Codecov / codecov/patch

mlir/lib/Dialect/Rock/IR/RockDialect.cpp#L2132-L2134

Added lines #L2132 - L2134 were not covered by tests
}
return success();
}

Expand Down
10 changes: 4 additions & 6 deletions mlir/lib/Dialect/Rock/Transforms/AffixTuningParameters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,7 @@
if (!isAccel) {
op.emitError("Currently, attention op is only supported on GPUs "
"with matrix accelerator extentions");
signalPassFailure();
return;
return signalPassFailure();

Check warning on line 249 in mlir/lib/Dialect/Rock/Transforms/AffixTuningParameters.cpp

View check run for this annotation

Codecov / codecov/patch

mlir/lib/Dialect/Rock/Transforms/AffixTuningParameters.cpp#L249

Added line #L249 was not covered by tests
}
Attribute params0 = op.getParams0().value_or(nullptr);
// set a default one if params is not provided
Expand All @@ -262,6 +261,7 @@
auto attnPerfConfig = AttnPerfConfigAttr::get(perfConfigStrAttr);
if (!attnPerfConfig) {
op.emitError("perf config string has an incorrect format.");
return signalPassFailure();

Check warning on line 264 in mlir/lib/Dialect/Rock/Transforms/AffixTuningParameters.cpp

View check run for this annotation

Codecov / codecov/patch

mlir/lib/Dialect/Rock/Transforms/AffixTuningParameters.cpp#L264

Added line #L264 was not covered by tests
}
GemmFeatures features = op.getFeatures();
RockAccelTuningParamAttrInterface accelParams0;
Expand All @@ -283,8 +283,7 @@
if (attnPerfConfig.getMPerBlockG0() > attnPerfConfig.getMPerBlockG1()) {
op.emitError(
"The MPerBlockG0 should be larger or equal to getMPerBlockG1.");
signalPassFailure();
return;
return signalPassFailure();

Check warning on line 286 in mlir/lib/Dialect/Rock/Transforms/AffixTuningParameters.cpp

View check run for this annotation

Codecov / codecov/patch

mlir/lib/Dialect/Rock/Transforms/AffixTuningParameters.cpp#L286

Added line #L286 was not covered by tests
}
RockAccelTuningParamAttrInterface accelParams1 =
deriveGemm1TuningParams(builder, op, attnPerfConfig);
Expand All @@ -308,8 +307,7 @@
/*enableDPerWaveFiltering=*/false);
if (isValidBlockwiseGemm0.failed() || isValidBlockwiseGemm1.failed()) {
op.emitError("The provided perf config is not valid");
signalPassFailure();
return;
return signalPassFailure();
}

IntegerAttr blockSizeAttr = builder.getI32IntegerAttr(blockSize);
Expand Down
5 changes: 3 additions & 2 deletions mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,8 +538,9 @@ AttentionRewritePattern::matchAndRewrite(AttentionOp op,
prePadG0NAttr = rw.getIndexAttr(gemm0Size.n);
}
auto newOp = rw.create<GridwiseAttentionAccelOp>(
loc, queries, keys, values, adaptor.getPreSoftmaxElemWiseInputs(), out,
op.getArchAttr(), op.getFeaturesAttr(), blockSizeAttr, gridSizeAttr,
loc, queries, keys, values, adaptor.getPreSoftmaxElemWiseInputs(),
op.getCurrentSeqLen(), out, op.getArchAttr(), op.getFeaturesAttr(),
blockSizeAttr, gridSizeAttr,
/*disableQBypassLDS=*/nullptr, prePadG0MAttr, prePadG0NAttr, params0,
params1, op.getFirstGemmIdxAttr());
bool linalgOpFound = false;
Expand Down
Loading
Loading