diff --git a/onnx2torch/node_converters/binary_math_operations.py b/onnx2torch/node_converters/binary_math_operations.py index fa2ff39e..ccf589c1 100644 --- a/onnx2torch/node_converters/binary_math_operations.py +++ b/onnx2torch/node_converters/binary_math_operations.py @@ -15,11 +15,19 @@ from onnx2torch.utils.common import old_style_broadcast from onnx2torch.utils.common import onnx_mapping_from_node + +def _onnx_div(first: torch.Tensor, second: torch.Tensor) -> torch.Tensor: + if first.is_floating_point() or second.is_floating_point(): # float division + return torch.div(first, second) + + return torch.div(first, second, rounding_mode='trunc') # integer division + + _TORCH_FUNCTION_FROM_ONNX_TYPE = { 'Add': torch.add, 'Sub': torch.sub, 'Mul': torch.mul, - 'Div': torch.div, + 'Div': _onnx_div, } diff --git a/tests/node_converters/binary_operations_test.py b/tests/node_converters/binary_operations_test.py index 9809bff6..73bec3bb 100644 --- a/tests/node_converters/binary_operations_test.py +++ b/tests/node_converters/binary_operations_test.py @@ -30,3 +30,31 @@ def test_math_binary_operation(op_type: str) -> None: # pylint: disable=missing model = make_model_from_nodes(nodes=node, initializers=initializers, inputs_example=test_inputs) check_onnx_model(model, test_inputs) + + +@pytest.mark.parametrize( + 'x, y', + [ + (1, 2), + (1, 5), + (5, 30), + (-1, 2), + (-1, 5), + (5, -30), + (5, 2), + (-5, 2), + ], +) +def test_div_operation(x: int, y: int) -> None: # pylint: disable=missing-function-docstring + x_ = np.array(x, dtype=np.int64) # pylint: disable=invalid-name + y_ = np.array(y, dtype=np.int64) # pylint: disable=invalid-name + test_inputs = {'x': x_, 'y': y_} + + node = onnx.helper.make_node( + op_type='Div', + inputs=['x', 'y'], + outputs=['z'], + ) + + model = make_model_from_nodes(nodes=node, initializers={}, inputs_example=test_inputs) + check_onnx_model(model, test_inputs)