Skip to content

Commit

Permalink
Add ONNX export support for depth anything and prompt depth anything
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Dec 25, 2024
1 parent d21256c commit 30f2bda
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 0 deletions.
58 changes: 58 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,64 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
return common_outputs


class DepthAnythingOnnxConfig(ViTOnnxConfig):
pass


class DummyPromptDepthInputGenerator(DummyVisionInputGenerator):
SUPPORTED_INPUT_NAMES = (
"pixel_values",
"prompt_depth",
)

def __init__(
self,
task: str,
normalized_config: NormalizedVisionConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"],
width: int = DEFAULT_DUMMY_SHAPES["width"],
height: int = DEFAULT_DUMMY_SHAPES["height"],
prompt_height: int = DEFAULT_DUMMY_SHAPES["prompt_height"],
prompt_width: int = DEFAULT_DUMMY_SHAPES["prompt_width"],
**kwargs,
):
super().__init__(
task=task,
normalized_config=normalized_config,
batch_size=batch_size,
num_channels=num_channels,
width=width,
height=height,
**kwargs,
)
self.prompt_height = prompt_height
self.prompt_width = prompt_width

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name == "prompt_depth":
return self.random_float_tensor(
(self.batch_size, 1, self.prompt_height, self.prompt_width),
framework=framework,
dtype=float_dtype,
)
else:
return super().generate(
input_name=input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype
)


class PromptDepthAnythingOnnxConfig(DepthAnythingOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyPromptDepthInputGenerator,)

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {
"pixel_values": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"},
"prompt_depth": {0: "batch_size", 2: "prompt_height", 3: "prompt_width"},
}


class CvTOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 13
ATOL_FOR_VALIDATION = 1e-2
Expand Down
10 changes: 10 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,11 @@ class TasksManager:
"masked-im",
onnx="DeiTOnnxConfig",
),
"depth-anything": supported_tasks_mapping(
"feature-extraction",
"depth-estimation",
onnx="DepthAnythingOnnxConfig",
),
"detr": supported_tasks_mapping(
"feature-extraction",
"object-detection",
Expand Down Expand Up @@ -1033,6 +1038,11 @@ class TasksManager:
"image-classification",
onnx="PoolFormerOnnxConfig",
),
"prompt-depth-anything": supported_tasks_mapping(
"feature-extraction",
"depth-estimation",
onnx="PromptDepthAnythingOnnxConfig",
),
"pvt": supported_tasks_mapping(
"feature-extraction",
"image-classification",
Expand Down
2 changes: 2 additions & 0 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def wrapper(*args, **kwargs):
"num_channels": 3,
"point_batch_size": 3,
"nb_points_per_image": 2,
"prompt_width": 32,
"prompt_height": 32,
# audio
"feature_size": 80,
"nb_max_frames": 3000,
Expand Down

0 comments on commit 30f2bda

Please sign in to comment.