From 58dec654c61aa02483179253d295bf697db532d1 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 19 Jul 2023 15:51:28 -0700 Subject: [PATCH] Python-only float8 data type + bare bones UEX Summary: This is a copy of https://github.com/facebookexperimental/protoquant/pull/23 Many things will change based on recent discussions! Test Plan: ``` python float8_playground/test.py ``` Reviewers: Subscribers: Tasks: Tags: --- .gitignore | 1 + float8_playground/float8_aten_api.py | 62 +++++++++ float8_playground/float8_linear.py | 122 +++++++++++++++++ float8_playground/float8_tensor.py | 141 ++++++++++++++++++++ float8_playground/float8_utils.py | 178 +++++++++++++++++++++++++ float8_playground/test.py | 192 +++++++++++++++++++++++++++ 6 files changed, 696 insertions(+) create mode 100644 .gitignore create mode 100644 float8_playground/float8_aten_api.py create mode 100644 float8_playground/float8_linear.py create mode 100644 float8_playground/float8_tensor.py create mode 100644 float8_playground/float8_utils.py create mode 100644 float8_playground/test.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..f1e26ba7 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +float8_playground/__pycache__/* diff --git a/float8_playground/float8_aten_api.py b/float8_playground/float8_aten_api.py new file mode 100644 index 00000000..933d1433 --- /dev/null +++ b/float8_playground/float8_aten_api.py @@ -0,0 +1,62 @@ +""" +This file defines the aten functions for float8. Today, all of these functions +are emulated. In the future, they should be calling NVIDIA's float8 kernels. +""" + +import torch +from torch.library import Library + +from float8_utils import ( + float32_to_float8, + float8_to_float32, + E4M3, + E5M2, + tensor_to_scale, +) + + +def mm_float8(m1, s1, flavor1, m2, s2, flavor2, s3, flavor3): + # naive implementation: dq -> op -> q + # TODO(future): hook up to real kernel + m1_fp32 = float8_to_float32(m1, flavor1) / s1 + m2_fp32 = float8_to_float32(m2, flavor2) / s2 + m3_fp32 = torch.mm(m1_fp32, m2_fp32) + # TODO(future): switch to delayed scaling + s3.fill_(tensor_to_scale(m3_fp32, flavor3)) + m3_fp32_scaled = m3_fp32 * s3 + return float32_to_float8(m3_fp32_scaled, flavor3) + +def add_float8_e5m2(m1, s1, m2, s2, s3): + # for now this is only implemented for e5m2 because we only care about + # this for adding gradients + # naive implementation: dq -> op -> q + # TODO(future): hook up to real kernel + # TODO(future): make this more accurate, accuracy is pretty low, + # can probably just calculate s3 dynamically since this is an edge case + # unlikely to affect e2e performance + m1_float32 = float8_to_float32(m1, E5M2) / s1 + m2_float32 = float8_to_float32(m2, E5M2) / s2 + m3_float32 = m1_float32 + m2_float32 + return float32_to_float8(m3_float32 * s3, E5M2) + +# +# ATen op placeholders +# + +# Register the aten level functions we need. +# These are mostly placeholder and might need to be implemented in c++ as needed +lib = Library("aten", "FRAGMENT") + +# For now register on CPU, +# TODO(future) add GPU and test there +lib.define("float32_to_float8(Tensor t, int flavor) -> Tensor") +lib.impl("float32_to_float8", float32_to_float8, "CPU") + +lib.define("float8_to_float32(Tensor t, int flavor) -> Tensor") +lib.impl("float8_to_float32", float8_to_float32, "CPU") + +lib.define("mm_float8(Tensor m1, Tensor s1, int flavor1, Tensor m2, Tensor s2, int flavor2, Tensor s3, int flavor3) -> Tensor") +lib.impl("mm_float8", mm_float8, "CPU") + +lib.define("add_float8_e5m2(Tensor m1, Tensor s1, Tensor m2, Tensor s2, Tensor s3) -> Tensor") +lib.impl("add_float8_e5m2", add_float8_e5m2, "CPU") diff --git a/float8_playground/float8_linear.py b/float8_playground/float8_linear.py new file mode 100644 index 00000000..c2382ab8 --- /dev/null +++ b/float8_playground/float8_linear.py @@ -0,0 +1,122 @@ +""" +A simple manual UEX for a float8 version of `torch.nn.Linear`. + +Note: this UEX is not intended for real usage. It merely demonstrates +an example of how features such as casting to and from float8 as well +as stateful scaling can be implemented. For now, we expect framework +owners to implement their own UEX. +""" + +import torch + +import float8_aten_api + +from float8_utils import E4M3, E5M2, tensor_to_scale +from float8_tensor import Float8Tensor + +class float8_linear_no_bias(torch.autograd.Function): + """ + Like F.linear, but with X, W, and Y in float8 + TODO(future) add logic for bias + """ + + @staticmethod + def forward( + ctx, + x_fp8, + w_fp8, + fp8_s_out, + fp8_s_dL_dX, + fp8_s_dL_dW, + fp8_s_dL_dY, + ): + ctx.save_for_backward(x_fp8, w_fp8, fp8_s_dL_dX, fp8_s_dL_dW, fp8_s_dL_dY) + + res_bits = torch.ops.aten.mm_float8( + x_fp8._data, x_fp8._scale, x_fp8._flavor, + w_fp8._data.t(), w_fp8._scale, w_fp8._flavor, + fp8_s_out, E4M3) + + res = Float8Tensor(res_bits, fp8_s_out, E4M3) + # scale update would also happen here, for now no-op + return res + + @staticmethod + def backward(ctx, go): + x_fp8, w_fp8, fp8_s_dL_dX, fp8_s_dL_dW, fp8_s_dL_dY = \ + ctx.saved_tensors + + if not isinstance(go, Float8Tensor): + # TODO(future): switch to delayed scaling + fp8_s_dL_dY.fill_(tensor_to_scale(go, E5M2)) + go_fp8 = Float8Tensor( + torch.ops.aten.float32_to_float8(go * fp8_s_dL_dY, E5M2), + fp8_s_dL_dY, + E5M2) + else: + go_fp8 = go + + dL_dX_bits = torch.ops.aten.mm_float8( + go_fp8._data, go_fp8._scale, go_fp8._flavor, + w_fp8._data, w_fp8._scale, w_fp8._flavor, + fp8_s_dL_dX, E5M2) + dL_dX_fp8 = Float8Tensor(dL_dX_bits, fp8_s_dL_dX, E5M2) + + dL_dW_bits = torch.ops.aten.mm_float8( + x_fp8._data.t(), x_fp8._scale, x_fp8._flavor, + go_fp8._data, go_fp8._scale, go_fp8._flavor, + fp8_s_dL_dW, E5M2).t() + dL_dW_fp8 = Float8Tensor(dL_dW_bits, fp8_s_dL_dW, E5M2) + + # scale update would also happen here, for now no-op + return dL_dX_fp8, dL_dW_fp8, None, None, None, None + + +class Float8Linear(torch.nn.Linear): + """ + A wrapper around a `torch.nn.Linear` module which does fp8 compute, and tracks + scales in way friendly to delayed scaling. + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # While this module currently implements just-in-time scaling, + # the scales are stored in buffers as a placeholder for delayed + # scaling such as the mechanism described in + # https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#Mixed-precision-training-with-FP8, + # or PTQ calibration. + self.register_buffer('fp8_s_in', torch.tensor(1.0)) + self.register_buffer('fp8_s_weight', torch.tensor(1.0)) + self.register_buffer('fp8_s_out', torch.tensor(1.0)) + self.register_buffer('fp8_s_dL_dX', torch.tensor(1.0)) + self.register_buffer('fp8_s_dL_dW', torch.tensor(1.0)) + self.register_buffer('fp8_s_dL_dY', torch.tensor(1.0)) + + def forward(self, x): + if not isinstance(x, Float8Tensor): + # TODO(future): switch to delayed scaling + self.fp8_s_in.fill_(tensor_to_scale(x, E4M3)) + x_fp8 = Float8Tensor.from_float32(x, self.fp8_s_in, E4M3) + else: + x_fp8 = x + + # TODO(future): switch to delayed scaling + self.fp8_s_weight.fill_(tensor_to_scale(self.weight, E4M3)) + w_fp8 = Float8Tensor.from_float32(self.weight, self.fp8_s_weight, E4M3) + + y_fp8 = float8_linear_no_bias.apply( + x_fp8, w_fp8, self.fp8_s_out, self.fp8_s_dL_dX, + self.fp8_s_dL_dW, self.fp8_s_dL_dY) + + # For now, hardcode returning Float8Tensor (propagate as much as we can). + # This can be changed to return a different dtype, if needed. + return y_fp8 + + @classmethod + def from_float(cls, mod): + """ + Create an nn.Linear with fp8 compute from a regular nn.Linear + """ + assert mod.bias is None, 'bias support not implemented yet' + new_mod = cls(mod.in_features, mod.out_features, bias=False) + new_mod.weight = mod.weight + return new_mod diff --git a/float8_playground/float8_tensor.py b/float8_playground/float8_tensor.py new file mode 100644 index 00000000..4c39f99b --- /dev/null +++ b/float8_playground/float8_tensor.py @@ -0,0 +1,141 @@ +from enum import Enum +import torch +from torch.utils._pytree import tree_map + +from float8_utils import E4M3, E5M2 + +aten = torch.ops.aten + +class Float8ConstrFunc(torch.autograd.Function): + """ + A differentiable conversion between fp32 and fp8 + TODO(future): split into two for cleaner code + """ + @staticmethod + def forward(ctx, tensor, scale: float=None, flavor=E4M3): + if isinstance(tensor, Float8Tensor): + ctx.inp_is_float8 = True + return torch.ops.aten.float8_to_float32(tensor._data, tensor._flavor) / tensor._scale + else: + ctx.inp_is_float8 = False + tensor_scaled = tensor * scale + bits_fp8 = torch.ops.aten.float32_to_float8(tensor_scaled, flavor) + return Float8Tensor(bits_fp8, scale, flavor) + + @staticmethod + def backward(ctx, g): + # Assume that we always want to scale the gradients + # back to full precision. We could do something else + if isinstance(g, Float8Tensor) and not ctx.inp_is_float8: + return g.to_float32(), None, None + elif ctx.inp_is_float8: + return Float8Tensor.from_float32(g), None, None + else: + return g, None, None + + +class Float8Tensor(torch.Tensor): + """ + A Python-only FP8 tensor. Contains: + * `_data`: the underlying e4m3 or e5m2 data + * `_scale`: the scale used to scale the original fp32 tensor. We multiply + by scale to go from fp32 range to fp8 range, and divide by scale to go + from fp8 range to fp32 range. + * `_flavor`: either E4M3 or E5M2 + + The current purpose of this object is 99% to bundle raw data + fp8 metadata + together for easy passing through PyTorch systems, and 1% to implement + gradient addition (since that has to happen outside of user code). + + The addition operation is defined inline and uses a naive + version of stateless scaling. This allows e5m2 gradients to be added. + TODO(future): verify this is numericaly accurate, optionally replace + with something better. + + It would probably make sense to also define fp8 path for data shuffling + ops like cat, transpose, view, etc inline so we don't have to fall back + to fp32 for them. + """ + + def __new__(cls, data, scale, flavor): + # This is a non-differentiable constructor! + assert not data.requires_grad + # TODO(future): make bits8 easier to work with and switch to using it + # assert data.dtype == torch.bits8 + assert scale.dtype == torch.float32 + assert scale.nelement() == 1 + + self = torch.Tensor._make_wrapper_subclass( + cls, + data.size(), + strides=data.stride(), + storage_offset=data.storage_offset(), + dtype=torch.float32, + layout=data.layout, + requires_grad=data.requires_grad, + device=data.device, + ) + self._data = data + self._scale = scale + self._flavor = flavor + + return self + + def __repr__(self): + return f"Float8Tensor(flavor={self._flavor}, scale={self._scale}, as_float32={self.to_float32()}" + + def to_float32(self): + return Float8ConstrFunc.apply(self) + + @classmethod + def from_float32(cls, tensor, scale, flavor): + return Float8ConstrFunc.apply(tensor, scale, flavor) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + # Note: unlike many other subclasses, this subclass's only propagates + # itself for addition (for gradient addition in backward). For all + # other ops, it self-converts to fp32. The user/framework is + # assumed to take care of defining where fp8 operations occur in the + # forward pass and how scaling is calculated. In this example, that is + # done by the `FP8Linear` class. + # Vasiliy: the main reason I went with ^ is because NVIDIA is + # doing stateful delayed scaling, and I don't know of a safe + # way to enable that without either full program capture or punting it + # to the user. This prototype takes the "punt it to the user" approach. + # IMO for now let's just write out the scale stuff manually so we can + # focus on other things, and revisit later if needed. + + # override addition so we can add e5m2 gradients + if ( + func is aten.add.Tensor + and isinstance(args[0], Float8Tensor) + and isinstance(args[1], Float8Tensor) + ): + x1_fp8, x2_fp8 = args[0], args[1] + assert x1_fp8._flavor == E5M2 and x2_fp8._flavor == E5M2 + # naive scale calculation: max of incoming two scales + x3_scale = torch.max(x1_fp8._scale, x2_fp8._scale) + res_bits = torch.ops.aten.add_float8_e5m2( + x1_fp8._data, x1_fp8._scale, + x2_fp8._data, x2_fp8._scale, + x3_scale) + res = Float8Tensor(res_bits, x3_scale, x1_fp8._flavor) + return res + + # for all other ops, fall back to fp32 + # TODO(future): add support for fp16/bf16 + + def maybe_unwrap(t): + if isinstance(t, Float8Tensor): + return t.to_float32() + return t + + args = tree_map(maybe_unwrap, args) + if kwargs is not None: + kwargs = tree_map(maybe_unwrap, kwargs) + out = super().__torch_dispatch__(func, types, args, kwargs) + return out + + # Do not force the Float8Tensor type on the returned tensor + __torch_function__ = torch._C._disabled_torch_function_impl diff --git a/float8_playground/float8_utils.py b/float8_playground/float8_utils.py new file mode 100644 index 00000000..b9bf9ee7 --- /dev/null +++ b/float8_playground/float8_utils.py @@ -0,0 +1,178 @@ +import torch + +# This file reproduces the emulated fp8 <-> fp32 casts from +# https://github.com/pytorch/FBGEMM/pull/974/files with plain PyTorch. +# This implements the fp8 format spec from https://arxiv.org/pdf/2209.05433.pdf +# +# TODO(future PR): hook up to NVIDIA's casts on gpu, and +# import the fbgemm emulator directly on cpu. We'll also need to ensure +# the two are aligned. +# +# Helpful visualizer for debugging (only supports fp32): +# https://www.h-schmidt.net/FloatConverter/IEEE754.html + +# define the e4m3/e5m2 constants +E4M3_EBITS = 4 +E4M3_EXP_BIAS = 7 +E4M3_MAX_POS = 448.0 + +E5M2_EBITS = 5 +E5M2_EXP_BIAS = 15 +E5M2_MAX_POS = 57344.0 + +# avoid division by zero when calculating scale +# TODO: align this value with NVIDIA's assumptions (current value is a guess) +EPS = 1e-12 + +# enum without using an enum, for brevity +# TODO(future): make this into an enum if needed +E4M3 = 0 +E5M2 = 1 + + +def _float_to_hfp8( + val_fp: torch.Tensor, # fp32 values + ebits: int, # exponent bits, mbits = 8 - ebits + exponent_bias: int, # exponent bias to use in the fp8 encoding + max_pos: float, # maximum positive number for fp8 encoding +): + mbits = 7 - ebits + + val_out = val_fp.clone().detach() + + # S{1}, E{8}, M{23} + sign_bit = (val_out.view(torch.int32) & 0x80000000).view(torch.float) + + # set sign bit to 0 + # 0{1}, E{8}, M{23} + val_out = (val_out.view(torch.int32) & 0x7FFFFFFF).view(torch.float) + + # ensure abs(val_out) <= max_pos) + val_out = torch.clamp(val_out, max=max_pos) + + smallest_normal = torch.zeros_like(val_fp).to(torch.int32) + smallest_normal = ((smallest_normal + 127 - exponent_bias + 1) << 23).view(torch.float) + + # normal and denormal paths split below, record + # which element need which path + is_normal_mask = torch.ge(val_out, smallest_normal) + + # + # normal path + # + + # Use round to nearest even. We make use of the standard rounding mechanism + # in FP32 rather than rounding the mantissa and handling tie-to-even and + # incrementing exponent We want to round of 23-mbits of the FP32 value + # val_in This can be done by adding a power of 2 exactly 23-mbits larger + # than the exponent of val_in This forces val_in to be moved to the right + # and rounding exact at the location corresponding to having mbits of + # explicit mantissa left + n_bouncer = ((val_out.view(torch.int32) & 0xFF800000) + ((23 - mbits) << 23)).view(torch.float) + n_val_out = (n_bouncer + val_out) - n_bouncer + + # adding the bouncer rounds off bits, and subtracting bouncer + # leaves the desired value, albeit in FP32 encoding + # All we need is to change the exponent encoding to using "bias" + n_val_out_i = (n_val_out.view(torch.int32) - ((127 - exponent_bias) << 23)) << (8 - ebits) + n_val_out_i = (n_val_out_i | sign_bit.view(torch.int32)) >> 24 + n_val_out = n_val_out_i.view(torch.float) + + # + # denormal path + # + + # When the value is in the denormal range, IEEE numbers essentially becomes + # a fixed point number. The lsb is the smallest non-zero number + # 2^(1-bias-mbits) Hence, we define the bouncer so that its lsb is this + # smallest non-zero number Adding the input to this bouncer forces rounding + # to occur appropriately Also, in this situation, after adding the bouncer, + # the 8 least significant bits of the sum is already the HFP8 encoding of + # the desired result. Just need to restore the sign bit + # bouncer.I = (127 + (23 + (1 - exponent_bias - mbits))) << 23; + # val_out.F = bouncer.F + val_out.F; + # val_out.I = val_out.I | (sign_bit >> 24); + + dn_bouncer = ((torch.zeros_like(val_out).view(torch.int32) + 127 + (23 + (1 - exponent_bias - mbits))) << 23).view(torch.float) + dn_val_out = dn_bouncer + val_out + dn_val_out_i = dn_val_out.view(torch.int32) | (sign_bit.view(torch.int32) >> 24) + dn_val_out = dn_val_out_i.view(torch.float) + + # + # combine normal and denormal paths + # + val_out = torch.where(is_normal_mask, n_val_out, dn_val_out) + # take the 8 least significant bits + orig_shape = val_fp.shape + val_out = val_out.view(torch.uint8) + val_out = val_out.reshape(-1, 4) + val_out = torch.tensor_split(val_out, 4, dim=-1)[0] + val_out = val_out.reshape(orig_shape) + return val_out + + +def _hfp8_to_float( + hfp8_val: torch.Tensor, + ebits: int, + exponent_bias: int, +): + assert hfp8_val.dtype == torch.uint8 + + sign_i = (hfp8_val & 0x80).to(torch.int32) << 24 + + val_out_i = (hfp8_val & 0x7F).to(torch.int32) << (24 - (8 - ebits)) + # so that the mantissa bits start at the mantissa bit positions of FP32 + # encoding + + # Let the hfp8 mantissa bits correspond to the value frac, 0 <= frac < 1 + # So if the hfp8 value is a normal number, it's value is 2^e x (1+frac) + # where e is its (true, unbiased) exponent + # If the hfp8 value is denormal, the value is 2^(1-bias) x frac + + # However, the bit pattern in the 8-bit exponent field of val_out.F + # is bias+e when hfp8 is normal, and 0 when hfp8 is subnormal. + # So, as an FP32 value, when hfp8 is normal, val_out.F represents the value + # of 2^(bias+e-127) * (1+frac) + # And when hfp8 is subnormal, val_out.F is also subnormal, and represents the + # value of 2^(-126) * frac In either case, val_out.F corresponds to + # 2^(bias-127) * (value of hfp8 input) Thus, if we multiply val_out.F by + # 2^(127-bias), we obtain the hfp8 value as an FP32 number + + # multiplier.I = (127 + (127 - exponent_bias)) + # << 23; // multiplier.F is 2^(127-bias) + # val_out.F *= multiplier.F; + # val_out.I |= sign.I; + # return val_out.F; + + multiplier_i = (torch.zeros_like(hfp8_val).to(torch.int32) + 127 + (127 - exponent_bias)) << 23 # multiplier_f is 2^(127-bias) + val_out_f = val_out_i.view(torch.float) + val_out_f *= multiplier_i.view(torch.float) + val_out_f = (val_out_f.view(torch.int32) | sign_i).view(torch.float) + return val_out_f + +def float32_to_float8(x, flavor): + if flavor == E4M3: + return _float_to_hfp8(x, E4M3_EBITS, E4M3_EXP_BIAS, E4M3_MAX_POS) + else: # e5m2 + return _float_to_hfp8(x, E5M2_EBITS, E5M2_EXP_BIAS, E5M2_MAX_POS) + +def float8_to_float32(x, flavor): + if flavor == E4M3: + return _hfp8_to_float(x, E4M3_EBITS, E4M3_EXP_BIAS) + else: # e5m2 + return _hfp8_to_float(x, E5M2_EBITS, E5M2_EXP_BIAS) + +def amax_to_scale(amax, flavor): + if flavor == E4M3: + return E4M3_MAX_POS / torch.clamp(amax, min=EPS) + else: # e5m2 + return E5M2_MAX_POS / torch.clamp(amax, min=EPS) + +def tensor_to_scale(x, flavor): + amax = torch.max(torch.abs(x)) + return amax_to_scale(amax, flavor) + +def compute_error(x, y): + Ps = torch.norm(x) + Pn = torch.norm(x - y) + return 20 * torch.log10(Ps / Pn) diff --git a/float8_playground/test.py b/float8_playground/test.py new file mode 100644 index 00000000..0fb1e13d --- /dev/null +++ b/float8_playground/test.py @@ -0,0 +1,192 @@ +import copy +import random +import unittest + +import torch +import torch.nn as nn + +from float8_utils import ( + float32_to_float8, + float8_to_float32, + E4M3, + E5M2, + compute_error, + tensor_to_scale, +) +from float8_tensor import Float8Tensor +from float8_linear import Float8Linear + +random.seed(0) +torch.manual_seed(0) + +class Float8CastsUnitTest(unittest.TestCase): + """ + Test the casts between fp32 and fp8 (e4m3 and e5m2) + """ + + def _compare_many_exact(self, flavor, x_fp32, comp_name): + x_fp8 = float32_to_float8(x_fp32, flavor) + x_fp8_fp32 = float8_to_float32(x_fp8, flavor) + torch.testing.assert_close(x_fp32, x_fp8_fp32) + + def _compare_many_approx(self, flavor, x_fp32, comp_name): + if flavor == E4M3: + sqnr_target = 25.0 + else: # e5m2 + sqnr_target = 23.0 + + x_fp8 = float32_to_float8(x_fp32, flavor) + x_fp8_fp32 = float8_to_float32(x_fp8, flavor) + + # sign should always be the same + torch.testing.assert_close( + torch.sign(x_fp32), + torch.sign(x_fp8_fp32), + atol=0, rtol=0) + + # for now just measure that sqnr is somewhat reasonable + # TODO(future): make this significantly more robust, this is about + # 2/10 on the scale of "robust enough" + sqnr = compute_error(x_fp32, x_fp8_fp32) + assert sqnr >= sqnr_target + + + def _compare_one(self, flavor, bits_str, expected_fp32, comp_name): + fp8_bits_ref = torch.tensor([int(bits_str, 2)], dtype=torch.uint8) + + fp32_tensor = torch.tensor([expected_fp32], dtype=torch.float) + fp8_bits = float32_to_float8(fp32_tensor, flavor) + torch.testing.assert_close(fp8_bits, fp8_bits_ref, atol=0, rtol=0) + + fp32_from_fp8_tensor = float8_to_float32(fp8_bits, flavor) + torch.testing.assert_close(fp32_tensor, fp32_from_fp8_tensor, atol=0, rtol=0) + + def test_e4m3_numerics_single(self): + # ensure that our format matches https://arxiv.org/pdf/2209.05433.pdf, Table 1 + + flavor = E4M3 + # e4m3 does not support infinity + self._compare_one(flavor, "00000000", 0.0, "zero") + self._compare_one(flavor, "10000000", -0.0, "neg_zero") + self._compare_one(flavor, "01111110", 448.0, "max_normal") + self._compare_one(flavor, "11111110", -448.0, "neg_max_normal") + self._compare_one(flavor, "00001000", 2 ** -6, "min_normal") + self._compare_one(flavor, "10001000", -1 * (2 ** -6), "neg_min_normal") + self._compare_one(flavor, "00000111", 0.875 * (2 ** -6), "max_subnorm") + self._compare_one(flavor, "10000111", -0.875 * (2 ** -6), "neg_max_subnorm") + self._compare_one(flavor, "00000001", 2 ** -9, "min_subnorm") + self._compare_one(flavor, "10000001", -1 * (2 ** -9), "neg_min_subnorm") + + def test_e5m2_numerics_single(self): + flavor = E5M2 + # e5m2 infinity (below) is off by one, TODO(future PR) debug or just move + # to NVIDIA's intrinsic casts + # _compare_one(flavor, "01111100", float("inf"), "inf") + # _compare_one(flavor, "11111100", -1 * float("inf"), "neg_inf") + self._compare_one(flavor, "00000000", 0.0, "zero") + self._compare_one(flavor, "10000000", -0.0, "neg_zero") + self._compare_one(flavor, "01111011", 57344.0, "max_normal") + self._compare_one(flavor, "11111011", -57344.0, "neg_max_normal") + self._compare_one(flavor, "00000100", 2 ** -14, "min_normal") + self._compare_one(flavor, "10000100", -1 * (2 ** -14), "neg_min_normal") + self._compare_one(flavor, "00000011", 0.75 * (2 ** -14), "max_subnorm") + self._compare_one(flavor, "10000011", -0.75 * (2 ** -14), "neg_max_subnorm") + self._compare_one(flavor, "00000001", 2 ** -16, "min_subnorm") + self._compare_one(flavor, "10000001", -1 * (2 ** -16), "neg_min_subnorm") + + def test_e4m3_numerics_multiple(self): + # test special cases + x = torch.tensor([ + 0.0, + -0.0, + 448.0, + -448.0, + 2 ** -6, + -1 * (2 ** 6), + 0.875 * (2 ** 6), + -0.875 * (2 ** 6), + 2 ** -9, + -1 * (2 ** -9), + ]) + self._compare_many_exact(E4M3, x, 'special_cases') + + # test normal values + shapes + for _ in range(10): + x = torch.randn(1, 2, 3, 4) * random.uniform(0.1, 300.0) + x.clamp_(min=-448.0, max=448.0) + self._compare_many_approx(E4M3, x, 'normal_cases') + + def test_e5m2_numerics_multiple(self): + # test special cases + x = torch.tensor([ + 0.0, + -0.0, + 57344.0, + -57344.0, + 2 ** -14, + -1 * (2 ** -14), + 0.75 * (2 ** -14), + -0.75 * (2 ** -14), + 2 ** -16, + -1 * (2 ** -16), + ]) + self._compare_many_exact(E5M2, x, 'special_cases') + + # test normal values + shapes + for _ in range(10): + x = torch.randn(1, 2, 3, 4) * random.uniform(0.1, 30000.0) + x.clamp_(min=-57344.0, max=57344.0) + self._compare_many_approx(E5M2, x, 'normal_cases') + +class Float8TensorUnitTest(unittest.TestCase): + def test_add(self): + x1_fp32 = torch.randn(4, 4) + x1_s = tensor_to_scale(x1_fp32, E5M2) + x2_fp32 = torch.randn(4, 4) + x2_s = tensor_to_scale(x2_fp32, E5M2) + x1_fp8 = Float8Tensor.from_float32(x1_fp32, x1_s, E5M2) + x2_fp8 = Float8Tensor.from_float32(x2_fp32, x2_s, E5M2) + x3_fp8 = x1_fp8 + x2_fp8 + x3_fp32 = x3_fp8.to_float32() + x3_fp32_ref = x1_fp32 + x2_fp32 + sqnr = compute_error(x3_fp32_ref, x3_fp32) + # TODO(future): make this more accurate, accuracy is pretty low + self.assertTrue(sqnr >= 10.0) + +class Float8LinearUnitTest(unittest.TestCase): + + def test_e2e(self): + m_ref = nn.Linear(4, 4, bias=False) + m_fp8 = Float8Linear.from_float(copy.deepcopy(m_ref)) + + x = torch.randn(4, 4) + + y_fp8 = m_fp8(x) + y_fp8.sum().backward() + y_ref = m_ref(x) + y_ref.sum().backward() + + y_sqnr = compute_error(y_ref, y_fp8) + g_sqnr = compute_error(m_ref.weight.grad, m_fp8.weight.grad) + + # verify sqnr is reasonable + self.assertTrue(y_sqnr >= 27.0) + self.assertTrue(g_sqnr >= 27.0) + + # verify all of the scales got updated + for buffer_name in ( + 'fp8_s_in', + 'fp8_s_weight', + 'fp8_s_out', + 'fp8_s_dL_dX', + 'fp8_s_dL_dW', + 'fp8_s_dL_dY', + ): + buffer_value = getattr(m_fp8, buffer_name) + self.assertTrue( + torch.ne(buffer_value, torch.tensor(1.0)), + f"{buffer_name} not filled") + + +if __name__ == '__main__': + unittest.main()