From d9bb47df64d150e5fa89516164b0645a5c771faa Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Sat, 14 Dec 2024 00:59:27 -0800 Subject: [PATCH] optimize consumer release --- .../TritonGPU/Transforms/WSLowering.cpp | 44 ++----------------- 1 file changed, 4 insertions(+), 40 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/WSLowering.cpp b/lib/Dialect/TritonGPU/Transforms/WSLowering.cpp index c2bf31fc5..adc009954 100644 --- a/lib/Dialect/TritonGPU/Transforms/WSLowering.cpp +++ b/lib/Dialect/TritonGPU/Transforms/WSLowering.cpp @@ -149,44 +149,8 @@ void processConsumerWaitOp(OpBuilder &builder, ttng::ConsumerWaitOp op, void processConsumerReleaseOp(OpBuilder &builder, ttng::ConsumerReleaseOp op, Value bufferEmpty, int numCTAs) { auto loc = op.getLoc(); - Value _0 = builder.create(loc, 0, 32); - Value _4 = builder.create(loc, 4, 32); - Value _8 = builder.create(loc, 8, 32); - Value _32 = builder.create(loc, 32, 32); - Value _threadPerTask = - builder.create(loc, THREADS_PER_TASK, 32); - - // threadId = threadId % THREADS_PER_TASK - Value threadId = builder.create( - loc, createThreadIdOp(builder, loc), _threadPerTask); - // k = threadId / 8 - Value k = builder.create(loc, threadId, _8); - // row = k / 4 - Value row = builder.create(loc, k, _4); - // col = k % 4 - Value col = builder.create(loc, k, _4); - // remoteCTAId = (col ^ row) * 4 + col - Value remoteCTAId = builder.create( - loc, - Value{builder.create( - loc, Value{builder.create(loc, col, row)}, _4)}, - col); - - // pred0 = threadId % 8 == 0 - Value pred0 = builder.create( - loc, arith::CmpIPredicate::eq, - builder.create(loc, threadId, _8), _0); - // pred1 = remoteCTAId < numCTAs - Value pred1 = builder.create( - loc, arith::CmpIPredicate::ult, remoteCTAId, - builder.create(loc, numCTAs, 32)); - - // pred = pred0 & pred1 - Value pred = builder.create(loc, pred0, pred1); - // bufferEmpty arrive - auto arriveOp = builder.create(loc, bufferEmpty, pred, - remoteCTAId, false, 0); - + auto arriveOp = builder.create( + loc, bufferEmpty, nullptr, nullptr, false, 0); assert(op.getOperation()->hasAttr("async_task_id")); setAsyncTaskIds(arriveOp, getAsyncTaskIds(op.getOperation())); } @@ -230,8 +194,8 @@ void lowerTokenOperations(Operation *parentOp, int numCTAs, Value barrierEmptyView = builder.create( loc, singleBarrierMemDescType, bufferEmptyArray, idx); - unsigned bufferEmptyCount = numCTAs; - builder.create(loc, barrierEmptyView, numCTAs); + builder.create(loc, barrierEmptyView, + THREADS_PER_TASK); } if (numCTAs == 1) {