From 9adf85e00b42a7dca00ff13c8646f504dad6c2c4 Mon Sep 17 00:00:00 2001 From: pglorio Date: Mon, 11 Nov 2024 06:31:49 +0000 Subject: [PATCH] Fix modular model converter --- .../models/zamba2/modeling_zamba2.py | 110 ++---------------- 1 file changed, 11 insertions(+), 99 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index e78c155a1f2ed2..79413fdf2673aa 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -56,7 +56,7 @@ if is_mamba_ssm_available(): - from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined from mamba_ssm.ops.triton.selective_state_update import selective_state_update else: selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None @@ -65,10 +65,8 @@ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update else: causal_conv1d_update, causal_conv1d_fn = None, None - -is_fast_path_available = all( - (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) -) + +is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) logger = logging.get_logger(__name__) @@ -271,92 +269,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -class HybridMambaAttentionDynamicCache(DynamicCache): - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache - (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - def __init__(self, config, batch_size, dtype=torch.float16, device=None): - self.dtype = dtype - self.layers_block_type = config.layers_block_type - self.has_previous_state = False # only used by mamba - intermediate_size = config.mamba_expand * config.hidden_size - ssm_state_size = config.mamba_d_state - conv_kernel_size = config.mamba_d_conv - self.n_mamba_heads = config.n_mamba_heads - self.conv_states = [] - self.ssm_states = [] - self.transformer_layers = [] - self._modules = {} - self._parameters = {} - self._buffers = {} - for i in range(config.num_hidden_layers): - self.conv_states += [ - torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype) - ] - cache_shape = (batch_size, self.n_mamba_heads, intermediate_size // self.n_mamba_heads, ssm_state_size) - self.ssm_states += [torch.zeros(cache_shape, device=device, dtype=dtype)] - if self.layers_block_type[i] == "hybrid": - self.transformer_layers.append(i) - - self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Update the cache - if self.key_cache[layer_idx].shape[-1] == 0: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - device = self.conv_states[layer_idx].device - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) - device = self.ssm_states[layer_idx].device - self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: - raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") - - @classmethod - def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": - raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") - - def layer_type_list(config: Zamba2Config): """ Returns list of layer ids containing hybrid layers @@ -1460,7 +1372,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, causal_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_value: Optional[Zamba2DynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -1471,7 +1383,7 @@ def forward( hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + past_key_value (`Zamba2DynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -1685,14 +1597,14 @@ def _init_weights(self, module): config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) - past_key_values (`HybridMambaAttentionDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the + past_key_values (`Zamba2DynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + A Zamba2DynamicCache object containing pre-computed hidden-states (keys and values in the self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`. Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and `(batch_size, d_inner, d_state)` respectively. - See the `HybridMambaAttentionDynamicCache` class for more details. + See the `Zamba2DynamicCache` class for more details. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all @@ -1831,7 +1743,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_values: Optional[Zamba2DynamicCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -1868,7 +1780,7 @@ def forward( if use_cache and past_key_values is None: logger.warning_once( - "Zamba2 requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was " + "Zamba2 requires an initialized `Zamba2DynamicCache` to return a cache. None was " "provided, so no cache will be returned." ) @@ -2014,7 +1926,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_values: Optional[Zamba2DynamicCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None,