Skip to content

Commit

Permalink
Update cache from records in OneAlyx.search
Browse files Browse the repository at this point in the history
  • Loading branch information
k1o0 committed Oct 14, 2024
1 parent f042aaa commit 4374a24
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 10 deletions.
65 changes: 56 additions & 9 deletions one/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,45 @@ def search_terms(self, query_type=None) -> tuple:

def _reset_cache(self):
"""Replace the cache object with a Bunch that contains the right fields."""
self._cache = Bunch({'_meta': {
'expired': False,
'created_time': None,
'loaded_time': None,
'modified_time': None,
'saved_time': None,
'raw': {} # map of original table metadata
}})
self._cache = Bunch({
'datasets': pd.DataFrame(columns=DATASETS_COLUMNS).set_index(['eid', 'id']),
'sessions': pd.DataFrame(columns=SESSIONS_COLUMNS).set_index('id'),
'_meta': {
'expired': False,
'created_time': None,
'loaded_time': None,
'modified_time': None,
'saved_time': None,
'raw': {}} # map of original table metadata
})

def _remove_cache_table_files(self, tables=None):
"""Delete cache tables on disk.
Parameters
----------
tables : list of str
A list of table names to removes, e.g. ['sessions', 'datasets'].
If None, the currently loaded table names are removed. NB: This
will also delete the cache_info.json metadata file.
Returns
-------
list of pathlib.Path
A list of the removed files.
TODO Add test.
"""
tables = tables or filter(lambda x: x[0] != '_', self._cache)
filenames = ('cache_info.json', *(f'{t}.pqt' for t in tables))
removed = []
for file in map(self._tables_dir.joinpath, filenames):
if file.exists():
file.unlink()
removed.append(file)
else:
_logger.warning('%s not found', file)
return removed

def load_cache(self, tables_dir=None, **kwargs):
"""
Expand Down Expand Up @@ -187,7 +218,7 @@ def _save_cache(self, save_dir=None, force=False):
If True, the cache is saved regardless of modification time.
"""
TIMEOUT = 5 # Delete lock file this many seconds after creation/modification or waiting
lock_file = Path(self.cache_dir).joinpath('.cache.lock')
lock_file = Path(self.cache_dir).joinpath('.cache.lock') # TODO use iblutil method here
save_dir = Path(save_dir or self.cache_dir)
meta = self._cache['_meta']
modified = meta.get('modified_time') or datetime.min
Expand Down Expand Up @@ -2265,6 +2296,22 @@ def search(self, details=False, query_type=None, **kwargs):
params.pop('django')
# Make GET request
ses = self.alyx.rest(self._search_endpoint, 'list', **params)

def _update_cache_from_records(session_records):
"""Update the cache tables with a list of session records."""
df = pd.DataFrame(next(zip(*map(util.ses2records, session_records))))
return self._update_cache_from_records(sessions=df)

if len(ses) == 0:
pass # no need to update cache here
elif isinstance(ses, list): # not a paginated response
_update_cache_from_records(ses)
else:
# populate first page
_update_cache_from_records(ses._cache[:ses.limit])
# Add callback for updating cache on future fetches
ses.add_callback(_update_cache_from_records)

# LazyId only transforms records when indexed
eids = util.LazyId(ses)
if not details:
Expand Down
22 changes: 22 additions & 0 deletions one/tests/test_alyxclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from unittest import mock
import urllib.parse
import random
import gc
import os
import one.webclient as wc
import one.params
Expand Down Expand Up @@ -498,12 +499,23 @@ def test_paginated_response(self):
self.assertTrue(not any(pg._cache[lim:]))
self.assertIs(pg.alyx, alyx)

# Check adding callbacks
self.assertRaises(TypeError, pg.add_callback, None)
cb1, cb2 = mock.MagicMock(), mock.MagicMock()
pg.add_callback(cb1)
pg.add_callback(cb2)
self.assertEqual(2, len(pg._callbacks))

# Check fetching cached item with +ve int
self.assertEqual({'id': 1}, pg[1])
alyx._generic_request.assert_not_called()
for cb in [cb1, cb2]:
cb.assert_not_called()
# Check fetching cached item with +ve slice
self.assertEqual([{'id': 1}, {'id': 2}], pg[1:3])
alyx._generic_request.assert_not_called()
for cb in [cb1, cb2]:
cb.assert_not_called()
# Check fetching cached item with -ve int
self.assertEqual({'id': 100}, pg[-1900])
alyx._generic_request.assert_not_called()
Expand All @@ -518,6 +530,12 @@ def test_paginated_response(self):
self.assertEqual(res['results'], pg._cache[offset:offset + lim])
alyx._generic_request.assert_called_once_with(requests.get, mock.ANY, clobber=True)
self._check_get_query(alyx._generic_request.call_args, lim, offset)
for cb in [cb1, cb2]:
cb.assert_called_once_with(res['results'])
# Check that deleting one of the callbacks with remove it from the list (weakref)
del cb1
gc.collect()
self.assertEqual(1, len(pg._callbacks))
# Check fetching uncached item with -ve int
offset = lim * 3
res['results'] = [{'id': i} for i in range(offset, offset + lim)]
Expand Down Expand Up @@ -548,6 +566,10 @@ def test_paginated_response(self):
self.assertEqual(expected_calls := 4, alyx._generic_request.call_count)
self.assertEqual((expected_calls + 1) * lim, sum(list(map(bool, pg._cache))))

# Check callbacks cleared when cache fully populated
self.assertTrue(all(map(bool, pg)))
self.assertEqual(0, len(pg._callbacks))

def _check_get_query(self, call_args, limit, offset):
"""Check URL get query contains the expected limit and offset params."""
(_, url), _ = call_args
Expand Down
20 changes: 19 additions & 1 deletion one/tests/test_one.py
Original file line number Diff line number Diff line change
Expand Up @@ -1470,8 +1470,17 @@ def test_list_datasets(self):

def test_search(self):
"""Test OneAlyx.search method in remote mode."""
# Modify sessions dataframe so we can check that the records get updated
records = self.one._cache.sessions[self.one._cache.sessions.subject == 'SWC_043']
self.one._cache.sessions.loc[records.index, 'lab'] = 'foolab' # change a field
self.one._cache.sessions.drop(self.eid, inplace=True) # remove a row

# Check remote seach of subject
eids = self.one.search(subject='SWC_043', query_type='remote')
self.assertIn(self.eid, list(eids))
updated = self.one._cache.sessions[self.one._cache.sessions.subject == 'SWC_043']
self.assertCountEqual(eids, updated.index)
self.assertFalse('foolab' in updated['lab'])

eids, det = self.one.search(subject='SWC_043', query_type='remote', details=True)
correct = len(det) == len(eids) and 'url' in det[0] and det[0]['url'].endswith(eids[0])
Expand All @@ -1496,10 +1505,19 @@ def test_search(self):
dates = set(map(lambda x: self.one.get_details(x)['date'], eids))
self.assertTrue(dates <= set(date_range))

# Test limit arg and LazyId
# Test limit arg, LazyId, and update with paginated response callback
self.one._reset_cache() # Remove sessions table
assert self.one._cache.sessions.empty
eids = self.one.search(date='2020-03-23', limit=2, query_type='remote')
self.assertEqual(2, len(self.one._cache.sessions),
'failed to update cache with first page of search results')
self.assertIsInstance(eids, LazyId)
e = eids[-3] # access an uncached value
self.assertEqual(
4, len(self.one._cache.sessions), 'failed to update cache after page access')
self.assertTrue(e in self.one._cache.sessions.index)
self.assertTrue(all(len(x) == 36 for x in eids))
self.assertEqual(len(eids), len(self.one._cache.sessions))

# Test laboratory kwarg
eids = self.one.search(laboratory='hoferlab', query_type='remote')
Expand Down
33 changes: 33 additions & 0 deletions one/webclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from typing import Optional
from datetime import datetime, timedelta
from pathlib import Path
from weakref import WeakSet
import warnings
import hashlib
import zipfile
Expand Down Expand Up @@ -206,6 +207,22 @@ def __init__(self, alyx, rep, cache_args=None):
# fill the cache with results of the query
for i in range(self.limit):
self._cache[i] = rep['results'][i]
self._callbacks = WeakSet()

def add_callback(self, cb):
"""Add a callback function to use each time a new page is fetched.
The callback function will be called with the page results each time :meth:`populate`
is called.
Parameters
----------
cb : callable
A callable that takes the results of each paginated resonse.
"""
if not callable(cb):
raise TypeError(f'Expected type "callable", got "{type(cb)}" instead')
self._callbacks.add(cb)

def __len__(self):
return self.count
Expand All @@ -222,6 +239,16 @@ def __getitem__(self, item):
return self._cache[item]

def populate(self, idx):
"""Populate response cache with new page of results.
Fetches the specific page of results containing the index passed and populates
stores the results in the :prop:`_cache` property.
Parameters
----------
idx : int
The index of a given record to fetch.
"""
offset = self.limit * math.floor(idx / self.limit)
query = update_url_params(self.query, {'limit': self.limit, 'offset': offset})
res = self.alyx._generic_request(requests.get, query, **self._cache_args)
Expand All @@ -231,6 +258,12 @@ def populate(self, idx):
f'results may be inconsistent', RuntimeWarning)
for i, r in enumerate(res['results'][:self.count - offset]):
self._cache[i + offset] = res['results'][i]
# Notify callbacks
for cb in self._callbacks:
cb(res['results'])
# When cache is complete, clear our callbacks
if all(reversed(self._cache)):
self._callbacks.clear()

def __iter__(self):
for i in range(self.count):
Expand Down

0 comments on commit 4374a24

Please sign in to comment.