Skip to content

Commit

Permalink
Fixing infer iterator.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Dec 5, 2023
1 parent 09839b0 commit 9bf31fe
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 14 deletions.
11 changes: 6 additions & 5 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -527,13 +527,14 @@ fn send_responses(
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
let n = tokens_.ids.len();
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64);
for (i, (((id, logprob), text), special)) in tokens_
let mut iterator = tokens_
.ids
.into_iter()
.zip(tokens_.logprobs.into_iter())
.zip(tokens_.texts.into_iter())
.zip(tokens_.is_special.into_iter())
.enumerate()
.enumerate().peekable();
while let Some( (i, (((id, logprob), text), special))) = iterator.next()
{
let token = Token {
id,
Expand All @@ -557,9 +558,9 @@ fn send_responses(
.collect()
} else {
vec![]
};
match (&generation.generated_text, i) {
(Some(generated_text), i) if i == n - 1 => {
};
match (&generation.generated_text, iterator.peek()) {
(Some(generated_text), None) => {
// Generation has ended
stopped = true;
// Send message
Expand Down
7 changes: 5 additions & 2 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def get_model(
if speculate is not None:
set_speculate(speculate)
else:
set_speculate(2)
set_speculate(0)

if "facebook/galactica" in model_id:
return GalacticaSharded(
Expand Down Expand Up @@ -159,7 +159,10 @@ def get_model(
method = "medusa"
else:
method = "n-gram"
logger.info(f"Using speculation {method} with {get_speculate()} input ids.")

speculate = get_speculate()
if speculate > 0:
logger.info(f"Using speculation {method} with {speculate} input ids.")

model_type = config_dict["model_type"]

Expand Down
11 changes: 4 additions & 7 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,9 +960,6 @@ def generate_token(
top_token_logprobs,
) in enumerate(iterator):
# Append next token to all tokens
_next_token_ids = next_token_ids[index: index+n_accepted_ids]
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids]

next_token_texts = []
left = 0
for j in range(index, index + n_accepted_ids):
Expand All @@ -983,12 +980,14 @@ def generate_token(

if stop:
stopped = True
left = len(_next_token_ids) - 1 - j
left = n_accepted_ids - 1 - j
break
else:
stopped = False

_next_token_ids = next_token_ids[index: index+n_accepted_ids - left]
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids - left]
index += n_accepted_ids
_next_token_ids = _next_token_ids[:len(_next_token_ids) - left]

# Shard generations
# All generations will be appended in the rust sharded client
Expand Down Expand Up @@ -1085,8 +1084,6 @@ def generate_token(
batch.prefill_cu_outlens = None
batch.prefill_head_indices = None
batch.prefill_next_token_indices = None
if prefill:
batch.max_seqlen += speculative_length
batch.max_seqlen = batch.max_seqlen + 1

return generations, batch

0 comments on commit 9bf31fe

Please sign in to comment.