Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement]Add clarify #281

Merged
merged 11 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions erniebot-agent/src/erniebot_agent/agents/function_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
NO_ACTION_STEP,
AgentResponse,
AgentStep,
EndInfo,
NoActionStep,
NullResult,
PluginStep,
ToolInfo,
ToolStep,
Expand Down Expand Up @@ -151,7 +153,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
steps_taken.append(curr_step)

if isinstance(curr_step, (NoActionStep, PluginStep)): # plugin with action
response = self._create_finished_response(chat_history, steps_taken)
response = self._create_finished_response(chat_history, steps_taken, curr_step)
self.memory.add_message(chat_history[0])
self.memory.add_message(chat_history[-1])
return response
Expand Down Expand Up @@ -204,20 +206,33 @@ async def _step(
new_messages,
)
else:
if output_message.clarify:
# `clarify` and [`function_call`, `plugin`(directly end)] will not appear at the same time
return NoActionStep(info=EndInfo(end_reason="CLARIFY"), result=NullResult()), new_messages
return NO_ACTION_STEP, new_messages

def _create_finished_response(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里其实不一定是创建finished response了,这个方法名字需要修改吗?

self,
chat_history: List[Message],
steps: List[AgentStep],
curr_step: Union[NoActionStep, PluginStep],
) -> AgentResponse:
last_message = chat_history[-1]
return AgentResponse(
text=last_message.content,
chat_history=chat_history,
steps=steps,
status="FINISHED",
)
if isinstance(curr_step, NoActionStep):
return AgentResponse(
text=last_message.content,
chat_history=chat_history,
steps=steps,
status=curr_step.info["end_reason"],
)
else:
# plugin end
return AgentResponse(
text=last_message.content,
chat_history=chat_history,
steps=steps,
status="FINISHED",
)

def _create_stopped_response(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from erniebot_agent.agents.function_agent import FunctionAgent
from erniebot_agent.agents.schema import (
NO_ACTION_STEP,
AgentResponse,
AgentStep,
File,
Expand Down Expand Up @@ -148,7 +149,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
await self._callback_manager.on_tool_error(agent=self, tool=self.search_tool, error=e)
raise
await self._callback_manager.on_tool_end(agent=self, tool=self.search_tool, response=tool_resp)
response = self._create_finished_response(chat_history, steps_taken)
response = self._create_finished_response(chat_history, steps_taken, curr_step=NO_ACTION_STEP)
self.memory.add_message(chat_history[0])
self.memory.add_message(chat_history[-1])
return response
Expand Down Expand Up @@ -252,7 +253,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
if not isinstance(curr_step, NoActionStep):
steps_taken.append(curr_step)
if isinstance(curr_step, (NoActionStep, PluginStep)): # plugin with action
response = self._create_finished_response(chat_history, steps_taken)
response = self._create_finished_response(chat_history, steps_taken, curr_step)
self.memory.add_message(chat_history[0])
self.memory.add_message(chat_history[-1])
return response
Expand Down Expand Up @@ -354,7 +355,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
if not isinstance(curr_step, NoActionStep):
steps_taken.append(curr_step)
if isinstance(curr_step, (NoActionStep, PluginStep)): # plugin with action
response = self._create_finished_response(chat_history, steps_taken)
response = self._create_finished_response(chat_history, steps_taken, curr_step=curr_step)
self.memory.add_message(chat_history[0])
self.memory.add_message(chat_history[-1])
return response
Expand Down
10 changes: 5 additions & 5 deletions erniebot-agent/src/erniebot_agent/agents/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,20 +114,20 @@ class PluginStep(AgentStepWithFiles[PluginInfo, str]):
"""A step taken by an agent that calls a plugin."""


class _NullInfo(Dict):
pass
class EndInfo(Dict):
end_reason: str


class _NullResult(object):
class NullResult(object):
pass


@dataclass
class NoActionStep(AgentStep[_NullInfo, _NullResult]):
class NoActionStep(AgentStep[EndInfo, NullResult]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我又思考了一下,有一些不同看法:现在NoActionStep可能已经不是简单的sentinel了。不过,NoActionStep在定位上其实并不一定与end强绑定(只是FunctionAgent正好遇到NoActionStep终止)。为了提升自由度,建议NoActionStep_IT设置为Dict[str, Any]_RT设置为Any,允许自由配置。在FunctionAgent中可以通过NoActionStep传递end_reason信息。
NO_ACTION_STEP仍然可以提供,建议直接NoActionStep(info={}, result=None),用作哨兵。

"""A step taken by an agent that performs no action and gives no result."""


NO_ACTION_STEP = NoActionStep(info=_NullInfo(), result=_NullResult())
NO_ACTION_STEP = NoActionStep(info=EndInfo(end_reason="FINISHED"), result=NullResult())


@dataclass
Expand Down
22 changes: 17 additions & 5 deletions erniebot-agent/src/erniebot_agent/chat_models/erniebot.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(
api_type: str = "aistudio",
access_token: Optional[str] = None,
enable_multi_step_tool_call: bool = False,
enable_human_clarify: bool = False,
**default_chat_kwargs: Any,
) -> None:
"""Initializes an instance of the `ERNIEBot` class.
Expand All @@ -114,6 +115,7 @@ def __init__(
If access_token is None, the global access_token will be used.
enable_multi_step_tool_call (bool): Whether to enable the multi-step tool call.
Defaults to False.
enable_human_clarify (bool): Whether to enable the human clarify. Defaults to False.
**default_chat_kwargs: Keyword arguments, such as `_config_`, `top_p`, `temperature`,
`penalty_score`, and `system`.
"""
Expand All @@ -125,9 +127,9 @@ def __init__(
self.access_token = access_token
self._maybe_validate_qianfan_auth()

self.enable_multi_step_json = json.dumps(
{"multi_step_tool_call_close": not enable_multi_step_tool_call}
)
self.extra_data = {}
self.extra_data["multi_step_tool_call_close"] = not enable_multi_step_tool_call
self.extra_data["chat_with_human_close"] = not enable_human_clarify

@overload
async def chat(
Expand Down Expand Up @@ -178,6 +180,7 @@ async def chat(
If `stream` is False, returns a single message.
If `stream` is True, returns an asynchronous iterator of message chunks.
"""

cfg_dict = self._generate_config(messages, functions=functions, **kwargs)

response = await self._generate_response(cfg_dict, stream, functions)
Expand Down Expand Up @@ -260,14 +263,14 @@ async def _generate_response(
_config_=cfg_dict["_config_"],
functions=functions, # type: ignore
extra_params={
"extra_data": self.enable_multi_step_json,
"extra_data": json.dumps(self.extra_data),
},
)
else:
response = await erniebot.ChatCompletion.acreate(
stream=stream,
extra_params={
"extra_data": self.enable_multi_step_json,
"extra_data": json.dumps(self.extra_data),
},
**cfg_dict,
)
Expand All @@ -276,6 +279,11 @@ async def _generate_response(


def convert_response_to_output(response: ChatCompletionResponse, output_type: Type[_T]) -> _T:
clarify = False
# ernie-turbo has no `finish_reason`
if hasattr(response, "finish_reason") and response["finish_reason"] == "plugin_clarify":
clarify = True

if hasattr(response, "function_call"):
function_call = FunctionCall(
name=response.function_call["name"],
Expand All @@ -287,6 +295,7 @@ def convert_response_to_output(response: ChatCompletionResponse, output_type: Ty
function_call=function_call,
plugin_info=None,
search_info=None,
clarify=clarify,
token_usage=response.usage,
)
elif hasattr(response, "plugin_info"):
Expand All @@ -303,6 +312,7 @@ def convert_response_to_output(response: ChatCompletionResponse, output_type: Ty
plugin_info=plugin_info,
search_info=None,
token_usage=response.usage,
clarify=clarify,
)
elif hasattr(response, "search_info") and len(response.search_info.items()) > 0:
search_info = SearchInfo(
Expand All @@ -314,6 +324,7 @@ def convert_response_to_output(response: ChatCompletionResponse, output_type: Ty
plugin_info=None,
search_info=search_info,
token_usage=response.usage,
clarify=clarify,
)
else:
return output_type(
Expand All @@ -322,4 +333,5 @@ def convert_response_to_output(response: ChatCompletionResponse, output_type: Ty
plugin_info=None,
search_info=None,
token_usage=response.usage,
clarify=clarify,
)
2 changes: 2 additions & 0 deletions erniebot-agent/src/erniebot_agent/memory/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def __init__(
token_usage: Optional[TokenUsage] = None,
plugin_info: Optional[PluginInfo] = None,
search_info: Optional[SearchInfo] = None,
clarify: Optional[bool] = False,
):
if token_usage is None:
prompt_tokens = 0
Expand All @@ -280,6 +281,7 @@ def __init__(
self.query_tokens_count = prompt_tokens
self.plugin_info = plugin_info
self.search_info = search_info
self.clarify = clarify
self._to_dict_keys = ["role", "content", "function_call", "plugin_info", "search_info"]

def _parse_token_count(self, token_usage: TokenUsage):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,48 @@ async def test_function_call(self):

content = res.content or None
self.assertIsNotNone(content)

@pytest.mark.asyncio
async def test_function_call_with_clarify(self):
functions = [
{
"name": "get_current_weather",
"description": "获得指定地点的天气",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string", "description": "省,市名,例如:河北省"},
"unit": {"type": "string", "enum": ["摄氏度", "华氏度"]},
},
"required": ["location"],
},
"responses": {
"type": "object",
"properties": {
"temperature": {"type": "number", "description": "当前温度"},
"weather_condition": {"type": "string", "description": "当前天气状况,例如:晴,多云,雨等"},
"humidity": {"type": "number", "description": "当前湿度百分比"},
"wind_speed": {"type": "number", "description": "风速,单位为公里每小时或英里每小时"},
},
"required": ["temperature", "weather_condition"],
},
}
]
eb = ERNIEBot(
model="ernie-3.5",
api_type="aistudio",
enable_human_clarify=True,
enable_multi_step_tool_call=True,
)
messages = [
HumanMessage(content="这个地方今天天气如何?"),
]
res = await eb.chat(messages, functions=functions)
self.assertTrue(isinstance(res, AIMessage))
self.assertTrue(res.clarify)

messages.append(res)
messages.append(HumanMessage(content="深圳"))
res_2 = await eb.chat(messages, functions=functions)
self.assertTrue(hasattr(res_2, "function_call"))
self.assertTrue(res_2.function_call["arguments"], '{"location":"深圳"}')
Loading