Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Task qc extractor refactor #649

Merged
merged 6 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ __pycache__
python/scratch
.idea/*
.vscode/
*.code-workspace
*checkpoint.ipynb
build/
venv/
Expand Down
2 changes: 1 addition & 1 deletion ibllib/ephys/ephysqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ def _qc_from_path(sess_path, display=True):
sync, chmap = ephys_fpga.get_main_probe_sync(sess_path, bin_exists=False)
_ = ephys_fpga.extract_all(sess_path, output_path=temp_alf_folder, save=True)
# check that the output is complete
fpga_trials = ephys_fpga.extract_behaviour_sync(sync, chmap=chmap, display=display)
fpga_trials, *_ = ephys_fpga.extract_behaviour_sync(sync, chmap=chmap, display=display)
# align with the bpod
bpod2fpga = ephys_fpga.align_with_bpod(temp_alf_folder.parent)
alf_trials = alfio.load_object(temp_alf_folder, 'trials')
Expand Down
24 changes: 12 additions & 12 deletions ibllib/io/extractors/biased_trials.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ class TrialsTableBiased(BaseBpodTrialsExtractor):
intervals, goCue_times, response_times, choice, stimOn_times, contrastLeft, contrastRight,
feedback_times, feedbackType, rewardVolume, probabilityLeft, firstMovement_times
Additionally extracts the following wheel data:
wheel_timestamps, wheel_position, wheel_moves_intervals, wheel_moves_peak_amplitude
wheel_timestamps, wheel_position, wheelMoves_intervals, wheelMoves_peakAmplitude
"""
save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy',
'_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None)
var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheel_moves_intervals',
'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement')
var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals',
'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement')

def _extract(self, extractor_classes=None, **kwargs):
base = [Intervals, GoCueTimes, ResponseTimes, Choice, StimOnOffFreezeTimes, ContrastLR, FeedbackTimes, FeedbackType,
Expand All @@ -120,13 +120,13 @@ class TrialsTableEphys(BaseBpodTrialsExtractor):
intervals, goCue_times, response_times, choice, stimOn_times, contrastLeft, contrastRight,
feedback_times, feedbackType, rewardVolume, probabilityLeft, firstMovement_times
Additionally extracts the following wheel data:
wheel_timestamps, wheel_position, wheel_moves_intervals, wheel_moves_peak_amplitude
wheel_timestamps, wheel_position, wheelMoves_intervals, wheelMoves_peakAmplitude
"""
save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy',
'_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None,
None, None, None, '_ibl_trials.quiescencePeriod.npy')
var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheel_moves_intervals',
'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement',
var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals',
'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement',
'phase', 'position', 'quiescence')

def _extract(self, extractor_classes=None, **kwargs):
Expand Down Expand Up @@ -154,16 +154,16 @@ class BiasedTrials(BaseBpodTrialsExtractor):
None, None, '_ibl_trials.quiescencePeriod.npy')
var_names = ('goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times',
'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position',
'wheel_moves_intervals', 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement', 'included',
'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement', 'included',
'phase', 'position', 'quiescence')

def _extract(self, extractor_classes=None, **kwargs):
def _extract(self, extractor_classes=None, **kwargs) -> dict:
base = [GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes,
ErrorCueTriggerTimes, TrialsTableBiased, IncludedTrials, PhasePosQuiescence]
# Exclude from trials table
out, _ = run_extractor_classes(base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings,
save=False, task_collection=self.task_collection)
return tuple(out.pop(x) for x in self.var_names)
return {k: out[k] for k in self.var_names}


class EphysTrials(BaseBpodTrialsExtractor):
Expand All @@ -177,16 +177,16 @@ class EphysTrials(BaseBpodTrialsExtractor):
'_ibl_trials.included.npy', None, None, '_ibl_trials.quiescencePeriod.npy')
var_names = ('goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times',
'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position',
'wheel_moves_intervals', 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement', 'included',
'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement', 'included',
'phase', 'position', 'quiescence')

def _extract(self, extractor_classes=None, **kwargs):
def _extract(self, extractor_classes=None, **kwargs) -> dict:
base = [GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes,
ErrorCueTriggerTimes, TrialsTableEphys, IncludedTrials, PhasePosQuiescence]
# Exclude from trials table
out, _ = run_extractor_classes(base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings,
save=False, task_collection=self.task_collection)
return tuple(out.pop(x) for x in self.var_names)
return {k: out[k] for k in self.var_names}


def extract_all(session_path, save=False, bpod_trials=False, settings=False, extra_classes=None,
Expand Down
60 changes: 44 additions & 16 deletions ibllib/io/extractors/ephys_fpga.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from iblutil.spacer import Spacer

import ibllib.exceptions as err
from ibllib.io import raw_data_loaders, session_params
from ibllib.io import raw_data_loaders as raw, session_params
from ibllib.io.extractors.bpod_trials import extract_all as bpod_extract_all
import ibllib.io.extractors.base as extractors_base
from ibllib.io.extractors.training_wheel import extract_wheel_moves
Expand Down Expand Up @@ -554,7 +554,7 @@ def extract_behaviour_sync(sync, chmap=None, display=False, bpod_trials=None, tm
ax.set_yticks([0, 1, 2, 3, 4, 5])
ax.set_ylim([0, 5])

return trials
return trials, frame2ttl, audio, bpod


def extract_sync(session_path, overwrite=False, ephys_files=None, namespace='spikeglx'):
Expand Down Expand Up @@ -734,6 +734,7 @@ def __init__(self, *args, bpod_trials=None, bpod_extractor=None, **kwargs):
super().__init__(*args, **kwargs)
self.bpod2fpga = None
self.bpod_trials = bpod_trials
self.frame2ttl = self.audio = self.bpod = self.settings = None
if bpod_extractor:
self.bpod_extractor = bpod_extractor
self._update_var_names()
Expand All @@ -750,14 +751,37 @@ def _update_var_names(self, bpod_fields=None, bpod_rsync_fields=None):
A set of Bpod trials fields to keep.
bpod_rsync_fields : tuple
A set of Bpod trials fields to sync to the DAQ times.

TODO Turn into property getter; requires ensuring the output field are the same for legacy
"""
if self.bpod_extractor:
self.var_names = self.bpod_extractor.var_names
self.save_names = self.bpod_extractor.save_names
self.bpod_rsync_fields = bpod_rsync_fields or self._time_fields(self.bpod_extractor.var_names)
self.bpod_fields = bpod_fields or [x for x in self.bpod_extractor.var_names if x not in self.bpod_rsync_fields]
for var_name, save_name in zip(self.bpod_extractor.var_names, self.bpod_extractor.save_names):
if var_name not in self.var_names:
self.var_names += (var_name,)
self.save_names += (save_name,)

# self.var_names = self.bpod_extractor.var_names
# self.save_names = self.bpod_extractor.save_names
self.settings = self.bpod_extractor.settings # This is used by the TaskQC
self.bpod_rsync_fields = bpod_rsync_fields
if self.bpod_rsync_fields is None:
self.bpod_rsync_fields = tuple(self._time_fields(self.bpod_extractor.var_names))
if 'table' in self.bpod_extractor.var_names:
if not self.bpod_trials:
self.bpod_trials = self.bpod_extractor.extract(save=False)
table_keys = alfio.AlfBunch.from_df(self.bpod_trials['table']).keys()
self.bpod_rsync_fields += tuple(self._time_fields(table_keys))
elif bpod_rsync_fields:
self.bpod_rsync_fields = bpod_rsync_fields
excluded = (*self.bpod_rsync_fields, 'table')
if bpod_fields:
assert not set(self.bpod_fields).intersection(excluded), 'bpod_fields must not also be bpod_rsync_fields'
self.bpod_fields = bpod_fields
elif self.bpod_extractor:
self.bpod_fields = tuple(x for x in self.bpod_extractor.var_names if x not in excluded)
if 'table' in self.bpod_extractor.var_names:
if not self.bpod_trials:
self.bpod_trials = self.bpod_extractor.extract(save=False)
table_keys = alfio.AlfBunch.from_df(self.bpod_trials['table']).keys()
self.bpod_fields += (*[x for x in table_keys if x not in excluded], self.sync_field + '_bpod')

@staticmethod
def _time_fields(trials_attr) -> set:
Expand All @@ -778,7 +802,8 @@ def _time_fields(trials_attr) -> set:
pattern = re.compile(fr'^[_\w]*({"|".join(FIELDS)})[_\w]*$')
return set(filter(pattern.match, trials_attr))

def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', task_collection='raw_behavior_data', **kwargs):
def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data',
task_collection='raw_behavior_data', **kwargs) -> dict:
"""Extracts ephys trials by combining Bpod and FPGA sync pulses"""
# extract the behaviour data from bpod
if sync is None or chmap is None:
Expand All @@ -804,7 +829,8 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', task
else:
tmin = tmax = None

fpga_trials = extract_behaviour_sync(
# Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC
fpga_trials, self.frame2ttl, self.audio, self.bpod = extract_behaviour_sync(
sync=sync, chmap=chmap, bpod_trials=self.bpod_trials, tmin=tmin, tmax=tmax)
assert self.sync_field in self.bpod_trials and self.sync_field in fpga_trials
self.bpod_trials[f'{self.sync_field}_bpod'] = np.copy(self.bpod_trials[self.sync_field])
Expand All @@ -827,18 +853,20 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', task
# extract the wheel data
wheel, moves = self.get_wheel_positions(sync=sync, chmap=chmap, tmin=tmin, tmax=tmax)
from ibllib.io.extractors.training_wheel import extract_first_movement_times
settings = raw_data_loaders.load_settings(session_path=self.session_path, task_collection=task_collection)
min_qt = settings.get('QUIESCENT_PERIOD', None)
if not self.settings:
self.settings = raw.load_settings(session_path=self.session_path, task_collection=task_collection)
min_qt = self.settings.get('QUIESCENT_PERIOD', None)
first_move_onsets, *_ = extract_first_movement_times(moves, out, min_qt=min_qt)
out.update({'firstMovement_times': first_move_onsets})
# Re-create trials table
trials_table = alfio.AlfBunch({x: out.pop(x) for x in table_columns})
out['table'] = trials_table.to_df()

out.update({f'wheel_{k}': v for k, v in wheel.items()})
out.update({f'wheelMoves_{k}': v for k, v in moves.items()})
out = {k: out[k] for k in self.var_names if k in out} # Reorder output
assert tuple(filter(lambda x: 'wheel' not in x, self.var_names)) == tuple(out.keys())
return [out[k] for k in out] + [wheel['timestamps'], wheel['position'],
moves['intervals'], moves['peakAmplitude']]
assert self.var_names == tuple(out.keys())
return out

def get_wheel_positions(self, *args, **kwargs):
"""Extract wheel and wheelMoves objects.
Expand Down Expand Up @@ -882,7 +910,7 @@ def extract_all(session_path, sync_collection='raw_ephys_data', save=True, save_
If save is True, a list of file paths to the extracted data.
"""
# Extract Bpod trials
bpod_raw = raw_data_loaders.load_data(session_path, task_collection=task_collection)
bpod_raw = raw.load_data(session_path, task_collection=task_collection)
assert bpod_raw is not None, 'No task trials data in raw_behavior_data - Exit'
bpod_trials, *_ = bpod_extract_all(
session_path=session_path, bpod_trials=bpod_raw, task_collection=task_collection,
Expand Down
18 changes: 11 additions & 7 deletions ibllib/io/extractors/habituation_trials.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,15 @@ class HabituationTrials(BaseBpodTrialsExtractor):
var_names = ('feedbackType', 'rewardVolume', 'stimOff_times', 'contrastLeft', 'contrastRight',
'feedback_times', 'stimOn_times', 'stimOnTrigger_times', 'intervals',
'goCue_times', 'goCueTrigger_times', 'itiIn_times', 'stimOffTrigger_times',
'stimCenterTrigger_times', 'stimCenter_times')
'stimCenterTrigger_times', 'stimCenter_times', 'position', 'phase')

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
exclude = ['itiIn_times', 'stimOffTrigger_times',
'stimCenter_times', 'stimCenterTrigger_times']
self.save_names = tuple([f'_ibl_trials.{x}.npy' if x not in exclude else None
for x in self.var_names])
exclude = ['itiIn_times', 'stimOffTrigger_times', 'stimCenter_times',
'stimCenterTrigger_times', 'position', 'phase']
self.save_names = tuple(f'_ibl_trials.{x}.npy' if x not in exclude else None for x in self.var_names)

def _extract(self):
def _extract(self) -> dict:
# Extract all trials...

# Get all stim_sync events detected
Expand Down Expand Up @@ -101,9 +100,14 @@ def _extract(self):
["iti"][0][0] for tr in self.bpod_trials]
)

# Phase and position
out['position'] = np.array([t['position'] for t in self.bpod_trials])
out['phase'] = np.array([t['stim_phase'] for t in self.bpod_trials])

# NB: We lose the last trial because the stim off event occurs at trial_num + 1
n_trials = out['stimOff_times'].size
return [out[k][:n_trials] for k in self.var_names]
# return [out[k][:n_trials] for k in self.var_names]
return {k: out[k][:n_trials] for k in self.var_names}


def extract_all(session_path, save=False, bpod_trials=False, settings=False, task_collection='raw_behavior_data', save_path=None):
Expand Down
21 changes: 9 additions & 12 deletions ibllib/io/extractors/mesoscope.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(self, *args, sync_collection='raw_sync_data', **kwargs):
super().__init__(*args, **kwargs)
self.timeline = alfio.load_object(self.session_path / sync_collection, 'DAQdata', namespace='timeline')

def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwargs):
def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwargs) -> dict:
if not (sync or chmap):
sync, chmap = load_timeline_sync_and_chmap(
self.session_path / sync_collection, timeline=self.timeline, chmap=chmap)
Expand All @@ -110,20 +110,17 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwa
trials = super()._extract(sync, chmap, sync_collection, extractor_type='ephys', **kwargs)

# If no protocol number is defined, trim timestamps based on Bpod trials intervals
trials_table = trials[self.var_names.index('table')]
trials_table = trials['table']
bpod = get_sync_fronts(sync, chmap['bpod'])
if kwargs.get('protocol_number') is None:
tmin = trials_table.intervals_0.iloc[0] - 1
tmax = trials_table.intervals_1.iloc[-1]
# Ensure wheel is cut off based on trials
wheel_ts_idx = self.var_names.index('wheel_timestamps')
mask = np.logical_and(tmin <= trials[wheel_ts_idx], trials[wheel_ts_idx] <= tmax)
trials[wheel_ts_idx] = trials[wheel_ts_idx][mask]
wheel_pos_idx = self.var_names.index('wheel_position')
trials[wheel_pos_idx] = trials[wheel_pos_idx][mask]
move_idx = self.var_names.index('wheelMoves_intervals')
mask = np.logical_and(trials[move_idx][:, 0] >= tmin, trials[move_idx][:, 0] <= tmax)
trials[move_idx] = trials[move_idx][mask, :]
mask = np.logical_and(tmin <= trials['wheel_timestamps'], trials['wheel_timestamps'] <= tmax)
trials['wheel_timestamps'] = trials['wheel_timestamps'][mask]
trials['wheel_position'] = trials['wheel_position'][mask]
mask = np.logical_and(trials['wheelMoves_intervals'][:, 0] >= tmin, trials['wheelMoves_intervals'][:, 0] <= tmax)
trials['wheelMoves_intervals'] = trials['wheelMoves_intervals'][mask, :]
else:
tmin, tmax = get_protocol_period(self.session_path, kwargs['protocol_number'], bpod)
bpod = get_sync_fronts(sync, chmap['bpod'], tmin, tmax)
Expand All @@ -138,7 +135,7 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwa
valve_open_times = self.get_valve_open_times(driver_ttls=driver_out)
assert len(valve_open_times) == sum(trials_table.feedbackType == 1) # TODO Relax assertion
correct = trials_table.feedbackType == 1
trials[self.var_names.index('valveOpen_times')][correct] = valve_open_times
trials['valveOpen_times'][correct] = valve_open_times
trials_table.feedback_times[correct] = valve_open_times

# Replace audio events
Expand Down Expand Up @@ -191,7 +188,7 @@ def first_true(arr):

trials_table.feedback_times[~correct] = error_cue
trials_table.goCue_times = go_cue
return trials
return {k: trials[k] for k in self.var_names}

def extract_wheel_sync(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4', tmin=None, tmax=None):
"""
Expand Down
Loading
Loading