Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: supports openai chat completions API #1408

Closed
wants to merge 9 commits into from
39 changes: 25 additions & 14 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions router/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] }
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
hf-hub = "0.3.1"
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
minijinja = "1.0.10"
futures-util = "0.3.30"

[build-dependencies]
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
Expand Down
46 changes: 45 additions & 1 deletion router/src/infer.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
/// Batching and inference logic
use crate::validation::{Validation, ValidationError};
use crate::HubTokenizerConfig;
use crate::{ChatRequest, GenerateRequest, PrefillToken};
use crate::{Entry, Queue, Token};
use crate::{GenerateRequest, PrefillToken};
use futures::future::try_join_all;
use nohash_hasher::IntMap;
use std::sync::{
Expand All @@ -26,6 +27,8 @@ pub struct Infer {
validation: Validation,
/// Request queue
queue: Queue,
/// Chat formatter
tokenizer_config: HubTokenizerConfig,
/// Shared state
shared: Arc<Shared>,
/// Inference limit
Expand All @@ -52,6 +55,7 @@ impl Infer {
window_size: Option<u32>,
speculate: u32,
generation_health: Arc<AtomicBool>,
tokenizer_config: HubTokenizerConfig,
) -> Self {
// Infer shared state
let queue = Queue::new(requires_padding, 16, window_size, speculate);
Expand Down Expand Up @@ -79,6 +83,7 @@ impl Infer {
queue,
shared,
limit_concurrent_requests: semaphore,
tokenizer_config,
}
}

Expand Down Expand Up @@ -133,6 +138,28 @@ impl Infer {
Ok((permit, UnboundedReceiverStream::new(response_rx)))
}

/// Apply the chat template to the chat request
#[instrument(skip_all)]
pub(crate) fn apply_chat_template(
&self,
chat: ChatRequest,
) -> Result<String, ChatTemplateError> {
let mut env = minijinja::Environment::new();
let chat_template = self
.tokenizer_config
.chat_template
.as_ref()
.ok_or(ChatTemplateError::TemplateNotFound)?;
env.add_template("_", chat_template)
.map_err(|e| ChatTemplateError::TemplateError(e))?;
let jinja_tmpl = env
.get_template("_")
.map_err(|e| ChatTemplateError::TemplateError(e))?;
jinja_tmpl
.render(chat)
.map_err(|e| ChatTemplateError::TemplateError(e))
}

/// Add a new request to the queue and return a InferResponse
#[instrument(skip_all)]
pub(crate) async fn generate(
Expand Down Expand Up @@ -666,3 +693,20 @@ impl InferError {
}
}
}

#[derive(Debug, Error)]
pub enum ChatTemplateError {
#[error("Template error: {0}")]
TemplateError(#[from] minijinja::Error),
#[error("Template not found")]
TemplateNotFound,
}

impl ChatTemplateError {
pub(crate) fn error_type(&self) -> &str {
drbh marked this conversation as resolved.
Show resolved Hide resolved
match self {
ChatTemplateError::TemplateError(_) => "template_error",
ChatTemplateError::TemplateNotFound => "template_not_found",
}
}
}
182 changes: 182 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ pub struct HubModelInfo {
pub pipeline_tag: Option<String>,
}

#[derive(Clone, Deserialize)]
pub struct HubTokenizerConfig {
#[serde(default)]
pub chat_template: Option<String>,
}

#[derive(Clone, Debug, Serialize, ToSchema)]
pub struct Info {
/// Model info
Expand Down Expand Up @@ -165,6 +171,182 @@ fn default_parameters() -> GenerateParameters {
}
}

#[derive(Clone, Deserialize, Serialize)]
pub(crate) struct ChatCompletion {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub system_fingerprint: String,
pub choices: Vec<ChatCompletionComplete>,
pub usage: Usage,
}

#[derive(Clone, Deserialize, Serialize)]
pub(crate) struct ChatCompletionComplete {
pub index: u32,
pub message: Message,
pub logprobs: Option<Vec<f32>>,
pub finish_reason: Option<String>,
}

#[derive(Clone, Deserialize, Serialize)]
pub(crate) struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}

impl ChatCompletion {
pub(crate) fn new(
ouput: String,
created: u64,
details: Details,
prompt_character_count: u32,
) -> Self {
Self {
id: "".to_string(),
object: "text_completion".to_string(),
created,
model: "".to_string(),
drbh marked this conversation as resolved.
Show resolved Hide resolved
system_fingerprint: "".to_string(),
drbh marked this conversation as resolved.
Show resolved Hide resolved
choices: vec![ChatCompletionComplete {
index: 0,
message: Message {
role: "assistant".to_string(),
content: ouput,
},
logprobs: None,
finish_reason: None,
drbh marked this conversation as resolved.
Show resolved Hide resolved
}],
usage: Usage {
prompt_tokens: prompt_character_count,
completion_tokens: details.generated_tokens,
total_tokens: prompt_character_count + details.generated_tokens,
},
}
}
}

#[derive(Clone, Deserialize, Serialize)]
pub(crate) struct ChatCompletionChunk {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub system_fingerprint: String,
pub choices: Vec<ChatCompletionChoice>,
}

#[derive(Clone, Deserialize, Serialize)]
pub(crate) struct ChatCompletionChoice {
pub index: u32,
pub delta: ChatCompletionDelta,
pub logprobs: Option<Vec<f32>>,
pub finish_reason: Option<String>,
}

#[derive(Clone, Debug, Deserialize, Serialize)]
pub(crate) struct ChatCompletionDelta {
pub role: String,
pub content: String,
}

impl ChatCompletionChunk {
pub(crate) fn new(delta: String, created: u64, index: u32) -> Self {
Self {
id: "".to_string(),
object: "text_completion".to_string(),
created,
model: "".to_string(),
system_fingerprint: "".to_string(),
choices: vec![ChatCompletionChoice {
index,
delta: ChatCompletionDelta {
role: "assistant".to_string(),
content: delta,
},
logprobs: None,
finish_reason: None,
drbh marked this conversation as resolved.
Show resolved Hide resolved
}],
}
}
}

fn default_request_messages() -> Vec<Message> {
vec![Message {
role: "system".to_string(),
content: "My name is David and I".to_string(),
}]
}

#[derive(Clone, Deserialize, ToSchema, Serialize)]
pub(crate) struct ChatRequest {
/// UNUSED
#[schema(example = "bigscience/blomm-560m")]
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
pub model: String, /* NOTE: UNUSED */

/// A list of messages comprising the conversation so far.
#[serde(default = "default_request_messages")]
pub messages: Vec<Message>,

/// UNUSED
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
/// decreasing the model's likelihood to repeat the same line verbatim.
#[serde(default)]
pub frequency_penalty: Option<f32>,
drbh marked this conversation as resolved.
Show resolved Hide resolved

/// UNUSED
/// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
/// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
/// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,
/// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should
/// result in a ban or exclusive selection of the relevant token.
#[serde(default)]
pub logit_bias: Option<Vec<f32>>,

/// UNUSED
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
/// output token returned in the content of message. This option is currently not available on the gpt-4-vision-preview
/// model.
#[serde(default)]
pub logprobs: Option<u32>,
drbh marked this conversation as resolved.
Show resolved Hide resolved

/// UNUSED
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
/// an associated log probability. logprobs must be set to true if this parameter is used.
#[serde(default)]
pub top_logprobs: Option<u32>,

/// The maximum number of tokens that can be generated in the chat completion.
#[serde(default)]
pub max_tokens: Option<u32>,

/// UNUSED
/// How many chat completion choices to generate for each input message. Note that you will be charged based on the
/// number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
#[serde(default)]
pub n: Option<u32>,

/// UNUSED
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
/// increasing the model's likelihood to talk about new topics
#[serde(default)]
pub presence_penalty: Option<f32>,

#[serde(default = "bool::default")]
pub stream: bool,
}

#[derive(Clone, Deserialize, ToSchema, Serialize)]
pub(crate) struct Message {
#[schema(example = "system")]
pub role: String,
#[schema(example = "My name is David and I")]
pub content: String,
}

#[derive(Clone, Debug, Deserialize, ToSchema)]
pub(crate) struct GenerateRequest {
#[schema(example = "My name is Olivier and I")]
Expand Down
Loading
Loading