Skip to content

Commit

Permalink
Ensure concurrency safety
Browse files Browse the repository at this point in the history
  • Loading branch information
Bobholamovic committed Jan 10, 2024
1 parent 287462c commit 5b4bcd3
Show file tree
Hide file tree
Showing 8 changed files with 172 additions and 148 deletions.
27 changes: 8 additions & 19 deletions erniebot-agent/src/erniebot_agent/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
Final,
Iterable,
List,
NoReturn,
Optional,
Sequence,
Tuple,
Expand All @@ -32,7 +31,6 @@
from erniebot_agent.memory.messages import Message, SystemMessage
from erniebot_agent.tools.base import BaseTool
from erniebot_agent.tools.tool_manager import ToolManager
from erniebot_agent.utils.exceptions import FileError

_PLUGINS_WO_FILE_IO: Final[Tuple[str]] = ("eChart",)

Expand Down Expand Up @@ -110,6 +108,7 @@ def __init__(
if plugins is not None:
raise NotImplementedError("The use of plugins is not supported yet.")
self._init_file_needs_url()
self._is_running = False

@final
async def run(self, prompt: str, files: Optional[Sequence[File]] = None) -> AgentResponse:
Expand All @@ -123,8 +122,9 @@ async def run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Agen
Returns:
Response from the agent.
"""
if files:
await self._ensure_managed_files(files)
if self._is_running:
raise RuntimeError("The agent is already running.")
self._is_running = True
await self._callback_manager.on_run_start(agent=self, prompt=prompt)
try:
agent_resp = await self._run(prompt, files)
Expand All @@ -133,6 +133,8 @@ async def run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Agen
raise e
else:
await self._callback_manager.on_run_end(agent=self, response=agent_resp)
finally:
self._is_running = False
return agent_resp

@final
Expand Down Expand Up @@ -251,10 +253,10 @@ async def _run_tool(self, tool: BaseTool, tool_args: str) -> ToolResponse:
# XXX: Sniffing is less efficient and probably unnecessary.
# Can we make a protocol to statically recognize file inputs and outputs
# or can we have the tools introspect about this?
input_files = file_manager.sniff_and_extract_files_from_list(list(parsed_tool_args.values()))
input_files = await file_manager.sniff_and_extract_files_from_obj(parsed_tool_args)
tool_ret = await tool(**parsed_tool_args)
if isinstance(tool_ret, dict):
output_files = file_manager.sniff_and_extract_files_from_list(list(tool_ret.values()))
output_files = await file_manager.sniff_and_extract_files_from_obj(tool_ret.values())
else:
output_files = []
tool_ret_json = json.dumps(tool_ret, ensure_ascii=False)
Expand All @@ -279,16 +281,3 @@ def _parse_tool_args(self, tool_args: str) -> Dict[str, Any]:
if not isinstance(args_dict, dict):
raise ValueError(f"`tool_args` cannot be interpreted as a dict. `tool_args`: {tool_args}")
return args_dict

async def _ensure_managed_files(self, files: Sequence[File]) -> None:
def _raise_exception(file: File) -> NoReturn:
raise FileError(f"{repr(file)} is not managed by the file manager of the agent.")

file_manager = self.get_file_manager()
for file in files:
try:
managed_file = file_manager.look_up_file_by_id(file.id)
except FileError:
_raise_exception(file)
if file is not managed_file:
_raise_exception(file)
6 changes: 3 additions & 3 deletions erniebot-agent/src/erniebot_agent/agents/function_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,10 @@ async def _step(
PluginStep(
info=output_message.plugin_info,
result=output_message.content,
input_files=file_manager.sniff_and_extract_files_from_text(
chat_history[-1].content
input_files=await file_manager.sniff_and_extract_files_from_text(
input_messages[-1].content
), # TODO: make sure this is correct.
output_files=file_manager.sniff_and_extract_files_from_text(output_message.content),
output_files=[],
),
new_messages,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ async def _upload(file, history):
history = history + [((single_file.name,), None)]
size = len(file)

output_lis = file_manager.list_registered_files()
output_lis = await file_manager.list_files()
item = ""
for i in range(len(output_lis) - size):
item += f'<li>{str(output_lis[i]).strip("<>")}</li>'
Expand Down
2 changes: 1 addition & 1 deletion erniebot-agent/src/erniebot_agent/file/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
>>> file_manager = GlobalFileManagerHandler().get()
>>> local_file = await file_manager.create_file_from_path(file_path='your_path', file_type='local')
>>> file = file_manager.look_up_file_by_id(file_id='your_file_id')
>>> file = await file_manager.look_up_file_by_id(file_id='your_file_id')
>>> file_content = await file.read_contents() # get file content(bytes)
>>> await local_file.write_contents_to('your_willing_path') # save to location you want
"""
Expand Down
Loading

0 comments on commit 5b4bcd3

Please sign in to comment.