Skip to content

Commit

Permalink
add MoonshineForConditionalGeneration
Browse files Browse the repository at this point in the history
  • Loading branch information
eustlb authored and Eustache Le Bihan committed Dec 15, 2024
1 parent 7a6935a commit 8fda426
Showing 1 changed file with 110 additions and 64 deletions.
174 changes: 110 additions & 64 deletions src/transformers/models/moonshine/modular_moonshine.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,17 +991,11 @@ def _init_weights(self, module):
module.weight.data[module.padding_idx].zero_()


class MoonshineEncoder(MoonshinePreTrainedModel):
"""
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
[`MoonshineEncoderLayer`].
Args:
config: MoonshineConfig
"""
class MoonshineEncoder(LlamaModel, MoonshinePreTrainedModel):
main_input_name = "input_features"

def __init__(self, config: MoonshineConfig):
super().__init__(config)
MoonshinePreTrainedModel.__init__(self, config)
self.config = config
embed_dim = config.hidden_size

Expand Down Expand Up @@ -1031,85 +1025,138 @@ def get_input_embeddings(self) -> nn.Module:

def set_input_embeddings(self, value: nn.Module):
self.conv1 = value


def preprocess(self, input_features: torch.FloatTensor):
input_features = input_features.unsqueeze(1)
inputs_embeds = nn.functional.tanh(self.conv1(input_features))
inputs_embeds = self.groupnorm(inputs_embeds)
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
inputs_embeds = nn.functional.gelu(self.conv3(inputs_embeds))
inputs_embeds = inputs_embeds.permute(0, 2, 1)
return inputs_embeds

def forward(
self,
input_features,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
r"""
Args:
input_features (`torch.LongTensor` of shape `(batch_size, 1, sequence_length)`):
Float values of the raw speech waveform. Raw speech waveform can be
obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`).
attention_mask (`torch.Tensor`)`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
input_features: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

inputs_embeds = nn.functional.tanh(self.conv1(input_features))
inputs_embeds = self.groupnorm(inputs_embeds)
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
inputs_embeds = nn.functional.gelu(self.conv3(inputs_embeds))
inputs_embeds = inputs_embeds.permute(0, 2, 1)
if (input_features is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False

if inputs_embeds is None:
inputs_embeds = self.preprocess(input_features)

# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache or past_key_values is not None:
if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache):
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
elif not isinstance(past_key_values, EncoderDecoderCache):
return_legacy_cache = True
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. "
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
)
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
hidden_states = inputs_embeds

position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
embed_pos = self.rotary_emb(hidden_states, position_ids)

encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None

for encoder_layer in self.layers:
for decoder_layer in self.layers:
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
all_hidden_states += (hidden_states,)

if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
decoder_layer.__call__,
hidden_states,
attention_mask,
causal_mask,
position_ids,
past_key_values,
output_attentions,
position_embeddings=embed_pos,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = encoder_layer(
layer_outputs = decoder_layer(
hidden_states,
attention_mask,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
position_embeddings=embed_pos,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)

hidden_states = layer_outputs[0]

if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]

if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
all_self_attns += (layer_outputs[1],)

hidden_states = self.layer_norm(hidden_states)

# add hidden states from the last decoder layer
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
all_hidden_states += (hidden_states,)

next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()

if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)


Expand Down Expand Up @@ -1461,6 +1508,10 @@ def set_output_embeddings(self, new_embeddings):
def get_input_embeddings(self) -> nn.Module:
return self.model.get_input_embeddings()

@property
def encoder(self):
return self.get_encoder()

@add_start_docstrings_to_model_forward(MOONSHINE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
Expand Down Expand Up @@ -1559,9 +1610,4 @@ def forward(
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)





)

0 comments on commit 8fda426

Please sign in to comment.