From 4374a24ab0b0997662d4ad0d73887796db90fc34 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Fri, 11 Oct 2024 14:01:00 +0300 Subject: [PATCH] Update cache from records in OneAlyx.search --- one/api.py | 65 +++++++++++++++++++++++++++++++----- one/tests/test_alyxclient.py | 22 ++++++++++++ one/tests/test_one.py | 20 ++++++++++- one/webclient.py | 33 ++++++++++++++++++ 4 files changed, 130 insertions(+), 10 deletions(-) diff --git a/one/api.py b/one/api.py index a00ef38a..de596ec5 100644 --- a/one/api.py +++ b/one/api.py @@ -98,14 +98,45 @@ def search_terms(self, query_type=None) -> tuple: def _reset_cache(self): """Replace the cache object with a Bunch that contains the right fields.""" - self._cache = Bunch({'_meta': { - 'expired': False, - 'created_time': None, - 'loaded_time': None, - 'modified_time': None, - 'saved_time': None, - 'raw': {} # map of original table metadata - }}) + self._cache = Bunch({ + 'datasets': pd.DataFrame(columns=DATASETS_COLUMNS).set_index(['eid', 'id']), + 'sessions': pd.DataFrame(columns=SESSIONS_COLUMNS).set_index('id'), + '_meta': { + 'expired': False, + 'created_time': None, + 'loaded_time': None, + 'modified_time': None, + 'saved_time': None, + 'raw': {}} # map of original table metadata + }) + + def _remove_cache_table_files(self, tables=None): + """Delete cache tables on disk. + + Parameters + ---------- + tables : list of str + A list of table names to removes, e.g. ['sessions', 'datasets']. + If None, the currently loaded table names are removed. NB: This + will also delete the cache_info.json metadata file. + + Returns + ------- + list of pathlib.Path + A list of the removed files. + + TODO Add test. + """ + tables = tables or filter(lambda x: x[0] != '_', self._cache) + filenames = ('cache_info.json', *(f'{t}.pqt' for t in tables)) + removed = [] + for file in map(self._tables_dir.joinpath, filenames): + if file.exists(): + file.unlink() + removed.append(file) + else: + _logger.warning('%s not found', file) + return removed def load_cache(self, tables_dir=None, **kwargs): """ @@ -187,7 +218,7 @@ def _save_cache(self, save_dir=None, force=False): If True, the cache is saved regardless of modification time. """ TIMEOUT = 5 # Delete lock file this many seconds after creation/modification or waiting - lock_file = Path(self.cache_dir).joinpath('.cache.lock') + lock_file = Path(self.cache_dir).joinpath('.cache.lock') # TODO use iblutil method here save_dir = Path(save_dir or self.cache_dir) meta = self._cache['_meta'] modified = meta.get('modified_time') or datetime.min @@ -2265,6 +2296,22 @@ def search(self, details=False, query_type=None, **kwargs): params.pop('django') # Make GET request ses = self.alyx.rest(self._search_endpoint, 'list', **params) + + def _update_cache_from_records(session_records): + """Update the cache tables with a list of session records.""" + df = pd.DataFrame(next(zip(*map(util.ses2records, session_records)))) + return self._update_cache_from_records(sessions=df) + + if len(ses) == 0: + pass # no need to update cache here + elif isinstance(ses, list): # not a paginated response + _update_cache_from_records(ses) + else: + # populate first page + _update_cache_from_records(ses._cache[:ses.limit]) + # Add callback for updating cache on future fetches + ses.add_callback(_update_cache_from_records) + # LazyId only transforms records when indexed eids = util.LazyId(ses) if not details: diff --git a/one/tests/test_alyxclient.py b/one/tests/test_alyxclient.py index 09615e6e..021b4aba 100644 --- a/one/tests/test_alyxclient.py +++ b/one/tests/test_alyxclient.py @@ -3,6 +3,7 @@ from unittest import mock import urllib.parse import random +import gc import os import one.webclient as wc import one.params @@ -498,12 +499,23 @@ def test_paginated_response(self): self.assertTrue(not any(pg._cache[lim:])) self.assertIs(pg.alyx, alyx) + # Check adding callbacks + self.assertRaises(TypeError, pg.add_callback, None) + cb1, cb2 = mock.MagicMock(), mock.MagicMock() + pg.add_callback(cb1) + pg.add_callback(cb2) + self.assertEqual(2, len(pg._callbacks)) + # Check fetching cached item with +ve int self.assertEqual({'id': 1}, pg[1]) alyx._generic_request.assert_not_called() + for cb in [cb1, cb2]: + cb.assert_not_called() # Check fetching cached item with +ve slice self.assertEqual([{'id': 1}, {'id': 2}], pg[1:3]) alyx._generic_request.assert_not_called() + for cb in [cb1, cb2]: + cb.assert_not_called() # Check fetching cached item with -ve int self.assertEqual({'id': 100}, pg[-1900]) alyx._generic_request.assert_not_called() @@ -518,6 +530,12 @@ def test_paginated_response(self): self.assertEqual(res['results'], pg._cache[offset:offset + lim]) alyx._generic_request.assert_called_once_with(requests.get, mock.ANY, clobber=True) self._check_get_query(alyx._generic_request.call_args, lim, offset) + for cb in [cb1, cb2]: + cb.assert_called_once_with(res['results']) + # Check that deleting one of the callbacks with remove it from the list (weakref) + del cb1 + gc.collect() + self.assertEqual(1, len(pg._callbacks)) # Check fetching uncached item with -ve int offset = lim * 3 res['results'] = [{'id': i} for i in range(offset, offset + lim)] @@ -548,6 +566,10 @@ def test_paginated_response(self): self.assertEqual(expected_calls := 4, alyx._generic_request.call_count) self.assertEqual((expected_calls + 1) * lim, sum(list(map(bool, pg._cache)))) + # Check callbacks cleared when cache fully populated + self.assertTrue(all(map(bool, pg))) + self.assertEqual(0, len(pg._callbacks)) + def _check_get_query(self, call_args, limit, offset): """Check URL get query contains the expected limit and offset params.""" (_, url), _ = call_args diff --git a/one/tests/test_one.py b/one/tests/test_one.py index de8d5341..261e91a9 100644 --- a/one/tests/test_one.py +++ b/one/tests/test_one.py @@ -1470,8 +1470,17 @@ def test_list_datasets(self): def test_search(self): """Test OneAlyx.search method in remote mode.""" + # Modify sessions dataframe so we can check that the records get updated + records = self.one._cache.sessions[self.one._cache.sessions.subject == 'SWC_043'] + self.one._cache.sessions.loc[records.index, 'lab'] = 'foolab' # change a field + self.one._cache.sessions.drop(self.eid, inplace=True) # remove a row + + # Check remote seach of subject eids = self.one.search(subject='SWC_043', query_type='remote') self.assertIn(self.eid, list(eids)) + updated = self.one._cache.sessions[self.one._cache.sessions.subject == 'SWC_043'] + self.assertCountEqual(eids, updated.index) + self.assertFalse('foolab' in updated['lab']) eids, det = self.one.search(subject='SWC_043', query_type='remote', details=True) correct = len(det) == len(eids) and 'url' in det[0] and det[0]['url'].endswith(eids[0]) @@ -1496,10 +1505,19 @@ def test_search(self): dates = set(map(lambda x: self.one.get_details(x)['date'], eids)) self.assertTrue(dates <= set(date_range)) - # Test limit arg and LazyId + # Test limit arg, LazyId, and update with paginated response callback + self.one._reset_cache() # Remove sessions table + assert self.one._cache.sessions.empty eids = self.one.search(date='2020-03-23', limit=2, query_type='remote') + self.assertEqual(2, len(self.one._cache.sessions), + 'failed to update cache with first page of search results') self.assertIsInstance(eids, LazyId) + e = eids[-3] # access an uncached value + self.assertEqual( + 4, len(self.one._cache.sessions), 'failed to update cache after page access') + self.assertTrue(e in self.one._cache.sessions.index) self.assertTrue(all(len(x) == 36 for x in eids)) + self.assertEqual(len(eids), len(self.one._cache.sessions)) # Test laboratory kwarg eids = self.one.search(laboratory='hoferlab', query_type='remote') diff --git a/one/webclient.py b/one/webclient.py index 155beeca..66901e9e 100644 --- a/one/webclient.py +++ b/one/webclient.py @@ -40,6 +40,7 @@ from typing import Optional from datetime import datetime, timedelta from pathlib import Path +from weakref import WeakSet import warnings import hashlib import zipfile @@ -206,6 +207,22 @@ def __init__(self, alyx, rep, cache_args=None): # fill the cache with results of the query for i in range(self.limit): self._cache[i] = rep['results'][i] + self._callbacks = WeakSet() + + def add_callback(self, cb): + """Add a callback function to use each time a new page is fetched. + + The callback function will be called with the page results each time :meth:`populate` + is called. + + Parameters + ---------- + cb : callable + A callable that takes the results of each paginated resonse. + """ + if not callable(cb): + raise TypeError(f'Expected type "callable", got "{type(cb)}" instead') + self._callbacks.add(cb) def __len__(self): return self.count @@ -222,6 +239,16 @@ def __getitem__(self, item): return self._cache[item] def populate(self, idx): + """Populate response cache with new page of results. + + Fetches the specific page of results containing the index passed and populates + stores the results in the :prop:`_cache` property. + + Parameters + ---------- + idx : int + The index of a given record to fetch. + """ offset = self.limit * math.floor(idx / self.limit) query = update_url_params(self.query, {'limit': self.limit, 'offset': offset}) res = self.alyx._generic_request(requests.get, query, **self._cache_args) @@ -231,6 +258,12 @@ def populate(self, idx): f'results may be inconsistent', RuntimeWarning) for i, r in enumerate(res['results'][:self.count - offset]): self._cache[i + offset] = res['results'][i] + # Notify callbacks + for cb in self._callbacks: + cb(res['results']) + # When cache is complete, clear our callbacks + if all(reversed(self._cache)): + self._callbacks.clear() def __iter__(self): for i in range(self.count):