Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add asynchronous support #43

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 155 additions & 28 deletions redis_cache/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import asyncio
from functools import wraps
from json import dumps, loads
from base64 import b64encode
from inspect import signature, Parameter
from redis import Redis
from redis.asyncio import Redis as RedisAsync

def compact_dump(value):
return dumps(value, separators=(',', ':'), sort_keys=True)
Expand Down Expand Up @@ -101,6 +104,23 @@ def chunks(iterable, n):

yield elements

async def async_chunks(iterable, n):
"""Yield successive n-sized chunks from iterator."""
_iterable = aiter(iterable)
while True:
elements = []
for _ in range(n):
try:
_e = await anext(_iterable)
elements.append(_e)
except StopAsyncIteration:
break

if not len(elements):
break

yield elements


class RedisCache:
def __init__(self, redis_client, prefix="rc", serializer=compact_dump, deserializer=loads, key_serializer=None, support_cluster=True, exception_handler=None, active:bool=True):
Expand Down Expand Up @@ -129,6 +149,8 @@ def cache(self, ttl=0, limit=0, namespace=None, exception_handler=None):
)

def mget(self, *fns_with_args):
if self.client and not isinstance(self.client, Redis):
raise RuntimeError("This method can only be used with a synchronous Redis client")
keys = []
for fn_and_args in fns_with_args:
fn = fn_and_args['fn']
Expand Down Expand Up @@ -160,6 +182,44 @@ def mget(self, *fns_with_args):
pipeline.execute()
return deserialized_results


async def async_mget(self, *fns_with_args):
if not isinstance(self.client, RedisAsync):
raise RuntimeError("This method can only be used with an async Redis client")
keys = []
for fn_and_args in fns_with_args:
fn = fn_and_args['fn']
args = fn_and_args['args'] if 'args' in fn_and_args else []
kwargs = fn_and_args['kwargs'] if 'kwargs' in fn_and_args else {}
keys.append(fn.instance.get_key(args=args, kwargs=kwargs))

results = await self.client.mget(*keys)
pipeline = self.client.pipeline()

deserialized_results = []
needs_pipeline = False
for i, result in enumerate(results):
if result is None:
needs_pipeline = True

fn_and_args = fns_with_args[i]
fn = fn_and_args['fn']
args = fn_and_args['args'] if 'args' in fn_and_args else []
kwargs = fn_and_args['kwargs'] if 'kwargs' in fn_and_args else {}
if asyncio.iscoroutinefunction(fn.instance.original_fn):
result = await fn.instance.original_fn(*args, **kwargs)
else:
result = await asyncio.to_thread(fn.instance.original_fn, *args, **kwargs)
result_serialized = self.serializer(result)
await get_cache_lua_fn(self.client)(keys=[keys[i], fn.instance.keys_key], args=[result_serialized, fn.instance.ttl, fn.instance.limit], client=pipeline)
else:
result = self.deserializer(result)
deserialized_results.append(result)

if needs_pipeline:
await pipeline.execute()
return deserialized_results

class CacheDecorator:
def __init__(self, redis_client, prefix="rc", serializer=compact_dump, deserializer=loads, key_serializer=None, ttl=0, limit=0, namespace=None, support_cluster=True, exception_handler=None, active:bool=True):
self.client = redis_client
Expand Down Expand Up @@ -208,34 +268,83 @@ def __call__(self, fn):
self.keys_key = f'{self.get_full_prefix()}:keys'
self.original_fn = fn

@wraps(fn)
def inner(*args, **kwargs):
nonlocal self
# Return the original function if we're not in active mode
if not self.active:
return fn(*args, **kwargs)
key = self.get_key(args, kwargs)
result = None

exception_handled = False
try:
result = self.client.get(key)
except Exception as e:
if self.exception_handler:
# This allows people to handle failures in cache lookups
exception_handled = True
parsed_result = self.exception_handler(e, self.original_fn, args, kwargs)
if result:
parsed_result = self.deserializer(result)
elif not exception_handled:
parsed_result = fn(*args, **kwargs)
result_serialized = self.serializer(parsed_result)
get_cache_lua_fn(self.client)(keys=[key, self.keys_key], args=[result_serialized, self.ttl, self.limit])

return parsed_result

inner.invalidate = self.invalidate
inner.invalidate_all = self.invalidate_all
if asyncio.iscoroutinefunction(fn):

@wraps(fn)
async def inner(*args, **kwargs):
nonlocal self
# Return the original function if we're not in active mode
if not self.active:
return await fn(*args, **kwargs)
key = self.get_key(args, kwargs)
result = None

exception_handled = False
try:
if isinstance(self.client, Redis):
result = await asyncio.to_thread(self.client.get, key)
else:
result = await self.client.get(key)
except Exception as e:
if self.exception_handler:
# This allows people to handle failures in cache lookups
exception_handled = True
if asyncio.iscoroutinefunction(self.exception_handler):
parsed_result = await self.exception_handler(e, self.original_fn, args, kwargs)
else:
parsed_result = await asyncio.to_thread(self.exception_handler,e, self.original_fn, args, kwargs)
if result:
parsed_result = self.deserializer(result)
elif not exception_handled:
parsed_result = await fn(*args, **kwargs)
result_serialized = self.serializer(parsed_result)
if isinstance(self.client, Redis):
await asyncio.to_thread(
get_cache_lua_fn(self.client),
keys=[key, self.keys_key],
args=[result_serialized, self.ttl, self.limit]
)
else:
await get_cache_lua_fn(self.client)(
keys=[key, self.keys_key],
args=[result_serialized, self.ttl, self.limit]
)

return parsed_result

inner.invalidate = self.async_invalidate
inner.invalidate_all = self.async_invalidate_all
else:
if self.client and not isinstance(self.client, Redis):
raise RuntimeError("This method can only be used with a synchronous Redis client")
@wraps(fn)
def inner(*args, **kwargs):
nonlocal self
# Return the original function if we're not in active mode
if not self.active:
return fn(*args, **kwargs)
key = self.get_key(args, kwargs)
result = None

exception_handled = False
try:
result = self.client.get(key)
except Exception as e:
if self.exception_handler:
# This allows people to handle failures in cache lookups
exception_handled = True
parsed_result = self.exception_handler(e, self.original_fn, args, kwargs)
if result:
parsed_result = self.deserializer(result)
elif not exception_handled:
parsed_result = fn(*args, **kwargs)
result_serialized = self.serializer(parsed_result)
get_cache_lua_fn(self.client)(keys=[key, self.keys_key], args=[result_serialized, self.ttl, self.limit])

return parsed_result

inner.invalidate = self.invalidate
inner.invalidate_all = self.invalidate_all
inner.get_full_prefix = self.get_full_prefix
inner.instance = self
return inner
Expand All @@ -247,7 +356,25 @@ def invalidate(self, *args, **kwargs):
pipe.zrem(self.keys_key, key)
pipe.execute()

async def async_invalidate(self, *args, **kwargs):
if isinstance(self.client, Redis):
await asyncio.to_thread(self.invalidate, *args, **kwargs)
else:
key = self.get_key(args, kwargs)
async with self.client.pipeline() as pipe:
await pipe.delete(key)
await pipe.zrem(self.keys_key, key)
await pipe.execute()

def invalidate_all(self, *args, **kwargs):
chunks_gen = chunks(self.client.scan_iter(f'{self.get_full_prefix()}:*'), 500)
for keys in chunks_gen:
self.client.delete(*keys)

async def async_invalidate_all(self, *args, **kwargs):
if isinstance(self.client, Redis):
await asyncio.to_thread(self.invalidate_all, *args, **kwargs)
else:
chunks_gen = async_chunks(self.client.scan_iter(f'{self.get_full_prefix()}:*'), 500)
async for keys in chunks_gen:
await self.client.delete(*keys)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@
packages=find_packages(),
install_requires=['redis'],
setup_requires=['pytest-runner==5.3.1'],
tests_require=['pytest==6.2.5', 'redis==4.4.4'],
tests_require=['pytest==8.3.2', 'pytest-asyncio==0.24.0', 'redis==4.4.4'],
)
4 changes: 2 additions & 2 deletions tests/test_redis_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
import zlib


redis_host = "redis-test-host"
redis_host = "localhost"
client = StrictRedis(host=redis_host, decode_responses=True)
client_no_decode = StrictRedis(host=redis_host)


@pytest.fixture(scope="session", autouse=True)
@pytest.fixture(scope="module", autouse=True)
def clear_cache(request):
client.flushall()

Expand Down