Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: supports openai chat completions API #1427

Merged
merged 10 commits into from
Jan 16, 2024
39 changes: 25 additions & 14 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions router/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ utoipa = { version = "3.5.0", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] }
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
minijinja = "1.0.10"
futures-util = "0.3.30"

[build-dependencies]
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
Expand Down
51 changes: 38 additions & 13 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
/// Batching and inference logic
use crate::validation::{Validation, ValidationError};
use crate::HubTokenizerConfig;
use crate::{ChatRequest, GenerateRequest, GenerateStreamResponse, PrefillToken};
use crate::{Entry, Queue, Token};
use crate::{GenerateRequest, PrefillToken};
use futures::future::try_join_all;
use minijinja::{Environment, ErrorKind, Template};
use nohash_hasher::IntMap;
use std::sync::{
atomic::{AtomicBool, Ordering},
Expand All @@ -13,7 +15,7 @@ use text_generation_client::{
};
use thiserror::Error;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt;
Expand All @@ -30,6 +32,8 @@ pub struct Infer {
shared: Arc<Shared>,
/// Inference limit
limit_concurrent_requests: Arc<Semaphore>,
/// Chat template
template: Option<Template<'static, 'static>>,
}

/// Infer shared state
Expand All @@ -52,6 +56,7 @@ impl Infer {
window_size: Option<u32>,
speculate: u32,
generation_health: Arc<AtomicBool>,
tokenizer_config: HubTokenizerConfig,
) -> Self {
// Infer shared state
let queue = Queue::new(requires_padding, 16, window_size, speculate);
Expand All @@ -74,11 +79,21 @@ impl Infer {
// Inference limit with a semaphore
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));

let template = tokenizer_config.chat_template.map(|t| {
let env = Box::new(Environment::new());
let template_str = t.into_boxed_str();
// leaking env and template_str as read-only, static resources for performance.
Box::leak(env)
.template_from_str(Box::leak(template_str))
.unwrap()
});

Self {
validation,
queue,
shared,
limit_concurrent_requests: semaphore,
template,
}
}

Expand All @@ -87,14 +102,7 @@ impl Infer {
pub(crate) async fn generate_stream(
&self,
request: GenerateRequest,
) -> Result<
(
OwnedSemaphorePermit,
u32,
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
),
InferError,
> {
) -> Result<GenerateStreamResponse, InferError> {
// Limit concurrent requests by acquiring a permit from the semaphore
let permit = self
.clone()
Expand Down Expand Up @@ -139,6 +147,20 @@ impl Infer {
))
}

/// Apply the chat template to the chat request
#[instrument(skip_all)]
pub(crate) fn apply_chat_template(&self, chat: ChatRequest) -> Result<String, InferError> {
self.template
.as_ref()
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
.render(chat)
.map_err(|e| {
metrics::increment_counter!("tgi_request_failure", "err" => "template");
tracing::error!("{e}");
InferError::TemplateError(e)
})
}

/// Add a new request to the queue and return a InferResponse
#[instrument(skip_all)]
pub(crate) async fn generate(
Expand Down Expand Up @@ -550,9 +572,9 @@ fn send_responses(
let mut iterator = tokens_
.ids
.into_iter()
.zip(tokens_.logprobs.into_iter())
.zip(tokens_.texts.into_iter())
.zip(tokens_.is_special.into_iter())
.zip(tokens_.logprobs)
.zip(tokens_.texts)
.zip(tokens_.is_special)
.enumerate()
.peekable();
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
Expand Down Expand Up @@ -665,6 +687,8 @@ pub enum InferError {
ValidationError(#[from] ValidationError),
#[error("Incomplete generation")]
IncompleteGeneration,
#[error("Template error: {0}")]
TemplateError(#[from] minijinja::Error),
}

impl InferError {
Expand All @@ -674,6 +698,7 @@ impl InferError {
InferError::Overloaded(_) => "overloaded",
InferError::ValidationError(_) => "validation",
InferError::IncompleteGeneration => "incomplete_generation",
InferError::TemplateError(_) => "template_error",
}
}
}
Loading
Loading