Skip to content

Commit

Permalink
add rope_theta
Browse files Browse the repository at this point in the history
  • Loading branch information
eustlb committed Dec 16, 2024
1 parent 3d52b1e commit 8f82a40
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 23 deletions.
4 changes: 4 additions & 0 deletions src/transformers/models/moonshine/configuration_moonshine.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class MoonshineConfig(PretrainedConfig):
Whether the model is used as an encoder/decoder or not.
min_rotary_ndims (`int`, *optional*, defaults to 32):
The minimum number of dimensions of the RoPE.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
ff_mult (`int`, *optional*, defaults to 4):
Factor by which to scale the intermediate size.
attention_bias (`bool`, *optional*, defaults to `False`):
Expand Down Expand Up @@ -129,6 +131,7 @@ def __init__(
layer_norm_eps=1e-5,
decoder_start_token_id=1,
use_cache=True,
rope_theta=10000.0,
is_encoder_decoder=True,
min_rotary_ndims=32,
attention_bias=False,
Expand Down Expand Up @@ -162,6 +165,7 @@ def __init__(
self.layer_norm_eps = layer_norm_eps
self.decoder_start_token_id = decoder_start_token_id
self.use_cache = use_cache
self.rope_theta = rope_theta
self.is_encoder_decoder = is_encoder_decoder
self.min_rotary_ndims = min_rotary_ndims
self.attention_bias = attention_bias
Expand Down
45 changes: 22 additions & 23 deletions src/transformers/models/moonshine/convert_usefulsensors_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,22 @@
# limitations under the License.

import argparse
from huggingface_hub import hf_hub_download
import re

import h5py
import torch
import numpy as np
import re
import torch
from huggingface_hub import hf_hub_download

from transformers.models.moonshine.modeling_moonshine import MoonshineConfig
from transformers.models.moonshine.modeling_moonshine import MoonshineDecoder
from transformers.models.moonshine.modeling_moonshine import MoonshineConfig, MoonshineForConditionalGeneration


# Copied from https://github.com/usefulsensors/moonshine/blob/a1d77cc573b0471ac4602b86f67b3f48d67df1a9/moonshine/model.py
def _get_weights(model_name):
repo = "UsefulSensors/moonshine"

return (
hf_hub_download(repo, f"{x}.weights.h5", subfolder=model_name)
for x in ("preprocessor", "encoder", "decoder")
hf_hub_download(repo, f"{x}.weights.h5", subfolder=model_name) for x in ("preprocessor", "encoder", "decoder")
)


Expand All @@ -58,7 +56,12 @@ def _read_h5_weights(group, current_key="", weights={}):


def _convert_layer_names(name, gated_mlp=False):
name = re.sub(r'layers\.functional(?:_(\d+))?\.layers', lambda m: f'layers.{m.group(1) if m.group(1) else "0"}', name, count=1)
name = re.sub(
r"layers\.functional(?:_(\d+))?\.layers",
lambda m: f'layers.{m.group(1) if m.group(1) else "0"}',
name,
count=1,
)
if gated_mlp:
name = re.sub(r"functional\.layers\.dense\.", "mlp.up_proj.", name)
name = re.sub(r"functional\.layers\.dense_1\.", "mlp.down_proj.", name)
Expand Down Expand Up @@ -113,32 +116,32 @@ def _convert_weights(weights, encoder=True):

def convert_usefulsensors_moonshine_to_hf(model_name, pytorch_dump_folder_path):
preprocessor_weights_path, encoder_weights_path, decoder_weights_path = _get_weights(model_name)
with h5py.File(preprocessor_weights_path, 'r') as f:

with h5py.File(preprocessor_weights_path, "r") as f:
loaded_preprocessor_weights = _read_h5_weights(f, weights={})

with h5py.File(encoder_weights_path, 'r') as f:
with h5py.File(encoder_weights_path, "r") as f:
loaded_encoder_weights = _read_h5_weights(f, weights={})

with h5py.File(decoder_weights_path, 'r') as f:
with h5py.File(decoder_weights_path, "r") as f:
loaded_decoder_weights = _read_h5_weights(f, weights={})

encoder_state_dict = {**loaded_encoder_weights, **loaded_preprocessor_weights}
encoder_state_dict = _convert_weights(encoder_state_dict)

converted_decoder_weights = _convert_weights(loaded_decoder_weights, encoder=False)
converted_decoder_weights['embed_tokens.weight'] = converted_decoder_weights['embed_tokens.weight'].T
converted_decoder_weights["embed_tokens.weight"] = converted_decoder_weights["embed_tokens.weight"].T

final_weights = {}
for k, v in encoder_state_dict.items():
final_weights[f"model.encoder.{k}"] = v

for k, v in converted_decoder_weights.items():
final_weights[f"model.decoder.{k}"] = v

if model_name == 'tiny':
if model_name == "tiny":
config = MoonshineConfig()
elif model_name == 'base':
elif model_name == "base":
config = MoonshineConfig(
hidden_size=416,
num_hidden_layers=8,
Expand All @@ -147,22 +150,18 @@ def convert_usefulsensors_moonshine_to_hf(model_name, pytorch_dump_folder_path):
else:
raise ValueError(f"Unknown model name {model_name}")

final_weights['proj_out.weight'] = converted_decoder_weights['embed_tokens.weight']
final_weights["proj_out.weight"] = converted_decoder_weights["embed_tokens.weight"]

model = MoonshineForConditionalGeneration(config)
model.load_state_dict(final_weights)
model.save_pretrained(pytorch_dump_folder_path)



if __name__ == "__main__":
parser = argparse.ArgumentParser()
# # Required parameters
parser.add_argument("--model_name", type=str, help="Path to the downloaded checkpoints")
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
args = parser.parse_args()

convert_usefulsensors_moonshine_to_hf(
args.model_name, args.pytorch_dump_folder_path
)

convert_usefulsensors_moonshine_to_hf(args.model_name, args.pytorch_dump_folder_path)
4 changes: 4 additions & 0 deletions src/transformers/models/moonshine/modular_moonshine.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ class MoonshineConfig(PretrainedConfig):
Whether the model is used as an encoder/decoder or not.
min_rotary_ndims (`int`, *optional*, defaults to 32):
The minimum number of dimensions of the RoPE.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
ff_mult (`int`, *optional*, defaults to 4):
Factor by which to scale the intermediate size.
attention_bias (`bool`, *optional*, defaults to `False`):
Expand Down Expand Up @@ -158,6 +160,7 @@ def __init__(
layer_norm_eps=1e-5,
decoder_start_token_id=1,
use_cache=True,
rope_theta=10000.0,
is_encoder_decoder=True,
min_rotary_ndims=32,
attention_bias=False,
Expand Down Expand Up @@ -191,6 +194,7 @@ def __init__(
self.layer_norm_eps = layer_norm_eps
self.decoder_start_token_id = decoder_start_token_id
self.use_cache = use_cache
self.rope_theta = rope_theta
self.is_encoder_decoder = is_encoder_decoder
self.min_rotary_ndims = min_rotary_ndims
self.attention_bias = attention_bias
Expand Down

0 comments on commit 8f82a40

Please sign in to comment.