From b2a6bc165910e0157706f767b44e1bc8052ba7e0 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Fri, 25 Oct 2024 04:36:40 -0700 Subject: [PATCH] fix test --- tests/series/test_minhash.py | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/tests/series/test_minhash.py b/tests/series/test_minhash.py index 06c5d5ccfa..8ddeff2662 100644 --- a/tests/series/test_minhash.py +++ b/tests/series/test_minhash.py @@ -1,9 +1,10 @@ from __future__ import annotations +from typing import Literal + import pytest from daft import DataType, Series -from daft.daft import HashFunctionKind def minhash_none( @@ -11,7 +12,7 @@ def minhash_none( num_hashes: int, ngram_size: int, seed: int | None, - hash_function: HashFunctionKind, + hash_function: Literal["murmurhash3", "xxhash", "sha1"], ) -> list[list[int] | None]: if seed is None: return series.minhash(num_hashes, ngram_size, hash_function=hash_function).to_pylist() @@ -56,9 +57,7 @@ def test_minhash(num_hashes, ngram_size, seed, hash_function): @pytest.mark.parametrize("num_hashes", [0, -1, -100]) @pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100]) @pytest.mark.parametrize("seed", [1, -1, 123, None]) -@pytest.mark.parametrize( - "hash_function", [HashFunctionKind.MurmurHash3, HashFunctionKind.XxHash, HashFunctionKind.Sha1] -) +@pytest.mark.parametrize("hash_function", ["murmurhash3", "xxhash", "sha1"]) def test_minhash_fails_nonpositive_num_hashes(num_hashes, ngram_size, seed, hash_function): with pytest.raises(ValueError, match="num_hashes must be positive"): minhash_none(test_series, num_hashes, ngram_size, seed, hash_function) @@ -67,9 +66,7 @@ def test_minhash_fails_nonpositive_num_hashes(num_hashes, ngram_size, seed, hash @pytest.mark.parametrize("num_hashes", [1, 2, 16, 128]) @pytest.mark.parametrize("ngram_size", [0, -1, -100]) @pytest.mark.parametrize("seed", [1, -1, 123, None]) -@pytest.mark.parametrize( - "hash_function", [HashFunctionKind.MurmurHash3, HashFunctionKind.XxHash, HashFunctionKind.Sha1] -) +@pytest.mark.parametrize("hash_function", ["murmurhash3", "xxhash", "sha1"]) def test_minhash_fails_nonpositive_ngram_size(num_hashes, ngram_size, seed, hash_function): with pytest.raises(ValueError, match="ngram_size must be positive"): minhash_none(test_series, num_hashes, ngram_size, seed, hash_function) @@ -78,9 +75,7 @@ def test_minhash_fails_nonpositive_ngram_size(num_hashes, ngram_size, seed, hash @pytest.mark.parametrize("num_hashes", [1, 2, 16, 128]) @pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100]) @pytest.mark.parametrize("seed", [1, -1, 123, None]) -@pytest.mark.parametrize( - "hash_function", [HashFunctionKind.MurmurHash3, HashFunctionKind.XxHash, HashFunctionKind.Sha1] -) +@pytest.mark.parametrize("hash_function", ["murmurhash3", "xxhash", "sha1"]) def test_minhash_empty_series(num_hashes, ngram_size, seed, hash_function): series = Series.from_pylist([]).cast(DataType.string()) @@ -91,9 +86,7 @@ def test_minhash_empty_series(num_hashes, ngram_size, seed, hash_function): @pytest.mark.parametrize("num_hashes", [1, 2, 16, 128]) @pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100]) @pytest.mark.parametrize("seed", [1, -1, 123, None]) -@pytest.mark.parametrize( - "hash_function", [HashFunctionKind.MurmurHash3, HashFunctionKind.XxHash, HashFunctionKind.Sha1] -) +@pytest.mark.parametrize("hash_function", ["murmurhash3", "xxhash", "sha1"]) def test_minhash_seed_consistency(num_hashes, ngram_size, seed, hash_function): minhash1 = minhash_none(test_series, num_hashes, ngram_size, seed, hash_function) minhash2 = minhash_none(test_series, num_hashes, ngram_size, seed, hash_function) @@ -103,9 +96,7 @@ def test_minhash_seed_consistency(num_hashes, ngram_size, seed, hash_function): @pytest.mark.parametrize("num_hashes", [1, 2, 16, 128]) @pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100]) @pytest.mark.parametrize("seed_pair", [[1, 2], [1, 5], [None, 2], [123, 234]]) -@pytest.mark.parametrize( - "hash_function", [HashFunctionKind.MurmurHash3, HashFunctionKind.XxHash, HashFunctionKind.Sha1] -) +@pytest.mark.parametrize("hash_function", ["murmurhash3", "xxhash", "sha1"]) def test_minhash_seed_differences(num_hashes, ngram_size, seed_pair, hash_function): minhash1 = minhash_none(test_series, num_hashes, ngram_size, seed_pair[0], hash_function) minhash2 = minhash_none(test_series, num_hashes, ngram_size, seed_pair[1], hash_function)