From 901f50458050116f8df860717ac38fe172c6809f Mon Sep 17 00:00:00 2001 From: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com> Date: Tue, 3 Dec 2024 13:14:52 +0100 Subject: [PATCH] Add token cost + runtime monitoring to Agent and HfEngine children (#34548) * Add monitoring to Agent and HfEngine children --- src/transformers/agents/agents.py | 186 ++++++++++++++++++-------- src/transformers/agents/llm_engine.py | 149 +++++++++++++-------- src/transformers/agents/monitoring.py | 28 +++- src/transformers/agents/tools.py | 25 ++-- tests/agents/test_monitoring.py | 86 +++++++++++- 5 files changed, 346 insertions(+), 128 deletions(-) diff --git a/src/transformers/agents/agents.py b/src/transformers/agents/agents.py index c461c50f29592c..08c30d54fd43d5 100644 --- a/src/transformers/agents/agents.py +++ b/src/transformers/agents/agents.py @@ -17,7 +17,8 @@ import json import logging import re -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +import time +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from .. import is_torch_available from ..utils import logging as transformers_logging @@ -25,6 +26,7 @@ from .agent_types import AgentAudio, AgentImage from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools from .llm_engine import HfApiEngine, MessageRole +from .monitoring import Monitor from .prompts import ( DEFAULT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_CODE_SYSTEM_PROMPT, @@ -353,17 +355,23 @@ class Agent: def __init__( self, tools: Union[List[Tool], Toolbox], - llm_engine: Callable = HfApiEngine(), - system_prompt=DEFAULT_REACT_CODE_SYSTEM_PROMPT, - tool_description_template=None, - additional_args={}, + llm_engine: Callable = None, + system_prompt: Optional[str] = None, + tool_description_template: Optional[str] = None, + additional_args: Dict = {}, max_iterations: int = 6, - tool_parser=parse_json_tool_call, + tool_parser: Optional[Callable] = None, add_base_tools: bool = False, verbose: int = 0, - grammar: Dict[str, str] = None, - managed_agents: List = None, + grammar: Optional[Dict[str, str]] = None, + managed_agents: Optional[List] = None, + step_callbacks: Optional[List[Callable]] = None, + monitor_metrics: bool = True, ): + if system_prompt is None: + system_prompt = DEFAULT_REACT_CODE_SYSTEM_PROMPT + if tool_parser is None: + tool_parser = parse_json_tool_call self.agent_name = self.__class__.__name__ self.llm_engine = llm_engine self.system_prompt_template = system_prompt @@ -406,6 +414,15 @@ def __init__( elif verbose == 2: logger.setLevel(logging.DEBUG) + # Initialize step callbacks + self.step_callbacks = step_callbacks if step_callbacks is not None else [] + + # Initialize Monitor if monitor_metrics is True + self.monitor = None + if monitor_metrics: + self.monitor = Monitor(self.llm_engine) + self.step_callbacks.append(self.monitor.update_metrics) + @property def toolbox(self) -> Toolbox: """Get the toolbox currently available to the agent""" @@ -578,13 +595,19 @@ class CodeAgent(Agent): def __init__( self, tools: List[Tool], - llm_engine: Callable = HfApiEngine(), - system_prompt: str = DEFAULT_CODE_SYSTEM_PROMPT, - tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE, - grammar: Dict[str, str] = None, + llm_engine: Optional[Callable] = None, + system_prompt: Optional[str] = None, + tool_description_template: Optional[str] = None, + grammar: Optional[Dict[str, str]] = None, additional_authorized_imports: Optional[List[str]] = None, **kwargs, ): + if llm_engine is None: + llm_engine = HfApiEngine() + if system_prompt is None: + system_prompt = DEFAULT_CODE_SYSTEM_PROMPT + if tool_description_template is None: + tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE super().__init__( tools=tools, llm_engine=llm_engine, @@ -700,15 +723,24 @@ class ReactAgent(Agent): def __init__( self, tools: List[Tool], - llm_engine: Callable = HfApiEngine(), - system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT, - tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE, - grammar: Dict[str, str] = None, - plan_type: Literal[tuple(SUPPORTED_PLAN_TYPES)] = SUPPORTED_PLAN_TYPES[0], + llm_engine: Optional[Callable] = None, + system_prompt: Optional[str] = None, + tool_description_template: Optional[str] = None, + grammar: Optional[Dict[str, str]] = None, + plan_type: Optional[str] = None, planning_interval: Optional[int] = None, **kwargs, ): - assert plan_type in SUPPORTED_PLAN_TYPES, f"plan type {plan_type} is not supported" + if llm_engine is None: + llm_engine = HfApiEngine() + if system_prompt is None: + system_prompt = DEFAULT_REACT_CODE_SYSTEM_PROMPT + if tool_description_template is None: + tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE + if plan_type is None: + plan_type = SUPPORTED_PLAN_TYPES[0] + else: + assert plan_type in SUPPORTED_PLAN_TYPES, f"plan type {plan_type} is not supported" super().__init__( tools=tools, llm_engine=llm_engine, @@ -776,16 +808,24 @@ def stream_run(self, task: str): final_answer = None iteration = 0 while final_answer is None and iteration < self.max_iterations: + step_start_time = time.time() + step_log_entry = {"iteration": iteration, "start_time": step_start_time} try: - step_logs = self.step() - if "final_answer" in step_logs: - final_answer = step_logs["final_answer"] + self.step(step_log_entry) + if "final_answer" in step_log_entry: + final_answer = step_log_entry["final_answer"] except AgentError as e: self.logger.error(e, exc_info=1) - self.logs[-1]["error"] = e + step_log_entry["error"] = e finally: + step_end_time = time.time() + step_log_entry["step_end_time"] = step_end_time + step_log_entry["step_duration"] = step_end_time - step_start_time + self.logs.append(step_log_entry) + for callback in self.step_callbacks: + callback(step_log_entry) iteration += 1 - yield self.logs[-1] + yield step_log_entry if final_answer is None and iteration == self.max_iterations: error_message = "Reached max iterations." @@ -794,6 +834,9 @@ def stream_run(self, task: str): self.logger.error(error_message, exc_info=1) final_answer = self.provide_final_answer(task) final_step_log["final_answer"] = final_answer + final_step_log["step_duration"] = 0 + for callback in self.step_callbacks: + callback(final_step_log) yield final_step_log yield final_answer @@ -805,16 +848,24 @@ def direct_run(self, task: str): final_answer = None iteration = 0 while final_answer is None and iteration < self.max_iterations: + step_start_time = time.time() + step_log_entry = {"iteration": iteration, "start_time": step_start_time} try: if self.planning_interval is not None and iteration % self.planning_interval == 0: self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration) - step_logs = self.step() - if "final_answer" in step_logs: - final_answer = step_logs["final_answer"] + self.step(step_log_entry) + if "final_answer" in step_log_entry: + final_answer = step_log_entry["final_answer"] except AgentError as e: self.logger.error(e, exc_info=1) - self.logs[-1]["error"] = e + step_log_entry["error"] = e finally: + step_end_time = time.time() + step_log_entry["step_end_time"] = step_end_time + step_log_entry["step_duration"] = step_end_time - step_start_time + self.logs.append(step_log_entry) + for callback in self.step_callbacks: + callback(step_log_entry) iteration += 1 if final_answer is None and iteration == self.max_iterations: @@ -824,6 +875,9 @@ def direct_run(self, task: str): self.logger.error(error_message, exc_info=1) final_answer = self.provide_final_answer(task) final_step_log["final_answer"] = final_answer + final_step_log["step_duration"] = 0 + for callback in self.step_callbacks: + callback(final_step_log) return final_answer @@ -937,13 +991,19 @@ class ReactJsonAgent(ReactAgent): def __init__( self, tools: List[Tool], - llm_engine: Callable = HfApiEngine(), - system_prompt: str = DEFAULT_REACT_JSON_SYSTEM_PROMPT, - tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE, - grammar: Dict[str, str] = None, + llm_engine: Optional[Callable] = None, + system_prompt: Optional[str] = None, + tool_description_template: Optional[str] = None, + grammar: Optional[Dict[str, str]] = None, planning_interval: Optional[int] = None, **kwargs, ): + if llm_engine is None: + llm_engine = HfApiEngine() + if system_prompt is None: + system_prompt = DEFAULT_REACT_JSON_SYSTEM_PROMPT + if tool_description_template is None: + tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE super().__init__( tools=tools, llm_engine=llm_engine, @@ -954,7 +1014,7 @@ def __init__( **kwargs, ) - def step(self): + def step(self, log_entry: Dict[str, Any]): """ Perform one step in the ReAct framework: the agent thinks, acts, and observes the result. The errors are raised here, they are caught and logged in the run() method. @@ -965,9 +1025,7 @@ def step(self): self.logger.debug("===== New step =====") # Add new step in logs - current_step_logs = {} - self.logs.append(current_step_logs) - current_step_logs["agent_memory"] = agent_memory.copy() + log_entry["agent_memory"] = agent_memory.copy() self.logger.info("===== Calling LLM with this last message: =====") self.logger.info(self.prompt[-1]) @@ -981,7 +1039,7 @@ def step(self): raise AgentGenerationError(f"Error in generating llm output: {e}.") self.logger.debug("===== Output message of the LLM: =====") self.logger.debug(llm_output) - current_step_logs["llm_output"] = llm_output + log_entry["llm_output"] = llm_output # Parse self.logger.debug("===== Extracting action =====") @@ -992,8 +1050,8 @@ def step(self): except Exception as e: raise AgentParsingError(f"Could not parse the given action: {e}.") - current_step_logs["rationale"] = rationale - current_step_logs["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments} + log_entry["rationale"] = rationale + log_entry["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments} # Execute self.logger.warning("=== Agent thoughts:") @@ -1011,8 +1069,8 @@ def step(self): answer = arguments else: answer = arguments - current_step_logs["final_answer"] = answer - return current_step_logs + log_entry["final_answer"] = answer + return answer else: if arguments is None: arguments = {} @@ -1030,8 +1088,8 @@ def step(self): else: updated_information = str(observation).strip() self.logger.info(updated_information) - current_step_logs["observation"] = updated_information - return current_step_logs + log_entry["observation"] = updated_information + return log_entry class ReactCodeAgent(ReactAgent): @@ -1044,14 +1102,20 @@ class ReactCodeAgent(ReactAgent): def __init__( self, tools: List[Tool], - llm_engine: Callable = HfApiEngine(), - system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT, - tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE, - grammar: Dict[str, str] = None, + llm_engine: Optional[Callable] = None, + system_prompt: Optional[str] = None, + tool_description_template: Optional[str] = None, + grammar: Optional[Dict[str, str]] = None, additional_authorized_imports: Optional[List[str]] = None, planning_interval: Optional[int] = None, **kwargs, ): + if llm_engine is None: + llm_engine = HfApiEngine() + if system_prompt is None: + system_prompt = DEFAULT_REACT_CODE_SYSTEM_PROMPT + if tool_description_template is None: + tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE super().__init__( tools=tools, llm_engine=llm_engine, @@ -1075,7 +1139,7 @@ def __init__( self.system_prompt = self.system_prompt.replace("<>", str(self.authorized_imports)) self.custom_tools = {} - def step(self): + def step(self, log_entry: Dict[str, Any]): """ Perform one step in the ReAct framework: the agent thinks, acts, and observes the result. The errors are raised here, they are caught and logged in the run() method. @@ -1083,13 +1147,10 @@ def step(self): agent_memory = self.write_inner_memory_from_logs() self.prompt = agent_memory.copy() - self.logger.debug("===== New step =====") # Add new step in logs - current_step_logs = {} - self.logs.append(current_step_logs) - current_step_logs["agent_memory"] = agent_memory.copy() + log_entry["agent_memory"] = agent_memory.copy() self.logger.info("===== Calling LLM with these last messages: =====") self.logger.info(self.prompt[-2:]) @@ -1104,7 +1165,7 @@ def step(self): self.logger.debug("=== Output message of the LLM:") self.logger.debug(llm_output) - current_step_logs["llm_output"] = llm_output + log_entry["llm_output"] = llm_output # Parse self.logger.debug("=== Extracting action ===") @@ -1120,8 +1181,8 @@ def step(self): error_msg = f"Error in code parsing: {e}. Make sure to provide correct code" raise AgentParsingError(error_msg) - current_step_logs["rationale"] = rationale - current_step_logs["tool_call"] = {"tool_name": "code interpreter", "tool_arguments": code_action} + log_entry["rationale"] = rationale + log_entry["tool_call"] = {"tool_name": "code interpreter", "tool_arguments": code_action} # Execute self.log_rationale_code_action(rationale, code_action) @@ -1146,7 +1207,7 @@ def step(self): self.logger.warning("Last output from code snippet:") self.logger.log(32, str(result)) observation += "Last output from code snippet:\n" + str(result)[:100000] - current_step_logs["observation"] = observation + log_entry["observation"] = observation except Exception as e: error_msg = f"Code execution failed due to the following error:\n{str(e)}" if "'dict' object has no attribute 'read'" in str(e): @@ -1156,8 +1217,11 @@ def step(self): if line[: len("final_answer")] == "final_answer": self.logger.log(33, "Final answer:") self.logger.log(32, result) - current_step_logs["final_answer"] = result - return current_step_logs + log_entry["final_answer"] = result + return result + + +LENGTH_TRUNCATE_REPORTS = 1000 class ManagedAgent: @@ -1200,10 +1264,14 @@ def __call__(self, request, **kwargs): answer += f"\n\nFor more detail, find below a summary of this agent's work:\nSUMMARY OF WORK FROM AGENT '{self.name}':\n" for message in self.agent.write_inner_memory_from_logs(summary_mode=True): content = message["content"] - if len(str(content)) < 1000 or "[FACTS LIST]" in str(content): + if len(str(content)) < LENGTH_TRUNCATE_REPORTS or "[FACTS LIST]" in str(content): answer += "\n" + str(content) + "\n---" else: - answer += "\n" + str(content)[:1000] + "\n(...Step was truncated because too long)...\n---" + answer += ( + "\n" + + str(content)[:LENGTH_TRUNCATE_REPORTS] + + "\n(...Step was truncated because too long)...\n---" + ) answer += f"\nEND OF SUMMARY OF WORK FROM AGENT '{self.name}'." return answer else: diff --git a/src/transformers/agents/llm_engine.py b/src/transformers/agents/llm_engine.py index 456c6172a77cb0..afa4d62d059e5b 100644 --- a/src/transformers/agents/llm_engine.py +++ b/src/transformers/agents/llm_engine.py @@ -20,7 +20,12 @@ from huggingface_hub import InferenceClient +from .. import AutoTokenizer from ..pipelines.base import Pipeline +from ..utils import logging + + +logger = logging.get_logger(__name__) class MessageRole(str, Enum): @@ -67,46 +72,32 @@ def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions: } -class HfApiEngine: - """A class to interact with Hugging Face's Inference API for language model interaction. - - This engine allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization. - - Parameters: - model (`str`, *optional*, defaults to `"meta-llama/Meta-Llama-3.1-8B-Instruct"`): - The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub. - token (`str`, *optional*): - The Hugging Face API token for authentication. If not provided, the class will use the token stored in the Hugging Face CLI configuration. - max_tokens (`int`, *optional*, defaults to 1500): - The maximum number of tokens allowed in the output. - timeout (`int`, *optional*, defaults to 120): - Timeout for the API request, in seconds. - - Raises: - ValueError: - If the model name is not provided. - """ - - def __init__( - self, - model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct", - token: Optional[str] = None, - max_tokens: Optional[int] = 1500, - timeout: Optional[int] = 120, +class HfEngine: + def __init__(self, model_id: Optional[str] = None): + self.last_input_token_count = None + self.last_output_token_count = None + if model_id is None: + model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct" + logger.warning(f"Using default model for token counting: '{model_id}'") + try: + self.tokenizer = AutoTokenizer.from_pretrained(model_id) + except Exception as e: + logger.warning(f"Failed to load tokenizer for model {model_id}: {e}. Loading default tokenizer instead.") + self.tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-1.7B-Instruct") + + def get_token_counts(self): + return { + "input_token_count": self.last_input_token_count, + "output_token_count": self.last_output_token_count, + } + + def generate( + self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None ): - """Initialize the HfApiEngine.""" - if not model: - raise ValueError("Model name must be provided.") - - self.model = model - self.client = InferenceClient(self.model, token=token, timeout=timeout) - self.max_tokens = max_tokens + raise NotImplementedError def __call__( - self, - messages: List[Dict[str, str]], - stop_sequences: Optional[List[str]] = [], - grammar: Optional[str] = None, + self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None ) -> str: """Process the input messages and return the model's response. @@ -136,6 +127,57 @@ def __call__( "Quantum mechanics is the branch of physics that studies..." ``` """ + if not isinstance(messages, List): + raise ValueError("Messages should be a list of dictionaries with 'role' and 'content' keys.") + if stop_sequences is None: + stop_sequences = [] + response = self.generate(messages, stop_sequences, grammar) + self.last_input_token_count = len(self.tokenizer.apply_chat_template(messages, tokenize=True)) + self.last_output_token_count = len(self.tokenizer.encode(response)) + + # Remove stop sequences from LLM output + for stop_seq in stop_sequences: + if response[-len(stop_seq) :] == stop_seq: + response = response[: -len(stop_seq)] + return response + + +class HfApiEngine(HfEngine): + """A class to interact with Hugging Face's Inference API for language model interaction. + + This engine allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization. + + Parameters: + model (`str`, *optional*, defaults to `"meta-llama/Meta-Llama-3.1-8B-Instruct"`): + The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub. + token (`str`, *optional*): + Token used by the Hugging Face API for authentication. + If not provided, the class will use the token stored in the Hugging Face CLI configuration. + max_tokens (`int`, *optional*, defaults to 1500): + The maximum number of tokens allowed in the output. + timeout (`int`, *optional*, defaults to 120): + Timeout for the API request, in seconds. + + Raises: + ValueError: + If the model name is not provided. + """ + + def __init__( + self, + model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct", + token: Optional[str] = None, + max_tokens: Optional[int] = 1500, + timeout: Optional[int] = 120, + ): + super().__init__(model_id=model) + self.model = model + self.client = InferenceClient(self.model, token=token, timeout=timeout) + self.max_tokens = max_tokens + + def generate( + self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None + ) -> str: # Get clean message list messages = get_clean_message_list(messages, role_conversions=llama_role_conversions) @@ -148,41 +190,40 @@ def __call__( response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=self.max_tokens) response = response.choices[0].message.content - - # Remove stop sequences from LLM output - for stop_seq in stop_sequences: - if response[-len(stop_seq) :] == stop_seq: - response = response[: -len(stop_seq)] return response -class TransformersEngine: +class TransformersEngine(HfEngine): """This engine uses a pre-initialized local text-generation pipeline.""" - def __init__(self, pipeline: Pipeline): + def __init__(self, pipeline: Pipeline, model_id: Optional[str] = None): + super().__init__(model_id) self.pipeline = pipeline - def __call__( - self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None + def generate( + self, + messages: List[Dict[str, str]], + stop_sequences: Optional[List[str]] = None, + grammar: Optional[str] = None, + max_length: int = 1500, ) -> str: # Get clean message list messages = get_clean_message_list(messages, role_conversions=llama_role_conversions) # Get LLM output + if stop_sequences is not None and len(stop_sequences) > 0: + stop_strings = stop_sequences + else: + stop_strings = None + output = self.pipeline( messages, - stop_strings=stop_sequences, - max_length=1500, + stop_strings=stop_strings, + max_length=max_length, tokenizer=self.pipeline.tokenizer, ) response = output[0]["generated_text"][-1]["content"] - - # Remove stop sequences from LLM output - if stop_sequences is not None: - for stop_seq in stop_sequences: - if response[-len(stop_seq) :] == stop_seq: - response = response[: -len(stop_seq)] return response diff --git a/src/transformers/agents/monitoring.py b/src/transformers/agents/monitoring.py index 755418d35a56a3..7126e72b5fd060 100644 --- a/src/transformers/agents/monitoring.py +++ b/src/transformers/agents/monitoring.py @@ -14,8 +14,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from ..utils import logging from .agent_types import AgentAudio, AgentImage, AgentText -from .agents import ReactAgent + + +logger = logging.get_logger(__name__) def pull_message(step_log: dict, test_mode: bool = True): @@ -54,7 +57,7 @@ def __init__(self, role, content, metadata=None): ) -def stream_to_gradio(agent: ReactAgent, task: str, test_mode: bool = False, **kwargs): +def stream_to_gradio(agent, task: str, test_mode: bool = False, **kwargs): """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages.""" try: @@ -91,3 +94,24 @@ def __init__(self, role, content, metadata=None): ) else: yield ChatMessage(role="assistant", content=str(final_answer)) + + +class Monitor: + def __init__(self, tracked_llm_engine): + self.step_durations = [] + self.tracked_llm_engine = tracked_llm_engine + if getattr(self.tracked_llm_engine, "last_input_token_count", "Not found") != "Not found": + self.total_input_token_count = 0 + self.total_output_token_count = 0 + + def update_metrics(self, step_log): + step_duration = step_log["step_duration"] + self.step_durations.append(step_duration) + logger.info(f"Step {len(self.step_durations)}:") + logger.info(f"- Time taken: {step_duration:.2f} seconds (valid only if step succeeded)") + + if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None: + self.total_input_token_count += self.tracked_llm_engine.last_input_token_count + self.total_output_token_count += self.tracked_llm_engine.last_output_token_count + logger.info(f"- Input tokens: {self.total_input_token_count}") + logger.info(f"- Output tokens: {self.total_output_token_count}") diff --git a/src/transformers/agents/tools.py b/src/transformers/agents/tools.py index e33e5cf5ce3c9e..759704612c2f4f 100644 --- a/src/transformers/agents/tools.py +++ b/src/transformers/agents/tools.py @@ -785,21 +785,22 @@ def launch_gradio_demo(tool_class: Tool): def fn(*args, **kwargs): return tool(*args, **kwargs) + TYPE_TO_COMPONENT_CLASS_MAPPING = { + "image": gr.Image, + "audio": gr.Audio, + "string": gr.Textbox, + "integer": gr.Textbox, + "number": gr.Textbox, + } + gradio_inputs = [] for input_name, input_details in tool_class.inputs.items(): - input_type = input_details["type"] - if input_type == "image": - gradio_inputs.append(gr.Image(label=input_name)) - elif input_type == "audio": - gradio_inputs.append(gr.Audio(label=input_name)) - elif input_type in ["string", "integer", "number"]: - gradio_inputs.append(gr.Textbox(label=input_name)) - else: - error_message = f"Input type '{input_type}' not supported." - raise ValueError(error_message) + input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[input_details["type"]] + new_component = input_gradio_component_class(label=input_name) + gradio_inputs.append(new_component) - gradio_output = tool_class.output_type - assert gradio_output in ["string", "image", "audio"], f"Output type '{gradio_output}' not supported." + output_gradio_componentclass = TYPE_TO_COMPONENT_CLASS_MAPPING[tool_class.output_type] + gradio_output = output_gradio_componentclass(label=input_name) gr.Interface( fn=fn, diff --git a/tests/agents/test_monitoring.py b/tests/agents/test_monitoring.py index c43c9cb8bf86dd..c35074270ea272 100644 --- a/tests/agents/test_monitoring.py +++ b/tests/agents/test_monitoring.py @@ -21,11 +21,95 @@ class MonitoringTester(unittest.TestCase): + def test_code_agent_metrics(self): + class FakeLLMEngine: + def __init__(self): + self.last_input_token_count = 10 + self.last_output_token_count = 20 + + def __call__(self, prompt, **kwargs): + return """ +Code: +```py +final_answer('This is the final answer.') +```""" + + agent = ReactCodeAgent( + tools=[], + llm_engine=FakeLLMEngine(), + max_iterations=1, + ) + + agent.run("Fake task") + + self.assertEqual(agent.monitor.total_input_token_count, 10) + self.assertEqual(agent.monitor.total_output_token_count, 20) + + def test_json_agent_metrics(self): + class FakeLLMEngine: + def __init__(self): + self.last_input_token_count = 10 + self.last_output_token_count = 20 + + def __call__(self, prompt, **kwargs): + return 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}' + + agent = ReactJsonAgent( + tools=[], + llm_engine=FakeLLMEngine(), + max_iterations=1, + ) + + agent.run("Fake task") + + self.assertEqual(agent.monitor.total_input_token_count, 10) + self.assertEqual(agent.monitor.total_output_token_count, 20) + + def test_code_agent_metrics_max_iterations(self): + class FakeLLMEngine: + def __init__(self): + self.last_input_token_count = 10 + self.last_output_token_count = 20 + + def __call__(self, prompt, **kwargs): + return "Malformed answer" + + agent = ReactCodeAgent( + tools=[], + llm_engine=FakeLLMEngine(), + max_iterations=1, + ) + + agent.run("Fake task") + + self.assertEqual(agent.monitor.total_input_token_count, 20) + self.assertEqual(agent.monitor.total_output_token_count, 40) + + def test_code_agent_metrics_generation_error(self): + class FakeLLMEngine: + def __init__(self): + self.last_input_token_count = 10 + self.last_output_token_count = 20 + + def __call__(self, prompt, **kwargs): + raise AgentError + + agent = ReactCodeAgent( + tools=[], + llm_engine=FakeLLMEngine(), + max_iterations=1, + ) + + agent.run("Fake task") + + self.assertEqual(agent.monitor.total_input_token_count, 20) + self.assertEqual(agent.monitor.total_output_token_count, 40) + def test_streaming_agent_text_output(self): def dummy_llm_engine(prompt, **kwargs): return """ Code: -```` +```py final_answer('This is the final answer.') ```"""