Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

memmap refactoring #1265

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions neo/rawio/openephysbinaryrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from .baserawio import (BaseRawIO, _signal_channel_dtype, _signal_stream_dtype,
_spike_channel_dtype, _event_channel_dtype)
from .utils import create_memmap_buffer, get_memmap_shape


class OpenEphysBinaryRawIO(BaseRawIO):
Expand Down Expand Up @@ -133,8 +134,8 @@ def _parse_header(self):
for seg_index in range(nb_segment_per_block[block_index]):
for stream_index, d in self._sig_streams[block_index][seg_index].items():
num_channels = len(d['channels'])
memmap_sigs = np.memmap(d['raw_filename'], d['dtype'],
order='C', mode='r').reshape(-1, num_channels)
#~ memmap_sigs = np.memmap(d['raw_filename'], d['dtype'],
#~ order='C', mode='r').reshape(-1, num_channels)
channel_names = [ch["channel_name"] for ch in d["channels"]]
# if there is a sync channel and it should not be loaded,
# find the right channel index and slice the memmap
Expand All @@ -145,12 +146,19 @@ def _parse_header(self):

# only sync channel in last position is supported to keep memmap
if sync_channel_index == num_channels - 1:
memmap_sigs = memmap_sigs[:, :-1]
#~ memmap_sigs = memmap_sigs[:, :-1]
#~ pass
d['remove_last_channel'] = True
else:
raise NotImplementedError("SYNC channel removal is only supported "
"when the sync channel is in the last "
"position")
d['memmap'] = memmap_sigs
else:
d['remove_last_channel'] = False
# d['memmap'] = memmap_sigs
shape = get_memmap_shape(d['raw_filename'], d['dtype'], num_channels=num_channels)
fid = open(d['raw_filename'], mode="rb")
d['memmap_args'] = (fid, shape, np.dtype(d['dtype']), 0)


# events zone
Expand Down Expand Up @@ -248,7 +256,9 @@ def _parse_header(self):
# loop over signals
for stream_index, d in self._sig_streams[block_index][seg_index].items():
t_start = d['t_start']
dur = d['memmap'].shape[0] / float(d['sample_rate'])
#~ dur = d['memmap'].shape[0] / float(d['sample_rate'])
memmap_sigs = create_memmap_buffer(*d['memmap_args'])
dur = memmap_sigs.shape[0] / float(d['sample_rate'])
t_stop = t_start + dur
if global_t_start is None or global_t_start > t_start:
global_t_start = t_start
Expand Down Expand Up @@ -327,6 +337,14 @@ def _parse_header(self):
arr_ann = arr_ann[selected_indices]
ev_ann['__array_annotations__'][k] = arr_ann

def __del__(self):
# need an explicit close
for block_index in range(self.header['nb_block']):
for seg_index in range(self.header['nb_segment'][block_index]):
for stream_index, d in self._sig_streams[block_index][seg_index].items():
fid, *_ = d['memmap_args']
fid.close()

def _segment_t_start(self, block_index, seg_index):
return self._t_start_segments[block_index][seg_index]

Expand All @@ -343,16 +361,21 @@ def _channels_to_group_id(self, channel_indexes):
return group_id

def _get_signal_size(self, block_index, seg_index, stream_index):
sigs = self._sig_streams[block_index][seg_index][stream_index]['memmap']
return sigs.shape[0]
#~ sigs = self._sig_streams[block_index][seg_index][stream_index]['memmap']
memmap_sigs = create_memmap_buffer(*self._sig_streams[block_index][seg_index][stream_index]['memmap_args'])
return memmap_sigs.shape[0]

def _get_signal_t_start(self, block_index, seg_index, stream_index):
t_start = self._sig_streams[block_index][seg_index][stream_index]['t_start']
return t_start

def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop,
stream_index, channel_indexes):
sigs = self._sig_streams[block_index][seg_index][stream_index]['memmap']
#~ sigs = self._sig_streams[block_index][seg_index][stream_index]['memmap']
d = self._sig_streams[block_index][seg_index][stream_index]
sigs = create_memmap_buffer(*d['memmap_args'])
if d['remove_last_channel']:
sigs = sigs[:, :-1]
sigs = sigs[i_start:i_stop, :]
if channel_indexes is not None:
sigs = sigs[:, channel_indexes]
Expand Down
34 changes: 25 additions & 9 deletions neo/rawio/spikeglxrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@

