diff --git a/kombu/transport/redis_cluster.py b/kombu/transport/redis_cluster.py index 8b67281ca..b9802acb7 100644 --- a/kombu/transport/redis_cluster.py +++ b/kombu/transport/redis_cluster.py @@ -8,6 +8,7 @@ from kombu.utils.eventio import READ, ERR from kombu.utils.json import loads, dumps from kombu.utils.objects import cached_property +from kombu.utils.url import parse_url from . import virtual from .redis import ( @@ -142,8 +143,9 @@ def restore_by_tag(self, tag, client=None, leftmost=False, queue=''): class RedisNodeConnection(): - def __init__(self, key): - self.client = None + def __init__(self, key, cluster_connection): + self.cluster_connection = cluster_connection + self.redis_connection = None self.in_poll = False self.key = key self.timeout = None @@ -153,13 +155,13 @@ def __init__(self): super().__init__() self._sock_to_fd = {} - def _register(self, channel, client, conn, cmd): - ident = (channel, client, conn, cmd) + def _register(self, channel, conn, cmd): + ident = (channel, conn, cmd) if ident in self._chan_to_sock: self._unregister(*ident) - if not conn.client: + if not conn.redis_connection: tries = 0 backoff = [0, 0.1, 0.2, 0.4] while True: @@ -168,41 +170,41 @@ def _register(self, channel, client, conn, cmd): try: if conn.key in channel.ask_errors: ask_error = channel.ask_errors[conn.key] - node = channel.consumer_client.get_node(ask_error.host, ask_error.port) + node = conn.cluster_connection.get_node(ask_error.host, ask_error.port) else: - node = channel.consumer_client.get_node_from_key(conn.key) + node = conn.cluster_connection.get_node_from_key(conn.key) if node: break except: logger.exception('Error while getting node from key', extra={"key": conn.key}) sleep(backoff[tries]) - channel.consumer_client.nodes_manager.initialize() + conn.cluster_connection.nodes_manager.initialize() tries += 1 - redis_connection = channel.consumer_client.get_redis_connection(node) - conn.client = redis_connection.client() + redis_connection = conn.cluster_connection.get_redis_connection(node) + conn.redis_connection = redis_connection.client() - sock = conn.client.connection._sock + sock = conn.redis_connection.connection._sock self._fd_to_chan[sock.fileno()] = (channel, conn, cmd) self._chan_to_sock[ident] = sock self._sock_to_fd[sock] = sock.fileno() self.poller.register(sock, self.eventflags) - def _unregister(self, channel, client, conn, cmd): - sock = self._chan_to_sock[(channel, client, conn, cmd)] + def _unregister(self, channel, conn, cmd): + sock = self._chan_to_sock[(channel, conn, cmd)] fd = self._sock_to_fd[sock] self.poller.unregister(sock) - if conn.client: - if conn.client.connection: + if conn.redis_connection: + if conn.redis_connection.connection: # There might be pending BRPOP response on the connection, so we disconnect to ensure safety - conn.client.connection.disconnect() - conn.client.close() - conn.client = None + conn.redis_connection.connection.disconnect() + conn.redis_connection.close() + conn.redis_connection = None del self._fd_to_chan[fd] - del self._chan_to_sock[(channel, client, conn, cmd)] + del self._chan_to_sock[(channel, conn, cmd)] del self._sock_to_fd[sock] def discard(self, channel): @@ -217,7 +219,7 @@ def _register_BRPOP(self, channel): conns = self._get_conns_for_channel(channel) for conn in conns: - ident = (channel, channel.consumer_client, conn, 'BRPOP') + ident = (channel, conn, 'BRPOP') if (ident not in self._chan_to_sock): try: @@ -247,14 +249,15 @@ def maybe_restore_messages(self): def _get_conns_for_channel(self, channel): result = [] - conns = [conn for _, _, conn, _ in self._chan_to_sock] + conns = [conn for _, conn, _ in self._chan_to_sock] for key in channel.active_queues: - try: - conn = next(x for x in conns if x.key == key) - conns.remove(conn) - except StopIteration: - conn = RedisNodeConnection(key) - result.append(conn) + for client in channel.consumer_clients: + try: + conn = next(x for x in conns if x.key == key and x.cluster_connection == client) + conns.remove(conn) + except StopIteration: + conn = RedisNodeConnection(key, client) + result.append(conn) return result @@ -375,37 +378,44 @@ def conn_or_acquire(self, client=None): yield self.client @cached_property - def consumer_client(self): + def consumer_clients(self): self.consumer_created = True conninfo = self.connection.client - hostname = conninfo.hostname - port = conninfo.port - password = conninfo.password - transport = self.connection.client.transport_cls - ssl = transport == 'rediss-cluster' + parsed = parse_url(conninfo.hostname) + ssl = parsed['transport'] == 'rediss-cluster' + + connection = RedisClusterConnection.get_consumer_connection(parsed['hostname'], parsed['port'], parsed['password'], ssl) - return RedisClusterConnection.get_consumer_connection(hostname, port, password, ssl) + # Additional redis cluster + # redis-cluster://172.16.0.1:7000?alt=redis-cluster://172.16.0.2:7000 + if 'alt' in parsed: + alt_parsed = parse_url(parsed['alt']) + alt_ssl = alt_parsed['transport'] == 'rediss-cluster' + alt_connection = RedisClusterConnection.get_consumer_connection(alt_parsed['hostname'], alt_parsed['port'], alt_parsed['password'], alt_ssl) + + return [connection, alt_connection] + + return [connection] @cached_property def client(self): conninfo = self.connection.client - hostname = conninfo.hostname - port = conninfo.port - password = conninfo.password + parsed = parse_url(conninfo.hostname) transport = self.connection.client.transport_cls ssl = transport == 'rediss-cluster' - return RedisClusterConnection.get_producer_connection(hostname, port, password, ssl) + return RedisClusterConnection.get_producer_connection(parsed['hostname'], parsed['port'], parsed['password'], ssl) def close(self): super().close() RedisClusterConnection.close(self.client) if self.consumer_created is True: - RedisClusterConnection.close(self.consumer_client) + for client in self.consumer_clients: + RedisClusterConnection.close(client) def _brpop_start(self, timeout): queues = self._queue_cycle.consume(len(self.active_queues)) @@ -413,24 +423,25 @@ def _brpop_start(self, timeout): return for key in queues: - for _, _, conn, _ in self.connection.cycle._chan_to_sock: - if conn.key == key and conn.in_poll == False: - conn.in_poll = True - conn.timeout = timeout - if conn.key in self.ask_errors: - del self.ask_errors[conn.key] + for client in self.consumer_clients: + for _, conn, _ in self.connection.cycle._chan_to_sock: + if conn.key == key and conn.in_poll == False and conn.cluster_connection == client: + conn.in_poll = True + conn.timeout = timeout + if conn.key in self.ask_errors: + del self.ask_errors[conn.key] + try: + conn.redis_connection.execute_command('ASKING') + except: + logger.exception('Error while sending ASKING', extra={"key": conn.key}) + continue + try: - conn.client.execute_command('ASKING') + conn.redis_connection.connection.send_command('BRPOP', key, timeout) except: - logger.exception('Error while sending ASKING', extra={"key": conn.key}) - continue - - try: - conn.client.connection.send_command('BRPOP', key, timeout) - except: - logger.exception('Error while sending BRPOP', extra={"key": conn.key}) - self.connection.cycle._unregister(self, self.consumer_client, conn, 'BRPOP') - break + logger.exception('Error while sending BRPOP', extra={"key": conn.key}) + self.connection.cycle._unregister(self, conn, 'BRPOP') + break def _brpop_read(self, **options): conn = options.pop('conn') @@ -441,7 +452,7 @@ def _brpop_read(self, **options): # We should not throw error on this method to make kombu to continue operation raise Empty() - conn.client.connection.send_command('BRPOP', conn.key, conn.timeout) # schedule next BRPOP + conn.redis_connection.connection.send_command('BRPOP', conn.key, conn.timeout) # schedule next BRPOP if resp: self.deliver_response(resp) @@ -457,7 +468,7 @@ def _poll_error(self, cmd, conn, **options): # Error is logged at `parse_response` pass - self.connection.cycle._unregister(self, self.consumer_client, conn, 'BRPOP') + self.connection.cycle._unregister(self, conn, 'BRPOP') def deliver_response(self, resp): dest, item = resp @@ -467,38 +478,38 @@ def deliver_response(self, resp): def parse_response(self, conn, cmd, **options): try: - return conn.client.parse_response(conn.client.connection, cmd, **options) + return conn.redis_connection.parse_response(conn.redis_connection.connection, cmd, **options) except Exception as e: logger.exception('Error while reading from Redis', extra={"key": conn.key}) # Mostly copied from https://github.com/sendbird/redis-py/blob/master/redis/cluster.py#L1173 if isinstance(e, ConnectionError) or isinstance(e, TimeoutError): try: - node = channel.consumer_client.get_node_from_key(conn.key) - self.client.nodes_manager.startup_nodes.pop(node.name, None) + node = conn.cluster_connection.get_node_from_key(conn.key) + conn.cluster_connection.nodes_manager.startup_nodes.pop(node.name, None) except: logger.exception('Error while removing node', extra={"key": conn.key}) - self.consumer_client.nodes_manager.initialize() + conn.cluster_connection.nodes_manager.initialize() elif isinstance(e, MovedError): - self.consumer_client.reinitialize_counter += 1 - if self.consumer_client._should_reinitialized(): - self.consumer_client.nodes_manager.initialize() - self.consumer_client.reinitialize_counter = 0 + conn.cluster_connection.reinitialize_counter += 1 + if conn.cluster_connection._should_reinitialized(): + conn.cluster_connection.nodes_manager.initialize() + conn.cluster_connection.reinitialize_counter = 0 else: - self.consumer_client.nodes_manager.update_moved_exception(e) + conn.cluster_connection.nodes_manager.update_moved_exception(e) elif isinstance(e, SlotNotCoveredError): - self.consumer_client.reinitialize_counter += 1 - if self.consumer_client._should_reinitialized(): - self.consumer_client.nodes_manager.initialize() - self.consumer_client.reinitialize_counter = 0 + conn.cluster_connection.reinitialize_counter += 1 + if conn.cluster_connection._should_reinitialized(): + conn.cluster_connection.nodes_manager.initialize() + conn.cluster_connection.reinitialize_counter = 0 elif isinstance(e, TryAgainError): return # try again in next BRPOP elif isinstance(e, AskError): self.add_ask_error(e, conn) elif isinstance(e, ClusterDownError): - self.consumer_client.nodes_manager.initialize() + conn.cluster_connection.nodes_manager.initialize() - self.connection.cycle._unregister(self, self.consumer_client, conn, cmd) + self.connection.cycle._unregister(self, conn, cmd) raise def add_ask_error(self, e, conn): @@ -511,6 +522,7 @@ class Transport(RedisTransport): driver_type = 'redis-cluster' driver_name = driver_type + can_parse_url = True implements = virtual.Transport.implements.extend( asynchronous=True, exchange_type=frozenset(['direct']) diff --git a/t/integration/test_redis_cluster.py b/t/integration/test_redis_cluster.py index 47fab30e8..206ad5232 100644 --- a/t/integration/test_redis_cluster.py +++ b/t/integration/test_redis_cluster.py @@ -154,7 +154,7 @@ def parse_response(*args, **kwargs): pass except: raise - assert conn.default_channel.consumer_client.reinitialize_counter != 0 + assert conn.default_channel.consumer_clients[0].reinitialize_counter != 0 def test_askerror(connection): @@ -209,3 +209,39 @@ def test_many_queue(connection): assert message.content_encoding == 'utf-8' assert message.headers == {'k1': 'v1'} message.ack() + + +def test_multiple_consume(): + consumer_conn = kombu.Connection('redis-cluster://localhost:7000?alt=redis-cluster://localhost:8000') + producer_conn1 = kombu.Connection('redis-cluster://localhost:7000') + producer_conn2 = kombu.Connection('redis-cluster://localhost:8000') + + with producer_conn1 as producer: + queue = producer.SimpleQueue('test_multiple_consume') + queue.put({'Hello': 'World'}, headers={'k1': 'v1'}) + queue.close() + + with consumer_conn as consumer: + queue = consumer.SimpleQueue('test_multiple_consume') + message = queue.get(timeout=10) + assert message.payload == {'Hello': 'World'} + assert message.content_type == 'application/json' + assert message.content_encoding == 'utf-8' + assert message.headers == {'k1': 'v1'} + message.ack() + queue.close() + + with producer_conn2 as producer: + queue = producer.SimpleQueue('test_multiple_consume1') + queue.put({'Hello': 'World'}, headers={'k1': 'v1'}) + queue.close() + + with consumer_conn as consumer: + queue = consumer.SimpleQueue('test_multiple_consume1') + message = queue.get(timeout=10) + assert message.payload == {'Hello': 'World'} + assert message.content_type == 'application/json' + assert message.content_encoding == 'utf-8' + assert message.headers == {'k1': 'v1'} + message.ack() + queue.close()