diff --git a/local.env b/.env_file similarity index 52% rename from local.env rename to .env_file index cb4e7c6..b9c2b3e 100644 --- a/local.env +++ b/.env_file @@ -1,4 +1,6 @@ +PRESTO_PORT=8000 DEPLOY_ENV=local -MODEL_NAME=mean_tokens.Model +# MODEL_NAME=mean_tokens.Model +MODEL_NAME=audio.Model AWS_ACCESS_KEY_ID=SOMETHING AWS_SECRET_ACCESS_KEY=OTHERTHING diff --git a/.env_file.test b/.env_file.test new file mode 100644 index 0000000..faeaf7b --- /dev/null +++ b/.env_file.test @@ -0,0 +1,10 @@ +PRESTO_PORT=8000 +DEPLOY_ENV=local +# MODEL_NAME=mean_tokens.Model +MODEL_NAME=audio.Model +AWS_ACCESS_KEY_ID=SOMETHING +AWS_SECRET_ACCESS_KEY=OTHERTHING +<<<<<<< HEAD +======= + +>>>>>>> master diff --git a/Dockerfile b/Dockerfile index 40972b2..9d27bf9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,10 @@ +ARG PRESTO_PORT + FROM python:3.9 -EXPOSE 8000 + +ENV PRESTO_PORT=${PRESTO_PORT} +EXPOSE ${PRESTO_PORT} + WORKDIR /app ENV DEBIAN_FRONTEND=noninteractive diff --git a/Makefile b/Makefile index 2b6ba87..69c6d8d 100644 --- a/Makefile +++ b/Makefile @@ -1,13 +1,16 @@ .PHONY: run run_http run_worker run_test run: - ./start_healthcheck_and_model_engine.sh + ./start_all.sh run_http: uvicorn main:app --host 0.0.0.0 --reload run_worker: - python run.py + python run_worker.py + +run_processor: + python run_processor.py run_test: - python -m unittest discover . \ No newline at end of file + python -m pytest test \ No newline at end of file diff --git a/README.md b/README.md index 6ff5eb1..fb9cb9a 100644 --- a/README.md +++ b/README.md @@ -168,12 +168,12 @@ Output Message: ``` ### Endpoints -#### /fingerprint_item/{fingerprinter} +#### /process_item/{process_name} This endpoint pushes a message into a queue. It's an async operation, meaning the server will respond before the operation is complete. This is useful when working with slow or unreliable external resources. Request ``` -curl -X POST "http://127.0.0.1:8000/fingerprint_item/sample_fingerprinter" -H "accept: application/json" -H "Content-Type: application/json" -d "{\"message_key\":\"message_value\"}" +curl -X POST "http://127.0.0.1:8000/process_item/mean_tokens__Model" -H "accept: application/json" -H "Content-Type: application/json" -d "{\"message_key\":\"message_value\"}" ``` Replace sample_fingerprinter with the name of your fingerprinter, and message_key and message_value with your actual message data. diff --git a/docker-compose.yml b/docker-compose.yml index 1ae1eaf..dcf5413 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -5,13 +5,17 @@ services: platform: linux/amd64 build: . env_file: - - ./local.env + - ./.env_file depends_on: - elasticmq links: - elasticmq volumes: - ./:/app + ports: + - "${PRESTO_PORT}:${PRESTO_PORT}" + args: + - PRESTO_PORT=${PRESTO_PORT} elasticmq: image: softwaremill/elasticmq hostname: presto-elasticmq diff --git a/lib/http.py b/lib/http.py index 26bf8e3..65070de 100644 --- a/lib/http.py +++ b/lib/http.py @@ -5,9 +5,10 @@ from httpx import HTTPStatusError from fastapi import FastAPI, Request from pydantic import BaseModel -from lib.queue.queue import Queue +from lib.queue.worker import QueueWorker from lib.logger import logger from lib import schemas +from lib.sentry import sentry_sdk app = FastAPI() @@ -27,14 +28,15 @@ async def post_url(url: str, params: dict) -> Dict[str, Any]: except HTTPStatusError: return {"error": f"HTTP Error on Attempt to call {url} with {params}"} -@app.post("/fingerprint_item/{fingerprinter}") -def fingerprint_item(fingerprinter: str, message: Dict[str, Any]): - queue = Queue.create(fingerprinter) - queue.push_message(fingerprinter, schemas.Message(body=message, input_queue=queue.input_queue_name, output_queue=queue.output_queue_name, start_time=str(datetime.datetime.now()))) - return {"message": "Message pushed successfully"} +@app.post("/process_item/{process_name}") +def process_item(process_name: str, message: Dict[str, Any]): + logger.info(message) + queue = QueueWorker.create(process_name) + queue.push_message(process_name, schemas.Message(body=message)) + return {"message": "Message pushed successfully", "queue": process_name, "body": message} @app.post("/trigger_callback") -async def fingerprint_item(message: Dict[str, Any]): +async def process_item(message: Dict[str, Any]): url = message.get("callback_url") if url: response = await post_url(url, message) @@ -46,5 +48,10 @@ async def fingerprint_item(message: Dict[str, Any]): return {"message": "No Message Callback, Passing"} @app.get("/ping") -def fingerprint_item(): +def process_item(): return {"pong": 1} + +@app.post("/echo") +async def echo(message: Dict[str, Any]): + logger.info(f"About to echo message of {message}") + return {"echo": message} \ No newline at end of file diff --git a/lib/model/audio.py b/lib/model/audio.py index f6fccec..fea03cb 100644 --- a/lib/model/audio.py +++ b/lib/model/audio.py @@ -20,7 +20,7 @@ def audio_hasher(self, filename: str) -> List[int]: except acoustid.FingerprintGenerationError: return [] - def fingerprint(self, audio: schemas.Message) -> Dict[str, Union[str, List[int]]]: + def process(self, audio: schemas.Message) -> Dict[str, Union[str, List[int]]]: temp_file_name = self.get_tempfile_for_url(audio.body.url) try: hash_value = self.audio_hasher(temp_file_name) diff --git a/lib/model/generic_transformer.py b/lib/model/generic_transformer.py index 9b40769..56ded7c 100644 --- a/lib/model/generic_transformer.py +++ b/lib/model/generic_transformer.py @@ -1,7 +1,7 @@ import os from typing import Union, Dict, List from sentence_transformers import SentenceTransformer - +from lib.logger import logger from lib.model.model import Model from lib import schemas @@ -21,7 +21,7 @@ def respond(self, docs: Union[List[schemas.Message], schemas.Message]) -> List[s """ if not isinstance(docs, list): docs = [docs] - print(docs) + logger.info(docs) vectorizable_texts = [e.body.text for e in docs] vectorized = self.vectorize(vectorizable_texts) for doc, vector in zip(docs, vectorized): diff --git a/lib/model/image.py b/lib/model/image.py index 33acefb..41cdf4c 100644 --- a/lib/model/image.py +++ b/lib/model/image.py @@ -30,7 +30,7 @@ def get_iobytes_for_image(self, image: schemas.Message) -> io.BytesIO: ).read() ) - def fingerprint(self, image: schemas.Message) -> schemas.ImageOutput: + def process(self, image: schemas.Message) -> schemas.ImageOutput: """ Generic function for returning the actual response. """ diff --git a/lib/model/model.py b/lib/model/model.py index 9140fbe..0f9d938 100644 --- a/lib/model/model.py +++ b/lib/model/model.py @@ -33,7 +33,7 @@ def get_tempfile(self) -> Any: """ return tempfile.NamedTemporaryFile() - def fingerprint(self, messages: Union[List[schemas.Message], schemas.Message]) -> List[schemas.Message]: + def process(self, messages: Union[List[schemas.Message], schemas.Message]) -> List[schemas.Message]: return [] def respond(self, messages: Union[List[schemas.Message], schemas.Message]) -> List[schemas.Message]: @@ -43,7 +43,7 @@ def respond(self, messages: Union[List[schemas.Message], schemas.Message]) -> Li if not isinstance(messages, list): messages = [messages] for message in messages: - message.response = self.fingerprint(message) + message.response = self.process(message) return messages @classmethod diff --git a/lib/model/video.py b/lib/model/video.py index 9507988..8420041 100644 --- a/lib/model/video.py +++ b/lib/model/video.py @@ -41,7 +41,7 @@ def tmk_bucket(self) -> str: """ return "presto_tmk_videos" - def fingerprint(self, video: schemas.Message) -> schemas.VideoOutput: + def process(self, video: schemas.Message) -> schemas.VideoOutput: """ Main fingerprinting routine - download video to disk, get short hash, then calculate larger TMK hash and upload that to S3. diff --git a/lib/queue/processor.py b/lib/queue/processor.py new file mode 100644 index 0000000..089900c --- /dev/null +++ b/lib/queue/processor.py @@ -0,0 +1,57 @@ +from typing import List +import json + +import requests + +from lib import schemas +from lib.logger import logger +from lib.helpers import get_setting +from lib.queue.queue import Queue +class QueueProcessor(Queue): + @classmethod + def create(cls, input_queue_name: str = None, batch_size: int = 10): + """ + Instantiate a queue. Must pass input_queue_name, output_queue_name, and batch_size. + Pulls settings and then inits instance. + """ + input_queue_name = get_setting(input_queue_name, "MODEL_NAME").replace(".", "__") + logger.info(f"Starting queue with: ('{input_queue_name}', {batch_size})") + return QueueProcessor(input_queue_name, batch_size) + + def __init__(self, input_queue_name: str, output_queue_name: str = None, batch_size: int = 1): + """ + Start a specific queue - must pass input_queue_name - optionally pass output_queue_name, batch_size. + """ + super().__init__() + self.input_queue_name = input_queue_name + self.input_queues = self.restrict_queues_to_suffix(self.get_or_create_queues(input_queue_name+"_output"), "_output") + self.all_queues = self.store_queue_map(self.input_queues) + logger.info(f"Processor listening to queues of {self.all_queues}") + self.batch_size = batch_size + + def send_callbacks(self) -> List[schemas.Message]: + """ + Main routine. Given a model, in a loop, read tasks from input_queue_name at batch_size depth, + pass messages to model to respond (i.e. fingerprint) them, then pass responses to output queue. + If failures happen at any point, resend failed messages to input queue. + """ + messages_with_queues = self.receive_messages(self.batch_size) + if messages_with_queues: + logger.debug(f"About to respond to: ({messages_with_queues})") + bodies = [schemas.Message(**json.loads(message.body)) for message, queue in messages_with_queues] + for body in bodies: + self.send_callback(body) + self.delete_messages(messages_with_queues) + + + def send_callback(self, message): + """ + Rescue against failures when attempting to respond (i.e. fingerprint) from models. + Return responses if no failure. + """ + logger.info(f"Message for callback is: {message}") + try: + callback_url = message.body.callback_url + requests.post(callback_url, json=message.dict()) + except Exception as e: + logger.error(f"Callback fail! Failed with {e} on {callback_url} with message of {message}") \ No newline at end of file diff --git a/lib/queue/queue.py b/lib/queue/queue.py index 7e5ea0f..99d1565 100644 --- a/lib/queue/queue.py +++ b/lib/queue/queue.py @@ -1,48 +1,28 @@ import json -from typing import Any, List, Dict, Tuple, Union +from typing import List, Dict, Tuple import os import boto3 import botocore -from lib.helpers import get_class, get_setting, get_environment_setting -from lib.model.model import Model +from lib.helpers import get_environment_setting from lib.logger import logger from lib import schemas +SQS_MAX_BATCH_SIZE = 10 class Queue: - @classmethod - def create(cls, input_queue_name: str = None, output_queue_name: str = None, batch_size: int = 10): + def __init__(self): """ - Instantiate a queue. Must pass input_queue_name, output_queue_name, and batch_size. - Pulls settings and then inits instance. - """ - input_queue_name = get_setting(input_queue_name, "MODEL_NAME").replace(".", "__") - output_queue_name = output_queue_name or f"{input_queue_name}_output" - logger.info(f"Starting queue with: ('{input_queue_name}', '{output_queue_name}', {batch_size})") - return Queue(input_queue_name, output_queue_name, batch_size) - - def __init__(self, input_queue_name: str, output_queue_name: str = None, batch_size: int = 1): - """ - Start a specific queue - must pass input_queue_name - optionally pass output_queue_name, batch_size. + Start a specific queue - must pass input_queue_name. """ self.sqs = self.get_sqs() - self.input_queue_name = input_queue_name - self.input_queues = self.restrict_queues_by_suffix(self.get_or_create_queues(input_queue_name), "_output") - if output_queue_name: - self.output_queue_name = self.get_output_queue_name(input_queue_name, output_queue_name) - self.output_queues = self.get_or_create_queues(output_queue_name) - self.all_queues = self.store_queue_map() - self.batch_size = batch_size - - def store_queue_map(self) -> Dict[str, boto3.resources.base.ServiceResource]: + def store_queue_map(self, all_queues: List[boto3.resources.base.ServiceResource]) -> Dict[str, boto3.resources.base.ServiceResource]: """ Store a quick lookup so that we dont loop through this over and over in other places. """ queue_map = {} - for group in [self.input_queues, self.output_queues]: - for q in group: - queue_map[self.queue_name(q)] = q + for queue in all_queues: + queue_map[self.queue_name(queue)] = queue return queue_map def queue_name(self, queue: boto3.resources.base.ServiceResource) -> str: @@ -51,6 +31,12 @@ def queue_name(self, queue: boto3.resources.base.ServiceResource) -> str: """ return queue.url.split('/')[-1] + def restrict_queues_to_suffix(self, queues: List[boto3.resources.base.ServiceResource], suffix: str) -> List[boto3.resources.base.ServiceResource]: + """ + When plucking input queues, we want to omit any queues that are our paired suffix queues.. + """ + return [queue for queue in queues if self.queue_name(queue).endswith(suffix)] + def restrict_queues_by_suffix(self, queues: List[boto3.resources.base.ServiceResource], suffix: str) -> List[boto3.resources.base.ServiceResource]: """ When plucking input queues, we want to omit any queues that are our paired suffix queues.. @@ -70,8 +56,9 @@ def get_or_create_queues(self, queue_name: str) -> List[boto3.resources.base.Ser """ try: found_queues = [q for q in self.sqs.queues.filter(QueueNamePrefix=queue_name)] - if found_queues: - return found_queues + exact_match_queues = [queue for queue in found_queues if queue.attributes['QueueArn'].split(':')[-1] == queue_name] + if exact_match_queues: + return exact_match_queues else: return [self.create_queue(queue_name)] except botocore.exceptions.ClientError as e: @@ -101,7 +88,7 @@ def get_output_queue_name(self, input_queue_name: str, output_queue_name: str = If output_queue_name was empty or None, set name for queue. """ if not output_queue_name: - output_queue_name = f'{input_queue_name}-output' + output_queue_name = f'{input_queue_name}_output' return output_queue_name def group_deletions(self, messages_with_queues: List[Tuple[schemas.Message, boto3.resources.base.ServiceResource]]) -> Dict[boto3.resources.base.ServiceResource, List[schemas.Message]]: @@ -115,7 +102,7 @@ def group_deletions(self, messages_with_queues: List[Tuple[schemas.Message, boto queue_to_messages[queue].append(message) return queue_to_messages - def delete_messages(self, messages_with_queues: List[Tuple[Dict[str, Any], boto3.resources.base.ServiceResource]]) -> None: + def delete_messages(self, messages_with_queues: List[Tuple[schemas.Message, boto3.resources.base.ServiceResource]]) -> None: """ Delete messages as we process them so other processes don't pick them up. SQS deals in max batches of 10, so break up messages into groups of 10 @@ -141,31 +128,6 @@ def delete_messages_from_queue(self, queue: boto3.resources.base.ServiceResource entries.append(entry) queue.delete_messages(Entries=entries) - def safely_respond(self, model: Model) -> List[schemas.Message]: - """ - Rescue against failures when attempting to respond (i.e. fingerprint) from models. - Return responses if no failure. - """ - messages_with_queues = self.receive_messages(model.BATCH_SIZE) - responses = [] - if messages_with_queues: - logger.debug(f"About to respond to: ({messages_with_queues})") - responses = model.respond([schemas.Message(**json.loads(message.body)) for message, queue in messages_with_queues]) - self.delete_messages(messages_with_queues) - return responses - - def fingerprint(self, model: Model): - """ - Main routine. Given a model, in a loop, read tasks from input_queue_name at batch_size depth, - pass messages to model to respond (i.e. fingerprint) them, then pass responses to output queue. - If failures happen at any point, resend failed messages to input queue. - """ - responses = self.safely_respond(model) - if responses: - for response in responses: - logger.info(f"Processing message of: ({response})") - self.return_response(response) - def receive_messages(self, batch_size: int = 1) -> List[Tuple[schemas.Message, boto3.resources.base.ServiceResource]]: """ Pull batch_size messages from input queue. @@ -175,19 +137,13 @@ def receive_messages(self, batch_size: int = 1) -> List[Tuple[schemas.Message, b for queue in self.input_queues: if batch_size <= 0: break - batch_messages = queue.receive_messages(MaxNumberOfMessages=min(batch_size, self.batch_size)) + batch_messages = queue.receive_messages(MaxNumberOfMessages=min(batch_size, SQS_MAX_BATCH_SIZE)) for message in batch_messages: if batch_size > 0: messages_with_queues.append((message, queue)) batch_size -= 1 return messages_with_queues - def return_response(self, message: schemas.Message): - """ - Send message to output queue - """ - return self.push_message(self.output_queue_name, message) - def find_queue_by_name(self, queue_name: str) -> boto3.resources.base.ServiceResource: """ Search through queues to find the right one diff --git a/lib/queue/worker.py b/lib/queue/worker.py new file mode 100644 index 0000000..21c046e --- /dev/null +++ b/lib/queue/worker.py @@ -0,0 +1,60 @@ +import json +from typing import List +from lib import schemas +from lib.logger import logger +from lib.queue.queue import Queue +from lib.model.model import Model +from lib.helpers import get_setting +class QueueWorker(Queue): + @classmethod + def create(cls, input_queue_name: str = None): + """ + Instantiate a queue worker. Must pass input_queue_name. + Pulls settings and then inits instance. + """ + input_queue_name = get_setting(input_queue_name, "MODEL_NAME").replace(".", "__") + output_queue_name = f"{input_queue_name}_output" + logger.info(f"Starting queue with: ('{input_queue_name}', '{output_queue_name}')") + return QueueWorker(input_queue_name, output_queue_name) + + def __init__(self, input_queue_name: str, output_queue_name: str = None): + """ + Start a specific queue - must pass input_queue_name - optionally pass output_queue_name. + """ + super().__init__() + self.input_queue_name = input_queue_name + self.input_queues = self.restrict_queues_by_suffix(self.get_or_create_queues(input_queue_name), "_output") + if output_queue_name: + self.output_queue_name = self.get_output_queue_name(input_queue_name, output_queue_name) + self.output_queues = self.get_or_create_queues(output_queue_name) + self.all_queues = self.store_queue_map([item for row in [self.input_queues, self.output_queues] for item in row]) + logger.info(f"Worker listening to queues of {self.all_queues}") + + def process(self, model: Model): + """ + Main routine. Given a model, in a loop, read tasks from input_queue_name, + pass messages to model to respond (i.e. fingerprint) them, then pass responses to output queue. + If failures happen at any point, resend failed messages to input queue. + """ + responses = self.safely_respond(model) + if responses: + for response in responses: + logger.info(f"Processing message of: ({response})") + self.push_message(self.output_queue_name, response) + + def safely_respond(self, model: Model) -> List[schemas.Message]: + """ + Rescue against failures when attempting to respond (i.e. fingerprint) from models. + Return responses if no failure. + """ + messages_with_queues = self.receive_messages(model.BATCH_SIZE) + responses = [] + if messages_with_queues: + logger.debug(f"About to respond to: ({messages_with_queues})") + try: + responses = model.respond([schemas.Message(**json.loads(message.body)) for message, queue in messages_with_queues]) + except Exception as e: + logger.error(e) + self.delete_messages(messages_with_queues) + return responses + diff --git a/lib/s3.py b/lib/s3.py index b2fb416..c19a122 100644 --- a/lib/s3.py +++ b/lib/s3.py @@ -1,5 +1,5 @@ import boto3 - +from lib.logger import logger def upload_file_to_s3(bucket: str, filename: str): """ Generic upload helper for s3. Could be moved over to helpers folder... @@ -11,6 +11,6 @@ def upload_file_to_s3(bucket: str, filename: str): # Upload the file to S3 try: s3_client.upload_file(filename, bucket, file_name) - print(f'Successfully uploaded file {file_name} to S3 bucket.') + logger.info(f'Successfully uploaded file {file_name} to S3 bucket.') except Exception as e: - print(f'Failed to upload file {file_name} to S3 bucket: {e}') + logger.error(f'Failed to upload file {file_name} to S3 bucket: {e}') diff --git a/lib/schemas.py b/lib/schemas.py index 3b6ed34..26bdc03 100644 --- a/lib/schemas.py +++ b/lib/schemas.py @@ -1,54 +1,60 @@ -from typing import Any, List, Union -from pydantic import BaseModel, HttpUrl +from typing import Any, List, Optional, Union +from pydantic import BaseModel # Output hash values can be of different types. HashValue = Union[List[float], str, int] class TextInput(BaseModel): id: str - callback_url: HttpUrl + callback_url: str text: str class TextOutput(BaseModel): id: str - callback_url: HttpUrl + callback_url: str text: str class VideoInput(BaseModel): id: str - callback_url: HttpUrl - url: HttpUrl + callback_url: str + url: str class VideoOutput(BaseModel): id: str - callback_url: HttpUrl - url: HttpUrl + callback_url: str + url: str bucket: str outfile: str hash_value: HashValue class AudioInput(BaseModel): id: str - callback_url: HttpUrl - url: HttpUrl + callback_url: str + url: str class AudioOutput(BaseModel): id: str - callback_url: HttpUrl - url: HttpUrl + callback_url: str + url: str hash_value: HashValue class ImageInput(BaseModel): id: str - callback_url: HttpUrl - url: HttpUrl + callback_url: str + url: str class ImageOutput(BaseModel): id: str - callback_url: HttpUrl - url: HttpUrl + callback_url: str + url: str hash_value: HashValue -class Message(BaseModel): - body: Union[TextInput, VideoInput, AudioInput, ImageInput] - response: Any +class GenericInput(BaseModel): + id: str + callback_url: str + url: Optional[str] = None + text: Optional[str] = None + raw: Optional[dict] = {} +class Message(BaseModel): + body: GenericInput + response: Any \ No newline at end of file diff --git a/lib/sentry.py b/lib/sentry.py new file mode 100644 index 0000000..8804a1e --- /dev/null +++ b/lib/sentry.py @@ -0,0 +1,9 @@ +import os +import sentry_sdk +from lib.helpers import get_environment_setting + +sentry_sdk.init( + dsn=get_environment_setting('sentry_sdk_dsn'), + environment=get_environment_setting("DEPLOY_ENV"), + traces_sample_rate=1.0, +) diff --git a/requirements.txt b/requirements.txt index 1594b52..a95670b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,6 @@ httpx==0.23.1 huggingface-hub==0.11.0 fasttext==0.9.2 langcodes==3.3.0 +requests==2.31.0 +pytest==7.4.0 +sentry-sdk==1.30.0 \ No newline at end of file diff --git a/run.py b/run.py deleted file mode 100644 index e52d7d2..0000000 --- a/run.py +++ /dev/null @@ -1,13 +0,0 @@ -import time -import os -import importlib -from lib.queue.queue import Queue -from lib.model.model import Model -from lib.logger import logger -queue = Queue.create() - -model = Model.create() - -logger.info("Beginning fingerprinter loop...") -while True: - queue.fingerprint(model) diff --git a/run_processor.py b/run_processor.py new file mode 100644 index 0000000..e7c6ee9 --- /dev/null +++ b/run_processor.py @@ -0,0 +1,12 @@ +import time +import os +import importlib +from lib.queue.processor import QueueProcessor +from lib.model.model import Model +from lib.logger import logger +from lib.sentry import sentry_sdk +queue = QueueProcessor.create() + +logger.info("Beginning callback loop...") +while True: + queue.send_callbacks() \ No newline at end of file diff --git a/run_worker.py b/run_worker.py new file mode 100644 index 0000000..fc7353e --- /dev/null +++ b/run_worker.py @@ -0,0 +1,14 @@ +import time +import os +import importlib +from lib.queue.worker import QueueWorker +from lib.model.model import Model +from lib.logger import logger +from lib.sentry import sentry_sdk +queue = QueueWorker.create() + +model = Model.create() + +logger.info("Beginning work loop...") +while True: + queue.process(model) diff --git a/start_healthcheck_and_model_engine.sh b/start_all.sh similarity index 59% rename from start_healthcheck_and_model_engine.sh rename to start_all.sh index f3a2022..eb4eaf5 100755 --- a/start_healthcheck_and_model_engine.sh +++ b/start_all.sh @@ -1,8 +1,9 @@ #!/bin/sh # Start the first process in the background -uvicorn main:app --host 0.0.0.0 --reload & +uvicorn main:app --host 0.0.0.0 --port ${PRESTO_PORT} --reload & # Start the second process in the foreground # This will ensure the script won't exit until this process does -python run.py \ No newline at end of file +python run_worker.py & +python run_processor.py \ No newline at end of file diff --git a/test/lib/model/test_audio.py b/test/lib/model/test_audio.py index 26f4443..c6359ba 100644 --- a/test/lib/model/test_audio.py +++ b/test/lib/model/test_audio.py @@ -6,24 +6,26 @@ from acoustid import FingerprintGenerationError from lib import schemas - +FINGERPRINT_RESPONSE = (170.6, b'AQAA3VFYJYrGJMj74EOZUCfCHGqYLBZO8UiX5bie47sCV0xwBTe49IiVHHrQnIImJyP-44rxI2cYiuiHMCMDPcqJrBcwnYeryBX6rccR_4Iy_YhfXESzqELJ5ASTLwhvNM94KDp9_IB_6NqDZ5I9_IWYvDiNCc1z8IeuHXkYpfhSg8su3M2K5lkrFM-PK3mQH8lznEpidLEoNAeLyWispQpqvfgRZjp0lHaENAmzBeamoRIZMrha5IsyHM6H7-jRhJlSBU1FLgiv4xlKUQmNptGOU3jzIj80Jk5xsQp0UegxJtmSCpeS5PiDozz0MAb5BG5z9MEPIcy0HeWD58M_4sotlNOF8UeuLJEgJt4xkUee4cflI1nMI4uciBLeGu9z9NjH4x9iSXoELYs04pqCSCvx5ei1Tzi3NMFRmsa2DD2POxVCR4IPMSfySC-u0EKuE6IOqz_6zJh8BzZlgc1IQkyTGdeLa4cT7bi2E30e_OgTI4xDPCGLJ_gvZHlwT7EgJc2XIBY_4fnBPENC_YilsGjDJzhJoeyCJn9A1kaeDUw4VA_-41uDGycO8w_eWlCU66iio0eYL8hVK_gD5QlyMR7hzzh-vDm6JE_hcTpq5cFTdFcKZfHxRMTZCS2VHKdOfDve5Hh0hCV9JEtMSbhxSSMuHU9y4kaTx5guHIGsoEAAwoASjmDlkSAEOCSoQEw4IDgghiguAEZAAMaAAYYAhBhACBEiiAGAIUCUUUgSESjgSBlKjZEEEIAFUEIBBRBinAAplFJKAIYQEAQSA4ywACkjgBFMAEoAQgYQwARB1gFmBCAECAAIMYYIoBxBBAAAFCKAAEgIBAQgAghgihIWBACEIUEIJEZIZIBRACGAGAEEIAGAUIBIhBCgRkI') class TestAudio(unittest.TestCase): def setUp(self): self.audio_model = Model() @patch('urllib.request.urlopen') @patch('urllib.request.Request') - def test_fingerprint_audio_success(self, mock_request, mock_urlopen): + @patch('acoustid.fingerprint_file') + def test_process_audio_success(self, mock_fingerprint_file, mock_request, mock_urlopen): + mock_fingerprint_file.return_value = FINGERPRINT_RESPONSE mock_request.return_value = mock_request - + # Use the `with` statement for proper file handling with open("data/test-audio.mp3", 'rb') as f: contents = f.read() mock_urlopen.return_value = MagicMock(read=MagicMock(return_value=contents)) - + audio = schemas.Message(body=schemas.AudioInput(id="123", callback_url="http://example.com/callback", url="https://example.com/audio.mp3")) - result = self.audio_model.fingerprint(audio) + result = self.audio_model.process(audio) mock_request.assert_called_once_with(audio.body.url, headers={'User-Agent': 'Mozilla/5.0'}) mock_urlopen.assert_called_once_with(mock_request) self.assertEqual(list, type(result["hash_value"])) @@ -32,19 +34,19 @@ def test_fingerprint_audio_success(self, mock_request, mock_urlopen): @patch('urllib.request.Request') @patch('acoustid.fingerprint_file') @patch('acoustid.chromaprint.decode_fingerprint') - def test_fingerprint_audio_failure(self, mock_decode_fingerprint, mock_fingerprint_file, + def test_process_audio_failure(self, mock_decode_fingerprint, mock_fingerprint_file, mock_request, mock_urlopen): mock_fingerprint_file.side_effect = FingerprintGenerationError("Failed to generate fingerprint") mock_request.return_value = mock_request - + # Use the `with` statement for proper file handling with open("data/test-audio.mp3", 'rb') as f: contents = f.read() mock_urlopen.return_value = MagicMock(read=MagicMock(return_value=contents)) - + audio = schemas.Message(body=schemas.AudioInput(id="123", callback_url="http://example.com/callback", url="https://example.com/audio.mp3")) - result = self.audio_model.fingerprint(audio) + result = self.audio_model.process(audio) mock_request.assert_called_once_with(audio.body.url, headers={'User-Agent': 'Mozilla/5.0'}) mock_urlopen.assert_called_once_with(mock_request) self.assertEqual([], result["hash_value"]) diff --git a/test/lib/model/test_fasttext.py b/test/lib/model/test_fasttext.py index 29decb7..ab2129e 100644 --- a/test/lib/model/test_fasttext.py +++ b/test/lib/model/test_fasttext.py @@ -1,33 +1,23 @@ -import os import unittest -from unittest.mock import MagicMock - -import numpy as np - +from unittest.mock import patch, MagicMock from lib.model.fasttext import FasttextModel from lib import schemas class TestFasttextModel(unittest.TestCase): def setUp(self): - self.model = FasttextModel() self.mock_model = MagicMock() - def test_respond(self): - query = [schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="Hello, how are you?")), - schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="今天是星期二")), - schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="چھِ کٲشرۍ نٹن گۅرزٕ خنجر وُچھِتھ اَژان لرزٕ چھُکھ کانہہ دِلاور وُچھِتھ")), - schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="🐐🐐🐐🐐123")), - schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text=""))] - - - response = self.model.respond(query) - - self.assertEqual(len(response), 5) - self.assertEqual(response[0].response, {'language': 'en', 'script': None, 'score': 1.0}) - self.assertEqual(response[1].response, {'language': 'zh', 'script': 'Hans', 'score': 0.8305}) - self.assertEqual(response[2].response, {'language': 'ks', 'script': 'Arab', 'score': 0.9999}) - self.assertEqual(response[3].response, {'language': 'bo', 'script': 'Tibt', 'score': 0.2168}) #non-text content returns random language with low certainty - self.assertEqual(response[4].response, {'language': 'en', 'script': None, 'score': 0.8267}) #empty string returns english with high-ish confidence + @patch('lib.model.fasttext.hf_hub_download') + @patch('lib.model.fasttext.fasttext.load_model') + def test_respond(self, mock_fasttext_load_model, mock_hf_hub_download): + mock_hf_hub_download.return_value = 'mocked_path' + mock_fasttext_load_model.return_value = self.mock_model + self.mock_model.predict.return_value = (['__label__eng_Latn'], [0.9]) + model = FasttextModel() # Now it uses mocked functions + query = [schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="Hello, how are you?"))] + response = model.respond(query) + self.assertEqual(len(response), 1) + self.assertEqual(response[0].response, {'language': 'en', 'script': None, 'score': 0.9}) if __name__ == '__main__': unittest.main() diff --git a/test/lib/model/test_image.py b/test/lib/model/test_image.py index fee690a..54ddb7e 100644 --- a/test/lib/model/test_image.py +++ b/test/lib/model/test_image.py @@ -39,11 +39,11 @@ def test_get_iobytes_for_image_raises_error(self, mock_urlopen): @patch.object(Model, "get_iobytes_for_image") @patch.object(Model, "compute_pdq") - def test_fingerprint(self, mock_compute_pdq, mock_get_iobytes_for_image): + def test_process(self, mock_compute_pdq, mock_get_iobytes_for_image): mock_compute_pdq.return_value = "1001" mock_get_iobytes_for_image.return_value = io.BytesIO(b"image_bytes") image = schemas.Message(body=schemas.ImageInput(id="123", callback_url="http://example.com?callback", url="http://example.com/image.jpg")) - result = Model().fingerprint(image) + result = Model().process(image) self.assertEqual(result, {"hash_value": "1001"}) diff --git a/test/lib/model/test_video.py b/test/lib/model/test_video.py index 62a67bb..700438a 100644 --- a/test/lib/model/test_video.py +++ b/test/lib/model/test_video.py @@ -24,7 +24,7 @@ def test_get_tempfile(self, mock_named_tempfile): @patch('tmkpy.hashVideo') @patch('s3.upload_file_to_s3') @patch('pathlib.Path') - def test_fingerprint_video(self, mock_pathlib, mock_upload_file_to_s3, + def test_process_video(self, mock_pathlib, mock_upload_file_to_s3, mock_hash_video, mock_urlopen): with open("data/test-video.mp4", "rb") as video_file: video_contents = video_file.read() @@ -32,7 +32,7 @@ def test_fingerprint_video(self, mock_pathlib, mock_upload_file_to_s3, mock_hash_video_output.getPureAverageFeature.return_value = "hash_value" mock_hash_video.return_value = mock_hash_video_output mock_urlopen.return_value = MagicMock(read=MagicMock(return_value=video_contents)) - self.video_model.fingerprint(schemas.Message(body=schemas.VideoInput(id="123", callback_url="http://blah.com?callback_id=123", url="http://example.com/video.mp4"))) + self.video_model.process(schemas.Message(body=schemas.VideoInput(id="123", callback_url="http://blah.com?callback_id=123", url="http://example.com/video.mp4"))) mock_urlopen.assert_called_once() mock_hash_video.assert_called_once_with(ANY, "/usr/local/bin/ffmpeg") @@ -50,16 +50,16 @@ def test_tmk_program_name(self): def test_respond_with_single_video(self): video = schemas.Message(body=schemas.VideoInput(id="123", callback_url="http://blah.com?callback_id=123", url="http://example.com/video.mp4")) - mock_fingerprint = MagicMock() - self.video_model.fingerprint = mock_fingerprint + mock_process = MagicMock() + self.video_model.process = mock_process result = self.video_model.respond(video) - mock_fingerprint.assert_called_once_with(video) + mock_process.assert_called_once_with(video) self.assertEqual(result, [video]) def test_respond_with_multiple_videos(self): videos = [schemas.Message(body=schemas.VideoInput(id="123", callback_url="http://blah.com?callback_id=123", url="http://example.com/video1.mp4")), schemas.Message(body=schemas.VideoInput(id="123", callback_url="http://blah.com?callback_id=123", url="http://example.com/video2.mp4"))] - mock_fingerprint = MagicMock() - self.video_model.fingerprint = mock_fingerprint + mock_process = MagicMock() + self.video_model.process = mock_process result = self.video_model.respond(videos) - mock_fingerprint.assert_called_with(videos[1]) + mock_process.assert_called_with(videos[1]) self.assertEqual(result, videos) \ No newline at end of file diff --git a/test/lib/queue/fake_sqs_message.py b/test/lib/queue/fake_sqs_message.py new file mode 100644 index 0000000..490cdaf --- /dev/null +++ b/test/lib/queue/fake_sqs_message.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel +class FakeSQSMessage(BaseModel): + body: str + receipt_handle: str + diff --git a/test/lib/queue/test_processor.py b/test/lib/queue/test_processor.py new file mode 100644 index 0000000..85dc7ac --- /dev/null +++ b/test/lib/queue/test_processor.py @@ -0,0 +1,55 @@ +import unittest +from unittest.mock import MagicMock, patch +import json + +from lib.queue.processor import QueueProcessor +from lib import schemas +from test.lib.queue.fake_sqs_message import FakeSQSMessage +class TestQueueProcessor(unittest.TestCase): + + @patch('lib.queue.queue.boto3.resource') + @patch('lib.helpers.get_environment_setting', return_value='us-west-1') + def setUp(self, mock_get_env_setting, mock_boto_resource): + self.queue_name_input = 'mean_tokens__Model' + + # Mock the SQS resource and the queues + self.mock_sqs_resource = MagicMock() + self.mock_input_queue = MagicMock() + self.mock_input_queue.url = "http://queue/mean_tokens__Model" + self.mock_sqs_resource.queues.filter.return_value = [self.mock_input_queue] + mock_boto_resource.return_value = self.mock_sqs_resource + + # Initialize the QueueProcessor instance + self.queue_processor = QueueProcessor(self.queue_name_input, batch_size=2) + + def test_send_callbacks(self): + # Mocking necessary methods and creating fake data + self.queue_processor.receive_messages = MagicMock( + return_value=[(FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": {"callback_url": "http://example.com", "text": "This is a test", "id": 1}, "response": [1,2,3]})), self.mock_input_queue)] + ) + self.queue_processor.send_callback = MagicMock(return_value=None) + self.queue_processor.delete_messages = MagicMock(return_value=None) + + responses = self.queue_processor.send_callbacks() + + self.queue_processor.receive_messages.assert_called_once_with(2) + self.queue_processor.send_callback.assert_called() + self.queue_processor.delete_messages.assert_called() + + @patch('lib.queue.processor.requests.post') + def test_send_callback(self, mock_post): + message_body = schemas.Message(body={"callback_url": "http://example.com", "text": "This is a test", "id": 123}, response=[1,2,3]) + self.queue_processor.send_callback(message_body) + + mock_post.assert_called_once_with("http://example.com", json=message_body) + + @patch('lib.queue.processor.requests.post') + def test_send_callback_failure(self, mock_post): + mock_post.side_effect = Exception("Request Failed!") + message_body = schemas.Message(body={"callback_url": "http://example.com", "text": "This is a test", "id": 123}, response=[1,2,3]) + with self.assertLogs(level='ERROR') as cm: + self.queue_processor.send_callback(message_body) + self.assertIn("Failed with Request Failed! on http://example.com with message of", cm.output[0]) + +if __name__ == '__main__': + unittest.main() diff --git a/test/lib/queue/test_queue.py b/test/lib/queue/test_queue.py index cb33902..ee0f536 100644 --- a/test/lib/queue/test_queue.py +++ b/test/lib/queue/test_queue.py @@ -2,21 +2,14 @@ import os import unittest from unittest.mock import MagicMock, patch -from pydantic import BaseModel import numpy as np from lib.model.generic_transformer import GenericTransformerModel -from lib.queue.queue import Queue +from lib.queue.worker import QueueWorker from lib import schemas -class FakeSQSMessage(BaseModel): - body: str - receipt_handle: str +from test.lib.queue.fake_sqs_message import FakeSQSMessage -class TestQueue(unittest.TestCase): - # def overwrite_restrict_queues_by_suffix(queues, suffix): - # return [MagicMock()] - # - # @patch('lib.queue.queue.Queue.restrict_queues_by_suffix', side_effect=overwrite_restrict_queues_by_suffix) +class TestQueueWorker(unittest.TestCase): @patch('lib.queue.queue.boto3.resource') @patch('lib.helpers.get_environment_setting', return_value='us-west-1') def setUp(self, mock_get_env_setting, mock_boto_resource):#, mock_restrict_queues_by_suffix): @@ -24,31 +17,32 @@ def setUp(self, mock_get_env_setting, mock_boto_resource):#, mock_restrict_queue self.mock_model = MagicMock() self.queue_name_input = 'mean_tokens__Model' self.queue_name_output = 'mean_tokens__Model_output' - self.batch_size = 5 # Mock the SQS resource and the queues self.mock_sqs_resource = MagicMock() self.mock_input_queue = MagicMock() self.mock_input_queue.url = "http://queue/mean_tokens__Model" + self.mock_input_queue.attributes = {"QueueArn": "queue:mean_tokens__Model"} self.mock_output_queue = MagicMock() self.mock_output_queue.url = "http://queue/mean_tokens__Model_output" + self.mock_output_queue.attributes = {"QueueArn": "queue:mean_tokens__Model_output"} self.mock_sqs_resource.queues.filter.return_value = [self.mock_input_queue, self.mock_output_queue] mock_boto_resource.return_value = self.mock_sqs_resource # Initialize the SQSQueue instance - self.queue = Queue(self.queue_name_input, self.queue_name_output, self.batch_size) + self.queue = QueueWorker(self.queue_name_input, self.queue_name_output) def test_get_output_queue_name(self): - self.assertEqual(self.queue.get_output_queue_name('test'), 'test-output') + self.assertEqual(self.queue.get_output_queue_name('test'), 'test_output') self.assertEqual(self.queue.get_output_queue_name('test', 'new-output'), 'new-output') - def test_fingerprint(self): + def test_process(self): self.queue.receive_messages = MagicMock(return_value=[(FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"}})), self.mock_input_queue)]) self.queue.input_queue = MagicMock(return_value=None) self.model.model = self.mock_model self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]])) self.queue.return_response = MagicMock(return_value=None) - self.queue.fingerprint(self.model) + self.queue.process(self.model) self.queue.receive_messages.assert_called_once_with(1) def test_receive_messages(self): @@ -58,7 +52,7 @@ def test_receive_messages(self): mock_queue2 = MagicMock() mock_queue2.receive_messages.return_value = [FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": {"id": 2, "callback_url": "http://example.com", "text": "This is another test"}}))] self.queue.input_queues = [mock_queue1, mock_queue2] - received_messages = self.queue.receive_messages(self.batch_size) + received_messages = self.queue.receive_messages(5) # Check if the right number of messages were received and the content is correct self.assertEqual(len(received_messages), 2) @@ -74,6 +68,15 @@ def test_restrict_queues_by_suffix(self): restricted_queues = self.queue.restrict_queues_by_suffix(queues, "_output") self.assertEqual(len(restricted_queues), 2) # expecting two queues that don't end with _output + def test_restrict_queues_to_suffix(self): + queues = [ + MagicMock(url='http://test.com/test_input'), + MagicMock(url='http://test.com/test_input_output'), + MagicMock(url='http://test.com/test_another_input') + ] + restricted_queues = self.queue.restrict_queues_to_suffix(queues, "_output") + self.assertEqual(len(restricted_queues), 1) # expecting one queue that ends with _output + def test_group_deletions(self): messages_with_queues = [ (FakeSQSMessage(receipt_handle="blah", body=json.dumps({"body": "msg1"})), self.mock_input_queue), @@ -101,8 +104,8 @@ def test_push_message(self): # Call push_message returned_message = self.queue.push_message(self.queue_name_output, message_to_push) # Check if the message was correctly serialized and sent - self.mock_output_queue.send_message.assert_called_once_with(MessageBody='{"body": {"id": "1", "callback_url": "http://example.com", "text": "This is a test"}, "response": null}') + self.mock_output_queue.send_message.assert_called_once_with(MessageBody='{"body": {"id": "1", "callback_url": "http://example.com", "url": null, "text": "This is a test", "raw": {}}, "response": null}') self.assertEqual(returned_message, message_to_push) if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/test/lib/test_http.py b/test/lib/test_http.py index a4ec4bd..773d61f 100644 --- a/test/lib/test_http.py +++ b/test/lib/test_http.py @@ -2,25 +2,25 @@ import unittest from unittest.mock import patch from lib.http import app -from lib.queue.queue import Queue +from lib.queue.worker import QueueWorker -class TestFingerprintItem(unittest.TestCase): +class TestProcessItem(unittest.TestCase): def setUp(self): self.client = TestClient(app) - @patch.object(Queue, 'create') - @patch.object(Queue, 'push_message') - def test_fingerprint_item(self, mock_push_message, mock_create): + @patch.object(QueueWorker, 'create') + @patch.object(QueueWorker, 'push_message') + def test_process_item(self, mock_push_message, mock_create): mock_queue = mock_create.return_value mock_queue.input_queue_name = "input_queue" mock_queue.output_queue_name = "output_queue" test_data = {"id": 1, "callback_url": "http://example.com", "text": "This is a test"} - response = self.client.post("/fingerprint_item/test_fingerprinter", json=test_data) - mock_create.assert_called_once_with("test_fingerprinter") + response = self.client.post("/process_item/test_process", json=test_data) + mock_create.assert_called_once_with("test_process") self.assertEqual(response.status_code, 200) - self.assertEqual(response.json(), {"message": "Message pushed successfully"}) + self.assertEqual(response.json(), {"message": "Message pushed successfully", "queue": "test_process", "body": test_data}) @patch('lib.http.post_url')