From 459a2ca8db7187b2f3fa1d5d7c36d7aefdc60ee4 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Wed, 10 May 2023 02:27:37 +0000 Subject: [PATCH] [wip] Python-only float8 data type + bare bones UEX Summary: Note: this is WIP and does not represent PyTorch's opinion on how we will integrate float8. At this point, this is a prototype to get some light feedback. TODOs that need to be implemented before review of this PR: * [done] emulated float8 casts to and from float32 * [not done] move https://github.com/albanD/subclass_zoo/blob/fp8_v2/fp8_subclass_v2.py here and hook up to accurate casts * [not done] basic dynamic scaling (with the assumption that the real UEX can switch to delayed scaling) * [not done] numerical testing that this all works for fw + bw with a toy model TODOs to be done in next couple of PRs before sending to NVIDIA: * [not done] example of how an integration with distributed would work What is out of scope for this POC a. hooking up to real float8 ops (saved for later, just needs someone to do it) b. real UEX (saved for later and will need a lot of design discussion) Test plan: ``` python protoquant/float8/test.py ``` --- protoquant/float8/e2e_example.py | 60 +++++++++ protoquant/float8/float8_aten_api.py | 60 +++++++++ protoquant/float8/float8_linear.py | 122 ++++++++++++++++++ protoquant/float8/float8_tensor.py | 141 +++++++++++++++++++++ protoquant/float8/float8_utils.py | 178 +++++++++++++++++++++++++++ protoquant/float8/test.py | 177 ++++++++++++++++++++++++++ 6 files changed, 738 insertions(+) create mode 100644 protoquant/float8/e2e_example.py create mode 100644 protoquant/float8/float8_aten_api.py create mode 100644 protoquant/float8/float8_linear.py create mode 100644 protoquant/float8/float8_tensor.py create mode 100644 protoquant/float8/float8_utils.py create mode 100644 protoquant/float8/test.py diff --git a/protoquant/float8/e2e_example.py b/protoquant/float8/e2e_example.py new file mode 100644 index 0000000..d2ed88b --- /dev/null +++ b/protoquant/float8/e2e_example.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from float8_utils import E4M3, E5M2 +from float8_tensor import Float8Tensor +from float8_linear import Float8Linear + +aten = torch.ops.aten + +torch.manual_seed(0) + +# TODO(before land): move this to unit tests (in progress) +if __name__ == "__main__": + + # test addition + print("\nExample of addition\n") + x1_fp32, x1_s = torch.randn(4), torch.tensor(1.0) + x2_fp32, x2_s = torch.randn(4), torch.tensor(1.0) + x1_fp8 = Float8Tensor.from_float32(x1_fp32, x1_s, E5M2) + x2_fp8 = Float8Tensor.from_float32(x2_fp32, x2_s, E5M2) + x3_fp8 = x1_fp8 + x2_fp8 + print('x1', x1_fp8, '\nx2', x2_fp8, '\nx1+x2', x3_fp8) + + + print("\nExample of fp8 linear fw + bw\n") + + class M(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(2, 3, bias=False) + self.fc2 = nn.Linear(3, 4, bias=False) + self.fc3 = nn.Linear(3, 4, bias=False) + self.fc4 = nn.Linear(4, 2, bias=False) + + def forward(self, x0): + x1 = self.fc1(x0) + x2 = self.fc2(x1) + x3 = self.fc3(x1) + # test gradient addition + # Note: cat happens in fp32, for now + c = torch.cat([x2, x3]) + x4 = self.fc4(c) + return x4 + + m = M() + m.fc1 = Float8Linear.from_float(m.fc1) + m.fc2 = Float8Linear.from_float(m.fc2) + m.fc3 = Float8Linear.from_float(m.fc3) + m.fc4 = Float8Linear.from_float(m.fc4) + + print(m) + + x = Float8Tensor.from_float32(torch.randn(1, 2), torch.tensor(1.0), E4M3) + y = m(x) + print(y) + s = y.sum() + print('before grad', m.fc1.weight.grad, m.fc2.weight.grad, m.fc3.weight.grad, m.fc4.weight.grad) + s.backward() + print('after grad', m.fc1.weight.grad, m.fc2.weight.grad, m.fc3.weight.grad, m.fc4.weight.grad) diff --git a/protoquant/float8/float8_aten_api.py b/protoquant/float8/float8_aten_api.py new file mode 100644 index 0000000..1350786 --- /dev/null +++ b/protoquant/float8/float8_aten_api.py @@ -0,0 +1,60 @@ +""" +This file defines the aten functions for float8. Today, all of these functions +are emulated. In the future, they should be calling NVIDIA's float8 kernels. +""" + +import torch +from torch.library import Library + +from float8_utils import ( + float32_to_float8, + float8_to_float32, + E4M3, + E5M2, + tensor_to_scale, +) + + +# TODO clean up var names +def mm_float8(m1, s1, flavor1, m2, s2, flavor2, sout, flavorout): + # naive implementation: dq -> op -> q + # TODO(future): hook up to real kernel + full_m1 = float8_to_float32(m1, flavor1) / s1 + full_m2 = float8_to_float32(m2, flavor2) / s2 + full_out = torch.mm(full_m1, full_m2) + # TODO(future): switch to delayed scaling + sout.fill_(tensor_to_scale(full_out, flavorout)) + out = full_out * sout + return float32_to_float8(out, flavorout) + +def add_float8_e5m2(m1, s1, m2, s2, s3): + # for now this is only implemented for e5m2 because we only care about + # this for adding gradients + # naive implementation: dq -> op -> q + # TODO(future): hook up to real kernel + m1_float32 = float8_to_float32(m1, E5M2) / s1 + m2_float32 = float8_to_float32(m2, E5M2) / s2 + m3_float32 = m1_float32 + m2_float32 + return float32_to_float8(m3_float32 * s3, E5M2) + +# +# ATen op placeholders +# + +# Register the aten level functions we need. +# These are mostly placeholder and might need to be implemented in c++ as needed +lib = Library("aten", "FRAGMENT") + +# Define our new custom functions +# Since all my Tensors are on CPU, I register everything there. +lib.define("float32_to_float8(Tensor t, int flavor) -> Tensor") +lib.impl("float32_to_float8", float32_to_float8, "CPU") + +lib.define("float8_to_float32(Tensor t, int flavor) -> Tensor") +lib.impl("float8_to_float32", float8_to_float32, "CPU") + +lib.define("mm_float8(Tensor m1, Tensor s1, int flavor1, Tensor m2, Tensor s2, int flavor2, Tensor sout, int flavorout) -> Tensor") +lib.impl("mm_float8", mm_float8, "CPU") + +lib.define("add_float8_e5m2(Tensor m1, Tensor s1, Tensor m2, Tensor s2, Tensor s3) -> Tensor") +lib.impl("add_float8_e5m2", add_float8_e5m2, "CPU") diff --git a/protoquant/float8/float8_linear.py b/protoquant/float8/float8_linear.py new file mode 100644 index 0000000..1ce39ef --- /dev/null +++ b/protoquant/float8/float8_linear.py @@ -0,0 +1,122 @@ +""" +A simple manual UEX for a float8 version of `torch.nn.Linear`. + +Note: this UEX is not intended for real usage. It merely demonstrates +an example of how features such as casting to and from float8 as well +as stateful scaling can be implemented. For now, we expect framework +owners to implement their own UEX. +""" + +import torch + +import float8_aten_api + +from float8_utils import E4M3, E5M2, tensor_to_scale +from float8_tensor import Float8Tensor + +class float8_linear_no_bias(torch.autograd.Function): + """ + Like F.linear, but with X, W, and Y in float8 + TODO(future) add logic for bias + """ + + @staticmethod + def forward( + ctx, + x_fp8, + w_t_fp8, + fp8_s_out, + fp8_s_dL_dX, + fp8_s_dL_dW, + fp8_s_dL_dY, + ): + ctx.save_for_backward(x_fp8, w_t_fp8, fp8_s_dL_dX, fp8_s_dL_dW, fp8_s_dL_dY) + + res_bits = torch.ops.aten.mm_float8( + x_fp8._data, x_fp8._scale, x_fp8._flavor, + w_t_fp8._data.t(), w_t_fp8._scale, w_t_fp8._flavor, + fp8_s_out, E4M3) + + res = Float8Tensor(res_bits, fp8_s_out, E4M3) + # scale update would also happen here, for now no-op + return res + + @staticmethod + def backward(ctx, go): + x_fp8, w_t_fp8, fp8_s_dL_dX, fp8_s_dL_dW, fp8_s_dL_dY = \ + ctx.saved_tensors + + if not isinstance(go, Float8Tensor): + # TODO(future): switch to delayed scaling + fp8_s_dL_dY.fill_(tensor_to_scale(go, E5M2)) + go_fp8 = Float8Tensor( + torch.ops.aten.float32_to_float8(go * fp8_s_dL_dY, E5M2), + fp8_s_dL_dY, + E5M2) + else: + go_fp8 = go + + dL_dX_bits = torch.ops.aten.mm_float8( + go_fp8._data, go_fp8._scale, go_fp8._flavor, + w_t_fp8._data, w_t_fp8._scale, w_t_fp8._flavor, + fp8_s_dL_dX, E5M2) + dL_dX_fp8 = Float8Tensor(dL_dX_bits, fp8_s_dL_dX, E5M2) + + dL_dW_bits = torch.ops.aten.mm_float8( + x_fp8._data.t(), x_fp8._scale, x_fp8._flavor, + go_fp8._data, go_fp8._scale, go_fp8._flavor, + fp8_s_dL_dW, E5M2).t() + dL_dW_fp8 = Float8Tensor(dL_dW_bits, fp8_s_dL_dW, E5M2) + + # scale update would also happen here, for now no-op + return dL_dX_fp8, dL_dW_fp8, None, None, None, None + + +class Float8Linear(torch.nn.Linear): + """ + A wrapper around a `torch.nn.Linear` module which does fp8 compute, and tracks + scales in way friendly to delayed scaling. + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # While this module currently implements just-in-time scaling, + # the scales are stored in buffers as a placeholder for delayed + # scaling such as the mechanism described in + # https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#Mixed-precision-training-with-FP8, + # or PTQ calibration. + self.register_buffer('fp8_s_in', torch.tensor(1.0)) + self.register_buffer('fp8_s_weight', torch.tensor(1.0)) + self.register_buffer('fp8_s_out', torch.tensor(1.0)) + self.register_buffer('fp8_s_dL_dX', torch.tensor(1.0)) + self.register_buffer('fp8_s_dL_dW', torch.tensor(1.0)) + self.register_buffer('fp8_s_dL_dY', torch.tensor(1.0)) + + def forward(self, x): + if not isinstance(x, Float8Tensor): + # TODO(future): switch to delayed scaling + self.fp8_s_in.fill_(tensor_to_scale(x, E4M3)) + x_fp8 = Float8Tensor.from_float32(x, self.fp8_s_in, E4M3) + else: + x_fp8 = x + + # TODO(future): switch to delayed scaling + self.fp8_s_weight.fill_(tensor_to_scale(self.weight, E4M3)) + w_t_fp8 = Float8Tensor.from_float32(self.weight, self.fp8_s_weight, E4M3) + + y_fp8 = float8_linear_no_bias.apply( + x_fp8, w_t_fp8, self.fp8_s_out, self.fp8_s_dL_dX, + self.fp8_s_dL_dW, self.fp8_s_dL_dY) + + # For now, hardcode returning Float8Tensor (propagate as much as we can). + # This can be changed to return a different dtype, if needed. + return y_fp8 + + @classmethod + def from_float(cls, mod): + """ + Create an nn.Linear with fp8 compute from a regular nn.Linear + """ + assert mod.bias is None, 'bias support not implemented yet' + new_mod = cls(mod.in_features, mod.out_features, bias=False) + new_mod.weight = mod.weight + return new_mod diff --git a/protoquant/float8/float8_tensor.py b/protoquant/float8/float8_tensor.py new file mode 100644 index 0000000..4c39f99 --- /dev/null +++ b/protoquant/float8/float8_tensor.py @@ -0,0 +1,141 @@ +from enum import Enum +import torch +from torch.utils._pytree import tree_map + +from float8_utils import E4M3, E5M2 + +aten = torch.ops.aten + +class Float8ConstrFunc(torch.autograd.Function): + """ + A differentiable conversion between fp32 and fp8 + TODO(future): split into two for cleaner code + """ + @staticmethod + def forward(ctx, tensor, scale: float=None, flavor=E4M3): + if isinstance(tensor, Float8Tensor): + ctx.inp_is_float8 = True + return torch.ops.aten.float8_to_float32(tensor._data, tensor._flavor) / tensor._scale + else: + ctx.inp_is_float8 = False + tensor_scaled = tensor * scale + bits_fp8 = torch.ops.aten.float32_to_float8(tensor_scaled, flavor) + return Float8Tensor(bits_fp8, scale, flavor) + + @staticmethod + def backward(ctx, g): + # Assume that we always want to scale the gradients + # back to full precision. We could do something else + if isinstance(g, Float8Tensor) and not ctx.inp_is_float8: + return g.to_float32(), None, None + elif ctx.inp_is_float8: + return Float8Tensor.from_float32(g), None, None + else: + return g, None, None + + +class Float8Tensor(torch.Tensor): + """ + A Python-only FP8 tensor. Contains: + * `_data`: the underlying e4m3 or e5m2 data + * `_scale`: the scale used to scale the original fp32 tensor. We multiply + by scale to go from fp32 range to fp8 range, and divide by scale to go + from fp8 range to fp32 range. + * `_flavor`: either E4M3 or E5M2 + + The current purpose of this object is 99% to bundle raw data + fp8 metadata + together for easy passing through PyTorch systems, and 1% to implement + gradient addition (since that has to happen outside of user code). + + The addition operation is defined inline and uses a naive + version of stateless scaling. This allows e5m2 gradients to be added. + TODO(future): verify this is numericaly accurate, optionally replace + with something better. + + It would probably make sense to also define fp8 path for data shuffling + ops like cat, transpose, view, etc inline so we don't have to fall back + to fp32 for them. + """ + + def __new__(cls, data, scale, flavor): + # This is a non-differentiable constructor! + assert not data.requires_grad + # TODO(future): make bits8 easier to work with and switch to using it + # assert data.dtype == torch.bits8 + assert scale.dtype == torch.float32 + assert scale.nelement() == 1 + + self = torch.Tensor._make_wrapper_subclass( + cls, + data.size(), + strides=data.stride(), + storage_offset=data.storage_offset(), + dtype=torch.float32, + layout=data.layout, + requires_grad=data.requires_grad, + device=data.device, + ) + self._data = data + self._scale = scale + self._flavor = flavor + + return self + + def __repr__(self): + return f"Float8Tensor(flavor={self._flavor}, scale={self._scale}, as_float32={self.to_float32()}" + + def to_float32(self): + return Float8ConstrFunc.apply(self) + + @classmethod + def from_float32(cls, tensor, scale, flavor): + return Float8ConstrFunc.apply(tensor, scale, flavor) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + # Note: unlike many other subclasses, this subclass's only propagates + # itself for addition (for gradient addition in backward). For all + # other ops, it self-converts to fp32. The user/framework is + # assumed to take care of defining where fp8 operations occur in the + # forward pass and how scaling is calculated. In this example, that is + # done by the `FP8Linear` class. + # Vasiliy: the main reason I went with ^ is because NVIDIA is + # doing stateful delayed scaling, and I don't know of a safe + # way to enable that without either full program capture or punting it + # to the user. This prototype takes the "punt it to the user" approach. + # IMO for now let's just write out the scale stuff manually so we can + # focus on other things, and revisit later if needed. + + # override addition so we can add e5m2 gradients + if ( + func is aten.add.Tensor + and isinstance(args[0], Float8Tensor) + and isinstance(args[1], Float8Tensor) + ): + x1_fp8, x2_fp8 = args[0], args[1] + assert x1_fp8._flavor == E5M2 and x2_fp8._flavor == E5M2 + # naive scale calculation: max of incoming two scales + x3_scale = torch.max(x1_fp8._scale, x2_fp8._scale) + res_bits = torch.ops.aten.add_float8_e5m2( + x1_fp8._data, x1_fp8._scale, + x2_fp8._data, x2_fp8._scale, + x3_scale) + res = Float8Tensor(res_bits, x3_scale, x1_fp8._flavor) + return res + + # for all other ops, fall back to fp32 + # TODO(future): add support for fp16/bf16 + + def maybe_unwrap(t): + if isinstance(t, Float8Tensor): + return t.to_float32() + return t + + args = tree_map(maybe_unwrap, args) + if kwargs is not None: + kwargs = tree_map(maybe_unwrap, kwargs) + out = super().__torch_dispatch__(func, types, args, kwargs) + return out + + # Do not force the Float8Tensor type on the returned tensor + __torch_function__ = torch._C._disabled_torch_function_impl diff --git a/protoquant/float8/float8_utils.py b/protoquant/float8/float8_utils.py new file mode 100644 index 0000000..b9bf9ee --- /dev/null +++ b/protoquant/float8/float8_utils.py @@ -0,0 +1,178 @@ +import torch + +# This file reproduces the emulated fp8 <-> fp32 casts from +# https://github.com/pytorch/FBGEMM/pull/974/files with plain PyTorch. +# This implements the fp8 format spec from https://arxiv.org/pdf/2209.05433.pdf +# +# TODO(future PR): hook up to NVIDIA's casts on gpu, and +# import the fbgemm emulator directly on cpu. We'll also need to ensure +# the two are aligned. +# +# Helpful visualizer for debugging (only supports fp32): +# https://www.h-schmidt.net/FloatConverter/IEEE754.html + +# define the e4m3/e5m2 constants +E4M3_EBITS = 4 +E4M3_EXP_BIAS = 7 +E4M3_MAX_POS = 448.0 + +E5M2_EBITS = 5 +E5M2_EXP_BIAS = 15 +E5M2_MAX_POS = 57344.0 + +# avoid division by zero when calculating scale +# TODO: align this value with NVIDIA's assumptions (current value is a guess) +EPS = 1e-12 + +# enum without using an enum, for brevity +# TODO(future): make this into an enum if needed +E4M3 = 0 +E5M2 = 1 + + +def _float_to_hfp8( + val_fp: torch.Tensor, # fp32 values + ebits: int, # exponent bits, mbits = 8 - ebits + exponent_bias: int, # exponent bias to use in the fp8 encoding + max_pos: float, # maximum positive number for fp8 encoding +): + mbits = 7 - ebits + + val_out = val_fp.clone().detach() + + # S{1}, E{8}, M{23} + sign_bit = (val_out.view(torch.int32) & 0x80000000).view(torch.float) + + # set sign bit to 0 + # 0{1}, E{8}, M{23} + val_out = (val_out.view(torch.int32) & 0x7FFFFFFF).view(torch.float) + + # ensure abs(val_out) <= max_pos) + val_out = torch.clamp(val_out, max=max_pos) + + smallest_normal = torch.zeros_like(val_fp).to(torch.int32) + smallest_normal = ((smallest_normal + 127 - exponent_bias + 1) << 23).view(torch.float) + + # normal and denormal paths split below, record + # which element need which path + is_normal_mask = torch.ge(val_out, smallest_normal) + + # + # normal path + # + + # Use round to nearest even. We make use of the standard rounding mechanism + # in FP32 rather than rounding the mantissa and handling tie-to-even and + # incrementing exponent We want to round of 23-mbits of the FP32 value + # val_in This can be done by adding a power of 2 exactly 23-mbits larger + # than the exponent of val_in This forces val_in to be moved to the right + # and rounding exact at the location corresponding to having mbits of + # explicit mantissa left + n_bouncer = ((val_out.view(torch.int32) & 0xFF800000) + ((23 - mbits) << 23)).view(torch.float) + n_val_out = (n_bouncer + val_out) - n_bouncer + + # adding the bouncer rounds off bits, and subtracting bouncer + # leaves the desired value, albeit in FP32 encoding + # All we need is to change the exponent encoding to using "bias" + n_val_out_i = (n_val_out.view(torch.int32) - ((127 - exponent_bias) << 23)) << (8 - ebits) + n_val_out_i = (n_val_out_i | sign_bit.view(torch.int32)) >> 24 + n_val_out = n_val_out_i.view(torch.float) + + # + # denormal path + # + + # When the value is in the denormal range, IEEE numbers essentially becomes + # a fixed point number. The lsb is the smallest non-zero number + # 2^(1-bias-mbits) Hence, we define the bouncer so that its lsb is this + # smallest non-zero number Adding the input to this bouncer forces rounding + # to occur appropriately Also, in this situation, after adding the bouncer, + # the 8 least significant bits of the sum is already the HFP8 encoding of + # the desired result. Just need to restore the sign bit + # bouncer.I = (127 + (23 + (1 - exponent_bias - mbits))) << 23; + # val_out.F = bouncer.F + val_out.F; + # val_out.I = val_out.I | (sign_bit >> 24); + + dn_bouncer = ((torch.zeros_like(val_out).view(torch.int32) + 127 + (23 + (1 - exponent_bias - mbits))) << 23).view(torch.float) + dn_val_out = dn_bouncer + val_out + dn_val_out_i = dn_val_out.view(torch.int32) | (sign_bit.view(torch.int32) >> 24) + dn_val_out = dn_val_out_i.view(torch.float) + + # + # combine normal and denormal paths + # + val_out = torch.where(is_normal_mask, n_val_out, dn_val_out) + # take the 8 least significant bits + orig_shape = val_fp.shape + val_out = val_out.view(torch.uint8) + val_out = val_out.reshape(-1, 4) + val_out = torch.tensor_split(val_out, 4, dim=-1)[0] + val_out = val_out.reshape(orig_shape) + return val_out + + +def _hfp8_to_float( + hfp8_val: torch.Tensor, + ebits: int, + exponent_bias: int, +): + assert hfp8_val.dtype == torch.uint8 + + sign_i = (hfp8_val & 0x80).to(torch.int32) << 24 + + val_out_i = (hfp8_val & 0x7F).to(torch.int32) << (24 - (8 - ebits)) + # so that the mantissa bits start at the mantissa bit positions of FP32 + # encoding + + # Let the hfp8 mantissa bits correspond to the value frac, 0 <= frac < 1 + # So if the hfp8 value is a normal number, it's value is 2^e x (1+frac) + # where e is its (true, unbiased) exponent + # If the hfp8 value is denormal, the value is 2^(1-bias) x frac + + # However, the bit pattern in the 8-bit exponent field of val_out.F + # is bias+e when hfp8 is normal, and 0 when hfp8 is subnormal. + # So, as an FP32 value, when hfp8 is normal, val_out.F represents the value + # of 2^(bias+e-127) * (1+frac) + # And when hfp8 is subnormal, val_out.F is also subnormal, and represents the + # value of 2^(-126) * frac In either case, val_out.F corresponds to + # 2^(bias-127) * (value of hfp8 input) Thus, if we multiply val_out.F by + # 2^(127-bias), we obtain the hfp8 value as an FP32 number + + # multiplier.I = (127 + (127 - exponent_bias)) + # << 23; // multiplier.F is 2^(127-bias) + # val_out.F *= multiplier.F; + # val_out.I |= sign.I; + # return val_out.F; + + multiplier_i = (torch.zeros_like(hfp8_val).to(torch.int32) + 127 + (127 - exponent_bias)) << 23 # multiplier_f is 2^(127-bias) + val_out_f = val_out_i.view(torch.float) + val_out_f *= multiplier_i.view(torch.float) + val_out_f = (val_out_f.view(torch.int32) | sign_i).view(torch.float) + return val_out_f + +def float32_to_float8(x, flavor): + if flavor == E4M3: + return _float_to_hfp8(x, E4M3_EBITS, E4M3_EXP_BIAS, E4M3_MAX_POS) + else: # e5m2 + return _float_to_hfp8(x, E5M2_EBITS, E5M2_EXP_BIAS, E5M2_MAX_POS) + +def float8_to_float32(x, flavor): + if flavor == E4M3: + return _hfp8_to_float(x, E4M3_EBITS, E4M3_EXP_BIAS) + else: # e5m2 + return _hfp8_to_float(x, E5M2_EBITS, E5M2_EXP_BIAS) + +def amax_to_scale(amax, flavor): + if flavor == E4M3: + return E4M3_MAX_POS / torch.clamp(amax, min=EPS) + else: # e5m2 + return E5M2_MAX_POS / torch.clamp(amax, min=EPS) + +def tensor_to_scale(x, flavor): + amax = torch.max(torch.abs(x)) + return amax_to_scale(amax, flavor) + +def compute_error(x, y): + Ps = torch.norm(x) + Pn = torch.norm(x - y) + return 20 * torch.log10(Ps / Pn) diff --git a/protoquant/float8/test.py b/protoquant/float8/test.py new file mode 100644 index 0000000..162e7c3 --- /dev/null +++ b/protoquant/float8/test.py @@ -0,0 +1,177 @@ +import copy +import random +import unittest + +import torch +import torch.nn as nn + +from float8_utils import ( + float32_to_float8, + float8_to_float32, + E4M3, + E5M2, + compute_error, +) + +from float8_linear import Float8Linear + +random.seed(0) +torch.manual_seed(0) + +class Fp8CastsUnitTest(unittest.TestCase): + """ + Test the casts between fp32 and fp8 (e4m3 and e5m2) + """ + + def _compare_many_exact(self, flavor, x_fp32, comp_name): + x_fp8 = float32_to_float8(x_fp32, flavor) + x_fp8_fp32 = float8_to_float32(x_fp8, flavor) + torch.testing.assert_close(x_fp32, x_fp8_fp32) + + def _compare_many_approx(self, flavor, x_fp32, comp_name): + if flavor == E4M3: + sqnr_target = 25.0 + else: # e5m2 + sqnr_target = 23.0 + + x_fp8 = float32_to_float8(x_fp32, flavor) + x_fp8_fp32 = float8_to_float32(x_fp8, flavor) + + # sign should always be the same + torch.testing.assert_close( + torch.sign(x_fp32), + torch.sign(x_fp8_fp32), + atol=0, rtol=0) + + # for now just measure that sqnr is somewhat reasonable + # TODO(future): make this significantly more robust, this is about + # 2/10 on the scale of "robust enough" + sqnr = compute_error(x_fp32, x_fp8_fp32) + assert sqnr >= sqnr_target + + + def _compare_one(self, flavor, bits_str, expected_fp32, comp_name): + fp8_bits_ref = torch.tensor([int(bits_str, 2)], dtype=torch.uint8) + + fp32_tensor = torch.tensor([expected_fp32], dtype=torch.float) + fp8_bits = float32_to_float8(fp32_tensor, flavor) + torch.testing.assert_close(fp8_bits, fp8_bits_ref, atol=0, rtol=0) + + fp32_from_fp8_tensor = float8_to_float32(fp8_bits, flavor) + torch.testing.assert_close(fp32_tensor, fp32_from_fp8_tensor, atol=0, rtol=0) + + def test_e4m3_numerics_single(self): + # ensure that our format matches https://arxiv.org/pdf/2209.05433.pdf, Table 1 + + flavor = E4M3 + # e4m3 does not support infinity + self._compare_one(flavor, "00000000", 0.0, "zero") + self._compare_one(flavor, "10000000", -0.0, "neg_zero") + self._compare_one(flavor, "01111110", 448.0, "max_normal") + self._compare_one(flavor, "11111110", -448.0, "neg_max_normal") + self._compare_one(flavor, "00001000", 2 ** -6, "min_normal") + self._compare_one(flavor, "10001000", -1 * (2 ** -6), "neg_min_normal") + self._compare_one(flavor, "00000111", 0.875 * (2 ** -6), "max_subnorm") + self._compare_one(flavor, "10000111", -0.875 * (2 ** -6), "neg_max_subnorm") + self._compare_one(flavor, "00000001", 2 ** -9, "min_subnorm") + self._compare_one(flavor, "10000001", -1 * (2 ** -9), "neg_min_subnorm") + + def test_e5m2_numerics_single(self): + flavor = E5M2 + # e5m2 infinity (below) is off by one, TODO(future PR) debug or just move + # to NVIDIA's intrinsic casts + # _compare_one(flavor, "01111100", float("inf"), "inf") + # _compare_one(flavor, "11111100", -1 * float("inf"), "neg_inf") + self._compare_one(flavor, "00000000", 0.0, "zero") + self._compare_one(flavor, "10000000", -0.0, "neg_zero") + self._compare_one(flavor, "01111011", 57344.0, "max_normal") + self._compare_one(flavor, "11111011", -57344.0, "neg_max_normal") + self._compare_one(flavor, "00000100", 2 ** -14, "min_normal") + self._compare_one(flavor, "10000100", -1 * (2 ** -14), "neg_min_normal") + self._compare_one(flavor, "00000011", 0.75 * (2 ** -14), "max_subnorm") + self._compare_one(flavor, "10000011", -0.75 * (2 ** -14), "neg_max_subnorm") + self._compare_one(flavor, "00000001", 2 ** -16, "min_subnorm") + self._compare_one(flavor, "10000001", -1 * (2 ** -16), "neg_min_subnorm") + + def test_e4m3_numerics_multiple(self): + # test special cases + x = torch.tensor([ + 0.0, + -0.0, + 448.0, + -448.0, + 2 ** -6, + -1 * (2 ** 6), + 0.875 * (2 ** 6), + -0.875 * (2 ** 6), + 2 ** -9, + -1 * (2 ** -9), + ]) + self._compare_many_exact(E4M3, x, 'special_cases') + + # test normal values + shapes + for _ in range(10): + x = torch.randn(1, 2, 3, 4) * random.uniform(0.1, 300.0) + x.clamp_(min=-448.0, max=448.0) + self._compare_many_approx(E4M3, x, 'normal_cases') + + def test_e5m2_numerics_multiple(self): + # test special cases + x = torch.tensor([ + 0.0, + -0.0, + 57344.0, + -57344.0, + 2 ** -14, + -1 * (2 ** -14), + 0.75 * (2 ** -14), + -0.75 * (2 ** -14), + 2 ** -16, + -1 * (2 ** -16), + ]) + self._compare_many_exact(E5M2, x, 'special_cases') + + # test normal values + shapes + for _ in range(10): + x = torch.randn(1, 2, 3, 4) * random.uniform(0.1, 30000.0) + x.clamp_(min=-57344.0, max=57344.0) + self._compare_many_approx(E5M2, x, 'normal_cases') + +class Float8LinearUnitTest(unittest.TestCase): + + def test_e2e(self): + m_ref = nn.Linear(4, 4, bias=False) + m_fp8 = Float8Linear.from_float(copy.deepcopy(m_ref)) + + x = torch.randn(4, 4) + + y_fp8 = m_fp8(x) + y_fp8.sum().backward() + y_ref = m_ref(x) + y_ref.sum().backward() + + y_sqnr = compute_error(y_ref, y_fp8) + g_sqnr = compute_error(m_ref.weight.grad, m_fp8.weight.grad) + + # verify sqnr is reasonable + print('y_sqnr', y_sqnr, 'g_sqnr', g_sqnr) + self.assertTrue(y_sqnr >= 27.0) + self.assertTrue(g_sqnr >= 27.0) + + # verify all of the scales got updated + for buffer_name in ( + 'fp8_s_in', + 'fp8_s_weight', + 'fp8_s_out', + 'fp8_s_dL_dX', + 'fp8_s_dL_dW', + 'fp8_s_dL_dY', + ): + buffer_value = getattr(m_fp8, buffer_name) + self.assertTrue( + torch.ne(buffer_value, torch.tensor(1.0)), + f"{buffer_name} not filled") + + +if __name__ == '__main__': + unittest.main()