Skip to content

Commit

Permalink
[Frontend][Backend] Implement support for scale_dot(-, bf16) (triton-…
Browse files Browse the repository at this point in the history
…lang#4996)

In the passing we also improve a few other things:
- Now `scaled_dot` accepts both uint8/uint16 fp8/bf16 as inputs (before
you had to cast it to uint8, which was weird when extending it to bf16).
- Add `scaled_dot` to the docs and improve the docs overall (have not
render them, might need a few further tweaks)
  • Loading branch information
lezcano authored Oct 30, 2024
1 parent 0591b37 commit 23c9ec1
Show file tree
Hide file tree
Showing 14 changed files with 128 additions and 114 deletions.
1 change: 1 addition & 0 deletions docs/python-api/triton.language.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ Linear Algebra Ops
:nosignatures:

dot
dot_scaled


Memory/Pointer Ops
Expand Down
9 changes: 5 additions & 4 deletions include/triton/Dialect/Triton/IR/TritonAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,16 @@ def TT_InputPrecisionAttr : I32EnumAttr<
let cppNamespace = "::mlir::triton";
}

// Type for F8F6F4 kind of floats.
def TT_F8F6F4TypeAttr : I32EnumAttr<
"F8F6F4Type", "",
// Type for ScaleDotElemType kind of floats.
def TT_ScaleDotElemTypeAttr : I32EnumAttr<
"ScaleDotElemType", "",
[
I32EnumAttrCase<"E4M3", 0, "e4m3">,
I32EnumAttrCase<"E5M2", 1, "e5m2">,
I32EnumAttrCase<"E2M3", 2, "e2m3">,
I32EnumAttrCase<"E3M2", 3, "e3m2">,
I32EnumAttrCase<"E2M1", 4, "e2m1">
I32EnumAttrCase<"E2M1", 4, "e2m1">,
I32EnumAttrCase<"BF16", 5, "bf16">

]>{
let cppNamespace = "::mlir::triton";
Expand Down
16 changes: 8 additions & 8 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -685,15 +685,15 @@ def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,

let arguments = (
ins
// inputs are integer types as they are packed types and we currently
// don't have a representation for those.
TT_IntTensor:$lhs,
TT_IntTensor:$rhs,
// inputs are floats if we have a type for them, otherwise (fp4),
// they are packed in pairs in an I8Tensor
RankedTensorOf<[TT_Float,I8]>:$lhs,
RankedTensorOf<[TT_Float,I8]>:$rhs,
TT_FloatTensor:$c,
TT_IntTensor:$lhs_scale,
Optional<TT_IntTensor>:$rhs_scale,
TT_F8F6F4TypeAttr:$lhs_type,
TT_F8F6F4TypeAttr:$rhs_type
RankedTensorOf<[I8]>:$lhs_scale,
Optional<RankedTensorOf<[I8]>>:$rhs_scale,
TT_ScaleDotElemTypeAttr:$lhs_type,
TT_ScaleDotElemTypeAttr:$rhs_type
);

let results = (outs TT_FloatTensor:$d);
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def TTG_UpcastMXFPOp : TTG_Op<"upcast_mxfp", [Pure, DeclareOpInterfaceMethods<In
let arguments = (ins
TT_Tensor:$src,
TT_Tensor:$scale,
TT_F8F6F4TypeAttr:$fp_type);
TT_ScaleDotElemTypeAttr:$fp_type);
let results = (outs TT_Tensor:$result);

let assemblyFormat = [{
Expand Down
8 changes: 4 additions & 4 deletions lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ LogicalResult UpcastMXFPOp::verify() {
"operands must have the same number of dimensions, at least 2");
}

if (!(fpType == F8F6F4Type::E2M1 || fpType == F8F6F4Type::E4M3 ||
fpType == F8F6F4Type::E5M2)) {
if (!(fpType == ScaleDotElemType::E2M1 || fpType == ScaleDotElemType::E4M3 ||
fpType == ScaleDotElemType::E5M2)) {
return emitOpError("NYI: fpType must be E2M1, E4M3, or E5M2");
}

// Change to support fp8 types
const auto elems_packed = fpType == F8F6F4Type::E2M1 ? 2 : 1;
const auto elems_packed = fpType == ScaleDotElemType::E2M1 ? 2 : 1;

if (xShape.back() != (32 / elems_packed) * scaleShape.back()) {
return emitOpError("last dimension of first operand must be 16 times "
Expand Down Expand Up @@ -93,7 +93,7 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
return emitOptionalError(loc, "expected a dotOperand encoding");
}

if (typeEncoded == F8F6F4Type::E2M1) {
if (typeEncoded == ScaleDotElemType::E2M1) {
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
auto newVEncoding = DotOperandEncodingAttr::get(
ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(),
Expand Down
61 changes: 26 additions & 35 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -415,22 +415,12 @@ class ScaledBlockedToMMAv2
auto aType = dotOp.getLhsType();
auto bType = dotOp.getRhsType();

auto enumToType = [&rewriter](F8F6F4Type type) {
switch (type) {
case F8F6F4Type::E4M3:
return rewriter.getFloat8E4M3FNType();
case F8F6F4Type::E5M2:
return rewriter.getFloat8E5M2Type();
default:
llvm_unreachable("unexpected type");
}
};

assert((aType == F8F6F4Type::E4M3 || aType == F8F6F4Type::E5M2 ||
aType == F8F6F4Type::E2M1) &&
assert((aType == ScaleDotElemType::E4M3 ||
aType == ScaleDotElemType::E5M2 ||
aType == ScaleDotElemType::E2M1) &&
"NYI: lhs supports fp4 or fp8");
assert(bType == F8F6F4Type::E4M3 ||
bType == F8F6F4Type::E5M2 && "NYI: rhs supports fp8");
assert(bType == ScaleDotElemType::E4M3 || bType == ScaleDotElemType::E5M2 ||
bType == ScaleDotElemType::BF16 && "NYI: rhs supports fp8 and bf16");

// TODO run accelerate matmul on A and B first to choose their layouts
// Set return type
Expand All @@ -454,11 +444,12 @@ class ScaledBlockedToMMAv2
auto newAcc =
rewriter.create<ConvertLayoutOp>(oldAcc.getLoc(), newRetType, oldAcc);

auto toMMABf16 = [&newRetType, &rewriter, &ctx, &enumToType](
TypedValue<RankedTensorType> v, int idx,
F8F6F4Type type) -> TypedValue<RankedTensorType> {
auto toMMABf16 =
[&newRetType, &rewriter,
&ctx](TypedValue<RankedTensorType> v, int idx,
ScaleDotElemType type) -> TypedValue<RankedTensorType> {
auto vType = v.getType();
if (type == F8F6F4Type::E2M1) {
if (type == ScaleDotElemType::E2M1) {
// A bit too dynamically typed...
// perhaps return ints in both cases?

Expand All @@ -469,23 +460,23 @@ class ScaledBlockedToMMAv2
vType.getShape(), vType.getElementType(), newVEncoding);
return rewriter.create<ConvertLayoutOp>(v.getLoc(), newVType, v);
} else {
assert(type == F8F6F4Type::E5M2 || type == F8F6F4Type::E4M3);
assert(type == ScaleDotElemType::E5M2 ||
type == ScaleDotElemType::E4M3 ||
type == ScaleDotElemType::BF16);
auto newVEncoding = DotOperandEncodingAttr::get(
ctx, idx, newRetType.getEncoding(), /*kWidth=*/8);
auto newVType = RankedTensorType::get(
vType.getShape(), vType.getElementType(), newVEncoding);
v = rewriter.create<ConvertLayoutOp>(v.getLoc(), newVType, v);

// Bitcast
auto vTypeFp8 = RankedTensorType::get(vType.getShape(),
enumToType(type), newVEncoding);
v = cast<TypedValue<RankedTensorType>>(
rewriter.create<BitcastOp>(v.getLoc(), vTypeFp8, v).getResult());

// Convert to bf16
auto vTypeBf16 = RankedTensorType::get(
vType.getShape(), rewriter.getBF16Type(), newVEncoding);
return rewriter.create<FpToFpOp>(v.getLoc(), vTypeBf16, v);
if (type == ScaleDotElemType::BF16) {
return v;
} else {
// Convert to bf16
auto vTypeBf16 = RankedTensorType::get(
vType.getShape(), rewriter.getBF16Type(), newVEncoding);
return rewriter.create<FpToFpOp>(v.getLoc(), vTypeBf16, v);
}
}
};
a = toMMABf16(a, 0, aType);
Expand Down Expand Up @@ -515,11 +506,11 @@ class ScaledBlockedToMMAv2
auto newScaleEncoding = triton::gpu::BlockedEncodingAttr::get(
ctx, {1, 1}, threadsPerWarp, warpsPerCTA, {1, 0}, CTALayout);

auto newScaleType = RankedTensorType::get(scale.getType().getShape(),
scale.getType().getElementType(),
newScaleEncoding);
scale =
rewriter.create<ConvertLayoutOp>(scale.getLoc(), newScaleType, scale);
auto newScaleDotElemType = RankedTensorType::get(
scale.getType().getShape(), scale.getType().getElementType(),
newScaleEncoding);
scale = rewriter.create<ConvertLayoutOp>(scale.getLoc(),
newScaleDotElemType, scale);

auto scaledA = rewriter.create<triton::gpu::UpcastMXFPOp>(
dotOp.getLoc(), a, scale, dotOp.getLhsType());
Expand Down
19 changes: 10 additions & 9 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,13 @@ void init_triton_ir(py::module &&m) {
.value("IEEE", InputPrecision::IEEE)
.export_values();

py::enum_<F8F6F4Type>(m, "F8F6F4TY", py::module_local())
.value("E4M3", F8F6F4Type::E4M3)
.value("E5M2", F8F6F4Type::E5M2)
.value("E2M3", F8F6F4Type::E2M3)
.value("E3M2", F8F6F4Type::E3M2)
.value("E2M1", F8F6F4Type::E2M1)
py::enum_<ScaleDotElemType>(m, "ScaleDotElemTypeTY", py::module_local())
.value("E4M3", ScaleDotElemType::E4M3)
.value("E5M2", ScaleDotElemType::E5M2)
.value("E2M3", ScaleDotElemType::E2M3)
.value("E3M2", ScaleDotElemType::E3M2)
.value("E2M1", ScaleDotElemType::E2M1)
.value("BF16", ScaleDotElemType::BF16)
.export_values();

py::class_<MLIRContext>(m, "context", py::module_local())
Expand Down Expand Up @@ -1423,9 +1424,9 @@ void init_triton_ir(py::module &&m) {
})
.def("create_dot_scaled",
[](TritonOpBuilder &self, mlir::Value &lhs, mlir::Value &lhs_scale,
F8F6F4Type lhs_format, mlir::Value &rhs,
std::optional<mlir::Value> &rhs_scale, F8F6F4Type rhs_format,
mlir::Value &c) -> mlir::Value {
ScaleDotElemType lhs_format, mlir::Value &rhs,
std::optional<mlir::Value> &rhs_scale,
ScaleDotElemType rhs_format, mlir::Value &c) -> mlir::Value {
return self.create<DotScaledOp>(
c.getType(), lhs, rhs, c, lhs_scale,
rhs_scale.value_or(Value()), lhs_format, rhs_format);
Expand Down
31 changes: 16 additions & 15 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3330,7 +3330,7 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
for M, N, K in itertools.product([32, 64, 128], [32, 64, 128], [64, 128])
for col_a, col_b in itertools.product([True, False], repeat=2)
for type_a in ["e2m1", "e4m3", "e5m2"]
for type_b in ["e4m3", "e5m2"]
for type_b in ["e4m3", "e5m2", "bf16"]
for mma in ([32, 16] if is_hip() else [16])
for kpack in ([1, 2] if is_hip() else [1])])
def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack, device):
Expand All @@ -3351,7 +3351,7 @@ def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack
def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, out,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, type_a: tl.constexpr,
type_b: tl.constexpr):
tl.static_assert(type_b == "e4m3" or type_b == "e5m2", "type_b must be fp8")
tl.static_assert((type_b == "e4m3" or type_b == "e5m2") or type_b == "bf16", "type_b must be fp8 or bf16")
IS_FP8: tl.constexpr = type_a == "e4m3" or type_a == "e5m2"
DIV_FACTOR: tl.constexpr = 1 if IS_FP8 else 2
PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K // DIV_FACTOR
Expand Down Expand Up @@ -3442,7 +3442,7 @@ def mxfp_to_bf16_kernel(

def dot_scale_ref(x, scale, y, type_x, type_y):
e_bits, m_bits = {"e2m1": (2, 1), "e4m3": (4, 3), "e5m2": (5, 2)}[type_x]
type_fp8_y = {"e4m3": torch.float8_e4m3fn, "e5m2": torch.float8_e5m2}[type_y]
type_y = {"e4m3": torch.float8_e4m3fn, "e5m2": torch.float8_e5m2, "bf16": torch.bfloat16}[type_y]

comp_dtype = torch.bfloat16

Expand All @@ -3455,7 +3455,7 @@ def dot_scale_ref(x, scale, y, type_x, type_y):
mxfp_to_bf16_kernel[grid](x, scale, x_upcast, scale.numel(), e_bits, m_bits, BLOCK_SIZE, num_warps=num_warps)
assert x_upcast.isfinite().all()

y_upcast = y.view(type_fp8_y).to(comp_dtype)
y_upcast = y.view(type_y).to(comp_dtype)

class AccumulateInFp32:

Expand All @@ -3467,28 +3467,30 @@ def __exit__(self, exc_type, exc_val, exc_tb):
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = self.prev_value

with AccumulateInFp32():
return torch.matmul(x_upcast.to(comp_dtype), y_upcast.to(comp_dtype))
return torch.matmul(x_upcast, y_upcast)

torch.manual_seed(0)

def create_uint8(shape, col_major=False, max_val=255):
def make_arg(shape, ty, col_major=False, max_val=255):
if col_major:
shape = shape[:-2] + (shape[-1], shape[-2])
ret = torch.randint(max_val + 1, shape, dtype=torch.uint8, device=device)
if ty == "bf16":
ret = torch.randn(shape, dtype=torch.bfloat16, device=device)
# Clamp to avoid relative error issues
ret.clamp_(-2**15, 2**15 - 1)
else:
ret = torch.randint(max_val + 1, shape, dtype=torch.uint8, device=device)
if col_major:
ret = ret.mT
return ret

DIV_FACTOR = 2 if type_a == "e2m1" else 1
x = create_uint8((M, K // DIV_FACTOR), col_major=col_a)
y = create_uint8((K, N), col_major=col_b)
x = make_arg((M, K // DIV_FACTOR), type_a, col_major=col_a)
y = make_arg((K, N), type_b, col_major=col_b)

# sample scales that don't overflow as otherwise it's implementation defined (underflowing is alright)
# We substract a reasonably high number (64) so that the sum of all the mxfp elements does not overflow
m_bytes = int(type_a[1])
bias_type_a = 1 << (m_bytes - 1) - 1
max_exponent_type_a = (1 << m_bytes) - 1 - bias_type_a
scale_x = create_uint8((M, K // 32), max_val=255 - max_exponent_type_a - 64)
# Max scale= 2**15
scale_x = make_arg((M, K // 32), "e8m0", max_val=127 + 15)

def make_finite(x, dtype):
# e5m2 has too many non-finite values when sampled uniformly (1 / 32) and
Expand All @@ -3513,7 +3515,6 @@ def make_finite(x, dtype):

z_ref = dot_scale_ref(x, scale_x, y, type_a, type_b)

# generous rtol as we are sampling the whole range of floats
torch.testing.assert_close(z, z_ref, atol=1e-5, rtol=1e-2)

# make sure ld/st are vectorized
Expand Down
14 changes: 8 additions & 6 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,15 +1556,17 @@ def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None,
lhs and rhs use microscaling formats described here:
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
:param lhs: The first tensor to be multiplied.
:type lhs: 2D tensor of f8, f6 or f4 format packed in int32 format.
:type lhs: 2D tensor representing fp4 or fp8 elements packed into uint8 for fp4 inputs, or in uint8 or the corresponding fp8 type for fp8 inputs.
:param lhs_scale: Scale factor for lhs tensor.
:type lhs_scale: ue8m0 float8 type (currently represented as an int8 tensor).
:param lhs_format: format of the lhs tensor, available formats: {:code:`e4m3`, :code: `e5m2`, :code:`e2m3`, :code:`e3m2`, :code:`e2m1`}.
:type lhs_scale: e8m0 type represented as an uint8 tensor.
:param lhs_format: format of the lhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code: `e5m2`}.
:type lhs_format: str
:param rhs: The second tensor to be multiplied.
:type rhs: 2D tensor of f8, f6 or f4 format packed in int32 format.
:type rhs: 2D tensor representing fp8 or bf16 elements in uint8 or the corresponding fp8 type for fp8 inputs or bf16 for bf16 inputs.
:param rhs_scale: Scale factor for rhs tensor.
:type rhs_scale: ue8m0 float8 type (currently represented as an int8 tensor).
:param rhs_format: format of the rhs tensor, available formats: {:code:`e4m3`, :code: `e5m2`, :code:`e2m3`, :code:`e3m2`, :code:`e2m1`}.
:type rhs_scale: e8m0 type represented as an uint8 tensor.
:param rhs_format: format of the rhs tensor. Available formats: {:code:`e4m3`, :code: `e5m2`, :code:`bf16`}.
:type rhs_format: str
:param acc: The accumulator tensor. If not None, the result is added to this tensor.
"""
out_dtype = _constexpr_to_value(out_dtype)
Expand Down
49 changes: 32 additions & 17 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,33 +1527,48 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optiona
ret_ty)


def _str_to_fp_type(float_format: Optional[str]):
if float_format == 'e4m3':
return ir.F8F6F4TY.E4M3
if float_format == 'e5m2':
return ir.F8F6F4TY.E5M2
if float_format == 'e2m3':
return ir.F8F6F4TY.E2M3
if float_format == 'e3m2':
return ir.F8F6F4TY.E3M2
if float_format == 'e2m1':
return ir.F8F6F4TY.E2M1
raise ValueError(f"Invalid float format: {float_format}.")


def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format, rhs: tl.tensor, rhs_scale: Optional[tl.tensor],
rhs_format, acc: tl.tensor | None, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
def _str_to_fp_type(float_format: str):
ty_enum = getattr(ir.ScaleDotElemTypeTY, float_format.upper(), None)
if ty_enum is None:
raise ValueError(f"Invalid float format: {float_format}.")
return ty_enum


def _bitcast_to_fp_type(val: tl.tensor, float_format: str, builder: ir.builder):
"""
If float_format is subbyte, make sure it's packed as uint8 and return it.
Otherwise, return a tensor (perhaps bitcasting) of the specified float format.
"""
triton_ty = {"e5m2": tl.float8e5, "e4m3": tl.float8e4nv, "bf16": tl.bfloat16}.get(float_format)
if triton_ty is None:
assert float_format == "e2m1", f"Internal Error: Unexpected float format: {float_format}"
assert val.dtype == tl.uint8, f"e2m1 format must be packed as uint8. Got {val.dtype}"
return val
if val.dtype == triton_ty:
return val
else:
unsigned_ty = {"e5m2": tl.uint8, "e4m3": tl.uint8, "bf16": tl.uint16}[float_format]
assert val.dtype == unsigned_ty, f"Unexpected dtype for {float_format}. Got {val.dtype}"
return bitcast(val, triton_ty, builder)


def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format: str, rhs: tl.tensor, rhs_scale: Optional[tl.tensor],
rhs_format: str, acc: tl.tensor | None, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
assert lhs.type.is_block() and rhs.type.is_block()
#TODO: validate types.
lhs_rank = len(lhs.shape)
rhs_rank = len(rhs.shape)
assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})"
lhs_format: str = lhs_format.value
rhs_format: str = rhs_format.value
lhs_format_enum = _str_to_fp_type(lhs_format)
rhs_format_enum = _str_to_fp_type(rhs_format)
assert lhs_format in ("e2m1", "e4m3", "e5m2"), f"NYI: lhs_format {lhs_format}"
assert rhs_format in ("e4m3", "e5m2"), f"NYI: rhs_format {rhs_format}"
assert rhs_format in ("e4m3", "e5m2", "bf16"), f"NYI: rhs_format {rhs_format}"
rhs_scale_is_none = isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None
assert rhs_scale_is_none, "NYI: rhs_scale not supported"
lhs = _bitcast_to_fp_type(lhs, lhs_format, builder)
rhs = _bitcast_to_fp_type(rhs, rhs_format, builder)

M = lhs.type.shape[-2]
K, N = rhs.type.shape[-2:]
Expand Down
Loading

0 comments on commit 23c9ec1

Please sign in to comment.