From eeb31595e200ddd08036c66613a0573eb51da6c2 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 22 Nov 2024 19:47:45 +0000 Subject: [PATCH] Add ONNX export support for patchtsmixer --- docs/source/exporters/onnx/overview.mdx | 1 + optimum/exporters/onnx/model_configs.py | 4 ++++ optimum/exporters/tasks.py | 11 ++++++++++- tests/exporters/exporters_utils.py | 2 ++ 4 files changed, 17 insertions(+), 1 deletion(-) diff --git a/docs/source/exporters/onnx/overview.mdx b/docs/source/exporters/onnx/overview.mdx index 38919cd1da7..204d9c51129 100644 --- a/docs/source/exporters/onnx/overview.mdx +++ b/docs/source/exporters/onnx/overview.mdx @@ -75,6 +75,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra - Nystromformer - OWL-ViT - PatchTST +- PatchTSMixer - Pegasus - Perceiver - Phi diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 45bbfca2315..61ab251d3b9 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -2483,3 +2483,7 @@ class PatchTSTOnnxConfig(OnnxConfig): @property def inputs(self) -> Dict[str, Dict[int, str]]: return {"past_values": {0: "batch_size", 1: "sequence_length"}} + + +class PatchTSMixerOnnxConfig(PatchTSTOnnxConfig): + pass diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 8e94c668c82..54164a6491e 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -222,7 +222,6 @@ class TasksManager: "text-generation": "AutoModelForCausalLM", "text2text-generation": "AutoModelForSeq2SeqLM", "text-classification": "AutoModelForSequenceClassification", - "time-series-forecasting": "PatchTSTForPrediction", # TODO: AutoModelForPrediction is not yet supported "token-classification": "AutoModelForTokenClassification", "zero-shot-image-classification": "AutoModelForZeroShotImageClassification", "zero-shot-object-detection": "AutoModelForZeroShotObjectDetection", @@ -315,6 +314,10 @@ class TasksManager: } _CUSTOM_CLASSES = { + ("pt", "patchtsmixer", "feature-extraction"): ("transformers", "PatchTSMixerModel"), + ("pt", "patchtsmixer", "time-series-forecasting"): ("transformers", "PatchTSMixerForPrediction"), + ("pt", "patchtst", "feature-extraction"): ("transformers", "PatchTSTModel"), + ("pt", "patchtst", "time-series-forecasting"): ("transformers", "PatchTSTForPrediction"), ("pt", "pix2struct", "image-to-text"): ("transformers", "Pix2StructForConditionalGeneration"), ("pt", "pix2struct", "visual-question-answering"): ("transformers", "Pix2StructForConditionalGeneration"), ("pt", "visual-bert", "question-answering"): ("transformers", "VisualBertForQuestionAnswering"), @@ -913,9 +916,15 @@ class TasksManager: onnx="OPTOnnxConfig", ), "patchtst": supported_tasks_mapping( + "feature-extraction", "time-series-forecasting", onnx="PatchTSTOnnxConfig", ), + "patchtsmixer": supported_tasks_mapping( + "feature-extraction", + "time-series-forecasting", + onnx="PatchTSMixerOnnxConfig", + ), "qwen2": supported_tasks_mapping( "feature-extraction", "feature-extraction-with-past", diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 9803c5c8d37..a58ee505316 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -129,6 +129,7 @@ "owlv2": "hf-internal-testing/tiny-random-Owlv2Model", "owlvit": "hf-tiny-model-private/tiny-random-OwlViTModel", "patchtst": "ibm/test-patchtst", + "patchtsmixer": "ibm/test-patchtsmixer", "pegasus": "hf-internal-testing/tiny-random-PegasusModel", "perceiver": { "hf-internal-testing/tiny-random-language_perceiver": ["fill-mask", "text-classification"], @@ -257,6 +258,7 @@ "owlv2": "google/owlv2-base-patch16", "owlvit": "google/owlvit-base-patch32", "patchtst": "ibm/test-patchtst", + "patchtsmixer": "ibm/test-patchtsmixer", "perceiver": "hf-internal-testing/tiny-random-PerceiverModel", # Not using deepmind/language-perceiver because it takes too much time for testing. # "rembert": "google/rembert", "poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",