Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate from aioredis to redis-py #38

Merged
merged 7 commits into from
Jul 9, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ jobs:
strategy:
max-parallel: 3
matrix:
python-version: [3.6, 3.7]

python-version: ['3.6', '3.7', '3.8', '3.9', '3.10']
steps:
- uses: actions/[email protected]
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -62,8 +61,7 @@ jobs:
strategy:
max-parallel: 3
matrix:
redis-version: [4, 5, 6]

redis-version: [4, 5, 6, 7]
steps:
- uses: actions/[email protected]
- name: Set up Python 3.6
Expand Down
5 changes: 3 additions & 2 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ verify_ssl = true
name = "pypi"

[packages]
aioredis = ">=1.3"
aioworkers = ">=0.14.3"
redis = ">=4.3"
aioworkers = ">=0.20"

[dev-packages]
pytest-aioworkers = "*"
Expand All @@ -15,6 +15,7 @@ pyyaml = "*"
isort = "*"
flake8 = "*"
mypy = "*"
types-redis = "*"

[requires]
python_version = "3.6"
4 changes: 2 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Redis plugin for `aioworkers`.
Features
--------

* Works on `aioredis <https://pypi.org/project/aioredis/>`_
* Works on `redis-py <https://pypi.org/project/redis/>`_

* Queue based on
`RPUSH <https://redis.io/commands/rpush>`_,
Expand Down Expand Up @@ -71,7 +71,7 @@ Add this to aioworkers config.yaml:
connection:
host: localhost
port: 6379
maxsize: 20
max_connections: 20
rossnomann marked this conversation as resolved.
Show resolved Hide resolved
queue:
cls: aioworkers_redis.queue.Queue
connection: .redis
Expand Down
55 changes: 20 additions & 35 deletions aioworkers_redis/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from typing import Optional, Union

import aioredis
from aioworkers.core.base import (
AbstractConnector,
AbstractNestedEntity,
LoggingEntity,
)
from aioworkers.core.config import ValueExtractor
from aioworkers.core.formatter import FormattedEntity
from redis.asyncio import Redis


DEFAULT_HOST = 'localhost'
DEFAULT_PORT = 6379


