Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CV2-5589 More support for queuing under threaded environments #121

Merged
merged 3 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/queue/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
super().__init__()
self.input_queue_name = input_queue_name
self.input_queues = self.restrict_queues_to_suffix(
self.get_or_create_queues(input_queue_name), Queue.get_queue_suffix()
self.get_or_create_queue(input_queue_name), Queue.get_queue_suffix()
)
self.all_queues = self.store_queue_map(self.input_queues)
logger.info(f"Processor listening to queues of {self.all_queues}")
Expand Down
107 changes: 58 additions & 49 deletions lib/queue/queue.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from typing import List, Dict, Tuple
import os
from threading import Lock
import threading

import boto3
import botocore
Expand All @@ -14,11 +14,31 @@
MAX_RETRIES = int(os.getenv("MAX_RETRIES", "5"))

class Queue:
def __init__(self):
"""
Start a specific queue - must pass input_queue_name.
"""
self.sqs_lock = Lock()
_thread_local = threading.local()

@staticmethod
def get_sqs():
"""
Thread-safe, lazy-initialized boto3 SQS resource per thread.
"""
if not hasattr(Queue._thread_local, "sqs_resource"):
deploy_env = get_environment_setting("DEPLOY_ENV")
if deploy_env == "local":
logger.info("Using ElasticMQ Interface")
Queue._thread_local.sqs_resource = boto3.resource(
'sqs',
region_name=get_environment_setting("AWS_DEFAULT_REGION") or 'eu-central-1',
endpoint_url=get_environment_setting("ELASTICMQ_URI") or 'http://presto-elasticmq:9324',
aws_access_key_id=get_environment_setting("AWS_ACCESS_KEY_ID") or 'x',
aws_secret_access_key=get_environment_setting("AWS_SECRET_ACCESS_KEY") or 'x'
)
else:
logger.info("Using SQS Interface")
Queue._thread_local.sqs_resource = boto3.resource(
'sqs',
region_name=get_environment_setting("AWS_DEFAULT_REGION")
)
return Queue._thread_local.sqs_resource

@staticmethod
def get_queue_prefix():
Expand Down Expand Up @@ -86,9 +106,9 @@ def create_queue(self, queue_name: str) -> boto3.resources.base.ServiceResource:
Attributes=attributes
)

def get_or_create_queues(self, queue_name: str) -> List[boto3.resources.base.ServiceResource]:
def get_or_create_queue(self, queue_name: str):
"""
Initialize all queues for the given worker - try to create them if they are not found by name for whatever reason
Retrieve or create a queue with the specified name.
"""
try:
return [self.get_sqs().get_queue_by_name(QueueName=queue_name)]
Expand All @@ -98,27 +118,29 @@ def get_or_create_queues(self, queue_name: str) -> List[boto3.resources.base.Ser
else:
raise

def get_sqs(self) -> boto3.resources.base.ServiceResource:
def create_queue(self, queue_name: str):
"""
Get an instantiated SQS - if local, use local alternative via elasticmq
Create a queue by name.
"""
deploy_env = get_environment_setting("DEPLOY_ENV")
if deploy_env == "local":
logger.info(f"Using ElasticMQ Interface")
with self.sqs_lock:
return boto3.resource('sqs',
region_name=(get_environment_setting("AWS_DEFAULT_REGION") or 'eu-central-1'),
endpoint_url=(get_environment_setting("ELASTICMQ_URI") or 'http://presto-elasticmq:9324'),
aws_access_key_id=(get_environment_setting("AWS_ACCESS_KEY_ID") or 'x'),
aws_secret_access_key=(get_environment_setting("AWS_SECRET_ACCESS_KEY") or 'x'))
else:
logger.info(f"Using SQS Interface")
with self.sqs_lock:
return boto3.resource('sqs', region_name=get_environment_setting("AWS_DEFAULT_REGION"))
attributes = {}
if queue_name.endswith('.fifo'):
attributes['FifoQueue'] = 'true'
attributes['ContentBasedDeduplication'] = 'true'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is attributes['ContentBasedDeduplication'] related to the queue ending in .fifo or is it something we want all the time?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its a fifo thing - and I realized I copypasta'd here twice so removing one of these!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right; the deduplication is an aspect of how FIFO queues work (where you avoid processing the same message more than once). I don't think we'll ever have a scenario with FIFO, but without deduplication, but it doesn't seem unreasonable to separate these aspects.

return self.get_sqs().create_queue(QueueName=queue_name, Attributes=attributes)

def send_message(self, queue_name: str, message: schemas.Message):
"""
Send a message to a specific queue.
"""
queue = self.get_or_create_queue(queue_name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to look up the queue each time we send a message? I would have hoped we could store a reference after looking it up once

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that if we do store it we introduce a bunch of state which causes our problem in the first place

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first call it will create the resource object, but then it will return a thread local reference on every subsequent request (within that thread). I think this is necessary due to the context switching that could happen otherwise...

message_data = {"MessageBody": json.dumps(message.dict())}
if queue_name.endswith('.fifo'):
message_data["MessageGroupId"] = message.body.id
queue.send_message(**message_data)

def group_deletions(self, messages_with_queues: List[Tuple[schemas.Message, boto3.resources.base.ServiceResource]]) -> Dict[boto3.resources.base.ServiceResource, List[schemas.Message]]:
"""
Group deletions so that we can run through a simplified set of batches rather than delete each item independently
Group deletions by queue.
"""
queue_to_messages = {}
for message, queue in messages_with_queues:
Expand All @@ -127,25 +149,21 @@ def group_deletions(self, messages_with_queues: List[Tuple[schemas.Message, boto
queue_to_messages[queue].append(message)
return queue_to_messages

def delete_messages(self, messages_with_queues: List[Tuple[schemas.Message, boto3.resources.base.ServiceResource]]) -> None:
def delete_messages(self, messages_with_queues: List[Tuple[schemas.Message, boto3.resources.base.ServiceResource]]):
"""
Delete messages as we process them so other processes don't pick them up.
SQS deals in max batches of 10, so break up messages into groups of 10
when deleting them.
Delete messages in batch mode.
"""
for queue, messages in self.group_deletions(messages_with_queues).items():
logger.info(f"Deleting messages {messages}")
self.delete_messages_from_queue(queue, messages)

def delete_messages_from_queue(self, queue: boto3.resources.base.ServiceResource, messages: List[schemas.Message]) -> None:
def delete_messages_from_queue(self, queue, messages: List[schemas.Message]):
"""
Helper function to delete a group of messages from a specific queue.
Helper to delete a batch of messages from a specific queue.
"""
for i in range(0, len(messages), 10):
batch = messages[i:i + 10]
entries = []
for idx, message in enumerate(batch):
logger.debug(f"Deleting message: {message}")
entries.append(self.delete_message_entry(message, idx))
entries = [{"Id": str(idx), "ReceiptHandle": message.receipt_handle} for idx, message in enumerate(batch)]
queue.delete_messages(Entries=entries)

def delete_message_entry(self, message: schemas.Message, idx: int = 0) -> Dict[str, str]:
Expand All @@ -157,21 +175,12 @@ def delete_message_entry(self, message: schemas.Message, idx: int = 0) -> Dict[s
'ReceiptHandle': message.receipt_handle
}

def receive_messages(self, batch_size: int = 1) -> List[Tuple[schemas.Message, boto3.resources.base.ServiceResource]]:
"""
Pull batch_size messages from input queue.
Actual SQS logic for pulling batch_size messages from matched queues
"""
messages_with_queues = []
for queue in self.input_queues:
if batch_size <= 0:
break
batch_messages = queue.receive_messages(MaxNumberOfMessages=min(batch_size, SQS_MAX_BATCH_SIZE))
for message in batch_messages:
if batch_size > 0:
messages_with_queues.append((message, queue))
batch_size -= 1
return messages_with_queues
def receive_messages(self, batch_size: int = 1):
"""
Receive messages from a queue.
"""
queue = self.get_or_create_queue(self.input_queue)
return queue.receive_messages(MaxNumberOfMessages=min(batch_size, SQS_MAX_BATCH_SIZE))

def find_queue_by_name(self, queue_name: str) -> boto3.resources.base.ServiceResource:
"""
Expand Down
6 changes: 3 additions & 3 deletions lib/queue/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def __init__(self, input_queue_name: str, output_queue_name: str = None, dlq_que
self.dlq_queue_name = dlq_queue_name or Queue.get_dead_letter_queue_name()
q_suffix = f"_output" + Queue.get_queue_suffix()
dlq_suffix = f"_dlq" + Queue.get_queue_suffix()
self.input_queues = self.restrict_queues_by_suffix(self.get_or_create_queues(input_queue_name), q_suffix)
self.output_queues = self.get_or_create_queues(self.output_queue_name)
self.dead_letter_queues = self.get_or_create_queues(self.dlq_queue_name)
self.input_queues = self.restrict_queues_by_suffix(self.get_or_create_queue(input_queue_name), q_suffix)
self.output_queues = self.get_or_create_queue(self.output_queue_name)
self.dead_letter_queues = self.get_or_create_queue(self.dlq_queue_name)
self.all_queues = self.store_queue_map([item for row in [self.input_queues, self.output_queues, self.dead_letter_queues] for item in row])
logger.info(f"Worker listening to queues of {self.all_queues}")

Expand Down
30 changes: 18 additions & 12 deletions test/lib/queue/test_queue.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pdb
import json
import os
import unittest
Expand Down Expand Up @@ -49,15 +50,24 @@ def setUp(self, mock_log_execution_time, mock_get_env_setting, mock_boto_resourc
self.mock_dlq_queue = MagicMock()
self.mock_dlq_queue.url = f"http://queue/{self.queue_name_dlq}"
self.mock_dlq_queue.attributes = {"QueueArn": f"queue:{self.queue_name_dlq}"}

# Set up side effects for get_queue_by_name
self.mock_sqs_resource.get_queue_by_name.side_effect = lambda QueueName: {
self.queue_name_input: self.mock_input_queue,
self.queue_name_output: self.mock_output_queue,
self.queue_name_dlq: self.mock_dlq_queue
}.get(QueueName)

mock_boto_resource.return_value = self.mock_sqs_resource

# Initialize the QueueWorker instance
self.queue = QueueWorker(self.queue_name_input, self.queue_name_output, self.queue_name_dlq)

# Initialize QueueWorker with mocked get_sqs method
with patch.object(QueueWorker, 'get_sqs', return_value=self.mock_sqs_resource):
self.queue = QueueWorker(self.queue_name_input, self.queue_name_output, self.queue_name_dlq)

# Ensure `self.all_queues` is populated for `find_queue_by_name`
self.queue.all_queues = self.queue.store_queue_map([
self.mock_input_queue, self.mock_output_queue, self.mock_dlq_queue
])

def test_get_output_queue_name(self):
self.assertEqual(self.queue.get_output_queue_name().replace(".fifo", ""), (self.queue.get_input_queue_name()+'_output').replace(".fifo", ""))
Expand Down Expand Up @@ -98,18 +108,15 @@ def test_process(self):
self.queue.receive_messages.assert_called_once_with(1)

def test_receive_messages(self):
self.queue.input_queue = self.queue_name_input
mock_queue1 = MagicMock()
mock_queue1.receive_messages.return_value = [FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"}}))]

mock_queue2 = MagicMock()
mock_queue2.receive_messages.return_value = [FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": {"id": 2, "callback_url": "http://example.com", "text": "This is another test"}}))]
self.queue.input_queues = [mock_queue1, mock_queue2]
mock_queue1.receive_messages.return_value = [FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"}})), FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": {"id": 2, "callback_url": "http://example.com", "text": "This is another test"}}))]
self.queue.get_or_create_queue = MagicMock(return_value=mock_queue1)
received_messages = self.queue.receive_messages(5)

# Check if the right number of messages were received and the content is correct
self.assertEqual(len(received_messages), 2)
self.assertIn("a test", received_messages[0][0].body)
self.assertIn("another test", received_messages[1][0].body)
self.assertIn("a test", received_messages[0].body)
self.assertIn("another test", received_messages[1].body)

def test_restrict_queues_by_suffix(self):
queues = [
Expand Down Expand Up @@ -149,7 +156,6 @@ def test_delete_messages_from_queue(self, mock_logger):
self.queue.delete_messages_from_queue(self.mock_input_queue, mock_messages)
# Check if the correct number of calls to delete_messages were made
self.mock_input_queue.delete_messages.assert_called_once()
mock_logger.assert_called_with(f"Deleting message: {mock_messages[-1]}")

def test_push_message(self):
message_to_push = schemas.parse_input_message({"body": {"id": 1, "content_hash": None, "callback_url": "http://example.com", "text": "This is a test"}, "model_name": "mean_tokens__Model"})
Expand Down
Loading