Skip to content

Commit

Permalink
add axiswise granularity to Float8Tensor
Browse files Browse the repository at this point in the history
Summary:

This is a copy-paste of pytorch-labs/float8_experimental#352
which never landed.

Test Plan:

```

```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: e998d637e0593760ad5a25d0c852d7a2706c8d1a
ghstack-comment-id: 2368837836
Pull Request resolved: #919
  • Loading branch information
vkuzo committed Sep 23, 2024
1 parent 53b6b78 commit 4473ac5
Show file tree
Hide file tree
Showing 6 changed files with 278 additions and 23 deletions.
122 changes: 117 additions & 5 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,20 @@
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
from torchao.float8.config import (
CastConfig,
Float8LinearConfig,
ScalingGranularity,
ScalingType,
)
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_linear_utils import (
convert_to_float8_training,
linear_requires_sync,
sync_float8_amax_and_scale_history,
)
from torchao.float8.float8_python_api import addmm_float8_unwrapped
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
from torchao.float8.float8_tensor import (
Float8Tensor,
GemmInputRole,
Expand All @@ -51,14 +57,15 @@


is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)

def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
assert torch.all(a._data == b._data).item(), "scales are not identical"
assert torch.all(a._data == b._data).item(), "data is not identical"
return True


class TestFloat8Tensor(unittest.TestCase):
class TestFloat8Tensor:
def test_preserves_dtype(self) -> None:
# hp means high precision, lp means low precision
hp_dtypes = (torch.float32, torch.float16, torch.bfloat16)
Expand All @@ -68,7 +75,7 @@ def test_preserves_dtype(self) -> None:
x1_s = tensor_to_scale(x1_hp, lp_dtype)
x2_lp = hp_tensor_and_scale_to_float8(x1_hp, x1_s, lp_dtype)
x3_hp = x2_lp.to_original_precision()
self.assertTrue(x3_hp.dtype == hp_dtype)
assert x3_hp.dtype == hp_dtype

def test_differentiable_casts(self) -> None:
lp_dtypes = (e4m3_dtype, e5m2_dtype)
Expand Down Expand Up @@ -103,7 +110,7 @@ def test_index_put(self):
fp8_b = hp_tensor_and_scale_to_float8(b, scale_a, torch.float8_e4m3fn)
fp8_b_bad = hp_tensor_and_scale_to_float8(b, scale_b, torch.float8_e4m3fn)

with self.assertRaises(AssertionError):
with pytest.raises(AssertionError):
b[index] = fp8_a
fp8_b[index] = a
fp8_b_bad[index] = fp8_a
Expand All @@ -117,7 +124,7 @@ def test_copy_(self):
b = torch.empty(16, dtype=torch.bfloat16)
b.copy_(fp8_a) # Should work
torch.testing.assert_close(b, fp8_a.to_original_precision())
with self.assertRaises(RuntimeError):
with pytest.raises(RuntimeError):
fp8_a.copy_(b) # Should fail

fp8_b = Float8Tensor(
Expand All @@ -129,6 +136,111 @@ def test_copy_(self):
fp8_b.copy_(fp8_a)
torch.testing.assert_close(fp8_a._data, fp8_b._data)

@pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)])
@pytest.mark.parametrize("dim_name", ["first", "last"])
def test_axiswise_dynamic_cast(self, shape, dim_name):
a = torch.randn(*shape, dtype=torch.bfloat16)

if dim_name == "first":
dim = 0
elif dim_name == "last":
dim = len(a.shape) - 1

linear_mm_config = LinearMMConfig()
a_fp8 = hp_tensor_to_float8_dynamic(
a,
e4m3_dtype,
linear_mm_config,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=dim,
)
a_dq = a_fp8.to_original_precision()
sqnr = compute_error(a, a_dq)
assert sqnr >= 25.0

def test_axiswise_reshape(self):
a = torch.randn(3, 5, 7, dtype=torch.bfloat16)
linear_mm_config = LinearMMConfig()

# if we scale across dim0, we can only reshape to [3, -1]
a_fp8_d0 = hp_tensor_to_float8_dynamic(
a,
e4m3_dtype,
linear_mm_config,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=0,
)
assert list(a_fp8_d0._data.shape) == [3, 5, 7]
assert list(a_fp8_d0._scale.shape) == [1, 5, 7]

a_fp8_d0_r = a_fp8_d0.reshape(3, -1)
assert list(a_fp8_d0_r.shape) == [3, 5 * 7]
assert list(a_fp8_d0_r._scale.shape) == [1, 5 * 7]
# verify numerics did not change
assert torch.allclose(
a_fp8_d0.to_original_precision(),
a_fp8_d0_r.to_original_precision().reshape(3, 5, 7),
atol=0,
rtol=0,
)
with pytest.raises(RuntimeError):
a_fp8_d0_r2 = a_fp8_d0.reshape(-1, 7)

