Skip to content

Commit

Permalink
Merge pull request #190 from FederatedAI/feature-1.7.2-fate_flow-merge
Browse files Browse the repository at this point in the history
Feature 1.7.2 fate flow merge
  • Loading branch information
zhihuiwan authored Feb 23, 2022
2 parents 2342e30 + cf505ab commit 9bf35dd
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 71 deletions.
56 changes: 39 additions & 17 deletions python/fate_flow/controller/job_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,8 +590,9 @@ def log_reload(cls, job):
def output_reload(cls, job, source_tasks: dict, target_tasks: dict):
# model reload
schedule_logger(job.f_job_id).info("start reload model")
cls.output_model_reload(job)
cls.checkpoint_reload(job)
source_job = JobSaver.query_job(job_id=job.f_inheritance_info.get("job_id"))[0]
cls.output_model_reload(job, source_job)
cls.checkpoint_reload(job, source_job)
schedule_logger(job.f_job_id).info("start reload data")
source_tracker_dict = cls.load_task_tracker(source_tasks)
target_tracker_dict = cls.load_task_tracker(target_tasks)
Expand Down Expand Up @@ -634,22 +635,43 @@ def status_reload(cls, job, source_tasks, target_tasks):
schedule_logger(job.f_job_id).info("reload status success")

@classmethod
def output_model_reload(cls, job):

model_id = model_utils.gen_party_model_id(job.f_runtime_conf.get("job_parameters").get("common").get("model_id"),
job.f_role, job.f_party_id)
PipelinedModel(model_id=model_id, model_version=job.f_job_id).reload_component_model(model_id=model_id, model_version=job.f_inheritance_info.get("job_id"),
component_list=job.f_inheritance_info.get("component_list"))
@classmethod
def checkpoint_reload(cls, job):
def output_model_reload(cls, job, source_job):
source_model_id = model_utils.gen_party_model_id(
source_job.f_runtime_conf.get("job_parameters").get("common").get("model_id"),
job.f_role,
job.f_party_id
)
model_id = model_utils.gen_party_model_id(
job.f_runtime_conf.get("job_parameters").get("common").get("model_id"),
job.f_role,
job.f_party_id
)
PipelinedModel(
model_id=model_id,
model_version=job.f_job_id
).reload_component_model(
model_id=source_model_id,
model_version=job.f_inheritance_info.get("job_id"),
component_list=job.f_inheritance_info.get("component_list")
)

@classmethod
def checkpoint_reload(cls, job, source_job):
for component_name in job.f_inheritance_info.get("component_list"):
path = CheckpointManager(role=job.f_role, party_id=job.f_party_id,
component_name=component_name, model_version=job.f_inheritance_info.get("job_id"),
model_id=job.f_runtime_conf.get("job_parameters").get("common").get("model_id")).directory
target_path = CheckpointManager(role=job.f_role, party_id=job.f_party_id,
component_name=component_name, model_version=job.f_job_id,
model_id=job.f_runtime_conf.get("job_parameters").get("common").get(
"model_id")).directory
path = CheckpointManager(
role=job.f_role,
party_id=job.f_party_id,
component_name=component_name,
model_version=job.f_inheritance_info.get("job_id"),
model_id=source_job.f_runtime_conf.get("job_parameters").get("common").get("model_id")
).directory
target_path = CheckpointManager(
role=job.f_role,
party_id=job.f_party_id,
component_name=component_name,
model_version=job.f_job_id,
model_id=job.f_runtime_conf.get("job_parameters").get("common").get("model_id")
).directory
if os.path.exists(path):
if os.path.exists(target_path):
shutil.rmtree(target_path)
Expand Down
114 changes: 62 additions & 52 deletions python/fate_flow/db/db_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,20 @@ class Meta:
def init_database_tables():
members = inspect.getmembers(sys.modules[__name__], inspect.isclass)
table_objs = []
create_failed_list = []
for name, obj in members:
if obj != DataBaseModel and issubclass(obj, DataBaseModel):
table_objs.append(obj)
DB.create_tables(table_objs)
LOGGER.info(f"start create table {obj.__name__}")
try:
obj.create_table()
LOGGER.info(f"create table success: {obj.__name__}")
except Exception as e:
LOGGER.exception(e)
create_failed_list.append(obj.__name__)
if create_failed_list:
LOGGER.info(f"create tables failed: {create_failed_list}")
raise Exception(f"create tables failed: {create_failed_list}")


def fill_db_model_object(model_object, human_model_dict):
Expand All @@ -141,40 +151,40 @@ def fill_db_model_object(model_object, human_model_dict):

class Job(DataBaseModel):
# multi-party common configuration
f_user_id = CharField(max_length=25, index=True, null=True)
f_user_id = CharField(max_length=25, null=True)
f_job_id = CharField(max_length=25, index=True)
f_name = CharField(max_length=500, null=True, default='')
f_description = TextField(null=True, default='')
f_tag = CharField(max_length=50, null=True, index=True, default='')
f_tag = CharField(max_length=50, null=True, default='')
f_dsl = JSONField()
f_runtime_conf = JSONField()
f_runtime_conf_on_party = JSONField()
f_train_runtime_conf = JSONField(null=True)
f_roles = JSONField()
f_initiator_role = CharField(max_length=50, index=True)
f_initiator_party_id = CharField(max_length=50, index=True)
f_status = CharField(max_length=50, index=True)
f_status_code = IntegerField(null=True, index=True)
f_initiator_role = CharField(max_length=50)
f_initiator_party_id = CharField(max_length=50)
f_status = CharField(max_length=50)
f_status_code = IntegerField(null=True)
f_user = JSONField()
# this party configuration
f_role = CharField(max_length=50, index=True)
f_party_id = CharField(max_length=10, index=True)
f_is_initiator = BooleanField(null=True, index=True, default=False)
f_is_initiator = BooleanField(null=True, default=False)
f_progress = IntegerField(null=True, default=0)
f_ready_signal = BooleanField(index=True, default=False)
f_ready_signal = BooleanField(default=False)
f_ready_time = BigIntegerField(null=True)
f_cancel_signal = BooleanField(index=True, default=False)
f_cancel_signal = BooleanField(default=False)
f_cancel_time = BigIntegerField(null=True)
f_rerun_signal = BooleanField(index=True, default=False)
f_rerun_signal = BooleanField(default=False)
f_end_scheduling_updates = IntegerField(null=True, default=0)

f_engine_name = CharField(max_length=50, null=True, index=True)
f_engine_type = CharField(max_length=10, null=True, index=True)
f_cores = IntegerField(index=True, default=0)
f_memory = IntegerField(index=True, default=0) # MB
f_remaining_cores = IntegerField(index=True, default=0)
f_remaining_memory = IntegerField(index=True, default=0) # MB
f_resource_in_use = BooleanField(index=True, default=False)
f_engine_name = CharField(max_length=50, null=True)
f_engine_type = CharField(max_length=10, null=True)
f_cores = IntegerField(default=0)
f_memory = IntegerField(default=0) # MB
f_remaining_cores = IntegerField(default=0)
f_remaining_memory = IntegerField(default=0) # MB
f_resource_in_use = BooleanField(default=False)
f_apply_resource_time = BigIntegerField(null=True)
f_return_resource_time = BigIntegerField(null=True)

Expand All @@ -196,26 +206,26 @@ class Task(DataBaseModel):
# multi-party common configuration
f_job_id = CharField(max_length=25, index=True)
f_component_name = TextField()
f_component_module = CharField(max_length=200, index=True)
f_task_id = CharField(max_length=100, index=True)
f_task_version = BigIntegerField(index=True)
f_initiator_role = CharField(max_length=50, index=True)
f_initiator_party_id = CharField(max_length=50, index=True, default=-1)
f_federated_mode = CharField(max_length=10, index=True)
f_federated_status_collect_type = CharField(max_length=10, index=True)
f_component_module = CharField(max_length=200)
f_task_id = CharField(max_length=100)
f_task_version = BigIntegerField()
f_initiator_role = CharField(max_length=50)
f_initiator_party_id = CharField(max_length=50, default=-1)
f_federated_mode = CharField(max_length=10)
f_federated_status_collect_type = CharField(max_length=10)
f_status = CharField(max_length=50, index=True)
f_status_code = IntegerField(null=True, index=True)
f_auto_retries = IntegerField(default=0, index=True)
f_status_code = IntegerField(null=True)
f_auto_retries = IntegerField(default=0)
f_auto_retry_delay = IntegerField(default=0)
# this party configuration
f_role = CharField(max_length=50, index=True)
f_party_id = CharField(max_length=10, index=True)
f_run_on_this_party = BooleanField(null=True, index=True, default=False)
f_worker_id = CharField(null=True, max_length=100, index=True)
f_worker_id = CharField(null=True, max_length=100)
f_cmd = JSONField(null=True)
f_run_ip = CharField(max_length=100, null=True)
f_run_pid = IntegerField(null=True)
f_party_status = CharField(max_length=50, index=True)
f_party_status = CharField(max_length=50)
f_provider_info = JSONField()
f_component_parameters = JSONField()
f_engine_conf = JSONField(null=True)
Expand Down Expand Up @@ -260,11 +270,11 @@ class Meta:
f_task_version = BigIntegerField(null=True, index=True)
f_role = CharField(max_length=50, index=True)
f_party_id = CharField(max_length=10, index=True)
f_metric_namespace = CharField(max_length=180, index=True)
f_metric_name = CharField(max_length=180, index=True)
f_metric_namespace = CharField(max_length=180)
f_metric_name = CharField(max_length=180)
f_key = CharField(max_length=200)
f_value = LongTextField()
f_type = IntegerField(index=True) # 0 is data, 1 is meta
f_type = IntegerField() # 0 is data, 1 is meta


class TrackingOutputDataInfo(DataBaseModel):
Expand Down Expand Up @@ -294,7 +304,7 @@ class Meta:
f_job_id = CharField(max_length=25, index=True)
f_component_name = TextField()
f_task_id = CharField(max_length=100, null=True, index=True)
f_task_version = BigIntegerField(null=True, index=True)
f_task_version = BigIntegerField(null=True)
f_data_name = CharField(max_length=30)
# this party configuration
f_role = CharField(max_length=50, index=True)
Expand All @@ -305,17 +315,17 @@ class Meta:


class MachineLearningModelInfo(DataBaseModel):
f_role = CharField(max_length=50, index=True)
f_party_id = CharField(max_length=10, index=True)
f_role = CharField(max_length=50)
f_party_id = CharField(max_length=10)
f_roles = JSONField(default={})
f_job_id = CharField(max_length=25, index=True)
f_model_id = CharField(max_length=100, index=True)
f_model_version = CharField(max_length=100, index=True)
f_loaded_times = IntegerField(default=0)
f_size = BigIntegerField(default=0)
f_description = TextField(null=True, default='')
f_initiator_role = CharField(max_length=50, index=True)
f_initiator_party_id = CharField(max_length=50, index=True, default=-1)
f_initiator_role = CharField(max_length=50)
f_initiator_party_id = CharField(max_length=50, default=-1)
f_runtime_conf = JSONField(default={})
f_train_dsl = JSONField(default={})
f_train_runtime_conf = JSONField(default={})
Expand All @@ -334,8 +344,8 @@ class Meta:

class DataTableTracking(DataBaseModel):
f_table_id = BigAutoField(primary_key=True)
f_table_name = CharField(max_length=300, index=True, null=True)
f_table_namespace = CharField(max_length=300, index=True, null=True)
f_table_name = CharField(max_length=300, null=True)
f_table_namespace = CharField(max_length=300, null=True)
f_job_id = CharField(max_length=25, index=True, null=True)
f_have_parent = BooleanField(default=False)
f_parent_number = IntegerField(default=0)
Expand All @@ -350,13 +360,13 @@ class Meta:


class CacheRecord(DataBaseModel):
f_cache_key = CharField(max_length=500, primary_key=True)
f_cache_key = CharField(max_length=500)
f_cache = JsonSerializedField()
f_job_id = CharField(max_length=25, index=True, null=True)
f_role = CharField(max_length=50, index=True, null=True)
f_party_id = CharField(max_length=10, index=True, null=True)
f_component_name = TextField(null=True)
f_task_id = CharField(max_length=100, null=True, index=True)
f_task_id = CharField(max_length=100, null=True)
f_task_version = BigIntegerField(null=True, index=True)
f_cache_name = CharField(max_length=50, null=True)
t_ttl = BigIntegerField(default=0)
Expand All @@ -376,7 +386,7 @@ class Meta:

class Tag(DataBaseModel):
f_id = BigAutoField(primary_key=True)
f_name = CharField(max_length=100, index=True, unique=True)
f_name = CharField(max_length=100, unique=True)
f_desc = TextField(null=True)

class Meta:
Expand Down Expand Up @@ -410,7 +420,7 @@ class Meta:
f_party_id = CharField(max_length=10, index=True)
f_component_name = TextField()
f_task_id = CharField(max_length=50, null=True, index=True)
f_task_version = CharField(max_length=50, null=True, index=True)
f_task_version = CharField(max_length=50, null=True)
f_summary = LongTextField()


Expand All @@ -420,8 +430,8 @@ class ModelOperationLog(DataBaseModel):
f_initiator_role = CharField(max_length=50, index=True, null=True)
f_initiator_party_id = CharField(max_length=10, index=True, null=True)
f_request_ip = CharField(max_length=20, null=True)
f_model_id = CharField(max_length=100, index=True)
f_model_version = CharField(max_length=100, index=True)
f_model_id = CharField(max_length=100)
f_model_version = CharField(max_length=100)

class Meta:
db_table = "t_model_operation_log"
Expand All @@ -432,11 +442,11 @@ class EngineRegistry(DataBaseModel):
f_engine_name = CharField(max_length=50, index=True)
f_engine_entrance = CharField(max_length=50, index=True)
f_engine_config = JSONField()
f_cores = IntegerField(index=True)
f_memory = IntegerField(index=True) # MB
f_remaining_cores = IntegerField(index=True)
f_remaining_memory = IntegerField(index=True) # MB
f_nodes = IntegerField(index=True)
f_cores = IntegerField()
f_memory = IntegerField() # MB
f_remaining_cores = IntegerField()
f_remaining_memory = IntegerField() # MB
f_nodes = IntegerField()

class Meta:
db_table = "t_engine_registry"
Expand Down Expand Up @@ -481,9 +491,9 @@ class WorkerInfo(DataBaseModel):
f_worker_id = CharField(max_length=100, primary_key=True)
f_worker_name = CharField(max_length=50, index=True)
f_job_id = CharField(max_length=25, index=True)
f_task_id = CharField(max_length=100, index=True)
f_task_id = CharField(max_length=100)
f_task_version = BigIntegerField(index=True)
f_role = CharField(max_length=50, index=True)
f_role = CharField(max_length=50)
f_party_id = CharField(max_length=10, index=True)
f_run_ip = CharField(max_length=100, null=True)
f_run_pid = IntegerField(null=True)
Expand Down
2 changes: 2 additions & 0 deletions python/fate_flow/pipelined_model/mysql_model_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from peewee import Model, CharField, BigIntegerField, TextField, CompositeKey, IntegerField, PeeweeException
from playhouse.pool import PooledMySQLDatabase

from fate_arch.common.conf_utils import decrypt_database_config
from fate_flow.pipelined_model.pipelined_model import PipelinedModel
from fate_flow.pipelined_model.model_storage_base import ModelStorageBase
from fate_flow.utils.log_utils import getLogger
Expand Down Expand Up @@ -149,6 +150,7 @@ def restore(self, model_id: str, model_version: str, store_address: dict):
def get_connection(store_address: dict):
store_address = deepcopy(store_address)
db_name = store_address.pop('database')
store_address = decrypt_database_config(store_address, passwd_key="password")
del store_address['storage']
DB.init(db_name, **store_address)

Expand Down
5 changes: 3 additions & 2 deletions python/fate_flow/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from fate_arch.computing import ComputingEngine
from fate_arch.common import engine_utils
from fate_arch.common.conf_utils import get_base_config
from fate_arch.common.conf_utils import get_base_config, decrypt_database_config
from fate_flow.utils.base_utils import get_fate_flow_directory
from fate_flow.utils.log_utils import LoggerFactory, getLogger

Expand Down Expand Up @@ -61,7 +61,7 @@
ENGINES = engine_utils.get_engines()
IS_STANDALONE = engine_utils.is_standalone()

DATABASE = get_base_config("database", {})
DATABASE = decrypt_database_config()
ZOOKEEPER = get_base_config("zookeeper", {})
FATE_FLOW_SERVER_START_CONFIG_ITEMS = {
"use_registry",
Expand All @@ -71,6 +71,7 @@
"database",
"zookeeper",
"enable_model_store",
"private_key", "encrypt_password", "encrypt_module"
}

# Registry
Expand Down

0 comments on commit 9bf35dd

Please sign in to comment.