Skip to content

Commit

Permalink
pytx - edit pdq index2
Browse files Browse the repository at this point in the history
  • Loading branch information
haianhng31 committed Nov 21, 2024
1 parent 610af2e commit 2ea86a1
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def query(self, hash: str) -> t.Sequence[PDQIndexMatch[IndexT]]:
# Create match objects for each entry
results.extend(
PDQIndexMatch(
SignalSimilarityInfoWithIntDistance(distance=int(distance)),
SignalSimilarityInfoWithIntDistance(distance=distance),
entry,
)
for entry in entries
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import typing as t
import numpy as np
import random
import io
import faiss

from threatexchange.signal_type.pdq.pdq_index2 import PDQIndex2
from threatexchange.signal_type.pdq.signal import PdqSignal
from threatexchange.signal_type.pdq.pdq_utils import simple_distance


def _generate_sample_hashes(size: int, seed: int = 42):
def _get_hash_generator(seed: int = 42):
random.seed(seed)
return [PdqSignal.get_random_signal() for _ in range(size)]

def get_n_hashes(n: int):
return [PdqSignal.get_random_signal() for _ in range(n)]

SAMPLE_HASHES = _generate_sample_hashes(100)
return get_n_hashes


def _brute_force_match(
Expand Down Expand Up @@ -44,18 +46,17 @@ def _generate_random_hash_with_distance(hash: str, distance: int) -> str:


def test_pdq_index():
get_random_hashes = _get_hash_generator()
base_hashes = get_random_hashes(100)
# Make sure base_hashes and query_hashes have at least 10 similar hashes
base_hashes = SAMPLE_HASHES
query_hashes = SAMPLE_HASHES[:10] + _generate_sample_hashes(10)
query_hashes = base_hashes[:10] + get_random_hashes(1000)

brute_force_matches = {
query_hash: _brute_force_match(base_hashes, query_hash)
for query_hash in query_hashes
}

index = PDQIndex2()
for i, base_hash in enumerate(base_hashes):
index.add(base_hash, i)
index = PDQIndex2(entries=[(h, base_hashes.index(h)) for h in base_hashes])

for query_hash in query_hashes:
expected_indices = brute_force_matches[query_hash]
Expand All @@ -73,18 +74,21 @@ def test_pdq_index():


def test_pdq_index_with_exact_distance():
get_random_hashes = _get_hash_generator()
base_hashes = get_random_hashes(100)

thresholds: t.List[int] = [10, 31, 50]

indexes = [
PDQIndex2(
entries=[(h, SAMPLE_HASHES.index(h)) for h in SAMPLE_HASHES],
entries=[(h, base_hashes.index(h)) for h in base_hashes],
threshold=thres,
)
for thres in thresholds
]

distances: t.List[int] = [0, 1, 20, 30, 31, 60]
query_hash = SAMPLE_HASHES[0]
query_hash = base_hashes[0]

for i in range(len(indexes)):
index = indexes[i]
Expand All @@ -97,6 +101,23 @@ def test_pdq_index_with_exact_distance():
assert dist in result_indices


def test_serialize_deserialize_index():
get_random_hashes = _get_hash_generator()
base_hashes = get_random_hashes(100)
index = PDQIndex2(entries=[(h, base_hashes.index(h)) for h in base_hashes])

buffer = io.BytesIO()
index.serialize(buffer)
buffer.seek(0)
deserialized_index = PDQIndex2.deserialize(buffer)

assert isinstance(deserialized_index, PDQIndex2)
assert isinstance(deserialized_index._index.faiss_index, faiss.IndexFlatL2)
assert deserialized_index.threshold == index.threshold
assert deserialized_index._deduper == index._deduper
assert deserialized_index._idx_to_entries == index._idx_to_entries


def test_empty_index_query():
"""Test querying an empty index."""
index = PDQIndex2()
Expand All @@ -108,19 +129,23 @@ def test_empty_index_query():

def test_sample_set_no_match():
"""Test no matches in sample set."""
index = PDQIndex2(entries=[(h, SAMPLE_HASHES.index(h)) for h in SAMPLE_HASHES])
get_random_hashes = _get_hash_generator()
base_hashes = get_random_hashes(100)
index = PDQIndex2(entries=[(h, base_hashes.index(h)) for h in base_hashes])
results = index.query("b" * 64)
assert len(results) == 0


def test_duplicate_handling():
"""Test how the index handles duplicate entries."""
index = PDQIndex2(entries=[(h, SAMPLE_HASHES.index(h)) for h in SAMPLE_HASHES])
get_random_hashes = _get_hash_generator()
base_hashes = get_random_hashes(100)
index = PDQIndex2(entries=[(h, base_hashes.index(h)) for h in base_hashes])

# Add same hash multiple times
index.add_all(entries=[(SAMPLE_HASHES[0], i) for i in range(3)])
index.add_all(entries=[(base_hashes[0], i) for i in range(3)])

results = index.query(SAMPLE_HASHES[0])
results = index.query(base_hashes[0])

# Should find all entries associated with the hash
assert len(results) == 4
Expand All @@ -129,11 +154,17 @@ def test_duplicate_handling():


def test_one_entry_sample_index():
"""Test how the index handles when it only has one entry."""
index = PDQIndex2(entries=[(SAMPLE_HASHES[0], 0)])
"""
Test how the index handles when it only has one entry.
See issue github.com/facebook/ThreatExchange/issues/1318
"""
get_random_hashes = _get_hash_generator()
base_hashes = get_random_hashes(100)
index = PDQIndex2(entries=[(base_hashes[0], 0)])

matching_test_hash = SAMPLE_HASHES[0] # This is the existing hash in index
unmatching_test_hash = SAMPLE_HASHES[1]
matching_test_hash = base_hashes[0] # This is the existing hash in index
unmatching_test_hash = base_hashes[1]

results = index.query(matching_test_hash)
# Should find 1 entry associated with the hash
Expand Down

0 comments on commit 2ea86a1

Please sign in to comment.