From 53fca4c48a3edc0547727452db8d097be5adc804 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 10 Jan 2024 10:08:51 -0500 Subject: [PATCH] feat: supports openai chat completions API prefer PR from original repo rather than fork to run CI https://github.com/huggingface/text-generation-inference/pull/1408 --- Cargo.lock | 39 ++-- router/Cargo.toml | 2 + router/src/infer.rs | 46 +++-- router/src/lib.rs | 233 +++++++++++++++++++++++- router/src/main.rs | 380 +++++++++++++++++++++------------------ router/src/server.rs | 189 ++++++++++++++++++- router/src/validation.rs | 2 +- 7 files changed, 675 insertions(+), 216 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f826ea34206..3baff665c8d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -773,9 +773,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff4dd66668b557604244583e3e1e1eada8c5c2e96a6d0d6653ede395b78bbacb" +checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" dependencies = [ "futures-core", "futures-sink", @@ -783,9 +783,9 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb1d22c66e66d9d72e1758f0bd7d4fd0bee04cad842ee34587d68c07e45d088c" +checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" [[package]] name = "futures-executor" @@ -800,15 +800,15 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bf34a163b5c4c52d0478a4d757da8fb65cabef42ba90515efee0f6f9fa45aaa" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" [[package]] name = "futures-macro" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", @@ -817,21 +817,21 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e36d3378ee38c2a36ad710c5d30c2911d752cb941c00c72dbabfb786a7970817" +checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" [[package]] name = "futures-task" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efd193069b0ddadc69c46389b740bbccdd97203899b48d09c5f7969591d6bae2" +checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" [[package]] name = "futures-util" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a19526d624e703a3179b3d322efec918b6246ea0fa51d41124525f00f1cc8104" +checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ "futures-channel", "futures-core", @@ -1373,6 +1373,15 @@ dependencies = [ "unicase", ] +[[package]] +name = "minijinja" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "208758577ef2c86cf5dd3e85730d161413ec3284e2d73b2ef65d9a24d9971bcb" +dependencies = [ + "serde", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -2807,10 +2816,12 @@ dependencies = [ "axum-tracing-opentelemetry", "clap", "futures", + "futures-util", "hf-hub", "init-tracing-opentelemetry", "metrics", "metrics-exporter-prometheus", + "minijinja", "ngrok", "nohash-hasher", "opentelemetry", diff --git a/router/Cargo.toml b/router/Cargo.toml index 5ccdb0cd48f..f6f16dae08c 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -43,6 +43,8 @@ utoipa = { version = "3.5.0", features = ["axum_extras"] } utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] } ngrok = { version = "0.13.1", features = ["axum"], optional = true } init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } +minijinja = "1.0.10" +futures-util = "0.3.30" [build-dependencies] vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] } diff --git a/router/src/infer.rs b/router/src/infer.rs index bf5920dac7e..0c058f122df 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -1,7 +1,8 @@ /// Batching and inference logic use crate::validation::{Validation, ValidationError}; +use crate::HubTokenizerConfig; +use crate::{ChatRequest, GenerateRequest, GenerateStreamResponse, PrefillToken}; use crate::{Entry, Queue, Token}; -use crate::{GenerateRequest, PrefillToken}; use futures::future::try_join_all; use nohash_hasher::IntMap; use std::sync::{ @@ -13,7 +14,7 @@ use text_generation_client::{ }; use thiserror::Error; use tokio::sync::mpsc::error::SendError; -use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; +use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::StreamExt; @@ -26,6 +27,8 @@ pub struct Infer { validation: Validation, /// Request queue queue: Queue, + /// Chat formatter + tokenizer_config: HubTokenizerConfig, /// Shared state shared: Arc, /// Inference limit @@ -52,6 +55,7 @@ impl Infer { window_size: Option, speculate: u32, generation_health: Arc, + tokenizer_config: HubTokenizerConfig, ) -> Self { // Infer shared state let queue = Queue::new(requires_padding, 16, window_size, speculate); @@ -79,6 +83,7 @@ impl Infer { queue, shared, limit_concurrent_requests: semaphore, + tokenizer_config, } } @@ -87,13 +92,7 @@ impl Infer { pub(crate) async fn generate_stream( &self, request: GenerateRequest, - ) -> Result< - ( - OwnedSemaphorePermit, - UnboundedReceiverStream>, - ), - InferError, - > { + ) -> Result { // Limit concurrent requests by acquiring a permit from the semaphore let permit = self .clone() @@ -117,7 +116,7 @@ impl Infer { // Append the request to the queue self.queue.append(Entry { - request: valid_request, + request: valid_request.clone(), response_tx, span: Span::current(), temp_span: None, @@ -130,7 +129,19 @@ impl Infer { self.shared.batching_task.notify_one(); // Return stream - Ok((permit, UnboundedReceiverStream::new(response_rx))) + Ok(( + permit, + valid_request, + UnboundedReceiverStream::new(response_rx), + )) + } + + /// Apply the chat template to the chat request + #[instrument(skip_all)] + pub(crate) fn apply_chat_template(&self, chat: ChatRequest) -> Result { + self.tokenizer_config + .apply_chat_template(chat) + .map_err(InferError::TemplateError) } /// Add a new request to the queue and return a InferResponse @@ -142,7 +153,7 @@ impl Infer { let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0); // Create stream and keep semaphore permit as long as generate lives - let (_permit, mut stream) = self.generate_stream(request).await?; + let (_permit, valid_request, mut stream) = self.generate_stream(request).await?; // Return values let mut result_prefill = Vec::new(); @@ -195,6 +206,7 @@ impl Infer { (result_generated_text, result_queued, result_start) { Ok(InferResponse { + prompt_token_count: valid_request.input_length, prefill: result_prefill, tokens: result_tokens, generated_text, @@ -543,9 +555,9 @@ fn send_responses( let mut iterator = tokens_ .ids .into_iter() - .zip(tokens_.logprobs.into_iter()) - .zip(tokens_.texts.into_iter()) - .zip(tokens_.is_special.into_iter()) + .zip(tokens_.logprobs) + .zip(tokens_.texts) + .zip(tokens_.is_special) .enumerate() .peekable(); while let Some((i, (((id, logprob), text), special))) = iterator.next() { @@ -636,6 +648,7 @@ pub(crate) enum InferStreamResponse { #[derive(Debug)] pub(crate) struct InferResponse { + pub(crate) prompt_token_count: u32, pub(crate) prefill: Vec, pub(crate) tokens: Vec, pub(crate) generated_text: GeneratedText, @@ -654,6 +667,8 @@ pub enum InferError { ValidationError(#[from] ValidationError), #[error("Incomplete generation")] IncompleteGeneration, + #[error("Template error: {0}")] + TemplateError(#[from] minijinja::Error), } impl InferError { @@ -663,6 +678,7 @@ impl InferError { InferError::Overloaded(_) => "overloaded", InferError::ValidationError(_) => "validation", InferError::IncompleteGeneration => "incomplete_generation", + InferError::TemplateError(_) => "template_error", } } } diff --git a/router/src/lib.rs b/router/src/lib.rs index 898fcd040d4..d5394f61286 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -5,12 +5,22 @@ mod queue; pub mod server; mod validation; -use infer::Infer; +use crate::validation::ValidGenerateRequest; +use infer::{Infer, InferError, InferStreamResponse}; use queue::{Entry, Queue}; use serde::{Deserialize, Serialize}; +use tokio::sync::OwnedSemaphorePermit; +use tokio_stream::wrappers::UnboundedReceiverStream; use utoipa::ToSchema; use validation::Validation; +/// Type alias for generation responses +pub(crate) type GenerateStreamResponse = ( + OwnedSemaphorePermit, + ValidGenerateRequest, + UnboundedReceiverStream>, +); + /// Hub type #[derive(Clone, Debug, Deserialize)] pub struct HubModelInfo { @@ -20,6 +30,28 @@ pub struct HubModelInfo { pub pipeline_tag: Option, } +#[derive(Clone, Deserialize)] +pub struct HubTokenizerConfig { + #[serde(default)] + pub chat_template: Option, +} + +impl HubTokenizerConfig { + /// Apply the chat template to the chat request + pub(crate) fn apply_chat_template( + &self, + chat: ChatRequest, + ) -> Result { + let mut env = minijinja::Environment::new(); + let chat_template = self + .chat_template + .as_ref() + .ok_or(minijinja::ErrorKind::TemplateNotFound)?; + env.add_template("_", chat_template)?; + env.get_template("_")?.render(chat) + } +} + #[derive(Clone, Debug, Serialize, ToSchema)] pub struct Info { /// Model info @@ -152,7 +184,7 @@ fn default_parameters() -> GenerateParameters { top_k: None, top_p: None, typical_p: None, - do_sample: false, + do_sample: true, max_new_tokens: default_max_new_tokens(), return_full_text: None, stop: Vec::new(), @@ -165,6 +197,190 @@ fn default_parameters() -> GenerateParameters { } } +#[derive(Clone, Deserialize, Serialize)] +pub(crate) struct ChatCompletion { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub system_fingerprint: String, + pub choices: Vec, + pub usage: Usage, +} + +#[derive(Clone, Deserialize, Serialize)] +pub(crate) struct ChatCompletionComplete { + pub index: u32, + pub message: Message, + pub logprobs: Option>, + pub finish_reason: String, +} + +#[derive(Clone, Deserialize, Serialize)] +pub(crate) struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +impl ChatCompletion { + pub(crate) fn new( + model: String, + system_fingerprint: String, + output: String, + created: u64, + details: Details, + return_logprobs: bool, + ) -> Self { + Self { + id: String::new(), + object: "text_completion".into(), + created, + model, + system_fingerprint, + choices: vec![ChatCompletionComplete { + index: 0, + message: Message { + role: "assistant".into(), + content: output, + }, + logprobs: return_logprobs + .then(|| details.tokens.iter().map(|t| t.logprob).collect()), + finish_reason: details.finish_reason.to_string(), + }], + usage: Usage { + prompt_tokens: details.prompt_token_count, + completion_tokens: details.generated_tokens, + total_tokens: details.prompt_token_count + details.generated_tokens, + }, + } + } +} + +#[derive(Clone, Deserialize, Serialize)] +pub(crate) struct ChatCompletionChunk { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub system_fingerprint: String, + pub choices: Vec, +} + +#[derive(Clone, Deserialize, Serialize)] +pub(crate) struct ChatCompletionChoice { + pub index: u32, + pub delta: ChatCompletionDelta, + pub logprobs: Option, + pub finish_reason: Option, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub(crate) struct ChatCompletionDelta { + pub role: String, + pub content: String, +} + +impl ChatCompletionChunk { + pub(crate) fn new( + model: String, + system_fingerprint: String, + delta: String, + created: u64, + index: u32, + logprobs: Option, + finish_reason: Option, + ) -> Self { + Self { + id: "".to_string(), + object: "text_completion".to_string(), + created, + model, + system_fingerprint, + choices: vec![ChatCompletionChoice { + index, + delta: ChatCompletionDelta { + role: "assistant".to_string(), + content: delta, + }, + logprobs, + finish_reason, + }], + } + } +} + +fn default_request_messages() -> Vec { + vec![Message { + role: "system".to_string(), + content: "My name is David and I".to_string(), + }] +} + +#[derive(Clone, Deserialize, ToSchema, Serialize)] +pub(crate) struct ChatRequest { + /// UNUSED + #[schema(example = "bigscience/blomm-560m")] + /// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. + pub model: String, /* NOTE: UNUSED */ + + /// A list of messages comprising the conversation so far. + #[serde(default = "default_request_messages")] + pub messages: Vec, + + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, + /// decreasing the model's likelihood to repeat the same line verbatim. + #[serde(default)] + pub frequency_penalty: Option, + + /// UNUSED + /// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens + /// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, + /// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, + /// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should + /// result in a ban or exclusive selection of the relevant token. + #[serde(default)] + pub logit_bias: Option>, + + /// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each + /// output token returned in the content of message. + #[serde(default)] + pub logprobs: Option, + + /// UNUSED + /// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with + /// an associated log probability. logprobs must be set to true if this parameter is used. + #[serde(default)] + pub top_logprobs: Option, + + /// The maximum number of tokens that can be generated in the chat completion. + #[serde(default)] + pub max_tokens: Option, + + /// UNUSED + /// How many chat completion choices to generate for each input message. Note that you will be charged based on the + /// number of generated tokens across all of the choices. Keep n as 1 to minimize costs. + #[serde(default)] + pub n: Option, + + /// UNUSED + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, + /// increasing the model's likelihood to talk about new topics + #[serde(default)] + pub presence_penalty: Option, + + #[serde(default = "bool::default")] + pub stream: bool, +} + +#[derive(Clone, Deserialize, ToSchema, Serialize)] +pub(crate) struct Message { + #[schema(example = "system")] + pub role: String, + #[schema(example = "My name is David and I")] + pub content: String, +} + #[derive(Clone, Debug, Deserialize, ToSchema)] pub(crate) struct GenerateRequest { #[schema(example = "My name is Olivier and I")] @@ -227,6 +443,16 @@ pub(crate) enum FinishReason { StopSequence, } +impl std::fmt::Display for FinishReason { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + FinishReason::Length => write!(f, "length"), + FinishReason::EndOfSequenceToken => write!(f, "eos_token"), + FinishReason::StopSequence => write!(f, "stop_sequence"), + } + } +} + #[derive(Serialize, ToSchema)] pub(crate) struct BestOfSequence { #[schema(example = "test")] @@ -257,6 +483,8 @@ pub(crate) struct Details { pub best_of_sequences: Option>, #[serde(skip_serializing_if = "Vec::is_empty")] pub top_tokens: Vec>, + #[schema(example = 1)] + pub prompt_token_count: u32, } #[derive(Serialize, ToSchema)] @@ -279,6 +507,7 @@ pub(crate) struct StreamDetails { #[derive(Serialize, ToSchema)] pub(crate) struct StreamResponse { + pub index: u32, pub token: Token, #[serde(skip_serializing_if = "Vec::is_empty")] pub top_tokens: Vec, diff --git a/router/src/main.rs b/router/src/main.rs index 4637c77c7f0..875070d1aae 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -1,22 +1,19 @@ +/// Text Generation Inference webserver entrypoint use axum::http::HeaderValue; use clap::Parser; -use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo}; -use hf_hub::{Repo, RepoType}; use opentelemetry::sdk::propagation::TraceContextPropagator; use opentelemetry::sdk::trace; use opentelemetry::sdk::trace::Sampler; use opentelemetry::sdk::Resource; use opentelemetry::{global, KeyValue}; use opentelemetry_otlp::WithExportConfig; -/// Text Generation Inference webserver entrypoint -use std::fs::File; -use std::io::BufReader; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::Path; +use std::time::Duration; use text_generation_client::{ClientError, ShardedClient}; -use text_generation_router::{server, HubModelInfo}; +use text_generation_router::{server, HubModelInfo, HubTokenizerConfig}; use thiserror::Error; -use tokenizers::Tokenizer; +use tokenizers::{FromPretrainedParameters, Tokenizer}; use tower_http::cors::AllowOrigin; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; @@ -72,8 +69,7 @@ struct Args { ngrok_edge: Option, } -#[tokio::main] -async fn main() -> Result<(), RouterError> { +fn main() -> Result<(), RouterError> { // Get args let args = Args::parse(); // Pattern match configuration @@ -102,9 +98,6 @@ async fn main() -> Result<(), RouterError> { ngrok_edge, } = args; - // Launch Tokio runtime - init_logging(otlp_endpoint, json_output); - // Validate args if max_input_length >= max_total_tokens { return Err(RouterError::ArgumentValidation( @@ -148,158 +141,161 @@ async fn main() -> Result<(), RouterError> { // This will only be used to validate payloads let local_path = Path::new(&tokenizer_name); let local_model = local_path.exists() && local_path.is_dir(); - - let (tokenizer, model_info) = if local_model { - // Get Model info - let model_info = HubModelInfo { - model_id: tokenizer_name.clone(), - sha: None, - pipeline_tag: None, - }; - + let tokenizer = if local_model { // Load local tokenizer - let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok(); - - (tokenizer, model_info) + Tokenizer::from_file(local_path.join("tokenizer.json")).ok() } else { - let mut builder = ApiBuilder::new() - .with_progress(false) - .with_token(authorization_token); - - if let Some(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE").ok() { - builder = builder.with_cache_dir(cache_dir.into()); - } - - if revision.is_none() { - tracing::warn!("`--revision` is not set"); - tracing::warn!("We strongly advise to set it to a known supported commit."); - } - - let api = builder.build().unwrap(); - let api_repo = api.repo(Repo::with_revision( - tokenizer_name.clone(), - RepoType::Model, - revision.clone().unwrap_or("main".to_string()), - )); - - // Get Model info - let model_info = get_model_info(&api_repo).await.unwrap_or_else(|| { - tracing::warn!("Could not retrieve model info from the Hugging Face hub."); - HubModelInfo { - model_id: tokenizer_name.to_string(), - sha: None, - pipeline_tag: None, - } - }); - - let tokenizer = match api_repo.get("tokenizer.json").await { - Ok(tokenizer_filename) => Tokenizer::from_file(tokenizer_filename).ok(), - Err(_) => get_base_tokenizer(&api, &api_repo).await, + // Download and instantiate tokenizer + // We need to download it outside of the Tokio runtime + let params = FromPretrainedParameters { + revision: revision.clone().unwrap_or("main".to_string()), + auth_token: authorization_token.clone(), + ..Default::default() }; - - (tokenizer, model_info) + Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).ok() }; - if tokenizer.is_none() { - tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}"); - tracing::warn!("Rust input length validation and truncation is disabled"); - } - - // if pipeline-tag == text-generation we default to return_full_text = true - let compat_return_full_text = match &model_info.pipeline_tag { - None => { - tracing::warn!("no pipeline tag found for model {tokenizer_name}"); - false - } - Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation", - }; + // Launch Tokio runtime + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build()? + .block_on(async { + init_logging(otlp_endpoint, json_output); - // Instantiate sharded client from the master unix socket - let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) - .await - .map_err(RouterError::Connection)?; - // Clear the cache; useful if the webserver rebooted - sharded_client - .clear_cache(None) - .await - .map_err(RouterError::Cache)?; - // Get info from the shard - let shard_info = sharded_client.info().await.map_err(RouterError::Info)?; - - // Warmup model - tracing::info!("Warming up model"); - let max_supported_batch_total_tokens = match sharded_client - .warmup( - max_input_length as u32, - max_batch_prefill_tokens, - max_total_tokens as u32, - ) - .await - .map_err(RouterError::Warmup)? - { - // Older models do not support automatic max-batch-total-tokens - None => { - let max_batch_total_tokens = max_batch_total_tokens - .unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))); - tracing::warn!("Model does not support automatic max batch total tokens"); - max_batch_total_tokens - } - // Flash attention models return their max supported total tokens - Some(max_supported_batch_total_tokens) => { - // Warn if user added his own max-batch-total-tokens as we will ignore it - if max_batch_total_tokens.is_some() { + if tokenizer.is_none() { tracing::warn!( - "`--max-batch-total-tokens` is deprecated for Flash \ - Attention models." + "Could not find a fast tokenizer implementation for {tokenizer_name}" ); - tracing::warn!( - "Inferred max batch total tokens: {max_supported_batch_total_tokens}" - ); - } - if max_total_tokens as u32 > max_supported_batch_total_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_supported_batch_total_tokens}"))); + tracing::warn!("Rust input length validation and truncation is disabled"); } - max_supported_batch_total_tokens - } - }; - tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}"); - tracing::info!("Connected"); - - let addr = match hostname.parse() { - Ok(ip) => SocketAddr::new(ip, port), - Err(_) => { - tracing::warn!("Invalid hostname, defaulting to 0.0.0.0"); - SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port) - } - }; - - // Run server - server::run( - model_info, - shard_info, - compat_return_full_text, - max_concurrent_requests, - max_best_of, - max_stop_sequences, - max_top_n_tokens, - max_input_length, - max_total_tokens, - waiting_served_ratio, - max_batch_prefill_tokens, - max_supported_batch_total_tokens, - max_waiting_tokens, - sharded_client, - tokenizer, - validation_workers, - addr, - cors_allow_origin, - ngrok, - ngrok_authtoken, - ngrok_edge, - ) - .await?; - Ok(()) + // Get Model info + let model_info = match local_model { + true => HubModelInfo { + model_id: tokenizer_name.clone(), + sha: None, + pipeline_tag: None, + }, + false => get_model_info(&tokenizer_name, revision.as_deref(), authorization_token.as_deref()) + .await + .unwrap_or_else(|| { + tracing::warn!("Could not retrieve model info from the Hugging Face hub."); + HubModelInfo { + model_id: tokenizer_name.to_string(), + sha: None, + pipeline_tag: None, + } + }), + }; + + let tokenizer_config: HubTokenizerConfig = match local_model { + true => HubTokenizerConfig{ + chat_template: None, + }, + false => get_tokenizer_config(&tokenizer_name, revision.as_deref(), authorization_token.as_deref()) + .await.unwrap_or_else(|| { + tracing::warn!("Could not retrieve tokenizer config from the Hugging Face hub."); + HubTokenizerConfig{ + chat_template: None, + } + }), + }; + + + // if pipeline-tag == text-generation we default to return_full_text = true + let compat_return_full_text = match &model_info.pipeline_tag { + None => { + tracing::warn!("no pipeline tag found for model {tokenizer_name}"); + false + } + Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation", + }; + + // Instantiate sharded client from the master unix socket + let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) + .await + .map_err(RouterError::Connection)?; + // Clear the cache; useful if the webserver rebooted + sharded_client + .clear_cache(None) + .await + .map_err(RouterError::Cache)?; + // Get info from the shard + let shard_info = sharded_client.info().await.map_err(RouterError::Info)?; + + // Warmup model + tracing::info!("Warming up model"); + let max_supported_batch_total_tokens = match sharded_client + .warmup(max_input_length as u32, max_batch_prefill_tokens, max_total_tokens as u32) + .await + .map_err(RouterError::Warmup)? + { + // Older models do not support automatic max-batch-total-tokens + None => { + let max_batch_total_tokens = max_batch_total_tokens.unwrap_or( + 16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)), + ); + tracing::warn!("Model does not support automatic max batch total tokens"); + max_batch_total_tokens + } + // Flash attention models return their max supported total tokens + Some(max_supported_batch_total_tokens) => { + // Warn if user added his own max-batch-total-tokens as we will ignore it + if max_batch_total_tokens.is_some() { + tracing::warn!( + "`--max-batch-total-tokens` is deprecated for Flash \ + Attention models." + ); + tracing::warn!( + "Inferred max batch total tokens: {max_supported_batch_total_tokens}" + ); + } + if max_total_tokens as u32 > max_supported_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_supported_batch_total_tokens}"))); + } + + max_supported_batch_total_tokens + } + }; + tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}"); + tracing::info!("Connected"); + + let addr = match hostname.parse() { + Ok(ip) => SocketAddr::new(ip, port), + Err(_) => { + tracing::warn!("Invalid hostname, defaulting to 0.0.0.0"); + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port) + } + }; + + // Run server + server::run( + model_info, + shard_info, + compat_return_full_text, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_length, + max_total_tokens, + waiting_served_ratio, + max_batch_prefill_tokens, + max_supported_batch_total_tokens, + max_waiting_tokens, + sharded_client, + tokenizer, + validation_workers, + addr, + cors_allow_origin, + ngrok, + ngrok_authtoken, + ngrok_edge, + tokenizer_config, + ) + .await?; + Ok(()) + }) } /// Init logging using env variables LOG_LEVEL and LOG_FORMAT: @@ -358,8 +354,30 @@ fn init_logging(otlp_endpoint: Option, json_output: bool) { } /// get model info from the Huggingface Hub -pub async fn get_model_info(api: &ApiRepo) -> Option { - let response = api.info_request().send().await.ok()?; +pub async fn get_model_info( + model_id: &str, + revision: Option<&str>, + token: Option<&str>, +) -> Option { + let revision = match revision { + None => { + tracing::warn!("`--revision` is not set"); + tracing::warn!("We strongly advise to set it to a known supported commit."); + "main".to_string() + } + Some(revision) => revision.to_string(), + }; + + let client = reqwest::Client::new(); + // Poor man's urlencode + let revision = revision.replace('/', "%2F"); + let url = format!("https://huggingface.co/api/models/{model_id}/revision/{revision}"); + let mut builder = client.get(url).timeout(Duration::from_secs(5)); + if let Some(token) = token { + builder = builder.bearer_auth(token); + } + + let response = builder.send().await.ok()?; if response.status().is_success() { let hub_model_info: HubModelInfo = @@ -376,26 +394,36 @@ pub async fn get_model_info(api: &ApiRepo) -> Option { } } -/// get base tokenizer -pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option { - let config_filename = api_repo.get("config.json").await.ok()?; - - // Open the file in read-only mode with buffer. - let file = File::open(config_filename).ok()?; - let reader = BufReader::new(file); - - // Read the JSON contents of the file as an instance of `User`. - let config: serde_json::Value = serde_json::from_reader(reader).ok()?; - - if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") { - let api_base_repo = api.repo(Repo::with_revision( - base_model_id.to_string(), - RepoType::Model, - "main".to_string(), - )); - - let tokenizer_filename = api_base_repo.get("tokenizer.json").await.ok()?; - Tokenizer::from_file(tokenizer_filename).ok() +/// get tokenizer_config from the Huggingface Hub +pub async fn get_tokenizer_config( + model_id: &str, + revision: Option<&str>, + token: Option<&str>, +) -> Option { + let revision = match revision { + None => { + tracing::warn!("`--revision` is not set"); + tracing::warn!("We strongly advise to set it to a known supported commit."); + "main".to_string() + } + Some(revision) => revision.to_string(), + }; + let client = reqwest::Client::new(); + // Poor man's urlencode + let revision = revision.replace('/', "%2F"); + let url = format!( + "https://huggingface.co/{}/raw/{}/tokenizer_config.json", + model_id, revision + ); + let mut builder = client.get(url).timeout(Duration::from_secs(5)); + if let Some(token) = token { + builder = builder.bearer_auth(token); + } + let response = builder.send().await.ok()?; + if response.status().is_success() { + let text = response.text().await.ok()?; + let hub_tokenizer_config: HubTokenizerConfig = serde_json::from_str(&text).ok()?; + Some(hub_tokenizer_config) } else { None } diff --git a/router/src/server.rs b/router/src/server.rs index fe1b83090b4..f0fe3118855 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -2,10 +2,11 @@ use crate::health::Health; use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::validation::ValidationError; +use crate::HubTokenizerConfig; use crate::{ - BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason, - GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken, - StreamDetails, StreamResponse, Token, Validation, + BestOfSequence, ChatCompletion, ChatCompletionChunk, ChatRequest, CompatGenerateRequest, + Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, + HubModelInfo, Infer, Info, PrefillToken, StreamDetails, StreamResponse, Token, Validation, }; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -206,6 +207,7 @@ async fn generate( seed: response.generated_text.seed, best_of_sequences, top_tokens: response.top_tokens, + prompt_token_count: response.prompt_token_count, }) } false => None, @@ -337,6 +339,21 @@ async fn generate_stream( HeaderMap, Sse>>, ) { + let on_message_callback = |stream_token: StreamResponse| { + let event = Event::default(); + event.json_data(stream_token).unwrap() + }; + let (headers, response_stream) = + generate_stream_internal(infer, Json(req), on_message_callback).await; + let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); + (headers, sse) +} + +async fn generate_stream_internal( + infer: Infer, + Json(req): Json, + on_message_callback: impl Fn(StreamResponse) -> Event, +) -> (HeaderMap, impl Stream>) { let span = tracing::Span::current(); let start_time = Instant::now(); metrics::increment_counter!("tgi_request_count"); @@ -378,9 +395,11 @@ async fn generate_stream( } else { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { // Keep permit as long as generate_stream lives - Ok((_permit, mut response_stream)) => { + Ok((_permit, _valid_request, mut response_stream)) => { + let mut index = 0; // Server-Sent Event stream while let Some(response) = response_stream.next().await { + index += 1; match response { Ok(response) => { match response { @@ -395,13 +414,14 @@ async fn generate_stream( // StreamResponse let stream_token = StreamResponse { + index, token, top_tokens, generated_text: None, details: None, }; - - yield Ok(Event::default().json_data(stream_token).unwrap()) + let event = on_message_callback(stream_token); + yield Ok(event); } // Yield event for last token and compute timings InferStreamResponse::End { @@ -457,13 +477,16 @@ async fn generate_stream( tracing::info!(parent: &span, "Success"); let stream_token = StreamResponse { + index, token, top_tokens, generated_text: Some(output_text), details }; - yield Ok(Event::default().json_data(stream_token).unwrap()); + + let event = on_message_callback(stream_token); + yield Ok(event); break; } } @@ -494,7 +517,153 @@ async fn generate_stream( } }; - (headers, Sse::new(stream).keep_alive(KeepAlive::default())) + (headers, stream) +} + +/// Generate tokens +#[utoipa::path( + post, + tag = "Text Generation Inference", + path = "/v1/chat/completions", + request_body = ChatRequest, + responses( + (status = 200, description = "Generated Text", body = GenerateResponse), + (status = 424, description = "Generation Error", body = ErrorResponse, + example = json ! ({"error": "Request failed during generation"})), + (status = 429, description = "Model is overloaded", body = ErrorResponse, + example = json ! ({"error": "Model is overloaded"})), + (status = 422, description = "Input validation error", body = ErrorResponse, + example = json ! ({"error": "Input validation error"})), + (status = 500, description = "Incomplete generation", body = ErrorResponse, + example = json ! ({"error": "Incomplete generation"})), + ) + )] +#[instrument( + skip_all, + fields( + // parameters = ? req.parameters, + total_time, + validation_time, + queue_time, + inference_time, + time_per_token, + seed, + ) + )] +async fn chat_completions( + Extension(infer): Extension, + Extension(info): Extension, + Json(req): Json, +) -> Result)> { + metrics::increment_counter!("tgi_request_count"); + + let stream = req.stream; + let max_new_tokens = req.max_tokens.or(Some(100)); + let repetition_penalty = req + .frequency_penalty + // rescale frequency_penalty from (-2.0, 2.0) to (0.0, 4.0) + .map(|x| x + 2.0); + let logprobs = req.logprobs.unwrap_or(false); + + // apply chat template to flatten the request into a single input + let inputs = match infer.apply_chat_template(req) { + Ok(inputs) => inputs, + Err(err) => { + metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + tracing::error!("{err}"); + return Err(( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: err.to_string(), + error_type: err.error_type().to_string(), + }), + )); + } + }; + + // build the request passing some parameters + let generate_request = GenerateRequest { + inputs: inputs.to_string(), + parameters: GenerateParameters { + best_of: None, + temperature: None, + repetition_penalty, + top_k: None, + top_p: None, + typical_p: None, + do_sample: true, + max_new_tokens, + return_full_text: None, + stop: Vec::new(), + truncate: None, + watermark: false, + details: true, + decoder_input_details: false, + seed: None, + top_n_tokens: None, + }, + }; + + // static values that will be returned in all cases + let model_id = info.model_id.clone(); + let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native")); + + // switch on stream + if stream { + // pass this callback to the stream generation and build the required event structure + let on_message_callback = move |stream_token: StreamResponse| { + let event = Event::default(); + + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); + + event + .json_data(ChatCompletionChunk::new( + model_id.clone(), + system_fingerprint.clone(), + stream_token.token.text, + current_time, + stream_token.index, + logprobs.then_some(stream_token.token.logprob), + stream_token.details.map(|d| d.finish_reason.to_string()), + )) + .map_or_else( + |e| { + println!("Failed to serialize ChatCompletionChunk: {:?}", e); + Event::default() + }, + |data| data, + ) + }; + + let (headers, response_stream) = + generate_stream_internal(infer, Json(generate_request), on_message_callback).await; + let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); + Ok((headers, sse).into_response()) + } else { + let (headers, Json(generation)) = + generate(Extension(infer), Json(generate_request)).await?; + + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); + + // build the complete response object with the full text + let response = ChatCompletion::new( + generation.generated_text, + model_id, + system_fingerprint, + current_time, + generation.details.unwrap(), + logprobs, + ); + + // wrap generation inside a Vec to match api-inference + Ok((headers, Json(response)).into_response()) + } } /// Prometheus metrics scrape endpoint @@ -532,6 +701,7 @@ pub async fn run( ngrok: bool, ngrok_authtoken: Option, ngrok_edge: Option, + tokenizer_config: HubTokenizerConfig, ) -> Result<(), axum::BoxError> { // OpenAPI documentation #[derive(OpenApi)] @@ -598,6 +768,7 @@ pub async fn run( shard_info.window_size, shard_info.speculate, generation_health, + tokenizer_config, ); // Duration buckets @@ -687,6 +858,7 @@ pub async fn run( .route("/info", get(get_model_info)) .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) + .route("/v1/chat/completions", post(chat_completions)) // AWS Sagemaker route .route("/invocations", post(compat_generate)) // Base Health route @@ -816,6 +988,7 @@ impl From for (StatusCode, Json) { InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS, InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, + InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY, }; ( diff --git a/router/src/validation.rs b/router/src/validation.rs index 64f25c82994..370e9588a49 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -376,7 +376,7 @@ type TokenizerRequest = ( Span, ); -#[derive(Debug)] +#[derive(Debug, Clone)] pub(crate) struct ValidGenerateRequest { pub inputs: String, pub input_length: u32,