-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
6 changed files
with
321 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Copyright 2024 The AI Edge Torch Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== |
50 changes: 50 additions & 0 deletions
50
ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Copyright 2024 The AI Edge Torch Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
"""Example of converting a Moonshine model to multi-signature tflite model.""" | ||
|
||
import os | ||
import pathlib | ||
|
||
from absl import app | ||
from absl import flags | ||
import ai_edge_torch | ||
from ai_edge_torch.generative.examples.moonshine import moonshine | ||
from ai_edge_torch.generative.utilities import converter | ||
import torch | ||
|
||
_CHECKPOINT_PATH = flags.DEFINE_string( | ||
'checkpoint_path', | ||
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/moonshine'), | ||
'The path to the model checkpoint, or directory holding the checkpoint.', | ||
) | ||
_TFLITE_PATH = flags.DEFINE_string( | ||
'tflite_path', | ||
'/tmp/', | ||
'The tflite file path to export.', | ||
) | ||
|
||
|
||
def main(_): | ||
p_model = moonshine.build_preprocessor(_CHECKPOINT_PATH.value) | ||
output_filename = f'moonshine_preprocessor.tflite' | ||
_input = torch.randn((1, 1, 159414), dtype=torch.float) | ||
edge_model = ai_edge_torch.convert(p_model, (_input,), quant_config=None) | ||
tflite_path = os.path.join(_TFLITE_PATH.value, output_filename) | ||
edge_model.export(tflite_path) | ||
|
||
|
||
if __name__ == '__main__': | ||
app.run(main) |
Binary file not shown.
Binary file not shown.
103 changes: 103 additions & 0 deletions
103
ai_edge_torch/generative/examples/moonshine/moonshine.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# Copyright 2024 The AI Edge Torch Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
"""Example of building the Moonshine model.""" | ||
|
||
import os | ||
import pathlib | ||
from typing import Optional, Tuple | ||
from absl import app | ||
from ai_edge_torch.generative.layers import attention | ||
from ai_edge_torch.generative.layers import builder | ||
from ai_edge_torch.generative.layers import kv_cache as kv_utils | ||
import ai_edge_torch.generative.layers.attention_utils as attn_utils | ||
import ai_edge_torch.generative.layers.model_config as cfg | ||
import ai_edge_torch.generative.layers.normalization as normalization | ||
import ai_edge_torch.generative.utilities.moonshine_loader as loading_utils | ||
import h5py | ||
import torch | ||
from torch import nn | ||
import torch.nn as nn | ||
|
||
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames( | ||
conv1D_0="layers/sequential/layers/conv1d/vars", | ||
conv1D_1="layers/sequential/layers/conv1d_1/vars", | ||
conv1D_2="layers/sequential/layers/conv1d_2/vars", | ||
group_norm="layers/sequential/layers/group_normalization/vars", | ||
) | ||
|
||
|
||
class AudioPreprocessor(nn.Module): | ||
|
||
def __init__(self, dim): | ||
super(AudioPreprocessor, self).__init__() | ||
self.conv1 = nn.Conv1d( | ||
in_channels=1, out_channels=dim, kernel_size=127, stride=64, bias=False | ||
) | ||
self.tanh = nn.Tanh() | ||
self.group_norm = normalization.GroupNorm(group_num=1, dim=dim, eps=1e-5) | ||
self.conv2 = nn.Conv1d( | ||
in_channels=dim, | ||
out_channels=2 * dim, | ||
kernel_size=7, | ||
stride=3, | ||
padding=0, # Equivalent to padding="valid" | ||
) | ||
self.gelu1 = nn.GELU() | ||
self.conv3 = nn.Conv1d( | ||
in_channels=2 * dim, | ||
out_channels=dim, | ||
kernel_size=3, | ||
stride=2, | ||
padding=0, # Equivalent to padding="valid" | ||
) | ||
self.gelu2 = nn.GELU() | ||
|
||
def forward(self, inputs): | ||
x = self.conv1(inputs) | ||
x = self.tanh(x) | ||
x = self.group_norm(x) | ||
x = self.conv2(x) | ||
x = self.gelu1(x) | ||
x = self.conv3(x) | ||
x = self.gelu2(x) | ||
return x | ||
|
||
|
||
def build_preprocessor(checkpoint_path: str, **kwargs) -> nn.Module: | ||
ap = AudioPreprocessor(dim=416) | ||
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES) | ||
loader.load(ap, strict=True) | ||
return ap | ||
|
||
|
||
def main(_): | ||
# TODO(b/375421767) Remove golden checks once full model is implemented. | ||
HF_PATH = os.path.join(pathlib.Path.home(), "Downloads/llm_data/moonshine") | ||
|
||
test_data_path = pathlib.Path(__file__).parent.resolve() | ||
INPUT_PATH = test_data_path / "data" / "pp_input.pt") | ||
GOLDEN_PATH = test_data_path / "data" / "pp_output.pt") | ||
|
||
ap = build_preprocessor(HF_PATH) | ||
ap.eval() | ||
inputs = torch.load(INPUT_PATH).reshape((1, 1, 159414)) | ||
out = ap(inputs) | ||
golden = torch.load(GOLDEN_PATH).transpose(1, 2) | ||
assert torch.allclose(out, golden, atol=1e-4, rtol=1e-4) | ||
|
||
|
||
if __name__ == "__main__": | ||
app.run(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
# Copyright 2024 The AI Edge Torch Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
# Common utility functions for data loading etc. | ||
from dataclasses import dataclass | ||
import glob | ||
import os | ||
from typing import Callable, Dict | ||
|
||
import h5py | ||
import torch | ||
|
||
|
||
def transpose_if_needed(t): | ||
"""We assume the file is from Keras, i.e. channel last format.""" | ||
if len(t.shape) > 2: | ||
return t.permute(2, 1, 0) | ||
return t | ||
|
||
|
||
def load_h5_statedict(full_path: str): | ||
"""Loads the HDF5 DataSets into a single dctionary. | ||
Args: | ||
full_path (string): the HDF5 filename or directory that contains the HDF5 | ||
files. | ||
Returns: | ||
A state dictionary contating loaded tensors. | ||
Raises: | ||
ValueError: If no tensors are loaded from the provided directory or file. | ||
""" | ||
pattern = ( | ||
os.path.join(full_path, "*.h5") if os.path.isdir(full_path) else full_path | ||
) | ||
files = [] | ||
for file in glob.glob(pattern): | ||
files.append(file) | ||
|
||
tensors = {} | ||
|
||
def collect_datasets(name, obj): | ||
if isinstance(obj, h5py.Dataset): | ||
tensors[name] = transpose_if_needed(torch.from_numpy(obj[:])) | ||
|
||
for file in files: | ||
with h5py.File(file) as f: | ||
f.visititems(collect_datasets) | ||
|
||
if not tensors: | ||
raise ValueError("Failed to load HDF5 file.") | ||
return tensors | ||
|
||
|
||
class ModelLoader: | ||
"""Utility class for loading and converting checkpoints to ODML transformer layer format.""" | ||
|
||
@dataclass | ||
class TensorNames: | ||
conv1D_0: str = None | ||
conv1D_1: str = None | ||
conv1D_2: str = None | ||
group_norm: str = None | ||
|
||
def __init__(self, file_name: str, names: TensorNames) -> None: | ||
"""ModelLoader constructor. | ||
Can be used to load multiple models of the same type. | ||
Args: | ||
file_name (str): Path to the checkpoint. Can be a directory or an exact | ||
file. | ||
names (TensorNames): An instance of `TensorNames` to determine mappings. | ||
""" | ||
self._file_name = file_name | ||
self._names = names | ||
self._loader = load_h5_statedict | ||
|
||
def load( | ||
self, | ||
model: torch.nn.Module, | ||
strict: bool = True, | ||
): | ||
"""Load the model from the checkpoint | ||
Args: | ||
model (torch.nn.Module): The pytorch model that needs to be loaded. | ||
strict (bool, optional): Whether the converted keys are strictly | ||
matched. Defaults to True. | ||
Raises: | ||
ValueError: If conversion results in unmapped tensors and strict mode is | ||
enabled. | ||
""" | ||
state = self._loader(self._file_name) | ||
|
||
if isinstance(self._names, ModelLoader.TensorNames): | ||
converted_state = self._do_load(model, state, self._names) | ||
else: | ||
raise ValueError(f"Unkown type for names: {type(self._names)}") | ||
|
||
if strict and state: | ||
raise ValueError( | ||
"Failed to map all tensor. Remaining tensor are:" | ||
f" {list(state.keys())}" | ||
) | ||
model.load_state_dict(converted_state, strict=strict) | ||
|
||
def _do_load(self, model, state, names, additional_prefix=""): | ||
"""Load the model from the checkpoint | ||
Args: | ||
model (torch.nn.Module): The pytorch model that needs to be loaded. | ||
state (Dict[str, torch.Tensor]): The pytorch state dictionary | ||
names (TensorNames]): The TensorNames for the model we are loading. | ||
Returns: | ||
Dict[str, torch.Tensor]: Map of name to tensor for loading. | ||
""" | ||
converted_state = dict() | ||
if names.conv1D_0 is not None: | ||
converted_state["conv1.weight"] = state.pop(f"{names.conv1D_0}/0") | ||
if f"{names.conv1D_0}/1" in state: | ||
converted_state["conv1.bias"] = state.pop(f"{names.conv1D_0}/1") | ||
|
||
if names.conv1D_1 is not None: | ||
converted_state["conv2.weight"] = state.pop(f"{names.conv1D_1}/0") | ||
if f"{names.conv1D_1}/1" in state: | ||
converted_state["conv2.bias"] = state.pop(f"{names.conv1D_1}/1") | ||
|
||
if names.conv1D_2 is not None: | ||
converted_state["conv3.weight"] = state.pop(f"{names.conv1D_2}/0") | ||
if f"{names.conv1D_2}/1" in state: | ||
converted_state["conv3.bias"] = state.pop(f"{names.conv1D_2}/1") | ||
|
||
if names.group_norm is not None: | ||
group_norm_name = names.group_norm | ||
converted_state[f"group_norm.weight"] = state.pop(f"{group_norm_name}/0") | ||
if f"{group_norm_name}/1" in state: | ||
converted_state["group_norm.bias"] = state.pop(f"{group_norm_name}/1") | ||
|
||
return converted_state |