Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD] inThreadTranspose: Transpose between global load and local store for non-TN layouts: part 2 of 4 #5223

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::registerTritonAMDGPUConvertToBufferOps();
mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints();
mlir::triton::registerTritonAMDGPULowerInstructionSchedHints();
mlir::registerTritonAMDGPUInThreadTranspose();

// TODO: register Triton & TritonGPU passes
registry
Expand Down
3 changes: 2 additions & 1 deletion include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,7 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
RankedTensorType registerTy, triton::gpu::MemDescType sharedTy,
Type elemLlvmTy, std::optional<int32_t> maxVecElems, Value shmemBase,
ArrayRef<Value> shmemStrides, Location loc, RewriterBase &rewriter,
const TargetInfoBase &target,
const TargetInfoBase &target, bool crossGrain,
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);

inline DenseMap<unsigned, Value> getSwizzledSharedPtrs(
Expand Down Expand Up @@ -1321,6 +1321,7 @@ void storeDistributedToShared(
triton::gpu::MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy,
ArrayRef<Value> srcVals, Value smemBase, ArrayRef<Value> dstStrides,
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
bool crossGrain = false,
std::pair<size_t, Type> *const llvmOpCount = nullptr);

inline Value getStructFromSharedMemoryObject(Location loc,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,18 @@ namespace mlir::triton::gpu {
// to compute the linear layout for MMAv3 (i.e. Hopper) shared layouts (i.e.
// shared layouts with hasLeadingOffset == true) but is otherwise unused.
//
// inThreadTranspose is a flag indicating if transpose should be performed while
// the data resides in thread-local registers. This is set to true on AMD
// platform when non-KContig matrix is about to be written into LDS (shared
// memory) but is otherwise unused. More details are provided in the
// transpose2D() function in LinearLayoutConversions.cpp.
// Returns std::nullopt if the given layout can't be converted to an LL.
// TODO(jlebar): Remove the std::optional once all layouts are supported.
//
std::optional<LinearLayout>
toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
std::optional<int32_t> elemBitWidth = std::nullopt);
std::optional<int32_t> elemBitWidth = std::nullopt,
bool inThreadTranspose = false);

// Given a linear layout where the input dimensions contain a "block" dimension,
// this method sets the "block" dimension to 0 and removes the corresponding
Expand Down
37 changes: 20 additions & 17 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -282,29 +282,32 @@ compared to 1*64 when the hasLeadingOffset is false.
if (needTrans)
kDimNum = 1 - kDimNum;
bool isKDimInner = (order[0] == kDimNum);
if (isKDimInner) {
const int numBanks = 32;
const int bankBitWidth = 32;
const int SIMDWidth = 16;
const int numBanks = 32;
const int bankBitWidth = 32;
const int SIMDWidth = 16;

// number of inner dimension rows per one pattern repeat
int innerDimLength = shape[order[0]];
int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit;
// number of inner dimension rows per one pattern repeat
unsigned innerDim = isKDimInner ? order[0] : order[1];
int innerDimLength = shape[innerDim];
int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit;

int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength);
// vecSize is set to kWidth of the dotop layout
int vecSize = dotOpEnc.getKWidth();
int maxPhase = std::min(SIMDWidth / perPhase, innerDimLength / vecSize);
int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength);
// vecSize is set to kWidth of the dotop layout
int vecSize = dotOpEnc.getKWidth();
int maxPhase = std::min(SIMDWidth / perPhase, innerDimLength / vecSize);

// TODO (zhanglx): figure out better parameters for mfma4
if (mfmaEnc.getMDim() == 4)
maxPhase = 4;
// TODO (zhanglx): figure out better parameters for mfma4
if (mfmaEnc.getMDim() == 4)
maxPhase = 4;

if (isKDimInner) {
return get(context, vecSize, perPhase, maxPhase, order, CTALayout);
} else {
// Do not swizzle in case k dimension is not innermost.
// In this case accesses will go in different banks even without swizzling.
return get(context, 1, 1, 1, order, CTALayout);
// swap order because blocked layout has non-KContig but in LDS it will be KContig
SmallVector<unsigned int> newOrder(order);
std::swap(newOrder[0], newOrder[1]);
// TODO: set inThreadTranspose to true since we want to use special swizzling
return $_get(context, vecSize, perPhase, maxPhase, newOrder, CTALayout, false);
}
}

Expand Down
12 changes: 11 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ void lowerDistributedToShared(
auto srcTy = cast<RankedTensorType>(src.getType());
auto dstTy = cast<MemDescType>(dst.getType());
auto outOrd = mlir::cast<SharedEncodingAttr>(dstTy.getEncoding()).getOrder();
bool crossGrain = false;
// only set crossGrain if it is blocked->shared. This is not a problem for
// NV path because for non-KContig tensor their blocked and shared layout
// still have the same order.
if (auto blocked = dyn_cast<BlockedEncodingAttr>(srcTy.getEncoding())) {
auto rank = blocked.getOrder().size();
auto inOrd = blocked.getOrder();
// it has to be 2D and blocked's and shared's order mismatch
crossGrain = (rank == 2) && (inOrd[0] != outOrd[0]);
}
assert(srcTy.getShape().size() <= 2 ||
(srcTy.getShape().size() == 3 && outOrd[2] == 0) &&
"Unexpected rank of ConvertLayout(blocked->shared)");
Expand All @@ -32,7 +42,7 @@ void lowerDistributedToShared(
auto dstStrides = smemObj.getStrides();
auto inVals = unpackLLElements(loc, adaptorSrc, rewriter);
storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemBase, dstStrides,
loc, rewriter, targetInfo, llvmOpCount);
loc, rewriter, targetInfo, crossGrain, llvmOpCount);
}

