Skip to content

Commit

Permalink
Adds custom agents to the langchain benchmarking repo (#120)
Browse files Browse the repository at this point in the history
* This PR adds code for running custom agents to the langchain
benchmarking repo.
* The agent code is good enough for experimentation / prototyping, but I
don't think it's good enough for the langchain repo:
-- The abstractions aren't fully implemented and aren't ready for
production use -- but OK for research
-- For production use, one may want to remove all the intermediate
abstractions to keep the agent as simple as possible

I was thinking initially of including this in a different repo, but I
think it's over-complicating things, probably OK to include some
reference implementations inside of langchain benchmarks.
  • Loading branch information
eyurtsev authored Dec 14, 2023
1 parent 7ed859c commit 8798735
Show file tree
Hide file tree
Showing 19 changed files with 991 additions and 74 deletions.
Empty file added agents/__init__.py
Empty file.
Empty file added agents/tests/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion langchain_benchmarks/model_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@
),
RegisteredModel(
provider="fireworks",
name="mixtral-8x7b-instruct",
name="mixtral-8x7b-instruct-fw",
description="Mistral MoE 8x7B Instruct v0.1 model with Sparse "
"Mixture of Experts. Fine tuned for instruction following",
type="llm",
Expand Down
7 changes: 7 additions & 0 deletions langchain_benchmarks/tool_usage/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from langchain_benchmarks.tool_usage.agents.adapters import apply_agent_executor_adapter
from langchain_benchmarks.tool_usage.agents.experimental.factory import (
CustomAgentFactory,
)
from langchain_benchmarks.tool_usage.agents.openai_functions import OpenAIAgentFactory

__all__ = ["OpenAIAgentFactory", "apply_agent_executor_adapter", "CustomAgentFactory"]
71 changes: 71 additions & 0 deletions langchain_benchmarks/tool_usage/agents/adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from typing import Optional, Callable, Any

from langchain.agents import AgentExecutor
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough


def _ensure_output_exists(inputs: dict) -> dict:
"""Make sure that the output key is always present."""
if "output" not in inputs:
return {"output": "", **inputs}
return inputs


def apply_agent_executor_adapter(
agent_executor: AgentExecutor,
*,
state_reader: Optional[Callable[[], Any]] = None,
) -> Runnable:
"""An adapter for the agent executor to standardize its input and output.
1) Map `question` to `input` (`question` is used in the datasets,
but `input` is used in the agent executor)
2) Ensure that `output` is always returned (will be set to "" if missing) --
note that this may be relaxed after more updates in the eval config.
3) Populate `state` key in the response of the agent with the system state
if a state reader is provided.
Args:
agent_executor: the agent executor
state_reader: A callable without parameters that if invoked will return
the state of the environment. Used to populate the 'state' key.
Returns:
a new runnable with a standardized output.
"""

def _read_state(*args: Any, **kwargs: Any) -> Any:
"""Read the state of the environment."""
if state_reader is not None:
return state_reader()
else:
return None

def _format_input(inputs: dict) -> dict:
"""Make sure that the input is always called `input`."""

if "question" not in inputs:
raise ValueError(
"Expected 'question' to be in the inputs. Found only the following "
f"keys {sorted(inputs.keys())}."
)

inputs = inputs.copy() # Because 'question' is popped below

if "input" not in inputs:
return {"input": inputs.pop("question"), **inputs}
return inputs

runnable = (
RunnableLambda(_format_input).with_config({"run_name": "Format Input"})
| agent_executor
| RunnableLambda(_ensure_output_exists).with_config(
{"run_name": "Ensure Output"}
)
)

if state_reader is not None:
runnable = runnable | RunnablePassthrough.assign(state=_read_state).with_config(
{"run_name": "Read Env State"}
)
return runnable
Empty file.
133 changes: 133 additions & 0 deletions langchain_benchmarks/tool_usage/agents/experimental/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from typing import List, Literal, Optional, Sequence, Tuple, Union

from langchain.agents import AgentOutputParser
from langchain.prompts.chat import ChatPromptTemplate
from langchain.schema.runnable import Runnable
from langchain.tools import StructuredTool
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.language_models import BaseChatModel, BaseLanguageModel
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langchain_core.prompts import MessagesPlaceholder
from typing_extensions import NotRequired, TypedDict

from langchain_benchmarks import RateLimiter
from langchain_benchmarks.rate_limiting import with_rate_limit
from langchain_benchmarks.tool_usage.agents.experimental.encoder import (
AstPrinter,
TypeScriptEncoder,
XMLEncoder,
)
from langchain_benchmarks.tool_usage.agents.experimental.encoder import FunctionResult
from langchain_benchmarks.tool_usage.agents.experimental.prompts import (
_AGENT_INSTRUCTIONS_BLOB_STYLE,
)
from langchain_benchmarks.tool_usage.agents.experimental.tool_utils import (
convert_tool_to_function_definition,
)


def format_steps_for_chat(
intermediate_steps: List[Tuple[AgentAction, str]],
ast_printer: AstPrinter,
) -> List[BaseMessage]:
"""Format the steps."""
messages = []
for action, observation in intermediate_steps:
# Action messages contains the tool invocation request from the LLM
# Now add the result of the tool invocation.

if action.tool == "_Exception":
messages.append(
AIMessage(
content=action.log,
)
)
messages.append(
# Tool input is the error message for the exception
HumanMessage(content=action.tool_input)
)
else:
messages.extend(action.messages)
function_result: FunctionResult = {
"name": action.tool,
"error": None,
"result": observation,
}
messages.append(
HumanMessage(
content=ast_printer.visit_function_result(function_result),
)
)

return messages


# PUBLIC API


class AgentInput(TypedDict):
"""The input to the agent."""

input: str
"""The input to the agent."""
intermediate_steps: List[Tuple[AgentAction, str]]
"""The intermediate steps taken by the agent."""
examples: NotRequired[List[BaseMessage]]
"""A list of messages that can be used to form example traces."""


def create_agent(
model: Union[BaseChatModel, BaseLanguageModel],
tools: Sequence[StructuredTool],
parser: AgentOutputParser,
*,
ast_printer: Union[AstPrinter, Literal["xml"]] = "xml",
rate_limiter: Optional[RateLimiter] = None,
) -> Runnable[AgentInput, Union[AgentAction, AgentFinish]]:
"""Create an agent for a chat model."""
if isinstance(ast_printer, str):
if ast_printer == "xml":
ast_printer_ = XMLEncoder()
elif ast_printer == "typescript":
ast_printer_ = TypeScriptEncoder()
else:
raise ValueError(f"Unknown ast printer: {ast_printer}")
elif isinstance(ast_printer, AstPrinter):
ast_printer_ = ast_printer
else:
raise TypeError(
f"Expected AstPrinter or str, got {type(ast_printer)} for `ast_printer`"
)

function_definitions = [convert_tool_to_function_definition(tool) for tool in tools]
tool_description = ast_printer_.visit_function_definitions(function_definitions)

template = ChatPromptTemplate.from_messages(
[
("system", _AGENT_INSTRUCTIONS_BLOB_STYLE),
MessagesPlaceholder("examples"), # Can use to add example traces
("human", "{input}"),
MessagesPlaceholder("history"),
]
).partial(tool_description=tool_description)

# For the time being, hard-coding the fact that we're using a <tool> tag.
model = model.bind(stop=["</tool>"])

if rate_limiter:
# Apply a rate limiter if it was provided
model = with_rate_limit(model, rate_limiter)

agent = (
{
"input": lambda x: x["input"],
"history": lambda x: format_steps_for_chat(
x["intermediate_steps"], ast_printer_
),
"examples": lambda x: x.get("examples", []),
}
| template
| model
| parser
)
return agent
Loading

0 comments on commit 8798735

Please sign in to comment.