Skip to content

Commit

Permalink
🚨 Bloom support for cache class (#31445)
Browse files Browse the repository at this point in the history
* bloom dynamic cache

* bloom follows standard cache format

* no skips for bloom anymore

* use cache position when possible

* clean up

* codestyle

* Update src/transformers/models/bloom/modeling_bloom.py

Co-authored-by: amyeroberts <[email protected]>

* Update src/transformers/models/bloom/modeling_bloom.py

Co-authored-by: amyeroberts <[email protected]>

* Update src/transformers/models/bloom/modeling_bloom.py

Co-authored-by: amyeroberts <[email protected]>

* pr comments

* isinstance fix

* address comments

* make musicgen test happy

* [run-slow] bloom

---------

Co-authored-by: amyeroberts <[email protected]>
  • Loading branch information
zucchini-nlp and amyeroberts authored Jul 29, 2024
1 parent 44f6fdd commit f739687
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 195 deletions.
15 changes: 1 addition & 14 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,19 +378,7 @@ def _crop_past_key_values(model, past_key_values, max_length):
)
)
past_key_values = tuple(new_past)
# bloom is special
elif "bloom" in model.__class__.__name__.lower() or (
model.config.architectures is not None and "bloom" in model.config.architectures[0].lower()
):
for idx in range(len(past_key_values)):
new_past.append(
(
past_key_values[idx][0][:, :, :max_length],
past_key_values[idx][1][:, :max_length, :],
)
)
past_key_values = tuple(new_past)
# gptbigcode is too
# gptbigcode is special and stores kv in shape (batch_size, seq_len, dim), if it's a multi_query model
elif "gptbigcode" in model.__class__.__name__.lower() or (
model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower()
):
Expand All @@ -402,7 +390,6 @@ def _crop_past_key_values(model, past_key_values, max_length):
past_key_values[idx] = past_key_values[idx][:, :, :max_length, :]
elif isinstance(past_key_values, DynamicCache):
past_key_values.crop(max_length)

elif past_key_values is not None:
for idx in range(len(past_key_values)):
new_past.append(
Expand Down
17 changes: 4 additions & 13 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ def _expand_dict_for_generation(dict_to_expand):

return input_ids, model_kwargs

def _extract_past_from_model_output(self, outputs: ModelOutput, standardize_cache_format: bool = False):
def _extract_past_from_model_output(self, outputs: ModelOutput):
past_key_values = None
cache_name = "past_key_values"
if "past_key_values" in outputs:
Expand All @@ -652,24 +652,17 @@ def _extract_past_from_model_output(self, outputs: ModelOutput, standardize_cach
past_key_values = outputs.cache_params
cache_name = "cache_params"

# Bloom fix: standardizes the cache format when requested
if standardize_cache_format and hasattr(self, "_convert_to_standard_cache"):
batch_size = outputs.logits.shape[0]
past_key_values = self._convert_to_standard_cache(past_key_values, batch_size=batch_size)
return cache_name, past_key_values

def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
num_new_tokens: int = 1,
) -> Dict[str, Any]:
# update past_key_values keeping its naming used in model code
cache_name, cache = self._extract_past_from_model_output(
outputs, standardize_cache_format=standardize_cache_format
)
cache_name, cache = self._extract_past_from_model_output(outputs)
model_kwargs[cache_name] = cache
if getattr(outputs, "state", None) is not None:
model_kwargs["state"] = outputs.state
Expand Down Expand Up @@ -2558,7 +2551,6 @@ def _contrastive_search(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
standardize_cache_format=True,
)

if not sequential:
Expand Down Expand Up @@ -2723,7 +2715,7 @@ def _contrastive_search(
next_past_key_values = selected_outputs["past_key_values"]

else:
_, next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True)
_, next_past_key_values = self._extract_past_from_model_output(outputs)
# Do it in-place layer per layer to save memory
if isinstance(next_past_key_values, DynamicCache) or (
isinstance(next_past_key_values, EncoderDecoderCache)
Expand Down Expand Up @@ -3033,7 +3025,7 @@ def _temporary_reorder_cache(self, past_key_values, beam_idx):
past_key_values = self._reorder_cache(past_key_values, beam_idx)
# Exception 2: models with different cache formats. These are limited to `DynamicCache` until their
# cache format is standardized, to avoid adding complexity to the codebase.
elif "bloom" in model_class or "gptbigcode" in model_class:
elif "gptbigcode" in model_class:
if not isinstance(past_key_values, (DynamicCache, EncoderDecoderCache)):
raise ValueError(
f"Using an unsupported cache format with {model_class}. Currently, it only supports the "
Expand Down Expand Up @@ -3161,7 +3153,6 @@ def _beam_search(
for model_name in [
"fsmt",
"reformer",
"bloom",
"ctrl",
"gpt_bigcode",
"transo_xl",
Expand Down
Loading

0 comments on commit f739687

Please sign in to comment.