From 3be3d22ddeb93af45cd829c37ab50f69d4800590 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Wed, 25 Sep 2024 09:36:18 +0100 Subject: [PATCH 01/59] get_protocol_period: relax assertion for spacer_times --- ibllib/io/extractors/ephys_fpga.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ibllib/io/extractors/ephys_fpga.py b/ibllib/io/extractors/ephys_fpga.py index 3810abf29..eb6c5141d 100644 --- a/ibllib/io/extractors/ephys_fpga.py +++ b/ibllib/io/extractors/ephys_fpga.py @@ -569,7 +569,8 @@ def get_protocol_period(session_path, protocol_number, bpod_sync): # Ensure that the number of detected spacers matched the number of expected tasks if acquisition_description := session_params.read_params(session_path): n_tasks = len(acquisition_description.get('tasks', [])) - assert n_tasks == len(spacer_times), f'expected {n_tasks} spacers, found {len(spacer_times)}' + assert len(spacer_times) >= protocol_number, (f'expected {n_tasks} spacers, found only {len(spacer_times)} - ' + f'can not return protocol number {protocol_number}.') assert n_tasks > protocol_number >= 0, f'protocol number must be between 0 and {n_tasks}' else: assert protocol_number < len(spacer_times) From a607f241d9f0d8fda5e1b71c150993a036e8cee8 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Wed, 25 Sep 2024 09:40:55 +0100 Subject: [PATCH 02/59] lower alpha for event markers --- ibllib/qc/task_qc_viewer/task_qc.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ibllib/qc/task_qc_viewer/task_qc.py b/ibllib/qc/task_qc_viewer/task_qc.py index 89a8d172f..86767328a 100644 --- a/ibllib/qc/task_qc_viewer/task_qc.py +++ b/ibllib/qc/task_qc_viewer/task_qc.py @@ -140,7 +140,8 @@ def create_plots(self, axes, 'ymin': 0, 'ymax': 4, 'linewidth': 2, - 'ax': axes + 'ax': axes, + 'alpha': 0.5, } bnc1 = self.qc.extractor.frame_ttls From 68beb147cc244fc525992d11d0512b9a10c51496 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Wed, 25 Sep 2024 12:34:17 +0100 Subject: [PATCH 03/59] some cleaning-up --- ibllib/qc/task_qc_viewer/ViewEphysQC.py | 97 ++++++++++++++----------- 1 file changed, 53 insertions(+), 44 deletions(-) diff --git a/ibllib/qc/task_qc_viewer/ViewEphysQC.py b/ibllib/qc/task_qc_viewer/ViewEphysQC.py index 48155b270..3f3717d54 100644 --- a/ibllib/qc/task_qc_viewer/ViewEphysQC.py +++ b/ibllib/qc/task_qc_viewer/ViewEphysQC.py @@ -1,7 +1,8 @@ """An interactive PyQT QC data frame.""" import logging -from PyQt5 import QtCore, QtWidgets +from PyQt5 import QtWidgets +from PyQt5.QtCore import pyqtProperty, Qt, QVariant, QAbstractTableModel, QModelIndex, pyqtSlot from matplotlib.figure import Figure from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg, NavigationToolbar2QT import pandas as pd @@ -12,13 +13,13 @@ _logger = logging.getLogger(__name__) -class DataFrameModel(QtCore.QAbstractTableModel): - DtypeRole = QtCore.Qt.UserRole + 1000 - ValueRole = QtCore.Qt.UserRole + 1001 +class DataFrameTableModel(QAbstractTableModel): + DtypeRole = Qt.UserRole + 1000 + ValueRole = Qt.UserRole + 1001 - def __init__(self, df=pd.DataFrame(), parent=None): - super(DataFrameModel, self).__init__(parent) - self._dataframe = df + def __init__(self, parent=None, dataFrame: pd.DataFrame = pd.DataFrame()): + super(DataFrameTableModel, self).__init__(parent) + self._dataframe = dataFrame def setDataFrame(self, dataframe): self.beginResetModel() @@ -28,50 +29,50 @@ def setDataFrame(self, dataframe): def dataFrame(self): return self._dataframe - dataFrame = QtCore.pyqtProperty(pd.DataFrame, fget=dataFrame, fset=setDataFrame) + dataFrame = pyqtProperty(pd.DataFrame, fget=dataFrame, fset=setDataFrame) - @QtCore.pyqtSlot(int, QtCore.Qt.Orientation, result=str) - def headerData(self, section: int, orientation: QtCore.Qt.Orientation, - role: int = QtCore.Qt.DisplayRole): - if role == QtCore.Qt.DisplayRole: - if orientation == QtCore.Qt.Horizontal: + @pyqtSlot(int, Qt.Orientation, result=str) + def headerData(self, section: int, orientation: Qt.Orientation, + role: int = Qt.DisplayRole): + if role == Qt.DisplayRole: + if orientation == Qt.Horizontal: return self._dataframe.columns[section] else: return str(self._dataframe.index[section]) - return QtCore.QVariant() + return QVariant() - def rowCount(self, parent=QtCore.QModelIndex()): + def rowCount(self, parent=QModelIndex()): if parent.isValid(): return 0 return len(self._dataframe.index) - def columnCount(self, parent=QtCore.QModelIndex()): + def columnCount(self, parent=QModelIndex()): if parent.isValid(): return 0 return self._dataframe.columns.size - def data(self, index, role=QtCore.Qt.DisplayRole): + def data(self, index, role=Qt.DisplayRole): if (not index.isValid() or not (0 <= index.row() < self.rowCount() and 0 <= index.column() < self.columnCount())): - return QtCore.QVariant() + return QVariant() row = self._dataframe.index[index.row()] col = self._dataframe.columns[index.column()] dt = self._dataframe[col].dtype val = self._dataframe.iloc[row][col] - if role == QtCore.Qt.DisplayRole: + if role == Qt.DisplayRole: return str(val) - elif role == DataFrameModel.ValueRole: + elif role == DataFrameTableModel.ValueRole: return val - if role == DataFrameModel.DtypeRole: + if role == DataFrameTableModel.DtypeRole: return dt - return QtCore.QVariant() + return QVariant() def roleNames(self): roles = { - QtCore.Qt.DisplayRole: b'display', - DataFrameModel.DtypeRole: b'dtype', - DataFrameModel.ValueRole: b'value' + Qt.DisplayRole: b'display', + DataFrameTableModel.DtypeRole: b'dtype', + DataFrameTableModel.ValueRole: b'value' } return roles @@ -83,6 +84,8 @@ def sort(self, col, order): :param order: the order to be sorted, 0 is descending; 1, ascending :return: """ + if self._dataframe.empty: + return self.layoutAboutToBeChanged.emit() col_name = self._dataframe.columns.values[col] # print('sorting by ' + col_name) @@ -125,37 +128,43 @@ def __init__(self, parent=None, wheel=None): class GraphWindow(QtWidgets.QWidget): def __init__(self, parent=None, wheel=None): QtWidgets.QWidget.__init__(self, parent=parent) + self.lineEditPath = QtWidgets.QLineEdit(self) + + self.pushButtonLoad = QtWidgets.QPushButton("Select File", self) + self.pushButtonLoad.clicked.connect(self.loadFile) + + self.tableModel = DataFrameTableModel(self) + self.tableView = QtWidgets.QTableView(self) + self.tableView.setModel(self.tableModel) + self.tableView.setSortingEnabled(True) + self.tableView.doubleClicked.connect(self.tv_double_clicked) + vLayout = QtWidgets.QVBoxLayout(self) hLayout = QtWidgets.QHBoxLayout() - self.pathLE = QtWidgets.QLineEdit(self) - hLayout.addWidget(self.pathLE) - self.loadBtn = QtWidgets.QPushButton("Select File", self) - hLayout.addWidget(self.loadBtn) + hLayout.addWidget(self.lineEditPath) + hLayout.addWidget(self.pushButtonLoad) vLayout.addLayout(hLayout) - self.pandasTv = QtWidgets.QTableView(self) - vLayout.addWidget(self.pandasTv) - self.loadBtn.clicked.connect(self.load_file) - self.pandasTv.setSortingEnabled(True) - self.pandasTv.doubleClicked.connect(self.tv_double_clicked) + vLayout.addWidget(self.tableView) + self.wplot = PlotWindow(wheel=wheel) self.wplot.show() + self.tableModel.dataChanged.connect(self.wplot.canvas.draw) + self.wheel = wheel - def load_file(self): + def loadFile(self): fileName, _ = QtWidgets.QFileDialog.getOpenFileName( self, "Open File", "", "CSV Files (*.csv)") - self.pathLE.setText(fileName) + self.lineEditPath.setText(fileName) df = pd.read_csv(fileName) - self.update_df(df) + self.updateDataframe(df) - def update_df(self, df): - model = DataFrameModel(df) - self.pandasTv.setModel(model) - self.wplot.canvas.draw() + def updateDataframe(self, dataFrame: pd.DataFrame): + self.tableModel.setDataFrame(dataFrame) def tv_double_clicked(self): - df = self.pandasTv.model()._dataframe - ind = self.pandasTv.currentIndex() + df = self.tableView.model()._dataframe + ind = self.tableView.currentIndex() start = df.loc[ind.row()]['intervals_0'] finish = df.loc[ind.row()]['intervals_1'] dt = finish - start @@ -179,6 +188,6 @@ def viewqc(qc=None, title=None, wheel=None): qcw = GraphWindow(wheel=wheel) qcw.setWindowTitle(title) if qc is not None: - qcw.update_df(qc) + qcw.updateDataframe(qc) qcw.show() return qcw From 9f99ce68e7737d8114d8a0e1614388cd8ceaf678 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Wed, 25 Sep 2024 13:58:21 +0100 Subject: [PATCH 04/59] remove unused rolenames from DataFrameTableModel, various small fixes --- ibllib/qc/task_qc_viewer/ViewEphysQC.py | 74 +++++++++++-------------- 1 file changed, 32 insertions(+), 42 deletions(-) diff --git a/ibllib/qc/task_qc_viewer/ViewEphysQC.py b/ibllib/qc/task_qc_viewer/ViewEphysQC.py index 3f3717d54..507099767 100644 --- a/ibllib/qc/task_qc_viewer/ViewEphysQC.py +++ b/ibllib/qc/task_qc_viewer/ViewEphysQC.py @@ -1,4 +1,5 @@ """An interactive PyQT QC data frame.""" + import logging from PyQt5 import QtWidgets @@ -14,26 +15,24 @@ class DataFrameTableModel(QAbstractTableModel): - DtypeRole = Qt.UserRole + 1000 - ValueRole = Qt.UserRole + 1001 - def __init__(self, parent=None, dataFrame: pd.DataFrame = pd.DataFrame()): super(DataFrameTableModel, self).__init__(parent) self._dataframe = dataFrame - def setDataFrame(self, dataframe): + def setDataFrame(self, dataFrame: pd.DataFrame): self.beginResetModel() - self._dataframe = dataframe.copy() + self._dataframe = dataFrame.copy() self.endResetModel() - def dataFrame(self): + def dataFrame(self) -> pd.DataFrame: return self._dataframe dataFrame = pyqtProperty(pd.DataFrame, fget=dataFrame, fset=setDataFrame) @pyqtSlot(int, Qt.Orientation, result=str) - def headerData(self, section: int, orientation: Qt.Orientation, - role: int = Qt.DisplayRole): + def headerData( + self, section: int, orientation: Qt.Orientation, role: int = Qt.DisplayRole + ): if role == Qt.DisplayRole: if orientation == Qt.Horizontal: return self._dataframe.columns[section] @@ -52,30 +51,19 @@ def columnCount(self, parent=QModelIndex()): return self._dataframe.columns.size def data(self, index, role=Qt.DisplayRole): - if (not index.isValid() or not (0 <= index.row() < self.rowCount() and - 0 <= index.column() < self.columnCount())): + if not index.isValid(): return QVariant() row = self._dataframe.index[index.row()] col = self._dataframe.columns[index.column()] - dt = self._dataframe[col].dtype - val = self._dataframe.iloc[row][col] if role == Qt.DisplayRole: + if isinstance(val, np.generic): + return val.item() return str(val) - elif role == DataFrameTableModel.ValueRole: - return val - if role == DataFrameTableModel.DtypeRole: - return dt + # elif role == Qt.BackgroundRole: + # return QBrush(Qt.red) return QVariant() - def roleNames(self): - roles = { - Qt.DisplayRole: b'display', - DataFrameTableModel.DtypeRole: b'dtype', - DataFrameTableModel.ValueRole: b'value' - } - return roles - def sort(self, col, order): """ Sort table by given column number. @@ -84,7 +72,7 @@ def sort(self, col, order): :param order: the order to be sorted, 0 is descending; 1, ascending :return: """ - if self._dataframe.empty: + if self.columnCount() == 0: return self.layoutAboutToBeChanged.emit() col_name = self._dataframe.columns.values[col] @@ -95,7 +83,6 @@ def sort(self, col, order): class PlotCanvas(FigureCanvasQTAgg): - def __init__(self, parent=None, width=5, height=4, dpi=100, wheel=None): fig = Figure(figsize=(width, height), dpi=dpi) @@ -103,13 +90,13 @@ def __init__(self, parent=None, width=5, height=4, dpi=100, wheel=None): self.setParent(parent) FigureCanvasQTAgg.setSizePolicy( - self, - QtWidgets.QSizePolicy.Expanding, - QtWidgets.QSizePolicy.Expanding) + self, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding + ) FigureCanvasQTAgg.updateGeometry(self) if wheel: self.ax, self.ax2 = fig.subplots( - 2, 1, gridspec_kw={'height_ratios': [2, 1]}, sharex=True) + 2, 1, gridspec_kw={"height_ratios": [2, 1]}, sharex=True + ) else: self.ax = fig.add_subplot(111) self.draw() @@ -119,7 +106,7 @@ class PlotWindow(QtWidgets.QWidget): def __init__(self, parent=None, wheel=None): QtWidgets.QWidget.__init__(self, parent=None) self.canvas = PlotCanvas(wheel=wheel) - self.vbl = QtWidgets.QVBoxLayout() # Set box for plotting + self.vbl = QtWidgets.QVBoxLayout() # Set box for plotting self.vbl.addWidget(self.canvas) self.setLayout(self.vbl) self.vbl.addWidget(NavigationToolbar2QT(self.canvas, self)) @@ -137,6 +124,7 @@ def __init__(self, parent=None, wheel=None): self.tableView = QtWidgets.QTableView(self) self.tableView.setModel(self.tableModel) self.tableView.setSortingEnabled(True) + self.tableView.horizontalHeader().setDefaultAlignment(Qt.AlignLeft | Qt.AlignVCenter) self.tableView.doubleClicked.connect(self.tv_double_clicked) vLayout = QtWidgets.QVBoxLayout(self) @@ -154,7 +142,10 @@ def __init__(self, parent=None, wheel=None): def loadFile(self): fileName, _ = QtWidgets.QFileDialog.getOpenFileName( - self, "Open File", "", "CSV Files (*.csv)") + self, "Open File", "", "CSV Files (*.csv)" + ) + if len(fileName) == 0: + return self.lineEditPath.setText(fileName) df = pd.read_csv(fileName) self.updateDataframe(df) @@ -163,22 +154,21 @@ def updateDataframe(self, dataFrame: pd.DataFrame): self.tableModel.setDataFrame(dataFrame) def tv_double_clicked(self): - df = self.tableView.model()._dataframe ind = self.tableView.currentIndex() - start = df.loc[ind.row()]['intervals_0'] - finish = df.loc[ind.row()]['intervals_1'] - dt = finish - start + data = self.tableModel.dataFrame.loc[ind.row()] + t0 = data["intervals_0"] + t1 = data["intervals_1"] + dt = t1 - t0 if self.wheel: - idx = np.searchsorted( - self.wheel['re_ts'], np.array([start - dt / 10, finish + dt / 10])) - period = self.wheel['re_pos'][idx[0]:idx[1]] + idx = np.searchsorted(self.wheel["re_ts"], np.array([t0 - dt / 10, t1 + dt / 10])) + period = self.wheel["re_pos"][idx[0] : idx[1]] if period.size == 0: - _logger.warning('No wheel data during trial #%i', ind.row()) + _logger.warning("No wheel data during trial #%i", ind.row()) else: min_val, max_val = np.min(period), np.max(period) self.wplot.canvas.ax2.set_ylim(min_val - 1, max_val + 1) - self.wplot.canvas.ax2.set_xlim(start - dt / 10, finish + dt / 10) - self.wplot.canvas.ax.set_xlim(start - dt / 10, finish + dt / 10) + self.wplot.canvas.ax2.set_xlim(t0 - dt / 10, t1 + dt / 10) + self.wplot.canvas.ax.set_xlim(t0 - dt / 10, t1 + dt / 10) self.wplot.canvas.draw() From bfbe1007c119de768ac6c20fe1239f8fa3bd9c93 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Wed, 25 Sep 2024 15:21:10 +0100 Subject: [PATCH 05/59] Update task_qc.py --- ibllib/qc/task_qc_viewer/task_qc.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ibllib/qc/task_qc_viewer/task_qc.py b/ibllib/qc/task_qc_viewer/task_qc.py index 86767328a..b8fc1749f 100644 --- a/ibllib/qc/task_qc_viewer/task_qc.py +++ b/ibllib/qc/task_qc_viewer/task_qc.py @@ -285,8 +285,11 @@ def show_session_task_qc(qc_or_session=None, bpod_only=False, local=False, one=N trial_events=list(events), color_map=cm, linestyle=ls) + # Update table and callbacks - w.update_df(qc.frame) + n_trials = qc.frame.shape[0] + df_trials = pd.DataFrame({k: v for k, v in task_qc.extractor.data.items() if v.size == n_trials}) + w.updateDataframe(df_trials.merge(qc.frame, left_index=True, right_index=True)) qt.run_app() return qc From 6ae91f6656648a004f90901159fa6996011a2e8e Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Wed, 25 Sep 2024 20:19:50 +0100 Subject: [PATCH 06/59] add ColoredDataFrameTableModel --- ibllib/qc/task_qc_viewer/ViewEphysQC.py | 81 +++++++++++++++++-------- 1 file changed, 55 insertions(+), 26 deletions(-) diff --git a/ibllib/qc/task_qc_viewer/ViewEphysQC.py b/ibllib/qc/task_qc_viewer/ViewEphysQC.py index 507099767..cbde17809 100644 --- a/ibllib/qc/task_qc_viewer/ViewEphysQC.py +++ b/ibllib/qc/task_qc_viewer/ViewEphysQC.py @@ -3,7 +3,9 @@ import logging from PyQt5 import QtWidgets -from PyQt5.QtCore import pyqtProperty, Qt, QVariant, QAbstractTableModel, QModelIndex, pyqtSlot +from PyQt5.QtCore import pyqtProperty, Qt, QVariant, QAbstractTableModel, QModelIndex, QObject +from PyQt5.QtGui import QBrush, QColor +import matplotlib.pyplot as plt from matplotlib.figure import Figure from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg, NavigationToolbar2QT import pandas as pd @@ -15,8 +17,8 @@ class DataFrameTableModel(QAbstractTableModel): - def __init__(self, parent=None, dataFrame: pd.DataFrame = pd.DataFrame()): - super(DataFrameTableModel, self).__init__(parent) + def __init__(self, parent: QObject = ..., dataFrame: pd.DataFrame = pd.DataFrame()): + super().__init__(parent) self._dataframe = dataFrame def setDataFrame(self, dataFrame: pd.DataFrame): @@ -29,10 +31,7 @@ def dataFrame(self) -> pd.DataFrame: dataFrame = pyqtProperty(pd.DataFrame, fget=dataFrame, fset=setDataFrame) - @pyqtSlot(int, Qt.Orientation, result=str) - def headerData( - self, section: int, orientation: Qt.Orientation, role: int = Qt.DisplayRole - ): + def headerData(self, section: int, orientation: Qt.Orientation, role: int = ...): if role == Qt.DisplayRole: if orientation == Qt.Horizontal: return self._dataframe.columns[section] @@ -40,17 +39,17 @@ def headerData( return str(self._dataframe.index[section]) return QVariant() - def rowCount(self, parent=QModelIndex()): - if parent.isValid(): + def rowCount(self, parent: QModelIndex = ...): + if isinstance(parent, QModelIndex) and parent.isValid(): return 0 return len(self._dataframe.index) - def columnCount(self, parent=QModelIndex()): - if parent.isValid(): + def columnCount(self, parent: QModelIndex = ...): + if isinstance(parent, QModelIndex) and parent.isValid(): return 0 return self._dataframe.columns.size - def data(self, index, role=Qt.DisplayRole): + def data(self, index: QModelIndex, role: int = ...) -> QVariant: if not index.isValid(): return QVariant() row = self._dataframe.index[index.row()] @@ -59,29 +58,59 @@ def data(self, index, role=Qt.DisplayRole): if role == Qt.DisplayRole: if isinstance(val, np.generic): return val.item() - return str(val) - # elif role == Qt.BackgroundRole: - # return QBrush(Qt.red) + return QVariant(str(val)) return QVariant() - def sort(self, col, order): - """ - Sort table by given column number. - - :param col: the column number selected (between 0 and self._dataframe.columns.size) - :param order: the order to be sorted, 0 is descending; 1, ascending - :return: - """ + def sort(self, column: int, order: Qt.SortOrder = ...): if self.columnCount() == 0: return self.layoutAboutToBeChanged.emit() - col_name = self._dataframe.columns.values[col] - # print('sorting by ' + col_name) + col_name = self._dataframe.columns.values[column] self._dataframe.sort_values(by=col_name, ascending=not order, inplace=True) self._dataframe.reset_index(inplace=True, drop=True) self.layoutChanged.emit() +class ColoredDataFrameTableModel(DataFrameTableModel): + _colors: pd.DataFrame + _cmap = plt.get_cmap('plasma') + _alpha = 0.5 + + def __init__(self, parent: QObject = ..., dataFrame: pd.DataFrame = pd.DataFrame()): + super().__init__(parent=parent, dataFrame=dataFrame) + self._setColors() + self.modelReset.connect(self._setColors) + self.dataChanged.connect(self._setColors) + self.layoutChanged.connect(self._setColors) + + def _setColors(self): + df = self._dataframe.copy() + df = df.replace([np.inf, -np.inf], np.nan) + for col in df.select_dtypes(include=['bool']): + df[col] = df[col].astype(float) + for col in df.select_dtypes(exclude=['bool']): + df[col] = pd.to_numeric(df[col], errors='coerce').astype(float) + if df[col].nunique() == 1: + df[col] = QColor.fromRgb(*self._cmap(0, self._alpha, True)) + else: + df[col] = (df[col] - df[col].min()) / (df[col].max() - df[col].min()) + df[col] = [QColor.fromRgb(*x) for x in self._cmap(df[col], self._alpha, True).tolist()] + self._colors = df + + def data(self, index, role=...): + if not index.isValid(): + return QVariant() + if role == Qt.BackgroundRole: + row = self._dataframe.index[index.row()] + col = self._dataframe.columns[index.column()] + val = self._dataframe.iloc[row][col] + if isinstance(val, (np.bool_, np.number)) and not np.isnan(val): + return self._colors.iloc[row][col] + else: + return QBrush(Qt.white) + return super().data(index, role) + + class PlotCanvas(FigureCanvasQTAgg): def __init__(self, parent=None, width=5, height=4, dpi=100, wheel=None): fig = Figure(figsize=(width, height), dpi=dpi) @@ -120,7 +149,7 @@ def __init__(self, parent=None, wheel=None): self.pushButtonLoad = QtWidgets.QPushButton("Select File", self) self.pushButtonLoad.clicked.connect(self.loadFile) - self.tableModel = DataFrameTableModel(self) + self.tableModel = ColoredDataFrameTableModel(self) self.tableView = QtWidgets.QTableView(self) self.tableView.setModel(self.tableModel) self.tableView.setSortingEnabled(True) From 5644b352f8037b2a15a5fe99f8821b732319aa0d Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Wed, 25 Sep 2024 22:12:28 +0100 Subject: [PATCH 07/59] Update ViewEphysQC.py --- ibllib/qc/task_qc_viewer/ViewEphysQC.py | 47 ++++++++++++++++--------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/ibllib/qc/task_qc_viewer/ViewEphysQC.py b/ibllib/qc/task_qc_viewer/ViewEphysQC.py index cbde17809..a1b46f174 100644 --- a/ibllib/qc/task_qc_viewer/ViewEphysQC.py +++ b/ibllib/qc/task_qc_viewer/ViewEphysQC.py @@ -10,6 +10,7 @@ from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg, NavigationToolbar2QT import pandas as pd import numpy as np +from sklearn.preprocessing import MinMaxScaler from ibllib.misc import qt @@ -84,18 +85,34 @@ def __init__(self, parent: QObject = ..., dataFrame: pd.DataFrame = pd.DataFrame self.layoutChanged.connect(self._setColors) def _setColors(self): - df = self._dataframe.copy() - df = df.replace([np.inf, -np.inf], np.nan) - for col in df.select_dtypes(include=['bool']): - df[col] = df[col].astype(float) - for col in df.select_dtypes(exclude=['bool']): - df[col] = pd.to_numeric(df[col], errors='coerce').astype(float) - if df[col].nunique() == 1: - df[col] = QColor.fromRgb(*self._cmap(0, self._alpha, True)) - else: - df[col] = (df[col] - df[col].min()) / (df[col].max() - df[col].min()) - df[col] = [QColor.fromRgb(*x) for x in self._cmap(df[col], self._alpha, True).tolist()] - self._colors = df + vals = self._dataframe.copy() + if vals.empty: + self._colors = vals + return + + # coerce non-bool / non-numeric values to numeric + for col in vals.select_dtypes(exclude=['bool', 'number']): + vals[col] = vals[col].to_numeric(errors='coerce') + + # normalize numeric values + cols = vals.select_dtypes(include=['number']).columns + vals.replace([np.inf, -np.inf], np.nan, inplace=True) + vals[cols] = MinMaxScaler().fit_transform(vals[cols]) + + # convert boolean values + cols = vals.select_dtypes(include=['bool']).columns + vals[cols] = vals[cols].astype(float) + + # assign QColors + colors = vals.astype(object) + for col in vals.columns: + colors[col] = [QColor.fromRgb(*x) for x in self._cmap(vals[col], self._alpha, True)] + + # NaNs should be white + nans = vals.isna() + colors[nans] = QColor('white') + + self._colors = colors def data(self, index, role=...): if not index.isValid(): @@ -103,11 +120,7 @@ def data(self, index, role=...): if role == Qt.BackgroundRole: row = self._dataframe.index[index.row()] col = self._dataframe.columns[index.column()] - val = self._dataframe.iloc[row][col] - if isinstance(val, (np.bool_, np.number)) and not np.isnan(val): - return self._colors.iloc[row][col] - else: - return QBrush(Qt.white) + return self._colors.iloc[row][col] return super().data(index, role) From 074b01a09f79ca119a83788fd4b3b3b888f786f1 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Thu, 26 Sep 2024 00:38:39 +0100 Subject: [PATCH 08/59] Update ViewEphysQC.py --- ibllib/qc/task_qc_viewer/ViewEphysQC.py | 71 ++++++++++++------------- 1 file changed, 35 insertions(+), 36 deletions(-) diff --git a/ibllib/qc/task_qc_viewer/ViewEphysQC.py b/ibllib/qc/task_qc_viewer/ViewEphysQC.py index a1b46f174..4e1f31c5c 100644 --- a/ibllib/qc/task_qc_viewer/ViewEphysQC.py +++ b/ibllib/qc/task_qc_viewer/ViewEphysQC.py @@ -6,11 +6,11 @@ from PyQt5.QtCore import pyqtProperty, Qt, QVariant, QAbstractTableModel, QModelIndex, QObject from PyQt5.QtGui import QBrush, QColor import matplotlib.pyplot as plt +from matplotlib.colors import ListedColormap from matplotlib.figure import Figure from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg, NavigationToolbar2QT import pandas as pd import numpy as np -from sklearn.preprocessing import MinMaxScaler from ibllib.misc import qt @@ -18,9 +18,9 @@ class DataFrameTableModel(QAbstractTableModel): - def __init__(self, parent: QObject = ..., dataFrame: pd.DataFrame = pd.DataFrame()): + def __init__(self, parent: QObject = ..., dataFrame: pd.DataFrame | None = None): super().__init__(parent) - self._dataframe = dataFrame + self._dataframe = pd.DataFrame() if dataFrame is None else dataFrame def setDataFrame(self, dataFrame: pd.DataFrame): self.beginResetModel() @@ -73,54 +73,53 @@ def sort(self, column: int, order: Qt.SortOrder = ...): class ColoredDataFrameTableModel(DataFrameTableModel): - _colors: pd.DataFrame - _cmap = plt.get_cmap('plasma') - _alpha = 0.5 + _rgba: np.ndarray + _cmap: ListedColormap - def __init__(self, parent: QObject = ..., dataFrame: pd.DataFrame = pd.DataFrame()): + def __init__(self, parent: QObject = ..., dataFrame: pd.DataFrame | None = None, + colorMap: ListedColormap | None = None, alpha: float = 0.5): super().__init__(parent=parent, dataFrame=dataFrame) - self._setColors() - self.modelReset.connect(self._setColors) - self.dataChanged.connect(self._setColors) - self.layoutChanged.connect(self._setColors) - - def _setColors(self): - vals = self._dataframe.copy() - if vals.empty: - self._colors = vals + + self._alpha = alpha + if colorMap is None: + self._cmap = plt.get_cmap('plasma') + self._cmap.set_bad(color='w') + else: + self._cmap = colorMap + + self._setRgba() + self.modelReset.connect(self._setRgba) + self.dataChanged.connect(self._setRgba) + self.layoutChanged.connect(self._setRgba) + + def _setRgba(self): + values = self._dataframe.copy() + if values.empty: + self._rgba = values return # coerce non-bool / non-numeric values to numeric - for col in vals.select_dtypes(exclude=['bool', 'number']): - vals[col] = vals[col].to_numeric(errors='coerce') + cols = values.select_dtypes(exclude=['bool', 'number']).columns + values[cols] = values[cols].apply(pd.to_numeric, errors='coerce') # normalize numeric values - cols = vals.select_dtypes(include=['number']).columns - vals.replace([np.inf, -np.inf], np.nan, inplace=True) - vals[cols] = MinMaxScaler().fit_transform(vals[cols]) + cols = values.select_dtypes(include=['number']).columns + values.replace([np.inf, -np.inf], np.nan, inplace=True) + values[cols] -= values[cols].min() + values[cols] /= values[cols].max() # convert boolean values - cols = vals.select_dtypes(include=['bool']).columns - vals[cols] = vals[cols].astype(float) - - # assign QColors - colors = vals.astype(object) - for col in vals.columns: - colors[col] = [QColor.fromRgb(*x) for x in self._cmap(vals[col], self._alpha, True)] - - # NaNs should be white - nans = vals.isna() - colors[nans] = QColor('white') + cols = values.select_dtypes(include=['bool']).columns + values[cols] = values[cols].astype(float) - self._colors = colors + # store color values to ndarray + self._rgba = self._cmap(values, self._alpha, True) def data(self, index, role=...): if not index.isValid(): return QVariant() if role == Qt.BackgroundRole: - row = self._dataframe.index[index.row()] - col = self._dataframe.columns[index.column()] - return self._colors.iloc[row][col] + return QColor.fromRgb(*self._rgba[index.row(), index.column()]) return super().data(index, role) From 9203f04436b77fc88ec7e1460518d86fab25ef6f Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Thu, 26 Sep 2024 00:58:23 +0100 Subject: [PATCH 09/59] Update ViewEphysQC.py --- ibllib/qc/task_qc_viewer/ViewEphysQC.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ibllib/qc/task_qc_viewer/ViewEphysQC.py b/ibllib/qc/task_qc_viewer/ViewEphysQC.py index 4e1f31c5c..f101435e5 100644 --- a/ibllib/qc/task_qc_viewer/ViewEphysQC.py +++ b/ibllib/qc/task_qc_viewer/ViewEphysQC.py @@ -103,10 +103,11 @@ def _setRgba(self): values[cols] = values[cols].apply(pd.to_numeric, errors='coerce') # normalize numeric values - cols = values.select_dtypes(include=['number']).columns values.replace([np.inf, -np.inf], np.nan, inplace=True) + cols = values.select_dtypes(include=['number']).columns + values[cols].astype(float) values[cols] -= values[cols].min() - values[cols] /= values[cols].max() + values[cols] = values[cols].div(values[cols].max()).replace(np.inf, 0, inplace=True) # convert boolean values cols = values.select_dtypes(include=['bool']).columns From 2f8a5c9ee380d2844b6287865114fe7715ff1b8b Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Thu, 26 Sep 2024 01:43:34 +0100 Subject: [PATCH 10/59] Update ViewEphysQC.py --- ibllib/qc/task_qc_viewer/ViewEphysQC.py | 32 +++++++++++++------------ 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/ibllib/qc/task_qc_viewer/ViewEphysQC.py b/ibllib/qc/task_qc_viewer/ViewEphysQC.py index f101435e5..5836275cf 100644 --- a/ibllib/qc/task_qc_viewer/ViewEphysQC.py +++ b/ibllib/qc/task_qc_viewer/ViewEphysQC.py @@ -93,28 +93,30 @@ def __init__(self, parent: QObject = ..., dataFrame: pd.DataFrame | None = None, self.layoutChanged.connect(self._setRgba) def _setRgba(self): - values = self._dataframe.copy() - if values.empty: - self._rgba = values + df = self._dataframe.copy() + if df.empty: + self._rgba = df return # coerce non-bool / non-numeric values to numeric - cols = values.select_dtypes(exclude=['bool', 'number']).columns - values[cols] = values[cols].apply(pd.to_numeric, errors='coerce') - - # normalize numeric values - values.replace([np.inf, -np.inf], np.nan, inplace=True) - cols = values.select_dtypes(include=['number']).columns - values[cols].astype(float) - values[cols] -= values[cols].min() - values[cols] = values[cols].div(values[cols].max()).replace(np.inf, 0, inplace=True) + cols = df.select_dtypes(exclude=['bool', 'number']).columns + df[cols] = df[cols].apply(pd.to_numeric, errors='coerce') + + # normalize numeric values, avoiding inf values and division by zero + num_cols = df.select_dtypes(include=['number']).columns + df[num_cols].replace([np.inf, -np.inf], np.nan) + mask = df[num_cols].nunique(dropna=True) == 1 + cols = num_cols[mask] + df[cols] = df[cols].where(df[cols].isna(), other=0.0) + cols = num_cols[~mask] + df[cols] = (df[cols] - df[cols].min()) / (df[cols].max() - df[cols].min()) # convert boolean values - cols = values.select_dtypes(include=['bool']).columns - values[cols] = values[cols].astype(float) + cols = df.select_dtypes(include=['bool']).columns + df[cols] = df[cols].astype(float) # store color values to ndarray - self._rgba = self._cmap(values, self._alpha, True) + self._rgba = self._cmap(df, self._alpha, True) def data(self, index, role=...): if not index.isValid(): From 4b4b063cba242d82c69cffb3c33b4ec44bfb86c3 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Thu, 26 Sep 2024 09:59:58 +0100 Subject: [PATCH 11/59] moveable sections, tooltips for header --- ibllib/qc/task_qc_viewer/ViewEphysQC.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/ibllib/qc/task_qc_viewer/ViewEphysQC.py b/ibllib/qc/task_qc_viewer/ViewEphysQC.py index 5836275cf..ed6847aff 100644 --- a/ibllib/qc/task_qc_viewer/ViewEphysQC.py +++ b/ibllib/qc/task_qc_viewer/ViewEphysQC.py @@ -6,6 +6,7 @@ from PyQt5.QtCore import pyqtProperty, Qt, QVariant, QAbstractTableModel, QModelIndex, QObject from PyQt5.QtGui import QBrush, QColor import matplotlib.pyplot as plt +from PyQt5.QtWidgets import QStyledItemDelegate from matplotlib.colors import ListedColormap from matplotlib.figure import Figure from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg, NavigationToolbar2QT @@ -33,7 +34,7 @@ def dataFrame(self) -> pd.DataFrame: dataFrame = pyqtProperty(pd.DataFrame, fget=dataFrame, fset=setDataFrame) def headerData(self, section: int, orientation: Qt.Orientation, role: int = ...): - if role == Qt.DisplayRole: + if role in (Qt.DisplayRole, Qt.ToolTipRole): if orientation == Qt.Horizontal: return self._dataframe.columns[section] else: @@ -103,12 +104,11 @@ def _setRgba(self): df[cols] = df[cols].apply(pd.to_numeric, errors='coerce') # normalize numeric values, avoiding inf values and division by zero - num_cols = df.select_dtypes(include=['number']).columns - df[num_cols].replace([np.inf, -np.inf], np.nan) - mask = df[num_cols].nunique(dropna=True) == 1 - cols = num_cols[mask] - df[cols] = df[cols].where(df[cols].isna(), other=0.0) - cols = num_cols[~mask] + cols = df.select_dtypes(include=['number']).columns + df[cols].replace([np.inf, -np.inf], np.nan) + m = df[cols].nunique() <= 1 # boolean mask for columns with only 1 unique value + df[cols[m]] = df[cols[m]].where(df[cols[m]].isna(), other=0.0) + cols = cols[~m] df[cols] = (df[cols] - df[cols].min()) / (df[cols].max() - df[cols].min()) # convert boolean values @@ -169,6 +169,8 @@ def __init__(self, parent=None, wheel=None): self.tableView.setModel(self.tableModel) self.tableView.setSortingEnabled(True) self.tableView.horizontalHeader().setDefaultAlignment(Qt.AlignLeft | Qt.AlignVCenter) + self.tableView.horizontalHeader().setSectionsMovable(True) + self.tableView.verticalHeader().hide() self.tableView.doubleClicked.connect(self.tv_double_clicked) vLayout = QtWidgets.QVBoxLayout(self) From ddbad74022d4a5aec46df12c4a65a8a579a05aed Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Thu, 26 Sep 2024 11:25:58 +0100 Subject: [PATCH 12/59] speed up sort and color handling --- ibllib/qc/task_qc_viewer/ViewEphysQC.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/ibllib/qc/task_qc_viewer/ViewEphysQC.py b/ibllib/qc/task_qc_viewer/ViewEphysQC.py index ed6847aff..2df0b91c3 100644 --- a/ibllib/qc/task_qc_viewer/ViewEphysQC.py +++ b/ibllib/qc/task_qc_viewer/ViewEphysQC.py @@ -54,10 +54,8 @@ def columnCount(self, parent: QModelIndex = ...): def data(self, index: QModelIndex, role: int = ...) -> QVariant: if not index.isValid(): return QVariant() - row = self._dataframe.index[index.row()] - col = self._dataframe.columns[index.column()] - val = self._dataframe.iloc[row][col] if role == Qt.DisplayRole: + val = self._dataframe.iloc[index.row()][index.column()] if isinstance(val, np.generic): return val.item() return QVariant(str(val)) @@ -66,10 +64,9 @@ def data(self, index: QModelIndex, role: int = ...) -> QVariant: def sort(self, column: int, order: Qt.SortOrder = ...): if self.columnCount() == 0: return + columnName = self._dataframe.columns.values[column] self.layoutAboutToBeChanged.emit() - col_name = self._dataframe.columns.values[column] - self._dataframe.sort_values(by=col_name, ascending=not order, inplace=True) - self._dataframe.reset_index(inplace=True, drop=True) + self._dataframe.sort_values(by=columnName, ascending=not order, inplace=True) self.layoutChanged.emit() @@ -91,7 +88,6 @@ def __init__(self, parent: QObject = ..., dataFrame: pd.DataFrame | None = None, self._setRgba() self.modelReset.connect(self._setRgba) self.dataChanged.connect(self._setRgba) - self.layoutChanged.connect(self._setRgba) def _setRgba(self): df = self._dataframe.copy() @@ -122,7 +118,8 @@ def data(self, index, role=...): if not index.isValid(): return QVariant() if role == Qt.BackgroundRole: - return QColor.fromRgb(*self._rgba[index.row(), index.column()]) + row = self._dataframe.index[index.row()] + return QColor.fromRgb(*self._rgba[row][index.column()]) return super().data(index, role) @@ -170,7 +167,7 @@ def __init__(self, parent=None, wheel=None): self.tableView.setSortingEnabled(True) self.tableView.horizontalHeader().setDefaultAlignment(Qt.AlignLeft | Qt.AlignVCenter) self.tableView.horizontalHeader().setSectionsMovable(True) - self.tableView.verticalHeader().hide() + # self.tableView.verticalHeader().hide() self.tableView.doubleClicked.connect(self.tv_double_clicked) vLayout = QtWidgets.QVBoxLayout(self) From 2fbc5c681c3c5cddaf0d06d7167a96a06e712867 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Thu, 26 Sep 2024 12:25:34 +0100 Subject: [PATCH 13/59] add filter for column names --- ibllib/qc/task_qc_viewer/ViewEphysQC.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/ibllib/qc/task_qc_viewer/ViewEphysQC.py b/ibllib/qc/task_qc_viewer/ViewEphysQC.py index 2df0b91c3..62ca7cb94 100644 --- a/ibllib/qc/task_qc_viewer/ViewEphysQC.py +++ b/ibllib/qc/task_qc_viewer/ViewEphysQC.py @@ -6,7 +6,6 @@ from PyQt5.QtCore import pyqtProperty, Qt, QVariant, QAbstractTableModel, QModelIndex, QObject from PyQt5.QtGui import QBrush, QColor import matplotlib.pyplot as plt -from PyQt5.QtWidgets import QStyledItemDelegate from matplotlib.colors import ListedColormap from matplotlib.figure import Figure from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg, NavigationToolbar2QT @@ -64,9 +63,9 @@ def data(self, index: QModelIndex, role: int = ...) -> QVariant: def sort(self, column: int, order: Qt.SortOrder = ...): if self.columnCount() == 0: return - columnName = self._dataframe.columns.values[column] + column = self._dataframe.columns[column] self.layoutAboutToBeChanged.emit() - self._dataframe.sort_values(by=columnName, ascending=not order, inplace=True) + self._dataframe.sort_values(by=column, ascending=not order, inplace=True) self.layoutChanged.emit() @@ -156,7 +155,10 @@ def __init__(self, parent=None, wheel=None): class GraphWindow(QtWidgets.QWidget): def __init__(self, parent=None, wheel=None): QtWidgets.QWidget.__init__(self, parent=parent) - self.lineEditPath = QtWidgets.QLineEdit(self) + + self.lineEditFilter = QtWidgets.QLineEdit(self) + self.lineEditFilter.setPlaceholderText('Filter columns by name') + self.lineEditFilter.textChanged.connect(self.changeFilter) self.pushButtonLoad = QtWidgets.QPushButton("Select File", self) self.pushButtonLoad.clicked.connect(self.loadFile) @@ -167,12 +169,12 @@ def __init__(self, parent=None, wheel=None): self.tableView.setSortingEnabled(True) self.tableView.horizontalHeader().setDefaultAlignment(Qt.AlignLeft | Qt.AlignVCenter) self.tableView.horizontalHeader().setSectionsMovable(True) - # self.tableView.verticalHeader().hide() + self.tableView.verticalHeader().hide() self.tableView.doubleClicked.connect(self.tv_double_clicked) vLayout = QtWidgets.QVBoxLayout(self) hLayout = QtWidgets.QHBoxLayout() - hLayout.addWidget(self.lineEditPath) + hLayout.addWidget(self.lineEditFilter) hLayout.addWidget(self.pushButtonLoad) vLayout.addLayout(hLayout) vLayout.addWidget(self.tableView) @@ -183,13 +185,18 @@ def __init__(self, parent=None, wheel=None): self.wheel = wheel + def changeFilter(self, string: str): + headers = [self.tableModel.headerData(x, Qt.Horizontal, Qt.DisplayRole) + for x in range(self.tableModel.columnCount())] + for idx, column in enumerate(headers): + self.tableView.setColumnHidden(idx, string.lower() not in column.lower()) + def loadFile(self): fileName, _ = QtWidgets.QFileDialog.getOpenFileName( self, "Open File", "", "CSV Files (*.csv)" ) if len(fileName) == 0: return - self.lineEditPath.setText(fileName) df = pd.read_csv(fileName) self.updateDataframe(df) From 7bfcf723ca4e2a3303228e871a2634abec0457fe Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Thu, 26 Sep 2024 13:11:09 +0100 Subject: [PATCH 14/59] Update ViewEphysQC.py --- ibllib/qc/task_qc_viewer/ViewEphysQC.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/ibllib/qc/task_qc_viewer/ViewEphysQC.py b/ibllib/qc/task_qc_viewer/ViewEphysQC.py index 62ca7cb94..ea2688b3f 100644 --- a/ibllib/qc/task_qc_viewer/ViewEphysQC.py +++ b/ibllib/qc/task_qc_viewer/ViewEphysQC.py @@ -6,6 +6,7 @@ from PyQt5.QtCore import pyqtProperty, Qt, QVariant, QAbstractTableModel, QModelIndex, QObject from PyQt5.QtGui import QBrush, QColor import matplotlib.pyplot as plt +from PyQt5.QtWidgets import QTableView from matplotlib.colors import ListedColormap from matplotlib.figure import Figure from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg, NavigationToolbar2QT @@ -17,6 +18,8 @@ _logger = logging.getLogger(__name__) +class FreezeTableView(QTableView) + class DataFrameTableModel(QAbstractTableModel): def __init__(self, parent: QObject = ..., dataFrame: pd.DataFrame | None = None): super().__init__(parent) @@ -203,9 +206,8 @@ def loadFile(self): def updateDataframe(self, dataFrame: pd.DataFrame): self.tableModel.setDataFrame(dataFrame) - def tv_double_clicked(self): - ind = self.tableView.currentIndex() - data = self.tableModel.dataFrame.loc[ind.row()] + def tv_double_clicked(self, index: QModelIndex): + data = self.tableModel.dataFrame.iloc[index.row()] t0 = data["intervals_0"] t1 = data["intervals_1"] dt = t1 - t0 @@ -213,13 +215,13 @@ def tv_double_clicked(self): idx = np.searchsorted(self.wheel["re_ts"], np.array([t0 - dt / 10, t1 + dt / 10])) period = self.wheel["re_pos"][idx[0] : idx[1]] if period.size == 0: - _logger.warning("No wheel data during trial #%i", ind.row()) + _logger.warning("No wheel data during trial #%i", index.row()) else: min_val, max_val = np.min(period), np.max(period) self.wplot.canvas.ax2.set_ylim(min_val - 1, max_val + 1) self.wplot.canvas.ax2.set_xlim(t0 - dt / 10, t1 + dt / 10) self.wplot.canvas.ax.set_xlim(t0 - dt / 10, t1 + dt / 10) - + self.wplot.setWindowTitle(f"Trial {data.get('trial_no', '?')}") self.wplot.canvas.draw() From f227b9a7e3acc496d9a148214ccb9b53486285f4 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Thu, 26 Sep 2024 13:12:31 +0100 Subject: [PATCH 15/59] Update ViewEphysQC.py --- ibllib/qc/task_qc_viewer/ViewEphysQC.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/ibllib/qc/task_qc_viewer/ViewEphysQC.py b/ibllib/qc/task_qc_viewer/ViewEphysQC.py index ea2688b3f..22fbcef06 100644 --- a/ibllib/qc/task_qc_viewer/ViewEphysQC.py +++ b/ibllib/qc/task_qc_viewer/ViewEphysQC.py @@ -4,9 +4,8 @@ from PyQt5 import QtWidgets from PyQt5.QtCore import pyqtProperty, Qt, QVariant, QAbstractTableModel, QModelIndex, QObject -from PyQt5.QtGui import QBrush, QColor +from PyQt5.QtGui import QColor import matplotlib.pyplot as plt -from PyQt5.QtWidgets import QTableView from matplotlib.colors import ListedColormap from matplotlib.figure import Figure from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg, NavigationToolbar2QT @@ -18,8 +17,6 @@ _logger = logging.getLogger(__name__) -class FreezeTableView(QTableView) - class DataFrameTableModel(QAbstractTableModel): def __init__(self, parent: QObject = ..., dataFrame: pd.DataFrame | None = None): super().__init__(parent) From fd30945c87524703a561df217d8ce87407e0df63 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Thu, 26 Sep 2024 14:59:54 +0100 Subject: [PATCH 16/59] allow pinning of columns to that they won't be filtered --- ibllib/qc/task_qc_viewer/ViewEphysQC.py | 38 +++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/ibllib/qc/task_qc_viewer/ViewEphysQC.py b/ibllib/qc/task_qc_viewer/ViewEphysQC.py index 22fbcef06..a46711af1 100644 --- a/ibllib/qc/task_qc_viewer/ViewEphysQC.py +++ b/ibllib/qc/task_qc_viewer/ViewEphysQC.py @@ -3,9 +3,11 @@ import logging from PyQt5 import QtWidgets -from PyQt5.QtCore import pyqtProperty, Qt, QVariant, QAbstractTableModel, QModelIndex, QObject +from PyQt5.QtCore import pyqtProperty, Qt, QVariant, QAbstractTableModel, QModelIndex, \ + QObject, QPoint, pyqtSignal, pyqtSlot from PyQt5.QtGui import QColor import matplotlib.pyplot as plt +from PyQt5.QtWidgets import QMenu, QAction, QHeaderView from matplotlib.colors import ListedColormap from matplotlib.figure import Figure from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg, NavigationToolbar2QT @@ -153,9 +155,13 @@ def __init__(self, parent=None, wheel=None): class GraphWindow(QtWidgets.QWidget): + _pinnedColumns = [] + def __init__(self, parent=None, wheel=None): QtWidgets.QWidget.__init__(self, parent=parent) + self.columnPinned = pyqtSignal(int, bool) + self.lineEditFilter = QtWidgets.QLineEdit(self) self.lineEditFilter.setPlaceholderText('Filter columns by name') self.lineEditFilter.textChanged.connect(self.changeFilter) @@ -169,9 +175,15 @@ def __init__(self, parent=None, wheel=None): self.tableView.setSortingEnabled(True) self.tableView.horizontalHeader().setDefaultAlignment(Qt.AlignLeft | Qt.AlignVCenter) self.tableView.horizontalHeader().setSectionsMovable(True) - self.tableView.verticalHeader().hide() + self.tableView.horizontalHeader().setContextMenuPolicy(Qt.CustomContextMenu) + self.tableView.horizontalHeader().customContextMenuRequested.connect( + self.contextMenu) self.tableView.doubleClicked.connect(self.tv_double_clicked) + self.pinAction = QAction('Pin column', self) + self.pinAction.setCheckable(True) + self.pinAction.toggled.connect(self.pinColumn) + vLayout = QtWidgets.QVBoxLayout(self) hLayout = QtWidgets.QHBoxLayout() hLayout.addWidget(self.lineEditFilter) @@ -185,11 +197,31 @@ def __init__(self, parent=None, wheel=None): self.wheel = wheel + def contextMenu(self, pos: QPoint): + idx = self.sender().logicalIndexAt(pos) + action = self.pinAction + action.setData(idx) + action.setChecked(idx in self._pinnedColumns) + menu = QMenu(self) + menu.addAction(action) + menu.exec(self.mapToParent(pos)) + + @pyqtSlot(bool) + @pyqtSlot(bool, int) + def pinColumn(self, pin: bool, idx: int | None = None): + idx = idx if idx is not None else self.sender().data() + if not pin and idx in self._pinnedColumns: + self._pinnedColumns.remove(idx) + if pin and idx not in self._pinnedColumns: + self._pinnedColumns.append(idx) + self.changeFilter(self.lineEditFilter.text()) + def changeFilter(self, string: str): headers = [self.tableModel.headerData(x, Qt.Horizontal, Qt.DisplayRole) for x in range(self.tableModel.columnCount())] for idx, column in enumerate(headers): - self.tableView.setColumnHidden(idx, string.lower() not in column.lower()) + self.tableView.setColumnHidden(idx, string.lower() not in column.lower() + and idx not in self._pinnedColumns) def loadFile(self): fileName, _ = QtWidgets.QFileDialog.getOpenFileName( From 520441c00b207c450c9b2aaf4df2e18181db8bd9 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Thu, 26 Sep 2024 15:43:47 +0100 Subject: [PATCH 17/59] correct location of context menu popup --- ibllib/qc/task_qc_viewer/ViewEphysQC.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ibllib/qc/task_qc_viewer/ViewEphysQC.py b/ibllib/qc/task_qc_viewer/ViewEphysQC.py index a46711af1..ef563bd84 100644 --- a/ibllib/qc/task_qc_viewer/ViewEphysQC.py +++ b/ibllib/qc/task_qc_viewer/ViewEphysQC.py @@ -204,7 +204,7 @@ def contextMenu(self, pos: QPoint): action.setChecked(idx in self._pinnedColumns) menu = QMenu(self) menu.addAction(action) - menu.exec(self.mapToParent(pos)) + menu.exec(self.sender().mapToGlobal(pos)) @pyqtSlot(bool) @pyqtSlot(bool, int) From 54b0df0cbbe61d41ece2788e24153796ee1dfd45 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Fri, 27 Sep 2024 10:18:18 +0100 Subject: [PATCH 18/59] happy colors --- ibllib/qc/task_qc_viewer/ViewEphysQC.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ibllib/qc/task_qc_viewer/ViewEphysQC.py b/ibllib/qc/task_qc_viewer/ViewEphysQC.py index ef563bd84..c46763db4 100644 --- a/ibllib/qc/task_qc_viewer/ViewEphysQC.py +++ b/ibllib/qc/task_qc_viewer/ViewEphysQC.py @@ -76,12 +76,12 @@ class ColoredDataFrameTableModel(DataFrameTableModel): _cmap: ListedColormap def __init__(self, parent: QObject = ..., dataFrame: pd.DataFrame | None = None, - colorMap: ListedColormap | None = None, alpha: float = 0.5): + colorMap: ListedColormap | None = None, alpha: float = 1): super().__init__(parent=parent, dataFrame=dataFrame) self._alpha = alpha if colorMap is None: - self._cmap = plt.get_cmap('plasma') + self._cmap = plt.get_cmap('spring') self._cmap.set_bad(color='w') else: self._cmap = colorMap From cdbbed51e95438f6043b04239ebb3fe3af932fb5 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Fri, 27 Sep 2024 12:37:43 +0100 Subject: [PATCH 19/59] add signals & slots for ColoredDataFrameTableModel, alpha slider --- ibllib/qc/task_qc_viewer/ViewEphysQC.py | 91 ++++++++++++++++++------- 1 file changed, 68 insertions(+), 23 deletions(-) diff --git a/ibllib/qc/task_qc_viewer/ViewEphysQC.py b/ibllib/qc/task_qc_viewer/ViewEphysQC.py index c46763db4..9a89d7303 100644 --- a/ibllib/qc/task_qc_viewer/ViewEphysQC.py +++ b/ibllib/qc/task_qc_viewer/ViewEphysQC.py @@ -5,10 +5,10 @@ from PyQt5 import QtWidgets from PyQt5.QtCore import pyqtProperty, Qt, QVariant, QAbstractTableModel, QModelIndex, \ QObject, QPoint, pyqtSignal, pyqtSlot -from PyQt5.QtGui import QColor +from PyQt5.QtGui import QColor, QPalette import matplotlib.pyplot as plt from PyQt5.QtWidgets import QMenu, QAction, QHeaderView -from matplotlib.colors import ListedColormap +from matplotlib.colors import Colormap from matplotlib.figure import Figure from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg, NavigationToolbar2QT import pandas as pd @@ -56,7 +56,7 @@ def data(self, index: QModelIndex, role: int = ...) -> QVariant: if not index.isValid(): return QVariant() if role == Qt.DisplayRole: - val = self._dataframe.iloc[index.row()][index.column()] + val = self._dataframe.iloc[index.row(), index.column()] if isinstance(val, np.generic): return val.item() return QVariant(str(val)) @@ -72,24 +72,47 @@ def sort(self, column: int, order: Qt.SortOrder = ...): class ColoredDataFrameTableModel(DataFrameTableModel): + colormapChanged = pyqtSignal(Colormap) + alphaChanged = pyqtSignal(float) _rgba: np.ndarray - _cmap: ListedColormap + _cmap: Colormap + _alpha: int def __init__(self, parent: QObject = ..., dataFrame: pd.DataFrame | None = None, - colorMap: ListedColormap | None = None, alpha: float = 1): + colormap: Colormap | None = None, alpha: int = 255): super().__init__(parent=parent, dataFrame=dataFrame) - self._alpha = alpha - if colorMap is None: - self._cmap = plt.get_cmap('spring') - self._cmap.set_bad(color='w') - else: - self._cmap = colorMap - - self._setRgba() + self.colormapChanged.connect(self._setRgba) self.modelReset.connect(self._setRgba) self.dataChanged.connect(self._setRgba) + if colormap is None: + colormap = plt.get_cmap('spring') + colormap.set_bad(color='w') + self.setColormap(colormap) + self.setAlpha(alpha) + + @pyqtSlot(Colormap) + def setColormap(self, colormap: Colormap): + self._cmap = colormap + self.colormapChanged.emit(colormap) + + def getColormap(self) -> Colormap: + return self._cmap + + colormap = pyqtProperty(Colormap, fget=getColormap, fset=setColormap) + + @pyqtSlot(int) + def setAlpha(self, alpha: int = 255): + _, self._alpha, _ = sorted([0, alpha, 255]) + self.alphaChanged.emit(self._alpha) + self.layoutChanged.emit() + + def getAlpha(self) -> int: + return self._alpha + + alpha = pyqtProperty(int, fget=getAlpha, fset=setAlpha) + def _setRgba(self): df = self._dataframe.copy() if df.empty: @@ -112,15 +135,16 @@ def _setRgba(self): cols = df.select_dtypes(include=['bool']).columns df[cols] = df[cols].astype(float) - # store color values to ndarray - self._rgba = self._cmap(df, self._alpha, True) + # store color values to ndarray & emit signal + self._rgba = self._cmap(df, alpha=None, bytes=True) + self.layoutChanged.emit() def data(self, index, role=...): if not index.isValid(): return QVariant() if role == Qt.BackgroundRole: row = self._dataframe.index[index.row()] - return QColor.fromRgb(*self._rgba[row][index.column()]) + return QColor.fromRgb(*self._rgba[row][index.column()][:3], self._alpha) return super().data(index, role) @@ -162,13 +186,10 @@ def __init__(self, parent=None, wheel=None): self.columnPinned = pyqtSignal(int, bool) - self.lineEditFilter = QtWidgets.QLineEdit(self) - self.lineEditFilter.setPlaceholderText('Filter columns by name') - self.lineEditFilter.textChanged.connect(self.changeFilter) - self.pushButtonLoad = QtWidgets.QPushButton("Select File", self) self.pushButtonLoad.clicked.connect(self.loadFile) + # define table model & view self.tableModel = ColoredDataFrameTableModel(self) self.tableView = QtWidgets.QTableView(self) self.tableView.setModel(self.tableModel) @@ -176,18 +197,42 @@ def __init__(self, parent=None, wheel=None): self.tableView.horizontalHeader().setDefaultAlignment(Qt.AlignLeft | Qt.AlignVCenter) self.tableView.horizontalHeader().setSectionsMovable(True) self.tableView.horizontalHeader().setContextMenuPolicy(Qt.CustomContextMenu) - self.tableView.horizontalHeader().customContextMenuRequested.connect( - self.contextMenu) + self.tableView.horizontalHeader().customContextMenuRequested.connect(self.contextMenu) self.tableView.doubleClicked.connect(self.tv_double_clicked) + # define colors for highlighted cells + p = self.tableView.palette() + p.setColor(QPalette.Highlight, Qt.black) + p.setColor(QPalette.HighlightedText, Qt.white) + self.tableView.setPalette(p) + + # QAction for pinning columns self.pinAction = QAction('Pin column', self) self.pinAction.setCheckable(True) self.pinAction.toggled.connect(self.pinColumn) - vLayout = QtWidgets.QVBoxLayout(self) + # Filter columns by name + self.lineEditFilter = QtWidgets.QLineEdit(self) + self.lineEditFilter.setPlaceholderText('Filter columns') + self.lineEditFilter.textChanged.connect(self.changeFilter) + + # slider for alpha values + self.sliderAlpha = QtWidgets.QSlider(Qt.Horizontal, self) + self.sliderAlpha.setMinimum(0) + self.sliderAlpha.setMaximum(255) + self.sliderAlpha.setValue(self.tableModel.alpha) + self.sliderAlpha.valueChanged.connect(self.tableModel.setAlpha) + + # Horizontal layout hLayout = QtWidgets.QHBoxLayout() hLayout.addWidget(self.lineEditFilter) + hLayout.addWidget(QtWidgets.QLabel('Alpha', self)) + hLayout.addWidget(self.sliderAlpha) + hLayout.addStretch(1) hLayout.addWidget(self.pushButtonLoad) + + # Vertical layout + vLayout = QtWidgets.QVBoxLayout(self) vLayout.addLayout(hLayout) vLayout.addWidget(self.tableView) From 57ce07827098e8acf74e78ba94db9378b3cf814c Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Fri, 27 Sep 2024 12:53:13 +0100 Subject: [PATCH 20/59] add picker for colormap --- ibllib/qc/task_qc_viewer/ViewEphysQC.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/ibllib/qc/task_qc_viewer/ViewEphysQC.py b/ibllib/qc/task_qc_viewer/ViewEphysQC.py index 9a89d7303..dccd4368e 100644 --- a/ibllib/qc/task_qc_viewer/ViewEphysQC.py +++ b/ibllib/qc/task_qc_viewer/ViewEphysQC.py @@ -92,6 +92,12 @@ def __init__(self, parent: QObject = ..., dataFrame: pd.DataFrame | None = None, self.setColormap(colormap) self.setAlpha(alpha) + @pyqtSlot(str) + def setColormapByName(self, colormapName: str): + colormap = plt.get_cmap(colormapName) + colormap.set_bad(color='w') + self.setColormap(colormap) + @pyqtSlot(Colormap) def setColormap(self, colormap: Colormap): self._cmap = colormap @@ -216,6 +222,11 @@ def __init__(self, parent=None, wheel=None): self.lineEditFilter.setPlaceholderText('Filter columns') self.lineEditFilter.textChanged.connect(self.changeFilter) + # colormap picker + self.comboboxColormap = QtWidgets.QComboBox(self) + self.comboboxColormap.addItems(['plasma', 'spring', 'summer', 'autumn', 'winter']) + self.comboboxColormap.currentTextChanged.connect(self.tableModel.setColormapByName) + # slider for alpha values self.sliderAlpha = QtWidgets.QSlider(Qt.Horizontal, self) self.sliderAlpha.setMinimum(0) @@ -226,6 +237,8 @@ def __init__(self, parent=None, wheel=None): # Horizontal layout hLayout = QtWidgets.QHBoxLayout() hLayout.addWidget(self.lineEditFilter) + hLayout.addWidget(QtWidgets.QLabel('Colormap', self)) + hLayout.addWidget(self.comboboxColormap) hLayout.addWidget(QtWidgets.QLabel('Alpha', self)) hLayout.addWidget(self.sliderAlpha) hLayout.addStretch(1) From 61fbb6c25b5bf397a722649d82f1b6d22f4c14d9 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Fri, 27 Sep 2024 12:58:28 +0100 Subject: [PATCH 21/59] Update ViewEphysQC.py --- ibllib/qc/task_qc_viewer/ViewEphysQC.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ibllib/qc/task_qc_viewer/ViewEphysQC.py b/ibllib/qc/task_qc_viewer/ViewEphysQC.py index dccd4368e..e4bfb2cc4 100644 --- a/ibllib/qc/task_qc_viewer/ViewEphysQC.py +++ b/ibllib/qc/task_qc_viewer/ViewEphysQC.py @@ -224,7 +224,9 @@ def __init__(self, parent=None, wheel=None): # colormap picker self.comboboxColormap = QtWidgets.QComboBox(self) - self.comboboxColormap.addItems(['plasma', 'spring', 'summer', 'autumn', 'winter']) + colormaps = {self.tableModel.colormap.name, 'plasma', 'spring', 'summer', 'autumn', 'winter'} + self.comboboxColormap.addItems(colormaps) + self.comboboxColormap.setCurrentText(self.tableModel.colormap.name) self.comboboxColormap.currentTextChanged.connect(self.tableModel.setColormapByName) # slider for alpha values From da50ac9c291b506b199716f8da18185813b8e40b Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Sat, 28 Sep 2024 08:37:45 +0100 Subject: [PATCH 22/59] separate normalization from rgba calculation --- ibllib/qc/task_qc_viewer/ViewEphysQC.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/ibllib/qc/task_qc_viewer/ViewEphysQC.py b/ibllib/qc/task_qc_viewer/ViewEphysQC.py index e4bfb2cc4..164c47de1 100644 --- a/ibllib/qc/task_qc_viewer/ViewEphysQC.py +++ b/ibllib/qc/task_qc_viewer/ViewEphysQC.py @@ -74,6 +74,7 @@ def sort(self, column: int, order: Qt.SortOrder = ...): class ColoredDataFrameTableModel(DataFrameTableModel): colormapChanged = pyqtSignal(Colormap) alphaChanged = pyqtSignal(float) + _normalizedData = pd.DataFrame _rgba: np.ndarray _cmap: Colormap _alpha: int @@ -82,9 +83,9 @@ def __init__(self, parent: QObject = ..., dataFrame: pd.DataFrame | None = None, colormap: Colormap | None = None, alpha: int = 255): super().__init__(parent=parent, dataFrame=dataFrame) + self.modelReset.connect(self._normalizeData) + self.dataChanged.connect(self._normalizeData) self.colormapChanged.connect(self._setRgba) - self.modelReset.connect(self._setRgba) - self.dataChanged.connect(self._setRgba) if colormap is None: colormap = plt.get_cmap('spring') @@ -119,7 +120,7 @@ def getAlpha(self) -> int: alpha = pyqtProperty(int, fget=getAlpha, fset=setAlpha) - def _setRgba(self): + def _normalizeData(self): df = self._dataframe.copy() if df.empty: self._rgba = df @@ -141,8 +142,15 @@ def _setRgba(self): cols = df.select_dtypes(include=['bool']).columns df[cols] = df[cols].astype(float) - # store color values to ndarray & emit signal - self._rgba = self._cmap(df, alpha=None, bytes=True) + # store as property & call _setRgba() + self._normalizedData = df + self._setRgba() + + def _setRgba(self): + if self._normalizedData.empty: + self._rgba = np.ndarray([]) + else: + self._rgba = self._cmap(self._normalizedData, alpha=None, bytes=True) self.layoutChanged.emit() def data(self, index, role=...): From 027c2ea16268020fa3544b5ff7f05b786c3ccf5b Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Sat, 28 Sep 2024 16:43:39 +0100 Subject: [PATCH 23/59] dynamic handling of text color --- ibllib/qc/task_qc_viewer/ViewEphysQC.py | 44 ++++++++++++++++--------- 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/ibllib/qc/task_qc_viewer/ViewEphysQC.py b/ibllib/qc/task_qc_viewer/ViewEphysQC.py index 164c47de1..1fbaa3034 100644 --- a/ibllib/qc/task_qc_viewer/ViewEphysQC.py +++ b/ibllib/qc/task_qc_viewer/ViewEphysQC.py @@ -4,10 +4,10 @@ from PyQt5 import QtWidgets from PyQt5.QtCore import pyqtProperty, Qt, QVariant, QAbstractTableModel, QModelIndex, \ - QObject, QPoint, pyqtSignal, pyqtSlot + QObject, QPoint, pyqtSignal, pyqtSlot, QCoreApplication, QSettings from PyQt5.QtGui import QColor, QPalette import matplotlib.pyplot as plt -from PyQt5.QtWidgets import QMenu, QAction, QHeaderView +from PyQt5.QtWidgets import QMenu, QAction from matplotlib.colors import Colormap from matplotlib.figure import Figure from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg, NavigationToolbar2QT @@ -74,10 +74,11 @@ def sort(self, column: int, order: Qt.SortOrder = ...): class ColoredDataFrameTableModel(DataFrameTableModel): colormapChanged = pyqtSignal(Colormap) alphaChanged = pyqtSignal(float) - _normalizedData = pd.DataFrame - _rgba: np.ndarray + _normData = pd.DataFrame + _background: np.ndarray _cmap: Colormap _alpha: int + _foreground: np.ndarray def __init__(self, parent: QObject = ..., dataFrame: pd.DataFrame | None = None, colormap: Colormap | None = None, alpha: int = 255): @@ -85,10 +86,10 @@ def __init__(self, parent: QObject = ..., dataFrame: pd.DataFrame | None = None, self.modelReset.connect(self._normalizeData) self.dataChanged.connect(self._normalizeData) - self.colormapChanged.connect(self._setRgba) + self.colormapChanged.connect(self._defineColors) if colormap is None: - colormap = plt.get_cmap('spring') + colormap = plt.get_cmap('plasma') colormap.set_bad(color='w') self.setColormap(colormap) self.setAlpha(alpha) @@ -123,7 +124,7 @@ def getAlpha(self) -> int: def _normalizeData(self): df = self._dataframe.copy() if df.empty: - self._rgba = df + self._background = df return # coerce non-bool / non-numeric values to numeric @@ -143,14 +144,17 @@ def _normalizeData(self): df[cols] = df[cols].astype(float) # store as property & call _setRgba() - self._normalizedData = df - self._setRgba() + self._normData = df + self._defineColors() - def _setRgba(self): - if self._normalizedData.empty: - self._rgba = np.ndarray([]) + def _defineColors(self): + if self._normData.empty: + self._background = np.ndarray([]) + self._foreground = np.ndarray([]) else: - self._rgba = self._cmap(self._normalizedData, alpha=None, bytes=True) + self._background = self._cmap(self._normData, alpha=None, bytes=True)[:, :, :3] + brightness = (self._background * np.array([[[0.21, 0.72, 0.07]]])).sum(axis=2) + self._foreground = 255 - brightness.astype(int) self.layoutChanged.emit() def data(self, index, role=...): @@ -158,7 +162,11 @@ def data(self, index, role=...): return QVariant() if role == Qt.BackgroundRole: row = self._dataframe.index[index.row()] - return QColor.fromRgb(*self._rgba[row][index.column()][:3], self._alpha) + return QColor.fromRgb(*self._background[row][index.column()], self._alpha) + if role == Qt.ForegroundRole: + row = self._dataframe.index[index.row()] + val = self._foreground[row][index.column()] * self._alpha + return QColor('black') if val < 32512 else QColor('white') return super().data(index, role) @@ -198,6 +206,9 @@ class GraphWindow(QtWidgets.QWidget): def __init__(self, parent=None, wheel=None): QtWidgets.QWidget.__init__(self, parent=parent) + # Store layout changes to QSettings + self.settings = QSettings() + self.columnPinned = pyqtSignal(int, bool) self.pushButtonLoad = QtWidgets.QPushButton("Select File", self) @@ -232,7 +243,7 @@ def __init__(self, parent=None, wheel=None): # colormap picker self.comboboxColormap = QtWidgets.QComboBox(self) - colormaps = {self.tableModel.colormap.name, 'plasma', 'spring', 'summer', 'autumn', 'winter'} + colormaps = sorted(list({self.tableModel.colormap.name, 'cividis', 'inferno', 'magma', 'plasma', 'viridis'})) self.comboboxColormap.addItems(colormaps) self.comboboxColormap.setCurrentText(self.tableModel.colormap.name) self.comboboxColormap.currentTextChanged.connect(self.tableModel.setColormapByName) @@ -323,6 +334,9 @@ def tv_double_clicked(self, index: QModelIndex): def viewqc(qc=None, title=None, wheel=None): + QCoreApplication.setOrganizationName('International Brain Laboratory') + QCoreApplication.setOrganizationDomain('internationalbrainlab.org') + QCoreApplication.setApplicationName('QC Viewer') qt.create_app() qcw = GraphWindow(wheel=wheel) qcw.setWindowTitle(title) From 7f6dbb625bcac5fb8c83c6e96cd0fb8d70d8feb9 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Sat, 28 Sep 2024 22:58:07 +0100 Subject: [PATCH 24/59] switch to using pyqtgraph's colormaps --- ibllib/qc/task_qc_viewer/ViewEphysQC.py | 62 ++++++++++++------------- 1 file changed, 30 insertions(+), 32 deletions(-) diff --git a/ibllib/qc/task_qc_viewer/ViewEphysQC.py b/ibllib/qc/task_qc_viewer/ViewEphysQC.py index 1fbaa3034..731d05bb2 100644 --- a/ibllib/qc/task_qc_viewer/ViewEphysQC.py +++ b/ibllib/qc/task_qc_viewer/ViewEphysQC.py @@ -5,14 +5,13 @@ from PyQt5 import QtWidgets from PyQt5.QtCore import pyqtProperty, Qt, QVariant, QAbstractTableModel, QModelIndex, \ QObject, QPoint, pyqtSignal, pyqtSlot, QCoreApplication, QSettings -from PyQt5.QtGui import QColor, QPalette -import matplotlib.pyplot as plt +from PyQt5.QtGui import QColor, QPalette, QShowEvent from PyQt5.QtWidgets import QMenu, QAction -from matplotlib.colors import Colormap from matplotlib.figure import Figure from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg, NavigationToolbar2QT import pandas as pd import numpy as np +from pyqtgraph import colormap, ColorMap from ibllib.misc import qt @@ -72,43 +71,36 @@ def sort(self, column: int, order: Qt.SortOrder = ...): class ColoredDataFrameTableModel(DataFrameTableModel): - colormapChanged = pyqtSignal(Colormap) - alphaChanged = pyqtSignal(float) + colormapChanged = pyqtSignal(str) + alphaChanged = pyqtSignal(int) _normData = pd.DataFrame _background: np.ndarray - _cmap: Colormap + _cmap: ColorMap _alpha: int _foreground: np.ndarray def __init__(self, parent: QObject = ..., dataFrame: pd.DataFrame | None = None, - colormap: Colormap | None = None, alpha: int = 255): + colormap: str = 'plasma', alpha: int = 255): super().__init__(parent=parent, dataFrame=dataFrame) - self.modelReset.connect(self._normalizeData) self.dataChanged.connect(self._normalizeData) self.colormapChanged.connect(self._defineColors) - - if colormap is None: - colormap = plt.get_cmap('plasma') - colormap.set_bad(color='w') self.setColormap(colormap) self.setAlpha(alpha) @pyqtSlot(str) - def setColormapByName(self, colormapName: str): - colormap = plt.get_cmap(colormapName) - colormap.set_bad(color='w') - self.setColormap(colormap) + def setColormap(self, name: str): + for source in [None, 'matplotlib', 'colorcet']: + if name in colormap.listMaps(source): + self._cmap = colormap.get(name, source) + self.colormapChanged.emit(name) + return + _logger.warning(f'No such colormap: "{name}"') - @pyqtSlot(Colormap) - def setColormap(self, colormap: Colormap): - self._cmap = colormap - self.colormapChanged.emit(colormap) + def getColormap(self) -> str: + return self._cmap.name - def getColormap(self) -> Colormap: - return self._cmap - - colormap = pyqtProperty(Colormap, fget=getColormap, fset=setColormap) + colormap = pyqtProperty(str, fget=getColormap, fset=setColormap) @pyqtSlot(int) def setAlpha(self, alpha: int = 255): @@ -152,9 +144,10 @@ def _defineColors(self): self._background = np.ndarray([]) self._foreground = np.ndarray([]) else: - self._background = self._cmap(self._normData, alpha=None, bytes=True)[:, :, :3] - brightness = (self._background * np.array([[[0.21, 0.72, 0.07]]])).sum(axis=2) - self._foreground = 255 - brightness.astype(int) + m = np.isfinite(self._normData) # binary mask for finite values + self._background = np.ones((*self._normData.shape, 3), dtype=int) * 255 + self._background[m] = self._cmap.mapToByte(self._normData.values[m])[:, :3] + self._foreground = 255 - (self._background * np.array([[[0.21, 0.72, 0.07]]])).sum(axis=2).astype(int) self.layoutChanged.emit() def data(self, index, role=...): @@ -162,7 +155,8 @@ def data(self, index, role=...): return QVariant() if role == Qt.BackgroundRole: row = self._dataframe.index[index.row()] - return QColor.fromRgb(*self._background[row][index.column()], self._alpha) + val = self._background[row][index.column()] + return QColor.fromRgb(*val, self._alpha) if role == Qt.ForegroundRole: row = self._dataframe.index[index.row()] val = self._foreground[row][index.column()] * self._alpha @@ -243,10 +237,10 @@ def __init__(self, parent=None, wheel=None): # colormap picker self.comboboxColormap = QtWidgets.QComboBox(self) - colormaps = sorted(list({self.tableModel.colormap.name, 'cividis', 'inferno', 'magma', 'plasma', 'viridis'})) - self.comboboxColormap.addItems(colormaps) - self.comboboxColormap.setCurrentText(self.tableModel.colormap.name) - self.comboboxColormap.currentTextChanged.connect(self.tableModel.setColormapByName) + colormaps = {self.tableModel.colormap, 'inferno', 'magma', 'plasma'} + self.comboboxColormap.addItems(sorted(list(colormaps))) + self.comboboxColormap.setCurrentText(self.tableModel.colormap) + self.comboboxColormap.currentTextChanged.connect(self.tableModel.setColormap) # slider for alpha values self.sliderAlpha = QtWidgets.QSlider(Qt.Horizontal, self) @@ -276,6 +270,10 @@ def __init__(self, parent=None, wheel=None): self.wheel = wheel + def showEvent(self, a0: QShowEvent) -> None: + super().showEvent(a0) + self.activateWindow() + def contextMenu(self, pos: QPoint): idx = self.sender().logicalIndexAt(pos) action = self.pinAction From 5608ec1c8bc3ceeace16b262a4949e4665a889ee Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Sun, 29 Sep 2024 19:34:32 +0100 Subject: [PATCH 25/59] filter by tokens --- ibllib/qc/task_qc_viewer/ViewEphysQC.py | 49 ++++++++++++++++--------- 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/ibllib/qc/task_qc_viewer/ViewEphysQC.py b/ibllib/qc/task_qc_viewer/ViewEphysQC.py index 731d05bb2..34672530b 100644 --- a/ibllib/qc/task_qc_viewer/ViewEphysQC.py +++ b/ibllib/qc/task_qc_viewer/ViewEphysQC.py @@ -73,11 +73,11 @@ def sort(self, column: int, order: Qt.SortOrder = ...): class ColoredDataFrameTableModel(DataFrameTableModel): colormapChanged = pyqtSignal(str) alphaChanged = pyqtSignal(int) - _normData = pd.DataFrame + _normData = pd.DataFrame() _background: np.ndarray + _foreground: np.ndarray _cmap: ColorMap _alpha: int - _foreground: np.ndarray def __init__(self, parent: QObject = ..., dataFrame: pd.DataFrame | None = None, colormap: str = 'plasma', alpha: int = 255): @@ -115,9 +115,6 @@ def getAlpha(self) -> int: def _normalizeData(self): df = self._dataframe.copy() - if df.empty: - self._background = df - return # coerce non-bool / non-numeric values to numeric cols = df.select_dtypes(exclude=['bool', 'number']).columns @@ -135,14 +132,23 @@ def _normalizeData(self): cols = df.select_dtypes(include=['bool']).columns df[cols] = df[cols].astype(float) - # store as property & call _setRgba() + # store as property & call _defineColors() self._normData = df self._defineColors() def _defineColors(self): + """ + Define the background and foreground colors according to the table's data. + + The background color is set to the colormap-mapped values of the normalized + data, and the foreground color is set to the inverse of the background's + approximated luminosity. + + The `layoutChanged` signal is emitted after the colors are defined. + """ if self._normData.empty: - self._background = np.ndarray([]) - self._foreground = np.ndarray([]) + self._background = np.zeros((0, 0, 3), dtype=int) + self._foreground = np.zeros((0, 0), dtype=int) else: m = np.isfinite(self._normData) # binary mask for finite values self._background = np.ones((*self._normData.shape, 3), dtype=int) * 255 @@ -153,14 +159,15 @@ def _defineColors(self): def data(self, index, role=...): if not index.isValid(): return QVariant() - if role == Qt.BackgroundRole: - row = self._dataframe.index[index.row()] - val = self._background[row][index.column()] - return QColor.fromRgb(*val, self._alpha) - if role == Qt.ForegroundRole: + if role in (Qt.BackgroundRole, Qt.ForegroundRole): row = self._dataframe.index[index.row()] - val = self._foreground[row][index.column()] * self._alpha - return QColor('black') if val < 32512 else QColor('white') + col = index.column() + if role == Qt.BackgroundRole: + val = self._background[row][col] + return QColor.fromRgb(*val, self._alpha) + if role == Qt.ForegroundRole: + val = self._foreground[row][col] + return QColor('black' if (val * self._alpha) < 32512 else 'white') return super().data(index, role) @@ -234,6 +241,7 @@ def __init__(self, parent=None, wheel=None): self.lineEditFilter = QtWidgets.QLineEdit(self) self.lineEditFilter.setPlaceholderText('Filter columns') self.lineEditFilter.textChanged.connect(self.changeFilter) + self.lineEditFilter.setMinimumWidth(200) # colormap picker self.comboboxColormap = QtWidgets.QComboBox(self) @@ -252,6 +260,7 @@ def __init__(self, parent=None, wheel=None): # Horizontal layout hLayout = QtWidgets.QHBoxLayout() hLayout.addWidget(self.lineEditFilter) + hLayout.addSpacing(50) hLayout.addWidget(QtWidgets.QLabel('Colormap', self)) hLayout.addWidget(self.comboboxColormap) hLayout.addWidget(QtWidgets.QLabel('Alpha', self)) @@ -264,6 +273,8 @@ def __init__(self, parent=None, wheel=None): vLayout.addLayout(hLayout) vLayout.addWidget(self.tableView) + self.setMinimumSize(500, 400) + self.wplot = PlotWindow(wheel=wheel) self.wplot.show() self.tableModel.dataChanged.connect(self.wplot.canvas.draw) @@ -294,11 +305,13 @@ def pinColumn(self, pin: bool, idx: int | None = None): self.changeFilter(self.lineEditFilter.text()) def changeFilter(self, string: str): - headers = [self.tableModel.headerData(x, Qt.Horizontal, Qt.DisplayRole) + headers = [self.tableModel.headerData(x, Qt.Horizontal, Qt.DisplayRole).lower() for x in range(self.tableModel.columnCount())] + tokens = [y.lower() for y in (x.strip() for x in string.split(',')) if len(y)] + showAll = len(tokens) == 0 for idx, column in enumerate(headers): - self.tableView.setColumnHidden(idx, string.lower() not in column.lower() - and idx not in self._pinnedColumns) + show = showAll or any((t in column for t in tokens)) or idx in self._pinnedColumns + self.tableView.setColumnHidden(idx, not show) def loadFile(self): fileName, _ = QtWidgets.QFileDialog.getOpenFileName( From 02e9a2cd268ca613d74473baac26b69ea72046cb Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Tue, 1 Oct 2024 11:23:00 +0100 Subject: [PATCH 26/59] move models to iblqt --- ibllib/qc/task_qc_viewer/ViewEphysQC.py | 164 +----------------------- requirements.txt | 1 + 2 files changed, 6 insertions(+), 159 deletions(-) diff --git a/ibllib/qc/task_qc_viewer/ViewEphysQC.py b/ibllib/qc/task_qc_viewer/ViewEphysQC.py index 34672530b..420c5ea9a 100644 --- a/ibllib/qc/task_qc_viewer/ViewEphysQC.py +++ b/ibllib/qc/task_qc_viewer/ViewEphysQC.py @@ -3,174 +3,20 @@ import logging from PyQt5 import QtWidgets -from PyQt5.QtCore import pyqtProperty, Qt, QVariant, QAbstractTableModel, QModelIndex, \ - QObject, QPoint, pyqtSignal, pyqtSlot, QCoreApplication, QSettings -from PyQt5.QtGui import QColor, QPalette, QShowEvent +from PyQt5.QtCore import Qt, QModelIndex, QPoint, pyqtSignal, pyqtSlot, QCoreApplication, QSettings +from PyQt5.QtGui import QPalette, QShowEvent from PyQt5.QtWidgets import QMenu, QAction +from iblqt.core import ColoredDataFrameTableModel from matplotlib.figure import Figure from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg, NavigationToolbar2QT import pandas as pd import numpy as np -from pyqtgraph import colormap, ColorMap from ibllib.misc import qt _logger = logging.getLogger(__name__) -class DataFrameTableModel(QAbstractTableModel): - def __init__(self, parent: QObject = ..., dataFrame: pd.DataFrame | None = None): - super().__init__(parent) - self._dataframe = pd.DataFrame() if dataFrame is None else dataFrame - - def setDataFrame(self, dataFrame: pd.DataFrame): - self.beginResetModel() - self._dataframe = dataFrame.copy() - self.endResetModel() - - def dataFrame(self) -> pd.DataFrame: - return self._dataframe - - dataFrame = pyqtProperty(pd.DataFrame, fget=dataFrame, fset=setDataFrame) - - def headerData(self, section: int, orientation: Qt.Orientation, role: int = ...): - if role in (Qt.DisplayRole, Qt.ToolTipRole): - if orientation == Qt.Horizontal: - return self._dataframe.columns[section] - else: - return str(self._dataframe.index[section]) - return QVariant() - - def rowCount(self, parent: QModelIndex = ...): - if isinstance(parent, QModelIndex) and parent.isValid(): - return 0 - return len(self._dataframe.index) - - def columnCount(self, parent: QModelIndex = ...): - if isinstance(parent, QModelIndex) and parent.isValid(): - return 0 - return self._dataframe.columns.size - - def data(self, index: QModelIndex, role: int = ...) -> QVariant: - if not index.isValid(): - return QVariant() - if role == Qt.DisplayRole: - val = self._dataframe.iloc[index.row(), index.column()] - if isinstance(val, np.generic): - return val.item() - return QVariant(str(val)) - return QVariant() - - def sort(self, column: int, order: Qt.SortOrder = ...): - if self.columnCount() == 0: - return - column = self._dataframe.columns[column] - self.layoutAboutToBeChanged.emit() - self._dataframe.sort_values(by=column, ascending=not order, inplace=True) - self.layoutChanged.emit() - - -class ColoredDataFrameTableModel(DataFrameTableModel): - colormapChanged = pyqtSignal(str) - alphaChanged = pyqtSignal(int) - _normData = pd.DataFrame() - _background: np.ndarray - _foreground: np.ndarray - _cmap: ColorMap - _alpha: int - - def __init__(self, parent: QObject = ..., dataFrame: pd.DataFrame | None = None, - colormap: str = 'plasma', alpha: int = 255): - super().__init__(parent=parent, dataFrame=dataFrame) - self.modelReset.connect(self._normalizeData) - self.dataChanged.connect(self._normalizeData) - self.colormapChanged.connect(self._defineColors) - self.setColormap(colormap) - self.setAlpha(alpha) - - @pyqtSlot(str) - def setColormap(self, name: str): - for source in [None, 'matplotlib', 'colorcet']: - if name in colormap.listMaps(source): - self._cmap = colormap.get(name, source) - self.colormapChanged.emit(name) - return - _logger.warning(f'No such colormap: "{name}"') - - def getColormap(self) -> str: - return self._cmap.name - - colormap = pyqtProperty(str, fget=getColormap, fset=setColormap) - - @pyqtSlot(int) - def setAlpha(self, alpha: int = 255): - _, self._alpha, _ = sorted([0, alpha, 255]) - self.alphaChanged.emit(self._alpha) - self.layoutChanged.emit() - - def getAlpha(self) -> int: - return self._alpha - - alpha = pyqtProperty(int, fget=getAlpha, fset=setAlpha) - - def _normalizeData(self): - df = self._dataframe.copy() - - # coerce non-bool / non-numeric values to numeric - cols = df.select_dtypes(exclude=['bool', 'number']).columns - df[cols] = df[cols].apply(pd.to_numeric, errors='coerce') - - # normalize numeric values, avoiding inf values and division by zero - cols = df.select_dtypes(include=['number']).columns - df[cols].replace([np.inf, -np.inf], np.nan) - m = df[cols].nunique() <= 1 # boolean mask for columns with only 1 unique value - df[cols[m]] = df[cols[m]].where(df[cols[m]].isna(), other=0.0) - cols = cols[~m] - df[cols] = (df[cols] - df[cols].min()) / (df[cols].max() - df[cols].min()) - - # convert boolean values - cols = df.select_dtypes(include=['bool']).columns - df[cols] = df[cols].astype(float) - - # store as property & call _defineColors() - self._normData = df - self._defineColors() - - def _defineColors(self): - """ - Define the background and foreground colors according to the table's data. - - The background color is set to the colormap-mapped values of the normalized - data, and the foreground color is set to the inverse of the background's - approximated luminosity. - - The `layoutChanged` signal is emitted after the colors are defined. - """ - if self._normData.empty: - self._background = np.zeros((0, 0, 3), dtype=int) - self._foreground = np.zeros((0, 0), dtype=int) - else: - m = np.isfinite(self._normData) # binary mask for finite values - self._background = np.ones((*self._normData.shape, 3), dtype=int) * 255 - self._background[m] = self._cmap.mapToByte(self._normData.values[m])[:, :3] - self._foreground = 255 - (self._background * np.array([[[0.21, 0.72, 0.07]]])).sum(axis=2).astype(int) - self.layoutChanged.emit() - - def data(self, index, role=...): - if not index.isValid(): - return QVariant() - if role in (Qt.BackgroundRole, Qt.ForegroundRole): - row = self._dataframe.index[index.row()] - col = index.column() - if role == Qt.BackgroundRole: - val = self._background[row][col] - return QColor.fromRgb(*val, self._alpha) - if role == Qt.ForegroundRole: - val = self._foreground[row][col] - return QColor('black' if (val * self._alpha) < 32512 else 'white') - return super().data(index, role) - - class PlotCanvas(FigureCanvasQTAgg): def __init__(self, parent=None, width=5, height=4, dpi=100, wheel=None): fig = Figure(figsize=(width, height), dpi=dpi) @@ -245,7 +91,7 @@ def __init__(self, parent=None, wheel=None): # colormap picker self.comboboxColormap = QtWidgets.QComboBox(self) - colormaps = {self.tableModel.colormap, 'inferno', 'magma', 'plasma'} + colormaps = {self.tableModel.colormap, 'inferno', 'magma', 'plasma', 'summer'} self.comboboxColormap.addItems(sorted(list(colormaps))) self.comboboxColormap.setCurrentText(self.tableModel.colormap) self.comboboxColormap.currentTextChanged.connect(self.tableModel.setColormap) @@ -305,7 +151,7 @@ def pinColumn(self, pin: bool, idx: int | None = None): self.changeFilter(self.lineEditFilter.text()) def changeFilter(self, string: str): - headers = [self.tableModel.headerData(x, Qt.Horizontal, Qt.DisplayRole).lower() + headers = [self.tableModel.headerData(x, Qt.Horizontal, Qt.DisplayRole).value().lower() for x in range(self.tableModel.columnCount())] tokens = [y.lower() for y in (x.strip() for x in string.split(',')) if len(y)] showAll = len(tokens) == 0 diff --git a/requirements.txt b/requirements.txt index 7524c22f3..005e43b5e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,6 +25,7 @@ tqdm>=4.32.1 iblatlas>=0.5.3 ibl-neuropixel>=1.0.1 iblutil>=1.11.0 +iblqt>=0.1.0 mtscomp>=1.0.1 ONE-api~=2.9.rc0 phylib>=2.6.0 From af0abe3080a9ce1990ed8f0129ea4aa55687e7cf Mon Sep 17 00:00:00 2001 From: owinter Date: Wed, 2 Oct 2024 18:41:54 +0100 Subject: [PATCH 27/59] fix stim freeze indexing issue in ephys_fpga extraction --- ibllib/io/extractors/ephys_fpga.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ibllib/io/extractors/ephys_fpga.py b/ibllib/io/extractors/ephys_fpga.py index eb6c5141d..e7ae93220 100644 --- a/ibllib/io/extractors/ephys_fpga.py +++ b/ibllib/io/extractors/ephys_fpga.py @@ -936,6 +936,7 @@ def build_trials(self, sync, chmap, display=False, **kwargs): t_trial_start = np.sort(np.r_[fpga_events['intervals_0'][:, 0], missing_bpod]) else: t_trial_start = fpga_events['intervals_0'] + t_trial_start = t_trial_start[ifpga] out = alfio.AlfBunch() # Add the Bpod trial events, converting the timestamp fields to FPGA time. @@ -960,9 +961,9 @@ def build_trials(self, sync, chmap, display=False, **kwargs): # f2ttl times are unreliable owing to calibration and Bonsai sync square update issues. # Take the first event after the FPGA aligned stimulus trigger time. - fpga_trials['stimOn_times'][ibpod] = _assign_events_to_trial( + fpga_trials['stimOn_times'] = _assign_events_to_trial( out['stimOnTrigger_times'], f2ttl_t, take='first', t_trial_end=out['stimOffTrigger_times']) - fpga_trials['stimOff_times'][ibpod] = _assign_events_to_trial( + fpga_trials['stimOff_times'] = _assign_events_to_trial( out['stimOffTrigger_times'], f2ttl_t, take='first', t_trial_end=out['intervals'][:, 1]) # For stim freeze we take the last event before the stim off trigger time. # To avoid assigning early events (e.g. for sessions where there are few flips due to @@ -981,13 +982,12 @@ def build_trials(self, sync, chmap, display=False, **kwargs): # take last event after freeze/stim on trigger, before stim off trigger stim_freeze = _assign_events_to_trial(lims, f2ttl_t, take='last', t_trial_end=out['stimOffTrigger_times']) fpga_trials['stimFreeze_times'][go_trials] = stim_freeze[go_trials] - # Feedback times are valve open on correct trials and error tone in on incorrect trials fpga_trials['feedback_times'] = np.copy(fpga_trials['valveOpen_times']) ind_err = np.isnan(fpga_trials['valveOpen_times']) fpga_trials['feedback_times'][ind_err] = fpga_trials['errorCue_times'][ind_err] - out.update({k: fpga_trials[k][ifpga] for k in fpga_trials.keys()}) + out.update({k: fpga_trials[k] for k in fpga_trials.keys()}) if display: # pragma: no cover width = 0.5 From 4a3627fc3e22ab69fcdd0402a41d5b04009183e7 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Thu, 3 Oct 2024 10:21:12 +0100 Subject: [PATCH 28/59] change ITI constants from 1s to 500ms --- ibllib/qc/task_metrics.py | 2 +- ibllib/tests/qc/test_task_metrics.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ibllib/qc/task_metrics.py b/ibllib/qc/task_metrics.py index c510d271c..2a7a4934f 100644 --- a/ibllib/qc/task_metrics.py +++ b/ibllib/qc/task_metrics.py @@ -679,7 +679,7 @@ def check_iti_delays(data, subtract_pauses=False, **_): An array of boolean values, 1 per trial, where True means trial passes QC threshold. """ # Initialize array the length of completed trials - ITI = 1. + ITI = .5 metric = np.full(data['intervals'].shape[0], np.nan) passed = metric.copy() pauses = (data['pause_duration'] if subtract_pauses else np.zeros_like(metric))[:-1] diff --git a/ibllib/tests/qc/test_task_metrics.py b/ibllib/tests/qc/test_task_metrics.py index a0433332a..6e4a206b4 100644 --- a/ibllib/tests/qc/test_task_metrics.py +++ b/ibllib/tests/qc/test_task_metrics.py @@ -159,7 +159,7 @@ def load_fake_bpod_data(n=5): # add a 5s pause on 3rd trial pauses[2] = 5. quiescence_length = 0.2 + np.random.standard_exponential(size=(n,)) - iti_length = 1 # inter-trial interval + iti_length = .5 # inter-trial interval # trial lengths include quiescence period, a couple small trigger delays and iti trial_lengths = quiescence_length + resp_feeback_delay + (trigg_delay * 4) + iti_length # add on 60s for nogos + feedback time (1 or 2s) + ~0.5s for other responses From 04ef5bc1a5101e17164e81ea2d1d31be0c5e7c84 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Fri, 4 Oct 2024 11:11:33 +0100 Subject: [PATCH 29/59] sort dataframe, store UI settings --- ibllib/qc/task_qc_viewer/ViewEphysQC.py | 100 ++++++++++++++++++------ ruff.toml | 4 + 2 files changed, 78 insertions(+), 26 deletions(-) create mode 100644 ruff.toml diff --git a/ibllib/qc/task_qc_viewer/ViewEphysQC.py b/ibllib/qc/task_qc_viewer/ViewEphysQC.py index 420c5ea9a..d4178260b 100644 --- a/ibllib/qc/task_qc_viewer/ViewEphysQC.py +++ b/ibllib/qc/task_qc_viewer/ViewEphysQC.py @@ -3,7 +3,16 @@ import logging from PyQt5 import QtWidgets -from PyQt5.QtCore import Qt, QModelIndex, QPoint, pyqtSignal, pyqtSlot, QCoreApplication, QSettings +from PyQt5.QtCore import ( + Qt, + QModelIndex, + pyqtSignal, + pyqtSlot, + QCoreApplication, + QSettings, + QSize, + QPoint, +) from PyQt5.QtGui import QPalette, QShowEvent from PyQt5.QtWidgets import QMenu, QAction from iblqt.core import ColoredDataFrameTableModel @@ -24,14 +33,10 @@ def __init__(self, parent=None, width=5, height=4, dpi=100, wheel=None): FigureCanvasQTAgg.__init__(self, fig) self.setParent(parent) - FigureCanvasQTAgg.setSizePolicy( - self, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding - ) + FigureCanvasQTAgg.setSizePolicy(self, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding) FigureCanvasQTAgg.updateGeometry(self) if wheel: - self.ax, self.ax2 = fig.subplots( - 2, 1, gridspec_kw={"height_ratios": [2, 1]}, sharex=True - ) + self.ax, self.ax2 = fig.subplots(2, 1, gridspec_kw={'height_ratios': [2, 1]}, sharex=True) else: self.ax = fig.add_subplot(111) self.draw() @@ -53,12 +58,10 @@ class GraphWindow(QtWidgets.QWidget): def __init__(self, parent=None, wheel=None): QtWidgets.QWidget.__init__(self, parent=parent) - # Store layout changes to QSettings - self.settings = QSettings() - self.columnPinned = pyqtSignal(int, bool) - self.pushButtonLoad = QtWidgets.QPushButton("Select File", self) + # load button + self.pushButtonLoad = QtWidgets.QPushButton('Select File', self) self.pushButtonLoad.clicked.connect(self.loadFile) # define table model & view @@ -70,6 +73,7 @@ def __init__(self, parent=None, wheel=None): self.tableView.horizontalHeader().setSectionsMovable(True) self.tableView.horizontalHeader().setContextMenuPolicy(Qt.CustomContextMenu) self.tableView.horizontalHeader().customContextMenuRequested.connect(self.contextMenu) + self.tableView.verticalHeader().hide() self.tableView.doubleClicked.connect(self.tv_double_clicked) # define colors for highlighted cells @@ -98,6 +102,7 @@ def __init__(self, parent=None, wheel=None): # slider for alpha values self.sliderAlpha = QtWidgets.QSlider(Qt.Horizontal, self) + self.sliderAlpha.setMaximumWidth(100) self.sliderAlpha.setMinimum(0) self.sliderAlpha.setMaximum(255) self.sliderAlpha.setValue(self.tableModel.alpha) @@ -111,7 +116,7 @@ def __init__(self, parent=None, wheel=None): hLayout.addWidget(self.comboboxColormap) hLayout.addWidget(QtWidgets.QLabel('Alpha', self)) hLayout.addWidget(self.sliderAlpha) - hLayout.addStretch(1) + hLayout.addSpacing(50) hLayout.addWidget(self.pushButtonLoad) # Vertical layout @@ -119,7 +124,13 @@ def __init__(self, parent=None, wheel=None): vLayout.addLayout(hLayout) vLayout.addWidget(self.tableView) - self.setMinimumSize(500, 400) + # Recover layout from QSettings + self.settings = QSettings() + self.settings.beginGroup('MainWindow') + self.resize(self.settings.value('size', QSize(800, 600), QSize)) + self.comboboxColormap.setCurrentText(self.settings.value('colormap', 'plasma', str)) + self.sliderAlpha.setValue(self.settings.value('alpha', 255, int)) + self.settings.endGroup() self.wplot = PlotWindow(wheel=wheel) self.wplot.show() @@ -127,6 +138,14 @@ def __init__(self, parent=None, wheel=None): self.wheel = wheel + def closeEvent(self, _) -> bool: + self.settings.beginGroup('MainWindow') + self.settings.setValue('size', self.size()) + self.settings.setValue('colormap', self.tableModel.colormap) + self.settings.setValue('alpha', self.tableModel.alpha) + self.settings.endGroup() + self.wplot.close() + def showEvent(self, a0: QShowEvent) -> None: super().showEvent(a0) self.activateWindow() @@ -151,8 +170,10 @@ def pinColumn(self, pin: bool, idx: int | None = None): self.changeFilter(self.lineEditFilter.text()) def changeFilter(self, string: str): - headers = [self.tableModel.headerData(x, Qt.Horizontal, Qt.DisplayRole).value().lower() - for x in range(self.tableModel.columnCount())] + headers = [ + self.tableModel.headerData(x, Qt.Horizontal, Qt.DisplayRole).value().lower() + for x in range(self.tableModel.columnCount()) + ] tokens = [y.lower() for y in (x.strip() for x in string.split(',')) if len(y)] showAll = len(tokens) == 0 for idx, column in enumerate(headers): @@ -160,27 +181,53 @@ def changeFilter(self, string: str): self.tableView.setColumnHidden(idx, not show) def loadFile(self): - fileName, _ = QtWidgets.QFileDialog.getOpenFileName( - self, "Open File", "", "CSV Files (*.csv)" - ) + fileName, _ = QtWidgets.QFileDialog.getOpenFileName(self, 'Open File', '', 'CSV Files (*.csv)') if len(fileName) == 0: return df = pd.read_csv(fileName) self.updateDataframe(df) - def updateDataframe(self, dataFrame: pd.DataFrame): - self.tableModel.setDataFrame(dataFrame) + def updateDataframe(self, df: pd.DataFrame): + # clear pinned columns + self._pinnedColumns = [] + + # try to identify and sort columns containing timestamps + col_names = df.columns + df_interp = df.replace([-np.inf, np.inf], np.nan) + df_interp = df_interp.interpolate(limit_direction='both') + cols_mono = col_names[[df_interp[c].is_monotonic_increasing for c in col_names]] + cols_mono = [c for c in cols_mono if df[c].nunique() > 1] + cols_mono = df_interp[cols_mono].mean().sort_values().keys() + for idx, col_name in enumerate(cols_mono): + df.insert(idx, col_name, df.pop(col_name)) + + # columns containing boolean values are sorted to the end + cols_bool = list(df.select_dtypes('bool').columns) + cols_pass = [cols_bool.pop(i) for i, c in enumerate(cols_bool) if 'pass' in c] + cols_bool += cols_pass + for col_name in cols_bool: + df = df.join(df.pop(col_name)) + + # trial_no should always be the first column + if 'trial_no' in col_names: + df.insert(0, 'trial_no', df.pop('trial_no')) + + # define columns that should be pinned by default + for col in ['trial_no']: + self._pinnedColumns.append(df.columns.get_loc(col)) + + self.tableModel.setDataFrame(df) def tv_double_clicked(self, index: QModelIndex): data = self.tableModel.dataFrame.iloc[index.row()] - t0 = data["intervals_0"] - t1 = data["intervals_1"] + t0 = data['intervals_0'] + t1 = data['intervals_1'] dt = t1 - t0 if self.wheel: - idx = np.searchsorted(self.wheel["re_ts"], np.array([t0 - dt / 10, t1 + dt / 10])) - period = self.wheel["re_pos"][idx[0] : idx[1]] + idx = np.searchsorted(self.wheel['re_ts'], np.array([t0 - dt / 10, t1 + dt / 10])) + period = self.wheel['re_pos'][idx[0] : idx[1]] if period.size == 0: - _logger.warning("No wheel data during trial #%i", index.row()) + _logger.warning('No wheel data during trial #%i', index.row()) else: min_val, max_val = np.min(period), np.max(period) self.wplot.canvas.ax2.set_ylim(min_val - 1, max_val + 1) @@ -191,10 +238,11 @@ def tv_double_clicked(self, index: QModelIndex): def viewqc(qc=None, title=None, wheel=None): + app = qt.create_app() + app.setStyle('Fusion') QCoreApplication.setOrganizationName('International Brain Laboratory') QCoreApplication.setOrganizationDomain('internationalbrainlab.org') QCoreApplication.setApplicationName('QC Viewer') - qt.create_app() qcw = GraphWindow(wheel=wheel) qcw.setWindowTitle(title) if qc is not None: diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 000000000..253516e9f --- /dev/null +++ b/ruff.toml @@ -0,0 +1,4 @@ +line-length = 130 + +[format] +quote-style = "single" From 72aec0ffbbada6302d44f74d43c9592e1735a843 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Fri, 4 Oct 2024 11:12:37 +0100 Subject: [PATCH 30/59] Update requirements.txt --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 005e43b5e..bf2ec84a1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,7 +25,7 @@ tqdm>=4.32.1 iblatlas>=0.5.3 ibl-neuropixel>=1.0.1 iblutil>=1.11.0 -iblqt>=0.1.0 +iblqt>=0.1.2 mtscomp>=1.0.1 ONE-api~=2.9.rc0 phylib>=2.6.0 From c6d480b66986177f94561bb4f74a20c8374f3967 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Fri, 4 Oct 2024 18:43:18 +0100 Subject: [PATCH 31/59] add passing status of individual tests --- ibllib/qc/task_qc_viewer/ViewEphysQC.py | 12 ++++++++---- ibllib/qc/task_qc_viewer/task_qc.py | 7 ++++++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/ibllib/qc/task_qc_viewer/ViewEphysQC.py b/ibllib/qc/task_qc_viewer/ViewEphysQC.py index d4178260b..8a7679aac 100644 --- a/ibllib/qc/task_qc_viewer/ViewEphysQC.py +++ b/ibllib/qc/task_qc_viewer/ViewEphysQC.py @@ -192,8 +192,8 @@ def updateDataframe(self, df: pd.DataFrame): self._pinnedColumns = [] # try to identify and sort columns containing timestamps - col_names = df.columns - df_interp = df.replace([-np.inf, np.inf], np.nan) + col_names = df.select_dtypes('number').columns + df_interp = df[col_names].replace([-np.inf, np.inf], np.nan) df_interp = df_interp.interpolate(limit_direction='both') cols_mono = col_names[[df_interp[c].is_monotonic_increasing for c in col_names]] cols_mono = [c for c in cols_mono if df[c].nunique() > 1] @@ -202,8 +202,12 @@ def updateDataframe(self, df: pd.DataFrame): df.insert(idx, col_name, df.pop(col_name)) # columns containing boolean values are sorted to the end - cols_bool = list(df.select_dtypes('bool').columns) - cols_pass = [cols_bool.pop(i) for i, c in enumerate(cols_bool) if 'pass' in c] + # of those, columns containing 'pass' in their title will be sorted by number of False values + col_names = df.columns + cols_bool = list(df.select_dtypes(['bool', 'boolean']).columns) + cols_pass = [c for c in cols_bool if 'pass' in c] + cols_bool = [c for c in cols_bool if c not in cols_pass] # I know. Friday evening, brain is fried ... sorry. + cols_pass = list((~df[cols_pass]).sum().sort_values().keys()) cols_bool += cols_pass for col_name in cols_bool: df = df.join(df.pop(col_name)) diff --git a/ibllib/qc/task_qc_viewer/task_qc.py b/ibllib/qc/task_qc_viewer/task_qc.py index b8fc1749f..8d0c034c9 100644 --- a/ibllib/qc/task_qc_viewer/task_qc.py +++ b/ibllib/qc/task_qc_viewer/task_qc.py @@ -289,7 +289,12 @@ def show_session_task_qc(qc_or_session=None, bpod_only=False, local=False, one=N # Update table and callbacks n_trials = qc.frame.shape[0] df_trials = pd.DataFrame({k: v for k, v in task_qc.extractor.data.items() if v.size == n_trials}) - w.updateDataframe(df_trials.merge(qc.frame, left_index=True, right_index=True)) + df_pass = pd.DataFrame({k: v for k, v in qc.qc.passed.items() if isinstance(v, np.ndarray) and v.size == n_trials}) + df_pass.drop('_task_passed_trial_checks', axis=1, errors='ignore', inplace=True) + df_pass.rename(columns=lambda x: x.replace('_task', 'passed'), inplace=True) + df = df_trials.merge(qc.frame, left_index=True, right_index=True) + df = df.merge(df_pass.astype('boolean'), left_index=True, right_index=True) + w.updateDataframe(df) qt.run_app() return qc From 4c58d00345ded61c2ca94ae2aa65f6aefac7c9ad Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Wed, 9 Oct 2024 14:12:47 +0100 Subject: [PATCH 32/59] require iblqt >= 0.2.0 --- ibllib/qc/task_qc_viewer/ViewEphysQC.py | 2 +- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ibllib/qc/task_qc_viewer/ViewEphysQC.py b/ibllib/qc/task_qc_viewer/ViewEphysQC.py index 8a7679aac..4e6684313 100644 --- a/ibllib/qc/task_qc_viewer/ViewEphysQC.py +++ b/ibllib/qc/task_qc_viewer/ViewEphysQC.py @@ -171,7 +171,7 @@ def pinColumn(self, pin: bool, idx: int | None = None): def changeFilter(self, string: str): headers = [ - self.tableModel.headerData(x, Qt.Horizontal, Qt.DisplayRole).value().lower() + self.tableModel.headerData(x, Qt.Horizontal, Qt.DisplayRole).lower() for x in range(self.tableModel.columnCount()) ] tokens = [y.lower() for y in (x.strip() for x in string.split(',')) if len(y)] diff --git a/requirements.txt b/requirements.txt index bf2ec84a1..b0f9fd71a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,7 +25,7 @@ tqdm>=4.32.1 iblatlas>=0.5.3 ibl-neuropixel>=1.0.1 iblutil>=1.11.0 -iblqt>=0.1.2 +iblqt>=0.2.0 mtscomp>=1.0.1 ONE-api~=2.9.rc0 phylib>=2.6.0 From af96df2131bb870735fe5408c82588385bcae7dc Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Tue, 15 Oct 2024 14:53:02 +0300 Subject: [PATCH 33/59] Resolves #853 --- ibllib/qc/task_qc_viewer/ViewEphysQC.py | 2 +- ibllib/qc/task_qc_viewer/task_qc.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/ibllib/qc/task_qc_viewer/ViewEphysQC.py b/ibllib/qc/task_qc_viewer/ViewEphysQC.py index 4e6684313..cae7431c2 100644 --- a/ibllib/qc/task_qc_viewer/ViewEphysQC.py +++ b/ibllib/qc/task_qc_viewer/ViewEphysQC.py @@ -229,7 +229,7 @@ def tv_double_clicked(self, index: QModelIndex): dt = t1 - t0 if self.wheel: idx = np.searchsorted(self.wheel['re_ts'], np.array([t0 - dt / 10, t1 + dt / 10])) - period = self.wheel['re_pos'][idx[0] : idx[1]] + period = self.wheel['re_pos'][idx[0]:idx[1]] if period.size == 0: _logger.warning('No wheel data during trial #%i', index.row()) else: diff --git a/ibllib/qc/task_qc_viewer/task_qc.py b/ibllib/qc/task_qc_viewer/task_qc.py index 8d0c034c9..a49c703eb 100644 --- a/ibllib/qc/task_qc_viewer/task_qc.py +++ b/ibllib/qc/task_qc_viewer/task_qc.py @@ -288,7 +288,10 @@ def show_session_task_qc(qc_or_session=None, bpod_only=False, local=False, one=N # Update table and callbacks n_trials = qc.frame.shape[0] - df_trials = pd.DataFrame({k: v for k, v in task_qc.extractor.data.items() if v.size == n_trials}) + df_trials = pd.DataFrame({ + k: v for k, v in task_qc.extractor.data.items() + if v.size == n_trials and not k.startswith('wheel') + }) df_pass = pd.DataFrame({k: v for k, v in qc.qc.passed.items() if isinstance(v, np.ndarray) and v.size == n_trials}) df_pass.drop('_task_passed_trial_checks', axis=1, errors='ignore', inplace=True) df_pass.rename(columns=lambda x: x.replace('_task', 'passed'), inplace=True) From bbf915ddd937c506c6ac0e597fcebe5ff72891b9 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Wed, 16 Oct 2024 12:46:19 +0100 Subject: [PATCH 34/59] Revert "Resolves #853" This reverts commit af96df2131bb870735fe5408c82588385bcae7dc. --- ibllib/qc/task_qc_viewer/ViewEphysQC.py | 2 +- ibllib/qc/task_qc_viewer/task_qc.py | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/ibllib/qc/task_qc_viewer/ViewEphysQC.py b/ibllib/qc/task_qc_viewer/ViewEphysQC.py index cae7431c2..4e6684313 100644 --- a/ibllib/qc/task_qc_viewer/ViewEphysQC.py +++ b/ibllib/qc/task_qc_viewer/ViewEphysQC.py @@ -229,7 +229,7 @@ def tv_double_clicked(self, index: QModelIndex): dt = t1 - t0 if self.wheel: idx = np.searchsorted(self.wheel['re_ts'], np.array([t0 - dt / 10, t1 + dt / 10])) - period = self.wheel['re_pos'][idx[0]:idx[1]] + period = self.wheel['re_pos'][idx[0] : idx[1]] if period.size == 0: _logger.warning('No wheel data during trial #%i', index.row()) else: diff --git a/ibllib/qc/task_qc_viewer/task_qc.py b/ibllib/qc/task_qc_viewer/task_qc.py index a49c703eb..8d0c034c9 100644 --- a/ibllib/qc/task_qc_viewer/task_qc.py +++ b/ibllib/qc/task_qc_viewer/task_qc.py @@ -288,10 +288,7 @@ def show_session_task_qc(qc_or_session=None, bpod_only=False, local=False, one=N # Update table and callbacks n_trials = qc.frame.shape[0] - df_trials = pd.DataFrame({ - k: v for k, v in task_qc.extractor.data.items() - if v.size == n_trials and not k.startswith('wheel') - }) + df_trials = pd.DataFrame({k: v for k, v in task_qc.extractor.data.items() if v.size == n_trials}) df_pass = pd.DataFrame({k: v for k, v in qc.qc.passed.items() if isinstance(v, np.ndarray) and v.size == n_trials}) df_pass.drop('_task_passed_trial_checks', axis=1, errors='ignore', inplace=True) df_pass.rename(columns=lambda x: x.replace('_task', 'passed'), inplace=True) From 4cb2ee3a33b5a493521c7c5b58ffc67a73de5a1b Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Wed, 16 Oct 2024 12:56:05 +0100 Subject: [PATCH 35/59] Remove 'peakVelocity_times` from QC trials table --- ibllib/qc/task_qc_viewer/task_qc.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ibllib/qc/task_qc_viewer/task_qc.py b/ibllib/qc/task_qc_viewer/task_qc.py index 8d0c034c9..4e3be3ca9 100644 --- a/ibllib/qc/task_qc_viewer/task_qc.py +++ b/ibllib/qc/task_qc_viewer/task_qc.py @@ -288,7 +288,10 @@ def show_session_task_qc(qc_or_session=None, bpod_only=False, local=False, one=N # Update table and callbacks n_trials = qc.frame.shape[0] - df_trials = pd.DataFrame({k: v for k, v in task_qc.extractor.data.items() if v.size == n_trials}) + df_trials = pd.DataFrame({ + k: v for k, v in task_qc.extractor.data.items() + if v.size == n_trials and not k.startswith('peakVelocity') + }) df_pass = pd.DataFrame({k: v for k, v in qc.qc.passed.items() if isinstance(v, np.ndarray) and v.size == n_trials}) df_pass.drop('_task_passed_trial_checks', axis=1, errors='ignore', inplace=True) df_pass.rename(columns=lambda x: x.replace('_task', 'passed'), inplace=True) From f32005d700c2fd1655097c177794a649134d642a Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Wed, 16 Oct 2024 13:04:30 +0100 Subject: [PATCH 36/59] Revert "Remove 'peakVelocity_times` from QC trials table" This reverts commit 4cb2ee3a33b5a493521c7c5b58ffc67a73de5a1b. --- ibllib/qc/task_qc_viewer/task_qc.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/ibllib/qc/task_qc_viewer/task_qc.py b/ibllib/qc/task_qc_viewer/task_qc.py index 4e3be3ca9..8d0c034c9 100644 --- a/ibllib/qc/task_qc_viewer/task_qc.py +++ b/ibllib/qc/task_qc_viewer/task_qc.py @@ -288,10 +288,7 @@ def show_session_task_qc(qc_or_session=None, bpod_only=False, local=False, one=N # Update table and callbacks n_trials = qc.frame.shape[0] - df_trials = pd.DataFrame({ - k: v for k, v in task_qc.extractor.data.items() - if v.size == n_trials and not k.startswith('peakVelocity') - }) + df_trials = pd.DataFrame({k: v for k, v in task_qc.extractor.data.items() if v.size == n_trials}) df_pass = pd.DataFrame({k: v for k, v in qc.qc.passed.items() if isinstance(v, np.ndarray) and v.size == n_trials}) df_pass.drop('_task_passed_trial_checks', axis=1, errors='ignore', inplace=True) df_pass.rename(columns=lambda x: x.replace('_task', 'passed'), inplace=True) From 5e54694279a380ecf95021c4d964df4159b0c988 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Wed, 16 Oct 2024 13:04:34 +0100 Subject: [PATCH 37/59] Reapply "Resolves #853" This reverts commit bbf915ddd937c506c6ac0e597fcebe5ff72891b9. --- ibllib/qc/task_qc_viewer/ViewEphysQC.py | 2 +- ibllib/qc/task_qc_viewer/task_qc.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/ibllib/qc/task_qc_viewer/ViewEphysQC.py b/ibllib/qc/task_qc_viewer/ViewEphysQC.py index 4e6684313..cae7431c2 100644 --- a/ibllib/qc/task_qc_viewer/ViewEphysQC.py +++ b/ibllib/qc/task_qc_viewer/ViewEphysQC.py @@ -229,7 +229,7 @@ def tv_double_clicked(self, index: QModelIndex): dt = t1 - t0 if self.wheel: idx = np.searchsorted(self.wheel['re_ts'], np.array([t0 - dt / 10, t1 + dt / 10])) - period = self.wheel['re_pos'][idx[0] : idx[1]] + period = self.wheel['re_pos'][idx[0]:idx[1]] if period.size == 0: _logger.warning('No wheel data during trial #%i', index.row()) else: diff --git a/ibllib/qc/task_qc_viewer/task_qc.py b/ibllib/qc/task_qc_viewer/task_qc.py index 8d0c034c9..a49c703eb 100644 --- a/ibllib/qc/task_qc_viewer/task_qc.py +++ b/ibllib/qc/task_qc_viewer/task_qc.py @@ -288,7 +288,10 @@ def show_session_task_qc(qc_or_session=None, bpod_only=False, local=False, one=N # Update table and callbacks n_trials = qc.frame.shape[0] - df_trials = pd.DataFrame({k: v for k, v in task_qc.extractor.data.items() if v.size == n_trials}) + df_trials = pd.DataFrame({ + k: v for k, v in task_qc.extractor.data.items() + if v.size == n_trials and not k.startswith('wheel') + }) df_pass = pd.DataFrame({k: v for k, v in qc.qc.passed.items() if isinstance(v, np.ndarray) and v.size == n_trials}) df_pass.drop('_task_passed_trial_checks', axis=1, errors='ignore', inplace=True) df_pass.rename(columns=lambda x: x.replace('_task', 'passed'), inplace=True) From cf76f6054dca7b775c4d51cffa9f96529e8d31cf Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Tue, 3 Dec 2024 13:21:15 +0000 Subject: [PATCH 38/59] fix ready4recording with new available json info --- brainbox/behavior/training.py | 4 ++-- ibllib/oneibl/registration.py | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/brainbox/behavior/training.py b/brainbox/behavior/training.py index 6959c3bae..92d95d1cb 100644 --- a/brainbox/behavior/training.py +++ b/brainbox/behavior/training.py @@ -265,13 +265,13 @@ def get_sessions(subj, date=None, one=None): if not np.any(np.array(task_protocol) == 'training'): ephys_sess = one.alyx.rest('sessions', 'list', subject=subj, date_range=[sess_dates[-1], sess_dates[0]], - django='json__PYBPOD_BOARD__icontains,ephys') + django='location__name__icontains,ephys') if len(ephys_sess) > 0: ephys_sess_dates = [sess['start_time'][:10] for sess in ephys_sess] n_delay = len(one.alyx.rest('sessions', 'list', subject=subj, date_range=[sess_dates[-1], sess_dates[0]], - django='json__SESSION_START_DELAY_SEC__gte,900')) + django='json__SESSION_DELAY_START__gte,900')) else: ephys_sess_dates = [] n_delay = 0 diff --git a/ibllib/oneibl/registration.py b/ibllib/oneibl/registration.py index 48767628e..85f4c174f 100644 --- a/ibllib/oneibl/registration.py +++ b/ibllib/oneibl/registration.py @@ -286,6 +286,11 @@ def register_session(self, ses_path, file_list=True, projects=None, procedures=N poo_counts = [md.get('POOP_COUNT') for md in settings if md.get('POOP_COUNT') is not None] if poo_counts: json_field['POOP_COUNT'] = int(sum(poo_counts)) + # Get the session start delay if available, needed for the training status + session_delay = [md.get('SESSION_DELAY_START') for md in settings + if md.get('SESSION_DELAY_START') is not None] + if session_delay: + json_field['SESSION_DELAY_START'] = int(sum(session_delay)) if not len(session): # Create session and weighings ses_ = {'subject': subject['nickname'], From e05da9ca07ab26c1024bca535eb8f8ff746315e9 Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Tue, 3 Dec 2024 14:14:00 +0000 Subject: [PATCH 39/59] fix keys in training status code too --- ibllib/pipes/training_status.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ibllib/pipes/training_status.py b/ibllib/pipes/training_status.py index fec33baaf..aecb33a13 100644 --- a/ibllib/pipes/training_status.py +++ b/ibllib/pipes/training_status.py @@ -433,12 +433,12 @@ def compute_session_duration_delay_location(sess_path, collections=None, **kwarg try: start_time, end_time = _get_session_times(sess_path, md, sess_data) session_duration = session_duration + int((end_time - start_time).total_seconds() / 60) - session_delay = session_delay + md.get('SESSION_START_DELAY_SEC', 0) + session_delay = session_delay + md.get('SESSION_DELAY_START', 0) except Exception: session_duration = session_duration + 0 session_delay = session_delay + 0 - if 'ephys' in md.get('PYBPOD_BOARD', None): + if 'ephys' in md.get('RIG_NAME', None): session_location = 'ephys_rig' else: session_location = 'training_rig' From f7c78e698b60ba8defd3e348cbf17db0b8d4f6cd Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Tue, 3 Dec 2024 15:08:18 +0000 Subject: [PATCH 40/59] add ibl-style to requirements --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index bf0f3128e..92066d9a0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,3 +31,4 @@ phylib>=2.6.0 psychofit slidingRP>=1.1.1 # steinmetz lab refractory period metrics pyqt5 +ibl-style From ecffa8d186ebd46eec836ffb5d2a2d1b78eeceeb Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Tue, 3 Dec 2024 15:10:18 +0000 Subject: [PATCH 41/59] relax mode to warn --- ibllib/pipes/training_status.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ibllib/pipes/training_status.py b/ibllib/pipes/training_status.py index aecb33a13..b0bb2dbb3 100644 --- a/ibllib/pipes/training_status.py +++ b/ibllib/pipes/training_status.py @@ -116,7 +116,7 @@ def load_existing_dataframe(subj_path): return None -def load_trials(sess_path, one, collections=None, force=True, mode='raise'): +def load_trials(sess_path, one, collections=None, force=True, mode='warn'): """ Load trials data for session. First attempts to load from local session path, if this fails will attempt to download via ONE, if this also fails, will then attempt to re-extract locally From 62848916f8b5ac57ffdbbcc0ee8e5616b858287d Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Fri, 6 Dec 2024 16:37:26 +0000 Subject: [PATCH 42/59] return info about fit parameters --- brainbox/behavior/training.py | 220 ++++++++++++++++++++++++-------- brainbox/tests/test_behavior.py | 33 +++-- ibllib/pipes/training_status.py | 101 +++++++++++---- 3 files changed, 263 insertions(+), 91 deletions(-) diff --git a/brainbox/behavior/training.py b/brainbox/behavior/training.py index 92d95d1cb..69614e51a 100644 --- a/brainbox/behavior/training.py +++ b/brainbox/behavior/training.py @@ -157,7 +157,7 @@ def get_subject_training_status(subj, date=None, details=True, one=None): if not trials: return sess_dates = list(trials.keys()) - status, info = get_training_status(trials, task_protocol, ephys_sess, n_delay) + status, info, _ = get_training_status(trials, task_protocol, ephys_sess, n_delay) if details: if np.any(info.get('psych')): @@ -313,23 +313,32 @@ def get_training_status(trials, task_protocol, ephys_sess_dates, n_delay): info = Bunch() trials_all = concatenate_trials(trials) + info.session_dates = list(trials.keys()) + info.protocols = [p for p in task_protocol] # Case when all sessions are trainingChoiceWorld if np.all(np.array(task_protocol) == 'training'): signed_contrast = get_signed_contrast(trials_all) (info.perf_easy, info.n_trials, info.psych, info.rt) = compute_training_info(trials, trials_all) - if not np.any(signed_contrast == 0): - status = 'in training' + + pass_criteria, criteria = criterion_1b(info.psych, info.n_trials, info.perf_easy, info.rt, + signed_contrast) + if pass_criteria: + failed_criteria = Bunch() + failed_criteria['NBiased'] = {'val': info.protocols, 'pass': False} + failed_criteria['Criteria'] = {'val': 'ready4ephysrig', 'pass': False} + status = 'trained 1b' else: - if criterion_1b(info.psych, info.n_trials, info.perf_easy, info.rt): - status = 'trained 1b' - elif criterion_1a(info.psych, info.n_trials, info.perf_easy): + failed_criteria = criteria + pass_criteria, criteria = criterion_1a(info.psych, info.n_trials, info.perf_easy, signed_contrast) + if pass_criteria: status = 'trained 1a' else: + failed_criteria = criteria status = 'in training' - return status, info + return status, info, failed_criteria # Case when there are < 3 biasedChoiceWorld sessions after reaching trained_1b criterion if ~np.all(np.array(task_protocol) == 'training') and \ @@ -338,7 +347,11 @@ def get_training_status(trials, task_protocol, ephys_sess_dates, n_delay): (info.perf_easy, info.n_trials, info.psych, info.rt) = compute_training_info(trials, trials_all) - return status, info + criteria = Bunch() + criteria['NBiased'] = {'val': info.protocols, 'pass': False} + criteria['Criteria'] = {'val': 'ready4ephysrig', 'pass': False} + + return status, info, criteria # Case when there is biasedChoiceWorld or ephysChoiceWorld in last three sessions if not np.any(np.array(task_protocol) == 'training'): @@ -346,37 +359,40 @@ def get_training_status(trials, task_protocol, ephys_sess_dates, n_delay): (info.perf_easy, info.n_trials, info.psych_20, info.psych_80, info.rt) = compute_bias_info(trials, trials_all) - # We are still on training rig and so all sessions should be biased - if len(ephys_sess_dates) == 0: - assert np.all(np.array(task_protocol) == 'biased') - if criterion_ephys(info.psych_20, info.psych_80, info.n_trials, info.perf_easy, - info.rt): - status = 'ready4ephysrig' - else: - status = 'trained 1b' - elif len(ephys_sess_dates) < 3: + n_ephys = len(ephys_sess_dates) + info.n_ephys = n_ephys + info.n_delay = n_delay + + # Criterion recording + pass_criteria, criteria = criteria_recording(n_ephys, n_delay, info.psych_20, info.psych_80, info.n_trials, + info.perf_easy, info.rt) + if pass_criteria: + # Here the criteria doesn't actually fail but we have no other criteria to meet so we return this + failed_criteria = criteria + status = 'ready4recording' + else: + failed_criteria = criteria assert all(date in trials for date in ephys_sess_dates) perf_ephys_easy = np.array([compute_performance_easy(trials[k]) for k in ephys_sess_dates]) n_ephys_trials = np.array([compute_n_trials(trials[k]) for k in ephys_sess_dates]) - if criterion_delay(n_ephys_trials, perf_ephys_easy): - status = 'ready4delay' - else: - status = 'ready4ephysrig' - - elif len(ephys_sess_dates) >= 3: - if n_delay > 0 and \ - criterion_ephys(info.psych_20, info.psych_80, info.n_trials, info.perf_easy, - info.rt): - status = 'ready4recording' - elif criterion_delay(info.n_trials, info.perf_easy): + pass_criteria, criteria = criterion_delay(n_ephys, n_ephys_trials, perf_ephys_easy) + + if pass_criteria: status = 'ready4delay' else: - status = 'ready4ephysrig' + failed_criteria = criteria + pass_criteria, criteria = criterion_ephys(info.psych_20, info.psych_80, info.n_trials, + info.perf_easy, info.rt) + if pass_criteria: + status = 'ready4ephysrig' + else: + failed_criteria = criteria + status = 'trained 1b' - return status, info + return status, info, failed_criteria def display_status(subj, sess_dates, status, perf_easy=None, n_trials=None, psych=None, @@ -814,7 +830,7 @@ def compute_reaction_time(trials, stim_on_type='stimOn_times', stim_off_type='re return reaction_time, contrasts, n_contrasts, -def criterion_1a(psych, n_trials, perf_easy): +def criterion_1a(psych, n_trials, perf_easy, signed_contrast): """ Returns bool indicating whether criteria for status 'trained_1a' are met. @@ -825,6 +841,7 @@ def criterion_1a(psych, n_trials, perf_easy): - Lapse rate on both sides is less than 0.2 - The total number of trials is greater than 200 for each session - Performance on easy contrasts > 80% for all sessions + - Zero contrast trials must be present Parameters ---------- @@ -835,11 +852,15 @@ def criterion_1a(psych, n_trials, perf_easy): The number for trials for each session. perf_easy : numpy.array of float The proportion of correct high contrast trials for each session. + signed_contrast: numpy.array + Unique list of contrasts displayed Returns ------- bool True if the criteria are met for 'trained_1a'. + Bunch + Bunch containing breakdown of the passing/ failing critieria Notes ----- @@ -847,12 +868,23 @@ def criterion_1a(psych, n_trials, perf_easy): for a number of sessions determined to be of 'good' performance by an experimenter. """ - criterion = (abs(psych[0]) < 16 and psych[1] < 19 and psych[2] < 0.2 and psych[3] < 0.2 and - np.all(n_trials > 200) and np.all(perf_easy > 0.8)) - return criterion + criteria = Bunch() + criteria['ZeroContrast'] = {'val': signed_contrast, 'pass': np.any(signed_contrast == 0)} + criteria['LapseLow_50'] = {'val': psych[2], 'pass': psych[2] < 0.2 } + criteria['LapseHigh_50'] = {'val': psych[3], 'pass': psych[3] < 0.2} + criteria['Bias'] = {'val': psych[0], 'pass': abs(psych[0]) < 16} + criteria['Threshold'] = {'val': psych[1], 'pass': psych[1] < 19} + criteria['NTrials'] = {'val': n_trials, 'pass': np.all(n_trials > 200)} + criteria['PerfEasy'] = {'val': perf_easy, 'pass': np.all(perf_easy > 0.8)} + + passing = np.all([v['pass'] for k, v in criteria.items()]) + criteria['Criteria'] = {'val': 'trained_1a', 'pass': passing} -def criterion_1b(psych, n_trials, perf_easy, rt): + return passing, criteria + + +def criterion_1b(psych, n_trials, perf_easy, rt, signed_contrast): """ Returns bool indicating whether criteria for trained_1b are met. @@ -864,6 +896,7 @@ def criterion_1b(psych, n_trials, perf_easy, rt): - The total number of trials is greater than 400 for each session - Performance on easy contrasts > 90% for all sessions - The median response time across all zero contrast trials is less than 2 seconds + - Zero contrast trials must be present Parameters ---------- @@ -876,11 +909,15 @@ def criterion_1b(psych, n_trials, perf_easy, rt): The proportion of correct high contrast trials for each session. rt : float The median response time for zero contrast trials. + signed_contrast: numpy.array + Unique list of contrasts displayed Returns ------- bool True if the criteria are met for 'trained_1b'. + Bunch + Bunch containing breakdown of the passing/ failing critieria Notes ----- @@ -890,17 +927,27 @@ def criterion_1b(psych, n_trials, perf_easy, rt): regrettably means that the maximum threshold fit for 1b is greater than for 1a, meaning the slope of the psychometric curve may be slightly less steep than 1a. """ - criterion = (abs(psych[0]) < 10 and psych[1] < 20 and psych[2] < 0.1 and psych[3] < 0.1 and - np.all(n_trials > 400) and np.all(perf_easy > 0.9) and rt < 2) - return criterion + + criteria = Bunch() + criteria['ZeroContrast'] = {'val': signed_contrast, 'pass': np.any(signed_contrast == 0)} + criteria['LapseLow_50'] = {'val': psych[2], 'pass': psych[2] < 0.1} + criteria['LapseHigh_50'] = {'val': psych[3], 'pass': psych[3] < 0.1} + criteria['Bias'] = {'val': psych[0], 'pass': abs(psych[0]) < 10} + criteria['Threshold'] = {'val': psych[1], 'pass': psych[1] < 20} + criteria['NTrials'] = {'val': n_trials, 'pass': np.all(n_trials > 400)} + criteria['PerfEasy'] = {'val': perf_easy, 'pass': np.all(perf_easy > 0.9)} + criteria['ReactionTime'] = {'val': rt, 'pass': rt < 2} + + passing = np.all([v['pass'] for k, v in criteria.items()]) + + criteria['Criteria'] = {'val': 'trained_1b', 'pass': passing} + + return passing, criteria def criterion_ephys(psych_20, psych_80, n_trials, perf_easy, rt): """ - Returns bool indicating whether criteria for ready4ephysrig or ready4recording are met. - - NB: The difference between these two is whether the sessions were acquired ot a recording rig - with a delay before the first trial. Neither of these two things are tested here. + Returns bool indicating whether criteria for ready4ephysrig are met. Criteria -------- @@ -929,21 +976,34 @@ def criterion_ephys(psych_20, psych_80, n_trials, perf_easy, rt): Returns ------- bool - True if subject passes the ready4ephysrig or ready4recording criteria. + True if subject passes the ready4ephysrig criteria. + Bunch + Bunch containing breakdown of the passing/ failing critieria """ + criteria = Bunch() + criteria['LapseLow_80'] = {'val': psych_80[2], 'pass': psych_80[2] < 0.1} + criteria['LapseHigh_80'] = {'val': psych_80[3], 'pass': psych_80[3] < 0.1} + criteria['LapseLow_20'] = {'val': psych_20[2], 'pass': psych_20[2] < 0.1} + criteria['LapseHigh_20'] = {'val': psych_20[3], 'pass': psych_20[3] < 0.1} + criteria['BiasShift'] = {'val': psych_80[0] - psych_20[0], 'pass': psych_80[0] - psych_20[0] > 5} + criteria['NTrials'] = {'val': n_trials, 'pass': np.all(n_trials > 400)} + criteria['PerfEasy'] = {'val': perf_easy, 'pass': np.all(perf_easy > 0.9)} + criteria['ReactionTime'] = {'val': rt, 'pass': rt < 2} - criterion = (np.all(np.r_[psych_20[2:4], psych_80[2:4]] < 0.1) and # lapse - psych_80[0] - psych_20[0] > 5 and np.all(n_trials > 400) and # bias shift and n trials - np.all(perf_easy > 0.9) and rt < 2) # overall performance and response times - return criterion + passing = np.all([v['pass'] for k, v in criteria.items()]) + criteria['Criteria'] = {'val': 'ready4ephysrig', 'pass': passing} -def criterion_delay(n_trials, perf_easy): + return passing, criteria + + +def criterion_delay(n_ephys, n_trials, perf_easy): """ Returns bool indicating whether criteria for 'ready4delay' is met. Criteria -------- + - At least one session on an ephys rig - Total number of trials for any of the sessions is greater than 400 - Performance on easy contrasts is greater than 90% for any of the sessions @@ -959,9 +1019,69 @@ def criterion_delay(n_trials, perf_easy): ------- bool True if subject passes the 'ready4delay' criteria. + Bunch + Bunch containing breakdown of the passing/ failing critieria + """ + + criteria = Bunch() + criteria['NEphys'] = {'val': n_ephys, 'pass': n_ephys > 0} + criteria['NTrials'] = {'val': n_trials, 'pass': np.any(n_trials > 400)} + criteria['PerfEasy'] = {'val': perf_easy, 'pass': np.any(perf_easy > 0.9)} + + passing = np.all([v['pass'] for k, v in criteria.items()]) + + criteria['Criteria'] = {'val': 'ready4delay', 'pass': passing} + + return passing, criteria + + +def criteria_recording(n_ephys, delay, psych_20, psych_80, n_trials, perf_easy, rt): """ - criterion = np.any(n_trials > 400) and np.any(perf_easy > 0.9) - return criterion + Returns bool indicating whether criteria for ready4recording are met. + + Criteria + -------- + - At least 3 ephys sessions + - Delay on any session > 0 + - Lapse on both sides < 0.1 for both bias blocks + - Bias shift between blocks > 5 + - Total number of trials > 400 for all sessions + - Performance on easy contrasts > 90% for all sessions + - Median response time for zero contrast stimuli < 2 seconds + + Parameters + ---------- + psych_20 : numpy.array + The fit psychometric parameters for the blocks where probability of a left stimulus is 0.2. + Parameters are bias, threshold, lapse high, lapse low. + psych_80 : numpy.array + The fit psychometric parameters for the blocks where probability of a left stimulus is 0.8. + Parameters are bias, threshold, lapse high, lapse low. + n_trials : numpy.array + The number of trials for each session (typically three consecutive sessions). + perf_easy : numpy.array + The proportion of correct high contrast trials for each session (typically three + consecutive sessions). + rt : float + The median response time for zero contrast trials. + + Returns + ------- + bool + True if subject passes the ready4recording criteria. + Bunch + Bunch containing breakdown of the passing/ failing critieria + """ + + _, criteria = criterion_ephys(psych_20, psych_80, n_trials, perf_easy, rt) + criteria['NEphys'] = {'val': n_ephys, 'pass': n_ephys >= 3} + criteria['NDelay'] = {'val': delay, 'pass': delay > 0} + + passing = np.all([v['pass'] for k, v in criteria.items()]) + + criteria['Criteria'] = {'val': 'ready4recording', 'pass': passing} + + return passing, criteria def plot_psychometric(trials, ax=None, title=None, plot_ci=False, ci_alpha=0.032, **kwargs): diff --git a/brainbox/tests/test_behavior.py b/brainbox/tests/test_behavior.py index 8d02d185a..493234937 100644 --- a/brainbox/tests/test_behavior.py +++ b/brainbox/tests/test_behavior.py @@ -177,58 +177,65 @@ def test_in_training(self): trials, task_protocol = self._get_trials( sess_dates=['2020-08-25', '2020-08-24', '2020-08-21']) assert (np.all(np.array(task_protocol) == 'training')) - status, info = train.get_training_status( + status, info, crit = train.get_training_status( trials, task_protocol, ephys_sess_dates=[], n_delay=0) assert (status == 'in training') + assert (crit['Criteria']['val'] == 'trained_1a') def test_trained_1a(self): trials, task_protocol = self._get_trials( sess_dates=['2020-08-26', '2020-08-25', '2020-08-24']) assert (np.all(np.array(task_protocol) == 'training')) - status, info = train.get_training_status(trials, task_protocol, ephys_sess_dates=[], - n_delay=0) + status, info, crit = train.get_training_status(trials, task_protocol, ephys_sess_dates=[], + n_delay=0) assert (status == 'trained 1a') + assert (crit['Criteria']['val'] == 'trained_1b') def test_trained_1b(self): trials, task_protocol = self._get_trials( sess_dates=['2020-08-27', '2020-08-26', '2020-08-25']) assert (np.all(np.array(task_protocol) == 'training')) - status, info = train.get_training_status(trials, task_protocol, ephys_sess_dates=[], - n_delay=0) + status, info, crit = train.get_training_status(trials, task_protocol, ephys_sess_dates=[], + n_delay=0) self.assertEqual(status, 'trained 1b') + assert (crit['Criteria']['val'] == 'ready4ephysrig') def test_training_to_bias(self): trials, task_protocol = self._get_trials( sess_dates=['2020-08-31', '2020-08-28', '2020-08-27']) assert (~np.all(np.array(task_protocol) == 'training') and np.any(np.array(task_protocol) == 'training')) - status, info = train.get_training_status(trials, task_protocol, ephys_sess_dates=[], - n_delay=0) + status, info, crit = train.get_training_status(trials, task_protocol, ephys_sess_dates=[], + n_delay=0) assert (status == 'trained 1b') + assert (crit['Criteria']['val'] == 'ready4ephysrig') def test_ready4ephys(self): sess_dates = ['2020-09-01', '2020-08-31', '2020-08-28'] trials, task_protocol = self._get_trials(sess_dates=sess_dates) assert (np.all(np.array(task_protocol) == 'biased')) - status, info = train.get_training_status(trials, task_protocol, ephys_sess_dates=[], - n_delay=0) + status, info, crit = train.get_training_status(trials, task_protocol, ephys_sess_dates=[], + n_delay=0) assert (status == 'ready4ephysrig') + assert (crit['Criteria']['val'] == 'ready4delay') def test_ready4delay(self): sess_dates = ['2020-09-03', '2020-09-02', '2020-08-31'] trials, task_protocol = self._get_trials(sess_dates=sess_dates) assert (np.all(np.array(task_protocol) == 'biased')) - status, info = train.get_training_status(trials, task_protocol, - ephys_sess_dates=['2020-09-03'], n_delay=0) + status, info, crit = train.get_training_status(trials, task_protocol, + ephys_sess_dates=['2020-09-03'], n_delay=0) assert (status == 'ready4delay') + assert (crit['Criteria']['val'] == 'ready4recording') def test_ready4recording(self): sess_dates = ['2020-09-01', '2020-08-31', '2020-08-28'] trials, task_protocol = self._get_trials(sess_dates=sess_dates) assert (np.all(np.array(task_protocol) == 'biased')) - status, info = train.get_training_status(trials, task_protocol, - ephys_sess_dates=sess_dates, n_delay=1) + status, info, crit = train.get_training_status(trials, task_protocol, + ephys_sess_dates=sess_dates, n_delay=1) assert (status == 'ready4recording') + assert (crit['Criteria']['val'] == 'ready4recording') def test_query_criterion(self): """Test for brainbox.behavior.training.query_criterion function.""" diff --git a/ibllib/pipes/training_status.py b/ibllib/pipes/training_status.py index b0bb2dbb3..afa730add 100644 --- a/ibllib/pipes/training_status.py +++ b/ibllib/pipes/training_status.py @@ -270,7 +270,7 @@ def get_latest_training_information(sess_path, one, save=True): # Find the earliest date in missing dates that we need to recompute the training status for missing_status = find_earliest_recompute_date(df.drop_duplicates('date').reset_index(drop=True)) for date in missing_status: - df = compute_training_status(df, date, one) + df, _, _, _ = compute_training_status(df, date, one) df_lim = df.drop_duplicates(subset='session_path', keep='first') @@ -314,7 +314,7 @@ def find_earliest_recompute_date(df): return df[first_index:].date.values -def compute_training_status(df, compute_date, one, force=True): +def compute_training_status(df, compute_date, one, force=True, populate=True): """ Compute the training status for compute date based on training from that session and two previous days. @@ -331,11 +331,19 @@ def compute_training_status(df, compute_date, one, force=True): An instance of ONE for loading trials data. force : bool When true and if the session trials can't be found, will attempt to re-extract from disk. + populate : bool + Whether to update the training data frame with the new training status value Returns ------- pandas.DataFrame - The input data frame with a 'training_status' column populated for `compute_date`. + The input data frame with a 'training_status' column populated for `compute_date` if populate=True + Bunch + Bunch containing information fit parameters information for the combined sessions + Bunch + Bunch cotaining the training status criteria information + str + The training status """ # compute_date = str(alfiles.session_path_parts(session_path, as_dict=True)['date']) @@ -378,11 +386,13 @@ def compute_training_status(df, compute_date, one, force=True): ephys_sessions.append(df_date.iloc[-1]['date']) n_status = np.max([-2, -1 * len(status)]) - training_status, _ = training.get_training_status(trials, protocol, ephys_sessions, n_delay) + #training_status, info, criteria = training.get_training_status(trials, protocol, ephys_sessions, n_delay) + training_status, info, criteria = get_training_status(trials, protocol, ephys_sessions, n_delay) training_status = pass_through_training_hierachy(training_status, status[n_status]) - df.loc[df['date'] == compute_date, 'training_status'] = training_status + if populate: + df.loc[df['date'] == compute_date, 'training_status'] = training_status - return df + return df, info, criteria, training_status def pass_through_training_hierachy(status_new, status_old): @@ -586,9 +596,12 @@ def get_training_info_for_session(session_paths, one, force=True): session_path = Path(session_path) protocols = [] for c in collections: - prot = get_bpod_extractor_class(session_path, task_collection=c) - prot = prot[:-6].lower() - protocols.append(prot) + try: + prot = get_bpod_extractor_class(session_path, task_collection=c) + prot = prot[:-6].lower() + protocols.append(prot) + except ValueError: + continue un_protocols = np.unique(protocols) # Example, training, training, biased - training would be combined, biased not @@ -751,10 +764,42 @@ def plot_performance_easy_median_reaction_time(df, subject): return ax +def display_info(df, axs): + compute_date = df['date'].values[-1] + _, info, criteria, _ = compute_training_status(df, compute_date, None, force=False, populate=False) + + def _array_to_string(vals): + if isinstance(vals, (str, bool, int)): + return f'{vals}' + + str_vals = '' + for v in vals: + if isinstance(v, float): + v = np.round(v, 3) + str_vals += f'{v}, ' + return str_vals[:-2] + + pos = np.arange(len(criteria))[::-1] * 0.1 + for i, (k, v) in enumerate(info.items()): + str_v = _array_to_string(v) + text = axs[0].text(0, pos[i], k.capitalize(), color='k', fontsize=7, transform=axs[0].transAxes) + axs[0].annotate(': ' + str_v, xycoords=text, xy=(1, 0), verticalalignment="bottom", + color='k', fontsize=7) + + pos = np.arange(len(criteria))[::-1] * 0.1 + for i, (k, v) in enumerate(criteria.items()): + c = 'g' if v['pass'] else 'r' + str_v = _array_to_string(v['val']) + text = axs[1].text(0, pos[i], k.capitalize(), color='k', fontsize=7, transform=axs[1].transAxes) + axs[1].annotate(': ' + str_v, xycoords=text, xy=(1, 0), verticalalignment="bottom", + color=c, fontsize=7) + def plot_fit_params(df, subject): - fig, axs = plt.subplots(2, 2, figsize=(12, 6)) + fig, axs = plt.subplots(2, 3, figsize=(12, 6)) axs = axs.ravel() + display_info(df, axs=[axs[0, 2], axs[1, 2]]) + df = df.drop_duplicates('date').reset_index(drop=True) cmap = sns.diverging_palette(20, 220, n=3, center="dark") @@ -777,11 +822,11 @@ def plot_fit_params(df, subject): 'color': cmap[0], 'join': False} - plot_over_days(df, subject, y50, ax=axs[0], legend=False, title=False) - plot_over_days(df, subject, y80, ax=axs[0], legend=False, title=False) - plot_over_days(df, subject, y20, ax=axs[0], legend=False, title=False) - axs[0].axhline(16, linewidth=2, linestyle='--', color='k') - axs[0].axhline(-16, linewidth=2, linestyle='--', color='k') + plot_over_days(df, subject, y50, ax=axs[0, 0], legend=False, title=False) + plot_over_days(df, subject, y80, ax=axs[0, 0], legend=False, title=False) + plot_over_days(df, subject, y20, ax=axs[0, 0], legend=False, title=False) + axs[0, 0].axhline(16, linewidth=2, linestyle='--', color='k') + axs[0, 0].axhline(-16, linewidth=2, linestyle='--', color='k') y50['column'] = 'combined_thres_50' y50['title'] = 'Threshold' @@ -793,10 +838,10 @@ def plot_fit_params(df, subject): y20['title'] = 'Threshold' y80['lim'] = [0, 100] - plot_over_days(df, subject, y50, ax=axs[1], legend=False, title=False) - plot_over_days(df, subject, y80, ax=axs[1], legend=False, title=False) - plot_over_days(df, subject, y20, ax=axs[1], legend=False, title=False) - axs[1].axhline(19, linewidth=2, linestyle='--', color='k') + plot_over_days(df, subject, y50, ax=axs[0, 1], legend=False, title=False) + plot_over_days(df, subject, y80, ax=axs[0, 1], legend=False, title=False) + plot_over_days(df, subject, y20, ax=axs[0, 1], legend=False, title=False) + axs[0, 1].axhline(19, linewidth=2, linestyle='--', color='k') y50['column'] = 'combined_lapselow_50' y50['title'] = 'Lapse Low' @@ -808,10 +853,10 @@ def plot_fit_params(df, subject): y20['title'] = 'Lapse Low' y20['lim'] = [0, 1] - plot_over_days(df, subject, y50, ax=axs[2], legend=False, title=False) - plot_over_days(df, subject, y80, ax=axs[2], legend=False, title=False) - plot_over_days(df, subject, y20, ax=axs[2], legend=False, title=False) - axs[2].axhline(0.2, linewidth=2, linestyle='--', color='k') + plot_over_days(df, subject, y50, ax=axs[1, 0], legend=False, title=False) + plot_over_days(df, subject, y80, ax=axs[1, 0], legend=False, title=False) + plot_over_days(df, subject, y20, ax=axs[1, 0], legend=False, title=False) + axs[1, 0].axhline(0.2, linewidth=2, linestyle='--', color='k') y50['column'] = 'combined_lapsehigh_50' y50['title'] = 'Lapse High' @@ -823,13 +868,13 @@ def plot_fit_params(df, subject): y20['title'] = 'Lapse High' y20['lim'] = [0, 1] - plot_over_days(df, subject, y50, ax=axs[3], legend=False, title=False, training_lines=True) - plot_over_days(df, subject, y80, ax=axs[3], legend=False, title=False, training_lines=False) - plot_over_days(df, subject, y20, ax=axs[3], legend=False, title=False, training_lines=False) - axs[3].axhline(0.2, linewidth=2, linestyle='--', color='k') + plot_over_days(df, subject, y50, ax=axs[1, 1], legend=False, title=False, training_lines=True) + plot_over_days(df, subject, y80, ax=axs[1, 1], legend=False, title=False, training_lines=False) + plot_over_days(df, subject, y20, ax=axs[1, 1], legend=False, title=False, training_lines=False) + axs[1, 1].axhline(0.2, linewidth=2, linestyle='--', color='k') fig.suptitle(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}') - lines, labels = axs[3].get_legend_handles_labels() + lines, labels = axs[1, 1].get_legend_handles_labels() fig.legend(lines, labels, loc='upper center', bbox_to_anchor=(0.5, 0.1), fancybox=True, shadow=True, ncol=5) legend_elements = [Line2D([0], [0], marker='o', color='w', label='p=0.5', markerfacecolor=cmap[1], markersize=8), From 5ab4065758bd79bc4fd689d92ce53c617237ac0c Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Fri, 6 Dec 2024 16:58:32 +0000 Subject: [PATCH 43/59] order the information, remove axis --- brainbox/behavior/training.py | 32 ++++++++++++++++---------------- ibllib/pipes/training_status.py | 30 ++++++++++++++++++++++-------- 2 files changed, 38 insertions(+), 24 deletions(-) diff --git a/brainbox/behavior/training.py b/brainbox/behavior/training.py index 69614e51a..0c3dc9000 100644 --- a/brainbox/behavior/training.py +++ b/brainbox/behavior/training.py @@ -869,13 +869,13 @@ def criterion_1a(psych, n_trials, perf_easy, signed_contrast): """ criteria = Bunch() - criteria['ZeroContrast'] = {'val': signed_contrast, 'pass': np.any(signed_contrast == 0)} + criteria['Zero_contrast'] = {'val': signed_contrast, 'pass': np.any(signed_contrast == 0)} criteria['LapseLow_50'] = {'val': psych[2], 'pass': psych[2] < 0.2 } criteria['LapseHigh_50'] = {'val': psych[3], 'pass': psych[3] < 0.2} criteria['Bias'] = {'val': psych[0], 'pass': abs(psych[0]) < 16} criteria['Threshold'] = {'val': psych[1], 'pass': psych[1] < 19} - criteria['NTrials'] = {'val': n_trials, 'pass': np.all(n_trials > 200)} - criteria['PerfEasy'] = {'val': perf_easy, 'pass': np.all(perf_easy > 0.8)} + criteria['N_trials'] = {'val': n_trials, 'pass': np.all(n_trials > 200)} + criteria['Perf_easy'] = {'val': perf_easy, 'pass': np.all(perf_easy > 0.8)} passing = np.all([v['pass'] for k, v in criteria.items()]) @@ -929,14 +929,14 @@ def criterion_1b(psych, n_trials, perf_easy, rt, signed_contrast): """ criteria = Bunch() - criteria['ZeroContrast'] = {'val': signed_contrast, 'pass': np.any(signed_contrast == 0)} + criteria['Zero_contrast'] = {'val': signed_contrast, 'pass': np.any(signed_contrast == 0)} criteria['LapseLow_50'] = {'val': psych[2], 'pass': psych[2] < 0.1} criteria['LapseHigh_50'] = {'val': psych[3], 'pass': psych[3] < 0.1} criteria['Bias'] = {'val': psych[0], 'pass': abs(psych[0]) < 10} criteria['Threshold'] = {'val': psych[1], 'pass': psych[1] < 20} - criteria['NTrials'] = {'val': n_trials, 'pass': np.all(n_trials > 400)} - criteria['PerfEasy'] = {'val': perf_easy, 'pass': np.all(perf_easy > 0.9)} - criteria['ReactionTime'] = {'val': rt, 'pass': rt < 2} + criteria['N_trials'] = {'val': n_trials, 'pass': np.all(n_trials > 400)} + criteria['Perf_tasy'] = {'val': perf_easy, 'pass': np.all(perf_easy > 0.9)} + criteria['Reaction_time'] = {'val': rt, 'pass': rt < 2} passing = np.all([v['pass'] for k, v in criteria.items()]) @@ -985,10 +985,10 @@ def criterion_ephys(psych_20, psych_80, n_trials, perf_easy, rt): criteria['LapseHigh_80'] = {'val': psych_80[3], 'pass': psych_80[3] < 0.1} criteria['LapseLow_20'] = {'val': psych_20[2], 'pass': psych_20[2] < 0.1} criteria['LapseHigh_20'] = {'val': psych_20[3], 'pass': psych_20[3] < 0.1} - criteria['BiasShift'] = {'val': psych_80[0] - psych_20[0], 'pass': psych_80[0] - psych_20[0] > 5} - criteria['NTrials'] = {'val': n_trials, 'pass': np.all(n_trials > 400)} - criteria['PerfEasy'] = {'val': perf_easy, 'pass': np.all(perf_easy > 0.9)} - criteria['ReactionTime'] = {'val': rt, 'pass': rt < 2} + criteria['Bias_shift'] = {'val': psych_80[0] - psych_20[0], 'pass': psych_80[0] - psych_20[0] > 5} + criteria['N_trials'] = {'val': n_trials, 'pass': np.all(n_trials > 400)} + criteria['Perf_easy'] = {'val': perf_easy, 'pass': np.all(perf_easy > 0.9)} + criteria['Reaction_time'] = {'val': rt, 'pass': rt < 2} passing = np.all([v['pass'] for k, v in criteria.items()]) @@ -1024,9 +1024,9 @@ def criterion_delay(n_ephys, n_trials, perf_easy): """ criteria = Bunch() - criteria['NEphys'] = {'val': n_ephys, 'pass': n_ephys > 0} - criteria['NTrials'] = {'val': n_trials, 'pass': np.any(n_trials > 400)} - criteria['PerfEasy'] = {'val': perf_easy, 'pass': np.any(perf_easy > 0.9)} + criteria['N_ephys'] = {'val': n_ephys, 'pass': n_ephys > 0} + criteria['N_trials'] = {'val': n_trials, 'pass': np.any(n_trials > 400)} + criteria['Perf_easy'] = {'val': perf_easy, 'pass': np.any(perf_easy > 0.9)} passing = np.all([v['pass'] for k, v in criteria.items()]) @@ -1074,8 +1074,8 @@ def criteria_recording(n_ephys, delay, psych_20, psych_80, n_trials, perf_easy, """ _, criteria = criterion_ephys(psych_20, psych_80, n_trials, perf_easy, rt) - criteria['NEphys'] = {'val': n_ephys, 'pass': n_ephys >= 3} - criteria['NDelay'] = {'val': delay, 'pass': delay > 0} + criteria['N_ephys'] = {'val': n_ephys, 'pass': n_ephys >= 3} + criteria['N_delay'] = {'val': delay, 'pass': delay > 0} passing = np.all([v['pass'] for k, v in criteria.items()]) diff --git a/ibllib/pipes/training_status.py b/ibllib/pipes/training_status.py index afa730add..7eac77e47 100644 --- a/ibllib/pipes/training_status.py +++ b/ibllib/pipes/training_status.py @@ -769,7 +769,9 @@ def display_info(df, axs): _, info, criteria, _ = compute_training_status(df, compute_date, None, force=False, populate=False) def _array_to_string(vals): - if isinstance(vals, (str, bool, int)): + if isinstance(vals, (str, bool, int, float)): + if isinstance(vals, float): + vals = np.round(vals, 3) return f'{vals}' str_vals = '' @@ -782,21 +784,31 @@ def _array_to_string(vals): pos = np.arange(len(criteria))[::-1] * 0.1 for i, (k, v) in enumerate(info.items()): str_v = _array_to_string(v) - text = axs[0].text(0, pos[i], k.capitalize(), color='k', fontsize=7, transform=axs[0].transAxes) + text = axs[0].text(0, pos[i], k.capitalize(), color='k', weight='bold', fontsize=8, transform=axs[0].transAxes) axs[0].annotate(': ' + str_v, xycoords=text, xy=(1, 0), verticalalignment="bottom", color='k', fontsize=7) pos = np.arange(len(criteria))[::-1] * 0.1 + crit_val = criteria.pop('Criteria') + c = 'g' if crit_val['pass'] else 'r' + str_v = _array_to_string(crit_val['val']) + text = axs[1].text(0, pos[0], 'Criteria', color='k', weight='bold', fontsize=8, transform=axs[1].transAxes) + axs[1].annotate(': ' + str_v, xycoords=text, xy=(1, 0), verticalalignment="bottom", + color=c, fontsize=7) + pos = pos[1:] + for i, (k, v) in enumerate(criteria.items()): c = 'g' if v['pass'] else 'r' str_v = _array_to_string(v['val']) - text = axs[1].text(0, pos[i], k.capitalize(), color='k', fontsize=7, transform=axs[1].transAxes) + text = axs[1].text(0, pos[i], k.capitalize(), color='k', weight='bold', fontsize=8, transform=axs[1].transAxes) axs[1].annotate(': ' + str_v, xycoords=text, xy=(1, 0), verticalalignment="bottom", color=c, fontsize=7) + axs[0].set_axis_off() + axs[1].set_axis_off() + def plot_fit_params(df, subject): - fig, axs = plt.subplots(2, 3, figsize=(12, 6)) - axs = axs.ravel() + fig, axs = plt.subplots(2, 3, figsize=(12, 6), gridspec_kw={'width_ratios': [2, 2, 1]}) display_info(df, axs=[axs[0, 2], axs[1, 2]]) @@ -875,12 +887,14 @@ def plot_fit_params(df, subject): fig.suptitle(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}') lines, labels = axs[1, 1].get_legend_handles_labels() - fig.legend(lines, labels, loc='upper center', bbox_to_anchor=(0.5, 0.1), fancybox=True, shadow=True, ncol=5) + fig.legend(lines, labels, loc='upper center', bbox_to_anchor=(0.5, 0.1), facecolor='w', fancybox=True, shadow=True, + ncol=5) legend_elements = [Line2D([0], [0], marker='o', color='w', label='p=0.5', markerfacecolor=cmap[1], markersize=8), Line2D([0], [0], marker='o', color='w', label='p=0.2', markerfacecolor=cmap[0], markersize=8), Line2D([0], [0], marker='o', color='w', label='p=0.8', markerfacecolor=cmap[2], markersize=8)] - legend2 = plt.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.1, -0.2), fancybox=True, shadow=True) + legend2 = plt.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.1, -0.2), fancybox=True, + shadow=True, facecolor='w') fig.add_artist(legend2) return axs @@ -952,7 +966,7 @@ def plot_over_days(df, subject, y1, y2=None, ax=None, legend=True, title=True, t box.width, box.height * 0.9]) if legend: ax1.legend(loc='upper center', bbox_to_anchor=(0.5, -0.1), - fancybox=True, shadow=True, ncol=5) + fancybox=True, shadow=True, ncol=5, fc='white') return ax1 From f8bed5afc44614333eec7809a5f832fe195cc3a3 Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Fri, 6 Dec 2024 17:09:12 +0000 Subject: [PATCH 44/59] fix facecolor --- ibllib/pipes/training_status.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ibllib/pipes/training_status.py b/ibllib/pipes/training_status.py index 7eac77e47..71fca95cd 100644 --- a/ibllib/pipes/training_status.py +++ b/ibllib/pipes/training_status.py @@ -966,7 +966,7 @@ def plot_over_days(df, subject, y1, y2=None, ax=None, legend=True, title=True, t box.width, box.height * 0.9]) if legend: ax1.legend(loc='upper center', bbox_to_anchor=(0.5, -0.1), - fancybox=True, shadow=True, ncol=5, fc='white') + fancybox=True, shadow=True, ncol=5, facecolor='white') return ax1 From 68ffd2d69200e5c90432fdc1c08c466cd490802c Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Fri, 6 Dec 2024 17:10:54 +0000 Subject: [PATCH 45/59] correct function --- ibllib/pipes/training_status.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ibllib/pipes/training_status.py b/ibllib/pipes/training_status.py index 71fca95cd..e490a35fb 100644 --- a/ibllib/pipes/training_status.py +++ b/ibllib/pipes/training_status.py @@ -386,8 +386,7 @@ def compute_training_status(df, compute_date, one, force=True, populate=True): ephys_sessions.append(df_date.iloc[-1]['date']) n_status = np.max([-2, -1 * len(status)]) - #training_status, info, criteria = training.get_training_status(trials, protocol, ephys_sessions, n_delay) - training_status, info, criteria = get_training_status(trials, protocol, ephys_sessions, n_delay) + training_status, info, criteria = training.get_training_status(trials, protocol, ephys_sessions, n_delay) training_status = pass_through_training_hierachy(training_status, status[n_status]) if populate: df.loc[df['date'] == compute_date, 'training_status'] = training_status From 266fa73ef463e9b0e151645cfb0a7bc9b7edc6af Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Fri, 6 Dec 2024 17:13:53 +0000 Subject: [PATCH 46/59] correctly index axis --- ibllib/pipes/training_status.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ibllib/pipes/training_status.py b/ibllib/pipes/training_status.py index e490a35fb..6a6ece21f 100644 --- a/ibllib/pipes/training_status.py +++ b/ibllib/pipes/training_status.py @@ -1068,7 +1068,7 @@ def make_plots(session_path, one, df=None, save=False, upload=False, task_collec save_name = save_path.joinpath('subj_psychometric_fit_params.png') outputs.append(save_name) - ax4[0].get_figure().savefig(save_name, bbox_inches='tight') + ax4[0, 0].get_figure().savefig(save_name, bbox_inches='tight') save_name = save_path.joinpath('subj_psychometric_curve.png') outputs.append(save_name) From cba9e03ca71c3c5f305a8a367dc3f2e0868e8077 Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Fri, 6 Dec 2024 17:23:08 +0000 Subject: [PATCH 47/59] only display unique contrasts --- brainbox/behavior/training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brainbox/behavior/training.py b/brainbox/behavior/training.py index 0c3dc9000..59abd8563 100644 --- a/brainbox/behavior/training.py +++ b/brainbox/behavior/training.py @@ -318,7 +318,7 @@ def get_training_status(trials, task_protocol, ephys_sess_dates, n_delay): # Case when all sessions are trainingChoiceWorld if np.all(np.array(task_protocol) == 'training'): - signed_contrast = get_signed_contrast(trials_all) + signed_contrast = np.unique(get_signed_contrast(trials_all)) (info.perf_easy, info.n_trials, info.psych, info.rt) = compute_training_info(trials, trials_all) From 3485785edc8f6003b221bb9801c5dbf94f503ace Mon Sep 17 00:00:00 2001 From: owinter Date: Thu, 5 Dec 2024 13:21:17 +0000 Subject: [PATCH 48/59] Add brain region query to the loading notebook --- examples/exploring_data/data_download.ipynb | 34 ++++++++++++++++++--- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/examples/exploring_data/data_download.ipynb b/examples/exploring_data/data_download.ipynb index bfaca800f..48d706d2d 100644 --- a/examples/exploring_data/data_download.ipynb +++ b/examples/exploring_data/data_download.ipynb @@ -37,10 +37,10 @@ "source": [ "## Installation\n", "### Environment\n", - "To use IBL data you will need a python environment with python > 3.8. To create a new environment from scratch you can install [anaconda](https://www.anaconda.com/products/distribution#download-section) and follow the instructions below to create a new python environment (more information can also be found [here](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html))\n", + "To use IBL data you will need a python environment with python > 3.10, although Python 3.12 is recommended. To create a new environment from scratch you can install [anaconda](https://www.anaconda.com/products/distribution#download-section) and follow the instructions below to create a new python environment (more information can also be found [here](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html))\n", "\n", "```\n", - "conda create --name ibl python=3.11\n", + "conda create --name ibl python=3.12\n", "```\n", "Make sure to always activate this environment before installing or working with the IBL data\n", "```\n", @@ -138,9 +138,33 @@ "outputs": [], "source": [ "# Each session is represented by a unique experiment id (eID)\n", - "print(sessions[0])" + "print(sessions[0],)" ] }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "### Find recordings of a specific brain region\n", + "If we are interested in a given brain region, we can use the `search_insertions` method to find all recordings associated with that region. For example, to find all recordings associated with the **Rhomboid Nucleus (RH)** region of the thalamus." + ] + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "# this is the query that yields the few recordings for the Rhomboid Nucleus (RH) region\n", + "insertions_rh = one.search_insertions(atlas_acronym='RH', datasets='spikes.times.npy', project='brainwide')\n", + "\n", + "# if we want to extend the search to all thalamic regions, we can do the following\n", + "insertions_th = one.search_insertions(atlas_acronym='TH', datasets='spikes.times.npy', project='brainwide')\n", + "\n", + "# the Allen brain regions parcellation is hierarchical, and searching for Thalamus will return all child Rhomboid Nucleus (RH) regions\n", + "assert set(insertions_rh).issubset(set(insertions_th))\n" + ], + "outputs": [], + "execution_count": null + }, { "cell_type": "markdown", "metadata": {}, @@ -402,9 +426,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.16" + "version": "3.11.9" } }, "nbformat": 4, - "nbformat_minor": 1 + "nbformat_minor": 4 } From b6a3ba46f485c90d841456fc4fcd60e5cad5e6a7 Mon Sep 17 00:00:00 2001 From: owinter Date: Thu, 5 Dec 2024 17:31:42 +0000 Subject: [PATCH 49/59] wip waveforms --- .../loading_spike_waveforms.ipynb | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/examples/loading_data/loading_spike_waveforms.ipynb b/examples/loading_data/loading_spike_waveforms.ipynb index 44b659980..eb19b1daf 100644 --- a/examples/loading_data/loading_spike_waveforms.ipynb +++ b/examples/loading_data/loading_spike_waveforms.ipynb @@ -10,12 +10,14 @@ }, { "cell_type": "code", - "execution_count": null, "id": "ea70eb4a", "metadata": { - "nbsphinx": "hidden" + "nbsphinx": "hidden", + "ExecuteTime": { + "end_time": "2024-12-05T12:37:41.707044Z", + "start_time": "2024-12-05T12:37:41.703076Z" + } }, - "outputs": [], "source": [ "# Turn off logging and disable tqdm this is a hidden cell on docs page\n", "import logging\n", @@ -25,7 +27,9 @@ "logger.setLevel(logging.CRITICAL)\n", "\n", "os.environ[\"TQDM_DISABLE\"] = \"1\"" - ] + ], + "outputs": [], + "execution_count": 1 }, { "cell_type": "markdown", @@ -62,19 +66,16 @@ "%%capture\n", "from one.api import ONE\n", "from brainbox.io.one import SpikeSortingLoader\n", - "from iblatlas.atlas import AllenAtlas\n", "\n", - "one = ONE()\n", - "ba = AllenAtlas()\n", - "pid = 'da8dfec1-d265-44e8-84ce-6ae9c109b8bd' \n", + "one = ONE(base_url='https://openalyx.internationalbrainlab.org')\n", "\n", - "# Load in the spikesorting\n", - "sl = SpikeSortingLoader(pid=pid, one=one, atlas=ba)\n", - "spikes, clusters, channels = sl.load_spike_sorting()\n", - "clusters = sl.merge_clusters(spikes, clusters, channels)\n", + "pid = 'da8dfec1-d265-44e8-84ce-6ae9c109b8bd'\n", "\n", - "# Load the spike waveforms\n", - "spike_wfs = one.load_object(sl.eid, '_phy_spikes_subset', collection=sl.collection)" + "# Load in the spikesorting\n", + "ssl = SpikeSortingLoader(pid=pid, one=one)\n", + "spikes, clusters, channels = ssl.load_spike_sorting(revision='2024-05-06')\n", + "clusters = ssl.merge_clusters(spikes, clusters, channels)\n", + "waveforms = ssl.load_spike_sorting_object('waveforms')\n" ] }, { From 4a7fa9713cf94980b9f40d504ee54b04c18f2e98 Mon Sep 17 00:00:00 2001 From: owinter Date: Tue, 10 Dec 2024 15:29:20 +0000 Subject: [PATCH 50/59] notebook waveforms --- .../loading_spike_waveforms.ipynb | 133 ++++++++++-------- .../loading_spikesorting_data.ipynb | 10 +- 2 files changed, 82 insertions(+), 61 deletions(-) diff --git a/examples/loading_data/loading_spike_waveforms.ipynb b/examples/loading_data/loading_spike_waveforms.ipynb index eb19b1daf..b8e1c0ffc 100644 --- a/examples/loading_data/loading_spike_waveforms.ipynb +++ b/examples/loading_data/loading_spike_waveforms.ipynb @@ -31,37 +31,26 @@ "outputs": [], "execution_count": 1 }, - { - "cell_type": "markdown", - "id": "64cec921", - "metadata": {}, - "source": [ - "Sample of spike waveforms extracted during spike sorting" - ] - }, { "cell_type": "markdown", "id": "dca47f09", "metadata": {}, "source": [ "## Relevant Alf objects\n", - "* \\_phy_spikes_subset" + "* waveforms" ] }, { "cell_type": "markdown", "id": "eb34d848", "metadata": {}, - "source": [ - "## Loading" - ] + "source": "## Loading the spike sorting and average waveforms" }, { - "cell_type": "code", - "execution_count": null, - "id": "c5d32232", "metadata": {}, + "cell_type": "code", "outputs": [], + "execution_count": null, "source": [ "%%capture\n", "from one.api import ONE\n", @@ -75,52 +64,63 @@ "ssl = SpikeSortingLoader(pid=pid, one=one)\n", "spikes, clusters, channels = ssl.load_spike_sorting(revision='2024-05-06')\n", "clusters = ssl.merge_clusters(spikes, clusters, channels)\n", - "waveforms = ssl.load_spike_sorting_object('waveforms')\n" - ] + "waveforms = ssl.load_spike_sorting_object('waveforms')\n", + "\n", + "\n" + ], + "id": "c5d32232" }, { - "cell_type": "markdown", - "id": "327a23e7", "metadata": {}, - "source": [ - "## More details\n", - "* [Description of datasets](https://docs.google.com/document/d/1OqIqqakPakHXRAwceYLwFY9gOrm8_P62XIfCTnHwstg/edit#heading=h.vcop4lz26gs9)" - ] - }, - { "cell_type": "markdown", - "id": "257fb8b8", - "metadata": {}, - "source": [ - "## Useful modules\n", - "* COMING SOON" - ] + "source": "## Displaying a few average waveforms", + "id": "baf9a06dcf72940a" }, { - "cell_type": "markdown", - "id": "157bf219", "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, "source": [ - "## Exploring sample waveforms" - ] + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import ibldsp.waveforms\n", + "from ibl_style.style import figure_style\n", + "figure_style()\n", + "\n", + "cids = np.random.choice(np.where(clusters['label'] == 1)[0], 3)\n", + "fig, axs = plt.subplots(1, 3, figsize=(12, 4))\n", + "for i, cid in enumerate(cids):\n", + " wf = waveforms['templates'][cid, :, :]\n", + " ax = ibldsp.waveforms.double_wiggle(wf * 1e6 / 80, fs=30_000, ax=axs[i])\n", + " ax.set(title=f'Cluster {cid}')" + ], + "id": "41d410af9a6f9c0a" }, { "cell_type": "markdown", - "id": "a617f8fb", + "id": "327a23e7", "metadata": {}, "source": [ - "### Example 1: Finding the cluster ID for each sample waveform" + "## More details\n", + "During spike sorting, pre-processing operations are performed on the voltage data to create the average templates.\n", + " Although those templates are useful for clustering, they may not be the best description of the neural activity.\n", + " As such we extract average waveforms for each cluster using a different pre-processing. Details are provided on the spike sorting white paper [on figshare](https://figshare.com/articles/online_resource/Spike_sorting_pipeline_for_the_International_Brain_Laboratory/19705522?file=49783080).\n", + "* [Description of datasets](https://docs.google.com/document/d/1OqIqqakPakHXRAwceYLwFY9gOrm8_P62XIfCTnHwstg/edit?tab=t.0#heading=h.i89jwttog3fq)" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "1ac805b6", + "cell_type": "markdown", + "id": "157bf219", "metadata": {}, - "outputs": [], "source": [ - "# Find the cluster id for each sample waveform\n", - "wf_clusterIDs = spikes['clusters'][spike_wfs['spikes']]" + "## Exploring the raw waveforms\n", + "For each unit, we compiled up to raw data 256 waveforms chosen randomly from the entire recording.\n", + "The pre-processing steps included rephasing of the channels, low-cut filtering, bad channel detection and common-average referencing.\n", + "\n", + "To perform the loading we will use a convenience.\n", + "\n", + "Warning, this will download a few Gigabytes of data to your computer !" ] }, { @@ -128,7 +128,8 @@ "id": "baf9eb11", "metadata": {}, "source": [ - "### Example 2: Compute average waveform for cluster" + "### Compute average waveform for cluster\n", + "Here we will load data from the striatum and compute the average wvaeform for several stack orders: 1, 2, 4, 8, 16, 32, 64, 128 and display the resulting waveform." ] }, { @@ -138,27 +139,45 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", + "from ibldsp.utils import rms\n", + "# instantiating the waveform loader will download the raw waveform arrays and can take a few minutes (~3Gb)\n", + "wfl = ssl.raw_waveforms()\n", + "ic = np.where(np.logical_and(clusters['acronym'] == 'LSr', clusters['bitwise_fail'] == 0))[0]\n", + "# look at templates\n", "\n", - "# define cluster of interest\n", - "clustID = 2\n", + "raw_wav, info, channel_map = wfl.load_waveforms(labels=ic[12])\n", "\n", - "# Find waveforms for this cluster\n", - "wf_idx = np.where(wf_clusterIDs == clustID)[0]\n", - "wfs = spike_wfs['waveforms'][wf_idx, :, :]\n", - "\n", - "# Compute average waveform on channel with max signal (chn_index 0)\n", - "wf_avg_chn_max = np.mean(wfs[:, :, 0], axis=0)" + "snr = np.zeros(8)\n", + "fig, axs = plt.subplots(2, 4, figsize=(10, 6))\n", + "for i, ax in enumerate(axs.flat):\n", + " w_stack = np.mean(raw_wav[0, :(2 ** i), :, :], axis=0)\n", + " ax = ibldsp.waveforms.double_wiggle( w_stack * 1e6 / 80, fs=30_000, ax=axs.flatten()[i])\n", + " ax.set_title(f\"Stack {2 ** i}\")\n", + " snr[i] = 20 * np.log10(rms(w_stack[19:22, wfl.trough_offset - 10:wfl.trough_offset + 10].flatten()) / np.mean(rms(w_stack)))" ] }, { + "metadata": {}, "cell_type": "markdown", - "id": "a20b24ea", + "source": [ + "For constant gaussian noise in `n` repeated experiments, we expect the SNR to scale proportionally to the square root of `n`.\n", + "In decibels, this corresponds to 3dB / octave. Let's see how the data compares to the prediction:" + ], + "id": "10b7139533052b12" + }, + { "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, "source": [ - "## Other relevant examples\n", - "* COMING SOON" - ] + "fig, ax = plt.subplots(figsize=(6, 4))\n", + "ax.plot(np.arange(8), snr, '*', label='SNR estimation')\n", + "ax.plot(np.arange(8), np.arange(8) * 3, label='SNR predicted')\n", + "ax.set(xlabel='Stack index (log2(N))', ylabel='SNR (dB)')\n", + "ax.legend()" + ], + "id": "c092737cafb28598" } ], "metadata": { diff --git a/examples/loading_data/loading_spikesorting_data.ipynb b/examples/loading_data/loading_spikesorting_data.ipynb index f711414a1..db568215b 100644 --- a/examples/loading_data/loading_spikesorting_data.ipynb +++ b/examples/loading_data/loading_spikesorting_data.ipynb @@ -43,7 +43,8 @@ "## Relevant Alf objects\n", "* channels\n", "* clusters\n", - "* spikes" + "* spikes\n", + "* waveforms" ] }, { @@ -74,9 +75,10 @@ "outputs": [], "source": [ "pid = 'da8dfec1-d265-44e8-84ce-6ae9c109b8bd' \n", - "sl = SpikeSortingLoader(pid=pid, one=one)\n", - "spikes, clusters, channels = sl.load_spike_sorting()\n", - "clusters = sl.merge_clusters(spikes, clusters, channels)" + "ssl = SpikeSortingLoader(pid=pid, one=one)\n", + "spikes, clusters, channels = ssl.load_spike_sorting()\n", + "clusters = ssl.merge_clusters(spikes, clusters, channels)\n", + "waveforms = ssl.load_spike_sorting_object('waveforms') # loads in the template waveforms" ] }, { From 941287fd19b218ff33597c1052f22d602a0befef Mon Sep 17 00:00:00 2001 From: owinter Date: Tue, 10 Dec 2024 16:27:03 +0000 Subject: [PATCH 51/59] remove the waveforms docs and move to iblenv --- .../loading_spike_waveforms.ipynb | 204 ------------------ 1 file changed, 204 deletions(-) delete mode 100644 examples/loading_data/loading_spike_waveforms.ipynb diff --git a/examples/loading_data/loading_spike_waveforms.ipynb b/examples/loading_data/loading_spike_waveforms.ipynb deleted file mode 100644 index b8e1c0ffc..000000000 --- a/examples/loading_data/loading_spike_waveforms.ipynb +++ /dev/null @@ -1,204 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "f73e02ee", - "metadata": {}, - "source": [ - "# Loading Spike Waveforms" - ] - }, - { - "cell_type": "code", - "id": "ea70eb4a", - "metadata": { - "nbsphinx": "hidden", - "ExecuteTime": { - "end_time": "2024-12-05T12:37:41.707044Z", - "start_time": "2024-12-05T12:37:41.703076Z" - } - }, - "source": [ - "# Turn off logging and disable tqdm this is a hidden cell on docs page\n", - "import logging\n", - "import os\n", - "\n", - "logger = logging.getLogger('ibllib')\n", - "logger.setLevel(logging.CRITICAL)\n", - "\n", - "os.environ[\"TQDM_DISABLE\"] = \"1\"" - ], - "outputs": [], - "execution_count": 1 - }, - { - "cell_type": "markdown", - "id": "dca47f09", - "metadata": {}, - "source": [ - "## Relevant Alf objects\n", - "* waveforms" - ] - }, - { - "cell_type": "markdown", - "id": "eb34d848", - "metadata": {}, - "source": "## Loading the spike sorting and average waveforms" - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": [ - "%%capture\n", - "from one.api import ONE\n", - "from brainbox.io.one import SpikeSortingLoader\n", - "\n", - "one = ONE(base_url='https://openalyx.internationalbrainlab.org')\n", - "\n", - "pid = 'da8dfec1-d265-44e8-84ce-6ae9c109b8bd'\n", - "\n", - "# Load in the spikesorting\n", - "ssl = SpikeSortingLoader(pid=pid, one=one)\n", - "spikes, clusters, channels = ssl.load_spike_sorting(revision='2024-05-06')\n", - "clusters = ssl.merge_clusters(spikes, clusters, channels)\n", - "waveforms = ssl.load_spike_sorting_object('waveforms')\n", - "\n", - "\n" - ], - "id": "c5d32232" - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": "## Displaying a few average waveforms", - "id": "baf9a06dcf72940a" - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": [ - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import ibldsp.waveforms\n", - "from ibl_style.style import figure_style\n", - "figure_style()\n", - "\n", - "cids = np.random.choice(np.where(clusters['label'] == 1)[0], 3)\n", - "fig, axs = plt.subplots(1, 3, figsize=(12, 4))\n", - "for i, cid in enumerate(cids):\n", - " wf = waveforms['templates'][cid, :, :]\n", - " ax = ibldsp.waveforms.double_wiggle(wf * 1e6 / 80, fs=30_000, ax=axs[i])\n", - " ax.set(title=f'Cluster {cid}')" - ], - "id": "41d410af9a6f9c0a" - }, - { - "cell_type": "markdown", - "id": "327a23e7", - "metadata": {}, - "source": [ - "## More details\n", - "During spike sorting, pre-processing operations are performed on the voltage data to create the average templates.\n", - " Although those templates are useful for clustering, they may not be the best description of the neural activity.\n", - " As such we extract average waveforms for each cluster using a different pre-processing. Details are provided on the spike sorting white paper [on figshare](https://figshare.com/articles/online_resource/Spike_sorting_pipeline_for_the_International_Brain_Laboratory/19705522?file=49783080).\n", - "* [Description of datasets](https://docs.google.com/document/d/1OqIqqakPakHXRAwceYLwFY9gOrm8_P62XIfCTnHwstg/edit?tab=t.0#heading=h.i89jwttog3fq)" - ] - }, - { - "cell_type": "markdown", - "id": "157bf219", - "metadata": {}, - "source": [ - "## Exploring the raw waveforms\n", - "For each unit, we compiled up to raw data 256 waveforms chosen randomly from the entire recording.\n", - "The pre-processing steps included rephasing of the channels, low-cut filtering, bad channel detection and common-average referencing.\n", - "\n", - "To perform the loading we will use a convenience.\n", - "\n", - "Warning, this will download a few Gigabytes of data to your computer !" - ] - }, - { - "cell_type": "markdown", - "id": "baf9eb11", - "metadata": {}, - "source": [ - "### Compute average waveform for cluster\n", - "Here we will load data from the striatum and compute the average wvaeform for several stack orders: 1, 2, 4, 8, 16, 32, 64, 128 and display the resulting waveform." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3d8a729c", - "metadata": {}, - "outputs": [], - "source": [ - "from ibldsp.utils import rms\n", - "# instantiating the waveform loader will download the raw waveform arrays and can take a few minutes (~3Gb)\n", - "wfl = ssl.raw_waveforms()\n", - "ic = np.where(np.logical_and(clusters['acronym'] == 'LSr', clusters['bitwise_fail'] == 0))[0]\n", - "# look at templates\n", - "\n", - "raw_wav, info, channel_map = wfl.load_waveforms(labels=ic[12])\n", - "\n", - "snr = np.zeros(8)\n", - "fig, axs = plt.subplots(2, 4, figsize=(10, 6))\n", - "for i, ax in enumerate(axs.flat):\n", - " w_stack = np.mean(raw_wav[0, :(2 ** i), :, :], axis=0)\n", - " ax = ibldsp.waveforms.double_wiggle( w_stack * 1e6 / 80, fs=30_000, ax=axs.flatten()[i])\n", - " ax.set_title(f\"Stack {2 ** i}\")\n", - " snr[i] = 20 * np.log10(rms(w_stack[19:22, wfl.trough_offset - 10:wfl.trough_offset + 10].flatten()) / np.mean(rms(w_stack)))" - ] - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": [ - "For constant gaussian noise in `n` repeated experiments, we expect the SNR to scale proportionally to the square root of `n`.\n", - "In decibels, this corresponds to 3dB / octave. Let's see how the data compares to the prediction:" - ], - "id": "10b7139533052b12" - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": [ - "fig, ax = plt.subplots(figsize=(6, 4))\n", - "ax.plot(np.arange(8), snr, '*', label='SNR estimation')\n", - "ax.plot(np.arange(8), np.arange(8) * 3, label='SNR predicted')\n", - "ax.set(xlabel='Stack index (log2(N))', ylabel='SNR (dB)')\n", - "ax.legend()" - ], - "id": "c092737cafb28598" - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.6" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From 3b051e45a480853152e33818bb3ee1636812d4e7 Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Tue, 17 Dec 2024 13:19:03 +0000 Subject: [PATCH 52/59] add back compatibility for old json keys --- brainbox/behavior/training.py | 2 +- ibllib/pipes/training_status.py | 16 ++++++++++------ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/brainbox/behavior/training.py b/brainbox/behavior/training.py index 59abd8563..a8f2f383d 100644 --- a/brainbox/behavior/training.py +++ b/brainbox/behavior/training.py @@ -870,7 +870,7 @@ def criterion_1a(psych, n_trials, perf_easy, signed_contrast): criteria = Bunch() criteria['Zero_contrast'] = {'val': signed_contrast, 'pass': np.any(signed_contrast == 0)} - criteria['LapseLow_50'] = {'val': psych[2], 'pass': psych[2] < 0.2 } + criteria['LapseLow_50'] = {'val': psych[2], 'pass': psych[2] < 0.2} criteria['LapseHigh_50'] = {'val': psych[3], 'pass': psych[3] < 0.2} criteria['Bias'] = {'val': psych[0], 'pass': abs(psych[0]) < 16} criteria['Threshold'] = {'val': psych[1], 'pass': psych[1] < 19} diff --git a/ibllib/pipes/training_status.py b/ibllib/pipes/training_status.py index 6a6ece21f..50d28707f 100644 --- a/ibllib/pipes/training_status.py +++ b/ibllib/pipes/training_status.py @@ -116,7 +116,7 @@ def load_existing_dataframe(subj_path): return None -def load_trials(sess_path, one, collections=None, force=True, mode='warn'): +def load_trials(sess_path, one, collections=None, force=True, mode='raise'): """ Load trials data for session. First attempts to load from local session path, if this fails will attempt to download via ONE, if this also fails, will then attempt to re-extract locally @@ -208,9 +208,11 @@ def load_combined_trials(sess_paths, one, force=True): """ trials_dict = {} for sess_path in sess_paths: - trials = load_trials(Path(sess_path), one, force=force) + trials = load_trials(Path(sess_path), one, force=force, mode='warn') if trials is not None: - trials_dict[Path(sess_path).stem] = load_trials(Path(sess_path), one, force=force) + trials_dict[Path(sess_path).stem] = load_trials(Path(sess_path), one, force=force, mode='warn' + + ) return training.concatenate_trials(trials_dict) @@ -442,12 +444,13 @@ def compute_session_duration_delay_location(sess_path, collections=None, **kwarg try: start_time, end_time = _get_session_times(sess_path, md, sess_data) session_duration = session_duration + int((end_time - start_time).total_seconds() / 60) - session_delay = session_delay + md.get('SESSION_DELAY_START', 0) + session_delay = session_delay + md.get('SESSION_DELAY_START', + md.get('SESSION_START_DELAY_SEC', 0)) except Exception: session_duration = session_duration + 0 session_delay = session_delay + 0 - if 'ephys' in md.get('RIG_NAME', None): + if 'ephys' in md.get('RIG_NAME', md.get('PYBPOD_BOARD', None)): session_location = 'ephys_rig' else: session_location = 'training_rig' @@ -806,6 +809,7 @@ def _array_to_string(vals): axs[0].set_axis_off() axs[1].set_axis_off() + def plot_fit_params(df, subject): fig, axs = plt.subplots(2, 3, figsize=(12, 6), gridspec_kw={'width_ratios': [2, 2, 1]}) @@ -902,7 +906,7 @@ def plot_fit_params(df, subject): def plot_psychometric_curve(df, subject, one): df = df.drop_duplicates('date').reset_index(drop=True) sess_path = Path(df.iloc[-1]["session_path"]) - trials = load_trials(sess_path, one) + trials = load_trials(sess_path, one, mode='warn') fig, ax1 = plt.subplots(figsize=(8, 6)) From 8bfa7e365cbc8a6f76c83fc2200495b665eb7356 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Tue, 17 Dec 2024 15:05:56 +0000 Subject: [PATCH 53/59] Update requirements.txt --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 5f2fc9c35..b890b3e5e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,7 +25,7 @@ tqdm>=4.32.1 iblatlas>=0.5.3 ibl-neuropixel>=1.5.0 iblutil>=1.13.0 -iblqt>=0.2.0 +iblqt>=0.3.2 mtscomp>=1.0.1 ONE-api>=2.11 phylib>=2.6.0 From fb3ea3188299c6666356ae704fde5c79bd3bca18 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Tue, 17 Dec 2024 15:47:56 +0000 Subject: [PATCH 54/59] fix CI --- ibllib/qc/task_qc_viewer/task_qc.py | 13 ++++++++----- ibllib/tests/qc/test_task_qc_viewer.py | 1 + requirements.txt | 2 +- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/ibllib/qc/task_qc_viewer/task_qc.py b/ibllib/qc/task_qc_viewer/task_qc.py index a49c703eb..7b75589a4 100644 --- a/ibllib/qc/task_qc_viewer/task_qc.py +++ b/ibllib/qc/task_qc_viewer/task_qc.py @@ -288,14 +288,17 @@ def show_session_task_qc(qc_or_session=None, bpod_only=False, local=False, one=N # Update table and callbacks n_trials = qc.frame.shape[0] - df_trials = pd.DataFrame({ - k: v for k, v in task_qc.extractor.data.items() - if v.size == n_trials and not k.startswith('wheel') - }) + if 'task_qc' in locals(): + df_trials = pd.DataFrame({ + k: v for k, v in task_qc.extractor.data.items() + if v.size == n_trials and not k.startswith('wheel') + }) + df = df_trials.merge(qc.frame, left_index=True, right_index=True) + else: + df = qc.frame df_pass = pd.DataFrame({k: v for k, v in qc.qc.passed.items() if isinstance(v, np.ndarray) and v.size == n_trials}) df_pass.drop('_task_passed_trial_checks', axis=1, errors='ignore', inplace=True) df_pass.rename(columns=lambda x: x.replace('_task', 'passed'), inplace=True) - df = df_trials.merge(qc.frame, left_index=True, right_index=True) df = df.merge(df_pass.astype('boolean'), left_index=True, right_index=True) w.updateDataframe(df) qt.run_app() diff --git a/ibllib/tests/qc/test_task_qc_viewer.py b/ibllib/tests/qc/test_task_qc_viewer.py index 6db045f91..7115f371f 100644 --- a/ibllib/tests/qc/test_task_qc_viewer.py +++ b/ibllib/tests/qc/test_task_qc_viewer.py @@ -66,6 +66,7 @@ def test_show_session_task_qc(self, trials_tasks_mock, run_app_mock): qc_mock.compute_session_status.return_value = ('Fail', qc_mock.metrics, {'foo': 'FAIL'}) qc_mock.extractor.data = {'intervals': np.array([[0, 1]])} qc_mock.extractor.frame_ttls = qc_mock.extractor.audio_ttls = qc_mock.extractor.bpod_ttls = mock.MagicMock() + qc_mock.passed = dict() active_task = mock.Mock(spec=ChoiceWorldTrialsNidq, unsafe=True) active_task.run_qc.return_value = qc_mock diff --git a/requirements.txt b/requirements.txt index 5f2fc9c35..b890b3e5e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,7 +25,7 @@ tqdm>=4.32.1 iblatlas>=0.5.3 ibl-neuropixel>=1.5.0 iblutil>=1.13.0 -iblqt>=0.2.0 +iblqt>=0.3.2 mtscomp>=1.0.1 ONE-api>=2.11 phylib>=2.6.0 From ad682fea463a6d1e0d6aaf9b785ff61bfc66b0da Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Tue, 17 Dec 2024 15:51:27 +0000 Subject: [PATCH 55/59] Update task_qc.py --- ibllib/qc/task_qc_viewer/task_qc.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ibllib/qc/task_qc_viewer/task_qc.py b/ibllib/qc/task_qc_viewer/task_qc.py index 7b75589a4..b9f212a5c 100644 --- a/ibllib/qc/task_qc_viewer/task_qc.py +++ b/ibllib/qc/task_qc_viewer/task_qc.py @@ -241,7 +241,8 @@ def show_session_task_qc(qc_or_session=None, bpod_only=False, local=False, one=N if isinstance(qc_or_session, QcFrame): qc = qc_or_session elif isinstance(qc_or_session, TaskQC): - qc = QcFrame(qc_or_session) + task_qc = qc_or_session + qc = QcFrame(task_qc) else: # assumed to be eid or session path one = one or ONE(mode='local' if local else 'auto') if not is_session_path(Path(qc_or_session)): From 4d94d8c97c2f74da8e437996d01a01872986ac42 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Wed, 11 Dec 2024 13:54:40 +0200 Subject: [PATCH 56/59] Exclude spacers from dud protocols --- ibllib/io/extractors/ephys_fpga.py | 21 +++++++++++++++++---- ibllib/pipes/behavior_tasks.py | 2 +- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/ibllib/io/extractors/ephys_fpga.py b/ibllib/io/extractors/ephys_fpga.py index 2980eb7bf..4da3d6bd8 100644 --- a/ibllib/io/extractors/ephys_fpga.py +++ b/ibllib/io/extractors/ephys_fpga.py @@ -69,7 +69,7 @@ """int: The number of encoder pulses per channel for one complete rotation.""" BPOD_FPGA_DRIFT_THRESHOLD_PPM = 150 -"""int: Throws an error if Bpod to FPGA clock drift is higher than this value.""" +"""int: Logs a warning if Bpod to FPGA clock drift is higher than this value.""" CHMAPS = {'3A': {'ap': @@ -545,17 +545,23 @@ def get_main_probe_sync(session_path, bin_exists=False): return sync, sync_chmap -def get_protocol_period(session_path, protocol_number, bpod_sync): +def get_protocol_period(session_path, protocol_number, bpod_sync, exclude_empty_periods=True): """ + Return the start and end time of the protocol number. + + Note that the start time is the start of the spacer pulses and the end time is either None + if the protocol is the final one, or the start of the next spacer. Parameters ---------- session_path : str, pathlib.Path The absolute session path, i.e. '/path/to/subject/yyyy-mm-dd/nnn'. protocol_number : int - The order that the protocol was run in. + The order that the protocol was run in, counted from 0. bpod_sync : dict The sync times and polarities for Bpod BNC1. + exclude_empty_periods : bool + When true, spacers are ignored if no bpod pulses are detected between periods. Returns ------- @@ -565,7 +571,14 @@ def get_protocol_period(session_path, protocol_number, bpod_sync): The time of the next detected spacer or None if this is the last protocol run. """ # The spacers are TTLs generated by Bpod at the start of each protocol - spacer_times = Spacer().find_spacers_from_fronts(bpod_sync) + sp = Spacer() + spacer_times = sp.find_spacers_from_fronts(bpod_sync) + if exclude_empty_periods: + # Drop dud protocol spacers (those without any bpod pulses after the spacer) + spacer_length = len(sp.generate_template(fs=1000)) / 1000 + periods = np.c_[spacer_times + spacer_length, np.r_[spacer_times[1:], np.inf]] + valid = [np.any((bpod_sync['times'] > pp[0]) & (bpod_sync['times'] < pp[1])) for pp in periods] + spacer_times = spacer_times[valid] # Ensure that the number of detected spacers matched the number of expected tasks if acquisition_description := session_params.read_params(session_path): n_tasks = len(acquisition_description.get('tasks', [])) diff --git a/ibllib/pipes/behavior_tasks.py b/ibllib/pipes/behavior_tasks.py index 3f519a10c..be75cf0d6 100644 --- a/ibllib/pipes/behavior_tasks.py +++ b/ibllib/pipes/behavior_tasks.py @@ -20,7 +20,7 @@ from ibllib.pipes import training_status from ibllib.plots.figures import BehaviourPlots -_logger = logging.getLogger('ibllib') +_logger = logging.getLogger(__name__) class HabituationRegisterRaw(base_tasks.RegisterRawDataTask, base_tasks.BehaviourTask): From f8d7c6a3a7398a3dcc8b9a9e10d3deef3ed051d2 Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Wed, 18 Dec 2024 10:14:25 +0000 Subject: [PATCH 57/59] release notes and version number --- ibllib/__init__.py | 2 +- release_notes.md | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/ibllib/__init__.py b/ibllib/__init__.py index 3525165e1..c1857164d 100644 --- a/ibllib/__init__.py +++ b/ibllib/__init__.py @@ -2,7 +2,7 @@ import logging import warnings -__version__ = '3.1.0' +__version__ = '3.2.0' warnings.filterwarnings('always', category=DeprecationWarning, module='ibllib') # if this becomes a full-blown library we should let the logging configuration to the discretion of the dev diff --git a/release_notes.md b/release_notes.md index e2eb6ce78..2292f42fe 100644 --- a/release_notes.md +++ b/release_notes.md @@ -1,3 +1,12 @@ +## Release Note 3.2.0 + +### features +- Add session delay info during registration of Bpod session +- Add detailed criteria info to behaviour plots + +### Bugfixes +- Read in updated json keys from task settings to establish ready4recording + ## Release Note 3.1.0 ### features From 9a0dd190e8dd8a61f2693ac5d536fa44601b4e84 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Wed, 18 Dec 2024 10:18:47 +0000 Subject: [PATCH 58/59] Update release_notes.md --- release_notes.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/release_notes.md b/release_notes.md index 2292f42fe..2a63db6ec 100644 --- a/release_notes.md +++ b/release_notes.md @@ -3,6 +3,8 @@ ### features - Add session delay info during registration of Bpod session - Add detailed criteria info to behaviour plots +- Add column filtering, sorting and color-coding of values to metrics table of + task_qc_viewer ### Bugfixes - Read in updated json keys from task settings to establish ready4recording From 549265e18b18f663726687dc07a66957a887cc36 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Wed, 18 Dec 2024 12:22:27 +0200 Subject: [PATCH 59/59] Update release notes --- release_notes.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/release_notes.md b/release_notes.md index 2a63db6ec..a89c2b60d 100644 --- a/release_notes.md +++ b/release_notes.md @@ -3,11 +3,12 @@ ### features - Add session delay info during registration of Bpod session - Add detailed criteria info to behaviour plots -- Add column filtering, sorting and color-coding of values to metrics table of +- Add column filtering, sorting and color-coding of values to metrics table of task_qc_viewer ### Bugfixes - Read in updated json keys from task settings to establish ready4recording +- Handle extraction of sessions with dud spacers ## Release Note 3.1.0