Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Agent] Enhance FileManager and simplify the use of file managers #193

Merged
merged 33 commits into from
Dec 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
260e9bc
Fix makefiles
Bobholamovic Dec 21, 2023
8b23664
Fix bugs
Bobholamovic Dec 21, 2023
0a9e9e6
Enhance file_io
Bobholamovic Dec 21, 2023
38cb8ff
Update library code
Bobholamovic Dec 21, 2023
f178300
Update examples
Bobholamovic Dec 21, 2023
4cd12a1
Update tests
Bobholamovic Dec 21, 2023
23287ad
Fix exceptions
Bobholamovic Dec 21, 2023
4807109
Fix error info
Bobholamovic Dec 21, 2023
bc8bbd8
Remove use of environment variables in integration tests
Bobholamovic Dec 21, 2023
b280333
Fix protocol
Bobholamovic Dec 21, 2023
f4f61e6
Fix type hints
Bobholamovic Dec 21, 2023
0b02a1a
Fix
Bobholamovic Dec 22, 2023
82b63a5
Fix typing
Bobholamovic Dec 22, 2023
1a5157b
Merge remote-tracking branch 'official' into agent/feat/add_environs
Bobholamovic Dec 22, 2023
d54a93b
Fix style
Bobholamovic Dec 22, 2023
7df8462
Fix cleanup bugs
Bobholamovic Dec 22, 2023
dc5024a
Fix data race
Bobholamovic Dec 22, 2023
866ba6d
Fix bugs
Bobholamovic Dec 22, 2023
29be29b
Show file type
Bobholamovic Dec 22, 2023
b74ead7
Fix bugs
Bobholamovic Dec 22, 2023
961faea
Merge branch 'develop' into agent/feat/add_environs
Bobholamovic Dec 22, 2023
9b73220
Fix linting issues
Bobholamovic Dec 22, 2023
b19b42a
Fix linting issues
Bobholamovic Dec 22, 2023
1f28164
Fix integration tests
Bobholamovic Dec 22, 2023
d372395
Remove unused file
Bobholamovic Dec 22, 2023
5e3273d
Fix and enhance
Bobholamovic Dec 24, 2023
c0858c9
Fix style
Bobholamovic Dec 24, 2023
1e6c17f
Fix CI
Bobholamovic Dec 24, 2023
f84514d
Update CI config
Bobholamovic Dec 24, 2023
87ae9b2
Merge remote-tracking branch 'official/develop' into agent/feat/add_e…
Bobholamovic Dec 24, 2023
275c572
Fix style
Bobholamovic Dec 24, 2023
357a28c
Remove anchors
Bobholamovic Dec 24, 2023
04d6ee1
Enhance mixins
Bobholamovic Dec 25, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions .github/workflows/agent_ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,24 @@ jobs:
cache-dependency-path: |
erniebot/setup.cfg
erniebot-agent/setup.cfg
erniebot-agent/dev-requirements.txt
erniebot-agent/*-requirements.txt
erniebot-agent/tests/requirements.txt
- name: Install erniebot-agent and dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -r erniebot-agent/dev-requirements.txt
python -m pip install ./erniebot
python -m pip install ./erniebot-agent[all]
python -m pip install -r erniebot-agent/dev-requirements.txt
- name: Show make version
run: make --version
- name: Format Python code
run: make format
- name: Perform format checks on Python code
run: make format-check
working-directory: erniebot-agent
- name: Lint Python code
run: make lint
working-directory: erniebot-agent
- name: Type-check Python code
run: make type_check
- name: Perform type checks on Python code
run: make type-check
working-directory: erniebot-agent
UnitTest:
name: Unit Test
Expand All @@ -59,17 +60,18 @@ jobs:
cache-dependency-path: |
erniebot/setup.cfg
erniebot-agent/setup.cfg
erniebot-agent/*-requirements.txt
erniebot-agent/tests/requirements.txt
- name: Install erniebot-agent and dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -r erniebot-agent/tests/requirements.txt
python -m pip install ./erniebot
python -m pip install ./erniebot-agent[all]
python -m pip install -r erniebot-agent/tests/requirements.txt
- name: Show make version
run: make --version
- name: Run unit tests
run: make test_coverage
run: make coverage
working-directory: erniebot-agent
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v3
Expand Down
14 changes: 7 additions & 7 deletions .github/workflows/client_ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,17 @@ jobs:
cache: pip # Caching pip dependencies
cache-dependency-path: |
setup.cfg
dev-requirements.txt
- name: Install dependencies
*-requirements.txt
- name: Install erniebot and dependencies
run: |
python -m pip install -r dev-requirements.txt
python -m pip install --upgrade pip
python -m pip install .
python -m pip install -r dev-requirements.txt
- name: Show make version
run: make --version
- name: Format Python code
run: make format
- name: Perform format checks on Python code
run: make format-check
- name: Lint Python code
run: make lint
- name: Type-check Python code
run: make type_check
- name: Perform type checks on Python code
run: make type-check
20 changes: 14 additions & 6 deletions erniebot-agent/Makefile
Original file line number Diff line number Diff line change
@@ -1,23 +1,31 @@
.DEFAULT_GOAL = format lint type_check
.DEFAULT_GOAL = dev
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.DEFAULT_GOAL不能是多目标

files_to_format_and_lint = src examples tests

.PHONY: dev
dev: format lint type-check

.PHONY: format
format:
python -m black $(files_to_format_and_lint)
python -m isort --filter-files $(files_to_format_and_lint)
python -m isort $(files_to_format_and_lint)

.PHONY: format-check
format-check:
python -m black --check --diff $(files_to_format_and_lint)
python -m isort --check-only --diff $(files_to_format_and_lint)

.PHONY: lint
lint:
python -m flake8 $(files_to_format_and_lint)

.PHONY: type_check
type_check:
.PHONY: type-check
type-check:
python -m mypy src

.PHONY: test
test:
python -m pytest tests/unit_tests

.PHONY: test_coverage
test_coverage:
.PHONY: coverage
coverage:
python -m pytest tests/unit_tests --cov erniebot_agent --cov-report xml:coverage.xml
15 changes: 8 additions & 7 deletions erniebot-agent/examples/cv_agent/CV_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from erniebot_agent.agents.functional_agent import FunctionalAgent
from erniebot_agent.chat_models.erniebot import ERNIEBot
from erniebot_agent.file_io import get_file_manager
from erniebot_agent.file import GlobalFileManagerHandler
from erniebot_agent.memory.whole_memory import WholeMemory
from erniebot_agent.tools import RemoteToolkit

Expand All @@ -14,14 +14,15 @@ def __init__(self):
self.tools = self.toolkit.get_tools()


llm = ERNIEBot(model="ernie-3.5", api_type="aistudio", access_token="<your-access-token>")
toolkit = CVToolkit()
memory = WholeMemory()
file_manager = get_file_manager()
agent = FunctionalAgent(llm=llm, tools=toolkit.tools, memory=memory, file_manager=file_manager)
async def run_agent():
await GlobalFileManagerHandler().configure(access_token="<your-access-token>")

llm = ERNIEBot(model="ernie-bot", api_type="aistudio", access_token="<your-access-token>")
toolkit = CVToolkit()
memory = WholeMemory()
agent = FunctionalAgent(llm=llm, tools=toolkit.tools, memory=memory)

async def run_agent():
file_manager = await GlobalFileManagerHandler().get()
seg_file = await file_manager.create_file_from_path(file_path="cityscapes_demo.png", file_type="local")
clas_file = await file_manager.create_file_from_path(file_path="class_img.jpg", file_type="local")
ocr_file = await file_manager.create_file_from_path(file_path="ch.png", file_type="local")
Expand Down
10 changes: 5 additions & 5 deletions erniebot-agent/examples/plugins/multiple_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from erniebot_agent.agents.callback.default import get_no_ellipsis_callback
from erniebot_agent.agents.functional_agent import FunctionalAgent
from erniebot_agent.chat_models.erniebot import ERNIEBot
from erniebot_agent.file_io import get_file_manager
from erniebot_agent.file import GlobalFileManagerHandler
from erniebot_agent.memory import AIMessage, HumanMessage, Message
from erniebot_agent.memory.sliding_window_memory import SlidingWindowMemory
from erniebot_agent.messages import AIMessage, HumanMessage, Message
from erniebot_agent.tools.base import Tool
from erniebot_agent.tools.calculator_tool import CalculatorTool
from erniebot_agent.tools.schema import ToolParameterView
Expand All @@ -32,7 +32,7 @@ async def __call__(self, input_file_id: str, repeat_times: int) -> Dict[str, Any
if "<split>" in input_file_id:
input_file_id = input_file_id.split("<split>")[0]

file_manager = get_file_manager() # Access_token needs to be set here.
file_manager = await GlobalFileManagerHandler().get()
input_file = file_manager.look_up_file_by_id(input_file_id)
if input_file is None:
raise RuntimeError("File not found")
Expand Down Expand Up @@ -109,20 +109,20 @@ def examples(self) -> List[Message]:
# TODO(shiyutang): replace this when model is online
llm = ERNIEBot(model="ernie-3.5", api_type="custom")
memory = SlidingWindowMemory(max_round=1)
file_manager = get_file_manager(access_token="") # Access_token needs to be set here.
# plugins = ["ChatFile", "eChart"]
plugins: List[str] = []
agent = FunctionalAgent(
llm=llm,
tools=[TextRepeaterTool(), TextRepeaterNoFileTool(), CalculatorTool()],
memory=memory,
file_manager=file_manager,
callbacks=get_no_ellipsis_callback(),
plugins=plugins,
)


async def run_agent():
file_manager = await GlobalFileManagerHandler().get()

docx_file = await file_manager.create_file_from_path(
file_path="浅谈牛奶的营养与消费趋势.docx",
file_type="remote",
Expand Down
5 changes: 1 addition & 4 deletions erniebot-agent/examples/rpg_game_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@
from erniebot_agent.agents.base import Agent
from erniebot_agent.agents.schema import AgentFile, AgentResponse
from erniebot_agent.chat_models.erniebot import ERNIEBot
from erniebot_agent.file_io import get_file_manager
from erniebot_agent.file_io.base import File
from erniebot_agent.file_io.file_manager import FileManager
from erniebot_agent.file.base import File
from erniebot_agent.memory.sliding_window_memory import SlidingWindowMemory
from erniebot_agent.messages import AIMessage, HumanMessage, SystemMessage
from erniebot_agent.tools.base import BaseTool
Expand Down Expand Up @@ -87,7 +85,6 @@ def __init__(
tools=tools,
system_message=system_message,
)
self.file_manager: FileManager = get_file_manager()

async def handle_tool(self, tool_name: str, tool_args: str) -> str:
tool_response = await self._async_run_tool(
Expand Down
1 change: 1 addition & 0 deletions erniebot-agent/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
aiohttp
anyio
asyncio-atexit
# erniebot
jinja2
langchain
Expand Down
65 changes: 32 additions & 33 deletions erniebot-agent/src/erniebot_agent/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

import abc
import json
from typing import Any, Dict, List, Literal, Optional, Union
import logging
from typing import Any, Dict, List, Literal, Optional, Union, final

from erniebot_agent.agents.callback.callback_manager import CallbackManager
from erniebot_agent.agents.callback.default import get_default_callbacks
Expand All @@ -26,16 +27,17 @@
ToolResponse,
)
from erniebot_agent.chat_models.base import ChatModel
from erniebot_agent.file import get_file_manager
from erniebot_agent.file import GlobalFileManagerHandler, protocol
from erniebot_agent.file.base import File
from erniebot_agent.file.file_manager import FileManager
from erniebot_agent.file.protocol import is_local_file_id, is_remote_file_id
from erniebot_agent.memory import Memory
from erniebot_agent.memory.messages import Message, SystemMessage
from erniebot_agent.tools.base import BaseTool
from erniebot_agent.tools.tool_manager import ToolManager
from erniebot_agent.utils.gradio_mixin import GradioMixin
from erniebot_agent.utils.logging import logger
from erniebot_agent.utils.exceptions import FileError
from erniebot_agent.utils.mixins import GradioMixin

logger = logging.getLogger(__name__)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

采用统一的logging风格,后同。



class BaseAgent(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -75,25 +77,24 @@ def __init__(
self._callback_manager = callbacks
else:
self._callback_manager = CallbackManager(callbacks)
if file_manager is None:
file_manager = get_file_manager()
self.plugins = plugins
self._file_manager = file_manager
self._plugins = plugins
self._init_file_repr()

def _init_file_repr(self):
self.file_needs_url = False

if self.plugins:
if self._plugins:
PLUGIN_WO_FILE = ["eChart"]
for plugin in self.plugins:
for plugin in self._plugins:
if plugin not in PLUGIN_WO_FILE:
self.file_needs_url = True

@property
def tools(self) -> List[BaseTool]:
return self._tool_manager.get_tools()

@final
async def async_run(self, prompt: str, files: Optional[List[File]] = None) -> AgentResponse:
await self._callback_manager.on_run_start(agent=self, prompt=prompt)
agent_resp = await self._async_run(prompt, files)
Expand All @@ -113,6 +114,7 @@ def reset_memory(self) -> None:
async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> AgentResponse:
raise NotImplementedError

@final
async def _async_run_tool(self, tool_name: str, tool_args: str) -> ToolResponse:
tool = self._tool_manager.get_tool(tool_name)
await self._callback_manager.on_tool_start(agent=self, tool=tool, input_args=tool_args)
Expand All @@ -124,6 +126,7 @@ async def _async_run_tool(self, tool_name: str, tool_args: str) -> ToolResponse:
await self._callback_manager.on_tool_end(agent=self, tool=tool, response=tool_resp)
return tool_resp

@final
async def _async_run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse:
await self._callback_manager.on_llm_start(agent=self, llm=self.llm, messages=messages)
try:
Expand Down Expand Up @@ -158,10 +161,10 @@ def _parse_tool_args(self, tool_args: str) -> Dict[str, Any]:
try:
args_dict = json.loads(tool_args)
except json.JSONDecodeError:
raise ValueError(f"`tool_args` cannot be parsed as JSON. `tool_args` is {tool_args}")
raise ValueError(f"`tool_args` cannot be parsed as JSON. `tool_args`: {tool_args}")

if not isinstance(args_dict, dict):
raise ValueError(f"`tool_args` cannot be interpreted as a dict. It loads as {args_dict} ")
raise ValueError(f"`tool_args` cannot be interpreted as a dict. `tool_args`: {tool_args}")
return args_dict

async def _sniff_and_extract_files_from_args(
Expand All @@ -170,30 +173,26 @@ async def _sniff_and_extract_files_from_args(
agent_files: List[AgentFile] = []
for val in args.values():
if isinstance(val, str):
if is_local_file_id(val):
if self._file_manager is None:
logger.warning(
f"A file is used by {repr(tool)}, but the agent has no file manager to fetch it."
)
continue
file = self._file_manager.look_up_file_by_id(val)
if file is None:
raise RuntimeError(f"Unregistered ID {repr(val)} is used by {repr(tool)}.")
elif is_remote_file_id(val):
if self._file_manager is None:
logger.warning(
f"A file is used by {repr(tool)}, but the agent has no file manager to fetch it."
)
continue
file = self._file_manager.look_up_file_by_id(val)
if file is None:
file = await self._file_manager.retrieve_remote_file_by_id(val)
else:
continue
agent_files.append(AgentFile(file=file, type=file_type, used_by=tool.tool_name))
if protocol.is_file_id(val):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

重构后FileManager不再考虑多FileManager共享一个FileRegistry的情形,这大大简化了我们要处理的情况,例如这里只需要判断FileManager在自己的registry中持有的文件。

file_manager = await self._get_file_manager()
try:
file = file_manager.look_up_file_by_id(val)
except FileError as e:
raise FileError(
f"Unregistered file with ID {repr(val)} is used by {repr(tool)}."
f" File type: {file_type}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

突然发现在look_up_file_by_id中 如果file为None已经抛出FileError了

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

) from e
agent_files.append(AgentFile(file=file, type=file_type, used_by=tool.tool_name))
elif isinstance(val, dict):
agent_files.extend(await self._sniff_and_extract_files_from_args(val, tool, file_type))
elif isinstance(val, list) and len(val) > 0 and isinstance(val[0], dict):
for item in val:
agent_files.extend(await self._sniff_and_extract_files_from_args(item, tool, file_type))
return agent_files

async def _get_file_manager(self) -> FileManager:
if self._file_manager is None:
file_manager = await GlobalFileManagerHandler().get()
else:
file_manager = self._file_manager
return file_manager
Original file line number Diff line number Diff line change
Expand Up @@ -32,37 +32,31 @@
class CallbackManager(object):
def __init__(self, handlers: List[CallbackHandler]):
super().__init__()
self._handlers = handlers
self._handlers: List[CallbackHandler] = []
self.set_handlers(handlers)

@property
def handlers(self) -> List[CallbackHandler]:
return self._handlers

def add_handler(self, handler: CallbackHandler):
if handler in self._handlers:
raise RuntimeError(f"The callback handler {handler} is already registered.")
self._handlers.append(handler)

def remove_handler(self, handler):
try:
self._handlers.remove(handler)
except ValueError as e:
raise RuntimeError(f"The callback handler {handler} is not registered.") from e
self._handlers.remove(handler)

def set_handlers(self, handlers: List[CallbackHandler]):
self._handlers = []
for handler in handlers:
self.add_handler(handler)
self._handlers[:] = handlers

def remove_all_handlers(self):
self._handlers = []
self._handlers.clear()

async def handle_event(self, event_type: EventType, *args: Any, **kwargs: Any) -> None:
callback_name = "on_" + event_type.value
for handler in self._handlers:
callback = getattr(handler, callback_name, None)
if not inspect.iscoroutinefunction(callback):
raise TypeError("Callback must be a coroutine function.")
raise RuntimeError("Callback must be a coroutine function.")
await callback(*args, **kwargs)

async def on_run_start(self, agent: Agent, prompt: str) -> None:
Expand Down
Loading