Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
sidhulyalkar committed Nov 20, 2023
2 parents 09a75cb + fbd1ff6 commit c0181c1
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 8 deletions.
19 changes: 12 additions & 7 deletions element_deeplabcut/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,6 @@ def insert_new_model(
*,
shuffle: int,
trainingsetindex,
project_path=None,
model_description="",
model_prefix="",
paramset_idx: int = None,
Expand Down Expand Up @@ -450,13 +449,19 @@ def insert_new_model(
print("Canceled insert.")
return

def _do_insert():
cls.insert1(model_dict)
# Returns array, so check size for unambiguous truth value
if BodyPart.extract_new_body_parts(dlc_config, verbose=False).size > 0:
BodyPart.insert_from_config(dlc_config, prompt=prompt)
cls.BodyPart.insert((model_name, bp) for bp in dlc_config["bodyparts"])

# ____ Insert into table ----
# with cls.connection.transaction:
cls.insert1(model_dict)
# Returns array, so check size for unambiguous truth value
if BodyPart.extract_new_body_parts(dlc_config, verbose=False).size > 0:
BodyPart.insert_from_config(dlc_config, prompt=prompt)
cls.BodyPart.insert((model_name, bp) for bp in dlc_config["bodyparts"])
if cls.connection.in_transaction:
_do_insert()
else:
with cls.connection.transaction:
_do_insert()


@schema
Expand Down
7 changes: 7 additions & 0 deletions element_deeplabcut/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,13 @@ def make(self, key):
latest_snapshot = int(snapshot.stem[9:])
max_modified_time = modified_time

# update snapshotindex in the config
dlc_config["snapshotindex"] = latest_snapshot
edit_config(
dlc_cfg_filepath,
{"snapshotindex": latest_snapshot},
)

self.insert1(
{**key, "latest_snapshot": latest_snapshot, "config_template": dlc_config}
)
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
"datajoint>=0.13",
"graphviz",
"pydot",
"networkx==2.8.2",
"ipykernel",
"ipywidgets",
],
Expand Down

0 comments on commit c0181c1

Please sign in to comment.