From 001b12438150091483e79e04e3db170acf30ecc0 Mon Sep 17 00:00:00 2001 From: Jeremy Chone Date: Tue, 22 Oct 2024 12:30:16 -0700 Subject: [PATCH] + tool - First pass at adding Function Calling for OpenAI and Anthropic (rel #24) --- Cargo.toml | 4 +- .../adapters/anthropic/adapter_impl.rs | 165 +++++++++++++--- src/adapter/adapters/anthropic/mod.rs | 1 + src/adapter/adapters/cohere/adapter_impl.rs | 20 +- src/adapter/adapters/gemini/adapter_impl.rs | 11 +- src/adapter/adapters/openai/adapter_impl.rs | 186 +++++++++++++++--- src/chat/chat_message.rs | 71 +++++++ src/chat/chat_options.rs | 2 +- ..._format.rs => chat_req_response_format.rs} | 2 +- src/chat/{chat_req.rs => chat_request.rs} | 101 +++------- src/chat/chat_res.rs | 24 ++- src/chat/message_content.rs | 58 ++++-- src/chat/mod.rs | 10 +- src/chat/tool.rs | 14 -- src/chat/tool/mod.rs | 11 ++ src/chat/tool/tool_base.rs | 64 ++++++ src/chat/tool/tool_call.rs | 10 + src/chat/tool/tool_response.rs | 29 +++ src/error.rs | 12 +- tests/support/common_tests.rs | 66 ++++++- tests/support/seeders.rs | 29 ++- tests/tests_p_anthropic.rs | 14 ++ tests/tests_p_openai.rs | 13 ++ 23 files changed, 740 insertions(+), 177 deletions(-) create mode 100644 src/chat/chat_message.rs rename src/chat/{chat_response_format.rs => chat_req_response_format.rs} (95%) rename src/chat/{chat_req.rs => chat_request.rs} (61%) delete mode 100644 src/chat/tool.rs create mode 100644 src/chat/tool/mod.rs create mode 100644 src/chat/tool/tool_base.rs create mode 100644 src/chat/tool/tool_call.rs create mode 100644 src/chat/tool/tool_response.rs diff --git a/Cargo.toml b/Cargo.toml index 8302e34..08986e2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,8 +11,8 @@ repository = "https://github.com/jeremychone/rust-genai" [lints.rust] unsafe_code = "forbid" -# unused = { level = "allow", priority = -1 } # For exploratory dev. -missing_docs = "warn" +unused = { level = "allow", priority = -1 } # For exploratory dev. +# missing_docs = "warn" [dependencies] # -- Async diff --git a/src/adapter/adapters/anthropic/adapter_impl.rs b/src/adapter/adapters/anthropic/adapter_impl.rs index 6101070..3ce6f71 100644 --- a/src/adapter/adapters/anthropic/adapter_impl.rs +++ b/src/adapter/adapters/anthropic/adapter_impl.rs @@ -3,10 +3,11 @@ use crate::adapter::support::get_api_key; use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData}; use crate::chat::{ ChatOptionsSet, ChatRequest, ChatResponse, ChatRole, ChatStream, ChatStreamResponse, MessageContent, MetaUsage, + ToolCall, }; use crate::webc::WebResponse; -use crate::Result; use crate::{ClientConfig, ModelIden}; +use crate::{Error, Result}; use reqwest::RequestBuilder; use reqwest_eventsource::EventSource; use serde_json::{json, Value}; @@ -18,9 +19,9 @@ const BASE_URL: &str = "https://api.anthropic.com/v1/"; const MAX_TOKENS: u32 = 1024; const ANTRHOPIC_VERSION: &str = "2023-06-01"; const MODELS: &[&str] = &[ + "claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20240620", "claude-3-opus-20240229", - "claude-3-sonnet-20240229", "claude-3-haiku-20240307", ]; @@ -53,7 +54,7 @@ impl Adapter for AnthropicAdapter { let url = Self::get_service_url(model_iden.clone(), service_type); // -- api_key (this Adapter requires it) - let api_key = get_api_key(model_iden, client_config)?; + let api_key = get_api_key(model_iden.clone(), client_config)?; let headers = vec![ // headers @@ -61,7 +62,11 @@ impl Adapter for AnthropicAdapter { ("anthropic-version".to_string(), ANTRHOPIC_VERSION.to_string()), ]; - let AnthropicRequestParts { system, messages } = Self::into_anthropic_request_parts(chat_req)?; + let AnthropicRequestParts { + system, + messages, + tools, + } = Self::into_anthropic_request_parts(model_iden, chat_req)?; // -- Build the basic payload let mut payload = json!({ @@ -69,10 +74,15 @@ impl Adapter for AnthropicAdapter { "messages": messages, "stream": stream }); + if let Some(system) = system { payload.x_insert("system", system)?; } + if let Some(tools) = tools { + payload.x_insert("/tools", tools); + } + // -- Add supported ChatOptions if let Some(temperature) = options_set.temperature() { payload.x_insert("temperature", temperature)?; @@ -90,27 +100,51 @@ impl Adapter for AnthropicAdapter { fn to_chat_response(model_iden: ModelIden, web_response: WebResponse) -> Result { let WebResponse { mut body, .. } = web_response; - let json_content_items: Vec = body.x_take("content")?; - - let mut content: Vec = Vec::new(); + // -- Capture the usage let usage = body.x_take("usage").map(Self::into_usage).unwrap_or_default(); - for mut item in json_content_items { - let item_text: String = item.x_take("text")?; - content.push(item_text); + // -- Capture the content + // NOTE: Anthropic support a list of content of multitypes but not the ChatResponse + // So, the strategy is to: + // - List all of the content and capture the text and tool_use + // - If there is one or more tool_use, this will take precedence and MessageContent support tool_call list + // - Otherwise, the text is concatenated + // NOTE: We need to see if the multiple content type text happens and why. If not, we can probably simplify this by just capturing the first one. + // Eventually, ChatResponse will have `content: Option>` for the multi parts (with images and such) + let content_items: Vec = body.x_take("content")?; + + let mut text_content: Vec = Vec::new(); + // Note: here tool_calls is probably the exception, so, not creating the vector if not needed + let mut tool_calls: Option> = None; + + for mut item in content_items { + let typ: &str = item.x_get_as("type")?; + if typ == "text" { + text_content.push(item.x_take("text")?); + } else if typ == "tool_use" { + let call_id = item.x_take::("id")?; + let fn_name = item.x_take::("name")?; + // if not found, will be Value::Null + let fn_arguments = item.x_take::("input").unwrap_or_default(); + let tool_call = ToolCall { + call_id, + fn_name, + fn_arguments, + }; + tool_calls.get_or_insert_with(Vec::new).push(tool_call); + } } - let content = if content.is_empty() { - None + let content = if let Some(tool_calls) = tool_calls { + Some(MessageContent::from(tool_calls)) } else { - Some(content.join("")) + Some(MessageContent::from(text_content.join("\n"))) }; - let content = content.map(MessageContent::from); Ok(ChatResponse { - model_iden, content, + model_iden, usage, }) } @@ -153,7 +187,7 @@ impl AnthropicAdapter { /// Takes the GenAI ChatMessages and constructs the System string and JSON Messages for Anthropic. /// - Will push the `ChatRequest.system` and system message to `AnthropicRequestParts.system` - fn into_anthropic_request_parts(chat_req: ChatRequest) -> Result { + fn into_anthropic_request_parts(model_iden: ModelIden, chat_req: ChatRequest) -> Result { let mut messages: Vec = Vec::new(); let mut systems: Vec = Vec::new(); @@ -161,32 +195,115 @@ impl AnthropicAdapter { systems.push(system); } + // -- Process the messages for msg in chat_req.messages { - // Note: Will handle more types later - let MessageContent::Text(content) = msg.content; - match msg.role { // for now, system and tool messages go to system - ChatRole::System | ChatRole::Tool => systems.push(content), - ChatRole::User => messages.push(json! ({"role": "user", "content": content})), - ChatRole::Assistant => messages.push(json! ({"role": "assistant", "content": content})), + ChatRole::System => { + if let MessageContent::Text(content) = msg.content { + systems.push(content) + } + // TODO: Needs to trace/warn that other type are not supported + } + ChatRole::User => { + if let MessageContent::Text(content) = msg.content { + messages.push(json! ({"role": "user", "content": content})) + } + // TODO: Needs to trace/warn that other type are not supported + } + ChatRole::Assistant => { + // + match msg.content { + MessageContent::Text(content) => { + messages.push(json! ({"role": "assistant", "content": content})) + } + MessageContent::ToolCalls(tool_calls) => { + let tool_calls = tool_calls + .into_iter() + .map(|tool_call| { + // see: https://docs.anthropic.com/en/docs/build-with-claude/tool-use#example-of-successful-tool-result + json!({ + "type": "tool_use", + "id": tool_call.call_id, + "name": tool_call.fn_name, + "input": tool_call.fn_arguments, + }) + }) + .collect::>(); + messages.push(json! ({ + "role": "assistant", + "content": tool_calls + })); + } + // TODO: Probably need to trace/warn that this will be ignored + MessageContent::ToolResponses(_) => (), + } + } + ChatRole::Tool => { + if let MessageContent::ToolResponses(tool_responses) = msg.content { + let tool_responses = tool_responses + .into_iter() + .map(|tool_response| { + json!({ + "type": "tool_result", + "content": tool_response.content, + "tool_use_id": tool_response.call_id, + }) + }) + .collect::>(); + + // FIXME: MessageContent::ToolResponse should be MessageContent::ToolResponses (even if openAI does require multi Tool message) + messages.push(json!({ + "role": "user", + "content": tool_responses + })); + } + // TODO: Probably need to trace/warn that this will be ignored + } } } + // -- Create the Anthropic system + // NOTE: Anthropic does not have a "role": "system", just a single optional system property let system = if !systems.is_empty() { Some(systems.join("\n")) } else { None }; - Ok(AnthropicRequestParts { system, messages }) + // -- Process the tools + let tools = chat_req.tools.map(|tools| { + tools + .into_iter() + .map(|tool| { + // TODO: Need to handle the error correctly + // TODO: Needs to have a custom serializer (tool should not have to match to a provider) + // NOTE: Right now, low probability, so, we just return null if cannto to value. + let mut tool_value = json!({ + "name": tool.name, + "input_schema": tool.schema, + }); + + if let Some(description) = tool.description { + tool_value.x_insert("description", description); + } + tool_value + }) + .collect::>() + }); + + Ok(AnthropicRequestParts { + system, + messages, + tools, + }) } } struct AnthropicRequestParts { system: Option, messages: Vec, - // TODO: need to add tools + tools: Option>, } // endregion: --- Support diff --git a/src/adapter/adapters/anthropic/mod.rs b/src/adapter/adapters/anthropic/mod.rs index 822e99c..e685fda 100644 --- a/src/adapter/adapters/anthropic/mod.rs +++ b/src/adapter/adapters/anthropic/mod.rs @@ -1,4 +1,5 @@ //! API Documentation: https://docs.anthropic.com/en/api/messages +//! Tool Documentation: https://docs.anthropic.com/en/docs/build-with-claude/tool-use //! Model Names: https://docs.anthropic.com/en/docs/models-overview //! Pricing: https://www.anthropic.com/pricing#anthropic-api diff --git a/src/adapter/adapters/cohere/adapter_impl.rs b/src/adapter/adapters/cohere/adapter_impl.rs index bac7b0d..d55d626 100644 --- a/src/adapter/adapters/cohere/adapter_impl.rs +++ b/src/adapter/adapters/cohere/adapter_impl.rs @@ -110,8 +110,8 @@ impl Adapter for CohereAdapter { .map(MessageContent::from); Ok(ChatResponse { - model_iden, content, + model_iden, usage, }) } @@ -185,13 +185,23 @@ impl CohereAdapter { actual_role: last_chat_msg.role, }); } - // Will handle more types later - let MessageContent::Text(message) = last_chat_msg.content; + + // TODO: Needs to implement tool_calls + let MessageContent::Text(message) = last_chat_msg.content else { + return Err(Error::MessageContentTypeNotSupported { + model_iden, + cause: "Only MessageContent::Text supported for this model (for now)", + }); + }; // -- Build for msg in chat_req.messages { - // Note: Will handle more types later - let MessageContent::Text(content) = msg.content; + let MessageContent::Text(content) = msg.content else { + return Err(Error::MessageContentTypeNotSupported { + model_iden, + cause: "Only MessageContent::Text supported for this model (for now)", + }); + }; match msg.role { // For now, system and tool go to the system diff --git a/src/adapter/adapters/gemini/adapter_impl.rs b/src/adapter/adapters/gemini/adapter_impl.rs index e165138..4ccfa50 100644 --- a/src/adapter/adapters/gemini/adapter_impl.rs +++ b/src/adapter/adapters/gemini/adapter_impl.rs @@ -121,8 +121,8 @@ impl Adapter for GeminiAdapter { let content = content.map(MessageContent::from); Ok(ChatResponse { - model_iden, content, + model_iden, usage, }) } @@ -192,8 +192,13 @@ impl GeminiAdapter { // -- Build for msg in chat_req.messages { - // Note: Will handle more types later - let MessageContent::Text(content) = msg.content; + // TODO: Needs to implement tool_calls + let MessageContent::Text(content) = msg.content else { + return Err(Error::MessageContentTypeNotSupported { + model_iden, + cause: "Only MessageContent::Text supported for this model (for now)", + }); + }; match msg.role { // For now, system goes as "user" (later, we might have adapter_config.system_to_user_impl) diff --git a/src/adapter/adapters/openai/adapter_impl.rs b/src/adapter/adapters/openai/adapter_impl.rs index 9e991bf..bc28984 100644 --- a/src/adapter/adapters/openai/adapter_impl.rs +++ b/src/adapter/adapters/openai/adapter_impl.rs @@ -3,13 +3,14 @@ use crate::adapter::support::get_api_key; use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData}; use crate::chat::{ ChatOptionsSet, ChatRequest, ChatResponse, ChatResponseFormat, ChatRole, ChatStream, ChatStreamResponse, - MessageContent, MetaUsage, + MessageContent, MetaUsage, Tool, ToolCall, }; use crate::webc::WebResponse; use crate::{ClientConfig, ModelIden}; use crate::{Error, Result}; use reqwest::RequestBuilder; use reqwest_eventsource::EventSource; +use serde::Deserialize; use serde_json::{json, Value}; use value_ext::JsonValueExt; @@ -54,15 +55,31 @@ impl Adapter for OpenAIAdapter { fn to_chat_response(model_iden: ModelIden, web_response: WebResponse) -> Result { let WebResponse { mut body, .. } = web_response; + // -- Capture the usage let usage = body.x_take("usage").map(OpenAIAdapter::into_usage).unwrap_or_default(); - let first_choice: Option = body.x_take("/choices/0")?; - let content: Option = first_choice.map(|mut c| c.x_take("/message/content")).transpose()?; - let content = content.map(MessageContent::from); + // -- Capture the content + let content = if let Some(mut first_choice) = body.x_take::>("/choices/0")? { + if let Some(content) = first_choice + .x_take::>("/message/content")? + .map(MessageContent::from) + { + Some(content) + } else { + first_choice + .x_take("/message/tool_calls") + .ok() + .map(parse_tool_calls) + .transpose()? + .map(MessageContent::from_tool_calls) + } + } else { + None + }; Ok(ChatResponse { - model_iden, content, + model_iden, usage, }) } @@ -117,13 +134,17 @@ impl OpenAIAdapter { // -- Build the basic payload let model_name = model_iden.model_name.to_string(); - let OpenAIRequestParts { messages } = Self::into_openai_request_parts(model_iden, chat_req)?; + let OpenAIRequestParts { messages, tools } = Self::into_openai_request_parts(model_iden, chat_req)?; let mut payload = json!({ "model": model_name, "messages": messages, "stream": stream }); + if let Some(tools) = tools { + payload.x_insert("/tools", tools); + } + // -- Add options let response_format = if let Some(response_format) = options_set.response_format() { match response_format { @@ -199,16 +220,17 @@ impl OpenAIAdapter { /// Takes the genai ChatMessages and builds the OpenAIChatRequestParts /// - `genai::ChatRequest.system`, if present, is added as the first message with role 'system'. /// - All messages get added with the corresponding roles (tools are not supported for now) - /// - /// NOTE: Here, the last `true` is for the Ollama variant - /// It seems the Ollama compatibility layer does not work well with multiple system messages. - /// So, when `true`, it will concatenate the system message into a single one at the beginning fn into_openai_request_parts(model_iden: ModelIden, chat_req: ChatRequest) -> Result { - let mut system_messages: Vec = Vec::new(); let mut messages: Vec = Vec::new(); + /// NOTE: For now system_messages is use to fix an issue with the Ollama compatibility layer that does not support multiple system messages. + /// So, when ollama, it will concatenate the system message into a single one at the beginning + /// NOTE: This might be fixed now, so, we could remove this. + let mut system_messages: Vec = Vec::new(); + let ollama_variant = matches!(model_iden.adapter_kind, AdapterKind::Ollama); + // -- Process the system if let Some(system_msg) = chat_req.system { if ollama_variant { system_messages.push(system_msg) @@ -217,37 +239,98 @@ impl OpenAIAdapter { } } + // -- Process the messages for msg in chat_req.messages { // Note: Will handle more types later - let MessageContent::Text(content) = msg.content; - match msg.role { // For now, system and tool messages go to the system ChatRole::System => { - // See note in the function comment - if ollama_variant { - system_messages.push(content); - } else { - messages.push(json!({"role": "system", "content": content})) + if let MessageContent::Text(content) = msg.content { + // NOTE: Ollama does not support multiple system messages + + // See note in the function comment + if ollama_variant { + system_messages.push(content); + } else { + messages.push(json!({"role": "system", "content": content})) + } + } + // TODO: Probably need to warn if it is a ToolCalls type of content + } + ChatRole::User => { + if let MessageContent::Text(content) = msg.content { + messages.push(json! ({"role": "user", "content": content})); } + // TODO: Probably need to warn if it is a ToolCalls type of content } - ChatRole::User => messages.push(json! ({"role": "user", "content": content})), - ChatRole::Assistant => messages.push(json! ({"role": "assistant", "content": content})), + + ChatRole::Assistant => match msg.content { + MessageContent::Text(content) => messages.push(json! ({"role": "assistant", "content": content})), + MessageContent::ToolCalls(tool_calls) => { + let tool_calls = tool_calls + .into_iter() + .map(|tool_call| { + json!({ + "type": "function", + "id": tool_call.call_id, + "function": { + "name": tool_call.fn_name, + "arguments": tool_call.fn_arguments.to_string(), + } + }) + }) + .collect::>(); + messages.push(json! ({"role": "assistant", "tool_calls": tool_calls})) + } + // TODO: Probably need to trace/warn that this will be ignored + MessageContent::ToolResponses(_) => (), + }, + ChatRole::Tool => { - return Err(Error::MessageRoleNotSupported { - model_iden, - role: ChatRole::Tool, - }) + if let MessageContent::ToolResponses(tool_responses) = msg.content { + for tool_response in tool_responses { + messages.push(json!({ + "role": "tool", + "content": tool_response.content, + "tool_call_id": tool_response.call_id, + })) + } + } + // TODO: Probably need to trace/warn that this will be ignored } } } + // -- Finalize the system messages ollama case if !system_messages.is_empty() { let system_message = system_messages.join("\n"); messages.insert(0, json!({"role": "system", "content": system_message})); } - Ok(OpenAIRequestParts { messages }) + // -- Process the tools + let tools = chat_req.tools.map(|tools| { + tools + .into_iter() + .map(|tool| { + // TODO: Need to handle the error correctly + // TODO: Needs to have a custom serializer (tool should not have to match to a provider) + // NOTE: Right now, low probability, so, we just return null if cannto to value. + json!({ + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.schema, + // TODO: If we need to support `strict: true` we need to add additionalProperties: false into the schema + // above (like structured output) + "strict": false, + } + }) + }) + .collect::>() + }); + + Ok(OpenAIRequestParts { messages, tools }) } } @@ -255,6 +338,59 @@ impl OpenAIAdapter { struct OpenAIRequestParts { messages: Vec, + tools: Option>, +} + +fn parse_tool_calls(raw_tool_calls: Value) -> Result> { + let Value::Array(raw_tool_calls) = raw_tool_calls else { + return Err(Error::InvalidJsonResponseElement { + info: "tool calls is not an array", + }); + }; + + let tool_calls = raw_tool_calls.into_iter().map(parse_tool_call).collect::>>()?; + + Ok(tool_calls) +} + +fn parse_tool_call(raw_tool_call: Value) -> Result { + // Define a helper struct to match the original JSON structure. + #[derive(Deserialize)] + struct IterimToolFnCall { + id: String, + #[serde(rename = "type")] + r#type: String, + function: IterimFunction, + } + + #[derive(Deserialize)] + struct IterimFunction { + name: String, + arguments: Value, + } + + let iterim = serde_json::from_value::(raw_tool_call)?; + + let fn_name = iterim.function.name; + + // For now support Object only, and parse the eventual string as a json value. + // Eventually, we might check pricing + let fn_arguments = match iterim.function.arguments { + Value::Object(obj) => Value::Object(obj), + Value::String(txt) => serde_json::from_str(&txt)?, + _ => { + return Err(Error::InvalidJsonResponseElement { + info: "tool call arguments is not an object", + }) + } + }; + + // Then, map the fields of the helper struct to the flat structure. + Ok(ToolCall { + call_id: iterim.id, + fn_name, + fn_arguments, + }) } // endregion: --- Support diff --git a/src/chat/chat_message.rs b/src/chat/chat_message.rs new file mode 100644 index 0000000..0e0905c --- /dev/null +++ b/src/chat/chat_message.rs @@ -0,0 +1,71 @@ +use crate::chat::{MessageContent, ToolCall, ToolResponse}; +use serde::{Deserialize, Serialize}; + +/// An individual chat message. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatMessage { + /// The role of the message. + pub role: ChatRole, + + /// The content of the message. + pub content: MessageContent, +} + +/// Constructors +impl ChatMessage { + /// Create a new ChatMessage with the role `ChatRole::System`. + pub fn system(content: impl Into) -> Self { + Self { + role: ChatRole::System, + content: content.into(), + } + } + + /// Create a new ChatMessage with the role `ChatRole::Assistant`. + pub fn assistant(content: impl Into) -> Self { + Self { + role: ChatRole::Assistant, + content: content.into(), + } + } + + /// Create a new ChatMessage with the role `ChatRole::User`. + pub fn user(content: impl Into) -> Self { + Self { + role: ChatRole::User, + content: content.into(), + } + } +} + +/// Chat roles. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[allow(missing_docs)] +pub enum ChatRole { + System, + User, + Assistant, + Tool, +} + +// region: --- Froms + +impl From> for ChatMessage { + fn from(tool_calls: Vec) -> Self { + Self { + role: ChatRole::Assistant, + content: MessageContent::from(tool_calls), + } + } +} + +impl From for ChatMessage { + fn from(value: ToolResponse) -> Self { + Self { + role: ChatRole::Tool, + content: MessageContent::from(value), + } + } +} + +// endregion: --- Froms diff --git a/src/chat/chat_options.rs b/src/chat/chat_options.rs index 4e9be90..d9ded9e 100644 --- a/src/chat/chat_options.rs +++ b/src/chat/chat_options.rs @@ -5,7 +5,7 @@ //! Note 1: In the future, we will probably allow setting the client //! Note 2: Extracting it from the `ChatRequest` object allows for better reusability of each component. -use crate::chat::chat_response_format::ChatResponseFormat; +use crate::chat::chat_req_response_format::ChatResponseFormat; use serde::{Deserialize, Serialize}; /// Chat Options that are taken into account for any `Client::exec...` calls. diff --git a/src/chat/chat_response_format.rs b/src/chat/chat_req_response_format.rs similarity index 95% rename from src/chat/chat_response_format.rs rename to src/chat/chat_req_response_format.rs index cbf5cab..2f31424 100644 --- a/src/chat/chat_response_format.rs +++ b/src/chat/chat_req_response_format.rs @@ -2,7 +2,7 @@ use derive_more::From; use serde::{Deserialize, Serialize}; use serde_json::Value; -/// The chat response format to be sent back by the LLM. +/// The chat response format for the ChatRequest for structured output. /// This will be taken into consideration only if the provider supports it. /// /// > Note: Currently, the AI Providers will not report an error if not supported. It will just be ignored. diff --git a/src/chat/chat_req.rs b/src/chat/chat_request.rs similarity index 61% rename from src/chat/chat_req.rs rename to src/chat/chat_request.rs index ff30e98..93b9b4e 100644 --- a/src/chat/chat_req.rs +++ b/src/chat/chat_request.rs @@ -1,6 +1,6 @@ //! This module contains all the types related to a Chat Request (except ChatOptions, which has its own file). -use crate::chat::MessageContent; +use crate::chat::{ChatMessage, ChatRole, MessageContent, Tool}; use serde::{Deserialize, Serialize}; // region: --- ChatRequest @@ -13,13 +13,19 @@ pub struct ChatRequest { /// The messages of the request. pub messages: Vec, + + pub tools: Option>, } /// Constructors impl ChatRequest { /// Create a new ChatRequest with the given messages. pub fn new(messages: Vec) -> Self { - Self { messages, system: None } + Self { + messages, + system: None, + tools: None, + } } /// From the `.system` property content. @@ -27,6 +33,7 @@ impl ChatRequest { Self { system: Some(content.into()), messages: Vec::new(), + tools: None, } } @@ -34,7 +41,8 @@ impl ChatRequest { pub fn from_user(content: impl Into) -> Self { Self { system: None, - messages: vec![ChatMessage::user(content)], + messages: vec![ChatMessage::user(content.into())], + tools: None, } } } @@ -48,8 +56,18 @@ impl ChatRequest { } /// Append a message to the request. - pub fn append_message(mut self, msg: ChatMessage) -> Self { - self.messages.push(msg); + pub fn append_message(mut self, msg: impl Into) -> Self { + self.messages.push(msg.into()); + self + } + + pub fn with_tools(mut self, tools: Vec) -> Self { + self.tools = Some(tools); + self + } + + pub fn append_tool(mut self, tool: impl Into) -> Self { + self.tools.get_or_insert_with(Vec::new).push(tool.into()); self } } @@ -65,6 +83,8 @@ impl ChatRequest { .chain(self.messages.iter().filter_map(|message| match message.role { ChatRole::System => match message.content { MessageContent::Text(ref content) => Some(content.as_str()), + /// If system content is not text, then, we do not add it for now. + _ => None, }, _ => None, })) @@ -97,74 +117,3 @@ impl ChatRequest { } // endregion: --- ChatRequest - -// region: --- ChatMessage - -/// An individual chat message. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ChatMessage { - /// The role of the message. - pub role: ChatRole, - - /// The content of the message. - pub content: MessageContent, - - /// Extra information about the message. - pub extra: Option, -} - -/// Constructors -impl ChatMessage { - /// Create a new ChatMessage with the role `ChatRole::System`. - pub fn system(content: impl Into) -> Self { - Self { - role: ChatRole::System, - content: content.into(), - extra: None, - } - } - - /// Create a new ChatMessage with the role `ChatRole::Assistant`. - pub fn assistant(content: impl Into) -> Self { - Self { - role: ChatRole::Assistant, - content: content.into(), - extra: None, - } - } - - /// Create a new ChatMessage with the role `ChatRole::User`. - pub fn user(content: impl Into) -> Self { - Self { - role: ChatRole::User, - content: content.into(), - extra: None, - } - } -} - -/// Chat roles. -#[derive(Debug, Clone, Serialize, Deserialize)] -#[allow(missing_docs)] -pub enum ChatRole { - System, - User, - Assistant, - Tool, -} - -/// NOTE: DO NOT USE, just a placeholder for now. -#[derive(Debug, Clone, Serialize, Deserialize)] -#[allow(missing_docs)] -pub enum MessageExtra { - Tool(ToolExtra), -} - -/// NOTE: DO NOT USE, just a placeholder for now. -#[allow(unused)] -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToolExtra { - tool_id: String, -} - -// endregion: --- ChatMessage diff --git a/src/chat/chat_res.rs b/src/chat/chat_res.rs index 1551651..d974171 100644 --- a/src/chat/chat_res.rs +++ b/src/chat/chat_res.rs @@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize}; -use crate::chat::{ChatStream, MessageContent}; +use crate::chat::{ChatStream, MessageContent, ToolCall}; use crate::ModelIden; // region: --- ChatResponse @@ -13,12 +13,12 @@ pub struct ChatResponse { /// The eventual content of the chat response pub content: Option, - /// The eventual usage of the chat response - pub usage: MetaUsage, - /// The Model Identifier (AdapterKind/ModelName) used for this request. /// > NOTE: This might be different from the request model if changed by the ModelMapper pub model_iden: ModelIden, + + /// The eventual usage of the chat response + pub usage: MetaUsage, } // Getters @@ -34,6 +34,22 @@ impl ChatResponse { pub fn content_text_into_string(self) -> Option { self.content.and_then(MessageContent::text_into_string) } + + pub fn tool_calls(&self) -> Option> { + if let Some(MessageContent::ToolCalls(tool_calls)) = self.content.as_ref() { + Some(tool_calls.iter().collect()) + } else { + None + } + } + + pub fn into_tool_calls(self) -> Option> { + if let Some(MessageContent::ToolCalls(tool_calls)) = self.content { + Some(tool_calls) + } else { + None + } + } } // endregion: --- ChatResponse diff --git a/src/chat/message_content.rs b/src/chat/message_content.rs index e4fc5ff..481f66b 100644 --- a/src/chat/message_content.rs +++ b/src/chat/message_content.rs @@ -1,19 +1,34 @@ +use crate::chat::{ToolCall, ToolResponse}; +use derive_more::derive::From; use serde::{Deserialize, Serialize}; /// Currently, it only supports Text, /// but the goal is to support multi-part message content (see below) -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, From)] pub enum MessageContent { /// Text content Text(String), + + /// Tool calls + #[from] + ToolCalls(Vec), + + /// Tool call Responses + #[from] + ToolResponses(Vec), } /// Constructors impl MessageContent { /// Create a new MessageContent with the Text variant - pub fn text(content: impl Into) -> Self { + pub fn from_text(content: impl Into) -> Self { MessageContent::Text(content.into()) } + + /// Create a new MessageContent with the ToolCalls variant + pub fn from_tool_calls(tool_calls: Vec) -> Self { + MessageContent::ToolCalls(tool_calls) + } } /// Getters @@ -25,6 +40,8 @@ impl MessageContent { pub fn text_as_str(&self) -> Option<&str> { match self { MessageContent::Text(content) => Some(content.as_str()), + MessageContent::ToolCalls(_) => None, + MessageContent::ToolResponses(_) => None, } } @@ -36,29 +53,44 @@ impl MessageContent { pub fn text_into_string(self) -> Option { match self { MessageContent::Text(content) => Some(content), + MessageContent::ToolCalls(_) => None, + MessageContent::ToolResponses(_) => None, } } - /// Checks if the text content is empty (for now) - /// Later, this will also validate each variant to check if they can be considered "empty" + /// Checks if the text content or the tools calls is empty. pub fn is_empty(&self) -> bool { match self { MessageContent::Text(content) => content.is_empty(), + MessageContent::ToolCalls(tool_calls) => tool_calls.is_empty(), + MessageContent::ToolResponses(tool_responses) => tool_responses.is_empty(), } } } // region: --- Froms -/// Blanket implementation for MessageContent::Text for anything that implements Into -/// Note: This means that when we support base64 as images, it should not use `.into()` for MessageContent. -/// It should be acceptable but may need reassessment. -impl From for MessageContent -where - T: Into, -{ - fn from(s: T) -> Self { - MessageContent::text(s) +impl From for MessageContent { + fn from(s: String) -> Self { + MessageContent::from_text(s) + } +} + +impl<'a> From<&'a str> for MessageContent { + fn from(s: &'a str) -> Self { + MessageContent::from_text(s.to_string()) + } +} + +impl From<&String> for MessageContent { + fn from(s: &String) -> Self { + MessageContent::from_text(s.clone()) + } +} + +impl From for MessageContent { + fn from(tool_response: ToolResponse) -> Self { + MessageContent::ToolResponses(vec![tool_response]) } } diff --git a/src/chat/mod.rs b/src/chat/mod.rs index 6da3045..4389dd0 100644 --- a/src/chat/mod.rs +++ b/src/chat/mod.rs @@ -3,19 +3,21 @@ // region: --- Modules +mod chat_message; mod chat_options; -mod chat_req; +mod chat_req_response_format; +mod chat_request; mod chat_res; -mod chat_response_format; mod chat_stream; mod message_content; mod tool; // -- Flatten +pub use chat_message::*; pub use chat_options::*; -pub use chat_req::*; +pub use chat_req_response_format::*; +pub use chat_request::*; pub use chat_res::*; -pub use chat_response_format::*; pub use chat_stream::*; pub use message_content::*; pub use tool::*; diff --git a/src/chat/tool.rs b/src/chat/tool.rs deleted file mode 100644 index 810666d..0000000 --- a/src/chat/tool.rs +++ /dev/null @@ -1,14 +0,0 @@ -use serde::{Deserialize, Serialize}; -use serde_json::Value; - -/// NOT USED FOR NOW -/// > For later, it will be used for function calling -/// > It will probably use the JsonSpec type we had in the response format, -/// > or have a `From` implementation. -#[allow(unused)] // Not used yet -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Tool { - fn_name: String, - fn_description: String, - params: Value, -} diff --git a/src/chat/tool/mod.rs b/src/chat/tool/mod.rs new file mode 100644 index 0000000..003af4b --- /dev/null +++ b/src/chat/tool/mod.rs @@ -0,0 +1,11 @@ +// region: --- Modules + +mod tool_base; +mod tool_call; +mod tool_response; + +pub use tool_base::*; +pub use tool_call::*; +pub use tool_response::*; + +// endregion: --- Modules diff --git a/src/chat/tool/tool_base.rs b/src/chat/tool/tool_base.rs new file mode 100644 index 0000000..5d08013 --- /dev/null +++ b/src/chat/tool/tool_base.rs @@ -0,0 +1,64 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Tool { + /// The tool name, which is typically the function name + /// e.g., `get_weather` + pub name: String, + + /// The description of the tool which will be used by the LLM to understand the context/usage of this tool + pub description: Option, + + /// The json-schema for the parameters + /// e.g., + /// ```json + /// json!({ + /// "type": "object", + /// "properties": { + /// "city": { + /// "type": "string", + /// "description": "The city name" + /// }, + /// "country": { + /// "type": "string", + /// "description": "The most likely country of this city name" + /// }, + /// "unit": { + /// "type": "string", + /// "enum": ["C", "F"], + /// "description": "The temperature unit of the country. C for Celsius, and F for Fahrenheit" + /// } + /// }, + /// "required": ["city", "country", "unit"], + /// }) + /// ``` + pub schema: Option, +} + +/// Constructor +impl Tool { + pub fn new(name: impl Into) -> Self { + Self { + name: name.into(), + description: None, + schema: None, + } + } +} + +// region: --- Setters + +impl Tool { + pub fn with_description(mut self, description: impl Into) -> Self { + self.description = Some(description.into()); + self + } + + pub fn with_schema(mut self, parameters: Value) -> Self { + self.schema = Some(parameters); + self + } +} + +// endregion: --- Setters diff --git a/src/chat/tool/tool_call.rs b/src/chat/tool/tool_call.rs new file mode 100644 index 0000000..75e49a0 --- /dev/null +++ b/src/chat/tool/tool_call.rs @@ -0,0 +1,10 @@ +use serde::{Deserialize, Deserializer, Serialize}; +use serde_json::Value; + +/// The tool call function name and arguments send back by the LLM. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCall { + pub call_id: String, + pub fn_name: String, + pub fn_arguments: Value, +} diff --git a/src/chat/tool/tool_response.rs b/src/chat/tool/tool_response.rs new file mode 100644 index 0000000..642a409 --- /dev/null +++ b/src/chat/tool/tool_response.rs @@ -0,0 +1,29 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolResponse { + pub call_id: String, + // for now, just string (would probably be serialized json) + pub content: String, +} + +/// constructor +impl ToolResponse { + pub fn new(tool_call_id: impl Into, content: impl Into) -> Self { + Self { + call_id: tool_call_id.into(), + content: content.into(), + } + } +} + +/// Getters +impl ToolResponse { + fn tool_call_id(&self) -> &str { + &self.call_id + } + + fn content(&self) -> &str { + &self.content + } +} diff --git a/src/error.rs b/src/error.rs index ea4311d..89f5e30 100644 --- a/src/error.rs +++ b/src/error.rs @@ -23,12 +23,19 @@ pub enum Error { model_iden: ModelIden, role: ChatRole, }, + MessageContentTypeNotSupported { + model_iden: ModelIden, + cause: &'static str, + }, JsonModeWithoutInstruction, // -- Chat Output NoChatResponse { model_iden: ModelIden, }, + InvalidJsonResponseElement { + info: &'static str, + }, // -- Auth RequiresApiKey { @@ -77,14 +84,15 @@ pub enum Error { resolver_error: resolver::Error, }, - // -- Utils - // -- Externals #[from] EventSourceClone(reqwest_eventsource::CannotCloneRequestError), #[from] JsonValueExt(JsonValueExtError), ReqwestEventSource(reqwest_eventsource::Error), + // Note: will probably need to remvoe this one to give more context + #[from] + SerdeJson(serde_json::Error), } // region: --- Error Boilerplate diff --git a/tests/support/common_tests.rs b/tests/support/common_tests.rs index 7377517..e5714cf 100644 --- a/tests/support/common_tests.rs +++ b/tests/support/common_tests.rs @@ -1,6 +1,6 @@ use crate::get_option_value; -use crate::support::{extract_stream_end, seed_chat_req_simple, Result}; -use genai::chat::{ChatMessage, ChatOptions, ChatRequest, ChatResponseFormat, JsonSpec}; +use crate::support::{extract_stream_end, seed_chat_req_simple, seed_chat_req_tool_simple, Result}; +use genai::chat::{ChatMessage, ChatOptions, ChatRequest, ChatResponseFormat, JsonSpec, Tool, ToolResponse}; use genai::resolver::{AuthData, AuthResolver, AuthResolverFn, IntoAuthResolverFn}; use genai::{Client, ClientConfig, ModelIden}; use serde_json::{json, Value}; @@ -260,6 +260,68 @@ pub async fn common_test_chat_stream_capture_all_ok(model: &str) -> Result<()> { // endregion: --- Chat Stream Tests +// region: --- Tools + +/// Just making the tool request, and checking the tool call response +/// `complete_check` if for LLMs that are better at giving back the unit and weather. +pub async fn common_test_tool_simple_ok(model: &str, complete_check: bool) -> Result<()> { + // -- Setup & Fixtures + let client = Client::default(); + let chat_req = seed_chat_req_tool_simple(); + + // -- Exec + let chat_res = client.exec_chat(model, chat_req, None).await?; + + // -- Check + let mut tool_calls = chat_res.tool_calls().ok_or("Should have tool calls")?; + let tool_call = tool_calls.pop().ok_or("Should have at least one tool call")?; + assert_eq!(tool_call.fn_arguments.x_get_as::<&str>("city")?, "Paris"); + assert_eq!(tool_call.fn_arguments.x_get_as::<&str>("country")?, "France"); + if complete_check { + // Note: Not all LLM will output the weather (e.g. Anthropic Haiku) + assert_eq!(tool_call.fn_arguments.x_get_as::<&str>("unit")?, "C"); + } + + Ok(()) +} + +/// `complete_check` if for LLMs that are better at giving back the unit and weather. +/// +pub async fn common_test_tool_full_flow_ok(model: &str, complete_check: bool) -> Result<()> { + // -- Setup & Fixtures + let client = Client::default(); + let mut chat_req = seed_chat_req_tool_simple(); + + // -- Exec first request to get the tool calls + let chat_res = client.exec_chat(model, chat_req.clone(), None).await?; + let tool_calls = chat_res.into_tool_calls().ok_or("Should have tool calls in chat_res")?; + + // -- Exec the second request + // get the tool call id (first one) + let first_tool_call = tool_calls.first().ok_or("Should have at least one tool call")?; + let first_tool_call_id = &first_tool_call.call_id; + // simulate the response + let tool_response = ToolResponse::new(first_tool_call_id, r#"{"weather": "Sunny", "temperature": "32C"}"#); + + // Add the tool_calls, tool_response + let chat_req = chat_req.append_message(tool_calls).append_message(tool_response); + + let chat_res = client.exec_chat(model, chat_req.clone(), None).await?; + + // -- Check + let content = chat_res.content_text_as_str().ok_or("Last response should be message")?; + assert!(content.contains("Paris"), "Should contain 'Paris'"); + assert!(content.contains("32"), "Should contain '32'"); + if complete_check { + // Note: Not all LLM will output the weather (e.g. Anthropic Haiku) + assert!(content.contains("sunny"), "Should contain 'sunny'"); + } + + Ok(()) +} + +// endregion: --- Tools + // region: --- With Resolvers pub async fn common_test_resolver_auth_ok(model: &str, auth_data: AuthData) -> Result<()> { diff --git a/tests/support/seeders.rs b/tests/support/seeders.rs index 0f540a6..2fd1051 100644 --- a/tests/support/seeders.rs +++ b/tests/support/seeders.rs @@ -1,4 +1,5 @@ -use genai::chat::{ChatMessage, ChatRequest}; +use genai::chat::{ChatMessage, ChatRequest, Tool}; +use serde_json::json; pub fn seed_chat_req_simple() -> ChatRequest { ChatRequest::new(vec![ @@ -7,3 +8,29 @@ pub fn seed_chat_req_simple() -> ChatRequest { ChatMessage::user("Why is the sky blue?"), ]) } + +pub fn seed_chat_req_tool_simple() -> ChatRequest { + ChatRequest::new(vec![ + // -- Messages (deactivate to see the differences) + ChatMessage::user("What is the temperature in C, in Paris"), + ]) + .append_tool(Tool::new("get_weather").with_schema(json!({ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city name" + }, + "country": { + "type": "string", + "description": "The most likely country of this city name" + }, + "unit": { + "type": "string", + "enum": ["C", "F"], + "description": "The temperature unit of the country. C for Celsius, and F for Fahrenheit" + } + }, + "required": ["city", "country", "unit"], + }))) +} diff --git a/tests/tests_p_anthropic.rs b/tests/tests_p_anthropic.rs index 6624995..c54cc57 100644 --- a/tests/tests_p_anthropic.rs +++ b/tests/tests_p_anthropic.rs @@ -45,6 +45,20 @@ async fn test_chat_stream_capture_all_ok() -> Result<()> { // endregion: --- Chat Stream Tests +// region: --- Tool Tests + +#[tokio::test] +async fn test_tool_simple_ok() -> Result<()> { + common_tests::common_test_tool_simple_ok(MODEL, false).await +} + +#[tokio::test] +async fn test_tool_full_flow_ok() -> Result<()> { + common_tests::common_test_tool_full_flow_ok(MODEL, false).await +} + +// endregion: --- Tool Tests + // region: --- Resolver Tests #[tokio::test] diff --git a/tests/tests_p_openai.rs b/tests/tests_p_openai.rs index 5300099..3c2850e 100644 --- a/tests/tests_p_openai.rs +++ b/tests/tests_p_openai.rs @@ -50,6 +50,19 @@ async fn test_chat_stream_capture_all_ok() -> Result<()> { // endregion: --- Chat Stream Tests +// region: --- Tool Tests + +#[tokio::test] +async fn test_tool_simple_ok() -> Result<()> { + common_tests::common_test_tool_simple_ok(MODEL, true).await +} + +#[tokio::test] +async fn test_tool_full_flow_ok() -> Result<()> { + common_tests::common_test_tool_full_flow_ok(MODEL, true).await +} +// endregion: --- Tool Tests + // region: --- Resolver Tests #[tokio::test]