Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache: dynamic cache with cross attention and UMT5 Cache support #28185

Closed
wants to merge 3 commits into from

Conversation

gante
Copy link
Member

@gante gante commented Dec 21, 2023

What does this PR do?

#28065 was becoming messy due to all Bart "copied from" dependencies, so this PR is a tiny version of it.

This PR:

  1. Introduces DynamicCacheWithCrossAttention, which expands DynamicCache [cache object equivalent to the previous past_key_values input/output] with the ability to hold a cross-attention cache. This design was intentional: most LLMs (and now even multimodel models) tend to be decoder-only, so this separation will keep the cache class for decoder-only models simpler. It also enables us to be more strict -- in Cache: Bart and related architectures support Cache objects #28065 I've caught an unintended cache deletion in Whisper thanks to the increased specificity!
  2. Adds Cache support to modeling_umt5.py, which is a form to test whether DynamicCacheWithCrossAttention is equivalent to the previous cache. These changes are the equivalent of the modeling changes in Generate: New Cache abstraction and Attention Sinks support #26681, but for encoder-decoder models.

Local tests run:

  1. RUN_SLOW=1 py.test tests/models/umt5/test_modeling_umt5.py -vv [Note: adds a test to ensure we keep the same results as in main]

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gante gante requested a review from ArthurZucker December 21, 2023 17:56
@gante gante marked this pull request as ready for review December 21, 2023 17:56
@@ -240,41 +248,71 @@ def compute_bias(self, query_length, key_length, device=None):
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values

def _prepare_key_values(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This abstraction does not look particularly useful here. However, for models with multiple attention implementations, this abstraction is useful: all attention implementations can share it!

(e.g. in Bart the benefits are clear)

@@ -481,6 +501,7 @@ class UMT5PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["UMT5Block"]
_keep_in_fp32_modules = ["wo"]
_supports_cache_class = True
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This enables the test_new_cache_format test -> converting back and forth between the new cache and the legacy cache with cross attention is tested

@@ -560,6 +560,27 @@ def test_training_gradient_checkpointing_use_reentrant_false(self):
@require_sentencepiece
@require_tokenizers
class Umt5IntegrationTest(unittest.TestCase):
def test_generation(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensures there is no regression from main

I've double-checked that we get the same values in main. I've also checked the results with and without cache, in both main and this PR.

Comment on lines +1367 to +1373
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and decoder_attention_mask is not None
and cache_length + decoder_input_ids.shape[1] > max_cache_length
):
decoder_attention_mask = decoder_attention_mask[:, -max_cache_length:]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic copied from llama + sink cache -> this makes the model ready for caches like sink cache

@huggingface huggingface deleted a comment from github-actions bot Jan 23, 2024
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#27931 will shamble things up 👿

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Mar 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants