Skip to content

Commit

Permalink
[py-tx] Add a new match command line option (--rotations) (#1672)
Browse files Browse the repository at this point in the history
  • Loading branch information
haianhng31 authored Nov 9, 2024
1 parent 65c92ce commit 21f4f41
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 19 deletions.
105 changes: 90 additions & 15 deletions python-threatexchange/threatexchange/cli/match_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,30 @@
Match command for parsing simple data sources against the dataset.
"""

from dataclasses import dataclass, field
import argparse
import logging
import pathlib
import typing as t

import tempfile

from threatexchange import common
from threatexchange.cli.fetch_cmd import FetchCommand
from threatexchange.cli.helpers import FlexFilesInputAction
from threatexchange.exchanges.fetch_state import FetchedSignalMetadata

from threatexchange.signal_type.index import IndexMatch, SignalTypeIndex
from threatexchange.signal_type.index import (
IndexMatch,
SignalTypeIndex,
IndexMatchUntyped,
SignalSimilarityInfo,
T,
)
from threatexchange.cli.exceptions import CommandError
from threatexchange.signal_type.signal_base import BytesHasher, SignalType
from threatexchange.cli.cli_config import CLISettings
from threatexchange.content_type.content_base import ContentType
from threatexchange.content_type.content_base import ContentType, RotationType
from threatexchange.content_type.photo import PhotoContent

from threatexchange.signal_type.signal_base import MatchesStr, TextHasher, FileHasher
from threatexchange.cli import command_base
Expand All @@ -29,6 +37,19 @@
TMatcher = t.Callable[[pathlib.Path], t.List[IndexMatch]]


@dataclass
class _IndexMatchWithRotation(t.Generic[T]):
match: IndexMatchUntyped[SignalSimilarityInfo, T]
rotation_type: t.Optional[RotationType] = field(default=None)

def __str__(self):
# Supposed to be without whitespace, but let's make sure
distance_str = "".join(self.match.similarity_info.pretty_str().split())
if self.rotation_type is None:
return distance_str
return f"{self.rotation_type.name} {distance_str}"


class MatchCommand(command_base.Command):
"""
Match content to fetched signals
Expand Down Expand Up @@ -126,6 +147,12 @@ def init_argparse(cls, settings: CLISettings, ap: argparse.ArgumentParser) -> No
action="store_true",
help="show all matches, not just one per collaboration",
)
ap.add_argument(
"--rotations",
"-R",
action="store_true",
help="for photos, generate and match all 8 simple rotations",
)

def __init__(
self,
Expand All @@ -136,6 +163,7 @@ def __init__(
show_false_positives: bool,
hide_disputed: bool,
all: bool,
rotations: bool = False,
) -> None:
self.content_type = content_type
self.only_signal = only_signal
Expand All @@ -144,6 +172,7 @@ def __init__(
self.hide_disputed = hide_disputed
self.files = files
self.all = all
self.rotations = rotations

if only_signal and content_type not in only_signal.get_content_types():
raise CommandError(
Expand All @@ -152,6 +181,11 @@ def __init__(
2,
)

if self.rotations and not issubclass(content_type, PhotoContent):
raise CommandError(
"--rotations flag is only available for Photo content type", 2
)

def execute(self, settings: CLISettings) -> None:
if not settings.index.list():
if not settings.in_demo_mode:
Expand Down Expand Up @@ -196,18 +230,23 @@ def execute(self, settings: CLISettings) -> None:
for s_type, index in indices:
seen = set() # TODO - maybe take the highest certainty?
if self.as_hashes:
results = _match_hashes(path, s_type, index)
results: t.Sequence[_IndexMatchWithRotation] = _match_hashes(
path, s_type, index
)
else:
results = _match_file(path, s_type, index)
results = _match_file(path, s_type, index, rotations=self.rotations)

for r in results:
metadatas: t.List[t.Tuple[str, FetchedSignalMetadata]] = r.metadata
metadatas: t.List[t.Tuple[str, FetchedSignalMetadata]] = (
r.match.metadata
)
distance_str = str(r)

for collab, fetched_data in metadatas:
if not self.all and collab in seen:
continue
seen.add(collab)
# Supposed to be without whitespace, but let's make sure
distance_str = "".join(r.similarity_info.pretty_str().split())

print(
s_type.get_name(),
distance_str,
Expand All @@ -217,18 +256,53 @@ def execute(self, settings: CLISettings) -> None:


def _match_file(
path: pathlib.Path, s_type: t.Type[SignalType], index: SignalTypeIndex
) -> t.Sequence[IndexMatch]:
path: pathlib.Path,
s_type: t.Type[SignalType],
index: SignalTypeIndex,
rotations: bool = False,
) -> t.Sequence[_IndexMatchWithRotation]:
if issubclass(s_type, MatchesStr):
return index.query(path.read_text())
matches = index.query(path.read_text())
return [_IndexMatchWithRotation(match=match) for match in matches]

assert issubclass(s_type, FileHasher)
return index.query(s_type.hash_from_file(path))

if not rotations:
matches = index.query(s_type.hash_from_file(path))
return [_IndexMatchWithRotation(match=match) for match in matches]

# Handle rotations for photos
with open(path, "rb") as f:
image_data = f.read()

rotated_images: t.Dict[RotationType, bytes] = PhotoContent.all_simple_rotations(
image_data
)
all_matches = []

for rotation_type, rotated_bytes in rotated_images.items():
# Create a temporary file to hold the image bytes
with tempfile.NamedTemporaryFile() as temp_file:
temp_file.write(rotated_bytes)
temp_file_path = pathlib.Path(temp_file.name)
matches = index.query(s_type.hash_from_file(temp_file_path))

# Add rotation information if any matches were found
matches_with_rotations = []
for match in matches:
matches_with_rotations.append(
_IndexMatchWithRotation(match=match, rotation_type=rotation_type)
)

all_matches.extend(matches_with_rotations)

return all_matches


def _match_hashes(
path: pathlib.Path, s_type: t.Type[SignalType], index: SignalTypeIndex
) -> t.Sequence[IndexMatch]:
ret: t.List[IndexMatch] = []
) -> t.Sequence[_IndexMatchWithRotation]:
ret: t.List[_IndexMatchWithRotation] = []
for hash in path.read_text().splitlines():
hash = hash.strip()
if not hash:
Expand All @@ -244,5 +318,6 @@ def _match_hashes(
f"{hash_repr} from {path} is not a valid hash for {s_type.get_name()}",
2,
)
ret.extend(index.query(hash))
matches = index.query(hash)
ret.extend([_IndexMatchWithRotation(match=match) for match in matches])
return ret
38 changes: 37 additions & 1 deletion python-threatexchange/threatexchange/cli/tests/match_cmd_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

import pathlib
import tempfile
from threatexchange.cli.tests.e2e_test_helper import ThreatExchangeCLIE2eTest
import os
from threatexchange.cli.tests.e2e_test_helper import (
ThreatExchangeCLIE2eHelper,
ThreatExchangeCLIE2eTest,
)
from threatexchange.content_type.content_base import RotationType
from threatexchange.content_type.photo import PhotoContent
from threatexchange.signal_type.md5 import VideoMD5Signal


Expand Down Expand Up @@ -31,3 +38,32 @@ def test_invalid_hash(self):
("-H", "video", "--", not_hash),
f"{not_hash!r} from .* is not a valid hash for video_md5",
)

def test_non_photo_match_with_rotations(self):
with tempfile.NamedTemporaryFile() as f:
for content_type in ["url", "text", "video"]:
self.assert_cli_usage_error(
("--rotations", content_type, f.name),
msg_regex="--rotations flag is only available for Photo content type",
)

def test_photo_hash_with_rotations(self):
test_file = pathlib.Path(
__file__ + "../../../../../../pdq/data/bridge-mods/aaa-orig.jpg"
).resolve()

rotated_images = PhotoContent.all_simple_rotations(test_file.read_bytes())

for rotation, image in rotated_images.items():
with tempfile.NamedTemporaryFile() as tmp_file:
tmp_file.write(image)

if rotation == RotationType.ROTATE270:
rotation = RotationType.ROTATE90
elif rotation == RotationType.ROTATE90:
rotation = RotationType.ROTATE270

self.assert_cli_output(
("--rotations", "photo", tmp_file.name),
f"pdq {rotation.name} 16 (Sample Signals) INVESTIGATION_SEED",
)
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
This records all the valid signal types for a piece of content.
"""

from enum import Enum, auto
from enum import Enum
import typing as t

from threatexchange import common
Expand Down
3 changes: 2 additions & 1 deletion python-threatexchange/threatexchange/content_type/photo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
from PIL import Image
import io
import typing as t

from .content_base import ContentType, RotationType

Expand Down Expand Up @@ -82,7 +83,7 @@ def flip_minus1(cls, image_data: bytes) -> bytes:
return buffer.getvalue()

@classmethod
def all_simple_rotations(cls, image_data: bytes):
def all_simple_rotations(cls, image_data: bytes) -> t.Dict[RotationType, bytes]:
"""
Generate the 8 naive rotations of an image.
Expand Down
1 change: 0 additions & 1 deletion python-threatexchange/threatexchange/signal_type/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import pickle
import typing as t


T = t.TypeVar("T")
S_Co = t.TypeVar("S_Co", covariant=True, bound="SignalSimilarityInfo")
CT = t.TypeVar("CT", bound="Comparable")
Expand Down

0 comments on commit 21f4f41

Please sign in to comment.