Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ready4recording #892

Merged
merged 13 commits into from
Dec 17, 2024
226 changes: 173 additions & 53 deletions brainbox/behavior/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def get_subject_training_status(subj, date=None, details=True, one=None):
if not trials:
return
sess_dates = list(trials.keys())
status, info = get_training_status(trials, task_protocol, ephys_sess, n_delay)
status, info, _ = get_training_status(trials, task_protocol, ephys_sess, n_delay)

if details:
if np.any(info.get('psych')):
Expand Down Expand Up @@ -265,13 +265,13 @@ def get_sessions(subj, date=None, one=None):
if not np.any(np.array(task_protocol) == 'training'):
ephys_sess = one.alyx.rest('sessions', 'list', subject=subj,
date_range=[sess_dates[-1], sess_dates[0]],
django='json__PYBPOD_BOARD__icontains,ephys')
django='location__name__icontains,ephys')
if len(ephys_sess) > 0:
ephys_sess_dates = [sess['start_time'][:10] for sess in ephys_sess]

n_delay = len(one.alyx.rest('sessions', 'list', subject=subj,
date_range=[sess_dates[-1], sess_dates[0]],
django='json__SESSION_START_DELAY_SEC__gte,900'))
django='json__SESSION_DELAY_START__gte,900'))
else:
ephys_sess_dates = []
n_delay = 0
Expand Down Expand Up @@ -313,23 +313,32 @@ def get_training_status(trials, task_protocol, ephys_sess_dates, n_delay):

info = Bunch()
trials_all = concatenate_trials(trials)
info.session_dates = list(trials.keys())
info.protocols = [p for p in task_protocol]

# Case when all sessions are trainingChoiceWorld
if np.all(np.array(task_protocol) == 'training'):
signed_contrast = get_signed_contrast(trials_all)
signed_contrast = np.unique(get_signed_contrast(trials_all))
(info.perf_easy, info.n_trials,
info.psych, info.rt) = compute_training_info(trials, trials_all)
if not np.any(signed_contrast == 0):
status = 'in training'

pass_criteria, criteria = criterion_1b(info.psych, info.n_trials, info.perf_easy, info.rt,
signed_contrast)
if pass_criteria:
failed_criteria = Bunch()
failed_criteria['NBiased'] = {'val': info.protocols, 'pass': False}
failed_criteria['Criteria'] = {'val': 'ready4ephysrig', 'pass': False}
status = 'trained 1b'
else:
if criterion_1b(info.psych, info.n_trials, info.perf_easy, info.rt):
status = 'trained 1b'
elif criterion_1a(info.psych, info.n_trials, info.perf_easy):
failed_criteria = criteria
pass_criteria, criteria = criterion_1a(info.psych, info.n_trials, info.perf_easy, signed_contrast)
if pass_criteria:
status = 'trained 1a'
else:
failed_criteria = criteria
status = 'in training'

return status, info
return status, info, failed_criteria

# Case when there are < 3 biasedChoiceWorld sessions after reaching trained_1b criterion
if ~np.all(np.array(task_protocol) == 'training') and \
Expand All @@ -338,45 +347,52 @@ def get_training_status(trials, task_protocol, ephys_sess_dates, n_delay):
(info.perf_easy, info.n_trials,
info.psych, info.rt) = compute_training_info(trials, trials_all)

return status, info
criteria = Bunch()
criteria['NBiased'] = {'val': info.protocols, 'pass': False}
criteria['Criteria'] = {'val': 'ready4ephysrig', 'pass': False}

return status, info, criteria

# Case when there is biasedChoiceWorld or ephysChoiceWorld in last three sessions
if not np.any(np.array(task_protocol) == 'training'):

(info.perf_easy, info.n_trials,
info.psych_20, info.psych_80,
info.rt) = compute_bias_info(trials, trials_all)
# We are still on training rig and so all sessions should be biased
if len(ephys_sess_dates) == 0:
assert np.all(np.array(task_protocol) == 'biased')
if criterion_ephys(info.psych_20, info.psych_80, info.n_trials, info.perf_easy,
info.rt):
status = 'ready4ephysrig'
else:
status = 'trained 1b'

elif len(ephys_sess_dates) < 3:
n_ephys = len(ephys_sess_dates)
info.n_ephys = n_ephys
info.n_delay = n_delay

# Criterion recording
pass_criteria, criteria = criteria_recording(n_ephys, n_delay, info.psych_20, info.psych_80, info.n_trials,
info.perf_easy, info.rt)
if pass_criteria:
# Here the criteria doesn't actually fail but we have no other criteria to meet so we return this
failed_criteria = criteria
status = 'ready4recording'
else:
failed_criteria = criteria
assert all(date in trials for date in ephys_sess_dates)
perf_ephys_easy = np.array([compute_performance_easy(trials[k]) for k in
ephys_sess_dates])
n_ephys_trials = np.array([compute_n_trials(trials[k]) for k in ephys_sess_dates])

if criterion_delay(n_ephys_trials, perf_ephys_easy):
status = 'ready4delay'
else:
status = 'ready4ephysrig'

elif len(ephys_sess_dates) >= 3:
if n_delay > 0 and \
criterion_ephys(info.psych_20, info.psych_80, info.n_trials, info.perf_easy,
info.rt):
status = 'ready4recording'
elif criterion_delay(info.n_trials, info.perf_easy):
pass_criteria, criteria = criterion_delay(n_ephys, n_ephys_trials, perf_ephys_easy)

if pass_criteria:
status = 'ready4delay'
else:
status = 'ready4ephysrig'
failed_criteria = criteria
pass_criteria, criteria = criterion_ephys(info.psych_20, info.psych_80, info.n_trials,
info.perf_easy, info.rt)
if pass_criteria:
status = 'ready4ephysrig'
else:
failed_criteria = criteria
status = 'trained 1b'

return status, info
return status, info, failed_criteria


def display_status(subj, sess_dates, status, perf_easy=None, n_trials=None, psych=None,
Expand Down Expand Up @@ -814,7 +830,7 @@ def compute_reaction_time(trials, stim_on_type='stimOn_times', stim_off_type='re
return reaction_time, contrasts, n_contrasts,


def criterion_1a(psych, n_trials, perf_easy):
def criterion_1a(psych, n_trials, perf_easy, signed_contrast):
"""
Returns bool indicating whether criteria for status 'trained_1a' are met.

Expand All @@ -825,6 +841,7 @@ def criterion_1a(psych, n_trials, perf_easy):
- Lapse rate on both sides is less than 0.2
- The total number of trials is greater than 200 for each session
- Performance on easy contrasts > 80% for all sessions
- Zero contrast trials must be present

Parameters
----------
Expand All @@ -835,24 +852,39 @@ def criterion_1a(psych, n_trials, perf_easy):
The number for trials for each session.
perf_easy : numpy.array of float
The proportion of correct high contrast trials for each session.
signed_contrast: numpy.array
Unique list of contrasts displayed

Returns
-------
bool
True if the criteria are met for 'trained_1a'.
Bunch
Bunch containing breakdown of the passing/ failing critieria

Notes
-----
The parameter thresholds chosen here were originally determined by averaging the parameter fits
for a number of sessions determined to be of 'good' performance by an experimenter.
"""

criterion = (abs(psych[0]) < 16 and psych[1] < 19 and psych[2] < 0.2 and psych[3] < 0.2 and
np.all(n_trials > 200) and np.all(perf_easy > 0.8))
return criterion
criteria = Bunch()
criteria['Zero_contrast'] = {'val': signed_contrast, 'pass': np.any(signed_contrast == 0)}
criteria['LapseLow_50'] = {'val': psych[2], 'pass': psych[2] < 0.2}
criteria['LapseHigh_50'] = {'val': psych[3], 'pass': psych[3] < 0.2}
criteria['Bias'] = {'val': psych[0], 'pass': abs(psych[0]) < 16}
criteria['Threshold'] = {'val': psych[1], 'pass': psych[1] < 19}
criteria['N_trials'] = {'val': n_trials, 'pass': np.all(n_trials > 200)}
criteria['Perf_easy'] = {'val': perf_easy, 'pass': np.all(perf_easy > 0.8)}

passing = np.all([v['pass'] for k, v in criteria.items()])

criteria['Criteria'] = {'val': 'trained_1a', 'pass': passing}

def criterion_1b(psych, n_trials, perf_easy, rt):
return passing, criteria


def criterion_1b(psych, n_trials, perf_easy, rt, signed_contrast):
"""
Returns bool indicating whether criteria for trained_1b are met.

Expand All @@ -864,6 +896,7 @@ def criterion_1b(psych, n_trials, perf_easy, rt):
- The total number of trials is greater than 400 for each session
- Performance on easy contrasts > 90% for all sessions
- The median response time across all zero contrast trials is less than 2 seconds
- Zero contrast trials must be present

Parameters
----------
Expand All @@ -876,11 +909,15 @@ def criterion_1b(psych, n_trials, perf_easy, rt):
The proportion of correct high contrast trials for each session.
rt : float
The median response time for zero contrast trials.
signed_contrast: numpy.array
Unique list of contrasts displayed

Returns
-------
bool
True if the criteria are met for 'trained_1b'.
Bunch
Bunch containing breakdown of the passing/ failing critieria

Notes
-----
Expand All @@ -890,17 +927,27 @@ def criterion_1b(psych, n_trials, perf_easy, rt):
regrettably means that the maximum threshold fit for 1b is greater than for 1a, meaning the
slope of the psychometric curve may be slightly less steep than 1a.
"""
criterion = (abs(psych[0]) < 10 and psych[1] < 20 and psych[2] < 0.1 and psych[3] < 0.1 and
np.all(n_trials > 400) and np.all(perf_easy > 0.9) and rt < 2)
return criterion

criteria = Bunch()
criteria['Zero_contrast'] = {'val': signed_contrast, 'pass': np.any(signed_contrast == 0)}
criteria['LapseLow_50'] = {'val': psych[2], 'pass': psych[2] < 0.1}
criteria['LapseHigh_50'] = {'val': psych[3], 'pass': psych[3] < 0.1}
criteria['Bias'] = {'val': psych[0], 'pass': abs(psych[0]) < 10}
criteria['Threshold'] = {'val': psych[1], 'pass': psych[1] < 20}
criteria['N_trials'] = {'val': n_trials, 'pass': np.all(n_trials > 400)}
criteria['Perf_tasy'] = {'val': perf_easy, 'pass': np.all(perf_easy > 0.9)}
criteria['Reaction_time'] = {'val': rt, 'pass': rt < 2}

passing = np.all([v['pass'] for k, v in criteria.items()])

criteria['Criteria'] = {'val': 'trained_1b', 'pass': passing}

return passing, criteria


def criterion_ephys(psych_20, psych_80, n_trials, perf_easy, rt):
"""
Returns bool indicating whether criteria for ready4ephysrig or ready4recording are met.

NB: The difference between these two is whether the sessions were acquired ot a recording rig
with a delay before the first trial. Neither of these two things are tested here.
Returns bool indicating whether criteria for ready4ephysrig are met.

Criteria
--------
Expand Down Expand Up @@ -929,21 +976,34 @@ def criterion_ephys(psych_20, psych_80, n_trials, perf_easy, rt):
Returns
-------
bool
True if subject passes the ready4ephysrig or ready4recording criteria.
True if subject passes the ready4ephysrig criteria.
Bunch
Bunch containing breakdown of the passing/ failing critieria
"""
criteria = Bunch()
criteria['LapseLow_80'] = {'val': psych_80[2], 'pass': psych_80[2] < 0.1}
criteria['LapseHigh_80'] = {'val': psych_80[3], 'pass': psych_80[3] < 0.1}
criteria['LapseLow_20'] = {'val': psych_20[2], 'pass': psych_20[2] < 0.1}
criteria['LapseHigh_20'] = {'val': psych_20[3], 'pass': psych_20[3] < 0.1}
criteria['Bias_shift'] = {'val': psych_80[0] - psych_20[0], 'pass': psych_80[0] - psych_20[0] > 5}
criteria['N_trials'] = {'val': n_trials, 'pass': np.all(n_trials > 400)}
criteria['Perf_easy'] = {'val': perf_easy, 'pass': np.all(perf_easy > 0.9)}
criteria['Reaction_time'] = {'val': rt, 'pass': rt < 2}

criterion = (np.all(np.r_[psych_20[2:4], psych_80[2:4]] < 0.1) and # lapse
psych_80[0] - psych_20[0] > 5 and np.all(n_trials > 400) and # bias shift and n trials
np.all(perf_easy > 0.9) and rt < 2) # overall performance and response times
return criterion
passing = np.all([v['pass'] for k, v in criteria.items()])

criteria['Criteria'] = {'val': 'ready4ephysrig', 'pass': passing}

def criterion_delay(n_trials, perf_easy):
return passing, criteria


def criterion_delay(n_ephys, n_trials, perf_easy):
"""
Returns bool indicating whether criteria for 'ready4delay' is met.

Criteria
--------
- At least one session on an ephys rig
- Total number of trials for any of the sessions is greater than 400
- Performance on easy contrasts is greater than 90% for any of the sessions

Expand All @@ -959,9 +1019,69 @@ def criterion_delay(n_trials, perf_easy):
-------
bool
True if subject passes the 'ready4delay' criteria.
Bunch
Bunch containing breakdown of the passing/ failing critieria
"""

criteria = Bunch()
criteria['N_ephys'] = {'val': n_ephys, 'pass': n_ephys > 0}
criteria['N_trials'] = {'val': n_trials, 'pass': np.any(n_trials > 400)}
criteria['Perf_easy'] = {'val': perf_easy, 'pass': np.any(perf_easy > 0.9)}

passing = np.all([v['pass'] for k, v in criteria.items()])

criteria['Criteria'] = {'val': 'ready4delay', 'pass': passing}

return passing, criteria


def criteria_recording(n_ephys, delay, psych_20, psych_80, n_trials, perf_easy, rt):
"""
criterion = np.any(n_trials > 400) and np.any(perf_easy > 0.9)
return criterion
Returns bool indicating whether criteria for ready4recording are met.

Criteria
--------
- At least 3 ephys sessions
- Delay on any session > 0
- Lapse on both sides < 0.1 for both bias blocks
- Bias shift between blocks > 5
- Total number of trials > 400 for all sessions
- Performance on easy contrasts > 90% for all sessions
- Median response time for zero contrast stimuli < 2 seconds

Parameters
----------
psych_20 : numpy.array
The fit psychometric parameters for the blocks where probability of a left stimulus is 0.2.
Parameters are bias, threshold, lapse high, lapse low.
psych_80 : numpy.array
The fit psychometric parameters for the blocks where probability of a left stimulus is 0.8.
Parameters are bias, threshold, lapse high, lapse low.
n_trials : numpy.array
The number of trials for each session (typically three consecutive sessions).
perf_easy : numpy.array
The proportion of correct high contrast trials for each session (typically three
consecutive sessions).
rt : float
The median response time for zero contrast trials.

Returns
-------
bool
True if subject passes the ready4recording criteria.
Bunch
Bunch containing breakdown of the passing/ failing critieria
"""

_, criteria = criterion_ephys(psych_20, psych_80, n_trials, perf_easy, rt)
criteria['N_ephys'] = {'val': n_ephys, 'pass': n_ephys >= 3}
criteria['N_delay'] = {'val': delay, 'pass': delay > 0}

passing = np.all([v['pass'] for k, v in criteria.items()])

criteria['Criteria'] = {'val': 'ready4recording', 'pass': passing}

return passing, criteria


def plot_psychometric(trials, ax=None, title=None, plot_ci=False, ci_alpha=0.032, **kwargs):
Expand Down
Loading
Loading