Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] Python-only float8 data type + bare bones UEX #23

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions protoquant/float8/float8_aten_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""
This file defines the aten functions for float8. Today, all of these functions
are emulated. In the future, they should be calling NVIDIA's float8 kernels.
"""

import torch
from torch.library import Library

from float8_utils import (
float32_to_float8,
float8_to_float32,
E4M3,
E5M2,
tensor_to_scale,
)


def mm_float8(m1, s1, flavor1, m2, s2, flavor2, s3, flavor3):
# naive implementation: dq -> op -> q
# TODO(future): hook up to real kernel
m1_fp32 = float8_to_float32(m1, flavor1) / s1
m2_fp32 = float8_to_float32(m2, flavor2) / s2
m3_fp32 = torch.mm(m1_fp32, m2_fp32)
# TODO(future): switch to delayed scaling
s3.fill_(tensor_to_scale(m3_fp32, flavor3))
m3_fp32_scaled = m3_fp32 * s3
return float32_to_float8(m3_fp32_scaled, flavor3)

def add_float8_e5m2(m1, s1, m2, s2, s3):
# for now this is only implemented for e5m2 because we only care about
# this for adding gradients
# naive implementation: dq -> op -> q
# TODO(future): hook up to real kernel
# TODO(future): make this more accurate, accuracy is pretty low,
# can probably just calculate s3 dynamically since this is an edge case
# unlikely to affect e2e performance
m1_float32 = float8_to_float32(m1, E5M2) / s1
m2_float32 = float8_to_float32(m2, E5M2) / s2
m3_float32 = m1_float32 + m2_float32
return float32_to_float8(m3_float32 * s3, E5M2)

#
# ATen op placeholders
#

# Register the aten level functions we need.
# These are mostly placeholder and might need to be implemented in c++ as needed
lib = Library("aten", "FRAGMENT")

# For now register on CPU,
# TODO(future) add GPU and test there
lib.define("float32_to_float8(Tensor t, int flavor) -> Tensor")
lib.impl("float32_to_float8", float32_to_float8, "CPU")

lib.define("float8_to_float32(Tensor t, int flavor) -> Tensor")
lib.impl("float8_to_float32", float8_to_float32, "CPU")

lib.define("mm_float8(Tensor m1, Tensor s1, int flavor1, Tensor m2, Tensor s2, int flavor2, Tensor s3, int flavor3) -> Tensor")
lib.impl("mm_float8", mm_float8, "CPU")

lib.define("add_float8_e5m2(Tensor m1, Tensor s1, Tensor m2, Tensor s2, Tensor s3) -> Tensor")
lib.impl("add_float8_e5m2", add_float8_e5m2, "CPU")
122 changes: 122 additions & 0 deletions protoquant/float8/float8_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""
A simple manual UEX for a float8 version of `torch.nn.Linear`.

Note: this UEX is not intended for real usage. It merely demonstrates
an example of how features such as casting to and from float8 as well
as stateful scaling can be implemented. For now, we expect framework
owners to implement their own UEX.
"""

import torch

import float8_aten_api

from float8_utils import E4M3, E5M2, tensor_to_scale
from float8_tensor import Float8Tensor

class float8_linear_no_bias(torch.autograd.Function):
"""
Like F.linear, but with X, W, and Y in float8
TODO(future) add logic for bias
"""

@staticmethod
def forward(
ctx,
x_fp8,
w_fp8,
fp8_s_out,
fp8_s_dL_dX,
fp8_s_dL_dW,
fp8_s_dL_dY,
):
ctx.save_for_backward(x_fp8, w_fp8, fp8_s_dL_dX, fp8_s_dL_dW, fp8_s_dL_dY)

res_bits = torch.ops.aten.mm_float8(
x_fp8._data, x_fp8._scale, x_fp8._flavor,
w_fp8._data.t(), w_fp8._scale, w_fp8._flavor,
fp8_s_out, E4M3)

res = Float8Tensor(res_bits, fp8_s_out, E4M3)
# scale update would also happen here, for now no-op
return res

@staticmethod
def backward(ctx, go):
x_fp8, w_fp8, fp8_s_dL_dX, fp8_s_dL_dW, fp8_s_dL_dY = \
ctx.saved_tensors

