From 87e0aaeb197343937886dae992cf2ca08b37f560 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Tue, 28 Nov 2023 17:25:20 -0800 Subject: [PATCH 1/4] update init_weights_path before training model --- element_deeplabcut/train.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/element_deeplabcut/train.py b/element_deeplabcut/train.py index b4f2765..6cde106 100644 --- a/element_deeplabcut/train.py +++ b/element_deeplabcut/train.py @@ -11,6 +11,7 @@ from pathlib import Path from element_interface.utils import find_full_path, dict_to_uuid from .readers import dlc_reader +import yaml schema = dj.schema() _linking_module = None @@ -241,8 +242,7 @@ class ModelTraining(dj.Computed): # https://github.com/DeepLabCut/DeepLabCut/issues/70 def make(self, key): - from deeplabcut import train_network # isort:skip - + import deeplabcut try: from deeplabcut.utils.auxiliaryfunctions import ( get_model_folder, @@ -288,13 +288,26 @@ def make(self, key): ) model_train_folder = project_path / model_folder / "train" + # update init_weight + with open(model_train_folder / "pose_cfg.yaml", "r") as f: + pose_cfg = yaml.safe_load(f) + init_weights_path = Path(pose_cfg["init_weights"]) + + if "pose_estimation_tensorflow/models/pretrained" in init_weights_path.as_posix(): + # this is the res_net models, construct new path here + init_weights_path = Path(deeplabcut.__file__).parent / "pose_estimation_tensorflow/models/pretrained" / init_weights_path.name + else: + # this is existing snapshot weights, update path here + init_weights_path = model_train_folder / init_weights_path.name + edit_config( model_train_folder / "pose_cfg.yaml", - {"project_path": project_path.as_posix()}, + {"project_path": project_path.as_posix(), + "init_weights": init_weights_path.as_posix()}, ) # ---- Trigger DLC model training job ---- - train_network_input_args = list(inspect.signature(train_network).parameters) + train_network_input_args = list(inspect.signature(deeplabcut.train_network).parameters) train_network_kwargs = { k: int(v) if k in ("shuffle", "trainingsetindex", "maxiters") else v for k, v in dlc_config.items() @@ -304,7 +317,7 @@ def make(self, key): train_network_kwargs[k] = int(train_network_kwargs[k]) try: - train_network(dlc_cfg_filepath, **train_network_kwargs) + deeplabcut.train_network(dlc_cfg_filepath, **train_network_kwargs) except KeyboardInterrupt: # Instructions indicate to train until interrupt print("DLC training stopped via Keyboard Interrupt") From 935938e12a5892aba91951a516b5631221c270ff Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Tue, 28 Nov 2023 17:36:04 -0800 Subject: [PATCH 2/4] update to use deeplabcut.__path__ --- element_deeplabcut/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/element_deeplabcut/train.py b/element_deeplabcut/train.py index 6cde106..2f02193 100644 --- a/element_deeplabcut/train.py +++ b/element_deeplabcut/train.py @@ -295,7 +295,7 @@ def make(self, key): if "pose_estimation_tensorflow/models/pretrained" in init_weights_path.as_posix(): # this is the res_net models, construct new path here - init_weights_path = Path(deeplabcut.__file__).parent / "pose_estimation_tensorflow/models/pretrained" / init_weights_path.name + init_weights_path = Path(deeplabcut.__path__[0]) / "pose_estimation_tensorflow/models/pretrained" / init_weights_path.name else: # this is existing snapshot weights, update path here init_weights_path = model_train_folder / init_weights_path.name From 4543d087ccf51540d77a8b04facd4664dc8fae3d Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Tue, 28 Nov 2023 19:31:35 -0800 Subject: [PATCH 3/4] update version --- element_deeplabcut/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/element_deeplabcut/version.py b/element_deeplabcut/version.py index 1719e56..1ce9d0b 100644 --- a/element_deeplabcut/version.py +++ b/element_deeplabcut/version.py @@ -1,4 +1,4 @@ """ Package metadata """ -__version__ = "0.2.10" +__version__ = "0.2.11" From fcc700737d6c1ffa7f167a2ed2b76b6e85443dc2 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Tue, 28 Nov 2023 19:34:42 -0800 Subject: [PATCH 4/4] update changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f7f5241..b9a9e10 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,10 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention. +## [0.2.11] - 2023-11-28 + ++ Fix - Modify training to update init_weights path in pose_cfg.yaml + ## [0.2.10] - 2023-11-20 + Fix - Revert fixing of networkx version in setup