Building and Interpretation of a Seizure Classifier from EEG Recordings Using Graph Signal Processing
This repository contains the code for the implementation of a seizure classifier and its interpretation described in the project report. The different parts of the implementation involve the following steps :
- Extraction and formatting of the seizures from the dataset. Run
data_preparation/build_data.py
to execute this step. - Computation of the graphs from the signals. Run
learn_graph_laplacian/laplacian.py
to compute the adjacency matrix, andlearn_graph_laplacian/covariance.py
to compute the covariance matrix. - Training of the classifier. All the classifiers can be found in
classifier/
. - The explainability analysis.
Read sh_scripts/README.md
to find out how to download the dataset.
The current implementation is based on the TUH EEG Corpus dataset v1.5.2 released in 2020. The dataset can be downloaded following the instructions on TUH EEG dataset.
The dataset can be easily downloaded using Rsync on Linux. The shell scripts containing the commands to execute the download can be found in sh_scripts/. If trying to download from Windows, Rsync is unfortunately not available, and you might need the help of MobaXterm to be able to run Rsync. A tutorial explaining how to do this can be found here.
Due to the high amount of data, high computational power might be required for some steps of the implementation, in particular the computation of the adjacency matrix, and the training of the classifier. For that reason, the SCITAS EPFL cluster was used to run the heavy scripts.
Find here how to set-up your access to the cluster (fidis.epfl.ch
works perfectly) and here how to use the clusters. If using VS Code, you can use the Remote Explorer extension to set-up the SSH connection and access the interface and your scripts on the cluster without having to navigate through them with the terminal. When running a python script on the cluster, you first need to allocate the power and time desired for the run, therefore you will have to run the appropriate scripts in scitas_run/
to run the python scripts instead of using the terminal command. The cluster does not offer the option to directly plot an output with Matplotlib, so you might need to transfer your outputs of the graph computation step on your local device to plot graphs or plot the available plots in the explainability step.
# To create the virtual environment on the cluster
module load gcc/8.4.0 python/3.7.7
virtualenv --system-site-packages env_name
# To activate the environment
source env_test/bin/activate
To download packages on the cluster, you will have to create your virtual environment (tutorial) and download the required packages on your activated virtual environment. For example for numpy, run (in LTS4/
) :
# Activate your virtual environment (adapt with your env name)
source ../rma_env/bin/activate
# Download the package (no_cache_dir option required on the cluster)
pip install --no-cache-dir numpy
The file requirements.txt
can be run with the command below to install the dependencies. The command --no-cache-dir
might only be required on the cluster. Note that as explained in the explainability section of classifier/README.md
, the version of numpy on the cluster might be too old and you might experience an error when updating it with the below requirements.txt
, check what's explained at the bottom of classifier/README.md
in case that happens.
pip install --no-cache-dir -r requirements.txt
Read data_preparation/README.md
to find out how to extract the seizure types of interest.
This step is an adaptation of the implementation of the work found on this Github repository, which itself is the implementation of the the work in SeizureNet: Multi-Spectral Deep Feature Learning for Seizure Type Classification (Asif et al., 2020). It provides a way to extract the chops of recordings during which a seizure event occurred and retrieve them by type.
Read learn_graph_laplacian/README.md
to find out how to compute the different graphs.
A graph representation of the connectivity between the 20-channel signals is computed with 2 different techniques :
- Covariance matrix
- Adjacency matrix by computing it back from the learned Laplacian matrix. Implementation of the framework described in Learning Laplacian Matrix in Smooth Graph Signal Representations (Dong et al., 2016). Their code can be found here.
Read classifier/README.md
to find out how to use the multiple classifiers and explain their output.
Multiple classifiers can be tried here, simple classifiers (kNN, Bayes, Logistic Regression, Decision Tree, SVM) can be trained running classifier/graph_classifier.py
. The Feed-forward Neural Network running classifier/FC_NN.py
, the Convolutional Neural Network using classifier/CNN.py
, and the 2-channel Convolutional Neural Network running classifier/dual_CNN.py
.
The model that yielded the best accuracy after training was obtained using the dual CNN, combining both the covariance and the adjacency matrix graph representations into the same input. The output of this model is explained using SHAP's DeepExplainer tool is also computed in classifier/dual_CNN.py
using the appropriate arguments.