From 1c6cca8689ef050c49152318ea1d281b413c2141 Mon Sep 17 00:00:00 2001 From: pglorio Date: Tue, 10 Sep 2024 06:27:40 +0000 Subject: [PATCH] circleci fixes --- src/transformers/models/zamba/modeling_zamba.py | 11 +++++++++++ tests/models/zamba/test_modeling_zamba.py | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index b45a59bbd6ea67..ef8e434c870bb7 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -999,6 +999,9 @@ def forward( ) transformer_hidden_states = layer_outputs[0] + + if output_attentions: + self_attn_weights = layer_outputs[1] transformer_hidden_states = self.linear(transformer_hidden_states) @@ -1012,6 +1015,9 @@ def forward( use_cache=use_cache, cache_position=cache_position, ) + + if output_attentions: + layer_outputs[1] = self_attn_weights return layer_outputs @@ -1291,6 +1297,11 @@ def forward( cache_position=cache_position, ) hidden_states = layer_outputs[0] + + if output_attentions: + if layer_outputs[1] is not None: + # append attentions only of attention layers. Mamba layers return `None` as the attention weights + all_self_attns += (layer_outputs[1],) hidden_states = self.final_layernorm(hidden_states) diff --git a/tests/models/zamba/test_modeling_zamba.py b/tests/models/zamba/test_modeling_zamba.py index e80015f068ee20..b55ef0eae8d016 100644 --- a/tests/models/zamba/test_modeling_zamba.py +++ b/tests/models/zamba/test_modeling_zamba.py @@ -348,7 +348,7 @@ def test_initialization(self): elif "D" in name: # check if it's a ones like self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5)) - elif "x_proj" in name or "dt_proj_weight" in name: + elif "x_proj" in name or "dt_proj_weight" in name or "dt_proj_bias" in name: self.assertIn( ((param.data.mean() * 1e2).round() / 1e2).item(), [0.0, 1.0],