Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add axiswise scaling to Float8Linear #920

Merged
merged 20 commits into from
Oct 7, 2024
32 changes: 27 additions & 5 deletions benchmarks/float8/bench_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@

import torch
import torch.utils.benchmark as benchmark
from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
from torchao.float8.config import (
CastConfig,
Float8LinearConfig,
ScalingType,
ScalingGranularity,
)
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_linear_utils import (
linear_requires_sync,
Expand Down Expand Up @@ -107,35 +112,49 @@ def main(
scaling_type_input: str = "dynamic",
scaling_type_weight: str = "dynamic",
scaling_type_grad_output: str = "dynamic",
scaling_granularity: str = "tensorwise",
):
device = "cuda"
print(f"Compile is set to | {compile}")

scaling_type_input = ScalingType(scaling_type_input)
scaling_type_weight = ScalingType(scaling_type_weight)
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
scaling_granularity = ScalingGranularity(scaling_granularity)

if scaling_type_input is ScalingType.STATIC:
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_input=CastConfig(scaling_type=scaling_type_input)
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
scaling_granularity=scaling_granularity,
)
if scaling_type_weight is ScalingType.STATIC:
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_weight=CastConfig(scaling_type=scaling_type_weight)
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
scaling_granularity=scaling_granularity,
)
if scaling_type_grad_output is ScalingType.STATIC:
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output)
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
scaling_granularity=scaling_granularity,
)

config = Float8LinearConfig(
cast_config_input=cast_config_input,
Expand Down Expand Up @@ -167,7 +186,7 @@ def main(
copy.deepcopy(linear_ref),
config=config,
)
scaling_repr = linear_float8.scaling_repr()
scaling_repr = f"{linear_float8.scaling_type_repr()},{linear_float8.scaling_granularity_repr()}"

if fast_accum:
linear_float8.forward_config = ScaledMMConfig(False, True, False)
Expand Down Expand Up @@ -310,6 +329,7 @@ def invoke_main() -> None:
parser.add_argument("--scaling_type_input", type=str, required=False)
parser.add_argument("--scaling_type_weight", type=str, required=False)
parser.add_argument("--scaling_type_grad_output", type=str, required=False)
parser.add_argument("--scaling_granularity", type=str, required=False)
args = parser.parse_args()
output_path = Path(args.output_path) if args.output_path is not None else None
kwargs = {}
Expand All @@ -327,6 +347,8 @@ def invoke_main() -> None:
kwargs["scaling_type_weight"] = args.scaling_type_weight
if args.scaling_type_grad_output is not None:
kwargs["scaling_type_grad_output"] = args.scaling_type_grad_output
if args.scaling_granularity is not None:
kwargs["scaling_granularity"] = args.scaling_granularity
main(
output_path,
not args.disable_compile,
Expand Down
13 changes: 11 additions & 2 deletions benchmarks/float8/bench_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import torch.nn as nn
import torch.utils.benchmark as benchmark

from torchao.float8.config import ScalingGranularity

from utils import (
get_name_to_shapes_iter,
profiler_output_to_filtered_time_by_kernel_name,
Expand Down Expand Up @@ -75,6 +77,7 @@ def run(
K: Optional[int] = None,
N: Optional[int] = None,
use_gpu_kernel_time: bool = False,
scaling_granularity: str = "tensorwise",
):
device = "cuda"

Expand All @@ -84,6 +87,7 @@ def run(
dtype = torch.bfloat16
name_to_shapes = get_name_to_shapes_iter(shape_gen_name, M, K, N)
fast_accum_vals = [True, False]
scaling_granularity = ScalingGranularity(scaling_granularity)

for idx, (fast_accum, (name, (M, K, N))) in enumerate(itertools.product(fast_accum_vals, name_to_shapes)):
if n_limit is not None and idx >= n_limit:
Expand All @@ -109,8 +113,13 @@ def run(
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype
A = torch.zeros(M, K, device=device, dtype=d1)
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
scale_a = torch.tensor([1.0], device=device)
scale_b = torch.tensor([1.0], device=device)
if scaling_granularity == ScalingGranularity.TENSORWISE:
scale_a = torch.tensor([1.0], device=device)
scale_b = torch.tensor([1.0], device=device)
else:
assert scaling_granularity == ScalingGranularity.AXISWISE, "unsupported"
scale_a = torch.ones(M, 1, device=device)
scale_b = torch.ones(1, N, device=device)

def do_matmul(A, B):
nonlocal scale_a
Expand Down
27 changes: 23 additions & 4 deletions benchmarks/float8/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
from torchao.float8.config import (
CastConfig,
Float8LinearConfig,
ScalingType,
ScalingGranularity,
)
from torchao.float8.float8_linear_utils import (
convert_to_float8_training,
linear_requires_sync,
Expand Down Expand Up @@ -252,6 +257,7 @@ def main(
scaling_type_input: str = "dynamic",
scaling_type_weight: str = "dynamic",
scaling_type_grad_output: str = "dynamic",
scaling_granularity: str = "tensorwise",
model_type: str = "linear",
dtype_filter: str = "both",
add_inductor_metadata_to_trace: bool = True,
Expand All @@ -263,28 +269,41 @@ def main(
scaling_type_input = ScalingType(scaling_type_input)
scaling_type_weight = ScalingType(scaling_type_weight)
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
scaling_granularity = ScalingGranularity(scaling_granularity)

if scaling_type_input is ScalingType.STATIC:
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_input=CastConfig(scaling_type=scaling_type_input)
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
scaling_granularity=scaling_granularity,
)
if scaling_type_weight is ScalingType.STATIC:
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_weight=CastConfig(scaling_type=scaling_type_weight)
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
scaling_granularity=scaling_granularity,
)
if scaling_type_grad_output is ScalingType.STATIC:
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output)
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
scaling_granularity=scaling_granularity,
)

config = Float8LinearConfig(
cast_config_input=cast_config_input,
Expand Down
33 changes: 30 additions & 3 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,10 @@ def _test_linear_impl(
"scaling_type_grad_output",
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
)
@pytest.mark.parametrize(
"scaling_granularity",
[ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE],
)
@pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32])
@pytest.mark.parametrize("linear_bias", [False, True])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
Expand All @@ -334,33 +338,56 @@ def test_linear(
scaling_type_input: ScalingType,
scaling_type_weight: ScalingType,
scaling_type_grad_output: ScalingType,
scaling_granularity: ScalingGranularity,
linear_dtype: torch.dtype,
linear_bias: bool,
):
if scaling_granularity is ScalingGranularity.AXISWISE:
if (
scaling_type_input != ScalingType.DYNAMIC or
scaling_type_weight != ScalingType.DYNAMIC or
scaling_type_grad_output != ScalingType.DYNAMIC or
linear_dtype != torch.bfloat16 or
(not is_cuda_9_0)
):
pytest.skip()

x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)

if scaling_type_input is ScalingType.STATIC:
cast_config_input = CastConfig(
scaling_type=scaling_type_input,
scaling_granularity=scaling_granularity,
static_scale=torch.tensor([1.0], device="cuda"),
)
else:
cast_config_input = CastConfig(scaling_type=scaling_type_input)
cast_config_input = CastConfig(
scaling_type=scaling_type_input,
scaling_granularity=scaling_granularity,
)
if scaling_type_weight is ScalingType.STATIC:
cast_config_weight = CastConfig(
scaling_type=scaling_type_weight,
scaling_granularity=scaling_granularity,
static_scale=torch.tensor([1.0], device="cuda"),
)
else:
cast_config_weight = CastConfig(scaling_type=scaling_type_weight)
cast_config_weight = CastConfig(
scaling_type=scaling_type_weight,
scaling_granularity=scaling_granularity,
)
if scaling_type_grad_output is ScalingType.STATIC:
cast_config_grad_output = CastConfig(
scaling_type=scaling_type_grad_output,
scaling_granularity=scaling_granularity,
static_scale=torch.tensor([1.0], device="cuda"),
)
else:
cast_config_grad_output = CastConfig(scaling_type=scaling_type_grad_output)
cast_config_grad_output = CastConfig(
scaling_type=scaling_type_grad_output,
scaling_granularity=scaling_granularity,
)

config = Float8LinearConfig(
cast_config_input=cast_config_input,
Expand Down
Loading
Loading