Skip to content

Commit

Permalink
Offloaded cache: fix generate (#34921)
Browse files Browse the repository at this point in the history
* fix cache impl

* require_torch_gpu

* fix mamba

* fix copies
  • Loading branch information
zucchini-nlp authored Nov 28, 2024
1 parent 57ca9e6 commit 5e8c1d7
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 19 deletions.
62 changes: 46 additions & 16 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,13 +1140,13 @@ def __init__(
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None:
super().__init__()
if max_batch_size is not None:
if batch_size is not None:
logger.warning_once(
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.46. Use the more precisely named 'batch_size' argument instead."
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
)

self.batch_size = batch_size or max_batch_size
self.max_batch_size = batch_size or max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len

# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
Expand Down Expand Up @@ -1254,6 +1254,14 @@ def reset(self):
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()

@property
def batch_size(self):
logger.warning_once(
f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
)
return self.max_batch_size


class SlidingWindowCache(StaticCache):
"""
Expand Down Expand Up @@ -1626,10 +1634,10 @@ def __init__(
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None:
super().__init__()
if max_batch_size is not None:
if batch_size is not None:
logger.warning_once(
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.46. Use the more precisely named 'batch_size' argument instead."
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
)
if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError(
Expand All @@ -1638,7 +1646,7 @@ def __init__(
"config and it's not set to None."
)
self.max_cache_len = max_cache_len
self.batch_size = batch_size or max_batch_size
self.max_batch_size = batch_size or max_batch_size
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
Expand Down Expand Up @@ -1758,6 +1766,14 @@ def reset(self):
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()

@property
def batch_size(self):
logger.warning_once(
f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
)
return self.max_batch_size


class MambaCache:
"""
Expand Down Expand Up @@ -1815,28 +1831,28 @@ def __init__(
device: Optional[Union[torch.device, str]] = None,
max_batch_size: Optional[int] = None,
):
if max_batch_size is not None:
if batch_size is not None:
logger.warning_once(
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.46. Use the more precisely named 'batch_size' argument instead."
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
)
self.dtype = dtype
self.batch_size = batch_size or max_batch_size
self.max_batch_size = batch_size or max_batch_size
self.intermediate_size = config.intermediate_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel

self.conv_states: torch.Tensor = torch.zeros(
config.num_hidden_layers,
self.batch_size,
self.max_batch_size,
self.intermediate_size,
self.conv_kernel_size,
device=device,
dtype=dtype,
)
self.ssm_states: torch.Tensor = torch.zeros(
config.num_hidden_layers,
self.batch_size,
self.max_batch_size,
self.intermediate_size,
self.ssm_state_size,
device=device,
Expand Down Expand Up @@ -1866,6 +1882,14 @@ def reset(self):
self.conv_states.zero_()
self.ssm_states.zero_()

@property
def batch_size(self):
logger.warning_once(
f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
)
return self.max_batch_size


class OffloadedStaticCache(StaticCache):
"""
Expand All @@ -1887,6 +1911,9 @@ class OffloadedStaticCache(StaticCache):
The default `dtype` to use when initializing the cache.
offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`):
The device to offload to. Defaults to CPU.
layer_device_map (`Dict[int, Union[str, torch.device, int]]`, *optional*):
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
Attributes:
key_cache (`List[torch.Tensor]`):
Expand Down Expand Up @@ -1933,18 +1960,21 @@ def __init__(
device: Union[str, torch.device],
dtype: Optional[torch.dtype] = None,
offload_device: Union[str, torch.device] = torch.device("cpu"),
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None:
self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.device = torch.device(device)
self.device = torch.device(device) if layer_device_map is None else layer_device_map[0]
self.offload_device = torch.device(offload_device)
self.dtype = dtype if dtype is not None else torch.float32

# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads

num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
else config.num_key_value_heads
)

cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim)
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@
"mamba": MambaCache,
}
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
ALL_CACHE_IMPLEMENTATIONS = list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(NEEDS_CACHE_CONFIG.keys())
ALL_CACHE_IMPLEMENTATIONS = (
list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(NEEDS_CACHE_CONFIG.keys()) + ["offloaded"]
)


class GenerationMode(ExplicitEnum):
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1610,7 +1610,7 @@ def _get_cache(
need_new_cache = (
not hasattr(self, "_cache")
or (not isinstance(cache_to_check, cache_cls))
or cache_to_check.batch_size != batch_size
or cache_to_check.max_batch_size != batch_size
)
if cache_implementation != "mamba":
need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len
Expand Down Expand Up @@ -1666,7 +1666,7 @@ def get_layer_device_map(execution_device_map: Optional[dict] = None):

cache_kwargs = {
"config": self.config.get_text_config(),
"batch_size": batch_size,
"max_batch_size": batch_size,
"max_cache_len": max_cache_len,
"device": device,
"dtype": cache_dtype,
Expand Down
26 changes: 26 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1880,6 +1880,32 @@ def test_new_cache_format(self, num_beams, do_sample):
)
)

@parameterized.expand([("offloaded",)]) # ("offloaded_static",) TODO: @raushan fixme in some models (eg T5)
@require_torch_gpu
@pytest.mark.generate
def test_offloaded_cache_implementation(self, cache_implementation):
"""Tests we can generate by indicating `cache_implementation` for each possible cache class"""
for model_class in self.all_generative_model_classes:
if not model_class._supports_cache_class:
self.skipTest(reason="This model does not support the new cache format")

config, inputs_dict = self.prepare_config_and_inputs_for_generate()

model = model_class(config).to(torch_device).eval()
generation_kwargs = {
"max_new_tokens": 5,
"use_cache": True,
"cache_implementation": cache_implementation,
}

legacy_results = model.generate(**generation_kwargs, **inputs_dict)

# Most cache classes have their own tests except for some that are tested here
# The ones here do not need special treatment when passing `cache_implementation`
# and are not bound to specific models only
new_results = model.generate(**generation_kwargs, **inputs_dict)
self.assertListEqual(legacy_results.tolist(), new_results.tolist())

@pytest.mark.generate
def test_generate_with_static_cache(self):
"""
Expand Down
8 changes: 8 additions & 0 deletions tests/models/mllama/test_modeling_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

import unittest

import pytest
import requests
from parameterized import parameterized

from transformers import (
AutoProcessor,
Expand Down Expand Up @@ -365,6 +367,12 @@ def test_sdpa_can_compile_dynamic(self):
def test_model_parallelism(self):
pass

@parameterized.expand([("offloaded",)])
@pytest.mark.generate
@unittest.skip(reason="Offloaded cache seems to not work with mllama's kv cache type")
def test_offloaded_cache_implementation(self, cache_implementation):
pass

def test_generate_text_only_with_cache(self):
"""
Tests that our cached generation with text-only inputs works. When mllama was introduced, this feature
Expand Down
6 changes: 6 additions & 0 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,12 @@ def test_training_gradient_checkpointing_use_reentrant_false(self):
def test_generate_with_head_masking(self):
pass

@parameterized.expand([("offloaded",)])
@pytest.mark.generate
@unittest.skip(reason="Whisper doesnt work with offloaded cache implementation yet")
def test_offloaded_cache_implementation(self, cache_implementation):
pass

@require_torch_fp16
def test_generate_fp16(self):
config, input_dict = self.model_tester.prepare_config_and_inputs()
Expand Down

0 comments on commit 5e8c1d7

Please sign in to comment.