Skip to content

Commit

Permalink
make fixup fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
pglorio committed Dec 20, 2024
1 parent fc23f0d commit 65f6a42
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 192 deletions.
194 changes: 6 additions & 188 deletions src/transformers/models/zamba/modeling_zamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"""PyTorch Zamba model."""

import math
from typing import Callable, Any, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
import torch.utils.checkpoint
Expand All @@ -45,7 +45,6 @@
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
Expand Down Expand Up @@ -207,6 +206,7 @@ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
raise NotImplementedError("ZambaHybridDynamicCache does not have a legacy cache equivalent.")


def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
Expand All @@ -232,6 +232,7 @@ def eager_attention_forward(

return attn_output, attn_weights


class ZambaAttention(nn.Module):
"""
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
Expand All @@ -251,12 +252,11 @@ def __init__(self, config: ZambaConfig, layer_idx: int):
self.config = config
self.layer_idx = layer_idx

#### self.hidden_size = config.hidden_size
self.attention_hidden_size = config.attention_hidden_size
#### self.num_heads = config.num_attention_heads
self.head_dim = config.attention_head_dim
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = (self.head_dim / 2)**-0.5
self.max_position_embeddings = config.max_position_embeddings
self.scaling = (self.head_dim / 2) ** -0.5
self.is_causal = True
self.attention_dropout = config.attention_dropout

Expand Down Expand Up @@ -304,194 +304,12 @@ def forward(
scaling=self.scaling,
**kwargs,
)

attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights


# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention:
# Added softmax_scale = 1 / (query_states.shape[-1]/2)**0.5 to the arguments of self._flash_attention_forward
# dropped use_sliding_windows from the arguments of self._flash_attention_forward
# class ZambaFlashAttention2(ZambaAttention):
# """
# Zamba flash attention module. This module inherits from `ZambaAttention` as the weights of the module stays
# untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
# flash attention and deal with padding tokens in case the input contains any of them.
# """

# def __init__(self, *args, **kwargs):
# super().__init__(*args, **kwargs)

# # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
# self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

# def forward(
# self,
# hidden_states: torch.Tensor,
# layer_idx: int,
# attention_mask: Optional[torch.Tensor] = None,
# position_ids: Optional[torch.LongTensor] = None,
# past_key_value: Optional[ZambaHybridDynamicCache] = None,
# output_attentions: bool = False,
# use_cache: bool = False,
# cache_position: Optional[torch.LongTensor] = None,
# **kwargs,
# ):
# bsz, q_len, _ = hidden_states.size()

# query_states = self.q_proj(hidden_states)
# key_states = self.k_proj(hidden_states)
# value_states = self.v_proj(hidden_states)

# # Flash attention requires the input to have the shape
# # batch_size x seq_length x head_dim x hidden_dim
# # therefore we just need to keep the original shape
# query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
# key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

# if past_key_value is not None:
# key_states, value_states = past_key_value.update(key_states, value_states, layer_idx)

# # repeat k/v heads if n_kv_heads < n_heads
# key_states = repeat_kv(key_states, self.num_key_value_groups)
# value_states = repeat_kv(value_states, self.num_key_value_groups)
# dropout_rate = 0.0 if not self.training else self.attention_dropout

# # In PEFT, usually we cast the layer norms in float32 for training stability reasons
# # therefore the input hidden states gets silently casted in float32. Hence, we need
# # cast them back in float16 just to be sure everything works as expected.
# input_dtype = query_states.dtype
# if input_dtype == torch.float32:
# if torch.is_autocast_enabled():
# target_dtype = torch.get_autocast_gpu_dtype()
# # Handle the case where the model is quantized
# elif hasattr(self.config, "_pre_quantization_dtype"):
# target_dtype = self.config._pre_quantization_dtype
# else:
# target_dtype = self.q_proj.weight.dtype

# logger.warning_once(
# f"The input hidden states seems to be silently casted in float32, this might be related to"
# f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
# f" {target_dtype}."
# )

# query_states = query_states.to(target_dtype)
# key_states = key_states.to(target_dtype)
# value_states = value_states.to(target_dtype)

# # Reashape to the expected shape for Flash Attention
# query_states = query_states.transpose(1, 2)
# key_states = key_states.transpose(1, 2)
# value_states = value_states.transpose(1, 2)
# softmax_scale = 1 / math.sqrt(self.head_dim / 2)

# attn_output = _flash_attention_forward(
# query_states,
# key_states,
# value_states,
# attention_mask,
# q_len,
# dropout=dropout_rate,
# softmax_scale=softmax_scale,
# )

# attn_output = attn_output.reshape(bsz, q_len, self.attention_hidden_size).contiguous()
# attn_output = self.o_proj(attn_output)

# if not output_attentions:
# attn_weights = None

# return attn_output, attn_weights, past_key_value


# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention:
# added scale = 1 / (query_states.shape[-1]/2)**0.5 to the arguments of torch.nn.functional.scaled_dot_product_attention
# class ZambaSdpaAttention(ZambaAttention):
# """
# Zamba attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
# `ZambaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
# SDPA API.
# """

# def forward(
# self,
# hidden_states: torch.Tensor,
# layer_idx: int,
# attention_mask: Optional[torch.Tensor] = None,
# position_ids: Optional[torch.LongTensor] = None,
# past_key_value: Optional[ZambaHybridDynamicCache] = None,
# output_attentions: bool = False,
# use_cache: bool = False,
# cache_position: Optional[torch.LongTensor] = None,
# ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# if output_attentions:
# # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
# logger.warning_once(
# "ZambaModel is using ZambaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
# 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
# )
# return super().forward(
# hidden_states=hidden_states,
# attention_mask=attention_mask,
# position_ids=position_ids,
# past_key_value=past_key_value,
# output_attentions=output_attentions,
# use_cache=use_cache,
# )

# bsz, q_len, _ = hidden_states.size()

# query_states = self.q_proj(hidden_states)
# key_states = self.k_proj(hidden_states)
# value_states = self.v_proj(hidden_states)

# query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
# key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

# if past_key_value is not None:
# key_states, value_states = past_key_value.update(key_states, value_states, layer_idx)

# key_states = repeat_kv(key_states, self.num_key_value_groups)
# value_states = repeat_kv(value_states, self.num_key_value_groups)

# causal_mask = attention_mask
# if attention_mask is not None:
# causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

# # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# # Reference: https://github.com/pytorch/pytorch/issues/112577.
# if query_states.device.type == "cuda" and attention_mask is not None:
# query_states = query_states.contiguous()
# key_states = key_states.contiguous()
# value_states = value_states.contiguous()

# softmax_scale = 1 / math.sqrt(self.head_dim / 2)

# attn_output = torch.nn.functional.scaled_dot_product_attention(
# query_states,
# key_states,
# value_states,
# attn_mask=causal_mask,
# dropout_p=self.attention_dropout if self.training else 0.0,
# # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
# is_causal=self.is_causal and attention_mask is None and q_len > 1,
# scale=softmax_scale,
# )

# attn_output = attn_output.transpose(1, 2).contiguous()
# attn_output = attn_output.view(bsz, q_len, self.attention_hidden_size)

# attn_output = self.o_proj(attn_output)

# return attn_output, None, past_key_value


class ZambaMambaMixer(nn.Module):
"""
Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
Expand Down
6 changes: 2 additions & 4 deletions tests/models/zamba/test_modeling_zamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
ZambaModel,
)
from transformers.models.zamba.modeling_zamba import (
HybridMambaAttentionDynamicCache,
ZambaHybridDynamicCache,
)


Expand Down Expand Up @@ -215,9 +215,7 @@ def create_and_check_decoder_model_past_large_inputs(

# first forward pass
# Attention: Zamba needs the cache to be initialized to return a cache!
past_key_values = HybridMambaAttentionDynamicCache(
config, input_ids.shape[0], model.dtype, device=model.device
)
past_key_values = ZambaHybridDynamicCache(config, input_ids.shape[0], model.dtype, device=model.device)
outputs = model(
input_ids,
attention_mask=input_mask,
Expand Down

0 comments on commit 65f6a42

Please sign in to comment.