Skip to content

Commit

Permalink
server : fix crash when system prompt is bigger than batch size (gger…
Browse files Browse the repository at this point in the history
…ganov#5714)

The system prompt is now decoded in batches.

* server : fix off-by-one n_past when start of prompt matches whole cache

The tokens right after the matching part would otherwise skip a pos value.
  • Loading branch information
compilade authored Feb 25, 2024
1 parent abbabc5 commit f762501
Showing 1 changed file with 25 additions and 3 deletions.
28 changes: 25 additions & 3 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -902,10 +902,24 @@ struct llama_server_context
llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
}

if (llama_decode(ctx, batch) != 0)
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += params.n_batch)
{
LOG_TEE("%s: llama_decode() failed\n", __func__);
return;
const int32_t n_tokens = std::min(params.n_batch, (int32_t) (batch.n_tokens - i));
llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused
};
if (llama_decode(ctx, batch_view) != 0)
{
LOG_TEE("%s: llama_decode() failed\n", __func__);
return;
}
}

// assign the system KV cache to all parallel sequences
Expand Down Expand Up @@ -1785,6 +1799,14 @@ struct llama_server_context
}

slot.n_past = common_part(slot.cache_tokens, prompt_tokens);

// the last token of the cache is not in the KV cache until the next call to llama_decode
// (it was sampled, pushed into the "cache_tokens", but not yet put in the context)
if (slot.n_past > 0 && slot.n_past == (int32_t) slot.cache_tokens.size())
{
slot.n_past -= 1;
}

slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;

if (slot.ga_n != 1)
Expand Down

0 comments on commit f762501

Please sign in to comment.