Skip to content

Commit

Permalink
working example
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene committed Jun 5, 2024
1 parent 51befc3 commit da426e4
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 52 deletions.
72 changes: 58 additions & 14 deletions router/src/infer/v3/block_allocator.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,28 @@
use std::cmp::min;
use std::cmp::{max, min};
use thiserror::Error;
use tokio::sync::{mpsc, oneshot};

#[derive(Debug, Clone)]
pub(crate) struct BlockAllocation {
pub blocks: Vec<u32>,
pub slots: Vec<u32>,
prompt_tokens: u32,
decode_tokens: u32,
block_allocator: BlockAllocator,
}

impl BlockAllocation {
pub(crate) fn len(&self) -> usize {
self.slots.len()
}

pub(crate) async fn extend(&mut self, current_length: u32) -> Result<(), AllocationError> {
let remaining_tokens = max(self.prompt_tokens + self.decode_tokens - current_length, 1);
self.block_allocator
.clone()
.extend(self, remaining_tokens)
.await
}
}

impl Drop for BlockAllocation {
Expand Down Expand Up @@ -48,11 +59,16 @@ impl BlockAllocator {
}
}

pub(crate) async fn allocate(&self, tokens: u32) -> Option<BlockAllocation> {
pub(crate) async fn allocate(
&self,
prompt_tokens: u32,
decode_tokens: u32,
) -> Result<BlockAllocation, AllocationError> {
let (response_sender, response_receiver) = oneshot::channel();
self.block_allocator
.send(BlockAllocatorCommand::Allocate {
tokens,
prompt_tokens,
decode_tokens,
response_sender,
})
.unwrap();
Expand All @@ -63,10 +79,32 @@ impl BlockAllocator {
.map(|(blocks, slots)| BlockAllocation {
blocks,
slots,
prompt_tokens,
decode_tokens,
block_allocator: self.clone(),
})
}

pub(crate) async fn extend(
&self,
block_allocation: &mut BlockAllocation,
tokens: u32,
) -> Result<(), AllocationError> {
let (response_sender, response_receiver) = oneshot::channel();
self.block_allocator
.send(BlockAllocatorCommand::Allocate {
prompt_tokens: 0,
decode_tokens: tokens,
response_sender,
})
.unwrap();

let (blocks, slots) = response_receiver.await.unwrap()?;
block_allocation.blocks.extend(blocks);
block_allocation.slots.extend(slots);
Ok(())
}

pub(crate) fn free(&self, blocks: Vec<u32>) {
self.block_allocator
.send(BlockAllocatorCommand::Free { blocks })
Expand All @@ -86,10 +124,12 @@ async fn block_allocator_task(
match cmd {
BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks),
BlockAllocatorCommand::Allocate {
tokens,
prompt_tokens,
decode_tokens,
response_sender,
} => {
// let tokens = 16;
let decode_tokens = min(decode_tokens, block_size);
let tokens = prompt_tokens + decode_tokens;

// Apply window size
let (required_blocks, repeats) = {
Expand All @@ -106,25 +146,21 @@ async fn block_allocator_task(
(required_blocks, repeats)
};

let tokens = tokens as usize;
let allocation = if required_blocks > free_blocks.len() as u32 {
None
Err(AllocationError::NotEnoughPages)
} else {
let blocks =
free_blocks.split_off(free_blocks.len() - required_blocks as usize);
let mut slots = Vec::with_capacity(
(required_blocks * block_size * repeats as u32) as usize,
);

'slots: for block_id in blocks.repeat(repeats).iter() {
for block_id in blocks.repeat(repeats).iter() {
for s in (block_id * block_size)..((block_id + 1) * block_size) {
slots.push(s);
if slots.len() == tokens {
break 'slots;
}
}
}
Some((blocks, slots))
Ok((blocks, slots))
};
response_sender.send(allocation).unwrap();
}
Expand All @@ -138,7 +174,15 @@ enum BlockAllocatorCommand {
blocks: Vec<u32>,
},
Allocate {
tokens: u32,
response_sender: oneshot::Sender<Option<(Vec<u32>, Vec<u32>)>>,
prompt_tokens: u32,
decode_tokens: u32,
#[allow(clippy::type_complexity)]
response_sender: oneshot::Sender<Result<(Vec<u32>, Vec<u32>), AllocationError>>,
},
}

#[derive(Error, Debug)]
pub enum AllocationError {
#[error("Not enough pages")]
NotEnoughPages,
}
16 changes: 8 additions & 8 deletions router/src/infer/v3/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,20 +295,20 @@ impl State {
break;
}

let tokens = entry.request.input_length
+ entry.request.stopping_parameters.max_new_tokens
+ self.speculate
- 1;

match block_allocator.allocate(tokens).await {
None => {
let decode_tokens =
entry.request.stopping_parameters.max_new_tokens + self.speculate - 1;
match block_allocator
.allocate(entry.request.input_length, decode_tokens)
.await
{
Err(_) => {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: not enough free blocks");
self.entries.push_front((id, entry));
break 'entry_loop;
}
Some(block_allocation) => {
Ok(block_allocation) => {
tracing::debug!("Allocation: {block_allocation:?}");
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
Some(block_allocation)
Expand Down
80 changes: 53 additions & 27 deletions router/src/infer/v3/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ async fn prefill(
filter_send_generations(generations, entries);

// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
let next_batch = filter_batch(client, next_batch, entries, false).await;

metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill");
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill");
Expand Down Expand Up @@ -288,10 +288,10 @@ async fn decode(
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);

filter_update_allocations(entries).await;
let updated = filter_update_allocations(entries).await;

// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
let next_batch = filter_batch(client, next_batch, entries, updated).await;

if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode");
Expand Down Expand Up @@ -322,11 +322,12 @@ async fn filter_batch(
client: &mut ShardedClient,
next_batch: Option<CachedBatch>,
entries: &IntMap<u64, Entry>,
force_update: bool,
) -> Option<CachedBatch> {
let batch = next_batch?;

// No need to filter
if batch.size as usize == entries.len() {
if batch.size as usize == entries.len() && !force_update {
return Some(batch);
}

Expand All @@ -348,6 +349,7 @@ async fn filter_batch(
.as_ref()
.map(|alloc| (alloc.blocks.clone(), alloc.slots.clone()))
.unwrap_or((Vec::new(), Vec::new()));

UpdatedRequest {
id: *request_id,
blocks,
Expand Down Expand Up @@ -393,34 +395,58 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
/// Check if block allocations need to be extended
/// If we don't have enough blocks, request will be filtered with an OutOfPages error
#[instrument(skip_all)]
async fn filter_update_allocations(entries: &mut IntMap<u64, Entry>) {
entries.retain(|request_id, entry| {
if entry.block_allocation.is_none() {
return true;
}
async fn filter_update_allocations(entries: &mut IntMap<u64, Entry>) -> bool {
let ids: Vec<u64> = entries
.iter()
.filter_map(|(id, entry)| {
entry
.block_allocation
.as_ref()
.map(|block_allocation| {
if entry.current_length > block_allocation.len() as u32 {
// We need to re-allocate
Some(*id)
} else {
None
}
})
.unwrap_or(None)
})
.collect();

// We can unwrap since we already validated above that block_allocation is not None
let mut block_allocation = entry.block_allocation.as_ref().unwrap();
for id in ids.iter() {
// Get entry
// We can `expect` here as the request id should always be in the entries
let extension = {
let entry = entries
.get_mut(id)
.expect("ID not found in entries. This is a bug.");
entry
.block_allocation
.as_mut()
.unwrap()
.extend(entry.current_length)
.await
};

// Nothing to update
if entry.current_length <= block_allocation.len() as u32 {
return true;
}
if extension.is_err() {
let entry = entries
.remove(id)
.expect("ID not found in entries. This is a bug.");

// Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::OutOfPages;
metrics::increment_counter!("tgi_request_failure", "err" => "out_of_pages");
tracing::error!("{err}");
// Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::OutOfPages;
metrics::increment_counter!("tgi_request_failure", "err" => "out_of_pages");
tracing::error!("{err}");

// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Err(err))
.unwrap_or(());
// unwrap_or is valid here as we don't care if the receiver is gone.
entry.response_tx.send(Err(err)).unwrap_or(());
}
}

false
});
// If ids is not empty, we need to update
!ids.is_empty()
}

/// Send responses through the `entry` response channel
Expand Down
3 changes: 0 additions & 3 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,6 @@ def filter(
) -> Optional["FlashCausalLMBatch"]:
if len(updated_requests) == 0:
raise ValueError("Batch must have at least one request")
# We assume that if len(requests) == len(self) then the requests are the same
if len(updated_requests) == len(self):
return self

device = self.input_ids.device

Expand Down

0 comments on commit da426e4

Please sign in to comment.