Skip to content

Commit

Permalink
default_revisions_only parameter in One.list_datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
k1o0 committed Sep 26, 2024
1 parent 124d0b4 commit a543e08
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ This version improves behaviour of loading revisions and loading datasets from l

- sub-collections no longer captured when filtering with filename that starts with wildcard in wildcard mode
- bugfix of spurious error raised when loading dataset with a revision provided
- default_revisions_only parameter in One.list_datasets filters non-default datasets

### Added

Expand Down
17 changes: 13 additions & 4 deletions one/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ def list_subjects(self) -> List[str]:
@util.refresh
def list_datasets(
self, eid=None, filename=None, collection=None, revision=None, qc=QC.FAIL,
ignore_qc_not_set=False, details=False, query_type=None
ignore_qc_not_set=False, details=False, query_type=None, default_revisions_only=False
) -> Union[np.ndarray, pd.DataFrame]:
"""
Given an eid, return the datasets for those sessions.
Expand Down Expand Up @@ -734,6 +734,9 @@ def list_datasets(
relative paths (collection/revision/filename) - see one.alf.spec.describe for details.
query_type : str
Query cache ('local') or Alyx database ('remote').
default_revisions_only : bool
When true, only matching datasets that are considered default revisions are returned.
If no 'default_revision' column is present, and ALFError is raised.
Returns
-------
Expand Down Expand Up @@ -763,6 +766,11 @@ def list_datasets(
>>> datasets = one.list_datasets(eid, {'object': ['wheel', 'trial?']})
"""
datasets = self._cache['datasets']
if default_revisions_only:
if 'default_revision' not in datasets.columns:
raise alferr.ALFError('No default revisions specified')
datasets = datasets[datasets['default_revision']]

filter_args = dict(
collection=collection, filename=filename, wildcards=self.wildcards, revision=revision,
revision_last_before=False, assert_unique=False, qc=qc,
Expand Down Expand Up @@ -1766,11 +1774,11 @@ def describe_dataset(self, dataset_type=None):
@util.refresh
def list_datasets(
self, eid=None, filename=None, collection=None, revision=None, qc=QC.FAIL,
ignore_qc_not_set=False, details=False, query_type=None
ignore_qc_not_set=False, details=False, query_type=None, default_revisions_only=False
) -> Union[np.ndarray, pd.DataFrame]:
filters = dict(
collection=collection, filename=filename, revision=revision,
qc=qc, ignore_qc_not_set=ignore_qc_not_set)
collection=collection, filename=filename, revision=revision, qc=qc,
ignore_qc_not_set=ignore_qc_not_set, default_revisions_only=default_revisions_only)
if (query_type or self.mode) != 'remote':
return super().list_datasets(eid, details=details, query_type=query_type, **filters)
elif not eid:
Expand All @@ -1785,6 +1793,7 @@ def list_datasets(
if datasets is None or datasets.empty:
return self._cache['datasets'].iloc[0:0] if details else [] # Return empty
assert set(datasets.index.unique('eid')) == {eid}
del filters['default_revisions_only']
datasets = util.filter_datasets(
datasets.droplevel('eid'), assert_unique=False, wildcards=self.wildcards, **filters)
# Return only the relative path
Expand Down
22 changes: 20 additions & 2 deletions one/tests/test_one.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,25 @@ def test_list_datasets(self):
self.assertIsInstance(dsets, list)
self.assertTrue(len(dsets) == np.unique(dsets).size)

# Test default_revisions_only=True
with self.assertRaises(alferr.ALFError): # should raise as no 'default_revision' column
self.one.list_datasets('KS005/2019-04-02/001', default_revisions_only=True)
# Add the column and add some alternates
datasets = util.revisions_datasets_table(collections=['alf'], revisions=['', '2023-01-01'])
datasets['default_revision'] = [False, True] * 2
self.one._cache.datasets['default_revision'] = True
self.one._cache.datasets = pd.concat([self.one._cache.datasets, datasets]).sort_index()
eid, *_ = datasets.index.get_level_values(0)
dsets = self.one.list_datasets(eid, 'spikes.*', default_revisions_only=False)
self.assertEqual(4, len(dsets))
dsets = self.one.list_datasets(eid, 'spikes.*', default_revisions_only=True)
self.assertEqual(2, len(dsets))
self.assertTrue(all('#2023-01-01#' in x for x in dsets))
# Should be the same with details=True
dsets = self.one.list_datasets(eid, 'spikes.*', default_revisions_only=True, details=True)
self.assertEqual(2, len(dsets))
self.assertTrue(all('#2023-01-01#' in x for x in dsets.rel_path))

def test_list_collections(self):
"""Test One.list_collections"""
# Test no eid
Expand Down Expand Up @@ -1994,8 +2013,7 @@ def test_revision_last_before(self):
"""Test one.util.filter_revision_last_before"""
datasets = util.revisions_datasets_table()
df = datasets[datasets.rel_path.str.startswith('alf/probe00')].copy()
verifiable = filter_revision_last_before(df,
revision='2020-09-01', assert_unique=False)
verifiable = filter_revision_last_before(df, revision='2020-09-01', assert_unique=False)
self.assertTrue(len(verifiable) == 2)

# Remove one of the datasets' revisions to test assert unique on mixed revisions
Expand Down
2 changes: 1 addition & 1 deletion one/tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def revisions_datasets_table(collections=('', 'alf/probe00', 'alf/probe01'),
d = {
'rel_path': rel_path,
'session_path': 'subject/1900-01-01/001',
'file_size': None,
'file_size': 0,
'hash': None,
'exists': True,
'qc': 'NOT_SET',
Expand Down
6 changes: 5 additions & 1 deletion one/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,11 @@ def filter_datasets(
else:
return match

return filter_revision_last_before(match, revision, assert_unique=assert_unique)
match = filter_revision_last_before(match, revision, assert_unique=assert_unique)
if assert_unique and len(match) > 1:
_list = '"' + '", "'.join(match['rel_path']) + '"'
raise alferr.ALFMultipleObjectsFound(_list)
return match


def filter_revision_last_before(
Expand Down

0 comments on commit a543e08

Please sign in to comment.