Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for tt.dot_scaled operator #2804

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
34 changes: 27 additions & 7 deletions lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "intel/include/Dialect/TritonIntelGPU/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Utility.h"
Expand Down Expand Up @@ -114,17 +115,36 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
retTy = RankedTensorType::get(xShape, FloatType::getBF16(ctx));
} else {
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
auto newVEncoding = DotOperandEncodingAttr::get(
ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(),
oldEncoding.getKWidth() * 2);
// Figure out the K dimension for the input A/B, given that the return
// type is upcasted A/B type so we need to update the proper dim size.

const int opIdx = oldEncoding.getOpIdx();
const bool hasBatch = xShape.size() == 3;
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
newShape[kIdx] *= 2;
retTy = RankedTensorType::get(newShape, FloatType::getBF16(ctx),
newVEncoding);
Type elemType = FloatType::getBF16(ctx);

// Note: For Intel the dot operands layout's kWidth parameter must
// match the parent's DPAS layout opsPerChannel so we need to materialize
// a new DPAS layout.
Attribute newVEncoding;
if (auto dpasEncoding =
dyn_cast<intel::DpasEncodingAttr>(oldEncoding.getParent())) {
auto newDpasEncoding = intel::DpasEncodingAttr::get(
ctx, dpasEncoding.getRepeatCount(), dpasEncoding.getSystolicDepth(),
dpasEncoding.getExecutionSize(),
intel::DpasEncodingAttr::getOpsPerChannel(elemType),
dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(),
dpasEncoding.getSubGroupSize());
newVEncoding = DotOperandEncodingAttr::get(
ctx, oldEncoding.getOpIdx(), newDpasEncoding,
newDpasEncoding.getOpsPerChannel());
} else {
// Figure out the K dimension for the input A/B, given that the return
// type is upcasted A/B type so we need to update the proper dim size.
newVEncoding = DotOperandEncodingAttr::get(ctx, oldEncoding.getOpIdx(),
oldEncoding.getParent(),
oldEncoding.getKWidth() * 2);
}
retTy = RankedTensorType::get(newShape, elemType, newVEncoding);
}
inferredReturnTypes.push_back(retTy);
} else {
Expand Down
29 changes: 29 additions & 0 deletions test/TritonIntelGPU/accelerate-matmul-pvc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,32 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
tt.return
}
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>

module attributes {"triton_gpu.target" = "xpu", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 16 : i32, "triton_intel_gpu.min_sg_size" = 16 : i32, "triton_intel_gpu.support_dpas"} {
// CHECK: [[BLOCKED:#.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
// CHECK: [[BLOCKED1:#.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
// CHECK: [[BLOCKED2:#.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
// CHECK: [[DPAS:#.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
// CHECK: [[DPAS1:#.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 4, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [1, 1], A = [8, 32], B = [32, 16], C = [8, 16]}>
// CHECK: dot_scaled
tt.func @dot_scaled(%a: tensor<128x32xi8, #blocked2>, %scale: tensor<128x2xi8, #blocked1>, %b: tensor<64x128xbf16, #blocked>) -> tensor<128x128xf32, #blocked> {
// CHECK: [[CST:%.*]] = arith.constant dense<0.000000e+00> : tensor<128x128xf32, [[BLOCKED2]]>
// CHECK: [[C:%.*]] = triton_gpu.convert_layout [[CST]] : tensor<128x128xf32, [[BLOCKED2]]> -> tensor<128x128xf32, [[DPAS]]>
// CHECK: [[CVT_ARG0:%.*]] = triton_gpu.convert_layout %arg0 : tensor<128x32xi8, [[BLOCKED]]> -> tensor<128x32xi8, #triton_gpu.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>>
// CHECK: [[CVT_ARG1:%.*]] = triton_gpu.convert_layout %arg1 : tensor<128x2xi8, [[BLOCKED1]]> -> tensor<128x2xi8, [[BLOCKED1]]>
// CHECK: [[A:%.*]] = triton_gpu.upcast_mxfp [[CVT_ARG0]], [[CVT_ARG1]] fp_type = e2m1 : tensor<128x32xi8, #triton_gpu.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>>, tensor<128x2xi8, [[BLOCKED1]]> -> tensor<128x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>>
// CHECK: [[B:%.*]] = triton_gpu.convert_layout %arg2 : tensor<64x128xbf16, [[BLOCKED2]]> -> tensor<64x128xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>>
// CHECK: [[D:%.*]] = tt.dot [[A]], [[B]], [[C]] : tensor<128x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>> * tensor<64x128xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>> -> tensor<128x128xf32, [[DPAS]]>
// CHECK: [[RES:%.*]] = triton_gpu.convert_layout {{.*}} : tensor<128x128xf32, [[DPAS]]> -> tensor<128x128xf32, [[BLOCKED2]]>

%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
%result = tt.dot_scaled %a scale %scale, %b, %cst lhs = e2m1 rhs = bf16 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xbf16, #blocked> -> tensor<128x128xf32, #blocked>
tt.return %result : tensor<128x128xf32, #blocked>
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

#include "triton/Dialect/TritonGPU/IR/Attributes.h"

namespace mlir {
class ModuleOp;
}

#define GET_ATTRDEF_CLASSES
#include "intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.h.inc"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ along the row (resp. col) dimension.
);

let extraClassDeclaration = extraDistributedDeclaration # [{

SmallVector<unsigned> getDPASInstShapeA() const;
SmallVector<unsigned> getDPASInstShapeB() const;
SmallVector<unsigned> getDPASInstShapeC() const;
Expand All @@ -91,7 +90,30 @@ along the row (resp. col) dimension.
return true;
}

SmallVector<unsigned> getContigPerThread();
SmallVector<unsigned> getContigPerThread() const;

