Skip to content

Commit

Permalink
Always fetch latest collection information
Browse files Browse the repository at this point in the history
  • Loading branch information
Sicheng Pan committed Dec 6, 2024
1 parent 1b3cd98 commit 4575446
Showing 7 changed files with 44 additions and 60 deletions.
31 changes: 21 additions & 10 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
@@ -580,7 +580,7 @@ def _get(
}
)

coll = self._get_collection(collection_id)
scan = self._scan(collection_id)

# TODO: Replace with unified validation
if where is not None:
@@ -619,7 +619,7 @@ def _get(

return self._executor.get(
GetPlan(
Scan(coll),
scan,
Filter(ids, where, where_document),
Limit(offset or 0, limit),
Projection(
@@ -676,7 +676,7 @@ def _delete(
"""
)

coll = self._get_collection(collection_id)
scan = self._scan(collection_id)

self._quota_enforcer.enforce(
action=Action.DELETE,
@@ -690,7 +690,7 @@ def _delete(

if (where or where_document) or not ids:
ids_to_delete = self._executor.get(
GetPlan(Scan(coll), Filter(ids, where, where_document))
GetPlan(scan, Filter(ids, where, where_document))
)["ids"]
else:
ids_to_delete = ids
@@ -701,7 +701,7 @@ def _delete(
records_to_submit = list(
_records(operation=t.Operation.DELETE, ids=ids_to_delete)
)
self._validate_embedding_record_set(coll, records_to_submit)
self._validate_embedding_record_set(scan.collection, records_to_submit)
self._producer.submit_embeddings(collection_id, records_to_submit)

self._product_telemetry_client.capture(
@@ -726,8 +726,7 @@ def _count(
database: str = DEFAULT_DATABASE,
) -> int:
add_attributes_to_current_span({"collection_id": str(collection_id)})
coll = self._get_collection(collection_id)
return self._executor.count(CountPlan(Scan(coll)))
return self._executor.count(CountPlan(self._scan(collection_id)))

@trace_method("SegmentAPI._query", OpenTelemetryGranularity.OPERATION)
# We retry on version mismatch errors because the version of the collection
@@ -785,9 +784,9 @@ def _query(
if where_document is not None:
validate_where_document(where_document)

coll = self._get_collection(collection_id)
scan = self._scan(collection_id)
for embedding in query_embeddings:
self._validate_dimension(coll, len(embedding), update=False)
self._validate_dimension(scan.collection, len(embedding), update=False)

self._quota_enforcer.enforce(
action=Action.QUERY,
@@ -800,7 +799,7 @@ def _query(

return self._executor.knn(
KNNPlan(
Scan(coll),
scan,
KNN(query_embeddings, n_results),
Filter(None, where, where_document),
Projection(
@@ -893,6 +892,18 @@ def _get_collection(self, collection_id: UUID) -> t.Collection:
)
return collections[0]

@trace_method("SegmentAPI._scan", OpenTelemetryGranularity.ALL)
def _scan(self, collection_id: UUID) -> Scan:
collection_segments = self._sysdb.get_collection_with_segments(collection_id)
scope_to_segment = {segment["scope"]: segment for segment in collection_segments["segments"]}
return Scan(
collection=collection_segments["collection"],
knn=scope_to_segment[t.SegmentScope.VECTOR],
metadata=scope_to_segment[t.SegmentScope.METADATA],
# Local chroma do not have record segment
record=scope_to_segment.get(t.SegmentScope.RECORD, None), # type: ignore[arg-type]
)


def _records(
operation: t.Operation,
2 changes: 2 additions & 0 deletions chromadb/db/impl/grpc/client.py
Original file line number Diff line number Diff line change
@@ -376,6 +376,8 @@ def get_collection_with_segments(self, collection_id: UUID) -> CollectionSegment
segments=[from_proto_segment(segment) for segment in response.segments]
)
except grpc.RpcError as e:
if e.code() == grpc.StatusCode.NOT_FOUND:
raise NotFoundError()
logger.error(
f"Failed to get collection {collection_id} and its segments due to error: {e}"
)
9 changes: 7 additions & 2 deletions chromadb/db/mixins/sysdb.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System
from chromadb.db.base import Cursor, SqlDB, ParameterValue, get_sql
from chromadb.db.system import SysDB
from chromadb.errors import NotFoundError, UniqueConstraintError
from chromadb.errors import InvalidCollectionException, NotFoundError, UniqueConstraintError
from chromadb.telemetry.opentelemetry import (
add_attributes_to_current_span,
OpenTelemetryClient,
@@ -491,8 +491,13 @@ def get_collections(

@override
def get_collection_with_segments(self, collection_id: UUID) -> CollectionSegments:
collections = self.get_collections(id=collection_id)
if len(collections) == 0:
raise InvalidCollectionException(
f"Collection {collection_id} does not exist."
)
return CollectionSegments(
collection=self.get_collections(id=collection_id)[0],
collection=collections[0],
segments=self.get_segments(collection=collection_id),
)

20 changes: 4 additions & 16 deletions chromadb/execution/executor/distributed.py
Original file line number Diff line number Diff line change
@@ -7,15 +7,14 @@
from chromadb.config import System
from chromadb.errors import VersionMismatchError
from chromadb.execution.executor.abstract import Executor
from chromadb.execution.expression.operator import Scan, SegmentScan
from chromadb.execution.expression.operator import Scan
from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan
from chromadb.proto import convert

from chromadb.proto.query_executor_pb2_grpc import QueryExecutorStub
from chromadb.proto.utils import RetryOnRpcErrorClientInterceptor
from chromadb.segment.impl.manager.distributed import DistributedSegmentManager
from chromadb.telemetry.opentelemetry.grpc import OtelInterceptor
from chromadb.types import SegmentScope


def _clean_metadata(metadata: Optional[Metadata]) -> Optional[Metadata]:
@@ -55,7 +54,6 @@ def __init__(self, system: System):
@overrides
def count(self, plan: CountPlan) -> int:
executor = self._grpc_executuor_stub(plan.scan)
plan.scan = self._segment_scan(plan.scan)
try:
count_result = executor.Count(convert.to_proto_count_plan(plan))
except grpc.RpcError as rpc_error:
@@ -70,7 +68,6 @@ def count(self, plan: CountPlan) -> int:
@overrides
def get(self, plan: GetPlan) -> GetResult:
executor = self._grpc_executuor_stub(plan.scan)
plan.scan = self._segment_scan(plan.scan)
try:
get_result = executor.Get(convert.to_proto_get_plan(plan))
except grpc.RpcError as rpc_error:
@@ -118,7 +115,6 @@ def get(self, plan: GetPlan) -> GetResult:
@overrides
def knn(self, plan: KNNPlan) -> QueryResult:
executor = self._grpc_executuor_stub(plan.scan)
plan.scan = self._segment_scan(plan.scan)
try:
knn_result = executor.KNN(convert.to_proto_knn_plan(plan))
except grpc.RpcError as rpc_error:
@@ -181,18 +177,10 @@ def knn(self, plan: KNNPlan) -> QueryResult:
included=plan.projection.included,
)

def _segment_scan(self, scan: Scan) -> SegmentScan:
collection_segments = self._manager.get_collection_segments(scan.collection.id)
scope_to_segment = {segment["scope"]: segment["id"] for segment in collection_segments["segments"]}
return SegmentScan(
collection=collection_segments["collection"],
knn_id=scope_to_segment[SegmentScope.VECTOR],
metadata_id=scope_to_segment[SegmentScope.METADATA],
record_id=scope_to_segment[SegmentScope.RECORD],
)

def _grpc_executuor_stub(self, scan: Scan) -> QueryExecutorStub:
grpc_url = self._manager.get_endpoint(scan.collection.id)
# Since grpc endpoint is endpoint is determined by collection uuid,
# the endpoint should be the same for all segments of the same collection
grpc_url = self._manager.get_endpoint(scan.record)
if grpc_url not in self._grpc_stub_pool:
channel = grpc.insecure_channel(grpc_url)
interceptors = [OtelInterceptor(), RetryOnRpcErrorClientInterceptor()]
14 changes: 4 additions & 10 deletions chromadb/execution/expression/operator.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from dataclasses import dataclass
from typing import Optional
from uuid import UUID

from chromadb.api.types import Embeddings, IDs, Include, IncludeEnum
from chromadb.types import Collection, RequestVersionContext, Where, WhereDocument
from chromadb.types import Collection, RequestVersionContext, Segment, Where, WhereDocument


@dataclass
class Scan:
collection: Collection
knn: Segment
metadata: Segment
record: Segment

@property
def version(self) -> RequestVersionContext:
@@ -17,14 +19,6 @@ def version(self) -> RequestVersionContext:
log_position=self.collection.log_position,
)


@dataclass
class SegmentScan(Scan):
knn_id: UUID
metadata_id: UUID
record_id: UUID


@dataclass
class Filter:
user_ids: Optional[IDs] = None
10 changes: 5 additions & 5 deletions chromadb/proto/convert.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@
Filter,
Limit,
Projection,
SegmentScan,
Scan,
)
from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan
from chromadb.types import (
@@ -568,12 +568,12 @@ def to_proto_where_document(where_document: WhereDocument) -> chroma_pb.WhereDoc
return response


def to_proto_scan(scan: SegmentScan) -> query_pb.ScanOperator:
def to_proto_scan(scan: Scan) -> query_pb.ScanOperator:
return query_pb.ScanOperator(
collection=to_proto_collection(scan.collection),
knn_id=scan.knn_id.hex,
metadata_id=scan.metadata_id.hex,
record_id=scan.record_id.hex,
knn_id=scan.knn["id"].hex,
metadata_id=scan.metadata["id"].hex,
record_id=scan.record["id"].hex,
)


18 changes: 1 addition & 17 deletions chromadb/segment/impl/manager/distributed.py
Original file line number Diff line number Diff line change
@@ -26,9 +26,6 @@ class DistributedSegmentManager(SegmentManager):
_system: System
_opentelemetry_client: OpenTelemetryClient
_instances: Dict[UUID, SegmentImplementation]
_collection_segment_cache: Dict[
UUID, CollectionSegments
] # collection_id -> (collection_with_version, segments)
_segment_directory: SegmentDirectory
_lock: Lock
# _segment_server_stubs: Dict[str, SegmentServerStub] # grpc_url -> grpc stub
@@ -40,7 +37,6 @@ def __init__(self, system: System):
self._system = system
self._opentelemetry_client = system.require(OpenTelemetryClient)
self._instances = {}
self._collection_segment_cache = {}
self._lock = Lock()

@trace_method(
@@ -81,23 +77,11 @@ def delete_segments(self, collection_id: UUID) -> Sequence[UUID]:
segments = self._sysdb.get_segments(collection=collection_id)
return [s["id"] for s in segments]

@trace_method(
"DistributedSegmentManager.get_collection_segments",
OpenTelemetryGranularity.OPERATION_AND_SEGMENT,
)
def get_collection_segments(self, collection_id: UUID) -> CollectionSegments:
if collection_id not in self._collection_segment_cache:
self._collection_segment_cache[collection_id] = self._sysdb.get_collection_with_segments(collection_id)
return self._collection_segment_cache[collection_id]

@trace_method(
"DistributedSegmentManager.get_endpoint",
OpenTelemetryGranularity.OPERATION_AND_SEGMENT,
)
def get_endpoint(self, collection_id: UUID) -> str:
# Since grpc endpoint is endpoint is determined by collection uuid,
# the endpoint should be the same for all segments of the same collection
segment = self.get_collection_segments(collection_id)["segments"][0]
def get_endpoint(self, segment: Segment) -> str:
return self._segment_directory.get_segment_endpoint(segment)

@trace_method(

0 comments on commit 4575446

Please sign in to comment.