Skip to content

Commit

Permalink
Merge branch 'hotfix/2.13.4'
Browse files Browse the repository at this point in the history
  • Loading branch information
mayofaulkner committed Jul 22, 2022
2 parents e316c0c + 70cccde commit aa5bec6
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 43 deletions.
5 changes: 3 additions & 2 deletions brainbox/io/spikeglx.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,10 @@ def read(self, nsel=slice(0, 10000), csel=slice(None), sync=True):
"""
overload the read function by downloading the necessary chunks
"""
first_chunk = np.maximum(0, np.searchsorted(self.chunks['chunk_bounds'], nsel.start + 0.01 * self.fs) - 1)
last_chunk = np.maximum(0, np.searchsorted(self.chunks['chunk_bounds'], nsel.stop + 0.01 * self.fs) - 2)
first_chunk = np.maximum(0, np.searchsorted(self.chunks['chunk_bounds'], nsel.start) - 1)
last_chunk = np.maximum(0, np.searchsorted(self.chunks['chunk_bounds'], nsel.stop) - 1)
n0 = self.chunks['chunk_bounds'][first_chunk]
_logger.debug(f'Streamer: caching sample {n0}, (t={n0 / self.fs})')
self.cache_folder.mkdir(exist_ok=True, parents=True)
sr = self._download_raw_partial(first_chunk=first_chunk, last_chunk=last_chunk)
data = sr[nsel.start - n0: nsel.stop - n0, csel]
Expand Down
2 changes: 1 addition & 1 deletion ibllib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.13.3"
__version__ = "2.13.4"
import warnings

from ibllib.misc import logger_config
Expand Down
229 changes: 196 additions & 33 deletions ibllib/pipes/training_status.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from one.api import ONE
import one.alf.io as alfio
from one.alf.spec import is_session_path
from one.alf.exceptions import ALFObjectNotFound
Expand All @@ -16,9 +15,10 @@
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from matplotlib.lines import Line2D
from datetime import datetime
import seaborn as sns

one = ONE()

