From 68fb8cd807ca91d5a51ff9c985442934825ce8ac Mon Sep 17 00:00:00 2001 From: k1o0 Date: Fri, 6 Oct 2023 16:03:36 +0300 Subject: [PATCH] Task qc extractor refactor (#649) * Habituation extract phase and position * Independent task QC method in behaviour tasks * var names * flake8 * Initialize settings property --- .gitignore | 1 + ibllib/ephys/ephysqc.py | 2 +- ibllib/io/extractors/biased_trials.py | 24 +- ibllib/io/extractors/ephys_fpga.py | 60 +++-- ibllib/io/extractors/habituation_trials.py | 18 +- ibllib/io/extractors/mesoscope.py | 21 +- ibllib/io/extractors/training_trials.py | 10 +- ibllib/io/extractors/training_wheel.py | 4 +- ibllib/io/session_params.py | 2 +- ibllib/pipes/behavior_tasks.py | 190 +++++++++------- ibllib/qc/task_extractors.py | 9 +- ibllib/qc/task_metrics.py | 252 ++++++++++----------- 12 files changed, 320 insertions(+), 273 deletions(-) diff --git a/.gitignore b/.gitignore index 906c5d9ac..e291b8572 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ __pycache__ python/scratch .idea/* .vscode/ +*.code-workspace *checkpoint.ipynb build/ venv/ diff --git a/ibllib/ephys/ephysqc.py b/ibllib/ephys/ephysqc.py index b8721bfe2..16ab9f870 100644 --- a/ibllib/ephys/ephysqc.py +++ b/ibllib/ephys/ephysqc.py @@ -580,7 +580,7 @@ def _qc_from_path(sess_path, display=True): sync, chmap = ephys_fpga.get_main_probe_sync(sess_path, bin_exists=False) _ = ephys_fpga.extract_all(sess_path, output_path=temp_alf_folder, save=True) # check that the output is complete - fpga_trials = ephys_fpga.extract_behaviour_sync(sync, chmap=chmap, display=display) + fpga_trials, *_ = ephys_fpga.extract_behaviour_sync(sync, chmap=chmap, display=display) # align with the bpod bpod2fpga = ephys_fpga.align_with_bpod(temp_alf_folder.parent) alf_trials = alfio.load_object(temp_alf_folder, 'trials') diff --git a/ibllib/io/extractors/biased_trials.py b/ibllib/io/extractors/biased_trials.py index c7c16d6c0..16d8f8111 100644 --- a/ibllib/io/extractors/biased_trials.py +++ b/ibllib/io/extractors/biased_trials.py @@ -95,12 +95,12 @@ class TrialsTableBiased(BaseBpodTrialsExtractor): intervals, goCue_times, response_times, choice, stimOn_times, contrastLeft, contrastRight, feedback_times, feedbackType, rewardVolume, probabilityLeft, firstMovement_times Additionally extracts the following wheel data: - wheel_timestamps, wheel_position, wheel_moves_intervals, wheel_moves_peak_amplitude + wheel_timestamps, wheel_position, wheelMoves_intervals, wheelMoves_peakAmplitude """ save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None) - var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheel_moves_intervals', - 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement') + var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', + 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement') def _extract(self, extractor_classes=None, **kwargs): base = [Intervals, GoCueTimes, ResponseTimes, Choice, StimOnOffFreezeTimes, ContrastLR, FeedbackTimes, FeedbackType, @@ -120,13 +120,13 @@ class TrialsTableEphys(BaseBpodTrialsExtractor): intervals, goCue_times, response_times, choice, stimOn_times, contrastLeft, contrastRight, feedback_times, feedbackType, rewardVolume, probabilityLeft, firstMovement_times Additionally extracts the following wheel data: - wheel_timestamps, wheel_position, wheel_moves_intervals, wheel_moves_peak_amplitude + wheel_timestamps, wheel_position, wheelMoves_intervals, wheelMoves_peakAmplitude """ save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None, None, None, '_ibl_trials.quiescencePeriod.npy') - var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheel_moves_intervals', - 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement', + var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', + 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement', 'phase', 'position', 'quiescence') def _extract(self, extractor_classes=None, **kwargs): @@ -154,16 +154,16 @@ class BiasedTrials(BaseBpodTrialsExtractor): None, None, '_ibl_trials.quiescencePeriod.npy') var_names = ('goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times', 'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', - 'wheel_moves_intervals', 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement', 'included', + 'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement', 'included', 'phase', 'position', 'quiescence') - def _extract(self, extractor_classes=None, **kwargs): + def _extract(self, extractor_classes=None, **kwargs) -> dict: base = [GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes, ErrorCueTriggerTimes, TrialsTableBiased, IncludedTrials, PhasePosQuiescence] # Exclude from trials table out, _ = run_extractor_classes(base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, save=False, task_collection=self.task_collection) - return tuple(out.pop(x) for x in self.var_names) + return {k: out[k] for k in self.var_names} class EphysTrials(BaseBpodTrialsExtractor): @@ -177,16 +177,16 @@ class EphysTrials(BaseBpodTrialsExtractor): '_ibl_trials.included.npy', None, None, '_ibl_trials.quiescencePeriod.npy') var_names = ('goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times', 'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', - 'wheel_moves_intervals', 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement', 'included', + 'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement', 'included', 'phase', 'position', 'quiescence') - def _extract(self, extractor_classes=None, **kwargs): + def _extract(self, extractor_classes=None, **kwargs) -> dict: base = [GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes, ErrorCueTriggerTimes, TrialsTableEphys, IncludedTrials, PhasePosQuiescence] # Exclude from trials table out, _ = run_extractor_classes(base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, save=False, task_collection=self.task_collection) - return tuple(out.pop(x) for x in self.var_names) + return {k: out[k] for k in self.var_names} def extract_all(session_path, save=False, bpod_trials=False, settings=False, extra_classes=None, diff --git a/ibllib/io/extractors/ephys_fpga.py b/ibllib/io/extractors/ephys_fpga.py index 98bdcdd25..74ac1e551 100644 --- a/ibllib/io/extractors/ephys_fpga.py +++ b/ibllib/io/extractors/ephys_fpga.py @@ -17,7 +17,7 @@ from iblutil.spacer import Spacer import ibllib.exceptions as err -from ibllib.io import raw_data_loaders, session_params +from ibllib.io import raw_data_loaders as raw, session_params from ibllib.io.extractors.bpod_trials import extract_all as bpod_extract_all import ibllib.io.extractors.base as extractors_base from ibllib.io.extractors.training_wheel import extract_wheel_moves @@ -554,7 +554,7 @@ def extract_behaviour_sync(sync, chmap=None, display=False, bpod_trials=None, tm ax.set_yticks([0, 1, 2, 3, 4, 5]) ax.set_ylim([0, 5]) - return trials + return trials, frame2ttl, audio, bpod def extract_sync(session_path, overwrite=False, ephys_files=None, namespace='spikeglx'): @@ -734,6 +734,7 @@ def __init__(self, *args, bpod_trials=None, bpod_extractor=None, **kwargs): super().__init__(*args, **kwargs) self.bpod2fpga = None self.bpod_trials = bpod_trials + self.frame2ttl = self.audio = self.bpod = self.settings = None if bpod_extractor: self.bpod_extractor = bpod_extractor self._update_var_names() @@ -750,14 +751,37 @@ def _update_var_names(self, bpod_fields=None, bpod_rsync_fields=None): A set of Bpod trials fields to keep. bpod_rsync_fields : tuple A set of Bpod trials fields to sync to the DAQ times. - - TODO Turn into property getter; requires ensuring the output field are the same for legacy """ if self.bpod_extractor: - self.var_names = self.bpod_extractor.var_names - self.save_names = self.bpod_extractor.save_names - self.bpod_rsync_fields = bpod_rsync_fields or self._time_fields(self.bpod_extractor.var_names) - self.bpod_fields = bpod_fields or [x for x in self.bpod_extractor.var_names if x not in self.bpod_rsync_fields] + for var_name, save_name in zip(self.bpod_extractor.var_names, self.bpod_extractor.save_names): + if var_name not in self.var_names: + self.var_names += (var_name,) + self.save_names += (save_name,) + + # self.var_names = self.bpod_extractor.var_names + # self.save_names = self.bpod_extractor.save_names + self.settings = self.bpod_extractor.settings # This is used by the TaskQC + self.bpod_rsync_fields = bpod_rsync_fields + if self.bpod_rsync_fields is None: + self.bpod_rsync_fields = tuple(self._time_fields(self.bpod_extractor.var_names)) + if 'table' in self.bpod_extractor.var_names: + if not self.bpod_trials: + self.bpod_trials = self.bpod_extractor.extract(save=False) + table_keys = alfio.AlfBunch.from_df(self.bpod_trials['table']).keys() + self.bpod_rsync_fields += tuple(self._time_fields(table_keys)) + elif bpod_rsync_fields: + self.bpod_rsync_fields = bpod_rsync_fields + excluded = (*self.bpod_rsync_fields, 'table') + if bpod_fields: + assert not set(self.bpod_fields).intersection(excluded), 'bpod_fields must not also be bpod_rsync_fields' + self.bpod_fields = bpod_fields + elif self.bpod_extractor: + self.bpod_fields = tuple(x for x in self.bpod_extractor.var_names if x not in excluded) + if 'table' in self.bpod_extractor.var_names: + if not self.bpod_trials: + self.bpod_trials = self.bpod_extractor.extract(save=False) + table_keys = alfio.AlfBunch.from_df(self.bpod_trials['table']).keys() + self.bpod_fields += (*[x for x in table_keys if x not in excluded], self.sync_field + '_bpod') @staticmethod def _time_fields(trials_attr) -> set: @@ -778,7 +802,8 @@ def _time_fields(trials_attr) -> set: pattern = re.compile(fr'^[_\w]*({"|".join(FIELDS)})[_\w]*$') return set(filter(pattern.match, trials_attr)) - def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', task_collection='raw_behavior_data', **kwargs): + def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', + task_collection='raw_behavior_data', **kwargs) -> dict: """Extracts ephys trials by combining Bpod and FPGA sync pulses""" # extract the behaviour data from bpod if sync is None or chmap is None: @@ -804,7 +829,8 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', task else: tmin = tmax = None - fpga_trials = extract_behaviour_sync( + # Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC + fpga_trials, self.frame2ttl, self.audio, self.bpod = extract_behaviour_sync( sync=sync, chmap=chmap, bpod_trials=self.bpod_trials, tmin=tmin, tmax=tmax) assert self.sync_field in self.bpod_trials and self.sync_field in fpga_trials self.bpod_trials[f'{self.sync_field}_bpod'] = np.copy(self.bpod_trials[self.sync_field]) @@ -827,18 +853,20 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', task # extract the wheel data wheel, moves = self.get_wheel_positions(sync=sync, chmap=chmap, tmin=tmin, tmax=tmax) from ibllib.io.extractors.training_wheel import extract_first_movement_times - settings = raw_data_loaders.load_settings(session_path=self.session_path, task_collection=task_collection) - min_qt = settings.get('QUIESCENT_PERIOD', None) + if not self.settings: + self.settings = raw.load_settings(session_path=self.session_path, task_collection=task_collection) + min_qt = self.settings.get('QUIESCENT_PERIOD', None) first_move_onsets, *_ = extract_first_movement_times(moves, out, min_qt=min_qt) out.update({'firstMovement_times': first_move_onsets}) # Re-create trials table trials_table = alfio.AlfBunch({x: out.pop(x) for x in table_columns}) out['table'] = trials_table.to_df() + out.update({f'wheel_{k}': v for k, v in wheel.items()}) + out.update({f'wheelMoves_{k}': v for k, v in moves.items()}) out = {k: out[k] for k in self.var_names if k in out} # Reorder output - assert tuple(filter(lambda x: 'wheel' not in x, self.var_names)) == tuple(out.keys()) - return [out[k] for k in out] + [wheel['timestamps'], wheel['position'], - moves['intervals'], moves['peakAmplitude']] + assert self.var_names == tuple(out.keys()) + return out def get_wheel_positions(self, *args, **kwargs): """Extract wheel and wheelMoves objects. @@ -882,7 +910,7 @@ def extract_all(session_path, sync_collection='raw_ephys_data', save=True, save_ If save is True, a list of file paths to the extracted data. """ # Extract Bpod trials - bpod_raw = raw_data_loaders.load_data(session_path, task_collection=task_collection) + bpod_raw = raw.load_data(session_path, task_collection=task_collection) assert bpod_raw is not None, 'No task trials data in raw_behavior_data - Exit' bpod_trials, *_ = bpod_extract_all( session_path=session_path, bpod_trials=bpod_raw, task_collection=task_collection, diff --git a/ibllib/io/extractors/habituation_trials.py b/ibllib/io/extractors/habituation_trials.py index a78a57eef..9dedbd3d5 100644 --- a/ibllib/io/extractors/habituation_trials.py +++ b/ibllib/io/extractors/habituation_trials.py @@ -15,16 +15,15 @@ class HabituationTrials(BaseBpodTrialsExtractor): var_names = ('feedbackType', 'rewardVolume', 'stimOff_times', 'contrastLeft', 'contrastRight', 'feedback_times', 'stimOn_times', 'stimOnTrigger_times', 'intervals', 'goCue_times', 'goCueTrigger_times', 'itiIn_times', 'stimOffTrigger_times', - 'stimCenterTrigger_times', 'stimCenter_times') + 'stimCenterTrigger_times', 'stimCenter_times', 'position', 'phase') def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - exclude = ['itiIn_times', 'stimOffTrigger_times', - 'stimCenter_times', 'stimCenterTrigger_times'] - self.save_names = tuple([f'_ibl_trials.{x}.npy' if x not in exclude else None - for x in self.var_names]) + exclude = ['itiIn_times', 'stimOffTrigger_times', 'stimCenter_times', + 'stimCenterTrigger_times', 'position', 'phase'] + self.save_names = tuple(f'_ibl_trials.{x}.npy' if x not in exclude else None for x in self.var_names) - def _extract(self): + def _extract(self) -> dict: # Extract all trials... # Get all stim_sync events detected @@ -101,9 +100,14 @@ def _extract(self): ["iti"][0][0] for tr in self.bpod_trials] ) + # Phase and position + out['position'] = np.array([t['position'] for t in self.bpod_trials]) + out['phase'] = np.array([t['stim_phase'] for t in self.bpod_trials]) + # NB: We lose the last trial because the stim off event occurs at trial_num + 1 n_trials = out['stimOff_times'].size - return [out[k][:n_trials] for k in self.var_names] + # return [out[k][:n_trials] for k in self.var_names] + return {k: out[k][:n_trials] for k in self.var_names} def extract_all(session_path, save=False, bpod_trials=False, settings=False, task_collection='raw_behavior_data', save_path=None): diff --git a/ibllib/io/extractors/mesoscope.py b/ibllib/io/extractors/mesoscope.py index 93491945e..561bb6343 100644 --- a/ibllib/io/extractors/mesoscope.py +++ b/ibllib/io/extractors/mesoscope.py @@ -100,7 +100,7 @@ def __init__(self, *args, sync_collection='raw_sync_data', **kwargs): super().__init__(*args, **kwargs) self.timeline = alfio.load_object(self.session_path / sync_collection, 'DAQdata', namespace='timeline') - def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwargs): + def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwargs) -> dict: if not (sync or chmap): sync, chmap = load_timeline_sync_and_chmap( self.session_path / sync_collection, timeline=self.timeline, chmap=chmap) @@ -110,20 +110,17 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwa trials = super()._extract(sync, chmap, sync_collection, extractor_type='ephys', **kwargs) # If no protocol number is defined, trim timestamps based on Bpod trials intervals - trials_table = trials[self.var_names.index('table')] + trials_table = trials['table'] bpod = get_sync_fronts(sync, chmap['bpod']) if kwargs.get('protocol_number') is None: tmin = trials_table.intervals_0.iloc[0] - 1 tmax = trials_table.intervals_1.iloc[-1] # Ensure wheel is cut off based on trials - wheel_ts_idx = self.var_names.index('wheel_timestamps') - mask = np.logical_and(tmin <= trials[wheel_ts_idx], trials[wheel_ts_idx] <= tmax) - trials[wheel_ts_idx] = trials[wheel_ts_idx][mask] - wheel_pos_idx = self.var_names.index('wheel_position') - trials[wheel_pos_idx] = trials[wheel_pos_idx][mask] - move_idx = self.var_names.index('wheelMoves_intervals') - mask = np.logical_and(trials[move_idx][:, 0] >= tmin, trials[move_idx][:, 0] <= tmax) - trials[move_idx] = trials[move_idx][mask, :] + mask = np.logical_and(tmin <= trials['wheel_timestamps'], trials['wheel_timestamps'] <= tmax) + trials['wheel_timestamps'] = trials['wheel_timestamps'][mask] + trials['wheel_position'] = trials['wheel_position'][mask] + mask = np.logical_and(trials['wheelMoves_intervals'][:, 0] >= tmin, trials['wheelMoves_intervals'][:, 0] <= tmax) + trials['wheelMoves_intervals'] = trials['wheelMoves_intervals'][mask, :] else: tmin, tmax = get_protocol_period(self.session_path, kwargs['protocol_number'], bpod) bpod = get_sync_fronts(sync, chmap['bpod'], tmin, tmax) @@ -138,7 +135,7 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwa valve_open_times = self.get_valve_open_times(driver_ttls=driver_out) assert len(valve_open_times) == sum(trials_table.feedbackType == 1) # TODO Relax assertion correct = trials_table.feedbackType == 1 - trials[self.var_names.index('valveOpen_times')][correct] = valve_open_times + trials['valveOpen_times'][correct] = valve_open_times trials_table.feedback_times[correct] = valve_open_times # Replace audio events @@ -191,7 +188,7 @@ def first_true(arr): trials_table.feedback_times[~correct] = error_cue trials_table.goCue_times = go_cue - return trials + return {k: trials[k] for k in self.var_names} def extract_wheel_sync(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4', tmin=None, tmax=None): """ diff --git a/ibllib/io/extractors/training_trials.py b/ibllib/io/extractors/training_trials.py index dc13ed7dd..41a69d815 100644 --- a/ibllib/io/extractors/training_trials.py +++ b/ibllib/io/extractors/training_trials.py @@ -682,8 +682,8 @@ class TrialsTable(BaseBpodTrialsExtractor): """ save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None) - var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheel_moves_intervals', - 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement') + var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', + 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement') def _extract(self, extractor_classes=None, **kwargs): base = [Intervals, GoCueTimes, ResponseTimes, Choice, StimOnOffFreezeTimes, ContrastLR, FeedbackTimes, FeedbackType, @@ -703,16 +703,16 @@ class TrainingTrials(BaseBpodTrialsExtractor): '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None, None, None, None) var_names = ('repNum', 'goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times', 'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', - 'wheel_timestamps', 'wheel_position', 'wheel_moves_intervals', 'wheel_moves_peak_amplitude', + 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement', 'phase', 'position', 'quiescence') - def _extract(self): + def _extract(self) -> dict: base = [RepNum, GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes, ErrorCueTriggerTimes, TrialsTable, PhasePosQuiescence] out, _ = run_extractor_classes( base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, save=False, task_collection=self.task_collection) - return tuple(out.pop(x) for x in self.var_names) + return {k: out[k] for k in self.var_names} def extract_all(session_path, save=False, bpod_trials=None, settings=None, task_collection='raw_behavior_data', save_path=None): diff --git a/ibllib/io/extractors/training_wheel.py b/ibllib/io/extractors/training_wheel.py index 617b5f1df..2f1aded8c 100644 --- a/ibllib/io/extractors/training_wheel.py +++ b/ibllib/io/extractors/training_wheel.py @@ -385,8 +385,8 @@ class Wheel(BaseBpodTrialsExtractor): save_names = ('_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, '_ibl_trials.firstMovement_times.npy', None) - var_names = ('wheel_timestamps', 'wheel_position', 'wheel_moves_intervals', - 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'firstMovement_times', + var_names = ('wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', + 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'firstMovement_times', 'is_final_movement') def _extract(self): diff --git a/ibllib/io/session_params.py b/ibllib/io/session_params.py index 34e668ced..5bcaf2873 100644 --- a/ibllib/io/session_params.py +++ b/ibllib/io/session_params.py @@ -510,7 +510,7 @@ def prepare_experiment(session_path, acquisition_description=None, local=None, r # won't be preserved by create_basic_transfer_params by default remote = False if remote is False else params['REMOTE_DATA_FOLDER_PATH'] - # THis is in the docstring but still, if the session Path is absolute, we need to make it relative + # This is in the docstring but still, if the session Path is absolute, we need to make it relative if Path(session_path).is_absolute(): session_path = Path(*session_path.parts[-3:]) diff --git a/ibllib/pipes/behavior_tasks.py b/ibllib/pipes/behavior_tasks.py index 7cc317c28..6f1c8d506 100644 --- a/ibllib/pipes/behavior_tasks.py +++ b/ibllib/pipes/behavior_tasks.py @@ -9,14 +9,12 @@ from ibllib.oneibl.registration import get_lab from ibllib.pipes import base_tasks -from ibllib.io.raw_data_loaders import load_settings +from ibllib.io.raw_data_loaders import load_settings, load_bpod_fronts from ibllib.qc.task_extractors import TaskQCExtractor from ibllib.qc.task_metrics import HabituationQC, TaskQC from ibllib.io.extractors.ephys_passive import PassiveChoiceWorld -from ibllib.io.extractors import bpod_trials -from ibllib.io.extractors.base import get_session_extractor_type from ibllib.io.extractors.bpod_trials import get_bpod_extractor -from ibllib.io.extractors.ephys_fpga import extract_all +from ibllib.io.extractors.ephys_fpga import FpgaTrials, get_sync_and_chn_map from ibllib.io.extractors.mesoscope import TimelineTrials from ibllib.pipes import training_status from ibllib.plots.figures import BehaviourPlots @@ -73,25 +71,43 @@ def signature(self): } return signature - def _run(self, update=True): + def _run(self, update=True, save=True): """ Extracts an iblrig training session """ - extractor = bpod_trials.get_bpod_extractor(self.session_path, task_collection=self.collection) - trials, output_files = extractor.extract(task_collection=self.collection, save=True) + trials, output_files = self._extract_behaviour(save=save) if trials is None: return None if self.one is None or self.one.offline: return output_files + # Run the task QC + self._run_qc(trials, update=update) + return output_files + + def _extract_behaviour(self, **kwargs): + self.extractor = get_bpod_extractor(self.session_path, task_collection=self.collection) + self.extractor.default_path = self.output_collection + return self.extractor.extract(task_collection=self.collection, **kwargs) + + def _run_qc(self, trials_data=None, update=True): + if not self.extractor or trials_data is None: + trials_data, _ = self._extract_behaviour(save=False) + if not trials_data: + raise ValueError('No trials data found') + # Compile task data for QC qc = HabituationQC(self.session_path, one=self.one) - qc.extractor = TaskQCExtractor(self.session_path, sync_collection=self.sync_collection, + qc.extractor = TaskQCExtractor(self.session_path, lazy=True, sync_collection=self.sync_collection, one=self.one, sync_type=self.sync, task_collection=self.collection) + + # Currently only the data field is accessed + qc.extractor.data = qc.extractor.rename_data(trials_data.copy()) + namespace = 'task' if self.protocol_number is None else f'task_{self.protocol_number:02}' qc.run(update=update, namespace=namespace) - return output_files + return qc class TrialRegisterRaw(base_tasks.RegisterRawDataTask, base_tasks.BehaviourTask): @@ -213,6 +229,7 @@ def _run(self, **kwargs): class ChoiceWorldTrialsBpod(base_tasks.BehaviourTask): priority = 90 job_size = 'small' + extractor = None @property def signature(self): @@ -234,38 +251,53 @@ def signature(self): } return signature - def _run(self, update=True): + def _run(self, update=True, save=True): """ Extracts an iblrig training session """ - extractor = bpod_trials.get_bpod_extractor(self.session_path, task_collection=self.collection) - extractor.default_path = self.output_collection - trials, output_files = extractor.extract(task_collection=self.collection, save=True) + trials, output_files = self._extract_behaviour(save=save) if trials is None: return None if self.one is None or self.one.offline: return output_files + # Run the task QC + self._run_qc(trials) + + return output_files + + def _extract_behaviour(self, **kwargs): + self.extractor = get_bpod_extractor(self.session_path, task_collection=self.collection) + self.extractor.default_path = self.output_collection + return self.extractor.extract(task_collection=self.collection, **kwargs) + + def _run_qc(self, trials_data=None, update=True): + if not self.extractor or trials_data is None: + trials_data, _ = self._extract_behaviour(save=False) + if not trials_data: + raise ValueError('No trials data found') + # Compile task data for QC - type = get_session_extractor_type(self.session_path, task_collection=self.collection) - # FIXME Task data should not need re-extracting - if type == 'habituation': - qc = HabituationQC(self.session_path, one=self.one) - qc.extractor = TaskQCExtractor(self.session_path, one=self.one, sync_collection=self.sync_collection, - sync_type=self.sync, task_collection=self.collection) - else: # Update wheel data - qc = TaskQC(self.session_path, one=self.one) - qc.extractor = TaskQCExtractor(self.session_path, one=self.one, sync_collection=self.sync_collection, - sync_type=self.sync, task_collection=self.collection) - qc.extractor.wheel_encoding = 'X1' + qc_extractor = TaskQCExtractor(self.session_path, lazy=True, sync_collection=self.sync_collection, one=self.one, + sync_type=self.sync, task_collection=self.collection) + qc_extractor.data = qc_extractor.rename_data(trials_data) + if type(self.extractor).__name__ == 'HabituationTrials': + qc = HabituationQC(self.session_path, one=self.one, log=_logger) + else: + qc = TaskQC(self.session_path, one=self.one, log=_logger) + qc_extractor.wheel_encoding = 'X1' + qc_extractor.settings = self.extractor.settings + qc_extractor.frame_ttls, qc_extractor.audio_ttls = load_bpod_fronts( + self.session_path, task_collection=self.collection) + qc.extractor = qc_extractor + # Aggregate and update Alyx QC fields namespace = 'task' if self.protocol_number is None else f'task_{self.protocol_number:02}' qc.run(update=update, namespace=namespace) - - return output_files + return qc -class ChoiceWorldTrialsNidq(base_tasks.BehaviourTask): +class ChoiceWorldTrialsNidq(ChoiceWorldTrialsBpod): priority = 90 job_size = 'small' @@ -312,21 +344,41 @@ def _behaviour_criterion(self, update=True): "sessions", eid, "extended_qc", {"behavior": int(good_enough)} ) - def _extract_behaviour(self): - dsets, out_files = extract_all(self.session_path, self.sync_collection, task_collection=self.collection, - save_path=self.session_path.joinpath(self.output_collection), - protocol_number=self.protocol_number, save=True) + def _extract_behaviour(self, save=True, **kwargs): + # Extract Bpod trials + bpod_trials, _ = super()._extract_behaviour(save=False, **kwargs) - return dsets, out_files + # Sync Bpod trials to FPGA + sync, chmap = get_sync_and_chn_map(self.session_path, self.sync_collection) + self.extractor = FpgaTrials(self.session_path, bpod_trials=bpod_trials, bpod_extractor=self.extractor) + outputs, files = self.extractor.extract( + save=save, sync=sync, chmap=chmap, path_out=self.session_path.joinpath(self.output_collection), + task_collection=self.collection, protocol_number=self.protocol_number, **kwargs) + return outputs, files - def _run_qc(self, trials_data, update=True, plot_qc=True): - # Run the task QC - qc = TaskQC(self.session_path, one=self.one, log=_logger) - qc.extractor = TaskQCExtractor(self.session_path, lazy=True, one=qc.one, sync_collection=self.sync_collection, + def _run_qc(self, trials_data=None, update=False, plot_qc=False): + if not self.extractor or trials_data is None: + trials_data, _ = self._extract_behaviour(save=False) + if not trials_data: + raise ValueError('No trials data found') + + # Compile task data for QC + qc_extractor = TaskQCExtractor(self.session_path, lazy=True, sync_collection=self.sync_collection, one=self.one, sync_type=self.sync, task_collection=self.collection) - # Extract extra datasets required for QC - qc.extractor.data = trials_data # FIXME This line is pointless - qc.extractor.extract_data() + qc_extractor.data = qc_extractor.rename_data(trials_data.copy()) + if type(self.extractor).__name__ == 'HabituationTrials': + qc = HabituationQC(self.session_path, one=self.one, log=_logger) + else: + qc = TaskQC(self.session_path, one=self.one, log=_logger) + qc_extractor.settings = self.extractor.settings + # Add Bpod wheel data + wheel_ts_bpod = self.extractor.bpod2fpga(self.extractor.bpod_trials['wheel_timestamps']) + qc_extractor.data['wheel_timestamps_bpod'] = wheel_ts_bpod + qc_extractor.data['wheel_position_bpod'] = self.extractor.bpod_trials['wheel_position'] + qc_extractor.wheel_encoding = 'X4' + qc_extractor.frame_ttls = self.extractor.frame2ttl + qc_extractor.audio_ttls = self.extractor.audio + qc.extractor = qc_extractor # Aggregate and update Alyx QC fields namespace = 'task' if self.protocol_number is None else f'task_{self.protocol_number:02}' @@ -345,9 +397,10 @@ def _run_qc(self, trials_data, update=True, plot_qc=True): _logger.error('Could not create Trials QC Plot') _logger.error(traceback.format_exc()) self.status = -1 + return qc - def _run(self, update=True, plot_qc=True): - dsets, out_files = self._extract_behaviour() + def _run(self, update=True, plot_qc=True, save=True): + dsets, out_files = self._extract_behaviour(save=save) if not self.one or self.one.offline: return out_files @@ -378,63 +431,24 @@ def signature(self): for fn in filter(None, extractor.save_names)] return signature - def _extract_behaviour(self): + def _extract_behaviour(self, save=True, **kwargs): """Extract the Bpod trials data and Timeline acquired signals.""" # First determine the extractor from the task protocol - extractor = get_bpod_extractor(self.session_path, self.protocol, self.collection) - ret, _ = extractor.extract(save=False, task_collection=self.collection) - bpod_trials = {k: v for k, v in zip(extractor.var_names, ret)} + bpod_trials, _ = ChoiceWorldTrialsBpod._extract_behaviour(self, save=False, **kwargs) - trials = TimelineTrials(self.session_path, bpod_trials=bpod_trials) + # Sync Bpod trials to DAQ + self.extractor = TimelineTrials(self.session_path, bpod_trials=bpod_trials, bpod_extractor=self.extractor) save_path = self.session_path / self.output_collection - if not self._spacer_support(extractor.settings): + if not self._spacer_support(self.extractor.settings): _logger.warning('Protocol spacers not supported; setting protocol_number to None') self.protocol_number = None - dsets, out_files = trials.extract( - save=True, path_out=save_path, sync_collection=self.sync_collection, - task_collection=self.collection, protocol_number=self.protocol_number) - if not isinstance(dsets, dict): - dsets = {k: v for k, v in zip(trials.var_names, dsets)} - - self.timeline = trials.timeline # Store for QC later - self.frame2ttl = trials.frame2ttl - self.audio = trials.audio + dsets, out_files = self.extractor.extract( + save=save, path_out=save_path, sync_collection=self.sync_collection, + task_collection=self.collection, protocol_number=self.protocol_number, **kwargs) return dsets, out_files - def _run_qc(self, trials_data, update=True, **kwargs): - """ - Run the task QC and update Alyx with results. - - Parameters - ---------- - trials_data : dict - The extracted trials data. - update : bool - If true, update Alyx with the result. - - Notes - ----- - - Unlike the super class, currently the QC plots are not generated. - - Expects the frame2ttl and audio attributes to be set from running _extract_behaviour. - """ - # TODO Task QC extractor for Timeline - qc = TaskQC(self.session_path, one=self.one, log=_logger) - qc.extractor = TaskQCExtractor(self.session_path, lazy=True, one=qc.one, sync_collection=self.sync_collection, - sync_type=self.sync, task_collection=self.collection) - # Extract extra datasets required for QC - qc.extractor.data = TaskQCExtractor.rename_data(trials_data.copy()) - qc.extractor.load_raw_data() - - qc.extractor.frame_ttls = self.frame2ttl - qc.extractor.audio_ttls = self.audio - # qc.extractor.bpod_ttls = channel_events('bpod') - - # Aggregate and update Alyx QC fields - namespace = 'task' if self.protocol_number is None else f'task_{self.protocol_number:02}' - qc.run(update=update, namespace=namespace) - class TrainingStatus(base_tasks.BehaviourTask): priority = 90 diff --git a/ibllib/qc/task_extractors.py b/ibllib/qc/task_extractors.py index f0d46ed02..5f5269710 100644 --- a/ibllib/qc/task_extractors.py +++ b/ibllib/qc/task_extractors.py @@ -1,4 +1,5 @@ import logging +import warnings import numpy as np from scipy.interpolate import interp1d @@ -26,16 +27,16 @@ 'wheel_position', 'wheel_timestamps'] -class TaskQCExtractor(object): +class TaskQCExtractor: def __init__(self, session_path, lazy=False, one=None, download_data=False, bpod_only=False, sync_collection=None, sync_type=None, task_collection=None): """ - A class for extracting the task data required to perform task quality control + A class for extracting the task data required to perform task quality control. :param session_path: a valid session path :param lazy: if True, the data are not extracted immediately :param one: an instance of ONE, used to download the raw data if download_data is True :param download_data: if True, any missing raw data is downloaded via ONE - :param bpod_only: extract from from raw Bpod data only, even for FPGA sessions + :param bpod_only: extract from raw Bpod data only, even for FPGA sessions """ if not is_session_path(session_path): raise ValueError('Invalid session path') @@ -151,6 +152,8 @@ def extract_data(self): intervals_bpod to be assigned to the data attribute before calling this function. :return: """ + warnings.warn('The TaskQCExtractor.extract_data will be removed in the future, ' + 'use dynamic pipeline behaviour tasks instead.', DeprecationWarning) self.log.info(f'Extracting session: {self.session_path}') self.type = self.type or get_session_extractor_type(self.session_path, task_collection=self.task_collection) # Finds the sync type when it isn't explicitly set, if ephys we assume nidq otherwise bpod diff --git a/ibllib/qc/task_metrics.py b/ibllib/qc/task_metrics.py index 36f2b4806..42361645d 100644 --- a/ibllib/qc/task_metrics.py +++ b/ibllib/qc/task_metrics.py @@ -69,21 +69,21 @@ class TaskQC(base.QC): """A class for computing task QC metrics""" criteria = dict() - criteria['default'] = {"PASS": 0.99, "WARNING": 0.90, "FAIL": 0} # Note: WARNING was 0.95 prior to Aug 2022 - criteria['_task_stimOff_itiIn_delays'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_positive_feedback_stimOff_delays'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_negative_feedback_stimOff_delays'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_wheel_move_during_closed_loop'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_response_stimFreeze_delays'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_detected_wheel_moves'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_trial_length'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_goCue_delays'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_errorCue_delays'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_stimOn_delays'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_stimOff_delays'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_stimFreeze_delays'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_iti_delays'] = {"NOT_SET": 0} - criteria['_task_passed_trial_checks'] = {"NOT_SET": 0} + criteria['default'] = {'PASS': 0.99, 'WARNING': 0.90, 'FAIL': 0} # Note: WARNING was 0.95 prior to Aug 2022 + criteria['_task_stimOff_itiIn_delays'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_positive_feedback_stimOff_delays'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_negative_feedback_stimOff_delays'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_wheel_move_during_closed_loop'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_response_stimFreeze_delays'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_detected_wheel_moves'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_trial_length'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_goCue_delays'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_errorCue_delays'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_stimOn_delays'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_stimOff_delays'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_stimFreeze_delays'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_iti_delays'] = {'NOT_SET': 0} + criteria['_task_passed_trial_checks'] = {'NOT_SET': 0} @staticmethod def _thresholding(qc_value, thresholds=None): @@ -100,7 +100,7 @@ def _thresholding(qc_value, thresholds=None): if qc_value is None or np.isnan(qc_value): return int(-1) elif (qc_value > MAX_BOUND) or (qc_value < MIN_BOUND): - raise ValueError("Values out of bound") + raise ValueError('Values out of bound') if 'PASS' in thresholds.keys() and qc_value >= thresholds['PASS']: return 0 if 'WARNING' in thresholds.keys() and qc_value >= thresholds['WARNING']: @@ -151,7 +151,7 @@ def compute(self, **kwargs): if self.extractor is None: kwargs['download_data'] = kwargs.pop('download_data', self.download_data) self.load_data(**kwargs) - self.log.info(f"Session {self.session_path}: Running QC on behavior data...") + self.log.info(f'Session {self.session_path}: Running QC on behavior data...') self.metrics, self.passed = get_bpodqc_metrics_frame( self.extractor.data, wheel_gain=self.extractor.settings['STIM_GAIN'], # The wheel gain @@ -229,7 +229,7 @@ def compute(self, download_data=None): # If download_data is None, decide based on whether eid or session path was provided ensure_data = self.download_data if download_data is None else download_data self.load_data(download_data=ensure_data) - self.log.info(f"Session {self.session_path}: Running QC on habituation data...") + self.log.info(f'Session {self.session_path}: Running QC on habituation data...') # Initialize checks prefix = '_task_' @@ -274,16 +274,16 @@ def compute(self, download_data=None): # Check event orders: trial_start < stim on < stim center < feedback < stim off check = prefix + 'trial_event_sequence' nans = ( - np.isnan(data["intervals"][:, 0]) | # noqa - np.isnan(data["stimOn_times"]) | # noqa - np.isnan(data["stimCenter_times"]) | - np.isnan(data["valveOpen_times"]) | # noqa - np.isnan(data["stimOff_times"]) + np.isnan(data['intervals'][:, 0]) | # noqa + np.isnan(data['stimOn_times']) | # noqa + np.isnan(data['stimCenter_times']) | + np.isnan(data['valveOpen_times']) | # noqa + np.isnan(data['stimOff_times']) ) - a = np.less(data["intervals"][:, 0], data["stimOn_times"], where=~nans) - b = np.less(data["stimOn_times"], data["stimCenter_times"], where=~nans) - c = np.less(data["stimCenter_times"], data["valveOpen_times"], where=~nans) - d = np.less(data["valveOpen_times"], data["stimOff_times"], where=~nans) + a = np.less(data['intervals'][:, 0], data['stimOn_times'], where=~nans) + b = np.less(data['stimOn_times'], data['stimCenter_times'], where=~nans) + c = np.less(data['stimCenter_times'], data['valveOpen_times'], where=~nans) + d = np.less(data['valveOpen_times'], data['stimOff_times'], where=~nans) metrics[check] = a & b & c & d & ~nans passed[check] = metrics[check].astype(float) @@ -291,7 +291,7 @@ def compute(self, download_data=None): # Check that the time difference between the visual stimulus center-command being # triggered and the stimulus effectively appearing in the center is smaller than 150 ms. check = prefix + 'stimCenter_delays' - metric = np.nan_to_num(data["stimCenter_times"] - data["stimCenterTrigger_times"], + metric = np.nan_to_num(data['stimCenter_times'] - data['stimCenterTrigger_times'], nan=np.inf) passed[check] = (metric <= 0.15) & (metric > 0) metrics[check] = metric @@ -375,9 +375,9 @@ def check_stimOn_goCue_delays(data, **_): """ # Calculate the difference between stimOn and goCue times. # If either are NaN, the result will be Inf to ensure that it crosses the failure threshold. - metric = np.nan_to_num(data["goCue_times"] - data["stimOn_times"], nan=np.inf) + metric = np.nan_to_num(data['goCue_times'] - data['stimOn_times'], nan=np.inf) passed = (metric < 0.01) & (metric > 0) - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -391,9 +391,9 @@ def check_response_feedback_delays(data, **_): :param data: dict of trial data with keys ('feedback_times', 'response_times', 'intervals') """ - metric = np.nan_to_num(data["feedback_times"] - data["response_times"], nan=np.inf) + metric = np.nan_to_num(data['feedback_times'] - data['response_times'], nan=np.inf) passed = (metric < 0.01) & (metric > 0) - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -410,13 +410,13 @@ def check_response_stimFreeze_delays(data, **_): """ # Calculate the difference between stimOn and goCue times. # If either are NaN, the result will be Inf to ensure that it crosses the failure threshold. - metric = np.nan_to_num(data["stimFreeze_times"] - data["response_times"], nan=np.inf) + metric = np.nan_to_num(data['stimFreeze_times'] - data['response_times'], nan=np.inf) # Test for valid values passed = ((metric < 0.1) & (metric > 0)).astype(float) # Finally remove no_go trials (stimFreeze triggered differently in no_go trials) # These values are ignored in calculation of proportion passed - passed[data["choice"] == 0] = np.nan - assert data["intervals"].shape[0] == len(metric) == len(passed) + passed[data['choice'] == 0] = np.nan + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -431,12 +431,12 @@ def check_stimOff_itiIn_delays(data, **_): 'choice') """ # If either are NaN, the result will be Inf to ensure that it crosses the failure threshold. - metric = np.nan_to_num(data["itiIn_times"] - data["stimOff_times"], nan=np.inf) + metric = np.nan_to_num(data['itiIn_times'] - data['stimOff_times'], nan=np.inf) passed = ((metric < 0.01) & (metric >= 0)).astype(float) # Remove no_go trials (stimOff triggered differently in no_go trials) # NaN values are ignored in calculation of proportion passed - metric[data["choice"] == 0] = passed[data["choice"] == 0] = np.nan - assert data["intervals"].shape[0] == len(metric) == len(passed) + metric[data['choice'] == 0] = passed[data['choice'] == 0] = np.nan + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -451,14 +451,14 @@ def check_iti_delays(data, **_): :param data: dict of trial data with keys ('stimOff_times', 'intervals') """ # Initialize array the length of completed trials - metric = np.full(data["intervals"].shape[0], np.nan) + metric = np.full(data['intervals'].shape[0], np.nan) passed = metric.copy() # Get the difference between stim off and the start of the next trial # Missing data are set to Inf, except for the last trial which is a NaN metric[:-1] = \ - np.nan_to_num(data["intervals"][1:, 0] - data["stimOff_times"][:-1] - 0.5, nan=np.inf) + np.nan_to_num(data['intervals'][1:, 0] - data['stimOff_times'][:-1] - 0.5, nan=np.inf) passed[:-1] = np.abs(metric[:-1]) < .5 # Last trial is not counted - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -474,11 +474,11 @@ def check_positive_feedback_stimOff_delays(data, **_): 'correct') """ # If either are NaN, the result will be Inf to ensure that it crosses the failure threshold. - metric = np.nan_to_num(data["stimOff_times"] - data["feedback_times"] - 1, nan=np.inf) + metric = np.nan_to_num(data['stimOff_times'] - data['feedback_times'] - 1, nan=np.inf) passed = (np.abs(metric) < 0.15).astype(float) # NaN values are ignored in calculation of proportion passed; ignore incorrect trials here - metric[~data["correct"]] = passed[~data["correct"]] = np.nan - assert data["intervals"].shape[0] == len(metric) == len(passed) + metric[~data['correct']] = passed[~data['correct']] = np.nan + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -492,12 +492,12 @@ def check_negative_feedback_stimOff_delays(data, **_): :param data: dict of trial data with keys ('stimOff_times', 'errorCue_times', 'intervals') """ - metric = np.nan_to_num(data["stimOff_times"] - data["errorCue_times"] - 2, nan=np.inf) + metric = np.nan_to_num(data['stimOff_times'] - data['errorCue_times'] - 2, nan=np.inf) # Apply criteria passed = (np.abs(metric) < 0.15).astype(float) # Remove none negative feedback trials - metric[data["correct"]] = passed[data["correct"]] = np.nan - assert data["intervals"].shape[0] == len(metric) == len(passed) + metric[data['correct']] = passed[data['correct']] = np.nan + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -515,12 +515,12 @@ def check_wheel_move_before_feedback(data, **_): """ # Get tuple of wheel times and positions within 100ms of feedback traces = traces_by_trial( - data["wheel_timestamps"], - data["wheel_position"], - start=data["feedback_times"] - 0.05, - end=data["feedback_times"] + 0.05, + data['wheel_timestamps'], + data['wheel_position'], + start=data['feedback_times'] - 0.05, + end=data['feedback_times'] + 0.05, ) - metric = np.zeros_like(data["feedback_times"]) + metric = np.zeros_like(data['feedback_times']) # For each trial find the displacement for i, trial in enumerate(traces): pos = trial[1] @@ -528,12 +528,12 @@ def check_wheel_move_before_feedback(data, **_): metric[i] = pos[-1] - pos[0] # except no-go trials - metric[data["choice"] == 0] = np.nan # NaN = trial ignored for this check + metric[data['choice'] == 0] = np.nan # NaN = trial ignored for this check nans = np.isnan(metric) passed = np.zeros_like(metric) * np.nan passed[~nans] = (metric[~nans] != 0).astype(float) - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -555,15 +555,15 @@ def _wheel_move_during_closed_loop(re_ts, re_pos, data, wheel_gain=None, tol=1, :param tol: the criterion in visual degrees """ if wheel_gain is None: - _log.warning("No wheel_gain input in function call, returning None") + _log.warning('No wheel_gain input in function call, returning None') return None, None # Get tuple of wheel times and positions over each trial's closed-loop period traces = traces_by_trial(re_ts, re_pos, - start=data["goCueTrigger_times"], - end=data["response_times"]) + start=data['goCueTrigger_times'], + end=data['response_times']) - metric = np.zeros_like(data["feedback_times"]) + metric = np.zeros_like(data['feedback_times']) # For each trial find the absolute displacement for i, trial in enumerate(traces): t, pos = trial @@ -574,16 +574,16 @@ def _wheel_move_during_closed_loop(re_ts, re_pos, data, wheel_gain=None, tol=1, metric[i] = np.abs(pos - origin).max() # Load wheel_gain and thresholds for each trial - wheel_gain = np.array([wheel_gain] * len(data["position"])) - thresh = data["position"] + wheel_gain = np.array([wheel_gain] * len(data['position'])) + thresh = data['position'] # abs displacement, s, in mm required to move 35 visual degrees s_mm = np.abs(thresh / wheel_gain) # don't care about direction criterion = cm_to_rad(s_mm * 1e-1) # convert abs displacement to radians (wheel pos is in rad) metric = metric - criterion # difference should be close to 0 rad_per_deg = cm_to_rad(1 / wheel_gain * 1e-1) passed = (np.abs(metric) < rad_per_deg * tol).astype(float) # less than 1 visual degree off - metric[data["choice"] == 0] = passed[data["choice"] == 0] = np.nan # except no-go trials - assert data["intervals"].shape[0] == len(metric) == len(passed) + metric[data['choice'] == 0] = passed[data['choice'] == 0] = np.nan # except no-go trials + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -642,25 +642,25 @@ def check_wheel_freeze_during_quiescence(data, **_): :param data: dict of trial data with keys ('wheel_timestamps', 'wheel_position', 'quiescence', 'intervals', 'stimOnTrigger_times') """ - assert np.all(np.diff(data["wheel_timestamps"]) >= 0) - assert data["quiescence"].size == data["stimOnTrigger_times"].size + assert np.all(np.diff(data['wheel_timestamps']) >= 0) + assert data['quiescence'].size == data['stimOnTrigger_times'].size # Get tuple of wheel times and positions over each trial's quiescence period - qevt_start_times = data["stimOnTrigger_times"] - data["quiescence"] + qevt_start_times = data['stimOnTrigger_times'] - data['quiescence'] traces = traces_by_trial( - data["wheel_timestamps"], - data["wheel_position"], + data['wheel_timestamps'], + data['wheel_position'], start=qevt_start_times, - end=data["stimOnTrigger_times"] + end=data['stimOnTrigger_times'] ) - metric = np.zeros((len(data["quiescence"]), 2)) # (n_trials, n_directions) + metric = np.zeros((len(data['quiescence']), 2)) # (n_trials, n_directions) for i, trial in enumerate(traces): t, pos = trial # Get the last position before the period began if pos.size > 0: # Find the position of the preceding sample and subtract it - idx = np.abs(data["wheel_timestamps"] - t[0]).argmin() - 1 - origin = data["wheel_position"][idx if idx != -1 else 0] + idx = np.abs(data['wheel_timestamps'] - t[0]).argmin() - 1 + origin = data['wheel_position'][idx if idx != -1 else 0] # Find the absolute min and max relative to the last sample metric[i, :] = np.abs([np.min(pos - origin), np.max(pos - origin)]) # Reduce to the largest displacement found in any direction @@ -668,7 +668,7 @@ def check_wheel_freeze_during_quiescence(data, **_): metric = 180 * metric / np.pi # convert to degrees from radians criterion = 2 # Position shouldn't change more than 2 in either direction passed = metric < criterion - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -685,8 +685,8 @@ def check_detected_wheel_moves(data, min_qt=0, **_): """ # Depending on task version this may be a single value or an array of quiescent periods min_qt = np.array(min_qt) - if min_qt.size > data["intervals"].shape[0]: - min_qt = min_qt[:data["intervals"].shape[0]] + if min_qt.size > data['intervals'].shape[0]: + min_qt = min_qt[:data['intervals'].shape[0]] metric = data['firstMovement_times'] qevt_start = data['goCueTrigger_times'] - np.array(min_qt) @@ -714,25 +714,25 @@ def check_error_trial_event_sequence(data, **_): """ # An array the length of N trials where True means at least one event time was NaN (bad) nans = ( - np.isnan(data["intervals"][:, 0]) | - np.isnan(data["goCue_times"]) | # noqa - np.isnan(data["errorCue_times"]) | # noqa - np.isnan(data["itiIn_times"]) | # noqa - np.isnan(data["intervals"][:, 1]) + np.isnan(data['intervals'][:, 0]) | + np.isnan(data['goCue_times']) | # noqa + np.isnan(data['errorCue_times']) | # noqa + np.isnan(data['itiIn_times']) | # noqa + np.isnan(data['intervals'][:, 1]) ) # For each trial check that the events happened in the correct order (ignore NaN values) - a = np.less(data["intervals"][:, 0], data["goCue_times"], where=~nans) # Start time < go cue - b = np.less(data["goCue_times"], data["errorCue_times"], where=~nans) # Go cue < error cue - c = np.less(data["errorCue_times"], data["itiIn_times"], where=~nans) # Error cue < ITI start - d = np.less(data["itiIn_times"], data["intervals"][:, 1], where=~nans) # ITI start < end time + a = np.less(data['intervals'][:, 0], data['goCue_times'], where=~nans) # Start time < go cue + b = np.less(data['goCue_times'], data['errorCue_times'], where=~nans) # Go cue < error cue + c = np.less(data['errorCue_times'], data['itiIn_times'], where=~nans) # Error cue < ITI start + d = np.less(data['itiIn_times'], data['intervals'][:, 1], where=~nans) # ITI start < end time # For each trial check all events were in order AND all event times were not NaN metric = a & b & c & d & ~nans passed = metric.astype(float) - passed[data["correct"]] = np.nan # Look only at incorrect trials - assert data["intervals"].shape[0] == len(metric) == len(passed) + passed[data['correct']] = np.nan # Look only at incorrect trials + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -749,25 +749,25 @@ def check_correct_trial_event_sequence(data, **_): """ # An array the length of N trials where True means at least one event time was NaN (bad) nans = ( - np.isnan(data["intervals"][:, 0]) | - np.isnan(data["goCue_times"]) | # noqa - np.isnan(data["valveOpen_times"]) | - np.isnan(data["itiIn_times"]) | # noqa - np.isnan(data["intervals"][:, 1]) + np.isnan(data['intervals'][:, 0]) | + np.isnan(data['goCue_times']) | # noqa + np.isnan(data['valveOpen_times']) | + np.isnan(data['itiIn_times']) | # noqa + np.isnan(data['intervals'][:, 1]) ) # For each trial check that the events happened in the correct order (ignore NaN values) - a = np.less(data["intervals"][:, 0], data["goCue_times"], where=~nans) # Start time < go cue - b = np.less(data["goCue_times"], data["valveOpen_times"], where=~nans) # Go cue < feedback - c = np.less(data["valveOpen_times"], data["itiIn_times"], where=~nans) # Feedback < ITI start - d = np.less(data["itiIn_times"], data["intervals"][:, 1], where=~nans) # ITI start < end time + a = np.less(data['intervals'][:, 0], data['goCue_times'], where=~nans) # Start time < go cue + b = np.less(data['goCue_times'], data['valveOpen_times'], where=~nans) # Go cue < feedback + c = np.less(data['valveOpen_times'], data['itiIn_times'], where=~nans) # Feedback < ITI start + d = np.less(data['itiIn_times'], data['intervals'][:, 1], where=~nans) # ITI start < end time # For each trial True means all events were in order AND all event times were not NaN metric = a & b & c & d & ~nans passed = metric.astype(float) - passed[~data["correct"]] = np.nan # Look only at correct trials - assert data["intervals"].shape[0] == len(metric) == len(passed) + passed[~data['correct']] = np.nan # Look only at correct trials + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -799,7 +799,7 @@ def check_n_trial_events(data, **_): 'wheel_moves_peak_amplitude', 'wheel_moves_intervals', 'wheel_timestamps', 'wheel_intervals', 'stimFreeze_times'] events = [k for k in data.keys() if k.endswith('_times') and k not in exclude] - metric = np.zeros(data["intervals"].shape[0], dtype=bool) + metric = np.zeros(data['intervals'].shape[0], dtype=bool) # For each trial interval check that one of each trial event occurred. For incorrect trials, # check the error cue trigger occurred within the interval, otherwise check it is nan. @@ -822,9 +822,9 @@ def check_trial_length(data, **_): :param data: dict of trial data with keys ('feedback_times', 'goCue_times', 'intervals') """ # NaN values are usually ignored so replace them with Inf so they fail the threshold - metric = np.nan_to_num(data["feedback_times"] - data["goCue_times"], nan=np.inf) + metric = np.nan_to_num(data['feedback_times'] - data['goCue_times'], nan=np.inf) passed = (metric < 60.1) & (metric > 0) - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -835,14 +835,14 @@ def check_goCue_delays(data, **_): effectively played is smaller than 1ms. Metric: M = goCue_times - goCueTrigger_times - Criterion: 0 < M <= 0.001 s + Criterion: 0 < M <= 0.0015 s Units: seconds [s] :param data: dict of trial data with keys ('goCue_times', 'goCueTrigger_times', 'intervals') """ - metric = np.nan_to_num(data["goCue_times"] - data["goCueTrigger_times"], nan=np.inf) + metric = np.nan_to_num(data['goCue_times'] - data['goCueTrigger_times'], nan=np.inf) passed = (metric <= 0.0015) & (metric > 0) - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -850,16 +850,16 @@ def check_errorCue_delays(data, **_): """ Check that the time difference between the error sound being triggered and effectively played is smaller than 1ms. Metric: M = errorCue_times - errorCueTrigger_times - Criterion: 0 < M <= 0.001 s + Criterion: 0 < M <= 0.0015 s Units: seconds [s] :param data: dict of trial data with keys ('errorCue_times', 'errorCueTrigger_times', 'intervals', 'correct') """ - metric = np.nan_to_num(data["errorCue_times"] - data["errorCueTrigger_times"], nan=np.inf) + metric = np.nan_to_num(data['errorCue_times'] - data['errorCueTrigger_times'], nan=np.inf) passed = ((metric <= 0.0015) & (metric > 0)).astype(float) - passed[data["correct"]] = metric[data["correct"]] = np.nan - assert data["intervals"].shape[0] == len(metric) == len(passed) + passed[data['correct']] = metric[data['correct']] = np.nan + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -868,15 +868,15 @@ def check_stimOn_delays(data, **_): and the stimulus effectively appearing on the screen is smaller than 150 ms. Metric: M = stimOn_times - stimOnTrigger_times - Criterion: 0 < M < 0.150 s + Criterion: 0 < M < 0.15 s Units: seconds [s] :param data: dict of trial data with keys ('stimOn_times', 'stimOnTrigger_times', 'intervals') """ - metric = np.nan_to_num(data["stimOn_times"] - data["stimOnTrigger_times"], nan=np.inf) + metric = np.nan_to_num(data['stimOn_times'] - data['stimOnTrigger_times'], nan=np.inf) passed = (metric <= 0.15) & (metric > 0) - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -886,15 +886,15 @@ def check_stimOff_delays(data, **_): is smaller than 150 ms. Metric: M = stimOff_times - stimOffTrigger_times - Criterion: 0 < M < 0.150 s + Criterion: 0 < M < 0.15 s Units: seconds [s] :param data: dict of trial data with keys ('stimOff_times', 'stimOffTrigger_times', 'intervals') """ - metric = np.nan_to_num(data["stimOff_times"] - data["stimOffTrigger_times"], nan=np.inf) + metric = np.nan_to_num(data['stimOff_times'] - data['stimOffTrigger_times'], nan=np.inf) passed = (metric <= 0.15) & (metric > 0) - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -904,15 +904,15 @@ def check_stimFreeze_delays(data, **_): is smaller than 150 ms. Metric: M = stimFreeze_times - stimFreezeTrigger_times - Criterion: 0 < M < 0.150 s + Criterion: 0 < M < 0.15 s Units: seconds [s] :param data: dict of trial data with keys ('stimFreeze_times', 'stimFreezeTrigger_times', 'intervals') """ - metric = np.nan_to_num(data["stimFreeze_times"] - data["stimFreezeTrigger_times"], nan=np.inf) + metric = np.nan_to_num(data['stimFreeze_times'] - data['stimFreezeTrigger_times'], nan=np.inf) passed = (metric <= 0.15) & (metric > 0) - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -934,7 +934,7 @@ def check_reward_volumes(data, **_): passed[correct] = (1.5 <= metric[correct]) & (metric[correct] <= 3.) # Check incorrect trials are 0 passed[~correct] = metric[~correct] == 0 - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -946,7 +946,7 @@ def check_reward_volume_set(data, **_): :param data: dict of trial data with keys ('rewardVolume') """ - metric = data["rewardVolume"] + metric = data['rewardVolume'] passed = 0 < len(set(metric)) <= 2 and 0. in metric return metric, passed @@ -994,19 +994,19 @@ def check_stimulus_move_before_goCue(data, photodiode=None, **_): :param photodiode: the fronts from Bpod's BNC1 input or FPGA frame2ttl channel """ if photodiode is None: - _log.warning("No photodiode TTL input in function call, returning None") + _log.warning('No photodiode TTL input in function call, returning None') return None photodiode_clean = ephys_fpga._clean_frame2ttl(photodiode) - s = photodiode_clean["times"] + s = photodiode_clean['times'] s = s[~np.isnan(s)] # Remove NaNs metric = np.array([]) - for i, c in zip(data["intervals"][:, 0], data["goCue_times"]): + for i, c in zip(data['intervals'][:, 0], data['goCue_times']): metric = np.append(metric, np.count_nonzero(s[s > i] < (c - 0.02))) passed = (metric == 0).astype(float) # Remove no go trials - passed[data["choice"] == 0] = np.nan - assert data["intervals"].shape[0] == len(metric) == len(passed) + passed[data['choice'] == 0] = np.nan + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -1022,12 +1022,12 @@ def check_audio_pre_trial(data, audio=None, **_): :param audio: the fronts from Bpod's BNC2 input FPGA audio sync channel """ if audio is None: - _log.warning("No BNC2 input in function call, retuning None") + _log.warning('No BNC2 input in function call, retuning None') return None - s = audio["times"][~np.isnan(audio["times"])] # Audio TTLs with NaNs removed + s = audio['times'][~np.isnan(audio['times'])] # Audio TTLs with NaNs removed metric = np.array([], dtype=np.int8) - for i, c in zip(data["intervals"][:, 0], data["goCue_times"]): + for i, c in zip(data['intervals'][:, 0], data['goCue_times']): metric = np.append(metric, sum(s[s > i] < (c - 0.02))) passed = metric == 0 - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed