From 9b78a6eee32c76316d55bbb86c507625bf79ac40 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 14 Dec 2023 17:04:58 +0100 Subject: [PATCH] fix: only keep stop sequence buffer if we have some --- server/text_generation_server/utils/tokens.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 53722fec054..04cc8d97091 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -112,7 +112,7 @@ def __init__( self.stop_sequence_criterias = stop_sequence_criterias self.max_new_tokens = max_new_tokens self.current_tokens = 0 - self.current_output = "test" + self.current_output = "" self.ignore_eos_token = ignore_eos_token def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]: @@ -123,14 +123,15 @@ def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[st if not self.ignore_eos_token and last_token == self.eos_token_id: return True, FinishReason.FINISH_REASON_EOS_TOKEN - self.current_output += last_output - # There is no need to keep an output that is too long - if len(self.current_output) > 300: - # Slice to -200 to avoid doing it all the time - self.current_output = self.current_output[-200:] - for stop_sequence_criteria in self.stop_sequence_criterias: - if stop_sequence_criteria(self.current_output): - return True, FinishReason.FINISH_REASON_STOP_SEQUENCE + if self.stop_sequence_criterias: + self.current_output += last_output + # There is no need to keep an output that is too long + if len(self.current_output) > 300: + # Slice to -200 to avoid doing it all the time + self.current_output = self.current_output[-200:] + for stop_sequence_criteria in self.stop_sequence_criterias: + if stop_sequence_criteria(self.current_output): + return True, FinishReason.FINISH_REASON_STOP_SEQUENCE return False, None