From 5372d570a96d47a4400f32264cadbb9082861e7c Mon Sep 17 00:00:00 2001 From: computermacgyver Date: Sat, 11 Nov 2023 16:21:37 +0000 Subject: [PATCH] Large refactor * Create GenericImageModel and have pdq and sscd inherit form it * Move model loading for sscd to __init__ * Update requirements.txt for sscd --- .env_file | 2 + Dockerfile | 3 +- lib/model/image.py | 37 ------------ lib/model/image_sscd.py | 58 +++++-------------- lib/queue/worker.py | 10 ++-- requirements.txt | 4 +- .../{test_image.py => test_image_pdq.py} | 0 7 files changed, 26 insertions(+), 88 deletions(-) delete mode 100644 lib/model/image.py rename test/lib/model/{test_image.py => test_image_pdq.py} (100%) diff --git a/.env_file b/.env_file index 1b880ab..9976955 100644 --- a/.env_file +++ b/.env_file @@ -3,5 +3,7 @@ PRESTO_PORT=8000 DEPLOY_ENV=local # MODEL_NAME=mean_tokens.Model MODEL_NAME=audio.Model +# MODEL_NAME=image_sscd.Model +# MODEL_NAME=image_pdq.Model AWS_ACCESS_KEY_ID=SOMETHING AWS_SECRET_ACCESS_KEY=OTHERTHING diff --git a/Dockerfile b/Dockerfile index ac66fed..70607fc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -31,8 +31,7 @@ RUN pip install pact-python RUN pip install --no-cache-dir -r requirements.txt RUN cd threatexchange/pdq/python && pip install . -RUN mkdir models_files -RUN wget https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_disc_mixup.torchscript.pt +RUN wget -O "sscd_disc_mixup.torchscript.pt" "https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_disc_mixup.torchscript.pt" COPY . . CMD ["make", "run"] diff --git a/lib/model/image.py b/lib/model/image.py deleted file mode 100644 index 346e320..0000000 --- a/lib/model/image.py +++ /dev/null @@ -1,37 +0,0 @@ -from typing import Dict -import io -import urllib.request - -from lib.model.model import Model - -from pdqhashing.hasher.pdq_hasher import PDQHasher -from lib import schemas - -class Model(Model): - def compute_pdq(self, iobytes: io.BytesIO) -> str: - """Compute perceptual hash using ImageHash library - :param im: Numpy.ndarray - :returns: Imagehash.ImageHash - """ - pdq_hasher = PDQHasher() - hash_and_qual = pdq_hasher.fromBufferedImage(iobytes) - return hash_and_qual.getHash().dumpBitsFlat() - - def get_iobytes_for_image(self, image: schemas.Message) -> io.BytesIO: - """ - Read file as bytes after requesting based on URL. - """ - return io.BytesIO( - urllib.request.urlopen( - urllib.request.Request( - image.body.url, - headers={'User-Agent': 'Mozilla/5.0'} - ) - ).read() - ) - - def process(self, image: schemas.Message) -> schemas.GenericItem: - """ - Generic function for returning the actual response. - """ - return self.compute_pdq(self.get_iobytes_for_image(image)) diff --git a/lib/model/image_sscd.py b/lib/model/image_sscd.py index 5f2fa6c..d462c18 100644 --- a/lib/model/image_sscd.py +++ b/lib/model/image_sscd.py @@ -1,27 +1,24 @@ from typing import Dict import io -import urllib.request -from lib.model.model import Model - -from pdqhashing.hasher.pdq_hasher import PDQHasher +from lib.model.generic_image import GenericImageModel from lib import schemas from torchvision import transforms -from PIL import Image import torch from lib.logger import logger -import requests import numpy as np +from PIL import Image + +class Model(GenericImageModel): + def __init__(self): + super().__init__() + self.model = torch.jit.load("sscd_disc_mixup.torchscript.pt") -class Model(Model): - def compute_sscd(self, image_url: str) -> str: + def compute_sscd(self, iobytes: io.BytesIO) -> str: """Compute perceptual hash using ImageHash library - :param im: Numpy.ndarray - :returns: Imagehash.ImageHash + :param im: Numpy.ndarray #FIXME + :returns: Imagehash.ImageHash #FIXME """ - # pdq_hasher = PDQHasher() - # hash_and_qual = pdq_hasher.fromBufferedImage(iobytes) - # return hash_and_qual.getHash().dumpBitsFlat() normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ) @@ -36,35 +33,10 @@ def compute_sscd(self, image_url: str) -> str: normalize, ]) - model = torch.jit.load("sscd_disc_mixup.torchscript.pt") - # img = Image.open(image_file_path).convert('RGB') - - response = requests.get(image_url) - img = Image.open(io.BytesIO(response.content)) - # img = Image.open(image.body.url).convert('RGB') - - batch = small_288(img).unsqueeze(0) - embedding = model(batch)[0, :] + image = Image.open(iobytes) + batch = small_288(image).unsqueeze(0) + embedding = self.model(batch)[0, :] return np.asarray(embedding.detach().numpy()).tolist() - def get_iobytes_for_image(self, image: schemas.Message) -> io.BytesIO: - """ - Read file as bytes after requesting based on URL. - """ - return io.BytesIO( - urllib.request.urlopen( - urllib.request.Request( - image.body.url, - headers={'User-Agent': 'Mozilla/5.0'} - ) - ).read() - ) - - def process(self, image: schemas.Message) -> schemas.ImageOutput: - """ - Generic function for returning the actual response. - """ - - # get_image_embeddings("example-image-airplane1.png", - # "/content/sscd-copy-detection/models_files/sscd_disc_mixup.torchscript.pt") - return {"embeddings": self.compute_sscd(image.body.url)} + def compute_imagehash(self, iobytes: io.BytesIO) -> str: + return self.compute_sscd(iobytes) diff --git a/lib/queue/worker.py b/lib/queue/worker.py index b45bbcb..4f6aacf 100644 --- a/lib/queue/worker.py +++ b/lib/queue/worker.py @@ -51,10 +51,12 @@ def safely_respond(self, model: Model) -> List[schemas.Message]: responses = [] if messages_with_queues: logger.debug(f"About to respond to: ({messages_with_queues})") - try: - responses = model.respond([schemas.Message(**{**json.loads(message.body), **{"model_name": model.model_name}}) for message, queue in messages_with_queues]) - except Exception as e: - logger.error(e) + #try: + responses = model.respond([schemas.Message(**{**json.loads(message.body), **{"model_name": model.model_name}}) for message, queue in messages_with_queues]) + logger.info("!!!!") + logger.info(responses) + #except Exception as e: + # logger.error(e) self.delete_messages(messages_with_queues) return responses diff --git a/requirements.txt b/requirements.txt index 7afdb62..730e91b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,5 +15,5 @@ pytest==7.4.0 sentry-sdk==1.30.0 pytorch-lightning==1.5.10 lightning-bolts==0.4.0 -torch -torchvision +torch==1.9.0 +torchvision==0.10.0 diff --git a/test/lib/model/test_image.py b/test/lib/model/test_image_pdq.py similarity index 100% rename from test/lib/model/test_image.py rename to test/lib/model/test_image_pdq.py