diff --git a/.vscode/settings.json b/.vscode/settings.json index 5f44b0983874..f62dcb24a741 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -35,5 +35,10 @@ "--no-pretty", "--strict", "--disable-error-code=type-abstract" - ] + ], + "protoc": { + "options": [ + "--proto_path=idl/", + ] + } } diff --git a/chromadb/config.py b/chromadb/config.py index 920c92d6a96e..a2af7bd32bc9 100644 --- a/chromadb/config.py +++ b/chromadb/config.py @@ -95,6 +95,7 @@ class Settings(BaseSettings): # type: ignore chroma_segment_directory_impl: str = "chromadb.segment.impl.distributed.segment_directory.RendezvousHashSegmentDirectory" chroma_memberlist_provider_impl: str = "chromadb.segment.impl.distributed.segment_directory.CustomResourceMemberlistProvider" worker_memberlist_name: str = "worker-memberlist" + chroma_coordinator_host = "localhost" tenant_id: str = "default" topic_namespace: str = "default" diff --git a/chromadb/db/impl/grpc.py b/chromadb/db/impl/grpc.py deleted file mode 100644 index fc2b51685515..000000000000 --- a/chromadb/db/impl/grpc.py +++ /dev/null @@ -1,102 +0,0 @@ -from typing import Optional, Sequence -from uuid import UUID -from overrides import overrides -from chromadb.config import System -from chromadb.db.system import SysDB -from chromadb.proto.coordinator_pb2_grpc import SysDBStub -from chromadb.types import ( - Collection, - OptionalArgument, - Segment, - SegmentScope, - Unspecified, - UpdateMetadata, -) -import grpc - - -class GrpcSysDB(SysDB): - """A gRPC implementation of the SysDB. In the distributed system, the SysDB is also - called the 'Coordinator'. This implementation is used by Chroma frontend servers - to call a remote SysDB (Coordinator) service.""" - - _sys_db_stub: SysDBStub - _coordinator_url: str - _coordinator_port: int - - def __init__(self, system: System): - self._coordinator_url = system.settings.require("coordinator_host") - # TODO: break out coordinator_port into a separate setting? - self._coordinator_port = system.settings.require("chroma_server_grpc_port") - - @overrides - def start(self) -> None: - channel = grpc.insecure_channel(self._coordinator_url) - self._sys_db_stub = SysDBStub(channel) # type: ignore - return super().start() - - @overrides - def stop(self) -> None: - return super().stop() - - @overrides - def reset_state(self) -> None: - # TODO - remote service should be able to reset state for testing - return super().reset_state() - - @overrides - def create_segment(self, segment: Segment) -> None: - return super().create_segment(segment) - - @overrides - def delete_segment(self, id: UUID) -> None: - raise NotImplementedError() - - @overrides - def get_segments( - self, - id: Optional[UUID] = None, - type: Optional[str] = None, - scope: Optional[SegmentScope] = None, - topic: Optional[str] = None, - collection: Optional[UUID] = None, - ) -> Sequence[Segment]: - raise NotImplementedError() - - @overrides - def update_segment( - self, - id: UUID, - topic: OptionalArgument[Optional[str]] = Unspecified(), - collection: OptionalArgument[Optional[UUID]] = Unspecified(), - metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(), - ) -> None: - raise NotImplementedError() - - @overrides - def create_collection(self, collection: Collection) -> None: - raise NotImplementedError() - - @overrides - def delete_collection(self, id: UUID) -> None: - raise NotImplementedError() - - @overrides - def get_collections( - self, - id: Optional[UUID] = None, - topic: Optional[str] = None, - name: Optional[str] = None, - ) -> Sequence[Collection]: - raise NotImplementedError() - - @overrides - def update_collection( - self, - id: UUID, - topic: OptionalArgument[str] = Unspecified(), - name: OptionalArgument[str] = Unspecified(), - dimension: OptionalArgument[Optional[int]] = Unspecified(), - metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(), - ) -> None: - raise NotImplementedError() diff --git a/chromadb/db/impl/grpc/client.py b/chromadb/db/impl/grpc/client.py new file mode 100644 index 000000000000..5d7c6e888381 --- /dev/null +++ b/chromadb/db/impl/grpc/client.py @@ -0,0 +1,237 @@ +from typing import List, Optional, Sequence, Union, cast +from uuid import UUID +from overrides import overrides +from chromadb.config import System +from chromadb.db.base import NotFoundError, UniqueConstraintError +from chromadb.db.system import SysDB +from chromadb.proto.convert import ( + from_proto_collection, + from_proto_segment, + to_proto_collection, + to_proto_update_metadata, + to_proto_segment, + to_proto_segment_scope, +) +from chromadb.proto.coordinator_pb2 import ( + CreateCollectionRequest, + CreateSegmentRequest, + DeleteCollectionRequest, + DeleteSegmentRequest, + GetCollectionsRequest, + GetCollectionsResponse, + GetSegmentsRequest, + UpdateCollectionRequest, + UpdateSegmentRequest, +) +from chromadb.proto.coordinator_pb2_grpc import SysDBStub +from chromadb.types import ( + Collection, + OptionalArgument, + Segment, + SegmentScope, + Unspecified, + UpdateMetadata, +) +from google.protobuf.empty_pb2 import Empty +import grpc + + +class GrpcSysDB(SysDB): + """A gRPC implementation of the SysDB. In the distributed system, the SysDB is also + called the 'Coordinator'. This implementation is used by Chroma frontend servers + to call a remote SysDB (Coordinator) service.""" + + _sys_db_stub: SysDBStub + _channel: grpc.Channel + _coordinator_url: str + _coordinator_port: int + + def __init__(self, system: System): + self._coordinator_url = system.settings.require("chroma_coordinator_host") + # TODO: break out coordinator_port into a separate setting? + self._coordinator_port = system.settings.require("chroma_server_grpc_port") + return super().__init__(system) + + @overrides + def start(self) -> None: + self._channel = grpc.insecure_channel( + f"{self._coordinator_url}:{self._coordinator_port}" + ) + self._sys_db_stub = SysDBStub(self._channel) # type: ignore + return super().start() + + @overrides + def stop(self) -> None: + self._channel.close() + return super().stop() + + @overrides + def reset_state(self) -> None: + self._sys_db_stub.ResetState(Empty()) + return super().reset_state() + + @overrides + def create_segment(self, segment: Segment) -> None: + proto_segment = to_proto_segment(segment) + request = CreateSegmentRequest( + segment=proto_segment, + ) + response = self._sys_db_stub.CreateSegment(request) + if response.status.code == 409: + raise UniqueConstraintError() + + @overrides + def delete_segment(self, id: UUID) -> None: + request = DeleteSegmentRequest( + id=id.hex, + ) + response = self._sys_db_stub.DeleteSegment(request) + if response.status.code == 404: + raise NotFoundError() + + @overrides + def get_segments( + self, + id: Optional[UUID] = None, + type: Optional[str] = None, + scope: Optional[SegmentScope] = None, + topic: Optional[str] = None, + collection: Optional[UUID] = None, + ) -> Sequence[Segment]: + request = GetSegmentsRequest( + id=id.hex if id else None, + type=type, + scope=to_proto_segment_scope(scope) if scope else None, + topic=topic, + collection=collection.hex if collection else None, + ) + response = self._sys_db_stub.GetSegments(request) + results: List[Segment] = [] + for proto_segment in response.segments: + segment = from_proto_segment(proto_segment) + results.append(segment) + return results + + @overrides + def update_segment( + self, + id: UUID, + topic: OptionalArgument[Optional[str]] = Unspecified(), + collection: OptionalArgument[Optional[UUID]] = Unspecified(), + metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(), + ) -> None: + write_topic = None + if topic != Unspecified(): + write_topic = cast(Union[str, None], topic) + + write_collection = None + if collection != Unspecified(): + write_collection = cast(Union[UUID, None], collection) + + write_metadata = None + if metadata != Unspecified(): + write_metadata = cast(Union[UpdateMetadata, None], metadata) + + request = UpdateSegmentRequest( + id=id.hex, + topic=write_topic, + collection=write_collection.hex if write_collection else None, + metadata=to_proto_update_metadata(write_metadata) + if write_metadata + else None, + ) + + if topic is None: + request.ClearField("topic") + request.reset_topic = True + + if collection is None: + request.ClearField("collection") + request.reset_collection = True + + if metadata is None: + request.ClearField("metadata") + request.reset_metadata = True + + self._sys_db_stub.UpdateSegment(request) + + @overrides + def create_collection(self, collection: Collection) -> None: + # TODO: the get_or_create concept needs to be pushed down to the sysdb interface + request = CreateCollectionRequest( + collection=to_proto_collection(collection), + get_or_create=False, + ) + response = self._sys_db_stub.CreateCollection(request) + if response.status.code == 409: + raise UniqueConstraintError() + + @overrides + def delete_collection(self, id: UUID) -> None: + request = DeleteCollectionRequest( + id=id.hex, + ) + response = self._sys_db_stub.DeleteCollection(request) + if response.status.code == 404: + raise NotFoundError() + + @overrides + def get_collections( + self, + id: Optional[UUID] = None, + topic: Optional[str] = None, + name: Optional[str] = None, + ) -> Sequence[Collection]: + request = GetCollectionsRequest( + id=id.hex if id else None, + topic=topic, + name=name, + ) + response: GetCollectionsResponse = self._sys_db_stub.GetCollections(request) + results: List[Collection] = [] + for collection in response.collections: + results.append(from_proto_collection(collection)) + return results + + @overrides + def update_collection( + self, + id: UUID, + topic: OptionalArgument[str] = Unspecified(), + name: OptionalArgument[str] = Unspecified(), + dimension: OptionalArgument[Optional[int]] = Unspecified(), + metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(), + ) -> None: + write_topic = None + if topic != Unspecified(): + write_topic = cast(str, topic) + + write_name = None + if name != Unspecified(): + write_name = cast(str, name) + + write_dimension = None + if dimension != Unspecified(): + write_dimension = cast(Union[int, None], dimension) + + write_metadata = None + if metadata != Unspecified(): + write_metadata = cast(Union[UpdateMetadata, None], metadata) + + request = UpdateCollectionRequest( + id=id.hex, + topic=write_topic, + name=write_name, + dimension=write_dimension, + metadata=to_proto_update_metadata(write_metadata) + if write_metadata + else None, + ) + if metadata is None: + request.ClearField("metadata") + request.reset_metadata = True + + self._sys_db_stub.UpdateCollection(request) + + def reset_and_wait_for_ready(self) -> None: + self._sys_db_stub.ResetState(Empty(), wait_for_ready=True) diff --git a/chromadb/db/impl/grpc/server.py b/chromadb/db/impl/grpc/server.py new file mode 100644 index 000000000000..77405ea4ac91 --- /dev/null +++ b/chromadb/db/impl/grpc/server.py @@ -0,0 +1,273 @@ +from concurrent import futures +from typing import Any, Dict, cast +from uuid import UUID +from overrides import overrides +from chromadb.config import Component, System +from chromadb.proto.convert import ( + from_proto_collection, + from_proto_update_metadata, + from_proto_segment, + from_proto_segment_scope, + to_proto_collection, + to_proto_segment, +) +import chromadb.proto.chroma_pb2 as proto +from chromadb.proto.coordinator_pb2 import ( + CreateCollectionRequest, + CreateCollectionResponse, + CreateSegmentRequest, + DeleteCollectionRequest, + DeleteSegmentRequest, + GetCollectionsRequest, + GetCollectionsResponse, + GetSegmentsRequest, + GetSegmentsResponse, + UpdateCollectionRequest, + UpdateSegmentRequest, +) +from chromadb.proto.coordinator_pb2_grpc import ( + SysDBServicer, + add_SysDBServicer_to_server, +) +import grpc +from google.protobuf.empty_pb2 import Empty +from chromadb.types import Collection, Metadata, Segment + + +class GrpcMockSysDB(SysDBServicer, Component): + """A mock sysdb implementation that can be used for testing the grpc client. It stores + state in simple python data structures instead of a database.""" + + _server: grpc.Server + _segments: Dict[str, Segment] = {} + _collections: Dict[str, Collection] = {} + + def __init__(self, system: System): + return super().__init__(system) + + @overrides + def start(self) -> None: + self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + add_SysDBServicer_to_server(self, self._server) # type: ignore + self._server.add_insecure_port("[::]:50051") # TODO: make port configurable + self._server.start() + return super().start() + + @overrides + def stop(self) -> None: + self._server.stop(0) + return super().stop() + + @overrides + def reset_state(self) -> None: + self._segments = {} + self._collections = {} + return super().reset_state() + + # We are forced to use check_signature=False because the generated proto code + # does not have type annotations for the request and response objects. + # TODO: investigate generating types for the request and response objects + @overrides(check_signature=False) + def CreateSegment( + self, request: CreateSegmentRequest, context: grpc.ServicerContext + ) -> proto.ChromaResponse: + segment = from_proto_segment(request.segment) + if segment["id"].hex in self._segments: + return proto.ChromaResponse( + status=proto.Status( + code=409, reason=f"Segment {segment['id']} already exists" + ) + ) + self._segments[segment["id"].hex] = segment + return proto.ChromaResponse( + status=proto.Status(code=200) + ) # TODO: how are these codes used? + + @overrides(check_signature=False) + def DeleteSegment( + self, request: DeleteSegmentRequest, context: grpc.ServicerContext + ) -> proto.ChromaResponse: + id_to_delete = request.id + if id_to_delete in self._segments: + del self._segments[id_to_delete] + return proto.ChromaResponse(status=proto.Status(code=200)) + else: + return proto.ChromaResponse( + status=proto.Status( + code=404, reason=f"Segment {id_to_delete} not found" + ) + ) + + @overrides(check_signature=False) + def GetSegments( + self, request: GetSegmentsRequest, context: grpc.ServicerContext + ) -> GetSegmentsResponse: + target_id = UUID(hex=request.id) if request.HasField("id") else None + target_type = request.type if request.HasField("type") else None + target_scope = ( + from_proto_segment_scope(request.scope) + if request.HasField("scope") + else None + ) + target_topic = request.topic if request.HasField("topic") else None + target_collection = ( + UUID(hex=request.collection) if request.HasField("collection") else None + ) + + found_segments = [] + for segment in self._segments.values(): + if target_id and segment["id"] != target_id: + continue + if target_type and segment["type"] != target_type: + continue + if target_scope and segment["scope"] != target_scope: + continue + if target_topic and segment["topic"] != target_topic: + continue + if target_collection and segment["collection"] != target_collection: + continue + found_segments.append(segment) + return GetSegmentsResponse( + segments=[to_proto_segment(segment) for segment in found_segments] + ) + + @overrides(check_signature=False) + def UpdateSegment( + self, request: UpdateSegmentRequest, context: grpc.ServicerContext + ) -> proto.ChromaResponse: + id_to_update = UUID(request.id) + if id_to_update.hex not in self._segments: + return proto.ChromaResponse( + status=proto.Status( + code=404, reason=f"Segment {id_to_update} not found" + ) + ) + else: + segment = self._segments[id_to_update.hex] + if request.HasField("topic"): + segment["topic"] = request.topic + if request.HasField("reset_topic") and request.reset_topic: + segment["topic"] = None + if request.HasField("collection"): + segment["collection"] = UUID(hex=request.collection) + if request.HasField("reset_collection") and request.reset_collection: + segment["collection"] = None + if request.HasField("metadata"): + target = cast(Dict[str, Any], segment["metadata"]) + if segment["metadata"] is None: + segment["metadata"] = {} + self._merge_metadata(target, request.metadata) + if request.HasField("reset_metadata") and request.reset_metadata: + segment["metadata"] = {} + return proto.ChromaResponse(status=proto.Status(code=200)) + + @overrides(check_signature=False) + def CreateCollection( + self, request: CreateCollectionRequest, context: grpc.ServicerContext + ) -> CreateCollectionResponse: + collection = from_proto_collection(request.collection) + if collection["id"].hex in self._collections: + return CreateCollectionResponse( + status=proto.Status( + code=409, reason=f"Collection {collection['id']} already exists" + ) + ) + + self._collections[collection["id"].hex] = collection + return CreateCollectionResponse( + status=proto.Status(code=200), + collection=to_proto_collection(collection), + ) + + @overrides(check_signature=False) + def DeleteCollection( + self, request: DeleteCollectionRequest, context: grpc.ServicerContext + ) -> proto.ChromaResponse: + collection_id = request.id + if collection_id in self._collections: + del self._collections[collection_id] + return proto.ChromaResponse(status=proto.Status(code=200)) + else: + return proto.ChromaResponse( + status=proto.Status( + code=404, reason=f"Collection {collection_id} not found" + ) + ) + + @overrides(check_signature=False) + def GetCollections( + self, request: GetCollectionsRequest, context: grpc.ServicerContext + ) -> GetCollectionsResponse: + target_id = UUID(hex=request.id) if request.HasField("id") else None + target_topic = request.topic if request.HasField("topic") else None + target_name = request.name if request.HasField("name") else None + + found_collections = [] + for collection in self._collections.values(): + if target_id and collection["id"] != target_id: + continue + if target_topic and collection["topic"] != target_topic: + continue + if target_name and collection["name"] != target_name: + continue + found_collections.append(collection) + return GetCollectionsResponse( + collections=[ + to_proto_collection(collection) for collection in found_collections + ] + ) + + @overrides(check_signature=False) + def UpdateCollection( + self, request: UpdateCollectionRequest, context: grpc.ServicerContext + ) -> proto.ChromaResponse: + id_to_update = UUID(request.id) + if id_to_update.hex not in self._collections: + return proto.ChromaResponse( + status=proto.Status( + code=404, reason=f"Collection {id_to_update} not found" + ) + ) + else: + collection = self._collections[id_to_update.hex] + if request.HasField("topic"): + collection["topic"] = request.topic + if request.HasField("name"): + collection["name"] = request.name + if request.HasField("dimension"): + collection["dimension"] = request.dimension + if request.HasField("metadata"): + # TODO: IN SysDB SQlite we have technical debt where we + # replace the entire metadata dict with the new one. We should + # fix that by merging it. For now we just do the same thing here + + update_metadata = from_proto_update_metadata(request.metadata) + cleaned_metadata = None + if update_metadata is not None: + cleaned_metadata = {} + for key, value in update_metadata.items(): + if value is not None: + cleaned_metadata[key] = value + + collection["metadata"] = cleaned_metadata + elif request.HasField("reset_metadata"): + if request.reset_metadata: + collection["metadata"] = {} + + return proto.ChromaResponse(status=proto.Status(code=200)) + + @overrides(check_signature=False) + def ResetState( + self, request: Empty, context: grpc.ServicerContext + ) -> proto.ChromaResponse: + self.reset_state() + return proto.ChromaResponse(status=proto.Status(code=200)) + + def _merge_metadata(self, target: Metadata, source: proto.UpdateMetadata) -> None: + target_metadata = cast(Dict[str, Any], target) + source_metadata = cast(Dict[str, Any], from_proto_update_metadata(source)) + target_metadata.update(source_metadata) + # If a key has a None value, remove it from the metadata + for key, value in source_metadata.items(): + if value is None and key in target: + del target_metadata[key] diff --git a/chromadb/db/system.py b/chromadb/db/system.py index 969b71afa3b6..23f068c3be37 100644 --- a/chromadb/db/system.py +++ b/chromadb/db/system.py @@ -53,7 +53,7 @@ def update_segment( @abstractmethod def create_collection(self, collection: Collection) -> None: - """Create a new topic""" + """Create a new collection any associated resources in the SysDB.""" pass @abstractmethod diff --git a/chromadb/proto/chroma_pb2.py b/chromadb/proto/chroma_pb2.py index d6b22217d7dc..bd069cc74f48 100644 --- a/chromadb/proto/chroma_pb2.py +++ b/chromadb/proto/chroma_pb2.py @@ -13,7 +13,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1b\x63hromadb/proto/chroma.proto\x12\x06\x63hroma\"U\n\x06Vector\x12\x11\n\tdimension\x18\x01 \x01(\x05\x12\x0e\n\x06vector\x18\x02 \x01(\x0c\x12(\n\x08\x65ncoding\x18\x03 \x01(\x0e\x32\x16.chroma.ScalarEncoding\"\xca\x01\n\x07Segment\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12#\n\x05scope\x18\x03 \x01(\x0e\x32\x14.chroma.SegmentScope\x12\x12\n\x05topic\x18\x04 \x01(\tH\x00\x88\x01\x01\x12\x17\n\ncollection\x18\x05 \x01(\tH\x01\x88\x01\x01\x12-\n\x08metadata\x18\x06 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x88\x01\x01\x42\x08\n\x06_topicB\r\n\x0b_collectionB\x0b\n\t_metadata\"b\n\x13UpdateMetadataValue\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x13\n\tint_value\x18\x02 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x01H\x00\x42\x07\n\x05value\"\x96\x01\n\x0eUpdateMetadata\x12\x36\n\x08metadata\x18\x01 \x03(\x0b\x32$.chroma.UpdateMetadata.MetadataEntry\x1aL\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12*\n\x05value\x18\x02 \x01(\x0b\x32\x1b.chroma.UpdateMetadataValue:\x02\x38\x01\"\xb5\x01\n\x15SubmitEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12#\n\x06vector\x18\x02 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x12-\n\x08metadata\x18\x03 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x01\x88\x01\x01\x12$\n\toperation\x18\x04 \x01(\x0e\x32\x11.chroma.OperationB\t\n\x07_vectorB\x0b\n\t_metadata\"S\n\x15VectorEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06seq_id\x18\x02 \x01(\x0c\x12\x1e\n\x06vector\x18\x03 \x01(\x0b\x32\x0e.chroma.Vector\"q\n\x11VectorQueryResult\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06seq_id\x18\x02 \x01(\x0c\x12\x10\n\x08\x64istance\x18\x03 \x01(\x01\x12#\n\x06vector\x18\x04 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x42\t\n\x07_vector\"@\n\x12VectorQueryResults\x12*\n\x07results\x18\x01 \x03(\x0b\x32\x19.chroma.VectorQueryResult\"(\n\x15SegmentServerResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"4\n\x11GetVectorsRequest\x12\x0b\n\x03ids\x18\x01 \x03(\t\x12\x12\n\nsegment_id\x18\x02 \x01(\t\"D\n\x12GetVectorsResponse\x12.\n\x07records\x18\x01 \x03(\x0b\x32\x1d.chroma.VectorEmbeddingRecord\"\x86\x01\n\x13QueryVectorsRequest\x12\x1f\n\x07vectors\x18\x01 \x03(\x0b\x32\x0e.chroma.Vector\x12\t\n\x01k\x18\x02 \x01(\x05\x12\x13\n\x0b\x61llowed_ids\x18\x03 \x03(\t\x12\x1a\n\x12include_embeddings\x18\x04 \x01(\x08\x12\x12\n\nsegment_id\x18\x05 \x01(\t\"C\n\x14QueryVectorsResponse\x12+\n\x07results\x18\x01 \x03(\x0b\x32\x1a.chroma.VectorQueryResults*8\n\tOperation\x12\x07\n\x03\x41\x44\x44\x10\x00\x12\n\n\x06UPDATE\x10\x01\x12\n\n\x06UPSERT\x10\x02\x12\n\n\x06\x44\x45LETE\x10\x03*(\n\x0eScalarEncoding\x12\x0b\n\x07\x46LOAT32\x10\x00\x12\t\n\x05INT32\x10\x01*(\n\x0cSegmentScope\x12\n\n\x06VECTOR\x10\x00\x12\x0c\n\x08METADATA\x10\x01\x32\x94\x01\n\rSegmentServer\x12?\n\x0bLoadSegment\x12\x0f.chroma.Segment\x1a\x1d.chroma.SegmentServerResponse\"\x00\x12\x42\n\x0eReleaseSegment\x12\x0f.chroma.Segment\x1a\x1d.chroma.SegmentServerResponse\"\x00\x32\xa2\x01\n\x0cVectorReader\x12\x45\n\nGetVectors\x12\x19.chroma.GetVectorsRequest\x1a\x1a.chroma.GetVectorsResponse\"\x00\x12K\n\x0cQueryVectors\x12\x1b.chroma.QueryVectorsRequest\x1a\x1c.chroma.QueryVectorsResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1b\x63hromadb/proto/chroma.proto\x12\x06\x63hroma\"&\n\x06Status\x12\x0e\n\x06reason\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\x05\"0\n\x0e\x43hromaResponse\x12\x1e\n\x06status\x18\x01 \x01(\x0b\x32\x0e.chroma.Status\"U\n\x06Vector\x12\x11\n\tdimension\x18\x01 \x01(\x05\x12\x0e\n\x06vector\x18\x02 \x01(\x0c\x12(\n\x08\x65ncoding\x18\x03 \x01(\x0e\x32\x16.chroma.ScalarEncoding\"\xca\x01\n\x07Segment\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12#\n\x05scope\x18\x03 \x01(\x0e\x32\x14.chroma.SegmentScope\x12\x12\n\x05topic\x18\x04 \x01(\tH\x00\x88\x01\x01\x12\x17\n\ncollection\x18\x05 \x01(\tH\x01\x88\x01\x01\x12-\n\x08metadata\x18\x06 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x88\x01\x01\x42\x08\n\x06_topicB\r\n\x0b_collectionB\x0b\n\t_metadata\"\x97\x01\n\nCollection\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\r\n\x05topic\x18\x03 \x01(\t\x12-\n\x08metadata\x18\x04 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x88\x01\x01\x12\x16\n\tdimension\x18\x05 \x01(\x05H\x01\x88\x01\x01\x42\x0b\n\t_metadataB\x0c\n\n_dimension\"b\n\x13UpdateMetadataValue\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x13\n\tint_value\x18\x02 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x01H\x00\x42\x07\n\x05value\"\x96\x01\n\x0eUpdateMetadata\x12\x36\n\x08metadata\x18\x01 \x03(\x0b\x32$.chroma.UpdateMetadata.MetadataEntry\x1aL\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12*\n\x05value\x18\x02 \x01(\x0b\x32\x1b.chroma.UpdateMetadataValue:\x02\x38\x01\"\xb5\x01\n\x15SubmitEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12#\n\x06vector\x18\x02 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x12-\n\x08metadata\x18\x03 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x01\x88\x01\x01\x12$\n\toperation\x18\x04 \x01(\x0e\x32\x11.chroma.OperationB\t\n\x07_vectorB\x0b\n\t_metadata\"S\n\x15VectorEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06seq_id\x18\x02 \x01(\x0c\x12\x1e\n\x06vector\x18\x03 \x01(\x0b\x32\x0e.chroma.Vector\"q\n\x11VectorQueryResult\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06seq_id\x18\x02 \x01(\x0c\x12\x10\n\x08\x64istance\x18\x03 \x01(\x01\x12#\n\x06vector\x18\x04 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x42\t\n\x07_vector\"@\n\x12VectorQueryResults\x12*\n\x07results\x18\x01 \x03(\x0b\x32\x19.chroma.VectorQueryResult\"(\n\x15SegmentServerResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"4\n\x11GetVectorsRequest\x12\x0b\n\x03ids\x18\x01 \x03(\t\x12\x12\n\nsegment_id\x18\x02 \x01(\t\"D\n\x12GetVectorsResponse\x12.\n\x07records\x18\x01 \x03(\x0b\x32\x1d.chroma.VectorEmbeddingRecord\"\x86\x01\n\x13QueryVectorsRequest\x12\x1f\n\x07vectors\x18\x01 \x03(\x0b\x32\x0e.chroma.Vector\x12\t\n\x01k\x18\x02 \x01(\x05\x12\x13\n\x0b\x61llowed_ids\x18\x03 \x03(\t\x12\x1a\n\x12include_embeddings\x18\x04 \x01(\x08\x12\x12\n\nsegment_id\x18\x05 \x01(\t\"C\n\x14QueryVectorsResponse\x12+\n\x07results\x18\x01 \x03(\x0b\x32\x1a.chroma.VectorQueryResults*8\n\tOperation\x12\x07\n\x03\x41\x44\x44\x10\x00\x12\n\n\x06UPDATE\x10\x01\x12\n\n\x06UPSERT\x10\x02\x12\n\n\x06\x44\x45LETE\x10\x03*(\n\x0eScalarEncoding\x12\x0b\n\x07\x46LOAT32\x10\x00\x12\t\n\x05INT32\x10\x01*(\n\x0cSegmentScope\x12\n\n\x06VECTOR\x10\x00\x12\x0c\n\x08METADATA\x10\x01\x32\x94\x01\n\rSegmentServer\x12?\n\x0bLoadSegment\x12\x0f.chroma.Segment\x1a\x1d.chroma.SegmentServerResponse\"\x00\x12\x42\n\x0eReleaseSegment\x12\x0f.chroma.Segment\x1a\x1d.chroma.SegmentServerResponse\"\x00\x32\xa2\x01\n\x0cVectorReader\x12\x45\n\nGetVectors\x12\x19.chroma.GetVectorsRequest\x1a\x1a.chroma.GetVectorsResponse\"\x00\x12K\n\x0cQueryVectors\x12\x1b.chroma.QueryVectorsRequest\x1a\x1c.chroma.QueryVectorsResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -22,42 +22,48 @@ DESCRIPTOR._options = None _UPDATEMETADATA_METADATAENTRY._options = None _UPDATEMETADATA_METADATAENTRY._serialized_options = b'8\001' - _globals['_OPERATION']._serialized_start=1406 - _globals['_OPERATION']._serialized_end=1462 - _globals['_SCALARENCODING']._serialized_start=1464 - _globals['_SCALARENCODING']._serialized_end=1504 - _globals['_SEGMENTSCOPE']._serialized_start=1506 - _globals['_SEGMENTSCOPE']._serialized_end=1546 - _globals['_VECTOR']._serialized_start=39 - _globals['_VECTOR']._serialized_end=124 - _globals['_SEGMENT']._serialized_start=127 - _globals['_SEGMENT']._serialized_end=329 - _globals['_UPDATEMETADATAVALUE']._serialized_start=331 - _globals['_UPDATEMETADATAVALUE']._serialized_end=429 - _globals['_UPDATEMETADATA']._serialized_start=432 - _globals['_UPDATEMETADATA']._serialized_end=582 - _globals['_UPDATEMETADATA_METADATAENTRY']._serialized_start=506 - _globals['_UPDATEMETADATA_METADATAENTRY']._serialized_end=582 - _globals['_SUBMITEMBEDDINGRECORD']._serialized_start=585 - _globals['_SUBMITEMBEDDINGRECORD']._serialized_end=766 - _globals['_VECTOREMBEDDINGRECORD']._serialized_start=768 - _globals['_VECTOREMBEDDINGRECORD']._serialized_end=851 - _globals['_VECTORQUERYRESULT']._serialized_start=853 - _globals['_VECTORQUERYRESULT']._serialized_end=966 - _globals['_VECTORQUERYRESULTS']._serialized_start=968 - _globals['_VECTORQUERYRESULTS']._serialized_end=1032 - _globals['_SEGMENTSERVERRESPONSE']._serialized_start=1034 - _globals['_SEGMENTSERVERRESPONSE']._serialized_end=1074 - _globals['_GETVECTORSREQUEST']._serialized_start=1076 - _globals['_GETVECTORSREQUEST']._serialized_end=1128 - _globals['_GETVECTORSRESPONSE']._serialized_start=1130 - _globals['_GETVECTORSRESPONSE']._serialized_end=1198 - _globals['_QUERYVECTORSREQUEST']._serialized_start=1201 - _globals['_QUERYVECTORSREQUEST']._serialized_end=1335 - _globals['_QUERYVECTORSRESPONSE']._serialized_start=1337 - _globals['_QUERYVECTORSRESPONSE']._serialized_end=1404 - _globals['_SEGMENTSERVER']._serialized_start=1549 - _globals['_SEGMENTSERVER']._serialized_end=1697 - _globals['_VECTORREADER']._serialized_start=1700 - _globals['_VECTORREADER']._serialized_end=1862 + _globals['_OPERATION']._serialized_start=1650 + _globals['_OPERATION']._serialized_end=1706 + _globals['_SCALARENCODING']._serialized_start=1708 + _globals['_SCALARENCODING']._serialized_end=1748 + _globals['_SEGMENTSCOPE']._serialized_start=1750 + _globals['_SEGMENTSCOPE']._serialized_end=1790 + _globals['_STATUS']._serialized_start=39 + _globals['_STATUS']._serialized_end=77 + _globals['_CHROMARESPONSE']._serialized_start=79 + _globals['_CHROMARESPONSE']._serialized_end=127 + _globals['_VECTOR']._serialized_start=129 + _globals['_VECTOR']._serialized_end=214 + _globals['_SEGMENT']._serialized_start=217 + _globals['_SEGMENT']._serialized_end=419 + _globals['_COLLECTION']._serialized_start=422 + _globals['_COLLECTION']._serialized_end=573 + _globals['_UPDATEMETADATAVALUE']._serialized_start=575 + _globals['_UPDATEMETADATAVALUE']._serialized_end=673 + _globals['_UPDATEMETADATA']._serialized_start=676 + _globals['_UPDATEMETADATA']._serialized_end=826 + _globals['_UPDATEMETADATA_METADATAENTRY']._serialized_start=750 + _globals['_UPDATEMETADATA_METADATAENTRY']._serialized_end=826 + _globals['_SUBMITEMBEDDINGRECORD']._serialized_start=829 + _globals['_SUBMITEMBEDDINGRECORD']._serialized_end=1010 + _globals['_VECTOREMBEDDINGRECORD']._serialized_start=1012 + _globals['_VECTOREMBEDDINGRECORD']._serialized_end=1095 + _globals['_VECTORQUERYRESULT']._serialized_start=1097 + _globals['_VECTORQUERYRESULT']._serialized_end=1210 + _globals['_VECTORQUERYRESULTS']._serialized_start=1212 + _globals['_VECTORQUERYRESULTS']._serialized_end=1276 + _globals['_SEGMENTSERVERRESPONSE']._serialized_start=1278 + _globals['_SEGMENTSERVERRESPONSE']._serialized_end=1318 + _globals['_GETVECTORSREQUEST']._serialized_start=1320 + _globals['_GETVECTORSREQUEST']._serialized_end=1372 + _globals['_GETVECTORSRESPONSE']._serialized_start=1374 + _globals['_GETVECTORSRESPONSE']._serialized_end=1442 + _globals['_QUERYVECTORSREQUEST']._serialized_start=1445 + _globals['_QUERYVECTORSREQUEST']._serialized_end=1579 + _globals['_QUERYVECTORSRESPONSE']._serialized_start=1581 + _globals['_QUERYVECTORSRESPONSE']._serialized_end=1648 + _globals['_SEGMENTSERVER']._serialized_start=1793 + _globals['_SEGMENTSERVER']._serialized_end=1941 + _globals['_VECTORREADER']._serialized_start=1944 + _globals['_VECTORREADER']._serialized_end=2106 # @@protoc_insertion_point(module_scope) diff --git a/chromadb/proto/chroma_pb2.pyi b/chromadb/proto/chroma_pb2.pyi index 6d06e074c06b..733cae0a2738 100644 --- a/chromadb/proto/chroma_pb2.pyi +++ b/chromadb/proto/chroma_pb2.pyi @@ -31,6 +31,20 @@ INT32: ScalarEncoding VECTOR: SegmentScope METADATA: SegmentScope +class Status(_message.Message): + __slots__ = ["reason", "code"] + REASON_FIELD_NUMBER: _ClassVar[int] + CODE_FIELD_NUMBER: _ClassVar[int] + reason: str + code: int + def __init__(self, reason: _Optional[str] = ..., code: _Optional[int] = ...) -> None: ... + +class ChromaResponse(_message.Message): + __slots__ = ["status"] + STATUS_FIELD_NUMBER: _ClassVar[int] + status: Status + def __init__(self, status: _Optional[_Union[Status, _Mapping]] = ...) -> None: ... + class Vector(_message.Message): __slots__ = ["dimension", "vector", "encoding"] DIMENSION_FIELD_NUMBER: _ClassVar[int] @@ -57,6 +71,20 @@ class Segment(_message.Message): metadata: UpdateMetadata def __init__(self, id: _Optional[str] = ..., type: _Optional[str] = ..., scope: _Optional[_Union[SegmentScope, str]] = ..., topic: _Optional[str] = ..., collection: _Optional[str] = ..., metadata: _Optional[_Union[UpdateMetadata, _Mapping]] = ...) -> None: ... +class Collection(_message.Message): + __slots__ = ["id", "name", "topic", "metadata", "dimension"] + ID_FIELD_NUMBER: _ClassVar[int] + NAME_FIELD_NUMBER: _ClassVar[int] + TOPIC_FIELD_NUMBER: _ClassVar[int] + METADATA_FIELD_NUMBER: _ClassVar[int] + DIMENSION_FIELD_NUMBER: _ClassVar[int] + id: str + name: str + topic: str + metadata: UpdateMetadata + dimension: int + def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., topic: _Optional[str] = ..., metadata: _Optional[_Union[UpdateMetadata, _Mapping]] = ..., dimension: _Optional[int] = ...) -> None: ... + class UpdateMetadataValue(_message.Message): __slots__ = ["string_value", "int_value", "float_value"] STRING_VALUE_FIELD_NUMBER: _ClassVar[int] diff --git a/chromadb/proto/convert.py b/chromadb/proto/convert.py index 5ff7bab085de..d46cad07710a 100644 --- a/chromadb/proto/convert.py +++ b/chromadb/proto/convert.py @@ -1,10 +1,11 @@ import array from uuid import UUID -from typing import Dict, Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union, cast from chromadb.api.types import Embedding import chromadb.proto.chroma_pb2 as proto from chromadb.utils.messageid import bytes_to_int, int_to_bytes from chromadb.types import ( + Collection, EmbeddingRecord, Metadata, Operation, @@ -13,6 +14,7 @@ SegmentScope, SeqId, SubmitEmbeddingRecord, + UpdateMetadata, Vector, VectorEmbeddingRecord, VectorQueryResult, @@ -71,9 +73,23 @@ def from_proto_operation(operation: proto.Operation) -> Operation: def from_proto_metadata(metadata: proto.UpdateMetadata) -> Optional[Metadata]: + return cast(Optional[Metadata], _from_proto_metadata_handle_none(metadata, False)) + + +def from_proto_update_metadata( + metadata: proto.UpdateMetadata, +) -> Optional[UpdateMetadata]: + return cast( + Optional[UpdateMetadata], _from_proto_metadata_handle_none(metadata, True) + ) + + +def _from_proto_metadata_handle_none( + metadata: proto.UpdateMetadata, is_update: bool +) -> Optional[Union[UpdateMetadata, Metadata]]: if not metadata.metadata: return None - out_metadata: Dict[str, Union[str, int, float]] = {} + out_metadata: Dict[str, Union[str, int, float, None]] = {} for key, value in metadata.metadata.items(): if value.HasField("string_value"): out_metadata[key] = value.string_value @@ -81,11 +97,19 @@ def from_proto_metadata(metadata: proto.UpdateMetadata) -> Optional[Metadata]: out_metadata[key] = value.int_value elif value.HasField("float_value"): out_metadata[key] = value.float_value + elif is_update: + out_metadata[key] = None else: - raise RuntimeError(f"Unknown metadata value type {value}") + raise ValueError(f"Metadata key {key} value cannot be None") return out_metadata +def to_proto_update_metadata(metadata: UpdateMetadata) -> proto.UpdateMetadata: + return proto.UpdateMetadata( + metadata={k: to_proto_metadata_update_value(v) for k, v in metadata.items()} + ) + + def from_proto_submit( submit_embedding_record: proto.SubmitEmbeddingRecord, seq_id: SeqId ) -> EmbeddingRecord: @@ -95,7 +119,7 @@ def from_proto_submit( seq_id=seq_id, embedding=embedding, encoding=encoding, - metadata=from_proto_metadata(submit_embedding_record.metadata), + metadata=from_proto_update_metadata(submit_embedding_record.metadata), operation=from_proto_operation(submit_embedding_record.operation), ) return record @@ -106,11 +130,13 @@ def from_proto_segment(segment: proto.Segment) -> Segment: id=UUID(hex=segment.id), type=segment.type, scope=from_proto_segment_scope(segment.scope), - topic=segment.topic, + topic=segment.topic if segment.HasField("topic") else None, collection=None if not segment.HasField("collection") else UUID(hex=segment.collection), - metadata=from_proto_metadata(segment.metadata), + metadata=from_proto_metadata(segment.metadata) + if segment.HasField("metadata") + else None, ) @@ -123,9 +149,7 @@ def to_proto_segment(segment: Segment) -> proto.Segment: collection=None if segment["collection"] is None else segment["collection"].hex, metadata=None if segment["metadata"] is None - else { - k: to_proto_metadata_update_value(v) for k, v in segment["metadata"].items() - }, # TODO: refactor out to_proto_metadata + else to_proto_update_metadata(segment["metadata"]), ) @@ -165,6 +189,30 @@ def to_proto_metadata_update_value( ) +def from_proto_collection(collection: proto.Collection) -> Collection: + return Collection( + id=UUID(hex=collection.id), + name=collection.name, + topic=collection.topic, + metadata=from_proto_metadata(collection.metadata) + if collection.HasField("metadata") + else None, + dimension=collection.dimension if collection.HasField("dimension") else None, + ) + + +def to_proto_collection(collection: Collection) -> proto.Collection: + return proto.Collection( + id=collection["id"].hex, + name=collection["name"], + topic=collection["topic"], + metadata=None + if collection["metadata"] is None + else to_proto_update_metadata(collection["metadata"]), + dimension=collection["dimension"], + ) + + def to_proto_operation(operation: Operation) -> proto.Operation: if operation == Operation.ADD: return proto.Operation.ADD @@ -190,17 +238,12 @@ def to_proto_submit( metadata = None if submit_record["metadata"] is not None: - metadata = { - k: to_proto_metadata_update_value(v) - for k, v in submit_record["metadata"].items() - } + metadata = to_proto_update_metadata(submit_record["metadata"]) return proto.SubmitEmbeddingRecord( id=submit_record["id"], vector=vector, - metadata=proto.UpdateMetadata(metadata=metadata) - if metadata is not None - else None, + metadata=metadata, operation=to_proto_operation(submit_record["operation"]), ) diff --git a/chromadb/proto/coordinator_pb2.py b/chromadb/proto/coordinator_pb2.py index 91b350485019..29c7dd5dc518 100644 --- a/chromadb/proto/coordinator_pb2.py +++ b/chromadb/proto/coordinator_pb2.py @@ -6,38 +6,44 @@ from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder - # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() from chromadb.proto import chroma_pb2 as chromadb_dot_proto_dot_chroma__pb2 +from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n chromadb/proto/coordinator.proto\x12\x06\x63hroma\x1a\x1b\x63hromadb/proto/chroma.proto"&\n\x06Status\x12\x0e\n\x06reason\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\x05"o\n\x17\x43reateCollectionRequest\x12&\n\ncollection\x18\x01 \x01(\x0b\x32\x12.chroma.Collection\x12\x1a\n\rget_or_create\x18\x02 \x01(\x08H\x00\x88\x01\x01\x42\x10\n\x0e_get_or_create"b\n\nCollection\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12-\n\x08metadata\x18\x03 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x88\x01\x01\x42\x0b\n\t_metadata"b\n\x18\x43reateCollectionResponse\x12&\n\ncollection\x18\x01 \x01(\x0b\x32\x12.chroma.Collection\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status"i\n\x15GetCollectionsRequest\x12\x0f\n\x02id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x12\n\x05topic\x18\x03 \x01(\tH\x02\x88\x01\x01\x42\x05\n\x03_idB\x07\n\x05_nameB\x08\n\x06_topic"a\n\x16GetCollectionsResponse\x12\'\n\x0b\x63ollections\x18\x01 \x03(\x0b\x32\x12.chroma.Collection\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status2\xb3\x01\n\x05SysDB\x12W\n\x10\x43reateCollection\x12\x1f.chroma.CreateCollectionRequest\x1a .chroma.CreateCollectionResponse"\x00\x12Q\n\x0eGetCollections\x12\x1d.chroma.GetCollectionsRequest\x1a\x1e.chroma.GetCollectionsResponse"\x00\x62\x06proto3' -) +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n chromadb/proto/coordinator.proto\x12\x06\x63hroma\x1a\x1b\x63hromadb/proto/chroma.proto\x1a\x1bgoogle/protobuf/empty.proto\"8\n\x14\x43reateSegmentRequest\x12 \n\x07segment\x18\x01 \x01(\x0b\x32\x0f.chroma.Segment\"\"\n\x14\x44\x65leteSegmentRequest\x12\n\n\x02id\x18\x01 \x01(\t\"\xc2\x01\n\x12GetSegmentsRequest\x12\x0f\n\x02id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04type\x18\x02 \x01(\tH\x01\x88\x01\x01\x12(\n\x05scope\x18\x03 \x01(\x0e\x32\x14.chroma.SegmentScopeH\x02\x88\x01\x01\x12\x12\n\x05topic\x18\x04 \x01(\tH\x03\x88\x01\x01\x12\x17\n\ncollection\x18\x05 \x01(\tH\x04\x88\x01\x01\x42\x05\n\x03_idB\x07\n\x05_typeB\x08\n\x06_scopeB\x08\n\x06_topicB\r\n\x0b_collection\"X\n\x13GetSegmentsResponse\x12!\n\x08segments\x18\x01 \x03(\x0b\x32\x0f.chroma.Segment\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"\xfa\x01\n\x14UpdateSegmentRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0f\n\x05topic\x18\x02 \x01(\tH\x00\x12\x15\n\x0breset_topic\x18\x03 \x01(\x08H\x00\x12\x14\n\ncollection\x18\x04 \x01(\tH\x01\x12\x1a\n\x10reset_collection\x18\x05 \x01(\x08H\x01\x12*\n\x08metadata\x18\x06 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x12\x18\n\x0ereset_metadata\x18\x07 \x01(\x08H\x02\x42\x0e\n\x0ctopic_updateB\x13\n\x11\x63ollection_updateB\x11\n\x0fmetadata_update\"o\n\x17\x43reateCollectionRequest\x12&\n\ncollection\x18\x01 \x01(\x0b\x32\x12.chroma.Collection\x12\x1a\n\rget_or_create\x18\x02 \x01(\x08H\x00\x88\x01\x01\x42\x10\n\x0e_get_or_create\"b\n\x18\x43reateCollectionResponse\x12&\n\ncollection\x18\x01 \x01(\x0b\x32\x12.chroma.Collection\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"%\n\x17\x44\x65leteCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\"i\n\x15GetCollectionsRequest\x12\x0f\n\x02id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x12\n\x05topic\x18\x03 \x01(\tH\x02\x88\x01\x01\x42\x05\n\x03_idB\x07\n\x05_nameB\x08\n\x06_topic\"a\n\x16GetCollectionsResponse\x12\'\n\x0b\x63ollections\x18\x01 \x03(\x0b\x32\x12.chroma.Collection\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"\xde\x01\n\x17UpdateCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x12\n\x05topic\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x11\n\x04name\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x16\n\tdimension\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12*\n\x08metadata\x18\x05 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x12\x18\n\x0ereset_metadata\x18\x06 \x01(\x08H\x00\x42\x11\n\x0fmetadata_updateB\x08\n\x06_topicB\x07\n\x05_nameB\x0c\n\n_dimension2\xb6\x05\n\x05SysDB\x12G\n\rCreateSegment\x12\x1c.chroma.CreateSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12G\n\rDeleteSegment\x12\x1c.chroma.DeleteSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12H\n\x0bGetSegments\x12\x1a.chroma.GetSegmentsRequest\x1a\x1b.chroma.GetSegmentsResponse\"\x00\x12G\n\rUpdateSegment\x12\x1c.chroma.UpdateSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12W\n\x10\x43reateCollection\x12\x1f.chroma.CreateCollectionRequest\x1a .chroma.CreateCollectionResponse\"\x00\x12M\n\x10\x44\x65leteCollection\x12\x1f.chroma.DeleteCollectionRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12Q\n\x0eGetCollections\x12\x1d.chroma.GetCollectionsRequest\x1a\x1e.chroma.GetCollectionsResponse\"\x00\x12M\n\x10UpdateCollection\x12\x1f.chroma.UpdateCollectionRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12>\n\nResetState\x12\x16.google.protobuf.Empty\x1a\x16.chroma.ChromaResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages( - DESCRIPTOR, "chromadb.proto.coordinator_pb2", _globals -) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'chromadb.proto.coordinator_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _globals["_STATUS"]._serialized_start = 73 - _globals["_STATUS"]._serialized_end = 111 - _globals["_CREATECOLLECTIONREQUEST"]._serialized_start = 113 - _globals["_CREATECOLLECTIONREQUEST"]._serialized_end = 224 - _globals["_COLLECTION"]._serialized_start = 226 - _globals["_COLLECTION"]._serialized_end = 324 - _globals["_CREATECOLLECTIONRESPONSE"]._serialized_start = 326 - _globals["_CREATECOLLECTIONRESPONSE"]._serialized_end = 424 - _globals["_GETCOLLECTIONSREQUEST"]._serialized_start = 426 - _globals["_GETCOLLECTIONSREQUEST"]._serialized_end = 531 - _globals["_GETCOLLECTIONSRESPONSE"]._serialized_start = 533 - _globals["_GETCOLLECTIONSRESPONSE"]._serialized_end = 630 - _globals["_SYSDB"]._serialized_start = 633 - _globals["_SYSDB"]._serialized_end = 812 + DESCRIPTOR._options = None + _globals['_CREATESEGMENTREQUEST']._serialized_start=102 + _globals['_CREATESEGMENTREQUEST']._serialized_end=158 + _globals['_DELETESEGMENTREQUEST']._serialized_start=160 + _globals['_DELETESEGMENTREQUEST']._serialized_end=194 + _globals['_GETSEGMENTSREQUEST']._serialized_start=197 + _globals['_GETSEGMENTSREQUEST']._serialized_end=391 + _globals['_GETSEGMENTSRESPONSE']._serialized_start=393 + _globals['_GETSEGMENTSRESPONSE']._serialized_end=481 + _globals['_UPDATESEGMENTREQUEST']._serialized_start=484 + _globals['_UPDATESEGMENTREQUEST']._serialized_end=734 + _globals['_CREATECOLLECTIONREQUEST']._serialized_start=736 + _globals['_CREATECOLLECTIONREQUEST']._serialized_end=847 + _globals['_CREATECOLLECTIONRESPONSE']._serialized_start=849 + _globals['_CREATECOLLECTIONRESPONSE']._serialized_end=947 + _globals['_DELETECOLLECTIONREQUEST']._serialized_start=949 + _globals['_DELETECOLLECTIONREQUEST']._serialized_end=986 + _globals['_GETCOLLECTIONSREQUEST']._serialized_start=988 + _globals['_GETCOLLECTIONSREQUEST']._serialized_end=1093 + _globals['_GETCOLLECTIONSRESPONSE']._serialized_start=1095 + _globals['_GETCOLLECTIONSRESPONSE']._serialized_end=1192 + _globals['_UPDATECOLLECTIONREQUEST']._serialized_start=1195 + _globals['_UPDATECOLLECTIONREQUEST']._serialized_end=1417 + _globals['_SYSDB']._serialized_start=1420 + _globals['_SYSDB']._serialized_end=2114 # @@protoc_insertion_point(module_scope) diff --git a/chromadb/proto/coordinator_pb2.pyi b/chromadb/proto/coordinator_pb2.pyi index 69a76b675373..37736ff720e5 100644 --- a/chromadb/proto/coordinator_pb2.pyi +++ b/chromadb/proto/coordinator_pb2.pyi @@ -1,65 +1,85 @@ from chromadb.proto import chroma_pb2 as _chroma_pb2 +from google.protobuf import empty_pb2 as _empty_pb2 from google.protobuf.internal import containers as _containers from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message -from typing import ( - ClassVar as _ClassVar, - Iterable as _Iterable, - Mapping as _Mapping, - Optional as _Optional, - Union as _Union, -) +from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union DESCRIPTOR: _descriptor.FileDescriptor -class Status(_message.Message): - __slots__ = ["reason", "code"] - REASON_FIELD_NUMBER: _ClassVar[int] - CODE_FIELD_NUMBER: _ClassVar[int] - reason: str - code: int - def __init__( - self, reason: _Optional[str] = ..., code: _Optional[int] = ... - ) -> None: ... +class CreateSegmentRequest(_message.Message): + __slots__ = ["segment"] + SEGMENT_FIELD_NUMBER: _ClassVar[int] + segment: _chroma_pb2.Segment + def __init__(self, segment: _Optional[_Union[_chroma_pb2.Segment, _Mapping]] = ...) -> None: ... -class CreateCollectionRequest(_message.Message): - __slots__ = ["collection", "get_or_create"] +class DeleteSegmentRequest(_message.Message): + __slots__ = ["id"] + ID_FIELD_NUMBER: _ClassVar[int] + id: str + def __init__(self, id: _Optional[str] = ...) -> None: ... + +class GetSegmentsRequest(_message.Message): + __slots__ = ["id", "type", "scope", "topic", "collection"] + ID_FIELD_NUMBER: _ClassVar[int] + TYPE_FIELD_NUMBER: _ClassVar[int] + SCOPE_FIELD_NUMBER: _ClassVar[int] + TOPIC_FIELD_NUMBER: _ClassVar[int] COLLECTION_FIELD_NUMBER: _ClassVar[int] - GET_OR_CREATE_FIELD_NUMBER: _ClassVar[int] - collection: Collection - get_or_create: bool - def __init__( - self, - collection: _Optional[_Union[Collection, _Mapping]] = ..., - get_or_create: bool = ..., - ) -> None: ... + id: str + type: str + scope: _chroma_pb2.SegmentScope + topic: str + collection: str + def __init__(self, id: _Optional[str] = ..., type: _Optional[str] = ..., scope: _Optional[_Union[_chroma_pb2.SegmentScope, str]] = ..., topic: _Optional[str] = ..., collection: _Optional[str] = ...) -> None: ... -class Collection(_message.Message): - __slots__ = ["id", "name", "metadata"] +class GetSegmentsResponse(_message.Message): + __slots__ = ["segments", "status"] + SEGMENTS_FIELD_NUMBER: _ClassVar[int] + STATUS_FIELD_NUMBER: _ClassVar[int] + segments: _containers.RepeatedCompositeFieldContainer[_chroma_pb2.Segment] + status: _chroma_pb2.Status + def __init__(self, segments: _Optional[_Iterable[_Union[_chroma_pb2.Segment, _Mapping]]] = ..., status: _Optional[_Union[_chroma_pb2.Status, _Mapping]] = ...) -> None: ... + +class UpdateSegmentRequest(_message.Message): + __slots__ = ["id", "topic", "reset_topic", "collection", "reset_collection", "metadata", "reset_metadata"] ID_FIELD_NUMBER: _ClassVar[int] - NAME_FIELD_NUMBER: _ClassVar[int] + TOPIC_FIELD_NUMBER: _ClassVar[int] + RESET_TOPIC_FIELD_NUMBER: _ClassVar[int] + COLLECTION_FIELD_NUMBER: _ClassVar[int] + RESET_COLLECTION_FIELD_NUMBER: _ClassVar[int] METADATA_FIELD_NUMBER: _ClassVar[int] + RESET_METADATA_FIELD_NUMBER: _ClassVar[int] id: str - name: str + topic: str + reset_topic: bool + collection: str + reset_collection: bool metadata: _chroma_pb2.UpdateMetadata - def __init__( - self, - id: _Optional[str] = ..., - name: _Optional[str] = ..., - metadata: _Optional[_Union[_chroma_pb2.UpdateMetadata, _Mapping]] = ..., - ) -> None: ... + reset_metadata: bool + def __init__(self, id: _Optional[str] = ..., topic: _Optional[str] = ..., reset_topic: bool = ..., collection: _Optional[str] = ..., reset_collection: bool = ..., metadata: _Optional[_Union[_chroma_pb2.UpdateMetadata, _Mapping]] = ..., reset_metadata: bool = ...) -> None: ... + +class CreateCollectionRequest(_message.Message): + __slots__ = ["collection", "get_or_create"] + COLLECTION_FIELD_NUMBER: _ClassVar[int] + GET_OR_CREATE_FIELD_NUMBER: _ClassVar[int] + collection: _chroma_pb2.Collection + get_or_create: bool + def __init__(self, collection: _Optional[_Union[_chroma_pb2.Collection, _Mapping]] = ..., get_or_create: bool = ...) -> None: ... class CreateCollectionResponse(_message.Message): __slots__ = ["collection", "status"] COLLECTION_FIELD_NUMBER: _ClassVar[int] STATUS_FIELD_NUMBER: _ClassVar[int] - collection: Collection - status: Status - def __init__( - self, - collection: _Optional[_Union[Collection, _Mapping]] = ..., - status: _Optional[_Union[Status, _Mapping]] = ..., - ) -> None: ... + collection: _chroma_pb2.Collection + status: _chroma_pb2.Status + def __init__(self, collection: _Optional[_Union[_chroma_pb2.Collection, _Mapping]] = ..., status: _Optional[_Union[_chroma_pb2.Status, _Mapping]] = ...) -> None: ... + +class DeleteCollectionRequest(_message.Message): + __slots__ = ["id"] + ID_FIELD_NUMBER: _ClassVar[int] + id: str + def __init__(self, id: _Optional[str] = ...) -> None: ... class GetCollectionsRequest(_message.Message): __slots__ = ["id", "name", "topic"] @@ -69,21 +89,28 @@ class GetCollectionsRequest(_message.Message): id: str name: str topic: str - def __init__( - self, - id: _Optional[str] = ..., - name: _Optional[str] = ..., - topic: _Optional[str] = ..., - ) -> None: ... + def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., topic: _Optional[str] = ...) -> None: ... class GetCollectionsResponse(_message.Message): __slots__ = ["collections", "status"] COLLECTIONS_FIELD_NUMBER: _ClassVar[int] STATUS_FIELD_NUMBER: _ClassVar[int] - collections: _containers.RepeatedCompositeFieldContainer[Collection] - status: Status - def __init__( - self, - collections: _Optional[_Iterable[_Union[Collection, _Mapping]]] = ..., - status: _Optional[_Union[Status, _Mapping]] = ..., - ) -> None: ... + collections: _containers.RepeatedCompositeFieldContainer[_chroma_pb2.Collection] + status: _chroma_pb2.Status + def __init__(self, collections: _Optional[_Iterable[_Union[_chroma_pb2.Collection, _Mapping]]] = ..., status: _Optional[_Union[_chroma_pb2.Status, _Mapping]] = ...) -> None: ... + +class UpdateCollectionRequest(_message.Message): + __slots__ = ["id", "topic", "name", "dimension", "metadata", "reset_metadata"] + ID_FIELD_NUMBER: _ClassVar[int] + TOPIC_FIELD_NUMBER: _ClassVar[int] + NAME_FIELD_NUMBER: _ClassVar[int] + DIMENSION_FIELD_NUMBER: _ClassVar[int] + METADATA_FIELD_NUMBER: _ClassVar[int] + RESET_METADATA_FIELD_NUMBER: _ClassVar[int] + id: str + topic: str + name: str + dimension: int + metadata: _chroma_pb2.UpdateMetadata + reset_metadata: bool + def __init__(self, id: _Optional[str] = ..., topic: _Optional[str] = ..., name: _Optional[str] = ..., dimension: _Optional[int] = ..., metadata: _Optional[_Union[_chroma_pb2.UpdateMetadata, _Mapping]] = ..., reset_metadata: bool = ...) -> None: ... diff --git a/chromadb/proto/coordinator_pb2_grpc.py b/chromadb/proto/coordinator_pb2_grpc.py index 3da8aaba7928..a3a1e03227b3 100644 --- a/chromadb/proto/coordinator_pb2_grpc.py +++ b/chromadb/proto/coordinator_pb2_grpc.py @@ -2,7 +2,9 @@ """Client and server classes corresponding to protobuf-defined services.""" import grpc +from chromadb.proto import chroma_pb2 as chromadb_dot_proto_dot_chroma__pb2 from chromadb.proto import coordinator_pb2 as chromadb_dot_proto_dot_coordinator__pb2 +from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 class SysDBStub(object): @@ -14,46 +16,158 @@ def __init__(self, channel): Args: channel: A grpc.Channel. """ + self.CreateSegment = channel.unary_unary( + "/chroma.SysDB/CreateSegment", + request_serializer=chromadb_dot_proto_dot_coordinator__pb2.CreateSegmentRequest.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + ) + self.DeleteSegment = channel.unary_unary( + "/chroma.SysDB/DeleteSegment", + request_serializer=chromadb_dot_proto_dot_coordinator__pb2.DeleteSegmentRequest.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + ) + self.GetSegments = channel.unary_unary( + "/chroma.SysDB/GetSegments", + request_serializer=chromadb_dot_proto_dot_coordinator__pb2.GetSegmentsRequest.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_coordinator__pb2.GetSegmentsResponse.FromString, + ) + self.UpdateSegment = channel.unary_unary( + "/chroma.SysDB/UpdateSegment", + request_serializer=chromadb_dot_proto_dot_coordinator__pb2.UpdateSegmentRequest.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + ) self.CreateCollection = channel.unary_unary( "/chroma.SysDB/CreateCollection", request_serializer=chromadb_dot_proto_dot_coordinator__pb2.CreateCollectionRequest.SerializeToString, response_deserializer=chromadb_dot_proto_dot_coordinator__pb2.CreateCollectionResponse.FromString, ) + self.DeleteCollection = channel.unary_unary( + "/chroma.SysDB/DeleteCollection", + request_serializer=chromadb_dot_proto_dot_coordinator__pb2.DeleteCollectionRequest.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + ) self.GetCollections = channel.unary_unary( "/chroma.SysDB/GetCollections", request_serializer=chromadb_dot_proto_dot_coordinator__pb2.GetCollectionsRequest.SerializeToString, response_deserializer=chromadb_dot_proto_dot_coordinator__pb2.GetCollectionsResponse.FromString, ) + self.UpdateCollection = channel.unary_unary( + "/chroma.SysDB/UpdateCollection", + request_serializer=chromadb_dot_proto_dot_coordinator__pb2.UpdateCollectionRequest.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + ) + self.ResetState = channel.unary_unary( + "/chroma.SysDB/ResetState", + request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + ) class SysDBServicer(object): """Missing associated documentation comment in .proto file.""" + def CreateSegment(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def DeleteSegment(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def GetSegments(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def UpdateSegment(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + def CreateCollection(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details("Method not implemented!") raise NotImplementedError("Method not implemented!") + def DeleteCollection(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + def GetCollections(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details("Method not implemented!") raise NotImplementedError("Method not implemented!") + def UpdateCollection(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def ResetState(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + def add_SysDBServicer_to_server(servicer, server): rpc_method_handlers = { + "CreateSegment": grpc.unary_unary_rpc_method_handler( + servicer.CreateSegment, + request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.CreateSegmentRequest.FromString, + response_serializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.SerializeToString, + ), + "DeleteSegment": grpc.unary_unary_rpc_method_handler( + servicer.DeleteSegment, + request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.DeleteSegmentRequest.FromString, + response_serializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.SerializeToString, + ), + "GetSegments": grpc.unary_unary_rpc_method_handler( + servicer.GetSegments, + request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.GetSegmentsRequest.FromString, + response_serializer=chromadb_dot_proto_dot_coordinator__pb2.GetSegmentsResponse.SerializeToString, + ), + "UpdateSegment": grpc.unary_unary_rpc_method_handler( + servicer.UpdateSegment, + request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.UpdateSegmentRequest.FromString, + response_serializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.SerializeToString, + ), "CreateCollection": grpc.unary_unary_rpc_method_handler( servicer.CreateCollection, request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.CreateCollectionRequest.FromString, response_serializer=chromadb_dot_proto_dot_coordinator__pb2.CreateCollectionResponse.SerializeToString, ), + "DeleteCollection": grpc.unary_unary_rpc_method_handler( + servicer.DeleteCollection, + request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.DeleteCollectionRequest.FromString, + response_serializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.SerializeToString, + ), "GetCollections": grpc.unary_unary_rpc_method_handler( servicer.GetCollections, request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.GetCollectionsRequest.FromString, response_serializer=chromadb_dot_proto_dot_coordinator__pb2.GetCollectionsResponse.SerializeToString, ), + "UpdateCollection": grpc.unary_unary_rpc_method_handler( + servicer.UpdateCollection, + request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.UpdateCollectionRequest.FromString, + response_serializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.SerializeToString, + ), + "ResetState": grpc.unary_unary_rpc_method_handler( + servicer.ResetState, + request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, + response_serializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( "chroma.SysDB", rpc_method_handlers @@ -65,6 +179,122 @@ def add_SysDBServicer_to_server(servicer, server): class SysDB(object): """Missing associated documentation comment in .proto file.""" + @staticmethod + def CreateSegment( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/chroma.SysDB/CreateSegment", + chromadb_dot_proto_dot_coordinator__pb2.CreateSegmentRequest.SerializeToString, + chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def DeleteSegment( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/chroma.SysDB/DeleteSegment", + chromadb_dot_proto_dot_coordinator__pb2.DeleteSegmentRequest.SerializeToString, + chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def GetSegments( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/chroma.SysDB/GetSegments", + chromadb_dot_proto_dot_coordinator__pb2.GetSegmentsRequest.SerializeToString, + chromadb_dot_proto_dot_coordinator__pb2.GetSegmentsResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def UpdateSegment( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/chroma.SysDB/UpdateSegment", + chromadb_dot_proto_dot_coordinator__pb2.UpdateSegmentRequest.SerializeToString, + chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + @staticmethod def CreateCollection( request, @@ -94,6 +324,35 @@ def CreateCollection( metadata, ) + @staticmethod + def DeleteCollection( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/chroma.SysDB/DeleteCollection", + chromadb_dot_proto_dot_coordinator__pb2.DeleteCollectionRequest.SerializeToString, + chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + @staticmethod def GetCollections( request, @@ -122,3 +381,61 @@ def GetCollections( timeout, metadata, ) + + @staticmethod + def UpdateCollection( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/chroma.SysDB/UpdateCollection", + chromadb_dot_proto_dot_coordinator__pb2.UpdateCollectionRequest.SerializeToString, + chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def ResetState( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/chroma.SysDB/ResetState", + google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) diff --git a/chromadb/test/db/test_system.py b/chromadb/test/db/test_system.py index 82c9b6a8f611..be4ae299358a 100644 --- a/chromadb/test/db/test_system.py +++ b/chromadb/test/db/test_system.py @@ -3,6 +3,8 @@ import tempfile import pytest from typing import Generator, List, Callable, Dict, Union +from chromadb.db.impl.grpc.client import GrpcSysDB +from chromadb.db.impl.grpc.server import GrpcMockSysDB from chromadb.types import Collection, Segment, SegmentScope from chromadb.db.impl.sqlite import SqliteDB from chromadb.config import System, Settings @@ -35,8 +37,19 @@ def sqlite_persistent() -> Generator[SysDB, None, None]: shutil.rmtree(save_path) +def grpc_with_mock_server() -> Generator[SysDB, None, None]: + """Fixture generator for sqlite DB that creates a mock grpc sysdb server + and a grpc client that connects to it.""" + system = System(Settings(allow_reset=True)) + system.instance(GrpcMockSysDB) + client = system.instance(GrpcSysDB) + system.start() + client.reset_and_wait_for_ready() + yield client + + def db_fixtures() -> List[Callable[[], Generator[SysDB, None, None]]]: - return [sqlite, sqlite_persistent] + return [sqlite, sqlite_persistent, grpc_with_mock_server] @pytest.fixture(scope="module", params=db_fixtures()) diff --git a/idl/chromadb/proto/chroma.proto b/idl/chromadb/proto/chroma.proto index ddc7f11bc269..7b1f10f18ea6 100644 --- a/idl/chromadb/proto/chroma.proto +++ b/idl/chromadb/proto/chroma.proto @@ -2,6 +2,17 @@ syntax = "proto3"; package chroma; +message Status { + string reason = 1; + int32 code = 2; // TODO: What is the enum of this code? +} + +message ChromaResponse { + Status status = 1; +} + +// Types here should mirror chromadb/types.py + enum Operation { ADD = 0; UPDATE = 1; @@ -36,6 +47,14 @@ message Segment { optional UpdateMetadata metadata = 6; } +message Collection { + string id = 1; + string name = 2; + string topic = 3; + optional UpdateMetadata metadata = 4; + optional int32 dimension = 5; +} + message UpdateMetadataValue { oneof value { string string_value = 1; diff --git a/idl/chromadb/proto/coordinator.proto b/idl/chromadb/proto/coordinator.proto index 1148c197328b..d4b95058451f 100644 --- a/idl/chromadb/proto/coordinator.proto +++ b/idl/chromadb/proto/coordinator.proto @@ -3,21 +3,49 @@ syntax = "proto3"; package chroma; import "chromadb/proto/chroma.proto"; +import "google/protobuf/empty.proto"; -message Status { - string reason = 1; - int32 code = 2; +message CreateSegmentRequest { + Segment segment = 1; } -message CreateCollectionRequest { - Collection collection = 1; - optional bool get_or_create = 2; +message DeleteSegmentRequest { + string id = 1; } -message Collection { +message GetSegmentsRequest { + optional string id = 1; + optional string type = 2; + optional SegmentScope scope = 3; + optional string topic = 4; + optional string collection = 5; +} + +message GetSegmentsResponse { + repeated Segment segments = 1; + Status status = 2; +} + + +message UpdateSegmentRequest { string id = 1; - string name = 2; - optional UpdateMetadata metadata = 3; + oneof topic_update { + string topic = 2; + bool reset_topic = 3; + } + oneof collection_update { + string collection = 4; + bool reset_collection = 5; + } + oneof metadata_update { + UpdateMetadata metadata = 6; + bool reset_metadata = 7; + } +} + +message CreateCollectionRequest { + Collection collection = 1; + optional bool get_or_create = 2; } message CreateCollectionResponse { @@ -25,6 +53,10 @@ message CreateCollectionResponse { Status status = 2; } +message DeleteCollectionRequest { + string id = 1; +} + message GetCollectionsRequest { optional string id = 1; optional string name = 2; @@ -36,7 +68,25 @@ message GetCollectionsResponse { Status status = 2; } +message UpdateCollectionRequest { + string id = 1; + optional string topic = 2; + optional string name = 3; + optional int32 dimension = 4; + oneof metadata_update { + UpdateMetadata metadata = 5; + bool reset_metadata = 6; + } +} + service SysDB { + rpc CreateSegment(CreateSegmentRequest) returns (ChromaResponse) {} + rpc DeleteSegment(DeleteSegmentRequest) returns (ChromaResponse) {} + rpc GetSegments(GetSegmentsRequest) returns (GetSegmentsResponse) {} + rpc UpdateSegment(UpdateSegmentRequest) returns (ChromaResponse) {} rpc CreateCollection(CreateCollectionRequest) returns (CreateCollectionResponse) {} + rpc DeleteCollection(DeleteCollectionRequest) returns (ChromaResponse) {} rpc GetCollections(GetCollectionsRequest) returns (GetCollectionsResponse) {} + rpc UpdateCollection(UpdateCollectionRequest) returns (ChromaResponse) {} + rpc ResetState(google.protobuf.Empty) returns (ChromaResponse) {} }