Skip to content

Commit

Permalink
Merge pull request #48 from meedan/cv2-2875-schema-fixes
Browse files Browse the repository at this point in the history
CV2-2875 tweak response structure
  • Loading branch information
DGaffney authored Nov 9, 2023
2 parents 17d0357 + 6165fdc commit 4ed34f0
Show file tree
Hide file tree
Showing 23 changed files with 72 additions and 98 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,4 +206,4 @@ If no callback_url is provided:
{
"message": "No Message Callback, Passing"
}
```
```
2 changes: 1 addition & 1 deletion lib/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def process_item(process_name: str, message: Dict[str, Any]):
logger.info(message)
queue_prefix = Queue.get_queue_prefix()
queue = QueueWorker.create(process_name)
queue.push_message(f"{queue_prefix}{process_name}", schemas.Message(body=message))
queue.push_message(f"{queue_prefix}{process_name}", schemas.Message(body=message, model_name=process_name))
return {"message": "Message pushed successfully", "queue": process_name, "body": message}

@app.post("/trigger_callback")
Expand Down
2 changes: 1 addition & 1 deletion lib/model/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ def process(self, audio: schemas.Message) -> Dict[str, Union[str, List[int]]]:
hash_value = self.audio_hasher(temp_file_name)
finally:
os.remove(temp_file_name)
return {"hash_value": hash_value}
return hash_value
4 changes: 2 additions & 2 deletions lib/model/fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self):
self.model = fasttext.load_model(model_path)


def respond(self, docs: Union[List[schemas.Message], schemas.Message]) -> List[schemas.TextOutput]:
def respond(self, docs: Union[List[schemas.Message], schemas.Message]) -> List[schemas.GenericItem]:
"""
Force messages as list of messages in case we get a singular item. Then, run fingerprint routine.
Respond can probably be genericized across all models.
Expand All @@ -42,5 +42,5 @@ def respond(self, docs: Union[List[schemas.Message], schemas.Message]) -> List[s
detected_langs.append({'language': model_language, 'script': model_script, 'score': model_certainty})

for doc, detected_lang in zip(docs, detected_langs):
doc.response = detected_lang
doc.body.hash_value = detected_lang
return docs
4 changes: 2 additions & 2 deletions lib/model/generic_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self, model_name: str):
if model_name:
self.model = SentenceTransformer(model_name, cache_folder=os.getenv("MODEL_DIR", "./models"))

def respond(self, docs: Union[List[schemas.Message], schemas.Message]) -> List[schemas.TextOutput]:
def respond(self, docs: Union[List[schemas.Message], schemas.Message]) -> List[schemas.GenericItem]:
"""
Force messages as list of messages in case we get a singular item. Then, run fingerprint routine.
Respond can probably be genericized across all models.
Expand All @@ -25,7 +25,7 @@ def respond(self, docs: Union[List[schemas.Message], schemas.Message]) -> List[s
vectorizable_texts = [e.body.text for e in docs]
vectorized = self.vectorize(vectorizable_texts)
for doc, vector in zip(docs, vectorized):
doc.response = vector
doc.body.hash_value = vector
return docs

def vectorize(self, texts: List[str]) -> List[List[float]]:
Expand Down
4 changes: 2 additions & 2 deletions lib/model/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def get_iobytes_for_image(self, image: schemas.Message) -> io.BytesIO:
).read()
)

def process(self, image: schemas.Message) -> schemas.ImageOutput:
def process(self, image: schemas.Message) -> schemas.GenericItem:
"""
Generic function for returning the actual response.
"""
return {"hash_value": self.compute_pdq(self.get_iobytes_for_image(image))}
return self.compute_pdq(self.get_iobytes_for_image(image))
5 changes: 4 additions & 1 deletion lib/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from lib import schemas
class Model(ABC):
BATCH_SIZE = 1
def __init__(self):
self.model_name = os.environ.get("MODEL_NAME")

def get_tempfile_for_url(self, url: str) -> str:
"""
Loads a file based on specified URL into a named tempfile.
Expand Down Expand Up @@ -43,7 +46,7 @@ def respond(self, messages: Union[List[schemas.Message], schemas.Message]) -> Li
if not isinstance(messages, list):
messages = [messages]
for message in messages:
message.response = self.process(message)
message.body.hash_value = self.process(message)
return messages

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion lib/model/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def tmk_bucket(self) -> str:
"""
return "presto_tmk_videos"

