diff --git a/element_deeplabcut/model.py b/element_deeplabcut/model.py index 8a6e47f..c06f617 100644 --- a/element_deeplabcut/model.py +++ b/element_deeplabcut/model.py @@ -706,7 +706,14 @@ def make(self, key): task_mode, output_dir = (PoseEstimationTask & key).fetch1( "task_mode", "pose_estimation_output_dir" ) - + if not output_dir: + output_dir = PoseEstimationTask.infer_output_dir( + key, relative=True, mkdir=True + ) + # update pose_estimation_output_dir + PoseEstimationTask.update1( + {**key, "pose_estimation_output_dir": output_dir.as_posix()} + ) output_dir = find_full_path(get_dlc_root_data_dir(), output_dir) # Triger PoseEstimation diff --git a/element_deeplabcut/train.py b/element_deeplabcut/train.py index b4f2765..750f3bd 100644 --- a/element_deeplabcut/train.py +++ b/element_deeplabcut/train.py @@ -11,6 +11,7 @@ from pathlib import Path from element_interface.utils import find_full_path, dict_to_uuid from .readers import dlc_reader +import yaml schema = dj.schema() _linking_module = None @@ -241,8 +242,7 @@ class ModelTraining(dj.Computed): # https://github.com/DeepLabCut/DeepLabCut/issues/70 def make(self, key): - from deeplabcut import train_network # isort:skip - + import deeplabcut try: from deeplabcut.utils.auxiliaryfunctions import ( get_model_folder, @@ -288,13 +288,26 @@ def make(self, key): ) model_train_folder = project_path / model_folder / "train" + # update init_weight + with open(model_train_folder / "pose_cfg.yaml", "r") as f: + pose_cfg = yaml.safe_load(f) + init_weights_path = Path(pose_cfg["init_weights"]) + + if "pose_estimation_tensorflow/models/pretrained" in init_weights_path.as_posix(): + # this is the res_net models, construct new path here + init_weights_path = Path(deeplabcut.__path__[0]) / "pose_estimation_tensorflow/models/pretrained" / init_weights_path.name + else: + # this is existing snapshot weights, update path here + init_weights_path = model_train_folder / init_weights_path.name + edit_config( model_train_folder / "pose_cfg.yaml", - {"project_path": project_path.as_posix()}, + {"project_path": project_path.as_posix(), + "init_weights": init_weights_path.as_posix()}, ) # ---- Trigger DLC model training job ---- - train_network_input_args = list(inspect.signature(train_network).parameters) + train_network_input_args = list(inspect.signature(deeplabcut.train_network).parameters) train_network_kwargs = { k: int(v) if k in ("shuffle", "trainingsetindex", "maxiters") else v for k, v in dlc_config.items() @@ -304,27 +317,32 @@ def make(self, key): train_network_kwargs[k] = int(train_network_kwargs[k]) try: - train_network(dlc_cfg_filepath, **train_network_kwargs) + deeplabcut.train_network(dlc_cfg_filepath, **train_network_kwargs) except KeyboardInterrupt: # Instructions indicate to train until interrupt print("DLC training stopped via Keyboard Interrupt") - snapshots = list(model_train_folder.glob("*index*")) + snapshots = sorted(list(model_train_folder.glob("*index*"))) max_modified_time = 0 # DLC goes by snapshot magnitude when judging 'latest' for evaluation # Here, we mean most recently generated for snapshot in snapshots: modified_time = os.path.getmtime(snapshot) if modified_time > max_modified_time: - latest_snapshot = int(snapshot.stem[9:]) + latest_snapshot_file = snapshot + latest_snapshot = int(latest_snapshot_file.stem[9:]) max_modified_time = modified_time # update snapshotindex in the config - dlc_config["snapshotindex"] = latest_snapshot + snapshotindex = snapshots.index(latest_snapshot_file) + + dlc_config["snapshotindex"] = snapshotindex edit_config( dlc_cfg_filepath, - {"snapshotindex": latest_snapshot}, + {"snapshotindex": snapshotindex}, ) self.insert1( {**key, "latest_snapshot": latest_snapshot, "config_template": dlc_config} ) + +