Skip to content

Commit

Permalink
Migrate from aioredis to redis-py #38 from rossnomann
Browse files Browse the repository at this point in the history
Migrate from aioredis to redis-py
  • Loading branch information
aamalev authored Jul 9, 2022
2 parents 9f441cb + 35109a5 commit 813f133
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 143 deletions.
6 changes: 2 additions & 4 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ jobs:
strategy:
max-parallel: 3
matrix:
python-version: [3.6, 3.7]

python-version: ['3.6', '3.7', '3.8', '3.9', '3.10']
steps:
- uses: actions/[email protected]
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -62,8 +61,7 @@ jobs:
strategy:
max-parallel: 3
matrix:
redis-version: [4, 5, 6]

redis-version: [4, 5, 6, 7]
steps:
- uses: actions/[email protected]
- name: Set up Python 3.6
Expand Down
5 changes: 3 additions & 2 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ verify_ssl = true
name = "pypi"

[packages]
aioredis = ">=1.3"
aioworkers = ">=0.14.3"
redis = ">=4.3"
aioworkers = ">=0.20"

[dev-packages]
pytest-aioworkers = "*"
Expand All @@ -15,6 +15,7 @@ pyyaml = "*"
isort = "*"
flake8 = "*"
mypy = "*"
types-redis = "*"

[requires]
python_version = "3.6"
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Redis plugin for `aioworkers`.
Features
--------

* Works on `aioredis <https://pypi.org/project/aioredis/>`_
* Works on `redis-py <https://pypi.org/project/redis/>`_

* Queue based on
`RPUSH <https://redis.io/commands/rpush>`_,
Expand Down
57 changes: 22 additions & 35 deletions aioworkers_redis/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from typing import Optional, Union

import aioredis
from aioworkers.core.base import (
AbstractConnector,
AbstractNestedEntity,
LoggingEntity,
)
from aioworkers.core.config import ValueExtractor
from aioworkers.core.formatter import FormattedEntity
from redis.asyncio import Redis


DEFAULT_HOST = 'localhost'
DEFAULT_PORT = 6379


