Skip to content

Commit

Permalink
Update cache from records in OneAlyx.search
Browse files Browse the repository at this point in the history
  • Loading branch information
k1o0 committed Oct 14, 2024
1 parent d1a8331 commit 47e5b27
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 10 deletions.
50 changes: 41 additions & 9 deletions one/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import urllib.parse
import warnings
import logging
from weakref import WeakMethod
from datetime import datetime, timedelta
from functools import lru_cache, partial
from inspect import unwrap
Expand Down Expand Up @@ -98,14 +99,17 @@ def search_terms(self, query_type=None) -> tuple:

def _reset_cache(self):
"""Replace the cache object with a Bunch that contains the right fields."""
self._cache = Bunch({'_meta': {
'expired': False,
'created_time': None,
'loaded_time': None,
'modified_time': None,
'saved_time': None,
'raw': {} # map of original table metadata
}})
self._cache = Bunch({
'datasets': pd.DataFrame(columns=DATASETS_COLUMNS).set_index(['eid', 'id']),
'sessions': pd.DataFrame(columns=SESSIONS_COLUMNS).set_index('id'),
'_meta': {
'expired': False,
'created_time': None,
'loaded_time': None,
'modified_time': None,
'saved_time': None,
'raw': {}} # map of original table metadata
})

def load_cache(self, tables_dir=None, **kwargs):
"""
Expand Down Expand Up @@ -187,7 +191,7 @@ def _save_cache(self, save_dir=None, force=False):
If True, the cache is saved regardless of modification time.
"""
TIMEOUT = 5 # Delete lock file this many seconds after creation/modification or waiting
lock_file = Path(self.cache_dir).joinpath('.cache.lock')
lock_file = Path(self.cache_dir).joinpath('.cache.lock') # TODO use iblutil method here
save_dir = Path(save_dir or self.cache_dir)
meta = self._cache['_meta']
modified = meta.get('modified_time') or datetime.min
Expand Down Expand Up @@ -2265,6 +2269,18 @@ def search(self, details=False, query_type=None, **kwargs):
params.pop('django')
# Make GET request
ses = self.alyx.rest(self._search_endpoint, 'list', **params)

# Update cache table with results
if len(ses) == 0:
pass # no need to update cache here
elif isinstance(ses, list): # not a paginated response
self._update_sessions_table(ses)
else:
# populate first page
self._update_sessions_table(ses._cache[:ses.limit])
# Add callback for updating cache on future fetches
ses.add_callback(WeakMethod(self._update_sessions_table))

# LazyId only transforms records when indexed
eids = util.LazyId(ses)
if not details:
Expand All @@ -2278,6 +2294,22 @@ def _add_date(records):

return eids, util.LazyId(ses, func=_add_date)

def _update_sessions_table(self, session_records):
"""Update the sessions tables with a list of session records.
Parameters
----------
session_records : list of dict
A list of session records from the /sessions list endpoint.
Returns
-------
datetime.datetime:
A timestamp of when the cache was updated.
"""
df = pd.DataFrame(next(zip(*map(util.ses2records, session_records))))
return self._update_cache_from_records(sessions=df)

def _download_datasets(self, dsets, **kwargs) -> List[Path]:
"""
Download a single or multitude of datasets if stored on AWS, otherwise calls
Expand Down
22 changes: 22 additions & 0 deletions one/tests/test_alyxclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from unittest import mock
import urllib.parse
import random
import weakref
import os
import one.webclient as wc
import one.params
Expand Down Expand Up @@ -498,12 +499,24 @@ def test_paginated_response(self):
self.assertTrue(not any(pg._cache[lim:]))
self.assertIs(pg.alyx, alyx)

# Check adding callbacks
self.assertRaises(TypeError, pg.add_callback, None)
wf = mock.Mock(spec_set=weakref.ref)
cb1, cb2 = mock.MagicMock(), wf()
pg.add_callback(cb1)
pg.add_callback(wf)
self.assertEqual(2, len(pg._callbacks))

# Check fetching cached item with +ve int
self.assertEqual({'id': 1}, pg[1])
alyx._generic_request.assert_not_called()
for cb in [cb1, cb2]:
cb.assert_not_called()
# Check fetching cached item with +ve slice
self.assertEqual([{'id': 1}, {'id': 2}], pg[1:3])
alyx._generic_request.assert_not_called()
for cb in [cb1, cb2]:
cb.assert_not_called()
# Check fetching cached item with -ve int
self.assertEqual({'id': 100}, pg[-1900])
alyx._generic_request.assert_not_called()
Expand All @@ -518,6 +531,10 @@ def test_paginated_response(self):
self.assertEqual(res['results'], pg._cache[offset:offset + lim])
alyx._generic_request.assert_called_once_with(requests.get, mock.ANY, clobber=True)
self._check_get_query(alyx._generic_request.call_args, lim, offset)
for cb in [cb1, cb2]:
cb.assert_called_once_with(res['results'])
# Check that dead weakreaf will be removed from the list on next call
wf.return_value = None
# Check fetching uncached item with -ve int
offset = lim * 3
res['results'] = [{'id': i} for i in range(offset, offset + lim)]
Expand All @@ -527,6 +544,7 @@ def test_paginated_response(self):
self.assertEqual(res['results'], pg._cache[offset:offset + lim])
alyx._generic_request.assert_called_with(requests.get, mock.ANY, clobber=True)
self._check_get_query(alyx._generic_request.call_args, lim, offset)
self.assertEqual(1, len(pg._callbacks), 'failed to remove weakref callback')
# Check fetching uncached item with +ve slice
offset = lim * 5
res['results'] = [{'id': i} for i in range(offset, offset + lim)]
Expand All @@ -548,6 +566,10 @@ def test_paginated_response(self):
self.assertEqual(expected_calls := 4, alyx._generic_request.call_count)
self.assertEqual((expected_calls + 1) * lim, sum(list(map(bool, pg._cache))))

# Check callbacks cleared when cache fully populated
self.assertTrue(all(map(bool, pg)))
self.assertEqual(0, len(pg._callbacks))

def _check_get_query(self, call_args, limit, offset):
"""Check URL get query contains the expected limit and offset params."""
(_, url), _ = call_args
Expand Down
21 changes: 20 additions & 1 deletion one/tests/test_one.py
Original file line number Diff line number Diff line change
Expand Up @@ -1470,8 +1470,17 @@ def test_list_datasets(self):

def test_search(self):
"""Test OneAlyx.search method in remote mode."""
# Modify sessions dataframe so we can check that the records get updated
records = self.one._cache.sessions[self.one._cache.sessions.subject == 'SWC_043']
self.one._cache.sessions.loc[records.index, 'lab'] = 'foolab' # change a field
self.one._cache.sessions.drop(self.eid, inplace=True) # remove a row

# Check remote seach of subject
eids = self.one.search(subject='SWC_043', query_type='remote')
self.assertIn(self.eid, list(eids))
updated = self.one._cache.sessions[self.one._cache.sessions.subject == 'SWC_043']
self.assertCountEqual(eids, updated.index)
self.assertFalse('foolab' in updated['lab'])

eids, det = self.one.search(subject='SWC_043', query_type='remote', details=True)
correct = len(det) == len(eids) and 'url' in det[0] and det[0]['url'].endswith(eids[0])
Expand All @@ -1496,10 +1505,20 @@ def test_search(self):
dates = set(map(lambda x: self.one.get_details(x)['date'], eids))
self.assertTrue(dates <= set(date_range))

# Test limit arg and LazyId
# Test limit arg, LazyId, and update with paginated response callback
self.one._reset_cache() # Remove sessions table
assert self.one._cache.sessions.empty
eids = self.one.search(date='2020-03-23', limit=2, query_type='remote')
self.assertEqual(2, len(self.one._cache.sessions),
'failed to update cache with first page of search results')
self.assertIsInstance(eids, LazyId)
assert len(eids) > 5, 'in order to check paginated response callback we need several pages'
e = eids[-3] # access an uncached value
self.assertEqual(
4, len(self.one._cache.sessions), 'failed to update cache after page access')
self.assertTrue(e in self.one._cache.sessions.index)
self.assertTrue(all(len(x) == 36 for x in eids))
self.assertEqual(len(eids), len(self.one._cache.sessions))

# Test laboratory kwarg
eids = self.one.search(laboratory='hoferlab', query_type='remote')
Expand Down
43 changes: 43 additions & 0 deletions one/webclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from typing import Optional
from datetime import datetime, timedelta
from pathlib import Path
from weakref import ReferenceType
import warnings
import hashlib
import zipfile
Expand Down Expand Up @@ -206,6 +207,23 @@ def __init__(self, alyx, rep, cache_args=None):
# fill the cache with results of the query
for i in range(self.limit):
self._cache[i] = rep['results'][i]
self._callbacks = set()

def add_callback(self, cb):
"""Add a callback function to use each time a new page is fetched.
The callback function will be called with the page results each time :meth:`populate`
is called.
Parameters
----------
cb : callable
A callable that takes the results of each paginated resonse.
"""
if not callable(cb):
raise TypeError(f'Expected type "callable", got "{type(cb)}" instead')
else:
self._callbacks.add(cb)

def __len__(self):
return self.count
Expand All @@ -222,6 +240,16 @@ def __getitem__(self, item):
return self._cache[item]

def populate(self, idx):
"""Populate response cache with new page of results.
Fetches the specific page of results containing the index passed and populates
stores the results in the :prop:`_cache` property.
Parameters
----------
idx : int
The index of a given record to fetch.
"""
offset = self.limit * math.floor(idx / self.limit)
query = update_url_params(self.query, {'limit': self.limit, 'offset': offset})
res = self.alyx._generic_request(requests.get, query, **self._cache_args)
Expand All @@ -231,6 +259,21 @@ def populate(self, idx):
f'results may be inconsistent', RuntimeWarning)
for i, r in enumerate(res['results'][:self.count - offset]):
self._cache[i + offset] = res['results'][i]
# Notify callbacks
pending_removal = []
for callback in self._callbacks:
# Handle weak reference callbacks first
if isinstance(callback, ReferenceType):
wf = callback
if (callback := wf()) is None:
pending_removal.append(wf)
continue
callback(res['results'])
for wf in pending_removal:
self._callbacks.discard(wf)
# When cache is complete, clear our callbacks
if all(reversed(self._cache)):
self._callbacks.clear()

def __iter__(self):
for i in range(self.count):
Expand Down

0 comments on commit 47e5b27

Please sign in to comment.