Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix: ReduceSumStaticAxes type correction #171

Closed
wants to merge 2 commits into from

Conversation

francis2tm
Copy link

@francis2tm francis2tm commented Aug 6, 2023

Hello,

Bug details:

File "onnx2torch/onnx2torch/node_converters/reduce.py", line 159, in forward
    return torch.sum(input_tensor, dim=self._axes, keepdim=self._keepdims)
TypeError: sum() received an invalid combination of arguments - got (Tensor, keepdim=int, dim=list), but expected one of:
 * (Tensor input, *, torch.dtype dtype)
      didn't match because some of the keywords were incorrect: keepdim, dim
 * (Tensor input, tuple of ints dim, bool keepdim, *, torch.dtype dtype, Tensor out)
 * (Tensor input, tuple of names dim, bool keepdim, *, torch.dtype dtype, Tensor out)

Args/Attrs to reproduce
input_tensor.shape: torch.Size([1, 8, 8, 65])
self._axes: [1, 2] <class 'list'>
self._keepdims: 0 <class 'int'>

Solution

My solution is a simple typecast. Only tested against my particular example.

Copy link
Collaborator

@senysenyseny16 senysenyseny16 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -18,6 +18,7 @@ dependencies = [
'onnx>=1.9.0',
'torch>=1.8.0',
'torchvision>=0.9.0',
'onnxruntime>=1.15.1',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
'onnxruntime>=1.15.1',

@@ -155,7 +155,7 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # pylint: disabl

self._axes = list(range(input_tensor.dim()))

return torch.sum(input_tensor, dim=self._axes, keepdim=self._keepdims)
return torch.sum(input_tensor, dim=tuple(self._axes), keepdim=bool(self._keepdims))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please explain why self._axes should be wrapped into tuple?

@senysenyseny16 senysenyseny16 self-assigned this Oct 13, 2023
@senysenyseny16 senysenyseny16 added the fix bug fix for the user label Oct 13, 2023
@sniklaus
Copy link

I had the very same error message and for me the bool(...) was sufficient, I didn't need the tuple(...).

@Sbisseb
Copy link

Sbisseb commented Mar 25, 2024

I am also using return torch.sum(input_tensor, dim=self._axes, keepdim=bool(self._keepdims)) but I am now getting issues in onnx2torch/node_converters/gather.py in the slicing I am not sure why

return input_tensor[slice_for_take]
IndexError: index 4 is out of bounds for dimension 0 with size 3

Would you have any idea how to fix that?

@senysenyseny16
Copy link
Collaborator

#206

@senysenyseny16
Copy link
Collaborator

I am also using return torch.sum(input_tensor, dim=self._axes, keepdim=bool(self._keepdims)) but I am now getting issues in onnx2torch/node_converters/gather.py in the slicing I am not sure why

return input_tensor[slice_for_take]
IndexError: index 4 is out of bounds for dimension 0 with size 3

Would you have any idea how to fix that?

Hi, could you please open an issue with instructions for reproducing?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
fix bug fix for the user
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants