Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
danielvegamyhre committed Dec 23, 2024
1 parent 1c06b47 commit 2bf8656
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
10 changes: 6 additions & 4 deletions torchao/prototype/float8nocompile/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
def get_configs() -> List[ExperimentConfig]:
layer_sizes = [[4096, 4096]]
input_shapes = [(2**4, 4096), (2**8, 4096), (2**12, 4096), (2**16, 4096)]
high_precision_dtypes = [torch.float32, torch.bfloat16]
high_precision_dtypes = [torch.bfloat16]
configs = []
for layer_size, input_shape, high_precision_dtype in itertools.product(
layer_sizes, input_shapes, high_precision_dtypes
Expand Down Expand Up @@ -133,18 +133,20 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:

def print_results(experiments: List[Experiment]):
headers = [
"input_size",
"input_shape",
"high_precision_dtype",
"eager_time",
"compiled_time",
"float8nocompile",
]
rows = []
for experiment in experiments:
input_size = experiment.config.input_shape[0] * experiment.config.input_shape[1]
input_shape = (
f"({experiment.config.input_shape[0]}, {experiment.config.input_shape[1]})"
)
rows.append(
[
f"{input_size:.2e}",
input_shape,
experiment.config.high_precision_dtype,
experiment.result.eager_time,
experiment.result.compiled_time,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _block_amax_atomic(
block_offs = block_start + tl.arange(0, BLOCK_SIZE)
block_mask = block_offs < num_elements
vals = tl.load(input_ptr + block_offs, mask=block_mask).to(input_dtype)
block_amax = tl.max(tl.abs(vals), axis=0)
block_amax = tl.max(tl.abs(vals))
tl.atomic_max(amax_ptr, block_amax)


Expand Down Expand Up @@ -124,7 +124,7 @@ def _block_amax_reduction(
block_offs = block_start + tl.arange(0, BLOCK_SIZE)
block_mask = block_offs < num_elements
vals = tl.load(input_ptr + block_offs, mask=block_mask).to(input_dtype)
block_amax = tl.max(tl.abs(vals), axis=0)
block_amax = tl.max(tl.abs(vals))
tl.store(block_amaxes_ptr + block_id, block_amax)


Expand Down

0 comments on commit 2bf8656

Please sign in to comment.