Skip to content

Commit

Permalink
[wip] make scaling configurable by gemm-argument
Browse files Browse the repository at this point in the history
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 71d847f5c6666ff307062f8cea25c50c398a1774
ghstack-comment-id: 2372563439
Pull Request resolved: #940
  • Loading branch information
vkuzo committed Oct 4, 2024
1 parent 3b6ca01 commit 9c4fd9d
Show file tree
Hide file tree
Showing 10 changed files with 589 additions and 398 deletions.
83 changes: 46 additions & 37 deletions benchmarks/float8/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
Float8LinearConfig,
ScalingType,
ScalingGranularity,
_Float8LinearRecipeName,
_recipe_name_to_linear_config,
)
from torchao.float8.float8_linear_utils import (
convert_to_float8_training,
Expand Down Expand Up @@ -258,6 +260,7 @@ def main(
scaling_type_weight: str = "dynamic",
scaling_type_grad_output: str = "dynamic",
scaling_granularity: str = "tensorwise",
recipe_name: Optional[str] = None,
model_type: str = "linear",
dtype_filter: str = "both",
add_inductor_metadata_to_trace: bool = True,
Expand All @@ -271,45 +274,51 @@ def main(
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
scaling_granularity = ScalingGranularity(scaling_granularity)

if scaling_type_input is ScalingType.STATIC:
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
scaling_granularity=scaling_granularity,
)
if scaling_type_weight is ScalingType.STATIC:
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
scaling_granularity=scaling_granularity,
)
if scaling_type_grad_output is ScalingType.STATIC:
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
scaling_granularity=scaling_granularity,
if recipe_name is None:

if scaling_type_input is ScalingType.STATIC:
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
scaling_granularity=scaling_granularity,
)
if scaling_type_weight is ScalingType.STATIC:
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
scaling_granularity=scaling_granularity,
)
if scaling_type_grad_output is ScalingType.STATIC:
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
scaling_granularity=scaling_granularity,
)

config = Float8LinearConfig(
cast_config_input=cast_config_input,
cast_config_weight=cast_config_weight,
cast_config_grad_output=cast_config_grad_output,
)

config = Float8LinearConfig(
cast_config_input=cast_config_input,
cast_config_weight=cast_config_weight,
cast_config_grad_output=cast_config_grad_output,
)
elif recipe_name is not None:
recipe_name = _Float8LinearRecipeName(recipe_name)
config = _recipe_name_to_linear_config(recipe_name)

scaling_repr = "_".join(
[
Expand Down
126 changes: 65 additions & 61 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import itertools
import random
import re
from typing import List, Tuple
import unittest
import warnings

Expand All @@ -27,6 +28,8 @@
Float8LinearConfig,
ScalingGranularity,
ScalingType,
_Float8LinearRecipeName,
_recipe_name_to_linear_config,
)
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_linear_utils import (
Expand All @@ -35,7 +38,10 @@
sync_float8_amax_and_scale_history,
)
from torchao.float8.float8_python_api import addmm_float8_unwrapped
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
from torchao.float8.float8_scaling_utils import (
hp_tensor_to_float8_dynamic,
get_maybe_axiswise_dim,
)
from torchao.float8.float8_tensor import (
Float8Tensor,
GemmInputRole,
Expand All @@ -51,6 +57,7 @@
FP8_TYPES,
tensor_to_scale,
)
from torchao.testing.float8.test_utils import get_test_float8_linear_config

random.seed(0)
torch.manual_seed(0)
Expand All @@ -59,6 +66,8 @@
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)



def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
assert torch.all(a._scale == b._scale).item(), "scales are not identical"
assert torch.all(a._data == b._data).item(), "data is not identical"
Expand Down Expand Up @@ -205,9 +214,17 @@ def test_axiswise_reshape(self):
a_fp8_d2_r2 = a_fp8_d2.reshape(3, -1)

@pytest.mark.parametrize("a_shape", [(16, 32), (2, 16, 32), (1, 2, 16, 32)])
@pytest.mark.parametrize(
"a_granularity,b_granularity",
[
(ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE),
(ScalingGranularity.AXISWISE, ScalingGranularity.TENSORWISE),
(ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE),
]
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(not is_cuda_9_0, "Requires CUDA capability >= 9.0")
def test_axiswise_gemm(self, a_shape):
def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity):
a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda")
b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda")

