Skip to content

Commit

Permalink
Merge pull request #120 from dattalab/dev
Browse files Browse the repository at this point in the history
Bugfixes (linestyle and extract results)
  • Loading branch information
calebweinreb authored Dec 29, 2023
2 parents 82ea15e + 1dd80ee commit 1c80236
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 83 deletions.
100 changes: 26 additions & 74 deletions keypoint_moseq/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,7 @@ def generate_index(project_dir, model_name, index_filepath):
else:
# generate a new index file
results_dict = load_results(project_dir, model_name)
index_df = pd.DataFrame(
{"name": list(results_dict.keys()), "group": "default"}
)
index_df = pd.DataFrame({"name": list(results_dict.keys()), "group": "default"})
# write index dataframe
index_df.to_csv(index_filepath, index=False)

Expand Down Expand Up @@ -218,9 +216,7 @@ def compute_moseq_df(project_dir, model_name, *, fps=30, smooth_heading=True):
np.concatenate(
(
[0],
np.sqrt(
np.square(np.diff(v["centroid"], axis=0)).sum(axis=1)
)
np.sqrt(np.square(np.diff(v["centroid"], axis=0)).sum(axis=1))
* fps,
)
)
Expand All @@ -229,8 +225,7 @@ def compute_moseq_df(project_dir, model_name, *, fps=30, smooth_heading=True):
if index_data is not None:
# find the group for each recording from index data
s_group.append(
[index_data[index_data["name"] == k]["group"].values[0]]
* n_frame
[index_data[index_data["name"] == k]["group"].values[0]] * n_frame
)
else:
# no index data
Expand Down Expand Up @@ -360,24 +355,18 @@ def compute_stats_df(
["heading", "angular_velocity", "velocity_px_s"]
].agg(["mean", "std", "min", "max"])

features.columns = [
"_".join(col).strip() for col in features.columns.values
]
features.columns = ["_".join(col).strip() for col in features.columns.values]
features.reset_index(inplace=True)

# get durations
trials = filtered_df["onset"].cumsum()
trials.name = "trials"
durations = filtered_df.groupby(groupby + ["syllable"] + [trials])[
"onset"
].count()
durations = filtered_df.groupby(groupby + ["syllable"] + [trials])["onset"].count()
# average duration in seconds
durations = durations.groupby(groupby + ["syllable"]).mean() / fps
durations.name = "duration"
# only keep the columns we need
durations = durations.fillna(0).reset_index()[
groupby + ["syllable", "duration"]
]
durations = durations.fillna(0).reset_index()[groupby + ["syllable", "duration"]]

stats_df = pd.merge(features, frequency_df, on=groupby + ["syllable"])
stats_df = pd.merge(stats_df, durations, on=groupby + ["syllable"])
Expand All @@ -400,9 +389,7 @@ def generate_syll_info(project_dir, model_name, syll_info_path):
}
)

