Skip to content

Commit

Permalink
Enforce data frame dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
k1o0 committed Oct 31, 2024
1 parent 61e6427 commit 0f05989
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 16 deletions.
21 changes: 12 additions & 9 deletions one/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,30 +310,33 @@ def _update_cache_from_records(self, strict=False, **kwargs):
raise KeyError(f'Table "{table}" not in cache')
if isinstance(records, pd.Series):
records = pd.DataFrame([records])
records.index.set_names(self._cache[table].index.names, inplace=True)
if not strict:
# Deal with case where there are extra columns in the cache
extra_columns = set(self._cache[table].columns) - set(records.columns)
extra_columns = list(set(self._cache[table].columns) - set(records.columns))
# Convert these columns to nullable, if required
cache_columns = self._cache[table][extra_columns]
self._cache[table][extra_columns] = cache_columns.convert_dtypes()
column_ids = map(list(self._cache[table].columns).index, extra_columns)
for col, n in sorted(zip(extra_columns, column_ids), key=lambda x: x[1]):
val = records.get('exists', True) if col.startswith('exists_') else np.nan
dtype = self._cache[table][col].dtype
nan = getattr(dtype, 'na_value', np.nan)
val = records.get('exists', True) if col.startswith('exists_') else nan
records.insert(n, col, val)
# Drop any extra columns in the records that aren't in cache table
to_drop = set(records.columns) - set(self._cache[table].columns)
records.drop(to_drop, axis=1, inplace=True)
records = records.reindex(columns=self._cache[table].columns)
assert set(self._cache[table].columns) == set(records.columns)
records = records.astype(self._cache[table].dtypes)
# Update existing rows
to_update = records.index.isin(self._cache[table].index)
self._cache[table].loc[records.index[to_update], :] = records[to_update]
# Assign new rows
to_assign = records[~to_update]
if isinstance(self._cache[table].index, pd.MultiIndex) and not to_assign.empty:
# Concatenate and sort (no other way for non-unique index within MultiIndex)
frames = filter(lambda x: not x.empty, [self._cache[table], to_assign])
self._cache[table] = pd.concat(frames).sort_index()
else:
for index, record in to_assign.iterrows():
self._cache[table].loc[index, :] = record[self._cache[table].columns].values
frames = [self._cache[table], to_assign]
# Concatenate and sort
self._cache[table] = pd.concat(frames).sort_index()
updated = datetime.now()
self._cache['_meta']['modified_time'] = updated
return updated
Expand Down
12 changes: 6 additions & 6 deletions one/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from iblutil.util import Bunch, Listable, ensure_list

from one.alf.spec import is_session_path, is_uuid_string
from one.alf.cache import QC_TYPE, EMPTY_DATASETS_FRAME
from one.alf.cache import EMPTY_DATASETS_FRAME
from one.alf.files import (
get_session_path, add_uuid_string, session_path_parts, get_alf_path, remove_uuid_string)

Expand Down Expand Up @@ -783,7 +783,8 @@ def _to_record(d):
return session, EMPTY_DATASETS_FRAME.copy()
records = map(_to_record, ses['data_dataset_session_related'])
index = ['eid', 'id']
datasets = pd.DataFrame(records).set_index(index).sort_index().astype({'qc': QC_TYPE})
dtypes = EMPTY_DATASETS_FRAME.dtypes
datasets = pd.DataFrame(records).astype(dtypes).set_index(index).sort_index()
return session, datasets


Expand Down Expand Up @@ -829,8 +830,7 @@ def datasets2records(datasets, additional=None) -> pd.DataFrame:
rec[field] = d.get(field)
records.append(rec)

index = ['eid', 'id']
if not records:
keys = (*index, 'file_size', 'hash', 'session_path', 'rel_path', 'default_revision', 'qc')
return pd.DataFrame(columns=keys).set_index(index)
return pd.DataFrame(records).set_index(index).sort_index().astype({'qc': QC_TYPE})
return EMPTY_DATASETS_FRAME
index = EMPTY_DATASETS_FRAME.index.names
return pd.DataFrame(records).set_index(index).sort_index().astype(EMPTY_DATASETS_FRAME.dtypes)
41 changes: 40 additions & 1 deletion one/tests/test_one.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from one.converters import datasets2records
from one.alf import spec
from one.alf.files import get_alf_path
from one.alf.cache import EMPTY_DATASETS_FRAME, EMPTY_SESSIONS_FRAME
from . import util
from . import OFFLINE_ONLY, TEST_DB_1, TEST_DB_2 # 1 = TestAlyx; 2 = OpenAlyx

Expand Down Expand Up @@ -905,8 +906,28 @@ def test_save_cache(self):
self.assertFalse(lock_file.exists(), 'failed to remove stale lock file')
time_mock.sleep.assert_called()

def test_reset_cache(self):
"""Test One._reset_cache method, namely that cache types are correct."""
# Assert cache dtypes are indeed what are expected
self.one._reset_cache()
self.assertCountEqual(['datasets', 'sessions', '_meta'], self.one._cache.keys())
self.assertTrue(self.one._cache.datasets.empty)
self.assertCountEqual(EMPTY_SESSIONS_FRAME.columns, self.one._cache.sessions.columns)
self.assertTrue(self.one._cache.sessions.empty)
self.assertCountEqual(EMPTY_DATASETS_FRAME.columns, self.one._cache.datasets.columns)
# Check sessions data frame types
sessions_types = EMPTY_SESSIONS_FRAME.reset_index().dtypes.to_dict()
s_types = self.one._cache.sessions.reset_index().dtypes.to_dict()
self.assertDictEqual(sessions_types, s_types)
# Check datasets data frame types
datasets_types = EMPTY_DATASETS_FRAME.reset_index().dtypes.to_dict()
d_types = self.one._cache.datasets.reset_index().dtypes.to_dict()
self.assertDictEqual(datasets_types, d_types)

def test_update_cache_from_records(self):
"""Test One._update_cache_from_records"""
sessions_types = self.one._cache.sessions.reset_index().dtypes.to_dict()
datasets_types = self.one._cache.datasets.reset_index().dtypes.to_dict()
# Update with single record (pandas.Series), one exists, one doesn't
session = self.one._cache.sessions.iloc[0].squeeze()
session.name = str(uuid4()) # New record
Expand All @@ -916,24 +937,33 @@ def test_update_cache_from_records(self):
self.assertTrue(session.name in self.one._cache.sessions.index)
updated, = dataset['exists'] == self.one._cache.datasets.loc[dataset.name, 'exists']
self.assertTrue(updated)
# Check that the updated data frame has kept its original dtypes
types = self.one._cache.sessions.reset_index().dtypes.to_dict()
self.assertDictEqual(sessions_types, types)
types = self.one._cache.datasets.reset_index().dtypes.to_dict()
self.assertDictEqual(datasets_types, types)

# Update a number of records
datasets = self.one._cache.datasets.iloc[:3].copy()
datasets.loc[:, 'exists'] = ~datasets.loc[:, 'exists']
# Make one of the datasets a new record
idx = datasets.index.values
idx[-1] = (idx[-1][0], str(uuid4()))
datasets.index = pd.MultiIndex.from_tuples(idx)
datasets.index = pd.MultiIndex.from_tuples(idx, names=('eid', 'id'))
self.one._update_cache_from_records(datasets=datasets)
self.assertTrue(idx[-1] in self.one._cache.datasets.index)
verifiable = self.one._cache.datasets.loc[datasets.index.values, 'exists']
self.assertTrue(np.all(verifiable == datasets.loc[:, 'exists']))
# Check that the updated data frame has kept its original dtypes
types = self.one._cache.datasets.reset_index().dtypes.to_dict()
self.assertDictEqual(datasets_types, types)

# Check behaviour when columns don't match
datasets.loc[:, 'exists'] = ~datasets.loc[:, 'exists']
datasets['extra_column'] = True
self.one._cache.datasets['foo_bar'] = 12 # this column is missing in our new records
self.one._cache.datasets['new_column'] = False
expected_datasets_types = self.one._cache.datasets.reset_index().dtypes.to_dict()
self.addCleanup(self.one._cache.datasets.drop, 'foo_bar', axis=1, inplace=True)
# An exception is exists_* as the Alyx cache contains exists_aws and exists_flatiron
# These should simply be filled with the values of exists as Alyx won't return datasets
Expand All @@ -950,18 +980,27 @@ def test_update_cache_from_records(self):
# If the extra column does not start with 'exists' it should be set to NaN
verifiable = self.one._cache.datasets.loc[datasets.index.values, 'foo_bar']
self.assertTrue(np.isnan(verifiable).all())
# Check that the missing columns were updated to nullable fields
expected_datasets_types.update(
foo_bar=pd.Int64Dtype(), exists_aws=pd.BooleanDtype(), new_column=pd.BooleanDtype())
types = self.one._cache.datasets.reset_index().dtypes.to_dict()
self.assertDictEqual(expected_datasets_types, types)

# Check fringe cases
with self.assertRaises(KeyError):
self.one._update_cache_from_records(unknown=datasets)
self.assertIsNone(self.one._update_cache_from_records(datasets=None))
# Absent cache table
self.one.load_cache(tables_dir='/foo')
sessions_types = self.one._cache.sessions.reset_index().dtypes.to_dict()
datasets_types = self.one._cache.datasets.reset_index().dtypes.to_dict()
self.one._update_cache_from_records(sessions=session, datasets=dataset)
self.assertTrue(all(self.one._cache.sessions == pd.DataFrame([session])))
self.assertEqual(1, len(self.one._cache.datasets))
self.assertEqual(self.one._cache.datasets.squeeze().name, dataset.name)
self.assertCountEqual(self.one._cache.datasets.squeeze().to_dict(), dataset.to_dict())
types = self.one._cache.datasets.reset_index().dtypes.to_dict()
self.assertDictEqual(datasets_types, types)

def test_save_loaded_ids(self):
"""Test One.save_loaded_ids and logic within One._check_filesystem"""
Expand Down

0 comments on commit 0f05989

Please sign in to comment.