Skip to content

Commit

Permalink
Merge branch 'master' into cv2-3408-language-codes
Browse files Browse the repository at this point in the history
  • Loading branch information
DGaffney authored Sep 22, 2023
2 parents ea519b3 + 72487b0 commit 59a8bd0
Show file tree
Hide file tree
Showing 31 changed files with 382 additions and 191 deletions.
4 changes: 3 additions & 1 deletion local.env → .env_file
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
PRESTO_PORT=8000
DEPLOY_ENV=local
MODEL_NAME=mean_tokens.Model
# MODEL_NAME=mean_tokens.Model
MODEL_NAME=audio.Model
AWS_ACCESS_KEY_ID=SOMETHING
AWS_SECRET_ACCESS_KEY=OTHERTHING
10 changes: 10 additions & 0 deletions .env_file.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
PRESTO_PORT=8000
DEPLOY_ENV=local
# MODEL_NAME=mean_tokens.Model
MODEL_NAME=audio.Model
AWS_ACCESS_KEY_ID=SOMETHING
AWS_SECRET_ACCESS_KEY=OTHERTHING
<<<<<<< HEAD
=======

>>>>>>> master
7 changes: 6 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
ARG PRESTO_PORT

FROM python:3.9
EXPOSE 8000

ENV PRESTO_PORT=${PRESTO_PORT}
EXPOSE ${PRESTO_PORT}

WORKDIR /app
ENV DEBIAN_FRONTEND=noninteractive

Expand Down
9 changes: 6 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
.PHONY: run run_http run_worker run_test

run:
./start_healthcheck_and_model_engine.sh
./start_all.sh

run_http:
uvicorn main:app --host 0.0.0.0 --reload

run_worker:
python run.py
python run_worker.py

run_processor:
python run_processor.py

run_test:
python -m unittest discover .
python -m pytest test
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,12 @@ Output Message:
```

### Endpoints
#### /fingerprint_item/{fingerprinter}
#### /process_item/{process_name}
This endpoint pushes a message into a queue. It's an async operation, meaning the server will respond before the operation is complete. This is useful when working with slow or unreliable external resources.

Request
```
curl -X POST "http://127.0.0.1:8000/fingerprint_item/sample_fingerprinter" -H "accept: application/json" -H "Content-Type: application/json" -d "{\"message_key\":\"message_value\"}"
curl -X POST "http://127.0.0.1:8000/process_item/mean_tokens__Model" -H "accept: application/json" -H "Content-Type: application/json" -d "{\"message_key\":\"message_value\"}"
```

Replace sample_fingerprinter with the name of your fingerprinter, and message_key and message_value with your actual message data.
Expand Down
6 changes: 5 additions & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@ services:
platform: linux/amd64
build: .
env_file:
- ./local.env
- ./.env_file
depends_on:
- elasticmq
links:
- elasticmq
volumes:
- ./:/app
ports:
- "${PRESTO_PORT}:${PRESTO_PORT}"
args:
- PRESTO_PORT=${PRESTO_PORT}
elasticmq:
image: softwaremill/elasticmq
hostname: presto-elasticmq
Expand Down
23 changes: 15 additions & 8 deletions lib/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from httpx import HTTPStatusError
from fastapi import FastAPI, Request
from pydantic import BaseModel
from lib.queue.queue import Queue
from lib.queue.worker import QueueWorker
from lib.logger import logger
from lib import schemas
from lib.sentry import sentry_sdk

app = FastAPI()

Expand All @@ -27,14 +28,15 @@ async def post_url(url: str, params: dict) -> Dict[str, Any]:
except HTTPStatusError:
return {"error": f"HTTP Error on Attempt to call {url} with {params}"}

@app.post("/fingerprint_item/{fingerprinter}")
def fingerprint_item(fingerprinter: str, message: Dict[str, Any]):
queue = Queue.create(fingerprinter)
queue.push_message(fingerprinter, schemas.Message(body=message, input_queue=queue.input_queue_name, output_queue=queue.output_queue_name, start_time=str(datetime.datetime.now())))
return {"message": "Message pushed successfully"}
@app.post("/process_item/{process_name}")
def process_item(process_name: str, message: Dict[str, Any]):
logger.info(message)
queue = QueueWorker.create(process_name)
queue.push_message(process_name, schemas.Message(body=message))
return {"message": "Message pushed successfully", "queue": process_name, "body": message}

@app.post("/trigger_callback")
async def fingerprint_item(message: Dict[str, Any]):
async def process_item(message: Dict[str, Any]):
url = message.get("callback_url")
if url:
response = await post_url(url, message)
Expand All @@ -46,5 +48,10 @@ async def fingerprint_item(message: Dict[str, Any]):
return {"message": "No Message Callback, Passing"}

@app.get("/ping")
def fingerprint_item():
def process_item():
return {"pong": 1}

@app.post("/echo")
async def echo(message: Dict[str, Any]):
logger.info(f"About to echo message of {message}")
return {"echo": message}
2 changes: 1 addition & 1 deletion lib/model/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def audio_hasher(self, filename: str) -> List[int]:
except acoustid.FingerprintGenerationError:
return []

def fingerprint(self, audio: schemas.Message) -> Dict[str, Union[str, List[int]]]:
def process(self, audio: schemas.Message) -> Dict[str, Union[str, List[int]]]:
temp_file_name = self.get_tempfile_for_url(audio.body.url)
try:
hash_value = self.audio_hasher(temp_file_name)
Expand Down
4 changes: 2 additions & 2 deletions lib/model/generic_transformer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from typing import Union, Dict, List
from sentence_transformers import SentenceTransformer

from lib.logger import logger
from lib.model.model import Model
from lib import schemas

Expand All @@ -21,7 +21,7 @@ def respond(self, docs: Union[List[schemas.Message], schemas.Message]) -> List[s
"""
if not isinstance(docs, list):
docs = [docs]
print(docs)
logger.info(docs)
vectorizable_texts = [e.body.text for e in docs]
vectorized = self.vectorize(vectorizable_texts)
for doc, vector in zip(docs, vectorized):
Expand Down
2 changes: 1 addition & 1 deletion lib/model/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_iobytes_for_image(self, image: schemas.Message) -> io.BytesIO:
).read()
)

def fingerprint(self, image: schemas.Message) -> schemas.ImageOutput:
def process(self, image: schemas.Message) -> schemas.ImageOutput:
"""
Generic function for returning the actual response.
"""
Expand Down
4 changes: 2 additions & 2 deletions lib/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_tempfile(self) -> Any:
"""
return tempfile.NamedTemporaryFile()

def fingerprint(self, messages: Union[List[schemas.Message], schemas.Message]) -> List[schemas.Message]:
def process(self, messages: Union[List[schemas.Message], schemas.Message]) -> List[schemas.Message]:
return []

def respond(self, messages: Union[List[schemas.Message], schemas.Message]) -> List[schemas.Message]:
Expand All @@ -43,7 +43,7 @@ def respond(self, messages: Union[List[schemas.Message], schemas.Message]) -> Li
if not isinstance(messages, list):
messages = [messages]
for message in messages:
message.response = self.fingerprint(message)
message.response = self.process(message)
return messages

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion lib/model/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def tmk_bucket(self) -> str:
"""
return "presto_tmk_videos"

def fingerprint(self, video: schemas.Message) -> schemas.VideoOutput:
def process(self, video: schemas.Message) -> schemas.VideoOutput:
"""
Main fingerprinting routine - download video to disk, get short hash,
then calculate larger TMK hash and upload that to S3.
Expand Down
57 changes: 57 additions & 0 deletions lib/queue/processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import List
import json

import requests

from lib import schemas
from lib.logger import logger
from lib.helpers import get_setting
from lib.queue.queue import Queue
class QueueProcessor(Queue):
@classmethod
def create(cls, input_queue_name: str = None, batch_size: int = 10):
"""
Instantiate a queue. Must pass input_queue_name, output_queue_name, and batch_size.
Pulls settings and then inits instance.
"""
input_queue_name = get_setting(input_queue_name, "MODEL_NAME").replace(".", "__")
logger.info(f"Starting queue with: ('{input_queue_name}', {batch_size})")
return QueueProcessor(input_queue_name, batch_size)

def __init__(self, input_queue_name: str, output_queue_name: str = None, batch_size: int = 1):
"""
Start a specific queue - must pass input_queue_name - optionally pass output_queue_name, batch_size.
"""
super().__init__()
self.input_queue_name = input_queue_name
self.input_queues = self.restrict_queues_to_suffix(self.get_or_create_queues(input_queue_name+"_output"), "_output")
self.all_queues = self.store_queue_map(self.input_queues)
logger.info(f"Processor listening to queues of {self.all_queues}")
self.batch_size = batch_size

def send_callbacks(self) -> List[schemas.Message]:
"""
Main routine. Given a model, in a loop, read tasks from input_queue_name at batch_size depth,
pass messages to model to respond (i.e. fingerprint) them, then pass responses to output queue.
If failures happen at any point, resend failed messages to input queue.
"""
messages_with_queues = self.receive_messages(self.batch_size)
if messages_with_queues:
logger.debug(f"About to respond to: ({messages_with_queues})")
bodies = [schemas.Message(**json.loads(message.body)) for message, queue in messages_with_queues]
for body in bodies:
self.send_callback(body)
self.delete_messages(messages_with_queues)


def send_callback(self, message):
"""
Rescue against failures when attempting to respond (i.e. fingerprint) from models.
Return responses if no failure.
"""
logger.info(f"Message for callback is: {message}")
try:
callback_url = message.body.callback_url
requests.post(callback_url, json=message.dict())
except Exception as e:
logger.error(f"Callback fail! Failed with {e} on {callback_url} with message of {message}")
Loading

0 comments on commit 59a8bd0

Please sign in to comment.