Skip to content

Commit

Permalink
feat: CIP - Expose Collection Dimensionality
Browse files Browse the repository at this point in the history
A new CIP to add ability for clients to view collection's dimensionality
  • Loading branch information
tazarov committed Oct 8, 2023
1 parent e357ef3 commit c23952a
Show file tree
Hide file tree
Showing 8 changed files with 231 additions and 65 deletions.
28 changes: 24 additions & 4 deletions chromadb/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def create_collection(
self,
name: str,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(
),
get_or_create: bool = False,
) -> Collection:
"""Create a new collection with the given name and metadata.
Expand Down Expand Up @@ -87,7 +88,8 @@ def create_collection(
def get_collection(
self,
name: str,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(
),
) -> Collection:
"""Get a collection with the given name.
Args:
Expand All @@ -114,7 +116,8 @@ def get_or_create_collection(
self,
name: str,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(
),
) -> Collection:
"""Get or create a collection with the given name and metadata.
Args:
Expand Down Expand Up @@ -255,6 +258,22 @@ def _count(self, collection_id: UUID) -> int:
"""
pass

@abstractmethod
def _dimensions(self, collection_id: UUID) -> int:
"""[Internal] Returns the number of dimensions of the embeddings in a collection
specified by UUID.
Args:
collection_id: The UUID of the collection to get the dimensions of the
embeddings in.
Returns:
int: The number of dimensions of the embeddings in the collection. If the
collection has no embeddings, returns -1.
"""
pass

@abstractmethod
def _peek(self, collection_id: UUID, n: int = 10) -> GetResult:
"""[Internal] Returns the first n entries in a collection specified by UUID.
Expand Down Expand Up @@ -332,7 +351,8 @@ def _query(
n_results: int = 10,
where: Where = {},
where_document: WhereDocument = {},
include: Include = ["embeddings", "metadatas", "documents", "distances"],
include: Include = ["embeddings",
"metadatas", "documents", "distances"],
) -> QueryResult:
"""[Internal] Performs a nearest neighbors query on a collection specified by UUID.
Expand Down
40 changes: 30 additions & 10 deletions chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def resolve_url(
scheme = "https" if chroma_server_ssl_enabled else parsed.scheme or "http"
net_loc = parsed.netloc or parsed.hostname or chroma_server_host
port = (
":" + str(parsed.port or chroma_server_http_port) if not _skip_port else ""
":" +
str(parsed.port or chroma_server_http_port) if not _skip_port else ""
)
path = parsed.path or default_api_path

Expand All @@ -82,7 +83,8 @@ def resolve_url(
if not path.endswith(default_api_path or ""):
path = path + default_api_path if default_api_path else ""
full_url = urlunparse(
(scheme, f"{net_loc}{port}", quote(path.replace("//", "/")), "", "", "")
(scheme, f"{net_loc}{port}", quote(
path.replace("//", "/")), "", "", "")
)

return full_url
Expand All @@ -97,7 +99,8 @@ def __init__(self, system: System):

self._api_url = FastAPI.resolve_url(
chroma_server_host=str(system.settings.chroma_server_host),
chroma_server_http_port=int(str(system.settings.chroma_server_http_port)),
chroma_server_http_port=int(
str(system.settings.chroma_server_http_port)),
chroma_server_ssl_enabled=system.settings.chroma_server_ssl_enabled,
default_api_path=system.settings.chroma_server_api_default_path,
)
Expand Down Expand Up @@ -151,14 +154,16 @@ def create_collection(
self,
name: str,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(
),
get_or_create: bool = False,
) -> Collection:
"""Creates a collection"""
resp = self._session.post(
self._api_url + "/collections",
data=json.dumps(
{"name": name, "metadata": metadata, "get_or_create": get_or_create}
{"name": name, "metadata": metadata,
"get_or_create": get_or_create}
),
)
raise_chroma_error(resp)
Expand All @@ -175,7 +180,8 @@ def create_collection(
def get_collection(
self,
name: str,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(
),
) -> Collection:
"""Returns a collection"""
resp = self._session.get(self._api_url + "/collections/" + name)
Expand All @@ -194,7 +200,8 @@ def get_or_create_collection(
self,
name: str,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(
),
) -> Collection:
return self.create_collection(
name, metadata, embedding_function, get_or_create=True
Expand All @@ -210,7 +217,8 @@ def _modify(
"""Updates a collection"""
resp = self._session.put(
self._api_url + "/collections/" + str(id),
data=json.dumps({"new_metadata": new_metadata, "new_name": new_name}),
data=json.dumps(
{"new_metadata": new_metadata, "new_name": new_name}),
)
raise_chroma_error(resp)

Expand All @@ -229,6 +237,16 @@ def _count(self, collection_id: UUID) -> int:
raise_chroma_error(resp)
return cast(int, resp.json())

@override
def _dimensions(self, collection_id: UUID) -> int:
"""Returns the dimensionality of the embeddings in the collection"""
resp = self._session.get(
self._api_url + "/collections/" +
str(collection_id) + "/dimensions"
)
raise_chroma_error(resp)
return cast(int, resp.json())

@override
def _peek(self, collection_id: UUID, n: int = 10) -> GetResult:
return self._get(
Expand Down Expand Up @@ -336,7 +354,8 @@ def _add(
"""
batch = (ids, embeddings, metadatas, documents)
validate_batch(batch, {"max_batch_size": self.max_batch_size})
resp = self._submit_batch(batch, "/collections/" + str(collection_id) + "/add")
resp = self._submit_batch(
batch, "/collections/" + str(collection_id) + "/add")
raise_chroma_error(resp)
return True

Expand Down Expand Up @@ -456,7 +475,8 @@ def raise_chroma_error(resp: requests.Response) -> None:
body = resp.json()
if "error" in body:
if body["error"] in errors.error_types:
chroma_error = errors.error_types[body["error"]](body["message"])
chroma_error = errors.error_types[body["error"]](
body["message"])

except BaseException:
pass
Expand Down
21 changes: 17 additions & 4 deletions chromadb/api/models/Collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def __init__(
client: "API",
name: str,
id: UUID,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(
),
metadata: Optional[CollectionMetadata] = None,
):
super().__init__(name=name, metadata=metadata, id=id)
Expand All @@ -65,7 +66,17 @@ def count(self) -> int:
int: The total number of embeddings added to the database
"""
return self._client._count(collection_id=self.id)
return self._client._count(collection_id=self.id)\


def dimensions(self) -> int:
"""The number of dimensions of the embeddings
Returns:
int: The number of dimensions of the embeddings
"""
return self._client._dimensions(collection_id=self.id)

def add(
self,
Expand Down Expand Up @@ -188,7 +199,8 @@ def query(
else None
)
query_texts = (
maybe_cast_one_to_many(query_texts) if query_texts is not None else None
maybe_cast_one_to_many(
query_texts) if query_texts is not None else None
)
include = validate_include(include, allow_distances=True)
n_results = validate_n_results(n_results)
Expand Down Expand Up @@ -355,7 +367,8 @@ def _validate_embedding_set(
if metadatas is not None
else None
)
documents = maybe_cast_one_to_many(documents) if documents is not None else None
documents = maybe_cast_one_to_many(
documents) if documents is not None else None

# Check that one of embeddings or documents is provided
if require_embeddings_or_documents:
Expand Down
Loading

0 comments on commit c23952a

Please sign in to comment.