From e56192506ea6a5a7fa68b3923d1c8823ed73f82e Mon Sep 17 00:00:00 2001 From: Guang Yang Date: Tue, 5 Nov 2024 16:35:48 -0800 Subject: [PATCH] Export to ExecuTorch: Initial Integration --- .github/workflows/test_executorch_runtime.yml | 37 ++ .github/workflows/test_export_executorch.yml | 37 ++ optimum/commands/__init__.py | 2 +- optimum/commands/export/__init__.py | 1 + optimum/commands/export/base.py | 6 + optimum/commands/export/executorch.py | 55 +++ optimum/executorchruntime/__init__.py | 16 + .../executorchruntime/modeling_executorch.py | 456 ++++++++++++++++++ optimum/exporters/executorch/__init__.py | 24 + optimum/exporters/executorch/__main__.py | 148 ++++++ optimum/exporters/executorch/convert.py | 73 +++ .../exporters/executorch/recipe_registry.py | 56 +++ .../exporters/executorch/recipes/__init__.py | 0 .../exporters/executorch/recipes/xnnpack.py | 68 +++ optimum/exporters/executorch/task_registry.py | 56 +++ .../exporters/executorch/tasks/__init__.py | 0 .../exporters/executorch/tasks/causal_lm.py | 54 +++ optimum/onnxruntime/runs/__init__.py | 6 +- setup.py | 4 + tests/executorch/test_modeling.py | 49 ++ tests/exporters/executorch/__init__.py | 0 .../test_exporters_executorch_cli.py | 26 + 22 files changed, 1170 insertions(+), 4 deletions(-) create mode 100644 .github/workflows/test_executorch_runtime.yml create mode 100644 .github/workflows/test_export_executorch.yml create mode 100644 optimum/commands/export/executorch.py create mode 100644 optimum/executorchruntime/__init__.py create mode 100644 optimum/executorchruntime/modeling_executorch.py create mode 100644 optimum/exporters/executorch/__init__.py create mode 100644 optimum/exporters/executorch/__main__.py create mode 100644 optimum/exporters/executorch/convert.py create mode 100644 optimum/exporters/executorch/recipe_registry.py create mode 100644 optimum/exporters/executorch/recipes/__init__.py create mode 100644 optimum/exporters/executorch/recipes/xnnpack.py create mode 100644 optimum/exporters/executorch/task_registry.py create mode 100644 optimum/exporters/executorch/tasks/__init__.py create mode 100644 optimum/exporters/executorch/tasks/causal_lm.py create mode 100644 tests/executorch/test_modeling.py create mode 100644 tests/exporters/executorch/__init__.py create mode 100644 tests/exporters/executorch/test_exporters_executorch_cli.py diff --git a/.github/workflows/test_executorch_runtime.yml b/.github/workflows/test_executorch_runtime.yml new file mode 100644 index 00000000000..4d99c74fd8f --- /dev/null +++ b/.github/workflows/test_executorch_runtime.yml @@ -0,0 +1,37 @@ +name: ExecuTorch Runtime / Python - Test + +on: + push: + branches: [main] + pull_request: + branches: [main] + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + build: + strategy: + fail-fast: false + matrix: + python-version: ['3.10', '3.11', '3.12'] + os: [ubuntu-20.04, macos-15] + + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v2 + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies for ExecuTorch + run: | + pip install .[tests,exporters-executorch] + python -c "import executorch; print(executorch.__version__)" + python -c "import torch; print(torch.__version__)" + python -c "import transformers; print(transformers.__version__)" + - name: Run tests + working-directory: tests + run: | + pytest tests/executorch/test_*.py -s -vvvv --durations=0 diff --git a/.github/workflows/test_export_executorch.yml b/.github/workflows/test_export_executorch.yml new file mode 100644 index 00000000000..3b300de785d --- /dev/null +++ b/.github/workflows/test_export_executorch.yml @@ -0,0 +1,37 @@ +name: Exporters ExecuTorch / Python - Test + +on: + push: + branches: [main] + pull_request: + branches: [main] + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + build: + strategy: + fail-fast: false + matrix: + python-version: ['3.10', '3.11', '3.12'] + os: [ubuntu-20.04, macos-15] + + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v2 + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies for ExecuTorch + run: | + pip install .[tests,exporters-executorch] + python -c "import executorch; print(executorch.__version__)" + python -c "import torch; print(torch.__version__)" + python -c "import transformers; print(transformers.__version__)" + - name: Run tests + working-directory: tests + run: | + pytest tests/exporters/executorch/test_*.py -s -vvvv --durations=0 diff --git a/optimum/commands/__init__.py b/optimum/commands/__init__.py index 8a2a276d1c5..a31344ed133 100644 --- a/optimum/commands/__init__.py +++ b/optimum/commands/__init__.py @@ -14,5 +14,5 @@ from .base import BaseOptimumCLICommand, CommandInfo, RootOptimumCLICommand from .env import EnvironmentCommand -from .export import ExportCommand, ONNXExportCommand, TFLiteExportCommand +from .export import ExecuTorchExportCommand, ExportCommand, ONNXExportCommand, TFLiteExportCommand from .optimum_cli import optimum_cli_subcommand diff --git a/optimum/commands/export/__init__.py b/optimum/commands/export/__init__.py index 19da68a60d2..b72cd5dbc8d 100644 --- a/optimum/commands/export/__init__.py +++ b/optimum/commands/export/__init__.py @@ -14,5 +14,6 @@ from .base import ExportCommand +from .executorch import ExecuTorchExportCommand from .onnx import ONNXExportCommand from .tflite import TFLiteExportCommand diff --git a/optimum/commands/export/base.py b/optimum/commands/export/base.py index 07737cb8eaf..e5ed4c90ff5 100644 --- a/optimum/commands/export/base.py +++ b/optimum/commands/export/base.py @@ -15,6 +15,7 @@ """optimum.exporters command-line interface base classes.""" from .. import BaseOptimumCLICommand, CommandInfo +from .executorch import ExecuTorchExportCommand from .onnx import ONNXExportCommand from .tflite import TFLiteExportCommand @@ -25,6 +26,11 @@ class ExportCommand(BaseOptimumCLICommand): help="Export PyTorch and TensorFlow models to several format.", ) SUBCOMMANDS = ( + CommandInfo( + name="executorch", + help="Export PyTorch model to ExecuTorch.", + subcommand_class=ExecuTorchExportCommand, + ), CommandInfo( name="onnx", help="Export PyTorch and TensorFlow to ONNX.", diff --git a/optimum/commands/export/executorch.py b/optimum/commands/export/executorch.py new file mode 100644 index 00000000000..7c0dab1f14f --- /dev/null +++ b/optimum/commands/export/executorch.py @@ -0,0 +1,55 @@ +"""Defines the command line for the export with ExecuTorch.""" + +from pathlib import Path +from typing import TYPE_CHECKING + +from ...exporters import TasksManager +from ..base import BaseOptimumCLICommand + + +if TYPE_CHECKING: + from argparse import ArgumentParser + + +def parse_args_executorch(parser): + required_group = parser.add_argument_group("Required arguments") + required_group.add_argument( + "-m", "--model", type=str, required=True, help="Model ID on huggingface.co or path on disk to load model from." + ) + required_group.add_argument( + "-o", + "--output_dir", + type=Path, + help="Path indicating the directory where to store the generated ExecuTorch model.", + ) + required_group.add_argument( + "--task", + type=str, + default="text-generation", + help=( + "The task to export the model for. Available tasks depend on the model, but are among:" + f" {str(TasksManager.get_all_tasks())}." + ), + ) + required_group.add_argument( + "--recipe", + type=str, + default="xnnpack", + help='Pre-defined recipes for export to ExecuTorch. Defaults to "xnnpack".', + ) + + +class ExecuTorchExportCommand(BaseOptimumCLICommand): + @staticmethod + def parse_args(parser: "ArgumentParser"): + return parse_args_executorch(parser) + + def run(self): + from ...exporters.executorch import main_export + + main_export( + model_name_or_path=self.args.model, + task=self.args.task, + recipe=self.args.recipe, + output_dir=self.args.output_dir, + ) diff --git a/optimum/executorchruntime/__init__.py b/optimum/executorchruntime/__init__.py new file mode 100644 index 00000000000..3b9eec7fed5 --- /dev/null +++ b/optimum/executorchruntime/__init__.py @@ -0,0 +1,16 @@ +from typing import TYPE_CHECKING +from transformers.utils import _LazyModule + + +_import_structure = { + "modeling_executorch": [ + "ExecuTorchModelForCausalLM", + ], +} + +if TYPE_CHECKING: + from .modeling_executorch import ExecuTorchModelForCausalLM +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/optimum/executorchruntime/modeling_executorch.py b/optimum/executorchruntime/modeling_executorch.py new file mode 100644 index 00000000000..6fb03780bbf --- /dev/null +++ b/optimum/executorchruntime/modeling_executorch.py @@ -0,0 +1,456 @@ +"""ExecuTorchModelForXXX classes, allowing to run ExecuTorch Models with ExecuTorch Runtime using the same API as Transformers.""" + +import logging +import os +import warnings +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union + +import torch +from executorch.extension.pybindings.portable_lib import _load_for_executorch +from huggingface_hub import hf_hub_download +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +from huggingface_hub.utils import EntryNotFoundError +from transformers import ( + AutoConfig, + AutoModel, + GenerationMixin, + AutoModelForCausalLM, + GenerationConfig, + PretrainedConfig, +) +from transformers.integrations.executorch import TorchExportableModuleWithStaticCache +from transformers.modeling_outputs import ( + BaseModelOutput, + CausalLMOutput, + CausalLMOutputWithPast, + ModelOutput, +) + +from ..exporters import TasksManager +from ..exporters.executorch import main_export +from ..modeling_base import FROM_PRETRAINED_START_DOCSTRING, OptimizedModel + +if TYPE_CHECKING: + from transformers import PretrainedConfig + + +logger = logging.getLogger(__name__) + + +class ExecuTorchModelForCausalLM(OptimizedModel): + """ + ExecuTorch model with a causal language modeling head for inference using the ExecuTorch Runtime. + + This class provides an interface for loading, running, and generating outputs from a causal language model + optimized for ExecuTorch Runtime. It includes utilities for exporting and loading pre-trained models + compatible with ExecuTorch runtime. + + Attributes: + auto_model_class (`Type`): + Associated Transformers class, `AutoModelForCausalLM`. + et_model (`ExecuTorchModule`): + The loaded ExecuTorch model. + use_kv_cache (`bool`): + Whether key-value caching is enabled. For performance reasons, the exported model is + optimized to use a static cache. + max_cache_size (`int`): + Maximum sequence length supported by the cache. + max_batch_size (`int`): + Maximum supported batch size. + dtype (`str`): + Data type of the model parameters. + bos_token_id (`int`): + Beginning-of-sequence token ID. + eos_token_id (`int`): + End-of-sequence token ID. + vocab_size (`int`): + Size of the model vocabulary. + """ + + auto_model_class = AutoModelForCausalLM + + def __init__( + self, + model: "ExecuTorchModule", + config: "PretrainedConfig", + ): + super().__init__(model, config) + self.et_model = model + logger.debug(f"Load all static methods: {self.et_model.method_names()}") + self.use_kv_cache = self.et_model.run_method("use_kv_cache")[0] + self.max_cache_size = self.et_model.run_method("get_max_seq_len")[0] + self.max_batch_size = self.et_model.run_method("get_max_batch_size")[0] + self.dtype = self.et_model.run_method("get_dtype")[0] + self.bos_token_id = self.et_model.run_method("get_bos_id")[0] + self.eos_token_id = self.et_model.run_method("get_eos_id")[0] + self.vocab_size = self.et_model.run_method("get_vocab_size")[0] + + def forward( + self, + input_ids: torch.Tensor, + cache_position: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass of the model, which is compatible with the ExecuTorch runtime for LLM. + + Args: + input_ids (`torch.Tensor`): Tensor representing current input token id to the model. + cache_position (`torch.Tensor`): Tensor representing current input position in the cache. + + Returns: + torch.Tensor: Logits output from the model. + """ + return self.et_model.forward((input_ids, cache_position))[0] + + @classmethod + def from_pretrained( + cls, + model_name_or_path: Union[str, Path], + task: str, + recipe: str, + export: bool = False, + config: "PretrainedConfig" = None, + subfolder: str = "", + revision: Optional[str] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, + force_download: bool = False, + local_files_only: bool = False, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, + **kwargs, + ) -> "ExecuTorchModelForCausalLM": + """ + Load a pre-trained ExecuTorch model. + + Args: + model_name_or_path (`Union[str, Path]`): + Model ID on huggingface.co or path on disk to the model repository to export. Example: `model_name_or_path="meta-llama/Llama-3.2-1B"` or `mode_name_or_path="/path/to/model_folder`. + task (`str`): + The task to export the model for, e.g. "text-generation". + recipe (`str`): + The recipe to use to do the export, e.g. "xnnpack". + export (`bool`, *optional*, defaults to `False`): + If `True`, the model will be exported from eager to ExecuTorch after fetched from huggingface.co. + If `False`, the previously exported ExecuTorch model will be loaded from a local path. + config (`PretrainedConfig`, *optional*): + Configuration of the pre-trained model. + subfolder (`str`, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo either locally or on huggingface.co, you can + specify the folder name here. + revision (`str`, defaults to `"main"`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. + cache_dir (`Optional[str]`, defaults to `None`): + Path indicating where to store cache. The default Hugging Face cache path will be used by default. + force_download (`bool`, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + local_files_only (`Optional[bool]`, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`Optional[Union[bool,str]]`, defaults to `None`): + Deprecated. Please use the `token` argument instead. + token (`Optional[Union[bool,str]]`, defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). + **kwargs: + Additional configuration options to tasks and recipes. + + Returns: + `ExecuTorchModelForCausalLM`: An instance of the ExecuTorch model for text generation task. + """ + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + + if export: + return cls._export( + model_id=model_name_or_path, + task=task, + recipe=recipe, + config=config, + **kwargs, + ) + else: + return cls._from_pretrained( + model_dir_path=model_name_or_path, + task=task, + recipe=recipe, + config=config, + ) + + @classmethod + def _from_pretrained( + cls, + model_dir_path: Union[str, Path], + task: str, + recipe: str, + config: PretrainedConfig, + subfolder: str = "", + revision: Optional[str] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, + force_download: bool = False, + local_files_only: bool = False, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, + ) -> "ExecuTorchModelForCausalLM": + """ + Load a pre-trained ExecuTorch model from a local directory. + + Args: + model_dir_path (`Union[str, Path]`): + Path to the directory containing the ExecuTorch model file (`model.pte`). + task (`str`): + The task to export the model for, e.g. "text-generation". + recipe (`str`): + The recipe to use to do the export, e.g. "xnnpack". + config (`PretrainedConfig`, *optional*): + Configuration of the pre-trained model. + subfolder (`str`, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo either locally or on huggingface.co, you can + specify the folder name here. + revision (`str`, defaults to `"main"`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. + cache_dir (`Optional[str]`, defaults to `None`): + Path indicating where to store cache. The default Hugging Face cache path will be used by default. + force_download (`bool`, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + local_files_only (`Optional[bool]`, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`Optional[Union[bool,str]]`, defaults to `None`): + Deprecated. Please use the `token` argument instead. + token (`Optional[Union[bool,str]]`, defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). + + Returns: + `ExecuTorchModelForCausalLM`: The initialized ExecuTorch model. + + """ + full_path = os.path.join(f"{model_dir_path}", "model.pte") + model = _load_for_executorch(full_path) + logging.debug(f"{model.method_meta('forward')}") + return cls( + model=model, + config=config, + ) + + def _save_pretrained(self, save_directory): + """ + Saves a model weights into a directory, so that it can be re-loaded using the + [`from_pretrained`] class method. + """ + raise NotImplementedError + + @classmethod + def _export( + cls, + model_id: str, + task: str, + recipe: str, + config: PretrainedConfig, + cache_dir: str = HUGGINGFACE_HUB_CACHE, + trust_remote_code: bool = False, + subfolder: str = "", + revision: Optional[str] = None, + force_download: bool = False, + local_files_only: bool = False, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, + **kwargs, + ): + """ + Fetch a model from the Hugging Face Hub and export it to ExecuTorch format. + + Args: + model_id (`str`): + Model ID on huggingface.co, for example: `model_name_or_path="meta-llama/Llama-3.2-1B"`. + task (`str`): + The task to export the model for, e.g. "text-generation". + recipe (`str`): + The recipe to use to do the export, e.g. "xnnpack". + config (`PretrainedConfig`, *optional*): + Configuration of the pre-trained model. + cache_dir (`Optional[str]`, defaults to `None`): + Path indicating where to store cache. The default Hugging Face cache path will be used by default. + trust_remote_code (`bool`, defaults to `False`): + Allows to use custom code for the modeling hosted in the model repository. This option should only be set for repositories + you trust and in which you have read the code, as it will execute on your local machine arbitrary code present in the + model repository. + subfolder (`str`, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo either locally or on huggingface.co, you can + specify the folder name here. + revision (`str`, defaults to `"main"`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. + force_download (`bool`, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + local_files_only (`Optional[bool]`, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`Optional[Union[bool,str]]`, defaults to `None`): + Deprecated. Please use the `token` argument instead. + token (`Optional[Union[bool,str]]`, defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). + **kwargs: + Additional configuration options to tasks and recipes. + + Returns: + `ExecuTorchModelForCausalLM`: The loaded and exported ExecuTorch model. + + """ + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + + save_dir = TemporaryDirectory() + save_dir_path = Path(save_dir.name) + + # Export to ExecuTorch and save the pte file to the temporary directory + main_export( + model_name_or_path=model_id, + output_dir=save_dir_path, + task=task, + recipe=recipe, + subfolder=subfolder, + revision=revision, + cache_dir=cache_dir, + token=token, + local_files_only=local_files_only, + force_download=force_download, + trust_remote_code=trust_remote_code, + **kwargs, + ) + + return cls._from_pretrained( + model_dir_path=save_dir_path, + task=task, + recipe=recipe, + config=config, + use_auth_token=use_auth_token, + subfolder=subfolder, + revision=revision, + cache_dir=cache_dir, + token=token, + local_files_only=local_files_only, + force_download=force_download, + ) + + def generate( + self, + prompt_tokens: List[int], + echo: bool = False, + pos_base: int = 0, + max_seq_len: Optional[int] = None, + ) -> List[int]: + """ + Generate tokens from a prompt using the ExecuTorch model. + + Args: + prompt_tokens (List[int]): + List of token IDs representing the prompt. + echo (`bool`, *optional*): + Whether to include prompt tokens in the generated output. Defaults to `False`. + pos_base (`int`, *optional*): + Base position for the prompt tokens. Defaults to 0. + max_seq_len (`int`, *optional*): + Maximum sequence length for the generated output. + Defaults to None and uses the model's `max_cache_size` attribute. + Will be truncated to maximal cache size if larger than `max_cache_size`. + + Returns: + List[int]: List of generated token IDs. + + Note: + Temporarily implemented this method in Python due to limited access to ExecuTorch's c++ LLM runner via pybind. + Expect improvements to the pybind interface in ExecuTorch version 0.4.1. + """ + self.device = torch.device("cpu") + if max_seq_len is None: + # Default to max_cache_size if max_seq_len is not specified + max_seq_len = self.max_cache_size + elif max_seq_len > self.max_cache_size: + logger.warning( + f"max_seq_len={max_seq_len} is larger than max_cache_size={self.max_cache_size}. Generating tokens will be truncated to max_cache_size." + ) + max_seq_len = self.max_cache_size + generated_tokens = [] + + # prefill + for i, prompt_token in enumerate(prompt_tokens): + logits = self.forward( + input_ids=torch.tensor([prompt_token], dtype=torch.long, device=self.device).unsqueeze(0), + cache_position=torch.tensor([i], dtype=torch.long, device=self.device), + ) + + next_token = torch.argmax(logits, dim=-1).item() + generated_tokens = prompt_tokens + [next_token] + + while len(generated_tokens) < max_seq_len: + logits = self.forward( + input_ids=torch.tensor([next_token], dtype=torch.long, device=self.device).unsqueeze(0), + cache_position=torch.tensor( + [pos_base + len(generated_tokens) - 1], + dtype=torch.long, + device=self.device, + ), + ) + next_token = torch.argmax(logits, dim=-1).item() + generated_tokens.append(next_token) + if next_token == self.eos_token_id: + break + + return generated_tokens if echo else generated_tokens[len(prompt_tokens) :] + + def text_generation( + self, + tokenizer: "PreTrainedTokenizer", + prompt: str, + echo: bool = True, + max_seq_len: Optional[int] = None, + ): + """ + Perform text generation task for a given prompt using the ExecuTorch model. + + Args: + tokenizer (`PreTrainedTokenizer`): + The tokenizer used to encode and decode the prompt and output. + prompt (`str`): + The text prompt to complete. + echo (`bool`, *optional*): + Whether to include prompt tokens in the generated output. Defaults to `True`. + max_seq_len (`int`, *optional*): + Maximum sequence length for the generated output. + Defaults to None and uses the model's `max_cache_size` attribute. + Will be truncated to maximal cache size if larger than `max_cache_size`. + """ + self.tokenizer = tokenizer + + # Sanity check + if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.bos_token_id: + raise ValueError( + f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} must be the same as the model's bos_token_id={self.bos_token_id}." + ) + if self.tokenizer.eos_token_id is not None and self.tokenizer.eos_token_id != self.eos_token_id: + raise ValueError( + f"The tokenizer's eos_token_id={self.tokenizer.eos_token_id} must be the same as the model's eos_token_id={self.eos_token_id}." + ) + + prompt_tokens = self.tokenizer.encode(prompt) + generated_tokens = self.generate( + prompt_tokens=prompt_tokens, + echo=echo, + max_seq_len=max_seq_len, + ) + return self.tokenizer.decode(generated_tokens, skip_special_tokens=True) diff --git a/optimum/exporters/executorch/__init__.py b/optimum/exporters/executorch/__init__.py new file mode 100644 index 00000000000..90e228392eb --- /dev/null +++ b/optimum/exporters/executorch/__init__.py @@ -0,0 +1,24 @@ +from typing import TYPE_CHECKING + +from transformers.utils import _LazyModule + + +_import_structure = { + "convert": [ + "export_to_executorch", + ], + "__main__": ["main_export"], +} + +if TYPE_CHECKING: + from .__main__ import main_export + from .convert import export_to_executorch +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/optimum/exporters/executorch/__main__.py b/optimum/exporters/executorch/__main__.py new file mode 100644 index 00000000000..6c257984b6d --- /dev/null +++ b/optimum/exporters/executorch/__main__.py @@ -0,0 +1,148 @@ +"""Entry point to the optimum.exporters.executorch command line.""" + +import argparse +import os +import warnings +from pathlib import Path + +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +from optimum.utils.import_utils import check_if_transformers_greater +from transformers import AutoConfig, AutoTokenizer +from transformers.utils import is_torch_available + +from ...commands.export.executorch import parse_args_executorch +from .convert import export_to_executorch +from .task_registry import discover_tasks, task_registry + + +if is_torch_available(): + import torch + +from typing import Optional, Union + + +def main_export( + model_name_or_path: str, + task: str, + recipe: str, + output_dir: Union[str, Path], + cache_dir: str = HUGGINGFACE_HUB_CACHE, + trust_remote_code: bool = False, + pad_token_id: Optional[int] = None, + subfolder: str = "", + revision: str = "main", + force_download: bool = False, + local_files_only: bool = False, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, + **kwargs, +): + """ + Full-suite ExecuTorch export function, exporting **from a model ID on Hugging Face Hub or a local model repository**. + + Args: + model_name_or_path (`str`): + Model ID on huggingface.co or path on disk to the model repository to export. Example: `model_name_or_path="meta-llama/Llama-3.2-1B"` or `mode_name_or_path="/path/to/model_folder`. + task (`str`): + The task to export the model for, e.g. "text-generation". + recipe (`str`): + The recipe to use to do the export, e.g. "xnnpack". + output_dir (`Union[str, Path]`): + Path indicating the directory where to store the generated ExecuTorch model. + cache_dir (`Optional[str]`, defaults to `None`): + Path indicating where to store cache. The default Hugging Face cache path will be used by default. + trust_remote_code (`bool`, defaults to `False`): + Allows to use custom code for the modeling hosted in the model repository. This option should only be set for repositories + you trust and in which you have read the code, as it will execute on your local machine arbitrary code present in the + model repository. + pad_token_id (`Optional[int]`, defaults to `None`): + This is needed by some models, for some tasks. If not provided, will attempt to use the tokenizer to guess it. + subfolder (`str`, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo either locally or on huggingface.co, you can + specify the folder name here. + revision (`str`, defaults to `"main"`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. + force_download (`bool`, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + local_files_only (`Optional[bool]`, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`Optional[Union[bool,str]]`, defaults to `None`): + Deprecated. Please use the `token` argument instead. + token (`Optional[Union[bool,str]]`, defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). + **kwargs: + Additional configuration options to tasks and recipes. + + Example usage: + ```python + >>> from optimum.exporters.executorch import main_export + + >>> main_export("meta-llama/Llama-3.2-1B", "text-generation", "xnnpack", "meta_llama3_2_1b/") + ``` + """ + + if not check_if_transformers_greater("4.46"): + raise ValueError( + "The minimum Transformers version compatible with ExecuTorch is 4.46.0. Please upgrade to Transformers 4.46.0 or later." + ) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + + # Dynamically discover and import registered tasks + discover_tasks() + + # Load the model for specific task + try: + task_func = task_registry.get(task) + except KeyError as e: + raise RuntimeError(f"The task '{task}' isn't registered. Detailed error: {e}") + + model = task_func(model_name_or_path, **kwargs) + + if task == "text-generation": + from transformers.integrations.executorch import TorchExportableModuleWithStaticCache + + model = TorchExportableModuleWithStaticCache(model) + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + return export_to_executorch( + model=model, + task=task, + recipe=recipe, + output_dir=output_dir, + **kwargs, + ) + + +def main(): + parser = argparse.ArgumentParser("Hugging Face Optimum ExecuTorch exporter") + + parse_args_executorch(parser) + + # Retrieve CLI arguments + args = parser.parse_args() + + main_export( + model_name_or_path=args.model, + output_dir=args.output_dir, + task=args.task, + recipe=args.recipe, + cache_dir=args.cache_dir, + trust_remote_code=args.trust_remote_code, + pad_token_id=args.pad_token_id, + ) + + +if __name__ == "__main__": + main() diff --git a/optimum/exporters/executorch/convert.py b/optimum/exporters/executorch/convert.py new file mode 100644 index 00000000000..742f4acd0b7 --- /dev/null +++ b/optimum/exporters/executorch/convert.py @@ -0,0 +1,73 @@ +"""ExecuTorch model check and export functions.""" + +import logging +import os +from pathlib import Path +from typing import Union + +from transformers.utils import is_torch_available +from ...utils import check_if_transformers_greater +from ..tasks import TasksManager +from .recipe_registry import discover_recipes, recipe_registry + + +if is_torch_available(): + import torch + from transformers.modeling_utils import PreTrainedModel + +logger = logging.getLogger(__name__) + + +def export_to_executorch( + model: Union["PreTrainedModel", "TorchExportableModuleWithStaticCache"], + task: str, + recipe: str, + output_dir: Union[str, Path], + **kwargs, +): + """ + Export a pre-trained PyTorch model to the ExecuTorch format using a specified recipe. + + This function facilitates the transformation of a PyTorch model into an optimized ExecuTorch program. + + Args: + model (`Union["PreTrainedModel", "TorchExportableModuleWithStaticCache"]`): + A PyTorch model to be exported. This can be a standard HuggingFace `PreTrainedModel` or a wrapped + module like `TorchExportableModuleWithStaticCache` for text generation task. + task (`str`): + The specific task the exported model will perform, e.g., "text-generation". + recipe (`str`): + The recipe to guide the export process, e.g., "xnnpack". Recipes define the optimization and lowering steps. + Will raise an exception if the specified recipe is not registered in the recipe registry. + output_dir (`Union[str, Path]`): + Path to the directory where the resulting ExecuTorch model will be saved. + **kwargs: + Additional configuration options passed to the recipe. + + Returns: + `ExecuTorchProgram`: + The lowered ExecuTorch program object. + + Notes: + - The function uses a dynamic recipe discovery mechanism to identify and import the specified recipe. + - The exported model is stored in the specified output directory with the fixed filename `model.pte`. + - The resulting ExecuTorch program is serialized and saved to the output directory. + """ + + # Dynamically discover and import registered recipes + discover_recipes() + + # Export and lower the model to ExecuTorch with the recipe + try: + recipe_func = recipe_registry.get(recipe) + except KeyError as e: + raise RuntimeError(f"The recipe '{recipe}' isn't registered. Detailed error: {e}") + + executorch_prog = recipe_func(model, task, **kwargs) + + full_path = os.path.join(f"{output_dir}", "model.pte") + with open(full_path, "wb") as f: + executorch_prog.write_to_file(f) + logger.info(f"Saved exported program to {full_path}") + + return executorch_prog diff --git a/optimum/exporters/executorch/recipe_registry.py b/optimum/exporters/executorch/recipe_registry.py new file mode 100644 index 00000000000..162a835196b --- /dev/null +++ b/optimum/exporters/executorch/recipe_registry.py @@ -0,0 +1,56 @@ +import importlib +import logging +import os +import pkgutil + +logger = logging.getLogger(__name__) + +recipe_registry = {} + +package_name = "optimum.exporters.executorch.recipes" + + +def register_recipe(recipe_name): + """ + Decorator to register a recipe for exporting and lowering an ExecuTorch model under a specific name. + + Args: + recipe_name (`str`): + The name of the recipe to associate with a callable recipe. + + Returns: + `Callable`: + The original function wrapped as a registered recipe. + + Example: + ```python + @register_recipe("my_new_recipe") + def my_new_recipe(...): + ... + ``` + """ + + def decorator(func): + recipe_registry[recipe_name] = func + return func + + return decorator + + +def discover_recipes(): + """ + Dynamically discovers and imports all recipe modules within the `optimum.exporters.executorch.recipes` package. + + Ensures recipes under `./recipes` directory are dynamically loaded without requiring manual imports. + + Notes: + New recipes **must** be added to the `./recipes` directory to be discovered and used by `main_export`. + Failure to do so will prevent dynamic discovery and registration. Recipes must also use the + `@register_recipe` decorator to be properly registered in the `recipe_registry`. + """ + package = importlib.import_module(package_name) + package_path = package.__path__ + + for _, module_name, _ in pkgutil.iter_modules(package_path): + logger.info(f"Importing {package_name}.{module_name}") + importlib.import_module(f"{package_name}.{module_name}") diff --git a/optimum/exporters/executorch/recipes/__init__.py b/optimum/exporters/executorch/recipes/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/optimum/exporters/executorch/recipes/xnnpack.py b/optimum/exporters/executorch/recipes/xnnpack.py new file mode 100644 index 00000000000..72ad2a4a54f --- /dev/null +++ b/optimum/exporters/executorch/recipes/xnnpack.py @@ -0,0 +1,68 @@ +from typing import Union + +import torch +import torch.export._trace +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner +from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, +) +from torch.nn.attention import SDPBackend +from transformers import PreTrainedModel, TorchExportableModuleWithStaticCache + +from ..recipe_registry import register_recipe + + +@register_recipe("xnnpack") +def export_to_executorch_with_xnnpack( + model: Union[PreTrainedModel, TorchExportableModuleWithStaticCache], + task: str, + **kwargs, +): + metadata = {} + if task == "text-generation": + example_input_ids = torch.tensor([[1]], dtype=torch.long) + example_cache_position = torch.tensor([0], dtype=torch.long) + + def _get_constant_methods(model: PreTrainedModel): + metadata = { + "get_dtype": 5 if model.config.torch_dtype == torch.float16 else 6, + "get_bos_id": model.config.bos_token_id, + "get_eos_id": model.config.eos_token_id, + "get_head_dim": model.config.hidden_size / model.config.num_attention_heads, + "get_max_batch_size": model.generation_config.cache_config.batch_size, + "get_max_seq_len": model.generation_config.cache_config.max_cache_len, + "get_n_kv_heads": model.config.num_key_value_heads, + "get_n_layers": model.config.num_hidden_layers, + "get_vocab_size": model.config.vocab_size, + "use_kv_cache": model.generation_config.use_cache, + } + return {k: v for k, v in metadata.items() if v is not None} + + metadata = _get_constant_methods(model if isinstance(model, PreTrainedModel) else model.model) + else: + # TODO: Prepare model inputs for other tasks + raise ValueError(f"Unsupported task '{task}'.") + + with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): + exported_program = torch.export._trace._export( + model, + args=(example_input_ids,), + kwargs={"cache_position": example_cache_position}, + pre_dispatch=False, + strict=True, + ) + + return to_edge_transform_and_lower( + exported_program, + partitioner=[XnnpackPartitioner()], + compile_config=EdgeCompileConfig( + _skip_dim_order=True, + ), + constant_methods=metadata, + ).to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + ), + ) diff --git a/optimum/exporters/executorch/task_registry.py b/optimum/exporters/executorch/task_registry.py new file mode 100644 index 00000000000..5a297f2ced2 --- /dev/null +++ b/optimum/exporters/executorch/task_registry.py @@ -0,0 +1,56 @@ +import importlib +import logging +import os +import pkgutil + +logger = logging.getLogger(__name__) + +task_registry = {} + +package_name = "optimum.exporters.executorch.tasks" + + +def register_task(task_name): + """ + Decorator to register a task under a specific name. + + Args: + task_name (`str`): + The name of the task to associate with a callable task. + + Returns: + `Callable`: + The original function wrapped as a registered task. + + Example: + ```python + @register_task("my_new_task") + def my_new_task(...): + ... + ``` + """ + + def decorator(func): + task_registry[task_name] = func + return func + + return decorator + + +def discover_tasks(): + """ + Dynamically discovers and imports all task modules within the `optimum.exporters.executorch.tasks` package. + + Ensures tasks under `./tasks` directory are dynamically loaded without requiring manual imports. + + Notes: + New tasks **must** be added to the `./tasks` directory to be discovered and used by `main_export`. + Failure to do so will prevent dynamic discovery and registration. Tasks must also use the + `@register_task` decorator to be properly registered in the `task_registry`. + """ + package = importlib.import_module(package_name) + package_path = package.__path__ + + for _, module_name, _ in pkgutil.iter_modules(package_path): + logger.info(f"Importing {package_name}.{module_name}") + importlib.import_module(f"{package_name}.{module_name}") diff --git a/optimum/exporters/executorch/tasks/__init__.py b/optimum/exporters/executorch/tasks/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/optimum/exporters/executorch/tasks/causal_lm.py b/optimum/exporters/executorch/tasks/causal_lm.py new file mode 100644 index 00000000000..10b07c84ed1 --- /dev/null +++ b/optimum/exporters/executorch/tasks/causal_lm.py @@ -0,0 +1,54 @@ +from transformers import AutoModelForCausalLM, GenerationConfig + +from ..task_registry import register_task + + +@register_task("text-generation") +def load_causal_lm_model(model_name_or_path: str, **kwargs): + """ + Loads a causal language model for text generation and registers it under the task + 'text-generation' using Hugging Face's AutoModelForCausalLM. + + Args: + model_name_or_path (str): + Model ID on huggingface.co or path on disk to the model repository to export. For example: + `model_name_or_path="meta-llama/Llama-3.2-1B"` or `mode_name_or_path="/path/to/model_folder` + **kwargs: + Additional configuration options for the model: + - dtype (str, optional): + Data type for model weights (default: "float32"). + Options include "float16" and "bfloat16". + - attn_implementation (str, optional): + Attention mechanism implementation (default: "sdpa"). + - cache_implementation (str, optional): + Cache management strategy (default: "static"). + - max_length (int, optional): + Maximum sequence length for generation (default: 2048). + + Returns: + transformers.PreTrainedModel: + An instance of a model subclass (e.g., Llama, Gemma) with the configuration for exporting + and lowering to ExecuTorch. + """ + device = "cpu" + batch_size = 1 + dtype = kwargs.get("dtype", "float32") + attn_implementation = kwargs.get("attn_implementation", "sdpa") + cache_implementation = kwargs.get("cache_implementation", "static") + max_length = kwargs.get("max_length", 2048) + + return AutoModelForCausalLM.from_pretrained( + model_name_or_path, + device_map=device, + torch_dtype=dtype, + attn_implementation=attn_implementation, + generation_config=GenerationConfig( + use_cache=True, + cache_implementation=cache_implementation, + max_length=max_length, + cache_config={ + "batch_size": batch_size, + "max_cache_len": max_length, + }, + ), + ) diff --git a/optimum/onnxruntime/runs/__init__.py b/optimum/onnxruntime/runs/__init__.py index 1d982949344..d21db2a4aca 100644 --- a/optimum/onnxruntime/runs/__init__.py +++ b/optimum/onnxruntime/runs/__init__.py @@ -110,9 +110,9 @@ def __init__(self, run_config): model_class = FeaturesManager.get_model_class_for_feature(get_autoclass_name(self.task)) self.torch_model = model_class.from_pretrained(run_config["model_name_or_path"]) - self.return_body[ - "model_type" - ] = self.torch_model.config.model_type # return_body is initialized in parent class + self.return_body["model_type"] = ( + self.torch_model.config.model_type + ) # return_body is initialized in parent class def _launch_time(self, trial): batch_size = trial.suggest_categorical("batch_size", self.batch_sizes) diff --git a/setup.py b/setup.py index 82892bfcc8c..460a51816a2 100644 --- a/setup.py +++ b/setup.py @@ -88,6 +88,10 @@ "datasets<=2.16", "transformers>=4.26,<4.38", ], + "exporters-executorch": [ + "executorch>=0.4.0", + "transformers>=4.46", + ], "diffusers": ["diffusers"], "intel": "optimum-intel>=1.18.0", "openvino": "optimum-intel[openvino]>=1.18.0", diff --git a/tests/executorch/test_modeling.py b/tests/executorch/test_modeling.py new file mode 100644 index 00000000000..28a55a87dee --- /dev/null +++ b/tests/executorch/test_modeling.py @@ -0,0 +1,49 @@ +import os +import subprocess +import tempfile +import unittest +from pathlib import Path +from typing import Dict + +import numpy as np +import pytest +import requests +import torch +from huggingface_hub import HfApi +from huggingface_hub.constants import default_cache_path +from parameterized import parameterized + +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, +) + +from optimum.executorchruntime import ExecuTorchModelForCausalLM +from optimum.utils.import_utils import check_if_transformers_greater +from optimum.utils.testing_utils import require_hf_token +from executorch.extension.pybindings.portable_lib import ExecuTorchModule + + +class ExecuTorchModelIntegrationTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.TEST_MODEL_ID = "meta-llama/Llama-3.2-1B" + + def test_text_generation_with_xnnpack(self): + model = ExecuTorchModelForCausalLM.from_pretrained( + model_name_or_path=self.TEST_MODEL_ID, + task="text-generation", + recipe="xnnpack", + export=True, + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + EXPECTED_GENERATED_TEXT = "Simply put, the theory of relativity states that the laws of physics are the same in all inertial frames of reference." + tokenizer = AutoTokenizer.from_pretrained(self.TEST_MODEL_ID) + generated_text = model.text_generation( + tokenizer=tokenizer, + prompt="Simply put, the theory of relativity states that", + max_seq_len=len(tokenizer.encode(EXPECTED_GENERATED_TEXT)), + ) + self.assertEqual(generated_text, EXPECTED_GENERATED_TEXT) diff --git a/tests/exporters/executorch/__init__.py b/tests/exporters/executorch/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/exporters/executorch/test_exporters_executorch_cli.py b/tests/exporters/executorch/test_exporters_executorch_cli.py new file mode 100644 index 00000000000..6bbd643f20d --- /dev/null +++ b/tests/exporters/executorch/test_exporters_executorch_cli.py @@ -0,0 +1,26 @@ +import os +import subprocess +import tempfile +import unittest +import optimum.commands + + +class TestExportToExecuTorchCLI(unittest.TestCase): + def test_helps_no_raise(self): + subprocess.run( + "optimum-cli export executorch --help", + shell=True, + check=True, + ) + + def test_export_to_executorch_command(self): + model_id = "meta-llama/Llama-3.2-1B" + task = "text-generation" + recipe = "xnnpack" + with tempfile.TemporaryDirectory() as tempdir: + subprocess.run( + f"optimum-cli export executorch --model {model_id} --task {task} --recipe {recipe} --output_dir {tempdir}/executorch", + shell=True, + check=True, + ) + self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte"))