From 88cbe3b7b44c328bc7996bd4b0d4d379dc25d177 Mon Sep 17 00:00:00 2001 From: Ryan Barrett Date: Tue, 30 Jul 2024 14:50:33 -0700 Subject: [PATCH] cache models.get_originals in memcache with new memcache_memoize decorator for #1149 --- common.py | 28 ++++++++++++++++++++++++++++ config.py | 2 +- models.py | 4 +++- tests/test_models.py | 7 +++++++ 4 files changed, 39 insertions(+), 2 deletions(-) diff --git a/common.py b/common.py index b498128e..70ee0450 100644 --- a/common.py +++ b/common.py @@ -1,6 +1,7 @@ """Misc common utilities.""" import base64 from datetime import timedelta +import functools import logging from pathlib import Path import re @@ -9,6 +10,7 @@ from urllib.parse import urljoin, urlparse import cachetools +from cachetools.keys import hashkey from Crypto.Util import number from flask import abort, g, has_request_context, make_response, request from google.cloud.error_reporting.util import build_flask_context @@ -413,3 +415,29 @@ def memcache_key(key): pymemcache Client's allow_unicode_keys constructor kwarg. """ return key[:MEMCACHE_KEY_MAX_LEN].replace(' ', '%20').encode() + + +def memcache_memoize(expire=None): + """Memoize function decorator that stores the cached value in memcache. + + Only caches non-null/empty values. + + Args: + expire (int): optional, expiration in seconds + """ + def decorator(fn): + @functools.wraps(fn) + def wrapped(*args, **kwargs): + key = memcache_key(f'{fn.__name__}-{str(hashkey(*args, **kwargs))}') + if val := memcache.get(key): + logger.debug(f'cache hit {key}') + return val + + logger.debug(f'cache miss {key}') + val = fn(*args, **kwargs) + memcache.set(key, val) + return val + + return wrapped + + return decorator diff --git a/config.py b/config.py index 44bf8fb3..3a157788 100644 --- a/config.py +++ b/config.py @@ -29,7 +29,7 @@ if logging_client := getattr(appengine_config, 'logging_client'): logging_client.setup_logging(log_level=logging.INFO) - for logger in ('oauth_dropins.webutil.webmention', 'lexrpc'): + for logger in ('common', 'oauth_dropins.webutil.webmention', 'lexrpc'): logging.getLogger(logger).setLevel(logging.DEBUG) os.environ.setdefault('APPVIEW_HOST', 'api.bsky.local') diff --git a/models.py b/models.py index eec14164..38ce86c8 100644 --- a/models.py +++ b/models.py @@ -30,6 +30,7 @@ base64_to_long, DOMAIN_RE, long_to_base64, + memcache_memoize, OLD_ACCOUNT_AGE, remove, report_error, @@ -1563,6 +1564,7 @@ def get_original(copy_id, keys_only=None): return got[0] +@memcache_memoize(expire=60 * 60 * 24) # 1d def get_originals(copy_ids, keys_only=None): """Fetches users (across all protocols) for a given set of copies. @@ -1577,7 +1579,7 @@ def get_originals(copy_ids, keys_only=None): """ assert copy_ids - classes = set(cls for cls in PROTOCOLS.values() if cls) + classes = set(cls for cls in PROTOCOLS.values() if cls and cls.LABEL != 'ui') classes.add(Object) return list(itertools.chain(*( diff --git a/tests/test_models.py b/tests/test_models.py index 1801cbb3..f2559984 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1102,6 +1102,13 @@ def test_get_originals(self): user = self.make_user('other:user', cls=OtherFake, copies=[Target(uri='fake:bar', protocol='fake')]) + memcache_key = "get_originals-(['other:foo',%20'fake:bar',%20'baz'],)" + self.assertIsNone(common.memcache.get(memcache_key)) + + self.assert_entities_equal( + [obj, user], models.get_originals(['other:foo', 'fake:bar', 'baz'])) + + self.assertIsNotNone(common.memcache.get(memcache_key)) self.assert_entities_equal( [obj, user], models.get_originals(['other:foo', 'fake:bar', 'baz']))