Skip to content

Commit

Permalink
optimize consumer release
Browse files Browse the repository at this point in the history
  • Loading branch information
htyu committed Dec 16, 2024
1 parent ab5f7c2 commit d9bb47d
Showing 1 changed file with 4 additions and 40 deletions.
44 changes: 4 additions & 40 deletions lib/Dialect/TritonGPU/Transforms/WSLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,44 +149,8 @@ void processConsumerWaitOp(OpBuilder &builder, ttng::ConsumerWaitOp op,
void processConsumerReleaseOp(OpBuilder &builder, ttng::ConsumerReleaseOp op,
Value bufferEmpty, int numCTAs) {
auto loc = op.getLoc();
Value _0 = builder.create<arith::ConstantIntOp>(loc, 0, 32);
Value _4 = builder.create<arith::ConstantIntOp>(loc, 4, 32);
Value _8 = builder.create<arith::ConstantIntOp>(loc, 8, 32);
Value _32 = builder.create<arith::ConstantIntOp>(loc, 32, 32);
Value _threadPerTask =
builder.create<arith::ConstantIntOp>(loc, THREADS_PER_TASK, 32);

// threadId = threadId % THREADS_PER_TASK
Value threadId = builder.create<arith::RemUIOp>(
loc, createThreadIdOp(builder, loc), _threadPerTask);
// k = threadId / 8
Value k = builder.create<arith::DivUIOp>(loc, threadId, _8);
// row = k / 4
Value row = builder.create<arith::DivUIOp>(loc, k, _4);
// col = k % 4
Value col = builder.create<arith::RemUIOp>(loc, k, _4);
// remoteCTAId = (col ^ row) * 4 + col
Value remoteCTAId = builder.create<arith::AddIOp>(
loc,
Value{builder.create<arith::MulIOp>(
loc, Value{builder.create<arith::XOrIOp>(loc, col, row)}, _4)},
col);

// pred0 = threadId % 8 == 0
Value pred0 = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq,
builder.create<arith::RemUIOp>(loc, threadId, _8), _0);
// pred1 = remoteCTAId < numCTAs
Value pred1 = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ult, remoteCTAId,
builder.create<arith::ConstantIntOp>(loc, numCTAs, 32));

// pred = pred0 & pred1
Value pred = builder.create<arith::AndIOp>(loc, pred0, pred1);
// bufferEmpty arrive
auto arriveOp = builder.create<ttng::MBarrierArriveOp>(loc, bufferEmpty, pred,
remoteCTAId, false, 0);

auto arriveOp = builder.create<ttng::MBarrierArriveOp>(
loc, bufferEmpty, nullptr, nullptr, false, 0);
assert(op.getOperation()->hasAttr("async_task_id"));
setAsyncTaskIds(arriveOp, getAsyncTaskIds(op.getOperation()));
}
Expand Down Expand Up @@ -230,8 +194,8 @@ void lowerTokenOperations(Operation *parentOp, int numCTAs,

Value barrierEmptyView = builder.create<ttg::MemDescSubviewOp>(
loc, singleBarrierMemDescType, bufferEmptyArray, idx);
unsigned bufferEmptyCount = numCTAs;
builder.create<ttng::InitBarrierOp>(loc, barrierEmptyView, numCTAs);
builder.create<ttng::InitBarrierOp>(loc, barrierEmptyView,
THREADS_PER_TASK);
}

if (numCTAs == 1) {
Expand Down

0 comments on commit d9bb47d

Please sign in to comment.