From 2b728422981b41e72c72e761f3eb84da43a018c7 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Thu, 26 Sep 2024 18:45:50 +0300 Subject: [PATCH] Hard-code setting 'exists_*' columns to 'exists' values instead of NaN to avoid pandas dtype change --- one/api.py | 10 ++++++---- one/tests/alf/test_alf_io.py | 2 +- one/tests/test_one.py | 13 ++++++++++++- one/util.py | 4 ++-- one/webclient.py | 3 +++ 5 files changed, 24 insertions(+), 8 deletions(-) diff --git a/one/api.py b/one/api.py index d552b3d1..6663f6b5 100644 --- a/one/api.py +++ b/one/api.py @@ -287,9 +287,10 @@ def _update_cache_from_records(self, strict=False, **kwargs): if not strict: # Deal with case where there are extra columns in the cache extra_columns = set(self._cache[table].columns) - set(records.columns) - for col in extra_columns: - n = list(self._cache[table].columns).index(col) - records.insert(n, col, np.nan) + column_ids = map(list(self._cache[table].columns).index, extra_columns) + for col, n in sorted(zip(extra_columns, column_ids), key=lambda x: x[1]): + val = records.get('exists', True) if col.startswith('exists_') else np.nan + records.insert(n, col, val) # Drop any extra columns in the records that aren't in cache table to_drop = set(records.columns) - set(self._cache[table].columns) records.drop(to_drop, axis=1, inplace=True) @@ -302,7 +303,8 @@ def _update_cache_from_records(self, strict=False, **kwargs): to_assign = records[~to_update] if isinstance(self._cache[table].index, pd.MultiIndex) and not to_assign.empty: # Concatenate and sort (no other way for non-unique index within MultiIndex) - self._cache[table] = pd.concat([self._cache[table], to_assign]).sort_index() + frames = filter(lambda x: not x.empty, [self._cache[table], to_assign]) + self._cache[table] = pd.concat(frames).sort_index() else: for index, record in to_assign.iterrows(): self._cache[table].loc[index, :] = record[self._cache[table].columns].values diff --git a/one/tests/alf/test_alf_io.py b/one/tests/alf/test_alf_io.py index 21e56e0d..cfe05b33 100644 --- a/one/tests/alf/test_alf_io.py +++ b/one/tests/alf/test_alf_io.py @@ -727,7 +727,7 @@ def setUp(self): self.session_path.joinpath(f'bar.baz_y.{uuid.uuid4()}.npy'), self.session_path.joinpath('#2021-01-01#', f'bar.baz.{uuid.uuid4()}.npy'), self.session_path.joinpath('task_00', 'x.y.z'), - self.session_path.joinpath('x.y.z'), + self.session_path.joinpath('x.y.z') ] for f in self.dsets: f.parent.mkdir(exist_ok=True, parents=True) diff --git a/one/tests/test_one.py b/one/tests/test_one.py index 1fc006b5..013542a5 100644 --- a/one/tests/test_one.py +++ b/one/tests/test_one.py @@ -906,13 +906,24 @@ def test_update_cache_from_records(self): # Check behaviour when columns don't match datasets.loc[:, 'exists'] = ~datasets.loc[:, 'exists'] datasets['extra_column'] = True + self.one._cache.datasets['foo_bar'] = 12 # this column is missing in our new records self.one._cache.datasets['new_column'] = False - self.addCleanup(self.one._cache.datasets.drop, 'new_column', axis=1, inplace=True) + self.addCleanup(self.one._cache.datasets.drop, 'foo_bar', axis=1, inplace=True) + # An exception is exists_* as the Alyx cache contains exists_aws and exists_flatiron + # These should simply be filled with the values of exists as Alyx won't return datasets + # that don't exist on FlatIron and if they don't exist on AWS it falls back to this. + self.one._cache.datasets['exists_aws'] = False with self.assertRaises(AssertionError): self.one._update_cache_from_records(datasets=datasets, strict=True) self.one._update_cache_from_records(datasets=datasets) verifiable = self.one._cache.datasets.loc[datasets.index.values, 'exists'] self.assertTrue(np.all(verifiable == datasets.loc[:, 'exists'])) + self.one._update_cache_from_records(datasets=datasets) + verifiable = self.one._cache.datasets.loc[datasets.index.values, 'exists_aws'] + self.assertTrue(np.all(verifiable == datasets.loc[:, 'exists'])) + # If the extra column does not start with 'exists' it should be set to NaN + verifiable = self.one._cache.datasets.loc[datasets.index.values, 'foo_bar'] + self.assertTrue(np.isnan(verifiable).all()) # Check fringe cases with self.assertRaises(KeyError): diff --git a/one/util.py b/one/util.py index 743b0101..e9766a1b 100644 --- a/one/util.py +++ b/one/util.py @@ -392,8 +392,8 @@ def filter_datasets( # Convert to regex if necessary and assert end of string flagless_token = re.escape(r'(?s:') # fnmatch.translate may wrap input in flagless group # If there is a wildcard at the start of the filename we must exclude capture of slashes to - # avoid capture of collection part, e.g. * -> .* -> [^/]+ (one or more non-slash chars) - exclude_slash = partial(re.sub, fr'^({flagless_token})?\.[*?]', r'\g<1>[^/]+') + # avoid capture of collection part, e.g. * -> .* -> [^/]* (one or more non-slash chars) + exclude_slash = partial(re.sub, fr'^({flagless_token})?\.\*', r'\g<1>[^/]*') spec_str += '|'.join( exclude_slash(fnmatch.translate(x)) if wildcards else x + '$' for x in ensure_list(filename) diff --git a/one/webclient.py b/one/webclient.py index 0782122c..9682df43 100644 --- a/one/webclient.py +++ b/one/webclient.py @@ -802,6 +802,9 @@ def download_cache_tables(self, source=None, destination=None): headers = self._headers with tempfile.TemporaryDirectory(dir=destination) as tmp: + if source.startswith('s3://'): + from one.remote import aws + aws.s3_download_file(source, destination, s3='.') file = http_download_file(source, headers=headers, silent=self.silent,