diff --git a/Cargo.lock b/Cargo.lock index c6d5c8f43a..ac2b00da30 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1741,6 +1741,7 @@ dependencies = [ "daft-dsl", "daft-functions", "daft-functions-json", + "daft-hash", "daft-image", "daft-io", "daft-json", diff --git a/Cargo.toml b/Cargo.toml index 1195d7373a..bce444fdee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ daft-csv = {path = "src/daft-csv", default-features = false} daft-dsl = {path = "src/daft-dsl", default-features = false} daft-functions = {path = "src/daft-functions", default-features = false} daft-functions-json = {path = "src/daft-functions-json", default-features = false} +daft-hash = {path = "src/daft-hash", default-features = false} daft-image = {path = "src/daft-image", default-features = false} daft-io = {path = "src/daft-io", default-features = false} daft-json = {path = "src/daft-json", default-features = false} @@ -36,29 +37,30 @@ sysinfo = {workspace = true} [features] # maturin will turn this on python = [ - "dep:pyo3", - "dep:pyo3-log", + "common-daft-config/python", + "common-display/python", + "common-resource-request/python", + "common-system-info/python", "daft-core/python", "daft-csv/python", "daft-dsl/python", - "daft-local-execution/python", - "daft-io/python", + "daft-functions-json/python", + "daft-functions/python", + "daft-hash/python", "daft-image/python", + "daft-io/python", "daft-json/python", + "daft-local-execution/python", "daft-micropartition/python", "daft-parquet/python", "daft-plan/python", "daft-scan/python", "daft-scheduler/python", - "daft-stats/python", "daft-sql/python", + "daft-stats/python", "daft-table/python", - "daft-functions/python", - "daft-functions-json/python", - "common-daft-config/python", - "common-system-info/python", - "common-display/python", - "common-resource-request/python" + "dep:pyo3", + "dep:pyo3-log" ] [lib] diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 6b04402701..ac2673e702 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1220,7 +1220,7 @@ def minhash( num_hashes: int, ngram_size: int, seed: int = 1, - hash_function: Literal["murmur3", "xxhash", "sha1"] = "murmur3", + hash_function: HashFunctionKind = HashFunctionKind.MurmurHash3, ) -> PyExpr: ... # ----- @@ -1360,7 +1360,7 @@ class PySeries: num_hashes: int, ngram_size: int, seed: int = 1, - hash_function: Literal["murmur3", "xxhash", "sha1"] = "murmur3", + hash_function: HashFunctionKind = HashFunctionKind.MurmurHash3, ) -> PySeries: ... def __invert__(self) -> PySeries: ... def count(self, mode: CountMode) -> PySeries: ... diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index f005b58c59..cb2f277834 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -1186,7 +1186,7 @@ def minhash( num_hashes: int, ngram_size: int, seed: int = 1, - hash_function: Literal["murmur3", "xxhash", "sha1"] = "murmur3", + hash_function: native.HashFunctionKind = native.HashFunctionKind.MurmurHash3, ) -> Expression: """ Runs the MinHash algorithm on the series. @@ -1195,7 +1195,6 @@ def minhash( repeating with `num_hashes` permutations. Returns as a list of 32-bit unsigned integers. Tokens for the ngrams are delimited by spaces. - MurmurHash is used for the initial hash. The strings are not normalized or pre-processed, so it is recommended to normalize the strings yourself. @@ -1203,12 +1202,14 @@ def minhash( num_hashes: The number of hash permutations to compute. ngram_size: The number of tokens in each shingle/ngram. seed (optional): Seed used for generating permutations and the initial string hashes. Defaults to 1. + hash_function (optional): Hash function to use for initial string hashing. One of "murmur3", "xxhash", or "sha1". Defaults to "murmur3". + """ assert isinstance(num_hashes, int) assert isinstance(ngram_size, int) assert isinstance(seed, int) assert isinstance(hash_function, str) - assert hash_function in ("murmur3", "xxhash", "sha1"), f"Hash function {hash_function} not found" + assert isinstance(hash_function, native.HashFunctionKind), f"Hash function {hash_function} not found" return Expression._from_pyexpr(native.minhash(self._expr, num_hashes, ngram_size, seed, hash_function)) diff --git a/daft/series.py b/daft/series.py index fd85d33f13..88df6f20d3 100644 --- a/daft/series.py +++ b/daft/series.py @@ -4,7 +4,7 @@ from typing import Any, Literal, TypeVar from daft.arrow_utils import ensure_array, ensure_chunked_array -from daft.daft import CountMode, ImageFormat, ImageMode, PySeries, image +from daft.daft import CountMode, HashFunctionKind, ImageFormat, ImageMode, PySeries, image from daft.datatype import DataType, _ensure_registered_super_ext_type from daft.dependencies import np, pa, pd from daft.utils import pyarrow_supports_fixed_shape_tensor @@ -562,6 +562,7 @@ def minhash( num_hashes: int, ngram_size: int, seed: int = 1, + hash_function: HashFunctionKind = HashFunctionKind.MurmurHash3, ) -> Series: """ Runs the MinHash algorithm on the series. @@ -576,6 +577,7 @@ def minhash( num_hashes: The number of hash permutations to compute. ngram_size: The number of tokens in each shingle/ngram. seed (optional): Seed used for generating permutations and the initial string hashes. Defaults to 1. + hash_function (optional): Hash function to use for initial string hashing. One of "murmur3", "xxhash", or "sha1". Defaults to "murmur3". """ if not isinstance(num_hashes, int): raise ValueError(f"expected an integer for num_hashes but got {type(num_hashes)}") @@ -583,6 +585,10 @@ def minhash( raise ValueError(f"expected an integer for ngram_size but got {type(ngram_size)}") if seed is not None and not isinstance(seed, int): raise ValueError(f"expected an integer or None for seed but got {type(seed)}") + if not isinstance(hash_function, str): + raise ValueError(f"expected a string for hash_function but got {type(hash_function)}") + if not isinstance(hash_function, HashFunctionKind): + raise ValueError(f"expected HashFunctionKind for hash_function but got {type(hash_function)}") return Series._from_pyseries(self._series.minhash(num_hashes, ngram_size, seed)) diff --git a/src/daft-hash/src/lib.rs b/src/daft-hash/src/lib.rs index fb6bcb61e4..06af25362e 100644 --- a/src/daft-hash/src/lib.rs +++ b/src/daft-hash/src/lib.rs @@ -76,3 +76,10 @@ impl FromStr for HashFunctionKind { } } } + +#[cfg(feature = "python")] +pub fn register_modules(parent: &Bound) -> PyResult<()> { + parent.add_class::()?; + + Ok(()) +} diff --git a/src/lib.rs b/src/lib.rs index a7a1382538..e8d007e6da 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -118,6 +118,7 @@ pub mod pylib { daft_sql::register_modules(m)?; daft_functions::register_modules(m)?; daft_functions_json::register_modules(m)?; + daft_hash::register_modules(m)?; m.add_wrapped(wrap_pyfunction!(version))?; m.add_wrapped(wrap_pyfunction!(build_type))?; diff --git a/tests/series/test_minhash.py b/tests/series/test_minhash.py index d1f34471dd..9fffc279b9 100644 --- a/tests/series/test_minhash.py +++ b/tests/series/test_minhash.py @@ -1,21 +1,33 @@ -from typing import Literal +from __future__ import annotations + +from enum import Enum import pytest from daft import DataType, Series +class HashFunctionKind(Enum): + """ + Kind of hash function to use for minhash. + """ + + MurmurHash3 = 0 + XxHash = 1 + Sha1 = 2 + + def minhash_none( series: Series, num_hashes: int, ngram_size: int, seed: int | None, - hash_function: Literal["murmur3", "xxhash", "sha1"], + hash_function: HashFunctionKind, ) -> list[list[int] | None]: if seed is None: - return series.minhash(num_hashes, ngram_size, hash_function=hash_function).to_pylist() + return series.minhash(num_hashes, ngram_size, hash_function=hash_function.name.lower()).to_pylist() else: - return series.minhash(num_hashes, ngram_size, seed, hash_function=hash_function).to_pylist() + return series.minhash(num_hashes, ngram_size, seed, hash_function=hash_function.name.lower()).to_pylist() test_series = Series.from_pylist( @@ -39,7 +51,9 @@ def minhash_none( @pytest.mark.parametrize("num_hashes", [1, 2, 16, 128]) @pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100]) @pytest.mark.parametrize("seed", [1, -1, 123, None]) -@pytest.mark.parametrize("hash_function", ["murmur3", "xxhash", "sha1"]) +@pytest.mark.parametrize( + "hash_function", [HashFunctionKind.MurmurHash3, HashFunctionKind.XxHash, HashFunctionKind.Sha1] +) def test_minhash(num_hashes, ngram_size, seed, hash_function): minhash = minhash_none(test_series, num_hashes, ngram_size, seed, hash_function) assert minhash[4] is None and minhash[-1] is None @@ -55,7 +69,9 @@ def test_minhash(num_hashes, ngram_size, seed, hash_function): @pytest.mark.parametrize("num_hashes", [0, -1, -100]) @pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100]) @pytest.mark.parametrize("seed", [1, -1, 123, None]) -@pytest.mark.parametrize("hash_function", ["murmur3", "xxhash", "sha1"]) +@pytest.mark.parametrize( + "hash_function", [HashFunctionKind.MurmurHash3, HashFunctionKind.XxHash, HashFunctionKind.Sha1] +) def test_minhash_fails_nonpositive_num_hashes(num_hashes, ngram_size, seed, hash_function): with pytest.raises(ValueError, match="num_hashes must be positive"): minhash_none(test_series, num_hashes, ngram_size, seed, hash_function) @@ -64,7 +80,9 @@ def test_minhash_fails_nonpositive_num_hashes(num_hashes, ngram_size, seed, hash @pytest.mark.parametrize("num_hashes", [1, 2, 16, 128]) @pytest.mark.parametrize("ngram_size", [0, -1, -100]) @pytest.mark.parametrize("seed", [1, -1, 123, None]) -@pytest.mark.parametrize("hash_function", ["murmur3", "xxhash", "sha1"]) +@pytest.mark.parametrize( + "hash_function", [HashFunctionKind.MurmurHash3, HashFunctionKind.XxHash, HashFunctionKind.Sha1] +) def test_minhash_fails_nonpositive_ngram_size(num_hashes, ngram_size, seed, hash_function): with pytest.raises(ValueError, match="ngram_size must be positive"): minhash_none(test_series, num_hashes, ngram_size, seed, hash_function) @@ -73,7 +91,9 @@ def test_minhash_fails_nonpositive_ngram_size(num_hashes, ngram_size, seed, hash @pytest.mark.parametrize("num_hashes", [1, 2, 16, 128]) @pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100]) @pytest.mark.parametrize("seed", [1, -1, 123, None]) -@pytest.mark.parametrize("hash_function", ["murmur3", "xxhash", "sha1"]) +@pytest.mark.parametrize( + "hash_function", [HashFunctionKind.MurmurHash3, HashFunctionKind.XxHash, HashFunctionKind.Sha1] +) def test_minhash_empty_series(num_hashes, ngram_size, seed, hash_function): series = Series.from_pylist([]).cast(DataType.string()) @@ -84,7 +104,9 @@ def test_minhash_empty_series(num_hashes, ngram_size, seed, hash_function): @pytest.mark.parametrize("num_hashes", [1, 2, 16, 128]) @pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100]) @pytest.mark.parametrize("seed", [1, -1, 123, None]) -@pytest.mark.parametrize("hash_function", ["murmur3", "xxhash", "sha1"]) +@pytest.mark.parametrize( + "hash_function", [HashFunctionKind.MurmurHash3, HashFunctionKind.XxHash, HashFunctionKind.Sha1] +) def test_minhash_seed_consistency(num_hashes, ngram_size, seed, hash_function): minhash1 = minhash_none(test_series, num_hashes, ngram_size, seed, hash_function) minhash2 = minhash_none(test_series, num_hashes, ngram_size, seed, hash_function) @@ -94,7 +116,9 @@ def test_minhash_seed_consistency(num_hashes, ngram_size, seed, hash_function): @pytest.mark.parametrize("num_hashes", [1, 2, 16, 128]) @pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100]) @pytest.mark.parametrize("seed_pair", [[1, 2], [1, 5], [None, 2], [123, 234]]) -@pytest.mark.parametrize("hash_function", ["murmur3", "xxhash", "sha1"]) +@pytest.mark.parametrize( + "hash_function", [HashFunctionKind.MurmurHash3, HashFunctionKind.XxHash, HashFunctionKind.Sha1] +) def test_minhash_seed_differences(num_hashes, ngram_size, seed_pair, hash_function): minhash1 = minhash_none(test_series, num_hashes, ngram_size, seed_pair[0], hash_function) minhash2 = minhash_none(test_series, num_hashes, ngram_size, seed_pair[1], hash_function) diff --git a/tests/table/test_minhash.py b/tests/table/test_minhash.py index 5c56e95528..e485fa239c 100644 --- a/tests/table/test_minhash.py +++ b/tests/table/test_minhash.py @@ -2,12 +2,15 @@ import daft from daft import col +from daft.daft import HashFunctionKind @pytest.mark.parametrize("num_hashes", [1, 2, 16, 128]) @pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100]) @pytest.mark.parametrize("seed", [1, -1, 123, None]) -@pytest.mark.parametrize("hash_function", ["murmur3", "xxhash", "sha1"]) +@pytest.mark.parametrize( + "hash_function", [HashFunctionKind.MurmurHash3, HashFunctionKind.XxHash, HashFunctionKind.Sha1] +) def test_table_expr_minhash(num_hashes, ngram_size, seed, hash_function): df = daft.from_pydict( {