Skip to content

Commit

Permalink
Added keep_eid_index kwarg to One.list_datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
k1o0 committed Oct 25, 2024
1 parent e7e20e3 commit 5f5cc9d
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 21 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

- prompt user to strip quotation marks if used during ONE setup
- indicate when downloading from S3
- added 'keep_eid_index' kwarg to One.list_datasets which will return the data frame with the eid index level reinstated

## [2.10.0]
This version improves behaviour of loading revisions and loading datasets from list_datasets output.
Expand Down
46 changes: 25 additions & 21 deletions one/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,8 @@ def list_subjects(self) -> List[str]:
@util.refresh
def list_datasets(
self, eid=None, filename=None, collection=None, revision=None, qc=QC.FAIL,
ignore_qc_not_set=False, details=False, query_type=None, default_revisions_only=False
ignore_qc_not_set=False, details=False, query_type=None, default_revisions_only=False,
keep_eid_index=False
) -> Union[np.ndarray, pd.DataFrame]:
"""
Given an eid, return the datasets for those sessions.
Expand Down Expand Up @@ -748,6 +749,10 @@ def list_datasets(
default_revisions_only : bool
When true, only matching datasets that are considered default revisions are returned.
If no 'default_revision' column is present, and ALFError is raised.
keep_eid_index : bool
If details is true, this determines whether the returned data frame contains the eid
in the index. When false (default) the returned data frame index is the dataset id
only, otherwise the index is a MultIndex with levels (eid, id).
Returns
-------
Expand Down Expand Up @@ -798,8 +803,15 @@ def list_datasets(
return datasets.iloc[0:0] # Return empty

datasets = util.filter_datasets(datasets, **filter_args)
# Return only the relative path
return datasets if details else datasets['rel_path'].sort_values().values.tolist()
if details:
if keep_eid_index and datasets.index.nlevels == 1:
# Reinstate eid index
datasets = pd.concat({str(eid): datasets}, names=['eid'])
# Return the full data frame
return datasets
else:
# Return only the relative path
return datasets['rel_path'].sort_values().values.tolist()

@util.refresh
def list_collections(self, eid=None, filename=None, collection=None, revision=None,
Expand Down Expand Up @@ -1000,7 +1012,8 @@ def load_object(self,
>>> load_object(eid, 'spikes', attribute=['times*', 'clusters'])
"""
query_type = query_type or self.mode
datasets = self.list_datasets(eid, details=True, query_type=query_type)
datasets = self.list_datasets(
eid, details=True, query_type=query_type, keep_eid_index=True)

if len(datasets) == 0:
raise alferr.ALFObjectNotFound(obj)
Expand All @@ -1022,9 +1035,6 @@ def load_object(self,

# For those that don't exist, download them
offline = None if query_type == 'auto' else self.mode == 'local'
if datasets.index.nlevels == 1:
# Reinstate eid index
datasets = pd.concat({str(eid): datasets}, names=['eid'])
files = self._check_filesystem(datasets, offline=offline, check_hash=check_hash)
files = [x for x in files if x]
if not files:
Expand Down Expand Up @@ -1112,7 +1122,8 @@ def load_dataset(self,
wildcards/regular expressions must not be used. To use wildcards, pass the collection
and revision as separate keyword arguments.
"""
datasets = self.list_datasets(eid, details=True, query_type=query_type or self.mode)
datasets = self.list_datasets(
eid, details=True, query_type=query_type or self.mode, keep_eid_index=True)
# If only two parts and wildcards are on, append ext wildcard
if self.wildcards and isinstance(dataset, str) and len(dataset.split('.')) == 2:
dataset += '.*'
Expand All @@ -1129,9 +1140,6 @@ def load_dataset(self,
wildcards=self.wildcards, assert_unique=assert_unique)
if len(datasets) == 0:
raise alferr.ALFObjectNotFound(f'Dataset "{dataset}" not found')
if datasets.index.nlevels == 1:
# Reinstate eid index
datasets = pd.concat({str(eid): datasets}, names=['eid'])

# Check files exist / download remote files
offline = None if query_type == 'auto' else self.mode == 'local'
Expand Down Expand Up @@ -1271,7 +1279,8 @@ def _verify_specifiers(specifiers):

# Short circuit
query_type = query_type or self.mode
all_datasets = self.list_datasets(eid, details=True, query_type=query_type)
all_datasets = self.list_datasets(
eid, details=True, query_type=query_type, keep_eid_index=True)
if len(all_datasets) == 0:
if assert_present:
raise alferr.ALFObjectNotFound(f'No datasets found for session {eid}')
Expand Down Expand Up @@ -1303,9 +1312,6 @@ def _verify_specifiers(specifiers):
for x, y, z in zip(datasets, collections, revisions)]
present = [len(x) == 1 for x in slices]
present_datasets = pd.concat(slices)
if present_datasets.index.nlevels == 1:
# Reinstate eid index
present_datasets = pd.concat({str(eid): present_datasets}, names=['eid'])

# Check if user is blindly downloading all data and warn of non-default revisions
if 'default_revision' in present_datasets and \
Expand Down Expand Up @@ -1463,8 +1469,8 @@ def load_collection(self,
No datasets match the object, attribute or revision filters for this collection.
"""
query_type = query_type or self.mode
datasets = self.list_datasets(eid, details=True, collection=collection,
query_type=query_type)
datasets = self.list_datasets(
eid, details=True, collection=collection, query_type=query_type, keep_eid_index=True)

if len(datasets) == 0:
raise alferr.ALFError(f'{collection} not found for session {eid}')
Expand All @@ -1477,9 +1483,6 @@ def load_collection(self,
if len(datasets) == 0:
raise alferr.ALFObjectNotFound(object or '')
parts = [alfiles.rel_path_parts(x) for x in datasets.rel_path]
if datasets.index.nlevels == 1:
# Reinstate eid index
datasets = pd.concat({str(eid): datasets}, names=['eid'])

# For those that don't exist, download them
offline = None if query_type == 'auto' else self.mode == 'local'
Expand Down Expand Up @@ -1829,7 +1832,8 @@ def describe_dataset(self, dataset_type=None):
@util.refresh
def list_datasets(
self, eid=None, filename=None, collection=None, revision=None, qc=QC.FAIL,
ignore_qc_not_set=False, details=False, query_type=None, default_revisions_only=False
ignore_qc_not_set=False, details=False, query_type=None, default_revisions_only=False,
keep_eid_index=False
) -> Union[np.ndarray, pd.DataFrame]:
filters = dict(
collection=collection, filename=filename, revision=revision, qc=qc,
Expand Down
5 changes: 5 additions & 0 deletions one/tests/test_one.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,11 @@ def test_list_datasets(self):
self.assertEqual(27, len(dsets))
self.assertEqual(1, dsets.index.nlevels, 'details data frame should be without eid index')

# Test keep_eid_index parameter
dsets = self.one.list_datasets('KS005/2019-04-02/001', details=True, keep_eid_index=True)
self.assertEqual(27, len(dsets))
self.assertEqual(2, dsets.index.nlevels, 'details data frame should be with eid index')

# Test filters
filename = {'attribute': ['times', 'intervals'], 'extension': 'npy'}
dsets = self.one.list_datasets('ZFM-01935/2021-02-05/001', filename)
Expand Down

0 comments on commit 5f5cc9d

Please sign in to comment.