Skip to content

Commit

Permalink
rework response structure to be more consistent across modalities
Browse files Browse the repository at this point in the history
  • Loading branch information
DGaffney committed Nov 3, 2023
1 parent e7fec72 commit 5712a19
Show file tree
Hide file tree
Showing 16 changed files with 40 additions and 80 deletions.
2 changes: 1 addition & 1 deletion lib/model/fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self):
self.model = fasttext.load_model(model_path)


def respond(self, docs: Union[List[schemas.Message], schemas.Message]) -> List[schemas.TextOutput]:
def respond(self, docs: Union[List[schemas.Message], schemas.Message]) -> List[schemas.MediaItem]:
"""
Force messages as list of messages in case we get a singular item. Then, run fingerprint routine.
Respond can probably be genericized across all models.
Expand Down
2 changes: 1 addition & 1 deletion lib/model/generic_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self, model_name: str):
if model_name:
self.model = SentenceTransformer(model_name, cache_folder=os.getenv("MODEL_DIR", "./models"))

def respond(self, docs: Union[List[schemas.Message], schemas.Message]) -> List[schemas.TextOutput]:
def respond(self, docs: Union[List[schemas.Message], schemas.Message]) -> List[schemas.MediaItem]:
"""
Force messages as list of messages in case we get a singular item. Then, run fingerprint routine.
Respond can probably be genericized across all models.
Expand Down
2 changes: 1 addition & 1 deletion lib/model/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_iobytes_for_image(self, image: schemas.Message) -> io.BytesIO:
).read()
)

def process(self, image: schemas.Message) -> schemas.ImageOutput:
def process(self, image: schemas.Message) -> schemas.MediaItem:
"""
Generic function for returning the actual response.
"""
Expand Down
2 changes: 1 addition & 1 deletion lib/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def respond(self, messages: Union[List[schemas.Message], schemas.Message]) -> Li
if not isinstance(messages, list):
messages = [messages]
for message in messages:
message.response = self.process(message)
message.body.hash_value = self.process(message)
return messages

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion lib/model/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def tmk_bucket(self) -> str:
"""
return "presto_tmk_videos"

def process(self, video: schemas.Message) -> schemas.VideoOutput:
def process(self, video: schemas.Message) -> schemas.VideoItem:
"""
Main fingerprinting routine - download video to disk, get short hash,
then calculate larger TMK hash and upload that to S3.
Expand Down
62 changes: 11 additions & 51 deletions lib/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,59 +3,19 @@

# Output hash values can be of different types.
HashValue = Union[List[float], str, int]
class TextInput(BaseModel):
id: str
callback_url: str
text: str

class TextOutput(BaseModel):
id: str
callback_url: str
text: str
hash_value: HashValue

class VideoInput(BaseModel):
id: str
callback_url: str
url: str

class VideoOutput(BaseModel):
id: str
callback_url: str
url: str
bucket: str
outfile: str
hash_value: HashValue

class AudioInput(BaseModel):
id: str
callback_url: str
url: str

class AudioOutput(BaseModel):
id: str
callback_url: str
url: str
hash_value: HashValue

class ImageInput(BaseModel):
id: str
callback_url: str
url: str

class ImageOutput(BaseModel):
id: str
callback_url: str
url: str
hash_value: HashValue

class GenericInput(BaseModel):
id: str
callback_url: str
class GenericItem(BaseModel):
id: Optional[str] = None
callback_url: Optional[str] = None
url: Optional[str] = None
text: Optional[str] = None
raw: Optional[dict] = {}

class MediaItem(GenericItem):
hash_value: Optional[HashValue] = None

class VideoItem(MediaItem):
bucket: Optional[str] = None
outfile: Optional[str] = None

class Message(BaseModel):
body: GenericInput
response: Any
body: Union[GenericItem, MediaItem, VideoItem]
4 changes: 2 additions & 2 deletions test/lib/model/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_process_audio_success(self, mock_fingerprint_file, mock_request, mock_u

mock_urlopen.return_value = MagicMock(read=MagicMock(return_value=contents))

audio = schemas.Message(body=schemas.AudioInput(id="123", callback_url="http://example.com/callback", url="https://example.com/audio.mp3"))
audio = schemas.Message(body=schemas.MediaItem(id="123", callback_url="http://example.com/callback", url="https://example.com/audio.mp3"))
result = self.audio_model.process(audio)
mock_request.assert_called_once_with(audio.body.url, headers={'User-Agent': 'Mozilla/5.0'})
mock_urlopen.assert_called_once_with(mock_request)
Expand All @@ -45,7 +45,7 @@ def test_process_audio_failure(self, mock_decode_fingerprint, mock_fingerprint_f

mock_urlopen.return_value = MagicMock(read=MagicMock(return_value=contents))

audio = schemas.Message(body=schemas.AudioInput(id="123", callback_url="http://example.com/callback", url="https://example.com/audio.mp3"))
audio = schemas.Message(body=schemas.MediaItem(id="123", callback_url="http://example.com/callback", url="https://example.com/audio.mp3"))
result = self.audio_model.process(audio)
mock_request.assert_called_once_with(audio.body.url, headers={'User-Agent': 'Mozilla/5.0'})
mock_urlopen.assert_called_once_with(mock_request)
Expand Down
4 changes: 2 additions & 2 deletions test/lib/model/test_fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ def test_respond(self, mock_fasttext_load_model, mock_hf_hub_download):
mock_fasttext_load_model.return_value = self.mock_model
self.mock_model.predict.return_value = (['__label__eng_Latn'], np.array([0.9]))
model = FasttextModel() # Now it uses mocked functions
query = [schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="Hello, how are you?"))]
query = [schemas.Message(body=schemas.MediaItem(id="123", callback_url="http://example.com/callback", text="Hello, how are you?"))]
response = model.respond(query)
self.assertEqual(len(response), 1)
self.assertEqual(response[0].response, {'language': 'en', 'script': None, 'score': 0.9})
self.assertEqual(response[0].body.hash_value, {'language': 'en', 'script': None, 'score': 0.9})

if __name__ == '__main__':
unittest.main()
6 changes: 3 additions & 3 deletions test/lib/model/test_fptg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def setUp(self):
self.mock_model = MagicMock()

def test_vectorize(self):
texts = [schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="Hello, how are you?")), schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="I'm doing great, thanks!"))]
texts = [schemas.Message(body=schemas.MediaItem(id="123", callback_url="http://example.com/callback", text="Hello, how are you?")), schemas.Message(body=schemas.MediaItem(id="123", callback_url="http://example.com/callback", text="I'm doing great, thanks!"))]
self.model.model = self.mock_model
self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]]))
vectors = self.model.vectorize(texts)
Expand All @@ -22,11 +22,11 @@ def test_vectorize(self):
self.assertEqual(vectors[1], [7, 8, 9])

def test_respond(self):
query = schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="Anong pangalan mo?"))
query = schemas.Message(body=schemas.MediaItem(id="123", callback_url="http://example.com/callback", text="Anong pangalan mo?"))
self.model.vectorize = MagicMock(return_value=[[1, 2, 3]])
response = self.model.respond(query)
self.assertEqual(len(response), 1)
self.assertEqual(response[0].response, [1, 2, 3])
self.assertEqual(response[0].body.hash_value, [1, 2, 3])

if __name__ == '__main__':
unittest.main()
6 changes: 3 additions & 3 deletions test/lib/model/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def setUp(self):
self.mock_model = MagicMock()

def test_vectorize(self):
texts = [schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="Hello, how are you?")), schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="I'm doing great, thanks!"))]
texts = [schemas.Message(body=schemas.MediaItem(id="123", callback_url="http://example.com/callback", text="Hello, how are you?")), schemas.Message(body=schemas.MediaItem(id="123", callback_url="http://example.com/callback", text="I'm doing great, thanks!"))]
self.model.model = self.mock_model
self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]]))
vectors = self.model.vectorize(texts)
Expand All @@ -22,11 +22,11 @@ def test_vectorize(self):
self.assertEqual(vectors[1], [7, 8, 9])

def test_respond(self):
query = schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="Anong pangalan mo?"))
query = schemas.Message(body=schemas.MediaItem(id="123", callback_url="http://example.com/callback", text="Anong pangalan mo?"))
self.model.vectorize = MagicMock(return_value=[[1, 2, 3]])
response = self.model.respond(query)
self.assertEqual(len(response), 1)
self.assertEqual(response[0].response, [1, 2, 3])
self.assertEqual(response[0].body.hash_value, [1, 2, 3])

