From 4577a51f0d4cb928c2117f45c385d0dd29aa4162 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Tue, 22 Oct 2024 16:27:47 -0700 Subject: [PATCH] update tests --- daft/expressions/expressions.py | 6 ++++- tests/series/test_minhash.py | 48 +++++++++++++++++++++------------ tests/table/test_minhash.py | 7 ++--- 3 files changed, 40 insertions(+), 21 deletions(-) diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 273a73a850..f005b58c59 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -1186,6 +1186,7 @@ def minhash( num_hashes: int, ngram_size: int, seed: int = 1, + hash_function: Literal["murmur3", "xxhash", "sha1"] = "murmur3", ) -> Expression: """ Runs the MinHash algorithm on the series. @@ -1206,7 +1207,10 @@ def minhash( assert isinstance(num_hashes, int) assert isinstance(ngram_size, int) assert isinstance(seed, int) - return Expression._from_pyexpr(native.minhash(self._expr, num_hashes, ngram_size, seed)) + assert isinstance(hash_function, str) + assert hash_function in ("murmur3", "xxhash", "sha1"), f"Hash function {hash_function} not found" + + return Expression._from_pyexpr(native.minhash(self._expr, num_hashes, ngram_size, seed, hash_function)) def name(self) -> builtins.str: return self._expr.name() diff --git a/tests/series/test_minhash.py b/tests/series/test_minhash.py index 28019d9d1b..d1f34471dd 100644 --- a/tests/series/test_minhash.py +++ b/tests/series/test_minhash.py @@ -1,13 +1,21 @@ +from typing import Literal + import pytest from daft import DataType, Series -def minhash_none(series, num_hashes, ngram_size, seed): +def minhash_none( + series: Series, + num_hashes: int, + ngram_size: int, + seed: int | None, + hash_function: Literal["murmur3", "xxhash", "sha1"], +) -> list[list[int] | None]: if seed is None: - return series.minhash(num_hashes, ngram_size).to_pylist() + return series.minhash(num_hashes, ngram_size, hash_function=hash_function).to_pylist() else: - return series.minhash(num_hashes, ngram_size, seed).to_pylist() + return series.minhash(num_hashes, ngram_size, seed, hash_function=hash_function).to_pylist() test_series = Series.from_pylist( @@ -31,8 +39,9 @@ def minhash_none(series, num_hashes, ngram_size, seed): @pytest.mark.parametrize("num_hashes", [1, 2, 16, 128]) @pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100]) @pytest.mark.parametrize("seed", [1, -1, 123, None]) -def test_minhash(num_hashes, ngram_size, seed): - minhash = minhash_none(test_series, num_hashes, ngram_size, seed) +@pytest.mark.parametrize("hash_function", ["murmur3", "xxhash", "sha1"]) +def test_minhash(num_hashes, ngram_size, seed, hash_function): + minhash = minhash_none(test_series, num_hashes, ngram_size, seed, hash_function) assert minhash[4] is None and minhash[-1] is None for lst in minhash: if lst is not None: @@ -46,42 +55,47 @@ def test_minhash(num_hashes, ngram_size, seed): @pytest.mark.parametrize("num_hashes", [0, -1, -100]) @pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100]) @pytest.mark.parametrize("seed", [1, -1, 123, None]) -def test_minhash_fails_nonpositive_num_hashes(num_hashes, ngram_size, seed): +@pytest.mark.parametrize("hash_function", ["murmur3", "xxhash", "sha1"]) +def test_minhash_fails_nonpositive_num_hashes(num_hashes, ngram_size, seed, hash_function): with pytest.raises(ValueError, match="num_hashes must be positive"): - minhash_none(test_series, num_hashes, ngram_size, seed) + minhash_none(test_series, num_hashes, ngram_size, seed, hash_function) @pytest.mark.parametrize("num_hashes", [1, 2, 16, 128]) @pytest.mark.parametrize("ngram_size", [0, -1, -100]) @pytest.mark.parametrize("seed", [1, -1, 123, None]) -def test_minhash_fails_nonpositive_ngram_size(num_hashes, ngram_size, seed): +@pytest.mark.parametrize("hash_function", ["murmur3", "xxhash", "sha1"]) +def test_minhash_fails_nonpositive_ngram_size(num_hashes, ngram_size, seed, hash_function): with pytest.raises(ValueError, match="ngram_size must be positive"): - minhash_none(test_series, num_hashes, ngram_size, seed) + minhash_none(test_series, num_hashes, ngram_size, seed, hash_function) @pytest.mark.parametrize("num_hashes", [1, 2, 16, 128]) @pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100]) @pytest.mark.parametrize("seed", [1, -1, 123, None]) -def test_minhash_empty_series(num_hashes, ngram_size, seed): +@pytest.mark.parametrize("hash_function", ["murmur3", "xxhash", "sha1"]) +def test_minhash_empty_series(num_hashes, ngram_size, seed, hash_function): series = Series.from_pylist([]).cast(DataType.string()) - minhash = minhash_none(series, num_hashes, ngram_size, seed) + minhash = minhash_none(series, num_hashes, ngram_size, seed, hash_function) assert len(minhash) == 0 @pytest.mark.parametrize("num_hashes", [1, 2, 16, 128]) @pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100]) @pytest.mark.parametrize("seed", [1, -1, 123, None]) -def test_minhash_seed_consistency(num_hashes, ngram_size, seed): - minhash1 = minhash_none(test_series, num_hashes, ngram_size, seed) - minhash2 = minhash_none(test_series, num_hashes, ngram_size, seed) +@pytest.mark.parametrize("hash_function", ["murmur3", "xxhash", "sha1"]) +def test_minhash_seed_consistency(num_hashes, ngram_size, seed, hash_function): + minhash1 = minhash_none(test_series, num_hashes, ngram_size, seed, hash_function) + minhash2 = minhash_none(test_series, num_hashes, ngram_size, seed, hash_function) assert minhash1 == minhash2 @pytest.mark.parametrize("num_hashes", [1, 2, 16, 128]) @pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100]) @pytest.mark.parametrize("seed_pair", [[1, 2], [1, 5], [None, 2], [123, 234]]) -def test_minhash_seed_differences(num_hashes, ngram_size, seed_pair): - minhash1 = minhash_none(test_series, num_hashes, ngram_size, seed_pair[0]) - minhash2 = minhash_none(test_series, num_hashes, ngram_size, seed_pair[1]) +@pytest.mark.parametrize("hash_function", ["murmur3", "xxhash", "sha1"]) +def test_minhash_seed_differences(num_hashes, ngram_size, seed_pair, hash_function): + minhash1 = minhash_none(test_series, num_hashes, ngram_size, seed_pair[0], hash_function) + minhash2 = minhash_none(test_series, num_hashes, ngram_size, seed_pair[1], hash_function) assert minhash1 != minhash2 diff --git a/tests/table/test_minhash.py b/tests/table/test_minhash.py index f8aa7dba3e..5c56e95528 100644 --- a/tests/table/test_minhash.py +++ b/tests/table/test_minhash.py @@ -7,7 +7,8 @@ @pytest.mark.parametrize("num_hashes", [1, 2, 16, 128]) @pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100]) @pytest.mark.parametrize("seed", [1, -1, 123, None]) -def test_table_expr_minhash(num_hashes, ngram_size, seed): +@pytest.mark.parametrize("hash_function", ["murmur3", "xxhash", "sha1"]) +def test_table_expr_minhash(num_hashes, ngram_size, seed, hash_function): df = daft.from_pydict( { "data": [ @@ -25,9 +26,9 @@ def test_table_expr_minhash(num_hashes, ngram_size, seed): res = None if seed is None: - res = df.select(col("data").minhash(num_hashes, ngram_size)) + res = df.select(col("data").minhash(num_hashes, ngram_size, hash_function=hash_function)) else: - res = df.select(col("data").minhash(num_hashes, ngram_size, seed)) + res = df.select(col("data").minhash(num_hashes, ngram_size, seed, hash_function=hash_function)) minhash = res.to_pydict()["data"] assert minhash[4] is None and minhash[-1] is None for lst in minhash: