Skip to content

Commit

Permalink
[ENH] Multimodal Embeddings (#1293)
Browse files Browse the repository at this point in the history
## Description of changes

This PR introduces multi-modal embeddings into Chroma. 
- It adds the generic `EmbeddingFunction` which can take various data
types. Existing functions take the `Documents` type.
- Adds `Images` as a type (numpy NDArray taking ints or floats)
- Add `OpenCLIPEmbeddingFunction` which is an
`EmbeddingFunction[Union[Documents, Images]]`

## Test

Integration tests pass. 

A new test for multimodal embedding functions:
[chromadb/test/ef/test_multimodal_ef.py](https://github.com/chroma-core/chroma/blob/86a9e2620352ee0b2844bc3233f9e001cc4aa3d9/chromadb/test/ef/test_multimodal_ef.py)

## Documentation

See #1294

## TODOs
- [x] Tests
- [x] ~Wiring through FastAPI~ Nothing to wire through
- [x] Documentation
- [x] Telemetry
- [ ] ~JavaScript~
  • Loading branch information
atroyn authored Nov 7, 2023
1 parent ae0206b commit 4db9955
Show file tree
Hide file tree
Showing 12 changed files with 696 additions and 186 deletions.
29 changes: 21 additions & 8 deletions chromadb/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from chromadb.api.types import (
CollectionMetadata,
Documents,
Embeddable,
EmbeddingFunction,
Embeddings,
IDs,
Expand Down Expand Up @@ -58,7 +59,9 @@ def create_collection(
self,
name: str,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
get_or_create: bool = False,
) -> Collection:
"""Create a new collection with the given name and metadata.
Expand Down Expand Up @@ -90,9 +93,11 @@ def create_collection(
@abstractmethod
def get_collection(
self,
name: Optional[str] = None,
name: str,
id: Optional[UUID] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
) -> Collection:
"""Get a collection with the given name.
Args:
Expand All @@ -119,7 +124,9 @@ def get_or_create_collection(
self,
name: str,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
) -> Collection:
"""Get or create a collection with the given name and metadata.
Args:
Expand Down Expand Up @@ -486,7 +493,9 @@ def create_collection(
self,
name: str,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
get_or_create: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
Expand All @@ -497,9 +506,11 @@ def create_collection(
@override
def get_collection(
self,
name: Optional[str] = None,
name: str,
id: Optional[UUID] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Collection:
Expand All @@ -511,7 +522,9 @@ def get_or_create_collection(
self,
name: str,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Collection:
Expand Down
39 changes: 24 additions & 15 deletions chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from chromadb.api.models.Collection import Collection
from chromadb.api.types import (
Documents,
Embeddable,
Embeddings,
EmbeddingFunction,
IDs,
Expand Down Expand Up @@ -219,7 +220,9 @@ def create_collection(
self,
name: str,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
get_or_create: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
Expand Down Expand Up @@ -250,9 +253,9 @@ def create_collection(
@override
def get_collection(
self,
name: Optional[str] = None,
name: str,
id: Optional[UUID] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
embedding_function: Optional[EmbeddingFunction[Embeddable]] = ef.DefaultEmbeddingFunction(), # type: ignore
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Collection:
Expand Down Expand Up @@ -284,17 +287,20 @@ def get_or_create_collection(
self,
name: str,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
embedding_function: Optional[EmbeddingFunction[Embeddable]] = ef.DefaultEmbeddingFunction(), # type: ignore
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Collection:
return self.create_collection(
name,
metadata,
embedding_function,
get_or_create=True,
tenant=tenant,
database=database,
return cast(
Collection,
self.create_collection(
name,
metadata,
embedding_function,
get_or_create=True,
tenant=tenant,
database=database,
),
)

@trace_method("FastAPI._modify", OpenTelemetryGranularity.OPERATION)
Expand Down Expand Up @@ -347,10 +353,13 @@ def _peek(
collection_id: UUID,
n: int = 10,
) -> GetResult:
return self._get(
collection_id,
limit=n,
include=["embeddings", "documents", "metadatas"],
return cast(
GetResult,
self._get(
collection_id,
limit=n,
include=["embeddings", "documents", "metadatas"],
),
)

@trace_method("FastAPI._get", OpenTelemetryGranularity.OPERATION)
Expand Down
Loading

0 comments on commit 4db9955

Please sign in to comment.