diff --git a/src/transformers/generation/streamers.py b/src/transformers/generation/streamers.py index c75b43466af7a8..7f03fad1306b10 100644 --- a/src/transformers/generation/streamers.py +++ b/src/transformers/generation/streamers.py @@ -34,6 +34,10 @@ def end(self): """Function that is called by `.generate()` to signal the end of generation""" raise NotImplementedError() + def is_running(self) -> bool: + """Function that is called by `.generate()` to check if the streamer has ended""" + raise NotImplementedError() + class TextStreamer(BaseStreamer): """ @@ -69,7 +73,9 @@ class TextStreamer(BaseStreamer): ``` """ - def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs): + def __init__( + self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs + ): self.tokenizer = tokenizer self.skip_prompt = skip_prompt self.decode_kwargs = decode_kwargs @@ -203,12 +209,17 @@ class TextIteratorStreamer(TextStreamer): """ def __init__( - self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, timeout: Optional[float] = None, **decode_kwargs + self, + tokenizer: "AutoTokenizer", + skip_prompt: bool = False, + timeout: Optional[float] = None, + **decode_kwargs ): super().__init__(tokenizer, skip_prompt, **decode_kwargs) self.text_queue = Queue() self.stop_signal = None self.timeout = timeout + self.stopped = False def on_finalized_text(self, text: str, stream_end: bool = False): """Put the new text in the queue. If the stream is ending, also put a stop signal in the queue.""" @@ -216,6 +227,13 @@ def on_finalized_text(self, text: str, stream_end: bool = False): if stream_end: self.text_queue.put(self.stop_signal, timeout=self.timeout) + def end(self): + self.stopped = True + self.on_finalized_text("", stream_end=True) + + def is_running(self) -> bool: + return not self.stopped + def __iter__(self): return self diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 1d7eef755bf984..a58fbf70fc28c1 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2164,6 +2164,8 @@ def _contrastive_search( input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: streamer.put(next_tokens.cpu()) + if not streamer.is_running(): + break model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) @@ -2450,6 +2452,8 @@ def _greedy_search( input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: streamer.put(next_tokens.cpu()) + if not streamer.is_running(): + break model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, @@ -2752,6 +2756,8 @@ def _sample( input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: streamer.put(next_tokens.cpu()) + if not streamer.is_running(): + break model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) @@ -4612,6 +4618,8 @@ def _assisted_decoding( input_ids = torch.cat((input_ids, valid_tokens), dim=-1) if streamer is not None: streamer.put(valid_tokens.cpu()) + if not streamer.is_running(): + break new_cur_len = input_ids.shape[-1] # 4.2. Discard past key values relative to unused assistant tokens