if __name__ == '__main__':
unittest.main()
6 changes: 3 additions & 3 deletions test/lib/model/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ def test_get_iobytes_for_image(self, mock_urlopen):
mock_response = Mock()
mock_response.read.return_value = image_content
mock_urlopen.return_value = mock_response
image = schemas.Message(body=schemas.ImageInput(id="123", callback_url="http://example.com?callback", url="http://example.com/image.jpg"))
image = schemas.Message(body=schemas.MediaItem(id="123", callback_url="http://example.com?callback", url="http://example.com/image.jpg"))
result = Model().get_iobytes_for_image(image)
self.assertIsInstance(result, io.BytesIO)
self.assertEqual(result.read(), image_content)

@patch("urllib.request.urlopen")
def test_get_iobytes_for_image_raises_error(self, mock_urlopen):
mock_urlopen.side_effect = URLError('test error')
image = schemas.Message(body=schemas.ImageInput(id="123", callback_url="http://example.com?callback", url="http://example.com/image.jpg"))
image = schemas.Message(body=schemas.MediaItem(id="123", callback_url="http://example.com?callback", url="http://example.com/image.jpg"))
with self.assertRaises(URLError):
Model().get_iobytes_for_image(image)

Expand All @@ -42,7 +42,7 @@ def test_get_iobytes_for_image_raises_error(self, mock_urlopen):
def test_process(self, mock_compute_pdq, mock_get_iobytes_for_image):
mock_compute_pdq.return_value = "1001"
mock_get_iobytes_for_image.return_value = io.BytesIO(b"image_bytes")
image = schemas.Message(body=schemas.ImageInput(id="123", callback_url="http://example.com?callback", url="http://example.com/image.jpg"))
image = schemas.Message(body=schemas.MediaItem(id="123", callback_url="http://example.com?callback", url="http://example.com/image.jpg"))
result = Model().process(image)
self.assertEqual(result, {"hash_value": "1001"})

Expand Down
6 changes: 3 additions & 3 deletions test/lib/model/test_indian_sbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def setUp(self):
self.mock_model = MagicMock()

def test_vectorize(self):
texts = [schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="Hello, how are you?")), schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="I'm doing great, thanks!"))]
texts = [schemas.Message(body=schemas.MediaItem(id="123", callback_url="http://example.com/callback", text="Hello, how are you?")), schemas.Message(body=schemas.MediaItem(id="123", callback_url="http://example.com/callback", text="I'm doing great, thanks!"))]
self.model.model = self.mock_model
self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]]))
vectors = self.model.vectorize(texts)
Expand All @@ -22,11 +22,11 @@ def test_vectorize(self):
self.assertEqual(vectors[1], [7, 8, 9])

def test_respond(self):
query = schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="What is the capital of India?"))
query = schemas.Message(body=schemas.MediaItem(id="123", callback_url="http://example.com/callback", text="What is the capital of India?"))
self.model.vectorize = MagicMock(return_value=[[1, 2, 3]])
response = self.model.respond(query)
self.assertEqual(len(response), 1)
self.assertEqual(response[0].response, [1, 2, 3])
self.assertEqual(response[0].body.hash_value, [1, 2, 3])

if __name__ == '__main__':
unittest.main()
6 changes: 3 additions & 3 deletions test/lib/model/test_meantokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def setUp(self):
self.mock_model = MagicMock()

def test_vectorize(self):
texts = [schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="Hello, how are you?")), schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="I'm doing great, thanks!"))]
texts = [schemas.Message(body=schemas.MediaItem(id="123", callback_url="http://example.com/callback", text="Hello, how are you?")), schemas.Message(body=schemas.MediaItem(id="123", callback_url="http://example.com/callback", text="I'm doing great, thanks!"))]
self.model.model = self.mock_model
self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]]))
vectors = self.model.vectorize(texts)
Expand All @@ -22,11 +22,11 @@ def test_vectorize(self):
self.assertEqual(vectors[1], [7, 8, 9])

