Skip to content

Commit

Permalink
Add token cost + runtime monitoring to Agent and HfEngine children (#…
Browse files Browse the repository at this point in the history
…34548)

* Add monitoring to Agent and HfEngine children
  • Loading branch information
aymeric-roucher authored Dec 3, 2024
1 parent ee37bf0 commit 901f504
Show file tree
Hide file tree
Showing 5 changed files with 346 additions and 128 deletions.
186 changes: 127 additions & 59 deletions src/transformers/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@
import json
import logging
import re
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import time
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from .. import is_torch_available
from ..utils import logging as transformers_logging
from ..utils.import_utils import is_pygments_available
from .agent_types import AgentAudio, AgentImage
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
from .llm_engine import HfApiEngine, MessageRole
from .monitoring import Monitor
from .prompts import (
DEFAULT_CODE_SYSTEM_PROMPT,
DEFAULT_REACT_CODE_SYSTEM_PROMPT,
Expand Down Expand Up @@ -353,17 +355,23 @@ class Agent:
def __init__(
self,
tools: Union[List[Tool], Toolbox],
llm_engine: Callable = HfApiEngine(),
system_prompt=DEFAULT_REACT_CODE_SYSTEM_PROMPT,
tool_description_template=None,
additional_args={},
llm_engine: Callable = None,
system_prompt: Optional[str] = None,
tool_description_template: Optional[str] = None,
additional_args: Dict = {},
max_iterations: int = 6,
tool_parser=parse_json_tool_call,
tool_parser: Optional[Callable] = None,
add_base_tools: bool = False,
verbose: int = 0,
grammar: Dict[str, str] = None,
managed_agents: List = None,
grammar: Optional[Dict[str, str]] = None,
managed_agents: Optional[List] = None,
step_callbacks: Optional[List[Callable]] = None,
monitor_metrics: bool = True,
):
if system_prompt is None:
system_prompt = DEFAULT_REACT_CODE_SYSTEM_PROMPT
if tool_parser is None:
tool_parser = parse_json_tool_call
self.agent_name = self.__class__.__name__
self.llm_engine = llm_engine
self.system_prompt_template = system_prompt
Expand Down Expand Up @@ -406,6 +414,15 @@ def __init__(
elif verbose == 2:
logger.setLevel(logging.DEBUG)

# Initialize step callbacks
self.step_callbacks = step_callbacks if step_callbacks is not None else []

# Initialize Monitor if monitor_metrics is True
self.monitor = None
if monitor_metrics:
self.monitor = Monitor(self.llm_engine)
self.step_callbacks.append(self.monitor.update_metrics)

@property
def toolbox(self) -> Toolbox:
"""Get the toolbox currently available to the agent"""
Expand Down Expand Up @@ -578,13 +595,19 @@ class CodeAgent(Agent):
def __init__(
self,
tools: List[Tool],
llm_engine: Callable = HfApiEngine(),
system_prompt: str = DEFAULT_CODE_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
grammar: Dict[str, str] = None,
llm_engine: Optional[Callable] = None,
system_prompt: Optional[str] = None,
tool_description_template: Optional[str] = None,
grammar: Optional[Dict[str, str]] = None,
additional_authorized_imports: Optional[List[str]] = None,
**kwargs,
):
if llm_engine is None:
llm_engine = HfApiEngine()
if system_prompt is None:
system_prompt = DEFAULT_CODE_SYSTEM_PROMPT
if tool_description_template is None:
tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
super().__init__(
tools=tools,
llm_engine=llm_engine,
Expand Down Expand Up @@ -700,15 +723,24 @@ class ReactAgent(Agent):
def __init__(
self,
tools: List[Tool],
llm_engine: Callable = HfApiEngine(),
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
grammar: Dict[str, str] = None,
plan_type: Literal[tuple(SUPPORTED_PLAN_TYPES)] = SUPPORTED_PLAN_TYPES[0],
llm_engine: Optional[Callable] = None,
system_prompt: Optional[str] = None,
tool_description_template: Optional[str] = None,
grammar: Optional[Dict[str, str]] = None,
plan_type: Optional[str] = None,
planning_interval: Optional[int] = None,
**kwargs,
):
assert plan_type in SUPPORTED_PLAN_TYPES, f"plan type {plan_type} is not supported"
if llm_engine is None:
llm_engine = HfApiEngine()
if system_prompt is None:
system_prompt = DEFAULT_REACT_CODE_SYSTEM_PROMPT
if tool_description_template is None:
tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
if plan_type is None:
plan_type = SUPPORTED_PLAN_TYPES[0]
else:
assert plan_type in SUPPORTED_PLAN_TYPES, f"plan type {plan_type} is not supported"
super().__init__(
tools=tools,
llm_engine=llm_engine,
Expand Down Expand Up @@ -776,16 +808,24 @@ def stream_run(self, task: str):
final_answer = None
iteration = 0
while final_answer is None and iteration < self.max_iterations:
step_start_time = time.time()
step_log_entry = {"iteration": iteration, "start_time": step_start_time}
try:
step_logs = self.step()
if "final_answer" in step_logs:
final_answer = step_logs["final_answer"]
self.step(step_log_entry)
if "final_answer" in step_log_entry:
final_answer = step_log_entry["final_answer"]
except AgentError as e:
self.logger.error(e, exc_info=1)
self.logs[-1]["error"] = e
step_log_entry["error"] = e
finally:
step_end_time = time.time()
step_log_entry["step_end_time"] = step_end_time
step_log_entry["step_duration"] = step_end_time - step_start_time
self.logs.append(step_log_entry)
for callback in self.step_callbacks:
callback(step_log_entry)
iteration += 1
yield self.logs[-1]
yield step_log_entry

if final_answer is None and iteration == self.max_iterations:
error_message = "Reached max iterations."
Expand All @@ -794,6 +834,9 @@ def stream_run(self, task: str):
self.logger.error(error_message, exc_info=1)
final_answer = self.provide_final_answer(task)
final_step_log["final_answer"] = final_answer
final_step_log["step_duration"] = 0
for callback in self.step_callbacks:
callback(final_step_log)
yield final_step_log

yield final_answer
Expand All @@ -805,16 +848,24 @@ def direct_run(self, task: str):
final_answer = None
iteration = 0
while final_answer is None and iteration < self.max_iterations:
step_start_time = time.time()
step_log_entry = {"iteration": iteration, "start_time": step_start_time}
try:
if self.planning_interval is not None and iteration % self.planning_interval == 0:
self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration)
step_logs = self.step()
if "final_answer" in step_logs:
final_answer = step_logs["final_answer"]
self.step(step_log_entry)
if "final_answer" in step_log_entry:
final_answer = step_log_entry["final_answer"]
except AgentError as e:
self.logger.error(e, exc_info=1)
self.logs[-1]["error"] = e
step_log_entry["error"] = e
finally:
step_end_time = time.time()
step_log_entry["step_end_time"] = step_end_time
step_log_entry["step_duration"] = step_end_time - step_start_time
self.logs.append(step_log_entry)
for callback in self.step_callbacks:
callback(step_log_entry)
iteration += 1

if final_answer is None and iteration == self.max_iterations:
Expand All @@ -824,6 +875,9 @@ def direct_run(self, task: str):
self.logger.error(error_message, exc_info=1)
final_answer = self.provide_final_answer(task)
final_step_log["final_answer"] = final_answer
final_step_log["step_duration"] = 0
for callback in self.step_callbacks:
callback(final_step_log)

return final_answer

Expand Down Expand Up @@ -937,13 +991,19 @@ class ReactJsonAgent(ReactAgent):
def __init__(
self,
tools: List[Tool],
llm_engine: Callable = HfApiEngine(),
system_prompt: str = DEFAULT_REACT_JSON_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
grammar: Dict[str, str] = None,
llm_engine: Optional[Callable] = None,
system_prompt: Optional[str] = None,
tool_description_template: Optional[str] = None,
grammar: Optional[Dict[str, str]] = None,
planning_interval: Optional[int] = None,
**kwargs,
):
if llm_engine is None:
llm_engine = HfApiEngine()
if system_prompt is None:
system_prompt = DEFAULT_REACT_JSON_SYSTEM_PROMPT
if tool_description_template is None:
tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
super().__init__(
tools=tools,
llm_engine=llm_engine,
Expand All @@ -954,7 +1014,7 @@ def __init__(
**kwargs,
)

def step(self):
def step(self, log_entry: Dict[str, Any]):
"""
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
The errors are raised here, they are caught and logged in the run() method.
Expand All @@ -965,9 +1025,7 @@ def step(self):
self.logger.debug("===== New step =====")

# Add new step in logs
current_step_logs = {}
self.logs.append(current_step_logs)
current_step_logs["agent_memory"] = agent_memory.copy()
log_entry["agent_memory"] = agent_memory.copy()

self.logger.info("===== Calling LLM with this last message: =====")
self.logger.info(self.prompt[-1])
Expand All @@ -981,7 +1039,7 @@ def step(self):
raise AgentGenerationError(f"Error in generating llm output: {e}.")
self.logger.debug("===== Output message of the LLM: =====")
self.logger.debug(llm_output)
current_step_logs["llm_output"] = llm_output
log_entry["llm_output"] = llm_output

# Parse
self.logger.debug("===== Extracting action =====")
Expand All @@ -992,8 +1050,8 @@ def step(self):
except Exception as e:
raise AgentParsingError(f"Could not parse the given action: {e}.")

current_step_logs["rationale"] = rationale
current_step_logs["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments}
log_entry["rationale"] = rationale
log_entry["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments}

# Execute
self.logger.warning("=== Agent thoughts:")
Expand All @@ -1011,8 +1069,8 @@ def step(self):
answer = arguments
else:
answer = arguments
current_step_logs["final_answer"] = answer
return current_step_logs
log_entry["final_answer"] = answer
return answer
else:
if arguments is None:
arguments = {}
Expand All @@ -1030,8 +1088,8 @@ def step(self):
else:
updated_information = str(observation).strip()
self.logger.info(updated_information)
current_step_logs["observation"] = updated_information
return current_step_logs
log_entry["observation"] = updated_information
return log_entry


class ReactCodeAgent(ReactAgent):
Expand All @@ -1044,14 +1102,20 @@ class ReactCodeAgent(ReactAgent):
def __init__(
self,
tools: List[Tool],
llm_engine: Callable = HfApiEngine(),
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
grammar: Dict[str, str] = None,
llm_engine: Optional[Callable] = None,
system_prompt: Optional[str] = None,
tool_description_template: Optional[str] = None,
grammar: Optional[Dict[str, str]] = None,
additional_authorized_imports: Optional[List[str]] = None,
planning_interval: Optional[int] = None,
**kwargs,
):
if llm_engine is None:
llm_engine = HfApiEngine()
if system_prompt is None:
system_prompt = DEFAULT_REACT_CODE_SYSTEM_PROMPT
if tool_description_template is None:
tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
super().__init__(
tools=tools,
llm_engine=llm_engine,
Expand All @@ -1075,21 +1139,18 @@ def __init__(
self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports))
self.custom_tools = {}

def step(self):
def step(self, log_entry: Dict[str, Any]):
"""
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
The errors are raised here, they are caught and logged in the run() method.
"""
agent_memory = self.write_inner_memory_from_logs()

self.prompt = agent_memory.copy()

self.logger.debug("===== New step =====")

# Add new step in logs
current_step_logs = {}
self.logs.append(current_step_logs)
current_step_logs["agent_memory"] = agent_memory.copy()
log_entry["agent_memory"] = agent_memory.copy()

self.logger.info("===== Calling LLM with these last messages: =====")
self.logger.info(self.prompt[-2:])
Expand All @@ -1104,7 +1165,7 @@ def step(self):

self.logger.debug("=== Output message of the LLM:")
self.logger.debug(llm_output)
current_step_logs["llm_output"] = llm_output
log_entry["llm_output"] = llm_output

# Parse
self.logger.debug("=== Extracting action ===")
Expand All @@ -1120,8 +1181,8 @@ def step(self):
error_msg = f"Error in code parsing: {e}. Make sure to provide correct code"
raise AgentParsingError(error_msg)

current_step_logs["rationale"] = rationale
current_step_logs["tool_call"] = {"tool_name": "code interpreter", "tool_arguments": code_action}
log_entry["rationale"] = rationale
log_entry["tool_call"] = {"tool_name": "code interpreter", "tool_arguments": code_action}

# Execute
self.log_rationale_code_action(rationale, code_action)
Expand All @@ -1146,7 +1207,7 @@ def step(self):
self.logger.warning("Last output from code snippet:")
self.logger.log(32, str(result))
observation += "Last output from code snippet:\n" + str(result)[:100000]
current_step_logs["observation"] = observation
log_entry["observation"] = observation
except Exception as e:
error_msg = f"Code execution failed due to the following error:\n{str(e)}"
if "'dict' object has no attribute 'read'" in str(e):
Expand All @@ -1156,8 +1217,11 @@ def step(self):
if line[: len("final_answer")] == "final_answer":
self.logger.log(33, "Final answer:")
self.logger.log(32, result)
current_step_logs["final_answer"] = result
return current_step_logs
log_entry["final_answer"] = result
return result


LENGTH_TRUNCATE_REPORTS = 1000


class ManagedAgent:
Expand Down Expand Up @@ -1200,10 +1264,14 @@ def __call__(self, request, **kwargs):
answer += f"\n\nFor more detail, find below a summary of this agent's work:\nSUMMARY OF WORK FROM AGENT '{self.name}':\n"
for message in self.agent.write_inner_memory_from_logs(summary_mode=True):
content = message["content"]
if len(str(content)) < 1000 or "[FACTS LIST]" in str(content):
if len(str(content)) < LENGTH_TRUNCATE_REPORTS or "[FACTS LIST]" in str(content):
answer += "\n" + str(content) + "\n---"
else:
answer += "\n" + str(content)[:1000] + "\n(...Step was truncated because too long)...\n---"
answer += (
"\n"
+ str(content)[:LENGTH_TRUNCATE_REPORTS]
+ "\n(...Step was truncated because too long)...\n---"
)
answer += f"\nEND OF SUMMARY OF WORK FROM AGENT '{self.name}'."
return answer
else:
Expand Down
Loading

0 comments on commit 901f504

Please sign in to comment.