From e81cc9f361e5aa072534a1fbbc483da406b54848 Mon Sep 17 00:00:00 2001 From: Ben Eggers <64657842+beggers@users.noreply.github.com> Date: Mon, 9 Oct 2023 12:28:59 -0700 Subject: [PATCH 01/14] [RELEASE] 0.4.14 (#1221) Release 0.4.14 --- chromadb/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chromadb/__init__.py b/chromadb/__init__.py index ffc32392e07..9c0b8000a14 100644 --- a/chromadb/__init__.py +++ b/chromadb/__init__.py @@ -43,7 +43,7 @@ __settings = Settings() -__version__ = "0.4.13" +__version__ = "0.4.14" # Workaround to deal with Colab's old sqlite3 version try: From 764ffe259c2c21e7d90ad768a1919455bec91dbb Mon Sep 17 00:00:00 2001 From: Hammad Bashir Date: Tue, 10 Oct 2023 10:32:09 -0700 Subject: [PATCH 02/14] [ENH] Add CRD backed SegmentDirectory. (#1207) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Move from docker-compose to k8s manifests. These are lightweight for now, will replace with helm charts. - Modify workflow action for distributed mode to use minikube - Add tenacity for retry/backoff logic - Move segment_directory to /distributed since its a distributed segment type specific interface. - Remove the xfail from the workflow. Only run passing tests - New functionality - Adds a custom resource backed memberlist implementation that fetch + watches. We watch in another thread. - Plumbs this memberlist into the Rendezvous hashing segment directory. For now this is still hardcoded to return the segment-server service. ## Test plan *How are these changes tested?* Tests were added that: 1. Test fetch 2. Test streaming update 3. Test killing the watcher You can run `bin/cluster-test.sh chromadb/test/segment/distributed/test_memberlist_provider.py` to run them yourself (or run minikube, apply the manifests and then run `pytest chromadb/test/segment/distributed/test_memberlist_provider.py`) ## Documentation Changes None required. I added a WARNING.md in k8s to make it clear these manifests are not for use. --- .github/workflows/chroma-cluster-test.yml | 18 +- .pre-commit-config.yaml | 2 +- bin/cluster-test.sh | 40 ++- chromadb/config.py | 15 +- chromadb/proto/chroma_pb2.py | 95 ++++--- chromadb/proto/chroma_pb2.pyi | 95 ++++++- chromadb/proto/chroma_pb2_grpc.py | 261 ++++++++++-------- chromadb/segment/__init__.py | 18 +- chromadb/segment/distributed/__init__.py | 70 +++++ .../impl/distributed/segment_directory.py | 226 +++++++++++++++ chromadb/segment/impl/manager/distributed.py | 9 +- .../segment/impl/manager/segment_directory.py | 36 --- chromadb/test/conftest.py | 1 - .../test/ingest/test_producer_consumer.py | 6 +- .../distributed/test_memberlist_provider.py | 122 ++++++++ docker-compose.cluster.test.yml | 96 ------- docker-compose.cluster.yml | 92 ------ k8s/WARNING.md | 3 + k8s/cr/worker_memberlist_cr.yaml | 43 +++ k8s/crd/memberlist_crd.yaml | 36 +++ k8s/deployment/kubernetes.yaml | 240 ++++++++++++++++ k8s/test/pulsar_service.yaml | 20 ++ pyproject.toml | 2 + requirements.txt | 2 + 24 files changed, 1113 insertions(+), 435 deletions(-) create mode 100644 chromadb/segment/distributed/__init__.py create mode 100644 chromadb/segment/impl/distributed/segment_directory.py delete mode 100644 chromadb/segment/impl/manager/segment_directory.py create mode 100644 chromadb/test/segment/distributed/test_memberlist_provider.py delete mode 100644 docker-compose.cluster.test.yml delete mode 100644 docker-compose.cluster.yml create mode 100644 k8s/WARNING.md create mode 100644 k8s/cr/worker_memberlist_cr.yaml create mode 100644 k8s/crd/memberlist_crd.yaml create mode 100644 k8s/deployment/kubernetes.yaml create mode 100644 k8s/test/pulsar_service.yaml diff --git a/.github/workflows/chroma-cluster-test.yml b/.github/workflows/chroma-cluster-test.yml index 25287dbf0cd..fc8e514f323 100644 --- a/.github/workflows/chroma-cluster-test.yml +++ b/.github/workflows/chroma-cluster-test.yml @@ -16,12 +16,8 @@ jobs: matrix: python: ['3.7'] platform: [ubuntu-latest] - testfile: ["--ignore-glob 'chromadb/test/property/*' --ignore='chromadb/test/test_cli.py'", - "chromadb/test/property/test_add.py", - "chromadb/test/property/test_collections.py", - "chromadb/test/property/test_embeddings.py", - "chromadb/test/property/test_filtering.py", - "chromadb/test/property/test_persist.py"] + testfile: ["chromadb/test/ingest/test_producer_consumer.py", + "chromadb/test/segment/distributed/test_memberlist_provider.py",] runs-on: ${{ matrix.platform }} steps: - name: Checkout @@ -32,6 +28,14 @@ jobs: python-version: ${{ matrix.python }} - name: Install test dependencies run: python -m pip install -r requirements.txt && python -m pip install -r requirements_dev.txt + - name: Start minikube + id: minikube + uses: medyagh/setup-minikube@latest + with: + minikube-version: latest + kubernetes-version: latest + driver: docker + addons: ingress, ingress-dns + start-args: '--profile chroma-test' - name: Integration Test run: bin/cluster-test.sh ${{ matrix.testfile }} - continue-on-error: true # Mark the job as successful even if the tests fail for now (Xfail) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6bf67a64ead..3f0065bb133 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,4 +33,4 @@ repos: hooks: - id: mypy args: [--strict, --ignore-missing-imports, --follow-imports=silent, --disable-error-code=type-abstract] - additional_dependencies: ["types-requests", "pydantic", "overrides", "hypothesis", "pytest", "pypika", "numpy", "types-protobuf"] + additional_dependencies: ["types-requests", "pydantic", "overrides", "hypothesis", "pytest", "pypika", "numpy", "types-protobuf", "kubernetes"] diff --git a/bin/cluster-test.sh b/bin/cluster-test.sh index 6eccc5f3158..71c6eea47bf 100755 --- a/bin/cluster-test.sh +++ b/bin/cluster-test.sh @@ -3,14 +3,50 @@ set -e function cleanup { - docker compose -f docker-compose.cluster.test.yml down --rmi local --volumes + # Restore the previous kube context + kubectl config use-context $PREV_CHROMA_KUBE_CONTEXT + # Kill the tunnel process + kill $TUNNEL_PID + minikube delete -p chroma-test } trap cleanup EXIT -docker compose -f docker-compose.cluster.test.yml up -d --wait +# Save the current kube context into a variable +export PREV_CHROMA_KUBE_CONTEXT=$(kubectl config current-context) + +# Create a new minikube cluster for the test +minikube start -p chroma-test + +# Add the ingress addon to the cluster +minikube addons enable ingress -p chroma-test +minikube addons enable ingress-dns -p chroma-test + +# Setup docker to build inside the minikube cluster and build the image +eval $(minikube -p chroma-test docker-env) +docker build -t server:latest -f Dockerfile . + +# Apply the kubernetes manifests +kubectl apply -f k8s/deployment +kubectl apply -f k8s/crd +kubectl apply -f k8s/cr +kubectl apply -f k8s/test + +# Wait for the pods in the chroma namespace to be ready +kubectl wait --namespace chroma --for=condition=Ready pods --all --timeout=300s + +# Run mini kube tunnel in the background to expose the service +minikube tunnel -p chroma-test & +TUNNEL_PID=$! + +# Wait for the tunnel to be ready. There isn't an easy way to check if the tunnel is ready. So we just wait for 10 seconds +sleep 10 export CHROMA_CLUSTER_TEST_ONLY=1 +export CHROMA_SERVER_HOST=$(kubectl get svc server -n chroma -o=jsonpath='{.status.loadBalancer.ingress[0].ip}') +export PULSAR_BROKER_URL=$(kubectl get svc pulsar -n chroma -o=jsonpath='{.status.loadBalancer.ingress[0].ip}') +echo "Chroma Server is running at port $CHROMA_SERVER_HOST" +echo "Pulsar Broker is running at port $PULSAR_BROKER_URL" echo testing: python -m pytest "$@" python -m pytest "$@" diff --git a/chromadb/config.py b/chromadb/config.py index 1ecf7d04254..920c92d6a96 100644 --- a/chromadb/config.py +++ b/chromadb/config.py @@ -69,7 +69,8 @@ "chromadb.ingest.Consumer": "chroma_consumer_impl", "chromadb.db.system.SysDB": "chroma_sysdb_impl", "chromadb.segment.SegmentManager": "chroma_segment_manager_impl", - "chromadb.segment.SegmentDirectory": "chroma_segment_directory_impl", + "chromadb.segment.distributed.SegmentDirectory": "chroma_segment_directory_impl", + "chromadb.segment.distributed.MemberlistProvider": "chroma_memberlist_provider_impl", } @@ -89,9 +90,11 @@ class Settings(BaseSettings): # type: ignore chroma_segment_manager_impl: str = ( "chromadb.segment.impl.manager.local.LocalSegmentManager" ) - chroma_segment_directory_impl: str = ( - "chromadb.segment.impl.manager.segment_directory.DockerComposeSegmentDirectory" - ) + + # Distributed architecture specific components + chroma_segment_directory_impl: str = "chromadb.segment.impl.distributed.segment_directory.RendezvousHashSegmentDirectory" + chroma_memberlist_provider_impl: str = "chromadb.segment.impl.distributed.segment_directory.CustomResourceMemberlistProvider" + worker_memberlist_name: str = "worker-memberlist" tenant_id: str = "default" topic_namespace: str = "default" @@ -108,8 +111,8 @@ class Settings(BaseSettings): # type: ignore chroma_server_cors_allow_origins: List[str] = [] # eg ["http://localhost:3000"] pulsar_broker_url: Optional[str] = None - pulsar_admin_port: Optional[str] = None - pulsar_broker_port: Optional[str] = None + pulsar_admin_port: Optional[str] = "8080" + pulsar_broker_port: Optional[str] = "6650" chroma_server_auth_provider: Optional[str] = None diff --git a/chromadb/proto/chroma_pb2.py b/chromadb/proto/chroma_pb2.py index 2a302c67154..4e8c62576f0 100644 --- a/chromadb/proto/chroma_pb2.py +++ b/chromadb/proto/chroma_pb2.py @@ -1,3 +1,4 @@ +# type: ignore # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: chromadb/proto/chroma.proto @@ -6,59 +7,61 @@ from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1b\x63hromadb/proto/chroma.proto\x12\x06\x63hroma\"U\n\x06Vector\x12\x11\n\tdimension\x18\x01 \x01(\x05\x12\x0e\n\x06vector\x18\x02 \x01(\x0c\x12(\n\x08\x65ncoding\x18\x03 \x01(\x0e\x32\x16.chroma.ScalarEncoding\"\xca\x01\n\x07Segment\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12#\n\x05scope\x18\x03 \x01(\x0e\x32\x14.chroma.SegmentScope\x12\x12\n\x05topic\x18\x04 \x01(\tH\x00\x88\x01\x01\x12\x17\n\ncollection\x18\x05 \x01(\tH\x01\x88\x01\x01\x12-\n\x08metadata\x18\x06 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x88\x01\x01\x42\x08\n\x06_topicB\r\n\x0b_collectionB\x0b\n\t_metadata\"b\n\x13UpdateMetadataValue\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x13\n\tint_value\x18\x02 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x01H\x00\x42\x07\n\x05value\"\x96\x01\n\x0eUpdateMetadata\x12\x36\n\x08metadata\x18\x01 \x03(\x0b\x32$.chroma.UpdateMetadata.MetadataEntry\x1aL\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12*\n\x05value\x18\x02 \x01(\x0b\x32\x1b.chroma.UpdateMetadataValue:\x02\x38\x01\"\xb5\x01\n\x15SubmitEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12#\n\x06vector\x18\x02 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x12-\n\x08metadata\x18\x03 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x01\x88\x01\x01\x12$\n\toperation\x18\x04 \x01(\x0e\x32\x11.chroma.OperationB\t\n\x07_vectorB\x0b\n\t_metadata\"S\n\x15VectorEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06seq_id\x18\x02 \x01(\x0c\x12\x1e\n\x06vector\x18\x03 \x01(\x0b\x32\x0e.chroma.Vector\"q\n\x11VectorQueryResult\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06seq_id\x18\x02 \x01(\x0c\x12\x10\n\x08\x64istance\x18\x03 \x01(\x01\x12#\n\x06vector\x18\x04 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x42\t\n\x07_vector\"@\n\x12VectorQueryResults\x12*\n\x07results\x18\x01 \x03(\x0b\x32\x19.chroma.VectorQueryResult\"(\n\x15SegmentServerResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"4\n\x11GetVectorsRequest\x12\x0b\n\x03ids\x18\x01 \x03(\t\x12\x12\n\nsegment_id\x18\x02 \x01(\t\"D\n\x12GetVectorsResponse\x12.\n\x07records\x18\x01 \x03(\x0b\x32\x1d.chroma.VectorEmbeddingRecord\"\x86\x01\n\x13QueryVectorsRequest\x12\x1f\n\x07vectors\x18\x01 \x03(\x0b\x32\x0e.chroma.Vector\x12\t\n\x01k\x18\x02 \x01(\x05\x12\x13\n\x0b\x61llowed_ids\x18\x03 \x03(\t\x12\x1a\n\x12include_embeddings\x18\x04 \x01(\x08\x12\x12\n\nsegment_id\x18\x05 \x01(\t\"C\n\x14QueryVectorsResponse\x12+\n\x07results\x18\x01 \x03(\x0b\x32\x1a.chroma.VectorQueryResults*8\n\tOperation\x12\x07\n\x03\x41\x44\x44\x10\x00\x12\n\n\x06UPDATE\x10\x01\x12\n\n\x06UPSERT\x10\x02\x12\n\n\x06\x44\x45LETE\x10\x03*(\n\x0eScalarEncoding\x12\x0b\n\x07\x46LOAT32\x10\x00\x12\t\n\x05INT32\x10\x01*(\n\x0cSegmentScope\x12\n\n\x06VECTOR\x10\x00\x12\x0c\n\x08METADATA\x10\x01\x32\x94\x01\n\rSegmentServer\x12?\n\x0bLoadSegment\x12\x0f.chroma.Segment\x1a\x1d.chroma.SegmentServerResponse\"\x00\x12\x42\n\x0eReleaseSegment\x12\x0f.chroma.Segment\x1a\x1d.chroma.SegmentServerResponse\"\x00\x32\xa2\x01\n\x0cVectorReader\x12\x45\n\nGetVectors\x12\x19.chroma.GetVectorsRequest\x1a\x1a.chroma.GetVectorsResponse\"\x00\x12K\n\x0cQueryVectors\x12\x1b.chroma.QueryVectorsRequest\x1a\x1c.chroma.QueryVectorsResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x1b\x63hromadb/proto/chroma.proto\x12\x06\x63hroma"U\n\x06Vector\x12\x11\n\tdimension\x18\x01 \x01(\x05\x12\x0e\n\x06vector\x18\x02 \x01(\x0c\x12(\n\x08\x65ncoding\x18\x03 \x01(\x0e\x32\x16.chroma.ScalarEncoding"\xca\x01\n\x07Segment\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12#\n\x05scope\x18\x03 \x01(\x0e\x32\x14.chroma.SegmentScope\x12\x12\n\x05topic\x18\x04 \x01(\tH\x00\x88\x01\x01\x12\x17\n\ncollection\x18\x05 \x01(\tH\x01\x88\x01\x01\x12-\n\x08metadata\x18\x06 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x88\x01\x01\x42\x08\n\x06_topicB\r\n\x0b_collectionB\x0b\n\t_metadata"b\n\x13UpdateMetadataValue\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x13\n\tint_value\x18\x02 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x01H\x00\x42\x07\n\x05value"\x96\x01\n\x0eUpdateMetadata\x12\x36\n\x08metadata\x18\x01 \x03(\x0b\x32$.chroma.UpdateMetadata.MetadataEntry\x1aL\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12*\n\x05value\x18\x02 \x01(\x0b\x32\x1b.chroma.UpdateMetadataValue:\x02\x38\x01"\xb5\x01\n\x15SubmitEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12#\n\x06vector\x18\x02 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x12-\n\x08metadata\x18\x03 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x01\x88\x01\x01\x12$\n\toperation\x18\x04 \x01(\x0e\x32\x11.chroma.OperationB\t\n\x07_vectorB\x0b\n\t_metadata"S\n\x15VectorEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06seq_id\x18\x02 \x01(\x0c\x12\x1e\n\x06vector\x18\x03 \x01(\x0b\x32\x0e.chroma.Vector"q\n\x11VectorQueryResult\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06seq_id\x18\x02 \x01(\x0c\x12\x10\n\x08\x64istance\x18\x03 \x01(\x01\x12#\n\x06vector\x18\x04 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x42\t\n\x07_vector"@\n\x12VectorQueryResults\x12*\n\x07results\x18\x01 \x03(\x0b\x32\x19.chroma.VectorQueryResult"(\n\x15SegmentServerResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08"4\n\x11GetVectorsRequest\x12\x0b\n\x03ids\x18\x01 \x03(\t\x12\x12\n\nsegment_id\x18\x02 \x01(\t"D\n\x12GetVectorsResponse\x12.\n\x07records\x18\x01 \x03(\x0b\x32\x1d.chroma.VectorEmbeddingRecord"\x86\x01\n\x13QueryVectorsRequest\x12\x1f\n\x07vectors\x18\x01 \x03(\x0b\x32\x0e.chroma.Vector\x12\t\n\x01k\x18\x02 \x01(\x05\x12\x13\n\x0b\x61llowed_ids\x18\x03 \x03(\t\x12\x1a\n\x12include_embeddings\x18\x04 \x01(\x08\x12\x12\n\nsegment_id\x18\x05 \x01(\t"C\n\x14QueryVectorsResponse\x12+\n\x07results\x18\x01 \x03(\x0b\x32\x1a.chroma.VectorQueryResults*8\n\tOperation\x12\x07\n\x03\x41\x44\x44\x10\x00\x12\n\n\x06UPDATE\x10\x01\x12\n\n\x06UPSERT\x10\x02\x12\n\n\x06\x44\x45LETE\x10\x03*(\n\x0eScalarEncoding\x12\x0b\n\x07\x46LOAT32\x10\x00\x12\t\n\x05INT32\x10\x01*(\n\x0cSegmentScope\x12\n\n\x06VECTOR\x10\x00\x12\x0c\n\x08METADATA\x10\x01\x32\x94\x01\n\rSegmentServer\x12?\n\x0bLoadSegment\x12\x0f.chroma.Segment\x1a\x1d.chroma.SegmentServerResponse"\x00\x12\x42\n\x0eReleaseSegment\x12\x0f.chroma.Segment\x1a\x1d.chroma.SegmentServerResponse"\x00\x32\xa2\x01\n\x0cVectorReader\x12\x45\n\nGetVectors\x12\x19.chroma.GetVectorsRequest\x1a\x1a.chroma.GetVectorsResponse"\x00\x12K\n\x0cQueryVectors\x12\x1b.chroma.QueryVectorsRequest\x1a\x1c.chroma.QueryVectorsResponse"\x00\x62\x06proto3' +) _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'chromadb.proto.chroma_pb2', _globals) +_builder.BuildTopDescriptorsAndMessages( + DESCRIPTOR, "chromadb.proto.chroma_pb2", _globals +) if _descriptor._USE_C_DESCRIPTORS == False: - - DESCRIPTOR._options = None - _UPDATEMETADATA_METADATAENTRY._options = None - _UPDATEMETADATA_METADATAENTRY._serialized_options = b'8\001' - _globals['_OPERATION']._serialized_start=1406 - _globals['_OPERATION']._serialized_end=1462 - _globals['_SCALARENCODING']._serialized_start=1464 - _globals['_SCALARENCODING']._serialized_end=1504 - _globals['_SEGMENTSCOPE']._serialized_start=1506 - _globals['_SEGMENTSCOPE']._serialized_end=1546 - _globals['_VECTOR']._serialized_start=39 - _globals['_VECTOR']._serialized_end=124 - _globals['_SEGMENT']._serialized_start=127 - _globals['_SEGMENT']._serialized_end=329 - _globals['_UPDATEMETADATAVALUE']._serialized_start=331 - _globals['_UPDATEMETADATAVALUE']._serialized_end=429 - _globals['_UPDATEMETADATA']._serialized_start=432 - _globals['_UPDATEMETADATA']._serialized_end=582 - _globals['_UPDATEMETADATA_METADATAENTRY']._serialized_start=506 - _globals['_UPDATEMETADATA_METADATAENTRY']._serialized_end=582 - _globals['_SUBMITEMBEDDINGRECORD']._serialized_start=585 - _globals['_SUBMITEMBEDDINGRECORD']._serialized_end=766 - _globals['_VECTOREMBEDDINGRECORD']._serialized_start=768 - _globals['_VECTOREMBEDDINGRECORD']._serialized_end=851 - _globals['_VECTORQUERYRESULT']._serialized_start=853 - _globals['_VECTORQUERYRESULT']._serialized_end=966 - _globals['_VECTORQUERYRESULTS']._serialized_start=968 - _globals['_VECTORQUERYRESULTS']._serialized_end=1032 - _globals['_SEGMENTSERVERRESPONSE']._serialized_start=1034 - _globals['_SEGMENTSERVERRESPONSE']._serialized_end=1074 - _globals['_GETVECTORSREQUEST']._serialized_start=1076 - _globals['_GETVECTORSREQUEST']._serialized_end=1128 - _globals['_GETVECTORSRESPONSE']._serialized_start=1130 - _globals['_GETVECTORSRESPONSE']._serialized_end=1198 - _globals['_QUERYVECTORSREQUEST']._serialized_start=1201 - _globals['_QUERYVECTORSREQUEST']._serialized_end=1335 - _globals['_QUERYVECTORSRESPONSE']._serialized_start=1337 - _globals['_QUERYVECTORSRESPONSE']._serialized_end=1404 - _globals['_SEGMENTSERVER']._serialized_start=1549 - _globals['_SEGMENTSERVER']._serialized_end=1697 - _globals['_VECTORREADER']._serialized_start=1700 - _globals['_VECTORREADER']._serialized_end=1862 + DESCRIPTOR._options = None + _UPDATEMETADATA_METADATAENTRY._options = None + _UPDATEMETADATA_METADATAENTRY._serialized_options = b"8\001" + _globals["_OPERATION"]._serialized_start = 1406 + _globals["_OPERATION"]._serialized_end = 1462 + _globals["_SCALARENCODING"]._serialized_start = 1464 + _globals["_SCALARENCODING"]._serialized_end = 1504 + _globals["_SEGMENTSCOPE"]._serialized_start = 1506 + _globals["_SEGMENTSCOPE"]._serialized_end = 1546 + _globals["_VECTOR"]._serialized_start = 39 + _globals["_VECTOR"]._serialized_end = 124 + _globals["_SEGMENT"]._serialized_start = 127 + _globals["_SEGMENT"]._serialized_end = 329 + _globals["_UPDATEMETADATAVALUE"]._serialized_start = 331 + _globals["_UPDATEMETADATAVALUE"]._serialized_end = 429 + _globals["_UPDATEMETADATA"]._serialized_start = 432 + _globals["_UPDATEMETADATA"]._serialized_end = 582 + _globals["_UPDATEMETADATA_METADATAENTRY"]._serialized_start = 506 + _globals["_UPDATEMETADATA_METADATAENTRY"]._serialized_end = 582 + _globals["_SUBMITEMBEDDINGRECORD"]._serialized_start = 585 + _globals["_SUBMITEMBEDDINGRECORD"]._serialized_end = 766 + _globals["_VECTOREMBEDDINGRECORD"]._serialized_start = 768 + _globals["_VECTOREMBEDDINGRECORD"]._serialized_end = 851 + _globals["_VECTORQUERYRESULT"]._serialized_start = 853 + _globals["_VECTORQUERYRESULT"]._serialized_end = 966 + _globals["_VECTORQUERYRESULTS"]._serialized_start = 968 + _globals["_VECTORQUERYRESULTS"]._serialized_end = 1032 + _globals["_SEGMENTSERVERRESPONSE"]._serialized_start = 1034 + _globals["_SEGMENTSERVERRESPONSE"]._serialized_end = 1074 + _globals["_GETVECTORSREQUEST"]._serialized_start = 1076 + _globals["_GETVECTORSREQUEST"]._serialized_end = 1128 + _globals["_GETVECTORSRESPONSE"]._serialized_start = 1130 + _globals["_GETVECTORSRESPONSE"]._serialized_end = 1198 + _globals["_QUERYVECTORSREQUEST"]._serialized_start = 1201 + _globals["_QUERYVECTORSREQUEST"]._serialized_end = 1335 + _globals["_QUERYVECTORSRESPONSE"]._serialized_start = 1337 + _globals["_QUERYVECTORSRESPONSE"]._serialized_end = 1404 + _globals["_SEGMENTSERVER"]._serialized_start = 1549 + _globals["_SEGMENTSERVER"]._serialized_end = 1697 + _globals["_VECTORREADER"]._serialized_start = 1700 + _globals["_VECTORREADER"]._serialized_end = 1862 # @@protoc_insertion_point(module_scope) diff --git a/chromadb/proto/chroma_pb2.pyi b/chromadb/proto/chroma_pb2.pyi index 6d06e074c06..0b52141e64a 100644 --- a/chromadb/proto/chroma_pb2.pyi +++ b/chromadb/proto/chroma_pb2.pyi @@ -1,8 +1,16 @@ +# type: ignore + from google.protobuf.internal import containers as _containers from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message -from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union +from typing import ( + ClassVar as _ClassVar, + Iterable as _Iterable, + Mapping as _Mapping, + Optional as _Optional, + Union as _Union, +) DESCRIPTOR: _descriptor.FileDescriptor @@ -22,6 +30,7 @@ class SegmentScope(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): __slots__ = [] VECTOR: _ClassVar[SegmentScope] METADATA: _ClassVar[SegmentScope] + ADD: Operation UPDATE: Operation UPSERT: Operation @@ -39,7 +48,12 @@ class Vector(_message.Message): dimension: int vector: bytes encoding: ScalarEncoding - def __init__(self, dimension: _Optional[int] = ..., vector: _Optional[bytes] = ..., encoding: _Optional[_Union[ScalarEncoding, str]] = ...) -> None: ... + def __init__( + self, + dimension: _Optional[int] = ..., + vector: _Optional[bytes] = ..., + encoding: _Optional[_Union[ScalarEncoding, str]] = ..., + ) -> None: ... class Segment(_message.Message): __slots__ = ["id", "type", "scope", "topic", "collection", "metadata"] @@ -55,7 +69,15 @@ class Segment(_message.Message): topic: str collection: str metadata: UpdateMetadata - def __init__(self, id: _Optional[str] = ..., type: _Optional[str] = ..., scope: _Optional[_Union[SegmentScope, str]] = ..., topic: _Optional[str] = ..., collection: _Optional[str] = ..., metadata: _Optional[_Union[UpdateMetadata, _Mapping]] = ...) -> None: ... + def __init__( + self, + id: _Optional[str] = ..., + type: _Optional[str] = ..., + scope: _Optional[_Union[SegmentScope, str]] = ..., + topic: _Optional[str] = ..., + collection: _Optional[str] = ..., + metadata: _Optional[_Union[UpdateMetadata, _Mapping]] = ..., + ) -> None: ... class UpdateMetadataValue(_message.Message): __slots__ = ["string_value", "int_value", "float_value"] @@ -65,20 +87,32 @@ class UpdateMetadataValue(_message.Message): string_value: str int_value: int float_value: float - def __init__(self, string_value: _Optional[str] = ..., int_value: _Optional[int] = ..., float_value: _Optional[float] = ...) -> None: ... + def __init__( + self, + string_value: _Optional[str] = ..., + int_value: _Optional[int] = ..., + float_value: _Optional[float] = ..., + ) -> None: ... class UpdateMetadata(_message.Message): __slots__ = ["metadata"] + class MetadataEntry(_message.Message): __slots__ = ["key", "value"] KEY_FIELD_NUMBER: _ClassVar[int] VALUE_FIELD_NUMBER: _ClassVar[int] key: str value: UpdateMetadataValue - def __init__(self, key: _Optional[str] = ..., value: _Optional[_Union[UpdateMetadataValue, _Mapping]] = ...) -> None: ... + def __init__( + self, + key: _Optional[str] = ..., + value: _Optional[_Union[UpdateMetadataValue, _Mapping]] = ..., + ) -> None: ... METADATA_FIELD_NUMBER: _ClassVar[int] metadata: _containers.MessageMap[str, UpdateMetadataValue] - def __init__(self, metadata: _Optional[_Mapping[str, UpdateMetadataValue]] = ...) -> None: ... + def __init__( + self, metadata: _Optional[_Mapping[str, UpdateMetadataValue]] = ... + ) -> None: ... class SubmitEmbeddingRecord(_message.Message): __slots__ = ["id", "vector", "metadata", "operation"] @@ -90,7 +124,13 @@ class SubmitEmbeddingRecord(_message.Message): vector: Vector metadata: UpdateMetadata operation: Operation - def __init__(self, id: _Optional[str] = ..., vector: _Optional[_Union[Vector, _Mapping]] = ..., metadata: _Optional[_Union[UpdateMetadata, _Mapping]] = ..., operation: _Optional[_Union[Operation, str]] = ...) -> None: ... + def __init__( + self, + id: _Optional[str] = ..., + vector: _Optional[_Union[Vector, _Mapping]] = ..., + metadata: _Optional[_Union[UpdateMetadata, _Mapping]] = ..., + operation: _Optional[_Union[Operation, str]] = ..., + ) -> None: ... class VectorEmbeddingRecord(_message.Message): __slots__ = ["id", "seq_id", "vector"] @@ -100,7 +140,12 @@ class VectorEmbeddingRecord(_message.Message): id: str seq_id: bytes vector: Vector - def __init__(self, id: _Optional[str] = ..., seq_id: _Optional[bytes] = ..., vector: _Optional[_Union[Vector, _Mapping]] = ...) -> None: ... + def __init__( + self, + id: _Optional[str] = ..., + seq_id: _Optional[bytes] = ..., + vector: _Optional[_Union[Vector, _Mapping]] = ..., + ) -> None: ... class VectorQueryResult(_message.Message): __slots__ = ["id", "seq_id", "distance", "vector"] @@ -112,13 +157,21 @@ class VectorQueryResult(_message.Message): seq_id: bytes distance: float vector: Vector - def __init__(self, id: _Optional[str] = ..., seq_id: _Optional[bytes] = ..., distance: _Optional[float] = ..., vector: _Optional[_Union[Vector, _Mapping]] = ...) -> None: ... + def __init__( + self, + id: _Optional[str] = ..., + seq_id: _Optional[bytes] = ..., + distance: _Optional[float] = ..., + vector: _Optional[_Union[Vector, _Mapping]] = ..., + ) -> None: ... class VectorQueryResults(_message.Message): __slots__ = ["results"] RESULTS_FIELD_NUMBER: _ClassVar[int] results: _containers.RepeatedCompositeFieldContainer[VectorQueryResult] - def __init__(self, results: _Optional[_Iterable[_Union[VectorQueryResult, _Mapping]]] = ...) -> None: ... + def __init__( + self, results: _Optional[_Iterable[_Union[VectorQueryResult, _Mapping]]] = ... + ) -> None: ... class SegmentServerResponse(_message.Message): __slots__ = ["success"] @@ -132,13 +185,18 @@ class GetVectorsRequest(_message.Message): SEGMENT_ID_FIELD_NUMBER: _ClassVar[int] ids: _containers.RepeatedScalarFieldContainer[str] segment_id: str - def __init__(self, ids: _Optional[_Iterable[str]] = ..., segment_id: _Optional[str] = ...) -> None: ... + def __init__( + self, ids: _Optional[_Iterable[str]] = ..., segment_id: _Optional[str] = ... + ) -> None: ... class GetVectorsResponse(_message.Message): __slots__ = ["records"] RECORDS_FIELD_NUMBER: _ClassVar[int] records: _containers.RepeatedCompositeFieldContainer[VectorEmbeddingRecord] - def __init__(self, records: _Optional[_Iterable[_Union[VectorEmbeddingRecord, _Mapping]]] = ...) -> None: ... + def __init__( + self, + records: _Optional[_Iterable[_Union[VectorEmbeddingRecord, _Mapping]]] = ..., + ) -> None: ... class QueryVectorsRequest(_message.Message): __slots__ = ["vectors", "k", "allowed_ids", "include_embeddings", "segment_id"] @@ -152,10 +210,19 @@ class QueryVectorsRequest(_message.Message): allowed_ids: _containers.RepeatedScalarFieldContainer[str] include_embeddings: bool segment_id: str - def __init__(self, vectors: _Optional[_Iterable[_Union[Vector, _Mapping]]] = ..., k: _Optional[int] = ..., allowed_ids: _Optional[_Iterable[str]] = ..., include_embeddings: bool = ..., segment_id: _Optional[str] = ...) -> None: ... + def __init__( + self, + vectors: _Optional[_Iterable[_Union[Vector, _Mapping]]] = ..., + k: _Optional[int] = ..., + allowed_ids: _Optional[_Iterable[str]] = ..., + include_embeddings: bool = ..., + segment_id: _Optional[str] = ..., + ) -> None: ... class QueryVectorsResponse(_message.Message): __slots__ = ["results"] RESULTS_FIELD_NUMBER: _ClassVar[int] results: _containers.RepeatedCompositeFieldContainer[VectorQueryResults] - def __init__(self, results: _Optional[_Iterable[_Union[VectorQueryResults, _Mapping]]] = ...) -> None: ... + def __init__( + self, results: _Optional[_Iterable[_Union[VectorQueryResults, _Mapping]]] = ... + ) -> None: ... diff --git a/chromadb/proto/chroma_pb2_grpc.py b/chromadb/proto/chroma_pb2_grpc.py index f5cc85a36bd..af3c29b622d 100644 --- a/chromadb/proto/chroma_pb2_grpc.py +++ b/chromadb/proto/chroma_pb2_grpc.py @@ -1,3 +1,4 @@ +# type: ignore # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" import grpc @@ -6,7 +7,7 @@ class SegmentServerStub(object): - """Segment Server Interface + """Segment Server Interface TODO: figure out subpackaging, ideally this file is colocated with the segment server implementation """ @@ -18,19 +19,19 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.LoadSegment = channel.unary_unary( - '/chroma.SegmentServer/LoadSegment', - request_serializer=chromadb_dot_proto_dot_chroma__pb2.Segment.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_chroma__pb2.SegmentServerResponse.FromString, - ) + "/chroma.SegmentServer/LoadSegment", + request_serializer=chromadb_dot_proto_dot_chroma__pb2.Segment.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_chroma__pb2.SegmentServerResponse.FromString, + ) self.ReleaseSegment = channel.unary_unary( - '/chroma.SegmentServer/ReleaseSegment', - request_serializer=chromadb_dot_proto_dot_chroma__pb2.Segment.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_chroma__pb2.SegmentServerResponse.FromString, - ) + "/chroma.SegmentServer/ReleaseSegment", + request_serializer=chromadb_dot_proto_dot_chroma__pb2.Segment.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_chroma__pb2.SegmentServerResponse.FromString, + ) class SegmentServerServicer(object): - """Segment Server Interface + """Segment Server Interface TODO: figure out subpackaging, ideally this file is colocated with the segment server implementation """ @@ -38,80 +39,103 @@ class SegmentServerServicer(object): def LoadSegment(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def ReleaseSegment(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def add_SegmentServerServicer_to_server(servicer, server): rpc_method_handlers = { - 'LoadSegment': grpc.unary_unary_rpc_method_handler( - servicer.LoadSegment, - request_deserializer=chromadb_dot_proto_dot_chroma__pb2.Segment.FromString, - response_serializer=chromadb_dot_proto_dot_chroma__pb2.SegmentServerResponse.SerializeToString, - ), - 'ReleaseSegment': grpc.unary_unary_rpc_method_handler( - servicer.ReleaseSegment, - request_deserializer=chromadb_dot_proto_dot_chroma__pb2.Segment.FromString, - response_serializer=chromadb_dot_proto_dot_chroma__pb2.SegmentServerResponse.SerializeToString, - ), + "LoadSegment": grpc.unary_unary_rpc_method_handler( + servicer.LoadSegment, + request_deserializer=chromadb_dot_proto_dot_chroma__pb2.Segment.FromString, + response_serializer=chromadb_dot_proto_dot_chroma__pb2.SegmentServerResponse.SerializeToString, + ), + "ReleaseSegment": grpc.unary_unary_rpc_method_handler( + servicer.ReleaseSegment, + request_deserializer=chromadb_dot_proto_dot_chroma__pb2.Segment.FromString, + response_serializer=chromadb_dot_proto_dot_chroma__pb2.SegmentServerResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( - 'chroma.SegmentServer', rpc_method_handlers) + "chroma.SegmentServer", rpc_method_handlers + ) server.add_generic_rpc_handlers((generic_handler,)) - # This class is part of an EXPERIMENTAL API. +# This class is part of an EXPERIMENTAL API. class SegmentServer(object): - """Segment Server Interface + """Segment Server Interface TODO: figure out subpackaging, ideally this file is colocated with the segment server implementation """ @staticmethod - def LoadSegment(request, + def LoadSegment( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.SegmentServer/LoadSegment', + "/chroma.SegmentServer/LoadSegment", chromadb_dot_proto_dot_chroma__pb2.Segment.SerializeToString, chromadb_dot_proto_dot_chroma__pb2.SegmentServerResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) @staticmethod - def ReleaseSegment(request, + def ReleaseSegment( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.SegmentServer/ReleaseSegment', + "/chroma.SegmentServer/ReleaseSegment", chromadb_dot_proto_dot_chroma__pb2.Segment.SerializeToString, chromadb_dot_proto_dot_chroma__pb2.SegmentServerResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) class VectorReaderStub(object): - """Vector Reader Interface - - """ + """Vector Reader Interface""" def __init__(self, channel): """Constructor. @@ -120,89 +144,110 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.GetVectors = channel.unary_unary( - '/chroma.VectorReader/GetVectors', - request_serializer=chromadb_dot_proto_dot_chroma__pb2.GetVectorsRequest.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_chroma__pb2.GetVectorsResponse.FromString, - ) + "/chroma.VectorReader/GetVectors", + request_serializer=chromadb_dot_proto_dot_chroma__pb2.GetVectorsRequest.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_chroma__pb2.GetVectorsResponse.FromString, + ) self.QueryVectors = channel.unary_unary( - '/chroma.VectorReader/QueryVectors', - request_serializer=chromadb_dot_proto_dot_chroma__pb2.QueryVectorsRequest.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_chroma__pb2.QueryVectorsResponse.FromString, - ) + "/chroma.VectorReader/QueryVectors", + request_serializer=chromadb_dot_proto_dot_chroma__pb2.QueryVectorsRequest.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_chroma__pb2.QueryVectorsResponse.FromString, + ) class VectorReaderServicer(object): - """Vector Reader Interface - - """ + """Vector Reader Interface""" def GetVectors(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def QueryVectors(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def add_VectorReaderServicer_to_server(servicer, server): rpc_method_handlers = { - 'GetVectors': grpc.unary_unary_rpc_method_handler( - servicer.GetVectors, - request_deserializer=chromadb_dot_proto_dot_chroma__pb2.GetVectorsRequest.FromString, - response_serializer=chromadb_dot_proto_dot_chroma__pb2.GetVectorsResponse.SerializeToString, - ), - 'QueryVectors': grpc.unary_unary_rpc_method_handler( - servicer.QueryVectors, - request_deserializer=chromadb_dot_proto_dot_chroma__pb2.QueryVectorsRequest.FromString, - response_serializer=chromadb_dot_proto_dot_chroma__pb2.QueryVectorsResponse.SerializeToString, - ), + "GetVectors": grpc.unary_unary_rpc_method_handler( + servicer.GetVectors, + request_deserializer=chromadb_dot_proto_dot_chroma__pb2.GetVectorsRequest.FromString, + response_serializer=chromadb_dot_proto_dot_chroma__pb2.GetVectorsResponse.SerializeToString, + ), + "QueryVectors": grpc.unary_unary_rpc_method_handler( + servicer.QueryVectors, + request_deserializer=chromadb_dot_proto_dot_chroma__pb2.QueryVectorsRequest.FromString, + response_serializer=chromadb_dot_proto_dot_chroma__pb2.QueryVectorsResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( - 'chroma.VectorReader', rpc_method_handlers) + "chroma.VectorReader", rpc_method_handlers + ) server.add_generic_rpc_handlers((generic_handler,)) - # This class is part of an EXPERIMENTAL API. +# This class is part of an EXPERIMENTAL API. class VectorReader(object): - """Vector Reader Interface - - """ + """Vector Reader Interface""" @staticmethod - def GetVectors(request, + def GetVectors( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.VectorReader/GetVectors', + "/chroma.VectorReader/GetVectors", chromadb_dot_proto_dot_chroma__pb2.GetVectorsRequest.SerializeToString, chromadb_dot_proto_dot_chroma__pb2.GetVectorsResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) @staticmethod - def QueryVectors(request, + def QueryVectors( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.VectorReader/QueryVectors', + "/chroma.VectorReader/QueryVectors", chromadb_dot_proto_dot_chroma__pb2.QueryVectorsRequest.SerializeToString, chromadb_dot_proto_dot_chroma__pb2.QueryVectorsResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) diff --git a/chromadb/segment/__init__.py b/chromadb/segment/__init__.py index 2c2570796fc..f9e5afa7903 100644 --- a/chromadb/segment/__init__.py +++ b/chromadb/segment/__init__.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Sequence, TypeVar, Type +from typing import Optional, Sequence, TypeVar, Type from abc import abstractmethod from chromadb.types import ( Collection, @@ -126,19 +126,3 @@ def hint_use_collection(self, collection_id: UUID, hint_type: Operation) -> None it can preload segments as needed. This is only a hint, and implementations are free to ignore it.""" pass - - -class SegmentDirectory(Component): - """A segment directory is a data interface that manages the location of segments. Concretely, this - means that for clustered chroma, it provides the grpc endpoint for a segment.""" - - @abstractmethod - def get_segment_endpoint(self, segment: Segment) -> str: - """Return the segment residence for a given segment ID""" - - @abstractmethod - def register_updated_segment_callback( - self, callback: Callable[[Segment], None] - ) -> None: - """Register a callback that will be called when a segment is updated""" - pass diff --git a/chromadb/segment/distributed/__init__.py b/chromadb/segment/distributed/__init__.py new file mode 100644 index 00000000000..08efdafd18c --- /dev/null +++ b/chromadb/segment/distributed/__init__.py @@ -0,0 +1,70 @@ +from abc import abstractmethod +from typing import Any, Callable, List + +from overrides import EnforceOverrides, overrides +from chromadb.config import Component, System +from chromadb.types import Segment + + +class SegmentDirectory(Component): + """A segment directory is a data interface that manages the location of segments. Concretely, this + means that for clustered chroma, it provides the grpc endpoint for a segment.""" + + @abstractmethod + def get_segment_endpoint(self, segment: Segment) -> str: + """Return the segment residence for a given segment ID""" + + @abstractmethod + def register_updated_segment_callback( + self, callback: Callable[[Segment], None] + ) -> None: + """Register a callback that will be called when a segment is updated""" + pass + + +Memberlist = List[str] + + +class MemberlistProvider(Component, EnforceOverrides): + """Returns the latest memberlist and provdes a callback for when it changes. This + callback may be called from a different thread than the one that called. Callers should ensure + that they are thread-safe.""" + + callbacks: List[Callable[[Memberlist], Any]] + + def __init__(self, system: System): + self.callbacks = [] + super().__init__(system) + + @abstractmethod + def get_memberlist(self) -> Memberlist: + """Returns the latest memberlist""" + pass + + @abstractmethod + def set_memberlist_name(self, memberlist: str) -> None: + """Sets the memberlist that this provider will watch""" + pass + + @overrides + def stop(self) -> None: + """Stops watching the memberlist""" + self.callbacks = [] + + def register_updated_memberlist_callback( + self, callback: Callable[[Memberlist], Any] + ) -> None: + """Registers a callback that will be called when the memberlist changes. May be called many times + with the same memberlist, so callers should be idempotent. May be called from a different thread. + """ + self.callbacks.append(callback) + + def unregister_updated_memberlist_callback( + self, callback: Callable[[Memberlist], Any] + ) -> bool: + """Unregisters a callback that was previously registered. Returns True if the callback was + successfully unregistered, False if it was not ever registered.""" + if callback in self.callbacks: + self.callbacks.remove(callback) + return True + return False diff --git a/chromadb/segment/impl/distributed/segment_directory.py b/chromadb/segment/impl/distributed/segment_directory.py new file mode 100644 index 00000000000..9068f5ce645 --- /dev/null +++ b/chromadb/segment/impl/distributed/segment_directory.py @@ -0,0 +1,226 @@ +from typing import Any, Callable, Dict, Optional, cast +from overrides import EnforceOverrides, override +from chromadb.config import System +from chromadb.segment.distributed import ( + Memberlist, + MemberlistProvider, + SegmentDirectory, +) +from chromadb.types import Segment +from kubernetes import client, config, watch +from kubernetes.client.rest import ApiException +import threading + +# These could go in config but given that they will rarely change, they are here for now to avoid +# polluting the config file further. +WATCH_TIMEOUT_SECONDS = 10 +KUBERNETES_NAMESPACE = "chroma" +KUBERNETES_GROUP = "chroma.cluster" + + +class MockMemberlistProvider(MemberlistProvider, EnforceOverrides): + """A mock memberlist provider for testing""" + + _memberlist: Memberlist + + def __init__(self, system: System): + super().__init__(system) + self._memberlist = ["a", "b", "c"] + + @override + def get_memberlist(self) -> Memberlist: + return self._memberlist + + @override + def set_memberlist_name(self, memberlist: str) -> None: + pass # The mock provider does not need to set the memberlist name + + def update_memberlist(self, memberlist: Memberlist) -> None: + """Updates the memberlist and calls all registered callbacks. This mocks an update from a k8s CR""" + self._memberlist = memberlist + for callback in self.callbacks: + callback(memberlist) + + +class CustomResourceMemberlistProvider(MemberlistProvider, EnforceOverrides): + """A memberlist provider that uses a k8s custom resource to store the memberlist""" + + _kubernetes_api: client.CustomObjectsApi + _memberlist_name: Optional[str] + _curr_memberlist: Optional[Memberlist] + _curr_memberlist_mutex: threading.Lock + _watch_thread: Optional[threading.Thread] + _kill_watch_thread: threading.Event + + def __init__(self, system: System): + super().__init__(system) + config.load_config() + self._kubernetes_api = client.CustomObjectsApi() + self._watch_thread = None + self._memberlist_name = None + self._curr_memberlist = None + self._curr_memberlist_mutex = threading.Lock() + self._kill_watch_thread = threading.Event() + + @override + def start(self) -> None: + if self._memberlist_name is None: + raise ValueError("Memberlist name must be set before starting") + self.get_memberlist() + self._watch_worker_memberlist() + return super().start() + + @override + def stop(self) -> None: + self._curr_memberlist = None + self._memberlist_name = None + + # Stop the watch thread + self._kill_watch_thread.set() + if self._watch_thread is not None: + self._watch_thread.join() + self._watch_thread = None + self._kill_watch_thread.clear() + return super().stop() + + @override + def reset_state(self) -> None: + if not self._system.settings.require("allow_reset"): + raise ValueError( + "Resetting the database is not allowed. Set `allow_reset` to true in the config in tests or other non-production environments where reset should be permitted." + ) + if self._memberlist_name: + self._kubernetes_api.patch_namespaced_custom_object( + group=KUBERNETES_GROUP, + version="v1", + namespace=KUBERNETES_NAMESPACE, + plural="memberlists", + name=self._memberlist_name, + body={ + "kind": "MemberList", + "spec": {"members": []}, + }, + ) + + @override + def get_memberlist(self) -> Memberlist: + if self._curr_memberlist is None: + self._curr_memberlist = self._fetch_memberlist() + return self._curr_memberlist + + @override + def set_memberlist_name(self, memberlist: str) -> None: + self._memberlist_name = memberlist + + def _fetch_memberlist(self) -> Memberlist: + api_response = self._kubernetes_api.get_namespaced_custom_object( + group=KUBERNETES_GROUP, + version="v1", + namespace=KUBERNETES_NAMESPACE, + plural="memberlists", + name=f"{self._memberlist_name}", + ) + api_response = cast(Dict[str, Any], api_response) + if "spec" not in api_response: + return [] + response_spec = cast(Dict[str, Any], api_response["spec"]) + return self._parse_response_memberlist(response_spec) + + def _watch_worker_memberlist(self) -> None: + # TODO: We may want to make this watch function a library function that can be used by other + # components that need to watch k8s custom resources. + def run_watch() -> None: + w = watch.Watch() + + def do_watch() -> None: + for event in w.stream( + self._kubernetes_api.list_namespaced_custom_object, + group=KUBERNETES_GROUP, + version="v1", + namespace=KUBERNETES_NAMESPACE, + plural="memberlists", + field_selector=f"metadata.name={self._memberlist_name}", + timeout_seconds=WATCH_TIMEOUT_SECONDS, + ): + event = cast(Dict[str, Any], event) + response_spec = event["object"]["spec"] + response_spec = cast(Dict[str, Any], response_spec) + with self._curr_memberlist_mutex: + self._curr_memberlist = self._parse_response_memberlist( + response_spec + ) + self._notify(self._curr_memberlist) + + # Watch the custom resource for changes + # Watch with a timeout and retry so we can gracefully stop this if needed + while not self._kill_watch_thread.is_set(): + try: + do_watch() + except ApiException as e: + # If status code is 410, the watch has expired and we need to start a new one. + if e.status == 410: + pass + return + + if self._watch_thread is None: + thread = threading.Thread(target=run_watch, daemon=True) + thread.start() + self._watch_thread = thread + else: + raise Exception("A watch thread is already running.") + + def _parse_response_memberlist( + self, api_response_spec: Dict[str, Any] + ) -> Memberlist: + if "members" not in api_response_spec: + return [] + return [m["url"] for m in api_response_spec["members"]] + + def _notify(self, memberlist: Memberlist) -> None: + for callback in self.callbacks: + callback(memberlist) + + +class RendezvousHashSegmentDirectory(SegmentDirectory, EnforceOverrides): + _memberlist_provider: MemberlistProvider + _curr_memberlist_mutex: threading.Lock + _curr_memberlist: Optional[Memberlist] + + def __init__(self, system: System): + super().__init__(system) + self._memberlist_provider = self.require(MemberlistProvider) + memberlist_name = system.settings.require("worker_memberlist_name") + self._memberlist_provider.set_memberlist_name(memberlist_name) + + self._curr_memberlist = None + self._curr_memberlist_mutex = threading.Lock() + + @override + def start(self) -> None: + self._curr_memberlist = self._memberlist_provider.get_memberlist() + self._memberlist_provider.register_updated_memberlist_callback( + self._update_memberlist + ) + return super().start() + + @override + def stop(self) -> None: + self._memberlist_provider.unregister_updated_memberlist_callback( + self._update_memberlist + ) + return super().stop() + + @override + def get_segment_endpoint(self, segment: Segment) -> str: + # TODO: This should rendezvous hash the segment ID to a worker given the current memberlist + return "segment-worker.chroma:50051" + + @override + def register_updated_segment_callback( + self, callback: Callable[[Segment], None] + ) -> None: + raise NotImplementedError() + + def _update_memberlist(self, memberlist: Memberlist) -> None: + with self._curr_memberlist_mutex: + self._curr_memberlist = memberlist diff --git a/chromadb/segment/impl/manager/distributed.py b/chromadb/segment/impl/manager/distributed.py index ea6c2f0267f..a7c673920a8 100644 --- a/chromadb/segment/impl/manager/distributed.py +++ b/chromadb/segment/impl/manager/distributed.py @@ -1,7 +1,7 @@ from threading import Lock import grpc -from chromadb.proto.chroma_pb2_grpc import SegmentServerStub +from chromadb.proto.chroma_pb2_grpc import SegmentServerStub # type: ignore from chromadb.proto.convert import to_proto_segment from chromadb.segment import ( SegmentImplementation, @@ -14,10 +14,9 @@ from chromadb.config import System, get_class from chromadb.db.system import SysDB from overrides import override -from enum import Enum -from chromadb.segment import SegmentDirectory +from chromadb.segment.distributed import SegmentDirectory from chromadb.types import Collection, Operation, Segment, SegmentScope, Metadata -from typing import Dict, List, Type, Sequence, Optional, cast +from typing import Dict, Type, Sequence, Optional, cast from uuid import UUID, uuid4 from collections import defaultdict @@ -117,7 +116,7 @@ def hint_use_collection(self, collection_id: UUID, hint_type: Operation) -> None if grpc_url not in self._segment_server_stubs: channel = grpc.insecure_channel(grpc_url) - self._segment_server_stubs[grpc_url] = SegmentServerStub(channel) # type: ignore + self._segment_server_stubs[grpc_url] = SegmentServerStub(channel) self._segment_server_stubs[grpc_url].LoadSegment( to_proto_segment(segment) diff --git a/chromadb/segment/impl/manager/segment_directory.py b/chromadb/segment/impl/manager/segment_directory.py deleted file mode 100644 index 7e086e26c5e..00000000000 --- a/chromadb/segment/impl/manager/segment_directory.py +++ /dev/null @@ -1,36 +0,0 @@ -from typing import Callable -from overrides import EnforceOverrides, override - -from chromadb.segment import SegmentDirectory -from chromadb.types import Segment - - -class DockerComposeSegmentDirectory(SegmentDirectory, EnforceOverrides): - """A segment directory that uses docker-compose to manage segment endpoints""" - - @override - def get_segment_endpoint(self, segment: Segment) -> str: - # This is just a stub for now, as we don't have a real coordinator to assign and manage this - return "segment-worker:50051" - - @override - def register_updated_segment_callback( - self, callback: Callable[[Segment], None] - ) -> None: - # Updates are not supported for docker-compose yet, as there is only a single, static - # indexing node - pass - - -class KubernetesSegmentDirectory(SegmentDirectory, EnforceOverrides): - @override - def get_segment_endpoint(self, segment: Segment) -> str: - return "segment-worker.chroma:50051" - - @override - def register_updated_segment_callback( - self, callback: Callable[[Segment], None] - ) -> None: - # Updates are not supported for docker-compose yet, as there is only a single, static - # indexing node - pass diff --git a/chromadb/test/conftest.py b/chromadb/test/conftest.py index aa4e530e384..af66ef2513f 100644 --- a/chromadb/test/conftest.py +++ b/chromadb/test/conftest.py @@ -194,7 +194,6 @@ def fastapi_persistent() -> Generator[System, None, None]: def basic_http_client() -> Generator[System, None, None]: settings = Settings( chroma_api_impl="chromadb.api.fastapi.FastAPI", - chroma_server_host="localhost", chroma_server_http_port="8000", allow_reset=True, ) diff --git a/chromadb/test/ingest/test_producer_consumer.py b/chromadb/test/ingest/test_producer_consumer.py index de2ed592d07..399b945bf5f 100644 --- a/chromadb/test/ingest/test_producer_consumer.py +++ b/chromadb/test/ingest/test_producer_consumer.py @@ -55,16 +55,14 @@ def sqlite_persistent() -> Generator[Tuple[Producer, Consumer], None, None]: def pulsar() -> Generator[Tuple[Producer, Consumer], None, None]: """Fixture generator for pulsar Producer + Consumer. This fixture requires a running - pulsar cluster. You can use bin/cluster-test.sh to start a standalone pulsar and run this test + pulsar cluster. You can use bin/cluster-test.sh to start a standalone pulsar and run this test. + Assumes pulsar_broker_url etc is set from the environment variables like PULSAR_BROKER_URL. """ system = System( Settings( allow_reset=True, chroma_producer_impl="chromadb.ingest.impl.pulsar.PulsarProducer", chroma_consumer_impl="chromadb.ingest.impl.pulsar.PulsarConsumer", - pulsar_broker_url="localhost", - pulsar_admin_port="8080", - pulsar_broker_port="6650", ) ) producer = system.require(Producer) diff --git a/chromadb/test/segment/distributed/test_memberlist_provider.py b/chromadb/test/segment/distributed/test_memberlist_provider.py new file mode 100644 index 00000000000..8fda6998023 --- /dev/null +++ b/chromadb/test/segment/distributed/test_memberlist_provider.py @@ -0,0 +1,122 @@ +# Tests the CustomResourceMemberlist provider +import threading +from kubernetes import client, config +import pytest +import os +from chromadb.config import System, Settings +from chromadb.segment.distributed import Memberlist +from chromadb.segment.impl.distributed.segment_directory import ( + CustomResourceMemberlistProvider, + KUBERNETES_GROUP, + KUBERNETES_NAMESPACE, +) +import time + +NOT_CLUSTER_ONLY = os.getenv("CHROMA_CLUSTER_TEST_ONLY") != "1" + + +def skip_if_not_cluster() -> pytest.MarkDecorator: + return pytest.mark.skipif( + NOT_CLUSTER_ONLY, + reason="Requires Kubernetes to be running with a valid config", + ) + + +# Used for testing to update the memberlist CRD +def update_memberlist(n: int, memberlist_name: str = "worker-memberlist") -> Memberlist: + config.load_config() + api_instance = client.CustomObjectsApi() + + members = [{"url": f"ip.{i}.com"} for i in range(1, n + 1)] + + body = { + "kind": "MemberList", + "metadata": {"name": memberlist_name}, + "spec": {"members": members}, + } + + _ = api_instance.patch_namespaced_custom_object( + group=KUBERNETES_GROUP, + version="v1", + namespace=KUBERNETES_NAMESPACE, + plural="memberlists", + name=memberlist_name, + body=body, + ) + + return [m["url"] for m in members] + + +def compare_memberlists(m1: Memberlist, m2: Memberlist) -> bool: + return sorted(m1) == sorted(m2) + + +@skip_if_not_cluster() +def test_can_get_memberlist() -> None: + # This test assumes that the memberlist CRD is already created with the name "worker-memberlist" + system = System(Settings(allow_reset=True)) + provider = system.instance(CustomResourceMemberlistProvider) + provider.set_memberlist_name("worker-memberlist") + system.reset_state() + system.start() + + # Update the memberlist + members = update_memberlist(3) + + # Check that the memberlist is updated after a short delay + time.sleep(2) + assert compare_memberlists(provider.get_memberlist(), members) + + system.stop() + + +@skip_if_not_cluster() +def test_can_update_memberlist_multiple_times() -> None: + # This test assumes that the memberlist CRD is already created with the name "worker-memberlist" + system = System(Settings(allow_reset=True)) + provider = system.instance(CustomResourceMemberlistProvider) + provider.set_memberlist_name("worker-memberlist") + system.reset_state() + system.start() + + # Update the memberlist + members = update_memberlist(3) + + # Check that the memberlist is updated after a short delay + time.sleep(2) + assert compare_memberlists(provider.get_memberlist(), members) + + # Update the memberlist again + members = update_memberlist(5) + + # Check that the memberlist is updated after a short delay + time.sleep(2) + assert compare_memberlists(provider.get_memberlist(), members) + + system.stop() + + +@skip_if_not_cluster() +def test_stop_memberlist_kills_thread() -> None: + # This test assumes that the memberlist CRD is already created with the name "worker-memberlist" + system = System(Settings(allow_reset=True)) + provider = system.instance(CustomResourceMemberlistProvider) + provider.set_memberlist_name("worker-memberlist") + system.reset_state() + system.start() + + # Make sure a background thread is running + assert len(threading.enumerate()) == 2 + + # Update the memberlist + members = update_memberlist(3) + + # Check that the memberlist is updated after a short delay + time.sleep(2) + assert compare_memberlists(provider.get_memberlist(), members) + + # Stop the system + system.stop() + + # Check to make sure only one thread is running + assert len(threading.enumerate()) == 1 diff --git a/docker-compose.cluster.test.yml b/docker-compose.cluster.test.yml deleted file mode 100644 index 8b3f83eda7f..00000000000 --- a/docker-compose.cluster.test.yml +++ /dev/null @@ -1,96 +0,0 @@ -# This docker compose file is not meant to be used. It is a work in progress -# for the distributed version of Chroma. It is not yet functional. - -version: '3.9' - -networks: - net: - driver: bridge - -services: - server: - image: server - build: - context: . - dockerfile: Dockerfile - volumes: - - ./:/chroma - - index_data:/index_data - command: uvicorn chromadb.app:app --reload --workers 1 --host 0.0.0.0 --port 8000 --log-config chromadb/log_config.yml - environment: - - IS_PERSISTENT=TRUE - - CHROMA_PRODUCER_IMPL=chromadb.ingest.impl.pulsar.PulsarProducer - - CHROMA_CONSUMER_IMPL=chromadb.ingest.impl.pulsar.PulsarConsumer - - CHROMA_SEGMENT_MANAGER_IMPL=chromadb.segment.impl.manager.distributed.DistributedSegmentManager - - PULSAR_BROKER_URL=pulsar - - PULSAR_BROKER_PORT=6650 - - PULSAR_ADMIN_PORT=8080 - - ANONYMIZED_TELEMETRY=False - - ALLOW_RESET=True - ports: - - 8000:8000 - depends_on: - pulsar: - condition: service_healthy - networks: - - net - - segment-worker: - image: segment-worker - build: - context: . - dockerfile: Dockerfile - volumes: - - ./:/chroma - - index_data:/index_data - command: python -m chromadb.segment.impl.distributed.server - environment: - - IS_PERSISTENT=TRUE - - CHROMA_PRODUCER_IMPL=chromadb.ingest.impl.pulsar.PulsarProducer - - CHROMA_CONSUMER_IMPL=chromadb.ingest.impl.pulsar.PulsarConsumer - - PULSAR_BROKER_URL=pulsar - - PULSAR_BROKER_PORT=6650 - - PULSAR_ADMIN_PORT=8080 - - CHROMA_SERVER_GRPC_PORT=50051 - - ANONYMIZED_TELEMETRY=False - - ALLOW_RESET=True - ports: - - 50051:50051 - depends_on: - pulsar: - condition: service_healthy - networks: - - net - - pulsar: - image: apachepulsar/pulsar - volumes: - - pulsardata:/pulsar/data - - pulsarconf:/pulsar/conf - command: bin/pulsar standalone - ports: - - 6650:6650 - - 8080:8080 - networks: - - net - healthcheck: - test: - [ - "CMD", - "curl", - "-f", - "localhost:8080/admin/v2/brokers/health" - ] - interval: 3s - timeout: 1m - retries: 10 - -volumes: - index_data: - driver: local - backups: - driver: local - pulsardata: - driver: local - pulsarconf: - driver: local diff --git a/docker-compose.cluster.yml b/docker-compose.cluster.yml deleted file mode 100644 index 2aaee58c566..00000000000 --- a/docker-compose.cluster.yml +++ /dev/null @@ -1,92 +0,0 @@ -# This docker compose file is not meant to be used. It is a work in progress -# for the distributed version of Chroma. It is not yet functional. - -version: '3.9' - -networks: - net: - driver: bridge - -services: - server: - image: server - build: - context: . - dockerfile: Dockerfile - volumes: - - ./:/chroma - - index_data:/index_data - command: uvicorn chromadb.app:app --reload --workers 1 --host 0.0.0.0 --port 8000 --log-config chromadb/log_config.yml - environment: - - IS_PERSISTENT=TRUE - - CHROMA_PRODUCER_IMPL=chromadb.ingest.impl.pulsar.PulsarProducer - - CHROMA_CONSUMER_IMPL=chromadb.ingest.impl.pulsar.PulsarConsumer - - CHROMA_SEGMENT_MANAGER_IMPL=chromadb.segment.impl.manager.distributed.DistributedSegmentManager - - PULSAR_BROKER_URL=pulsar - - PULSAR_BROKER_PORT=6650 - - PULSAR_ADMIN_PORT=8080 - ports: - - 8000:8000 - depends_on: - pulsar: - condition: service_healthy - networks: - - net - - segment-worker: - image: segment-worker - build: - context: . - dockerfile: Dockerfile - volumes: - - ./:/chroma - - index_data:/index_data - command: python -m chromadb.segment.impl.distributed.server - environment: - - IS_PERSISTENT=TRUE - - CHROMA_PRODUCER_IMPL=chromadb.ingest.impl.pulsar.PulsarProducer - - CHROMA_CONSUMER_IMPL=chromadb.ingest.impl.pulsar.PulsarConsumer - - PULSAR_BROKER_URL=pulsar - - PULSAR_BROKER_PORT=6650 - - PULSAR_ADMIN_PORT=8080 - - CHROMA_SERVER_GRPC_PORT=50051 - ports: - - 50051:50051 - depends_on: - pulsar: - condition: service_healthy - networks: - - net - - pulsar: - image: apachepulsar/pulsar - volumes: - - pulsardata:/pulsar/data - - pulsarconf:/pulsar/conf - command: bin/pulsar standalone - ports: - - 6650:6650 - - 8080:8080 - networks: - - net - healthcheck: - test: - [ - "CMD", - "curl", - "-f", - "localhost:8080/admin/v2/brokers/health" - ] - interval: 3s - timeout: 1m - retries: 10 - -volumes: - index_data: - driver: local - backups: - driver: local - pulsardata: - driver: local - pulsarconf: - driver: local diff --git a/k8s/WARNING.md b/k8s/WARNING.md new file mode 100644 index 00000000000..7933f8a712a --- /dev/null +++ b/k8s/WARNING.md @@ -0,0 +1,3 @@ +# These kubernetes manifests are UNDER ACTIVE DEVELOPMENT and are not yet ready for production use. +# They will be used for the upcoming distributed version of chroma. They are not even ready +# for testing yet. Please do not use them unless you are working on the distributed version of chroma. diff --git a/k8s/cr/worker_memberlist_cr.yaml b/k8s/cr/worker_memberlist_cr.yaml new file mode 100644 index 00000000000..a56c3ef1525 --- /dev/null +++ b/k8s/cr/worker_memberlist_cr.yaml @@ -0,0 +1,43 @@ +# These kubernetes manifests are UNDER ACTIVE DEVELOPMENT and are not yet ready for production use. +# They will be used for the upcoming distributed version of chroma. They are not even ready +# for testing yet. Please do not use them unless you are working on the distributed version of chroma. + +# Create a memberlist called worker-memberlist +apiVersion: chroma.cluster/v1 +kind: MemberList +metadata: + name: worker-memberlist + namespace: chroma +spec: + members: + +--- + +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: worker-memberlist-reader +rules: +- apiGroups: + - chroma.cluster + resources: + - memberlists + verbs: + - get + - list + - watch + +--- + +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: worker-memberlist-reader-binding +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: ClusterRole + name: worker-memberlist-reader +subjects: +- kind: ServiceAccount + name: default + namespace: chroma diff --git a/k8s/crd/memberlist_crd.yaml b/k8s/crd/memberlist_crd.yaml new file mode 100644 index 00000000000..96be7388d01 --- /dev/null +++ b/k8s/crd/memberlist_crd.yaml @@ -0,0 +1,36 @@ +# These kubernetes manifests are UNDER ACTIVE DEVELOPMENT and are not yet ready for production use. +# They will be used for the upcoming distributed version of chroma. They are not even ready +# for testing yet. Please do not use them unless you are working on the distributed version of chroma. + +apiVersion: apiextensions.k8s.io/v1 +kind: CustomResourceDefinition +metadata: + name: memberlists.chroma.cluster +spec: + group: chroma.cluster + versions: + - name: v1 + served: true + storage: true + schema: + openAPIV3Schema: + type: object + properties: + spec: + type: object + properties: + members: + type: array + items: + type: object + properties: + url: + type: string + pattern: '^(https?:\/\/)?([\da-z\.-]+)\.([a-z\.]{2,6})([\/\w \.-]*)*\/?$' + scope: Namespaced + names: + plural: memberlists + singular: memberlist + kind: MemberList + shortNames: + - ml diff --git a/k8s/deployment/kubernetes.yaml b/k8s/deployment/kubernetes.yaml new file mode 100644 index 00000000000..1d7f5330ce6 --- /dev/null +++ b/k8s/deployment/kubernetes.yaml @@ -0,0 +1,240 @@ +# These kubernetes manifests are UNDER ACTIVE DEVELOPMENT and are not yet ready for production use. +# They will be used for the upcoming distributed version of chroma. They are not even ready +# for testing yet. Please do not use them unless you are working on the distributed version of chroma. + +apiVersion: v1 +kind: Namespace +metadata: + name: chroma + +--- + +apiVersion: v1 +kind: Service +metadata: + name: pulsar + namespace: chroma +spec: + ports: + - name: pulsar-port + port: 6650 + targetPort: 6650 + - name: admin-port + port: 8080 + targetPort: 8080 + selector: + app: pulsar + type: ClusterIP + +--- + +# TODO: Should be stateful set locally or managed via terraform into streamnative for cloud deployment +apiVersion: apps/v1 +kind: Deployment +metadata: + name: pulsar + namespace: chroma +spec: + replicas: 1 + selector: + matchLabels: + app: pulsar + template: + metadata: + labels: + app: pulsar + spec: + containers: + - name: pulsar + image: apachepulsar/pulsar + command: [ "/pulsar/bin/pulsar", "standalone" ] + env: + # This is needed by github actions. We force this to be lower everywehre for now. + # Since real deployments will configure/use pulsar this way. + - name: PULSAR_MEM + value: "-Xms128m -Xmx512m" + ports: + - containerPort: 6650 + - containerPort: 8080 + volumeMounts: + - name: pulsardata + mountPath: /pulsar/data + # readinessProbe: + # httpGet: + # path: /admin/v2/brokers/health + # port: 8080 + # initialDelaySeconds: 10 + # periodSeconds: 5 + # livenessProbe: + # httpGet: + # path: /admin/v2/brokers/health + # port: 8080 + # initialDelaySeconds: 20 + # periodSeconds: 10 + volumes: + - name: pulsardata + emptyDir: {} + +--- + +apiVersion: v1 +kind: Service +metadata: + name: server + namespace: chroma +spec: + ports: + - name: server + port: 8000 + targetPort: 8000 + selector: + app: server + type: LoadBalancer + +--- + +apiVersion: apps/v1 +kind: Deployment +metadata: + name: server + namespace: chroma +spec: + replicas: 1 + selector: + matchLabels: + app: server + template: + metadata: + labels: + app: server + spec: + containers: + - name: server + image: server + imagePullPolicy: IfNotPresent + ports: + - containerPort: 8000 + volumeMounts: + - name: chroma + mountPath: /test + env: + - name: IS_PERSISTENT + value: "TRUE" + - name: CHROMA_PRODUCER_IMPL + value: "chromadb.ingest.impl.pulsar.PulsarProducer" + - name: CHROMA_CONSUMER_IMPL + value: "chromadb.ingest.impl.pulsar.PulsarConsumer" + - name: CHROMA_SEGMENT_MANAGER_IMPL + value: "chromadb.segment.impl.manager.distributed.DistributedSegmentManager" + - name: PULSAR_BROKER_URL + value: "pulsar.chroma" + - name: PULSAR_BROKER_PORT + value: "6650" + - name: PULSAR_ADMIN_PORT + value: "8080" + - name: ALLOW_RESET + value: "TRUE" + readinessProbe: + httpGet: + path: /api/v1/heartbeat + port: 8000 + initialDelaySeconds: 10 + periodSeconds: 5 + # livenessProbe: + # httpGet: + # path: /healthz + # port: 8000 + # initialDelaySeconds: 20 + # periodSeconds: 10 + # Ephemeral for now + volumes: + - name: chroma + emptyDir: {} + +--- + +apiVersion: v1 +kind: Service +metadata: + name: segment-server + namespace: chroma +spec: + ports: + - name: segment-server-port + port: 50051 + targetPort: 50051 + selector: + app: segment-server + type: ClusterIP + +--- + +apiVersion: apps/v1 +kind: Deployment +metadata: + name: segment-server + namespace: chroma +spec: + replicas: 1 + selector: + matchLabels: + app: segment-server + template: + metadata: + labels: + app: segment-server + spec: + containers: + - name: segment-server + image: server + imagePullPolicy: IfNotPresent + command: ["python", "-m", "chromadb.segment.impl.distributed.server"] + ports: + - containerPort: 50051 + volumeMounts: + - name: chroma + mountPath: /index_data + env: + - name: IS_PERSISTENT + value: "TRUE" + - name: CHROMA_PRODUCER_IMPL + value: "chromadb.ingest.impl.pulsar.PulsarProducer" + - name: CHROMA_CONSUMER_IMPL + value: "chromadb.ingest.impl.pulsar.PulsarConsumer" + - name: PULSAR_BROKER_URL + value: "pulsar.chroma" + - name: PULSAR_BROKER_PORT + value: "6650" + - name: PULSAR_ADMIN_PORT + value: "8080" + - name: CHROMA_SERVER_GRPC_PORT + value: "50051" + # readinessProbe: + # httpGet: + # path: /healthz + # port: 50051 + # initialDelaySeconds: 10 + # periodSeconds: 5 + # livenessProbe: + # httpGet: + # path: /healthz + # port: 50051 + # initialDelaySeconds: 20 + # periodSeconds: 10 + volumes: + - name: chroma + emptyDir: {} + +# --- + +# apiVersion: v1 +# kind: PersistentVolumeClaim +# metadata: +# name: index-data +# namespace: chroma +# spec: +# accessModes: +# - ReadWriteOnce +# resources: +# requests: +# storage: 1Gi diff --git a/k8s/test/pulsar_service.yaml b/k8s/test/pulsar_service.yaml new file mode 100644 index 00000000000..1053c709afa --- /dev/null +++ b/k8s/test/pulsar_service.yaml @@ -0,0 +1,20 @@ +# These kubernetes manifests are UNDER ACTIVE DEVELOPMENT and are not yet ready for production use. +# They will be used for the upcoming distributed version of chroma. They are not even ready +# for testing yet. Please do not use them unless you are working on the distributed version of chroma. + +apiVersion: v1 +kind: Service +metadata: + name: pulsar + namespace: chroma +spec: + ports: + - name: pulsar-port + port: 6650 + targetPort: 6650 + - name: admin-port + port: 8080 + targetPort: 8080 + selector: + app: pulsar + type: LoadBalancer diff --git a/pyproject.toml b/pyproject.toml index f69ae22140c..7dd144cf3ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,8 @@ dependencies = [ 'grpcio >= 1.58.0', 'bcrypt >= 4.0.1', 'typer >= 0.9.0', + 'kubernetes>=28.1.0', + 'tenacity>=8.2.3', ] [tool.black] diff --git a/requirements.txt b/requirements.txt index 6501c7d9d2c..78af541f009 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ fastapi>=0.95.2 graphlib_backport==1.0.3; python_version < '3.9' grpcio==1.58.0 importlib-resources +kubernetes>=28.1.0 numpy==1.21.6; python_version < '3.8' numpy>=1.22.4; python_version >= '3.8' onnxruntime>=1.14.1 @@ -13,6 +14,7 @@ pulsar-client==3.1.0 pydantic>=1.9 pypika==0.48.9 requests==2.28.1 +tenacity>=8.2.3 tokenizers==0.13.2 tqdm==4.65.0 typer>=0.9.0 From c0e307e33fa95766d7157be555cdddd7c9962228 Mon Sep 17 00:00:00 2001 From: Kevin Ji <1146876+kevinji@users.noreply.github.com> Date: Tue, 10 Oct 2023 13:33:07 -0400 Subject: [PATCH 03/14] README: Fix link to sentence transformers (#1224) ## Description of changes Fix a link to sentence transformers in the README. ## Test plan N/A ## Documentation Changes N/A --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 25db53b73d8..139b66583f4 100644 --- a/README.md +++ b/README.md @@ -88,7 +88,7 @@ What are embeddings? - __Technical__: An embedding is the latent-space position of a document at a layer of a deep neural network. For models trained specifically to embed data, this is the last layer. - __A small example__: If you search your photos for "famous bridge in San Francisco". By embedding this query and comparing it to the embeddings of your photos and their metadata - it should return photos of the Golden Gate Bridge. -Embeddings databases (also known as **vector databases**) store embeddings and allow you to search by nearest neighbors rather than by substrings like a traditional database. By default, Chroma uses [Sentence Transformers](https://docs.trychroma.com/embeddings#default-sentence-transformers) to embed for you but you can also use OpenAI embeddings, Cohere (multilingual) embeddings, or your own. +Embeddings databases (also known as **vector databases**) store embeddings and allow you to search by nearest neighbors rather than by substrings like a traditional database. By default, Chroma uses [Sentence Transformers](https://docs.trychroma.com/embeddings#sentence-transformers) to embed for you but you can also use OpenAI embeddings, Cohere (multilingual) embeddings, or your own. ## Get involved From b20db0b19744c74dbbd318b7900909877c9093c9 Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Thu, 12 Oct 2023 20:15:21 +0300 Subject: [PATCH 04/14] [BUG]: Unpinned tqdm version in requirements.txt (#1236) #1235 ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - TQDM version pinning was causing issues with other deps that required a higher version. ## Test plan *How are these changes tested?* Local pip install tests. ## Documentation Changes N/A --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 78af541f009..7b60e6101bb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,7 @@ pypika==0.48.9 requests==2.28.1 tenacity>=8.2.3 tokenizers==0.13.2 -tqdm==4.65.0 +tqdm>=4.65.0 typer>=0.9.0 typing_extensions>=4.5.0 uvicorn[standard]==0.18.3 From 85f125ad44d97c12590ded741ec809dbdadbac94 Mon Sep 17 00:00:00 2001 From: Ben Eggers <64657842+beggers@users.noreply.github.com> Date: Fri, 13 Oct 2023 08:26:35 -0700 Subject: [PATCH 05/14] Observability cip (#1219) ## Description of changes - New functionality - A CIP proposing adding OpenTelemetry observability to Chroma. ## Test plan *How are these changes tested?* - [x] Tests pass locally with `pytest` for python, `yarn test` for js --- docs/CIP_6_OpenTelemetry_Monitoring.md | 41 ++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 docs/CIP_6_OpenTelemetry_Monitoring.md diff --git a/docs/CIP_6_OpenTelemetry_Monitoring.md b/docs/CIP_6_OpenTelemetry_Monitoring.md new file mode 100644 index 00000000000..a212b8e6c49 --- /dev/null +++ b/docs/CIP_6_OpenTelemetry_Monitoring.md @@ -0,0 +1,41 @@ +# CIP 6: OpenTelemetry Monitoring + +## **Status** + +Current status: `Under Discussion` + +## **Motivation** + +Chroma currently has very little observability, only offering basic logging. Using Chroma in a high-performance production context requires the ability to understand how Chroma is behaving and responding to requests. + +## **Public Interfaces** + +The changes will affect the following: + +- Logging output +- Several new CLI flags + +## **Proposed Changes** + +We propose to instrument Chroma with [OpenTelemetry](https://opentelemetry.io/docs/instrumentation/python/) (OTel), the most prevalent open-source observability standard. OTel's Python libraries are considered stable for traces and metrics. We will create several layers of observability, configurable with command-line flags. + +- Chroma's default behavior will remain the same: events will be logged to the console with configurable severity levels. +- We will add a flag, `--opentelemetry-mode={api, sdk}` to instruct Chroma to export OTel data in either [API or SDK mode](https://stackoverflow.com/questions/72963553/opentelemetry-api-vs-sdk). +- We will add another flag, `--opentelemtry-detail={partial, full}`, to specify the level of detail desired from OTel. + - With `partial` detail, Chroma's top-level API calls will produce a single span. This mode is suitable for end-users of Chroma who are not intimately familiar with its operation but use it as part of their production system. + - `full` detail will emit spans for Chroma's sub-operations, enabling Chroma maintainers to monitor performance and diagnose issues. +- For now Chroma's OTel integrations will need to be specified with environment variables. As the [OTel file configuration project](https://github.com/MrAlias/otel-schema/pull/44) matures we will integrate support for file-based OTel configuration. + +## **Compatibility, Deprecation, and Migration Plan** + +This change adds no new default-on functionality. + +## **Test Plan** + +Observability logic and output will be tested on both single-node and distributed Chroma to confirm that metrics are exported properly and traces correctly identify parent spans across function and service boundaries. + +## **Rejected Alternatives** + +### Prometheus metrics + +Prometheus metrics offer similar OSS functionality to OTel. However the Prometheus standard is older and belongs to a single open-source project; OTel is designed for long-term cross-compatibility between *all* observability backends. As such, OTel output can easily be ingested by Prometheus users so there is no loss of functionality or compatibility. \ No newline at end of file From 734b133909f4d2e0e159c02c9447efbd627facbd Mon Sep 17 00:00:00 2001 From: Ben Eggers <64657842+beggers@users.noreply.github.com> Date: Fri, 13 Oct 2023 08:32:50 -0700 Subject: [PATCH 06/14] Further posthog improvements (and a little .gitignore) (#1222) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Increase add() batch size. - Flatten server context so we can do things like group by version. - Add a few things to `.gitignore` which seem to be created by test runs. - New functionality - Batch query() calls (batch size = 20, where we started for add()). - Change `collection.get()` so its fields are `int` instead of `bool` since I imagine we'll eventually batch it as well. ## Test plan *How are these changes tested?* - [x] Tests pass locally with `pytest` for python, `yarn test` for js - [x] Tested locally by printing posthog events: ```python >>> import chromadb >>> chroma_client = chromadb.Client() bf9b885c-b86e-4194-97b4-d9701d293cce ClientStartEvent {'batch_size': 1, 'chroma_version': '0.4.14', 'server_context': 'None', 'chroma_api_impl': 'chromadb.api.segment.SegmentAPI', 'is_persistent': False, 'chroma_server_ssl_enabled': False} >>> collection = chroma_client.create_collection(name="my_collection") bf9b885c-b86e-4194-97b4-d9701d293cce ClientCreateCollectionEvent {'batch_size': 1, 'collection_uuid': '50de2cd2-06ba-442d-82c1-ae0d94b620e4', 'embedding_function': 'ONNXMiniLM_L6_V2', 'chroma_version': '0.4.14', 'server_context': 'None', 'chroma_api_impl': 'chromadb.api.segment.SegmentAPI', 'is_persistent': False, 'chroma_server_ssl_enabled': False} >>> collection.add( ... documents=["This is a document", "This is another document"], ... metadatas=[{"source": "my_source"}, {"source": "my_source"}], ... ids=["id1", "id2"] ... ) bf9b885c-b86e-4194-97b4-d9701d293cce CollectionAddEvent {'batch_size': 1, 'collection_uuid': '50de2cd2-06ba-442d-82c1-ae0d94b620e4', 'add_amount': 2, 'with_documents': 2, 'with_metadata': 2, 'chroma_version': '0.4.14', 'server_context': 'None', 'chroma_api_impl': 'chromadb.api.segment.SegmentAPI', 'is_persistent': False, 'chroma_server_ssl_enabled': False} >>> for i in range(41): ... results = collection.query( ... query_texts=["This is a query document"], ... n_results=2 ... ) ... bf9b885c-b86e-4194-97b4-d9701d293cce CollectionQueryEvent {'batch_size': 1, 'collection_uuid': '50de2cd2-06ba-442d-82c1-ae0d94b620e4', 'query_amount': 1, 'with_metadata_filter': 1, 'with_document_filter': 1, 'n_results': 2, 'include_metadatas': 1, 'include_documents': 1, 'include_distances': 1, 'chroma_version': '0.4.14', 'server_context': 'None', 'chroma_api_impl': 'chromadb.api.segment.SegmentAPI', 'is_persistent': False, 'chroma_server_ssl_enabled': False} bf9b885c-b86e-4194-97b4-d9701d293cce CollectionQueryEvent {'batch_size': 20, 'collection_uuid': '50de2cd2-06ba-442d-82c1-ae0d94b620e4', 'query_amount': 20, 'with_metadata_filter': 20, 'with_document_filter': 20, 'n_results': 40, 'include_metadatas': 20, 'include_documents': 20, 'include_distances': 20, 'chroma_version': '0.4.14', 'server_context': 'None', 'chroma_api_impl': 'chromadb.api.segment.SegmentAPI', 'is_persistent': False, 'chroma_server_ssl_enabled': False} bf9b885c-b86e-4194-97b4-d9701d293cce CollectionQueryEvent {'batch_size': 20, 'collection_uuid': '50de2cd2-06ba-442d-82c1-ae0d94b620e4', 'query_amount': 20, 'with_metadata_filter': 20, 'with_document_filter': 20, 'n_results': 40, 'include_metadatas': 20, 'include_documents': 20, 'include_distances': 20, 'chroma_version': '0.4.14', 'server_context': 'None', 'chroma_api_impl': 'chromadb.api.segment.SegmentAPI', 'is_persistent': False, 'chroma_server_ssl_enabled': False} >>> for i in range(100): ... collection.add(documents=[str(i)], ids=[str(i)]) ... bf9b885c-b86e-4194-97b4-d9701d293cce CollectionAddEvent {'batch_size': 100, 'collection_uuid': '50de2cd2-06ba-442d-82c1-ae0d94b620e4', 'add_amount': 100, 'with_documents': 100, 'with_metadata': 0, 'chroma_version': '0.4.14', 'server_context': 'None', 'chroma_api_impl': 'chromadb.api.segment.SegmentAPI', 'is_persistent': False, 'chroma_server_ssl_enabled': False} ``` Also confirmed that `collection.get()` spits out an event every time it's called. It also spits out a bunch of data so I'll elide it from this PR description. ## Documentation Changes *Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?* No docs change needed -- we're not collecting anything new or changing anything significant about how we collect. --- .gitignore | 3 ++ chromadb/api/segment.py | 20 +++++++------ chromadb/telemetry/events.py | 56 +++++++++++++++++++++++++---------- chromadb/telemetry/posthog.py | 2 +- 4 files changed, 56 insertions(+), 25 deletions(-) diff --git a/.gitignore b/.gitignore index fd4f8aa8a97..316c32cb664 100644 --- a/.gitignore +++ b/.gitignore @@ -14,7 +14,10 @@ index_data # Default configuration for persist_directory in chromadb/config.py # Currently it's located in "./chroma/" chroma/ +chroma_test_data +server.htpasswd +.venv venv .env .chroma diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index d23139759d9..85dca1d8532 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -408,13 +408,14 @@ def _get( if "documents" in include: documents = [_doc(m) for m in metadatas] + ids_amount = len(ids) if ids else 0 self._telemetry_client.capture( CollectionGetEvent( collection_uuid=str(collection_id), - ids_count=len(ids) if ids else 0, + ids_count=ids_amount, limit=limit if limit else 0, - include_metadata="metadatas" in include, - include_documents="documents" in include, + include_metadata=ids_amount if "metadatas" in include else 0, + include_documents=ids_amount if "documents" in include else 0, ) ) @@ -571,16 +572,17 @@ def _query( doc_list = [_doc(m) for m in metadata_list] documents.append(doc_list) # type: ignore + query_amount = len(query_embeddings) self._telemetry_client.capture( CollectionQueryEvent( collection_uuid=str(collection_id), - query_amount=len(query_embeddings), + query_amount=query_amount, n_results=n_results, - with_metadata_filter=where is not None, - with_document_filter=where_document is not None, - include_metadatas="metadatas" in include, - include_documents="documents" in include, - include_distances="distances" in include, + with_metadata_filter=query_amount if where is not None else 0, + with_document_filter=query_amount if where_document is not None else 0, + include_metadatas=query_amount if "metadatas" in include else 0, + include_documents=query_amount if "documents" in include else 0, + include_distances=query_amount if "distances" in include else 0, ) ) diff --git a/chromadb/telemetry/events.py b/chromadb/telemetry/events.py index 34c6264fcc9..e662cd85fa7 100644 --- a/chromadb/telemetry/events.py +++ b/chromadb/telemetry/events.py @@ -26,7 +26,8 @@ def __init__(self, collection_uuid: str, embedding_function: str): class CollectionAddEvent(TelemetryEvent): - max_batch_size: ClassVar[int] = 20 + max_batch_size: ClassVar[int] = 100 + batch_size: int collection_uuid: str add_amount: int with_documents: int @@ -89,25 +90,28 @@ def __init__( class CollectionQueryEvent(TelemetryEvent): + max_batch_size: ClassVar[int] = 20 + batch_size: int collection_uuid: str query_amount: int - with_metadata_filter: bool - with_document_filter: bool + with_metadata_filter: int + with_document_filter: int n_results: int - include_metadatas: bool - include_documents: bool - include_distances: bool + include_metadatas: int + include_documents: int + include_distances: int def __init__( self, collection_uuid: str, query_amount: int, - with_metadata_filter: bool, - with_document_filter: bool, + with_metadata_filter: int, + with_document_filter: int, n_results: int, - include_metadatas: bool, - include_documents: bool, - include_distances: bool, + include_metadatas: int, + include_documents: int, + include_distances: int, + batch_size: int = 1, ): super().__init__() self.collection_uuid = collection_uuid @@ -118,22 +122,44 @@ def __init__( self.include_metadatas = include_metadatas self.include_documents = include_documents self.include_distances = include_distances + self.batch_size = batch_size + + @property + def batch_key(self) -> str: + return self.collection_uuid + self.name + + def batch(self, other: "TelemetryEvent") -> "CollectionQueryEvent": + if not self.batch_key == other.batch_key: + raise ValueError("Cannot batch events") + other = cast(CollectionQueryEvent, other) + total_amount = self.query_amount + other.query_amount + return CollectionQueryEvent( + collection_uuid=self.collection_uuid, + query_amount=total_amount, + with_metadata_filter=self.with_metadata_filter + other.with_metadata_filter, + with_document_filter=self.with_document_filter + other.with_document_filter, + n_results=self.n_results + other.n_results, + include_metadatas=self.include_metadatas + other.include_metadatas, + include_documents=self.include_documents + other.include_documents, + include_distances=self.include_distances + other.include_distances, + batch_size=self.batch_size + other.batch_size, + ) class CollectionGetEvent(TelemetryEvent): collection_uuid: str ids_count: int limit: int - include_metadata: bool - include_documents: bool + include_metadata: int + include_documents: int def __init__( self, collection_uuid: str, ids_count: int, limit: int, - include_metadata: bool, - include_documents: bool, + include_metadata: int, + include_documents: int, ): super().__init__() self.collection_uuid = collection_uuid diff --git a/chromadb/telemetry/posthog.py b/chromadb/telemetry/posthog.py index 184904531ef..21676b9fbe7 100644 --- a/chromadb/telemetry/posthog.py +++ b/chromadb/telemetry/posthog.py @@ -49,7 +49,7 @@ def _direct_capture(self, event: TelemetryEvent) -> None: posthog.capture( self.user_id, event.name, - {**(event.properties), "chroma_context": self.context}, + {**event.properties, **self.context}, ) except Exception as e: logger.error(f"Failed to send telemetry event {event.name}: {e}") From 7fd35dfa273e28b00a90020f7574d2c8dfdbb314 Mon Sep 17 00:00:00 2001 From: Hammad Bashir Date: Fri, 13 Oct 2023 13:58:49 -0700 Subject: [PATCH 07/14] [CLN] Move protos into IDL folder. (#1228) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Move protobufs into a top level folder so they can be shared. - New functionality - N/A ## Test plan *How are these changes tested?* Existing Tests ## Documentation Changes None --- .pre-commit-config.yaml | 2 +- chromadb/proto/chroma.proto | 40 -------- chromadb/proto/chroma_pb2.py | 94 +++++++++--------- chromadb/proto/chroma_pb2.pyi | 95 +++---------------- chromadb/proto/chroma_pb2_grpc.py | 1 - .../proto => idl}/chromadb/proto/chroma.proto | 0 idl/makefile | 8 ++ mypy.ini | 2 + 8 files changed, 70 insertions(+), 172 deletions(-) delete mode 100644 chromadb/proto/chroma.proto rename {chromadb/proto => idl}/chromadb/proto/chroma.proto (100%) create mode 100644 idl/makefile create mode 100644 mypy.ini diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3f0065bb133..750bab0d304 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,5 +32,5 @@ repos: rev: "v1.2.0" hooks: - id: mypy - args: [--strict, --ignore-missing-imports, --follow-imports=silent, --disable-error-code=type-abstract] + args: [--strict, --ignore-missing-imports, --follow-imports=silent, --disable-error-code=type-abstract, --config-file=./mypy.ini] additional_dependencies: ["types-requests", "pydantic", "overrides", "hypothesis", "pytest", "pypika", "numpy", "types-protobuf", "kubernetes"] diff --git a/chromadb/proto/chroma.proto b/chromadb/proto/chroma.proto deleted file mode 100644 index 7eefed74e12..00000000000 --- a/chromadb/proto/chroma.proto +++ /dev/null @@ -1,40 +0,0 @@ -syntax = "proto3"; - -package chroma; - -enum Operation { - ADD = 0; - UPDATE = 1; - UPSERT = 2; - DELETE = 3; -} - -enum ScalarEncoding { - FLOAT32 = 0; - INT32 = 1; -} - -message Vector { - int32 dimension = 1; - bytes vector = 2; - ScalarEncoding encoding = 3; -} - -message UpdateMetadataValue { - oneof value { - string string_value = 1; - int64 int_value = 2; - double float_value = 3; - } -} - -message UpdateMetadata { - map metadata = 1; -} - -message SubmitEmbeddingRecord { - string id = 1; - optional Vector vector = 2; - optional UpdateMetadata metadata = 3; - Operation operation = 4; -} diff --git a/chromadb/proto/chroma_pb2.py b/chromadb/proto/chroma_pb2.py index 4e8c62576f0..d6b22217d7d 100644 --- a/chromadb/proto/chroma_pb2.py +++ b/chromadb/proto/chroma_pb2.py @@ -1,4 +1,3 @@ -# type: ignore # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: chromadb/proto/chroma.proto @@ -7,61 +6,58 @@ from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder - # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1b\x63hromadb/proto/chroma.proto\x12\x06\x63hroma"U\n\x06Vector\x12\x11\n\tdimension\x18\x01 \x01(\x05\x12\x0e\n\x06vector\x18\x02 \x01(\x0c\x12(\n\x08\x65ncoding\x18\x03 \x01(\x0e\x32\x16.chroma.ScalarEncoding"\xca\x01\n\x07Segment\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12#\n\x05scope\x18\x03 \x01(\x0e\x32\x14.chroma.SegmentScope\x12\x12\n\x05topic\x18\x04 \x01(\tH\x00\x88\x01\x01\x12\x17\n\ncollection\x18\x05 \x01(\tH\x01\x88\x01\x01\x12-\n\x08metadata\x18\x06 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x88\x01\x01\x42\x08\n\x06_topicB\r\n\x0b_collectionB\x0b\n\t_metadata"b\n\x13UpdateMetadataValue\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x13\n\tint_value\x18\x02 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x01H\x00\x42\x07\n\x05value"\x96\x01\n\x0eUpdateMetadata\x12\x36\n\x08metadata\x18\x01 \x03(\x0b\x32$.chroma.UpdateMetadata.MetadataEntry\x1aL\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12*\n\x05value\x18\x02 \x01(\x0b\x32\x1b.chroma.UpdateMetadataValue:\x02\x38\x01"\xb5\x01\n\x15SubmitEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12#\n\x06vector\x18\x02 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x12-\n\x08metadata\x18\x03 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x01\x88\x01\x01\x12$\n\toperation\x18\x04 \x01(\x0e\x32\x11.chroma.OperationB\t\n\x07_vectorB\x0b\n\t_metadata"S\n\x15VectorEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06seq_id\x18\x02 \x01(\x0c\x12\x1e\n\x06vector\x18\x03 \x01(\x0b\x32\x0e.chroma.Vector"q\n\x11VectorQueryResult\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06seq_id\x18\x02 \x01(\x0c\x12\x10\n\x08\x64istance\x18\x03 \x01(\x01\x12#\n\x06vector\x18\x04 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x42\t\n\x07_vector"@\n\x12VectorQueryResults\x12*\n\x07results\x18\x01 \x03(\x0b\x32\x19.chroma.VectorQueryResult"(\n\x15SegmentServerResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08"4\n\x11GetVectorsRequest\x12\x0b\n\x03ids\x18\x01 \x03(\t\x12\x12\n\nsegment_id\x18\x02 \x01(\t"D\n\x12GetVectorsResponse\x12.\n\x07records\x18\x01 \x03(\x0b\x32\x1d.chroma.VectorEmbeddingRecord"\x86\x01\n\x13QueryVectorsRequest\x12\x1f\n\x07vectors\x18\x01 \x03(\x0b\x32\x0e.chroma.Vector\x12\t\n\x01k\x18\x02 \x01(\x05\x12\x13\n\x0b\x61llowed_ids\x18\x03 \x03(\t\x12\x1a\n\x12include_embeddings\x18\x04 \x01(\x08\x12\x12\n\nsegment_id\x18\x05 \x01(\t"C\n\x14QueryVectorsResponse\x12+\n\x07results\x18\x01 \x03(\x0b\x32\x1a.chroma.VectorQueryResults*8\n\tOperation\x12\x07\n\x03\x41\x44\x44\x10\x00\x12\n\n\x06UPDATE\x10\x01\x12\n\n\x06UPSERT\x10\x02\x12\n\n\x06\x44\x45LETE\x10\x03*(\n\x0eScalarEncoding\x12\x0b\n\x07\x46LOAT32\x10\x00\x12\t\n\x05INT32\x10\x01*(\n\x0cSegmentScope\x12\n\n\x06VECTOR\x10\x00\x12\x0c\n\x08METADATA\x10\x01\x32\x94\x01\n\rSegmentServer\x12?\n\x0bLoadSegment\x12\x0f.chroma.Segment\x1a\x1d.chroma.SegmentServerResponse"\x00\x12\x42\n\x0eReleaseSegment\x12\x0f.chroma.Segment\x1a\x1d.chroma.SegmentServerResponse"\x00\x32\xa2\x01\n\x0cVectorReader\x12\x45\n\nGetVectors\x12\x19.chroma.GetVectorsRequest\x1a\x1a.chroma.GetVectorsResponse"\x00\x12K\n\x0cQueryVectors\x12\x1b.chroma.QueryVectorsRequest\x1a\x1c.chroma.QueryVectorsResponse"\x00\x62\x06proto3' -) + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1b\x63hromadb/proto/chroma.proto\x12\x06\x63hroma\"U\n\x06Vector\x12\x11\n\tdimension\x18\x01 \x01(\x05\x12\x0e\n\x06vector\x18\x02 \x01(\x0c\x12(\n\x08\x65ncoding\x18\x03 \x01(\x0e\x32\x16.chroma.ScalarEncoding\"\xca\x01\n\x07Segment\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12#\n\x05scope\x18\x03 \x01(\x0e\x32\x14.chroma.SegmentScope\x12\x12\n\x05topic\x18\x04 \x01(\tH\x00\x88\x01\x01\x12\x17\n\ncollection\x18\x05 \x01(\tH\x01\x88\x01\x01\x12-\n\x08metadata\x18\x06 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x88\x01\x01\x42\x08\n\x06_topicB\r\n\x0b_collectionB\x0b\n\t_metadata\"b\n\x13UpdateMetadataValue\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x13\n\tint_value\x18\x02 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x01H\x00\x42\x07\n\x05value\"\x96\x01\n\x0eUpdateMetadata\x12\x36\n\x08metadata\x18\x01 \x03(\x0b\x32$.chroma.UpdateMetadata.MetadataEntry\x1aL\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12*\n\x05value\x18\x02 \x01(\x0b\x32\x1b.chroma.UpdateMetadataValue:\x02\x38\x01\"\xb5\x01\n\x15SubmitEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12#\n\x06vector\x18\x02 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x12-\n\x08metadata\x18\x03 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x01\x88\x01\x01\x12$\n\toperation\x18\x04 \x01(\x0e\x32\x11.chroma.OperationB\t\n\x07_vectorB\x0b\n\t_metadata\"S\n\x15VectorEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06seq_id\x18\x02 \x01(\x0c\x12\x1e\n\x06vector\x18\x03 \x01(\x0b\x32\x0e.chroma.Vector\"q\n\x11VectorQueryResult\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06seq_id\x18\x02 \x01(\x0c\x12\x10\n\x08\x64istance\x18\x03 \x01(\x01\x12#\n\x06vector\x18\x04 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x42\t\n\x07_vector\"@\n\x12VectorQueryResults\x12*\n\x07results\x18\x01 \x03(\x0b\x32\x19.chroma.VectorQueryResult\"(\n\x15SegmentServerResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"4\n\x11GetVectorsRequest\x12\x0b\n\x03ids\x18\x01 \x03(\t\x12\x12\n\nsegment_id\x18\x02 \x01(\t\"D\n\x12GetVectorsResponse\x12.\n\x07records\x18\x01 \x03(\x0b\x32\x1d.chroma.VectorEmbeddingRecord\"\x86\x01\n\x13QueryVectorsRequest\x12\x1f\n\x07vectors\x18\x01 \x03(\x0b\x32\x0e.chroma.Vector\x12\t\n\x01k\x18\x02 \x01(\x05\x12\x13\n\x0b\x61llowed_ids\x18\x03 \x03(\t\x12\x1a\n\x12include_embeddings\x18\x04 \x01(\x08\x12\x12\n\nsegment_id\x18\x05 \x01(\t\"C\n\x14QueryVectorsResponse\x12+\n\x07results\x18\x01 \x03(\x0b\x32\x1a.chroma.VectorQueryResults*8\n\tOperation\x12\x07\n\x03\x41\x44\x44\x10\x00\x12\n\n\x06UPDATE\x10\x01\x12\n\n\x06UPSERT\x10\x02\x12\n\n\x06\x44\x45LETE\x10\x03*(\n\x0eScalarEncoding\x12\x0b\n\x07\x46LOAT32\x10\x00\x12\t\n\x05INT32\x10\x01*(\n\x0cSegmentScope\x12\n\n\x06VECTOR\x10\x00\x12\x0c\n\x08METADATA\x10\x01\x32\x94\x01\n\rSegmentServer\x12?\n\x0bLoadSegment\x12\x0f.chroma.Segment\x1a\x1d.chroma.SegmentServerResponse\"\x00\x12\x42\n\x0eReleaseSegment\x12\x0f.chroma.Segment\x1a\x1d.chroma.SegmentServerResponse\"\x00\x32\xa2\x01\n\x0cVectorReader\x12\x45\n\nGetVectors\x12\x19.chroma.GetVectorsRequest\x1a\x1a.chroma.GetVectorsResponse\"\x00\x12K\n\x0cQueryVectors\x12\x1b.chroma.QueryVectorsRequest\x1a\x1c.chroma.QueryVectorsResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages( - DESCRIPTOR, "chromadb.proto.chroma_pb2", _globals -) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'chromadb.proto.chroma_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _UPDATEMETADATA_METADATAENTRY._options = None - _UPDATEMETADATA_METADATAENTRY._serialized_options = b"8\001" - _globals["_OPERATION"]._serialized_start = 1406 - _globals["_OPERATION"]._serialized_end = 1462 - _globals["_SCALARENCODING"]._serialized_start = 1464 - _globals["_SCALARENCODING"]._serialized_end = 1504 - _globals["_SEGMENTSCOPE"]._serialized_start = 1506 - _globals["_SEGMENTSCOPE"]._serialized_end = 1546 - _globals["_VECTOR"]._serialized_start = 39 - _globals["_VECTOR"]._serialized_end = 124 - _globals["_SEGMENT"]._serialized_start = 127 - _globals["_SEGMENT"]._serialized_end = 329 - _globals["_UPDATEMETADATAVALUE"]._serialized_start = 331 - _globals["_UPDATEMETADATAVALUE"]._serialized_end = 429 - _globals["_UPDATEMETADATA"]._serialized_start = 432 - _globals["_UPDATEMETADATA"]._serialized_end = 582 - _globals["_UPDATEMETADATA_METADATAENTRY"]._serialized_start = 506 - _globals["_UPDATEMETADATA_METADATAENTRY"]._serialized_end = 582 - _globals["_SUBMITEMBEDDINGRECORD"]._serialized_start = 585 - _globals["_SUBMITEMBEDDINGRECORD"]._serialized_end = 766 - _globals["_VECTOREMBEDDINGRECORD"]._serialized_start = 768 - _globals["_VECTOREMBEDDINGRECORD"]._serialized_end = 851 - _globals["_VECTORQUERYRESULT"]._serialized_start = 853 - _globals["_VECTORQUERYRESULT"]._serialized_end = 966 - _globals["_VECTORQUERYRESULTS"]._serialized_start = 968 - _globals["_VECTORQUERYRESULTS"]._serialized_end = 1032 - _globals["_SEGMENTSERVERRESPONSE"]._serialized_start = 1034 - _globals["_SEGMENTSERVERRESPONSE"]._serialized_end = 1074 - _globals["_GETVECTORSREQUEST"]._serialized_start = 1076 - _globals["_GETVECTORSREQUEST"]._serialized_end = 1128 - _globals["_GETVECTORSRESPONSE"]._serialized_start = 1130 - _globals["_GETVECTORSRESPONSE"]._serialized_end = 1198 - _globals["_QUERYVECTORSREQUEST"]._serialized_start = 1201 - _globals["_QUERYVECTORSREQUEST"]._serialized_end = 1335 - _globals["_QUERYVECTORSRESPONSE"]._serialized_start = 1337 - _globals["_QUERYVECTORSRESPONSE"]._serialized_end = 1404 - _globals["_SEGMENTSERVER"]._serialized_start = 1549 - _globals["_SEGMENTSERVER"]._serialized_end = 1697 - _globals["_VECTORREADER"]._serialized_start = 1700 - _globals["_VECTORREADER"]._serialized_end = 1862 + DESCRIPTOR._options = None + _UPDATEMETADATA_METADATAENTRY._options = None + _UPDATEMETADATA_METADATAENTRY._serialized_options = b'8\001' + _globals['_OPERATION']._serialized_start=1406 + _globals['_OPERATION']._serialized_end=1462 + _globals['_SCALARENCODING']._serialized_start=1464 + _globals['_SCALARENCODING']._serialized_end=1504 + _globals['_SEGMENTSCOPE']._serialized_start=1506 + _globals['_SEGMENTSCOPE']._serialized_end=1546 + _globals['_VECTOR']._serialized_start=39 + _globals['_VECTOR']._serialized_end=124 + _globals['_SEGMENT']._serialized_start=127 + _globals['_SEGMENT']._serialized_end=329 + _globals['_UPDATEMETADATAVALUE']._serialized_start=331 + _globals['_UPDATEMETADATAVALUE']._serialized_end=429 + _globals['_UPDATEMETADATA']._serialized_start=432 + _globals['_UPDATEMETADATA']._serialized_end=582 + _globals['_UPDATEMETADATA_METADATAENTRY']._serialized_start=506 + _globals['_UPDATEMETADATA_METADATAENTRY']._serialized_end=582 + _globals['_SUBMITEMBEDDINGRECORD']._serialized_start=585 + _globals['_SUBMITEMBEDDINGRECORD']._serialized_end=766 + _globals['_VECTOREMBEDDINGRECORD']._serialized_start=768 + _globals['_VECTOREMBEDDINGRECORD']._serialized_end=851 + _globals['_VECTORQUERYRESULT']._serialized_start=853 + _globals['_VECTORQUERYRESULT']._serialized_end=966 + _globals['_VECTORQUERYRESULTS']._serialized_start=968 + _globals['_VECTORQUERYRESULTS']._serialized_end=1032 + _globals['_SEGMENTSERVERRESPONSE']._serialized_start=1034 + _globals['_SEGMENTSERVERRESPONSE']._serialized_end=1074 + _globals['_GETVECTORSREQUEST']._serialized_start=1076 + _globals['_GETVECTORSREQUEST']._serialized_end=1128 + _globals['_GETVECTORSRESPONSE']._serialized_start=1130 + _globals['_GETVECTORSRESPONSE']._serialized_end=1198 + _globals['_QUERYVECTORSREQUEST']._serialized_start=1201 + _globals['_QUERYVECTORSREQUEST']._serialized_end=1335 + _globals['_QUERYVECTORSRESPONSE']._serialized_start=1337 + _globals['_QUERYVECTORSRESPONSE']._serialized_end=1404 + _globals['_SEGMENTSERVER']._serialized_start=1549 + _globals['_SEGMENTSERVER']._serialized_end=1697 + _globals['_VECTORREADER']._serialized_start=1700 + _globals['_VECTORREADER']._serialized_end=1862 # @@protoc_insertion_point(module_scope) diff --git a/chromadb/proto/chroma_pb2.pyi b/chromadb/proto/chroma_pb2.pyi index 0b52141e64a..6d06e074c06 100644 --- a/chromadb/proto/chroma_pb2.pyi +++ b/chromadb/proto/chroma_pb2.pyi @@ -1,16 +1,8 @@ -# type: ignore - from google.protobuf.internal import containers as _containers from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message -from typing import ( - ClassVar as _ClassVar, - Iterable as _Iterable, - Mapping as _Mapping, - Optional as _Optional, - Union as _Union, -) +from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union DESCRIPTOR: _descriptor.FileDescriptor @@ -30,7 +22,6 @@ class SegmentScope(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): __slots__ = [] VECTOR: _ClassVar[SegmentScope] METADATA: _ClassVar[SegmentScope] - ADD: Operation UPDATE: Operation UPSERT: Operation @@ -48,12 +39,7 @@ class Vector(_message.Message): dimension: int vector: bytes encoding: ScalarEncoding - def __init__( - self, - dimension: _Optional[int] = ..., - vector: _Optional[bytes] = ..., - encoding: _Optional[_Union[ScalarEncoding, str]] = ..., - ) -> None: ... + def __init__(self, dimension: _Optional[int] = ..., vector: _Optional[bytes] = ..., encoding: _Optional[_Union[ScalarEncoding, str]] = ...) -> None: ... class Segment(_message.Message): __slots__ = ["id", "type", "scope", "topic", "collection", "metadata"] @@ -69,15 +55,7 @@ class Segment(_message.Message): topic: str collection: str metadata: UpdateMetadata - def __init__( - self, - id: _Optional[str] = ..., - type: _Optional[str] = ..., - scope: _Optional[_Union[SegmentScope, str]] = ..., - topic: _Optional[str] = ..., - collection: _Optional[str] = ..., - metadata: _Optional[_Union[UpdateMetadata, _Mapping]] = ..., - ) -> None: ... + def __init__(self, id: _Optional[str] = ..., type: _Optional[str] = ..., scope: _Optional[_Union[SegmentScope, str]] = ..., topic: _Optional[str] = ..., collection: _Optional[str] = ..., metadata: _Optional[_Union[UpdateMetadata, _Mapping]] = ...) -> None: ... class UpdateMetadataValue(_message.Message): __slots__ = ["string_value", "int_value", "float_value"] @@ -87,32 +65,20 @@ class UpdateMetadataValue(_message.Message): string_value: str int_value: int float_value: float - def __init__( - self, - string_value: _Optional[str] = ..., - int_value: _Optional[int] = ..., - float_value: _Optional[float] = ..., - ) -> None: ... + def __init__(self, string_value: _Optional[str] = ..., int_value: _Optional[int] = ..., float_value: _Optional[float] = ...) -> None: ... class UpdateMetadata(_message.Message): __slots__ = ["metadata"] - class MetadataEntry(_message.Message): __slots__ = ["key", "value"] KEY_FIELD_NUMBER: _ClassVar[int] VALUE_FIELD_NUMBER: _ClassVar[int] key: str value: UpdateMetadataValue - def __init__( - self, - key: _Optional[str] = ..., - value: _Optional[_Union[UpdateMetadataValue, _Mapping]] = ..., - ) -> None: ... + def __init__(self, key: _Optional[str] = ..., value: _Optional[_Union[UpdateMetadataValue, _Mapping]] = ...) -> None: ... METADATA_FIELD_NUMBER: _ClassVar[int] metadata: _containers.MessageMap[str, UpdateMetadataValue] - def __init__( - self, metadata: _Optional[_Mapping[str, UpdateMetadataValue]] = ... - ) -> None: ... + def __init__(self, metadata: _Optional[_Mapping[str, UpdateMetadataValue]] = ...) -> None: ... class SubmitEmbeddingRecord(_message.Message): __slots__ = ["id", "vector", "metadata", "operation"] @@ -124,13 +90,7 @@ class SubmitEmbeddingRecord(_message.Message): vector: Vector metadata: UpdateMetadata operation: Operation - def __init__( - self, - id: _Optional[str] = ..., - vector: _Optional[_Union[Vector, _Mapping]] = ..., - metadata: _Optional[_Union[UpdateMetadata, _Mapping]] = ..., - operation: _Optional[_Union[Operation, str]] = ..., - ) -> None: ... + def __init__(self, id: _Optional[str] = ..., vector: _Optional[_Union[Vector, _Mapping]] = ..., metadata: _Optional[_Union[UpdateMetadata, _Mapping]] = ..., operation: _Optional[_Union[Operation, str]] = ...) -> None: ... class VectorEmbeddingRecord(_message.Message): __slots__ = ["id", "seq_id", "vector"] @@ -140,12 +100,7 @@ class VectorEmbeddingRecord(_message.Message): id: str seq_id: bytes vector: Vector - def __init__( - self, - id: _Optional[str] = ..., - seq_id: _Optional[bytes] = ..., - vector: _Optional[_Union[Vector, _Mapping]] = ..., - ) -> None: ... + def __init__(self, id: _Optional[str] = ..., seq_id: _Optional[bytes] = ..., vector: _Optional[_Union[Vector, _Mapping]] = ...) -> None: ... class VectorQueryResult(_message.Message): __slots__ = ["id", "seq_id", "distance", "vector"] @@ -157,21 +112,13 @@ class VectorQueryResult(_message.Message): seq_id: bytes distance: float vector: Vector - def __init__( - self, - id: _Optional[str] = ..., - seq_id: _Optional[bytes] = ..., - distance: _Optional[float] = ..., - vector: _Optional[_Union[Vector, _Mapping]] = ..., - ) -> None: ... + def __init__(self, id: _Optional[str] = ..., seq_id: _Optional[bytes] = ..., distance: _Optional[float] = ..., vector: _Optional[_Union[Vector, _Mapping]] = ...) -> None: ... class VectorQueryResults(_message.Message): __slots__ = ["results"] RESULTS_FIELD_NUMBER: _ClassVar[int] results: _containers.RepeatedCompositeFieldContainer[VectorQueryResult] - def __init__( - self, results: _Optional[_Iterable[_Union[VectorQueryResult, _Mapping]]] = ... - ) -> None: ... + def __init__(self, results: _Optional[_Iterable[_Union[VectorQueryResult, _Mapping]]] = ...) -> None: ... class SegmentServerResponse(_message.Message): __slots__ = ["success"] @@ -185,18 +132,13 @@ class GetVectorsRequest(_message.Message): SEGMENT_ID_FIELD_NUMBER: _ClassVar[int] ids: _containers.RepeatedScalarFieldContainer[str] segment_id: str - def __init__( - self, ids: _Optional[_Iterable[str]] = ..., segment_id: _Optional[str] = ... - ) -> None: ... + def __init__(self, ids: _Optional[_Iterable[str]] = ..., segment_id: _Optional[str] = ...) -> None: ... class GetVectorsResponse(_message.Message): __slots__ = ["records"] RECORDS_FIELD_NUMBER: _ClassVar[int] records: _containers.RepeatedCompositeFieldContainer[VectorEmbeddingRecord] - def __init__( - self, - records: _Optional[_Iterable[_Union[VectorEmbeddingRecord, _Mapping]]] = ..., - ) -> None: ... + def __init__(self, records: _Optional[_Iterable[_Union[VectorEmbeddingRecord, _Mapping]]] = ...) -> None: ... class QueryVectorsRequest(_message.Message): __slots__ = ["vectors", "k", "allowed_ids", "include_embeddings", "segment_id"] @@ -210,19 +152,10 @@ class QueryVectorsRequest(_message.Message): allowed_ids: _containers.RepeatedScalarFieldContainer[str] include_embeddings: bool segment_id: str - def __init__( - self, - vectors: _Optional[_Iterable[_Union[Vector, _Mapping]]] = ..., - k: _Optional[int] = ..., - allowed_ids: _Optional[_Iterable[str]] = ..., - include_embeddings: bool = ..., - segment_id: _Optional[str] = ..., - ) -> None: ... + def __init__(self, vectors: _Optional[_Iterable[_Union[Vector, _Mapping]]] = ..., k: _Optional[int] = ..., allowed_ids: _Optional[_Iterable[str]] = ..., include_embeddings: bool = ..., segment_id: _Optional[str] = ...) -> None: ... class QueryVectorsResponse(_message.Message): __slots__ = ["results"] RESULTS_FIELD_NUMBER: _ClassVar[int] results: _containers.RepeatedCompositeFieldContainer[VectorQueryResults] - def __init__( - self, results: _Optional[_Iterable[_Union[VectorQueryResults, _Mapping]]] = ... - ) -> None: ... + def __init__(self, results: _Optional[_Iterable[_Union[VectorQueryResults, _Mapping]]] = ...) -> None: ... diff --git a/chromadb/proto/chroma_pb2_grpc.py b/chromadb/proto/chroma_pb2_grpc.py index af3c29b622d..6d98cc34681 100644 --- a/chromadb/proto/chroma_pb2_grpc.py +++ b/chromadb/proto/chroma_pb2_grpc.py @@ -1,4 +1,3 @@ -# type: ignore # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" import grpc diff --git a/chromadb/proto/chromadb/proto/chroma.proto b/idl/chromadb/proto/chroma.proto similarity index 100% rename from chromadb/proto/chromadb/proto/chroma.proto rename to idl/chromadb/proto/chroma.proto diff --git a/idl/makefile b/idl/makefile new file mode 100644 index 00000000000..00a2f7d64a6 --- /dev/null +++ b/idl/makefile @@ -0,0 +1,8 @@ +.PHONY: proto + +proto: + @echo "Generating gRPC code..." + @python -m grpc_tools.protoc -I ./ --python_out=. --pyi_out=. --grpc_python_out=. ./chromadb/proto/*.proto + @mv chromadb/proto/*.py ../chromadb/proto/ + @mv chromadb/proto/*.pyi ../chromadb/proto/ + @echo "Done" diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000000..bcbf5f20f72 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,2 @@ +[mypy-chromadb.proto.*] +ignore_errors = True From f8805e6c66c7e7c7e2ff331e8ba845fbf98f2673 Mon Sep 17 00:00:00 2001 From: Hammad Bashir Date: Fri, 13 Oct 2023 14:30:04 -0700 Subject: [PATCH 08/14] [CLN] Add gitattributes entry to mark generated files (#1231) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Edit .gitattributes to collapse generated files (https://thoughtbot.com/blog/github-diff-supression) - New functionality - ... ## Test plan *How are these changes tested?* - [x] Tests pass locally with `pytest` for python, `yarn test` for js ## Documentation Changes *Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)? None required* --- .gitattributes | 1 + 1 file changed, 1 insertion(+) create mode 100644 .gitattributes diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000000..ff6c194874c --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +*_pb2.py* linguist-generated From a9d654bf375bec775923f8f7ead27197d978f99e Mon Sep 17 00:00:00 2001 From: Hammad Bashir Date: Sun, 15 Oct 2023 19:58:03 -0700 Subject: [PATCH 09/14] [STACKED #1228] [ENH] Grpc Coordinator/SysDB (#1229) ## Description of changes This PR is stacked on #1228, please TAL at that before this. *Summarize the changes made by this PR.* - Improvements & Bug fixes - n/a - New functionality - Introduced a grpc sysdb implementation. - Migrate the protobufs to support Null vs Presence where needed. - Added minimal protobuf definitions to implement the grpc coordinator - Add a mock grpc sysdb server that reimplements the core functionality in memory for testing. Will address the following open questions in stacked PRs! Open questions/issues: - [x] Collection <> Topic Mapping _This requires some thought if we use hashing then we can just map the collection onto a topic and store it. Will we ever want to shard a collection across topics?_ I don't think this is needed at this moment! - [x] Topic management moving to sysdb interface _This is conceptually simple, we make topic creation implicitly part of sysdb create collection_ I address this in #1237 - [x] We need to move the get_or_create concept into sysdb _This is conceptually simple, we just move this logic into sysdb and make it aware of it_ I address this in #1242 ## Test plan *How are these changes tested?* I added a mock gRPC server that implements the basic functionality of sysdb with in memory data structures. We will run the coordinator tests with this impl in the bin/cluster test once the go coordinator service is ready from @Ishiihara ## Documentation Changes None required. --- .pre-commit-config.yaml | 2 +- .vscode/settings.json | 7 +- chromadb/config.py | 1 + chromadb/db/impl/grpc/client.py | 237 +++++++++++++ chromadb/db/impl/grpc/server.py | 275 +++++++++++++++ chromadb/db/system.py | 2 +- chromadb/proto/chroma_pb2.py | 84 ++--- chromadb/proto/chroma_pb2.pyi | 28 ++ chromadb/proto/convert.py | 75 ++++- chromadb/proto/coordinator_pb2.py | 49 +++ chromadb/proto/coordinator_pb2.pyi | 116 +++++++ chromadb/proto/coordinator_pb2_grpc.py | 441 +++++++++++++++++++++++++ chromadb/test/db/test_system.py | 15 +- idl/chromadb/proto/chroma.proto | 19 ++ idl/chromadb/proto/coordinator.proto | 92 ++++++ 15 files changed, 1384 insertions(+), 59 deletions(-) create mode 100644 chromadb/db/impl/grpc/client.py create mode 100644 chromadb/db/impl/grpc/server.py create mode 100644 chromadb/proto/coordinator_pb2.py create mode 100644 chromadb/proto/coordinator_pb2.pyi create mode 100644 chromadb/proto/coordinator_pb2_grpc.py create mode 100644 idl/chromadb/proto/coordinator.proto diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 750bab0d304..97763ef5201 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,4 @@ -exclude: 'chromadb/proto/chroma_pb2\.(py|pyi|py_grpc\.py)' # Generated files +exclude: 'chromadb/proto/(chroma_pb2|coordinator_pb2)\.(py|pyi|py_grpc\.py)' # Generated files repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 diff --git a/.vscode/settings.json b/.vscode/settings.json index 5f44b098387..f62dcb24a74 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -35,5 +35,10 @@ "--no-pretty", "--strict", "--disable-error-code=type-abstract" - ] + ], + "protoc": { + "options": [ + "--proto_path=idl/", + ] + } } diff --git a/chromadb/config.py b/chromadb/config.py index 920c92d6a96..a2af7bd32bc 100644 --- a/chromadb/config.py +++ b/chromadb/config.py @@ -95,6 +95,7 @@ class Settings(BaseSettings): # type: ignore chroma_segment_directory_impl: str = "chromadb.segment.impl.distributed.segment_directory.RendezvousHashSegmentDirectory" chroma_memberlist_provider_impl: str = "chromadb.segment.impl.distributed.segment_directory.CustomResourceMemberlistProvider" worker_memberlist_name: str = "worker-memberlist" + chroma_coordinator_host = "localhost" tenant_id: str = "default" topic_namespace: str = "default" diff --git a/chromadb/db/impl/grpc/client.py b/chromadb/db/impl/grpc/client.py new file mode 100644 index 00000000000..5d7c6e88838 --- /dev/null +++ b/chromadb/db/impl/grpc/client.py @@ -0,0 +1,237 @@ +from typing import List, Optional, Sequence, Union, cast +from uuid import UUID +from overrides import overrides +from chromadb.config import System +from chromadb.db.base import NotFoundError, UniqueConstraintError +from chromadb.db.system import SysDB +from chromadb.proto.convert import ( + from_proto_collection, + from_proto_segment, + to_proto_collection, + to_proto_update_metadata, + to_proto_segment, + to_proto_segment_scope, +) +from chromadb.proto.coordinator_pb2 import ( + CreateCollectionRequest, + CreateSegmentRequest, + DeleteCollectionRequest, + DeleteSegmentRequest, + GetCollectionsRequest, + GetCollectionsResponse, + GetSegmentsRequest, + UpdateCollectionRequest, + UpdateSegmentRequest, +) +from chromadb.proto.coordinator_pb2_grpc import SysDBStub +from chromadb.types import ( + Collection, + OptionalArgument, + Segment, + SegmentScope, + Unspecified, + UpdateMetadata, +) +from google.protobuf.empty_pb2 import Empty +import grpc + + +class GrpcSysDB(SysDB): + """A gRPC implementation of the SysDB. In the distributed system, the SysDB is also + called the 'Coordinator'. This implementation is used by Chroma frontend servers + to call a remote SysDB (Coordinator) service.""" + + _sys_db_stub: SysDBStub + _channel: grpc.Channel + _coordinator_url: str + _coordinator_port: int + + def __init__(self, system: System): + self._coordinator_url = system.settings.require("chroma_coordinator_host") + # TODO: break out coordinator_port into a separate setting? + self._coordinator_port = system.settings.require("chroma_server_grpc_port") + return super().__init__(system) + + @overrides + def start(self) -> None: + self._channel = grpc.insecure_channel( + f"{self._coordinator_url}:{self._coordinator_port}" + ) + self._sys_db_stub = SysDBStub(self._channel) # type: ignore + return super().start() + + @overrides + def stop(self) -> None: + self._channel.close() + return super().stop() + + @overrides + def reset_state(self) -> None: + self._sys_db_stub.ResetState(Empty()) + return super().reset_state() + + @overrides + def create_segment(self, segment: Segment) -> None: + proto_segment = to_proto_segment(segment) + request = CreateSegmentRequest( + segment=proto_segment, + ) + response = self._sys_db_stub.CreateSegment(request) + if response.status.code == 409: + raise UniqueConstraintError() + + @overrides + def delete_segment(self, id: UUID) -> None: + request = DeleteSegmentRequest( + id=id.hex, + ) + response = self._sys_db_stub.DeleteSegment(request) + if response.status.code == 404: + raise NotFoundError() + + @overrides + def get_segments( + self, + id: Optional[UUID] = None, + type: Optional[str] = None, + scope: Optional[SegmentScope] = None, + topic: Optional[str] = None, + collection: Optional[UUID] = None, + ) -> Sequence[Segment]: + request = GetSegmentsRequest( + id=id.hex if id else None, + type=type, + scope=to_proto_segment_scope(scope) if scope else None, + topic=topic, + collection=collection.hex if collection else None, + ) + response = self._sys_db_stub.GetSegments(request) + results: List[Segment] = [] + for proto_segment in response.segments: + segment = from_proto_segment(proto_segment) + results.append(segment) + return results + + @overrides + def update_segment( + self, + id: UUID, + topic: OptionalArgument[Optional[str]] = Unspecified(), + collection: OptionalArgument[Optional[UUID]] = Unspecified(), + metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(), + ) -> None: + write_topic = None + if topic != Unspecified(): + write_topic = cast(Union[str, None], topic) + + write_collection = None + if collection != Unspecified(): + write_collection = cast(Union[UUID, None], collection) + + write_metadata = None + if metadata != Unspecified(): + write_metadata = cast(Union[UpdateMetadata, None], metadata) + + request = UpdateSegmentRequest( + id=id.hex, + topic=write_topic, + collection=write_collection.hex if write_collection else None, + metadata=to_proto_update_metadata(write_metadata) + if write_metadata + else None, + ) + + if topic is None: + request.ClearField("topic") + request.reset_topic = True + + if collection is None: + request.ClearField("collection") + request.reset_collection = True + + if metadata is None: + request.ClearField("metadata") + request.reset_metadata = True + + self._sys_db_stub.UpdateSegment(request) + + @overrides + def create_collection(self, collection: Collection) -> None: + # TODO: the get_or_create concept needs to be pushed down to the sysdb interface + request = CreateCollectionRequest( + collection=to_proto_collection(collection), + get_or_create=False, + ) + response = self._sys_db_stub.CreateCollection(request) + if response.status.code == 409: + raise UniqueConstraintError() + + @overrides + def delete_collection(self, id: UUID) -> None: + request = DeleteCollectionRequest( + id=id.hex, + ) + response = self._sys_db_stub.DeleteCollection(request) + if response.status.code == 404: + raise NotFoundError() + + @overrides + def get_collections( + self, + id: Optional[UUID] = None, + topic: Optional[str] = None, + name: Optional[str] = None, + ) -> Sequence[Collection]: + request = GetCollectionsRequest( + id=id.hex if id else None, + topic=topic, + name=name, + ) + response: GetCollectionsResponse = self._sys_db_stub.GetCollections(request) + results: List[Collection] = [] + for collection in response.collections: + results.append(from_proto_collection(collection)) + return results + + @overrides + def update_collection( + self, + id: UUID, + topic: OptionalArgument[str] = Unspecified(), + name: OptionalArgument[str] = Unspecified(), + dimension: OptionalArgument[Optional[int]] = Unspecified(), + metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(), + ) -> None: + write_topic = None + if topic != Unspecified(): + write_topic = cast(str, topic) + + write_name = None + if name != Unspecified(): + write_name = cast(str, name) + + write_dimension = None + if dimension != Unspecified(): + write_dimension = cast(Union[int, None], dimension) + + write_metadata = None + if metadata != Unspecified(): + write_metadata = cast(Union[UpdateMetadata, None], metadata) + + request = UpdateCollectionRequest( + id=id.hex, + topic=write_topic, + name=write_name, + dimension=write_dimension, + metadata=to_proto_update_metadata(write_metadata) + if write_metadata + else None, + ) + if metadata is None: + request.ClearField("metadata") + request.reset_metadata = True + + self._sys_db_stub.UpdateCollection(request) + + def reset_and_wait_for_ready(self) -> None: + self._sys_db_stub.ResetState(Empty(), wait_for_ready=True) diff --git a/chromadb/db/impl/grpc/server.py b/chromadb/db/impl/grpc/server.py new file mode 100644 index 00000000000..f1b69460492 --- /dev/null +++ b/chromadb/db/impl/grpc/server.py @@ -0,0 +1,275 @@ +from concurrent import futures +from typing import Any, Dict, cast +from uuid import UUID +from overrides import overrides +from chromadb.config import Component, System +from chromadb.proto.convert import ( + from_proto_collection, + from_proto_update_metadata, + from_proto_segment, + from_proto_segment_scope, + to_proto_collection, + to_proto_segment, +) +import chromadb.proto.chroma_pb2 as proto +from chromadb.proto.coordinator_pb2 import ( + CreateCollectionRequest, + CreateCollectionResponse, + CreateSegmentRequest, + DeleteCollectionRequest, + DeleteSegmentRequest, + GetCollectionsRequest, + GetCollectionsResponse, + GetSegmentsRequest, + GetSegmentsResponse, + UpdateCollectionRequest, + UpdateSegmentRequest, +) +from chromadb.proto.coordinator_pb2_grpc import ( + SysDBServicer, + add_SysDBServicer_to_server, +) +import grpc +from google.protobuf.empty_pb2 import Empty +from chromadb.types import Collection, Metadata, Segment + + +class GrpcMockSysDB(SysDBServicer, Component): + """A mock sysdb implementation that can be used for testing the grpc client. It stores + state in simple python data structures instead of a database.""" + + _server: grpc.Server + _server_port: int + _segments: Dict[str, Segment] = {} + _collections: Dict[str, Collection] = {} + + def __init__(self, system: System): + self._server_port = system.settings.require("chroma_server_grpc_port") + return super().__init__(system) + + @overrides + def start(self) -> None: + self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + add_SysDBServicer_to_server(self, self._server) # type: ignore + self._server.add_insecure_port(f"[::]:{self._server_port}") + self._server.start() + return super().start() + + @overrides + def stop(self) -> None: + self._server.stop(0) + return super().stop() + + @overrides + def reset_state(self) -> None: + self._segments = {} + self._collections = {} + return super().reset_state() + + # We are forced to use check_signature=False because the generated proto code + # does not have type annotations for the request and response objects. + # TODO: investigate generating types for the request and response objects + @overrides(check_signature=False) + def CreateSegment( + self, request: CreateSegmentRequest, context: grpc.ServicerContext + ) -> proto.ChromaResponse: + segment = from_proto_segment(request.segment) + if segment["id"].hex in self._segments: + return proto.ChromaResponse( + status=proto.Status( + code=409, reason=f"Segment {segment['id']} already exists" + ) + ) + self._segments[segment["id"].hex] = segment + return proto.ChromaResponse( + status=proto.Status(code=200) + ) # TODO: how are these codes used? Need to determine the standards for the code and reason. + + @overrides(check_signature=False) + def DeleteSegment( + self, request: DeleteSegmentRequest, context: grpc.ServicerContext + ) -> proto.ChromaResponse: + id_to_delete = request.id + if id_to_delete in self._segments: + del self._segments[id_to_delete] + return proto.ChromaResponse(status=proto.Status(code=200)) + else: + return proto.ChromaResponse( + status=proto.Status( + code=404, reason=f"Segment {id_to_delete} not found" + ) + ) + + @overrides(check_signature=False) + def GetSegments( + self, request: GetSegmentsRequest, context: grpc.ServicerContext + ) -> GetSegmentsResponse: + target_id = UUID(hex=request.id) if request.HasField("id") else None + target_type = request.type if request.HasField("type") else None + target_scope = ( + from_proto_segment_scope(request.scope) + if request.HasField("scope") + else None + ) + target_topic = request.topic if request.HasField("topic") else None + target_collection = ( + UUID(hex=request.collection) if request.HasField("collection") else None + ) + + found_segments = [] + for segment in self._segments.values(): + if target_id and segment["id"] != target_id: + continue + if target_type and segment["type"] != target_type: + continue + if target_scope and segment["scope"] != target_scope: + continue + if target_topic and segment["topic"] != target_topic: + continue + if target_collection and segment["collection"] != target_collection: + continue + found_segments.append(segment) + return GetSegmentsResponse( + segments=[to_proto_segment(segment) for segment in found_segments] + ) + + @overrides(check_signature=False) + def UpdateSegment( + self, request: UpdateSegmentRequest, context: grpc.ServicerContext + ) -> proto.ChromaResponse: + id_to_update = UUID(request.id) + if id_to_update.hex not in self._segments: + return proto.ChromaResponse( + status=proto.Status( + code=404, reason=f"Segment {id_to_update} not found" + ) + ) + else: + segment = self._segments[id_to_update.hex] + if request.HasField("topic"): + segment["topic"] = request.topic + if request.HasField("reset_topic") and request.reset_topic: + segment["topic"] = None + if request.HasField("collection"): + segment["collection"] = UUID(hex=request.collection) + if request.HasField("reset_collection") and request.reset_collection: + segment["collection"] = None + if request.HasField("metadata"): + target = cast(Dict[str, Any], segment["metadata"]) + if segment["metadata"] is None: + segment["metadata"] = {} + self._merge_metadata(target, request.metadata) + if request.HasField("reset_metadata") and request.reset_metadata: + segment["metadata"] = {} + return proto.ChromaResponse(status=proto.Status(code=200)) + + @overrides(check_signature=False) + def CreateCollection( + self, request: CreateCollectionRequest, context: grpc.ServicerContext + ) -> CreateCollectionResponse: + collection = from_proto_collection(request.collection) + if collection["id"].hex in self._collections: + return CreateCollectionResponse( + status=proto.Status( + code=409, reason=f"Collection {collection['id']} already exists" + ) + ) + + self._collections[collection["id"].hex] = collection + return CreateCollectionResponse( + status=proto.Status(code=200), + collection=to_proto_collection(collection), + ) + + @overrides(check_signature=False) + def DeleteCollection( + self, request: DeleteCollectionRequest, context: grpc.ServicerContext + ) -> proto.ChromaResponse: + collection_id = request.id + if collection_id in self._collections: + del self._collections[collection_id] + return proto.ChromaResponse(status=proto.Status(code=200)) + else: + return proto.ChromaResponse( + status=proto.Status( + code=404, reason=f"Collection {collection_id} not found" + ) + ) + + @overrides(check_signature=False) + def GetCollections( + self, request: GetCollectionsRequest, context: grpc.ServicerContext + ) -> GetCollectionsResponse: + target_id = UUID(hex=request.id) if request.HasField("id") else None + target_topic = request.topic if request.HasField("topic") else None + target_name = request.name if request.HasField("name") else None + + found_collections = [] + for collection in self._collections.values(): + if target_id and collection["id"] != target_id: + continue + if target_topic and collection["topic"] != target_topic: + continue + if target_name and collection["name"] != target_name: + continue + found_collections.append(collection) + return GetCollectionsResponse( + collections=[ + to_proto_collection(collection) for collection in found_collections + ] + ) + + @overrides(check_signature=False) + def UpdateCollection( + self, request: UpdateCollectionRequest, context: grpc.ServicerContext + ) -> proto.ChromaResponse: + id_to_update = UUID(request.id) + if id_to_update.hex not in self._collections: + return proto.ChromaResponse( + status=proto.Status( + code=404, reason=f"Collection {id_to_update} not found" + ) + ) + else: + collection = self._collections[id_to_update.hex] + if request.HasField("topic"): + collection["topic"] = request.topic + if request.HasField("name"): + collection["name"] = request.name + if request.HasField("dimension"): + collection["dimension"] = request.dimension + if request.HasField("metadata"): + # TODO: IN SysDB SQlite we have technical debt where we + # replace the entire metadata dict with the new one. We should + # fix that by merging it. For now we just do the same thing here + + update_metadata = from_proto_update_metadata(request.metadata) + cleaned_metadata = None + if update_metadata is not None: + cleaned_metadata = {} + for key, value in update_metadata.items(): + if value is not None: + cleaned_metadata[key] = value + + collection["metadata"] = cleaned_metadata + elif request.HasField("reset_metadata"): + if request.reset_metadata: + collection["metadata"] = {} + + return proto.ChromaResponse(status=proto.Status(code=200)) + + @overrides(check_signature=False) + def ResetState( + self, request: Empty, context: grpc.ServicerContext + ) -> proto.ChromaResponse: + self.reset_state() + return proto.ChromaResponse(status=proto.Status(code=200)) + + def _merge_metadata(self, target: Metadata, source: proto.UpdateMetadata) -> None: + target_metadata = cast(Dict[str, Any], target) + source_metadata = cast(Dict[str, Any], from_proto_update_metadata(source)) + target_metadata.update(source_metadata) + # If a key has a None value, remove it from the metadata + for key, value in source_metadata.items(): + if value is None and key in target: + del target_metadata[key] diff --git a/chromadb/db/system.py b/chromadb/db/system.py index 969b71afa3b..23f068c3be3 100644 --- a/chromadb/db/system.py +++ b/chromadb/db/system.py @@ -53,7 +53,7 @@ def update_segment( @abstractmethod def create_collection(self, collection: Collection) -> None: - """Create a new topic""" + """Create a new collection any associated resources in the SysDB.""" pass @abstractmethod diff --git a/chromadb/proto/chroma_pb2.py b/chromadb/proto/chroma_pb2.py index d6b22217d7d..bd069cc74f4 100644 --- a/chromadb/proto/chroma_pb2.py +++ b/chromadb/proto/chroma_pb2.py @@ -13,7 +13,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1b\x63hromadb/proto/chroma.proto\x12\x06\x63hroma\"U\n\x06Vector\x12\x11\n\tdimension\x18\x01 \x01(\x05\x12\x0e\n\x06vector\x18\x02 \x01(\x0c\x12(\n\x08\x65ncoding\x18\x03 \x01(\x0e\x32\x16.chroma.ScalarEncoding\"\xca\x01\n\x07Segment\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12#\n\x05scope\x18\x03 \x01(\x0e\x32\x14.chroma.SegmentScope\x12\x12\n\x05topic\x18\x04 \x01(\tH\x00\x88\x01\x01\x12\x17\n\ncollection\x18\x05 \x01(\tH\x01\x88\x01\x01\x12-\n\x08metadata\x18\x06 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x88\x01\x01\x42\x08\n\x06_topicB\r\n\x0b_collectionB\x0b\n\t_metadata\"b\n\x13UpdateMetadataValue\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x13\n\tint_value\x18\x02 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x01H\x00\x42\x07\n\x05value\"\x96\x01\n\x0eUpdateMetadata\x12\x36\n\x08metadata\x18\x01 \x03(\x0b\x32$.chroma.UpdateMetadata.MetadataEntry\x1aL\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12*\n\x05value\x18\x02 \x01(\x0b\x32\x1b.chroma.UpdateMetadataValue:\x02\x38\x01\"\xb5\x01\n\x15SubmitEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12#\n\x06vector\x18\x02 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x12-\n\x08metadata\x18\x03 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x01\x88\x01\x01\x12$\n\toperation\x18\x04 \x01(\x0e\x32\x11.chroma.OperationB\t\n\x07_vectorB\x0b\n\t_metadata\"S\n\x15VectorEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06seq_id\x18\x02 \x01(\x0c\x12\x1e\n\x06vector\x18\x03 \x01(\x0b\x32\x0e.chroma.Vector\"q\n\x11VectorQueryResult\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06seq_id\x18\x02 \x01(\x0c\x12\x10\n\x08\x64istance\x18\x03 \x01(\x01\x12#\n\x06vector\x18\x04 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x42\t\n\x07_vector\"@\n\x12VectorQueryResults\x12*\n\x07results\x18\x01 \x03(\x0b\x32\x19.chroma.VectorQueryResult\"(\n\x15SegmentServerResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"4\n\x11GetVectorsRequest\x12\x0b\n\x03ids\x18\x01 \x03(\t\x12\x12\n\nsegment_id\x18\x02 \x01(\t\"D\n\x12GetVectorsResponse\x12.\n\x07records\x18\x01 \x03(\x0b\x32\x1d.chroma.VectorEmbeddingRecord\"\x86\x01\n\x13QueryVectorsRequest\x12\x1f\n\x07vectors\x18\x01 \x03(\x0b\x32\x0e.chroma.Vector\x12\t\n\x01k\x18\x02 \x01(\x05\x12\x13\n\x0b\x61llowed_ids\x18\x03 \x03(\t\x12\x1a\n\x12include_embeddings\x18\x04 \x01(\x08\x12\x12\n\nsegment_id\x18\x05 \x01(\t\"C\n\x14QueryVectorsResponse\x12+\n\x07results\x18\x01 \x03(\x0b\x32\x1a.chroma.VectorQueryResults*8\n\tOperation\x12\x07\n\x03\x41\x44\x44\x10\x00\x12\n\n\x06UPDATE\x10\x01\x12\n\n\x06UPSERT\x10\x02\x12\n\n\x06\x44\x45LETE\x10\x03*(\n\x0eScalarEncoding\x12\x0b\n\x07\x46LOAT32\x10\x00\x12\t\n\x05INT32\x10\x01*(\n\x0cSegmentScope\x12\n\n\x06VECTOR\x10\x00\x12\x0c\n\x08METADATA\x10\x01\x32\x94\x01\n\rSegmentServer\x12?\n\x0bLoadSegment\x12\x0f.chroma.Segment\x1a\x1d.chroma.SegmentServerResponse\"\x00\x12\x42\n\x0eReleaseSegment\x12\x0f.chroma.Segment\x1a\x1d.chroma.SegmentServerResponse\"\x00\x32\xa2\x01\n\x0cVectorReader\x12\x45\n\nGetVectors\x12\x19.chroma.GetVectorsRequest\x1a\x1a.chroma.GetVectorsResponse\"\x00\x12K\n\x0cQueryVectors\x12\x1b.chroma.QueryVectorsRequest\x1a\x1c.chroma.QueryVectorsResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1b\x63hromadb/proto/chroma.proto\x12\x06\x63hroma\"&\n\x06Status\x12\x0e\n\x06reason\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\x05\"0\n\x0e\x43hromaResponse\x12\x1e\n\x06status\x18\x01 \x01(\x0b\x32\x0e.chroma.Status\"U\n\x06Vector\x12\x11\n\tdimension\x18\x01 \x01(\x05\x12\x0e\n\x06vector\x18\x02 \x01(\x0c\x12(\n\x08\x65ncoding\x18\x03 \x01(\x0e\x32\x16.chroma.ScalarEncoding\"\xca\x01\n\x07Segment\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12#\n\x05scope\x18\x03 \x01(\x0e\x32\x14.chroma.SegmentScope\x12\x12\n\x05topic\x18\x04 \x01(\tH\x00\x88\x01\x01\x12\x17\n\ncollection\x18\x05 \x01(\tH\x01\x88\x01\x01\x12-\n\x08metadata\x18\x06 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x88\x01\x01\x42\x08\n\x06_topicB\r\n\x0b_collectionB\x0b\n\t_metadata\"\x97\x01\n\nCollection\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\r\n\x05topic\x18\x03 \x01(\t\x12-\n\x08metadata\x18\x04 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x88\x01\x01\x12\x16\n\tdimension\x18\x05 \x01(\x05H\x01\x88\x01\x01\x42\x0b\n\t_metadataB\x0c\n\n_dimension\"b\n\x13UpdateMetadataValue\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x13\n\tint_value\x18\x02 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x01H\x00\x42\x07\n\x05value\"\x96\x01\n\x0eUpdateMetadata\x12\x36\n\x08metadata\x18\x01 \x03(\x0b\x32$.chroma.UpdateMetadata.MetadataEntry\x1aL\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12*\n\x05value\x18\x02 \x01(\x0b\x32\x1b.chroma.UpdateMetadataValue:\x02\x38\x01\"\xb5\x01\n\x15SubmitEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12#\n\x06vector\x18\x02 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x12-\n\x08metadata\x18\x03 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x01\x88\x01\x01\x12$\n\toperation\x18\x04 \x01(\x0e\x32\x11.chroma.OperationB\t\n\x07_vectorB\x0b\n\t_metadata\"S\n\x15VectorEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06seq_id\x18\x02 \x01(\x0c\x12\x1e\n\x06vector\x18\x03 \x01(\x0b\x32\x0e.chroma.Vector\"q\n\x11VectorQueryResult\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06seq_id\x18\x02 \x01(\x0c\x12\x10\n\x08\x64istance\x18\x03 \x01(\x01\x12#\n\x06vector\x18\x04 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x42\t\n\x07_vector\"@\n\x12VectorQueryResults\x12*\n\x07results\x18\x01 \x03(\x0b\x32\x19.chroma.VectorQueryResult\"(\n\x15SegmentServerResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"4\n\x11GetVectorsRequest\x12\x0b\n\x03ids\x18\x01 \x03(\t\x12\x12\n\nsegment_id\x18\x02 \x01(\t\"D\n\x12GetVectorsResponse\x12.\n\x07records\x18\x01 \x03(\x0b\x32\x1d.chroma.VectorEmbeddingRecord\"\x86\x01\n\x13QueryVectorsRequest\x12\x1f\n\x07vectors\x18\x01 \x03(\x0b\x32\x0e.chroma.Vector\x12\t\n\x01k\x18\x02 \x01(\x05\x12\x13\n\x0b\x61llowed_ids\x18\x03 \x03(\t\x12\x1a\n\x12include_embeddings\x18\x04 \x01(\x08\x12\x12\n\nsegment_id\x18\x05 \x01(\t\"C\n\x14QueryVectorsResponse\x12+\n\x07results\x18\x01 \x03(\x0b\x32\x1a.chroma.VectorQueryResults*8\n\tOperation\x12\x07\n\x03\x41\x44\x44\x10\x00\x12\n\n\x06UPDATE\x10\x01\x12\n\n\x06UPSERT\x10\x02\x12\n\n\x06\x44\x45LETE\x10\x03*(\n\x0eScalarEncoding\x12\x0b\n\x07\x46LOAT32\x10\x00\x12\t\n\x05INT32\x10\x01*(\n\x0cSegmentScope\x12\n\n\x06VECTOR\x10\x00\x12\x0c\n\x08METADATA\x10\x01\x32\x94\x01\n\rSegmentServer\x12?\n\x0bLoadSegment\x12\x0f.chroma.Segment\x1a\x1d.chroma.SegmentServerResponse\"\x00\x12\x42\n\x0eReleaseSegment\x12\x0f.chroma.Segment\x1a\x1d.chroma.SegmentServerResponse\"\x00\x32\xa2\x01\n\x0cVectorReader\x12\x45\n\nGetVectors\x12\x19.chroma.GetVectorsRequest\x1a\x1a.chroma.GetVectorsResponse\"\x00\x12K\n\x0cQueryVectors\x12\x1b.chroma.QueryVectorsRequest\x1a\x1c.chroma.QueryVectorsResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -22,42 +22,48 @@ DESCRIPTOR._options = None _UPDATEMETADATA_METADATAENTRY._options = None _UPDATEMETADATA_METADATAENTRY._serialized_options = b'8\001' - _globals['_OPERATION']._serialized_start=1406 - _globals['_OPERATION']._serialized_end=1462 - _globals['_SCALARENCODING']._serialized_start=1464 - _globals['_SCALARENCODING']._serialized_end=1504 - _globals['_SEGMENTSCOPE']._serialized_start=1506 - _globals['_SEGMENTSCOPE']._serialized_end=1546 - _globals['_VECTOR']._serialized_start=39 - _globals['_VECTOR']._serialized_end=124 - _globals['_SEGMENT']._serialized_start=127 - _globals['_SEGMENT']._serialized_end=329 - _globals['_UPDATEMETADATAVALUE']._serialized_start=331 - _globals['_UPDATEMETADATAVALUE']._serialized_end=429 - _globals['_UPDATEMETADATA']._serialized_start=432 - _globals['_UPDATEMETADATA']._serialized_end=582 - _globals['_UPDATEMETADATA_METADATAENTRY']._serialized_start=506 - _globals['_UPDATEMETADATA_METADATAENTRY']._serialized_end=582 - _globals['_SUBMITEMBEDDINGRECORD']._serialized_start=585 - _globals['_SUBMITEMBEDDINGRECORD']._serialized_end=766 - _globals['_VECTOREMBEDDINGRECORD']._serialized_start=768 - _globals['_VECTOREMBEDDINGRECORD']._serialized_end=851 - _globals['_VECTORQUERYRESULT']._serialized_start=853 - _globals['_VECTORQUERYRESULT']._serialized_end=966 - _globals['_VECTORQUERYRESULTS']._serialized_start=968 - _globals['_VECTORQUERYRESULTS']._serialized_end=1032 - _globals['_SEGMENTSERVERRESPONSE']._serialized_start=1034 - _globals['_SEGMENTSERVERRESPONSE']._serialized_end=1074 - _globals['_GETVECTORSREQUEST']._serialized_start=1076 - _globals['_GETVECTORSREQUEST']._serialized_end=1128 - _globals['_GETVECTORSRESPONSE']._serialized_start=1130 - _globals['_GETVECTORSRESPONSE']._serialized_end=1198 - _globals['_QUERYVECTORSREQUEST']._serialized_start=1201 - _globals['_QUERYVECTORSREQUEST']._serialized_end=1335 - _globals['_QUERYVECTORSRESPONSE']._serialized_start=1337 - _globals['_QUERYVECTORSRESPONSE']._serialized_end=1404 - _globals['_SEGMENTSERVER']._serialized_start=1549 - _globals['_SEGMENTSERVER']._serialized_end=1697 - _globals['_VECTORREADER']._serialized_start=1700 - _globals['_VECTORREADER']._serialized_end=1862 + _globals['_OPERATION']._serialized_start=1650 + _globals['_OPERATION']._serialized_end=1706 + _globals['_SCALARENCODING']._serialized_start=1708 + _globals['_SCALARENCODING']._serialized_end=1748 + _globals['_SEGMENTSCOPE']._serialized_start=1750 + _globals['_SEGMENTSCOPE']._serialized_end=1790 + _globals['_STATUS']._serialized_start=39 + _globals['_STATUS']._serialized_end=77 + _globals['_CHROMARESPONSE']._serialized_start=79 + _globals['_CHROMARESPONSE']._serialized_end=127 + _globals['_VECTOR']._serialized_start=129 + _globals['_VECTOR']._serialized_end=214 + _globals['_SEGMENT']._serialized_start=217 + _globals['_SEGMENT']._serialized_end=419 + _globals['_COLLECTION']._serialized_start=422 + _globals['_COLLECTION']._serialized_end=573 + _globals['_UPDATEMETADATAVALUE']._serialized_start=575 + _globals['_UPDATEMETADATAVALUE']._serialized_end=673 + _globals['_UPDATEMETADATA']._serialized_start=676 + _globals['_UPDATEMETADATA']._serialized_end=826 + _globals['_UPDATEMETADATA_METADATAENTRY']._serialized_start=750 + _globals['_UPDATEMETADATA_METADATAENTRY']._serialized_end=826 + _globals['_SUBMITEMBEDDINGRECORD']._serialized_start=829 + _globals['_SUBMITEMBEDDINGRECORD']._serialized_end=1010 + _globals['_VECTOREMBEDDINGRECORD']._serialized_start=1012 + _globals['_VECTOREMBEDDINGRECORD']._serialized_end=1095 + _globals['_VECTORQUERYRESULT']._serialized_start=1097 + _globals['_VECTORQUERYRESULT']._serialized_end=1210 + _globals['_VECTORQUERYRESULTS']._serialized_start=1212 + _globals['_VECTORQUERYRESULTS']._serialized_end=1276 + _globals['_SEGMENTSERVERRESPONSE']._serialized_start=1278 + _globals['_SEGMENTSERVERRESPONSE']._serialized_end=1318 + _globals['_GETVECTORSREQUEST']._serialized_start=1320 + _globals['_GETVECTORSREQUEST']._serialized_end=1372 + _globals['_GETVECTORSRESPONSE']._serialized_start=1374 + _globals['_GETVECTORSRESPONSE']._serialized_end=1442 + _globals['_QUERYVECTORSREQUEST']._serialized_start=1445 + _globals['_QUERYVECTORSREQUEST']._serialized_end=1579 + _globals['_QUERYVECTORSRESPONSE']._serialized_start=1581 + _globals['_QUERYVECTORSRESPONSE']._serialized_end=1648 + _globals['_SEGMENTSERVER']._serialized_start=1793 + _globals['_SEGMENTSERVER']._serialized_end=1941 + _globals['_VECTORREADER']._serialized_start=1944 + _globals['_VECTORREADER']._serialized_end=2106 # @@protoc_insertion_point(module_scope) diff --git a/chromadb/proto/chroma_pb2.pyi b/chromadb/proto/chroma_pb2.pyi index 6d06e074c06..733cae0a273 100644 --- a/chromadb/proto/chroma_pb2.pyi +++ b/chromadb/proto/chroma_pb2.pyi @@ -31,6 +31,20 @@ INT32: ScalarEncoding VECTOR: SegmentScope METADATA: SegmentScope +class Status(_message.Message): + __slots__ = ["reason", "code"] + REASON_FIELD_NUMBER: _ClassVar[int] + CODE_FIELD_NUMBER: _ClassVar[int] + reason: str + code: int + def __init__(self, reason: _Optional[str] = ..., code: _Optional[int] = ...) -> None: ... + +class ChromaResponse(_message.Message): + __slots__ = ["status"] + STATUS_FIELD_NUMBER: _ClassVar[int] + status: Status + def __init__(self, status: _Optional[_Union[Status, _Mapping]] = ...) -> None: ... + class Vector(_message.Message): __slots__ = ["dimension", "vector", "encoding"] DIMENSION_FIELD_NUMBER: _ClassVar[int] @@ -57,6 +71,20 @@ class Segment(_message.Message): metadata: UpdateMetadata def __init__(self, id: _Optional[str] = ..., type: _Optional[str] = ..., scope: _Optional[_Union[SegmentScope, str]] = ..., topic: _Optional[str] = ..., collection: _Optional[str] = ..., metadata: _Optional[_Union[UpdateMetadata, _Mapping]] = ...) -> None: ... +class Collection(_message.Message): + __slots__ = ["id", "name", "topic", "metadata", "dimension"] + ID_FIELD_NUMBER: _ClassVar[int] + NAME_FIELD_NUMBER: _ClassVar[int] + TOPIC_FIELD_NUMBER: _ClassVar[int] + METADATA_FIELD_NUMBER: _ClassVar[int] + DIMENSION_FIELD_NUMBER: _ClassVar[int] + id: str + name: str + topic: str + metadata: UpdateMetadata + dimension: int + def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., topic: _Optional[str] = ..., metadata: _Optional[_Union[UpdateMetadata, _Mapping]] = ..., dimension: _Optional[int] = ...) -> None: ... + class UpdateMetadataValue(_message.Message): __slots__ = ["string_value", "int_value", "float_value"] STRING_VALUE_FIELD_NUMBER: _ClassVar[int] diff --git a/chromadb/proto/convert.py b/chromadb/proto/convert.py index 5ff7bab085d..d46cad07710 100644 --- a/chromadb/proto/convert.py +++ b/chromadb/proto/convert.py @@ -1,10 +1,11 @@ import array from uuid import UUID -from typing import Dict, Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union, cast from chromadb.api.types import Embedding import chromadb.proto.chroma_pb2 as proto from chromadb.utils.messageid import bytes_to_int, int_to_bytes from chromadb.types import ( + Collection, EmbeddingRecord, Metadata, Operation, @@ -13,6 +14,7 @@ SegmentScope, SeqId, SubmitEmbeddingRecord, + UpdateMetadata, Vector, VectorEmbeddingRecord, VectorQueryResult, @@ -71,9 +73,23 @@ def from_proto_operation(operation: proto.Operation) -> Operation: def from_proto_metadata(metadata: proto.UpdateMetadata) -> Optional[Metadata]: + return cast(Optional[Metadata], _from_proto_metadata_handle_none(metadata, False)) + + +def from_proto_update_metadata( + metadata: proto.UpdateMetadata, +) -> Optional[UpdateMetadata]: + return cast( + Optional[UpdateMetadata], _from_proto_metadata_handle_none(metadata, True) + ) + + +def _from_proto_metadata_handle_none( + metadata: proto.UpdateMetadata, is_update: bool +) -> Optional[Union[UpdateMetadata, Metadata]]: if not metadata.metadata: return None - out_metadata: Dict[str, Union[str, int, float]] = {} + out_metadata: Dict[str, Union[str, int, float, None]] = {} for key, value in metadata.metadata.items(): if value.HasField("string_value"): out_metadata[key] = value.string_value @@ -81,11 +97,19 @@ def from_proto_metadata(metadata: proto.UpdateMetadata) -> Optional[Metadata]: out_metadata[key] = value.int_value elif value.HasField("float_value"): out_metadata[key] = value.float_value + elif is_update: + out_metadata[key] = None else: - raise RuntimeError(f"Unknown metadata value type {value}") + raise ValueError(f"Metadata key {key} value cannot be None") return out_metadata +def to_proto_update_metadata(metadata: UpdateMetadata) -> proto.UpdateMetadata: + return proto.UpdateMetadata( + metadata={k: to_proto_metadata_update_value(v) for k, v in metadata.items()} + ) + + def from_proto_submit( submit_embedding_record: proto.SubmitEmbeddingRecord, seq_id: SeqId ) -> EmbeddingRecord: @@ -95,7 +119,7 @@ def from_proto_submit( seq_id=seq_id, embedding=embedding, encoding=encoding, - metadata=from_proto_metadata(submit_embedding_record.metadata), + metadata=from_proto_update_metadata(submit_embedding_record.metadata), operation=from_proto_operation(submit_embedding_record.operation), ) return record @@ -106,11 +130,13 @@ def from_proto_segment(segment: proto.Segment) -> Segment: id=UUID(hex=segment.id), type=segment.type, scope=from_proto_segment_scope(segment.scope), - topic=segment.topic, + topic=segment.topic if segment.HasField("topic") else None, collection=None if not segment.HasField("collection") else UUID(hex=segment.collection), - metadata=from_proto_metadata(segment.metadata), + metadata=from_proto_metadata(segment.metadata) + if segment.HasField("metadata") + else None, ) @@ -123,9 +149,7 @@ def to_proto_segment(segment: Segment) -> proto.Segment: collection=None if segment["collection"] is None else segment["collection"].hex, metadata=None if segment["metadata"] is None - else { - k: to_proto_metadata_update_value(v) for k, v in segment["metadata"].items() - }, # TODO: refactor out to_proto_metadata + else to_proto_update_metadata(segment["metadata"]), ) @@ -165,6 +189,30 @@ def to_proto_metadata_update_value( ) +def from_proto_collection(collection: proto.Collection) -> Collection: + return Collection( + id=UUID(hex=collection.id), + name=collection.name, + topic=collection.topic, + metadata=from_proto_metadata(collection.metadata) + if collection.HasField("metadata") + else None, + dimension=collection.dimension if collection.HasField("dimension") else None, + ) + + +def to_proto_collection(collection: Collection) -> proto.Collection: + return proto.Collection( + id=collection["id"].hex, + name=collection["name"], + topic=collection["topic"], + metadata=None + if collection["metadata"] is None + else to_proto_update_metadata(collection["metadata"]), + dimension=collection["dimension"], + ) + + def to_proto_operation(operation: Operation) -> proto.Operation: if operation == Operation.ADD: return proto.Operation.ADD @@ -190,17 +238,12 @@ def to_proto_submit( metadata = None if submit_record["metadata"] is not None: - metadata = { - k: to_proto_metadata_update_value(v) - for k, v in submit_record["metadata"].items() - } + metadata = to_proto_update_metadata(submit_record["metadata"]) return proto.SubmitEmbeddingRecord( id=submit_record["id"], vector=vector, - metadata=proto.UpdateMetadata(metadata=metadata) - if metadata is not None - else None, + metadata=metadata, operation=to_proto_operation(submit_record["operation"]), ) diff --git a/chromadb/proto/coordinator_pb2.py b/chromadb/proto/coordinator_pb2.py new file mode 100644 index 00000000000..29c7dd5dc51 --- /dev/null +++ b/chromadb/proto/coordinator_pb2.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: chromadb/proto/coordinator.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from chromadb.proto import chroma_pb2 as chromadb_dot_proto_dot_chroma__pb2 +from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n chromadb/proto/coordinator.proto\x12\x06\x63hroma\x1a\x1b\x63hromadb/proto/chroma.proto\x1a\x1bgoogle/protobuf/empty.proto\"8\n\x14\x43reateSegmentRequest\x12 \n\x07segment\x18\x01 \x01(\x0b\x32\x0f.chroma.Segment\"\"\n\x14\x44\x65leteSegmentRequest\x12\n\n\x02id\x18\x01 \x01(\t\"\xc2\x01\n\x12GetSegmentsRequest\x12\x0f\n\x02id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04type\x18\x02 \x01(\tH\x01\x88\x01\x01\x12(\n\x05scope\x18\x03 \x01(\x0e\x32\x14.chroma.SegmentScopeH\x02\x88\x01\x01\x12\x12\n\x05topic\x18\x04 \x01(\tH\x03\x88\x01\x01\x12\x17\n\ncollection\x18\x05 \x01(\tH\x04\x88\x01\x01\x42\x05\n\x03_idB\x07\n\x05_typeB\x08\n\x06_scopeB\x08\n\x06_topicB\r\n\x0b_collection\"X\n\x13GetSegmentsResponse\x12!\n\x08segments\x18\x01 \x03(\x0b\x32\x0f.chroma.Segment\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"\xfa\x01\n\x14UpdateSegmentRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0f\n\x05topic\x18\x02 \x01(\tH\x00\x12\x15\n\x0breset_topic\x18\x03 \x01(\x08H\x00\x12\x14\n\ncollection\x18\x04 \x01(\tH\x01\x12\x1a\n\x10reset_collection\x18\x05 \x01(\x08H\x01\x12*\n\x08metadata\x18\x06 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x12\x18\n\x0ereset_metadata\x18\x07 \x01(\x08H\x02\x42\x0e\n\x0ctopic_updateB\x13\n\x11\x63ollection_updateB\x11\n\x0fmetadata_update\"o\n\x17\x43reateCollectionRequest\x12&\n\ncollection\x18\x01 \x01(\x0b\x32\x12.chroma.Collection\x12\x1a\n\rget_or_create\x18\x02 \x01(\x08H\x00\x88\x01\x01\x42\x10\n\x0e_get_or_create\"b\n\x18\x43reateCollectionResponse\x12&\n\ncollection\x18\x01 \x01(\x0b\x32\x12.chroma.Collection\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"%\n\x17\x44\x65leteCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\"i\n\x15GetCollectionsRequest\x12\x0f\n\x02id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x12\n\x05topic\x18\x03 \x01(\tH\x02\x88\x01\x01\x42\x05\n\x03_idB\x07\n\x05_nameB\x08\n\x06_topic\"a\n\x16GetCollectionsResponse\x12\'\n\x0b\x63ollections\x18\x01 \x03(\x0b\x32\x12.chroma.Collection\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"\xde\x01\n\x17UpdateCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x12\n\x05topic\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x11\n\x04name\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x16\n\tdimension\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12*\n\x08metadata\x18\x05 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x12\x18\n\x0ereset_metadata\x18\x06 \x01(\x08H\x00\x42\x11\n\x0fmetadata_updateB\x08\n\x06_topicB\x07\n\x05_nameB\x0c\n\n_dimension2\xb6\x05\n\x05SysDB\x12G\n\rCreateSegment\x12\x1c.chroma.CreateSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12G\n\rDeleteSegment\x12\x1c.chroma.DeleteSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12H\n\x0bGetSegments\x12\x1a.chroma.GetSegmentsRequest\x1a\x1b.chroma.GetSegmentsResponse\"\x00\x12G\n\rUpdateSegment\x12\x1c.chroma.UpdateSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12W\n\x10\x43reateCollection\x12\x1f.chroma.CreateCollectionRequest\x1a .chroma.CreateCollectionResponse\"\x00\x12M\n\x10\x44\x65leteCollection\x12\x1f.chroma.DeleteCollectionRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12Q\n\x0eGetCollections\x12\x1d.chroma.GetCollectionsRequest\x1a\x1e.chroma.GetCollectionsResponse\"\x00\x12M\n\x10UpdateCollection\x12\x1f.chroma.UpdateCollectionRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12>\n\nResetState\x12\x16.google.protobuf.Empty\x1a\x16.chroma.ChromaResponse\"\x00\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'chromadb.proto.coordinator_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals['_CREATESEGMENTREQUEST']._serialized_start=102 + _globals['_CREATESEGMENTREQUEST']._serialized_end=158 + _globals['_DELETESEGMENTREQUEST']._serialized_start=160 + _globals['_DELETESEGMENTREQUEST']._serialized_end=194 + _globals['_GETSEGMENTSREQUEST']._serialized_start=197 + _globals['_GETSEGMENTSREQUEST']._serialized_end=391 + _globals['_GETSEGMENTSRESPONSE']._serialized_start=393 + _globals['_GETSEGMENTSRESPONSE']._serialized_end=481 + _globals['_UPDATESEGMENTREQUEST']._serialized_start=484 + _globals['_UPDATESEGMENTREQUEST']._serialized_end=734 + _globals['_CREATECOLLECTIONREQUEST']._serialized_start=736 + _globals['_CREATECOLLECTIONREQUEST']._serialized_end=847 + _globals['_CREATECOLLECTIONRESPONSE']._serialized_start=849 + _globals['_CREATECOLLECTIONRESPONSE']._serialized_end=947 + _globals['_DELETECOLLECTIONREQUEST']._serialized_start=949 + _globals['_DELETECOLLECTIONREQUEST']._serialized_end=986 + _globals['_GETCOLLECTIONSREQUEST']._serialized_start=988 + _globals['_GETCOLLECTIONSREQUEST']._serialized_end=1093 + _globals['_GETCOLLECTIONSRESPONSE']._serialized_start=1095 + _globals['_GETCOLLECTIONSRESPONSE']._serialized_end=1192 + _globals['_UPDATECOLLECTIONREQUEST']._serialized_start=1195 + _globals['_UPDATECOLLECTIONREQUEST']._serialized_end=1417 + _globals['_SYSDB']._serialized_start=1420 + _globals['_SYSDB']._serialized_end=2114 +# @@protoc_insertion_point(module_scope) diff --git a/chromadb/proto/coordinator_pb2.pyi b/chromadb/proto/coordinator_pb2.pyi new file mode 100644 index 00000000000..37736ff720e --- /dev/null +++ b/chromadb/proto/coordinator_pb2.pyi @@ -0,0 +1,116 @@ +from chromadb.proto import chroma_pb2 as _chroma_pb2 +from google.protobuf import empty_pb2 as _empty_pb2 +from google.protobuf.internal import containers as _containers +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union + +DESCRIPTOR: _descriptor.FileDescriptor + +class CreateSegmentRequest(_message.Message): + __slots__ = ["segment"] + SEGMENT_FIELD_NUMBER: _ClassVar[int] + segment: _chroma_pb2.Segment + def __init__(self, segment: _Optional[_Union[_chroma_pb2.Segment, _Mapping]] = ...) -> None: ... + +class DeleteSegmentRequest(_message.Message): + __slots__ = ["id"] + ID_FIELD_NUMBER: _ClassVar[int] + id: str + def __init__(self, id: _Optional[str] = ...) -> None: ... + +class GetSegmentsRequest(_message.Message): + __slots__ = ["id", "type", "scope", "topic", "collection"] + ID_FIELD_NUMBER: _ClassVar[int] + TYPE_FIELD_NUMBER: _ClassVar[int] + SCOPE_FIELD_NUMBER: _ClassVar[int] + TOPIC_FIELD_NUMBER: _ClassVar[int] + COLLECTION_FIELD_NUMBER: _ClassVar[int] + id: str + type: str + scope: _chroma_pb2.SegmentScope + topic: str + collection: str + def __init__(self, id: _Optional[str] = ..., type: _Optional[str] = ..., scope: _Optional[_Union[_chroma_pb2.SegmentScope, str]] = ..., topic: _Optional[str] = ..., collection: _Optional[str] = ...) -> None: ... + +class GetSegmentsResponse(_message.Message): + __slots__ = ["segments", "status"] + SEGMENTS_FIELD_NUMBER: _ClassVar[int] + STATUS_FIELD_NUMBER: _ClassVar[int] + segments: _containers.RepeatedCompositeFieldContainer[_chroma_pb2.Segment] + status: _chroma_pb2.Status + def __init__(self, segments: _Optional[_Iterable[_Union[_chroma_pb2.Segment, _Mapping]]] = ..., status: _Optional[_Union[_chroma_pb2.Status, _Mapping]] = ...) -> None: ... + +class UpdateSegmentRequest(_message.Message): + __slots__ = ["id", "topic", "reset_topic", "collection", "reset_collection", "metadata", "reset_metadata"] + ID_FIELD_NUMBER: _ClassVar[int] + TOPIC_FIELD_NUMBER: _ClassVar[int] + RESET_TOPIC_FIELD_NUMBER: _ClassVar[int] + COLLECTION_FIELD_NUMBER: _ClassVar[int] + RESET_COLLECTION_FIELD_NUMBER: _ClassVar[int] + METADATA_FIELD_NUMBER: _ClassVar[int] + RESET_METADATA_FIELD_NUMBER: _ClassVar[int] + id: str + topic: str + reset_topic: bool + collection: str + reset_collection: bool + metadata: _chroma_pb2.UpdateMetadata + reset_metadata: bool + def __init__(self, id: _Optional[str] = ..., topic: _Optional[str] = ..., reset_topic: bool = ..., collection: _Optional[str] = ..., reset_collection: bool = ..., metadata: _Optional[_Union[_chroma_pb2.UpdateMetadata, _Mapping]] = ..., reset_metadata: bool = ...) -> None: ... + +class CreateCollectionRequest(_message.Message): + __slots__ = ["collection", "get_or_create"] + COLLECTION_FIELD_NUMBER: _ClassVar[int] + GET_OR_CREATE_FIELD_NUMBER: _ClassVar[int] + collection: _chroma_pb2.Collection + get_or_create: bool + def __init__(self, collection: _Optional[_Union[_chroma_pb2.Collection, _Mapping]] = ..., get_or_create: bool = ...) -> None: ... + +class CreateCollectionResponse(_message.Message): + __slots__ = ["collection", "status"] + COLLECTION_FIELD_NUMBER: _ClassVar[int] + STATUS_FIELD_NUMBER: _ClassVar[int] + collection: _chroma_pb2.Collection + status: _chroma_pb2.Status + def __init__(self, collection: _Optional[_Union[_chroma_pb2.Collection, _Mapping]] = ..., status: _Optional[_Union[_chroma_pb2.Status, _Mapping]] = ...) -> None: ... + +class DeleteCollectionRequest(_message.Message): + __slots__ = ["id"] + ID_FIELD_NUMBER: _ClassVar[int] + id: str + def __init__(self, id: _Optional[str] = ...) -> None: ... + +class GetCollectionsRequest(_message.Message): + __slots__ = ["id", "name", "topic"] + ID_FIELD_NUMBER: _ClassVar[int] + NAME_FIELD_NUMBER: _ClassVar[int] + TOPIC_FIELD_NUMBER: _ClassVar[int] + id: str + name: str + topic: str + def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., topic: _Optional[str] = ...) -> None: ... + +class GetCollectionsResponse(_message.Message): + __slots__ = ["collections", "status"] + COLLECTIONS_FIELD_NUMBER: _ClassVar[int] + STATUS_FIELD_NUMBER: _ClassVar[int] + collections: _containers.RepeatedCompositeFieldContainer[_chroma_pb2.Collection] + status: _chroma_pb2.Status + def __init__(self, collections: _Optional[_Iterable[_Union[_chroma_pb2.Collection, _Mapping]]] = ..., status: _Optional[_Union[_chroma_pb2.Status, _Mapping]] = ...) -> None: ... + +class UpdateCollectionRequest(_message.Message): + __slots__ = ["id", "topic", "name", "dimension", "metadata", "reset_metadata"] + ID_FIELD_NUMBER: _ClassVar[int] + TOPIC_FIELD_NUMBER: _ClassVar[int] + NAME_FIELD_NUMBER: _ClassVar[int] + DIMENSION_FIELD_NUMBER: _ClassVar[int] + METADATA_FIELD_NUMBER: _ClassVar[int] + RESET_METADATA_FIELD_NUMBER: _ClassVar[int] + id: str + topic: str + name: str + dimension: int + metadata: _chroma_pb2.UpdateMetadata + reset_metadata: bool + def __init__(self, id: _Optional[str] = ..., topic: _Optional[str] = ..., name: _Optional[str] = ..., dimension: _Optional[int] = ..., metadata: _Optional[_Union[_chroma_pb2.UpdateMetadata, _Mapping]] = ..., reset_metadata: bool = ...) -> None: ... diff --git a/chromadb/proto/coordinator_pb2_grpc.py b/chromadb/proto/coordinator_pb2_grpc.py new file mode 100644 index 00000000000..a3a1e03227b --- /dev/null +++ b/chromadb/proto/coordinator_pb2_grpc.py @@ -0,0 +1,441 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +from chromadb.proto import chroma_pb2 as chromadb_dot_proto_dot_chroma__pb2 +from chromadb.proto import coordinator_pb2 as chromadb_dot_proto_dot_coordinator__pb2 +from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 + + +class SysDBStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.CreateSegment = channel.unary_unary( + "/chroma.SysDB/CreateSegment", + request_serializer=chromadb_dot_proto_dot_coordinator__pb2.CreateSegmentRequest.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + ) + self.DeleteSegment = channel.unary_unary( + "/chroma.SysDB/DeleteSegment", + request_serializer=chromadb_dot_proto_dot_coordinator__pb2.DeleteSegmentRequest.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + ) + self.GetSegments = channel.unary_unary( + "/chroma.SysDB/GetSegments", + request_serializer=chromadb_dot_proto_dot_coordinator__pb2.GetSegmentsRequest.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_coordinator__pb2.GetSegmentsResponse.FromString, + ) + self.UpdateSegment = channel.unary_unary( + "/chroma.SysDB/UpdateSegment", + request_serializer=chromadb_dot_proto_dot_coordinator__pb2.UpdateSegmentRequest.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + ) + self.CreateCollection = channel.unary_unary( + "/chroma.SysDB/CreateCollection", + request_serializer=chromadb_dot_proto_dot_coordinator__pb2.CreateCollectionRequest.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_coordinator__pb2.CreateCollectionResponse.FromString, + ) + self.DeleteCollection = channel.unary_unary( + "/chroma.SysDB/DeleteCollection", + request_serializer=chromadb_dot_proto_dot_coordinator__pb2.DeleteCollectionRequest.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + ) + self.GetCollections = channel.unary_unary( + "/chroma.SysDB/GetCollections", + request_serializer=chromadb_dot_proto_dot_coordinator__pb2.GetCollectionsRequest.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_coordinator__pb2.GetCollectionsResponse.FromString, + ) + self.UpdateCollection = channel.unary_unary( + "/chroma.SysDB/UpdateCollection", + request_serializer=chromadb_dot_proto_dot_coordinator__pb2.UpdateCollectionRequest.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + ) + self.ResetState = channel.unary_unary( + "/chroma.SysDB/ResetState", + request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + ) + + +class SysDBServicer(object): + """Missing associated documentation comment in .proto file.""" + + def CreateSegment(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def DeleteSegment(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def GetSegments(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def UpdateSegment(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def CreateCollection(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def DeleteCollection(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def GetCollections(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def UpdateCollection(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def ResetState(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + +def add_SysDBServicer_to_server(servicer, server): + rpc_method_handlers = { + "CreateSegment": grpc.unary_unary_rpc_method_handler( + servicer.CreateSegment, + request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.CreateSegmentRequest.FromString, + response_serializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.SerializeToString, + ), + "DeleteSegment": grpc.unary_unary_rpc_method_handler( + servicer.DeleteSegment, + request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.DeleteSegmentRequest.FromString, + response_serializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.SerializeToString, + ), + "GetSegments": grpc.unary_unary_rpc_method_handler( + servicer.GetSegments, + request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.GetSegmentsRequest.FromString, + response_serializer=chromadb_dot_proto_dot_coordinator__pb2.GetSegmentsResponse.SerializeToString, + ), + "UpdateSegment": grpc.unary_unary_rpc_method_handler( + servicer.UpdateSegment, + request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.UpdateSegmentRequest.FromString, + response_serializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.SerializeToString, + ), + "CreateCollection": grpc.unary_unary_rpc_method_handler( + servicer.CreateCollection, + request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.CreateCollectionRequest.FromString, + response_serializer=chromadb_dot_proto_dot_coordinator__pb2.CreateCollectionResponse.SerializeToString, + ), + "DeleteCollection": grpc.unary_unary_rpc_method_handler( + servicer.DeleteCollection, + request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.DeleteCollectionRequest.FromString, + response_serializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.SerializeToString, + ), + "GetCollections": grpc.unary_unary_rpc_method_handler( + servicer.GetCollections, + request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.GetCollectionsRequest.FromString, + response_serializer=chromadb_dot_proto_dot_coordinator__pb2.GetCollectionsResponse.SerializeToString, + ), + "UpdateCollection": grpc.unary_unary_rpc_method_handler( + servicer.UpdateCollection, + request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.UpdateCollectionRequest.FromString, + response_serializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.SerializeToString, + ), + "ResetState": grpc.unary_unary_rpc_method_handler( + servicer.ResetState, + request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, + response_serializer=chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + "chroma.SysDB", rpc_method_handlers + ) + server.add_generic_rpc_handlers((generic_handler,)) + + +# This class is part of an EXPERIMENTAL API. +class SysDB(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def CreateSegment( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/chroma.SysDB/CreateSegment", + chromadb_dot_proto_dot_coordinator__pb2.CreateSegmentRequest.SerializeToString, + chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def DeleteSegment( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/chroma.SysDB/DeleteSegment", + chromadb_dot_proto_dot_coordinator__pb2.DeleteSegmentRequest.SerializeToString, + chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def GetSegments( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/chroma.SysDB/GetSegments", + chromadb_dot_proto_dot_coordinator__pb2.GetSegmentsRequest.SerializeToString, + chromadb_dot_proto_dot_coordinator__pb2.GetSegmentsResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def UpdateSegment( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/chroma.SysDB/UpdateSegment", + chromadb_dot_proto_dot_coordinator__pb2.UpdateSegmentRequest.SerializeToString, + chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def CreateCollection( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/chroma.SysDB/CreateCollection", + chromadb_dot_proto_dot_coordinator__pb2.CreateCollectionRequest.SerializeToString, + chromadb_dot_proto_dot_coordinator__pb2.CreateCollectionResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def DeleteCollection( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/chroma.SysDB/DeleteCollection", + chromadb_dot_proto_dot_coordinator__pb2.DeleteCollectionRequest.SerializeToString, + chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def GetCollections( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/chroma.SysDB/GetCollections", + chromadb_dot_proto_dot_coordinator__pb2.GetCollectionsRequest.SerializeToString, + chromadb_dot_proto_dot_coordinator__pb2.GetCollectionsResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def UpdateCollection( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/chroma.SysDB/UpdateCollection", + chromadb_dot_proto_dot_coordinator__pb2.UpdateCollectionRequest.SerializeToString, + chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def ResetState( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/chroma.SysDB/ResetState", + google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + chromadb_dot_proto_dot_chroma__pb2.ChromaResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) diff --git a/chromadb/test/db/test_system.py b/chromadb/test/db/test_system.py index 82c9b6a8f61..f67a3d9d390 100644 --- a/chromadb/test/db/test_system.py +++ b/chromadb/test/db/test_system.py @@ -3,6 +3,8 @@ import tempfile import pytest from typing import Generator, List, Callable, Dict, Union +from chromadb.db.impl.grpc.client import GrpcSysDB +from chromadb.db.impl.grpc.server import GrpcMockSysDB from chromadb.types import Collection, Segment, SegmentScope from chromadb.db.impl.sqlite import SqliteDB from chromadb.config import System, Settings @@ -35,8 +37,19 @@ def sqlite_persistent() -> Generator[SysDB, None, None]: shutil.rmtree(save_path) +def grpc_with_mock_server() -> Generator[SysDB, None, None]: + """Fixture generator for sqlite DB that creates a mock grpc sysdb server + and a grpc client that connects to it.""" + system = System(Settings(allow_reset=True, chroma_server_grpc_port=50051)) + system.instance(GrpcMockSysDB) + client = system.instance(GrpcSysDB) + system.start() + client.reset_and_wait_for_ready() + yield client + + def db_fixtures() -> List[Callable[[], Generator[SysDB, None, None]]]: - return [sqlite, sqlite_persistent] + return [sqlite, sqlite_persistent, grpc_with_mock_server] @pytest.fixture(scope="module", params=db_fixtures()) diff --git a/idl/chromadb/proto/chroma.proto b/idl/chromadb/proto/chroma.proto index ddc7f11bc26..7b1f10f18ea 100644 --- a/idl/chromadb/proto/chroma.proto +++ b/idl/chromadb/proto/chroma.proto @@ -2,6 +2,17 @@ syntax = "proto3"; package chroma; +message Status { + string reason = 1; + int32 code = 2; // TODO: What is the enum of this code? +} + +message ChromaResponse { + Status status = 1; +} + +// Types here should mirror chromadb/types.py + enum Operation { ADD = 0; UPDATE = 1; @@ -36,6 +47,14 @@ message Segment { optional UpdateMetadata metadata = 6; } +message Collection { + string id = 1; + string name = 2; + string topic = 3; + optional UpdateMetadata metadata = 4; + optional int32 dimension = 5; +} + message UpdateMetadataValue { oneof value { string string_value = 1; diff --git a/idl/chromadb/proto/coordinator.proto b/idl/chromadb/proto/coordinator.proto new file mode 100644 index 00000000000..d4b95058451 --- /dev/null +++ b/idl/chromadb/proto/coordinator.proto @@ -0,0 +1,92 @@ +syntax = "proto3"; + +package chroma; + +import "chromadb/proto/chroma.proto"; +import "google/protobuf/empty.proto"; + +message CreateSegmentRequest { + Segment segment = 1; +} + +message DeleteSegmentRequest { + string id = 1; +} + +message GetSegmentsRequest { + optional string id = 1; + optional string type = 2; + optional SegmentScope scope = 3; + optional string topic = 4; + optional string collection = 5; +} + +message GetSegmentsResponse { + repeated Segment segments = 1; + Status status = 2; +} + + +message UpdateSegmentRequest { + string id = 1; + oneof topic_update { + string topic = 2; + bool reset_topic = 3; + } + oneof collection_update { + string collection = 4; + bool reset_collection = 5; + } + oneof metadata_update { + UpdateMetadata metadata = 6; + bool reset_metadata = 7; + } +} + +message CreateCollectionRequest { + Collection collection = 1; + optional bool get_or_create = 2; +} + +message CreateCollectionResponse { + Collection collection = 1; + Status status = 2; +} + +message DeleteCollectionRequest { + string id = 1; +} + +message GetCollectionsRequest { + optional string id = 1; + optional string name = 2; + optional string topic = 3; +} + +message GetCollectionsResponse { + repeated Collection collections = 1; + Status status = 2; +} + +message UpdateCollectionRequest { + string id = 1; + optional string topic = 2; + optional string name = 3; + optional int32 dimension = 4; + oneof metadata_update { + UpdateMetadata metadata = 5; + bool reset_metadata = 6; + } +} + +service SysDB { + rpc CreateSegment(CreateSegmentRequest) returns (ChromaResponse) {} + rpc DeleteSegment(DeleteSegmentRequest) returns (ChromaResponse) {} + rpc GetSegments(GetSegmentsRequest) returns (GetSegmentsResponse) {} + rpc UpdateSegment(UpdateSegmentRequest) returns (ChromaResponse) {} + rpc CreateCollection(CreateCollectionRequest) returns (CreateCollectionResponse) {} + rpc DeleteCollection(DeleteCollectionRequest) returns (ChromaResponse) {} + rpc GetCollections(GetCollectionsRequest) returns (GetCollectionsResponse) {} + rpc UpdateCollection(UpdateCollectionRequest) returns (ChromaResponse) {} + rpc ResetState(google.protobuf.Empty) returns (ChromaResponse) {} +} From b042331234268c6b380efd95f28b558f33d40221 Mon Sep 17 00:00:00 2001 From: Hammad Bashir Date: Sun, 15 Oct 2023 22:16:21 -0700 Subject: [PATCH 10/14] [STACKED #1229] [ENH] Add a CollectionAssignmentPolicy and move topic creation into SysDB using the policy (#1237) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Adds a CollectionAssignmentPolicy and moves topic creation down into SysDB - Changes SysDB to not accept a Collection Object in create_collection() but params instead - New functionality - None TODO: - [ ] delete topic should also live in the sysdb now. ## Test plan *How are these changes tested?* Existing tests were modified with the refactor and ensured to pass. ## Documentation Changes None required --- chromadb/api/segment.py | 23 +---- chromadb/config.py | 7 +- chromadb/db/impl/grpc/client.py | 21 ++++- chromadb/db/mixins/sysdb.py | 32 ++++++- chromadb/db/system.py | 15 ++- chromadb/ingest/__init__.py | 9 ++ chromadb/ingest/impl/simple_policy.py | 25 +++++ chromadb/test/db/test_system.py | 127 +++++++++++++++++--------- 8 files changed, 188 insertions(+), 71 deletions(-) create mode 100644 chromadb/ingest/impl/simple_policy.py diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index 85dca1d8532..73ed7329200 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -1,7 +1,6 @@ from chromadb.api import API from chromadb.config import Settings, System from chromadb.db.system import SysDB -from chromadb.ingest.impl.utils import create_topic_name from chromadb.segment import SegmentManager, MetadataReader, VectorReader from chromadb.telemetry import Telemetry from chromadb.ingest import Producer @@ -79,10 +78,7 @@ class SegmentAPI(API): _sysdb: SysDB _manager: SegmentManager _producer: Producer - # TODO: fire telemetry events _telemetry_client: Telemetry - _tenant_id: str - _topic_ns: str _collection_cache: Dict[UUID, t.Collection] def __init__(self, system: System): @@ -92,8 +88,6 @@ def __init__(self, system: System): self._manager = self.require(SegmentManager) self._telemetry_client = self.require(Telemetry) self._producer = self.require(Producer) - self._tenant_id = system.settings.tenant_id - self._topic_ns = system.settings.topic_namespace self._collection_cache = {} @override @@ -135,15 +129,12 @@ def create_collection( check_index_name(name) id = uuid4() - coll = t.Collection( - id=id, name=name, metadata=metadata, topic=self._topic(id), dimension=None + + coll = self._sysdb.create_collection( + id=id, name=name, metadata=metadata, dimension=None ) - # TODO: Topic creation right now lives in the producer but it should be moved to the coordinator, - # and the producer should just be responsible for publishing messages. Coordinator should - # be responsible for all management of topics. - self._producer.create_topic(coll["topic"]) segments = self._manager.create_segments(coll) - self._sysdb.create_collection(coll) + for segment in segments: self._sysdb.create_segment(segment) @@ -244,7 +235,6 @@ def delete_collection(self, name: str) -> None: self._sysdb.delete_collection(existing[0]["id"]) for s in self._manager.delete_segments(existing[0]["id"]): self._sysdb.delete_segment(s) - self._producer.delete_topic(existing[0]["topic"]) if existing and existing[0]["id"] in self._collection_cache: del self._collection_cache[existing[0]["id"]] else: @@ -620,13 +610,8 @@ def get_settings(self) -> Settings: def max_batch_size(self) -> int: return self._producer.max_batch_size - def _topic(self, collection_id: UUID) -> str: - return create_topic_name(self._tenant_id, self._topic_ns, str(collection_id)) - # TODO: This could potentially cause race conditions in a distributed version of the # system, since the cache is only local. - # TODO: promote collection -> topic to a base class method so that it can be - # used for channel assignment in the distributed version of the system. def _validate_embedding_record( self, collection: t.Collection, record: t.SubmitEmbeddingRecord ) -> None: diff --git a/chromadb/config.py b/chromadb/config.py index a2af7bd32bc..eb7bca93ef5 100644 --- a/chromadb/config.py +++ b/chromadb/config.py @@ -67,6 +67,7 @@ "chromadb.telemetry.Telemetry": "chroma_telemetry_impl", "chromadb.ingest.Producer": "chroma_producer_impl", "chromadb.ingest.Consumer": "chroma_consumer_impl", + "chromadb.ingest.CollectionAssignmentPolicy": "chroma_collection_assignment_policy_impl", # noqa "chromadb.db.system.SysDB": "chroma_sysdb_impl", "chromadb.segment.SegmentManager": "chroma_segment_manager_impl", "chromadb.segment.distributed.SegmentDirectory": "chroma_segment_directory_impl", @@ -77,7 +78,8 @@ class Settings(BaseSettings): # type: ignore environment: str = "" - # Legacy config has to be kept around because pydantic will error on nonexisting keys + # Legacy config has to be kept around because pydantic will error + # on nonexisting keys chroma_db_impl: Optional[str] = None chroma_api_impl: str = "chromadb.api.segment.SegmentAPI" # Can be "chromadb.api.segment.SegmentAPI" or "chromadb.api.fastapi.FastAPI" @@ -94,6 +96,9 @@ class Settings(BaseSettings): # type: ignore # Distributed architecture specific components chroma_segment_directory_impl: str = "chromadb.segment.impl.distributed.segment_directory.RendezvousHashSegmentDirectory" chroma_memberlist_provider_impl: str = "chromadb.segment.impl.distributed.segment_directory.CustomResourceMemberlistProvider" + chroma_collection_assignment_policy_impl: str = ( + "chromadb.ingest.impl.simple_policy.SimpleAssignmentPolicy" + ) worker_memberlist_name: str = "worker-memberlist" chroma_coordinator_host = "localhost" diff --git a/chromadb/db/impl/grpc/client.py b/chromadb/db/impl/grpc/client.py index 5d7c6e88838..49e231b0086 100644 --- a/chromadb/db/impl/grpc/client.py +++ b/chromadb/db/impl/grpc/client.py @@ -4,6 +4,7 @@ from chromadb.config import System from chromadb.db.base import NotFoundError, UniqueConstraintError from chromadb.db.system import SysDB +from chromadb.ingest import CollectionAssignmentPolicy from chromadb.proto.convert import ( from_proto_collection, from_proto_segment, @@ -26,6 +27,7 @@ from chromadb.proto.coordinator_pb2_grpc import SysDBStub from chromadb.types import ( Collection, + Metadata, OptionalArgument, Segment, SegmentScope, @@ -42,6 +44,7 @@ class GrpcSysDB(SysDB): to call a remote SysDB (Coordinator) service.""" _sys_db_stub: SysDBStub + _assignment_policy: CollectionAssignmentPolicy _channel: grpc.Channel _coordinator_url: str _coordinator_port: int @@ -50,6 +53,7 @@ def __init__(self, system: System): self._coordinator_url = system.settings.require("chroma_coordinator_host") # TODO: break out coordinator_port into a separate setting? self._coordinator_port = system.settings.require("chroma_server_grpc_port") + self._assignment_policy = system.instance(CollectionAssignmentPolicy) return super().__init__(system) @overrides @@ -156,8 +160,22 @@ def update_segment( self._sys_db_stub.UpdateSegment(request) @overrides - def create_collection(self, collection: Collection) -> None: + def create_collection( + self, + id: UUID, + name: str, + metadata: Optional[Metadata] = None, + dimension: Optional[int] = None, + ) -> Collection: # TODO: the get_or_create concept needs to be pushed down to the sysdb interface + topic = self._assignment_policy.assign_collection(id) + collection = Collection( + id=id, + name=name, + topic=topic, + metadata=metadata, + dimension=dimension, + ) request = CreateCollectionRequest( collection=to_proto_collection(collection), get_or_create=False, @@ -165,6 +183,7 @@ def create_collection(self, collection: Collection) -> None: response = self._sys_db_stub.CreateCollection(request) if response.status.code == 409: raise UniqueConstraintError() + return collection @overrides def delete_collection(self, id: UUID) -> None: diff --git a/chromadb/db/mixins/sysdb.py b/chromadb/db/mixins/sysdb.py index 58ee4488b64..afd3f9917c0 100644 --- a/chromadb/db/mixins/sysdb.py +++ b/chromadb/db/mixins/sysdb.py @@ -14,6 +14,7 @@ UniqueConstraintError, ) from chromadb.db.system import SysDB +from chromadb.ingest import CollectionAssignmentPolicy, Producer from chromadb.types import ( OptionalArgument, Segment, @@ -26,9 +27,20 @@ class SqlSysDB(SqlDB, SysDB): + _assignment_policy: CollectionAssignmentPolicy + # Used only to delete topics on collection deletion. + # TODO: refactor to remove this dependency into a separate interface + _producer: Producer + def __init__(self, system: System): + self._assignment_policy = system.instance(CollectionAssignmentPolicy) super().__init__(system) + @override + def start(self) -> None: + super().start() + self._producer = self._system.instance(Producer) + @override def create_segment(self, segment: Segment) -> None: with self.tx() as cur: @@ -69,8 +81,20 @@ def create_segment(self, segment: Segment) -> None: ) @override - def create_collection(self, collection: Collection) -> None: - """Create a new collection""" + def create_collection( + self, + id: UUID, + name: str, + metadata: Optional[Metadata] = None, + dimension: Optional[int] = None, + ) -> Collection: + """Create a new collection and the associate topic""" + + topic = self._assignment_policy.assign_collection(id) + collection = Collection( + id=id, topic=topic, name=name, metadata=metadata, dimension=dimension + ) + with self.tx() as cur: collections = Table("collections") insert_collection = ( @@ -105,6 +129,7 @@ def create_collection(self, collection: Collection) -> None: collection["id"], collection["metadata"], ) + return collection @override def get_segments( @@ -263,10 +288,11 @@ def delete_collection(self, id: UUID) -> None: with self.tx() as cur: # no need for explicit del from metadata table because of ON DELETE CASCADE sql, params = get_sql(q, self.parameter_format()) - sql = sql + " RETURNING id" + sql = sql + " RETURNING id, topic" result = cur.execute(sql, params).fetchone() if not result: raise NotFoundError(f"Collection {id} not found") + self._producer.delete_topic(result[1]) @override def update_segment( diff --git a/chromadb/db/system.py b/chromadb/db/system.py index 23f068c3be3..a1bfbbffea6 100644 --- a/chromadb/db/system.py +++ b/chromadb/db/system.py @@ -3,6 +3,7 @@ from uuid import UUID from chromadb.types import ( Collection, + Metadata, Segment, SegmentScope, OptionalArgument, @@ -52,13 +53,21 @@ def update_segment( pass @abstractmethod - def create_collection(self, collection: Collection) -> None: - """Create a new collection any associated resources in the SysDB.""" + def create_collection( + self, + id: UUID, + name: str, + metadata: Optional[Metadata] = None, + dimension: Optional[int] = None, + ) -> Collection: + """Create a new collection any associated resources + (Such as the necessary topics) in the SysDB.""" pass @abstractmethod def delete_collection(self, id: UUID) -> None: - """Delete a topic and all associated segments from the SysDB""" + """Delete a collection, topic, all associated segments and any associate resources + from the SysDB and the system at large.""" pass @abstractmethod diff --git a/chromadb/ingest/__init__.py b/chromadb/ingest/__init__.py index 56863e8914d..5a5abf1c99b 100644 --- a/chromadb/ingest/__init__.py +++ b/chromadb/ingest/__init__.py @@ -118,3 +118,12 @@ def min_seqid(self) -> SeqId: def max_seqid(self) -> SeqId: """Return the maximum possible SeqID in this implementation.""" pass + + +class CollectionAssignmentPolicy(Component): + """Interface for assigning collections to topics""" + + @abstractmethod + def assign_collection(self, collection_id: UUID) -> str: + """Return the topic that should be used for the given collection""" + pass diff --git a/chromadb/ingest/impl/simple_policy.py b/chromadb/ingest/impl/simple_policy.py new file mode 100644 index 00000000000..06ee2e001e0 --- /dev/null +++ b/chromadb/ingest/impl/simple_policy.py @@ -0,0 +1,25 @@ +from uuid import UUID +from overrides import overrides +from chromadb.config import System +from chromadb.ingest import CollectionAssignmentPolicy +from chromadb.ingest.impl.utils import create_topic_name + + +class SimpleAssignmentPolicy(CollectionAssignmentPolicy): + """Simple assignment policy that assigns a 1 collection to 1 topic based on the + id of the collection.""" + + _tenant_id: str + _topic_ns: str + + def __init__(self, system: System): + self._tenant_id = system.settings.tenant_id + self._topic_ns = system.settings.topic_namespace + super().__init__(system) + + def _topic(self, collection_id: UUID) -> str: + return create_topic_name(self._tenant_id, self._topic_ns, str(collection_id)) + + @overrides + def assign_collection(self, collection_id: UUID) -> str: + return self._topic(collection_id) diff --git a/chromadb/test/db/test_system.py b/chromadb/test/db/test_system.py index f67a3d9d390..8127c27c5d9 100644 --- a/chromadb/test/db/test_system.py +++ b/chromadb/test/db/test_system.py @@ -7,16 +7,55 @@ from chromadb.db.impl.grpc.server import GrpcMockSysDB from chromadb.types import Collection, Segment, SegmentScope from chromadb.db.impl.sqlite import SqliteDB -from chromadb.config import System, Settings +from chromadb.config import Component, System, Settings from chromadb.db.system import SysDB from chromadb.db.base import NotFoundError, UniqueConstraintError from pytest import FixtureRequest import uuid +sample_collections = [ + Collection( + id=uuid.UUID("93ffe3ec-0107-48d4-8695-51f978c509dc"), + name="test_collection_1", + topic="test_topic_1", + metadata={"test_str": "str1", "test_int": 1, "test_float": 1.3}, + dimension=128, + ), + Collection( + id=uuid.UUID("f444f1d7-d06c-4357-ac22-5a4a1f92d761"), + name="test_collection_2", + topic="test_topic_2", + metadata={"test_str": "str2", "test_int": 2, "test_float": 2.3}, + dimension=None, + ), + Collection( + id=uuid.UUID("43babc1a-e403-4a50-91a9-16621ba29ab0"), + name="test_collection_3", + topic="test_topic_3", + metadata={"test_str": "str3", "test_int": 3, "test_float": 3.3}, + dimension=None, + ), +] + + +class MockAssignmentPolicy(Component): + def assign_collection(self, collection_id: uuid.UUID) -> str: + for collection in sample_collections: + if collection["id"] == collection_id: + return collection["topic"] + raise ValueError(f"Unknown collection ID: {collection_id}") + def sqlite() -> Generator[SysDB, None, None]: """Fixture generator for sqlite DB""" - db = SqliteDB(System(Settings(allow_reset=True))) + db = SqliteDB( + System( + Settings( + allow_reset=True, + chroma_collection_assignment_policy_impl="chromadb.test.db.test_system.MockAssignmentPolicy", + ) + ) + ) db.start() yield db db.stop() @@ -27,7 +66,12 @@ def sqlite_persistent() -> Generator[SysDB, None, None]: save_path = tempfile.mkdtemp() db = SqliteDB( System( - Settings(allow_reset=True, is_persistent=True, persist_directory=save_path) + Settings( + allow_reset=True, + is_persistent=True, + persist_directory=save_path, + chroma_collection_assignment_policy_impl="chromadb.test.db.test_system.MockAssignmentPolicy", + ) ) ) db.start() @@ -40,7 +84,13 @@ def sqlite_persistent() -> Generator[SysDB, None, None]: def grpc_with_mock_server() -> Generator[SysDB, None, None]: """Fixture generator for sqlite DB that creates a mock grpc sysdb server and a grpc client that connects to it.""" - system = System(Settings(allow_reset=True, chroma_server_grpc_port=50051)) + system = System( + Settings( + allow_reset=True, + chroma_collection_assignment_policy_impl="chromadb.test.db.test_system.MockAssignmentPolicy", + chroma_server_grpc_port=50051, + ) + ) system.instance(GrpcMockSysDB) client = system.instance(GrpcSysDB) system.start() @@ -57,36 +107,16 @@ def sysdb(request: FixtureRequest) -> Generator[SysDB, None, None]: yield next(request.param()) -sample_collections = [ - Collection( - id=uuid.uuid4(), - name="test_collection_1", - topic="test_topic_1", - metadata={"test_str": "str1", "test_int": 1, "test_float": 1.3}, - dimension=128, - ), - Collection( - id=uuid.uuid4(), - name="test_collection_2", - topic="test_topic_2", - metadata={"test_str": "str2", "test_int": 2, "test_float": 2.3}, - dimension=None, - ), - Collection( - id=uuid.uuid4(), - name="test_collection_3", - topic="test_topic_3", - metadata={"test_str": "str3", "test_int": 3, "test_float": 3.3}, - dimension=None, - ), -] - - def test_create_get_delete_collections(sysdb: SysDB) -> None: sysdb.reset_state() for collection in sample_collections: - sysdb.create_collection(collection) + sysdb.create_collection( + id=collection["id"], + name=collection["name"], + metadata=collection["metadata"], + dimension=collection["dimension"], + ) results = sysdb.get_collections() results = sorted(results, key=lambda c: c["name"]) @@ -95,7 +125,9 @@ def test_create_get_delete_collections(sysdb: SysDB) -> None: # Duplicate create fails with pytest.raises(UniqueConstraintError): - sysdb.create_collection(sample_collections[0]) + sysdb.create_collection( + name=sample_collections[0]["name"], id=sample_collections[0]["id"] + ) # Find by name for collection in sample_collections: @@ -140,22 +172,22 @@ def test_create_get_delete_collections(sysdb: SysDB) -> None: def test_update_collections(sysdb: SysDB) -> None: - metadata: Dict[str, Union[str, int, float]] = { - "test_str": "str1", - "test_int": 1, - "test_float": 1.3, - } coll = Collection( - id=uuid.uuid4(), - name="test_collection_1", - topic="test_topic_1", - metadata=metadata, - dimension=None, + name=sample_collections[0]["name"], + id=sample_collections[0]["id"], + topic=sample_collections[0]["topic"], + metadata=sample_collections[0]["metadata"], + dimension=sample_collections[0]["dimension"], ) sysdb.reset_state() - sysdb.create_collection(coll) + sysdb.create_collection( + id=coll["id"], + name=coll["name"], + metadata=coll["metadata"], + dimension=coll["dimension"], + ) # Update name coll["name"] = "new_name" @@ -220,7 +252,12 @@ def test_create_get_delete_segments(sysdb: SysDB) -> None: sysdb.reset_state() for collection in sample_collections: - sysdb.create_collection(collection) + sysdb.create_collection( + id=collection["id"], + name=collection["name"], + metadata=collection["metadata"], + dimension=collection["dimension"], + ) for segment in sample_segments: sysdb.create_segment(segment) @@ -293,7 +330,9 @@ def test_update_segment(sysdb: SysDB) -> None: sysdb.reset_state() for c in sample_collections: - sysdb.create_collection(c) + sysdb.create_collection( + id=c["id"], name=c["name"], metadata=c["metadata"], dimension=c["dimension"] + ) sysdb.create_segment(segment) From 9d89b9621a6f3a726ce601c6a14266146bf6506d Mon Sep 17 00:00:00 2001 From: Hammad Bashir Date: Sun, 15 Oct 2023 23:44:27 -0700 Subject: [PATCH 11/14] [STACKED #1237] [ENH] Move get_or_create into sysdb (#1242) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Move get_or_create logic into the sys db interface - New functionality - None added ## Test plan *How are these changes tested?* Existing tests as well as a new test at the sysdb level. - [x] Tests pass locally with `pytest` for python, `yarn test` for js ## Documentation Changes None, all changes are to internal interfaces. --- chromadb/api/segment.py | 37 +++++--------- chromadb/db/impl/grpc/client.py | 27 ++++------ chromadb/db/impl/grpc/server.py | 39 +++++++++++--- chromadb/db/mixins/sysdb.py | 21 ++++++-- chromadb/db/system.py | 14 +++-- chromadb/proto/convert.py | 4 +- chromadb/proto/coordinator_pb2.py | 30 +++++------ chromadb/proto/coordinator_pb2.pyi | 20 +++++--- chromadb/test/db/test_system.py | 76 ++++++++++++++++++++++++++++ idl/chromadb/proto/coordinator.proto | 10 ++-- 10 files changed, 200 insertions(+), 78 deletions(-) diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index 73ed7329200..cfe1300e76e 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -105,39 +105,28 @@ def create_collection( embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), get_or_create: bool = False, ) -> Collection: - existing = self._sysdb.get_collections(name=name) - if metadata is not None: validate_metadata(metadata) - if existing: - if get_or_create: - if metadata and existing[0]["metadata"] != metadata: - self._modify(id=existing[0]["id"], new_metadata=metadata) - existing = self._sysdb.get_collections(id=existing[0]["id"]) - return Collection( - client=self, - id=existing[0]["id"], - name=existing[0]["name"], - metadata=existing[0]["metadata"], # type: ignore - embedding_function=embedding_function, - ) - else: - raise ValueError(f"Collection {name} already exists.") - # TODO: remove backwards compatibility in naming requirements check_index_name(name) id = uuid4() - coll = self._sysdb.create_collection( - id=id, name=name, metadata=metadata, dimension=None + coll, created = self._sysdb.create_collection( + id=id, + name=name, + metadata=metadata, + dimension=None, + get_or_create=get_or_create, ) - segments = self._manager.create_segments(coll) - for segment in segments: - self._sysdb.create_segment(segment) + if created: + segments = self._manager.create_segments(coll) + for segment in segments: + self._sysdb.create_segment(segment) + # TODO: This event doesn't capture the get_or_create case appropriately self._telemetry_client.capture( ClientCreateCollectionEvent( collection_uuid=str(id), @@ -147,9 +136,9 @@ def create_collection( return Collection( client=self, - id=id, + id=coll["id"], name=name, - metadata=metadata, + metadata=coll["metadata"], # type: ignore embedding_function=embedding_function, ) diff --git a/chromadb/db/impl/grpc/client.py b/chromadb/db/impl/grpc/client.py index 49e231b0086..04d4302062a 100644 --- a/chromadb/db/impl/grpc/client.py +++ b/chromadb/db/impl/grpc/client.py @@ -1,14 +1,12 @@ -from typing import List, Optional, Sequence, Union, cast +from typing import List, Optional, Sequence, Tuple, Union, cast from uuid import UUID from overrides import overrides from chromadb.config import System from chromadb.db.base import NotFoundError, UniqueConstraintError from chromadb.db.system import SysDB -from chromadb.ingest import CollectionAssignmentPolicy from chromadb.proto.convert import ( from_proto_collection, from_proto_segment, - to_proto_collection, to_proto_update_metadata, to_proto_segment, to_proto_segment_scope, @@ -44,7 +42,6 @@ class GrpcSysDB(SysDB): to call a remote SysDB (Coordinator) service.""" _sys_db_stub: SysDBStub - _assignment_policy: CollectionAssignmentPolicy _channel: grpc.Channel _coordinator_url: str _coordinator_port: int @@ -53,11 +50,11 @@ def __init__(self, system: System): self._coordinator_url = system.settings.require("chroma_coordinator_host") # TODO: break out coordinator_port into a separate setting? self._coordinator_port = system.settings.require("chroma_server_grpc_port") - self._assignment_policy = system.instance(CollectionAssignmentPolicy) return super().__init__(system) @overrides def start(self) -> None: + # TODO: add retry policy here self._channel = grpc.insecure_channel( f"{self._coordinator_url}:{self._coordinator_port}" ) @@ -166,24 +163,20 @@ def create_collection( name: str, metadata: Optional[Metadata] = None, dimension: Optional[int] = None, - ) -> Collection: - # TODO: the get_or_create concept needs to be pushed down to the sysdb interface - topic = self._assignment_policy.assign_collection(id) - collection = Collection( - id=id, + get_or_create: bool = False, + ) -> Tuple[Collection, bool]: + request = CreateCollectionRequest( + id=id.hex, name=name, - topic=topic, - metadata=metadata, + metadata=to_proto_update_metadata(metadata) if metadata else None, dimension=dimension, - ) - request = CreateCollectionRequest( - collection=to_proto_collection(collection), - get_or_create=False, + get_or_create=get_or_create, ) response = self._sys_db_stub.CreateCollection(request) if response.status.code == 409: raise UniqueConstraintError() - return collection + collection = from_proto_collection(response.collection) + return collection, response.created @overrides def delete_collection(self, id: UUID) -> None: diff --git a/chromadb/db/impl/grpc/server.py b/chromadb/db/impl/grpc/server.py index f1b69460492..436a57fc167 100644 --- a/chromadb/db/impl/grpc/server.py +++ b/chromadb/db/impl/grpc/server.py @@ -2,9 +2,10 @@ from typing import Any, Dict, cast from uuid import UUID from overrides import overrides +from chromadb.ingest import CollectionAssignmentPolicy from chromadb.config import Component, System from chromadb.proto.convert import ( - from_proto_collection, + from_proto_metadata, from_proto_update_metadata, from_proto_segment, from_proto_segment_scope, @@ -40,11 +41,13 @@ class GrpcMockSysDB(SysDBServicer, Component): _server: grpc.Server _server_port: int + _assignment_policy: CollectionAssignmentPolicy _segments: Dict[str, Segment] = {} _collections: Dict[str, Collection] = {} def __init__(self, system: System): self._server_port = system.settings.require("chroma_server_grpc_port") + self._assignment_policy = system.instance(CollectionAssignmentPolicy) return super().__init__(system) @overrides @@ -167,18 +170,42 @@ def UpdateSegment( def CreateCollection( self, request: CreateCollectionRequest, context: grpc.ServicerContext ) -> CreateCollectionResponse: - collection = from_proto_collection(request.collection) - if collection["id"].hex in self._collections: + collection_name = request.name + matches = [ + c for c in self._collections.values() if c["name"] == collection_name + ] + assert len(matches) <= 1 + if len(matches) > 0: + if request.get_or_create: + existing_collection = matches[0] + if request.HasField("metadata"): + existing_collection["metadata"] = from_proto_metadata( + request.metadata + ) + return CreateCollectionResponse( + status=proto.Status(code=200), + collection=to_proto_collection(existing_collection), + created=False, + ) return CreateCollectionResponse( status=proto.Status( - code=409, reason=f"Collection {collection['id']} already exists" + code=409, reason=f"Collection {request.name} already exists" ) ) - self._collections[collection["id"].hex] = collection + id = UUID(hex=request.id) + new_collection = Collection( + id=id, + name=request.name, + metadata=from_proto_metadata(request.metadata), + dimension=request.dimension, + topic=self._assignment_policy.assign_collection(id), + ) + self._collections[request.id] = new_collection return CreateCollectionResponse( status=proto.Status(code=200), - collection=to_proto_collection(collection), + collection=to_proto_collection(new_collection), + created=True, ) @overrides(check_signature=False) diff --git a/chromadb/db/mixins/sysdb.py b/chromadb/db/mixins/sysdb.py index afd3f9917c0..d105918e700 100644 --- a/chromadb/db/mixins/sysdb.py +++ b/chromadb/db/mixins/sysdb.py @@ -87,8 +87,23 @@ def create_collection( name: str, metadata: Optional[Metadata] = None, dimension: Optional[int] = None, - ) -> Collection: - """Create a new collection and the associate topic""" + get_or_create: bool = False, + ) -> Tuple[Collection, bool]: + if id is None and not get_or_create: + raise ValueError("id must be specified if get_or_create is False") + + existing = self.get_collections(name=name) + if existing: + if get_or_create: + collection = existing[0] + if metadata is not None and collection["metadata"] != metadata: + self.update_collection( + collection["id"], + metadata=metadata, + ) + return self.get_collections(id=collection["id"])[0], False + else: + raise UniqueConstraintError(f"Collection {name} already exists") topic = self._assignment_policy.assign_collection(id) collection = Collection( @@ -129,7 +144,7 @@ def create_collection( collection["id"], collection["metadata"], ) - return collection + return collection, True @override def get_segments( diff --git a/chromadb/db/system.py b/chromadb/db/system.py index a1bfbbffea6..b4975339fbc 100644 --- a/chromadb/db/system.py +++ b/chromadb/db/system.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Optional, Sequence +from typing import Optional, Sequence, Tuple from uuid import UUID from chromadb.types import ( Collection, @@ -59,9 +59,17 @@ def create_collection( name: str, metadata: Optional[Metadata] = None, dimension: Optional[int] = None, - ) -> Collection: + get_or_create: bool = False, + ) -> Tuple[Collection, bool]: """Create a new collection any associated resources - (Such as the necessary topics) in the SysDB.""" + (Such as the necessary topics) in the SysDB. If get_or_create is True, the + collectionwill be created if one with the same name does not exist. + The metadata will be updated using the same protocol as update_collection. If get_or_create + is False and the collection already exists, a error will be raised. + + Returns a tuple of the created collection and a boolean indicating whether the + collection was created or not. + """ pass @abstractmethod diff --git a/chromadb/proto/convert.py b/chromadb/proto/convert.py index d46cad07710..129d7e3ff2a 100644 --- a/chromadb/proto/convert.py +++ b/chromadb/proto/convert.py @@ -197,7 +197,9 @@ def from_proto_collection(collection: proto.Collection) -> Collection: metadata=from_proto_metadata(collection.metadata) if collection.HasField("metadata") else None, - dimension=collection.dimension if collection.HasField("dimension") else None, + dimension=collection.dimension + if collection.HasField("dimension") and collection.dimension + else None, ) diff --git a/chromadb/proto/coordinator_pb2.py b/chromadb/proto/coordinator_pb2.py index 29c7dd5dc51..118405d423a 100644 --- a/chromadb/proto/coordinator_pb2.py +++ b/chromadb/proto/coordinator_pb2.py @@ -15,7 +15,7 @@ from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n chromadb/proto/coordinator.proto\x12\x06\x63hroma\x1a\x1b\x63hromadb/proto/chroma.proto\x1a\x1bgoogle/protobuf/empty.proto\"8\n\x14\x43reateSegmentRequest\x12 \n\x07segment\x18\x01 \x01(\x0b\x32\x0f.chroma.Segment\"\"\n\x14\x44\x65leteSegmentRequest\x12\n\n\x02id\x18\x01 \x01(\t\"\xc2\x01\n\x12GetSegmentsRequest\x12\x0f\n\x02id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04type\x18\x02 \x01(\tH\x01\x88\x01\x01\x12(\n\x05scope\x18\x03 \x01(\x0e\x32\x14.chroma.SegmentScopeH\x02\x88\x01\x01\x12\x12\n\x05topic\x18\x04 \x01(\tH\x03\x88\x01\x01\x12\x17\n\ncollection\x18\x05 \x01(\tH\x04\x88\x01\x01\x42\x05\n\x03_idB\x07\n\x05_typeB\x08\n\x06_scopeB\x08\n\x06_topicB\r\n\x0b_collection\"X\n\x13GetSegmentsResponse\x12!\n\x08segments\x18\x01 \x03(\x0b\x32\x0f.chroma.Segment\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"\xfa\x01\n\x14UpdateSegmentRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0f\n\x05topic\x18\x02 \x01(\tH\x00\x12\x15\n\x0breset_topic\x18\x03 \x01(\x08H\x00\x12\x14\n\ncollection\x18\x04 \x01(\tH\x01\x12\x1a\n\x10reset_collection\x18\x05 \x01(\x08H\x01\x12*\n\x08metadata\x18\x06 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x12\x18\n\x0ereset_metadata\x18\x07 \x01(\x08H\x02\x42\x0e\n\x0ctopic_updateB\x13\n\x11\x63ollection_updateB\x11\n\x0fmetadata_update\"o\n\x17\x43reateCollectionRequest\x12&\n\ncollection\x18\x01 \x01(\x0b\x32\x12.chroma.Collection\x12\x1a\n\rget_or_create\x18\x02 \x01(\x08H\x00\x88\x01\x01\x42\x10\n\x0e_get_or_create\"b\n\x18\x43reateCollectionResponse\x12&\n\ncollection\x18\x01 \x01(\x0b\x32\x12.chroma.Collection\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"%\n\x17\x44\x65leteCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\"i\n\x15GetCollectionsRequest\x12\x0f\n\x02id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x12\n\x05topic\x18\x03 \x01(\tH\x02\x88\x01\x01\x42\x05\n\x03_idB\x07\n\x05_nameB\x08\n\x06_topic\"a\n\x16GetCollectionsResponse\x12\'\n\x0b\x63ollections\x18\x01 \x03(\x0b\x32\x12.chroma.Collection\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"\xde\x01\n\x17UpdateCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x12\n\x05topic\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x11\n\x04name\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x16\n\tdimension\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12*\n\x08metadata\x18\x05 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x12\x18\n\x0ereset_metadata\x18\x06 \x01(\x08H\x00\x42\x11\n\x0fmetadata_updateB\x08\n\x06_topicB\x07\n\x05_nameB\x0c\n\n_dimension2\xb6\x05\n\x05SysDB\x12G\n\rCreateSegment\x12\x1c.chroma.CreateSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12G\n\rDeleteSegment\x12\x1c.chroma.DeleteSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12H\n\x0bGetSegments\x12\x1a.chroma.GetSegmentsRequest\x1a\x1b.chroma.GetSegmentsResponse\"\x00\x12G\n\rUpdateSegment\x12\x1c.chroma.UpdateSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12W\n\x10\x43reateCollection\x12\x1f.chroma.CreateCollectionRequest\x1a .chroma.CreateCollectionResponse\"\x00\x12M\n\x10\x44\x65leteCollection\x12\x1f.chroma.DeleteCollectionRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12Q\n\x0eGetCollections\x12\x1d.chroma.GetCollectionsRequest\x1a\x1e.chroma.GetCollectionsResponse\"\x00\x12M\n\x10UpdateCollection\x12\x1f.chroma.UpdateCollectionRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12>\n\nResetState\x12\x16.google.protobuf.Empty\x1a\x16.chroma.ChromaResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n chromadb/proto/coordinator.proto\x12\x06\x63hroma\x1a\x1b\x63hromadb/proto/chroma.proto\x1a\x1bgoogle/protobuf/empty.proto\"8\n\x14\x43reateSegmentRequest\x12 \n\x07segment\x18\x01 \x01(\x0b\x32\x0f.chroma.Segment\"\"\n\x14\x44\x65leteSegmentRequest\x12\n\n\x02id\x18\x01 \x01(\t\"\xc2\x01\n\x12GetSegmentsRequest\x12\x0f\n\x02id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04type\x18\x02 \x01(\tH\x01\x88\x01\x01\x12(\n\x05scope\x18\x03 \x01(\x0e\x32\x14.chroma.SegmentScopeH\x02\x88\x01\x01\x12\x12\n\x05topic\x18\x04 \x01(\tH\x03\x88\x01\x01\x12\x17\n\ncollection\x18\x05 \x01(\tH\x04\x88\x01\x01\x42\x05\n\x03_idB\x07\n\x05_typeB\x08\n\x06_scopeB\x08\n\x06_topicB\r\n\x0b_collection\"X\n\x13GetSegmentsResponse\x12!\n\x08segments\x18\x01 \x03(\x0b\x32\x0f.chroma.Segment\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"\xfa\x01\n\x14UpdateSegmentRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0f\n\x05topic\x18\x02 \x01(\tH\x00\x12\x15\n\x0breset_topic\x18\x03 \x01(\x08H\x00\x12\x14\n\ncollection\x18\x04 \x01(\tH\x01\x12\x1a\n\x10reset_collection\x18\x05 \x01(\x08H\x01\x12*\n\x08metadata\x18\x06 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x12\x18\n\x0ereset_metadata\x18\x07 \x01(\x08H\x02\x42\x0e\n\x0ctopic_updateB\x13\n\x11\x63ollection_updateB\x11\n\x0fmetadata_update\"\xc3\x01\n\x17\x43reateCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12-\n\x08metadata\x18\x03 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x88\x01\x01\x12\x16\n\tdimension\x18\x04 \x01(\x05H\x01\x88\x01\x01\x12\x1a\n\rget_or_create\x18\x05 \x01(\x08H\x02\x88\x01\x01\x42\x0b\n\t_metadataB\x0c\n\n_dimensionB\x10\n\x0e_get_or_create\"s\n\x18\x43reateCollectionResponse\x12&\n\ncollection\x18\x01 \x01(\x0b\x32\x12.chroma.Collection\x12\x0f\n\x07\x63reated\x18\x02 \x01(\x08\x12\x1e\n\x06status\x18\x03 \x01(\x0b\x32\x0e.chroma.Status\"%\n\x17\x44\x65leteCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\"i\n\x15GetCollectionsRequest\x12\x0f\n\x02id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x12\n\x05topic\x18\x03 \x01(\tH\x02\x88\x01\x01\x42\x05\n\x03_idB\x07\n\x05_nameB\x08\n\x06_topic\"a\n\x16GetCollectionsResponse\x12\'\n\x0b\x63ollections\x18\x01 \x03(\x0b\x32\x12.chroma.Collection\x12\x1e\n\x06status\x18\x02 \x01(\x0b\x32\x0e.chroma.Status\"\xde\x01\n\x17UpdateCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x12\n\x05topic\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x11\n\x04name\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x16\n\tdimension\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12*\n\x08metadata\x18\x05 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x12\x18\n\x0ereset_metadata\x18\x06 \x01(\x08H\x00\x42\x11\n\x0fmetadata_updateB\x08\n\x06_topicB\x07\n\x05_nameB\x0c\n\n_dimension2\xb6\x05\n\x05SysDB\x12G\n\rCreateSegment\x12\x1c.chroma.CreateSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12G\n\rDeleteSegment\x12\x1c.chroma.DeleteSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12H\n\x0bGetSegments\x12\x1a.chroma.GetSegmentsRequest\x1a\x1b.chroma.GetSegmentsResponse\"\x00\x12G\n\rUpdateSegment\x12\x1c.chroma.UpdateSegmentRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12W\n\x10\x43reateCollection\x12\x1f.chroma.CreateCollectionRequest\x1a .chroma.CreateCollectionResponse\"\x00\x12M\n\x10\x44\x65leteCollection\x12\x1f.chroma.DeleteCollectionRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12Q\n\x0eGetCollections\x12\x1d.chroma.GetCollectionsRequest\x1a\x1e.chroma.GetCollectionsResponse\"\x00\x12M\n\x10UpdateCollection\x12\x1f.chroma.UpdateCollectionRequest\x1a\x16.chroma.ChromaResponse\"\x00\x12>\n\nResetState\x12\x16.google.protobuf.Empty\x1a\x16.chroma.ChromaResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -32,18 +32,18 @@ _globals['_GETSEGMENTSRESPONSE']._serialized_end=481 _globals['_UPDATESEGMENTREQUEST']._serialized_start=484 _globals['_UPDATESEGMENTREQUEST']._serialized_end=734 - _globals['_CREATECOLLECTIONREQUEST']._serialized_start=736 - _globals['_CREATECOLLECTIONREQUEST']._serialized_end=847 - _globals['_CREATECOLLECTIONRESPONSE']._serialized_start=849 - _globals['_CREATECOLLECTIONRESPONSE']._serialized_end=947 - _globals['_DELETECOLLECTIONREQUEST']._serialized_start=949 - _globals['_DELETECOLLECTIONREQUEST']._serialized_end=986 - _globals['_GETCOLLECTIONSREQUEST']._serialized_start=988 - _globals['_GETCOLLECTIONSREQUEST']._serialized_end=1093 - _globals['_GETCOLLECTIONSRESPONSE']._serialized_start=1095 - _globals['_GETCOLLECTIONSRESPONSE']._serialized_end=1192 - _globals['_UPDATECOLLECTIONREQUEST']._serialized_start=1195 - _globals['_UPDATECOLLECTIONREQUEST']._serialized_end=1417 - _globals['_SYSDB']._serialized_start=1420 - _globals['_SYSDB']._serialized_end=2114 + _globals['_CREATECOLLECTIONREQUEST']._serialized_start=737 + _globals['_CREATECOLLECTIONREQUEST']._serialized_end=932 + _globals['_CREATECOLLECTIONRESPONSE']._serialized_start=934 + _globals['_CREATECOLLECTIONRESPONSE']._serialized_end=1049 + _globals['_DELETECOLLECTIONREQUEST']._serialized_start=1051 + _globals['_DELETECOLLECTIONREQUEST']._serialized_end=1088 + _globals['_GETCOLLECTIONSREQUEST']._serialized_start=1090 + _globals['_GETCOLLECTIONSREQUEST']._serialized_end=1195 + _globals['_GETCOLLECTIONSRESPONSE']._serialized_start=1197 + _globals['_GETCOLLECTIONSRESPONSE']._serialized_end=1294 + _globals['_UPDATECOLLECTIONREQUEST']._serialized_start=1297 + _globals['_UPDATECOLLECTIONREQUEST']._serialized_end=1519 + _globals['_SYSDB']._serialized_start=1522 + _globals['_SYSDB']._serialized_end=2216 # @@protoc_insertion_point(module_scope) diff --git a/chromadb/proto/coordinator_pb2.pyi b/chromadb/proto/coordinator_pb2.pyi index 37736ff720e..6b9c974e424 100644 --- a/chromadb/proto/coordinator_pb2.pyi +++ b/chromadb/proto/coordinator_pb2.pyi @@ -60,20 +60,28 @@ class UpdateSegmentRequest(_message.Message): def __init__(self, id: _Optional[str] = ..., topic: _Optional[str] = ..., reset_topic: bool = ..., collection: _Optional[str] = ..., reset_collection: bool = ..., metadata: _Optional[_Union[_chroma_pb2.UpdateMetadata, _Mapping]] = ..., reset_metadata: bool = ...) -> None: ... class CreateCollectionRequest(_message.Message): - __slots__ = ["collection", "get_or_create"] - COLLECTION_FIELD_NUMBER: _ClassVar[int] + __slots__ = ["id", "name", "metadata", "dimension", "get_or_create"] + ID_FIELD_NUMBER: _ClassVar[int] + NAME_FIELD_NUMBER: _ClassVar[int] + METADATA_FIELD_NUMBER: _ClassVar[int] + DIMENSION_FIELD_NUMBER: _ClassVar[int] GET_OR_CREATE_FIELD_NUMBER: _ClassVar[int] - collection: _chroma_pb2.Collection + id: str + name: str + metadata: _chroma_pb2.UpdateMetadata + dimension: int get_or_create: bool - def __init__(self, collection: _Optional[_Union[_chroma_pb2.Collection, _Mapping]] = ..., get_or_create: bool = ...) -> None: ... + def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., metadata: _Optional[_Union[_chroma_pb2.UpdateMetadata, _Mapping]] = ..., dimension: _Optional[int] = ..., get_or_create: bool = ...) -> None: ... class CreateCollectionResponse(_message.Message): - __slots__ = ["collection", "status"] + __slots__ = ["collection", "created", "status"] COLLECTION_FIELD_NUMBER: _ClassVar[int] + CREATED_FIELD_NUMBER: _ClassVar[int] STATUS_FIELD_NUMBER: _ClassVar[int] collection: _chroma_pb2.Collection + created: bool status: _chroma_pb2.Status - def __init__(self, collection: _Optional[_Union[_chroma_pb2.Collection, _Mapping]] = ..., status: _Optional[_Union[_chroma_pb2.Status, _Mapping]] = ...) -> None: ... + def __init__(self, collection: _Optional[_Union[_chroma_pb2.Collection, _Mapping]] = ..., created: bool = ..., status: _Optional[_Union[_chroma_pb2.Status, _Mapping]] = ...) -> None: ... class DeleteCollectionRequest(_message.Message): __slots__ = ["id"] diff --git a/chromadb/test/db/test_system.py b/chromadb/test/db/test_system.py index 8127c27c5d9..541643a2ff6 100644 --- a/chromadb/test/db/test_system.py +++ b/chromadb/test/db/test_system.py @@ -220,6 +220,82 @@ def test_update_collections(sysdb: SysDB) -> None: assert result == [coll] +def test_get_or_create_collection(sysdb: SysDB) -> None: + sysdb.reset_state() + + # get_or_create = True returns existing collection + collection = sample_collections[0] + + sysdb.create_collection( + id=collection["id"], + name=collection["name"], + metadata=collection["metadata"], + dimension=collection["dimension"], + ) + + result, created = sysdb.create_collection( + name=collection["name"], + id=uuid.uuid4(), + get_or_create=True, + metadata=collection["metadata"], + ) + assert result == collection + + # Only one collection with the same name exists + get_result = sysdb.get_collections(name=collection["name"]) + assert get_result == [collection] + + # get_or_create = True creates new collection + result, created = sysdb.create_collection( + name=sample_collections[1]["name"], + id=sample_collections[1]["id"], + get_or_create=True, + metadata=sample_collections[1]["metadata"], + ) + assert result == sample_collections[1] + + # get_or_create = False creates new collection + result, created = sysdb.create_collection( + name=sample_collections[2]["name"], + id=sample_collections[2]["id"], + get_or_create=False, + metadata=sample_collections[2]["metadata"], + ) + assert result == sample_collections[2] + + # get_or_create = False fails if collection already exists + with pytest.raises(UniqueConstraintError): + sysdb.create_collection( + name=sample_collections[2]["name"], + id=sample_collections[2]["id"], + get_or_create=False, + metadata=collection["metadata"], + ) + + # get_or_create = True overwrites metadata + overlayed_metadata: Dict[str, Union[str, int, float]] = { + "test_new_str": "new_str", + "test_int": 1, + } + result, created = sysdb.create_collection( + name=sample_collections[2]["name"], + id=sample_collections[2]["id"], + get_or_create=True, + metadata=overlayed_metadata, + ) + + assert result["metadata"] == overlayed_metadata + + # get_or_create = False with None metadata does not overwrite metadata + result, created = sysdb.create_collection( + name=sample_collections[2]["name"], + id=sample_collections[2]["id"], + get_or_create=True, + metadata=None, + ) + assert result["metadata"] == overlayed_metadata + + sample_segments = [ Segment( id=uuid.UUID("00000000-d7d7-413b-92e1-731098a6e492"), diff --git a/idl/chromadb/proto/coordinator.proto b/idl/chromadb/proto/coordinator.proto index d4b95058451..2a557f99613 100644 --- a/idl/chromadb/proto/coordinator.proto +++ b/idl/chromadb/proto/coordinator.proto @@ -44,13 +44,17 @@ message UpdateSegmentRequest { } message CreateCollectionRequest { - Collection collection = 1; - optional bool get_or_create = 2; + string id = 1; + string name = 2; + optional UpdateMetadata metadata = 3; + optional int32 dimension = 4; + optional bool get_or_create = 5; } message CreateCollectionResponse { Collection collection = 1; - Status status = 2; + bool created = 2; + Status status = 3; } message DeleteCollectionRequest { From 99c8a99a894c3b8908b7ac0cb0406fdff6c87de6 Mon Sep 17 00:00:00 2001 From: Ben Eggers <64657842+beggers@users.noreply.github.com> Date: Wed, 18 Oct 2023 10:45:37 -0700 Subject: [PATCH 12/14] [ENH] OTel tracing throughout the codebase (#1238) ## Description of changes This PR adds OpenTelemetry tracing to ~all major methods throughout our codebase. It also adds configuration to specify where these traces should be sent. I focused more on laying the groundwork for tracing than on collecting all the data we need everywhere. Default behavior is unchanged: no tracing, no printing. *Summarize the changes made by this PR.* - New functionality - OpenTelemetryClient with relevant config. - Wrap most of our code in tracing. The only major design decision I made was to fully separate OpenTelemetry stuff and Posthog (product telemetry) stuff. Justification: It's tempting to combine OTel and product telemetry behind a single internal interface. I don't think this coupling is worth it. Product telemetry cares about a small and relatively static set of uses, whereas tracing by nature should be very deep in our codebase. I see two ways to couple them and problems with each: - Have a unified Telemetry interface only for the events our product telemetry cares about and use raw OTel for the rest. In other words, use this telemetry interface only for `collection.add()`s, `collection.delete()`s, etc. This seems weird to me: tracing code would be implicit in some cases but explicit in others, making the codebase less easily comprehensible. Also if an engineer later decides to add product telemetry to a codepath that already has tracing, they need to know to remove existing tracing. This increases the cognitive overhead required to work on Chroma, reducing the readability and maintainability of our codebase. - Have a unified Telemetry interface which does everything. In this case, it has the above behavior but also wraps all other OTel behavior we want. This seems weird to me because we're basically writing a wrapper around the complete set of OpenTelemetry behavior which only modifies a small part of it. This increases our maintenance burden for very little value. Instead we have two well-encapsulated telemetry modules which we can modify and use without worrying about the other telemetry. The OTel module provides some lightweight helpers to make OTel a little easier to use, but we can put raw OTel code throughout our codebase and it'll play nicely. ## Test plan *How are these changes tested?* - [ ] Tests pass locally with `pytest` for python, `yarn test` for js Manual testing: - Set environment variables to export traces to Honeycomb at various granularities. - Went through various basic Chroma flows and checked that traces show up in Honeycomb as expected. ![Screenshot 2023-10-12 at 10 39 11 AM](https://github.com/chroma-core/chroma/assets/64657842/49c95054-ef7f-42b1-bb14-4b372edf9343) ## Documentation Changes *Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?* Docs PR to land before this does. --- .gitignore | 2 +- chroma_data/chroma.sqlite3 | Bin 0 -> 126976 bytes chromadb/__init__.py | 18 ++- chromadb/api/fastapi.py | 34 ++++- chromadb/api/segment.py | 81 +++++++++-- chromadb/auth/basic/__init__.py | 6 + chromadb/auth/fastapi.py | 8 ++ chromadb/auth/providers.py | 9 ++ chromadb/auth/token/__init__.py | 11 ++ chromadb/config.py | 11 +- chromadb/db/impl/sqlite.py | 38 ++++-- chromadb/db/migrations.py | 8 ++ chromadb/db/mixins/embeddings_queue.py | 30 +++- chromadb/db/mixins/sysdb.py | 86 +++++++++++- chromadb/ingest/impl/pulsar.py | 21 ++- chromadb/segment/impl/distributed/server.py | 14 +- chromadb/segment/impl/manager/distributed.py | 26 ++++ chromadb/segment/impl/manager/local.py | 23 ++++ chromadb/segment/impl/metadata/sqlite.py | 36 ++++- chromadb/segment/impl/vector/grpc_segment.py | 9 ++ chromadb/segment/impl/vector/local_hnsw.py | 16 +++ .../impl/vector/local_persistent_hnsw.py | 29 ++++ chromadb/server/fastapi/__init__.py | 22 ++- chromadb/telemetry/README.md | 10 ++ chromadb/telemetry/__init__.py | 122 ----------------- chromadb/telemetry/opentelemetry/__init__.py | 128 ++++++++++++++++++ chromadb/telemetry/product/__init__.py | 93 +++++++++++++ chromadb/telemetry/{ => product}/events.py | 20 +-- chromadb/telemetry/{ => product}/posthog.py | 16 ++- .../property/test_cross_version_persist.py | 26 +++- docker-compose.yml | 4 + requirements.txt | 3 + server.htpasswd | 1 + 33 files changed, 766 insertions(+), 195 deletions(-) create mode 100644 chroma_data/chroma.sqlite3 create mode 100644 chromadb/telemetry/README.md create mode 100644 chromadb/telemetry/opentelemetry/__init__.py create mode 100644 chromadb/telemetry/product/__init__.py rename chromadb/telemetry/{ => product}/events.py (89%) rename chromadb/telemetry/{ => product}/posthog.py (77%) create mode 100644 server.htpasswd diff --git a/.gitignore b/.gitignore index 316c32cb664..0ee3678ced8 100644 --- a/.gitignore +++ b/.gitignore @@ -14,7 +14,7 @@ index_data # Default configuration for persist_directory in chromadb/config.py # Currently it's located in "./chroma/" chroma/ -chroma_test_data +chroma_test_data/ server.htpasswd .venv diff --git a/chroma_data/chroma.sqlite3 b/chroma_data/chroma.sqlite3 new file mode 100644 index 0000000000000000000000000000000000000000..5885d1523cf4dc3a4ca02bdec0c439d63d523d39 GIT binary patch literal 126976 zcmeI*Pi)&(o(FKze?(h$thD*VIH@Nu$SiDOBu9TOyPh_y$Vx^S|4C)3o}dSTA}L#l zB-$eFB*k20r!&~wVvc((7IWD{uY2gFKv68PhhF9|$Gr~r(90ZlQOsf=|B0kTT23-S z;`FQ7vBdXC{`h%+`1mNx+n=mAREgB(PK#GbDsVBtFoAz0BoGK(qu)2__htG$ZoVF) zKdlF2Jvrs)wZLNg-?f#Q@Q+%y55hlA|6%&G>7P$cguV#{CcmHfcH)F>vCF~5@xP28 zjQ?`%yRnZ)zZ|_X@_XiQ3?KM5@U{QegExYDV)y!WmMa#S?j4n{H6^K4lSHx6-mTUT zn@v@Es#X<=?+ANUkyrWCQ)ly~LatmO<=nmX0+~NO^ZY_XB*o2g;X$E99+rw5xzZ!@ z>%yZbsWoNcF}b(CbuZGLon^VZcPF~}(}VNfl^M~e#ZtMFTQ?)~%}pxX#mxt#t}2-Y zMeS$`M||^8iV{_REVUb-OCQbC#k(E8H6Q7+*RL zf+FoUL|Q?lPAf|GciVE*Sy;~GNPn@7UuL=0)zga2yUA%S@$%5#wUV{weKu4yp~Y zny{Z=I(m=gZrx(KSv^!c!G89>%f((!7D9wRXlb=*)uY=b<-%{u&Z3H%8MK}1HJNN~ zmC0sheLWP(Ukq|P3BO=%pq8ZantL@dxa_@SHmpH2Co6r>)>?z>Ty!1HV7hr!j-Iqs zxL+t0HuHsT(#x*Ut~O#E-ec*Rap2amnMdvD?DTYFaf3?f9iY`w8;k2rnYR*XkH1d4 zY*}u4<#hDxga#2MbpMf1BsFtd+1qhD7?tqxr3-rDpB%XP6g&TnH>Fu{A#APqE*AHiT))QjwhyAE*4`>j8klXdjIQb z7jb{8Pb(viT4}DL*GaOqNeR{qbj*;?ZRc}q1>?B=YpGeuw6<$xJ!=)!8@Hp>-Ka*j zyb}s?`MDvMO3zT4>V7=Qa&vRc@s?gCtwXbZhPlAt!^wy_xpzMZ6X386?~BFH^T`5m#HyJ}Y;kA;)RvnuL$dYQgT}7|Y$h%^ctE zD+$GsYurbEXw_`k2Tw?UI=W|(iwo~k9BQ9He zxv8oRpIi^I(J?w0_Yt}BXfKYWj^#u~LsG5{Yn%@@@({DucOU8PmotU*JhhHp4wvE|nQI-_amQ|u0 z?(fSTmGJs}E)n)5I<5Moq3#i$PV9DwJ`ji!ULic`NXlVTr7w)UypcOZx7uhQ$;~5) z=z}J0T*8yK^u$=1@NJRM%aUi>a&=E4+S~_`=?Q<46jjnt^ylPpquDfftZ!c1EU(WQ zNPL^<*E@7{de~|&k}};0zrWvX2*w!eNk^MtS7m+cdcbt=lBj5l=;-pWLle#Y8o{d4 zE?wo>Vkoy>rW4D?RCG_ScO64=YilIGwO-lS^q9N|C6at58?R>*$wV#1FAGvO6X%nB zR!A+0d@^4DVqu(}o!tJXkIWL%3Mg@U0l5ZOy?`7;ZfiK{7=aigLhB;idjQ$n)LS<- zu^Mt;0NasT_17F(zUk>JePrigcTQFHAH?5lM$z8xurANp*9~+}tD$5qEz~nfAziB_ zm$FMzG9fN4FU9$EI=)h0&MbA=u^`9nJljy7DWcUC+ZThe-esDjGjX;ZJFDVsG`-cV ztGm@uIw>ZTVmy;gi>ZW2do>}+$E9q2DVY&gQape1fL`h`cYhZ)LNR+aYqG?#dskDj z((StP=BG{U!ZZ38tA+N}YiqI9aL>s!7_Cz)iIh-V;xqA7Jd@#9GFlm?)64NxR$5+3 zBu{=Z!g7;^FSd=6G30R_dwFfmm74Pk{Y7N&+CB#uw*(w1RzaPo z`K8rRYB`ljirHj_7ZXBaDV|N&c_E#Q*OIhvpu>#g9fplf=AUyLy&l+{FX=mNDSqYF zzn)om=B&5xiR+xtv#f^1R5p_l>tZsMtVzq#aw?vvr^RetsMAqFdU+{CNBC^`%RuAJ z5}X>FoitO%r^YyjG3H0d?B^q+p)qdKSUz%jEH*jWKM$ha|EF*I#{&WofB*y_009U< z00Izz00bZafp=IyyZ&dwzYEYWJRkr82tWV=5P$##AOHafKmY;|c&!3YnP6qU8xADW zLP}a*7Gw2HHW^C`*;;I+#tSj2R+nn=<#=r+F0`Ii?cevHpX8u46z!wPKD={Bl@&GjTq)oR#?4O0r&0 zr^Ph==)3G;r~RR@*O6QNhl;e@q91v8NB^sMyvpBE4w?;B+9;HBYq@gnl|K~kEWbG4 zU1QivW0nTL%x9O9E6F4cKE4!7FQ@oeZ8?*Or4!kunwZUsqQrXyex2tWV=5P$##AOHafKmY;| zI9~#||9`&93(-OV0uX=z1Rwwb2tWV=5P$##UPb`-|6fK8>p%bk5P$##AOHafKmY;| zfB*!}mjLenpRe*lv=D#*1Rwwb2tWV=5P$##AOL|^5E!HDUZLOD=r>2duhMTkAOHaf zKmY;|fB*y_009U<00Iy=p8}!KXyCdT`Y1CRn8yA8^I4FH8Uhf200bZa0SG_<0uX=z z1R(Gt0{H#^7tumq2tWV=5P$##AOHafKmY;|fWY|_!1e$6EH^|A0SG_<0uX=z1Rwwb z2tWV=5O@&*>;3=XzrM&?$O{1oKmY;|fB*y_009U<00Izzz}qe`HF|?Nk~&I5ZjWjA z|3|_<1j0Yk|L}kS1Rwwb2tWV=5P$##AOHafK;T^vz~}$H3#Eu-fB*y_009U<00Izz z00bZa0SLUHfOh}?0^17A{M*dYh5Ym%raznh`P4+{n^0i#`-yKSPS_T^99$g#%lN_g zFUP(c`*`%r(JLdrXa2_Ufo}s}`|m7xBd8~KuU}`mVv*_IQTbX^l3F!M6dUc`YW=X; zRHdhCRgw6PuvZm%mG?;X--Dg9r^Qmal3O<;^UX~v z+r`ZXq^>HN1x4*>(noyrP>K>&ek`>cpGzOj)5W_TzBM1|ve&fOnC_fkY=YdbH+DVt zW+Yr2%4Ou6UwG^WI*O2;V)<8Fl`<)9?G)F#Qyk0Xa!mJ@UnHU|C^UMHP|Or=Xe?$% z|6)@d-F`JHdYTzRSg4P8LFhK4@vRXw{?JCFR0z%Fd#S zni;g6>NS~cZk5SqWqmyq$zKd|I|;vFZJ?H<@|t@!F}Uo#V>YZoGbby3(AHXm>s)jl z&0xBDRF0msRJdO#6*lvQZPLrG(5^OO9NuH;nQ`FOv6)Bh=x2doC3OFhP$V^TTG`uiI~bMl@udrT;h!A1`4l_{TqXR=A(5td~i? zQlf=cuJ*DRD_1s)pHvD&>qN|28?{T7hC(UtP6xT|m|xAbR7WTs}k>kd!bFgJl=3@bJ9pKtKpqPxk7rqyOQo2T zN4kSY7LP>}E**~yW@U1uSiQFA5xw~u<%-nm>oUP0SG+x(Sg9uM^%XWVe%93zj@5$Q z#W9wYJB=4diESS?d;joE>E=hXh&r~A3ZKH`q}fj ztUKE&_OlU}t-ah-)rL>5huG*C9gO>kTzRw?=DX#sYG79pon;_Fxtq4&_4&Jg5L$g|2gx`%x`DLE_^@v*U6oUKTb@rQt&(H zKl^$c&maH+2tWV=Z-+qlgP*b7#s<^= 3.35.0.\033[0m\n" - "\033[94mPlease visit https://docs.trychroma.com/troubleshooting#sqlite to learn how to upgrade.\033[0m" + "\033[91mYour system has an unsupported version of sqlite3. Chroma \ + requires sqlite3 >= 3.35.0.\033[0m\n" + "\033[94mPlease visit \ + https://docs.trychroma.com/troubleshooting#sqlite to learn how \ + to upgrade.\033[0m" ) @@ -147,12 +152,11 @@ def Client(settings: Settings = __settings) -> API: system = System(settings) - telemetry_client = system.instance(Telemetry) + product_telemetry_client = system.instance(ProductTelemetryClient) api = system.instance(API) system.start() - # Submit event for client start - telemetry_client.capture(ClientStartEvent()) + product_telemetry_client.capture(ClientStartEvent()) return api diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index 2ddd537ebff..8db5bf889f7 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -31,7 +31,12 @@ from chromadb.auth.providers import RequestsClientAuthProtocolAdapter from chromadb.auth.registry import resolve_provider from chromadb.config import Settings, System -from chromadb.telemetry import Telemetry +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) +from chromadb.telemetry.product import ProductTelemetryClient from urllib.parse import urlparse, urlunparse, quote logger = logging.getLogger(__name__) @@ -51,7 +56,8 @@ def _validate_host(host: str) -> None: if "/" in host and (not host.startswith("http")): raise ValueError( "Invalid URL. " - "Seems that you are trying to pass URL as a host but without specifying the protocol. " + "Seems that you are trying to pass URL as a host but without \ + specifying the protocol. " "Please add http:// or https:// to the host." ) @@ -92,7 +98,8 @@ def __init__(self, system: System): system.settings.require("chroma_server_host") system.settings.require("chroma_server_http_port") - self._telemetry_client = self.require(Telemetry) + self._opentelemetry_client = self.require(OpenTelemetryClient) + self._product_telemetry_client = self.require(ProductTelemetryClient) self._settings = system.settings self._api_url = FastAPI.resolve_url( @@ -127,6 +134,7 @@ def __init__(self, system: System): if self._header is not None: self._session.headers.update(self._header) + @trace_method("FastAPI.heartbeat", OpenTelemetryGranularity.OPERATION) @override def heartbeat(self) -> int: """Returns the current server time in nanoseconds to check if the server is alive""" @@ -134,6 +142,7 @@ def heartbeat(self) -> int: raise_chroma_error(resp) return int(resp.json()["nanosecond heartbeat"]) + @trace_method("FastAPI.list_collections", OpenTelemetryGranularity.OPERATION) @override def list_collections(self) -> Sequence[Collection]: """Returns a list of all collections""" @@ -146,6 +155,7 @@ def list_collections(self) -> Sequence[Collection]: return collections + @trace_method("FastAPI.create_collection", OpenTelemetryGranularity.OPERATION) @override def create_collection( self, @@ -171,6 +181,7 @@ def create_collection( metadata=resp_json["metadata"], ) + @trace_method("FastAPI.get_collection", OpenTelemetryGranularity.OPERATION) @override def get_collection( self, @@ -189,6 +200,9 @@ def get_collection( metadata=resp_json["metadata"], ) + @trace_method( + "FastAPI.get_or_create_collection", OpenTelemetryGranularity.OPERATION + ) @override def get_or_create_collection( self, @@ -200,6 +214,7 @@ def get_or_create_collection( name, metadata, embedding_function, get_or_create=True ) + @trace_method("FastAPI._modify", OpenTelemetryGranularity.OPERATION) @override def _modify( self, @@ -214,12 +229,14 @@ def _modify( ) raise_chroma_error(resp) + @trace_method("FastAPI.delete_collection", OpenTelemetryGranularity.OPERATION) @override def delete_collection(self, name: str) -> None: """Deletes a collection""" resp = self._session.delete(self._api_url + "/collections/" + name) raise_chroma_error(resp) + @trace_method("FastAPI._count", OpenTelemetryGranularity.OPERATION) @override def _count(self, collection_id: UUID) -> int: """Returns the number of embeddings in the database""" @@ -229,6 +246,7 @@ def _count(self, collection_id: UUID) -> int: raise_chroma_error(resp) return cast(int, resp.json()) + @trace_method("FastAPI._peek", OpenTelemetryGranularity.OPERATION) @override def _peek(self, collection_id: UUID, n: int = 10) -> GetResult: return self._get( @@ -237,6 +255,7 @@ def _peek(self, collection_id: UUID, n: int = 10) -> GetResult: include=["embeddings", "documents", "metadatas"], ) + @trace_method("FastAPI._get", OpenTelemetryGranularity.OPERATION) @override def _get( self, @@ -279,6 +298,7 @@ def _get( documents=body.get("documents", None), ) + @trace_method("FastAPI._delete", OpenTelemetryGranularity.OPERATION) @override def _delete( self, @@ -298,6 +318,7 @@ def _delete( raise_chroma_error(resp) return cast(IDs, resp.json()) + @trace_method("FastAPI._submit_batch", OpenTelemetryGranularity.ALL) def _submit_batch( self, batch: Tuple[ @@ -321,6 +342,7 @@ def _submit_batch( ) return resp + @trace_method("FastAPI._add", OpenTelemetryGranularity.ALL) @override def _add( self, @@ -340,6 +362,7 @@ def _add( raise_chroma_error(resp) return True + @trace_method("FastAPI._update", OpenTelemetryGranularity.ALL) @override def _update( self, @@ -361,6 +384,7 @@ def _update( resp.raise_for_status() return True + @trace_method("FastAPI._upsert", OpenTelemetryGranularity.ALL) @override def _upsert( self, @@ -382,6 +406,7 @@ def _upsert( resp.raise_for_status() return True + @trace_method("FastAPI._query", OpenTelemetryGranularity.ALL) @override def _query( self, @@ -417,6 +442,7 @@ def _query( documents=body.get("documents", None), ) + @trace_method("FastAPI.reset", OpenTelemetryGranularity.ALL) @override def reset(self) -> bool: """Resets the database""" @@ -424,6 +450,7 @@ def reset(self) -> bool: raise_chroma_error(resp) return cast(bool, resp.json()) + @trace_method("FastAPI.get_version", OpenTelemetryGranularity.OPERATION) @override def get_version(self) -> str: """Returns the version of the server""" @@ -437,6 +464,7 @@ def get_settings(self) -> Settings: return self._settings @property + @trace_method("FastAPI.max_batch_size", OpenTelemetryGranularity.OPERATION) @override def max_batch_size(self) -> int: if self._max_batch_size == -1: diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index cfe1300e76e..45dcefc6697 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -2,7 +2,13 @@ from chromadb.config import Settings, System from chromadb.db.system import SysDB from chromadb.segment import SegmentManager, MetadataReader, VectorReader -from chromadb.telemetry import Telemetry +from chromadb.telemetry.opentelemetry import ( + add_attributes_to_current_span, + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) +from chromadb.telemetry.product import ProductTelemetryClient from chromadb.ingest import Producer from chromadb.api.models.Collection import Collection from chromadb import __version__ @@ -28,7 +34,7 @@ validate_where_document, validate_batch, ) -from chromadb.telemetry.events import ( +from chromadb.telemetry.product.events import ( CollectionAddEvent, CollectionDeleteEvent, CollectionGetEvent, @@ -78,7 +84,10 @@ class SegmentAPI(API): _sysdb: SysDB _manager: SegmentManager _producer: Producer - _telemetry_client: Telemetry + _product_telemetry_client: ProductTelemetryClient + _opentelemetry_client: OpenTelemetryClient + _tenant_id: str + _topic_ns: str _collection_cache: Dict[UUID, t.Collection] def __init__(self, system: System): @@ -86,7 +95,8 @@ def __init__(self, system: System): self._settings = system.settings self._sysdb = self.require(SysDB) self._manager = self.require(SegmentManager) - self._telemetry_client = self.require(Telemetry) + self._product_telemetry_client = self.require(ProductTelemetryClient) + self._opentelemetry_client = self.require(OpenTelemetryClient) self._producer = self.require(Producer) self._collection_cache = {} @@ -97,6 +107,7 @@ def heartbeat(self) -> int: # TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is # necessary because changing the value type from `Any` to`` `Union[str, int, float]` # causes the system to somehow convert all values to strings. + @trace_method("SegmentAPI.create_collection", OpenTelemetryGranularity.OPERATION) @override def create_collection( self, @@ -127,12 +138,13 @@ def create_collection( self._sysdb.create_segment(segment) # TODO: This event doesn't capture the get_or_create case appropriately - self._telemetry_client.capture( + self._product_telemetry_client.capture( ClientCreateCollectionEvent( collection_uuid=str(id), embedding_function=embedding_function.__class__.__name__, ) ) + add_attributes_to_current_span({"collection_uuid": str(id)}) return Collection( client=self, @@ -142,6 +154,9 @@ def create_collection( embedding_function=embedding_function, ) + @trace_method( + "SegmentAPI.get_or_create_collection", OpenTelemetryGranularity.OPERATION + ) @override def get_or_create_collection( self, @@ -149,7 +164,7 @@ def get_or_create_collection( metadata: Optional[CollectionMetadata] = None, embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), ) -> Collection: - return self.create_collection( + return self.create_collection( # type: ignore name=name, metadata=metadata, embedding_function=embedding_function, @@ -159,6 +174,7 @@ def get_or_create_collection( # TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is # necessary because changing the value type from `Any` to`` `Union[str, int, float]` # causes the system to somehow convert all values to strings + @trace_method("SegmentAPI.get_collection", OpenTelemetryGranularity.OPERATION) @override def get_collection( self, @@ -178,6 +194,7 @@ def get_collection( else: raise ValueError(f"Collection {name} does not exist.") + @trace_method("SegmentAPI.list_collection", OpenTelemetryGranularity.OPERATION) @override def list_collections(self) -> Sequence[Collection]: collections = [] @@ -193,6 +210,7 @@ def list_collections(self) -> Sequence[Collection]: ) return collections + @trace_method("SegmentAPI._modify", OpenTelemetryGranularity.OPERATION) @override def _modify( self, @@ -216,6 +234,7 @@ def _modify( elif new_metadata: self._sysdb.update_collection(id, metadata=new_metadata) + @trace_method("SegmentAPI.delete_collection", OpenTelemetryGranularity.OPERATION) @override def delete_collection(self, name: str) -> None: existing = self._sysdb.get_collections(name=name) @@ -229,6 +248,7 @@ def delete_collection(self, name: str) -> None: else: raise ValueError(f"Collection {name} does not exist.") + @trace_method("SegmentAPI._add", OpenTelemetryGranularity.OPERATION) @override def _add( self, @@ -256,7 +276,7 @@ def _add( records_to_submit.append(r) self._producer.submit_embeddings(coll["topic"], records_to_submit) - self._telemetry_client.capture( + self._product_telemetry_client.capture( CollectionAddEvent( collection_uuid=str(collection_id), add_amount=len(ids), @@ -266,6 +286,7 @@ def _add( ) return True + @trace_method("SegmentAPI._update", OpenTelemetryGranularity.OPERATION) @override def _update( self, @@ -293,7 +314,7 @@ def _update( records_to_submit.append(r) self._producer.submit_embeddings(coll["topic"], records_to_submit) - self._telemetry_client.capture( + self._product_telemetry_client.capture( CollectionUpdateEvent( collection_uuid=str(collection_id), update_amount=len(ids), @@ -305,6 +326,7 @@ def _update( return True + @trace_method("SegmentAPI._upsert", OpenTelemetryGranularity.OPERATION) @override def _upsert( self, @@ -334,6 +356,7 @@ def _upsert( return True + @trace_method("SegmentAPI._get", OpenTelemetryGranularity.OPERATION) @override def _get( self, @@ -348,6 +371,13 @@ def _get( where_document: Optional[WhereDocument] = {}, include: Include = ["embeddings", "metadatas", "documents"], ) -> GetResult: + add_attributes_to_current_span( + { + "collection_id": str(collection_id), + "ids_count": len(ids) if ids else 0, + } + ) + where = validate_where(where) if where is not None and len(where) > 0 else None where_document = ( validate_where_document(where_document) @@ -388,7 +418,7 @@ def _get( documents = [_doc(m) for m in metadatas] ids_amount = len(ids) if ids else 0 - self._telemetry_client.capture( + self._product_telemetry_client.capture( CollectionGetEvent( collection_uuid=str(collection_id), ids_count=ids_amount, @@ -407,6 +437,7 @@ def _get( documents=documents if "documents" in include else None, # type: ignore ) + @trace_method("SegmentAPI._delete", OpenTelemetryGranularity.OPERATION) @override def _delete( self, @@ -415,6 +446,13 @@ def _delete( where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, ) -> IDs: + add_attributes_to_current_span( + { + "collection_id": str(collection_id), + "ids_count": len(ids) if ids else 0, + } + ) + where = validate_where(where) if where is not None and len(where) > 0 else None where_document = ( validate_where_document(where_document) @@ -461,18 +499,21 @@ def _delete( records_to_submit.append(r) self._producer.submit_embeddings(coll["topic"], records_to_submit) - self._telemetry_client.capture( + self._product_telemetry_client.capture( CollectionDeleteEvent( collection_uuid=str(collection_id), delete_amount=len(ids_to_delete) ) ) return ids_to_delete + @trace_method("SegmentAPI._count", OpenTelemetryGranularity.OPERATION) @override def _count(self, collection_id: UUID) -> int: + add_attributes_to_current_span({"collection_id": str(collection_id)}) metadata_segment = self._manager.get_segment(collection_id, MetadataReader) return metadata_segment.count() + @trace_method("SegmentAPI._query", OpenTelemetryGranularity.OPERATION) @override def _query( self, @@ -483,6 +524,13 @@ def _query( where_document: WhereDocument = {}, include: Include = ["documents", "metadatas", "distances"], ) -> QueryResult: + add_attributes_to_current_span( + { + "collection_id": str(collection_id), + "n_results": n_results, + "where": str(where), + } + ) where = validate_where(where) if where is not None and len(where) > 0 else where where_document = ( validate_where_document(where_document) @@ -552,7 +600,7 @@ def _query( documents.append(doc_list) # type: ignore query_amount = len(query_embeddings) - self._telemetry_client.capture( + self._product_telemetry_client.capture( CollectionQueryEvent( collection_uuid=str(collection_id), query_amount=query_amount, @@ -573,9 +621,11 @@ def _query( documents=documents if documents else None, ) + @trace_method("SegmentAPI._peek", OpenTelemetryGranularity.OPERATION) @override def _peek(self, collection_id: UUID, n: int = 10) -> GetResult: - return self._get(collection_id, limit=n) + add_attributes_to_current_span({"collection_id": str(collection_id)}) + return self._get(collection_id, limit=n) # type: ignore @override def get_version(self) -> str: @@ -601,20 +651,24 @@ def max_batch_size(self) -> int: # TODO: This could potentially cause race conditions in a distributed version of the # system, since the cache is only local. + # TODO: promote collection -> topic to a base class method so that it can be + # used for channel assignment in the distributed version of the system. + @trace_method("SegmentAPI._validate_embedding_record", OpenTelemetryGranularity.ALL) def _validate_embedding_record( self, collection: t.Collection, record: t.SubmitEmbeddingRecord ) -> None: """Validate the dimension of an embedding record before submitting it to the system.""" + add_attributes_to_current_span({"collection_id": str(collection["id"])}) if record["embedding"]: self._validate_dimension(collection, len(record["embedding"]), update=True) + @trace_method("SegmentAPI._validate_dimension", OpenTelemetryGranularity.ALL) def _validate_dimension( self, collection: t.Collection, dim: int, update: bool ) -> None: """Validate that a collection supports records of the given dimension. If update is true, update the collection if the collection doesn't already have a dimension.""" - if collection["dimension"] is None: if update: id = collection["id"] @@ -627,6 +681,7 @@ def _validate_dimension( else: return # all is well + @trace_method("SegmentAPI._get_collection", OpenTelemetryGranularity.ALL) def _get_collection(self, collection_id: UUID) -> t.Collection: """Read-through cache for collection data""" if collection_id not in self._collection_cache: diff --git a/chromadb/auth/basic/__init__.py b/chromadb/auth/basic/__init__.py index a03d195e8ae..a9888598a22 100644 --- a/chromadb/auth/basic/__init__.py +++ b/chromadb/auth/basic/__init__.py @@ -17,6 +17,11 @@ ) from chromadb.auth.registry import register_provider, resolve_provider from chromadb.config import System +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) from chromadb.utils import get_class logger = logging.getLogger(__name__) @@ -84,6 +89,7 @@ def __init__(self, system: System) -> None: ), ) + @trace_method("BasicAuthServerProvider.authenticate", OpenTelemetryGranularity.ALL) @override def authenticate(self, request: ServerAuthenticationRequest[Any]) -> bool: try: diff --git a/chromadb/auth/fastapi.py b/chromadb/auth/fastapi.py index a488ef5f2b3..14b531e48e8 100644 --- a/chromadb/auth/fastapi.py +++ b/chromadb/auth/fastapi.py @@ -17,6 +17,11 @@ ChromaAuthMiddleware, ) from chromadb.auth.registry import resolve_provider +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) logger = logging.getLogger(__name__) @@ -72,6 +77,9 @@ def __init__(self, system: System) -> None: ) self._auth_provider = cast(ServerAuthProvider, self.require(_cls)) + @trace_method( + "FastAPIChromaAuthMiddleware.authenticate", OpenTelemetryGranularity.ALL + ) @override def authenticate( self, request: ServerAuthenticationRequest[Any] diff --git a/chromadb/auth/providers.py b/chromadb/auth/providers.py index eceee3bc2ab..2982b9e15a6 100644 --- a/chromadb/auth/providers.py +++ b/chromadb/auth/providers.py @@ -15,6 +15,11 @@ ) from chromadb.auth.registry import register_provider, resolve_provider from chromadb.config import System +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) T = TypeVar("T") @@ -34,6 +39,10 @@ def __init__(self, system: System) -> None: "The bcrypt python package is not installed. Please install it with `pip install bcrypt`" ) + @trace_method( + "HtpasswdServerAuthCredentialsProvider.validate_credentials", + OpenTelemetryGranularity.ALL, + ) @override def validate_credentials(self, credentials: AbstractCredentials[T]) -> bool: _creds = cast(Dict[str, SecretStr], credentials.get_credentials()) diff --git a/chromadb/auth/token/__init__.py b/chromadb/auth/token/__init__.py index 5132fa35798..6dfa8635942 100644 --- a/chromadb/auth/token/__init__.py +++ b/chromadb/auth/token/__init__.py @@ -19,6 +19,11 @@ ) from chromadb.auth.registry import register_provider, resolve_provider from chromadb.config import System +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) from chromadb.utils import get_class T = TypeVar("T") @@ -86,6 +91,10 @@ def __init__(self, system: System) -> None: check_token(token_str) self._token = SecretStr(token_str) + @trace_method( + "TokenConfigServerAuthCredentialsProvider.validate_credentials", + OpenTelemetryGranularity.ALL, + ) @override def validate_credentials(self, credentials: AbstractCredentials[T]) -> bool: _creds = cast(Dict[str, SecretStr], credentials.get_credentials()) @@ -150,6 +159,7 @@ def __init__(self, system: System) -> None: str(system.settings.chroma_server_auth_token_transport_header) ] + @trace_method("TokenAuthServerProvider.authenticate", OpenTelemetryGranularity.ALL) @override def authenticate(self, request: ServerAuthenticationRequest[Any]) -> bool: try: @@ -189,6 +199,7 @@ def __init__(self, system: System) -> None: str(system.settings.chroma_client_auth_token_transport_header) ] + @trace_method("TokenAuthClientProvider.authenticate", OpenTelemetryGranularity.ALL) @override def authenticate(self) -> ClientAuthResponse: _token = self._credentials_provider.get_credentials() diff --git a/chromadb/config.py b/chromadb/config.py index eb7bca93ef5..6731db255b7 100644 --- a/chromadb/config.py +++ b/chromadb/config.py @@ -64,7 +64,7 @@ # TODO: Don't use concrete types here to avoid circular deps. Strings are fine for right here! _abstract_type_keys: Dict[str, str] = { "chromadb.api.API": "chroma_api_impl", - "chromadb.telemetry.Telemetry": "chroma_telemetry_impl", + "chromadb.telemetry.product.ProductTelemetryClient": "chroma_product_telemetry_impl", "chromadb.ingest.Producer": "chroma_producer_impl", "chromadb.ingest.Consumer": "chroma_consumer_impl", "chromadb.ingest.CollectionAssignmentPolicy": "chroma_collection_assignment_policy_impl", # noqa @@ -83,7 +83,9 @@ class Settings(BaseSettings): # type: ignore chroma_db_impl: Optional[str] = None chroma_api_impl: str = "chromadb.api.segment.SegmentAPI" # Can be "chromadb.api.segment.SegmentAPI" or "chromadb.api.fastapi.FastAPI" - chroma_telemetry_impl: str = "chromadb.telemetry.posthog.Posthog" + chroma_product_telemetry_impl: str = "chromadb.telemetry.product.posthog.Posthog" + # Required for backwards compatibility + chroma_telemetry_impl: str = chroma_product_telemetry_impl # New architecture components chroma_sysdb_impl: str = "chromadb.db.impl.sqlite.SqliteDB" @@ -174,6 +176,11 @@ def chroma_server_auth_credentials_file_non_empty_file_exists( anonymized_telemetry: bool = True + chroma_otel_collection_endpoint: Optional[str] = "" + chroma_otel_service_name: Optional[str] = "chromadb" + chroma_otel_collection_headers: Dict[str, str] = {} + chroma_otel_granularity: Optional[str] = "none" + allow_reset: bool = False migrations: Literal["none", "validate", "apply"] = "apply" diff --git a/chromadb/db/impl/sqlite.py b/chromadb/db/impl/sqlite.py index aed14deb8e2..6652d21333a 100644 --- a/chromadb/db/impl/sqlite.py +++ b/chromadb/db/impl/sqlite.py @@ -4,6 +4,11 @@ import chromadb.db.base as base from chromadb.db.mixins.embeddings_queue import SqlEmbeddingsQueue from chromadb.db.mixins.sysdb import SqlSysDB +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) from chromadb.utils.delete_file import delete_file import sqlite3 from overrides import override @@ -67,6 +72,7 @@ def __init__(self, system: System): files("chromadb.migrations.metadb"), ] self._is_persistent = self._settings.require("is_persistent") + self._opentelemetry_client = system.require(OpenTelemetryClient) if not self._is_persistent: # In order to allow sqlite to be shared between multiple threads, we need to use a # URI connection string with shared cache. @@ -84,6 +90,7 @@ def __init__(self, system: System): self._tx_stack = local() super().__init__(system) + @trace_method("SqliteDB.start", OpenTelemetryGranularity.ALL) @override def start(self) -> None: super().start() @@ -92,6 +99,7 @@ def start(self) -> None: cur.execute("PRAGMA case_sensitive_like = ON") self.initialize_migrations() + @trace_method("SqliteDB.stop", OpenTelemetryGranularity.ALL) @override def stop(self) -> None: super().stop() @@ -122,6 +130,7 @@ def tx(self) -> TxWrapper: self._tx_stack.stack = [] return TxWrapper(self._conn_pool, stack=self._tx_stack) + @trace_method("SqliteDB.reset_state", OpenTelemetryGranularity.ALL) @override def reset_state(self) -> None: if not self._settings.require("allow_reset"): @@ -132,9 +141,9 @@ def reset_state(self) -> None: # Drop all tables cur.execute( """ - SELECT name FROM sqlite_master - WHERE type='table' - """ + SELECT name FROM sqlite_master + WHERE type='table' + """ ) for row in cur.fetchall(): cur.execute(f"DROP TABLE IF EXISTS {row[0]}") @@ -144,28 +153,30 @@ def reset_state(self) -> None: self.start() super().reset_state() + @trace_method("SqliteDB.setup_migrations", OpenTelemetryGranularity.ALL) @override def setup_migrations(self) -> None: with self.tx() as cur: cur.execute( """ - CREATE TABLE IF NOT EXISTS migrations ( - dir TEXT NOT NULL, - version INTEGER NOT NULL, - filename TEXT NOT NULL, - sql TEXT NOT NULL, - hash TEXT NOT NULL, - PRIMARY KEY (dir, version) - ) - """ + CREATE TABLE IF NOT EXISTS migrations ( + dir TEXT NOT NULL, + version INTEGER NOT NULL, + filename TEXT NOT NULL, + sql TEXT NOT NULL, + hash TEXT NOT NULL, + PRIMARY KEY (dir, version) + ) + """ ) + @trace_method("SqliteDB.migrations_initialized", OpenTelemetryGranularity.ALL) @override def migrations_initialized(self) -> bool: with self.tx() as cur: cur.execute( """SELECT count(*) FROM sqlite_master - WHERE type='table' AND name='migrations'""" + WHERE type='table' AND name='migrations'""" ) if cur.fetchone()[0] == 0: @@ -173,6 +184,7 @@ def migrations_initialized(self) -> bool: else: return True + @trace_method("SqliteDB.db_migrations", OpenTelemetryGranularity.ALL) @override def db_migrations(self, dir: Traversable) -> Sequence[Migration]: with self.tx() as cur: diff --git a/chromadb/db/migrations.py b/chromadb/db/migrations.py index af2ecce4375..76476502d1b 100644 --- a/chromadb/db/migrations.py +++ b/chromadb/db/migrations.py @@ -6,6 +6,11 @@ from chromadb.db.base import SqlDB, Cursor from abc import abstractmethod from chromadb.config import System, Settings +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) class MigrationFile(TypedDict): @@ -82,6 +87,7 @@ class MigratableDB(SqlDB): def __init__(self, system: System) -> None: self._settings = system.settings + self._opentelemetry_client = system.require(OpenTelemetryClient) super().__init__(system) @staticmethod @@ -127,6 +133,7 @@ def initialize_migrations(self) -> None: if migrate == "apply": self.apply_migrations() + @trace_method("MigratableDB.validate_migrations", OpenTelemetryGranularity.ALL) def validate_migrations(self) -> None: """Validate all migrations and throw an exception if there are any unapplied migrations in the source repo.""" @@ -142,6 +149,7 @@ def validate_migrations(self) -> None: version = unapplied_migrations[0]["version"] raise UnappliedMigrationsError(dir=dir.name, version=version) + @trace_method("MigratableDB.apply_migrations", OpenTelemetryGranularity.ALL) def apply_migrations(self) -> None: """Validate existing migrations, and apply all new ones.""" self.setup_migrations() diff --git a/chromadb/db/mixins/embeddings_queue.py b/chromadb/db/mixins/embeddings_queue.py index 472e0254283..f926d608e05 100644 --- a/chromadb/db/mixins/embeddings_queue.py +++ b/chromadb/db/mixins/embeddings_queue.py @@ -14,6 +14,11 @@ Operation, ) from chromadb.config import System +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) from overrides import override from collections import defaultdict from typing import Sequence, Tuple, Optional, Dict, Set, cast @@ -79,8 +84,10 @@ def __init__( def __init__(self, system: System): self._subscriptions = defaultdict(set) self._max_batch_size = None + self._opentelemetry_client = system.require(OpenTelemetryClient) super().__init__(system) + @trace_method("SqlEmbeddingsQueue.reset_state", OpenTelemetryGranularity.ALL) @override def reset_state(self) -> None: super().reset_state() @@ -91,6 +98,7 @@ def create_topic(self, topic_name: str) -> None: # Topic creation is implicit for this impl pass + @trace_method("SqlEmbeddingsQueue.delete_topic", OpenTelemetryGranularity.ALL) @override def delete_topic(self, topic_name: str) -> None: t = Table("embeddings_queue") @@ -104,6 +112,7 @@ def delete_topic(self, topic_name: str) -> None: sql, params = get_sql(q, self.parameter_format()) cur.execute(sql, params) + @trace_method("SqlEmbeddingsQueue.submit_embedding", OpenTelemetryGranularity.ALL) @override def submit_embedding( self, topic_name: str, embedding: SubmitEmbeddingRecord @@ -113,6 +122,7 @@ def submit_embedding( return self.submit_embeddings(topic_name, [embedding])[0] + @trace_method("SqlEmbeddingsQueue.submit_embeddings", OpenTelemetryGranularity.ALL) @override def submit_embeddings( self, topic_name: str, embeddings: Sequence[SubmitEmbeddingRecord] @@ -126,10 +136,10 @@ def submit_embeddings( if len(embeddings) > self.max_batch_size: raise ValueError( f""" - Cannot submit more than {self.max_batch_size:,} embeddings at once. - Please submit your embeddings in batches of size - {self.max_batch_size:,} or less. - """ + Cannot submit more than {self.max_batch_size:,} embeddings at once. + Please submit your embeddings in batches of size + {self.max_batch_size:,} or less. + """ ) t = Table("embeddings_queue") @@ -182,6 +192,7 @@ def submit_embeddings( self._notify_all(topic_name, embedding_records) return seq_ids + @trace_method("SqlEmbeddingsQueue.subscribe", OpenTelemetryGranularity.ALL) @override def subscribe( self, @@ -207,6 +218,7 @@ def subscribe( return subscription_id + @trace_method("SqlEmbeddingsQueue.unsubscribe", OpenTelemetryGranularity.ALL) @override def unsubscribe(self, subscription_id: UUID) -> None: for topic_name, subscriptions in self._subscriptions.items(): @@ -226,6 +238,7 @@ def max_seqid(self) -> SeqId: return 2**63 - 1 @property + @trace_method("SqlEmbeddingsQueue.max_batch_size", OpenTelemetryGranularity.ALL) @override def max_batch_size(self) -> int: if self._max_batch_size is None: @@ -247,6 +260,10 @@ def max_batch_size(self) -> int: self._max_batch_size = 999 // self.VARIABLES_PER_RECORD return self._max_batch_size + @trace_method( + "SqlEmbeddingsQueue._prepare_vector_encoding_metadata", + OpenTelemetryGranularity.ALL, + ) def _prepare_vector_encoding_metadata( self, embedding: SubmitEmbeddingRecord ) -> Tuple[Optional[bytes], Optional[str], Optional[str]]: @@ -260,6 +277,7 @@ def _prepare_vector_encoding_metadata( metadata = json.dumps(embedding["metadata"]) if embedding["metadata"] else None return embedding_bytes, encoding, metadata + @trace_method("SqlEmbeddingsQueue._backfill", OpenTelemetryGranularity.ALL) def _backfill(self, subscription: Subscription) -> None: """Backfill the given subscription with any currently matching records in the DB""" @@ -298,6 +316,7 @@ def _backfill(self, subscription: Subscription) -> None: ], ) + @trace_method("SqlEmbeddingsQueue._validate_range", OpenTelemetryGranularity.ALL) def _validate_range( self, start: Optional[SeqId], end: Optional[SeqId] ) -> Tuple[int, int]: @@ -311,6 +330,7 @@ def _validate_range( raise ValueError(f"Invalid SeqID range: {start} to {end}") return start, end + @trace_method("SqlEmbeddingsQueue._next_seq_id", OpenTelemetryGranularity.ALL) def _next_seq_id(self) -> int: """Get the next SeqID for this database.""" t = Table("embeddings_queue") @@ -319,12 +339,14 @@ def _next_seq_id(self) -> int: cur.execute(q.get_sql()) return int(cur.fetchone()[0]) + 1 + @trace_method("SqlEmbeddingsQueue._notify_all", OpenTelemetryGranularity.ALL) def _notify_all(self, topic: str, embeddings: Sequence[EmbeddingRecord]) -> None: """Send a notification to each subscriber of the given topic.""" if self._running: for sub in self._subscriptions[topic]: self._notify_one(sub, embeddings) + @trace_method("SqlEmbeddingsQueue._notify_one", OpenTelemetryGranularity.ALL) def _notify_one( self, sub: Subscription, embeddings: Sequence[EmbeddingRecord] ) -> None: diff --git a/chromadb/db/mixins/sysdb.py b/chromadb/db/mixins/sysdb.py index d105918e700..d9deb144f66 100644 --- a/chromadb/db/mixins/sysdb.py +++ b/chromadb/db/mixins/sysdb.py @@ -14,6 +14,12 @@ UniqueConstraintError, ) from chromadb.db.system import SysDB +from chromadb.telemetry.opentelemetry import ( + add_attributes_to_current_span, + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) from chromadb.ingest import CollectionAssignmentPolicy, Producer from chromadb.types import ( OptionalArgument, @@ -35,7 +41,9 @@ class SqlSysDB(SqlDB, SysDB): def __init__(self, system: System): self._assignment_policy = system.instance(CollectionAssignmentPolicy) super().__init__(system) + self._opentelemetry_client = system.require(OpenTelemetryClient) + @trace_method("SqlSysDB.create_segment", OpenTelemetryGranularity.ALL) @override def start(self) -> None: super().start() @@ -43,6 +51,15 @@ def start(self) -> None: @override def create_segment(self, segment: Segment) -> None: + add_attributes_to_current_span( + { + "segment_id": str(segment["id"]), + "segment_type": segment["type"], + "segment_scope": segment["scope"].value, + "segment_topic": str(segment["topic"]), + "collection": str(segment["collection"]), + } + ) with self.tx() as cur: segments = Table("segments") insert_segment = ( @@ -80,6 +97,7 @@ def create_segment(self, segment: Segment) -> None: segment["metadata"], ) + @trace_method("SqlSysDB.create_collection", OpenTelemetryGranularity.ALL) @override def create_collection( self, @@ -92,6 +110,13 @@ def create_collection( if id is None and not get_or_create: raise ValueError("id must be specified if get_or_create is False") + add_attributes_to_current_span( + { + "collection_id": str(id), + "collection_name": name, + } + ) + existing = self.get_collections(name=name) if existing: if get_or_create: @@ -146,6 +171,7 @@ def create_collection( ) return collection, True + @trace_method("SqlSysDB.get_segments", OpenTelemetryGranularity.ALL) @override def get_segments( self, @@ -155,6 +181,15 @@ def get_segments( topic: Optional[str] = None, collection: Optional[UUID] = None, ) -> Sequence[Segment]: + add_attributes_to_current_span( + { + "segment_id": str(id), + "segment_type": type if type else "", + "segment_scope": scope.value if scope else "", + "segment_topic": topic if topic else "", + "collection": str(collection), + } + ) segments_t = Table("segments") metadata_t = Table("segment_metadata") q = ( @@ -214,6 +249,7 @@ def get_segments( return segments + @trace_method("SqlSysDB.get_collections", OpenTelemetryGranularity.ALL) @override def get_collections( self, @@ -222,6 +258,13 @@ def get_collections( name: Optional[str] = None, ) -> Sequence[Collection]: """Get collections by name, embedding function and/or metadata""" + add_attributes_to_current_span( + { + "collection_id": str(id), + "collection_topic": topic if topic else "", + "collection_name": name if name else "", + } + ) collections_t = Table("collections") metadata_t = Table("collection_metadata") q = ( @@ -272,9 +315,15 @@ def get_collections( return collections + @trace_method("SqlSysDB.delete_segment", OpenTelemetryGranularity.ALL) @override def delete_segment(self, id: UUID) -> None: """Delete a segment from the SysDB""" + add_attributes_to_current_span( + { + "segment_id": str(id), + } + ) t = Table("segments") q = ( self.querybuilder() @@ -290,9 +339,15 @@ def delete_segment(self, id: UUID) -> None: if not result: raise NotFoundError(f"Segment {id} not found") + @trace_method("SqlSysDB.delete_collection", OpenTelemetryGranularity.ALL) @override def delete_collection(self, id: UUID) -> None: """Delete a topic and all associated segments from the SysDB""" + add_attributes_to_current_span( + { + "collection_id": str(id), + } + ) t = Table("collections") q = ( self.querybuilder() @@ -309,6 +364,7 @@ def delete_collection(self, id: UUID) -> None: raise NotFoundError(f"Collection {id} not found") self._producer.delete_topic(result[1]) + @trace_method("SqlSysDB.update_segment", OpenTelemetryGranularity.ALL) @override def update_segment( self, @@ -317,6 +373,12 @@ def update_segment( collection: OptionalArgument[Optional[UUID]] = Unspecified(), metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(), ) -> None: + add_attributes_to_current_span( + { + "segment_id": str(id), + "collection": str(collection), + } + ) segments_t = Table("segments") metadata_t = Table("segment_metadata") @@ -361,6 +423,7 @@ def update_segment( set(metadata.keys()), ) + @trace_method("SqlSysDB.update_collection", OpenTelemetryGranularity.ALL) @override def update_collection( self, @@ -370,6 +433,11 @@ def update_collection( dimension: OptionalArgument[Optional[int]] = Unspecified(), metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(), ) -> None: + add_attributes_to_current_span( + { + "collection_id": str(id), + } + ) collections_t = Table("collections") metadata_t = Table("collection_metadata") @@ -419,11 +487,17 @@ def update_collection( set(metadata.keys()), ) + @trace_method("SqlSysDB._metadata_from_rows", OpenTelemetryGranularity.ALL) def _metadata_from_rows( self, rows: Sequence[Tuple[Any, ...]] ) -> Optional[Metadata]: """Given SQL rows, return a metadata map (assuming that the last four columns are the key, str_value, int_value & float_value)""" + add_attributes_to_current_span( + { + "num_rows": len(rows), + } + ) metadata: Dict[str, Union[str, int, float]] = {} for row in rows: key = str(row[-4]) @@ -435,6 +509,7 @@ def _metadata_from_rows( metadata[key] = float(row[-1]) return metadata or None + @trace_method("SqlSysDB._insert_metadata", OpenTelemetryGranularity.ALL) def _insert_metadata( self, cur: Cursor, @@ -447,6 +522,11 @@ def _insert_metadata( # It would be cleaner to use something like ON CONFLICT UPDATE here But that is # very difficult to do in a portable way (e.g sqlite and postgres have # completely different sytnax) + add_attributes_to_current_span( + { + "num_keys": len(metadata), + } + ) if clear_keys: q = ( self.querybuilder() @@ -462,7 +542,11 @@ def _insert_metadata( self.querybuilder() .into(table) .columns( - id_col, table.key, table.str_value, table.int_value, table.float_value + id_col, + table.key, + table.str_value, + table.int_value, + table.float_value, ) ) sql_id = self.uuid_to_db(id) diff --git a/chromadb/ingest/impl/pulsar.py b/chromadb/ingest/impl/pulsar.py index 3f293c90580..3f71a1db36a 100644 --- a/chromadb/ingest/impl/pulsar.py +++ b/chromadb/ingest/impl/pulsar.py @@ -10,6 +10,11 @@ from chromadb.ingest.impl.utils import create_pulsar_connection_str from chromadb.proto.convert import from_proto_submit, to_proto_submit import chromadb.proto.chroma_pb2 as proto +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) from chromadb.types import SeqId, SubmitEmbeddingRecord import pulsar from concurrent.futures import wait, Future @@ -18,8 +23,10 @@ class PulsarProducer(Producer, EnforceOverrides): + # TODO: ensure trace context propagates _connection_str: str _topic_to_producer: Dict[str, pulsar.Producer] + _opentelemetry_client: OpenTelemetryClient _client: pulsar.Client _admin: PulsarAdmin _settings: Settings @@ -31,6 +38,7 @@ def __init__(self, system: System) -> None: self._topic_to_producer = {} self._settings = system.settings self._admin = PulsarAdmin(system) + self._opentelemetry_client = system.require(OpenTelemetryClient) super().__init__(system) @overrides @@ -51,6 +59,7 @@ def create_topic(self, topic_name: str) -> None: def delete_topic(self, topic_name: str) -> None: self._admin.delete_topic(topic_name) + @trace_method("PulsarProducer.submit_embedding", OpenTelemetryGranularity.ALL) @overrides def submit_embedding( self, topic_name: str, embedding: SubmitEmbeddingRecord @@ -62,6 +71,7 @@ def submit_embedding( msg_id: pulsar.MessageId = producer.send(proto_submit.SerializeToString()) return pulsar_to_int(msg_id) + @trace_method("PulsarProducer.submit_embeddings", OpenTelemetryGranularity.ALL) @overrides def submit_embeddings( self, topic_name: str, embeddings: Sequence[SubmitEmbeddingRecord] @@ -75,10 +85,10 @@ def submit_embeddings( if len(embeddings) > self.max_batch_size: raise ValueError( f""" - Cannot submit more than {self.max_batch_size:,} embeddings at once. - Please submit your embeddings in batches of size - {self.max_batch_size:,} or less. - """ + Cannot submit more than {self.max_batch_size:,} embeddings at once. + Please submit your embeddings in batches of size + {self.max_batch_size:,} or less. + """ ) producer = self._get_or_create_producer(topic_name) @@ -171,6 +181,7 @@ def __init__( _connection_str: str _client: pulsar.Client + _opentelemetry_client: OpenTelemetryClient _subscriptions: Dict[str, Set[PulsarSubscription]] _settings: Settings @@ -180,6 +191,7 @@ def __init__(self, system: System) -> None: self._connection_str = create_pulsar_connection_str(pulsar_host, pulsar_port) self._subscriptions = defaultdict(set) self._settings = system.settings + self._opentelemetry_client = system.require(OpenTelemetryClient) super().__init__(system) @overrides @@ -192,6 +204,7 @@ def stop(self) -> None: self._client.close() super().stop() + @trace_method("PulsarConsumer.subscribe", OpenTelemetryGranularity.ALL) @overrides def subscribe( self, diff --git a/chromadb/segment/impl/distributed/server.py b/chromadb/segment/impl/distributed/server.py index f7ea2f2ecaf..9b56ed4d18a 100644 --- a/chromadb/segment/impl/distributed/server.py +++ b/chromadb/segment/impl/distributed/server.py @@ -17,7 +17,11 @@ to_proto_vector_embedding_record, ) from chromadb.segment import SegmentImplementation, SegmentType, VectorReader -from chromadb.config import System +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) from chromadb.types import ScalarEncoding, Segment, SegmentScope import logging @@ -38,11 +42,16 @@ class SegmentServer(SegmentServerServicer, VectorReaderServicer): _segment_cache: Dict[UUID, SegmentImplementation] = {} _system: System + _opentelemetry_client: OpenTelemetryClient def __init__(self, system: System) -> None: super().__init__() self._system = system + self._opentelemetry_client = system.require(OpenTelemetryClient) + @trace_method( + "SegmentServer.LoadSegment", OpenTelemetryGranularity.OPERATION_AND_SEGMENT + ) def LoadSegment( self, request: proto.Segment, context: Any ) -> proto.SegmentServerResponse: @@ -85,6 +94,9 @@ def QueryVectors( context.set_details("Query segment not implemented yet") return proto.QueryVectorsResponse() + @trace_method( + "SegmentServer.GetVectors", OpenTelemetryGranularity.OPERATION_AND_SEGMENT + ) def GetVectors( self, request: proto.GetVectorsRequest, context: Any ) -> proto.GetVectorsResponse: diff --git a/chromadb/segment/impl/manager/distributed.py b/chromadb/segment/impl/manager/distributed.py index a7c673920a8..e03b58db224 100644 --- a/chromadb/segment/impl/manager/distributed.py +++ b/chromadb/segment/impl/manager/distributed.py @@ -15,6 +15,11 @@ from chromadb.db.system import SysDB from overrides import override from chromadb.segment.distributed import SegmentDirectory +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) from chromadb.types import Collection, Operation, Segment, SegmentScope, Metadata from typing import Dict, Type, Sequence, Optional, cast from uuid import UUID, uuid4 @@ -35,6 +40,7 @@ class DistributedSegmentManager(SegmentManager): _sysdb: SysDB _system: System + _opentelemetry_client: OpenTelemetryClient _instances: Dict[UUID, SegmentImplementation] _segment_cache: Dict[ UUID, Dict[SegmentScope, Segment] @@ -48,11 +54,16 @@ def __init__(self, system: System): self._sysdb = self.require(SysDB) self._segment_directory = self.require(SegmentDirectory) self._system = system + self._opentelemetry_client = system.require(OpenTelemetryClient) self._instances = {} self._segment_cache = defaultdict(dict) self._segment_server_stubs = {} self._lock = Lock() + @trace_method( + "DistributedSegmentManager.create_segments", + OpenTelemetryGranularity.OPERATION_AND_SEGMENT, + ) @override def create_segments(self, collection: Collection) -> Sequence[Segment]: vector_segment = _segment( @@ -67,6 +78,10 @@ def create_segments(self, collection: Collection) -> Sequence[Segment]: def delete_segments(self, collection_id: UUID) -> Sequence[UUID]: raise NotImplementedError() + @trace_method( + "DistributedSegmentManager.get_segment", + OpenTelemetryGranularity.OPERATION_AND_SEGMENT, + ) @override def get_segment(self, collection_id: UUID, type: type[S]) -> S: if type == MetadataReader: @@ -96,6 +111,10 @@ def get_segment(self, collection_id: UUID, type: type[S]) -> S: instance = self._instance(self._segment_cache[collection_id][scope]) return cast(S, instance) + @trace_method( + "DistributedSegmentManager.hint_use_collection", + OpenTelemetryGranularity.OPERATION_AND_SEGMENT, + ) @override def hint_use_collection(self, collection_id: UUID, hint_type: Operation) -> None: # TODO: this should call load/release on the target node, node should be stored in metadata @@ -114,6 +133,13 @@ def hint_use_collection(self, collection_id: UUID, hint_type: Operation) -> None segment = next(filter(lambda s: s["type"] in known_types, segments)) grpc_url = self._segment_directory.get_segment_endpoint(segment) + if grpc_url not in self._segment_server_stubs: + channel = grpc.insecure_channel(grpc_url) + self._segment_server_stubs[grpc_url] = SegmentServerStub(channel) # type: ignore + + self._segment_server_stubs[grpc_url].LoadSegment( + to_proto_segment(segment) + ) if grpc_url not in self._segment_server_stubs: channel = grpc.insecure_channel(grpc_url) self._segment_server_stubs[grpc_url] = SegmentServerStub(channel) diff --git a/chromadb/segment/impl/manager/local.py b/chromadb/segment/impl/manager/local.py index a5b797e31c6..5e7e8b53784 100644 --- a/chromadb/segment/impl/manager/local.py +++ b/chromadb/segment/impl/manager/local.py @@ -13,6 +13,11 @@ from chromadb.segment.impl.vector.local_persistent_hnsw import ( PersistentLocalHnswSegment, ) +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) from chromadb.types import Collection, Operation, Segment, SegmentScope, Metadata from typing import Dict, Type, Sequence, Optional, cast from uuid import UUID, uuid4 @@ -37,6 +42,7 @@ class LocalSegmentManager(SegmentManager): _sysdb: SysDB _system: System + _opentelemetry_client: OpenTelemetryClient _instances: Dict[UUID, SegmentImplementation] _vector_instances_file_handle_cache: LRUCache[ UUID, PersistentLocalHnswSegment @@ -52,6 +58,7 @@ def __init__(self, system: System): super().__init__(system) self._sysdb = self.require(SysDB) self._system = system + self._opentelemetry_client = system.require(OpenTelemetryClient) self._instances = {} self._segment_cache = defaultdict(dict) self._lock = Lock() @@ -93,6 +100,10 @@ def reset_state(self) -> None: self._segment_cache = defaultdict(dict) super().reset_state() + @trace_method( + "LocalSegmentManager.create_segments", + OpenTelemetryGranularity.OPERATION_AND_SEGMENT, + ) @override def create_segments(self, collection: Collection) -> Sequence[Segment]: vector_segment = _segment( @@ -103,6 +114,10 @@ def create_segments(self, collection: Collection) -> Sequence[Segment]: ) return [vector_segment, metadata_segment] + @trace_method( + "LocalSegmentManager.delete_segments", + OpenTelemetryGranularity.OPERATION_AND_SEGMENT, + ) @override def delete_segments(self, collection_id: UUID) -> Sequence[UUID]: segments = self._sysdb.get_segments(collection=collection_id) @@ -118,6 +133,10 @@ def delete_segments(self, collection_id: UUID) -> Sequence[UUID]: del self._segment_cache[collection_id] return [s["id"] for s in segments] + @trace_method( + "LocalSegmentManager.get_segment", + OpenTelemetryGranularity.OPERATION_AND_SEGMENT, + ) @override def get_segment(self, collection_id: UUID, type: Type[S]) -> S: if type == MetadataReader: @@ -140,6 +159,10 @@ def get_segment(self, collection_id: UUID, type: Type[S]) -> S: instance = self._instance(self._segment_cache[collection_id][scope]) return cast(S, instance) + @trace_method( + "LocalSegmentManager.hint_use_collection", + OpenTelemetryGranularity.OPERATION_AND_SEGMENT, + ) @override def hint_use_collection(self, collection_id: UUID, hint_type: Operation) -> None: # The local segment manager responds to hints by pre-loading both the metadata and vector diff --git a/chromadb/segment/impl/metadata/sqlite.py b/chromadb/segment/impl/metadata/sqlite.py index a7098d7808b..1bdb4eea63c 100644 --- a/chromadb/segment/impl/metadata/sqlite.py +++ b/chromadb/segment/impl/metadata/sqlite.py @@ -10,6 +10,11 @@ ParameterValue, get_sql, ) +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) from chromadb.types import ( Where, WhereDocument, @@ -39,6 +44,7 @@ class SqliteMetadataSegment(MetadataReader): _consumer: Consumer _db: SqliteDB _id: UUID + _opentelemetry_client: OpenTelemetryClient _topic: Optional[str] _subscription: Optional[UUID] @@ -46,8 +52,10 @@ def __init__(self, system: System, segment: Segment): self._db = system.instance(SqliteDB) self._consumer = system.instance(Consumer) self._id = segment["id"] + self._opentelemetry_client = system.require(OpenTelemetryClient) self._topic = segment["topic"] + @trace_method("SqliteMetadataSegment.start", OpenTelemetryGranularity.ALL) @override def start(self) -> None: if self._topic: @@ -56,11 +64,13 @@ def start(self) -> None: self._topic, self._write_metadata, start=seq_id ) + @trace_method("SqliteMetadataSegment.stop", OpenTelemetryGranularity.ALL) @override def stop(self) -> None: if self._subscription: self._consumer.unsubscribe(self._subscription) + @trace_method("SqliteMetadataSegment.max_seqid", OpenTelemetryGranularity.ALL) @override def max_seqid(self) -> SeqId: t = Table("max_seq_id") @@ -79,6 +89,7 @@ def max_seqid(self) -> SeqId: else: return _decode_seq_id(result[0]) + @trace_method("SqliteMetadataSegment.count", OpenTelemetryGranularity.ALL) @override def count(self) -> int: embeddings_t = Table("embeddings") @@ -95,6 +106,7 @@ def count(self) -> int: result = cur.execute(sql, params).fetchone()[0] return cast(int, result) + @trace_method("SqliteMetadataSegment.get_metadata", OpenTelemetryGranularity.ALL) @override def get_metadata( self, @@ -162,6 +174,7 @@ def _records( for _, group in group_iterator: yield self._record(list(group)) + @trace_method("SqliteMetadataSegment._record", OpenTelemetryGranularity.ALL) def _record(self, rows: Sequence[Tuple[Any, ...]]) -> MetadataEmbeddingRecord: """Given a list of DB rows with the same ID, construct a MetadataEmbeddingRecord""" @@ -187,6 +200,7 @@ def _record(self, rows: Sequence[Tuple[Any, ...]]) -> MetadataEmbeddingRecord: metadata=metadata or None, ) + @trace_method("SqliteMetadataSegment._insert_record", OpenTelemetryGranularity.ALL) def _insert_record( self, cur: Cursor, record: EmbeddingRecord, upsert: bool ) -> None: @@ -221,6 +235,9 @@ def _insert_record( if record["metadata"]: self._update_metadata(cur, id, record["metadata"]) + @trace_method( + "SqliteMetadataSegment._update_metadata", OpenTelemetryGranularity.ALL + ) def _update_metadata(self, cur: Cursor, id: int, metadata: UpdateMetadata) -> None: """Update the metadata for a single EmbeddingRecord""" t = Table("embedding_metadata") @@ -238,6 +255,9 @@ def _update_metadata(self, cur: Cursor, id: int, metadata: UpdateMetadata) -> No self._insert_metadata(cur, id, metadata) + @trace_method( + "SqliteMetadataSegment._insert_metadata", OpenTelemetryGranularity.ALL + ) def _insert_metadata(self, cur: Cursor, id: int, metadata: UpdateMetadata) -> None: """Insert or update each metadata row for a single embedding record""" t = Table("embedding_metadata") @@ -245,7 +265,12 @@ def _insert_metadata(self, cur: Cursor, id: int, metadata: UpdateMetadata) -> No self._db.querybuilder() .into(t) .columns( - t.id, t.key, t.string_value, t.int_value, t.float_value, t.bool_value + t.id, + t.key, + t.string_value, + t.int_value, + t.float_value, + t.bool_value, ) ) for key, value in metadata.items(): @@ -321,6 +346,7 @@ def insert_into_fulltext_search() -> None: cur.execute(sql, params) insert_into_fulltext_search() + @trace_method("SqliteMetadataSegment._delete_record", OpenTelemetryGranularity.ALL) def _delete_record(self, cur: Cursor, record: EmbeddingRecord) -> None: """Delete a single EmbeddingRecord from the DB""" t = Table("embeddings") @@ -351,6 +377,7 @@ def _delete_record(self, cur: Cursor, record: EmbeddingRecord) -> None: sql, params = get_sql(q) cur.execute(sql, params) + @trace_method("SqliteMetadataSegment._update_record", OpenTelemetryGranularity.ALL) def _update_record(self, cur: Cursor, record: EmbeddingRecord) -> None: """Update a single EmbeddingRecord in the DB""" t = Table("embeddings") @@ -371,6 +398,7 @@ def _update_record(self, cur: Cursor, record: EmbeddingRecord) -> None: if record["metadata"]: self._update_metadata(cur, id, record["metadata"]) + @trace_method("SqliteMetadataSegment._write_metadata", OpenTelemetryGranularity.ALL) def _write_metadata(self, records: Sequence[EmbeddingRecord]) -> None: """Write embedding metadata to the database. Care should be taken to ensure records are append-only (that is, that seq-ids should increase monotonically)""" @@ -398,6 +426,9 @@ def _write_metadata(self, records: Sequence[EmbeddingRecord]) -> None: elif record["operation"] == Operation.UPDATE: self._update_record(cur, record) + @trace_method( + "SqliteMetadataSegment._where_map_criterion", OpenTelemetryGranularity.ALL + ) def _where_map_criterion( self, q: QueryBuilder, where: Where, embeddings_t: Table, metadata_t: Table ) -> Criterion: @@ -427,6 +458,9 @@ def _where_map_criterion( clause.append(embeddings_t.id.isin(sq)) return reduce(lambda x, y: x & y, clause) + @trace_method( + "SqliteMetadataSegment._where_doc_criterion", OpenTelemetryGranularity.ALL + ) def _where_doc_criterion( self, q: QueryBuilder, diff --git a/chromadb/segment/impl/vector/grpc_segment.py b/chromadb/segment/impl/vector/grpc_segment.py index 0aac3baa253..89cc1b814f0 100644 --- a/chromadb/segment/impl/vector/grpc_segment.py +++ b/chromadb/segment/impl/vector/grpc_segment.py @@ -9,6 +9,11 @@ ) from chromadb.segment import MetadataReader, VectorReader from chromadb.segment.impl.vector.hnsw_params import PersistentHnswParams +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) from chromadb.types import ( Metadata, ScalarEncoding, @@ -30,6 +35,7 @@ class GrpcVectorSegment(VectorReader, EnforceOverrides): _vector_reader_stub: VectorReaderStub _segment: Segment + _opentelemetry_client: OpenTelemetryClient def __init__(self, system: System, segment: Segment): # TODO: move to start() method @@ -40,7 +46,9 @@ def __init__(self, system: System, segment: Segment): channel = grpc.insecure_channel(segment["metadata"]["grpc_url"]) self._vector_reader_stub = VectorReaderStub(channel) # type: ignore self._segment = segment + self._opentelemetry_client = system.require(OpenTelemetryClient) + @trace_method("GrpcVectorSegment.get_vectors", OpenTelemetryGranularity.ALL) @override def get_vectors( self, ids: Optional[Sequence[str]] = None @@ -53,6 +61,7 @@ def get_vectors( results.append(result) return results + @trace_method("GrpcVectorSegment.query_vectors", OpenTelemetryGranularity.ALL) @override def query_vectors( self, query: VectorQuery diff --git a/chromadb/segment/impl/vector/local_hnsw.py b/chromadb/segment/impl/vector/local_hnsw.py index c45af628d2f..e4437881b2a 100644 --- a/chromadb/segment/impl/vector/local_hnsw.py +++ b/chromadb/segment/impl/vector/local_hnsw.py @@ -6,6 +6,11 @@ from chromadb.config import System, Settings from chromadb.segment.impl.vector.batch import Batch from chromadb.segment.impl.vector.hnsw_params import HnswParams +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) from chromadb.types import ( EmbeddingRecord, VectorEmbeddingRecord, @@ -46,6 +51,8 @@ class LocalHnswSegment(VectorReader): _label_to_id: Dict[int, str] _id_to_seq_id: Dict[str, SeqId] + _opentelemtry_client: OpenTelemetryClient + def __init__(self, system: System, segment: Segment): self._consumer = system.instance(Consumer) self._id = segment["id"] @@ -63,6 +70,7 @@ def __init__(self, system: System, segment: Segment): self._label_to_id = {} self._lock = ReadWriteLock() + self._opentelemtry_client = system.require(OpenTelemetryClient) super().__init__(system, segment) @staticmethod @@ -72,6 +80,7 @@ def propagate_collection_metadata(metadata: Metadata) -> Optional[Metadata]: segment_metadata = HnswParams.extract(metadata) return segment_metadata + @trace_method("LocalHnswSegment.start", OpenTelemetryGranularity.ALL) @override def start(self) -> None: super().start() @@ -81,12 +90,14 @@ def start(self) -> None: self._topic, self._write_records, start=seq_id ) + @trace_method("LocalHnswSegment.stop", OpenTelemetryGranularity.ALL) @override def stop(self) -> None: super().stop() if self._subscription: self._consumer.unsubscribe(self._subscription) + @trace_method("LocalHnswSegment.get_vectors", OpenTelemetryGranularity.ALL) @override def get_vectors( self, ids: Optional[Sequence[str]] = None @@ -112,6 +123,7 @@ def get_vectors( return results + @trace_method("LocalHnswSegment.query_vectors", OpenTelemetryGranularity.ALL) @override def query_vectors( self, query: VectorQuery @@ -181,6 +193,7 @@ def max_seqid(self) -> SeqId: def count(self) -> int: return len(self._id_to_label) + @trace_method("LocalHnswSegment._init_index", OpenTelemetryGranularity.ALL) def _init_index(self, dimensionality: int) -> None: # more comments available at the source: https://github.com/nmslib/hnswlib @@ -198,6 +211,7 @@ def _init_index(self, dimensionality: int) -> None: self._index = index self._dimensionality = dimensionality + @trace_method("LocalHnswSegment._ensure_index", OpenTelemetryGranularity.ALL) def _ensure_index(self, n: int, dim: int) -> None: """Create or resize the index as necessary to accomodate N new records""" if not self._index: @@ -218,6 +232,7 @@ def _ensure_index(self, n: int, dim: int) -> None: ) index.resize_index(max(new_size, DEFAULT_CAPACITY)) + @trace_method("LocalHnswSegment._apply_batch", OpenTelemetryGranularity.ALL) def _apply_batch(self, batch: Batch) -> None: """Apply a batch of changes, as atomically as possible.""" deleted_ids = batch.get_deleted_ids() @@ -267,6 +282,7 @@ def _apply_batch(self, batch: Batch) -> None: # If that succeeds, finally the seq ID self._max_seq_id = batch.max_seq_id + @trace_method("LocalHnswSegment._write_records", OpenTelemetryGranularity.ALL) def _write_records(self, records: Sequence[EmbeddingRecord]) -> None: """Add a batch of embeddings to the index""" if not self._running: diff --git a/chromadb/segment/impl/vector/local_persistent_hnsw.py b/chromadb/segment/impl/vector/local_persistent_hnsw.py index f8c74bd0fe7..4ab60a1725d 100644 --- a/chromadb/segment/impl/vector/local_persistent_hnsw.py +++ b/chromadb/segment/impl/vector/local_persistent_hnsw.py @@ -11,6 +11,11 @@ LocalHnswSegment, ) from chromadb.segment.impl.vector.brute_force_index import BruteForceIndex +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) from chromadb.types import ( EmbeddingRecord, Metadata, @@ -81,9 +86,13 @@ class PersistentLocalHnswSegment(LocalHnswSegment): _persist_directory: str _allow_reset: bool + _opentelemtry_client: OpenTelemetryClient + def __init__(self, system: System, segment: Segment): super().__init__(system, segment) + self._opentelemtry_client = system.require(OpenTelemetryClient) + self._params = PersistentHnswParams(segment["metadata"] or {}) self._batch_size = self._params.batch_size self._sync_threshold = self._params.sync_threshold @@ -138,6 +147,9 @@ def _get_storage_folder(self) -> str: folder = os.path.join(self._persist_directory, str(self._id)) return folder + @trace_method( + "PersistentLocalHnswSegment._init_index", OpenTelemetryGranularity.ALL + ) @override def _init_index(self, dimensionality: int) -> None: index = hnswlib.Index(space=self._params.space, dim=dimensionality) @@ -172,6 +184,7 @@ def _init_index(self, dimensionality: int) -> None: self._dimensionality = dimensionality self._index_initialized = True + @trace_method("PersistentLocalHnswSegment._persist", OpenTelemetryGranularity.ALL) def _persist(self) -> None: """Persist the index and data to disk""" index = cast(hnswlib.Index, self._index) @@ -193,6 +206,9 @@ def _persist(self) -> None: with open(self._get_metadata_file(), "wb") as metadata_file: pickle.dump(self._persist_data, metadata_file, pickle.HIGHEST_PROTOCOL) + @trace_method( + "PersistentLocalHnswSegment._apply_batch", OpenTelemetryGranularity.ALL + ) @override def _apply_batch(self, batch: Batch) -> None: super()._apply_batch(batch) @@ -202,6 +218,9 @@ def _apply_batch(self, batch: Batch) -> None: ): self._persist() + @trace_method( + "PersistentLocalHnswSegment._write_records", OpenTelemetryGranularity.ALL + ) @override def _write_records(self, records: Sequence[EmbeddingRecord]) -> None: """Add a batch of embeddings to the index""" @@ -267,6 +286,9 @@ def count(self) -> int: - self._curr_batch.delete_count ) + @trace_method( + "PersistentLocalHnswSegment.get_vectors", OpenTelemetryGranularity.ALL + ) @override def get_vectors( self, ids: Optional[Sequence[str]] = None @@ -310,6 +332,9 @@ def get_vectors( return results # type: ignore ## Python can't cast List with Optional to List with VectorEmbeddingRecord + @trace_method( + "PersistentLocalHnswSegment.query_vectors", OpenTelemetryGranularity.ALL + ) @override def query_vectors( self, query: VectorQuery @@ -395,6 +420,9 @@ def query_vectors( results.append(curr_results) return results + @trace_method( + "PersistentLocalHnswSegment.reset_state", OpenTelemetryGranularity.ALL + ) @override def reset_state(self) -> None: if self._allow_reset: @@ -403,6 +431,7 @@ def reset_state(self) -> None: self.close_persistent_index() shutil.rmtree(data_path, ignore_errors=True) + @trace_method("PersistentLocalHnswSegment.delete", OpenTelemetryGranularity.ALL) @override def delete(self) -> None: data_path = self._get_storage_folder() diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index e92d16d63ba..4921392d3ee 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -35,7 +35,12 @@ from starlette.requests import Request import logging -from chromadb.telemetry import ServerContext, Telemetry +from chromadb.telemetry.product import ServerContext, ProductTelemetryClient +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) logger = logging.getLogger(__name__) @@ -102,9 +107,10 @@ def include_in_schema(path: str) -> bool: class FastAPI(chromadb.server.Server): def __init__(self, settings: Settings): super().__init__(settings) - Telemetry.SERVER_CONTEXT = ServerContext.FASTAPI + ProductTelemetryClient.SERVER_CONTEXT = ServerContext.FASTAPI self._app = fastapi.FastAPI(debug=True) self._api: chromadb.api.API = chromadb.Client(settings) + self._opentelemetry_client = self._api.require(OpenTelemetryClient) self._app.middleware("http")(catch_exceptions_middleware) self._app.add_middleware( @@ -221,9 +227,11 @@ def heartbeat(self) -> Dict[str, int]: def version(self) -> str: return self._api.get_version() + @trace_method("FastAPI.list_collections", OpenTelemetryGranularity.OPERATION) def list_collections(self) -> Sequence[Collection]: return self._api.list_collections() + @trace_method("FastAPI.create_collection", OpenTelemetryGranularity.OPERATION) def create_collection(self, collection: CreateCollection) -> Collection: return self._api.create_collection( name=collection.name, @@ -231,9 +239,11 @@ def create_collection(self, collection: CreateCollection) -> Collection: get_or_create=collection.get_or_create, ) + @trace_method("FastAPI.get_collection", OpenTelemetryGranularity.OPERATION) def get_collection(self, collection_name: str) -> Collection: return self._api.get_collection(collection_name) + @trace_method("FastAPI.update_collection", OpenTelemetryGranularity.OPERATION) def update_collection( self, collection_id: str, collection: UpdateCollection ) -> None: @@ -243,9 +253,11 @@ def update_collection( new_metadata=collection.new_metadata, ) + @trace_method("FastAPI.delete_collection", OpenTelemetryGranularity.OPERATION) def delete_collection(self, collection_name: str) -> None: return self._api.delete_collection(collection_name) + @trace_method("FastAPI.add", OpenTelemetryGranularity.OPERATION) def add(self, collection_id: str, add: AddEmbedding) -> None: try: result = self._api._add( @@ -259,6 +271,7 @@ def add(self, collection_id: str, add: AddEmbedding) -> None: raise HTTPException(status_code=500, detail=str(e)) return result + @trace_method("FastAPI.update", OpenTelemetryGranularity.OPERATION) def update(self, collection_id: str, add: UpdateEmbedding) -> None: return self._api._update( ids=add.ids, @@ -268,6 +281,7 @@ def update(self, collection_id: str, add: UpdateEmbedding) -> None: metadatas=add.metadatas, ) + @trace_method("FastAPI.upsert", OpenTelemetryGranularity.OPERATION) def upsert(self, collection_id: str, upsert: AddEmbedding) -> None: return self._api._upsert( collection_id=_uuid(collection_id), @@ -277,6 +291,7 @@ def upsert(self, collection_id: str, upsert: AddEmbedding) -> None: metadatas=upsert.metadatas, ) + @trace_method("FastAPI.get", OpenTelemetryGranularity.OPERATION) def get(self, collection_id: str, get: GetEmbedding) -> GetResult: return self._api._get( collection_id=_uuid(collection_id), @@ -289,6 +304,7 @@ def get(self, collection_id: str, get: GetEmbedding) -> GetResult: include=get.include, ) + @trace_method("FastAPI.delete", OpenTelemetryGranularity.OPERATION) def delete(self, collection_id: str, delete: DeleteEmbedding) -> List[UUID]: return self._api._delete( where=delete.where, @@ -297,12 +313,14 @@ def delete(self, collection_id: str, delete: DeleteEmbedding) -> List[UUID]: where_document=delete.where_document, ) + @trace_method("FastAPI.count", OpenTelemetryGranularity.OPERATION) def count(self, collection_id: str) -> int: return self._api._count(_uuid(collection_id)) def reset(self) -> bool: return self._api.reset() + @trace_method("FastAPI.get_nearest_neighbors", OpenTelemetryGranularity.OPERATION) def get_nearest_neighbors( self, collection_id: str, query: QueryEmbedding ) -> QueryResult: diff --git a/chromadb/telemetry/README.md b/chromadb/telemetry/README.md new file mode 100644 index 00000000000..c406074e41e --- /dev/null +++ b/chromadb/telemetry/README.md @@ -0,0 +1,10 @@ +# Telemetry + +This directory holds all the telemetry for Chroma. + +- `product/` contains anonymized product telemetry which we, Chroma, collect so we can + understand usage patterns. For more information, see https://docs.trychroma.com/telemetry. +- `opentelemetry/` contains all of the config for Chroma's [OpenTelemetry](https://opentelemetry.io/docs/instrumentation/python/getting-started/) + setup. These metrics are *not* sent back to Chroma -- anyone operating a Chroma instance + can use the OpenTelemetry metrics and traces to understand how their instance of Chroma + is behaving. \ No newline at end of file diff --git a/chromadb/telemetry/__init__.py b/chromadb/telemetry/__init__.py index d20b8e5d71c..e69de29bb2d 100644 --- a/chromadb/telemetry/__init__.py +++ b/chromadb/telemetry/__init__.py @@ -1,122 +0,0 @@ -from abc import abstractmethod -import os -from typing import Callable, ClassVar, Dict, Any -import uuid -import time -from threading import Event, Thread -import chromadb -from chromadb.config import Component -from pathlib import Path -from enum import Enum - -TELEMETRY_WHITELISTED_SETTINGS = [ - "chroma_api_impl", - "is_persistent", - "chroma_server_ssl_enabled", -] - - -class ServerContext(Enum): - NONE = "None" - FASTAPI = "FastAPI" - - -class TelemetryEvent: - max_batch_size: ClassVar[int] = 1 - batch_size: int - - def __init__(self, batch_size: int = 1): - self.batch_size = batch_size - - @property - def properties(self) -> Dict[str, Any]: - return self.__dict__ - - @property - def name(self) -> str: - return self.__class__.__name__ - - # A batch key is used to determine whether two events can be batched together. - # If a TelemetryEvent's max_batch_size > 1, batch_key() and batch() MUST be implemented. - # Otherwise they are ignored. - @property - def batch_key(self) -> str: - return self.name - - def batch(self, other: "TelemetryEvent") -> "TelemetryEvent": - raise NotImplementedError - - -class RepeatedTelemetry: - def __init__(self, interval: int, function: Callable[[], None]): - self.interval = interval - self.function = function - self.start = time.time() - self.event = Event() - self.thread = Thread(target=self._target) - self.thread.daemon = True - self.thread.start() - - def _target(self) -> None: - while not self.event.wait(self._time): - self.function() - - @property - def _time(self) -> float: - return self.interval - ((time.time() - self.start) % self.interval) - - def stop(self) -> None: - self.event.set() - self.thread.join() - - -class Telemetry(Component): - USER_ID_PATH = str(Path.home() / ".cache" / "chroma" / "telemetry_user_id") - UNKNOWN_USER_ID = "UNKNOWN" - SERVER_CONTEXT: ServerContext = ServerContext.NONE - _curr_user_id = None - - @abstractmethod - def capture(self, event: TelemetryEvent) -> None: - pass - - # Schedule a function that creates a TelemetryEvent to be called every `every_seconds` seconds. - def schedule_event_function( - self, event_function: Callable[..., TelemetryEvent], every_seconds: int - ) -> None: - RepeatedTelemetry(every_seconds, lambda: self.capture(event_function())) - - @property - def context(self) -> Dict[str, Any]: - chroma_version = chromadb.__version__ - settings = chromadb.get_settings() - telemetry_settings = {} - for whitelisted in TELEMETRY_WHITELISTED_SETTINGS: - telemetry_settings[whitelisted] = settings[whitelisted] - - self._context = { - "chroma_version": chroma_version, - "server_context": self.SERVER_CONTEXT.value, - **telemetry_settings, - } - return self._context - - @property - def user_id(self) -> str: - if self._curr_user_id: - return self._curr_user_id - - # File access may fail due to permissions or other reasons. We don't want to crash so we catch all exceptions. - try: - if not os.path.exists(self.USER_ID_PATH): - os.makedirs(os.path.dirname(self.USER_ID_PATH), exist_ok=True) - with open(self.USER_ID_PATH, "w") as f: - new_user_id = str(uuid.uuid4()) - f.write(new_user_id) - self._curr_user_id = new_user_id - else: - with open(self.USER_ID_PATH, "r") as f: - self._curr_user_id = f.read() - except Exception: - self._curr_user_id = self.UNKNOWN_USER_ID - return self._curr_user_id diff --git a/chromadb/telemetry/opentelemetry/__init__.py b/chromadb/telemetry/opentelemetry/__init__.py new file mode 100644 index 00000000000..0840871bcae --- /dev/null +++ b/chromadb/telemetry/opentelemetry/__init__.py @@ -0,0 +1,128 @@ +from functools import wraps +from enum import Enum +from typing import Any, Callable, Dict, Optional, Union + +from opentelemetry import trace +from opentelemetry.sdk.resources import SERVICE_NAME, Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import ( + BatchSpanProcessor, +) +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter + +from chromadb.config import Component +from chromadb.config import System + + +class OpenTelemetryGranularity(Enum): + """The granularity of the OpenTelemetry spans.""" + + NONE = "none" + """No spans are emitted.""" + + OPERATION = "operation" + """Spans are emitted for each operation.""" + + OPERATION_AND_SEGMENT = "operation_and_segment" + """Spans are emitted for each operation and segment.""" + + ALL = "all" + """Spans are emitted for almost every method call.""" + + # Greater is more restrictive. So "all" < "operation" (and everything else), + # "none" > everything. + def __lt__(self, other: Any) -> bool: + """Compare two granularities.""" + order = [ + OpenTelemetryGranularity.ALL, + OpenTelemetryGranularity.OPERATION_AND_SEGMENT, + OpenTelemetryGranularity.OPERATION, + OpenTelemetryGranularity.NONE, + ] + return order.index(self) < order.index(other) + + +class OpenTelemetryClient(Component): + def __init__(self, system: System): + super().__init__(system) + otel_init( + system.settings.chroma_otel_service_name, + system.settings.chroma_otel_collection_endpoint, + system.settings.chroma_otel_collection_headers, + OpenTelemetryGranularity(system.settings.chroma_otel_granularity), + ) + + +tracer: Optional[trace.Tracer] = None +granularity: OpenTelemetryGranularity = OpenTelemetryGranularity("none") + + +def otel_init( + otel_service_name: Optional[str], + otel_collection_endpoint: Optional[str], + otel_collection_headers: Optional[Dict[str, str]], + otel_granularity: OpenTelemetryGranularity, +) -> None: + """Initializes module-level state for OpenTelemetry. + + Parameters match the environment variables which configure OTel as documented + at https://docs.trychroma.com/observability. + - otel_service_name: The name of the service for OTel tagging and aggregation. + - otel_collection_endpoint: The endpoint to which OTel spans are sent (e.g. api.honeycomb.com). + - otel_collection_headers: The headers to send with OTel spans (e.g. {"x-honeycomb-team": "abc123"}). + - otel_granularity: The granularity of the spans to emit. + """ + if otel_granularity == OpenTelemetryGranularity.NONE: + return + resource = Resource(attributes={SERVICE_NAME: str(otel_service_name)}) + provider = TracerProvider(resource=resource) + provider.add_span_processor( + BatchSpanProcessor( + # TODO: we may eventually want to make this configurable. + OTLPSpanExporter( + endpoint=str(otel_collection_endpoint), + headers=otel_collection_headers, + ) + ) + ) + trace.set_tracer_provider(provider) + + global tracer, granularity + tracer = trace.get_tracer(__name__) + granularity = otel_granularity + + +def trace_method( + trace_name: str, + trace_granularity: OpenTelemetryGranularity, + attributes: Dict[str, Union[str, bool, float, int]] = {}, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """A decorator that traces a method.""" + + def decorator(f: Callable[..., Any]) -> Callable[..., Any]: + @wraps(f) + def wrapper(*args: Any, **kwargs: Dict[Any, Any]) -> Any: + global tracer, granularity, _transform_attributes + if trace_granularity < granularity: + return f(*args, **kwargs) + if not tracer: + return + with tracer.start_as_current_span(trace_name, attributes=attributes): + return f(*args, **kwargs) + + return wrapper + + return decorator + + +def add_attributes_to_current_span( + attributes: Dict[str, Union[str, bool, float, int]] +) -> None: + """Add attributes to the current span.""" + global tracer, granularity, _transform_attributes + if granularity == OpenTelemetryGranularity.NONE: + return + if not tracer: + return + span = trace.get_current_span() + span.set_attributes(_transform_attributes(attributes)) # type: ignore diff --git a/chromadb/telemetry/product/__init__.py b/chromadb/telemetry/product/__init__.py new file mode 100644 index 00000000000..a6fd0d7ad87 --- /dev/null +++ b/chromadb/telemetry/product/__init__.py @@ -0,0 +1,93 @@ +from abc import abstractmethod +import os +from typing import ClassVar, Dict, Any +import uuid +import chromadb +from chromadb.config import Component +from pathlib import Path +from enum import Enum + +TELEMETRY_WHITELISTED_SETTINGS = [ + "chroma_api_impl", + "is_persistent", + "chroma_server_ssl_enabled", +] + + +class ServerContext(Enum): + NONE = "None" + FASTAPI = "FastAPI" + + +class ProductTelemetryEvent: + max_batch_size: ClassVar[int] = 1 + batch_size: int + + def __init__(self, batch_size: int = 1): + self.batch_size = batch_size + + @property + def properties(self) -> Dict[str, Any]: + return self.__dict__ + + @property + def name(self) -> str: + return self.__class__.__name__ + + # A batch key is used to determine whether two events can be batched together. + # If a TelemetryEvent's max_batch_size > 1, batch_key() and batch() MUST be + # implemented. + # Otherwise they are ignored. + @property + def batch_key(self) -> str: + return self.name + + def batch(self, other: "ProductTelemetryEvent") -> "ProductTelemetryEvent": + raise NotImplementedError + + +class ProductTelemetryClient(Component): + USER_ID_PATH = str(Path.home() / ".cache" / "chroma" / "telemetry_user_id") + UNKNOWN_USER_ID = "UNKNOWN" + SERVER_CONTEXT: ServerContext = ServerContext.NONE + _curr_user_id = None + + @abstractmethod + def capture(self, event: ProductTelemetryEvent) -> None: + pass + + @property + def context(self) -> Dict[str, Any]: + chroma_version = chromadb.__version__ + settings = chromadb.get_settings() + telemetry_settings = {} + for whitelisted in TELEMETRY_WHITELISTED_SETTINGS: + telemetry_settings[whitelisted] = settings[whitelisted] + + self._context = { + "chroma_version": chroma_version, + "server_context": self.SERVER_CONTEXT.value, + **telemetry_settings, + } + return self._context + + @property + def user_id(self) -> str: + if self._curr_user_id: + return self._curr_user_id + + # File access may fail due to permissions or other reasons. We don't want to + # crash so we catch all exceptions. + try: + if not os.path.exists(self.USER_ID_PATH): + os.makedirs(os.path.dirname(self.USER_ID_PATH), exist_ok=True) + with open(self.USER_ID_PATH, "w") as f: + new_user_id = str(uuid.uuid4()) + f.write(new_user_id) + self._curr_user_id = new_user_id + else: + with open(self.USER_ID_PATH, "r") as f: + self._curr_user_id = f.read() + except Exception: + self._curr_user_id = self.UNKNOWN_USER_ID + return self._curr_user_id diff --git a/chromadb/telemetry/events.py b/chromadb/telemetry/product/events.py similarity index 89% rename from chromadb/telemetry/events.py rename to chromadb/telemetry/product/events.py index e662cd85fa7..e5f6bc688c1 100644 --- a/chromadb/telemetry/events.py +++ b/chromadb/telemetry/product/events.py @@ -1,14 +1,14 @@ from typing import cast, ClassVar -from chromadb.telemetry import TelemetryEvent +from chromadb.telemetry.product import ProductTelemetryEvent from chromadb.utils.embedding_functions import get_builtins -class ClientStartEvent(TelemetryEvent): +class ClientStartEvent(ProductTelemetryEvent): def __init__(self) -> None: super().__init__() -class ClientCreateCollectionEvent(TelemetryEvent): +class ClientCreateCollectionEvent(ProductTelemetryEvent): collection_uuid: str embedding_function: str @@ -25,7 +25,7 @@ def __init__(self, collection_uuid: str, embedding_function: str): ) -class CollectionAddEvent(TelemetryEvent): +class CollectionAddEvent(ProductTelemetryEvent): max_batch_size: ClassVar[int] = 100 batch_size: int collection_uuid: str @@ -52,7 +52,7 @@ def __init__( def batch_key(self) -> str: return self.collection_uuid + self.name - def batch(self, other: "TelemetryEvent") -> "CollectionAddEvent": + def batch(self, other: "ProductTelemetryEvent") -> "CollectionAddEvent": if not self.batch_key == other.batch_key: raise ValueError("Cannot batch events") other = cast(CollectionAddEvent, other) @@ -66,7 +66,7 @@ def batch(self, other: "TelemetryEvent") -> "CollectionAddEvent": ) -class CollectionUpdateEvent(TelemetryEvent): +class CollectionUpdateEvent(ProductTelemetryEvent): collection_uuid: str update_amount: int with_embeddings: int @@ -89,7 +89,7 @@ def __init__( self.with_documents = with_documents -class CollectionQueryEvent(TelemetryEvent): +class CollectionQueryEvent(ProductTelemetryEvent): max_batch_size: ClassVar[int] = 20 batch_size: int collection_uuid: str @@ -128,7 +128,7 @@ def __init__( def batch_key(self) -> str: return self.collection_uuid + self.name - def batch(self, other: "TelemetryEvent") -> "CollectionQueryEvent": + def batch(self, other: "ProductTelemetryEvent") -> "CollectionQueryEvent": if not self.batch_key == other.batch_key: raise ValueError("Cannot batch events") other = cast(CollectionQueryEvent, other) @@ -146,7 +146,7 @@ def batch(self, other: "TelemetryEvent") -> "CollectionQueryEvent": ) -class CollectionGetEvent(TelemetryEvent): +class CollectionGetEvent(ProductTelemetryEvent): collection_uuid: str ids_count: int limit: int @@ -169,7 +169,7 @@ def __init__( self.include_documents = include_documents -class CollectionDeleteEvent(TelemetryEvent): +class CollectionDeleteEvent(ProductTelemetryEvent): collection_uuid: str delete_amount: int diff --git a/chromadb/telemetry/posthog.py b/chromadb/telemetry/product/posthog.py similarity index 77% rename from chromadb/telemetry/posthog.py rename to chromadb/telemetry/product/posthog.py index 21676b9fbe7..05c46b07256 100644 --- a/chromadb/telemetry/posthog.py +++ b/chromadb/telemetry/product/posthog.py @@ -3,19 +3,23 @@ import sys from typing import Any, Dict, Set from chromadb.config import System -from chromadb.telemetry import Telemetry, TelemetryEvent +from chromadb.telemetry.product import ( + ProductTelemetryClient, + ProductTelemetryEvent, +) from overrides import override logger = logging.getLogger(__name__) -class Posthog(Telemetry): +class Posthog(ProductTelemetryClient): def __init__(self, system: System): if not system.settings.anonymized_telemetry or "pytest" in sys.modules: posthog.disabled = True else: logger.info( - "Anonymized telemetry enabled. See https://docs.trychroma.com/telemetry for more information." + "Anonymized telemetry enabled. See \ + https://docs.trychroma.com/telemetry for more information." ) posthog.project_api_key = "phc_YeUxaojbKk5KPi8hNlx1bBKHzuZ4FDtl67kH1blv8Bh" @@ -23,13 +27,13 @@ def __init__(self, system: System): # Silence posthog's logging posthog_logger.disabled = True - self.batched_events: Dict[str, TelemetryEvent] = {} + self.batched_events: Dict[str, ProductTelemetryEvent] = {} self.seen_event_types: Set[Any] = set() super().__init__(system) @override - def capture(self, event: TelemetryEvent) -> None: + def capture(self, event: ProductTelemetryEvent) -> None: if event.max_batch_size == 1 or event.batch_key not in self.seen_event_types: self.seen_event_types.add(event.batch_key) self._direct_capture(event) @@ -44,7 +48,7 @@ def capture(self, event: TelemetryEvent) -> None: self._direct_capture(batched_event) del self.batched_events[batch_key] - def _direct_capture(self, event: TelemetryEvent) -> None: + def _direct_capture(self, event: ProductTelemetryEvent) -> None: try: posthog.capture( self.user_id, diff --git a/chromadb/test/property/test_cross_version_persist.py b/chromadb/test/property/test_cross_version_persist.py index 529fe02dda7..3bd83231b32 100644 --- a/chromadb/test/property/test_cross_version_persist.py +++ b/chromadb/test/property/test_cross_version_persist.py @@ -46,7 +46,9 @@ def _bool_to_int(metadata: Dict[str, Any]) -> Dict[str, Any]: def _patch_boolean_metadata( - collection: strategies.Collection, embeddings: strategies.RecordSet + collection: strategies.Collection, + embeddings: strategies.RecordSet, + settings: Settings, ) -> None: # Since the old version does not support boolean value metadata, we will convert # boolean value metadata to int @@ -64,15 +66,29 @@ def _patch_boolean_metadata( _bool_to_int(metadata) +def _patch_telemetry_client( + collection: strategies.Collection, + embeddings: strategies.RecordSet, + settings: Settings, +) -> None: + # chroma 0.4.14 added OpenTelemetry, distinct from ProductTelemetry. Before 0.4.14 + # ProductTelemetry was simply called Telemetry. + settings.chroma_telemetry_impl = "chromadb.telemetry.posthog.Posthog" + + version_patches: List[ - Tuple[str, Callable[[strategies.Collection, strategies.RecordSet], None]] + Tuple[str, Callable[[strategies.Collection, strategies.RecordSet, Settings], None]] ] = [ ("0.4.3", _patch_boolean_metadata), + ("0.4.14", _patch_telemetry_client), ] def patch_for_version( - version: str, collection: strategies.Collection, embeddings: strategies.RecordSet + version: str, + collection: strategies.Collection, + embeddings: strategies.RecordSet, + settings: Settings, ) -> None: """Override aspects of the collection and embeddings, before testing, to account for breaking changes in old versions.""" @@ -81,7 +97,7 @@ def patch_for_version( if packaging_version.Version(version) <= packaging_version.Version( patch_version ): - patch(collection, embeddings) + patch(collection, embeddings, settings) def configurations(versions: List[str]) -> List[Tuple[str, Settings]]: @@ -261,7 +277,7 @@ def test_cycle_versions( for m in embeddings_strategy["metadatas"] ] - patch_for_version(version, collection_strategy, embeddings_strategy) + patch_for_version(version, collection_strategy, embeddings_strategy, settings) # Can't pickle a function, and we won't need them collection_strategy.embedding_function = None diff --git a/docker-compose.yml b/docker-compose.yml index 93581dd23c7..3bc5d5a9404 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -22,6 +22,10 @@ services: - CHROMA_SERVER_AUTH_CREDENTIALS=${CHROMA_SERVER_AUTH_CREDENTIALS} - CHROMA_SERVER_AUTH_CREDENTIALS_PROVIDER=${CHROMA_SERVER_AUTH_CREDENTIALS_PROVIDER} - PERSIST_DIRECTORY=${PERSIST_DIRECTORY:-/chroma/chroma} + - CHROMA_OTEL_EXPORTER_ENDPOINT=${CHROMA_OTEL_EXPORTER_ENDPOINT} + - CHROMA_OTEL_EXPORTER_HEADERS=${CHROMA_OTEL_EXPORTER_HEADERS} + - CHROMA_OTEL_SERVICE_NAME=${CHROMA_OTEL_SERVICE_NAME} + - CHROMA_OTEL_GRANULARITY=${CHROMA_OTEL_GRANULARITY} ports: - 8000:8000 networks: diff --git a/requirements.txt b/requirements.txt index 7b60e6101bb..f3093341f14 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,9 @@ kubernetes>=28.1.0 numpy==1.21.6; python_version < '3.8' numpy>=1.22.4; python_version >= '3.8' onnxruntime>=1.14.1 +opentelemetry-api>=1.2.0 +opentelemetry-exporter-otlp-proto-grpc>=1.2.0 +opentelemetry-sdk>=1.2.0 overrides==7.3.1 posthog==2.4.0 pulsar-client==3.1.0 diff --git a/server.htpasswd b/server.htpasswd new file mode 100644 index 00000000000..77f277a399b --- /dev/null +++ b/server.htpasswd @@ -0,0 +1 @@ +admin:$2y$05$e5sRb6NCcSH3YfbIxe1AGu2h5K7OOd982OXKmd8WyQ3DRQ4MvpnZS From ac644a82249f0c44fd07560066b3a57febdd99ac Mon Sep 17 00:00:00 2001 From: Ben Eggers <64657842+beggers@users.noreply.github.com> Date: Wed, 18 Oct 2023 13:39:35 -0700 Subject: [PATCH 13/14] Update pyproject.toml (#1256) Add OTel dependencies to our toml so the release works ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - In my recent OTel PR I didn't add the requirements to `pyproject.toml` so the release pipeline is broken. This fixes that. - New functionality - None. ## Test plan *How are these changes tested?* - CI ## Documentation Changes *Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?* --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 7dd144cf3ab..0c5fa10307c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,9 @@ dependencies = [ 'typing_extensions >= 4.5.0', 'pulsar-client>=3.1.0', 'onnxruntime >= 1.14.1', + 'opentelemetry-api>=1.2.0', + 'opentelemetry-exporter-otlp-proto-grpc>=1.2.0', + 'opentelemetry-sdk>=1.2.0', 'tokenizers >= 0.13.2', 'pypika >= 0.48.9', 'tqdm >= 4.65.0', From 019b954e571fe362a420ec27b2add9bac94c6334 Mon Sep 17 00:00:00 2001 From: Ben Eggers <64657842+beggers@users.noreply.github.com> Date: Thu, 19 Oct 2023 10:31:02 -0700 Subject: [PATCH 14/14] Fix otel startup bug (#1263) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Fixes https://github.com/chroma-core/chroma/issues/1260#issuecomment-1771207121 - TL;DR the default `chroma_otel_tracing_granularity` ends up set to `""` which isn't a valid enum value. - New functionality - None ## Test plan *How are these changes tested?* - CI - The `docker-compose` startup flow works now ## Documentation Changes *Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?* --- chromadb/config.py | 2 +- chromadb/telemetry/opentelemetry/__init__.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/chromadb/config.py b/chromadb/config.py index 6731db255b7..0a3e4864673 100644 --- a/chromadb/config.py +++ b/chromadb/config.py @@ -179,7 +179,7 @@ def chroma_server_auth_credentials_file_non_empty_file_exists( chroma_otel_collection_endpoint: Optional[str] = "" chroma_otel_service_name: Optional[str] = "chromadb" chroma_otel_collection_headers: Dict[str, str] = {} - chroma_otel_granularity: Optional[str] = "none" + chroma_otel_granularity: Optional[str] = None allow_reset: bool = False diff --git a/chromadb/telemetry/opentelemetry/__init__.py b/chromadb/telemetry/opentelemetry/__init__.py index 0840871bcae..a713ed11ab4 100644 --- a/chromadb/telemetry/opentelemetry/__init__.py +++ b/chromadb/telemetry/opentelemetry/__init__.py @@ -49,7 +49,11 @@ def __init__(self, system: System): system.settings.chroma_otel_service_name, system.settings.chroma_otel_collection_endpoint, system.settings.chroma_otel_collection_headers, - OpenTelemetryGranularity(system.settings.chroma_otel_granularity), + OpenTelemetryGranularity( + system.settings.chroma_otel_granularity + if system.settings.chroma_otel_granularity + else "none" + ), )