diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index b8fb01d7b23cad..83637536a12531 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -653,6 +653,7 @@ def forward( value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=causal_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -970,9 +971,7 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1] - ) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) # embed positions hidden_states = inputs_embeds @@ -1036,17 +1035,32 @@ def forward( attentions=all_self_attns, ) - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask - def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length): + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_seen_tokens: int, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None + if self.config._attn_implementation == "sdpa": + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, + # in order to dispatch on Flash Attention 2. + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens + ): + return None + dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] @@ -1054,7 +1068,9 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position, curr target_length = self.config.max_position_embeddings else: # dynamic cache target_length = ( - attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1 + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 ) causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)