diff --git a/RELEASE.md b/RELEASE.md index 1f1de7d6e..331372528 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,10 @@ +# Release 1.10.0 +## Major Features and Improvements +* Add connection test API +* May configure gRPC message size limit +## Bug Fixes +* Fix module duplication issue in model + # Release 1.9.1 ## Bug Fixes * Fix parameter inheritance when loading non-model modules from ModelLoader diff --git a/conf/component_registry.json b/conf/component_registry.json index f079a255b..6061b6494 100644 --- a/conf/component_registry.json +++ b/conf/component_registry.json @@ -19,7 +19,9 @@ "homo_model_convert": "protobuf.homo_model_convert.homo_model_convert", "anonymous_generator": "util.anonymous_generator_util.Anonymous", "data_format": "util.data_format_preprocess.DataFormatPreProcess", - "hetero_model_merge": "protobuf.model_merge.merge_hetero_models.hetero_model_merge" + "hetero_model_merge": "protobuf.model_merge.merge_hetero_models.hetero_model_merge", + "extract_woe_array_dict": "protobuf.model_migrate.binning_model_migrate.extract_woe_array_dict", + "merge_woe_array_dict": "protobuf.model_migrate.binning_model_migrate.merge_woe_array_dict" } } } diff --git a/python/fate_flow/apps/__init__.py b/python/fate_flow/apps/__init__.py index d989bb767..be45bdb62 100644 --- a/python/fate_flow/apps/__init__.py +++ b/python/fate_flow/apps/__init__.py @@ -22,14 +22,17 @@ from werkzeug.wrappers.request import Request from fate_arch.common.base_utils import CustomJSONEncoder + from fate_flow.entity import RetCode from fate_flow.hook import HookManager from fate_flow.hook.common.parameters import AuthenticationParameters, ClientAuthenticationParameters -from fate_flow.settings import (API_VERSION, access_logger, stat_logger, CLIENT_AUTHENTICATION, SITE_AUTHENTICATION) -from fate_flow.utils.api_utils import server_error_response, get_json_result +from fate_flow.settings import API_VERSION, CLIENT_AUTHENTICATION, SITE_AUTHENTICATION, access_logger +from fate_flow.utils.api_utils import get_json_result, server_error_response + __all__ = ['app'] + logger = logging.getLogger('flask.app') for h in access_logger.handlers: logger.addHandler(h) @@ -38,75 +41,75 @@ app = Flask(__name__) app.url_map.strict_slashes = False -app.errorhandler(Exception)(server_error_response) app.json_encoder = CustomJSONEncoder +app.errorhandler(Exception)(server_error_response) -pages_dir = [ - Path(__file__).parent, - Path(__file__).parent.parent / 'scheduling_apps' -] -pages_path = [j for i in pages_dir for j in i.glob('*_app.py')] -scheduling_url_prefix = [] -client_url_prefix = [] -for path in pages_path: - page_name = path.stem.rstrip('_app') - module_name = '.'.join(path.parts[path.parts.index('fate_flow'):-1] + (page_name, )) - - spec = spec_from_file_location(module_name, path) + +def search_pages_path(pages_dir): + return [path for path in pages_dir.glob('*_app.py') if not path.name.startswith('.')] + + +def register_page(page_path): + page_name = page_path.stem.rstrip('_app') + module_name = '.'.join(page_path.parts[page_path.parts.index('fate_flow'):-1] + (page_name, )) + + spec = spec_from_file_location(module_name, page_path) page = module_from_spec(spec) page.app = app page.manager = Blueprint(page_name, module_name) sys.modules[module_name] = page spec.loader.exec_module(page) - if not isinstance(page.manager, Blueprint): - raise TypeError(f'page.manager should be {Blueprint!r}, got {type(page.manager)}. filepath: {path!s}') - - api_version = getattr(page, 'api_version', API_VERSION) page_name = getattr(page, 'page_name', page_name) + url_prefix = f'/{API_VERSION}/{page_name}' - app.register_blueprint(page.manager, url_prefix=f'/{api_version}/{page_name}') - if 'scheduling_apps' in path.parts: - scheduling_url_prefix.append(f'/{api_version}/{page_name}') - else: - client_url_prefix.append(f'/{api_version}/{page_name}') - + app.register_blueprint(page.manager, url_prefix=url_prefix) + return url_prefix -stat_logger.info('imported pages: %s', ' '.join(str(path) for path in pages_path)) - -@app.before_request -def authentication_before_request(): - if CLIENT_AUTHENTICATION: - _result = client_authentication_before_request() - if _result: - return _result - if SITE_AUTHENTICATION: - _result = site_authentication_before_request() - if _result: - return _result +client_urls_prefix = [ + register_page(path) + for path in search_pages_path(Path(__file__).parent) +] +scheduling_urls_prefix = [ + register_page(path) + for path in search_pages_path(Path(__file__).parent.parent / 'scheduling_apps') +] def client_authentication_before_request(): - for url_prefix in scheduling_url_prefix: + for url_prefix in scheduling_urls_prefix: if request.path.startswith(url_prefix): return - parm = ClientAuthenticationParameters(full_path=request.full_path, headers=request.headers, form=request.form, - data=request.data, json=request.json) - result = HookManager.client_authentication(parm) + + result = HookManager.client_authentication(ClientAuthenticationParameters( + request.full_path, request.headers, + request.form, request.data, request.json, + )) + if result.code != RetCode.SUCCESS: return get_json_result(result.code, result.message) def site_authentication_before_request(): - from flask import request - for url_prefix in client_url_prefix: + for url_prefix in client_urls_prefix: if request.path.startswith(url_prefix): return - body = request.json - headers = request.headers - site_signature = headers.get("site_signature") - result = HookManager.site_authentication( - AuthenticationParameters(site_signature=site_signature, src_party_id=headers.get("src_party_id"), body=body)) + + result = HookManager.site_authentication(AuthenticationParameters( + request.headers.get('src_party_id'), + request.headers.get('site_signature'), + request.json, + )) + if result.code != RetCode.SUCCESS: return get_json_result(result.code, result.message) + + +@app.before_request +def authentication_before_request(): + if CLIENT_AUTHENTICATION: + return client_authentication_before_request() + + if SITE_AUTHENTICATION: + return site_authentication_before_request() diff --git a/python/fate_flow/apps/component_app.py b/python/fate_flow/apps/component_app.py index 3d8a8fc16..c870a59ee 100644 --- a/python/fate_flow/apps/component_app.py +++ b/python/fate_flow/apps/component_app.py @@ -19,11 +19,13 @@ from fate_flow.component_env_utils.env_utils import get_class_object from fate_flow.db.component_registry import ComponentRegistry +from fate_flow.db.db_models import PipelineComponentMeta from fate_flow.model.sync_model import SyncComponent from fate_flow.pipelined_model.pipelined_model import PipelinedModel from fate_flow.settings import ENABLE_MODEL_STORE from fate_flow.utils.api_utils import error_response, get_json_result, validate_request from fate_flow.utils.detect_utils import check_config +from fate_flow.utils.job_utils import generate_job_id from fate_flow.utils.model_utils import gen_party_model_id from fate_flow.utils.schedule_utils import get_dsl_parser_by_version @@ -41,7 +43,7 @@ def get_component(component_name): @manager.route('/validate', methods=['POST']) def validate_component_param(): if not request.json or not isinstance(request.json, dict): - return error_response(400, 'bad request') + return error_response(400) required_keys = [ 'component_name', @@ -79,7 +81,7 @@ def validate_component_param(): 'model_id', 'model_version', 'guest_party_id', 'host_party_ids', 'component_name', 'model_type', 'output_format', ) -def hetero_merge(): +def hetero_model_merge(): request_data = request.json if ENABLE_MODEL_STORE: @@ -91,7 +93,7 @@ def hetero_merge(): component_name=request_data['component_name'], ) if not sync_component.local_exists() and sync_component.remote_exists(): - sync_component.download(True) + sync_component.download() for party_id in request_data['host_party_ids']: sync_component = SyncComponent( @@ -102,7 +104,7 @@ def hetero_merge(): component_name=request_data['component_name'], ) if not sync_component.local_exists() and sync_component.remote_exists(): - sync_component.download(True) + sync_component.download() model = PipelinedModel( gen_party_model_id( @@ -167,3 +169,143 @@ def hetero_merge(): request_data.get('include_guest_coef', False), ) return get_json_result(data=data) + + +@manager.route('/woe_array/extract', methods=['POST']) +@validate_request( + 'model_id', 'model_version', 'role', 'party_id', 'component_name', +) +def woe_array_extract(): + if request.json['role'] != 'guest': + return error_response(400, 'Only support guest role.') + + if ENABLE_MODEL_STORE: + sync_component = SyncComponent( + role=request.json['role'], + party_id=request.json['party_id'], + model_id=request.json['model_id'], + model_version=request.json['model_version'], + component_name=request.json['component_name'], + ) + if not sync_component.local_exists() and sync_component.remote_exists(): + sync_component.download() + + model = PipelinedModel( + gen_party_model_id( + request.json['model_id'], + request.json['role'], + request.json['party_id'], + ), + request.json['model_version'], + ).read_component_model( + request.json['component_name'], + output_json=True, + ) + + param = None + meta = None + + for k, v in model.items(): + if k.endswith('Param'): + param = v + elif k.endswith('Meta'): + meta = v + else: + return error_response(400, f'Unknown model key: "{k}".') + + if param is None or meta is None: + return error_response(400, 'Invalid model.') + + data = get_class_object('extract_woe_array_dict')(param) + return get_json_result(data=data) + + +@manager.route('/woe_array/merge', methods=['POST']) +@validate_request( + 'model_id', 'model_version', 'role', 'party_id', 'component_name', 'woe_array', +) +def woe_array_merge(): + if request.json['role'] != 'host': + return error_response(400, 'Only support host role.') + + pipelined_model = PipelinedModel( + gen_party_model_id( + request.json['model_id'], + request.json['role'], + request.json['party_id'], + ), + request.json['model_version'], + ) + + query = pipelined_model.pipelined_component.get_define_meta_from_db( + PipelineComponentMeta.f_component_name == request.json['component_name'], + ) + if not query: + return error_response(404, 'Component not found.') + query = query[0] + + if ENABLE_MODEL_STORE: + sync_component = SyncComponent( + role=query.f_role, + party_id=query.f_party_id, + model_id=query.f_model_id, + model_version=query.f_model_version, + component_name=query.f_component_name, + ) + if not sync_component.local_exists() and sync_component.remote_exists(): + sync_component.download() + + model = pipelined_model._read_component_model( + query.f_component_name, + query.f_model_alias, + ) + + for model_name, ( + buffer_name, + buffer_string, + buffer_dict, + ) in model.items(): + if model_name.endswith('Param'): + string_merged, dict_merged = get_class_object('merge_woe_array_dict')( + buffer_name, + buffer_string, + buffer_dict, + request.json['woe_array'], + ) + model[model_name] = ( + buffer_name, + string_merged, + dict_merged, + ) + break + + pipelined_model = PipelinedModel( + pipelined_model.party_model_id, + generate_job_id() + ) + + pipelined_model.save_component_model( + query.f_component_name, + query.f_component_module_name, + query.f_model_alias, + model, + query.f_run_parameters, + ) + + if ENABLE_MODEL_STORE: + sync_component = SyncComponent( + role=query.f_role, + party_id=query.f_party_id, + model_id=query.f_model_id, + model_version=pipelined_model.model_version, + component_name=query.f_component_name, + ) + sync_component.upload() + + return get_json_result(data={ + 'role': query.f_role, + 'party_id': query.f_party_id, + 'model_id': query.f_model_id, + 'model_version': pipelined_model.model_version, + 'component_name': query.f_component_name, + }) diff --git a/python/fate_flow/apps/info_app.py b/python/fate_flow/apps/info_app.py index 148ca9076..e8eb3de25 100644 --- a/python/fate_flow/apps/info_app.py +++ b/python/fate_flow/apps/info_app.py @@ -15,11 +15,26 @@ # import socket -from fate_arch.common import CoordinationProxyService -from fate_flow.utils.api_utils import error_response, get_json_result -from fate_flow.settings import PROXY, IS_STANDALONE +from flask import request +from flask.json import jsonify + +from fate_arch.common import FederatedMode + +from fate_flow.db.runtime_config import RuntimeConfig from fate_flow.db.service_registry import ServerRegistry -from fate_flow.db.db_models import DB +from fate_flow.settings import API_VERSION, GRPC_PORT, HOST, HTTP_PORT, PARTY_ID +from fate_flow.utils.api_utils import error_response, federated_api, get_json_result + + +@manager.route('/common', methods=['POST']) +def get_common_info(): + return get_json_result(data={ + 'version': RuntimeConfig.get_env('FATE'), + 'host': HOST, + 'http_port': HTTP_PORT, + 'grpc_port': GRPC_PORT, + 'party_id': PARTY_ID, + }) @manager.route('/fateboard', methods=['POST']) @@ -28,35 +43,16 @@ def get_fateboard_info(): port = ServerRegistry.FATEBOARD.get('port') if not host or not port: return error_response(404, 'fateboard is not configured') + return get_json_result(data={ 'host': host, 'port': port, }) -@manager.route('/mysql', methods=['POST']) -def get_mysql_info(): - if IS_STANDALONE: - return error_response(404, 'mysql only available on cluster mode') - - try: - with DB.connection_context(): - DB.random() - except Exception as e: - return error_response(503, str(e)) - - return error_response(200) - - # TODO: send greetings message using grpc protocol @manager.route('/eggroll', methods=['POST']) def get_eggroll_info(): - if IS_STANDALONE: - return error_response(404, 'eggroll only available on cluster mode') - - if PROXY != CoordinationProxyService.ROLLSITE: - return error_response(404, 'coordination communication protocol is not rollsite') - conf = ServerRegistry.FATE_ON_EGGROLL['rollsite'] with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: r = s.connect_ex((conf['host'], conf['port'])) @@ -64,3 +60,37 @@ def get_eggroll_info(): return error_response(503) return error_response(200) + + +@manager.route('/version', methods=['POST']) +@app.route(f'/{API_VERSION}/version/get', methods=['POST']) +def get_version(): + module = request.json['module'] if isinstance(request.json, dict) and request.json.get('module') else 'FATE' + version = RuntimeConfig.get_env(module) + if version is None: + return error_response(404, f'unknown module {module}') + + return get_json_result(data={ + module: version, + 'API': API_VERSION, + }) + + +@manager.route('/party/', methods=['POST']) +def get_party_info(dest_party_id): + response = federated_api( + 'party_info', 'POST', '/info/common', + PARTY_ID, dest_party_id, '', + {}, FederatedMode.MULTIPLE, + ) + return jsonify(response) + + +@manager.route('/party//', methods=['POST']) +def get_party_info_from_another_party(proxy_party_id, dest_party_id): + response = federated_api( + 'party_info', 'POST', f'/info/party/{dest_party_id}', + PARTY_ID, proxy_party_id, '', + {}, FederatedMode.MULTIPLE, + ) + return jsonify(response) diff --git a/python/fate_flow/apps/proxy_app.py b/python/fate_flow/apps/proxy_app.py index 5dd881756..3d31369dc 100644 --- a/python/fate_flow/apps/proxy_app.py +++ b/python/fate_flow/apps/proxy_app.py @@ -24,39 +24,38 @@ page_name = 'forward' -@manager.route('/', methods=['post']) +@manager.route('/', methods=['POST']) def start_proxy(role): + _job_id = f'{role}_forward' request_config = request.json or request.form.to_dict() - _job_id = f"{role}_forward" - if role in ['marketplace']: - response = proxy_api(role, _job_id, request_config) + + if request_config.get('header') and request_config.get('body'): + request_config['header'] = { + **request.headers, + **{ + k.replace('_', '-').upper(): v + for k, v in request_config['header'].items() + }, + } else: - headers = request.headers - json_body = {} - if request_config.get('header') and request_config.get("body"): - src_party_id = request_config.get('header').get('src_party_id') - dest_party_id = request_config.get('header').get('dest_party_id') - json_body = request_config - if headers: - json_body['header'].update(headers) - else: - src_party_id = headers.get('src_party_id') - dest_party_id = headers.get('dest_party_id') - json_body["header"] = request.headers - json_body["body"] = request_config - response = federated_api(job_id=_job_id, - method='POST', - endpoint=f'/forward/{role}/do', - src_party_id=src_party_id, - dest_party_id=dest_party_id, - src_role=None, - json_body=json_body, - federated_mode=FederatedMode.MULTIPLE) + request_config = { + 'header': request.headers, + 'body': request_config, + } + + response = ( + proxy_api(role, _job_id, request_config) if role == 'marketplace' + else federated_api( + _job_id, 'POST', f'/forward/{role}/do', + request_config['header'].get('SRC-PARTY-ID'), + request_config['header'].get('DEST-PARTY-ID'), + '', request_config, FederatedMode.MULTIPLE, + ) + ) return jsonify(response) -@manager.route('//do', methods=['post']) +@manager.route('//do', methods=['POST']) def start_forward(role): request_config = request.json or request.form.to_dict() - response = forward_api(role, request_config) - return jsonify(response) + return jsonify(forward_api(role, request_config)) diff --git a/python/fate_flow/apps/version_app.py b/python/fate_flow/apps/version_app.py deleted file mode 100644 index cab915762..000000000 --- a/python/fate_flow/apps/version_app.py +++ /dev/null @@ -1,32 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from flask import request - -from fate_flow.settings import API_VERSION, FATE_ENV_KEY_LIST -from fate_flow.utils.api_utils import get_json_result, error_response -from fate_flow.db.runtime_config import RuntimeConfig - - -@manager.route('/get', methods=['POST']) -def get_fate_version_info(): - module = request.json['module'] if isinstance(request.json, dict) and request.json.get('module') else 'FATE' - version = RuntimeConfig.get_env(module) - if version is None: - return error_response(404, 'invalid module, please input module parameter in this scope: ' + " or ".join(FATE_ENV_KEY_LIST)) - return get_json_result(data={ - module: version, - 'API': API_VERSION, - }) diff --git a/python/fate_flow/fate_flow_server.py b/python/fate_flow/fate_flow_server.py index 35515d315..d64db9810 100644 --- a/python/fate_flow/fate_flow_server.py +++ b/python/fate_flow/fate_flow_server.py @@ -23,7 +23,6 @@ import traceback import grpc -from grpc._cython import cygrpc from werkzeug.serving import run_simple from fate_arch.common import file_utils @@ -45,7 +44,7 @@ from fate_flow.manager.provider_manager import ProviderManager from fate_flow.scheduler.dag_scheduler import DAGScheduler from fate_flow.settings import ( - GRPC_PORT, GRPC_SERVER_MAX_WORKERS, HOST, HTTP_PORT, + GRPC_OPTIONS, GRPC_PORT, GRPC_SERVER_MAX_WORKERS, HOST, HTTP_PORT, access_logger, database_logger, detect_logger, stat_logger, ) from fate_flow.utils.base_utils import get_fate_flow_directory @@ -105,9 +104,7 @@ thread_pool_executor = ThreadPoolExecutor(max_workers=GRPC_SERVER_MAX_WORKERS) stat_logger.info(f"start grpc server thread pool by {thread_pool_executor._max_workers} max workers") - server = grpc.server(thread_pool=thread_pool_executor, - options=[(cygrpc.ChannelArgKey.max_send_message_length, -1), - (cygrpc.ChannelArgKey.max_receive_message_length, -1)]) + server = grpc.server(thread_pool=thread_pool_executor, options=GRPC_OPTIONS) proxy_pb2_grpc.add_DataTransferServiceServicer_to_server(UnaryService(), server) server.add_insecure_port(f"{HOST}:{GRPC_PORT}") diff --git a/python/fate_flow/pipelined_model/pipelined_model.py b/python/fate_flow/pipelined_model/pipelined_model.py index e1794095f..351a71a26 100644 --- a/python/fate_flow/pipelined_model/pipelined_model.py +++ b/python/fate_flow/pipelined_model/pipelined_model.py @@ -178,7 +178,7 @@ def read_component_model(self, component_name, model_alias=None, parse=True, out else: model_buffers[model_name] = [ buffer_name, - base64.b64encode(buffer_object_serialized_string).decode("ascii"), + base64.b64encode(buffer_object_serialized_string).decode(), ] return model_buffers diff --git a/python/fate_flow/pipelined_model/publish_model.py b/python/fate_flow/pipelined_model/publish_model.py index 14cc59a4d..1ffdbefc2 100644 --- a/python/fate_flow/pipelined_model/publish_model.py +++ b/python/fate_flow/pipelined_model/publish_model.py @@ -25,7 +25,7 @@ from fate_flow.pipelined_model.homo_model_deployer.model_deploy import model_deploy from fate_flow.settings import ( ENABLE_MODEL_STORE, FATE_FLOW_MODEL_TRANSFER_ENDPOINT, - HOST, HTTP_PORT, USE_REGISTRY, stat_logger, + GRPC_OPTIONS, HOST, HTTP_PORT, USE_REGISTRY, stat_logger, ) from fate_flow.utils import model_utils @@ -49,7 +49,7 @@ def load_model(config_data): return 100, 'Please configure servings address' for serving in config_data['servings']: - with grpc.insecure_channel(serving) as channel: + with grpc.insecure_channel(serving, GRPC_OPTIONS) as channel: stub = model_service_pb2_grpc.ModelServiceStub(channel) load_model_request = model_service_pb2.PublishRequest() for role_name, role_partys in config_data.get("role", {}).items(): @@ -93,7 +93,7 @@ def bind_model_service(config_data): model_version = config_data['job_parameters']['model_version'] for serving in config_data['servings']: - with grpc.insecure_channel(serving) as channel: + with grpc.insecure_channel(serving, GRPC_OPTIONS) as channel: stub = model_service_pb2_grpc.ModelServiceStub(channel) publish_model_request = model_service_pb2.PublishRequest() publish_model_request.serviceId = service_id @@ -142,7 +142,7 @@ def convert_homo_model(request_data): if not model.exists(): return 100, 'Model {} {} does not exist'.format(party_model_id, model_version), None - define_meta = pipelined_model.pipelined_component.get_define_meta() + define_meta = model.pipelined_component.get_define_meta() framework_name = request_data.get("framework_name") detail = [] diff --git a/python/fate_flow/settings.py b/python/fate_flow/settings.py index 54008e929..11d0343e2 100644 --- a/python/fate_flow/settings.py +++ b/python/fate_flow/settings.py @@ -15,6 +15,8 @@ # import os +from grpc._cython import cygrpc + from fate_arch.computing import ComputingEngine from fate_arch.common import engine_utils from fate_arch.common.conf_utils import get_base_config, decrypt_database_config @@ -24,7 +26,6 @@ # Server API_VERSION = "v1" -FATE_ENV_KEY_LIST = ['FATE', 'FATEFlow', 'FATEBoard', 'EGGROLL', 'CENTOS', 'UBUNTU', 'PYTHON', 'MAVEN', 'JDK', 'SPARK'] FATE_FLOW_SERVICE_NAME = "fateflow" SERVER_MODULE = "fate_flow_server.py" CASBIN_TABLE_NAME = "fate_casbin" @@ -40,11 +41,15 @@ SUBPROCESS_STD_LOG_NAME = "std.log" GRPC_SERVER_MAX_WORKERS = None -MAX_TIMESTAMP_INTERVAL = 60 +GRPC_OPTIONS = [ + (cygrpc.ChannelArgKey.max_send_message_length, -1), + (cygrpc.ChannelArgKey.max_receive_message_length, -1), +] ERROR_REPORT = True ERROR_REPORT_WITH_PATH = False +MAX_TIMESTAMP_INTERVAL = 60 SESSION_VALID_PERIOD = 7 * 24 * 60 * 60 * 1000 REQUEST_TRY_TIMES = 3 diff --git a/python/fate_flow/utils/api_utils.py b/python/fate_flow/utils/api_utils.py index 6c19ffc92..2b9b99136 100644 --- a/python/fate_flow/utils/api_utils.py +++ b/python/fate_flow/utils/api_utils.py @@ -288,37 +288,48 @@ def federated_coordination_on_grpc( def proxy_api(role, _job_id, request_config): - job_id = request_config.get('header').get('job_id', _job_id) - method = request_config.get('header').get('method', 'POST') - endpoint = request_config.get('header').get('endpoint') - src_party_id = request_config.get('header').get('src_party_id') - dest_party_id = request_config.get('header').get('dest_party_id') - json_body = request_config.get('body') - _packet = forward_grpc_packet(json_body, method, endpoint, src_party_id, dest_party_id, job_id=job_id, role=role) - _routing_metadata = gen_routing_metadata(src_party_id=src_party_id, dest_party_id=dest_party_id) + headers = request_config.get('header', {}) + body = request_config.get('body', {}) + method = headers.get('METHOD', 'POST') + endpoint = headers.get('ENDPOINT', '') + job_id = headers.get('JOB-ID', _job_id) + src_party_id = headers.get('SRC-PARTY-ID', '') + dest_party_id = headers.get('DEST-PARTY-ID', '') + + _packet = forward_grpc_packet(body, method, endpoint, src_party_id, dest_party_id, role, job_id) + _routing_metadata = gen_routing_metadata(src_party_id, dest_party_id) host, port, protocol = get_federated_proxy_address(src_party_id, dest_party_id) + channel, stub = get_command_federation_channel(host, port) _return, _call = stub.unaryCall.with_call(_packet, metadata=_routing_metadata) channel.close() - json_body = json_loads(_return.body.value) - return json_body + + response = json_loads(_return.body.value) + return response def forward_api(role, request_config): - method = request_config.get('header', {}).get('method', 'post') - endpoint = request_config.get('header', {}).get('endpoint') - if not getattr(ServerRegistry, role.upper()): + role = role.upper() + if not hasattr(ServerRegistry, role): ServerRegistry.load() - ip = getattr(ServerRegistry, role.upper()).get("host") - port = getattr(ServerRegistry, role.upper()).get("port") - url = "http://{}:{}{}".format(ip, port, endpoint) - audit_logger().info('api request: {}'.format(url)) - - http_response = request(method=method, url=url, json=request_config.get('body'), headers=request_config.get('header')) - if http_response.status_code == 200: - response = http_response.json() - else: - response = {"retcode": http_response.status_code, "retmsg": http_response.text} + if not hasattr(ServerRegistry, role): + return {'retcode': 404, 'retmsg': f'role "{role.lower()}" not supported'} + registry = getattr(ServerRegistry, role) + + headers = request_config.get('header', {}) + body = request_config.get('body', {}) + method = headers.get('METHOD', 'POST') + endpoint = headers.get('ENDPOINT', '') + ip = registry.get('host', '') + port = registry.get('port', '') + url = f'http://{ip}:{port}{endpoint}' + audit_logger().info(f'api request: {url}') + + response = request(method=method, url=url, json=body, headers=headers) + response = ( + response.json() if response.status_code == 200 + else {'retcode': response.status_code, 'retmsg': response.text} + ) audit_logger().info(response) return response diff --git a/python/fate_flow/utils/grpc_utils.py b/python/fate_flow/utils/grpc_utils.py index 7fc977e96..2c754f2ed 100644 --- a/python/fate_flow/utils/grpc_utils.py +++ b/python/fate_flow/utils/grpc_utils.py @@ -20,13 +20,13 @@ from fate_flow.db.job_default_config import JobDefaultConfig from fate_flow.db.runtime_config import RuntimeConfig -from fate_flow.settings import FATE_FLOW_SERVICE_NAME, GRPC_PORT, HOST +from fate_flow.settings import FATE_FLOW_SERVICE_NAME, GRPC_OPTIONS, GRPC_PORT, HOST from fate_flow.utils.log_utils import audit_logger from fate_flow.utils.requests_utils import request def get_command_federation_channel(host, port): - channel = grpc.insecure_channel(f"{host}:{port}") + channel = grpc.insecure_channel(f"{host}:{port}", GRPC_OPTIONS) stub = proxy_pb2_grpc.DataTransferServiceStub(channel) return channel, stub diff --git a/python/fate_flow/utils/model_utils.py b/python/fate_flow/utils/model_utils.py index 9056eb394..1bccbe04a 100644 --- a/python/fate_flow/utils/model_utils.py +++ b/python/fate_flow/utils/model_utils.py @@ -105,9 +105,7 @@ def query_model_info_from_file(model_id='*', model_version='*', role='*', party_ stat_logger.exception(e) if query_filters: - for k, v in model_info.items(): - if k not in query_filters: - del model_info[k] + model_info = {k: v for k, v in model_info.items() if k in query_filters} models.append(model_info) diff --git a/python/fate_flow/worker/dependence_upload.py b/python/fate_flow/worker/dependence_upload.py index 475efda48..90cbf876d 100644 --- a/python/fate_flow/worker/dependence_upload.py +++ b/python/fate_flow/worker/dependence_upload.py @@ -17,11 +17,13 @@ import os import shutil import zipfile +import subprocess from fate_arch.common import file_utils from fate_flow.utils.log_utils import getLogger from fate_flow.db.db_models import ComponentProviderInfo from fate_flow.db.dependence_registry import DependenceRegistry +from fate_flow.db.service_registry import ServerRegistry from fate_flow.entity import ComponentProvider from fate_flow.entity.types import FateDependenceName, ComponentProviderName, FateDependenceStorageEngine from fate_flow.settings import FATE_VERSION_DEPENDENCIES_PATH @@ -64,13 +66,12 @@ def upload_dependencies_to_hadoop(cls, provider, dependence_type, storage_engine LOGGER.info(f'dependencies loading ...') if dependence_type == FateDependenceName.Python_Env.value: # todo: version python env - target_file = os.path.join(FATE_VERSION_DEPENDENCIES_PATH, provider.version, "python_env.zip") + target_file = os.path.join(FATE_VERSION_DEPENDENCIES_PATH, provider.version, "python_env.tar.gz") + venv_pack_path = os.path.join(os.getenv("VIRTUAL_ENV"), "bin/venv-pack") + subprocess.run([venv_pack_path, "-o", target_file]) source_path = os.path.dirname(os.path.dirname(os.getenv("VIRTUAL_ENV"))) cls.rewrite_pyvenv_cfg(os.path.join(os.getenv("VIRTUAL_ENV"), "pyvenv.cfg"), "python_env") - env_dir_list = ["python", "miniconda3"] - cls.zip_dir(source_path, target_file, env_dir_list) - - dependencies_conf = {"executor_python": f"./{dependence_type}/python/venv/bin/python", + dependencies_conf = {"executor_python": f"./{dependence_type}/bin/python", "driver_python": f"{os.path.join(os.getenv('VIRTUAL_ENV'), 'bin', 'python')}"} else: fate_code_dependencies = { @@ -102,9 +103,11 @@ def upload_dependencies_to_hadoop(cls, provider, dependence_type, storage_engine LOGGER.info(f'start upload') snapshot_time = DependenceRegistry.get_modify_time(source_path) + hdfs_address = ServerRegistry.FATE_ON_SPARK.get("hdfs", {}).get("name_node") + LOGGER.info(f'hdfs address: {hdfs_address}') storage_dir = f"/fate_dependence/{provider.version}" - os.system(f" {os.getenv('HADOOP_HOME')}/bin/hdfs dfs -mkdir -p {storage_dir}") - status = os.system(f"{os.getenv('HADOOP_HOME')}/bin/hdfs dfs -put -f {target_file} {storage_dir}") + os.system(f" {os.getenv('HADOOP_HOME')}/bin/hdfs dfs -mkdir -p {hdfs_address}{storage_dir}") + status = os.system(f"{os.getenv('HADOOP_HOME')}/bin/hdfs dfs -put -f {target_file} {hdfs_address}{storage_dir}") LOGGER.info(f'upload end, status is {status}') if status == 0: storage_path = os.path.join(storage_dir, os.path.basename(target_file))