diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 6fd79cf..602bd19 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -34,8 +34,7 @@ jobs: strategy: max-parallel: 3 matrix: - python-version: [3.6, 3.7] - + python-version: ['3.6', '3.7', '3.8', '3.9', '3.10'] steps: - uses: actions/checkout@v2.3.4 - name: Set up Python ${{ matrix.python-version }} @@ -62,8 +61,7 @@ jobs: strategy: max-parallel: 3 matrix: - redis-version: [4, 5, 6] - + redis-version: [4, 5, 6, 7] steps: - uses: actions/checkout@v2.3.4 - name: Set up Python 3.6 diff --git a/Pipfile b/Pipfile index d974fc6..a5a722d 100644 --- a/Pipfile +++ b/Pipfile @@ -4,8 +4,8 @@ verify_ssl = true name = "pypi" [packages] -aioredis = ">=1.3" -aioworkers = ">=0.14.3" +redis = ">=4.3" +aioworkers = ">=0.20" [dev-packages] pytest-aioworkers = "*" @@ -15,6 +15,7 @@ pyyaml = "*" isort = "*" flake8 = "*" mypy = "*" +types-redis = "*" [requires] python_version = "3.6" diff --git a/README.rst b/README.rst index 57639a8..8428f03 100644 --- a/README.rst +++ b/README.rst @@ -23,7 +23,7 @@ Redis plugin for `aioworkers`. Features -------- -* Works on `aioredis `_ +* Works on `redis-py `_ * Queue based on `RPUSH `_, diff --git a/aioworkers_redis/base.py b/aioworkers_redis/base.py index e33b948..f669554 100644 --- a/aioworkers_redis/base.py +++ b/aioworkers_redis/base.py @@ -1,6 +1,5 @@ from typing import Optional, Union -import aioredis from aioworkers.core.base import ( AbstractConnector, AbstractNestedEntity, @@ -8,6 +7,11 @@ ) from aioworkers.core.config import ValueExtractor from aioworkers.core.formatter import FormattedEntity +from redis.asyncio import Redis + + +DEFAULT_HOST = 'localhost' +DEFAULT_PORT = 6379 class Connector( @@ -20,7 +24,7 @@ def __init__(self, *args, **kwargs): self._joiner: str = ':' self._prefix: str = '' self._connector: Optional[Connector] = None - self._pool: Optional[aioredis.Redis] = None + self._client: Optional[Redis] = None super().__init__(*args, **kwargs) def set_config(self, config): @@ -38,10 +42,10 @@ def set_config(self, config): super().set_config(cfg) @property - def pool(self) -> aioredis.Redis: + def pool(self) -> Redis: connector = self._connector or self._get_connector() - assert connector._pool is not None, 'Pool not ready' - return connector._pool + assert connector._client is not None, 'Client is not ready' + return connector._client def _get_connector(self) -> 'Connector': cfg = self.config.get('connection') @@ -91,9 +95,6 @@ def clean_key(self, raw_key: Union[str, bytes]) -> str: return result return result.decode() - def acquire(self): - return AsyncConnectionContextManager(self) - async def connect(self): connector = self._connector or self._get_connector() if connector is not self: @@ -114,26 +115,27 @@ async def connect(self): cfg = dict(cfg) else: cfg = {} - self._pool = await self.pool_factory(cfg) + self._client = await self.client_factory(cfg) - async def pool_factory(self, cfg: dict) -> aioredis.Redis: + async def client_factory(self, cfg: dict) -> Redis: if cfg.get('dsn'): address = cfg.pop('dsn') elif cfg.get('address'): address = cfg.pop('address') else: - address = cfg.pop('host', 'localhost'), cfg.pop('port', 6379) - self.logger.debug('Create pool with address %s', address) - return await aioredis.create_redis_pool( - address, **cfg, loop=self.loop, - ) + host = cfg.pop('host', DEFAULT_HOST) + port = cfg.pop('port', DEFAULT_PORT) + address = 'redis://{}:{}'.format(host, port) + if 'maxsize' in cfg: + cfg['max_connections'] = cfg.pop('maxsize') + self.logger.debug('Create client with address %s', address) + return Redis.from_url(address, **cfg) async def disconnect(self): - pool = self._pool - if pool is not None: - self.logger.debug('Close pool') - pool.close() - await pool.wait_closed() + client = self._client + if client is not None: + self.logger.debug('Close connection') + await client.close() def decode(self, b): if b is not None: @@ -151,21 +153,6 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self.disconnect() -class AsyncConnectionContextManager: - __slots__ = ('_connector', '_connection') - - def __init__(self, connector: Connector): - self._connector: Connector = connector - - async def __aenter__(self): - self._connection = await self._connector.pool - self._connection.__enter__() - return self._connection - - async def __aexit__(self, exc_type, exc_value, tb): - self._connection.__exit__(exc_type, exc_value, tb) - - class KeyEntity(Connector): @property def key(self): diff --git a/aioworkers_redis/queue.py b/aioworkers_redis/queue.py index b209661..3364b1c 100644 --- a/aioworkers_redis/queue.py +++ b/aioworkers_redis/queue.py @@ -9,45 +9,37 @@ class Queue(KeyEntity, AbstractQueue): async def init(self): await super().init() - self._lock = asyncio.Lock(loop=self.loop) + self._lock = asyncio.Lock() def factory(self, item, config=None): inst = super().factory(item, config=config) - inst._lock = asyncio.Lock(loop=self.loop) + inst._lock = asyncio.Lock() return inst async def put(self, value): value = self.encode(value) - async with self.acquire() as conn: - return await conn.execute('rpush', self.key, value) + return await self.pool.rpush(self.key, value) async def get(self, *, timeout=0): async with self._lock: - async with self.acquire() as conn: - result = await conn.execute('blpop', self.key, timeout) + result = await self.pool.blpop(self.key, timeout) if timeout and result is None: raise TimeoutError value = self.decode(result[-1]) return value async def length(self): - async with self.acquire() as conn: - return await conn.execute('llen', self.key) + return await self.pool.llen(self.key) async def list(self): - async with self.acquire() as conn: - return [ - self.decode(i) - for i in await conn.execute('lrange', self.key, 0, -1)] + return [self.decode(i) for i in await self.pool.lrange(self.key, 0, -1)] async def remove(self, value): value = self.encode(value) - async with self.acquire() as conn: - await conn.execute('lrem', self.key, 0, value) + await self.pool.lrem(self.key, 0, value) async def clear(self): - async with self.acquire() as conn: - return await conn.execute('del', self.key) + return await self.pool.delete(self.key) class BaseZQueue(Queue): @@ -56,33 +48,27 @@ class BaseZQueue(Queue): async def put(self, value): score, val = value val = self.encode(val) - async with self.acquire() as conn: - return await conn.execute('zadd', self.key, score, val) + return await self.pool.zadd(self.key, {val: score}) async def get(self): async with self._lock: while True: - async with self.acquire() as conn: - lv = await conn.execute('eval', self.script, 1, self.key) + lv = await self.pool.eval(self.script, 1, self.key) if lv: break - await asyncio.sleep(self.config.timeout, loop=self.loop) + await asyncio.sleep(self.config.timeout) value, score = lv return float(score), self.decode(value) async def length(self): - async with self.acquire() as conn: - return await conn.execute('zcard', self.key) + return await self.pool.zcard(self.key) async def list(self): - async with self.acquire() as conn: - return [self.decode(i) - for i in await conn.execute('zrange', self.key, 0, -1)] + return [self.decode(i) for i in await self.pool.zrange(self.key, 0, -1)] async def remove(self, value): value = self.encode(value) - async with self.acquire() as conn: - await conn.execute('zrem', self.key, value) + await self.pool.zrem(self.key, value) @score_queue('time.time') @@ -111,11 +97,9 @@ class TimestampZQueue(BaseZQueue): async def get(self): async with self._lock: while True: - async with self.acquire() as conn: - lv = await conn.execute( - 'eval', self.script, 1, self.key, time.time()) + lv = await self.pool.eval(self.script, 1, self.key, time.time()) if lv: break - await asyncio.sleep(self.config.timeout, loop=self.loop) + await asyncio.sleep(self.config.timeout) value, score = lv return float(score), self.decode(value) diff --git a/aioworkers_redis/storage.py b/aioworkers_redis/storage.py index e8d0f80..a67a85a 100644 --- a/aioworkers_redis/storage.py +++ b/aioworkers_redis/storage.py @@ -18,13 +18,11 @@ def set_config(self, config): ) async def list(self): - async with self.acquire() as conn: - keys = await conn.execute('keys', self.raw_key('*')) + keys = await self.pool.keys(self.raw_key('*')) return [self.clean_key(i) for i in keys] async def length(self): - async with self.acquire() as conn: - keys = await conn.execute('keys', self.raw_key('*')) + keys = await self.pool.keys(self.raw_key('*')) return len(keys) async def set(self, key, value): @@ -32,27 +30,22 @@ async def set(self, key, value): is_null = value is None if not is_null: value = self.encode(value) - async with self.acquire() as conn: - if is_null: - return await conn.execute('del', raw_key) - elif self._expiry: - return await conn.execute( - 'setex', raw_key, self._expiry, value, - ) - else: - return await conn.execute('set', raw_key, value) + if is_null: + return await self.pool.delete(raw_key) + elif self._expiry: + return await self.pool.setex(raw_key, self._expiry, value) + else: + return await self.pool.set(raw_key, value) async def get(self, key): raw_key = self.raw_key(key) - async with self.acquire() as conn: - value = await conn.execute('get', raw_key) + value = await self.pool.get(raw_key) if value is not None: return self.decode(value) async def expiry(self, key, expiry): raw_key = self.raw_key(key) - async with self.acquire() as conn: - await conn.execute('expire', raw_key, expiry) + await self.pool.expire(raw_key, expiry) class HashStorage(FieldStorageMixin, Storage): @@ -60,65 +53,57 @@ class HashStorage(FieldStorageMixin, Storage): async def set(self, key, value, *, field=None, fields=None): raw_key = self.raw_key(key) to_del = [] - async with self.acquire() as conn: - if field: - if value is None: - to_del.append(field) - else: - return await conn.execute( - 'hset', raw_key, field, self.encode(value), - ) - elif value is None: - return await conn.execute('del', raw_key) + if field: + if value is None: + to_del.append(field) else: - pairs = [] - for f in fields or value: - v = value[f] - if v is None: - to_del.append(f) - else: - pairs.extend((f, self.encode(v))) - if pairs: - await conn.execute('hmset', raw_key, *pairs) - if to_del: - await conn.execute('hdel', raw_key, *to_del) - if self._expiry: - await conn.execute('expire', raw_key, self._expiry) + return await self.pool.hset(raw_key, field, self.encode(value)) + elif value is None: + return await self.pool.delete(raw_key) + else: + pairs = {} + for f in fields or value: + v = value[f] + if v is None: + to_del.append(f) + else: + pairs[f] = self.encode(v) + if pairs: + await self.pool.hset(raw_key, mapping=pairs) + if to_del: + await self.pool.hdel(raw_key, *to_del) + if self._expiry: + await self.pool.expire(raw_key, self._expiry) async def get(self, key, *, field=None, fields=None): raw_key = self.raw_key(key) - async with self.acquire() as conn: - if field: - return self.decode(await conn.execute('hget', raw_key, field)) - elif fields: - v = await conn.execute('hmget', raw_key, *fields) - m = self.model() - for f, v in zip(fields, v): - m[f] = self.decode(v) - else: - a = await conn.execute('hgetall', raw_key) - m = self.model() - a = iter(a) - for f, v in zip(a, a): - m[f.decode()] = self.decode(v) - return m + if field: + return self.decode(await self.pool.hget(raw_key, field)) + elif fields: + v = await self.pool.hmget(raw_key, *fields) + m = self.model() + for f, v in zip(fields, v): + m[f] = self.decode(v) + else: + a = await self.pool.hgetall(raw_key) + m = self.model() + for f, v in a.items(): + m[f.decode()] = self.decode(v) + return m class HyperLogLogStorage(KeyEntity, AbstractBaseStorage): async def set(self, key, value=True): assert value is True - async with self.acquire() as conn: - await conn.execute('pfadd', self.key, key) + await self.pool.pfadd(self.key, key) async def get(self, key): tmp_key = self.raw_key('tmp:hhl:' + key) - async with self.acquire() as conn: - await conn.execute('pfmerge', tmp_key, self.key) - result = await conn.execute('pfadd', tmp_key, key) - await conn.execute('del', tmp_key) + await self.pool.pfmerge(tmp_key, self.key) + result = await self.pool.pfadd(tmp_key, key) + await self.pool.delete(tmp_key) return result == 0 async def length(self): - async with self.acquire() as conn: - c = await conn.execute('pfcount', self.key) + c = await self.pool.pfcount(self.key) return c diff --git a/setup.cfg b/setup.cfg index 115b366..fce6db0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,12 +1,12 @@ [tool:pytest] testpaths = aioworkers_redis tests -codestyle_max_line_length = 79 +asyncio_mode = auto ; filterwarnings = ; ignore::DeprecationWarning ; ignore::PendingDeprecationWarning [flake8] -max_line_length = 79 +max_line_length = 120 [mypy] ignore_missing_imports = True diff --git a/setup.py b/setup.py index 4a601e3..3901ce6 100644 --- a/setup.py +++ b/setup.py @@ -27,8 +27,8 @@ def get_version(): requirements = [ - 'aioworkers>=0.14.3', - 'aioredis>=1.3.0', + 'aioworkers>=0.20', + 'redis>=4.3', ] setup( diff --git a/tests/test_queue.py b/tests/test_queue.py index 45edb88..5b83b44 100644 --- a/tests/test_queue.py +++ b/tests/test_queue.py @@ -86,8 +86,7 @@ async def test_zqueue(config, loop): assert 1 == await q.length() await q.remove('3') assert not await q.length() - - with mock.patch('asyncio.sleep'): + with mock.patch('asyncio.sleep', object()): with pytest.raises(TypeError): await q.get()