diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index b651840412023a..e89407c7eeaa10 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -5351,7 +5351,6 @@ ) from .models.gpt_neo import GPTNeoConfig from .models.gpt_neox import GPTNeoXConfig - from .models.moonshine import MoonshineConfig from .models.gpt_neox_japanese import ( GPTNeoXJapaneseConfig, ) @@ -5499,6 +5498,7 @@ from .models.mobilevitv2 import ( MobileViTV2Config, ) + from .models.moonshine import MoonshineConfig from .models.moshi import ( MoshiConfig, MoshiDepthConfig, @@ -6022,7 +6022,6 @@ from .models.gemma import GemmaTokenizerFast from .models.gpt2 import GPT2TokenizerFast from .models.gpt_neox import GPTNeoXTokenizerFast - from .models.moonshine import MoonshineTokenizer from .models.gpt_neox_japanese import GPTNeoXJapaneseTokenizer from .models.herbert import HerbertTokenizerFast from .models.layoutlm import LayoutLMTokenizerFast @@ -6037,6 +6036,7 @@ from .models.mbart import MBartTokenizerFast from .models.mbart50 import MBart50TokenizerFast from .models.mobilebert import MobileBertTokenizerFast + from .models.moonshine import MoonshineTokenizer from .models.mpnet import MPNetTokenizerFast from .models.mt5 import MT5TokenizerFast from .models.mvp import MvpTokenizerFast @@ -7116,14 +7116,6 @@ GPTNeoXModel, GPTNeoXPreTrainedModel, ) - from .models.moonshine import ( - MoonshineForCausalLM, - MoonshineForQuestionAnswering, - MoonshineForSequenceClassification, - MoonshineForTokenClassification, - MoonshineModel, - MoonshinePreTrainedModel, - ) from .models.gpt_neox_japanese import ( GPTNeoXJapaneseForCausalLM, GPTNeoXJapaneseModel, @@ -7463,6 +7455,14 @@ MobileViTV2Model, MobileViTV2PreTrainedModel, ) + from .models.moonshine import ( + MoonshineForCausalLM, + MoonshineForQuestionAnswering, + MoonshineForSequenceClassification, + MoonshineForTokenClassification, + MoonshineModel, + MoonshinePreTrainedModel, + ) from .models.moshi import ( MoshiForCausalLM, MoshiForConditionalGeneration, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index ded8e359e30443..a95d4f348ab1c4 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -103,7 +103,6 @@ gpt_bigcode, gpt_neo, gpt_neox, - moonshine, gpt_neox_japanese, gpt_sw3, gptj, @@ -163,6 +162,7 @@ mobilenet_v2, mobilevit, mobilevitv2, + moonshine, moshi, mpnet, mpt, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 2696bda883a628..53457d1fb08b81 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -121,7 +121,6 @@ ("gpt_bigcode", "GPTBigCodeConfig"), ("gpt_neo", "GPTNeoConfig"), ("gpt_neox", "GPTNeoXConfig"), - ("moonshine", "MoonshineConfig"), ("gpt_neox_japanese", "GPTNeoXJapaneseConfig"), ("gptj", "GPTJConfig"), ("gptsan-japanese", "GPTSanJapaneseConfig"), @@ -181,6 +180,7 @@ ("mobilenet_v2", "MobileNetV2Config"), ("mobilevit", "MobileViTConfig"), ("mobilevitv2", "MobileViTV2Config"), + ("moonshine", "MoonshineConfig"), ("moshi", "MoshiConfig"), ("mpnet", "MPNetConfig"), ("mpt", "MptConfig"), @@ -426,7 +426,6 @@ ("gpt_bigcode", "GPTBigCode"), ("gpt_neo", "GPT Neo"), ("gpt_neox", "GPT NeoX"), - ("moonshine", "moonshine"), ("gpt_neox_japanese", "GPT NeoX Japanese"), ("gptj", "GPT-J"), ("gptsan-japanese", "GPTSAN-japanese"), @@ -496,6 +495,7 @@ ("mobilenet_v2", "MobileNetV2"), ("mobilevit", "MobileViT"), ("mobilevitv2", "MobileViTV2"), + ("moonshine", "moonshine"), ("moshi", "Moshi"), ("mpnet", "MPNet"), ("mpt", "MPT"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index b8258f972ff4b4..5cdcf88812ee03 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -118,7 +118,6 @@ ("gpt_bigcode", "GPTBigCodeModel"), ("gpt_neo", "GPTNeoModel"), ("gpt_neox", "GPTNeoXModel"), - ("moonshine", "MoonshineModel"), ("gpt_neox_japanese", "GPTNeoXJapaneseModel"), ("gptj", "GPTJModel"), ("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"), @@ -171,6 +170,7 @@ ("mobilenet_v2", "MobileNetV2Model"), ("mobilevit", "MobileViTModel"), ("mobilevitv2", "MobileViTV2Model"), + ("moonshine", "MoonshineModel"), ("moshi", "MoshiModel"), ("mpnet", "MPNetModel"), ("mpt", "MptModel"), @@ -409,7 +409,6 @@ ("gpt_bigcode", "GPTBigCodeForCausalLM"), ("gpt_neo", "GPTNeoForCausalLM"), ("gpt_neox", "GPTNeoXForCausalLM"), - ("moonshine", "MoonshineForCausalLM"), ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"), ("gptj", "GPTJForCausalLM"), ("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"), @@ -426,6 +425,7 @@ ("mega", "MegaForMaskedLM"), ("megatron-bert", "MegatronBertForCausalLM"), ("mobilebert", "MobileBertForMaskedLM"), + ("moonshine", "MoonshineForConditionalGeneration"), ("mpnet", "MPNetForMaskedLM"), ("mpt", "MptForCausalLM"), ("mra", "MraForMaskedLM"), @@ -496,7 +496,6 @@ ("gpt_bigcode", "GPTBigCodeForCausalLM"), ("gpt_neo", "GPTNeoForCausalLM"), ("gpt_neox", "GPTNeoXForCausalLM"), - ("moonshine", "MoonshineForCausalLM"), ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"), ("gptj", "GPTJForCausalLM"), ("granite", "GraniteForCausalLM"), @@ -954,7 +953,6 @@ ("gpt_bigcode", "GPTBigCodeForSequenceClassification"), ("gpt_neo", "GPTNeoForSequenceClassification"), ("gpt_neox", "GPTNeoXForSequenceClassification"), - ("moonshine", "MoonshineForSequenceClassification"), ("gptj", "GPTJForSequenceClassification"), ("ibert", "IBertForSequenceClassification"), ("jamba", "JambaForSequenceClassification"), @@ -1043,7 +1041,6 @@ ("gpt2", "GPT2ForQuestionAnswering"), ("gpt_neo", "GPTNeoForQuestionAnswering"), ("gpt_neox", "GPTNeoXForQuestionAnswering"), - ("moonshine", "MoonshineForQuestionAnswering"), ("gptj", "GPTJForQuestionAnswering"), ("ibert", "IBertForQuestionAnswering"), ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"), @@ -1147,7 +1144,6 @@ ("gpt_bigcode", "GPTBigCodeForTokenClassification"), ("gpt_neo", "GPTNeoForTokenClassification"), ("gpt_neox", "GPTNeoXForTokenClassification"), - ("moonshine", "MoonshineForTokenClassification"), ("ibert", "IBertForTokenClassification"), ("layoutlm", "LayoutLMForTokenClassification"), ("layoutlmv2", "LayoutLMv2ForTokenClassification"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index c1f23bc1cb3f18..b4ceffec37c5b0 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -79,6 +79,7 @@ ("mctct", "MCTCTProcessor"), ("mgp-str", "MgpstrProcessor"), ("mllama", "MllamaProcessor"), + ("moonshine", "Wav2Vec2Processor"), ("oneformer", "OneFormerProcessor"), ("owlv2", "Owlv2Processor"), ("owlvit", "OwlViTProcessor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index ea51f0c1302ab7..fc1fe2cc936868 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -310,8 +310,8 @@ ("mllama", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)), ("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)), - ("moshi", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("moonshine", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), + ("moshi", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)), ("mpt", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), ("mra", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), diff --git a/src/transformers/models/moonshine/configuration_moonshine.py b/src/transformers/models/moonshine/configuration_moonshine.py index 44f043d6d3c3dd..a0a040d1d3d4c4 100644 --- a/src/transformers/models/moonshine/configuration_moonshine.py +++ b/src/transformers/models/moonshine/configuration_moonshine.py @@ -4,7 +4,6 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_moonshine.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 - from ...configuration_utils import PretrainedConfig @@ -42,8 +41,6 @@ class MoonshineConfig(PretrainedConfig): The non-linear activation function (function or string) in the encoder. decoder_hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 2048): - The maximum sequence length that this model might ever be used with. TODO: check this initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (`float`, *optional*, defaults to 1e-5): @@ -56,10 +53,8 @@ class MoonshineConfig(PretrainedConfig): Whether or not the model should return the last key/values attentions (not used by all models). is_encoder_decoder (`bool`, *optional*, defaults to `True`): Whether the model is used as an encoder/decoder or not. - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. TODO: check this - partial_rotary_factor (`float`, *optional*, defaults to 0.5): - Percentage of the query and keys which will have rotary embedding. TODO: check this + min_rotary_ndims (`int`, *optional*, defaults to 32): + The minimum number of dimensions of the RoPE. ff_mult (`int`, *optional*, defaults to 4): Factor by which to scale the intermediate size. attention_bias (`bool`, *optional*, defaults to `False`): @@ -68,43 +63,6 @@ class MoonshineConfig(PretrainedConfig): The dropout ratio for the attention probabilities. qk_layernorm (`bool`, *optional*, defaults to `False`): Whether or not to normalize the Queries and Keys after projecting the hidden states. - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type - and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value - accordingly. - Expected contents: - `rope_type` (`str`): - The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', - 'llama3'], with 'default' being the original RoPE implementation. - `factor` (`float`, *optional*): - Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In - most scaling types, a `factor` of x will enable the model to handle sequences of length x * - original maximum pre-trained length. - `original_max_position_embeddings` (`int`, *optional*): - Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during - pretraining. - `attention_factor` (`float`, *optional*): - Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention - computation. If unspecified, it defaults to value recommended by the implementation, using the - `factor` field to infer the suggested value. - `beta_fast` (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear - ramp function. If unspecified, it defaults to 32. - `beta_slow` (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear - ramp function. If unspecified, it defaults to 1. - `short_factor` (`List[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to short contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - `long_factor` (`List[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to long contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - `low_freq_factor` (`float`, *optional*): - Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE - `high_freq_factor` (`float`, *optional*): - Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE bos_token_id (`int`, *optional*, defaults to 1): Denotes beginning of sequences token id. eos_token_id (`int`, *optional*, defaults to 2): @@ -167,18 +125,15 @@ def __init__( num_key_value_heads=None, encoder_hidden_act="gelu", decoder_hidden_act="silu", - max_position_embeddings=2048, initializer_range=0.02, layer_norm_eps=1e-5, decoder_start_token_id=1, use_cache=True, is_encoder_decoder=True, - rope_theta=10000.0, - partial_rotary_factor=0.5, + min_rotary_ndims=32, attention_bias=False, attention_dropout=0.0, qk_layernorm=False, - rope_scaling=None, ff_mult=4, bos_token_id=1, eos_token_id=2, @@ -203,19 +158,15 @@ def __init__( self.num_key_value_heads = num_key_value_heads self.encoder_hidden_act = encoder_hidden_act self.decoder_hidden_act = decoder_hidden_act - self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps self.decoder_start_token_id = decoder_start_token_id self.use_cache = use_cache self.is_encoder_decoder = is_encoder_decoder - self.rope_theta = rope_theta - self.partial_rotary_factor = partial_rotary_factor - + self.min_rotary_ndims = min_rotary_ndims self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.qk_layernorm = qk_layernorm - self.rope_scaling = rope_scaling self.ff_mult = ff_mult # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779 diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 9f46520f40d86d..ab277c38566d1a 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -4,7 +4,6 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_moonshine.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 - import copy import math from typing import List, Optional, Tuple, Union @@ -256,7 +255,8 @@ def __init__(self, config: MoonshineConfig, layer_idx: Optional[int] = None, is_ self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.rope_theta = config.rope_theta - self.rotary_ndims = max(config.hidden_size // config.num_attention_heads // 2, 32) + + self.rotary_ndims = max(config.hidden_size // config.num_attention_heads // 2, config.min_rotary_ndims) self.is_causal = is_causal @@ -278,11 +278,7 @@ def __init__(self, config: MoonshineConfig, layer_idx: Optional[int] = None, is_ self.k_layernorm = nn.LayerNorm( config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True ) - - self.rotary_emb = MoonshineRotaryEmbedding( - dim=self.rotary_ndims, - max_position_embeddings=config.max_position_embeddings, - ) + self.rotary_emb = MoonshineRotaryEmbedding(dim=self.rotary_ndims) def forward( self, @@ -356,15 +352,14 @@ def forward( key_states[..., : self.rotary_ndims], key_states[..., self.rotary_ndims :], ) - # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] + # [batch_size, seq_length, num_heads, self.rotary_ndims] query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) # [batch_size, seq_length, num_heads, head_dim] query_states = torch.cat((query_rot, query_pass), dim=-1) key_states = torch.cat((key_rot, key_pass), dim=-1) - if past_key_value is not None: - if not is_cross_attention: + if past_key_value is not None: cache_kwargs = { "sin": sin, "cos": cos, @@ -497,15 +492,14 @@ def forward( key_states[..., : self.rotary_ndims], key_states[..., self.rotary_ndims :], ) - # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] + # [batch_size, seq_length, num_heads, self.rotary_ndims] query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) # [batch_size, seq_length, num_heads, head_dim] query_states = torch.cat((query_rot, query_pass), dim=-1) key_states = torch.cat((key_rot, key_pass), dim=-1) - if past_key_value is not None: - if not is_cross_attention: + if past_key_value is not None: cache_kwargs = { "sin": sin, "cos": cos, @@ -655,15 +649,14 @@ def forward( key_states[..., : self.rotary_ndims], key_states[..., self.rotary_ndims :], ) - # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] + # [batch_size, seq_length, num_heads, self.rotary_ndims] query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) # [batch_size, seq_length, num_heads, head_dim] query_states = torch.cat((query_rot, query_pass), dim=-1) key_states = torch.cat((key_rot, key_pass), dim=-1) - if past_key_value is not None: - if not is_cross_attention: + if past_key_value is not None: cache_kwargs = { "sin": sin, "cos": cos, @@ -1046,7 +1039,7 @@ class MoonshineEncoder(MoonshinePreTrainedModel): config: MoonshineConfig """ - main_input_name = "input_features" + main_input_name = "input_values" def __init__(self, config: MoonshineConfig): super().__init__(config) @@ -1059,8 +1052,7 @@ def __init__(self, config: MoonshineConfig): self.groupnorm = nn.GroupNorm(num_groups=1, num_channels=embed_dim, eps=1e-5) self.rotary_emb = MoonshineRotaryEmbedding( - dim=max(config.hidden_size // config.num_attention_heads // 2, 32), - max_position_embeddings=config.max_position_embeddings, + dim=max(config.hidden_size // config.num_attention_heads // 2, config.min_rotary_ndims) ) self.layers = nn.ModuleList([MoonshineEncoderLayer(config, idx) for idx in range(config.num_hidden_layers)]) @@ -1078,8 +1070,7 @@ def set_input_embeddings(self, value: nn.Module): @add_start_docstrings_to_model_forward(MOONSHINE_INPUTS_DOCSTRING) def forward( self, - input_features: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.Tensor] = None, + input_values: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, @@ -1097,7 +1088,7 @@ def forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_features is None) ^ (inputs_embeds is not None): + if (input_values is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if self.gradient_checkpointing and self.training and use_cache: @@ -1107,7 +1098,7 @@ def forward( use_cache = False if inputs_embeds is None: - inputs_embeds = self.preprocess(input_features) + inputs_embeds = self.preprocess(input_values) # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = False @@ -1131,9 +1122,6 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - ) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers @@ -1144,15 +1132,15 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = None - for decoder_layer in self.layers: + for encoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + encoder_layer.__call__, hidden_states, - causal_mask, + None, position_ids, past_key_values, output_attentions, @@ -1161,9 +1149,8 @@ def forward( position_embeddings, ) else: - layer_outputs = decoder_layer( + layer_outputs = encoder_layer( hidden_states, - attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, @@ -1326,9 +1313,9 @@ def _freeze_parameters(self): param.requires_grad = False self._requires_grad = False - def preprocess(self, input_features: torch.FloatTensor): - input_features = input_features.unsqueeze(1) - inputs_embeds = nn.functional.tanh(self.conv1(input_features)) + def preprocess(self, input_values: torch.FloatTensor): + input_values = input_values.unsqueeze(1) + inputs_embeds = nn.functional.tanh(self.conv1(input_values)) inputs_embeds = self.groupnorm(inputs_embeds) inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) inputs_embeds = nn.functional.gelu(self.conv3(inputs_embeds)) @@ -1359,8 +1346,7 @@ def __init__(self, config: MoonshineConfig): ) self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False) self.rotary_emb = MoonshineRotaryEmbedding( - dim=max(config.hidden_size // config.num_attention_heads // 2, 32), - max_position_embeddings=config.max_position_embeddings, + dim=max(config.hidden_size // config.num_attention_heads // 2, config.min_rotary_ndims) ) self.gradient_checkpointing = False @@ -1839,7 +1825,7 @@ def _mask_input_features( @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) def forward( self, - input_features: Optional[torch.FloatTensor] = None, + input_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, @@ -1863,18 +1849,18 @@ def forward( ```python >>> import torch - >>> from transformers import AutoFeatureExtractor, WhisperModel + >>> from transformers import AutoFeatureExtractor, MoonshineModel >>> from datasets import load_dataset - >>> model = WhisperModel.from_pretrained("openai/whisper-base") - >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base") + >>> model = MoonshineModel.from_pretrained("UsefulSensors/moonshine-tiny") + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("UsefulSensors/moonshine-tiny") >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt") - >>> input_features = inputs.input_features + >>> input_values = inputs.input_values >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id - >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state + >>> last_hidden_state = model(input_values, decoder_input_ids=decoder_input_ids).last_hidden_state >>> list(last_hidden_state.shape) - [1, 2, 512] + [1, 2, 288] ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1884,10 +1870,10 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if encoder_outputs is None: - input_features = self._mask_input_features(input_features, attention_mask=attention_mask) + input_values = self._mask_input_values(input_values, attention_mask=attention_mask) encoder_outputs = self.encoder( - input_features, + input_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -1981,7 +1967,7 @@ def encoder(self): @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) def forward( self, - input_features: Optional[torch.FloatTensor] = None, + input_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, @@ -2000,7 +1986,7 @@ def forward( labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is - only computed for the tokens with labels in `[0, ..., config.vocab_size]`. `sequence_length` should be smaller than or equal to `config.max_target_positions`. + only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: @@ -2008,18 +1994,18 @@ def forward( ```python >>> import torch - >>> from transformers import AutoProcessor, WhisperForConditionalGeneration + >>> from transformers import AutoProcessor, MoonshineForConditionalGeneration >>> from datasets import load_dataset - >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") - >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") + >>> processor = AutoProcessor.from_pretrained("UsefulSensors/moonshine") + >>> model = MoonshineForConditionalGeneration.from_pretrained("UsefulSensors/moonshine") >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt") - >>> input_features = inputs.input_features + >>> input_values = inputs.input_values - >>> generated_ids = model.generate(inputs=input_features) + >>> generated_ids = model.generate(input_values, max_new_tokens=100) >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] >>> transcription @@ -2028,17 +2014,13 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: - if labels.shape[1] > self.max_target_positions: - raise ValueError( - f"Labels' sequence length {labels.shape[1]} cannot exceed the maximum allowed length of {self.max_target_positions} tokens." - ) if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ) outputs = self.model( - input_features, + input_values, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index 6d816249eca7f1..6aa9f96a36a3c8 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -1,48 +1,35 @@ -from ...configuration_utils import PretrainedConfig -from ..phi.modeling_phi import PhiAttention, PhiFlashAttention2, PhiSdpaAttention, PhiMLP, PhiRotaryEmbedding -from ..llama.modeling_llama import LlamaDecoderLayer, LlamaModel -from ..mistral.modeling_mistral import MistralMLP -from ..whisper.modeling_whisper import WhisperModel +import copy +import math +from typing import List, Optional, Tuple, Union +import torch +import torch.nn as nn from torch.nn import CrossEntropyLoss -from typing import List, Optional, Tuple, Union -from ...processing_utils import Unpack - +from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...configuration_utils import PretrainedConfig from ...generation import GenerationMixin - +from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions, - CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, - SequenceClassifierOutput, ) -from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward - from ...modeling_utils import PreTrainedModel - -import torch.nn as nn -import torch - +from ...processing_utils import Unpack from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) +from ..llama.modeling_llama import LlamaDecoderLayer, LlamaModel +from ..phi.modeling_phi import PhiAttention, PhiFlashAttention2, PhiMLP, PhiRotaryEmbedding, PhiSdpaAttention +from ..whisper.modeling_whisper import WhisperModel -from typing import Optional, Tuple - -from ...activations import ACT2FN - -import copy -import math logger = logging.get_logger(__name__) @@ -83,8 +70,6 @@ class MoonshineConfig(PretrainedConfig): The non-linear activation function (function or string) in the encoder. decoder_hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 2048): - The maximum sequence length that this model might ever be used with. TODO: check this initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (`float`, *optional*, defaults to 1e-5): @@ -97,10 +82,8 @@ class MoonshineConfig(PretrainedConfig): Whether or not the model should return the last key/values attentions (not used by all models). is_encoder_decoder (`bool`, *optional*, defaults to `True`): Whether the model is used as an encoder/decoder or not. - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. TODO: check this - partial_rotary_factor (`float`, *optional*, defaults to 0.5): - Percentage of the query and keys which will have rotary embedding. TODO: check this + min_rotary_ndims (`int`, *optional*, defaults to 32): + The minimum number of dimensions of the RoPE. ff_mult (`int`, *optional*, defaults to 4): Factor by which to scale the intermediate size. attention_bias (`bool`, *optional*, defaults to `False`): @@ -109,43 +92,6 @@ class MoonshineConfig(PretrainedConfig): The dropout ratio for the attention probabilities. qk_layernorm (`bool`, *optional*, defaults to `False`): Whether or not to normalize the Queries and Keys after projecting the hidden states. - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type - and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value - accordingly. - Expected contents: - `rope_type` (`str`): - The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', - 'llama3'], with 'default' being the original RoPE implementation. - `factor` (`float`, *optional*): - Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In - most scaling types, a `factor` of x will enable the model to handle sequences of length x * - original maximum pre-trained length. - `original_max_position_embeddings` (`int`, *optional*): - Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during - pretraining. - `attention_factor` (`float`, *optional*): - Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention - computation. If unspecified, it defaults to value recommended by the implementation, using the - `factor` field to infer the suggested value. - `beta_fast` (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear - ramp function. If unspecified, it defaults to 32. - `beta_slow` (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear - ramp function. If unspecified, it defaults to 1. - `short_factor` (`List[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to short contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - `long_factor` (`List[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to long contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - `low_freq_factor` (`float`, *optional*): - Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE - `high_freq_factor` (`float`, *optional*): - Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE bos_token_id (`int`, *optional*, defaults to 1): Denotes beginning of sequences token id. eos_token_id (`int`, *optional*, defaults to 2): @@ -208,18 +154,15 @@ def __init__( num_key_value_heads=None, encoder_hidden_act="gelu", decoder_hidden_act="silu", - max_position_embeddings=2048, initializer_range=0.02, layer_norm_eps=1e-5, decoder_start_token_id=1, use_cache=True, is_encoder_decoder=True, - rope_theta=10000.0, - partial_rotary_factor=0.5, + min_rotary_ndims=32, attention_bias=False, attention_dropout=0.0, qk_layernorm=False, - rope_scaling=None, ff_mult=4, bos_token_id=1, eos_token_id=2, @@ -244,19 +187,15 @@ def __init__( self.num_key_value_heads = num_key_value_heads self.encoder_hidden_act = encoder_hidden_act self.decoder_hidden_act = decoder_hidden_act - self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps self.decoder_start_token_id = decoder_start_token_id self.use_cache = use_cache self.is_encoder_decoder = is_encoder_decoder - self.rope_theta = rope_theta - self.partial_rotary_factor = partial_rotary_factor - + self.min_rotary_ndims = min_rotary_ndims self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.qk_layernorm = qk_layernorm - self.rope_scaling = rope_scaling self.ff_mult = ff_mult # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779 @@ -356,7 +295,7 @@ def forward(self, x, position_ids): with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.stack((freqs, freqs), dim=-1) - emb = emb.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)') + emb = emb.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)') cos = emb.cos() sin = emb.sin() @@ -394,11 +333,11 @@ def forward(self, hidden_state): hidden_state, gate = hidden_state.chunk(2, dim=-1) hidden_state = self.act_fn(gate) * hidden_state return self.down_proj(hidden_state) - + class MoonshineMLP: def __new__(cls, config: MoonshineConfig, hidden_act: str): - if hidden_act == "gelu": + if hidden_act == "gelu": return MoonshineNonGatedMLP(config, hidden_act) elif hidden_act == "silu": return MoonshineGatedMLP(config, hidden_act) @@ -413,15 +352,12 @@ def __init__(self, config: MoonshineConfig, layer_idx: Optional[int] = None, is_ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - self.rotary_ndims = max(config.hidden_size // config.num_attention_heads // 2, 32) - self.rotary_emb = MoonshineRotaryEmbedding( - dim=self.rotary_ndims, - max_position_embeddings=config.max_position_embeddings, - ) + self.rotary_ndims = max(config.hidden_size // config.num_attention_heads // 2, config.min_rotary_ndims) + self.rotary_emb = MoonshineRotaryEmbedding(dim=self.rotary_ndims) self.is_causal = is_causal - + def forward( self, hidden_states: torch.Tensor, @@ -437,7 +373,7 @@ def forward( bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) - + # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None @@ -463,7 +399,9 @@ def forward( key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) if is_cross_attention and past_key_value is not None: - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, {"cache_position": cache_position}) + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) if self.qk_layernorm: query_states = self.q_layernorm(query_states) @@ -471,7 +409,7 @@ def forward( query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - if not is_cross_attention: + if not is_cross_attention: if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " @@ -492,7 +430,7 @@ def forward( key_states[..., : self.rotary_ndims], key_states[..., self.rotary_ndims :], ) - # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] + # [batch_size, seq_length, num_heads, self.rotary_ndims] query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) # [batch_size, seq_length, num_heads, head_dim] @@ -506,7 +444,9 @@ def forward( "partial_rotation_size": self.rotary_ndims, "cache_position": cache_position, } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -559,7 +499,7 @@ def forward( bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) - + # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None @@ -585,7 +525,9 @@ def forward( key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) if is_cross_attention and past_key_value is not None: - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, {"cache_position": cache_position}) + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) if self.qk_layernorm: query_states = self.q_layernorm(query_states) @@ -593,7 +535,7 @@ def forward( query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - if not is_cross_attention: + if not is_cross_attention: if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " @@ -614,7 +556,7 @@ def forward( key_states[..., : self.rotary_ndims], key_states[..., self.rotary_ndims :], ) - # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] + # [batch_size, seq_length, num_heads, self.rotary_ndims] query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) # [batch_size, seq_length, num_heads, head_dim] @@ -628,7 +570,9 @@ def forward( "partial_rotation_size": self.rotary_ndims, "cache_position": cache_position, } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. @@ -701,7 +645,7 @@ def forward( bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) - + # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None @@ -727,7 +671,9 @@ def forward( key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) if is_cross_attention and past_key_value is not None: - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, {"cache_position": cache_position}) + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) if self.qk_layernorm: query_states = self.q_layernorm(query_states) @@ -735,7 +681,7 @@ def forward( query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - if not is_cross_attention: + if not is_cross_attention: if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " @@ -756,7 +702,7 @@ def forward( key_states[..., : self.rotary_ndims], key_states[..., self.rotary_ndims :], ) - # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] + # [batch_size, seq_length, num_heads, self.rotary_ndims] query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) # [batch_size, seq_length, num_heads, head_dim] @@ -770,7 +716,9 @@ def forward( "partial_rotation_size": self.rotary_ndims, "cache_position": cache_position, } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -818,7 +766,7 @@ def forward( class MoonshineEncoderLayer(LlamaDecoderLayer): def __init__(self, config: MoonshineConfig, layer_idx: int): super().__init__(config, layer_idx) - + self.mlp = MoonshineMLP(config, config.encoder_hidden_act) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False) @@ -829,10 +777,14 @@ def __init__(self, config: MoonshineConfig, layer_idx: int = None): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = MOONSHINE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx, is_causal=True) - self.encoder_attn = MOONSHINE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx, is_causal=False) + self.self_attn = MOONSHINE_ATTENTION_CLASSES[config._attn_implementation]( + config=config, layer_idx=layer_idx, is_causal=True + ) + self.encoder_attn = MOONSHINE_ATTENTION_CLASSES[config._attn_implementation]( + config=config, layer_idx=layer_idx, is_causal=False + ) - self.mlp = MoonshineMLP(config, config.decoder_hidden_act) + self.mlp = MoonshineMLP(config, config.decoder_hidden_act) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False) self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False) @@ -935,7 +887,7 @@ def forward( outputs += (present_key_value,) return outputs - + MOONSHINE_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -983,7 +935,7 @@ def _init_weights(self, module): class MoonshineEncoder(LlamaModel, MoonshinePreTrainedModel): - main_input_name = "input_features" + main_input_name = "input_values" def __init__(self, config: MoonshineConfig): MoonshinePreTrainedModel.__init__(self, config) @@ -996,9 +948,8 @@ def __init__(self, config: MoonshineConfig): self.groupnorm = nn.GroupNorm(num_groups=1, num_channels=embed_dim, eps=1e-5) self.rotary_emb = MoonshineRotaryEmbedding( - dim=max(config.hidden_size // config.num_attention_heads // 2, 32), - max_position_embeddings=config.max_position_embeddings, - ) + dim=max(config.hidden_size // config.num_attention_heads // 2, config.min_rotary_ndims) + ) self.layers = nn.ModuleList([MoonshineEncoderLayer(config, idx) for idx in range(config.num_hidden_layers)]) self.layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps, bias=False) @@ -1016,19 +967,19 @@ def get_input_embeddings(self) -> nn.Module: def set_input_embeddings(self, value: nn.Module): self.conv1 = value - - def preprocess(self, input_features: torch.FloatTensor): - input_features = input_features.unsqueeze(1) - inputs_embeds = nn.functional.tanh(self.conv1(input_features)) + + def preprocess(self, input_values: torch.FloatTensor): + input_values = input_values.unsqueeze(1) + inputs_embeds = nn.functional.tanh(self.conv1(input_values)) inputs_embeds = self.groupnorm(inputs_embeds) inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) inputs_embeds = nn.functional.gelu(self.conv3(inputs_embeds)) - inputs_embeds = inputs_embeds.permute(0, 2, 1) + inputs_embeds = inputs_embeds.permute(0, 2, 1) return inputs_embeds - + def forward( self, - input_features: Optional[torch.FloatTensor] = None, + input_values: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, @@ -1046,7 +997,7 @@ def forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_features is None) ^ (inputs_embeds is not None): + if (input_values is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if self.gradient_checkpointing and self.training and use_cache: @@ -1056,7 +1007,7 @@ def forward( use_cache = False if inputs_embeds is None: - inputs_embeds = self.preprocess(input_features) + inputs_embeds = self.preprocess(input_values) # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = False @@ -1070,7 +1021,7 @@ def forward( "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -1226,8 +1177,7 @@ def __init__(self, config: MoonshineConfig): super().__init__(config) self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False) self.rotary_emb = MoonshineRotaryEmbedding( - dim= max(config.hidden_size // config.num_attention_heads // 2, 32), - max_position_embeddings=config.max_position_embeddings, + dim=max(config.hidden_size // config.num_attention_heads // 2, config.min_rotary_ndims) ) def forward( @@ -1363,7 +1313,11 @@ def forward( next_cache = next_cache.to_legacy_cache() if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None) + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -1371,7 +1325,7 @@ def forward( attentions=all_self_attns, cross_attentions=all_cross_attentions, ) - + class MoonshineModel(WhisperModel): def __init__(self, config: MoonshineConfig): @@ -1381,7 +1335,7 @@ def __init__(self, config: MoonshineConfig): def forward( self, - input_features: Optional[torch.FloatTensor] = None, + input_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, @@ -1401,18 +1355,18 @@ def forward( Example: ```python >>> import torch - >>> from transformers import AutoFeatureExtractor, WhisperModel + >>> from transformers import AutoFeatureExtractor, MoonshineModel >>> from datasets import load_dataset - >>> model = WhisperModel.from_pretrained("openai/whisper-base") - >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base") + >>> model = MoonshineModel.from_pretrained("UsefulSensors/moonshine-tiny") + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("UsefulSensors/moonshine-tiny") >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt") - >>> input_features = inputs.input_features + >>> input_values = inputs.input_values >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id - >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state + >>> last_hidden_state = model(input_values, decoder_input_ids=decoder_input_ids).last_hidden_state >>> list(last_hidden_state.shape) - [1, 2, 512] + [1, 2, 288] ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1422,10 +1376,10 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if encoder_outputs is None: - input_features = self._mask_input_features(input_features, attention_mask=attention_mask) + input_values = self._mask_input_values(input_values, attention_mask=attention_mask) encoder_outputs = self.encoder( - input_features, + input_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -1493,16 +1447,16 @@ def set_output_embeddings(self, new_embeddings): def get_input_embeddings(self) -> nn.Module: return self.model.get_input_embeddings() - + @property def encoder(self): return self.get_encoder() - + @add_start_docstrings_to_model_forward(MOONSHINE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) def forward( self, - input_features: Optional[torch.FloatTensor] = None, + input_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, @@ -1521,7 +1475,7 @@ def forward( labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is - only computed for the tokens with labels in `[0, ..., config.vocab_size]`. `sequence_length` should be smaller than or equal to `config.max_target_positions`. + only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: @@ -1529,18 +1483,18 @@ def forward( ```python >>> import torch - >>> from transformers import AutoProcessor, WhisperForConditionalGeneration + >>> from transformers import AutoProcessor, MoonshineForConditionalGeneration >>> from datasets import load_dataset - >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") - >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") + >>> processor = AutoProcessor.from_pretrained("UsefulSensors/moonshine") + >>> model = MoonshineForConditionalGeneration.from_pretrained("UsefulSensors/moonshine") >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt") - >>> input_features = inputs.input_features + >>> input_values = inputs.input_values - >>> generated_ids = model.generate(inputs=input_features) + >>> generated_ids = model.generate(input_values, max_new_tokens=100) >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] >>> transcription @@ -1549,17 +1503,13 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: - if labels.shape[1] > self.max_target_positions: - raise ValueError( - f"Labels' sequence length {labels.shape[1]} cannot exceed the maximum allowed length of {self.max_target_positions} tokens." - ) if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ) outputs = self.model( - input_features, + input_values, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, @@ -1596,4 +1546,4 @@ def forward( encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_hidden_states=outputs.encoder_hidden_states, encoder_attentions=outputs.encoder_attentions, - ) \ No newline at end of file + )