Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding nullable column to postgres images table with type Vector to s… #357

Open
wants to merge 12 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && apt-get install -y ffmpeg cmake swig libavcodec-dev libavformat-dev
RUN ln -s /usr/bin/ffmpeg /usr/local/bin/ffmpeg

#RUN git clone https://github.com/pgvector/pgvector-python.git
#RUN cd pgvector-python && pip install -r requirements.txt

# Copy necessary threatexchange folders
COPY ./threatexchange/tmk/cpp /app/threatexchange/tmk/cpp
COPY ./threatexchange/pdq/cpp /app/threatexchange/pdq/cpp
Expand Down
13 changes: 12 additions & 1 deletion app/main/controller/image_similarity_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ def delete(self):

@api.response(200, 'image signature successfully stored in the similarity database.')
@api.doc('Store an image signature in the similarity database')
@api.doc(params={
'url': 'image URL to be stored or queried for similarity',
'models': 'A list of image similarity models to use. Supported models are phash, pdq, and sscd',
'context': 'context'
})
@api.expect(image_similarity_request, validate=True)
def post(self):
return similarity.add_item(request.json, "image")
Expand All @@ -36,10 +41,16 @@ def request_package(self, request):
"context": self.get_from_args_or_json(request, 'context'),
"threshold": self.get_from_args_or_json(request, 'threshold'),
"limit": (self.get_from_args_or_json(request, 'limit') or similarity.DEFAULT_SEARCH_LIMIT),
"models": (self.get_from_args_or_json(request, 'models') or [app.config['IMAGE_MODEL']]),
}

@api.response(200, 'image similarity successfully queried.')
@api.doc('Make a image similarity query. Note that we currently require GET requests with a JSON body rather than embedded params in the URL. You can achieve this via curl -X GET -H "Content-type: application/json" -H "Accept: application/json" -d \'{"url":"http://some.link/video.mp4", "threshold": 0.5}\' "http://[ALEGRE_HOST]/image/similarity"')
@api.doc(params={'url': 'image URL to be stored or queried for similarity', 'threshold': 'minimum score to consider, between 0.0 and 1.0 (defaults to 0.9)', 'context': 'context'})
@api.doc(params={
'url': 'image URL to be stored or queried for similarity',
'models': 'A list of image similarity models to use. Supported models are "phash", "pdq", and "sscd"',
'threshold': 'minimum score to consider, between 0.0 and 1.0 (defaults to 0.9)',
'context': 'context'
})
def get(self):
return similarity.get_similar_items(self.request_package(request), "image")
69 changes: 63 additions & 6 deletions app/main/lib/image_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def add_image(save_params):
try:
if save_params.get("doc_id"):
delete_image(save_params)
image = ImageModel.from_url(save_params['url'], save_params.get('doc_id'), save_params['context'], save_params.get("created_at"))
image = ImageModel.from_url(save_params['url'], save_params.get('doc_id'), save_params['context'], save_params.get("created_at"), save_params.get('models'))
save(image)
return {
'success': True
Expand All @@ -65,19 +65,31 @@ def search_image(params):
context = params.get("context")
threshold = params.get("threshold")
limit = params.get("limit")
models = params.get("models")
# if models is None:
# models = ['phash']
try:
models = [m.lower() for m in models]
except:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No exception type(s) specified

app.logger.warn(f"Unable to lowercase list of models in search_image. {models}")

if not context:
context = {}
if not threshold:
threshold = 0.9
if url:
image = ImageModel.from_url(url, None)
result = []
image = ImageModel.from_url(url, None,models = models)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No space allowed around keyword argument assignment

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly one space required after comma

model=app.config['IMAGE_MODEL']
if model and model.lower()=="pdq":
if "pdq" in models:
app.logger.info(f"Searching with PDQ.")
result = search_by_pdq(image.pdq, threshold, context, limit)
else:
result += search_by_pdq(image.pdq, threshold, context, limit)
if "sscd" in models: # and image.sscd is not None:
app.logger.info(f"Searching with sscd.")
result += search_by_sscd(image.sscd, threshold, context, limit)
if "phash" in models:
app.logger.info(f"Searching with phash.")
result = search_by_phash(image.phash, threshold, context, limit)
result += search_by_phash(image.phash, threshold, context, limit)
else:
result = search_by_context(context, limit)
return {
Expand Down Expand Up @@ -196,3 +208,48 @@ def search_by_pdq(pdq, threshold, context, limit=None):
except Exception as e:
db.session.rollback()
raise e


@tenacity.retry(wait=tenacity.wait_fixed(0.5), stop=tenacity.stop_after_delay(5), after=_after_log)
def search_by_sscd(sscd, threshold, context, limit=None):

try:
context_query, context_hash = get_context_query(context)
# operator <=> is cosine distance (1 - cosine distance to give the similarity)
if context_query:
cmd = """
SELECT * FROM (
SELECT id, sha256, sscd, url, context, 1 - (sscd <=> :sscd)
AS score FROM images
) f
WHERE score >= :threshold
AND
"""+context_query+"""
ORDER BY score DESC
"""
else:
cmd = """
SELECT * FROM (
SELECT id, sha256, sscd, url, context, 1 - (sscd <=> :sscd)
AS score FROM images
) f
WHERE score >= :threshold
ORDER BY score DESC
"""
if limit:
cmd = cmd+" LIMIT :limit"
matches = db.session.execute(text(cmd), dict(**{
'sscd': str(sscd),
'threshold': threshold,
'limit': limit,
}, **context_hash)).fetchall()
keys = ('id', 'sha256', 'sscd', 'url', 'context', 'score')
rows = []
for values in matches:
row = dict(zip(keys, values))
row["model"] = "image/sscd"
rows.append(row)
return rows
except Exception as e:
db.session.rollback()
raise e
2 changes: 2 additions & 0 deletions app/main/lib/similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def callback_add_item(item, similarity_type):
if similarity_type == "audio":
response = audio_model().add(item)
app.logger.info(f"[Alegre Similarity] CallbackAddItem: [Item {item}, Similarity type: {similarity_type}] Response looks like {response}")
elif similarity_type == "image_sscd__Model":
return None
else:
app.logger.warning(f"[Alegre Similarity] InvalidCallbackAddItem: [Item {item}, Similarity type: {similarity_type}] Response looks like {response}")
return response
Expand Down
51 changes: 41 additions & 10 deletions app/main/model/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@

from app.main import db
from app.main.lib.image_hash import compute_phash_int, sha256_stream, compute_phash_int, compute_pdq
from pgvector.sqlalchemy import Vector

from app.main.lib.presto import Presto
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imports from package app are not grouped

import json
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

standard import import json should be placed before from flask import current_app as app

logging.basicConfig(level=logging.INFO)

class ImageModel(db.Model):
Expand All @@ -22,7 +25,7 @@ class ImageModel(db.Model):
doc_id = db.Column(db.String(64, convert_unicode=True), nullable=True, index=True, unique=True)
phash = db.Column(db.BigInteger, nullable=True, index=True)
pdq = db.Column(BIT(256), nullable=True, index=True)

sscd = db.Column(Vector(512), nullable=True, index=True)
url = db.Column(db.String(255, convert_unicode=True), nullable=False, index=True)
context = db.Column(JSONB(), default=[], nullable=False)
created_at = db.Column(db.DateTime, nullable=True)
Expand All @@ -31,23 +34,51 @@ class ImageModel(db.Model):
)

@staticmethod
def from_url(url, doc_id, context={}, created_at=None):
def from_url(url, doc_id, context={}, created_at=None, models=None):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too many local variables (18/15)

"""Fetch an image from a URL and load it
:param url: Image URL
:returns: ImageModel object
"""
app.logger.info(f"Starting image hash for doc_id {doc_id}.")
app.logger.info(f"Starting image hash for doc_id {doc_id} and models {models}.")
ImageFile.LOAD_TRUNCATED_IMAGES = True
remote_request = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'})
remote_response = urllib.request.urlopen(remote_request)
raw = remote_response.read()
im = Image.open(io.BytesIO(raw)).convert('RGB')
phash = compute_phash_int(im)
try:
pdq = compute_pdq(io.BytesIO(raw))
if not isinstance(models, list):
models = [models]
models = [m.lower() for m in models]
except:
pdq=None
e = sys.exc_info()[0]
app.logger.error(f"PDQ failure: {e}")
app.logger.warn(f"Unable to lowercase list of models in from_url. {models}")

phash = None
pdq = None
sscd = None
if "phash" in models:
im = Image.open(io.BytesIO(raw)).convert('RGB')
phash = compute_phash_int(im)
if "pdq" in models:
try:
pdq = compute_pdq(io.BytesIO(raw))
except:
pdq=None
e = sys.exc_info()[0]
app.logger.error(f"PDQ failure: {e}")
if "sscd" in models:
try:
# Call presto to calculate SSCD embeddings
callback_url = Presto.add_item_callback_url(app.config['ALEGRE_HOST'], "image_sscd__Model")
model_response_package = {"url": url
, "command": "add_item"}
response = Presto.send_request(app.config['PRESTO_HOST'], "image_sscd__Model", callback_url,
model_response_package).text
response = json.loads(response)
result = Presto.blocked_response(response, "image_sscd__Model")
sscd = result['body']['hash_value']
except:
sscd = None
e = sys.exc_info()[0]
app.logger.error(f"SSCD failure: {e}")

sha256 = sha256_stream(io.BytesIO(raw))
return ImageModel(sha256=sha256, phash=phash, pdq=pdq, url=url, context=context, doc_id=doc_id, created_at=created_at)
return ImageModel(sha256=sha256, phash=phash, pdq=pdq, url=url, context=context, doc_id=doc_id, created_at=created_at, sscd=sscd)
Loading