Skip to content

Commit

Permalink
Merge pull request #121 from meedan/cv2-5589-queue-concurrency
Browse files Browse the repository at this point in the history
CV2-5589 More support for queuing under threaded environments
  • Loading branch information
DGaffney authored Nov 14, 2024
2 parents 6e3d1db + d80ada1 commit eb14ec3
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 65 deletions.
2 changes: 1 addition & 1 deletion lib/queue/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
super().__init__()
self.input_queue_name = input_queue_name
self.input_queues = self.restrict_queues_to_suffix(
self.get_or_create_queues(input_queue_name), Queue.get_queue_suffix()
self.get_or_create_queue(input_queue_name), Queue.get_queue_suffix()
)
self.all_queues = self.store_queue_map(self.input_queues)
logger.info(f"Processor listening to queues of {self.all_queues}")
Expand Down
97 changes: 48 additions & 49 deletions lib/queue/queue.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from typing import List, Dict, Tuple
import os
from threading import Lock
import threading

import boto3
import botocore
Expand All @@ -14,11 +14,31 @@
MAX_RETRIES = int(os.getenv("MAX_RETRIES", "5"))

class Queue:
def __init__(self):
"""
Start a specific queue - must pass input_queue_name.
"""
self.sqs_lock = Lock()
_thread_local = threading.local()

@staticmethod
def get_sqs():
"""
Thread-safe, lazy-initialized boto3 SQS resource per thread.
"""
if not hasattr(Queue._thread_local, "sqs_resource"):
deploy_env = get_environment_setting("DEPLOY_ENV")
if deploy_env == "local":
logger.info("Using ElasticMQ Interface")
Queue._thread_local.sqs_resource = boto3.resource(
'sqs',
region_name=get_environment_setting("AWS_DEFAULT_REGION") or 'eu-central-1',
endpoint_url=get_environment_setting("ELASTICMQ_URI") or 'http://presto-elasticmq:9324',
aws_access_key_id=get_environment_setting("AWS_ACCESS_KEY_ID") or 'x',
aws_secret_access_key=get_environment_setting("AWS_SECRET_ACCESS_KEY") or 'x'
)
else:
logger.info("Using SQS Interface")
Queue._thread_local.sqs_resource = boto3.resource(
'sqs',
region_name=get_environment_setting("AWS_DEFAULT_REGION")
)
return Queue._thread_local.sqs_resource

@staticmethod
def get_queue_prefix():
Expand Down Expand Up @@ -86,9 +106,9 @@ def create_queue(self, queue_name: str) -> boto3.resources.base.ServiceResource:
Attributes=attributes
)

