diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 07c334a3647..7a1022b8748 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -7,9 +7,11 @@ pub(crate) use health::HealthCheck; use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; use crate::{ ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, - HubTokenizerConfig, Message, MessageChunk, PrefillToken, Text, TextMessage, Token, + HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token, +}; +use crate::{ + FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools, }; -use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools}; use futures::future::try_join_all; use minijinja::{Environment, ErrorKind, Template}; use minijinja_contrib::pycompat; @@ -270,7 +272,11 @@ struct ChatTemplate { } impl ChatTemplate { - fn new(template: String, bos_token: Option, eos_token: Option) -> Self { + fn new( + template: String, + bos_token: Option, + eos_token: Option, + ) -> Self { let mut env = Box::new(Environment::new()); // enable things like .strip() or .capitalize() env.set_unknown_method_callback(pycompat::unknown_method_callback); @@ -287,8 +293,20 @@ impl ChatTemplate { Self { template, - bos_token, - eos_token, + bos_token: match bos_token { + Some(token) => match token { + TokenizerConfigToken::String(token) => Some(token), + TokenizerConfigToken::Object { content } => Some(content), + }, + None => None, + }, + eos_token: match eos_token { + Some(token) => match token { + TokenizerConfigToken::String(token) => Some(token), + TokenizerConfigToken::Object { content } => Some(content), + }, + None => None, + }, use_default_tool_template, } } @@ -301,9 +319,9 @@ impl ChatTemplate { if self.use_default_tool_template { if let Some(last_message) = messages.last_mut() { if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { - last_message.content.push(MessageChunk::Text(Text { + last_message.content.push(MessageChunk::Text { text: format!("\n---\n{}\n{}", tool_prompt, tools), - })); + }); } } } @@ -340,6 +358,14 @@ impl ToolGrammar { .unwrap_or_else(|| panic!("Tool with name {} not found", name)) .clone()] } + ToolType::Function { function } => { + let tool = req_tools + .iter() + .find(|tool| tool.function.name == function.name) + .unwrap_or_else(|| panic!("Tool with name {} not found", function.name)) + .clone(); + vec![tool] + } ToolType::OneOf => req_tools.to_owned(), }; diff --git a/router/src/lib.rs b/router/src/lib.rs index 5ee3eb883b2..b9718c72b9e 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -53,20 +53,37 @@ pub enum ChatTemplateVersions { Multiple(Vec), } +use std::path::Path; + #[derive(Debug, Clone, Deserialize, Default)] pub struct HubTokenizerConfig { pub chat_template: Option, pub completion_template: Option, - #[serde(deserialize_with = "token_serde::deserialize")] - pub bos_token: Option, - #[serde(deserialize_with = "token_serde::deserialize")] - pub eos_token: Option, + pub bos_token: Option, + pub eos_token: Option, } impl HubTokenizerConfig { - pub fn from_file>(filename: P) -> Option { - let content = std::fs::read_to_string(filename).ok()?; - serde_json::from_str(&content).ok() + pub fn from_file>(filename: P) -> Option { + std::fs::read_to_string(filename) + .ok() + .and_then(|content| serde_json::from_str(&content).ok()) + } +} + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +#[serde(untagged)] +pub enum TokenizerConfigToken { + String(String), + Object { content: String }, +} + +impl From for String { + fn from(token: TokenizerConfigToken) -> Self { + match token { + TokenizerConfigToken::String(s) => s, + TokenizerConfigToken::Object { content } => content, + } } } @@ -78,9 +95,10 @@ pub struct HubProcessorConfig { } impl HubProcessorConfig { - pub fn from_file>(filename: P) -> Option { - let content = std::fs::read_to_string(filename).ok()?; - serde_json::from_str(&content).ok() + pub fn from_file>(filename: P) -> Option { + std::fs::read_to_string(filename) + .ok() + .and_then(|content| serde_json::from_str(&content).ok()) } } @@ -99,35 +117,6 @@ pub(crate) enum GrammarType { Regex(String), } -mod token_serde { - use super::*; - use serde::de; - use serde::Deserializer; - use serde_json::Value; - - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> - where - D: Deserializer<'de>, - { - let value = Value::deserialize(deserializer)?; - - match value { - Value::String(s) => Ok(Some(s)), - Value::Object(map) => { - if let Some(content) = map.get("content").and_then(|v| v.as_str()) { - Ok(Some(content.to_string())) - } else { - Err(de::Error::custom( - "content key not found in structured token", - )) - } - } - Value::Null => Ok(None), - _ => Err(de::Error::custom("invalid token format")), - } - } -} - #[derive(Clone, Debug, Serialize, ToSchema)] pub struct Info { /// Model info @@ -337,30 +326,33 @@ fn default_parameters() -> GenerateParameters { } } -mod prompt_serde { - use serde::{self, Deserialize, Deserializer}; - use serde_json::Value; +#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] +#[serde(try_from = "PromptDeserializer")] +pub struct Prompt(pub Vec); - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> - where - D: Deserializer<'de>, - { - let value = Value::deserialize(deserializer)?; +#[derive(Deserialize)] +#[serde(untagged)] +enum PromptDeserializer { + Single(String), + Multiple(Vec), +} + +impl TryFrom for Prompt { + type Error = String; + + fn try_from(value: PromptDeserializer) -> Result { match value { - Value::String(s) => Ok(vec![s]), - Value::Array(arr) if arr.is_empty() => Err(serde::de::Error::custom( - "Empty array detected. Do not use an empty array for the prompt.", - )), - Value::Array(arr) => arr - .iter() - .map(|v| match v { - Value::String(s) => Ok(s.to_owned()), - _ => Err(serde::de::Error::custom("Expected a string")), - }) - .collect(), - _ => Err(serde::de::Error::custom( - "Expected a string or an array of strings", - )), + PromptDeserializer::Single(s) => Ok(Prompt(vec![s])), + PromptDeserializer::Multiple(v) => { + if v.is_empty() { + Err( + "Empty array detected. Do not use an empty array for the prompt." + .to_string(), + ) + } else { + Ok(Prompt(v)) + } + } } } } @@ -374,8 +366,7 @@ pub struct CompletionRequest { /// The prompt to generate completions for. #[schema(example = "What is Deep Learning?")] - #[serde(deserialize_with = "prompt_serde::deserialize")] - pub prompt: Vec, + pub prompt: Prompt, /// The maximum number of tokens that can be generated in the chat completion. #[serde(default)] @@ -802,7 +793,6 @@ pub(crate) struct ChatRequest { /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. #[serde(default)] #[schema(nullable = true, example = "null")] - #[serde(deserialize_with = "deserialize_tool_choice::deserialize")] pub tool_choice: Option, /// Response format constraints for the generation. @@ -818,44 +808,41 @@ fn default_tool_prompt() -> Option { "\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(), ) } -#[derive(Clone, Deserialize, ToSchema, Serialize)] -enum ToolType { - FunctionName(String), + +#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)] +#[serde(untagged)] +pub enum ToolType { OneOf, + FunctionName(String), + Function { function: FunctionName }, } -/// Deserialize the tool choice from the JSON input or from the function name ("none" is allowed but mapped to None) -mod deserialize_tool_choice { - use super::*; - use serde::de; - use serde::Deserializer; - use serde_json::Value; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct FunctionName { + pub name: String, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(from = "ToolTypeDeserializer")] +pub struct ToolChoice(pub Option); - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> - where - D: Deserializer<'de>, - { - let value = Value::deserialize(deserializer)?; +#[derive(Deserialize)] +#[serde(untagged)] +enum ToolTypeDeserializer { + None(Option), + Some(ToolType), +} +impl From for ToolChoice { + fn from(value: ToolTypeDeserializer) -> Self { match value { - Value::String(s) => match s.as_str() { - "none" => Ok(None), - "auto" => Ok(Some(ToolType::OneOf)), - _ => Ok(Some(ToolType::FunctionName(s))), + ToolTypeDeserializer::None(opt) => match opt.as_deref() { + Some("none") => ToolChoice(None), + Some("auto") => ToolChoice(Some(ToolType::OneOf)), + Some(s) => ToolChoice(Some(ToolType::FunctionName(s.to_string()))), + None => ToolChoice(Some(ToolType::OneOf)), }, - Value::Object(map) => { - if let Some(content) = map - .get("function") - .and_then(|v| v.get("name")) - .and_then(|v| v.as_str()) - { - Ok(Some(ToolType::FunctionName(content.to_string()))) - } else { - Err(de::Error::custom("function key not found in tool choice")) - } - } - Value::Null => Ok(Some(ToolType::OneOf)), - _ => Err(de::Error::custom("invalid token format")), + ToolTypeDeserializer::Some(tool_type) => ToolChoice(Some(tool_type)), } } } @@ -931,26 +918,16 @@ pub(crate) struct ToolCall { } #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] -struct Url { +pub struct Url { url: String, } -#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] -struct ImageUrl { - image_url: Url, -} - -#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] -struct Text { - text: String, -} - #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] #[serde(tag = "type")] #[serde(rename_all = "snake_case")] -enum MessageChunk { - Text(Text), - ImageUrl(ImageUrl), +pub enum MessageChunk { + Text { text: String }, + ImageUrl { image_url: Url }, } #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] @@ -958,35 +935,31 @@ pub struct Message { #[schema(example = "user")] role: String, #[schema(example = "My name is David and I")] - #[serde(deserialize_with = "message_content_serde::deserialize")] - content: Vec, + pub content: MessageContent, #[serde(default, skip_serializing_if = "Option::is_none")] #[schema(example = "\"David\"")] name: Option, } -mod message_content_serde { - use super::*; - use serde::{Deserialize, Deserializer}; - - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> - where - D: Deserializer<'de>, - { - #[derive(Deserialize)] - #[serde(untagged)] - enum Message { - Text(String), - Chunks(Vec), - } - let message: Message = Deserialize::deserialize(deserializer)?; - let chunks = match message { - Message::Text(text) => { - vec![MessageChunk::Text(Text { text })] +#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)] +#[serde(untagged)] +pub enum MessageContent { + SingleText(String), + MultipleChunks(Vec), +} + +// Pushing a chunk to a single text message will convert it to a multiple chunks message +impl MessageContent { + pub fn push(&mut self, chunk: MessageChunk) { + match self { + MessageContent::SingleText(text) => { + *self = + MessageContent::MultipleChunks(vec![MessageChunk::Text { text: text.clone() }]); } - Message::Chunks(s) => s, - }; - Ok(chunks) + MessageContent::MultipleChunks(chunks) => { + chunks.push(chunk); + } + } } } @@ -1002,18 +975,17 @@ impl From for TextMessage { fn from(value: Message) -> Self { TextMessage { role: value.role, - content: value - .content - .into_iter() - .map(|c| match c { - MessageChunk::Text(Text { text }) => text, - MessageChunk::ImageUrl(image) => { - let url = image.image_url.url; - format!("![]({url})") - } - }) - .collect::>() - .join(""), + content: match value.content { + MessageContent::SingleText(text) => text, + MessageContent::MultipleChunks(chunks) => chunks + .into_iter() + .map(|chunk| match chunk { + MessageChunk::Text { text } => text, + MessageChunk::ImageUrl { image_url } => format!("![]({})", image_url.url), + }) + .collect::>() + .join(""), + }, } } } @@ -1221,9 +1193,16 @@ mod tests { ); assert_eq!( config.bos_token, - Some("<|begin▁of▁sentence|>".to_string()) + Some(TokenizerConfigToken::String( + "<|begin▁of▁sentence|>".to_string() + )) + ); + assert_eq!( + config.eos_token, + Some(TokenizerConfigToken::String( + "<|end▁of▁sentence|>".to_string() + )) ); - assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string())); // in this case we expect the tokens to be encoded as structured tokens // we want the content of the structured token @@ -1256,9 +1235,16 @@ mod tests { ); assert_eq!( config.bos_token, - Some("<|begin▁of▁sentence|>".to_string()) + Some(TokenizerConfigToken::Object { + content: "<|begin▁of▁sentence|>".to_string() + }) + ); + assert_eq!( + config.eos_token, + Some(TokenizerConfigToken::Object { + content: "<|end▁of▁sentence|>".to_string() + }) ); - assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string())); } #[test] @@ -1276,9 +1262,7 @@ mod tests { request.messages[0], Message { role: "user".to_string(), - content: vec![MessageChunk::Text(Text { - text: "What is Deep Learning?".to_string() - }),], + content: MessageContent::SingleText("What is Deep Learning?".to_string()), name: None } ); @@ -1302,10 +1286,10 @@ mod tests { request.messages[0], Message{ role: "user".to_string(), - content: vec![ - MessageChunk::Text(Text { text: "Whats in this image?".to_string() }), - MessageChunk::ImageUrl(ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } }) - ], + content: MessageContent::MultipleChunks(vec![ + MessageChunk::Text { text: "Whats in this image?".to_string() }, + MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() }}, + ]), name: None } ); @@ -1315,10 +1299,10 @@ mod tests { fn text_message_convert() { let message = Message{ role: "user".to_string(), - content: vec![ - MessageChunk::Text(Text { text: "Whats in this image?".to_string() }), - MessageChunk::ImageUrl(ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } }) - ], + content: MessageContent::MultipleChunks(vec![ + MessageChunk::Text { text: "Whats in this image?".to_string() }, + MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } } + ]), name: None }; let textmsg: TextMessage = message.into(); diff --git a/router/src/server.rs b/router/src/server.rs index 466741a2c63..34c8a257720 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -636,7 +636,7 @@ async fn completions( )); } - if req.prompt.len() > info.max_client_batch_size { + if req.prompt.0.len() > info.max_client_batch_size { metrics::increment_counter!("tgi_request_failure", "err" => "validation"); return Err(( StatusCode::UNPROCESSABLE_ENTITY, @@ -652,6 +652,7 @@ async fn completions( let generate_requests: Vec = req .prompt + .0 .iter() .map(|prompt| GenerateRequest { inputs: prompt.to_string(),