Skip to content

Commit

Permalink
Fix: Enable prefill phase key value caching of nemotron/minitron mode…
Browse files Browse the repository at this point in the history
…ls (#34742)

* modeling nemotron kv caching bugfix

Signed-off-by: jeongin601 <[email protected]>

* test file deleted

Signed-off-by: jeongin601 <[email protected]>

* code refinement

Signed-off-by: jeongin601 <[email protected]>

* remove unused variables

Signed-off-by: jeongin601 <[email protected]>

* import block sorted

* removed deprecation warning

Signed-off-by: jeongin601 <[email protected]>

* removed support for tuple shape past_key_values

Signed-off-by: jeongin601 <[email protected]>

* Update conditional statement for cache initialization

Co-authored-by: Arthur <[email protected]>

---------

Signed-off-by: jeongin601 <[email protected]>
Co-authored-by: Arthur <[email protected]>
  • Loading branch information
jeongin601 and ArthurZucker authored Nov 25, 2024
1 parent 3a8eb74 commit 318fe25
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/transformers/models/nemotron/modeling_nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torch import Size, Tensor, nn

from ...activations import ACT2FN
from ...cache_utils import Cache, StaticCache
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import _flash_attention_forward
Expand Down Expand Up @@ -783,8 +783,14 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

if use_cache and past_key_values is None:
past_key_values = DynamicCache()

if cache_position is None:
cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

Expand Down

0 comments on commit 318fe25

Please sign in to comment.