Skip to content

Commit

Permalink
[wip] Python-only float8 data type + bare bones UEX
Browse files Browse the repository at this point in the history
Summary:

Note: this is WIP and does not represent PyTorch's opinion on how we
will integrate float8. At this point, this is a prototype to get some light
feedback.

TODOs that need to be implemented before review of this PR:
* [done] emulated float8 casts to and from float32
* [not done] move https://github.com/albanD/subclass_zoo/blob/fp8_v2/fp8_subclass_v2.py here and hook up to accurate casts
* [not done] basic dynamic scaling (with the assumption that the real UEX can switch to delayed scaling)
* [not done] numerical testing that this all works for fw + bw with a toy model

TODOs to be done in next couple of PRs before sending to NVIDIA:
* [not done] example of how an integration with distributed would work

What is out of scope for this POC
a. hooking up to real float8 ops (saved for later, just needs someone to do it)
b. real UEX (saved for later and will need a lot of design discussion)

Test plan:

```
python protoquant/float8/test.py
```
  • Loading branch information
vkuzo committed May 10, 2023
1 parent 46e75e3 commit 4befd60
Show file tree
Hide file tree
Showing 2 changed files with 313 additions and 0 deletions.
153 changes: 153 additions & 0 deletions protoquant/float8/float8_casts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import torch

# This file reproduces the emulated fp8 <-> fp32 casts from
# https://github.com/pytorch/FBGEMM/pull/974/files with plain PyTorch.
# This implements the fp8 format spec from https://arxiv.org/pdf/2209.05433.pdf
#
# TODO(future PR): hook up to NVIDIA's casts on gpu, and
# import the fbgemm emulator directly on cpu. We'll also need to ensure
# the two are aligned.
#
# Helpful visualizer for debugging (only supports fp32):
# https://www.h-schmidt.net/FloatConverter/IEEE754.html

E4M3_EBITS = 4
E4M3_EXP_BIAS = 7
E4M3_MAX_POS = 448.0

E5M2_EBITS = 5
E5M2_EXP_BIAS = 15
E5M2_MAX_POS = 57344.0


def _float_to_hfp8(
val_fp: torch.Tensor, # fp32 values
ebits: int, # exponent bits, mbits = 8 - ebits
exponent_bias: int, # exponent bias to use in the fp8 encoding
max_pos: float, # maximum positive number for fp8 encoding
):
mbits = 7 - ebits

val_out = val_fp.clone().detach()

# S{1}, E{8}, M{23}
sign_bit = (val_out.view(torch.int32) & 0x80000000).view(torch.float)

# set sign bit to 0
# 0{1}, E{8}, M{23}
val_out = (val_out.view(torch.int32) & 0x7FFFFFFF).view(torch.float)

# ensure abs(val_out) <= max_pos)
val_out = torch.clamp(val_out, max=max_pos)

smallest_normal = torch.zeros_like(val_fp).to(torch.int32)
smallest_normal = ((smallest_normal + 127 - exponent_bias + 1) << 23).view(torch.float)

# normal and denormal paths split below, record
# which element need which path
is_normal_mask = torch.ge(val_out, smallest_normal)

#
# normal path
#

# Use round to nearest even. We make use of the standard rounding mechanism
# in FP32 rather than rounding the mantissa and handling tie-to-even and
# incrementing exponent We want to round of 23-mbits of the FP32 value
# val_in This can be done by adding a power of 2 exactly 23-mbits larger
# than the exponent of val_in This forces val_in to be moved to the right
# and rounding exact at the location corresponding to having mbits of
# explicit mantissa left
n_bouncer = ((val_out.view(torch.int32) & 0xFF800000) + ((23 - mbits) << 23)).view(torch.float)
n_val_out = (n_bouncer + val_out) - n_bouncer

# adding the bouncer rounds off bits, and subtracting bouncer
# leaves the desired value, albeit in FP32 encoding
# All we need is to change the exponent encoding to using "bias"
n_val_out_i = (n_val_out.view(torch.int32) - ((127 - exponent_bias) << 23)) << (8 - ebits)
n_val_out_i = (n_val_out_i | sign_bit.view(torch.int32)) >> 24
n_val_out = n_val_out_i.view(torch.float)

#
# denormal path
#

# When the value is in the denormal range, IEEE numbers essentially becomes
# a fixed point number. The lsb is the smallest non-zero number
# 2^(1-bias-mbits) Hence, we define the bouncer so that its lsb is this
# smallest non-zero number Adding the input to this bouncer forces rounding
# to occur appropriately Also, in this situation, after adding the bouncer,
# the 8 least significant bits of the sum is already the HFP8 encoding of
# the desired result. Just need to restore the sign bit
# bouncer.I = (127 + (23 + (1 - exponent_bias - mbits))) << 23;
# val_out.F = bouncer.F + val_out.F;
# val_out.I = val_out.I | (sign_bit >> 24);

dn_bouncer = ((torch.zeros_like(val_out).view(torch.int32) + 127 + (23 + (1 - exponent_bias - mbits))) << 23).view(torch.float)
dn_val_out = dn_bouncer + val_out
dn_val_out_i = dn_val_out.view(torch.int32) | (sign_bit.view(torch.int32) >> 24)
dn_val_out = dn_val_out_i.view(torch.float)

#
# combine normal and denormal paths
#
val_out = torch.where(is_normal_mask, n_val_out, dn_val_out)
# take the 8 least significant bits
orig_shape = val_fp.shape
val_out = val_out.view(torch.uint8)
val_out = val_out.reshape(-1, 4)
val_out = torch.tensor_split(val_out, 4, dim=-1)[0]
val_out = val_out.reshape(orig_shape)
return val_out


def _hfp8_to_float(
hfp8_val: torch.Tensor,
ebits: int,
exponent_bias: int,
):
assert hfp8_val.dtype == torch.uint8

sign_i = (hfp8_val & 0x80).to(torch.int32) << 24

val_out_i = (hfp8_val & 0x7F).to(torch.int32) << (24 - (8 - ebits))
# so that the mantissa bits start at the mantissa bit positions of FP32
# encoding

# Let the hfp8 mantissa bits correspond to the value frac, 0 <= frac < 1
# So if the hfp8 value is a normal number, it's value is 2^e x (1+frac)
# where e is its (true, unbiased) exponent
# If the hfp8 value is denormal, the value is 2^(1-bias) x frac

# However, the bit pattern in the 8-bit exponent field of val_out.F
# is bias+e when hfp8 is normal, and 0 when hfp8 is subnormal.
# So, as an FP32 value, when hfp8 is normal, val_out.F represents the value
# of 2^(bias+e-127) * (1+frac)
# And when hfp8 is subnormal, val_out.F is also subnormal, and represents the
# value of 2^(-126) * frac In either case, val_out.F corresponds to
# 2^(bias-127) * (value of hfp8 input) Thus, if we multiply val_out.F by
# 2^(127-bias), we obtain the hfp8 value as an FP32 number

# multiplier.I = (127 + (127 - exponent_bias))
# << 23; // multiplier.F is 2^(127-bias)
# val_out.F *= multiplier.F;
# val_out.I |= sign.I;
# return val_out.F;

multiplier_i = (torch.zeros_like(hfp8_val).to(torch.int32) + 127 + (127 - exponent_bias)) << 23 # multiplier_f is 2^(127-bias)
val_out_f = val_out_i.view(torch.float)
val_out_f *= multiplier_i.view(torch.float)
val_out_f = (val_out_f.view(torch.int32) | sign_i).view(torch.float)
return val_out_f

def float32_to_float8_e4m3(x):
return _float_to_hfp8(x, E4M3_EBITS, E4M3_EXP_BIAS, E4M3_MAX_POS)

def float32_to_float8_e5m2(x):
return _float_to_hfp8(x, E5M2_EBITS, E5M2_EXP_BIAS, E5M2_MAX_POS)

def float8_e4m3_to_float32(x):
return _hfp8_to_float(x, E4M3_EBITS, E4M3_EXP_BIAS)

def float8_e5m2_to_float32(x):
return _hfp8_to_float(x, E5M2_EBITS, E5M2_EXP_BIAS)
160 changes: 160 additions & 0 deletions protoquant/float8/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import random
import unittest

import torch

from float8_casts import (
float32_to_float8_e4m3,
float32_to_float8_e5m2,
float8_e4m3_to_float32,
float8_e5m2_to_float32,
)

random.seed(0)
torch.manual_seed(0)

class Fp8CastsUnitTest(unittest.TestCase):
"""
Test the casts between fp32 and fp8 (e4m3 and e5m2)
"""

def _compare_many_exact(self, flavor, x_fp32, comp_name):
if flavor == 'e4m3':
to_fp8 = float32_to_float8_e4m3
from_fp8 = float8_e4m3_to_float32
else: # e5m2
to_fp8 = float32_to_float8_e5m2
from_fp8 = float8_e5m2_to_float32

x_fp8 = to_fp8(x_fp32)
x_fp8_fp32 = from_fp8(x_fp8)

torch.testing.assert_close(x_fp32, x_fp8_fp32)

def _compare_many_approx(self, flavor, x_fp32, comp_name):
if flavor == 'e4m3':
sqnr_target = 25.0
to_fp8 = float32_to_float8_e4m3
from_fp8 = float8_e4m3_to_float32
else: # e5m2
sqnr_target = 23.0
to_fp8 = float32_to_float8_e5m2
from_fp8 = float8_e5m2_to_float32

x_fp8 = to_fp8(x_fp32)
x_fp8_fp32 = from_fp8(x_fp8)

# sign should always be the same
torch.testing.assert_close(
torch.sign(x_fp32),
torch.sign(x_fp8_fp32),
atol=0, rtol=0)

# for now just measure that sqnr is somewhat reasonable
# TODO(future): make this significantly more robust, this is about
# 2/10 on the scale of "robust enough"
def compute_error(x, y):
Ps = torch.norm(x)
Pn = torch.norm(x - y)
return 20 * torch.log10(Ps / Pn)

sqnr = compute_error(x_fp32, x_fp8_fp32)
assert sqnr >= sqnr_target


def _compare_one(self, flavor, bits_str, expected_fp32, comp_name):
if flavor == 'e4m3':
to_fp8 = float32_to_float8_e4m3
from_fp8 = float8_e4m3_to_float32
else: # e5m2
to_fp8 = float32_to_float8_e5m2
from_fp8 = float8_e5m2_to_float32

fp8_bits_ref = torch.tensor([int(bits_str, 2)], dtype=torch.uint8)

fp32_tensor = torch.tensor([expected_fp32], dtype=torch.float)
fp8_bits = to_fp8(fp32_tensor)
torch.testing.assert_close(fp8_bits, fp8_bits_ref, atol=0, rtol=0)

fp32_from_fp8_tensor = from_fp8(fp8_bits)
torch.testing.assert_close(fp32_tensor, fp32_from_fp8_tensor, atol=0, rtol=0)

def test_e4m3_numerics_single(self):
# ensure that our format matches https://arxiv.org/pdf/2209.05433.pdf, Table 1

flavor = 'e4m3'
# e4m3 does not support infinity
self._compare_one(flavor, "00000000", 0.0, "zero")
self._compare_one(flavor, "10000000", -0.0, "neg_zero")
self._compare_one(flavor, "01111110", 448.0, "max_normal")
self._compare_one(flavor, "11111110", -448.0, "neg_max_normal")
self._compare_one(flavor, "00001000", 2 ** -6, "min_normal")
self._compare_one(flavor, "10001000", -1 * (2 ** -6), "neg_min_normal")
self._compare_one(flavor, "00000111", 0.875 * (2 ** -6), "max_subnorm")
self._compare_one(flavor, "10000111", -0.875 * (2 ** -6), "neg_max_subnorm")
self._compare_one(flavor, "00000001", 2 ** -9, "min_subnorm")
self._compare_one(flavor, "10000001", -1 * (2 ** -9), "neg_min_subnorm")

def test_e5m2_numerics_single(self):
flavor = 'e5m2'
# e5m2 infinity (below) is off by one, TODO(future PR) debug or just move
# to NVIDIA's intrinsic casts
# _compare_one(flavor, "01111100", float("inf"), "inf")
# _compare_one(flavor, "11111100", -1 * float("inf"), "neg_inf")
self._compare_one(flavor, "00000000", 0.0, "zero")
self._compare_one(flavor, "10000000", -0.0, "neg_zero")
self._compare_one(flavor, "01111011", 57344.0, "max_normal")
self._compare_one(flavor, "11111011", -57344.0, "neg_max_normal")
self._compare_one(flavor, "00000100", 2 ** -14, "min_normal")
self._compare_one(flavor, "10000100", -1 * (2 ** -14), "neg_min_normal")
self._compare_one(flavor, "00000011", 0.75 * (2 ** -14), "max_subnorm")
self._compare_one(flavor, "10000011", -0.75 * (2 ** -14), "neg_max_subnorm")
self._compare_one(flavor, "00000001", 2 ** -16, "min_subnorm")
self._compare_one(flavor, "10000001", -1 * (2 ** -16), "neg_min_subnorm")

def test_e4m3_numerics_multiple(self):
# test special cases
x = torch.tensor([
0.0,
-0.0,
448.0,
-448.0,
2 ** -6,
-1 * (2 ** 6),
0.875 * (2 ** 6),
-0.875 * (2 ** 6),
2 ** -9,
-1 * (2 ** -9),
])
self._compare_many_exact('e4m3', x, 'special_cases')

# test normal values + shapes
for _ in range(10):
x = torch.randn(1, 2, 3, 4) * random.uniform(0.1, 300.0)
x.clamp_(min=-448.0, max=448.0)
self._compare_many_approx('e4m3', x, 'normal_cases')

def test_e5m2_numerics_multiple(self):
# test special cases
x = torch.tensor([
0.0,
-0.0,
57344.0,
-57344.0,
2 ** -14,
-1 * (2 ** -14),
0.75 * (2 ** -14),
-0.75 * (2 ** -14),
2 ** -16,
-1 * (2 ** -16),
])
self._compare_many_exact('e5m2', x, 'special_cases')

# test normal values + shapes
for _ in range(10):
x = torch.randn(1, 2, 3, 4) * random.uniform(0.1, 30000.0)
x.clamp_(min=-57344.0, max=57344.0)
self._compare_many_approx('e5m2', x, 'normal_cases')

if __name__ == '__main__':
unittest.main()

0 comments on commit 4befd60

Please sign in to comment.