From 6363690cece6d2e9f10342511db1eee87d3863a8 Mon Sep 17 00:00:00 2001 From: zottelsheep <31415206+zottelsheep@users.noreply.github.com> Date: Mon, 25 Mar 2024 11:52:51 +0100 Subject: [PATCH] [Feature] Merging "Total spiking probability edges" into elephant (#560) * Feat: Add basic scaffold for total_spiking_probability_edges * Feat: Add filterpair generation * Feat: Add normalized_cross_correlation * Feat: Add total_spiking_probability_edges * Refactor: Rename filter parameters * Feat: Add function to compute connectivity_matrix * Feat: Combine total_spiking_probability_edges and get_connectivty_matrix * Docs: Explenation why mean values are ommited in NCC * Refactor: Naming conventions * Fix: Extra dimension on delay_matrix * Test: Add test for total_spiking_probability_edges * Fix: Incorrect repo path * Fix Incorrect repo path part 2 * Docs: Add further documentation * docs: Fix typos Co-authored-by: Moritz Kern <92092328+Moritz-Alexander-Kern@users.noreply.github.com> * fixed pep8 issue in total_spiking_probability_edges.py * fixed pep8 issue in test_total_spiking_probability_edges.py * convert tests to unittest.TestCase classes * add zenodo info --------- Co-authored-by: Moritz Kern <92092328+Moritz-Alexander-Kern@users.noreply.github.com> Co-authored-by: Moritz-Alexander-Kern --- .zenodo.json | 5 + doc/authors.rst | 1 + doc/bib/elephant.bib | 14 + doc/modules.rst | 6 + .../functional_connectivity_estimation.rst | 5 + elephant/functional_connectivity.py | 31 ++ .../total_spiking_probability_edges.py | 343 ++++++++++++++++++ .../test_total_spiking_probability_edges.py | 219 +++++++++++ 8 files changed, 624 insertions(+) create mode 100644 doc/reference/functional_connectivity_estimation.rst create mode 100644 elephant/functional_connectivity.py create mode 100644 elephant/functional_connectivity_src/total_spiking_probability_edges.py create mode 100644 elephant/test/test_total_spiking_probability_edges.py diff --git a/.zenodo.json b/.zenodo.json index 5f12a3a85..fe61aecff 100644 --- a/.zenodo.json +++ b/.zenodo.json @@ -9,6 +9,11 @@ "orcid": "0000-0001-7292-1982", "affiliation": "Institute of Neuroscience and Medicine (INM-6) and Institute for Advanced Simulation (IAS-6) and JARA-Institute Brain Structure-Function Relationships (INM-10), Jülich Research Centre, Jülich, Germany", "name": "Kern, Moritz" + }, + { + "orcid": "0009-0003-9352-9826", + "affiliation": "BioMEMS Lab, University of Applied Sciences Aschaffenburg, Germany", + "name": "Richter, Felician" } ], diff --git a/doc/authors.rst b/doc/authors.rst index 4af092105..4018a71bc 100644 --- a/doc/authors.rst +++ b/doc/authors.rst @@ -51,6 +51,7 @@ contribution, and may not be the current affiliation of a contributor. * Florian Porrmann [13] * Sarah Pilz [13] * Oliver Kloß [1] +* Felician Richter [12] 1. Institute of Neuroscience and Medicine (INM-6) and Institute for Advanced Simulation (IAS-6) and JARA-Institute Brain Structure-Function Relationships (INM-10), Jülich Research Centre, Jülich, Germany 2. Unité de Neurosciences, Information et Complexité, CNRS UPR 3293, Gif-sur-Yvette, France diff --git a/doc/bib/elephant.bib b/doc/bib/elephant.bib index 46ecea203..0ad8654d1 100644 --- a/doc/bib/elephant.bib +++ b/doc/bib/elephant.bib @@ -466,3 +466,17 @@ @article{Deger12_443 title = {Statistical properties of superimposed stationary spike trains}, volume = 32, year = 2012} + + +@article{de_blasi19_169, + title = {Total spiking probability edges: {A} cross-correlation based method for effective connectivity estimation of cortical spiking neurons}, + volume = {312}, + shorttitle = {Total spiking probability edges}, + doi = {10.1016/j.jneumeth.2018.11.013}, + language = {en}, + journal = {Journal of Neuroscience Methods}, + author = {{De Blasi}, Stefano and Ciba, Manuel and Bahmer, Andreas and Thielemann, Christiane}, + month = jan, + year = {2019}, + pages = {169--181}, +} diff --git a/doc/modules.rst b/doc/modules.rst index d076db1e2..b89c733bb 100644 --- a/doc/modules.rst +++ b/doc/modules.rst @@ -32,11 +32,17 @@ Spike trains reference/_spike_train_processing reference/_spike_train_patterns +.. toctree:: + :maxdepth: 1 + + reference/functional_connectivity_estimation.rst + .. toctree:: :maxdepth: 1 reference/change_point_detection reference/gpfa + .. toctree:: :maxdepth: 1 diff --git a/doc/reference/functional_connectivity_estimation.rst b/doc/reference/functional_connectivity_estimation.rst new file mode 100644 index 000000000..b9689937d --- /dev/null +++ b/doc/reference/functional_connectivity_estimation.rst @@ -0,0 +1,5 @@ +================================== +Functional connectivity estimation +================================== + +.. automodule:: elephant.functional_connectivity diff --git a/elephant/functional_connectivity.py b/elephant/functional_connectivity.py new file mode 100644 index 000000000..dc5e764cc --- /dev/null +++ b/elephant/functional_connectivity.py @@ -0,0 +1,31 @@ +""" +Functions for analysing and estimating firing-patterns and connectivity among neurons +in order to better understand the underlying neural-networks and information-flow +between neurons. + + +Network connectivity estimation +******************************* + +.. autosummary:: + :toctree: _toctree/functional_connectivity/ + + total_spiking_probability_edges + +References +---------- + +.. bibliography:: + :keyprefix: functional_connectivity- + + +:copyright: Copyright 2014-2023 by the Elephant team, see `doc/authors.rst`. +:license: Modified BSD, see LICENSE.txt for details. +""" + +from elephant.functional_connectivity_src.total_spiking_probability_edges import ( + total_spiking_probability_edges, +) + +__all__ = ["total_spiking_probability_edges"] + diff --git a/elephant/functional_connectivity_src/total_spiking_probability_edges.py b/elephant/functional_connectivity_src/total_spiking_probability_edges.py new file mode 100644 index 000000000..b1964aa0b --- /dev/null +++ b/elephant/functional_connectivity_src/total_spiking_probability_edges.py @@ -0,0 +1,343 @@ +import itertools +from typing import Iterable, List, NamedTuple, Union, Optional + +import numpy as np +from scipy.signal import oaconvolve + +from elephant.conversion import BinnedSpikeTrain + + +def total_spiking_probability_edges( + spike_trains: BinnedSpikeTrain, + surrounding_window_sizes: Optional[List[int]] = None, + observed_window_sizes: Optional[List[int]] = None, + crossover_window_sizes: Optional[List[int]] = None, + max_delay: int = 25, + normalize: bool = False, +): + r""" + Use total spiking probability edges (TSPE) to estimate + the functional connectivity and delay-times of a neural-network. + + This algorithm uses a normalized cross correlation between pairs of + spiketrains at different delay-times to get a cross-correlogram. + Afterwards a series of convolutions with multiple edge-filters + on the cross-correlogram are preformed, in order to estimate the + connectivity between neurons and thus allowing the discrimination + between inhibitory and excitatory effects. + + The default window-sizes and max-delay were optimized using + in-silico generated spiketrains. + + *Background:* + + - On an excitatory connection the spikerate increases and decreases again + due to the refractory period which results in local maxima in the + cross-correlogram followed by downwards slope + + - On an inhibitory connection the spikerate decreases and after refractory + period, increases again which results in lokal minima surrounded by high + values in the cross-correlogram. + + - An Edge-Filter can be used to interpret the cross-correlogram and + accentuate the lokal Maxima and Minima + + *Procedure:* + + 1) Compute normalized cross-correlation :math:`NCC` of spiketrains of all + Neuronpairs + 2) Convolve :math:`NCC` with Edge-Filter :math:`g_{i}` to compute + :math:`SPE` + 3) Convolve :math:`SPE` with corresponding Running-Total-Filter + :math:`h_{i}` to account for different lengths after convolution with + Edge-Filter + 4) Compute :math:`TSPE` using the sum of all :math:`SPE` for all different + filterpairs + 5) Compute connectivitymatrix by using the index of the tspe-values with + the highest absolute values + + *Normalized Cross-Correlation:* + + .. math :: + + NCC_{XY}(d) = \frac{1}{N} \sum_{i=-\infty}^{\infty}{ \frac{ (y_{(i)} - + \bar{y}) \cdot (x_{(i-d)} - \bar{x}) }{ \sigma_x \cdot \sigma_y }} + + *Spiking Probability Edges* + + .. math :: + SPE_{X \rightarrow Y(d)} = NCC_{XY}(d) * g(i) + + *Total Spiking Probability Edges:* + + .. math :: + TSPE_{X \rightarrow Y}(d) = \sum_{n=1}^{N_a \cdot N_b \cdot N_c} + {SPE_{X \rightarrow Y}^{(n)}(d) * h(i)^{(n)} } + + :cite:`functional_connectivity-de_blasi19_169` + + Parameters + ---------- + spike_trains : (N, ) elephant.conversion.BinnedSpikeTrain + A binned spike train containing all neurons for connectivity estimation + surrounding_window_sizes : List[int], default = [3, 4, 5, 6, 7, 8] + Array of window-sizes for the surrounding area of the point of + interest. + observed_window_sizes : List[int], default = [2, 3, 4, 5, 6] + Array of window-sizes for the observed area + crossover_window_sizes : List[int], default = [0] + Array of window-sizes for the crossover between surrounding and + observed window. + max_delay : int, default = 25 + Defines the max delay when performing the normalized crosscorrelations. + Value depends on the bin-size of the BinnedSpikeTrain. + On a bin-size of *1ms* a value of *25* corresponds to *25ms* + normalize : bool, optional + Normalize the output [experimental] + + Returns + ------- + connectivity_matrix : (N, N) np.ndarray + Square Matrix of the connectivity estimation between neurons. + Positive values describe an excitatory connection while + negative values describe an inhibitory connection. + delay_matrix : (N, N) np.ndarray + Square Matrix of the delay_times between neuron-activity. + """ + + if not surrounding_window_sizes: + surrounding_window_sizes = [3, 4, 5, 6, 7, 8] + + if not observed_window_sizes: + observed_window_sizes = [2, 3, 4, 5, 6] + + if not crossover_window_sizes: + crossover_window_sizes = [0] + + n_neurons, n_bins = spike_trains.shape + + filter_pairs = generate_filter_pairs( + surrounding_window_sizes, observed_window_sizes, crossover_window_sizes + ) + + # Calculate normalized cross corelation for different delays + # The delay range ranges from 0 to max-delay and includes + # padding for the filter convolution + max_padding = max(surrounding_window_sizes) + max(crossover_window_sizes) + delay_times = list(range(-max_padding, max_delay + max_padding)) + NCC_d = normalized_cross_correlation(spike_trains, delay_times=delay_times) + + # Normalize to counter network-bursts + if normalize: + for delay_time in delay_times: + NCC_d[:, :, delay_time] /= np.sum( + NCC_d[:, :, delay_time][~np.identity(NCC_d.shape[0], + dtype=bool)] + ) + + # Apply edge and running total filter + tspe_matrix = np.zeros((n_neurons, n_neurons, max_delay)) + for filter in filter_pairs: + # Select ncc_window based on needed filter padding + NCC_window = NCC_d[ + :, + :, + max_padding + - filter.needed_padding: max_delay + + max_padding + + filter.needed_padding, + ] + + # Compute two convolutions with edge- and running total filter + x1 = oaconvolve( + NCC_window, np.expand_dims(filter.edge_filter, (0, 1)), + mode="valid", axes=2 + ) + x2 = oaconvolve( + x1, np.expand_dims(filter.running_total_filter, (0, 1)), + mode="full", axes=2 + ) + + tspe_matrix += x2 + + # Take maxima of absolute of delays to get estimation for connectivity + connectivity_matrix_index = np.argmax(np.abs(tspe_matrix), + axis=2, keepdims=True) + connectivity_matrix = np.take_along_axis(tspe_matrix, + connectivity_matrix_index, axis=2 + ).squeeze(axis=2) + delay_matrix = connectivity_matrix_index.squeeze() + + return connectivity_matrix, delay_matrix + + +def normalized_cross_correlation( + spike_trains: BinnedSpikeTrain, + delay_times: Union[int, List[int], Iterable[int]] = 0, +) -> np.ndarray: + r""" + Normalized cross correlation using std deviation + + Computes the normalized_cross_correlation between all + Spiketrains inside a BinnedSpikeTrain-Object at a given delay_time + + The underlying formula is: + + .. math:: + NCC_{X\arrY(d)} = \frac{1}{N_{bins}}\sum_{i=-\inf}^{\inf}{ + \frac{(y_{(i)} - \bar{y}) \cdot (x_{(i-d) - \bar{x})}{\sigma_x + \cdot \sigma_y}}} + + The subtraction of mean-values is omitted, since it offers little added + accuracy but increases the compute-time immensely. + """ + + n_neurons, n_bins = spike_trains.shape + + # Get sparse array of BinnedSpikeTrain + spike_trains_array = spike_trains.sparse_matrix + + # Get std deviation of spike trains + spike_trains_zeroed = spike_trains_array - spike_trains_array.mean(axis=1) + spike_trains_std = np.std(spike_trains_zeroed, ddof=1, axis=1) + std_factors = spike_trains_std @ spike_trains_std.T + + # Loop over delay times + if isinstance(delay_times, int): + delay_times = [delay_times] + elif isinstance(delay_times, list): + pass + elif isinstance(delay_times, Iterable): + delay_times = list(delay_times) + + NCC_d = np.zeros((len(delay_times), n_neurons, n_neurons)) + + for index, delay_time in enumerate(delay_times): + # Uses theoretical zero-padding for shifted values, + # but since $0 \cdot x = 0$ values can simply be omitted + if delay_time == 0: + CC = spike_trains_array[:, :] @ spike_trains_array[:, : + ].transpose() + + elif delay_time > 0: + CC = ( + spike_trains_array[:, delay_time:] + @ spike_trains_array[:, :-delay_time].transpose() + ) + + else: + CC = ( + spike_trains_array[:, :delay_time] + @ spike_trains_array[:, -delay_time:].transpose() + ) + + # Convert CC to dense matrix before performing the division + CC = CC.toarray() + # Normalize using std deviation + NCC = CC / std_factors / n_bins + + # Compute cross correlation at given delay time + NCC_d[index, :, :] = NCC + + # Move delay_time axis to back of array + # Makes index using neurons more intuitive → (n_neuron, n_neuron, + # delay_times) + NCC_d = np.moveaxis(NCC_d, 0, -1) + + return NCC_d + + +def generate_edge_filter( + surrounding_window_size: int, + observed_window_size: int, + crossover_window_size: int, +) -> np.ndarray: + r"""Generate an edge filter + + The edge filter is generated using following piecewise defined function: + + a = surrounding_window_size + b = observed_window_size + c = crossover_window_size + + .. math:: + g_{(i)} = \begin{cases} + - \frac{1}{a} & 0 \lt i \leq a \\ + \frac{2}{b} & a+c \lt i \leq a + b + c \\ + - \frac{1}{a} & a+b+2c \lt i \leq 2a + b + 2c + \end{cases} + + """ + filter_length = ( + (2 * surrounding_window_size) + + observed_window_size + + (2 * crossover_window_size) + ) + i = np.arange(1, filter_length + 1, dtype=np.float64) + + conditions = [ + (i > 0) & (i <= surrounding_window_size), + (i > (surrounding_window_size + crossover_window_size)) + & (i <= surrounding_window_size + observed_window_size + + crossover_window_size), + ( + i + > surrounding_window_size + + observed_window_size + + (2 * crossover_window_size) + ) + & ( + i + <= (2 * surrounding_window_size) + + observed_window_size + + (2 * crossover_window_size) + ), + ] + + values = [ + -(1 / surrounding_window_size), + 2 / observed_window_size, + -(1 / surrounding_window_size), + 0, + ] # Default Value + + edge_filter = np.piecewise(i, conditions, values) + + return edge_filter + + +def generate_running_total_filter(observed_window_size: int) -> np.ndarray: + return np.ones(observed_window_size) + + +class TspeFilterPair(NamedTuple): + edge_filter: np.ndarray + running_total_filter: np.ndarray + needed_padding: int + surrounding_window_size: int + observed_window_size: int + crossover_window_size: int + + +def generate_filter_pairs( + surrounding_window_sizes: List[int], + observed_window_sizes: List[int], + crossover_window_sizes: List[int], +) -> List[TspeFilterPair]: + """Generates filter pairs of edge and running total filter using all + permutations of given parameters + """ + filter_pairs = [] + + for _a, _b, _c in itertools.product( + surrounding_window_sizes, observed_window_sizes, crossover_window_sizes + ): + edge_filter = generate_edge_filter(_a, _b, _c) + running_total_filter = generate_running_total_filter(_b) + needed_padding = _a + _c + filter_pair = TspeFilterPair( + edge_filter, running_total_filter, needed_padding, _a, _b, _c + ) + filter_pairs.append(filter_pair) + + return filter_pairs diff --git a/elephant/test/test_total_spiking_probability_edges.py b/elephant/test/test_total_spiking_probability_edges.py new file mode 100644 index 000000000..60bfa5e2f --- /dev/null +++ b/elephant/test/test_total_spiking_probability_edges.py @@ -0,0 +1,219 @@ +import unittest +from pathlib import Path +from typing import Tuple, Union + +from neo import SpikeTrain +import numpy as np +from quantities import millisecond as ms +from scipy.io import loadmat + +from elephant.conversion import BinnedSpikeTrain +from elephant.functional_connectivity_src.total_spiking_probability_edges \ + import (generate_filter_pairs, + normalized_cross_correlation, + TspeFilterPair, + total_spiking_probability_edges, + ) + +from elephant.datasets import download_datasets + + +class TotalSpikingProbabilityEdgesTestCase(unittest.TestCase): + def test_generate_filter_pairs(self): + a = [1] + b = [1] + c = [1] + test_output = [ + TspeFilterPair( + edge_filter=np.array([-1.0, 0.0, 2.0, 0.0, -1.0]), + running_total_filter=np.array([1.0]), + needed_padding=2, + surrounding_window_size=1, + observed_window_size=1, + crossover_window_size=1, + ) + ] + + function_output = generate_filter_pairs(a, b, c) + + for filter_pair_function, filter_pair_test in zip(function_output, + test_output): + np.testing.assert_array_equal( + filter_pair_function.edge_filter, + filter_pair_test.edge_filter) + + np.testing.assert_array_equal( + filter_pair_function.running_total_filter, + filter_pair_test.running_total_filter) + + self.assertEqual(filter_pair_function.needed_padding, + filter_pair_test.needed_padding) + + self.assertEqual(filter_pair_function.surrounding_window_size, + filter_pair_test.surrounding_window_size) + + self.assertEqual(filter_pair_function.observed_window_size, + filter_pair_test.observed_window_size) + + self.assertEqual(filter_pair_function.crossover_window_size, + filter_pair_test.crossover_window_size) + + def test_normalized_cross_correlation(self): + # Generate Spiketrains + delay_time = 5 + spike_times = [3, 4, 5] * ms + spike_times_delayed = spike_times + delay_time * ms + + spiketrains = BinnedSpikeTrain( + [SpikeTrain(spike_times, t_stop=20.0 * ms), + SpikeTrain(spike_times_delayed, t_stop=20.0 * ms),], + bin_size=1 * ms, + ) + + test_output = np.array([[[0.0, 0.0], [1.1, 0.0]], [[0.0, 1.1], + [0.0, 0.0]]]) + + function_output = normalized_cross_correlation( + spiketrains, [-delay_time, delay_time] + ) + + assert np.allclose(function_output, test_output, 0.1) + + def test_total_spiking_probability_edges(self): + files = ["SW/new_sim0_100.mat", + "BA/new_sim0_100.mat", + "CA/new_sim0_100.mat", + "ER05/new_sim0_100.mat", + "ER10/new_sim0_100.mat", + "ER15/new_sim0_100.mat", + ] + + for datafile in files: + repo_base_path = 'unittest/functional_connectivity/' \ + 'total_spiking_probability_edges/data/' + downloaded_dataset_path = download_datasets(repo_base_path + + datafile) + + spiketrains, original_data = load_spike_train_simulated( + downloaded_dataset_path) + + connectivity_matrix, delay_matrix = \ + total_spiking_probability_edges(spiketrains) + + # Remove self-connections + np.fill_diagonal(connectivity_matrix, 0) + + _, _, _, auc = roc_curve(connectivity_matrix, original_data) + + self.assertGreater(auc, 0.95) + +# ====== HELPER FUNCTIONS ====== + + +def classify_connections(connectivity_matrix: np.ndarray, threshold: int): + connectivity_matrix_binarized = connectivity_matrix.copy() + + mask_excitatory = connectivity_matrix_binarized > threshold + mask_inhibitory = connectivity_matrix_binarized < -threshold + + mask_left = ~ (mask_excitatory + mask_inhibitory) + + connectivity_matrix_binarized[mask_excitatory] = 1 + connectivity_matrix_binarized[mask_inhibitory] = -1 + connectivity_matrix_binarized[mask_left] = 0 + + return connectivity_matrix_binarized + + +def confusion_matrix(estimate, original, threshold: int = 1): + """ + Definition: + - TP: Matches of connections are True Positive + - FP: Mismatches are False Positive, + - TN: Matches for non-existing synapses are True Negative + - FN: mismatches are False Negative. + """ + if not np.all(np.isin([-1, 0, 1], np.unique(estimate))): + estimate = classify_connections(estimate, threshold) + if not np.all(np.isin([-1, 0, 1], np.unique(original))): + original = classify_connections(original, threshold) + + TP = (np.not_equal(estimate, 0) & np.not_equal(original, 0)).sum() + + TN = (np.equal(estimate, 0) & np.equal(original, 0)).sum() + + FP = (np.not_equal(estimate, 0) & np.equal(original, 0)).sum() + + FN = (np.equal(estimate, 0) & np.not_equal(original, 0)).sum() + + return TP, TN, FP, FN + + +def fall_out(TP: int, TN: int, FP: int, FN: int): + FPR = FP / (FP + TN) + return FPR + + +def sensitivity(TP: int, TN: int, FP: int, FN: int): + TPR = TP / (TP + FN) + return TPR + + +def roc_curve(estimate, original): + tpr_list = [] + fpr_list = [] + + max_threshold = max(np.max(np.abs(estimate)), 1) + + thresholds = np.linspace(max_threshold, 0, 30) + + for t in thresholds: + conf_matrix = confusion_matrix(estimate, original, threshold=t) + + tpr_list.append(sensitivity(*conf_matrix)) + fpr_list.append(fall_out(*conf_matrix)) + + auc = np.trapz(tpr_list, fpr_list) + + return tpr_list, fpr_list, thresholds, auc + + +def load_spike_train_simulated(path: Union[Path, str], bin_size=None, + t_stop=None, + ) -> Tuple[BinnedSpikeTrain, np.ndarray]: + if isinstance(path, str): + path = Path(path) + + if not bin_size: + bin_size = 1 * ms + + data = loadmat(path, simplify_cells=True)["data"] + + if "asdf" not in data: + raise ValueError('Incorrect Dataformat: Missing spiketrain_data in' + '"asdf"') + + spiketrain_data = data["asdf"] + + # Get number of electrodesa and recording_duration from last element of + # data array + n_electrodes, recording_duration_ms = spiketrain_data[-1] + recording_duration_ms = recording_duration_ms * ms + + # Create spiketrains + spiketrains = [] + for spiketrain_raw in spiketrain_data[0:n_electrodes]: + spiketrains.append( + SpikeTrain( + spiketrain_raw * ms, + t_stop=recording_duration_ms, + ) + ) + + spiketrains = BinnedSpikeTrain(spiketrains, bin_size=bin_size, + t_stop=t_stop or recording_duration_ms) + + # Load original_data + original_data = data['SWM'].T + + return spiketrains, original_data