diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index cec9ae5521a..df5be0b1c6d 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -641,8 +641,10 @@ def generate_token( if i % self.world_size == self.rank: if stop: # Decode generated tokens - output_text = self.decode( - all_input_ids[-stopping_criteria.current_tokens :, 0] + output_text, _, _ = self.decode_token( + all_input_ids[:, 0], + prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, + read_offset=len(all_input_ids) - stopping_criteria.current_tokens, ) # Get seed if isinstance(next_token_chooser.choice, Sampling): diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index d6af07f4c01..6945f44b0b4 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1008,8 +1008,10 @@ def generate_token( if i % self.world_size == self.rank: if stop: # Decode generated tokens - output_text = self.decode( - all_input_ids[-stopping_criteria.current_tokens :] + output_text, _, _ = self.decode_token( + all_input_ids, + prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, + read_offset=len(all_input_ids) - stopping_criteria.current_tokens, ) generated_text = GeneratedText( output_text, diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index f4177145b18..02dda681247 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -611,11 +611,6 @@ def __init__( def batch_type(self) -> Type[IdeficsCausalLMBatch]: return IdeficsCausalLMBatch - def decode(self, generated_ids: List[int]) -> str: - return self.tokenizer.decode( - generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) - def forward( self, input_ids, @@ -728,8 +723,10 @@ def generate_token( if i % self.world_size == self.rank: if stop: # Decode generated tokens - output_text = self.decode( - all_input_ids[-stopping_criteria.current_tokens :, 0] + output_text, _, _ = self.decode_token( + all_input_ids[:, 0], + prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, + read_offset=len(all_input_ids) - stopping_criteria.current_tokens, ) # Get seed if isinstance(next_token_chooser.choice, Sampling): diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 1a7911acd3c..b8a0daf5aed 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -710,8 +710,10 @@ def generate_token( if stop: # Slice with decoder_input_length to remove padding # Decode all tokens - output_text = self.decode( - all_decoder_input_ids[-decoder_input_length:] + output_text, _, _ = self.decode_token( + all_decoder_input_ids, + prefix_offset=len(all_decoder_input_ids) - decoder_input_length - 1, + read_offset=len(all_decoder_input_ids) - decoder_input_length, ) # Get seed