Skip to content

Commit

Permalink
done
Browse files Browse the repository at this point in the history
Signed-off-by: N <[email protected]>
  • Loading branch information
yao-matrix committed Nov 20, 2024
1 parent c967bbe commit 980aa08
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,6 @@ def __init__(
"Passing `MinLengthLogitsProcessor` when using `assisted_generation is disabled. "
"Please pass in `min_length` into `.generate()` instead"
)
# assume cache created while _prepare_cache_for_generation is called
self.generation_config.cache_implementation = None


def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""
Expand Down Expand Up @@ -229,6 +226,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,

# 3. Update variables for the next round of candidate generation
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
self.generation_config.cache_implementation = None

# 4. Prepare variables for output
candidate_logits = torch.stack(assistant_output.scores, dim=1)
Expand Down

0 comments on commit 980aa08

Please sign in to comment.