Skip to content

Commit

Permalink
Fix heuristic scheduling for UAG (#34805)
Browse files Browse the repository at this point in the history
* fix heuristic schedule

* fix style

* fix format
  • Loading branch information
jmamou authored Nov 21, 2024
1 parent d6a5c23 commit 1887159
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F
"heuristic",
"heuristic_transient",
}:
if num_matches == int(self.num_assistant_tokens):
# len(scores[0])-1 is the number of candidates according to the target tokenizer.
if num_matches == len(scores[0]) - 1:
self.num_assistant_tokens += 2.0
else:
self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0)
Expand Down

0 comments on commit 1887159

Please sign in to comment.