Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make LogitsProcessor compatible with torch.compile #29018

Closed

Conversation

zucchini-nlp
Copy link
Member

@zucchini-nlp zucchini-nlp commented Feb 14, 2024

What does this PR do?

Small part of the issue #28981 . This PR makes sure that Logits Processor and Stopping Criteria are compatible with torch.compile when fullgraph=True. The changes were tested with dummy inputs and logits and also with Llama. For now only the Processors used in generate were checked, those that are used in bark/whisper models can be checked later if needed.

The below processors are not compatible, exceptions will be added later:

  • EncoderNoRepeatNGramLogitsProcessor and NoRepeatNGramLogitsProcessor -> tries to get a value from dict, which is input dependent
  • PrefixConstrainedLogitsProcessor -> relies on user provided functions, which mostly probably are also input dependent
  • SequenceBiasLogitsProcessor will not work at the same time with NoBadWordsProcessor, only one needs to be defined -> both call the same _prepare_bias_variables, which leads to recompiling it the second time we call with new arguments. Can be fixed if we either merge them into one processor or separate as two distinct.
  • UnbatchedClassifierFreeGuidanceLogitsProcessor -> calls the model forward, current Llama with sdpa failed due to providing not None attention_mask.
  • MaxTimeCriteria -> uses built-in time.time()

FYI @gante

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good! I've added a few comments below

Note: the fixes ... in the PR header will close the tracker, which we don't want. In this case, avoid using one of the magic keywords 🤗

src/transformers/generation/logits_process.py Outdated Show resolved Hide resolved
Comment on lines 1422 to 1424
mask = torch.full_like(scores, -math.inf)
mask[:, self.bos_token_id] = 0
scores = scores + mask
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although technically correct (BOS will be selected), it is not fully backwards compatible: previously, scores[:, self.bos_token_id] was set to 0, now it will not be set.

Perhaps scores = mask? Or the whole thing could also be something like scores = torch.where(torch.arange(scores.shape[1]) == self.bos_token_id, 0, -math.inf))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I tried, but it raises an error that using built-in equal operator is not supported for tracing in this situation. So, it's kind of a hack to overcome the error. I will go for scores = mask then

Comment on lines 1470 to 1472
mask = torch.full_like(scores, -math.inf)
mask[:, self.eos_token_id] = 0
scores = scores + mask
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(same as above)

src/transformers/generation/logits_process.py Outdated Show resolved Hide resolved
@@ -1572,14 +1579,15 @@ def __init__(
self.eos_token_id = eos_token_id

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this one, there is a chance we need to make cur_len a tensor, and not an int. If torch's compilation behaves anything like TF's, then integers will trigger a new compilation each time it is called with a different value 🤔

Try to confirm that a recompilation is not happening as you call this one with different cur_len. If a recompilation is happening, try changing cur_len to a tensor. Avoiding recompilations is crucial for performance.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right, cur_len as int causes recompilation every time the processor is called. But when we pass it in as a tensor, it will raise en error when trying to get tensor's value and compare the current length.

There is actually a way to use functorch.cond and then we can avoid graph breaks and recompilation, but it does not work with inductor backend right now. Pytorch devs said that it will be ready to use and efficient around the end of Q1.

@zucchini-nlp
Copy link
Member Author

@gante I went through the comments and fixed where possible. I am wondering if it is a good idea to add warnings as I did? Maybe there is a better way to do it, so that the users do not see lots of unrelated warning. I guess not everyone will use compile to generate

@zucchini-nlp zucchini-nlp changed the title Make LogitsProcessor compatible with torch.compile (WIP) Make LogitsProcessor compatible with torch.compile Feb 20, 2024
@zucchini-nlp
Copy link
Member Author

@gante Ready to review. I fixed tests and the generation utils to work with "cur_len", everything runs successfully in my machine.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While re-reviewing the PR, I noticed that we were not taking backward compatibility (BC) in mind 😬 The logits processors are public classes, and thus we need to warn users of the new cur_len argument. We should:

  1. set cur_len=None in ALL signatures.
  2. when cur_len=None, set it to input_ids.shape[1] (1. + 2. make these changes BC)
  3. add a deprecation cycle for the case where cur_len is None in ALL logits processors. A deprecation cycle looks like this (in this case, with a warning_once, see below). Since this is a substantial change, let's give the users 4 minor versions to adapt (i.e. target to remove in v4.43).

Other warning-related comments:

  1. you should be able to emit the warning using this function, as in this example.
  2. replace logger.warning with logger.warning_once. Otherwise there is a warning at each generation step :)

src/transformers/generation/logits_process.py Outdated Show resolved Hide resolved
src/transformers/generation/logits_process.py Outdated Show resolved Hide resolved
src/transformers/generation/logits_process.py Outdated Show resolved Hide resolved
@@ -863,6 +877,10 @@ class NoRepeatNGramLogitsProcessor(LogitsProcessor):
def __init__(self, ngram_size: int):
if not isinstance(ngram_size, int) or ngram_size <= 0:
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
logger.warning(
Copy link
Member

@gante gante Feb 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These cases that prevent compilation should be exceptions instead. More precisely, if possible, the original exception should be caught, the message you wrote here gets appended at the end, and then the same exception is reraised.

That way, the user has a full stack trace and, if for some reason torch fixes the issue on their end, the problem becomes automatically solved on our end as well!

(not sure if this is feasible, lmk if it is not)

Copy link
Member Author

@zucchini-nlp zucchini-nlp Feb 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found out that loggings/prints/warning cause a graph break in the current version of pytorch. There is an open PR to fix it.

I am committing a temporary change with all the logging and warning right now. Everything has to be tested when the pytorch PR gets merged. Current changes will not compile with fullgraph.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Mar 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants