Skip to content

Commit

Permalink
Large refactor
Browse files Browse the repository at this point in the history
* Create GenericImageModel and have pdq and sscd inherit form it
* Move model loading for sscd to __init__
* Update requirements.txt for sscd
  • Loading branch information
computermacgyver committed Nov 11, 2023
1 parent 11fd99c commit 5372d57
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 88 deletions.
2 changes: 2 additions & 0 deletions .env_file
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,7 @@ PRESTO_PORT=8000
DEPLOY_ENV=local
# MODEL_NAME=mean_tokens.Model
MODEL_NAME=audio.Model
# MODEL_NAME=image_sscd.Model
# MODEL_NAME=image_pdq.Model
AWS_ACCESS_KEY_ID=SOMETHING
AWS_SECRET_ACCESS_KEY=OTHERTHING
3 changes: 1 addition & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ RUN pip install pact-python
RUN pip install --no-cache-dir -r requirements.txt
RUN cd threatexchange/pdq/python && pip install .

RUN mkdir models_files
RUN wget https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_disc_mixup.torchscript.pt
RUN wget -O "sscd_disc_mixup.torchscript.pt" "https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_disc_mixup.torchscript.pt"

COPY . .
CMD ["make", "run"]
37 changes: 0 additions & 37 deletions lib/model/image.py

This file was deleted.

58 changes: 15 additions & 43 deletions lib/model/image_sscd.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,24 @@
from typing import Dict
import io
import urllib.request

from lib.model.model import Model

from pdqhashing.hasher.pdq_hasher import PDQHasher
from lib.model.generic_image import GenericImageModel
from lib import schemas
from torchvision import transforms
from PIL import Image
import torch
from lib.logger import logger
import requests
import numpy as np
from PIL import Image

class Model(GenericImageModel):
def __init__(self):
super().__init__()
self.model = torch.jit.load("sscd_disc_mixup.torchscript.pt")

class Model(Model):
def compute_sscd(self, image_url: str) -> str:
def compute_sscd(self, iobytes: io.BytesIO) -> str:
"""Compute perceptual hash using ImageHash library
:param im: Numpy.ndarray
:returns: Imagehash.ImageHash
:param im: Numpy.ndarray #FIXME
:returns: Imagehash.ImageHash #FIXME
"""
# pdq_hasher = PDQHasher()
# hash_and_qual = pdq_hasher.fromBufferedImage(iobytes)
# return hash_and_qual.getHash().dumpBitsFlat()
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],
)
Expand All @@ -36,35 +33,10 @@ def compute_sscd(self, image_url: str) -> str:
normalize,
])

model = torch.jit.load("sscd_disc_mixup.torchscript.pt")
# img = Image.open(image_file_path).convert('RGB')

response = requests.get(image_url)
img = Image.open(io.BytesIO(response.content))
# img = Image.open(image.body.url).convert('RGB')

batch = small_288(img).unsqueeze(0)
embedding = model(batch)[0, :]
image = Image.open(iobytes)
batch = small_288(image).unsqueeze(0)
embedding = self.model(batch)[0, :]
return np.asarray(embedding.detach().numpy()).tolist()

def get_iobytes_for_image(self, image: schemas.Message) -> io.BytesIO:
"""
Read file as bytes after requesting based on URL.
"""
return io.BytesIO(
urllib.request.urlopen(
urllib.request.Request(
image.body.url,
headers={'User-Agent': 'Mozilla/5.0'}
)
).read()
)

def process(self, image: schemas.Message) -> schemas.ImageOutput:
"""
Generic function for returning the actual response.
"""

# get_image_embeddings("example-image-airplane1.png",
# "/content/sscd-copy-detection/models_files/sscd_disc_mixup.torchscript.pt")
return {"embeddings": self.compute_sscd(image.body.url)}
def compute_imagehash(self, iobytes: io.BytesIO) -> str:
return self.compute_sscd(iobytes)
10 changes: 6 additions & 4 deletions lib/queue/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@ def safely_respond(self, model: Model) -> List[schemas.Message]:
responses = []
if messages_with_queues:
logger.debug(f"About to respond to: ({messages_with_queues})")
try:
responses = model.respond([schemas.Message(**{**json.loads(message.body), **{"model_name": model.model_name}}) for message, queue in messages_with_queues])
except Exception as e:
logger.error(e)
#try:
responses = model.respond([schemas.Message(**{**json.loads(message.body), **{"model_name": model.model_name}}) for message, queue in messages_with_queues])
logger.info("!!!!")
logger.info(responses)
#except Exception as e:
# logger.error(e)
self.delete_messages(messages_with_queues)
return responses

4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ pytest==7.4.0
sentry-sdk==1.30.0
pytorch-lightning==1.5.10
lightning-bolts==0.4.0
torch
torchvision
torch==1.9.0
torchvision==0.10.0
File renamed without changes.

0 comments on commit 5372d57

Please sign in to comment.