Skip to content

Commit

Permalink
feat: leverage serde for conditional deser
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh committed Jun 27, 2024
1 parent 9500b03 commit bf24b76
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 163 deletions.
40 changes: 33 additions & 7 deletions router/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ pub(crate) use health::HealthCheck;
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
use crate::{
ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig,
HubTokenizerConfig, Message, MessageChunk, PrefillToken, Text, TextMessage, Token,
HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token,
};
use crate::{
FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools,
};
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
use futures::future::try_join_all;
use minijinja::{Environment, ErrorKind, Template};
use minijinja_contrib::pycompat;
Expand Down Expand Up @@ -270,7 +272,11 @@ struct ChatTemplate {
}

impl ChatTemplate {
fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
fn new(
template: String,
bos_token: Option<TokenizerConfigToken>,
eos_token: Option<TokenizerConfigToken>,
) -> Self {
let mut env = Box::new(Environment::new());
// enable things like .strip() or .capitalize()
env.set_unknown_method_callback(pycompat::unknown_method_callback);
Expand All @@ -287,8 +293,20 @@ impl ChatTemplate {

Self {
template,
bos_token,
eos_token,
bos_token: match bos_token {
Some(token) => match token {
TokenizerConfigToken::String(token) => Some(token),
TokenizerConfigToken::Object { content } => Some(content),
},
None => None,
},
eos_token: match eos_token {
Some(token) => match token {
TokenizerConfigToken::String(token) => Some(token),
TokenizerConfigToken::Object { content } => Some(content),
},
None => None,
},
use_default_tool_template,
}
}
Expand All @@ -301,9 +319,9 @@ impl ChatTemplate {
if self.use_default_tool_template {
if let Some(last_message) = messages.last_mut() {
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
last_message.content.push(MessageChunk::Text(Text {
last_message.content.push(MessageChunk::Text {
text: format!("\n---\n{}\n{}", tool_prompt, tools),
}));
});
}
}
}
Expand Down Expand Up @@ -340,6 +358,14 @@ impl ToolGrammar {
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
.clone()]
}
ToolType::Function { function } => {
let tool = req_tools
.iter()
.find(|tool| tool.function.name == function.name)
.unwrap_or_else(|| panic!("Tool with name {} not found", function.name))
.clone();
vec![tool]
}
ToolType::OneOf => req_tools.to_owned(),
};

Expand Down
Loading

0 comments on commit bf24b76

Please sign in to comment.