From 8798735ea450de2c8761ba39d7de89689226638c Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 14 Dec 2023 12:05:59 -0500 Subject: [PATCH] Adds custom agents to the langchain benchmarking repo (#120) * This PR adds code for running custom agents to the langchain benchmarking repo. * The agent code is good enough for experimentation / prototyping, but I don't think it's good enough for the langchain repo: -- The abstractions aren't fully implemented and aren't ready for production use -- but OK for research -- For production use, one may want to remove all the intermediate abstractions to keep the agent as simple as possible I was thinking initially of including this in a different repo, but I think it's over-complicating things, probably OK to include some reference implementations inside of langchain benchmarks. --- agents/__init__.py | 0 agents/tests/__init__.py | 0 langchain_benchmarks/model_registration.py | 2 +- .../tool_usage/agents/__init__.py | 7 + .../tool_usage/agents/adapters.py | 71 ++++++ .../agents/experimental/__init__.py | 0 .../tool_usage/agents/experimental/agent.py | 133 ++++++++++ .../tool_usage/agents/experimental/encoder.py | 240 ++++++++++++++++++ .../tool_usage/agents/experimental/factory.py | 86 +++++++ .../tool_usage/agents/experimental/parser.py | 122 +++++++++ .../tool_usage/agents/experimental/prompts.py | 42 +++ .../agents/experimental/tool_utils.py | 57 +++++ .../{agents.py => agents/openai_functions.py} | 77 +----- tests/unit_tests/agents/__init__.py | 0 .../agents/encoding_and_decoding/__init__.py | 0 .../encoding_and_decoding/test_decoding.py | 54 ++++ .../test_typescript_encoding.py | 25 ++ .../test_xml_encoding.py | 90 +++++++ tests/unit_tests/agents/test_tool_utils.py | 59 +++++ 19 files changed, 991 insertions(+), 74 deletions(-) create mode 100644 agents/__init__.py create mode 100644 agents/tests/__init__.py create mode 100644 langchain_benchmarks/tool_usage/agents/__init__.py create mode 100644 langchain_benchmarks/tool_usage/agents/adapters.py create mode 100644 langchain_benchmarks/tool_usage/agents/experimental/__init__.py create mode 100644 langchain_benchmarks/tool_usage/agents/experimental/agent.py create mode 100644 langchain_benchmarks/tool_usage/agents/experimental/encoder.py create mode 100644 langchain_benchmarks/tool_usage/agents/experimental/factory.py create mode 100644 langchain_benchmarks/tool_usage/agents/experimental/parser.py create mode 100644 langchain_benchmarks/tool_usage/agents/experimental/prompts.py create mode 100644 langchain_benchmarks/tool_usage/agents/experimental/tool_utils.py rename langchain_benchmarks/tool_usage/{agents.py => agents/openai_functions.py} (51%) create mode 100644 tests/unit_tests/agents/__init__.py create mode 100644 tests/unit_tests/agents/encoding_and_decoding/__init__.py create mode 100644 tests/unit_tests/agents/encoding_and_decoding/test_decoding.py create mode 100644 tests/unit_tests/agents/encoding_and_decoding/test_typescript_encoding.py create mode 100644 tests/unit_tests/agents/encoding_and_decoding/test_xml_encoding.py create mode 100644 tests/unit_tests/agents/test_tool_utils.py diff --git a/agents/__init__.py b/agents/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/agents/tests/__init__.py b/agents/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/langchain_benchmarks/model_registration.py b/langchain_benchmarks/model_registration.py index fa9da072..d2a5a39b 100644 --- a/langchain_benchmarks/model_registration.py +++ b/langchain_benchmarks/model_registration.py @@ -194,7 +194,7 @@ ), RegisteredModel( provider="fireworks", - name="mixtral-8x7b-instruct", + name="mixtral-8x7b-instruct-fw", description="Mistral MoE 8x7B Instruct v0.1 model with Sparse " "Mixture of Experts. Fine tuned for instruction following", type="llm", diff --git a/langchain_benchmarks/tool_usage/agents/__init__.py b/langchain_benchmarks/tool_usage/agents/__init__.py new file mode 100644 index 00000000..4e9f2896 --- /dev/null +++ b/langchain_benchmarks/tool_usage/agents/__init__.py @@ -0,0 +1,7 @@ +from langchain_benchmarks.tool_usage.agents.adapters import apply_agent_executor_adapter +from langchain_benchmarks.tool_usage.agents.experimental.factory import ( + CustomAgentFactory, +) +from langchain_benchmarks.tool_usage.agents.openai_functions import OpenAIAgentFactory + +__all__ = ["OpenAIAgentFactory", "apply_agent_executor_adapter", "CustomAgentFactory"] diff --git a/langchain_benchmarks/tool_usage/agents/adapters.py b/langchain_benchmarks/tool_usage/agents/adapters.py new file mode 100644 index 00000000..29bd3ddf --- /dev/null +++ b/langchain_benchmarks/tool_usage/agents/adapters.py @@ -0,0 +1,71 @@ +from typing import Optional, Callable, Any + +from langchain.agents import AgentExecutor +from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough + + +def _ensure_output_exists(inputs: dict) -> dict: + """Make sure that the output key is always present.""" + if "output" not in inputs: + return {"output": "", **inputs} + return inputs + + +def apply_agent_executor_adapter( + agent_executor: AgentExecutor, + *, + state_reader: Optional[Callable[[], Any]] = None, +) -> Runnable: + """An adapter for the agent executor to standardize its input and output. + + 1) Map `question` to `input` (`question` is used in the datasets, + but `input` is used in the agent executor) + 2) Ensure that `output` is always returned (will be set to "" if missing) -- + note that this may be relaxed after more updates in the eval config. + 3) Populate `state` key in the response of the agent with the system state + if a state reader is provided. + + Args: + agent_executor: the agent executor + state_reader: A callable without parameters that if invoked will return + the state of the environment. Used to populate the 'state' key. + + Returns: + a new runnable with a standardized output. + """ + + def _read_state(*args: Any, **kwargs: Any) -> Any: + """Read the state of the environment.""" + if state_reader is not None: + return state_reader() + else: + return None + + def _format_input(inputs: dict) -> dict: + """Make sure that the input is always called `input`.""" + + if "question" not in inputs: + raise ValueError( + "Expected 'question' to be in the inputs. Found only the following " + f"keys {sorted(inputs.keys())}." + ) + + inputs = inputs.copy() # Because 'question' is popped below + + if "input" not in inputs: + return {"input": inputs.pop("question"), **inputs} + return inputs + + runnable = ( + RunnableLambda(_format_input).with_config({"run_name": "Format Input"}) + | agent_executor + | RunnableLambda(_ensure_output_exists).with_config( + {"run_name": "Ensure Output"} + ) + ) + + if state_reader is not None: + runnable = runnable | RunnablePassthrough.assign(state=_read_state).with_config( + {"run_name": "Read Env State"} + ) + return runnable diff --git a/langchain_benchmarks/tool_usage/agents/experimental/__init__.py b/langchain_benchmarks/tool_usage/agents/experimental/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/langchain_benchmarks/tool_usage/agents/experimental/agent.py b/langchain_benchmarks/tool_usage/agents/experimental/agent.py new file mode 100644 index 00000000..14ba932d --- /dev/null +++ b/langchain_benchmarks/tool_usage/agents/experimental/agent.py @@ -0,0 +1,133 @@ +from typing import List, Literal, Optional, Sequence, Tuple, Union + +from langchain.agents import AgentOutputParser +from langchain.prompts.chat import ChatPromptTemplate +from langchain.schema.runnable import Runnable +from langchain.tools import StructuredTool +from langchain_core.agents import AgentAction, AgentFinish +from langchain_core.language_models import BaseChatModel, BaseLanguageModel +from langchain_core.messages import BaseMessage, HumanMessage, AIMessage +from langchain_core.prompts import MessagesPlaceholder +from typing_extensions import NotRequired, TypedDict + +from langchain_benchmarks import RateLimiter +from langchain_benchmarks.rate_limiting import with_rate_limit +from langchain_benchmarks.tool_usage.agents.experimental.encoder import ( + AstPrinter, + TypeScriptEncoder, + XMLEncoder, +) +from langchain_benchmarks.tool_usage.agents.experimental.encoder import FunctionResult +from langchain_benchmarks.tool_usage.agents.experimental.prompts import ( + _AGENT_INSTRUCTIONS_BLOB_STYLE, +) +from langchain_benchmarks.tool_usage.agents.experimental.tool_utils import ( + convert_tool_to_function_definition, +) + + +def format_steps_for_chat( + intermediate_steps: List[Tuple[AgentAction, str]], + ast_printer: AstPrinter, +) -> List[BaseMessage]: + """Format the steps.""" + messages = [] + for action, observation in intermediate_steps: + # Action messages contains the tool invocation request from the LLM + # Now add the result of the tool invocation. + + if action.tool == "_Exception": + messages.append( + AIMessage( + content=action.log, + ) + ) + messages.append( + # Tool input is the error message for the exception + HumanMessage(content=action.tool_input) + ) + else: + messages.extend(action.messages) + function_result: FunctionResult = { + "name": action.tool, + "error": None, + "result": observation, + } + messages.append( + HumanMessage( + content=ast_printer.visit_function_result(function_result), + ) + ) + + return messages + + +# PUBLIC API + + +class AgentInput(TypedDict): + """The input to the agent.""" + + input: str + """The input to the agent.""" + intermediate_steps: List[Tuple[AgentAction, str]] + """The intermediate steps taken by the agent.""" + examples: NotRequired[List[BaseMessage]] + """A list of messages that can be used to form example traces.""" + + +def create_agent( + model: Union[BaseChatModel, BaseLanguageModel], + tools: Sequence[StructuredTool], + parser: AgentOutputParser, + *, + ast_printer: Union[AstPrinter, Literal["xml"]] = "xml", + rate_limiter: Optional[RateLimiter] = None, +) -> Runnable[AgentInput, Union[AgentAction, AgentFinish]]: + """Create an agent for a chat model.""" + if isinstance(ast_printer, str): + if ast_printer == "xml": + ast_printer_ = XMLEncoder() + elif ast_printer == "typescript": + ast_printer_ = TypeScriptEncoder() + else: + raise ValueError(f"Unknown ast printer: {ast_printer}") + elif isinstance(ast_printer, AstPrinter): + ast_printer_ = ast_printer + else: + raise TypeError( + f"Expected AstPrinter or str, got {type(ast_printer)} for `ast_printer`" + ) + + function_definitions = [convert_tool_to_function_definition(tool) for tool in tools] + tool_description = ast_printer_.visit_function_definitions(function_definitions) + + template = ChatPromptTemplate.from_messages( + [ + ("system", _AGENT_INSTRUCTIONS_BLOB_STYLE), + MessagesPlaceholder("examples"), # Can use to add example traces + ("human", "{input}"), + MessagesPlaceholder("history"), + ] + ).partial(tool_description=tool_description) + + # For the time being, hard-coding the fact that we're using a tag. + model = model.bind(stop=[""]) + + if rate_limiter: + # Apply a rate limiter if it was provided + model = with_rate_limit(model, rate_limiter) + + agent = ( + { + "input": lambda x: x["input"], + "history": lambda x: format_steps_for_chat( + x["intermediate_steps"], ast_printer_ + ), + "examples": lambda x: x.get("examples", []), + } + | template + | model + | parser + ) + return agent diff --git a/langchain_benchmarks/tool_usage/agents/experimental/encoder.py b/langchain_benchmarks/tool_usage/agents/experimental/encoder.py new file mode 100644 index 00000000..c6799609 --- /dev/null +++ b/langchain_benchmarks/tool_usage/agents/experimental/encoder.py @@ -0,0 +1,240 @@ +"""Prototyping code for rendering function definitions, invocations, and results. + +Types are simplified for now to `str`. + +We should actually support something like pydantic or jsonschema for the types, so +we can expand them recursively for nested types. +""" +import abc +from typing import Any, List, Optional + +from typing_extensions import NotRequired, TypedDict + + +class Parameter(TypedDict): + """Representation for a parameter.""" + + name: str + type: str + description: str + + +class Arguments(TypedDict): + """Arguments are passed to a function during function invocation.""" + + name: Optional[str] + value: Any + + +class ReturnValue(TypedDict): + """Representation for a return value of a function call.""" + + type: str + description: NotRequired[str] + + +class FunctionDefinition(TypedDict): + """Representation for a function.""" + + name: str + description: str # Function description + parameters: List[Parameter] + return_value: ReturnValue + + +class FunctionInvocation(TypedDict): + """Representation for a function invocation.""" + + id: NotRequired[str] + name: str + arguments: List[Arguments] + + +class FunctionResult(TypedDict): + """Representation for a function result.""" + + id: NotRequired[str] + name: str + result: Optional[str] + error: Optional[str] + + +class Visitor(abc.ABC): + @abc.abstractmethod + def visit_function_definition(self, function_definition: FunctionDefinition) -> str: + """Render a function.""" + + @abc.abstractmethod + def visit_function_definitions( + self, function_definitions: List[FunctionDefinition] + ) -> str: + """Render a function.""" + + @abc.abstractmethod + def visit_function_invocation(self, function_invocation: FunctionInvocation) -> str: + """Render a function invocation.""" + + @abc.abstractmethod + def visit_function_result(self, function_result: FunctionResult) -> str: + """Render a function result.""" + + +class AstPrinter(Visitor): + """Print the AST.""" + + +class XMLEncoder(AstPrinter): + def visit_function_definition(self, function_definition: FunctionDefinition) -> str: + """Render a function.""" + parameters_lines = [] + + for parameter in function_definition["parameters"]: + parameters_lines.extend( + [ + "", + f"{parameter['name']}", + f"{parameter['type']}", + f"{parameter['description']}", + "", + ] + ) + lines = [ + "", + f"{function_definition['name']}", + "", + f"{function_definition['description']}", + "", + "", + *parameters_lines, + "", + "", + f"{function_definition['return_value']['type']}", + ] + if function_definition["return_value"].get("description"): + lines.append( + f"{function_definition['return_value']['description']}" + f"" + ) + + lines.extend(["", ""]) + return "\n".join(lines) + + def visit_function_definitions( + self, function_definitions: List[FunctionDefinition] + ) -> str: + """Render a function.""" + strs = [ + self.visit_function_definition(function_definition) + for function_definition in function_definitions + ] + return "\n" + "\n".join(strs) + "\n" + + def visit_function_invocation(self, invocation: FunctionInvocation) -> str: + """Render a function invocation.""" + arguments_as_strings = [ + "\n" + f"{argument['name']}\n" + f"{argument['value']}\n" + "\n" + for argument in invocation["arguments"] + ] + lines = [""] + + if invocation.get("id"): + lines.append(f"{invocation['id']}") + + lines.extend( + [ + f"{invocation['name']}\n" + "\n" + f"{''.join(arguments_as_strings)}" # Already includes trailing newline + "\n" + "" + ] + ) + return "\n".join(lines) + + def visit_function_result(self, function_result: FunctionResult) -> str: + """Render a function result.""" + lines = [ + "", + ] + + if function_result.get("id"): + lines.append(f"{function_result['id']}") + + lines.append(f"{function_result['name']}") + + if function_result["error"]: + lines.extend( + [ + f"{function_result['error']}", + ] + ) + else: + lines.append( + f"{function_result['result']}", + ) + + lines.append("") + + return "\n".join(lines) + + +class TypeScriptEncoder(AstPrinter): + def visit_function_definition(self, function_definition: FunctionDefinition) -> str: + """Render a function.""" + parameters_as_strings = [ + f"{parameter['name']}: {parameter['type']}" + for parameter in function_definition["parameters"] + ] + # Let's use JSdoc style comments + # First the function description + lines = [ + f"// {function_definition['description']}", + # Then the parameter descriptions + *[ + f"// @param {parameter['name']} {parameter['description']}" + for parameter in function_definition["parameters"] + ], + # Then the return value description + f"// @returns {function_definition['return_value']['description']}", + # Then the function definition + f"function {function_definition['name']}(" + f"{', '.join(parameters_as_strings)}): " + f"{function_definition['return_value']['type']};", + ] + + # finally join + function = "\n".join(lines) + return function + + def visit_function_definitions( + self, function_definitions: List[FunctionDefinition] + ) -> str: + """Render a function.""" + strs = [ + self.visit_function_definition(function_definition) + for function_definition in function_definitions + ] + return "\n\n".join(strs) + + def visit_function_invocation(self, invocation: FunctionInvocation) -> str: + """Render a function invocation.""" + arguments_as_strings = [ + f"{argument['name']}: {argument['value']}" + for argument in invocation["arguments"] + ] + lines = [f"{invocation['name']}(" f"{', '.join(arguments_as_strings)});"] + return "\n".join(lines) + + def visit_function_result(self, function_result: FunctionResult) -> str: + """Render a function result.""" + lines = [] + if function_result["error"]: + lines.append(f"ERROR: {function_result['error']}") + else: + lines.append(f"> {function_result['result']}") + if function_result.get("id"): + lines.append(f"// ID: {function_result['id']}") + return "\n".join(lines) diff --git a/langchain_benchmarks/tool_usage/agents/experimental/factory.py b/langchain_benchmarks/tool_usage/agents/experimental/factory.py new file mode 100644 index 00000000..d03c7626 --- /dev/null +++ b/langchain_benchmarks/tool_usage/agents/experimental/factory.py @@ -0,0 +1,86 @@ +"""Factory for creating agents for the tool usage task.""" +from typing import Optional + +from langchain.agents import AgentExecutor +from langchain_core.runnables import Runnable, RunnableConfig + +from langchain_benchmarks import RateLimiter, model_registry +from langchain_benchmarks.schema import ToolUsageTask +from langchain_benchmarks.tool_usage.agents.adapters import apply_agent_executor_adapter +from langchain_benchmarks.tool_usage.agents.experimental.agent import create_agent +from langchain_benchmarks.tool_usage.agents.experimental.parser import ( + GenericAgentParser, +) + + +class CustomAgentFactory: + """A factory for creating tool using agents. + + A factory for agents that do not leverage any special JSON mode for + function usage; instead all function invocation behavior is implemented solely + through prompt engineering and parsing. + """ + + def __init__( + self, + task: ToolUsageTask, + model: str, + *, + rate_limiter: Optional[RateLimiter] = None, + ) -> None: + """Create an agent factory for the given tool usage task. + + Args: + task: The task to create an agent factory for + model: model name (check model_registry) + rate_limiter: The rate limiter to use if provided + """ + if model not in model_registry: + raise ValueError(f"Unknown model: {model}") + self.task = task + self.model = model + self.rate_limiter = rate_limiter + + def __call__(self) -> Runnable: + if isinstance(self.model, str): + registered_model = model_registry.get_model(self.model) + if registered_model is None: + raise ValueError(f"Unknown model: {self.model}") + model = registered_model.get_model(model_params={"temperature": 0}) + else: + model = self.model + + def _add_task_instructions( + input: dict, config: Optional[RunnableConfig] = None, **kwargs + ) -> dict: + """Add task instructions to the question.""" + if not isinstance(input, dict): + raise ValueError( + f"Expected input to be a dict with key `question`. " + f"Found {type(input)}." + ) + input = input.copy() + input["question"] = ( + f"{self.task.instructions}\nWrite down your answer, " + f"but do not explain it. Input: `{input['question']}`" + ) + return input + + env = self.task.create_environment() + + agent = create_agent( + model, + env.tools, + GenericAgentParser(wrapping_xml_tag="tool", require_closing_xml_tag=False), + rate_limiter=self.rate_limiter, + ) + executor = AgentExecutor( + agent=agent, + tools=env.tools, + handle_parsing_errors=True, + return_intermediate_steps=True, + ) + + return _add_task_instructions | apply_agent_executor_adapter( + executor, state_reader=env.read_state + ) diff --git a/langchain_benchmarks/tool_usage/agents/experimental/parser.py b/langchain_benchmarks/tool_usage/agents/experimental/parser.py new file mode 100644 index 00000000..7be09776 --- /dev/null +++ b/langchain_benchmarks/tool_usage/agents/experimental/parser.py @@ -0,0 +1,122 @@ +import ast +import re +from typing import Union, Dict, Optional + +from langchain.agents import AgentOutputParser +from langchain.pydantic_v1 import BaseModel, Field +from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish +from langchain_core.exceptions import OutputParserException +from langchain_core.messages import AIMessage + + +class _ToolInvocationRequest(BaseModel): + """Light-weight pydantic model for validating the raw tool invocation request. + + The purpose of this model, is to make sure that whatever as parsed from + the raw llm output has `tool_name` and potential `arguments` fields, and + nothing else. + """ + + tool_name: str + # OK parameterless tools which do not take arguments + arguments: Optional[Dict] = Field(default_factory=dict) + + +class GenericAgentParser(AgentOutputParser): + """A generalized parser that makes it easier to parameterize different parsing.""" + + wrapping_xml_tag: str + """The tag that wraps the function invocation request. + + For example, if "tool", then the function invocation request should be wrapped + in .... + """ + require_closing_xml_tag: bool = False + """Whether we should require a closing tag for the wrapping_xml_tag. + + For example, if True, then the function invocation request should be wrapped + """ + + def parse(self, text: str) -> Union[AgentFinish, AgentAction]: + """Parse the output of the agent.""" + open_tag = f"<{self.wrapping_xml_tag}>" + close_tag = f"" + if open_tag in text: + # This is a hack to make sure that is always present + # in the output if . may be a stop sequence for the + # language model, so depending on implementation + # the stop sequence may be cut off. + # There might be a better way to do this, but this works and + # is simple. + if not self.require_closing_xml_tag: + text += close_tag + + pattern = rf"{open_tag}(?P.*?){close_tag}" + match = re.search(pattern, text, re.DOTALL) + if match: + content = match.group("invocation").strip() + return parse_invocation(content, self.wrapping_xml_tag) + + return AgentFinish( + log=text, + return_values={ + "output": text, + }, + ) + + +def parse_invocation(text: str, tag: str) -> AgentAction: + """Parse the content of the function invocation. + + Args: + text: The text to parse. + tag: The tag that wraps the function invocation request. + + Returns: + An AgentAction that corresponds to the function invocation. + + Raises: + OutputParserException: If the parsing fails. + + This exception is meant to be caught by the agent executor and + handled appropriately to provide feedback to the LLM. + """ + ai_content = f"<{tag}>{text}\n" + + try: + result = ast.literal_eval(text) + except BaseException as e: + # Convert this to something controllable by the user. + err_msg = ( + f"ERROR: Please use the format " + f'<{tag}>{{"tool_name": $TOOL_NAME, "arguments": $ARGUMENTS}}\n' + ) + + raise OutputParserException( + error=e, + llm_output=ai_content, + observation=err_msg, + send_to_llm=True, + ) + + try: + request = _ToolInvocationRequest.validate(result) + except Exception as e: # Using broad exception since it's not just ValidationError + # Can also raise DictError if result is not a dict. + err_msg = ( + f"ERROR: Please use the format " + f'<{tag}>{{"tool_name": $TOOL_NAME, "arguments": $ARGUMENTS}}\n' + ) + raise OutputParserException( + error=e, + llm_output=ai_content, + send_to_llm=True, + observation=err_msg, + ) + + return AgentActionMessageLog( + message_log=[AIMessage(content=ai_content)], + tool=request.tool_name, + tool_input=request.arguments, + log=f"\nInvoking {request.tool_name}: {request.arguments}\n\t", + ) diff --git a/langchain_benchmarks/tool_usage/agents/experimental/prompts.py b/langchain_benchmarks/tool_usage/agents/experimental/prompts.py new file mode 100644 index 00000000..9abc051e --- /dev/null +++ b/langchain_benchmarks/tool_usage/agents/experimental/prompts.py @@ -0,0 +1,42 @@ +AGENT_INSTRUCTIONS_XML_FORMAT = """\ +In this environment you have access to a set of tools you can use to answer the user's question. + +You may call them like this: + + +$TOOL_NAME + +<$PARAMETER_NAME>$PARAMETER_VALUE +... + + + + +Here are the tools available: + +{tool_description} +""" # noqa: E501 + +_AGENT_INSTRUCTIONS_BLOB_STYLE = """\ +In this environment you have access to a set of tools you can use to answer the user's question. + +Here are the tools available: + +{tool_description} + +You may call one tool at a time using a format that includes and tag. + +Inside the tag the content is a python dictionary that uses python literals (e.g., numbers, strings, lists, dictionaries, etc.) to specify the tool invocation. + +It must match the schema of the function as described in the tool description. +"arguments" is a dictionary of the arguments to the function. + + +{{ + "tool_name": $TOOL_NAME, + "arguments": $ARGUMENTS +}} + + +If you do not know the answer use more tools. You can only take a single action at a time.\ +""" # noqa: E501 diff --git a/langchain_benchmarks/tool_usage/agents/experimental/tool_utils.py b/langchain_benchmarks/tool_usage/agents/experimental/tool_utils.py new file mode 100644 index 00000000..976fc061 --- /dev/null +++ b/langchain_benchmarks/tool_usage/agents/experimental/tool_utils.py @@ -0,0 +1,57 @@ +"""Utilities to extract information from langchain tools for use in prompts.""" +import inspect +from textwrap import dedent +from typing import List + +from langchain.tools.base import StructuredTool + +from langchain_benchmarks.tool_usage.agents.experimental.encoder import ( + Parameter, + FunctionDefinition, +) + +# PUBLIC API + + +def get_parameters_from_tool(tool: StructuredTool) -> List[Parameter]: + """Convert a langchain tool to a tool user tool.""" + schema = tool.args_schema.schema() + + properties = schema["properties"] + parameters = [] + # Is this needed or is string OK? + type_adapter = { + "string": "str", # str or string? + "integer": "int", + "number": "float", + "boolean": "bool", + } + for key, value in properties.items(): + parameters.append( + { + "name": key, + "type": type_adapter.get(value["type"], value["type"]), + "description": value.get("description", ""), + } + ) + + return parameters + + +# +def convert_tool_to_function_definition(tool: StructuredTool) -> FunctionDefinition: + """Convert a langchain tool to a tool user tool.""" + # Here we re-inspect the underlying function to get the doc-string + # since StructuredTool modifies it, but we want the raw one for maximum + # flexibility. + description = inspect.getdoc(tool.func) + + parameters = get_parameters_from_tool(tool) + return { + "name": tool.name, + "description": dedent(description), + "parameters": parameters, + "return_value": { + "type": "Any", + }, + } diff --git a/langchain_benchmarks/tool_usage/agents.py b/langchain_benchmarks/tool_usage/agents/openai_functions.py similarity index 51% rename from langchain_benchmarks/tool_usage/agents.py rename to langchain_benchmarks/tool_usage/agents/openai_functions.py index ad28a54f..32186d02 100644 --- a/langchain_benchmarks/tool_usage/agents.py +++ b/langchain_benchmarks/tool_usage/agents/openai_functions.py @@ -1,23 +1,17 @@ """Code for creating an agent factory for evaluating tool usage tasks.""" -from typing import Any, Callable, Optional +from typing import Optional from langchain.agents import AgentExecutor from langchain.agents.format_scratchpad import format_to_openai_functions from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser from langchain.chat_models import ChatOpenAI from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder -from langchain.schema.runnable import Runnable, RunnableLambda, RunnablePassthrough +from langchain.schema.runnable import Runnable from langchain.tools.render import format_tool_to_openai_function from langchain_benchmarks import rate_limiting from langchain_benchmarks.schema import ToolUsageTask - - -def _ensure_output_exists(inputs: dict) -> dict: - """Make sure that the output key is always present.""" - if "output" not in inputs: - return {"output": "", **inputs} - return inputs +from langchain_benchmarks.tool_usage.agents.adapters import apply_agent_executor_adapter # PUBLIC API @@ -61,7 +55,7 @@ def __call__(self) -> Runnable: if rate_limiting: # Rate limited model - model = with_rate_limit(model, self.rate_limiter) + model = rate_limiting.with_rate_limit(model, self.rate_limiter) prompt = ChatPromptTemplate.from_messages( [ @@ -96,66 +90,3 @@ def __call__(self) -> Runnable: # Returns `state` in the output if the environment has a state reader # makes sure that `output` is always in the output return apply_agent_executor_adapter(runnable, state_reader=env.read_state) - - -# PUBLIC API - - -def apply_agent_executor_adapter( - agent_executor: AgentExecutor, - *, - state_reader: Optional[Callable[[], Any]] = None, -) -> Runnable: - """An adapter for the agent executor to standardize its input and output. - - 1) Map `question` to `input` (`question` is used in the datasets, - but `input` is used in the agent executor) - 2) Ensure that `output` is always returned (will be set to "" if missing) -- - note that this may be relaxed after more updates in the eval config. - 3) Populate `state` key in the response of the agent with the system state - if a state reader is provided. - - Args: - agent_executor: the agent executor - state_reader: A callable without parameters that if invoked will return - the state of the environment. Used to populate the 'state' key. - - Returns: - a new runnable with a standardized output. - """ - - def _read_state(*args: Any, **kwargs: Any) -> Any: - """Read the state of the environment.""" - if state_reader is not None: - return state_reader() - else: - return None - - def _format_input(inputs: dict) -> dict: - """Make sure that the input is always called `input`.""" - - if "question" not in inputs: - raise ValueError( - "Expected 'question' to be in the inputs. Found only the following " - f"keys {sorted(inputs.keys())}." - ) - - inputs = inputs.copy() # Because 'question' is popped below - - if "input" not in inputs: - return {"input": inputs.pop("question"), **inputs} - return inputs - - runnable = ( - RunnableLambda(_format_input).with_config({"run_name": "Format Input"}) - | agent_executor - | RunnableLambda(_ensure_output_exists).with_config( - {"run_name": "Ensure Output"} - ) - ) - - if state_reader is not None: - runnable = runnable | RunnablePassthrough.assign(state=_read_state).with_config( - {"run_name": "Read Env State"} - ) - return runnable diff --git a/tests/unit_tests/agents/__init__.py b/tests/unit_tests/agents/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/agents/encoding_and_decoding/__init__.py b/tests/unit_tests/agents/encoding_and_decoding/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/agents/encoding_and_decoding/test_decoding.py b/tests/unit_tests/agents/encoding_and_decoding/test_decoding.py new file mode 100644 index 00000000..5a416cfa --- /dev/null +++ b/tests/unit_tests/agents/encoding_and_decoding/test_decoding.py @@ -0,0 +1,54 @@ +import pytest +from langchain_core.agents import AgentFinish, AgentActionMessageLog +from langchain_core.exceptions import OutputParserException +from langchain_core.messages import AIMessage + +from langchain_benchmarks.tool_usage.agents.experimental.parser import ( + GenericAgentParser, +) + + +def test_parser() -> None: + """Test parser.""" + parser = GenericAgentParser(require_closing_tag=False, wrapping_xml_tag="tool") + + # If tag not found then it's an agent finish + assert isinstance(parser.invoke("goodbye"), AgentFinish) + + with pytest.raises(OutputParserException): + # Invocation content is missing tool name and arguments + parser.invoke("'hello'") + + with pytest.raises(OutputParserException): + parser.invoke("hello") + + # Full invocation + text = ( + '{\n "tool_name": "type_letter",\n ' + '"arguments": {\n ' + '"letter": "h"\n }\n}\n' + ) + + assert parser.invoke(text) == AgentActionMessageLog( + tool="type_letter", + tool_input={"letter": "h"}, + log="\nInvoking type_letter: {'letter': 'h'}\n\t", + message_log=[AIMessage(content=text)], + ) + + # Test more cases + parsed = parser.invoke('{"tool_name": "hello"}') + assert parsed.tool == "hello" + # Assumes that it's a structured tool by default! + assert parsed.tool_input == {} + + with pytest.raises(OutputParserException): + # Arguments need to be a dict + parser.invoke('{"tool_name": "hello", "arguments": [1, 2]}') + + parsed = parser.invoke( + '{"tool_name": "hello", "arguments": {"a": "b"}}' + ) + assert parsed.tool == "hello" + # Assumes that it's a structured tool by default! + assert parsed.tool_input == {"a": "b"} diff --git a/tests/unit_tests/agents/encoding_and_decoding/test_typescript_encoding.py b/tests/unit_tests/agents/encoding_and_decoding/test_typescript_encoding.py new file mode 100644 index 00000000..39175919 --- /dev/null +++ b/tests/unit_tests/agents/encoding_and_decoding/test_typescript_encoding.py @@ -0,0 +1,25 @@ +"""Test typescript encoding.""" +from langchain_benchmarks.tool_usage.agents.experimental.encoder import ( + FunctionDefinition, + TypeScriptEncoder, +) + + +def test_function_definition() -> None: + """Test encoding a function definition.""" + function_definition = FunctionDefinition( + name="test_function", + description="A test function", + parameters=[ + {"name": "test_parameter", "type": "str", "description": "A test parameter"} + ], + return_value={"type": "str", "description": "A test return value"}, + ) + encoder = TypeScriptEncoder() + xml = encoder.visit_function_definition(function_definition) + assert xml == ( + "// A test function\n" + "// @param test_parameter A test parameter\n" + "// @returns A test return value\n" + "function test_function(test_parameter: str): str;" + ) diff --git a/tests/unit_tests/agents/encoding_and_decoding/test_xml_encoding.py b/tests/unit_tests/agents/encoding_and_decoding/test_xml_encoding.py new file mode 100644 index 00000000..d41b63be --- /dev/null +++ b/tests/unit_tests/agents/encoding_and_decoding/test_xml_encoding.py @@ -0,0 +1,90 @@ +"""Test XML encoding and decoding of function definitions, invocation, and results.""" +from langchain_benchmarks.tool_usage.agents.experimental.encoder import ( + FunctionDefinition, + FunctionInvocation, + FunctionResult, + XMLEncoder, +) + + +def test_function_definition_encoding() -> None: + """Test encoding a function definition.""" + function_definition = FunctionDefinition( + name="test_function", + description="A test function", + parameters=[ + {"name": "test_parameter", "type": "str", "description": "A test parameter"} + ], + return_value={"type": "str", "description": "A test return value"}, + ) + encoder = XMLEncoder() + xml = encoder.visit_function_definition(function_definition) + assert xml == ( + "\n" + "test_function\n" + "\n" + "A test function\n" + "\n" + "\n" + "\n" + "test_parameter\n" + "str\n" + "A test parameter\n" + "\n" + "\n" + "\n" + "str\n" + "A test return value\n" + "\n" + "" + ) + + +def test_function_result_encoding() -> None: + """Test encoding a function result.""" + encoder = XMLEncoder() + function_result = FunctionResult( + name="test_function", + result="test_result", + error=None, + ) + xml = encoder.visit_function_result(function_result) + assert xml == ( + "\n" + "test_function\n" + "test_result\n" + "" + ) + + function_result = FunctionResult( + name="test_function", + error="error", + ) + xml = encoder.visit_function_result(function_result) + assert xml == ( + "\n" + "test_function\n" + "error\n" + "" + ) + + +def test_function_invocation() -> None: + """Test function invocation.""" + function_invocation = FunctionInvocation( + name="test_function", + arguments=[{"name": "test_argument", "value": "test_value"}], + ) + encoder = XMLEncoder() + xml = encoder.visit_function_invocation(function_invocation) + assert xml == ( + "\n" + "test_function\n" + "\n" + "\n" + "test_argument\n" + "test_value\n" + "\n" + "\n" + "" + ) diff --git a/tests/unit_tests/agents/test_tool_utils.py b/tests/unit_tests/agents/test_tool_utils.py new file mode 100644 index 00000000..9e4bb95f --- /dev/null +++ b/tests/unit_tests/agents/test_tool_utils.py @@ -0,0 +1,59 @@ +import pytest +from langchain.tools import tool + +from langchain_benchmarks.tool_usage.agents.experimental.tool_utils import ( + convert_tool_to_function_definition, +) + + +@tool +def get_hello() -> str: + """Get hello.""" + return "hello" + + +@tool +def repeat(x: str) -> str: + """Repeat x. + + Args: + x: The string to repeat. + + Returns: + The repeated string. + """ + return x + + +def test_parameterless_function() -> None: + """Test foo.""" + function_definition = convert_tool_to_function_definition(get_hello) + assert function_definition == { + "name": "get_hello", + "description": "Get hello.", + "parameters": [], + "return_value": { + "type": "Any", + }, + } + + +@pytest.mark.skip("Need to fix handling of leading whitespace") +def test_function_with_parameters() -> None: + import textwrap + + doc = textwrap.dedent(repeat.func.__doc__) + assert convert_tool_to_function_definition(repeat) == { + "name": "repeat", + "description": doc, + "parameters": [ + { + "name": "x", + "type": "str", + "description": "", # Need to fix this + } + ], + "return_value": { + "type": "Any", + }, + }