diff --git a/onnx2torch/node_converters/activations.py b/onnx2torch/node_converters/activations.py index 374b231f..3f6f5234 100644 --- a/onnx2torch/node_converters/activations.py +++ b/onnx2torch/node_converters/activations.py @@ -58,7 +58,7 @@ def _forward(): if slope.nelement() == 1 or ( slope.shape[0] == input_tensor.shape[1] and all(s == 1 for s in slope.shape[1:]) ): - return nn.functional.prelu(input_tensor, weight=slope.view(-1)) + return nn.functional.prelu(input_tensor, weight=slope.view(-1)) # pylint: disable=not-callable output = input_tensor.clone() output = output * slope diff --git a/onnx2torch/node_converters/constant_of_shape.py b/onnx2torch/node_converters/constant_of_shape.py index d175336b..e81a6137 100644 --- a/onnx2torch/node_converters/constant_of_shape.py +++ b/onnx2torch/node_converters/constant_of_shape.py @@ -29,9 +29,11 @@ def __init__(self, value: Optional[torch.Tensor] = None): self.register_buffer('value', value) def forward(self, shape: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring + fill_value = self.value.item() + return torch.full( size=torch.Size(shape), - fill_value=self.value.item(), + fill_value=int(fill_value) if isinstance(fill_value, bool) else fill_value, dtype=self.value.dtype, device=self.value.device, ) diff --git a/onnx2torch/node_converters/pad.py b/onnx2torch/node_converters/pad.py index 66ea86f9..cb0dd1eb 100644 --- a/onnx2torch/node_converters/pad.py +++ b/onnx2torch/node_converters/pad.py @@ -87,7 +87,7 @@ def create_from_onnx_params( # pylint: disable=missing-function-docstring return cls(pads=torch_padding, mode=torch_mode, constant_value=constant_value) def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring - return F.pad( + return F.pad( # pylint: disable=not-callable input_tensor, mode=self.mode, pad=self.pads, @@ -109,7 +109,7 @@ def forward( # pylint: disable=missing-function-docstring torch_pads = _onnx_padding_to_torch(pads.tolist()) torch_pads = _torch_padding_to_mode_format(torch_pads, self.mode) - return F.pad(input_tensor, mode=self.mode, pad=torch_pads, value=constant_value) + return F.pad(input_tensor, mode=self.mode, pad=torch_pads, value=constant_value) # pylint: disable=not-callable @add_converter(operation_type='Pad', version=11) diff --git a/tests/node_converters/constant_of_shape_test.py b/tests/node_converters/constant_of_shape_test.py index 1ea1e89f..8277bb27 100644 --- a/tests/node_converters/constant_of_shape_test.py +++ b/tests/node_converters/constant_of_shape_test.py @@ -40,3 +40,5 @@ def test_constant_of_shape() -> None: # pylint: disable=missing-function-docstr shape = np.random.randint(low=1, high=2, size=(size,)) value = np.random.uniform(low=-10000, high=10000, size=(1,)) _test_constant_of_shape(shape, value) + + _test_constant_of_shape(np.asarray([3, 3]), np.asarray([True]))