def process(self, video: schemas.Message) -> schemas.VideoOutput:
def process(self, video: schemas.Message) -> schemas.GenericItem:
"""
Main fingerprinting routine - download video to disk, get short hash,
then calculate larger TMK hash and upload that to S3.
Expand Down
2 changes: 1 addition & 1 deletion lib/queue/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def send_callbacks(self) -> List[schemas.Message]:
"""
messages_with_queues = self.receive_messages(self.batch_size)
if messages_with_queues:
logger.debug(f"About to respond to: ({messages_with_queues})")
logger.info(f"About to respond to: ({messages_with_queues})")
bodies = [schemas.Message(**json.loads(message.body)) for message, queue in messages_with_queues]
for body in bodies:
self.send_callback(body)
Expand Down
2 changes: 1 addition & 1 deletion lib/queue/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def safely_respond(self, model: Model) -> List[schemas.Message]:
if messages_with_queues:
logger.debug(f"About to respond to: ({messages_with_queues})")
try:
responses = model.respond([schemas.Message(**json.loads(message.body)) for message, queue in messages_with_queues])
responses = model.respond([schemas.Message(**{**json.loads(message.body), **{"model_name": model.model_name}}) for message, queue in messages_with_queues])
except Exception as e:
logger.error(e)
self.delete_messages(messages_with_queues)
Expand Down
72 changes: 21 additions & 51 deletions lib/schemas.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,30 @@
from typing import Any, List, Optional, Union
from pydantic import BaseModel
from pydantic import BaseModel, root_validator

# Output hash values can be of different types.
HashValue = Union[List[float], str, int]
class TextInput(BaseModel):
class GenericItem(BaseModel):
id: str
callback_url: str
text: str

class TextOutput(BaseModel):
id: str
callback_url: str
text: str

class VideoInput(BaseModel):
id: str
callback_url: str
url: str

class VideoOutput(BaseModel):
id: str
callback_url: str
url: str
bucket: str
outfile: str
hash_value: HashValue

class AudioInput(BaseModel):
id: str
callback_url: str
url: str

class AudioOutput(BaseModel):
id: str
callback_url: str
url: str
hash_value: HashValue

class ImageInput(BaseModel):
id: str
callback_url: str
url: str

class ImageOutput(BaseModel):
id: str
callback_url: str
url: str
hash_value: HashValue

class GenericInput(BaseModel):
id: str
callback_url: str
callback_url: Optional[str] = None
url: Optional[str] = None
text: Optional[str] = None
raw: Optional[dict] = {}

class MediaItem(GenericItem):
hash_value: Optional[Any] = None

class VideoItem(MediaItem):
bucket: Optional[str] = None
outfile: Optional[str] = None

