Skip to content

Commit

Permalink
Merge pull request #100 from sidhulyalkar/main
Browse files Browse the repository at this point in the history
Modify training to update init_weights path
  • Loading branch information
Thinh Nguyen authored Nov 29, 2023
2 parents 11802ff + fcc7007 commit fcdb2ec
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 6 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and
[Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention.

## [0.2.11] - 2023-11-28

+ Fix - Modify training to update init_weights path in pose_cfg.yaml

## [0.2.10] - 2023-11-20

+ Fix - Revert fixing of networkx version in setup
Expand Down
23 changes: 18 additions & 5 deletions element_deeplabcut/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pathlib import Path
from element_interface.utils import find_full_path, dict_to_uuid
from .readers import dlc_reader
import yaml

schema = dj.schema()
_linking_module = None
Expand Down Expand Up @@ -241,8 +242,7 @@ class ModelTraining(dj.Computed):
# https://github.com/DeepLabCut/DeepLabCut/issues/70

def make(self, key):
from deeplabcut import train_network # isort:skip

import deeplabcut
try:
from deeplabcut.utils.auxiliaryfunctions import (
get_model_folder,
Expand Down Expand Up @@ -288,13 +288,26 @@ def make(self, key):
)
model_train_folder = project_path / model_folder / "train"

# update init_weight
with open(model_train_folder / "pose_cfg.yaml", "r") as f:
pose_cfg = yaml.safe_load(f)
init_weights_path = Path(pose_cfg["init_weights"])

if "pose_estimation_tensorflow/models/pretrained" in init_weights_path.as_posix():
# this is the res_net models, construct new path here
init_weights_path = Path(deeplabcut.__path__[0]) / "pose_estimation_tensorflow/models/pretrained" / init_weights_path.name
else:
# this is existing snapshot weights, update path here
init_weights_path = model_train_folder / init_weights_path.name

edit_config(
model_train_folder / "pose_cfg.yaml",
{"project_path": project_path.as_posix()},
{"project_path": project_path.as_posix(),
"init_weights": init_weights_path.as_posix()},
)

# ---- Trigger DLC model training job ----
train_network_input_args = list(inspect.signature(train_network).parameters)
train_network_input_args = list(inspect.signature(deeplabcut.train_network).parameters)
train_network_kwargs = {
k: int(v) if k in ("shuffle", "trainingsetindex", "maxiters") else v
for k, v in dlc_config.items()
Expand All @@ -304,7 +317,7 @@ def make(self, key):
train_network_kwargs[k] = int(train_network_kwargs[k])

try:
train_network(dlc_cfg_filepath, **train_network_kwargs)
deeplabcut.train_network(dlc_cfg_filepath, **train_network_kwargs)
except KeyboardInterrupt: # Instructions indicate to train until interrupt
print("DLC training stopped via Keyboard Interrupt")

Expand Down
2 changes: 1 addition & 1 deletion element_deeplabcut/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""
Package metadata
"""
__version__ = "0.2.10"
__version__ = "0.2.11"

0 comments on commit fcdb2ec

Please sign in to comment.