Skip to content

Commit

Permalink
Merge branch 'main' into fix-windows-int32
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil authored Jun 5, 2024
2 parents 05582db + ac951ca commit f770d85
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 5 deletions.
17 changes: 13 additions & 4 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -1727,13 +1727,20 @@ class ORTModelForSemanticSegmentation(ORTModel):
checkpoint="optimum/segformer-b0-finetuned-ade-512-512",
)
)
def forward(self, **model_inputs: Union[torch.Tensor, np.ndarray]):
use_torch = isinstance(next(iter(model_inputs.values())), torch.Tensor)

def forward(
self,
pixel_values: Union[torch.Tensor, np.ndarray],
**kwargs,
):
use_torch = isinstance(pixel_values, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)

if self.device.type == "cuda" and self.use_io_binding:
io_binding, output_shapes, output_buffers = self.prepare_io_binding(
**model_inputs, ordered_input_names=self._ordered_input_names
io_binding = IOBindingHelper.prepare_io_binding(
self,
pixel_values,
ordered_input_names=self._ordered_input_names,
)

# run inference with binding
Expand All @@ -1743,6 +1750,8 @@ def forward(self, **model_inputs: Union[torch.Tensor, np.ndarray]):

logits = output_buffers["logits"].view(output_shapes["logits"])
else:
model_inputs = {"pixel_values": pixel_values}

onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_outputs = self.model.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)
Expand Down
1 change: 1 addition & 0 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class ORTConfigManager:
"nystromformer": "bert",
"pegasus": "bert",
"roberta": "bert",
"segformer": "vit",
"t5": "bert",
"vit": "vit",
"whisper": "bart",
Expand Down
15 changes: 14 additions & 1 deletion optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,19 @@ class NormalizedVisionConfig(NormalizedConfig):
INPUT_SIZE = "input_size"


class NormalizedSegformerConfig(NormalizedVisionConfig):
NUM_ATTENTION_HEADS = "num_attention_heads"
HIDDEN_SIZE = "hidden_sizes"

# If the attribute is a list, return 0
# 0 means let the optimizer infer the correct value based on the model graph
def __getattr__(self, attr_name):
attr_value = super().__getattr__(attr_name)
if isinstance(attr_value, list):
attr_value = 0
return attr_value


class NormalizedTextAndVisionConfig(NormalizedTextConfig, NormalizedVisionConfig):
TEXT_CONFIG = None
VISION_CONFIG = None
Expand Down Expand Up @@ -203,7 +216,6 @@ class NormalizedConfigManager:
'owlvit',
'perceiver',
'roformer',
'segformer',
'squeezebert',
'table-transformer',
"""
Expand Down Expand Up @@ -258,6 +270,7 @@ class NormalizedConfigManager:
"regnet": NormalizedVisionConfig,
"resnet": NormalizedVisionConfig,
"roberta": NormalizedTextConfig,
"segformer": NormalizedSegformerConfig,
"speech-to-text": SpeechToTextLikeNormalizedTextConfig,
"splinter": NormalizedTextConfig,
"t5": T5LikeNormalizedTextConfig,
Expand Down
2 changes: 2 additions & 0 deletions tests/onnxruntime/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
AutoOptimizationConfig,
ORTConfig,
ORTModelForImageClassification,
ORTModelForSemanticSegmentation,
ORTModelForSequenceClassification,
ORTOptimizer,
)
Expand Down Expand Up @@ -171,6 +172,7 @@ def test_compare_original_seq2seq_model_with_optimized_model(self, model_cls, mo

# Contribution note: Please add test models in alphabetical order. Find test models here: https://huggingface.co/hf-internal-testing.
SUPPORTED_IMAGE_ARCHITECTURES_WITH_MODEL_ID = (
(ORTModelForSemanticSegmentation, "hf-internal-testing/tiny-random-segformer"),
(ORTModelForImageClassification, "hf-internal-testing/tiny-random-vit"),
)

Expand Down

0 comments on commit f770d85

Please sign in to comment.