From 0f05989aff479a6e270e08aab60cec472275811b Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Thu, 31 Oct 2024 17:18:22 +0200 Subject: [PATCH] Enforce data frame dtypes --- one/api.py | 21 ++++++++++++--------- one/converters.py | 12 ++++++------ one/tests/test_one.py | 41 ++++++++++++++++++++++++++++++++++++++++- 3 files changed, 58 insertions(+), 16 deletions(-) diff --git a/one/api.py b/one/api.py index 7589f7fa..ef5a4a08 100644 --- a/one/api.py +++ b/one/api.py @@ -310,30 +310,33 @@ def _update_cache_from_records(self, strict=False, **kwargs): raise KeyError(f'Table "{table}" not in cache') if isinstance(records, pd.Series): records = pd.DataFrame([records]) + records.index.set_names(self._cache[table].index.names, inplace=True) if not strict: # Deal with case where there are extra columns in the cache - extra_columns = set(self._cache[table].columns) - set(records.columns) + extra_columns = list(set(self._cache[table].columns) - set(records.columns)) + # Convert these columns to nullable, if required + cache_columns = self._cache[table][extra_columns] + self._cache[table][extra_columns] = cache_columns.convert_dtypes() column_ids = map(list(self._cache[table].columns).index, extra_columns) for col, n in sorted(zip(extra_columns, column_ids), key=lambda x: x[1]): - val = records.get('exists', True) if col.startswith('exists_') else np.nan + dtype = self._cache[table][col].dtype + nan = getattr(dtype, 'na_value', np.nan) + val = records.get('exists', True) if col.startswith('exists_') else nan records.insert(n, col, val) # Drop any extra columns in the records that aren't in cache table to_drop = set(records.columns) - set(self._cache[table].columns) records.drop(to_drop, axis=1, inplace=True) records = records.reindex(columns=self._cache[table].columns) assert set(self._cache[table].columns) == set(records.columns) + records = records.astype(self._cache[table].dtypes) # Update existing rows to_update = records.index.isin(self._cache[table].index) self._cache[table].loc[records.index[to_update], :] = records[to_update] # Assign new rows to_assign = records[~to_update] - if isinstance(self._cache[table].index, pd.MultiIndex) and not to_assign.empty: - # Concatenate and sort (no other way for non-unique index within MultiIndex) - frames = filter(lambda x: not x.empty, [self._cache[table], to_assign]) - self._cache[table] = pd.concat(frames).sort_index() - else: - for index, record in to_assign.iterrows(): - self._cache[table].loc[index, :] = record[self._cache[table].columns].values + frames = [self._cache[table], to_assign] + # Concatenate and sort + self._cache[table] = pd.concat(frames).sort_index() updated = datetime.now() self._cache['_meta']['modified_time'] = updated return updated diff --git a/one/converters.py b/one/converters.py index e2259ab1..2266ec3b 100644 --- a/one/converters.py +++ b/one/converters.py @@ -20,7 +20,7 @@ from iblutil.util import Bunch, Listable, ensure_list from one.alf.spec import is_session_path, is_uuid_string -from one.alf.cache import QC_TYPE, EMPTY_DATASETS_FRAME +from one.alf.cache import EMPTY_DATASETS_FRAME from one.alf.files import ( get_session_path, add_uuid_string, session_path_parts, get_alf_path, remove_uuid_string) @@ -783,7 +783,8 @@ def _to_record(d): return session, EMPTY_DATASETS_FRAME.copy() records = map(_to_record, ses['data_dataset_session_related']) index = ['eid', 'id'] - datasets = pd.DataFrame(records).set_index(index).sort_index().astype({'qc': QC_TYPE}) + dtypes = EMPTY_DATASETS_FRAME.dtypes + datasets = pd.DataFrame(records).astype(dtypes).set_index(index).sort_index() return session, datasets @@ -829,8 +830,7 @@ def datasets2records(datasets, additional=None) -> pd.DataFrame: rec[field] = d.get(field) records.append(rec) - index = ['eid', 'id'] if not records: - keys = (*index, 'file_size', 'hash', 'session_path', 'rel_path', 'default_revision', 'qc') - return pd.DataFrame(columns=keys).set_index(index) - return pd.DataFrame(records).set_index(index).sort_index().astype({'qc': QC_TYPE}) + return EMPTY_DATASETS_FRAME + index = EMPTY_DATASETS_FRAME.index.names + return pd.DataFrame(records).set_index(index).sort_index().astype(EMPTY_DATASETS_FRAME.dtypes) diff --git a/one/tests/test_one.py b/one/tests/test_one.py index fbf4e750..1612e6ba 100644 --- a/one/tests/test_one.py +++ b/one/tests/test_one.py @@ -55,6 +55,7 @@ from one.converters import datasets2records from one.alf import spec from one.alf.files import get_alf_path +from one.alf.cache import EMPTY_DATASETS_FRAME, EMPTY_SESSIONS_FRAME from . import util from . import OFFLINE_ONLY, TEST_DB_1, TEST_DB_2 # 1 = TestAlyx; 2 = OpenAlyx @@ -905,8 +906,28 @@ def test_save_cache(self): self.assertFalse(lock_file.exists(), 'failed to remove stale lock file') time_mock.sleep.assert_called() + def test_reset_cache(self): + """Test One._reset_cache method, namely that cache types are correct.""" + # Assert cache dtypes are indeed what are expected + self.one._reset_cache() + self.assertCountEqual(['datasets', 'sessions', '_meta'], self.one._cache.keys()) + self.assertTrue(self.one._cache.datasets.empty) + self.assertCountEqual(EMPTY_SESSIONS_FRAME.columns, self.one._cache.sessions.columns) + self.assertTrue(self.one._cache.sessions.empty) + self.assertCountEqual(EMPTY_DATASETS_FRAME.columns, self.one._cache.datasets.columns) + # Check sessions data frame types + sessions_types = EMPTY_SESSIONS_FRAME.reset_index().dtypes.to_dict() + s_types = self.one._cache.sessions.reset_index().dtypes.to_dict() + self.assertDictEqual(sessions_types, s_types) + # Check datasets data frame types + datasets_types = EMPTY_DATASETS_FRAME.reset_index().dtypes.to_dict() + d_types = self.one._cache.datasets.reset_index().dtypes.to_dict() + self.assertDictEqual(datasets_types, d_types) + def test_update_cache_from_records(self): """Test One._update_cache_from_records""" + sessions_types = self.one._cache.sessions.reset_index().dtypes.to_dict() + datasets_types = self.one._cache.datasets.reset_index().dtypes.to_dict() # Update with single record (pandas.Series), one exists, one doesn't session = self.one._cache.sessions.iloc[0].squeeze() session.name = str(uuid4()) # New record @@ -916,6 +937,11 @@ def test_update_cache_from_records(self): self.assertTrue(session.name in self.one._cache.sessions.index) updated, = dataset['exists'] == self.one._cache.datasets.loc[dataset.name, 'exists'] self.assertTrue(updated) + # Check that the updated data frame has kept its original dtypes + types = self.one._cache.sessions.reset_index().dtypes.to_dict() + self.assertDictEqual(sessions_types, types) + types = self.one._cache.datasets.reset_index().dtypes.to_dict() + self.assertDictEqual(datasets_types, types) # Update a number of records datasets = self.one._cache.datasets.iloc[:3].copy() @@ -923,17 +949,21 @@ def test_update_cache_from_records(self): # Make one of the datasets a new record idx = datasets.index.values idx[-1] = (idx[-1][0], str(uuid4())) - datasets.index = pd.MultiIndex.from_tuples(idx) + datasets.index = pd.MultiIndex.from_tuples(idx, names=('eid', 'id')) self.one._update_cache_from_records(datasets=datasets) self.assertTrue(idx[-1] in self.one._cache.datasets.index) verifiable = self.one._cache.datasets.loc[datasets.index.values, 'exists'] self.assertTrue(np.all(verifiable == datasets.loc[:, 'exists'])) + # Check that the updated data frame has kept its original dtypes + types = self.one._cache.datasets.reset_index().dtypes.to_dict() + self.assertDictEqual(datasets_types, types) # Check behaviour when columns don't match datasets.loc[:, 'exists'] = ~datasets.loc[:, 'exists'] datasets['extra_column'] = True self.one._cache.datasets['foo_bar'] = 12 # this column is missing in our new records self.one._cache.datasets['new_column'] = False + expected_datasets_types = self.one._cache.datasets.reset_index().dtypes.to_dict() self.addCleanup(self.one._cache.datasets.drop, 'foo_bar', axis=1, inplace=True) # An exception is exists_* as the Alyx cache contains exists_aws and exists_flatiron # These should simply be filled with the values of exists as Alyx won't return datasets @@ -950,6 +980,11 @@ def test_update_cache_from_records(self): # If the extra column does not start with 'exists' it should be set to NaN verifiable = self.one._cache.datasets.loc[datasets.index.values, 'foo_bar'] self.assertTrue(np.isnan(verifiable).all()) + # Check that the missing columns were updated to nullable fields + expected_datasets_types.update( + foo_bar=pd.Int64Dtype(), exists_aws=pd.BooleanDtype(), new_column=pd.BooleanDtype()) + types = self.one._cache.datasets.reset_index().dtypes.to_dict() + self.assertDictEqual(expected_datasets_types, types) # Check fringe cases with self.assertRaises(KeyError): @@ -957,11 +992,15 @@ def test_update_cache_from_records(self): self.assertIsNone(self.one._update_cache_from_records(datasets=None)) # Absent cache table self.one.load_cache(tables_dir='/foo') + sessions_types = self.one._cache.sessions.reset_index().dtypes.to_dict() + datasets_types = self.one._cache.datasets.reset_index().dtypes.to_dict() self.one._update_cache_from_records(sessions=session, datasets=dataset) self.assertTrue(all(self.one._cache.sessions == pd.DataFrame([session]))) self.assertEqual(1, len(self.one._cache.datasets)) self.assertEqual(self.one._cache.datasets.squeeze().name, dataset.name) self.assertCountEqual(self.one._cache.datasets.squeeze().to_dict(), dataset.to_dict()) + types = self.one._cache.datasets.reset_index().dtypes.to_dict() + self.assertDictEqual(datasets_types, types) def test_save_loaded_ids(self): """Test One.save_loaded_ids and logic within One._check_filesystem"""