class Message(BaseModel):
body: GenericInput
response: Any
body: Union[MediaItem, VideoItem]
model_name: str
@root_validator(pre=True)
def set_body(cls, values):
body = values.get("body")
model_name = values.get("model_name")
if model_name == "video__Model":
values["body"] = VideoItem(**values["body"]).dict()
if model_name in ["audio__Model", "image__Model", "fptg__Model", "indian_sbert__Model", "mean_tokens__Model", "fasttext__Model"]:
values["body"] = MediaItem(**values["body"]).dict()
return values
8 changes: 4 additions & 4 deletions test/lib/model/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ def test_process_audio_success(self, mock_fingerprint_file, mock_request, mock_u

mock_urlopen.return_value = MagicMock(read=MagicMock(return_value=contents))

audio = schemas.Message(body=schemas.AudioInput(id="123", callback_url="http://example.com/callback", url="https://example.com/audio.mp3"))
audio = schemas.Message(body={"id": "123", "callback_url": "http://example.com/callback", "url": "https://example.com/audio.mp3"}, model_name="audio__Model")
result = self.audio_model.process(audio)
mock_request.assert_called_once_with(audio.body.url, headers={'User-Agent': 'Mozilla/5.0'})
mock_urlopen.assert_called_once_with(mock_request)
self.assertEqual(list, type(result["hash_value"]))
self.assertEqual(list, type(result))

@patch('urllib.request.urlopen')
@patch('urllib.request.Request')
Expand All @@ -45,11 +45,11 @@ def test_process_audio_failure(self, mock_decode_fingerprint, mock_fingerprint_f

mock_urlopen.return_value = MagicMock(read=MagicMock(return_value=contents))

audio = schemas.Message(body=schemas.AudioInput(id="123", callback_url="http://example.com/callback", url="https://example.com/audio.mp3"))
audio = schemas.Message(body={"id": "123", "callback_url": "http://example.com/callback", "url": "https://example.com/audio.mp3"}, model_name="audio__Model")
result = self.audio_model.process(audio)
mock_request.assert_called_once_with(audio.body.url, headers={'User-Agent': 'Mozilla/5.0'})
mock_urlopen.assert_called_once_with(mock_request)
self.assertEqual([], result["hash_value"])
self.assertEqual([], result)

if __name__ == '__main__':
unittest.main()
4 changes: 2 additions & 2 deletions test/lib/model/test_fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ def test_respond(self, mock_fasttext_load_model, mock_hf_hub_download):
mock_fasttext_load_model.return_value = self.mock_model
self.mock_model.predict.return_value = (['__label__eng_Latn'], np.array([0.9]))
model = FasttextModel() # Now it uses mocked functions
query = [schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="Hello, how are you?"))]
query = [schemas.Message(body={"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, model_name="fasttext__Model")]
response = model.respond(query)
self.assertEqual(len(response), 1)
self.assertEqual(response[0].response, {'language': 'en', 'script': None, 'score': 0.9})
self.assertEqual(response[0].body.hash_value, {'language': 'en', 'script': None, 'score': 0.9})

if __name__ == '__main__':
unittest.main()
6 changes: 3 additions & 3 deletions test/lib/model/test_fptg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def setUp(self):
self.mock_model = MagicMock()

def test_vectorize(self):
texts = [schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="Hello, how are you?")), schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="I'm doing great, thanks!"))]
texts = [schemas.Message(body={"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, model_name="fptg__Model"), schemas.Message(body={"id": "123", "callback_url": "http://example.com/callback", "text": "I'm doing great, thanks!"}, model_name="fptg__Model")]
self.model.model = self.mock_model
self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]]))
vectors = self.model.vectorize(texts)
Expand All @@ -22,11 +22,11 @@ def test_vectorize(self):
self.assertEqual(vectors[1], [7, 8, 9])

def test_respond(self):
query = schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="Anong pangalan mo?"))
query = schemas.Message(body={"id": "123", "callback_url": "http://example.com/callback", "text": "Anong pangalan mo?"}, model_name="fptg__Model")
self.model.vectorize = MagicMock(return_value=[[1, 2, 3]])
response = self.model.respond(query)
self.assertEqual(len(response), 1)
self.assertEqual(response[0].response, [1, 2, 3])
self.assertEqual(response[0].body.hash_value, [1, 2, 3])

if __name__ == '__main__':
unittest.main()
6 changes: 3 additions & 3 deletions test/lib/model/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def setUp(self):
self.mock_model = MagicMock()

def test_vectorize(self):
texts = [schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="Hello, how are you?")), schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="I'm doing great, thanks!"))]
texts = [schemas.Message(body={"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, model_name="fptg__Model"), schemas.Message(body={"id": "123", "callback_url": "http://example.com/callback", "text": "I'm doing great, thanks!"}, model_name="fptg__Model")]
self.model.model = self.mock_model
self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]]))
vectors = self.model.vectorize(texts)
Expand All @@ -22,11 +22,11 @@ def test_vectorize(self):
self.assertEqual(vectors[1], [7, 8, 9])

def test_respond(self):
query = schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="Anong pangalan mo?"))
query = schemas.Message(body={"id": "123", "callback_url": "http://example.com/callback", "text": "Anong pangalan mo?"}, model_name="fptg__Model")
self.model.vectorize = MagicMock(return_value=[[1, 2, 3]])
response = self.model.respond(query)
self.assertEqual(len(response), 1)
self.assertEqual(response[0].response, [1, 2, 3])
self.assertEqual(response[0].body.hash_value, [1, 2, 3])

if __name__ == '__main__':
unittest.main()
8 changes: 4 additions & 4 deletions test/lib/model/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ def test_get_iobytes_for_image(self, mock_urlopen):
mock_response = Mock()
mock_response.read.return_value = image_content
mock_urlopen.return_value = mock_response
image = schemas.Message(body=schemas.ImageInput(id="123", callback_url="http://example.com?callback", url="http://example.com/image.jpg"))
image = schemas.Message(body={"id": "123", "callback_url": "http://example.com?callback", "url": "http://example.com/image.jpg"}, model_name="audio__Model")
result = Model().get_iobytes_for_image(image)
self.assertIsInstance(result, io.BytesIO)
self.assertEqual(result.read(), image_content)

@patch("urllib.request.urlopen")
def test_get_iobytes_for_image_raises_error(self, mock_urlopen):
mock_urlopen.side_effect = URLError('test error')
image = schemas.Message(body=schemas.ImageInput(id="123", callback_url="http://example.com?callback", url="http://example.com/image.jpg"))
image = schemas.Message(body={"id": "123", "callback_url": "http://example.com?callback", "url": "http://example.com/image.jpg"}, model_name="audio__Model")
with self.assertRaises(URLError):
Model().get_iobytes_for_image(image)

Expand All @@ -42,9 +42,9 @@ def test_get_iobytes_for_image_raises_error(self, mock_urlopen):
def test_process(self, mock_compute_pdq, mock_get_iobytes_for_image):
mock_compute_pdq.return_value = "1001"
mock_get_iobytes_for_image.return_value = io.BytesIO(b"image_bytes")
image = schemas.Message(body=schemas.ImageInput(id="123", callback_url="http://example.com?callback", url="http://example.com/image.jpg"))
image = schemas.Message(body={"id": "123", "callback_url": "http://example.com?callback", "url": "http://example.com/image.jpg"}, model_name="audio__Model")
result = Model().process(image)
self.assertEqual(result, {"hash_value": "1001"})
self.assertEqual(result, "1001")


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions test/lib/model/test_indian_sbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def setUp(self):
self.mock_model = MagicMock()

def test_vectorize(self):
texts = [schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="Hello, how are you?")), schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="I'm doing great, thanks!"))]
texts = [schemas.Message(body={"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, model_name="indian_sbert__Model"), schemas.Message(body={"id": "123", "callback_url": "http://example.com/callback", "text": "I'm doing great, thanks!"}, model_name="indian_sbert__Model")]
self.model.model = self.mock_model
self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]]))
vectors = self.model.vectorize(texts)
Expand All @@ -22,11 +22,11 @@ def test_vectorize(self):
self.assertEqual(vectors[1], [7, 8, 9])

def test_respond(self):
query = schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="What is the capital of India?"))
query = schemas.Message(body={"id": "123", "callback_url": "http://example.com/callback", "text": "What is the capital of India?"}, model_name="indian_sbert__Model")
self.model.vectorize = MagicMock(return_value=[[1, 2, 3]])
response = self.model.respond(query)
self.assertEqual(len(response), 1)
self.assertEqual(response[0].response, [1, 2, 3])
self.assertEqual(response[0].body.hash_value, [1, 2, 3])

if __name__ == '__main__':
unittest.main()
6 changes: 3 additions & 3 deletions test/lib/model/test_meantokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def setUp(self):
self.mock_model = MagicMock()

def test_vectorize(self):
texts = [schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="Hello, how are you?")), schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="I'm doing great, thanks!"))]
texts = [schemas.Message(body={"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, model_name="mean_tokens__Model"), schemas.Message(body={"id": "123", "callback_url": "http://example.com/callback", "text": "I'm doing great, thanks!"}, model_name="mean_tokens__Model")]
self.model.model = self.mock_model
self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]]))
vectors = self.model.vectorize(texts)
Expand All @@ -22,11 +22,11 @@ def test_vectorize(self):
self.assertEqual(vectors[1], [7, 8, 9])

def test_respond(self):
query = schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="What is the capital of France?"))
query = schemas.Message(body={"id": "123", "callback_url": "http://example.com/callback", "text": "What is the capital of France?"}, model_name="mean_tokens__Model")
self.model.vectorize = MagicMock(return_value=[[1, 2, 3]])
response = self.model.respond(query)
self.assertEqual(len(response), 1)
self.assertEqual(response[0].response, [1, 2, 3])
self.assertEqual(response[0].body.hash_value, [1, 2, 3])

if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion test/lib/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# class TestModel(unittest.TestCase):
# def test_respond(self):
# model = Model()
# self.assertEqual(model.respond(schemas.Message(body=schemas.TextInput(id='123', callback_url="http://example.com/callback", text="hello"))), model.respond(schemas.Message(body=schemas.TextInput(id='123', callback_url="http://example.com/callback", text="hello"), response=[])))
# self.assertEqual(model.respond(schemas.Message(body=schemas.GenericItem(id='123', callback_url="http://example.com/callback", text="hello"))), model.respond(schemas.Message(body=schemas.GenericItem(id='123', callback_url="http://example.com/callback", text="hello"), response=[])))
#
# if __name__ == '__main__':
# unittest.main()
Loading

0 comments on commit 4ed34f0

Please sign in to comment.