Skip to content

Commit

Permalink
refactor: static analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
senysenyseny16 committed Mar 29, 2024
1 parent a8b0603 commit 310490a
Show file tree
Hide file tree
Showing 11 changed files with 105 additions and 114 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
python -m pip install -e .
- name: Analysing the code with pylint
run: |
pylint --rcfile .pylintrc --output-format=colorized $(git ls-files '*.py')
pylint --output-format=colorized $(git ls-files '*.py')
lint-python-format:
name: Python format
Expand Down
25 changes: 9 additions & 16 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
repos:
- repo: https://github.com/psf/black
rev: 23.1.0
rev: 24.3.0
hooks:
- id: black
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
args:
[
"--force-single-line-imports",
"--ensure-newline-before-comments",
"--line-length=120",
"--profile=black",
]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: check-yaml
- id: check-toml
Expand All @@ -29,18 +22,18 @@ repos:
- id: end-of-file-fixer
- id: requirements-txt-fixer
- id: no-commit-to-branch
args:
[
"-b=main",
]
args: ["-b=main"]
- repo: https://github.com/PyCQA/pylint
rev: v2.16.0
rev: v3.1.0
hooks:
- id: pylint
name: pylint
entry: pylint
language: system
types: [ python ]
args:
[
"-rn",
"-sn",
"--rcfile=.pylintrc",
"--output-format=colorized",
]
44 changes: 0 additions & 44 deletions .pylintrc

This file was deleted.

16 changes: 10 additions & 6 deletions onnx2torch/node_converters/base_element_wise.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# pylint: disable=missing-docstring
import torch
from torch import nn

from onnx2torch.utils.custom_export_to_onnx import DefaultExportToOnnx
from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport


class OnnxBaseElementWise(nn.Module, OnnxToTorchModuleWithCustomExport): # pylint: disable=missing-docstring
class OnnxBaseElementWise(nn.Module, OnnxToTorchModuleWithCustomExport):
def __init__(self, op_type: str):
super().__init__()
self._op_type = op_type
Expand All @@ -16,17 +17,20 @@ def _broadcast_shape(*tensors: torch.Tensor):
broadcast_shape = torch.broadcast_shapes(*shapes)
return broadcast_shape