# if we scale across dim2, we can only reshape to [-1, 7]
a_fp8_d2 = hp_tensor_to_float8_dynamic(
a,
e4m3_dtype,
linear_mm_config,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=2,
)
assert list(a_fp8_d2._data.shape) == [3, 5, 7]
assert list(a_fp8_d2._scale.shape) == [3, 5, 1]

a_fp8_d2_r = a_fp8_d2.reshape(-1, 7)
assert list(a_fp8_d2_r.shape) == [3 * 5, 7]
assert list(a_fp8_d2_r._scale.shape) == [3 * 5, 1]
# verify numerics did not change
assert torch.allclose(
a_fp8_d2.to_original_precision(),
a_fp8_d2_r.to_original_precision().reshape(3, 5, 7),
atol=0,
rtol=0,
)
with pytest.raises(RuntimeError):
a_fp8_d2_r2 = a_fp8_d2.reshape(3, -1)

@pytest.mark.parametrize("a_shape", [(16, 32), (2, 16, 32), (1, 2, 16, 32)])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(not is_cuda_9_0, "Requires CUDA capability >= 9.0")
def test_axiswise_gemm(self, a_shape):
a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda")
b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda")

linear_mm_config = LinearMMConfig()

a_fp8 = hp_tensor_to_float8_dynamic(
a,
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=-1,
)
a_fp8 = a_fp8.reshape(-1, a_shape[-1])
b_fp8 = hp_tensor_to_float8_dynamic(
b,
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=1, # will be transposed
)
c_fp8_compute = torch.mm(a_fp8, b_fp8.t())
a = a.reshape(-1, a_shape[-1])
c_ref = torch.mm(a, b.t())
sqnr = compute_error(c_ref, c_fp8_compute)
assert sqnr >= 25.0


class TestFloat8Linear:
Expand Down
12 changes: 12 additions & 0 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ def short_str(self):
return "sta"


class ScalingGranularity(enum.Enum):
"""
Defines the granularity of scaling strategies for casting to float8
"""

# A single scaling factor for the entire tensor
TENSORWISE = "tensorwise"
# Scaling factors computed along one axis of the tensor, reducing it to
# size 1.
AXISWISE = "axiswise"


@dataclass(frozen=True)
class CastConfig:
"""
Expand Down
93 changes: 92 additions & 1 deletion torchao/float8/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@
FLOAT8_OPS_TABLE: Dict[Any, Any] = {}


def _assert_tensorwise_scale(aten_op, scale):
assert (
# TODO(future PR): figure out why tensorwise scaling can have
# both rank 0 and rank 1
len(scale.shape)
in (0, 1)
), f"{aten_op} with axiswise scaling is not supported yet"


def implements(aten_ops):
"""Register aten ops to the float8 op table"""

Expand All @@ -45,6 +54,7 @@ def decorator(func):
]
)
def float8_desugar_op(aten_op, args, kwargs=None):
_assert_tensorwise_scale(aten_op, args[0]._scale)
new_data = aten_op(args[0]._data, *args[1:], **kwargs)
return Float8Tensor(
new_data,
Expand All @@ -55,10 +65,82 @@ def float8_desugar_op(aten_op, args, kwargs=None):
)


@implements(
[
aten.t.default,
aten.transpose.int,
]
)
def float8_desugar_data_and_scale(aten_op, args, kwargs=None):
new_data = aten_op(args[0]._data, *args[1:], **kwargs)
new_scale = aten_op(args[0]._scale, *args[1:], **kwargs)

if aten_op == aten.transpose.int:
_assert_tensorwise_scale(aten_op, args[0]._scale)

old_axiswise_dim = args[0]._axiswise_dim
new_axiswise_dim = old_axiswise_dim
if old_axiswise_dim is not None:
if old_axiswise_dim == 0:
new_axiswise_dim == -1
else:
new_axiswise_dim == 0

return Float8Tensor(
new_data,
new_scale,
args[0]._orig_dtype,
args[0]._linear_mm_config,
args[0]._gemm_input_role,
new_axiswise_dim,
)


@implements([aten.view.default])
def float8_view(aten_op, args, kwargs=None):
if len(args[0]._scale.shape) < 2:
# tensorwise scaling
return float8_desugar_op(aten_op, args, kwargs)

t, new_shape = args[0], args[1]
# for now, only support reshaping to [-1, dim] or [dim, -1]
axiswise_dim = t._axiswise_dim
if len(new_shape) == 2:

if axiswise_dim == 0:
new_data = aten_op(t._data, new_shape, **kwargs)
new_scale_shape = [1, new_shape[-1]]
new_scale = aten_op(t._scale, new_scale_shape, **kwargs)
return Float8Tensor(
new_data,
new_scale,
t._orig_dtype,
t._linear_mm_config,
t._gemm_input_role,
t._axiswise_dim,
)
elif axiswise_dim == -1 or axiswise_dim == (len(t.shape) - 1):
new_data = aten_op(t._data, new_shape, **kwargs)
new_scale_shape = [new_shape[0], 1]
new_scale = aten_op(t._scale, new_scale_shape, **kwargs)
new_axiswise_dim = -1
return Float8Tensor(
new_data,
new_scale,
t._orig_dtype,
t._linear_mm_config,
t._gemm_input_role,
new_axiswise_dim,
)
raise AssertionError(
f"{aten_op} with axiswise scaling and t.shape {t.shape} t._scale.shape {t._scale.shape} t._axiswise_dim {t._axiswise_dim} new_shape {new_shape} is not supported yet."
)


@implements([aten.split.Tensor])
def float8_split(aten_op, args, kwargs=None):
new_data_tensors = aten_op(args[0]._data, *args[1:], **kwargs)

_assert_tensorwise_scale(aten_op, args[0]._scale)
def make_float8(data):
return Float8Tensor(
data,
Expand Down Expand Up @@ -102,6 +184,7 @@ def float8_cat(aten_op, args, kwargs=None):
assert (
chunk._gemm_input_role is gemm_input_role
), "Expecting all chunks to have the same gemm_input_role as a result of a split"
_assert_tensorwise_scale(aten_op, chunk._scale)
chunk_data.append(chunk._data.view(torch.uint8))

new_data = aten_op(chunk_data, *args[1:], **kwargs)
Expand All @@ -118,6 +201,7 @@ def float8_cast_up_op(aten_op, args, kwargs=None):
"addmm" -> out
"hp_gradBias" <-"sum" <- "identity" <- gradOut <- "hp_gradOut"
"""
_assert_tensorwise_scale(aten_op, args[0]._scale)

