From 0f8e40299b1df3547a4016ce6af57e232120a0d3 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Wed, 25 Sep 2024 23:24:39 +0300 Subject: [PATCH] Resolves issue #123 and #65 --- CHANGELOG.md | 1 + one/api.py | 77 ++++++++++++++++++++++++++++++------------- one/tests/test_one.py | 39 ++++++++++++++++++---- one/tests/util.py | 11 ++++++- 4 files changed, 99 insertions(+), 29 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 366acc98..a57eb883 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ This version improves behaviour of loading revisions and loading datasets from l - sub-collections no longer captured when filtering with filename that starts with wildcard in wildcard mode - bugfix of spurious error raised when loading dataset with a revision provided - default_revisions_only parameter in One.list_datasets filters non-default datasets +- permit data frame input to One.load_datasets and load precise relative paths provided (instead of default revisions) ### Added diff --git a/one/api.py b/one/api.py index ef30d6ee..3a89fad3 100644 --- a/one/api.py +++ b/one/api.py @@ -28,10 +28,10 @@ import one.alf.files as alfiles import one.alf.exceptions as alferr from .alf.cache import make_parquet_db, DATASETS_COLUMNS, SESSIONS_COLUMNS -from .alf.spec import is_uuid_string, QC +from .alf.spec import is_uuid_string, QC, to_alf from . import __version__ from one.converters import ConversionMixin, session_record2path -import one.util as util +from one import util _logger = logging.getLogger(__name__) __all__ = ['ONE', 'One', 'OneAlyx'] @@ -1138,9 +1138,11 @@ def load_datasets(self, download_only: bool = False, check_hash: bool = True) -> Any: """ - Load datasets for a given session id. Returns two lists the length of datasets. The - first is the data (or file paths if download_data is false), the second is a list of - meta data Bunches. If assert_present is false, missing data will be returned as None. + Load datasets for a given session id. + + Returns two lists the length of datasets. The first is the data (or file paths if + download_data is false), the second is a list of meta data Bunches. If assert_present is + false, missing data will be returned as None. Parameters ---------- @@ -1172,9 +1174,9 @@ def load_datasets(self, Returns ------- list - A list of data (or file paths) the length of datasets + A list of data (or file paths) the length of datasets. list - A list of meta data Bunches. If assert_present is False, missing data will be None + A list of meta data Bunches. If assert_present is False, missing data will be None. Notes ----- @@ -1186,6 +1188,8 @@ def load_datasets(self, revision as separate keyword arguments. - To ensure you are loading the correct revision, use the revisions kwarg instead of relative paths. + - To load an exact revision (i.e. not the last revision before a given date), pass in + a list of relative paths or a data frame. Raises ------ @@ -1228,8 +1232,25 @@ def _verify_specifiers(specifiers): if isinstance(datasets, str): raise TypeError('`datasets` must be a non-string iterable') - # Check input args - collections, revisions = _verify_specifiers([collections, revisions]) + + # Check if rel paths have been used (e.g. the output of list_datasets) + is_frame = isinstance(datasets, pd.DataFrame) + if is_rel_paths := (is_frame or any('/' in x for x in datasets)): + if not (collections, revisions) == (None, None): + raise ValueError( + 'collection and revision kwargs must be None when dataset is a relative path') + if is_frame: + if 'eid' in datasets.index.names: + assert set(datasets.index.get_level_values('eid')) == {eid} + datasets = datasets['rel_path'].tolist() + datasets = list(map(partial(alfiles.rel_path_parts, as_dict=True), datasets)) + if len(datasets) > 0: + # Extract collection and revision from each of the parsed datasets + # None -> '' ensures exact collections and revisions are used in filter + # NB: f user passes in dicts, any collection/revision keys will be ignored. + collections, revisions = zip( + *((x.pop('collection') or '', x.pop('revision') or '') for x in datasets) + ) # Short circuit query_type = query_type or self.mode @@ -1243,16 +1264,25 @@ def _verify_specifiers(specifiers): if len(datasets) == 0: return None, all_datasets.iloc[0:0] # Return empty - # Filter and load missing - if self.wildcards: # Append extension wildcard if 'object.attribute' string - datasets = [x + ('.*' if isinstance(x, str) and len(x.split('.')) == 2 else '') - for x in datasets] + # More input validation + input_types = [(isinstance(x, str), isinstance(x, dict)) for x in datasets] + if not all(map(any, input_types)) or not any(map(all, zip(*input_types))): + raise ValueError('`datasets` must be iterable of only str or only dicts') + if self.wildcards and input_types[0][0]: # if wildcards and input is iter of str + # Append extension wildcard if 'object.attribute' string + datasets = [ + x + ('.*' if isinstance(x, str) and len(x.split('.')) == 2 else '') + for x in datasets + ] + + # Check input args + collections, revisions = _verify_specifiers([collections, revisions]) + # If collections provided in datasets list, e.g. [collection/x.y.z], do not assert unique - validate = not any(('/' if isinstance(d, str) else 'collection') in d for d in datasets) - if not validate and not all(x is None for x in collections + revisions): - raise ValueError( - 'collection and revision kwargs must be None when dataset is a relative path') - ops = dict(wildcards=self.wildcards, assert_unique=validate) + # If not a dataframe, use revision last before (we've asserted no revision in rel_path) + validate = not is_rel_paths + ops = dict( + wildcards=self.wildcards, assert_unique=True, revision_last_before=not is_rel_paths) slices = [util.filter_datasets(all_datasets, x, y, z, **ops) for x, y, z in zip(datasets, collections, revisions)] present = [len(x) == 1 for x in slices] @@ -1260,18 +1290,21 @@ def _verify_specifiers(specifiers): # Check if user is blindly downloading all data and warn of non-default revisions if 'default_revision' in present_datasets and \ - not any(revisions) and not all(present_datasets['default_revision']): + is_rel_paths and not all(present_datasets['default_revision']): old = present_datasets.loc[~present_datasets['default_revision'], 'rel_path'].to_list() warnings.warn( 'The following datasets may have been revised and ' + 'are therefore not recommended for analysis:\n\t' + '\n\t'.join(old) + '\n' - 'To avoid this warning, specify the revision as a kwarg or use load_dataset.' + 'To avoid this warning, specify the revision as a kwarg or use load_dataset.', + alferr.ALFWarning ) if not all(present): - missing_list = ', '.join(x for x, y in zip(datasets, present) if not y) - # FIXME include collection and revision also + missing_list = (x if isinstance(x, str) else to_alf(**x) for x in datasets) + missing_list = ('/'.join(filter(None, [c, f'#{r}#' if r else None, d])) + for c, r, d in zip(collections, revisions, missing_list)) + missing_list = ', '.join(x for x, y in zip(missing_list, present) if not y) message = f'The following datasets are not in the cache: {missing_list}' if assert_present: raise alferr.ALFObjectNotFound(message) diff --git a/one/tests/test_one.py b/one/tests/test_one.py index 2d240f9f..8bdc0901 100644 --- a/one/tests/test_one.py +++ b/one/tests/test_one.py @@ -654,16 +654,43 @@ def test_load_datasets(self): files, meta = self.one.load_datasets(eid, dsets, download_only=True) self.assertTrue(all(isinstance(x, Path) for x in files)) + # Check behaviour when loading with a data frame (undocumented) + eid = '01390fcc-4f86-4707-8a3b-4d9309feb0a1' + datasets = self.one._cache.datasets.loc[([eid],), :].iloc[:3, :] + files, meta = self.one.load_datasets(eid, datasets, download_only=True) + self.assertTrue(all(isinstance(x, Path) for x in files)) + # Should raise when data frame contains a different eid + self.assertRaises(AssertionError, self.one.load_datasets, uuid4(), datasets) + + # Mix of str and dict + # Check download only + dsets = [ + spec.regex(spec.FILE_SPEC).match('_ibl_wheel.position.npy').groupdict(), + '_ibl_wheel.timestamps.npy' + ] + with self.assertRaises(ValueError): + self.one.load_datasets('KS005/2019-04-02/001', dsets, assert_present=False) + # Loading of non default revisions without using the revision kwarg causes user warning. - # With relative paths provided as input, dataset uniqueness validation is supressed. + # With relative paths provided as input, dataset uniqueness validation is suppressed. + eid = self.one._cache.sessions.iloc[0].name datasets = util.revisions_datasets_table( - revisions=('', '2020-01-08'), attributes=('times',)) + revisions=('', '2020-01-08'), attributes=('times',), touch_path=self.one.cache_dir) datasets['default_revision'] = [False, True] * 3 - eid = datasets.iloc[0].name[0] + datasets.index = datasets.index.set_levels([eid], level=0) self.one._cache.datasets = datasets - with self.assertWarns(UserWarning): - self.one.load_datasets(eid, datasets['rel_path'].to_list(), - download_only=True, assert_present=False) + with self.assertWarns(alferr.ALFWarning): + self.one.load_datasets(eid, datasets['rel_path'].to_list(), download_only=True) + + # Ensure that when rel paths are passed, a null collection/revision is not interpreted as + # an ANY. NB this means the output of 'spikes.times.npy' will be different depending on + # weather other datasets in list include a collection or revision. + self.one._cache.datasets = datasets.iloc[:2, :].copy() # only two datasets, one default + (file,), _ = self.one.load_datasets(eid, ['spikes.times.npy', ], download_only=True) + self.assertTrue(file.as_posix().endswith('001/#2020-01-08#/spikes.times.npy')) + (file, _), _ = self.one.load_datasets( + eid, ['spikes.times.npy', 'xx/obj.attr.ext'], download_only=True, assert_present=False) + self.assertTrue(file.as_posix().endswith('001/spikes.times.npy')) # When loading without collections in the dataset list (i.e. just the dataset names) # an exception should be raised when datasets belong to multiple collections. diff --git a/one/tests/util.py b/one/tests/util.py index c1aa4263..b215ea09 100644 --- a/one/tests/util.py +++ b/one/tests/util.py @@ -116,7 +116,8 @@ def setup_test_params(token=False, cache_dir=None): def revisions_datasets_table(collections=('', 'alf/probe00', 'alf/probe01'), revisions=('', '2020-01-08', '2021-07-06'), object='spikes', - attributes=('times', 'waveforems')): + attributes=('times', 'waveforems'), + touch_path=None): """Returns a datasets cache DataFrame containing datasets with revision folders. As there are no revised datasets on the test databases, this function acts as a fixture for @@ -132,6 +133,8 @@ def revisions_datasets_table(collections=('', 'alf/probe00', 'alf/probe01'), An ALF object attributes : tuple A list of ALF attributes + touch_path : pathlib.Path, str + If provided, files are created in this directory. Returns ------- @@ -155,6 +158,12 @@ def revisions_datasets_table(collections=('', 'alf/probe00', 'alf/probe01'), 'id': map(str, (uuid4() for _ in rel_path)) } + if touch_path: + for p in rel_path: + path = Path(touch_path).joinpath(d['session_path'] + '/' + p) + path.parent.mkdir(parents=True, exist_ok=True) + path.touch() + return pd.DataFrame(data=d).astype({'qc': QC_TYPE}).set_index(['eid', 'id'])