From 47383779c0975fd7ec213bc7b07c34ca5a1e5abb Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 8 Jan 2025 16:09:17 -0800 Subject: [PATCH] fix bug in tl.store mask for kernel _to_fp8_row_major_t_and_non_t (#1516) --- .../float8nocompile/kernels/fp8_dynamic_tensorwise.py | 4 ++-- torchao/prototype/float8nocompile/test/train_test.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py index 4400a587c..630e80e09 100644 --- a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py +++ b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py @@ -375,8 +375,8 @@ def _to_fp8_row_major_t_and_non_t( block_col_offs[:, None] * row_major_t_out_stride_row + block_row_offs[None, :] * row_major_t_out_stride_col ) - mask = (block_row_offs[:, None] < row_major_t_num_rows) & ( - block_col_offs[None, :] < row_major_t_num_cols + mask = (block_col_offs[:, None] < row_major_t_num_rows) & ( + block_row_offs[None, :] < row_major_t_num_cols ) tl.store(row_major_t_out_ptr + row_major_t_offs, fp8_vals.trans(1, 0), mask=mask) diff --git a/torchao/prototype/float8nocompile/test/train_test.py b/torchao/prototype/float8nocompile/test/train_test.py index ccb219a21..40fc2787c 100644 --- a/torchao/prototype/float8nocompile/test/train_test.py +++ b/torchao/prototype/float8nocompile/test/train_test.py @@ -36,7 +36,9 @@ def model2(): return TestModel() -@pytest.mark.parametrize("input_shape", [(16, 32), (1, 16, 32), (2, 16, 32)]) +@pytest.mark.parametrize( + "input_shape", [(16, 32), (1, 16, 32), (2, 16, 32), (128, 8192, 32)] +) def test_model_weights_and_gradients(model1, model2, input_shape: tuple[int, int]): assert torch.cuda.is_available() device = torch.device("cuda")