class Connector(
Expand All @@ -20,7 +24,7 @@ def __init__(self, *args, **kwargs):
self._joiner: str = ':'
self._prefix: str = ''
self._connector: Optional[Connector] = None
self._pool: Optional[aioredis.Redis] = None
self._client: Optional[Redis] = None
super().__init__(*args, **kwargs)

def set_config(self, config):
Expand All @@ -38,10 +42,10 @@ def set_config(self, config):
super().set_config(cfg)

@property
def pool(self) -> aioredis.Redis:
def pool(self) -> Redis:
connector = self._connector or self._get_connector()
assert connector._pool is not None, 'Pool not ready'
return connector._pool
assert connector._client is not None, 'Client is not ready'
return connector._client

def _get_connector(self) -> 'Connector':
cfg = self.config.get('connection')
Expand Down Expand Up @@ -91,9 +95,6 @@ def clean_key(self, raw_key: Union[str, bytes]) -> str:
return result
return result.decode()

def acquire(self):
return AsyncConnectionContextManager(self)

async def connect(self):
connector = self._connector or self._get_connector()
if connector is not self:
Expand All @@ -114,26 +115,27 @@ async def connect(self):
cfg = dict(cfg)
else:
cfg = {}
self._pool = await self.pool_factory(cfg)
self._client = await self.client_factory(cfg)

async def pool_factory(self, cfg: dict) -> aioredis.Redis:
async def client_factory(self, cfg: dict) -> Redis:
if cfg.get('dsn'):
address = cfg.pop('dsn')
elif cfg.get('address'):
address = cfg.pop('address')
else:
address = cfg.pop('host', 'localhost'), cfg.pop('port', 6379)
self.logger.debug('Create pool with address %s', address)
return await aioredis.create_redis_pool(
address, **cfg, loop=self.loop,
)
host = cfg.pop('host', DEFAULT_HOST)
port = cfg.pop('port', DEFAULT_PORT)
address = 'redis://{}:{}'.format(host, port)
if 'maxsize' in cfg:
cfg['max_connections'] = cfg.pop('maxsize')
self.logger.debug('Create client with address %s', address)
return Redis.from_url(address, **cfg)

async def disconnect(self):
pool = self._pool
if pool is not None:
self.logger.debug('Close pool')
pool.close()
await pool.wait_closed()
client = self._client
if client is not None:
self.logger.debug('Close connection')
await client.close()

def decode(self, b):
if b is not None:
Expand All @@ -151,21 +153,6 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.disconnect()


class AsyncConnectionContextManager:
__slots__ = ('_connector', '_connection')

def __init__(self, connector: Connector):
self._connector: Connector = connector

async def __aenter__(self):
self._connection = await self._connector.pool
self._connection.__enter__()
return self._connection

async def __aexit__(self, exc_type, exc_value, tb):
self._connection.__exit__(exc_type, exc_value, tb)


class KeyEntity(Connector):
@property
def key(self):
Expand Down
48 changes: 16 additions & 32 deletions aioworkers_redis/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,45 +9,37 @@
class Queue(KeyEntity, AbstractQueue):
async def init(self):
await super().init()
self._lock = asyncio.Lock(loop=self.loop)
self._lock = asyncio.Lock()

def factory(self, item, config=None):
inst = super().factory(item, config=config)
inst._lock = asyncio.Lock(loop=self.loop)
inst._lock = asyncio.Lock()
return inst

async def put(self, value):
value = self.encode(value)
async with self.acquire() as conn:
return await conn.execute('rpush', self.key, value)
return await self.pool.rpush(self.key, value)

async def get(self, *, timeout=0):
async with self._lock:
async with self.acquire() as conn:
result = await conn.execute('blpop', self.key, timeout)
result = await self.pool.blpop(self.key, timeout)
if timeout and result is None:
raise TimeoutError
value = self.decode(result[-1])
return value

async def length(self):
async with self.acquire() as conn:
return await conn.execute('llen', self.key)
return await self.pool.llen(self.key)

async def list(self):
async with self.acquire() as conn:
return [
self.decode(i)
for i in await conn.execute('lrange', self.key, 0, -1)]
return [self.decode(i) for i in await self.pool.lrange(self.key, 0, -1)]

async def remove(self, value):
value = self.encode(value)
async with self.acquire() as conn:
await conn.execute('lrem', self.key, 0, value)
await self.pool.lrem(self.key, 0, value)

async def clear(self):
async with self.acquire() as conn:
return await conn.execute('del', self.key)
return await self.pool.delete(self.key)


class BaseZQueue(Queue):
Expand All @@ -56,33 +48,27 @@ class BaseZQueue(Queue):
async def put(self, value):
score, val = value
val = self.encode(val)
async with self.acquire() as conn:
return await conn.execute('zadd', self.key, score, val)
return await self.pool.zadd(self.key, {val: score})

async def get(self):
async with self._lock:
while True:
async with self.acquire() as conn:
lv = await conn.execute('eval', self.script, 1, self.key)
lv = await self.pool.eval(self.script, 1, self.key)
if lv:
break
await asyncio.sleep(self.config.timeout, loop=self.loop)
await asyncio.sleep(self.config.timeout)
value, score = lv
return float(score), self.decode(value)

async def length(self):
async with self.acquire() as conn:
return await conn.execute('zcard', self.key)
return await self.pool.zcard(self.key)

async def list(self):
async with self.acquire() as conn:
return [self.decode(i)
for i in await conn.execute('zrange', self.key, 0, -1)]
return [self.decode(i) for i in await self.pool.zrange(self.key, 0, -1)]

async def remove(self, value):
value = self.encode(value)
async with self.acquire() as conn:
await conn.execute('zrem', self.key, value)
await self.pool.zrem(self.key, value)


@score_queue('time.time')
Expand Down Expand Up @@ -111,11 +97,9 @@ class TimestampZQueue(BaseZQueue):
async def get(self):
async with self._lock:
while True:
async with self.acquire() as conn:
lv = await conn.execute(
'eval', self.script, 1, self.key, time.time())
lv = await self.pool.eval(self.script, 1, self.key, time.time())
if lv:
break
await asyncio.sleep(self.config.timeout, loop=self.loop)
await asyncio.sleep(self.config.timeout)
value, score = lv
return float(score), self.decode(value)
Loading

0 comments on commit 813f133

Please sign in to comment.