Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpenEphysBinaryRawIO: Fixing ttl multichan #1603

Merged
merged 6 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 53 additions & 20 deletions neo/rawio/openephysbinaryrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ def _parse_header(self):
if name + "_npy" in info:
data = np.load(info[name + "_npy"], mmap_mode="r")
info[name] = data

# check that events have timestamps
assert "timestamps" in info, "Event stream does not have timestamps!"
# Updates for OpenEphys v0.6:
Expand Down Expand Up @@ -253,30 +252,64 @@ def _parse_header(self):
# 'states' was introduced in OpenEphys v0.6. For previous versions, events used 'channel_states'
if "states" in info or "channel_states" in info:
states = info["channel_states"] if "channel_states" in info else info["states"]

if states.size > 0:
timestamps = info["timestamps"]
labels = info["labels"]
rising = np.where(states > 0)[0]
falling = np.where(states < 0)[0]

# infer durations
# Identify unique channels based on state values
channels = np.unique(np.abs(states))

rising_indices = []
falling_indices = []

# all channels are packed into the same `states` array.
# So the states array includes positive and negative values for each channel:
# for example channel one rising would be +1 and channel one falling would be -1,
# channel two rising would be +2 and channel two falling would be -2, etc.
# This is the case for sure for version >= 0.6.x.
for channel in channels:
# Find rising and falling edges for each channel
vigji marked this conversation as resolved.
Show resolved Hide resolved
rising = np.where(states == channel)[0]
falling = np.where(states == -channel)[0]

# Ensure each rising has a corresponding falling
if rising.size > 0 and falling.size > 0:
if rising[0] > falling[0]:
falling = falling[1:]
if rising.size > falling.size:
rising = rising[:-1]

# ensure that the number of rising and falling edges are the same:
if len(rising) != len(falling):
warn(
f"Channel {channel} has {len(rising)} rising edges and "
f"{len(falling)} falling edges. The number of rising and "
f"falling edges should be equal. Skipping events from this channel."
)
continue

rising_indices.extend(rising)
falling_indices.extend(falling)
vigji marked this conversation as resolved.
Show resolved Hide resolved

rising_indices = np.array(rising_indices)
falling_indices = np.array(falling_indices)

# Sort the indices to maintain chronological order
sorted_order = np.argsort(rising_indices)
rising_indices = rising_indices[sorted_order]
falling_indices = falling_indices[sorted_order]

durations = None
if len(states) > 0:
# make sure first event is rising and last is falling
if states[0] < 0:
falling = falling[1:]
if states[-1] > 0:
rising = rising[:-1]

if len(rising) == len(falling):
durations = timestamps[falling] - timestamps[rising]
if not self._use_direct_evt_timestamps:
timestamps = timestamps / info["sample_rate"]
durations = durations / info["sample_rate"]

info["rising"] = rising
info["timestamps"] = timestamps[rising]
info["labels"] = labels[rising]
# if len(rising_indices) == len(falling_indices):
durations = timestamps[falling_indices] - timestamps[rising_indices]
if not self._use_direct_evt_timestamps:
timestamps = timestamps / info["sample_rate"]
durations = durations / info["sample_rate"]

info["rising"] = rising_indices
info["timestamps"] = timestamps[rising_indices]
info["labels"] = labels[rising_indices]
info["durations"] = durations

# no spike read yet
Expand Down
21 changes: 21 additions & 0 deletions neo/test/rawiotest/test_openephysbinaryrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from neo.rawio.openephysbinaryrawio import OpenEphysBinaryRawIO
from neo.test.rawiotest.common_rawio_test import BaseTestRawIO

import numpy as np


class TestOpenEphysBinaryRawIO(BaseTestRawIO, unittest.TestCase):
rawioclass = OpenEphysBinaryRawIO
Expand Down Expand Up @@ -57,6 +59,25 @@ def test_missing_folders(self):
)
rawio.parse_header()

def test_multiple_ttl_events_parsing(self):
rawio = OpenEphysBinaryRawIO(
self.get_local_path("openephysbinary/v0.6.x_neuropixels_with_sync"), load_sync_channel=False
)
rawio.parse_header()
rawio.header = rawio.header
# Testing co
# This is the TTL events from the NI Board channel
ttl_events = rawio._evt_streams[0][0][1]
assert "rising" in ttl_events.keys()
assert "labels" in ttl_events.keys()
assert "durations" in ttl_events.keys()
assert "timestamps" in ttl_events.keys()

# Check that durations of different event streams are correctly parsed:
assert np.allclose(ttl_events["durations"][ttl_events["labels"] == "1"], 0.5, atol=0.001)
assert np.allclose(ttl_events["durations"][ttl_events["labels"] == "6"], 0.025, atol=0.001)
assert np.allclose(ttl_events["durations"][ttl_events["labels"] == "7"], 0.016666, atol=0.001)


if __name__ == "__main__":
unittest.main()
Loading