Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify training to update init_weights path #100

Merged
merged 4 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and
[Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention.

## [0.2.11] - 2023-11-28

+ Fix - Modify training to update init_weights path in pose_cfg.yaml

## [0.2.10] - 2023-11-20

+ Fix - Revert fixing of networkx version in setup
Expand Down
23 changes: 18 additions & 5 deletions element_deeplabcut/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pathlib import Path
from element_interface.utils import find_full_path, dict_to_uuid
from .readers import dlc_reader
import yaml

schema = dj.schema()
_linking_module = None
Expand Down Expand Up @@ -241,8 +242,7 @@ class ModelTraining(dj.Computed):
# https://github.com/DeepLabCut/DeepLabCut/issues/70

def make(self, key):
from deeplabcut import train_network # isort:skip

import deeplabcut
try:
from deeplabcut.utils.auxiliaryfunctions import (
get_model_folder,
Expand Down Expand Up @@ -288,13 +288,26 @@ def make(self, key):
)
model_train_folder = project_path / model_folder / "train"

# update init_weight
with open(model_train_folder / "pose_cfg.yaml", "r") as f:
pose_cfg = yaml.safe_load(f)
init_weights_path = Path(pose_cfg["init_weights"])

if "pose_estimation_tensorflow/models/pretrained" in init_weights_path.as_posix():
# this is the res_net models, construct new path here
init_weights_path = Path(deeplabcut.__path__[0]) / "pose_estimation_tensorflow/models/pretrained" / init_weights_path.name
else:
# this is existing snapshot weights, update path here
init_weights_path = model_train_folder / init_weights_path.name

edit_config(
model_train_folder / "pose_cfg.yaml",
{"project_path": project_path.as_posix()},
{"project_path": project_path.as_posix(),
"init_weights": init_weights_path.as_posix()},
)

# ---- Trigger DLC model training job ----
train_network_input_args = list(inspect.signature(train_network).parameters)
train_network_input_args = list(inspect.signature(deeplabcut.train_network).parameters)
train_network_kwargs = {
k: int(v) if k in ("shuffle", "trainingsetindex", "maxiters") else v
for k, v in dlc_config.items()
Expand All @@ -304,7 +317,7 @@ def make(self, key):
train_network_kwargs[k] = int(train_network_kwargs[k])

try:
train_network(dlc_cfg_filepath, **train_network_kwargs)
deeplabcut.train_network(dlc_cfg_filepath, **train_network_kwargs)
except KeyboardInterrupt: # Instructions indicate to train until interrupt
print("DLC training stopped via Keyboard Interrupt")

Expand Down
2 changes: 1 addition & 1 deletion element_deeplabcut/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""
Package metadata
"""
__version__ = "0.2.10"
__version__ = "0.2.11"
Loading