-
Notifications
You must be signed in to change notification settings - Fork 71
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9a58fec
commit 536a299
Showing
4 changed files
with
279 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
__all__ = [ | ||
"OnnxArgExtremumOld", | ||
"OnnxArgExtremum", | ||
] | ||
|
||
from typing import Optional | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch import nn | ||
|
||
from onnx2torch.node_converters.registry import add_converter | ||
from onnx2torch.onnx_graph import OnnxGraph | ||
from onnx2torch.onnx_node import OnnxNode | ||
from onnx2torch.utils.common import OnnxToTorchModule | ||
from onnx2torch.utils.common import OperationConverterResult | ||
from onnx2torch.utils.common import onnx_mapping_from_node | ||
|
||
DEFAULT_AXIS = 0 | ||
DEFAULT_KEEPDIMS = 1 | ||
DEFAULT_SELECT_LAST_INDEX = 0 | ||
|
||
_TORCH_FUNCTION_FROM_ONNX_TYPE = { | ||
"ArgMax": torch.argmax, | ||
"ArgMin": torch.argmin, | ||
} | ||
|
||
|
||
class OnnxArgExtremumOld(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring | ||
def __init__(self, operation_type: str, axis: int, keepdims: int): | ||
super().__init__() | ||
self.axis = axis | ||
self.keepdims = bool(keepdims) | ||
self.extremum_function = _TORCH_FUNCTION_FROM_ONNX_TYPE[operation_type] | ||
|
||
def forward(self, data: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring | ||
return self.extremum_function(data, dim=self.axis, keepdim=self.keepdims) | ||
|
||
|
||
class OnnxArgExtremum(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring | ||
def __init__(self, operation_type: str, axis: int, keepdims: int, select_last_index: int): | ||
super().__init__() | ||
self.axis = axis | ||
self.keepdims = bool(keepdims) | ||
self.select_last_index = bool(select_last_index) | ||
self.extremum_function = _TORCH_FUNCTION_FROM_ONNX_TYPE[operation_type] | ||
|
||
def forward(self, data: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring | ||
if self.select_last_index: | ||
# torch's argmax does not handle the select_last_index attribute from Onnx. | ||
# We flip the data, call the normal argmax, then map it back to the original | ||
flipped = torch.flip(data, dims=[self.axis]) | ||
|
||
extremum_index_flipped = self.extremum_function(flipped, dim=self.axis, keepdim=self.keepdims) | ||
extremum_index_original = data.size(dim=self.axis) - 1 - extremum_index_flipped | ||
return extremum_index_original | ||
else: | ||
return self.extremum_function(data, dim=self.axis, keepdim=self.keepdims) | ||
|
||
|
||
@add_converter(operation_type="ArgMax", version=12) | ||
@add_converter(operation_type="ArgMax", version=13) | ||
@add_converter(operation_type="ArgMin", version=12) | ||
@add_converter(operation_type="ArgMin", version=13) | ||
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument | ||
return OperationConverterResult( | ||
torch_module=OnnxArgExtremum( | ||
operation_type=node.operation_type, | ||
axis=node.attributes.get("axis", DEFAULT_AXIS), | ||
keepdims=node.attributes.get("keepdims", DEFAULT_KEEPDIMS), | ||
select_last_index=node.attributes.get("select_last_index", DEFAULT_SELECT_LAST_INDEX), | ||
), | ||
onnx_mapping=onnx_mapping_from_node(node=node), | ||
) | ||
|
||
|
||
@add_converter(operation_type="ArgMax", version=11) | ||
@add_converter(operation_type="ArgMin", version=11) | ||
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument | ||
return OperationConverterResult( | ||
torch_module=OnnxArgExtremumOld( | ||
operation_type=node.operation_type, | ||
axis=node.attributes.get("axis", DEFAULT_AXIS), | ||
keepdims=node.attributes.get("keepdims", DEFAULT_KEEPDIMS), | ||
), | ||
onnx_mapping=onnx_mapping_from_node(node=node), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import onnx | ||
from onnx.helper import make_tensor_value_info | ||
import pytest | ||
import torch | ||
|
||
from tests.utils.common import check_onnx_model | ||
from tests.utils.common import make_model_from_nodes | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"op_type", | ||
( | ||
"ArgMax", | ||
"ArgMin", | ||
), | ||
) | ||
@pytest.mark.parametrize( | ||
"opset_version", | ||
( | ||
11, | ||
12, | ||
13, | ||
), | ||
) | ||
@pytest.mark.parametrize( | ||
"dims,axis", | ||
( | ||
(1, 0), | ||
(2, 0), | ||
(2, 1), | ||
(3, 0), | ||
(3, 1), | ||
(3, 2), | ||
(4, 0), | ||
(4, 1), | ||
(4, 2), | ||
(4, 3), | ||
), | ||
) | ||
@pytest.mark.parametrize( | ||
"keepdims", | ||
( | ||
0, | ||
1, | ||
), | ||
) | ||
@pytest.mark.parametrize( | ||
"select_last_index", | ||
(0, 1), | ||
) | ||
def test_arg_max_arg_min( # pylint: disable=missing-function-docstring | ||
op_type: str, | ||
opset_version: int, | ||
dims: int, | ||
axis: int, | ||
keepdims: int, | ||
select_last_index: int, | ||
) -> None: | ||
input_shape = [3] * dims # arbitrary magnitude in each dimension | ||
test_inputs = {"data": np.random.randn(*input_shape).astype(np.float32)} | ||
|
||
kwargs = {"keepdims": keepdims, "axis": axis} | ||
if opset_version >= 12: | ||
# since opset_version 12, we can specify whether to return the LAST index | ||
# of the max/min (respectively) occurance | ||
kwargs["select_last_index"] = select_last_index | ||
|
||
node = onnx.helper.make_node(op_type=op_type, inputs=["data"], outputs=["reduced"], **kwargs) | ||
|
||
# we need to specify outputs_info, since the required output type for arg max (int64) | ||
# is different than the input type | ||
outputs_info = [make_tensor_value_info(name="reduced", elem_type=onnx.TensorProto.INT64, shape=None)] | ||
|
||
model = make_model_from_nodes( | ||
nodes=node, | ||
initializers={}, | ||
inputs_example=test_inputs, | ||
outputs_info=outputs_info, | ||
opset_version=opset_version, | ||
) | ||
|
||
check_onnx_model(model, test_inputs) | ||
|
||
# Test once again with input we know to all be the same. | ||
# This is a way to force the testing of the select_last_index attribute. | ||
# We need the min/max index to occur more than once. | ||
test_inputs2 = {"data": np.ones_like(test_inputs["data"])} | ||
check_onnx_model(model, test_inputs2) | ||
|
||
|
||
class ArgMaxModel(torch.nn.Module): | ||
def __init__(self, axis: int, keepdims: bool): | ||
super().__init__() | ||
self.axis = axis | ||
self.keepdims = bool(keepdims) | ||
|
||
def forward(self, data: torch.Tensor) -> torch.Tensor: | ||
return torch.argmax(data, dim=self.axis, keepdim=self.keepdims) | ||
|
||
|
||
class ArgMinModel(torch.nn.Module): | ||
def __init__(self, axis: int, keepdims: bool): | ||
super().__init__() | ||
self.axis = axis | ||
self.keepdims = bool(keepdims) | ||
|
||
def forward(self, data: torch.Tensor) -> torch.Tensor: | ||
return torch.argmin(data, dim=self.axis, keepdim=self.keepdims) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"op_type", | ||
( | ||
"ArgMax", | ||
"ArgMin", | ||
), | ||
) | ||
@pytest.mark.parametrize( | ||
"opset_version", | ||
( | ||
11, | ||
12, | ||
13, | ||
), | ||
) | ||
@pytest.mark.parametrize( | ||
"dims,axis", | ||
( | ||
(1, 0), | ||
(2, 0), | ||
(2, 1), | ||
(3, 0), | ||
(3, 1), | ||
(3, 2), | ||
(4, 0), | ||
(4, 1), | ||
(4, 2), | ||
(4, 3), | ||
), | ||
) | ||
@pytest.mark.parametrize( | ||
"keepdims", | ||
( | ||
0, | ||
1, | ||
), | ||
) | ||
def test_start_from_torch_module( | ||
op_type: str, | ||
opset_version: int, | ||
dims: int, | ||
axis: int, | ||
keepdims: int, | ||
tmp_path: Path, | ||
) -> None: | ||
""" | ||
Test starting from a torch module, export to Onnx, then converting back to torch. | ||
""" | ||
if op_type == "ArgMax": | ||
model = ArgMaxModel(axis=axis, keepdims=keepdims) | ||
else: | ||
model = ArgMinModel(axis=axis, keepdims=keepdims) | ||
|
||
input_shape = [3] * dims # arbitrary magnitude in each dimension | ||
|
||
# export the pytorch model to onnx | ||
dummy_data = {"data": torch.randn(*input_shape)} | ||
input_names = ["data"] | ||
output_names = ["indices"] | ||
model_path = tmp_path / "model.onnx" | ||
torch.onnx.export( | ||
model, | ||
(dummy_data,), | ||
str(model_path), | ||
export_params=True, | ||
input_names=input_names, | ||
output_names=output_names, | ||
do_constant_folding=False, | ||
training=torch._C._onnx.TrainingMode.TRAINING, | ||
) | ||
|
||
# load the exported onnx file | ||
model = onnx.load(model_path) | ||
onnx.checker.check_model(model, False) | ||
|
||
test_inputs = {"data": np.random.randn(*input_shape).astype(np.float32)} | ||
check_onnx_model(model, test_inputs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters