From da426e4b24bbae0b360e61484d5528611c3772fa Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 5 Jun 2024 18:47:16 +0200 Subject: [PATCH] working example --- router/src/infer/v3/block_allocator.rs | 72 +++++++++++++---- router/src/infer/v3/queue.rs | 16 ++-- router/src/infer/v3/scheduler.rs | 80 ++++++++++++------- .../models/flash_causal_lm.py | 3 - 4 files changed, 119 insertions(+), 52 deletions(-) diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs index 811efb262ec..3e7cde893ed 100644 --- a/router/src/infer/v3/block_allocator.rs +++ b/router/src/infer/v3/block_allocator.rs @@ -1,10 +1,13 @@ -use std::cmp::min; +use std::cmp::{max, min}; +use thiserror::Error; use tokio::sync::{mpsc, oneshot}; #[derive(Debug, Clone)] pub(crate) struct BlockAllocation { pub blocks: Vec, pub slots: Vec, + prompt_tokens: u32, + decode_tokens: u32, block_allocator: BlockAllocator, } @@ -12,6 +15,14 @@ impl BlockAllocation { pub(crate) fn len(&self) -> usize { self.slots.len() } + + pub(crate) async fn extend(&mut self, current_length: u32) -> Result<(), AllocationError> { + let remaining_tokens = max(self.prompt_tokens + self.decode_tokens - current_length, 1); + self.block_allocator + .clone() + .extend(self, remaining_tokens) + .await + } } impl Drop for BlockAllocation { @@ -48,11 +59,16 @@ impl BlockAllocator { } } - pub(crate) async fn allocate(&self, tokens: u32) -> Option { + pub(crate) async fn allocate( + &self, + prompt_tokens: u32, + decode_tokens: u32, + ) -> Result { let (response_sender, response_receiver) = oneshot::channel(); self.block_allocator .send(BlockAllocatorCommand::Allocate { - tokens, + prompt_tokens, + decode_tokens, response_sender, }) .unwrap(); @@ -63,10 +79,32 @@ impl BlockAllocator { .map(|(blocks, slots)| BlockAllocation { blocks, slots, + prompt_tokens, + decode_tokens, block_allocator: self.clone(), }) } + pub(crate) async fn extend( + &self, + block_allocation: &mut BlockAllocation, + tokens: u32, + ) -> Result<(), AllocationError> { + let (response_sender, response_receiver) = oneshot::channel(); + self.block_allocator + .send(BlockAllocatorCommand::Allocate { + prompt_tokens: 0, + decode_tokens: tokens, + response_sender, + }) + .unwrap(); + + let (blocks, slots) = response_receiver.await.unwrap()?; + block_allocation.blocks.extend(blocks); + block_allocation.slots.extend(slots); + Ok(()) + } + pub(crate) fn free(&self, blocks: Vec) { self.block_allocator .send(BlockAllocatorCommand::Free { blocks }) @@ -86,10 +124,12 @@ async fn block_allocator_task( match cmd { BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks), BlockAllocatorCommand::Allocate { - tokens, + prompt_tokens, + decode_tokens, response_sender, } => { - // let tokens = 16; + let decode_tokens = min(decode_tokens, block_size); + let tokens = prompt_tokens + decode_tokens; // Apply window size let (required_blocks, repeats) = { @@ -106,9 +146,8 @@ async fn block_allocator_task( (required_blocks, repeats) }; - let tokens = tokens as usize; let allocation = if required_blocks > free_blocks.len() as u32 { - None + Err(AllocationError::NotEnoughPages) } else { let blocks = free_blocks.split_off(free_blocks.len() - required_blocks as usize); @@ -116,15 +155,12 @@ async fn block_allocator_task( (required_blocks * block_size * repeats as u32) as usize, ); - 'slots: for block_id in blocks.repeat(repeats).iter() { + for block_id in blocks.repeat(repeats).iter() { for s in (block_id * block_size)..((block_id + 1) * block_size) { slots.push(s); - if slots.len() == tokens { - break 'slots; - } } } - Some((blocks, slots)) + Ok((blocks, slots)) }; response_sender.send(allocation).unwrap(); } @@ -138,7 +174,15 @@ enum BlockAllocatorCommand { blocks: Vec, }, Allocate { - tokens: u32, - response_sender: oneshot::Sender, Vec)>>, + prompt_tokens: u32, + decode_tokens: u32, + #[allow(clippy::type_complexity)] + response_sender: oneshot::Sender, Vec), AllocationError>>, }, } + +#[derive(Error, Debug)] +pub enum AllocationError { + #[error("Not enough pages")] + NotEnoughPages, +} diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index 1ac06ae97c9..9a7b1084bbe 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -295,20 +295,20 @@ impl State { break; } - let tokens = entry.request.input_length - + entry.request.stopping_parameters.max_new_tokens - + self.speculate - - 1; - - match block_allocator.allocate(tokens).await { - None => { + let decode_tokens = + entry.request.stopping_parameters.max_new_tokens + self.speculate - 1; + match block_allocator + .allocate(entry.request.input_length, decode_tokens) + .await + { + Err(_) => { // Entry is over budget // Add it back to the front tracing::debug!("Over budget: not enough free blocks"); self.entries.push_front((id, entry)); break 'entry_loop; } - Some(block_allocation) => { + Ok(block_allocation) => { tracing::debug!("Allocation: {block_allocation:?}"); max_blocks = max(max_blocks, block_allocation.blocks.len() as u32); Some(block_allocation) diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index faa899ecd1d..b76c5c50eab 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -247,7 +247,7 @@ async fn prefill( filter_send_generations(generations, entries); // Filter next batch and remove requests that were stopped - let next_batch = filter_batch(client, next_batch, entries).await; + let next_batch = filter_batch(client, next_batch, entries, false).await; metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill"); metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill"); @@ -288,10 +288,10 @@ async fn decode( // Send generated tokens and filter stopped entries filter_send_generations(generations, entries); - filter_update_allocations(entries).await; + let updated = filter_update_allocations(entries).await; // Filter next batch and remove requests that were stopped - let next_batch = filter_batch(client, next_batch, entries).await; + let next_batch = filter_batch(client, next_batch, entries, updated).await; if let Some(concat_duration) = timings.concat { metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode"); @@ -322,11 +322,12 @@ async fn filter_batch( client: &mut ShardedClient, next_batch: Option, entries: &IntMap, + force_update: bool, ) -> Option { let batch = next_batch?; // No need to filter - if batch.size as usize == entries.len() { + if batch.size as usize == entries.len() && !force_update { return Some(batch); } @@ -348,6 +349,7 @@ async fn filter_batch( .as_ref() .map(|alloc| (alloc.blocks.clone(), alloc.slots.clone())) .unwrap_or((Vec::new(), Vec::new())); + UpdatedRequest { id: *request_id, blocks, @@ -393,34 +395,58 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap) { - entries.retain(|request_id, entry| { - if entry.block_allocation.is_none() { - return true; - } +async fn filter_update_allocations(entries: &mut IntMap) -> bool { + let ids: Vec = entries + .iter() + .filter_map(|(id, entry)| { + entry + .block_allocation + .as_ref() + .map(|block_allocation| { + if entry.current_length > block_allocation.len() as u32 { + // We need to re-allocate + Some(*id) + } else { + None + } + }) + .unwrap_or(None) + }) + .collect(); - // We can unwrap since we already validated above that block_allocation is not None - let mut block_allocation = entry.block_allocation.as_ref().unwrap(); + for id in ids.iter() { + // Get entry + // We can `expect` here as the request id should always be in the entries + let extension = { + let entry = entries + .get_mut(id) + .expect("ID not found in entries. This is a bug."); + entry + .block_allocation + .as_mut() + .unwrap() + .extend(entry.current_length) + .await + }; - // Nothing to update - if entry.current_length <= block_allocation.len() as u32 { - return true; - } + if extension.is_err() { + let entry = entries + .remove(id) + .expect("ID not found in entries. This is a bug."); - // Create and enter a span to link this function back to the entry - let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); - let err = InferError::OutOfPages; - metrics::increment_counter!("tgi_request_failure", "err" => "out_of_pages"); - tracing::error!("{err}"); + // Create and enter a span to link this function back to the entry + let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); + let err = InferError::OutOfPages; + metrics::increment_counter!("tgi_request_failure", "err" => "out_of_pages"); + tracing::error!("{err}"); - // unwrap_or is valid here as we don't care if the receiver is gone. - entry - .response_tx - .send(Err(err)) - .unwrap_or(()); + // unwrap_or is valid here as we don't care if the receiver is gone. + entry.response_tx.send(Err(err)).unwrap_or(()); + } + } - false - }); + // If ids is not empty, we need to update + !ids.is_empty() } /// Send responses through the `entry` response channel diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index d9a22755923..b7326c7ca41 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -399,9 +399,6 @@ def filter( ) -> Optional["FlashCausalLMBatch"]: if len(updated_requests) == 0: raise ValueError("Batch must have at least one request") - # We assume that if len(requests) == len(self) then the requests are the same - if len(updated_requests) == len(self): - return self device = self.input_ids.device