From 89a0c5d3781800f3a864008bedf14f82e40aedc0 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 1 Jul 2024 09:08:05 -0400 Subject: [PATCH] fix: prefer serde structs over custom functions (#2127) * fix: prefer enum for chat object * fix: adjust typo * fix: enum CompletionType not ObjectType * fix: adjust typo * feat: leverage serde for conditional deser * fix: adjust HubTokenizerConfig after rebase * fix: update create_post_processor logic for token type * fix: adjust unwrap syntax in template * Fixing the post processor. --------- Co-authored-by: Nicolas Patry --- router/src/infer/mod.rs | 28 +++- router/src/lib.rs | 317 +++++++++++++++++++--------------------- router/src/main.rs | 29 ++-- router/src/server.rs | 38 ++--- 4 files changed, 208 insertions(+), 204 deletions(-) diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 07c334a3647..49282eb9eca 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -7,9 +7,11 @@ pub(crate) use health::HealthCheck; use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; use crate::{ ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, - HubTokenizerConfig, Message, MessageChunk, PrefillToken, Text, TextMessage, Token, + HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token, +}; +use crate::{ + FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools, }; -use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools}; use futures::future::try_join_all; use minijinja::{Environment, ErrorKind, Template}; use minijinja_contrib::pycompat; @@ -270,7 +272,11 @@ struct ChatTemplate { } impl ChatTemplate { - fn new(template: String, bos_token: Option, eos_token: Option) -> Self { + fn new( + template: String, + bos_token: Option, + eos_token: Option, + ) -> Self { let mut env = Box::new(Environment::new()); // enable things like .strip() or .capitalize() env.set_unknown_method_callback(pycompat::unknown_method_callback); @@ -287,8 +293,8 @@ impl ChatTemplate { Self { template, - bos_token, - eos_token, + bos_token: bos_token.map(|token| token.as_str().to_string()), + eos_token: eos_token.map(|token| token.as_str().to_string()), use_default_tool_template, } } @@ -301,9 +307,9 @@ impl ChatTemplate { if self.use_default_tool_template { if let Some(last_message) = messages.last_mut() { if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { - last_message.content.push(MessageChunk::Text(Text { + last_message.content.push(MessageChunk::Text { text: format!("\n---\n{}\n{}", tool_prompt, tools), - })); + }); } } } @@ -340,6 +346,14 @@ impl ToolGrammar { .unwrap_or_else(|| panic!("Tool with name {} not found", name)) .clone()] } + ToolType::Function { function } => { + let tool = req_tools + .iter() + .find(|tool| tool.function.name == function.name) + .unwrap_or_else(|| panic!("Tool with name {} not found", function.name)) + .clone(); + vec![tool] + } ToolType::OneOf => req_tools.to_owned(), }; diff --git a/router/src/lib.rs b/router/src/lib.rs index a5b97af36ee..9ecfa051258 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -53,23 +53,40 @@ pub enum ChatTemplateVersions { Multiple(Vec), } +use std::path::Path; + #[derive(Debug, Clone, Deserialize, Default)] pub struct HubTokenizerConfig { pub chat_template: Option, pub completion_template: Option, - #[serde(deserialize_with = "token_serde::deserialize")] - pub bos_token: Option, - #[serde(deserialize_with = "token_serde::deserialize")] - pub eos_token: Option, + pub bos_token: Option, + pub eos_token: Option, pub tokenizer_class: Option, pub add_bos_token: Option, pub add_eos_token: Option, } impl HubTokenizerConfig { - pub fn from_file>(filename: P) -> Option { - let content = std::fs::read_to_string(filename).ok()?; - serde_json::from_str(&content).ok() + pub fn from_file>(filename: P) -> Option { + std::fs::read_to_string(filename) + .ok() + .and_then(|content| serde_json::from_str(&content).ok()) + } +} + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +#[serde(untagged)] +pub enum TokenizerConfigToken { + String(String), + Object { content: String }, +} + +impl TokenizerConfigToken { + pub fn as_str(&self) -> &str { + match self { + TokenizerConfigToken::String(s) => s, + TokenizerConfigToken::Object { content } => content, + } } } @@ -100,9 +117,10 @@ pub struct HubProcessorConfig { } impl HubProcessorConfig { - pub fn from_file>(filename: P) -> Option { - let content = std::fs::read_to_string(filename).ok()?; - serde_json::from_str(&content).ok() + pub fn from_file>(filename: P) -> Option { + std::fs::read_to_string(filename) + .ok() + .and_then(|content| serde_json::from_str(&content).ok()) } } @@ -121,35 +139,6 @@ pub(crate) enum GrammarType { Regex(String), } -mod token_serde { - use super::*; - use serde::de; - use serde::Deserializer; - use serde_json::Value; - - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> - where - D: Deserializer<'de>, - { - let value = Value::deserialize(deserializer)?; - - match value { - Value::String(s) => Ok(Some(s)), - Value::Object(map) => { - if let Some(content) = map.get("content").and_then(|v| v.as_str()) { - Ok(Some(content.to_string())) - } else { - Err(de::Error::custom( - "content key not found in structured token", - )) - } - } - Value::Null => Ok(None), - _ => Err(de::Error::custom("invalid token format")), - } - } -} - #[derive(Clone, Debug, Serialize, ToSchema)] pub struct Info { /// Model info @@ -359,30 +348,33 @@ fn default_parameters() -> GenerateParameters { } } -mod prompt_serde { - use serde::{self, Deserialize, Deserializer}; - use serde_json::Value; +#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] +#[serde(try_from = "PromptDeserializer")] +pub struct Prompt(pub Vec); + +#[derive(Deserialize)] +#[serde(untagged)] +enum PromptDeserializer { + Single(String), + Multiple(Vec), +} + +impl TryFrom for Prompt { + type Error = String; - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> - where - D: Deserializer<'de>, - { - let value = Value::deserialize(deserializer)?; + fn try_from(value: PromptDeserializer) -> Result { match value { - Value::String(s) => Ok(vec![s]), - Value::Array(arr) if arr.is_empty() => Err(serde::de::Error::custom( - "Empty array detected. Do not use an empty array for the prompt.", - )), - Value::Array(arr) => arr - .iter() - .map(|v| match v { - Value::String(s) => Ok(s.to_owned()), - _ => Err(serde::de::Error::custom("Expected a string")), - }) - .collect(), - _ => Err(serde::de::Error::custom( - "Expected a string or an array of strings", - )), + PromptDeserializer::Single(s) => Ok(Prompt(vec![s])), + PromptDeserializer::Multiple(v) => { + if v.is_empty() { + Err( + "Empty array detected. Do not use an empty array for the prompt." + .to_string(), + ) + } else { + Ok(Prompt(v)) + } + } } } } @@ -396,8 +388,7 @@ pub struct CompletionRequest { /// The prompt to generate completions for. #[schema(example = "What is Deep Learning?")] - #[serde(deserialize_with = "prompt_serde::deserialize")] - pub prompt: Vec, + pub prompt: Prompt, /// The maximum number of tokens that can be generated in the chat completion. #[serde(default)] @@ -445,7 +436,6 @@ pub struct CompletionRequest { #[derive(Clone, Deserialize, Serialize, ToSchema, Default)] pub(crate) struct Completion { pub id: String, - pub object: String, #[schema(example = "1706270835")] pub created: u64, #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] @@ -466,7 +456,6 @@ pub(crate) struct CompletionComplete { #[derive(Clone, Deserialize, Serialize, ToSchema)] pub(crate) struct ChatCompletion { pub id: String, - pub object: String, #[schema(example = "1706270835")] pub created: u64, #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] @@ -562,6 +551,15 @@ pub(crate) struct Usage { pub total_tokens: u32, } +#[derive(Clone, Serialize, ToSchema)] +#[serde(tag = "object")] +enum CompletionType { + #[serde(rename = "chat.completion.chunk")] + ChatCompletionChunk(ChatCompletionChunk), + #[serde(rename = "chat.completion")] + ChatCompletion(ChatCompletion), +} + impl ChatCompletion { pub(crate) fn new( model: String, @@ -598,7 +596,6 @@ impl ChatCompletion { }; Self { id: String::new(), - object: "chat.completion".into(), created, model, system_fingerprint, @@ -620,7 +617,6 @@ impl ChatCompletion { #[derive(Clone, Deserialize, Serialize, ToSchema)] pub(crate) struct CompletionCompleteChunk { pub id: String, - pub object: String, pub created: u64, pub choices: Vec, pub model: String, @@ -630,7 +626,6 @@ pub(crate) struct CompletionCompleteChunk { #[derive(Clone, Serialize, ToSchema)] pub(crate) struct ChatCompletionChunk { pub id: String, - pub object: String, #[schema(example = "1706270978")] pub created: u64, #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] @@ -710,7 +705,6 @@ impl ChatCompletionChunk { }; Self { id: String::new(), - object: "chat.completion.chunk".to_string(), created, model, system_fingerprint, @@ -821,7 +815,6 @@ pub(crate) struct ChatRequest { /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. #[serde(default)] #[schema(nullable = true, example = "null")] - #[serde(deserialize_with = "deserialize_tool_choice::deserialize")] pub tool_choice: Option, /// Response format constraints for the generation. @@ -837,44 +830,41 @@ fn default_tool_prompt() -> Option { "\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(), ) } -#[derive(Clone, Deserialize, ToSchema, Serialize)] -enum ToolType { - FunctionName(String), + +#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)] +#[serde(untagged)] +pub enum ToolType { OneOf, + FunctionName(String), + Function { function: FunctionName }, } -/// Deserialize the tool choice from the JSON input or from the function name ("none" is allowed but mapped to None) -mod deserialize_tool_choice { - use super::*; - use serde::de; - use serde::Deserializer; - use serde_json::Value; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct FunctionName { + pub name: String, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(from = "ToolTypeDeserializer")] +pub struct ToolChoice(pub Option); - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> - where - D: Deserializer<'de>, - { - let value = Value::deserialize(deserializer)?; +#[derive(Deserialize)] +#[serde(untagged)] +enum ToolTypeDeserializer { + None(Option), + Some(ToolType), +} +impl From for ToolChoice { + fn from(value: ToolTypeDeserializer) -> Self { match value { - Value::String(s) => match s.as_str() { - "none" => Ok(None), - "auto" => Ok(Some(ToolType::OneOf)), - _ => Ok(Some(ToolType::FunctionName(s))), + ToolTypeDeserializer::None(opt) => match opt.as_deref() { + Some("none") => ToolChoice(None), + Some("auto") => ToolChoice(Some(ToolType::OneOf)), + Some(s) => ToolChoice(Some(ToolType::FunctionName(s.to_string()))), + None => ToolChoice(Some(ToolType::OneOf)), }, - Value::Object(map) => { - if let Some(content) = map - .get("function") - .and_then(|v| v.get("name")) - .and_then(|v| v.as_str()) - { - Ok(Some(ToolType::FunctionName(content.to_string()))) - } else { - Err(de::Error::custom("function key not found in tool choice")) - } - } - Value::Null => Ok(Some(ToolType::OneOf)), - _ => Err(de::Error::custom("invalid token format")), + ToolTypeDeserializer::Some(tool_type) => ToolChoice(Some(tool_type)), } } } @@ -950,26 +940,16 @@ pub(crate) struct ToolCall { } #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] -struct Url { +pub struct Url { url: String, } -#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] -struct ImageUrl { - image_url: Url, -} - -#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] -struct Text { - text: String, -} - #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] #[serde(tag = "type")] #[serde(rename_all = "snake_case")] -enum MessageChunk { - Text(Text), - ImageUrl(ImageUrl), +pub enum MessageChunk { + Text { text: String }, + ImageUrl { image_url: Url }, } #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] @@ -977,35 +957,31 @@ pub struct Message { #[schema(example = "user")] role: String, #[schema(example = "My name is David and I")] - #[serde(deserialize_with = "message_content_serde::deserialize")] - content: Vec, + pub content: MessageContent, #[serde(default, skip_serializing_if = "Option::is_none")] #[schema(example = "\"David\"")] name: Option, } -mod message_content_serde { - use super::*; - use serde::{Deserialize, Deserializer}; - - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> - where - D: Deserializer<'de>, - { - #[derive(Deserialize)] - #[serde(untagged)] - enum Message { - Text(String), - Chunks(Vec), - } - let message: Message = Deserialize::deserialize(deserializer)?; - let chunks = match message { - Message::Text(text) => { - vec![MessageChunk::Text(Text { text })] +#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)] +#[serde(untagged)] +pub enum MessageContent { + SingleText(String), + MultipleChunks(Vec), +} + +// Pushing a chunk to a single text message will convert it to a multiple chunks message +impl MessageContent { + pub fn push(&mut self, chunk: MessageChunk) { + match self { + MessageContent::SingleText(text) => { + *self = + MessageContent::MultipleChunks(vec![MessageChunk::Text { text: text.clone() }]); } - Message::Chunks(s) => s, - }; - Ok(chunks) + MessageContent::MultipleChunks(chunks) => { + chunks.push(chunk); + } + } } } @@ -1021,18 +997,17 @@ impl From for TextMessage { fn from(value: Message) -> Self { TextMessage { role: value.role, - content: value - .content - .into_iter() - .map(|c| match c { - MessageChunk::Text(Text { text }) => text, - MessageChunk::ImageUrl(image) => { - let url = image.image_url.url; - format!("![]({url})") - } - }) - .collect::>() - .join(""), + content: match value.content { + MessageContent::SingleText(text) => text, + MessageContent::MultipleChunks(chunks) => chunks + .into_iter() + .map(|chunk| match chunk { + MessageChunk::Text { text } => text, + MessageChunk::ImageUrl { image_url } => format!("![]({})", image_url.url), + }) + .collect::>() + .join(""), + }, } } } @@ -1240,9 +1215,16 @@ mod tests { ); assert_eq!( config.bos_token, - Some("<|begin▁of▁sentence|>".to_string()) + Some(TokenizerConfigToken::String( + "<|begin▁of▁sentence|>".to_string() + )) + ); + assert_eq!( + config.eos_token, + Some(TokenizerConfigToken::String( + "<|end▁of▁sentence|>".to_string() + )) ); - assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string())); // in this case we expect the tokens to be encoded as structured tokens // we want the content of the structured token @@ -1275,9 +1257,16 @@ mod tests { ); assert_eq!( config.bos_token, - Some("<|begin▁of▁sentence|>".to_string()) + Some(TokenizerConfigToken::Object { + content: "<|begin▁of▁sentence|>".to_string() + }) + ); + assert_eq!( + config.eos_token, + Some(TokenizerConfigToken::Object { + content: "<|end▁of▁sentence|>".to_string() + }) ); - assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string())); } #[test] @@ -1295,9 +1284,7 @@ mod tests { request.messages[0], Message { role: "user".to_string(), - content: vec![MessageChunk::Text(Text { - text: "What is Deep Learning?".to_string() - }),], + content: MessageContent::SingleText("What is Deep Learning?".to_string()), name: None } ); @@ -1321,10 +1308,10 @@ mod tests { request.messages[0], Message{ role: "user".to_string(), - content: vec![ - MessageChunk::Text(Text { text: "Whats in this image?".to_string() }), - MessageChunk::ImageUrl(ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } }) - ], + content: MessageContent::MultipleChunks(vec![ + MessageChunk::Text { text: "Whats in this image?".to_string() }, + MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() }}, + ]), name: None } ); @@ -1334,10 +1321,10 @@ mod tests { fn text_message_convert() { let message = Message{ role: "user".to_string(), - content: vec![ - MessageChunk::Text(Text { text: "Whats in this image?".to_string() }), - MessageChunk::ImageUrl(ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } }) - ], + content: MessageContent::MultipleChunks(vec![ + MessageChunk::Text { text: "Whats in this image?".to_string() }, + MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } } + ]), name: None }; let textmsg: TextMessage = message.into(); diff --git a/router/src/main.rs b/router/src/main.rs index 08e14f79eb1..21942104e95 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -557,11 +557,11 @@ pub fn create_post_processor( if add_bos_token { if let Some(bos) = bos_token { let bos_token_id = tokenizer - .token_to_id(bos) + .token_to_id(bos.as_str()) .expect("Should have found the bos token id"); - special_tokens.push((bos.clone(), bos_token_id)); - single.push(format!("{}:0", bos)); - pair.push(format!("{}:0", bos)); + special_tokens.push((bos.as_str(), bos_token_id)); + single.push(format!("{}:0", bos.as_str())); + pair.push(format!("{}:0", bos.as_str())); } } @@ -571,17 +571,17 @@ pub fn create_post_processor( if add_eos_token { if let Some(eos) = eos_token { let eos_token_id = tokenizer - .token_to_id(eos) + .token_to_id(eos.as_str()) .expect("Should have found the eos token id"); - special_tokens.push((eos.clone(), eos_token_id)); - single.push(format!("{}:0", eos)); - pair.push(format!("{}:0", eos)); + special_tokens.push((eos.as_str(), eos_token_id)); + single.push(format!("{}:0", eos.as_str())); + pair.push(format!("{}:0", eos.as_str())); } } if add_bos_token { if let Some(bos) = bos_token { - pair.push(format!("{}:1", bos)); + pair.push(format!("{}:1", bos.as_str())); } } @@ -589,7 +589,7 @@ pub fn create_post_processor( if add_eos_token { if let Some(eos) = eos_token { - pair.push(format!("{}:1", eos)); + pair.push(format!("{}:1", eos.as_str())); } } @@ -615,14 +615,15 @@ enum RouterError { #[cfg(test)] mod tests { use super::*; + use text_generation_router::TokenizerConfigToken; #[test] fn test_create_post_processor() { let tokenizer_config = HubTokenizerConfig { add_bos_token: None, add_eos_token: None, - bos_token: Some("".to_string()), - eos_token: Some("".to_string()), + bos_token: Some(TokenizerConfigToken::String("".to_string())), + eos_token: Some(TokenizerConfigToken::String("".to_string())), chat_template: None, tokenizer_class: None, completion_template: None, @@ -633,9 +634,9 @@ mod tests { let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap(); let expected = TemplateProcessing::builder() - .try_single(":0 $A:0 :1") + .try_single(":0 $A:0") .unwrap() - .try_pair(":0 $A:0 $B:1") + .try_pair(":0 $A:0 :1 $B:1") .unwrap() .special_tokens(vec![("".to_string(), 1)]) .build() diff --git a/router/src/server.rs b/router/src/server.rs index fa2ba001cf6..581f0068c94 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -12,17 +12,18 @@ use crate::kserve::{ use crate::validation::ValidationError; use crate::{ BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, - GenerateResponse, GrammarType, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, - HubTokenizerConfig, Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, - Token, TokenizeResponse, Usage, Validation, + GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info, + Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, + Usage, Validation, }; use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob, ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk, - CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse, + CompletionRequest, CompletionType, DeltaToolCall, Function, Tool, VertexRequest, + VertexResponse, }; -use crate::{FunctionDefinition, ToolCall, ToolType}; +use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolType}; use async_stream::__private::AsyncStream; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -636,7 +637,7 @@ async fn completions( )); } - if req.prompt.len() > info.max_client_batch_size { + if req.prompt.0.len() > info.max_client_batch_size { metrics::increment_counter!("tgi_request_failure", "err" => "validation"); return Err(( StatusCode::UNPROCESSABLE_ENTITY, @@ -652,6 +653,7 @@ async fn completions( let generate_requests: Vec = req .prompt + .0 .iter() .map(|prompt| GenerateRequest { inputs: prompt.to_string(), @@ -706,7 +708,6 @@ async fn completions( event .json_data(CompletionCompleteChunk { id: "".to_string(), - object: "text_completion".to_string(), created: current_time, choices: vec![CompletionComplete { @@ -933,7 +934,6 @@ async fn completions( let response = Completion { id: "".to_string(), - object: "text_completion".to_string(), created: current_time, model: info.model_id.clone(), system_fingerprint: format!( @@ -1154,14 +1154,16 @@ async fn chat_completions( }; event - .json_data(ChatCompletionChunk::new( - model_id.clone(), - system_fingerprint.clone(), - content, - tool_calls, - current_time, - logprobs, - stream_token.details.map(|d| d.finish_reason.to_string()), + .json_data(CompletionType::ChatCompletionChunk( + ChatCompletionChunk::new( + model_id.clone(), + system_fingerprint.clone(), + content, + tool_calls, + current_time, + logprobs, + stream_token.details.map(|d| d.finish_reason.to_string()), + ), )) .unwrap_or_else(|e| { println!("Failed to serialize ChatCompletionChunk: {:?}", e); @@ -1229,7 +1231,7 @@ async fn chat_completions( (None, Some(generation.generated_text)) }; // build the complete response object with the full text - let response = ChatCompletion::new( + let response = CompletionType::ChatCompletion(ChatCompletion::new( model_id, system_fingerprint, output, @@ -1237,7 +1239,7 @@ async fn chat_completions( generation.details.unwrap(), logprobs, tool_calls, - ); + )); // wrap generation inside a Vec to match api-inference Ok((headers, Json(response)).into_response())