Skip to content

Commit

Permalink
Merge pull request #146 from int-brain-lab/v2.10.1
Browse files Browse the repository at this point in the history
v2.10.1
  • Loading branch information
k1o0 authored Oct 30, 2024
2 parents c83b7e6 + 43ce76b commit b6016c9
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 31 deletions.
11 changes: 10 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# Changelog
## [Latest](https://github.com/int-brain-lab/ONE/commits/main) [2.10.0]
## [Latest](https://github.com/int-brain-lab/ONE/commits/main) [2.10.1]

### Modified

- prompt user to strip quotation marks if used during ONE setup
- indicate when downloading from S3
- added 'keep_eid_index' kwarg to One.list_datasets which will return the data frame with the eid index level reinstated
- HOTFIX: include Subject/lab part in destination path when downloading from S3

## [2.10.0]
This version improves behaviour of loading revisions and loading datasets from list_datasets output.

### Modified
Expand Down
2 changes: 1 addition & 1 deletion one/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
"""The Open Neurophysiology Environment (ONE) API."""
__version__ = '2.10.0'
__version__ = '2.10.1'
55 changes: 31 additions & 24 deletions one/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,8 @@ def list_subjects(self) -> List[str]:
@util.refresh
def list_datasets(
self, eid=None, filename=None, collection=None, revision=None, qc=QC.FAIL,
ignore_qc_not_set=False, details=False, query_type=None, default_revisions_only=False
ignore_qc_not_set=False, details=False, query_type=None, default_revisions_only=False,
keep_eid_index=False
) -> Union[np.ndarray, pd.DataFrame]:
"""
Given an eid, return the datasets for those sessions.
Expand Down Expand Up @@ -748,6 +749,10 @@ def list_datasets(
default_revisions_only : bool
When true, only matching datasets that are considered default revisions are returned.
If no 'default_revision' column is present, and ALFError is raised.
keep_eid_index : bool
If details is true, this determines whether the returned data frame contains the eid
in the index. When false (default) the returned data frame index is the dataset id
only, otherwise the index is a MultIndex with levels (eid, id).
Returns
-------
Expand Down Expand Up @@ -798,8 +803,15 @@ def list_datasets(
return datasets.iloc[0:0] # Return empty

datasets = util.filter_datasets(datasets, **filter_args)
# Return only the relative path
return datasets if details else datasets['rel_path'].sort_values().values.tolist()
if details:
if keep_eid_index and datasets.index.nlevels == 1:
# Reinstate eid index
datasets = pd.concat({str(eid): datasets}, names=['eid'])
# Return the full data frame
return datasets
else:
# Return only the relative path
return datasets['rel_path'].sort_values().values.tolist()

@util.refresh
def list_collections(self, eid=None, filename=None, collection=None, revision=None,
Expand Down Expand Up @@ -1000,7 +1012,8 @@ def load_object(self,
>>> load_object(eid, 'spikes', attribute=['times*', 'clusters'])
"""
query_type = query_type or self.mode
datasets = self.list_datasets(eid, details=True, query_type=query_type)
datasets = self.list_datasets(
eid, details=True, query_type=query_type, keep_eid_index=True)

