Skip to content

Commit

Permalink
Avoid caching trodes timestamps for memory efficiency
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelbray32 committed Oct 3, 2023
1 parent b59d063 commit 94592ad
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 10 deletions.
43 changes: 33 additions & 10 deletions src/spikegadgets_to_nwb/spike_gadgets_raw_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,22 +436,45 @@ def _get_analogsignal_chunk(

return raw_unit16

@functools.lru_cache(maxsize=None)
def get_analogsignal_timestamps(self, i_start, i_stop):
raw_uint8 = self._raw_memmap[
i_start:i_stop, self._timestamp_byte : self._timestamp_byte + 4
]
raw_uint32 = raw_uint8.view("uint8").reshape(-1, 4).view("uint32").reshape(-1)
if not self.interpolate_dropped_packets:
# no interpolation
raw_uint8 = self._raw_memmap[
i_start:i_stop, self._timestamp_byte : self._timestamp_byte + 4
]
raw_uint32 = (
raw_uint8.view("uint8").reshape(-1, 4).view("uint32").reshape(-1)
)
return raw_uint32

if self.interpolate_dropped_packets and self.interpolate_index is None:
# first call in a interpolation iterator, needs to find the dropped packets
# has to run through the entire file to find missing packets
raw_uint8 = self._raw_memmap[
:, self._timestamp_byte : self._timestamp_byte + 4
]
raw_uint32 = (
raw_uint8.view("uint8").reshape(-1, 4).view("uint32").reshape(-1)
)
self.interpolate_index = np.where(np.diff(raw_uint32) == 2)[
0
] # find locations of single dropped packets
self._interpolate_raw_memmap() # interpolates in the memmap
return np.insert(
raw_uint32,
self.interpolate_index + 1,
raw_uint32[self.interpolate_index] + 1,
)[i_start:i_stop]

# subsequent calls in a interpolation iterator don't remake the interpolated memmap, start here
if i_stop is None:
i_stop = self._raw_memmap.shape[0]
raw_uint8 = self._raw_memmap[
i_start:i_stop, self._timestamp_byte : self._timestamp_byte + 4
]
raw_uint32 = raw_uint8.view("uint8").reshape(-1, 4).view("uint32").reshape(-1)
# add +1 to the inserted locations
inserted_locations = np.array(self._raw_memmap.inserted_locations) - i_start + 1
inserted_locations = inserted_locations[
(inserted_locations >= 0) & (inserted_locations < i_stop - i_start)
]
if not len(inserted_locations) == 0:
raw_uint32[inserted_locations] += 1
return raw_uint32

def get_sys_clock(self, i_start, i_stop):
Expand Down
9 changes: 9 additions & 0 deletions src/spikegadgets_to_nwb/tests/test_spikegadgets_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def test_spikegadgets_raw_io_interpolation():
# get the trodes timestamps from each to compare. This also generates the interpolation
trodes_timestamps = neo_io.get_analogsignal_timestamps(0, 10)
trodes_timestamps_dropped = neo_io_dropped.get_analogsignal_timestamps(0, 10)
trodes_timestamps_dropped_secondary = neo_io_dropped.get_analogsignal_timestamps(
0, 10
)

# check that the interpolated memmap returns the same shape value
assert isinstance(neo_io_dropped._raw_memmap, InsertedMemmap)
Expand All @@ -49,6 +52,12 @@ def test_spikegadgets_raw_io_interpolation():
assert np.isclose(
trodes_timestamps, trodes_timestamps_dropped, atol=1e-6, rtol=0
).all()
assert np.isclose(
trodes_timestamps_dropped,
trodes_timestamps_dropped_secondary,
atol=1e-6,
rtol=0,
).all()
# make sure systime behaves expectedly
systime = neo_io.get_sys_clock(0, 10)
systime_dropped = neo_io_dropped.get_sys_clock(0, 10)
Expand Down

0 comments on commit 94592ad

Please sign in to comment.