diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index f3f0bd6fe5458f..9d4d90f11221db 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1140,13 +1140,13 @@ def __init__( layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, ) -> None: super().__init__() - if max_batch_size is not None: + if batch_size is not None: logger.warning_once( - f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " - "v4.46. Use the more precisely named 'batch_size' argument instead." + f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " + "v4.49. Use the more precisely named 'max_batch_size' argument instead." ) - self.batch_size = batch_size or max_batch_size + self.max_batch_size = batch_size or max_batch_size self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads @@ -1254,6 +1254,14 @@ def reset(self): self.key_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_() + @property + def batch_size(self): + logger.warning_once( + f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in " + "v4.49. Use the more precisely named 'self.max_batch_size' attribute instead." + ) + return self.max_batch_size + class SlidingWindowCache(StaticCache): """ @@ -1626,10 +1634,10 @@ def __init__( layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, ) -> None: super().__init__() - if max_batch_size is not None: + if batch_size is not None: logger.warning_once( - f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " - "v4.46. Use the more precisely named 'batch_size' argument instead." + f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " + "v4.49. Use the more precisely named 'max_batch_size' argument instead." ) if not hasattr(config, "sliding_window") or config.sliding_window is None: raise ValueError( @@ -1638,7 +1646,7 @@ def __init__( "config and it's not set to None." ) self.max_cache_len = max_cache_len - self.batch_size = batch_size or max_batch_size + self.max_batch_size = batch_size or max_batch_size # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads self.head_dim = ( config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads @@ -1758,6 +1766,14 @@ def reset(self): self.key_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_() + @property + def batch_size(self): + logger.warning_once( + f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in " + "v4.49. Use the more precisely named 'self.max_batch_size' attribute instead." + ) + return self.max_batch_size + class MambaCache: """ @@ -1815,20 +1831,20 @@ def __init__( device: Optional[Union[torch.device, str]] = None, max_batch_size: Optional[int] = None, ): - if max_batch_size is not None: + if batch_size is not None: logger.warning_once( - f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " - "v4.46. Use the more precisely named 'batch_size' argument instead." + f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " + "v4.49. Use the more precisely named 'max_batch_size' argument instead." ) self.dtype = dtype - self.batch_size = batch_size or max_batch_size + self.max_batch_size = batch_size or max_batch_size self.intermediate_size = config.intermediate_size self.ssm_state_size = config.state_size self.conv_kernel_size = config.conv_kernel self.conv_states: torch.Tensor = torch.zeros( config.num_hidden_layers, - self.batch_size, + self.max_batch_size, self.intermediate_size, self.conv_kernel_size, device=device, @@ -1836,7 +1852,7 @@ def __init__( ) self.ssm_states: torch.Tensor = torch.zeros( config.num_hidden_layers, - self.batch_size, + self.max_batch_size, self.intermediate_size, self.ssm_state_size, device=device, @@ -1866,6 +1882,14 @@ def reset(self): self.conv_states.zero_() self.ssm_states.zero_() + @property + def batch_size(self): + logger.warning_once( + f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in " + "v4.49. Use the more precisely named 'self.max_batch_size' attribute instead." + ) + return self.max_batch_size + class OffloadedStaticCache(StaticCache): """ @@ -1887,6 +1911,9 @@ class OffloadedStaticCache(StaticCache): The default `dtype` to use when initializing the cache. offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`): The device to offload to. Defaults to CPU. + layer_device_map (`Dict[int, Union[str, torch.device, int]]`, *optional*): + Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus. + You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`. Attributes: key_cache (`List[torch.Tensor]`): @@ -1933,10 +1960,11 @@ def __init__( device: Union[str, torch.device], dtype: Optional[torch.dtype] = None, offload_device: Union[str, torch.device] = torch.device("cpu"), + layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, ) -> None: self.max_batch_size = max_batch_size self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len - self.device = torch.device(device) + self.device = torch.device(device) if layer_device_map is None else layer_device_map[0] self.offload_device = torch.device(offload_device) self.dtype = dtype if dtype is not None else torch.float32 @@ -1944,7 +1972,9 @@ def __init__( head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads num_key_value_heads = ( - config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads + config.num_attention_heads + if getattr(config, "num_key_value_heads", None) is None + else config.num_key_value_heads ) cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index cbc445308ad541..30a632aa8cca6a 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -72,7 +72,9 @@ "mamba": MambaCache, } QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache} - ALL_CACHE_IMPLEMENTATIONS = list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(NEEDS_CACHE_CONFIG.keys()) + ALL_CACHE_IMPLEMENTATIONS = ( + list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(NEEDS_CACHE_CONFIG.keys()) + ["offloaded"] + ) class GenerationMode(ExplicitEnum): diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 16b26ade7a62f9..05e39c4a9b56b5 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1610,7 +1610,7 @@ def _get_cache( need_new_cache = ( not hasattr(self, "_cache") or (not isinstance(cache_to_check, cache_cls)) - or cache_to_check.batch_size != batch_size + or cache_to_check.max_batch_size != batch_size ) if cache_implementation != "mamba": need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len @@ -1666,7 +1666,7 @@ def get_layer_device_map(execution_device_map: Optional[dict] = None): cache_kwargs = { "config": self.config.get_text_config(), - "batch_size": batch_size, + "max_batch_size": batch_size, "max_cache_len": max_cache_len, "device": device, "dtype": cache_dtype, diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index e403a528a8c8b7..063e9a3da8fdad 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1880,6 +1880,32 @@ def test_new_cache_format(self, num_beams, do_sample): ) ) + @parameterized.expand([("offloaded",)]) # ("offloaded_static",) TODO: @raushan fixme in some models (eg T5) + @require_torch_gpu + @pytest.mark.generate + def test_offloaded_cache_implementation(self, cache_implementation): + """Tests we can generate by indicating `cache_implementation` for each possible cache class""" + for model_class in self.all_generative_model_classes: + if not model_class._supports_cache_class: + self.skipTest(reason="This model does not support the new cache format") + + config, inputs_dict = self.prepare_config_and_inputs_for_generate() + + model = model_class(config).to(torch_device).eval() + generation_kwargs = { + "max_new_tokens": 5, + "use_cache": True, + "cache_implementation": cache_implementation, + } + + legacy_results = model.generate(**generation_kwargs, **inputs_dict) + + # Most cache classes have their own tests except for some that are tested here + # The ones here do not need special treatment when passing `cache_implementation` + # and are not bound to specific models only + new_results = model.generate(**generation_kwargs, **inputs_dict) + self.assertListEqual(legacy_results.tolist(), new_results.tolist()) + @pytest.mark.generate def test_generate_with_static_cache(self): """ diff --git a/tests/models/mllama/test_modeling_mllama.py b/tests/models/mllama/test_modeling_mllama.py index 8da927f815db81..cfd64aee5368f3 100644 --- a/tests/models/mllama/test_modeling_mllama.py +++ b/tests/models/mllama/test_modeling_mllama.py @@ -16,7 +16,9 @@ import unittest +import pytest import requests +from parameterized import parameterized from transformers import ( AutoProcessor, @@ -365,6 +367,12 @@ def test_sdpa_can_compile_dynamic(self): def test_model_parallelism(self): pass + @parameterized.expand([("offloaded",)]) + @pytest.mark.generate + @unittest.skip(reason="Offloaded cache seems to not work with mllama's kv cache type") + def test_offloaded_cache_implementation(self, cache_implementation): + pass + def test_generate_text_only_with_cache(self): """ Tests that our cached generation with text-only inputs works. When mllama was introduced, this feature diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index bf0da2c2757bbf..9389c4f47def1b 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -567,6 +567,12 @@ def test_training_gradient_checkpointing_use_reentrant_false(self): def test_generate_with_head_masking(self): pass + @parameterized.expand([("offloaded",)]) + @pytest.mark.generate + @unittest.skip(reason="Whisper doesnt work with offloaded cache implementation yet") + def test_offloaded_cache_implementation(self, cache_implementation): + pass + @require_torch_fp16 def test_generate_fp16(self): config, input_dict = self.model_tester.prepare_config_and_inputs()