-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[⚠️ removed a default argument] Make AttentionMaskConverter
compatible with torch.compile(..., fullgraph=True)
#27868
[⚠️ removed a default argument] Make AttentionMaskConverter
compatible with torch.compile(..., fullgraph=True)
#27868
Conversation
@@ -66,7 +66,7 @@ def to_causal_4d( | |||
batch_size: int, | |||
query_length: int, | |||
key_value_length: int, | |||
dtype: torch.dtype = torch.float32, | |||
dtype: torch.dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is fine to remove the default (that is the cause of the error), as in _prepare_4d_causal_attention_mask
, _prepare_4d_attention_mask
and _create_4d_causal_attention_mask
privately exposed methods dtype
is always passed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But to_causal_4d is part of the public api no? O this is breaking 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAIK it is used nowhere else than in _create_4d_causal_attention_mask
, _prepare_4d_attention_mask
, _prepare_4d_causal_attention_mask
, cc @patrickvonplaten I don't think the mask converter API was meant to be exposed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should be alright even if breaking as it's mostly used for internal calls and has not really been here for a long time. Could you link the pytorch issue?
@@ -66,7 +66,7 @@ def to_causal_4d( | |||
batch_size: int, | |||
query_length: int, | |||
key_value_length: int, | |||
dtype: torch.dtype = torch.float32, | |||
dtype: torch.dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But to_causal_4d is part of the public api no? O this is breaking 😅
Also, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright with me let's just add a
AttentionMaskConverter
compatible with torch.compile(..., fullgraph=True)
AttentionMaskConverter
compatible with torch.compile(..., fullgraph=True)
As per title, fixes #27789.
This issue is only for PyTorch 2.1 and has been fixed in torch nightly.