-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[wip] Python-only float8 data type + bare bones UEX
Summary: Note: this is WIP and does not represent PyTorch's opinion on how we will integrate float8. At this point, this is a prototype to get some light feedback. TODOs that need to be implemented before review of this PR: * [done] emulated float8 casts to and from float32 * [not done] move https://github.com/albanD/subclass_zoo/blob/fp8_v2/fp8_subclass_v2.py here and hook up to accurate casts * [not done] basic dynamic scaling (with the assumption that the real UEX can switch to delayed scaling) * [not done] numerical testing that this all works for fw + bw with a toy model TODOs to be done in next couple of PRs before sending to NVIDIA: * [not done] example of how an integration with distributed would work What is out of scope for this POC a. hooking up to real float8 ops (saved for later, just needs someone to do it) b. real UEX (saved for later and will need a lot of design discussion) Test plan: ``` python protoquant/float8/test.py ```
- Loading branch information
Showing
6 changed files
with
738 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from float8_utils import E4M3, E5M2 | ||
from float8_tensor import Float8Tensor | ||
from float8_linear import Float8Linear | ||
|
||
aten = torch.ops.aten | ||
|
||
torch.manual_seed(0) | ||
|
||
# TODO(before land): move this to unit tests (in progress) | ||
if __name__ == "__main__": | ||
|
||
# test addition | ||
print("\nExample of addition\n") | ||
x1_fp32, x1_s = torch.randn(4), torch.tensor(1.0) | ||
x2_fp32, x2_s = torch.randn(4), torch.tensor(1.0) | ||
x1_fp8 = Float8Tensor.from_float32(x1_fp32, x1_s, E5M2) | ||
x2_fp8 = Float8Tensor.from_float32(x2_fp32, x2_s, E5M2) | ||
x3_fp8 = x1_fp8 + x2_fp8 | ||
print('x1', x1_fp8, '\nx2', x2_fp8, '\nx1+x2', x3_fp8) | ||
|
||
|
||
print("\nExample of fp8 linear fw + bw\n") | ||
|
||
class M(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.fc1 = nn.Linear(2, 3, bias=False) | ||
self.fc2 = nn.Linear(3, 4, bias=False) | ||
self.fc3 = nn.Linear(3, 4, bias=False) | ||
self.fc4 = nn.Linear(4, 2, bias=False) | ||
|
||
def forward(self, x0): | ||
x1 = self.fc1(x0) | ||
x2 = self.fc2(x1) | ||
x3 = self.fc3(x1) | ||
# test gradient addition | ||
# Note: cat happens in fp32, for now | ||
c = torch.cat([x2, x3]) | ||
x4 = self.fc4(c) | ||
return x4 | ||
|
||
m = M() | ||
m.fc1 = Float8Linear.from_float(m.fc1) | ||
m.fc2 = Float8Linear.from_float(m.fc2) | ||
m.fc3 = Float8Linear.from_float(m.fc3) | ||
m.fc4 = Float8Linear.from_float(m.fc4) | ||
|
||
print(m) | ||
|
||
x = Float8Tensor.from_float32(torch.randn(1, 2), torch.tensor(1.0), E4M3) | ||
y = m(x) | ||
print(y) | ||
s = y.sum() | ||
print('before grad', m.fc1.weight.grad, m.fc2.weight.grad, m.fc3.weight.grad, m.fc4.weight.grad) | ||
s.backward() | ||
print('after grad', m.fc1.weight.grad, m.fc2.weight.grad, m.fc3.weight.grad, m.fc4.weight.grad) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
""" | ||
This file defines the aten functions for float8. Today, all of these functions | ||
are emulated. In the future, they should be calling NVIDIA's float8 kernels. | ||
""" | ||
|
||
import torch | ||
from torch.library import Library | ||
|
||
from float8_utils import ( | ||
float32_to_float8, | ||
float8_to_float32, | ||
E4M3, | ||
E5M2, | ||
tensor_to_scale, | ||
) | ||
|
||
|
||
# TODO clean up var names | ||
def mm_float8(m1, s1, flavor1, m2, s2, flavor2, sout, flavorout): | ||
# naive implementation: dq -> op -> q | ||
# TODO(future): hook up to real kernel | ||
full_m1 = float8_to_float32(m1, flavor1) / s1 | ||
full_m2 = float8_to_float32(m2, flavor2) / s2 | ||
full_out = torch.mm(full_m1, full_m2) | ||
# TODO(future): switch to delayed scaling | ||
sout.fill_(tensor_to_scale(full_out, flavorout)) | ||
out = full_out * sout | ||
return float32_to_float8(out, flavorout) | ||
|
||
def add_float8_e5m2(m1, s1, m2, s2, s3): | ||
# for now this is only implemented for e5m2 because we only care about | ||
# this for adding gradients | ||
# naive implementation: dq -> op -> q | ||
# TODO(future): hook up to real kernel | ||
m1_float32 = float8_to_float32(m1, E5M2) / s1 | ||
m2_float32 = float8_to_float32(m2, E5M2) / s2 | ||
m3_float32 = m1_float32 + m2_float32 | ||
return float32_to_float8(m3_float32 * s3, E5M2) | ||
|
||
# | ||
# ATen op placeholders | ||
# | ||
|
||
# Register the aten level functions we need. | ||
# These are mostly placeholder and might need to be implemented in c++ as needed | ||
lib = Library("aten", "FRAGMENT") | ||
|
||
# Define our new custom functions | ||
# Since all my Tensors are on CPU, I register everything there. | ||
lib.define("float32_to_float8(Tensor t, int flavor) -> Tensor") | ||
lib.impl("float32_to_float8", float32_to_float8, "CPU") | ||
|
||
lib.define("float8_to_float32(Tensor t, int flavor) -> Tensor") | ||
lib.impl("float8_to_float32", float8_to_float32, "CPU") | ||
|
||
lib.define("mm_float8(Tensor m1, Tensor s1, int flavor1, Tensor m2, Tensor s2, int flavor2, Tensor sout, int flavorout) -> Tensor") | ||
lib.impl("mm_float8", mm_float8, "CPU") | ||
|
||
lib.define("add_float8_e5m2(Tensor m1, Tensor s1, Tensor m2, Tensor s2, Tensor s3) -> Tensor") | ||
lib.impl("add_float8_e5m2", add_float8_e5m2, "CPU") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
""" | ||
A simple manual UEX for a float8 version of `torch.nn.Linear`. | ||
Note: this UEX is not intended for real usage. It merely demonstrates | ||
an example of how features such as casting to and from float8 as well | ||
as stateful scaling can be implemented. For now, we expect framework | ||
owners to implement their own UEX. | ||
""" | ||
|
||
import torch | ||
|
||
import float8_aten_api | ||
|
||
from float8_utils import E4M3, E5M2, tensor_to_scale | ||
from float8_tensor import Float8Tensor | ||
|
||
class float8_linear_no_bias(torch.autograd.Function): | ||
""" | ||
Like F.linear, but with X, W, and Y in float8 | ||
TODO(future) add logic for bias | ||
""" | ||
|
||
@staticmethod | ||
def forward( | ||
ctx, | ||
x_fp8, | ||
w_t_fp8, | ||
fp8_s_out, | ||
fp8_s_dL_dX, | ||
fp8_s_dL_dW, | ||
fp8_s_dL_dY, | ||
): | ||
ctx.save_for_backward(x_fp8, w_t_fp8, fp8_s_dL_dX, fp8_s_dL_dW, fp8_s_dL_dY) | ||
|
||
res_bits = torch.ops.aten.mm_float8( | ||
x_fp8._data, x_fp8._scale, x_fp8._flavor, | ||
w_t_fp8._data.t(), w_t_fp8._scale, w_t_fp8._flavor, | ||
fp8_s_out, E4M3) | ||
|
||
res = Float8Tensor(res_bits, fp8_s_out, E4M3) | ||
# scale update would also happen here, for now no-op | ||
return res | ||
|
||
@staticmethod | ||
def backward(ctx, go): | ||
x_fp8, w_t_fp8, fp8_s_dL_dX, fp8_s_dL_dW, fp8_s_dL_dY = \ | ||
ctx.saved_tensors | ||
|
||
if not isinstance(go, Float8Tensor): | ||
# TODO(future): switch to delayed scaling | ||
fp8_s_dL_dY.fill_(tensor_to_scale(go, E5M2)) | ||
go_fp8 = Float8Tensor( | ||
torch.ops.aten.float32_to_float8(go * fp8_s_dL_dY, E5M2), | ||
fp8_s_dL_dY, | ||
E5M2) | ||
else: | ||
go_fp8 = go | ||
|
||
dL_dX_bits = torch.ops.aten.mm_float8( | ||
go_fp8._data, go_fp8._scale, go_fp8._flavor, | ||
w_t_fp8._data, w_t_fp8._scale, w_t_fp8._flavor, | ||
fp8_s_dL_dX, E5M2) | ||
dL_dX_fp8 = Float8Tensor(dL_dX_bits, fp8_s_dL_dX, E5M2) | ||
|
||
dL_dW_bits = torch.ops.aten.mm_float8( | ||
x_fp8._data.t(), x_fp8._scale, x_fp8._flavor, | ||
go_fp8._data, go_fp8._scale, go_fp8._flavor, | ||
fp8_s_dL_dW, E5M2).t() | ||
dL_dW_fp8 = Float8Tensor(dL_dW_bits, fp8_s_dL_dW, E5M2) | ||
|
||
# scale update would also happen here, for now no-op | ||
return dL_dX_fp8, dL_dW_fp8, None, None, None, None | ||
|
||
|
||
class Float8Linear(torch.nn.Linear): | ||
""" | ||
A wrapper around a `torch.nn.Linear` module which does fp8 compute, and tracks | ||
scales in way friendly to delayed scaling. | ||
""" | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
# While this module currently implements just-in-time scaling, | ||
# the scales are stored in buffers as a placeholder for delayed | ||
# scaling such as the mechanism described in | ||
# https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#Mixed-precision-training-with-FP8, | ||
# or PTQ calibration. | ||
self.register_buffer('fp8_s_in', torch.tensor(1.0)) | ||
self.register_buffer('fp8_s_weight', torch.tensor(1.0)) | ||
self.register_buffer('fp8_s_out', torch.tensor(1.0)) | ||
self.register_buffer('fp8_s_dL_dX', torch.tensor(1.0)) | ||
self.register_buffer('fp8_s_dL_dW', torch.tensor(1.0)) | ||
self.register_buffer('fp8_s_dL_dY', torch.tensor(1.0)) | ||
|
||
def forward(self, x): | ||
if not isinstance(x, Float8Tensor): | ||
# TODO(future): switch to delayed scaling | ||
self.fp8_s_in.fill_(tensor_to_scale(x, E4M3)) | ||
x_fp8 = Float8Tensor.from_float32(x, self.fp8_s_in, E4M3) | ||
else: | ||
x_fp8 = x | ||
|
||
# TODO(future): switch to delayed scaling | ||
self.fp8_s_weight.fill_(tensor_to_scale(self.weight, E4M3)) | ||
w_t_fp8 = Float8Tensor.from_float32(self.weight, self.fp8_s_weight, E4M3) | ||
|
||
y_fp8 = float8_linear_no_bias.apply( | ||
x_fp8, w_t_fp8, self.fp8_s_out, self.fp8_s_dL_dX, | ||
self.fp8_s_dL_dW, self.fp8_s_dL_dY) | ||
|
||
# For now, hardcode returning Float8Tensor (propagate as much as we can). | ||
# This can be changed to return a different dtype, if needed. | ||
return y_fp8 | ||
|
||
@classmethod | ||
def from_float(cls, mod): | ||
""" | ||
Create an nn.Linear with fp8 compute from a regular nn.Linear | ||
""" | ||
assert mod.bias is None, 'bias support not implemented yet' | ||
new_mod = cls(mod.in_features, mod.out_features, bias=False) | ||
new_mod.weight = mod.weight | ||
return new_mod |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
from enum import Enum | ||
import torch | ||
from torch.utils._pytree import tree_map | ||
|
||
from float8_utils import E4M3, E5M2 | ||
|
||
aten = torch.ops.aten | ||
|
||
class Float8ConstrFunc(torch.autograd.Function): | ||
""" | ||
A differentiable conversion between fp32 and fp8 | ||
TODO(future): split into two for cleaner code | ||
""" | ||
@staticmethod | ||
def forward(ctx, tensor, scale: float=None, flavor=E4M3): | ||
if isinstance(tensor, Float8Tensor): | ||
ctx.inp_is_float8 = True | ||
return torch.ops.aten.float8_to_float32(tensor._data, tensor._flavor) / tensor._scale | ||
else: | ||
ctx.inp_is_float8 = False | ||
tensor_scaled = tensor * scale | ||
bits_fp8 = torch.ops.aten.float32_to_float8(tensor_scaled, flavor) | ||
return Float8Tensor(bits_fp8, scale, flavor) | ||
|
||
@staticmethod | ||
def backward(ctx, g): | ||
# Assume that we always want to scale the gradients | ||
# back to full precision. We could do something else | ||
if isinstance(g, Float8Tensor) and not ctx.inp_is_float8: | ||
return g.to_float32(), None, None | ||
elif ctx.inp_is_float8: | ||
return Float8Tensor.from_float32(g), None, None | ||
else: | ||
return g, None, None | ||
|
||
|
||
class Float8Tensor(torch.Tensor): | ||
""" | ||
A Python-only FP8 tensor. Contains: | ||
* `_data`: the underlying e4m3 or e5m2 data | ||
* `_scale`: the scale used to scale the original fp32 tensor. We multiply | ||
by scale to go from fp32 range to fp8 range, and divide by scale to go | ||
from fp8 range to fp32 range. | ||
* `_flavor`: either E4M3 or E5M2 | ||
The current purpose of this object is 99% to bundle raw data + fp8 metadata | ||
together for easy passing through PyTorch systems, and 1% to implement | ||
gradient addition (since that has to happen outside of user code). | ||
The addition operation is defined inline and uses a naive | ||
version of stateless scaling. This allows e5m2 gradients to be added. | ||
TODO(future): verify this is numericaly accurate, optionally replace | ||
with something better. | ||
It would probably make sense to also define fp8 path for data shuffling | ||
ops like cat, transpose, view, etc inline so we don't have to fall back | ||
to fp32 for them. | ||
""" | ||
|
||
def __new__(cls, data, scale, flavor): | ||
# This is a non-differentiable constructor! | ||
assert not data.requires_grad | ||
# TODO(future): make bits8 easier to work with and switch to using it | ||
# assert data.dtype == torch.bits8 | ||
assert scale.dtype == torch.float32 | ||
assert scale.nelement() == 1 | ||
|
||
self = torch.Tensor._make_wrapper_subclass( | ||
cls, | ||
data.size(), | ||
strides=data.stride(), | ||
storage_offset=data.storage_offset(), | ||
dtype=torch.float32, | ||
layout=data.layout, | ||
requires_grad=data.requires_grad, | ||
device=data.device, | ||
) | ||
self._data = data | ||
self._scale = scale | ||
self._flavor = flavor | ||
|
||
return self | ||
|
||
def __repr__(self): | ||
return f"Float8Tensor(flavor={self._flavor}, scale={self._scale}, as_float32={self.to_float32()}" | ||
|
||
def to_float32(self): | ||
return Float8ConstrFunc.apply(self) | ||
|
||
@classmethod | ||
def from_float32(cls, tensor, scale, flavor): | ||
return Float8ConstrFunc.apply(tensor, scale, flavor) | ||
|
||
@classmethod | ||
def __torch_dispatch__(cls, func, types, args, kwargs=None): | ||
# Note: unlike many other subclasses, this subclass's only propagates | ||
# itself for addition (for gradient addition in backward). For all | ||
# other ops, it self-converts to fp32. The user/framework is | ||
# assumed to take care of defining where fp8 operations occur in the | ||
# forward pass and how scaling is calculated. In this example, that is | ||
# done by the `FP8Linear` class. | ||
# Vasiliy: the main reason I went with ^ is because NVIDIA is | ||
# doing stateful delayed scaling, and I don't know of a safe | ||
# way to enable that without either full program capture or punting it | ||
# to the user. This prototype takes the "punt it to the user" approach. | ||
# IMO for now let's just write out the scale stuff manually so we can | ||
# focus on other things, and revisit later if needed. | ||
|
||
# override addition so we can add e5m2 gradients | ||
if ( | ||
func is aten.add.Tensor | ||
and isinstance(args[0], Float8Tensor) | ||
and isinstance(args[1], Float8Tensor) | ||
): | ||
x1_fp8, x2_fp8 = args[0], args[1] | ||
assert x1_fp8._flavor == E5M2 and x2_fp8._flavor == E5M2 | ||
# naive scale calculation: max of incoming two scales | ||
x3_scale = torch.max(x1_fp8._scale, x2_fp8._scale) | ||
res_bits = torch.ops.aten.add_float8_e5m2( | ||
x1_fp8._data, x1_fp8._scale, | ||
x2_fp8._data, x2_fp8._scale, | ||
x3_scale) | ||
res = Float8Tensor(res_bits, x3_scale, x1_fp8._flavor) | ||
return res | ||
|
||
# for all other ops, fall back to fp32 | ||
# TODO(future): add support for fp16/bf16 | ||
|
||
def maybe_unwrap(t): | ||
if isinstance(t, Float8Tensor): | ||
return t.to_float32() | ||
return t | ||
|
||
args = tree_map(maybe_unwrap, args) | ||
if kwargs is not None: | ||
kwargs = tree_map(maybe_unwrap, kwargs) | ||
out = super().__torch_dispatch__(func, types, args, kwargs) | ||
return out | ||
|
||
# Do not force the Float8Tensor type on the returned tensor | ||
__torch_function__ = torch._C._disabled_torch_function_impl |
Oops, something went wrong.