diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 43c1085b5108d7..27226dc7417a9c 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -48,6 +48,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, @@ -331,7 +332,8 @@ def forward( output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -360,14 +362,16 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + **kwargs, ) hidden_states = residual + hidden_states @@ -382,9 +386,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - if output_router_logits: outputs += (router_logits,) @@ -624,6 +625,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( @@ -646,19 +648,8 @@ def forward( ) use_cache = False - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -677,11 +668,13 @@ def forward( hidden_states = inputs_embeds + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits else None - next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -698,6 +691,7 @@ def forward( output_router_logits, use_cache, cache_position, + position_embeddings, ) else: layer_outputs = decoder_layer( @@ -709,13 +703,12 @@ def forward( output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -728,23 +721,14 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] - if v is not None - ) - return MoeModelOutputWithPast( + output = MoeModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, router_logits=all_router_logits, ) + return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -897,6 +881,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + def load_balancing_loss_func( gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None], num_experts: Optional[int] = None, @@ -1030,7 +1017,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, - **loss_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1086,6 +1073,7 @@ def forward( output_router_logits=output_router_logits, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -1094,7 +1082,7 @@ def forward( loss = None if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) aux_loss = None if output_router_logits: diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py index a925998a522889..2c209bcd54557d 100644 --- a/src/transformers/models/mixtral/modular_mixtral.py +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -27,22 +27,26 @@ from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache +from ...cache_utils import DynamicCache +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( MoeCausalLMOutputWithPast, MoeModelOutputWithPast, ) +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, logging, ) -from ..llama.modeling_llama import ( - LlamaForCausalLM, - LlamaForQuestionAnswering, - LlamaForSequenceClassification, - LlamaForTokenClassification, - LlamaRMSNorm, +from ..mistral.modeling_mistral import ( + MistralAttention, + MistralForCausalLM, + MistralForQuestionAnswering, + MistralForSequenceClassification, + MistralForTokenClassification, + MistralModel, + MistralRMSNorm, ) -from ..mistral.modeling_mistral import MistralAttention, MistralModel from .configuration_mixtral import MixtralConfig @@ -220,7 +224,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return final_hidden_states, router_logits -class MixtralRMSNorm(LlamaRMSNorm): +class MixtralRMSNorm(MistralRMSNorm): pass @@ -249,7 +253,8 @@ def forward( output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -278,14 +283,16 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + **kwargs, ) hidden_states = residual + hidden_states @@ -300,9 +307,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - if output_router_logits: outputs += (router_logits,) @@ -323,6 +327,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, MoeModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( @@ -345,19 +350,8 @@ def forward( ) use_cache = False - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -376,11 +370,13 @@ def forward( hidden_states = inputs_embeds + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits else None - next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -397,6 +393,7 @@ def forward( output_router_logits, use_cache, cache_position, + position_embeddings, ) else: layer_outputs = decoder_layer( @@ -408,13 +405,12 @@ def forward( output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -427,38 +423,27 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] - if v is not None - ) - return MoeModelOutputWithPast( + output = MoeModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, router_logits=all_router_logits, ) + return output if return_dict else output.to_tuple() + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... -class MixtralForCausalLM(LlamaForCausalLM): +class MixtralForCausalLM(MistralForCausalLM): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) - self.model = MixtralModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.router_aux_loss_coef = config.router_aux_loss_coef self.num_experts = config.num_local_experts self.num_experts_per_tok = config.num_experts_per_tok - # Initialize weights and apply final processing - self.post_init() def forward( self, @@ -475,7 +460,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, - **loss_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" Args: @@ -531,6 +516,7 @@ def forward( output_router_logits=output_router_logits, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -539,7 +525,7 @@ def forward( loss = None if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) aux_loss = None if output_router_logits: @@ -569,13 +555,13 @@ def forward( ) -class MixtralForSequenceClassification(LlamaForSequenceClassification): +class MixtralForSequenceClassification(MistralForSequenceClassification): pass -class MixtralForTokenClassification(LlamaForTokenClassification): +class MixtralForTokenClassification(MistralForTokenClassification): pass -class MixtralForQuestionAnswering(LlamaForQuestionAnswering): +class MixtralForQuestionAnswering(MistralForQuestionAnswering): pass