diff --git a/onnx2torch/node_converters/__init__.py b/onnx2torch/node_converters/__init__.py index 53dfb4d6..c1a81ae4 100644 --- a/onnx2torch/node_converters/__init__.py +++ b/onnx2torch/node_converters/__init__.py @@ -48,6 +48,7 @@ from onnx2torch.node_converters.shape import * from onnx2torch.node_converters.slice import * from onnx2torch.node_converters.split import * +from onnx2torch.node_converters.split_to_sequence import * from onnx2torch.node_converters.squeeze import * from onnx2torch.node_converters.sum import * from onnx2torch.node_converters.tile import * diff --git a/onnx2torch/node_converters/split_to_sequence.py b/onnx2torch/node_converters/split_to_sequence.py new file mode 100644 index 00000000..bbae59e8 --- /dev/null +++ b/onnx2torch/node_converters/split_to_sequence.py @@ -0,0 +1,69 @@ +__all__ = ['OnnxSplitToSequence'] + +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + +import torch +from torch import nn + +from onnx2torch.node_converters.registry import add_converter +from onnx2torch.onnx_graph import OnnxGraph +from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import get_onnx_version +from onnx2torch.utils.common import onnx_mapping_from_node +from onnx2torch.utils.custom_export_to_onnx import DefaultExportToOnnx +from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport + + +class OnnxSplitToSequence(nn.Module, OnnxToTorchModuleWithCustomExport): # pylint: disable=missing-class-docstring + def __init__(self, axis: int = 0, keepdims: int = 1): + super().__init__() + self.axis = axis + self.keepdims = keepdims + + def _onnx_attrs(self, opset_version: int) -> Dict[str, Any]: + del opset_version + return {'axis_i': self.axis, 'keepdims_i': self.keepdims} + + @staticmethod + def _split_to_sequence( + inputs: torch.Tensor, + split: Optional[torch.IntTensor] = None, + axis: int = 0, + keepdims: int = 1, + ) -> List[torch.Tensor]: + del keepdims + + split_size_or_sections = split.tolist() if split is not None else 1 + return torch.split(tensor=inputs, split_size_or_sections=split_size_or_sections, dim=axis) + + # pylint: disable=missing-function-docstring + def forward(self, inputs: torch.Tensor, split: Optional[torch.IntTensor] = None) -> List[torch.Tensor]: + if torch.onnx.is_in_onnx_export(): + + def _stub_forward(): + return torch.Tensor() + + onnx_attrs = self._onnx_attrs(opset_version=get_onnx_version()) + return DefaultExportToOnnx.export(_stub_forward, 'SplitToSequence', inputs, split, onnx_attrs) + + return self._split_to_sequence(inputs=inputs, split=split, axis=self.axis, keepdims=self.keepdims) + + +@add_converter(operation_type='SplitToSequence', version=11) +def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument + del graph + + axis: int = node.attributes.get('axis', 0) + keepdims: int = node.attributes.get('keepdims', 1) + + if len(node.input_values) == 1 and keepdims == 0: # no split and keepdim = 0 + raise NotImplementedError('SplitToSequence without split argument and keepdims == 0 is not implemented') + + return OperationConverterResult( + torch_module=OnnxSplitToSequence(axis=axis, keepdims=keepdims), + onnx_mapping=onnx_mapping_from_node(node=node), + ) diff --git a/onnx2torch/utils/custom_export_to_onnx.py b/onnx2torch/utils/custom_export_to_onnx.py index 78e80297..3f3eb881 100644 --- a/onnx2torch/utils/custom_export_to_onnx.py +++ b/onnx2torch/utils/custom_export_to_onnx.py @@ -97,4 +97,5 @@ class DefaultExportToOnnx(CustomExportToOnnx): # pylint: disable=abstract-metho @staticmethod def symbolic(graph: torch_C.Graph, *args) -> torch_C.Value: op_type, *inputs, onnx_attrs = args + inputs = tuple(value for value in inputs if value is not None) # filter out None return graph.op(op_type, *inputs, **onnx_attrs, outputs=1) diff --git a/operators.md b/operators.md index 20c84832..f0d60d93 100644 --- a/operators.md +++ b/operators.md @@ -143,7 +143,7 @@ Minimal tested opset version 9, maximum tested opset version 16, recommended ops | Softsign | Y | | | SpaceToDepth | N | | | Split | Y | | -| SplitToSequence | N | | +| SplitToSequence | Y | SplitToSequence without "split" argument and "keepdims" = 0 is not implemented | | Sqrt | Y | | | Squeeze | Y | | | StringNormalizer | N | | diff --git a/tests/node_converters/split_to_sequence_test.py b/tests/node_converters/split_to_sequence_test.py new file mode 100644 index 00000000..6b9c40c4 --- /dev/null +++ b/tests/node_converters/split_to_sequence_test.py @@ -0,0 +1,76 @@ +from typing import Optional + +import numpy as np +import pytest +from onnx.helper import make_node +from onnx.helper import make_tensor_sequence_value_info +from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE + +from tests.utils.common import check_onnx_model +from tests.utils.common import make_model_from_nodes + + +def _test_split( + inputs: np.ndarray, + split: Optional[np.ndarray], + axis: Optional[int], + keepdims: Optional[int], +) -> None: + test_inputs = {'input': inputs} + if split is not None: + test_inputs['split'] = split + + node = make_node( + op_type='SplitToSequence', + inputs=[*test_inputs], + outputs=['output'], + axis=axis, + keepdims=keepdims, + ) + + outputs_info = [ + make_tensor_sequence_value_info( + name='output', + elem_type=NP_TYPE_TO_TENSOR_TYPE[inputs.dtype], + shape=None, + ) + ] + + model = make_model_from_nodes( + nodes=node, + initializers={}, + inputs_example=test_inputs, + outputs_info=outputs_info, + ) + check_onnx_model(model, test_inputs) + + +@pytest.mark.parametrize( + 'inputs, split, axis, keepdims', + [ + ( + np.arange(18).reshape((3, 6)).astype(np.float32), + np.array(2, dtype=np.int64), + 1, + None, + ), + ( + np.arange(18).reshape((3, 6)).astype(np.float32), + None, + 1, + 1, + ), + ], +) +def test_split( # pylint: disable=missing-function-docstring + inputs: np.ndarray, + split: Optional[np.ndarray], + axis: Optional[int], + keepdims: Optional[int], +) -> None: + _test_split( + inputs=inputs, + split=split, + axis=axis, + keepdims=keepdims, + )