Skip to content

Commit

Permalink
nits
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Jan 10, 2024
1 parent 07f5cdc commit f769b0e
Showing 1 changed file with 31 additions and 26 deletions.
57 changes: 31 additions & 26 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,8 @@ def reorder_cache(self, beam_idx: torch.LongTensor):
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))


class StaticCache(Cache):
import torch.nn as nn
class StaticCache(Cache, nn.Module):

def __init__(self, config: PretrainedConfig, max_batch_size) -> None:
super().__init__()
Expand All @@ -334,13 +334,13 @@ def __init__(self, config: PretrainedConfig, max_batch_size) -> None:
self.max_sequence_length = config.max_position_embeddings
self.head_dim = config.hidden_size // config.num_attention_heads
self.num_heads = config.num_attention_heads
self.shape = (max_batch_size, self.max_sequence_length, config.hidden_size // self.num_heads, self.num_heads)
self.dtype = config.torch_dtype if config.torch_dtype is not None else torch.float16

# TODO device meta?
cache_shape = (max_batch_size, self.num_heads, self.max_sequence_length, self.head_dim)

self.key_cache: List[torch.Tensor] = [torch.zeros(max_batch_size, self.num_heads, self.max_sequence_length, self.head_dim, dtype=self.dtype) for _ in range(self.num_layers)]
self.value_cache: List[torch.Tensor] = [torch.zeros(max_batch_size, self.num_heads, self.max_sequence_length, self.head_dim, dtype=self.dtype) for _ in range(self.num_layers)]
self.key_cache: List[torch.Tensor] = [torch.zeros(cache_shape, dtype=self.dtype, device = "cuda") for _ in range(self.num_layers)]
self.value_cache: List[torch.Tensor] = [torch.zeros(cache_shape, dtype=self.dtype, device = "cuda") for _ in range(self.num_layers)]

# FIXME our format should be the followingm, but most of our nlp model apply transpose operation on the k and v so we can't
# self.key_cache: List[torch.Tensor] = [torch.zeros(max_batch_size, max_sequence_length, num_heads, self.head_dim, dtype=dtype) for _ in range(num_layers)]
Expand Down Expand Up @@ -377,28 +377,33 @@ def update(
attention_mask = cache_kwargs.get("attention_mask")

# place each cache on the correct layer device, not optimised?
if self.seen_tokens == 0:
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device)
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device)
self.causal_4d_mask = self.causal_4d_mask.to(value_states.device)
# if self.seen_tokens == 0:
# self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device)
# self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device)
# self.causal_4d_mask = self.causal_4d_mask.to(value_states.device)

if attention_mask is not None:
# if the past length changes then we do have a problem
_, _, query_length, past_length = attention_mask.shape
self.causal_4d_mask[:,:,self.seen_tokens:self.seen_tokens + query_length,:past_length] = attention_mask
attention_mask = self.causal_4d_mask[:,:, self.seen_tokens:self.seen_tokens + query_length,:]
# if attention_mask is not None:
# # if the past length changes then we do have a problem
# _, _, query_length, past_length = attention_mask.shape
# self.causal_4d_mask[:,:,self.seen_tokens:self.seen_tokens + query_length,:past_length] = attention_mask
# attention_mask = self.causal_4d_mask[:,:, self.seen_tokens:self.seen_tokens + query_length,:]

self.key_cache[layer_idx][:, :, self.seen_tokens: self.seen_tokens + key_states.shape[-2]] = key_states
self.value_cache[layer_idx][:, :, self.seen_tokens: self.seen_tokens + key_states.shape[-2]] = value_states

# Update the number of seen tokens
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
prev_pos = self.seen_tokens*key_states.shape[-2]
pos = torch.arange(prev_pos, prev_pos + key_states.shape[-2], dtype=torch.int)
# k_out[:, :, pos] = key_states
# v_out[:, :, pos] = value_states
k_out.index_fill_(2, pos, key_states)
v_out.index_fill_(2, pos, value_states)
# # Update the number of seen tokens
if layer_idx == self.num_layers - 1:
# Update the cache
if self.seen_tokens + key_states.shape[-2] > self.max_sequence_length:
raise ValueError("Your are going outside the allocated cache")
self.seen_tokens += key_states.shape[-2]

return self.key_cache[layer_idx], self.value_cache[layer_idx], attention_mask
return self.key_cache[layer_idx], self.value_cache[layer_idx] , attention_mask

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model. A layer index can be optionally passed."""
Expand All @@ -408,10 +413,10 @@ def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
return self.max_sequence_length

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
# def reorder_cache(self, beam_idx: torch.LongTensor):
# """Reorders the cache for beam search, given the selected beam indices."""
# for layer_idx in range(len(self.key_cache)):
# device = self.key_cache[layer_idx].device
# self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
# device = self.value_cache[layer_idx].device
# self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))

0 comments on commit f769b0e

Please sign in to comment.