def get_or_create_queues(self, queue_name: str) -> List[boto3.resources.base.ServiceResource]:
def get_or_create_queue(self, queue_name: str):
"""
Initialize all queues for the given worker - try to create them if they are not found by name for whatever reason
Retrieve or create a queue with the specified name.
"""
try:
return [self.get_sqs().get_queue_by_name(QueueName=queue_name)]
Expand All @@ -98,27 +118,19 @@ def get_or_create_queues(self, queue_name: str) -> List[boto3.resources.base.Ser
else:
raise

def get_sqs(self) -> boto3.resources.base.ServiceResource:
def send_message(self, queue_name: str, message: schemas.Message):
"""
Get an instantiated SQS - if local, use local alternative via elasticmq
Send a message to a specific queue.
"""
deploy_env = get_environment_setting("DEPLOY_ENV")
if deploy_env == "local":
logger.info(f"Using ElasticMQ Interface")
with self.sqs_lock:
return boto3.resource('sqs',
region_name=(get_environment_setting("AWS_DEFAULT_REGION") or 'eu-central-1'),
endpoint_url=(get_environment_setting("ELASTICMQ_URI") or 'http://presto-elasticmq:9324'),
aws_access_key_id=(get_environment_setting("AWS_ACCESS_KEY_ID") or 'x'),
aws_secret_access_key=(get_environment_setting("AWS_SECRET_ACCESS_KEY") or 'x'))
else:
logger.info(f"Using SQS Interface")
with self.sqs_lock:
return boto3.resource('sqs', region_name=get_environment_setting("AWS_DEFAULT_REGION"))
queue = self.get_or_create_queue(queue_name)
message_data = {"MessageBody": json.dumps(message.dict())}
if queue_name.endswith('.fifo'):
message_data["MessageGroupId"] = message.body.id
queue.send_message(**message_data)

def group_deletions(self, messages_with_queues: List[Tuple[schemas.Message, boto3.resources.base.ServiceResource]]) -> Dict[boto3.resources.base.ServiceResource, List[schemas.Message]]:
"""
Group deletions so that we can run through a simplified set of batches rather than delete each item independently
Group deletions by queue.
"""
queue_to_messages = {}
for message, queue in messages_with_queues:
Expand All @@ -127,25 +139,21 @@ def group_deletions(self, messages_with_queues: List[Tuple[schemas.Message, boto
queue_to_messages[queue].append(message)
return queue_to_messages

def delete_messages(self, messages_with_queues: List[Tuple[schemas.Message, boto3.resources.base.ServiceResource]]) -> None:
def delete_messages(self, messages_with_queues: List[Tuple[schemas.Message, boto3.resources.base.ServiceResource]]):
"""
Delete messages as we process them so other processes don't pick them up.
SQS deals in max batches of 10, so break up messages into groups of 10
when deleting them.
Delete messages in batch mode.
"""
for queue, messages in self.group_deletions(messages_with_queues).items():
logger.info(f"Deleting messages {messages}")
self.delete_messages_from_queue(queue, messages)

def delete_messages_from_queue(self, queue: boto3.resources.base.ServiceResource, messages: List[schemas.Message]) -> None:
def delete_messages_from_queue(self, queue, messages: List[schemas.Message]):
"""
Helper function to delete a group of messages from a specific queue.
Helper to delete a batch of messages from a specific queue.
"""
for i in range(0, len(messages), 10):
batch = messages[i:i + 10]
entries = []
for idx, message in enumerate(batch):
logger.debug(f"Deleting message: {message}")
entries.append(self.delete_message_entry(message, idx))
entries = [{"Id": str(idx), "ReceiptHandle": message.receipt_handle} for idx, message in enumerate(batch)]
queue.delete_messages(Entries=entries)

def delete_message_entry(self, message: schemas.Message, idx: int = 0) -> Dict[str, str]:
Expand All @@ -157,21 +165,12 @@ def delete_message_entry(self, message: schemas.Message, idx: int = 0) -> Dict[s
'ReceiptHandle': message.receipt_handle
}

def receive_messages(self, batch_size: int = 1) -> List[Tuple[schemas.Message, boto3.resources.base.ServiceResource]]:
"""
Pull batch_size messages from input queue.
Actual SQS logic for pulling batch_size messages from matched queues
"""
messages_with_queues = []
for queue in self.input_queues:
if batch_size <= 0:
break
batch_messages = queue.receive_messages(MaxNumberOfMessages=min(batch_size, SQS_MAX_BATCH_SIZE))
for message in batch_messages:
if batch_size > 0:
messages_with_queues.append((message, queue))
batch_size -= 1
return messages_with_queues
def receive_messages(self, batch_size: int = 1):
"""
Receive messages from a queue.
"""
queue = self.get_or_create_queue(self.input_queue)
return queue.receive_messages(MaxNumberOfMessages=min(batch_size, SQS_MAX_BATCH_SIZE))

def find_queue_by_name(self, queue_name: str) -> boto3.resources.base.ServiceResource:
"""
Expand Down
6 changes: 3 additions & 3 deletions lib/queue/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def __init__(self, input_queue_name: str, output_queue_name: str = None, dlq_que
self.dlq_queue_name = dlq_queue_name or Queue.get_dead_letter_queue_name()
q_suffix = f"_output" + Queue.get_queue_suffix()
dlq_suffix = f"_dlq" + Queue.get_queue_suffix()
self.input_queues = self.restrict_queues_by_suffix(self.get_or_create_queues(input_queue_name), q_suffix)
self.output_queues = self.get_or_create_queues(self.output_queue_name)
self.dead_letter_queues = self.get_or_create_queues(self.dlq_queue_name)
self.input_queues = self.restrict_queues_by_suffix(self.get_or_create_queue(input_queue_name), q_suffix)
self.output_queues = self.get_or_create_queue(self.output_queue_name)
self.dead_letter_queues = self.get_or_create_queue(self.dlq_queue_name)
self.all_queues = self.store_queue_map([item for row in [self.input_queues, self.output_queues, self.dead_letter_queues] for item in row])
logger.info(f"Worker listening to queues of {self.all_queues}")

Expand Down
30 changes: 18 additions & 12 deletions test/lib/queue/test_queue.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pdb
import json
import os
import unittest
Expand Down Expand Up @@ -49,15 +50,24 @@ def setUp(self, mock_log_execution_time, mock_get_env_setting, mock_boto_resourc
self.mock_dlq_queue = MagicMock()
self.mock_dlq_queue.url = f"http://queue/{self.queue_name_dlq}"
self.mock_dlq_queue.attributes = {"QueueArn": f"queue:{self.queue_name_dlq}"}

# Set up side effects for get_queue_by_name
self.mock_sqs_resource.get_queue_by_name.side_effect = lambda QueueName: {
self.queue_name_input: self.mock_input_queue,
self.queue_name_output: self.mock_output_queue,
self.queue_name_dlq: self.mock_dlq_queue
}.get(QueueName)

mock_boto_resource.return_value = self.mock_sqs_resource

# Initialize the QueueWorker instance
self.queue = QueueWorker(self.queue_name_input, self.queue_name_output, self.queue_name_dlq)

# Initialize QueueWorker with mocked get_sqs method
with patch.object(QueueWorker, 'get_sqs', return_value=self.mock_sqs_resource):
self.queue = QueueWorker(self.queue_name_input, self.queue_name_output, self.queue_name_dlq)

# Ensure `self.all_queues` is populated for `find_queue_by_name`
self.queue.all_queues = self.queue.store_queue_map([
self.mock_input_queue, self.mock_output_queue, self.mock_dlq_queue
])

def test_get_output_queue_name(self):
self.assertEqual(self.queue.get_output_queue_name().replace(".fifo", ""), (self.queue.get_input_queue_name()+'_output').replace(".fifo", ""))
Expand Down Expand Up @@ -98,18 +108,15 @@ def test_process(self):
self.queue.receive_messages.assert_called_once_with(1)

def test_receive_messages(self):
self.queue.input_queue = self.queue_name_input
mock_queue1 = MagicMock()
mock_queue1.receive_messages.return_value = [FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"}}))]

mock_queue2 = MagicMock()
mock_queue2.receive_messages.return_value = [FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": {"id": 2, "callback_url": "http://example.com", "text": "This is another test"}}))]
self.queue.input_queues = [mock_queue1, mock_queue2]
mock_queue1.receive_messages.return_value = [FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"}})), FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": {"id": 2, "callback_url": "http://example.com", "text": "This is another test"}}))]
self.queue.get_or_create_queue = MagicMock(return_value=mock_queue1)
received_messages = self.queue.receive_messages(5)

# Check if the right number of messages were received and the content is correct
self.assertEqual(len(received_messages), 2)
self.assertIn("a test", received_messages[0][0].body)
self.assertIn("another test", received_messages[1][0].body)
self.assertIn("a test", received_messages[0].body)
self.assertIn("another test", received_messages[1].body)

def test_restrict_queues_by_suffix(self):
queues = [
Expand Down Expand Up @@ -149,7 +156,6 @@ def test_delete_messages_from_queue(self, mock_logger):
self.queue.delete_messages_from_queue(self.mock_input_queue, mock_messages)
# Check if the correct number of calls to delete_messages were made
self.mock_input_queue.delete_messages.assert_called_once()
mock_logger.assert_called_with(f"Deleting message: {mock_messages[-1]}")

def test_push_message(self):
message_to_push = schemas.parse_input_message({"body": {"id": 1, "content_hash": None, "callback_url": "http://example.com", "text": "This is a test"}, "model_name": "mean_tokens__Model"})
Expand Down

0 comments on commit eb14ec3

Please sign in to comment.