diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 2b468a8c442b95..d9e58efaa9165f 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -919,7 +919,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] - final_hidden_states[top_x].index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits @@ -1315,7 +1315,11 @@ def _update_causal_mask( using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) - if self.config._attn_implementation == "sdpa" and not (using_static_cache or using_sliding_window_cache): + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor,