diff --git a/benchmarks/fp8_matmul.py b/benchmarks/fp8_matmul.py index 0217fdf..0323cb6 100644 --- a/benchmarks/fp8_matmul.py +++ b/benchmarks/fp8_matmul.py @@ -1,6 +1,6 @@ import itertools -from dataclasses import dataclass -from typing import List, Optional +from dataclasses import dataclass, replace +from typing import List, Optional, Union import torch from tabulate import tabulate import pandas as pd @@ -44,7 +44,10 @@ def is_col_major(stride): def get_fp8_matmul( - A: torch.Tensor, B: torch.Tensor, scaling_strategy: ScalingStrategy, fp8_kernel: FP8Kernel + A: torch.Tensor, + B: torch.Tensor, + scaling_strategy: ScalingStrategy, + fp8_kernel: FP8Kernel, ): A_fp8 = A.to(torch.float8_e4m3fn) B_fp8 = B.to(torch.float8_e4m3fn) @@ -66,7 +69,12 @@ def get_fp8_matmul( return lambda: matmul_device_tma_persistent(A_fp8, a_scale, B_fp8, b_scale, torch.bfloat16) elif fp8_kernel == FP8Kernel.SCALED_MM: return lambda: addmm_float8_unwrapped_inference( - A_fp8, a_scale, B_fp8, b_scale, output_dtype=torch.bfloat16, use_fast_accum=True + A_fp8, + a_scale, + B_fp8, + b_scale, + output_dtype=torch.bfloat16, + use_fast_accum=True, ) else: raise ValueError(f"Invalid FP8 kernel: {fp8_kernel}") @@ -80,13 +88,14 @@ class ExperimentConfig: scaling_strategy: ScalingStrategy fp8_kernel: FP8Kernel compile: bool + calc_bfloat16: bool = False @dataclass(frozen=True) class ExperimentResult: - bf16_time: float + bf16_time: Optional[float] fp8_time: float - bf16_tflops: float + bf16_tflops: Optional[float] fp8_tflops: float @@ -103,42 +112,53 @@ def calculate_tflops(M: int, N: int, K: int, time_us: float) -> float: return tflops -def run_experiment(config: ExperimentConfig) -> ExperimentResult: +def run_experiment(config: ExperimentConfig, check_correctness: bool = False) -> ExperimentResult: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") A = torch.randn(config.M, config.K, device=device, dtype=torch.bfloat16) B = torch.randn(config.K, config.N, device=device, dtype=torch.bfloat16) - bf16_matmul = lambda x, y: torch.matmul(x, y) + bf16_matmul = lambda x, y: torch.matmul(x, y) if config.calc_bfloat16 else None fp8_matmul = get_fp8_matmul(A, B, config.scaling_strategy, config.fp8_kernel) if config.compile and config.fp8_kernel == FP8Kernel.SCALED_MM: - bf16_matmul = torch.compile(bf16_matmul) + if config.calc_bfloat16: + bf16_matmul = torch.compile(bf16_matmul) fp8_matmul = torch.compile(fp8_matmul, mode="max-autotune") # Warmup phase warmup_iterations = 5 for _ in range(warmup_iterations): - _ = bf16_matmul(A, B) + if config.calc_bfloat16: + _ = bf16_matmul(A, B) _ = fp8_matmul() torch.cuda.synchronize() # Actual benchmarking - bf16_time = benchmark_cuda_function_in_microseconds(lambda: bf16_matmul(A, B)) + bf16_time = ( + benchmark_cuda_function_in_microseconds(lambda: bf16_matmul(A, B)) + if config.calc_bfloat16 + else None + ) fp8_time = benchmark_cuda_function_in_microseconds(fp8_matmul) # Calculate TFLOPS - bf16_tflops = calculate_tflops(config.M, config.N, config.K, bf16_time) + bf16_tflops = ( + calculate_tflops(config.M, config.N, config.K, bf16_time) if config.calc_bfloat16 else None + ) fp8_tflops = calculate_tflops(config.M, config.N, config.K, fp8_time) # Baseline fp8_matmul correctness - scaled_mm_base = get_fp8_matmul(A, B, config.scaling_strategy, FP8Kernel.SCALED_MM) - out_base = scaled_mm_base() - out = fp8_matmul() - # Failing on one sample with large N - torch.testing.assert_close(out, out_base) + if check_correctness: + scaled_mm_base = get_fp8_matmul(A, B, config.scaling_strategy, FP8Kernel.SCALED_MM) + out_base = scaled_mm_base() + out = fp8_matmul() + torch.testing.assert_close(out, out_base) return ExperimentResult( - bf16_time=bf16_time, fp8_time=fp8_time, bf16_tflops=bf16_tflops, fp8_tflops=fp8_tflops + bf16_time=bf16_time, + fp8_time=fp8_time, + bf16_tflops=bf16_tflops, + fp8_tflops=fp8_tflops, ) @@ -161,8 +181,10 @@ def print_results(experiments: List[Experiment], save_path: Optional[Path] = Non for experiment in experiments: config = experiment.config result = experiment.result - speedup = result.bf16_time / result.fp8_time - tflops_ratio = result.fp8_tflops / result.bf16_tflops + speedup = result.bf16_time / result.fp8_time if result.bf16_time is not None else None + tflops_ratio = ( + result.fp8_tflops / result.bf16_tflops if result.bf16_tflops is not None else None + ) rows.append( [ config.M, @@ -171,12 +193,12 @@ def print_results(experiments: List[Experiment], save_path: Optional[Path] = Non config.scaling_strategy, config.fp8_kernel, config.compile, - f"{result.bf16_time:.4f}", + f"{result.bf16_time:.4f}" if result.bf16_time is not None else "N/A", f"{result.fp8_time:.4f}", - f"{speedup:.2f}x", - f"{result.bf16_tflops:.2f}", + f"{speedup:.2f}x" if speedup is not None else "N/A", + f"{result.bf16_tflops:.2f}" if result.bf16_tflops is not None else "N/A", f"{result.fp8_tflops:.2f}", - f"{tflops_ratio:.2f}x", + f"{tflops_ratio:.2f}x" if tflops_ratio is not None else "N/A", ] ) print(tabulate(rows, headers=headers, floatfmt=".4f")) @@ -206,7 +228,43 @@ def get_configs_varying_k(M: int = 8192, N: int = 8192) -> List[ExperimentConfig ): configs.append( ExperimentConfig( - M=M, K=K, N=N, scaling_strategy=strategy, compile=compile, fp8_kernel=kernel + M=M, + K=K, + N=N, + scaling_strategy=strategy, + compile=compile, + fp8_kernel=kernel, + ) + ) + return configs + + +def get_configs_varying_k_big() -> List[ExperimentConfig]: + M = [1024, 1536, 2048, 3072, 4096, 6144, 8192, 12288, 16384] + K = [1024, 1536, 2048, 3072, 4096, 6144, 8192, 12288, 16384] + N = [1024, 1536, 2048, 3072, 4096, 6144, 8192, 12288, 16384] + shapes = itertools.product(M, K, N) + scaling_strategies = [ScalingStrategy.PER_ROW] + compile_options = [False] + configs = [] + fp8_kernels = [ + FP8Kernel.SCALED_MM, + # FP8Kernel.PERSISTENT, + # FP8Kernel.PERSISTENT_TMA, + FP8Kernel.DEVICE_TMA, + ] + + for (M, K, N), strategy, compile, kernel in itertools.product( + shapes, scaling_strategies, compile_options, fp8_kernels + ): + configs.append( + ExperimentConfig( + M=M, + K=K, + N=N, + scaling_strategy=strategy, + compile=compile, + fp8_kernel=kernel, ) ) return configs @@ -250,17 +308,33 @@ def plot_tflops_comparison(df, save_path: Path): print(f"TFLOPS comparison plot saved as {graph_path}") -def main(save_path: Optional[str] = None, M: int = 8192, N: int = 8192, graph: bool = False): +def main( + save_path: Optional[Union[Path, str]] = None, + M: int = 8192, + N: int = 8192, + graph: bool = False, + big: bool = False, + calc_bfloat16: bool = True, # New parameter added +): """Benchmark FP8 MatMul with different configurations and optionally graph results. Args: save_path (Optional[str], optional): Path to save the results. Defaults to None. M (int, optional): Number of rows in the first matrix. Defaults to 8192. N (int, optional): Number of columns in the second matrix. Defaults to 8192. - graph_results (bool, optional): Whether to create a graph of the results. Defaults to False. + graph (bool, optional): Whether to create a graph of the results. Defaults to False. + big (bool, optional): Whether to use larger matrix sizes. Defaults to False. + calc_bfloat16 (bool, optional): Whether to calculate bfloat16 results. Defaults to True. """ torch.random.manual_seed(123) - configs = get_configs_varying_k(M, N) + if not big: + configs = get_configs_varying_k(M, N) + else: + configs = get_configs_varying_k_big() + + # Update configs with calc_bfloat16 option + configs = [replace(config, calc_bfloat16=calc_bfloat16) for config in configs] + results = [] if save_path is not None: save_path = Path(save_path) diff --git a/transformer_nuggets/utils/benchmark.py b/transformer_nuggets/utils/benchmark.py index 32a0574..523059d 100644 --- a/transformer_nuggets/utils/benchmark.py +++ b/transformer_nuggets/utils/benchmark.py @@ -55,7 +55,7 @@ def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float: """Thin wrapper around do_bench_using_profiling""" no_args = lambda: func(*args, **kwargs) - time = do_bench_using_profiling(no_args) + time = do_bench_using_profiling(no_args, rep=25) return time * 1e3