grid_movies = glob(
os.path.join(project_dir, model_name, "grid_movies", "*.mp4")
)
grid_movies = glob(os.path.join(project_dir, model_name, "grid_movies", "*.mp4"))
assert len(grid_movies) > 0, (
"No grid movies found. Please run `generate_grid_movies` as described in the docs: "
"https://keypoint-moseq.readthedocs.io/en/latest/modeling.html#visualization"
Expand Down Expand Up @@ -443,9 +430,7 @@ def label_syllables(project_dir, model_name, moseq_df):
generate_syll_info(project_dir, model_name, syll_info_path)

# ensure there is grid movies
grid_movies = glob(
os.path.join(project_dir, model_name, "grid_movies", "*.mp4")
)
grid_movies = glob(os.path.join(project_dir, model_name, "grid_movies", "*.mp4"))
assert len(grid_movies) > 0, (
"No grid movies found. Please run `generate_grid_movies` as described in the docs: "
"https://keypoint-moseq.readthedocs.io/en/latest/modeling.html#visualization"
Expand Down Expand Up @@ -479,9 +464,7 @@ def show_movie(syllable):
# create the labeler dataframe
# only include the syllable that have grid movies
include = syll_info_df_with_movie.syllable.values
syll_df = (
moseq_df[["syllable"]].groupby("syllable").mean().reset_index().copy()
)
syll_df = moseq_df[["syllable"]].groupby("syllable").mean().reset_index().copy()
syll_df = syll_df[syll_df.syllable.isin(include)]

# get labels and description from syll info
Expand Down Expand Up @@ -530,9 +513,7 @@ def show_movie(syllable):
configuration=base_configuration,
)

button = pn.widgets.Button(
name="Save syllable info", button_type="primary"
)
button = pn.widgets.Button(name="Save syllable info", button_type="primary")

# call back function to save the index file
def save_index(syll_df):
Expand Down Expand Up @@ -710,9 +691,7 @@ def dunns_z_test_permute_within_group_pairs(

n_mice = is_i.sum() + is_j.sum()

ranks_perm = real_ranks[(is_i | is_j)][
rnd.rand(n_perm, n_mice).argsort(-1)
]
ranks_perm = real_ranks[(is_i | is_j)][rnd.rand(n_perm, n_mice).argsort(-1)]
diff = np.abs(
ranks_perm[:, : is_i.sum(), :].mean(1)
- ranks_perm[:, is_i.sum() :, :].mean(1)
Expand All @@ -722,8 +701,7 @@ def dunns_z_test_permute_within_group_pairs(
# also do for real data
group_ranks = real_ranks[(is_i | is_j)]
real_diff = np.abs(
group_ranks[: is_i.sum(), :].mean(0)
- group_ranks[is_i.sum() :, :].mean(0)
group_ranks[: is_i.sum(), :].mean(0) - group_ranks[is_i.sum() :, :].mean(0)
)

# add to dict
Expand Down Expand Up @@ -787,13 +765,9 @@ def compute_pvalues_for_group_pairs(
def correct_p(x):
return multipletests(x, alpha=thresh, method=mc_method)[1]

df_pval_corrected = df_pval.apply(
correct_p, axis=1, result_type="broadcast"
)
df_pval_corrected = df_pval.apply(correct_p, axis=1, result_type="broadcast")

return df_pval_corrected, (
(df_pval_corrected[df_k_real.is_sig] < thresh).sum(0)
)
return df_pval_corrected, ((df_pval_corrected[df_k_real.is_sig] < thresh).sum(0))


def run_kruskal(
Expand Down Expand Up @@ -868,9 +842,7 @@ def run_kruskal(
df_k_real = pd.DataFrame(
[
stats.kruskal(
*np.array_split(
syllable_data[:, s_i], np.cumsum(n_per_group[:-1])
)
*np.array_split(syllable_data[:, s_i], np.cumsum(n_per_group[:-1]))
)
for s_i in range(N_s)
]
Expand Down Expand Up @@ -963,9 +935,7 @@ def sort_syllables_by_stat_difference(
exp_df = mutation_df.loc[exp_group]

# compute mean difference at each syll frequency and reorder based on difference
ordering = (
(exp_df[stat] - control_df[stat]).sort_values(ascending=False).index
)
ordering = (exp_df[stat] - control_df[stat]).sort_values(ascending=False).index

return list(ordering)

Expand Down Expand Up @@ -996,11 +966,7 @@ def sort_syllables_by_stat(stats_df, stat="frequency"):
else:
ordering = (
stats_df.drop(
[
col
for col, dtype in stats_df.dtypes.items()
if dtype == "object"
],
[col for col, dtype in stats_df.dtypes.items() if dtype == "object"],
axis=1,
)
.groupby("syllable")
Expand Down Expand Up @@ -1188,7 +1154,6 @@ def plot_syll_stats_with_sem(
y=stat,
hue=hue,
order=ordering,
linestyles="none",
errorbar=("ci", 68),
ax=ax,
hue_order=groups,
Expand All @@ -1215,7 +1180,6 @@ def plot_syll_stats_with_sem(
[],
color="red",
marker="*",
linestyles="None",
markersize=9,
label="Significant Syllable",
)
Expand Down Expand Up @@ -1366,9 +1330,7 @@ def get_transition_matrix(
transitions = get_transitions(v)[0]

trans_mat = (
n_gram_transition_matrix(
transitions, n=2, max_label=max_syllable
)
n_gram_transition_matrix(transitions, n=2, max_label=max_syllable)
+ smoothing
)

Expand All @@ -1379,9 +1341,7 @@ def get_transition_matrix(
return all_mats


def get_group_trans_mats(
labels, label_group, group, syll_include, normalize="bigram"
):
def get_group_trans_mats(labels, label_group, group, syll_include, normalize="bigram"):
"""Get the transition matrices for each group.
Parameters
Expand Down Expand Up @@ -1410,18 +1370,16 @@ def get_group_trans_mats(
# Computing transition matrices for each given group
for plt_group in group:
# list of syll labels in recordings in the group
use_labels = [
lbl for lbl, grp in zip(labels, label_group) if grp == plt_group
]
use_labels = [lbl for lbl, grp in zip(labels, label_group) if grp == plt_group]
# find stack np array shape
row_num = len(use_labels)
max_len = max([len(lbl) for lbl in use_labels])
# Get recordings to include in trans_mat
# subset only syllable included
trans_mats.append(
get_transition_matrix(
use_labels, normalize=normalize, combine=True
)[syll_include, :][:, syll_include]
get_transition_matrix(use_labels, normalize=normalize, combine=True)[
syll_include, :
][:, syll_include]
)

# Getting frequency information for node scaling
Expand Down Expand Up @@ -1466,9 +1424,7 @@ def visualize_transition_bigram(
# infer max_syllables
max_syllables = trans_mats[0].shape[0]

fig, ax = plt.subplots(
1, len(group), figsize=figsize, sharex=False, sharey=True
)
fig, ax = plt.subplots(1, len(group), figsize=figsize, sharex=False, sharey=True)
title_map = dict(bigram="Bigram", columns="Incoming", rows="Outgoing")
color_lim = max([x.max() for x in trans_mats])
if len(group) == 1:
Expand All @@ -1488,9 +1444,7 @@ def visualize_transition_bigram(
cb.set_label(f"{title_map[normalize]} transition probability")
axs[i].set_xlabel("Outgoing syllable")
axs[i].set_title(g)
axs[i].set_xticks(
np.arange(len(syll_include)), syll_names, rotation=90
)
axs[i].set_xticks(np.arange(len(syll_include)), syll_names, rotation=90)

# save the figure
plot_name = "transition_matrices"
Expand Down Expand Up @@ -1531,9 +1485,7 @@ def generate_transition_matrices(
results_dict = load_results(project_dir, model_name)

# filter out syllables by freqency
model_labels = [
results_dict[recording]["syllable"] for recording in recordings
]
model_labels = [results_dict[recording]["syllable"] for recording in recordings]
frequencies = get_frequencies(model_labels)
syll_include = np.where(frequencies > min_frequency)[0]

Expand Down
14 changes: 5 additions & 9 deletions keypoint_moseq/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,20 +648,16 @@ def extract_results(
path = _get_path(project_dir, model_name, path, "results.h5")

states = jax.device_get(model["states"])
keys, bounds = metadata

# extract syllables; repeat first syllable an extra `nlags` times
nlags = states["x"].shape[1] - states["z"].shape[1]
syllables = unbatch(states["z"], keys, bounds + np.array([nlags, 0]))
syllables = {
k: np.pad(z[nlags:], (nlags, 0), mode="edge")
for k, z in syllables.items()
}
z = np.pad(states["z"], ((0,0),(nlags, 0)), mode="edge")
syllables = unbatch(z, *metadata)

# extract latent state, centroid, and heading
latent_state = unbatch(states["x"], keys, bounds)
centroid = unbatch(states["v"], keys, bounds)
heading = unbatch(states["h"], keys, bounds)
latent_state = unbatch(states["x"], *metadata)
centroid = unbatch(states["v"], *metadata)
heading = unbatch(states["h"], *metadata)

results_dict = {
recording_name: {
Expand Down

0 comments on commit 1c80236

Please sign in to comment.