Skip to content

Commit

Permalink
Export to ExecuTorch: Initial Integration
Browse files Browse the repository at this point in the history
  • Loading branch information
Guang Yang committed Dec 11, 2024
1 parent 4a7cb29 commit 154f7eb
Show file tree
Hide file tree
Showing 29 changed files with 1,898 additions and 5 deletions.
35 changes: 35 additions & 0 deletions .github/workflows/test_executorch_export.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
name: ExecuTorch Export / Python - Test

on:
push:
branches: [main]
pull_request:
branches: [main]

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true

jobs:
build:
strategy:
fail-fast: false
matrix:
python-version: ['3.10', '3.11', '3.12']
os: [ubuntu-20.04, macos-15]

runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies for ExecuTorch
run: |
pip install .[tests,exporters-executorch]
pip list
- name: Run tests
working-directory: tests
run: |
RUN_SLOW=1 pytest executorch/export/test_*.py -s -vvvv --durations=0
35 changes: 35 additions & 0 deletions .github/workflows/test_executorch_runtime.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
name: ExecuTorch Runtime / Python - Test

on:
push:
branches: [main]
pull_request:
branches: [main]

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true

jobs:
build:
strategy:
fail-fast: false
matrix:
python-version: ['3.10', '3.11', '3.12']
os: [ubuntu-20.04, macos-15]

runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies for ExecuTorch
run: |
pip install .[tests,exporters-executorch]
pip list
- name: Run tests
working-directory: tests
run: |
RUN_SLOW=1 pytest executorch/runtime/test_*.py -s -vvvv --durations=0
26 changes: 26 additions & 0 deletions docs/source/exporters/executorch/overview.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# Overview

🤗 Optimum handles the export of PyTorch to ExecuTorch in the `exporters.executorch` module. It provides classes, functions, and a command line interface to perform the export easily.

