From 2cd4f74f32cc57985f352a8e27d5335f67b2919a Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Fri, 13 Dec 2024 18:29:20 +0200 Subject: [PATCH] fix: delete_collection resource leak --- chromadb/api/segment.py | 5 +- chromadb/segment/__init__.py | 6 +- chromadb/segment/impl/manager/distributed.py | 3 +- chromadb/segment/impl/manager/local.py | 8 ++- chromadb/test/property/invariants.py | 58 +++++++++++++++++++- chromadb/test/property/test_collections.py | 6 +- 6 files changed, 74 insertions(+), 12 deletions(-) diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index 045f16507f6..906faf73a6d 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -227,7 +227,7 @@ def create_collection( id=model.id, name=model.name, configuration=model.get_configuration(), - segments=[], # Passing empty till backend changes are deployed. + segments=[], # Passing empty till backend changes are deployed. metadata=model.metadata, dimension=None, # This is lazily populated on the first add get_or_create=get_or_create, @@ -384,10 +384,11 @@ def delete_collection( ) if existing: + segments = self._sysdb.get_segments(collection=existing[0].id) self._sysdb.delete_collection( existing[0].id, tenant=tenant, database=database ) - self._manager.delete_segments(existing[0].id) + self._manager.delete_segments(segments) else: raise ValueError(f"Collection {name} does not exist.") diff --git a/chromadb/segment/__init__.py b/chromadb/segment/__init__.py index d1e440f17c5..b41ecb32bbf 100644 --- a/chromadb/segment/__init__.py +++ b/chromadb/segment/__init__.py @@ -104,13 +104,15 @@ class SegmentManager(Component): segments as required""" @abstractmethod - def prepare_segments_for_new_collection(self, collection: Collection) -> Sequence[Segment]: + def prepare_segments_for_new_collection( + self, collection: Collection + ) -> Sequence[Segment]: """Return the segments required for a new collection. Returns only segment data, does not persist to the SysDB""" pass @abstractmethod - def delete_segments(self, collection_id: UUID) -> Sequence[UUID]: + def delete_segments(self, segments: Sequence[Segment]) -> Sequence[UUID]: """Delete any local state for all the segments associated with a collection, and returns a sequence of their IDs. Does not update the SysDB.""" pass diff --git a/chromadb/segment/impl/manager/distributed.py b/chromadb/segment/impl/manager/distributed.py index 033cf375e9c..218efe341fe 100644 --- a/chromadb/segment/impl/manager/distributed.py +++ b/chromadb/segment/impl/manager/distributed.py @@ -78,8 +78,7 @@ def prepare_segments_for_new_collection( return [vector_segment, record_segment, metadata_segment] @override - def delete_segments(self, collection_id: UUID) -> Sequence[UUID]: - segments = self._sysdb.get_segments(collection=collection_id) + def delete_segments(self, segments: Sequence[Segment]) -> Sequence[UUID]: return [s["id"] for s in segments] @trace_method( diff --git a/chromadb/segment/impl/manager/local.py b/chromadb/segment/impl/manager/local.py index 296ace7f9e7..4d4e24b7d0e 100644 --- a/chromadb/segment/impl/manager/local.py +++ b/chromadb/segment/impl/manager/local.py @@ -137,7 +137,9 @@ def reset_state(self) -> None: OpenTelemetryGranularity.OPERATION_AND_SEGMENT, ) @override - def prepare_segments_for_new_collection(self, collection: Collection) -> Sequence[Segment]: + def prepare_segments_for_new_collection( + self, collection: Collection + ) -> Sequence[Segment]: vector_segment = _segment( self._vector_segment_type, SegmentScope.VECTOR, collection ) @@ -151,9 +153,9 @@ def prepare_segments_for_new_collection(self, collection: Collection) -> Sequenc OpenTelemetryGranularity.OPERATION_AND_SEGMENT, ) @override - def delete_segments(self, collection_id: UUID) -> Sequence[UUID]: - segments = self._sysdb.get_segments(collection=collection_id) + def delete_segments(self, segments: Sequence[Segment]) -> Sequence[UUID]: for segment in segments: + collection_id = segment["collection"] if segment["id"] in self._instances: if segment["type"] == SegmentType.HNSW_LOCAL_PERSISTED.value: instance = self.get_segment(collection_id, VectorReader) diff --git a/chromadb/test/property/invariants.py b/chromadb/test/property/invariants.py index 8dba12ffb14..a9f121d64f2 100644 --- a/chromadb/test/property/invariants.py +++ b/chromadb/test/property/invariants.py @@ -1,7 +1,11 @@ import gc import math +import os.path from uuid import UUID +from contextlib import contextmanager +from chromadb.api.segment import SegmentAPI +from chromadb.db.system import SysDB from chromadb.ingest.impl.utils import create_topic_name from chromadb.config import System @@ -9,12 +13,14 @@ from chromadb.db.impl.sqlite import SqliteDB from time import sleep import psutil + +from chromadb.segment import SegmentType from chromadb.test.property.strategies import NormalizedRecordSet, RecordSet from typing import Callable, Optional, Tuple, Union, List, TypeVar, cast, Any, Dict from typing_extensions import Literal import numpy as np import numpy.typing as npt -from chromadb.api import types +from chromadb.api import types, ClientAPI from chromadb.api.models.Collection import Collection from hypothesis import note from hypothesis.errors import InvalidArgument @@ -457,3 +463,53 @@ def log_size_for_collections_match_expected( else: assert _total_embedding_queue_log_size(sqlite) == 0 + + +@contextmanager +def collection_deleted(client: ClientAPI, collection_name: str): + # Invariant checks before deletion + assert collection_name in [c.name for c in client.list_collections()] + collection = client.get_collection(collection_name) + segments = [] + if isinstance(client._server, SegmentAPI): # type: ignore + sysdb: SysDB = client._server._sysdb # type: ignore + segments = sysdb.get_segments(collection=collection.id) + segment_types = {} + should_have_hnsw = False + for segment in segments: + segment_types[segment["type"]] = True + if segment["type"] == SegmentType.HNSW_LOCAL_PERSISTED.value: + sync_threshold = ( + collection.metadata["hnsw:sync_threshold"] + if collection.metadata is not None + and "hnsw:sync_threshold" in collection.metadata + else 1000 + ) + if ( + collection.count() > sync_threshold + ): # we only check if vector segment dir exists if we've synced at least once + should_have_hnsw = True + assert os.path.exists( + os.path.join( + client.get_settings().persist_directory, str(segment["id"]) + ) + ) + if should_have_hnsw: + assert segment_types[SegmentType.HNSW_LOCAL_PERSISTED.value] + assert segment_types[SegmentType.SQLITE.value] + + yield + + # Invariant checks after deletion + assert collection_name not in [c.name for c in client.list_collections()] + if len(segments) > 0: + sysdb: SysDB = client._server._sysdb # type: ignore + segments_after = sysdb.get_segments(collection=collection.id) + assert len(segments_after) == 0 + for segment in segments: + if segment["type"] == SegmentType.HNSW_LOCAL_PERSISTED.value: + assert not os.path.exists( + os.path.join( + client.get_settings().persist_directory, str(segment["id"]) + ) + ) diff --git a/chromadb/test/property/test_collections.py b/chromadb/test/property/test_collections.py index 8dd10837ce1..6a432ca66c1 100644 --- a/chromadb/test/property/test_collections.py +++ b/chromadb/test/property/test_collections.py @@ -14,6 +14,7 @@ run_state_machine_as_test, MultipleResults, ) +import chromadb.test.property.invariants as invariants from typing import Any, Dict, Mapping, Optional import numpy from chromadb.test.property.strategies import hashing_embedding_function @@ -75,8 +76,9 @@ def get_coll(self, coll: strategies.ExternalCollection) -> None: @rule(coll=consumes(collections)) def delete_coll(self, coll: strategies.ExternalCollection) -> None: if coll.name in self.model: - self.client.delete_collection(name=coll.name) - self.delete_from_model(coll.name) + with invariants.collection_deleted(self.client, coll.name): + self.client.delete_collection(name=coll.name) + self.delete_from_model(coll.name) else: with pytest.raises(Exception): self.client.delete_collection(name=coll.name)