From 77c5d59e0ee7d59fc12dbf83fbbfcc162a7491af Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 20 Sep 2024 17:01:49 +0100 Subject: [PATCH] Generate: assistant should sample when the main model samples (#33534) --- src/transformers/generation/candidate_generator.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 62d5fb6eed0c49..0b799dceb267c2 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -160,12 +160,6 @@ def __init__( self.generation_config.output_scores = True self.generation_config.assistant_confidence_threshold = self.assistant_confidence_threshold - # Disable sampling -- this implementation of assisted generation/speculative decoding uses the assistant - # greedily to maximize matches. Disables sampling-related flags to prevent warnings - self.generation_config.do_sample = False - for attr in ("temperature", "top_p", "min_p", "typical_p", "top_k", "epsilon_cutoff", "eta_cutoff"): - setattr(self.generation_config, attr, None) - # avoid unnecessary warnings that min_length is larger than max_new_tokens # remove the `MinLengthLogitsProcessor` if exists (NOTE: no need to check for `MinNewTokensLogitsProcessor`) self.main_model_min_length = self.generation_config.min_length