Skip to content

Commit

Permalink
added inheritances in modular, renamed zamba cache
Browse files Browse the repository at this point in the history
  • Loading branch information
pglorio committed Nov 19, 2024
1 parent 9adf85e commit 904da4e
Show file tree
Hide file tree
Showing 3 changed files with 328 additions and 321 deletions.
60 changes: 30 additions & 30 deletions src/transformers/models/zamba/modeling_zamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class HybridMambaAttentionDynamicCache(DynamicCache):
class ZambaHybridDynamicCache(DynamicCache):
"""
A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
(which has a constant shape regardless of seq_len).
Expand All @@ -131,9 +131,9 @@ def __init__(self, config, batch_size, dtype=torch.float16, device=None):
self.dtype = dtype
self.layers_block_type = config.layers_block_type
self.has_previous_state = False # only used by mamba
intermediate_size = config.mamba_expand * config.hidden_size
ssm_state_size = config.mamba_d_state
conv_kernel_size = config.mamba_d_conv
self.intermediate_size = config.mamba_expand * config.hidden_size
self.ssm_state_size = config.mamba_d_state
self.conv_kernel_size = config.mamba_d_conv
self.n_mamba_heads = config.n_mamba_heads
self.conv_states = []
self.ssm_states = []
Expand All @@ -143,9 +143,9 @@ def __init__(self, config, batch_size, dtype=torch.float16, device=None):
self._buffers = {}
for i in range(config.num_hidden_layers):
self.conv_states += [
torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
torch.zeros(batch_size, self.intermediate_size, self.conv_kernel_size, device=device, dtype=dtype)
]
cache_shape = (batch_size, self.n_mamba_heads, intermediate_size // self.n_mamba_heads, ssm_state_size)
cache_shape = (batch_size, self.n_mamba_heads, self.intermediate_size // self.n_mamba_heads, self.ssm_state_size)
self.ssm_states += [torch.zeros(cache_shape, device=device, dtype=dtype)]
if self.layers_block_type[i] == "hybrid":
self.transformer_layers.append(i)
Expand Down Expand Up @@ -196,12 +196,12 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:

# Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.to_legacy_cache
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
raise NotImplementedError("ZambaHybridDynamicCache does not have a legacy cache equivalent.")

@classmethod
# Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.from_legacy_cache
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
raise NotImplementedError("ZambaHybridDynamicCache does not have a legacy cache equivalent.")


class ZambaAttention(nn.Module):
Expand Down Expand Up @@ -249,7 +249,7 @@ def forward(
layer_idx: int,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
past_key_value: Optional[ZambaHybridDynamicCache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
Expand Down Expand Up @@ -327,7 +327,7 @@ def forward(
layer_idx: int,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
past_key_value: Optional[ZambaHybridDynamicCache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
Expand Down Expand Up @@ -417,7 +417,7 @@ def forward(
layer_idx: int,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
past_key_value: Optional[ZambaHybridDynamicCache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
Expand Down Expand Up @@ -569,7 +569,7 @@ def __init__(self, config: ZambaConfig, layer_idx):
)

def cuda_kernels_forward(
self, hidden_states: torch.Tensor, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask=None
self, hidden_states: torch.Tensor, cache_params: ZambaHybridDynamicCache = None, attention_mask=None
):
batch_size, seq_len, _ = hidden_states.shape
use_precomputed_states = cache_params is not None and cache_params.has_previous_state and seq_len == 1
Expand Down Expand Up @@ -665,7 +665,7 @@ def cuda_kernels_forward(
contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
return contextualized_states

def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask=None):
def slow_forward(self, input_states, cache_params: ZambaHybridDynamicCache = None, attention_mask=None):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
# 1. Gated linear projection
Expand All @@ -676,7 +676,7 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa
gate = gate.squeeze(2)
gate = gate.reshape(batch_size, self.n_mamba_heads, -1, seq_len).transpose(0, 1)

use_cache = isinstance(cache_params, HybridMambaAttentionDynamicCache)
use_cache = isinstance(cache_params, ZambaHybridDynamicCache)
# 2. Convolution sequence transformation
if use_cache and cache_params.ssm_states[self.layer_idx].shape[0] == batch_size:
if self.training:
Expand Down Expand Up @@ -758,7 +758,7 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa
)
return contextualized_states

def forward(self, hidden_states, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask=None):
def forward(self, hidden_states, cache_params: ZambaHybridDynamicCache = None, attention_mask=None):
if self.use_fast_kernels:
if not is_fast_path_available or "cuda" not in self.x_proj_weight.device.type:
raise ValueError(
Expand Down Expand Up @@ -801,7 +801,7 @@ def forward(
layer_idx: int,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
past_key_value: Optional[ZambaHybridDynamicCache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
Expand All @@ -816,7 +816,7 @@ def forward(
(see fig. 2 in https://arxiv.org/pdf/2405.16712).
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states
past_key_value (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
Expand Down Expand Up @@ -869,7 +869,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
causal_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
past_key_value: Optional[ZambaHybridDynamicCache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
Expand All @@ -880,7 +880,7 @@ def forward(
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states
past_key_value (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
Expand Down Expand Up @@ -937,7 +937,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
causal_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
past_key_value: Optional[ZambaHybridDynamicCache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
Expand All @@ -950,7 +950,7 @@ def forward(
layer_idx (`int`): layer number.
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states
past_key_value (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
Expand Down Expand Up @@ -1026,7 +1026,7 @@ class ZambaPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = False
_supports_sdpa = False
_supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache
_supports_cache_class = True # Note: only supports ZambaHybridDynamicCache
_is_stateful = True

def _init_weights(self, module):
Expand Down Expand Up @@ -1120,14 +1120,14 @@ def _check_and_enable_flash_attn_2(
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`HybridMambaAttentionDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the
past_key_values (`ZambaHybridDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
A ZambaHybridDynamicCache object containing pre-computed hidden-states (keys and values in the
self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`.
Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and
`(batch_size, d_inner, d_state)` respectively.
See the `HybridMambaAttentionDynamicCache` class for more details.
See the `ZambaHybridDynamicCache` class for more details.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
Expand Down Expand Up @@ -1225,7 +1225,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
past_key_values: Optional[ZambaHybridDynamicCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
Expand Down Expand Up @@ -1262,7 +1262,7 @@ def forward(

if use_cache and past_key_values is None:
logger.warning_once(
"Zamba requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was "
"Zamba requires an initialized `ZambaHybridDynamicCache` to return a cache. None was "
"provided, so no cache will be returned."
)

Expand Down Expand Up @@ -1409,7 +1409,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
past_key_values: Optional[ZambaHybridDynamicCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
Expand Down Expand Up @@ -1503,7 +1503,7 @@ def prepare_inputs_for_generation(
use_cache=True,
**kwargs,
):
# Overwitten -- has a unique cache type, `HybridMambaAttentionDynamicCache`
# Overwitten -- has a unique cache type, `ZambaHybridDynamicCache`

empty_past_kv = past_key_values is None

Expand All @@ -1517,7 +1517,7 @@ def prepare_inputs_for_generation(
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
else:
past_key_values = HybridMambaAttentionDynamicCache(
past_key_values = ZambaHybridDynamicCache(
self.config, input_ids.shape[0], dtype=self.dtype, device=self.device
)

Expand Down
Loading

0 comments on commit 904da4e

Please sign in to comment.