diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 67bd31fdaeede5..ce984574963afc 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -101,6 +101,7 @@ FlashAttention-2 is currently supported for the following architectures: * [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip) * [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel) * [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel) +* [Zamba2](https://huggingface.co/docs/transformers/model_doc/zamba2) You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request. @@ -304,7 +305,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaModel) * [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/model_doc/xlm-roberta-xl#transformers.XLMRobertaXLModel) * [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel) - +* [Zamba2](https://huggingface.co/docs/transformers/model_doc/zamba2) FlashAttention can only be used for models with the `fp16` or `bf16` torch type, so make sure to cast your model to the appropriate type first. The memory-efficient attention backend is able to handle `fp32` models. diff --git a/src/transformers/models/zamba2/configuration_zamba2.py b/src/transformers/models/zamba2/configuration_zamba2.py index 2acd266a425c62..5463ca6e8f5610 100644 --- a/src/transformers/models/zamba2/configuration_zamba2.py +++ b/src/transformers/models/zamba2/configuration_zamba2.py @@ -19,7 +19,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math + from ...configuration_utils import PretrainedConfig @@ -30,19 +30,16 @@ class Zamba2Config(PretrainedConfig): Zamba2 model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Zamba2 model. + [Zyphra/Zamba2-2.7B](https://huggingface.co/Zyphra/Zamba2-2.7B) + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. - - Args: vocab_size (`int`, *optional*, defaults to 32000): Vocabulary size of the Zamba2 model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`Zamba2Model`] max_position_embeddings (`int`, *optional*, defaults to 4096): The maximum sequence length that this model might ever be used with. - tie_word_embeddings (`bool`, *optional*, defaults to `True`): - Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the - model has a output word embedding layer. hidden_size (`int`, *optional*, defaults to 2560): Dimension of the hidden representations. num_hidden_layers (`int`, *optional*, defaults to 54): @@ -52,7 +49,7 @@ class Zamba2Config(PretrainedConfig): mamba_d_state (`int`, *optional*, defaults to 64): shape of the state space latents. mamba_d_conv (`int`, *optional*, defaults to 4): Size of the convolution kernel. mamba_expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size. - mamba_ngroups (`int`, *optional*, defaults to 8): + mamba_ngroups (`int`, *optional*, defaults to 1): Number of groups for the evolution matrices of mamba 2. time_step_min (`float`, *optional*, defaults to 0.001): Minimum `time_step` used to bound `dt_proj.bias`. @@ -62,16 +59,10 @@ class Zamba2Config(PretrainedConfig): Minimum clamping value of the `dt_proj.bias` layer initialization. time_step_limit (`tuple`, *optional*): Accepted range of time step values. - mamba_dt_rank (`Union[int,str]`, *optional*, defaults to `"auto"`): - Rank of the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)` n_mamba_heads (`int`, *optional*, defaults to 1): Number of heads for the evolution matrices of mamba 2. use_conv_bias (`bool`, *optional*, defaults to `True`): Whether or not to use bias in the convolution layer of the mixer block. - mamba_proj_bias (`bool`, *optional*, defaults to `False`): - Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block - hidden_mamba_act (`str`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) adjacent to the mamba conv. chunk_size (`int`, *optional*, defaults to 256): Size of the chunks that will comprise the sequence. add_bias_linear (`bool`, *optional*, defaults to `False`): @@ -101,11 +92,11 @@ class Zamba2Config(PretrainedConfig): Rank of the LoRA in the shared MLP and shared attention layers. use_mem_rope (`bool`, *optional*, defaults to `False`): If True, includes RoPE in the shared attention layers. - rope_theta (`float`, *optional*, defaults to 10000.0): + rope_theta (`float`, *optional*, defaults to `10000.0`): The base period of the RoPE embeddings. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-5): + rms_norm_eps (`float`, *optional*, defaults to 1e-05): The epsilon used by the rms normalization layers. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only @@ -122,6 +113,16 @@ class Zamba2Config(PretrainedConfig): The id of the "beginning-of-sequence" token. eos_token_id (`int`, *optional*, defaults to 2): The id of the "end-of-sequence" token. + use_long_context (`bool`, *optional*, defaults to `False`): + Activates the context-extended version of Zamba by modifying RoPE. + ```python + >>> from transformers import Zamba2Model, Zamba2Config + >>> # Initializing a Zamba2-2.7B style configuration + >>> configuration = Zamba2Config() + >>> # Initializing a model from the Zamba2-2.7B style configuration + >>> model = Zamba2Model(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config """ model_type = "zamba2" @@ -131,7 +132,6 @@ def __init__( self, vocab_size=32000, max_position_embeddings=4096, - tie_word_embeddings=True, hidden_size=2560, num_hidden_layers=54, layers_block_type=None, @@ -143,10 +143,7 @@ def __init__( time_step_max=0.1, time_step_floor=1e-4, time_step_limit=None, - mamba_dt_rank="auto", n_mamba_heads=1, - mamba_proj_bias=False, - hidden_mamba_act="silu", use_conv_bias=True, chunk_size=256, add_bias_linear=False, @@ -175,13 +172,10 @@ def __init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, **kwargs, ) - self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings - self.tie_word_embeddings = tie_word_embeddings self.hidden_size = hidden_size if intermediate_size is None: self.intermediate_size = 4 * hidden_size @@ -199,17 +193,13 @@ def __init__( self.mamba_d_state = mamba_d_state self.mamba_d_conv = mamba_d_conv self.mamba_expand = mamba_expand - self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank self.add_bias_linear = add_bias_linear - self.mamba_headdim = int(mamba_expand * hidden_size) // n_mamba_heads self.mamba_ngroups = mamba_ngroups self.n_mamba_heads = n_mamba_heads - self.mamba_proj_bias = mamba_proj_bias - self.hidden_mamba_act = hidden_mamba_act + self.mamba_headdim = int(mamba_expand * hidden_size) // n_mamba_heads self.use_conv_bias = use_conv_bias self.chunk_size = chunk_size self.time_step_limit = time_step_limit - self.use_shared_mlp_lora = use_shared_mlp_lora self.use_shared_attention_lora = use_shared_attention_lora self.lora_rank = lora_rank @@ -219,21 +209,12 @@ def __init__( self.time_step_floor = time_step_floor if use_long_context: self.max_position_embeddings = 16384 - - # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads - self.num_attention_heads = num_attention_heads self.kv_channels = self.hidden_size // self.num_attention_heads self.num_query_groups = self.num_attention_heads - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - - self.use_cache = use_cache - self.num_logits_to_keep = num_logits_to_keep - # Below, "mamba" stands for mamba layer, "hybrid" stands for hybrid layer (composed by a shared transformer followed by mamba layer) if layers_block_type is None: self.layers_block_type = ( @@ -246,4 +227,8 @@ def __init__( + ["mamba"] * 2 ) else: - self.layers_block_type = layers_block_type \ No newline at end of file + self.layers_block_type = layers_block_type + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.num_logits_to_keep = num_logits_to_keep diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index a92c59ac81ef2b..e78c155a1f2ed2 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -56,12 +56,8 @@ if is_mamba_ssm_available(): - #### from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn + from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn from mamba_ssm.ops.triton.selective_state_update import selective_state_update - from mamba_ssm.ops.triton.ssd_combined import ( #### added - mamba_chunk_scan_combined, - mamba_split_conv1d_scan_combined, - ) else: selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None @@ -70,7 +66,32 @@ else: causal_conv1d_update, causal_conv1d_fn = None, None -is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) #### added +is_fast_path_available = all( + (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) +) + + +logger = logging.get_logger(__name__) + + +class Zamba2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Zamba2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class Zamba2DynamicCache(DynamicCache): @@ -191,29 +212,6 @@ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTens raise NotImplementedError("Zamba2DynamicCache does not have a legacy cache equivalent.") -logger = logging.get_logger(__name__) - - -class Zamba2RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - Zamba2RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - class MambaRMSNormGated(torch.nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() @@ -273,6 +271,92 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +class HybridMambaAttentionDynamicCache(DynamicCache): + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + def __init__(self, config, batch_size, dtype=torch.float16, device=None): + self.dtype = dtype + self.layers_block_type = config.layers_block_type + self.has_previous_state = False # only used by mamba + intermediate_size = config.mamba_expand * config.hidden_size + ssm_state_size = config.mamba_d_state + conv_kernel_size = config.mamba_d_conv + self.n_mamba_heads = config.n_mamba_heads + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + self._modules = {} + self._parameters = {} + self._buffers = {} + for i in range(config.num_hidden_layers): + self.conv_states += [ + torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype) + ] + cache_shape = (batch_size, self.n_mamba_heads, intermediate_size // self.n_mamba_heads, ssm_state_size) + self.ssm_states += [torch.zeros(cache_shape, device=device, dtype=dtype)] + if self.layers_block_type[i] == "hybrid": + self.transformer_layers.append(i) + + self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Update the cache + if self.key_cache[layer_idx].shape[-1] == 0: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + def layer_type_list(config: Zamba2Config): """ Returns list of layer ids containing hybrid layers @@ -518,7 +602,6 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - self.is_causal = True def forward( self, @@ -613,9 +696,9 @@ def forward( value_states, attention_mask, q_len, - is_causal=self.is_causal, dropout=dropout_rate, softmax_scale=softmax_scale, + is_causal=self.is_causal, ) attn_output = attn_output.reshape(bsz, q_len, self.attention_hidden_size).contiguous() @@ -743,7 +826,6 @@ def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): Assumes that we only have tensors of either size 4 or 3 """ - # pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if input_tensor.ndim == 4 else (0, 0, 0, pad_size, 0, 0) return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) @@ -759,7 +841,6 @@ def reshape_into_chunks(input_tensor, pad_size, chunk_size): # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] input_tensor = pad_tensor_by_size(input_tensor, pad_size) - # if len(input_tensor.shape) == 3: if input_tensor.ndim == 3: # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) @@ -806,7 +887,7 @@ def __init__(self, config: Zamba2Config, layer_idx: int = None): self.conv_kernel_size = config.mamba_d_conv self.intermediate_size = int(config.mamba_expand * self.hidden_size) self.layer_idx = layer_idx - self.use_conv_bias = config.use_conv_bias # add this with default True + self.use_conv_bias = config.use_conv_bias self.activation = "silu" self.act = nn.SiLU() @@ -815,9 +896,9 @@ def __init__(self, config: Zamba2Config, layer_idx: int = None): self.num_heads = self.config.n_mamba_heads self.chunk_size = config.chunk_size - self.time_step_limit = config.time_step_limit # add this with default (0.0, float("inf")) - self.time_step_min = config.time_step_min # add this, with same default as zamba1 - self.time_step_max = config.time_step_max # add this, with same default as zamba1 + self.time_step_limit = config.time_step_limit + self.time_step_min = config.time_step_min + self.time_step_max = config.time_step_max self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size self.conv1d = nn.Conv1d( @@ -926,7 +1007,7 @@ def cuda_kernels_forward( # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states) A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) - dt_limit_kwargs = {} if self.time_step_limit == None else {"dt_limit": self.time_step_limit} + dt_limit_kwargs = {} if self.time_step_limit is None else {"dt_limit": self.time_step_limit} if attention_mask is not None: input_not_masked = torch.all(attention_mask == 1) else: @@ -941,7 +1022,7 @@ def cuda_kernels_forward( A, D=self.D, chunk_size=self.chunk_size, - seq_idx=None, # was seq_idx + seq_idx=None, activation=self.activation, rmsnorm_weight=self.norm.weight, rmsnorm_eps=self.norm.variance_epsilon, @@ -1379,7 +1460,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, causal_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Zamba2DynamicCache] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -1390,7 +1471,7 @@ def forward( hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`Zamba2DynamicCache`, *optional*): cached past key and value projection states + past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -1553,9 +1634,9 @@ def _init_weights(self, module): module.A_log._no_weight_decay = True module.D._no_weight_decay = True - # num_heads = int(self.config.mamba_expand * self.config.hidden_size) // self.config.mamba_headdim dt = torch.exp( - torch.rand(self.config.n_mamba_heads) * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + torch.rand(self.config.n_mamba_heads) + * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + math.log(self.config.time_step_min) ).clamp(min=self.config.time_step_floor) # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 @@ -1604,14 +1685,14 @@ def _init_weights(self, module): config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) - past_key_values (`Zamba2DynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - A Zamba2DynamicCache object containing pre-computed hidden-states (keys and values in the + past_key_values (`HybridMambaAttentionDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`. Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and `(batch_size, d_inner, d_state)` respectively. - See the `Zamba2DynamicCache` class for more details. + See the `HybridMambaAttentionDynamicCache` class for more details. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all @@ -1692,8 +1773,12 @@ def __init__(self, config: Zamba2Config): lora_id = 0 for _layer_type in self.layers_block_type: if _layer_type == "hybrid" and lora_id % config.num_mem_blocks == block.block_id: - tied_keys_lora.append('shared_transf.feed_forward.gate_up_proj_lora_A_list.' + str(lora_id) + '.weight') - tied_keys_lora.append('shared_transf.feed_forward.gate_up_proj_lora_B_list.' + str(lora_id) + '.weight') + tied_keys_lora.append( + "shared_transf.feed_forward.gate_up_proj_lora_A_list." + str(lora_id) + ".weight" + ) + tied_keys_lora.append( + "shared_transf.feed_forward.gate_up_proj_lora_B_list." + str(lora_id) + ".weight" + ) lora_id += 1 self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_lora] if config.use_shared_attention_lora: @@ -1701,15 +1786,27 @@ def __init__(self, config: Zamba2Config): lora_id = 0 for _layer_type in self.layers_block_type: if _layer_type == "hybrid" and lora_id % config.num_mem_blocks == block.block_id: - tied_keys_lora.append('shared_transf.self_attn.linear_q_lora_A_list.' + str(lora_id) + '.weight') - tied_keys_lora.append('shared_transf.self_attn.linear_k_lora_A_list.' + str(lora_id) + '.weight') - tied_keys_lora.append('shared_transf.self_attn.linear_v_lora_A_list.' + str(lora_id) + '.weight') - tied_keys_lora.append('shared_transf.self_attn.linear_q_lora_B_list.' + str(lora_id) + '.weight') - tied_keys_lora.append('shared_transf.self_attn.linear_k_lora_B_list.' + str(lora_id) + '.weight') - tied_keys_lora.append('shared_transf.self_attn.linear_v_lora_B_list.' + str(lora_id) + '.weight') + tied_keys_lora.append( + "shared_transf.self_attn.linear_q_lora_A_list." + str(lora_id) + ".weight" + ) + tied_keys_lora.append( + "shared_transf.self_attn.linear_k_lora_A_list." + str(lora_id) + ".weight" + ) + tied_keys_lora.append( + "shared_transf.self_attn.linear_v_lora_A_list." + str(lora_id) + ".weight" + ) + tied_keys_lora.append( + "shared_transf.self_attn.linear_q_lora_B_list." + str(lora_id) + ".weight" + ) + tied_keys_lora.append( + "shared_transf.self_attn.linear_k_lora_B_list." + str(lora_id) + ".weight" + ) + tied_keys_lora.append( + "shared_transf.self_attn.linear_v_lora_B_list." + str(lora_id) + ".weight" + ) lora_id += 1 self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_lora] - layers.append(Zamba2HybridLayer(next(blocks), next(linear_layers), next(mamba_layers))) + layers.append(Zamba2HybridLayer(block, next(linear_layers), next(mamba_layers))) else: layers.append(next(mamba_layers)) self.layers = nn.ModuleList(layers) @@ -1734,7 +1831,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Zamba2DynamicCache] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -1771,7 +1868,7 @@ def forward( if use_cache and past_key_values is None: logger.warning_once( - "Zamba2 requires an initialized `Zamba2DynamicCache` to return a cache. None was " + "Zamba2 requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was " "provided, so no cache will be returned." ) @@ -1917,7 +2014,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Zamba2DynamicCache] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 7facf6c1701656..3b254481a908e1 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -73,19 +73,16 @@ class Zamba2Config(PretrainedConfig): Zamba2 model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Zamba2 model. + [Zyphra/Zamba2-2.7B](https://huggingface.co/Zyphra/Zamba2-2.7B) + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. - - Args: vocab_size (`int`, *optional*, defaults to 32000): Vocabulary size of the Zamba2 model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`Zamba2Model`] max_position_embeddings (`int`, *optional*, defaults to 4096): The maximum sequence length that this model might ever be used with. - tie_word_embeddings (`bool`, *optional*, defaults to `True`): - Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the - model has a output word embedding layer. hidden_size (`int`, *optional*, defaults to 2560): Dimension of the hidden representations. num_hidden_layers (`int`, *optional*, defaults to 54): @@ -95,7 +92,7 @@ class Zamba2Config(PretrainedConfig): mamba_d_state (`int`, *optional*, defaults to 64): shape of the state space latents. mamba_d_conv (`int`, *optional*, defaults to 4): Size of the convolution kernel. mamba_expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size. - mamba_ngroups (`int`, *optional*, defaults to 8): + mamba_ngroups (`int`, *optional*, defaults to 1): Number of groups for the evolution matrices of mamba 2. time_step_min (`float`, *optional*, defaults to 0.001): Minimum `time_step` used to bound `dt_proj.bias`. @@ -105,16 +102,10 @@ class Zamba2Config(PretrainedConfig): Minimum clamping value of the `dt_proj.bias` layer initialization. time_step_limit (`tuple`, *optional*): Accepted range of time step values. - mamba_dt_rank (`Union[int,str]`, *optional*, defaults to `"auto"`): - Rank of the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)` n_mamba_heads (`int`, *optional*, defaults to 1): Number of heads for the evolution matrices of mamba 2. use_conv_bias (`bool`, *optional*, defaults to `True`): Whether or not to use bias in the convolution layer of the mixer block. - mamba_proj_bias (`bool`, *optional*, defaults to `False`): - Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block - hidden_mamba_act (`str`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) adjacent to the mamba conv. chunk_size (`int`, *optional*, defaults to 256): Size of the chunks that will comprise the sequence. add_bias_linear (`bool`, *optional*, defaults to `False`): @@ -144,11 +135,11 @@ class Zamba2Config(PretrainedConfig): Rank of the LoRA in the shared MLP and shared attention layers. use_mem_rope (`bool`, *optional*, defaults to `False`): If True, includes RoPE in the shared attention layers. - rope_theta (`float`, *optional*, defaults to 10000.0): + rope_theta (`float`, *optional*, defaults to `10000.0`): The base period of the RoPE embeddings. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-5): + rms_norm_eps (`float`, *optional*, defaults to 1e-05): The epsilon used by the rms normalization layers. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only @@ -165,6 +156,16 @@ class Zamba2Config(PretrainedConfig): The id of the "beginning-of-sequence" token. eos_token_id (`int`, *optional*, defaults to 2): The id of the "end-of-sequence" token. + use_long_context (`bool`, *optional*, defaults to `False`): + Activates the context-extended version of Zamba by modifying RoPE. + ```python + >>> from transformers import Zamba2Model, Zamba2Config + >>> # Initializing a Zamba2-2.7B style configuration + >>> configuration = Zamba2Config() + >>> # Initializing a model from the Zamba2-2.7B style configuration + >>> model = Zamba2Model(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config """ model_type = "zamba2" @@ -174,7 +175,6 @@ def __init__( self, vocab_size=32000, max_position_embeddings=4096, - tie_word_embeddings=True, hidden_size=2560, num_hidden_layers=54, layers_block_type=None, @@ -186,10 +186,7 @@ def __init__( time_step_max=0.1, time_step_floor=1e-4, time_step_limit=None, - mamba_dt_rank="auto", n_mamba_heads=1, - mamba_proj_bias=False, - hidden_mamba_act="silu", use_conv_bias=True, chunk_size=256, add_bias_linear=False, @@ -218,13 +215,10 @@ def __init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, **kwargs, ) - self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings - self.tie_word_embeddings = tie_word_embeddings self.hidden_size = hidden_size if intermediate_size is None: self.intermediate_size = 4 * hidden_size @@ -242,17 +236,13 @@ def __init__( self.mamba_d_state = mamba_d_state self.mamba_d_conv = mamba_d_conv self.mamba_expand = mamba_expand - self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank self.add_bias_linear = add_bias_linear self.mamba_ngroups = mamba_ngroups self.n_mamba_heads = n_mamba_heads self.mamba_headdim = int(mamba_expand * hidden_size) // n_mamba_heads - self.mamba_proj_bias = mamba_proj_bias - self.hidden_mamba_act = hidden_mamba_act self.use_conv_bias = use_conv_bias self.chunk_size = chunk_size self.time_step_limit = time_step_limit - self.use_shared_mlp_lora = use_shared_mlp_lora self.use_shared_attention_lora = use_shared_attention_lora self.lora_rank = lora_rank @@ -262,21 +252,12 @@ def __init__( self.time_step_floor = time_step_floor if use_long_context: self.max_position_embeddings = 16384 - - # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads - self.num_attention_heads = num_attention_heads self.kv_channels = self.hidden_size // self.num_attention_heads self.num_query_groups = self.num_attention_heads - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - - self.use_cache = use_cache - self.num_logits_to_keep = num_logits_to_keep - # Below, "mamba" stands for mamba layer, "hybrid" stands for hybrid layer (composed by a shared transformer followed by mamba layer) if layers_block_type is None: self.layers_block_type = ( @@ -290,6 +271,14 @@ def __init__( ) else: self.layers_block_type = layers_block_type + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.num_logits_to_keep = num_logits_to_keep + + +class Zamba2RMSNorm(ZambaRMSNorm): + pass def count_mem_blocks_in_config(config: Zamba2Config): @@ -488,10 +477,6 @@ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTens raise NotImplementedError("Zamba2DynamicCache does not have a legacy cache equivalent.") -class Zamba2RMSNorm(ZambaRMSNorm): - pass - - class MambaRMSNormGated(torch.nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() @@ -1649,7 +1634,8 @@ def _init_weights(self, module): module.D._no_weight_decay = True dt = torch.exp( - torch.rand(self.config.n_mamba_heads) * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + torch.rand(self.config.n_mamba_heads) + * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + math.log(self.config.time_step_min) ).clamp(min=self.config.time_step_floor) # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 @@ -1783,8 +1769,12 @@ def __init__(self, config: Zamba2Config): lora_id = 0 for _layer_type in self.layers_block_type: if _layer_type == "hybrid" and lora_id % config.num_mem_blocks == block.block_id: - tied_keys_lora.append('shared_transf.feed_forward.gate_up_proj_lora_A_list.' + str(lora_id) + '.weight') - tied_keys_lora.append('shared_transf.feed_forward.gate_up_proj_lora_B_list.' + str(lora_id) + '.weight') + tied_keys_lora.append( + "shared_transf.feed_forward.gate_up_proj_lora_A_list." + str(lora_id) + ".weight" + ) + tied_keys_lora.append( + "shared_transf.feed_forward.gate_up_proj_lora_B_list." + str(lora_id) + ".weight" + ) lora_id += 1 self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_lora] if config.use_shared_attention_lora: @@ -1792,12 +1782,24 @@ def __init__(self, config: Zamba2Config): lora_id = 0 for _layer_type in self.layers_block_type: if _layer_type == "hybrid" and lora_id % config.num_mem_blocks == block.block_id: - tied_keys_lora.append('shared_transf.self_attn.linear_q_lora_A_list.' + str(lora_id) + '.weight') - tied_keys_lora.append('shared_transf.self_attn.linear_k_lora_A_list.' + str(lora_id) + '.weight') - tied_keys_lora.append('shared_transf.self_attn.linear_v_lora_A_list.' + str(lora_id) + '.weight') - tied_keys_lora.append('shared_transf.self_attn.linear_q_lora_B_list.' + str(lora_id) + '.weight') - tied_keys_lora.append('shared_transf.self_attn.linear_k_lora_B_list.' + str(lora_id) + '.weight') - tied_keys_lora.append('shared_transf.self_attn.linear_v_lora_B_list.' + str(lora_id) + '.weight') + tied_keys_lora.append( + "shared_transf.self_attn.linear_q_lora_A_list." + str(lora_id) + ".weight" + ) + tied_keys_lora.append( + "shared_transf.self_attn.linear_k_lora_A_list." + str(lora_id) + ".weight" + ) + tied_keys_lora.append( + "shared_transf.self_attn.linear_v_lora_A_list." + str(lora_id) + ".weight" + ) + tied_keys_lora.append( + "shared_transf.self_attn.linear_q_lora_B_list." + str(lora_id) + ".weight" + ) + tied_keys_lora.append( + "shared_transf.self_attn.linear_k_lora_B_list." + str(lora_id) + ".weight" + ) + tied_keys_lora.append( + "shared_transf.self_attn.linear_v_lora_B_list." + str(lora_id) + ".weight" + ) lora_id += 1 self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_lora] layers.append(Zamba2HybridLayer(block, next(linear_layers), next(mamba_layers))) diff --git a/tests/models/zamba2/test_modeling_zamba2.py b/tests/models/zamba2/test_modeling_zamba2.py index cdfb84b40bbb4d..e3ca547923b481 100644 --- a/tests/models/zamba2/test_modeling_zamba2.py +++ b/tests/models/zamba2/test_modeling_zamba2.py @@ -80,7 +80,7 @@ def __init__( num_labels=3, num_choices=4, scope=None, - layers_block_type = ['mamba', 'hybrid'], + layers_block_type=["mamba", "hybrid"], num_mem_blocks=1, ): self.parent = parent @@ -137,7 +137,7 @@ def get_config(self): vocab_size=self.vocab_size, hidden_size=self.hidden_size, mamba_dt_rank=self.mamba_dt_rank, - mamba_d_state = self.mamba_d_state, + mamba_d_state=self.mamba_d_state, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads, n_mamba_heads=self.n_mamba_heads, @@ -221,9 +221,7 @@ def create_and_check_decoder_model_past_large_inputs( # first forward pass # Attention: Zamba2 needs the cache to be initialized to return a cache! - past_key_values = Zamba2DynamicCache( - config, input_ids.shape[0], model.dtype, device=model.device - ) + past_key_values = Zamba2DynamicCache(config, input_ids.shape[0], model.dtype, device=model.device) outputs = model( input_ids, attention_mask=input_mask, @@ -600,7 +598,9 @@ def test_simple_batched_generate_with_padding(self): self.model.to(torch_device) inputs = self.tokenizer( - ["Hey how are you doing on this lovely evening?", "When did the Roman empire "], padding=True, return_tensors="pt" + ["Hey how are you doing on this lovely evening?", "When did the Roman empire "], + padding=True, + return_tensors="pt", ).to(torch_device) out = self.model.generate(**inputs, do_sample=False, max_new_tokens=10) output_sentences = self.tokenizer.batch_decode(out) @@ -608,14 +608,16 @@ def test_simple_batched_generate_with_padding(self): output_sentences[0], " Hey how are you doing on this lovely evening?\n\nI'm doing well, thanks for", ) - + self.assertEqual( output_sentences[1], "[PAD][PAD][PAD][PAD] When did the Roman empire 1st fall?\nThe Roman Empire fell in", ) with torch.no_grad(): - logits = self.model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]).logits.to(dtype=torch.float32) + logits = self.model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]).logits.to( + dtype=torch.float32 + ) EXPECTED_LOGITS_NO_GRAD_0 = torch.tensor( [