Skip to content

Commit

Permalink
fix: only keep stop sequence buffer if we have some
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene committed Dec 14, 2023
1 parent 80a6920 commit 9b78a6e
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions server/text_generation_server/utils/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(
self.stop_sequence_criterias = stop_sequence_criterias
self.max_new_tokens = max_new_tokens
self.current_tokens = 0
self.current_output = "test"
self.current_output = ""
self.ignore_eos_token = ignore_eos_token

def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
Expand All @@ -123,14 +123,15 @@ def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[st
if not self.ignore_eos_token and last_token == self.eos_token_id:
return True, FinishReason.FINISH_REASON_EOS_TOKEN

self.current_output += last_output
# There is no need to keep an output that is too long
if len(self.current_output) > 300:
# Slice to -200 to avoid doing it all the time
self.current_output = self.current_output[-200:]
for stop_sequence_criteria in self.stop_sequence_criterias:
if stop_sequence_criteria(self.current_output):
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
if self.stop_sequence_criterias:
self.current_output += last_output
# There is no need to keep an output that is too long
if len(self.current_output) > 300:
# Slice to -200 to avoid doing it all the time
self.current_output = self.current_output[-200:]
for stop_sequence_criteria in self.stop_sequence_criterias:
if stop_sequence_criteria(self.current_output):
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE

return False, None

Expand Down

0 comments on commit 9b78a6e

Please sign in to comment.