diff --git a/keypoint_moseq/io.py b/keypoint_moseq/io.py index ab97372..889a149 100644 --- a/keypoint_moseq/io.py +++ b/keypoint_moseq/io.py @@ -31,9 +31,7 @@ def _build_yaml(sections, comments): return "\n".join(text_blocks) -def _get_path( - project_dir, model_name, path, filename, pathname_for_error_msg="path" -): +def _get_path(project_dir, model_name, path, filename, pathname_for_error_msg="path"): if path is None: assert project_dir is not None and model_name is not None, fill( f"`model_name` and `project_dir` are required if `{pathname_for_error_msg}` is None." @@ -258,16 +256,10 @@ def load_config(project_dir, check_if_valid=True, build_indexes=True): if build_indexes: config["anterior_idxs"] = jnp.array( - [ - config["use_bodyparts"].index(bp) - for bp in config["anterior_bodyparts"] - ] + [config["use_bodyparts"].index(bp) for bp in config["anterior_bodyparts"]] ) config["posterior_idxs"] = jnp.array( - [ - config["use_bodyparts"].index(bp) - for bp in config["posterior_bodyparts"] - ] + [config["use_bodyparts"].index(bp) for bp in config["posterior_bodyparts"]] ) if not "skeleton" in config or config["skeleton"] is None: @@ -297,9 +289,7 @@ def update_config(project_dir, **kwargs): >>> print(load_config(project_dir)['trans_hypparams']['kappa']) 100 """ - config = load_config( - project_dir, check_if_valid=False, build_indexes=False - ) + config = load_config(project_dir, check_if_valid=False, build_indexes=False) config.update(kwargs) generate_config(project_dir, **config) @@ -364,14 +354,9 @@ def setup_project( f"{deeplabcut_config} does not exists or is not a" " valid yaml file" ) - if ( - "multianimalproject" in dlc_config - and dlc_config["multianimalproject"] - ): + if "multianimalproject" in dlc_config and dlc_config["multianimalproject"]: dlc_options["bodyparts"] = dlc_config["multianimalbodyparts"] - dlc_options["use_bodyparts"] = dlc_config[ - "multianimalbodyparts" - ] + dlc_options["use_bodyparts"] = dlc_config["multianimalbodyparts"] else: dlc_options["bodyparts"] = dlc_config["bodyparts"] dlc_options["use_bodyparts"] = dlc_config["bodyparts"] @@ -392,15 +377,12 @@ def setup_project( ) skeleton = slp_file.skeletons[0] node_names = skeleton.node_names - edge_names = [ - [e.source.name, e.destination.name] for e in skeleton.edges - ] + edge_names = [[e.source.name, e.destination.name] for e in skeleton.edges] else: with h5py.File(sleap_file, "r") as f: node_names = [n.decode("utf-8") for n in f["node_names"]] edge_names = [ - [n.decode("utf-8") for n in edge] - for edge in f["edge_names"] + [n.decode("utf-8") for n in edge] for edge in f["edge_names"] ] sleap_options["bodyparts"] = node_names sleap_options["use_bodyparts"] = node_names @@ -457,15 +439,11 @@ def load_pca(project_dir, pca_path=None): """ if pca_path is None: pca_path = os.path.join(project_dir, "pca.p") - assert os.path.exists(pca_path), fill( - f"No PCA model found at {pca_path}" - ) + assert os.path.exists(pca_path), fill(f"No PCA model found at {pca_path}") return joblib.load(pca_path) -def load_checkpoint( - project_dir=None, model_name=None, path=None, iteration=None -): +def load_checkpoint(project_dir=None, model_name=None, path=None, iteration=None): """Load data and model snapshot from a saved checkpoint. The checkpoint path can be specified directly via `path` or else it is @@ -571,9 +549,7 @@ def reindex_syllables_in_checkpoint( num_states = f[f"model_snapshots/{last_iter}/params/pi"].shape[0] z = f[f"model_snapshots/{last_iter}/states/z"][()] mask = f["data/mask"][()] - index = np.argsort(get_frequencies(z, mask, num_states, runlength))[ - ::-1 - ] + index = np.argsort(get_frequencies(z, mask, num_states, runlength))[::-1] def _reindex(model): model["params"]["betas"] = model["params"]["betas"][index] @@ -651,7 +627,7 @@ def extract_results( # extract syllables; repeat first syllable an extra `nlags` times nlags = states["x"].shape[1] - states["z"].shape[1] - z = np.pad(states["z"], ((0,0),(nlags, 0)), mode="edge") + z = np.pad(states["z"], ((0, 0), (nlags, 0)), mode="edge") syllables = unbatch(z, *metadata) # extract latent state, centroid, and heading @@ -724,9 +700,7 @@ def save_results_as_csv( If a path separator ("/" or "\") is present in the recording name, it will be replaced with `path_sep` when saving the csv file. """ - save_dir = _get_path( - project_dir, model_name, save_dir, "results", "save_dir" - ) + save_dir = _get_path(project_dir, model_name, save_dir, "results", "save_dir") if not os.path.exists(save_dir): os.makedirs(save_dir) @@ -749,15 +723,10 @@ def save_results_as_csv( if "latent_state" in results[key].keys(): latent_dim = results[key]["latent_state"].shape[1] - column_names.append( - [f"latent_state {i}" for i in range(latent_dim)] - ) + column_names.append([f"latent_state {i}" for i in range(latent_dim)]) data.append(results[key]["latent_state"]) - dfs = [ - pd.DataFrame(arr, columns=cols) - for arr, cols in zip(data, column_names) - ] + dfs = [pd.DataFrame(arr, columns=cols) for arr, cols in zip(data, column_names)] df = pd.concat(dfs, axis=1) for col in df.select_dtypes(include=[np.floating]).columns: @@ -790,6 +759,7 @@ def load_keypoints( path_sep="-", path_in_name=False, remove_extension=True, + exclude_individuals=["single"], ): """ Load keypoint tracking results from one or more files. Several file @@ -886,6 +856,10 @@ def load_keypoints( Whether to remove the file extension when naming the tracking results from each file. + exclude_individuals: list of str, default=["single"] + List of individuals to exclude from the results. This is only used for + multi-animal tracking with deeplabcut. + Returns ------- coordinates: dict @@ -933,9 +907,12 @@ def load_keypoints( "facemap": _facemap_loader, }[format] - filepaths = list_files_with_exts( - filepath_pattern, extensions, recursive=recursive - ) + if format == "deeplabcut": + additional_args = {"exclude_individuals": exclude_individuals} + else: + additional_args = {} + + filepaths = list_files_with_exts(filepath_pattern, extensions, recursive=recursive) assert len(filepaths) > 0, fill( f"No files with extensions {extensions} found for {filepath_pattern}" ) @@ -943,11 +920,9 @@ def load_keypoints( coordinates, confidences, bodyparts = {}, {}, None for filepath in tqdm.tqdm(filepaths, desc=f"Loading keypoints", ncols=72): try: - name = _name_from_path( - filepath, path_in_name, path_sep, remove_extension - ) + name = _name_from_path(filepath, path_in_name, path_sep, remove_extension) new_coordinates, new_confidences, bodyparts = loader( - filepath, name + filepath, name, **additional_args ) if set(new_coordinates.keys()) & set(coordinates.keys()): @@ -967,15 +942,13 @@ def load_keypoints( coordinates.update(new_coordinates) confidences.update(new_confidences) - assert len(coordinates) > 0, fill( - f"No valid results found for {filepath_pattern}" - ) + assert len(coordinates) > 0, fill(f"No valid results found for {filepath_pattern}") check_nan_proportions(coordinates, bodyparts) return coordinates, confidences, bodyparts -def _deeplabcut_loader(filepath, name): +def _deeplabcut_loader(filepath, name, exclude_individuals=["single"]): """Load tracking results from deeplabcut csv or hdf5 files.""" ext = os.path.splitext(filepath)[1] if ext == ".h5": @@ -990,14 +963,30 @@ def _deeplabcut_loader(filepath, name): df = pd.read_csv(filepath, header=header, index_col=0) coordinates, confidences = {}, {} - bodyparts = df.columns.get_level_values("bodyparts").unique().tolist() if "individuals" in df.columns.names: + ind_bodyparts = {} for ind in df.columns.get_level_values("individuals").unique(): - ind_df = df.xs(ind, axis=1, level="individuals") - arr = ind_df.to_numpy().reshape(len(ind_df), -1, 3) - coordinates[f"{name}_{ind}"] = arr[:, :, :-1] - confidences[f"{name}_{ind}"] = arr[:, :, -1] + if ind in exclude_individuals: + print( + f'Excluding individual: "{ind}". Set `exclude_individuals=[]` to include.' + ) + else: + ind_df = df.xs(ind, axis=1, level="individuals") + bps = ind_df.columns.get_level_values("bodyparts").unique().tolist() + ind_bodyparts[ind] = bps + + arr = ind_df.to_numpy().reshape(len(ind_df), -1, 3) + coordinates[f"{name}_{ind}"] = arr[:, :, :-1] + confidences[f"{name}_{ind}"] = arr[:, :, -1] + + bodyparts = set(ind_bodyparts[list(ind_bodyparts.keys())[0]]) + assert all([set(bps) == bodyparts for bps in ind_bodyparts.values()]), ( + f"Bodyparts are not consistent across individuals. The following bodyparts " + f"were found for each individual: {ind_bodyparts}. Use `exclude_individuals`" + "to exclude specific individuals." + ) else: + bodyparts = df.columns.get_level_values("bodyparts").unique().tolist() arr = df.to_numpy().reshape(len(df), -1, 3) coordinates[name] = arr[:, :, :-1] confidences[name] = arr[:, :, -1] @@ -1030,12 +1019,8 @@ def _sleap_loader(filepath, name): coordinates = {name: coords[0].T} confidences = {name: confs[0].T} else: - coordinates = { - f"{name}_track{i}": coords[i].T for i in range(coords.shape[0]) - } - confidences = { - f"{name}_track{i}": confs[i].T for i in range(coords.shape[0]) - } + coordinates = {f"{name}_track{i}": coords[i].T for i in range(coords.shape[0])} + confidences = {f"{name}_track{i}": confs[i].T for i in range(coords.shape[0])} return coordinates, confidences, bodyparts @@ -1050,10 +1035,7 @@ def _anipose_loader(filepath, name): df = pd.read_csv(filepath) coordinates = { name: np.stack( - [ - df[[f"{bp}_x", f"{bp}_y", f"{bp}_z"]].to_numpy() - for bp in bodyparts - ], + [df[[f"{bp}_x", f"{bp}_y", f"{bp}_z"]].to_numpy() for bp in bodyparts], axis=1, ) } @@ -1075,8 +1057,7 @@ def _sleap_anipose_loader(filepath, name): confidences = {name: confs[:, 0]} else: coordinates = { - f"{name}_track{i}": coords[:, i] - for i in range(coords.shape[1]) + f"{name}_track{i}": coords[:, i] for i in range(coords.shape[1]) } confidences = { f"{name}_track{i}": confs[:, i] for i in range(coords.shape[1]) @@ -1088,9 +1069,7 @@ def _load_nwb_pose_obj(io, filepath): """Grab PoseEstimation object from an opened .nwb file.""" all_objs = io.read().all_children() pose_objs = [o for o in all_objs if isinstance(o, PoseEstimation)] - assert len(pose_objs) > 0, fill( - f"No PoseEstimation objects found in {filepath}" - ) + assert len(pose_objs) > 0, fill(f"No PoseEstimation objects found in {filepath}") assert len(pose_objs) == 1, fill( f"Found multiple PoseEstimation objects in {filepath}. " "This is not currently supported. Please open a github " @@ -1109,10 +1088,7 @@ def _nwb_loader(filepath, name): [pose_obj.pose_estimation_series[bp].data[()] for bp in bodyparts], axis=1, ) - if ( - "confidence" - in pose_obj.pose_estimation_series[bodyparts[0]].fields - ): + if "confidence" in pose_obj.pose_estimation_series[bodyparts[0]].fields: confs = np.stack( [ pose_obj.pose_estimation_series[bp].confidence[()]