Skip to content

Commit

Permalink
circleci fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
pglorio committed Sep 10, 2024
1 parent 3788196 commit 1c6cca8
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
11 changes: 11 additions & 0 deletions src/transformers/models/zamba/modeling_zamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,6 +999,9 @@ def forward(
)

transformer_hidden_states = layer_outputs[0]

if output_attentions:
self_attn_weights = layer_outputs[1]

transformer_hidden_states = self.linear(transformer_hidden_states)

Expand All @@ -1012,6 +1015,9 @@ def forward(
use_cache=use_cache,
cache_position=cache_position,
)

if output_attentions:
layer_outputs[1] = self_attn_weights

return layer_outputs

Expand Down Expand Up @@ -1291,6 +1297,11 @@ def forward(
cache_position=cache_position,
)
hidden_states = layer_outputs[0]

if output_attentions:
if layer_outputs[1] is not None:
# append attentions only of attention layers. Mamba layers return `None` as the attention weights
all_self_attns += (layer_outputs[1],)

hidden_states = self.final_layernorm(hidden_states)

Expand Down
2 changes: 1 addition & 1 deletion tests/models/zamba/test_modeling_zamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def test_initialization(self):
elif "D" in name:
# check if it's a ones like
self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5))
elif "x_proj" in name or "dt_proj_weight" in name:
elif "x_proj" in name or "dt_proj_weight" in name or "dt_proj_bias" in name:
self.assertIn(
((param.data.mean() * 1e2).round() / 1e2).item(),
[0.0, 1.0],
Expand Down

0 comments on commit 1c6cca8

Please sign in to comment.