diff --git a/router/src/lib.rs b/router/src/lib.rs index a4bd322961f..d5e551a0a0f 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -59,10 +59,8 @@ use std::path::Path; pub struct HubTokenizerConfig { pub chat_template: Option, pub completion_template: Option, - #[serde(deserialize_with = "token_serde::deserialize")] - pub bos_token: Option, - #[serde(deserialize_with = "token_serde::deserialize")] - pub eos_token: Option, + pub bos_token: Option, + pub eos_token: Option, pub tokenizer_class: Option, pub add_bos_token: Option, pub add_eos_token: Option, @@ -83,15 +81,27 @@ pub enum TokenizerConfigToken { Object { content: String }, } -impl From for String { - fn from(token: TokenizerConfigToken) -> Self { - match token { +impl TokenizerConfigToken { + pub fn as_str(&self) -> &str { + match self { TokenizerConfigToken::String(s) => s, TokenizerConfigToken::Object { content } => content, } } } +impl From for String { + fn from(token: TokenizerConfigToken) -> Self { + token.as_str().to_string() + } +} + +impl From for TokenizerConfigToken { + fn from(s: String) -> Self { + TokenizerConfigToken::String(s) + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "processor_class")] pub enum HubPreprocessorConfig { diff --git a/router/src/main.rs b/router/src/main.rs index 3aa5a6bf9d2..ef103d925d4 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -315,17 +315,19 @@ async fn main() -> Result<(), RouterError> { let mut special_tokens = vec![]; if let Some(true) = &tokenizer_config.add_bos_token{ if let Some(bos_token) = &tokenizer_config.bos_token{ - let bos_token_id = tokenizer.token_to_id(&bos_token).expect("Should have found the bos token id"); - special_tokens.push((bos_token.clone(), bos_token_id)); - single.push(bos_token.to_string()); + let bos_token_str = bos_token.as_str(); + let bos_token_id = tokenizer.token_to_id(bos_token_str).expect("Should have found the bos token id"); + special_tokens.push((bos_token_str.to_string(), bos_token_id)); + single.push(bos_token_str.to_string()); } } single.push("$0".to_string()); if let Some(true) = &tokenizer_config.add_eos_token{ if let Some(eos_token) = &tokenizer_config.eos_token{ - let eos_token_id = tokenizer.token_to_id(&eos_token).expect("Should have found the eos token id"); - special_tokens.push((eos_token.clone(), eos_token_id)); - single.push(eos_token.to_string()); + let eos_token_str = eos_token.as_str(); + let eos_token_id = tokenizer.token_to_id(eos_token_str).expect("Should have found the eos token id"); + special_tokens.push((eos_token_str.to_string(), eos_token_id)); + single.push(eos_token_str.to_string()); } } let post_processor = TemplateProcessing::builder().try_single(single).unwrap().special_tokens(special_tokens).build().unwrap(); diff --git a/router/src/server.rs b/router/src/server.rs index e11403eb537..d24774f96c3 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -23,7 +23,7 @@ use crate::{ CompletionRequest, CompletionType, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse, }; -use crate::{FunctionDefinition, ToolCall, ToolType}; +use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolType}; use async_stream::__private::AsyncStream; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode};