Skip to content

Commit

Permalink
nits
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed May 22, 2024
1 parent ae17ea4 commit 9c08cec
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
14 changes: 9 additions & 5 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1325,10 +1325,12 @@ def _get_initial_cache_position(self, input_ids, model_kwargs):
model_kwargs["cache_position"] = cache_position + past_length - 1
return model_kwargs

def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int) -> Cache:
def _get_cache(
self, cache_implementation: str, max_batch_size: int, max_cache_len: int, device: Union[torch.device, str]
) -> Cache:
"""
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
new `generate` call requires a larger cache.
new `generate` call requires a larger cache or uses a different batch size.
Returns the resulting cache object.
"""
Expand All @@ -1351,12 +1353,14 @@ def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_l
if hasattr(self.config, "_pre_quantization_dtype"):
cache_dtype = self.config._pre_quantization_dtype
else:
cache_dtype = self.dtype
# `self.dtype` (which calls `ModuleUtilsMixin.dtype()`) is not compileable, so we fall back to
# `self.config.torch_dtype`. Compiling `generate` after calling `model.to(some_dtype)` will fail
cache_dtype = self.dtype if not is_torchdynamo_compiling() else self.config.torch_dtype
self._cache = cache_cls(
config=self.config,
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
device=self.device,
device=device,
dtype=cache_dtype,
)
else:
Expand Down Expand Up @@ -1631,7 +1635,7 @@ def generate(
"issue: https://github.com/huggingface/transformers/issues/28981"
)
model_kwargs["past_key_values"] = self._get_cache(
generation_config.cache_implementation, batch_size, generation_config.max_length
generation_config.cache_implementation, batch_size, generation_config.max_length, device=device
)
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -717,9 +717,11 @@ def bigbird_block_sparse_attention(
attention_probs[:, :, -2 * from_block_size : -from_block_size, :to_block_size] = second_last_attn_weights[
:, :, :, :to_block_size
] # 1st key block (global)
attention_probs[:, :, -2 * from_block_size : -from_block_size, -3 * to_block_size :] = (
second_last_attn_weights[:, :, :, to_block_size : 4 * to_block_size]
) # last three blocks (global + sliding)
attention_probs[
:, :, -2 * from_block_size : -from_block_size, -3 * to_block_size :
] = second_last_attn_weights[
:, :, :, to_block_size : 4 * to_block_size
] # last three blocks (global + sliding)
# random keys
for p1, i1, w1 in zip(range(bsz), rand_attn, second_last_attn_weights):
# p1, i1, w1 corresponds to batch_dim i.e. following operation is done for each sequence in batch
Expand Down

0 comments on commit 9c08cec

Please sign in to comment.