Skip to content

Commit

Permalink
more nits
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Jan 9, 2024
1 parent e05f8da commit 07f5cdc
Showing 1 changed file with 9 additions and 12 deletions.
21 changes: 9 additions & 12 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,31 +374,28 @@ def update(
Return:
A tuple containing the updated key and value states.
"""

attention_mask = cache_kwargs.get("attention_mask")

# make sure the parts that are not seen are masked as well

# place each cache on the correct layer device, not optimised?
if self.seen_tokens == 0:
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device)
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device)
self.causal_4d_mask = self.causal_4d_mask.to(value_states.device)

if attention_mask is not None:
# if the past length changes then we do have a problem
_, _, query_length, past_length = attention_mask.shape
self.causal_4d_mask[:,:,self.seen_tokens:self.seen_tokens + key_states.shape[-2],:past_length] = attention_mask
attention_mask = self.causal_4d_mask[:,:, self.seen_tokens:self.seen_tokens + key_states.shape[-2],:]
self.causal_4d_mask[:,:,self.seen_tokens:self.seen_tokens + query_length,:past_length] = attention_mask
attention_mask = self.causal_4d_mask[:,:, self.seen_tokens:self.seen_tokens + query_length,:]

# Update the cache
if len(self.key_cache) + 1 == self.max_sequence_length:
# let's overwrite and roll the cache to support going beyond?
raise ValueError("Your are going outside the allocated cache")
else:
self.key_cache[layer_idx][:, :, self.seen_tokens: self.seen_tokens + key_states.shape[-2]] = key_states
self.value_cache[layer_idx][:, :, self.seen_tokens: self.seen_tokens + key_states.shape[-2]] = value_states
self.key_cache[layer_idx][:, :, self.seen_tokens: self.seen_tokens + key_states.shape[-2]] = key_states
self.value_cache[layer_idx][:, :, self.seen_tokens: self.seen_tokens + key_states.shape[-2]] = value_states

# Update the number of seen tokens
if layer_idx == self.num_layers - 1:
# Update the cache
if self.seen_tokens + key_states.shape[-2] > self.max_sequence_length:
raise ValueError("Your are going outside the allocated cache")
self.seen_tokens += key_states.shape[-2]

return self.key_cache[layer_idx], self.value_cache[layer_idx], attention_mask
Expand Down

0 comments on commit 07f5cdc

Please sign in to comment.