Skip to content

Commit

Permalink
[Enhancement]Update system message (#272)
Browse files Browse the repository at this point in the history
* revise system

* typing

* fix typing of functions

* add system unit test

* fix unit test

* update_system_message

* update

* fix_unit_test

* merge conflict

* fix run llm

* add unit test for retrieval agent

* fix file

* update

* update

* fix_lint

---------

Co-authored-by: Southpika <[email protected]>
  • Loading branch information
shiyutang and Southpika authored Jan 10, 2024
1 parent 1c06a5d commit b55b450
Show file tree
Hide file tree
Showing 10 changed files with 84 additions and 41 deletions.
2 changes: 1 addition & 1 deletion docs/modules/memory.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
* add_messages(self, messages): 批量地增加message。
* add_message(self, message): 增加一条message。
* get_messages(self): 获取Memory中的所有messages。
* get_system_message(self): 获取Memory中的系统消息,Memory中有且仅有一条系统信息,可以传入LLM的* * system接口,用于建立LLM的特性。
* set_system_message(self): 设置Memory中的系统消息,Memory中有且仅有一条系统信息,通过Agent中system的接口进行同步,用于建立LLM的特性。
* clear_chat_history(self): 清除memory中所有message的历史。
* 关系:Memory的基类,关联到MessageManager类。

Expand Down
10 changes: 5 additions & 5 deletions erniebot-agent/applications/rpg_game/rpg_game_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
from erniebot_agent.agents.schema import AgentResponse
from erniebot_agent.chat_models.erniebot import ERNIEBot
from erniebot_agent.file import File
from erniebot_agent.memory.messages import AIMessage, HumanMessage, SystemMessage
from erniebot_agent.memory.messages import AIMessage, HumanMessage
from erniebot_agent.memory.sliding_window_memory import SlidingWindowMemory
from erniebot_agent.tools.base import BaseTool
from erniebot_agent.tools.image_generation_tool import (
ImageGenerationTool, # 目前为remotetool,如做直接展示可以替换为yinian
ImageGenerationTool, # 目前为自己搭建的remotetool,待aistudio上线直接替换
)
from erniebot_agent.tools.tool_manager import ToolManager

Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(
model: str,
script: str,
tools: Union[ToolManager, List[BaseTool]],
system_message: Optional[SystemMessage] = None,
system: Optional[str] = None,
access_token: Union[str, None] = None,
max_round: int = 3,
) -> None:
Expand All @@ -83,7 +83,7 @@ def __init__(
llm=ERNIEBot(model, api_type="aistudio", access_token=access_token),
memory=memory,
tools=tools,
system_message=system_message,
system=system,
)

async def handle_tool(self, tool_name: str, tool_args: str) -> str:
Expand Down Expand Up @@ -191,7 +191,7 @@ async def _handle_gradio_stream(self, history) -> AsyncGenerator:
model=args.model,
script=args.game,
tools=[ImageGenerationTool()],
system_message=SystemMessage(INSTRUCTION.format(SCRIPT=args.game)),
system=INSTRUCTION.format(SCRIPT=args.game),
access_token=access_token,
)
game_system.launch_gradio_demo()
41 changes: 27 additions & 14 deletions erniebot-agent/src/erniebot_agent/agents/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc
import json
import logging
from typing import (
Any,
Dict,
Expand Down Expand Up @@ -36,6 +37,8 @@

_PLUGINS_WO_FILE_IO: Final[Tuple[str]] = ("eChart",)

_logger = logging.getLogger(__name__)


class Agent(GradioMixin, BaseAgent[BaseERNIEBot]):
"""The base class for agents.
Expand All @@ -58,7 +61,7 @@ def __init__(
tools: Union[ToolManager, Iterable[BaseTool]],
*,
memory: Optional[Memory] = None,
system_message: Optional[SystemMessage] = None,
system: Optional[str] = None,
callbacks: Optional[Union[CallbackManager, Iterable[CallbackHandler]]] = None,
file_manager: Optional[FileManager] = None,
plugins: Optional[List[str]] = None,
Expand All @@ -69,12 +72,11 @@ def __init__(
llm: An LLM for the agent to use.
tools: Tools for the agent to use.
memory: A memory object that equips the agent to remember chat
history. If `None`, a `WholeMemory` object will be used.
system_message: A message that tells the LLM how to interpret the
conversations. If `None`, the system message contained in
`memory` will be used.
callbacks: Callback handlers for the agent to use. If `None`, a
default list of callbacks will be used.
history. If not specified, a new WholeMemory object will be instantiated.
system: A message that tells the LLM how to interpret the
conversations.
callbacks: A list of callback handlers for the agent to use. If
`None`, a default list of callbacks will be used.
file_manager: A file manager for the agent to interact with files.
If `None`, a global file manager that can be shared among
different components will be implicitly created and used.
Expand All @@ -91,10 +93,11 @@ def __init__(
if memory is None:
memory = self._create_default_memory()
self.memory = memory
if system_message:
self.system_message = system_message
else:
self.system_message = self.memory.get_system_message()

self.system = SystemMessage(system) if system is not None else system
if self.system is not None:
self.memory.set_system_message(self.system)

if callbacks is None:
callbacks = get_default_callbacks()
if isinstance(callbacks, CallbackManager):
Expand Down Expand Up @@ -225,11 +228,21 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
raise NotImplementedError

async def _run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse:
for reserved_opt in ("stream", "functions", "system", "plugins"):
for reserved_opt in ("stream", "system", "plugins"):
if reserved_opt in opts:
raise TypeError(f"`{reserved_opt}` should not be set.")
functions = self._tool_manager.get_tool_schemas()
opts["system"] = self.system_message.content if self.system_message is not None else None

if "functions" not in opts:
functions = self._tool_manager.get_tool_schemas()
else:
functions = opts.pop("functions")

if hasattr(self.llm, "system"):
_logger.warning(
"The `system` message has already been set in the agent;"
"the `system` message configured in ERNIEBot will become ineffective."
)
opts["system"] = self.system.content if self.system is not None else None
opts["plugins"] = self._plugins
llm_ret = await self.llm.chat(messages, stream=False, functions=functions, **opts)
return LLMResponse(message=llm_ret)
Expand Down
13 changes: 4 additions & 9 deletions erniebot-agent/src/erniebot_agent/agents/function_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,7 @@
from erniebot_agent.chat_models.erniebot import BaseERNIEBot
from erniebot_agent.file import File, FileManager
from erniebot_agent.memory import Memory
from erniebot_agent.memory.messages import (
FunctionMessage,
HumanMessage,
Message,
SystemMessage,
)
from erniebot_agent.memory.messages import FunctionMessage, HumanMessage, Message
from erniebot_agent.tools.base import BaseTool
from erniebot_agent.tools.tool_manager import ToolManager

Expand Down Expand Up @@ -67,7 +62,7 @@ def __init__(
tools: Union[ToolManager, Iterable[BaseTool]],
*,
memory: Optional[Memory] = None,
system_message: Optional[SystemMessage] = None,
system: Optional[str] = None,
callbacks: Optional[Union[CallbackManager, Iterable[CallbackHandler]]] = None,
file_manager: Optional[FileManager] = None,
plugins: Optional[List[str]] = None,
Expand All @@ -80,7 +75,7 @@ def __init__(
tools: A list of tools for the agent to use.
memory: A memory object that equips the agent to remember chat
history. If `None`, a `WholeMemory` object will be used.
system_message: A message that tells the LLM how to interpret the
system: A message that tells the LLM how to interpret the
conversations. If `None`, the system message contained in
`memory` will be used.
callbacks: A list of callback handlers for the agent to use. If
Expand All @@ -98,7 +93,7 @@ def __init__(
llm=llm,
tools=tools,
memory=memory,
system_message=system_message,
system=system,
callbacks=callbacks,
file_manager=file_manager,
plugins=plugins,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
try:
docs = self._enforce_token_limit(results)
step_input = HumanMessage(content=self.rag_prompt.format(query=prompt, documents=docs))
chat_history: List[Message] = [step_input]
chat_history: List[Message] = []
chat_history.append(step_input)
steps_taken: List[AgentStep] = []

tool_ret_json = json.dumps(results, ensure_ascii=False)
Expand All @@ -121,6 +122,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
)
llm_resp = await self._run_llm(
messages=chat_history,
functions=None,
)
output_message = llm_resp.message
if output_message.search_info is None:
Expand Down
13 changes: 8 additions & 5 deletions erniebot-agent/src/erniebot_agent/memory/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ class Memory:
def __init__(self):
self.msg_manager = MessageManager()

def set_system_message(self, message: SystemMessage):
"""Set the system message of a conversation."""
self.msg_manager.system_message = message

def add_messages(self, messages: List[Message]):
"""Add a list of messages to memory."""
for message in messages:
Expand All @@ -99,11 +103,10 @@ def add_message(self, message: Message):

def get_messages(self) -> List[Message]:
"""Get all the messages in memory."""
return self.msg_manager.retrieve_messages()

def get_system_message(self) -> SystemMessage:
"""Get the system message in memory."""
return self.msg_manager.system_message
if self.msg_manager.system_message is not None:
return [self.msg_manager.system_message] + self.msg_manager.retrieve_messages()
else:
return self.msg_manager.retrieve_messages()

def clear_chat_history(self):
"""Reset the memory."""
Expand Down
12 changes: 12 additions & 0 deletions erniebot-agent/tests/unit_tests/agents/test_function_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tests.unit_tests.testing_utils.components import CountingCallbackHandler
from tests.unit_tests.testing_utils.mocks.mock_chat_models import (
FakeERNIEBotWithPresetResponses,
FakeSimpleChatModel,
)
from tests.unit_tests.testing_utils.mocks.mock_memory import FakeMemory
from tests.unit_tests.testing_utils.mocks.mock_tool import FakeTool
Expand Down Expand Up @@ -219,3 +220,14 @@ async def test_function_agent_max_steps(identity_tool):
response = await agent.run("Run!")

assert response.status == "STOPPED"


@pytest.mark.asyncio
async def test_function_agent_system():
agent = FunctionAgent(
llm=FakeSimpleChatModel(),
tools=[],
system="You are a helpful bot.",
)
response = await agent.run("Run!")
assert "Recieved system message" in response.text
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,24 @@ async def test_functional_agent_with_retrieval_run_retrieval(identity_tool):
assert response.chat_history[0].content == "Hello, world!"
# AIMessage
assert response.chat_history[1].content == "Text response"


@pytest.mark.asyncio
async def test_function_agent_with_retrieval_system():
knowledge_base_name = "test"
access_token = "your access token"
knowledge_base_id = 111
with mock.patch("requests.post") as my_mock:
search_db = BaizhongSearch(
knowledge_base_name=knowledge_base_name,
access_token=access_token,
knowledge_base_id=knowledge_base_id if knowledge_base_id != "" else None,
)
agent = FunctionAgentWithRetrieval(
llm=FakeSimpleChatModel(), tools=[], system="You are a helpful bot.", knowledge_base=search_db
)
with mock.patch("requests.post") as my_mock:
my_mock.return_value = MagicMock(status_code=200, json=lambda: EXAMPLE_RESPONSE)
response = await agent.run("Run!")

assert "Recieved system message" in response.text
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ def response(self):
async def chat(self, messages, *, stream=False, **kwargs):
if stream:
raise ValueError("Streaming is not supported.")
if "system" in kwargs and kwargs["system"] is not None:
response = f"Recieved system message: {kwargs['system']}"
return AIMessage(content=response, function_call=None, token_usage=None)
return self.response


Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from erniebot_agent.memory import SystemMessage


class FakeMemory(object):
def __init__(self):
super().__init__()
Expand All @@ -16,8 +13,5 @@ def add_message(self, message):
def get_messages(self):
return self._history[:]

def get_system_message(self):
return SystemMessage("System message")

def clear_chat_history(self):
self._history.clear()

0 comments on commit b55b450

Please sign in to comment.