Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
danielvegamyhre committed Jan 8, 2025
2 parents 2db5deb + 7184b5b commit 077e8bd
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion torchao/prototype/float8nocompile/test/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def model2():
return TestModel()


@pytest.mark.parametrize("input_shape", [(16, 32), (1, 16, 32), (2, 16, 32)])
@pytest.mark.parametrize(
"input_shape", [(16, 32), (1, 16, 32), (2, 16, 32), (128, 8192, 32)]
)
@pytest.mark.parametrize("use_activation_checkpointing", [True, False])
def test_model_weights_and_gradients(
model1, model2, input_shape: tuple[int, int], use_activation_checkpointing: bool
Expand Down

0 comments on commit 077e8bd

Please sign in to comment.