Skip to content

Commit

Permalink
release v1.2.0 (#25)
Browse files Browse the repository at this point in the history
* removed useless context manager (SkipTorchTracing), removed useless debug prints (#18)

Co-authored-by: igor <[email protected]>

* fixed conversion of interpolation modes, fixed checking of empty inputs (#20)

Co-authored-by: igor <[email protected]>

* Fix: average pool (#19)

* fixed avg pool operation
* added tests (avg pool)

Co-authored-by: igor <[email protected]>

* fix_optional_args

* common.py and custom_export_to_onnx.py moved to utils package

* fixed moving OnnxConstantOfShape between devices

* fixed placeholders names

* fixed error string

* Models test (#23)

* Add tests for classification, segmentation and detection models
* Fix readme

Co-authored-by: Peter Ivanov <[email protected]>

* fix empty nms, tracing reduce

* fix bug

* change version to 1.2

* fixed indexing in scatter nd

Co-authored-by: IvanovP <[email protected]>

* Transformers support (#27)

* now we support visual transformers models from tiim such as vit and swin
* added split operation
* added matmul operation

Co-authored-by: IvanovP <[email protected]>
Co-authored-by: igor <[email protected]>

* fixed pylint warnings (#28)

Co-authored-by: igor <[email protected]>

* fix slow scatternd (#31)

fix slow scatter nd

Co-authored-by: IvanovP <[email protected]>

* Roialign, roundings, trigonometry (#33)

* added roialigne
* added trigonometric functions
* added rounding functions

Co-authored-by: IvanovP <[email protected]>
Co-authored-by: igor <[email protected]>

* Fix: export to onnx (#34)

* updated export logic of reshape + slice + shape operations

* added custom export for tile

Co-authored-by: igor <[email protected]>

* Feature: modules markers (#36)

* added modules markers
* added custom export to onnx for gather

Co-authored-by: igor <[email protected]>

* Fix: model tests (#35)

* added logsoftmax operation
* refactored tests resourses

Co-authored-by: IvanovP <[email protected]>
Co-authored-by: igor <[email protected]>

* fix expand test for onnx 1.11

Co-authored-by: ivkalgin <[email protected]>
Co-authored-by: igor <[email protected]>
Co-authored-by: IvanovP <[email protected]>
Co-authored-by: Peter Ivanov <[email protected]>
  • Loading branch information
5 people authored Mar 23, 2022
1 parent 4715ffd commit d2dc0a0
Show file tree
Hide file tree
Showing 87 changed files with 1,494 additions and 517 deletions.
6 changes: 5 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
[BASIC]
# Good variable names which should always be accepted, separated by a comma.
good-names=x,y,z

[DESIGN]
# Maximum number of arguments for function / method
max-args=12
Expand All @@ -16,7 +20,7 @@ max-public-methods=20
max-line-length=120

[MESSAGES]
disable=logging-fstring-interpolation
disable=logging-fstring-interpolation,no-self-use

[SIMILARITIES]
# Minimum lines number of a similarity.
Expand Down
49 changes: 35 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Below you can find some examples of use.
### Convert
```python
import torch
from onnx2torch.converter import convert
from onnx2torch import convert

# Path to ONNX model
onnx_model_path = '/some/path/mobile_net_v2.onnx'
Expand Down Expand Up @@ -60,8 +60,34 @@ print(np.allclose(outputs_ort, out_torch.detach().numpy(), atol=1.e-7))
## Models

We have tested the following models:
- [x] ResNet50
- [x] SSDLite with MobileNetV2 backbone

Segmentation models:
- [x] DeepLabv3plus
- [x] DeepLabv3 resnet50 (torchvision)
- [x] HRNet
- [x] UNet (torchvision)
- [x] FCN resnet50 (torchvision)
- [x] lraspp mobilenetv3 (torchvision)

Detection from MMdetection:
- [x] [SSDLite with MobileNetV2 backbone](https://github.com/open-mmlab/mmdetection)
- [x] [RetinaNet R50](https://github.com/open-mmlab/mmdetection)
- [x] [SSD300 with VGG backbone](https://github.com/open-mmlab/mmdetection)
- [x] [Yolov3_d53](https://github.com/open-mmlab/mmdetection)
- [x] [Yolov5](https://github.com/ultralytics/yolov5)

Classification from __torchvision__:
- [x] Resnet18
- [x] Resnet50
- [x] MobileNet v2
- [x] MobileNet v3 large
- [x] EfficientNet_b{0, 1, 2, 3}
- [x] WideResNet50
- [x] ResNext50
- [x] VGG16
- [x] GoogleleNet
- [x] MnasNet
- [x] RegNet

## How to add new operations to converter

Expand All @@ -86,24 +112,19 @@ If Operation's behaviour differs from one opset version to another, you should i
```python
class OnnxExpand(nn.Module):

@staticmethod
def _do_forward(input_tensor: torch.Tensor, shape: torch.Tensor) -> torch.Tensor:
return input_tensor * torch.ones(torch.Size(shape), dtype=input_tensor.dtype, device=input_tensor.device)

def forward(self, *args) -> torch.Tensor:
def forward(self, input_tensor: torch.Tensor, shape: torch.Tensor) -> torch.Tensor:
output = input_tensor * torch.ones(torch.Size(shape), dtype=input_tensor.dtype, device=input_tensor.device)
if torch.onnx.is_in_onnx_export():
with SkipTorchTracing():
output = self._do_forward(*args)
return _ExpandExportToOnnx.set_output_and_apply(output, *args)
return _ExpandExportToOnnx.set_output_and_apply(output, input_tensor, shape)

return self._do_forward(*args)
return output


class _ExpandExportToOnnx(CustomExportToOnnx):

@staticmethod
def symbolic(graph: torch_C.Graph, *args, **kwargs) -> torch_C.Value:
return graph.op('Expand', *args, **kwargs, outputs=1)
def symbolic(graph: torch_C.Graph, *args) -> torch_C.Value:
return graph.op('Expand', *args, outputs=1)


@add_converter(operation_type='Expand', version=8)
Expand Down
1 change: 1 addition & 0 deletions onnx2torch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from onnx2torch.converter import convert
55 changes: 38 additions & 17 deletions onnx2torch/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,29 +39,37 @@ def forward(self, *args, **kwargs): # pylint: disable=no-self-use
raise RuntimeError('Got unexpected "forward" on constant container')


def convert(onnx_model_or_path: Union[str, Path, ModelProto], attach_onnx_mapping: bool = False):
def convert(
onnx_model_or_path: Union[str, Path, ModelProto],
save_input_names: bool = False,
attach_onnx_mapping: bool = False,
) -> fx.GraphModule:
"""Convert model from onnx to PyTorch.
This function build torch.fx GraphModule from onnx ModelProto using operations from the converter registry.
The registered operation can be found in onnx2torch/node_converters
The registered operation can be found in onnx2torch/node_converters.
Usage example:
from onnx2torch.converter import convert
from onnx2torch import convert
torch_module = convert('path/to/onnx_model.onnx')
Parameters
----------
onnx_model_or_path:
onnx_model_or_path : Union[str, Path, ModelProto]
Onnx ModelProto or model path to convert.
attach_onnx_mapping:
save_input_names : bool
Whether to use original onnx inputs names as fx graph placeholders names or to use generated names (input_n).
False by default.
attach_onnx_mapping : bool
Whether to attach info about mapping to original onnx tensors names.
Returns
-------
:
fx.GraphModule
PyTorch GraphModule
"""

if isinstance(onnx_model_or_path, ModelProto):
Expand Down Expand Up @@ -90,8 +98,16 @@ def convert(onnx_model_or_path: Union[str, Path, ModelProto], attach_onnx_mappin
torch_nodes = {}

# create input nodes
for name in onnx_graph.input_values:
torch_nodes[name] = torch_graph.placeholder(name=name)
for i, name in enumerate(onnx_graph.input_values, 1):
if save_input_names:
if not name.isidentifier():
raise ValueError(f'Input name "{name}" cannot be used as name of placeholder in fx.GraphModule.')

placeholder_name = name
else:
placeholder_name = f'input_{i}'

torch_nodes[name] = torch_graph.placeholder(name=placeholder_name)

# create intermediate nodes
# IMPORTANT: nodes already topologically sorted
Expand Down Expand Up @@ -131,10 +147,16 @@ def convert(onnx_model_or_path: Union[str, Path, ModelProto], attach_onnx_mappin
args.append(torch_input_node)

elif value_type == ValueType.GRAPH_INITIALIZER:
# The name of putorch buffer must not contain '.'(dot)
len_torch_initializers = sum(1 for _ in torch_initializers.buffers())
torch_buffer_name = f'onnx_initializer_{len_torch_initializers}'
if value_name not in torch_nodes:
torch_initializers.add_initializer(value_name, onnx_graph.initializers[value_name].to_torch())
torch_nodes[value_name] = torch_graph.get_attr(f'initializers.{value_name}')
args.append(torch_nodes[value_name])
torch_initializers.add_initializer(
torch_buffer_name,
onnx_graph.initializers[value_name].to_torch(),
)
torch_nodes[torch_buffer_name] = torch_graph.get_attr(f'initializers.{torch_buffer_name}')
args.append(torch_nodes[torch_buffer_name])

elif value_type == ValueType.EMPTY:
args.append(None)
Expand All @@ -147,12 +169,11 @@ def convert(onnx_model_or_path: Union[str, Path, ModelProto], attach_onnx_mappin
if None in args:
first_skipped_arg = args.index(None)
forward_args = tuple(inspect.signature(torch_module.forward).parameters.keys())
forward_args = forward_args[first_skipped_arg:]

for arg_name in forward_args:
arg_value = args.pop(first_skipped_arg)
if arg_value is not None:
kwargs[arg_name] = arg_value
forward_args = forward_args[first_skipped_arg:len(args)]
args, kwargs_values = args[:first_skipped_arg], args[first_skipped_arg:]
kwargs.update(
{name: value for name, value in zip(forward_args, kwargs_values) if value is not None}
)

torch_nodes[name] = torch_graph.call_module(module_name=name, args=tuple(args), kwargs=kwargs)

Expand Down
6 changes: 6 additions & 0 deletions onnx2torch/node_converters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from onnx2torch.node_converters.activations import *
from onnx2torch.node_converters.average_pool import *
from onnx2torch.node_converters.batch_norm import *
from onnx2torch.node_converters.binary_math_operations import *
from onnx2torch.node_converters.cast import *
Expand All @@ -10,21 +11,26 @@
from onnx2torch.node_converters.conv import *
from onnx2torch.node_converters.expand import *
from onnx2torch.node_converters.flatten import *
from onnx2torch.node_converters.functions import *
from onnx2torch.node_converters.gather import *
from onnx2torch.node_converters.gemm import *
from onnx2torch.node_converters.global_average_pool import *
from onnx2torch.node_converters.identity import *
from onnx2torch.node_converters.logical import *
from onnx2torch.node_converters.matmul import *
from onnx2torch.node_converters.max_pool import *
from onnx2torch.node_converters.nms import *
from onnx2torch.node_converters.pow import *
from onnx2torch.node_converters.range import *
from onnx2torch.node_converters.reduce import *
from onnx2torch.node_converters.reshape import *
from onnx2torch.node_converters.resize import *
from onnx2torch.node_converters.roialign import *
from onnx2torch.node_converters.roundings import *
from onnx2torch.node_converters.scatter_nd import *
from onnx2torch.node_converters.shape import *
from onnx2torch.node_converters.slice import *
from onnx2torch.node_converters.split import *
from onnx2torch.node_converters.squeeze import *
from onnx2torch.node_converters.tile import *
from onnx2torch.node_converters.topk import *
Expand Down
47 changes: 32 additions & 15 deletions onnx2torch/node_converters/activations.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
__all__ = ['OnnxExp', 'OnnxHardSigmoid', 'OnnxSoftmaxV1V11']
__all__ = ['OnnxErf', 'OnnxHardSigmoid', 'OnnxSoftmaxV1V11']

import torch
from torch import nn

from onnx2torch.common import OperationConverterResult
from onnx2torch.common import onnx_mapping_from_node
from onnx2torch.node_converters.registry import add_converter
from onnx2torch.onnx_graph import OnnxGraph
from onnx2torch.onnx_node import OnnxNode
from onnx2torch.utils.common import OnnxToTorchModule
from onnx2torch.utils.common import OperationConverterResult
from onnx2torch.utils.common import onnx_mapping_from_node


class OnnxExp(nn.Module):
class OnnxErf(nn.Module, OnnxToTorchModule):

def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
return torch.exp(input_tensor)
return torch.erf(input_tensor)


class OnnxHardSigmoid(nn.Module):
class OnnxHardSigmoid(nn.Module, OnnxToTorchModule):
def __init__(self, alpha: float = 0.2, beta: float = 0.5):
super().__init__()
self.alpha = alpha
Expand All @@ -26,24 +27,25 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
return torch.clip(self.alpha * input_tensor + self.beta, min=0.0, max=1.0)


class OnnxSoftmaxV1V11(nn.Module):
def __init__(self, axis: int = 1):
class OnnxSoftmaxV1V11(nn.Module, OnnxToTorchModule):
def __init__(self, axis: int = 1, is_log: bool = False):
super().__init__()
self.axis = axis
self.is_log = is_log

def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
shape = input_tensor.shape
result = torch.flatten(input_tensor, start_dim=self.axis)
result = torch.softmax(result, -1)
result = torch.log_softmax(result, -1) if self.is_log else torch.softmax(result, -1)

return torch.reshape(result, shape)


@add_converter(operation_type='Exp', version=6)
@add_converter(operation_type='Exp', version=13)
@add_converter(operation_type='Erf', version=9)
@add_converter(operation_type='Erf', version=13)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
return OperationConverterResult(
torch_module=OnnxExp(),
torch_module=OnnxErf(),
onnx_mapping=onnx_mapping_from_node(node=node),
)

Expand Down Expand Up @@ -71,6 +73,23 @@ def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint:
)


@add_converter(operation_type='LogSoftmax', version=13)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
return OperationConverterResult(
torch_module=nn.LogSoftmax(dim=node.attributes.get('axis', -1)),
onnx_mapping=onnx_mapping_from_node(node=node),
)


@add_converter(operation_type='LogSoftmax', version=1)
@add_converter(operation_type='LogSoftmax', version=11)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
return OperationConverterResult(
torch_module=OnnxSoftmaxV1V11(axis=node.attributes.get('axis', 1), is_log=True),
onnx_mapping=onnx_mapping_from_node(node=node),
)


@add_converter(operation_type='Relu', version=6)
@add_converter(operation_type='Relu', version=13)
@add_converter(operation_type='Relu', version=14)
Expand Down Expand Up @@ -102,9 +121,7 @@ def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint:

@add_converter(operation_type='Softmax', version=13)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
axis = node.attributes.get('axis', -1)

return OperationConverterResult(
torch_module=torch.nn.Softmax(dim=axis),
torch_module=torch.nn.Softmax(dim=node.attributes.get('axis', -1)),
onnx_mapping=onnx_mapping_from_node(node=node),
)
Loading

0 comments on commit d2dc0a0

Please sign in to comment.