diff --git a/router/src/infer.rs b/router/src/infer.rs index de8debc3871..e5bab868b95 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -527,13 +527,14 @@ fn send_responses( let tokens_ = generation.tokens.expect("Non empty tokens in generation"); let n = tokens_.ids.len(); metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64); - for (i, (((id, logprob), text), special)) in tokens_ + let mut iterator = tokens_ .ids .into_iter() .zip(tokens_.logprobs.into_iter()) .zip(tokens_.texts.into_iter()) .zip(tokens_.is_special.into_iter()) - .enumerate() + .enumerate().peekable(); + while let Some( (i, (((id, logprob), text), special))) = iterator.next() { let token = Token { id, @@ -557,9 +558,9 @@ fn send_responses( .collect() } else { vec![] - }; - match (&generation.generated_text, i) { - (Some(generated_text), i) if i == n - 1 => { + }; + match (&generation.generated_text, iterator.peek()) { + (Some(generated_text), None) => { // Generation has ended stopped = true; // Send message diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 4b1e2e547d8..27e3897d01c 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -101,7 +101,7 @@ def get_model( if speculate is not None: set_speculate(speculate) else: - set_speculate(2) + set_speculate(0) if "facebook/galactica" in model_id: return GalacticaSharded( @@ -159,7 +159,10 @@ def get_model( method = "medusa" else: method = "n-gram" - logger.info(f"Using speculation {method} with {get_speculate()} input ids.") + + speculate = get_speculate() + if speculate > 0: + logger.info(f"Using speculation {method} with {speculate} input ids.") model_type = config_dict["model_type"] diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 2ca864887ad..5e12e67c18b 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -960,9 +960,6 @@ def generate_token( top_token_logprobs, ) in enumerate(iterator): # Append next token to all tokens - _next_token_ids = next_token_ids[index: index+n_accepted_ids] - _next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids] - next_token_texts = [] left = 0 for j in range(index, index + n_accepted_ids): @@ -983,12 +980,14 @@ def generate_token( if stop: stopped = True - left = len(_next_token_ids) - 1 - j + left = n_accepted_ids - 1 - j break else: stopped = False + + _next_token_ids = next_token_ids[index: index+n_accepted_ids - left] + _next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids - left] index += n_accepted_ids - _next_token_ids = _next_token_ids[:len(_next_token_ids) - left] # Shard generations # All generations will be appended in the rust sharded client @@ -1085,8 +1084,6 @@ def generate_token( batch.prefill_cu_outlens = None batch.prefill_head_indices = None batch.prefill_next_token_indices = None - if prefill: - batch.max_seqlen += speculative_length batch.max_seqlen = batch.max_seqlen + 1 return generations, batch