Skip to content

Commit

Permalink
[wip] make scaling configurable by gemm-argument
Browse files Browse the repository at this point in the history
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 3eaa2df5258fc3795eaf9a86c892248889d06cb2
ghstack-comment-id: 2372563439
Pull Request resolved: #940
  • Loading branch information
vkuzo committed Sep 27, 2024
1 parent 823f64d commit f8dbc15
Show file tree
Hide file tree
Showing 8 changed files with 566 additions and 239 deletions.
92 changes: 55 additions & 37 deletions benchmarks/float8/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@
linear_requires_sync,
sync_float8_amax_and_scale_history,
)
from torchao.testing.float8.test_utils import (
scaling_granularities_by_gemm_lcw_recipe,
get_test_float8_linear_config,
)
from torch.profiler import profile, ProfilerActivity, record_function
from utils import (
kernel_name_to_category,
Expand Down Expand Up @@ -258,6 +262,8 @@ def main(
scaling_type_weight: str = "dynamic",
scaling_type_grad_output: str = "dynamic",
scaling_granularity: str = "tensorwise",
# TODO(future PR): clean up the override, it's confusing
recipe_override: Optional[str] = None,
model_type: str = "linear",
dtype_filter: str = "both",
add_inductor_metadata_to_trace: bool = True,
Expand All @@ -271,45 +277,57 @@ def main(
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
scaling_granularity = ScalingGranularity(scaling_granularity)

if scaling_type_input is ScalingType.STATIC:
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
scaling_granularity=scaling_granularity,
)
if scaling_type_weight is ScalingType.STATIC:
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
scaling_granularity=scaling_granularity,
)
if scaling_type_grad_output is ScalingType.STATIC:
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
scaling_granularity=scaling_granularity,
if recipe_override is None:

if scaling_type_input is ScalingType.STATIC:
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
scaling_granularity=scaling_granularity,
)
if scaling_type_weight is ScalingType.STATIC:
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
scaling_granularity=scaling_granularity,
)
if scaling_type_grad_output is ScalingType.STATIC:
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
scaling_granularity=scaling_granularity,
)

config = Float8LinearConfig(
cast_config_input=cast_config_input,
cast_config_weight=cast_config_weight,
cast_config_grad_output=cast_config_grad_output,
)

config = Float8LinearConfig(
cast_config_input=cast_config_input,
cast_config_weight=cast_config_weight,
cast_config_grad_output=cast_config_grad_output,
)
elif recipe_override == "lcw":
scaling_granularities_by_gemm = scaling_granularities_by_gemm_lcw_recipe
config = get_test_float8_linear_config(
scaling_type_input,
scaling_type_weight,
scaling_type_grad_output,
scaling_granularities_by_gemm,
False, # emulate
)

scaling_repr = "_".join(
[
Expand Down
129 changes: 74 additions & 55 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import itertools
import random
import re
from typing import List, Tuple
import unittest
import warnings

Expand Down Expand Up @@ -51,6 +52,10 @@
FP8_TYPES,
tensor_to_scale,
)
from torchao.testing.float8.test_utils import (
scaling_granularities_by_gemm,
get_test_float8_linear_config,
)

random.seed(0)
torch.manual_seed(0)
Expand All @@ -59,6 +64,8 @@
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)



def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
assert torch.all(a._data == b._data).item(), "scales are not identical"
assert torch.all(a._data == b._data).item(), "data is not identical"
Expand Down Expand Up @@ -211,31 +218,52 @@ def test_axiswise_reshape(self):
a_fp8_d2_r2 = a_fp8_d2.reshape(3, -1)

@pytest.mark.parametrize("a_shape", [(16, 32), (2, 16, 32), (1, 2, 16, 32)])
@pytest.mark.parametrize(
"a_granularity,b_granularity",
[
(ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE),
(ScalingGranularity.AXISWISE, ScalingGranularity.TENSORWISE),
(ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE),
]
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(not is_cuda_9_0, "Requires CUDA capability >= 9.0")
def test_axiswise_gemm(self, a_shape):
def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity):
a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda")
b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda")

linear_mm_config = LinearMMConfig()

if a_granularity is ScalingGranularity.AXISWISE:
a_axiswise_dim = -1
else:
assert a_granularity is ScalingGranularity.TENSORWISE
a_axiswise_dim = None
a_fp8 = hp_tensor_to_float8_dynamic(
a,
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=-1,
scaling_granularity=a_granularity,
axiswise_dim=a_axiswise_dim,
)
a_fp8 = a_fp8.reshape(-1, a_shape[-1])

b_axiswise_dim = 1 if b_granularity is ScalingGranularity.AXISWISE else None
if b_granularity is ScalingGranularity.AXISWISE:
b_axiswise_dim = 1 # will be transposed
else:
assert b_granularity is ScalingGranularity.TENSORWISE
b_axiswise_dim = None
b_fp8 = hp_tensor_to_float8_dynamic(
b,
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=1, # will be transposed
scaling_granularity=b_granularity,
axiswise_dim=b_axiswise_dim,
)

c_fp8_compute = torch.mm(a_fp8, b_fp8.t())
a = a.reshape(-1, a_shape[-1])
c_ref = torch.mm(a, b.t())
Expand Down Expand Up @@ -316,26 +344,33 @@ def _test_linear_impl(
# verify initialization flags got updated
assert m_fp8.is_amax_initialized, "Amax was not properly initialized"

@pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True])
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
# @pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True])
@pytest.mark.parametrize("emulate", [False] if is_cuda_8_9 else [True])
# @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
@pytest.mark.parametrize("x_shape", [(16, 16),])
@pytest.mark.parametrize(
"scaling_type_input",
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
# [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
[ScalingType.DYNAMIC]
)
@pytest.mark.parametrize(
"scaling_type_weight",
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
# [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
[ScalingType.DYNAMIC]
)
@pytest.mark.parametrize(
"scaling_type_grad_output",
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
# [ScalingType.DELAYED, ScalingType.DYNAMIC],
[ScalingType.DYNAMIC]
)
@pytest.mark.parametrize(
"scaling_granularity",
[ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE],
"scaling_granularities_by_gemm",
scaling_granularities_by_gemm
)
@pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32])
@pytest.mark.parametrize("linear_bias", [False, True])
# @pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32])
@pytest.mark.parametrize("linear_dtype", [torch.bfloat16, ])
# @pytest.mark.parametrize("linear_bias", [False, True])
@pytest.mark.parametrize("linear_bias", [False, ])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_linear(
self,
Expand All @@ -344,7 +379,7 @@ def test_linear(
scaling_type_input: ScalingType,
scaling_type_weight: ScalingType,
scaling_type_grad_output: ScalingType,
scaling_granularity: ScalingGranularity,
scaling_granularities_by_gemm: List[List[Tuple[ScalingGranularity, ScalingGranularity]]],
linear_dtype: torch.dtype,
linear_bias: bool,
):
Expand All @@ -357,7 +392,23 @@ def test_linear(
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
)
pytest.skip()
if scaling_granularity is ScalingGranularity.AXISWISE:

(
(scaling_granularity_input, scaling_granularity_weight, original_prec_input, original_prec_weight),
(scaling_granularity_grad_output, scaling_granularity_weight_for_grad_input, original_prec_grad_output, original_prec_weight_for_grad_input),
(scaling_granularity_input_for_grad_weight, scaling_granularity_grad_output_for_grad_weight, original_prec_input_for_grad_weight, original_prec_grad_output_for_grad_weight),
) = scaling_granularities_by_gemm

has_any_axiswise_scaling = (
scaling_granularity_input is ScalingGranularity.AXISWISE or
scaling_granularity_weight is ScalingGranularity.AXISWISE or
scaling_granularity_grad_output is ScalingGranularity.AXISWISE or
scaling_granularity_input_for_grad_weight is ScalingGranularity.AXISWISE or
scaling_granularity_weight_for_grad_input is ScalingGranularity.AXISWISE or
scaling_granularity_grad_output_for_grad_weight is ScalingGranularity.AXISWISE
)

if has_any_axiswise_scaling:
if (
scaling_type_input != ScalingType.DYNAMIC or
scaling_type_weight != ScalingType.DYNAMIC or
Expand All @@ -370,46 +421,14 @@ def test_linear(
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)

if scaling_type_input is ScalingType.STATIC:
cast_config_input = CastConfig(
scaling_type=scaling_type_input,
scaling_granularity=scaling_granularity,
static_scale=torch.tensor([1.0], device="cuda"),
)
else:
cast_config_input = CastConfig(
scaling_type=scaling_type_input,
scaling_granularity=scaling_granularity,
)
if scaling_type_weight is ScalingType.STATIC:
cast_config_weight = CastConfig(
scaling_type=scaling_type_weight,
scaling_granularity=scaling_granularity,
static_scale=torch.tensor([1.0], device="cuda"),
)
else:
cast_config_weight = CastConfig(
scaling_type=scaling_type_weight,
scaling_granularity=scaling_granularity,
)
if scaling_type_grad_output is ScalingType.STATIC:
cast_config_grad_output = CastConfig(
scaling_type=scaling_type_grad_output,
scaling_granularity=scaling_granularity,
static_scale=torch.tensor([1.0], device="cuda"),
)
else:
cast_config_grad_output = CastConfig(
scaling_type=scaling_type_grad_output,
scaling_granularity=scaling_granularity,
)

config = Float8LinearConfig(
cast_config_input=cast_config_input,
cast_config_weight=cast_config_weight,
cast_config_grad_output=cast_config_grad_output,
emulate=emulate,
config = get_test_float8_linear_config(
scaling_type_input,
scaling_type_weight,
scaling_type_grad_output,
scaling_granularities_by_gemm,
emulate,
)

self._test_linear_impl(
x,
m_ref,
Expand Down
Loading

0 comments on commit f8dbc15

Please sign in to comment.