From 13637f2a690652f0c6ae410322665fd5c34fc503 Mon Sep 17 00:00:00 2001 From: wish2023 Date: Sat, 21 Dec 2024 04:11:59 +0000 Subject: [PATCH] Add support for H2O cache eviction with LLaMA --- benchmark/h20/benchmark_varying_cache.py | 146 +++++++++++++ benchmark/h20/benchmark_varying_output.py | 172 +++++++++++++++ src/transformers/cache_utils.py | 195 ++++++++++++++++++ .../models/cohere/modeling_cohere.py | 1 + src/transformers/models/glm/modeling_glm.py | 1 + .../models/granite/modeling_granite.py | 1 + .../models/llama/modeling_llama.py | 4 + src/transformers/models/olmo/modeling_olmo.py | 1 + .../models/olmo2/modeling_olmo2.py | 1 + .../models/olmoe/modeling_olmoe.py | 1 + tests/h2O/__init__.py | 0 tests/h2O/test_h2O.py | 41 ++++ 12 files changed, 564 insertions(+) create mode 100644 benchmark/h20/benchmark_varying_cache.py create mode 100644 benchmark/h20/benchmark_varying_output.py create mode 100644 tests/h2O/__init__.py create mode 100644 tests/h2O/test_h2O.py diff --git a/benchmark/h20/benchmark_varying_cache.py b/benchmark/h20/benchmark_varying_cache.py new file mode 100644 index 00000000000000..96ef831fc785fa --- /dev/null +++ b/benchmark/h20/benchmark_varying_cache.py @@ -0,0 +1,146 @@ +from transformers import LlamaTokenizer, LlamaForCausalLM, BitsAndBytesConfig +from transformers.cache_utils import H2OCache, DynamicCache +from tqdm import tqdm +from nltk.translate.bleu_score import sentence_bleu +from rouge_score import rouge_scorer +import matplotlib.pyplot as plt +import torch +import time +import copy + +quantization_config = BitsAndBytesConfig(load_in_8bit=True) + +model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", + quantization_config=quantization_config, + device_map="auto", + attn_implementation="eager" +) +tokenizer = LlamaTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") +device = model.device + +def calculate_cache_memory(cache): + total_memory = 0 + for key_token, value_token in zip(cache.key_cache, cache.value_cache): + total_memory += key_token.element_size() * key_token.numel() + total_memory += value_token.element_size() * value_token.numel() + return total_memory + + +def run_generation(model, tokenizer, user_prompt, max_cache_len=None, pre_fill_cache=False): + message = [{"role": "system", "content": "You are a personal fitness coach who creates customized workout plans. Keep advice practical and motivating."}] + total_time = 0 + total_prompt_tokens = 0 + total_output_tokens = 0 + + print("Generating...") + # Run generation + message.append({"role": "user", "content": user_prompt}) + inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device) + input_length = inputs["input_ids"].shape[1] + target_length = 5 * input_length # Total tokens we want to generate + warmup_length = 4 * input_length # Length at which we start timing (that way the H2OCache size is 20% of the sequence length) + + if pre_fill_cache: + assert max_cache_len + past_key_values = H2OCache(max_cache_len=max_cache_len) + else: + past_key_values = DynamicCache() + + # First generate up to warmup length without timing + outputs = model.generate( + **inputs, + do_sample=False, + min_new_tokens=warmup_length, + max_new_tokens=warmup_length, + past_key_values=past_key_values + ) + + # Now time the generation of the remaining tokens + warmup_inputs = copy.deepcopy(inputs) + warmup_inputs["input_ids"] = outputs + warmup_inputs["attention_mask"] = torch.ones_like(outputs) + + # Time the generation of the final portion + remaining_tokens = target_length - warmup_length + start = time.time() + final_outputs = model.generate( + **warmup_inputs, + do_sample=False, + min_new_tokens=remaining_tokens, + max_new_tokens=remaining_tokens, + past_key_values=past_key_values + ) + end = time.time() + + # if pre_fill_cache: + # past_key_values.print_profile_summary() + + total_time += end - start + total_prompt_tokens += final_outputs[:,:input_length].shape[1] + # Only count the tokens generated in the timed portion + total_output_tokens += remaining_tokens + + completion = tokenizer.decode(final_outputs[0, input_length:], skip_special_tokens=True) + message.append({"role": "assistant", "content": completion}) + + throughput = total_output_tokens / total_time + memory = calculate_cache_memory(past_key_values) + + return { + "message": message, + "total_prompt_tokens": total_prompt_tokens, + "total_output_tokens": total_output_tokens, + "total_time": total_time, + "throughput": throughput, + "memory": memory + } + +# Test prompts +user_prompt = "I'm a beginner looking to exercise at home. I have dumbbells and a yoga mat. I can work out 3 times per week." # Run multiple times for better measurement ? + +print("\nRunning without pre-filled cache:") +results_normal = run_generation(model, tokenizer, user_prompt, pre_fill_cache=False) +for key, value in results_normal.items(): + print(f"No-prefill {key}: {value}") + +cache_sizes = [] +bleus = [] +rouges = [] +messages = [] +throughputs = [] +times = [] +memories = [] +for max_cache_len in range(78, 468, 10): # 78 is default + print(f"\nRunning with pre-filled cache size {max_cache_len}") + results_prefill = run_generation(model, tokenizer, user_prompt, max_cache_len=max_cache_len, pre_fill_cache=True) + + # Print comparison + for key, value in results_prefill.items(): + print(f"Prefilled {key}: {value} ") + print(f"Speedup: {results_prefill['throughput']/results_normal['throughput']:.2f}x") + print(f"KV cache memory saved: {100*(results_normal['memory'] - results_prefill['memory'])/results_normal['memory']:.2f}%") + + bleu_score = sentence_bleu([results_normal['message'][-1]["content"].split()], results_prefill['message'][-1]["content"].split()) + print("BLEU Score:", bleu_score) + + scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True) + rouge = scorer.score(results_normal['message'][-1]["content"], results_prefill['message'][-1]["content"])['rougeL'].fmeasure + + cache_sizes.append(max_cache_len) + bleus.append(bleu_score) + rouges.append(rouge) + messages.append([results_normal['message'][-1]["content"]]) + throughputs.append(results_prefill['throughput']) + times.append(results_prefill['total_time']) + memories.append(results_prefill['memory']) + + + + +plt.plot(cache_sizes, bleus, marker = 'o', label='BLEU', markersize=3) +plt.plot(cache_sizes, rouges, marker = 'o', label='ROUGE', markersize=3) +plt.title("Accuracy vs H20 Cache Size (total tokens = 468)") +plt.xlabel('H20 Cache Size') +plt.ylabel('Accuracy') +plt.legend() +plt.savefig('acc_vs_cache_size.png') \ No newline at end of file diff --git a/benchmark/h20/benchmark_varying_output.py b/benchmark/h20/benchmark_varying_output.py new file mode 100644 index 00000000000000..99e66fdab498e7 --- /dev/null +++ b/benchmark/h20/benchmark_varying_output.py @@ -0,0 +1,172 @@ +from transformers import LlamaTokenizer, LlamaForCausalLM, BitsAndBytesConfig +from transformers.cache_utils import H2OCache, DynamicCache +from tqdm import tqdm +from nltk.translate.bleu_score import sentence_bleu +from rouge_score import rouge_scorer +import matplotlib.pyplot as plt +import torch +import time +import copy + +quantization_config = BitsAndBytesConfig(load_in_8bit=True) + +model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", + quantization_config=quantization_config, + device_map="auto", + attn_implementation="eager" +) +tokenizer = LlamaTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") +device = model.device + +def calculate_cache_memory(cache): + total_memory = 0 + for key_token, value_token in zip(cache.key_cache, cache.value_cache): + total_memory += key_token.element_size() * key_token.numel() + total_memory += value_token.element_size() * value_token.numel() + return total_memory + + +def run_generation(model, tokenizer, user_prompt, target_length, max_cache_len=None, pre_fill_cache=False): + message = [{"role": "system", "content": "You are a personal fitness coach who creates customized workout plans. Keep advice practical and motivating."}] + total_time = 0 + total_prompt_tokens = 0 + total_output_tokens = 0 + + print("Generating...") + # Run generation + message.append({"role": "user", "content": user_prompt}) + inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device) + input_length = inputs["input_ids"].shape[1] + target_length = target_length # Total tokens we want to generate + warmup_length = 4 * input_length # Length at which we start timing (that way the H2OCache size is 20% of the sequence length) + + if pre_fill_cache: + assert max_cache_len + past_key_values = H2OCache(max_cache_len=max_cache_len) + else: + past_key_values = DynamicCache() + + # First generate up to warmup length without timing + outputs = model.generate( + **inputs, + do_sample=False, + min_new_tokens=warmup_length, + max_new_tokens=warmup_length, + past_key_values=past_key_values + ) + + # Now time the generation of the remaining tokens + warmup_inputs = copy.deepcopy(inputs) + warmup_inputs["input_ids"] = outputs + warmup_inputs["attention_mask"] = torch.ones_like(outputs) + + # Time the generation of the final portion + remaining_tokens = target_length - warmup_length + start = time.time() + final_outputs = model.generate( + **warmup_inputs, + do_sample=False, + min_new_tokens=remaining_tokens, + max_new_tokens=remaining_tokens, + past_key_values=past_key_values + ) + end = time.time() + + # if pre_fill_cache: + # past_key_values.print_profile_summary() + + total_time += end - start + total_prompt_tokens += final_outputs[:,:input_length].shape[1] + # Only count the tokens generated in the timed portion + total_output_tokens += remaining_tokens + + completion = tokenizer.decode(final_outputs[0, input_length:], skip_special_tokens=True) + message.append({"role": "assistant", "content": completion}) + + throughput = total_output_tokens / total_time + memory = calculate_cache_memory(past_key_values) + + torch.cuda.empty_cache() + + return { + "message": message, + "total_prompt_tokens": total_prompt_tokens, + "total_output_tokens": total_output_tokens, + "total_time": total_time, + "throughput": throughput, + "memory": memory + } + +# Test prompts +user_prompt = "I'm a beginner looking to exercise at home. I have dumbbells and a yoga mat. I can work out 3 times per week. Tell me everything I need to know, you have 2000 words, feel free to ramble on and leave no detail out." # Run multiple times for better measurement ? + +target_lengths = [] +bleus = [] +rouges = [] +messages = [] +throughputs = [] +times = [] +speedups = [] +for target_length in range(505, 3000, 214): # 101 is default + print(f"\nRunning with pre-filled, target length is {target_length}") + results_prefill = run_generation(model, tokenizer, user_prompt, target_length=target_length, max_cache_len=214, pre_fill_cache=True) + + print("\nRunning without pre-filled cache:") + results_normal = run_generation(model, tokenizer, user_prompt, target_length=target_length, pre_fill_cache=False) + + # Print comparison + for key, value in results_prefill.items(): + print(f"Prefilled {key}: {value} ") + for key, value in results_normal.items(): + print(f"No-prefill {key}: {value}") + print(f"Speedup: {results_prefill['throughput']/results_normal['throughput']:.2f}x") + print(f"KV cache memory saved: {100*(results_normal['memory'] - results_prefill['memory'])/results_normal['memory']:.2f}%") + + bleu_score = sentence_bleu([results_normal['message'][-1]["content"].split()], results_prefill['message'][-1]["content"].split()) + print("BLEU Score:", bleu_score) + + scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True) + rouge = scorer.score(results_normal['message'][-1]["content"], results_prefill['message'][-1]["content"])['rougeL'].fmeasure + + + target_lengths.append(target_length) + bleus.append(bleu_score) + rouges.append(rouge) + messages.append([results_normal['message'][-1]["content"]]) + throughputs.append(results_prefill['throughput']) + times.append(results_prefill['total_time']) + speedups.append(results_prefill['throughput']/results_normal['throughput']) + + +# print(messages) +# print(target_lengths) +# print(bleus) +# print(throughputs) +# print(times) + + +plt.plot(target_lengths, bleus, marker = 'o', label='BLEU', markersize=3) +plt.plot(target_lengths, rouges, marker = 'o', label='ROUGE', markersize=3) +plt.title("Accuracy vs Output tokens generated (cache size = 214)") +plt.xlabel('Number of output tokens') +plt.ylabel('Accuracy') +plt.legend() +plt.savefig('acc_vs_output_length.png') + + +# plt.figure(1) +# plt.plot(target_lengths, bleus, marker = 'o', markersize=3) +# plt.title("BLEU Score vs Output tokens generated (cache size = 214)") +# plt.xlabel('Number of output tokens') +# plt.ylabel('BLEU score') +# plt.savefig('bleu_vs_output_length.png') +# plt.close() + + +# plt.figure(2) +# plt.plot(target_lengths, speedups, marker = 'o', markersize=3) +# plt.title("Speedup vs Output tokens generated (cache size = 214)") +# plt.xlabel('Number of output tokens') +# plt.ylabel('Speedup') +# plt.savefig('speedup_vs_output_length.png') +# plt.close() \ No newline at end of file diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index f38fc8f9824d3b..8c6b6fb907cedd 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -58,6 +58,28 @@ def update( """ raise NotImplementedError("Make sure to implement `update` in a subclass.") + def post_process( + self, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with any post processing logic for the specified layer_idx + + Parameters: + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. These are specific to each subclass and allow new types of + cache post processing logic to be implemented. + + Return: + A tuple containing the updated key and value states. + """ + # Default implementation - return current states without modification + if len(self.key_cache) <= layer_idx: + return None, None + return self.key_cache[layer_idx], self.value_cache[layer_idx] + # raise NotImplementedError("Make sure to implement 'post_process' in a subclass") + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" # TODO: deprecate this function in favor of `cache_position` @@ -1381,6 +1403,179 @@ def reset(self): self.value_cache[layer_idx].zero_() +class H2OCache(Cache): + """ + A cache that implements the H2O eviction strategy. + It maintains a balance of the most recent tokens and + heavy-hitter tokens, which are those that receive the highest + cumulative attention scores. The cache updates the attention scores + per layer during generation through the post_process method. + keeping a balance of recent and heavy-hitter tokens. + + Parameters: + max_cache_len (`int`): + The maximum sequence length of the cache. This is split between recent and heavy-hitter sections. + heavy_ratio (`float`, *optional*, defaults to 0.5): + The ratio of the cache dedicated to heavy hitters vs recent tokens. A ratio of 0.5 means half the cache is + used for heavy hitters and half for recent tokens. + device (`str` or `torch.device`, *optional*, defaults to "cuda"): + Device on which to allocate the cache. + + Example: + ```python + >>> from transformers import LlamaTokenizer, LlamaForCausalLM, H2OCache + + >>> model = LlamaForCausalLM.from_pretrained( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + device_map="auto", + attn_implementation="eager") + >>> tokenizer = LlamaTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") + + >>> messages = [{"role": "user", "content": "Tell me a joke"}] + >>> inputs = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + return_tensors="pt", + return_dict=True).to(model.device) + >>> input_length = inputs["input_ids"].shape[1] + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = H2OCache(max_cache_len=32, heavy_ratio=0.5) + >>> outputs = model(**inputs, past_key_values=past_key_values, min_new_tokens=50) + >>> completion = tokenizer.decode(outputs[0, input_length: ], skip_special_tokens=True) + >>> messages.append({"role": "assistant", "content": completion}) + >>> print(messages) + ``` + + """ + + def __init__( + self, max_cache_len: int, heavy_ratio: float = 0.5, device: Optional[Union[torch.device, str]] = None + ): + super().__init__() + self.max_cache_len = max_cache_len + self.heay_ratio = heavy_ratio + self.heavy_size = int(max_cache_len * heavy_ratio) + self.recent_size = max_cache_len - self.heavy_size + + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + self._accum_attn_scores: List[torch.Tensor] = [] # layer wise accumulated attention scores + + # Pre-compute and store indices for heavy and recent sections + self.recent_start = self.max_cache_len - self.recent_size + self.recent_indices = torch.arange(self.recent_start + 1, self.max_cache_len, device=device) + self.heavy_indices_template = torch.arange(self.recent_start + 1, device=device) + + def update( + self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. Manages both + the heavy hitters and recent sections of the cache according to the H2O strategy. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + + Return: + A tuple containing the updated key and value states. + """ + + # Initialize layer caches if needed + if len(self.key_cache) <= layer_idx: + # print(f"Initializing new cache for layer {layer_idx}") + self.key_cache.append(key_states) + self.value_cache.append(value_states) + self._accum_attn_scores.append(torch.empty(self.max_cache_len, device=key_states.device)) + return key_states, value_states + + if self.key_cache[layer_idx].shape[-2] < self.max_cache_len: + # Cache not full yet - concatenate directly + # print(f"Cache not full for layer {layer_idx}, concatenating directly") + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + # Extend accumulated scores tensor + self._accum_attn_scores[layer_idx] = torch.cat( + [self._accum_attn_scores[layer_idx], torch.empty(key_states.shape[-2], device=key_states.device)] + ) + else: + # Get heavy hitters only from the non-recent section + # + 1 to consider the token at position "recent_start" that + # we will remove from the recent_indices to make space for new token + scores_slice = self._accum_attn_scores[layer_idx].narrow(0, 0, self.recent_start + 1) + + heavy_scores, heavy_indices = torch.topk(scores_slice, k=self.heavy_size, sorted=False) + + # Get recent indices (excludes the token at index self.recent_start to make room for the current token) + combined_indices = torch.cat([heavy_indices, self.recent_indices]) + + # Update cache with kept indices and new states + self.key_cache[layer_idx] = torch.cat( + [torch.index_select(self.key_cache[layer_idx], -2, combined_indices), key_states], dim=-2 + ) + self.value_cache[layer_idx] = torch.cat( + [torch.index_select(self.value_cache[layer_idx], -2, combined_indices), value_states], dim=-2 + ) + + # Update accumulated scores tensor with new empty tensor for the token(s) we just added + self._accum_attn_scores[layer_idx] = torch.cat( + [ + torch.index_select(self._accum_attn_scores[layer_idx], 0, combined_indices), + torch.empty(key_states.shape[-2], device=key_states.device), + ] + ) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def post_process( + self, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates accumulated attention scores after each attention computation. This is called during each attention + layer's forward pass, after adding the current token's KV-pair through the update method to maintain running + statistics of token importance. + + Parameters: + layer_idx (`int`): + The index of the layer to process. + cache_kwargs (`Dict[str, Any]`, *optional*): + Additional arguments, must include 'attn_weights' containing the attention weights from the current layer. + + Return: + A tuple containing the layer's key and value cached states. + """ + if cache_kwargs is None or "attn_weights" not in cache_kwargs: + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + attn_weights = cache_kwargs["attn_weights"] + + # Update accumulated attention scores + # Sum attention weights per token accross attention heads, batch, and query length + scores = attn_weights.sum(dim=(0, 1, 2)) # [kv_length] + + if self._accum_attn_scores[layer_idx].size(0) != scores.size(0): + self._accum_attn_scores[layer_idx] = self._accum_attn_scores[layer_idx][: scores.size(0)] + + self._accum_attn_scores[layer_idx] += scores + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of cached states.""" + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def get_max_cache_shape(self) -> Optional[int]: + """Returns maximum sequence length the cache can hold.""" + return self.max_cache_len + + class EncoderDecoderCache(Cache): """ Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 7b8b9547ac1c33..4e98f585527c33 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -673,6 +673,7 @@ class CoherePreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_h2o_cache = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 95ad0d9719951d..b5886f1bf89eb1 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -406,6 +406,7 @@ class GlmPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_h2o_cache = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 2e045e149d95de..20e6c8816e7016 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -405,6 +405,7 @@ class GranitePreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_h2o_cache = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 5be33c26414cd7..0d1d9212a36079 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -303,6 +303,10 @@ def forward( **kwargs, ) + if past_key_value is not None: + cache_kwargs = {"attn_weights": attn_weights} + key_states, value_states = past_key_value.post_process(self.layer_idx, cache_kwargs) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 11d3d99f4f72c9..ffc3fd2cd8c014 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -370,6 +370,7 @@ class OlmoPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_h2o_cache = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 49ae798e7f1101..36ab21aa2402b6 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -371,6 +371,7 @@ class Olmo2PreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_h2o_cache = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index fa3c2f3cd4d11b..d0227843481cd8 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -772,6 +772,7 @@ class OlmoePreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_h2o_cache = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/tests/h2O/__init__.py b/tests/h2O/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/h2O/test_h2O.py b/tests/h2O/test_h2O.py new file mode 100644 index 00000000000000..8ceedf04a89420 --- /dev/null +++ b/tests/h2O/test_h2O.py @@ -0,0 +1,41 @@ +import unittest + +from transformers import BitsAndBytesConfig, LlamaForCausalLM, LlamaTokenizer +from transformers.cache_utils import H2OCache +from transformers.testing_utils import require_torch, torch_device + + +@require_torch +class TestH2OCache(unittest.TestCase): + @classmethod + def setUpClass(cls): + quantization_config = BitsAndBytesConfig(load_in_8bit=True) + cls.model = LlamaForCausalLM.from_pretrained( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + quantization_config=quantization_config, + device_map="auto", + attn_implementation="eager", + ) + cls.tokenizer = LlamaTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") + + @unittest.skipIf(torch_device == "cpu", "Requires CUDA") + def test_h2o_cache_response(self): + past_key_values = H2OCache(max_cache_len=50, device=torch_device) + + messages = [ + {"role": "system", "content": "You are a friendly chatbot."}, + {"role": "user", "content": "Tell me a joke."}, + ] + + prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) + + outputs = self.model.generate(**inputs, do_sample=False, max_new_tokens=50, past_key_values=past_key_values) + response = self.tokenizer.decode(outputs[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True) + + self.assertIsInstance(response, str, "Response should be a string.") + self.assertGreater(len(response.strip()), 0, "Response should not be empty.") + + +if __name__ == "__main__": + unittest.main()