diff --git a/CHANGELOG.md b/CHANGELOG.md index dddf14ff..d9746329 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ - prompt user to strip quotation marks if used during ONE setup - indicate when downloading from S3 +- added 'keep_eid_index' kwarg to One.list_datasets which will return the data frame with the eid index level reinstated ## [2.10.0] This version improves behaviour of loading revisions and loading datasets from list_datasets output. diff --git a/one/api.py b/one/api.py index c375f735..d09c4cad 100644 --- a/one/api.py +++ b/one/api.py @@ -712,7 +712,8 @@ def list_subjects(self) -> List[str]: @util.refresh def list_datasets( self, eid=None, filename=None, collection=None, revision=None, qc=QC.FAIL, - ignore_qc_not_set=False, details=False, query_type=None, default_revisions_only=False + ignore_qc_not_set=False, details=False, query_type=None, default_revisions_only=False, + keep_eid_index=False ) -> Union[np.ndarray, pd.DataFrame]: """ Given an eid, return the datasets for those sessions. @@ -748,6 +749,10 @@ def list_datasets( default_revisions_only : bool When true, only matching datasets that are considered default revisions are returned. If no 'default_revision' column is present, and ALFError is raised. + keep_eid_index : bool + If details is true, this determines whether the returned data frame contains the eid + in the index. When false (default) the returned data frame index is the dataset id + only, otherwise the index is a MultIndex with levels (eid, id). Returns ------- @@ -798,8 +803,15 @@ def list_datasets( return datasets.iloc[0:0] # Return empty datasets = util.filter_datasets(datasets, **filter_args) - # Return only the relative path - return datasets if details else datasets['rel_path'].sort_values().values.tolist() + if details: + if keep_eid_index and datasets.index.nlevels == 1: + # Reinstate eid index + datasets = pd.concat({str(eid): datasets}, names=['eid']) + # Return the full data frame + return datasets + else: + # Return only the relative path + return datasets['rel_path'].sort_values().values.tolist() @util.refresh def list_collections(self, eid=None, filename=None, collection=None, revision=None, @@ -1000,7 +1012,8 @@ def load_object(self, >>> load_object(eid, 'spikes', attribute=['times*', 'clusters']) """ query_type = query_type or self.mode - datasets = self.list_datasets(eid, details=True, query_type=query_type) + datasets = self.list_datasets( + eid, details=True, query_type=query_type, keep_eid_index=True) if len(datasets) == 0: raise alferr.ALFObjectNotFound(obj) @@ -1022,9 +1035,6 @@ def load_object(self, # For those that don't exist, download them offline = None if query_type == 'auto' else self.mode == 'local' - if datasets.index.nlevels == 1: - # Reinstate eid index - datasets = pd.concat({str(eid): datasets}, names=['eid']) files = self._check_filesystem(datasets, offline=offline, check_hash=check_hash) files = [x for x in files if x] if not files: @@ -1112,7 +1122,8 @@ def load_dataset(self, wildcards/regular expressions must not be used. To use wildcards, pass the collection and revision as separate keyword arguments. """ - datasets = self.list_datasets(eid, details=True, query_type=query_type or self.mode) + datasets = self.list_datasets( + eid, details=True, query_type=query_type or self.mode, keep_eid_index=True) # If only two parts and wildcards are on, append ext wildcard if self.wildcards and isinstance(dataset, str) and len(dataset.split('.')) == 2: dataset += '.*' @@ -1129,9 +1140,6 @@ def load_dataset(self, wildcards=self.wildcards, assert_unique=assert_unique) if len(datasets) == 0: raise alferr.ALFObjectNotFound(f'Dataset "{dataset}" not found') - if datasets.index.nlevels == 1: - # Reinstate eid index - datasets = pd.concat({str(eid): datasets}, names=['eid']) # Check files exist / download remote files offline = None if query_type == 'auto' else self.mode == 'local' @@ -1271,7 +1279,8 @@ def _verify_specifiers(specifiers): # Short circuit query_type = query_type or self.mode - all_datasets = self.list_datasets(eid, details=True, query_type=query_type) + all_datasets = self.list_datasets( + eid, details=True, query_type=query_type, keep_eid_index=True) if len(all_datasets) == 0: if assert_present: raise alferr.ALFObjectNotFound(f'No datasets found for session {eid}') @@ -1303,9 +1312,6 @@ def _verify_specifiers(specifiers): for x, y, z in zip(datasets, collections, revisions)] present = [len(x) == 1 for x in slices] present_datasets = pd.concat(slices) - if present_datasets.index.nlevels == 1: - # Reinstate eid index - present_datasets = pd.concat({str(eid): present_datasets}, names=['eid']) # Check if user is blindly downloading all data and warn of non-default revisions if 'default_revision' in present_datasets and \ @@ -1463,8 +1469,8 @@ def load_collection(self, No datasets match the object, attribute or revision filters for this collection. """ query_type = query_type or self.mode - datasets = self.list_datasets(eid, details=True, collection=collection, - query_type=query_type) + datasets = self.list_datasets( + eid, details=True, collection=collection, query_type=query_type, keep_eid_index=True) if len(datasets) == 0: raise alferr.ALFError(f'{collection} not found for session {eid}') @@ -1477,9 +1483,6 @@ def load_collection(self, if len(datasets) == 0: raise alferr.ALFObjectNotFound(object or '') parts = [alfiles.rel_path_parts(x) for x in datasets.rel_path] - if datasets.index.nlevels == 1: - # Reinstate eid index - datasets = pd.concat({str(eid): datasets}, names=['eid']) # For those that don't exist, download them offline = None if query_type == 'auto' else self.mode == 'local' @@ -1829,7 +1832,8 @@ def describe_dataset(self, dataset_type=None): @util.refresh def list_datasets( self, eid=None, filename=None, collection=None, revision=None, qc=QC.FAIL, - ignore_qc_not_set=False, details=False, query_type=None, default_revisions_only=False + ignore_qc_not_set=False, details=False, query_type=None, default_revisions_only=False, + keep_eid_index=False ) -> Union[np.ndarray, pd.DataFrame]: filters = dict( collection=collection, filename=filename, revision=revision, qc=qc, diff --git a/one/tests/test_one.py b/one/tests/test_one.py index c88351be..dd15fd38 100644 --- a/one/tests/test_one.py +++ b/one/tests/test_one.py @@ -404,6 +404,11 @@ def test_list_datasets(self): self.assertEqual(27, len(dsets)) self.assertEqual(1, dsets.index.nlevels, 'details data frame should be without eid index') + # Test keep_eid_index parameter + dsets = self.one.list_datasets('KS005/2019-04-02/001', details=True, keep_eid_index=True) + self.assertEqual(27, len(dsets)) + self.assertEqual(2, dsets.index.nlevels, 'details data frame should be with eid index') + # Test filters filename = {'attribute': ['times', 'intervals'], 'extension': 'npy'} dsets = self.one.list_datasets('ZFM-01935/2021-02-05/001', filename)