From 3c0328bc899677fbf03c485dd20fce033e3632a6 Mon Sep 17 00:00:00 2001 From: Caleb Weinreb Date: Thu, 21 Dec 2023 15:17:08 -0500 Subject: [PATCH 1/2] removed linestyles="None" --- keypoint_moseq/analysis.py | 100 ++++++++++--------------------------- 1 file changed, 26 insertions(+), 74 deletions(-) diff --git a/keypoint_moseq/analysis.py b/keypoint_moseq/analysis.py index b7e848d..dfa806d 100644 --- a/keypoint_moseq/analysis.py +++ b/keypoint_moseq/analysis.py @@ -88,9 +88,7 @@ def generate_index(project_dir, model_name, index_filepath): else: # generate a new index file results_dict = load_results(project_dir, model_name) - index_df = pd.DataFrame( - {"name": list(results_dict.keys()), "group": "default"} - ) + index_df = pd.DataFrame({"name": list(results_dict.keys()), "group": "default"}) # write index dataframe index_df.to_csv(index_filepath, index=False) @@ -218,9 +216,7 @@ def compute_moseq_df(project_dir, model_name, *, fps=30, smooth_heading=True): np.concatenate( ( [0], - np.sqrt( - np.square(np.diff(v["centroid"], axis=0)).sum(axis=1) - ) + np.sqrt(np.square(np.diff(v["centroid"], axis=0)).sum(axis=1)) * fps, ) ) @@ -229,8 +225,7 @@ def compute_moseq_df(project_dir, model_name, *, fps=30, smooth_heading=True): if index_data is not None: # find the group for each recording from index data s_group.append( - [index_data[index_data["name"] == k]["group"].values[0]] - * n_frame + [index_data[index_data["name"] == k]["group"].values[0]] * n_frame ) else: # no index data @@ -360,24 +355,18 @@ def compute_stats_df( ["heading", "angular_velocity", "velocity_px_s"] ].agg(["mean", "std", "min", "max"]) - features.columns = [ - "_".join(col).strip() for col in features.columns.values - ] + features.columns = ["_".join(col).strip() for col in features.columns.values] features.reset_index(inplace=True) # get durations trials = filtered_df["onset"].cumsum() trials.name = "trials" - durations = filtered_df.groupby(groupby + ["syllable"] + [trials])[ - "onset" - ].count() + durations = filtered_df.groupby(groupby + ["syllable"] + [trials])["onset"].count() # average duration in seconds durations = durations.groupby(groupby + ["syllable"]).mean() / fps durations.name = "duration" # only keep the columns we need - durations = durations.fillna(0).reset_index()[ - groupby + ["syllable", "duration"] - ] + durations = durations.fillna(0).reset_index()[groupby + ["syllable", "duration"]] stats_df = pd.merge(features, frequency_df, on=groupby + ["syllable"]) stats_df = pd.merge(stats_df, durations, on=groupby + ["syllable"]) @@ -400,9 +389,7 @@ def generate_syll_info(project_dir, model_name, syll_info_path): } ) - grid_movies = glob( - os.path.join(project_dir, model_name, "grid_movies", "*.mp4") - ) + grid_movies = glob(os.path.join(project_dir, model_name, "grid_movies", "*.mp4")) assert len(grid_movies) > 0, ( "No grid movies found. Please run `generate_grid_movies` as described in the docs: " "https://keypoint-moseq.readthedocs.io/en/latest/modeling.html#visualization" @@ -443,9 +430,7 @@ def label_syllables(project_dir, model_name, moseq_df): generate_syll_info(project_dir, model_name, syll_info_path) # ensure there is grid movies - grid_movies = glob( - os.path.join(project_dir, model_name, "grid_movies", "*.mp4") - ) + grid_movies = glob(os.path.join(project_dir, model_name, "grid_movies", "*.mp4")) assert len(grid_movies) > 0, ( "No grid movies found. Please run `generate_grid_movies` as described in the docs: " "https://keypoint-moseq.readthedocs.io/en/latest/modeling.html#visualization" @@ -479,9 +464,7 @@ def show_movie(syllable): # create the labeler dataframe # only include the syllable that have grid movies include = syll_info_df_with_movie.syllable.values - syll_df = ( - moseq_df[["syllable"]].groupby("syllable").mean().reset_index().copy() - ) + syll_df = moseq_df[["syllable"]].groupby("syllable").mean().reset_index().copy() syll_df = syll_df[syll_df.syllable.isin(include)] # get labels and description from syll info @@ -530,9 +513,7 @@ def show_movie(syllable): configuration=base_configuration, ) - button = pn.widgets.Button( - name="Save syllable info", button_type="primary" - ) + button = pn.widgets.Button(name="Save syllable info", button_type="primary") # call back function to save the index file def save_index(syll_df): @@ -710,9 +691,7 @@ def dunns_z_test_permute_within_group_pairs( n_mice = is_i.sum() + is_j.sum() - ranks_perm = real_ranks[(is_i | is_j)][ - rnd.rand(n_perm, n_mice).argsort(-1) - ] + ranks_perm = real_ranks[(is_i | is_j)][rnd.rand(n_perm, n_mice).argsort(-1)] diff = np.abs( ranks_perm[:, : is_i.sum(), :].mean(1) - ranks_perm[:, is_i.sum() :, :].mean(1) @@ -722,8 +701,7 @@ def dunns_z_test_permute_within_group_pairs( # also do for real data group_ranks = real_ranks[(is_i | is_j)] real_diff = np.abs( - group_ranks[: is_i.sum(), :].mean(0) - - group_ranks[is_i.sum() :, :].mean(0) + group_ranks[: is_i.sum(), :].mean(0) - group_ranks[is_i.sum() :, :].mean(0) ) # add to dict @@ -787,13 +765,9 @@ def compute_pvalues_for_group_pairs( def correct_p(x): return multipletests(x, alpha=thresh, method=mc_method)[1] - df_pval_corrected = df_pval.apply( - correct_p, axis=1, result_type="broadcast" - ) + df_pval_corrected = df_pval.apply(correct_p, axis=1, result_type="broadcast") - return df_pval_corrected, ( - (df_pval_corrected[df_k_real.is_sig] < thresh).sum(0) - ) + return df_pval_corrected, ((df_pval_corrected[df_k_real.is_sig] < thresh).sum(0)) def run_kruskal( @@ -868,9 +842,7 @@ def run_kruskal( df_k_real = pd.DataFrame( [ stats.kruskal( - *np.array_split( - syllable_data[:, s_i], np.cumsum(n_per_group[:-1]) - ) + *np.array_split(syllable_data[:, s_i], np.cumsum(n_per_group[:-1])) ) for s_i in range(N_s) ] @@ -963,9 +935,7 @@ def sort_syllables_by_stat_difference( exp_df = mutation_df.loc[exp_group] # compute mean difference at each syll frequency and reorder based on difference - ordering = ( - (exp_df[stat] - control_df[stat]).sort_values(ascending=False).index - ) + ordering = (exp_df[stat] - control_df[stat]).sort_values(ascending=False).index return list(ordering) @@ -996,11 +966,7 @@ def sort_syllables_by_stat(stats_df, stat="frequency"): else: ordering = ( stats_df.drop( - [ - col - for col, dtype in stats_df.dtypes.items() - if dtype == "object" - ], + [col for col, dtype in stats_df.dtypes.items() if dtype == "object"], axis=1, ) .groupby("syllable") @@ -1188,7 +1154,6 @@ def plot_syll_stats_with_sem( y=stat, hue=hue, order=ordering, - linestyles="none", errorbar=("ci", 68), ax=ax, hue_order=groups, @@ -1215,7 +1180,6 @@ def plot_syll_stats_with_sem( [], color="red", marker="*", - linestyles="None", markersize=9, label="Significant Syllable", ) @@ -1366,9 +1330,7 @@ def get_transition_matrix( transitions = get_transitions(v)[0] trans_mat = ( - n_gram_transition_matrix( - transitions, n=2, max_label=max_syllable - ) + n_gram_transition_matrix(transitions, n=2, max_label=max_syllable) + smoothing ) @@ -1379,9 +1341,7 @@ def get_transition_matrix( return all_mats -def get_group_trans_mats( - labels, label_group, group, syll_include, normalize="bigram" -): +def get_group_trans_mats(labels, label_group, group, syll_include, normalize="bigram"): """Get the transition matrices for each group. Parameters @@ -1410,18 +1370,16 @@ def get_group_trans_mats( # Computing transition matrices for each given group for plt_group in group: # list of syll labels in recordings in the group - use_labels = [ - lbl for lbl, grp in zip(labels, label_group) if grp == plt_group - ] + use_labels = [lbl for lbl, grp in zip(labels, label_group) if grp == plt_group] # find stack np array shape row_num = len(use_labels) max_len = max([len(lbl) for lbl in use_labels]) # Get recordings to include in trans_mat # subset only syllable included trans_mats.append( - get_transition_matrix( - use_labels, normalize=normalize, combine=True - )[syll_include, :][:, syll_include] + get_transition_matrix(use_labels, normalize=normalize, combine=True)[ + syll_include, : + ][:, syll_include] ) # Getting frequency information for node scaling @@ -1466,9 +1424,7 @@ def visualize_transition_bigram( # infer max_syllables max_syllables = trans_mats[0].shape[0] - fig, ax = plt.subplots( - 1, len(group), figsize=figsize, sharex=False, sharey=True - ) + fig, ax = plt.subplots(1, len(group), figsize=figsize, sharex=False, sharey=True) title_map = dict(bigram="Bigram", columns="Incoming", rows="Outgoing") color_lim = max([x.max() for x in trans_mats]) if len(group) == 1: @@ -1488,9 +1444,7 @@ def visualize_transition_bigram( cb.set_label(f"{title_map[normalize]} transition probability") axs[i].set_xlabel("Outgoing syllable") axs[i].set_title(g) - axs[i].set_xticks( - np.arange(len(syll_include)), syll_names, rotation=90 - ) + axs[i].set_xticks(np.arange(len(syll_include)), syll_names, rotation=90) # save the figure plot_name = "transition_matrices" @@ -1531,9 +1485,7 @@ def generate_transition_matrices( results_dict = load_results(project_dir, model_name) # filter out syllables by freqency - model_labels = [ - results_dict[recording]["syllable"] for recording in recordings - ] + model_labels = [results_dict[recording]["syllable"] for recording in recordings] frequencies = get_frequencies(model_labels) syll_include = np.where(frequencies > min_frequency)[0] From 516059068165d4703a51490a2a369d5d18359743 Mon Sep 17 00:00:00 2001 From: Caleb Weinreb Date: Thu, 28 Dec 2023 21:50:41 -0500 Subject: [PATCH 2/2] bugfix for sequeces of length < nlags --- keypoint_moseq/io.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/keypoint_moseq/io.py b/keypoint_moseq/io.py index fc3633f..ab97372 100644 --- a/keypoint_moseq/io.py +++ b/keypoint_moseq/io.py @@ -648,20 +648,16 @@ def extract_results( path = _get_path(project_dir, model_name, path, "results.h5") states = jax.device_get(model["states"]) - keys, bounds = metadata # extract syllables; repeat first syllable an extra `nlags` times nlags = states["x"].shape[1] - states["z"].shape[1] - syllables = unbatch(states["z"], keys, bounds + np.array([nlags, 0])) - syllables = { - k: np.pad(z[nlags:], (nlags, 0), mode="edge") - for k, z in syllables.items() - } + z = np.pad(states["z"], ((0,0),(nlags, 0)), mode="edge") + syllables = unbatch(z, *metadata) # extract latent state, centroid, and heading - latent_state = unbatch(states["x"], keys, bounds) - centroid = unbatch(states["v"], keys, bounds) - heading = unbatch(states["h"], keys, bounds) + latent_state = unbatch(states["x"], *metadata) + centroid = unbatch(states["v"], *metadata) + heading = unbatch(states["h"], *metadata) results_dict = { recording_name: {