diff --git a/Dockerfile b/Dockerfile index 5a63d8df..e182b415 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,6 +6,9 @@ ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update && apt-get install -y ffmpeg cmake swig libavcodec-dev libavformat-dev RUN ln -s /usr/bin/ffmpeg /usr/local/bin/ffmpeg +#RUN git clone https://github.com/pgvector/pgvector-python.git +#RUN cd pgvector-python && pip install -r requirements.txt + # Copy necessary threatexchange folders COPY ./threatexchange/tmk/cpp /app/threatexchange/tmk/cpp COPY ./threatexchange/pdq/cpp /app/threatexchange/pdq/cpp diff --git a/app/main/controller/image_similarity_controller.py b/app/main/controller/image_similarity_controller.py index ae4604a9..6ee27ba4 100644 --- a/app/main/controller/image_similarity_controller.py +++ b/app/main/controller/image_similarity_controller.py @@ -23,6 +23,11 @@ def delete(self): @api.response(200, 'image signature successfully stored in the similarity database.') @api.doc('Store an image signature in the similarity database') + @api.doc(params={ + 'url': 'image URL to be stored or queried for similarity', + 'models': 'A list of image similarity models to use. Supported models are phash, pdq, and sscd', + 'context': 'context' + }) @api.expect(image_similarity_request, validate=True) def post(self): return similarity.add_item(request.json, "image") @@ -36,10 +41,16 @@ def request_package(self, request): "context": self.get_from_args_or_json(request, 'context'), "threshold": self.get_from_args_or_json(request, 'threshold'), "limit": (self.get_from_args_or_json(request, 'limit') or similarity.DEFAULT_SEARCH_LIMIT), + "models": (self.get_from_args_or_json(request, 'models') or [app.config['IMAGE_MODEL']]), } @api.response(200, 'image similarity successfully queried.') @api.doc('Make a image similarity query. Note that we currently require GET requests with a JSON body rather than embedded params in the URL. You can achieve this via curl -X GET -H "Content-type: application/json" -H "Accept: application/json" -d \'{"url":"http://some.link/video.mp4", "threshold": 0.5}\' "http://[ALEGRE_HOST]/image/similarity"') - @api.doc(params={'url': 'image URL to be stored or queried for similarity', 'threshold': 'minimum score to consider, between 0.0 and 1.0 (defaults to 0.9)', 'context': 'context'}) + @api.doc(params={ + 'url': 'image URL to be stored or queried for similarity', + 'models': 'A list of image similarity models to use. Supported models are "phash", "pdq", and "sscd"', + 'threshold': 'minimum score to consider, between 0.0 and 1.0 (defaults to 0.9)', + 'context': 'context' + }) def get(self): return similarity.get_similar_items(self.request_package(request), "image") diff --git a/app/main/lib/image_similarity.py b/app/main/lib/image_similarity.py index f5b17fd5..8c444546 100644 --- a/app/main/lib/image_similarity.py +++ b/app/main/lib/image_similarity.py @@ -51,7 +51,7 @@ def add_image(save_params): try: if save_params.get("doc_id"): delete_image(save_params) - image = ImageModel.from_url(save_params['url'], save_params.get('doc_id'), save_params['context'], save_params.get("created_at")) + image = ImageModel.from_url(save_params['url'], save_params.get('doc_id'), save_params['context'], save_params.get("created_at"), save_params.get('models')) save(image) return { 'success': True @@ -65,19 +65,31 @@ def search_image(params): context = params.get("context") threshold = params.get("threshold") limit = params.get("limit") + models = params.get("models") + # if models is None: + # models = ['phash'] + try: + models = [m.lower() for m in models] + except: + app.logger.warn(f"Unable to lowercase list of models in search_image. {models}") + if not context: context = {} if not threshold: threshold = 0.9 if url: - image = ImageModel.from_url(url, None) + result = [] + image = ImageModel.from_url(url, None,models = models) model=app.config['IMAGE_MODEL'] - if model and model.lower()=="pdq": + if "pdq" in models: app.logger.info(f"Searching with PDQ.") - result = search_by_pdq(image.pdq, threshold, context, limit) - else: + result += search_by_pdq(image.pdq, threshold, context, limit) + if "sscd" in models: # and image.sscd is not None: + app.logger.info(f"Searching with sscd.") + result += search_by_sscd(image.sscd, threshold, context, limit) + if "phash" in models: app.logger.info(f"Searching with phash.") - result = search_by_phash(image.phash, threshold, context, limit) + result += search_by_phash(image.phash, threshold, context, limit) else: result = search_by_context(context, limit) return { @@ -196,3 +208,48 @@ def search_by_pdq(pdq, threshold, context, limit=None): except Exception as e: db.session.rollback() raise e + + +@tenacity.retry(wait=tenacity.wait_fixed(0.5), stop=tenacity.stop_after_delay(5), after=_after_log) +def search_by_sscd(sscd, threshold, context, limit=None): + + try: + context_query, context_hash = get_context_query(context) + # operator <=> is cosine distance (1 - cosine distance to give the similarity) + if context_query: + cmd = """ + SELECT * FROM ( + SELECT id, sha256, sscd, url, context, 1 - (sscd <=> :sscd) + AS score FROM images + ) f + WHERE score >= :threshold + AND + """+context_query+""" + ORDER BY score DESC + """ + else: + cmd = """ + SELECT * FROM ( + SELECT id, sha256, sscd, url, context, 1 - (sscd <=> :sscd) + AS score FROM images + ) f + WHERE score >= :threshold + ORDER BY score DESC + """ + if limit: + cmd = cmd+" LIMIT :limit" + matches = db.session.execute(text(cmd), dict(**{ + 'sscd': str(sscd), + 'threshold': threshold, + 'limit': limit, + }, **context_hash)).fetchall() + keys = ('id', 'sha256', 'sscd', 'url', 'context', 'score') + rows = [] + for values in matches: + row = dict(zip(keys, values)) + row["model"] = "image/sscd" + rows.append(row) + return rows + except Exception as e: + db.session.rollback() + raise e diff --git a/app/main/lib/similarity.py b/app/main/lib/similarity.py index 8fd9f608..b51fa39a 100644 --- a/app/main/lib/similarity.py +++ b/app/main/lib/similarity.py @@ -126,6 +126,8 @@ def callback_add_item(item, similarity_type): if similarity_type == "audio": response = audio_model().add(item) app.logger.info(f"[Alegre Similarity] CallbackAddItem: [Item {item}, Similarity type: {similarity_type}] Response looks like {response}") + elif similarity_type == "image_sscd__Model": + return None else: app.logger.warning(f"[Alegre Similarity] InvalidCallbackAddItem: [Item {item}, Similarity type: {similarity_type}] Response looks like {response}") return response diff --git a/app/main/model/image.py b/app/main/model/image.py index 7cabc83d..110ac201 100644 --- a/app/main/model/image.py +++ b/app/main/model/image.py @@ -10,7 +10,10 @@ from app.main import db from app.main.lib.image_hash import compute_phash_int, sha256_stream, compute_phash_int, compute_pdq +from pgvector.sqlalchemy import Vector +from app.main.lib.presto import Presto +import json logging.basicConfig(level=logging.INFO) class ImageModel(db.Model): @@ -22,7 +25,7 @@ class ImageModel(db.Model): doc_id = db.Column(db.String(64, convert_unicode=True), nullable=True, index=True, unique=True) phash = db.Column(db.BigInteger, nullable=True, index=True) pdq = db.Column(BIT(256), nullable=True, index=True) - + sscd = db.Column(Vector(512), nullable=True, index=True) url = db.Column(db.String(255, convert_unicode=True), nullable=False, index=True) context = db.Column(JSONB(), default=[], nullable=False) created_at = db.Column(db.DateTime, nullable=True) @@ -31,23 +34,51 @@ class ImageModel(db.Model): ) @staticmethod - def from_url(url, doc_id, context={}, created_at=None): + def from_url(url, doc_id, context={}, created_at=None, models=None): """Fetch an image from a URL and load it :param url: Image URL :returns: ImageModel object """ - app.logger.info(f"Starting image hash for doc_id {doc_id}.") + app.logger.info(f"Starting image hash for doc_id {doc_id} and models {models}.") ImageFile.LOAD_TRUNCATED_IMAGES = True remote_request = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'}) remote_response = urllib.request.urlopen(remote_request) raw = remote_response.read() - im = Image.open(io.BytesIO(raw)).convert('RGB') - phash = compute_phash_int(im) try: - pdq = compute_pdq(io.BytesIO(raw)) + if not isinstance(models, list): + models = [models] + models = [m.lower() for m in models] except: - pdq=None - e = sys.exc_info()[0] - app.logger.error(f"PDQ failure: {e}") + app.logger.warn(f"Unable to lowercase list of models in from_url. {models}") + + phash = None + pdq = None + sscd = None + if "phash" in models: + im = Image.open(io.BytesIO(raw)).convert('RGB') + phash = compute_phash_int(im) + if "pdq" in models: + try: + pdq = compute_pdq(io.BytesIO(raw)) + except: + pdq=None + e = sys.exc_info()[0] + app.logger.error(f"PDQ failure: {e}") + if "sscd" in models: + try: + # Call presto to calculate SSCD embeddings + callback_url = Presto.add_item_callback_url(app.config['ALEGRE_HOST'], "image_sscd__Model") + model_response_package = {"url": url + , "command": "add_item"} + response = Presto.send_request(app.config['PRESTO_HOST'], "image_sscd__Model", callback_url, + model_response_package).text + response = json.loads(response) + result = Presto.blocked_response(response, "image_sscd__Model") + sscd = result['body']['hash_value'] + except: + sscd = None + e = sys.exc_info()[0] + app.logger.error(f"SSCD failure: {e}") + sha256 = sha256_stream(io.BytesIO(raw)) - return ImageModel(sha256=sha256, phash=phash, pdq=pdq, url=url, context=context, doc_id=doc_id, created_at=created_at) + return ImageModel(sha256=sha256, phash=phash, pdq=pdq, url=url, context=context, doc_id=doc_id, created_at=created_at, sscd=sscd) diff --git a/app/test/test_image_similarity.py b/app/test/test_image_similarity.py index e9b89c92..733222c9 100644 --- a/app/test/test_image_similarity.py +++ b/app/test/test_image_similarity.py @@ -46,23 +46,23 @@ def test_bit_count(self): self.assertEqual(result['test_count'], 0.5625) def test_truncated_image_fetch(self): - image = ImageModel.from_url('file:///app/app/test/data/truncated_img.jpg', '1-2-3') + image = ImageModel.from_url('file:///app/app/test/data/truncated_img.jpg', '1-2-3', models=['phash']) self.assertEqual(image.phash, 25444816931300591) def test_image_fetch(self): - image = ImageModel.from_url('file:///app/app/test/data/lenna-512.png', '1-2-3') + image = ImageModel.from_url('file:///app/app/test/data/lenna-512.png', '1-2-3', models=['phash']) self.assertEqual(image.phash, 45655524591978137) def test_image_api(self): url = 'file:///app/app/test/data/lenna-512.png' - # Test adding an image. response = self.client.post('/image/similarity/', data=json.dumps({ 'url': url, 'context': { 'team_id': 1, 'project_media_id': 1 - } + }, + 'models': ['phash'] }), content_type='application/json') result = json.loads(response.data.decode()) self.assertEqual(True, result['success']) @@ -74,11 +74,12 @@ def test_image_api(self): 'context': { 'team_id': 2, 'project_media_id': 2 - } + }, + 'models': ['phash'] }), content_type='application/json') result = json.loads(response.data.decode()) self.assertEqual(True, result['success']) - image = ImageModel.from_url(url, '1-2-3') + image = ImageModel.from_url(url, '1-2-3', models=['phash']) self.assertListEqual([ { 'team_id': 1, @@ -94,7 +95,8 @@ def test_image_api(self): response = self.client.get('/image/similarity/', data=json.dumps({ 'context': { 'team_id': 2 - } + }, + 'models': ['phash'] }), content_type='application/json') result = json.loads(response.data.decode()) self.assertEqual(1, len(result['result'])) @@ -103,7 +105,8 @@ def test_image_api(self): response = self.client.get('/image/similarity/', data=json.dumps({ 'context': { 'team_id': [2, 3] - } + }, + 'models': ['phash'] }), content_type='application/json') result = json.loads(response.data.decode()) self.assertEqual(1, len(result['result'])) @@ -112,7 +115,8 @@ def test_image_api(self): response = self.client.get('/image/similarity/', data=json.dumps({ 'context': { 'team_id': [-1, -2] - } + }, + 'models': ['phash'] }), content_type='application/json') result = json.loads(response.data.decode()) self.assertEqual(0, len(result['result'])) @@ -122,7 +126,8 @@ def test_image_api(self): response = self.client.get('/image/similarity/', data=json.dumps({ 'url': url, 'threshold': 1.0, - 'context': {} + 'context': {}, + 'models': ['phash'] }), content_type='application/json') result = json.loads(response.data.decode()) self.assertEqual(1, len(result['result'])) @@ -133,7 +138,9 @@ def test_image_api(self): 'threshold': 1.0, 'context': { 'team_id': 1 - } + }, + 'models': ['phash'] + }), content_type='application/json') result = json.loads(response.data.decode()) self.assertEqual(1, len(result['result'])) @@ -144,7 +151,9 @@ def test_image_api(self): 'threshold': 1.0, 'context': { 'team_id': [1, 2, 3] - } + }, + 'models': ['phash'] + }), content_type='application/json') result = json.loads(response.data.decode()) self.assertEqual(1, len(result['result'])) @@ -155,7 +164,9 @@ def test_image_api(self): 'threshold': 1.0, 'context': { 'team_id': [-1, -2] - } + }, + 'models': ['phash'] + }), content_type='application/json') result = json.loads(response.data.decode()) self.assertEqual(0, len(result['result'])) @@ -165,7 +176,9 @@ def test_image_api(self): response = self.client.get('/image/similarity/', data=json.dumps({ 'url': url, 'threshold': 1.0, - 'context': {} + 'context': {}, + 'models': ['phash'] + }), content_type='application/json') result = json.loads(response.data.decode()) self.assertEqual(0, len(result['result'])) @@ -173,7 +186,9 @@ def test_image_api(self): 'url': url, 'context': { 'team_id': 2 - } + }, + 'models': ['phash'] + }), content_type='application/json') # threshold should default to 0.9 == round(1 - 0.9) * 64.0 == 6 result = json.loads(response.data.decode()) self.assertEqual(1, len(result['result'])) @@ -187,7 +202,8 @@ def test_update_image(self): 'context': { 'team_id': 1, 'project_media_id': 1 - } + }, + 'models': ['phash'] }), content_type='application/json') result = json.loads(response.data.decode()) url = 'file:///app/app/test/data/lenna-512.png' @@ -198,7 +214,8 @@ def test_update_image(self): 'context': { 'team_id': 2, 'project_media_id': 2 - } + }, + 'models': ['phash'] }), content_type='application/json') result = json.loads(response.data.decode()) image = ImageModel.query.filter_by(url=url).all()[0] @@ -213,7 +230,8 @@ def test_delete_image(self): 'context': { 'team_id': 1, 'project_media_id': 1 - } + }, + 'models': ['phash'] }), content_type='application/json') result = json.loads(response.data.decode()) self.assertEqual(True, result['success']) @@ -224,7 +242,8 @@ def test_delete_image(self): 'context': { 'team_id': 1, 'project_media_id': 1 - } + }, + 'models': ['phash'] }), content_type='application/json') # threshold should default to 0.9 == round(1 - 0.9) * 64.0 == 6 result = json.loads(response.data.decode()) self.assertEqual(True, result['deleted']) @@ -242,7 +261,8 @@ def test_image_api_error(self): 'context': { 'team_id': 1, 'project_media_id': 1 - } + }, + 'models': ['phash'] }), content_type='application/json') result = json.loads(response.data.decode()) self.assertEqual(500, response.status_code) @@ -256,7 +276,8 @@ def test_image_api_error(self): 'context': { 'team_id': 1, 'project_media_id': 1 - } + }, + 'models': ['phash'] }), content_type='application/json') self.assertEqual(500, response.status_code) @@ -271,7 +292,8 @@ def test_add_image_error(self): 'context': { 'team_id': 1, 'project_media_id': 1 - } + }, + 'models': ['phash'] }), content_type='application/json') self.assertEqual(500, response.status_code) @@ -282,14 +304,16 @@ def test_search_by_context_error(self): 'context': { 'team_id': 2, 'project_media_id': 1 - } + }, + 'models': ['phash'] }), content_type='application/json') result = json.loads(response.data.decode()) self.assertEqual(True, result['success']) response = self.client.get('/image/similarity/', data=json.dumps({ 'context': { 'team_id': 'aa' - } + }, + 'models': ['phash'] }), content_type='application/json') self.assertEqual(500, response.status_code) result = json.loads(response.data.decode()) @@ -301,13 +325,15 @@ def test_search_with_empty_context(self): 'context': { 'team_id': 2, 'project_media_id': 1 - } + }, + 'models': ['phash'] }), content_type='application/json') result = json.loads(response.data.decode()) self.assertEqual(True, result['success']) response = self.client.get('/image/similarity/', data=json.dumps({ 'context': { - } + }, + 'models': ['phash'] }), content_type='application/json') self.assertEqual(200, response.status_code) @@ -319,12 +345,14 @@ def test_search_using_url(self): 'context': { 'team_id': 2, 'project_media_id': 1 - } + }, + 'models': ['phash'] }), content_type='application/json') result = json.loads(response.data.decode()) self.assertEqual(True, result['success']) response = self.client.get('/image/similarity/', data=json.dumps({ - 'url': url + 'url': url, + 'models': ['phash'] }), content_type='application/json') result = get_context_query(context, False, True) self.assertIn({'context_team_id': 2}, result) diff --git a/manage.py b/manage.py index bffdb1d9..91a63eb8 100644 --- a/manage.py +++ b/manage.py @@ -19,6 +19,7 @@ from app.main.lib.language_analyzers import init_indices from app.main.lib.image_hash import compute_phash_int from PIL import Image +from sqlalchemy import text # Don't remove this line until https://github.com/tensorflow/tensorflow/issues/34607 is fixed # (by upgrading to tensorflow 2.2 or higher) @@ -229,6 +230,13 @@ def init_perl_functions(): LANGUAGE plperl; """) ) + sqlalchemy.event.listen( + db.metadata, + 'before_create', + DDL(""" + CREATE EXTENSION IF NOT EXISTS vector; + """) + ) db.create_all() @manager.command diff --git a/migrations/versions/61ac93be86b2_create_sscd_column.py b/migrations/versions/61ac93be86b2_create_sscd_column.py new file mode 100644 index 00000000..701d2814 --- /dev/null +++ b/migrations/versions/61ac93be86b2_create_sscd_column.py @@ -0,0 +1,37 @@ +"""create sscd column + +Revision ID: 61ac93be86b2 +Revises: e495509fad52 +Create Date: 2023-11-06 20:57:37.335903 + +""" +from alembic import op +import sqlalchemy as sa +from pgvector.sqlalchemy import Vector +from app.main import create_app, db +import sqlalchemy +from sqlalchemy.schema import DDL + + +# revision identifiers, used by Alembic. +revision = '61ac93be86b2' +down_revision = 'e495509fad52' +branch_labels = None +depends_on = None + + +def upgrade(): + sqlalchemy.event.listen( + db.metadata, + 'before_create', + DDL(""" + CREATE EXTENSION IF NOT EXISTS vector; + """) + ) + op.add_column('images', sa.Column('sscd', Vector(256), nullable=True)) + op.create_index(op.f('ix_images_sscd'), 'images', ['sscd'], unique=False) + +def downgrade(): + op.drop_index(op.f('ix_images_sscd'), table_name='images') + op.drop_column('images', 'sscd') + diff --git a/postgres/Dockerfile b/postgres/Dockerfile index 2d9ccedf..1efa94eb 100644 --- a/postgres/Dockerfile +++ b/postgres/Dockerfile @@ -15,6 +15,7 @@ RUN apt-get update && \ apt-transport-https \ libcurl3-gnutls \ gawk \ + postgresql-13-pgvector \ postgresql-plperl-13 \ && localedef -i ru_RU -c -f UTF-8 -A /usr/share/locale/locale.alias ru_RU.UTF-8 \ && rm -rf /var/lib/apt/lists/* diff --git a/requirements.txt b/requirements.txt index 9d18e43b..f77c0991 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +pgvector==0.1.8 openai[embeddings]==0.27.4 matplotlib==3.5.3 plotly==5.14.1