From d0ed91720f46ef9d5b47e6110dddb2a3e71482b7 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Fri, 13 Dec 2024 18:42:07 +0100 Subject: [PATCH] fix cache handling and causality for cross attention --- .../moonshine/configuration_moonshine.py | 236 ++ .../moonshine/convert_usefulsensors_to_hf.py | 145 ++ .../models/moonshine/modeling_moonshine.py | 2078 +++++++++++++++++ .../models/moonshine/modular_moonshine.py | 26 +- 4 files changed, 2469 insertions(+), 16 deletions(-) create mode 100644 src/transformers/models/moonshine/configuration_moonshine.py create mode 100644 src/transformers/models/moonshine/convert_usefulsensors_to_hf.py create mode 100644 src/transformers/models/moonshine/modeling_moonshine.py diff --git a/src/transformers/models/moonshine/configuration_moonshine.py b/src/transformers/models/moonshine/configuration_moonshine.py new file mode 100644 index 00000000000000..44f043d6d3c3dd --- /dev/null +++ b/src/transformers/models/moonshine/configuration_moonshine.py @@ -0,0 +1,236 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/moonshine/modular_moonshine.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_moonshine.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 + +from ...configuration_utils import PretrainedConfig + + +class MoonshineConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MoonshineModel`]. It is used to instantiate a Moonshine + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Moonshine + [UsefulSensors/moonshine](https://huggingface.co/UsefulSensors/moonshine). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32768): + Vocabulary size of the Moonshine model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MoonshineModel`]. + hidden_size (`int`, *optional*, defaults to 288): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer encoder and decoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder and decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + encoder_hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder. + decoder_hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. TODO: check this + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + decoder_start_token_id (`int`, *optional*, defaults to 1): + Corresponds to the "<|startoftranscript|>" token, which is automatically used when no `decoder_input_ids` + are provided to the `generate` function. It is used to guide the model`s generation process depending on + the task. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Whether the model is used as an encoder/decoder or not. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. TODO: check this + partial_rotary_factor (`float`, *optional*, defaults to 0.5): + Percentage of the query and keys which will have rotary embedding. TODO: check this + ff_mult (`int`, *optional*, defaults to 4): + Factor by which to scale the intermediate size. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + qk_layernorm (`bool`, *optional*, defaults to `False`): + Whether or not to normalize the Queries and Keys after projecting the hidden states. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + bos_token_id (`int`, *optional*, defaults to 1): + Denotes beginning of sequences token id. + eos_token_id (`int`, *optional*, defaults to 2): + Denotes end of sequences token id. + apply_spec_augment (`bool`, *optional*, defaults to `False`): + Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see + [SpecAugment: A Simple Data Augmentation Method for Automatic Speech + Recognition](https://arxiv.org/abs/1904.08779). + mask_time_prob (`float`, *optional*, defaults to 0.05): + Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking + procecure generates `mask_time_prob*len(time_axis)/mask_time_length` independent masks over the axis. If + reasoning from the propability of each feature vector to be chosen as the start of the vector span to be + masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the + actual percentage of masked vectors. This is only relevant if `apply_spec_augment == True`. + mask_time_length (`int`, *optional*, defaults to 10): + Length of vector span along the time axis. + mask_time_min_masks (`int`, *optional*, defaults to 2),: + The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step, + irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length < + mask_time_min_masks'' + mask_feature_prob (`float`, *optional*, defaults to 0.0): + Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The + masking procecure generates `mask_feature_prob*len(feature_axis)/mask_time_length` independent masks over + the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector + span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap + may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is + True`. + mask_feature_length (`int`, *optional*, defaults to 10): + Length of vector span along the feature axis. + mask_feature_min_masks (`int`, *optional*, defaults to 0),: + The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time + step, irrespectively of `mask_feature_prob`. Only relevant if + `mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks`. + + Example: + + ```python + >>> from transformers import MoonshineModel, MoonshineConfig + + >>> # Initializing a Moonshine style configuration + >>> configuration = MoonshineConfig().from_pretrained("UsefulSensors/moonshine") + + >>> # Initializing a model from the configuration + >>> model = MoonshineModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "moonshine" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32768, + hidden_size=288, + intermediate_size=None, + num_hidden_layers=6, + num_attention_heads=8, + num_key_value_heads=None, + encoder_hidden_act="gelu", + decoder_hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + layer_norm_eps=1e-5, + decoder_start_token_id=1, + use_cache=True, + is_encoder_decoder=True, + rope_theta=10000.0, + partial_rotary_factor=0.5, + attention_bias=False, + attention_dropout=0.0, + qk_layernorm=False, + rope_scaling=None, + ff_mult=4, + bos_token_id=1, + eos_token_id=2, + apply_spec_augment=False, + mask_time_prob=0.05, + mask_time_length=10, + mask_time_min_masks=2, + mask_feature_prob=0.0, + mask_feature_length=10, + mask_feature_min_masks=0, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.encoder_hidden_act = encoder_hidden_act + self.decoder_hidden_act = decoder_hidden_act + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.decoder_start_token_id = decoder_start_token_id + self.use_cache = use_cache + self.is_encoder_decoder = is_encoder_decoder + self.rope_theta = rope_theta + self.partial_rotary_factor = partial_rotary_factor + + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.qk_layernorm = qk_layernorm + self.rope_scaling = rope_scaling + self.ff_mult = ff_mult + + # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779 + self.apply_spec_augment = apply_spec_augment + self.mask_time_prob = mask_time_prob + self.mask_time_length = mask_time_length + self.mask_time_min_masks = mask_time_min_masks + self.mask_feature_prob = mask_feature_prob + self.mask_feature_length = mask_feature_length + self.mask_feature_min_masks = mask_feature_min_masks + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + decoder_start_token_id=decoder_start_token_id, + **kwargs, + ) diff --git a/src/transformers/models/moonshine/convert_usefulsensors_to_hf.py b/src/transformers/models/moonshine/convert_usefulsensors_to_hf.py new file mode 100644 index 00000000000000..cf0010fd552ddd --- /dev/null +++ b/src/transformers/models/moonshine/convert_usefulsensors_to_hf.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python +"""Converts a Moonshine model in Useful Sensors format to Hugging Face format.""" +# Copyright 2022 The HuggingFace Inc. team and the OpenAI team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from huggingface_hub import hf_hub_download + +import h5py +import torch +import numpy as np +import re + +from transformers.models.moonshine.modeling_moonshine import MoonshineConfig +from transformers.models.moonshine.modeling_moonshine import MoonshineDecoder + + +# Copied from https://github.com/usefulsensors/moonshine/blob/a1d77cc573b0471ac4602b86f67b3f48d67df1a9/moonshine/model.py +def _get_weights(model_name): + repo = "UsefulSensors/moonshine" + + return ( + hf_hub_download(repo, f"{x}.weights.h5", subfolder=model_name) + for x in ("preprocessor", "encoder", "decoder") + ) + + +def _read_h5_weights(group, current_key="", weights={}): + for key in group.keys(): + full_key = f"{current_key}.{key}" if current_key else key + if isinstance(group[key], h5py.Dataset): + w = np.array(group[key]) + w = torch.from_numpy(w) + if len(w.shape) > 1: + if len(w.shape) == 3: + hidden_size = max(list(w.shape)) + try: + w = w.reshape(hidden_size, hidden_size) + except RuntimeError: + # meaning its a conv layers + pass + w = w.transpose(0, -1) + weights[full_key] = w + else: + _read_h5_weights(group[key], full_key, weights) + return weights + + +def _convert_layer_names(name, gated_mlp=False): + name = re.sub(r'layers\.functional(?:_(\d+))?\.layers', lambda m: f'layers.{m.group(1) if m.group(1) else "0"}', name, count=1) + if gated_mlp: + name = re.sub(r"functional\.layers\.dense\.", "mlp.up_proj.", name) + name = re.sub(r"functional\.layers\.dense_1\.", "mlp.down_proj.", name) + else: + name = re.sub(r"functional\.layers\.sequential\.layers\.dense\.", "mlp.fc1.", name) + name = re.sub(r"functional\.layers\.sequential\.layers\.dense_1\.", "mlp.fc2.", name) + name = re.sub(r"layers\.sequential\.layers\.conv1d\.", "conv1.", name) + name = re.sub(r"layers\.sequential\.layers\.conv1d_1\.", "conv2.", name) + name = re.sub(r"layers\.sequential\.layers\.conv1d_2\.", "conv3.", name) + name = re.sub(r"layers\.sequential\.layers\.group_normalization\.", "groupnorm.", name) + name = re.sub(r"mha_with_rope\.key_dense", "self_attn.k_proj", name) + name = re.sub(r"mha_with_rope\.query_dense", "self_attn.q_proj", name) + name = re.sub(r"mha_with_rope\.value_dense", "self_attn.v_proj", name) + name = re.sub(r"mha_with_rope\.output_dense", "self_attn.dense", name) + name = re.sub(r"mha_precomputed_kv\.key_dense", "encoder_attn.k_proj", name) + name = re.sub(r"mha_precomputed_kv\.query_dense", "encoder_attn.q_proj", name) + name = re.sub(r"mha_precomputed_kv\.value_dense", "encoder_attn.v_proj", name) + name = re.sub(r"mha_precomputed_kv\.output_dense", "encoder_attn.dense", name) + name = re.sub(r"mha_causal_with_rope\.key_dense", "self_attn.k_proj", name) + name = re.sub(r"mha_causal_with_rope\.query_dense", "self_attn.q_proj", name) + name = re.sub(r"mha_causal_with_rope\.value_dense", "self_attn.v_proj", name) + name = re.sub(r"mha_causal_with_rope\.output_dense", "self_attn.dense", name) + name = re.sub(r"layer_normalization\.", "input_layernorm.", name) + name = re.sub(r"layer_normalization_1\.", "post_attention_layernorm.", name) + name = re.sub(r"layer_normalization_2\.", "final_layernorm.", name) + name = re.sub(r"vars\.0", "weight", name) + name = re.sub(r"vars\.1", "bias", name) + name = re.sub(r"layers\.reversible_embedding", "embed_tokens", name) + + return name + + +def _convert_weights(weights, encoder=True): + if "layers.rotary_embedding.vars.0" in weights: + weights.pop("layers.rotary_embedding.vars.0") + + converted_weights = {} + if encoder: + converted_weights["layer_norm.weight"] = weights.pop("layers.layer_normalization.vars.0") + else: + converted_weights["norm.weight"] = weights.pop("layers.layer_normalization.vars.0") + + for name, w in weights.items(): + if encoder: + new_name = _convert_layer_names(name) + else: + new_name = _convert_layer_names(name, gated_mlp=True) + converted_weights[new_name] = w + + return converted_weights + + +def convert_usefulsensors_moonshine_to_hf(model_name, pytorch_dump_folder_path): + preprocessor_weights_path, encoder_weights_path, decoder_weights_path = _get_weights(model_name) + + with h5py.File(preprocessor_weights_path, 'r') as f: + loaded_preprocessor_weights = _read_h5_weights(f, weights={}) + + with h5py.File(encoder_weights_path, 'r') as f: + loaded_encoder_weights = _read_h5_weights(f, weights={}) + + with h5py.File(decoder_weights_path, 'r') as f: + loaded_decoder_weights = _read_h5_weights(f, weights={}) + + encoder_state_dict = {**loaded_encoder_weights, **loaded_preprocessor_weights} + encoder_state_dict = _convert_weights(encoder_state_dict) + + converted_decoder_weights = _convert_weights(loaded_decoder_weights, encoder=False) + converted_decoder_weights['embed_tokens.weight'] = converted_decoder_weights['embed_tokens.weight'].T + + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # # Required parameters + parser.add_argument("--model_name", type=str, help="Path to the downloaded checkpoints") + parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + args = parser.parse_args() + + convert_usefulsensors_moonshine_to_hf( + args.model_name, args.pytorch_dump_folder_path + ) + diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py new file mode 100644 index 00000000000000..9f46520f40d86d --- /dev/null +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -0,0 +1,2078 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/moonshine/modular_moonshine.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_moonshine.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 + +import copy +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from packaging import version +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPast, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + get_torch_version, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_moonshine import MoonshineConfig + + +if is_flash_attn_2_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "MoonshineConfig" + + +class MoonshineRotaryEmbedding(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[MoonshineConfig] = None, + ): + super().__init__() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`MoonshineRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.46" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.stack((freqs, freqs), dim=-1) + emb = emb.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)') + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class MoonshineNonGatedMLP(nn.Module): + def __init__(self, config: MoonshineConfig, hidden_act: str): + super().__init__() + config = copy.deepcopy(config) + config.hidden_act = hidden_act + if config.intermediate_size is None: + config.intermediate_size = config.hidden_size * config.ff_mult + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class MoonshineGatedMLP(nn.Module): + def __init__(self, config: MoonshineConfig, hidden_act: str): + super().__init__() + config = copy.deepcopy(config) + config.hidden_act = hidden_act + if config.intermediate_size is None: + config.intermediate_size = config.hidden_size * config.ff_mult * 2 + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=True) + self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=True) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + hidden_state = self.up_proj(hidden_state) + hidden_state, gate = hidden_state.chunk(2, dim=-1) + hidden_state = self.act_fn(gate) * hidden_state + return self.down_proj(hidden_state) + + +class MoonshineMLP: + def __new__(cls, config: MoonshineConfig, hidden_act: str): + if hidden_act == "gelu": + return MoonshineNonGatedMLP(config, hidden_act) + elif hidden_act == "silu": + return MoonshineGatedMLP(config, hidden_act) + else: + raise ValueError(f"Unsupported activation function: {hidden_act}, please use 'gelu' or 'silu'") + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def rotate_every_two(x: torch.Tensor) -> torch.Tensor: + x1 = x[:, :, :, ::2] + x2 = x[:, :, :, 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)') + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + q_embed = (q * cos) + (rotate_every_two(q) * sin) + k_embed = (k * cos) + (rotate_every_two(k) * sin) + return q_embed, k_embed + + +class MoonshineAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: MoonshineConfig, layer_idx: Optional[int] = None, is_causal: bool = False): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.rope_theta = config.rope_theta + self.rotary_ndims = max(config.hidden_size // config.num_attention_heads // 2, 32) + + self.is_causal = is_causal + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + + self.qk_layernorm = config.qk_layernorm + if self.qk_layernorm: + self.q_layernorm = nn.LayerNorm( + config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True + ) + self.k_layernorm = nn.LayerNorm( + config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True + ) + + self.rotary_emb = MoonshineRotaryEmbedding( + dim=self.rotary_ndims, + max_position_embeddings=config.max_position_embeddings, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + key_value_states: Optional[torch.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value.is_updated[self.layer_idx] = True + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache + + # use key_value_states if cross attention + current_states = key_value_states if key_value_states is not None else hidden_states + if is_cross_attention and past_key_value and is_updated: + # reuse k,v, cross_attentions + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if is_cross_attention and past_key_value is not None: + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + + if self.qk_layernorm: + query_states = self.q_layernorm(query_states) + key_states = self.k_layernorm(key_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + if not is_cross_attention: + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + + # Partial rotary embedding + query_rot, query_pass = ( + query_states[..., : self.rotary_ndims], + query_states[..., self.rotary_ndims :], + ) + key_rot, key_pass = ( + key_states[..., : self.rotary_ndims], + key_states[..., self.rotary_ndims :], + ) + # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + + # [batch_size, seq_length, num_heads, head_dim] + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) + + if past_key_value is not None: + if not is_cross_attention: + cache_kwargs = { + "sin": sin, + "cos": cos, + "partial_rotation_size": self.rotary_ndims, + "cache_position": cache_position, + } + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # Queries and keys upcast to fp32 is required by Moonshine-2 to avoid overflow + attn_weights = torch.matmul( + query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3) + ) / math.sqrt(self.head_dim) + + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights += causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.dense(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class MoonshineFlashAttention2(MoonshineAttention): + """ + Moonshine flash attention module. This module inherits from `MoonshineAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + key_value_states: Optional[torch.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value.is_updated[self.layer_idx] = True + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache + + # use key_value_states if cross attention + current_states = key_value_states if key_value_states is not None else hidden_states + if is_cross_attention and past_key_value and is_updated: + # reuse k,v, cross_attentions + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if is_cross_attention and past_key_value is not None: + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + + if self.qk_layernorm: + query_states = self.q_layernorm(query_states) + key_states = self.k_layernorm(key_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + if not is_cross_attention: + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + + # Partial rotary embedding + query_rot, query_pass = ( + query_states[..., : self.rotary_ndims], + query_states[..., self.rotary_ndims :], + ) + key_rot, key_pass = ( + key_states[..., : self.rotary_ndims], + key_states[..., self.rotary_ndims :], + ) + # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + + # [batch_size, seq_length, num_heads, head_dim] + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) + + if past_key_value is not None: + if not is_cross_attention: + cache_kwargs = { + "sin": sin, + "cos": cos, + "partial_rotation_size": self.rotary_ndims, + "cache_position": cache_position, + } + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_dropout = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. + + if query_states.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=attn_dropout, + softmax_scale=None, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.dense(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class MoonshineSdpaAttention(MoonshineAttention): + """ + SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `MoonshineAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") + + # Adapted from MoonshineAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + key_value_states: Optional[torch.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value.is_updated[self.layer_idx] = True + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache + + # use key_value_states if cross attention + current_states = key_value_states if key_value_states is not None else hidden_states + if is_cross_attention and past_key_value and is_updated: + # reuse k,v, cross_attentions + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if is_cross_attention and past_key_value is not None: + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + + if self.qk_layernorm: + query_states = self.q_layernorm(query_states) + key_states = self.k_layernorm(key_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + if not is_cross_attention: + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + + # Partial rotary embedding + query_rot, query_pass = ( + query_states[..., : self.rotary_ndims], + query_states[..., self.rotary_ndims :], + ) + key_rot, key_pass = ( + key_states[..., : self.rotary_ndims], + key_states[..., self.rotary_ndims :], + ) + # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + + # [batch_size, seq_length, num_heads, head_dim] + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) + + if past_key_value is not None: + if not is_cross_attention: + cache_kwargs = { + "sin": sin, + "cos": cos, + "partial_rotation_size": self.rotary_ndims, + "cache_position": cache_position, + } + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom + # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. + # Reference: https://github.com/pytorch/pytorch/issues/112577 + if self.require_contiguous_qkv and query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if self.is_causal and causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.dense(attn_output) + + return attn_output, None, past_key_value + + +MOONSHINE_ATTENTION_CLASSES = { + "eager": MoonshineAttention, + "flash_attention_2": MoonshineFlashAttention2, + "sdpa": MoonshineSdpaAttention, +} + + +class MoonshineEncoderLayer(nn.Module): + def __init__(self, config: MoonshineConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MOONSHINE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = MoonshineMLP(config, config.encoder_hidden_act) + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class MoonshineDecoderLayer(nn.Module): + def __init__(self, config: MoonshineConfig, layer_idx: int = None): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MOONSHINE_ATTENTION_CLASSES[config._attn_implementation]( + config=config, layer_idx=layer_idx, is_causal=True + ) + self.encoder_attn = MOONSHINE_ATTENTION_CLASSES[config._attn_implementation]( + config=config, layer_idx=layer_idx, is_causal=False + ) + + self.mlp = MoonshineMLP(config, config.decoder_hidden_act) + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False) + self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + encoder_position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + encoder_position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + encoder_position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, encoder_seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_ids=encoder_position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + position_embeddings=encoder_position_embeddings, + ) + hidden_states = residual + hidden_states + + # add cross-attn to positions 1 of present_key_value tuple + present_key_value = (present_key_value, cross_attn_present_key_value) + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +MOONSHINE_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MoonshineConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Moonshine Model outputting raw hidden-states without any specific head on top.", + MOONSHINE_START_DOCSTRING, +) +class MoonshinePreTrainedModel(PreTrainedModel): + config_class = MoonshineConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MoonshineDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +MOONSHINE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Moonshine Model outputting raw hidden-states without any specific head on top.", + MOONSHINE_START_DOCSTRING, +) +class MoonshineEncoder(MoonshinePreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MoonshineDecoderLayer`] + + Args: + config: MoonshineConfig + """ + + main_input_name = "input_features" + + def __init__(self, config: MoonshineConfig): + super().__init__(config) + self.config = config + embed_dim = config.hidden_size + + self.conv1 = nn.Conv1d(1, embed_dim, kernel_size=127, stride=64, bias=False) + self.conv2 = nn.Conv1d(embed_dim, 2 * embed_dim, kernel_size=7, stride=3) + self.conv3 = nn.Conv1d(2 * embed_dim, embed_dim, kernel_size=3, stride=2) + self.groupnorm = nn.GroupNorm(num_groups=1, num_channels=embed_dim, eps=1e-5) + + self.rotary_emb = MoonshineRotaryEmbedding( + dim=max(config.hidden_size // config.num_attention_heads // 2, 32), + max_position_embeddings=config.max_position_embeddings, + ) + + self.layers = nn.ModuleList([MoonshineEncoderLayer(config, idx) for idx in range(config.num_hidden_layers)]) + self.layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps, bias=False) + + self.gradient_checkpointing = False + self.post_init() + + def get_input_embeddings(self): + return self.conv1 + + def set_input_embeddings(self, value: nn.Module): + self.conv1 = value + + @add_start_docstrings_to_model_forward(MOONSHINE_INPUTS_DOCSTRING) + def forward( + self, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_features is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.preprocess(input_features) + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache or past_key_values is not None: + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def preprocess(self, input_features: torch.FloatTensor): + input_features = input_features.unsqueeze(1) + inputs_embeds = nn.functional.tanh(self.conv1(input_features)) + inputs_embeds = self.groupnorm(inputs_embeds) + inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) + inputs_embeds = nn.functional.gelu(self.conv3(inputs_embeds)) + inputs_embeds = inputs_embeds.permute(0, 2, 1) + return inputs_embeds + + +@add_start_docstrings( + "The bare Moonshine Model outputting raw hidden-states without any specific head on top.", + MOONSHINE_START_DOCSTRING, +) +class MoonshineDecoder(MoonshinePreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MoonshineDecoderLayer`] + + Args: + config: MoonshineConfig + """ + + def __init__(self, config: MoonshineConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [MoonshineDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False) + self.rotary_emb = MoonshineRotaryEmbedding( + dim=max(config.hidden_size // config.num_attention_heads // 2, 32), + max_position_embeddings=config.max_position_embeddings, + ) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(MOONSHINE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + encoder_position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + return_self_attention_cache = False + if use_cache or past_key_values is not None: + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + if encoder_position_ids is None: + encoder_position_ids = torch.arange( + encoder_hidden_states.shape[1], device=encoder_hidden_states.device + ).unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + encoder_position_embeddings = self.rotary_emb(encoder_hidden_states, encoder_position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + encoder_hidden_states, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + encoder_hidden_states=encoder_hidden_states, + position_ids=position_ids, + encoder_position_ids=encoder_position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + encoder_position_embeddings=encoder_position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = past_key_values if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[torch.LongTensor] = None, + min_masks: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for + ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on + CPU as part of the preprocessing during training. + + Args: + shape: The shape for which to compute masks. This should be of a tuple of size 2 where + the first element is the batch size and the second element is the length of the axis to span. + mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of + independently generated mask spans of length `mask_length` is computed by + `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the + actual percentage will be smaller. + mask_length: size of the mask + min_masks: minimum number of masked spans + attention_mask: A (right-padded) attention mask which independently shortens the feature axis of + each batch dimension. + """ + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" + f" and `sequence_length`: {sequence_length}`" + ) + + # epsilon is used for probabilistic rounding + epsilon = np.random.rand(1).item() + + def compute_num_masked_span(input_length): + """Given input length, compute how many spans should be masked""" + num_masked_span = int(mask_prob * input_length / mask_length + epsilon) + num_masked_span = max(num_masked_span, min_masks) + + # make sure num masked span <= sequence_length + if num_masked_span * mask_length > sequence_length: + num_masked_span = sequence_length // mask_length + + # make sure num_masked span is also <= input_length - (mask_length - 1) + if input_length - (mask_length - 1) < num_masked_span: + num_masked_span = max(input_length - (mask_length - 1), 0) + + return num_masked_span + + # compute number of masked spans in batch + input_lengths = ( + attention_mask.sum(-1).detach().tolist() + if attention_mask is not None + else [sequence_length for _ in range(batch_size)] + ) + + # SpecAugment mask to fill + spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) + spec_aug_mask_idxs = [] + + max_num_masked_span = compute_num_masked_span(sequence_length) + + if max_num_masked_span == 0: + return spec_aug_mask + + for input_length in input_lengths: + # compute num of masked spans for this input + num_masked_span = compute_num_masked_span(input_length) + + # get random indices to mask + spec_aug_mask_idx = np.random.choice( + np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False + ) + + # pick first sampled index that will serve as a dummy index to pad vector + # to ensure same dimension for all batches due to probabilistic rounding + # Picking first sample just pads those vectors twice. + if len(spec_aug_mask_idx) == 0: + # this case can only happen if `input_length` is strictly smaller then + # `sequence_length` in which case the last token has to be a padding + # token which we can use as a dummy mask id + dummy_mask_idx = sequence_length - 1 + else: + dummy_mask_idx = spec_aug_mask_idx[0] + + spec_aug_mask_idx = np.concatenate( + [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] + ) + spec_aug_mask_idxs.append(spec_aug_mask_idx) + + spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) + + # expand masked indices to masked spans + spec_aug_mask_idxs = np.broadcast_to( + spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) + ) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) + + # add offset to the starting indexes so that indexes now create a span + offsets = np.arange(mask_length)[None, None, :] + offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( + batch_size, max_num_masked_span * mask_length + ) + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets + + # ensure that we cannot have indices larger than sequence_length + if spec_aug_mask_idxs.max() > sequence_length - 1: + spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + + # scatter indices to mask + np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) + + return spec_aug_mask + + +@add_start_docstrings( + "The bare Moonshine Model outputting raw hidden-states without any specific head on top.", + MOONSHINE_START_DOCSTRING, +) +class MoonshineModel(MoonshinePreTrainedModel): + def __init__(self, config: MoonshineConfig): + super().__init__(config) + self.encoder = MoonshineEncoder(config) + self.decoder = MoonshineDecoder(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.decoder.embed_tokens = value + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def freeze_encoder(self): + """ + Calling this function will disable the gradient computation for the Moonshine encoder so that its parameters will + not be updated during training. + """ + self.encoder._freeze_parameters() + + def _mask_input_features( + self, + input_features: torch.FloatTensor, + attention_mask: Optional[torch.LongTensor] = None, + ): + """ + Masks extracted features along time axis and/or along feature axis according to + [SpecAugment](https://arxiv.org/abs/1904.08779). + """ + + # `config.apply_spec_augment` can set masking to False + if not getattr(self.config, "apply_spec_augment", True): + return input_features + + # generate indices & apply SpecAugment along time axis + batch_size, hidden_size, sequence_length = input_features.size() + + if self.config.mask_time_prob > 0 and self.training: + # generate indices & apply SpecAugment along time axis + mask_time_indices = _compute_mask_indices( + (batch_size, sequence_length), + mask_prob=self.config.mask_time_prob, + mask_length=self.config.mask_time_length, + attention_mask=attention_mask, + min_masks=self.config.mask_time_min_masks, + ) + mask_time_indices = torch.tensor(mask_time_indices, device=input_features.device, dtype=torch.bool) + mask_time_indices = mask_time_indices[:, None].expand(-1, hidden_size, -1) + input_features[mask_time_indices] = 0 + + if self.config.mask_feature_prob > 0 and self.training: + # generate indices & apply SpecAugment along feature axis + mask_feature_indices = _compute_mask_indices( + (batch_size, hidden_size), + mask_prob=self.config.mask_feature_prob, + mask_length=self.config.mask_feature_length, + min_masks=self.config.mask_feature_min_masks, + ) + mask_feature_indices = torch.tensor(mask_feature_indices, device=input_features.device, dtype=torch.bool) + input_features[mask_feature_indices] = 0 + + return input_features + + @add_start_docstrings_to_model_forward(MOONSHINE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None, + decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, + decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: + r""" + Returns: + + Example: + Returns: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoFeatureExtractor, WhisperModel + >>> from datasets import load_dataset + + >>> model = WhisperModel.from_pretrained("openai/whisper-base") + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base") + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt") + >>> input_features = inputs.input_features + >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id + >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state + >>> list(last_hidden_state.shape) + [1, 2, 512] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + input_features = self._mask_input_features(input_features, attention_mask=attention_mask) + + encoder_outputs = self.encoder( + input_features, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + position_ids=decoder_position_ids, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +# Copied from transformers.models.bart.modeling_bart.shift_tokens_right +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +class MoonshineForConditionalGeneration(MoonshinePreTrainedModel, GenerationMixin): + _tied_weights_keys = ["proj_out.weight"] + + def __init__(self, config: MoonshineConfig): + super().__init__(config) + self.model = MoonshineModel(config) + self.proj_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def get_output_embeddings(self): + return self.proj_out + + def set_output_embeddings(self, new_embeddings): + self.proj_out = new_embeddings + + def get_input_embeddings(self) -> nn.Module: + return self.model.get_input_embeddings() + + @property + def encoder(self): + return self.get_encoder() + + @add_start_docstrings_to_model_forward(MOONSHINE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None, + decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, + decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` + or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is + only computed for the tokens with labels in `[0, ..., config.vocab_size]`. `sequence_length` should be smaller than or equal to `config.max_target_positions`. + + Returns: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoProcessor, WhisperForConditionalGeneration + >>> from datasets import load_dataset + + >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") + >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + + >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt") + >>> input_features = inputs.input_features + + >>> generated_ids = model.generate(inputs=input_features) + + >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + >>> transcription + ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if labels.shape[1] > self.max_target_positions: + raise ValueError( + f"Labels' sequence length {labels.shape[1]} cannot exceed the maximum allowed length of {self.max_target_positions} tokens." + ) + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.model( + input_features, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + past_key_values=past_key_values, + decoder_inputs_embeds=decoder_inputs_embeds, + decoder_position_ids=decoder_position_ids, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + lm_logits = self.proj_out(outputs[0]) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # move labels to correct device to enable PP + labels = labels.to(lm_logits.device) + loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index be24c7837e8360..bea8370b7313ea 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -462,6 +462,8 @@ def forward( value_states = self.v_proj(current_states) key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if is_cross_attention and past_key_value is not None: + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, {"cache_position": cache_position}) if self.qk_layernorm: query_states = self.q_layernorm(query_states) @@ -505,11 +507,7 @@ def forward( "partial_rotation_size": self.rotary_ndims, "cache_position": cache_position, } - else: - cache_kwargs = { - "cache_position": cache_position, - } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -587,6 +585,8 @@ def forward( value_states = self.v_proj(current_states) key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if is_cross_attention and past_key_value is not None: + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, {"cache_position": cache_position}) if self.qk_layernorm: query_states = self.q_layernorm(query_states) @@ -630,11 +630,7 @@ def forward( "partial_rotation_size": self.rotary_ndims, "cache_position": cache_position, } - else: - cache_kwargs = { - "cache_position": cache_position, - } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. @@ -732,6 +728,8 @@ def forward( value_states = self.v_proj(current_states) key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if is_cross_attention and past_key_value is not None: + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, {"cache_position": cache_position}) if self.qk_layernorm: query_states = self.q_layernorm(query_states) @@ -775,11 +773,7 @@ def forward( "partial_rotation_size": self.rotary_ndims, "cache_position": cache_position, } - else: - cache_kwargs = { - "cache_position": cache_position, - } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -839,7 +833,7 @@ def __init__(self, config: MoonshineConfig, layer_idx: int = None): self.hidden_size = config.hidden_size self.self_attn = MOONSHINE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx, is_causal=True) - self.encoder_attn = MOONSHINE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx, is_causal=True) + self.encoder_attn = MOONSHINE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx, is_causal=False) self.mlp = MoonshineMLP(config, config.decoder_hidden_act) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False)