Skip to content

Commit

Permalink
mixtral
Browse files Browse the repository at this point in the history
  • Loading branch information
Cyrilvallez committed Dec 16, 2024
1 parent b275fdc commit 71eb6a2
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 90 deletions.
60 changes: 24 additions & 36 deletions src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
LossKwargs,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
Expand Down Expand Up @@ -331,7 +332,8 @@ def forward(
output_router_logits: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
Expand Down Expand Up @@ -360,14 +362,16 @@ def forward(
hidden_states = self.input_layernorm(hidden_states)

# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + hidden_states

Expand All @@ -382,9 +386,6 @@ def forward(
if output_attentions:
outputs += (self_attn_weights,)

if use_cache:
outputs += (present_key_value,)

if output_router_logits:
outputs += (router_logits,)

Expand Down Expand Up @@ -624,6 +625,7 @@ def forward(
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_router_logits = (
Expand All @@ -646,19 +648,8 @@ def forward(
)
use_cache = False

# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
)
if use_cache and past_key_values is None:
past_key_values = DynamicCache()

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
Expand All @@ -677,11 +668,13 @@ def forward(

hidden_states = inputs_embeds

# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_router_logits = () if output_router_logits else None
next_decoder_cache = None

for decoder_layer in self.layers:
if output_hidden_states:
Expand All @@ -698,6 +691,7 @@ def forward(
output_router_logits,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
Expand All @@ -709,13 +703,12 @@ def forward(
output_router_logits=output_router_logits,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)

hidden_states = layer_outputs[0]

if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]

if output_attentions:
all_self_attns += (layer_outputs[1],)

Expand All @@ -728,23 +721,14 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)

next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()

if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
if v is not None
)
return MoeModelOutputWithPast(
output = MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=all_self_attns,
router_logits=all_router_logits,
)
return output if return_dict else output.to_tuple()

def _update_causal_mask(
self,
Expand Down Expand Up @@ -897,6 +881,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask


class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...


def load_balancing_loss_func(
gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
num_experts: Optional[int] = None,
Expand Down Expand Up @@ -1030,7 +1017,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1086,6 +1073,7 @@ def forward(
output_router_logits=output_router_logits,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)

hidden_states = outputs[0]
Expand All @@ -1094,7 +1082,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)

aux_loss = None
if output_router_logits:
Expand Down
Loading

0 comments on commit 71eb6a2

Please sign in to comment.