Skip to content

Commit

Permalink
Fixing some simple stuff, adding speculate to budget.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Dec 5, 2023
1 parent 5aa3a01 commit 09839b0
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 26 deletions.
2 changes: 1 addition & 1 deletion proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ message InfoResponse {
string dtype = 2;
string device_type = 3;
optional uint32 window_size = 4;
optional uint32 speculate = 5;
uint32 speculate = 5;
}

/// Empty request
Expand Down
3 changes: 2 additions & 1 deletion router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,11 @@ impl Infer {
max_concurrent_requests: usize,
requires_padding: bool,
window_size: Option<u32>,
speculate: u32,
generation_health: Arc<AtomicBool>,
) -> Self {
// Infer shared state
let queue = Queue::new(requires_padding, 16, window_size);
let queue = Queue::new(requires_padding, 16, window_size, speculate);
let shared = Arc::new(Shared {
batching_task: Notify::new(),
});
Expand Down
53 changes: 40 additions & 13 deletions router/src/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ pub(crate) struct Queue {
}

impl Queue {
pub(crate) fn new(requires_padding: bool, block_size: u32, window_size: Option<u32>) -> Self {
pub(crate) fn new(
requires_padding: bool,
block_size: u32,
window_size: Option<u32>,
speculate: u32,
) -> Self {
// Create channel
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();

Expand All @@ -43,6 +48,7 @@ impl Queue {
requires_padding,
block_size,
window_size,
speculate,
queue_receiver,
));

Expand Down Expand Up @@ -91,9 +97,10 @@ async fn queue_task(
requires_padding: bool,
block_size: u32,
window_size: Option<u32>,
speculate: u32,
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
) {
let mut state = State::new(requires_padding, block_size, window_size);
let mut state = State::new(requires_padding, block_size, window_size, speculate);

while let Some(cmd) = receiver.recv().await {
match cmd {
Expand Down Expand Up @@ -136,17 +143,26 @@ struct State {

/// Sliding window
window_size: Option<u32>,

/// Speculation amount
speculate: u32,
}

impl State {
fn new(requires_padding: bool, block_size: u32, window_size: Option<u32>) -> Self {
fn new(
requires_padding: bool,
block_size: u32,
window_size: Option<u32>,
speculate: u32,
) -> Self {
Self {
entries: VecDeque::with_capacity(128),
next_id: 0,
next_batch_id: 0,
requires_padding,
block_size,
window_size,
speculate,
}
}

Expand Down Expand Up @@ -221,7 +237,7 @@ impl State {
window_size.saturating_sub(entry.request.input_length),
entry.request.stopping_parameters.max_new_tokens,
),
};
} + self.speculate;

// pad to block size
decode_tokens +=
Expand Down Expand Up @@ -359,7 +375,7 @@ mod tests {

#[test]
fn test_append() {
let mut state = State::new(false, 1, None);
let mut state = State::new(false, 1, None, 0);
let (entry, _guard) = default_entry();

assert_eq!(state.next_id, 0);
Expand All @@ -375,15 +391,15 @@ mod tests {

#[test]
fn test_next_batch_empty() {
let mut state = State::new(false, 1, None);
let mut state = State::new(false, 1, None, 0);

assert!(state.next_batch(None, 1, 1).is_none());
assert!(state.next_batch(Some(1), 1, 1).is_none());
}

#[test]
fn test_next_batch_min_size() {
let mut state = State::new(false, 1, None);
let mut state = State::new(false, 1, None, 0);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
Expand Down Expand Up @@ -415,7 +431,7 @@ mod tests {

#[test]
fn test_next_batch_token_budget() {
let mut state = State::new(false, 1, None);
let mut state = State::new(false, 1, None, 0);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
Expand Down Expand Up @@ -448,22 +464,22 @@ mod tests {

#[tokio::test]
async fn test_queue_append() {
let queue = Queue::new(false, 1, None);
let queue = Queue::new(false, 1, None, 0);
let (entry, _guard) = default_entry();
queue.append(entry);
}

#[tokio::test]
async fn test_queue_next_batch_empty() {
let queue = Queue::new(false, 1, None);
let queue = Queue::new(false, 1, None, 0);

assert!(queue.next_batch(None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), 1, 1).await.is_none());
}

#[tokio::test]
async fn test_queue_next_batch_min_size() {
let queue = Queue::new(false, 1, None);
let queue = Queue::new(false, 1, None, 0);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
Expand Down Expand Up @@ -496,7 +512,7 @@ mod tests {

#[tokio::test]
async fn test_queue_next_batch_token_budget() {
let queue = Queue::new(false, 1, None);
let queue = Queue::new(false, 1, None, 0);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
Expand All @@ -519,9 +535,20 @@ mod tests {
assert_eq!(batch.size, 2);
}

#[tokio::test]
async fn test_queue_next_batch_token_speculate() {
let queue = Queue::new(false, 1, None, 2);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
queue.append(entry2);

assert!(queue.next_batch(None, 1, 1).await.is_none());
}

#[tokio::test]
async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new(false, 1, None);
let queue = Queue::new(false, 1, None, 0);
let (entry, _) = default_entry();
queue.append(entry);

Expand Down
1 change: 1 addition & 0 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,7 @@ pub async fn run(
max_concurrent_requests,
shard_info.requires_padding,
shard_info.window_size,
shard_info.speculate,
generation_health,
);

Expand Down
16 changes: 5 additions & 11 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,8 +938,6 @@ def generate_token(
batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds,
batch.top_n_tokens,
# next_token_ids,
# next_token_logprobs,
accepted_ids,
batch_top_token_ids,
batch_top_token_logprobs,
Expand All @@ -957,8 +955,6 @@ def generate_token(
do_sample,
seed,
top_n_tokens,
# next_token_id,
# next_token_logprob,
n_accepted_ids,
top_token_ids,
top_token_logprobs,
Expand All @@ -968,21 +964,18 @@ def generate_token(
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids]

next_token_texts = []
left = 0
for j in range(index, index + n_accepted_ids):
# Generated token
all_input_ids.append(next_token_ids[j])
next_token_id = next_token_ids[j]
all_input_ids.append(next_token_id)
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids,
prefix_offset,
read_offset,
)
next_token_texts.append(next_token_text)
index += n_accepted_ids

# Evaluate stopping criteria

left = 0
for j, next_token_id in enumerate(_next_token_ids):
stop, reason = stopping_criteria(
next_token_id,
next_token_text,
Expand All @@ -994,6 +987,7 @@ def generate_token(
break
else:
stopped = False
index += n_accepted_ids
_next_token_ids = _next_token_ids[:len(_next_token_ids) - left]

# Shard generations
Expand All @@ -1003,7 +997,7 @@ def generate_token(
# Decode generated tokens
# Remove potentially accepted ids that do not respect
# the stopping_criteria
_ids = all_input_ids[:len(all_input_ids)-left]
_ids = all_input_ids
output_text, _, _ = self.decode_token(
_ids,
prefix_offset=len(_ids)
Expand Down
7 changes: 7 additions & 0 deletions server/text_generation_server/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from transformers import PreTrainedTokenizerBase, PretrainedConfig

from text_generation_server.models.types import Batch, Generation
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.pb.generate_pb2 import InfoResponse

B = TypeVar("B", bound=Batch)
Expand All @@ -22,6 +23,7 @@ def __init__(
rank: int = 0,
world_size: int = 1,
sliding_window: Optional[int] = None,
speculate: Optional[int] = None,
):
self.model = model.eval()
self.tokenizer = tokenizer
Expand All @@ -33,6 +35,10 @@ def __init__(
self.world_size = world_size
self.sliding_window = sliding_window

if speculate is None:
speculate = get_speculate()
self.speculate = speculate

self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None)
is not None
Expand All @@ -50,6 +56,7 @@ def info(self) -> InfoResponse:
dtype=str(self.dtype),
device_type=self.device.type,
window_size=self.sliding_window,
speculate=self.speculate
)

@property
Expand Down

0 comments on commit 09839b0

Please sign in to comment.