Skip to content

Commit

Permalink
small tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Oct 11, 2024
1 parent 3f7e8a9 commit f75fbb4
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 29 deletions.
130 changes: 102 additions & 28 deletions benchmarks/fp8_matmul.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import itertools
from dataclasses import dataclass
from typing import List, Optional
from dataclasses import dataclass, replace
from typing import List, Optional, Union
import torch
from tabulate import tabulate
import pandas as pd
Expand Down Expand Up @@ -44,7 +44,10 @@ def is_col_major(stride):


def get_fp8_matmul(
A: torch.Tensor, B: torch.Tensor, scaling_strategy: ScalingStrategy, fp8_kernel: FP8Kernel
A: torch.Tensor,
B: torch.Tensor,
scaling_strategy: ScalingStrategy,
fp8_kernel: FP8Kernel,
):
A_fp8 = A.to(torch.float8_e4m3fn)
B_fp8 = B.to(torch.float8_e4m3fn)
Expand All @@ -66,7 +69,12 @@ def get_fp8_matmul(
return lambda: matmul_device_tma_persistent(A_fp8, a_scale, B_fp8, b_scale, torch.bfloat16)
elif fp8_kernel == FP8Kernel.SCALED_MM:
return lambda: addmm_float8_unwrapped_inference(
A_fp8, a_scale, B_fp8, b_scale, output_dtype=torch.bfloat16, use_fast_accum=True
A_fp8,
a_scale,
B_fp8,
b_scale,
output_dtype=torch.bfloat16,
use_fast_accum=True,
)
else:
raise ValueError(f"Invalid FP8 kernel: {fp8_kernel}")
Expand All @@ -80,13 +88,14 @@ class ExperimentConfig:
scaling_strategy: ScalingStrategy
fp8_kernel: FP8Kernel
compile: bool
calc_bfloat16: bool = False


@dataclass(frozen=True)
class ExperimentResult:
bf16_time: float
bf16_time: Optional[float]
fp8_time: float
bf16_tflops: float
bf16_tflops: Optional[float]
fp8_tflops: float


Expand All @@ -103,42 +112,53 @@ def calculate_tflops(M: int, N: int, K: int, time_us: float) -> float:
return tflops


def run_experiment(config: ExperimentConfig) -> ExperimentResult:
def run_experiment(config: ExperimentConfig, check_correctness: bool = False) -> ExperimentResult:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
A = torch.randn(config.M, config.K, device=device, dtype=torch.bfloat16)
B = torch.randn(config.K, config.N, device=device, dtype=torch.bfloat16)

bf16_matmul = lambda x, y: torch.matmul(x, y)
bf16_matmul = lambda x, y: torch.matmul(x, y) if config.calc_bfloat16 else None
fp8_matmul = get_fp8_matmul(A, B, config.scaling_strategy, config.fp8_kernel)

if config.compile and config.fp8_kernel == FP8Kernel.SCALED_MM:
bf16_matmul = torch.compile(bf16_matmul)
if config.calc_bfloat16:
bf16_matmul = torch.compile(bf16_matmul)
fp8_matmul = torch.compile(fp8_matmul, mode="max-autotune")

# Warmup phase
warmup_iterations = 5
for _ in range(warmup_iterations):
_ = bf16_matmul(A, B)
if config.calc_bfloat16:
_ = bf16_matmul(A, B)
_ = fp8_matmul()
torch.cuda.synchronize()

# Actual benchmarking
bf16_time = benchmark_cuda_function_in_microseconds(lambda: bf16_matmul(A, B))
bf16_time = (
benchmark_cuda_function_in_microseconds(lambda: bf16_matmul(A, B))
if config.calc_bfloat16
else None
)
fp8_time = benchmark_cuda_function_in_microseconds(fp8_matmul)

# Calculate TFLOPS
bf16_tflops = calculate_tflops(config.M, config.N, config.K, bf16_time)
bf16_tflops = (
calculate_tflops(config.M, config.N, config.K, bf16_time) if config.calc_bfloat16 else None
)
fp8_tflops = calculate_tflops(config.M, config.N, config.K, fp8_time)

# Baseline fp8_matmul correctness
scaled_mm_base = get_fp8_matmul(A, B, config.scaling_strategy, FP8Kernel.SCALED_MM)
out_base = scaled_mm_base()
out = fp8_matmul()
# Failing on one sample with large N
torch.testing.assert_close(out, out_base)
if check_correctness:
scaled_mm_base = get_fp8_matmul(A, B, config.scaling_strategy, FP8Kernel.SCALED_MM)
out_base = scaled_mm_base()
out = fp8_matmul()
torch.testing.assert_close(out, out_base)

return ExperimentResult(
bf16_time=bf16_time, fp8_time=fp8_time, bf16_tflops=bf16_tflops, fp8_tflops=fp8_tflops
bf16_time=bf16_time,
fp8_time=fp8_time,
bf16_tflops=bf16_tflops,
fp8_tflops=fp8_tflops,
)


Expand All @@ -161,8 +181,10 @@ def print_results(experiments: List[Experiment], save_path: Optional[Path] = Non
for experiment in experiments:
config = experiment.config
result = experiment.result
speedup = result.bf16_time / result.fp8_time
tflops_ratio = result.fp8_tflops / result.bf16_tflops
speedup = result.bf16_time / result.fp8_time if result.bf16_time is not None else None
tflops_ratio = (
result.fp8_tflops / result.bf16_tflops if result.bf16_tflops is not None else None
)
rows.append(
[
config.M,
Expand All @@ -171,12 +193,12 @@ def print_results(experiments: List[Experiment], save_path: Optional[Path] = Non
config.scaling_strategy,
config.fp8_kernel,
config.compile,
f"{result.bf16_time:.4f}",
f"{result.bf16_time:.4f}" if result.bf16_time is not None else "N/A",
f"{result.fp8_time:.4f}",
f"{speedup:.2f}x",
f"{result.bf16_tflops:.2f}",
f"{speedup:.2f}x" if speedup is not None else "N/A",
f"{result.bf16_tflops:.2f}" if result.bf16_tflops is not None else "N/A",
f"{result.fp8_tflops:.2f}",
f"{tflops_ratio:.2f}x",
f"{tflops_ratio:.2f}x" if tflops_ratio is not None else "N/A",
]
)
print(tabulate(rows, headers=headers, floatfmt=".4f"))
Expand Down Expand Up @@ -206,7 +228,43 @@ def get_configs_varying_k(M: int = 8192, N: int = 8192) -> List[ExperimentConfig
):
configs.append(
ExperimentConfig(
M=M, K=K, N=N, scaling_strategy=strategy, compile=compile, fp8_kernel=kernel
M=M,
K=K,
N=N,
scaling_strategy=strategy,
compile=compile,
fp8_kernel=kernel,
)
)
return configs


def get_configs_varying_k_big() -> List[ExperimentConfig]:
M = [1024, 1536, 2048, 3072, 4096, 6144, 8192, 12288, 16384]
K = [1024, 1536, 2048, 3072, 4096, 6144, 8192, 12288, 16384]
N = [1024, 1536, 2048, 3072, 4096, 6144, 8192, 12288, 16384]
shapes = itertools.product(M, K, N)
scaling_strategies = [ScalingStrategy.PER_ROW]
compile_options = [False]
configs = []
fp8_kernels = [
FP8Kernel.SCALED_MM,
# FP8Kernel.PERSISTENT,
# FP8Kernel.PERSISTENT_TMA,
FP8Kernel.DEVICE_TMA,
]

for (M, K, N), strategy, compile, kernel in itertools.product(
shapes, scaling_strategies, compile_options, fp8_kernels
):
configs.append(
ExperimentConfig(
M=M,
K=K,
N=N,
scaling_strategy=strategy,
compile=compile,
fp8_kernel=kernel,
)
)
return configs
Expand Down Expand Up @@ -250,17 +308,33 @@ def plot_tflops_comparison(df, save_path: Path):
print(f"TFLOPS comparison plot saved as {graph_path}")


def main(save_path: Optional[str] = None, M: int = 8192, N: int = 8192, graph: bool = False):
def main(
save_path: Optional[Union[Path, str]] = None,
M: int = 8192,
N: int = 8192,
graph: bool = False,
big: bool = False,
calc_bfloat16: bool = True, # New parameter added
):
"""Benchmark FP8 MatMul with different configurations and optionally graph results.
Args:
save_path (Optional[str], optional): Path to save the results. Defaults to None.
M (int, optional): Number of rows in the first matrix. Defaults to 8192.
N (int, optional): Number of columns in the second matrix. Defaults to 8192.
graph_results (bool, optional): Whether to create a graph of the results. Defaults to False.
graph (bool, optional): Whether to create a graph of the results. Defaults to False.
big (bool, optional): Whether to use larger matrix sizes. Defaults to False.
calc_bfloat16 (bool, optional): Whether to calculate bfloat16 results. Defaults to True.
"""
torch.random.manual_seed(123)
configs = get_configs_varying_k(M, N)
if not big:
configs = get_configs_varying_k(M, N)
else:
configs = get_configs_varying_k_big()

# Update configs with calc_bfloat16 option
configs = [replace(config, calc_bfloat16=calc_bfloat16) for config in configs]

results = []
if save_path is not None:
save_path = Path(save_path)
Expand Down
2 changes: 1 addition & 1 deletion transformer_nuggets/utils/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) ->
def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
"""Thin wrapper around do_bench_using_profiling"""
no_args = lambda: func(*args, **kwargs)
time = do_bench_using_profiling(no_args)
time = do_bench_using_profiling(no_args, rep=25)
return time * 1e3


Expand Down

0 comments on commit f75fbb4

Please sign in to comment.