if not isinstance(go, Float8Tensor):
# TODO(future): switch to delayed scaling
fp8_s_dL_dY.fill_(tensor_to_scale(go, E5M2))
go_fp8 = Float8Tensor(
torch.ops.aten.float32_to_float8(go * fp8_s_dL_dY, E5M2),
fp8_s_dL_dY,
E5M2)
else:
go_fp8 = go

dL_dX_bits = torch.ops.aten.mm_float8(
go_fp8._data, go_fp8._scale, go_fp8._flavor,
w_fp8._data, w_fp8._scale, w_fp8._flavor,
fp8_s_dL_dX, E5M2)
dL_dX_fp8 = Float8Tensor(dL_dX_bits, fp8_s_dL_dX, E5M2)

dL_dW_bits = torch.ops.aten.mm_float8(
x_fp8._data.t(), x_fp8._scale, x_fp8._flavor,
go_fp8._data, go_fp8._scale, go_fp8._flavor,
fp8_s_dL_dW, E5M2).t()
dL_dW_fp8 = Float8Tensor(dL_dW_bits, fp8_s_dL_dW, E5M2)

# scale update would also happen here, for now no-op
return dL_dX_fp8, dL_dW_fp8, None, None, None, None


class Float8Linear(torch.nn.Linear):
"""
A wrapper around a `torch.nn.Linear` module which does fp8 compute, and tracks
scales in way friendly to delayed scaling.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# While this module currently implements just-in-time scaling,
# the scales are stored in buffers as a placeholder for delayed
# scaling such as the mechanism described in
# https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#Mixed-precision-training-with-FP8,
# or PTQ calibration.
self.register_buffer('fp8_s_in', torch.tensor(1.0))
self.register_buffer('fp8_s_weight', torch.tensor(1.0))
self.register_buffer('fp8_s_out', torch.tensor(1.0))
self.register_buffer('fp8_s_dL_dX', torch.tensor(1.0))
self.register_buffer('fp8_s_dL_dW', torch.tensor(1.0))
self.register_buffer('fp8_s_dL_dY', torch.tensor(1.0))

def forward(self, x):
if not isinstance(x, Float8Tensor):
# TODO(future): switch to delayed scaling
self.fp8_s_in.fill_(tensor_to_scale(x, E4M3))
x_fp8 = Float8Tensor.from_float32(x, self.fp8_s_in, E4M3)
else:
x_fp8 = x

# TODO(future): switch to delayed scaling
self.fp8_s_weight.fill_(tensor_to_scale(self.weight, E4M3))
w_fp8 = Float8Tensor.from_float32(self.weight, self.fp8_s_weight, E4M3)

y_fp8 = float8_linear_no_bias.apply(
x_fp8, w_fp8, self.fp8_s_out, self.fp8_s_dL_dX,
self.fp8_s_dL_dW, self.fp8_s_dL_dY)

# For now, hardcode returning Float8Tensor (propagate as much as we can).
# This can be changed to return a different dtype, if needed.
return y_fp8

@classmethod
def from_float(cls, mod):
"""
Create an nn.Linear with fp8 compute from a regular nn.Linear
"""
assert mod.bias is None, 'bias support not implemented yet'
new_mod = cls(mod.in_features, mod.out_features, bias=False)
new_mod.weight = mod.weight
return new_mod
141 changes: 141 additions & 0 deletions protoquant/float8/float8_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from enum import Enum
import torch
from torch.utils._pytree import tree_map

from float8_utils import E4M3, E5M2

aten = torch.ops.aten

class Float8ConstrFunc(torch.autograd.Function):
"""
A differentiable conversion between fp32 and fp8
TODO(future): split into two for cleaner code
"""
@staticmethod
def forward(ctx, tensor, scale: float=None, flavor=E4M3):
if isinstance(tensor, Float8Tensor):
ctx.inp_is_float8 = True
return torch.ops.aten.float8_to_float32(tensor._data, tensor._flavor) / tensor._scale
else:
ctx.inp_is_float8 = False
tensor_scaled = tensor * scale
bits_fp8 = torch.ops.aten.float32_to_float8(tensor_scaled, flavor)
return Float8Tensor(bits_fp8, scale, flavor)

@staticmethod
def backward(ctx, g):
# Assume that we always want to scale the gradients
# back to full precision. We could do something else
if isinstance(g, Float8Tensor) and not ctx.inp_is_float8:
return g.to_float32(), None, None
elif ctx.inp_is_float8:
return Float8Tensor.from_float32(g), None, None
else:
return g, None, None


class Float8Tensor(torch.Tensor):
"""
A Python-only FP8 tensor. Contains:
* `_data`: the underlying e4m3 or e5m2 data
* `_scale`: the scale used to scale the original fp32 tensor. We multiply
by scale to go from fp32 range to fp8 range, and divide by scale to go
from fp8 range to fp32 range.
* `_flavor`: either E4M3 or E5M2
The current purpose of this object is 99% to bundle raw data + fp8 metadata
together for easy passing through PyTorch systems, and 1% to implement
gradient addition (since that has to happen outside of user code).
The addition operation is defined inline and uses a naive
version of stateless scaling. This allows e5m2 gradients to be added.
TODO(future): verify this is numericaly accurate, optionally replace
with something better.
It would probably make sense to also define fp8 path for data shuffling
ops like cat, transpose, view, etc inline so we don't have to fall back
to fp32 for them.
"""

def __new__(cls, data, scale, flavor):
# This is a non-differentiable constructor!
assert not data.requires_grad
# TODO(future): make bits8 easier to work with and switch to using it
# assert data.dtype == torch.bits8
assert scale.dtype == torch.float32
assert scale.nelement() == 1

self = torch.Tensor._make_wrapper_subclass(
cls,
data.size(),
strides=data.stride(),
storage_offset=data.storage_offset(),
dtype=torch.float32,
layout=data.layout,
requires_grad=data.requires_grad,
device=data.device,
)
self._data = data
self._scale = scale
self._flavor = flavor

return self

def __repr__(self):
return f"Float8Tensor(flavor={self._flavor}, scale={self._scale}, as_float32={self.to_float32()}"

def to_float32(self):
return Float8ConstrFunc.apply(self)

@classmethod
def from_float32(cls, tensor, scale, flavor):
return Float8ConstrFunc.apply(tensor, scale, flavor)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
# Note: unlike many other subclasses, this subclass's only propagates
# itself for addition (for gradient addition in backward). For all
# other ops, it self-converts to fp32. The user/framework is
# assumed to take care of defining where fp8 operations occur in the
# forward pass and how scaling is calculated. In this example, that is
# done by the `FP8Linear` class.
# Vasiliy: the main reason I went with ^ is because NVIDIA is
# doing stateful delayed scaling, and I don't know of a safe
# way to enable that without either full program capture or punting it
# to the user. This prototype takes the "punt it to the user" approach.
# IMO for now let's just write out the scale stuff manually so we can
# focus on other things, and revisit later if needed.

# override addition so we can add e5m2 gradients
if (
func is aten.add.Tensor
and isinstance(args[0], Float8Tensor)
and isinstance(args[1], Float8Tensor)
):
x1_fp8, x2_fp8 = args[0], args[1]
assert x1_fp8._flavor == E5M2 and x2_fp8._flavor == E5M2
# naive scale calculation: max of incoming two scales
x3_scale = torch.max(x1_fp8._scale, x2_fp8._scale)
res_bits = torch.ops.aten.add_float8_e5m2(
x1_fp8._data, x1_fp8._scale,
x2_fp8._data, x2_fp8._scale,
x3_scale)
res = Float8Tensor(res_bits, x3_scale, x1_fp8._flavor)
return res

# for all other ops, fall back to fp32
# TODO(future): add support for fp16/bf16

def maybe_unwrap(t):
if isinstance(t, Float8Tensor):
return t.to_float32()
return t

args = tree_map(maybe_unwrap, args)
if kwargs is not None:
kwargs = tree_map(maybe_unwrap, kwargs)
out = super().__torch_dispatch__(func, types, args, kwargs)
return out

# Do not force the Float8Tensor type on the returned tensor
__torch_function__ = torch._C._disabled_torch_function_impl
Loading