Skip to content

Commit

Permalink
Fix modular model converter
Browse files Browse the repository at this point in the history
  • Loading branch information
pglorio committed Nov 11, 2024
1 parent ffc2a58 commit 9adf85e
Showing 1 changed file with 11 additions and 99 deletions.
110 changes: 11 additions & 99 deletions src/transformers/models/zamba2/modeling_zamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@


if is_mamba_ssm_available():
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
else:
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
Expand All @@ -65,10 +65,8 @@
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
else:
causal_conv1d_update, causal_conv1d_fn = None, None

is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)

is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -271,92 +269,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class HybridMambaAttentionDynamicCache(DynamicCache):
"""
A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
(which has a constant shape regardless of seq_len).
This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`.
"""

def __init__(self, config, batch_size, dtype=torch.float16, device=None):
self.dtype = dtype
self.layers_block_type = config.layers_block_type
self.has_previous_state = False # only used by mamba
intermediate_size = config.mamba_expand * config.hidden_size
ssm_state_size = config.mamba_d_state
conv_kernel_size = config.mamba_d_conv
self.n_mamba_heads = config.n_mamba_heads
self.conv_states = []
self.ssm_states = []
self.transformer_layers = []
self._modules = {}
self._parameters = {}
self._buffers = {}
for i in range(config.num_hidden_layers):
self.conv_states += [
torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
]
cache_shape = (batch_size, self.n_mamba_heads, intermediate_size // self.n_mamba_heads, ssm_state_size)
self.ssm_states += [torch.zeros(cache_shape, device=device, dtype=dtype)]
if self.layers_block_type[i] == "hybrid":
self.transformer_layers.append(i)

self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Update the cache
if self.key_cache[layer_idx].shape[-1] == 0:
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2)

return self.key_cache[layer_idx], self.value_cache[layer_idx]

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))

device = self.conv_states[layer_idx].device
self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device))
device = self.ssm_states[layer_idx].device
self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device))

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# take any layer that contains cache and not empty tensor
layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
if len(self.key_cache) <= layer_idx:
return 0
return self.key_cache[layer_idx].shape[-2]

def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")

@classmethod
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")


def layer_type_list(config: Zamba2Config):
"""
Returns list of layer ids containing hybrid layers
Expand Down Expand Up @@ -1460,7 +1372,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
causal_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
past_key_value: Optional[Zamba2DynamicCache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
Expand All @@ -1471,7 +1383,7 @@ def forward(
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states
past_key_value (`Zamba2DynamicCache`, *optional*): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
Expand Down Expand Up @@ -1685,14 +1597,14 @@ def _init_weights(self, module):
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`HybridMambaAttentionDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the
past_key_values (`Zamba2DynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
A Zamba2DynamicCache object containing pre-computed hidden-states (keys and values in the
self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`.
Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and
`(batch_size, d_inner, d_state)` respectively.
See the `HybridMambaAttentionDynamicCache` class for more details.
See the `Zamba2DynamicCache` class for more details.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
Expand Down Expand Up @@ -1831,7 +1743,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
past_key_values: Optional[Zamba2DynamicCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
Expand Down Expand Up @@ -1868,7 +1780,7 @@ def forward(

if use_cache and past_key_values is None:
logger.warning_once(
"Zamba2 requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was "
"Zamba2 requires an initialized `Zamba2DynamicCache` to return a cache. None was "
"provided, so no cache will be returned."
)

Expand Down Expand Up @@ -2014,7 +1926,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
past_key_values: Optional[Zamba2DynamicCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
Expand Down

0 comments on commit 9adf85e

Please sign in to comment.