from .baserawio import (BaseRawIO, _signal_channel_dtype, _signal_stream_dtype,
_spike_channel_dtype, _event_channel_dtype)
from .utils import create_memmap_buffer, get_memmap_shape

from pathlib import Path
import os
Expand Down Expand Up @@ -91,7 +92,8 @@ def _parse_header(self):

nb_segment = np.unique([info['seg_index'] for info in self.signals_info_list]).size

self._memmaps = {}
# self._memmaps = {}
self._memmap_args = {}
self.signals_info_dict = {}
for info in self.signals_info_list:
# key is (seg_index, stream_name)
Expand All @@ -100,11 +102,14 @@ def _parse_header(self):
self.signals_info_dict[key] = info

# create memmap
data = np.memmap(info['bin_file'], dtype='int16', mode='r', offset=0, order='C')
# this should be (info['sample_length'], info['num_chan'])
# be some file are shorten
data = data.reshape(-1, info['num_chan'])
self._memmaps[key] = data
#~ data = np.memmap(info['bin_file'], dtype='int16', mode='r', offset=0, order='C')
#~ # this should be (info['sample_length'], info['num_chan'])
#~ # be some file are shorten
#~ data = data.reshape(-1, info['num_chan'])
#~ self._memmaps[key] = data
shape = get_memmap_shape(info['bin_file'], 'int16', num_channels= info['num_chan'])
fid = open(info['bin_file'], "rb")
self._memmap_args[key] = (fid, shape, np.dtype('int16'), 0)

# create channel header
signal_streams = []
Expand Down Expand Up @@ -182,7 +187,13 @@ def _parse_header(self):
loc = np.concatenate((loc, [[0., 0.]]), axis=0)
for ndim in range(loc.shape[1]):
sig_ann['__array_annotations__'][f'channel_location_{ndim}'] = loc[:, ndim]


def __del__(self):
# need an explicit close
for k, args in self._memmap_args.items():
fid, *_ = args
fid.close()

def _segment_t_start(self, block_index, seg_index):
return 0.

Expand All @@ -191,7 +202,9 @@ def _segment_t_stop(self, block_index, seg_index):

def _get_signal_size(self, block_index, seg_index, stream_index):
stream_id = self.header['signal_streams'][stream_index]['id']
memmap = self._memmaps[seg_index, stream_id]
#~ memmap = self._memmaps[seg_index, stream_id]
key = (seg_index, stream_id)
memmap = create_memmap_buffer(*self._memmap_args[key])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't you use get_memmap_shape to avoid creating a buffer here?

return int(memmap.shape[0])

def _get_signal_t_start(self, block_index, seg_index, stream_index):
Expand All @@ -200,7 +213,10 @@ def _get_signal_t_start(self, block_index, seg_index, stream_index):
def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop,
stream_index, channel_indexes):
stream_id = self.header['signal_streams'][stream_index]['id']
memmap = self._memmaps[seg_index, stream_id]
#~ memmap = self._memmaps[seg_index, stream_id]
key = (seg_index, stream_id)
memmap = create_memmap_buffer(*self._memmap_args[key])

if channel_indexes is None:
if self.load_sync_channel:
channel_selection = slice(None)
Expand Down
34 changes: 34 additions & 0 deletions neo/rawio/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import mmap
import numpy as np

def get_memmap_shape(filename, dtype, num_channels=None, offset=0):
dtype = np.dtype(dtype)
with open(filename, mode='rb') as f:
f.seek(0, 2)
flen = f.tell()
bytes = flen - offset
if bytes % dtype.itemsize != 0:
raise ValueError("Size of available data is not a multiple of the data-type size.")
size = bytes // dtype.itemsize
if num_channels is None:
shape = (size,)
else:
shape = (size // num_channels, num_channels)
return shape

def create_memmap_buffer(fid, shape, dtype, offset=0):
"""
A function that mimic the np.memmap but:
* use an already opened file as input without checking the file size.
* it handles also only the ready only case
This should be faster.
"""
dtype = np.dtype(dtype)
size = np.prod(shape, dtype='int64')
bytes = dtype.itemsize * size
start = offset - offset % mmap.ALLOCATIONGRANULARITY
bytes -= start
array_offset = offset - start
mmap_buffer = mmap.mmap(fid.fileno(), bytes, access=mmap.ACCESS_READ, offset=start)
arr = np.ndarray.__new__(np.ndarray, shape, dtype=dtype, buffer=mmap_buffer, offset=array_offset, order='c')
return arr