-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make LogitsProcessor compatible with torch.compile #29018
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good! I've added a few comments below
Note: the fixes ...
in the PR header will close the tracker, which we don't want. In this case, avoid using one of the magic keywords 🤗
mask = torch.full_like(scores, -math.inf) | ||
mask[:, self.bos_token_id] = 0 | ||
scores = scores + mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although technically correct (BOS will be selected), it is not fully backwards compatible: previously, scores[:, self.bos_token_id]
was set to 0
, now it will not be set.
Perhaps scores = mask
? Or the whole thing could also be something like scores = torch.where(torch.arange(scores.shape[1]) == self.bos_token_id, 0, -math.inf))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I tried, but it raises an error that using built-in equal operator is not supported for tracing in this situation. So, it's kind of a hack to overcome the error. I will go for scores = mask
then
mask = torch.full_like(scores, -math.inf) | ||
mask[:, self.eos_token_id] = 0 | ||
scores = scores + mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(same as above)
@@ -1572,14 +1579,15 @@ def __init__( | |||
self.eos_token_id = eos_token_id | |||
|
|||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) | |||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |||
cur_len = input_ids.shape[-1] | |||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at this one, there is a chance we need to make cur_len
a tensor, and not an int
. If torch's compilation behaves anything like TF's, then integers will trigger a new compilation each time it is called with a different value 🤔
Try to confirm that a recompilation is not happening as you call this one with different cur_len
. If a recompilation is happening, try changing cur_len
to a tensor. Avoiding recompilations is crucial for performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are right, cur_len
as int causes recompilation every time the processor is called. But when we pass it in as a tensor, it will raise en error when trying to get tensor's value and compare the current length.
There is actually a way to use functorch.cond
and then we can avoid graph breaks and recompilation, but it does not work with inductor backend right now. Pytorch devs said that it will be ready to use and efficient around the end of Q1.
@gante I went through the comments and fixed where possible. I am wondering if it is a good idea to add warnings as I did? Maybe there is a better way to do it, so that the users do not see lots of unrelated warning. I guess not everyone will use |
@gante Ready to review. I fixed tests and the generation utils to work with "cur_len", everything runs successfully in my machine. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While re-reviewing the PR, I noticed that we were not taking backward compatibility (BC) in mind 😬 The logits processors are public classes, and thus we need to warn users of the new cur_len
argument. We should:
- set
cur_len=None
in ALL signatures. - when
cur_len=None
, set it toinput_ids.shape[1]
(1. + 2. make these changes BC) - add a deprecation cycle for the case where
cur_len is None
in ALL logits processors. A deprecation cycle looks like this (in this case, with awarning_once
, see below). Since this is a substantial change, let's give the users 4 minor versions to adapt (i.e. target to remove in v4.43).
Other warning-related comments:
- you should be able to emit the warning using this function, as in this example.
- replace
logger.warning
withlogger.warning_once
. Otherwise there is a warning at each generation step :)
@@ -863,6 +877,10 @@ class NoRepeatNGramLogitsProcessor(LogitsProcessor): | |||
def __init__(self, ngram_size: int): | |||
if not isinstance(ngram_size, int) or ngram_size <= 0: | |||
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}") | |||
logger.warning( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These cases that prevent compilation should be exceptions instead. More precisely, if possible, the original exception should be caught, the message you wrote here gets appended at the end, and then the same exception is reraised.
That way, the user has a full stack trace and, if for some reason torch fixes the issue on their end, the problem becomes automatically solved on our end as well!
(not sure if this is feasible, lmk if it is not)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found out that loggings/prints/warning cause a graph break in the current version of pytorch. There is an open PR to fix it.
I am committing a temporary change with all the logging and warning right now. Everything has to be tested when the pytorch PR gets merged. Current changes will not compile with fullgraph.
Co-authored-by: Joao Gante <[email protected]>
Co-authored-by: Joao Gante <[email protected]>
Co-authored-by: Joao Gante <[email protected]>
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
What does this PR do?
Small part of the issue #28981 . This PR makes sure that Logits Processor and Stopping Criteria are compatible with
torch.compile
whenfullgraph=True
. The changes were tested with dummy inputs and logits and also with Llama. For now only the Processors used ingenerate
were checked, those that are used in bark/whisper models can be checked later if needed.The below processors are not compatible, exceptions will be added later:
_prepare_bias_variables
, which leads to recompiling it the second time we call with new arguments. Can be fixed if we either merge them into one processor or separate as two distinct.FYI @gante