diff --git a/tests/test_model.py b/tests/test_model.py index 136e387fbe..ff5a97efc7 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1341,7 +1341,7 @@ def test_forward_with_output_attentions_and_output_hidden_states( if output_attentions: assert len(outputs.attentions) == n_layers - assert all(attn.shape == (1, 4, 2048, 2048) for attn in outputs.attentions) + assert all(attn.shape == (1, 4, 3, 3) for attn in outputs.attentions) if output_hidden_states: assert len(outputs.hidden_states) == n_layers + 1