Skip to content

Commit

Permalink
Merge branch 'main' into Issue-1666-Letter-Unboxing
Browse files Browse the repository at this point in the history
  • Loading branch information
Mackay-Fisher authored Nov 13, 2024
2 parents 9dba274 + 8823b84 commit bfc10da
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 46 deletions.
6 changes: 6 additions & 0 deletions pdq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ Before evaluating the results on your own to choose the thresholds that work for
* **Distance Threshold to consider two hashes to be similar/matching**: <=31
* **Quality Threshold where we recommend discarding hashes**: <=49

## Note on Dihedral PDQ Hashes

The PDQ hashing algorithm is easily capable of producing eight "dihedral" hashes (one for each 90 degree rotation and one for each flip across a horizontal, vertical or diagonal axis). However, PDQ does not guarantee exact rotational invariance. Small variations can occur in the hash values for each rotation due to how PDQ processes the image’s grid alignment in its DCT (Discrete Cosine Transform) phase.

For example, two rotated versions of an image can have a slightly different set of eight dihedral hashes. Selecting a "minimal" hash from these transformations (e.g., lexicographically) may yield inconsistent results because of these minor bit differences. For each image, if we select the minimal hash, there’s no guarantee that the same hash will be selected across different rotations. These inconsistencies arise when small bit variations lead to a different hash being identified as "minimal" for each rotation. For a clearer example, check this issue: ([https://github.com/facebook/ThreatExchange/issues/1676#issuecomment-2466331532](https://github.com/facebook/ThreatExchange/issues/1676#issuecomment-2466331532)).

## Contact

[email protected]
105 changes: 90 additions & 15 deletions python-threatexchange/threatexchange/cli/match_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,30 @@
Match command for parsing simple data sources against the dataset.
"""

from dataclasses import dataclass, field
import argparse
import logging
import pathlib
import typing as t

import tempfile

from threatexchange import common
from threatexchange.cli.fetch_cmd import FetchCommand
from threatexchange.cli.helpers import FlexFilesInputAction
from threatexchange.exchanges.fetch_state import FetchedSignalMetadata

from threatexchange.signal_type.index import IndexMatch, SignalTypeIndex
from threatexchange.signal_type.index import (
IndexMatch,
SignalTypeIndex,
IndexMatchUntyped,
SignalSimilarityInfo,
T,
)
from threatexchange.cli.exceptions import CommandError
from threatexchange.signal_type.signal_base import BytesHasher, SignalType
from threatexchange.cli.cli_config import CLISettings
from threatexchange.content_type.content_base import ContentType
from threatexchange.content_type.content_base import ContentType, RotationType
from threatexchange.content_type.photo import PhotoContent

from threatexchange.signal_type.signal_base import MatchesStr, TextHasher, FileHasher
from threatexchange.cli import command_base
Expand All @@ -29,6 +37,19 @@
TMatcher = t.Callable[[pathlib.Path], t.List[IndexMatch]]


@dataclass
class _IndexMatchWithRotation(t.Generic[T]):
match: IndexMatchUntyped[SignalSimilarityInfo, T]
rotation_type: t.Optional[RotationType] = field(default=None)

def __str__(self):
# Supposed to be without whitespace, but let's make sure
distance_str = "".join(self.match.similarity_info.pretty_str().split())
if self.rotation_type is None:
return distance_str
return f"{self.rotation_type.name} {distance_str}"


class MatchCommand(command_base.Command):
"""
Match content to fetched signals
Expand Down Expand Up @@ -126,6 +147,12 @@ def init_argparse(cls, settings: CLISettings, ap: argparse.ArgumentParser) -> No
action="store_true",
help="show all matches, not just one per collaboration",
)
ap.add_argument(
"--rotations",
"-R",
action="store_true",
help="for photos, generate and match all 8 simple rotations",
)

def __init__(
self,
Expand All @@ -136,6 +163,7 @@ def __init__(
show_false_positives: bool,
hide_disputed: bool,
all: bool,
rotations: bool = False,
) -> None:
self.content_type = content_type
self.only_signal = only_signal
Expand All @@ -144,6 +172,7 @@ def __init__(
self.hide_disputed = hide_disputed
self.files = files
self.all = all
self.rotations = rotations

if only_signal and content_type not in only_signal.get_content_types():
raise CommandError(
Expand All @@ -152,6 +181,11 @@ def __init__(
2,
)

if self.rotations and not issubclass(content_type, PhotoContent):
raise CommandError(
"--rotations flag is only available for Photo content type", 2
)

def execute(self, settings: CLISettings) -> None:
if not settings.index.list():
if not settings.in_demo_mode:
Expand Down Expand Up @@ -196,18 +230,23 @@ def execute(self, settings: CLISettings) -> None:
for s_type, index in indices:
seen = set() # TODO - maybe take the highest certainty?
if self.as_hashes:
results = _match_hashes(path, s_type, index)
results: t.Sequence[_IndexMatchWithRotation] = _match_hashes(
path, s_type, index
)
else:
results = _match_file(path, s_type, index)
results = _match_file(path, s_type, index, rotations=self.rotations)

for r in results:
metadatas: t.List[t.Tuple[str, FetchedSignalMetadata]] = r.metadata
metadatas: t.List[t.Tuple[str, FetchedSignalMetadata]] = (
r.match.metadata
)
distance_str = str(r)

for collab, fetched_data in metadatas:
if not self.all and collab in seen:
continue
seen.add(collab)
# Supposed to be without whitespace, but let's make sure
distance_str = "".join(r.similarity_info.pretty_str().split())

print(
s_type.get_name(),
distance_str,
Expand All @@ -217,18 +256,53 @@ def execute(self, settings: CLISettings) -> None:


def _match_file(
path: pathlib.Path, s_type: t.Type[SignalType], index: SignalTypeIndex
) -> t.Sequence[IndexMatch]:
path: pathlib.Path,
s_type: t.Type[SignalType],
index: SignalTypeIndex,
rotations: bool = False,
) -> t.Sequence[_IndexMatchWithRotation]:
if issubclass(s_type, MatchesStr):
return index.query(path.read_text())
matches = index.query(path.read_text())
return [_IndexMatchWithRotation(match=match) for match in matches]

assert issubclass(s_type, FileHasher)
return index.query(s_type.hash_from_file(path))

if not rotations:
matches = index.query(s_type.hash_from_file(path))
return [_IndexMatchWithRotation(match=match) for match in matches]

# Handle rotations for photos
with open(path, "rb") as f:
image_data = f.read()

rotated_images: t.Dict[RotationType, bytes] = PhotoContent.all_simple_rotations(
image_data
)
all_matches = []

for rotation_type, rotated_bytes in rotated_images.items():
# Create a temporary file to hold the image bytes
with tempfile.NamedTemporaryFile() as temp_file:
temp_file.write(rotated_bytes)
temp_file_path = pathlib.Path(temp_file.name)
matches = index.query(s_type.hash_from_file(temp_file_path))

# Add rotation information if any matches were found
matches_with_rotations = []
for match in matches:
matches_with_rotations.append(
_IndexMatchWithRotation(match=match, rotation_type=rotation_type)
)

all_matches.extend(matches_with_rotations)

return all_matches


def _match_hashes(
path: pathlib.Path, s_type: t.Type[SignalType], index: SignalTypeIndex
) -> t.Sequence[IndexMatch]:
ret: t.List[IndexMatch] = []
) -> t.Sequence[_IndexMatchWithRotation]:
ret: t.List[_IndexMatchWithRotation] = []
for hash in path.read_text().splitlines():
hash = hash.strip()
if not hash:
Expand All @@ -244,5 +318,6 @@ def _match_hashes(
f"{hash_repr} from {path} is not a valid hash for {s_type.get_name()}",
2,
)
ret.extend(index.query(hash))
matches = index.query(hash)
ret.extend([_IndexMatchWithRotation(match=match) for match in matches])
return ret
38 changes: 37 additions & 1 deletion python-threatexchange/threatexchange/cli/tests/match_cmd_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

import pathlib
import tempfile
from threatexchange.cli.tests.e2e_test_helper import ThreatExchangeCLIE2eTest
import os
from threatexchange.cli.tests.e2e_test_helper import (
ThreatExchangeCLIE2eHelper,
ThreatExchangeCLIE2eTest,
)
from threatexchange.content_type.content_base import RotationType
from threatexchange.content_type.photo import PhotoContent
from threatexchange.signal_type.md5 import VideoMD5Signal


Expand Down Expand Up @@ -31,3 +38,32 @@ def test_invalid_hash(self):
("-H", "video", "--", not_hash),
f"{not_hash!r} from .* is not a valid hash for video_md5",
)

def test_non_photo_match_with_rotations(self):
with tempfile.NamedTemporaryFile() as f:
for content_type in ["url", "text", "video"]:
self.assert_cli_usage_error(
("--rotations", content_type, f.name),
msg_regex="--rotations flag is only available for Photo content type",
)

def test_photo_hash_with_rotations(self):
test_file = pathlib.Path(
__file__ + "../../../../../../pdq/data/bridge-mods/aaa-orig.jpg"
).resolve()

rotated_images = PhotoContent.all_simple_rotations(test_file.read_bytes())

for rotation, image in rotated_images.items():
with tempfile.NamedTemporaryFile() as tmp_file:
tmp_file.write(image)

if rotation == RotationType.ROTATE270:
rotation = RotationType.ROTATE90
elif rotation == RotationType.ROTATE90:
rotation = RotationType.ROTATE270

self.assert_cli_output(
("--rotations", "photo", tmp_file.name),
f"pdq {rotation.name} 16 (Sample Signals) INVESTIGATION_SEED",
)
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
This records all the valid signal types for a piece of content.
"""

from enum import Enum, auto
from enum import Enum
import typing as t

from threatexchange import common
Expand Down
4 changes: 2 additions & 2 deletions python-threatexchange/threatexchange/content_type/photo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from PIL import Image
from pathlib import Path
import io
import os
import typing as t

from .content_base import ContentType, RotationType
from threatexchange.content_type.preprocess import unletterboxing
Expand Down Expand Up @@ -85,7 +85,7 @@ def flip_minus1(cls, image_data: bytes) -> bytes:
return buffer.getvalue()

@classmethod
def all_simple_rotations(cls, image_data: bytes):
def all_simple_rotations(cls, image_data: bytes) -> t.Dict[RotationType, bytes]:
"""
Generate the 8 naive rotations of an image.
Expand Down
1 change: 0 additions & 1 deletion python-threatexchange/threatexchange/signal_type/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import pickle
import typing as t


T = t.TypeVar("T")
S_Co = t.TypeVar("S_Co", covariant=True, bound="SignalSimilarityInfo")
CT = t.TypeVar("CT", bound="Comparable")
Expand Down
43 changes: 22 additions & 21 deletions python-threatexchange/threatexchange/tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

import os
import unittest
import pytest
import collections.abc

from threatexchange.exchanges.clients.fb_threatexchange.api import ThreatExchangeAPI
Expand All @@ -14,24 +14,25 @@
)


@unittest.skipUnless(
THREAT_EXCHANGE_INTEGRATION_TEST_TOKEN,
"Integration Test requires tokens. Use THREAT_EXCHANGE_INTEGRATION_TEST_TOKEN environment variable.",
@pytest.fixture
def api():
return ThreatExchangeAPI(THREAT_EXCHANGE_INTEGRATION_TEST_TOKEN)


need_token = pytest.mark.skipif(
not THREAT_EXCHANGE_INTEGRATION_TEST_TOKEN,
reason="Integration Test requires tokens. Use THREAT_EXCHANGE_INTEGRATION_TEST_TOKEN environment variable.",
)
class APIIntegrationTest(unittest.TestCase):
def setUp(self):
self.api = ThreatExchangeAPI(THREAT_EXCHANGE_INTEGRATION_TEST_TOKEN)

def test_get_threat_privacy_groups_member(self):
"""
Assumes that the app (if token is provided) will have at least one
privacy group.
"""
response = self.api.get_threat_privacy_groups_member()
self.assertTrue(
isinstance(response, collections.abc.Sequence)
and not isinstance(response, staticmethod),
"API returned something that's not a list!",
)

self.assertTrue(isinstance(response[0], ThreatPrivacyGroup))


@need_token
def test_get_threat_privacy_groups_member(api):
"""
Assumes that the app (if token is provided) will have at least one
privacy group.
"""
response = api.get_threat_privacy_groups_member()
assert isinstance(response, collections.abc.Sequence) and not isinstance(
response, (str, bytes)
), "API returned something that's not a list!"
assert isinstance(response[0], ThreatPrivacyGroup)
7 changes: 2 additions & 5 deletions python-threatexchange/threatexchange/tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

import unittest

import threatexchange.common


class TestCommon(unittest.TestCase):
def test_camel_case_to_underscore(self):
assert threatexchange.common.camel_case_to_underscore("AbcXyz") == "abc_xyz"
def test_camel_case_to_underscore():
assert threatexchange.common.camel_case_to_underscore("AbcXyz") == "abc_xyz"

0 comments on commit bfc10da

Please sign in to comment.