diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 6fd79cf..602bd19 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -34,8 +34,7 @@ jobs:
strategy:
max-parallel: 3
matrix:
- python-version: [3.6, 3.7]
-
+ python-version: ['3.6', '3.7', '3.8', '3.9', '3.10']
steps:
- uses: actions/checkout@v2.3.4
- name: Set up Python ${{ matrix.python-version }}
@@ -62,8 +61,7 @@ jobs:
strategy:
max-parallel: 3
matrix:
- redis-version: [4, 5, 6]
-
+ redis-version: [4, 5, 6, 7]
steps:
- uses: actions/checkout@v2.3.4
- name: Set up Python 3.6
diff --git a/Pipfile b/Pipfile
index d974fc6..a5a722d 100644
--- a/Pipfile
+++ b/Pipfile
@@ -4,8 +4,8 @@ verify_ssl = true
name = "pypi"
[packages]
-aioredis = ">=1.3"
-aioworkers = ">=0.14.3"
+redis = ">=4.3"
+aioworkers = ">=0.20"
[dev-packages]
pytest-aioworkers = "*"
@@ -15,6 +15,7 @@ pyyaml = "*"
isort = "*"
flake8 = "*"
mypy = "*"
+types-redis = "*"
[requires]
python_version = "3.6"
diff --git a/README.rst b/README.rst
index 57639a8..8428f03 100644
--- a/README.rst
+++ b/README.rst
@@ -23,7 +23,7 @@ Redis plugin for `aioworkers`.
Features
--------
-* Works on `aioredis `_
+* Works on `redis-py `_
* Queue based on
`RPUSH `_,
diff --git a/aioworkers_redis/base.py b/aioworkers_redis/base.py
index e33b948..f669554 100644
--- a/aioworkers_redis/base.py
+++ b/aioworkers_redis/base.py
@@ -1,6 +1,5 @@
from typing import Optional, Union
-import aioredis
from aioworkers.core.base import (
AbstractConnector,
AbstractNestedEntity,
@@ -8,6 +7,11 @@
)
from aioworkers.core.config import ValueExtractor
from aioworkers.core.formatter import FormattedEntity
+from redis.asyncio import Redis
+
+
+DEFAULT_HOST = 'localhost'
+DEFAULT_PORT = 6379
class Connector(
@@ -20,7 +24,7 @@ def __init__(self, *args, **kwargs):
self._joiner: str = ':'
self._prefix: str = ''
self._connector: Optional[Connector] = None
- self._pool: Optional[aioredis.Redis] = None
+ self._client: Optional[Redis] = None
super().__init__(*args, **kwargs)
def set_config(self, config):
@@ -38,10 +42,10 @@ def set_config(self, config):
super().set_config(cfg)
@property
- def pool(self) -> aioredis.Redis:
+ def pool(self) -> Redis:
connector = self._connector or self._get_connector()
- assert connector._pool is not None, 'Pool not ready'
- return connector._pool
+ assert connector._client is not None, 'Client is not ready'
+ return connector._client
def _get_connector(self) -> 'Connector':
cfg = self.config.get('connection')
@@ -91,9 +95,6 @@ def clean_key(self, raw_key: Union[str, bytes]) -> str:
return result
return result.decode()
- def acquire(self):
- return AsyncConnectionContextManager(self)
-
async def connect(self):
connector = self._connector or self._get_connector()
if connector is not self:
@@ -114,26 +115,27 @@ async def connect(self):
cfg = dict(cfg)
else:
cfg = {}
- self._pool = await self.pool_factory(cfg)
+ self._client = await self.client_factory(cfg)
- async def pool_factory(self, cfg: dict) -> aioredis.Redis:
+ async def client_factory(self, cfg: dict) -> Redis:
if cfg.get('dsn'):
address = cfg.pop('dsn')
elif cfg.get('address'):
address = cfg.pop('address')
else:
- address = cfg.pop('host', 'localhost'), cfg.pop('port', 6379)
- self.logger.debug('Create pool with address %s', address)
- return await aioredis.create_redis_pool(
- address, **cfg, loop=self.loop,
- )
+ host = cfg.pop('host', DEFAULT_HOST)
+ port = cfg.pop('port', DEFAULT_PORT)
+ address = 'redis://{}:{}'.format(host, port)
+ if 'maxsize' in cfg:
+ cfg['max_connections'] = cfg.pop('maxsize')
+ self.logger.debug('Create client with address %s', address)
+ return Redis.from_url(address, **cfg)
async def disconnect(self):
- pool = self._pool
- if pool is not None:
- self.logger.debug('Close pool')
- pool.close()
- await pool.wait_closed()
+ client = self._client
+ if client is not None:
+ self.logger.debug('Close connection')
+ await client.close()
def decode(self, b):
if b is not None:
@@ -151,21 +153,6 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.disconnect()
-class AsyncConnectionContextManager:
- __slots__ = ('_connector', '_connection')
-
- def __init__(self, connector: Connector):
- self._connector: Connector = connector
-
- async def __aenter__(self):
- self._connection = await self._connector.pool
- self._connection.__enter__()
- return self._connection
-
- async def __aexit__(self, exc_type, exc_value, tb):
- self._connection.__exit__(exc_type, exc_value, tb)
-
-
class KeyEntity(Connector):
@property
def key(self):
diff --git a/aioworkers_redis/queue.py b/aioworkers_redis/queue.py
index b209661..3364b1c 100644
--- a/aioworkers_redis/queue.py
+++ b/aioworkers_redis/queue.py
@@ -9,45 +9,37 @@
class Queue(KeyEntity, AbstractQueue):
async def init(self):
await super().init()
- self._lock = asyncio.Lock(loop=self.loop)
+ self._lock = asyncio.Lock()
def factory(self, item, config=None):
inst = super().factory(item, config=config)
- inst._lock = asyncio.Lock(loop=self.loop)
+ inst._lock = asyncio.Lock()
return inst
async def put(self, value):
value = self.encode(value)
- async with self.acquire() as conn:
- return await conn.execute('rpush', self.key, value)
+ return await self.pool.rpush(self.key, value)
async def get(self, *, timeout=0):
async with self._lock:
- async with self.acquire() as conn:
- result = await conn.execute('blpop', self.key, timeout)
+ result = await self.pool.blpop(self.key, timeout)
if timeout and result is None:
raise TimeoutError
value = self.decode(result[-1])
return value
async def length(self):
- async with self.acquire() as conn:
- return await conn.execute('llen', self.key)
+ return await self.pool.llen(self.key)
async def list(self):
- async with self.acquire() as conn:
- return [
- self.decode(i)
- for i in await conn.execute('lrange', self.key, 0, -1)]
+ return [self.decode(i) for i in await self.pool.lrange(self.key, 0, -1)]
async def remove(self, value):
value = self.encode(value)
- async with self.acquire() as conn:
- await conn.execute('lrem', self.key, 0, value)
+ await self.pool.lrem(self.key, 0, value)
async def clear(self):
- async with self.acquire() as conn:
- return await conn.execute('del', self.key)
+ return await self.pool.delete(self.key)
class BaseZQueue(Queue):
@@ -56,33 +48,27 @@ class BaseZQueue(Queue):
async def put(self, value):
score, val = value
val = self.encode(val)
- async with self.acquire() as conn:
- return await conn.execute('zadd', self.key, score, val)
+ return await self.pool.zadd(self.key, {val: score})
async def get(self):
async with self._lock:
while True:
- async with self.acquire() as conn:
- lv = await conn.execute('eval', self.script, 1, self.key)
+ lv = await self.pool.eval(self.script, 1, self.key)
if lv:
break
- await asyncio.sleep(self.config.timeout, loop=self.loop)
+ await asyncio.sleep(self.config.timeout)
value, score = lv
return float(score), self.decode(value)
async def length(self):
- async with self.acquire() as conn:
- return await conn.execute('zcard', self.key)
+ return await self.pool.zcard(self.key)
async def list(self):
- async with self.acquire() as conn:
- return [self.decode(i)
- for i in await conn.execute('zrange', self.key, 0, -1)]
+ return [self.decode(i) for i in await self.pool.zrange(self.key, 0, -1)]
async def remove(self, value):
value = self.encode(value)
- async with self.acquire() as conn:
- await conn.execute('zrem', self.key, value)
+ await self.pool.zrem(self.key, value)
@score_queue('time.time')
@@ -111,11 +97,9 @@ class TimestampZQueue(BaseZQueue):
async def get(self):
async with self._lock:
while True:
- async with self.acquire() as conn:
- lv = await conn.execute(
- 'eval', self.script, 1, self.key, time.time())
+ lv = await self.pool.eval(self.script, 1, self.key, time.time())
if lv:
break
- await asyncio.sleep(self.config.timeout, loop=self.loop)
+ await asyncio.sleep(self.config.timeout)
value, score = lv
return float(score), self.decode(value)
diff --git a/aioworkers_redis/storage.py b/aioworkers_redis/storage.py
index e8d0f80..a67a85a 100644
--- a/aioworkers_redis/storage.py
+++ b/aioworkers_redis/storage.py
@@ -18,13 +18,11 @@ def set_config(self, config):
)
async def list(self):
- async with self.acquire() as conn:
- keys = await conn.execute('keys', self.raw_key('*'))
+ keys = await self.pool.keys(self.raw_key('*'))
return [self.clean_key(i) for i in keys]
async def length(self):
- async with self.acquire() as conn:
- keys = await conn.execute('keys', self.raw_key('*'))
+ keys = await self.pool.keys(self.raw_key('*'))
return len(keys)
async def set(self, key, value):
@@ -32,27 +30,22 @@ async def set(self, key, value):
is_null = value is None
if not is_null:
value = self.encode(value)
- async with self.acquire() as conn:
- if is_null:
- return await conn.execute('del', raw_key)
- elif self._expiry:
- return await conn.execute(
- 'setex', raw_key, self._expiry, value,
- )
- else:
- return await conn.execute('set', raw_key, value)
+ if is_null:
+ return await self.pool.delete(raw_key)
+ elif self._expiry:
+ return await self.pool.setex(raw_key, self._expiry, value)
+ else:
+ return await self.pool.set(raw_key, value)
async def get(self, key):
raw_key = self.raw_key(key)
- async with self.acquire() as conn:
- value = await conn.execute('get', raw_key)
+ value = await self.pool.get(raw_key)
if value is not None:
return self.decode(value)
async def expiry(self, key, expiry):
raw_key = self.raw_key(key)
- async with self.acquire() as conn:
- await conn.execute('expire', raw_key, expiry)
+ await self.pool.expire(raw_key, expiry)
class HashStorage(FieldStorageMixin, Storage):
@@ -60,65 +53,57 @@ class HashStorage(FieldStorageMixin, Storage):
async def set(self, key, value, *, field=None, fields=None):
raw_key = self.raw_key(key)
to_del = []
- async with self.acquire() as conn:
- if field:
- if value is None:
- to_del.append(field)
- else:
- return await conn.execute(
- 'hset', raw_key, field, self.encode(value),
- )
- elif value is None:
- return await conn.execute('del', raw_key)
+ if field:
+ if value is None:
+ to_del.append(field)
else:
- pairs = []
- for f in fields or value:
- v = value[f]
- if v is None:
- to_del.append(f)
- else:
- pairs.extend((f, self.encode(v)))
- if pairs:
- await conn.execute('hmset', raw_key, *pairs)
- if to_del:
- await conn.execute('hdel', raw_key, *to_del)
- if self._expiry:
- await conn.execute('expire', raw_key, self._expiry)
+ return await self.pool.hset(raw_key, field, self.encode(value))
+ elif value is None:
+ return await self.pool.delete(raw_key)
+ else:
+ pairs = {}
+ for f in fields or value:
+ v = value[f]
+ if v is None:
+ to_del.append(f)
+ else:
+ pairs[f] = self.encode(v)
+ if pairs:
+ await self.pool.hset(raw_key, mapping=pairs)
+ if to_del:
+ await self.pool.hdel(raw_key, *to_del)
+ if self._expiry:
+ await self.pool.expire(raw_key, self._expiry)
async def get(self, key, *, field=None, fields=None):
raw_key = self.raw_key(key)
- async with self.acquire() as conn:
- if field:
- return self.decode(await conn.execute('hget', raw_key, field))
- elif fields:
- v = await conn.execute('hmget', raw_key, *fields)
- m = self.model()
- for f, v in zip(fields, v):
- m[f] = self.decode(v)
- else:
- a = await conn.execute('hgetall', raw_key)
- m = self.model()
- a = iter(a)
- for f, v in zip(a, a):
- m[f.decode()] = self.decode(v)
- return m
+ if field:
+ return self.decode(await self.pool.hget(raw_key, field))
+ elif fields:
+ v = await self.pool.hmget(raw_key, *fields)
+ m = self.model()
+ for f, v in zip(fields, v):
+ m[f] = self.decode(v)
+ else:
+ a = await self.pool.hgetall(raw_key)
+ m = self.model()
+ for f, v in a.items():
+ m[f.decode()] = self.decode(v)
+ return m
class HyperLogLogStorage(KeyEntity, AbstractBaseStorage):
async def set(self, key, value=True):
assert value is True
- async with self.acquire() as conn:
- await conn.execute('pfadd', self.key, key)
+ await self.pool.pfadd(self.key, key)
async def get(self, key):
tmp_key = self.raw_key('tmp:hhl:' + key)
- async with self.acquire() as conn:
- await conn.execute('pfmerge', tmp_key, self.key)
- result = await conn.execute('pfadd', tmp_key, key)
- await conn.execute('del', tmp_key)
+ await self.pool.pfmerge(tmp_key, self.key)
+ result = await self.pool.pfadd(tmp_key, key)
+ await self.pool.delete(tmp_key)
return result == 0
async def length(self):
- async with self.acquire() as conn:
- c = await conn.execute('pfcount', self.key)
+ c = await self.pool.pfcount(self.key)
return c
diff --git a/setup.cfg b/setup.cfg
index 115b366..fce6db0 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,12 +1,12 @@
[tool:pytest]
testpaths = aioworkers_redis tests
-codestyle_max_line_length = 79
+asyncio_mode = auto
; filterwarnings =
; ignore::DeprecationWarning
; ignore::PendingDeprecationWarning
[flake8]
-max_line_length = 79
+max_line_length = 120
[mypy]
ignore_missing_imports = True
diff --git a/setup.py b/setup.py
index 4a601e3..3901ce6 100644
--- a/setup.py
+++ b/setup.py
@@ -27,8 +27,8 @@ def get_version():
requirements = [
- 'aioworkers>=0.14.3',
- 'aioredis>=1.3.0',
+ 'aioworkers>=0.20',
+ 'redis>=4.3',
]
setup(
diff --git a/tests/test_queue.py b/tests/test_queue.py
index 45edb88..5b83b44 100644
--- a/tests/test_queue.py
+++ b/tests/test_queue.py
@@ -86,8 +86,7 @@ async def test_zqueue(config, loop):
assert 1 == await q.length()
await q.remove('3')
assert not await q.length()
-
- with mock.patch('asyncio.sleep'):
+ with mock.patch('asyncio.sleep', object()):
with pytest.raises(TypeError):
await q.get()