Skip to content

Commit

Permalink
fix: div return type for integer arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
senysenyseny16 committed Oct 27, 2023
1 parent bebf52f commit 6094bcd
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 3 deletions.
36 changes: 33 additions & 3 deletions onnx2torch/node_converters/binary_math_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
'Add': torch.add,
'Sub': torch.sub,
'Mul': torch.mul,
'Div': torch.div,
}


Expand All @@ -42,6 +41,27 @@ def forward( # pylint: disable=missing-function-docstring
return self.math_op_function(first, second)


class OnnxDiv(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring
def __init__(self, broadcast: Optional[int] = None, axis: Optional[int] = None):
super().__init__()

self.broadcast = broadcast
self.axis = axis

def forward( # pylint: disable=missing-function-docstring
self,
first: torch.Tensor,
second: torch.Tensor,
) -> torch.Tensor:
if self.broadcast == 1 and self.axis is not None:
second = old_style_broadcast(first, second, self.axis)

if first.is_floating_point() or second.is_floating_point(): # float division
return torch.div(first, second)

return torch.div(first, second, rounding_mode='trunc') # integer division


@add_converter(operation_type='Add', version=1)
@add_converter(operation_type='Add', version=6)
@add_converter(operation_type='Add', version=7)
Expand All @@ -57,15 +77,25 @@ def forward( # pylint: disable=missing-function-docstring
@add_converter(operation_type='Mul', version=7)
@add_converter(operation_type='Mul', version=13)
@add_converter(operation_type='Mul', version=14)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
return OperationConverterResult(
torch_module=OnnxBinaryMathOperation(
operation_type=node.operation_type,
broadcast=node.attributes.get('broadcast', None),
axis=node.attributes.get('axis', None),
),
onnx_mapping=onnx_mapping_from_node(node=node),
)


@add_converter(operation_type='Div', version=1)
@add_converter(operation_type='Div', version=6)
@add_converter(operation_type='Div', version=7)
@add_converter(operation_type='Div', version=13)
@add_converter(operation_type='Div', version=14)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
return OperationConverterResult(
torch_module=OnnxBinaryMathOperation(
operation_type=node.operation_type,
torch_module=OnnxDiv(
broadcast=node.attributes.get('broadcast', None),
axis=node.attributes.get('axis', None),
),
Expand Down
28 changes: 28 additions & 0 deletions tests/node_converters/binary_operations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,31 @@ def test_math_binary_operation(op_type: str) -> None: # pylint: disable=missing

model = make_model_from_nodes(nodes=node, initializers=initializers, inputs_example=test_inputs)
check_onnx_model(model, test_inputs)


@pytest.mark.parametrize(
'x, y',
[
(1, 2),
(1, 5),
(5, 30),
(-1, 2),
(-1, 5),
(5, -30),
(5, 2),
(-5, 2),
],
)
def test_div_operation(x: int, y: int) -> None: # pylint: disable=missing-function-docstring
x_ = np.array(x, dtype=np.int64) # pylint: disable=invalid-name
y_ = np.array(y, dtype=np.int64) # pylint: disable=invalid-name
test_inputs = {'x': x_, 'y': y_}

node = onnx.helper.make_node(
op_type='Div',
inputs=['x', 'y'],
outputs=['z'],
)

model = make_model_from_nodes(nodes=node, initializers={}, inputs_example=test_inputs)
check_onnx_model(model, test_inputs)

0 comments on commit 6094bcd

Please sign in to comment.