Skip to content

Commit

Permalink
fix: reduce keepdim type (int -> bool) (#206)
Browse files Browse the repository at this point in the history
  • Loading branch information
senysenyseny16 authored Mar 29, 2024
1 parent a8b0603 commit f04c3a7
Showing 1 changed file with 31 additions and 27 deletions.
58 changes: 31 additions & 27 deletions onnx2torch/node_converters/reduce.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=missing-class-docstring
__all__ = [
'OnnxReduceSumDynamicAxes',
'OnnxReduceSumStaticAxes',
Expand All @@ -11,6 +12,7 @@
from typing import Optional
from typing import Tuple
from typing import Union
from typing import cast

import torch
from torch import nn
Expand All @@ -29,14 +31,17 @@


@torch.fx.wrap
def _get_element(x: Union[List, Tuple], index: int = 0) -> Any:
def _get_element(x: Any, index: int = 0) -> Any:
if isinstance(x, (tuple, list)):
return x[index]

return x


def _initialize_none_dim(dim: Optional[Union[int, Tuple[int, ...]]], input_dim: int):
def _initialize_none_dim(
dim: Optional[Union[int, Tuple[int, ...]]],
input_dim: int,
) -> Union[List[int], Tuple[int, ...], int]:
if dim is None:
return list(range(input_dim))

Expand All @@ -47,27 +52,27 @@ def _log_sum(
input_tensor: torch.Tensor,
dim: Optional[Union[int, Tuple[int, ...]]] = None,
keepdim: bool = False,
):
dim = _initialize_none_dim(dim, input_tensor.dim())
return torch.log(torch.sum(input_tensor, dim=dim, keepdim=keepdim))
) -> torch.Tensor:
dim_ = _initialize_none_dim(dim, input_tensor.dim())
return torch.log(torch.sum(input_tensor, dim=dim_, keepdim=keepdim))


def _log_sum_exp(
input_tensor: torch.Tensor,
dim: Optional[Union[int, Tuple[int, ...]]] = None,
keepdim: bool = False,
):
dim = _initialize_none_dim(dim, input_tensor.dim())
return torch.logsumexp(input_tensor, dim=dim, keepdim=keepdim)
) -> torch.Tensor:
dim_ = _initialize_none_dim(dim, input_tensor.dim())
return torch.logsumexp(input_tensor, dim=dim_, keepdim=keepdim)


def _sum_square(
input_tensor: torch.Tensor,
dim: Optional[Union[int, Tuple[int, ...]]] = None,
keepdim: bool = False,
):
dim = _initialize_none_dim(dim, input_tensor.dim())
return torch.sum(torch.square(input_tensor), dim=dim, keepdim=keepdim)
) -> torch.Tensor:
dim_ = _initialize_none_dim(dim, input_tensor.dim())
return torch.sum(torch.square(input_tensor), dim=dim_, keepdim=keepdim)


_TORCH_FUNCTION_FROM_ONNX_TYPE = {
Expand All @@ -84,17 +89,15 @@ def _sum_square(
}


class OnnxReduceSumDynamicAxes( # pylint: disable=missing-class-docstring
nn.Module,
OnnxToTorchModuleWithCustomExport,
):
class OnnxReduceSumDynamicAxes(nn.Module, OnnxToTorchModuleWithCustomExport):
def __init__(self, keepdims: int = 1, noop_with_empty_axes: int = 0):
super().__init__()

self._keepdims = keepdims
self._noop_with_empty_axes = noop_with_empty_axes

def _onnx_attrs(self, opset_version: int) -> Dict[str, Any]:
del opset_version
return {
'noop_with_empty_axes_i': self._noop_with_empty_axes,
'keepdims_i': self._keepdims,
Expand All @@ -105,7 +108,7 @@ def forward( # pylint: disable=missing-function-docstring
input_tensor: torch.Tensor,
axes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
def _forward():
def _forward() -> torch.Tensor:
if axes is None or axes.nelement() == 0:
if self._noop_with_empty_axes:
return input_tensor
Expand All @@ -130,7 +133,7 @@ def _forward():
return _forward()


class OnnxReduceSumStaticAxes(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring
class OnnxReduceSumStaticAxes(nn.Module, OnnxToTorchModule):
def __init__(
self,
axes: List[int],
Expand All @@ -155,14 +158,14 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # pylint: disabl

self._axes = list(range(input_tensor.dim()))

return torch.sum(input_tensor, dim=self._axes, keepdim=self._keepdims)
return torch.sum(input_tensor, dim=self._axes, keepdim=bool(self._keepdims))


class OnnxReduceStaticAxes(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring
class OnnxReduceStaticAxes(nn.Module, OnnxToTorchModule):
def __init__(
self,
operation_type: str,
axes: List[int],
axes: Optional[List[int]],
keepdims: int = 1,
):
super().__init__()
Expand Down Expand Up @@ -228,10 +231,11 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # pylint: disabl
@add_converter(operation_type='ReduceSumSquare', version=1)
@add_converter(operation_type='ReduceSumSquare', version=11)
@add_converter(operation_type='ReduceSumSquare', version=13)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
del graph
node_attributes = node.attributes
axes = node_attributes.get('axes', None)
keepdims = node_attributes.get('keepdims', 1)
axes: Optional[List[int]] = node_attributes.get('axes', None)
keepdims: int = node_attributes.get('keepdims', 1)

return OperationConverterResult(
torch_module=OnnxReduceStaticAxes(
Expand All @@ -244,13 +248,13 @@ def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint:


@add_converter(operation_type='ReduceSum', version=13)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
keepdims = node.attributes.get('keepdims', 1)
noop_with_empty_axes = node.attributes.get('noop_with_empty_axes', 0)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
keepdims: int = node.attributes.get('keepdims', 1)
noop_with_empty_axes: int = node.attributes.get('noop_with_empty_axes', 0)

if len(node.input_values) == 2:
try:
axes = get_const_value(node.input_values[1], graph)
axes = cast(torch.Tensor, get_const_value(node.input_values[1], graph))
axes = axes.tolist()
return OperationConverterResult(
torch_module=OnnxReduceSumStaticAxes(
Expand Down

0 comments on commit f04c3a7

Please sign in to comment.