From f04c3a7231177ae06260f8a65893358502b6d105 Mon Sep 17 00:00:00 2001 From: Arseny <82811840+senysenyseny16@users.noreply.github.com> Date: Fri, 29 Mar 2024 21:00:38 +0100 Subject: [PATCH] fix: reduce keepdim type (int -> bool) (#206) --- onnx2torch/node_converters/reduce.py | 58 +++++++++++++++------------- 1 file changed, 31 insertions(+), 27 deletions(-) diff --git a/onnx2torch/node_converters/reduce.py b/onnx2torch/node_converters/reduce.py index 025b135e..fb73c32e 100644 --- a/onnx2torch/node_converters/reduce.py +++ b/onnx2torch/node_converters/reduce.py @@ -1,3 +1,4 @@ +# pylint: disable=missing-class-docstring __all__ = [ 'OnnxReduceSumDynamicAxes', 'OnnxReduceSumStaticAxes', @@ -11,6 +12,7 @@ from typing import Optional from typing import Tuple from typing import Union +from typing import cast import torch from torch import nn @@ -29,14 +31,17 @@ @torch.fx.wrap -def _get_element(x: Union[List, Tuple], index: int = 0) -> Any: +def _get_element(x: Any, index: int = 0) -> Any: if isinstance(x, (tuple, list)): return x[index] return x -def _initialize_none_dim(dim: Optional[Union[int, Tuple[int, ...]]], input_dim: int): +def _initialize_none_dim( + dim: Optional[Union[int, Tuple[int, ...]]], + input_dim: int, +) -> Union[List[int], Tuple[int, ...], int]: if dim is None: return list(range(input_dim)) @@ -47,27 +52,27 @@ def _log_sum( input_tensor: torch.Tensor, dim: Optional[Union[int, Tuple[int, ...]]] = None, keepdim: bool = False, -): - dim = _initialize_none_dim(dim, input_tensor.dim()) - return torch.log(torch.sum(input_tensor, dim=dim, keepdim=keepdim)) +) -> torch.Tensor: + dim_ = _initialize_none_dim(dim, input_tensor.dim()) + return torch.log(torch.sum(input_tensor, dim=dim_, keepdim=keepdim)) def _log_sum_exp( input_tensor: torch.Tensor, dim: Optional[Union[int, Tuple[int, ...]]] = None, keepdim: bool = False, -): - dim = _initialize_none_dim(dim, input_tensor.dim()) - return torch.logsumexp(input_tensor, dim=dim, keepdim=keepdim) +) -> torch.Tensor: + dim_ = _initialize_none_dim(dim, input_tensor.dim()) + return torch.logsumexp(input_tensor, dim=dim_, keepdim=keepdim) def _sum_square( input_tensor: torch.Tensor, dim: Optional[Union[int, Tuple[int, ...]]] = None, keepdim: bool = False, -): - dim = _initialize_none_dim(dim, input_tensor.dim()) - return torch.sum(torch.square(input_tensor), dim=dim, keepdim=keepdim) +) -> torch.Tensor: + dim_ = _initialize_none_dim(dim, input_tensor.dim()) + return torch.sum(torch.square(input_tensor), dim=dim_, keepdim=keepdim) _TORCH_FUNCTION_FROM_ONNX_TYPE = { @@ -84,10 +89,7 @@ def _sum_square( } -class OnnxReduceSumDynamicAxes( # pylint: disable=missing-class-docstring - nn.Module, - OnnxToTorchModuleWithCustomExport, -): +class OnnxReduceSumDynamicAxes(nn.Module, OnnxToTorchModuleWithCustomExport): def __init__(self, keepdims: int = 1, noop_with_empty_axes: int = 0): super().__init__() @@ -95,6 +97,7 @@ def __init__(self, keepdims: int = 1, noop_with_empty_axes: int = 0): self._noop_with_empty_axes = noop_with_empty_axes def _onnx_attrs(self, opset_version: int) -> Dict[str, Any]: + del opset_version return { 'noop_with_empty_axes_i': self._noop_with_empty_axes, 'keepdims_i': self._keepdims, @@ -105,7 +108,7 @@ def forward( # pylint: disable=missing-function-docstring input_tensor: torch.Tensor, axes: Optional[torch.Tensor] = None, ) -> torch.Tensor: - def _forward(): + def _forward() -> torch.Tensor: if axes is None or axes.nelement() == 0: if self._noop_with_empty_axes: return input_tensor @@ -130,7 +133,7 @@ def _forward(): return _forward() -class OnnxReduceSumStaticAxes(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring +class OnnxReduceSumStaticAxes(nn.Module, OnnxToTorchModule): def __init__( self, axes: List[int], @@ -155,14 +158,14 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # pylint: disabl self._axes = list(range(input_tensor.dim())) - return torch.sum(input_tensor, dim=self._axes, keepdim=self._keepdims) + return torch.sum(input_tensor, dim=self._axes, keepdim=bool(self._keepdims)) -class OnnxReduceStaticAxes(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring +class OnnxReduceStaticAxes(nn.Module, OnnxToTorchModule): def __init__( self, operation_type: str, - axes: List[int], + axes: Optional[List[int]], keepdims: int = 1, ): super().__init__() @@ -228,10 +231,11 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # pylint: disabl @add_converter(operation_type='ReduceSumSquare', version=1) @add_converter(operation_type='ReduceSumSquare', version=11) @add_converter(operation_type='ReduceSumSquare', version=13) -def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument +def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: + del graph node_attributes = node.attributes - axes = node_attributes.get('axes', None) - keepdims = node_attributes.get('keepdims', 1) + axes: Optional[List[int]] = node_attributes.get('axes', None) + keepdims: int = node_attributes.get('keepdims', 1) return OperationConverterResult( torch_module=OnnxReduceStaticAxes( @@ -244,13 +248,13 @@ def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: @add_converter(operation_type='ReduceSum', version=13) -def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument - keepdims = node.attributes.get('keepdims', 1) - noop_with_empty_axes = node.attributes.get('noop_with_empty_axes', 0) +def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: + keepdims: int = node.attributes.get('keepdims', 1) + noop_with_empty_axes: int = node.attributes.get('noop_with_empty_axes', 0) if len(node.input_values) == 2: try: - axes = get_const_value(node.input_values[1], graph) + axes = cast(torch.Tensor, get_const_value(node.input_values[1], graph)) axes = axes.tolist() return OperationConverterResult( torch_module=OnnxReduceSumStaticAxes(