Skip to content

Commit

Permalink
Merge pull request #20 from sendbird/feat/separate_connection
Browse files Browse the repository at this point in the history
Separate producer/consumer connection
  • Loading branch information
dlunch authored Sep 18, 2023
2 parents e11979b + e5f6b46 commit 5d6db08
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 37 deletions.
107 changes: 74 additions & 33 deletions kombu/transport/redis_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from kombu.utils.encoding import bytes_to_str
from kombu.utils.eventio import READ, ERR
from kombu.utils.json import loads, dumps
from kombu.utils.objects import cached_property

from . import virtual
from .redis import (
Expand All @@ -27,15 +28,18 @@
logger = get_logger(__name__)


# Override this method to use other redis client
def create_redis_cluster_connection(hostname, port, password, ssl):
params = {'skip_full_coverage_check': True, 'host': hostname, 'port': port, 'password': password}
# Override create_redis_cluster_connection_for_{producer,consumer} to use other redis client
def create_redis_cluster_connection_for_consumer(hostname, port, password, ssl):
params = {'require_full_coverage': False, 'host': hostname, 'port': port, 'password': password, 'dynamic_startup_nodes': True}
if ssl:
params['ssl'] = True

return redis.RedisCluster(**params)


create_redis_cluster_connection_for_producer = create_redis_cluster_connection_for_consumer


class QoS(RedisQoS):
def __init__(self, *args, **kwargs):
super(QoS, self).__init__(*args, **kwargs)
Expand Down Expand Up @@ -164,19 +168,19 @@ def _register(self, channel, client, conn, cmd):
try:
if conn.key in channel.ask_errors:
ask_error = channel.ask_errors[conn.key]
node = channel.client.get_node(ask_error.host, ask_error.port)
node = channel.consumer_client.get_node(ask_error.host, ask_error.port)
else:
node = channel.client.get_node_from_key(conn.key)
node = channel.consumer_client.get_node_from_key(conn.key)
if node:
break
except Exception as e:
logger.error('Error while getting node from key', extra={"e": e, "key": conn.key})

sleep(backoff[tries])
channel.client.nodes_manager.initialize()
channel.consumer_client.nodes_manager.initialize()
tries += 1

redis_connection = channel.client.get_redis_connection(node)
redis_connection = channel.consumer_client.get_redis_connection(node)
conn.client = redis_connection.client()

sock = conn.client.connection._sock
Expand Down Expand Up @@ -213,7 +217,7 @@ def _register_BRPOP(self, channel):
conns = self._get_conns_for_channel(channel)

for conn in conns:
ident = (channel, channel.client, conn, 'BRPOP')
ident = (channel, channel.consumer_client, conn, 'BRPOP')

if (ident not in self._chan_to_sock):
try:
Expand Down Expand Up @@ -273,22 +277,36 @@ def on_readable(self, fileno):


class RedisClusterConnection():
connections = {}
producer_connections = {}
consumer_connections = {}
connection_to_key = {}
refcounts = {}

@classmethod
def get_connection(cls, host, port, password, ssl):
key = (host, port, password, ssl)
if key not in cls.connections:
connection = create_redis_cluster_connection(host, port, password, ssl)
cls.connections[key] = connection
def get_consumer_connection(cls, host, port, password, ssl):
key = (host, port, password, ssl, 'consumer')
if key not in cls.consumer_connections:
connection = create_redis_cluster_connection_for_consumer(host, port, password, ssl)
cls.consumer_connections[key] = connection
cls.connection_to_key[connection] = key
cls.refcounts[key] = 0

cls.refcounts[key] += 1

return cls.consumer_connections[key]

@classmethod
def get_producer_connection(cls, host, port, password, ssl):
key = (host, port, password, ssl, 'producer')
if key not in cls.producer_connections:
connection = create_redis_cluster_connection_for_producer(host, port, password, ssl)
cls.producer_connections[key] = connection
cls.connection_to_key[connection] = key
cls.refcounts[key] = 0

cls.refcounts[key] += 1

return cls.connections[key]
return cls.producer_connections[key]

@classmethod
def close(cls, connection):
Expand All @@ -299,7 +317,12 @@ def close(cls, connection):
connection.close()
del cls.refcounts[key]
del cls.connection_to_key[connection]
del cls.connections[key]
if key[4] == 'producer':
del cls.producer_connections[key]
elif key[4] == 'consumer':
del cls.consumer_connections[key]
else:
raise ValueError(f'Unknown connection type: {key[4]}')


class Channel(RedisChannel):
Expand All @@ -326,6 +349,7 @@ def __init__(self, conn, *args, **kwargs):
super().__init__(conn, *args, **kwargs)

self.ask_errors = {}
self.consumer_created = False

def _restore(self, message, leftmost=False):
if not self.ack_emulation:
Expand All @@ -350,7 +374,22 @@ def conn_or_acquire(self, client=None):
else:
yield self.client

def _create_client(self, asynchronous=False):
@cached_property
def consumer_client(self):
self.consumer_created = True

conninfo = self.connection.client

hostname = conninfo.hostname
port = conninfo.port
password = conninfo.password
transport = self.connection.client.transport_cls
ssl = transport == 'rediss-cluster'

return RedisClusterConnection.get_consumer_connection(hostname, port, password, ssl)

@cached_property
def client(self):
conninfo = self.connection.client

hostname = conninfo.hostname
Expand All @@ -359,12 +398,14 @@ def _create_client(self, asynchronous=False):
transport = self.connection.client.transport_cls
ssl = transport == 'rediss-cluster'

return RedisClusterConnection.get_connection(hostname, port, password, ssl)
return RedisClusterConnection.get_producer_connection(hostname, port, password, ssl)

def close(self):
super().close()

RedisClusterConnection.close(self.client)
if self.consumer_created is True:
RedisClusterConnection.close(self.consumer_client)

def _brpop_start(self, timeout):
queues = self._queue_cycle.consume(len(self.active_queues))
Expand All @@ -388,7 +429,7 @@ def _brpop_start(self, timeout):
conn.client.connection.send_command('BRPOP', key, timeout)
except:
logger.exception('Error while sending BRPOP', extra={"key": conn.key})
self.connection.cycle._unregister(self, self.client, conn, 'BRPOP')
self.connection.cycle._unregister(self, self.consumer_client, conn, 'BRPOP')
break

def _brpop_read(self, **options):
Expand Down Expand Up @@ -416,7 +457,7 @@ def _poll_error(self, cmd, conn, **options):
# Error is logged at `parse_response`
pass

self.connection.cycle._unregister(self, self.client, conn, 'BRPOP')
self.connection.cycle._unregister(self, self.consumer_client, conn, 'BRPOP')

def deliver_response(self, resp):
dest, item = resp
Expand All @@ -433,31 +474,31 @@ def parse_response(self, conn, cmd, **options):
# Mostly copied from https://github.com/sendbird/redis-py/blob/master/redis/cluster.py#L1173
if isinstance(e, ConnectionError) or isinstance(e, TimeoutError):
try:
node = channel.client.get_node_from_key(conn.key)
node = channel.consumer_client.get_node_from_key(conn.key)
self.client.nodes_manager.startup_nodes.pop(node.name, None)
except Exception as e:
logger.error('Error while removing node', extra={"e": e, "key": conn.key})
self.client.nodes_manager.initialize()
self.consumer_client.nodes_manager.initialize()
elif isinstance(e, MovedError):
self.client.reinitialize_counter += 1
if self.client._should_reinitialized():
self.client.nodes_manager.initialize()
self.client.reinitialize_counter = 0
self.consumer_client.reinitialize_counter += 1
if self.consumer_client._should_reinitialized():
self.consumer_client.nodes_manager.initialize()
self.consumer_client.reinitialize_counter = 0
else:
self.client.nodes_manager.update_moved_exception(e)
self.consumer_client.nodes_manager.update_moved_exception(e)
elif isinstance(e, SlotNotCoveredError):
self.client.reinitialize_counter += 1
if self.client._should_reinitialized():
self.client.nodes_manager.initialize()
self.client.reinitialize_counter = 0
self.consumer_client.reinitialize_counter += 1
if self.consumer_client._should_reinitialized():
self.consumer_client.nodes_manager.initialize()
self.consumer_client.reinitialize_counter = 0
elif isinstance(e, TryAgainError):
return # try again in next BRPOP
elif isinstance(e, AskError):
self.add_ask_error(e, conn)
elif isinstance(e, ClusterDownError):
self.client.nodes_manager.initialize()
self.consumer_client.nodes_manager.initialize()

self.connection.cycle._unregister(self, self.client, conn, cmd)
self.connection.cycle._unregister(self, self.consumer_client, conn, cmd)
raise

def add_ask_error(self, e, conn):
Expand Down
11 changes: 7 additions & 4 deletions t/integration/test_redis_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,18 @@ def patched_brpop_start(self, timeout):
def test_connection_reuse(connection):
from kombu.transport.redis_cluster import RedisClusterConnection

assert len(RedisClusterConnection.connections) == 0
assert len(RedisClusterConnection.producer_connections) == 0
assert len(RedisClusterConnection.consumer_connections) == 0
with connection as conn:
queue = conn.SimpleQueue('test_connectionerror')
queue.put({'Hello': 'World'}, headers={'k1': 'v1'})
_ = queue.get(timeout=1)

assert len(RedisClusterConnection.connections) == 1
assert len(RedisClusterConnection.producer_connections) == 1
assert len(RedisClusterConnection.consumer_connections) == 1

assert len(RedisClusterConnection.connections) == 0
assert len(RedisClusterConnection.producer_connections) == 0
assert len(RedisClusterConnection.consumer_connections) == 0


def test_brpop_send_error(connection):
Expand Down Expand Up @@ -151,7 +154,7 @@ def parse_response(*args, **kwargs):
pass
except:
raise
assert conn.default_channel.client.reinitialize_counter != 0
assert conn.default_channel.consumer_client.reinitialize_counter != 0


def test_askerror(connection):
Expand Down

0 comments on commit 5d6db08

Please sign in to comment.