Skip to content

Commit

Permalink
fix: export to onnx ConstantOfShape in case of bool value attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
ivkalgin authored Oct 9, 2023
1 parent 9731aa3 commit f3f6dd6
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion onnx2torch/node_converters/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _forward():
if slope.nelement() == 1 or (
slope.shape[0] == input_tensor.shape[1] and all(s == 1 for s in slope.shape[1:])
):
return nn.functional.prelu(input_tensor, weight=slope.view(-1))
return nn.functional.prelu(input_tensor, weight=slope.view(-1)) # pylint: disable=not-callable

output = input_tensor.clone()
output = output * slope
Expand Down
4 changes: 3 additions & 1 deletion onnx2torch/node_converters/constant_of_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ def __init__(self, value: Optional[torch.Tensor] = None):
self.register_buffer('value', value)

def forward(self, shape: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring
fill_value = self.value.item()

return torch.full(
size=torch.Size(shape),
fill_value=self.value.item(),
fill_value=int(fill_value) if isinstance(fill_value, bool) else fill_value,
dtype=self.value.dtype,
device=self.value.device,
)
Expand Down
4 changes: 2 additions & 2 deletions onnx2torch/node_converters/pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def create_from_onnx_params( # pylint: disable=missing-function-docstring
return cls(pads=torch_padding, mode=torch_mode, constant_value=constant_value)

def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring
return F.pad(
return F.pad( # pylint: disable=not-callable
input_tensor,
mode=self.mode,
pad=self.pads,
Expand All @@ -109,7 +109,7 @@ def forward( # pylint: disable=missing-function-docstring
torch_pads = _onnx_padding_to_torch(pads.tolist())
torch_pads = _torch_padding_to_mode_format(torch_pads, self.mode)

return F.pad(input_tensor, mode=self.mode, pad=torch_pads, value=constant_value)
return F.pad(input_tensor, mode=self.mode, pad=torch_pads, value=constant_value) # pylint: disable=not-callable


@add_converter(operation_type='Pad', version=11)
Expand Down
2 changes: 2 additions & 0 deletions tests/node_converters/constant_of_shape_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,5 @@ def test_constant_of_shape() -> None: # pylint: disable=missing-function-docstr
shape = np.random.randint(low=1, high=2, size=(size,))
value = np.random.uniform(low=-10000, high=10000, size=(1,))
_test_constant_of_shape(shape, value)

_test_constant_of_shape(np.asarray([3, 3]), np.asarray([True]))

0 comments on commit f3f6dd6

Please sign in to comment.