Skip to content

Commit

Permalink
feat: supports openai chat completions API
Browse files Browse the repository at this point in the history
prefer PR from original repo rather than fork to run CI #1408
  • Loading branch information
drbh committed Jan 10, 2024
1 parent da27fbd commit 53fca4c
Show file tree
Hide file tree
Showing 7 changed files with 675 additions and 216 deletions.
39 changes: 25 additions & 14 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions router/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ utoipa = { version = "3.5.0", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] }
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
minijinja = "1.0.10"
futures-util = "0.3.30"

[build-dependencies]
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
Expand Down
46 changes: 31 additions & 15 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
/// Batching and inference logic
use crate::validation::{Validation, ValidationError};
use crate::HubTokenizerConfig;
use crate::{ChatRequest, GenerateRequest, GenerateStreamResponse, PrefillToken};
use crate::{Entry, Queue, Token};
use crate::{GenerateRequest, PrefillToken};
use futures::future::try_join_all;
use nohash_hasher::IntMap;
use std::sync::{
Expand All @@ -13,7 +14,7 @@ use text_generation_client::{
};
use thiserror::Error;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt;
Expand All @@ -26,6 +27,8 @@ pub struct Infer {
validation: Validation,
/// Request queue
queue: Queue,
/// Chat formatter
tokenizer_config: HubTokenizerConfig,
/// Shared state
shared: Arc<Shared>,
/// Inference limit
Expand All @@ -52,6 +55,7 @@ impl Infer {
window_size: Option<u32>,
speculate: u32,
generation_health: Arc<AtomicBool>,
tokenizer_config: HubTokenizerConfig,
) -> Self {
// Infer shared state
let queue = Queue::new(requires_padding, 16, window_size, speculate);
Expand Down Expand Up @@ -79,6 +83,7 @@ impl Infer {
queue,
shared,
limit_concurrent_requests: semaphore,
tokenizer_config,
}
}

Expand All @@ -87,13 +92,7 @@ impl Infer {
pub(crate) async fn generate_stream(
&self,
request: GenerateRequest,
) -> Result<
(
OwnedSemaphorePermit,
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
),
InferError,
> {
) -> Result<GenerateStreamResponse, InferError> {
// Limit concurrent requests by acquiring a permit from the semaphore
let permit = self
.clone()
Expand All @@ -117,7 +116,7 @@ impl Infer {

// Append the request to the queue
self.queue.append(Entry {
request: valid_request,
request: valid_request.clone(),
response_tx,
span: Span::current(),
temp_span: None,
Expand All @@ -130,7 +129,19 @@ impl Infer {
self.shared.batching_task.notify_one();

// Return stream
Ok((permit, UnboundedReceiverStream::new(response_rx)))
Ok((
permit,
valid_request,
UnboundedReceiverStream::new(response_rx),
))
}

/// Apply the chat template to the chat request
#[instrument(skip_all)]
pub(crate) fn apply_chat_template(&self, chat: ChatRequest) -> Result<String, InferError> {
self.tokenizer_config
.apply_chat_template(chat)
.map_err(InferError::TemplateError)
}

/// Add a new request to the queue and return a InferResponse
Expand All @@ -142,7 +153,7 @@ impl Infer {
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);

// Create stream and keep semaphore permit as long as generate lives
let (_permit, mut stream) = self.generate_stream(request).await?;
let (_permit, valid_request, mut stream) = self.generate_stream(request).await?;

// Return values
let mut result_prefill = Vec::new();
Expand Down Expand Up @@ -195,6 +206,7 @@ impl Infer {
(result_generated_text, result_queued, result_start)
{
Ok(InferResponse {
prompt_token_count: valid_request.input_length,
prefill: result_prefill,
tokens: result_tokens,
generated_text,
Expand Down Expand Up @@ -543,9 +555,9 @@ fn send_responses(
let mut iterator = tokens_
.ids
.into_iter()
.zip(tokens_.logprobs.into_iter())
.zip(tokens_.texts.into_iter())
.zip(tokens_.is_special.into_iter())
.zip(tokens_.logprobs)
.zip(tokens_.texts)
.zip(tokens_.is_special)
.enumerate()
.peekable();
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
Expand Down Expand Up @@ -636,6 +648,7 @@ pub(crate) enum InferStreamResponse {

#[derive(Debug)]
pub(crate) struct InferResponse {
pub(crate) prompt_token_count: u32,
pub(crate) prefill: Vec<PrefillToken>,
pub(crate) tokens: Vec<Token>,
pub(crate) generated_text: GeneratedText,
Expand All @@ -654,6 +667,8 @@ pub enum InferError {
ValidationError(#[from] ValidationError),
#[error("Incomplete generation")]
IncompleteGeneration,
#[error("Template error: {0}")]
TemplateError(#[from] minijinja::Error),
}

impl InferError {
Expand All @@ -663,6 +678,7 @@ impl InferError {
InferError::Overloaded(_) => "overloaded",
InferError::ValidationError(_) => "validation",
InferError::IncompleteGeneration => "incomplete_generation",
InferError::TemplateError(_) => "template_error",
}
}
}
Loading

0 comments on commit 53fca4c

Please sign in to comment.