TRAINING_STATUS = {'not_computed': (-2, (0, 0, 0, 0)),
'habituation': (-1, (0, 0, 0, 0)),
Expand Down Expand Up @@ -301,6 +301,12 @@ def get_training_info_for_session(session_paths, one):
sess_dict['n_delay'] = np.nan
sess_dict['location'] = np.nan
sess_dict['training_status'] = 'habituation'
sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapsehigh_50'], sess_dict['lapselow_50'] = \
(np.nan, np.nan, np.nan, np.nan)
sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapsehigh_20'], sess_dict['lapselow_20'] = \
(np.nan, np.nan, np.nan, np.nan)
sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapsehigh_80'], sess_dict['lapselow_80'] = \
(np.nan, np.nan, np.nan, np.nan)

else:
# if we can't compute trials then we need to pass
Expand All @@ -309,12 +315,26 @@ def get_training_info_for_session(session_paths, one):
continue

sess_dict['performance'], sess_dict['contrasts'], _ = training.compute_performance(trials, prob_right=True)
if sess_dict['task_protocol'] == 'training':
sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapsehigh_50'], sess_dict['lapselow_50'] = \
training.compute_psychometric(trials)
sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapsehigh_20'], sess_dict['lapselow_20'] = \
(np.nan, np.nan, np.nan, np.nan)
sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapsehigh_80'], sess_dict['lapselow_80'] = \
(np.nan, np.nan, np.nan, np.nan)
else:
sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapsehigh_50'], sess_dict['lapselow_50'] = \
training.compute_psychometric(trials, block=0.5)
sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapsehigh_20'], sess_dict['lapselow_20'] = \
training.compute_psychometric(trials, block=0.2)
sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapsehigh_80'], sess_dict['lapselow_80'] = \
training.compute_psychometric(trials, block=0.8)

sess_dict['performance_easy'] = training.compute_performance_easy(trials)
sess_dict['reaction_time'] = training.compute_median_reaction_time(trials)
sess_dict['n_trials'] = training.compute_n_trials(trials)
sess_dict['sess_duration'], sess_dict['n_delay'], sess_dict['location'] = \
compute_session_duration_delay_location(session_path)
sess_dict['task_protocol'] = get_session_extractor_type(session_path)
sess_dict['training_status'] = 'not_computed'

sess_dicts.append(sess_dict)
Expand All @@ -328,6 +348,11 @@ def get_training_info_for_session(session_paths, one):
print(f'{len(sess_dicts)} sessions being combined for date {sess_dicts[0]["date"]}')
combined_trials = load_combined_trials(session_paths, one)
performance, contrasts, _ = training.compute_performance(combined_trials, prob_right=True)
psychs = {}
psychs['50'] = training.compute_psychometric(trials, block=0.5)
psychs['20'] = training.compute_psychometric(trials, block=0.2)
psychs['80'] = training.compute_psychometric(trials, block=0.8)

performance_easy = training.compute_performance_easy(combined_trials)
reaction_time = training.compute_median_reaction_time(combined_trials)
n_trials = training.compute_n_trials(combined_trials)
Expand All @@ -344,6 +369,12 @@ def get_training_info_for_session(session_paths, one):
sess_dict['combined_sess_duration'] = sess_duration
sess_dict['combined_n_delay'] = n_delay

for bias in [50, 20, 80]:
sess_dict[f'combined_bias_{bias}'] = psychs[f'{bias}'][0]
sess_dict[f'combined_thres_{bias}'] = psychs[f'{bias}'][1]
sess_dict[f'combined_lapsehigh_{bias}'] = psychs[f'{bias}'][2]
sess_dict[f'combined_lapselow_{bias}'] = psychs[f'{bias}'][3]

# Case where two sessions on same day with different number of contrasts! Oh boy
if sess_dict['combined_performance'].size != sess_dict['performance'].size:
sess_dict['performance'] = \
Expand All @@ -363,6 +394,12 @@ def get_training_info_for_session(session_paths, one):
sess_dict['combined_sess_duration'] = sess_dict['sess_duration']
sess_dict['combined_n_delay'] = sess_dict['n_delay']

for bias in [50, 20, 80]:
sess_dict[f'combined_bias_{bias}'] = sess_dict[f'bias_{bias}']
sess_dict[f'combined_thres_{bias}'] = sess_dict[f'thres_{bias}']
sess_dict[f'combined_lapsehigh_{bias}'] = sess_dict[f'lapsehigh_{bias}']
sess_dict[f'combined_lapselow_{bias}'] = sess_dict[f'lapselow_{bias}']

return sess_dicts


Expand All @@ -384,7 +421,7 @@ def check_up_to_date(subj_path, df):
df_session = pd.concat([df_session, pd.DataFrame({'date': date, 'session_path': str(sess)}, index=[0])],
ignore_index=True)

if df is None:
if df is None or 'combined_thres_50' not in df.columns:
return df_session
else:
# recorded_session_paths = df['session_path'].values
Expand All @@ -399,14 +436,18 @@ def plot_trial_count_and_session_duration(df, subject):

y1 = {'column': 'combined_n_trials',
'title': 'Trial counts',
'lim': None}
'lim': None,
'color': 'k',
'join': True}

y2 = {'column': 'combined_sess_duration',
'title': 'Session duration (mins)',
'lim': None,
'log': False}
'color': 'r',
'log': False,
'join': True}

ax = plot_over_days(df, y1, y2, subject)
ax = plot_over_days(df, subject, y1, y2)

return ax

Expand All @@ -416,40 +457,152 @@ def plot_performance_easy_median_reaction_time(df, subject):

y1 = {'column': 'combined_performance_easy',
'title': 'Performance on easy trials',
'lim': [0, 1.05]}
'lim': [0, 1.05],
'color': 'k',
'join': True}

y2 = {'column': 'combined_reaction_time',
'title': 'Median reaction time (s)',
'lim': [0.1, np.nanmax([10, np.nanmax(df.combined_reaction_time.values)])],
'log': True}
ax = plot_over_days(df, y1, y2, subject)
'color': 'r',
'log': True,
'join': True}
ax = plot_over_days(df, subject, y1, y2)

return ax


def plot_over_days(df, y1, y2, subject, ax=None):
def plot_fit_params(df, subject):
fig, axs = plt.subplots(2, 2, figsize=(12, 6))
axs = axs.ravel()

