From 65f6a425d91a99bf1beef8b45aa298bfcffe9ff9 Mon Sep 17 00:00:00 2001 From: pglorio Date: Fri, 20 Dec 2024 20:36:02 +0000 Subject: [PATCH] make fixup fixes --- .../models/zamba/modeling_zamba.py | 194 +----------------- tests/models/zamba/test_modeling_zamba.py | 6 +- 2 files changed, 8 insertions(+), 192 deletions(-) diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 384aa070243e87..edcbcc6d883ebf 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -20,7 +20,7 @@ """PyTorch Zamba model.""" import math -from typing import Callable, Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -45,7 +45,6 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -207,6 +206,7 @@ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": raise NotImplementedError("ZambaHybridDynamicCache does not have a legacy cache equivalent.") + def eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -232,6 +232,7 @@ def eager_attention_forward( return attn_output, attn_weights + class ZambaAttention(nn.Module): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer @@ -251,12 +252,11 @@ def __init__(self, config: ZambaConfig, layer_idx: int): self.config = config self.layer_idx = layer_idx - #### self.hidden_size = config.hidden_size self.attention_hidden_size = config.attention_hidden_size - #### self.num_heads = config.num_attention_heads self.head_dim = config.attention_head_dim self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.scaling = (self.head_dim / 2)**-0.5 + self.max_position_embeddings = config.max_position_embeddings + self.scaling = (self.head_dim / 2) ** -0.5 self.is_causal = True self.attention_dropout = config.attention_dropout @@ -304,194 +304,12 @@ def forward( scaling=self.scaling, **kwargs, ) - + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights -# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention: -# Added softmax_scale = 1 / (query_states.shape[-1]/2)**0.5 to the arguments of self._flash_attention_forward -# dropped use_sliding_windows from the arguments of self._flash_attention_forward -# class ZambaFlashAttention2(ZambaAttention): -# """ -# Zamba flash attention module. This module inherits from `ZambaAttention` as the weights of the module stays -# untouched. The only required change would be on the forward pass where it needs to correctly call the public API of -# flash attention and deal with padding tokens in case the input contains any of them. -# """ - -# def __init__(self, *args, **kwargs): -# super().__init__(*args, **kwargs) - -# # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. -# # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. -# # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). -# self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - -# def forward( -# self, -# hidden_states: torch.Tensor, -# layer_idx: int, -# attention_mask: Optional[torch.Tensor] = None, -# position_ids: Optional[torch.LongTensor] = None, -# past_key_value: Optional[ZambaHybridDynamicCache] = None, -# output_attentions: bool = False, -# use_cache: bool = False, -# cache_position: Optional[torch.LongTensor] = None, -# **kwargs, -# ): -# bsz, q_len, _ = hidden_states.size() - -# query_states = self.q_proj(hidden_states) -# key_states = self.k_proj(hidden_states) -# value_states = self.v_proj(hidden_states) - -# # Flash attention requires the input to have the shape -# # batch_size x seq_length x head_dim x hidden_dim -# # therefore we just need to keep the original shape -# query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) -# key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) -# value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - -# if past_key_value is not None: -# key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) - -# # repeat k/v heads if n_kv_heads < n_heads -# key_states = repeat_kv(key_states, self.num_key_value_groups) -# value_states = repeat_kv(value_states, self.num_key_value_groups) -# dropout_rate = 0.0 if not self.training else self.attention_dropout - -# # In PEFT, usually we cast the layer norms in float32 for training stability reasons -# # therefore the input hidden states gets silently casted in float32. Hence, we need -# # cast them back in float16 just to be sure everything works as expected. -# input_dtype = query_states.dtype -# if input_dtype == torch.float32: -# if torch.is_autocast_enabled(): -# target_dtype = torch.get_autocast_gpu_dtype() -# # Handle the case where the model is quantized -# elif hasattr(self.config, "_pre_quantization_dtype"): -# target_dtype = self.config._pre_quantization_dtype -# else: -# target_dtype = self.q_proj.weight.dtype - -# logger.warning_once( -# f"The input hidden states seems to be silently casted in float32, this might be related to" -# f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" -# f" {target_dtype}." -# ) - -# query_states = query_states.to(target_dtype) -# key_states = key_states.to(target_dtype) -# value_states = value_states.to(target_dtype) - -# # Reashape to the expected shape for Flash Attention -# query_states = query_states.transpose(1, 2) -# key_states = key_states.transpose(1, 2) -# value_states = value_states.transpose(1, 2) -# softmax_scale = 1 / math.sqrt(self.head_dim / 2) - -# attn_output = _flash_attention_forward( -# query_states, -# key_states, -# value_states, -# attention_mask, -# q_len, -# dropout=dropout_rate, -# softmax_scale=softmax_scale, -# ) - -# attn_output = attn_output.reshape(bsz, q_len, self.attention_hidden_size).contiguous() -# attn_output = self.o_proj(attn_output) - -# if not output_attentions: -# attn_weights = None - -# return attn_output, attn_weights, past_key_value - - -# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention: -# added scale = 1 / (query_states.shape[-1]/2)**0.5 to the arguments of torch.nn.functional.scaled_dot_product_attention -# class ZambaSdpaAttention(ZambaAttention): -# """ -# Zamba attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from -# `ZambaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to -# SDPA API. -# """ - -# def forward( -# self, -# hidden_states: torch.Tensor, -# layer_idx: int, -# attention_mask: Optional[torch.Tensor] = None, -# position_ids: Optional[torch.LongTensor] = None, -# past_key_value: Optional[ZambaHybridDynamicCache] = None, -# output_attentions: bool = False, -# use_cache: bool = False, -# cache_position: Optional[torch.LongTensor] = None, -# ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: -# if output_attentions: -# # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. -# logger.warning_once( -# "ZambaModel is using ZambaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " -# 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' -# ) -# return super().forward( -# hidden_states=hidden_states, -# attention_mask=attention_mask, -# position_ids=position_ids, -# past_key_value=past_key_value, -# output_attentions=output_attentions, -# use_cache=use_cache, -# ) - -# bsz, q_len, _ = hidden_states.size() - -# query_states = self.q_proj(hidden_states) -# key_states = self.k_proj(hidden_states) -# value_states = self.v_proj(hidden_states) - -# query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) -# key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) -# value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - -# if past_key_value is not None: -# key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) - -# key_states = repeat_kv(key_states, self.num_key_value_groups) -# value_states = repeat_kv(value_states, self.num_key_value_groups) - -# causal_mask = attention_mask -# if attention_mask is not None: -# causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - -# # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, -# # Reference: https://github.com/pytorch/pytorch/issues/112577. -# if query_states.device.type == "cuda" and attention_mask is not None: -# query_states = query_states.contiguous() -# key_states = key_states.contiguous() -# value_states = value_states.contiguous() - -# softmax_scale = 1 / math.sqrt(self.head_dim / 2) - -# attn_output = torch.nn.functional.scaled_dot_product_attention( -# query_states, -# key_states, -# value_states, -# attn_mask=causal_mask, -# dropout_p=self.attention_dropout if self.training else 0.0, -# # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. -# is_causal=self.is_causal and attention_mask is None and q_len > 1, -# scale=softmax_scale, -# ) - -# attn_output = attn_output.transpose(1, 2).contiguous() -# attn_output = attn_output.view(bsz, q_len, self.attention_hidden_size) - -# attn_output = self.o_proj(attn_output) - -# return attn_output, None, past_key_value - - class ZambaMambaMixer(nn.Module): """ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. diff --git a/tests/models/zamba/test_modeling_zamba.py b/tests/models/zamba/test_modeling_zamba.py index a6dd516f98a412..ee47f98a1f4133 100644 --- a/tests/models/zamba/test_modeling_zamba.py +++ b/tests/models/zamba/test_modeling_zamba.py @@ -46,7 +46,7 @@ ZambaModel, ) from transformers.models.zamba.modeling_zamba import ( - HybridMambaAttentionDynamicCache, + ZambaHybridDynamicCache, ) @@ -215,9 +215,7 @@ def create_and_check_decoder_model_past_large_inputs( # first forward pass # Attention: Zamba needs the cache to be initialized to return a cache! - past_key_values = HybridMambaAttentionDynamicCache( - config, input_ids.shape[0], model.dtype, device=model.device - ) + past_key_values = ZambaHybridDynamicCache(config, input_ids.shape[0], model.dtype, device=model.device) outputs = model( input_ids, attention_mask=input_mask,