diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index b5ef6669d8c5d1..fe1aa2aadaf2c5 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -374,31 +374,28 @@ def update( Return: A tuple containing the updated key and value states. """ - attention_mask = cache_kwargs.get("attention_mask") - # make sure the parts that are not seen are masked as well - + # place each cache on the correct layer device, not optimised? if self.seen_tokens == 0: self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device) self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device) self.causal_4d_mask = self.causal_4d_mask.to(value_states.device) if attention_mask is not None: + # if the past length changes then we do have a problem _, _, query_length, past_length = attention_mask.shape - self.causal_4d_mask[:,:,self.seen_tokens:self.seen_tokens + key_states.shape[-2],:past_length] = attention_mask - attention_mask = self.causal_4d_mask[:,:, self.seen_tokens:self.seen_tokens + key_states.shape[-2],:] + self.causal_4d_mask[:,:,self.seen_tokens:self.seen_tokens + query_length,:past_length] = attention_mask + attention_mask = self.causal_4d_mask[:,:, self.seen_tokens:self.seen_tokens + query_length,:] - # Update the cache - if len(self.key_cache) + 1 == self.max_sequence_length: - # let's overwrite and roll the cache to support going beyond? - raise ValueError("Your are going outside the allocated cache") - else: - self.key_cache[layer_idx][:, :, self.seen_tokens: self.seen_tokens + key_states.shape[-2]] = key_states - self.value_cache[layer_idx][:, :, self.seen_tokens: self.seen_tokens + key_states.shape[-2]] = value_states + self.key_cache[layer_idx][:, :, self.seen_tokens: self.seen_tokens + key_states.shape[-2]] = key_states + self.value_cache[layer_idx][:, :, self.seen_tokens: self.seen_tokens + key_states.shape[-2]] = value_states # Update the number of seen tokens if layer_idx == self.num_layers - 1: + # Update the cache + if self.seen_tokens + key_states.shape[-2] > self.max_sequence_length: + raise ValueError("Your are going outside the allocated cache") self.seen_tokens += key_states.shape[-2] return self.key_cache[layer_idx], self.value_cache[layer_idx], attention_mask