Skip to content

Commit

Permalink
[py-tx] embeded tx hash for unletterboxing (#1684)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mackay-Fisher authored Nov 27, 2024
1 parent 2ae7f50 commit cb1f21f
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 29 deletions.
108 changes: 83 additions & 25 deletions python-threatexchange/threatexchange/cli/hash_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
import pathlib
import typing as t
import tempfile
from pathlib import Path

from threatexchange import common
from threatexchange.cli.cli_config import CLISettings
from threatexchange.cli.exceptions import CommandError
from threatexchange.content_type.content_base import ContentType
from threatexchange.content_type.photo import PhotoContent
from threatexchange.content_type.content_base import RotationType

from threatexchange.signal_type.signal_base import FileHasher, SignalType
from threatexchange.cli import command_base
Expand Down Expand Up @@ -53,6 +55,7 @@ def init_argparse(cls, settings: CLISettings, ap: argparse.ArgumentParser) -> No
signal_choices = sorted(
s.get_name() for s in signal_types if issubclass(s, FileHasher)
)

ap.add_argument(
"content_type",
**common.argparse_choices_pre_type_kwargs(
Expand Down Expand Up @@ -80,25 +83,50 @@ def init_argparse(cls, settings: CLISettings, ap: argparse.ArgumentParser) -> No
)

ap.add_argument(
"--rotations",
"--R",
"--photo-preprocess",
choices=["unletterbox", "rotations"],
help=(
"Apply one of the preprocessing steps to the image before hashing. "
"'unletterbox' removes black borders, and 'rotations' generates all 8 "
"simple rotations."
),
)

ap.add_argument(
"--black-threshold",
type=int,
default=15,
help=(
"Set the black threshold for unletterboxing (default: 15)."
"Only applies when 'unletterbox' is selected in --preprocess."
),
)

ap.add_argument(
"--save-preprocess",
action="store_true",
help="for photos, generate all 8 simple rotations",
help="save the preprocessed image data as new files",
)

def __init__(
self,
content_type: t.Type[ContentType],
signal_type: t.Optional[t.Type[SignalType]],
files: t.List[pathlib.Path],
rotations: bool = False,
photo_preprocess: t.Optional[str] = None,
black_threshold: int = 0,
save_preprocess: bool = False,
) -> None:
self.content_type = content_type
self.signal_type = signal_type

self.photo_preprocess = photo_preprocess
self.black_threshold = black_threshold
self.save_preprocess = save_preprocess
self.files = files

self.rotations = rotations
if self.photo_preprocess and not issubclass(self.content_type, PhotoContent):
raise CommandError(
"--photo-preprocess flag is only available for Photo content type", 2
)

def execute(self, settings: CLISettings) -> None:
hashers = [
Expand All @@ -115,28 +143,58 @@ def execute(self, settings: CLISettings) -> None:

hashers = [self.signal_type] # type: ignore # can't detect intersection types

if not self.rotations:
if not self.photo_preprocess:
for file in self.files:
for hasher in hashers:
hash_str = hasher.hash_from_file(file)
if hash_str:
print(hasher.get_name(), hash_str)
return

if not issubclass(self.content_type, PhotoContent):
raise CommandError(
"--rotations flag is only available for Photo content type", 2
)

for file in self.files:
with open(file, "rb") as f:
image_bytes = f.read()
rotated_images = PhotoContent.all_simple_rotations(image_bytes)
for rotation_type, rotated_bytes in rotated_images.items():
with tempfile.NamedTemporaryFile() as temp_file: # Create a temporary file to hold the byte data
temp_file.write(rotated_bytes)
temp_file_path = pathlib.Path(temp_file.name)
for hasher in hashers:
hash_str = hasher.hash_from_file(temp_file_path)
if hash_str:
print(rotation_type.name, hasher.get_name(), hash_str)
def pre_processed_files() -> (
t.Iterator[t.Tuple[Path, bytes, t.Union[None, RotationType], str]]
):
"""
Generator that yields preprocessed files and their metadata.
Each item is a tuple of (file path, processed bytes, rotation name, image format).
"""
for file in self.files:
image_format = file.suffix.lower().lstrip(".")
if self.photo_preprocess == "unletterbox":
processed_bytes = PhotoContent.unletterbox(
file, self.black_threshold
)
yield file, processed_bytes, None, image_format
elif self.photo_preprocess == "rotations":
with open(file, "rb") as f:
image_bytes = f.read()
rotations = PhotoContent.all_simple_rotations(image_bytes)
for rotation_type, processed_bytes in rotations.items():
yield file, processed_bytes, rotation_type, image_format

for (
file,
processed_bytes,
rotation_type,
image_format,
) in pre_processed_files():
output_extension = f".{image_format.lower()}" if image_format else ".png"
with tempfile.NamedTemporaryFile(
delete=not self.save_preprocess, suffix=output_extension
) as temp_file:
temp_file.write(processed_bytes)
temp_file_path = Path(temp_file.name)
for hasher in hashers:
hash_str = hasher.hash_from_file(temp_file_path)
if hash_str:
prefix = rotation_type.name if rotation_type else ""
print(f"{prefix} {hasher.get_name()} {hash_str}")
if self.save_preprocess:
suffix = (
f"_{rotation_type.name}" if rotation_type else "_unletterboxed"
)
output_path = file.with_stem(f"{file.stem}{suffix}").with_suffix(
output_extension
)
temp_file_path.rename(output_path)
print(f"Processed image saved to: {output_path}")
58 changes: 54 additions & 4 deletions python-threatexchange/threatexchange/cli/tests/hash_cmd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,20 @@ def test_rotations_with_non_photo_content(
"""Test that rotation flag raises error with non-photo content"""
for content_type in ["url", "text", "video"]:
hash_cli.assert_cli_usage_error(
("--rotations", content_type, str(tmp_file)),
msg_regex="--rotations flag is only available for Photo content type",
("--photo-preprocess=rotations", content_type, str(tmp_file)),
msg_regex="--photo-preprocess flag is only available for Photo content type",
)


def test_rotations_with_photo_content(hash_cli: ThreatExchangeCLIE2eHelper):
"""Test that photo rotations are properly processed"""
test_file = pathlib.Path("threatexchange/tests/hashing/resources/LA.png")
resources_dir = (
pathlib.Path(__file__).parent.parent.parent / "tests/hashing/resources"
)
test_file = resources_dir / "LA.png"

hash_cli.assert_cli_output(
("--rotations", "photo", str(test_file)),
("--photo-preprocess=rotations", "photo", str(test_file)),
[
"ORIGINAL pdq accb6d39648035f8125c8ce6ba65007de7b54c67a2d93ef7b8f33b0611306715",
"ROTATE90 pdq 1f70cbbc77edc5f9524faa1b18f3b76cd0a04a833e20f645d229d0acc8499c56",
Expand All @@ -105,3 +108,50 @@ def test_rotations_with_photo_content(hash_cli: ThreatExchangeCLIE2eHelper):
"FLIPMINUS1 pdq 5bb15db9e8a1f03c174a380a55aeaa2985bde9c60abce301bde48df918b5c15b",
],
)


def test_unletterbox_with_non_photo_content(
hash_cli: ThreatExchangeCLIE2eHelper, tmp_file: pathlib.Path
):
"""Test that unletterbox flag raises error with non-photo content"""
for content_type in ["url", "text", "video"]:
hash_cli.assert_cli_usage_error(
("--photo-preprocess=unletterbox", content_type, str(tmp_file)),
msg_regex="--photo-preprocess flag is only available for Photo content type",
)


def test_unletterbox_with_photo_content(hash_cli: ThreatExchangeCLIE2eHelper):
"""Test that photo unletterboxing is properly processed"""
resources_dir = (
pathlib.Path(__file__).parent.parent.parent / "tests/hashing/resources"
)
test_file = resources_dir / "letterboxed_sample-b.jpg"
clean_file = resources_dir / "sample-b.jpg"

hash_cli.assert_cli_output(
("photo", str(clean_file)),
[
"pdq f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22",
],
)

"""Test that photo unletterboxing is changed based on allowed threshold"""
hash_cli.assert_cli_output(
("--photo-preprocess=unletterbox", "photo", str(test_file)),
[
"pdq d8f871cce0f4e84d8a370a32028f63f4b36e27d597621e1d33e6b39c4a9c9b22",
],
)

hash_cli.assert_cli_output(
(
"--photo-preprocess=unletterbox",
"--black-threshold=25",
"photo",
str(test_file),
),
[
"pdq f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22",
],
)
23 changes: 23 additions & 0 deletions python-threatexchange/threatexchange/content_type/photo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
Wrapper around the video content type.
"""
from PIL import Image
from pathlib import Path
import io
import typing as t

from .content_base import ContentType, RotationType
from threatexchange.content_type.preprocess import unletterboxing


class PhotoContent(ContentType):
Expand Down Expand Up @@ -102,3 +104,24 @@ def all_simple_rotations(cls, image_data: bytes) -> t.Dict[RotationType, bytes]:
RotationType.FLIPMINUS1: cls.flip_minus1(image_data),
}
return rotations

@classmethod
def unletterbox(cls, file_path: Path, black_threshold: int = 0) -> bytes:
"""
Remove black letterbox borders from the sides and top of the image based on the specified black_threshold.
Returns the cleaned image as raw bytes.
"""
with file_path.open("rb") as file:
with Image.open(file) as image:
img = image.convert("RGB")
top = unletterboxing.detect_top_border(img, black_threshold)
bottom = unletterboxing.detect_bottom_border(img, black_threshold)
left = unletterboxing.detect_left_border(img, black_threshold)
right = unletterboxing.detect_right_border(img, black_threshold)

width, height = image.size
cropped_img = image.crop((left, top, width - right, height - bottom))

with io.BytesIO() as buffer:
cropped_img.save(buffer, format=image.format)
return buffer.getvalue()
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#!/usr/bin/env python
# Copyright (c) Meta Platforms, Inc. and affiliates.
from PIL import Image


def is_pixel_black(pixel: tuple, black_threshold: int):
"""
Check if each color channel in the pixel is below the threshold
"""
r, g, b = pixel
return r <= black_threshold and g <= black_threshold and b <= black_threshold


def detect_top_border(image: Image.Image, black_threshold: int = 0) -> int:
"""
Detect the top black border by counting rows with only black pixels.
Checks each RGB channel of each pixel in each row.
Returns the first row that is not all black from the top.
"""
width, height = image.size
for y in range(height):
row_pixels = list(image.crop((0, y, width, y + 1)).getdata())
if all(is_pixel_black(pixel, black_threshold) for pixel in row_pixels):
continue
return y
return height


def detect_bottom_border(image: Image.Image, black_threshold: int = 0) -> int:
"""
Detect the bottom black border by counting rows with only black pixels from the bottom up.
Checks each RGB channel of each pixel in each row.
Returns the first row that is not all black from the bottom.
"""
width, height = image.size
for y in range(height - 1, -1, -1):
row_pixels = list(image.crop((0, y, width, y + 1)).getdata())
if all(is_pixel_black(pixel, black_threshold) for pixel in row_pixels):
continue
return height - y - 1
return height


def detect_left_border(image: Image.Image, black_threshold: int = 0) -> int:
"""
Detect the left black border by counting columns with only black pixels.
Checks each RGB channel of each pixel in each column.
Returns the first column from the left that is not all black.
"""
width, height = image.size
for x in range(width):
col_pixels = list(image.crop((x, 0, x + 1, height)).getdata())
if all(is_pixel_black(pixel, black_threshold) for pixel in col_pixels):
continue
return x
return width


def detect_right_border(image: Image.Image, black_threshold: int = 0) -> int:
"""
Detect the right black border by counting columns with only black pixels from the right.
Checks each RGB channel of each pixel in each column.
Returns the first column from the right that is not all black.
"""
width, height = image.size
for x in range(width - 1, -1, -1):
col_pixels = list(image.crop((x, 0, x + 1, height)).getdata())
if all(is_pixel_black(pixel, black_threshold) for pixel in col_pixels):
continue
return width - x - 1
return width
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit cb1f21f

Please sign in to comment.