diff --git a/docs/source/exporters/onnx/overview.mdx b/docs/source/exporters/onnx/overview.mdx index 0a5da755a3b..f6bd06cde19 100644 --- a/docs/source/exporters/onnx/overview.mdx +++ b/docs/source/exporters/onnx/overview.mdx @@ -38,6 +38,7 @@ Supported architectures: - Deberta-v2 - Deit - Detr +- DINOv2 - DistilBert - Donut-Swin - Electra diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index a58f42dca4b..a8fa6c1e01f 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -708,6 +708,45 @@ class ConvNextV2OnnxConfig(ViTOnnxConfig): pass +class Dinov2DummyInputGenerator(DummyVisionInputGenerator): + def __init__( + self, + task: str, + normalized_config: NormalizedVisionConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"], + width: int = DEFAULT_DUMMY_SHAPES["width"], + height: int = DEFAULT_DUMMY_SHAPES["height"], + **kwargs, + ): + super().__init__( + task=task, + normalized_config=normalized_config, + batch_size=batch_size, + num_channels=num_channels, + width=width, + height=height, + **kwargs, + ) + + from transformers.onnx.utils import get_preprocessor + + preprocessor = get_preprocessor(normalized_config._name_or_path) + if preprocessor is not None and hasattr(preprocessor, "crop_size"): + self.height = preprocessor.crop_size.get("height", self.height) + self.width = preprocessor.crop_size.get("width", self.width) + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + input_ = super().generate( + input_name=input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype + ) + return input_ + + +class Dinov2OnnxConfig(ViTOnnxConfig): + DUMMY_INPUT_GENERATOR_CLASSES = (Dinov2DummyInputGenerator,) + + class MobileViTOnnxConfig(ViTOnnxConfig): ATOL_FOR_VALIDATION = 1e-4 diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 4d3f9f98d05..91025b6d2d1 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -482,6 +482,11 @@ class TasksManager: "image-segmentation", onnx="DetrOnnxConfig", ), + "dinov2": supported_tasks_mapping( + "feature-extraction", + "image-classification", + onnx="Dinov2OnnxConfig", + ), "distilbert": supported_tasks_mapping( "feature-extraction", "fill-mask", diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index eb9b5404806..8d54c9a0f9c 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -1534,7 +1534,7 @@ def forward( @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) class ORTModelForImageClassification(ORTModel): """ - ONNX Model for image-classification tasks. This class officially supports beit, convnext, convnextv2, data2vec_vision, deit, levit, mobilenet_v1, mobilenet_v2, mobilevit, poolformer, resnet, segformer, swin, vit. + ONNX Model for image-classification tasks. This class officially supports beit, convnext, convnextv2, data2vec_vision, deit, dinov2, levit, mobilenet_v1, mobilenet_v2, mobilevit, poolformer, resnet, segformer, swin, vit. """ auto_model_class = AutoModelForImageClassification diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index 6fa4adcecfd..38d63cf7825 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -190,6 +190,7 @@ class NormalizedConfigManager: 'data2vec-text', 'data2vec-vision', 'detr', + 'dinov2', 'flaubert', 'groupvit', 'ibert', diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 9af7806e7f8..7ca67d8f056 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -62,6 +62,7 @@ "deberta": "hf-internal-testing/tiny-random-DebertaModel", "deberta-v2": "hf-internal-testing/tiny-random-DebertaV2Model", "deit": "hf-internal-testing/tiny-random-DeiTModel", + "dinov2": "hf-internal-testing/tiny-random-Dinov2Model", "donut": "fxmarty/tiny-doc-qa-vision-encoder-decoder", "donut-swin": "hf-internal-testing/tiny-random-DonutSwinModel", "detr": "hf-internal-testing/tiny-random-DetrModel", # hf-internal-testing/tiny-random-detr is larger diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index bb06e42157b..81e1cb31492 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -2702,6 +2702,7 @@ class ORTModelForImageClassificationIntegrationTest(ORTModelTestMixin): "convnextv2", "data2vec_vision", "deit", + "dinov2", "levit", "mobilenet_v1", "mobilenet_v2", diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index 8e579879ea6..70be094fecf 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -50,6 +50,7 @@ "deit": "hf-internal-testing/tiny-random-DeiTModel", "donut": "fxmarty/tiny-doc-qa-vision-encoder-decoder", "detr": "hf-internal-testing/tiny-random-detr", + "dinov2": "hf-internal-testing/tiny-random-Dinov2Model", "distilbert": "hf-internal-testing/tiny-random-DistilBertModel", "electra": "hf-internal-testing/tiny-random-ElectraModel", "encoder-decoder": {