From da0858d63f7bef8e99953d7b4e467879be49afe2 Mon Sep 17 00:00:00 2001 From: Macronova <60079945+Sicheng-Pan@users.noreply.github.com> Date: Tue, 17 Dec 2024 14:54:34 -0800 Subject: [PATCH] [CLN] Cleanup frontend after query pushdown (#3291) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Remove `Grpc*Segment` implementation in frontend, since they are no longer used after query pushdown - Remove `test_version_mismatch.py` and corresponding exception handling in distributed executor, since we no longer have version mismatch after passing the full collection and segment information form frontend - Add a few tests in `test_protobuf_translation.py` to test translation for collection and segment information - New functionality - N/A ## Test plan *How are these changes tested?* - [ ] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Documentation Changes *Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?* N/A --- chromadb/execution/executor/distributed.py | 15 - .../segment/impl/metadata/grpc_segment.py | 394 ------------------ chromadb/segment/impl/vector/grpc_segment.py | 145 ------- .../test/distributed/test_version_mismatch.py | 219 ---------- .../distributed/test_protobuf_translation.py | 230 +++++----- 5 files changed, 109 insertions(+), 894 deletions(-) delete mode 100644 chromadb/segment/impl/metadata/grpc_segment.py delete mode 100644 chromadb/segment/impl/vector/grpc_segment.py delete mode 100644 chromadb/test/distributed/test_version_mismatch.py diff --git a/chromadb/execution/executor/distributed.py b/chromadb/execution/executor/distributed.py index c41f52a5816..3cf5c591c77 100644 --- a/chromadb/execution/executor/distributed.py +++ b/chromadb/execution/executor/distributed.py @@ -57,11 +57,6 @@ def count(self, plan: CountPlan) -> int: try: count_result = executor.Count(convert.to_proto_count_plan(plan)) except grpc.RpcError as rpc_error: - if ( - rpc_error.code() == grpc.StatusCode.INTERNAL - and "version mismatch" in rpc_error.details() - ): - raise VersionMismatchError() raise rpc_error return convert.from_proto_count_result(count_result) @@ -71,11 +66,6 @@ def get(self, plan: GetPlan) -> GetResult: try: get_result = executor.Get(convert.to_proto_get_plan(plan)) except grpc.RpcError as rpc_error: - if ( - rpc_error.code() == grpc.StatusCode.INTERNAL - and "version mismatch" in rpc_error.details() - ): - raise VersionMismatchError() raise rpc_error records = convert.from_proto_get_result(get_result) @@ -118,11 +108,6 @@ def knn(self, plan: KNNPlan) -> QueryResult: try: knn_result = executor.KNN(convert.to_proto_knn_plan(plan)) except grpc.RpcError as rpc_error: - if ( - rpc_error.code() == grpc.StatusCode.INTERNAL - and "version mismatch" in rpc_error.details() - ): - raise VersionMismatchError() raise rpc_error results = convert.from_proto_knn_batch_result(knn_result) diff --git a/chromadb/segment/impl/metadata/grpc_segment.py b/chromadb/segment/impl/metadata/grpc_segment.py deleted file mode 100644 index 53ffdc72734..00000000000 --- a/chromadb/segment/impl/metadata/grpc_segment.py +++ /dev/null @@ -1,394 +0,0 @@ -from typing import Dict, List, Optional, Sequence -from chromadb.proto.convert import to_proto_request_version_context -from chromadb.proto.utils import RetryOnRpcErrorClientInterceptor -from chromadb.segment import MetadataReader -from chromadb.config import System -from chromadb.errors import InvalidArgumentError, VersionMismatchError -from chromadb.types import Segment, RequestVersionContext -from overrides import override -from chromadb.telemetry.opentelemetry import ( - OpenTelemetryGranularity, - trace_method, -) -from chromadb.telemetry.opentelemetry.grpc import OtelInterceptor -from chromadb.types import ( - Where, - WhereDocument, - MetadataEmbeddingRecord, -) -from chromadb.proto.chroma_pb2_grpc import MetadataReaderStub -import chromadb.proto.chroma_pb2 as pb -import grpc - - -class GrpcMetadataSegment(MetadataReader): - """Embedding Metadata segment interface""" - - _request_timeout_seconds: int - _metadata_reader_stub: MetadataReaderStub - _segment: Segment - - def __init__(self, system: System, segment: Segment) -> None: - super().__init__(system, segment) # type: ignore[safe-super] - if not segment["metadata"] or not segment["metadata"]["grpc_url"]: - raise Exception("Missing grpc_url in segment metadata") - - self._segment = segment - self._request_timeout_seconds = system.settings.require( - "chroma_query_request_timeout_seconds" - ) - - @override - def start(self) -> None: - if not self._segment["metadata"] or not self._segment["metadata"]["grpc_url"]: - raise Exception("Missing grpc_url in segment metadata") - - channel = grpc.insecure_channel(self._segment["metadata"]["grpc_url"]) - interceptors = [OtelInterceptor(), RetryOnRpcErrorClientInterceptor()] - channel = grpc.intercept_channel(channel, *interceptors) - self._metadata_reader_stub = MetadataReaderStub(channel) # type: ignore - - @override - def count(self, request_version_context: RequestVersionContext) -> int: - request: pb.CountRecordsRequest = pb.CountRecordsRequest( - segment_id=self._segment["id"].hex, - collection_id=self._segment["collection"].hex, - version_context=to_proto_request_version_context(request_version_context), - ) - - try: - response: pb.CountRecordsResponse = self._metadata_reader_stub.CountRecords( - request, - timeout=self._request_timeout_seconds, - ) - except grpc.RpcError as rpc_error: - message = rpc_error.details() - if "Collection version mismatch" in message: - raise VersionMismatchError() - raise rpc_error - - return response.count - - @override - def delete(self, where: Optional[Where] = None) -> None: - raise NotImplementedError() - - @override - def max_seqid(self) -> int: - raise NotImplementedError() - - @trace_method( - "GrpcMetadataSegment.get_metadata", - OpenTelemetryGranularity.ALL, - ) - @override - def get_metadata( - self, - request_version_context: RequestVersionContext, - where: Optional[Where] = None, - where_document: Optional[WhereDocument] = None, - ids: Optional[Sequence[str]] = None, - limit: Optional[int] = None, - offset: Optional[int] = None, - include_metadata: bool = True, - ) -> Sequence[MetadataEmbeddingRecord]: - """Query for embedding metadata.""" - - if limit is not None and limit < 0: - raise InvalidArgumentError(f"Limit cannot be negative: {limit}") - - if offset is not None and offset < 0: - raise InvalidArgumentError(f"Offset cannot be negative: {offset}") - - request: pb.QueryMetadataRequest = pb.QueryMetadataRequest( - segment_id=self._segment["id"].hex, - collection_id=self._segment["collection"].hex, - where=self._where_to_proto(where) - if where is not None and len(where) > 0 - else None, - where_document=( - self._where_document_to_proto(where_document) - if where_document is not None and len(where_document) > 0 - else None - ), - ids=pb.UserIds(ids=ids) if ids is not None else None, - limit=limit, - offset=offset, - include_metadata=include_metadata, - version_context=to_proto_request_version_context(request_version_context), - ) - - try: - response: pb.QueryMetadataResponse = ( - self._metadata_reader_stub.QueryMetadata( - request, - timeout=self._request_timeout_seconds, - ) - ) - except grpc.RpcError as rpc_error: - message = rpc_error.details() - if "Collection version mismatch" in message: - raise VersionMismatchError() - raise rpc_error - - results: List[MetadataEmbeddingRecord] = [] - for record in response.records: - result = self._from_proto(record) - results.append(result) - - return results - - def _where_to_proto(self, where: Optional[Where]) -> pb.Where: - response = pb.Where() - if where is None: - return response - if len(where) != 1: - raise ValueError( - f"Expected where to have exactly one operator, got {where}" - ) - - for key, value in where.items(): - if not isinstance(key, str): - raise ValueError(f"Expected where key to be a str, got {key}") - - if key == "$and" or key == "$or": - if not isinstance(value, list): - raise ValueError( - f"Expected where value for $and or $or to be a list of where expressions, got {value}" - ) - children: pb.WhereChildren = pb.WhereChildren( - children=[self._where_to_proto(w) for w in value] - ) - if key == "$and": - children.operator = pb.BooleanOperator.AND - else: - children.operator = pb.BooleanOperator.OR - - response.children.CopyFrom(children) - return response - - # At this point we know we're at a direct comparison. It can either - # be of the form {"key": "value"} or {"key": {"$operator": "value"}}. - - dc = pb.DirectComparison() - dc.key = key - - if not isinstance(value, dict): - # {'key': 'value'} case - if type(value) is str: - ssc = pb.SingleStringComparison() - ssc.value = value - ssc.comparator = pb.GenericComparator.EQ - dc.single_string_operand.CopyFrom(ssc) - elif type(value) is bool: - sbc = pb.SingleBoolComparison() - sbc.value = value - sbc.comparator = pb.GenericComparator.EQ - dc.single_bool_operand.CopyFrom(sbc) - elif type(value) is int: - sic = pb.SingleIntComparison() - sic.value = value - sic.generic_comparator = pb.GenericComparator.EQ - dc.single_int_operand.CopyFrom(sic) - elif type(value) is float: - sdc = pb.SingleDoubleComparison() - sdc.value = value - sdc.generic_comparator = pb.GenericComparator.EQ - dc.single_double_operand.CopyFrom(sdc) - else: - raise ValueError( - f"Expected where value to be a string, int, or float, got {value}" - ) - else: - for operator, operand in value.items(): - if operator in ["$in", "$nin"]: - if not isinstance(operand, list): - raise ValueError( - f"Expected where value for $in or $nin to be a list of values, got {value}" - ) - if len(operand) == 0 or not all( - isinstance(x, type(operand[0])) for x in operand - ): - raise ValueError( - f"Expected where operand value to be a non-empty list, and all values to be of the same type " - f"got {operand}" - ) - list_operator = None - if operator == "$in": - list_operator = pb.ListOperator.IN - else: - list_operator = pb.ListOperator.NIN - if type(operand[0]) is str: - slo = pb.StringListComparison() - for x in operand: - slo.values.extend([x]) # type: ignore - slo.list_operator = list_operator - dc.string_list_operand.CopyFrom(slo) - elif type(operand[0]) is bool: - blo = pb.BoolListComparison() - for x in operand: - blo.values.extend([x]) # type: ignore - blo.list_operator = list_operator - dc.bool_list_operand.CopyFrom(blo) - elif type(operand[0]) is int: - ilo = pb.IntListComparison() - for x in operand: - ilo.values.extend([x]) # type: ignore - ilo.list_operator = list_operator - dc.int_list_operand.CopyFrom(ilo) - elif type(operand[0]) is float: - dlo = pb.DoubleListComparison() - for x in operand: - dlo.values.extend([x]) # type: ignore - dlo.list_operator = list_operator - dc.double_list_operand.CopyFrom(dlo) - else: - raise ValueError( - f"Expected where operand value to be a list of strings, ints, or floats, got {operand}" - ) - elif operator in ["$eq", "$ne", "$gt", "$lt", "$gte", "$lte"]: - # Direct comparison to a single value. - if type(operand) is str: - ssc = pb.SingleStringComparison() - ssc.value = operand - if operator == "$eq": - ssc.comparator = pb.GenericComparator.EQ - elif operator == "$ne": - ssc.comparator = pb.GenericComparator.NE - else: - raise ValueError( - f"Expected where operator to be $eq or $ne, got {operator}" - ) - dc.single_string_operand.CopyFrom(ssc) - elif type(operand) is bool: - sbc = pb.SingleBoolComparison() - sbc.value = operand - if operator == "$eq": - sbc.comparator = pb.GenericComparator.EQ - elif operator == "$ne": - sbc.comparator = pb.GenericComparator.NE - else: - raise ValueError( - f"Expected where operator to be $eq or $ne, got {operator}" - ) - dc.single_bool_operand.CopyFrom(sbc) - elif type(operand) is int: - sic = pb.SingleIntComparison() - sic.value = operand - if operator == "$eq": - sic.generic_comparator = pb.GenericComparator.EQ - elif operator == "$ne": - sic.generic_comparator = pb.GenericComparator.NE - elif operator == "$gt": - sic.number_comparator = pb.NumberComparator.GT - elif operator == "$lt": - sic.number_comparator = pb.NumberComparator.LT - elif operator == "$gte": - sic.number_comparator = pb.NumberComparator.GTE - elif operator == "$lte": - sic.number_comparator = pb.NumberComparator.LTE - else: - raise ValueError( - f"Expected where operator to be one of $eq, $ne, $gt, $lt, $gte, $lte, got {operator}" - ) - dc.single_int_operand.CopyFrom(sic) - elif type(operand) is float: - sfc = pb.SingleDoubleComparison() - sfc.value = operand - if operator == "$eq": - sfc.generic_comparator = pb.GenericComparator.EQ - elif operator == "$ne": - sfc.generic_comparator = pb.GenericComparator.NE - elif operator == "$gt": - sfc.number_comparator = pb.NumberComparator.GT - elif operator == "$lt": - sfc.number_comparator = pb.NumberComparator.LT - elif operator == "$gte": - sfc.number_comparator = pb.NumberComparator.GTE - elif operator == "$lte": - sfc.number_comparator = pb.NumberComparator.LTE - else: - raise ValueError( - f"Expected where operator to be one of $eq, $ne, $gt, $lt, $gte, $lte, got {operator}" - ) - dc.single_double_operand.CopyFrom(sfc) - else: - raise ValueError( - f"Expected where operand value to be a string, int, or float, got {operand}" - ) - else: - # This case should never happen, as we've already - # handled the case for direct comparisons. - pass - - response.direct_comparison.CopyFrom(dc) - return response - - def _where_document_to_proto( - self, where_document: Optional[WhereDocument] - ) -> pb.WhereDocument: - response = pb.WhereDocument() - if where_document is None: - return response - if len(where_document) != 1: - raise ValueError( - f"Expected where_document to have exactly one operator, got {where_document}" - ) - - for operator, operand in where_document.items(): - if operator == "$and" or operator == "$or": - # Nested "$and" or "$or" expression. - if not isinstance(operand, list): - raise ValueError( - f"Expected where_document value for $and or $or to be a list of where_document expressions, got {operand}" - ) - children: pb.WhereDocumentChildren = pb.WhereDocumentChildren( - children=[self._where_document_to_proto(w) for w in operand] - ) - if operator == "$and": - children.operator = pb.BooleanOperator.AND - else: - children.operator = pb.BooleanOperator.OR - - response.children.CopyFrom(children) - else: - # Direct "$contains" or "$not_contains" comparison to a single - # value. - if not isinstance(operand, str): - raise ValueError( - f"Expected where_document operand to be a string, got {operand}" - ) - dwd = pb.DirectWhereDocument() - dwd.document = operand - if operator == "$contains": - dwd.operator = pb.WhereDocumentOperator.CONTAINS - elif operator == "$not_contains": - dwd.operator = pb.WhereDocumentOperator.NOT_CONTAINS - else: - raise ValueError( - f"Expected where_document operator to be one of $contains, $not_contains, got {operator}" - ) - response.direct.CopyFrom(dwd) - - return response - - def _from_proto( - self, record: pb.MetadataEmbeddingRecord - ) -> MetadataEmbeddingRecord: - translated_metadata: Dict[str, str | int | float | bool] = {} - record_metadata_map = record.metadata.metadata - for key, value in record_metadata_map.items(): - if value.HasField("bool_value"): - translated_metadata[key] = value.bool_value - elif value.HasField("string_value"): - translated_metadata[key] = value.string_value - elif value.HasField("int_value"): - translated_metadata[key] = value.int_value - elif value.HasField("float_value"): - translated_metadata[key] = value.float_value - else: - raise ValueError(f"Unknown metadata value type: {value}") - - mer = MetadataEmbeddingRecord(id=record.id, metadata=translated_metadata) - - return mer diff --git a/chromadb/segment/impl/vector/grpc_segment.py b/chromadb/segment/impl/vector/grpc_segment.py deleted file mode 100644 index a66de4a71cd..00000000000 --- a/chromadb/segment/impl/vector/grpc_segment.py +++ /dev/null @@ -1,145 +0,0 @@ -from overrides import EnforceOverrides, override -from typing import List, Optional, Sequence -from chromadb.config import System -from chromadb.proto.convert import ( - from_proto_vector_embedding_record, - from_proto_vector_query_result, - to_proto_request_version_context, - to_proto_vector, -) -from chromadb.proto.utils import RetryOnRpcErrorClientInterceptor -from chromadb.segment import VectorReader -from chromadb.errors import VersionMismatchError -from chromadb.segment.impl.vector.hnsw_params import PersistentHnswParams -from chromadb.telemetry.opentelemetry import ( - OpenTelemetryGranularity, - trace_method, -) -from chromadb.telemetry.opentelemetry.grpc import OtelInterceptor -from chromadb.types import ( - Metadata, - RequestVersionContext, - ScalarEncoding, - Segment, - VectorEmbeddingRecord, - VectorQuery, - VectorQueryResult, -) -from chromadb.proto.chroma_pb2_grpc import VectorReaderStub -from chromadb.proto.chroma_pb2 import ( - GetVectorsRequest, - GetVectorsResponse, - QueryVectorsRequest, - QueryVectorsResponse, -) -import grpc - - -class GrpcVectorSegment(VectorReader, EnforceOverrides): - _vector_reader_stub: VectorReaderStub - _segment: Segment - _request_timeout_seconds: int - - def __init__(self, system: System, segment: Segment): - # TODO: move to start() method - # TODO: close channel in stop() method - if segment["metadata"] is None or segment["metadata"]["grpc_url"] is None: - raise Exception("Missing grpc_url in segment metadata") - - channel = grpc.insecure_channel(segment["metadata"]["grpc_url"]) - interceptors = [OtelInterceptor(), RetryOnRpcErrorClientInterceptor()] - channel = grpc.intercept_channel(channel, *interceptors) - self._vector_reader_stub = VectorReaderStub(channel) # type: ignore - self._segment = segment - self._request_timeout_seconds = system.settings.require( - "chroma_query_request_timeout_seconds" - ) - - @trace_method("GrpcVectorSegment.get_vectors", OpenTelemetryGranularity.ALL) - @override - def get_vectors( - self, - request_version_context: RequestVersionContext, - ids: Optional[Sequence[str]] = None, - ) -> Sequence[VectorEmbeddingRecord]: - request = GetVectorsRequest( - ids=ids, - segment_id=self._segment["id"].hex, - collection_id=self._segment["collection"].hex, - version_context=to_proto_request_version_context(request_version_context), - ) - - try: - response: GetVectorsResponse = self._vector_reader_stub.GetVectors( - request, - timeout=self._request_timeout_seconds, - ) - except grpc.RpcError as rpc_error: - message = rpc_error.details() - if "Collection version mismatch" in message: - raise VersionMismatchError() - raise rpc_error - - results: List[VectorEmbeddingRecord] = [] - for vector in response.records: - result = from_proto_vector_embedding_record(vector) - results.append(result) - return results - - @trace_method("GrpcVectorSegment.query_vectors", OpenTelemetryGranularity.ALL) - @override - def query_vectors( - self, query: VectorQuery - ) -> Sequence[Sequence[VectorQueryResult]]: - request = QueryVectorsRequest( - vectors=[ - to_proto_vector(vector=v, encoding=ScalarEncoding.FLOAT32) - for v in query["vectors"] - ], - k=query["k"], - allowed_ids=query["allowed_ids"], - include_embeddings=query["include_embeddings"], - segment_id=self._segment["id"].hex, - collection_id=self._segment["collection"].hex, - version_context=to_proto_request_version_context( - query["request_version_context"] - ), - ) - - try: - response: QueryVectorsResponse = self._vector_reader_stub.QueryVectors( - request, - timeout=self._request_timeout_seconds, - ) - except grpc.RpcError as rpc_error: - message = rpc_error.details() - if "Collection version mismatch" in message: - raise VersionMismatchError() - raise rpc_error - - results: List[List[VectorQueryResult]] = [] - for result in response.results: - curr_result: List[VectorQueryResult] = [] - for r in result.results: - curr_result.append(from_proto_vector_query_result(r)) - results.append(curr_result) - return results - - @override - def count(self, request_version_context: RequestVersionContext) -> int: - raise NotImplementedError() - - @override - def max_seqid(self) -> int: - return 0 - - @staticmethod - @override - def propagate_collection_metadata(metadata: Metadata) -> Optional[Metadata]: - # Great example of why language sharing is nice. - segment_metadata = PersistentHnswParams.extract(metadata) - return segment_metadata - - @override - def delete(self) -> None: - raise NotImplementedError() diff --git a/chromadb/test/distributed/test_version_mismatch.py b/chromadb/test/distributed/test_version_mismatch.py deleted file mode 100644 index 817abf5c0e4..00000000000 --- a/chromadb/test/distributed/test_version_mismatch.py +++ /dev/null @@ -1,219 +0,0 @@ -import random -from typing import List, Tuple -import uuid -from chromadb.api.models.Collection import Collection -from chromadb.config import Settings, System -from chromadb.db.impl.grpc.client import GrpcSysDB -from chromadb.db.system import SysDB -from chromadb.errors import VersionMismatchError -from chromadb.segment import MetadataReader, VectorReader -from chromadb.segment.impl.metadata.grpc_segment import GrpcMetadataSegment -from chromadb.segment.impl.vector.grpc_segment import GrpcVectorSegment -from chromadb.test.conftest import reset, skip_if_not_cluster -from chromadb.api import ClientAPI -from chromadb.test.utils.wait_for_version_increase import wait_for_version_increase -from chromadb.types import RequestVersionContext, SegmentScope, VectorQuery - - -# Helpers -def create_test_collection(client: ClientAPI, name: str) -> Collection: - return client.create_collection( - name=name, - metadata={"hnsw:construction_ef": 128, "hnsw:search_ef": 128, "hnsw:M": 128}, - ) - - -def add_random_records_and_wait_for_compaction( - client: ClientAPI, collection: Collection, n: int -) -> Tuple[List[str], List[List[float]], int]: - ids = [] - embeddings = [] - for i in range(n): - ids.append(str(i)) - embeddings.append([random.random(), random.random(), random.random()]) - collection.add( - ids=[str(i)], - embeddings=[embeddings[-1]], # type: ignore - ) - final_version = wait_for_version_increase( - client=client, collection_name=collection.name, initial_version=0 - ) - return ids, embeddings, final_version - - -def get_mock_frontend_system() -> System: - settings = Settings( - chroma_coordinator_host="localhost", chroma_server_grpc_port=50051 - ) - return System(settings) - - -def get_vector_segment( - system: System, sysdb: SysDB, collection: uuid.UUID -) -> GrpcVectorSegment: - segment = sysdb.get_segments(collection=collection, scope=SegmentScope.VECTOR)[0] - if segment["metadata"] is None: - segment["metadata"] = {} - # Inject the url, replicating the behavior of the segment manager, we use the tilt grpc server url - segment["metadata"]["grpc_url"] = "localhost:50053" # type: ignore - ret_segment = GrpcVectorSegment(system, segment) - ret_segment.start() - return ret_segment - - -def get_metadata_segment( - system: System, sysdb: SysDB, collection: uuid.UUID -) -> GrpcMetadataSegment: - segment = sysdb.get_segments(collection=collection, scope=SegmentScope.METADATA)[0] - if segment["metadata"] is None: - segment["metadata"] = {} - # Inject the url, replicating the behavior of the segment manager, we use the tilt grpc server url - segment["metadata"]["grpc_url"] = "localhost:50053" # type: ignore - ret_segment = GrpcMetadataSegment(system, segment) - ret_segment.start() - return ret_segment - - -def setup_vector_test( - client: ClientAPI, n: int -) -> Tuple[VectorReader, List[str], List[List[float]], int, int]: - reset(client) - collection = create_test_collection(client=client, name="test_version_mismatch") - ids, embeddings, version = add_random_records_and_wait_for_compaction( - client=client, collection=collection, n=n - ) - log_position = client.get_collection(collection.name)._model.log_position - - fe_system = get_mock_frontend_system() - sysdb = GrpcSysDB(fe_system) - sysdb.start() - - return ( - get_vector_segment(system=fe_system, sysdb=sysdb, collection=collection.id), - ids, - embeddings, - version, - log_position, - ) - - -def setup_metadata_test( - client: ClientAPI, n: int -) -> Tuple[MetadataReader, List[str], List[List[float]], int, int]: - reset(client) - collection = create_test_collection(client=client, name="test_version_mismatch") - ids, embeddings, version = add_random_records_and_wait_for_compaction( - client=client, collection=collection, n=n - ) - log_position = client.get_collection(collection.name)._model.log_position - - fe_system = get_mock_frontend_system() - sysdb = GrpcSysDB(fe_system) - sysdb.start() - - return ( - get_metadata_segment(system=fe_system, sysdb=sysdb, collection=collection.id), - ids, - embeddings, - version, - log_position, - ) - - -@skip_if_not_cluster() -def test_version_mistmatch_query_vectors( - client: ClientAPI, -) -> None: - N = 100 - reader, _, embeddings, compacted_version, log_position = setup_vector_test( - client=client, n=N - ) - request = VectorQuery( - vectors=[embeddings[0]], - request_version_context=RequestVersionContext( - collection_version=compacted_version, log_position=log_position - ), - k=10, - include_embeddings=False, - allowed_ids=None, - options=None, - ) - - reader.query_vectors(query=request) - # Now change the collection version to > N, which should cause a version mismatch - request["request_version_context"]["collection_version"] = N + 1 - try: - reader.query_vectors(request) - except VersionMismatchError: - pass - except Exception as e: - assert False, f"Unexpected exception {e}" - - -@skip_if_not_cluster() -def test_version_mistmatch_get_vectors( - client: ClientAPI, -) -> None: - N = 100 - reader, _, _, compacted_version, log_position = setup_vector_test( - client=client, n=N - ) - request_version_context = RequestVersionContext( - collection_version=compacted_version, log_position=log_position - ) - - reader.get_vectors(ids=None, request_version_context=request_version_context) - # Now change the collection version to > N, which should cause a version mismatch - request_version_context["collection_version"] = N + 1 - try: - reader.get_vectors(request_version_context) - except VersionMismatchError: - pass - except Exception as e: - assert False, f"Unexpected exception {e}" - - -@skip_if_not_cluster() -def test_version_mismatch_metadata_get( - client: ClientAPI, -) -> None: - N = 100 - reader, _, _, compacted_version, log_position = setup_metadata_test( - client=client, n=N - ) - request_version_context = RequestVersionContext( - collection_version=compacted_version, log_position=log_position - ) - - reader.get_metadata(request_version_context=request_version_context) - # Now change the collection version to > N, which should cause a version mismatch - request_version_context["collection_version"] = N + 1 - try: - reader.get_metadata(request_version_context) - except VersionMismatchError: - pass - except Exception as e: - assert False, f"Unexpected exception {e}" - - -@skip_if_not_cluster() -def test_version_mismatch_metadata_count( - client: ClientAPI, -) -> None: - N = 100 - reader, _, _, compacted_version, log_position = setup_metadata_test( - client=client, n=N - ) - request_version_context = RequestVersionContext( - collection_version=compacted_version, log_position=log_position - ) - - reader.count(request_version_context) - # Now change the collection version to > N, which should cause a version mismatch - request_version_context["collection_version"] = N + 1 - try: - reader.count(request_version_context) - except VersionMismatchError: - pass - except Exception as e: - assert False, f"Unexpected exception {e}" diff --git a/chromadb/test/segment/distributed/test_protobuf_translation.py b/chromadb/test/segment/distributed/test_protobuf_translation.py index 7cfc8b26f3f..d29fcad0365 100644 --- a/chromadb/test/segment/distributed/test_protobuf_translation.py +++ b/chromadb/test/segment/distributed/test_protobuf_translation.py @@ -1,69 +1,126 @@ import uuid - -from chromadb.config import Settings, System -from chromadb.segment.impl.metadata.grpc_segment import GrpcMetadataSegment +from chromadb.proto import convert +from chromadb.segment import SegmentType from chromadb.types import ( + Collection, + CollectionConfigurationInternal, Segment, SegmentScope, Where, WhereDocument, - MetadataEmbeddingRecord, ) import chromadb.proto.chroma_pb2 as pb +import chromadb.proto.query_executor_pb2 as query_pb +def test_collection_to_proto() -> None: + collection = Collection( + id=uuid.uuid4(), + name="test_collection", + configuration=CollectionConfigurationInternal(), + metadata={"hnsw_m": 128}, + dimension=512, + tenant="test_tenant", + database="test_database", + version=1, + log_position=42, + ) + + assert convert.to_proto_collection(collection) == pb.Collection( + id=collection.id.hex, + name="test_collection", + configuration_json_str=CollectionConfigurationInternal().to_json_str(), + metadata=pb.UpdateMetadata(metadata={"hnsw_m": pb.UpdateMetadataValue(int_value=128)}), + dimension=512, + tenant="test_tenant", + database="test_database", + version=1, + log_position=42, + ) -# Note: trying to start() this segment will cause it to error since it doesn't -# have a remote server to talk to. This is only suitable for testing the -# python <-> proto translation logic. -def unstarted_grpc_metadata_segment() -> GrpcMetadataSegment: - settings = Settings( - allow_reset=True, +def test_collection_from_proto() -> None: + proto = pb.Collection( + id=uuid.uuid4().hex, + name="test_collection", + configuration_json_str=CollectionConfigurationInternal().to_json_str(), + metadata=pb.UpdateMetadata(metadata={"hnsw_m": pb.UpdateMetadataValue(int_value=128)}), + dimension=512, + tenant="test_tenant", + database="test_database", + version=1, + log_position=42, ) - system = System(settings) + assert convert.from_proto_collection(proto) == Collection( + id=uuid.UUID(proto.id), + name="test_collection", + configuration=CollectionConfigurationInternal(), + metadata={"hnsw_m": 128}, + dimension=512, + tenant="test_tenant", + database="test_database", + version=1, + log_position=42, + ) + +def test_segment_to_proto() -> None: segment = Segment( id=uuid.uuid4(), - type="test", - scope=SegmentScope.METADATA, + type=SegmentType.HNSW_DISTRIBUTED.value, + scope=SegmentScope.VECTOR, collection=uuid.uuid4(), - metadata={ - "grpc_url": "test", - }, - file_paths={}, + metadata={"hnsw_m": 128}, + file_paths={"name": ["path_0", "path_1"]}, ) - grpc_metadata_segment = GrpcMetadataSegment( - system=system, - segment=segment, + assert convert.to_proto_segment(segment) == pb.Segment( + id=segment["id"].hex, + type=SegmentType.HNSW_DISTRIBUTED.value, + scope=pb.SegmentScope.VECTOR, + collection=segment["collection"].hex, + metadata=pb.UpdateMetadata(metadata={"hnsw_m": pb.UpdateMetadataValue(int_value=128)}), + file_paths={"name": pb.FilePaths(paths=["path_0", "path_1"])}, ) - return grpc_metadata_segment +def test_segment_from_proto() -> None: + proto = pb.Segment( + id=uuid.uuid4().hex, + type=SegmentType.HNSW_DISTRIBUTED.value, + scope=pb.SegmentScope.VECTOR, + collection=uuid.uuid4().hex, + metadata=pb.UpdateMetadata(metadata={"hnsw_m": pb.UpdateMetadataValue(int_value=128)}), + file_paths={"name": pb.FilePaths(paths=["path_0", "path_1"])}, + ) + assert convert.from_proto_segment(proto) == Segment( + id=uuid.UUID(proto.id), + type=SegmentType.HNSW_DISTRIBUTED.value, + scope=SegmentScope.VECTOR, + collection=uuid.UUID(proto.collection), + metadata={"hnsw_m": 128}, + file_paths={"name": ["path_0", "path_1"]}, + ) def test_where_document_to_proto_not_contains() -> None: - md_segment = unstarted_grpc_metadata_segment() where_document: WhereDocument = {"$not_contains": "test"} - proto = md_segment._where_document_to_proto(where_document) + proto = convert.to_proto_where_document(where_document) assert proto.HasField("direct") assert proto.direct.document == "test" assert proto.direct.operator == pb.WhereDocumentOperator.NOT_CONTAINS def test_where_document_to_proto_contains_to_proto() -> None: - md_segment = unstarted_grpc_metadata_segment() where_document: WhereDocument = {"$contains": "test"} - proto = md_segment._where_document_to_proto(where_document) + proto = convert.to_proto_where_document(where_document) assert proto.HasField("direct") assert proto.direct.document == "test" assert proto.direct.operator == pb.WhereDocumentOperator.CONTAINS def test_where_document_to_proto_and() -> None: - md_segment = unstarted_grpc_metadata_segment() where_document: WhereDocument = { "$and": [ {"$contains": "test"}, {"$not_contains": "test"}, ] } - proto = md_segment._where_document_to_proto(where_document) + proto = convert.to_proto_where_document(where_document) assert proto.HasField("children") children_pb = proto.children assert children_pb.operator == pb.BooleanOperator.AND @@ -79,14 +136,13 @@ def test_where_document_to_proto_and() -> None: def test_where_document_to_proto_or() -> None: - md_segment = unstarted_grpc_metadata_segment() where_document: WhereDocument = { "$or": [ {"$contains": "test"}, {"$not_contains": "test"}, ] } - proto = md_segment._where_document_to_proto(where_document) + proto = convert.to_proto_where_document(where_document) assert proto.HasField("children") children_pb = proto.children assert children_pb.operator == pb.BooleanOperator.OR @@ -102,7 +158,6 @@ def test_where_document_to_proto_or() -> None: def test_where_document_to_proto_nested_boolean_operators() -> None: - md_segment = unstarted_grpc_metadata_segment() where_document: WhereDocument = { "$and": [ { @@ -119,7 +174,7 @@ def test_where_document_to_proto_nested_boolean_operators() -> None: }, ] } - proto = md_segment._where_document_to_proto(where_document) + proto = convert.to_proto_where_document(where_document) assert proto.HasField("children") children_pb = proto.children assert children_pb.operator == pb.BooleanOperator.AND @@ -142,11 +197,10 @@ def test_where_document_to_proto_nested_boolean_operators() -> None: def test_where_to_proto_string_value() -> None: - md_segment = unstarted_grpc_metadata_segment() where: Where = { "test": "value", } - proto: pb.Where = md_segment._where_to_proto(where) + proto = convert.to_proto_where(where) assert proto.HasField("direct_comparison") d = proto.direct_comparison assert d.key == "test" @@ -155,11 +209,10 @@ def test_where_to_proto_string_value() -> None: def test_where_to_proto_int_value() -> None: - md_segment = unstarted_grpc_metadata_segment() where: Where = { "test": 1, } - proto = md_segment._where_to_proto(where) + proto = convert.to_proto_where(where) assert proto.HasField("direct_comparison") d = proto.direct_comparison assert d.key == "test" @@ -168,11 +221,10 @@ def test_where_to_proto_int_value() -> None: def test_where_to_proto_double_value() -> None: - md_segment = unstarted_grpc_metadata_segment() where: Where = { "test": 1.0, } - proto = md_segment._where_to_proto(where) + proto = convert.to_proto_where(where) assert proto.HasField("direct_comparison") d = proto.direct_comparison assert d.key == "test" @@ -181,14 +233,13 @@ def test_where_to_proto_double_value() -> None: def test_where_to_proto_and() -> None: - md_segment = unstarted_grpc_metadata_segment() where: Where = { "$and": [ {"test": 1}, {"test": "value"}, ] } - proto = md_segment._where_to_proto(where) + proto = convert.to_proto_where(where) assert proto.HasField("children") children_pb = proto.children assert children_pb.operator == pb.BooleanOperator.AND @@ -206,14 +257,13 @@ def test_where_to_proto_and() -> None: def test_where_to_proto_or() -> None: - md_segment = unstarted_grpc_metadata_segment() where: Where = { "$or": [ {"test": 1}, {"test": "value"}, ] } - proto = md_segment._where_to_proto(where) + proto = convert.to_proto_where(where) assert proto.HasField("children") children_pb = proto.children assert children_pb.operator == pb.BooleanOperator.OR @@ -231,7 +281,6 @@ def test_where_to_proto_or() -> None: def test_where_to_proto_nested_boolean_operators() -> None: - md_segment = unstarted_grpc_metadata_segment() where: Where = { "$and": [ { @@ -248,7 +297,7 @@ def test_where_to_proto_nested_boolean_operators() -> None: }, ] } - proto = md_segment._where_to_proto(where) + proto = convert.to_proto_where(where) assert proto.HasField("children") children_pb = proto.children assert children_pb.operator == pb.BooleanOperator.AND @@ -273,14 +322,13 @@ def test_where_to_proto_nested_boolean_operators() -> None: def test_where_to_proto_float_operator() -> None: - md_segment = unstarted_grpc_metadata_segment() where: Where = { "$and": [ {"test1": 1.0}, {"test2": 2.0}, ] } - proto = md_segment._where_to_proto(where) + proto = convert.to_proto_where(where) assert proto.HasField("children") children_pb = proto.children assert children_pb.operator == pb.BooleanOperator.AND @@ -300,89 +348,29 @@ def test_where_to_proto_float_operator() -> None: assert child_1.direct_comparison.single_double_operand.value == 2.0 -def test_metadata_embedding_record_string_from_proto() -> None: - md_segment = unstarted_grpc_metadata_segment() - val: pb.UpdateMetadataValue = pb.UpdateMetadataValue( - string_value="test_value", - ) - update: pb.UpdateMetadata = pb.UpdateMetadata( - metadata={"test_key": val}, - ) - record: pb.MetadataEmbeddingRecord = pb.MetadataEmbeddingRecord( - id="test_id", - metadata=update, - ) - - mdr: MetadataEmbeddingRecord = md_segment._from_proto(record) - assert mdr["id"] == "test_id" - assert mdr["metadata"] - assert mdr["metadata"]["test_key"] == "test_value" - - -def test_metadata_embedding_record_int_from_proto() -> None: - md_segment = unstarted_grpc_metadata_segment() - val: pb.UpdateMetadataValue = pb.UpdateMetadataValue( - int_value=1, - ) - update: pb.UpdateMetadata = pb.UpdateMetadata( - metadata={"test_key": val}, - ) - record: pb.MetadataEmbeddingRecord = pb.MetadataEmbeddingRecord( - id="test_id", - metadata=update, - ) - - mdr: MetadataEmbeddingRecord = md_segment._from_proto(record) - assert mdr["id"] == "test_id" - assert mdr["metadata"] - assert mdr["metadata"]["test_key"] == 1 - - -def test_metadata_embedding_record_double_from_proto() -> None: - md_segment = unstarted_grpc_metadata_segment() - val: pb.UpdateMetadataValue = pb.UpdateMetadataValue( +def test_projection_record_from_proto() -> None: + float_val: pb.UpdateMetadataValue = pb.UpdateMetadataValue( float_value=1.0, ) - update: pb.UpdateMetadata = pb.UpdateMetadata( - metadata={"test_key": val}, + int_val: pb.UpdateMetadataValue = pb.UpdateMetadataValue( + int_value=2, ) - record: pb.MetadataEmbeddingRecord = pb.MetadataEmbeddingRecord( - id="test_id", - metadata=update, - ) - - mdr: MetadataEmbeddingRecord = md_segment._from_proto(record) - assert mdr["id"] == "test_id" - assert mdr["metadata"] - assert mdr["metadata"]["test_key"] == 1.0 - - -def test_metadata_embedding_record_heterogeneous_from_proto() -> None: - md_segment = unstarted_grpc_metadata_segment() - val1: pb.UpdateMetadataValue = pb.UpdateMetadataValue( - string_value="test_value", - ) - val2: pb.UpdateMetadataValue = pb.UpdateMetadataValue( - int_value=1, - ) - val3: pb.UpdateMetadataValue = pb.UpdateMetadataValue( - float_value=1.0, + str_val: pb.UpdateMetadataValue = pb.UpdateMetadataValue( + string_value="three", ) update: pb.UpdateMetadata = pb.UpdateMetadata( - metadata={ - "test_key1": val1, - "test_key2": val2, - "test_key3": val3, - }, + metadata={"float_key": float_val, "int_key": int_val, "str_key": str_val}, ) - record: pb.MetadataEmbeddingRecord = pb.MetadataEmbeddingRecord( + record: query_pb.ProjectionRecord = query_pb.ProjectionRecord( id="test_id", + document="document", metadata=update, ) - mdr: MetadataEmbeddingRecord = md_segment._from_proto(record) - assert mdr["id"] == "test_id" - assert mdr["metadata"] - assert mdr["metadata"]["test_key1"] == "test_value" - assert mdr["metadata"]["test_key2"] == 1 - assert mdr["metadata"]["test_key3"] == 1.0 + projection_record = convert.from_proto_projection_record(record) + + assert projection_record["id"] == "test_id" + assert projection_record["metadata"] + assert projection_record["metadata"]["float_key"] == 1.0 + assert projection_record["metadata"]["int_key"] == 2 + assert projection_record["metadata"]["str_key"] == "three"