Skip to content

Commit

Permalink
feat: add concrete tool types
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh committed Feb 22, 2024
1 parent 7be5244 commit 6956a32
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 68 deletions.
73 changes: 48 additions & 25 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -537,38 +537,61 @@ pub(crate) struct ChatRequest {
pub tools: Option<Vec<Tool>>,
}

// TODO: define and use better types for tools

// #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
// enum ToolType {
// #[serde(rename = "function")]
// Function,
// }

// impl Default for ToolType {
// fn default() -> Self {
// ToolType::Function
// }
// }

// #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
// pub(crate) struct Function {
// pub description: String,
// pub name: String,
// #[serde(
// rename = "json",
// deserialize_with = "json_object_or_string_to_string::deserialize"
// )]
// pub parameters: String,
// }
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)]
pub struct Tools {
// rename to "$function" to avoid conflicts with other fields
#[serde(rename = "$function")]
pub function: std::collections::HashMap<String, serde_json::Value>,
pub any_of: Vec<FunctionRef>,
}

// add traut to convert to serde_json::Value for tools
impl From<Tools> for serde_json::Value {
fn from(tools: Tools) -> Self {
println!("tools: {:?}", tools);
let mut map = serde_json::Map::new();
let mut functions = serde_json::Map::new();
for (name, value) in tools.function {
functions.insert(name, value);
}
map.insert("$functions".to_string(), serde_json::json!(functions));
let mut properties = serde_json::Map::new();
let mut function = serde_json::Map::new();
function.insert("anyOf".to_string(), serde_json::json!(tools.any_of));
properties.insert("function".to_string(), serde_json::json!(function));
map.insert("properties".to_string(), serde_json::json!(properties));
serde_json::Value::Object(map)
}
}

#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
pub struct FunctionRef {
#[serde(rename = "$ref")]
pub _ref: String,
}

impl FunctionRef {
pub fn new(name: &str) -> Self {
Self {
_ref: format!("#/$functions/{}", name),
}
}
}

#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
pub(crate) struct Function {
pub description: String,
pub name: String,
pub parameters: serde_json::Value,
}

#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
pub(crate) struct Tool {
// The type of the tool. Currently, only 'function' is supported.
#[schema(example = "function")]
pub r#type: String,
// Grab the tool as generic JSON for debugging purposes.
pub function: serde_json::Value,
pub function: Function,
}

#[derive(Clone, Serialize, Deserialize)]
Expand Down
100 changes: 57 additions & 43 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::{
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse,
};
use crate::{FunctionRef, Tools};
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse};
Expand All @@ -22,6 +23,8 @@ use futures::stream::StreamExt;
use futures::Stream;
use futures::TryStreamExt;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use serde_json::Value;
use std::collections::HashMap;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::atomic::AtomicBool;
Expand Down Expand Up @@ -580,45 +583,6 @@ async fn chat_completions(
let logprobs = req.logprobs.unwrap_or(false);
let seed = req.seed;

// Build a new JSON schema that defines the "$functions" object
// and requires the grammar to choose anyOf the functions defined.
let mut tools = serde_json::json!({});

// First decompose the tools and use the function name as the key
// and the parameters as the value in the "$functions" object.
if let Some(req_tools) = &req.tools {
for tool in req_tools {
let func = tool.function.clone();
let name = func.get("name").unwrap().as_str().unwrap();
let parameters = func.get("parameters").unwrap().as_object().unwrap().clone();
// add a entry to the "$functions" object
tools["$functions"][name] = serde_json::Value::Object(parameters);
}

// now add the properties to the root object
tools["properties"]["function"]["anyOf"] = serde_json::Value::Array(
req.tools
.as_ref()
.unwrap()
.iter()
// map each tool to a $ref to the function
.map(|tool| {
let func = tool.function.clone();
let name = func.get("name").unwrap().as_str().unwrap();
serde_json::json!({
"$ref": format!("#/$functions/{}", name)
})
})
.collect(),
);
}

// only add grammar if tools are present
let grammar = match req.tools {
Some(_grammar) => Some(crate::GrammarType::Json(tools.to_string())),
None => None,
};

// apply chat template to flatten the request into a single input
let mut inputs = match infer.apply_chat_template(req.messages) {
Ok(inputs) => inputs,
Expand All @@ -635,10 +599,60 @@ async fn chat_completions(
}
};

// append the tools to the inputs with TOOL prompt
let tool_prompt =
"Based on the conversation, please choose the most appropriate tool to use:".to_string();
inputs = format!("{inputs}\n\n{tool_prompt}\n\n{tools}\n\n");
// if theres a tools object, we need to decompose it and use the function name as the key
// and the parameters as the value in the "$functions" object.
let grammar = if let Some(req_tools) = &req.tools {
let functions: HashMap<String, Value> = {
let mut tools = HashMap::new();
for tool in req_tools {
let func = tool.function.clone();
let name = func.name;
let parameters = match func.parameters.as_object() {
Some(parameters) => parameters.clone(),
None => {
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Input validation error".to_string(),
error_type: "Input validation error".to_string(),
}),
))
}
};

tools.insert(name, Value::Object(parameters));
}
tools
};

let tools = Tools {
function: functions,
any_of: req_tools
.iter()
.map(|tool| FunctionRef::new(&tool.function.name))
.collect(),
};

// update the input
let tools_str = serde_json::to_string(&tools).map_err(|e| {
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: e.to_string(),
error_type: "Input validation error".to_string(),
}),
)
})?;

let tool_prompt =
"Based on the conversation, please choose the most appropriate tool to use:"
.to_string();
inputs = format!("{inputs}\n\n{tool_prompt}\n\n{tools_str}\n\n");

Some(GrammarType::Json(tools.into()))
} else {
None
};

// build the request passing some parameters
let generate_request = GenerateRequest {
Expand Down

0 comments on commit 6956a32

Please sign in to comment.