diff --git a/chromadb/api/__init__.py b/chromadb/api/__init__.py index b68d7a27d10..ab8d22499bc 100644 --- a/chromadb/api/__init__.py +++ b/chromadb/api/__init__.py @@ -8,6 +8,7 @@ from chromadb.api.types import ( CollectionMetadata, Documents, + Embeddable, EmbeddingFunction, Embeddings, IDs, @@ -58,7 +59,9 @@ def create_collection( self, name: str, metadata: Optional[CollectionMetadata] = None, - embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + embedding_function: Optional[ + EmbeddingFunction[Embeddable] + ] = ef.DefaultEmbeddingFunction(), # type: ignore get_or_create: bool = False, ) -> Collection: """Create a new collection with the given name and metadata. @@ -90,9 +93,11 @@ def create_collection( @abstractmethod def get_collection( self, - name: Optional[str] = None, + name: str, id: Optional[UUID] = None, - embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + embedding_function: Optional[ + EmbeddingFunction[Embeddable] + ] = ef.DefaultEmbeddingFunction(), # type: ignore ) -> Collection: """Get a collection with the given name. Args: @@ -119,7 +124,9 @@ def get_or_create_collection( self, name: str, metadata: Optional[CollectionMetadata] = None, - embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + embedding_function: Optional[ + EmbeddingFunction[Embeddable] + ] = ef.DefaultEmbeddingFunction(), # type: ignore ) -> Collection: """Get or create a collection with the given name and metadata. Args: @@ -486,7 +493,9 @@ def create_collection( self, name: str, metadata: Optional[CollectionMetadata] = None, - embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + embedding_function: Optional[ + EmbeddingFunction[Embeddable] + ] = ef.DefaultEmbeddingFunction(), # type: ignore get_or_create: bool = False, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, @@ -497,9 +506,11 @@ def create_collection( @override def get_collection( self, - name: Optional[str] = None, + name: str, id: Optional[UUID] = None, - embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + embedding_function: Optional[ + EmbeddingFunction[Embeddable] + ] = ef.DefaultEmbeddingFunction(), # type: ignore tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> Collection: @@ -511,7 +522,9 @@ def get_or_create_collection( self, name: str, metadata: Optional[CollectionMetadata] = None, - embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + embedding_function: Optional[ + EmbeddingFunction[Embeddable] + ] = ef.DefaultEmbeddingFunction(), # type: ignore tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> Collection: diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index 38af7e52a91..6dcaaf84c44 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -14,6 +14,7 @@ from chromadb.api.models.Collection import Collection from chromadb.api.types import ( Documents, + Embeddable, Embeddings, EmbeddingFunction, IDs, @@ -219,7 +220,9 @@ def create_collection( self, name: str, metadata: Optional[CollectionMetadata] = None, - embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + embedding_function: Optional[ + EmbeddingFunction[Embeddable] + ] = ef.DefaultEmbeddingFunction(), # type: ignore get_or_create: bool = False, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, @@ -250,9 +253,9 @@ def create_collection( @override def get_collection( self, - name: Optional[str] = None, + name: str, id: Optional[UUID] = None, - embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + embedding_function: Optional[EmbeddingFunction[Embeddable]] = ef.DefaultEmbeddingFunction(), # type: ignore tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> Collection: @@ -284,17 +287,20 @@ def get_or_create_collection( self, name: str, metadata: Optional[CollectionMetadata] = None, - embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + embedding_function: Optional[EmbeddingFunction[Embeddable]] = ef.DefaultEmbeddingFunction(), # type: ignore tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> Collection: - return self.create_collection( - name, - metadata, - embedding_function, - get_or_create=True, - tenant=tenant, - database=database, + return cast( + Collection, + self.create_collection( + name, + metadata, + embedding_function, + get_or_create=True, + tenant=tenant, + database=database, + ), ) @trace_method("FastAPI._modify", OpenTelemetryGranularity.OPERATION) @@ -347,10 +353,13 @@ def _peek( collection_id: UUID, n: int = 10, ) -> GetResult: - return self._get( - collection_id, - limit=n, - include=["embeddings", "documents", "metadatas"], + return cast( + GetResult, + self._get( + collection_id, + limit=n, + include=["embeddings", "documents", "metadatas"], + ), ) @trace_method("FastAPI._get", OpenTelemetryGranularity.OPERATION) diff --git a/chromadb/api/models/Collection.py b/chromadb/api/models/Collection.py index ef7c66139d2..058c9c86f8f 100644 --- a/chromadb/api/models/Collection.py +++ b/chromadb/api/models/Collection.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional, Tuple, cast, List +from typing import TYPE_CHECKING, Optional, Tuple, Any from pydantic import BaseModel, PrivateAttr from uuid import UUID @@ -7,9 +7,15 @@ from chromadb.api.types import ( CollectionMetadata, Embedding, + Embeddings, + Embeddable, Include, Metadata, + Metadatas, Document, + Documents, + Image, + Images, Where, IDs, EmbeddingFunction, @@ -18,7 +24,11 @@ ID, OneOrMany, WhereDocument, - maybe_cast_one_to_many, + maybe_cast_one_to_many_ids, + maybe_cast_one_to_many_embedding, + maybe_cast_one_to_many_metadata, + maybe_cast_one_to_many_document, + maybe_cast_one_to_many_image, validate_ids, validate_include, validate_metadata, @@ -27,6 +37,7 @@ validate_where_document, validate_n_results, validate_embeddings, + validate_embedding_function, ) import logging @@ -43,14 +54,16 @@ class Collection(BaseModel): tenant: Optional[str] = None database: Optional[str] = None _client: "ServerAPI" = PrivateAttr() - _embedding_function: Optional[EmbeddingFunction] = PrivateAttr() + _embedding_function: Optional[EmbeddingFunction[Embeddable]] = PrivateAttr() def __init__( self, client: "ServerAPI", name: str, id: UUID, - embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + embedding_function: Optional[ + EmbeddingFunction[Embeddable] + ] = ef.DefaultEmbeddingFunction(), # type: ignore tenant: Optional[str] = None, database: Optional[str] = None, metadata: Optional[CollectionMetadata] = None, @@ -59,6 +72,11 @@ def __init__( name=name, metadata=metadata, id=id, tenant=tenant, database=database ) self._client = client + + # Check to make sure the embedding function has the right signature, as defined by the EmbeddingFunction protocol + if embedding_function is not None: + validate_embedding_function(embedding_function) + self._embedding_function = embedding_function def __repr__(self) -> str: @@ -79,13 +97,15 @@ def add( embeddings: Optional[OneOrMany[Embedding]] = None, metadatas: Optional[OneOrMany[Metadata]] = None, documents: Optional[OneOrMany[Document]] = None, + images: Optional[OneOrMany[Image]] = None, ) -> None: """Add embeddings to the data store. Args: ids: The ids of the embeddings you wish to add - embeddings: The embeddings to add. If None, embeddings will be computed based on the documents using the embedding_function set for the Collection. Optional. + embeddings: The embeddings to add. If None, embeddings will be computed based on the documents or images using the embedding_function set for the Collection. Optional. metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional. documents: The documents to associate with the embeddings. Optional. + images: The images to associate with the embeddings. Optional. Returns: None @@ -99,10 +119,22 @@ def add( """ - ids, embeddings, metadatas, documents = self._validate_embedding_set( - ids, embeddings, metadatas, documents + ids, embeddings, metadatas, documents, images = self._validate_embedding_set( + ids, embeddings, metadatas, documents, images ) + # We need to compute the embeddings if they're not provided + if embeddings is None: + # At this point, we know that one of documents or images are provided from the validation above + if documents is not None: + embeddings = self._embed(input=documents) + elif images is not None: + embeddings = self._embed(input=images) + else: + raise ValueError( + "You must provide embeddings, documents, or images, or an embedding function." + ) + self._client._add(ids, self.id, embeddings, metadatas, documents) def get( @@ -133,7 +165,7 @@ def get( where_document = ( validate_where_document(where_document) if where_document else None ) - ids = validate_ids(maybe_cast_one_to_many(ids)) if ids else None + ids = validate_ids(maybe_cast_one_to_many_ids(ids)) if ids else None include = validate_include(include, allow_distances=False) return self._client._get( self.id, @@ -161,6 +193,7 @@ def query( self, query_embeddings: Optional[OneOrMany[Embedding]] = None, query_texts: Optional[OneOrMany[Document]] = None, + query_images: Optional[OneOrMany[Image]] = None, n_results: int = 10, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, @@ -171,6 +204,7 @@ def query( Args: query_embeddings: The embeddings to get the closes neighbors of. Optional. query_texts: The document texts to get the closes neighbors of. Optional. + query_images: The images to get the closes neighbors of. Optional. n_results: The number of neighbors to return for each query_embedding or query_texts. Optional. where: A Where type dict used to filter results by. E.g. `{"$and": ["color" : "red", "price": {"$gte": 4.20}]}`. Optional. where_document: A WhereDocument type dict used to filter by the documents. E.g. `{$contains: {"text": "hello"}}`. Optional. @@ -180,43 +214,58 @@ def query( QueryResult: A QueryResult object containing the results. Raises: - ValueError: If you don't provide either query_embeddings or query_texts + ValueError: If you don't provide either query_embeddings, query_texts, or query_images ValueError: If you provide both query_embeddings and query_texts + ValueError: If you provide both query_embeddings and query_images + ValueError: If you provide both query_texts and query_images """ + # If neither query_embeddings nor query_texts are provided, or both are provided, raise an error + if ( + (query_embeddings is None and query_texts is None and query_images is None) + or ( + query_embeddings is not None + and (query_texts is not None or query_images is not None) + ) + or (query_texts is not None and query_images is not None) + ): + raise ValueError( + "You must provide either query embeddings, or else one of query texts or query images." + ) + where = validate_where(where) if where else None where_document = ( validate_where_document(where_document) if where_document else None ) query_embeddings = ( - validate_embeddings(maybe_cast_one_to_many(query_embeddings)) + validate_embeddings(maybe_cast_one_to_many_embedding(query_embeddings)) if query_embeddings is not None else None ) query_texts = ( - maybe_cast_one_to_many(query_texts) if query_texts is not None else None + maybe_cast_one_to_many_document(query_texts) + if query_texts is not None + else None + ) + query_images = ( + maybe_cast_one_to_many_image(query_images) + if query_images is not None + else None ) include = validate_include(include, allow_distances=True) n_results = validate_n_results(n_results) - # If neither query_embeddings nor query_texts are provided, or both are provided, raise an error - if (query_embeddings is None and query_texts is None) or ( - query_embeddings is not None and query_texts is not None - ): - raise ValueError( - "You must provide either query embeddings or query texts, but not both" - ) - - # If query_embeddings are not provided, we need to compute them from the query_texts + # If query_embeddings are not provided, we need to compute them from the inputs if query_embeddings is None: - if self._embedding_function is None: + # At this point, we know that one of query_texts or query_images are provided from the validation above + if query_texts is not None: + query_embeddings = self._embed(input=query_texts) + elif query_images is not None: + query_embeddings = self._embed(input=query_images) + else: raise ValueError( - "You must provide embeddings or a function to compute them" + "You must provide either query embeddings, or else one of query texts or query images." ) - # We know query texts is not None at this point, cast for the typechecker - query_embeddings = self._embedding_function( - cast(List[Document], query_texts) - ) if where is None: where = {} @@ -260,23 +309,35 @@ def update( embeddings: Optional[OneOrMany[Embedding]] = None, metadatas: Optional[OneOrMany[Metadata]] = None, documents: Optional[OneOrMany[Document]] = None, + images: Optional[OneOrMany[Image]] = None, ) -> None: """Update the embeddings, metadatas or documents for provided ids. Args: ids: The ids of the embeddings to update - embeddings: The embeddings to add. If None, embeddings will be computed based on the documents using the embedding_function set for the Collection. Optional. + embeddings: The embeddings to update. If None, embeddings will be computed based on the documents or images using the embedding_function set for the Collection. Optional. metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional. documents: The documents to associate with the embeddings. Optional. - + images: The images to associate with the embeddings. Optional. Returns: None """ - ids, embeddings, metadatas, documents = self._validate_embedding_set( - ids, embeddings, metadatas, documents, require_embeddings_or_documents=False + ids, embeddings, metadatas, documents, images = self._validate_embedding_set( + ids, + embeddings, + metadatas, + documents, + images, + require_embeddings_or_data=False, ) + if embeddings is None: + if documents is not None: + embeddings = self._embed(input=documents) + elif images is not None: + embeddings = self._embed(input=images) + self._client._update(self.id, ids, embeddings, metadatas, documents) def upsert( @@ -285,6 +346,7 @@ def upsert( embeddings: Optional[OneOrMany[Embedding]] = None, metadatas: Optional[OneOrMany[Metadata]] = None, documents: Optional[OneOrMany[Document]] = None, + images: Optional[OneOrMany[Image]] = None, ) -> None: """Update the embeddings, metadatas or documents for provided ids, or create them if they don't exist. @@ -298,10 +360,16 @@ def upsert( None """ - ids, embeddings, metadatas, documents = self._validate_embedding_set( - ids, embeddings, metadatas, documents + ids, embeddings, metadatas, documents, images = self._validate_embedding_set( + ids, embeddings, metadatas, documents, images ) + if embeddings is None: + if documents is not None: + embeddings = self._embed(input=documents) + else: + embeddings = self._embed(input=images) + self._client._upsert( collection_id=self.id, ids=ids, @@ -329,7 +397,7 @@ def delete( Raises: ValueError: If you don't provide either ids, where, or where_document """ - ids = validate_ids(maybe_cast_one_to_many(ids)) if ids else None + ids = validate_ids(maybe_cast_one_to_many_ids(ids)) if ids else None where = validate_where(where) if where else None where_document = ( validate_where_document(where_document) if where_document else None @@ -343,58 +411,74 @@ def _validate_embedding_set( embeddings: Optional[OneOrMany[Embedding]], metadatas: Optional[OneOrMany[Metadata]], documents: Optional[OneOrMany[Document]], - require_embeddings_or_documents: bool = True, + images: Optional[OneOrMany[Image]] = None, + require_embeddings_or_data: bool = True, ) -> Tuple[ IDs, - List[Embedding], - Optional[List[Metadata]], - Optional[List[Document]], + Optional[Embeddings], + Optional[Metadatas], + Optional[Documents], + Optional[Images], ]: - ids = validate_ids(maybe_cast_one_to_many(ids)) - embeddings = ( - validate_embeddings(maybe_cast_one_to_many(embeddings)) + valid_ids = validate_ids(maybe_cast_one_to_many_ids(ids)) + valid_embeddings = ( + validate_embeddings(maybe_cast_one_to_many_embedding(embeddings)) if embeddings is not None else None ) - metadatas = ( - validate_metadatas(maybe_cast_one_to_many(metadatas)) + valid_metadatas = ( + validate_metadatas(maybe_cast_one_to_many_metadata(metadatas)) if metadatas is not None else None ) - documents = maybe_cast_one_to_many(documents) if documents is not None else None + valid_documents = ( + maybe_cast_one_to_many_document(documents) + if documents is not None + else None + ) + valid_images = ( + maybe_cast_one_to_many_image(images) if images is not None else None + ) - # Check that one of embeddings or documents is provided - if require_embeddings_or_documents: - if embeddings is None and documents is None: - raise ValueError( - "You must provide either embeddings or documents, or both" - ) + # Check that one of embeddings or ducuments or images is provided + if require_embeddings_or_data: + if ( + valid_embeddings is None + and valid_documents is None + and valid_images is None + ): + raise ValueError("You must provide embeddings, documents, or images.") + + # Only one of documents or images can be provided + if valid_documents is not None and valid_images is not None: + raise ValueError("You can only provide documents or images, not both.") # Check that, if they're provided, the lengths of the arrays match the length of ids - if embeddings is not None and len(embeddings) != len(ids): + if valid_embeddings is not None and len(valid_embeddings) != len(valid_ids): raise ValueError( - f"Number of embeddings {len(embeddings)} must match number of ids {len(ids)}" + f"Number of embeddings {len(valid_embeddings)} must match number of ids {len(valid_ids)}" ) - if metadatas is not None and len(metadatas) != len(ids): + if valid_metadatas is not None and len(valid_metadatas) != len(valid_ids): raise ValueError( - f"Number of metadatas {len(metadatas)} must match number of ids {len(ids)}" + f"Number of metadatas {len(valid_metadatas)} must match number of ids {len(valid_ids)}" ) - if documents is not None and len(documents) != len(ids): + if valid_documents is not None and len(valid_documents) != len(valid_ids): raise ValueError( - f"Number of documents {len(documents)} must match number of ids {len(ids)}" + f"Number of documents {len(valid_documents)} must match number of ids {len(valid_ids)}" ) - # If document embeddings are not provided, we need to compute them - if embeddings is None and documents is not None: - if self._embedding_function is None: - raise ValueError( - "You must provide embeddings or a function to compute them" - ) - embeddings = self._embedding_function(documents) - - # if embeddings is None: - # raise ValueError( - # "Something went wrong. Embeddings should be computed at this point" - # ) + return ( + valid_ids, + valid_embeddings, + valid_metadatas, + valid_documents, + valid_images, + ) - return ids, embeddings, metadatas, documents # type: ignore + def _embed(self, input: Any) -> Embeddings: + if self._embedding_function is None: + raise ValueError( + "You must provide an embedding function to compute embeddings." + "https://docs.trychroma.com/embeddings" + ) + return self._embedding_function(input=input) diff --git a/chromadb/api/types.py b/chromadb/api/types.py index 017e356ffac..84c55257dcb 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -1,4 +1,6 @@ -from typing import Optional, Union, Sequence, TypeVar, List, Dict, Any, Tuple +from typing import Optional, Union, TypeVar, List, Dict, Any, Tuple, cast +from numpy.typing import NDArray +import numpy as np from typing_extensions import Literal, TypedDict, Protocol import chromadb.errors as errors from chromadb.types import ( @@ -13,27 +15,97 @@ WhereDocumentOperator, WhereDocument, ) +from inspect import signature # Re-export types from chromadb.types __all__ = ["Metadata", "Where", "WhereDocument", "UpdateCollectionMetadata"] +T = TypeVar("T") +OneOrMany = Union[T, List[T]] + +# IDs ID = str IDs = List[ID] + +def maybe_cast_one_to_many_ids(target: OneOrMany[ID]) -> IDs: + if isinstance(target, str): + # One ID + return cast(IDs, [target]) + # Already a sequence + return cast(IDs, target) + + +# Embeddings Embedding = Vector Embeddings = List[Embedding] + +def maybe_cast_one_to_many_embedding(target: OneOrMany[Embedding]) -> Embeddings: + if isinstance(target, List): + # One Embedding + if isinstance(target[0], (int, float)): + return cast(Embeddings, [target]) + # Already a sequence + return cast(Embeddings, target) + + +# Metadatas Metadatas = List[Metadata] + +def maybe_cast_one_to_many_metadata(target: OneOrMany[Metadata]) -> Metadatas: + # One Metadata dict + if isinstance(target, dict): + return cast(Metadatas, [target]) + # Already a sequence + return cast(Metadatas, target) + + CollectionMetadata = Dict[str, Any] UpdateCollectionMetadata = UpdateMetadata +# Documents Document = str Documents = List[Document] -Parameter = TypeVar("Parameter", Embedding, Document, Metadata, ID) -T = TypeVar("T") -OneOrMany = Union[T, List[T]] + +def is_document(target: Any) -> bool: + if not isinstance(target, str): + return False + return True + + +def maybe_cast_one_to_many_document(target: OneOrMany[Document]) -> Documents: + # One Document + if is_document(target): + return cast(Documents, [target]) + # Already a sequence + return cast(Documents, target) + + +# Images +ImageDType = Union[np.uint, np.int_, np.float_] +Image = NDArray[ImageDType] +Images = List[Image] + + +def is_image(target: Any) -> bool: + if not isinstance(target, np.ndarray): + return False + if len(target.shape) < 2: + return False + return True + + +def maybe_cast_one_to_many_image(target: OneOrMany[Image]) -> Images: + if is_image(target): + return cast(Images, [target]) + # Already a sequence + return cast(Images, target) + + +Parameter = TypeVar("Parameter", Document, Image, Embedding, Metadata, ID) # This should ust be List[Literal["documents", "embeddings", "metadatas", "distances"]] # However, this provokes an incompatibility with the Overrides library and Python 3.7 @@ -81,28 +153,29 @@ class IndexMetadata(TypedDict): time_created: float -class EmbeddingFunction(Protocol): - def __call__(self, texts: Documents) -> Embeddings: +Embeddable = Union[Documents, Images] +D = TypeVar("D", bound=Embeddable, contravariant=True) + + +class EmbeddingFunction(Protocol[D]): + def __call__(self, input: D) -> Embeddings: ... -def maybe_cast_one_to_many( - target: OneOrMany[Parameter], -) -> List[Parameter]: - """Infers if target is Embedding, Metadata, or Document and casts it to a many object if its one""" +def validate_embedding_function( + embedding_function: EmbeddingFunction[Embeddable], +) -> None: + function_signature = signature( + embedding_function.__class__.__call__ + ).parameters.keys() + protocol_signature = signature(EmbeddingFunction.__call__).parameters.keys() - if isinstance(target, Sequence): - # One Document or ID - if isinstance(target, str) and target is not None: - return [target] - # One Embedding - if isinstance(target[0], (int, float)): - return [target] # type: ignore - # One Metadata dict - if isinstance(target, dict): - return [target] - # Already a sequence - return target # type: ignore + if not function_signature == protocol_signature: + raise ValueError( + f"Expected EmbeddingFunction.__call__ to have the following signature: {protocol_signature}, got {function_signature}\n" + "Please see https://docs.trychroma.com/embeddings for details of the EmbeddingFunction interface.\n" + "Please note the recent change to the EmbeddingFunction interface: https://docs.trychroma.com/migration#migration-to-0416---november-7-2023 \n" + ) def validate_ids(ids: IDs) -> IDs: diff --git a/chromadb/test/conftest.py b/chromadb/test/conftest.py index 401139684ab..087cb2271bd 100644 --- a/chromadb/test/conftest.py +++ b/chromadb/test/conftest.py @@ -429,6 +429,7 @@ def client(system: System) -> Generator[ClientAPI, None, None]: system.reset_state() client = ClientCreator.from_system(system) yield client + client.clear_system_cache() @pytest.fixture(scope="function") diff --git a/chromadb/test/ef/test_multimodal_ef.py b/chromadb/test/ef/test_multimodal_ef.py new file mode 100644 index 00000000000..52213c77a4c --- /dev/null +++ b/chromadb/test/ef/test_multimodal_ef.py @@ -0,0 +1,152 @@ +from typing import Generator, cast +import numpy as np +import pytest +import chromadb +from chromadb.api.types import ( + Embeddable, + EmbeddingFunction, + Embeddings, + Image, + Document, +) +from chromadb.test.property.strategies import hashing_embedding_function +from chromadb.test.property.invariants import _exact_distances + + +# A 'standard' multimodal embedding function, which converts inputs to strings +# then hashes them to a fixed dimension. +class hashing_multimodal_ef(EmbeddingFunction[Embeddable]): + def __init__(self) -> None: + self._hef = hashing_embedding_function(dim=10, dtype=np.float_) + + def __call__(self, input: Embeddable) -> Embeddings: + to_texts = [str(i) for i in input] + embeddings = np.array(self._hef(to_texts)) + # Normalize the embeddings + # This is so we can generate random unit vectors and have them be close to the embeddings + embeddings /= np.linalg.norm(embeddings, axis=1, keepdims=True) + return cast(Embeddings, embeddings.tolist()) + + +def random_image() -> Image: + return np.random.randint(0, 255, size=(10, 10, 3), dtype=np.int32) + + +def random_document() -> Document: + return str(random_image()) + + +@pytest.fixture +def multimodal_collection( + default_ef: EmbeddingFunction[Embeddable] = hashing_multimodal_ef(), +) -> Generator[chromadb.Collection, None, None]: + client = chromadb.Client() + collection = client.create_collection( + name="multimodal_collection", embedding_function=default_ef + ) + yield collection + client.clear_system_cache() + + +# Test adding and querying of a multimodal collection consisting of images and documents +def test_multimodal( + multimodal_collection: chromadb.Collection, + default_ef: EmbeddingFunction[Embeddable] = hashing_multimodal_ef(), + n_examples: int = 10, + n_query_results: int = 3, +) -> None: + image_ids = [str(i) for i in range(n_examples)] + images = [random_image() for _ in range(n_examples)] + image_embeddings = default_ef(images) + + document_ids = [str(i) for i in range(n_examples, 2 * n_examples)] + documents = [random_document() for _ in range(n_examples)] + document_embeddings = default_ef(documents) + + # Trying to add a document and an image at the same time should fail + with pytest.raises( + ValueError, match="You can only provide documents or images, not both." + ): + multimodal_collection.add( + ids=image_ids[0], documents=documents[0], images=images[0] + ) + + # Add some documents + multimodal_collection.add(ids=document_ids, documents=documents) + # Add some images + multimodal_collection.add(ids=image_ids, images=images) + + # get() should return all the documents and images + # ids corresponding to images should not have documents + get_result = multimodal_collection.get(include=["documents"]) + assert len(get_result["ids"]) == len(document_ids) + len(image_ids) + for i, id in enumerate(get_result["ids"]): + assert id in document_ids or id in image_ids + assert get_result["documents"] is not None + if id in document_ids: + assert get_result["documents"][i] == documents[document_ids.index(id)] + if id in image_ids: + assert get_result["documents"][i] is None + + # Generate a random query image + query_image = random_image() + query_image_embedding = default_ef([query_image]) + + image_neighbor_indices, _ = _exact_distances( + query_image_embedding, image_embeddings + document_embeddings + ) + # Get the ids of the nearest neighbors + nearest_image_neighbor_ids = [ + image_ids[i] if i < n_examples else document_ids[i % n_examples] + for i in image_neighbor_indices[0][:n_query_results] + ] + + # Generate a random query document + query_document = random_document() + query_document_embedding = default_ef([query_document]) + document_neighbor_indices, _ = _exact_distances( + query_document_embedding, image_embeddings + document_embeddings + ) + nearest_document_neighbor_ids = [ + image_ids[i] if i < n_examples else document_ids[i % n_examples] + for i in document_neighbor_indices[0][:n_query_results] + ] + + # Querying with both images and documents should fail + with pytest.raises(ValueError): + multimodal_collection.query( + query_images=[query_image], query_texts=[query_document] + ) + + # Query with images + query_result = multimodal_collection.query( + query_images=[query_image], n_results=n_query_results, include=["documents"] + ) + + assert query_result["ids"][0] == nearest_image_neighbor_ids + + # Query with documents + query_result = multimodal_collection.query( + query_texts=[query_document], n_results=n_query_results, include=["documents"] + ) + + assert query_result["ids"][0] == nearest_document_neighbor_ids + + +@pytest.mark.xfail +def test_multimodal_update_with_image( + multimodal_collection: chromadb.Collection, +) -> None: + # Updating an entry with an existing document should remove the documentß + + document = random_document() + image = random_image() + id = "0" + + multimodal_collection.add(ids=id, documents=document) + + multimodal_collection.update(ids=id, images=image) + + get_result = multimodal_collection.get(ids=id, include=["documents"]) + assert get_result["documents"] is not None + assert get_result["documents"][0] is None diff --git a/chromadb/test/property/strategies.py b/chromadb/test/property/strategies.py index 3583dadfba9..142fbc8b3f2 100644 --- a/chromadb/test/property/strategies.py +++ b/chromadb/test/property/strategies.py @@ -1,7 +1,7 @@ import hashlib import hypothesis import hypothesis.strategies as st -from typing import Any, Optional, List, Dict, Union +from typing import Any, Optional, List, Dict, Union, cast from typing_extensions import TypedDict import numpy as np import numpy.typing as npt @@ -13,8 +13,14 @@ from dataclasses import dataclass -from chromadb.api.types import Documents, Embeddings, Metadata -from chromadb.types import LiteralValue +from chromadb.api.types import ( + Documents, + Embeddable, + EmbeddingFunction, + Embeddings, + Metadata, +) +from chromadb.types import LiteralValue, WhereOperator, LogicalOperator # Set the random seed for reproducibility np.random.seed(0) # unnecessary, hypothesis does this for us @@ -178,15 +184,15 @@ def create_embeddings( return embeddings -class hashing_embedding_function(types.EmbeddingFunction): +class hashing_embedding_function(types.EmbeddingFunction[Documents]): def __init__(self, dim: int, dtype: npt.DTypeLike) -> None: self.dim = dim self.dtype = dtype - def __call__(self, texts: types.Documents) -> types.Embeddings: + def __call__(self, input: types.Documents) -> types.Embeddings: # Hash the texts and convert to hex strings hashed_texts = [ - list(hashlib.sha256(text.encode("utf-8")).hexdigest()) for text in texts + list(hashlib.sha256(text.encode("utf-8")).hexdigest()) for text in input ] # Pad with repetition, or truncate the hex strings to the desired dimension padded_texts = [ @@ -203,15 +209,17 @@ def __call__(self, texts: types.Documents) -> types.Embeddings: return embeddings -class not_implemented_embedding_function(types.EmbeddingFunction): - def __call__(self, texts: Documents) -> Embeddings: +class not_implemented_embedding_function(types.EmbeddingFunction[Documents]): + def __call__(self, input: Documents) -> Embeddings: assert False, "This embedding function is not implemented" def embedding_function_strategy( dim: int, dtype: npt.DTypeLike -) -> st.SearchStrategy[types.EmbeddingFunction]: - return st.just(hashing_embedding_function(dim, dtype)) +) -> st.SearchStrategy[types.EmbeddingFunction[Embeddable]]: + return st.just( + cast(EmbeddingFunction[Embeddable], hashing_embedding_function(dim, dtype)) + ) @dataclass @@ -224,7 +232,7 @@ class Collection: known_document_keywords: List[str] has_documents: bool = False has_embeddings: bool = False - embedding_function: Optional[types.EmbeddingFunction] = None + embedding_function: Optional[types.EmbeddingFunction[Embeddable]] = None @st.composite @@ -311,12 +319,12 @@ def metadata(draw: st.DrawFn, collection: Collection) -> types.Metadata: if collection.known_metadata_keys: for key in collection.known_metadata_keys.keys(): if key in metadata: - del metadata[key] + del metadata[key] # type: ignore # Finally, add in some of the known keys for the collection sampling_dict: Dict[str, st.SearchStrategy[Union[str, int, float]]] = { k: st.just(v) for k, v in collection.known_metadata_keys.items() } - metadata.update(draw(st.fixed_dictionaries({}, optional=sampling_dict))) + metadata.update(draw(st.fixed_dictionaries({}, optional=sampling_dict))) # type: ignore return metadata @@ -332,11 +340,11 @@ def document(draw: st.DrawFn, collection: Collection) -> types.Document: else: known_words_st = st.text( min_size=1, - alphabet=st.characters(blacklist_categories=blacklist_categories), + alphabet=st.characters(blacklist_categories=blacklist_categories), # type: ignore ) random_words_st = st.text( - min_size=1, alphabet=st.characters(blacklist_categories=blacklist_categories) + min_size=1, alphabet=st.characters(blacklist_categories=blacklist_categories) # type: ignore ) words = draw(st.lists(st.one_of(known_words_st, random_words_st), min_size=1)) return " ".join(words) @@ -487,20 +495,20 @@ def where_clause(draw: st.DrawFn, collection: Collection) -> types.Where: # Add or subtract a small number to avoid floating point rounding errors value = value + draw(st.sampled_from([1e-6, -1e-6])) - op: types.WhereOperator = draw(st.sampled_from(legal_ops)) + op: WhereOperator = draw(st.sampled_from(legal_ops)) if op is None: return {key: value} - elif op == "$in": + elif op == "$in": # type: ignore if isinstance(value, str) and not value: return {} return {key: {op: [value, *[draw(opposite_value(value)) for _ in range(3)]]}} - elif op == "$nin": + elif op == "$nin": # type: ignore if isinstance(value, str) and not value: return {} return {key: {op: [draw(opposite_value(value)) for _ in range(3)]}} else: - return {key: {op: value}} + return {key: {op: value}} # type: ignore @st.composite @@ -516,7 +524,7 @@ def where_doc_clause(draw: st.DrawFn, collection: Collection) -> types.WhereDocu def binary_operator_clause( base_st: SearchStrategy[types.Where], ) -> SearchStrategy[types.Where]: - op: SearchStrategy[types.LogicalOperator] = st.sampled_from(["$and", "$or"]) + op: SearchStrategy[LogicalOperator] = st.sampled_from(["$and", "$or"]) return st.dictionaries( keys=op, values=st.lists(base_st, max_size=2, min_size=2), @@ -528,7 +536,7 @@ def binary_operator_clause( def binary_document_operator_clause( base_st: SearchStrategy[types.WhereDocument], ) -> SearchStrategy[types.WhereDocument]: - op: SearchStrategy[types.LogicalOperator] = st.sampled_from(["$and", "$or"]) + op: SearchStrategy[LogicalOperator] = st.sampled_from(["$and", "$or"]) return st.dictionaries( keys=op, values=st.lists(base_st, max_size=2, min_size=2), diff --git a/chromadb/test/property/test_add.py b/chromadb/test/property/test_add.py index 5f8991b00ed..f97e33aa305 100644 --- a/chromadb/test/property/test_add.py +++ b/chromadb/test/property/test_add.py @@ -26,7 +26,7 @@ def test_add( # TODO: Generative embedding functions coll = api.create_collection( name=collection.name, - metadata=collection.metadata, + metadata=collection.metadata, # type: ignore embedding_function=collection.embedding_function, ) normalized_record_set = invariants.wrap_all(record_set) @@ -64,7 +64,7 @@ def create_large_recordset( "metadatas": metadatas, "documents": documents, } - return record_set + return cast(strategies.RecordSet, record_set) @given(collection=collection_st) @@ -77,7 +77,7 @@ def test_add_large(api: ServerAPI, collection: strategies.Collection) -> None: ) coll = api.create_collection( name=collection.name, - metadata=collection.metadata, + metadata=collection.metadata, # type: ignore embedding_function=collection.embedding_function, ) normalized_record_set = invariants.wrap_all(record_set) @@ -107,7 +107,7 @@ def test_add_large_exceeding(api: ServerAPI, collection: strategies.Collection) ) coll = api.create_collection( name=collection.name, - metadata=collection.metadata, + metadata=collection.metadata, # type: ignore embedding_function=collection.embedding_function, ) normalized_record_set = invariants.wrap_all(record_set) @@ -157,7 +157,7 @@ def test_out_of_order_ids(api: ServerAPI) -> None: ] coll = api.create_collection( - "test", embedding_function=lambda texts: [[1, 2, 3] for _ in texts] # type: ignore + "test", embedding_function=lambda input: [[1, 2, 3] for _ in input] # type: ignore ) embeddings: Embeddings = [[1, 2, 3] for _ in ooo_ids] coll.add(ids=ooo_ids, embeddings=embeddings) @@ -174,7 +174,7 @@ def test_add_partial(api: ServerAPI) -> None: # TODO: We need to clean up the api types to support this typing coll.add( ids=["1", "2", "3"], - embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]], + embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]], # type: ignore metadatas=[{"a": 1}, None, {"a": 3}], # type: ignore documents=["a", "b", None], # type: ignore ) diff --git a/chromadb/test/property/test_cross_version_persist.py b/chromadb/test/property/test_cross_version_persist.py index b5320dfe7a9..82bfc5f7cda 100644 --- a/chromadb/test/property/test_cross_version_persist.py +++ b/chromadb/test/property/test_cross_version_persist.py @@ -204,8 +204,8 @@ def switch_to_version(version: str) -> ModuleType: return chromadb -class not_implemented_ef(EmbeddingFunction): - def __call__(self, texts: Documents) -> Embeddings: +class not_implemented_ef(EmbeddingFunction[Documents]): + def __call__(self, input: Documents) -> Embeddings: assert False, "Embedding function should not be called" @@ -314,7 +314,7 @@ def test_cycle_versions( system.start() coll = api.get_collection( name=collection_strategy.name, - embedding_function=not_implemented_ef(), + embedding_function=not_implemented_ef(), # type: ignore ) invariants.count(coll, embeddings_strategy) invariants.metadatas_match(coll, embeddings_strategy) diff --git a/chromadb/test/test_api.py b/chromadb/test/test_api.py index ed3c87ee682..d6d3e3c30a8 100644 --- a/chromadb/test/test_api.py +++ b/chromadb/test/test_api.py @@ -3,7 +3,7 @@ import chromadb from chromadb.api.fastapi import FastAPI -from chromadb.api.types import QueryResult +from chromadb.api.types import QueryResult, EmbeddingFunction, Document from chromadb.config import Settings import chromadb.server.fastapi import pytest @@ -91,14 +91,17 @@ def test_persist_index_loading(api_fixture, request): @pytest.mark.parametrize("api_fixture", [local_persist_api]) def test_persist_index_loading_embedding_function(api_fixture, request): - embedding_function = lambda x: [[1, 2, 3] for _ in range(len(x))] # noqa E731 + class TestEF(EmbeddingFunction[Document]): + def __call__(self, input): + return [[1, 2, 3] for _ in range(len(input))] + api = request.getfixturevalue("local_persist_api") api.reset() - collection = api.create_collection("test", embedding_function=embedding_function) + collection = api.create_collection("test", embedding_function=TestEF()) collection.add(ids="id1", documents="hello") api2 = request.getfixturevalue("local_persist_api_cache_bust") - collection = api2.get_collection("test", embedding_function=embedding_function) + collection = api2.get_collection("test", embedding_function=TestEF()) nn = collection.query( query_texts="hello", @@ -111,18 +114,17 @@ def test_persist_index_loading_embedding_function(api_fixture, request): @pytest.mark.parametrize("api_fixture", [local_persist_api]) def test_persist_index_get_or_create_embedding_function(api_fixture, request): - embedding_function = lambda x: [[1, 2, 3] for _ in range(len(x))] # noqa E731 + class TestEF(EmbeddingFunction[Document]): + def __call__(self, input): + return [[1, 2, 3] for _ in range(len(input))] + api = request.getfixturevalue("local_persist_api") api.reset() - collection = api.get_or_create_collection( - "test", embedding_function=embedding_function - ) + collection = api.get_or_create_collection("test", embedding_function=TestEF()) collection.add(ids="id1", documents="hello") api2 = request.getfixturevalue("local_persist_api_cache_bust") - collection = api2.get_or_create_collection( - "test", embedding_function=embedding_function - ) + collection = api2.get_or_create_collection("test", embedding_function=TestEF()) nn = collection.query( query_texts="hello", diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index aaef53c01e2..5e38936ef6c 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -1,11 +1,21 @@ import logging -from chromadb.api.types import Documents, EmbeddingFunction, Embeddings +from chromadb.api.types import ( + Document, + Documents, + Embedding, + Image, + Images, + EmbeddingFunction, + Embeddings, + is_image, + is_document, +) from pathlib import Path import os import tarfile import requests -from typing import Any, Dict, List, cast +from typing import Any, Dict, List, Union, cast import numpy as np import numpy.typing as npt import importlib @@ -21,7 +31,7 @@ logger = logging.getLogger(__name__) -class SentenceTransformerEmbeddingFunction(EmbeddingFunction): +class SentenceTransformerEmbeddingFunction(EmbeddingFunction[Documents]): # Since we do dynamic imports we have to type this as Any models: Dict[str, Any] = {} @@ -44,15 +54,15 @@ def __init__( self._model = self.models[model_name] self._normalize_embeddings = normalize_embeddings - def __call__(self, texts: Documents) -> Embeddings: + def __call__(self, input: Documents) -> Embeddings: return self._model.encode( # type: ignore - list(texts), + list(input), convert_to_numpy=True, normalize_embeddings=self._normalize_embeddings, ).tolist() -class Text2VecEmbeddingFunction(EmbeddingFunction): +class Text2VecEmbeddingFunction(EmbeddingFunction[Documents]): def __init__(self, model_name: str = "shibing624/text2vec-base-chinese"): try: from text2vec import SentenceModel @@ -62,11 +72,11 @@ def __init__(self, model_name: str = "shibing624/text2vec-base-chinese"): ) self._model = SentenceModel(model_name_or_path=model_name) - def __call__(self, texts: Documents) -> Embeddings: - return self._model.encode(list(texts), convert_to_numpy=True).tolist() # type: ignore # noqa E501 + def __call__(self, input: Documents) -> Embeddings: + return self._model.encode(list(input), convert_to_numpy=True).tolist() # type: ignore # noqa E501 -class OpenAIEmbeddingFunction(EmbeddingFunction): +class OpenAIEmbeddingFunction(EmbeddingFunction[Documents]): def __init__( self, api_key: Optional[str] = None, @@ -125,12 +135,12 @@ def __init__( self._client = openai.Embedding self._model_name = model_name - def __call__(self, texts: Documents) -> Embeddings: + def __call__(self, input: Documents) -> Embeddings: # replace newlines, which can negatively affect performance. - texts = [t.replace("\n", " ") for t in texts] + input = [t.replace("\n", " ") for t in input] # Call the OpenAI Embedding API - embeddings = self._client.create(input=texts, engine=self._model_name)["data"] + embeddings = self._client.create(input=input, engine=self._model_name)["data"] # Sort resulting embeddings by index sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore @@ -139,7 +149,7 @@ def __call__(self, texts: Documents) -> Embeddings: return [result["embedding"] for result in sorted_embeddings] -class CohereEmbeddingFunction(EmbeddingFunction): +class CohereEmbeddingFunction(EmbeddingFunction[Documents]): def __init__(self, api_key: str, model_name: str = "large"): try: import cohere @@ -151,15 +161,15 @@ def __init__(self, api_key: str, model_name: str = "large"): self._client = cohere.Client(api_key) self._model_name = model_name - def __call__(self, texts: Documents) -> Embeddings: + def __call__(self, input: Documents) -> Embeddings: # Call Cohere Embedding API for each document. return [ embeddings - for embeddings in self._client.embed(texts=texts, model=self._model_name) + for embeddings in self._client.embed(texts=input, model=self._model_name) ] -class HuggingFaceEmbeddingFunction(EmbeddingFunction): +class HuggingFaceEmbeddingFunction(EmbeddingFunction[Documents]): """ This class is used to get embeddings for a list of texts using the HuggingFace API. It requires an API key and a model name. The default model name is "sentence-transformers/all-MiniLM-L6-v2". @@ -185,7 +195,7 @@ def __init__( self._session = requests.Session() self._session.headers.update({"Authorization": f"Bearer {api_key}"}) - def __call__(self, texts: Documents) -> Embeddings: + def __call__(self, input: Documents) -> Embeddings: """ Get the embeddings for a list of texts. @@ -202,11 +212,11 @@ def __call__(self, texts: Documents) -> Embeddings: """ # Call HuggingFace Embedding API for each document return self._session.post( # type: ignore - self._api_url, json={"inputs": texts, "options": {"wait_for_model": True}} + self._api_url, json={"inputs": input, "options": {"wait_for_model": True}} ).json() -class InstructorEmbeddingFunction(EmbeddingFunction): +class InstructorEmbeddingFunction(EmbeddingFunction[Documents]): # If you have a GPU with at least 6GB try model_name = "hkunlp/instructor-xl" and device = "cuda" # for a full list of options: https://github.com/HKUNLP/instructor-embedding#model-list def __init__( @@ -224,11 +234,11 @@ def __init__( self._model = INSTRUCTOR(model_name, device=device) self._instruction = instruction - def __call__(self, texts: Documents) -> Embeddings: + def __call__(self, input: Documents) -> Embeddings: if self._instruction is None: - return self._model.encode(texts).tolist() # type: ignore + return self._model.encode(input).tolist() # type: ignore - texts_with_instructions = [[self._instruction, text] for text in texts] + texts_with_instructions = [[self._instruction, text] for text in input] return self._model.encode(texts_with_instructions).tolist() # type: ignore @@ -237,7 +247,7 @@ def __call__(self, texts: Documents) -> Embeddings: # implements the same functionality as "all-MiniLM-L6-v2" from sentence-transformers. # visit https://github.com/chroma-core/onnx-embedding for the source code to generate # and verify the ONNX model. -class ONNXMiniLM_L6_V2(EmbeddingFunction): +class ONNXMiniLM_L6_V2(EmbeddingFunction[Documents]): MODEL_NAME = "all-MiniLM-L6-v2" DOWNLOAD_PATH = Path.home() / ".cache" / "chroma" / "onnx_models" / MODEL_NAME EXTRACTED_FOLDER_NAME = "onnx" @@ -374,11 +384,11 @@ def _init_model_and_tokenizer(self) -> None: providers=self._preferred_providers, ) - def __call__(self, texts: Documents) -> Embeddings: + def __call__(self, input: Documents) -> Embeddings: # Only download the model when it is actually used self._download_model_if_not_exists() self._init_model_and_tokenizer() - res = cast(Embeddings, self._forward(texts).tolist()) + res = cast(Embeddings, self._forward(input).tolist()) return res def _download_model_if_not_exists(self) -> None: @@ -413,14 +423,14 @@ def _download_model_if_not_exists(self) -> None: tar.extractall(path=self.DOWNLOAD_PATH) -def DefaultEmbeddingFunction() -> Optional[EmbeddingFunction]: +def DefaultEmbeddingFunction() -> Optional[EmbeddingFunction[Documents]]: if is_thin_client: return None else: return ONNXMiniLM_L6_V2() -class GooglePalmEmbeddingFunction(EmbeddingFunction): +class GooglePalmEmbeddingFunction(EmbeddingFunction[Documents]): """To use this EmbeddingFunction, you must have the google.generativeai Python package installed and have a PaLM API key.""" def __init__(self, api_key: str, model_name: str = "models/embedding-gecko-001"): @@ -441,16 +451,16 @@ def __init__(self, api_key: str, model_name: str = "models/embedding-gecko-001") self._palm = palm self._model_name = model_name - def __call__(self, texts: Documents) -> Embeddings: + def __call__(self, input: Documents) -> Embeddings: return [ self._palm.generate_embeddings(model=self._model_name, text=text)[ "embedding" ] - for text in texts + for text in input ] -class GoogleVertexEmbeddingFunction(EmbeddingFunction): +class GoogleVertexEmbeddingFunction(EmbeddingFunction[Documents]): # Follow API Quickstart for Google Vertex AI # https://cloud.google.com/vertex-ai/docs/generative-ai/start/quickstarts/api-quickstart # Information about the text embedding modules in Google Vertex AI @@ -466,9 +476,9 @@ def __init__( self._session = requests.Session() self._session.headers.update({"Authorization": f"Bearer {api_key}"}) - def __call__(self, texts: Documents) -> Embeddings: + def __call__(self, input: Documents) -> Embeddings: embeddings = [] - for text in texts: + for text in input: response = self._session.post( self._api_url, json={"instances": [{"content": text}]} ).json() @@ -479,6 +489,62 @@ def __call__(self, texts: Documents) -> Embeddings: return embeddings +class OpenCLIPEmbeddingFunction(EmbeddingFunction[Union[Documents, Images]]): + def __init__( + self, model_name: str = "ViT-B-32", checkpoint: str = "laion2b_s34b_b79k" + ) -> None: + try: + import open_clip + except ImportError: + raise ValueError( + "The open_clip python package is not installed. Please install it with `pip install open-clip-torch`. https://github.com/mlfoundations/open_clip" + ) + try: + self._torch = importlib.import_module("torch") + except ImportError: + raise ValueError( + "The torch python package is not installed. Please install it with `pip install torch`" + ) + + try: + self._PILImage = importlib.import_module("PIL.Image") + except ImportError: + raise ValueError( + "The PIL python package is not installed. Please install it with `pip install pillow`" + ) + + model, _, preprocess = open_clip.create_model_and_transforms( + model_name=model_name, pretrained=checkpoint + ) + self._model = model + self._preprocess = preprocess + self._tokenizer = open_clip.get_tokenizer(model_name=model_name) + + def _encode_image(self, image: Image) -> Embedding: + pil_image = self._PILImage.fromarray(image) + with self._torch.no_grad(): + image_features = self._model.encode_image( + self._preprocess(pil_image).unsqueeze(0) + ) + image_features /= image_features.norm(dim=-1, keepdim=True) + return cast(Embedding, image_features.squeeze().tolist()) + + def _encode_text(self, text: Document) -> Embedding: + with self._torch.no_grad(): + text_features = self._model.encode_text(self._tokenizer(text)) + text_features /= text_features.norm(dim=-1, keepdim=True) + return cast(Embedding, text_features.squeeze().tolist()) + + def __call__(self, input: Union[Documents, Images]) -> Embeddings: + embeddings: Embeddings = [] + for item in input: + if is_image(item): + embeddings.append(self._encode_image(cast(Image, item))) + elif is_document(item): + embeddings.append(self._encode_text(cast(Document, item))) + return embeddings + + # List of all classes in this module _classes = [ name diff --git a/multimodal_ef_example.ipynb b/multimodal_ef_example.ipynb new file mode 100644 index 00000000000..879c04454a5 --- /dev/null +++ b/multimodal_ef_example.ipynb @@ -0,0 +1,102 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import chromadb\n", + "\n", + "client = chromadb.Client()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from chromadb.api.types import Embeddings, Images\n", + "from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction\n", + "\n", + "embedding_function = OpenCLIPEmbeddingFunction()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "collection = client.create_collection('test', embedding_function=embedding_function)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from PIL import Image\n", + "\n", + "image = np.array(Image.open('test_img.jpeg'))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "collection.add(ids='a', images=image)\n", + "collection.add(ids='b', documents='hello world')" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'ids': ['a', 'b'],\n", + " 'embeddings': None,\n", + " 'metadatas': None,\n", + " 'documents': [None, 'hello world']}" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "collection.get(include=['documents'])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "chroma", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}