Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
dhernandez0 committed Nov 28, 2024
1 parent a35bb9d commit 0f3912b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
23 changes: 13 additions & 10 deletions mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1043,7 +1043,7 @@ struct GridwiseAttentionAccelRewritePattern
gemm0OutExpTrs, gemm0OutTrs},
/*bounds=*/ArrayRef<int64_t>{g0Mpt, g0Npt},
/*strides=*/ArrayRef<int64_t>{1, 1},
/*useIndexDiffs=*/true, /*forceUnroll=*/true);
/*forceUnroll=*/true, /*useIndexDiffs=*/true);
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(loop.getBody());
Expand Down Expand Up @@ -1105,7 +1105,7 @@ struct GridwiseAttentionAccelRewritePattern
gemm0OutBufferMaxTrs},
/*bounds=*/ArrayRef<int64_t>{g0Mpt, 1},
/*strides=*/ArrayRef<int64_t>{1, 1},
/*useIndexDiffs=*/true, /*forceUnroll=*/true);
/*forceUnroll=*/true, /*useIndexDiffs=*/true);
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(loop.getBody());
Expand Down Expand Up @@ -1167,7 +1167,7 @@ struct GridwiseAttentionAccelRewritePattern
ArrayRef<Attribute>{rewriter.getArrayAttr({}), attentionOutAccTrs},
/*bounds=*/ArrayRef<int64_t>{g1Mpt, g1Npt},
/*strides=*/ArrayRef<int64_t>{1, 1},
/*useIndexDiffs=*/true, /*forceUnroll=*/true);
/*forceUnroll=*/true, /*useIndexDiffs=*/true);
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(loop.getBody());
Expand Down Expand Up @@ -1230,7 +1230,7 @@ struct GridwiseAttentionAccelRewritePattern
attentionOutAccBufferTrs},
/*bounds=*/ArrayRef<int64_t>{g1Mpt, g1Npt},
/*strides=*/ArrayRef<int64_t>{1, 1},
/*useIndexDiffs=*/true, /*forceUnroll=*/true);
/*forceUnroll=*/true, /*useIndexDiffs=*/true);
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(loop.getBody());
Expand Down Expand Up @@ -1336,7 +1336,7 @@ struct GridwiseAttentionAccelRewritePattern
void createFirstGemmNegInfPadding(
PatternRewriter &rewriter, Location loc,
layout::GridCoordinates gridCoords, Value gemm0OutBuffer,
RegsAsMatrixSubTiles gemm0OutSubTileViews) const {
RegsAsMatrixSubTiles gemm0OutSubTileViews, bool isGfx11) const {
MemRefType gemm0OutBufferType = cast<MemRefType>(gemm0OutBuffer.getType());
auto negInfTyped = createConstantFloatOp(
rewriter, loc, gemm0OutBufferType.getElementType(),
Expand All @@ -1346,6 +1346,8 @@ struct GridwiseAttentionAccelRewritePattern
auto tid = rewriter.create<WorkitemIdOp>(loc, rewriter.getIndexType());
int64_t elementsInThreadBuffer = gemm0OutBufferType.getNumElements();
Value zero = rewriter.createOrFold<ConstantIndexOp>(loc, 0);

// TODO: fix forceUnroll=false for gfx1100 (https://github.com/ROCm/rocMLIR-internal/issues/1661)
auto loop = rewriter.create<TransformingForOp>(
loc,
ArrayRef<ValueRange>{{gridCoords.g_block, gridCoords.m_block,
Expand All @@ -1355,7 +1357,7 @@ struct GridwiseAttentionAccelRewritePattern
rewriter.getArrayAttr({})},
/*bounds=*/ArrayRef<int64_t>{1, 1, 1, 1, elementsInThreadBuffer},
/*strides=*/ArrayRef<int64_t>{1, 1, 1, 1, 1},
/*useIndexDiffs=*/true, /*forceUnroll=*/true);
/*forceUnroll=*/!isGfx11, /*useIndexDiffs=*/true);
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(loop.getBody());
Expand Down Expand Up @@ -1408,15 +1410,15 @@ struct GridwiseAttentionAccelRewritePattern
thenb.getArrayAttr({})},
/*bounds=*/ArrayRef<int64_t>{1, 1, 1, 1, elementsInThreadBuffer},
/*strides=*/ArrayRef<int64_t>{1, 1, 1, 1, 1},
/*useIndexDiffs=*/true, /*forceUnroll=*/true);
/*forceUnroll=*/true, /*useIndexDiffs=*/true);
{
OpBuilder::InsertionGuard guard(thenb);
thenb.setInsertionPointToStart(loop.getBody());

Block::BlockArgListType lowerCoords = loop.getLowerCoords(0);
Block::BlockArgListType upperCoords = loop.getLowerCoords(1);
auto isInvalid = thenb.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, lowerCoords[2], currentSeqLen);
loc, arith::CmpIPredicate::uge, lowerCoords[2], currentSeqLen);
scf::IfOp ifb = thenb.create<scf::IfOp>(loc, isInvalid,
/*withElseRegion=*/false);
{
Expand Down Expand Up @@ -2214,20 +2216,21 @@ struct GridwiseAttentionAccelRewritePattern
postProcessFirstGemmSplat<ElementwiseMultOp>(
rewriter, loc, gridCoordsGemm0, gemm0OutBuffer, gemm0OutSubTileViews,
ln2Recip.getDefiningOp<arith::ConstantOp>().getValue());
#endif

// Handle padding
bool hasPadding =
op.getPrePadG0M().has_value() || op.getPrePadG0N().has_value();
if (hasPadding) {
bool isGfx11 = arch.contains("gfx11");
createFirstGemmNegInfPadding(rewriter, loc, gridCoordsGemm0,
gemm0OutBuffer,
gemm0OutSubTileViewsTrUnPadded);
gemm0OutSubTileViewsTrUnPadded, isGfx11);
}
// Negative Infinite for extra values (KV cache)
setGemm0OutputOutOfScopeKVCache(
rewriter, loc, gridCoordsGemm0, gemm0OutBuffer,
gemm0OutSubTileViewsTr, currentSeqLen, mLoopIV, gemm0MBlocksLastIter);
#endif

APInt reductionAxis = APInt(64, 1);
APInt nrDimPerThread = APInt(64, gemm0MPerBlock / gemm0MPerThread);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ func.func @gridwise_attn_kvcache(%arg0: memref<1x384x64xf32>, %arg1: memref<1x64
// CHECK: %[[comparison:.+]] = arith.cmpi eq, %[[iterIndex]], %[[lastIter]] : index
// CHECK-NEXT: scf.if %[[comparison]] {
// CHECK: rock.transforming_for {forceUnroll, useIndexDiffs} (%[[dim0:.+]], %[[dim1:.+]], %[[dim2:.+]]) = [{{.*}}]({{.*}}), ({{.*}}) = []
// CHECK-NEXT: %[[secondComparison:.+]] = arith.cmpi sge, %[[dim2]], %[[currSeqLenIndex]] : index
// CHECK-NEXT: %[[secondComparison:.+]] = arith.cmpi uge, %[[dim2]], %[[currSeqLenIndex]] : index
// CHECK-NEXT: scf.if %[[secondComparison]] {
// CHECK-NEXT: rock.in_bounds_store
rock.gridwise_attention_accel(%0, %arg1, %arg2, %arg4, %arg3) features = mfma|dot|atomic_add preSoftmaxOps = {} {
Expand Down

0 comments on commit 0f3912b

Please sign in to comment.