Skip to content

Commit

Permalink
Task qc extractor refactor (#649)
Browse files Browse the repository at this point in the history
* Habituation extract phase and position

* Independent task QC method in behaviour tasks

* var names

* flake8

* Initialize settings property
  • Loading branch information
k1o0 authored Oct 6, 2023
1 parent 935c0c7 commit 68fb8cd
Show file tree
Hide file tree
Showing 12 changed files with 320 additions and 273 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ __pycache__
python/scratch
.idea/*
.vscode/
*.code-workspace
*checkpoint.ipynb
build/
venv/
Expand Down
2 changes: 1 addition & 1 deletion ibllib/ephys/ephysqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ def _qc_from_path(sess_path, display=True):
sync, chmap = ephys_fpga.get_main_probe_sync(sess_path, bin_exists=False)
_ = ephys_fpga.extract_all(sess_path, output_path=temp_alf_folder, save=True)
# check that the output is complete
fpga_trials = ephys_fpga.extract_behaviour_sync(sync, chmap=chmap, display=display)
fpga_trials, *_ = ephys_fpga.extract_behaviour_sync(sync, chmap=chmap, display=display)
# align with the bpod
bpod2fpga = ephys_fpga.align_with_bpod(temp_alf_folder.parent)
alf_trials = alfio.load_object(temp_alf_folder, 'trials')
Expand Down
24 changes: 12 additions & 12 deletions ibllib/io/extractors/biased_trials.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ class TrialsTableBiased(BaseBpodTrialsExtractor):
intervals, goCue_times, response_times, choice, stimOn_times, contrastLeft, contrastRight,
feedback_times, feedbackType, rewardVolume, probabilityLeft, firstMovement_times
Additionally extracts the following wheel data:
wheel_timestamps, wheel_position, wheel_moves_intervals, wheel_moves_peak_amplitude
wheel_timestamps, wheel_position, wheelMoves_intervals, wheelMoves_peakAmplitude
"""
save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy',
'_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None)
var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheel_moves_intervals',
'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement')
var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals',
'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement')

def _extract(self, extractor_classes=None, **kwargs):
base = [Intervals, GoCueTimes, ResponseTimes, Choice, StimOnOffFreezeTimes, ContrastLR, FeedbackTimes, FeedbackType,
Expand All @@ -120,13 +120,13 @@ class TrialsTableEphys(BaseBpodTrialsExtractor):
intervals, goCue_times, response_times, choice, stimOn_times, contrastLeft, contrastRight,
feedback_times, feedbackType, rewardVolume, probabilityLeft, firstMovement_times
Additionally extracts the following wheel data:
wheel_timestamps, wheel_position, wheel_moves_intervals, wheel_moves_peak_amplitude
wheel_timestamps, wheel_position, wheelMoves_intervals, wheelMoves_peakAmplitude
"""
save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy',
'_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None,
None, None, None, '_ibl_trials.quiescencePeriod.npy')
var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheel_moves_intervals',
'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement',
var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals',
'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement',
'phase', 'position', 'quiescence')

def _extract(self, extractor_classes=None, **kwargs):
Expand Down Expand Up @@ -154,16 +154,16 @@ class BiasedTrials(BaseBpodTrialsExtractor):
None, None, '_ibl_trials.quiescencePeriod.npy')
var_names = ('goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times',
'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position',
'wheel_moves_intervals', 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement', 'included',
'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement', 'included',
'phase', 'position', 'quiescence')

def _extract(self, extractor_classes=None, **kwargs):
def _extract(self, extractor_classes=None, **kwargs) -> dict:
base = [GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes,
ErrorCueTriggerTimes, TrialsTableBiased, IncludedTrials, PhasePosQuiescence]
# Exclude from trials table
out, _ = run_extractor_classes(base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings,
save=False, task_collection=self.task_collection)
return tuple(out.pop(x) for x in self.var_names)
return {k: out[k] for k in self.var_names}


class EphysTrials(BaseBpodTrialsExtractor):
Expand All @@ -177,16 +177,16 @@ class EphysTrials(BaseBpodTrialsExtractor):
'_ibl_trials.included.npy', None, None, '_ibl_trials.quiescencePeriod.npy')
var_names = ('goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times',
'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position',
'wheel_moves_intervals', 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement', 'included',
'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement', 'included',
'phase', 'position', 'quiescence')

def _extract(self, extractor_classes=None, **kwargs):
def _extract(self, extractor_classes=None, **kwargs) -> dict:
base = [GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes,
ErrorCueTriggerTimes, TrialsTableEphys, IncludedTrials, PhasePosQuiescence]
# Exclude from trials table
out, _ = run_extractor_classes(base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings,
save=False, task_collection=self.task_collection)
return tuple(out.pop(x) for x in self.var_names)
return {k: out[k] for k in self.var_names}


def extract_all(session_path, save=False, bpod_trials=False, settings=False, extra_classes=None,
Expand Down
60 changes: 44 additions & 16 deletions ibllib/io/extractors/ephys_fpga.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from iblutil.spacer import Spacer

import ibllib.exceptions as err
from ibllib.io import raw_data_loaders, session_params
from ibllib.io import raw_data_loaders as raw, session_params
from ibllib.io.extractors.bpod_trials import extract_all as bpod_extract_all
import ibllib.io.extractors.base as extractors_base
from ibllib.io.extractors.training_wheel import extract_wheel_moves
Expand Down Expand Up @@ -554,7 +554,7 @@ def extract_behaviour_sync(sync, chmap=None, display=False, bpod_trials=None, tm
ax.set_yticks([0, 1, 2, 3, 4, 5])
ax.set_ylim([0, 5])

return trials
return trials, frame2ttl, audio, bpod


def extract_sync(session_path, overwrite=False, ephys_files=None, namespace='spikeglx'):
Expand Down Expand Up @@ -734,6 +734,7 @@ def __init__(self, *args, bpod_trials=None, bpod_extractor=None, **kwargs):
super().__init__(*args, **kwargs)
self.bpod2fpga = None
self.bpod_trials = bpod_trials
self.frame2ttl = self.audio = self.bpod = self.settings = None
if bpod_extractor:
self.bpod_extractor = bpod_extractor
self._update_var_names()
Expand All @@ -750,14 +751,37 @@ def _update_var_names(self, bpod_fields=None, bpod_rsync_fields=None):
A set of Bpod trials fields to keep.
bpod_rsync_fields : tuple
A set of Bpod trials fields to sync to the DAQ times.
TODO Turn into property getter; requires ensuring the output field are the same for legacy
"""
if self.bpod_extractor:
self.var_names = self.bpod_extractor.var_names
self.save_names = self.bpod_extractor.save_names
self.bpod_rsync_fields = bpod_rsync_fields or self._time_fields(self.bpod_extractor.var_names)
self.bpod_fields = bpod_fields or [x for x in self.bpod_extractor.var_names if x not in self.bpod_rsync_fields]
for var_name, save_name in zip(self.bpod_extractor.var_names, self.bpod_extractor.save_names):
if var_name not in self.var_names:
self.var_names += (var_name,)
self.save_names += (save_name,)

# self.var_names = self.bpod_extractor.var_names
# self.save_names = self.bpod_extractor.save_names
self.settings = self.bpod_extractor.settings # This is used by the TaskQC
self.bpod_rsync_fields = bpod_rsync_fields
if self.bpod_rsync_fields is None:
self.bpod_rsync_fields = tuple(self._time_fields(self.bpod_extractor.var_names))
if 'table' in self.bpod_extractor.var_names:
if not self.bpod_trials:
self.bpod_trials = self.bpod_extractor.extract(save=False)
table_keys = alfio.AlfBunch.from_df(self.bpod_trials['table']).keys()
self.bpod_rsync_fields += tuple(self._time_fields(table_keys))
elif bpod_rsync_fields:
self.bpod_rsync_fields = bpod_rsync_fields
excluded = (*self.bpod_rsync_fields, 'table')
if bpod_fields:
assert not set(self.bpod_fields).intersection(excluded), 'bpod_fields must not also be bpod_rsync_fields'
self.bpod_fields = bpod_fields
elif self.bpod_extractor:
self.bpod_fields = tuple(x for x in self.bpod_extractor.var_names if x not in excluded)
if 'table' in self.bpod_extractor.var_names:
if not self.bpod_trials:
self.bpod_trials = self.bpod_extractor.extract(save=False)
table_keys = alfio.AlfBunch.from_df(self.bpod_trials['table']).keys()
self.bpod_fields += (*[x for x in table_keys if x not in excluded], self.sync_field + '_bpod')

@staticmethod
def _time_fields(trials_attr) -> set:
Expand All @@ -778,7 +802,8 @@ def _time_fields(trials_attr) -> set:
pattern = re.compile(fr'^[_\w]*({"|".join(FIELDS)})[_\w]*$')
return set(filter(pattern.match, trials_attr))

def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', task_collection='raw_behavior_data', **kwargs):
def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data',
task_collection='raw_behavior_data', **kwargs) -> dict:
"""Extracts ephys trials by combining Bpod and FPGA sync pulses"""
# extract the behaviour data from bpod
if sync is None or chmap is None:
Expand All @@ -804,7 +829,8 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', task
else:
tmin = tmax = None

fpga_trials = extract_behaviour_sync(
# Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC
fpga_trials, self.frame2ttl, self.audio, self.bpod = extract_behaviour_sync(
sync=sync, chmap=chmap, bpod_trials=self.bpod_trials, tmin=tmin, tmax=tmax)
assert self.sync_field in self.bpod_trials and self.sync_field in fpga_trials
self.bpod_trials[f'{self.sync_field}_bpod'] = np.copy(self.bpod_trials[self.sync_field])
Expand All @@ -827,18 +853,20 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', task
# extract the wheel data
wheel, moves = self.get_wheel_positions(sync=sync, chmap=chmap, tmin=tmin, tmax=tmax)
from ibllib.io.extractors.training_wheel import extract_first_movement_times
settings = raw_data_loaders.load_settings(session_path=self.session_path, task_collection=task_collection)
min_qt = settings.get('QUIESCENT_PERIOD', None)
if not self.settings:
self.settings = raw.load_settings(session_path=self.session_path, task_collection=task_collection)
min_qt = self.settings.get('QUIESCENT_PERIOD', None)
first_move_onsets, *_ = extract_first_movement_times(moves, out, min_qt=min_qt)
out.update({'firstMovement_times': first_move_onsets})
# Re-create trials table
trials_table = alfio.AlfBunch({x: out.pop(x) for x in table_columns})
out['table'] = trials_table.to_df()

out.update({f'wheel_{k}': v for k, v in wheel.items()})
out.update({f'wheelMoves_{k}': v for k, v in moves.items()})
out = {k: out[k] for k in self.var_names if k in out} # Reorder output
assert tuple(filter(lambda x: 'wheel' not in x, self.var_names)) == tuple(out.keys())
return [out[k] for k in out] + [wheel['timestamps'], wheel['position'],
moves['intervals'], moves['peakAmplitude']]
assert self.var_names == tuple(out.keys())
return out

def get_wheel_positions(self, *args, **kwargs):
"""Extract wheel and wheelMoves objects.
Expand Down Expand Up @@ -882,7 +910,7 @@ def extract_all(session_path, sync_collection='raw_ephys_data', save=True, save_
If save is True, a list of file paths to the extracted data.
"""
# Extract Bpod trials
bpod_raw = raw_data_loaders.load_data(session_path, task_collection=task_collection)
bpod_raw = raw.load_data(session_path, task_collection=task_collection)
assert bpod_raw is not None, 'No task trials data in raw_behavior_data - Exit'
bpod_trials, *_ = bpod_extract_all(
session_path=session_path, bpod_trials=bpod_raw, task_collection=task_collection,
Expand Down
18 changes: 11 additions & 7 deletions ibllib/io/extractors/habituation_trials.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,15 @@ class HabituationTrials(BaseBpodTrialsExtractor):
var_names = ('feedbackType', 'rewardVolume', 'stimOff_times', 'contrastLeft', 'contrastRight',
'feedback_times', 'stimOn_times', 'stimOnTrigger_times', 'intervals',
'goCue_times', 'goCueTrigger_times', 'itiIn_times', 'stimOffTrigger_times',
'stimCenterTrigger_times', 'stimCenter_times')
'stimCenterTrigger_times', 'stimCenter_times', 'position', 'phase')

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
exclude = ['itiIn_times', 'stimOffTrigger_times',
'stimCenter_times', 'stimCenterTrigger_times']
self.save_names = tuple([f'_ibl_trials.{x}.npy' if x not in exclude else None
for x in self.var_names])
exclude = ['itiIn_times', 'stimOffTrigger_times', 'stimCenter_times',
'stimCenterTrigger_times', 'position', 'phase']
self.save_names = tuple(f'_ibl_trials.{x}.npy' if x not in exclude else None for x in self.var_names)

def _extract(self):
def _extract(self) -> dict:
# Extract all trials...

# Get all stim_sync events detected
Expand Down Expand Up @@ -101,9 +100,14 @@ def _extract(self):
["iti"][0][0] for tr in self.bpod_trials]
)

# Phase and position
out['position'] = np.array([t['position'] for t in self.bpod_trials])
out['phase'] = np.array([t['stim_phase'] for t in self.bpod_trials])

# NB: We lose the last trial because the stim off event occurs at trial_num + 1
n_trials = out['stimOff_times'].size
return [out[k][:n_trials] for k in self.var_names]
# return [out[k][:n_trials] for k in self.var_names]
return {k: out[k][:n_trials] for k in self.var_names}


def extract_all(session_path, save=False, bpod_trials=False, settings=False, task_collection='raw_behavior_data', save_path=None):
Expand Down
21 changes: 9 additions & 12 deletions ibllib/io/extractors/mesoscope.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(self, *args, sync_collection='raw_sync_data', **kwargs):
super().__init__(*args, **kwargs)
self.timeline = alfio.load_object(self.session_path / sync_collection, 'DAQdata', namespace='timeline')

def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwargs):
def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwargs) -> dict:
if not (sync or chmap):
sync, chmap = load_timeline_sync_and_chmap(
self.session_path / sync_collection, timeline=self.timeline, chmap=chmap)
Expand All @@ -110,20 +110,17 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwa
trials = super()._extract(sync, chmap, sync_collection, extractor_type='ephys', **kwargs)

# If no protocol number is defined, trim timestamps based on Bpod trials intervals
trials_table = trials[self.var_names.index('table')]
trials_table = trials['table']
bpod = get_sync_fronts(sync, chmap['bpod'])
if kwargs.get('protocol_number') is None:
tmin = trials_table.intervals_0.iloc[0] - 1
tmax = trials_table.intervals_1.iloc[-1]
# Ensure wheel is cut off based on trials
wheel_ts_idx = self.var_names.index('wheel_timestamps')
mask = np.logical_and(tmin <= trials[wheel_ts_idx], trials[wheel_ts_idx] <= tmax)
trials[wheel_ts_idx] = trials[wheel_ts_idx][mask]
wheel_pos_idx = self.var_names.index('wheel_position')
trials[wheel_pos_idx] = trials[wheel_pos_idx][mask]
move_idx = self.var_names.index('wheelMoves_intervals')
mask = np.logical_and(trials[move_idx][:, 0] >= tmin, trials[move_idx][:, 0] <= tmax)
trials[move_idx] = trials[move_idx][mask, :]
mask = np.logical_and(tmin <= trials['wheel_timestamps'], trials['wheel_timestamps'] <= tmax)
trials['wheel_timestamps'] = trials['wheel_timestamps'][mask]
trials['wheel_position'] = trials['wheel_position'][mask]
mask = np.logical_and(trials['wheelMoves_intervals'][:, 0] >= tmin, trials['wheelMoves_intervals'][:, 0] <= tmax)
trials['wheelMoves_intervals'] = trials['wheelMoves_intervals'][mask, :]
else:
tmin, tmax = get_protocol_period(self.session_path, kwargs['protocol_number'], bpod)
bpod = get_sync_fronts(sync, chmap['bpod'], tmin, tmax)
Expand All @@ -138,7 +135,7 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwa
valve_open_times = self.get_valve_open_times(driver_ttls=driver_out)
assert len(valve_open_times) == sum(trials_table.feedbackType == 1) # TODO Relax assertion
correct = trials_table.feedbackType == 1
trials[self.var_names.index('valveOpen_times')][correct] = valve_open_times
trials['valveOpen_times'][correct] = valve_open_times
trials_table.feedback_times[correct] = valve_open_times

# Replace audio events
Expand Down Expand Up @@ -191,7 +188,7 @@ def first_true(arr):

trials_table.feedback_times[~correct] = error_cue
trials_table.goCue_times = go_cue
return trials
return {k: trials[k] for k in self.var_names}

def extract_wheel_sync(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4', tmin=None, tmax=None):
"""
Expand Down
Loading

0 comments on commit 68fb8cd

Please sign in to comment.