diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 8de6bc90ea3fec..1c56ecd56f54ae 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -24,7 +24,7 @@ from torch import Size, Tensor, nn from ...activations import ACT2FN -from ...cache_utils import Cache, StaticCache +from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward @@ -783,8 +783,14 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + if cache_position is None: - cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) if position_ids is None: position_ids = cache_position.unsqueeze(0)