Skip to content

Commit

Permalink
fix some bug
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenglongjiepheonix committed May 16, 2024
1 parent 18fa186 commit 9b2c104
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
final_hidden_states[top_x].index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits

Expand Down Expand Up @@ -1315,7 +1315,11 @@ def _update_causal_mask(
using_static_cache = isinstance(past_key_values, StaticCache)
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)

if self.config._attn_implementation == "sdpa" and not (using_static_cache or using_sliding_window_cache):
if (
self.config._attn_implementation == "sdpa"
and not (using_static_cache or using_sliding_window_cache)
and not output_attentions
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
Expand Down

0 comments on commit 9b2c104

Please sign in to comment.