Skip to content

Commit

Permalink
Resolves issue #123 and #65
Browse files Browse the repository at this point in the history
  • Loading branch information
k1o0 committed Sep 26, 2024
1 parent a01c5f7 commit 21366f5
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 29 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ This version improves behaviour of loading revisions and loading datasets from l
- sub-collections no longer captured when filtering with filename that starts with wildcard in wildcard mode
- bugfix of spurious error raised when loading dataset with a revision provided
- default_revisions_only parameter in One.list_datasets filters non-default datasets
- permit data frame input to One.load_datasets and load precise relative paths provided (instead of default revisions)

### Added

Expand Down
76 changes: 54 additions & 22 deletions one/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
import one.alf.files as alfiles
import one.alf.exceptions as alferr
from .alf.cache import make_parquet_db, DATASETS_COLUMNS, SESSIONS_COLUMNS
from .alf.spec import is_uuid_string, QC
from .alf.spec import is_uuid_string, QC, to_alf
from . import __version__
from one.converters import ConversionMixin, session_record2path
import one.util as util
from one import util

_logger = logging.getLogger(__name__)
__all__ = ['ONE', 'One', 'OneAlyx']
Expand Down Expand Up @@ -1138,9 +1138,11 @@ def load_datasets(self,
download_only: bool = False,
check_hash: bool = True) -> Any:
"""
Load datasets for a given session id. Returns two lists the length of datasets. The
first is the data (or file paths if download_data is false), the second is a list of
meta data Bunches. If assert_present is false, missing data will be returned as None.
Load datasets for a given session id.

Returns two lists the length of datasets. The first is the data (or file paths if
download_data is false), the second is a list of meta data Bunches. If assert_present is
false, missing data will be returned as None.

Parameters
----------
Expand Down Expand Up @@ -1172,9 +1174,9 @@ def load_datasets(self,
Returns
-------
list
A list of data (or file paths) the length of datasets
A list of data (or file paths) the length of datasets.
list
A list of meta data Bunches. If assert_present is False, missing data will be None
A list of meta data Bunches. If assert_present is False, missing data will be None.

Notes
-----
Expand All @@ -1186,6 +1188,8 @@ def load_datasets(self,
revision as separate keyword arguments.
- To ensure you are loading the correct revision, use the revisions kwarg instead of
relative paths.
- To load an exact revision (i.e. not the last revision before a given date), pass in
a list of relative paths or a data frame.

Raises
------
Expand Down Expand Up @@ -1228,8 +1232,25 @@ def _verify_specifiers(specifiers):

if isinstance(datasets, str):
raise TypeError('`datasets` must be a non-string iterable')
# Check input args
collections, revisions = _verify_specifiers([collections, revisions])

# Check if rel paths have been used (e.g. the output of list_datasets)
is_frame = isinstance(datasets, pd.DataFrame)
if is_rel_paths := (is_frame or any('/' in x for x in datasets)):
if not (collections, revisions) == (None, None):
raise ValueError(
'collection and revision kwargs must be None when dataset is a relative path')
if is_frame:
if 'eid' in datasets.index.names:
assert set(datasets.index.get_level_values('eid')) == {eid}
datasets = datasets['rel_path'].tolist()
datasets = list(map(partial(alfiles.rel_path_parts, as_dict=True), datasets))
if len(datasets) > 0:
# Extract collection and revision from each of the parsed datasets
# None -> '' ensures exact collections and revisions are used in filter
# NB: f user passes in dicts, any collection/revision keys will be ignored.
collections, revisions = zip(
*((x.pop('collection') or '', x.pop('revision') or '') for x in datasets)
)

# Short circuit
query_type = query_type or self.mode
Expand All @@ -1243,35 +1264,46 @@ def _verify_specifiers(specifiers):
if len(datasets) == 0:
return None, all_datasets.iloc[0:0] # Return empty

# Filter and load missing
if self.wildcards: # Append extension wildcard if 'object.attribute' string
datasets = [x + ('.*' if isinstance(x, str) and len(x.split('.')) == 2 else '')
for x in datasets]
# More input validation
input_types = [(isinstance(x, str), isinstance(x, dict)) for x in datasets]
if not all(map(any, input_types)) or not any(map(all, zip(*input_types))):
raise ValueError('`datasets` must be iterable of only str or only dicts')
if self.wildcards and input_types[0][0]: # if wildcards and input is iter of str
# Append extension wildcard if 'object.attribute' string
datasets = [
x + ('.*' if isinstance(x, str) and len(x.split('.')) == 2 else '')
for x in datasets
]

# Check input args
collections, revisions = _verify_specifiers([collections, revisions])

# If collections provided in datasets list, e.g. [collection/x.y.z], do not assert unique
validate = not any(('/' if isinstance(d, str) else 'collection') in d for d in datasets)
if not validate and not all(x is None for x in collections + revisions):
raise ValueError(
'collection and revision kwargs must be None when dataset is a relative path')
ops = dict(wildcards=self.wildcards, assert_unique=validate)
# If not a dataframe, use revision last before (we've asserted no revision in rel_path)
ops = dict(
wildcards=self.wildcards, assert_unique=True, revision_last_before=not is_rel_paths)
slices = [util.filter_datasets(all_datasets, x, y, z, **ops)
for x, y, z in zip(datasets, collections, revisions)]
present = [len(x) == 1 for x in slices]
present_datasets = pd.concat(slices)

# Check if user is blindly downloading all data and warn of non-default revisions
if 'default_revision' in present_datasets and \
not any(revisions) and not all(present_datasets['default_revision']):
is_rel_paths and not all(present_datasets['default_revision']):
old = present_datasets.loc[~present_datasets['default_revision'], 'rel_path'].to_list()
warnings.warn(
'The following datasets may have been revised and ' +
'are therefore not recommended for analysis:\n\t' +
'\n\t'.join(old) + '\n'
'To avoid this warning, specify the revision as a kwarg or use load_dataset.'
'To avoid this warning, specify the revision as a kwarg or use load_dataset.',
alferr.ALFWarning
)

if not all(present):
missing_list = ', '.join(x for x, y in zip(datasets, present) if not y)
# FIXME include collection and revision also
missing_list = (x if isinstance(x, str) else to_alf(**x) for x in datasets)
missing_list = ('/'.join(filter(None, [c, f'#{r}#' if r else None, d]))
for c, r, d in zip(collections, revisions, missing_list))
missing_list = ', '.join(x for x, y in zip(missing_list, present) if not y)
message = f'The following datasets are not in the cache: {missing_list}'
if assert_present:
raise alferr.ALFObjectNotFound(message)
Expand Down
39 changes: 33 additions & 6 deletions one/tests/test_one.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,16 +654,43 @@ def test_load_datasets(self):
files, meta = self.one.load_datasets(eid, dsets, download_only=True)
self.assertTrue(all(isinstance(x, Path) for x in files))

# Check behaviour when loading with a data frame (undocumented)
eid = '01390fcc-4f86-4707-8a3b-4d9309feb0a1'
datasets = self.one._cache.datasets.loc[([eid],), :].iloc[:3, :]
files, meta = self.one.load_datasets(eid, datasets, download_only=True)
self.assertTrue(all(isinstance(x, Path) for x in files))
# Should raise when data frame contains a different eid
self.assertRaises(AssertionError, self.one.load_datasets, uuid4(), datasets)

# Mix of str and dict
# Check download only
dsets = [
spec.regex(spec.FILE_SPEC).match('_ibl_wheel.position.npy').groupdict(),
'_ibl_wheel.timestamps.npy'
]
with self.assertRaises(ValueError):
self.one.load_datasets('KS005/2019-04-02/001', dsets, assert_present=False)

# Loading of non default revisions without using the revision kwarg causes user warning.
# With relative paths provided as input, dataset uniqueness validation is supressed.
# With relative paths provided as input, dataset uniqueness validation is suppressed.
eid = self.one._cache.sessions.iloc[0].name
datasets = util.revisions_datasets_table(
revisions=('', '2020-01-08'), attributes=('times',))
revisions=('', '2020-01-08'), attributes=('times',), touch_path=self.one.cache_dir)
datasets['default_revision'] = [False, True] * 3
eid = datasets.iloc[0].name[0]
datasets.index = datasets.index.set_levels([eid], level=0)
self.one._cache.datasets = datasets
with self.assertWarns(UserWarning):
self.one.load_datasets(eid, datasets['rel_path'].to_list(),
download_only=True, assert_present=False)
with self.assertWarns(alferr.ALFWarning):
self.one.load_datasets(eid, datasets['rel_path'].to_list(), download_only=True)

# Ensure that when rel paths are passed, a null collection/revision is not interpreted as
# an ANY. NB this means the output of 'spikes.times.npy' will be different depending on
# weather other datasets in list include a collection or revision.
self.one._cache.datasets = datasets.iloc[:2, :].copy() # only two datasets, one default
(file,), _ = self.one.load_datasets(eid, ['spikes.times.npy', ], download_only=True)
self.assertTrue(file.as_posix().endswith('001/#2020-01-08#/spikes.times.npy'))
(file, _), _ = self.one.load_datasets(
eid, ['spikes.times.npy', 'xx/obj.attr.ext'], download_only=True, assert_present=False)
self.assertTrue(file.as_posix().endswith('001/spikes.times.npy'))

# When loading without collections in the dataset list (i.e. just the dataset names)
# an exception should be raised when datasets belong to multiple collections.
Expand Down
11 changes: 10 additions & 1 deletion one/tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def setup_test_params(token=False, cache_dir=None):
def revisions_datasets_table(collections=('', 'alf/probe00', 'alf/probe01'),
revisions=('', '2020-01-08', '2021-07-06'),
object='spikes',
attributes=('times', 'waveforems')):
attributes=('times', 'waveforems'),
touch_path=None):
"""Returns a datasets cache DataFrame containing datasets with revision folders.

As there are no revised datasets on the test databases, this function acts as a fixture for
Expand All @@ -132,6 +133,8 @@ def revisions_datasets_table(collections=('', 'alf/probe00', 'alf/probe01'),
An ALF object
attributes : tuple
A list of ALF attributes
touch_path : pathlib.Path, str
If provided, files are created in this directory.

Returns
-------
Expand All @@ -155,6 +158,12 @@ def revisions_datasets_table(collections=('', 'alf/probe00', 'alf/probe01'),
'id': map(str, (uuid4() for _ in rel_path))
}

if touch_path:
for p in rel_path:
path = Path(touch_path).joinpath(d['session_path'] + '/' + p)
path.parent.mkdir(parents=True, exist_ok=True)
path.touch()

return pd.DataFrame(data=d).astype({'qc': QC_TYPE}).set_index(['eid', 'id'])


Expand Down

0 comments on commit 21366f5

Please sign in to comment.