From a91954a39bdf644948d448e1deb7e4502f367585 Mon Sep 17 00:00:00 2001 From: Devin Gaffney Date: Thu, 14 Nov 2024 10:43:47 -0800 Subject: [PATCH 1/3] CV2-5589 More support for queuing under threaded environments --- lib/queue/processor.py | 2 +- lib/queue/queue.py | 111 +++++++++++++++++++---------------- lib/queue/worker.py | 6 +- test/lib/queue/test_queue.py | 30 ++++++---- 4 files changed, 82 insertions(+), 67 deletions(-) diff --git a/lib/queue/processor.py b/lib/queue/processor.py index fdde2666..f750ce20 100644 --- a/lib/queue/processor.py +++ b/lib/queue/processor.py @@ -28,7 +28,7 @@ def __init__( super().__init__() self.input_queue_name = input_queue_name self.input_queues = self.restrict_queues_to_suffix( - self.get_or_create_queues(input_queue_name), Queue.get_queue_suffix() + self.get_or_create_queue(input_queue_name), Queue.get_queue_suffix() ) self.all_queues = self.store_queue_map(self.input_queues) logger.info(f"Processor listening to queues of {self.all_queues}") diff --git a/lib/queue/queue.py b/lib/queue/queue.py index 67f8a23f..ecde4fd9 100644 --- a/lib/queue/queue.py +++ b/lib/queue/queue.py @@ -1,7 +1,7 @@ import json from typing import List, Dict, Tuple import os -from threading import Lock +import threading import boto3 import botocore @@ -14,11 +14,31 @@ MAX_RETRIES = int(os.getenv("MAX_RETRIES", "5")) class Queue: - def __init__(self): - """ - Start a specific queue - must pass input_queue_name. - """ - self.sqs_lock = Lock() + _thread_local = threading.local() + + @staticmethod + def get_sqs(): + """ + Thread-safe, lazy-initialized boto3 SQS resource per thread. + """ + if not hasattr(Queue._thread_local, "sqs_resource"): + deploy_env = get_environment_setting("DEPLOY_ENV") + if deploy_env == "local": + logger.info("Using ElasticMQ Interface") + Queue._thread_local.sqs_resource = boto3.resource( + 'sqs', + region_name=get_environment_setting("AWS_DEFAULT_REGION") or 'eu-central-1', + endpoint_url=get_environment_setting("ELASTICMQ_URI") or 'http://presto-elasticmq:9324', + aws_access_key_id=get_environment_setting("AWS_ACCESS_KEY_ID") or 'x', + aws_secret_access_key=get_environment_setting("AWS_SECRET_ACCESS_KEY") or 'x' + ) + else: + logger.info("Using SQS Interface") + Queue._thread_local.sqs_resource = boto3.resource( + 'sqs', + region_name=get_environment_setting("AWS_DEFAULT_REGION") + ) + return Queue._thread_local.sqs_resource @staticmethod def get_queue_prefix(): @@ -86,39 +106,41 @@ def create_queue(self, queue_name: str) -> boto3.resources.base.ServiceResource: Attributes=attributes ) - def get_or_create_queues(self, queue_name: str) -> List[boto3.resources.base.ServiceResource]: + def get_or_create_queue(self, queue_name: str): """ - Initialize all queues for the given worker - try to create them if they are not found by name for whatever reason + Retrieve or create a queue with the specified name. """ try: - return [self.get_sqs().get_queue_by_name(QueueName=queue_name)] + return self.get_sqs().get_queue_by_name(QueueName=queue_name) except botocore.exceptions.ClientError as e: if e.response['Error']['Code'] == "AWS.SimpleQueueService.NonExistentQueue": - return [self.create_queue(queue_name)] + return self.create_queue(queue_name) else: raise - def get_sqs(self) -> boto3.resources.base.ServiceResource: + def create_queue(self, queue_name: str): """ - Get an instantiated SQS - if local, use local alternative via elasticmq + Create a queue by name. """ - deploy_env = get_environment_setting("DEPLOY_ENV") - if deploy_env == "local": - logger.info(f"Using ElasticMQ Interface") - with self.sqs_lock: - return boto3.resource('sqs', - region_name=(get_environment_setting("AWS_DEFAULT_REGION") or 'eu-central-1'), - endpoint_url=(get_environment_setting("ELASTICMQ_URI") or 'http://presto-elasticmq:9324'), - aws_access_key_id=(get_environment_setting("AWS_ACCESS_KEY_ID") or 'x'), - aws_secret_access_key=(get_environment_setting("AWS_SECRET_ACCESS_KEY") or 'x')) - else: - logger.info(f"Using SQS Interface") - with self.sqs_lock: - return boto3.resource('sqs', region_name=get_environment_setting("AWS_DEFAULT_REGION")) + attributes = {} + if queue_name.endswith('.fifo'): + attributes['FifoQueue'] = 'true' + attributes['ContentBasedDeduplication'] = 'true' + return self.get_sqs().create_queue(QueueName=queue_name, Attributes=attributes) + + def send_message(self, queue_name: str, message: schemas.Message): + """ + Send a message to a specific queue. + """ + queue = self.get_or_create_queue(queue_name) + message_data = {"MessageBody": json.dumps(message.dict())} + if queue_name.endswith('.fifo'): + message_data["MessageGroupId"] = message.body.id + queue.send_message(**message_data) def group_deletions(self, messages_with_queues: List[Tuple[schemas.Message, boto3.resources.base.ServiceResource]]) -> Dict[boto3.resources.base.ServiceResource, List[schemas.Message]]: """ - Group deletions so that we can run through a simplified set of batches rather than delete each item independently + Group deletions by queue. """ queue_to_messages = {} for message, queue in messages_with_queues: @@ -127,25 +149,21 @@ def group_deletions(self, messages_with_queues: List[Tuple[schemas.Message, boto queue_to_messages[queue].append(message) return queue_to_messages - def delete_messages(self, messages_with_queues: List[Tuple[schemas.Message, boto3.resources.base.ServiceResource]]) -> None: + def delete_messages(self, messages_with_queues: List[Tuple[schemas.Message, boto3.resources.base.ServiceResource]]): """ - Delete messages as we process them so other processes don't pick them up. - SQS deals in max batches of 10, so break up messages into groups of 10 - when deleting them. + Delete messages in batch mode. """ for queue, messages in self.group_deletions(messages_with_queues).items(): + logger.info(f"Deleting messages {messages}") self.delete_messages_from_queue(queue, messages) - def delete_messages_from_queue(self, queue: boto3.resources.base.ServiceResource, messages: List[schemas.Message]) -> None: + def delete_messages_from_queue(self, queue, messages: List[schemas.Message]): """ - Helper function to delete a group of messages from a specific queue. + Helper to delete a batch of messages from a specific queue. """ for i in range(0, len(messages), 10): batch = messages[i:i + 10] - entries = [] - for idx, message in enumerate(batch): - logger.debug(f"Deleting message: {message}") - entries.append(self.delete_message_entry(message, idx)) + entries = [{"Id": str(idx), "ReceiptHandle": message.receipt_handle} for idx, message in enumerate(batch)] queue.delete_messages(Entries=entries) def delete_message_entry(self, message: schemas.Message, idx: int = 0) -> Dict[str, str]: @@ -157,21 +175,12 @@ def delete_message_entry(self, message: schemas.Message, idx: int = 0) -> Dict[s 'ReceiptHandle': message.receipt_handle } - def receive_messages(self, batch_size: int = 1) -> List[Tuple[schemas.Message, boto3.resources.base.ServiceResource]]: - """ - Pull batch_size messages from input queue. - Actual SQS logic for pulling batch_size messages from matched queues - """ - messages_with_queues = [] - for queue in self.input_queues: - if batch_size <= 0: - break - batch_messages = queue.receive_messages(MaxNumberOfMessages=min(batch_size, SQS_MAX_BATCH_SIZE)) - for message in batch_messages: - if batch_size > 0: - messages_with_queues.append((message, queue)) - batch_size -= 1 - return messages_with_queues + def receive_messages(self, batch_size: int = 1): + """ + Receive messages from a queue. + """ + queue = self.get_or_create_queue(self.input_queue) + return queue.receive_messages(MaxNumberOfMessages=min(batch_size, SQS_MAX_BATCH_SIZE)) def find_queue_by_name(self, queue_name: str) -> boto3.resources.base.ServiceResource: """ diff --git a/lib/queue/worker.py b/lib/queue/worker.py index 80094411..f7af4135 100644 --- a/lib/queue/worker.py +++ b/lib/queue/worker.py @@ -36,9 +36,9 @@ def __init__(self, input_queue_name: str, output_queue_name: str = None, dlq_que self.dlq_queue_name = dlq_queue_name or Queue.get_dead_letter_queue_name() q_suffix = f"_output" + Queue.get_queue_suffix() dlq_suffix = f"_dlq" + Queue.get_queue_suffix() - self.input_queues = self.restrict_queues_by_suffix(self.get_or_create_queues(input_queue_name), q_suffix) - self.output_queues = self.get_or_create_queues(self.output_queue_name) - self.dead_letter_queues = self.get_or_create_queues(self.dlq_queue_name) + self.input_queues = self.restrict_queues_by_suffix(self.get_or_create_queue(input_queue_name), q_suffix) + self.output_queues = self.get_or_create_queue(self.output_queue_name) + self.dead_letter_queues = self.get_or_create_queue(self.dlq_queue_name) self.all_queues = self.store_queue_map([item for row in [self.input_queues, self.output_queues, self.dead_letter_queues] for item in row]) logger.info(f"Worker listening to queues of {self.all_queues}") diff --git a/test/lib/queue/test_queue.py b/test/lib/queue/test_queue.py index 3961ea62..24a10a7d 100644 --- a/test/lib/queue/test_queue.py +++ b/test/lib/queue/test_queue.py @@ -1,3 +1,4 @@ +import pdb import json import os import unittest @@ -49,15 +50,24 @@ def setUp(self, mock_log_execution_time, mock_get_env_setting, mock_boto_resourc self.mock_dlq_queue = MagicMock() self.mock_dlq_queue.url = f"http://queue/{self.queue_name_dlq}" self.mock_dlq_queue.attributes = {"QueueArn": f"queue:{self.queue_name_dlq}"} + + # Set up side effects for get_queue_by_name self.mock_sqs_resource.get_queue_by_name.side_effect = lambda QueueName: { self.queue_name_input: self.mock_input_queue, self.queue_name_output: self.mock_output_queue, self.queue_name_dlq: self.mock_dlq_queue }.get(QueueName) + mock_boto_resource.return_value = self.mock_sqs_resource - - # Initialize the QueueWorker instance - self.queue = QueueWorker(self.queue_name_input, self.queue_name_output, self.queue_name_dlq) + + # Initialize QueueWorker with mocked get_sqs method + with patch.object(QueueWorker, 'get_sqs', return_value=self.mock_sqs_resource): + self.queue = QueueWorker(self.queue_name_input, self.queue_name_output, self.queue_name_dlq) + + # Ensure `self.all_queues` is populated for `find_queue_by_name` + self.queue.all_queues = self.queue.store_queue_map([ + self.mock_input_queue, self.mock_output_queue, self.mock_dlq_queue + ]) def test_get_output_queue_name(self): self.assertEqual(self.queue.get_output_queue_name().replace(".fifo", ""), (self.queue.get_input_queue_name()+'_output').replace(".fifo", "")) @@ -98,18 +108,15 @@ def test_process(self): self.queue.receive_messages.assert_called_once_with(1) def test_receive_messages(self): + self.queue.input_queue = self.queue_name_input mock_queue1 = MagicMock() - mock_queue1.receive_messages.return_value = [FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"}}))] - - mock_queue2 = MagicMock() - mock_queue2.receive_messages.return_value = [FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": {"id": 2, "callback_url": "http://example.com", "text": "This is another test"}}))] - self.queue.input_queues = [mock_queue1, mock_queue2] + mock_queue1.receive_messages.return_value = [FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"}})), FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": {"id": 2, "callback_url": "http://example.com", "text": "This is another test"}}))] + self.queue.get_or_create_queue = MagicMock(return_value=mock_queue1) received_messages = self.queue.receive_messages(5) - # Check if the right number of messages were received and the content is correct self.assertEqual(len(received_messages), 2) - self.assertIn("a test", received_messages[0][0].body) - self.assertIn("another test", received_messages[1][0].body) + self.assertIn("a test", received_messages[0].body) + self.assertIn("another test", received_messages[1].body) def test_restrict_queues_by_suffix(self): queues = [ @@ -149,7 +156,6 @@ def test_delete_messages_from_queue(self, mock_logger): self.queue.delete_messages_from_queue(self.mock_input_queue, mock_messages) # Check if the correct number of calls to delete_messages were made self.mock_input_queue.delete_messages.assert_called_once() - mock_logger.assert_called_with(f"Deleting message: {mock_messages[-1]}") def test_push_message(self): message_to_push = schemas.parse_input_message({"body": {"id": 1, "content_hash": None, "callback_url": "http://example.com", "text": "This is a test"}, "model_name": "mean_tokens__Model"}) From 04be44c384b4950f7d30f2e0ec499251b414e6b3 Mon Sep 17 00:00:00 2001 From: Devin Gaffney Date: Thu, 14 Nov 2024 10:50:39 -0800 Subject: [PATCH 2/3] small tweak to return queue set --- lib/queue/queue.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/queue/queue.py b/lib/queue/queue.py index ecde4fd9..ffe70b2b 100644 --- a/lib/queue/queue.py +++ b/lib/queue/queue.py @@ -111,10 +111,10 @@ def get_or_create_queue(self, queue_name: str): Retrieve or create a queue with the specified name. """ try: - return self.get_sqs().get_queue_by_name(QueueName=queue_name) + return [self.get_sqs().get_queue_by_name(QueueName=queue_name)] except botocore.exceptions.ClientError as e: if e.response['Error']['Code'] == "AWS.SimpleQueueService.NonExistentQueue": - return self.create_queue(queue_name) + return [self.create_queue(queue_name)] else: raise From d80ada1177390f9ba693bf50d068cfeda8500c99 Mon Sep 17 00:00:00 2001 From: Devin Gaffney Date: Thu, 14 Nov 2024 12:02:27 -0800 Subject: [PATCH 3/3] remove function --- lib/queue/queue.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/lib/queue/queue.py b/lib/queue/queue.py index ffe70b2b..ef1ed530 100644 --- a/lib/queue/queue.py +++ b/lib/queue/queue.py @@ -118,16 +118,6 @@ def get_or_create_queue(self, queue_name: str): else: raise - def create_queue(self, queue_name: str): - """ - Create a queue by name. - """ - attributes = {} - if queue_name.endswith('.fifo'): - attributes['FifoQueue'] = 'true' - attributes['ContentBasedDeduplication'] = 'true' - return self.get_sqs().create_queue(QueueName=queue_name, Attributes=attributes) - def send_message(self, queue_name: str, message: schemas.Message): """ Send a message to a specific queue.