-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add coordinator impl and mock server to test it. Update protobufs and…
… convert to handle null + update
- Loading branch information
Showing
15 changed files
with
1,170 additions
and
247 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,237 @@ | ||
from typing import List, Optional, Sequence, Union, cast | ||
from uuid import UUID | ||
from overrides import overrides | ||
from chromadb.config import System | ||
from chromadb.db.base import NotFoundError, UniqueConstraintError | ||
from chromadb.db.system import SysDB | ||
from chromadb.proto.convert import ( | ||
from_proto_collection, | ||
from_proto_segment, | ||
to_proto_collection, | ||
to_proto_update_metadata, | ||
to_proto_segment, | ||
to_proto_segment_scope, | ||
) | ||
from chromadb.proto.coordinator_pb2 import ( | ||
CreateCollectionRequest, | ||
CreateSegmentRequest, | ||
DeleteCollectionRequest, | ||
DeleteSegmentRequest, | ||
GetCollectionsRequest, | ||
GetCollectionsResponse, | ||
GetSegmentsRequest, | ||
UpdateCollectionRequest, | ||
UpdateSegmentRequest, | ||
) | ||
from chromadb.proto.coordinator_pb2_grpc import SysDBStub | ||
from chromadb.types import ( | ||
Collection, | ||
OptionalArgument, | ||
Segment, | ||
SegmentScope, | ||
Unspecified, | ||
UpdateMetadata, | ||
) | ||
from google.protobuf.empty_pb2 import Empty | ||
import grpc | ||
|
||
|
||
class GrpcSysDB(SysDB): | ||
"""A gRPC implementation of the SysDB. In the distributed system, the SysDB is also | ||
called the 'Coordinator'. This implementation is used by Chroma frontend servers | ||
to call a remote SysDB (Coordinator) service.""" | ||
|
||
_sys_db_stub: SysDBStub | ||
_channel: grpc.Channel | ||
_coordinator_url: str | ||
_coordinator_port: int | ||
|
||
def __init__(self, system: System): | ||
self._coordinator_url = system.settings.require("chroma_coordinator_host") | ||
# TODO: break out coordinator_port into a separate setting? | ||
self._coordinator_port = system.settings.require("chroma_server_grpc_port") | ||
return super().__init__(system) | ||
|
||
@overrides | ||
def start(self) -> None: | ||
self._channel = grpc.insecure_channel( | ||
f"{self._coordinator_url}:{self._coordinator_port}" | ||
) | ||
self._sys_db_stub = SysDBStub(self._channel) # type: ignore | ||
return super().start() | ||
|
||
@overrides | ||
def stop(self) -> None: | ||
self._channel.close() | ||
return super().stop() | ||
|
||
@overrides | ||
def reset_state(self) -> None: | ||
self._sys_db_stub.ResetState(Empty()) | ||
return super().reset_state() | ||
|
||
@overrides | ||
def create_segment(self, segment: Segment) -> None: | ||
proto_segment = to_proto_segment(segment) | ||
request = CreateSegmentRequest( | ||
segment=proto_segment, | ||
) | ||
response = self._sys_db_stub.CreateSegment(request) | ||
if response.status.code == 409: | ||
raise UniqueConstraintError() | ||
|
||
@overrides | ||
def delete_segment(self, id: UUID) -> None: | ||
request = DeleteSegmentRequest( | ||
id=id.hex, | ||
) | ||
response = self._sys_db_stub.DeleteSegment(request) | ||
if response.status.code == 404: | ||
raise NotFoundError() | ||
|
||
@overrides | ||
def get_segments( | ||
self, | ||
id: Optional[UUID] = None, | ||
type: Optional[str] = None, | ||
scope: Optional[SegmentScope] = None, | ||
topic: Optional[str] = None, | ||
collection: Optional[UUID] = None, | ||
) -> Sequence[Segment]: | ||
request = GetSegmentsRequest( | ||
id=id.hex if id else None, | ||
type=type, | ||
scope=to_proto_segment_scope(scope) if scope else None, | ||
topic=topic, | ||
collection=collection.hex if collection else None, | ||
) | ||
response = self._sys_db_stub.GetSegments(request) | ||
results: List[Segment] = [] | ||
for proto_segment in response.segments: | ||
segment = from_proto_segment(proto_segment) | ||
results.append(segment) | ||
return results | ||
|
||
@overrides | ||
def update_segment( | ||
self, | ||
id: UUID, | ||
topic: OptionalArgument[Optional[str]] = Unspecified(), | ||
collection: OptionalArgument[Optional[UUID]] = Unspecified(), | ||
metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(), | ||
) -> None: | ||
write_topic = None | ||
if topic != Unspecified(): | ||
write_topic = cast(Union[str, None], topic) | ||
|
||
write_collection = None | ||
if collection != Unspecified(): | ||
write_collection = cast(Union[UUID, None], collection) | ||
|
||
write_metadata = None | ||
if metadata != Unspecified(): | ||
write_metadata = cast(Union[UpdateMetadata, None], metadata) | ||
|
||
request = UpdateSegmentRequest( | ||
id=id.hex, | ||
topic=write_topic, | ||
collection=write_collection.hex if write_collection else None, | ||
metadata=to_proto_update_metadata(write_metadata) | ||
if write_metadata | ||
else None, | ||
) | ||
|
||
if topic is None: | ||
request.ClearField("topic") | ||
request.reset_topic = True | ||
|
||
if collection is None: | ||
request.ClearField("collection") | ||
request.reset_collection = True | ||
|
||
if metadata is None: | ||
request.ClearField("metadata") | ||
request.reset_metadata = True | ||
|
||
self._sys_db_stub.UpdateSegment(request) | ||
|
||
@overrides | ||
def create_collection(self, collection: Collection) -> None: | ||
# TODO: the get_or_create concept needs to be pushed down to the sysdb interface | ||
request = CreateCollectionRequest( | ||
collection=to_proto_collection(collection), | ||
get_or_create=False, | ||
) | ||
response = self._sys_db_stub.CreateCollection(request) | ||
if response.status.code == 409: | ||
raise UniqueConstraintError() | ||
|
||
@overrides | ||
def delete_collection(self, id: UUID) -> None: | ||
request = DeleteCollectionRequest( | ||
id=id.hex, | ||
) | ||
response = self._sys_db_stub.DeleteCollection(request) | ||
if response.status.code == 404: | ||
raise NotFoundError() | ||
|
||
@overrides | ||
def get_collections( | ||
self, | ||
id: Optional[UUID] = None, | ||
topic: Optional[str] = None, | ||
name: Optional[str] = None, | ||
) -> Sequence[Collection]: | ||
request = GetCollectionsRequest( | ||
id=id.hex if id else None, | ||
topic=topic, | ||
name=name, | ||
) | ||
response: GetCollectionsResponse = self._sys_db_stub.GetCollections(request) | ||
results: List[Collection] = [] | ||
for collection in response.collections: | ||
results.append(from_proto_collection(collection)) | ||
return results | ||
|
||
@overrides | ||
def update_collection( | ||
self, | ||
id: UUID, | ||
topic: OptionalArgument[str] = Unspecified(), | ||
name: OptionalArgument[str] = Unspecified(), | ||
dimension: OptionalArgument[Optional[int]] = Unspecified(), | ||
metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(), | ||
) -> None: | ||
write_topic = None | ||
if topic != Unspecified(): | ||
write_topic = cast(str, topic) | ||
|
||
write_name = None | ||
if name != Unspecified(): | ||
write_name = cast(str, name) | ||
|
||
write_dimension = None | ||
if dimension != Unspecified(): | ||
write_dimension = cast(Union[int, None], dimension) | ||
|
||
write_metadata = None | ||
if metadata != Unspecified(): | ||
write_metadata = cast(Union[UpdateMetadata, None], metadata) | ||
|
||
request = UpdateCollectionRequest( | ||
id=id.hex, | ||
topic=write_topic, | ||
name=write_name, | ||
dimension=write_dimension, | ||
metadata=to_proto_update_metadata(write_metadata) | ||
if write_metadata | ||
else None, | ||
) | ||
if metadata is None: | ||
request.ClearField("metadata") | ||
request.reset_metadata = True | ||
|
||
self._sys_db_stub.UpdateCollection(request) | ||
|
||
def reset_and_wait_for_ready(self) -> None: | ||
self._sys_db_stub.ResetState(Empty(), wait_for_ready=True) |
Oops, something went wrong.