def test_respond(self):
query = schemas.Message(body=schemas.TextInput(id="123", callback_url="http://example.com/callback", text="What is the capital of France?"))
query = schemas.Message(body=schemas.MediaItem(id="123", callback_url="http://example.com/callback", text="What is the capital of France?"))
self.model.vectorize = MagicMock(return_value=[[1, 2, 3]])
response = self.model.respond(query)
self.assertEqual(len(response), 1)
self.assertEqual(response[0].response, [1, 2, 3])
self.assertEqual(response[0].body.hash_value, [1, 2, 3])

if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion test/lib/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# class TestModel(unittest.TestCase):
# def test_respond(self):
# model = Model()
# self.assertEqual(model.respond(schemas.Message(body=schemas.TextInput(id='123', callback_url="http://example.com/callback", text="hello"))), model.respond(schemas.Message(body=schemas.TextInput(id='123', callback_url="http://example.com/callback", text="hello"), response=[])))
# self.assertEqual(model.respond(schemas.Message(body=schemas.MediaItem(id='123', callback_url="http://example.com/callback", text="hello"))), model.respond(schemas.Message(body=schemas.MediaItem(id='123', callback_url="http://example.com/callback", text="hello"), response=[])))
#
# if __name__ == '__main__':
# unittest.main()
6 changes: 3 additions & 3 deletions test/lib/model/test_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_process_video(self, mock_pathlib, mock_upload_file_to_s3,
mock_hash_video_output.getPureAverageFeature.return_value = "hash_value"
mock_hash_video.return_value = mock_hash_video_output
mock_urlopen.return_value = MagicMock(read=MagicMock(return_value=video_contents))
self.video_model.process(schemas.Message(body=schemas.VideoInput(id="123", callback_url="http://blah.com?callback_id=123", url="http://example.com/video.mp4")))
self.video_model.process(schemas.Message(body=schemas.VideoItem(id="123", callback_url="http://blah.com?callback_id=123", url="http://example.com/video.mp4")))
mock_urlopen.assert_called_once()
mock_hash_video.assert_called_once_with(ANY, "/usr/local/bin/ffmpeg")

Expand All @@ -49,15 +49,15 @@ def test_tmk_program_name(self):
self.assertEqual(result, "PrestoVideoEncoder")

def test_respond_with_single_video(self):
video = schemas.Message(body=schemas.VideoInput(id="123", callback_url="http://blah.com?callback_id=123", url="http://example.com/video.mp4"))
video = schemas.Message(body=schemas.VideoItem(id="123", callback_url="http://blah.com?callback_id=123", url="http://example.com/video.mp4"))
mock_process = MagicMock()
self.video_model.process = mock_process
result = self.video_model.respond(video)
mock_process.assert_called_once_with(video)
self.assertEqual(result, [video])

def test_respond_with_multiple_videos(self):
videos = [schemas.Message(body=schemas.VideoInput(id="123", callback_url="http://blah.com?callback_id=123", url="http://example.com/video1.mp4")), schemas.Message(body=schemas.VideoInput(id="123", callback_url="http://blah.com?callback_id=123", url="http://example.com/video2.mp4"))]
videos = [schemas.Message(body=schemas.VideoItem(id="123", callback_url="http://blah.com?callback_id=123", url="http://example.com/video1.mp4")), schemas.Message(body=schemas.VideoItem(id="123", callback_url="http://blah.com?callback_id=123", url="http://example.com/video2.mp4"))]
mock_process = MagicMock()
self.video_model.process = mock_process
result = self.video_model.respond(videos)
Expand Down
2 changes: 1 addition & 1 deletion test/lib/queue/test_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_push_message(self):
# Call push_message
returned_message = self.queue.push_message(self.queue_name_output, message_to_push)
# Check if the message was correctly serialized and sent
self.mock_output_queue.send_message.assert_called_once_with(MessageBody='{"body": {"id": "1", "callback_url": "http://example.com", "url": null, "text": "This is a test", "raw": {}}, "response": null}')
self.mock_output_queue.send_message.assert_called_once_with(MessageBody='{"body": {"id": "1", "callback_url": "http://example.com", "url": null, "text": "This is a test", "raw": {}}}')
self.assertEqual(returned_message, message_to_push)

if __name__ == '__main__':
Expand Down

0 comments on commit 5712a19

Please sign in to comment.