From a543e083cf36932a64db4d74e99d60df4b53a533 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Wed, 25 Sep 2024 15:53:49 +0300 Subject: [PATCH] default_revisions_only parameter in One.list_datasets --- CHANGELOG.md | 1 + one/api.py | 17 +++++++++++++---- one/tests/test_one.py | 22 ++++++++++++++++++++-- one/tests/util.py | 2 +- one/util.py | 6 +++++- 5 files changed, 40 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 09a095ad..366acc98 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ This version improves behaviour of loading revisions and loading datasets from l - sub-collections no longer captured when filtering with filename that starts with wildcard in wildcard mode - bugfix of spurious error raised when loading dataset with a revision provided +- default_revisions_only parameter in One.list_datasets filters non-default datasets ### Added diff --git a/one/api.py b/one/api.py index 2e39dd2d..ab0d1fd7 100644 --- a/one/api.py +++ b/one/api.py @@ -701,7 +701,7 @@ def list_subjects(self) -> List[str]: @util.refresh def list_datasets( self, eid=None, filename=None, collection=None, revision=None, qc=QC.FAIL, - ignore_qc_not_set=False, details=False, query_type=None + ignore_qc_not_set=False, details=False, query_type=None, default_revisions_only=False ) -> Union[np.ndarray, pd.DataFrame]: """ Given an eid, return the datasets for those sessions. @@ -734,6 +734,9 @@ def list_datasets( relative paths (collection/revision/filename) - see one.alf.spec.describe for details. query_type : str Query cache ('local') or Alyx database ('remote'). + default_revisions_only : bool + When true, only matching datasets that are considered default revisions are returned. + If no 'default_revision' column is present, and ALFError is raised. Returns ------- @@ -763,6 +766,11 @@ def list_datasets( >>> datasets = one.list_datasets(eid, {'object': ['wheel', 'trial?']}) """ datasets = self._cache['datasets'] + if default_revisions_only: + if 'default_revision' not in datasets.columns: + raise alferr.ALFError('No default revisions specified') + datasets = datasets[datasets['default_revision']] + filter_args = dict( collection=collection, filename=filename, wildcards=self.wildcards, revision=revision, revision_last_before=False, assert_unique=False, qc=qc, @@ -1766,11 +1774,11 @@ def describe_dataset(self, dataset_type=None): @util.refresh def list_datasets( self, eid=None, filename=None, collection=None, revision=None, qc=QC.FAIL, - ignore_qc_not_set=False, details=False, query_type=None + ignore_qc_not_set=False, details=False, query_type=None, default_revisions_only=False ) -> Union[np.ndarray, pd.DataFrame]: filters = dict( - collection=collection, filename=filename, revision=revision, - qc=qc, ignore_qc_not_set=ignore_qc_not_set) + collection=collection, filename=filename, revision=revision, qc=qc, + ignore_qc_not_set=ignore_qc_not_set, default_revisions_only=default_revisions_only) if (query_type or self.mode) != 'remote': return super().list_datasets(eid, details=details, query_type=query_type, **filters) elif not eid: @@ -1785,6 +1793,7 @@ def list_datasets( if datasets is None or datasets.empty: return self._cache['datasets'].iloc[0:0] if details else [] # Return empty assert set(datasets.index.unique('eid')) == {eid} + del filters['default_revisions_only'] datasets = util.filter_datasets( datasets.droplevel('eid'), assert_unique=False, wildcards=self.wildcards, **filters) # Return only the relative path diff --git a/one/tests/test_one.py b/one/tests/test_one.py index e487708d..935f75d7 100644 --- a/one/tests/test_one.py +++ b/one/tests/test_one.py @@ -426,6 +426,25 @@ def test_list_datasets(self): self.assertIsInstance(dsets, list) self.assertTrue(len(dsets) == np.unique(dsets).size) + # Test default_revisions_only=True + with self.assertRaises(alferr.ALFError): # should raise as no 'default_revision' column + self.one.list_datasets('KS005/2019-04-02/001', default_revisions_only=True) + # Add the column and add some alternates + datasets = util.revisions_datasets_table(collections=['alf'], revisions=['', '2023-01-01']) + datasets['default_revision'] = [False, True] * 2 + self.one._cache.datasets['default_revision'] = True + self.one._cache.datasets = pd.concat([self.one._cache.datasets, datasets]).sort_index() + eid, *_ = datasets.index.get_level_values(0) + dsets = self.one.list_datasets(eid, 'spikes.*', default_revisions_only=False) + self.assertEqual(4, len(dsets)) + dsets = self.one.list_datasets(eid, 'spikes.*', default_revisions_only=True) + self.assertEqual(2, len(dsets)) + self.assertTrue(all('#2023-01-01#' in x for x in dsets)) + # Should be the same with details=True + dsets = self.one.list_datasets(eid, 'spikes.*', default_revisions_only=True, details=True) + self.assertEqual(2, len(dsets)) + self.assertTrue(all('#2023-01-01#' in x for x in dsets.rel_path)) + def test_list_collections(self): """Test One.list_collections""" # Test no eid @@ -1994,8 +2013,7 @@ def test_revision_last_before(self): """Test one.util.filter_revision_last_before""" datasets = util.revisions_datasets_table() df = datasets[datasets.rel_path.str.startswith('alf/probe00')].copy() - verifiable = filter_revision_last_before(df, - revision='2020-09-01', assert_unique=False) + verifiable = filter_revision_last_before(df, revision='2020-09-01', assert_unique=False) self.assertTrue(len(verifiable) == 2) # Remove one of the datasets' revisions to test assert unique on mixed revisions diff --git a/one/tests/util.py b/one/tests/util.py index fee30011..c1aa4263 100644 --- a/one/tests/util.py +++ b/one/tests/util.py @@ -147,7 +147,7 @@ def revisions_datasets_table(collections=('', 'alf/probe00', 'alf/probe01'), d = { 'rel_path': rel_path, 'session_path': 'subject/1900-01-01/001', - 'file_size': None, + 'file_size': 0, 'hash': None, 'exists': True, 'qc': 'NOT_SET', diff --git a/one/util.py b/one/util.py index da8359c5..743b0101 100644 --- a/one/util.py +++ b/one/util.py @@ -445,7 +445,11 @@ def filter_datasets( else: return match - return filter_revision_last_before(match, revision, assert_unique=assert_unique) + match = filter_revision_last_before(match, revision, assert_unique=assert_unique) + if assert_unique and len(match) > 1: + _list = '"' + '", "'.join(match['rel_path']) + '"' + raise alferr.ALFMultipleObjectsFound(_list) + return match def filter_revision_last_before(