Skip to content

Commit

Permalink
[Enhancement]Add clarify (#281)
Browse files Browse the repository at this point in the history
* add clarify

* add extra paras

* rm breakpoint

* renew type

* add clarify in step and response

* rm pdb

* add ernie-turbo

* detailed step info

* add unit test of clarify

* add end step

* support retrieval function agent
  • Loading branch information
Southpika authored Jan 11, 2024
1 parent d02cc4b commit 68c305b
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 34 deletions.
28 changes: 20 additions & 8 deletions erniebot-agent/src/erniebot_agent/agents/function_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
from erniebot_agent.agents.callback.callback_manager import CallbackManager
from erniebot_agent.agents.callback.handlers.base import CallbackHandler
from erniebot_agent.agents.schema import (
NO_ACTION_STEP,
DEFAULT_FINISH_STEP,
AgentResponse,
AgentStep,
NoActionStep,
EndInfo,
EndStep,
PluginStep,
ToolInfo,
ToolStep,
Expand Down Expand Up @@ -136,7 +137,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age

for tool in self._first_tools:
curr_step, new_messages = await self._step(chat_history, selected_tool=tool)
if not isinstance(curr_step, NoActionStep):
if not isinstance(curr_step, EndStep):
chat_history.extend(new_messages)
num_steps_taken += 1
steps_taken.append(curr_step)
Expand All @@ -147,11 +148,18 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
while num_steps_taken < self.max_steps:
curr_step, new_messages = await self._step(chat_history)
chat_history.extend(new_messages)
if not isinstance(curr_step, NoActionStep):
if isinstance(curr_step, ToolStep):
steps_taken.append(curr_step)

if isinstance(curr_step, (NoActionStep, PluginStep)): # plugin with action
response = self._create_finished_response(chat_history, steps_taken)
elif isinstance(curr_step, PluginStep):
steps_taken.append(curr_step)
# 预留 调用了Plugin之后不结束的接口

# 此处为调用了Plugin之后直接结束的Plugin
curr_step = DEFAULT_FINISH_STEP

if isinstance(curr_step, EndStep):
response = self._create_finished_response(chat_history, steps_taken, curr_step)
self.memory.add_message(chat_history[0])
self.memory.add_message(chat_history[-1])
return response
Expand Down Expand Up @@ -204,19 +212,23 @@ async def _step(
new_messages,
)
else:
return NO_ACTION_STEP, new_messages
if output_message.clarify:
# `clarify` and [`function_call`, `plugin`(directly end)] will not appear at the same time
return EndStep(info=EndInfo(end_reason="CLARIFY"), result=None), new_messages
return DEFAULT_FINISH_STEP, new_messages

def _create_finished_response(
self,
chat_history: List[Message],
steps: List[AgentStep],
curr_step: EndStep,
) -> AgentResponse:
last_message = chat_history[-1]
return AgentResponse(
text=last_message.content,
chat_history=chat_history,
steps=steps,
status="FINISHED",
status=curr_step.info["end_reason"],
)

def _create_stopped_response(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

from erniebot_agent.agents.function_agent import FunctionAgent
from erniebot_agent.agents.schema import (
DEFAULT_FINISH_STEP,
AgentResponse,
AgentStep,
EndStep,
File,
NoActionStep,
PluginStep,
ToolAction,
ToolInfo,
Expand Down Expand Up @@ -148,7 +149,9 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
await self._callback_manager.on_tool_error(agent=self, tool=self.search_tool, error=e)
raise
await self._callback_manager.on_tool_end(agent=self, tool=self.search_tool, response=tool_resp)
response = self._create_finished_response(chat_history, steps_taken)
response = self._create_finished_response(
chat_history, steps_taken, curr_step=DEFAULT_FINISH_STEP
)
self.memory.add_message(chat_history[0])
self.memory.add_message(chat_history[-1])
return response
Expand Down Expand Up @@ -249,10 +252,18 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
while num_steps_taken < self.max_steps:
curr_step, new_messages = await self._step(chat_history)
chat_history.extend(new_messages)
if not isinstance(curr_step, NoActionStep):
if isinstance(curr_step, ToolStep):
steps_taken.append(curr_step)

elif isinstance(curr_step, PluginStep):
steps_taken.append(curr_step)
if isinstance(curr_step, (NoActionStep, PluginStep)): # plugin with action
response = self._create_finished_response(chat_history, steps_taken)
# 预留 调用了Plugin之后不结束的接口

# 此处为调用了Plugin之后直接结束的Plugin
curr_step = DEFAULT_FINISH_STEP

if isinstance(curr_step, EndStep):
response = self._create_finished_response(chat_history, steps_taken, curr_step)
self.memory.add_message(chat_history[0])
self.memory.add_message(chat_history[-1])
return response
Expand Down Expand Up @@ -351,10 +362,18 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
while num_steps_taken < self.max_steps:
curr_step, new_messages = await self._step(chat_history)
chat_history.extend(new_messages)
if not isinstance(curr_step, NoActionStep):
if isinstance(curr_step, ToolStep):
steps_taken.append(curr_step)
if isinstance(curr_step, (NoActionStep, PluginStep)): # plugin with action
response = self._create_finished_response(chat_history, steps_taken)

elif isinstance(curr_step, PluginStep):
steps_taken.append(curr_step)
# 预留 调用了Plugin之后不结束的接口

# 此处为调用了Plugin之后直接结束的Plugin
curr_step = DEFAULT_FINISH_STEP

if isinstance(curr_step, EndStep): # plugin with action
response = self._create_finished_response(chat_history, steps_taken, curr_step=curr_step)
self.memory.add_message(chat_history[0])
self.memory.add_message(chat_history[-1])
return response
Expand Down
13 changes: 13 additions & 0 deletions erniebot-agent/src/erniebot_agent/agents/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,19 @@ class NoActionStep(AgentStep[_NullInfo, _NullResult]):
NO_ACTION_STEP = NoActionStep(info=_NullInfo(), result=_NullResult())


class EndInfo(Dict):
end_reason: str
extra_info: str # json format


@dataclass
class EndStep(AgentStep[EndInfo, None]):
"""A step taken by an agent that ends whole run."""


DEFAULT_FINISH_STEP = EndStep(info=EndInfo(end_reason="FINISHED"), result=None)


@dataclass
class AgentResponse(object):
"""The final response from an agent."""
Expand Down
22 changes: 17 additions & 5 deletions erniebot-agent/src/erniebot_agent/chat_models/erniebot.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(
api_type: str = "aistudio",
access_token: Optional[str] = None,
enable_multi_step_tool_call: bool = False,
enable_human_clarify: bool = False,
**default_chat_kwargs: Any,
) -> None:
"""Initializes an instance of the `ERNIEBot` class.
Expand All @@ -114,6 +115,7 @@ def __init__(
If access_token is None, the global access_token will be used.
enable_multi_step_tool_call (bool): Whether to enable the multi-step tool call.
Defaults to False.
enable_human_clarify (bool): Whether to enable the human clarify. Defaults to False.
**default_chat_kwargs: Keyword arguments, such as `_config_`, `top_p`, `temperature`,
`penalty_score`, and `system`.
"""
Expand All @@ -125,9 +127,9 @@ def __init__(
self.access_token = access_token
self._maybe_validate_qianfan_auth()

self.enable_multi_step_json = json.dumps(
{"multi_step_tool_call_close": not enable_multi_step_tool_call}
)
self.extra_data = {}
self.extra_data["multi_step_tool_call_close"] = not enable_multi_step_tool_call
self.extra_data["chat_with_human_close"] = not enable_human_clarify

@overload
async def chat(
Expand Down Expand Up @@ -178,6 +180,7 @@ async def chat(
If `stream` is False, returns a single message.
If `stream` is True, returns an asynchronous iterator of message chunks.
"""

cfg_dict = self._generate_config(messages, functions=functions, **kwargs)

response = await self._generate_response(cfg_dict, stream, functions)
Expand Down Expand Up @@ -260,14 +263,14 @@ async def _generate_response(
_config_=cfg_dict["_config_"],
functions=functions, # type: ignore
extra_params={
"extra_data": self.enable_multi_step_json,
"extra_data": json.dumps(self.extra_data),
},
)
else:
response = await erniebot.ChatCompletion.acreate(
stream=stream,
extra_params={
"extra_data": self.enable_multi_step_json,
"extra_data": json.dumps(self.extra_data),
},
**cfg_dict,
)
Expand All @@ -276,6 +279,11 @@ async def _generate_response(


def convert_response_to_output(response: ChatCompletionResponse, output_type: Type[_T]) -> _T:
clarify = False
# ernie-turbo has no `finish_reason`
if hasattr(response, "finish_reason") and response["finish_reason"] == "plugin_clarify":
clarify = True

if hasattr(response, "function_call"):
function_call = FunctionCall(
name=response.function_call["name"],
Expand All @@ -287,6 +295,7 @@ def convert_response_to_output(response: ChatCompletionResponse, output_type: Ty
function_call=function_call,
plugin_info=None,
search_info=None,
clarify=clarify,
token_usage=response.usage,
)
elif hasattr(response, "plugin_info"):
Expand All @@ -303,6 +312,7 @@ def convert_response_to_output(response: ChatCompletionResponse, output_type: Ty
plugin_info=plugin_info,
search_info=None,
token_usage=response.usage,
clarify=clarify,
)
elif hasattr(response, "search_info") and len(response.search_info.items()) > 0:
search_info = SearchInfo(
Expand All @@ -314,6 +324,7 @@ def convert_response_to_output(response: ChatCompletionResponse, output_type: Ty
plugin_info=None,
search_info=search_info,
token_usage=response.usage,
clarify=clarify,
)
else:
return output_type(
Expand All @@ -322,4 +333,5 @@ def convert_response_to_output(response: ChatCompletionResponse, output_type: Ty
plugin_info=None,
search_info=None,
token_usage=response.usage,
clarify=clarify,
)
2 changes: 2 additions & 0 deletions erniebot-agent/src/erniebot_agent/memory/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def __init__(
token_usage: Optional[TokenUsage] = None,
plugin_info: Optional[PluginInfo] = None,
search_info: Optional[SearchInfo] = None,
clarify: Optional[bool] = False,
):
if token_usage is None:
prompt_tokens = 0
Expand All @@ -280,6 +281,7 @@ def __init__(
self.query_tokens_count = prompt_tokens
self.plugin_info = plugin_info
self.search_info = search_info
self.clarify = clarify
self._to_dict_keys = ["role", "content", "function_call", "plugin_info", "search_info"]

def _parse_token_count(self, token_usage: TokenUsage):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,48 @@ async def test_function_call(self):

content = res.content or None
self.assertIsNotNone(content)

@pytest.mark.asyncio
async def test_function_call_with_clarify(self):
functions = [
{
"name": "get_current_weather",
"description": "获得指定地点的天气",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string", "description": "省,市名,例如:河北省"},
"unit": {"type": "string", "enum": ["摄氏度", "华氏度"]},
},
"required": ["location"],
},
"responses": {
"type": "object",
"properties": {
"temperature": {"type": "number", "description": "当前温度"},
"weather_condition": {"type": "string", "description": "当前天气状况,例如:晴,多云,雨等"},
"humidity": {"type": "number", "description": "当前湿度百分比"},
"wind_speed": {"type": "number", "description": "风速,单位为公里每小时或英里每小时"},
},
"required": ["temperature", "weather_condition"],
},
}
]
eb = ERNIEBot(
model="ernie-3.5",
api_type="aistudio",
enable_human_clarify=True,
enable_multi_step_tool_call=True,
)
messages = [
HumanMessage(content="这个地方今天天气如何?"),
]
res = await eb.chat(messages, functions=functions)
self.assertTrue(isinstance(res, AIMessage))
self.assertTrue(res.clarify)

messages.append(res)
messages.append(HumanMessage(content="深圳"))
res_2 = await eb.chat(messages, functions=functions)
self.assertTrue(hasattr(res_2, "function_call"))
self.assertTrue(res_2.function_call["arguments"], '{"location":"深圳"}')
Loading

0 comments on commit 68c305b

Please sign in to comment.