Skip to content

Commit

Permalink
Release 1.2.5 (#40)
Browse files Browse the repository at this point in the history
* fixed type of axes parameter (squeeze/unsqueeze/reduce) (#39)

Co-authored-by: igor <[email protected]>

* Feat: added converters for rcnn models (#38)

Added ConvTranspose converters and tests
Fixed export to onnx RoiAlign and Squeeze nodes
Fixed typo in ConstantOfShape

Co-authored-by: a.nazdryukhin <[email protected]>

* add new version 1.2.5

Co-authored-by: ivkalgin <[email protected]>
Co-authored-by: igor <[email protected]>
Co-authored-by: niobeus <[email protected]>
Co-authored-by: a.nazdryukhin <[email protected]>
Co-authored-by: p.ivanov <[email protected]>
  • Loading branch information
6 people authored Apr 7, 2022
1 parent d2dc0a0 commit 5617ca6
Show file tree
Hide file tree
Showing 10 changed files with 135 additions and 52 deletions.
2 changes: 1 addition & 1 deletion onnx2torch/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.2.0
1.2.5
2 changes: 1 addition & 1 deletion onnx2torch/node_converters/constant_of_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, value: Optional[torch.Tensor] = None):
super().__init__()

if value is None:
value = torch.Tensor(0.0, dtype=torch.float32)
value = torch.tensor(0.0, dtype=torch.float32)

if value.numel() != 1:
raise ValueError('parameter "value" must be scalar')
Expand Down
14 changes: 10 additions & 4 deletions onnx2torch/node_converters/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,19 @@
from onnx2torch.utils.common import onnx_padding_to_torch_padding

_CONV_CLASS_FROM_SPATIAL_RANK = {
1: nn.Conv1d,
2: nn.Conv2d,
3: nn.Conv3d,
('Conv', 1): nn.Conv1d,
('Conv', 2): nn.Conv2d,
('Conv', 3): nn.Conv3d,
('ConvTranspose', 1): nn.ConvTranspose1d,
('ConvTranspose', 2): nn.ConvTranspose2d,
('ConvTranspose', 3): nn.ConvTranspose3d,
}


@add_converter(operation_type='Conv', version=1)
@add_converter(operation_type='Conv', version=11)
@add_converter(operation_type='ConvTranspose', version=1)
@add_converter(operation_type='ConvTranspose', version=11)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
weights_value_name = node.input_values[1]
weights = graph.initializers[weights_value_name]
Expand All @@ -30,8 +35,9 @@ def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
else:
bias = None

op_type = node.operation_type
spatial_rank = len(weights.shape) - 2
conv_class = _CONV_CLASS_FROM_SPATIAL_RANK.get(spatial_rank, None)
conv_class = _CONV_CLASS_FROM_SPATIAL_RANK.get((op_type, spatial_rank), None)
if conv_class is None:
raise NotImplementedError(f'Convolution operation with spatial rank == {spatial_rank} is not implemented')

Expand Down
1 change: 1 addition & 0 deletions onnx2torch/node_converters/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint:
if len(node.input_values) == 2:
try:
axes = get_const_value(node.input_values[1], graph)
axes = axes.tolist()
return OperationConverterResult(
torch_module=OnnxReduceSumStaticAxes(
axes=axes,
Expand Down
76 changes: 66 additions & 10 deletions onnx2torch/node_converters/roialign.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
__all__ = ['OnnxRoiAlign']

from typing import Tuple

import torch
import torch._C as torch_C
from torch import nn
from torchvision.ops import roi_align

from onnx2torch.node_converters.registry import add_converter
from onnx2torch.onnx_graph import OnnxGraph
from onnx2torch.onnx_node import OnnxNode
from onnx2torch.utils.common import OnnxToTorchModule
from onnx2torch.utils.common import OperationConverterResult
from onnx2torch.utils.common import onnx_mapping_from_node
from onnx2torch.utils.custom_export_to_onnx import CustomExportToOnnx
from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport


class OnnxRoiAlign(nn.Module, OnnxToTorchModule):
class OnnxRoiAlign(nn.Module, OnnxToTorchModuleWithCustomExport):

def __init__(
self,
Expand All @@ -23,31 +27,83 @@ def __init__(
spatial_scale: float = 1.0,
):
super().__init__()

if mode != 'avg':
raise NotImplementedError(f'"{mode}" roi align mode is not implemented.')

self._output_size = (output_height, output_width)
self._output_height = output_height
self._output_width = output_width
self._sampling_ratio = sampling_ratio
self._spatial_scale = spatial_scale

def forward(
self,
@staticmethod
def _do_forward(
input_tensor: torch.Tensor,
rois: torch.Tensor,
batch_indices: torch.Tensor,
output_size: Tuple[int, int],
sampling_ratio: int,
spatial_scale: float,
) -> torch.Tensor:
batched_rois = torch.concat([batch_indices.unsqueeze(1).to(rois.dtype), rois], dim=1)

batch_indices = batch_indices.unsqueeze(1).to(rois.dtype)
batched_rois = torch.concat([batch_indices, rois], dim=1)

return roi_align(
input=input_tensor,
boxes=batched_rois,
output_size=self._output_size,
spatial_scale=self._spatial_scale,
sampling_ratio=self._sampling_ratio,
output_size=output_size,
spatial_scale=spatial_scale,
sampling_ratio=sampling_ratio,
aligned=False,
)

def forward(
self,
input_tensor: torch.Tensor,
rois: torch.Tensor,
batch_indices: torch.Tensor,

) -> torch.Tensor:

output = self._do_forward(
input_tensor=input_tensor,
rois=rois,
batch_indices=batch_indices,
output_size=(self._output_height, self._output_width),
sampling_ratio=self._sampling_ratio,
spatial_scale=self._spatial_scale,
)
if torch.onnx.is_in_onnx_export():
args = [
input_tensor,
rois,
batch_indices,
self._output_height,
self._output_width,
self._sampling_ratio,
self._spatial_scale
]
return _RoiAlignExportToOnnx.set_output_and_apply(output, *args)

return output


class _RoiAlignExportToOnnx(CustomExportToOnnx): # pylint: disable=abstract-method

@staticmethod
def symbolic(graph: torch_C.Graph, *args) -> torch_C.Value:
input_args = args[:3]
output_height, output_width, sampling_ratio, spatial_scale = args[3:]
return graph.op(
'RoiAlign',
*input_args,
output_height_i=output_height,
output_width_i=output_width,
sampling_ratio_i=sampling_ratio,
spatial_scale_f=spatial_scale,
outputs=1,
)


@add_converter(operation_type='RoiAlign', version=10)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
Expand Down
53 changes: 29 additions & 24 deletions onnx2torch/node_converters/squeeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,14 @@
from onnx2torch.node_converters.registry import add_converter
from onnx2torch.onnx_graph import OnnxGraph
from onnx2torch.onnx_node import OnnxNode
from onnx2torch.utils.common import OnnxMapping
from onnx2torch.utils.common import OnnxToTorchModule
from onnx2torch.utils.common import OperationConverterResult
from onnx2torch.utils.common import get_const_value
from onnx2torch.utils.common import get_onnx_version
from onnx2torch.utils.common import onnx_mapping_from_node
from onnx2torch.utils.custom_export_to_onnx import CustomExportToOnnx
from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport


class OnnxSqueezeStaticAxes(nn.Module, OnnxToTorchModule):
class OnnxSqueezeStaticAxes(nn.Module, OnnxToTorchModuleWithCustomExport):

def __init__(self, axes: Optional[List[int]] = None):
super().__init__()
Expand All @@ -31,22 +29,42 @@ def __init__(self, axes: Optional[List[int]] = None):

self.axes = axes

def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
if not self.axes:
@staticmethod
def _do_forward(input_tensor: torch.Tensor, axes: Optional[List[int]]) -> torch.Tensor:
if not axes:
return torch.squeeze(input_tensor)

result = input_tensor
for axes_id in self.axes:
for axes_id in axes:
result = torch.squeeze(result, dim=axes_id)

return result

def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
output = self._do_forward(input_tensor, self.axes)
if torch.onnx.is_in_onnx_export() and get_onnx_version() >= 13:
args = [input_tensor]
if self.axes:
axes = torch.tensor(
self.axes,
device=input_tensor.device,
dtype=torch.int64
)
args.append(axes)
return _SqueezeDynamicAxesExportToOnnx.set_output_and_apply(output, *args)

return output


class OnnxSqueezeDynamicAxes(nn.Module, OnnxToTorchModuleWithCustomExport):

@staticmethod
def is_empty_axes(axes: torch.Tensor) -> bool:
return axes is None or axes.nelement() == 0

@staticmethod
def _do_forward(input_tensor: torch.Tensor, axes: Optional[torch.Tensor]) -> torch.Tensor:
if axes is None or axes.nelement() == 0:
if OnnxSqueezeDynamicAxes.is_empty_axes(axes):
return torch.squeeze(input_tensor)

result = input_tensor
Expand All @@ -59,15 +77,15 @@ def forward(self, input_tensor: torch.Tensor, axes: Optional[torch.Tensor] = Non
output = self._do_forward(input_tensor, axes)
if torch.onnx.is_in_onnx_export():
args = [input_tensor]
if axes is not None:
if not self.is_empty_axes(axes):
args.append(axes)

return _SqueezeExportToOnnx.set_output_and_apply(output, *args)
return _SqueezeDynamicAxesExportToOnnx.set_output_and_apply(output, *args)

return output


class _SqueezeExportToOnnx(CustomExportToOnnx): # pylint: disable=abstract-method
class _SqueezeDynamicAxesExportToOnnx(CustomExportToOnnx): # pylint: disable=abstract-method

@staticmethod
def symbolic(graph: torch_C.Graph, *args) -> torch_C.Value:
Expand All @@ -86,19 +104,6 @@ def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint:

@add_converter(operation_type='Squeeze', version=13)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
if len(node.input_values) == 2:
try:
axes = get_const_value(node.input_values[1], graph)
return OperationConverterResult(
torch_module=OnnxSqueezeStaticAxes(axes=axes),
onnx_mapping=OnnxMapping(
inputs=(node.input_values[0],),
outputs=node.output_values,
),
)
except KeyError:
pass

return OperationConverterResult(
torch_module=OnnxSqueezeDynamicAxes(),
onnx_mapping=onnx_mapping_from_node(node),
Expand Down
3 changes: 2 additions & 1 deletion onnx2torch/node_converters/unsqueeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

class OnnxUnsqueezeStaticAxes(nn.Module, OnnxToTorchModule):

def __init__(self, axes: Optional[List[int]] = None):
def __init__(self, axes: List[int]):
super().__init__()
self.axes = sorted(axes)

Expand Down Expand Up @@ -75,6 +75,7 @@ def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint:
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
try:
axes = get_const_value(node.input_values[1], graph)
axes = axes.tolist()
return OperationConverterResult(
torch_module=OnnxUnsqueezeStaticAxes(axes=axes),
onnx_mapping=OnnxMapping(
Expand Down
6 changes: 5 additions & 1 deletion onnx2torch/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from typing import Union

import torch
from onnx import ValueInfoProto
from torch import nn
from torch.onnx import symbolic_helper
from onnx import ValueInfoProto

from onnx2torch.onnx_graph import OnnxGraph
from onnx2torch.onnx_node import OnnxNode
Expand Down Expand Up @@ -34,6 +35,9 @@ def onnx_mapping_from_node(node: OnnxNode) -> OnnxMapping:
outputs=node.output_values,
)

def get_onnx_version():
return symbolic_helper._export_onnx_opset_version


def get_shape_from_value_info(value_info: ValueInfoProto) -> List[int]:
return [
Expand Down
29 changes: 19 additions & 10 deletions tests/node_converters/conv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@


def _test_conv(
op_type: str,
in_channels: int,
out_channels: int,
kernel_shape: Tuple[int, int],
Expand All @@ -20,13 +21,16 @@ def _test_conv(

x_shape = (2, in_channels) + input_hw
x = np.random.uniform(low=-1.0, high=1.0, size=x_shape).astype(np.float32)
weights_shape = (out_channels, in_channels//group) + kernel_shape
if op_type == 'Conv':
weights_shape = (out_channels, in_channels//group) + kernel_shape
elif op_type == 'ConvTranspose':
weights_shape = (in_channels, out_channels//group) + kernel_shape
weights = np.random.uniform(low=-1.0, high=1.0, size=weights_shape).astype(np.float32)

test_inputs = {'x': x}
initializers = {'weights': weights}
node = onnx.helper.make_node(
op_type='Conv',
op_type=op_type,
inputs=['x', 'weights'],
outputs=['y'],
kernel_shape=kernel_shape,
Expand All @@ -44,27 +48,30 @@ def _test_conv(


def test_conv2d_base_params() -> None:
op_type_variants = ('ConvTranspose', 'Conv')
in_channels_variants = (1, 2, 3, 4, 16)
out_channels_variants = (1, 2, 3, 4, 16)
input_hw_variants = ((32, 32), (32, 31), (31, 31))
input_hw_variants = ((32, 32), (32, 31), (31, 32), (31, 31))
kernel_shape_variants = tuple(chain(
((i, i) for i in range(1, 6)),
((1, 2), (1, 3), (1, 5)),
((2, 2), (2, 3), (2, 5)),
))
all_variants = product(in_channels_variants, out_channels_variants, input_hw_variants, kernel_shape_variants)
for in_channels, out_channels, input_hw, kernel_shape in all_variants:
all_variants = product(op_type_variants, in_channels_variants, out_channels_variants, input_hw_variants, kernel_shape_variants)
for op_type, in_channels, out_channels, input_hw, kernel_shape in all_variants:
_test_conv(
op_type=op_type,
in_channels=in_channels,
out_channels=out_channels,
input_hw=input_hw,
kernel_shape=kernel_shape,
)

in_out_channels_variants = (2, 3, 4, 16)
all_variants = product(in_out_channels_variants, input_hw_variants, kernel_shape_variants)
for in_out_channels, input_hw, kernel_shape in all_variants:
all_variants = product(op_type_variants, in_out_channels_variants, input_hw_variants, kernel_shape_variants)
for op_type, in_out_channels, input_hw, kernel_shape in all_variants:
_test_conv(
op_type=op_type,
in_channels=in_out_channels,
out_channels=in_out_channels,
input_hw=input_hw,
Expand All @@ -74,7 +81,8 @@ def test_conv2d_base_params() -> None:


def test_conv_stride_dilations_pads() -> None:
input_hw_variants = ((32, 32), (32, 27), (27, 27))
op_type_variants = ('ConvTranspose', 'Conv')
input_hw_variants = ((32, 32), (32, 27), (27, 32), (27, 27))
kernel_shape_variants = tuple(chain(
((i, i) for i in range(1, 4)),
((1, 2), (1, 3), (2, 3)),
Expand All @@ -85,9 +93,10 @@ def test_conv_stride_dilations_pads() -> None:
dilations_variants = (
(1, 1), (2, 2), (1, 2), (2, 1),
)
all_variants = product(input_hw_variants, kernel_shape_variants, stride_variants, dilations_variants)
for input_hw, kernel_shape, strides, dilations in all_variants:
all_variants = product(op_type_variants, input_hw_variants, kernel_shape_variants, stride_variants, dilations_variants)
for op_type, input_hw, kernel_shape, strides, dilations in all_variants:
_test_conv(
op_type=op_type,
in_channels=16,
out_channels=16,
input_hw=input_hw,
Expand Down
Loading

0 comments on commit 5617ca6

Please sign in to comment.