Skip to content

Commit

Permalink
Add Moonshine Preprocessor
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 702332422
  • Loading branch information
talumbau authored and copybara-github committed Dec 3, 2024
1 parent d4e358e commit 10e9c59
Show file tree
Hide file tree
Showing 6 changed files with 321 additions and 0 deletions.
14 changes: 14 additions & 0 deletions ai_edge_torch/generative/examples/moonshine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Example of converting a Moonshine model to multi-signature tflite model."""

import os
import pathlib

from absl import app
from absl import flags
import ai_edge_torch
from ai_edge_torch.generative.examples.moonshine import moonshine
from ai_edge_torch.generative.utilities import converter
import torch

_CHECKPOINT_PATH = flags.DEFINE_string(
'checkpoint_path',
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/moonshine'),
'The path to the model checkpoint, or directory holding the checkpoint.',
)
_TFLITE_PATH = flags.DEFINE_string(
'tflite_path',
'/tmp/',
'The tflite file path to export.',
)


def main(_):
p_model = moonshine.build_preprocessor(_CHECKPOINT_PATH.value)
output_filename = f'moonshine_preprocessor.tflite'
_input = torch.randn((1, 1, 159414), dtype=torch.float)
edge_model = ai_edge_torch.convert(p_model, (_input,), quant_config=None)
tflite_path = os.path.join(_TFLITE_PATH.value, output_filename)
edge_model.export(tflite_path)


if __name__ == '__main__':
app.run(main)
Binary file not shown.
Binary file not shown.
103 changes: 103 additions & 0 deletions ai_edge_torch/generative/examples/moonshine/moonshine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Example of building the Moonshine model."""

import os
import pathlib
from typing import Optional, Tuple
from absl import app
from ai_edge_torch.generative.layers import attention
from ai_edge_torch.generative.layers import builder
from ai_edge_torch.generative.layers import kv_cache as kv_utils
import ai_edge_torch.generative.layers.attention_utils as attn_utils
import ai_edge_torch.generative.layers.model_config as cfg
import ai_edge_torch.generative.layers.normalization as normalization
import ai_edge_torch.generative.utilities.moonshine_loader as loading_utils
import h5py
import torch
from torch import nn
import torch.nn as nn

TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
conv1D_0="layers/sequential/layers/conv1d/vars",
conv1D_1="layers/sequential/layers/conv1d_1/vars",
conv1D_2="layers/sequential/layers/conv1d_2/vars",
group_norm="layers/sequential/layers/group_normalization/vars",
)


class AudioPreprocessor(nn.Module):

def __init__(self, dim):
super(AudioPreprocessor, self).__init__()
self.conv1 = nn.Conv1d(
in_channels=1, out_channels=dim, kernel_size=127, stride=64, bias=False
)
self.tanh = nn.Tanh()
self.group_norm = normalization.GroupNorm(group_num=1, dim=dim, eps=1e-5)
self.conv2 = nn.Conv1d(
in_channels=dim,
out_channels=2 * dim,
kernel_size=7,
stride=3,
padding=0, # Equivalent to padding="valid"
)
self.gelu1 = nn.GELU()
self.conv3 = nn.Conv1d(
in_channels=2 * dim,
out_channels=dim,
kernel_size=3,
stride=2,
padding=0, # Equivalent to padding="valid"
)
self.gelu2 = nn.GELU()

def forward(self, inputs):
x = self.conv1(inputs)
x = self.tanh(x)
x = self.group_norm(x)
x = self.conv2(x)
x = self.gelu1(x)
x = self.conv3(x)
x = self.gelu2(x)
return x


def build_preprocessor(checkpoint_path: str, **kwargs) -> nn.Module:
ap = AudioPreprocessor(dim=416)
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
loader.load(ap, strict=True)
return ap


def main(_):
# TODO(b/375421767) Remove golden checks once full model is implemented.
HF_PATH = os.path.join(pathlib.Path.home(), "Downloads/llm_data/moonshine")

test_data_path = pathlib.Path(__file__).parent.resolve()
INPUT_PATH = test_data_path / "data" / "pp_input.pt")
GOLDEN_PATH = test_data_path / "data" / "pp_output.pt")

ap = build_preprocessor(HF_PATH)
ap.eval()
inputs = torch.load(INPUT_PATH).reshape((1, 1, 159414))
out = ap(inputs)
golden = torch.load(GOLDEN_PATH).transpose(1, 2)
assert torch.allclose(out, golden, atol=1e-4, rtol=1e-4)


if __name__ == "__main__":
app.run(main)
154 changes: 154 additions & 0 deletions ai_edge_torch/generative/utilities/moonshine_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Common utility functions for data loading etc.
from dataclasses import dataclass
import glob
import os
from typing import Callable, Dict

import h5py
import torch


def transpose_if_needed(t):
"""We assume the file is from Keras, i.e. channel last format."""
if len(t.shape) > 2:
return t.permute(2, 1, 0)
return t


def load_h5_statedict(full_path: str):
"""Loads the HDF5 DataSets into a single dctionary.
Args:
full_path (string): the HDF5 filename or directory that contains the HDF5
files.
Returns:
A state dictionary contating loaded tensors.
Raises:
ValueError: If no tensors are loaded from the provided directory or file.
"""
pattern = (
os.path.join(full_path, "*.h5") if os.path.isdir(full_path) else full_path
)
files = []
for file in glob.glob(pattern):
files.append(file)

tensors = {}

def collect_datasets(name, obj):
if isinstance(obj, h5py.Dataset):
tensors[name] = transpose_if_needed(torch.from_numpy(obj[:]))

for file in files:
with h5py.File(file) as f:
f.visititems(collect_datasets)

if not tensors:
raise ValueError("Failed to load HDF5 file.")
return tensors


class ModelLoader:
"""Utility class for loading and converting checkpoints to ODML transformer layer format."""

@dataclass
class TensorNames:
conv1D_0: str = None
conv1D_1: str = None
conv1D_2: str = None
group_norm: str = None

def __init__(self, file_name: str, names: TensorNames) -> None:
"""ModelLoader constructor.
Can be used to load multiple models of the same type.
Args:
file_name (str): Path to the checkpoint. Can be a directory or an exact
file.
names (TensorNames): An instance of `TensorNames` to determine mappings.
"""
self._file_name = file_name
self._names = names
self._loader = load_h5_statedict

def load(
self,
model: torch.nn.Module,
strict: bool = True,
):
"""Load the model from the checkpoint
Args:
model (torch.nn.Module): The pytorch model that needs to be loaded.
strict (bool, optional): Whether the converted keys are strictly
matched. Defaults to True.
Raises:
ValueError: If conversion results in unmapped tensors and strict mode is
enabled.
"""
state = self._loader(self._file_name)

if isinstance(self._names, ModelLoader.TensorNames):
converted_state = self._do_load(model, state, self._names)
else:
raise ValueError(f"Unkown type for names: {type(self._names)}")

if strict and state:
raise ValueError(
"Failed to map all tensor. Remaining tensor are:"
f" {list(state.keys())}"
)
model.load_state_dict(converted_state, strict=strict)

def _do_load(self, model, state, names, additional_prefix=""):
"""Load the model from the checkpoint
Args:
model (torch.nn.Module): The pytorch model that needs to be loaded.
state (Dict[str, torch.Tensor]): The pytorch state dictionary
names (TensorNames]): The TensorNames for the model we are loading.
Returns:
Dict[str, torch.Tensor]: Map of name to tensor for loading.
"""
converted_state = dict()
if names.conv1D_0 is not None:
converted_state["conv1.weight"] = state.pop(f"{names.conv1D_0}/0")
if f"{names.conv1D_0}/1" in state:
converted_state["conv1.bias"] = state.pop(f"{names.conv1D_0}/1")

if names.conv1D_1 is not None:
converted_state["conv2.weight"] = state.pop(f"{names.conv1D_1}/0")
if f"{names.conv1D_1}/1" in state:
converted_state["conv2.bias"] = state.pop(f"{names.conv1D_1}/1")

if names.conv1D_2 is not None:
converted_state["conv3.weight"] = state.pop(f"{names.conv1D_2}/0")
if f"{names.conv1D_2}/1" in state:
converted_state["conv3.bias"] = state.pop(f"{names.conv1D_2}/1")

if names.group_norm is not None:
group_norm_name = names.group_norm
converted_state[f"group_norm.weight"] = state.pop(f"{group_norm_name}/0")
if f"{group_norm_name}/1" in state:
converted_state["group_norm.bias"] = state.pop(f"{group_norm_name}/1")

return converted_state

0 comments on commit 10e9c59

Please sign in to comment.