Skip to content

Commit

Permalink
fix: adjust HubTokenizerConfig after rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh committed Jun 27, 2024
1 parent bf24b76 commit d43ef3d
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 14 deletions.
24 changes: 17 additions & 7 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,8 @@ use std::path::Path;
pub struct HubTokenizerConfig {
pub chat_template: Option<ChatTemplateVersions>,
pub completion_template: Option<String>,
#[serde(deserialize_with = "token_serde::deserialize")]
pub bos_token: Option<String>,
#[serde(deserialize_with = "token_serde::deserialize")]
pub eos_token: Option<String>,
pub bos_token: Option<TokenizerConfigToken>,
pub eos_token: Option<TokenizerConfigToken>,
pub tokenizer_class: Option<String>,
pub add_bos_token: Option<bool>,
pub add_eos_token: Option<bool>,
Expand All @@ -83,15 +81,27 @@ pub enum TokenizerConfigToken {
Object { content: String },
}

impl From<TokenizerConfigToken> for String {
fn from(token: TokenizerConfigToken) -> Self {
match token {
impl TokenizerConfigToken {
pub fn as_str(&self) -> &str {
match self {
TokenizerConfigToken::String(s) => s,
TokenizerConfigToken::Object { content } => content,
}
}
}

impl From<TokenizerConfigToken> for String {
fn from(token: TokenizerConfigToken) -> Self {
token.as_str().to_string()
}
}

impl From<String> for TokenizerConfigToken {
fn from(s: String) -> Self {
TokenizerConfigToken::String(s)
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "processor_class")]
pub enum HubPreprocessorConfig {
Expand Down
14 changes: 8 additions & 6 deletions router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,17 +315,19 @@ async fn main() -> Result<(), RouterError> {
let mut special_tokens = vec![];
if let Some(true) = &tokenizer_config.add_bos_token{
if let Some(bos_token) = &tokenizer_config.bos_token{
let bos_token_id = tokenizer.token_to_id(&bos_token).expect("Should have found the bos token id");
special_tokens.push((bos_token.clone(), bos_token_id));
single.push(bos_token.to_string());
let bos_token_str = bos_token.as_str();
let bos_token_id = tokenizer.token_to_id(bos_token_str).expect("Should have found the bos token id");
special_tokens.push((bos_token_str.to_string(), bos_token_id));
single.push(bos_token_str.to_string());
}
}
single.push("$0".to_string());
if let Some(true) = &tokenizer_config.add_eos_token{
if let Some(eos_token) = &tokenizer_config.eos_token{
let eos_token_id = tokenizer.token_to_id(&eos_token).expect("Should have found the eos token id");
special_tokens.push((eos_token.clone(), eos_token_id));
single.push(eos_token.to_string());
let eos_token_str = eos_token.as_str();
let eos_token_id = tokenizer.token_to_id(eos_token_str).expect("Should have found the eos token id");
special_tokens.push((eos_token_str.to_string(), eos_token_id));
single.push(eos_token_str.to_string());
}
}
let post_processor = TemplateProcessing::builder().try_single(single).unwrap().special_tokens(special_tokens).build().unwrap();
Expand Down
2 changes: 1 addition & 1 deletion router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use crate::{
CompletionRequest, CompletionType, DeltaToolCall, Function, Tool, VertexRequest,
VertexResponse,
};
use crate::{FunctionDefinition, ToolCall, ToolType};
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolType};
use async_stream::__private::AsyncStream;
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
Expand Down

0 comments on commit d43ef3d

Please sign in to comment.