df = df.drop_duplicates('date').reset_index(drop=True)

cmap = sns.diverging_palette(20, 220, n=3, center="dark")

y50 = {'column': 'combined_bias_50',
'title': 'Bias',
'lim': [-100, 100],
'color': cmap[1],
'join': False}

y80 = {'column': 'combined_bias_80',
'title': 'Bias',
'lim': [-100, 100],
'color': cmap[2],
'join': False}

y20 = {'column': 'combined_bias_20',
'title': 'Bias',
'lim': [-100, 100],
'color': cmap[0],
'join': False}

plot_over_days(df, subject, y50, ax=axs[0], legend=False, title=False)
plot_over_days(df, subject, y80, ax=axs[0], legend=False, title=False)
plot_over_days(df, subject, y20, ax=axs[0], legend=False, title=False)
axs[0].axhline(16, linewidth=2, linestyle='--', color='k')
axs[0].axhline(-16, linewidth=2, linestyle='--', color='k')

y50['column'] = 'combined_thres_50'
y50['title'] = 'Threshold'
y50['lim'] = [0, 100]
y80['column'] = 'combined_thres_20'
y80['title'] = 'Threshold'
y20['lim'] = [0, 100]
y20['column'] = 'combined_thres_80'
y20['title'] = 'Threshold'
y80['lim'] = [0, 100]

plot_over_days(df, subject, y50, ax=axs[1], legend=False, title=False)
plot_over_days(df, subject, y80, ax=axs[1], legend=False, title=False)
plot_over_days(df, subject, y20, ax=axs[1], legend=False, title=False)
axs[1].axhline(19, linewidth=2, linestyle='--', color='k')

y50['column'] = 'combined_lapselow_50'
y50['title'] = 'Lapse Low'
y50['lim'] = [0, 1]
y80['column'] = 'combined_lapselow_20'
y80['title'] = 'Lapse Low'
y80['lim'] = [0, 1]
y20['column'] = 'combined_lapselow_80'
y20['title'] = 'Lapse Low'
y20['lim'] = [0, 1]

plot_over_days(df, subject, y50, ax=axs[2], legend=False, title=False)
plot_over_days(df, subject, y80, ax=axs[2], legend=False, title=False)
plot_over_days(df, subject, y20, ax=axs[2], legend=False, title=False)
axs[2].axhline(0.2, linewidth=2, linestyle='--', color='k')

y50['column'] = 'combined_lapsehigh_50'
y50['title'] = 'Lapse High'
y50['lim'] = [0, 1]
y80['column'] = 'combined_lapsehigh_20'
y80['title'] = 'Lapse High'
y80['lim'] = [0, 1]
y20['column'] = 'combined_lapsehigh_80'
y20['title'] = 'Lapse High'
y20['lim'] = [0, 1]

plot_over_days(df, subject, y50, ax=axs[3], legend=False, title=False, training_lines=True)
plot_over_days(df, subject, y80, ax=axs[3], legend=False, title=False, training_lines=False)
plot_over_days(df, subject, y20, ax=axs[3], legend=False, title=False, training_lines=False)
axs[3].axhline(0.2, linewidth=2, linestyle='--', color='k')

fig.suptitle(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}')
lines, labels = axs[3].get_legend_handles_labels()
fig.legend(lines, labels, loc='upper center', bbox_to_anchor=(0.5, 0.1), fancybox=True, shadow=True, ncol=5)

legend_elements = [Line2D([0], [0], marker='o', color='w', label='p=0.5', markerfacecolor=cmap[1], markersize=8),
Line2D([0], [0], marker='o', color='w', label='p=0.2', markerfacecolor=cmap[0], markersize=8),
Line2D([0], [0], marker='o', color='w', label='p=0.8', markerfacecolor=cmap[2], markersize=8)]
legend2 = plt.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.1, -0.2), fancybox=True, shadow=True)
fig.add_artist(legend2)

return axs


def plot_psychometric_curve(df, subject, one):
df = df.drop_duplicates('date').reset_index(drop=True)
sess_path = Path(df.iloc[-1]["session_path"])
trials = load_trials(sess_path, one)

fig, ax1 = plt.subplots(figsize=(8, 6))

training.plot_psychometric(trials, ax=ax1, title=f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}')

return ax1


def plot_over_days(df, subject, y1, y2=None, ax=None, legend=True, title=True, training_lines=True):

if ax is None:
fig, ax1 = plt.subplots(figsize=(12, 6))
else:
ax1 = ax

ax2 = ax1.twinx()

dates = [datetime.strptime(dat, '%Y-%m-%d') for dat in df['date']]
ax1.plot(dates, df[y1['column']], 'k')
ax1.scatter(dates, df[y1['column']], c='k')
if y1['join']:
ax1.plot(dates, df[y1['column']], color=y1['color'])
ax1.scatter(dates, df[y1['column']], color=y1['color'])
ax1.set_ylabel(y1['title'])
ax1.set_ylim(y1['lim'])

ax2.plot(dates, df[y2['column']], 'r')
ax2.scatter(dates, df[y2['column']], c='r')
ax2.set_ylabel(y2['title'])
ax2.yaxis.label.set_color('r')
ax2.tick_params(axis='y', colors='r')
ax2.set_ylim(y2['lim'])
if y2['log']:
ax2.set_yscale('log')
if y2 is not None:
ax2 = ax1.twinx()
if y2['join']:
ax2.plot(dates, df[y2['column']], color=y2['color'])
ax2.scatter(dates, df[y2['column']], color=y2['color'])
ax2.set_ylabel(y2['title'])
ax2.yaxis.label.set_color(y2['color'])
ax2.tick_params(axis='y', colors=y2['color'])
ax2.set_ylim(y2['lim'])
if y2['log']:
ax2.set_yscale('log')

ax2.spines['right'].set_visible(False)
ax2.spines['top'].set_visible(False)
ax2.spines['left'].set_visible(False)

month_format = mdates.DateFormatter('%b %Y')
month_locator = mdates.MonthLocator()
Expand All @@ -462,20 +615,20 @@ def plot_over_days(df, y1, y2, subject, ax=None):
ax1.spines['left'].set_visible(False)
ax1.spines['right'].set_visible(False)
ax1.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)
ax2.spines['top'].set_visible(False)
ax2.spines['left'].set_visible(False)

ax1 = add_training_lines(df, ax1)
if training_lines:
ax1 = add_training_lines(df, ax1)

ax1.set_title(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}')
box = ax1.get_position()
ax1.set_position([box.x0, box.y0 + box.height * 0.1,
box.width, box.height * 0.9])
if title:
ax1.set_title(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}')

# Put a legend below current axis
ax1.legend(loc='upper center', bbox_to_anchor=(0.5, -0.1),
fancybox=True, shadow=True, ncol=5)
box = ax1.get_position()
ax1.set_position([box.x0, box.y0 + box.height * 0.1,
box.width, box.height * 0.9])
if legend:
ax1.legend(loc='upper center', bbox_to_anchor=(0.5, -0.1),
fancybox=True, shadow=True, ncol=5)

return ax1

Expand Down Expand Up @@ -554,6 +707,8 @@ def make_plots(session_path, one, df=None, save=False, upload=False):
ax1 = plot_trial_count_and_session_duration(df, subject)
ax2 = plot_performance_easy_median_reaction_time(df, subject)
ax3 = plot_heatmap_performance_over_days(df, subject)
ax4 = plot_fit_params(df, subject)
ax5 = plot_psychometric_curve(df, subject, one)

outputs = []
if save:
Expand All @@ -570,6 +725,14 @@ def make_plots(session_path, one, df=None, save=False, upload=False):
outputs.append(save_name)
ax3.get_figure().savefig(save_name, bbox_inches='tight')

save_name = save_path.joinpath('subj_psychometric_fit_params.png')
outputs.append(save_name)
ax4[0].get_figure().savefig(save_name, bbox_inches='tight')

save_name = save_path.joinpath('subj_psychometric_curve.png')
outputs.append(save_name)
ax5.get_figure().savefig(save_name, bbox_inches='tight')

if upload:
subj = one.alyx.rest('subjects', 'list', nickname=subject)[0]
snp = ReportSnapshot(session_path, subj['id'], content_type='subject', one=one)
Expand Down
Loading

0 comments on commit aa5bec6

Please sign in to comment.