Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
HammadB committed Oct 16, 2023
1 parent b042331 commit ee2849b
Show file tree
Hide file tree
Showing 10 changed files with 200 additions and 78 deletions.
37 changes: 13 additions & 24 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,39 +105,28 @@ def create_collection(
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
get_or_create: bool = False,
) -> Collection:
existing = self._sysdb.get_collections(name=name)

if metadata is not None:
validate_metadata(metadata)

if existing:
if get_or_create:
if metadata and existing[0]["metadata"] != metadata:
self._modify(id=existing[0]["id"], new_metadata=metadata)
existing = self._sysdb.get_collections(id=existing[0]["id"])
return Collection(
client=self,
id=existing[0]["id"],
name=existing[0]["name"],
metadata=existing[0]["metadata"], # type: ignore
embedding_function=embedding_function,
)
else:
raise ValueError(f"Collection {name} already exists.")

# TODO: remove backwards compatibility in naming requirements
check_index_name(name)

id = uuid4()

coll = self._sysdb.create_collection(
id=id, name=name, metadata=metadata, dimension=None
coll, created = self._sysdb.create_collection(
id=id,
name=name,
metadata=metadata,
dimension=None,
get_or_create=get_or_create,
)
segments = self._manager.create_segments(coll)

for segment in segments:
self._sysdb.create_segment(segment)
if created:
segments = self._manager.create_segments(coll)
for segment in segments:
self._sysdb.create_segment(segment)

# TODO: This event doesn't capture the get_or_create case appropriately
self._telemetry_client.capture(
ClientCreateCollectionEvent(
collection_uuid=str(id),
Expand All @@ -147,9 +136,9 @@ def create_collection(

return Collection(
client=self,
id=id,
id=coll["id"],
name=name,
metadata=metadata,
metadata=coll["metadata"], # type: ignore
embedding_function=embedding_function,
)

Expand Down
27 changes: 10 additions & 17 deletions chromadb/db/impl/grpc/client.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from typing import List, Optional, Sequence, Union, cast
from typing import List, Optional, Sequence, Tuple, Union, cast
from uuid import UUID
from overrides import overrides
from chromadb.config import System
from chromadb.db.base import NotFoundError, UniqueConstraintError
from chromadb.db.system import SysDB
from chromadb.ingest import CollectionAssignmentPolicy
from chromadb.proto.convert import (
from_proto_collection,
from_proto_segment,
to_proto_collection,
to_proto_update_metadata,
to_proto_segment,
to_proto_segment_scope,
Expand Down Expand Up @@ -44,7 +42,6 @@ class GrpcSysDB(SysDB):
to call a remote SysDB (Coordinator) service."""

_sys_db_stub: SysDBStub
_assignment_policy: CollectionAssignmentPolicy
_channel: grpc.Channel
_coordinator_url: str
_coordinator_port: int
Expand All @@ -53,11 +50,11 @@ def __init__(self, system: System):
self._coordinator_url = system.settings.require("chroma_coordinator_host")
# TODO: break out coordinator_port into a separate setting?
self._coordinator_port = system.settings.require("chroma_server_grpc_port")
self._assignment_policy = system.instance(CollectionAssignmentPolicy)
return super().__init__(system)

@overrides
def start(self) -> None:
# TODO: add retry policy here
self._channel = grpc.insecure_channel(
f"{self._coordinator_url}:{self._coordinator_port}"
)
Expand Down Expand Up @@ -166,24 +163,20 @@ def create_collection(
name: str,
metadata: Optional[Metadata] = None,
dimension: Optional[int] = None,
) -> Collection:
# TODO: the get_or_create concept needs to be pushed down to the sysdb interface
topic = self._assignment_policy.assign_collection(id)
collection = Collection(
id=id,
get_or_create: bool = False,
) -> Tuple[Collection, bool]:
request = CreateCollectionRequest(
id=id.hex,
name=name,
topic=topic,
metadata=metadata,
metadata=to_proto_update_metadata(metadata) if metadata else None,
dimension=dimension,
)
request = CreateCollectionRequest(
collection=to_proto_collection(collection),
get_or_create=False,
get_or_create=get_or_create,
)
response = self._sys_db_stub.CreateCollection(request)
if response.status.code == 409:
raise UniqueConstraintError()
return collection
collection = from_proto_collection(response.collection)
return collection, response.created

@overrides
def delete_collection(self, id: UUID) -> None:
Expand Down
39 changes: 33 additions & 6 deletions chromadb/db/impl/grpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from typing import Any, Dict, cast
from uuid import UUID
from overrides import overrides
from chromadb.ingest import CollectionAssignmentPolicy
from chromadb.config import Component, System
from chromadb.proto.convert import (
from_proto_collection,
from_proto_metadata,
from_proto_update_metadata,
from_proto_segment,
from_proto_segment_scope,
Expand Down Expand Up @@ -40,11 +41,13 @@ class GrpcMockSysDB(SysDBServicer, Component):

_server: grpc.Server
_server_port: int
_assignment_policy: CollectionAssignmentPolicy
_segments: Dict[str, Segment] = {}
_collections: Dict[str, Collection] = {}

def __init__(self, system: System):
self._server_port = system.settings.require("chroma_server_grpc_port")
self._assignment_policy = system.instance(CollectionAssignmentPolicy)
return super().__init__(system)

@overrides
Expand Down Expand Up @@ -167,18 +170,42 @@ def UpdateSegment(
def CreateCollection(
self, request: CreateCollectionRequest, context: grpc.ServicerContext
) -> CreateCollectionResponse:
collection = from_proto_collection(request.collection)
if collection["id"].hex in self._collections:
collection_name = request.name
matches = [
c for c in self._collections.values() if c["name"] == collection_name
]
assert len(matches) <= 1
if len(matches) > 0:
if request.get_or_create:
existing_collection = matches[0]
if request.HasField("metadata"):
existing_collection["metadata"] = from_proto_metadata(
request.metadata
)
return CreateCollectionResponse(
status=proto.Status(code=200),
collection=to_proto_collection(existing_collection),
created=False,
)
return CreateCollectionResponse(
status=proto.Status(
code=409, reason=f"Collection {collection['id']} already exists"
code=409, reason=f"Collection {request.name} already exists"
)
)

self._collections[collection["id"].hex] = collection
id = UUID(hex=request.id)
new_collection = Collection(
id=id,
name=request.name,
metadata=from_proto_metadata(request.metadata),
dimension=request.dimension,
topic=self._assignment_policy.assign_collection(id),
)
self._collections[request.id] = new_collection
return CreateCollectionResponse(
status=proto.Status(code=200),
collection=to_proto_collection(collection),
collection=to_proto_collection(new_collection),
created=True,
)

@overrides(check_signature=False)
Expand Down
21 changes: 18 additions & 3 deletions chromadb/db/mixins/sysdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,23 @@ def create_collection(
name: str,
metadata: Optional[Metadata] = None,
dimension: Optional[int] = None,
) -> Collection:
"""Create a new collection and the associate topic"""
get_or_create: bool = False,
) -> Tuple[Collection, bool]:
if id is None and not get_or_create:
raise ValueError("id must be specified if get_or_create is False")

existing = self.get_collections(name=name)
if existing:
if get_or_create:
collection = existing[0]
if metadata is not None and collection["metadata"] != metadata:
self.update_collection(
collection["id"],
metadata=metadata,
)
return self.get_collections(id=collection["id"])[0], False
else:
raise UniqueConstraintError(f"Collection {name} already exists")

topic = self._assignment_policy.assign_collection(id)
collection = Collection(
Expand Down Expand Up @@ -129,7 +144,7 @@ def create_collection(
collection["id"],
collection["metadata"],
)
return collection
return collection, True

@override
def get_segments(
Expand Down
14 changes: 11 additions & 3 deletions chromadb/db/system.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Optional, Sequence
from typing import Optional, Sequence, Tuple
from uuid import UUID
from chromadb.types import (
Collection,
Expand Down Expand Up @@ -59,9 +59,17 @@ def create_collection(
name: str,
metadata: Optional[Metadata] = None,
dimension: Optional[int] = None,
) -> Collection:
get_or_create: bool = False,
) -> Tuple[Collection, bool]:
"""Create a new collection any associated resources
(Such as the necessary topics) in the SysDB."""
(Such as the necessary topics) in the SysDB. If get_or_create is True, the
collectionwill be created if one with the same name does not exist.
The metadata will be updated using the same protocol as update_collection. If get_or_create
is False and the collection already exists, a error will be raised.
Returns a tuple of the created collection and a boolean indicating whether the
collection was created or not.
"""
pass

@abstractmethod
Expand Down
4 changes: 3 additions & 1 deletion chromadb/proto/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ def from_proto_collection(collection: proto.Collection) -> Collection:
metadata=from_proto_metadata(collection.metadata)
if collection.HasField("metadata")
else None,
dimension=collection.dimension if collection.HasField("dimension") else None,
dimension=collection.dimension
if collection.HasField("dimension") and collection.dimension
else None,
)


Expand Down
30 changes: 15 additions & 15 deletions chromadb/proto/coordinator_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit ee2849b

Please sign in to comment.