From 2bf8656fd0a1392bc082f70a822e4df8c6d43681 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Mon, 23 Dec 2024 09:02:17 -0800 Subject: [PATCH] address comments --- .../prototype/float8nocompile/benchmark/benchmark.py | 10 ++++++---- .../float8nocompile/kernels/fp8_dynamic_tensorwise.py | 4 ++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/torchao/prototype/float8nocompile/benchmark/benchmark.py b/torchao/prototype/float8nocompile/benchmark/benchmark.py index 72ecd030c..e6c39cb43 100644 --- a/torchao/prototype/float8nocompile/benchmark/benchmark.py +++ b/torchao/prototype/float8nocompile/benchmark/benchmark.py @@ -59,7 +59,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def get_configs() -> List[ExperimentConfig]: layer_sizes = [[4096, 4096]] input_shapes = [(2**4, 4096), (2**8, 4096), (2**12, 4096), (2**16, 4096)] - high_precision_dtypes = [torch.float32, torch.bfloat16] + high_precision_dtypes = [torch.bfloat16] configs = [] for layer_size, input_shape, high_precision_dtype in itertools.product( layer_sizes, input_shapes, high_precision_dtypes @@ -133,7 +133,7 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult: def print_results(experiments: List[Experiment]): headers = [ - "input_size", + "input_shape", "high_precision_dtype", "eager_time", "compiled_time", @@ -141,10 +141,12 @@ def print_results(experiments: List[Experiment]): ] rows = [] for experiment in experiments: - input_size = experiment.config.input_shape[0] * experiment.config.input_shape[1] + input_shape = ( + f"({experiment.config.input_shape[0]}, {experiment.config.input_shape[1]})" + ) rows.append( [ - f"{input_size:.2e}", + input_shape, experiment.config.high_precision_dtype, experiment.result.eager_time, experiment.result.compiled_time, diff --git a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py index eec4f2838..cedf68e1a 100644 --- a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py +++ b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py @@ -65,7 +65,7 @@ def _block_amax_atomic( block_offs = block_start + tl.arange(0, BLOCK_SIZE) block_mask = block_offs < num_elements vals = tl.load(input_ptr + block_offs, mask=block_mask).to(input_dtype) - block_amax = tl.max(tl.abs(vals), axis=0) + block_amax = tl.max(tl.abs(vals)) tl.atomic_max(amax_ptr, block_amax) @@ -124,7 +124,7 @@ def _block_amax_reduction( block_offs = block_start + tl.arange(0, BLOCK_SIZE) block_mask = block_offs < num_elements vals = tl.load(input_ptr + block_offs, mask=block_mask).to(input_dtype) - block_amax = tl.max(tl.abs(vals), axis=0) + block_amax = tl.max(tl.abs(vals)) tl.store(block_amaxes_ptr + block_id, block_amax)