diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d13777b337789b..dcc6403a294135 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1325,10 +1325,12 @@ def _get_initial_cache_position(self, input_ids, model_kwargs): model_kwargs["cache_position"] = cache_position + past_length - 1 return model_kwargs - def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int) -> Cache: + def _get_cache( + self, cache_implementation: str, max_batch_size: int, max_cache_len: int, device: Union[torch.device, str] + ) -> Cache: """ Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a - new `generate` call requires a larger cache. + new `generate` call requires a larger cache or uses a different batch size. Returns the resulting cache object. """ @@ -1351,12 +1353,14 @@ def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_l if hasattr(self.config, "_pre_quantization_dtype"): cache_dtype = self.config._pre_quantization_dtype else: - cache_dtype = self.dtype + # `self.dtype` (which calls `ModuleUtilsMixin.dtype()`) is not compileable, so we fall back to + # `self.config.torch_dtype`. Compiling `generate` after calling `model.to(some_dtype)` will fail + cache_dtype = self.dtype if not is_torchdynamo_compiling() else self.config.torch_dtype self._cache = cache_cls( config=self.config, max_batch_size=max_batch_size, max_cache_len=max_cache_len, - device=self.device, + device=device, dtype=cache_dtype, ) else: @@ -1631,7 +1635,7 @@ def generate( "issue: https://github.com/huggingface/transformers/issues/28981" ) model_kwargs["past_key_values"] = self._get_cache( - generation_config.cache_implementation, batch_size, generation_config.max_length + generation_config.cache_implementation, batch_size, generation_config.max_length, device=device ) self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index d1ba54213a0346..883b598415f0d6 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -717,9 +717,11 @@ def bigbird_block_sparse_attention( attention_probs[:, :, -2 * from_block_size : -from_block_size, :to_block_size] = second_last_attn_weights[ :, :, :, :to_block_size ] # 1st key block (global) - attention_probs[:, :, -2 * from_block_size : -from_block_size, -3 * to_block_size :] = ( - second_last_attn_weights[:, :, :, to_block_size : 4 * to_block_size] - ) # last three blocks (global + sliding) + attention_probs[ + :, :, -2 * from_block_size : -from_block_size, -3 * to_block_size : + ] = second_last_attn_weights[ + :, :, :, to_block_size : 4 * to_block_size + ] # last three blocks (global + sliding) # random keys for p1, i1, w1 in zip(range(bsz), rand_attn, second_last_attn_weights): # p1, i1, w1 corresponds to batch_dim i.e. following operation is done for each sequence in batch