From 4aa72ca49f0123a4c5f52e3eec89bc323416470d Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Mon, 22 Apr 2024 14:42:57 +0200 Subject: [PATCH] `Llama` family, fix `use_cache=False` generation (#30380) * nit to make sure cache positions are not sliced * fix other models * nit * style --- src/transformers/models/cohere/modeling_cohere.py | 13 ++++++++++--- src/transformers/models/gemma/modeling_gemma.py | 13 ++++++++++--- src/transformers/models/llama/modeling_llama.py | 13 ++++++++++--- src/transformers/models/olmo/modeling_olmo.py | 13 ++++++++++--- 4 files changed, 40 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 950d45ea867a30..41bb4c0516928c 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -1175,7 +1175,14 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + **kwargs, ): # With static cache, the `past_key_values` is None # TODO joao: standardize interface for the different Cache classes and remove of this if @@ -1239,7 +1246,7 @@ def prepare_inputs_for_generation( input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] if cache_position is None: cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - else: + elif use_cache: cache_position = cache_position[-input_length:] if has_static_cache: @@ -1250,7 +1257,7 @@ def prepare_inputs_for_generation( "position_ids": position_ids, "cache_position": cache_position, "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), + "use_cache": use_cache, "attention_mask": attention_mask, } ) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 6077259d0b0fac..e5b6b207748a53 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -1157,7 +1157,14 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + **kwargs, ): # With static cache, the `past_key_values` is None # TODO joao: standardize interface for the different Cache classes and remove of this if @@ -1221,7 +1228,7 @@ def prepare_inputs_for_generation( input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] if cache_position is None: cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - else: + elif use_cache: cache_position = cache_position[-input_length:] if has_static_cache: @@ -1232,7 +1239,7 @@ def prepare_inputs_for_generation( "position_ids": position_ids, "cache_position": cache_position, "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), + "use_cache": use_cache, "attention_mask": attention_mask, } ) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 2b8e8f6d0958dd..905edf5f71a63d 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1253,7 +1253,14 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + **kwargs, ): # With static cache, the `past_key_values` is None # TODO joao: standardize interface for the different Cache classes and remove of this if @@ -1317,7 +1324,7 @@ def prepare_inputs_for_generation( input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] if cache_position is None: cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - else: + elif use_cache: cache_position = cache_position[-input_length:] if has_static_cache: @@ -1328,7 +1335,7 @@ def prepare_inputs_for_generation( "position_ids": position_ids, "cache_position": cache_position, "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), + "use_cache": use_cache, "attention_mask": attention_mask, } ) diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 83637536a12531..e3b0e05127c52d 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -1234,7 +1234,14 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + **kwargs, ): # With static cache, the `past_key_values` is None # TODO joao: standardize interface for the different Cache classes and remove of this if @@ -1298,7 +1305,7 @@ def prepare_inputs_for_generation( input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] if cache_position is None: cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - else: + elif use_cache: cache_position = cache_position[-input_length:] if has_static_cache: @@ -1309,7 +1316,7 @@ def prepare_inputs_for_generation( "position_ids": position_ids, "cache_position": cache_position, "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), + "use_cache": use_cache, "attention_mask": attention_mask, } )