class Connector(
Expand All @@ -20,7 +24,7 @@ def __init__(self, *args, **kwargs):
self._joiner: str = ':'
self._prefix: str = ''
self._connector: Optional[Connector] = None
self._pool: Optional[aioredis.Redis] = None
self._client: Optional[Redis] = None
super().__init__(*args, **kwargs)

def set_config(self, config):
Expand All @@ -38,10 +42,10 @@ def set_config(self, config):
super().set_config(cfg)

@property
def pool(self) -> aioredis.Redis:
rossnomann marked this conversation as resolved.
Show resolved Hide resolved
def client(self) -> Redis:
connector = self._connector or self._get_connector()
assert connector._pool is not None, 'Pool not ready'
return connector._pool
assert connector._client is not None, 'Client is not ready'
return connector._client

def _get_connector(self) -> 'Connector':
cfg = self.config.get('connection')
Expand Down Expand Up @@ -91,9 +95,6 @@ def clean_key(self, raw_key: Union[str, bytes]) -> str:
return result
return result.decode()

def acquire(self):
return AsyncConnectionContextManager(self)

async def connect(self):
connector = self._connector or self._get_connector()
if connector is not self:
Expand All @@ -114,26 +115,25 @@ async def connect(self):
cfg = dict(cfg)
else:
cfg = {}
self._pool = await self.pool_factory(cfg)
self._client = await self.client_factory(cfg)

async def pool_factory(self, cfg: dict) -> aioredis.Redis:
async def client_factory(self, cfg: dict) -> Redis:
if cfg.get('dsn'):
address = cfg.pop('dsn')
elif cfg.get('address'):
address = cfg.pop('address')
else:
address = cfg.pop('host', 'localhost'), cfg.pop('port', 6379)
self.logger.debug('Create pool with address %s', address)
return await aioredis.create_redis_pool(
address, **cfg, loop=self.loop,
)
host = cfg.pop('host', DEFAULT_HOST)
port = cfg.pop('port', DEFAULT_PORT)
address = 'redis://{}:{}'.format(host, port)
self.logger.debug('Create client with address %s', address)
return Redis.from_url(address, **cfg)

async def disconnect(self):
pool = self._pool
if pool is not None:
self.logger.debug('Close pool')
pool.close()
await pool.wait_closed()
client = self._client
if client is not None:
self.logger.debug('Close connection')
await client.close()

def decode(self, b):
if b is not None:
Expand All @@ -151,21 +151,6 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.disconnect()


class AsyncConnectionContextManager:
__slots__ = ('_connector', '_connection')

def __init__(self, connector: Connector):
self._connector: Connector = connector

async def __aenter__(self):
self._connection = await self._connector.pool
self._connection.__enter__()
return self._connection

async def __aexit__(self, exc_type, exc_value, tb):
self._connection.__exit__(exc_type, exc_value, tb)


class KeyEntity(Connector):
@property
def key(self):
Expand Down
48 changes: 16 additions & 32 deletions aioworkers_redis/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,45 +9,37 @@
class Queue(KeyEntity, AbstractQueue):
async def init(self):
await super().init()
self._lock = asyncio.Lock(loop=self.loop)
self._lock = asyncio.Lock()

def factory(self, item, config=None):
inst = super().factory(item, config=config)
inst._lock = asyncio.Lock(loop=self.loop)
inst._lock = asyncio.Lock()
return inst

async def put(self, value):
value = self.encode(value)
async with self.acquire() as conn:
return await conn.execute('rpush', self.key, value)
return await self.client.rpush(self.key, value)

async def get(self, *, timeout=0):
async with self._lock:
async with self.acquire() as conn:
result = await conn.execute('blpop', self.key, timeout)
result = await self.client.blpop(self.key, timeout)
if timeout and result is None:
raise TimeoutError
value = self.decode(result[-1])
return value

async def length(self):
async with self.acquire() as conn:
return await conn.execute('llen', self.key)
return await self.client.llen(self.key)

async def list(self):
async with self.acquire() as conn:
return [
self.decode(i)
for i in await conn.execute('lrange', self.key, 0, -1)]
return [self.decode(i) for i in await self.client.lrange(self.key, 0, -1)]

async def remove(self, value):
value = self.encode(value)
async with self.acquire() as conn:
await conn.execute('lrem', self.key, 0, value)
await self.client.lrem(self.key, 0, value)

async def clear(self):
async with self.acquire() as conn:
return await conn.execute('del', self.key)
return await self.client.delete(self.key)


class BaseZQueue(Queue):
Expand All @@ -56,33 +48,27 @@ class BaseZQueue(Queue):
async def put(self, value):
score, val = value
val = self.encode(val)
async with self.acquire() as conn:
return await conn.execute('zadd', self.key, score, val)
return await self.client.zadd(self.key, {val: score})

async def get(self):
async with self._lock:
while True:
async with self.acquire() as conn:
lv = await conn.execute('eval', self.script, 1, self.key)
lv = await self.client.eval(self.script, 1, self.key)
if lv:
break
await asyncio.sleep(self.config.timeout, loop=self.loop)
await asyncio.sleep(self.config.timeout)
value, score = lv
return float(score), self.decode(value)

async def length(self):
async with self.acquire() as conn:
return await conn.execute('zcard', self.key)
return await self.client.zcard(self.key)

async def list(self):
async with self.acquire() as conn:
return [self.decode(i)
for i in await conn.execute('zrange', self.key, 0, -1)]
return [self.decode(i) for i in await self.client.zrange(self.key, 0, -1)]

async def remove(self, value):
value = self.encode(value)
async with self.acquire() as conn:
await conn.execute('zrem', self.key, value)
await self.client.zrem(self.key, value)


@score_queue('time.time')
Expand Down Expand Up @@ -111,11 +97,9 @@ class TimestampZQueue(BaseZQueue):
async def get(self):
async with self._lock:
while True:
async with self.acquire() as conn:
lv = await conn.execute(
'eval', self.script, 1, self.key, time.time())
lv = await self.client.eval(self.script, 1, self.key, time.time())
if lv:
break
await asyncio.sleep(self.config.timeout, loop=self.loop)
await asyncio.sleep(self.config.timeout)
value, score = lv
return float(score), self.decode(value)
Loading