-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[wip] Python-only float8 data type + bare bones UEX
Summary: Note: this is WIP and does not represent PyTorch's opinion on how we will integrate float8. At this point, this is a prototype to get some light feedback. TODOs that need to be implemented before review of this PR: * [done] emulated float8 casts to and from float32 * [not done] move https://github.com/albanD/subclass_zoo/blob/fp8_v2/fp8_subclass_v2.py here and hook up to accurate casts * [not done] basic dynamic scaling (with the assumption that the real UEX can switch to delayed scaling) * [not done] numerical testing that this all works for fw + bw with a toy model TODOs to be done in next couple of PRs before sending to NVIDIA: * [not done] example of how an integration with distributed would work What is out of scope for this POC a. hooking up to real float8 ops (saved for later, just needs someone to do it) b. real UEX (saved for later and will need a lot of design discussion) Test plan: ``` python protoquant/float8/test.py ```
- Loading branch information
Showing
2 changed files
with
313 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import torch | ||
|
||
# This file reproduces the emulated fp8 <-> fp32 casts from | ||
# https://github.com/pytorch/FBGEMM/pull/974/files with plain PyTorch. | ||
# This implements the fp8 format spec from https://arxiv.org/pdf/2209.05433.pdf | ||
# | ||
# TODO(future PR): hook up to NVIDIA's casts on gpu, and | ||
# import the fbgemm emulator directly on cpu. We'll also need to ensure | ||
# the two are aligned. | ||
# | ||
# Helpful visualizer for debugging (only supports fp32): | ||
# https://www.h-schmidt.net/FloatConverter/IEEE754.html | ||
|
||
E4M3_EBITS = 4 | ||
E4M3_EXP_BIAS = 7 | ||
E4M3_MAX_POS = 448.0 | ||
|
||
E5M2_EBITS = 5 | ||
E5M2_EXP_BIAS = 15 | ||
E5M2_MAX_POS = 57344.0 | ||
|
||
|
||
def _float_to_hfp8( | ||
val_fp: torch.Tensor, # fp32 values | ||
ebits: int, # exponent bits, mbits = 8 - ebits | ||
exponent_bias: int, # exponent bias to use in the fp8 encoding | ||
max_pos: float, # maximum positive number for fp8 encoding | ||
): | ||
mbits = 7 - ebits | ||
|
||
val_out = val_fp.clone().detach() | ||
|
||
# S{1}, E{8}, M{23} | ||
sign_bit = (val_out.view(torch.int32) & 0x80000000).view(torch.float) | ||
|
||
# set sign bit to 0 | ||
# 0{1}, E{8}, M{23} | ||
val_out = (val_out.view(torch.int32) & 0x7FFFFFFF).view(torch.float) | ||
|
||
# ensure abs(val_out) <= max_pos) | ||
val_out = torch.clamp(val_out, max=max_pos) | ||
|
||
smallest_normal = torch.zeros_like(val_fp).to(torch.int32) | ||
smallest_normal = ((smallest_normal + 127 - exponent_bias + 1) << 23).view(torch.float) | ||
|
||
# normal and denormal paths split below, record | ||
# which element need which path | ||
is_normal_mask = torch.ge(val_out, smallest_normal) | ||
|
||
# | ||
# normal path | ||
# | ||
|
||
# Use round to nearest even. We make use of the standard rounding mechanism | ||
# in FP32 rather than rounding the mantissa and handling tie-to-even and | ||
# incrementing exponent We want to round of 23-mbits of the FP32 value | ||
# val_in This can be done by adding a power of 2 exactly 23-mbits larger | ||
# than the exponent of val_in This forces val_in to be moved to the right | ||
# and rounding exact at the location corresponding to having mbits of | ||
# explicit mantissa left | ||
n_bouncer = ((val_out.view(torch.int32) & 0xFF800000) + ((23 - mbits) << 23)).view(torch.float) | ||
n_val_out = (n_bouncer + val_out) - n_bouncer | ||
|
||
# adding the bouncer rounds off bits, and subtracting bouncer | ||
# leaves the desired value, albeit in FP32 encoding | ||
# All we need is to change the exponent encoding to using "bias" | ||
n_val_out_i = (n_val_out.view(torch.int32) - ((127 - exponent_bias) << 23)) << (8 - ebits) | ||
n_val_out_i = (n_val_out_i | sign_bit.view(torch.int32)) >> 24 | ||
n_val_out = n_val_out_i.view(torch.float) | ||
|
||
# | ||
# denormal path | ||
# | ||
|
||
# When the value is in the denormal range, IEEE numbers essentially becomes | ||
# a fixed point number. The lsb is the smallest non-zero number | ||
# 2^(1-bias-mbits) Hence, we define the bouncer so that its lsb is this | ||
# smallest non-zero number Adding the input to this bouncer forces rounding | ||
# to occur appropriately Also, in this situation, after adding the bouncer, | ||
# the 8 least significant bits of the sum is already the HFP8 encoding of | ||
# the desired result. Just need to restore the sign bit | ||
# bouncer.I = (127 + (23 + (1 - exponent_bias - mbits))) << 23; | ||
# val_out.F = bouncer.F + val_out.F; | ||
# val_out.I = val_out.I | (sign_bit >> 24); | ||
|
||
dn_bouncer = ((torch.zeros_like(val_out).view(torch.int32) + 127 + (23 + (1 - exponent_bias - mbits))) << 23).view(torch.float) | ||
dn_val_out = dn_bouncer + val_out | ||
dn_val_out_i = dn_val_out.view(torch.int32) | (sign_bit.view(torch.int32) >> 24) | ||
dn_val_out = dn_val_out_i.view(torch.float) | ||
|
||
# | ||
# combine normal and denormal paths | ||
# | ||
val_out = torch.where(is_normal_mask, n_val_out, dn_val_out) | ||
# take the 8 least significant bits | ||
orig_shape = val_fp.shape | ||
val_out = val_out.view(torch.uint8) | ||
val_out = val_out.reshape(-1, 4) | ||
val_out = torch.tensor_split(val_out, 4, dim=-1)[0] | ||
val_out = val_out.reshape(orig_shape) | ||
return val_out | ||
|
||
|
||
def _hfp8_to_float( | ||
hfp8_val: torch.Tensor, | ||
ebits: int, | ||
exponent_bias: int, | ||
): | ||
assert hfp8_val.dtype == torch.uint8 | ||
|
||
sign_i = (hfp8_val & 0x80).to(torch.int32) << 24 | ||
|
||
val_out_i = (hfp8_val & 0x7F).to(torch.int32) << (24 - (8 - ebits)) | ||
# so that the mantissa bits start at the mantissa bit positions of FP32 | ||
# encoding | ||
|
||
# Let the hfp8 mantissa bits correspond to the value frac, 0 <= frac < 1 | ||
# So if the hfp8 value is a normal number, it's value is 2^e x (1+frac) | ||
# where e is its (true, unbiased) exponent | ||
# If the hfp8 value is denormal, the value is 2^(1-bias) x frac | ||
|
||
# However, the bit pattern in the 8-bit exponent field of val_out.F | ||
# is bias+e when hfp8 is normal, and 0 when hfp8 is subnormal. | ||
# So, as an FP32 value, when hfp8 is normal, val_out.F represents the value | ||
# of 2^(bias+e-127) * (1+frac) | ||
# And when hfp8 is subnormal, val_out.F is also subnormal, and represents the | ||
# value of 2^(-126) * frac In either case, val_out.F corresponds to | ||
# 2^(bias-127) * (value of hfp8 input) Thus, if we multiply val_out.F by | ||
# 2^(127-bias), we obtain the hfp8 value as an FP32 number | ||
|
||
# multiplier.I = (127 + (127 - exponent_bias)) | ||
# << 23; // multiplier.F is 2^(127-bias) | ||
# val_out.F *= multiplier.F; | ||
# val_out.I |= sign.I; | ||
# return val_out.F; | ||
|
||
multiplier_i = (torch.zeros_like(hfp8_val).to(torch.int32) + 127 + (127 - exponent_bias)) << 23 # multiplier_f is 2^(127-bias) | ||
val_out_f = val_out_i.view(torch.float) | ||
val_out_f *= multiplier_i.view(torch.float) | ||
val_out_f = (val_out_f.view(torch.int32) | sign_i).view(torch.float) | ||
return val_out_f | ||
|
||
def float32_to_float8_e4m3(x): | ||
return _float_to_hfp8(x, E4M3_EBITS, E4M3_EXP_BIAS, E4M3_MAX_POS) | ||
|
||
def float32_to_float8_e5m2(x): | ||
return _float_to_hfp8(x, E5M2_EBITS, E5M2_EXP_BIAS, E5M2_MAX_POS) | ||
|
||
def float8_e4m3_to_float32(x): | ||
return _hfp8_to_float(x, E4M3_EBITS, E4M3_EXP_BIAS) | ||
|
||
def float8_e5m2_to_float32(x): | ||
return _hfp8_to_float(x, E5M2_EBITS, E5M2_EXP_BIAS) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
import random | ||
import unittest | ||
|
||
import torch | ||
|
||
from float8_casts import ( | ||
float32_to_float8_e4m3, | ||
float32_to_float8_e5m2, | ||
float8_e4m3_to_float32, | ||
float8_e5m2_to_float32, | ||
) | ||
|
||
random.seed(0) | ||
torch.manual_seed(0) | ||
|
||
class Fp8CastsUnitTest(unittest.TestCase): | ||
""" | ||
Test the casts between fp32 and fp8 (e4m3 and e5m2) | ||
""" | ||
|
||
def _compare_many_exact(self, flavor, x_fp32, comp_name): | ||
if flavor == 'e4m3': | ||
to_fp8 = float32_to_float8_e4m3 | ||
from_fp8 = float8_e4m3_to_float32 | ||
else: # e5m2 | ||
to_fp8 = float32_to_float8_e5m2 | ||
from_fp8 = float8_e5m2_to_float32 | ||
|
||
x_fp8 = to_fp8(x_fp32) | ||
x_fp8_fp32 = from_fp8(x_fp8) | ||
|
||
torch.testing.assert_close(x_fp32, x_fp8_fp32) | ||
|
||
def _compare_many_approx(self, flavor, x_fp32, comp_name): | ||
if flavor == 'e4m3': | ||
sqnr_target = 25.0 | ||
to_fp8 = float32_to_float8_e4m3 | ||
from_fp8 = float8_e4m3_to_float32 | ||
else: # e5m2 | ||
sqnr_target = 23.0 | ||
to_fp8 = float32_to_float8_e5m2 | ||
from_fp8 = float8_e5m2_to_float32 | ||
|
||
x_fp8 = to_fp8(x_fp32) | ||
x_fp8_fp32 = from_fp8(x_fp8) | ||
|
||
# sign should always be the same | ||
torch.testing.assert_close( | ||
torch.sign(x_fp32), | ||
torch.sign(x_fp8_fp32), | ||
atol=0, rtol=0) | ||
|
||
# for now just measure that sqnr is somewhat reasonable | ||
# TODO(future): make this significantly more robust, this is about | ||
# 2/10 on the scale of "robust enough" | ||
def compute_error(x, y): | ||
Ps = torch.norm(x) | ||
Pn = torch.norm(x - y) | ||
return 20 * torch.log10(Ps / Pn) | ||
|
||
sqnr = compute_error(x_fp32, x_fp8_fp32) | ||
assert sqnr >= sqnr_target | ||
|
||
|
||
def _compare_one(self, flavor, bits_str, expected_fp32, comp_name): | ||
if flavor == 'e4m3': | ||
to_fp8 = float32_to_float8_e4m3 | ||
from_fp8 = float8_e4m3_to_float32 | ||
else: # e5m2 | ||
to_fp8 = float32_to_float8_e5m2 | ||
from_fp8 = float8_e5m2_to_float32 | ||
|
||
fp8_bits_ref = torch.tensor([int(bits_str, 2)], dtype=torch.uint8) | ||
|
||
fp32_tensor = torch.tensor([expected_fp32], dtype=torch.float) | ||
fp8_bits = to_fp8(fp32_tensor) | ||
torch.testing.assert_close(fp8_bits, fp8_bits_ref, atol=0, rtol=0) | ||
|
||
fp32_from_fp8_tensor = from_fp8(fp8_bits) | ||
torch.testing.assert_close(fp32_tensor, fp32_from_fp8_tensor, atol=0, rtol=0) | ||
|
||
def test_e4m3_numerics_single(self): | ||
# ensure that our format matches https://arxiv.org/pdf/2209.05433.pdf, Table 1 | ||
|
||
flavor = 'e4m3' | ||
# e4m3 does not support infinity | ||
self._compare_one(flavor, "00000000", 0.0, "zero") | ||
self._compare_one(flavor, "10000000", -0.0, "neg_zero") | ||
self._compare_one(flavor, "01111110", 448.0, "max_normal") | ||
self._compare_one(flavor, "11111110", -448.0, "neg_max_normal") | ||
self._compare_one(flavor, "00001000", 2 ** -6, "min_normal") | ||
self._compare_one(flavor, "10001000", -1 * (2 ** -6), "neg_min_normal") | ||
self._compare_one(flavor, "00000111", 0.875 * (2 ** -6), "max_subnorm") | ||
self._compare_one(flavor, "10000111", -0.875 * (2 ** -6), "neg_max_subnorm") | ||
self._compare_one(flavor, "00000001", 2 ** -9, "min_subnorm") | ||
self._compare_one(flavor, "10000001", -1 * (2 ** -9), "neg_min_subnorm") | ||
|
||
def test_e5m2_numerics_single(self): | ||
flavor = 'e5m2' | ||
# e5m2 infinity (below) is off by one, TODO(future PR) debug or just move | ||
# to NVIDIA's intrinsic casts | ||
# _compare_one(flavor, "01111100", float("inf"), "inf") | ||
# _compare_one(flavor, "11111100", -1 * float("inf"), "neg_inf") | ||
self._compare_one(flavor, "00000000", 0.0, "zero") | ||
self._compare_one(flavor, "10000000", -0.0, "neg_zero") | ||
self._compare_one(flavor, "01111011", 57344.0, "max_normal") | ||
self._compare_one(flavor, "11111011", -57344.0, "neg_max_normal") | ||
self._compare_one(flavor, "00000100", 2 ** -14, "min_normal") | ||
self._compare_one(flavor, "10000100", -1 * (2 ** -14), "neg_min_normal") | ||
self._compare_one(flavor, "00000011", 0.75 * (2 ** -14), "max_subnorm") | ||
self._compare_one(flavor, "10000011", -0.75 * (2 ** -14), "neg_max_subnorm") | ||
self._compare_one(flavor, "00000001", 2 ** -16, "min_subnorm") | ||
self._compare_one(flavor, "10000001", -1 * (2 ** -16), "neg_min_subnorm") | ||
|
||
def test_e4m3_numerics_multiple(self): | ||
# test special cases | ||
x = torch.tensor([ | ||
0.0, | ||
-0.0, | ||
448.0, | ||
-448.0, | ||
2 ** -6, | ||
-1 * (2 ** 6), | ||
0.875 * (2 ** 6), | ||
-0.875 * (2 ** 6), | ||
2 ** -9, | ||
-1 * (2 ** -9), | ||
]) | ||
self._compare_many_exact('e4m3', x, 'special_cases') | ||
|
||
# test normal values + shapes | ||
for _ in range(10): | ||
x = torch.randn(1, 2, 3, 4) * random.uniform(0.1, 300.0) | ||
x.clamp_(min=-448.0, max=448.0) | ||
self._compare_many_approx('e4m3', x, 'normal_cases') | ||
|
||
def test_e5m2_numerics_multiple(self): | ||
# test special cases | ||
x = torch.tensor([ | ||
0.0, | ||
-0.0, | ||
57344.0, | ||
-57344.0, | ||
2 ** -14, | ||
-1 * (2 ** -14), | ||
0.75 * (2 ** -14), | ||
-0.75 * (2 ** -14), | ||
2 ** -16, | ||
-1 * (2 ** -16), | ||
]) | ||
self._compare_many_exact('e5m2', x, 'special_cases') | ||
|
||
# test normal values + shapes | ||
for _ in range(10): | ||
x = torch.randn(1, 2, 3, 4) * random.uniform(0.1, 30000.0) | ||
x.clamp_(min=-57344.0, max=57344.0) | ||
self._compare_many_approx('e5m2', x, 'normal_cases') | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |