Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: SplitToSequence #181

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions onnx2torch/node_converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from onnx2torch.node_converters.shape import *
from onnx2torch.node_converters.slice import *
from onnx2torch.node_converters.split import *
from onnx2torch.node_converters.split_to_sequence import *
from onnx2torch.node_converters.squeeze import *
from onnx2torch.node_converters.sum import *
from onnx2torch.node_converters.tile import *
Expand Down
69 changes: 69 additions & 0 deletions onnx2torch/node_converters/split_to_sequence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
__all__ = ['OnnxSplitToSequence']

from typing import Any
from typing import Dict
from typing import List
from typing import Optional

import torch
from torch import nn

from onnx2torch.node_converters.registry import add_converter
from onnx2torch.onnx_graph import OnnxGraph
from onnx2torch.onnx_node import OnnxNode
from onnx2torch.utils.common import OperationConverterResult
from onnx2torch.utils.common import get_onnx_version
from onnx2torch.utils.common import onnx_mapping_from_node
from onnx2torch.utils.custom_export_to_onnx import DefaultExportToOnnx
from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport


class OnnxSplitToSequence(nn.Module, OnnxToTorchModuleWithCustomExport): # pylint: disable=missing-class-docstring
def __init__(self, axis: int = 0, keepdims: int = 1):
super().__init__()
self.axis = axis
self.keepdims = keepdims

def _onnx_attrs(self, opset_version: int) -> Dict[str, Any]:
del opset_version
return {'axis_i': self.axis, 'keepdims_i': self.keepdims}

@staticmethod
def _split_to_sequence(
inputs: torch.Tensor,
split: Optional[torch.IntTensor] = None,
axis: int = 0,
keepdims: int = 1,
) -> List[torch.Tensor]:
del keepdims

split_size_or_sections = split.tolist() if split is not None else 1
return torch.split(tensor=inputs, split_size_or_sections=split_size_or_sections, dim=axis)

# pylint: disable=missing-function-docstring
def forward(self, inputs: torch.Tensor, split: Optional[torch.IntTensor] = None) -> List[torch.Tensor]:
if torch.onnx.is_in_onnx_export():

def _stub_forward():
return torch.Tensor()

onnx_attrs = self._onnx_attrs(opset_version=get_onnx_version())
return DefaultExportToOnnx.export(_stub_forward, 'SplitToSequence', inputs, split, onnx_attrs)

return self._split_to_sequence(inputs=inputs, split=split, axis=self.axis, keepdims=self.keepdims)


@add_converter(operation_type='SplitToSequence', version=11)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
del graph

axis: int = node.attributes.get('axis', 0)
keepdims: int = node.attributes.get('keepdims', 1)

if len(node.input_values) == 1 and keepdims == 0: # no split and keepdim = 0
raise NotImplementedError('SplitToSequence without split argument and keepdims == 0 is not implemented')

return OperationConverterResult(
torch_module=OnnxSplitToSequence(axis=axis, keepdims=keepdims),
onnx_mapping=onnx_mapping_from_node(node=node),
)
1 change: 1 addition & 0 deletions onnx2torch/utils/custom_export_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,5 @@ class DefaultExportToOnnx(CustomExportToOnnx): # pylint: disable=abstract-metho
@staticmethod
def symbolic(graph: torch_C.Graph, *args) -> torch_C.Value:
op_type, *inputs, onnx_attrs = args
inputs = tuple(value for value in inputs if value is not None) # filter out None
return graph.op(op_type, *inputs, **onnx_attrs, outputs=1)
2 changes: 1 addition & 1 deletion operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ Minimal tested opset version 9, maximum tested opset version 16, recommended ops
| Softsign | Y | |
| SpaceToDepth | N | |
| Split | Y | |
| SplitToSequence | N | |
| SplitToSequence | Y | SplitToSequence without "split" argument and "keepdims" = 0 is not implemented |
| Sqrt | Y | |
| Squeeze | Y | |
| StringNormalizer | N | |
Expand Down
76 changes: 76 additions & 0 deletions tests/node_converters/split_to_sequence_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import Optional

import numpy as np
import pytest
from onnx.helper import make_node
from onnx.helper import make_tensor_sequence_value_info
from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE

from tests.utils.common import check_onnx_model
from tests.utils.common import make_model_from_nodes


def _test_split(
inputs: np.ndarray,
split: Optional[np.ndarray],
axis: Optional[int],
keepdims: Optional[int],
) -> None:
test_inputs = {'input': inputs}
if split is not None:
test_inputs['split'] = split

node = make_node(
op_type='SplitToSequence',
inputs=[*test_inputs],
outputs=['output'],
axis=axis,
keepdims=keepdims,
)

outputs_info = [
make_tensor_sequence_value_info(
name='output',
elem_type=NP_TYPE_TO_TENSOR_TYPE[inputs.dtype],
shape=None,
)
]

model = make_model_from_nodes(
nodes=node,
initializers={},
inputs_example=test_inputs,
outputs_info=outputs_info,
)
check_onnx_model(model, test_inputs)


@pytest.mark.parametrize(
'inputs, split, axis, keepdims',
[
(
np.arange(18).reshape((3, 6)).astype(np.float32),
np.array(2, dtype=np.int64),
1,
None,
),
(
np.arange(18).reshape((3, 6)).astype(np.float32),
None,
1,
1,
),
],
)
def test_split( # pylint: disable=missing-function-docstring
inputs: np.ndarray,
split: Optional[np.ndarray],
axis: Optional[int],
keepdims: Optional[int],
) -> None:
_test_split(
inputs=inputs,
split=split,
axis=axis,
keepdims=keepdims,
)