Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD] NFC: Unified comment style #5248

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,41 +11,41 @@ using namespace mlir;
using namespace mlir::triton;

// clang-format off
/***
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# WO # W1 # | #
# # # | #
# # # # # | #
# W2 # W3 # .... | #
# # # | SkipElems #
# # # # # | #
# | #
# Slice | #
# . / \ | #
# . / \ | #
# . / \| #
# # # # # # #
# # W0 # W1 # #
# # # # #
# # # # # # tensorStride #
# # W2 # W3 # --------------------------------#
# # # # #
# # # # # # #
# tensorStride # W0 # W1 # #
# ---------------------------------- # # # #
# # # # # # #
# # W2 # W3 # #
# # # # #
# # # # # # ---> lastIdx #
# . #
# . #
# . #
# #
# #
# #
# #
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
***/
//===--------------------------------------------------------------------------------===//
// # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
// # WO # W1 # | #
// # # # | #
// # # # # # | #
// # W2 # W3 # .... | #
// # # # | SkipElems #
// # # # # # | #
// # | #
// # Slice | #
// # . / \ | #
// # . / \ | #
// # . / \| #
// # # # # # # #
// # # W0 # W1 # #
// # # # # #
// # # # # # # tensorStride #
// # # W2 # W3 # --------------------------------#
// # # # # #
// # # # # # # #
// # tensorStride # W0 # W1 # #
// # ---------------------------------- # # # #
// # # # # # # #
// # # W2 # W3 # #
// # # # # #
// # # # # # # ---> lastIdx #
// # . #
// # . #
// # . #
// # #
// # #
// # #
// # #
// # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
//===--------------------------------------------------------------------------------===//
// clang-format on

namespace {
Expand Down
20 changes: 8 additions & 12 deletions third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,10 @@ class CallOpConversion : public OpRewritePattern<LLVM::CallOp> {
rewriter.setInsertionPointToEnd(currentBlock);
rewriter.create<LLVM::CondBrOp>(loc, pred, trueBlock, afterStore);
rewriter.setInsertionPointToStart(trueBlock);
/*
| vialatile | non-tmp | gcn instr gfx94
LLVM::StoreOp | 0 | 0 | (cg) global store
| 0 | 1 | (cs) global store nt
| 1 | 0/1 | (wt) global store sc0 sc1
*/
// | vialatile | non-tmp | gcn instr gfx94
// LLVM::StoreOp | 0 | 0 | (cg) global store
// | 0 | 1 | (cs) global store nt
// | 1 | 0/1 | (wt) global store sc0 sc1
bool vialatileFlag = isPredicatedStoreWT(callOp);
bool nonTmpFlag = isPredicatedStoreCS(callOp);
auto storeOp = rewriter.create<LLVM::StoreOp>(
Expand Down Expand Up @@ -138,12 +136,10 @@ class CallOpConversion : public OpRewritePattern<LLVM::CallOp> {
rewriter.setInsertionPointToEnd(currentBlock);
rewriter.create<LLVM::CondBrOp>(loc, pred, trueBlock, falseBlock);
rewriter.setInsertionPointToStart(trueBlock);
/*
| vialatile | non-tmp | gcn instr gfx94
LLVM::LoadOp | 0 | 0 | (ca) global load
| 0/1 | 1 | (cg) global load nt
| 1 | 0 | (cv) flat load sc0 sc1
*/
// | vialatile | non-tmp | gcn instr gfx94
// LLVM::LoadOp | 0 | 0 | (ca) global load
// | 0/1 | 1 | (cg) global load nt
// | 1 | 0 | (cv) flat load sc0 sc1
bool vialatileFlag = isPredicatedLoadCV(callOp);
bool nonTmpFlag = isPredicatedLoadCG(callOp);
auto loadOp = rewriter.create<LLVM::LoadOp>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,15 @@ bool isKMajor(llvm::ArrayRef<unsigned> order, int opIdx) {
return order[0] == kdim;
}

/**
* @brief checks that swizzle pattern fits into one warp block
* and block size is a multiple of swizzle size along non-K dimension
*
* @param sharedLayout
* @param opIdx operand id 0 or 1
* @param reps number of repetitions: [non-k, k] or [batch, non-k, k]
* @param elemsPerInstr one instruction size
* @param warpsPerBlockNonK number of warps along non-k Dim
* @return bool
*/
/// Checks that swizzle pattern fits into one warp block
/// and block size is a multiple of swizzle size along non-K dimension
///
/// \param sharedLayout
/// \param opIdx operand id 0 or 1
/// \param reps number of repetitions: [non-k, k] or [batch, non-k, k]
/// \param elemsPerInstr one instruction size
/// \param warpsPerBlockNonK number of warps along non-k Dim
/// \returns bool
bool isSwizzlePatternFitsIntoBlock(const SharedEncodingAttr sharedLayout,
int opIdx, const ArrayRef<int64_t> reps,
const ArrayRef<int64_t> elemsPerInstr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,16 @@ Value getWarpIdInBlock(ConversionPatternRewriter &rewriter, Location loc,

bool isSwizzled(gpu::SharedEncodingAttr layout);

/**
* @brief swizzling tensor element indexes according pattern encoded in
* SharedEncodingAttr
*
* @param rewriter
* @param loc
* @param row row of target tensor element related to the start of smemObj
* @param col col of target tensor element related to the start of smemObj
* @param smemObj shared memory object, contains info about tensor in LDS
* @param attr layout attribute, contains swizzling info
* @return swizzled row, col indexes in tensor notation
*/
/// Swizzling tensor element indexes according pattern encoded in
/// SharedEncodingAttr
///
/// \param rewriter
/// \param loc
/// \param row row of target tensor element related to the start of smemObj
/// \param col col of target tensor element related to the start of smemObj
/// \param smemObj shared memory object, contains info about tensor in LDS
/// \param attr layout attribute, contains swizzling info
/// \returns swizzled row, col indexes in tensor notation
std::pair<mlir::Value, mlir::Value>
swizzleIndexes(ConversionPatternRewriter &rewriter, Location loc, Value row,
Value col, SharedMemoryObject smemObj,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,43 +33,41 @@ using ::mlir::triton::gpu::SharedEncodingAttr;

namespace SharedToDotOperandMFMA {

/**
* @brief This function maps particular load of mfma dot operand to element
* indexes(row, col)
*
* Whole tensor is broken into "blocks" of warps along "non-K" axis.
* One block could be processed by multiple warps.
* One warp works on a piece of tensor size elemsPerInstr[0] x K.
* Each of these pieces is broken into "tiles" of size elemsPerInstr[0] x
* elemsPerInstr[1].
*
* Total offset of element is a sum of following values:
* 1. Offset of warp-block in tensor
* 2. Offset of warp inside one warp-block
* 3. Offset of tile in one warp
* 4. Offset of one lane data in a tile
* 5. Offset of particular element of tensor processed by one lane
*
* This function computes these offsets for axies independently
* Note that this function returns the offsets of elements in the first
* warp-block. The offsets of elements in later warp-blocks can be computed
* by adding a constant stride to the xor-ed offsets of elements in the
* first warp-block.
*
* @param rewriter
* @param loc
* @param elemsPerInstr operand tile shape consumed by one MFMA instruction
* @param warpId id component of 2d warp grid along non-K axis
* @param laneId lane id in warp [0..63]
* @param numOfElems number of elements accessed by thread per repetition
* @param reps number of instructions repetition to fully cover dot operand
* @param smemStrides strides in LDS tensor
* @param loadVecSize number of elements loaded by one operation
* @param iNonKDim non-K dimension size of one MFMA instruction
* @param iKDim K dimension size of one MFMA instruction
* @return vector (i-th element corresponds to i-th load instruction) of
* 2-element vectors(tensor row and col).
*/
/// This function maps particular load of mfma dot operand to element
/// indexes(row, col)
///
/// Whole tensor is broken into "blocks" of warps along "non-K" axis.
/// One block could be processed by multiple warps.
/// One warp works on a piece of tensor size elemsPerInstr[0] x K.
/// Each of these pieces is broken into "tiles" of size elemsPerInstr[0] x
/// elemsPerInstr[1].
///
/// Total offset of element is a sum of following values:
/// 1. Offset of warp-block in tensor
/// 2. Offset of warp inside one warp-block
/// 3. Offset of tile in one warp
/// 4. Offset of one lane data in a tile
/// 5. Offset of particular element of tensor processed by one lane
///
/// This function computes these offsets for axies independently
/// Note that this function returns the offsets of elements in the first
/// warp-block. The offsets of elements in later warp-blocks can be computed
/// by adding a constant stride to the xor-ed offsets of elements in the
/// first warp-block.
///
/// \param rewriter
/// \param loc
/// \param elemsPerInstr operand tile shape consumed by one MFMA instruction
/// \param warpId id component of 2d warp grid along non-K axis
/// \param laneId lane id in warp [0..63]
/// \param numOfElems number of elements accessed by thread per repetition
/// \param reps number of instructions repetition to fully cover dot operand
/// \param smemStrides strides in LDS tensor
/// \param loadVecSize number of elements loaded by one operation
/// \param iNonKDim non-K dimension size of one MFMA instruction
/// \param iKDim K dimension size of one MFMA instruction
/// \returns vector (i-th element corresponds to i-th load instruction) of
/// 2-element vectors(tensor row and col).
llvm::SmallVector<llvm::SmallVector<Value>> computeTensorElemMappingInBlock(
ConversionPatternRewriter &rewriter, Location loc,
const ArrayRef<int64_t> &elemsPerInstr, Value warpId, Value laneId,
Expand Down Expand Up @@ -127,17 +125,18 @@ bool hasSwizzleEnabled(const SharedEncodingAttr &srcEncoding) {
return srcEncoding.getMaxPhase() > 1;
}

// Computes offsets for operand B or transposed operand A
// @param rewriter
// @param loc
// @param elemsPerInstr operand tile shape [K, nonK] consumed by one MFMA
// instruction
// @param warpId warp id for the "non K" axis
// @param laneId lane id in warp [0..63]
// @param warpsPerBlock number of warps per horizontal axis
// @param numOfElems number of elements accessed by threads per repetition
// @param reps number of instructions repretition to fully cover dot operand
// @param cSwizzleOffset
/// Computes offsets for operand B or transposed operand A
///
/// \param rewriter
/// \param loc
/// \param elemsPerInstr operand tile shape [K, nonK] consumed by one MFMA
/// instruction
/// \param warpId warp id for the "non K" axis
/// \param laneId lane id in warp [0..63]
/// \param warpsPerBlock number of warps per horizontal axis
/// \param numOfElems number of elements accessed by threads per repetition
/// \param reps number of instructions repretition to fully cover dot operand
/// \param cSwizzleOffset
llvm::SmallVector<Value>
fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc,
const ArrayRef<int64_t> &elemsPerInstr, Value warpId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,39 +33,37 @@ using ::mlir::triton::gpu::SharedEncodingAttr;

namespace SharedToDotOperandWMMA {

/**
* @brief Following functions maps particular load of wmma dot operand to
* element indexes(row, col). For each WMMA generation separate function is
* used.
*
* Whole tensor is broken into "blocks" of warps along "non-K" axis.
* One block could be processed by multiple warps.
* One warp works on a piece of tensor size elemsPerInstr[0] x K.
* Each of these pieces is broken into "tiles" of size elemsPerInstr[0] x
* elemsPerInstr[1].
*
* Total offset of element is a sum of following values:
* 1. Offset of warp block in tensor
* 2. Offset of warp inside one warp block
* 3. Offset of tile in one warp
* 4. Offset of one lane data in a tile
* 5. Offset of particular element of tensor processed by one lane
*
* This function computes these offsets for axes independently
*
* @param rewriter
* @param loc
* @param elemsPerInstr operand tile shape consumed by one WMMA instruction
* @param warpId id component of 2d warp grid along non-K axis
* @param laneId lane id in warp [0..63]
* @param numOfElems number of elements accessed by thread per repetition
* @param reps number of instructions repetition to fully cover dot operand
* @param smemStrides strides in LDS tensor
* @param loadVecSize number of elements loaded by one operation
* @param iNonKDim non-K dimension of dot operand
* @return vector (i-th element corresponds to i-th load instruction) of
* 2-element vectors(tensor row and col).
*/
/// Following functions maps particular load of wmma dot operand to
/// element indexes(row, col). For each WMMA generation separate function is
/// used.
///
/// Whole tensor is broken into "blocks" of warps along "non-K" axis.
/// One block could be processed by multiple warps.
/// One warp works on a piece of tensor size elemsPerInstr[0] x K.
/// Each of these pieces is broken into "tiles" of size elemsPerInstr[0] x
/// elemsPerInstr[1].
///
/// Total offset of element is a sum of following values:
/// 1. Offset of warp block in tensor
/// 2. Offset of warp inside one warp block
/// 3. Offset of tile in one warp
/// 4. Offset of one lane data in a tile
/// 5. Offset of particular element of tensor processed by one lane
///
/// This function computes these offsets for axes independently
///
/// \param rewriter
/// \param loc
/// \param elemsPerInstr operand tile shape consumed by one WMMA instruction
/// \param warpId id component of 2d warp grid along non-K axis
/// \param laneId lane id in warp [0..63]
/// \param numOfElems number of elements accessed by thread per repetition
/// \param reps number of instructions repetition to fully cover dot operand
/// \param smemStrides strides in LDS tensor
/// \param loadVecSize number of elements loaded by one operation
/// \param iNonKDim non-K dimension of dot operand
/// \returns vector (i-th element corresponds to i-th load instruction) of
/// 2-element vectors(tensor row and col).
llvm::SmallVector<llvm::SmallVector<Value>>
computeTensorElemMappingInBlockWmma1(
ConversionPatternRewriter &rewriter, Location loc,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ struct DecomposeUnsupportedAMDConversions

triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod, isShortcut);

/* -------------------------------- */
// Replace `wmma -> dot_op` with `wmma -> blocked -> dot_op`
/* -------------------------------- */
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
OpBuilder builder(cvtOp);
auto srcType = cvtOp.getSrc().getType();
Expand Down
Loading
Loading