Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] support more file/array schema in tools #292

Merged
merged 6 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions erniebot-agent/src/erniebot_agent/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,10 @@ async def _run_tool(self, tool: BaseTool, tool_args: str) -> ToolResponse:
# XXX: Sniffing is less efficient and probably unnecessary.
# Can we make a protocol to statically recognize file inputs and outputs
# or can we have the tools introspect about this?
input_files = file_manager.sniff_and_extract_files_from_list(list(parsed_tool_args.values()))
input_files = file_manager.sniff_and_extract_files_from_dict(parsed_tool_args)
tool_ret = await tool(**parsed_tool_args)
if isinstance(tool_ret, dict):
output_files = file_manager.sniff_and_extract_files_from_list(list(tool_ret.values()))
output_files = file_manager.sniff_and_extract_files_from_dict(tool_ret)
else:
output_files = []
tool_ret_json = json.dumps(tool_ret, ensure_ascii=False)
Expand Down
26 changes: 26 additions & 0 deletions erniebot-agent/src/erniebot_agent/file/file_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,32 @@ def sniff_and_extract_files_from_list(self, list_: List[Any]) -> List[File]:
files.append(file)
return files

def sniff_and_extract_files_from_dict(self, dict_data: Dict[str, Any]) -> List[File]:
files = []

def try_get_file(string_value):
if not protocol.is_file_id(string_value):
return
try:
file = self.look_up_file_by_id(string_value)
files.append(file)
except FileError as e:
raise FileError(f"An unregistered file with ID {repr(string_value)} was found.") from e

for key, value in dict_data.items():
if isinstance(value, str):
try_get_file(value)
elif isinstance(value, list):
for item in value:
if isinstance(item, str):
try_get_file(item)
elif isinstance(item, dict):
files.extend(self.sniff_and_extract_files_from_dict(item))
elif isinstance(value, dict):
files.extend(self.sniff_and_extract_files_from_dict(value))

return files

def sniff_and_extract_files_from_text(self, text: str) -> List[File]:
file_ids = protocol.extract_file_ids(text)
files: List[File] = []
Expand Down
109 changes: 24 additions & 85 deletions erniebot-agent/src/erniebot_agent/tools/remote_tool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import base64
import dataclasses
import logging
from copy import deepcopy
Expand All @@ -18,11 +17,10 @@
from erniebot_agent.tools.schema import RemoteToolView
from erniebot_agent.tools.utils import (
get_file_info_from_param_view,
parse_file_from_json_response,
parse_file_from_response,
parse_json_request,
parse_response,
tool_response_contains_file,
)
from erniebot_agent.utils.common import is_json_response
from erniebot_agent.utils.exceptions import RemoteToolError

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -71,51 +69,12 @@ def tool_name(self):
return self.tool_view.name

async def __pre_process__(self, tool_arguments: Dict[str, Any]) -> dict:
async def fileid_to_byte(file_id, file_manager):
file = file_manager.look_up_file_by_id(file_id)
byte_str = await file.read_contents()
return byte_str

async def convert_to_file_data(file_data: str, format: str):
value = file_data.replace("<file>", "").replace("</file>", "")
byte_value = await fileid_to_byte(value, file_manager)
if format == "byte":
byte_value = base64.b64encode(byte_value).decode()
return byte_value

file_manager = self._get_file_manager()

# 1. replace fileid with byte string
parameter_file_info = get_file_info_from_param_view(self.tool_view.parameters)
for key in tool_arguments.keys():
if self.tool_view.parameters:
if key not in self.tool_view.parameters.model_fields:
keys = list(self.tool_view.parameters.model_fields.keys())
raise RemoteToolError(
f"`{self.tool_name}` received unexpected arguments `{key}`. "
f"The avaiable arguments are {keys}",
stage="Input parsing",
)
if key not in parameter_file_info:
continue
if self.tool_view.parameters is None:
break

argument_value = tool_arguments[key]
if isinstance(argument_value, list):
for index in range(len(argument_value)):
argument_value[index] = await convert_to_file_data(
argument_value[index], parameter_file_info[key]["format"]
)
else:
argument_value = await convert_to_file_data(
argument_value, parameter_file_info[key]["format"]
)

tool_arguments[key] = argument_value

# 2. call tool get response
if self.tool_view.parameters is not None:
tool_arguments = await parse_json_request(
self.tool_view.parameters, tool_arguments, file_manager
)
tool_arguments = self.tool_view.parameters(**tool_arguments).model_dump(mode="json")

return tool_arguments
Expand All @@ -132,9 +91,17 @@ async def __post_process__(self, tool_response: dict) -> dict:
"请务必确保每个符合'file-'格式的字段只出现一次,无需将其转换为链接,也无需添加任何HTML、Markdown或其他格式化元素。"
)

# TODO(wj-Mcat): open the tool-response valdiation with pydantic model
# if self.tool_view.returns is not None:
# tool_response = dict(self.tool_view.returns(**tool_response))
if self.tool_view.returns is not None:
try:
origin_tool_response = deepcopy(tool_response)
valid_tool_response = self.tool_view.returns(**origin_tool_response).model_dump(mode="json")
tool_response.update(valid_tool_response)
except Exception as e:
_logger.warning(
"Unable to validate the 'tool_response' against the schema defined in the YAML file. "
f"The specific error encountered is: '<{e}>'. "
"As a result, the original response from the tool will be used.",
)
return tool_response

async def __call__(self, **tool_arguments: Dict[str, Any]) -> Any:
Expand All @@ -143,7 +110,8 @@ async def __call__(self, **tool_arguments: Dict[str, Any]) -> Any:
return await self.__post_process__(tool_response)

async def send_request(self, tool_arguments: Dict[str, Any]) -> dict:
url = self.server_url + self.tool_view.uri + "?version=" + self.version
url = "/".join([self.server_url.strip("/"), self.tool_view.uri.strip("/")])
url += "?version=" + self.version

headers = deepcopy(self.headers)
headers["Content-Type"] = self.tool_view.parameters_content_type
Expand Down Expand Up @@ -171,6 +139,7 @@ async def send_request(self, tool_arguments: Dict[str, Any]) -> dict:
raise RemoteToolError(
f"Unsupported content type: {self.tool_view.parameters_content_type}", stage="Executing"
)

if self.tool_view.method == "get":
response = requests.get(url, **requests_inputs) # type: ignore
elif self.tool_view.method == "post":
Expand All @@ -190,44 +159,14 @@ async def send_request(self, tool_arguments: Dict[str, Any]) -> dict:
stage="Executing",
)

# parse the file from response
returns_file_infos = get_file_info_from_param_view(self.tool_view.returns)

if len(returns_file_infos) == 0 and is_json_response(response):
return response.json()

file_manager = self._get_file_manager()

file_metadata = {"tool_name": self.tool_name}
if is_json_response(response) and len(returns_file_infos) > 0:
response_json = response.json()
file_info = await parse_file_from_json_response(
response_json,
file_manager=file_manager,
param_view=self.tool_view.returns, # type: ignore
tool_name=self.tool_name,
)
response_json.update(file_info)
return response_json
file = await parse_file_from_response(
response, file_manager, file_infos=returns_file_infos, file_metadata=file_metadata
)

if file is not None:
if len(returns_file_infos) == 0:
return {self.tool_view.returns_ref_uri: file.id}

file_name = list(returns_file_infos.keys())[0]
return {file_name: file.id}

if len(returns_file_infos) == 0:
return response.json()

raise RemoteToolError(
f"<{list(returns_file_infos.keys())}> are defined but cannot be processed from the "
"response. Please ensure that the response headers contain either the Content-Disposition "
"or Content-Type field.",
stage="Output parsing",
return await parse_response(
response=response,
tool_parameter_view=self.tool_view.returns,
file_manager=file_manager,
file_metadata=file_metadata,
)

def function_call_schema(self) -> dict:
Expand Down
3 changes: 3 additions & 0 deletions erniebot-agent/src/erniebot_agent/tools/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ def python_type_from_json_type(json_type_dict: dict) -> Type[object]:
return List[float]
if json_type_value == "object":
return List[ToolParameterView]
elif json_type_value == "array":
sub_type = python_type_from_json_type(json_type_dict["items"])
return List[sub_type] # type: ignore

raise ValueError(f"unsupported data type: {json_type_value}")

Expand Down
Loading