diff --git a/element_deeplabcut/model.py b/element_deeplabcut/model.py index 9b83921..1aa3902 100644 --- a/element_deeplabcut/model.py +++ b/element_deeplabcut/model.py @@ -354,7 +354,6 @@ def insert_new_model( *, shuffle: int, trainingsetindex, - project_path=None, model_description="", model_prefix="", paramset_idx: int = None, @@ -450,13 +449,19 @@ def insert_new_model( print("Canceled insert.") return + def _do_insert(): + cls.insert1(model_dict) + # Returns array, so check size for unambiguous truth value + if BodyPart.extract_new_body_parts(dlc_config, verbose=False).size > 0: + BodyPart.insert_from_config(dlc_config, prompt=prompt) + cls.BodyPart.insert((model_name, bp) for bp in dlc_config["bodyparts"]) + # ____ Insert into table ---- - # with cls.connection.transaction: - cls.insert1(model_dict) - # Returns array, so check size for unambiguous truth value - if BodyPart.extract_new_body_parts(dlc_config, verbose=False).size > 0: - BodyPart.insert_from_config(dlc_config, prompt=prompt) - cls.BodyPart.insert((model_name, bp) for bp in dlc_config["bodyparts"]) + if cls.connection.in_transaction: + _do_insert() + else: + with cls.connection.transaction: + _do_insert() @schema diff --git a/element_deeplabcut/train.py b/element_deeplabcut/train.py index cfb8f0b..b4f2765 100644 --- a/element_deeplabcut/train.py +++ b/element_deeplabcut/train.py @@ -318,6 +318,13 @@ def make(self, key): latest_snapshot = int(snapshot.stem[9:]) max_modified_time = modified_time + # update snapshotindex in the config + dlc_config["snapshotindex"] = latest_snapshot + edit_config( + dlc_cfg_filepath, + {"snapshotindex": latest_snapshot}, + ) + self.insert1( {**key, "latest_snapshot": latest_snapshot, "config_template": dlc_config} ) diff --git a/setup.py b/setup.py index e501313..5ddda39 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,6 @@ "datajoint>=0.13", "graphviz", "pydot", - "networkx==2.8.2", "ipykernel", "ipywidgets", ],