diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index b6dbd0cc..91d19f1f 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -2,8 +2,8 @@ # See https://help.github.com/articles/about-code-owners/ # Code -/onnx2torch @ivkalgin @irkjero -/tests @ivkalgin @irkjero +/onnx2torch @ivkalgin @senysenyseny16 +/tests @ivkalgin @senysenyseny16 # Actions /.github @senysenyseny16 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index eeb9ffb1..d98525e5 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -11,18 +11,18 @@ jobs: name: Pylint runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v3 - with: - python-version: "3.9" - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install pylint - python -m pip install -e . - - name: Analysing the code with pylint - run: | - pylint --rcfile .pylintrc --output-format=colorized $(git ls-files '*.py') + - uses: actions/checkout@v3 + - uses: actions/setup-python@v3 + with: + python-version: '3.9' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pylint + python -m pip install -e .[dev] + - name: Analysing the code with pylint + run: | + pylint --output-format=colorized $(git ls-files '*.py') lint-python-format: name: Python format @@ -31,12 +31,10 @@ jobs: - uses: actions/checkout@v3 - uses: actions/setup-python@v3 with: - python-version: "3.9" + python-version: '3.9' - uses: psf/black@stable with: - options: "--check --diff" + options: --check --diff - uses: isort/isort-action@master with: - configuration: - --check - --diff + configuration: --check --diff diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c3c97f07..ba5cc041 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,46 +1,79 @@ +default_install_hook_types: [commit-msg, pre-commit, pre-push] + repos: -- repo: https://github.com/psf/black - rev: 23.1.0 - hooks: - - id: black -- repo: https://github.com/PyCQA/isort - rev: 5.12.0 - hooks: - - id: isort - args: - [ - "--force-single-line-imports", - "--ensure-newline-before-comments", - "--line-length=120", - "--profile=black", - ] -- repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 - hooks: - - id: check-yaml - - id: check-toml - - id: check-json - - id: end-of-file-fixer - - id: trailing-whitespace - - id: check-added-large-files - - id: check-case-conflict - - id: check-merge-conflict - - id: detect-private-key - - id: end-of-file-fixer - - id: requirements-txt-fixer - - id: no-commit-to-branch - args: - [ - "-b=main", - ] -- repo: https://github.com/PyCQA/pylint - rev: v2.16.0 - hooks: - - id: pylint - args: - [ - "-rn", - "-sn", - "--rcfile=.pylintrc", - "--output-format=colorized", - ] + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: check-yaml + - id: check-toml + - id: check-json + - id: end-of-file-fixer + - id: trailing-whitespace + - id: check-added-large-files + - id: check-case-conflict + - id: check-merge-conflict + - id: detect-private-key + - id: end-of-file-fixer + - id: debug-statements + - id: detect-private-key + - id: detect-aws-credentials + args: [--allow-missing-credentials] + - id: no-commit-to-branch + args: [-b=main] + + - repo: https://github.com/commitizen-tools/commitizen + rev: v3.20.0 + hooks: + - id: commitizen + + - repo: https://github.com/gitleaks/gitleaks + rev: v8.18.2 + hooks: + - id: gitleaks + + - repo: https://github.com/executablebooks/mdformat + rev: 0.7.17 + hooks: + - id: mdformat + additional_dependencies: + - mdformat-gfm + - mdformat-black + - mdformat-shfmt + + - repo: https://github.com/lyz-code/yamlfix + rev: 1.16.0 + hooks: + - id: yamlfix + + - repo: https://github.com/adrienverge/yamllint.git + rev: v1.35.1 + hooks: + - id: yamllint + args: + - --format + - parsable + - --strict + - -d + - '{extends: relaxed, rules: {line-length: {max: 120}}}' + + - repo: https://github.com/psf/black + rev: 24.3.0 + hooks: + - id: black + + - repo: https://github.com/PyCQA/isort + rev: 5.13.2 + hooks: + - id: isort + + - repo: https://github.com/PyCQA/pylint + rev: v3.1.0 + hooks: + - id: pylint + language: system + args: [-rn, -sn] + + - repo: https://github.com/RobertCraigie/pyright-python + rev: v1.1.356 + hooks: + - id: pyright diff --git a/.pylintrc b/.pylintrc deleted file mode 100644 index 757551f4..00000000 --- a/.pylintrc +++ /dev/null @@ -1,44 +0,0 @@ -[MAIN] -load-plugins=pylint.extensions.docparams - -[BASIC] -# Good variable names which should always be accepted, separated by a comma. -good-names=x,y,z - -[DESIGN] -# Maximum number of arguments for function / method -max-args=12 -# Maximum number of locals for function / method body -max-locals=30 -# Maximum number of statements in function / method body -max-statements=60 -# Maximum number of attributes for a class (see R0902). -max-attributes=20 -# Minimum number of public methods for a class (see R0903). -min-public-methods=0 -# Maximum number of public methods for a class (see R0904). -max-public-methods=20 - -[FORMAT] -max-line-length=120 - -[MESSAGES] -disable=logging-fstring-interpolation, import-error, missing-module-docstring, missing-raises-doc, duplicate-code, fixme, unnecessary-lambda-assignment - -[SIMILARITIES] -# Minimum lines number of a similarity. -min-similarity-lines=4 -# Ignore comments when computing similarities. -ignore-comments=yes -# Ignore docstrings when computing similarities. -ignore-docstrings=yes -# Ignore imports when computing similarities. -ignore-imports=yes -# Ignore function signatures when computing similarities. -ignore-signatures=no - -[TYPECHECK] -# List of members which are set dynamically and missed by Pylint inference -# system, and so shouldn't trigger E1101 when accessed. -generated-members=numpy.*, torch.*, onnx.*, onnxruntime.* -ignored-modules=onnx.onnx_ml_pb2 diff --git a/README.md b/README.md index f0d713f5..dae844c4 100644 --- a/README.md +++ b/README.md @@ -11,36 +11,47 @@ - - - - + - + +
+ + + + + + + + +

