Skip to content

Commit

Permalink
fix: div return type for integer arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
senysenyseny16 authored Oct 27, 2023
1 parent bebf52f commit afbddcb
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
10 changes: 9 additions & 1 deletion onnx2torch/node_converters/binary_math_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,19 @@
from onnx2torch.utils.common import old_style_broadcast
from onnx2torch.utils.common import onnx_mapping_from_node


def _onnx_div(first: torch.Tensor, second: torch.Tensor) -> torch.Tensor:
if first.is_floating_point() or second.is_floating_point(): # float division
return torch.div(first, second)

return torch.div(first, second, rounding_mode='trunc') # integer division


_TORCH_FUNCTION_FROM_ONNX_TYPE = {
'Add': torch.add,
'Sub': torch.sub,
'Mul': torch.mul,
'Div': torch.div,
'Div': _onnx_div,
}


Expand Down
28 changes: 28 additions & 0 deletions tests/node_converters/binary_operations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,31 @@ def test_math_binary_operation(op_type: str) -> None: # pylint: disable=missing

model = make_model_from_nodes(nodes=node, initializers=initializers, inputs_example=test_inputs)
check_onnx_model(model, test_inputs)


@pytest.mark.parametrize(
'x, y',
[
(1, 2),
(1, 5),
(5, 30),
(-1, 2),
(-1, 5),
(5, -30),
(5, 2),
(-5, 2),
],
)
def test_div_operation(x: int, y: int) -> None: # pylint: disable=missing-function-docstring
x_ = np.array(x, dtype=np.int64) # pylint: disable=invalid-name
y_ = np.array(y, dtype=np.int64) # pylint: disable=invalid-name
test_inputs = {'x': x_, 'y': y_}

node = onnx.helper.make_node(
op_type='Div',
inputs=['x', 'y'],
outputs=['z'],
)

model = make_model_from_nodes(nodes=node, initializers={}, inputs_example=test_inputs)
check_onnx_model(model, test_inputs)

0 comments on commit afbddcb

Please sign in to comment.