Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update yaml safe loading syntax #102

Closed
wants to merge 11 commits into from
9 changes: 8 additions & 1 deletion element_deeplabcut/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,14 @@ def make(self, key):
task_mode, output_dir = (PoseEstimationTask & key).fetch1(
"task_mode", "pose_estimation_output_dir"
)

if not output_dir:
output_dir = PoseEstimationTask.infer_output_dir(
key, relative=True, mkdir=True
)
# update pose_estimation_output_dir
PoseEstimationTask.update1(
{**key, "pose_estimation_output_dir": output_dir.as_posix()}
)
output_dir = find_full_path(get_dlc_root_data_dir(), output_dir)

# Triger PoseEstimation
Expand Down
36 changes: 27 additions & 9 deletions element_deeplabcut/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pathlib import Path
from element_interface.utils import find_full_path, dict_to_uuid
from .readers import dlc_reader
import yaml

schema = dj.schema()
_linking_module = None
Expand Down Expand Up @@ -241,8 +242,7 @@ class ModelTraining(dj.Computed):
# https://github.com/DeepLabCut/DeepLabCut/issues/70

def make(self, key):
from deeplabcut import train_network # isort:skip

import deeplabcut
try:
from deeplabcut.utils.auxiliaryfunctions import (
get_model_folder,
Expand Down Expand Up @@ -288,13 +288,26 @@ def make(self, key):
)
model_train_folder = project_path / model_folder / "train"

# update init_weight
with open(model_train_folder / "pose_cfg.yaml", "r") as f:
pose_cfg = yaml.safe_load(f)
init_weights_path = Path(pose_cfg["init_weights"])

if "pose_estimation_tensorflow/models/pretrained" in init_weights_path.as_posix():
# this is the res_net models, construct new path here
init_weights_path = Path(deeplabcut.__path__[0]) / "pose_estimation_tensorflow/models/pretrained" / init_weights_path.name
else:
# this is existing snapshot weights, update path here
init_weights_path = model_train_folder / init_weights_path.name

edit_config(
model_train_folder / "pose_cfg.yaml",
{"project_path": project_path.as_posix()},
{"project_path": project_path.as_posix(),
"init_weights": init_weights_path.as_posix()},
)

# ---- Trigger DLC model training job ----
train_network_input_args = list(inspect.signature(train_network).parameters)
train_network_input_args = list(inspect.signature(deeplabcut.train_network).parameters)
train_network_kwargs = {
k: int(v) if k in ("shuffle", "trainingsetindex", "maxiters") else v
for k, v in dlc_config.items()
Expand All @@ -304,27 +317,32 @@ def make(self, key):
train_network_kwargs[k] = int(train_network_kwargs[k])

try:
train_network(dlc_cfg_filepath, **train_network_kwargs)
deeplabcut.train_network(dlc_cfg_filepath, **train_network_kwargs)
except KeyboardInterrupt: # Instructions indicate to train until interrupt
print("DLC training stopped via Keyboard Interrupt")

snapshots = list(model_train_folder.glob("*index*"))
snapshots = sorted(list(model_train_folder.glob("*index*")))
max_modified_time = 0
# DLC goes by snapshot magnitude when judging 'latest' for evaluation
# Here, we mean most recently generated
for snapshot in snapshots:
modified_time = os.path.getmtime(snapshot)
if modified_time > max_modified_time:
latest_snapshot = int(snapshot.stem[9:])
latest_snapshot_file = snapshot
latest_snapshot = int(latest_snapshot_file.stem[9:])
max_modified_time = modified_time

# update snapshotindex in the config
dlc_config["snapshotindex"] = latest_snapshot
snapshotindex = snapshots.index(latest_snapshot_file)

dlc_config["snapshotindex"] = snapshotindex
edit_config(
dlc_cfg_filepath,
{"snapshotindex": latest_snapshot},
{"snapshotindex": snapshotindex},
)

self.insert1(
{**key, "latest_snapshot": latest_snapshot, "config_template": dlc_config}
)


Loading