if len(datasets) == 0:
raise alferr.ALFObjectNotFound(obj)
Expand All @@ -1022,9 +1035,6 @@ def load_object(self,

# For those that don't exist, download them
offline = None if query_type == 'auto' else self.mode == 'local'
if datasets.index.nlevels == 1:
# Reinstate eid index
datasets = pd.concat({str(eid): datasets}, names=['eid'])
files = self._check_filesystem(datasets, offline=offline, check_hash=check_hash)
files = [x for x in files if x]
if not files:
Expand Down Expand Up @@ -1112,7 +1122,8 @@ def load_dataset(self,
wildcards/regular expressions must not be used. To use wildcards, pass the collection
and revision as separate keyword arguments.
"""
datasets = self.list_datasets(eid, details=True, query_type=query_type or self.mode)
datasets = self.list_datasets(
eid, details=True, query_type=query_type or self.mode, keep_eid_index=True)
# If only two parts and wildcards are on, append ext wildcard
if self.wildcards and isinstance(dataset, str) and len(dataset.split('.')) == 2:
dataset += '.*'
Expand All @@ -1129,9 +1140,6 @@ def load_dataset(self,
wildcards=self.wildcards, assert_unique=assert_unique)
if len(datasets) == 0:
raise alferr.ALFObjectNotFound(f'Dataset "{dataset}" not found')
if datasets.index.nlevels == 1:
# Reinstate eid index
datasets = pd.concat({str(eid): datasets}, names=['eid'])

# Check files exist / download remote files
offline = None if query_type == 'auto' else self.mode == 'local'
Expand Down Expand Up @@ -1271,7 +1279,8 @@ def _verify_specifiers(specifiers):

# Short circuit
query_type = query_type or self.mode
all_datasets = self.list_datasets(eid, details=True, query_type=query_type)
all_datasets = self.list_datasets(
eid, details=True, query_type=query_type, keep_eid_index=True)
if len(all_datasets) == 0:
if assert_present:
raise alferr.ALFObjectNotFound(f'No datasets found for session {eid}')
Expand Down Expand Up @@ -1303,9 +1312,6 @@ def _verify_specifiers(specifiers):
for x, y, z in zip(datasets, collections, revisions)]
present = [len(x) == 1 for x in slices]
present_datasets = pd.concat(slices)
if present_datasets.index.nlevels == 1:
# Reinstate eid index
present_datasets = pd.concat({str(eid): present_datasets}, names=['eid'])

# Check if user is blindly downloading all data and warn of non-default revisions
if 'default_revision' in present_datasets and \
Expand Down Expand Up @@ -1463,8 +1469,8 @@ def load_collection(self,
No datasets match the object, attribute or revision filters for this collection.
"""
query_type = query_type or self.mode
datasets = self.list_datasets(eid, details=True, collection=collection,
query_type=query_type)
datasets = self.list_datasets(
eid, details=True, collection=collection, query_type=query_type, keep_eid_index=True)

if len(datasets) == 0:
raise alferr.ALFError(f'{collection} not found for session {eid}')
Expand All @@ -1477,9 +1483,6 @@ def load_collection(self,
if len(datasets) == 0:
raise alferr.ALFObjectNotFound(object or '')
parts = [alfiles.rel_path_parts(x) for x in datasets.rel_path]
if datasets.index.nlevels == 1:
# Reinstate eid index
datasets = pd.concat({str(eid): datasets}, names=['eid'])

# For those that don't exist, download them
offline = None if query_type == 'auto' else self.mode == 'local'
Expand Down Expand Up @@ -1829,16 +1832,19 @@ def describe_dataset(self, dataset_type=None):
@util.refresh
def list_datasets(
self, eid=None, filename=None, collection=None, revision=None, qc=QC.FAIL,
ignore_qc_not_set=False, details=False, query_type=None, default_revisions_only=False
ignore_qc_not_set=False, details=False, query_type=None, default_revisions_only=False,
keep_eid_index=False
) -> Union[np.ndarray, pd.DataFrame]:
filters = dict(
collection=collection, filename=filename, revision=revision, qc=qc,
ignore_qc_not_set=ignore_qc_not_set, default_revisions_only=default_revisions_only)
if (query_type or self.mode) != 'remote':
return super().list_datasets(eid, details=details, query_type=query_type, **filters)
return super().list_datasets(eid, details=details, keep_eid_index=keep_eid_index,
query_type=query_type, **filters)
elif not eid:
warnings.warn('Unable to list all remote datasets')
return super().list_datasets(eid, details=details, query_type=query_type, **filters)
return super().list_datasets(eid, details=details, keep_eid_index=keep_eid_index,
query_type=query_type, **filters)
eid = self.to_eid(eid) # Ensure we have a UUID str list
if not eid:
return self._cache['datasets'].iloc[0:0] if details else [] # Return empty
Expand Down Expand Up @@ -2367,8 +2373,9 @@ def _download_aws(self, dsets, update_exists=True, keep_uuid=None, **_) -> List[
assert record['relative_path'].endswith(dset['rel_path']), \
f'Relative path for dataset {uuid} does not match Alyx record'
source_path = PurePosixPath(record['data_repository_path'], record['relative_path'])
local_path = self.cache_dir.joinpath(alfiles.get_alf_path(source_path))
# Add UUIDs to filenames, if required
source_path = alfiles.add_uuid_string(source_path, uuid)
local_path = self.cache_dir.joinpath(record['relative_path'])
if keep_uuid is True or (keep_uuid is None and self.uuid_filenames is True):
local_path = alfiles.add_uuid_string(local_path, uuid)
local_path.parent.mkdir(exist_ok=True, parents=True)
Expand Down
13 changes: 11 additions & 2 deletions one/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def setup(client=None, silent=False, make_default=None, username=None, cache_dir
if not silent:
prompt = 'Param %s, current value is ["%s"]:'
par = iopar.as_dict(par_default)
quotes = '"\'`'
# Iterate through non-password pars
for k in filter(lambda k: 'PWD' not in k, par.keys()):
cpar = _get_current_par(k, par_current)
Expand All @@ -137,10 +138,18 @@ def setup(client=None, silent=False, make_default=None, username=None, cache_dir
url_parsed = urlsplit(par[k])
if not (url_parsed.netloc and re.match('https?', url_parsed.scheme)):
raise ValueError(f'{k} must be valid HTTP URL')
if k == 'ALYX_URL':
client = par[k]
else:
par[k] = input(prompt % (k, cpar)).strip() or cpar
# Check whether user erroneously entered quotation marks
# Prompting the user here (hopefully) corrects them before they input a password
# where the use of quotation marks may be legitimate
if par[k] and len(par[k]) >= 2 and par[k][0] in quotes and par[k][-1] in quotes:
warnings.warn('Do not use quotation marks with input answers', UserWarning)
ans = input('Strip quotation marks from response? [Y/n]:').strip() or 'y'
if ans.lower()[0] == 'y':
par[k] = par[k].strip(quotes)
if k == 'ALYX_URL':
client = par[k]

cpar = _get_current_par('HTTP_DATA_SERVER_PWD', par_current)
prompt = f'Enter the FlatIron HTTP password for {par["HTTP_DATA_SERVER_LOGIN"]} '\
Expand Down
2 changes: 1 addition & 1 deletion one/remote/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def s3_download_file(source, destination, s3=None, bucket_name=None, overwrite=F
_logger.debug(f"{destination} exists and match size -- skipping")
return destination
with tqdm(total=filesize, unit='B',
unit_scale=True, desc=str(destination)) as t:
unit_scale=True, desc=f'(S3) {destination}') as t:
file_object.download_file(Filename=str(destination), Callback=_callback_hook(t))
except (NoCredentialsError, PartialCredentialsError) as ex:
raise ex # Credentials need updating in Alyx # pragma: no cover
Expand Down
5 changes: 5 additions & 0 deletions one/tests/test_one.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,11 @@ def test_list_datasets(self):
self.assertEqual(27, len(dsets))
self.assertEqual(1, dsets.index.nlevels, 'details data frame should be without eid index')

# Test keep_eid_index parameter
dsets = self.one.list_datasets('KS005/2019-04-02/001', details=True, keep_eid_index=True)
self.assertEqual(27, len(dsets))
self.assertEqual(2, dsets.index.nlevels, 'details data frame should be with eid index')

# Test filters
filename = {'attribute': ['times', 'intervals'], 'extension': 'npy'}
dsets = self.one.list_datasets('ZFM-01935/2021-02-05/001', filename)
Expand Down
12 changes: 10 additions & 2 deletions one/tests/test_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,16 @@ def test_setup(self, _):
# Check verification prompt
resp_map = {'ALYX_LOGIN': 'mistake', 'settings correct?': 'N'}
with mock.patch('one.params.input', new=partial(self._mock_input, **resp_map)):
cache = one.params.setup()
self.assertNotEqual(cache.ALYX_LOGIN, 'mistake')
one.params.setup()
par = one.params.get(self.url, silent=True)
self.assertNotEqual(par.ALYX_LOGIN, 'mistake')

# Check prompt when quotation marks used
resp_map = {'ALYX_LOGIN': '"foo"', 'Strip quotation marks': 'y', 'settings correct?': 'Y'}
with mock.patch('one.params.input', new=partial(self._mock_input, **resp_map)):
self.assertWarnsRegex(UserWarning, 'quotation marks', one.params.setup)
par = one.params.get(self.url, silent=True)
self.assertEqual(par.ALYX_LOGIN, 'foo', 'failed to strip quotes from user input')

# Check that raises ValueError when bad URL provided
self.url = 'ftp://foo.bar.org'
Expand Down

0 comments on commit b6016c9

Please sign in to comment.