Skip to content

Commit

Permalink
Hard-code setting 'exists_*' columns to 'exists' values instead of Na…
Browse files Browse the repository at this point in the history
…N to avoid pandas dtype change
  • Loading branch information
k1o0 committed Sep 26, 2024
1 parent 8edc6d2 commit 2b72842
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 8 deletions.
10 changes: 6 additions & 4 deletions one/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,10 @@ def _update_cache_from_records(self, strict=False, **kwargs):
if not strict:
# Deal with case where there are extra columns in the cache
extra_columns = set(self._cache[table].columns) - set(records.columns)
for col in extra_columns:
n = list(self._cache[table].columns).index(col)
records.insert(n, col, np.nan)
column_ids = map(list(self._cache[table].columns).index, extra_columns)
for col, n in sorted(zip(extra_columns, column_ids), key=lambda x: x[1]):
val = records.get('exists', True) if col.startswith('exists_') else np.nan
records.insert(n, col, val)
# Drop any extra columns in the records that aren't in cache table
to_drop = set(records.columns) - set(self._cache[table].columns)
records.drop(to_drop, axis=1, inplace=True)
Expand All @@ -302,7 +303,8 @@ def _update_cache_from_records(self, strict=False, **kwargs):
to_assign = records[~to_update]
if isinstance(self._cache[table].index, pd.MultiIndex) and not to_assign.empty:
# Concatenate and sort (no other way for non-unique index within MultiIndex)
self._cache[table] = pd.concat([self._cache[table], to_assign]).sort_index()
frames = filter(lambda x: not x.empty, [self._cache[table], to_assign])
self._cache[table] = pd.concat(frames).sort_index()
else:
for index, record in to_assign.iterrows():
self._cache[table].loc[index, :] = record[self._cache[table].columns].values
Expand Down
2 changes: 1 addition & 1 deletion one/tests/alf/test_alf_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ def setUp(self):
self.session_path.joinpath(f'bar.baz_y.{uuid.uuid4()}.npy'),
self.session_path.joinpath('#2021-01-01#', f'bar.baz.{uuid.uuid4()}.npy'),
self.session_path.joinpath('task_00', 'x.y.z'),
self.session_path.joinpath('x.y.z'),
self.session_path.joinpath('x.y.z')
]
for f in self.dsets:
f.parent.mkdir(exist_ok=True, parents=True)
Expand Down
13 changes: 12 additions & 1 deletion one/tests/test_one.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,13 +906,24 @@ def test_update_cache_from_records(self):
# Check behaviour when columns don't match
datasets.loc[:, 'exists'] = ~datasets.loc[:, 'exists']
datasets['extra_column'] = True
self.one._cache.datasets['foo_bar'] = 12 # this column is missing in our new records
self.one._cache.datasets['new_column'] = False
self.addCleanup(self.one._cache.datasets.drop, 'new_column', axis=1, inplace=True)
self.addCleanup(self.one._cache.datasets.drop, 'foo_bar', axis=1, inplace=True)
# An exception is exists_* as the Alyx cache contains exists_aws and exists_flatiron
# These should simply be filled with the values of exists as Alyx won't return datasets
# that don't exist on FlatIron and if they don't exist on AWS it falls back to this.
self.one._cache.datasets['exists_aws'] = False
with self.assertRaises(AssertionError):
self.one._update_cache_from_records(datasets=datasets, strict=True)
self.one._update_cache_from_records(datasets=datasets)
verifiable = self.one._cache.datasets.loc[datasets.index.values, 'exists']
self.assertTrue(np.all(verifiable == datasets.loc[:, 'exists']))
self.one._update_cache_from_records(datasets=datasets)
verifiable = self.one._cache.datasets.loc[datasets.index.values, 'exists_aws']
self.assertTrue(np.all(verifiable == datasets.loc[:, 'exists']))
# If the extra column does not start with 'exists' it should be set to NaN
verifiable = self.one._cache.datasets.loc[datasets.index.values, 'foo_bar']
self.assertTrue(np.isnan(verifiable).all())

# Check fringe cases
with self.assertRaises(KeyError):
Expand Down
4 changes: 2 additions & 2 deletions one/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,8 @@ def filter_datasets(
# Convert to regex if necessary and assert end of string
flagless_token = re.escape(r'(?s:') # fnmatch.translate may wrap input in flagless group
# If there is a wildcard at the start of the filename we must exclude capture of slashes to
# avoid capture of collection part, e.g. * -> .* -> [^/]+ (one or more non-slash chars)
exclude_slash = partial(re.sub, fr'^({flagless_token})?\.[*?]', r'\g<1>[^/]+')
# avoid capture of collection part, e.g. * -> .* -> [^/]* (one or more non-slash chars)
exclude_slash = partial(re.sub, fr'^({flagless_token})?\.\*', r'\g<1>[^/]*')
spec_str += '|'.join(
exclude_slash(fnmatch.translate(x)) if wildcards else x + '$'
for x in ensure_list(filename)
Expand Down
3 changes: 3 additions & 0 deletions one/webclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,9 @@ def download_cache_tables(self, source=None, destination=None):
headers = self._headers

with tempfile.TemporaryDirectory(dir=destination) as tmp:
if source.startswith('s3://'):
from one.remote import aws
aws.s3_download_file(source, destination, s3='.')
file = http_download_file(source,
headers=headers,
silent=self.silent,
Expand Down

0 comments on commit 2b72842

Please sign in to comment.