From e3b7efb1257c011db907ef40ab340e795cc5684c Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 14 Nov 2023 15:19:38 +0200 Subject: [PATCH] [ONNX export] Add depth-estimation w/ DPT+GLPN (#1529) * Add depth-estimation w/ dpt * Fix depth-estimation outputs * Add GLPN onnx export --- optimum/exporters/onnx/base.py | 1 + optimum/exporters/onnx/model_configs.py | 8 ++++++++ optimum/exporters/tasks.py | 11 +++++++++++ tests/exporters/exporters_utils.py | 2 ++ 4 files changed, 22 insertions(+) diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 2958d3d920c..8bf93feb29c 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -146,6 +146,7 @@ class OnnxConfig(ExportConfig, ABC): "audio-frame-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), "automatic-speech-recognition": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), "audio-xvector": OrderedDict({"logits": {0: "batch_size"}, "embeddings": {0: "batch_size"}}), + "depth-estimation": OrderedDict({"predicted_depth": {0: "batch_size", 1: "height", 2: "width"}}), "document-question-answering": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), "feature-extraction": OrderedDict({"last_hidden_state": {0: "batch_size", 1: "sequence_length"}}), "fill-mask": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index fb4d190a2c9..f4d50ad58d4 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -743,6 +743,14 @@ class Swin2srOnnxConfig(SwinOnnxConfig): pass +class DptOnnxConfig(ViTOnnxConfig): + pass + + +class GlpnOnnxConfig(ViTOnnxConfig): + pass + + class PoolFormerOnnxConfig(ViTOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig ATOL_FOR_VALIDATION = 2e-3 diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 2208a6fe26d..7545c72d6c6 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -164,6 +164,7 @@ class TasksManager: "audio-xvector": "AutoModelForAudioXVector", "automatic-speech-recognition": ("AutoModelForSpeechSeq2Seq", "AutoModelForCTC"), "conversational": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"), + "depth-estimation": "AutoModelForDepthEstimation", "feature-extraction": "AutoModel", "fill-mask": "AutoModelForMaskedLM", "image-classification": "AutoModelForImageClassification", @@ -497,6 +498,11 @@ class TasksManager: "feature-extraction", onnx="DonutSwinOnnxConfig", ), + "dpt": supported_tasks_mapping( + "feature-extraction", + "depth-estimation", + onnx="DptOnnxConfig", + ), "electra": supported_tasks_mapping( "feature-extraction", "fill-mask", @@ -533,6 +539,11 @@ class TasksManager: onnx="FlaubertOnnxConfig", tflite="FlaubertTFLiteConfig", ), + "glpn": supported_tasks_mapping( + "feature-extraction", + "depth-estimation", + onnx="GlpnOnnxConfig", + ), "gpt2": supported_tasks_mapping( "feature-extraction", "feature-extraction-with-past", diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index f17184a1b7c..6e43b65e34f 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -65,6 +65,7 @@ "donut-swin": "hf-internal-testing/tiny-random-DonutSwinModel", "detr": "hf-internal-testing/tiny-random-DetrModel", # hf-internal-testing/tiny-random-detr is larger "distilbert": "hf-internal-testing/tiny-random-DistilBertModel", + "dpt": "hf-internal-testing/tiny-random-DPTModel", "electra": "hf-internal-testing/tiny-random-ElectraModel", "encoder-decoder": { "hf-internal-testing/tiny-random-EncoderDecoderModel-bert-bert": [ @@ -84,6 +85,7 @@ "fxmarty/tiny-testing-falcon-alibi": ["text-generation", "text-generation-with-past"], }, "flaubert": "hf-internal-testing/tiny-random-flaubert", + "glpn": "hf-internal-testing/tiny-random-GLPNModel", "gpt2": "hf-internal-testing/tiny-random-gpt2", "gpt-bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", "gpt-neo": "hf-internal-testing/tiny-random-GPTNeoModel",