-
Notifications
You must be signed in to change notification settings - Fork 1
/
asyncio_celery_client.py
145 lines (127 loc) · 5.29 KB
/
asyncio_celery_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from asyncio import Transport
from contextlib import asynccontextmanager
from typing import Optional, List
from urllib.parse import urlparse
import aioamqp
from aioamqp import AmqpProtocol
from aioamqp.channel import Channel
from aioamqp.protocol import OPEN
from aioredis import ConnectionsPool, Redis
from aioredis.util import parse_url
from celery import Celery, states
from celery.app.amqp import task_message
from celery.app.utils import Settings
from celery.backends.base import BaseKeyValueStoreBackend
from kombu import uuid, serialization
class AsyncResult:
def __init__(self, client, task_id):
self.client: AsyncCeleryClient = client
self.task_id = task_id
async def get(self):
return await self.client.result_backend.wait_for_result(self.task_id)
class AmqpBackend:
def __init__(self, conf):
self.conf: Settings = conf
self.transport: Optional[Transport] = None
self.protocol: Optional[AmqpProtocol] = None
self.channels: List[Channel] = []
@asynccontextmanager
async def get_channel(self):
# TODO fix race to create connection
# TODO be more persistent at creating the connection. Mimic
# what celery does perhaps.
if self.protocol is None or self.protocol.state != OPEN:
url = self.conf.broker_url
parts = urlparse(url)
self.channels = []
# TODO more connection parameters, ssl, login details
self.transport, self.protocol = await aioamqp.connect(
host=parts.hostname,
port=parts.port,
)
if self.channels:
channel = self.channels.pop()
else:
channel = await self.protocol.channel()
try:
yield channel
finally:
self.channels.append(channel)
class RedisResultBackend:
task_keyprefix = BaseKeyValueStoreBackend.task_keyprefix
def __init__(self, celery):
self.celery: Celery = celery
address, options = parse_url(celery.conf.result_backend)
self.redis_pool: Optional[ConnectionsPool] = ConnectionsPool(
address,
minsize=0,
maxsize=10,
**options
)
async def wait_for_result(self, task_id):
key = (self.task_keyprefix + task_id).encode()
# TODO share a single connection for all waiting results
# connections effectively get a task to process messages
# coming in. The built-in Connection then wants to put that
# on a queue. Do we then need another task processing those
# queues? The case of two concurrent waits on the same
# celery task is the tricky one. We can't just have the call
# to wait_for_result own that queue. Maybe the first call
# could own that queue and setup a future for any other calls
# to wait on.
async with self.redis_pool.get() as conn:
conn = Redis(conn)
chan, = await conn.subscribe(key)
try:
while True:
encodedmeta = await Redis(self.redis_pool).get(key)
if not encodedmeta:
meta = {'status': states.PENDING, 'result': None}
else:
meta = self.celery.backend.decode_result(encodedmeta)
if meta['status'] in states.READY_STATES:
if meta['status'] in states.PROPAGATE_STATES:
raise meta['result']
return meta['result']
await chan.get()
finally:
await conn.unsubscribe(key)
class AsyncCeleryClient:
def __init__(self, celery):
self.celery: Celery = celery
self.broker = AmqpBackend(celery.conf)
self.result_backend = RedisResultBackend(celery)
async def queue_task(self, task, args, kwargs):
# TODO implement more options from apply_async. I need at least the
# queue option. I don't know what the other commonly used ones are.
async with self.broker.get_channel() as channel:
queue = self.celery.conf.task_default_queue
# TODO who's responsible in celery for creating the queue
# I hoped that the consumer would create it so there'd
# never be any need to here. But I've seen errors that suggest
# other wise.
await channel.queue_declare(queue, passive=True)
task_id = uuid()
message: task_message = self.celery.amqp.create_task_message(
task_id, task.name, args, kwargs
)
content_type, content_encoding, body = serialization.dumps(
message.body, 'json',
)
properties = {
"content_type": content_type,
"content_encoding": content_encoding,
"headers": message.headers,
**message.properties
}
body = body.encode(content_encoding)
await channel.publish(
body,
'',
queue,
properties=properties
)
return AsyncResult(self, task_id)
async def wait_task(self, task, args, kwargs):
r = await self.queue_task(task, args, kwargs)
return await r.get()