Skip to content

Commit

Permalink
float8nocompile: add benchmark script
Browse files Browse the repository at this point in the history
  • Loading branch information
danielvegamyhre committed Dec 20, 2024
1 parent 29de3e0 commit e74d63a
Showing 1 changed file with 177 additions and 0 deletions.
177 changes: 177 additions & 0 deletions torchao/prototype/float8nocompile/benchmark/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py

import itertools
from dataclasses import dataclass
from typing import Callable, List

import torch
from tabulate import tabulate
from torch import nn
from torch._inductor.utils import do_bench_using_profiling
from torch.nn import functional as F

from torchao.float8.float8_linear_utils import convert_to_float8_training
from torchao.prototype.float8nocompile.float8nocompile_linear_utils import (
convert_to_float8_nocompile_training,
)
from tqdm import tqdm

device = torch.device("cuda")

# Needed since changing args to function causes recompiles
torch._dynamo.config.cache_size_limit = 1000


@dataclass(frozen=True)
class ExperimentConfig:
high_precision_dtype: torch.dtype
layer_sizes: list[int]
input_shape: tuple[int]


@dataclass(frozen=True)
class ExperimentResult:
float8nocompile_time: float
eager_time: float
compiled_time: float


@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
result: ExperimentResult


class TestModel(nn.Module):
def __init__(self, layer_sizes=[32, 64, 32]):
super().__init__()
self.layers = nn.Sequential(
*[
nn.Linear(layer_sizes[i], layer_sizes[i + 1], bias=False)
for i in range(len(layer_sizes) - 1)
]
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layers(x)


def get_configs() -> List[ExperimentConfig]:
layer_sizes = [[4096, 4096]]
input_shapes = [(2**4, 4096), (2**8, 4096), (2**12, 4096), (2**16, 4096)]
high_precision_dtypes = [torch.float32, torch.bfloat16]
configs = []
for layer_size, input_shape, high_precision_dtype in itertools.product(
layer_sizes, input_shapes, high_precision_dtypes
):
configs.append(
ExperimentConfig(
layer_sizes=layer_size,
input_shape=input_shape,
high_precision_dtype=high_precision_dtype,
)
)
return configs


def forward_backward(model, input_tensor):
output = model(input_tensor)
loss = F.mse_loss(output, torch.zeros_like(output))
loss.backward()


def run_experiment(config: ExperimentConfig) -> ExperimentResult:
# eager float8 baseline
eager_float8_model = convert_to_float8_training(
TestModel(config.layer_sizes).to(device)
)

# compiled float8 baseline
compiled_float8_model = torch.compile(eager_float8_model, fullgraph=True)

# float8nocompile triton implementation
float8nocompile_model = convert_to_float8_nocompile_training(
TestModel(config.layer_sizes).to(device)
)

# define test inputs
input_tensor = torch.randn(
*config.input_shape,
requires_grad=True,
dtype=config.high_precision_dtype,
device=device,
)
input_eager = input_tensor.clone().detach().requires_grad_(True)
input_compiled = input_tensor.clone().detach().requires_grad_(True)
input_triton = input_tensor.clone().detach().requires_grad_(True)

# benchmark forward + backward for each model
eager_time = benchmark_cuda_function_in_microseconds(
forward_backward,
eager_float8_model,
input_eager,
)

compiled_time = benchmark_cuda_function_in_microseconds(
forward_backward,
compiled_float8_model,
input_compiled,
)

float8nocompile_time = benchmark_cuda_function_in_microseconds(
forward_backward,
float8nocompile_model,
input_triton,
)

return ExperimentResult(
eager_time=eager_time,
compiled_time=compiled_time,
float8nocompile_time=float8nocompile_time,
)


def print_results(experiments: List[Experiment]):
headers = [
"input_size",
"high_precision_dtype",
"eager_time",
"compiled_time",
"float8nocompile",
]
rows = []
for experiment in experiments:
input_size = experiment.config.input_shape[0] * experiment.config.input_shape[1]
rows.append(
[
f"{input_size:.2e}",
experiment.config.high_precision_dtype,
experiment.result.eager_time,
experiment.result.compiled_time,
experiment.result.float8nocompile_time,
]
)
print(tabulate(rows, headers=headers))


def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
"""Thin wrapper around do_bench_using_profiling"""
no_args = lambda: func(*args, **kwargs)
time = do_bench_using_profiling(no_args)
return time * 1e3


def main():
torch.random.manual_seed(123)
configs = get_configs()
results = []
for config in tqdm(configs):
result = run_experiment(config)
results.append(Experiment(config=config, result=result))

# Use Tabulate to print results
print_results(results)


if __name__ == "__main__":
main()

0 comments on commit e74d63a

Please sign in to comment.