diff --git a/python-threatexchange/threatexchange/cli/match_cmd.py b/python-threatexchange/threatexchange/cli/match_cmd.py index 67d0878d4..3f01db6af 100644 --- a/python-threatexchange/threatexchange/cli/match_cmd.py +++ b/python-threatexchange/threatexchange/cli/match_cmd.py @@ -5,22 +5,30 @@ Match command for parsing simple data sources against the dataset. """ +from dataclasses import dataclass, field import argparse import logging import pathlib import typing as t - +import tempfile from threatexchange import common from threatexchange.cli.fetch_cmd import FetchCommand from threatexchange.cli.helpers import FlexFilesInputAction from threatexchange.exchanges.fetch_state import FetchedSignalMetadata -from threatexchange.signal_type.index import IndexMatch, SignalTypeIndex +from threatexchange.signal_type.index import ( + IndexMatch, + SignalTypeIndex, + IndexMatchUntyped, + SignalSimilarityInfo, + T, +) from threatexchange.cli.exceptions import CommandError from threatexchange.signal_type.signal_base import BytesHasher, SignalType from threatexchange.cli.cli_config import CLISettings -from threatexchange.content_type.content_base import ContentType +from threatexchange.content_type.content_base import ContentType, RotationType +from threatexchange.content_type.photo import PhotoContent from threatexchange.signal_type.signal_base import MatchesStr, TextHasher, FileHasher from threatexchange.cli import command_base @@ -29,6 +37,19 @@ TMatcher = t.Callable[[pathlib.Path], t.List[IndexMatch]] +@dataclass +class _IndexMatchWithRotation(t.Generic[T]): + match: IndexMatchUntyped[SignalSimilarityInfo, T] + rotation_type: t.Optional[RotationType] = field(default=None) + + def __str__(self): + # Supposed to be without whitespace, but let's make sure + distance_str = "".join(self.match.similarity_info.pretty_str().split()) + if self.rotation_type is None: + return distance_str + return f"{self.rotation_type.name} {distance_str}" + + class MatchCommand(command_base.Command): """ Match content to fetched signals @@ -126,6 +147,12 @@ def init_argparse(cls, settings: CLISettings, ap: argparse.ArgumentParser) -> No action="store_true", help="show all matches, not just one per collaboration", ) + ap.add_argument( + "--rotations", + "-R", + action="store_true", + help="for photos, generate and match all 8 simple rotations", + ) def __init__( self, @@ -136,6 +163,7 @@ def __init__( show_false_positives: bool, hide_disputed: bool, all: bool, + rotations: bool = False, ) -> None: self.content_type = content_type self.only_signal = only_signal @@ -144,6 +172,7 @@ def __init__( self.hide_disputed = hide_disputed self.files = files self.all = all + self.rotations = rotations if only_signal and content_type not in only_signal.get_content_types(): raise CommandError( @@ -152,6 +181,11 @@ def __init__( 2, ) + if self.rotations and not issubclass(content_type, PhotoContent): + raise CommandError( + "--rotations flag is only available for Photo content type", 2 + ) + def execute(self, settings: CLISettings) -> None: if not settings.index.list(): if not settings.in_demo_mode: @@ -196,18 +230,23 @@ def execute(self, settings: CLISettings) -> None: for s_type, index in indices: seen = set() # TODO - maybe take the highest certainty? if self.as_hashes: - results = _match_hashes(path, s_type, index) + results: t.Sequence[_IndexMatchWithRotation] = _match_hashes( + path, s_type, index + ) else: - results = _match_file(path, s_type, index) + results = _match_file(path, s_type, index, rotations=self.rotations) for r in results: - metadatas: t.List[t.Tuple[str, FetchedSignalMetadata]] = r.metadata + metadatas: t.List[t.Tuple[str, FetchedSignalMetadata]] = ( + r.match.metadata + ) + distance_str = str(r) + for collab, fetched_data in metadatas: if not self.all and collab in seen: continue seen.add(collab) - # Supposed to be without whitespace, but let's make sure - distance_str = "".join(r.similarity_info.pretty_str().split()) + print( s_type.get_name(), distance_str, @@ -217,18 +256,53 @@ def execute(self, settings: CLISettings) -> None: def _match_file( - path: pathlib.Path, s_type: t.Type[SignalType], index: SignalTypeIndex -) -> t.Sequence[IndexMatch]: + path: pathlib.Path, + s_type: t.Type[SignalType], + index: SignalTypeIndex, + rotations: bool = False, +) -> t.Sequence[_IndexMatchWithRotation]: if issubclass(s_type, MatchesStr): - return index.query(path.read_text()) + matches = index.query(path.read_text()) + return [_IndexMatchWithRotation(match=match) for match in matches] + assert issubclass(s_type, FileHasher) - return index.query(s_type.hash_from_file(path)) + + if not rotations: + matches = index.query(s_type.hash_from_file(path)) + return [_IndexMatchWithRotation(match=match) for match in matches] + + # Handle rotations for photos + with open(path, "rb") as f: + image_data = f.read() + + rotated_images: t.Dict[RotationType, bytes] = PhotoContent.all_simple_rotations( + image_data + ) + all_matches = [] + + for rotation_type, rotated_bytes in rotated_images.items(): + # Create a temporary file to hold the image bytes + with tempfile.NamedTemporaryFile() as temp_file: + temp_file.write(rotated_bytes) + temp_file_path = pathlib.Path(temp_file.name) + matches = index.query(s_type.hash_from_file(temp_file_path)) + + # Add rotation information if any matches were found + matches_with_rotations = [] + for match in matches: + matches_with_rotations.append( + _IndexMatchWithRotation(match=match, rotation_type=rotation_type) + ) + + all_matches.extend(matches_with_rotations) + + return all_matches def _match_hashes( path: pathlib.Path, s_type: t.Type[SignalType], index: SignalTypeIndex -) -> t.Sequence[IndexMatch]: - ret: t.List[IndexMatch] = [] +) -> t.Sequence[_IndexMatchWithRotation]: + ret: t.List[_IndexMatchWithRotation] = [] for hash in path.read_text().splitlines(): hash = hash.strip() if not hash: @@ -244,5 +318,6 @@ def _match_hashes( f"{hash_repr} from {path} is not a valid hash for {s_type.get_name()}", 2, ) - ret.extend(index.query(hash)) + matches = index.query(hash) + ret.extend([_IndexMatchWithRotation(match=match) for match in matches]) return ret diff --git a/python-threatexchange/threatexchange/cli/tests/match_cmd_test.py b/python-threatexchange/threatexchange/cli/tests/match_cmd_test.py index d774ffc36..fad596841 100644 --- a/python-threatexchange/threatexchange/cli/tests/match_cmd_test.py +++ b/python-threatexchange/threatexchange/cli/tests/match_cmd_test.py @@ -1,7 +1,14 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import pathlib import tempfile -from threatexchange.cli.tests.e2e_test_helper import ThreatExchangeCLIE2eTest +import os +from threatexchange.cli.tests.e2e_test_helper import ( + ThreatExchangeCLIE2eHelper, + ThreatExchangeCLIE2eTest, +) +from threatexchange.content_type.content_base import RotationType +from threatexchange.content_type.photo import PhotoContent from threatexchange.signal_type.md5 import VideoMD5Signal @@ -31,3 +38,32 @@ def test_invalid_hash(self): ("-H", "video", "--", not_hash), f"{not_hash!r} from .* is not a valid hash for video_md5", ) + + def test_non_photo_match_with_rotations(self): + with tempfile.NamedTemporaryFile() as f: + for content_type in ["url", "text", "video"]: + self.assert_cli_usage_error( + ("--rotations", content_type, f.name), + msg_regex="--rotations flag is only available for Photo content type", + ) + + def test_photo_hash_with_rotations(self): + test_file = pathlib.Path( + __file__ + "../../../../../../pdq/data/bridge-mods/aaa-orig.jpg" + ).resolve() + + rotated_images = PhotoContent.all_simple_rotations(test_file.read_bytes()) + + for rotation, image in rotated_images.items(): + with tempfile.NamedTemporaryFile() as tmp_file: + tmp_file.write(image) + + if rotation == RotationType.ROTATE270: + rotation = RotationType.ROTATE90 + elif rotation == RotationType.ROTATE90: + rotation = RotationType.ROTATE270 + + self.assert_cli_output( + ("--rotations", "photo", tmp_file.name), + f"pdq {rotation.name} 16 (Sample Signals) INVESTIGATION_SEED", + ) diff --git a/python-threatexchange/threatexchange/content_type/content_base.py b/python-threatexchange/threatexchange/content_type/content_base.py index 81e600877..e9ea6aada 100644 --- a/python-threatexchange/threatexchange/content_type/content_base.py +++ b/python-threatexchange/threatexchange/content_type/content_base.py @@ -7,7 +7,7 @@ This records all the valid signal types for a piece of content. """ -from enum import Enum, auto +from enum import Enum import typing as t from threatexchange import common diff --git a/python-threatexchange/threatexchange/content_type/photo.py b/python-threatexchange/threatexchange/content_type/photo.py index 6bdc50022..aa0f7e922 100644 --- a/python-threatexchange/threatexchange/content_type/photo.py +++ b/python-threatexchange/threatexchange/content_type/photo.py @@ -6,6 +6,7 @@ """ from PIL import Image import io +import typing as t from .content_base import ContentType, RotationType @@ -82,7 +83,7 @@ def flip_minus1(cls, image_data: bytes) -> bytes: return buffer.getvalue() @classmethod - def all_simple_rotations(cls, image_data: bytes): + def all_simple_rotations(cls, image_data: bytes) -> t.Dict[RotationType, bytes]: """ Generate the 8 naive rotations of an image. diff --git a/python-threatexchange/threatexchange/signal_type/index.py b/python-threatexchange/threatexchange/signal_type/index.py index ea3205bbd..a26f879e2 100644 --- a/python-threatexchange/threatexchange/signal_type/index.py +++ b/python-threatexchange/threatexchange/signal_type/index.py @@ -23,7 +23,6 @@ import pickle import typing as t - T = t.TypeVar("T") S_Co = t.TypeVar("S_Co", covariant=True, bound="SignalSimilarityInfo") CT = t.TypeVar("CT", bound="Comparable")