onnx2torch is an ONNX to PyTorch converter. Our converter: -* Is easy to use – Convert the ONNX model with the function call ``convert``; -* Is easy to extend – Write your own custom layer in PyTorch and register it with ``@add_converter``; -* Convert back to ONNX – You can convert the model back to ONNX using the ``torch.onnx.export`` function. -If you find an issue, please [let us know](https://github.com/ENOT-AutoDL/onnx2torch/issues)! And feel free to create merge requests. +- Is easy to use – Convert the ONNX model with the function call `convert`; +- Is easy to extend – Write your own custom layer in PyTorch and register it with `@add_converter`; +- Convert back to ONNX – You can convert the model back to ONNX using the `torch.onnx.export` function. + +If you find an issue, please [let us know](https://github.com/ENOT-AutoDL/onnx2torch/issues)! +And feel free to create merge requests. Please note that this converter covers only a limited number of PyTorch / ONNX models and operations. -Let us know which models you use or want to convert from onnx to torch [here](https://github.com/ENOT-AutoDL/onnx2torch/discussions). +Let us know which models you use or want to convert from ONNX to PyTorch [here](https://github.com/ENOT-AutoDL/onnx2torch/discussions). ## Installation ```bash pip install onnx2torch ``` + or + ```bash conda install -c conda-forge onnx2torch ``` @@ -57,7 +68,7 @@ import torch from onnx2torch import convert # Path to ONNX model -onnx_model_path = '/some/path/mobile_net_v2.onnx' +onnx_model_path = "/some/path/mobile_net_v2.onnx" # You can pass the path to the onnx model to convert it or... torch_model_1 = convert(onnx_model_path) @@ -68,21 +79,22 @@ torch_model_2 = convert(onnx_model) ### Execute -We can execute the returned ``PyTorch model`` in the same way as the original torch model. +We can execute the returned `PyTorch model` in the same way as the original torch model. ```python import onnxruntime as ort + # Create example data x = torch.ones((1, 2, 224, 224)).cuda() out_torch = torch_model_1(x) ort_sess = ort.InferenceSession(onnx_model_path) -outputs_ort = ort_sess.run(None, {'input': x.numpy()}) +outputs_ort = ort_sess.run(None, {"input": x.numpy()}) # Check the Onnx output against PyTorch print(torch.max(torch.abs(outputs_ort - out_torch.detach().numpy()))) -print(np.allclose(outputs_ort, out_torch.detach().numpy(), atol=1.e-7)) +print(np.allclose(outputs_ort, out_torch.detach().numpy(), atol=1.0e-7)) ``` ## Models @@ -90,35 +102,39 @@ print(np.allclose(outputs_ort, out_torch.detach().numpy(), atol=1.e-7)) We have tested the following models: Segmentation models: -- [x] DeepLabv3plus -- [x] DeepLabv3 resnet50 (torchvision) + +- [x] DeepLabV3+ +- [x] DeepLabV3 ResNet-50 (TorchVision) - [x] HRNet -- [x] UNet (torchvision) -- [x] FCN resnet50 (torchvision) -- [x] lraspp mobilenetv3 (torchvision) +- [x] UNet (TorchVision) +- [x] FCN ResNet-50 (TorchVision) +- [x] LRASPP MobileNetV3 (TorchVision) + +Detection from MMdetection: -Detection from MMdetection: - [x] [SSDLite with MobileNetV2 backbone](https://github.com/open-mmlab/mmdetection) - [x] [RetinaNet R50](https://github.com/open-mmlab/mmdetection) - [x] [SSD300 with VGG backbone](https://github.com/open-mmlab/mmdetection) -- [x] [Yolov3_d53](https://github.com/open-mmlab/mmdetection) -- [x] [Yolov5](https://github.com/ultralytics/yolov5) - -Classification from __torchvision__: -- [x] Resnet18 -- [x] Resnet50 -- [x] MobileNet v2 -- [x] MobileNet v3 large -- [x] EfficientNet_b{0, 1, 2, 3} -- [x] WideResNet50 -- [x] ResNext50 -- [x] VGG16 -- [x] GoogleleNet +- [x] [YOLOv3 d53](https://github.com/open-mmlab/mmdetection) +- [x] [YOLOv5](https://github.com/ultralytics/yolov5) + +Classification from __TorchVision__: + +- [x] ResNet-18 +- [x] ResNet-50 +- [x] MobileNetV2 +- [x] MobileNetV3 Large +- [x] EfficientNet-B{0, 1, 2, 3} +- [x] WideResNet-50 +- [x] ResNext-50 +- [x] VGG-16 +- [x] GoogLeNet - [x] MnasNet - [x] RegNet Transformers: -- [x] Vit + +- [x] ViT - [x] Swin - [x] GPT-J @@ -127,15 +143,16 @@ Transformers: ## How to add new operations to converter Here we show how to extend onnx2torch with new ONNX operation, that supported by both PyTorch and ONNX +
and has the same behaviour An example of such a module is [Relu](./onnx2torch/node_converters/activations.py) ```python -@add_converter(operation_type='Relu', version=6) -@add_converter(operation_type='Relu', version=13) -@add_converter(operation_type='Relu', version=14) +@add_converter(operation_type="Relu", version=6) +@add_converter(operation_type="Relu", version=13) +@add_converter(operation_type="Relu", version=14) def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: return OperationConverterResult( torch_module=nn.ReLU(), @@ -143,9 +160,10 @@ def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: ) ``` -Here we have registered an operation named ``Relu`` for opset versions 6, 13, 14. -Note that the ``torch_module`` argument in ``OperationConverterResult`` must be a torch.nn.Module, not just a callable object! +Here we have registered an operation named `Relu` for opset versions 6, 13, 14. +Note that the `torch_module` argument in `OperationConverterResult` must be a torch.nn.Module, not just a callable object! If Operation's behaviour differs from one opset version to another, you should implement it separately. +
@@ -156,9 +174,9 @@ An example of such a module is [ScatterND](./onnx2torch/node_converters/scatter_ ```python # It is recommended to use Enum for string ONNX attributes. class ReductionOnnxAttr(Enum): - NONE = 'none' - ADD = 'add' - MUL = 'mul' + NONE = "none" + ADD = "add" + MUL = "mul" class OnnxScatterND(nn.Module, OnnxToTorchModuleWithCustomExport): @@ -177,13 +195,13 @@ class OnnxScatterND(nn.Module, OnnxToTorchModuleWithCustomExport): if opset_version < 16: if self._reduction != ReductionOnnxAttr.NONE: raise ValueError( - 'ScatterND from opset < 16 does not support' - f'reduction attribute != {ReductionOnnxAttr.NONE.value},' - f'got {self._reduction.value}' + "ScatterND from opset < 16 does not support" + f"reduction attribute != {ReductionOnnxAttr.NONE.value}," + f"got {self._reduction.value}" ) return onnx_attrs - onnx_attrs['reduction_s'] = self._reduction.value + onnx_attrs["reduction_s"] = self._reduction.value return onnx_attrs def forward( @@ -200,23 +218,27 @@ class OnnxScatterND(nn.Module, OnnxToTorchModuleWithCustomExport): # Please follow our convention, args consists of: # forward function, operation type, operation inputs, operation attributes. onnx_attrs = self._onnx_attrs(opset_version=get_onnx_version()) - return DefaultExportToOnnx.export(_forward, 'ScatterND', data, indices, updates, onnx_attrs) + return DefaultExportToOnnx.export( + _forward, "ScatterND", data, indices, updates, onnx_attrs + ) return _forward() -@add_converter(operation_type='ScatterND', version=11) -@add_converter(operation_type='ScatterND', version=13) -@add_converter(operation_type='ScatterND', version=16) +@add_converter(operation_type="ScatterND", version=11) +@add_converter(operation_type="ScatterND", version=13) +@add_converter(operation_type="ScatterND", version=16) def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: node_attributes = node.attributes - reduction = ReductionOnnxAttr(node_attributes.get('reduction', 'none')) + reduction = ReductionOnnxAttr(node_attributes.get("reduction", "none")) return OperationConverterResult( torch_module=OnnxScatterND(reduction=reduction), onnx_mapping=onnx_mapping_from_node(node=node), ) ``` -Here we have used a trick to convert the model from torch back to ONNX by defining the custom ``_ScatterNDExportToOnnx``. + +Here we have used a trick to convert the model from torch back to ONNX by defining the custom `_ScatterNDExportToOnnx`. +
## Opset version workaround @@ -235,13 +257,13 @@ import torch from onnx2torch import convert # Load the ONNX model. -model = onnx.load('model.onnx') +model = onnx.load("model.onnx") # Convert the model to the target version. target_version = 13 converted_model = version_converter.convert_version(model, target_version) # Convert to torch. torch_model = convert(converted_model) -torch.save(torch_model, 'model.pt') +torch.save(torch_model, "model.pt") ``` @@ -251,6 +273,7 @@ Note: use this only when the model does not convert to PyTorch using the existin ## Citation To cite onnx2torch use `Cite this repository` button, or: + ``` @misc{onnx2torch, title={onnx2torch}, diff --git a/onnx2torch/node_converters/base_element_wise.py b/onnx2torch/node_converters/base_element_wise.py index 60e94214..762e513a 100644 --- a/onnx2torch/node_converters/base_element_wise.py +++ b/onnx2torch/node_converters/base_element_wise.py @@ -1,3 +1,4 @@ +# pylint: disable=missing-docstring import torch from torch import nn @@ -5,7 +6,7 @@ from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport -class OnnxBaseElementWise(nn.Module, OnnxToTorchModuleWithCustomExport): # pylint: disable=missing-docstring +class OnnxBaseElementWise(nn.Module, OnnxToTorchModuleWithCustomExport): def __init__(self, op_type: str): super().__init__() self._op_type = op_type @@ -16,17 +17,20 @@ def _broadcast_shape(*tensors: torch.Tensor): broadcast_shape = torch.broadcast_shapes(*shapes) return broadcast_shape - def apply_reduction(self, *tensors: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring + def apply_reduction(self, *tensors: torch.Tensor) -> torch.Tensor: + del tensors raise NotImplementedError - def forward(self, *input_tensors: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring + def forward(self, *input_tensors: torch.Tensor) -> torch.Tensor: if len(input_tensors) == 1: # If there is a single element, return it (no op). # Also, no need for manually building the ONNX node. return input_tensors[0] - forward_lambda = lambda: self.apply_reduction(*input_tensors) + def _forward() -> torch.Tensor: + return self.apply_reduction(*input_tensors) + if torch.onnx.is_in_onnx_export(): - return DefaultExportToOnnx.export(forward_lambda, self._op_type, *input_tensors, {}) + return DefaultExportToOnnx.export(_forward, self._op_type, *input_tensors, {}) - return forward_lambda() + return _forward() diff --git a/onnx2torch/node_converters/global_average_pool.py b/onnx2torch/node_converters/global_average_pool.py index f57726dd..4e27287a 100644 --- a/onnx2torch/node_converters/global_average_pool.py +++ b/onnx2torch/node_converters/global_average_pool.py @@ -1,3 +1,4 @@ +# pylint: disable=missing-docstring __all__ = [ 'OnnxGlobalAveragePool', 'OnnxGlobalAveragePoolWithKnownInputShape', @@ -18,8 +19,8 @@ from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport -class OnnxGlobalAveragePool(nn.Module, OnnxToTorchModuleWithCustomExport): # pylint: disable=missing-docstring - def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring +class OnnxGlobalAveragePool(nn.Module, OnnxToTorchModuleWithCustomExport): + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: def _forward(): x_dims = list(range(2, len(input_tensor.shape))) return torch.mean(input_tensor, dim=x_dims, keepdim=True) @@ -30,24 +31,23 @@ def _forward(): return _forward() -class OnnxGlobalAveragePoolWithKnownInputShape( - nn.Module, OnnxToTorchModuleWithCustomExport -): # pylint: disable=missing-docstring +class OnnxGlobalAveragePoolWithKnownInputShape(nn.Module, OnnxToTorchModuleWithCustomExport): def __init__(self, input_shape: List[int]): super().__init__() self._x_dims = list(range(2, len(input_shape))) - def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring - forward_lambda = lambda: torch.mean(input_tensor, dim=self._x_dims, keepdim=True) + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: + def _forward() -> torch.Tensor: + return torch.mean(input_tensor, dim=self._x_dims, keepdim=True) if torch.onnx.is_in_onnx_export(): - return DefaultExportToOnnx.export(forward_lambda, 'GlobalAveragePool', input_tensor, {}) + return DefaultExportToOnnx.export(_forward, 'GlobalAveragePool', input_tensor, {}) - return forward_lambda() + return _forward() @add_converter(operation_type='GlobalAveragePool', version=1) -def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument +def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: input_value_info = graph.value_info[node.input_values[0]] input_shape = get_shape_from_value_info(input_value_info) diff --git a/onnx2torch/node_converters/logical.py b/onnx2torch/node_converters/logical.py index 79ea08c0..ac5b4681 100644 --- a/onnx2torch/node_converters/logical.py +++ b/onnx2torch/node_converters/logical.py @@ -1,3 +1,4 @@ +# pylint: disable=missing-docstring __all__ = [ 'OnnxNot', 'OnnxLogical', @@ -25,16 +26,18 @@ } -class OnnxNot(nn.Module, OnnxToTorchModuleWithCustomExport): # pylint: disable=missing-class-docstring - def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring - forward_lambda = lambda: torch.logical_not(input_tensor) +class OnnxNot(nn.Module, OnnxToTorchModuleWithCustomExport): + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: + def _forward() -> torch.Tensor: + return torch.logical_not(input_tensor) + if torch.onnx.is_in_onnx_export(): - return DefaultExportToOnnx.export(forward_lambda, 'Not', input_tensor, {}) + return DefaultExportToOnnx.export(_forward, 'Not', input_tensor, {}) - return forward_lambda() + return _forward() -class OnnxLogical(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring +class OnnxLogical(nn.Module, OnnxToTorchModule): def __init__(self, operation_type: str, broadcast: Optional[int] = None, axis: Optional[int] = None): super().__init__() self.broadcast = broadcast @@ -42,9 +45,7 @@ def __init__(self, operation_type: str, broadcast: Optional[int] = None, axis: O self.logic_op_function = _TORCH_FUNCTION_FROM_ONNX_TYPE[operation_type] - def forward( # pylint: disable=missing-function-docstring - self, first_tensor: torch.Tensor, second_tensor: torch.Tensor - ): + def forward(self, first_tensor: torch.Tensor, second_tensor: torch.Tensor): if self.broadcast == 1 and self.axis is not None: second_tensor = old_style_broadcast(first_tensor, second_tensor, self.axis) @@ -57,7 +58,8 @@ def forward( # pylint: disable=missing-function-docstring @add_converter(operation_type='And', version=7) @add_converter(operation_type='Or', version=1) @add_converter(operation_type='Or', version=7) -def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument +def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: + del graph return OperationConverterResult( torch_module=OnnxLogical( operation_type=node.operation_type, @@ -69,7 +71,8 @@ def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: @add_converter(operation_type='Not', version=1) -def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument +def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: + del graph return OperationConverterResult( torch_module=OnnxNot(), onnx_mapping=onnx_mapping_from_node(node=node), diff --git a/onnx2torch/node_converters/nms.py b/onnx2torch/node_converters/nms.py index 37bbfa46..c6b87317 100644 --- a/onnx2torch/node_converters/nms.py +++ b/onnx2torch/node_converters/nms.py @@ -1,3 +1,4 @@ +# pylint: disable=missing-docstring __all__ = [ 'OnnxNonMaxSuppression', ] @@ -20,15 +21,16 @@ from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport -class OnnxNonMaxSuppression(nn.Module, OnnxToTorchModuleWithCustomExport): # pylint: disable=missing-class-docstring +class OnnxNonMaxSuppression(nn.Module, OnnxToTorchModuleWithCustomExport): def __init__(self, center_point_box: int = 0): super().__init__() self._center_point_box = center_point_box def _onnx_attrs(self, opset_version: int) -> Dict[str, Any]: + del opset_version return {'center_point_box_i': self._center_point_box} - def forward( # pylint: disable=missing-function-docstring + def forward( self, boxes: torch.Tensor, scores: torch.Tensor, @@ -36,7 +38,8 @@ def forward( # pylint: disable=missing-function-docstring iou_threshold: Optional[torch.Tensor] = None, score_threshold: Optional[torch.Tensor] = None, ) -> torch.Tensor: - forward_lambda = lambda: self._nms(boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold) + def _forward() -> torch.Tensor: + return self._nms(boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold) if torch.onnx.is_in_onnx_export(): if max_output_boxes_per_class is None: @@ -48,7 +51,7 @@ def forward( # pylint: disable=missing-function-docstring onnx_attrs = self._onnx_attrs(opset_version=get_onnx_version()) return DefaultExportToOnnx.export( - forward_lambda, + _forward, 'NonMaxSuppression', boxes, scores, @@ -58,7 +61,7 @@ def forward( # pylint: disable=missing-function-docstring onnx_attrs, ) - return forward_lambda() + return _forward() def _nms( self, @@ -109,7 +112,8 @@ def _nms( @add_converter(operation_type='NonMaxSuppression', version=10) @add_converter(operation_type='NonMaxSuppression', version=11) -def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument +def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: + del graph center_point_box = node.attributes.get('center_point_box', 0) return OperationConverterResult( torch_module=OnnxNonMaxSuppression(center_point_box=center_point_box), diff --git a/onnx2torch/node_converters/range.py b/onnx2torch/node_converters/range.py index 8535e613..392e32e8 100644 --- a/onnx2torch/node_converters/range.py +++ b/onnx2torch/node_converters/range.py @@ -1,3 +1,4 @@ +# pylint: disable=missing-docstring __all__ = [ 'OnnxRange', ] @@ -16,7 +17,7 @@ from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport -class OnnxRange(nn.Module, OnnxToTorchModuleWithCustomExport): # pylint: disable=missing-class-docstring +class OnnxRange(nn.Module, OnnxToTorchModuleWithCustomExport): def __init__(self): super().__init__() self.register_buffer('dummy_buffer', torch.Tensor(), persistent=False) @@ -41,22 +42,24 @@ def _arange( device=self.dummy_buffer.device, ) - def forward( # pylint: disable=missing-function-docstring + def forward( self, start: Union[torch.Tensor, float, int], limit: Union[torch.Tensor, float, int], delta: Union[torch.Tensor, float, int], ) -> torch.Tensor: - forward_lambda = lambda: self._arange(start, limit, delta) + def _forward() -> torch.Tensor: + return self._arange(start, limit, delta) if torch.onnx.is_in_onnx_export(): - return DefaultExportToOnnx.export(forward_lambda, 'Range', start, limit, delta, {}) + return DefaultExportToOnnx.export(_forward, 'Range', start, limit, delta, {}) - return forward_lambda() + return _forward() @add_converter(operation_type='Range', version=11) -def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument +def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: + del graph return OperationConverterResult( torch_module=OnnxRange(), onnx_mapping=onnx_mapping_from_node(node), diff --git a/onnx2torch/node_converters/registry.py b/onnx2torch/node_converters/registry.py index 3d0b02d3..b4ddd174 100644 --- a/onnx2torch/node_converters/registry.py +++ b/onnx2torch/node_converters/registry.py @@ -37,7 +37,7 @@ def deco(converter: TConverter): raise ValueError(f'Operation "{description}" already registered') _CONVERTER_REGISTRY[description] = converter - _LOGGER.info(f'Operation converter registered {description}') + _LOGGER.debug(f'Operation converter registered {description}') return converter diff --git a/onnx2torch/node_converters/reshape.py b/onnx2torch/node_converters/reshape.py index edb8f05f..e8ec3aa7 100644 --- a/onnx2torch/node_converters/reshape.py +++ b/onnx2torch/node_converters/reshape.py @@ -27,12 +27,13 @@ def forward( # pylint: disable=missing-function-docstring input_tensor: torch.Tensor, shape: torch.Tensor, ) -> torch.Tensor: - forward_lambda = lambda: self._do_reshape(input_tensor, shape) + def _forward() -> torch.Tensor: + return self._do_reshape(input_tensor, shape) if torch.onnx.is_in_onnx_export(): - return DefaultExportToOnnx.export(forward_lambda, 'Reshape', input_tensor, shape, {}) + return DefaultExportToOnnx.export(_forward, 'Reshape', input_tensor, shape, {}) - return forward_lambda() + return _forward() @add_converter(operation_type='Reshape', version=5) diff --git a/onnx2torch/node_converters/tile.py b/onnx2torch/node_converters/tile.py index 9cfd0db4..0508a877 100644 --- a/onnx2torch/node_converters/tile.py +++ b/onnx2torch/node_converters/tile.py @@ -1,3 +1,4 @@ +# pylint: disable=missing-docstring __all__ = [ 'OnnxTile', ] @@ -14,23 +15,21 @@ from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport -class OnnxTile(nn.Module, OnnxToTorchModuleWithCustomExport): # pylint: disable=missing-class-docstring - def forward( # pylint: disable=missing-function-docstring - self, - input_tensor: torch.Tensor, - repeats: torch.Tensor, - ) -> torch.Tensor: - # torch.tile(input_tensor, repeats) is not supported for exporting - forward_lambda = lambda: input_tensor.repeat(torch.Size(repeats)) +class OnnxTile(nn.Module, OnnxToTorchModuleWithCustomExport): + def forward(self, input_tensor: torch.Tensor, repeats: torch.Tensor) -> torch.Tensor: + def _forward() -> torch.Tensor: + return input_tensor.repeat(torch.Size(repeats)) + if torch.onnx.is_in_onnx_export(): - return DefaultExportToOnnx.export(forward_lambda, 'Tile', input_tensor, repeats, {}) + return DefaultExportToOnnx.export(_forward, 'Tile', input_tensor, repeats, {}) - return forward_lambda() + return _forward() @add_converter(operation_type='Tile', version=6) @add_converter(operation_type='Tile', version=13) -def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument +def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: + del graph return OperationConverterResult( torch_module=OnnxTile(), onnx_mapping=onnx_mapping_from_node(node=node), diff --git a/operators.md b/operators.md index be77b4e1..949a1c6e 100644 --- a/operators.md +++ b/operators.md @@ -2,177 +2,177 @@ Minimal tested opset version 9, maximum tested opset version 16, recommended opset version 13 -| Operation type | Supported | Restrictions | -|---------------------------|-----------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| Abs | Y | | -| Acos | Y | | -| Acosh | N | | -| Add | Y | | -| And | Y | | -| ArgMax | N | | -| ArgMin | N | | -| Asin | Y | | -| Asinh | N | | -| Atan | Y | | -| Atanh | N | | -| AveragePool | Y | Average pool operation with spatial rank > 3 is not implemented | -| BatchNormalization | Y | BatchNorm operation with spatial rank > 3 is not implemented. BatchNorm nodes in training mode are not supported | -| BitShift | N | | -| Cast | Y | | -| Ceil | Y | | -| Clip | Y | Dynamic value of min/max is not implemented | -| Compress | N | | -| Concat | Y | | -| ConcatFromSequence | N | | -| Constant | Y | | -| ConstantOfShape | Y | Parameter "value" must be scalar | -| Conv | Y | Convolution operation with spatial rank > 3 is not implemented | -| ConvInteger | N | | -| ConvTranspose | Y | Convolution operation with spatial rank > 3 is not implemented | -| Cos | Y | | -| Cosh | N | | -| CumSum | Y | | -| DepthToSpace | Y | DCR mode is not implemented | -| DequantizeLinear | N | | -| Det | N | | -| Div | Y | | -| Dropout | Y | | -| Einsum | Y | | -| Elu | Y | | -| Equal | Y | | -| Erf | Y | | -| Exp | Y | | -| Expand | Y | | -| EyeLike | Y | | -| Flatten | Y | | -| Floor | Y | | -| GRU | N | | -| Gather | Y | | -| GatherElements | Y | | -| GatherND | Y | GatherND operation with parameter "batch_dims" > 0 is not implemented | -| Gemm | Y | | -| GlobalAveragePool | Y | | -| GlobalLpPool | N | | -| GlobalMaxPool | N | | -| Greater | Y | | -| GridSample | N | | -| HardSigmoid | Y | | -| Hardmax | N | | -| Identity | Y | | -| If | N | | -| InstanceNormalization | Y | | -| IsInf | Y | | -| IsNaN | Y | | -| LayerNormalization | Y | LayerNormalization outputs "Mean" and "InvStdDev" are not implemented | -| LRN | Y | | -| LSTM | N | | -| LeakyRelu | Y | | -| Less | Y | | -| Log | Y | | -| Loop | N | | -| LpNormalization | N | | -| LpPool | N | | -| MatMul | Y | | -| MatMulInteger | N | | -| Max | Y | | -| MaxPool | Y | Max pool operation with spatial rank > 3 is not implemented | -| MaxRoiPool | N | | -| MaxUnpool | N | | -| Mean | Y | | -| Min | Y | | -| Mod | Y | | -| Mul | Y | | -| Multinomial | N | | -| Neg | Y | | -| NonMaxSuppression | Y | | -| NonZero | Y | | -| Not | Y | | -| OneHot | N | | -| Optional | N | | -| OptionalGetElement | N | | -| OptionalHasElement | N | | -| Or | Y | | -| PRelu | Y | | -| Pad | Y | Padding is implemented to pad the last 3 dimensions of 5D input tensor, or the last 2 dimensions of 4D input tensor, or the last dimension of 3D input tensor | -| Pow | Y | | -| QLinearConv | N | | -| QLinearMatMul | N | | -| QuantizeLinear | N | | -| RNN | N | | -| RandomNormal | N | | -| RandomNormalLike | N | | -| RandomUniform | N | | -| RandomUniformLike | N | | -| Reciprocal | Y | | -| ReduceL1 | Y | | -| ReduceL2 | Y | | -| ReduceLogSum | Y | | -| ReduceLogSumExp | Y | | -| ReduceMax | Y | | -| ReduceMean | Y | | -| ReduceMin | Y | | -| ReduceProd | Y | | -| ReduceSum | Y | | -| ReduceSumSquare | Y | | -| Relu | Y | | -| Reshape | Y | Parameter "allowzero" = 1 is not implemented | -| Resize | Y | Roi logic is not implemented (pytorch's interpolate cannot resize channel or batch dimensions) | -| ReverseSequence | N | | -| RoiAlign | Y | Only "avg" mode is supported | -| Round | Y | | -| Scan | N | | -| Scatter(deprecated) | N | | -| ScatterElements | N | | -| ScatterND | Y | Only "none" reduction is supported | -| Selu | Y | Parameters "alpha" and "gamma" must be default | -| SequenceAt | N | | -| SequenceConstruct | N | | -| SequenceEmpty | N | | -| SequenceErase | N | | -| SequenceInsert | N | | -| SequenceLength | N | | -| Shape | Y | | -| Shrink | N | | -| Sigmoid | Y | | -| Sign | Y | | -| Sin | Y | | -| Sinh | N | | -| Size | N | | -| Slice | Y | | -| Softplus | Y | | -| Softsign | Y | | -| SpaceToDepth | N | | -| Split | Y | | -| SplitToSequence | N | | -| Sqrt | Y | | -| Squeeze | Y | | -| StringNormalizer | N | | -| Sub | Y | | -| Sum | Y | | -| Tan | Y | | -| Tanh | Y | | -| TfIdfVectorizer | N | | -| ThresholdedRelu | N | | -| Tile | Y | | -| TopK | Y | | -| Transpose | Y | | -| Trilu | N | | -| Unique | N | | -| Unsqueeze | Y | | -| Upsample(deprecated) | N | | -| Where | Y | | -| Xor | Y | | -| Function | N | | -| Bernoulli | N | | -| CastLike | N | | -| Celu | Y | | -| DynamicQuantizeLinear | N | | -| GreaterOrEqual | Y | | -| HardSwish | Y | | -| LessOrEqual | Y | | -| LogSoftmax | Y | | -| MeanVarianceNormalization | N | | -| NegativeLogLikelihoodLoss | N | | -| Range | Y | | -| SequenceMap | N | | -| Softmax | Y | | +| Operation type | Supported | Restrictions | +| ------------------------- | --------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| Abs | Y | | +| Acos | Y | | +| Acosh | N | | +| Add | Y | | +| And | Y | | +| ArgMax | N | | +| ArgMin | N | | +| Asin | Y | | +| Asinh | N | | +| Atan | Y | | +| Atanh | N | | +| AveragePool | Y | Average pool operation with spatial rank > 3 is not implemented | +| BatchNormalization | Y | BatchNorm operation with spatial rank > 3 is not implemented. BatchNorm nodes in training mode are not supported | +| BitShift | N | | +| Cast | Y | | +| Ceil | Y | | +| Clip | Y | Dynamic value of min/max is not implemented | +| Compress | N | | +| Concat | Y | | +| ConcatFromSequence | N | | +| Constant | Y | | +| ConstantOfShape | Y | Parameter "value" must be scalar | +| Conv | Y | Convolution operation with spatial rank > 3 is not implemented | +| ConvInteger | N | | +| ConvTranspose | Y | Convolution operation with spatial rank > 3 is not implemented | +| Cos | Y | | +| Cosh | N | | +| CumSum | Y | | +| DepthToSpace | Y | DCR mode is not implemented | +| DequantizeLinear | N | | +| Det | N | | +| Div | Y | | +| Dropout | Y | | +| Einsum | Y | | +| Elu | Y | | +| Equal | Y | | +| Erf | Y | | +| Exp | Y | | +| Expand | Y | | +| EyeLike | Y | | +| Flatten | Y | | +| Floor | Y | | +| GRU | N | | +| Gather | Y | | +| GatherElements | Y | | +| GatherND | Y | GatherND operation with parameter "batch_dims" > 0 is not implemented | +| Gemm | Y | | +| GlobalAveragePool | Y | | +| GlobalLpPool | N | | +| GlobalMaxPool | N | | +| Greater | Y | | +| GridSample | N | | +| HardSigmoid | Y | | +| Hardmax | N | | +| Identity | Y | | +| If | N | | +| InstanceNormalization | Y | | +| IsInf | Y | | +| IsNaN | Y | | +| LayerNormalization | Y | LayerNormalization outputs "Mean" and "InvStdDev" are not implemented | +| LRN | Y | | +| LSTM | N | | +| LeakyRelu | Y | | +| Less | Y | | +| Log | Y | | +| Loop | N | | +| LpNormalization | N | | +| LpPool | N | | +| MatMul | Y | | +| MatMulInteger | N | | +| Max | Y | | +| MaxPool | Y | Max pool operation with spatial rank > 3 is not implemented | +| MaxRoiPool | N | | +| MaxUnpool | N | | +| Mean | Y | | +| Min | Y | | +| Mod | Y | | +| Mul | Y | | +| Multinomial | N | | +| Neg | Y | | +| NonMaxSuppression | Y | | +| NonZero | Y | | +| Not | Y | | +| OneHot | N | | +| Optional | N | | +| OptionalGetElement | N | | +| OptionalHasElement | N | | +| Or | Y | | +| PRelu | Y | | +| Pad | Y | Padding is implemented to pad the last 3 dimensions of 5D input tensor, or the last 2 dimensions of 4D input tensor, or the last dimension of 3D input tensor | +| Pow | Y | | +| QLinearConv | N | | +| QLinearMatMul | N | | +| QuantizeLinear | N | | +| RNN | N | | +| RandomNormal | N | | +| RandomNormalLike | N | | +| RandomUniform | N | | +| RandomUniformLike | N | | +| Reciprocal | Y | | +| ReduceL1 | Y | | +| ReduceL2 | Y | | +| ReduceLogSum | Y | | +| ReduceLogSumExp | Y | | +| ReduceMax | Y | | +| ReduceMean | Y | | +| ReduceMin | Y | | +| ReduceProd | Y | | +| ReduceSum | Y | | +| ReduceSumSquare | Y | | +| Relu | Y | | +| Reshape | Y | Parameter "allowzero" = 1 is not implemented | +| Resize | Y | Roi logic is not implemented (pytorch's interpolate cannot resize channel or batch dimensions) | +| ReverseSequence | N | | +| RoiAlign | Y | Only "avg" mode is supported | +| Round | Y | | +| Scan | N | | +| Scatter(deprecated) | N | | +| ScatterElements | N | | +| ScatterND | Y | Only "none" reduction is supported | +| Selu | Y | Parameters "alpha" and "gamma" must be default | +| SequenceAt | N | | +| SequenceConstruct | N | | +| SequenceEmpty | N | | +| SequenceErase | N | | +| SequenceInsert | N | | +| SequenceLength | N | | +| Shape | Y | | +| Shrink | N | | +| Sigmoid | Y | | +| Sign | Y | | +| Sin | Y | | +| Sinh | N | | +| Size | N | | +| Slice | Y | | +| Softplus | Y | | +| Softsign | Y | | +| SpaceToDepth | N | | +| Split | Y | | +| SplitToSequence | N | | +| Sqrt | Y | | +| Squeeze | Y | | +| StringNormalizer | N | | +| Sub | Y | | +| Sum | Y | | +| Tan | Y | | +| Tanh | Y | | +| TfIdfVectorizer | N | | +| ThresholdedRelu | N | | +| Tile | Y | | +| TopK | Y | | +| Transpose | Y | | +| Trilu | N | | +| Unique | N | | +| Unsqueeze | Y | | +| Upsample(deprecated) | N | | +| Where | Y | | +| Xor | Y | | +| Function | N | | +| Bernoulli | N | | +| CastLike | N | | +| Celu | Y | | +| DynamicQuantizeLinear | N | | +| GreaterOrEqual | Y | | +| HardSwish | Y | | +| LessOrEqual | Y | | +| LogSoftmax | Y | | +| MeanVarianceNormalization | N | | +| NegativeLogLikelihoodLoss | N | | +| Range | Y | | +| SequenceMap | N | | +| Softmax | Y | | diff --git a/pyproject.toml b/pyproject.toml index c4d03d30..b438a774 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,29 @@ repository = 'https://github.com/ENOT-AutoDL/onnx2torch' [tool.setuptools.packages.find] include = ['onnx2torch*'] +[tool.commitizen] +name = 'cz_conventional_commits' +tag_format = '$version' +version_scheme = 'pep440' +version_provider = 'pep621' +update_changelog_on_bump = true +major_version_zero = true + +[tool.docformatter] +recursive = true +wrap-summaries = 0 +wrap-descriptions = 0 +blank = true +black = true +pre-summary-newline = true + +[tool.yamlfix] +line_length = 120 +explicit_start = false +sequence_style = 'keep_style' +whitelines = 1 +section_whitelines = 1 + [tool.black] line-length = 120 target-version = ['py36', 'py37', 'py38', 'py39'] @@ -51,3 +74,36 @@ profile = 'black' line_length = 120 ensure_newline_before_comments = true force_single_line = true + +[tool.pylint.master] +load-plugins = ['pylint.extensions.docparams'] + +[tool.pylint.format] +max-line-length = 120 + +[tool.pylint.design] +max-args = 12 +max-locals = 30 +max-attributes = 20 +min-public-methods = 0 + +[tool.pylint.typecheck] +generated-members = ['torch.*'] + +[tool.pylint.messages_control] +disable = [ + 'logging-fstring-interpolation', + 'cyclic-import', + 'duplicate-code', + 'missing-module-docstring', + 'unnecessary-pass', + 'no-name-in-module', +] + +[tool.pylint.BASIC] +good-names = ['bs', 'bn'] + +[tool.pyright] +reportMissingImports = false +reportMissingTypeStubs = false +reportWildcardImportFromLibrary = false diff --git a/tests/node_converters/nms_test.py b/tests/node_converters/nms_test.py index 6341fdbc..1ffd45ed 100644 --- a/tests/node_converters/nms_test.py +++ b/tests/node_converters/nms_test.py @@ -136,7 +136,6 @@ def _test_nms( 'boxes,scores,max_output_boxes_per_class,iou_threshold,score_threshold,center_point_box', ( (_BOXES_CXCYWH_FORMAT_TEST, _SCORES_CXCYWH_FORMAT_TEST, 3, 0.1, 0.0, 1), # center point box format - # FIXME # flipped coordinates # (_BOXES_FLIPPED_COORDINATES_TEST, _SCORES_FLIPPED_COORDINATES_TEST, 3, 0.5, 0.0, None), (_BOXES_IDENTICAL_BOXES_TEST, _SCORES_IDENTICAL_BOXES_TEST, 3, 0.5, 0.0, None), # identical boxes