Skip to content

Commit

Permalink
feat: EyeLike
Browse files Browse the repository at this point in the history
Co-authored-by: Mason Ma <[email protected]>
  • Loading branch information
senysenyseny16 and JohnMasoner authored Mar 30, 2023
1 parent 1626c74 commit c8f8ac4
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 1 deletion.
1 change: 1 addition & 0 deletions onnx2torch/node_converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from onnx2torch.node_converters.dropout import *
from onnx2torch.node_converters.einsum import *
from onnx2torch.node_converters.expand import *
from onnx2torch.node_converters.eye_like import *
from onnx2torch.node_converters.flatten import *
from onnx2torch.node_converters.functions import *
from onnx2torch.node_converters.gather import *
Expand Down
67 changes: 67 additions & 0 deletions onnx2torch/node_converters/eye_like.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
__all__ = [
'OnnxEyeLike',
]

from typing import Optional

import torch
from torch import nn

from onnx2torch.node_converters.registry import add_converter
from onnx2torch.onnx_graph import OnnxGraph
from onnx2torch.onnx_node import OnnxNode
from onnx2torch.utils.common import OnnxToTorchModule
from onnx2torch.utils.common import OperationConverterResult
from onnx2torch.utils.common import onnx_mapping_from_node
from onnx2torch.utils.dtype import onnx_dtype_to_torch_dtype


class OnnxEyeLike(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring
def __init__(self, dtype: Optional[int] = None, k: int = 0): # pylint: disable=invalid-name
super().__init__()
self.dtype = dtype
self.k = k # pylint: disable=invalid-name

def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring
if len(x.shape) != 2:
raise ValueError(f'EyeLike only supports 2D tensors, got {len(x.shape)}')

dtype = x.dtype if self.dtype is None else onnx_dtype_to_torch_dtype(self.dtype)
if not isinstance(dtype, torch.dtype):
raise ValueError(f'Expected type of dtype is torch.dtype, got {type(dtype)}')

rows, cols = x.size()
if self.k > rows:
raise ValueError(
f'EyeLike attribute k should be less or equal than the zero dimension of input tensor,'
f'got {self.k} and {rows}'
)

if self.k == 0:
return torch.eye(n=rows, m=cols, dtype=dtype)
if self.k > 0:
return torch.concat(
[
torch.zeros(rows, self.k, dtype=dtype),
torch.eye(n=rows, m=(cols - self.k), dtype=dtype),
],
dim=1,
)
return torch.concat( # k < 0:
[
torch.zeros(-self.k, cols, dtype=dtype),
torch.eye(n=(rows + self.k), m=cols, dtype=dtype),
],
dim=0,
)


@add_converter(operation_type='EyeLike', version=9)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
node_attributes = node.attributes
k = node_attributes.get('k', 0) # pylint: disable=invalid-name
dtype = node_attributes.get('dtype', None)
return OperationConverterResult(
torch_module=OnnxEyeLike(dtype=dtype, k=k),
onnx_mapping=onnx_mapping_from_node(node=node),
)
87 changes: 87 additions & 0 deletions onnx2torch/utils/dtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import Any
from typing import Dict
from typing import Type
from typing import Union

import numpy as np
import torch


def onnx_dtype_to_torch_dtype(dtype: int) -> Union[torch.dtype, Type[str], Type[bool]]:
"""
Convert ONNX dtype to PyTorch dtype.
Parameters
----------
dtype : int
ONNX data type.
Returns
-------
Union[torch.dtype, Type[str], Type[bool]]
Corresponding PyTorch dtype.
"""
# https://github.com/onnx/onnx/blob/main/onnx/onnx-ml.proto#L485
_dtypes: Dict[int, Union[torch.dtype, Type[str], Type[bool]]] = {
1: torch.float32,
2: torch.uint8,
3: torch.int8,
# 4: UINT16 is not supported: https://github.com/pytorch/pytorch/issues/58734.
5: torch.int16,
6: torch.int32,
7: torch.int64,
8: str,
9: bool,
10: torch.float16,
11: torch.float64,
# 12: UINT32 is not supported: https://github.com/pytorch/pytorch/issues/58734.
# 13: UINT64 is not supported: https://github.com/pytorch/pytorch/issues/58734.
14: torch.complex64,
15: torch.complex128,
16: torch.bfloat16,
}
try:
return _dtypes[dtype]
except KeyError as exc:
raise ValueError(f'dtype={dtype} is not supported') from exc


def onnx_dtype_to_numpy_dtype(dtype: int) -> Union[np.dtype, Type[str], Type[bool]]:
"""
Convert ONNX dtype to Numpy dtype.
Parameters
----------
dtype : int
ONNX data type.
Returns
-------
Union[torch.dtype, Type[str], Type[bool]]
Corresponding Numpy dtype.
"""
# https://numpy.org/doc/stable/reference/arrays.dtypes.html
_dtypes: Dict[int, Any] = {
1: np.float32,
2: np.uint8,
3: np.int8,
4: np.uint16,
5: np.int16,
6: np.int32,
7: np.int64,
8: str,
9: bool,
10: np.float16,
11: np.float64,
12: np.uint32,
13: np.uint64,
14: np.complex64,
15: np.complex128,
# 16: bfloat16 is not supported.
}
try:
return _dtypes[dtype]
except KeyError as exc:
raise ValueError(f'dtype={dtype} is not supported') from exc
2 changes: 1 addition & 1 deletion operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Minimal tested opset version 9, maximum tested opset version 16, recommended ops
| Erf | Y | |
| Exp | Y | |
| Expand | Y | |
| EyeLike | N | |
| EyeLike | Y | |
| Flatten | Y | |
| Floor | Y | |
| GRU | N | |
Expand Down
31 changes: 31 additions & 0 deletions tests/node_converters/eye_like_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Optional
from typing import Tuple

import numpy as np
import onnx
import pytest
from onnx.helper import make_tensor_value_info

from tests.utils.common import check_onnx_model
from tests.utils.common import make_model_from_nodes


@pytest.mark.parametrize('dtype', [None, 1, 6, 7, 11])
@pytest.mark.parametrize('k', [-2, -1, 0, 1, 2])
@pytest.mark.parametrize('shape', [[2, 3], [3, 4], [3, 3]])
def test_eye_like( # pylint: disable=missing-function-docstring
shape: Tuple[int],
dtype: Optional[int],
k: int, # pylint: disable=invalid-name
) -> None:
input_values = np.random.randn(*shape).astype(np.float32)
test_inputs = {'x': input_values}

node = onnx.helper.make_node(op_type='EyeLike', inputs=['x'], outputs=['z'], dtype=dtype, k=k)
model = make_model_from_nodes(
nodes=node,
initializers={},
inputs_example=test_inputs,
outputs_info=[make_tensor_value_info(name='z', elem_type=dtype, shape=shape)] if dtype else None,
)
check_onnx_model(model, test_inputs)

0 comments on commit c8f8ac4

Please sign in to comment.