From 154f7eba05186ad3bb1b2e4c98cf0fbdf202dc53 Mon Sep 17 00:00:00 2001 From: Guang Yang Date: Tue, 5 Nov 2024 16:35:48 -0800 Subject: [PATCH] Export to ExecuTorch: Initial Integration --- .github/workflows/test_executorch_export.yml | 35 ++ .github/workflows/test_executorch_runtime.yml | 35 ++ docs/source/exporters/executorch/overview.mdx | 26 + .../package_reference/configuration.mdx | 54 ++ .../executorch/package_reference/export.mdx | 26 + .../executorch/usage_guides/contribute.mdx | 57 +++ .../usage_guides/export_a_model.mdx | 124 +++++ docs/source/exporters/overview.mdx | 2 +- optimum/commands/__init__.py | 2 +- optimum/commands/export/__init__.py | 1 + optimum/commands/export/base.py | 6 + optimum/commands/export/executorch.py | 67 +++ optimum/executorchruntime/__init__.py | 29 ++ .../executorchruntime/modeling_executorch.py | 464 ++++++++++++++++++ optimum/exporters/executorch/__init__.py | 44 ++ optimum/exporters/executorch/__main__.py | 160 ++++++ optimum/exporters/executorch/convert.py | 90 ++++ .../exporters/executorch/recipe_registry.py | 68 +++ .../exporters/executorch/recipes/__init__.py | 11 + .../exporters/executorch/recipes/xnnpack.py | 97 ++++ optimum/exporters/executorch/task_registry.py | 68 +++ .../exporters/executorch/tasks/__init__.py | 11 + .../exporters/executorch/tasks/causal_lm.py | 66 +++ optimum/onnxruntime/runs/__init__.py | 6 +- setup.py | 4 + tests/executorch/export/__init__.py | 14 + .../export/test_exporters_executorch.py | 115 +++++ tests/executorch/runtime/__init__.py | 14 + tests/executorch/runtime/test_modeling.py | 207 ++++++++ 29 files changed, 1898 insertions(+), 5 deletions(-) create mode 100644 .github/workflows/test_executorch_export.yml create mode 100644 .github/workflows/test_executorch_runtime.yml create mode 100644 docs/source/exporters/executorch/overview.mdx create mode 100644 docs/source/exporters/executorch/package_reference/configuration.mdx create mode 100644 docs/source/exporters/executorch/package_reference/export.mdx create mode 100644 docs/source/exporters/executorch/usage_guides/contribute.mdx create mode 100644 docs/source/exporters/executorch/usage_guides/export_a_model.mdx create mode 100644 optimum/commands/export/executorch.py create mode 100644 optimum/executorchruntime/__init__.py create mode 100644 optimum/executorchruntime/modeling_executorch.py create mode 100644 optimum/exporters/executorch/__init__.py create mode 100644 optimum/exporters/executorch/__main__.py create mode 100644 optimum/exporters/executorch/convert.py create mode 100644 optimum/exporters/executorch/recipe_registry.py create mode 100644 optimum/exporters/executorch/recipes/__init__.py create mode 100644 optimum/exporters/executorch/recipes/xnnpack.py create mode 100644 optimum/exporters/executorch/task_registry.py create mode 100644 optimum/exporters/executorch/tasks/__init__.py create mode 100644 optimum/exporters/executorch/tasks/causal_lm.py create mode 100644 tests/executorch/export/__init__.py create mode 100644 tests/executorch/export/test_exporters_executorch.py create mode 100644 tests/executorch/runtime/__init__.py create mode 100644 tests/executorch/runtime/test_modeling.py diff --git a/.github/workflows/test_executorch_export.yml b/.github/workflows/test_executorch_export.yml new file mode 100644 index 00000000000..eb8f995f71c --- /dev/null +++ b/.github/workflows/test_executorch_export.yml @@ -0,0 +1,35 @@ +name: ExecuTorch Export / Python - Test + +on: + push: + branches: [main] + pull_request: + branches: [main] + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + build: + strategy: + fail-fast: false + matrix: + python-version: ['3.10', '3.11', '3.12'] + os: [ubuntu-20.04, macos-15] + + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v2 + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies for ExecuTorch + run: | + pip install .[tests,exporters-executorch] + pip list + - name: Run tests + working-directory: tests + run: | + RUN_SLOW=1 pytest executorch/export/test_*.py -s -vvvv --durations=0 diff --git a/.github/workflows/test_executorch_runtime.yml b/.github/workflows/test_executorch_runtime.yml new file mode 100644 index 00000000000..f7e3abcceff --- /dev/null +++ b/.github/workflows/test_executorch_runtime.yml @@ -0,0 +1,35 @@ +name: ExecuTorch Runtime / Python - Test + +on: + push: + branches: [main] + pull_request: + branches: [main] + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + build: + strategy: + fail-fast: false + matrix: + python-version: ['3.10', '3.11', '3.12'] + os: [ubuntu-20.04, macos-15] + + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v2 + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies for ExecuTorch + run: | + pip install .[tests,exporters-executorch] + pip list + - name: Run tests + working-directory: tests + run: | + RUN_SLOW=1 pytest executorch/runtime/test_*.py -s -vvvv --durations=0 diff --git a/docs/source/exporters/executorch/overview.mdx b/docs/source/exporters/executorch/overview.mdx new file mode 100644 index 00000000000..0e880968bf7 --- /dev/null +++ b/docs/source/exporters/executorch/overview.mdx @@ -0,0 +1,26 @@ + + +# Overview + +πŸ€— Optimum handles the export of PyTorch to ExecuTorch in the `exporters.executorch` module. It provides classes, functions, and a command line interface to perform the export easily. + +Supported architectures from [πŸ€— Transformers](https://huggingface.co/docs/transformers/index): + +- Gemma +- Gemma2 +- Llama2 +- Llama3(Llama3.2) +- OLMo +- Qwen2(Qwen2.5) + +There are many more models are supported by ExecuTorch, we will add those models to Optimum over time. Read more at [pytorch/executorch/examples/](https://github.com/pytorch/executorch/tree/main/examples) diff --git a/docs/source/exporters/executorch/package_reference/configuration.mdx b/docs/source/exporters/executorch/package_reference/configuration.mdx new file mode 100644 index 00000000000..b7a10b80419 --- /dev/null +++ b/docs/source/exporters/executorch/package_reference/configuration.mdx @@ -0,0 +1,54 @@ + + +# Configuration for ExecuTorch Export + +ExecuTorch export provides a flexible configuration mechanism through dynamic registration, enabling users to have +complete control over the export process. The configuration system is divided into task configurations and recipe +configurations, each addressing specific aspects of the export pipeline. + + +## Task Configurations + +Task configurations determine how a Hugging Face model should be loaded and prepared for export, tailored to specific tasks. + +For instance, when exporting a model for a text generation task, the provided configuration utilizes **static caching** and +**SDPA (Scaled Dot-Product Attention)** for inference optimization. + +By leveraging task configurations, users can ensure that their models are appropriately prepared for efficient execution on +the ExecuTorch backend. + +[[autodoc]] exporters.executorch.task_registry.discover_tasks + +[[autodoc]] exporters.executorch.task_registry.register_task + +[[autodoc]] exporters.executorch.tasks.causal_lm.load_causal_lm_model + + +## Recipe Configurations + +Recipe configurations control the specifics of lowering an eager PyTorch module to the ExecuTorch backend. These +configurations allow users to: + +- Specify whether and how to **quantize** the model. +- Delegate computation to various accelerators, such as **CPU**, **GPU**, **NPU**, **DSP**, and others. +- Define **custom transformation passes**. +- Implement advanced techniques like memory planning algorithms to optimize resource utilization. + +[[autodoc]] exporters.executorch.recipe_registry.discover_recipes + +[[autodoc]] exporters.executorch.recipe_registry.register_recipe + +[[autodoc]] exporters.executorch.recipes.xnnpack.export_to_executorch_with_xnnpack + +The combination of task and recipe configurations ensures that users can customize both the high-level task setup +and the low-level export details to suit their deployment requirements. diff --git a/docs/source/exporters/executorch/package_reference/export.mdx b/docs/source/exporters/executorch/package_reference/export.mdx new file mode 100644 index 00000000000..6663eb5278e --- /dev/null +++ b/docs/source/exporters/executorch/package_reference/export.mdx @@ -0,0 +1,26 @@ + + +# Export functions + +## Main functions + +[[autodoc]] exporters.executorch.convert.export_to_executorch + +The primary export function is designed to be **model- and task-independent** as well as **optimization-agnostic**, providing a +highly flexible and modular interface for exporting Hugging Face models to the ExecuTorch backend. + +This approach highlights the **composability** of ExecuTorch export pipeline, where dynamically registered **task configurations** +specify how a :hug model is prepared, and **recipe configurations** encapsulate device-specific optimizations during export. This +separation allows users to customize the export process without altering the core function. + +For more details on task and recipe configurations, see the [Configuration for ExecuTorch Export](./configuration.mdx). diff --git a/docs/source/exporters/executorch/usage_guides/contribute.mdx b/docs/source/exporters/executorch/usage_guides/contribute.mdx new file mode 100644 index 00000000000..2c6c1593169 --- /dev/null +++ b/docs/source/exporters/executorch/usage_guides/contribute.mdx @@ -0,0 +1,57 @@ + + +# Adding support for an unsupported architecture + +We welcome contributions to extend the functionality of ExecuTorch export. This guide provides high-level instructions for contributors who want to: + +1. Export a new model that is not currently supported. +2. Add new recipes or support a new task for export. + +--- + +## Exporting a New Model + +If you want to export a model that is not already supported by the library, follow these steps: + +### Step 1: Export and Test the Model +1. Attempt to export and lower the model using an existing task and recipe. On success, it will store the exported model in a `.pte` file. +2. Add a test case for the model in the appropriate test suite. + - For example, you can make sure tests pass for the new `my_new_model` by running: + ```bash + pytest tests/executorch/export/test_*.py -k "test_my_new_model" # doctest: +SKIP + pytest tests/executorch/runtime/test_*.py -k "test_my_new_model" # doctest: +SKIP + ``` + +### Step 2: Handle Export Failures +1. If the export fails in Step 1, report the issue by opening a GitHub issue. +2. If the issue requires changes to the model’s architecture or its Hugging Face implementation, these modifications may be made upstream in the Hugging Face Transformers library. + +--- + +## Adding New Recipes or Tasks + +To extend ExecuTorch with new recipes or tasks, follow these guidelines: + +### Registering a New Recipe +You can add a custom recipe to define specific optimizations or configurations for exporting models. Below is an example: + +```python +from exporters.executorch import register_recipe + +@register_recipe("my_custom_recipe") +def export_with_custom_recipe(model, config, *args, **kwargs): + # Example: Apply a custom quantization +``` + +### Registering a Task +The task registration process is same as adding a recipe. Besides that you may need to implement a new `ExecuTorchModelForXXX` class. diff --git a/docs/source/exporters/executorch/usage_guides/export_a_model.mdx b/docs/source/exporters/executorch/usage_guides/export_a_model.mdx new file mode 100644 index 00000000000..7993188cbd5 --- /dev/null +++ b/docs/source/exporters/executorch/usage_guides/export_a_model.mdx @@ -0,0 +1,124 @@ + + +# Export a model to ExecuTorch with optimum.exporters.executorch + +If you need to deploy πŸ€— Transformers models for on-device use cases, we recommend +exporting them to a serialized format that can be distributed and executed on specialized +runtimes and hardware. In this guide, we'll show you how to export these +models to [ExecuTorch](https://pytorch.org/executorch/main/intro-overview.html). + + +## Why ExecuTorch? + +ExecuTorch is the ideal solution for deploying PyTorch models on edge devices, offering a streamlined process from +export to deployment without leaving PyTorch ecosystem. + +Supporting on-device AI presents unique challenges with diverse hardware, critical power requirements, low/no internet +connectivity, and realtime processing needs. These constraints have historically prevented or slowed down the creation +of scalable and performant on-device AI solutions. We designed ExecuTorch, backed by our industry partners like Meta, +Arm, Apple, Qualcomm, MediaTek, etc. to be highly portable and provide superior developer productivity without losing on +performance. + + +## Summary + +Exporting a PyTorch model to ExecuTorch is as simple as + +```bash +optimum-cli export executorch --model "meta-llama/Llama-3.2-1B" --task "text-generation" --recipe "xnnpack" --output_dir "meta_llama3_2_1b" +``` + +Check out the help for more options: + +```bash +optimum-cli export executorch --help +``` + + +## Exporting a model to ExecuTorch using the CLI + +To export a πŸ€— Transformers model to ExecuTorch, you'll first need to install some extra +dependencies: + +```bash +pip install optimum[exporters-executorch] +``` + +The Optimum ExecuTorch export can be used through Optimum command-line: + +```bash +optimum-cli export executorch --help + +usage: optimum-cli export executorch [-h] -m MODEL [-o OUTPUT_DIR] [--task TASK] [--recipe RECIPE] + +options: + -h, --help show this help message and exit + +Required arguments: + -m MODEL, --model MODEL + Model ID on huggingface.co or path on disk to load model from. + -o OUTPUT_DIR, --output_dir OUTPUT_DIR + Path indicating the directory where to store the generated ExecuTorch model. + --task TASK The task to export the model for. Available tasks depend on the model, but are among: ['audio-classification', 'feature-extraction', 'image-to-text', + 'sentence-similarity', 'depth-estimation', 'image-segmentation', 'audio-frame-classification', 'masked-im', 'semantic-segmentation', 'text-classification', + 'audio-xvector', 'mask-generation', 'question-answering', 'text-to-audio', 'automatic-speech-recognition', 'image-to-image', 'multiple-choice', 'image- + classification', 'text2text-generation', 'token-classification', 'object-detection', 'zero-shot-object-detection', 'zero-shot-image-classification', 'text- + generation', 'fill-mask']. + --recipe RECIPE Pre-defined recipes for export to ExecuTorch. Defaults to "xnnpack". + +``` + +Exporting a checkpoint can be done as follows: + +```bash +optimum-cli export executorch --model "meta-llama/Llama-3.2-1B" --task "text-generation" --recipe "xnnpack" --output_dir "meta_llama3_2_1b" +``` + +You should see a `model.pte` file is stored under "./meta_llama3_2_1b/": + +```bash +meta_llama3_2_1b/ +└── model.pte +``` + +This will fetch the model on the Hub and exports the PyTorch model with the specialized recipe. The resulting `model.pte` file can then be run on the [XNNPACK backend](https://pytorch.org/executorch/main/tutorial-xnnpack-delegate-lowering.html), or on many +other ExecuTorh supported backends if exports with different recipes, e.g. Apple's [Core ML](https://pytorch.org/executorch/main/build-run-coreml.html) or [MPS](https://pytorch.org/executorch/main/build-run-mps.html), [Qualcomm's SoCs](https://pytorch.org/executorch/main/build-run-qualcomm-ai-engine-direct-backend.html), [ARM's Ethos-U](https://pytorch.org/executorch/main/executorch-arm-delegate-tutorial.html), [Xtensa HiFi4 DSP](https://pytorch.org/executorch/main/build-run-xtensa.html), [Vulkan GPU](https://pytorch.org/executorch/main/build-run-vulkan.html), [MediaTek](https://pytorch.org/executorch/main/build-run-mediatek-backend.html), etc. + +For example, we can load and run the model with [ExecuTorch +Runtime](https://pytorch.org/executorch/main/runtime-overview.html) using the `optimum.executorchruntime` package as follows: + +```python +>>> from transformers import AutoTokenizer +>>> from optimum.executorchruntime import ExecuTorchModelForCausalLM + +>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") # doctest: +SKIP +>>> model = ExecuTorchModelForCausalLM.from_pretrained("meta_llama3_2_1b/", export=False) # doctest: +SKIP + +>>> generated_text = model.text_generation(tokenizer=tokenizer, prompt="Simply put, the theory of relativity states that", max_seq_len=45) # doctest: +SKIP +``` + +Printing the `generated_text` would give that: + +``` +"Simply put, the theory of relativity states that the laws of physics are the same in all inertial frames of reference. In other words, the laws of physics are the same in all inertial frames of reference." +``` + +As you can see, converting a model to ExecuTorch does not mean leaving the Hugging Face ecosystem. You end up with a similar API as regular πŸ€— Transformers models! + +It is also possible to export the model to ExecuTorch directly from the `ExecuTorchModelForCausalLM` class by doing the following: + +```python +>>> from optimum.executorchruntime import ExecuTorchModelForCausalLM + +>>> model = ExecuTorchModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B", export=True, task="text-generation", recipe="xnnpack") +``` diff --git a/docs/source/exporters/overview.mdx b/docs/source/exporters/overview.mdx index 6fd7bd9d916..2b4c2e11792 100644 --- a/docs/source/exporters/overview.mdx +++ b/docs/source/exporters/overview.mdx @@ -12,4 +12,4 @@ specific language governing permissions and limitations under the License. # Overview -πŸ€— Optimum enables exporting models from PyTorch or TensorFlow to different formats through its `exporters` module. For now, two exporting format are supported: ONNX and TFLite (TensorFlow Lite). +πŸ€— Optimum enables exporting models from PyTorch or TensorFlow to different formats through its `exporters` module. For now, three exporting format are supported: ONNX, TFLite (TensorFlow Lite), and ExecuTorch. diff --git a/optimum/commands/__init__.py b/optimum/commands/__init__.py index 8a2a276d1c5..a31344ed133 100644 --- a/optimum/commands/__init__.py +++ b/optimum/commands/__init__.py @@ -14,5 +14,5 @@ from .base import BaseOptimumCLICommand, CommandInfo, RootOptimumCLICommand from .env import EnvironmentCommand -from .export import ExportCommand, ONNXExportCommand, TFLiteExportCommand +from .export import ExecuTorchExportCommand, ExportCommand, ONNXExportCommand, TFLiteExportCommand from .optimum_cli import optimum_cli_subcommand diff --git a/optimum/commands/export/__init__.py b/optimum/commands/export/__init__.py index 19da68a60d2..b72cd5dbc8d 100644 --- a/optimum/commands/export/__init__.py +++ b/optimum/commands/export/__init__.py @@ -14,5 +14,6 @@ from .base import ExportCommand +from .executorch import ExecuTorchExportCommand from .onnx import ONNXExportCommand from .tflite import TFLiteExportCommand diff --git a/optimum/commands/export/base.py b/optimum/commands/export/base.py index 07737cb8eaf..e5ed4c90ff5 100644 --- a/optimum/commands/export/base.py +++ b/optimum/commands/export/base.py @@ -15,6 +15,7 @@ """optimum.exporters command-line interface base classes.""" from .. import BaseOptimumCLICommand, CommandInfo +from .executorch import ExecuTorchExportCommand from .onnx import ONNXExportCommand from .tflite import TFLiteExportCommand @@ -25,6 +26,11 @@ class ExportCommand(BaseOptimumCLICommand): help="Export PyTorch and TensorFlow models to several format.", ) SUBCOMMANDS = ( + CommandInfo( + name="executorch", + help="Export PyTorch model to ExecuTorch.", + subcommand_class=ExecuTorchExportCommand, + ), CommandInfo( name="onnx", help="Export PyTorch and TensorFlow to ONNX.", diff --git a/optimum/commands/export/executorch.py b/optimum/commands/export/executorch.py new file mode 100644 index 00000000000..2bf2f1d3054 --- /dev/null +++ b/optimum/commands/export/executorch.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +"""Defines the command line for the export with ExecuTorch.""" + +from pathlib import Path +from typing import TYPE_CHECKING + +from ...exporters import TasksManager +from ..base import BaseOptimumCLICommand + + +if TYPE_CHECKING: + from argparse import ArgumentParser + + +def parse_args_executorch(parser): + required_group = parser.add_argument_group("Required arguments") + required_group.add_argument( + "-m", "--model", type=str, required=True, help="Model ID on huggingface.co or path on disk to load model from." + ) + required_group.add_argument( + "-o", + "--output_dir", + type=Path, + help="Path indicating the directory where to store the generated ExecuTorch model.", + ) + required_group.add_argument( + "--task", + type=str, + default="text-generation", + help=( + "The task to export the model for. Available tasks depend on the model, but are among:" + f" {str(TasksManager.get_all_tasks())}." + ), + ) + required_group.add_argument( + "--recipe", + type=str, + default="xnnpack", + help='Pre-defined recipes for export to ExecuTorch. Defaults to "xnnpack".', + ) + + +class ExecuTorchExportCommand(BaseOptimumCLICommand): + @staticmethod + def parse_args(parser: "ArgumentParser"): + return parse_args_executorch(parser) + + def run(self): + from ...exporters.executorch import main_export + + main_export( + model_name_or_path=self.args.model, + task=self.args.task, + recipe=self.args.recipe, + output_dir=self.args.output_dir, + ) diff --git a/optimum/executorchruntime/__init__.py b/optimum/executorchruntime/__init__.py new file mode 100644 index 00000000000..0a84c3a139b --- /dev/null +++ b/optimum/executorchruntime/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +from typing import TYPE_CHECKING + +from transformers.utils import _LazyModule + + +_import_structure = { + "modeling_executorch": [ + "ExecuTorchModelForCausalLM", + ], +} + +if TYPE_CHECKING: + from .modeling_executorch import ExecuTorchModelForCausalLM +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/optimum/executorchruntime/modeling_executorch.py b/optimum/executorchruntime/modeling_executorch.py new file mode 100644 index 00000000000..39c75a03863 --- /dev/null +++ b/optimum/executorchruntime/modeling_executorch.py @@ -0,0 +1,464 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +"""ExecuTorchModelForXXX classes, allowing to run ExecuTorch Models with ExecuTorch Runtime using the same API as Transformers.""" + +import logging +import os +import warnings +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import TYPE_CHECKING, List, Optional, Union + +import torch +from executorch.extension.pybindings.portable_lib import ( + ExecuTorchModule, + _load_for_executorch, +) +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +from transformers import ( + AutoModelForCausalLM, + PretrainedConfig, + PreTrainedTokenizer, +) + +from ..exporters.executorch import main_export +from ..modeling_base import OptimizedModel + + +if TYPE_CHECKING: + from transformers import PretrainedConfig + + +logger = logging.getLogger(__name__) + + +class ExecuTorchModelForCausalLM(OptimizedModel): + """ + ExecuTorch model with a causal language modeling head for inference using the ExecuTorch Runtime. + + This class provides an interface for loading, running, and generating outputs from a causal language model + optimized for ExecuTorch Runtime. It includes utilities for exporting and loading pre-trained models + compatible with ExecuTorch runtime. + + Attributes: + auto_model_class (`Type`): + Associated Transformers class, `AutoModelForCausalLM`. + et_model (`ExecuTorchModule`): + The loaded ExecuTorch model. + use_kv_cache (`bool`): + Whether key-value caching is enabled. For performance reasons, the exported model is + optimized to use a static cache. + max_cache_size (`int`): + Maximum sequence length supported by the cache. + max_batch_size (`int`): + Maximum supported batch size. + dtype (`str`): + Data type of the model parameters. + bos_token_id (`int`): + Beginning-of-sequence token ID. + eos_token_id (`int`): + End-of-sequence token ID. + vocab_size (`int`): + Size of the model vocabulary. + """ + + auto_model_class = AutoModelForCausalLM + + def __init__( + self, + model: "ExecuTorchModule", + config: "PretrainedConfig", + ): + super().__init__(model, config) + self.et_model = model + metadata = self.et_model.method_names() + logging.info(f"Load all static methods: {metadata}") + if "use_kv_cache" in metadata: + self.use_kv_cache = self.et_model.run_method("use_kv_cache")[0] + if "get_max_seq_len" in metadata: + self.max_cache_size = self.et_model.run_method("get_max_seq_len")[0] + if "get_max_batch_size" in metadata: + self.max_batch_size = self.et_model.run_method("get_max_batch_size")[0] + if "get_dtype" in metadata: + self.dtype = self.et_model.run_method("get_dtype")[0] + if "get_bos_id" in metadata: + self.bos_token_id = self.et_model.run_method("get_bos_id")[0] + if "get_eos_id" in metadata: + self.eos_token_id = self.et_model.run_method("get_eos_id")[0] + if "get_vocab_size" in metadata: + self.vocab_size = self.et_model.run_method("get_vocab_size")[0] + + def forward( + self, + input_ids: torch.Tensor, + cache_position: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass of the model, which is compatible with the ExecuTorch runtime for LLM. + + Args: + input_ids (`torch.Tensor`): Tensor representing current input token id to the model. + cache_position (`torch.Tensor`): Tensor representing current input position in the cache. + + Returns: + torch.Tensor: Logits output from the model. + """ + return self.et_model.forward((input_ids, cache_position))[0] + + @classmethod + def from_pretrained( + cls, + model_name_or_path: Union[str, Path], + export: bool = True, + task: str = "", + recipe: str = "", + config: "PretrainedConfig" = None, + subfolder: str = "", + revision: Optional[str] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, + force_download: bool = False, + local_files_only: bool = False, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, + **kwargs, + ) -> "ExecuTorchModelForCausalLM": + """ + Load a pre-trained ExecuTorch model. + + Args: + model_name_or_path (`Union[str, Path]`): + Model ID on huggingface.co or path on disk to the model repository to export. Example: `model_name_or_path="meta-llama/Llama-3.2-1B"` or `mode_name_or_path="/path/to/model_folder`. + export (`bool`, *optional*, defaults to `True`): + If `True`, the model will be exported from eager to ExecuTorch after fetched from huggingface.co. `model_name_or_path` must be a valid model ID on huggingface.co. + If `False`, the previously exported ExecuTorch model will be loaded from a local path. `model_name_or_path` must be a valid local directory where a `model.pte` is stored. + task (`str`, defaults to `""`): + The task to export the model for, e.g. "text-generation". It is required to specify a task when `export` is `True`. + recipe (`str`, defaults to `""`): + The recipe to use to do the export, e.g. "xnnpack". It is required to specify a task when `export` is `True`. + config (`PretrainedConfig`, *optional*): + Configuration of the pre-trained model. + subfolder (`str`, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo either locally or on huggingface.co, you can + specify the folder name here. + revision (`str`, defaults to `"main"`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. + cache_dir (`Optional[str]`, defaults to `None`): + Path indicating where to store cache. The default Hugging Face cache path will be used by default. + force_download (`bool`, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + local_files_only (`Optional[bool]`, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`Optional[Union[bool,str]]`, defaults to `None`): + Deprecated. Please use the `token` argument instead. + token (`Optional[Union[bool,str]]`, defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). + **kwargs: + Additional configuration options to tasks and recipes. + + Returns: + `ExecuTorchModelForCausalLM`: An instance of the ExecuTorch model for text generation task. + """ + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + + if export: + # Fetch the model from huggingface.co and export it to ExecuTorch + if task == "": + raise ValueError("Please specify a task to export the model for.") + if recipe == "": + raise ValueError("Please specify a recipe to export the model for.") + return cls._export( + model_id=model_name_or_path, + task=task, + recipe=recipe, + config=config, + **kwargs, + ) + else: + # Load the ExecuTorch model from a local path + return cls._from_pretrained( + model_dir_path=model_name_or_path, + config=config, + ) + + @classmethod + def _from_pretrained( + cls, + model_dir_path: Union[str, Path], + config: PretrainedConfig, + subfolder: str = "", + revision: Optional[str] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, + force_download: bool = False, + local_files_only: bool = False, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, + ) -> "ExecuTorchModelForCausalLM": + """ + Load a pre-trained ExecuTorch model from a local directory. + + Args: + model_dir_path (`Union[str, Path]`): + Path to the directory containing the ExecuTorch model file (`model.pte`). + config (`PretrainedConfig`, *optional*): + Configuration of the pre-trained model. + subfolder (`str`, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo either locally or on huggingface.co, you can + specify the folder name here. + revision (`str`, defaults to `"main"`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. + cache_dir (`Optional[str]`, defaults to `None`): + Path indicating where to store cache. The default Hugging Face cache path will be used by default. + force_download (`bool`, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + local_files_only (`Optional[bool]`, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`Optional[Union[bool,str]]`, defaults to `None`): + Deprecated. Please use the `token` argument instead. + token (`Optional[Union[bool,str]]`, defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). + + Returns: + `ExecuTorchModelForCausalLM`: The initialized ExecuTorch model. + + """ + full_path = os.path.join(f"{model_dir_path}", "model.pte") + model = _load_for_executorch(full_path) + logging.info(f"Loaded model from {full_path}") + logging.debug(f"{model.method_meta('forward')}") + return cls( + model=model, + config=config, + ) + + def _save_pretrained(self, save_directory): + """ + Saves a model weights into a directory, so that it can be re-loaded using the + [`from_pretrained`] class method. + """ + raise NotImplementedError + + @classmethod + def _export( + cls, + model_id: str, + task: str, + recipe: str, + config: PretrainedConfig, + cache_dir: str = HUGGINGFACE_HUB_CACHE, + trust_remote_code: bool = False, + subfolder: str = "", + revision: Optional[str] = None, + force_download: bool = False, + local_files_only: bool = False, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, + **kwargs, + ): + """ + Fetch a model from the Hugging Face Hub and export it to ExecuTorch format. + + Args: + model_id (`str`): + Model ID on huggingface.co, for example: `model_name_or_path="meta-llama/Llama-3.2-1B"`. + task (`str`): + The task to export the model for, e.g. "text-generation". + recipe (`str`): + The recipe to use to do the export, e.g. "xnnpack". + config (`PretrainedConfig`, *optional*): + Configuration of the pre-trained model. + cache_dir (`Optional[str]`, defaults to `None`): + Path indicating where to store cache. The default Hugging Face cache path will be used by default. + trust_remote_code (`bool`, defaults to `False`): + Allows to use custom code for the modeling hosted in the model repository. This option should only be set for repositories + you trust and in which you have read the code, as it will execute on your local machine arbitrary code present in the + model repository. + subfolder (`str`, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo either locally or on huggingface.co, you can + specify the folder name here. + revision (`str`, defaults to `"main"`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. + force_download (`bool`, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + local_files_only (`Optional[bool]`, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`Optional[Union[bool,str]]`, defaults to `None`): + Deprecated. Please use the `token` argument instead. + token (`Optional[Union[bool,str]]`, defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). + **kwargs: + Additional configuration options to tasks and recipes. + + Returns: + `ExecuTorchModelForCausalLM`: The loaded and exported ExecuTorch model. + + """ + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + + save_dir = TemporaryDirectory() + save_dir_path = Path(save_dir.name) + + # Export to ExecuTorch and save the pte file to the temporary directory + main_export( + model_name_or_path=model_id, + output_dir=save_dir_path, + task=task, + recipe=recipe, + subfolder=subfolder, + revision=revision, + cache_dir=cache_dir, + token=token, + local_files_only=local_files_only, + force_download=force_download, + trust_remote_code=trust_remote_code, + **kwargs, + ) + + return cls._from_pretrained( + model_dir_path=save_dir_path, + config=config, + use_auth_token=use_auth_token, + subfolder=subfolder, + revision=revision, + cache_dir=cache_dir, + token=token, + local_files_only=local_files_only, + force_download=force_download, + ) + + def generate( + self, + prompt_tokens: List[int], + echo: bool = False, + pos_base: int = 0, + max_seq_len: Optional[int] = None, + ) -> List[int]: + """ + Generate tokens from a prompt using the ExecuTorch model. + + Args: + prompt_tokens (List[int]): + List of token IDs representing the prompt. + echo (`bool`, *optional*): + Whether to include prompt tokens in the generated output. Defaults to `False`. + pos_base (`int`, *optional*): + Base position for the prompt tokens. Defaults to 0. + max_seq_len (`int`, *optional*): + Maximum sequence length for the generated output. + Defaults to None and uses the model's `max_cache_size` attribute. + Will be truncated to maximal cache size if larger than `max_cache_size`. + + Returns: + List[int]: List of generated token IDs. + + Note: + Temporarily implemented this method in Python due to limited access to ExecuTorch's c++ LLM runner via pybind. + Expect improvements to the pybind interface in ExecuTorch version 0.4.1. + """ + self.device = torch.device("cpu") + if max_seq_len is None: + # Default to max_cache_size if max_seq_len is not specified + max_seq_len = self.max_cache_size + elif max_seq_len > self.max_cache_size: + logging.warning( + f"max_seq_len={max_seq_len} is larger than max_cache_size={self.max_cache_size}. Generating tokens will be truncated to max_cache_size." + ) + max_seq_len = self.max_cache_size + generated_tokens = [] + + # prefill + for i, prompt_token in enumerate(prompt_tokens): + logits = self.forward( + input_ids=torch.tensor([prompt_token], dtype=torch.long, device=self.device).unsqueeze(0), + cache_position=torch.tensor([i], dtype=torch.long, device=self.device), + ) + + next_token = torch.argmax(logits, dim=-1).item() + generated_tokens = prompt_tokens + [next_token] + + while len(generated_tokens) < max_seq_len: + logits = self.forward( + input_ids=torch.tensor([next_token], dtype=torch.long, device=self.device).unsqueeze(0), + cache_position=torch.tensor( + [pos_base + len(generated_tokens) - 1], + dtype=torch.long, + device=self.device, + ), + ) + next_token = torch.argmax(logits, dim=-1).item() + generated_tokens.append(next_token) + if next_token == self.eos_token_id: + break + + return generated_tokens if echo else generated_tokens[len(prompt_tokens) :] + + def text_generation( + self, + tokenizer: "PreTrainedTokenizer", + prompt: str, + echo: bool = True, + max_seq_len: Optional[int] = None, + ): + """ + Perform text generation task for a given prompt using the ExecuTorch model. + + Args: + tokenizer (`PreTrainedTokenizer`): + The tokenizer used to encode and decode the prompt and output. + prompt (`str`): + The text prompt to complete. + echo (`bool`, *optional*): + Whether to include prompt tokens in the generated output. Defaults to `True`. + max_seq_len (`int`, *optional*): + Maximum sequence length for the generated output. + Defaults to None and uses the model's `max_cache_size` attribute. + Will be truncated to maximal cache size if larger than `max_cache_size`. + """ + self.tokenizer = tokenizer + + # Sanity check + if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.bos_token_id: + raise ValueError( + f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} must be the same as the model's bos_token_id={self.bos_token_id}." + ) + if self.tokenizer.eos_token_id is not None and self.tokenizer.eos_token_id != self.eos_token_id: + raise ValueError( + f"The tokenizer's eos_token_id={self.tokenizer.eos_token_id} must be the same as the model's eos_token_id={self.eos_token_id}." + ) + + prompt_tokens = self.tokenizer.encode(prompt) + generated_tokens = self.generate( + prompt_tokens=prompt_tokens, + echo=echo, + max_seq_len=max_seq_len, + ) + return self.tokenizer.decode(generated_tokens, skip_special_tokens=True) diff --git a/optimum/exporters/executorch/__init__.py b/optimum/exporters/executorch/__init__.py new file mode 100644 index 00000000000..cbdd2bfc0a9 --- /dev/null +++ b/optimum/exporters/executorch/__init__.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +from typing import TYPE_CHECKING + +from transformers.utils import _LazyModule + + +_import_structure = { + "convert": [ + "export_to_executorch", + ], + "recipe_registry": [ + "discover_recipes", + "register_recipe", + ], + "task_registry": [ + "discover_tasks", + "register_task", + ], + "__main__": ["main_export"], +} + +if TYPE_CHECKING: + from .__main__ import main_export + from .convert import export_to_executorch +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/optimum/exporters/executorch/__main__.py b/optimum/exporters/executorch/__main__.py new file mode 100644 index 00000000000..33a668b0674 --- /dev/null +++ b/optimum/exporters/executorch/__main__.py @@ -0,0 +1,160 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +"""Entry point to the optimum.exporters.executorch command line.""" + +import argparse +import os +import warnings +from pathlib import Path + +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +from transformers.utils import is_torch_available + +from optimum.utils.import_utils import check_if_transformers_greater + +from ...commands.export.executorch import parse_args_executorch +from .convert import export_to_executorch +from .task_registry import discover_tasks, task_registry + + +if is_torch_available(): + pass + +from typing import Optional, Union + + +def main_export( + model_name_or_path: str, + task: str, + recipe: str, + output_dir: Union[str, Path], + cache_dir: str = HUGGINGFACE_HUB_CACHE, + trust_remote_code: bool = False, + pad_token_id: Optional[int] = None, + subfolder: str = "", + revision: str = "main", + force_download: bool = False, + local_files_only: bool = False, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, + **kwargs, +): + """ + Full-suite ExecuTorch export function, exporting **from a model ID on Hugging Face Hub or a local model repository**. + + Args: + model_name_or_path (`str`): + Model ID on huggingface.co or path on disk to the model repository to export. Example: `model_name_or_path="meta-llama/Llama-3.2-1B"` or `mode_name_or_path="/path/to/model_folder`. + task (`str`): + The task to export the model for, e.g. "text-generation". + recipe (`str`): + The recipe to use to do the export, e.g. "xnnpack". + output_dir (`Union[str, Path]`): + Path indicating the directory where to store the generated ExecuTorch model. + cache_dir (`Optional[str]`, defaults to `None`): + Path indicating where to store cache. The default Hugging Face cache path will be used by default. + trust_remote_code (`bool`, defaults to `False`): + Allows to use custom code for the modeling hosted in the model repository. This option should only be set for repositories + you trust and in which you have read the code, as it will execute on your local machine arbitrary code present in the + model repository. + pad_token_id (`Optional[int]`, defaults to `None`): + This is needed by some models, for some tasks. If not provided, will attempt to use the tokenizer to guess it. + subfolder (`str`, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo either locally or on huggingface.co, you can + specify the folder name here. + revision (`str`, defaults to `"main"`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. + force_download (`bool`, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + local_files_only (`Optional[bool]`, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`Optional[Union[bool,str]]`, defaults to `None`): + Deprecated. Please use the `token` argument instead. + token (`Optional[Union[bool,str]]`, defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). + **kwargs: + Additional configuration options to tasks and recipes. + + Example usage: + ```python + >>> from optimum.exporters.executorch import main_export + + >>> main_export("meta-llama/Llama-3.2-1B", "text-generation", "xnnpack", "meta_llama3_2_1b/") + ``` + """ + + if not check_if_transformers_greater("4.46"): + raise ValueError( + "The minimum Transformers version compatible with ExecuTorch is 4.46.0. Please upgrade to Transformers 4.46.0 or later." + ) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + + # Dynamically discover and import registered tasks + discover_tasks() + + # Load the model for specific task + try: + task_func = task_registry.get(task) + except KeyError as e: + raise RuntimeError(f"The task '{task}' isn't registered. Detailed error: {e}") + + model = task_func(model_name_or_path, **kwargs) + + if task == "text-generation": + from transformers.integrations.executorch import TorchExportableModuleWithStaticCache + + model = TorchExportableModuleWithStaticCache(model) + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + return export_to_executorch( + model=model, + task=task, + recipe=recipe, + output_dir=output_dir, + **kwargs, + ) + + +def main(): + parser = argparse.ArgumentParser("Hugging Face Optimum ExecuTorch exporter") + + parse_args_executorch(parser) + + # Retrieve CLI arguments + args = parser.parse_args() + + main_export( + model_name_or_path=args.model, + output_dir=args.output_dir, + task=args.task, + recipe=args.recipe, + cache_dir=args.cache_dir, + trust_remote_code=args.trust_remote_code, + pad_token_id=args.pad_token_id, + ) + + +if __name__ == "__main__": + main() diff --git a/optimum/exporters/executorch/convert.py b/optimum/exporters/executorch/convert.py new file mode 100644 index 00000000000..f50a4b54a96 --- /dev/null +++ b/optimum/exporters/executorch/convert.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +"""ExecuTorch model check and export functions.""" + +import logging +import os +from pathlib import Path +from typing import Union + +from transformers.utils import is_torch_available + +from optimum.utils.import_utils import check_if_transformers_greater + +from .recipe_registry import discover_recipes, recipe_registry + + +if is_torch_available(): + from transformers.modeling_utils import PreTrainedModel + +if check_if_transformers_greater("4.46"): + from transformers.integrations.executorch import ( + TorchExportableModuleWithStaticCache, + ) + +logger = logging.getLogger(__name__) + + +def export_to_executorch( + model: Union["PreTrainedModel", "TorchExportableModuleWithStaticCache"], + task: str, + recipe: str, + output_dir: Union[str, Path], + **kwargs, +): + """ + Export a pre-trained PyTorch model to the ExecuTorch format using a specified recipe. + + This function facilitates the transformation of a PyTorch model into an optimized ExecuTorch program. + + Args: + model (`Union["PreTrainedModel", "TorchExportableModuleWithStaticCache"]`): + A PyTorch model to be exported. This can be a standard HuggingFace `PreTrainedModel` or a wrapped + module like `TorchExportableModuleWithStaticCache` for text generation task. + task (`str`): + The specific task the exported model will perform, e.g., "text-generation". + recipe (`str`): + The recipe to guide the export process, e.g., "xnnpack". Recipes define the optimization and lowering steps. + Will raise an exception if the specified recipe is not registered in the recipe registry. + output_dir (`Union[str, Path]`): + Path to the directory where the resulting ExecuTorch model will be saved. + **kwargs: + Additional configuration options passed to the recipe. + + Returns: + `ExecuTorchProgram`: + The lowered ExecuTorch program object. + + Notes: + - The function uses a dynamic recipe discovery mechanism to identify and import the specified recipe. + - The exported model is stored in the specified output directory with the fixed filename `model.pte`. + - The resulting ExecuTorch program is serialized and saved to the output directory. + """ + + # Dynamically discover and import registered recipes + discover_recipes() + + # Export and lower the model to ExecuTorch with the recipe + try: + recipe_func = recipe_registry.get(recipe) + except KeyError as e: + raise RuntimeError(f"The recipe '{recipe}' isn't registered. Detailed error: {e}") + + executorch_prog = recipe_func(model, task, **kwargs) + + full_path = os.path.join(f"{output_dir}", "model.pte") + with open(full_path, "wb") as f: + executorch_prog.write_to_file(f) + logging.info(f"Saved exported program to {full_path}") + + return executorch_prog diff --git a/optimum/exporters/executorch/recipe_registry.py b/optimum/exporters/executorch/recipe_registry.py new file mode 100644 index 00000000000..2eb728b7573 --- /dev/null +++ b/optimum/exporters/executorch/recipe_registry.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import importlib +import logging +import pkgutil + + +logger = logging.getLogger(__name__) + +recipe_registry = {} + +package_name = "optimum.exporters.executorch.recipes" + + +def register_recipe(recipe_name): + """ + Decorator to register a recipe for exporting and lowering an ExecuTorch model under a specific name. + + Args: + recipe_name (`str`): + The name of the recipe to associate with a callable recipe. + + Returns: + `Callable`: + The original function wrapped as a registered recipe. + + Example: + ```python + @register_recipe("my_new_recipe") + def my_new_recipe(...): + ... + ``` + """ + + def decorator(func): + recipe_registry[recipe_name] = func + return func + + return decorator + + +def discover_recipes(): + """ + Dynamically discovers and imports all recipe modules within the `optimum.exporters.executorch.recipes` package. + + Ensures recipes under `./recipes` directory are dynamically loaded without requiring manual imports. + + Notes: + New recipes **must** be added to the `./recipes` directory to be discovered and used by `main_export`. + Failure to do so will prevent dynamic discovery and registration. Recipes must also use the + `@register_recipe` decorator to be properly registered in the `recipe_registry`. + """ + package = importlib.import_module(package_name) + package_path = package.__path__ + + for _, module_name, _ in pkgutil.iter_modules(package_path): + logger.info(f"Importing {package_name}.{module_name}") + importlib.import_module(f"{package_name}.{module_name}") diff --git a/optimum/exporters/executorch/recipes/__init__.py b/optimum/exporters/executorch/recipes/__init__.py new file mode 100644 index 00000000000..30466c2d1a1 --- /dev/null +++ b/optimum/exporters/executorch/recipes/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. diff --git a/optimum/exporters/executorch/recipes/xnnpack.py b/optimum/exporters/executorch/recipes/xnnpack.py new file mode 100644 index 00000000000..d3b3a5d52aa --- /dev/null +++ b/optimum/exporters/executorch/recipes/xnnpack.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +from typing import Union + +import torch +import torch.export._trace +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner +from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, +) +from torch.nn.attention import SDPBackend +from transformers import PreTrainedModel, TorchExportableModuleWithStaticCache + +from ..recipe_registry import register_recipe + + +@register_recipe("xnnpack") +def export_to_executorch_with_xnnpack( + model: Union[PreTrainedModel, TorchExportableModuleWithStaticCache], + task: str, + **kwargs, +): + """ + Export a PyTorch model to ExecuTorch w/ delegation to XNNPACK backend. + + This function also write metadata required by the ExecuTorch runtime to the model. + + Args: + model (Union[PreTrainedModel, TorchExportableModuleWithStaticCache]): + The PyTorch model to be exported to ExecuTorch. + task (str): + The task name to export the model for (e.g., "text-generation"). + **kwargs: + Additional keyword arguments for recipe-specific configurations. + + Returns: + ExecuTorchProgram: + The exported and optimized program for ExecuTorch. + """ + metadata = {} + if task == "text-generation": + example_input_ids = torch.tensor([[1]], dtype=torch.long) + example_cache_position = torch.tensor([0], dtype=torch.long) + + def _get_constant_methods(model: PreTrainedModel): + metadata = { + "get_dtype": 5 if model.config.torch_dtype == torch.float16 else 6, + "get_bos_id": model.config.bos_token_id, + "get_eos_id": model.config.eos_token_id, + "get_head_dim": model.config.hidden_size / model.config.num_attention_heads, + "get_max_batch_size": model.generation_config.cache_config.batch_size, + "get_max_seq_len": model.generation_config.cache_config.max_cache_len, + "get_n_kv_heads": model.config.num_key_value_heads, + "get_n_layers": model.config.num_hidden_layers, + "get_vocab_size": model.config.vocab_size, + "use_kv_cache": model.generation_config.use_cache, + } + return {k: v for k, v in metadata.items() if v is not None} + + metadata = _get_constant_methods(model if isinstance(model, PreTrainedModel) else model.model) + else: + # TODO: Prepare model inputs for other tasks + raise ValueError(f"Unsupported task '{task}'.") + + with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): + exported_program = torch.export._trace._export( + model, + args=(example_input_ids,), + kwargs={"cache_position": example_cache_position}, + pre_dispatch=False, + strict=True, + ) + + return to_edge_transform_and_lower( + exported_program, + partitioner=[XnnpackPartitioner()], + compile_config=EdgeCompileConfig( + _skip_dim_order=True, + ), + constant_methods=metadata, + ).to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + ), + ) diff --git a/optimum/exporters/executorch/task_registry.py b/optimum/exporters/executorch/task_registry.py new file mode 100644 index 00000000000..fdc34f0359a --- /dev/null +++ b/optimum/exporters/executorch/task_registry.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import importlib +import logging +import pkgutil + + +logger = logging.getLogger(__name__) + +task_registry = {} + +package_name = "optimum.exporters.executorch.tasks" + + +def register_task(task_name): + """ + Decorator to register a task under a specific name. + + Args: + task_name (`str`): + The name of the task to associate with a callable task. + + Returns: + `Callable`: + The original function wrapped as a registered task. + + Example: + ```python + @register_task("my_new_task") + def my_new_task(...): + ... + ``` + """ + + def decorator(func): + task_registry[task_name] = func + return func + + return decorator + + +def discover_tasks(): + """ + Dynamically discovers and imports all task modules within the `optimum.exporters.executorch.tasks` package. + + Ensures tasks under `./tasks` directory are dynamically loaded without requiring manual imports. + + Notes: + New tasks **must** be added to the `./tasks` directory to be discovered and used by `main_export`. + Failure to do so will prevent dynamic discovery and registration. Tasks must also use the + `@register_task` decorator to be properly registered in the `task_registry`. + """ + package = importlib.import_module(package_name) + package_path = package.__path__ + + for _, module_name, _ in pkgutil.iter_modules(package_path): + logger.info(f"Importing {package_name}.{module_name}") + importlib.import_module(f"{package_name}.{module_name}") diff --git a/optimum/exporters/executorch/tasks/__init__.py b/optimum/exporters/executorch/tasks/__init__.py new file mode 100644 index 00000000000..30466c2d1a1 --- /dev/null +++ b/optimum/exporters/executorch/tasks/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. diff --git a/optimum/exporters/executorch/tasks/causal_lm.py b/optimum/exporters/executorch/tasks/causal_lm.py new file mode 100644 index 00000000000..b02da8b319e --- /dev/null +++ b/optimum/exporters/executorch/tasks/causal_lm.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +from transformers import AutoModelForCausalLM, GenerationConfig + +from ..task_registry import register_task + + +@register_task("text-generation") +def load_causal_lm_model(model_name_or_path: str, **kwargs): + """ + Loads a causal language model for text generation and registers it under the task + 'text-generation' using Hugging Face's AutoModelForCausalLM. + + Args: + model_name_or_path (str): + Model ID on huggingface.co or path on disk to the model repository to export. For example: + `model_name_or_path="meta-llama/Llama-3.2-1B"` or `mode_name_or_path="/path/to/model_folder` + **kwargs: + Additional configuration options for the model: + - dtype (str, optional): + Data type for model weights (default: "float32"). + Options include "float16" and "bfloat16". + - attn_implementation (str, optional): + Attention mechanism implementation (default: "sdpa"). + - cache_implementation (str, optional): + Cache management strategy (default: "static"). + - max_length (int, optional): + Maximum sequence length for generation (default: 2048). + + Returns: + transformers.PreTrainedModel: + An instance of a model subclass (e.g., Llama, Gemma) with the configuration for exporting + and lowering to ExecuTorch. + """ + device = "cpu" + batch_size = 1 + dtype = kwargs.get("dtype", "float32") + attn_implementation = kwargs.get("attn_implementation", "sdpa") + cache_implementation = kwargs.get("cache_implementation", "static") + max_length = kwargs.get("max_length", 2048) + + return AutoModelForCausalLM.from_pretrained( + model_name_or_path, + device_map=device, + torch_dtype=dtype, + attn_implementation=attn_implementation, + generation_config=GenerationConfig( + use_cache=True, + cache_implementation=cache_implementation, + max_length=max_length, + cache_config={ + "batch_size": batch_size, + "max_cache_len": max_length, + }, + ), + ) diff --git a/optimum/onnxruntime/runs/__init__.py b/optimum/onnxruntime/runs/__init__.py index 1d982949344..d21db2a4aca 100644 --- a/optimum/onnxruntime/runs/__init__.py +++ b/optimum/onnxruntime/runs/__init__.py @@ -110,9 +110,9 @@ def __init__(self, run_config): model_class = FeaturesManager.get_model_class_for_feature(get_autoclass_name(self.task)) self.torch_model = model_class.from_pretrained(run_config["model_name_or_path"]) - self.return_body[ - "model_type" - ] = self.torch_model.config.model_type # return_body is initialized in parent class + self.return_body["model_type"] = ( + self.torch_model.config.model_type + ) # return_body is initialized in parent class def _launch_time(self, trial): batch_size = trial.suggest_categorical("batch_size", self.batch_sizes) diff --git a/setup.py b/setup.py index 6736085943a..bb5bcc11d43 100644 --- a/setup.py +++ b/setup.py @@ -85,6 +85,10 @@ "datasets<=2.16", "transformers>=4.36,<4.38", ], + "exporters-executorch": [ + "executorch>=0.4.0", + "transformers>=4.46", + ], "diffusers": ["diffusers"], "intel": "optimum-intel>=1.18.0", "openvino": "optimum-intel[openvino]>=1.18.0", diff --git a/tests/executorch/export/__init__.py b/tests/executorch/export/__init__.py new file mode 100644 index 00000000000..fdc02578672 --- /dev/null +++ b/tests/executorch/export/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/executorch/export/test_exporters_executorch.py b/tests/executorch/export/test_exporters_executorch.py new file mode 100644 index 00000000000..a4521bc0183 --- /dev/null +++ b/tests/executorch/export/test_exporters_executorch.py @@ -0,0 +1,115 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +import tempfile +import unittest + +import pytest +from transformers.testing_utils import slow + + +class TestExportToExecuTorchCLI(unittest.TestCase): + def test_helps_no_raise(self): + subprocess.run( + "optimum-cli export executorch --help", + shell=True, + check=True, + ) + + @slow + @pytest.mark.run_slow + def test_llama3_2_1b_export_to_executorch(self): + model_id = "meta-llama/Llama-3.2-1B" + task = "text-generation" + recipe = "xnnpack" + with tempfile.TemporaryDirectory() as tempdir: + subprocess.run( + f"optimum-cli export executorch --model {model_id} --task {task} --recipe {recipe} --output_dir {tempdir}/executorch", + shell=True, + check=True, + ) + self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte")) + + @slow + @pytest.mark.run_slow + def test_llama3_2_3b_export_to_executorch(self): + model_id = "meta-llama/Llama-3.2-3B" + task = "text-generation" + recipe = "xnnpack" + with tempfile.TemporaryDirectory() as tempdir: + subprocess.run( + f"optimum-cli export executorch --model {model_id} --task {task} --recipe {recipe} --output_dir {tempdir}/executorch", + shell=True, + check=True, + ) + self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte")) + + @slow + @pytest.mark.run_slow + def test_qwen2_5_export_to_executorch(self): + model_id = "Qwen/Qwen2.5-0.5B" + task = "text-generation" + recipe = "xnnpack" + with tempfile.TemporaryDirectory() as tempdir: + subprocess.run( + f"optimum-cli export executorch --model {model_id} --task {task} --recipe {recipe} --output_dir {tempdir}/executorch", + shell=True, + check=True, + ) + self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte")) + + @slow + @pytest.mark.run_slow + def test_gemma2_export_to_executorch(self): + model_id = "google/gemma-2-2b" + task = "text-generation" + recipe = "xnnpack" + with tempfile.TemporaryDirectory() as tempdir: + subprocess.run( + f"optimum-cli export executorch --model {model_id} --task {task} --recipe {recipe} --output_dir {tempdir}/executorch", + shell=True, + check=True, + ) + self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte")) + + @slow + @pytest.mark.run_slow + def test_gemma_export_to_executorch(self): + model_id = "google/gemma-2b" + task = "text-generation" + recipe = "xnnpack" + with tempfile.TemporaryDirectory() as tempdir: + subprocess.run( + f"optimum-cli export executorch --model {model_id} --task {task} --recipe {recipe} --output_dir {tempdir}/executorch", + shell=True, + check=True, + ) + self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte")) + + @slow + @pytest.mark.run_slow + def test_olmo_export_to_executorch(self): + model_id = "allenai/OLMo-1B-hf" + task = "text-generation" + recipe = "xnnpack" + with tempfile.TemporaryDirectory() as tempdir: + subprocess.run( + f"optimum-cli export executorch --model {model_id} --task {task} --recipe {recipe} --output_dir {tempdir}/executorch", + shell=True, + check=True, + ) + self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte")) diff --git a/tests/executorch/runtime/__init__.py b/tests/executorch/runtime/__init__.py new file mode 100644 index 00000000000..fdc02578672 --- /dev/null +++ b/tests/executorch/runtime/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/executorch/runtime/test_modeling.py b/tests/executorch/runtime/test_modeling.py new file mode 100644 index 00000000000..88caf81b6d5 --- /dev/null +++ b/tests/executorch/runtime/test_modeling.py @@ -0,0 +1,207 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import pytest +from executorch.extension.pybindings.portable_lib import ExecuTorchModule +from transformers import AutoTokenizer +from transformers.testing_utils import slow + +from optimum.executorchruntime import ExecuTorchModelForCausalLM + + +class ExecuTorchModelIntegrationTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @slow + @pytest.mark.run_slow + def test_load_model_from_hub(self): + model = ExecuTorchModelForCausalLM.from_pretrained( + model_name_or_path="meta-llama/Llama-3.2-1B", + export=True, + task="text-generation", + recipe="xnnpack", + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + @slow + @pytest.mark.run_slow + def test_load_model_from_local_path(self): + from optimum.exporters.executorch import main_export + + model_id = "meta-llama/Llama-3.2-1B" + task = "text-generation" + recipe = "xnnpack" + + with tempfile.TemporaryDirectory() as tempdir: + # Export to a local dir + main_export( + model_name_or_path=model_id, + task=task, + recipe=recipe, + output_dir=tempdir, + ) + self.assertTrue(os.path.exists(f"{tempdir}/model.pte")) + + # Load the exported model from a local dir + model = ExecuTorchModelForCausalLM.from_pretrained( + model_name_or_path=tempdir, + export=False, + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + @slow + @pytest.mark.run_slow + def test_llama3_2_1b_text_generation_with_xnnpack(self): + model_id = "meta-llama/Llama-3.2-1B" + model = ExecuTorchModelForCausalLM.from_pretrained( + model_name_or_path=model_id, + export=True, + task="text-generation", + recipe="xnnpack", + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + EXPECTED_GENERATED_TEXT = "Simply put, the theory of relativity states that the laws of physics are the same in all inertial frames of reference." + tokenizer = AutoTokenizer.from_pretrained(model_id) + generated_text = model.text_generation( + tokenizer=tokenizer, + prompt="Simply put, the theory of relativity states that", + max_seq_len=len(tokenizer.encode(EXPECTED_GENERATED_TEXT)), + ) + self.assertEqual(generated_text, EXPECTED_GENERATED_TEXT) + + @slow + @pytest.mark.run_slow + def test_llama3_2_3b_text_generation_with_xnnpack(self): + model_id = "meta-llama/Llama-3.2-3B" + model = ExecuTorchModelForCausalLM.from_pretrained( + model_name_or_path=model_id, + export=True, + task="text-generation", + recipe="xnnpack", + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + EXPECTED_GENERATED_TEXT = ( + "Simply put, the theory of relativity states that the speed of light is constant. This " + "means that no matter how fast you are traveling, the speed of light will always be " + "186,000 miles per second." + ) + tokenizer = AutoTokenizer.from_pretrained(model_id) + generated_text = model.text_generation( + tokenizer=tokenizer, + prompt="Simply put, the theory of relativity states that", + max_seq_len=len(tokenizer.encode(EXPECTED_GENERATED_TEXT)), + ) + self.assertEqual(generated_text, EXPECTED_GENERATED_TEXT) + + @slow + @pytest.mark.run_slow + def test_qwen2_5_text_generation_with_xnnpack(self): + model_id = "Qwen/Qwen2.5-0.5B" + model = ExecuTorchModelForCausalLM.from_pretrained( + model_name_or_path=model_id, + export=True, + task="text-generation", + recipe="xnnpack", + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + EXPECTED_GENERATED_TEXT = "My favourite condiment is iced tea. I love it with my breakfast, my lunch" + tokenizer = AutoTokenizer.from_pretrained(model_id) + generated_text = model.text_generation( + tokenizer=tokenizer, + prompt="My favourite condiment is ", + max_seq_len=len(tokenizer.encode(EXPECTED_GENERATED_TEXT)), + ) + self.assertEqual(generated_text, EXPECTED_GENERATED_TEXT) + + @slow + @pytest.mark.run_slow + def test_gemma2_text_generation_with_xnnpack(self): + model_id = "google/gemma-2-2b" + model = ExecuTorchModelForCausalLM.from_pretrained( + model_name_or_path=model_id, + export=True, + task="text-generation", + recipe="xnnpack", + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + EXPECTED_GENERATED_TEXT = "Hello I am doing a project for my school. I need help with my science homework" + tokenizer = AutoTokenizer.from_pretrained(model_id) + generated_text = model.text_generation( + tokenizer=tokenizer, + prompt="Hello I am doing a project for my school", + max_seq_len=len(tokenizer.encode(EXPECTED_GENERATED_TEXT)), + ) + self.assertEqual(generated_text, EXPECTED_GENERATED_TEXT) + + @slow + @pytest.mark.run_slow + def test_gemma_text_generation_with_xnnpack(self): + model_id = "google/gemma-2b" + model = ExecuTorchModelForCausalLM.from_pretrained( + model_name_or_path=model_id, + export=True, + task="text-generation", + recipe="xnnpack", + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + EXPECTED_GENERATED_TEXT = "Hello I am doing a project for my school and I need to make a 3D model of a car." + tokenizer = AutoTokenizer.from_pretrained(model_id) + generated_text = model.text_generation( + tokenizer=tokenizer, + prompt="Hello I am doing a project for my school", + max_seq_len=len(tokenizer.encode(EXPECTED_GENERATED_TEXT)), + ) + self.assertEqual(generated_text, EXPECTED_GENERATED_TEXT) + + @slow + @pytest.mark.run_slow + def test_olmo_text_generation_with_xnnpack(self): + model_id = "allenai/OLMo-1B-hf" + model = ExecuTorchModelForCausalLM.from_pretrained( + model_name_or_path=model_id, + export=True, + task="text-generation", + recipe="xnnpack", + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + EXPECTED_GENERATED_TEXT = ( + "Simply put, the theory of relativity states that the speed of light is the same in all directions." + ) + tokenizer = AutoTokenizer.from_pretrained(model_id) + generated_text = model.text_generation( + tokenizer=tokenizer, + prompt="Simply put, the theory of relativity states that", + max_seq_len=len(tokenizer.encode(EXPECTED_GENERATED_TEXT)), + ) + self.assertEqual(generated_text, EXPECTED_GENERATED_TEXT)