Skip to content

Commit

Permalink
Move get_or_create into sysdb
Browse files Browse the repository at this point in the history
  • Loading branch information
HammadB committed Oct 13, 2023
1 parent fa502f1 commit 3e01da7
Show file tree
Hide file tree
Showing 10 changed files with 175 additions and 63 deletions.
23 changes: 5 additions & 18 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,33 +105,20 @@ def create_collection(
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
get_or_create: bool = False,
) -> Collection:
existing = self._sysdb.get_collections(name=name)

if metadata is not None:
validate_metadata(metadata)

if existing:
if get_or_create:
if metadata and existing[0]["metadata"] != metadata:
self._modify(id=existing[0]["id"], new_metadata=metadata)
existing = self._sysdb.get_collections(id=existing[0]["id"])
return Collection(
client=self,
id=existing[0]["id"],
name=existing[0]["name"],
metadata=existing[0]["metadata"], # type: ignore
embedding_function=embedding_function,
)
else:
raise ValueError(f"Collection {name} already exists.")

# TODO: remove backwards compatibility in naming requirements
check_index_name(name)

id = uuid4()

coll = self._sysdb.create_collection(
id=id, name=name, metadata=metadata, dimension=None
id=id,
name=name,
metadata=metadata,
dimension=None,
get_or_create=get_or_create,
)
segments = self._manager.create_segments(coll)

