diff --git a/python-threatexchange/threatexchange/cli/hash_cmd.py b/python-threatexchange/threatexchange/cli/hash_cmd.py index 3848c41be..947b45a6a 100644 --- a/python-threatexchange/threatexchange/cli/hash_cmd.py +++ b/python-threatexchange/threatexchange/cli/hash_cmd.py @@ -19,7 +19,6 @@ from threatexchange.signal_type.signal_base import FileHasher, SignalType from threatexchange.cli import command_base from threatexchange.cli.helpers import FlexFilesInputAction -from threatexchange.signal_type.pdq.signal import PdqSignal class HashCommand(command_base.Command): @@ -54,6 +53,7 @@ def init_argparse(cls, settings: CLISettings, ap: argparse.ArgumentParser) -> No signal_choices = sorted( s.get_name() for s in signal_types if issubclass(s, FileHasher) ) + ap.add_argument( "content_type", **common.argparse_choices_pre_type_kwargs( @@ -81,30 +81,29 @@ def init_argparse(cls, settings: CLISettings, ap: argparse.ArgumentParser) -> No ) ap.add_argument( - "--preprocess", - choices=["unletterbox"], - help="Apply preprocessing steps to the image before hashing.", + "--photo-preprocess", + choices=["unletterbox", "rotations"], + help=( + "Apply one of the preprocessing steps to the image before hashing. " + "'unletterbox' removes black borders, and 'rotations' generates all 8 " + "simple rotations." + ), ) ap.add_argument( "--black-threshold", type=int, - default=40, - help="Set the black threshold for unletterboxing. Default is 40.", + default=10, + help=( + "Set the black threshold for unletterboxing (default: 5)." + "Only applies when 'unletterbox' is selected in --preprocess." + ), ) ap.add_argument( "--save-output", - type=bool, - default=False, - help="If true, save the processed image as a new file.", - ) - - ap.add_argument( - "--rotations", - "--R", action="store_true", - help="for photos, generate all 8 simple rotations", + help="If true, saves the processed image as a new file.", ) def __init__( @@ -112,19 +111,20 @@ def __init__( content_type: t.Type[ContentType], signal_type: t.Optional[t.Type[SignalType]], files: t.List[pathlib.Path], - rotations: bool = False, - preprocess: t.Optional[str] = None, - black_threshold: int = 40, + photo_preprocess: t.Optional[str] = None, + black_threshold: int = 0, save_output: bool = False, ) -> None: self.content_type = content_type self.signal_type = signal_type - self.preprocess = preprocess + self.photo_preprocess = photo_preprocess self.black_threshold = black_threshold self.save_output = save_output self.files = files - - self.rotations = rotations + if self.photo_preprocess and not issubclass(self.content_type, PhotoContent): + raise CommandError( + "--photo-preprocess flag is only available for Photo content type", 2 + ) def execute(self, settings: CLISettings) -> None: hashers = [ @@ -141,46 +141,44 @@ def execute(self, settings: CLISettings) -> None: hashers = [self.signal_type] # type: ignore # can't detect intersection types - if not self.rotations: + if self.photo_preprocess: for file in self.files: - for hasher in hashers: - if isinstance(hasher, PdqSignal) and ( - self.content_type.get_name() == "photo" - and self.preprocess == "unletterbox" - ): - hash_str = PdqSignal.hash_from_bytes( - PhotoContent.unletterbox( - file, self.save_output, self.black_threshold - ) - ) - else: - hash_str = hasher.hash_from_file(file) - if hash_str: - print(hasher.get_name(), hash_str) - return - - if not issubclass(self.content_type, PhotoContent): - raise CommandError( - "--rotations flag is only available for Photo content type", 2 - ) - - for file in self.files: - with open(file, "rb") as f: - if ( - self.content_type.get_name() == "photo" - and self.preprocess == "unletterbox" - ): - image_bytes = PhotoContent.unletterbox( - file, self.save_output, self.black_threshold + updated_bytes: t.List[bytes] = [] + rotation_type = [] + if self.photo_preprocess == "unletterbox": + updated_bytes.append( + PhotoContent.unletterbox(str(file), self.black_threshold) ) - else: - image_bytes = f.read() - rotated_images = PhotoContent.all_simple_rotations(image_bytes) - for rotation_type, rotated_bytes in rotated_images.items(): - with tempfile.NamedTemporaryFile() as temp_file: # Create a temporary file to hold the byte data - temp_file.write(rotated_bytes) + elif self.photo_preprocess == "rotations": + with open(file, "rb") as f: + image_bytes = f.read() + rotations = PhotoContent.all_simple_rotations(image_bytes) + rotation_type, updated_bytes = list(rotations.keys()), list( + rotations.values() + ) + for idx, bytes_data in enumerate(updated_bytes): + with tempfile.NamedTemporaryFile() as temp_file: + temp_file.write(bytes_data) temp_file_path = pathlib.Path(temp_file.name) for hasher in hashers: hash_str = hasher.hash_from_file(temp_file_path) if hash_str: - print(rotation_type.name, hasher.get_name(), hash_str) + print( + f"{rotation_type[idx].name if rotation_type else ''} {hasher.get_name()} {hash_str}" + ) + if self.save_output: + suffix = ( + f"_{rotation_type[idx].name}" + if rotation_type + else "_unletterboxed" + ) + output_path = file.with_stem(f"{file.stem}{suffix}") + with open(output_path, "wb") as output_file: + output_file.write(bytes_data) + print(f"Processed image saved to: {output_path}") + else: + for file in self.files: + for hasher in hashers: + hash_str = hasher.hash_from_file(file) + if hash_str: + print(hasher.get_name(), hash_str) diff --git a/python-threatexchange/threatexchange/cli/tests/hash_cmd_test.py b/python-threatexchange/threatexchange/cli/tests/hash_cmd_test.py index 0eaf96a1e..789e03b0f 100644 --- a/python-threatexchange/threatexchange/cli/tests/hash_cmd_test.py +++ b/python-threatexchange/threatexchange/cli/tests/hash_cmd_test.py @@ -83,8 +83,8 @@ def test_rotations_with_non_photo_content( """Test that rotation flag raises error with non-photo content""" for content_type in ["url", "text", "video"]: hash_cli.assert_cli_usage_error( - ("--rotations", content_type, str(tmp_file)), - msg_regex="--rotations flag is only available for Photo content type", + ("--photo-preprocess=rotations", content_type, str(tmp_file)), + msg_regex="--photo-preprocess flag is only available for Photo content type", ) @@ -93,7 +93,7 @@ def test_rotations_with_photo_content(hash_cli: ThreatExchangeCLIE2eHelper): test_file = pathlib.Path("threatexchange/tests/hashing/resources/LA.png") hash_cli.assert_cli_output( - ("--rotations", "photo", str(test_file)), + ("--photo-preprocess=rotations", "photo", str(test_file)), [ "ORIGINAL pdq accb6d39648035f8125c8ce6ba65007de7b54c67a2d93ef7b8f33b0611306715", "ROTATE90 pdq 1f70cbbc77edc5f9524faa1b18f3b76cd0a04a833e20f645d229d0acc8499c56", @@ -105,3 +105,49 @@ def test_rotations_with_photo_content(hash_cli: ThreatExchangeCLIE2eHelper): "FLIPMINUS1 pdq 5bb15db9e8a1f03c174a380a55aeaa2985bde9c60abce301bde48df918b5c15b", ], ) + + +def test_unletterbox_with_non_photo_content( + hash_cli: ThreatExchangeCLIE2eHelper, tmp_file: pathlib.Path +): + """Test that unletterbox flag raises error with non-photo content""" + for content_type in ["url", "text", "video"]: + hash_cli.assert_cli_usage_error( + ("--photo-preprocess=unletterbox", content_type, str(tmp_file)), + msg_regex="--photo-preprocess flag is only available for Photo content type", + ) + + +def test_unletterbox_with_photo_content(hash_cli: ThreatExchangeCLIE2eHelper): + """Test that photo unletterboxing is properly processed""" + test_file = pathlib.Path( + "threatexchange/tests/hashing/resources/letterboxed_sample-b.jpg" + ) + clean_file = pathlib.Path("threatexchange/tests/hashing/resources/sample-b.jpg") + + hash_cli.assert_cli_output( + ("photo", str(clean_file)), + [ + "pdq f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22", + ], + ) + + """Test that photo unletterboxing is chnaged based off of allowed threshold""" + hash_cli.assert_cli_output( + ("--photo-preprocess=unletterbox", "photo", str(test_file)), + [ + "pdq 58f870cce0f4e84d8e378a32028f63f4b36e26f597621e1d33e6b39c4a9c9b22", + ], + ) + + hash_cli.assert_cli_output( + ( + "--photo-preprocess=unletterbox", + "--black-threshold=25", + "photo", + str(test_file), + ), + [ + "pdq f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22", + ], + ) diff --git a/python-threatexchange/threatexchange/content_type/photo.py b/python-threatexchange/threatexchange/content_type/photo.py index 1ee00a217..9225a1eed 100644 --- a/python-threatexchange/threatexchange/content_type/photo.py +++ b/python-threatexchange/threatexchange/content_type/photo.py @@ -10,6 +10,7 @@ import os from .content_base import ContentType, RotationType +from threatexchange.content_type.preprocess import unletterboxing class PhotoContent(ContentType): @@ -105,118 +106,20 @@ def all_simple_rotations(cls, image_data: bytes): return rotations @classmethod - def detect_top_border( - cls, grayscale_img: Image.Image, black_threshold: int = 10 - ) -> int: + def unletterbox(cls, file_path: str, black_threshold: int = 0) -> bytes: """ - Detect the top black border by counting rows with only black pixels. - Uses a defualt black threshold of 10 so that only rows with pixel brightness - of 10 or lower will be removed. - - Returns the first row that is not all blacked out from the top. - """ - width, height = grayscale_img.size - for y in range(height): - row_pixels = list(grayscale_img.crop((0, y, width, y + 1)).getdata()) - if all(pixel < black_threshold for pixel in row_pixels): - continue - return y - return height - - @classmethod - def detect_bottom_border( - cls, grayscale_img: Image.Image, black_threshold: int = 10 - ) -> int: - """ - Detect the bottom black border by counting rows with only black pixels from the bottom up. - Uses a defualt black threshold of 10 so that only rows with pixel brightness - of 10 or lower will be removed. - - Returns the first row that is not all blacked out from the bottom. - """ - width, height = grayscale_img.size - for y in range(height - 1, -1, -1): - row_pixels = list(grayscale_img.crop((0, y, width, y + 1)).getdata()) - if all(pixel < black_threshold for pixel in row_pixels): - continue - return height - y - 1 - return height - - @classmethod - def detect_left_border( - cls, grayscale_img: Image.Image, black_threshold: int = 10 - ) -> int: + Remove black letterbox borders from the sides and top of the image based on the specified black_threshold. + Returns the cleaned image as raw bytes. """ - Detect the left black border by counting columns with only black pixels. - Uses a defualt black threshold of 10 so that only colums with pixel brightness - of 10 or lower will be removed. + with Image.open(file_path) as image: + top = unletterboxing.detect_top_border(image, black_threshold) + bottom = unletterboxing.detect_bottom_border(image, black_threshold) + left = unletterboxing.detect_left_border(image, black_threshold) + right = unletterboxing.detect_right_border(image, black_threshold) - Returns the first column from the left that is not all blacked out in the column. - """ - width, height = grayscale_img.size - for x in range(width): - col_pixels = list(grayscale_img.crop((x, 0, x + 1, height)).getdata()) - if all(pixel < black_threshold for pixel in col_pixels): - continue - return x - return width + width, height = image.size + cropped_img = image.crop((left, top, width - right, height - bottom)) - @classmethod - def detect_right_border( - cls, grayscale_img: Image.Image, black_threshold: int = 10 - ) -> int: - """ - Detect the right black border by counting columns with only black pixels from the right. - Uses a defualt black threshold of 10 so that only colums with pixel brightness - of 10 or lower will be removed. - - Returns the first column from the right that is not all blacked out in the column. - """ - width, height = grayscale_img.size - for x in range(width - 1, -1, -1): - col_pixels = list(grayscale_img.crop((x, 0, x + 1, height)).getdata()) - if all(pixel < black_threshold for pixel in col_pixels): - continue - return width - x - 1 - return width - - @classmethod - def unletterbox( - cls, file_path: Path, save_output: bool = False, black_threshold: int = 40 - ) -> bytes: - """ - Remove black letterbox borders from the sides and top of the image. - - Converts the image to grescale then remove the columns and rows that - are all completly blacked out. - - Then removing the edges to give back a cleaned image bytes. - - Return the new hash of the cleaned image with an option to create a new output file as well - """ - # Open the original image - with Image.open(file_path) as img: - grayscale_img = img.convert("L") - - top = cls.detect_top_border(grayscale_img, black_threshold) - bottom = cls.detect_bottom_border(grayscale_img, black_threshold) - left = cls.detect_left_border(grayscale_img, black_threshold) - right = cls.detect_right_border(grayscale_img, black_threshold) - - width, height = grayscale_img.size - cropped_box = (left, top, width - right, height - bottom) - - cropped_img = img.crop(cropped_box) - - # Optionally save the unletterboxed image to a new file in the same directory - if save_output: - path = Path(file_path) - output_path = path.parent / f"{path.stem}_unletterboxed{path.suffix}" - cropped_img.save(output_path) - print(f"Unletterboxed image saved to: {output_path}") - - # Convert the cropped image to bytes for hashing with io.BytesIO() as buffer: - cropped_img.save(buffer, format=img.format) - cropped_image_data = buffer.getvalue() - return cropped_image_data + cropped_img.save(buffer, format=image.format) + return buffer.getvalue() diff --git a/python-threatexchange/threatexchange/content_type/preprocess/unletterboxing.py b/python-threatexchange/threatexchange/content_type/preprocess/unletterboxing.py new file mode 100644 index 000000000..31a505e56 --- /dev/null +++ b/python-threatexchange/threatexchange/content_type/preprocess/unletterboxing.py @@ -0,0 +1,69 @@ +from PIL import Image + + +def is_pixel_black(pixel, threshold): + """ + Check if each color channel in the pixel is below the threshold + """ + r, g, b = pixel + return r < threshold and g < threshold and b < threshold + + +def detect_top_border(image: Image.Image, black_threshold: int = 0) -> int: + """ + Detect the top black border by counting rows with only black pixels. + Checks each RGB channel of each pixel in each row. + Returns the first row that is not all black from the top. + """ + width, height = image.size + for y in range(height): + row_pixels = list(image.crop((0, y, width, y + 1)).getdata()) + if all(is_pixel_black(pixel, black_threshold) for pixel in row_pixels): + continue + return y + return height + + +def detect_bottom_border(image: Image.Image, black_threshold: int = 0) -> int: + """ + Detect the bottom black border by counting rows with only black pixels from the bottom up. + Checks each RGB channel of each pixel in each row. + Returns the first row that is not all black from the bottom. + """ + width, height = image.size + for y in range(height - 1, -1, -1): + row_pixels = list(image.crop((0, y, width, y + 1)).getdata()) + if all(is_pixel_black(pixel, black_threshold) for pixel in row_pixels): + continue + return height - y - 1 + return height + + +def detect_left_border(image: Image.Image, black_threshold: int = 0) -> int: + """ + Detect the left black border by counting columns with only black pixels. + Checks each RGB channel of each pixel in each column. + Returns the first column from the left that is not all black. + """ + width, height = image.size + for x in range(width): + col_pixels = list(image.crop((x, 0, x + 1, height)).getdata()) + if all(is_pixel_black(pixel, black_threshold) for pixel in col_pixels): + continue + return x + return width + + +def detect_right_border(image: Image.Image, black_threshold: int = 0) -> int: + """ + Detect the right black border by counting columns with only black pixels from the right. + Checks each RGB channel of each pixel in each column. + Returns the first column from the right that is not all black. + """ + width, height = image.size + for x in range(width - 1, -1, -1): + col_pixels = list(image.crop((x, 0, x + 1, height)).getdata()) + if all(is_pixel_black(pixel, black_threshold) for pixel in col_pixels): + continue + return width - x - 1 + return width diff --git a/python-threatexchange/threatexchange/tests/hashing/resources/clean.png b/python-threatexchange/threatexchange/tests/hashing/resources/clean.png deleted file mode 100644 index 679aacbb1..000000000 Binary files a/python-threatexchange/threatexchange/tests/hashing/resources/clean.png and /dev/null differ diff --git a/python-threatexchange/threatexchange/tests/hashing/resources/letterbox.png b/python-threatexchange/threatexchange/tests/hashing/resources/letterbox.png deleted file mode 100644 index cb53935a7..000000000 Binary files a/python-threatexchange/threatexchange/tests/hashing/resources/letterbox.png and /dev/null differ diff --git a/python-threatexchange/threatexchange/tests/hashing/resources/letterboxed_sample-b.jpg b/python-threatexchange/threatexchange/tests/hashing/resources/letterboxed_sample-b.jpg new file mode 100644 index 000000000..d2e23eb6c Binary files /dev/null and b/python-threatexchange/threatexchange/tests/hashing/resources/letterboxed_sample-b.jpg differ diff --git a/python-threatexchange/threatexchange/tests/hashing/resources/sample-b.jpg b/python-threatexchange/threatexchange/tests/hashing/resources/sample-b.jpg new file mode 100644 index 000000000..66ad092df Binary files /dev/null and b/python-threatexchange/threatexchange/tests/hashing/resources/sample-b.jpg differ diff --git a/python-threatexchange/threatexchange/tests/hashing/test_pdq_letterboxing.py b/python-threatexchange/threatexchange/tests/hashing/test_pdq_letterboxing.py deleted file mode 100644 index bd644aa98..000000000 --- a/python-threatexchange/threatexchange/tests/hashing/test_pdq_letterboxing.py +++ /dev/null @@ -1,81 +0,0 @@ -import unittest -from pathlib import Path -from threatexchange.signal_type.pdq.signal import PdqSignal -from threatexchange.content_type.photo import PhotoContent - - -class TestUnletterboxFunction(unittest.TestCase): - def setUp(self): - # Load the file paths - current_path = Path(__file__).parent - self.letterbox_path = Path(f"{current_path}/resources/letterbox.png") - self.clean_path = Path(f"{current_path}/resources/clean.png") - self.output_path = Path(f"{current_path}/resources/letterbox_unletterboxed.png") - - def clean(self): - # Removes generated output file if already exists - if self.output_path.exists(): - self.output_path.unlink() - - def test_letterbox_image_without_unletterbox(self): - with self.letterbox_path.open("rb") as f: - letterbox_data = f.read() - - letterbox_hash = PdqSignal.hash_from_bytes(letterbox_data) - - with self.clean_path.open("rb") as f: - clean_data = f.read() - clean_hash = PdqSignal.hash_from_bytes(clean_data) - - # Assert that the hash of the original letterbox image is different from the clean image's hash - self.assertNotEqual( - letterbox_hash, - clean_hash, - "Letterbox image unexpectedly matches the clean image", - ) - - def test_unletterbox_image(self): - # Generate PDQ hash for the unletterboxed image - unletterboxed_hash = PdqSignal.hash_from_bytes( - PhotoContent.unletterbox(self.letterbox_path) - ) - - # Read the clean image data and generate PDQ hash - with self.clean_path.open("rb") as f: - clean_data = f.read() - clean_hash = PdqSignal.hash_from_bytes(clean_data) - - self.assertEqual( - unletterboxed_hash, - clean_hash, - "Unletterboxed image does not match the clean image", - ) - - def test_unletterboxfile_creates_matching_image(self): - # Created generated hash and also create new output file - generated_hash = PdqSignal.hash_from_bytes( - PhotoContent.unletterbox(self.letterbox_path, True) - ) - self.assertTrue( - self.output_path.exists(), "The unletterboxed output file was not created." - ) - - # Generate PDQ hash for the clean image - with self.clean_path.open("rb") as f: - clean_data = f.read() - clean_hash = PdqSignal.hash_from_bytes(clean_data) - - # Assert that the hash of the generated unletterboxed image matches the clean image's hash - self.assertEqual( - generated_hash, - clean_hash, - "Unletterboxfile output does not match the clean image", - ) - - # Removes created file - if self.output_path.exists(): - self.output_path.unlink() - - -if __name__ == "__main__": - unittest.main()