Supported architectures from [🤗 Transformers](https://huggingface.co/docs/transformers/index):

- Gemma
- Gemma2
- Llama2
- Llama3(Llama3.2)
- OLMo
- Qwen2(Qwen2.5)

There are many more models are supported by ExecuTorch, we will add those models to Optimum over time. Read more at [pytorch/executorch/examples/](https://github.com/pytorch/executorch/tree/main/examples)
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# Configuration for ExecuTorch Export

ExecuTorch export provides a flexible configuration mechanism through dynamic registration, enabling users to have
complete control over the export process. The configuration system is divided into task configurations and recipe
configurations, each addressing specific aspects of the export pipeline.


## Task Configurations

Task configurations determine how a Hugging Face model should be loaded and prepared for export, tailored to specific tasks.

For instance, when exporting a model for a text generation task, the provided configuration utilizes **static caching** and
**SDPA (Scaled Dot-Product Attention)** for inference optimization.

By leveraging task configurations, users can ensure that their models are appropriately prepared for efficient execution on
the ExecuTorch backend.

[[autodoc]] exporters.executorch.task_registry.discover_tasks

[[autodoc]] exporters.executorch.task_registry.register_task

[[autodoc]] exporters.executorch.tasks.causal_lm.load_causal_lm_model


## Recipe Configurations

Recipe configurations control the specifics of lowering an eager PyTorch module to the ExecuTorch backend. These
configurations allow users to:

- Specify whether and how to **quantize** the model.
- Delegate computation to various accelerators, such as **CPU**, **GPU**, **NPU**, **DSP**, and others.
- Define **custom transformation passes**.
- Implement advanced techniques like memory planning algorithms to optimize resource utilization.

[[autodoc]] exporters.executorch.recipe_registry.discover_recipes

[[autodoc]] exporters.executorch.recipe_registry.register_recipe

[[autodoc]] exporters.executorch.recipes.xnnpack.export_to_executorch_with_xnnpack

The combination of task and recipe configurations ensures that users can customize both the high-level task setup
and the low-level export details to suit their deployment requirements.
26 changes: 26 additions & 0 deletions docs/source/exporters/executorch/package_reference/export.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# Export functions

## Main functions

[[autodoc]] exporters.executorch.convert.export_to_executorch

The primary export function is designed to be **model- and task-independent** as well as **optimization-agnostic**, providing a
highly flexible and modular interface for exporting Hugging Face models to the ExecuTorch backend.

This approach highlights the **composability** of ExecuTorch export pipeline, where dynamically registered **task configurations**
specify how a :hug model is prepared, and **recipe configurations** encapsulate device-specific optimizations during export. This
separation allows users to customize the export process without altering the core function.

For more details on task and recipe configurations, see the [Configuration for ExecuTorch Export](./configuration.mdx).
57 changes: 57 additions & 0 deletions docs/source/exporters/executorch/usage_guides/contribute.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# Adding support for an unsupported architecture

We welcome contributions to extend the functionality of ExecuTorch export. This guide provides high-level instructions for contributors who want to:

1. Export a new model that is not currently supported.
2. Add new recipes or support a new task for export.

---

## Exporting a New Model

If you want to export a model that is not already supported by the library, follow these steps:

### Step 1: Export and Test the Model
1. Attempt to export and lower the model using an existing task and recipe. On success, it will store the exported model in a `.pte` file.
2. Add a test case for the model in the appropriate test suite.
- For example, you can make sure tests pass for the new `my_new_model` by running:
```bash
pytest tests/executorch/export/test_*.py -k "test_my_new_model" # doctest: +SKIP
pytest tests/executorch/runtime/test_*.py -k "test_my_new_model" # doctest: +SKIP
```

### Step 2: Handle Export Failures
1. If the export fails in Step 1, report the issue by opening a GitHub issue.
2. If the issue requires changes to the model’s architecture or its Hugging Face implementation, these modifications may be made upstream in the Hugging Face Transformers library.

---

## Adding New Recipes or Tasks

To extend ExecuTorch with new recipes or tasks, follow these guidelines:

### Registering a New Recipe
You can add a custom recipe to define specific optimizations or configurations for exporting models. Below is an example:

```python
from exporters.executorch import register_recipe

@register_recipe("my_custom_recipe")
def export_with_custom_recipe(model, config, *args, **kwargs):
# Example: Apply a custom quantization
```

### Registering a Task
The task registration process is same as adding a recipe. Besides that you may need to implement a new `ExecuTorchModelForXXX` class.
124 changes: 124 additions & 0 deletions docs/source/exporters/executorch/usage_guides/export_a_model.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# Export a model to ExecuTorch with optimum.exporters.executorch

If you need to deploy 🤗 Transformers models for on-device use cases, we recommend
exporting them to a serialized format that can be distributed and executed on specialized
runtimes and hardware. In this guide, we'll show you how to export these
models to [ExecuTorch](https://pytorch.org/executorch/main/intro-overview.html).


## Why ExecuTorch?

ExecuTorch is the ideal solution for deploying PyTorch models on edge devices, offering a streamlined process from
export to deployment without leaving PyTorch ecosystem.

Supporting on-device AI presents unique challenges with diverse hardware, critical power requirements, low/no internet
connectivity, and realtime processing needs. These constraints have historically prevented or slowed down the creation
of scalable and performant on-device AI solutions. We designed ExecuTorch, backed by our industry partners like Meta,
Arm, Apple, Qualcomm, MediaTek, etc. to be highly portable and provide superior developer productivity without losing on
performance.


## Summary

Exporting a PyTorch model to ExecuTorch is as simple as

```bash
optimum-cli export executorch --model "meta-llama/Llama-3.2-1B" --task "text-generation" --recipe "xnnpack" --output_dir "meta_llama3_2_1b"
```

Check out the help for more options:

```bash
optimum-cli export executorch --help
```


## Exporting a model to ExecuTorch using the CLI

To export a 🤗 Transformers model to ExecuTorch, you'll first need to install some extra
dependencies:

```bash
pip install optimum[exporters-executorch]
```

The Optimum ExecuTorch export can be used through Optimum command-line:

```bash
optimum-cli export executorch --help

usage: optimum-cli export executorch [-h] -m MODEL [-o OUTPUT_DIR] [--task TASK] [--recipe RECIPE]

options:
-h, --help show this help message and exit

Required arguments:
-m MODEL, --model MODEL
Model ID on huggingface.co or path on disk to load model from.
-o OUTPUT_DIR, --output_dir OUTPUT_DIR
Path indicating the directory where to store the generated ExecuTorch model.
--task TASK The task to export the model for. Available tasks depend on the model, but are among: ['audio-classification', 'feature-extraction', 'image-to-text',
'sentence-similarity', 'depth-estimation', 'image-segmentation', 'audio-frame-classification', 'masked-im', 'semantic-segmentation', 'text-classification',
'audio-xvector', 'mask-generation', 'question-answering', 'text-to-audio', 'automatic-speech-recognition', 'image-to-image', 'multiple-choice', 'image-
classification', 'text2text-generation', 'token-classification', 'object-detection', 'zero-shot-object-detection', 'zero-shot-image-classification', 'text-
generation', 'fill-mask'].
--recipe RECIPE Pre-defined recipes for export to ExecuTorch. Defaults to "xnnpack".

```

Exporting a checkpoint can be done as follows:

```bash
optimum-cli export executorch --model "meta-llama/Llama-3.2-1B" --task "text-generation" --recipe "xnnpack" --output_dir "meta_llama3_2_1b"
```

You should see a `model.pte` file is stored under "./meta_llama3_2_1b/":

```bash
meta_llama3_2_1b/
└── model.pte
```

This will fetch the model on the Hub and exports the PyTorch model with the specialized recipe. The resulting `model.pte` file can then be run on the [XNNPACK backend](https://pytorch.org/executorch/main/tutorial-xnnpack-delegate-lowering.html), or on many
other ExecuTorh supported backends if exports with different recipes, e.g. Apple's [Core ML](https://pytorch.org/executorch/main/build-run-coreml.html) or [MPS](https://pytorch.org/executorch/main/build-run-mps.html), [Qualcomm's SoCs](https://pytorch.org/executorch/main/build-run-qualcomm-ai-engine-direct-backend.html), [ARM's Ethos-U](https://pytorch.org/executorch/main/executorch-arm-delegate-tutorial.html), [Xtensa HiFi4 DSP](https://pytorch.org/executorch/main/build-run-xtensa.html), [Vulkan GPU](https://pytorch.org/executorch/main/build-run-vulkan.html), [MediaTek](https://pytorch.org/executorch/main/build-run-mediatek-backend.html), etc.

For example, we can load and run the model with [ExecuTorch
Runtime](https://pytorch.org/executorch/main/runtime-overview.html) using the `optimum.executorchruntime` package as follows:

```python
>>> from transformers import AutoTokenizer
>>> from optimum.executorchruntime import ExecuTorchModelForCausalLM
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") # doctest: +SKIP
>>> model = ExecuTorchModelForCausalLM.from_pretrained("meta_llama3_2_1b/", export=False) # doctest: +SKIP
>>> generated_text = model.text_generation(tokenizer=tokenizer, prompt="Simply put, the theory of relativity states that", max_seq_len=45) # doctest: +SKIP
```
Printing the `generated_text` would give that:
```
"Simply put, the theory of relativity states that the laws of physics are the same in all inertial frames of reference. In other words, the laws of physics are the same in all inertial frames of reference."
```
As you can see, converting a model to ExecuTorch does not mean leaving the Hugging Face ecosystem. You end up with a similar API as regular 🤗 Transformers models!
It is also possible to export the model to ExecuTorch directly from the `ExecuTorchModelForCausalLM` class by doing the following:
```python
>>> from optimum.executorchruntime import ExecuTorchModelForCausalLM
>>> model = ExecuTorchModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B", export=True, task="text-generation", recipe="xnnpack")
```
2 changes: 1 addition & 1 deletion docs/source/exporters/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ specific language governing permissions and limitations under the License.

# Overview

🤗 Optimum enables exporting models from PyTorch or TensorFlow to different formats through its `exporters` module. For now, two exporting format are supported: ONNX and TFLite (TensorFlow Lite).
🤗 Optimum enables exporting models from PyTorch or TensorFlow to different formats through its `exporters` module. For now, three exporting format are supported: ONNX, TFLite (TensorFlow Lite), and ExecuTorch.
2 changes: 1 addition & 1 deletion optimum/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@

from .base import BaseOptimumCLICommand, CommandInfo, RootOptimumCLICommand
from .env import EnvironmentCommand
from .export import ExportCommand, ONNXExportCommand, TFLiteExportCommand
from .export import ExecuTorchExportCommand, ExportCommand, ONNXExportCommand, TFLiteExportCommand
from .optimum_cli import optimum_cli_subcommand
1 change: 1 addition & 0 deletions optimum/commands/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@


from .base import ExportCommand
from .executorch import ExecuTorchExportCommand
from .onnx import ONNXExportCommand
from .tflite import TFLiteExportCommand
Loading

0 comments on commit 154f7eb

Please sign in to comment.