diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
index b6dbd0cc..91d19f1f 100644
--- a/.github/CODEOWNERS
+++ b/.github/CODEOWNERS
@@ -2,8 +2,8 @@
# See https://help.github.com/articles/about-code-owners/
# Code
-/onnx2torch @ivkalgin @irkjero
-/tests @ivkalgin @irkjero
+/onnx2torch @ivkalgin @senysenyseny16
+/tests @ivkalgin @senysenyseny16
# Actions
/.github @senysenyseny16
diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index eeb9ffb1..d98525e5 100644
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -11,18 +11,18 @@ jobs:
name: Pylint
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v3
- - uses: actions/setup-python@v3
- with:
- python-version: "3.9"
- - name: Install dependencies
- run: |
- python -m pip install --upgrade pip
- pip install pylint
- python -m pip install -e .
- - name: Analysing the code with pylint
- run: |
- pylint --rcfile .pylintrc --output-format=colorized $(git ls-files '*.py')
+ - uses: actions/checkout@v3
+ - uses: actions/setup-python@v3
+ with:
+ python-version: '3.9'
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install pylint
+ python -m pip install -e .[dev]
+ - name: Analysing the code with pylint
+ run: |
+ pylint --output-format=colorized $(git ls-files '*.py')
lint-python-format:
name: Python format
@@ -31,12 +31,10 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
- python-version: "3.9"
+ python-version: '3.9'
- uses: psf/black@stable
with:
- options: "--check --diff"
+ options: --check --diff
- uses: isort/isort-action@master
with:
- configuration:
- --check
- --diff
+ configuration: --check --diff
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index c3c97f07..ba5cc041 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,46 +1,79 @@
+default_install_hook_types: [commit-msg, pre-commit, pre-push]
+
repos:
-- repo: https://github.com/psf/black
- rev: 23.1.0
- hooks:
- - id: black
-- repo: https://github.com/PyCQA/isort
- rev: 5.12.0
- hooks:
- - id: isort
- args:
- [
- "--force-single-line-imports",
- "--ensure-newline-before-comments",
- "--line-length=120",
- "--profile=black",
- ]
-- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v4.4.0
- hooks:
- - id: check-yaml
- - id: check-toml
- - id: check-json
- - id: end-of-file-fixer
- - id: trailing-whitespace
- - id: check-added-large-files
- - id: check-case-conflict
- - id: check-merge-conflict
- - id: detect-private-key
- - id: end-of-file-fixer
- - id: requirements-txt-fixer
- - id: no-commit-to-branch
- args:
- [
- "-b=main",
- ]
-- repo: https://github.com/PyCQA/pylint
- rev: v2.16.0
- hooks:
- - id: pylint
- args:
- [
- "-rn",
- "-sn",
- "--rcfile=.pylintrc",
- "--output-format=colorized",
- ]
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.5.0
+ hooks:
+ - id: check-yaml
+ - id: check-toml
+ - id: check-json
+ - id: end-of-file-fixer
+ - id: trailing-whitespace
+ - id: check-added-large-files
+ - id: check-case-conflict
+ - id: check-merge-conflict
+ - id: detect-private-key
+ - id: end-of-file-fixer
+ - id: debug-statements
+ - id: detect-private-key
+ - id: detect-aws-credentials
+ args: [--allow-missing-credentials]
+ - id: no-commit-to-branch
+ args: [-b=main]
+
+ - repo: https://github.com/commitizen-tools/commitizen
+ rev: v3.20.0
+ hooks:
+ - id: commitizen
+
+ - repo: https://github.com/gitleaks/gitleaks
+ rev: v8.18.2
+ hooks:
+ - id: gitleaks
+
+ - repo: https://github.com/executablebooks/mdformat
+ rev: 0.7.17
+ hooks:
+ - id: mdformat
+ additional_dependencies:
+ - mdformat-gfm
+ - mdformat-black
+ - mdformat-shfmt
+
+ - repo: https://github.com/lyz-code/yamlfix
+ rev: 1.16.0
+ hooks:
+ - id: yamlfix
+
+ - repo: https://github.com/adrienverge/yamllint.git
+ rev: v1.35.1
+ hooks:
+ - id: yamllint
+ args:
+ - --format
+ - parsable
+ - --strict
+ - -d
+ - '{extends: relaxed, rules: {line-length: {max: 120}}}'
+
+ - repo: https://github.com/psf/black
+ rev: 24.3.0
+ hooks:
+ - id: black
+
+ - repo: https://github.com/PyCQA/isort
+ rev: 5.13.2
+ hooks:
+ - id: isort
+
+ - repo: https://github.com/PyCQA/pylint
+ rev: v3.1.0
+ hooks:
+ - id: pylint
+ language: system
+ args: [-rn, -sn]
+
+ - repo: https://github.com/RobertCraigie/pyright-python
+ rev: v1.1.356
+ hooks:
+ - id: pyright
diff --git a/.pylintrc b/.pylintrc
deleted file mode 100644
index 757551f4..00000000
--- a/.pylintrc
+++ /dev/null
@@ -1,44 +0,0 @@
-[MAIN]
-load-plugins=pylint.extensions.docparams
-
-[BASIC]
-# Good variable names which should always be accepted, separated by a comma.
-good-names=x,y,z
-
-[DESIGN]
-# Maximum number of arguments for function / method
-max-args=12
-# Maximum number of locals for function / method body
-max-locals=30
-# Maximum number of statements in function / method body
-max-statements=60
-# Maximum number of attributes for a class (see R0902).
-max-attributes=20
-# Minimum number of public methods for a class (see R0903).
-min-public-methods=0
-# Maximum number of public methods for a class (see R0904).
-max-public-methods=20
-
-[FORMAT]
-max-line-length=120
-
-[MESSAGES]
-disable=logging-fstring-interpolation, import-error, missing-module-docstring, missing-raises-doc, duplicate-code, fixme, unnecessary-lambda-assignment
-
-[SIMILARITIES]
-# Minimum lines number of a similarity.
-min-similarity-lines=4
-# Ignore comments when computing similarities.
-ignore-comments=yes
-# Ignore docstrings when computing similarities.
-ignore-docstrings=yes
-# Ignore imports when computing similarities.
-ignore-imports=yes
-# Ignore function signatures when computing similarities.
-ignore-signatures=no
-
-[TYPECHECK]
-# List of members which are set dynamically and missed by Pylint inference
-# system, and so shouldn't trigger E1101 when accessed.
-generated-members=numpy.*, torch.*, onnx.*, onnxruntime.*
-ignored-modules=onnx.onnx_ml_pb2
diff --git a/README.md b/README.md
index f0d713f5..dae844c4 100644
--- a/README.md
+++ b/README.md
@@ -11,36 +11,47 @@
-
-
-
-
+
-
+
+
+
+
+
+
+
+
+
+
+
onnx2torch is an ONNX to PyTorch converter.
Our converter:
-* Is easy to use – Convert the ONNX model with the function call ``convert``;
-* Is easy to extend – Write your own custom layer in PyTorch and register it with ``@add_converter``;
-* Convert back to ONNX – You can convert the model back to ONNX using the ``torch.onnx.export`` function.
-If you find an issue, please [let us know](https://github.com/ENOT-AutoDL/onnx2torch/issues)! And feel free to create merge requests.
+- Is easy to use – Convert the ONNX model with the function call `convert`;
+- Is easy to extend – Write your own custom layer in PyTorch and register it with `@add_converter`;
+- Convert back to ONNX – You can convert the model back to ONNX using the `torch.onnx.export` function.
+
+If you find an issue, please [let us know](https://github.com/ENOT-AutoDL/onnx2torch/issues)!
+And feel free to create merge requests.
Please note that this converter covers only a limited number of PyTorch / ONNX models and operations.
-Let us know which models you use or want to convert from onnx to torch [here](https://github.com/ENOT-AutoDL/onnx2torch/discussions).
+Let us know which models you use or want to convert from ONNX to PyTorch [here](https://github.com/ENOT-AutoDL/onnx2torch/discussions).
## Installation
```bash
pip install onnx2torch
```
+
or
+
```bash
conda install -c conda-forge onnx2torch
```
@@ -57,7 +68,7 @@ import torch
from onnx2torch import convert
# Path to ONNX model
-onnx_model_path = '/some/path/mobile_net_v2.onnx'
+onnx_model_path = "/some/path/mobile_net_v2.onnx"
# You can pass the path to the onnx model to convert it or...
torch_model_1 = convert(onnx_model_path)
@@ -68,21 +79,22 @@ torch_model_2 = convert(onnx_model)
### Execute
-We can execute the returned ``PyTorch model`` in the same way as the original torch model.
+We can execute the returned `PyTorch model` in the same way as the original torch model.
```python
import onnxruntime as ort
+
# Create example data
x = torch.ones((1, 2, 224, 224)).cuda()
out_torch = torch_model_1(x)
ort_sess = ort.InferenceSession(onnx_model_path)
-outputs_ort = ort_sess.run(None, {'input': x.numpy()})
+outputs_ort = ort_sess.run(None, {"input": x.numpy()})
# Check the Onnx output against PyTorch
print(torch.max(torch.abs(outputs_ort - out_torch.detach().numpy())))
-print(np.allclose(outputs_ort, out_torch.detach().numpy(), atol=1.e-7))
+print(np.allclose(outputs_ort, out_torch.detach().numpy(), atol=1.0e-7))
```
## Models
@@ -90,35 +102,39 @@ print(np.allclose(outputs_ort, out_torch.detach().numpy(), atol=1.e-7))
We have tested the following models:
Segmentation models:
-- [x] DeepLabv3plus
-- [x] DeepLabv3 resnet50 (torchvision)
+
+- [x] DeepLabV3+
+- [x] DeepLabV3 ResNet-50 (TorchVision)
- [x] HRNet
-- [x] UNet (torchvision)
-- [x] FCN resnet50 (torchvision)
-- [x] lraspp mobilenetv3 (torchvision)
+- [x] UNet (TorchVision)
+- [x] FCN ResNet-50 (TorchVision)
+- [x] LRASPP MobileNetV3 (TorchVision)
+
+Detection from MMdetection:
-Detection from MMdetection:
- [x] [SSDLite with MobileNetV2 backbone](https://github.com/open-mmlab/mmdetection)
- [x] [RetinaNet R50](https://github.com/open-mmlab/mmdetection)
- [x] [SSD300 with VGG backbone](https://github.com/open-mmlab/mmdetection)
-- [x] [Yolov3_d53](https://github.com/open-mmlab/mmdetection)
-- [x] [Yolov5](https://github.com/ultralytics/yolov5)
-
-Classification from __torchvision__:
-- [x] Resnet18
-- [x] Resnet50
-- [x] MobileNet v2
-- [x] MobileNet v3 large
-- [x] EfficientNet_b{0, 1, 2, 3}
-- [x] WideResNet50
-- [x] ResNext50
-- [x] VGG16
-- [x] GoogleleNet
+- [x] [YOLOv3 d53](https://github.com/open-mmlab/mmdetection)
+- [x] [YOLOv5](https://github.com/ultralytics/yolov5)
+
+Classification from __TorchVision__:
+
+- [x] ResNet-18
+- [x] ResNet-50
+- [x] MobileNetV2
+- [x] MobileNetV3 Large
+- [x] EfficientNet-B{0, 1, 2, 3}
+- [x] WideResNet-50
+- [x] ResNext-50
+- [x] VGG-16
+- [x] GoogLeNet
- [x] MnasNet
- [x] RegNet
Transformers:
-- [x] Vit
+
+- [x] ViT
- [x] Swin
- [x] GPT-J
@@ -127,15 +143,16 @@ Transformers:
## How to add new operations to converter
Here we show how to extend onnx2torch with new ONNX operation, that supported by both PyTorch and ONNX
+
and has the same behaviour
An example of such a module is [Relu](./onnx2torch/node_converters/activations.py)
```python
-@add_converter(operation_type='Relu', version=6)
-@add_converter(operation_type='Relu', version=13)
-@add_converter(operation_type='Relu', version=14)
+@add_converter(operation_type="Relu", version=6)
+@add_converter(operation_type="Relu", version=13)
+@add_converter(operation_type="Relu", version=14)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
return OperationConverterResult(
torch_module=nn.ReLU(),
@@ -143,9 +160,10 @@ def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
)
```
-Here we have registered an operation named ``Relu`` for opset versions 6, 13, 14.
-Note that the ``torch_module`` argument in ``OperationConverterResult`` must be a torch.nn.Module, not just a callable object!
+Here we have registered an operation named `Relu` for opset versions 6, 13, 14.
+Note that the `torch_module` argument in `OperationConverterResult` must be a torch.nn.Module, not just a callable object!
If Operation's behaviour differs from one opset version to another, you should implement it separately.
+
@@ -156,9 +174,9 @@ An example of such a module is [ScatterND](./onnx2torch/node_converters/scatter_
```python
# It is recommended to use Enum for string ONNX attributes.
class ReductionOnnxAttr(Enum):
- NONE = 'none'
- ADD = 'add'
- MUL = 'mul'
+ NONE = "none"
+ ADD = "add"
+ MUL = "mul"
class OnnxScatterND(nn.Module, OnnxToTorchModuleWithCustomExport):
@@ -177,13 +195,13 @@ class OnnxScatterND(nn.Module, OnnxToTorchModuleWithCustomExport):
if opset_version < 16:
if self._reduction != ReductionOnnxAttr.NONE:
raise ValueError(
- 'ScatterND from opset < 16 does not support'
- f'reduction attribute != {ReductionOnnxAttr.NONE.value},'
- f'got {self._reduction.value}'
+ "ScatterND from opset < 16 does not support"
+ f"reduction attribute != {ReductionOnnxAttr.NONE.value},"
+ f"got {self._reduction.value}"
)
return onnx_attrs
- onnx_attrs['reduction_s'] = self._reduction.value
+ onnx_attrs["reduction_s"] = self._reduction.value
return onnx_attrs
def forward(
@@ -200,23 +218,27 @@ class OnnxScatterND(nn.Module, OnnxToTorchModuleWithCustomExport):
# Please follow our convention, args consists of:
# forward function, operation type, operation inputs, operation attributes.
onnx_attrs = self._onnx_attrs(opset_version=get_onnx_version())
- return DefaultExportToOnnx.export(_forward, 'ScatterND', data, indices, updates, onnx_attrs)
+ return DefaultExportToOnnx.export(
+ _forward, "ScatterND", data, indices, updates, onnx_attrs
+ )
return _forward()
-@add_converter(operation_type='ScatterND', version=11)
-@add_converter(operation_type='ScatterND', version=13)
-@add_converter(operation_type='ScatterND', version=16)
+@add_converter(operation_type="ScatterND", version=11)
+@add_converter(operation_type="ScatterND", version=13)
+@add_converter(operation_type="ScatterND", version=16)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
node_attributes = node.attributes
- reduction = ReductionOnnxAttr(node_attributes.get('reduction', 'none'))
+ reduction = ReductionOnnxAttr(node_attributes.get("reduction", "none"))
return OperationConverterResult(
torch_module=OnnxScatterND(reduction=reduction),
onnx_mapping=onnx_mapping_from_node(node=node),
)
```
-Here we have used a trick to convert the model from torch back to ONNX by defining the custom ``_ScatterNDExportToOnnx``.
+
+Here we have used a trick to convert the model from torch back to ONNX by defining the custom `_ScatterNDExportToOnnx`.
+
## Opset version workaround
@@ -235,13 +257,13 @@ import torch
from onnx2torch import convert
# Load the ONNX model.
-model = onnx.load('model.onnx')
+model = onnx.load("model.onnx")
# Convert the model to the target version.
target_version = 13
converted_model = version_converter.convert_version(model, target_version)
# Convert to torch.
torch_model = convert(converted_model)
-torch.save(torch_model, 'model.pt')
+torch.save(torch_model, "model.pt")
```
@@ -251,6 +273,7 @@ Note: use this only when the model does not convert to PyTorch using the existin
## Citation
To cite onnx2torch use `Cite this repository` button, or:
+
```
@misc{onnx2torch,
title={onnx2torch},
diff --git a/onnx2torch/node_converters/base_element_wise.py b/onnx2torch/node_converters/base_element_wise.py
index 60e94214..762e513a 100644
--- a/onnx2torch/node_converters/base_element_wise.py
+++ b/onnx2torch/node_converters/base_element_wise.py
@@ -1,3 +1,4 @@
+# pylint: disable=missing-docstring
import torch
from torch import nn
@@ -5,7 +6,7 @@
from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport
-class OnnxBaseElementWise(nn.Module, OnnxToTorchModuleWithCustomExport): # pylint: disable=missing-docstring
+class OnnxBaseElementWise(nn.Module, OnnxToTorchModuleWithCustomExport):
def __init__(self, op_type: str):
super().__init__()
self._op_type = op_type
@@ -16,17 +17,20 @@ def _broadcast_shape(*tensors: torch.Tensor):
broadcast_shape = torch.broadcast_shapes(*shapes)
return broadcast_shape
- def apply_reduction(self, *tensors: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring
+ def apply_reduction(self, *tensors: torch.Tensor) -> torch.Tensor:
+ del tensors
raise NotImplementedError
- def forward(self, *input_tensors: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring
+ def forward(self, *input_tensors: torch.Tensor) -> torch.Tensor:
if len(input_tensors) == 1:
# If there is a single element, return it (no op).
# Also, no need for manually building the ONNX node.
return input_tensors[0]
- forward_lambda = lambda: self.apply_reduction(*input_tensors)
+ def _forward() -> torch.Tensor:
+ return self.apply_reduction(*input_tensors)
+
if torch.onnx.is_in_onnx_export():
- return DefaultExportToOnnx.export(forward_lambda, self._op_type, *input_tensors, {})
+ return DefaultExportToOnnx.export(_forward, self._op_type, *input_tensors, {})
- return forward_lambda()
+ return _forward()
diff --git a/onnx2torch/node_converters/global_average_pool.py b/onnx2torch/node_converters/global_average_pool.py
index f57726dd..4e27287a 100644
--- a/onnx2torch/node_converters/global_average_pool.py
+++ b/onnx2torch/node_converters/global_average_pool.py
@@ -1,3 +1,4 @@
+# pylint: disable=missing-docstring
__all__ = [
'OnnxGlobalAveragePool',
'OnnxGlobalAveragePoolWithKnownInputShape',
@@ -18,8 +19,8 @@
from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport
-class OnnxGlobalAveragePool(nn.Module, OnnxToTorchModuleWithCustomExport): # pylint: disable=missing-docstring
- def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring
+class OnnxGlobalAveragePool(nn.Module, OnnxToTorchModuleWithCustomExport):
+ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
def _forward():
x_dims = list(range(2, len(input_tensor.shape)))
return torch.mean(input_tensor, dim=x_dims, keepdim=True)
@@ -30,24 +31,23 @@ def _forward():
return _forward()
-class OnnxGlobalAveragePoolWithKnownInputShape(
- nn.Module, OnnxToTorchModuleWithCustomExport
-): # pylint: disable=missing-docstring
+class OnnxGlobalAveragePoolWithKnownInputShape(nn.Module, OnnxToTorchModuleWithCustomExport):
def __init__(self, input_shape: List[int]):
super().__init__()
self._x_dims = list(range(2, len(input_shape)))
- def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring
- forward_lambda = lambda: torch.mean(input_tensor, dim=self._x_dims, keepdim=True)
+ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
+ def _forward() -> torch.Tensor:
+ return torch.mean(input_tensor, dim=self._x_dims, keepdim=True)
if torch.onnx.is_in_onnx_export():
- return DefaultExportToOnnx.export(forward_lambda, 'GlobalAveragePool', input_tensor, {})
+ return DefaultExportToOnnx.export(_forward, 'GlobalAveragePool', input_tensor, {})
- return forward_lambda()
+ return _forward()
@add_converter(operation_type='GlobalAveragePool', version=1)
-def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
+def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
input_value_info = graph.value_info[node.input_values[0]]
input_shape = get_shape_from_value_info(input_value_info)
diff --git a/onnx2torch/node_converters/logical.py b/onnx2torch/node_converters/logical.py
index 79ea08c0..ac5b4681 100644
--- a/onnx2torch/node_converters/logical.py
+++ b/onnx2torch/node_converters/logical.py
@@ -1,3 +1,4 @@
+# pylint: disable=missing-docstring
__all__ = [
'OnnxNot',
'OnnxLogical',
@@ -25,16 +26,18 @@
}
-class OnnxNot(nn.Module, OnnxToTorchModuleWithCustomExport): # pylint: disable=missing-class-docstring
- def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring
- forward_lambda = lambda: torch.logical_not(input_tensor)
+class OnnxNot(nn.Module, OnnxToTorchModuleWithCustomExport):
+ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
+ def _forward() -> torch.Tensor:
+ return torch.logical_not(input_tensor)
+
if torch.onnx.is_in_onnx_export():
- return DefaultExportToOnnx.export(forward_lambda, 'Not', input_tensor, {})
+ return DefaultExportToOnnx.export(_forward, 'Not', input_tensor, {})
- return forward_lambda()
+ return _forward()
-class OnnxLogical(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring
+class OnnxLogical(nn.Module, OnnxToTorchModule):
def __init__(self, operation_type: str, broadcast: Optional[int] = None, axis: Optional[int] = None):
super().__init__()
self.broadcast = broadcast
@@ -42,9 +45,7 @@ def __init__(self, operation_type: str, broadcast: Optional[int] = None, axis: O
self.logic_op_function = _TORCH_FUNCTION_FROM_ONNX_TYPE[operation_type]
- def forward( # pylint: disable=missing-function-docstring
- self, first_tensor: torch.Tensor, second_tensor: torch.Tensor
- ):
+ def forward(self, first_tensor: torch.Tensor, second_tensor: torch.Tensor):
if self.broadcast == 1 and self.axis is not None:
second_tensor = old_style_broadcast(first_tensor, second_tensor, self.axis)
@@ -57,7 +58,8 @@ def forward( # pylint: disable=missing-function-docstring
@add_converter(operation_type='And', version=7)
@add_converter(operation_type='Or', version=1)
@add_converter(operation_type='Or', version=7)
-def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
+def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
+ del graph
return OperationConverterResult(
torch_module=OnnxLogical(
operation_type=node.operation_type,
@@ -69,7 +71,8 @@ def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint:
@add_converter(operation_type='Not', version=1)
-def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
+def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
+ del graph
return OperationConverterResult(
torch_module=OnnxNot(),
onnx_mapping=onnx_mapping_from_node(node=node),
diff --git a/onnx2torch/node_converters/nms.py b/onnx2torch/node_converters/nms.py
index 37bbfa46..c6b87317 100644
--- a/onnx2torch/node_converters/nms.py
+++ b/onnx2torch/node_converters/nms.py
@@ -1,3 +1,4 @@
+# pylint: disable=missing-docstring
__all__ = [
'OnnxNonMaxSuppression',
]
@@ -20,15 +21,16 @@
from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport
-class OnnxNonMaxSuppression(nn.Module, OnnxToTorchModuleWithCustomExport): # pylint: disable=missing-class-docstring
+class OnnxNonMaxSuppression(nn.Module, OnnxToTorchModuleWithCustomExport):
def __init__(self, center_point_box: int = 0):
super().__init__()
self._center_point_box = center_point_box
def _onnx_attrs(self, opset_version: int) -> Dict[str, Any]:
+ del opset_version
return {'center_point_box_i': self._center_point_box}
- def forward( # pylint: disable=missing-function-docstring
+ def forward(
self,
boxes: torch.Tensor,
scores: torch.Tensor,
@@ -36,7 +38,8 @@ def forward( # pylint: disable=missing-function-docstring
iou_threshold: Optional[torch.Tensor] = None,
score_threshold: Optional[torch.Tensor] = None,
) -> torch.Tensor:
- forward_lambda = lambda: self._nms(boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold)
+ def _forward() -> torch.Tensor:
+ return self._nms(boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold)
if torch.onnx.is_in_onnx_export():
if max_output_boxes_per_class is None:
@@ -48,7 +51,7 @@ def forward( # pylint: disable=missing-function-docstring
onnx_attrs = self._onnx_attrs(opset_version=get_onnx_version())
return DefaultExportToOnnx.export(
- forward_lambda,
+ _forward,
'NonMaxSuppression',
boxes,
scores,
@@ -58,7 +61,7 @@ def forward( # pylint: disable=missing-function-docstring
onnx_attrs,
)
- return forward_lambda()
+ return _forward()
def _nms(
self,
@@ -109,7 +112,8 @@ def _nms(
@add_converter(operation_type='NonMaxSuppression', version=10)
@add_converter(operation_type='NonMaxSuppression', version=11)
-def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
+def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
+ del graph
center_point_box = node.attributes.get('center_point_box', 0)
return OperationConverterResult(
torch_module=OnnxNonMaxSuppression(center_point_box=center_point_box),
diff --git a/onnx2torch/node_converters/range.py b/onnx2torch/node_converters/range.py
index 8535e613..392e32e8 100644
--- a/onnx2torch/node_converters/range.py
+++ b/onnx2torch/node_converters/range.py
@@ -1,3 +1,4 @@
+# pylint: disable=missing-docstring
__all__ = [
'OnnxRange',
]
@@ -16,7 +17,7 @@
from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport
-class OnnxRange(nn.Module, OnnxToTorchModuleWithCustomExport): # pylint: disable=missing-class-docstring
+class OnnxRange(nn.Module, OnnxToTorchModuleWithCustomExport):
def __init__(self):
super().__init__()
self.register_buffer('dummy_buffer', torch.Tensor(), persistent=False)
@@ -41,22 +42,24 @@ def _arange(
device=self.dummy_buffer.device,
)
- def forward( # pylint: disable=missing-function-docstring
+ def forward(
self,
start: Union[torch.Tensor, float, int],
limit: Union[torch.Tensor, float, int],
delta: Union[torch.Tensor, float, int],
) -> torch.Tensor:
- forward_lambda = lambda: self._arange(start, limit, delta)
+ def _forward() -> torch.Tensor:
+ return self._arange(start, limit, delta)
if torch.onnx.is_in_onnx_export():
- return DefaultExportToOnnx.export(forward_lambda, 'Range', start, limit, delta, {})
+ return DefaultExportToOnnx.export(_forward, 'Range', start, limit, delta, {})
- return forward_lambda()
+ return _forward()
@add_converter(operation_type='Range', version=11)
-def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
+def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
+ del graph
return OperationConverterResult(
torch_module=OnnxRange(),
onnx_mapping=onnx_mapping_from_node(node),
diff --git a/onnx2torch/node_converters/registry.py b/onnx2torch/node_converters/registry.py
index 3d0b02d3..b4ddd174 100644
--- a/onnx2torch/node_converters/registry.py
+++ b/onnx2torch/node_converters/registry.py
@@ -37,7 +37,7 @@ def deco(converter: TConverter):
raise ValueError(f'Operation "{description}" already registered')
_CONVERTER_REGISTRY[description] = converter
- _LOGGER.info(f'Operation converter registered {description}')
+ _LOGGER.debug(f'Operation converter registered {description}')
return converter
diff --git a/onnx2torch/node_converters/reshape.py b/onnx2torch/node_converters/reshape.py
index edb8f05f..e8ec3aa7 100644
--- a/onnx2torch/node_converters/reshape.py
+++ b/onnx2torch/node_converters/reshape.py
@@ -27,12 +27,13 @@ def forward( # pylint: disable=missing-function-docstring
input_tensor: torch.Tensor,
shape: torch.Tensor,
) -> torch.Tensor:
- forward_lambda = lambda: self._do_reshape(input_tensor, shape)
+ def _forward() -> torch.Tensor:
+ return self._do_reshape(input_tensor, shape)
if torch.onnx.is_in_onnx_export():
- return DefaultExportToOnnx.export(forward_lambda, 'Reshape', input_tensor, shape, {})
+ return DefaultExportToOnnx.export(_forward, 'Reshape', input_tensor, shape, {})
- return forward_lambda()
+ return _forward()
@add_converter(operation_type='Reshape', version=5)
diff --git a/onnx2torch/node_converters/tile.py b/onnx2torch/node_converters/tile.py
index 9cfd0db4..0508a877 100644
--- a/onnx2torch/node_converters/tile.py
+++ b/onnx2torch/node_converters/tile.py
@@ -1,3 +1,4 @@
+# pylint: disable=missing-docstring
__all__ = [
'OnnxTile',
]
@@ -14,23 +15,21 @@
from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport
-class OnnxTile(nn.Module, OnnxToTorchModuleWithCustomExport): # pylint: disable=missing-class-docstring
- def forward( # pylint: disable=missing-function-docstring
- self,
- input_tensor: torch.Tensor,
- repeats: torch.Tensor,
- ) -> torch.Tensor:
- # torch.tile(input_tensor, repeats) is not supported for exporting
- forward_lambda = lambda: input_tensor.repeat(torch.Size(repeats))
+class OnnxTile(nn.Module, OnnxToTorchModuleWithCustomExport):
+ def forward(self, input_tensor: torch.Tensor, repeats: torch.Tensor) -> torch.Tensor:
+ def _forward() -> torch.Tensor:
+ return input_tensor.repeat(torch.Size(repeats))
+
if torch.onnx.is_in_onnx_export():
- return DefaultExportToOnnx.export(forward_lambda, 'Tile', input_tensor, repeats, {})
+ return DefaultExportToOnnx.export(_forward, 'Tile', input_tensor, repeats, {})
- return forward_lambda()
+ return _forward()
@add_converter(operation_type='Tile', version=6)
@add_converter(operation_type='Tile', version=13)
-def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
+def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
+ del graph
return OperationConverterResult(
torch_module=OnnxTile(),
onnx_mapping=onnx_mapping_from_node(node=node),
diff --git a/operators.md b/operators.md
index be77b4e1..949a1c6e 100644
--- a/operators.md
+++ b/operators.md
@@ -2,177 +2,177 @@
Minimal tested opset version 9, maximum tested opset version 16, recommended opset version 13
-| Operation type | Supported | Restrictions |
-|---------------------------|-----------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------|
-| Abs | Y | |
-| Acos | Y | |
-| Acosh | N | |
-| Add | Y | |
-| And | Y | |
-| ArgMax | N | |
-| ArgMin | N | |
-| Asin | Y | |
-| Asinh | N | |
-| Atan | Y | |
-| Atanh | N | |
-| AveragePool | Y | Average pool operation with spatial rank > 3 is not implemented |
-| BatchNormalization | Y | BatchNorm operation with spatial rank > 3 is not implemented. BatchNorm nodes in training mode are not supported |
-| BitShift | N | |
-| Cast | Y | |
-| Ceil | Y | |
-| Clip | Y | Dynamic value of min/max is not implemented |
-| Compress | N | |
-| Concat | Y | |
-| ConcatFromSequence | N | |
-| Constant | Y | |
-| ConstantOfShape | Y | Parameter "value" must be scalar |
-| Conv | Y | Convolution operation with spatial rank > 3 is not implemented |
-| ConvInteger | N | |
-| ConvTranspose | Y | Convolution operation with spatial rank > 3 is not implemented |
-| Cos | Y | |
-| Cosh | N | |
-| CumSum | Y | |
-| DepthToSpace | Y | DCR mode is not implemented |
-| DequantizeLinear | N | |
-| Det | N | |
-| Div | Y | |
-| Dropout | Y | |
-| Einsum | Y | |
-| Elu | Y | |
-| Equal | Y | |
-| Erf | Y | |
-| Exp | Y | |
-| Expand | Y | |
-| EyeLike | Y | |
-| Flatten | Y | |
-| Floor | Y | |
-| GRU | N | |
-| Gather | Y | |
-| GatherElements | Y | |
-| GatherND | Y | GatherND operation with parameter "batch_dims" > 0 is not implemented |
-| Gemm | Y | |
-| GlobalAveragePool | Y | |
-| GlobalLpPool | N | |
-| GlobalMaxPool | N | |
-| Greater | Y | |
-| GridSample | N | |
-| HardSigmoid | Y | |
-| Hardmax | N | |
-| Identity | Y | |
-| If | N | |
-| InstanceNormalization | Y | |
-| IsInf | Y | |
-| IsNaN | Y | |
-| LayerNormalization | Y | LayerNormalization outputs "Mean" and "InvStdDev" are not implemented |
-| LRN | Y | |
-| LSTM | N | |
-| LeakyRelu | Y | |
-| Less | Y | |
-| Log | Y | |
-| Loop | N | |
-| LpNormalization | N | |
-| LpPool | N | |
-| MatMul | Y | |
-| MatMulInteger | N | |
-| Max | Y | |
-| MaxPool | Y | Max pool operation with spatial rank > 3 is not implemented |
-| MaxRoiPool | N | |
-| MaxUnpool | N | |
-| Mean | Y | |
-| Min | Y | |
-| Mod | Y | |
-| Mul | Y | |
-| Multinomial | N | |
-| Neg | Y | |
-| NonMaxSuppression | Y | |
-| NonZero | Y | |
-| Not | Y | |
-| OneHot | N | |
-| Optional | N | |
-| OptionalGetElement | N | |
-| OptionalHasElement | N | |
-| Or | Y | |
-| PRelu | Y | |
-| Pad | Y | Padding is implemented to pad the last 3 dimensions of 5D input tensor, or the last 2 dimensions of 4D input tensor, or the last dimension of 3D input tensor |
-| Pow | Y | |
-| QLinearConv | N | |
-| QLinearMatMul | N | |
-| QuantizeLinear | N | |
-| RNN | N | |
-| RandomNormal | N | |
-| RandomNormalLike | N | |
-| RandomUniform | N | |
-| RandomUniformLike | N | |
-| Reciprocal | Y | |
-| ReduceL1 | Y | |
-| ReduceL2 | Y | |
-| ReduceLogSum | Y | |
-| ReduceLogSumExp | Y | |
-| ReduceMax | Y | |
-| ReduceMean | Y | |
-| ReduceMin | Y | |
-| ReduceProd | Y | |
-| ReduceSum | Y | |
-| ReduceSumSquare | Y | |
-| Relu | Y | |
-| Reshape | Y | Parameter "allowzero" = 1 is not implemented |
-| Resize | Y | Roi logic is not implemented (pytorch's interpolate cannot resize channel or batch dimensions) |
-| ReverseSequence | N | |
-| RoiAlign | Y | Only "avg" mode is supported |
-| Round | Y | |
-| Scan | N | |
-| Scatter(deprecated) | N | |
-| ScatterElements | N | |
-| ScatterND | Y | Only "none" reduction is supported |
-| Selu | Y | Parameters "alpha" and "gamma" must be default |
-| SequenceAt | N | |
-| SequenceConstruct | N | |
-| SequenceEmpty | N | |
-| SequenceErase | N | |
-| SequenceInsert | N | |
-| SequenceLength | N | |
-| Shape | Y | |
-| Shrink | N | |
-| Sigmoid | Y | |
-| Sign | Y | |
-| Sin | Y | |
-| Sinh | N | |
-| Size | N | |
-| Slice | Y | |
-| Softplus | Y | |
-| Softsign | Y | |
-| SpaceToDepth | N | |
-| Split | Y | |
-| SplitToSequence | N | |
-| Sqrt | Y | |
-| Squeeze | Y | |
-| StringNormalizer | N | |
-| Sub | Y | |
-| Sum | Y | |
-| Tan | Y | |
-| Tanh | Y | |
-| TfIdfVectorizer | N | |
-| ThresholdedRelu | N | |
-| Tile | Y | |
-| TopK | Y | |
-| Transpose | Y | |
-| Trilu | N | |
-| Unique | N | |
-| Unsqueeze | Y | |
-| Upsample(deprecated) | N | |
-| Where | Y | |
-| Xor | Y | |
-| Function | N | |
-| Bernoulli | N | |
-| CastLike | N | |
-| Celu | Y | |
-| DynamicQuantizeLinear | N | |
-| GreaterOrEqual | Y | |
-| HardSwish | Y | |
-| LessOrEqual | Y | |
-| LogSoftmax | Y | |
-| MeanVarianceNormalization | N | |
-| NegativeLogLikelihoodLoss | N | |
-| Range | Y | |
-| SequenceMap | N | |
-| Softmax | Y | |
+| Operation type | Supported | Restrictions |
+| ------------------------- | --------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| Abs | Y | |
+| Acos | Y | |
+| Acosh | N | |
+| Add | Y | |
+| And | Y | |
+| ArgMax | N | |
+| ArgMin | N | |
+| Asin | Y | |
+| Asinh | N | |
+| Atan | Y | |
+| Atanh | N | |
+| AveragePool | Y | Average pool operation with spatial rank > 3 is not implemented |
+| BatchNormalization | Y | BatchNorm operation with spatial rank > 3 is not implemented. BatchNorm nodes in training mode are not supported |
+| BitShift | N | |
+| Cast | Y | |
+| Ceil | Y | |
+| Clip | Y | Dynamic value of min/max is not implemented |
+| Compress | N | |
+| Concat | Y | |
+| ConcatFromSequence | N | |
+| Constant | Y | |
+| ConstantOfShape | Y | Parameter "value" must be scalar |
+| Conv | Y | Convolution operation with spatial rank > 3 is not implemented |
+| ConvInteger | N | |
+| ConvTranspose | Y | Convolution operation with spatial rank > 3 is not implemented |
+| Cos | Y | |
+| Cosh | N | |
+| CumSum | Y | |
+| DepthToSpace | Y | DCR mode is not implemented |
+| DequantizeLinear | N | |
+| Det | N | |
+| Div | Y | |
+| Dropout | Y | |
+| Einsum | Y | |
+| Elu | Y | |
+| Equal | Y | |
+| Erf | Y | |
+| Exp | Y | |
+| Expand | Y | |
+| EyeLike | Y | |
+| Flatten | Y | |
+| Floor | Y | |
+| GRU | N | |
+| Gather | Y | |
+| GatherElements | Y | |
+| GatherND | Y | GatherND operation with parameter "batch_dims" > 0 is not implemented |
+| Gemm | Y | |
+| GlobalAveragePool | Y | |
+| GlobalLpPool | N | |
+| GlobalMaxPool | N | |
+| Greater | Y | |
+| GridSample | N | |
+| HardSigmoid | Y | |
+| Hardmax | N | |
+| Identity | Y | |
+| If | N | |
+| InstanceNormalization | Y | |
+| IsInf | Y | |
+| IsNaN | Y | |
+| LayerNormalization | Y | LayerNormalization outputs "Mean" and "InvStdDev" are not implemented |
+| LRN | Y | |
+| LSTM | N | |
+| LeakyRelu | Y | |
+| Less | Y | |
+| Log | Y | |
+| Loop | N | |
+| LpNormalization | N | |
+| LpPool | N | |
+| MatMul | Y | |
+| MatMulInteger | N | |
+| Max | Y | |
+| MaxPool | Y | Max pool operation with spatial rank > 3 is not implemented |
+| MaxRoiPool | N | |
+| MaxUnpool | N | |
+| Mean | Y | |
+| Min | Y | |
+| Mod | Y | |
+| Mul | Y | |
+| Multinomial | N | |
+| Neg | Y | |
+| NonMaxSuppression | Y | |
+| NonZero | Y | |
+| Not | Y | |
+| OneHot | N | |
+| Optional | N | |
+| OptionalGetElement | N | |
+| OptionalHasElement | N | |
+| Or | Y | |
+| PRelu | Y | |
+| Pad | Y | Padding is implemented to pad the last 3 dimensions of 5D input tensor, or the last 2 dimensions of 4D input tensor, or the last dimension of 3D input tensor |
+| Pow | Y | |
+| QLinearConv | N | |
+| QLinearMatMul | N | |
+| QuantizeLinear | N | |
+| RNN | N | |
+| RandomNormal | N | |
+| RandomNormalLike | N | |
+| RandomUniform | N | |
+| RandomUniformLike | N | |
+| Reciprocal | Y | |
+| ReduceL1 | Y | |
+| ReduceL2 | Y | |
+| ReduceLogSum | Y | |
+| ReduceLogSumExp | Y | |
+| ReduceMax | Y | |
+| ReduceMean | Y | |
+| ReduceMin | Y | |
+| ReduceProd | Y | |
+| ReduceSum | Y | |
+| ReduceSumSquare | Y | |
+| Relu | Y | |
+| Reshape | Y | Parameter "allowzero" = 1 is not implemented |
+| Resize | Y | Roi logic is not implemented (pytorch's interpolate cannot resize channel or batch dimensions) |
+| ReverseSequence | N | |
+| RoiAlign | Y | Only "avg" mode is supported |
+| Round | Y | |
+| Scan | N | |
+| Scatter(deprecated) | N | |
+| ScatterElements | N | |
+| ScatterND | Y | Only "none" reduction is supported |
+| Selu | Y | Parameters "alpha" and "gamma" must be default |
+| SequenceAt | N | |
+| SequenceConstruct | N | |
+| SequenceEmpty | N | |
+| SequenceErase | N | |
+| SequenceInsert | N | |
+| SequenceLength | N | |
+| Shape | Y | |
+| Shrink | N | |
+| Sigmoid | Y | |
+| Sign | Y | |
+| Sin | Y | |
+| Sinh | N | |
+| Size | N | |
+| Slice | Y | |
+| Softplus | Y | |
+| Softsign | Y | |
+| SpaceToDepth | N | |
+| Split | Y | |
+| SplitToSequence | N | |
+| Sqrt | Y | |
+| Squeeze | Y | |
+| StringNormalizer | N | |
+| Sub | Y | |
+| Sum | Y | |
+| Tan | Y | |
+| Tanh | Y | |
+| TfIdfVectorizer | N | |
+| ThresholdedRelu | N | |
+| Tile | Y | |
+| TopK | Y | |
+| Transpose | Y | |
+| Trilu | N | |
+| Unique | N | |
+| Unsqueeze | Y | |
+| Upsample(deprecated) | N | |
+| Where | Y | |
+| Xor | Y | |
+| Function | N | |
+| Bernoulli | N | |
+| CastLike | N | |
+| Celu | Y | |
+| DynamicQuantizeLinear | N | |
+| GreaterOrEqual | Y | |
+| HardSwish | Y | |
+| LessOrEqual | Y | |
+| LogSoftmax | Y | |
+| MeanVarianceNormalization | N | |
+| NegativeLogLikelihoodLoss | N | |
+| Range | Y | |
+| SequenceMap | N | |
+| Softmax | Y | |
diff --git a/pyproject.toml b/pyproject.toml
index c4d03d30..b438a774 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -40,6 +40,29 @@ repository = 'https://github.com/ENOT-AutoDL/onnx2torch'
[tool.setuptools.packages.find]
include = ['onnx2torch*']
+[tool.commitizen]
+name = 'cz_conventional_commits'
+tag_format = '$version'
+version_scheme = 'pep440'
+version_provider = 'pep621'
+update_changelog_on_bump = true
+major_version_zero = true
+
+[tool.docformatter]
+recursive = true
+wrap-summaries = 0
+wrap-descriptions = 0
+blank = true
+black = true
+pre-summary-newline = true
+
+[tool.yamlfix]
+line_length = 120
+explicit_start = false
+sequence_style = 'keep_style'
+whitelines = 1
+section_whitelines = 1
+
[tool.black]
line-length = 120
target-version = ['py36', 'py37', 'py38', 'py39']
@@ -51,3 +74,36 @@ profile = 'black'
line_length = 120
ensure_newline_before_comments = true
force_single_line = true
+
+[tool.pylint.master]
+load-plugins = ['pylint.extensions.docparams']
+
+[tool.pylint.format]
+max-line-length = 120
+
+[tool.pylint.design]
+max-args = 12
+max-locals = 30
+max-attributes = 20
+min-public-methods = 0
+
+[tool.pylint.typecheck]
+generated-members = ['torch.*']
+
+[tool.pylint.messages_control]
+disable = [
+ 'logging-fstring-interpolation',
+ 'cyclic-import',
+ 'duplicate-code',
+ 'missing-module-docstring',
+ 'unnecessary-pass',
+ 'no-name-in-module',
+]
+
+[tool.pylint.BASIC]
+good-names = ['bs', 'bn']
+
+[tool.pyright]
+reportMissingImports = false
+reportMissingTypeStubs = false
+reportWildcardImportFromLibrary = false
diff --git a/tests/node_converters/nms_test.py b/tests/node_converters/nms_test.py
index 6341fdbc..1ffd45ed 100644
--- a/tests/node_converters/nms_test.py
+++ b/tests/node_converters/nms_test.py
@@ -136,7 +136,6 @@ def _test_nms(
'boxes,scores,max_output_boxes_per_class,iou_threshold,score_threshold,center_point_box',
(
(_BOXES_CXCYWH_FORMAT_TEST, _SCORES_CXCYWH_FORMAT_TEST, 3, 0.1, 0.0, 1), # center point box format
- # FIXME
# flipped coordinates
# (_BOXES_FLIPPED_COORDINATES_TEST, _SCORES_FLIPPED_COORDINATES_TEST, 3, 0.5, 0.0, None),
(_BOXES_IDENTICAL_BOXES_TEST, _SCORES_IDENTICAL_BOXES_TEST, 3, 0.5, 0.0, None), # identical boxes