Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Oct 23, 2024
1 parent 8654d09 commit 4577a51
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 21 deletions.
6 changes: 5 additions & 1 deletion daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,6 +1186,7 @@ def minhash(
num_hashes: int,
ngram_size: int,
seed: int = 1,
hash_function: Literal["murmur3", "xxhash", "sha1"] = "murmur3",
) -> Expression:
"""
Runs the MinHash algorithm on the series.
Expand All @@ -1206,7 +1207,10 @@ def minhash(
assert isinstance(num_hashes, int)
assert isinstance(ngram_size, int)
assert isinstance(seed, int)
return Expression._from_pyexpr(native.minhash(self._expr, num_hashes, ngram_size, seed))
assert isinstance(hash_function, str)
assert hash_function in ("murmur3", "xxhash", "sha1"), f"Hash function {hash_function} not found"

return Expression._from_pyexpr(native.minhash(self._expr, num_hashes, ngram_size, seed, hash_function))

def name(self) -> builtins.str:
return self._expr.name()
Expand Down
48 changes: 31 additions & 17 deletions tests/series/test_minhash.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
from typing import Literal

import pytest

from daft import DataType, Series


def minhash_none(series, num_hashes, ngram_size, seed):
def minhash_none(
series: Series,
num_hashes: int,
ngram_size: int,
seed: int | None,
hash_function: Literal["murmur3", "xxhash", "sha1"],
) -> list[list[int] | None]:
if seed is None:
return series.minhash(num_hashes, ngram_size).to_pylist()
return series.minhash(num_hashes, ngram_size, hash_function=hash_function).to_pylist()
else:
return series.minhash(num_hashes, ngram_size, seed).to_pylist()
return series.minhash(num_hashes, ngram_size, seed, hash_function=hash_function).to_pylist()


test_series = Series.from_pylist(
Expand All @@ -31,8 +39,9 @@ def minhash_none(series, num_hashes, ngram_size, seed):
@pytest.mark.parametrize("num_hashes", [1, 2, 16, 128])
@pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100])
@pytest.mark.parametrize("seed", [1, -1, 123, None])
def test_minhash(num_hashes, ngram_size, seed):
minhash = minhash_none(test_series, num_hashes, ngram_size, seed)
@pytest.mark.parametrize("hash_function", ["murmur3", "xxhash", "sha1"])
def test_minhash(num_hashes, ngram_size, seed, hash_function):
minhash = minhash_none(test_series, num_hashes, ngram_size, seed, hash_function)
assert minhash[4] is None and minhash[-1] is None
for lst in minhash:
if lst is not None:
Expand All @@ -46,42 +55,47 @@ def test_minhash(num_hashes, ngram_size, seed):
@pytest.mark.parametrize("num_hashes", [0, -1, -100])
@pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100])
@pytest.mark.parametrize("seed", [1, -1, 123, None])
def test_minhash_fails_nonpositive_num_hashes(num_hashes, ngram_size, seed):
@pytest.mark.parametrize("hash_function", ["murmur3", "xxhash", "sha1"])
def test_minhash_fails_nonpositive_num_hashes(num_hashes, ngram_size, seed, hash_function):
with pytest.raises(ValueError, match="num_hashes must be positive"):
minhash_none(test_series, num_hashes, ngram_size, seed)
minhash_none(test_series, num_hashes, ngram_size, seed, hash_function)


@pytest.mark.parametrize("num_hashes", [1, 2, 16, 128])
@pytest.mark.parametrize("ngram_size", [0, -1, -100])
@pytest.mark.parametrize("seed", [1, -1, 123, None])
def test_minhash_fails_nonpositive_ngram_size(num_hashes, ngram_size, seed):
@pytest.mark.parametrize("hash_function", ["murmur3", "xxhash", "sha1"])
def test_minhash_fails_nonpositive_ngram_size(num_hashes, ngram_size, seed, hash_function):
with pytest.raises(ValueError, match="ngram_size must be positive"):
minhash_none(test_series, num_hashes, ngram_size, seed)
minhash_none(test_series, num_hashes, ngram_size, seed, hash_function)


@pytest.mark.parametrize("num_hashes", [1, 2, 16, 128])
@pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100])
@pytest.mark.parametrize("seed", [1, -1, 123, None])
def test_minhash_empty_series(num_hashes, ngram_size, seed):
@pytest.mark.parametrize("hash_function", ["murmur3", "xxhash", "sha1"])
def test_minhash_empty_series(num_hashes, ngram_size, seed, hash_function):
series = Series.from_pylist([]).cast(DataType.string())

minhash = minhash_none(series, num_hashes, ngram_size, seed)
minhash = minhash_none(series, num_hashes, ngram_size, seed, hash_function)
assert len(minhash) == 0


@pytest.mark.parametrize("num_hashes", [1, 2, 16, 128])
@pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100])
@pytest.mark.parametrize("seed", [1, -1, 123, None])
def test_minhash_seed_consistency(num_hashes, ngram_size, seed):
minhash1 = minhash_none(test_series, num_hashes, ngram_size, seed)
minhash2 = minhash_none(test_series, num_hashes, ngram_size, seed)
@pytest.mark.parametrize("hash_function", ["murmur3", "xxhash", "sha1"])
def test_minhash_seed_consistency(num_hashes, ngram_size, seed, hash_function):
minhash1 = minhash_none(test_series, num_hashes, ngram_size, seed, hash_function)
minhash2 = minhash_none(test_series, num_hashes, ngram_size, seed, hash_function)
assert minhash1 == minhash2


@pytest.mark.parametrize("num_hashes", [1, 2, 16, 128])
@pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100])
@pytest.mark.parametrize("seed_pair", [[1, 2], [1, 5], [None, 2], [123, 234]])
def test_minhash_seed_differences(num_hashes, ngram_size, seed_pair):
minhash1 = minhash_none(test_series, num_hashes, ngram_size, seed_pair[0])
minhash2 = minhash_none(test_series, num_hashes, ngram_size, seed_pair[1])
@pytest.mark.parametrize("hash_function", ["murmur3", "xxhash", "sha1"])
def test_minhash_seed_differences(num_hashes, ngram_size, seed_pair, hash_function):
minhash1 = minhash_none(test_series, num_hashes, ngram_size, seed_pair[0], hash_function)
minhash2 = minhash_none(test_series, num_hashes, ngram_size, seed_pair[1], hash_function)
assert minhash1 != minhash2
7 changes: 4 additions & 3 deletions tests/table/test_minhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
@pytest.mark.parametrize("num_hashes", [1, 2, 16, 128])
@pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100])
@pytest.mark.parametrize("seed", [1, -1, 123, None])
def test_table_expr_minhash(num_hashes, ngram_size, seed):
@pytest.mark.parametrize("hash_function", ["murmur3", "xxhash", "sha1"])
def test_table_expr_minhash(num_hashes, ngram_size, seed, hash_function):
df = daft.from_pydict(
{
"data": [
Expand All @@ -25,9 +26,9 @@ def test_table_expr_minhash(num_hashes, ngram_size, seed):

res = None
if seed is None:
res = df.select(col("data").minhash(num_hashes, ngram_size))
res = df.select(col("data").minhash(num_hashes, ngram_size, hash_function=hash_function))
else:
res = df.select(col("data").minhash(num_hashes, ngram_size, seed))
res = df.select(col("data").minhash(num_hashes, ngram_size, seed, hash_function=hash_function))
minhash = res.to_pydict()["data"]
assert minhash[4] is None and minhash[-1] is None
for lst in minhash:
Expand Down

0 comments on commit 4577a51

Please sign in to comment.