Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Requires for torch.tensor before casting #31755

Merged
merged 1 commit into from
Jul 3, 2024

Conversation

echarlaix
Copy link
Collaborator

@echarlaix echarlaix commented Jul 2, 2024

Fixes ONNX export for swin, swin-donut and clap models

self.shift_size = torch_int(0)

coming from :

self.shift_size = torch_int(0)

introduced in #31311

as torch_int is expecting a torch.Tensor :

return x.to(torch.int64) if torch.jit.is_tracing() else int(x)

also I think we should be able to have here

self.shift_size = 0

cc @merveenoyan @xenova

@echarlaix echarlaix requested a review from amyeroberts July 2, 2024 17:16
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

@@ -762,7 +762,7 @@ def torch_int(x):

import torch

return x.to(torch.int64) if torch.jit.is_tracing() else int(x)
return x.to(torch.int64) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you can use the torch_int utility instead

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean ensuring shift_size is a torch.Tensor when jit tracing in the modeling directly ?

self.shift_size = torch_int(0)

@@ -774,7 +774,7 @@ def torch_float(x):

import torch

return x.to(torch.float32) if torch.jit.is_tracing() else int(x)
return x.to(torch.float32) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here torch_float

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think since we're explicitly calling torch_float where float was used to be it should be fine, no?

@echarlaix
Copy link
Collaborator Author

echarlaix commented Jul 3, 2024

@amyeroberts do you think this fix could be included in a patch release ?

Copy link
Contributor

@merveenoyan merveenoyan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@LysandreJik LysandreJik merged commit dc72fd7 into huggingface:main Jul 3, 2024
18 of 21 checks passed
@LysandreJik
Copy link
Member

@echarlaix I'm happy to include it in a patch release towards the end of the week

@echarlaix echarlaix deleted the fix-onnx branch July 3, 2024 09:18
@echarlaix
Copy link
Collaborator Author

@echarlaix I'm happy to include it in a patch release towards the end of the week

thanks a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants