Skip to content

Commit

Permalink
Merge pull request #23 from sendbird/feature/multiple_consumer
Browse files Browse the repository at this point in the history
Enable redis-cluster backend to consume from multiple redis clusters
  • Loading branch information
dlunch authored Nov 7, 2023
2 parents 32b97be + 94d2154 commit e5c74c9
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 73 deletions.
156 changes: 84 additions & 72 deletions kombu/transport/redis_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from kombu.utils.eventio import READ, ERR
from kombu.utils.json import loads, dumps
from kombu.utils.objects import cached_property
from kombu.utils.url import parse_url

from . import virtual
from .redis import (
Expand Down Expand Up @@ -142,8 +143,9 @@ def restore_by_tag(self, tag, client=None, leftmost=False, queue=''):


class RedisNodeConnection():
def __init__(self, key):
self.client = None
def __init__(self, key, cluster_connection):
self.cluster_connection = cluster_connection
self.redis_connection = None
self.in_poll = False
self.key = key
self.timeout = None
Expand All @@ -153,13 +155,13 @@ def __init__(self):
super().__init__()
self._sock_to_fd = {}

def _register(self, channel, client, conn, cmd):
ident = (channel, client, conn, cmd)
def _register(self, channel, conn, cmd):
ident = (channel, conn, cmd)

if ident in self._chan_to_sock:
self._unregister(*ident)

if not conn.client:
if not conn.redis_connection:
tries = 0
backoff = [0, 0.1, 0.2, 0.4]
while True:
Expand All @@ -168,41 +170,41 @@ def _register(self, channel, client, conn, cmd):
try:
if conn.key in channel.ask_errors:
ask_error = channel.ask_errors[conn.key]
node = channel.consumer_client.get_node(ask_error.host, ask_error.port)
node = conn.cluster_connection.get_node(ask_error.host, ask_error.port)
else:
node = channel.consumer_client.get_node_from_key(conn.key)
node = conn.cluster_connection.get_node_from_key(conn.key)
if node:
break
except:
logger.exception('Error while getting node from key', extra={"key": conn.key})

sleep(backoff[tries])
channel.consumer_client.nodes_manager.initialize()
conn.cluster_connection.nodes_manager.initialize()
tries += 1

redis_connection = channel.consumer_client.get_redis_connection(node)
conn.client = redis_connection.client()
redis_connection = conn.cluster_connection.get_redis_connection(node)
conn.redis_connection = redis_connection.client()

sock = conn.client.connection._sock
sock = conn.redis_connection.connection._sock
self._fd_to_chan[sock.fileno()] = (channel, conn, cmd)
self._chan_to_sock[ident] = sock
self._sock_to_fd[sock] = sock.fileno()
self.poller.register(sock, self.eventflags)

def _unregister(self, channel, client, conn, cmd):
sock = self._chan_to_sock[(channel, client, conn, cmd)]
def _unregister(self, channel, conn, cmd):
sock = self._chan_to_sock[(channel, conn, cmd)]
fd = self._sock_to_fd[sock]

self.poller.unregister(sock)
if conn.client:
if conn.client.connection:
if conn.redis_connection:
if conn.redis_connection.connection:
# There might be pending BRPOP response on the connection, so we disconnect to ensure safety
conn.client.connection.disconnect()
conn.client.close()
conn.client = None
conn.redis_connection.connection.disconnect()
conn.redis_connection.close()
conn.redis_connection = None

del self._fd_to_chan[fd]
del self._chan_to_sock[(channel, client, conn, cmd)]
del self._chan_to_sock[(channel, conn, cmd)]
del self._sock_to_fd[sock]

def discard(self, channel):
Expand All @@ -217,7 +219,7 @@ def _register_BRPOP(self, channel):
conns = self._get_conns_for_channel(channel)

for conn in conns:
ident = (channel, channel.consumer_client, conn, 'BRPOP')
ident = (channel, conn, 'BRPOP')

if (ident not in self._chan_to_sock):
try:
Expand Down Expand Up @@ -247,14 +249,15 @@ def maybe_restore_messages(self):

def _get_conns_for_channel(self, channel):
result = []
conns = [conn for _, _, conn, _ in self._chan_to_sock]
conns = [conn for _, conn, _ in self._chan_to_sock]
for key in channel.active_queues:
try:
conn = next(x for x in conns if x.key == key)
conns.remove(conn)
except StopIteration:
conn = RedisNodeConnection(key)
result.append(conn)
for client in channel.consumer_clients:
try:
conn = next(x for x in conns if x.key == key and x.cluster_connection == client)
conns.remove(conn)
except StopIteration:
conn = RedisNodeConnection(key, client)
result.append(conn)

return result

Expand Down Expand Up @@ -375,62 +378,70 @@ def conn_or_acquire(self, client=None):
yield self.client

@cached_property
def consumer_client(self):
def consumer_clients(self):
self.consumer_created = True

conninfo = self.connection.client

hostname = conninfo.hostname
port = conninfo.port
password = conninfo.password
transport = self.connection.client.transport_cls
ssl = transport == 'rediss-cluster'
parsed = parse_url(conninfo.hostname)
ssl = parsed['transport'] == 'rediss-cluster'

connection = RedisClusterConnection.get_consumer_connection(parsed['hostname'], parsed['port'], parsed['password'], ssl)

return RedisClusterConnection.get_consumer_connection(hostname, port, password, ssl)
# Additional redis cluster
# redis-cluster://172.16.0.1:7000?alt=redis-cluster://172.16.0.2:7000
if 'alt' in parsed:
alt_parsed = parse_url(parsed['alt'])
alt_ssl = alt_parsed['transport'] == 'rediss-cluster'
alt_connection = RedisClusterConnection.get_consumer_connection(alt_parsed['hostname'], alt_parsed['port'], alt_parsed['password'], alt_ssl)

return [connection, alt_connection]

return [connection]

@cached_property
def client(self):
conninfo = self.connection.client

hostname = conninfo.hostname
port = conninfo.port
password = conninfo.password
parsed = parse_url(conninfo.hostname)
transport = self.connection.client.transport_cls
ssl = transport == 'rediss-cluster'

return RedisClusterConnection.get_producer_connection(hostname, port, password, ssl)
return RedisClusterConnection.get_producer_connection(parsed['hostname'], parsed['port'], parsed['password'], ssl)

def close(self):
super().close()

RedisClusterConnection.close(self.client)
if self.consumer_created is True:
RedisClusterConnection.close(self.consumer_client)
for client in self.consumer_clients:
RedisClusterConnection.close(client)

def _brpop_start(self, timeout):
queues = self._queue_cycle.consume(len(self.active_queues))
if not queues:
return

for key in queues:
for _, _, conn, _ in self.connection.cycle._chan_to_sock:
if conn.key == key and conn.in_poll == False:
conn.in_poll = True
conn.timeout = timeout
if conn.key in self.ask_errors:
del self.ask_errors[conn.key]
for client in self.consumer_clients:
for _, conn, _ in self.connection.cycle._chan_to_sock:
if conn.key == key and conn.in_poll == False and conn.cluster_connection == client:
conn.in_poll = True
conn.timeout = timeout
if conn.key in self.ask_errors:
del self.ask_errors[conn.key]
try:
conn.redis_connection.execute_command('ASKING')
except:
logger.exception('Error while sending ASKING', extra={"key": conn.key})
continue

try:
conn.client.execute_command('ASKING')
conn.redis_connection.connection.send_command('BRPOP', key, timeout)
except:
logger.exception('Error while sending ASKING', extra={"key": conn.key})
continue

try:
conn.client.connection.send_command('BRPOP', key, timeout)
except:
logger.exception('Error while sending BRPOP', extra={"key": conn.key})
self.connection.cycle._unregister(self, self.consumer_client, conn, 'BRPOP')
break
logger.exception('Error while sending BRPOP', extra={"key": conn.key})
self.connection.cycle._unregister(self, conn, 'BRPOP')
break

def _brpop_read(self, **options):
conn = options.pop('conn')
Expand All @@ -441,7 +452,7 @@ def _brpop_read(self, **options):
# We should not throw error on this method to make kombu to continue operation
raise Empty()

conn.client.connection.send_command('BRPOP', conn.key, conn.timeout) # schedule next BRPOP
conn.redis_connection.connection.send_command('BRPOP', conn.key, conn.timeout) # schedule next BRPOP

if resp:
self.deliver_response(resp)
Expand All @@ -457,7 +468,7 @@ def _poll_error(self, cmd, conn, **options):
# Error is logged at `parse_response`
pass

self.connection.cycle._unregister(self, self.consumer_client, conn, 'BRPOP')
self.connection.cycle._unregister(self, conn, 'BRPOP')

def deliver_response(self, resp):
dest, item = resp
Expand All @@ -467,38 +478,38 @@ def deliver_response(self, resp):

def parse_response(self, conn, cmd, **options):
try:
return conn.client.parse_response(conn.client.connection, cmd, **options)
return conn.redis_connection.parse_response(conn.redis_connection.connection, cmd, **options)
except Exception as e:
logger.exception('Error while reading from Redis', extra={"key": conn.key})

# Mostly copied from https://github.com/sendbird/redis-py/blob/master/redis/cluster.py#L1173
if isinstance(e, ConnectionError) or isinstance(e, TimeoutError):
try:
node = channel.consumer_client.get_node_from_key(conn.key)
self.client.nodes_manager.startup_nodes.pop(node.name, None)
node = conn.cluster_connection.get_node_from_key(conn.key)
conn.cluster_connection.nodes_manager.startup_nodes.pop(node.name, None)
except:
logger.exception('Error while removing node', extra={"key": conn.key})
self.consumer_client.nodes_manager.initialize()
conn.cluster_connection.nodes_manager.initialize()
elif isinstance(e, MovedError):
self.consumer_client.reinitialize_counter += 1
if self.consumer_client._should_reinitialized():
self.consumer_client.nodes_manager.initialize()
self.consumer_client.reinitialize_counter = 0
conn.cluster_connection.reinitialize_counter += 1
if conn.cluster_connection._should_reinitialized():
conn.cluster_connection.nodes_manager.initialize()
conn.cluster_connection.reinitialize_counter = 0
else:
self.consumer_client.nodes_manager.update_moved_exception(e)
conn.cluster_connection.nodes_manager.update_moved_exception(e)
elif isinstance(e, SlotNotCoveredError):
self.consumer_client.reinitialize_counter += 1
if self.consumer_client._should_reinitialized():
self.consumer_client.nodes_manager.initialize()
self.consumer_client.reinitialize_counter = 0
conn.cluster_connection.reinitialize_counter += 1
if conn.cluster_connection._should_reinitialized():
conn.cluster_connection.nodes_manager.initialize()
conn.cluster_connection.reinitialize_counter = 0
elif isinstance(e, TryAgainError):
return # try again in next BRPOP
elif isinstance(e, AskError):
self.add_ask_error(e, conn)
elif isinstance(e, ClusterDownError):
self.consumer_client.nodes_manager.initialize()
conn.cluster_connection.nodes_manager.initialize()

self.connection.cycle._unregister(self, self.consumer_client, conn, cmd)
self.connection.cycle._unregister(self, conn, cmd)
raise

def add_ask_error(self, e, conn):
Expand All @@ -511,6 +522,7 @@ class Transport(RedisTransport):

driver_type = 'redis-cluster'
driver_name = driver_type
can_parse_url = True

implements = virtual.Transport.implements.extend(
asynchronous=True, exchange_type=frozenset(['direct'])
Expand Down
38 changes: 37 additions & 1 deletion t/integration/test_redis_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def parse_response(*args, **kwargs):
pass
except:
raise
assert conn.default_channel.consumer_client.reinitialize_counter != 0
assert conn.default_channel.consumer_clients[0].reinitialize_counter != 0


def test_askerror(connection):
Expand Down Expand Up @@ -209,3 +209,39 @@ def test_many_queue(connection):
assert message.content_encoding == 'utf-8'
assert message.headers == {'k1': 'v1'}
message.ack()


def test_multiple_consume():
consumer_conn = kombu.Connection('redis-cluster://localhost:7000?alt=redis-cluster://localhost:8000')
producer_conn1 = kombu.Connection('redis-cluster://localhost:7000')
producer_conn2 = kombu.Connection('redis-cluster://localhost:8000')

with producer_conn1 as producer:
queue = producer.SimpleQueue('test_multiple_consume')
queue.put({'Hello': 'World'}, headers={'k1': 'v1'})
queue.close()

with consumer_conn as consumer:
queue = consumer.SimpleQueue('test_multiple_consume')
message = queue.get(timeout=10)
assert message.payload == {'Hello': 'World'}
assert message.content_type == 'application/json'
assert message.content_encoding == 'utf-8'
assert message.headers == {'k1': 'v1'}
message.ack()
queue.close()

with producer_conn2 as producer:
queue = producer.SimpleQueue('test_multiple_consume1')
queue.put({'Hello': 'World'}, headers={'k1': 'v1'})
queue.close()

with consumer_conn as consumer:
queue = consumer.SimpleQueue('test_multiple_consume1')
message = queue.get(timeout=10)
assert message.payload == {'Hello': 'World'}
assert message.content_type == 'application/json'
assert message.content_encoding == 'utf-8'
assert message.headers == {'k1': 'v1'}
message.ack()
queue.close()

0 comments on commit e5c74c9

Please sign in to comment.