Skip to content

Commit

Permalink
fix: prefer serde structs over custom functions (#2127)
Browse files Browse the repository at this point in the history
* fix: prefer enum for chat object

* fix: adjust typo

* fix: enum CompletionType not ObjectType

* fix: adjust typo

* feat: leverage serde for conditional deser

* fix: adjust HubTokenizerConfig after rebase

* fix: update create_post_processor logic for token type

* fix: adjust unwrap syntax in template

* Fixing the post processor.

---------

Co-authored-by: Nicolas Patry <[email protected]>
  • Loading branch information
2 people authored and ErikKaum committed Jul 26, 2024
1 parent dca51e4 commit 89a0c5d
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 204 deletions.
28 changes: 21 additions & 7 deletions router/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ pub(crate) use health::HealthCheck;
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
use crate::{
ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig,
HubTokenizerConfig, Message, MessageChunk, PrefillToken, Text, TextMessage, Token,
HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token,
};
use crate::{
FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools,
};
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
use futures::future::try_join_all;
use minijinja::{Environment, ErrorKind, Template};
use minijinja_contrib::pycompat;
Expand Down Expand Up @@ -270,7 +272,11 @@ struct ChatTemplate {
}

impl ChatTemplate {
fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
fn new(
template: String,
bos_token: Option<TokenizerConfigToken>,
eos_token: Option<TokenizerConfigToken>,
) -> Self {
let mut env = Box::new(Environment::new());
// enable things like .strip() or .capitalize()
env.set_unknown_method_callback(pycompat::unknown_method_callback);
Expand All @@ -287,8 +293,8 @@ impl ChatTemplate {

Self {
template,
bos_token,
eos_token,
bos_token: bos_token.map(|token| token.as_str().to_string()),
eos_token: eos_token.map(|token| token.as_str().to_string()),
use_default_tool_template,
}
}
Expand All @@ -301,9 +307,9 @@ impl ChatTemplate {
if self.use_default_tool_template {
if let Some(last_message) = messages.last_mut() {
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
last_message.content.push(MessageChunk::Text(Text {
last_message.content.push(MessageChunk::Text {
text: format!("\n---\n{}\n{}", tool_prompt, tools),
}));
});
}
}
}
Expand Down Expand Up @@ -340,6 +346,14 @@ impl ToolGrammar {
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
.clone()]
}
ToolType::Function { function } => {
let tool = req_tools
.iter()
.find(|tool| tool.function.name == function.name)
.unwrap_or_else(|| panic!("Tool with name {} not found", function.name))
.clone();
vec![tool]
}
ToolType::OneOf => req_tools.to_owned(),
};

Expand Down
Loading

0 comments on commit 89a0c5d

Please sign in to comment.