Skip to content

Commit

Permalink
fix remaining failing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Nov 6, 2023
1 parent d0a344d commit c1c79ae
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion optimum/bettertransformer/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,7 @@ def falcon_forward(
value_layer_,
attention_mask,
0.0,
is_causal=self.is_causal and attention_mask is None,
is_causal=self.is_causal and attention_mask is None and query_length > 1,
)
attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
attn_output = attn_output.permute(0, 2, 1, 3)
Expand Down
4 changes: 3 additions & 1 deletion optimum/bettertransformer/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ def _revert(self, module: torch.nn.Module) -> torch.nn.Module:
continue

if module not in self.keys_to_ignore:
parameter = current_weight[i * split_index : (i + 1) * split_index]
# TODO: remove the clone once https://github.com/huggingface/transformers/pull/27314 & https://github.com/huggingface/safetensors/pull/379 are released.
# Safetensors is bugged when using views of tensors.
parameter = current_weight[i * split_index : (i + 1) * split_index].clone()
if isinstance(recurse_getattr(module, subparam_name), torch.nn.Parameter):
parameter = torch.nn.Parameter(parameter)
recurse_setattr(module, subparam_name, parameter)
Expand Down
1 change: 0 additions & 1 deletion tests/bettertransformer/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ def test_logits_with_cache(self, test_name: str, model_type: str, batch_size: in
result_vanilla = model(input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values)

model = BetterTransformer.transform(model)

result_bettertransformer = model(
input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values
)
Expand Down

0 comments on commit c1c79ae

Please sign in to comment.