From f7152251b0b1be89490c0e424303894d5757d851 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Thu, 23 May 2024 17:25:20 +0500 Subject: [PATCH] Quantized KV Cache (#30483) * clean-up * Update src/transformers/cache_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/cache_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/cache_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fixup * Update tests/quantization/quanto_integration/test_quanto.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/generation/configuration_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * more suggestions * mapping if torch available * run tests & add 'support_quantized' flag * fix jamba test * revert, will be fixed by another PR * codestyle * HQQ and versatile cache classes * final update * typo * make tests happy --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- docker/transformers-all-latest-gpu/Dockerfile | 3 + docs/source/en/generation_strategies.md | 37 ++ docs/source/en/internal/generation_utils.md | 16 +- src/transformers/__init__.py | 24 +- src/transformers/cache_utils.py | 370 +++++++++++++++++- .../generation/configuration_utils.py | 35 +- src/transformers/generation/utils.py | 67 +++- src/transformers/modeling_utils.py | 3 + .../models/cohere/modeling_cohere.py | 1 + src/transformers/models/dbrx/modeling_dbrx.py | 1 + .../models/gemma/modeling_gemma.py | 1 + .../models/llama/modeling_llama.py | 1 + src/transformers/models/olmo/modeling_olmo.py | 1 + .../modeling_recurrent_gemma.py | 2 + .../models/stablelm/modeling_stablelm.py | 1 + .../models/whisper/generation_whisper.py | 10 +- src/transformers/utils/dummy_pt_objects.py | 35 ++ tests/generation/test_utils.py | 36 +- .../quanto_integration/test_quanto.py | 36 +- 19 files changed, 652 insertions(+), 28 deletions(-) diff --git a/docker/transformers-all-latest-gpu/Dockerfile b/docker/transformers-all-latest-gpu/Dockerfile index f6fa587fb64073..930fdfb799cd33 100644 --- a/docker/transformers-all-latest-gpu/Dockerfile +++ b/docker/transformers-all-latest-gpu/Dockerfile @@ -51,6 +51,9 @@ RUN python3 -m pip install --no-cache-dir gguf # Some slow tests require bnb RUN python3 -m pip install --no-cache-dir bitsandbytes +# Some tests require quanto +RUN python3 -m pip install --no-cache-dir quanto + # For `dinat` model # The `XXX` part in `torchXXX` needs to match `PYTORCH` (to some extent) RUN python3 -m pip install --no-cache-dir natten==0.15.1+torch220$CUDA -f https://shi-labs.com/natten/wheels diff --git a/docs/source/en/generation_strategies.md b/docs/source/en/generation_strategies.md index 6c7c70cb1400b5..b000cc06779918 100644 --- a/docs/source/en/generation_strategies.md +++ b/docs/source/en/generation_strategies.md @@ -174,6 +174,43 @@ An increasing sequence: one, two, three, four, five, six, seven, eight, nine, te ``` +## KV Cache Quantization + +The `generate()` method supports caching keys and values to enhance efficiency and avoid re-computations. However the key and value +cache can occupy a large portion of memory, becoming a bottleneck for long-context generation, especially for Large Language Models. +Quantizing the cache when using `generate()` can significantly reduce memory requirements at the cost of speed. + +KV Cache quantization in `transformers` is largely inspired by the paper [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache] +(https://arxiv.org/abs/2402.02750) and currently supports `quanto` and `HQQ` as backends. For more information on the inner workings see the paper. + +To enable quantization of the key-value cache, one needs to indicate `cache_implementation="quantized"` in the `generation_config`. +Quantization related arguments should be passed to the `generation_config` either as a `dict` or an instance of a [`QuantizedCacheConfig`] class. +One has to indicate which quantization backend to use in the [`QuantizedCacheConfig`], the default is `quanto`. + + + +Cache quantization can be detrimental if the context length is short and there is enough GPU VRAM available to run without cache quantization. + + + + +```python +>>> import torch +>>> from transformers import AutoTokenizer, AutoModelForCausalLM + +>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") +>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16).to("cuda:0") +>>> inputs = tokenizer("I like rock music because", return_tensors="pt").to(model.device) + +>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="quantized", cache_config={"nbits": 4, "backend": "quanto"}) +>>> print(tokenizer.batch_decode(out, skip_special_tokens=True)[0]) +I like rock music because it's loud and energetic. It's a great way to express myself and rel + +>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=20) +>>> print(tokenizer.batch_decode(out, skip_special_tokens=True)[0]) +I like rock music because it's loud and energetic. I like to listen to it when I'm feeling +``` + ## Watermarking The `generate()` supports watermarking the generated text by randomly marking a portion of tokens as "green". diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index 04a4428a008085..5bf8b5c4a0b36f 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -360,6 +360,12 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] Cache - update +[[autodoc]] CacheConfig + - update + +[[autodoc]] QuantizedCacheConfig + - validate + [[autodoc]] DynamicCache - update - get_seq_length @@ -367,6 +373,14 @@ A [`Constraint`] can be used to force the generation to include specific tokens - to_legacy_cache - from_legacy_cache +[[autodoc]] QuantizedCache + - update + - get_seq_length + +[[autodoc]] QuantoQuantizedCache + +[[autodoc]] HQQQuantizedCache + [[autodoc]] SinkCache - update - get_seq_length @@ -375,7 +389,7 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] StaticCache - update - get_seq_length - - reorder_cache + - reset ## Watermark Utils diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 4255e303799442..8da7a8b3e39a1a 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1182,7 +1182,17 @@ _import_structure["activations"] = [] _import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"] _import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"] - _import_structure["cache_utils"] = ["Cache", "DynamicCache", "SinkCache", "StaticCache"] + _import_structure["cache_utils"] = [ + "Cache", + "CacheConfig", + "DynamicCache", + "HQQQuantizedCache", + "QuantizedCache", + "QuantizedCacheConfig", + "QuantoQuantizedCache", + "SinkCache", + "StaticCache", + ] _import_structure["data.datasets"] = [ "GlueDataset", "GlueDataTrainingArguments", @@ -5792,7 +5802,17 @@ # Benchmarks from .benchmark.benchmark import PyTorchBenchmark from .benchmark.benchmark_args import PyTorchBenchmarkArguments - from .cache_utils import Cache, DynamicCache, SinkCache, StaticCache + from .cache_utils import ( + Cache, + CacheConfig, + DynamicCache, + HQQQuantizedCache, + QuantizedCache, + QuantizedCacheConfig, + QuantoQuantizedCache, + SinkCache, + StaticCache, + ) from .data.datasets import ( GlueDataset, GlueDataTrainingArguments, diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 990a863e18e8dc..ad91edfcbb50b2 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1,12 +1,21 @@ +import copy +import json +import os from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import torch from .configuration_utils import PretrainedConfig -from .utils import logging +from .utils import is_hqq_available, is_quanto_available, logging +if is_quanto_available(): + from quanto import QBitsTensor, qint2, qint4 + +if is_hqq_available(): + from hqq.core.quantize import Quantizer as HQQQuantizer + logger = logging.get_logger(__name__) @@ -82,6 +91,201 @@ def seen_tokens(self): return None +@dataclass +class CacheConfig: + """ + Base class for cache configs + """ + + cache_implementation: None + + @classmethod + def from_dict(cls, config_dict, **kwargs): + """ + Constructs a CacheConfig instance from a dictionary of parameters. + Args: + config_dict (Dict[str, Any]): Dictionary containing configuration parameters. + **kwargs: Additional keyword arguments to override dictionary values. + Returns: + CacheConfig: Instance of CacheConfig constructed from the dictionary. + """ + config = cls(**config_dict) + to_remove = [] + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + return config + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this configuration instance's parameters will be saved. + use_diff (`bool`, *optional*, defaults to `True`): + If set to `True`, only the difference between the config instance and the default + `QuantizationConfig()` is serialized to JSON file. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + config_dict = self.to_dict() + json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + writer.write(json_string) + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + return copy.deepcopy(self.__dict__) + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__ + def __iter__(self): + """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" + for attr, value in copy.deepcopy(self.__dict__).items(): + yield attr, value + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__ + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + def to_json_string(self): + """ + Serializes this instance to a JSON formatted string. + Returns: + str: JSON formatted string representing the configuration instance. + """ + return json.dumps(self.__dict__, indent=2) + "\n" + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update + def update(self, **kwargs): + """ + Updates attributes of this class instance with attributes from `kwargs` if they match existing atributtes, + returning all the unused kwargs. + + Args: + kwargs (`Dict[str, Any]`): + Dictionary of attributes to tentatively update this class. + + Returns: + `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. + """ + to_remove = [] + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + to_remove.append(key) + + # Remove all the attributes that were updated, without modifying the input dict + unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} + return unused_kwargs + + +@dataclass +class QuantizedCacheConfig(CacheConfig): + """ + Configuration class for quantized cache settings. + + Attributes: + backend (`str`, *optional*, defaults to `"quanto"`): + Backend to use when performing quantization, Can be one of [`quanto`, `HQQ`] + nbits (`Optional[int]`, *optional*, defaults to 4): + Number of bits, can be 2 or 4 for the `quanto` backend and one of [1, 2, 3, 4, 8] for the `HQQ` backend. Defaults to 2. + axis_key (`int`, *optional*, defaults to 0): + Axis over which to perform grouping for the key tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. + axis_value (`int`, *optional*, defaults to 0): + Axis over which to perform grouping for the value tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. + q_group_size (`Optional[int]`, *optional*, defaults to 64): + Size of the quantization group, should be a divisor of the model's hidden dimension. + Defaults to 64. + residual_length (`Optional[int]`, *optional*, defaults to 128): + Length of the residual cache which will always be stored in original presicion. + Defaults to 128. + compute_dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): + The defualt dtype used for computations in the model. Keys and Values will be cast to this dtype after dequantization. + device (`str`, *optional*, defaults to `"cpu"`): + Device on which to peform computations, should be same as the model's device. + """ + + def __init__( + self, + backend: str = "quanto", + nbits: Optional[int] = 4, + axis_key: Optional[int] = 0, + axis_value: Optional[int] = 0, + q_group_size: Optional[int] = 64, + residual_length: Optional[int] = 128, + compute_dtype: Optional[torch.dtype] = torch.float16, + device: Optional[str] = "cpu", + ): + self.backend = backend + self.nbits = nbits + self.axis_key = axis_key + self.axis_value = axis_value + self.q_group_size = q_group_size + self.residual_length = residual_length + self.compute_dtype = compute_dtype + self.device = device + + def validate(self): + """Validates if the arguments passed are correct""" + + incorrect_arg_msg = ( + "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` " + "but found {found_value}" + ) + # Check that the values are reasonable in general (nbits, axis) + # Later in QuantizedCache init we check if they are supported for that particular backend + if self.nbits not in [1, 2, 3, 4, 8]: + raise ValueError( + incorrect_arg_msg.format( + key="nbits", + correct_value="2 or 4 or 8", + found_value=self.nbits, + ), + ) + if self.q_group_size <= 0: + raise ValueError( + incorrect_arg_msg.format( + key="q_group_size", + correct_value="a positive integer", + found_value=self.q_group_size, + ), + ) + if self.residual_length < 0: + raise ValueError( + incorrect_arg_msg.format( + key="residual_length", + correct_value="a positive integer", + found_value=self.residual_length, + ), + ) + + if self.axis_key not in [0, 1, -1]: + raise ValueError( + incorrect_arg_msg.format( + key="axis_key", + correct_value="`1` or `0`, `-1`", + found_value=self.axis_key, + ), + ) + + if self.axis_value not in [0, 1, -1]: + raise ValueError( + incorrect_arg_msg.format( + key="axis_value", + correct_value="`1` or `0` or `-1`", + found_value=self.axis_value, + ), + ) + + class DynamicCache(Cache): """ A cache that grows dynamically as more tokens are generated. This is the default for generative models. @@ -186,6 +390,168 @@ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTens return cache +class QuantizedCache(DynamicCache): + """ + A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). + It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. + + The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the + original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The + quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. + + It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and + Value in original precision states as a list of tensors, one for each layer. The size of each tensor + is `[batch_size, num_heads, seq_len - residual_length, head_dim]` + """ + + def __init__(self, cache_config: QuantizedCacheConfig) -> None: + self._quantized_key_cache: List[torch.Tensor] = [] + self._quantized_value_cache: List[torch.Tensor] = [] + + self.nbits = cache_config.nbits + self.residual_length = cache_config.residual_length + self.q_group_size = cache_config.q_group_size + self.axis_key = cache_config.axis_key + self.axis_value = cache_config.axis_value + self.compute_dtype = cache_config.compute_dtype + self.device = cache_config.device + + super().__init__() + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + if len(self.key_cache) <= layer_idx: + self._quantized_key_cache.append(self._quantize(key_states.contiguous(), axis=self.axis_key)) + self._quantized_value_cache.append(self._quantize(value_states.contiguous(), axis=self.axis_value)) + self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) + self.value_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) + keys_to_return, values_to_return = key_states, value_states + else: + dequant_key = self._dequantize(self._quantized_key_cache[layer_idx]) + dequant_value = self._dequantize(self._quantized_value_cache[layer_idx]) + keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states] + values_to_return = [dequant_value, self.value_cache[layer_idx], value_states] + + keys_to_return = torch.cat(keys_to_return, dim=-2) + values_to_return = torch.cat(values_to_return, dim=-2) + if ( + self.key_cache[layer_idx].dim() == 4 + and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length + ): + self._quantized_key_cache[layer_idx] = self._quantize(keys_to_return.contiguous(), axis=self.axis_key) + self._quantized_value_cache[layer_idx] = self._quantize( + values_to_return.contiguous(), axis=self.axis_value + ) + self.key_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device) + self.value_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device) + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + + return keys_to_return, values_to_return + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + if len(self.key_cache) <= layer_idx: + return 0 + # since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is + # updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx + # this part of code otherwise fails when used to verify attn_weight shape in some models + return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1 + + def _quantize(self, tensor, axis): + """Quantizes a key/value using a defined quantization method.""" + raise NotImplementedError("Make sure to implement `_quantize` in a subclass.") + + def _dequantize(self, q_tensor): + """Dequantizes back the tensor that was quantized by `self._quantize()`""" + raise NotImplementedError("Make sure to implement `_dequantize` in a subclass.") + + +class QuantoQuantizedCache(QuantizedCache): + """ + Quantized Cache class that uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only. + + Parameters: + cache_config (`QuantizedCacheConfig`,): + A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. + """ + + def __init__(self, cache_config: CacheConfig) -> None: + super().__init__(cache_config) + if self.nbits not in [2, 4]: + raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}") + + if self.axis_key not in [0, -1]: + raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}") + + if self.axis_value not in [0, -1]: + raise ValueError( + f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}" + ) + + self.qtype = qint4 if self.nbits == 4 else qint2 + + def _quantize(self, tensor, axis): + qtensor = QBitsTensor.quantize(tensor, axis=axis, qtype=self.qtype, group_size=self.q_group_size) + return qtensor + + def _dequantize(self, qtensor): + return qtensor.dequantize() + + +class HQQQuantizedCache(QuantizedCache): + """ + Quantized Cache class that uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes. + + Parameters: + cache_config (`QuantizedCacheConfig`,): + A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. + """ + + def __init__(self, cache_config: CacheConfig) -> None: + super().__init__(cache_config) + if self.nbits not in [1, 2, 3, 4, 8]: + raise ValueError( + f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}" + ) + + if self.axis_key not in [0, 1]: + raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}") + + if self.axis_value not in [0, 1]: + raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}") + + self.quantizer = HQQQuantizer + + def _quantize(self, tensor, axis): + qtensor, meta = self.quantizer.quantize( + tensor, + axis=axis, + device=self.device, + compute_dtype=self.compute_dtype, + nbits=self.nbits, + group_size=self.q_group_size, + ) + meta["compute_dtype"] = self.compute_dtype + self.quantizer.cuda(qtensor, meta=meta, device=self.device) # Move to device and cast to dtype + return qtensor, meta + + def _dequantize(self, qtensor): + quant_tensor, meta = qtensor + tensor = self.quantizer.dequantize(quant_tensor, meta) + return tensor + + class SinkCache(Cache): """ A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index eb14c60d9af905..0d1eba0bd5d6ef 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -31,6 +31,7 @@ download_url, extract_commit_hash, is_remote_url, + is_torch_available, logging, ) @@ -41,6 +42,12 @@ logger = logging.get_logger(__name__) METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash", "transformers_version") +NEEDS_CACHE_CONFIG = {} + +if is_torch_available(): + from ..cache_utils import QuantizedCacheConfig + + NEEDS_CACHE_CONFIG["quantized"] = QuantizedCacheConfig class GenerationMode(ExplicitEnum): @@ -299,6 +306,10 @@ class GenerationConfig(PushToHubMixin): cache_implementation (`str`, *optional*, default to `None`): Cache class that should be used when generating. + cache_config (`Union[CacheConfig, dict]`, *optional*, default to `None`): + Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and + it will be converted to its repsective `CacheConfig` internally. + Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`. > Wild card @@ -382,6 +393,13 @@ def __init__(self, **kwargs): # Cache implementation self.cache_implementation = kwargs.pop("cache_implementation", None) + self.cache_config = kwargs.pop("cache_config", None) + if self.cache_implementation is not None: + cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation] + if self.cache_config is None: + self.cache_config = cache_config_class() + elif isinstance(self.cache_config, dict): + self.cache_config = cache_config_class.from_dict(self.cache_config) # Prompt lookup decoding self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None) @@ -638,13 +656,26 @@ def validate(self, is_init=False): f"({self.num_beams})." ) - # check watermarking arguments + # 5. check `cache_config` + if self.cache_config is not None: + cache_class = NEEDS_CACHE_CONFIG.get(self.cache_implementation) + if cache_class is None: + raise ValueError( + "You provided a `cache_config` but the cache implementation you are using " + f"({self.cache_implementation}) does not require any config. Make sure to use the " + "correct cache implementation matching your cache config." + ) + if not isinstance(self.cache_config, cache_class): + self.cache_config = cache_class.from_dict(self.cache_config) + self.cache_config.validate() + + # 6. check watermarking arguments if self.watermarking_config is not None: if not isinstance(self.watermarking_config, WatermarkingConfig): self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config) self.watermarking_config.validate() - # 5. check common issue: passing `generate` arguments inside the generation config + # 7. check common issue: passing `generate` arguments inside the generation config generate_arguments = ( "logits_processor", "stopping_criteria", diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 149ce144e66272..84c9dd995eb4f1 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -24,7 +24,15 @@ import torch.distributed as dist from torch import nn -from ..cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ..cache_utils import ( + Cache, + DynamicCache, + HQQQuantizedCache, + QuantizedCacheConfig, + QuantoQuantizedCache, + SlidingWindowCache, + StaticCache, +) from ..integrations.deepspeed import is_deepspeed_zero3_enabled from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput from ..models.auto import ( @@ -34,7 +42,14 @@ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, MODEL_FOR_VISION_2_SEQ_MAPPING, ) -from ..utils import ModelOutput, is_accelerate_available, is_torchdynamo_compiling, logging +from ..utils import ( + ModelOutput, + is_accelerate_available, + is_hqq_available, + is_quanto_available, + is_torchdynamo_compiling, + logging, +) from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .candidate_generator import ( @@ -97,6 +112,7 @@ from accelerate.hooks import AlignDevicesHook, add_hook_to_module NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache} +QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache} @dataclass @@ -1658,20 +1674,43 @@ def generate( "Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a " "Cache object) is unsupported. Please use only one of the two." ) - elif generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: - if not self._supports_cache_class: - raise ValueError( - "This model does not support the `cache_implementation` argument. Please check the following " - "issue: https://github.com/huggingface/transformers/issues/28981." + elif generation_config.cache_implementation is not None: + if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: + if generation_config.cache_implementation == "static" and not self._supports_static_cache: + raise ValueError( + "This model does not support `cache_implementation='static'`. Please check the following " + "issue: https://github.com/huggingface/transformers/issues/28981" + ) + model_kwargs["past_key_values"] = self._get_cache( + generation_config.cache_implementation, batch_size, generation_config.max_length ) - if generation_config.cache_implementation == "static" and not self._supports_static_cache: - raise ValueError( - "This model does not support `cache_implementation='static'`. Please check the following " - "issue: https://github.com/huggingface/transformers/issues/28981" + elif generation_config.cache_implementation == "quantized": + if not self._supports_quantized_cache: + raise ValueError( + "This model does not support the quantized cache. If you want your model to support quantized " + "cache, please open an issue." + ) + + cache_config = ( + generation_config.cache_config + if generation_config.cache_config is not None + else QuantizedCacheConfig() ) - model_kwargs["past_key_values"] = self._get_cache( - generation_config.cache_implementation, batch_size, generation_config.max_length - ) + cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend] + + if cache_config.backend == "quanto" and not is_quanto_available(): + raise ImportError( + "You need to install `quanto` in order to use KV cache quantization with quanto backend. " + "Please install it via with `pip install quanto`" + ) + elif cache_config.backend == "HQQ" and not is_hqq_available(): + raise ImportError( + "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. " + "Please install it via with `pip install hqq`" + ) + + model_kwargs["past_key_values"] = cache_class(cache_config) + self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) # 7. determine generation mode diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 106f79ae8e3b58..354962bab055df 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1284,6 +1284,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix _supports_cache_class = False _supports_static_cache = False + # Has support for a `QuantoQuantizedCache` instance as `past_key_values` + _supports_quantized_cache = False + @property def dummy_inputs(self) -> Dict[str, torch.Tensor]: """ diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 41c4e151a3da13..7d1b0e19fc4df6 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -712,6 +712,7 @@ class CoherePreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True def _init_weights(self, module): diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 850e3a3f81dde5..67f4b819e990d8 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -937,6 +937,7 @@ class DbrxPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True def _init_weights(self, module: nn.Module): diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 565a976fd74ad2..474dccf3081d49 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -698,6 +698,7 @@ class GemmaPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True def _init_weights(self, module): diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 1f4f6ac9a0660d..226d14c18b991c 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -767,6 +767,7 @@ class LlamaPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True def _init_weights(self, module): diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 9b4b08239bc4d9..1630297cd82d19 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -745,6 +745,7 @@ class OlmoPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True def _init_weights(self, module): diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index c48132c83c28a5..ab9f8c3d853006 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -539,6 +539,8 @@ class RecurrentGemmaPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["cache"] _supports_flash_attn_2 = False _supports_sdpa = False # we can't compare with eager for now + _supports_cache_class = True + _supports_quantized_cache = True def _init_weights(self, module): std = math.sqrt(self.config.w_init_variance_scale / self.config.conv1d_width) diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index df4e922c5a2f75..160d70fe617d5f 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -832,6 +832,7 @@ class StableLmPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_cache_class = True _supports_sdpa = True + _supports_quantized_cache = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index a42d7b7dec3626..f30cfe19476504 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -784,11 +784,11 @@ def generate_with_fallback( del generate_kwargs[key] seek_outputs = super().generate( segment_input, - generation_config, - logits_processor, - stopping_criteria, - prefix_allowed_tokens_fn, - synced_gpus, + generation_config=generation_config, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + synced_gpus=synced_gpus, decoder_input_ids=decoder_input_ids, **generate_kwargs, ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 5e00230aed4cf5..5ac2a2ccbd5973 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -23,6 +23,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class CacheConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class DynamicCache(metaclass=DummyObject): _backends = ["torch"] @@ -30,6 +37,34 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class HQQQuantizedCache(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class QuantizedCache(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class QuantizedCacheConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class QuantoQuantizedCache(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class SinkCache(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 840b64e17db010..7d654312a3a069 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -27,6 +27,7 @@ from transformers.testing_utils import ( is_flaky, require_accelerate, + require_quanto, require_torch, require_torch_multi_accelerator, slow, @@ -55,7 +56,7 @@ ImageGPTForCausalImageModeling, SpeechEncoderDecoderModel, ) - from transformers.cache_utils import DynamicCache + from transformers.cache_utils import DynamicCache, QuantoQuantizedCache from transformers.generation import ( BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput, @@ -1654,6 +1655,39 @@ def test_new_cache_format(self, num_beams, do_sample): ) ) + @require_quanto + def test_generate_with_quant_cache(self): + for model_class in self.all_generative_model_classes: + if not model_class._supports_quantized_cache: + self.skipTest("This model does not support the quantized cache format") + + config, input_ids, attention_mask = self._get_input_ids_and_config() + config.use_cache = True + config.is_decoder = True + + model = model_class(config).to(torch_device).eval() + generation_kwargs = { + "max_new_tokens": 5, + "cache_implementation": "quantized", + # careful with group size, should be divisor of model's hidden size + "cache_config": {"backend": "quanto", "nbits": 2, "q_group_size": 8, "residual_length": 128}, + "return_dict_in_generate": True, # Required to return `past_key_values` + } + + results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) + self.assertTrue(isinstance(results.past_key_values, QuantoQuantizedCache)) + + # passing past key values of different type should raise Error + with self.assertRaises(ValueError): + model.generate( + input_ids, attention_mask=attention_mask, past_key_valyes=DynamicCache(), **generation_kwargs + ) + + # setting incorrect cache_config args should raise an Error, i.e. nbits=60 does not make sense + generation_kwargs["cache_config"] = {"nbits": 60, "q_group_size": 8, "residual_length": 128} + with self.assertRaises(ValueError): + model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) + def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): batch_size, seq_length = input_ids.shape num_sequences_in_output = batch_size * num_return_sequences diff --git a/tests/quantization/quanto_integration/test_quanto.py b/tests/quantization/quanto_integration/test_quanto.py index 69bf998ace572f..f574478241979d 100644 --- a/tests/quantization/quanto_integration/test_quanto.py +++ b/tests/quantization/quanto_integration/test_quanto.py @@ -17,13 +17,22 @@ import unittest from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, QuantoConfig -from transformers.testing_utils import require_accelerate, require_quanto, require_torch_gpu, slow +from transformers.testing_utils import ( + require_accelerate, + require_quanto, + require_read_token, + require_torch_gpu, + slow, + torch_device, +) from transformers.utils import is_accelerate_available, is_quanto_available, is_torch_available if is_torch_available(): import torch + from transformers import LlamaForCausalLM, LlamaTokenizer + if is_accelerate_available(): from accelerate import init_empty_weights @@ -429,3 +438,28 @@ def test_quantize_activation(self): with self.assertRaises(ValueError) as e: AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", quantization_config=quantization_config) self.assertIn("We don't support quantizing the activations with transformers library", str(e.exception)) + + +@require_torch_gpu +class QuantoKVCacheQuantizationTest(unittest.TestCase): + @slow + @require_read_token + def test_quantized_cache(self): + EXPECTED_TEXT_COMPLETION = [ + "Simply put, the theory of relativity states that 1) the speed of light is the same for all observers, and 2) the laws of physics are the same for all observers.\nThe first part of the theory of relativity", + "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my burgers, my hot dogs, my sandwiches, my chicken, my pizza, my sal", + ] + + prompts = [ + "Simply put, the theory of relativity states that ", + "My favorite all time favorite condiment is ketchup.", + ] + tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="", padding_side="left") + model = LlamaForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", device_map="sequential", torch_dtype=torch.float16 + ) + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(torch_device) + + generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False, cache_implementation="quantized") + text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text)