diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 3b491b04607b57..490280ce813bd6 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -528,7 +528,7 @@ def from_batch_splits(cls, splits: List["DynamicCache"], num_hidden_layers: int cache = cls() for idx in range(len(splits[0])): key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []] - value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []] + value_cache = [current.value_cache[idx] for current in splits if current.value_cache[idx] != []] if key_cache != []: layer_keys = torch.cat(key_cache, dim=0) layer_values = torch.cat(value_cache, dim=0) @@ -1523,7 +1523,10 @@ def crop(self, maximum_length: int): self.check_dynamic_cache(self.crop.__name__) self.self_attention_cache.crop(maximum_length) - def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]": + @deprecate_kwarg("num_hidden_layers", version="4.47.0") + def batch_split( + self, full_batch_size: int, split_size: int, num_hidden_layers: int = None + ) -> "List[EncoderDecoderCache]": """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by `_split_model_inputs()` in `generation.utils`""" self.check_dynamic_cache(self.batch_split.__name__) @@ -1536,7 +1539,10 @@ def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDec return out @classmethod - def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache": + @deprecate_kwarg("num_hidden_layers", version="4.47.0") + def from_batch_splits( + cls, splits: List["EncoderDecoderCache"], num_hidden_layers: int = None + ) -> "EncoderDecoderCache": """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in `generation.utils`""" self_attention_cache = DynamicCache() diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index b1d0042c65167f..76dc23ed9bf7c1 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1046,7 +1046,6 @@ def test_contrastive_generate_low_memory(self): self.assertListEqual(low_output.tolist(), high_output.tolist()) @pytest.mark.generate - @unittest.skip("Started to break with https://github.com/huggingface/transformers/pull/33703") def test_beam_search_low_memory(self): # Check that choosing 'low_memory' does not change the model output for model_class in self.all_generative_model_classes: diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index 0943661b96666c..a141ef40be1959 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -330,7 +330,7 @@ def __init__( hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, - max_position_embeddings=20, + max_position_embeddings=512, eos_token_id=2, pad_token_id=1, bos_token_id=0,