Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Oct 25, 2024
1 parent c9d59a5 commit b2a6bc1
Showing 1 changed file with 8 additions and 17 deletions.
25 changes: 8 additions & 17 deletions tests/series/test_minhash.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from __future__ import annotations

from typing import Literal

import pytest

from daft import DataType, Series
from daft.daft import HashFunctionKind


def minhash_none(
series: Series,
num_hashes: int,
ngram_size: int,
seed: int | None,
hash_function: HashFunctionKind,
hash_function: Literal["murmurhash3", "xxhash", "sha1"],
) -> list[list[int] | None]:
if seed is None:
return series.minhash(num_hashes, ngram_size, hash_function=hash_function).to_pylist()
Expand Down Expand Up @@ -56,9 +57,7 @@ def test_minhash(num_hashes, ngram_size, seed, hash_function):
@pytest.mark.parametrize("num_hashes", [0, -1, -100])
@pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100])
@pytest.mark.parametrize("seed", [1, -1, 123, None])
@pytest.mark.parametrize(
"hash_function", [HashFunctionKind.MurmurHash3, HashFunctionKind.XxHash, HashFunctionKind.Sha1]
)
@pytest.mark.parametrize("hash_function", ["murmurhash3", "xxhash", "sha1"])
def test_minhash_fails_nonpositive_num_hashes(num_hashes, ngram_size, seed, hash_function):
with pytest.raises(ValueError, match="num_hashes must be positive"):
minhash_none(test_series, num_hashes, ngram_size, seed, hash_function)
Expand All @@ -67,9 +66,7 @@ def test_minhash_fails_nonpositive_num_hashes(num_hashes, ngram_size, seed, hash
@pytest.mark.parametrize("num_hashes", [1, 2, 16, 128])
@pytest.mark.parametrize("ngram_size", [0, -1, -100])
@pytest.mark.parametrize("seed", [1, -1, 123, None])
@pytest.mark.parametrize(
"hash_function", [HashFunctionKind.MurmurHash3, HashFunctionKind.XxHash, HashFunctionKind.Sha1]
)
@pytest.mark.parametrize("hash_function", ["murmurhash3", "xxhash", "sha1"])
def test_minhash_fails_nonpositive_ngram_size(num_hashes, ngram_size, seed, hash_function):
with pytest.raises(ValueError, match="ngram_size must be positive"):
minhash_none(test_series, num_hashes, ngram_size, seed, hash_function)
Expand All @@ -78,9 +75,7 @@ def test_minhash_fails_nonpositive_ngram_size(num_hashes, ngram_size, seed, hash
@pytest.mark.parametrize("num_hashes", [1, 2, 16, 128])
@pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100])
@pytest.mark.parametrize("seed", [1, -1, 123, None])
@pytest.mark.parametrize(
"hash_function", [HashFunctionKind.MurmurHash3, HashFunctionKind.XxHash, HashFunctionKind.Sha1]
)
@pytest.mark.parametrize("hash_function", ["murmurhash3", "xxhash", "sha1"])
def test_minhash_empty_series(num_hashes, ngram_size, seed, hash_function):
series = Series.from_pylist([]).cast(DataType.string())

Expand All @@ -91,9 +86,7 @@ def test_minhash_empty_series(num_hashes, ngram_size, seed, hash_function):
@pytest.mark.parametrize("num_hashes", [1, 2, 16, 128])
@pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100])
@pytest.mark.parametrize("seed", [1, -1, 123, None])
@pytest.mark.parametrize(
"hash_function", [HashFunctionKind.MurmurHash3, HashFunctionKind.XxHash, HashFunctionKind.Sha1]
)
@pytest.mark.parametrize("hash_function", ["murmurhash3", "xxhash", "sha1"])
def test_minhash_seed_consistency(num_hashes, ngram_size, seed, hash_function):
minhash1 = minhash_none(test_series, num_hashes, ngram_size, seed, hash_function)
minhash2 = minhash_none(test_series, num_hashes, ngram_size, seed, hash_function)
Expand All @@ -103,9 +96,7 @@ def test_minhash_seed_consistency(num_hashes, ngram_size, seed, hash_function):
@pytest.mark.parametrize("num_hashes", [1, 2, 16, 128])
@pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100])
@pytest.mark.parametrize("seed_pair", [[1, 2], [1, 5], [None, 2], [123, 234]])
@pytest.mark.parametrize(
"hash_function", [HashFunctionKind.MurmurHash3, HashFunctionKind.XxHash, HashFunctionKind.Sha1]
)
@pytest.mark.parametrize("hash_function", ["murmurhash3", "xxhash", "sha1"])
def test_minhash_seed_differences(num_hashes, ngram_size, seed_pair, hash_function):
minhash1 = minhash_none(test_series, num_hashes, ngram_size, seed_pair[0], hash_function)
minhash2 = minhash_none(test_series, num_hashes, ngram_size, seed_pair[1], hash_function)
Expand Down

0 comments on commit b2a6bc1

Please sign in to comment.