Skip to content

Commit

Permalink
Merge 'main' of datajoint/element-deeplabcut
Browse files Browse the repository at this point in the history
  • Loading branch information
kabilar committed Aug 7, 2023
2 parents b16c386 + 78ac980 commit 8ca047d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 13 deletions.
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and
[Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention.

## [0.2.7] - 2023-08-04
## [0.2.8] - 2023-08-07

+ Update - GitHub Actions with new reusable workflows
+ Update - Readme instructions

## [0.2.7] - 2023-08-04

+ Fix - Update the project path in the pose config file to train the model

## [0.2.6] - 2023-05-22

+ Add - DeepLabCut, NWB, and DANDI citations
Expand Down Expand Up @@ -73,6 +77,7 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and
graciously provided by the Mathis Lab.
+ Add - Support for 2d single-animal models

[0.2.8]: https://github.com/datajoint/element-deeplabcut/releases/tag/0.2.8
[0.2.7]: https://github.com/datajoint/element-deeplabcut/releases/tag/0.2.7
[0.2.6]: https://github.com/datajoint/element-deeplabcut/releases/tag/0.2.6
[0.2.5]: https://github.com/datajoint/element-deeplabcut/releases/tag/0.2.5
Expand Down
28 changes: 16 additions & 12 deletions element_deeplabcut/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ def make(self, key):
try:
from deeplabcut.utils.auxiliaryfunctions import (
get_model_folder,
edit_config,
) # isort:skip
except ImportError:
from deeplabcut.utils.auxiliaryfunctions import (
Expand Down Expand Up @@ -278,6 +279,20 @@ def make(self, key):
# Write dlc config file to base project folder
dlc_cfg_filepath = dlc_reader.save_yaml(project_path, dlc_config)

# ---- Update the project path in the DLC pose configuration (yaml) files ----
model_folder = get_model_folder(
trainFraction=dlc_config["train_fraction"],
shuffle=dlc_config["shuffle"],
cfg=dlc_config,
modelprefix=dlc_config["modelprefix"],
)
model_train_folder = project_path / model_folder / "train"

edit_config(
model_train_folder / "pose_cfg.yaml",
{"project_path": project_path.as_posix()},
)

# ---- Trigger DLC model training job ----
train_network_input_args = list(inspect.signature(train_network).parameters)
train_network_kwargs = {
Expand All @@ -293,18 +308,7 @@ def make(self, key):
except KeyboardInterrupt: # Instructions indicate to train until interrupt
print("DLC training stopped via Keyboard Interrupt")

snapshots = list(
(
project_path
/ get_model_folder(
trainFraction=dlc_config["train_fraction"],
shuffle=dlc_config["shuffle"],
cfg=dlc_config,
modelprefix=dlc_config["modelprefix"],
)
/ "train"
).glob("*index*")
)
snapshots = list(model_train_folder.glob("*index*"))
max_modified_time = 0
# DLC goes by snapshot magnitude when judging 'latest' for evaluation
# Here, we mean most recently generated
Expand Down

0 comments on commit 8ca047d

Please sign in to comment.