Skip to content

Commit

Permalink
fix: DX - added nparray as supported/validated type for embeddings
Browse files Browse the repository at this point in the history
- Fixed an issue where List[List[bool]] passes the validation check as bool is subclass of int
- Added strategy for embeddings purely as numpy array.
  • Loading branch information
tazarov committed Sep 12, 2023
1 parent 154b82b commit 48b6730
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 23 deletions.
42 changes: 24 additions & 18 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Optional, Union, Sequence, TypeVar, List, Dict, Any

import numpy as np
from typing_extensions import Literal, TypedDict, Protocol
import chromadb.errors as errors
from chromadb.types import (
Expand Down Expand Up @@ -325,33 +327,37 @@ def validate_n_results(n_results: int) -> int:
return n_results


def validate_embeddings(embeddings: Embeddings) -> Embeddings:
def validate_embeddings(
embeddings: Union[Embeddings, np.ndarray[Any, Any]]
) -> Embeddings:
"""Validates embeddings to ensure it is a list of list of ints, or floats"""
if not isinstance(embeddings, list):
if not isinstance(embeddings, (list, np.ndarray)):
raise ValueError(f"Expected embeddings to be a list, got {embeddings}")
if len(embeddings) == 0:
raise ValueError(
f"Expected embeddings to be a list with at least one item, got {embeddings}"
)
if not all([isinstance(e, list) for e in embeddings]):
if not all([isinstance(e, (list, np.ndarray)) for e in embeddings]):
raise ValueError(
f"Expected each embedding in the embeddings to be a list, got {embeddings}"
)
_new_embeddings: Embeddings = list()
for embedding in embeddings:
if not all([isinstance(value, (int, float)) for value in embedding]):
try:
_new_embeddings.append(
[
float(value) if isinstance(value, float) else int(value)
for value in embedding
]
)
except (ValueError, TypeError):
if isinstance(embedding, np.ndarray):
if not (
np.issubdtype(embedding.dtype, np.integer)
or np.issubdtype(embedding.dtype, np.floating)
):
raise ValueError(
f"Expected each value in the embedding to be a int or float, got { set(type(v).__name__ for v in embedding) } types: "
f"{embedding}"
f"Expected each value in the embedding to be a int or float, got {embeddings}"
)
else:
_new_embeddings.append(embedding)
return _new_embeddings
else:
embedding = embedding.tolist()
if all([isinstance(value, bool) for value in embedding]):
raise ValueError(
f"Expected each value in the embedding to be a int or float, got {embeddings}"
)
if not all([isinstance(value, (int, float)) for value in embedding]):
raise ValueError(
f"Expected each value in the embedding to be a int or float, got {embeddings}"
)
return embeddings.tolist() if isinstance(embeddings, np.ndarray) else embeddings # type: ignore
12 changes: 12 additions & 0 deletions chromadb/test/property/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,18 @@ def create_embeddings(
return embeddings


def create_embeddings_ndarray(
dim: int,
count: int,
dtype: npt.DTypeLike,
) -> np.ndarray[Any, Any]:
return np.random.uniform(
low=-1.0,
high=1.0,
size=(count, dim),
).astype(dtype)


class hashing_embedding_function(types.EmbeddingFunction):
def __init__(self, dim: int, dtype: npt.DTypeLike) -> None:
self.dim = dim
Expand Down
16 changes: 11 additions & 5 deletions chromadb/test/property/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
)
from collections import defaultdict
import chromadb.test.property.invariants as invariants
import numpy as np

traces: DefaultDict[str, int] = defaultdict(lambda: 0)

Expand Down Expand Up @@ -406,17 +405,24 @@ def test_delete_empty_fails(api: API):
coll.delete(ids=["foo"], where_document={"$contains": "bar"}, where={"foo": "bar"})


@given(supported_types=st.sampled_from([np.float32, np.int32, np.int64]))
@given(supported_types=st.sampled_from([np.float32, np.int32, np.int64, int, float]))
def test_autocasting_validate_embeddings_for_compatible_types(
supported_types: list[Type[Any]], caplog: pytest.LogCaptureFixture, api: API
supported_types: list[Type[Any]], caplog: pytest.LogCaptureFixture
) -> None:
embds = strategies.create_embeddings(10, 10, supported_types)
validate_embeddings(embds)


@given(unsupported_types=st.sampled_from([str]))
def test_autocasting_validate_embeddings_with_ndarray(
caplog: pytest.LogCaptureFixture,
) -> None:
embds = strategies.create_embeddings_ndarray(10, 10)
validate_embeddings(embds)


@given(unsupported_types=st.sampled_from([str, bool]))
def test_autocasting_validate_embeddings_incompatible_types(
unsupported_types: list[Type[Any]], api: API
unsupported_types: list[Type[Any]],
) -> None:
embds = strategies.create_embeddings(10, 10, unsupported_types)
with pytest.raises(ValueError) as e:
Expand Down

0 comments on commit 48b6730

Please sign in to comment.