From c8f8ac4e32c4d79402a82a0068a05c4fdc726330 Mon Sep 17 00:00:00 2001 From: Arseny <82811840+senysenyseny16@users.noreply.github.com> Date: Thu, 30 Mar 2023 23:17:48 +0600 Subject: [PATCH] feat: EyeLike Co-authored-by: Mason Ma --- onnx2torch/node_converters/__init__.py | 1 + onnx2torch/node_converters/eye_like.py | 67 ++++++++++++++++++++ onnx2torch/utils/dtype.py | 87 ++++++++++++++++++++++++++ operators.md | 2 +- tests/node_converters/eye_like_test.py | 31 +++++++++ 5 files changed, 187 insertions(+), 1 deletion(-) create mode 100644 onnx2torch/node_converters/eye_like.py create mode 100644 onnx2torch/utils/dtype.py create mode 100644 tests/node_converters/eye_like_test.py diff --git a/onnx2torch/node_converters/__init__.py b/onnx2torch/node_converters/__init__.py index b01bbee9..dd40354a 100644 --- a/onnx2torch/node_converters/__init__.py +++ b/onnx2torch/node_converters/__init__.py @@ -13,6 +13,7 @@ from onnx2torch.node_converters.dropout import * from onnx2torch.node_converters.einsum import * from onnx2torch.node_converters.expand import * +from onnx2torch.node_converters.eye_like import * from onnx2torch.node_converters.flatten import * from onnx2torch.node_converters.functions import * from onnx2torch.node_converters.gather import * diff --git a/onnx2torch/node_converters/eye_like.py b/onnx2torch/node_converters/eye_like.py new file mode 100644 index 00000000..dc13c54f --- /dev/null +++ b/onnx2torch/node_converters/eye_like.py @@ -0,0 +1,67 @@ +__all__ = [ + 'OnnxEyeLike', +] + +from typing import Optional + +import torch +from torch import nn + +from onnx2torch.node_converters.registry import add_converter +from onnx2torch.onnx_graph import OnnxGraph +from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import onnx_mapping_from_node +from onnx2torch.utils.dtype import onnx_dtype_to_torch_dtype + + +class OnnxEyeLike(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring + def __init__(self, dtype: Optional[int] = None, k: int = 0): # pylint: disable=invalid-name + super().__init__() + self.dtype = dtype + self.k = k # pylint: disable=invalid-name + + def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring + if len(x.shape) != 2: + raise ValueError(f'EyeLike only supports 2D tensors, got {len(x.shape)}') + + dtype = x.dtype if self.dtype is None else onnx_dtype_to_torch_dtype(self.dtype) + if not isinstance(dtype, torch.dtype): + raise ValueError(f'Expected type of dtype is torch.dtype, got {type(dtype)}') + + rows, cols = x.size() + if self.k > rows: + raise ValueError( + f'EyeLike attribute k should be less or equal than the zero dimension of input tensor,' + f'got {self.k} and {rows}' + ) + + if self.k == 0: + return torch.eye(n=rows, m=cols, dtype=dtype) + if self.k > 0: + return torch.concat( + [ + torch.zeros(rows, self.k, dtype=dtype), + torch.eye(n=rows, m=(cols - self.k), dtype=dtype), + ], + dim=1, + ) + return torch.concat( # k < 0: + [ + torch.zeros(-self.k, cols, dtype=dtype), + torch.eye(n=(rows + self.k), m=cols, dtype=dtype), + ], + dim=0, + ) + + +@add_converter(operation_type='EyeLike', version=9) +def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument + node_attributes = node.attributes + k = node_attributes.get('k', 0) # pylint: disable=invalid-name + dtype = node_attributes.get('dtype', None) + return OperationConverterResult( + torch_module=OnnxEyeLike(dtype=dtype, k=k), + onnx_mapping=onnx_mapping_from_node(node=node), + ) diff --git a/onnx2torch/utils/dtype.py b/onnx2torch/utils/dtype.py new file mode 100644 index 00000000..5c3c6e26 --- /dev/null +++ b/onnx2torch/utils/dtype.py @@ -0,0 +1,87 @@ +from typing import Any +from typing import Dict +from typing import Type +from typing import Union + +import numpy as np +import torch + + +def onnx_dtype_to_torch_dtype(dtype: int) -> Union[torch.dtype, Type[str], Type[bool]]: + """ + Convert ONNX dtype to PyTorch dtype. + + Parameters + ---------- + dtype : int + ONNX data type. + + Returns + ------- + Union[torch.dtype, Type[str], Type[bool]] + Corresponding PyTorch dtype. + + """ + # https://github.com/onnx/onnx/blob/main/onnx/onnx-ml.proto#L485 + _dtypes: Dict[int, Union[torch.dtype, Type[str], Type[bool]]] = { + 1: torch.float32, + 2: torch.uint8, + 3: torch.int8, + # 4: UINT16 is not supported: https://github.com/pytorch/pytorch/issues/58734. + 5: torch.int16, + 6: torch.int32, + 7: torch.int64, + 8: str, + 9: bool, + 10: torch.float16, + 11: torch.float64, + # 12: UINT32 is not supported: https://github.com/pytorch/pytorch/issues/58734. + # 13: UINT64 is not supported: https://github.com/pytorch/pytorch/issues/58734. + 14: torch.complex64, + 15: torch.complex128, + 16: torch.bfloat16, + } + try: + return _dtypes[dtype] + except KeyError as exc: + raise ValueError(f'dtype={dtype} is not supported') from exc + + +def onnx_dtype_to_numpy_dtype(dtype: int) -> Union[np.dtype, Type[str], Type[bool]]: + """ + Convert ONNX dtype to Numpy dtype. + + Parameters + ---------- + dtype : int + ONNX data type. + + Returns + ------- + Union[torch.dtype, Type[str], Type[bool]] + Corresponding Numpy dtype. + + """ + # https://numpy.org/doc/stable/reference/arrays.dtypes.html + _dtypes: Dict[int, Any] = { + 1: np.float32, + 2: np.uint8, + 3: np.int8, + 4: np.uint16, + 5: np.int16, + 6: np.int32, + 7: np.int64, + 8: str, + 9: bool, + 10: np.float16, + 11: np.float64, + 12: np.uint32, + 13: np.uint64, + 14: np.complex64, + 15: np.complex128, + # 16: bfloat16 is not supported. + } + try: + return _dtypes[dtype] + except KeyError as exc: + raise ValueError(f'dtype={dtype} is not supported') from exc diff --git a/operators.md b/operators.md index f2db4396..c0aa5768 100644 --- a/operators.md +++ b/operators.md @@ -43,7 +43,7 @@ Minimal tested opset version 9, maximum tested opset version 16, recommended ops | Erf | Y | | | Exp | Y | | | Expand | Y | | -| EyeLike | N | | +| EyeLike | Y | | | Flatten | Y | | | Floor | Y | | | GRU | N | | diff --git a/tests/node_converters/eye_like_test.py b/tests/node_converters/eye_like_test.py new file mode 100644 index 00000000..bf44e6f5 --- /dev/null +++ b/tests/node_converters/eye_like_test.py @@ -0,0 +1,31 @@ +from typing import Optional +from typing import Tuple + +import numpy as np +import onnx +import pytest +from onnx.helper import make_tensor_value_info + +from tests.utils.common import check_onnx_model +from tests.utils.common import make_model_from_nodes + + +@pytest.mark.parametrize('dtype', [None, 1, 6, 7, 11]) +@pytest.mark.parametrize('k', [-2, -1, 0, 1, 2]) +@pytest.mark.parametrize('shape', [[2, 3], [3, 4], [3, 3]]) +def test_eye_like( # pylint: disable=missing-function-docstring + shape: Tuple[int], + dtype: Optional[int], + k: int, # pylint: disable=invalid-name +) -> None: + input_values = np.random.randn(*shape).astype(np.float32) + test_inputs = {'x': input_values} + + node = onnx.helper.make_node(op_type='EyeLike', inputs=['x'], outputs=['z'], dtype=dtype, k=k) + model = make_model_from_nodes( + nodes=node, + initializers={}, + inputs_example=test_inputs, + outputs_info=[make_tensor_value_info(name='z', elem_type=dtype, shape=shape)] if dtype else None, + ) + check_onnx_model(model, test_inputs)