def apply_reduction(self, *tensors: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring
def apply_reduction(self, *tensors: torch.Tensor) -> torch.Tensor:
del tensors
raise NotImplementedError

def forward(self, *input_tensors: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring
def forward(self, *input_tensors: torch.Tensor) -> torch.Tensor:
if len(input_tensors) == 1:
# If there is a single element, return it (no op).
# Also, no need for manually building the ONNX node.
return input_tensors[0]

forward_lambda = lambda: self.apply_reduction(*input_tensors)
def _forward() -> torch.Tensor:
return self.apply_reduction(*input_tensors)

if torch.onnx.is_in_onnx_export():
return DefaultExportToOnnx.export(forward_lambda, self._op_type, *input_tensors, {})
return DefaultExportToOnnx.export(_forward, self._op_type, *input_tensors, {})

return forward_lambda()
return _forward()
20 changes: 10 additions & 10 deletions onnx2torch/node_converters/global_average_pool.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=missing-docstring
__all__ = [
'OnnxGlobalAveragePool',
'OnnxGlobalAveragePoolWithKnownInputShape',
Expand All @@ -18,8 +19,8 @@
from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport


class OnnxGlobalAveragePool(nn.Module, OnnxToTorchModuleWithCustomExport): # pylint: disable=missing-docstring
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring
class OnnxGlobalAveragePool(nn.Module, OnnxToTorchModuleWithCustomExport):
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
def _forward():
x_dims = list(range(2, len(input_tensor.shape)))
return torch.mean(input_tensor, dim=x_dims, keepdim=True)
Expand All @@ -30,24 +31,23 @@ def _forward():
return _forward()


class OnnxGlobalAveragePoolWithKnownInputShape(
nn.Module, OnnxToTorchModuleWithCustomExport
): # pylint: disable=missing-docstring
class OnnxGlobalAveragePoolWithKnownInputShape(nn.Module, OnnxToTorchModuleWithCustomExport):
def __init__(self, input_shape: List[int]):
super().__init__()
self._x_dims = list(range(2, len(input_shape)))

def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring
forward_lambda = lambda: torch.mean(input_tensor, dim=self._x_dims, keepdim=True)
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
def _forward() -> torch.Tensor:
return torch.mean(input_tensor, dim=self._x_dims, keepdim=True)

if torch.onnx.is_in_onnx_export():
return DefaultExportToOnnx.export(forward_lambda, 'GlobalAveragePool', input_tensor, {})
return DefaultExportToOnnx.export(_forward, 'GlobalAveragePool', input_tensor, {})

return forward_lambda()
return _forward()


@add_converter(operation_type='GlobalAveragePool', version=1)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
input_value_info = graph.value_info[node.input_values[0]]
input_shape = get_shape_from_value_info(input_value_info)

Expand Down
25 changes: 14 additions & 11 deletions onnx2torch/node_converters/logical.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=missing-docstring
__all__ = [
'OnnxNot',
'OnnxLogical',
Expand Down Expand Up @@ -25,26 +26,26 @@
}


class OnnxNot(nn.Module, OnnxToTorchModuleWithCustomExport): # pylint: disable=missing-class-docstring
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring
forward_lambda = lambda: torch.logical_not(input_tensor)
class OnnxNot(nn.Module, OnnxToTorchModuleWithCustomExport):
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
def _forward() -> torch.Tensor:
return torch.logical_not(input_tensor)

if torch.onnx.is_in_onnx_export():
return DefaultExportToOnnx.export(forward_lambda, 'Not', input_tensor, {})
return DefaultExportToOnnx.export(_forward, 'Not', input_tensor, {})

return forward_lambda()
return _forward()


class OnnxLogical(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring
class OnnxLogical(nn.Module, OnnxToTorchModule):
def __init__(self, operation_type: str, broadcast: Optional[int] = None, axis: Optional[int] = None):
super().__init__()
self.broadcast = broadcast
self.axis = axis

self.logic_op_function = _TORCH_FUNCTION_FROM_ONNX_TYPE[operation_type]

def forward( # pylint: disable=missing-function-docstring
self, first_tensor: torch.Tensor, second_tensor: torch.Tensor
):
def forward(self, first_tensor: torch.Tensor, second_tensor: torch.Tensor):
if self.broadcast == 1 and self.axis is not None:
second_tensor = old_style_broadcast(first_tensor, second_tensor, self.axis)

Expand All @@ -57,7 +58,8 @@ def forward( # pylint: disable=missing-function-docstring
@add_converter(operation_type='And', version=7)
@add_converter(operation_type='Or', version=1)
@add_converter(operation_type='Or', version=7)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
del graph
return OperationConverterResult(
torch_module=OnnxLogical(
operation_type=node.operation_type,
Expand All @@ -69,7 +71,8 @@ def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint:


@add_converter(operation_type='Not', version=1)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
del graph
return OperationConverterResult(
torch_module=OnnxNot(),
onnx_mapping=onnx_mapping_from_node(node=node),
Expand Down
16 changes: 10 additions & 6 deletions onnx2torch/node_converters/nms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=missing-docstring
__all__ = [
'OnnxNonMaxSuppression',
]
Expand All @@ -20,23 +21,25 @@
from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport


class OnnxNonMaxSuppression(nn.Module, OnnxToTorchModuleWithCustomExport): # pylint: disable=missing-class-docstring
class OnnxNonMaxSuppression(nn.Module, OnnxToTorchModuleWithCustomExport):
def __init__(self, center_point_box: int = 0):
super().__init__()
self._center_point_box = center_point_box

def _onnx_attrs(self, opset_version: int) -> Dict[str, Any]:
del opset_version
return {'center_point_box_i': self._center_point_box}

def forward( # pylint: disable=missing-function-docstring
def forward(
self,
boxes: torch.Tensor,
scores: torch.Tensor,
max_output_boxes_per_class: Optional[torch.Tensor] = None,
iou_threshold: Optional[torch.Tensor] = None,
score_threshold: Optional[torch.Tensor] = None,
) -> torch.Tensor:
forward_lambda = lambda: self._nms(boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold)
def _forward() -> torch.Tensor:
return self._nms(boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold)

if torch.onnx.is_in_onnx_export():
if max_output_boxes_per_class is None:
Expand All @@ -48,7 +51,7 @@ def forward( # pylint: disable=missing-function-docstring

onnx_attrs = self._onnx_attrs(opset_version=get_onnx_version())
return DefaultExportToOnnx.export(
forward_lambda,
_forward,
'NonMaxSuppression',
boxes,
scores,
Expand All @@ -58,7 +61,7 @@ def forward( # pylint: disable=missing-function-docstring
onnx_attrs,
)

return forward_lambda()
return _forward()

def _nms(
self,
Expand Down Expand Up @@ -109,7 +112,8 @@ def _nms(

@add_converter(operation_type='NonMaxSuppression', version=10)
@add_converter(operation_type='NonMaxSuppression', version=11)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
del graph
center_point_box = node.attributes.get('center_point_box', 0)
return OperationConverterResult(
torch_module=OnnxNonMaxSuppression(center_point_box=center_point_box),
Expand Down
15 changes: 9 additions & 6 deletions onnx2torch/node_converters/range.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=missing-docstring
__all__ = [
'OnnxRange',
]
Expand All @@ -16,7 +17,7 @@
from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport


class OnnxRange(nn.Module, OnnxToTorchModuleWithCustomExport): # pylint: disable=missing-class-docstring
class OnnxRange(nn.Module, OnnxToTorchModuleWithCustomExport):
def __init__(self):
super().__init__()
self.register_buffer('dummy_buffer', torch.Tensor(), persistent=False)
Expand All @@ -41,22 +42,24 @@ def _arange(
device=self.dummy_buffer.device,
)

def forward( # pylint: disable=missing-function-docstring
def forward(
self,
start: Union[torch.Tensor, float, int],
limit: Union[torch.Tensor, float, int],
delta: Union[torch.Tensor, float, int],
) -> torch.Tensor:
forward_lambda = lambda: self._arange(start, limit, delta)
def _forward() -> torch.Tensor:
return self._arange(start, limit, delta)

if torch.onnx.is_in_onnx_export():
return DefaultExportToOnnx.export(forward_lambda, 'Range', start, limit, delta, {})
return DefaultExportToOnnx.export(_forward, 'Range', start, limit, delta, {})

return forward_lambda()
return _forward()


@add_converter(operation_type='Range', version=11)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
del graph
return OperationConverterResult(
torch_module=OnnxRange(),
onnx_mapping=onnx_mapping_from_node(node),
Expand Down
7 changes: 4 additions & 3 deletions onnx2torch/node_converters/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@ def forward( # pylint: disable=missing-function-docstring
input_tensor: torch.Tensor,
shape: torch.Tensor,
) -> torch.Tensor:
forward_lambda = lambda: self._do_reshape(input_tensor, shape)
def _forward() -> torch.Tensor:
return self._do_reshape(input_tensor, shape)

if torch.onnx.is_in_onnx_export():
return DefaultExportToOnnx.export(forward_lambda, 'Reshape', input_tensor, shape, {})
return DefaultExportToOnnx.export(_forward, 'Reshape', input_tensor, shape, {})

return forward_lambda()
return _forward()


@add_converter(operation_type='Reshape', version=5)
Expand Down
Loading

0 comments on commit 310490a

Please sign in to comment.