-
Notifications
You must be signed in to change notification settings - Fork 27.4k
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -597,7 +597,31 @@ def __init__( | |
if "final_answer" not in self._toolbox.tools: | ||
self._toolbox.add_tool(FinalAnswerTool()) | ||
|
||
def run(self, task: str, **kwargs): | ||
|
||
def provide_final_answer(self, task) -> str: | ||
""" | ||
This method provides a final answer to the task, based on the logs of the agent's interactions. | ||
""" | ||
self.prompt = [ | ||
{ | ||
"role": MessageRole.SYSTEM, | ||
"content": "An agent tried to answer a user query but it failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:", | ||
} | ||
] | ||
self.prompt += self.write_inner_memory_from_logs()[1:] | ||
self.prompt += [ | ||
{ | ||
"role": MessageRole.USER, | ||
"content": f"Based on the above, please provide an answer to the following user request:\n{task}", | ||
} | ||
] | ||
try: | ||
return self.llm_engine(self.prompt, stop_sequences=["<end_action>", "Observation:"]) | ||
except Exception as e: | ||
return f"Error in generating final llm output: {e}." | ||
|
||
|
||
def run(self, task: str, stream: bool = False, **kwargs): | ||
""" | ||
Runs the agent for the given task. | ||
|
@@ -614,41 +638,59 @@ def run(self, task: str, **kwargs): | |
agent.run("What is the result of 2 power 3.7384?") | ||
``` | ||
""" | ||
if stream: | ||
return self.stream_run(task, **kwargs) | ||
else: | ||
return self.direct_run(task, **kwargs) | ||
|
||
|
||
def stream_run(self, task: str, **kwargs): | ||
self.initialize_for_run(task, **kwargs) | ||
|
||
final_answer = None | ||
iteration = 0 | ||
while final_answer is None and iteration < self.max_iterations: | ||
try: | ||
final_answer = self.step() | ||
step_logs = self.step() | ||
if 'final_answer' in step_logs: | ||
final_answer = step_logs['final_answer'] | ||
except AgentError as e: | ||
self.logger.error(e, exc_info=1) | ||
self.logs[-1]["error"] = e | ||
finally: | ||
iteration += 1 | ||
yield self.logs[-1] | ||
|
||
if final_answer is None and iteration == self.max_iterations: | ||
error_message = "Reached max iterations." | ||
self.logs.append({"error": AgentMaxIterationsError(error_message)}) | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
aymeric-roucher
Author
Contributor
|
||
self.logger.error(error_message, exc_info=1) | ||
final_answer = self.provide_final_answer(task) | ||
|
||
return final_answer | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
freddyaboulton
Contributor
|
||
|
||
|
||
def direct_run(self, task: str, **kwargs): | ||
self.initialize_for_run(task, **kwargs) | ||
|
||
self.prompt = [ | ||
{ | ||
"role": MessageRole.SYSTEM, | ||
"content": "An agent tried to answer a user query but it failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:", | ||
} | ||
] | ||
self.prompt += self.write_inner_memory_from_logs()[1:] | ||
self.prompt += [ | ||
{ | ||
"role": MessageRole.USER, | ||
"content": f"Based on the above, please provide an answer to the following user request:\n{task}", | ||
} | ||
] | ||
final_answer = None | ||
iteration = 0 | ||
while final_answer is None and iteration < self.max_iterations: | ||
try: | ||
final_answer = self.llm_engine(self.prompt, stop_sequences=["<end_action>", "Observation:"]) | ||
except Exception as e: | ||
final_answer = f"Error in generating final llm output: {e}." | ||
step_logs = self.step() | ||
if 'final_answer' in step_logs: | ||
final_answer = step_logs['final_answer'] | ||
except AgentError as e: | ||
self.logger.error(e, exc_info=1) | ||
self.logs[-1]["error"] = e | ||
finally: | ||
iteration += 1 | ||
|
||
if final_answer is None and iteration == self.max_iterations: | ||
error_message = "Reached max iterations." | ||
self.logs.append({"error": AgentMaxIterationsError(error_message)}) | ||
self.logger.error(error_message, exc_info=1) | ||
final_answer = self.provide_final_answer(task) | ||
|
||
return final_answer | ||
|
||
|
@@ -683,12 +725,14 @@ def step(self): | |
""" | ||
agent_memory = self.write_inner_memory_from_logs() | ||
|
||
self.logs[-1]["agent_memory"] = agent_memory.copy() | ||
self.prompt = agent_memory | ||
self.logger.debug("===== New step =====") | ||
|
||
# Add new step in logs | ||
self.logs.append({}) | ||
current_step_logs = {} | ||
self.logs.append(current_step_logs) | ||
current_step_logs["agent_memory"] = agent_memory.copy() | ||
|
||
self.logger.info("===== Calling LLM with this last message: =====") | ||
self.logger.info(self.prompt[-1]) | ||
|
||
|
@@ -698,7 +742,7 @@ def step(self): | |
raise AgentGenerationError(f"Error in generating llm output: {e}.") | ||
self.logger.debug("===== Output message of the LLM: =====") | ||
self.logger.debug(llm_output) | ||
self.logs[-1]["llm_output"] = llm_output | ||
current_step_logs["llm_output"] = llm_output | ||
|
||
# Parse | ||
self.logger.debug("===== Extracting action =====") | ||
|
@@ -709,8 +753,8 @@ def step(self): | |
except Exception as e: | ||
raise AgentParsingError(f"Could not parse the given action: {e}.") | ||
|
||
self.logs[-1]["rationale"] = rationale | ||
self.logs[-1]["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments} | ||
current_step_logs["rationale"] = rationale | ||
current_step_logs["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments} | ||
|
||
# Execute | ||
self.logger.warning(f"Calling tool: '{tool_name}' with arguments: {arguments}") | ||
|
@@ -740,8 +784,8 @@ def step(self): | |
updated_information = f"Stored '{observation_name}' in memory." | ||
|
||
self.logger.info(updated_information) | ||
self.logs[-1]["observation"] = updated_information | ||
return None | ||
current_step_logs["observation"] = updated_information | ||
return current_step_logs | ||
|
||
|
||
class ReactCodeAgent(ReactAgent): | ||
|
@@ -782,14 +826,15 @@ def step(self): | |
The errors are raised here, they are caught and logged in the run() method. | ||
""" | ||
agent_memory = self.write_inner_memory_from_logs() | ||
self.logs[-1]["agent_memory"] = agent_memory.copy() | ||
|
||
self.prompt = agent_memory.copy() | ||
|
||
self.logger.debug("===== New step =====") | ||
|
||
# Add new step in logs | ||
self.logs.append({}) | ||
current_step_logs = {} | ||
self.logs.append(current_step_logs) | ||
current_step_logs["agent_memory"] = agent_memory.copy() | ||
|
||
self.logger.info("===== Calling LLM with these last messages: =====") | ||
self.logger.info(self.prompt[-2:]) | ||
|
@@ -801,7 +846,7 @@ def step(self): | |
|
||
self.logger.debug("===== Output message of the LLM: =====") | ||
self.logger.debug(llm_output) | ||
self.logs[-1]["llm_output"] = llm_output | ||
current_step_logs["llm_output"] = llm_output | ||
|
||
# Parse | ||
self.logger.debug("===== Extracting action =====") | ||
|
@@ -813,8 +858,8 @@ def step(self): | |
error_msg = f"Error in code parsing: {e}. Make sure to provide correct code" | ||
raise AgentParsingError(error_msg) | ||
|
||
self.logs[-1]["rationale"] = rationale | ||
self.logs[-1]["tool_call"] = {"tool_name": "code interpreter", "tool_arguments": code_action} | ||
current_step_logs["rationale"] = rationale | ||
current_step_logs["tool_call"] = {"tool_name": "code interpreter", "tool_arguments": code_action} | ||
|
||
# Execute | ||
self.log_code_action(code_action) | ||
|
@@ -824,7 +869,7 @@ def step(self): | |
information = self.state["print_outputs"] | ||
self.logger.warning("Print outputs:") | ||
self.logger.log(32, information) | ||
self.logs[-1]["observation"] = information | ||
current_step_logs["observation"] = information | ||
except Exception as e: | ||
error_msg = f"Failed while trying to execute the code below:\n{CustomFormatter.reset + code_action + CustomFormatter.reset}\nThis failed due to the following error:\n{str(e)}" | ||
if "'dict' object has no attribute 'read'" in str(e): | ||
|
@@ -834,5 +879,5 @@ def step(self): | |
if line[: len("final_answer")] == "final_answer": | ||
self.logger.warning(">>> Final answer:") | ||
self.logger.log(32, result) | ||
return result | ||
return None | ||
current_step_logs["final_answer"] = result | ||
return current_step_logs |
Can you please also yield the max iterations error?