From 96ca8433194a7b623e21f286ab42c2c177313ea8 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Tue, 22 Oct 2024 20:51:41 -0700 Subject: [PATCH] fix --- daft/expressions/expressions.py | 1 - daft/series.py | 4 +--- tests/series/test_minhash.py | 17 +++-------------- 3 files changed, 4 insertions(+), 18 deletions(-) diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index cb2f277834..2c402e14ba 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -1208,7 +1208,6 @@ def minhash( assert isinstance(num_hashes, int) assert isinstance(ngram_size, int) assert isinstance(seed, int) - assert isinstance(hash_function, str) assert isinstance(hash_function, native.HashFunctionKind), f"Hash function {hash_function} not found" return Expression._from_pyexpr(native.minhash(self._expr, num_hashes, ngram_size, seed, hash_function)) diff --git a/daft/series.py b/daft/series.py index 88df6f20d3..9bbbef8fcf 100644 --- a/daft/series.py +++ b/daft/series.py @@ -585,12 +585,10 @@ def minhash( raise ValueError(f"expected an integer for ngram_size but got {type(ngram_size)}") if seed is not None and not isinstance(seed, int): raise ValueError(f"expected an integer or None for seed but got {type(seed)}") - if not isinstance(hash_function, str): - raise ValueError(f"expected a string for hash_function but got {type(hash_function)}") if not isinstance(hash_function, HashFunctionKind): raise ValueError(f"expected HashFunctionKind for hash_function but got {type(hash_function)}") - return Series._from_pyseries(self._series.minhash(num_hashes, ngram_size, seed)) + return Series._from_pyseries(self._series.minhash(num_hashes, ngram_size, seed, hash_function)) def _to_str_values(self) -> Series: return Series._from_pyseries(self._series.to_str_values()) diff --git a/tests/series/test_minhash.py b/tests/series/test_minhash.py index 9fffc279b9..7dd68cd7b0 100644 --- a/tests/series/test_minhash.py +++ b/tests/series/test_minhash.py @@ -1,20 +1,9 @@ from __future__ import annotations -from enum import Enum - import pytest from daft import DataType, Series - - -class HashFunctionKind(Enum): - """ - Kind of hash function to use for minhash. - """ - - MurmurHash3 = 0 - XxHash = 1 - Sha1 = 2 +from daft.daft import HashFunctionKind def minhash_none( @@ -25,9 +14,9 @@ def minhash_none( hash_function: HashFunctionKind, ) -> list[list[int] | None]: if seed is None: - return series.minhash(num_hashes, ngram_size, hash_function=hash_function.name.lower()).to_pylist() + return series.minhash(num_hashes, ngram_size, hash_function=hash_function).to_pylist() else: - return series.minhash(num_hashes, ngram_size, seed, hash_function=hash_function.name.lower()).to_pylist() + return series.minhash(num_hashes, ngram_size, seed, hash_function=hash_function).to_pylist() test_series = Series.from_pylist(