Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[⚠️ removed a default argument] Make AttentionMaskConverter compatible with torch.compile(..., fullgraph=True) #27868

Merged
merged 5 commits into from
Dec 8, 2023

Conversation

fxmarty
Copy link
Contributor

@fxmarty fxmarty commented Dec 6, 2023

As per title, fixes #27789.

This issue is only for PyTorch 2.1 and has been fixed in torch nightly.

@@ -66,7 +66,7 @@ def to_causal_4d(
batch_size: int,
query_length: int,
key_value_length: int,
dtype: torch.dtype = torch.float32,
dtype: torch.dtype,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is fine to remove the default (that is the cause of the error), as in _prepare_4d_causal_attention_mask, _prepare_4d_attention_mask and _create_4d_causal_attention_mask privately exposed methods dtype is always passed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But to_causal_4d is part of the public api no? O this is breaking 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK it is used nowhere else than in _create_4d_causal_attention_mask, _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask, cc @patrickvonplaten I don't think the mask converter API was meant to be exposed?

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be alright even if breaking as it's mostly used for internal calls and has not really been here for a long time. Could you link the pytorch issue?

@@ -66,7 +66,7 @@ def to_causal_4d(
batch_size: int,
query_length: int,
key_value_length: int,
dtype: torch.dtype = torch.float32,
dtype: torch.dtype,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But to_causal_4d is part of the public api no? O this is breaking 😅

@fxmarty
Copy link
Contributor Author

fxmarty commented Dec 7, 2023

Also, AttentionMaskConverter is not in the documentation so not really user-facing.

@fxmarty fxmarty requested a review from ArthurZucker December 7, 2023 11:26
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright with me let's just add a ⚠️

@fxmarty fxmarty changed the title Make AttentionMaskConverter compatible with torch.compile(..., fullgraph=True) [⚠️ removed a default argument] Make AttentionMaskConverter compatible with torch.compile(..., fullgraph=True) Dec 8, 2023
@fxmarty fxmarty merged commit 307a7d0 into huggingface:main Dec 8, 2023
20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

_prepare_4d_causal_attention_mask doesn't work with torch.compile
2 participants