Skip to content

Commit

Permalink
[wip] Python-only float8 data type + bare bones UEX
Browse files Browse the repository at this point in the history
Summary:

Note: this is WIP and does not represent PyTorch's opinion on how we
will integrate float8. At this point, this is a prototype to get some light
feedback.

TODOs that need to be implemented before review of this PR:
* [done] emulated float8 casts to and from float32
* [not done] move https://github.com/albanD/subclass_zoo/blob/fp8_v2/fp8_subclass_v2.py here and hook up to accurate casts
* [not done] basic dynamic scaling (with the assumption that the real UEX can switch to delayed scaling)
* [not done] numerical testing that this all works for fw + bw with a toy model

TODOs to be done in next couple of PRs before sending to NVIDIA:
* [not done] example of how an integration with distributed would work

What is out of scope for this POC
a. hooking up to real float8 ops (saved for later, just needs someone to do it)
b. real UEX (saved for later and will need a lot of design discussion)

Test plan:

```
python protoquant/float8/test.py
```
  • Loading branch information
vkuzo committed May 11, 2023
1 parent 46e75e3 commit 459a2ca
Show file tree
Hide file tree
Showing 6 changed files with 738 additions and 0 deletions.
60 changes: 60 additions & 0 deletions protoquant/float8/e2e_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from float8_utils import E4M3, E5M2
from float8_tensor import Float8Tensor
from float8_linear import Float8Linear

aten = torch.ops.aten

torch.manual_seed(0)

# TODO(before land): move this to unit tests (in progress)
if __name__ == "__main__":

# test addition
print("\nExample of addition\n")
x1_fp32, x1_s = torch.randn(4), torch.tensor(1.0)
x2_fp32, x2_s = torch.randn(4), torch.tensor(1.0)
x1_fp8 = Float8Tensor.from_float32(x1_fp32, x1_s, E5M2)
x2_fp8 = Float8Tensor.from_float32(x2_fp32, x2_s, E5M2)
x3_fp8 = x1_fp8 + x2_fp8
print('x1', x1_fp8, '\nx2', x2_fp8, '\nx1+x2', x3_fp8)


print("\nExample of fp8 linear fw + bw\n")

class M(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(2, 3, bias=False)
self.fc2 = nn.Linear(3, 4, bias=False)
self.fc3 = nn.Linear(3, 4, bias=False)
self.fc4 = nn.Linear(4, 2, bias=False)

def forward(self, x0):
x1 = self.fc1(x0)
x2 = self.fc2(x1)
x3 = self.fc3(x1)
# test gradient addition
# Note: cat happens in fp32, for now
c = torch.cat([x2, x3])
x4 = self.fc4(c)
return x4

m = M()
m.fc1 = Float8Linear.from_float(m.fc1)
m.fc2 = Float8Linear.from_float(m.fc2)
m.fc3 = Float8Linear.from_float(m.fc3)
m.fc4 = Float8Linear.from_float(m.fc4)

print(m)

x = Float8Tensor.from_float32(torch.randn(1, 2), torch.tensor(1.0), E4M3)
y = m(x)
print(y)
s = y.sum()
print('before grad', m.fc1.weight.grad, m.fc2.weight.grad, m.fc3.weight.grad, m.fc4.weight.grad)
s.backward()
print('after grad', m.fc1.weight.grad, m.fc2.weight.grad, m.fc3.weight.grad, m.fc4.weight.grad)
60 changes: 60 additions & 0 deletions protoquant/float8/float8_aten_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""
This file defines the aten functions for float8. Today, all of these functions
are emulated. In the future, they should be calling NVIDIA's float8 kernels.
"""

import torch
from torch.library import Library

from float8_utils import (
float32_to_float8,
float8_to_float32,
E4M3,
E5M2,
tensor_to_scale,
)


# TODO clean up var names
def mm_float8(m1, s1, flavor1, m2, s2, flavor2, sout, flavorout):
# naive implementation: dq -> op -> q
# TODO(future): hook up to real kernel
full_m1 = float8_to_float32(m1, flavor1) / s1
full_m2 = float8_to_float32(m2, flavor2) / s2
full_out = torch.mm(full_m1, full_m2)
# TODO(future): switch to delayed scaling
sout.fill_(tensor_to_scale(full_out, flavorout))
out = full_out * sout
return float32_to_float8(out, flavorout)

def add_float8_e5m2(m1, s1, m2, s2, s3):
# for now this is only implemented for e5m2 because we only care about
# this for adding gradients
# naive implementation: dq -> op -> q
# TODO(future): hook up to real kernel
m1_float32 = float8_to_float32(m1, E5M2) / s1
m2_float32 = float8_to_float32(m2, E5M2) / s2
m3_float32 = m1_float32 + m2_float32
return float32_to_float8(m3_float32 * s3, E5M2)

#
# ATen op placeholders
#

# Register the aten level functions we need.
# These are mostly placeholder and might need to be implemented in c++ as needed
lib = Library("aten", "FRAGMENT")

# Define our new custom functions
# Since all my Tensors are on CPU, I register everything there.
lib.define("float32_to_float8(Tensor t, int flavor) -> Tensor")
lib.impl("float32_to_float8", float32_to_float8, "CPU")

lib.define("float8_to_float32(Tensor t, int flavor) -> Tensor")
lib.impl("float8_to_float32", float8_to_float32, "CPU")

lib.define("mm_float8(Tensor m1, Tensor s1, int flavor1, Tensor m2, Tensor s2, int flavor2, Tensor sout, int flavorout) -> Tensor")
lib.impl("mm_float8", mm_float8, "CPU")

lib.define("add_float8_e5m2(Tensor m1, Tensor s1, Tensor m2, Tensor s2, Tensor s3) -> Tensor")
lib.impl("add_float8_e5m2", add_float8_e5m2, "CPU")
122 changes: 122 additions & 0 deletions protoquant/float8/float8_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""
A simple manual UEX for a float8 version of `torch.nn.Linear`.
Note: this UEX is not intended for real usage. It merely demonstrates
an example of how features such as casting to and from float8 as well
as stateful scaling can be implemented. For now, we expect framework
owners to implement their own UEX.
"""

import torch

import float8_aten_api

from float8_utils import E4M3, E5M2, tensor_to_scale
from float8_tensor import Float8Tensor

class float8_linear_no_bias(torch.autograd.Function):
"""
Like F.linear, but with X, W, and Y in float8
TODO(future) add logic for bias
"""

@staticmethod
def forward(
ctx,
x_fp8,
w_t_fp8,
fp8_s_out,
fp8_s_dL_dX,
fp8_s_dL_dW,
fp8_s_dL_dY,
):
ctx.save_for_backward(x_fp8, w_t_fp8, fp8_s_dL_dX, fp8_s_dL_dW, fp8_s_dL_dY)

res_bits = torch.ops.aten.mm_float8(
x_fp8._data, x_fp8._scale, x_fp8._flavor,
w_t_fp8._data.t(), w_t_fp8._scale, w_t_fp8._flavor,
fp8_s_out, E4M3)

res = Float8Tensor(res_bits, fp8_s_out, E4M3)
# scale update would also happen here, for now no-op
return res

@staticmethod
def backward(ctx, go):
x_fp8, w_t_fp8, fp8_s_dL_dX, fp8_s_dL_dW, fp8_s_dL_dY = \
ctx.saved_tensors

if not isinstance(go, Float8Tensor):
# TODO(future): switch to delayed scaling
fp8_s_dL_dY.fill_(tensor_to_scale(go, E5M2))
go_fp8 = Float8Tensor(
torch.ops.aten.float32_to_float8(go * fp8_s_dL_dY, E5M2),
fp8_s_dL_dY,
E5M2)
else:
go_fp8 = go

dL_dX_bits = torch.ops.aten.mm_float8(
go_fp8._data, go_fp8._scale, go_fp8._flavor,
w_t_fp8._data, w_t_fp8._scale, w_t_fp8._flavor,
fp8_s_dL_dX, E5M2)
dL_dX_fp8 = Float8Tensor(dL_dX_bits, fp8_s_dL_dX, E5M2)

dL_dW_bits = torch.ops.aten.mm_float8(
x_fp8._data.t(), x_fp8._scale, x_fp8._flavor,
go_fp8._data, go_fp8._scale, go_fp8._flavor,
fp8_s_dL_dW, E5M2).t()
dL_dW_fp8 = Float8Tensor(dL_dW_bits, fp8_s_dL_dW, E5M2)

# scale update would also happen here, for now no-op
return dL_dX_fp8, dL_dW_fp8, None, None, None, None


class Float8Linear(torch.nn.Linear):
"""
A wrapper around a `torch.nn.Linear` module which does fp8 compute, and tracks
scales in way friendly to delayed scaling.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# While this module currently implements just-in-time scaling,
# the scales are stored in buffers as a placeholder for delayed
# scaling such as the mechanism described in
# https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#Mixed-precision-training-with-FP8,
# or PTQ calibration.
self.register_buffer('fp8_s_in', torch.tensor(1.0))
self.register_buffer('fp8_s_weight', torch.tensor(1.0))
self.register_buffer('fp8_s_out', torch.tensor(1.0))
self.register_buffer('fp8_s_dL_dX', torch.tensor(1.0))
self.register_buffer('fp8_s_dL_dW', torch.tensor(1.0))
self.register_buffer('fp8_s_dL_dY', torch.tensor(1.0))

def forward(self, x):
if not isinstance(x, Float8Tensor):
# TODO(future): switch to delayed scaling
self.fp8_s_in.fill_(tensor_to_scale(x, E4M3))
x_fp8 = Float8Tensor.from_float32(x, self.fp8_s_in, E4M3)
else:
x_fp8 = x

# TODO(future): switch to delayed scaling
self.fp8_s_weight.fill_(tensor_to_scale(self.weight, E4M3))
w_t_fp8 = Float8Tensor.from_float32(self.weight, self.fp8_s_weight, E4M3)

y_fp8 = float8_linear_no_bias.apply(
x_fp8, w_t_fp8, self.fp8_s_out, self.fp8_s_dL_dX,
self.fp8_s_dL_dW, self.fp8_s_dL_dY)

# For now, hardcode returning Float8Tensor (propagate as much as we can).
# This can be changed to return a different dtype, if needed.
return y_fp8

@classmethod
def from_float(cls, mod):
"""
Create an nn.Linear with fp8 compute from a regular nn.Linear
"""
assert mod.bias is None, 'bias support not implemented yet'
new_mod = cls(mod.in_features, mod.out_features, bias=False)
new_mod.weight = mod.weight
return new_mod
141 changes: 141 additions & 0 deletions protoquant/float8/float8_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from enum import Enum
import torch
from torch.utils._pytree import tree_map

from float8_utils import E4M3, E5M2

aten = torch.ops.aten

class Float8ConstrFunc(torch.autograd.Function):
"""
A differentiable conversion between fp32 and fp8
TODO(future): split into two for cleaner code
"""
@staticmethod
def forward(ctx, tensor, scale: float=None, flavor=E4M3):
if isinstance(tensor, Float8Tensor):
ctx.inp_is_float8 = True
return torch.ops.aten.float8_to_float32(tensor._data, tensor._flavor) / tensor._scale
else:
ctx.inp_is_float8 = False
tensor_scaled = tensor * scale
bits_fp8 = torch.ops.aten.float32_to_float8(tensor_scaled, flavor)
return Float8Tensor(bits_fp8, scale, flavor)

@staticmethod
def backward(ctx, g):
# Assume that we always want to scale the gradients
# back to full precision. We could do something else
if isinstance(g, Float8Tensor) and not ctx.inp_is_float8:
return g.to_float32(), None, None
elif ctx.inp_is_float8:
return Float8Tensor.from_float32(g), None, None
else:
return g, None, None


class Float8Tensor(torch.Tensor):
"""
A Python-only FP8 tensor. Contains:
* `_data`: the underlying e4m3 or e5m2 data
* `_scale`: the scale used to scale the original fp32 tensor. We multiply
by scale to go from fp32 range to fp8 range, and divide by scale to go
from fp8 range to fp32 range.
* `_flavor`: either E4M3 or E5M2
The current purpose of this object is 99% to bundle raw data + fp8 metadata
together for easy passing through PyTorch systems, and 1% to implement
gradient addition (since that has to happen outside of user code).
The addition operation is defined inline and uses a naive
version of stateless scaling. This allows e5m2 gradients to be added.
TODO(future): verify this is numericaly accurate, optionally replace
with something better.
It would probably make sense to also define fp8 path for data shuffling
ops like cat, transpose, view, etc inline so we don't have to fall back
to fp32 for them.
"""

def __new__(cls, data, scale, flavor):
# This is a non-differentiable constructor!
assert not data.requires_grad
# TODO(future): make bits8 easier to work with and switch to using it
# assert data.dtype == torch.bits8
assert scale.dtype == torch.float32
assert scale.nelement() == 1

self = torch.Tensor._make_wrapper_subclass(
cls,
data.size(),
strides=data.stride(),
storage_offset=data.storage_offset(),
dtype=torch.float32,
layout=data.layout,
requires_grad=data.requires_grad,
device=data.device,
)
self._data = data
self._scale = scale
self._flavor = flavor

return self

def __repr__(self):
return f"Float8Tensor(flavor={self._flavor}, scale={self._scale}, as_float32={self.to_float32()}"

def to_float32(self):
return Float8ConstrFunc.apply(self)

@classmethod
def from_float32(cls, tensor, scale, flavor):
return Float8ConstrFunc.apply(tensor, scale, flavor)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
# Note: unlike many other subclasses, this subclass's only propagates
# itself for addition (for gradient addition in backward). For all
# other ops, it self-converts to fp32. The user/framework is
# assumed to take care of defining where fp8 operations occur in the
# forward pass and how scaling is calculated. In this example, that is
# done by the `FP8Linear` class.
# Vasiliy: the main reason I went with ^ is because NVIDIA is
# doing stateful delayed scaling, and I don't know of a safe
# way to enable that without either full program capture or punting it
# to the user. This prototype takes the "punt it to the user" approach.
# IMO for now let's just write out the scale stuff manually so we can
# focus on other things, and revisit later if needed.

# override addition so we can add e5m2 gradients
if (
func is aten.add.Tensor
and isinstance(args[0], Float8Tensor)
and isinstance(args[1], Float8Tensor)
):
x1_fp8, x2_fp8 = args[0], args[1]
assert x1_fp8._flavor == E5M2 and x2_fp8._flavor == E5M2
# naive scale calculation: max of incoming two scales
x3_scale = torch.max(x1_fp8._scale, x2_fp8._scale)
res_bits = torch.ops.aten.add_float8_e5m2(
x1_fp8._data, x1_fp8._scale,
x2_fp8._data, x2_fp8._scale,
x3_scale)
res = Float8Tensor(res_bits, x3_scale, x1_fp8._flavor)
return res

# for all other ops, fall back to fp32
# TODO(future): add support for fp16/bf16

def maybe_unwrap(t):
if isinstance(t, Float8Tensor):
return t.to_float32()
return t

args = tree_map(maybe_unwrap, args)
if kwargs is not None:
kwargs = tree_map(maybe_unwrap, kwargs)
out = super().__torch_dispatch__(func, types, args, kwargs)
return out

# Do not force the Float8Tensor type on the returned tensor
__torch_function__ = torch._C._disabled_torch_function_impl
Loading

0 comments on commit 459a2ca

Please sign in to comment.