diff --git a/ai_edge_torch/generative/examples/moonshine/__init__.py b/ai_edge_torch/generative/examples/moonshine/__init__.py new file mode 100644 index 00000000..57b12003 --- /dev/null +++ b/ai_edge_torch/generative/examples/moonshine/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py b/ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py new file mode 100644 index 00000000..a5e65d0b --- /dev/null +++ b/ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py @@ -0,0 +1,50 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Example of converting a Moonshine model to multi-signature tflite model.""" + +import os +import pathlib + +from absl import app +from absl import flags +import ai_edge_torch +from ai_edge_torch.generative.examples.moonshine import moonshine +from ai_edge_torch.generative.utilities import converter +import torch + +_CHECKPOINT_PATH = flags.DEFINE_string( + 'checkpoint_path', + os.path.join(pathlib.Path.home(), 'Downloads/llm_data/moonshine'), + 'The path to the model checkpoint, or directory holding the checkpoint.', +) +_TFLITE_PATH = flags.DEFINE_string( + 'tflite_path', + '/tmp/', + 'The tflite file path to export.', +) + + +def main(_): + p_model = moonshine.build_preprocessor(_CHECKPOINT_PATH.value) + output_filename = f'moonshine_preprocessor.tflite' + _input = torch.randn((1, 1, 159414), dtype=torch.float) + edge_model = ai_edge_torch.convert(p_model, (_input,), quant_config=None) + tflite_path = os.path.join(_TFLITE_PATH.value, output_filename) + edge_model.export(tflite_path) + + +if __name__ == '__main__': + app.run(main) diff --git a/ai_edge_torch/generative/examples/moonshine/data/pp_input.pt b/ai_edge_torch/generative/examples/moonshine/data/pp_input.pt new file mode 100644 index 00000000..dc83481e Binary files /dev/null and b/ai_edge_torch/generative/examples/moonshine/data/pp_input.pt differ diff --git a/ai_edge_torch/generative/examples/moonshine/data/pp_output.pt b/ai_edge_torch/generative/examples/moonshine/data/pp_output.pt new file mode 100644 index 00000000..c8f02178 Binary files /dev/null and b/ai_edge_torch/generative/examples/moonshine/data/pp_output.pt differ diff --git a/ai_edge_torch/generative/examples/moonshine/moonshine.py b/ai_edge_torch/generative/examples/moonshine/moonshine.py new file mode 100644 index 00000000..73885e31 --- /dev/null +++ b/ai_edge_torch/generative/examples/moonshine/moonshine.py @@ -0,0 +1,103 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Example of building the Moonshine model.""" + +import os +import pathlib +from typing import Optional, Tuple +from absl import app +from ai_edge_torch.generative.layers import attention +from ai_edge_torch.generative.layers import builder +from ai_edge_torch.generative.layers import kv_cache as kv_utils +import ai_edge_torch.generative.layers.attention_utils as attn_utils +import ai_edge_torch.generative.layers.model_config as cfg +import ai_edge_torch.generative.layers.normalization as normalization +import ai_edge_torch.generative.utilities.moonshine_loader as loading_utils +import h5py +import torch +from torch import nn +import torch.nn as nn + +TENSOR_NAMES = loading_utils.ModelLoader.TensorNames( + conv1D_0="layers/sequential/layers/conv1d/vars", + conv1D_1="layers/sequential/layers/conv1d_1/vars", + conv1D_2="layers/sequential/layers/conv1d_2/vars", + group_norm="layers/sequential/layers/group_normalization/vars", +) + + +class AudioPreprocessor(nn.Module): + + def __init__(self, dim): + super(AudioPreprocessor, self).__init__() + self.conv1 = nn.Conv1d( + in_channels=1, out_channels=dim, kernel_size=127, stride=64, bias=False + ) + self.tanh = nn.Tanh() + self.group_norm = normalization.GroupNorm(group_num=1, dim=dim, eps=1e-5) + self.conv2 = nn.Conv1d( + in_channels=dim, + out_channels=2 * dim, + kernel_size=7, + stride=3, + padding=0, # Equivalent to padding="valid" + ) + self.gelu1 = nn.GELU() + self.conv3 = nn.Conv1d( + in_channels=2 * dim, + out_channels=dim, + kernel_size=3, + stride=2, + padding=0, # Equivalent to padding="valid" + ) + self.gelu2 = nn.GELU() + + def forward(self, inputs): + x = self.conv1(inputs) + x = self.tanh(x) + x = self.group_norm(x) + x = self.conv2(x) + x = self.gelu1(x) + x = self.conv3(x) + x = self.gelu2(x) + return x + + +def build_preprocessor(checkpoint_path: str, **kwargs) -> nn.Module: + ap = AudioPreprocessor(dim=416) + loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES) + loader.load(ap, strict=True) + return ap + + +def main(_): + # TODO(b/375421767) Remove golden checks once full model is implemented. + HF_PATH = os.path.join(pathlib.Path.home(), "Downloads/llm_data/moonshine") + + test_data_path = pathlib.Path(__file__).parent.resolve() + INPUT_PATH = test_data_path / "data" / "pp_input.pt") + GOLDEN_PATH = test_data_path / "data" / "pp_output.pt") + + ap = build_preprocessor(HF_PATH) + ap.eval() + inputs = torch.load(INPUT_PATH).reshape((1, 1, 159414)) + out = ap(inputs) + golden = torch.load(GOLDEN_PATH).transpose(1, 2) + assert torch.allclose(out, golden, atol=1e-4, rtol=1e-4) + + +if __name__ == "__main__": + app.run(main) diff --git a/ai_edge_torch/generative/utilities/moonshine_loader.py b/ai_edge_torch/generative/utilities/moonshine_loader.py new file mode 100644 index 00000000..66167044 --- /dev/null +++ b/ai_edge_torch/generative/utilities/moonshine_loader.py @@ -0,0 +1,154 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Common utility functions for data loading etc. +from dataclasses import dataclass +import glob +import os +from typing import Callable, Dict + +import h5py +import torch + + +def transpose_if_needed(t): + """We assume the file is from Keras, i.e. channel last format.""" + if len(t.shape) > 2: + return t.permute(2, 1, 0) + return t + + +def load_h5_statedict(full_path: str): + """Loads the HDF5 DataSets into a single dctionary. + + Args: + full_path (string): the HDF5 filename or directory that contains the HDF5 + files. + + Returns: + A state dictionary contating loaded tensors. + + Raises: + ValueError: If no tensors are loaded from the provided directory or file. + """ + pattern = ( + os.path.join(full_path, "*.h5") if os.path.isdir(full_path) else full_path + ) + files = [] + for file in glob.glob(pattern): + files.append(file) + + tensors = {} + + def collect_datasets(name, obj): + if isinstance(obj, h5py.Dataset): + tensors[name] = transpose_if_needed(torch.from_numpy(obj[:])) + + for file in files: + with h5py.File(file) as f: + f.visititems(collect_datasets) + + if not tensors: + raise ValueError("Failed to load HDF5 file.") + return tensors + + +class ModelLoader: + """Utility class for loading and converting checkpoints to ODML transformer layer format.""" + + @dataclass + class TensorNames: + conv1D_0: str = None + conv1D_1: str = None + conv1D_2: str = None + group_norm: str = None + + def __init__(self, file_name: str, names: TensorNames) -> None: + """ModelLoader constructor. + + Can be used to load multiple models of the same type. + + Args: + file_name (str): Path to the checkpoint. Can be a directory or an exact + file. + names (TensorNames): An instance of `TensorNames` to determine mappings. + """ + self._file_name = file_name + self._names = names + self._loader = load_h5_statedict + + def load( + self, + model: torch.nn.Module, + strict: bool = True, + ): + """Load the model from the checkpoint + + Args: + model (torch.nn.Module): The pytorch model that needs to be loaded. + strict (bool, optional): Whether the converted keys are strictly + matched. Defaults to True. + + Raises: + ValueError: If conversion results in unmapped tensors and strict mode is + enabled. + """ + state = self._loader(self._file_name) + + if isinstance(self._names, ModelLoader.TensorNames): + converted_state = self._do_load(model, state, self._names) + else: + raise ValueError(f"Unkown type for names: {type(self._names)}") + + if strict and state: + raise ValueError( + "Failed to map all tensor. Remaining tensor are:" + f" {list(state.keys())}" + ) + model.load_state_dict(converted_state, strict=strict) + + def _do_load(self, model, state, names, additional_prefix=""): + """Load the model from the checkpoint + + Args: + model (torch.nn.Module): The pytorch model that needs to be loaded. + state (Dict[str, torch.Tensor]): The pytorch state dictionary + names (TensorNames]): The TensorNames for the model we are loading. + + Returns: + Dict[str, torch.Tensor]: Map of name to tensor for loading. + """ + converted_state = dict() + if names.conv1D_0 is not None: + converted_state["conv1.weight"] = state.pop(f"{names.conv1D_0}/0") + if f"{names.conv1D_0}/1" in state: + converted_state["conv1.bias"] = state.pop(f"{names.conv1D_0}/1") + + if names.conv1D_1 is not None: + converted_state["conv2.weight"] = state.pop(f"{names.conv1D_1}/0") + if f"{names.conv1D_1}/1" in state: + converted_state["conv2.bias"] = state.pop(f"{names.conv1D_1}/1") + + if names.conv1D_2 is not None: + converted_state["conv3.weight"] = state.pop(f"{names.conv1D_2}/0") + if f"{names.conv1D_2}/1" in state: + converted_state["conv3.bias"] = state.pop(f"{names.conv1D_2}/1") + + if names.group_norm is not None: + group_norm_name = names.group_norm + converted_state[f"group_norm.weight"] = state.pop(f"{group_norm_name}/0") + if f"{group_norm_name}/1" in state: + converted_state["group_norm.bias"] = state.pop(f"{group_norm_name}/1") + + return converted_state