From 18871599c9ae76f7b5a09186b2c09fc5b8826604 Mon Sep 17 00:00:00 2001 From: Jonathan Mamou Date: Thu, 21 Nov 2024 15:46:35 +0200 Subject: [PATCH] Fix heuristic scheduling for UAG (#34805) * fix heuristic schedule * fix style * fix format --- src/transformers/generation/candidate_generator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index d8344c25a6526a..df213b458cf8bb 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -255,7 +255,8 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F "heuristic", "heuristic_transient", }: - if num_matches == int(self.num_assistant_tokens): + # len(scores[0])-1 is the number of candidates according to the target tokenizer. + if num_matches == len(scores[0]) - 1: self.num_assistant_tokens += 2.0 else: self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0)