diff --git a/neo/rawio/openephysbinaryrawio.py b/neo/rawio/openephysbinaryrawio.py index 27e3a80c9..297514b4e 100644 --- a/neo/rawio/openephysbinaryrawio.py +++ b/neo/rawio/openephysbinaryrawio.py @@ -217,7 +217,6 @@ def _parse_header(self): if name + "_npy" in info: data = np.load(info[name + "_npy"], mmap_mode="r") info[name] = data - # check that events have timestamps assert "timestamps" in info, "Event stream does not have timestamps!" # Updates for OpenEphys v0.6: @@ -253,30 +252,55 @@ def _parse_header(self): # 'states' was introduced in OpenEphys v0.6. For previous versions, events used 'channel_states' if "states" in info or "channel_states" in info: states = info["channel_states"] if "channel_states" in info else info["states"] + if states.size > 0: timestamps = info["timestamps"] labels = info["labels"] - rising = np.where(states > 0)[0] - falling = np.where(states < 0)[0] - # infer durations + # Identify unique channels based on state values + channels = np.unique(np.abs(states)) + + rising_indices = [] + falling_indices = [] + + # all channels are packed into the same `states` array. + # So the states array includes positive and negative values for each channel: + # for example channel one rising would be +1 and channel one falling would be -1, + # channel two rising would be +2 and channel two falling would be -2, etc. + # This is the case for sure for version >= 0.6.x. + for channel in channels: + # Find rising and falling edges for each channel + rising = np.where(states == channel)[0] + falling = np.where(states == -channel)[0] + + # Ensure each rising has a corresponding falling + if rising.size > 0 and falling.size > 0: + if rising[0] > falling[0]: + falling = falling[1:] + if rising.size > falling.size: + rising = rising[:-1] + + rising_indices.extend(rising) + falling_indices.extend(falling) + + rising_indices = np.array(rising_indices) + falling_indices = np.array(falling_indices) + + # Sort the indices to maintain chronological order + sorted_order = np.argsort(rising_indices) + rising_indices = rising_indices[sorted_order] + falling_indices = falling_indices[sorted_order] + durations = None - if len(states) > 0: - # make sure first event is rising and last is falling - if states[0] < 0: - falling = falling[1:] - if states[-1] > 0: - rising = rising[:-1] - - if len(rising) == len(falling): - durations = timestamps[falling] - timestamps[rising] - if not self._use_direct_evt_timestamps: - timestamps = timestamps / info["sample_rate"] - durations = durations / info["sample_rate"] - - info["rising"] = rising - info["timestamps"] = timestamps[rising] - info["labels"] = labels[rising] + if len(rising_indices) == len(falling_indices): + durations = timestamps[falling_indices] - timestamps[rising_indices] + if not self._use_direct_evt_timestamps: + timestamps = timestamps / info["sample_rate"] + durations = durations / info["sample_rate"] + + info["rising"] = rising_indices + info["timestamps"] = timestamps[rising_indices] + info["labels"] = labels[rising_indices] info["durations"] = durations # no spike read yet diff --git a/neo/test/rawiotest/test_openephysbinaryrawio.py b/neo/test/rawiotest/test_openephysbinaryrawio.py index 7df22e93d..0959ddf62 100644 --- a/neo/test/rawiotest/test_openephysbinaryrawio.py +++ b/neo/test/rawiotest/test_openephysbinaryrawio.py @@ -3,6 +3,8 @@ from neo.rawio.openephysbinaryrawio import OpenEphysBinaryRawIO from neo.test.rawiotest.common_rawio_test import BaseTestRawIO +import numpy as np + class TestOpenEphysBinaryRawIO(BaseTestRawIO, unittest.TestCase): rawioclass = OpenEphysBinaryRawIO @@ -57,6 +59,25 @@ def test_missing_folders(self): ) rawio.parse_header() + def test_multiple_ttl_events_parsing(self): + rawio = OpenEphysBinaryRawIO( + self.get_local_path("openephysbinary/v0.6.x_neuropixels_with_sync"), load_sync_channel=False + ) + rawio.parse_header() + rawio.header = rawio.header + # Testing co + # This is the TTL events from the NI Board channel + ttl_events = rawio._evt_streams[0][0][1] + assert "rising" in ttl_events.keys() + assert "labels" in ttl_events.keys() + assert "durations" in ttl_events.keys() + assert "timestamps" in ttl_events.keys() + + # Check that durations of different event streams are correctly parsed: + assert np.allclose(ttl_events["durations"][ttl_events["labels"] == "1"], 0.5, atol=0.001) + assert np.allclose(ttl_events["durations"][ttl_events["labels"] == "6"], 0.025, atol=0.001) + assert np.allclose(ttl_events["durations"][ttl_events["labels"] == "7"], 0.016666, atol=0.001) + if __name__ == "__main__": unittest.main()