diff --git a/brainbox/behavior/training.py b/brainbox/behavior/training.py index 6959c3bae..a8f2f383d 100644 --- a/brainbox/behavior/training.py +++ b/brainbox/behavior/training.py @@ -157,7 +157,7 @@ def get_subject_training_status(subj, date=None, details=True, one=None): if not trials: return sess_dates = list(trials.keys()) - status, info = get_training_status(trials, task_protocol, ephys_sess, n_delay) + status, info, _ = get_training_status(trials, task_protocol, ephys_sess, n_delay) if details: if np.any(info.get('psych')): @@ -265,13 +265,13 @@ def get_sessions(subj, date=None, one=None): if not np.any(np.array(task_protocol) == 'training'): ephys_sess = one.alyx.rest('sessions', 'list', subject=subj, date_range=[sess_dates[-1], sess_dates[0]], - django='json__PYBPOD_BOARD__icontains,ephys') + django='location__name__icontains,ephys') if len(ephys_sess) > 0: ephys_sess_dates = [sess['start_time'][:10] for sess in ephys_sess] n_delay = len(one.alyx.rest('sessions', 'list', subject=subj, date_range=[sess_dates[-1], sess_dates[0]], - django='json__SESSION_START_DELAY_SEC__gte,900')) + django='json__SESSION_DELAY_START__gte,900')) else: ephys_sess_dates = [] n_delay = 0 @@ -313,23 +313,32 @@ def get_training_status(trials, task_protocol, ephys_sess_dates, n_delay): info = Bunch() trials_all = concatenate_trials(trials) + info.session_dates = list(trials.keys()) + info.protocols = [p for p in task_protocol] # Case when all sessions are trainingChoiceWorld if np.all(np.array(task_protocol) == 'training'): - signed_contrast = get_signed_contrast(trials_all) + signed_contrast = np.unique(get_signed_contrast(trials_all)) (info.perf_easy, info.n_trials, info.psych, info.rt) = compute_training_info(trials, trials_all) - if not np.any(signed_contrast == 0): - status = 'in training' + + pass_criteria, criteria = criterion_1b(info.psych, info.n_trials, info.perf_easy, info.rt, + signed_contrast) + if pass_criteria: + failed_criteria = Bunch() + failed_criteria['NBiased'] = {'val': info.protocols, 'pass': False} + failed_criteria['Criteria'] = {'val': 'ready4ephysrig', 'pass': False} + status = 'trained 1b' else: - if criterion_1b(info.psych, info.n_trials, info.perf_easy, info.rt): - status = 'trained 1b' - elif criterion_1a(info.psych, info.n_trials, info.perf_easy): + failed_criteria = criteria + pass_criteria, criteria = criterion_1a(info.psych, info.n_trials, info.perf_easy, signed_contrast) + if pass_criteria: status = 'trained 1a' else: + failed_criteria = criteria status = 'in training' - return status, info + return status, info, failed_criteria # Case when there are < 3 biasedChoiceWorld sessions after reaching trained_1b criterion if ~np.all(np.array(task_protocol) == 'training') and \ @@ -338,7 +347,11 @@ def get_training_status(trials, task_protocol, ephys_sess_dates, n_delay): (info.perf_easy, info.n_trials, info.psych, info.rt) = compute_training_info(trials, trials_all) - return status, info + criteria = Bunch() + criteria['NBiased'] = {'val': info.protocols, 'pass': False} + criteria['Criteria'] = {'val': 'ready4ephysrig', 'pass': False} + + return status, info, criteria # Case when there is biasedChoiceWorld or ephysChoiceWorld in last three sessions if not np.any(np.array(task_protocol) == 'training'): @@ -346,37 +359,40 @@ def get_training_status(trials, task_protocol, ephys_sess_dates, n_delay): (info.perf_easy, info.n_trials, info.psych_20, info.psych_80, info.rt) = compute_bias_info(trials, trials_all) - # We are still on training rig and so all sessions should be biased - if len(ephys_sess_dates) == 0: - assert np.all(np.array(task_protocol) == 'biased') - if criterion_ephys(info.psych_20, info.psych_80, info.n_trials, info.perf_easy, - info.rt): - status = 'ready4ephysrig' - else: - status = 'trained 1b' - elif len(ephys_sess_dates) < 3: + n_ephys = len(ephys_sess_dates) + info.n_ephys = n_ephys + info.n_delay = n_delay + + # Criterion recording + pass_criteria, criteria = criteria_recording(n_ephys, n_delay, info.psych_20, info.psych_80, info.n_trials, + info.perf_easy, info.rt) + if pass_criteria: + # Here the criteria doesn't actually fail but we have no other criteria to meet so we return this + failed_criteria = criteria + status = 'ready4recording' + else: + failed_criteria = criteria assert all(date in trials for date in ephys_sess_dates) perf_ephys_easy = np.array([compute_performance_easy(trials[k]) for k in ephys_sess_dates]) n_ephys_trials = np.array([compute_n_trials(trials[k]) for k in ephys_sess_dates]) - if criterion_delay(n_ephys_trials, perf_ephys_easy): - status = 'ready4delay' - else: - status = 'ready4ephysrig' - - elif len(ephys_sess_dates) >= 3: - if n_delay > 0 and \ - criterion_ephys(info.psych_20, info.psych_80, info.n_trials, info.perf_easy, - info.rt): - status = 'ready4recording' - elif criterion_delay(info.n_trials, info.perf_easy): + pass_criteria, criteria = criterion_delay(n_ephys, n_ephys_trials, perf_ephys_easy) + + if pass_criteria: status = 'ready4delay' else: - status = 'ready4ephysrig' + failed_criteria = criteria + pass_criteria, criteria = criterion_ephys(info.psych_20, info.psych_80, info.n_trials, + info.perf_easy, info.rt) + if pass_criteria: + status = 'ready4ephysrig' + else: + failed_criteria = criteria + status = 'trained 1b' - return status, info + return status, info, failed_criteria def display_status(subj, sess_dates, status, perf_easy=None, n_trials=None, psych=None, @@ -814,7 +830,7 @@ def compute_reaction_time(trials, stim_on_type='stimOn_times', stim_off_type='re return reaction_time, contrasts, n_contrasts, -def criterion_1a(psych, n_trials, perf_easy): +def criterion_1a(psych, n_trials, perf_easy, signed_contrast): """ Returns bool indicating whether criteria for status 'trained_1a' are met. @@ -825,6 +841,7 @@ def criterion_1a(psych, n_trials, perf_easy): - Lapse rate on both sides is less than 0.2 - The total number of trials is greater than 200 for each session - Performance on easy contrasts > 80% for all sessions + - Zero contrast trials must be present Parameters ---------- @@ -835,11 +852,15 @@ def criterion_1a(psych, n_trials, perf_easy): The number for trials for each session. perf_easy : numpy.array of float The proportion of correct high contrast trials for each session. + signed_contrast: numpy.array + Unique list of contrasts displayed Returns ------- bool True if the criteria are met for 'trained_1a'. + Bunch + Bunch containing breakdown of the passing/ failing critieria Notes ----- @@ -847,12 +868,23 @@ def criterion_1a(psych, n_trials, perf_easy): for a number of sessions determined to be of 'good' performance by an experimenter. """ - criterion = (abs(psych[0]) < 16 and psych[1] < 19 and psych[2] < 0.2 and psych[3] < 0.2 and - np.all(n_trials > 200) and np.all(perf_easy > 0.8)) - return criterion + criteria = Bunch() + criteria['Zero_contrast'] = {'val': signed_contrast, 'pass': np.any(signed_contrast == 0)} + criteria['LapseLow_50'] = {'val': psych[2], 'pass': psych[2] < 0.2} + criteria['LapseHigh_50'] = {'val': psych[3], 'pass': psych[3] < 0.2} + criteria['Bias'] = {'val': psych[0], 'pass': abs(psych[0]) < 16} + criteria['Threshold'] = {'val': psych[1], 'pass': psych[1] < 19} + criteria['N_trials'] = {'val': n_trials, 'pass': np.all(n_trials > 200)} + criteria['Perf_easy'] = {'val': perf_easy, 'pass': np.all(perf_easy > 0.8)} + + passing = np.all([v['pass'] for k, v in criteria.items()]) + criteria['Criteria'] = {'val': 'trained_1a', 'pass': passing} -def criterion_1b(psych, n_trials, perf_easy, rt): + return passing, criteria + + +def criterion_1b(psych, n_trials, perf_easy, rt, signed_contrast): """ Returns bool indicating whether criteria for trained_1b are met. @@ -864,6 +896,7 @@ def criterion_1b(psych, n_trials, perf_easy, rt): - The total number of trials is greater than 400 for each session - Performance on easy contrasts > 90% for all sessions - The median response time across all zero contrast trials is less than 2 seconds + - Zero contrast trials must be present Parameters ---------- @@ -876,11 +909,15 @@ def criterion_1b(psych, n_trials, perf_easy, rt): The proportion of correct high contrast trials for each session. rt : float The median response time for zero contrast trials. + signed_contrast: numpy.array + Unique list of contrasts displayed Returns ------- bool True if the criteria are met for 'trained_1b'. + Bunch + Bunch containing breakdown of the passing/ failing critieria Notes ----- @@ -890,17 +927,27 @@ def criterion_1b(psych, n_trials, perf_easy, rt): regrettably means that the maximum threshold fit for 1b is greater than for 1a, meaning the slope of the psychometric curve may be slightly less steep than 1a. """ - criterion = (abs(psych[0]) < 10 and psych[1] < 20 and psych[2] < 0.1 and psych[3] < 0.1 and - np.all(n_trials > 400) and np.all(perf_easy > 0.9) and rt < 2) - return criterion + + criteria = Bunch() + criteria['Zero_contrast'] = {'val': signed_contrast, 'pass': np.any(signed_contrast == 0)} + criteria['LapseLow_50'] = {'val': psych[2], 'pass': psych[2] < 0.1} + criteria['LapseHigh_50'] = {'val': psych[3], 'pass': psych[3] < 0.1} + criteria['Bias'] = {'val': psych[0], 'pass': abs(psych[0]) < 10} + criteria['Threshold'] = {'val': psych[1], 'pass': psych[1] < 20} + criteria['N_trials'] = {'val': n_trials, 'pass': np.all(n_trials > 400)} + criteria['Perf_tasy'] = {'val': perf_easy, 'pass': np.all(perf_easy > 0.9)} + criteria['Reaction_time'] = {'val': rt, 'pass': rt < 2} + + passing = np.all([v['pass'] for k, v in criteria.items()]) + + criteria['Criteria'] = {'val': 'trained_1b', 'pass': passing} + + return passing, criteria def criterion_ephys(psych_20, psych_80, n_trials, perf_easy, rt): """ - Returns bool indicating whether criteria for ready4ephysrig or ready4recording are met. - - NB: The difference between these two is whether the sessions were acquired ot a recording rig - with a delay before the first trial. Neither of these two things are tested here. + Returns bool indicating whether criteria for ready4ephysrig are met. Criteria -------- @@ -929,21 +976,34 @@ def criterion_ephys(psych_20, psych_80, n_trials, perf_easy, rt): Returns ------- bool - True if subject passes the ready4ephysrig or ready4recording criteria. + True if subject passes the ready4ephysrig criteria. + Bunch + Bunch containing breakdown of the passing/ failing critieria """ + criteria = Bunch() + criteria['LapseLow_80'] = {'val': psych_80[2], 'pass': psych_80[2] < 0.1} + criteria['LapseHigh_80'] = {'val': psych_80[3], 'pass': psych_80[3] < 0.1} + criteria['LapseLow_20'] = {'val': psych_20[2], 'pass': psych_20[2] < 0.1} + criteria['LapseHigh_20'] = {'val': psych_20[3], 'pass': psych_20[3] < 0.1} + criteria['Bias_shift'] = {'val': psych_80[0] - psych_20[0], 'pass': psych_80[0] - psych_20[0] > 5} + criteria['N_trials'] = {'val': n_trials, 'pass': np.all(n_trials > 400)} + criteria['Perf_easy'] = {'val': perf_easy, 'pass': np.all(perf_easy > 0.9)} + criteria['Reaction_time'] = {'val': rt, 'pass': rt < 2} - criterion = (np.all(np.r_[psych_20[2:4], psych_80[2:4]] < 0.1) and # lapse - psych_80[0] - psych_20[0] > 5 and np.all(n_trials > 400) and # bias shift and n trials - np.all(perf_easy > 0.9) and rt < 2) # overall performance and response times - return criterion + passing = np.all([v['pass'] for k, v in criteria.items()]) + criteria['Criteria'] = {'val': 'ready4ephysrig', 'pass': passing} -def criterion_delay(n_trials, perf_easy): + return passing, criteria + + +def criterion_delay(n_ephys, n_trials, perf_easy): """ Returns bool indicating whether criteria for 'ready4delay' is met. Criteria -------- + - At least one session on an ephys rig - Total number of trials for any of the sessions is greater than 400 - Performance on easy contrasts is greater than 90% for any of the sessions @@ -959,9 +1019,69 @@ def criterion_delay(n_trials, perf_easy): ------- bool True if subject passes the 'ready4delay' criteria. + Bunch + Bunch containing breakdown of the passing/ failing critieria + """ + + criteria = Bunch() + criteria['N_ephys'] = {'val': n_ephys, 'pass': n_ephys > 0} + criteria['N_trials'] = {'val': n_trials, 'pass': np.any(n_trials > 400)} + criteria['Perf_easy'] = {'val': perf_easy, 'pass': np.any(perf_easy > 0.9)} + + passing = np.all([v['pass'] for k, v in criteria.items()]) + + criteria['Criteria'] = {'val': 'ready4delay', 'pass': passing} + + return passing, criteria + + +def criteria_recording(n_ephys, delay, psych_20, psych_80, n_trials, perf_easy, rt): """ - criterion = np.any(n_trials > 400) and np.any(perf_easy > 0.9) - return criterion + Returns bool indicating whether criteria for ready4recording are met. + + Criteria + -------- + - At least 3 ephys sessions + - Delay on any session > 0 + - Lapse on both sides < 0.1 for both bias blocks + - Bias shift between blocks > 5 + - Total number of trials > 400 for all sessions + - Performance on easy contrasts > 90% for all sessions + - Median response time for zero contrast stimuli < 2 seconds + + Parameters + ---------- + psych_20 : numpy.array + The fit psychometric parameters for the blocks where probability of a left stimulus is 0.2. + Parameters are bias, threshold, lapse high, lapse low. + psych_80 : numpy.array + The fit psychometric parameters for the blocks where probability of a left stimulus is 0.8. + Parameters are bias, threshold, lapse high, lapse low. + n_trials : numpy.array + The number of trials for each session (typically three consecutive sessions). + perf_easy : numpy.array + The proportion of correct high contrast trials for each session (typically three + consecutive sessions). + rt : float + The median response time for zero contrast trials. + + Returns + ------- + bool + True if subject passes the ready4recording criteria. + Bunch + Bunch containing breakdown of the passing/ failing critieria + """ + + _, criteria = criterion_ephys(psych_20, psych_80, n_trials, perf_easy, rt) + criteria['N_ephys'] = {'val': n_ephys, 'pass': n_ephys >= 3} + criteria['N_delay'] = {'val': delay, 'pass': delay > 0} + + passing = np.all([v['pass'] for k, v in criteria.items()]) + + criteria['Criteria'] = {'val': 'ready4recording', 'pass': passing} + + return passing, criteria def plot_psychometric(trials, ax=None, title=None, plot_ci=False, ci_alpha=0.032, **kwargs): diff --git a/brainbox/tests/test_behavior.py b/brainbox/tests/test_behavior.py index 8d02d185a..493234937 100644 --- a/brainbox/tests/test_behavior.py +++ b/brainbox/tests/test_behavior.py @@ -177,58 +177,65 @@ def test_in_training(self): trials, task_protocol = self._get_trials( sess_dates=['2020-08-25', '2020-08-24', '2020-08-21']) assert (np.all(np.array(task_protocol) == 'training')) - status, info = train.get_training_status( + status, info, crit = train.get_training_status( trials, task_protocol, ephys_sess_dates=[], n_delay=0) assert (status == 'in training') + assert (crit['Criteria']['val'] == 'trained_1a') def test_trained_1a(self): trials, task_protocol = self._get_trials( sess_dates=['2020-08-26', '2020-08-25', '2020-08-24']) assert (np.all(np.array(task_protocol) == 'training')) - status, info = train.get_training_status(trials, task_protocol, ephys_sess_dates=[], - n_delay=0) + status, info, crit = train.get_training_status(trials, task_protocol, ephys_sess_dates=[], + n_delay=0) assert (status == 'trained 1a') + assert (crit['Criteria']['val'] == 'trained_1b') def test_trained_1b(self): trials, task_protocol = self._get_trials( sess_dates=['2020-08-27', '2020-08-26', '2020-08-25']) assert (np.all(np.array(task_protocol) == 'training')) - status, info = train.get_training_status(trials, task_protocol, ephys_sess_dates=[], - n_delay=0) + status, info, crit = train.get_training_status(trials, task_protocol, ephys_sess_dates=[], + n_delay=0) self.assertEqual(status, 'trained 1b') + assert (crit['Criteria']['val'] == 'ready4ephysrig') def test_training_to_bias(self): trials, task_protocol = self._get_trials( sess_dates=['2020-08-31', '2020-08-28', '2020-08-27']) assert (~np.all(np.array(task_protocol) == 'training') and np.any(np.array(task_protocol) == 'training')) - status, info = train.get_training_status(trials, task_protocol, ephys_sess_dates=[], - n_delay=0) + status, info, crit = train.get_training_status(trials, task_protocol, ephys_sess_dates=[], + n_delay=0) assert (status == 'trained 1b') + assert (crit['Criteria']['val'] == 'ready4ephysrig') def test_ready4ephys(self): sess_dates = ['2020-09-01', '2020-08-31', '2020-08-28'] trials, task_protocol = self._get_trials(sess_dates=sess_dates) assert (np.all(np.array(task_protocol) == 'biased')) - status, info = train.get_training_status(trials, task_protocol, ephys_sess_dates=[], - n_delay=0) + status, info, crit = train.get_training_status(trials, task_protocol, ephys_sess_dates=[], + n_delay=0) assert (status == 'ready4ephysrig') + assert (crit['Criteria']['val'] == 'ready4delay') def test_ready4delay(self): sess_dates = ['2020-09-03', '2020-09-02', '2020-08-31'] trials, task_protocol = self._get_trials(sess_dates=sess_dates) assert (np.all(np.array(task_protocol) == 'biased')) - status, info = train.get_training_status(trials, task_protocol, - ephys_sess_dates=['2020-09-03'], n_delay=0) + status, info, crit = train.get_training_status(trials, task_protocol, + ephys_sess_dates=['2020-09-03'], n_delay=0) assert (status == 'ready4delay') + assert (crit['Criteria']['val'] == 'ready4recording') def test_ready4recording(self): sess_dates = ['2020-09-01', '2020-08-31', '2020-08-28'] trials, task_protocol = self._get_trials(sess_dates=sess_dates) assert (np.all(np.array(task_protocol) == 'biased')) - status, info = train.get_training_status(trials, task_protocol, - ephys_sess_dates=sess_dates, n_delay=1) + status, info, crit = train.get_training_status(trials, task_protocol, + ephys_sess_dates=sess_dates, n_delay=1) assert (status == 'ready4recording') + assert (crit['Criteria']['val'] == 'ready4recording') def test_query_criterion(self): """Test for brainbox.behavior.training.query_criterion function.""" diff --git a/ibllib/oneibl/registration.py b/ibllib/oneibl/registration.py index c5dc71553..51a4af84f 100644 --- a/ibllib/oneibl/registration.py +++ b/ibllib/oneibl/registration.py @@ -293,6 +293,11 @@ def register_session(self, ses_path, file_list=True, projects=None, procedures=N poo_counts = [md.get('POOP_COUNT') for md in settings if md.get('POOP_COUNT') is not None] if poo_counts: json_field['POOP_COUNT'] = int(sum(poo_counts)) + # Get the session start delay if available, needed for the training status + session_delay = [md.get('SESSION_DELAY_START') for md in settings + if md.get('SESSION_DELAY_START') is not None] + if session_delay: + json_field['SESSION_DELAY_START'] = int(sum(session_delay)) if not len(session): # Create session and weighings ses_ = {'subject': subject['nickname'], diff --git a/ibllib/pipes/training_status.py b/ibllib/pipes/training_status.py index fec33baaf..50d28707f 100644 --- a/ibllib/pipes/training_status.py +++ b/ibllib/pipes/training_status.py @@ -208,9 +208,11 @@ def load_combined_trials(sess_paths, one, force=True): """ trials_dict = {} for sess_path in sess_paths: - trials = load_trials(Path(sess_path), one, force=force) + trials = load_trials(Path(sess_path), one, force=force, mode='warn') if trials is not None: - trials_dict[Path(sess_path).stem] = load_trials(Path(sess_path), one, force=force) + trials_dict[Path(sess_path).stem] = load_trials(Path(sess_path), one, force=force, mode='warn' + + ) return training.concatenate_trials(trials_dict) @@ -270,7 +272,7 @@ def get_latest_training_information(sess_path, one, save=True): # Find the earliest date in missing dates that we need to recompute the training status for missing_status = find_earliest_recompute_date(df.drop_duplicates('date').reset_index(drop=True)) for date in missing_status: - df = compute_training_status(df, date, one) + df, _, _, _ = compute_training_status(df, date, one) df_lim = df.drop_duplicates(subset='session_path', keep='first') @@ -314,7 +316,7 @@ def find_earliest_recompute_date(df): return df[first_index:].date.values -def compute_training_status(df, compute_date, one, force=True): +def compute_training_status(df, compute_date, one, force=True, populate=True): """ Compute the training status for compute date based on training from that session and two previous days. @@ -331,11 +333,19 @@ def compute_training_status(df, compute_date, one, force=True): An instance of ONE for loading trials data. force : bool When true and if the session trials can't be found, will attempt to re-extract from disk. + populate : bool + Whether to update the training data frame with the new training status value Returns ------- pandas.DataFrame - The input data frame with a 'training_status' column populated for `compute_date`. + The input data frame with a 'training_status' column populated for `compute_date` if populate=True + Bunch + Bunch containing information fit parameters information for the combined sessions + Bunch + Bunch cotaining the training status criteria information + str + The training status """ # compute_date = str(alfiles.session_path_parts(session_path, as_dict=True)['date']) @@ -378,11 +388,12 @@ def compute_training_status(df, compute_date, one, force=True): ephys_sessions.append(df_date.iloc[-1]['date']) n_status = np.max([-2, -1 * len(status)]) - training_status, _ = training.get_training_status(trials, protocol, ephys_sessions, n_delay) + training_status, info, criteria = training.get_training_status(trials, protocol, ephys_sessions, n_delay) training_status = pass_through_training_hierachy(training_status, status[n_status]) - df.loc[df['date'] == compute_date, 'training_status'] = training_status + if populate: + df.loc[df['date'] == compute_date, 'training_status'] = training_status - return df + return df, info, criteria, training_status def pass_through_training_hierachy(status_new, status_old): @@ -433,12 +444,13 @@ def compute_session_duration_delay_location(sess_path, collections=None, **kwarg try: start_time, end_time = _get_session_times(sess_path, md, sess_data) session_duration = session_duration + int((end_time - start_time).total_seconds() / 60) - session_delay = session_delay + md.get('SESSION_START_DELAY_SEC', 0) + session_delay = session_delay + md.get('SESSION_DELAY_START', + md.get('SESSION_START_DELAY_SEC', 0)) except Exception: session_duration = session_duration + 0 session_delay = session_delay + 0 - if 'ephys' in md.get('PYBPOD_BOARD', None): + if 'ephys' in md.get('RIG_NAME', md.get('PYBPOD_BOARD', None)): session_location = 'ephys_rig' else: session_location = 'training_rig' @@ -586,9 +598,12 @@ def get_training_info_for_session(session_paths, one, force=True): session_path = Path(session_path) protocols = [] for c in collections: - prot = get_bpod_extractor_class(session_path, task_collection=c) - prot = prot[:-6].lower() - protocols.append(prot) + try: + prot = get_bpod_extractor_class(session_path, task_collection=c) + prot = prot[:-6].lower() + protocols.append(prot) + except ValueError: + continue un_protocols = np.unique(protocols) # Example, training, training, biased - training would be combined, biased not @@ -751,9 +766,54 @@ def plot_performance_easy_median_reaction_time(df, subject): return ax +def display_info(df, axs): + compute_date = df['date'].values[-1] + _, info, criteria, _ = compute_training_status(df, compute_date, None, force=False, populate=False) + + def _array_to_string(vals): + if isinstance(vals, (str, bool, int, float)): + if isinstance(vals, float): + vals = np.round(vals, 3) + return f'{vals}' + + str_vals = '' + for v in vals: + if isinstance(v, float): + v = np.round(v, 3) + str_vals += f'{v}, ' + return str_vals[:-2] + + pos = np.arange(len(criteria))[::-1] * 0.1 + for i, (k, v) in enumerate(info.items()): + str_v = _array_to_string(v) + text = axs[0].text(0, pos[i], k.capitalize(), color='k', weight='bold', fontsize=8, transform=axs[0].transAxes) + axs[0].annotate(': ' + str_v, xycoords=text, xy=(1, 0), verticalalignment="bottom", + color='k', fontsize=7) + + pos = np.arange(len(criteria))[::-1] * 0.1 + crit_val = criteria.pop('Criteria') + c = 'g' if crit_val['pass'] else 'r' + str_v = _array_to_string(crit_val['val']) + text = axs[1].text(0, pos[0], 'Criteria', color='k', weight='bold', fontsize=8, transform=axs[1].transAxes) + axs[1].annotate(': ' + str_v, xycoords=text, xy=(1, 0), verticalalignment="bottom", + color=c, fontsize=7) + pos = pos[1:] + + for i, (k, v) in enumerate(criteria.items()): + c = 'g' if v['pass'] else 'r' + str_v = _array_to_string(v['val']) + text = axs[1].text(0, pos[i], k.capitalize(), color='k', weight='bold', fontsize=8, transform=axs[1].transAxes) + axs[1].annotate(': ' + str_v, xycoords=text, xy=(1, 0), verticalalignment="bottom", + color=c, fontsize=7) + + axs[0].set_axis_off() + axs[1].set_axis_off() + + def plot_fit_params(df, subject): - fig, axs = plt.subplots(2, 2, figsize=(12, 6)) - axs = axs.ravel() + fig, axs = plt.subplots(2, 3, figsize=(12, 6), gridspec_kw={'width_ratios': [2, 2, 1]}) + + display_info(df, axs=[axs[0, 2], axs[1, 2]]) df = df.drop_duplicates('date').reset_index(drop=True) @@ -777,11 +837,11 @@ def plot_fit_params(df, subject): 'color': cmap[0], 'join': False} - plot_over_days(df, subject, y50, ax=axs[0], legend=False, title=False) - plot_over_days(df, subject, y80, ax=axs[0], legend=False, title=False) - plot_over_days(df, subject, y20, ax=axs[0], legend=False, title=False) - axs[0].axhline(16, linewidth=2, linestyle='--', color='k') - axs[0].axhline(-16, linewidth=2, linestyle='--', color='k') + plot_over_days(df, subject, y50, ax=axs[0, 0], legend=False, title=False) + plot_over_days(df, subject, y80, ax=axs[0, 0], legend=False, title=False) + plot_over_days(df, subject, y20, ax=axs[0, 0], legend=False, title=False) + axs[0, 0].axhline(16, linewidth=2, linestyle='--', color='k') + axs[0, 0].axhline(-16, linewidth=2, linestyle='--', color='k') y50['column'] = 'combined_thres_50' y50['title'] = 'Threshold' @@ -793,10 +853,10 @@ def plot_fit_params(df, subject): y20['title'] = 'Threshold' y80['lim'] = [0, 100] - plot_over_days(df, subject, y50, ax=axs[1], legend=False, title=False) - plot_over_days(df, subject, y80, ax=axs[1], legend=False, title=False) - plot_over_days(df, subject, y20, ax=axs[1], legend=False, title=False) - axs[1].axhline(19, linewidth=2, linestyle='--', color='k') + plot_over_days(df, subject, y50, ax=axs[0, 1], legend=False, title=False) + plot_over_days(df, subject, y80, ax=axs[0, 1], legend=False, title=False) + plot_over_days(df, subject, y20, ax=axs[0, 1], legend=False, title=False) + axs[0, 1].axhline(19, linewidth=2, linestyle='--', color='k') y50['column'] = 'combined_lapselow_50' y50['title'] = 'Lapse Low' @@ -808,10 +868,10 @@ def plot_fit_params(df, subject): y20['title'] = 'Lapse Low' y20['lim'] = [0, 1] - plot_over_days(df, subject, y50, ax=axs[2], legend=False, title=False) - plot_over_days(df, subject, y80, ax=axs[2], legend=False, title=False) - plot_over_days(df, subject, y20, ax=axs[2], legend=False, title=False) - axs[2].axhline(0.2, linewidth=2, linestyle='--', color='k') + plot_over_days(df, subject, y50, ax=axs[1, 0], legend=False, title=False) + plot_over_days(df, subject, y80, ax=axs[1, 0], legend=False, title=False) + plot_over_days(df, subject, y20, ax=axs[1, 0], legend=False, title=False) + axs[1, 0].axhline(0.2, linewidth=2, linestyle='--', color='k') y50['column'] = 'combined_lapsehigh_50' y50['title'] = 'Lapse High' @@ -823,19 +883,21 @@ def plot_fit_params(df, subject): y20['title'] = 'Lapse High' y20['lim'] = [0, 1] - plot_over_days(df, subject, y50, ax=axs[3], legend=False, title=False, training_lines=True) - plot_over_days(df, subject, y80, ax=axs[3], legend=False, title=False, training_lines=False) - plot_over_days(df, subject, y20, ax=axs[3], legend=False, title=False, training_lines=False) - axs[3].axhline(0.2, linewidth=2, linestyle='--', color='k') + plot_over_days(df, subject, y50, ax=axs[1, 1], legend=False, title=False, training_lines=True) + plot_over_days(df, subject, y80, ax=axs[1, 1], legend=False, title=False, training_lines=False) + plot_over_days(df, subject, y20, ax=axs[1, 1], legend=False, title=False, training_lines=False) + axs[1, 1].axhline(0.2, linewidth=2, linestyle='--', color='k') fig.suptitle(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}') - lines, labels = axs[3].get_legend_handles_labels() - fig.legend(lines, labels, loc='upper center', bbox_to_anchor=(0.5, 0.1), fancybox=True, shadow=True, ncol=5) + lines, labels = axs[1, 1].get_legend_handles_labels() + fig.legend(lines, labels, loc='upper center', bbox_to_anchor=(0.5, 0.1), facecolor='w', fancybox=True, shadow=True, + ncol=5) legend_elements = [Line2D([0], [0], marker='o', color='w', label='p=0.5', markerfacecolor=cmap[1], markersize=8), Line2D([0], [0], marker='o', color='w', label='p=0.2', markerfacecolor=cmap[0], markersize=8), Line2D([0], [0], marker='o', color='w', label='p=0.8', markerfacecolor=cmap[2], markersize=8)] - legend2 = plt.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.1, -0.2), fancybox=True, shadow=True) + legend2 = plt.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.1, -0.2), fancybox=True, + shadow=True, facecolor='w') fig.add_artist(legend2) return axs @@ -844,7 +906,7 @@ def plot_fit_params(df, subject): def plot_psychometric_curve(df, subject, one): df = df.drop_duplicates('date').reset_index(drop=True) sess_path = Path(df.iloc[-1]["session_path"]) - trials = load_trials(sess_path, one) + trials = load_trials(sess_path, one, mode='warn') fig, ax1 = plt.subplots(figsize=(8, 6)) @@ -907,7 +969,7 @@ def plot_over_days(df, subject, y1, y2=None, ax=None, legend=True, title=True, t box.width, box.height * 0.9]) if legend: ax1.legend(loc='upper center', bbox_to_anchor=(0.5, -0.1), - fancybox=True, shadow=True, ncol=5) + fancybox=True, shadow=True, ncol=5, facecolor='white') return ax1 @@ -1010,7 +1072,7 @@ def make_plots(session_path, one, df=None, save=False, upload=False, task_collec save_name = save_path.joinpath('subj_psychometric_fit_params.png') outputs.append(save_name) - ax4[0].get_figure().savefig(save_name, bbox_inches='tight') + ax4[0, 0].get_figure().savefig(save_name, bbox_inches='tight') save_name = save_path.joinpath('subj_psychometric_curve.png') outputs.append(save_name) diff --git a/requirements.txt b/requirements.txt index bf0f3128e..92066d9a0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,3 +31,4 @@ phylib>=2.6.0 psychofit slidingRP>=1.1.1 # steinmetz lab refractory period metrics pyqt5 +ibl-style