From dd26eb703bc23642a5bb752f1796e1cff3183ea4 Mon Sep 17 00:00:00 2001 From: Devin Gaffney Date: Tue, 29 Aug 2023 06:16:31 -0700 Subject: [PATCH] CV2-3551 add local queue consumption and re-work a ton of the startup flow to accommodate --- Makefile | 9 ++- lib/http.py | 6 +- lib/queue/processor.py | 55 ++++++++++++++ lib/queue/queue.py | 73 +++---------------- lib/queue/worker.py | 56 ++++++++++++++ lib/schemas.py | 3 +- requirements.txt | 2 + run_processor.py | 11 +++ run.py => run_worker.py | 4 +- ...hcheck_and_model_engine.sh => start_all.sh | 3 +- test/lib/model/test_audio.py | 8 +- test/lib/queue/fake_sqs_message.py | 5 ++ test/lib/queue/test_processor.py | 57 +++++++++++++++ test/lib/queue/test_queue.py | 22 ++---- test/lib/test_http.py | 6 +- 15 files changed, 225 insertions(+), 95 deletions(-) create mode 100644 lib/queue/processor.py create mode 100644 lib/queue/worker.py create mode 100644 run_processor.py rename run.py => run_worker.py (75%) rename start_healthcheck_and_model_engine.sh => start_all.sh (81%) create mode 100644 test/lib/queue/fake_sqs_message.py create mode 100644 test/lib/queue/test_processor.py diff --git a/Makefile b/Makefile index 2b6ba87..69c6d8d 100644 --- a/Makefile +++ b/Makefile @@ -1,13 +1,16 @@ .PHONY: run run_http run_worker run_test run: - ./start_healthcheck_and_model_engine.sh + ./start_all.sh run_http: uvicorn main:app --host 0.0.0.0 --reload run_worker: - python run.py + python run_worker.py + +run_processor: + python run_processor.py run_test: - python -m unittest discover . \ No newline at end of file + python -m pytest test \ No newline at end of file diff --git a/lib/http.py b/lib/http.py index 26bf8e3..b250301 100644 --- a/lib/http.py +++ b/lib/http.py @@ -5,7 +5,7 @@ from httpx import HTTPStatusError from fastapi import FastAPI, Request from pydantic import BaseModel -from lib.queue.queue import Queue +from lib.queue.worker import QueueWorker from lib.logger import logger from lib import schemas @@ -29,8 +29,8 @@ async def post_url(url: str, params: dict) -> Dict[str, Any]: @app.post("/fingerprint_item/{fingerprinter}") def fingerprint_item(fingerprinter: str, message: Dict[str, Any]): - queue = Queue.create(fingerprinter) - queue.push_message(fingerprinter, schemas.Message(body=message, input_queue=queue.input_queue_name, output_queue=queue.output_queue_name, start_time=str(datetime.datetime.now()))) + queue = QueueWorker.create(fingerprinter) + queue.push_message(fingerprinter, schemas.Message(body=message)) return {"message": "Message pushed successfully"} @app.post("/trigger_callback") diff --git a/lib/queue/processor.py b/lib/queue/processor.py new file mode 100644 index 0000000..1b807c3 --- /dev/null +++ b/lib/queue/processor.py @@ -0,0 +1,55 @@ +from typing import List +import json + +import requests + +from lib import schemas +from lib.logger import logger +from lib.helpers import get_setting +from lib.queue.queue import Queue +class QueueProcessor(Queue): + @classmethod + def create(cls, input_queue_name: str = None, batch_size: int = 10): + """ + Instantiate a queue. Must pass input_queue_name, output_queue_name, and batch_size. + Pulls settings and then inits instance. + """ + input_queue_name = get_setting(input_queue_name, "MODEL_NAME").replace(".", "__") + logger.info(f"Starting queue with: ('{input_queue_name}', {batch_size})") + return QueueProcessor(input_queue_name, batch_size) + + def __init__(self, input_queue_name: str, output_queue_name: str = None, batch_size: int = 1): + """ + Start a specific queue - must pass input_queue_name - optionally pass output_queue_name, batch_size. + """ + super().__init__() + self.input_queue_name = input_queue_name + self.input_queues = self.restrict_queues_by_suffix(self.get_or_create_queues(input_queue_name), "_output") + self.all_queues = self.store_queue_map(self.input_queues) + self.batch_size = batch_size + + def send_callbacks(self) -> List[schemas.Message]: + """ + Main routine. Given a model, in a loop, read tasks from input_queue_name at batch_size depth, + pass messages to model to respond (i.e. fingerprint) them, then pass responses to output queue. + If failures happen at any point, resend failed messages to input queue. + """ + messages_with_queues = self.receive_messages(self.batch_size) + if messages_with_queues: + logger.debug(f"About to respond to: ({messages_with_queues})") + bodies = [schemas.Message(**json.loads(message.body)) for message, queue in messages_with_queues] + for body in bodies: + self.send_callback(body) + self.delete_messages(messages_with_queues) + + + def send_callback(self, body): + """ + Rescue against failures when attempting to respond (i.e. fingerprint) from models. + Return responses if no failure. + """ + try: + callback_url = body.get("callback_url") + requests.post(callback_url, json=body) + except Exception as e: + logger.error(f"Callback fail! Failed with {e} on {callback_url} with body of {body}") \ No newline at end of file diff --git a/lib/queue/queue.py b/lib/queue/queue.py index 7e5ea0f..82f5fd6 100644 --- a/lib/queue/queue.py +++ b/lib/queue/queue.py @@ -1,48 +1,28 @@ import json -from typing import Any, List, Dict, Tuple, Union +from typing import List, Dict, Tuple import os import boto3 import botocore -from lib.helpers import get_class, get_setting, get_environment_setting -from lib.model.model import Model +from lib.helpers import get_environment_setting from lib.logger import logger from lib import schemas +SQS_MAX_BATCH_SIZE = 10 class Queue: - @classmethod - def create(cls, input_queue_name: str = None, output_queue_name: str = None, batch_size: int = 10): + def __init__(self): """ - Instantiate a queue. Must pass input_queue_name, output_queue_name, and batch_size. - Pulls settings and then inits instance. - """ - input_queue_name = get_setting(input_queue_name, "MODEL_NAME").replace(".", "__") - output_queue_name = output_queue_name or f"{input_queue_name}_output" - logger.info(f"Starting queue with: ('{input_queue_name}', '{output_queue_name}', {batch_size})") - return Queue(input_queue_name, output_queue_name, batch_size) - - def __init__(self, input_queue_name: str, output_queue_name: str = None, batch_size: int = 1): - """ - Start a specific queue - must pass input_queue_name - optionally pass output_queue_name, batch_size. + Start a specific queue - must pass input_queue_name. """ self.sqs = self.get_sqs() - self.input_queue_name = input_queue_name - self.input_queues = self.restrict_queues_by_suffix(self.get_or_create_queues(input_queue_name), "_output") - if output_queue_name: - self.output_queue_name = self.get_output_queue_name(input_queue_name, output_queue_name) - self.output_queues = self.get_or_create_queues(output_queue_name) - self.all_queues = self.store_queue_map() - self.batch_size = batch_size - - def store_queue_map(self) -> Dict[str, boto3.resources.base.ServiceResource]: + def store_queue_map(self, all_queues: List[boto3.resources.base.ServiceResource]) -> Dict[str, boto3.resources.base.ServiceResource]: """ Store a quick lookup so that we dont loop through this over and over in other places. """ queue_map = {} - for group in [self.input_queues, self.output_queues]: - for q in group: - queue_map[self.queue_name(q)] = q + for queue in all_queues: + queue_map[self.queue_name(queue)] = queue return queue_map def queue_name(self, queue: boto3.resources.base.ServiceResource) -> str: @@ -101,7 +81,7 @@ def get_output_queue_name(self, input_queue_name: str, output_queue_name: str = If output_queue_name was empty or None, set name for queue. """ if not output_queue_name: - output_queue_name = f'{input_queue_name}-output' + output_queue_name = f'{input_queue_name}_output' return output_queue_name def group_deletions(self, messages_with_queues: List[Tuple[schemas.Message, boto3.resources.base.ServiceResource]]) -> Dict[boto3.resources.base.ServiceResource, List[schemas.Message]]: @@ -115,7 +95,7 @@ def group_deletions(self, messages_with_queues: List[Tuple[schemas.Message, boto queue_to_messages[queue].append(message) return queue_to_messages - def delete_messages(self, messages_with_queues: List[Tuple[Dict[str, Any], boto3.resources.base.ServiceResource]]) -> None: + def delete_messages(self, messages_with_queues: List[Tuple[schemas.Message, boto3.resources.base.ServiceResource]]) -> None: """ Delete messages as we process them so other processes don't pick them up. SQS deals in max batches of 10, so break up messages into groups of 10 @@ -141,31 +121,6 @@ def delete_messages_from_queue(self, queue: boto3.resources.base.ServiceResource entries.append(entry) queue.delete_messages(Entries=entries) - def safely_respond(self, model: Model) -> List[schemas.Message]: - """ - Rescue against failures when attempting to respond (i.e. fingerprint) from models. - Return responses if no failure. - """ - messages_with_queues = self.receive_messages(model.BATCH_SIZE) - responses = [] - if messages_with_queues: - logger.debug(f"About to respond to: ({messages_with_queues})") - responses = model.respond([schemas.Message(**json.loads(message.body)) for message, queue in messages_with_queues]) - self.delete_messages(messages_with_queues) - return responses - - def fingerprint(self, model: Model): - """ - Main routine. Given a model, in a loop, read tasks from input_queue_name at batch_size depth, - pass messages to model to respond (i.e. fingerprint) them, then pass responses to output queue. - If failures happen at any point, resend failed messages to input queue. - """ - responses = self.safely_respond(model) - if responses: - for response in responses: - logger.info(f"Processing message of: ({response})") - self.return_response(response) - def receive_messages(self, batch_size: int = 1) -> List[Tuple[schemas.Message, boto3.resources.base.ServiceResource]]: """ Pull batch_size messages from input queue. @@ -175,19 +130,13 @@ def receive_messages(self, batch_size: int = 1) -> List[Tuple[schemas.Message, b for queue in self.input_queues: if batch_size <= 0: break - batch_messages = queue.receive_messages(MaxNumberOfMessages=min(batch_size, self.batch_size)) + batch_messages = queue.receive_messages(MaxNumberOfMessages=min(batch_size, SQS_MAX_BATCH_SIZE)) for message in batch_messages: if batch_size > 0: messages_with_queues.append((message, queue)) batch_size -= 1 return messages_with_queues - def return_response(self, message: schemas.Message): - """ - Send message to output queue - """ - return self.push_message(self.output_queue_name, message) - def find_queue_by_name(self, queue_name: str) -> boto3.resources.base.ServiceResource: """ Search through queues to find the right one diff --git a/lib/queue/worker.py b/lib/queue/worker.py new file mode 100644 index 0000000..8bbc729 --- /dev/null +++ b/lib/queue/worker.py @@ -0,0 +1,56 @@ +import json +from typing import List +from lib import schemas +from lib.logger import logger +from lib.queue.queue import Queue +from lib.model.model import Model +from lib.helpers import get_setting +class QueueWorker(Queue): + @classmethod + def create(cls, input_queue_name: str = None): + """ + Instantiate a queue worker. Must pass input_queue_name. + Pulls settings and then inits instance. + """ + input_queue_name = get_setting(input_queue_name, "MODEL_NAME").replace(".", "__") + output_queue_name = f"{input_queue_name}_output" + logger.info(f"Starting queue with: ('{input_queue_name}', '{output_queue_name}')") + return QueueWorker(input_queue_name, output_queue_name) + + def __init__(self, input_queue_name: str, output_queue_name: str = None): + """ + Start a specific queue - must pass input_queue_name - optionally pass output_queue_name. + """ + super().__init__() + self.input_queue_name = input_queue_name + self.input_queues = self.restrict_queues_by_suffix(self.get_or_create_queues(input_queue_name), "_output") + if output_queue_name: + self.output_queue_name = self.get_output_queue_name(input_queue_name, output_queue_name) + self.output_queues = self.get_or_create_queues(output_queue_name) + self.all_queues = self.store_queue_map([item for row in [self.input_queues, self.output_queues] for item in row]) + + def fingerprint(self, model: Model): + """ + Main routine. Given a model, in a loop, read tasks from input_queue_name, + pass messages to model to respond (i.e. fingerprint) them, then pass responses to output queue. + If failures happen at any point, resend failed messages to input queue. + """ + responses = self.safely_respond(model) + if responses: + for response in responses: + logger.info(f"Processing message of: ({response})") + self.return_response(response) + + def safely_respond(self, model: Model) -> List[schemas.Message]: + """ + Rescue against failures when attempting to respond (i.e. fingerprint) from models. + Return responses if no failure. + """ + messages_with_queues = self.receive_messages(model.BATCH_SIZE) + responses = [] + if messages_with_queues: + logger.debug(f"About to respond to: ({messages_with_queues})") + responses = model.respond([schemas.Message(**json.loads(message.body)) for message, queue in messages_with_queues]) + self.delete_messages(messages_with_queues) + return responses + diff --git a/lib/schemas.py b/lib/schemas.py index 3b6ed34..30ca05e 100644 --- a/lib/schemas.py +++ b/lib/schemas.py @@ -50,5 +50,4 @@ class ImageOutput(BaseModel): class Message(BaseModel): body: Union[TextInput, VideoInput, AudioInput, ImageInput] - response: Any - + response: Any \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index d18ca8c..867696e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,5 @@ uvicorn[standard]==0.19.0 httpx==0.23.1 huggingface-hub==0.11.0 fasttext==0.9.2 +requests==2.31.0 +pytest==7.4.0 \ No newline at end of file diff --git a/run_processor.py b/run_processor.py new file mode 100644 index 0000000..960257c --- /dev/null +++ b/run_processor.py @@ -0,0 +1,11 @@ +import time +import os +import importlib +from lib.queue.processor import QueueProcessor +from lib.model.model import Model +from lib.logger import logger +queue = QueueProcessor.create() + +logger.info("Beginning callback loop...") +while True: + queue.send_callbacks() \ No newline at end of file diff --git a/run.py b/run_worker.py similarity index 75% rename from run.py rename to run_worker.py index e52d7d2..6376fbb 100644 --- a/run.py +++ b/run_worker.py @@ -1,10 +1,10 @@ import time import os import importlib -from lib.queue.queue import Queue +from lib.queue.worker import QueueWorker from lib.model.model import Model from lib.logger import logger -queue = Queue.create() +queue = QueueWorker.create() model = Model.create() diff --git a/start_healthcheck_and_model_engine.sh b/start_all.sh similarity index 81% rename from start_healthcheck_and_model_engine.sh rename to start_all.sh index f3a2022..f3c9197 100755 --- a/start_healthcheck_and_model_engine.sh +++ b/start_all.sh @@ -5,4 +5,5 @@ uvicorn main:app --host 0.0.0.0 --reload & # Start the second process in the foreground # This will ensure the script won't exit until this process does -python run.py \ No newline at end of file +python run_worker.py & +python run_processor.py \ No newline at end of file diff --git a/test/lib/model/test_audio.py b/test/lib/model/test_audio.py index 26f4443..9f226df 100644 --- a/test/lib/model/test_audio.py +++ b/test/lib/model/test_audio.py @@ -15,13 +15,13 @@ def setUp(self): @patch('urllib.request.Request') def test_fingerprint_audio_success(self, mock_request, mock_urlopen): mock_request.return_value = mock_request - + # Use the `with` statement for proper file handling with open("data/test-audio.mp3", 'rb') as f: contents = f.read() mock_urlopen.return_value = MagicMock(read=MagicMock(return_value=contents)) - + audio = schemas.Message(body=schemas.AudioInput(id="123", callback_url="http://example.com/callback", url="https://example.com/audio.mp3")) result = self.audio_model.fingerprint(audio) mock_request.assert_called_once_with(audio.body.url, headers={'User-Agent': 'Mozilla/5.0'}) @@ -36,13 +36,13 @@ def test_fingerprint_audio_failure(self, mock_decode_fingerprint, mock_fingerpri mock_request, mock_urlopen): mock_fingerprint_file.side_effect = FingerprintGenerationError("Failed to generate fingerprint") mock_request.return_value = mock_request - + # Use the `with` statement for proper file handling with open("data/test-audio.mp3", 'rb') as f: contents = f.read() mock_urlopen.return_value = MagicMock(read=MagicMock(return_value=contents)) - + audio = schemas.Message(body=schemas.AudioInput(id="123", callback_url="http://example.com/callback", url="https://example.com/audio.mp3")) result = self.audio_model.fingerprint(audio) mock_request.assert_called_once_with(audio.body.url, headers={'User-Agent': 'Mozilla/5.0'}) diff --git a/test/lib/queue/fake_sqs_message.py b/test/lib/queue/fake_sqs_message.py new file mode 100644 index 0000000..490cdaf --- /dev/null +++ b/test/lib/queue/fake_sqs_message.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel +class FakeSQSMessage(BaseModel): + body: str + receipt_handle: str + diff --git a/test/lib/queue/test_processor.py b/test/lib/queue/test_processor.py new file mode 100644 index 0000000..bd0299e --- /dev/null +++ b/test/lib/queue/test_processor.py @@ -0,0 +1,57 @@ +import unittest +from unittest.mock import MagicMock, patch +import json + +from lib.queue.processor import QueueProcessor +from lib import schemas +from test.lib.queue.fake_sqs_message import FakeSQSMessage +class TestQueueProcessor(unittest.TestCase): + + @patch('lib.queue.queue.boto3.resource') + @patch('lib.helpers.get_environment_setting', return_value='us-west-1') + def setUp(self, mock_get_env_setting, mock_boto_resource): + self.queue_name_input = 'mean_tokens__Model' + + # Mock the SQS resource and the queues + self.mock_sqs_resource = MagicMock() + self.mock_input_queue = MagicMock() + self.mock_input_queue.url = "http://queue/mean_tokens__Model" + self.mock_sqs_resource.queues.filter.return_value = [self.mock_input_queue] + mock_boto_resource.return_value = self.mock_sqs_resource + + # Initialize the QueueProcessor instance + self.queue_processor = QueueProcessor(self.queue_name_input, batch_size=2) + + def test_send_callbacks(self): + # Mocking necessary methods and creating fake data + self.queue_processor.receive_messages = MagicMock( + return_value=[(FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": {"callback_url": "http://example.com", "text": "This is a test", "id": 1}, "response": [1,2,3]})), self.mock_input_queue)] + ) + self.queue_processor.send_callback = MagicMock(return_value=None) + self.queue_processor.delete_messages = MagicMock(return_value=None) + + responses = self.queue_processor.send_callbacks() + + self.queue_processor.receive_messages.assert_called_once_with(2) + self.queue_processor.send_callback.assert_called() + self.queue_processor.delete_messages.assert_called() + + @patch('lib.queue.processor.requests.post') + def test_send_callback(self, mock_post): + message_body = {"callback_url": "http://example.com", "text": "This is a test"} + self.queue_processor.send_callback(message_body) + + mock_post.assert_called_once_with("http://example.com", json=message_body) + + @patch('lib.queue.processor.requests.post') + def test_send_callback_failure(self, mock_post): + mock_post.side_effect = Exception("Request Failed!") + message_body = {"callback_url": "http://example.com", "text": "This is a test"} + + with self.assertLogs(level='ERROR') as cm: + self.queue_processor.send_callback(message_body) + + self.assertIn("Callback fail! Failed with Request Failed! on http://example.com with body of", cm.output[0]) + +if __name__ == '__main__': + unittest.main() diff --git a/test/lib/queue/test_queue.py b/test/lib/queue/test_queue.py index cb33902..3d317c2 100644 --- a/test/lib/queue/test_queue.py +++ b/test/lib/queue/test_queue.py @@ -2,21 +2,14 @@ import os import unittest from unittest.mock import MagicMock, patch -from pydantic import BaseModel import numpy as np from lib.model.generic_transformer import GenericTransformerModel -from lib.queue.queue import Queue +from lib.queue.worker import QueueWorker from lib import schemas -class FakeSQSMessage(BaseModel): - body: str - receipt_handle: str +from test.lib.queue.fake_sqs_message import FakeSQSMessage -class TestQueue(unittest.TestCase): - # def overwrite_restrict_queues_by_suffix(queues, suffix): - # return [MagicMock()] - # - # @patch('lib.queue.queue.Queue.restrict_queues_by_suffix', side_effect=overwrite_restrict_queues_by_suffix) +class TestQueueWorker(unittest.TestCase): @patch('lib.queue.queue.boto3.resource') @patch('lib.helpers.get_environment_setting', return_value='us-west-1') def setUp(self, mock_get_env_setting, mock_boto_resource):#, mock_restrict_queues_by_suffix): @@ -24,7 +17,6 @@ def setUp(self, mock_get_env_setting, mock_boto_resource):#, mock_restrict_queue self.mock_model = MagicMock() self.queue_name_input = 'mean_tokens__Model' self.queue_name_output = 'mean_tokens__Model_output' - self.batch_size = 5 # Mock the SQS resource and the queues self.mock_sqs_resource = MagicMock() @@ -36,10 +28,10 @@ def setUp(self, mock_get_env_setting, mock_boto_resource):#, mock_restrict_queue mock_boto_resource.return_value = self.mock_sqs_resource # Initialize the SQSQueue instance - self.queue = Queue(self.queue_name_input, self.queue_name_output, self.batch_size) + self.queue = QueueWorker(self.queue_name_input, self.queue_name_output) def test_get_output_queue_name(self): - self.assertEqual(self.queue.get_output_queue_name('test'), 'test-output') + self.assertEqual(self.queue.get_output_queue_name('test'), 'test_output') self.assertEqual(self.queue.get_output_queue_name('test', 'new-output'), 'new-output') def test_fingerprint(self): @@ -58,7 +50,7 @@ def test_receive_messages(self): mock_queue2 = MagicMock() mock_queue2.receive_messages.return_value = [FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": {"id": 2, "callback_url": "http://example.com", "text": "This is another test"}}))] self.queue.input_queues = [mock_queue1, mock_queue2] - received_messages = self.queue.receive_messages(self.batch_size) + received_messages = self.queue.receive_messages(5) # Check if the right number of messages were received and the content is correct self.assertEqual(len(received_messages), 2) @@ -105,4 +97,4 @@ def test_push_message(self): self.assertEqual(returned_message, message_to_push) if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/test/lib/test_http.py b/test/lib/test_http.py index a4ec4bd..55c3d6a 100644 --- a/test/lib/test_http.py +++ b/test/lib/test_http.py @@ -2,14 +2,14 @@ import unittest from unittest.mock import patch from lib.http import app -from lib.queue.queue import Queue +from lib.queue.worker import QueueWorker class TestFingerprintItem(unittest.TestCase): def setUp(self): self.client = TestClient(app) - @patch.object(Queue, 'create') - @patch.object(Queue, 'push_message') + @patch.object(QueueWorker, 'create') + @patch.object(QueueWorker, 'push_message') def test_fingerprint_item(self, mock_push_message, mock_create): mock_queue = mock_create.return_value mock_queue.input_queue_name = "input_queue"