From a44bf12488ade388b8a1737f1534f1c92a0a2b28 Mon Sep 17 00:00:00 2001 From: Zhicheng Zhang Date: Thu, 10 Oct 2024 19:15:05 +0800 Subject: [PATCH] support openapi schema and test api --- apps/agentfabric/server.py | 128 +++++++++++++++++- modelscope_agent/tools/base.py | 59 +++++++- modelscope_agent/tools/utils/openapi_utils.py | 26 +++- 3 files changed, 202 insertions(+), 11 deletions(-) diff --git a/apps/agentfabric/server.py b/apps/agentfabric/server.py index 6aaa3b0f..3c26081b 100644 --- a/apps/agentfabric/server.py +++ b/apps/agentfabric/server.py @@ -20,6 +20,7 @@ from modelscope_agent.constants import (MODELSCOPE_AGENT_TOKEN_HEADER_NAME, ApiNames) from modelscope_agent.schemas import Message +from modelscope_agent.tools.base import OpenapiServiceProxy from publish_util import (pop_user_info_from_config, prepare_agent_zip, reload_agent_dir) from server_logging import logger, request_id_var @@ -561,6 +562,131 @@ def get_preview_chat_file(uuid_str, session_str): }), 404 +@app.route('/openapi/schema/', methods=['POST']) +@with_request_id +def openapi_schema_parser(uuid_str): + logger.info(f'parse openapi schema for: uuid_str_{uuid_str}') + params_str = request.get_data(as_text=True) + openapi_schema = None + try: + params = json.loads(params_str) + openapi_schema = params.get('openapi_schema') + except json.decoder.JSONDecodeError: + logger.error('OpenAPI schema format error, should be a valid json') + if not openapi_schema: + return jsonify({ + 'success': False, + 'message': 'OpenAPI schema format error, should be valid json', + 'request_id': request_id_var.get('') + }) + openapi_schema_instance = OpenapiServiceProxy(openapi=openapi_schema) + import copy + schema_info = copy.deepcopy(openapi_schema_instance.api_info_dict) + for item in schema_info: + schema_info[item].pop('is_active') + schema_info[item].pop('is_remote_tool') + schema_info[item].pop('details') + + return jsonify({ + 'success': True, + 'schema_info': schema_info, + 'request_id': request_id_var.get('') + }) + + +@app.route('/openapi/test/', methods=['POST']) +@with_request_id +def openapi_test_parser(uuid_str): + logger.info(f'parse openapi schema for: uuid_str_{uuid_str}') + params_str = request.get_data(as_text=True) + openapi_schema = None + tool_params = None + tool_name = '' + credentials = {} + try: + params = json.loads(params_str) + openapi_schema = params.get('openapi_schema') + tool_params = params.get('tool_params') + tool_name = params.get('tool_name') + credentials = params.get('credentials') + except json.decoder.JSONDecodeError: + logger.error('OpenAPI schema format error, should be a valid json') + if not openapi_schema: + return jsonify({ + 'success': False, + 'message': 'OpenAPI schema format error, should be valid json', + 'request_id': request_id_var.get('') + }) + openapi_schema_instance = OpenapiServiceProxy( + openapi=openapi_schema, is_remote=False) + result = openapi_schema_instance.call( + tool_params, **{ + 'tool_name': tool_name, + 'credentials': credentials + }) + if not result: + return jsonify({ + 'success': False, + 'result': None, + 'request_id': request_id_var.get('') + }) + return jsonify({ + 'success': True, + 'result': result, + 'request_id': request_id_var.get('') + }) + + +# Mock database +todos_db = {} + + +@app.route('/todos/', methods=['GET']) +def get_todos(username): + if username in todos_db: + return jsonify({'output': {'todos': todos_db[username]}}) + else: + return jsonify({'output': {'todos': []}}) + + +@app.route('/todos/', methods=['POST']) +def add_todo(username): + if not request.is_json: + return jsonify({'output': 'Missing JSON in request'}), 400 + + todo_data = request.get_json() + todo = todo_data.get('todo') + + if not todo: + return jsonify({'output': "Missing 'todo' in request"}), 400 + + if username in todos_db: + todos_db[username].append(todo) + else: + todos_db[username] = [todo] + + return jsonify({'output': 'Todo added successfully'}), 200 + + +@app.route('/todos/', methods=['DELETE']) +def delete_todo(username): + if not request.is_json: + return jsonify({'output': 'Missing JSON in request'}), 400 + + todo_data = request.get_json() + todo_idx = todo_data.get('todo_idx') + + if todo_idx is None: + return jsonify({'output': "Missing 'todo_idx' in request"}), 400 + + if username in todos_db and 0 <= todo_idx < len(todos_db[username]): + deleted_todo = todos_db[username].pop(todo_idx) + return jsonify( + {'output': f"Todo '{deleted_todo}' deleted successfully"}), 200 + else: + return jsonify({'output': "Invalid 'todo_idx' or username"}), 400 + + @app.errorhandler(Exception) @with_request_id def handle_error(error): @@ -579,4 +705,4 @@ def handle_error(error): if __name__ == '__main__': port = int(os.getenv('PORT', '5001')) - app.run(host='0.0.0.0', port=port, debug=False) + app.run(host='0.0.0.0', port=5002, debug=False) diff --git a/modelscope_agent/tools/base.py b/modelscope_agent/tools/base.py index 3911b7c2..ee09fb74 100644 --- a/modelscope_agent/tools/base.py +++ b/modelscope_agent/tools/base.py @@ -541,8 +541,10 @@ def parser_function_by_tool_name(self, tool_name: str): def parse_service_response(response): try: # Assuming the response is a JSON string - response_data = response.json() - + if not isinstance(response, dict): + response_data = response.json() + else: + response_data = response # Extract the 'output' field from the response output_data = response_data.get('output', {}) return output_data @@ -591,9 +593,45 @@ def _verify_args(self, params: str, api_info) -> Union[str, dict]: raise ValueError(f'param `{param["name"]}` is required') return params_json + def _parse_credentials(self, credentials: dict, headers=None): + if not headers: + headers = {} + if 'auth_type' not in credentials: + raise KeyError('Missing auth_type') + if credentials['auth_type'] == 'api_key': + api_key_header = 'api_key' + + if 'api_key_header' in credentials: + api_key_header = credentials['api_key_header'] + + if 'api_key_value' not in credentials: + raise KeyError('Missing api_key_value') + elif not isinstance(credentials['api_key_value'], str): + raise KeyError('api_key_value must be a string') + + if 'api_key_header_prefix' in credentials: + api_key_header_prefix = credentials['api_key_header_prefix'] + if api_key_header_prefix == 'basic' and credentials[ + 'api_key_value']: + credentials[ + 'api_key_value'] = f'Basic {credentials["api_key_value"]}' + elif api_key_header_prefix == 'bearer' and credentials[ + 'api_key_value']: + credentials[ + 'api_key_value'] = f'Bearer {credentials["api_key_value"]}' + elif api_key_header_prefix == 'custom': + pass + + headers[api_key_header] = credentials['api_key_value'] + return headers + def call(self, params: str, **kwargs): # ms_token tool_name = kwargs.get('tool_name', '') + if tool_name not in self.api_info_dict: + raise ValueError( + f'tool name {tool_name} not in the list of tools {self.tool_names}' + ) api_info = self.api_info_dict[tool_name] self.user_token = kwargs.get('user_token', self.user_token) service_token = os.getenv('TOOL_MANAGER_AUTH', '') @@ -630,7 +668,6 @@ def call(self, params: str, **kwargs): for name, value in path_params.items(): url = url.replace(f'{{{name}}}', f'{value}') - try: # visit tool node to call tool if self.is_remote: @@ -650,7 +687,9 @@ def call(self, params: str, **kwargs): response.raise_for_status() else: - response = execute_api_call(url, method, headers, params, data, + credentials = kwargs.get('credentials', {}) + header = self._parse_credentials(credentials, header) + response = execute_api_call(url, method, header, params, data, cookies) return OpenapiServiceProxy.parse_service_response(response) except Exception as e: @@ -661,7 +700,17 @@ def call(self, params: str, **kwargs): if __name__ == '__main__': import copy - openapi_instance = OpenapiServiceProxy('openapi_plugin') + + test_str = 'openapi_plugin' + openapi_instance = OpenapiServiceProxy(openapi=test_str) + schema_info = copy.deepcopy(openapi_instance.api_info_dict) + for item in schema_info: + schema_info[item].pop('is_active') + schema_info[item].pop('is_remote_tool') + schema_info[item].pop('details') + + print(schema_info) + print(openapi_instance.api_info_dict) function_map = {} tool_names = openapi_instance.tool_names for tool_name in tool_names: diff --git a/modelscope_agent/tools/utils/openapi_utils.py b/modelscope_agent/tools/utils/openapi_utils.py index 7acce5cb..3fc83cd0 100644 --- a/modelscope_agent/tools/utils/openapi_utils.py +++ b/modelscope_agent/tools/utils/openapi_utils.py @@ -26,8 +26,7 @@ def execute_api_call(url: str, method: str, headers: dict, params: dict, return response.json() except requests.exceptions.RequestException as e: - raise Exception( - f'An error occurred: {response.message}, with error {e}') + raise Exception(f'An error occurred with error {e}') def parse_nested_parameters(param_name, param_info, parameters_list, content): @@ -64,7 +63,9 @@ def parse_nested_parameters(param_name, param_info, parameters_list, content): 'type': inner_param_type, 'enum': - inner_param_info.get('enum', '') + inner_param_info.get('enum', ''), + 'in': + 'requestBody' }) else: # Non-nested parameters are added directly to the parameter list @@ -73,7 +74,8 @@ def parse_nested_parameters(param_name, param_info, parameters_list, content): 'description': param_description, 'required': param_required, 'type': param_type, - 'enum': param_info.get('enum', '') + 'enum': param_info.get('enum', ''), + 'in': 'requestBody' }) except Exception as e: raise ValueError(f'{e}:schema结构出错') @@ -117,7 +119,21 @@ def openapi_schema_convert(schema: dict, auth: dict = {}): path_parameters = details.get('parameters', []) if isinstance(path_parameters, dict): path_parameters = [path_parameters] - parameters_list.extend(path_parameters) + for path_parameter in path_parameters: + parameters_list.append({ + 'name': + path_parameter['name'], + 'description': + path_parameter.get('description', 'No description'), + 'in': + path_parameter['in'], + 'required': + path_parameter.get('required', False), + 'type': + path_parameter['schema']['type'], + 'enum': + path_parameter.get('enum', '') + }) summary = details.get('summary', 'No summary').replace(' ', '_').lower()