diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index 045f16507f63..852a128b0cad 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -580,7 +580,7 @@ def _get( } ) - coll = self._get_collection(collection_id) + scan = self._scan(collection_id) # TODO: Replace with unified validation if where is not None: @@ -619,7 +619,7 @@ def _get( return self._executor.get( GetPlan( - Scan(coll), + scan, Filter(ids, where, where_document), Limit(offset or 0, limit), Projection( @@ -676,7 +676,7 @@ def _delete( """ ) - coll = self._get_collection(collection_id) + scan = self._scan(collection_id) self._quota_enforcer.enforce( action=Action.DELETE, @@ -690,7 +690,7 @@ def _delete( if (where or where_document) or not ids: ids_to_delete = self._executor.get( - GetPlan(Scan(coll), Filter(ids, where, where_document)) + GetPlan(scan, Filter(ids, where, where_document)) )["ids"] else: ids_to_delete = ids @@ -701,7 +701,7 @@ def _delete( records_to_submit = list( _records(operation=t.Operation.DELETE, ids=ids_to_delete) ) - self._validate_embedding_record_set(coll, records_to_submit) + self._validate_embedding_record_set(scan.collection, records_to_submit) self._producer.submit_embeddings(collection_id, records_to_submit) self._product_telemetry_client.capture( @@ -726,8 +726,7 @@ def _count( database: str = DEFAULT_DATABASE, ) -> int: add_attributes_to_current_span({"collection_id": str(collection_id)}) - coll = self._get_collection(collection_id) - return self._executor.count(CountPlan(Scan(coll))) + return self._executor.count(CountPlan(self._scan(collection_id))) @trace_method("SegmentAPI._query", OpenTelemetryGranularity.OPERATION) # We retry on version mismatch errors because the version of the collection @@ -785,9 +784,9 @@ def _query( if where_document is not None: validate_where_document(where_document) - coll = self._get_collection(collection_id) + scan = self._scan(collection_id) for embedding in query_embeddings: - self._validate_dimension(coll, len(embedding), update=False) + self._validate_dimension(scan.collection, len(embedding), update=False) self._quota_enforcer.enforce( action=Action.QUERY, @@ -800,7 +799,7 @@ def _query( return self._executor.knn( KNNPlan( - Scan(coll), + scan, KNN(query_embeddings, n_results), Filter(None, where, where_document), Projection( @@ -893,6 +892,18 @@ def _get_collection(self, collection_id: UUID) -> t.Collection: ) return collections[0] + @trace_method("SegmentAPI._scan", OpenTelemetryGranularity.ALL) + def _scan(self, collection_id: UUID) -> Scan: + collection_segments = self._sysdb.get_collection_with_segments(collection_id) + scope_to_segment = {segment["scope"]: segment for segment in collection_segments["segments"]} + return Scan( + collection=collection_segments["collection"], + knn=scope_to_segment[t.SegmentScope.VECTOR], + metadata=scope_to_segment[t.SegmentScope.METADATA], + # Local chroma do not have record segment + record=scope_to_segment.get(t.SegmentScope.RECORD, None), # type: ignore[arg-type] + ) + def _records( operation: t.Operation, diff --git a/chromadb/db/impl/grpc/client.py b/chromadb/db/impl/grpc/client.py index bb5e9d8fdb01..e01bfd68bffb 100644 --- a/chromadb/db/impl/grpc/client.py +++ b/chromadb/db/impl/grpc/client.py @@ -376,6 +376,8 @@ def get_collection_with_segments(self, collection_id: UUID) -> CollectionSegment segments=[from_proto_segment(segment) for segment in response.segments] ) except grpc.RpcError as e: + if e.code() == grpc.StatusCode.NOT_FOUND: + raise NotFoundError() logger.error( f"Failed to get collection {collection_id} and its segments due to error: {e}" ) diff --git a/chromadb/db/mixins/sysdb.py b/chromadb/db/mixins/sysdb.py index 75adfea14f93..303d58488db1 100644 --- a/chromadb/db/mixins/sysdb.py +++ b/chromadb/db/mixins/sysdb.py @@ -15,7 +15,7 @@ from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System from chromadb.db.base import Cursor, SqlDB, ParameterValue, get_sql from chromadb.db.system import SysDB -from chromadb.errors import NotFoundError, UniqueConstraintError +from chromadb.errors import InvalidCollectionException, NotFoundError, UniqueConstraintError from chromadb.telemetry.opentelemetry import ( add_attributes_to_current_span, OpenTelemetryClient, @@ -491,8 +491,13 @@ def get_collections( @override def get_collection_with_segments(self, collection_id: UUID) -> CollectionSegments: + collections = self.get_collections(id=collection_id) + if len(collections) == 0: + raise InvalidCollectionException( + f"Collection {collection_id} does not exist." + ) return CollectionSegments( - collection=self.get_collections(id=collection_id)[0], + collection=collections[0], segments=self.get_segments(collection=collection_id), ) diff --git a/chromadb/execution/executor/distributed.py b/chromadb/execution/executor/distributed.py index 5f1a34ab64e4..c41f52a58161 100644 --- a/chromadb/execution/executor/distributed.py +++ b/chromadb/execution/executor/distributed.py @@ -7,7 +7,7 @@ from chromadb.config import System from chromadb.errors import VersionMismatchError from chromadb.execution.executor.abstract import Executor -from chromadb.execution.expression.operator import Scan, SegmentScan +from chromadb.execution.expression.operator import Scan from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan from chromadb.proto import convert @@ -15,7 +15,6 @@ from chromadb.proto.utils import RetryOnRpcErrorClientInterceptor from chromadb.segment.impl.manager.distributed import DistributedSegmentManager from chromadb.telemetry.opentelemetry.grpc import OtelInterceptor -from chromadb.types import SegmentScope def _clean_metadata(metadata: Optional[Metadata]) -> Optional[Metadata]: @@ -55,7 +54,6 @@ def __init__(self, system: System): @overrides def count(self, plan: CountPlan) -> int: executor = self._grpc_executuor_stub(plan.scan) - plan.scan = self._segment_scan(plan.scan) try: count_result = executor.Count(convert.to_proto_count_plan(plan)) except grpc.RpcError as rpc_error: @@ -70,7 +68,6 @@ def count(self, plan: CountPlan) -> int: @overrides def get(self, plan: GetPlan) -> GetResult: executor = self._grpc_executuor_stub(plan.scan) - plan.scan = self._segment_scan(plan.scan) try: get_result = executor.Get(convert.to_proto_get_plan(plan)) except grpc.RpcError as rpc_error: @@ -118,7 +115,6 @@ def get(self, plan: GetPlan) -> GetResult: @overrides def knn(self, plan: KNNPlan) -> QueryResult: executor = self._grpc_executuor_stub(plan.scan) - plan.scan = self._segment_scan(plan.scan) try: knn_result = executor.KNN(convert.to_proto_knn_plan(plan)) except grpc.RpcError as rpc_error: @@ -181,18 +177,10 @@ def knn(self, plan: KNNPlan) -> QueryResult: included=plan.projection.included, ) - def _segment_scan(self, scan: Scan) -> SegmentScan: - collection_segments = self._manager.get_collection_segments(scan.collection.id) - scope_to_segment = {segment["scope"]: segment["id"] for segment in collection_segments["segments"]} - return SegmentScan( - collection=collection_segments["collection"], - knn_id=scope_to_segment[SegmentScope.VECTOR], - metadata_id=scope_to_segment[SegmentScope.METADATA], - record_id=scope_to_segment[SegmentScope.RECORD], - ) - def _grpc_executuor_stub(self, scan: Scan) -> QueryExecutorStub: - grpc_url = self._manager.get_endpoint(scan.collection.id) + # Since grpc endpoint is endpoint is determined by collection uuid, + # the endpoint should be the same for all segments of the same collection + grpc_url = self._manager.get_endpoint(scan.record) if grpc_url not in self._grpc_stub_pool: channel = grpc.insecure_channel(grpc_url) interceptors = [OtelInterceptor(), RetryOnRpcErrorClientInterceptor()] diff --git a/chromadb/execution/expression/operator.py b/chromadb/execution/expression/operator.py index 2f56df78041c..01dff0bb84fc 100644 --- a/chromadb/execution/expression/operator.py +++ b/chromadb/execution/expression/operator.py @@ -1,14 +1,16 @@ from dataclasses import dataclass from typing import Optional -from uuid import UUID from chromadb.api.types import Embeddings, IDs, Include, IncludeEnum -from chromadb.types import Collection, RequestVersionContext, Where, WhereDocument +from chromadb.types import Collection, RequestVersionContext, Segment, Where, WhereDocument @dataclass class Scan: collection: Collection + knn: Segment + metadata: Segment + record: Segment @property def version(self) -> RequestVersionContext: @@ -17,14 +19,6 @@ def version(self) -> RequestVersionContext: log_position=self.collection.log_position, ) - -@dataclass -class SegmentScan(Scan): - knn_id: UUID - metadata_id: UUID - record_id: UUID - - @dataclass class Filter: user_ids: Optional[IDs] = None diff --git a/chromadb/proto/convert.py b/chromadb/proto/convert.py index 47f5e0e08b9c..75f73769c8a7 100644 --- a/chromadb/proto/convert.py +++ b/chromadb/proto/convert.py @@ -13,7 +13,7 @@ Filter, Limit, Projection, - SegmentScan, + Scan, ) from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan from chromadb.types import ( @@ -568,12 +568,12 @@ def to_proto_where_document(where_document: WhereDocument) -> chroma_pb.WhereDoc return response -def to_proto_scan(scan: SegmentScan) -> query_pb.ScanOperator: +def to_proto_scan(scan: Scan) -> query_pb.ScanOperator: return query_pb.ScanOperator( collection=to_proto_collection(scan.collection), - knn_id=scan.knn_id.hex, - metadata_id=scan.metadata_id.hex, - record_id=scan.record_id.hex, + knn_id=scan.knn["id"].hex, + metadata_id=scan.metadata["id"].hex, + record_id=scan.record["id"].hex, ) diff --git a/chromadb/segment/impl/manager/distributed.py b/chromadb/segment/impl/manager/distributed.py index ebed75b7261d..0f92e0c00078 100644 --- a/chromadb/segment/impl/manager/distributed.py +++ b/chromadb/segment/impl/manager/distributed.py @@ -26,9 +26,6 @@ class DistributedSegmentManager(SegmentManager): _system: System _opentelemetry_client: OpenTelemetryClient _instances: Dict[UUID, SegmentImplementation] - _collection_segment_cache: Dict[ - UUID, CollectionSegments - ] # collection_id -> (collection_with_version, segments) _segment_directory: SegmentDirectory _lock: Lock # _segment_server_stubs: Dict[str, SegmentServerStub] # grpc_url -> grpc stub @@ -40,7 +37,6 @@ def __init__(self, system: System): self._system = system self._opentelemetry_client = system.require(OpenTelemetryClient) self._instances = {} - self._collection_segment_cache = {} self._lock = Lock() @trace_method( @@ -81,23 +77,11 @@ def delete_segments(self, collection_id: UUID) -> Sequence[UUID]: segments = self._sysdb.get_segments(collection=collection_id) return [s["id"] for s in segments] - @trace_method( - "DistributedSegmentManager.get_collection_segments", - OpenTelemetryGranularity.OPERATION_AND_SEGMENT, - ) - def get_collection_segments(self, collection_id: UUID) -> CollectionSegments: - if collection_id not in self._collection_segment_cache: - self._collection_segment_cache[collection_id] = self._sysdb.get_collection_with_segments(collection_id) - return self._collection_segment_cache[collection_id] - @trace_method( "DistributedSegmentManager.get_endpoint", OpenTelemetryGranularity.OPERATION_AND_SEGMENT, ) - def get_endpoint(self, collection_id: UUID) -> str: - # Since grpc endpoint is endpoint is determined by collection uuid, - # the endpoint should be the same for all segments of the same collection - segment = self.get_collection_segments(collection_id)["segments"][0] + def get_endpoint(self, segment: Segment) -> str: return self._segment_directory.get_segment_endpoint(segment) @trace_method(