Skip to content

Commit

Permalink
Merge pull request #121 from int-brain-lab/npzLoader
Browse files Browse the repository at this point in the history
Resolves #120
  • Loading branch information
k1o0 authored Jul 2, 2024
2 parents 4bed1ae + e0db315 commit eb90285
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
- name: run tests
run: |
flake8 .
coverage run --omit=*one/tests/* -m unittest discover
coverage run --omit=one/tests/* -m unittest discover
- name: Upload coverage data to coveralls.io
run: coveralls --service=github
env:
Expand Down
11 changes: 10 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# Changelog
## [Latest](https://github.com/int-brain-lab/ONE/commits/main) [2.7.0]
## [Latest](https://github.com/int-brain-lab/ONE/commits/main) [2.8.0]
This version of ONE adds support for loading .npz files.

### Modified

- one.alf.io.load_file_content loads npz files and returns only array if single compressed array with default name of 'arr_0'.
- log warning when instantiating RegistrationClient with AlyxClient REST cache active
- bugfix in load_collection when one or more files missing

## [2.7.0]
This version of ONE adds support for Alyx 2.0.0 and pandas 3.0.0 with dataset QC filters. This version no longer supports 'data' search filter.

### Added
Expand Down
7 changes: 3 additions & 4 deletions docs/notebooks/one_load/one_load.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@
"revision will be returned. The revisions are ordered lexicographically.\n",
"\n",
"```python\n",
"probe1_spikes = one.load_dataset(eid, 'trials.intervals.npy', revision='2021-03-15a')\n",
"intervals = one.load_dataset(eid, 'trials.intervals.npy', revision='2021-03-15a')\n",
"```\n",
"\n",
"## Download only\n",
Expand Down Expand Up @@ -662,8 +662,7 @@
" filepath = one.load_dataset(eid '_ibl_trials.intervals.npy', download_only=True)\n",
" spike_times = one.load_dataset(eid 'spikes.times.npy', collection='alf/probe01')\n",
" old_spikes = one.load_dataset(eid, 'spikes.times.npy',\n",
" collection='alf/probe01', revision='2020-08-31')\n",
"\n"
" collection='alf/probe01', revision='2020-08-31')\n"
]
}
],
Expand Down Expand Up @@ -733,4 +732,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}
2 changes: 1 addition & 1 deletion one/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
"""The Open Neurophysiology Environment (ONE) API."""
__version__ = '2.7.0'
__version__ = '2.8.0'
4 changes: 4 additions & 0 deletions one/alf/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,10 @@ def load_file_content(fil):
return jsonable.read(fil)
if fil.suffix == '.npy':
return _ensure_flat(np.load(file=fil, allow_pickle=True))
if fil.suffix == '.npz':
arr = np.load(file=fil)
# If single array with the default name ('arr_0') return individual array
return arr['arr_0'] if set(arr.files) == {'arr_0'} else arr
if fil.suffix == '.pqt':
return parquet.load(fil)[0]
if fil.suffix == '.ssv':
Expand Down
15 changes: 8 additions & 7 deletions one/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1357,7 +1357,7 @@ def load_collection(self,
Query cache ('local') or Alyx database ('remote')
download_only : bool
When true the data are downloaded and the file path is returned.
**kwargs
kwargs
Additional filters for datasets, including namespace and timescale. For full list
see the one.alf.spec.describe function.
Expand Down Expand Up @@ -1394,18 +1394,19 @@ def load_collection(self,
if len(datasets) == 0:
raise alferr.ALFObjectNotFound(object or '')
parts = [alfiles.rel_path_parts(x) for x in datasets.rel_path]
unique_objects = set(x[3] or '' for x in parts)

# For those that don't exist, download them
offline = None if query_type == 'auto' else self.mode == 'local'
files = self._check_filesystem(datasets, offline=offline)
files = [x for x in files if x]
if not files:
if not any(files):
raise alferr.ALFObjectNotFound(f'ALF collection "{collection}" not found on disk')
# Remove missing items
files, parts = zip(*[(x, y) for x, y in zip(files, parts) if x])

if download_only:
return files

unique_objects = set(x[3] or '' for x in parts)
kwargs.update(wildcards=self.wildcards)
collection = {
obj: alfio.load_object([x for x, y in zip(files, parts) if y[3] == obj], **kwargs)
Expand All @@ -1424,7 +1425,7 @@ def setup(cache_dir=None, silent=False, **kwargs):
silent : (False) bool
When True will prompt for cache_dir, if cache_dir is None, and overwrite cache if any.
When False will use cwd for cache_dir, if cache_dir is None, and use existing cache.
**kwargs
kwargs
Optional arguments to pass to one.alf.cache.make_parquet_db.
Returns
Expand Down Expand Up @@ -2498,7 +2499,7 @@ def setup(base_url=None, **kwargs):
----------
base_url : str
An Alyx database URL. If None, the current default database is used.
**kwargs
kwargs
Optional arguments to pass to one.params.setup.
Returns
Expand Down Expand Up @@ -2785,7 +2786,7 @@ def _setup(**kwargs):
Parameters
----------
**kwargs
kwargs
See one.params.setup.
Returns
Expand Down
3 changes: 3 additions & 0 deletions one/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def __init__(self, one=None):
self.one = one
if not one:
self.one = ONE(cache_rest=None)
elif one.alyx.cache_mode == 'GET':
_logger.warning('AlyxClient REST cache active; '
'this may cause issues with registration.')
self.dtypes = list(map(Bunch, self.one.alyx.rest('dataset-types', 'list')))
self.registration_patterns = [
dt['filename_pattern'] for dt in self.dtypes if dt['filename_pattern']]
Expand Down
11 changes: 11 additions & 0 deletions one/tests/alf/test_alf_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,10 @@ def setUp(self) -> None:
self.xyz = Path(self.tmpdir.name) / 'foo.baz.xyz'
with open(self.xyz, 'wb') as f:
f.write(b'\x00\x00')
self.npz1 = Path(self.tmpdir.name) / 'foo.baz.npz'
np.savez_compressed(self.npz1, np.random.rand(5))
self.npz2 = Path(self.tmpdir.name) / 'foo.bar.npz'
np.savez_compressed(self.npz2, np.random.rand(5), np.random.rand(5))

def test_load_file_content(self):
"""Test for one.alf.io.load_file_content"""
Expand Down Expand Up @@ -550,6 +554,13 @@ def test_load_file_content(self):
# Load YAML file
loaded = alfio.load_file_content(str(self.yaml))
self.assertCountEqual(loaded.keys(), ['a', 'b'])
# Load npz file
loaded = alfio.load_file_content(str(self.npz1))
self.assertIsInstance(loaded, np.ndarray, 'failed to unpack')
self.assertEqual(loaded.shape, (5,))
loaded = alfio.load_file_content(str(self.npz2))
self.assertIsInstance(loaded, np.lib.npyio.NpzFile, 'failed to return npz array')
self.assertEqual(loaded['arr_0'].shape, (5,))

def tearDown(self) -> None:
self.tmpdir.cleanup()
Expand Down
1 change: 1 addition & 0 deletions one/tests/test_alyxrest.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def test_channels(self):
self.addCleanup(self.alyx.rest, 'insertions', 'delete', id=probe_insertion['id'])
trajectory = self.alyx.rest('trajectories', 'create', data={
'probe_insertion': probe_insertion['id'],
'chronic_insertion': None,
'x': 1500,
'y': -2000,
'z': 0,
Expand Down
6 changes: 3 additions & 3 deletions one/tests/test_one.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,7 +1338,7 @@ def test_list_datasets(self):
self.one._cache['datasets'] = self.one._cache['datasets'].iloc[0:0].copy()

dsets = self.one.list_datasets(self.eid, details=True, query_type='remote')
self.assertEqual(171, len(dsets)) # this may change after a BWM release or patch
self.assertEqual(183, len(dsets)) # this may change after a BWM release or patch
self.assertEqual(1, dsets.index.nlevels, 'details data frame should be without eid index')

# Test missing eid
Expand All @@ -1355,12 +1355,12 @@ def test_list_datasets(self):
# Test details=False, with eid
dsets = self.one.list_datasets(self.eid, details=False, query_type='remote')
self.assertIsInstance(dsets, list)
self.assertEqual(171, len(dsets)) # this may change after a BWM release or patch
self.assertEqual(183, len(dsets)) # this may change after a BWM release or patch

# Test with other filters
dsets = self.one.list_datasets(self.eid, collection='*probe*', filename='*channels*',
details=False, query_type='remote')
self.assertEqual(20, len(dsets))
self.assertEqual(24, len(dsets))
self.assertTrue(all(x in y for x in ('probe', 'channels') for y in dsets))

with self.assertWarns(Warning):
Expand Down

0 comments on commit eb90285

Please sign in to comment.