Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement]Add clarify #281

Merged
merged 11 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions erniebot-agent/src/erniebot_agent/agents/function_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
from erniebot_agent.agents.callback.callback_manager import CallbackManager
from erniebot_agent.agents.callback.handlers.base import CallbackHandler
from erniebot_agent.agents.schema import (
NO_ACTION_STEP,
DEFAULT_FINISH_STEP,
AgentResponse,
AgentStep,
NoActionStep,
EndInfo,
EndStep,
PluginStep,
ToolInfo,
ToolStep,
Expand Down Expand Up @@ -136,7 +137,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age

for tool in self._first_tools:
curr_step, new_messages = await self._step(chat_history, selected_tool=tool)
if not isinstance(curr_step, NoActionStep):
if not isinstance(curr_step, EndStep):
chat_history.extend(new_messages)
num_steps_taken += 1
steps_taken.append(curr_step)
Expand All @@ -147,11 +148,18 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
while num_steps_taken < self.max_steps:
curr_step, new_messages = await self._step(chat_history)
chat_history.extend(new_messages)
if not isinstance(curr_step, NoActionStep):
if isinstance(curr_step, ToolStep):
steps_taken.append(curr_step)

if isinstance(curr_step, (NoActionStep, PluginStep)): # plugin with action
response = self._create_finished_response(chat_history, steps_taken)
elif isinstance(curr_step, PluginStep):
steps_taken.append(curr_step)
# 预留 调用了Plugin之后不结束的接口

# 此处为调用了Plugin之后直接结束的Plugin
curr_step = DEFAULT_FINISH_STEP

if isinstance(curr_step, EndStep):
response = self._create_finished_response(chat_history, steps_taken, curr_step)
self.memory.add_message(chat_history[0])
self.memory.add_message(chat_history[-1])
return response
Expand Down Expand Up @@ -204,19 +212,23 @@ async def _step(
new_messages,
)
else:
return NO_ACTION_STEP, new_messages
if output_message.clarify:
# `clarify` and [`function_call`, `plugin`(directly end)] will not appear at the same time
return EndStep(info=EndInfo(end_reason="CLARIFY"), result=None), new_messages
return DEFAULT_FINISH_STEP, new_messages

def _create_finished_response(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里其实不一定是创建finished response了,这个方法名字需要修改吗?

self,
chat_history: List[Message],
steps: List[AgentStep],
curr_step: EndStep,
) -> AgentResponse:
last_message = chat_history[-1]
return AgentResponse(
text=last_message.content,
chat_history=chat_history,
steps=steps,
status="FINISHED",
status=curr_step.info["end_reason"],
)

def _create_stopped_response(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

from erniebot_agent.agents.function_agent import FunctionAgent
from erniebot_agent.agents.schema import (
DEFAULT_FINISH_STEP,
AgentResponse,
AgentStep,
EndStep,
File,
NoActionStep,
PluginStep,
ToolAction,
ToolInfo,
Expand Down Expand Up @@ -148,7 +149,9 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
await self._callback_manager.on_tool_error(agent=self, tool=self.search_tool, error=e)
raise
await self._callback_manager.on_tool_end(agent=self, tool=self.search_tool, response=tool_resp)
response = self._create_finished_response(chat_history, steps_taken)
response = self._create_finished_response(
chat_history, steps_taken, curr_step=DEFAULT_FINISH_STEP
)
self.memory.add_message(chat_history[0])
self.memory.add_message(chat_history[-1])
return response
Expand Down Expand Up @@ -249,10 +252,18 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
while num_steps_taken < self.max_steps:
curr_step, new_messages = await self._step(chat_history)
chat_history.extend(new_messages)
if not isinstance(curr_step, NoActionStep):
if isinstance(curr_step, ToolStep):
steps_taken.append(curr_step)

elif isinstance(curr_step, PluginStep):
steps_taken.append(curr_step)
if isinstance(curr_step, (NoActionStep, PluginStep)): # plugin with action
response = self._create_finished_response(chat_history, steps_taken)
# 预留 调用了Plugin之后不结束的接口

# 此处为调用了Plugin之后直接结束的Plugin
curr_step = DEFAULT_FINISH_STEP

if isinstance(curr_step, EndStep):
response = self._create_finished_response(chat_history, steps_taken, curr_step)
self.memory.add_message(chat_history[0])
self.memory.add_message(chat_history[-1])
return response
Expand Down Expand Up @@ -351,10 +362,18 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
while num_steps_taken < self.max_steps:
curr_step, new_messages = await self._step(chat_history)
chat_history.extend(new_messages)
if not isinstance(curr_step, NoActionStep):
if isinstance(curr_step, ToolStep):
steps_taken.append(curr_step)
if isinstance(curr_step, (NoActionStep, PluginStep)): # plugin with action
response = self._create_finished_response(chat_history, steps_taken)

elif isinstance(curr_step, PluginStep):
steps_taken.append(curr_step)
# 预留 调用了Plugin之后不结束的接口

# 此处为调用了Plugin之后直接结束的Plugin
curr_step = DEFAULT_FINISH_STEP

if isinstance(curr_step, EndStep): # plugin with action
response = self._create_finished_response(chat_history, steps_taken, curr_step=curr_step)
self.memory.add_message(chat_history[0])
self.memory.add_message(chat_history[-1])
return response
Expand Down
13 changes: 13 additions & 0 deletions erniebot-agent/src/erniebot_agent/agents/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,19 @@ class NoActionStep(AgentStep[_NullInfo, _NullResult]):
NO_ACTION_STEP = NoActionStep(info=_NullInfo(), result=_NullResult())


class EndInfo(Dict):
end_reason: str
extra_info: str # json format


@dataclass
class EndStep(AgentStep[EndInfo, None]):
"""A step taken by an agent that ends whole run."""


DEFAULT_FINISH_STEP = EndStep(info=EndInfo(end_reason="FINISHED"), result=None)


@dataclass
class AgentResponse(object):
"""The final response from an agent."""
Expand Down
22 changes: 17 additions & 5 deletions erniebot-agent/src/erniebot_agent/chat_models/erniebot.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(
api_type: str = "aistudio",
access_token: Optional[str] = None,
enable_multi_step_tool_call: bool = False,
enable_human_clarify: bool = False,
**default_chat_kwargs: Any,
) -> None:
"""Initializes an instance of the `ERNIEBot` class.
Expand All @@ -114,6 +115,7 @@ def __init__(
If access_token is None, the global access_token will be used.
enable_multi_step_tool_call (bool): Whether to enable the multi-step tool call.
Defaults to False.
enable_human_clarify (bool): Whether to enable the human clarify. Defaults to False.
**default_chat_kwargs: Keyword arguments, such as `_config_`, `top_p`, `temperature`,
`penalty_score`, and `system`.
"""
Expand All @@ -125,9 +127,9 @@ def __init__(
self.access_token = access_token
self._maybe_validate_qianfan_auth()

self.enable_multi_step_json = json.dumps(
{"multi_step_tool_call_close": not enable_multi_step_tool_call}
)
self.extra_data = {}
self.extra_data["multi_step_tool_call_close"] = not enable_multi_step_tool_call
self.extra_data["chat_with_human_close"] = not enable_human_clarify

@overload
async def chat(
Expand Down Expand Up @@ -178,6 +180,7 @@ async def chat(
If `stream` is False, returns a single message.
If `stream` is True, returns an asynchronous iterator of message chunks.
"""

cfg_dict = self._generate_config(messages, functions=functions, **kwargs)

response = await self._generate_response(cfg_dict, stream, functions)
Expand Down Expand Up @@ -260,14 +263,14 @@ async def _generate_response(
_config_=cfg_dict["_config_"],
functions=functions, # type: ignore
extra_params={
"extra_data": self.enable_multi_step_json,
"extra_data": json.dumps(self.extra_data),
},
)
else:
response = await erniebot.ChatCompletion.acreate(
stream=stream,
extra_params={
"extra_data": self.enable_multi_step_json,
"extra_data": json.dumps(self.extra_data),
},
**cfg_dict,
)
Expand All @@ -276,6 +279,11 @@ async def _generate_response(


def convert_response_to_output(response: ChatCompletionResponse, output_type: Type[_T]) -> _T:
clarify = False
# ernie-turbo has no `finish_reason`
if hasattr(response, "finish_reason") and response["finish_reason"] == "plugin_clarify":
clarify = True

if hasattr(response, "function_call"):
function_call = FunctionCall(
name=response.function_call["name"],
Expand All @@ -287,6 +295,7 @@ def convert_response_to_output(response: ChatCompletionResponse, output_type: Ty
function_call=function_call,
plugin_info=None,
search_info=None,
clarify=clarify,
token_usage=response.usage,
)
elif hasattr(response, "plugin_info"):
Expand All @@ -303,6 +312,7 @@ def convert_response_to_output(response: ChatCompletionResponse, output_type: Ty
plugin_info=plugin_info,
search_info=None,
token_usage=response.usage,
clarify=clarify,
)
elif hasattr(response, "search_info") and len(response.search_info.items()) > 0:
search_info = SearchInfo(
Expand All @@ -314,6 +324,7 @@ def convert_response_to_output(response: ChatCompletionResponse, output_type: Ty
plugin_info=None,
search_info=search_info,
token_usage=response.usage,
clarify=clarify,
)
else:
return output_type(
Expand All @@ -322,4 +333,5 @@ def convert_response_to_output(response: ChatCompletionResponse, output_type: Ty
plugin_info=None,
search_info=None,
token_usage=response.usage,
clarify=clarify,
)
2 changes: 2 additions & 0 deletions erniebot-agent/src/erniebot_agent/memory/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def __init__(
token_usage: Optional[TokenUsage] = None,
plugin_info: Optional[PluginInfo] = None,
search_info: Optional[SearchInfo] = None,
clarify: Optional[bool] = False,
):
if token_usage is None:
prompt_tokens = 0
Expand All @@ -280,6 +281,7 @@ def __init__(
self.query_tokens_count = prompt_tokens
self.plugin_info = plugin_info
self.search_info = search_info
self.clarify = clarify
self._to_dict_keys = ["role", "content", "function_call", "plugin_info", "search_info"]

def _parse_token_count(self, token_usage: TokenUsage):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,48 @@ async def test_function_call(self):

content = res.content or None
self.assertIsNotNone(content)

@pytest.mark.asyncio
async def test_function_call_with_clarify(self):
functions = [
{
"name": "get_current_weather",
"description": "获得指定地点的天气",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string", "description": "省,市名,例如:河北省"},
"unit": {"type": "string", "enum": ["摄氏度", "华氏度"]},
},
"required": ["location"],
},
"responses": {
"type": "object",
"properties": {
"temperature": {"type": "number", "description": "当前温度"},
"weather_condition": {"type": "string", "description": "当前天气状况,例如:晴,多云,雨等"},
"humidity": {"type": "number", "description": "当前湿度百分比"},
"wind_speed": {"type": "number", "description": "风速,单位为公里每小时或英里每小时"},
},
"required": ["temperature", "weather_condition"],
},
}
]
eb = ERNIEBot(
model="ernie-3.5",
api_type="aistudio",
enable_human_clarify=True,
enable_multi_step_tool_call=True,
)
messages = [
HumanMessage(content="这个地方今天天气如何?"),
]
res = await eb.chat(messages, functions=functions)
self.assertTrue(isinstance(res, AIMessage))
self.assertTrue(res.clarify)

messages.append(res)
messages.append(HumanMessage(content="深圳"))
res_2 = await eb.chat(messages, functions=functions)
self.assertTrue(hasattr(res_2, "function_call"))
self.assertTrue(res_2.function_call["arguments"], '{"location":"深圳"}')
Loading