struct DPASCapability {
DPASCapability(unsigned minSGSize) : executionSize(minSGSize) {}
DPASCapability() = default;

bool isPVC() const {
return executionSize == 16;
}
bool isFalconShore() const {
return executionSize == 16;
}
bool isATSM() const {
return executionSize == 8;
}

static constexpr unsigned systolicDepth = 8u;
static constexpr unsigned repeatCount = 8u;
static constexpr unsigned opsChanBitWidths = 32u;
unsigned executionSize = 0u;
};

static DPASCapability getDPASCapability(mlir::ModuleOp mod);
static unsigned getOpsPerChannel(Type elemType);
}];

let hasCustomAssemblyFormat = 1;
Expand Down
47 changes: 33 additions & 14 deletions third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ SmallVector<unsigned> DpasEncodingAttr::getElemsPerThreadForOperands(
return elemsPerThread;
};

SmallVector<unsigned> DpasEncodingAttr::getContigPerThread() {
SmallVector<unsigned> DpasEncodingAttr::getContigPerThread() const {
size_t rank = getWarpsPerCTA().size();
assert(rank == 2 || rank == 3);
SmallVector<unsigned> contigPerThread(rank, 1);
Expand All @@ -381,6 +381,32 @@ SmallVector<unsigned> DpasEncodingAttr::getContigPerThread() {
"be smaller than the threads required per row.");
}

DpasEncodingAttr::DPASCapability
DpasEncodingAttr::getDPASCapability(ModuleOp mod) {
assert(mod && "expected a valid module");
if (!mod->hasAttrOfType<IntegerAttr>(
triton::gpu::intel::TritonIntelGPUDialect::getMinSGSizeAttrName()))
return DPASCapability();

unsigned minSGSize =
mod->getAttrOfType<IntegerAttr>(
triton::gpu::intel::TritonIntelGPUDialect::getMinSGSizeAttrName())
.getInt();
assert(minSGSize == 8 || minSGSize == 16 && "unsupported minSGSize");
return DPASCapability(minSGSize);
Comment on lines +387 to +396
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (!mod->hasAttrOfType<IntegerAttr>(
triton::gpu::intel::TritonIntelGPUDialect::getMinSGSizeAttrName()))
return DPASCapability();
unsigned minSGSize =
mod->getAttrOfType<IntegerAttr>(
triton::gpu::intel::TritonIntelGPUDialect::getMinSGSizeAttrName())
.getInt();
assert(minSGSize == 8 || minSGSize == 16 && "unsupported minSGSize");
return DPASCapability(minSGSize);
if (auto minSGSizeAttr = mod->getAttrOfType<IntegerAttr>(
triton::gpu::intel::TritonIntelGPUDialect::getMinSGSizeAttrName())) {
unsigned minSGSize = minSGSizeAttr.getInt();
assert(minSGSize == 8 || minSGSize == 16 && "unsupported minSGSize");
return DPASCapability(minSGSize);
}
return DPASCapability();

}

unsigned DpasEncodingAttr::getOpsPerChannel(Type elemType) {
if (!elemType.isIntOrFloat())
llvm::report_fatal_error("unsupported type for DpasEncodingAttr");

unsigned dpasElemBitWidths = elemType.getIntOrFloatBitWidth();
if (elemType.isFloat8E5M2() || elemType.isFloat8E4M3FN())
dpasElemBitWidths *= 2; // We are upcasting FP8 to FP16.

return DPASCapability::opsChanBitWidths / dpasElemBitWidths;
}

LogicalResult DpasEncodingAttr::verify(
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
unsigned repeatCount, unsigned systolicDepth, unsigned executionSize,
Expand Down Expand Up @@ -469,18 +495,14 @@ void DpasEncodingAttr::print(AsmPrinter &printer) const {
llvm::ArrayRef<unsigned> rC = shapeC;
auto warpsPerCTA = getWarpsPerCTA();
auto repCluster = getRepCluster();
printer << "<{"
<< "repeatCount = " << getRepeatCount() << ", "
printer << "<{" << "repeatCount = " << getRepeatCount() << ", "
<< "systolicDepth = " << getSystolicDepth() << ", "
<< "executionSize = " << getExecutionSize() << ", "
<< "opsPerChan = " << getOpsPerChannel() << ", "
<< "threadsPerWarp = " << getSubGroupSize() << ", "
<< "warpsPerCTA = [" << llvm::ArrayRef<unsigned>(warpsPerCTA) << "], "
<< "repCluster = [" << repCluster << "], "
<< "A = [" << rA << "], "
<< "B = [" << rB << "], "
<< "C = [" << rC << "]"
<< "}>";
<< "repCluster = [" << repCluster << "], " << "A = [" << rA << "], "
<< "B = [" << rB << "], " << "C = [" << rC << "]" << "}>";
}

std::optional<LinearLayout>
Expand Down Expand Up @@ -553,13 +575,10 @@ Attribute WarpEncodingAttr::parse(AsmParser &parser, Type type) {
void WarpEncodingAttr::print(mlir::AsmPrinter &printer) const {
auto threadsPerWarp = getThreadsPerWarp();
auto sizePerThread = getSizePerThread();
printer << "<{"
<< "sizePerThread = [" << llvm::ArrayRef<unsigned>(sizePerThread)
<< "]"
printer << "<{" << "sizePerThread = ["
<< llvm::ArrayRef<unsigned>(sizePerThread) << "]"
<< ", threadsPerWarp = [" << llvm::ArrayRef<unsigned>(threadsPerWarp)
<< "]"
<< ", order = [" << getOrder() << "]"
<< "}>";
<< "]" << ", order = [" << getOrder() << "]" << "}>";
}

//===----------------------------------------------------------------------===//
Expand Down
Loading