Skip to content

Commit

Permalink
cache models.get_originals in memcache with new memcache_memoize deco…
Browse files Browse the repository at this point in the history
…rator

for #1149
  • Loading branch information
snarfed committed Jul 30, 2024
1 parent 33e0d0b commit 88cbe3b
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 2 deletions.
28 changes: 28 additions & 0 deletions common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Misc common utilities."""
import base64
from datetime import timedelta
import functools
import logging
from pathlib import Path
import re
Expand All @@ -9,6 +10,7 @@
from urllib.parse import urljoin, urlparse

import cachetools
from cachetools.keys import hashkey
from Crypto.Util import number
from flask import abort, g, has_request_context, make_response, request
from google.cloud.error_reporting.util import build_flask_context
Expand Down Expand Up @@ -413,3 +415,29 @@ def memcache_key(key):
pymemcache Client's allow_unicode_keys constructor kwarg.
"""
return key[:MEMCACHE_KEY_MAX_LEN].replace(' ', '%20').encode()


def memcache_memoize(expire=None):
"""Memoize function decorator that stores the cached value in memcache.
Only caches non-null/empty values.
Args:
expire (int): optional, expiration in seconds
"""
def decorator(fn):
@functools.wraps(fn)
def wrapped(*args, **kwargs):
key = memcache_key(f'{fn.__name__}-{str(hashkey(*args, **kwargs))}')
if val := memcache.get(key):
logger.debug(f'cache hit {key}')
return val

logger.debug(f'cache miss {key}')
val = fn(*args, **kwargs)
memcache.set(key, val)
return val

return wrapped

return decorator
2 changes: 1 addition & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
if logging_client := getattr(appengine_config, 'logging_client'):
logging_client.setup_logging(log_level=logging.INFO)

for logger in ('oauth_dropins.webutil.webmention', 'lexrpc'):
for logger in ('common', 'oauth_dropins.webutil.webmention', 'lexrpc'):
logging.getLogger(logger).setLevel(logging.DEBUG)

os.environ.setdefault('APPVIEW_HOST', 'api.bsky.local')
Expand Down
4 changes: 3 additions & 1 deletion models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
base64_to_long,
DOMAIN_RE,
long_to_base64,
memcache_memoize,
OLD_ACCOUNT_AGE,
remove,
report_error,
Expand Down Expand Up @@ -1563,6 +1564,7 @@ def get_original(copy_id, keys_only=None):
return got[0]


@memcache_memoize(expire=60 * 60 * 24) # 1d
def get_originals(copy_ids, keys_only=None):
"""Fetches users (across all protocols) for a given set of copies.
Expand All @@ -1577,7 +1579,7 @@ def get_originals(copy_ids, keys_only=None):
"""
assert copy_ids

classes = set(cls for cls in PROTOCOLS.values() if cls)
classes = set(cls for cls in PROTOCOLS.values() if cls and cls.LABEL != 'ui')
classes.add(Object)

return list(itertools.chain(*(
Expand Down
7 changes: 7 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,13 @@ def test_get_originals(self):
user = self.make_user('other:user', cls=OtherFake,
copies=[Target(uri='fake:bar', protocol='fake')])

memcache_key = "get_originals-(['other:foo',%20'fake:bar',%20'baz'],)"
self.assertIsNone(common.memcache.get(memcache_key))

self.assert_entities_equal(
[obj, user], models.get_originals(['other:foo', 'fake:bar', 'baz']))

self.assertIsNotNone(common.memcache.get(memcache_key))
self.assert_entities_equal(
[obj, user], models.get_originals(['other:foo', 'fake:bar', 'baz']))

Expand Down

0 comments on commit 88cbe3b

Please sign in to comment.