def unwrap(x):
if isinstance(x, Float8Tensor):
Expand Down Expand Up @@ -230,6 +314,7 @@ def float8_addmm(aten_op, args, kwargs=None):

@implements([aten.is_same_size.default])
def float8_is_same_size(aten_op, args, kwargs=None):
_assert_tensorwise_scale(aten_op, args[0]._scale)
return args[0].shape == args[1].shape


Expand All @@ -239,6 +324,7 @@ def autocast_to_copy(aten_op, args, kwargs=None):
when the input is a Float8Tensor, presenting as a fp32
tensor.
"""
_assert_tensorwise_scale(aten_op, args[0]._scale)
assert isinstance(args[0], Float8Tensor)
assert (
len(kwargs) == 1 and "dtype" in kwargs
Expand Down Expand Up @@ -266,6 +352,7 @@ def allgather_fp8(aten_op, args, kwargs=None):
"""
override funcol with FP8 handling
"""
_assert_tensorwise_scale(aten_op, args[0]._scale)
fp8_input = args[0]
assert isinstance(
fp8_input, Float8Tensor
Expand All @@ -285,6 +372,7 @@ def allgather_fp8(aten_op, args, kwargs=None):

@implements([c10d_functional.wait_tensor.default, _c10d_functional.wait_tensor.default])
def wait_tensor_fp8(aten_op, args, kwargs=None):
_assert_tensorwise_scale(aten_op, args[0]._scale)
fp8_input = args[0]
assert isinstance(fp8_input, Float8Tensor)

Expand All @@ -305,6 +393,7 @@ def index_put_fp8(aten_op, args, kwargs=None):
fp8_values = args[2]
assert isinstance(fp8_self, Float8Tensor)
assert isinstance(fp8_values, Float8Tensor)
_assert_tensorwise_scale(fp8_self, args[0]._scale)
assert fp8_self._scale == fp8_values._scale
assert fp8_self.dtype == fp8_values.dtype
assert fp8_self._orig_dtype == fp8_values._orig_dtype
Expand Down Expand Up @@ -335,8 +424,10 @@ def copy_fp8(aten_op, args, kwargs=None):

if not isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor):
src_hp = src.to_original_precision()
_assert_tensorwise_scale(aten_op, src._scale)
return aten_op(self, src_hp, *args[2:], **kwargs)
elif isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor):
_assert_tensorwise_scale(aten_op, src._scale)
assert (
self._orig_dtype == src._orig_dtype
), "Expecting both Float8Tensors to be of the same dtype"
Expand Down
Loading

0 comments on commit 4473ac5

Please sign in to comment.