struct GlobalScratchAllocOpConversion
Expand Down
92 changes: 70 additions & 22 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ bool emitTransferBetweenRegistersAndShared(
RankedTensorType registerTy, triton::gpu::MemDescType sharedTy,
Type elemLlvmTy, std::optional<int32_t> maxVecElems, Value shmemBase,
ArrayRef<Value> shmemStrides, Location loc, RewriterBase &rewriter,
const TargetInfoBase &target,
const TargetInfoBase &target, bool crossGrain,
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback) {
MLIRContext *ctx = rewriter.getContext();

Expand All @@ -174,8 +174,12 @@ bool emitTransferBetweenRegistersAndShared(
StringAttr kLane = str_attr("lane");
StringAttr kWarp = str_attr("warp");

std::optional<LinearLayout> regLayout =
triton::gpu::toLinearLayout(shape, registerTy.getEncoding());
std::optional<LinearLayout> regLayout = LinearLayout::empty();
auto regEncoding = registerTy.getEncoding();
// setting elemBitWidth to std::nullopt is fine because that is only used for
// shared layout
regLayout =
triton::gpu::toLinearLayout(shape, regEncoding, std::nullopt, crossGrain);
std::optional<LinearLayout> sharedLayout = triton::gpu::toLinearLayout(
shape, sharedTy.getEncoding(), elemLlvmTy.getIntOrFloatBitWidth());
if (!regLayout.has_value() || !sharedLayout.has_value()) {
Expand Down Expand Up @@ -280,7 +284,7 @@ SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
SmallVector<Value> ret;
bool success = emitTransferBetweenRegistersAndShared(
dstTy, srcTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj.getBase(),
smemObj.getStrides(), loc, rewriter, target,
smemObj.getStrides(), loc, rewriter, target, /*crossGrain = */ false,
[&](VectorType vecTy, Value vecAddr) {
auto vecVal = load(vecTy, vecAddr);
vecVal.setAlignment(vecTy.getNumElements() *
Expand All @@ -301,26 +305,70 @@ void storeDistributedToShared(triton::gpu::MemDescType dstTy,
ArrayRef<Value> srcVals, Value smemBase,
ArrayRef<Value> dstStrides, Location loc,
RewriterBase &rewriter,
const TargetInfoBase &target,
const TargetInfoBase &target, bool crossGrain,
std::pair<size_t, Type> *const llvmOpCount) {
bool success = emitTransferBetweenRegistersAndShared(
bool success;
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback;
if (!crossGrain) {
// callback for every situation except the non-KContig dotOperand
// blocked->shared on AMD platform
perVectorCallback = [&](VectorType vecTy, Value vecAddr) {
ArrayRef<Value> vals = srcVals.take_front(vecTy.getNumElements());
srcVals = srcVals.drop_front(vecTy.getNumElements());

Value vec = undef(vecTy);
for (int i = 0; i < vals.size(); i++) {
vec = insert_element(vec, vals[i], i32_val(i));
}
store(vec, vecAddr)
.setAlignment(vecTy.getNumElements() *
elemLlvmTy.getIntOrFloatBitWidth() / 8);
if (llvmOpCount) {
++(llvmOpCount->first);
llvmOpCount->second = vecTy;
}
};
} else {
// This section is only for inThreadTranspose for AMD path, where we want to
// transpose during the blocked->shared tranfer.
// For example, the thread-local register holds a [4, 8] section of matrix,
// where it is contiguous on the dim of 8. We want the perVectorCallback to
// access the column of 4 elements, 8 times, instead of row of 8 elements,
// 4 times like the callback above. For the specific example, the variables
// accessed or derived below will be the following:
// sizePerThread: [4, 8]
// order: [1, 0]
// numElemsPerIter: 4 x 8 = 32
// colIndex: initialized as 0, increment to 8 every time callback is called
// innerVecSize: 8, since it is the vector size of inner dimension
auto blockedEncoding = dyn_cast<BlockedEncodingAttr>(srcTy.getEncoding());
auto sizePerThread = blockedEncoding.getSizePerThread();
auto order = blockedEncoding.getOrder();
unsigned int numElemsPerIter = product<unsigned>(sizePerThread);
unsigned int colIndex = 0;
unsigned int innerVecSize = sizePerThread[order[0]];
perVectorCallback = [&](VectorType vecTy, Value vecAddr) {
Value vec = undef(vecTy);
auto startPos = colIndex / innerVecSize *
numElemsPerIter + // start pos of different iter
colIndex % innerVecSize; // start pos of single iter
for (int i = 0; i < vecTy.getNumElements(); i++) {
auto idx = startPos + i * innerVecSize; // iterate within a vector
vec = insert_element(vec, srcVals[idx], i32_val(i));
}
colIndex++;
store(vec, vecAddr)
.setAlignment(vecTy.getNumElements() *
elemLlvmTy.getIntOrFloatBitWidth() / 8);
if (llvmOpCount) {
++(llvmOpCount->first);
llvmOpCount->second = vecTy;
}
};
}
success = emitTransferBetweenRegistersAndShared(
srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemBase,
dstStrides, loc, rewriter, target, [&](VectorType vecTy, Value vecAddr) {
ArrayRef<Value> vals = srcVals.take_front(vecTy.getNumElements());
srcVals = srcVals.drop_front(vecTy.getNumElements());

Value vec = undef(vecTy);
for (int i = 0; i < vals.size(); i++) {
vec = insert_element(vec, vals[i], i32_val(i));
}
store(vec, vecAddr)
.setAlignment(vecTy.getNumElements() *
elemLlvmTy.getIntOrFloatBitWidth() / 8);
if (llvmOpCount) {
++(llvmOpCount->first);
llvmOpCount->second = vecTy;
}
});
dstStrides, loc, rewriter, target, crossGrain, perVectorCallback);

if (!success)
llvm::report_fatal_error("Failed to emit transfer from register to shared");
Expand Down
99 changes: 98 additions & 1 deletion lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,60 @@ SmallVector<StringAttr> permuteDimNames(const SmallVector<StringAttr> &names,
return ret;
}

// For sizePerThread = [4, 8], the regular linear layout will express it as
// the following
// - register=1 -> (1, 0)
// register=2 -> (2, 0)
// register=4 -> (4, 0)
// register=8 -> (0, 1)
// register=16 -> (0, 2)
// where out dims are: [dim1 (size 8), dim0 (size 4)]
// If we take the binary form, it will be an identity matrix. If we traverse
// from the dim of 4, it will be like the following
// - register=1 -> (0, 1)
// register=2 -> (0, 2)
// register=4 -> (1, 0)
// register=8 -> (2, 0)
// register=16 -> (4, 0)
// where out dims are: [dim1 (size 8), dim0 (size 4)]
// Inside the function we only change the register layout generation, so
// register layout is created by newly introduced transpose2D and the rest still
// comes from identityStandardND.
// Note that simply reversing the for-loop identityStandardND will not work
// because it will change the most minor dimension from dim1 to dim0, and still
// keep it as an identity matrix.
LinearLayout transpose2D(StringAttr inDimName, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order) {
assert(shape.size() == order.size());
assert((order.size() == 2) && "only support dim of 2 now");

MLIRContext *ctx = inDimName.getContext();
StringAttr kRegister = S("register");

std::vector<std::vector<int32_t>> bases;
// traverse 2nd dimension (K-dim in GEMM case)
int dim = order[1];
for (int basis = 1; basis < shape[dim]; basis <<= 1) {
bases.push_back({0, basis});
}
// traverse 1st dimension (N-dim in GEMM non-KContig B-tensor)
// this is the consecutive dimension loaded from global memory
dim = order[0];
for (int basis = 1; basis < shape[dim]; basis <<= 1) {
bases.push_back({basis, 0});
}

auto dimMinor = "dim" + std::to_string(order[0]);
auto dimMajor = "dim" + std::to_string(order[1]);
StringAttr kDimMinor = S(dimMinor);
StringAttr kDimMajor = S(dimMajor);
auto ret = LinearLayout(
{{kRegister, bases}},
{{kDimMinor, shape[order[0]]}, {kDimMajor, shape[order[1]]}}, false);

return ret;
}

// Make a LinearLayout that maps a block-id to an N-dimensional index.
//
// The tensor is split up into CTAsPerCGA pieces, which are distributed among
Expand Down Expand Up @@ -239,6 +293,45 @@ LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef<int64_t> shape,
return combineCtaCgaWithShape(tileLayout, shared.getCTALayout(), shape);
}

// This function convert blockedEncodingAttr to linear layout in a special way.
// It accompanies the AMDGPUInThreadTranspose pass to transpose non-KContig
// tensor into KContig prior to writing into LDS (shared memory). This
// conversion treats the sizePerThread as a 2D matrix and has different access
// pattern.
//
// For example, consider the following blocked layout generated by
// AMDGPUInThreadTranspose: #blocked1 = #ttg.blocked<{sizePerThread =
// [4, 8], threadsPerWarp = [2, 32], warpsPerCTA = [8, 1], order = [1, 0]}>.
// Here since sizePerThread is 2D, there could be two ways to traverse it: along
// the dim of 8 or the dim of 4. The regular toLinearLayout() would go through
// it from the leading order, i.e. dim of 8, but since we want to transpose it
// in-thread, we'd want to iterate of the 2nd order, i.e. dim of 4, so that we
// can pack the element of 4 into a single vector, and AMD backend LLVM compiler
// will pack elements into consecutive VGPR to write data contiguous in K
// dimension into LDS. In this way we guarantee vectorized ds_read, and ds_write
// can be vectorized to 64bit or 32bit depending on the block size and number of
// warps.
//
// The functions is named ThreadRake because we have thread raking through
// multiple row at the same time, as opposed each warp raking through a cluster
// of rows, or the Triton way, which iterates through every warp avaiable,
// and then tile it over the entire block.
LinearLayout blockedToLinearLayoutThreadRake(ArrayRef<int64_t> shape,
BlockedEncodingAttr blocked) {
MLIRContext *ctx = blocked.getContext();
int rank = shape.size();
auto outDimNames = standardOutDimNames(ctx, rank);
const auto &order = blocked.getOrder();
auto sizePerThread = blocked.getSizePerThread();

auto ctaLayout =
transpose2D(S("register"), sizePerThread, order) *
identityStandardND(S("lane"), blocked.getThreadsPerWarp(), order) *
identityStandardND(S("warp"), blocked.getWarpsPerCTA(), order);

return combineCtaCgaWithShape(ctaLayout, blocked.getCTALayout(), shape);
}

} // anonymous namespace

std::optional<LinearLayout>
Expand Down Expand Up @@ -755,9 +848,13 @@ SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {

std::optional<LinearLayout>
toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
std::optional<int32_t> elemBitWidth /*= std::nullopt*/) {
std::optional<int32_t> elemBitWidth /*= std::nullopt*/,
bool inThreadTranspose /*= false*/) {
// Layouts are distributed or shared
if (auto distributed = dyn_cast<DistributedEncodingTrait>(layout)) {
auto blocked = dyn_cast<BlockedEncodingAttr>(distributed);
if (blocked && inThreadTranspose)
return blockedToLinearLayoutThreadRake(shape, blocked);
return distributed.toLinearLayout(shape);
} else if (auto shared = dyn_cast<SharedEncodingAttr>(layout)) {
if (shared.getHasLeadingOffset()) {
Expand Down
6 changes: 3 additions & 3 deletions test/TritonGPU/amd/amd-instruction-sched.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ module {
// INSTR_COUNT_NS1-SAME: isBufferLoadsAEnabled = false
// INSTR_COUNT_NS1-SAME: isBufferLoadsBEnabled = false
// INSTR_COUNT_NS1-SAME: numDsReadsA = #amdgpu.InstCounter<8, vector<4xf16>>
// INSTR_COUNT_NS1-SAME: numDsReadsB = #amdgpu.InstCounter<32, vector<1xf16>>
// INSTR_COUNT_NS1-SAME: numDsReadsB = #amdgpu.InstCounter<8, vector<4xf16>>
// INSTR_COUNT_NS1-SAME: numDsWritesA = #amdgpu.InstCounter<0, none>
// INSTR_COUNT_NS1-SAME: numDsWritesB = #amdgpu.InstCounter<0, none>
// INSTR_COUNT_NS1-SAME: numGlobalLoadsA = #amdgpu.InstCounter<4, vector<4xf16>>
Expand All @@ -61,9 +61,9 @@ module {
// INSTR_COUNT_NS2-SAME: isBufferLoadsAEnabled = false
// INSTR_COUNT_NS2-SAME: isBufferLoadsBEnabled = false
// INSTR_COUNT_NS2-SAME: numDsReadsA = #amdgpu.InstCounter<8, vector<4xf16>>
// INSTR_COUNT_NS2-SAME: numDsReadsB = #amdgpu.InstCounter<32, vector<1xf16>>
// INSTR_COUNT_NS2-SAME: numDsReadsB = #amdgpu.InstCounter<8, vector<4xf16>>
// INSTR_COUNT_NS2-SAME: numDsWritesA = #amdgpu.InstCounter<4, vector<4xf16>>
// INSTR_COUNT_NS2-SAME: numDsWritesB = #amdgpu.InstCounter<4, vector<4xf16>>
// INSTR_COUNT_NS2-SAME: numDsWritesB = #amdgpu.InstCounter<16, vector<1xf16>>
// INSTR_COUNT_NS2-SAME: numGlobalLoadsA = #amdgpu.InstCounter<4, vector<4xf16>>
// INSTR_COUNT_NS2-SAME: numGlobalLoadsB = #amdgpu.InstCounter<4, vector<4xf16>>
// INSTR_COUNT_NS2-SAME: numMMAs = #amdgpu.InstCounter<16, tensor<32x32x8xf16>>
Expand Down
Loading
Loading