Expand Down
21 changes: 7 additions & 14 deletions chromadb/db/impl/grpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
from chromadb.config import System
from chromadb.db.base import NotFoundError, UniqueConstraintError
from chromadb.db.system import SysDB
from chromadb.ingest import CollectionAssignmentPolicy
from chromadb.proto.convert import (
from_proto_collection,
from_proto_segment,
to_proto_collection,
to_proto_update_metadata,
to_proto_segment,
to_proto_segment_scope,
Expand Down Expand Up @@ -44,7 +42,6 @@ class GrpcSysDB(SysDB):
to call a remote SysDB (Coordinator) service."""

_sys_db_stub: SysDBStub
_assignment_policy: CollectionAssignmentPolicy
_channel: grpc.Channel
_coordinator_url: str
_coordinator_port: int
Expand All @@ -53,11 +50,11 @@ def __init__(self, system: System):
self._coordinator_url = system.settings.require("chroma_coordinator_host")
# TODO: break out coordinator_port into a separate setting?
self._coordinator_port = system.settings.require("chroma_server_grpc_port")
self._assignment_policy = system.instance(CollectionAssignmentPolicy)
return super().__init__(system)

@overrides
def start(self) -> None:
# TODO: add retry policy here
self._channel = grpc.insecure_channel(
f"{self._coordinator_url}:{self._coordinator_port}"
)
Expand Down Expand Up @@ -166,23 +163,19 @@ def create_collection(
name: str,
metadata: Optional[Metadata] = None,
dimension: Optional[int] = None,
get_or_create: bool = False,
) -> Collection:
# TODO: the get_or_create concept needs to be pushed down to the sysdb interface
topic = self._assignment_policy.assign_collection(id)
collection = Collection(
id=id,
request = CreateCollectionRequest(
id=id.hex,
name=name,
topic=topic,
metadata=metadata,
metadata=to_proto_update_metadata(metadata) if metadata else None,
dimension=dimension,
)
request = CreateCollectionRequest(
collection=to_proto_collection(collection),
get_or_create=False,
get_or_create=get_or_create,
)
response = self._sys_db_stub.CreateCollection(request)
if response.status.code == 409:
raise UniqueConstraintError()
collection = from_proto_collection(response.collection)
return collection

@overrides
Expand Down
37 changes: 31 additions & 6 deletions chromadb/db/impl/grpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from typing import Any, Dict, cast
from uuid import UUID
from overrides import overrides
from chromadb.ingest import CollectionAssignmentPolicy
from chromadb.config import Component, System
from chromadb.proto.convert import (
from_proto_collection,
from_proto_metadata,
from_proto_update_metadata,
from_proto_segment,
from_proto_segment_scope,
Expand Down Expand Up @@ -39,10 +40,12 @@ class GrpcMockSysDB(SysDBServicer, Component):
state in simple python data structures instead of a database."""

_server: grpc.Server
_assignment_policy: CollectionAssignmentPolicy
_segments: Dict[str, Segment] = {}
_collections: Dict[str, Collection] = {}

def __init__(self, system: System):
self._assignment_policy = system.instance(CollectionAssignmentPolicy)
return super().__init__(system)

@overrides
Expand Down Expand Up @@ -165,18 +168,40 @@ def UpdateSegment(
def CreateCollection(
self, request: CreateCollectionRequest, context: grpc.ServicerContext
) -> CreateCollectionResponse:
collection = from_proto_collection(request.collection)
if collection["id"].hex in self._collections:
collection_name = request.name
matches = [
c for c in self._collections.values() if c["name"] == collection_name
]
assert len(matches) <= 1
if len(matches) > 0:
if request.get_or_create:
existing_collection = matches[0]
self._merge_metadata(
cast(Dict[str, Any], existing_collection["metadata"]),
request.metadata,
)
return CreateCollectionResponse(
status=proto.Status(code=200),
collection=to_proto_collection(existing_collection),
)
return CreateCollectionResponse(
status=proto.Status(
code=409, reason=f"Collection {collection['id']} already exists"
code=409, reason=f"Collection {request.name} already exists"
)
)

self._collections[collection["id"].hex] = collection
id = UUID(hex=request.id)
new_collection = Collection(
id=id,
name=request.name,
metadata=from_proto_metadata(request.metadata),
dimension=request.dimension,
topic=self._assignment_policy.assign_collection(id),
)
self._collections[request.id] = new_collection
return CreateCollectionResponse(
status=proto.Status(code=200),
collection=to_proto_collection(collection),
collection=to_proto_collection(new_collection),
)

@overrides(check_signature=False)
Expand Down
17 changes: 16 additions & 1 deletion chromadb/db/mixins/sysdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,23 @@ def create_collection(
name: str,
metadata: Optional[Metadata] = None,
dimension: Optional[int] = None,
get_or_create: bool = False,
) -> Collection:
"""Create a new collection and the associate topic"""
if id is None and not get_or_create:
raise ValueError("id must be specified if get_or_create is False")

existing = self.get_collections(name=name)
if existing:
if get_or_create:
collection = existing[0]
if metadata and collection["metadata"] != metadata:
self.update_collection(
collection["id"],
metadata=collection["metadata"],
)
return self.get_collections(id=collection["id"])[0]
else:
raise UniqueConstraintError(f"Collection {name} already exists")

topic = self._assignment_policy.assign_collection(id)
collection = Collection(
Expand Down
6 changes: 5 additions & 1 deletion chromadb/db/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,13 @@ def create_collection(
name: str,
metadata: Optional[Metadata] = None,
dimension: Optional[int] = None,
get_or_create: bool = False,
) -> Collection:
"""Create a new collection any associated resources
(Such as the necessary topics) in the SysDB."""
(Such as the necessary topics) in the SysDB. If get_or_create is True, the
collection if one with the same name does not exist. The metadata will be merged
with any existing metadata if the collection already exists. If get_or_create
is False and the collection already exists, a error will be raised."""
pass

@abstractmethod
Expand Down
4 changes: 3 additions & 1 deletion chromadb/proto/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ def from_proto_collection(collection: proto.Collection) -> Collection:
metadata=from_proto_metadata(collection.metadata)
if collection.HasField("metadata")
else None,
dimension=collection.dimension if collection.HasField("dimension") else None,
dimension=collection.dimension
if collection.HasField("dimension") and collection.dimension
else None,
)


Expand Down
30 changes: 15 additions & 15 deletions chromadb/proto/coordinator_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 10 additions & 4 deletions chromadb/proto/coordinator_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,18 @@ class UpdateSegmentRequest(_message.Message):
def __init__(self, id: _Optional[str] = ..., topic: _Optional[str] = ..., reset_topic: bool = ..., collection: _Optional[str] = ..., reset_collection: bool = ..., metadata: _Optional[_Union[_chroma_pb2.UpdateMetadata, _Mapping]] = ..., reset_metadata: bool = ...) -> None: ...

class CreateCollectionRequest(_message.Message):
__slots__ = ["collection", "get_or_create"]
COLLECTION_FIELD_NUMBER: _ClassVar[int]
__slots__ = ["id", "name", "metadata", "dimension", "get_or_create"]
ID_FIELD_NUMBER: _ClassVar[int]
NAME_FIELD_NUMBER: _ClassVar[int]
METADATA_FIELD_NUMBER: _ClassVar[int]
DIMENSION_FIELD_NUMBER: _ClassVar[int]
GET_OR_CREATE_FIELD_NUMBER: _ClassVar[int]
collection: _chroma_pb2.Collection
id: str
name: str
metadata: _chroma_pb2.UpdateMetadata
dimension: int
get_or_create: bool
def __init__(self, collection: _Optional[_Union[_chroma_pb2.Collection, _Mapping]] = ..., get_or_create: bool = ...) -> None: ...
def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., metadata: _Optional[_Union[_chroma_pb2.UpdateMetadata, _Mapping]] = ..., dimension: _Optional[int] = ..., get_or_create: bool = ...) -> None: ...

class CreateCollectionResponse(_message.Message):
__slots__ = ["collection", "status"]
Expand Down
Loading

0 comments on commit 3e01da7

Please sign in to comment.