diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 268c010a6a04ab..1f5a164815aaed 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1127,6 +1127,7 @@ def batch_select_indices(self, indices: torch.Tensor): self.self_attention_cache.batch_select_indices(indices) self.cross_attention_cache.batch_select_indices(indices) + class HybridCache(Cache): def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None: if not hasattr(config, "sliding_window") or config.sliding_window is None: