diff --git a/agents/__init__.py b/agents/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/agents/tests/__init__.py b/agents/tests/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/langchain_benchmarks/model_registration.py b/langchain_benchmarks/model_registration.py
index fa9da072..d2a5a39b 100644
--- a/langchain_benchmarks/model_registration.py
+++ b/langchain_benchmarks/model_registration.py
@@ -194,7 +194,7 @@
),
RegisteredModel(
provider="fireworks",
- name="mixtral-8x7b-instruct",
+ name="mixtral-8x7b-instruct-fw",
description="Mistral MoE 8x7B Instruct v0.1 model with Sparse "
"Mixture of Experts. Fine tuned for instruction following",
type="llm",
diff --git a/langchain_benchmarks/tool_usage/agents/__init__.py b/langchain_benchmarks/tool_usage/agents/__init__.py
new file mode 100644
index 00000000..4e9f2896
--- /dev/null
+++ b/langchain_benchmarks/tool_usage/agents/__init__.py
@@ -0,0 +1,7 @@
+from langchain_benchmarks.tool_usage.agents.adapters import apply_agent_executor_adapter
+from langchain_benchmarks.tool_usage.agents.experimental.factory import (
+ CustomAgentFactory,
+)
+from langchain_benchmarks.tool_usage.agents.openai_functions import OpenAIAgentFactory
+
+__all__ = ["OpenAIAgentFactory", "apply_agent_executor_adapter", "CustomAgentFactory"]
diff --git a/langchain_benchmarks/tool_usage/agents/adapters.py b/langchain_benchmarks/tool_usage/agents/adapters.py
new file mode 100644
index 00000000..29bd3ddf
--- /dev/null
+++ b/langchain_benchmarks/tool_usage/agents/adapters.py
@@ -0,0 +1,71 @@
+from typing import Optional, Callable, Any
+
+from langchain.agents import AgentExecutor
+from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
+
+
+def _ensure_output_exists(inputs: dict) -> dict:
+ """Make sure that the output key is always present."""
+ if "output" not in inputs:
+ return {"output": "", **inputs}
+ return inputs
+
+
+def apply_agent_executor_adapter(
+ agent_executor: AgentExecutor,
+ *,
+ state_reader: Optional[Callable[[], Any]] = None,
+) -> Runnable:
+ """An adapter for the agent executor to standardize its input and output.
+
+ 1) Map `question` to `input` (`question` is used in the datasets,
+ but `input` is used in the agent executor)
+ 2) Ensure that `output` is always returned (will be set to "" if missing) --
+ note that this may be relaxed after more updates in the eval config.
+ 3) Populate `state` key in the response of the agent with the system state
+ if a state reader is provided.
+
+ Args:
+ agent_executor: the agent executor
+ state_reader: A callable without parameters that if invoked will return
+ the state of the environment. Used to populate the 'state' key.
+
+ Returns:
+ a new runnable with a standardized output.
+ """
+
+ def _read_state(*args: Any, **kwargs: Any) -> Any:
+ """Read the state of the environment."""
+ if state_reader is not None:
+ return state_reader()
+ else:
+ return None
+
+ def _format_input(inputs: dict) -> dict:
+ """Make sure that the input is always called `input`."""
+
+ if "question" not in inputs:
+ raise ValueError(
+ "Expected 'question' to be in the inputs. Found only the following "
+ f"keys {sorted(inputs.keys())}."
+ )
+
+ inputs = inputs.copy() # Because 'question' is popped below
+
+ if "input" not in inputs:
+ return {"input": inputs.pop("question"), **inputs}
+ return inputs
+
+ runnable = (
+ RunnableLambda(_format_input).with_config({"run_name": "Format Input"})
+ | agent_executor
+ | RunnableLambda(_ensure_output_exists).with_config(
+ {"run_name": "Ensure Output"}
+ )
+ )
+
+ if state_reader is not None:
+ runnable = runnable | RunnablePassthrough.assign(state=_read_state).with_config(
+ {"run_name": "Read Env State"}
+ )
+ return runnable
diff --git a/langchain_benchmarks/tool_usage/agents/experimental/__init__.py b/langchain_benchmarks/tool_usage/agents/experimental/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/langchain_benchmarks/tool_usage/agents/experimental/agent.py b/langchain_benchmarks/tool_usage/agents/experimental/agent.py
new file mode 100644
index 00000000..14ba932d
--- /dev/null
+++ b/langchain_benchmarks/tool_usage/agents/experimental/agent.py
@@ -0,0 +1,133 @@
+from typing import List, Literal, Optional, Sequence, Tuple, Union
+
+from langchain.agents import AgentOutputParser
+from langchain.prompts.chat import ChatPromptTemplate
+from langchain.schema.runnable import Runnable
+from langchain.tools import StructuredTool
+from langchain_core.agents import AgentAction, AgentFinish
+from langchain_core.language_models import BaseChatModel, BaseLanguageModel
+from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
+from langchain_core.prompts import MessagesPlaceholder
+from typing_extensions import NotRequired, TypedDict
+
+from langchain_benchmarks import RateLimiter
+from langchain_benchmarks.rate_limiting import with_rate_limit
+from langchain_benchmarks.tool_usage.agents.experimental.encoder import (
+ AstPrinter,
+ TypeScriptEncoder,
+ XMLEncoder,
+)
+from langchain_benchmarks.tool_usage.agents.experimental.encoder import FunctionResult
+from langchain_benchmarks.tool_usage.agents.experimental.prompts import (
+ _AGENT_INSTRUCTIONS_BLOB_STYLE,
+)
+from langchain_benchmarks.tool_usage.agents.experimental.tool_utils import (
+ convert_tool_to_function_definition,
+)
+
+
+def format_steps_for_chat(
+ intermediate_steps: List[Tuple[AgentAction, str]],
+ ast_printer: AstPrinter,
+) -> List[BaseMessage]:
+ """Format the steps."""
+ messages = []
+ for action, observation in intermediate_steps:
+ # Action messages contains the tool invocation request from the LLM
+ # Now add the result of the tool invocation.
+
+ if action.tool == "_Exception":
+ messages.append(
+ AIMessage(
+ content=action.log,
+ )
+ )
+ messages.append(
+ # Tool input is the error message for the exception
+ HumanMessage(content=action.tool_input)
+ )
+ else:
+ messages.extend(action.messages)
+ function_result: FunctionResult = {
+ "name": action.tool,
+ "error": None,
+ "result": observation,
+ }
+ messages.append(
+ HumanMessage(
+ content=ast_printer.visit_function_result(function_result),
+ )
+ )
+
+ return messages
+
+
+# PUBLIC API
+
+
+class AgentInput(TypedDict):
+ """The input to the agent."""
+
+ input: str
+ """The input to the agent."""
+ intermediate_steps: List[Tuple[AgentAction, str]]
+ """The intermediate steps taken by the agent."""
+ examples: NotRequired[List[BaseMessage]]
+ """A list of messages that can be used to form example traces."""
+
+
+def create_agent(
+ model: Union[BaseChatModel, BaseLanguageModel],
+ tools: Sequence[StructuredTool],
+ parser: AgentOutputParser,
+ *,
+ ast_printer: Union[AstPrinter, Literal["xml"]] = "xml",
+ rate_limiter: Optional[RateLimiter] = None,
+) -> Runnable[AgentInput, Union[AgentAction, AgentFinish]]:
+ """Create an agent for a chat model."""
+ if isinstance(ast_printer, str):
+ if ast_printer == "xml":
+ ast_printer_ = XMLEncoder()
+ elif ast_printer == "typescript":
+ ast_printer_ = TypeScriptEncoder()
+ else:
+ raise ValueError(f"Unknown ast printer: {ast_printer}")
+ elif isinstance(ast_printer, AstPrinter):
+ ast_printer_ = ast_printer
+ else:
+ raise TypeError(
+ f"Expected AstPrinter or str, got {type(ast_printer)} for `ast_printer`"
+ )
+
+ function_definitions = [convert_tool_to_function_definition(tool) for tool in tools]
+ tool_description = ast_printer_.visit_function_definitions(function_definitions)
+
+ template = ChatPromptTemplate.from_messages(
+ [
+ ("system", _AGENT_INSTRUCTIONS_BLOB_STYLE),
+ MessagesPlaceholder("examples"), # Can use to add example traces
+ ("human", "{input}"),
+ MessagesPlaceholder("history"),
+ ]
+ ).partial(tool_description=tool_description)
+
+ # For the time being, hard-coding the fact that we're using a tag.
+ model = model.bind(stop=[""])
+
+ if rate_limiter:
+ # Apply a rate limiter if it was provided
+ model = with_rate_limit(model, rate_limiter)
+
+ agent = (
+ {
+ "input": lambda x: x["input"],
+ "history": lambda x: format_steps_for_chat(
+ x["intermediate_steps"], ast_printer_
+ ),
+ "examples": lambda x: x.get("examples", []),
+ }
+ | template
+ | model
+ | parser
+ )
+ return agent
diff --git a/langchain_benchmarks/tool_usage/agents/experimental/encoder.py b/langchain_benchmarks/tool_usage/agents/experimental/encoder.py
new file mode 100644
index 00000000..c6799609
--- /dev/null
+++ b/langchain_benchmarks/tool_usage/agents/experimental/encoder.py
@@ -0,0 +1,240 @@
+"""Prototyping code for rendering function definitions, invocations, and results.
+
+Types are simplified for now to `str`.
+
+We should actually support something like pydantic or jsonschema for the types, so
+we can expand them recursively for nested types.
+"""
+import abc
+from typing import Any, List, Optional
+
+from typing_extensions import NotRequired, TypedDict
+
+
+class Parameter(TypedDict):
+ """Representation for a parameter."""
+
+ name: str
+ type: str
+ description: str
+
+
+class Arguments(TypedDict):
+ """Arguments are passed to a function during function invocation."""
+
+ name: Optional[str]
+ value: Any
+
+
+class ReturnValue(TypedDict):
+ """Representation for a return value of a function call."""
+
+ type: str
+ description: NotRequired[str]
+
+
+class FunctionDefinition(TypedDict):
+ """Representation for a function."""
+
+ name: str
+ description: str # Function description
+ parameters: List[Parameter]
+ return_value: ReturnValue
+
+
+class FunctionInvocation(TypedDict):
+ """Representation for a function invocation."""
+
+ id: NotRequired[str]
+ name: str
+ arguments: List[Arguments]
+
+
+class FunctionResult(TypedDict):
+ """Representation for a function result."""
+
+ id: NotRequired[str]
+ name: str
+ result: Optional[str]
+ error: Optional[str]
+
+
+class Visitor(abc.ABC):
+ @abc.abstractmethod
+ def visit_function_definition(self, function_definition: FunctionDefinition) -> str:
+ """Render a function."""
+
+ @abc.abstractmethod
+ def visit_function_definitions(
+ self, function_definitions: List[FunctionDefinition]
+ ) -> str:
+ """Render a function."""
+
+ @abc.abstractmethod
+ def visit_function_invocation(self, function_invocation: FunctionInvocation) -> str:
+ """Render a function invocation."""
+
+ @abc.abstractmethod
+ def visit_function_result(self, function_result: FunctionResult) -> str:
+ """Render a function result."""
+
+
+class AstPrinter(Visitor):
+ """Print the AST."""
+
+
+class XMLEncoder(AstPrinter):
+ def visit_function_definition(self, function_definition: FunctionDefinition) -> str:
+ """Render a function."""
+ parameters_lines = []
+
+ for parameter in function_definition["parameters"]:
+ parameters_lines.extend(
+ [
+ "",
+ f"{parameter['name']}",
+ f"{parameter['type']}",
+ f"{parameter['description']}",
+ "",
+ ]
+ )
+ lines = [
+ "",
+ f"{function_definition['name']}",
+ "",
+ f"{function_definition['description']}",
+ "",
+ "",
+ *parameters_lines,
+ "",
+ "",
+ f"{function_definition['return_value']['type']}",
+ ]
+ if function_definition["return_value"].get("description"):
+ lines.append(
+ f"{function_definition['return_value']['description']}"
+ f""
+ )
+
+ lines.extend(["", ""])
+ return "\n".join(lines)
+
+ def visit_function_definitions(
+ self, function_definitions: List[FunctionDefinition]
+ ) -> str:
+ """Render a function."""
+ strs = [
+ self.visit_function_definition(function_definition)
+ for function_definition in function_definitions
+ ]
+ return "\n" + "\n".join(strs) + "\n"
+
+ def visit_function_invocation(self, invocation: FunctionInvocation) -> str:
+ """Render a function invocation."""
+ arguments_as_strings = [
+ "\n"
+ f"{argument['name']}\n"
+ f"{argument['value']}\n"
+ "\n"
+ for argument in invocation["arguments"]
+ ]
+ lines = [""]
+
+ if invocation.get("id"):
+ lines.append(f"{invocation['id']}")
+
+ lines.extend(
+ [
+ f"{invocation['name']}\n"
+ "\n"
+ f"{''.join(arguments_as_strings)}" # Already includes trailing newline
+ "\n"
+ ""
+ ]
+ )
+ return "\n".join(lines)
+
+ def visit_function_result(self, function_result: FunctionResult) -> str:
+ """Render a function result."""
+ lines = [
+ "",
+ ]
+
+ if function_result.get("id"):
+ lines.append(f"{function_result['id']}")
+
+ lines.append(f"{function_result['name']}")
+
+ if function_result["error"]:
+ lines.extend(
+ [
+ f"{function_result['error']}",
+ ]
+ )
+ else:
+ lines.append(
+ f"{function_result['result']}",
+ )
+
+ lines.append("")
+
+ return "\n".join(lines)
+
+
+class TypeScriptEncoder(AstPrinter):
+ def visit_function_definition(self, function_definition: FunctionDefinition) -> str:
+ """Render a function."""
+ parameters_as_strings = [
+ f"{parameter['name']}: {parameter['type']}"
+ for parameter in function_definition["parameters"]
+ ]
+ # Let's use JSdoc style comments
+ # First the function description
+ lines = [
+ f"// {function_definition['description']}",
+ # Then the parameter descriptions
+ *[
+ f"// @param {parameter['name']} {parameter['description']}"
+ for parameter in function_definition["parameters"]
+ ],
+ # Then the return value description
+ f"// @returns {function_definition['return_value']['description']}",
+ # Then the function definition
+ f"function {function_definition['name']}("
+ f"{', '.join(parameters_as_strings)}): "
+ f"{function_definition['return_value']['type']};",
+ ]
+
+ # finally join
+ function = "\n".join(lines)
+ return function
+
+ def visit_function_definitions(
+ self, function_definitions: List[FunctionDefinition]
+ ) -> str:
+ """Render a function."""
+ strs = [
+ self.visit_function_definition(function_definition)
+ for function_definition in function_definitions
+ ]
+ return "\n\n".join(strs)
+
+ def visit_function_invocation(self, invocation: FunctionInvocation) -> str:
+ """Render a function invocation."""
+ arguments_as_strings = [
+ f"{argument['name']}: {argument['value']}"
+ for argument in invocation["arguments"]
+ ]
+ lines = [f"{invocation['name']}(" f"{', '.join(arguments_as_strings)});"]
+ return "\n".join(lines)
+
+ def visit_function_result(self, function_result: FunctionResult) -> str:
+ """Render a function result."""
+ lines = []
+ if function_result["error"]:
+ lines.append(f"ERROR: {function_result['error']}")
+ else:
+ lines.append(f"> {function_result['result']}")
+ if function_result.get("id"):
+ lines.append(f"// ID: {function_result['id']}")
+ return "\n".join(lines)
diff --git a/langchain_benchmarks/tool_usage/agents/experimental/factory.py b/langchain_benchmarks/tool_usage/agents/experimental/factory.py
new file mode 100644
index 00000000..d03c7626
--- /dev/null
+++ b/langchain_benchmarks/tool_usage/agents/experimental/factory.py
@@ -0,0 +1,86 @@
+"""Factory for creating agents for the tool usage task."""
+from typing import Optional
+
+from langchain.agents import AgentExecutor
+from langchain_core.runnables import Runnable, RunnableConfig
+
+from langchain_benchmarks import RateLimiter, model_registry
+from langchain_benchmarks.schema import ToolUsageTask
+from langchain_benchmarks.tool_usage.agents.adapters import apply_agent_executor_adapter
+from langchain_benchmarks.tool_usage.agents.experimental.agent import create_agent
+from langchain_benchmarks.tool_usage.agents.experimental.parser import (
+ GenericAgentParser,
+)
+
+
+class CustomAgentFactory:
+ """A factory for creating tool using agents.
+
+ A factory for agents that do not leverage any special JSON mode for
+ function usage; instead all function invocation behavior is implemented solely
+ through prompt engineering and parsing.
+ """
+
+ def __init__(
+ self,
+ task: ToolUsageTask,
+ model: str,
+ *,
+ rate_limiter: Optional[RateLimiter] = None,
+ ) -> None:
+ """Create an agent factory for the given tool usage task.
+
+ Args:
+ task: The task to create an agent factory for
+ model: model name (check model_registry)
+ rate_limiter: The rate limiter to use if provided
+ """
+ if model not in model_registry:
+ raise ValueError(f"Unknown model: {model}")
+ self.task = task
+ self.model = model
+ self.rate_limiter = rate_limiter
+
+ def __call__(self) -> Runnable:
+ if isinstance(self.model, str):
+ registered_model = model_registry.get_model(self.model)
+ if registered_model is None:
+ raise ValueError(f"Unknown model: {self.model}")
+ model = registered_model.get_model(model_params={"temperature": 0})
+ else:
+ model = self.model
+
+ def _add_task_instructions(
+ input: dict, config: Optional[RunnableConfig] = None, **kwargs
+ ) -> dict:
+ """Add task instructions to the question."""
+ if not isinstance(input, dict):
+ raise ValueError(
+ f"Expected input to be a dict with key `question`. "
+ f"Found {type(input)}."
+ )
+ input = input.copy()
+ input["question"] = (
+ f"{self.task.instructions}\nWrite down your answer, "
+ f"but do not explain it. Input: `{input['question']}`"
+ )
+ return input
+
+ env = self.task.create_environment()
+
+ agent = create_agent(
+ model,
+ env.tools,
+ GenericAgentParser(wrapping_xml_tag="tool", require_closing_xml_tag=False),
+ rate_limiter=self.rate_limiter,
+ )
+ executor = AgentExecutor(
+ agent=agent,
+ tools=env.tools,
+ handle_parsing_errors=True,
+ return_intermediate_steps=True,
+ )
+
+ return _add_task_instructions | apply_agent_executor_adapter(
+ executor, state_reader=env.read_state
+ )
diff --git a/langchain_benchmarks/tool_usage/agents/experimental/parser.py b/langchain_benchmarks/tool_usage/agents/experimental/parser.py
new file mode 100644
index 00000000..7be09776
--- /dev/null
+++ b/langchain_benchmarks/tool_usage/agents/experimental/parser.py
@@ -0,0 +1,122 @@
+import ast
+import re
+from typing import Union, Dict, Optional
+
+from langchain.agents import AgentOutputParser
+from langchain.pydantic_v1 import BaseModel, Field
+from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
+from langchain_core.exceptions import OutputParserException
+from langchain_core.messages import AIMessage
+
+
+class _ToolInvocationRequest(BaseModel):
+ """Light-weight pydantic model for validating the raw tool invocation request.
+
+ The purpose of this model, is to make sure that whatever as parsed from
+ the raw llm output has `tool_name` and potential `arguments` fields, and
+ nothing else.
+ """
+
+ tool_name: str
+ # OK parameterless tools which do not take arguments
+ arguments: Optional[Dict] = Field(default_factory=dict)
+
+
+class GenericAgentParser(AgentOutputParser):
+ """A generalized parser that makes it easier to parameterize different parsing."""
+
+ wrapping_xml_tag: str
+ """The tag that wraps the function invocation request.
+
+ For example, if "tool", then the function invocation request should be wrapped
+ in ....
+ """
+ require_closing_xml_tag: bool = False
+ """Whether we should require a closing tag for the wrapping_xml_tag.
+
+ For example, if True, then the function invocation request should be wrapped
+ """
+
+ def parse(self, text: str) -> Union[AgentFinish, AgentAction]:
+ """Parse the output of the agent."""
+ open_tag = f"<{self.wrapping_xml_tag}>"
+ close_tag = f"{self.wrapping_xml_tag}>"
+ if open_tag in text:
+ # This is a hack to make sure that is always present
+ # in the output if . may be a stop sequence for the
+ # language model, so depending on implementation
+ # the stop sequence may be cut off.
+ # There might be a better way to do this, but this works and
+ # is simple.
+ if not self.require_closing_xml_tag:
+ text += close_tag
+
+ pattern = rf"{open_tag}(?P.*?){close_tag}"
+ match = re.search(pattern, text, re.DOTALL)
+ if match:
+ content = match.group("invocation").strip()
+ return parse_invocation(content, self.wrapping_xml_tag)
+
+ return AgentFinish(
+ log=text,
+ return_values={
+ "output": text,
+ },
+ )
+
+
+def parse_invocation(text: str, tag: str) -> AgentAction:
+ """Parse the content of the function invocation.
+
+ Args:
+ text: The text to parse.
+ tag: The tag that wraps the function invocation request.
+
+ Returns:
+ An AgentAction that corresponds to the function invocation.
+
+ Raises:
+ OutputParserException: If the parsing fails.
+
+ This exception is meant to be caught by the agent executor and
+ handled appropriately to provide feedback to the LLM.
+ """
+ ai_content = f"<{tag}>{text}{tag}>\n"
+
+ try:
+ result = ast.literal_eval(text)
+ except BaseException as e:
+ # Convert this to something controllable by the user.
+ err_msg = (
+ f"ERROR: Please use the format "
+ f'<{tag}>{{"tool_name": $TOOL_NAME, "arguments": $ARGUMENTS}}{tag}>\n'
+ )
+
+ raise OutputParserException(
+ error=e,
+ llm_output=ai_content,
+ observation=err_msg,
+ send_to_llm=True,
+ )
+
+ try:
+ request = _ToolInvocationRequest.validate(result)
+ except Exception as e: # Using broad exception since it's not just ValidationError
+ # Can also raise DictError if result is not a dict.
+ err_msg = (
+ f"ERROR: Please use the format "
+ f'<{tag}>{{"tool_name": $TOOL_NAME, "arguments": $ARGUMENTS}}{tag}>\n'
+ )
+ raise OutputParserException(
+ error=e,
+ llm_output=ai_content,
+ send_to_llm=True,
+ observation=err_msg,
+ )
+
+ return AgentActionMessageLog(
+ message_log=[AIMessage(content=ai_content)],
+ tool=request.tool_name,
+ tool_input=request.arguments,
+ log=f"\nInvoking {request.tool_name}: {request.arguments}\n\t",
+ )
diff --git a/langchain_benchmarks/tool_usage/agents/experimental/prompts.py b/langchain_benchmarks/tool_usage/agents/experimental/prompts.py
new file mode 100644
index 00000000..9abc051e
--- /dev/null
+++ b/langchain_benchmarks/tool_usage/agents/experimental/prompts.py
@@ -0,0 +1,42 @@
+AGENT_INSTRUCTIONS_XML_FORMAT = """\
+In this environment you have access to a set of tools you can use to answer the user's question.
+
+You may call them like this:
+
+
+$TOOL_NAME
+
+<$PARAMETER_NAME>$PARAMETER_VALUE$PARAMETER_NAME>
+...
+
+
+
+
+Here are the tools available:
+
+{tool_description}
+""" # noqa: E501
+
+_AGENT_INSTRUCTIONS_BLOB_STYLE = """\
+In this environment you have access to a set of tools you can use to answer the user's question.
+
+Here are the tools available:
+
+{tool_description}
+
+You may call one tool at a time using a format that includes and tag.
+
+Inside the tag the content is a python dictionary that uses python literals (e.g., numbers, strings, lists, dictionaries, etc.) to specify the tool invocation.
+
+It must match the schema of the function as described in the tool description.
+"arguments" is a dictionary of the arguments to the function.
+
+
+{{
+ "tool_name": $TOOL_NAME,
+ "arguments": $ARGUMENTS
+}}
+
+
+If you do not know the answer use more tools. You can only take a single action at a time.\
+""" # noqa: E501
diff --git a/langchain_benchmarks/tool_usage/agents/experimental/tool_utils.py b/langchain_benchmarks/tool_usage/agents/experimental/tool_utils.py
new file mode 100644
index 00000000..976fc061
--- /dev/null
+++ b/langchain_benchmarks/tool_usage/agents/experimental/tool_utils.py
@@ -0,0 +1,57 @@
+"""Utilities to extract information from langchain tools for use in prompts."""
+import inspect
+from textwrap import dedent
+from typing import List
+
+from langchain.tools.base import StructuredTool
+
+from langchain_benchmarks.tool_usage.agents.experimental.encoder import (
+ Parameter,
+ FunctionDefinition,
+)
+
+# PUBLIC API
+
+
+def get_parameters_from_tool(tool: StructuredTool) -> List[Parameter]:
+ """Convert a langchain tool to a tool user tool."""
+ schema = tool.args_schema.schema()
+
+ properties = schema["properties"]
+ parameters = []
+ # Is this needed or is string OK?
+ type_adapter = {
+ "string": "str", # str or string?
+ "integer": "int",
+ "number": "float",
+ "boolean": "bool",
+ }
+ for key, value in properties.items():
+ parameters.append(
+ {
+ "name": key,
+ "type": type_adapter.get(value["type"], value["type"]),
+ "description": value.get("description", ""),
+ }
+ )
+
+ return parameters
+
+
+#
+def convert_tool_to_function_definition(tool: StructuredTool) -> FunctionDefinition:
+ """Convert a langchain tool to a tool user tool."""
+ # Here we re-inspect the underlying function to get the doc-string
+ # since StructuredTool modifies it, but we want the raw one for maximum
+ # flexibility.
+ description = inspect.getdoc(tool.func)
+
+ parameters = get_parameters_from_tool(tool)
+ return {
+ "name": tool.name,
+ "description": dedent(description),
+ "parameters": parameters,
+ "return_value": {
+ "type": "Any",
+ },
+ }
diff --git a/langchain_benchmarks/tool_usage/agents.py b/langchain_benchmarks/tool_usage/agents/openai_functions.py
similarity index 51%
rename from langchain_benchmarks/tool_usage/agents.py
rename to langchain_benchmarks/tool_usage/agents/openai_functions.py
index ad28a54f..32186d02 100644
--- a/langchain_benchmarks/tool_usage/agents.py
+++ b/langchain_benchmarks/tool_usage/agents/openai_functions.py
@@ -1,23 +1,17 @@
"""Code for creating an agent factory for evaluating tool usage tasks."""
-from typing import Any, Callable, Optional
+from typing import Optional
from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad import format_to_openai_functions
from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
-from langchain.schema.runnable import Runnable, RunnableLambda, RunnablePassthrough
+from langchain.schema.runnable import Runnable
from langchain.tools.render import format_tool_to_openai_function
from langchain_benchmarks import rate_limiting
from langchain_benchmarks.schema import ToolUsageTask
-
-
-def _ensure_output_exists(inputs: dict) -> dict:
- """Make sure that the output key is always present."""
- if "output" not in inputs:
- return {"output": "", **inputs}
- return inputs
+from langchain_benchmarks.tool_usage.agents.adapters import apply_agent_executor_adapter
# PUBLIC API
@@ -61,7 +55,7 @@ def __call__(self) -> Runnable:
if rate_limiting:
# Rate limited model
- model = with_rate_limit(model, self.rate_limiter)
+ model = rate_limiting.with_rate_limit(model, self.rate_limiter)
prompt = ChatPromptTemplate.from_messages(
[
@@ -96,66 +90,3 @@ def __call__(self) -> Runnable:
# Returns `state` in the output if the environment has a state reader
# makes sure that `output` is always in the output
return apply_agent_executor_adapter(runnable, state_reader=env.read_state)
-
-
-# PUBLIC API
-
-
-def apply_agent_executor_adapter(
- agent_executor: AgentExecutor,
- *,
- state_reader: Optional[Callable[[], Any]] = None,
-) -> Runnable:
- """An adapter for the agent executor to standardize its input and output.
-
- 1) Map `question` to `input` (`question` is used in the datasets,
- but `input` is used in the agent executor)
- 2) Ensure that `output` is always returned (will be set to "" if missing) --
- note that this may be relaxed after more updates in the eval config.
- 3) Populate `state` key in the response of the agent with the system state
- if a state reader is provided.
-
- Args:
- agent_executor: the agent executor
- state_reader: A callable without parameters that if invoked will return
- the state of the environment. Used to populate the 'state' key.
-
- Returns:
- a new runnable with a standardized output.
- """
-
- def _read_state(*args: Any, **kwargs: Any) -> Any:
- """Read the state of the environment."""
- if state_reader is not None:
- return state_reader()
- else:
- return None
-
- def _format_input(inputs: dict) -> dict:
- """Make sure that the input is always called `input`."""
-
- if "question" not in inputs:
- raise ValueError(
- "Expected 'question' to be in the inputs. Found only the following "
- f"keys {sorted(inputs.keys())}."
- )
-
- inputs = inputs.copy() # Because 'question' is popped below
-
- if "input" not in inputs:
- return {"input": inputs.pop("question"), **inputs}
- return inputs
-
- runnable = (
- RunnableLambda(_format_input).with_config({"run_name": "Format Input"})
- | agent_executor
- | RunnableLambda(_ensure_output_exists).with_config(
- {"run_name": "Ensure Output"}
- )
- )
-
- if state_reader is not None:
- runnable = runnable | RunnablePassthrough.assign(state=_read_state).with_config(
- {"run_name": "Read Env State"}
- )
- return runnable
diff --git a/tests/unit_tests/agents/__init__.py b/tests/unit_tests/agents/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/unit_tests/agents/encoding_and_decoding/__init__.py b/tests/unit_tests/agents/encoding_and_decoding/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/unit_tests/agents/encoding_and_decoding/test_decoding.py b/tests/unit_tests/agents/encoding_and_decoding/test_decoding.py
new file mode 100644
index 00000000..5a416cfa
--- /dev/null
+++ b/tests/unit_tests/agents/encoding_and_decoding/test_decoding.py
@@ -0,0 +1,54 @@
+import pytest
+from langchain_core.agents import AgentFinish, AgentActionMessageLog
+from langchain_core.exceptions import OutputParserException
+from langchain_core.messages import AIMessage
+
+from langchain_benchmarks.tool_usage.agents.experimental.parser import (
+ GenericAgentParser,
+)
+
+
+def test_parser() -> None:
+ """Test parser."""
+ parser = GenericAgentParser(require_closing_tag=False, wrapping_xml_tag="tool")
+
+ # If tag not found then it's an agent finish
+ assert isinstance(parser.invoke("goodbye"), AgentFinish)
+
+ with pytest.raises(OutputParserException):
+ # Invocation content is missing tool name and arguments
+ parser.invoke("'hello'")
+
+ with pytest.raises(OutputParserException):
+ parser.invoke("hello")
+
+ # Full invocation
+ text = (
+ '{\n "tool_name": "type_letter",\n '
+ '"arguments": {\n '
+ '"letter": "h"\n }\n}\n'
+ )
+
+ assert parser.invoke(text) == AgentActionMessageLog(
+ tool="type_letter",
+ tool_input={"letter": "h"},
+ log="\nInvoking type_letter: {'letter': 'h'}\n\t",
+ message_log=[AIMessage(content=text)],
+ )
+
+ # Test more cases
+ parsed = parser.invoke('{"tool_name": "hello"}')
+ assert parsed.tool == "hello"
+ # Assumes that it's a structured tool by default!
+ assert parsed.tool_input == {}
+
+ with pytest.raises(OutputParserException):
+ # Arguments need to be a dict
+ parser.invoke('{"tool_name": "hello", "arguments": [1, 2]}')
+
+ parsed = parser.invoke(
+ '{"tool_name": "hello", "arguments": {"a": "b"}}'
+ )
+ assert parsed.tool == "hello"
+ # Assumes that it's a structured tool by default!
+ assert parsed.tool_input == {"a": "b"}
diff --git a/tests/unit_tests/agents/encoding_and_decoding/test_typescript_encoding.py b/tests/unit_tests/agents/encoding_and_decoding/test_typescript_encoding.py
new file mode 100644
index 00000000..39175919
--- /dev/null
+++ b/tests/unit_tests/agents/encoding_and_decoding/test_typescript_encoding.py
@@ -0,0 +1,25 @@
+"""Test typescript encoding."""
+from langchain_benchmarks.tool_usage.agents.experimental.encoder import (
+ FunctionDefinition,
+ TypeScriptEncoder,
+)
+
+
+def test_function_definition() -> None:
+ """Test encoding a function definition."""
+ function_definition = FunctionDefinition(
+ name="test_function",
+ description="A test function",
+ parameters=[
+ {"name": "test_parameter", "type": "str", "description": "A test parameter"}
+ ],
+ return_value={"type": "str", "description": "A test return value"},
+ )
+ encoder = TypeScriptEncoder()
+ xml = encoder.visit_function_definition(function_definition)
+ assert xml == (
+ "// A test function\n"
+ "// @param test_parameter A test parameter\n"
+ "// @returns A test return value\n"
+ "function test_function(test_parameter: str): str;"
+ )
diff --git a/tests/unit_tests/agents/encoding_and_decoding/test_xml_encoding.py b/tests/unit_tests/agents/encoding_and_decoding/test_xml_encoding.py
new file mode 100644
index 00000000..d41b63be
--- /dev/null
+++ b/tests/unit_tests/agents/encoding_and_decoding/test_xml_encoding.py
@@ -0,0 +1,90 @@
+"""Test XML encoding and decoding of function definitions, invocation, and results."""
+from langchain_benchmarks.tool_usage.agents.experimental.encoder import (
+ FunctionDefinition,
+ FunctionInvocation,
+ FunctionResult,
+ XMLEncoder,
+)
+
+
+def test_function_definition_encoding() -> None:
+ """Test encoding a function definition."""
+ function_definition = FunctionDefinition(
+ name="test_function",
+ description="A test function",
+ parameters=[
+ {"name": "test_parameter", "type": "str", "description": "A test parameter"}
+ ],
+ return_value={"type": "str", "description": "A test return value"},
+ )
+ encoder = XMLEncoder()
+ xml = encoder.visit_function_definition(function_definition)
+ assert xml == (
+ "\n"
+ "test_function\n"
+ "\n"
+ "A test function\n"
+ "\n"
+ "\n"
+ "\n"
+ "test_parameter\n"
+ "str\n"
+ "A test parameter\n"
+ "\n"
+ "\n"
+ "\n"
+ "str\n"
+ "A test return value\n"
+ "\n"
+ ""
+ )
+
+
+def test_function_result_encoding() -> None:
+ """Test encoding a function result."""
+ encoder = XMLEncoder()
+ function_result = FunctionResult(
+ name="test_function",
+ result="test_result",
+ error=None,
+ )
+ xml = encoder.visit_function_result(function_result)
+ assert xml == (
+ "\n"
+ "test_function\n"
+ "test_result\n"
+ ""
+ )
+
+ function_result = FunctionResult(
+ name="test_function",
+ error="error",
+ )
+ xml = encoder.visit_function_result(function_result)
+ assert xml == (
+ "\n"
+ "test_function\n"
+ "error\n"
+ ""
+ )
+
+
+def test_function_invocation() -> None:
+ """Test function invocation."""
+ function_invocation = FunctionInvocation(
+ name="test_function",
+ arguments=[{"name": "test_argument", "value": "test_value"}],
+ )
+ encoder = XMLEncoder()
+ xml = encoder.visit_function_invocation(function_invocation)
+ assert xml == (
+ "\n"
+ "test_function\n"
+ "\n"
+ "\n"
+ "test_argument\n"
+ "test_value\n"
+ "\n"
+ "\n"
+ ""
+ )
diff --git a/tests/unit_tests/agents/test_tool_utils.py b/tests/unit_tests/agents/test_tool_utils.py
new file mode 100644
index 00000000..9e4bb95f
--- /dev/null
+++ b/tests/unit_tests/agents/test_tool_utils.py
@@ -0,0 +1,59 @@
+import pytest
+from langchain.tools import tool
+
+from langchain_benchmarks.tool_usage.agents.experimental.tool_utils import (
+ convert_tool_to_function_definition,
+)
+
+
+@tool
+def get_hello() -> str:
+ """Get hello."""
+ return "hello"
+
+
+@tool
+def repeat(x: str) -> str:
+ """Repeat x.
+
+ Args:
+ x: The string to repeat.
+
+ Returns:
+ The repeated string.
+ """
+ return x
+
+
+def test_parameterless_function() -> None:
+ """Test foo."""
+ function_definition = convert_tool_to_function_definition(get_hello)
+ assert function_definition == {
+ "name": "get_hello",
+ "description": "Get hello.",
+ "parameters": [],
+ "return_value": {
+ "type": "Any",
+ },
+ }
+
+
+@pytest.mark.skip("Need to fix handling of leading whitespace")
+def test_function_with_parameters() -> None:
+ import textwrap
+
+ doc = textwrap.dedent(repeat.func.__doc__)
+ assert convert_tool_to_function_definition(repeat) == {
+ "name": "repeat",
+ "description": doc,
+ "parameters": [
+ {
+ "name": "x",
+ "type": "str",
+ "description": "", # Need to fix this
+ }
+ ],
+ "return_value": {
+ "type": "Any",
+ },
+ }