Expand All @@ -218,18 +235,20 @@ def test_axiswise_gemm(self, a_shape):
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=-1,
scaling_granularity=a_granularity,
axiswise_dim=get_maybe_axiswise_dim(-1, a_granularity),
)
a_fp8 = a_fp8.reshape(-1, a_shape[-1])

b_fp8 = hp_tensor_to_float8_dynamic(
b,
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=-1, # will be transposed
scaling_granularity=b_granularity,
axiswise_dim=get_maybe_axiswise_dim(-1, b_granularity),
)

c_fp8_compute = torch.mm(a_fp8, b_fp8.t())
a = a.reshape(-1, a_shape[-1])
c_ref = torch.mm(a, b.t())
Expand Down Expand Up @@ -322,79 +341,64 @@ def _test_linear_impl(
)
@pytest.mark.parametrize(
"scaling_type_grad_output",
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
)
@pytest.mark.parametrize(
"scaling_granularity",
[ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE],
[ScalingType.DELAYED, ScalingType.DYNAMIC],
)
@pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32])
@pytest.mark.parametrize("linear_bias", [False, True])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_linear(
def test_linear_from_config_params(
self,
x_shape,
emulate: bool,
scaling_type_input: ScalingType,
scaling_type_weight: ScalingType,
scaling_type_grad_output: ScalingType,
scaling_granularity: ScalingGranularity,
linear_dtype: torch.dtype,
linear_bias: bool,
):
if scaling_granularity is ScalingGranularity.AXISWISE:
if (
scaling_type_input != ScalingType.DYNAMIC or
scaling_type_weight != ScalingType.DYNAMIC or
scaling_type_grad_output != ScalingType.DYNAMIC or
linear_dtype != torch.bfloat16 or
(not is_cuda_9_0)
):
pytest.skip()

x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)

if scaling_type_input is ScalingType.STATIC:
cast_config_input = CastConfig(
scaling_type=scaling_type_input,
scaling_granularity=scaling_granularity,
static_scale=torch.tensor([1.0], device="cuda"),
)
else:
cast_config_input = CastConfig(
scaling_type=scaling_type_input,
scaling_granularity=scaling_granularity,
)
if scaling_type_weight is ScalingType.STATIC:
cast_config_weight = CastConfig(
scaling_type=scaling_type_weight,
scaling_granularity=scaling_granularity,
static_scale=torch.tensor([1.0], device="cuda"),
)
else:
cast_config_weight = CastConfig(
scaling_type=scaling_type_weight,
scaling_granularity=scaling_granularity,
)
if scaling_type_grad_output is ScalingType.STATIC:
cast_config_grad_output = CastConfig(
scaling_type=scaling_type_grad_output,
scaling_granularity=scaling_granularity,
static_scale=torch.tensor([1.0], device="cuda"),
)
else:
cast_config_grad_output = CastConfig(
scaling_type=scaling_type_grad_output,
scaling_granularity=scaling_granularity,
)
config = get_test_float8_linear_config(
scaling_type_input,
scaling_type_weight,
scaling_type_grad_output,
emulate,
)

config = Float8LinearConfig(
cast_config_input=cast_config_input,
cast_config_weight=cast_config_weight,
cast_config_grad_output=cast_config_grad_output,
emulate=emulate,
self._test_linear_impl(
x,
m_ref,
config,
)

# Note: there are now too many config combinations to test all of
# them, so this function factors out some of the recipes which are annoying
# to combine with the main testing function.
# TODO(future PR): make this cleaner.
@pytest.mark.parametrize(
"recipe_name",
[_Float8LinearRecipeName.ALL_AXISWISE, _Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP],
)
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
@pytest.mark.parametrize("linear_bias", [True, False])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_linear_from_recipe(
self,
recipe_name,
x_shape,
linear_bias: bool,
):
if torch.cuda.get_device_capability() < (9, 0):
warnings.warn(
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
)
pytest.skip()

linear_dtype = torch.bfloat16
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)
config = _recipe_name_to_linear_config(recipe_name)
self._test_linear_impl(
x,
m_ref,
Expand Down
Loading

0 comments on commit 9c4fd9d

Please sign in to comment.