diff --git a/onnx2torch/node_converters/binary_math_operations.py b/onnx2torch/node_converters/binary_math_operations.py index fa2ff39e..bda52c4f 100644 --- a/onnx2torch/node_converters/binary_math_operations.py +++ b/onnx2torch/node_converters/binary_math_operations.py @@ -19,7 +19,6 @@ 'Add': torch.add, 'Sub': torch.sub, 'Mul': torch.mul, - 'Div': torch.div, } @@ -42,6 +41,27 @@ def forward( # pylint: disable=missing-function-docstring return self.math_op_function(first, second) +class OnnxDiv(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring + def __init__(self, broadcast: Optional[int] = None, axis: Optional[int] = None): + super().__init__() + + self.broadcast = broadcast + self.axis = axis + + def forward( # pylint: disable=missing-function-docstring + self, + first: torch.Tensor, + second: torch.Tensor, + ) -> torch.Tensor: + if self.broadcast == 1 and self.axis is not None: + second = old_style_broadcast(first, second, self.axis) + + if first.is_floating_point() or second.is_floating_point(): # float division + return torch.div(first, second) + + return torch.div(first, second, rounding_mode='trunc') # integer division + + @add_converter(operation_type='Add', version=1) @add_converter(operation_type='Add', version=6) @add_converter(operation_type='Add', version=7) @@ -57,6 +77,17 @@ def forward( # pylint: disable=missing-function-docstring @add_converter(operation_type='Mul', version=7) @add_converter(operation_type='Mul', version=13) @add_converter(operation_type='Mul', version=14) +def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument + return OperationConverterResult( + torch_module=OnnxBinaryMathOperation( + operation_type=node.operation_type, + broadcast=node.attributes.get('broadcast', None), + axis=node.attributes.get('axis', None), + ), + onnx_mapping=onnx_mapping_from_node(node=node), + ) + + @add_converter(operation_type='Div', version=1) @add_converter(operation_type='Div', version=6) @add_converter(operation_type='Div', version=7) @@ -64,8 +95,7 @@ def forward( # pylint: disable=missing-function-docstring @add_converter(operation_type='Div', version=14) def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument return OperationConverterResult( - torch_module=OnnxBinaryMathOperation( - operation_type=node.operation_type, + torch_module=OnnxDiv( broadcast=node.attributes.get('broadcast', None), axis=node.attributes.get('axis', None), ), diff --git a/tests/node_converters/binary_operations_test.py b/tests/node_converters/binary_operations_test.py index 9809bff6..73bec3bb 100644 --- a/tests/node_converters/binary_operations_test.py +++ b/tests/node_converters/binary_operations_test.py @@ -30,3 +30,31 @@ def test_math_binary_operation(op_type: str) -> None: # pylint: disable=missing model = make_model_from_nodes(nodes=node, initializers=initializers, inputs_example=test_inputs) check_onnx_model(model, test_inputs) + + +@pytest.mark.parametrize( + 'x, y', + [ + (1, 2), + (1, 5), + (5, 30), + (-1, 2), + (-1, 5), + (5, -30), + (5, 2), + (-5, 2), + ], +) +def test_div_operation(x: int, y: int) -> None: # pylint: disable=missing-function-docstring + x_ = np.array(x, dtype=np.int64) # pylint: disable=invalid-name + y_ = np.array(y, dtype=np.int64) # pylint: disable=invalid-name + test_inputs = {'x': x_, 'y': y_} + + node = onnx.helper.make_node( + op_type='Div', + inputs=['x', 'y'], + outputs=['z'], + ) + + model = make_model_from_nodes(nodes=node, initializers={}, inputs_example=test_inputs) + check_onnx_model(model, test_inputs)