Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene committed Jun 5, 2024
1 parent 3ba6566 commit 51befc3
Show file tree
Hide file tree
Showing 18 changed files with 139 additions and 114 deletions.
33 changes: 10 additions & 23 deletions proto/v3/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ service TextGenerationService {
rpc Prefill (PrefillRequest) returns (PrefillResponse);
/// Decode token for a list of prefilled batches
rpc Decode (DecodeRequest) returns (DecodeResponse);
/// Update batch
rpc Update(UpdateRequest) returns (UpdateResponse);
/// Health check
rpc Health (HealthRequest) returns (HealthResponse);
}
Expand Down Expand Up @@ -204,11 +202,20 @@ message Generation {
uint32 current_length = 6;
}

message UpdatedRequest {
/// Request ID
uint64 id = 1;
/// Paged attention blocks
repeated uint32 blocks = 2;
/// Paged attention slots
repeated uint32 slots = 3;
}

message FilterBatchRequest {
/// Batch ID
uint64 batch_id = 1;
/// Requests to keep
repeated uint64 request_ids = 2;
repeated UpdatedRequest updated_requests = 2;
}

message FilterBatchResponse {
Expand Down Expand Up @@ -255,26 +262,6 @@ message DecodeResponse {
optional uint64 concat_ns = 6;
}

message ExtendedRequest {
/// Request ID
uint64 request_id = 1;
/// Paged attention blocks to add
repeated uint32 blocks = 2;
/// Paged attention slots to add
repeated uint32 slots = 3;
}

message UpdateRequest {
/// Batch ID
uint64 batch_id = 1;
/// Requests to update
repeated ExtendedRequest extend_requests = 2;
/// Requests to terminate
repeated uint64 terminated_request_ids = 3;
}

message UpdateResponse {}

message WarmupRequest {
/// Batch to warmup on
Batch batch = 1;
Expand Down
4 changes: 2 additions & 2 deletions router/client/src/v3/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,11 @@ impl Client {
pub async fn filter_batch(
&mut self,
batch_id: u64,
request_ids: Vec<u64>,
updated_requests: Vec<UpdatedRequest>,
) -> Result<Option<CachedBatch>> {
let request = tonic::Request::new(FilterBatchRequest {
batch_id,
request_ids,
updated_requests,
})
.inject_context();
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
Expand Down
2 changes: 1 addition & 1 deletion router/client/src/v3/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ pub use client::Client;
pub use pb::generate::v3::{
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
StoppingCriteriaParameters, Tokens,
StoppingCriteriaParameters, Tokens, UpdatedRequest,
};
pub use sharded_client::ShardedClient;
6 changes: 3 additions & 3 deletions router/client/src/v3/sharded_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use tracing::instrument;
use v3::client::{DecodeTimings, PrefillTimings};
use v3::{
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
NextTokenChooserParameters, Request, StoppingCriteriaParameters, UpdatedRequest,
};

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -84,12 +84,12 @@ impl ShardedClient {
pub async fn filter_batch(
&mut self,
batch_id: u64,
request_ids: Vec<u64>,
updated_requests: Vec<UpdatedRequest>,
) -> Result<Option<CachedBatch>> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
.map(|client| Box::pin(client.filter_batch(batch_id, updated_requests.clone())))
.collect();
// all shards return the same message
join_all(futures).await.pop().unwrap()
Expand Down
3 changes: 3 additions & 0 deletions router/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,8 @@ pub enum InferError {
TemplateError(#[from] minijinja::Error),
#[error("Tool error: {0}")]
ToolError(String),
#[error("Request could not be re-allocated: out of pages")]
OutOfPages,
}

impl InferError {
Expand All @@ -517,6 +519,7 @@ impl InferError {
InferError::IncompleteGeneration => "incomplete_generation",
InferError::TemplateError(_) => "template_error",
InferError::ToolError(_) => "tool_error",
InferError::OutOfPages => "out_of_pages",
}
}
}
8 changes: 8 additions & 0 deletions router/src/infer/v3/block_allocator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ pub(crate) struct BlockAllocation {
block_allocator: BlockAllocator,
}

impl BlockAllocation {
pub(crate) fn len(&self) -> usize {
self.slots.len()
}
}

impl Drop for BlockAllocation {
fn drop(&mut self) {
self.block_allocator.free(self.blocks.clone())
Expand Down Expand Up @@ -83,6 +89,8 @@ async fn block_allocator_task(
tokens,
response_sender,
} => {
// let tokens = 16;

// Apply window size
let (required_blocks, repeats) = {
let (tokens, repeats) = match window_size {
Expand Down
2 changes: 1 addition & 1 deletion router/src/infer/v3/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub(crate) struct Entry {
/// Block Allocation
pub block_allocation: Option<BlockAllocation>,
/// Current length (in tokens) of the request (prompt tokens + generated_tokens)
pub current_length: u32
pub current_length: u32,
}

/// Request Queue
Expand Down
84 changes: 50 additions & 34 deletions router/src/infer/v3/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient};
use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient, UpdatedRequest};
use text_generation_client::ClientError;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit};
Expand Down Expand Up @@ -288,7 +288,7 @@ async fn decode(
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);

filter_update_allocations(client, entries).await;
filter_update_allocations(entries).await;

// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
Expand Down Expand Up @@ -323,19 +323,15 @@ async fn filter_batch(
next_batch: Option<CachedBatch>,
entries: &IntMap<u64, Entry>,
) -> Option<CachedBatch> {
let mut batch = next_batch?;
let batch = next_batch?;

// No need to filter
if batch.size as usize == entries.len() {
return Some(batch);
}

let id = batch.id;

// Retain only requests that are still in entries
batch.request_ids.retain(|id| entries.contains_key(id));

if batch.request_ids.is_empty() {
if entries.is_empty() {
// All requests have been filtered out
// Next batch is now empty
// Clear it from the Python shards cache
Expand All @@ -344,8 +340,24 @@ async fn filter_batch(
None
} else {
// Filter Python shard cache
let updated_requests = entries
.iter()
.map(|(request_id, entry)| {
let (blocks, slots) = entry
.block_allocation
.as_ref()
.map(|alloc| (alloc.blocks.clone(), alloc.slots.clone()))
.unwrap_or((Vec::new(), Vec::new()));
UpdatedRequest {
id: *request_id,
blocks,
slots,
}
})
.collect();

// We unwrap here as we need to panic since we cannot recover if this method fails
client.filter_batch(id, batch.request_ids).await.unwrap()
client.filter_batch(id, updated_requests).await.unwrap()
}
}

Expand Down Expand Up @@ -379,32 +391,36 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
}

/// Check if block allocations need to be extended
/// If we don't have enough blocks, request will be filtered with an OutOfPages finish reason
/// If we don't have enough blocks, request will be filtered with an OutOfPages error
#[instrument(skip_all)]
async fn filter_update_allocations(client: &mut ShardedClient, entries: &mut IntMap<u64, Entry>) {
// let mut extend_entries = Vec::with_capacity(entries.len());
// let mut finish_entries = Vec::with_capacity(entries.len());

// for (request_id, entry) in entries.into_iter() {
// tracing::info!("Allocation {}; Current Length: {}", entry.block_allocation.as_ref().unwrap().allocated_tokens, entry.current_length);
//
// if let Some(block_allocation) = &mut entry.block_allocation {
// tracing::info!("Allocation {:?}", block_allocation);
//
// if entry.current_length > block_allocation.allocated_tokens {
// // We need to add new blocks to this entry
// let remaining_tokens = block_allocation.total_tokens - entry.current_length;
// match block_allocation.extend(remaining_tokens).await {
// true => {
//
// },
// false => {
//
// }
// }
// }
// }
// }
async fn filter_update_allocations(entries: &mut IntMap<u64, Entry>) {
entries.retain(|request_id, entry| {
if entry.block_allocation.is_none() {
return true;
}

// We can unwrap since we already validated above that block_allocation is not None
let mut block_allocation = entry.block_allocation.as_ref().unwrap();

// Nothing to update
if entry.current_length <= block_allocation.len() as u32 {
return true;
}

// Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::OutOfPages;
metrics::increment_counter!("tgi_request_failure", "err" => "out_of_pages");
tracing::error!("{err}");

// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Err(err))
.unwrap_or(());

false
});
}

/// Send responses through the `entry` response channel
Expand Down
3 changes: 0 additions & 3 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1085,8 +1085,6 @@ pub(crate) enum FinishReason {
EndOfSequenceToken,
#[schema(rename = "stop_sequence")]
StopSequence,
#[schema(rename = "out_of_pages")]
OutOfPages
}

impl std::fmt::Display for FinishReason {
Expand All @@ -1095,7 +1093,6 @@ impl std::fmt::Display for FinishReason {
FinishReason::Length => write!(f, "length"),
FinishReason::EndOfSequenceToken => write!(f, "eos_token"),
FinishReason::StopSequence => write!(f, "stop_sequence"),
FinishReason::OutOfPages => write!(f, "out_of_pages"),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1859,6 +1859,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::OutOfPages => StatusCode::TOO_MANY_REQUESTS,
};

(
Expand Down
8 changes: 6 additions & 2 deletions server/text_generation_server/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,11 @@ def from_pb(
)

@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
def filter(
self, updated_requests: List[generate_pb2.UpdatedRequest]
) -> Optional["CausalLMBatch"]:
request_ids = [r.id for r in updated_requests]

if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
if len(request_ids) == len(self):
Expand Down Expand Up @@ -744,7 +748,7 @@ def generate_token(
),
generated_text,
top_tokens,
new_input_length
new_input_length,
)

generations.append(generation)
Expand Down
Loading

0 comments on commit 51befc3

Please sign in to comment.