diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index d23139759d9..8eaa32adec6 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -1,7 +1,6 @@ from chromadb.api import API from chromadb.config import Settings, System from chromadb.db.system import SysDB -from chromadb.ingest.impl.utils import create_topic_name from chromadb.segment import SegmentManager, MetadataReader, VectorReader from chromadb.telemetry import Telemetry from chromadb.ingest import Producer @@ -79,10 +78,7 @@ class SegmentAPI(API): _sysdb: SysDB _manager: SegmentManager _producer: Producer - # TODO: fire telemetry events _telemetry_client: Telemetry - _tenant_id: str - _topic_ns: str _collection_cache: Dict[UUID, t.Collection] def __init__(self, system: System): @@ -92,8 +88,6 @@ def __init__(self, system: System): self._manager = self.require(SegmentManager) self._telemetry_client = self.require(Telemetry) self._producer = self.require(Producer) - self._tenant_id = system.settings.tenant_id - self._topic_ns = system.settings.topic_namespace self._collection_cache = {} @override @@ -135,15 +129,12 @@ def create_collection( check_index_name(name) id = uuid4() - coll = t.Collection( - id=id, name=name, metadata=metadata, topic=self._topic(id), dimension=None + + coll = self._sysdb.create_collection( + id=id, name=name, metadata=metadata, dimension=None ) - # TODO: Topic creation right now lives in the producer but it should be moved to the coordinator, - # and the producer should just be responsible for publishing messages. Coordinator should - # be responsible for all management of topics. - self._producer.create_topic(coll["topic"]) segments = self._manager.create_segments(coll) - self._sysdb.create_collection(coll) + for segment in segments: self._sysdb.create_segment(segment) @@ -244,6 +235,7 @@ def delete_collection(self, name: str) -> None: self._sysdb.delete_collection(existing[0]["id"]) for s in self._manager.delete_segments(existing[0]["id"]): self._sysdb.delete_segment(s) + # TODO: Move topic deletion into sysdb as well self._producer.delete_topic(existing[0]["topic"]) if existing and existing[0]["id"] in self._collection_cache: del self._collection_cache[existing[0]["id"]] @@ -618,9 +610,6 @@ def get_settings(self) -> Settings: def max_batch_size(self) -> int: return self._producer.max_batch_size - def _topic(self, collection_id: UUID) -> str: - return create_topic_name(self._tenant_id, self._topic_ns, str(collection_id)) - # TODO: This could potentially cause race conditions in a distributed version of the # system, since the cache is only local. # TODO: promote collection -> topic to a base class method so that it can be diff --git a/chromadb/config.py b/chromadb/config.py index a2af7bd32bc..eb7bca93ef5 100644 --- a/chromadb/config.py +++ b/chromadb/config.py @@ -67,6 +67,7 @@ "chromadb.telemetry.Telemetry": "chroma_telemetry_impl", "chromadb.ingest.Producer": "chroma_producer_impl", "chromadb.ingest.Consumer": "chroma_consumer_impl", + "chromadb.ingest.CollectionAssignmentPolicy": "chroma_collection_assignment_policy_impl", # noqa "chromadb.db.system.SysDB": "chroma_sysdb_impl", "chromadb.segment.SegmentManager": "chroma_segment_manager_impl", "chromadb.segment.distributed.SegmentDirectory": "chroma_segment_directory_impl", @@ -77,7 +78,8 @@ class Settings(BaseSettings): # type: ignore environment: str = "" - # Legacy config has to be kept around because pydantic will error on nonexisting keys + # Legacy config has to be kept around because pydantic will error + # on nonexisting keys chroma_db_impl: Optional[str] = None chroma_api_impl: str = "chromadb.api.segment.SegmentAPI" # Can be "chromadb.api.segment.SegmentAPI" or "chromadb.api.fastapi.FastAPI" @@ -94,6 +96,9 @@ class Settings(BaseSettings): # type: ignore # Distributed architecture specific components chroma_segment_directory_impl: str = "chromadb.segment.impl.distributed.segment_directory.RendezvousHashSegmentDirectory" chroma_memberlist_provider_impl: str = "chromadb.segment.impl.distributed.segment_directory.CustomResourceMemberlistProvider" + chroma_collection_assignment_policy_impl: str = ( + "chromadb.ingest.impl.simple_policy.SimpleAssignmentPolicy" + ) worker_memberlist_name: str = "worker-memberlist" chroma_coordinator_host = "localhost" diff --git a/chromadb/db/impl/grpc/client.py b/chromadb/db/impl/grpc/client.py index 5d7c6e88838..49e231b0086 100644 --- a/chromadb/db/impl/grpc/client.py +++ b/chromadb/db/impl/grpc/client.py @@ -4,6 +4,7 @@ from chromadb.config import System from chromadb.db.base import NotFoundError, UniqueConstraintError from chromadb.db.system import SysDB +from chromadb.ingest import CollectionAssignmentPolicy from chromadb.proto.convert import ( from_proto_collection, from_proto_segment, @@ -26,6 +27,7 @@ from chromadb.proto.coordinator_pb2_grpc import SysDBStub from chromadb.types import ( Collection, + Metadata, OptionalArgument, Segment, SegmentScope, @@ -42,6 +44,7 @@ class GrpcSysDB(SysDB): to call a remote SysDB (Coordinator) service.""" _sys_db_stub: SysDBStub + _assignment_policy: CollectionAssignmentPolicy _channel: grpc.Channel _coordinator_url: str _coordinator_port: int @@ -50,6 +53,7 @@ def __init__(self, system: System): self._coordinator_url = system.settings.require("chroma_coordinator_host") # TODO: break out coordinator_port into a separate setting? self._coordinator_port = system.settings.require("chroma_server_grpc_port") + self._assignment_policy = system.instance(CollectionAssignmentPolicy) return super().__init__(system) @overrides @@ -156,8 +160,22 @@ def update_segment( self._sys_db_stub.UpdateSegment(request) @overrides - def create_collection(self, collection: Collection) -> None: + def create_collection( + self, + id: UUID, + name: str, + metadata: Optional[Metadata] = None, + dimension: Optional[int] = None, + ) -> Collection: # TODO: the get_or_create concept needs to be pushed down to the sysdb interface + topic = self._assignment_policy.assign_collection(id) + collection = Collection( + id=id, + name=name, + topic=topic, + metadata=metadata, + dimension=dimension, + ) request = CreateCollectionRequest( collection=to_proto_collection(collection), get_or_create=False, @@ -165,6 +183,7 @@ def create_collection(self, collection: Collection) -> None: response = self._sys_db_stub.CreateCollection(request) if response.status.code == 409: raise UniqueConstraintError() + return collection @overrides def delete_collection(self, id: UUID) -> None: diff --git a/chromadb/db/mixins/sysdb.py b/chromadb/db/mixins/sysdb.py index 58ee4488b64..de276b44b16 100644 --- a/chromadb/db/mixins/sysdb.py +++ b/chromadb/db/mixins/sysdb.py @@ -14,6 +14,7 @@ UniqueConstraintError, ) from chromadb.db.system import SysDB +from chromadb.ingest import CollectionAssignmentPolicy from chromadb.types import ( OptionalArgument, Segment, @@ -26,7 +27,11 @@ class SqlSysDB(SqlDB, SysDB): + _assignment_policy: CollectionAssignmentPolicy + def __init__(self, system: System): + self._assignment_policy = system.instance(CollectionAssignmentPolicy) + super().__init__(system) @override @@ -69,8 +74,20 @@ def create_segment(self, segment: Segment) -> None: ) @override - def create_collection(self, collection: Collection) -> None: - """Create a new collection""" + def create_collection( + self, + id: UUID, + name: str, + metadata: Optional[Metadata] = None, + dimension: Optional[int] = None, + ) -> Collection: + """Create a new collection and the associate topic""" + + topic = self._assignment_policy.assign_collection(id) + collection = Collection( + id=id, topic=topic, name=name, metadata=metadata, dimension=dimension + ) + with self.tx() as cur: collections = Table("collections") insert_collection = ( @@ -105,6 +122,7 @@ def create_collection(self, collection: Collection) -> None: collection["id"], collection["metadata"], ) + return collection @override def get_segments( diff --git a/chromadb/db/system.py b/chromadb/db/system.py index 23f068c3be3..ac9488d3061 100644 --- a/chromadb/db/system.py +++ b/chromadb/db/system.py @@ -3,6 +3,7 @@ from uuid import UUID from chromadb.types import ( Collection, + Metadata, Segment, SegmentScope, OptionalArgument, @@ -52,8 +53,15 @@ def update_segment( pass @abstractmethod - def create_collection(self, collection: Collection) -> None: - """Create a new collection any associated resources in the SysDB.""" + def create_collection( + self, + id: UUID, + name: str, + metadata: Optional[Metadata] = None, + dimension: Optional[int] = None, + ) -> Collection: + """Create a new collection any associated resources + (Such as the necessary topics) in the SysDB.""" pass @abstractmethod diff --git a/chromadb/ingest/__init__.py b/chromadb/ingest/__init__.py index 56863e8914d..5a5abf1c99b 100644 --- a/chromadb/ingest/__init__.py +++ b/chromadb/ingest/__init__.py @@ -118,3 +118,12 @@ def min_seqid(self) -> SeqId: def max_seqid(self) -> SeqId: """Return the maximum possible SeqID in this implementation.""" pass + + +class CollectionAssignmentPolicy(Component): + """Interface for assigning collections to topics""" + + @abstractmethod + def assign_collection(self, collection_id: UUID) -> str: + """Return the topic that should be used for the given collection""" + pass diff --git a/chromadb/ingest/impl/simple_policy.py b/chromadb/ingest/impl/simple_policy.py new file mode 100644 index 00000000000..06ee2e001e0 --- /dev/null +++ b/chromadb/ingest/impl/simple_policy.py @@ -0,0 +1,25 @@ +from uuid import UUID +from overrides import overrides +from chromadb.config import System +from chromadb.ingest import CollectionAssignmentPolicy +from chromadb.ingest.impl.utils import create_topic_name + + +class SimpleAssignmentPolicy(CollectionAssignmentPolicy): + """Simple assignment policy that assigns a 1 collection to 1 topic based on the + id of the collection.""" + + _tenant_id: str + _topic_ns: str + + def __init__(self, system: System): + self._tenant_id = system.settings.tenant_id + self._topic_ns = system.settings.topic_namespace + super().__init__(system) + + def _topic(self, collection_id: UUID) -> str: + return create_topic_name(self._tenant_id, self._topic_ns, str(collection_id)) + + @overrides + def assign_collection(self, collection_id: UUID) -> str: + return self._topic(collection_id) diff --git a/chromadb/test/db/test_system.py b/chromadb/test/db/test_system.py index f67a3d9d390..8127c27c5d9 100644 --- a/chromadb/test/db/test_system.py +++ b/chromadb/test/db/test_system.py @@ -7,16 +7,55 @@ from chromadb.db.impl.grpc.server import GrpcMockSysDB from chromadb.types import Collection, Segment, SegmentScope from chromadb.db.impl.sqlite import SqliteDB -from chromadb.config import System, Settings +from chromadb.config import Component, System, Settings from chromadb.db.system import SysDB from chromadb.db.base import NotFoundError, UniqueConstraintError from pytest import FixtureRequest import uuid +sample_collections = [ + Collection( + id=uuid.UUID("93ffe3ec-0107-48d4-8695-51f978c509dc"), + name="test_collection_1", + topic="test_topic_1", + metadata={"test_str": "str1", "test_int": 1, "test_float": 1.3}, + dimension=128, + ), + Collection( + id=uuid.UUID("f444f1d7-d06c-4357-ac22-5a4a1f92d761"), + name="test_collection_2", + topic="test_topic_2", + metadata={"test_str": "str2", "test_int": 2, "test_float": 2.3}, + dimension=None, + ), + Collection( + id=uuid.UUID("43babc1a-e403-4a50-91a9-16621ba29ab0"), + name="test_collection_3", + topic="test_topic_3", + metadata={"test_str": "str3", "test_int": 3, "test_float": 3.3}, + dimension=None, + ), +] + + +class MockAssignmentPolicy(Component): + def assign_collection(self, collection_id: uuid.UUID) -> str: + for collection in sample_collections: + if collection["id"] == collection_id: + return collection["topic"] + raise ValueError(f"Unknown collection ID: {collection_id}") + def sqlite() -> Generator[SysDB, None, None]: """Fixture generator for sqlite DB""" - db = SqliteDB(System(Settings(allow_reset=True))) + db = SqliteDB( + System( + Settings( + allow_reset=True, + chroma_collection_assignment_policy_impl="chromadb.test.db.test_system.MockAssignmentPolicy", + ) + ) + ) db.start() yield db db.stop() @@ -27,7 +66,12 @@ def sqlite_persistent() -> Generator[SysDB, None, None]: save_path = tempfile.mkdtemp() db = SqliteDB( System( - Settings(allow_reset=True, is_persistent=True, persist_directory=save_path) + Settings( + allow_reset=True, + is_persistent=True, + persist_directory=save_path, + chroma_collection_assignment_policy_impl="chromadb.test.db.test_system.MockAssignmentPolicy", + ) ) ) db.start() @@ -40,7 +84,13 @@ def sqlite_persistent() -> Generator[SysDB, None, None]: def grpc_with_mock_server() -> Generator[SysDB, None, None]: """Fixture generator for sqlite DB that creates a mock grpc sysdb server and a grpc client that connects to it.""" - system = System(Settings(allow_reset=True, chroma_server_grpc_port=50051)) + system = System( + Settings( + allow_reset=True, + chroma_collection_assignment_policy_impl="chromadb.test.db.test_system.MockAssignmentPolicy", + chroma_server_grpc_port=50051, + ) + ) system.instance(GrpcMockSysDB) client = system.instance(GrpcSysDB) system.start() @@ -57,36 +107,16 @@ def sysdb(request: FixtureRequest) -> Generator[SysDB, None, None]: yield next(request.param()) -sample_collections = [ - Collection( - id=uuid.uuid4(), - name="test_collection_1", - topic="test_topic_1", - metadata={"test_str": "str1", "test_int": 1, "test_float": 1.3}, - dimension=128, - ), - Collection( - id=uuid.uuid4(), - name="test_collection_2", - topic="test_topic_2", - metadata={"test_str": "str2", "test_int": 2, "test_float": 2.3}, - dimension=None, - ), - Collection( - id=uuid.uuid4(), - name="test_collection_3", - topic="test_topic_3", - metadata={"test_str": "str3", "test_int": 3, "test_float": 3.3}, - dimension=None, - ), -] - - def test_create_get_delete_collections(sysdb: SysDB) -> None: sysdb.reset_state() for collection in sample_collections: - sysdb.create_collection(collection) + sysdb.create_collection( + id=collection["id"], + name=collection["name"], + metadata=collection["metadata"], + dimension=collection["dimension"], + ) results = sysdb.get_collections() results = sorted(results, key=lambda c: c["name"]) @@ -95,7 +125,9 @@ def test_create_get_delete_collections(sysdb: SysDB) -> None: # Duplicate create fails with pytest.raises(UniqueConstraintError): - sysdb.create_collection(sample_collections[0]) + sysdb.create_collection( + name=sample_collections[0]["name"], id=sample_collections[0]["id"] + ) # Find by name for collection in sample_collections: @@ -140,22 +172,22 @@ def test_create_get_delete_collections(sysdb: SysDB) -> None: def test_update_collections(sysdb: SysDB) -> None: - metadata: Dict[str, Union[str, int, float]] = { - "test_str": "str1", - "test_int": 1, - "test_float": 1.3, - } coll = Collection( - id=uuid.uuid4(), - name="test_collection_1", - topic="test_topic_1", - metadata=metadata, - dimension=None, + name=sample_collections[0]["name"], + id=sample_collections[0]["id"], + topic=sample_collections[0]["topic"], + metadata=sample_collections[0]["metadata"], + dimension=sample_collections[0]["dimension"], ) sysdb.reset_state() - sysdb.create_collection(coll) + sysdb.create_collection( + id=coll["id"], + name=coll["name"], + metadata=coll["metadata"], + dimension=coll["dimension"], + ) # Update name coll["name"] = "new_name" @@ -220,7 +252,12 @@ def test_create_get_delete_segments(sysdb: SysDB) -> None: sysdb.reset_state() for collection in sample_collections: - sysdb.create_collection(collection) + sysdb.create_collection( + id=collection["id"], + name=collection["name"], + metadata=collection["metadata"], + dimension=collection["dimension"], + ) for segment in sample_segments: sysdb.create_segment(segment) @@ -293,7 +330,9 @@ def test_update_segment(sysdb: SysDB) -> None: sysdb.reset_state() for c in sample_collections: - sysdb.create_collection(c) + sysdb.create_collection( + id=c["id"], name=c["name"], metadata=c["metadata"], dimension=c["dimension"] + ) sysdb.create_segment(segment)