Skip to content

Commit

Permalink
Add ONNX export support for patchtsmixer
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Nov 22, 2024
1 parent e348d47 commit eeb3159
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- Nystromformer
- OWL-ViT
- PatchTST
- PatchTSMixer
- Pegasus
- Perceiver
- Phi
Expand Down
4 changes: 4 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2483,3 +2483,7 @@ class PatchTSTOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {"past_values": {0: "batch_size", 1: "sequence_length"}}


class PatchTSMixerOnnxConfig(PatchTSTOnnxConfig):
pass
11 changes: 10 additions & 1 deletion optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,6 @@ class TasksManager:
"text-generation": "AutoModelForCausalLM",
"text2text-generation": "AutoModelForSeq2SeqLM",
"text-classification": "AutoModelForSequenceClassification",
"time-series-forecasting": "PatchTSTForPrediction", # TODO: AutoModelForPrediction is not yet supported
"token-classification": "AutoModelForTokenClassification",
"zero-shot-image-classification": "AutoModelForZeroShotImageClassification",
"zero-shot-object-detection": "AutoModelForZeroShotObjectDetection",
Expand Down Expand Up @@ -315,6 +314,10 @@ class TasksManager:
}

_CUSTOM_CLASSES = {
("pt", "patchtsmixer", "feature-extraction"): ("transformers", "PatchTSMixerModel"),
("pt", "patchtsmixer", "time-series-forecasting"): ("transformers", "PatchTSMixerForPrediction"),
("pt", "patchtst", "feature-extraction"): ("transformers", "PatchTSTModel"),
("pt", "patchtst", "time-series-forecasting"): ("transformers", "PatchTSTForPrediction"),
("pt", "pix2struct", "image-to-text"): ("transformers", "Pix2StructForConditionalGeneration"),
("pt", "pix2struct", "visual-question-answering"): ("transformers", "Pix2StructForConditionalGeneration"),
("pt", "visual-bert", "question-answering"): ("transformers", "VisualBertForQuestionAnswering"),
Expand Down Expand Up @@ -913,9 +916,15 @@ class TasksManager:
onnx="OPTOnnxConfig",
),
"patchtst": supported_tasks_mapping(
"feature-extraction",
"time-series-forecasting",
onnx="PatchTSTOnnxConfig",
),
"patchtsmixer": supported_tasks_mapping(
"feature-extraction",
"time-series-forecasting",
onnx="PatchTSMixerOnnxConfig",
),
"qwen2": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
Expand Down
2 changes: 2 additions & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@
"owlv2": "hf-internal-testing/tiny-random-Owlv2Model",
"owlvit": "hf-tiny-model-private/tiny-random-OwlViTModel",
"patchtst": "ibm/test-patchtst",
"patchtsmixer": "ibm/test-patchtsmixer",
"pegasus": "hf-internal-testing/tiny-random-PegasusModel",
"perceiver": {
"hf-internal-testing/tiny-random-language_perceiver": ["fill-mask", "text-classification"],
Expand Down Expand Up @@ -257,6 +258,7 @@
"owlv2": "google/owlv2-base-patch16",
"owlvit": "google/owlvit-base-patch32",
"patchtst": "ibm/test-patchtst",
"patchtsmixer": "ibm/test-patchtsmixer",
"perceiver": "hf-internal-testing/tiny-random-PerceiverModel", # Not using deepmind/language-perceiver because it takes too much time for testing.
# "rembert": "google/rembert",
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
Expand Down

0 comments on commit eeb3159

Please sign in to comment.