From 51befc3a16d80e956f71d806f021005bb3de4d2e Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 5 Jun 2024 17:01:06 +0200 Subject: [PATCH] wip --- proto/v3/generate.proto | 33 +++----- router/client/src/v3/client.rs | 4 +- router/client/src/v3/mod.rs | 2 +- router/client/src/v3/sharded_client.rs | 6 +- router/src/infer/mod.rs | 3 + router/src/infer/v3/block_allocator.rs | 8 ++ router/src/infer/v3/queue.rs | 2 +- router/src/infer/v3/scheduler.rs | 84 +++++++++++-------- router/src/lib.rs | 3 - router/src/server.rs | 1 + .../models/causal_lm.py | 8 +- .../models/flash_causal_lm.py | 67 +++++++-------- .../models/idefics_causal_lm.py | 8 +- server/text_generation_server/models/mamba.py | 8 +- .../models/seq2seq_lm.py | 6 +- server/text_generation_server/models/types.py | 2 +- .../models/vlm_causal_lm.py | 6 +- server/text_generation_server/server.py | 2 +- 18 files changed, 139 insertions(+), 114 deletions(-) diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index d57fbbad4b1..192cd111bb7 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -17,8 +17,6 @@ service TextGenerationService { rpc Prefill (PrefillRequest) returns (PrefillResponse); /// Decode token for a list of prefilled batches rpc Decode (DecodeRequest) returns (DecodeResponse); - /// Update batch - rpc Update(UpdateRequest) returns (UpdateResponse); /// Health check rpc Health (HealthRequest) returns (HealthResponse); } @@ -204,11 +202,20 @@ message Generation { uint32 current_length = 6; } +message UpdatedRequest { + /// Request ID + uint64 id = 1; + /// Paged attention blocks + repeated uint32 blocks = 2; + /// Paged attention slots + repeated uint32 slots = 3; +} + message FilterBatchRequest { /// Batch ID uint64 batch_id = 1; /// Requests to keep - repeated uint64 request_ids = 2; + repeated UpdatedRequest updated_requests = 2; } message FilterBatchResponse { @@ -255,26 +262,6 @@ message DecodeResponse { optional uint64 concat_ns = 6; } -message ExtendedRequest { - /// Request ID - uint64 request_id = 1; - /// Paged attention blocks to add - repeated uint32 blocks = 2; - /// Paged attention slots to add - repeated uint32 slots = 3; -} - -message UpdateRequest { - /// Batch ID - uint64 batch_id = 1; - /// Requests to update - repeated ExtendedRequest extend_requests = 2; - /// Requests to terminate - repeated uint64 terminated_request_ids = 3; -} - -message UpdateResponse {} - message WarmupRequest { /// Batch to warmup on Batch batch = 1; diff --git a/router/client/src/v3/client.rs b/router/client/src/v3/client.rs index 9a3892fb186..8cefd3137b2 100644 --- a/router/client/src/v3/client.rs +++ b/router/client/src/v3/client.rs @@ -90,11 +90,11 @@ impl Client { pub async fn filter_batch( &mut self, batch_id: u64, - request_ids: Vec, + updated_requests: Vec, ) -> Result> { let request = tonic::Request::new(FilterBatchRequest { batch_id, - request_ids, + updated_requests, }) .inject_context(); let filtered_batch = self.stub.filter_batch(request).await?.into_inner(); diff --git a/router/client/src/v3/mod.rs b/router/client/src/v3/mod.rs index 4a1296a2247..df2bb380734 100644 --- a/router/client/src/v3/mod.rs +++ b/router/client/src/v3/mod.rs @@ -8,6 +8,6 @@ pub use client::Client; pub use pb::generate::v3::{ input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request, - StoppingCriteriaParameters, Tokens, + StoppingCriteriaParameters, Tokens, UpdatedRequest, }; pub use sharded_client::ShardedClient; diff --git a/router/client/src/v3/sharded_client.rs b/router/client/src/v3/sharded_client.rs index 94002f55064..a066176ce5e 100644 --- a/router/client/src/v3/sharded_client.rs +++ b/router/client/src/v3/sharded_client.rs @@ -10,7 +10,7 @@ use tracing::instrument; use v3::client::{DecodeTimings, PrefillTimings}; use v3::{ Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, - NextTokenChooserParameters, Request, StoppingCriteriaParameters, + NextTokenChooserParameters, Request, StoppingCriteriaParameters, UpdatedRequest, }; #[derive(Debug, Clone)] @@ -84,12 +84,12 @@ impl ShardedClient { pub async fn filter_batch( &mut self, batch_id: u64, - request_ids: Vec, + updated_requests: Vec, ) -> Result> { let futures: Vec<_> = self .clients .iter_mut() - .map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone()))) + .map(|client| Box::pin(client.filter_batch(batch_id, updated_requests.clone()))) .collect(); // all shards return the same message join_all(futures).await.pop().unwrap() diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 20630c1b0cd..3b61e46667f 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -506,6 +506,8 @@ pub enum InferError { TemplateError(#[from] minijinja::Error), #[error("Tool error: {0}")] ToolError(String), + #[error("Request could not be re-allocated: out of pages")] + OutOfPages, } impl InferError { @@ -517,6 +519,7 @@ impl InferError { InferError::IncompleteGeneration => "incomplete_generation", InferError::TemplateError(_) => "template_error", InferError::ToolError(_) => "tool_error", + InferError::OutOfPages => "out_of_pages", } } } diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs index 7467fd85997..811efb262ec 100644 --- a/router/src/infer/v3/block_allocator.rs +++ b/router/src/infer/v3/block_allocator.rs @@ -8,6 +8,12 @@ pub(crate) struct BlockAllocation { block_allocator: BlockAllocator, } +impl BlockAllocation { + pub(crate) fn len(&self) -> usize { + self.slots.len() + } +} + impl Drop for BlockAllocation { fn drop(&mut self) { self.block_allocator.free(self.blocks.clone()) @@ -83,6 +89,8 @@ async fn block_allocator_task( tokens, response_sender, } => { + // let tokens = 16; + // Apply window size let (required_blocks, repeats) = { let (tokens, repeats) = match window_size { diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index 1522679445c..1ac06ae97c9 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -34,7 +34,7 @@ pub(crate) struct Entry { /// Block Allocation pub block_allocation: Option, /// Current length (in tokens) of the request (prompt tokens + generated_tokens) - pub current_length: u32 + pub current_length: u32, } /// Request Queue diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index bf52e69f72a..faa899ecd1d 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -10,7 +10,7 @@ use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; -use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient}; +use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient, UpdatedRequest}; use text_generation_client::ClientError; use tokio::sync::mpsc::error::SendError; use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit}; @@ -288,7 +288,7 @@ async fn decode( // Send generated tokens and filter stopped entries filter_send_generations(generations, entries); - filter_update_allocations(client, entries).await; + filter_update_allocations(entries).await; // Filter next batch and remove requests that were stopped let next_batch = filter_batch(client, next_batch, entries).await; @@ -323,7 +323,7 @@ async fn filter_batch( next_batch: Option, entries: &IntMap, ) -> Option { - let mut batch = next_batch?; + let batch = next_batch?; // No need to filter if batch.size as usize == entries.len() { @@ -331,11 +331,7 @@ async fn filter_batch( } let id = batch.id; - - // Retain only requests that are still in entries - batch.request_ids.retain(|id| entries.contains_key(id)); - - if batch.request_ids.is_empty() { + if entries.is_empty() { // All requests have been filtered out // Next batch is now empty // Clear it from the Python shards cache @@ -344,8 +340,24 @@ async fn filter_batch( None } else { // Filter Python shard cache + let updated_requests = entries + .iter() + .map(|(request_id, entry)| { + let (blocks, slots) = entry + .block_allocation + .as_ref() + .map(|alloc| (alloc.blocks.clone(), alloc.slots.clone())) + .unwrap_or((Vec::new(), Vec::new())); + UpdatedRequest { + id: *request_id, + blocks, + slots, + } + }) + .collect(); + // We unwrap here as we need to panic since we cannot recover if this method fails - client.filter_batch(id, batch.request_ids).await.unwrap() + client.filter_batch(id, updated_requests).await.unwrap() } } @@ -379,32 +391,36 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap) { - // let mut extend_entries = Vec::with_capacity(entries.len()); - // let mut finish_entries = Vec::with_capacity(entries.len()); - - // for (request_id, entry) in entries.into_iter() { - // tracing::info!("Allocation {}; Current Length: {}", entry.block_allocation.as_ref().unwrap().allocated_tokens, entry.current_length); - // - // if let Some(block_allocation) = &mut entry.block_allocation { - // tracing::info!("Allocation {:?}", block_allocation); - // - // if entry.current_length > block_allocation.allocated_tokens { - // // We need to add new blocks to this entry - // let remaining_tokens = block_allocation.total_tokens - entry.current_length; - // match block_allocation.extend(remaining_tokens).await { - // true => { - // - // }, - // false => { - // - // } - // } - // } - // } - // } +async fn filter_update_allocations(entries: &mut IntMap) { + entries.retain(|request_id, entry| { + if entry.block_allocation.is_none() { + return true; + } + + // We can unwrap since we already validated above that block_allocation is not None + let mut block_allocation = entry.block_allocation.as_ref().unwrap(); + + // Nothing to update + if entry.current_length <= block_allocation.len() as u32 { + return true; + } + + // Create and enter a span to link this function back to the entry + let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); + let err = InferError::OutOfPages; + metrics::increment_counter!("tgi_request_failure", "err" => "out_of_pages"); + tracing::error!("{err}"); + + // unwrap_or is valid here as we don't care if the receiver is gone. + entry + .response_tx + .send(Err(err)) + .unwrap_or(()); + + false + }); } /// Send responses through the `entry` response channel diff --git a/router/src/lib.rs b/router/src/lib.rs index 52c5aa461fd..b6902c497c1 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1085,8 +1085,6 @@ pub(crate) enum FinishReason { EndOfSequenceToken, #[schema(rename = "stop_sequence")] StopSequence, - #[schema(rename = "out_of_pages")] - OutOfPages } impl std::fmt::Display for FinishReason { @@ -1095,7 +1093,6 @@ impl std::fmt::Display for FinishReason { FinishReason::Length => write!(f, "length"), FinishReason::EndOfSequenceToken => write!(f, "eos_token"), FinishReason::StopSequence => write!(f, "stop_sequence"), - FinishReason::OutOfPages => write!(f, "out_of_pages"), } } } diff --git a/router/src/server.rs b/router/src/server.rs index 30479b0e4dc..9df33739c40 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1859,6 +1859,7 @@ impl From for (StatusCode, Json) { InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY, + InferError::OutOfPages => StatusCode::TOO_MANY_REQUESTS, }; ( diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 3b75e5c6dcd..b90a37686d6 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -156,7 +156,11 @@ def from_pb( ) @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: + def filter( + self, updated_requests: List[generate_pb2.UpdatedRequest] + ) -> Optional["CausalLMBatch"]: + request_ids = [r.id for r in updated_requests] + if len(request_ids) == 0: raise ValueError("Batch must have at least one request") if len(request_ids) == len(self): @@ -744,7 +748,7 @@ def generate_token( ), generated_text, top_tokens, - new_input_length + new_input_length, ) generations.append(generation) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 0aea9d10a80..d9a22755923 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -81,14 +81,10 @@ class FlashCausalLMBatch(Batch): # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode slot_indices: torch.Tensor - # list of length b of list of length s_i // block_size - block_tables: List[List[int]] # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences block_tables_tensor: torch.Tensor - # list of length b of list of length s_i - slots: List[List[int]] # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences - slots_tensor: torch.Tensor + slots: torch.Tensor max_seqlen: int @@ -180,7 +176,6 @@ def from_tokenized( max_blocks = 0 block_tables = [] - slots = [] flat_slots = [] # Parse batch @@ -250,7 +245,6 @@ def from_tokenized( len(flat_slots) + input_length, dtype=torch.int64, ) - slots.append(request_slots) flat_slots.extend(request_slots) slot_indices.append(request_slot_indices) @@ -350,7 +344,7 @@ def from_tokenized( top_n_tokens, device=device, dtype=torch.int64 ) - slots_tensor = torch.tensor(flat_slots, dtype=torch.int64, device=device) + slots = torch.tensor(flat_slots, dtype=torch.int64, device=device) block_tables_tensor = torch.zeros( (len(block_tables), max_blocks), dtype=torch.int32, device="cpu" ) @@ -367,10 +361,8 @@ def from_tokenized( cu_seqlen_prefill=cu_seqlen_prefill, prefill_cache_indices=prefill_cache_indices, slot_indices=slot_indices, - block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, - slots_tensor=slots_tensor, max_seqlen=max_seqlen, prefill_head_indices=prefill_head_indices, prefill_next_token_indices=prefill_next_token_indices, @@ -402,11 +394,13 @@ def from_pb( return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": - if len(request_ids) == 0: + def filter( + self, updated_requests: List[generate_pb2.UpdatedRequest] + ) -> Optional["FlashCausalLMBatch"]: + if len(updated_requests) == 0: raise ValueError("Batch must have at least one request") # We assume that if len(requests) == len(self) then the requests are the same - if len(request_ids) == len(self): + if len(updated_requests) == len(self): return self device = self.input_ids.device @@ -422,7 +416,6 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": requests = [] block_tables = [] - slots = [] flat_slots = [] all_input_ids = [] @@ -436,7 +429,9 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": num_blocks = 0 max_blocks = 0 - for i, request_id in enumerate(request_ids): + for i, request in enumerate(updated_requests): + request_id = request.id + idx = self.requests_idx_mapping[request_id] indices.append(idx) requests_idx_mapping[request_id] = i @@ -458,13 +453,12 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": top_n_tokens.append(self.top_n_tokens[idx]) - request_block_table = self.block_tables[idx] + request_block_table = request.blocks num_blocks += len(request_block_table) block_tables.append(request_block_table) # List of slots allocated for this request - request_slots = self.slots[idx] - slots.append(request_slots) + request_slots = request.slots # Index slot_indices.append(len(flat_slots) + request_input_length - 1) @@ -476,7 +470,6 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": input_ids = self.input_ids[indices] position_ids = self.position_ids[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices] - block_tables_tensor = self.block_tables_tensor[indices] input_lengths_tensor = self.input_lengths_tensor[indices] next_token_chooser = self.next_token_chooser.filter(indices) top_n_tokens_tensor = self.top_n_tokens_tensor[indices] @@ -484,10 +477,20 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": self.speculative_ids[indices] if self.speculative_ids is not None else None ) + # Create block_tables_tensor on CPU + block_tables_tensor = torch.zeros( + (len(block_tables), max_blocks), dtype=torch.int32, device="cpu" + ) + for i, request_blocks in enumerate(block_tables): + block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) + # Allocate on GPU - slots_tensor = torch.tensor(flat_slots, dtype=torch.int64, device=device) + slots = torch.tensor(flat_slots, dtype=torch.int64, device=device) slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device) + # Move to GPU + block_tables_tensor = block_tables_tensor.to(device) + return type(self)( batch_id=self.batch_id, requests=requests, @@ -497,10 +500,8 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": cu_seqlen_prefill=None, prefill_cache_indices=None, slot_indices=slot_indices, - block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, - slots_tensor=slots_tensor, max_seqlen=max_seqlen, prefill_head_indices=None, prefill_next_token_indices=None, @@ -535,7 +536,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch max_seqlen = 0 for b in batches: total_batch_size += len(b) - total_slots += len(b.slots_tensor) + total_slots += len(b.slots) num_blocks += b.num_blocks speculative_length = ( b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 @@ -558,7 +559,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch input_ids = batches[0].input_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size) - slots_tensor = batches[0].slots_tensor.new_empty(total_slots) + slots = batches[0].slots.new_empty(total_slots) slot_indices = batches[0].slot_indices.new_empty(total_batch_size) input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( total_batch_size @@ -573,8 +574,6 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch total_batch_size, ) - slots = [] - block_tables = [] all_input_ids = [] input_lengths = [] @@ -603,7 +602,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch start_index = cumulative_batch_size end_index = cumulative_batch_size + len(batch) slots_start_index = cumulative_slots - slots_end_index = cumulative_slots + len(batch.slots_tensor) + slots_end_index = cumulative_slots + len(batch.slots) # Copy tensors (GPU) input_ids[start_index:end_index] = batch.input_ids @@ -611,7 +610,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor - slots_tensor[slots_start_index:slots_end_index] = batch.slots_tensor + slots[slots_start_index:slots_end_index] = batch.slots all_input_ids_tensor[ start_index:end_index, : batch.all_input_ids_tensor.shape[1] @@ -621,8 +620,6 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch start_index:end_index, : batch.block_tables_tensor.shape[1] ] = batch.block_tables_tensor[:, :max_blocks] - slots.extend(batch.slots) - block_tables.extend(batch.block_tables) all_input_ids.extend(batch.all_input_ids) input_lengths.extend(batch.input_lengths) @@ -637,7 +634,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch # Update cumulative_batch_size += len(batch) - cumulative_slots += len(batch.slots_tensor) + cumulative_slots += len(batch.slots) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, @@ -662,10 +659,8 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch cu_seqlen_prefill=None, prefill_cache_indices=None, slot_indices=slot_indices, - block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, - slots_tensor=slots_tensor, max_seqlen=max_seqlen, prefill_head_indices=None, prefill_next_token_indices=None, @@ -961,7 +956,7 @@ def forward( cu_seqlen_prefill = batch.cu_seqlen_prefill kv_cache = self.kv_cache block_tables = batch.block_tables_tensor - slots = batch.slots_tensor[batch.slot_indices] + slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -1000,7 +995,7 @@ def forward( cu_seqlen_prefill = batch.cu_seqlen_prefill kv_cache = self.kv_cache block_tables = batch.block_tables_tensor - slots = batch.slots_tensor[batch.slot_indices] + slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -1342,7 +1337,7 @@ def generate_token( ), generated_text, top_tokens, - input_length + n_accepted_ids + input_length + n_accepted_ids, ) generations.append(generation) diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 09203ff3ec6..f5fe8f9ba16 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -211,7 +211,11 @@ def from_pb_processor( ) @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]: + def filter( + self, updated_requests: List[generate_pb2.UpdatedRequest] + ) -> Optional["IdeficsCausalLMBatch"]: + request_ids = [r.id for r in updated_requests] + # It deletes requests from the batch. For instance when client lost connection if len(request_ids) == 0: raise ValueError("Batch must have at least one request") @@ -826,7 +830,7 @@ def generate_token( ), generated_text, top_tokens, - new_input_length + new_input_length, ) generations.append(generation) diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index f26ca6d67b5..5ecc64faa98 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -194,7 +194,11 @@ def from_pb( max_tokens=max_tokens, ) - def filter(self, request_ids: List[int]) -> Optional["MambaBatch"]: + def filter( + self, updated_requests: List[generate_pb2.UpdatedRequest] + ) -> Optional["MambaBatch"]: + request_ids = [r.id for r in updated_requests] + if len(request_ids) == 0: raise ValueError("Batch must have at least one request") if len(request_ids) == len(self): @@ -774,7 +778,7 @@ def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, in ), generated_text, top_tokens, - new_input_length + new_input_length, ) generations.append(generation) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index dfa0431fbbb..4a998bb5dfb 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -165,7 +165,11 @@ def from_pb( ) @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]: + def filter( + self, updated_requests: List[generate_pb2.UpdatedRequest] + ) -> Optional["Seq2SeqLMBatch"]: + request_ids = [r.id for r in updated_requests] + if len(request_ids) == 0: raise ValueError("Batch must have at least one request") if len(request_ids) == len(self): diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index 1c7a157a453..50c14862762 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -28,7 +28,7 @@ def from_pb( raise NotImplementedError @abstractmethod - def filter(self, request_ids: List[int]) -> "Batch": + def filter(self, updated_requests: List[generate_pb2.UpdatedRequest]) -> "Batch": raise NotImplementedError @classmethod diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 4693c90f216..4780ad898e7 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -151,8 +151,10 @@ def concatenate(cls, batches): return batch @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]): - batch = super().filter(request_ids) + def filter( + self, updated_requests: List[generate_pb2.UpdatedRequest] + ) -> Optional["VlmCausalLMBatch"]: + batch = super().filter(updated_requests) batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 4118b3f6fad..7eef22b20bd 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -83,7 +83,7 @@ async def FilterBatch(self, request, context): batch = self.cache.pop(request.batch_id) if batch is None: raise ValueError(f"Batch ID {request.batch_id} not found in cache.") - filtered_batch = batch.filter(request.request_ids) + filtered_batch = batch.filter(request.updated_requests) self.cache.set(filtered_batch) return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())