-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Agent] Enhance FileManager
and simplify the use of file managers
#193
Changes from all commits
260e9bc
8b23664
0a9e9e6
38cb8ff
f178300
4cd12a1
23287ad
4807109
bc8bbd8
b280333
f4f61e6
0b02a1a
82b63a5
1a5157b
d54a93b
7df8462
dc5024a
866ba6d
29be29b
b74ead7
961faea
9b73220
b19b42a
1f28164
d372395
5e3273d
c0858c9
1e6c17f
f84514d
87ae9b2
275c572
357a28c
04d6ee1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,31 @@ | ||
.DEFAULT_GOAL = format lint type_check | ||
.DEFAULT_GOAL = dev | ||
files_to_format_and_lint = src examples tests | ||
|
||
.PHONY: dev | ||
dev: format lint type-check | ||
|
||
.PHONY: format | ||
format: | ||
python -m black $(files_to_format_and_lint) | ||
python -m isort --filter-files $(files_to_format_and_lint) | ||
python -m isort $(files_to_format_and_lint) | ||
|
||
.PHONY: format-check | ||
format-check: | ||
python -m black --check --diff $(files_to_format_and_lint) | ||
python -m isort --check-only --diff $(files_to_format_and_lint) | ||
|
||
.PHONY: lint | ||
lint: | ||
python -m flake8 $(files_to_format_and_lint) | ||
|
||
.PHONY: type_check | ||
type_check: | ||
.PHONY: type-check | ||
type-check: | ||
python -m mypy src | ||
|
||
.PHONY: test | ||
test: | ||
python -m pytest tests/unit_tests | ||
|
||
.PHONY: test_coverage | ||
test_coverage: | ||
.PHONY: coverage | ||
coverage: | ||
python -m pytest tests/unit_tests --cov erniebot_agent --cov-report xml:coverage.xml |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
aiohttp | ||
anyio | ||
asyncio-atexit | ||
# erniebot | ||
jinja2 | ||
langchain | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,8 @@ | |
|
||
import abc | ||
import json | ||
from typing import Any, Dict, List, Literal, Optional, Union | ||
import logging | ||
from typing import Any, Dict, List, Literal, Optional, Union, final | ||
|
||
from erniebot_agent.agents.callback.callback_manager import CallbackManager | ||
from erniebot_agent.agents.callback.default import get_default_callbacks | ||
|
@@ -26,16 +27,17 @@ | |
ToolResponse, | ||
) | ||
from erniebot_agent.chat_models.base import ChatModel | ||
from erniebot_agent.file import get_file_manager | ||
from erniebot_agent.file import GlobalFileManagerHandler, protocol | ||
from erniebot_agent.file.base import File | ||
from erniebot_agent.file.file_manager import FileManager | ||
from erniebot_agent.file.protocol import is_local_file_id, is_remote_file_id | ||
from erniebot_agent.memory import Memory | ||
from erniebot_agent.memory.messages import Message, SystemMessage | ||
from erniebot_agent.tools.base import BaseTool | ||
from erniebot_agent.tools.tool_manager import ToolManager | ||
from erniebot_agent.utils.gradio_mixin import GradioMixin | ||
from erniebot_agent.utils.logging import logger | ||
from erniebot_agent.utils.exceptions import FileError | ||
from erniebot_agent.utils.mixins import GradioMixin | ||
|
||
logger = logging.getLogger(__name__) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 采用统一的logging风格,后同。 |
||
|
||
|
||
class BaseAgent(metaclass=abc.ABCMeta): | ||
|
@@ -75,25 +77,24 @@ def __init__( | |
self._callback_manager = callbacks | ||
else: | ||
self._callback_manager = CallbackManager(callbacks) | ||
if file_manager is None: | ||
file_manager = get_file_manager() | ||
self.plugins = plugins | ||
self._file_manager = file_manager | ||
self._plugins = plugins | ||
self._init_file_repr() | ||
|
||
def _init_file_repr(self): | ||
self.file_needs_url = False | ||
|
||
if self.plugins: | ||
if self._plugins: | ||
PLUGIN_WO_FILE = ["eChart"] | ||
for plugin in self.plugins: | ||
for plugin in self._plugins: | ||
if plugin not in PLUGIN_WO_FILE: | ||
self.file_needs_url = True | ||
|
||
@property | ||
def tools(self) -> List[BaseTool]: | ||
return self._tool_manager.get_tools() | ||
|
||
@final | ||
async def async_run(self, prompt: str, files: Optional[List[File]] = None) -> AgentResponse: | ||
await self._callback_manager.on_run_start(agent=self, prompt=prompt) | ||
agent_resp = await self._async_run(prompt, files) | ||
|
@@ -113,6 +114,7 @@ def reset_memory(self) -> None: | |
async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> AgentResponse: | ||
raise NotImplementedError | ||
|
||
@final | ||
async def _async_run_tool(self, tool_name: str, tool_args: str) -> ToolResponse: | ||
tool = self._tool_manager.get_tool(tool_name) | ||
await self._callback_manager.on_tool_start(agent=self, tool=tool, input_args=tool_args) | ||
|
@@ -124,6 +126,7 @@ async def _async_run_tool(self, tool_name: str, tool_args: str) -> ToolResponse: | |
await self._callback_manager.on_tool_end(agent=self, tool=tool, response=tool_resp) | ||
return tool_resp | ||
|
||
@final | ||
async def _async_run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse: | ||
await self._callback_manager.on_llm_start(agent=self, llm=self.llm, messages=messages) | ||
try: | ||
|
@@ -158,10 +161,10 @@ def _parse_tool_args(self, tool_args: str) -> Dict[str, Any]: | |
try: | ||
args_dict = json.loads(tool_args) | ||
except json.JSONDecodeError: | ||
raise ValueError(f"`tool_args` cannot be parsed as JSON. `tool_args` is {tool_args}") | ||
raise ValueError(f"`tool_args` cannot be parsed as JSON. `tool_args`: {tool_args}") | ||
|
||
if not isinstance(args_dict, dict): | ||
raise ValueError(f"`tool_args` cannot be interpreted as a dict. It loads as {args_dict} ") | ||
raise ValueError(f"`tool_args` cannot be interpreted as a dict. `tool_args`: {tool_args}") | ||
return args_dict | ||
|
||
async def _sniff_and_extract_files_from_args( | ||
|
@@ -170,30 +173,26 @@ async def _sniff_and_extract_files_from_args( | |
agent_files: List[AgentFile] = [] | ||
for val in args.values(): | ||
if isinstance(val, str): | ||
if is_local_file_id(val): | ||
if self._file_manager is None: | ||
logger.warning( | ||
f"A file is used by {repr(tool)}, but the agent has no file manager to fetch it." | ||
) | ||
continue | ||
file = self._file_manager.look_up_file_by_id(val) | ||
if file is None: | ||
raise RuntimeError(f"Unregistered ID {repr(val)} is used by {repr(tool)}.") | ||
elif is_remote_file_id(val): | ||
if self._file_manager is None: | ||
logger.warning( | ||
f"A file is used by {repr(tool)}, but the agent has no file manager to fetch it." | ||
) | ||
continue | ||
file = self._file_manager.look_up_file_by_id(val) | ||
if file is None: | ||
file = await self._file_manager.retrieve_remote_file_by_id(val) | ||
else: | ||
continue | ||
agent_files.append(AgentFile(file=file, type=file_type, used_by=tool.tool_name)) | ||
if protocol.is_file_id(val): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 重构后 |
||
file_manager = await self._get_file_manager() | ||
try: | ||
file = file_manager.look_up_file_by_id(val) | ||
except FileError as e: | ||
raise FileError( | ||
f"Unregistered file with ID {repr(val)} is used by {repr(tool)}." | ||
f" File type: {file_type}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 突然发现在look_up_file_by_id中 如果file为None已经抛出FileError了 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
) from e | ||
agent_files.append(AgentFile(file=file, type=file_type, used_by=tool.tool_name)) | ||
elif isinstance(val, dict): | ||
agent_files.extend(await self._sniff_and_extract_files_from_args(val, tool, file_type)) | ||
elif isinstance(val, list) and len(val) > 0 and isinstance(val[0], dict): | ||
for item in val: | ||
agent_files.extend(await self._sniff_and_extract_files_from_args(item, tool, file_type)) | ||
return agent_files | ||
|
||
async def _get_file_manager(self) -> FileManager: | ||
if self._file_manager is None: | ||
file_manager = await GlobalFileManagerHandler().get() | ||
else: | ||
file_manager = self._file_manager | ||
return file_manager |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.DEFAULT_GOAL
不能是多目标