Skip to content

Commit

Permalink
fix some things in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Oct 23, 2024
1 parent 4577a51 commit 85ad3cc
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 28 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 13 additions & 11 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ daft-csv = {path = "src/daft-csv", default-features = false}
daft-dsl = {path = "src/daft-dsl", default-features = false}
daft-functions = {path = "src/daft-functions", default-features = false}
daft-functions-json = {path = "src/daft-functions-json", default-features = false}
daft-hash = {path = "src/daft-hash", default-features = false}
daft-image = {path = "src/daft-image", default-features = false}
daft-io = {path = "src/daft-io", default-features = false}
daft-json = {path = "src/daft-json", default-features = false}
Expand All @@ -36,29 +37,30 @@ sysinfo = {workspace = true}
[features]
# maturin will turn this on
python = [
"dep:pyo3",
"dep:pyo3-log",
"common-daft-config/python",
"common-display/python",
"common-resource-request/python",
"common-system-info/python",
"daft-core/python",
"daft-csv/python",
"daft-dsl/python",
"daft-local-execution/python",
"daft-io/python",
"daft-functions-json/python",
"daft-functions/python",
"daft-hash/python",
"daft-image/python",
"daft-io/python",
"daft-json/python",
"daft-local-execution/python",
"daft-micropartition/python",
"daft-parquet/python",
"daft-plan/python",
"daft-scan/python",
"daft-scheduler/python",
"daft-stats/python",
"daft-sql/python",
"daft-stats/python",
"daft-table/python",
"daft-functions/python",
"daft-functions-json/python",
"common-daft-config/python",
"common-system-info/python",
"common-display/python",
"common-resource-request/python"
"dep:pyo3",
"dep:pyo3-log"
]

[lib]
Expand Down
4 changes: 2 additions & 2 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1220,7 +1220,7 @@ def minhash(
num_hashes: int,
ngram_size: int,
seed: int = 1,
hash_function: Literal["murmur3", "xxhash", "sha1"] = "murmur3",
hash_function: HashFunctionKind = HashFunctionKind.MurmurHash3,
) -> PyExpr: ...

# -----
Expand Down Expand Up @@ -1360,7 +1360,7 @@ class PySeries:
num_hashes: int,
ngram_size: int,
seed: int = 1,
hash_function: Literal["murmur3", "xxhash", "sha1"] = "murmur3",
hash_function: HashFunctionKind = HashFunctionKind.MurmurHash3,
) -> PySeries: ...
def __invert__(self) -> PySeries: ...
def count(self, mode: CountMode) -> PySeries: ...
Expand Down
7 changes: 4 additions & 3 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,7 +1186,7 @@ def minhash(
num_hashes: int,
ngram_size: int,
seed: int = 1,
hash_function: Literal["murmur3", "xxhash", "sha1"] = "murmur3",
hash_function: native.HashFunctionKind = native.HashFunctionKind.MurmurHash3,
) -> Expression:
"""
Runs the MinHash algorithm on the series.
Expand All @@ -1195,20 +1195,21 @@ def minhash(
repeating with `num_hashes` permutations. Returns as a list of 32-bit unsigned integers.
Tokens for the ngrams are delimited by spaces.
MurmurHash is used for the initial hash.
The strings are not normalized or pre-processed, so it is recommended
to normalize the strings yourself.
Args:
num_hashes: The number of hash permutations to compute.
ngram_size: The number of tokens in each shingle/ngram.
seed (optional): Seed used for generating permutations and the initial string hashes. Defaults to 1.
hash_function (optional): Hash function to use for initial string hashing. One of "murmur3", "xxhash", or "sha1". Defaults to "murmur3".
"""
assert isinstance(num_hashes, int)
assert isinstance(ngram_size, int)
assert isinstance(seed, int)
assert isinstance(hash_function, str)
assert hash_function in ("murmur3", "xxhash", "sha1"), f"Hash function {hash_function} not found"
assert isinstance(hash_function, native.HashFunctionKind), f"Hash function {hash_function} not found"

return Expression._from_pyexpr(native.minhash(self._expr, num_hashes, ngram_size, seed, hash_function))

Expand Down
8 changes: 7 additions & 1 deletion daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Literal, TypeVar

from daft.arrow_utils import ensure_array, ensure_chunked_array
from daft.daft import CountMode, ImageFormat, ImageMode, PySeries, image
from daft.daft import CountMode, HashFunctionKind, ImageFormat, ImageMode, PySeries, image
from daft.datatype import DataType, _ensure_registered_super_ext_type
from daft.dependencies import np, pa, pd
from daft.utils import pyarrow_supports_fixed_shape_tensor
Expand Down Expand Up @@ -562,6 +562,7 @@ def minhash(
num_hashes: int,
ngram_size: int,
seed: int = 1,
hash_function: HashFunctionKind = HashFunctionKind.MurmurHash3,
) -> Series:
"""
Runs the MinHash algorithm on the series.
Expand All @@ -576,13 +577,18 @@ def minhash(
num_hashes: The number of hash permutations to compute.
ngram_size: The number of tokens in each shingle/ngram.
seed (optional): Seed used for generating permutations and the initial string hashes. Defaults to 1.
hash_function (optional): Hash function to use for initial string hashing. One of "murmur3", "xxhash", or "sha1". Defaults to "murmur3".
"""
if not isinstance(num_hashes, int):
raise ValueError(f"expected an integer for num_hashes but got {type(num_hashes)}")
if not isinstance(ngram_size, int):
raise ValueError(f"expected an integer for ngram_size but got {type(ngram_size)}")
if seed is not None and not isinstance(seed, int):
raise ValueError(f"expected an integer or None for seed but got {type(seed)}")
if not isinstance(hash_function, str):
raise ValueError(f"expected a string for hash_function but got {type(hash_function)}")
if not isinstance(hash_function, HashFunctionKind):
raise ValueError(f"expected HashFunctionKind for hash_function but got {type(hash_function)}")

return Series._from_pyseries(self._series.minhash(num_hashes, ngram_size, seed))

Expand Down
7 changes: 7 additions & 0 deletions src/daft-hash/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,10 @@ impl FromStr for HashFunctionKind {
}
}
}

#[cfg(feature = "python")]
pub fn register_modules(parent: &Bound<PyModule>) -> PyResult<()> {
parent.add_class::<HashFunctionKind>()?;

Ok(())
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ pub mod pylib {
daft_sql::register_modules(m)?;
daft_functions::register_modules(m)?;
daft_functions_json::register_modules(m)?;
daft_hash::register_modules(m)?;

m.add_wrapped(wrap_pyfunction!(version))?;
m.add_wrapped(wrap_pyfunction!(build_type))?;
Expand Down
44 changes: 34 additions & 10 deletions tests/series/test_minhash.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,33 @@
from typing import Literal
from __future__ import annotations

from enum import Enum

import pytest

from daft import DataType, Series


class HashFunctionKind(Enum):
"""
Kind of hash function to use for minhash.
"""

MurmurHash3 = 0
XxHash = 1
Sha1 = 2


def minhash_none(
series: Series,
num_hashes: int,
ngram_size: int,
seed: int | None,
hash_function: Literal["murmur3", "xxhash", "sha1"],
hash_function: HashFunctionKind,
) -> list[list[int] | None]:
if seed is None:
return series.minhash(num_hashes, ngram_size, hash_function=hash_function).to_pylist()
return series.minhash(num_hashes, ngram_size, hash_function=hash_function.name.lower()).to_pylist()
else:
return series.minhash(num_hashes, ngram_size, seed, hash_function=hash_function).to_pylist()
return series.minhash(num_hashes, ngram_size, seed, hash_function=hash_function.name.lower()).to_pylist()


test_series = Series.from_pylist(
Expand All @@ -39,7 +51,9 @@ def minhash_none(
@pytest.mark.parametrize("num_hashes", [1, 2, 16, 128])
@pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100])
@pytest.mark.parametrize("seed", [1, -1, 123, None])
@pytest.mark.parametrize("hash_function", ["murmur3", "xxhash", "sha1"])
@pytest.mark.parametrize(
"hash_function", [HashFunctionKind.MurmurHash3, HashFunctionKind.XxHash, HashFunctionKind.Sha1]
)
def test_minhash(num_hashes, ngram_size, seed, hash_function):
minhash = minhash_none(test_series, num_hashes, ngram_size, seed, hash_function)
assert minhash[4] is None and minhash[-1] is None
Expand All @@ -55,7 +69,9 @@ def test_minhash(num_hashes, ngram_size, seed, hash_function):
@pytest.mark.parametrize("num_hashes", [0, -1, -100])
@pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100])
@pytest.mark.parametrize("seed", [1, -1, 123, None])
@pytest.mark.parametrize("hash_function", ["murmur3", "xxhash", "sha1"])
@pytest.mark.parametrize(
"hash_function", [HashFunctionKind.MurmurHash3, HashFunctionKind.XxHash, HashFunctionKind.Sha1]
)
def test_minhash_fails_nonpositive_num_hashes(num_hashes, ngram_size, seed, hash_function):
with pytest.raises(ValueError, match="num_hashes must be positive"):
minhash_none(test_series, num_hashes, ngram_size, seed, hash_function)
Expand All @@ -64,7 +80,9 @@ def test_minhash_fails_nonpositive_num_hashes(num_hashes, ngram_size, seed, hash
@pytest.mark.parametrize("num_hashes", [1, 2, 16, 128])
@pytest.mark.parametrize("ngram_size", [0, -1, -100])
@pytest.mark.parametrize("seed", [1, -1, 123, None])
@pytest.mark.parametrize("hash_function", ["murmur3", "xxhash", "sha1"])
@pytest.mark.parametrize(
"hash_function", [HashFunctionKind.MurmurHash3, HashFunctionKind.XxHash, HashFunctionKind.Sha1]
)
def test_minhash_fails_nonpositive_ngram_size(num_hashes, ngram_size, seed, hash_function):
with pytest.raises(ValueError, match="ngram_size must be positive"):
minhash_none(test_series, num_hashes, ngram_size, seed, hash_function)
Expand All @@ -73,7 +91,9 @@ def test_minhash_fails_nonpositive_ngram_size(num_hashes, ngram_size, seed, hash
@pytest.mark.parametrize("num_hashes", [1, 2, 16, 128])
@pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100])
@pytest.mark.parametrize("seed", [1, -1, 123, None])
@pytest.mark.parametrize("hash_function", ["murmur3", "xxhash", "sha1"])
@pytest.mark.parametrize(
"hash_function", [HashFunctionKind.MurmurHash3, HashFunctionKind.XxHash, HashFunctionKind.Sha1]
)
def test_minhash_empty_series(num_hashes, ngram_size, seed, hash_function):
series = Series.from_pylist([]).cast(DataType.string())

Expand All @@ -84,7 +104,9 @@ def test_minhash_empty_series(num_hashes, ngram_size, seed, hash_function):
@pytest.mark.parametrize("num_hashes", [1, 2, 16, 128])
@pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100])
@pytest.mark.parametrize("seed", [1, -1, 123, None])
@pytest.mark.parametrize("hash_function", ["murmur3", "xxhash", "sha1"])
@pytest.mark.parametrize(
"hash_function", [HashFunctionKind.MurmurHash3, HashFunctionKind.XxHash, HashFunctionKind.Sha1]
)
def test_minhash_seed_consistency(num_hashes, ngram_size, seed, hash_function):
minhash1 = minhash_none(test_series, num_hashes, ngram_size, seed, hash_function)
minhash2 = minhash_none(test_series, num_hashes, ngram_size, seed, hash_function)
Expand All @@ -94,7 +116,9 @@ def test_minhash_seed_consistency(num_hashes, ngram_size, seed, hash_function):
@pytest.mark.parametrize("num_hashes", [1, 2, 16, 128])
@pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100])
@pytest.mark.parametrize("seed_pair", [[1, 2], [1, 5], [None, 2], [123, 234]])
@pytest.mark.parametrize("hash_function", ["murmur3", "xxhash", "sha1"])
@pytest.mark.parametrize(
"hash_function", [HashFunctionKind.MurmurHash3, HashFunctionKind.XxHash, HashFunctionKind.Sha1]
)
def test_minhash_seed_differences(num_hashes, ngram_size, seed_pair, hash_function):
minhash1 = minhash_none(test_series, num_hashes, ngram_size, seed_pair[0], hash_function)
minhash2 = minhash_none(test_series, num_hashes, ngram_size, seed_pair[1], hash_function)
Expand Down
5 changes: 4 additions & 1 deletion tests/table/test_minhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@

import daft
from daft import col
from daft.daft import HashFunctionKind


@pytest.mark.parametrize("num_hashes", [1, 2, 16, 128])
@pytest.mark.parametrize("ngram_size", [1, 2, 4, 5, 100])
@pytest.mark.parametrize("seed", [1, -1, 123, None])
@pytest.mark.parametrize("hash_function", ["murmur3", "xxhash", "sha1"])
@pytest.mark.parametrize(
"hash_function", [HashFunctionKind.MurmurHash3, HashFunctionKind.XxHash, HashFunctionKind.Sha1]
)
def test_table_expr_minhash(num_hashes, ngram_size, seed, hash_function):
df = daft.from_pydict(
{
Expand Down

0 comments on commit 85ad3cc

Please sign in to comment.