diff --git a/share/search/daemon.py b/share/search/daemon.py index 90aedb855..6f400afd7 100644 --- a/share/search/daemon.py +++ b/share/search/daemon.py @@ -7,7 +7,9 @@ import threading import time +import amqp.exceptions from django.conf import settings +import kombu from kombu.mixins import ConsumerMixin import sentry_sdk @@ -51,9 +53,14 @@ def start_daemonthreads_for_strategy(self, index_strategy): # spin up daemonthreads, ready for messages self._daemonthreads.extend(_daemon.start()) # assign a thread to pass messages to this daemon - threading.Thread( - target=CeleryMessageConsumer(self.celery_app, _daemon).run, - ).start() + _consumer = CeleryMessageConsumer( + celery_app=self.celery_app, + stop_event=self.stop_event, + index_strategy=index_strategy, + message_callback=_daemon.on_message, + ) + _daemon.kombu_connection = _consumer.connection + threading.Thread(target=_consumer.run).start() return _daemon def start_all_daemonthreads(self): @@ -73,12 +80,12 @@ class CeleryMessageConsumer(ConsumerMixin): # (from ConsumerMixin) # should_stop: bool - def __init__(self, celery_app, indexer_daemon): + def __init__(self, *, celery_app, stop_event, message_callback, index_strategy): self.connection = celery_app.pool.acquire(block=True) self.celery_app = celery_app - self.__stop_event = indexer_daemon.stop_event - self.__message_callback = indexer_daemon.on_message - self.__index_strategy = indexer_daemon.index_strategy + self.__stop_event = stop_event + self.__message_callback = message_callback + self.__index_strategy = index_strategy # overrides ConsumerMixin.run def run(self): @@ -115,6 +122,7 @@ def __repr__(self): class IndexerDaemon: MAX_LOCAL_QUEUE_SIZE = 5000 + kombu_connection: kombu.Connection | None = None def __init__(self, index_strategy, *, stop_event=None, daemonthread_context=None): self.stop_event = ( @@ -154,6 +162,7 @@ def start_typed_loop_and_queue(self, message_type) -> threading.Thread: local_message_queue=_queue_from_rabbit_to_daemon, log_prefix=f'{repr(self)} MessageHandlingLoop: ', daemonthread_context=self.__daemonthread_context, + kombu_connection=self.kombu_connection, ) return _handling_loop.start_thread() @@ -187,6 +196,7 @@ class MessageHandlingLoop: local_message_queue: queue.Queue log_prefix: str daemonthread_context: contextlib.AbstractContextManager + kombu_connection: kombu.Connection | None = None _leftover_daemon_messages_by_target_id = None def __post_init__(self): @@ -270,7 +280,7 @@ def _handle_some_messages(self): sentry_sdk.capture_message('error handling message', extras={'message_response': message_response}) target_id = message_response.index_message.target_id for daemon_message in daemon_messages_by_target_id.pop(target_id, ()): - daemon_message.ack() # finally set it free + self._ensure_ack(daemon_message) if daemon_messages_by_target_id: # should be empty by now logger.error('%sUnhandled messages?? %s', self.log_prefix, len(daemon_messages_by_target_id)) sentry_sdk.capture_message( @@ -296,6 +306,16 @@ def _back_off(self): logger.warning(f'{self.log_prefix}Backing off (pause for {self._backoff_timeout:.2} seconds)') _backoff_wait(stop_event=self.stop_event, backoff_timeout=self._backoff_timeout) + def _ensure_ack(self, daemon_message: messages.DaemonMessage): + try: + daemon_message.ack() + except amqp.exceptions.RecoverableConnectionError: + if self.kombu_connection is not None: + @self.kombu_connection.autoretry(channel=daemon_message.kombu_message.channel) + def _do_ack(*, channel): + channel.basic_ack(daemon_message.kombu_message.delivery_tag) + _do_ack() + # helper function for easier testing of backoff logic def _backoff_wait(*, stop_event, backoff_timeout): diff --git a/share/search/messages.py b/share/search/messages.py index 5ba2e466a..4e17bcf30 100644 --- a/share/search/messages.py +++ b/share/search/messages.py @@ -142,7 +142,11 @@ def __init__(self, *, kombu_message=None): def ack(self): if self.kombu_message is None: raise exceptions.DaemonMessageError('ack! called DaemonMessage.ack() but there is nothing to ack') - return self.kombu_message.ack() + try: + return self.kombu_message.ack() + except self.kombu_message.recoverable_connection_errors: + with ... as _channel: + _channel.basic_ack(self.kombu_message.delivery_tag) def requeue(self): if self.kombu_message is None: