-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cache: Bart
and related architectures support Cache
objects
#28065
Conversation
Cache
objectsCache
objects
Cache
objectsBart
supports Cache
objects
Bart
supports Cache
objectsBart
+ related architectures support Cache
objects
Bart
+ related architectures support Cache
objectsBart
and related architectures support Cache
objects
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@amyeroberts this PR is not finalized, but I'd love to get an early review -- the failing tests are fixed by propagating the changes to models with the The key parts to review now are labeled as |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Impressive piece of work 🔥
I've just paid attention to the addition in cache_utils
and changes in BART. Just some nits and questions on my side for understanding but overall structure I think looks great! Would be good to get a second set of eyes from someone with more cache experience on this too.
`past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or | ||
`config.use_cache=True`. | ||
|
||
Two formats are allowed: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is passing in inputs in the legacy_format discouraged? If both are allowed, then we should update the the type hint to have both; if the legacy format is deprecated, I'd reword this as we don't want to encourage passing in the old format.
is_cross_attention = key_value_states is not None | ||
|
||
if is_cross_attention: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit - this variable is only used once an on the immediate next line - comment provides enough context
is_cross_attention = key_value_states is not None | |
if is_cross_attention: | |
if key_value_states is not None: |
|
||
# Keep only the unprocessed tokens: | ||
# 1 - If the length of the decoder_attention_mask exceeds the length of decoder_input_ids, then we are in a | ||
# setting where some of the inputs are exclusivelly passed as part of the cache (e.g. when passing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ultranit
# setting where some of the inputs are exclusivelly passed as part of the cache (e.g. when passing | |
# setting where some of the inputs are exclusively passed as part of the cache (e.g. when passing |
# Keep only the unprocessed tokens: | ||
# 1 - If the length of the decoder_attention_mask exceeds the length of decoder_input_ids, then we are in a | ||
# setting where some of the inputs are exclusivelly passed as part of the cache (e.g. when passing | ||
# input_embeds as input) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for my own understanding, am I right in thinking the reason they're exclusively part of the cache if I pass input_embeds is because any input_ids must have been generated?
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard | ||
# decoder_input_ids based on the past_length. | ||
elif past_length < decoder_input_ids.shape[1]: | ||
decoder_input_ids = decoder_input_ids[:, past_length:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And in this case - we're removing tokens that have already been seen i.e. have been processed and part of the cache?
cache_length = past_length = past_key_values[0][0].shape[2] | ||
max_cache_length = None | ||
|
||
# Keep only the unprocessed tokens: | ||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where | ||
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as | ||
# input) | ||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: | ||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] | ||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard | ||
# input_ids based on the past_length. | ||
elif past_length < input_ids.shape[1]: | ||
input_ids = input_ids[:, past_length:] | ||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. | ||
|
||
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask. | ||
if ( | ||
max_cache_length is not None | ||
and attention_mask is not None | ||
and cache_length + input_ids.shape[1] > max_cache_length | ||
): | ||
attention_mask = attention_mask[:, -max_cache_length:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By eye, this looks equivalent to the logic above, just with input_ids instead of decoder_ids -> can we abstract out the common logic here?
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Mr bot, this is not stale (on hold while the static cache is being worked on, as they will likely have overlapping changes and the static cache is more important) |
Closing this PR, at this point it's easier to start from scratch |
What does this PR do?
This PR applies the changes to
Bart
so it supports the newCache
objects. In other works, it is akin to #26681 but for encoder-decoder models.cache_utils.py
. I've introducedDynamicCacheWithCrossAttention
, which expandsDynamicCache
[cache object equivalent to the previouspast_key_values
input/output] with the ability to hold a cross-attention cache. This design was intentional: most LLMs (and now even multimodel models) tend to be decoder-only, so this separation will keep the cache class for decoder-only models simpler. It also enable us to be more strict -- I've caught an unintended cache deletion in Whisper thanks to the increased specificity!modeling_bart.py
. These changes are the equivalent of the modeling changes in Generate: NewCache
abstraction and Attention Sinks support #26681, but for encoder-decoder models.make fix-copies
(plus a few manual changes like adding imports or updating docstrings), or test upgrades for the newDynamicCacheWithCrossAttention
.The following tests were run locally - includes FA2 and some pretty challenging tests to ensure nothing was broken in the process:
RUN_SLOW=1 py.test tests/models/bart/test_modeling_bart.py -vv
RUN_SLOW=1 py.test tests/models/mbart/test_modeling_mbart.py -vv
RUN_SLOW=1 py.test tests/models/whisper/test_modeling_whisper.py -vv
👉 In any case, we should run the slow CI before merging!
Note on Whisper: same failures as in `main`, i.e. (open me)