-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds custom agents to the langchain benchmarking repo (#120)
* This PR adds code for running custom agents to the langchain benchmarking repo. * The agent code is good enough for experimentation / prototyping, but I don't think it's good enough for the langchain repo: -- The abstractions aren't fully implemented and aren't ready for production use -- but OK for research -- For production use, one may want to remove all the intermediate abstractions to keep the agent as simple as possible I was thinking initially of including this in a different repo, but I think it's over-complicating things, probably OK to include some reference implementations inside of langchain benchmarks.
- Loading branch information
Showing
19 changed files
with
991 additions
and
74 deletions.
There are no files selected for viewing
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from langchain_benchmarks.tool_usage.agents.adapters import apply_agent_executor_adapter | ||
from langchain_benchmarks.tool_usage.agents.experimental.factory import ( | ||
CustomAgentFactory, | ||
) | ||
from langchain_benchmarks.tool_usage.agents.openai_functions import OpenAIAgentFactory | ||
|
||
__all__ = ["OpenAIAgentFactory", "apply_agent_executor_adapter", "CustomAgentFactory"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from typing import Optional, Callable, Any | ||
|
||
from langchain.agents import AgentExecutor | ||
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough | ||
|
||
|
||
def _ensure_output_exists(inputs: dict) -> dict: | ||
"""Make sure that the output key is always present.""" | ||
if "output" not in inputs: | ||
return {"output": "", **inputs} | ||
return inputs | ||
|
||
|
||
def apply_agent_executor_adapter( | ||
agent_executor: AgentExecutor, | ||
*, | ||
state_reader: Optional[Callable[[], Any]] = None, | ||
) -> Runnable: | ||
"""An adapter for the agent executor to standardize its input and output. | ||
1) Map `question` to `input` (`question` is used in the datasets, | ||
but `input` is used in the agent executor) | ||
2) Ensure that `output` is always returned (will be set to "" if missing) -- | ||
note that this may be relaxed after more updates in the eval config. | ||
3) Populate `state` key in the response of the agent with the system state | ||
if a state reader is provided. | ||
Args: | ||
agent_executor: the agent executor | ||
state_reader: A callable without parameters that if invoked will return | ||
the state of the environment. Used to populate the 'state' key. | ||
Returns: | ||
a new runnable with a standardized output. | ||
""" | ||
|
||
def _read_state(*args: Any, **kwargs: Any) -> Any: | ||
"""Read the state of the environment.""" | ||
if state_reader is not None: | ||
return state_reader() | ||
else: | ||
return None | ||
|
||
def _format_input(inputs: dict) -> dict: | ||
"""Make sure that the input is always called `input`.""" | ||
|
||
if "question" not in inputs: | ||
raise ValueError( | ||
"Expected 'question' to be in the inputs. Found only the following " | ||
f"keys {sorted(inputs.keys())}." | ||
) | ||
|
||
inputs = inputs.copy() # Because 'question' is popped below | ||
|
||
if "input" not in inputs: | ||
return {"input": inputs.pop("question"), **inputs} | ||
return inputs | ||
|
||
runnable = ( | ||
RunnableLambda(_format_input).with_config({"run_name": "Format Input"}) | ||
| agent_executor | ||
| RunnableLambda(_ensure_output_exists).with_config( | ||
{"run_name": "Ensure Output"} | ||
) | ||
) | ||
|
||
if state_reader is not None: | ||
runnable = runnable | RunnablePassthrough.assign(state=_read_state).with_config( | ||
{"run_name": "Read Env State"} | ||
) | ||
return runnable |
Empty file.
133 changes: 133 additions & 0 deletions
133
langchain_benchmarks/tool_usage/agents/experimental/agent.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
from typing import List, Literal, Optional, Sequence, Tuple, Union | ||
|
||
from langchain.agents import AgentOutputParser | ||
from langchain.prompts.chat import ChatPromptTemplate | ||
from langchain.schema.runnable import Runnable | ||
from langchain.tools import StructuredTool | ||
from langchain_core.agents import AgentAction, AgentFinish | ||
from langchain_core.language_models import BaseChatModel, BaseLanguageModel | ||
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage | ||
from langchain_core.prompts import MessagesPlaceholder | ||
from typing_extensions import NotRequired, TypedDict | ||
|
||
from langchain_benchmarks import RateLimiter | ||
from langchain_benchmarks.rate_limiting import with_rate_limit | ||
from langchain_benchmarks.tool_usage.agents.experimental.encoder import ( | ||
AstPrinter, | ||
TypeScriptEncoder, | ||
XMLEncoder, | ||
) | ||
from langchain_benchmarks.tool_usage.agents.experimental.encoder import FunctionResult | ||
from langchain_benchmarks.tool_usage.agents.experimental.prompts import ( | ||
_AGENT_INSTRUCTIONS_BLOB_STYLE, | ||
) | ||
from langchain_benchmarks.tool_usage.agents.experimental.tool_utils import ( | ||
convert_tool_to_function_definition, | ||
) | ||
|
||
|
||
def format_steps_for_chat( | ||
intermediate_steps: List[Tuple[AgentAction, str]], | ||
ast_printer: AstPrinter, | ||
) -> List[BaseMessage]: | ||
"""Format the steps.""" | ||
messages = [] | ||
for action, observation in intermediate_steps: | ||
# Action messages contains the tool invocation request from the LLM | ||
# Now add the result of the tool invocation. | ||
|
||
if action.tool == "_Exception": | ||
messages.append( | ||
AIMessage( | ||
content=action.log, | ||
) | ||
) | ||
messages.append( | ||
# Tool input is the error message for the exception | ||
HumanMessage(content=action.tool_input) | ||
) | ||
else: | ||
messages.extend(action.messages) | ||
function_result: FunctionResult = { | ||
"name": action.tool, | ||
"error": None, | ||
"result": observation, | ||
} | ||
messages.append( | ||
HumanMessage( | ||
content=ast_printer.visit_function_result(function_result), | ||
) | ||
) | ||
|
||
return messages | ||
|
||
|
||
# PUBLIC API | ||
|
||
|
||
class AgentInput(TypedDict): | ||
"""The input to the agent.""" | ||
|
||
input: str | ||
"""The input to the agent.""" | ||
intermediate_steps: List[Tuple[AgentAction, str]] | ||
"""The intermediate steps taken by the agent.""" | ||
examples: NotRequired[List[BaseMessage]] | ||
"""A list of messages that can be used to form example traces.""" | ||
|
||
|
||
def create_agent( | ||
model: Union[BaseChatModel, BaseLanguageModel], | ||
tools: Sequence[StructuredTool], | ||
parser: AgentOutputParser, | ||
*, | ||
ast_printer: Union[AstPrinter, Literal["xml"]] = "xml", | ||
rate_limiter: Optional[RateLimiter] = None, | ||
) -> Runnable[AgentInput, Union[AgentAction, AgentFinish]]: | ||
"""Create an agent for a chat model.""" | ||
if isinstance(ast_printer, str): | ||
if ast_printer == "xml": | ||
ast_printer_ = XMLEncoder() | ||
elif ast_printer == "typescript": | ||
ast_printer_ = TypeScriptEncoder() | ||
else: | ||
raise ValueError(f"Unknown ast printer: {ast_printer}") | ||
elif isinstance(ast_printer, AstPrinter): | ||
ast_printer_ = ast_printer | ||
else: | ||
raise TypeError( | ||
f"Expected AstPrinter or str, got {type(ast_printer)} for `ast_printer`" | ||
) | ||
|
||
function_definitions = [convert_tool_to_function_definition(tool) for tool in tools] | ||
tool_description = ast_printer_.visit_function_definitions(function_definitions) | ||
|
||
template = ChatPromptTemplate.from_messages( | ||
[ | ||
("system", _AGENT_INSTRUCTIONS_BLOB_STYLE), | ||
MessagesPlaceholder("examples"), # Can use to add example traces | ||
("human", "{input}"), | ||
MessagesPlaceholder("history"), | ||
] | ||
).partial(tool_description=tool_description) | ||
|
||
# For the time being, hard-coding the fact that we're using a <tool> tag. | ||
model = model.bind(stop=["</tool>"]) | ||
|
||
if rate_limiter: | ||
# Apply a rate limiter if it was provided | ||
model = with_rate_limit(model, rate_limiter) | ||
|
||
agent = ( | ||
{ | ||
"input": lambda x: x["input"], | ||
"history": lambda x: format_steps_for_chat( | ||
x["intermediate_steps"], ast_printer_ | ||
), | ||
"examples": lambda x: x.get("examples", []), | ||
} | ||
| template | ||
| model | ||
| parser | ||
) | ||
return agent |
Oops, something went wrong.