From 3241de7a6fc20fe178dbe93090ee89f57a5625b6 Mon Sep 17 00:00:00 2001 From: Jeff Huber Date: Wed, 6 Sep 2023 22:19:41 -0700 Subject: [PATCH 01/39] update JS instructions (#960) Improve develop instructions for the JS client --------- Co-authored-by: Pascal M <11357019+perzeuss@users.noreply.github.com> Co-authored-by: Hammad Bashir --- clients/js/DEVELOP.md | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/clients/js/DEVELOP.md b/clients/js/DEVELOP.md index 030f3ef455c..c82fd15327e 100644 --- a/clients/js/DEVELOP.md +++ b/clients/js/DEVELOP.md @@ -5,18 +5,24 @@ This readme is helpful for local dev. ### Prereqs: - Make sure you have Java installed (for the generator). You can download it from [java.com](https://java.com) +- Make sure you set ALLOW_RESET=True for your Docker Container. If you don't do this, tests won't pass. +``` +environment: + - IS_PERSISTENT=TRUE + - ALLOW_RESET=True +``` - Make sure you are running the docker backend at localhost:8000 (\*there is probably a way to stand up the fastapi server by itself and programmatically in the loop of generating this, but not prioritizing it for now. It may be important for the release) ### Generating 1. `yarn` to install deps -2. `yarn genapi-zsh` if you have zsh +2. `yarn genapi` 3. Examples are in the `examples` folder. There is one for the browser and one for node. Run them with `yarn dev`, eg `cd examples/browser && yarn dev` ### Running test -`yarn test` will launch a test docker backend. -`yarn test:run` will run against the docker backend you have running. But CAUTION, it will delete data. +`yarn test` will launch a test docker backend, run a db cleanup and run tests. +`yarn test:run` will run against the docker backend you have running. But CAUTION, it will delete data. This is the easiest and fastest way to run tests. ### Pushing to npm From 237b3e3c96c0e25d87dec4c5b2b21085c040dfa3 Mon Sep 17 00:00:00 2001 From: Hammad Bashir Date: Wed, 6 Sep 2023 23:13:29 -0700 Subject: [PATCH 02/39] [BLD] Add dockerhub support (#1112) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Pushes images to dockerhub ## Test plan *How are these changes tested?* Will have to be tested on main as part of CI/CD - [x] Tests pass locally with `pytest` for python, `yarn test` for js ## Documentation Changes None required. --- .github/workflows/chroma-release.yml | 53 +++++++++++++--------------- 1 file changed, 25 insertions(+), 28 deletions(-) diff --git a/.github/workflows/chroma-release.yml b/.github/workflows/chroma-release.yml index ca394a24594..a2f0a988a43 100644 --- a/.github/workflows/chroma-release.yml +++ b/.github/workflows/chroma-release.yml @@ -8,8 +8,8 @@ on: - main env: - REGISTRY: ghcr.io - IMAGE_NAME: "ghcr.io/chroma-core/chroma" + GHCR_IMAGE_NAME: "ghcr.io/chroma-core/chroma" + DOCKERHUB_IMAGE_NAME: "chromadb/chroma" PLATFORMS: linux/amd64,linux/arm64 #linux/riscv64, linux/arm/v7 jobs: @@ -27,14 +27,7 @@ jobs: build-and-release: runs-on: ubuntu-latest needs: check_tag - if: needs.check_tag.outputs.tag_matches == 'true' permissions: write-all -# id-token: write -# contents: read -# deployments: write -# packages: write -# pull-requests: read -# statuses: write steps: - name: Checkout uses: actions/checkout@v3 @@ -57,36 +50,38 @@ jobs: run: python -m build - name: Test Client Package run: bin/test-package.sh dist/*.tar.gz - - name: Log in to the Container registry + - name: Log in to the Github Container registry uses: docker/login-action@v2.1.0 with: - registry: ${{ env.REGISTRY }} + registry: ghcr.io username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} + - name: Login to DockerHub + uses: docker/login-action@v2.1.0 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Install setuptools_scm run: python -m pip install setuptools_scm - - name: Get Docker Tag - id: tag - run: echo "tag_name=$IMAGE_NAME:$(bin/version)" >> $GITHUB_OUTPUT + - name: Get Release Version + id: version + run: echo "version=$(python -m setuptools_scm)" >> $GITHUB_OUTPUT - name: Build and push prerelease Docker image - if: "!startsWith(github.ref, 'refs/tags/')" + if: "needs.check_tag.outputs.tag_matches != 'true'" uses: docker/build-push-action@v3.2.0 with: context: . platforms: ${{ env.PLATFORMS }} push: true - tags: ${{ steps.tag.outputs.tag_name}} + tags: "${{ env.GHCR_IMAGE_NAME }}:${{ steps.version.outputs.version }},${{ env.DOCKERHUB_IMAGE_NAME }}:${{ steps.version.outputs.version }}" - name: Build and push release Docker image - if: "startsWith(github.ref, 'refs/tags/')" + if: "needs.check_tag.outputs.tag_matches == 'true'" uses: docker/build-push-action@v3.2.0 with: context: . platforms: ${{ env.PLATFORMS }} push: true - tags: "${{ steps.tag.outputs.tag_name }},${{ env.IMAGE_NAME }}:latest" - - name: Get Release Version - id: version - run: echo "version=$(python -m setuptools_scm)" >> $GITHUB_OUTPUT + tags: "${{ env.GHCR_IMAGE_NAME }}:${{ steps.version.outputs.version }},${{ env.DOCKERHUB_IMAGE_NAME }}:${{ steps.version.outputs.version }},${{ env.GHCR_IMAGE_NAME }}:latest,${{ env.DOCKERHUB_IMAGE_NAME }}:latest" - name: Get current date id: builddate run: echo "builddate=$(date +'%Y-%m-%dT%H:%M')" >> $GITHUB_OUTPUT @@ -96,7 +91,7 @@ jobs: password: ${{ secrets.TEST_PYPI_API_TOKEN }} repository_url: https://test.pypi.org/legacy/ - name: Publish to PyPI - if: startsWith(github.ref, 'refs/tags') + if: "needs.check_tag.outputs.tag_matches == 'true'" uses: pypa/gh-action-pypi-publish@release/v1 with: password: ${{ secrets.PYPI_API_TOKEN }} @@ -107,31 +102,32 @@ jobs: aws-region: us-east-1 - name: Generate CloudFormation template id: generate-cf - if: "startsWith(github.ref, 'refs/tags/')" + if: "needs.check_tag.outputs.tag_matches == 'true'" run: "pip install boto3 && python bin/generate_cloudformation.py" - name: Release Tagged Version uses: ncipollo/release-action@v1.11.1 - if: "startsWith(github.ref, 'refs/tags/')" + if: "needs.check_tag.outputs.tag_matches == 'true'" with: body: | Version: `${{steps.version.outputs.version}}` Git ref: `${{github.ref}}` Build Date: `${{steps.builddate.outputs.builddate}}` PIP Package: `chroma-${{steps.version.outputs.version}}.tar.gz` - Docker Image: `${{steps.tag.outputs.tag_name}}` + Github Container Registry Image: `${{ env.GHCR_IMAGE_NAME }}:${{ steps.version.outputs.version }}` + DockerHub Image: `${{ env.DOCKERHUB_IMAGE_NAME }}:${{ steps.version.outputs.version }}` artifacts: "dist/chroma-${{steps.version.outputs.version}}.tar.gz" prerelease: true generateReleaseNotes: true - name: Update Tag uses: richardsimko/update-tag@v1.0.5 - if: "!startsWith(github.ref, 'refs/tags/')" + if: "needs.check_tag.outputs.tag_matches != 'true'" with: tag_name: latest env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: Release Latest uses: ncipollo/release-action@v1.11.1 - if: "!startsWith(github.ref, 'refs/tags/')" + if: "needs.check_tag.outputs.tag_matches != 'true'" with: tag: "latest" name: "Latest" @@ -140,7 +136,8 @@ jobs: Git ref: `${{github.ref}}` Build Date: `${{steps.builddate.outputs.builddate}}` PIP Package: `chroma-${{steps.version.outputs.version}}.tar.gz` - Docker Image: `${{steps.tag.outputs.tag_name}}` + Github Container Registry Image: `${{ env.GHCR_IMAGE_NAME }}:${{ steps.version.outputs.version }}` + DockerHub Image: `${{ env.DOCKERHUB_IMAGE_NAME }}:${{ steps.version.outputs.version }}` artifacts: "dist/chroma-${{steps.version.outputs.version}}.tar.gz" allowUpdates: true prerelease: true From ea73f05bdf91ff771388d6c573af9ddd8ca37cbe Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Thu, 7 Sep 2023 23:24:41 +0300 Subject: [PATCH 03/39] [BUG]: Issue where In/Nin list values (#1111) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Fixed an issue where list values for In/Nin that are not wrapped with pypika ParameterValue will result in floating point comparisons failure after a certain precision threshold. ## Test plan *How are these changes tested?* - [x] Tests pass locally with `pytest` for python ## Documentation Changes N/A --- chromadb/segment/impl/metadata/sqlite.py | 52 ++++++++---------------- 1 file changed, 16 insertions(+), 36 deletions(-) diff --git a/chromadb/segment/impl/metadata/sqlite.py b/chromadb/segment/impl/metadata/sqlite.py index 781aed00ba4..8e9649a627d 100644 --- a/chromadb/segment/impl/metadata/sqlite.py +++ b/chromadb/segment/impl/metadata/sqlite.py @@ -105,7 +105,6 @@ def get_metadata( offset: Optional[int] = None, ) -> Sequence[MetadataEmbeddingRecord]: """Query for embedding metadata.""" - embeddings_t, metadata_t, fulltext_t = Tables( "embeddings", "embedding_metadata", "embedding_fulltext_search" ) @@ -135,7 +134,6 @@ def get_metadata( if where: q = q.where(self._where_map_criterion(q, where, embeddings_t, metadata_t)) - if where_document: q = q.where( self._where_doc_criterion(q, where_document, embeddings_t, fulltext_t) @@ -417,32 +415,8 @@ def _where_map_criterion( for w in cast(Sequence[Where], v) ] clause.append(reduce(lambda x, y: x | y, criteria)) - elif k == "$in": - expr = cast( - Dict[InclusionExclusionOperator, List[LiteralValue]], {k: v} - ) - sq = ( - self._db.querybuilder() - .from_(metadata_t) - .select(metadata_t.id) - .where(metadata_t.key.isin(ParameterValue(k))) - .where(_where_clause(expr, metadata_t)) - ) - clause.append(embeddings_t.id.isin(sq)) - elif k == "$nin": - expr = cast( - Dict[InclusionExclusionOperator, List[LiteralValue]], {k: v} - ) - sq = ( - self._db.querybuilder() - .from_(metadata_t) - .select(metadata_t.id) - .where(metadata_t.key.notin(ParameterValue(k))) - .where(_where_clause(expr, metadata_t)) - ) - clause.append(embeddings_t.id.notin(sq)) else: - expr = cast(Union[LiteralValue, Dict[WhereOperator, LiteralValue]], v) # type: ignore + expr = cast(Union[LiteralValue, Dict[WhereOperator, LiteralValue]], v) sq = ( self._db.querybuilder() .from_(metadata_t) @@ -554,30 +528,36 @@ def _value_criterion( raise ValueError(f"Empty list for {op} operator") if isinstance(value[0], str): col_exprs = [ - table.string_value.isin(_v) + table.string_value.isin(ParameterValue(_v)) if op == "$in" - else table.str_value.notin(_v) + else table.str_value.notin(ParameterValue(_v)) ] elif isinstance(value[0], bool): col_exprs = [ - table.bool_value.isin(_v) if op == "$in" else table.bool_value.notin(_v) + table.bool_value.isin(ParameterValue(_v)) + if op == "$in" + else table.bool_value.notin(ParameterValue(_v)) ] elif isinstance(value[0], int): col_exprs = [ - table.int_value.isin(_v) if op == "$in" else table.int_value.notin(_v) + table.int_value.isin(ParameterValue(_v)) + if op == "$in" + else table.int_value.notin(ParameterValue(_v)) ] elif isinstance(value[0], float): col_exprs = [ - table.float_value.isin(_v) + table.float_value.isin(ParameterValue(_v)) if op == "$in" - else table.float_value.notin(_v) + else table.float_value.notin(ParameterValue(_v)) ] elif isinstance(value, list) and op in ("$in", "$nin"): col_exprs = [ - table.int_value.isin(value), - table.float_value.isin(value) + table.int_value.isin(ParameterValue(value)) + if op == "$in" + else table.int_value.notin(ParameterValue(value)), + table.float_value.isin(ParameterValue(value)) if op == "$in" - else table.float_value.notin(value), + else table.float_value.notin(ParameterValue(value)), ] else: cols = [table.int_value, table.float_value] From 9c1979c9311acb8662fc22a7deac17907a9b128b Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Mon, 11 Sep 2023 17:58:20 +0300 Subject: [PATCH 04/39] [BUG]: URL Parsing And Validation (#1118) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Added additional validations to URLs - URLs like api-gw.aws.com/dev will now trigger an error asking the user to correctly specify the URL with http or https - When the full URL (http(s)://example.com) is provided by the user, the port parameter is ignored (debug message is logged). An assumption is made that the URL is entirely defined, thus not requiring additional alterations such as injecting the port. - Added negative test cases for invalid URLs ## Test plan *How are these changes tested?* - [x] Tests pass locally with `pytest` for python ## Documentation Changes TBD --- chromadb/api/fastapi.py | 37 ++++++++++++++--- chromadb/test/property/test_client_url.py | 48 +++++++++++++++++++++-- 2 files changed, 76 insertions(+), 9 deletions(-) diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index 8498f9ec110..c08458a2fcb 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -1,4 +1,5 @@ import json +import logging from typing import Optional, cast from typing import Sequence from uuid import UUID @@ -32,28 +33,54 @@ from chromadb.telemetry import Telemetry from urllib.parse import urlparse, urlunparse, quote +logger = logging.getLogger(__name__) + class FastAPI(API): _settings: Settings + @staticmethod + def _validate_host(host: str) -> None: + parsed = urlparse(host) + if "/" in host and parsed.scheme not in {"http", "https"}: + raise ValueError( + "Invalid URL. " f"Unrecognized protocol - {parsed.scheme}." + ) + if "/" in host and (not host.startswith("http")): + raise ValueError( + "Invalid URL. " + "Seems that you are trying to pass URL as a host but without specifying the protocol. " + "Please add http:// or https:// to the host." + ) + @staticmethod def resolve_url( chroma_server_host: str, chroma_server_ssl_enabled: Optional[bool] = False, default_api_path: Optional[str] = "", - chroma_server_http_port: int = 8000, + chroma_server_http_port: Optional[int] = 8000, ) -> str: - parsed = urlparse(chroma_server_host) + _skip_port = False + _chroma_server_host = chroma_server_host + FastAPI._validate_host(_chroma_server_host) + if _chroma_server_host.startswith("http"): + logger.debug("Skipping port as the user is passing a full URL") + _skip_port = True + parsed = urlparse(_chroma_server_host) scheme = "https" if chroma_server_ssl_enabled else parsed.scheme or "http" net_loc = parsed.netloc or parsed.hostname or chroma_server_host - port = parsed.port or chroma_server_http_port + port = ( + ":" + str(parsed.port or chroma_server_http_port) if not _skip_port else "" + ) path = parsed.path or default_api_path - if not path or path == net_loc or not path.endswith(default_api_path or ""): + if not path or path == net_loc: path = default_api_path if default_api_path else "" + if not path.endswith(default_api_path or ""): + path = path + default_api_path if default_api_path else "" full_url = urlunparse( - (scheme, f"{net_loc}:{port}", quote(path.replace("//", "/")), "", "", "") + (scheme, f"{net_loc}{port}", quote(path.replace("//", "/")), "", "", "") ) return full_url diff --git a/chromadb/test/property/test_client_url.py b/chromadb/test/property/test_client_url.py index 992af981399..cc5df1e0514 100644 --- a/chromadb/test/property/test_client_url.py +++ b/chromadb/test/property/test_client_url.py @@ -1,6 +1,7 @@ from typing import Optional from urllib.parse import urlparse +import pytest from hypothesis import given, strategies as st from chromadb.api.fastapi import FastAPI @@ -28,7 +29,7 @@ def domain_strategy() -> st.SearchStrategy[str]: return st.tuples(label, tld).map(".".join) -port_strategy = st.integers(min_value=1, max_value=65535) +port_strategy = st.one_of(st.integers(min_value=1, max_value=65535), st.none()) ssl_enabled_strategy = st.booleans() @@ -56,8 +57,21 @@ def is_valid_url(url: str) -> bool: def generate_valid_domain_url() -> st.SearchStrategy[str]: return st.builds( - lambda url_scheme, hostname, url_path: f"{url_scheme}://{hostname}{url_path}", - url_scheme=st.sampled_from(["http", "https"]), + lambda url_scheme, hostname, url_path: f"{url_scheme}{hostname}{url_path}", + url_scheme=st.sampled_from(["http://", "https://"]), + hostname=domain_strategy(), + url_path=url_path_strategy(), + ) + + +def generate_invalid_domain_url() -> st.SearchStrategy[str]: + return st.builds( + lambda url_scheme, hostname, url_path: f"{url_scheme}{hostname}{url_path}", + url_scheme=st.builds( + lambda scheme, suffix: f"{scheme}{suffix}", + scheme=st.text(max_size=10), + suffix=st.sampled_from(["://", ":///", ":////", ""]), + ), hostname=domain_strategy(), url_path=url_path_strategy(), ) @@ -76,7 +90,7 @@ def generate_valid_domain_url() -> st.SearchStrategy[str]: ) def test_url_resolve( hostname: str, - port: int, + port: Optional[int], ssl_enabled: bool, default_api_path: Optional[str], ) -> None: @@ -90,5 +104,31 @@ def test_url_resolve( assert ( _url.startswith("https") if ssl_enabled else _url.startswith("http") ), f"Invalid URL: {_url} - SSL Enabled: {ssl_enabled}" + if hostname.startswith("http"): + assert ":" + str(port) not in _url, f"Port in URL not expected: {_url}" + else: + assert ":" + str(port) in _url, f"Port in URL expected: {_url}" if default_api_path: assert _url.endswith(default_api_path), f"Invalid URL: {_url}" + + +@given( + hostname=generate_invalid_domain_url(), + port=port_strategy, + ssl_enabled=ssl_enabled_strategy, + default_api_path=st.sampled_from(["/api/v1", "/api/v2", None]), +) +def test_resolve_invalid( + hostname: str, + port: Optional[int], + ssl_enabled: bool, + default_api_path: Optional[str], +) -> None: + with pytest.raises(ValueError) as e: + FastAPI.resolve_url( + chroma_server_host=hostname, + chroma_server_http_port=port, + chroma_server_ssl_enabled=ssl_enabled, + default_api_path=default_api_path, + ) + assert "Invalid URL" in str(e.value) From 2dd5a1552687043889a2dafe27de56577ad278a7 Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Mon, 11 Sep 2023 18:19:57 +0300 Subject: [PATCH 05/39] [ENH] Added auth and external volume support for GCP (#1107) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Added external volume for Chroma data - Bumped to the latest version (0.4.9) - Added auth by default - Made the template more configurable via variables with sensible defaults ## Test plan *How are these changes tested?* - Tested with terraform ## Documentation Changes The update contains README with docs. --- .../google-cloud-compute/README.md | 98 +++++++++++++- .../google-cloud-compute/chroma.tf | 105 ++++++++++++-- .../deployments/google-cloud-compute/main.tf | 8 -- .../google-cloud-compute/startup.sh | 39 +++++- .../google-cloud-compute/variables.tf | 128 +++++++++++++++++- 5 files changed, 348 insertions(+), 30 deletions(-) diff --git a/examples/deployments/google-cloud-compute/README.md b/examples/deployments/google-cloud-compute/README.md index 8da1af830de..ea25613baf4 100644 --- a/examples/deployments/google-cloud-compute/README.md +++ b/examples/deployments/google-cloud-compute/README.md @@ -3,43 +3,135 @@ This is an example deployment to Google Cloud Compute using [terraform](https://www.terraform.io/) ## Requirements + - [gcloud CLI](https://cloud.google.com/sdk/gcloud) - [Terraform CLI v1.3.4+](https://developer.hashicorp.com/terraform/tutorials/gcp-get-started/install-cli) +- [Terraform GCP provider](https://registry.terraform.io/providers/hashicorp/google/latest/docs) ## Deployment with terraform ### 1. Auth to your Google Cloud project + ```bash gcloud auth application-default login ``` ### 2. Init your terraform state + ```bash terraform init ``` ### 3. Deploy your application + +> **WARNING**: GCP Terraform provider does not allow use of variables in the lifecycle of the volume. By default, the +> template does not prevent deletion of the volume however if you plan to use this template for production deployment you +> may consider change the value of `prevent_destroy` to `true` in `chroma.tf` file. + +Generate SSH key to use with your chroma instance (so you can SSH to the GCP VM): + +> Note: This is optional. You can use your own existing SSH key if you prefer. + +```bash +ssh-keygen -t RSA -b 4096 -C "Chroma AWS Key" -N "" -f ./chroma-aws && chmod 400 ./chroma-aws +``` + ```bash export TF_VAR_project_id= #take note of this as it must be present in all of the subsequent steps -export TF_VAR_chroma_release=0.4.5 #set the chroma release to deploy +export TF_ssh_public_key="./chroma-aws.pub" #path to the public key you generated above (or can be different if you want to use your own key) +export TF_ssh_private_key="./chroma-aws" #path to the private key you generated above (or can be different if you want to use your own key) - used for formatting the Chroma data volume +export TF_VAR_chroma_release="0.4.9" #set the chroma release to deploy +export TF_VAR_zone="us-central1-a" # AWS region to deploy the chroma instance to +export TF_VAR_public_access="true" #enable public access to the chroma instance on port 8000 +export TF_VAR_enable_auth="true" #enable basic auth for the chroma instance +export TF_VAR_auth_type="token" #The auth type to use for the chroma instance (token or basic) terraform apply -auto-approve ``` ### 4. Check your public IP and that Chroma is running -Get the public IP of your instance +> Note: Depending on your instance type it might take a few minutes for the instance to be ready + +Get the public IP of your instance (it should also be printed out after successful `terraform apply`): ```bash terraform output instance_public_ip ``` -Check that chroma is running +Check that chroma is running: + ```bash export instance_public_ip=$(terraform output instance_public_ip | sed 's/"//g') curl -v http://$instance_public_ip:8000/api/v1/heartbeat ``` +#### 4.1 Checking Auth + +##### Token + +When token auth is enabled (this is the default option) you can check the get the credentials from Terraform state by +running: + +```bash +terraform output chroma_auth_token +``` + +You should see something of the form: + +```bash +PVcQ4qUUnmahXwUgAf3UuYZoMlos6MnF +``` + +You can then export these credentials: + +```bash +export CHROMA_AUTH=$(terraform output chroma_auth_token | sed 's/"//g') +``` + +Using the credentials: + +```bash +curl -v http://$instance_public_ip:8000/api/v1/collections -H "Authorization: Bearer ${CHROMA_AUTH}" +``` + +##### Basic + +When basic auth is enabled you can check the get the credentials from Terraform state by running: + +```bash +terraform output chroma_auth_basic +``` + +You should see something of the form: + +```bash +chroma:VuA8I}QyNrm0@QLq +``` + +You can then export these credentials: + +```bash +export CHROMA_AUTH=$(terraform output chroma_auth_basic | sed 's/"//g') +``` + +Using the credentials: + +```bash +curl -v http://$instance_public_ip:8000/api/v1/collections -u "${CHROMA_AUTH}" +``` + +> Note: Without `-u` you should be getting 401 Unauthorized response + +#### 4.2 SSH to your instance + +To SSH to your instance: + +```bash +ssh -i ./chroma-aws debian@$instance_public_ip +``` + ### 5. Destroy your application + ```bash terraform destroy -auto-approve ``` diff --git a/examples/deployments/google-cloud-compute/chroma.tf b/examples/deployments/google-cloud-compute/chroma.tf index 7ed41b417a2..f49fc59cfe3 100644 --- a/examples/deployments/google-cloud-compute/chroma.tf +++ b/examples/deployments/google-cloud-compute/chroma.tf @@ -1,18 +1,36 @@ -resource "google_compute_instance" "chroma1" { +terraform { + required_providers { + google = { + source = "hashicorp/google" + version = "~> 4.80.0" + } + } +} + +resource "google_compute_instance" "chroma" { project = var.project_id name = "chroma-1" machine_type = var.machine_type zone = var.zone - tags = ["chroma"] + tags = local.tags + + labels = var.labels + boot_disk { initialize_params { - image = "debian-cloud/debian-11" - size = 20 + image = var.image + size = var.chroma_instance_volume_size #size in GB } } + attached_disk { + source = google_compute_disk.chroma.id + device_name = var.chroma_data_volume_device_name + mode = "READ_WRITE" + } + network_interface { network = "default" @@ -21,30 +39,91 @@ resource "google_compute_instance" "chroma1" { } } - metadata_startup_script = templatefile("${path.module}/startup.sh", { chroma_release = var.chroma_release }) + metadata = { + ssh-keys = "${var.vm_user}:${file(var.ssh_public_key)}" + } + + metadata_startup_script = templatefile("${path.module}/startup.sh", { + chroma_release = var.chroma_release, + enable_auth = var.enable_auth, + auth_type = var.auth_type, + basic_auth_credentials = "${local.basic_auth_credentials.username}:${local.basic_auth_credentials.password}", + token_auth_credentials = random_password.chroma_token.result, + }) + + provisioner "remote-exec" { + inline = [ + "export VOLUME_ID=${var.chroma_data_volume_device_name} && sudo mkfs -t ext4 /dev/$(lsblk -o +SERIAL | grep $VOLUME_ID | awk '{print $1}')", + "sudo mkdir /chroma-data", + "export VOLUME_ID=${var.chroma_data_volume_device_name} && sudo mount /dev/$(lsblk -o +SERIAL | grep $VOLUME_ID | awk '{print $1}') /chroma-data" + ] + + connection { + host = google_compute_instance.chroma.network_interface[0].access_config[0].nat_ip + type = "ssh" + user = var.vm_user + private_key = file(var.ssh_private_key) + } + } } + +resource "google_compute_disk" "chroma" { + project = var.project_id + name = "chroma-data" + type = var.disk_type + zone = var.zone + labels = var.labels + size = var.chroma_data_volume_size #size in GB + + lifecycle { + prevent_destroy = false #WARNING: You need to configure this manually as the provider does not support it yet + } +} + +#resource "google_compute_attached_disk" "vm_attached_disk" { +# disk = google_compute_disk.chroma.id +# instance = google_compute_instance.chroma.self_link +# +#} + + + resource "google_compute_firewall" "default" { project = var.project_id name = "chroma-firewall" network = "default" allow { - protocol = "icmp" + protocol = "icmp" #allow ping } - allow { - protocol = "tcp" - ports = ["8000"] + dynamic "allow" { + for_each = var.public_access ? [1] : [] + content { + protocol = "tcp" + ports = [var.chroma_port] + } } - source_ranges = ["0.0.0.0/0"] + source_ranges = var.source_ranges - target_tags = ["chroma"] + target_tags = local.tags } output "instance_public_ip" { description = "The public IP address of the instance." - value = google_compute_instance.chroma1.network_interface[0].access_config[0].nat_ip -} \ No newline at end of file + value = google_compute_instance.chroma.network_interface[0].access_config[0].nat_ip +} + +output "chroma_auth_token" { + value = random_password.chroma_token.result + sensitive = true +} + + +output "chroma_auth_basic" { + value = "${local.basic_auth_credentials.username}:${local.basic_auth_credentials.password}" + sensitive = true +} diff --git a/examples/deployments/google-cloud-compute/main.tf b/examples/deployments/google-cloud-compute/main.tf index a73c3776542..e69de29bb2d 100644 --- a/examples/deployments/google-cloud-compute/main.tf +++ b/examples/deployments/google-cloud-compute/main.tf @@ -1,8 +0,0 @@ -terraform { - required_providers { - google = { - source = "hashicorp/google" - version = "~> 4.47.0" - } - } -} diff --git a/examples/deployments/google-cloud-compute/startup.sh b/examples/deployments/google-cloud-compute/startup.sh index d0140670cb4..1d93e46c3e2 100644 --- a/examples/deployments/google-cloud-compute/startup.sh +++ b/examples/deployments/google-cloud-compute/startup.sh @@ -1,7 +1,12 @@ #! /bin/bash -cd ~ +# Note: This is run as root +cd ~ +export enable_auth="${enable_auth}" +export basic_auth_credentials="${basic_auth_credentials}" +export auth_type="${auth_type}" +export token_auth_credentials="${token_auth_credentials}" apt-get update -y apt-get install -y ca-certificates curl gnupg lsb-release mkdir -m 0755 -p /etc/apt/keyrings @@ -13,10 +18,36 @@ apt-get update -y chmod a+r /etc/apt/keyrings/docker.gpg apt-get update -y apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin git - -git clone https://github.com/chroma-core/chroma.git -cd chroma +usermod -aG docker debian +git clone https://github.com/chroma-core/chroma.git && cd chroma git fetch --tags git checkout tags/${chroma_release} +if [ "$${enable_auth}" = "true" ] && [ "$${auth_type}" = "basic" ] && [ ! -z "$${basic_auth_credentials}" ]; then + username=$(echo $basic_auth_credentials | cut -d: -f1) + password=$(echo $basic_auth_credentials | cut -d: -f2) + docker run --rm --entrypoint htpasswd httpd:2 -Bbn $username $password > server.htpasswd + cat < .env +CHROMA_SERVER_AUTH_CREDENTIALS_FILE="/chroma/server.htpasswd" +CHROMA_SERVER_AUTH_CREDENTIALS_PROVIDER='chromadb.auth.providers.HtpasswdFileServerAuthCredentialsProvider' +CHROMA_SERVER_AUTH_PROVIDER='chromadb.auth.basic.BasicAuthServerProvider' +EOF +fi + +if [ "$${enable_auth}" = "true" ] && [ "$${auth_type}" = "token" ] && [ ! -z "$${token_auth_credentials}" ]; then + cat < .env +CHROMA_SERVER_AUTH_CREDENTIALS="$${token_auth_credentials}" \ +CHROMA_SERVER_AUTH_CREDENTIALS_PROVIDER='chromadb.auth.token.TokenConfigServerAuthCredentialsProvider' +CHROMA_SERVER_AUTH_PROVIDER='chromadb.auth.token.TokenAuthServerProvider' +EOF +fi + +cat < docker-compose.override.yaml +version: '3.8' +services: + server: + volumes: + - /chroma-data:/chroma/chroma +EOF + COMPOSE_PROJECT_NAME=chroma docker compose up -d --build diff --git a/examples/deployments/google-cloud-compute/variables.tf b/examples/deployments/google-cloud-compute/variables.tf index a3bca24654b..0147ce49aa4 100644 --- a/examples/deployments/google-cloud-compute/variables.tf +++ b/examples/deployments/google-cloud-compute/variables.tf @@ -1,10 +1,11 @@ variable "project_id" { - type = string + type = string + description = "The project id to deploy to" } variable "chroma_release" { description = "The chroma release to deploy" type = string - default = "0.4.5" + default = "0.4.9" } variable "zone" { @@ -12,7 +13,130 @@ variable "zone" { default = "us-central1-a" } +variable "image" { + default = "debian-cloud/debian-11" + description = "The image to use for the instance" + type = string +} + +variable "vm_user" { + default = "debian" + description = "The user to use for connecting to the instance. This is usually the default image user" + type = string +} + variable "machine_type" { type = string default = "e2-small" } + +variable "public_access" { + description = "Enable public ingress on port 8000" + type = bool + default = true // or true depending on your needs +} + +variable "enable_auth" { + description = "Enable authentication" + type = bool + default = true // or false depending on your needs +} + +variable "auth_type" { + description = "Authentication type" + type = string + default = "token" // or token depending on your needs + validation { + condition = contains(["basic", "token"], var.auth_type) + error_message = "The auth type must be either basic or token" + } +} + +resource "random_password" "chroma_password" { + length = 16 + special = true + lower = true + upper = true +} + +resource "random_password" "chroma_token" { + length = 32 + special = false + lower = true + upper = true +} + + +locals { + basic_auth_credentials = { + username = "chroma" + password = random_password.chroma_password.result + } + token_auth_credentials = { + token = random_password.chroma_token.result + } + tags = [ + "chroma", + "release-${replace(var.chroma_release, ".", "")}", + ] +} + +variable "ssh_public_key" { + description = "SSH Public Key" + type = string + default = "./chroma-aws.pub" +} +variable "ssh_private_key" { + description = "SSH Private Key" + type = string + default = "./chroma-aws" +} + +variable "chroma_instance_volume_size" { + description = "The size of the instance volume - the root volume" + type = number + default = 30 +} + +variable "chroma_data_volume_size" { + description = "Volume Size of the attached data volume where your chroma data is stored" + type = number + default = 20 +} + +variable "chroma_data_volume_device_name" { + default = "chroma-disk-0" + description = "The device name of the chroma data volume" + type = string +} + +variable "prevent_chroma_data_volume_delete" { + description = "Prevent the chroma data volume from being deleted when the instance is terminated" + type = bool + default = false +} + +variable "disk_type" { + default = "pd-ssd" + description = "The type of disk to use for the instance. Can be either pd-standard or pd-ssd" +} + +variable "labels" { + default = { + environment = "dev" + } + description = "Labels to apply to all resources in this example" + type = map(string) +} + +variable "chroma_port" { + default = "8000" + description = "The port that chroma listens on" + type = string +} + +variable "source_ranges" { + default = ["0.0.0.0/0"] + type = list(string) + description = "List of CIDR ranges to allow through the firewall" +} From 9db68045e6b57fdaa418ad723f72a3121be297ac Mon Sep 17 00:00:00 2001 From: Hammad Bashir Date: Mon, 11 Sep 2023 09:49:36 -0700 Subject: [PATCH 06/39] [CHORE] Bump HNSWlib to latest version that has precompiled binaries (#1109) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Bump HNSWlib to latest version that has precompiled binaries. Use alpha release for CI tests before releasing ## Test plan Existing tests should over functionality. Build compatibility of the binaries was manually verified. - [x] Tests pass locally with `pytest` for python, `yarn test` for js ## Documentation Changes We should add how to force recompiling with AVX to the docs. --- pyproject.toml | 2 +- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3d62663ca86..926048d38ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ classifiers = [ dependencies = [ 'requests >= 2.28', 'pydantic>=1.9,<2.0', - 'chroma-hnswlib==0.7.2', + 'chroma-hnswlib==0.7.3', 'fastapi>=0.95.2, <0.100.0', 'uvicorn[standard] >= 0.18.3', 'numpy == 1.21.6; python_version < "3.8"', diff --git a/requirements.txt b/requirements.txt index abf333a95c0..e94f7caf91f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ bcrypt==4.0.1 -chroma-hnswlib==0.7.2 +chroma-hnswlib==0.7.3 fastapi>=0.95.2, <0.100.0 graphlib_backport==1.0.3; python_version < '3.9' importlib-resources From 8e967304d6c1a18d88b63c1fdb3e7a638a678f01 Mon Sep 17 00:00:00 2001 From: Hammad Bashir Date: Mon, 11 Sep 2023 11:09:02 -0700 Subject: [PATCH 07/39] [RELEASE] 0.4.10 (#1132) Release 0.4.10 --- chromadb/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chromadb/__init__.py b/chromadb/__init__.py index 62cb174f7b0..804ff93404c 100644 --- a/chromadb/__init__.py +++ b/chromadb/__init__.py @@ -44,7 +44,7 @@ __settings = Settings() -__version__ = "0.4.9" +__version__ = "0.4.10" # Workaround to deal with Colab's old sqlite3 version try: From 7d412aef8c8b8e7623e8d394763ced4373e24898 Mon Sep 17 00:00:00 2001 From: Jeff Huber Date: Mon, 11 Sep 2023 20:49:25 -0700 Subject: [PATCH 08/39] [ENH] initial CLI (#1032) This proposes an initial CLI The CLI is installed when you installed `pip install chromadb`. You then call the CLI with `chroma run --path --port ` where path and port are optional. This also adds `chroma help` and `chroma docs` as convenience links - but I'm open to removing those. To make this easy - I added `typer` (by the author of FastAPI). I'm not sure this is the tool that we want to commit to for a fuller featured CLI, but given the extremely minimal footprint of this - I don't think it's a one way door. Screenshot 2023-08-23 at 4 59 54 PM *** #### TODO - [x] test in fresh env - i think i need to add `typer` as a req - [ ] consider expanding the test to make sure the service is actually running - [x] hide the test option from the typer UI - [x] linking to a getting started guide could be interesting at the top of the logs --- .github/workflows/chroma-integration-test.yml | 3 +- DEVELOP.md | 2 +- README.md | 2 +- chromadb/cli/__init__.py | 0 chromadb/cli/cli.py | 87 +++++++++++++++++++ chromadb/log_config.yml | 37 ++++++++ chromadb/test/test_cli.py | 21 +++++ clients/js/package.json | 5 +- docker-compose.test-auth.yml | 2 +- docker-compose.test.yml | 2 +- docker-compose.yml | 2 +- examples/use_with/cohere/cohere_js.js | 20 ++--- log_config.yml | 23 ----- pyproject.toml | 9 +- requirements.txt | 1 + 15 files changed, 175 insertions(+), 41 deletions(-) create mode 100644 chromadb/cli/__init__.py create mode 100644 chromadb/cli/cli.py create mode 100644 chromadb/log_config.yml create mode 100644 chromadb/test/test_cli.py delete mode 100644 log_config.yml diff --git a/.github/workflows/chroma-integration-test.yml b/.github/workflows/chroma-integration-test.yml index e4a2b7517f5..91628d6f545 100644 --- a/.github/workflows/chroma-integration-test.yml +++ b/.github/workflows/chroma-integration-test.yml @@ -16,8 +16,9 @@ jobs: matrix: python: ['3.7'] platform: [ubuntu-latest, windows-latest] - testfile: ["--ignore-glob 'chromadb/test/property/*'", + testfile: ["--ignore-glob 'chromadb/test/property/*' --ignore='chromadb/test/test_cli.py'", "chromadb/test/property/test_add.py", + "chromadb/test/test_cli.py", "chromadb/test/property/test_collections.py", "chromadb/test/property/test_cross_version_persist.py", "chromadb/test/property/test_embeddings.py", diff --git a/DEVELOP.md b/DEVELOP.md index 29f36abb1ff..f034e07bed3 100644 --- a/DEVELOP.md +++ b/DEVELOP.md @@ -43,7 +43,7 @@ print(api.heartbeat()) 3. With a persistent backend and a small frontend client -Run `docker-compose up -d --build` +Run `chroma run --path /chroma_db_path` ```python import chromadb api = chromadb.HttpClient(host="localhost", port="8000") diff --git a/README.md b/README.md index 7354f7ed991..25db53b73d8 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ ```bash pip install chromadb # python client # for javascript, npm install chromadb! -# for client-server mode, docker-compose up -d --build +# for client-server mode, chroma run --path /chroma_db_path ``` The core API is only 4 functions (run our [💡 Google Colab](https://colab.research.google.com/drive/1QEzFyqnoFxq7LUGyP1vzR4iLt9PpCDXv?usp=sharing) or [Replit template](https://replit.com/@swyx/BasicChromaStarter?v=1)): diff --git a/chromadb/cli/__init__.py b/chromadb/cli/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/chromadb/cli/cli.py b/chromadb/cli/cli.py new file mode 100644 index 00000000000..d60a80b8a48 --- /dev/null +++ b/chromadb/cli/cli.py @@ -0,0 +1,87 @@ +import typer +import uvicorn +import os +import webbrowser + +app = typer.Typer() + +_logo = """ + \033[38;5;069m((((((((( \033[38;5;203m(((((\033[38;5;220m#### + \033[38;5;069m(((((((((((((\033[38;5;203m(((((((((\033[38;5;220m######### + \033[38;5;069m(((((((((((((\033[38;5;203m(((((((((((\033[38;5;220m########### + \033[38;5;069m((((((((((((((\033[38;5;203m((((((((((((\033[38;5;220m############ + \033[38;5;069m(((((((((((((\033[38;5;203m((((((((((((((\033[38;5;220m############# + \033[38;5;069m(((((((((((((\033[38;5;203m((((((((((((((\033[38;5;220m############# + \033[38;5;069m((((((((((((\033[38;5;203m(((((((((((((\033[38;5;220m############## + \033[38;5;069m((((((((((((\033[38;5;203m((((((((((((\033[38;5;220m############## + \033[38;5;069m((((((((((\033[38;5;203m(((((((((((\033[38;5;220m############# + \033[38;5;069m((((((((\033[38;5;203m((((((((\033[38;5;220m############## + \033[38;5;069m(((((\033[38;5;203m(((( \033[38;5;220m#########\033[0m + + """ + + +@app.command() # type: ignore +def run( + path: str = typer.Option( + "./chroma_data", help="The path to the file or directory." + ), + port: int = typer.Option(8000, help="The port to run the server on."), + test: bool = typer.Option(False, help="Test mode.", show_envvar=False, hidden=True), +) -> None: + """Run a chroma server""" + + print("\033[1m") # Bold logo + print(_logo) + print("\033[1m") # Bold + print("Running Chroma") + print("\033[0m") # Reset + + typer.echo(f"\033[1mSaving data to\033[0m: \033[32m{path}\033[0m") + typer.echo( + f"\033[1mConnect to chroma at\033[0m: \033[32mhttp://localhost:{port}\033[0m" + ) + typer.echo( + "\033[1mGetting started guide\033[0m: https://docs.trychroma.com/getting-started\n\n" + ) + + # set ENV variable for PERSIST_DIRECTORY to path + os.environ["IS_PERSISTENT"] = "True" + os.environ["PERSIST_DIRECTORY"] = path + + # get the path where chromadb is installed + chromadb_path = os.path.dirname(os.path.realpath(__file__)) + + # this is the path of the CLI, we want to move up one directory + chromadb_path = os.path.dirname(chromadb_path) + + config = { + "app": "chromadb.app:app", + "host": "0.0.0.0", + "port": port, + "workers": 1, + "log_config": f"{chromadb_path}/log_config.yml", + } + + if test: + return + + uvicorn.run(**config) + + +@app.command() # type: ignore +def help() -> None: + """Opens help url in your browser""" + + webbrowser.open("https://discord.gg/MMeYNTmh3x") + + +@app.command() # type: ignore +def docs() -> None: + """Opens docs url in your browser""" + + webbrowser.open("https://docs.trychroma.com") + + +if __name__ == "__main__": + app() diff --git a/chromadb/log_config.yml b/chromadb/log_config.yml new file mode 100644 index 00000000000..80e62479917 --- /dev/null +++ b/chromadb/log_config.yml @@ -0,0 +1,37 @@ +version: 1 +disable_existing_loggers: False +formatters: + default: + "()": uvicorn.logging.DefaultFormatter + format: '%(levelprefix)s [%(asctime)s] %(message)s' + use_colors: null + datefmt: '%d-%m-%Y %H:%M:%S' + access: + "()": uvicorn.logging.AccessFormatter + format: '%(levelprefix)s [%(asctime)s] %(client_addr)s - "%(request_line)s" %(status_code)s' + datefmt: '%d-%m-%Y %H:%M:%S' +handlers: + default: + formatter: default + class: logging.StreamHandler + stream: ext://sys.stderr + access: + formatter: access + class: logging.StreamHandler + stream: ext://sys.stdout + console: + class: logging.StreamHandler + stream: ext://sys.stdout + formatter: default + file: + class : logging.handlers.RotatingFileHandler + filename: chroma.log + formatter: default +loggers: + root: + level: WARN + handlers: [console, file] + chromadb: + level: DEBUG + uvicorn: + level: INFO diff --git a/chromadb/test/test_cli.py b/chromadb/test/test_cli.py new file mode 100644 index 00000000000..231877341f5 --- /dev/null +++ b/chromadb/test/test_cli.py @@ -0,0 +1,21 @@ +from typer.testing import CliRunner + +from chromadb.cli.cli import app + +runner = CliRunner() + + +def test_app() -> None: + result = runner.invoke( + app, + [ + "run", + "--path", + "chroma_test_data", + "--port", + "8001", + "--test", + ], + ) + assert "chroma_test_data" in result.stdout + assert "8001" in result.stdout diff --git a/clients/js/package.json b/clients/js/package.json index 53dbaf0e9c6..13014226bed 100644 --- a/clients/js/package.json +++ b/clients/js/package.json @@ -29,7 +29,9 @@ "dist" ], "scripts": { - "test": "run-s db:clean db:run test:runfull db:clean db:run-auth test:runfull-authonly db:clean", + "test": "run-s db:clean db:cleanauth db:run test:runfull db:clean db:run-auth test:runfull-authonly db:cleanauth", + "testnoauth": "run-s db:clean db:run test:runfull db:clean", + "testauth": "run-s db:cleanauth db:run-auth test:runfull-authonly db:cleanauth", "test:set-port": "cross-env URL=localhost:8001", "test:run": "jest --runInBand --testPathIgnorePatterns=test/auth.basic.test.ts", "test:run-auth": "jest --runInBand --testPathPattern=test/auth.basic.test.ts", @@ -37,6 +39,7 @@ "test:runfull-authonly": "PORT=8001 jest --runInBand --testPathPattern=test/auth.basic.test.ts", "test:update": "run-s db:clean db:run && jest --runInBand --updateSnapshot && run-s db:clean", "db:clean": "cd ../.. && CHROMA_PORT=8001 docker-compose -f docker-compose.test.yml down --volumes", + "db:cleanauth": "cd ../.. && CHROMA_PORT=8001 docker-compose -f docker-compose.test-auth.yml down --volumes", "db:run": "cd ../.. && CHROMA_PORT=8001 docker-compose -f docker-compose.test.yml up --detach && sleep 5", "db:run-auth": "cd ../.. && CHROMA_PORT=8001 docker-compose -f docker-compose.test-auth.yml up --detach && sleep 5", "clean": "rimraf dist", diff --git a/docker-compose.test-auth.yml b/docker-compose.test-auth.yml index 945739782a1..921f0749f06 100644 --- a/docker-compose.test-auth.yml +++ b/docker-compose.test-auth.yml @@ -12,7 +12,7 @@ services: volumes: - ./:/chroma - test_index_data:/index_data - command: uvicorn chromadb.app:app --workers 1 --host 0.0.0.0 --port 8000 --log-config log_config.yml + command: uvicorn chromadb.app:app --workers 1 --host 0.0.0.0 --port 8000 --log-config chromadb/log_config.yml environment: - ANONYMIZED_TELEMETRY=False - ALLOW_RESET=True diff --git a/docker-compose.test.yml b/docker-compose.test.yml index eb65303ebd6..51404725730 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -12,7 +12,7 @@ services: volumes: - ./:/chroma - test_index_data:/index_data - command: uvicorn chromadb.app:app --workers 1 --host 0.0.0.0 --port 8000 --log-config log_config.yml + command: uvicorn chromadb.app:app --workers 1 --host 0.0.0.0 --port 8000 --log-config chromadb/log_config.yml environment: - ANONYMIZED_TELEMETRY=False - ALLOW_RESET=True diff --git a/docker-compose.yml b/docker-compose.yml index 2119eba7e00..93581dd23c7 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -14,7 +14,7 @@ services: - ./:/chroma # Be aware that indexed data are located in "/chroma/chroma/" # Default configuration for persist_directory in chromadb/config.py - command: uvicorn chromadb.app:app --reload --workers 1 --host 0.0.0.0 --port 8000 --log-config log_config.yml + command: uvicorn chromadb.app:app --reload --workers 1 --host 0.0.0.0 --port 8000 --log-config chromadb/log_config.yml environment: - IS_PERSISTENT=TRUE - CHROMA_SERVER_AUTH_PROVIDER=${CHROMA_SERVER_AUTH_PROVIDER} diff --git a/examples/use_with/cohere/cohere_js.js b/examples/use_with/cohere/cohere_js.js index 585e383c371..7fe86691602 100644 --- a/examples/use_with/cohere/cohere_js.js +++ b/examples/use_with/cohere/cohere_js.js @@ -7,7 +7,7 @@ First run Chroma ``` git clone git@github.com:chroma-core/chroma.git cd chroma -docker-compose up -d --build +chroma run --path /chroma_db_path ``` Then install chroma and cohere @@ -61,20 +61,20 @@ const main = async () => { }); // # 나는 오렌지를 좋아한다 is "I like oranges" in Korean - multilingual_texts = [ 'Hello from Cohere!', 'مرحبًا من كوهير!', - 'Hallo von Cohere!', 'Bonjour de Cohere!', - '¡Hola desde Cohere!', 'Olá do Cohere!', - 'Ciao da Cohere!', '您好,来自 Cohere!', - 'कोहेरे से नमस्ते!', '나는 오렌지를 좋아한다' ] + multilingual_texts = ['Hello from Cohere!', 'مرحبًا من كوهير!', + 'Hallo von Cohere!', 'Bonjour de Cohere!', + '¡Hola desde Cohere!', 'Olá do Cohere!', + 'Ciao da Cohere!', '您好,来自 Cohere!', + 'कोहेरे से नमस्ते!', '나는 오렌지를 좋아한다'] let ids = Array.from({ length: multilingual_texts.length }, (_, i) => String(i)); await collection.add({ - ids:ids, - documents:multilingual_texts -}) + ids: ids, + documents: multilingual_texts + }) - console.log(await collection.query({queryTexts:["citrus"], nResults:1})) + console.log(await collection.query({ queryTexts: ["citrus"], nResults: 1 })) } diff --git a/log_config.yml b/log_config.yml deleted file mode 100644 index e8da3c2c7de..00000000000 --- a/log_config.yml +++ /dev/null @@ -1,23 +0,0 @@ -version: 1 -disable_existing_loggers: False -formatters: - default: - format: '%(asctime)s %(levelname)-8s %(name)-15s %(message)s' - datefmt: '%Y-%m-%d %H:%M:%S' -handlers: - console: - class: logging.StreamHandler - stream: ext://sys.stdout - formatter: default - file: - class : logging.handlers.RotatingFileHandler - filename: chroma.log - formatter: default -loggers: - root: - level: WARN - handlers: [console, file] - chromadb: - level: DEBUG - uvicorn: - level: INFO diff --git a/pyproject.toml b/pyproject.toml index 926048d38ff..8fc60673607 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,8 @@ dependencies = [ 'overrides >= 7.3.1', 'importlib-resources', 'graphlib_backport >= 1.0.3; python_version < "3.9"', - 'bcrypt >= 4.0.1' + 'bcrypt >= 4.0.1', + 'typer >= 0.9.0', ] [tool.black] @@ -43,6 +44,9 @@ target-version = ['py36', 'py37', 'py38', 'py39', 'py310'] [tool.pytest.ini_options] pythonpath = ["."] +[project.scripts] +chroma = "chromadb.cli.cli:app" + [project.urls] "Homepage" = "https://github.com/chroma-core/chroma" "Bug Tracker" = "https://github.com/chroma-core/chroma/issues" @@ -56,3 +60,6 @@ local_scheme="no-local-version" [tool.setuptools] packages = ["chromadb"] + +[tool.setuptools.package-data] +chromadb = ["*.yml"] diff --git a/requirements.txt b/requirements.txt index e94f7caf91f..9a9fdcc295c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,5 +14,6 @@ pypika==0.48.9 requests==2.28.1 tokenizers==0.13.2 tqdm==4.65.0 +typer>=0.9.0 typing_extensions==4.5.0 uvicorn[standard]==0.18.3 From 831c027f5cfb27cf70d846a49315070ff26f3a3c Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Tue, 12 Sep 2023 06:49:55 +0300 Subject: [PATCH 09/39] [SEC]: Bandit Scan (#1113) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Added bandit scanning for all pushes to repo ## Test plan *How are these changes tested?* Manual testing of the workflow ## Documentation Changes N/A - unless we want to start a separate security section in the main docs repo. --------- Co-authored-by: Hammad Bashir --- .github/actions/bandit-scan/Dockerfile | 7 ++++++ .github/actions/bandit-scan/action.yaml | 26 +++++++++++++++++++++ .github/actions/bandit-scan/entrypoint.sh | 13 +++++++++++ .github/workflows/python-vuln.yaml | 28 +++++++++++++++++++++++ bandit.yaml | 4 ++++ 5 files changed, 78 insertions(+) create mode 100644 .github/actions/bandit-scan/Dockerfile create mode 100644 .github/actions/bandit-scan/action.yaml create mode 100755 .github/actions/bandit-scan/entrypoint.sh create mode 100644 .github/workflows/python-vuln.yaml create mode 100644 bandit.yaml diff --git a/.github/actions/bandit-scan/Dockerfile b/.github/actions/bandit-scan/Dockerfile new file mode 100644 index 00000000000..943f04fc8f3 --- /dev/null +++ b/.github/actions/bandit-scan/Dockerfile @@ -0,0 +1,7 @@ +FROM python:3.10-alpine AS base-action + +RUN pip3 install -U setuptools pip bandit + +COPY entrypoint.sh /entrypoint.sh +RUN chmod +x /entrypoint.sh +ENTRYPOINT ["sh","/entrypoint.sh"] diff --git a/.github/actions/bandit-scan/action.yaml b/.github/actions/bandit-scan/action.yaml new file mode 100644 index 00000000000..e0735450f57 --- /dev/null +++ b/.github/actions/bandit-scan/action.yaml @@ -0,0 +1,26 @@ +name: 'Bandit Scan' +description: 'This action performs a security vulnerability scan of python code using bandit library.' +inputs: + bandit-config: + description: 'Bandit configuration file' + required: false + input-dir: + description: 'Directory to scan' + required: false + default: '.' + format: + description: 'Output format (txt, csv, json, xml, yaml). Default: json' + required: false + default: 'json' + output-file: + description: "The report file to produce. Make sure to align your format with the file extension to avoid confusion." + required: false + default: "bandit-scan.json" +runs: + using: 'docker' + image: 'Dockerfile' + args: + - ${{ inputs.format }} + - ${{ inputs.bandit-config }} + - ${{ inputs.input-dir }} + - ${{ inputs.output-file }} diff --git a/.github/actions/bandit-scan/entrypoint.sh b/.github/actions/bandit-scan/entrypoint.sh new file mode 100755 index 00000000000..f52daddd781 --- /dev/null +++ b/.github/actions/bandit-scan/entrypoint.sh @@ -0,0 +1,13 @@ +#!/bin/bash +CFG="-c $2" +if [ -z "$1" ]; then + echo "No path to scan provided" + exit 1 +fi + +if [ -z "$2" ]; then + CFG = "" +fi + +bandit -f "$1" ${CFG} -r "$3" -o "$4" +exit 0 #we want to ignore the exit code of bandit (for now) diff --git a/.github/workflows/python-vuln.yaml b/.github/workflows/python-vuln.yaml new file mode 100644 index 00000000000..8e6c33a255c --- /dev/null +++ b/.github/workflows/python-vuln.yaml @@ -0,0 +1,28 @@ +name: Python Vulnerability Scan +on: + push: + branches: + - '*' + - '*/**' + paths: + - chromadb/** + - clients/python/** + workflow_dispatch: +jobs: + bandit-scan: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + - uses: ./.github/actions/bandit-scan/ + with: + input-dir: '.' + format: 'json' + bandit-config: 'bandit.yaml' + output-file: 'bandit-report.json' + - name: Upload Bandit Report + uses: actions/upload-artifact@v3 + with: + name: bandit-artifact + path: | + bandit-report.json diff --git a/bandit.yaml b/bandit.yaml new file mode 100644 index 00000000000..9a93633ea12 --- /dev/null +++ b/bandit.yaml @@ -0,0 +1,4 @@ +# FILE: bandit.yaml +exclude_dirs: [ 'chromadb/test', 'bin', 'build', 'build', '.git', '.venv', 'venv', 'env','.github','examples','clients/js','.vscode' ] +tests: [ ] +skips: [ ] From 6681df91bfb96191c6bbea6dc7be8d48994c66d3 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Dash <47926185+sunilkumardash9@users.noreply.github.com> Date: Thu, 14 Sep 2023 10:43:18 +0530 Subject: [PATCH 10/39] Enable manual workflow trigger (#1036) ## Description of changes *Summarize the changes made by this PR.* - Added a workflow_dispatch to manually trigger test workflows - will be good for development experience --------- Signed-off-by: sunilkumardash9 --- .github/workflows/chroma-client-integration-test.yml | 3 ++- .github/workflows/chroma-integration-test.yml | 1 + .github/workflows/chroma-test.yml | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/chroma-client-integration-test.yml b/.github/workflows/chroma-client-integration-test.yml index a4e70d13baf..25788090ef2 100644 --- a/.github/workflows/chroma-client-integration-test.yml +++ b/.github/workflows/chroma-client-integration-test.yml @@ -8,7 +8,8 @@ on: branches: - main - '**' - + workflow_dispatch: + jobs: test: timeout-minutes: 90 diff --git a/.github/workflows/chroma-integration-test.yml b/.github/workflows/chroma-integration-test.yml index 91628d6f545..963a7b6ed63 100644 --- a/.github/workflows/chroma-integration-test.yml +++ b/.github/workflows/chroma-integration-test.yml @@ -9,6 +9,7 @@ on: branches: - main - '**' + workflow_dispatch: jobs: test: diff --git a/.github/workflows/chroma-test.yml b/.github/workflows/chroma-test.yml index e0d44c2d647..90ff2b66940 100644 --- a/.github/workflows/chroma-test.yml +++ b/.github/workflows/chroma-test.yml @@ -9,6 +9,7 @@ on: branches: - main - '**' + workflow_dispatch: jobs: test: From d090ca6f6fb646972c3d3b6325a738a07890de51 Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Fri, 15 Sep 2023 10:13:30 -0700 Subject: [PATCH 11/39] Fix broken peer OpenAI dep dependency range (#1142) https://semver.npmjs.com/ Screenshot 2023-09-13 at 3 51 28 PM Screenshot 2023-09-13 at 3 51 16 PM npm strictly checks peer dep ranges, which means the `npm install` of anything with a peer dep on Chroma was affected by this. --- clients/js/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clients/js/package.json b/clients/js/package.json index 13014226bed..cd898e5d9c3 100644 --- a/clients/js/package.json +++ b/clients/js/package.json @@ -62,7 +62,7 @@ "@visheratin/web-ai-node": "^1.0.0", "@xenova/transformers": "^2.0.0", "cohere-ai": "^6.0.0", - "openai": "^3.0.0 | ^4.0.0" + "openai": "^3.0.0 || ^4.0.0" }, "peerDependenciesMeta": { "@visheratin/web-ai": { From a0a3c35217f2252870331dfd7b91590131fdad10 Mon Sep 17 00:00:00 2001 From: Jeff Huber Date: Fri, 15 Sep 2023 13:32:41 -0700 Subject: [PATCH 12/39] bump JS to 1.5.9 (#1145) Bump to release to 1.5.9 to release https://github.com/chroma-core/chroma/pull/1142 --- clients/js/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clients/js/package.json b/clients/js/package.json index cd898e5d9c3..76077d77e28 100644 --- a/clients/js/package.json +++ b/clients/js/package.json @@ -1,6 +1,6 @@ { "name": "chromadb", - "version": "1.5.8", + "version": "1.5.9", "description": "A JavaScript interface for chroma", "keywords": [], "author": "", From dac67e7ca82589f1a0cb484ebc5668b379339951 Mon Sep 17 00:00:00 2001 From: Leonid Ganeline Date: Mon, 18 Sep 2023 10:45:34 -0700 Subject: [PATCH 13/39] simplified ut-s (#1071) ## Description of changes - Improvements - simplified ut-s - cleaned up a typing import ## Test plan - [+] Tests passed successfully locally with `pytest` for python, `yarn test` for js ## Documentation Changes N/A --- chromadb/auth/providers.py | 2 +- chromadb/test/property/test_embeddings.py | 66 ++++++++++++----------- 2 files changed, 35 insertions(+), 33 deletions(-) diff --git a/chromadb/auth/providers.py b/chromadb/auth/providers.py index 3123c52a54c..a3bb23616e2 100644 --- a/chromadb/auth/providers.py +++ b/chromadb/auth/providers.py @@ -1,6 +1,6 @@ import importlib import logging -from typing import cast, Dict, TypeVar, Any, Optional +from typing import cast, Dict, TypeVar, Any import requests from overrides import override diff --git a/chromadb/test/property/test_embeddings.py b/chromadb/test/property/test_embeddings.py index 5b11f378b8d..0e402cca1a8 100644 --- a/chromadb/test/property/test_embeddings.py +++ b/chromadb/test/property/test_embeddings.py @@ -365,39 +365,41 @@ def test_escape_chars_in_ids(api: API) -> None: assert coll.count() == 0 -def test_delete_empty_fails(api: API): +@pytest.mark.parametrize( + "kwargs", + [ + {}, + {"ids": []}, + {"where": {}}, + {"where_document": {}}, + {"where_document": {}, "where": {}}, + ], +) +def test_delete_empty_fails(api: API, kwargs: dict): api.reset() coll = api.create_collection(name="foo") - - error_valid = ( - lambda e: "You must provide either ids, where, or where_document to delete." - in e - ) - with pytest.raises(Exception) as e: - coll.delete() - assert error_valid(str(e)) - - with pytest.raises(Exception): - coll.delete(ids=[]) - assert error_valid(str(e)) - - with pytest.raises(Exception): - coll.delete(where={}) - assert error_valid(str(e)) - - with pytest.raises(Exception): - coll.delete(where_document={}) - assert error_valid(str(e)) - - with pytest.raises(Exception): - coll.delete(where_document={}, where={}) - assert error_valid(str(e)) - + coll.delete(**kwargs) + assert "You must provide either ids, where, or where_document to delete." in str(e) + + +@pytest.mark.parametrize( + "kwargs", + [ + {"ids": ["foo"]}, + {"where": {"foo": "bar"}}, + {"where_document": {"$contains": "bar"}}, + {"ids": ["foo"], "where": {"foo": "bar"}}, + {"ids": ["foo"], "where_document": {"$contains": "bar"}}, + { + "ids": ["foo"], + "where": {"foo": "bar"}, + "where_document": {"$contains": "bar"}, + }, + ], +) +def test_delete_success(api: API, kwargs: dict): + api.reset() + coll = api.create_collection(name="foo") # Should not raise - coll.delete(where_document={"$contains": "bar"}) - coll.delete(where={"foo": "bar"}) - coll.delete(ids=["foo"]) - coll.delete(ids=["foo"], where={"foo": "bar"}) - coll.delete(ids=["foo"], where_document={"$contains": "bar"}) - coll.delete(ids=["foo"], where_document={"$contains": "bar"}, where={"foo": "bar"}) + coll.delete(**kwargs) From 2b434b826642a4095e43367abd0a90b6c2e0b0e5 Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Mon, 18 Sep 2023 21:30:25 +0300 Subject: [PATCH 14/39] [ENH]: JS Client Static Token support (#1114) Refs: #1083 ## Description of changes *Summarize the changes made by this PR.* - New functionality - JS Client now supports Authorization, and X-Chroma-Token auths supported - Tests and integration tests updated ## Test plan *How are these changes tested?* - [x] Tests pass locally `yarn test` for js ## Documentation Changes TBD --- bin/integration-test | 59 +++++++++++----- clients/js/package.json | 17 +++-- clients/js/src/auth.ts | 97 ++++++++++++++++++++++++++- clients/js/test/auth.basic.test.ts | 10 +-- clients/js/test/auth.token.test.ts | 59 ++++++++++++++++ clients/js/test/initClientWithAuth.ts | 13 +++- docker-compose.test-auth.yml | 8 ++- 7 files changed, 228 insertions(+), 35 deletions(-) create mode 100644 clients/js/test/auth.token.test.ts diff --git a/bin/integration-test b/bin/integration-test index e91e45e93c5..54b4e387e08 100755 --- a/bin/integration-test +++ b/bin/integration-test @@ -9,15 +9,38 @@ function cleanup { rm server.htpasswd .chroma_env } -function setup_basic_auth { - # Generate htpasswd file - docker run --rm --entrypoint htpasswd httpd:2 -Bbn admin admin > server.htpasswd - # Create .chroma_env file - cat < .chroma_env +function setup_auth { + local auth_type="$1" + case "$auth_type" in + basic) + docker run --rm --entrypoint htpasswd httpd:2 -Bbn admin admin > server.htpasswd + cat < .chroma_env CHROMA_SERVER_AUTH_CREDENTIALS_FILE="/chroma/server.htpasswd" -CHROMA_SERVER_AUTH_CREDENTIALS_PROVIDER='chromadb.auth.providers.HtpasswdFileServerAuthCredentialsProvider' -CHROMA_SERVER_AUTH_PROVIDER='chromadb.auth.basic.BasicAuthServerProvider' +CHROMA_SERVER_AUTH_CREDENTIALS_PROVIDER="chromadb.auth.providers.HtpasswdFileServerAuthCredentialsProvider" +CHROMA_SERVER_AUTH_PROVIDER="chromadb.auth.basic.BasicAuthServerProvider" EOF + ;; + token) + cat < .chroma_env +CHROMA_SERVER_AUTH_CREDENTIALS="test-token" +CHROMA_SERVER_AUTH_TOKEN_TRANSPORT_HEADER="AUTHORIZATION" +CHROMA_SERVER_AUTH_CREDENTIALS_PROVIDER="chromadb.auth.token.TokenConfigServerAuthCredentialsProvider" +CHROMA_SERVER_AUTH_PROVIDER="chromadb.auth.token.TokenAuthServerProvider" +EOF + ;; + xtoken) + cat < .chroma_env +CHROMA_SERVER_AUTH_CREDENTIALS="test-token" +CHROMA_SERVER_AUTH_TOKEN_TRANSPORT_HEADER="X_CHROMA_TOKEN" +CHROMA_SERVER_AUTH_CREDENTIALS_PROVIDER="chromadb.auth.token.TokenConfigServerAuthCredentialsProvider" +CHROMA_SERVER_AUTH_PROVIDER="chromadb.auth.token.TokenAuthServerProvider" +EOF + ;; + *) + echo "Unknown auth type: $auth_type" + exit 1 + ;; + esac } trap cleanup EXIT @@ -28,19 +51,21 @@ export CHROMA_INTEGRATION_TEST_ONLY=1 export CHROMA_API_IMPL=chromadb.api.fastapi.FastAPI export CHROMA_SERVER_HOST=localhost export CHROMA_SERVER_HTTP_PORT=8000 - -echo testing: python -m pytest "$@" -python -m pytest "$@" +# +#echo testing: python -m pytest "$@" +#python -m pytest "$@" cd clients/js yarn yarn test:run docker compose down cd ../.. -echo "Testing auth" -setup_basic_auth #this is specific to the auth type, later on we'll have other auth types -cd clients/js -# Start docker compose - this should be auth agnostic -docker compose --env-file ../../.chroma_env -f ../../docker-compose.test-auth.yml up --build -d -yarn test:run-auth -cd ../.. +for auth_type in basic token xtoken; do + echo "Testing $auth_type auth" + setup_auth "$auth_type" + cd clients/js + docker compose --env-file ../../.chroma_env -f ../../docker-compose.test-auth.yml up --build -d + yarn test:run-auth-"$auth_type" + cd ../.. + docker compose down +done diff --git a/clients/js/package.json b/clients/js/package.json index 76077d77e28..25d3d90b94c 100644 --- a/clients/js/package.json +++ b/clients/js/package.json @@ -33,15 +33,22 @@ "testnoauth": "run-s db:clean db:run test:runfull db:clean", "testauth": "run-s db:cleanauth db:run-auth test:runfull-authonly db:cleanauth", "test:set-port": "cross-env URL=localhost:8001", - "test:run": "jest --runInBand --testPathIgnorePatterns=test/auth.basic.test.ts", - "test:run-auth": "jest --runInBand --testPathPattern=test/auth.basic.test.ts", - "test:runfull": "PORT=8001 jest --runInBand --testPathIgnorePatterns=test/auth.basic.test.ts", - "test:runfull-authonly": "PORT=8001 jest --runInBand --testPathPattern=test/auth.basic.test.ts", + "test:run": "jest --runInBand --testPathIgnorePatterns=test/auth.*.test.ts", + "test:run-auth-basic": "jest --runInBand --testPathPattern=test/auth.basic.test.ts", + "test:run-auth-token": "jest --runInBand --testPathPattern=test/auth.token.test.ts", + "test:run-auth-xtoken": "XTOKEN_TEST=true jest --runInBand --testPathPattern=test/auth.token.test.ts", + "test:runfull": "PORT=8001 jest --runInBand --testPathIgnorePatterns=test/auth.*.test.ts", + "test:runfull-authonly": "run-s db:run-auth-basic test:runfull-authonly-basic db:clean db:run-auth-token test:runfull-authonly-token db:clean db:run-auth-xtoken test:runfull-authonly-xtoken db:clean", + "test:runfull-authonly-basic": "PORT=8001 jest --runInBand --testPathPattern=test/auth.basic.test.ts", + "test:runfull-authonly-token": "PORT=8001 jest --runInBand --testPathPattern=test/auth.token.test.ts", + "test:runfull-authonly-xtoken": "PORT=8001 XTOKEN_TEST=true jest --runInBand --testPathPattern=test/auth.token.test.ts", "test:update": "run-s db:clean db:run && jest --runInBand --updateSnapshot && run-s db:clean", "db:clean": "cd ../.. && CHROMA_PORT=8001 docker-compose -f docker-compose.test.yml down --volumes", "db:cleanauth": "cd ../.. && CHROMA_PORT=8001 docker-compose -f docker-compose.test-auth.yml down --volumes", "db:run": "cd ../.. && CHROMA_PORT=8001 docker-compose -f docker-compose.test.yml up --detach && sleep 5", - "db:run-auth": "cd ../.. && CHROMA_PORT=8001 docker-compose -f docker-compose.test-auth.yml up --detach && sleep 5", + "db:run-auth-basic": "cd ../.. && docker run --rm --entrypoint htpasswd httpd:2 -Bbn admin admin > server.htpasswd && echo \"CHROMA_SERVER_AUTH_CREDENTIALS_FILE=/chroma/server.htpasswd\\nCHROMA_SERVER_AUTH_CREDENTIALS_PROVIDER=chromadb.auth.providers.HtpasswdFileServerAuthCredentialsProvider\\nCHROMA_SERVER_AUTH_PROVIDER=chromadb.auth.basic.BasicAuthServerProvider\\nCHROMA_PORT=8001\" > .chroma_env && docker-compose -f docker-compose.test-auth.yml --env-file ./.chroma_env up --detach && sleep 5", + "db:run-auth-token": "cd ../.. && echo \"CHROMA_SERVER_AUTH_CREDENTIALS=test-token\nCHROMA_SERVER_AUTH_CREDENTIALS_PROVIDER=chromadb.auth.token.TokenConfigServerAuthCredentialsProvider\nCHROMA_SERVER_AUTH_PROVIDER=chromadb.auth.token.TokenAuthServerProvider\\nCHROMA_PORT=8001\" > .chroma_env && docker-compose -f docker-compose.test-auth.yml --env-file ./.chroma_env up --detach && sleep 5", + "db:run-auth-xtoken": "cd ../.. && echo \"CHROMA_SERVER_AUTH_TOKEN_TRANSPORT_HEADER=X_CHROMA_TOKEN\nCHROMA_SERVER_AUTH_CREDENTIALS=test-token\nCHROMA_SERVER_AUTH_CREDENTIALS_PROVIDER=chromadb.auth.token.TokenConfigServerAuthCredentialsProvider\nCHROMA_SERVER_AUTH_PROVIDER=chromadb.auth.token.TokenAuthServerProvider\\nCHROMA_PORT=8001\" > .chroma_env && docker-compose -f docker-compose.test-auth.yml --env-file ./.chroma_env up --detach && sleep 5", "clean": "rimraf dist", "build": "run-s clean build:*", "build:main": "tsc -p tsconfig.json", diff --git a/clients/js/src/auth.ts b/clients/js/src/auth.ts index e7626988f5d..4f833f97d61 100644 --- a/clients/js/src/auth.ts +++ b/clients/js/src/auth.ts @@ -118,7 +118,10 @@ class BasicAuthClientAuthProvider implements ClientAuthProvider { * @throws {Error} If neither credentials provider or text credentials are supplied. */ - constructor(options: { textCredentials: any; credentialsProvider: ClientAuthCredentialsProvider | undefined }) { + constructor(options: { + textCredentials: any; + credentialsProvider: ClientAuthCredentialsProvider | undefined + }) { if (!options.credentialsProvider && !options.textCredentials) { throw new Error("Either credentials provider or text credentials must be supplied."); } @@ -130,6 +133,85 @@ class BasicAuthClientAuthProvider implements ClientAuthProvider { } } +class TokenAuthCredentials implements AbstractCredentials { + private readonly credentials: SecretStr; + + constructor(_creds: string) { + this.credentials = new SecretStr(_creds) + } + + getCredentials(): SecretStr { + return this.credentials; + } +} + +export class TokenCredentialsProvider implements ClientAuthCredentialsProvider { + private readonly credentials: TokenAuthCredentials; + + constructor(_creds: string | undefined) { + if (_creds === undefined && !process.env.CHROMA_CLIENT_AUTH_CREDENTIALS) throw new Error("Credentials must be supplied via environment variable (CHROMA_CLIENT_AUTH_CREDENTIALS) or passed in as configuration."); + this.credentials = new TokenAuthCredentials((_creds ?? process.env.CHROMA_CLIENT_AUTH_CREDENTIALS) as string); + } + + getCredentials(): TokenAuthCredentials { + return this.credentials; + } +} + +export class TokenClientAuthProvider implements ClientAuthProvider { + private readonly credentialsProvider: ClientAuthCredentialsProvider; + private readonly providerOptions: { headerType: TokenHeaderType }; + + constructor(options: { + textCredentials: any; + credentialsProvider: ClientAuthCredentialsProvider | undefined, + providerOptions?: { headerType: TokenHeaderType } + }) { + if (!options.credentialsProvider && !options.textCredentials) { + throw new Error("Either credentials provider or text credentials must be supplied."); + } + if (options.providerOptions === undefined || !options.providerOptions.hasOwnProperty("headerType")) { + this.providerOptions = {headerType: "AUTHORIZATION"}; + } else { + this.providerOptions = {headerType: options.providerOptions.headerType}; + } + this.credentialsProvider = options.credentialsProvider || new TokenCredentialsProvider(options.textCredentials); + } + + authenticate(): ClientAuthResponse { + return new TokenClientAuthResponse(this.credentialsProvider.getCredentials(), this.providerOptions.headerType); + } + +} + + +type TokenHeaderType = 'AUTHORIZATION' | 'X_CHROMA_TOKEN'; + +const TokenHeader: Record { key: string; value: string; }> = { + AUTHORIZATION: (value: string) => ({key: "Authorization", value: `Bearer ${value}`}), + X_CHROMA_TOKEN: (value: string) => ({key: "X-Chroma-Token", value: value}) +} + +class TokenClientAuthResponse implements ClientAuthResponse { + constructor(private readonly credentials: TokenAuthCredentials, private readonly headerType: TokenHeaderType = 'AUTHORIZATION') { + } + + getAuthInfo(): { key: string; value: string } { + if (this.headerType === 'AUTHORIZATION') { + return TokenHeader.AUTHORIZATION(this.credentials.getCredentials().getSecret()); + } else if (this.headerType === 'X_CHROMA_TOKEN') { + return TokenHeader.X_CHROMA_TOKEN(this.credentials.getCredentials().getSecret()); + } else { + throw new Error("Invalid header type: " + this.headerType + ". Valid types are: " + Object.keys(TokenHeader).join(", ")); + } + } + + getAuthInfoType(): AuthInfoType { + return AuthInfoType.HEADER; + } +} + + export class IsomorphicFetchClientAuthProtocolAdapter implements ClientAuthProtocolAdapter { authProvider: ClientAuthProvider | undefined; wrapperApi: DefaultApi | undefined; @@ -144,7 +226,17 @@ export class IsomorphicFetchClientAuthProtocolAdapter implements ClientAuthProto switch (authConfiguration.provider) { case "basic": - this.authProvider = new BasicAuthClientAuthProvider({textCredentials: authConfiguration.credentials, credentialsProvider: authConfiguration.credentialsProvider}); + this.authProvider = new BasicAuthClientAuthProvider({ + textCredentials: authConfiguration.credentials, + credentialsProvider: authConfiguration.credentialsProvider + }); + break; + case "token": + this.authProvider = new TokenClientAuthProvider({ + textCredentials: authConfiguration.credentials, + credentialsProvider: authConfiguration.credentialsProvider, + providerOptions: authConfiguration.providerOptions + }); break; default: this.authProvider = undefined; @@ -225,4 +317,5 @@ export type AuthOptions = { credentialsProvider?: ClientAuthCredentialsProvider | undefined, configProvider?: ClientAuthConfigurationProvider | undefined, credentials?: any | undefined, + providerOptions?: any | undefined } diff --git a/clients/js/test/auth.basic.test.ts b/clients/js/test/auth.basic.test.ts index a698e02c2da..6bbcf230087 100644 --- a/clients/js/test/auth.basic.test.ts +++ b/clients/js/test/auth.basic.test.ts @@ -1,6 +1,6 @@ import {expect, test} from "@jest/globals"; import {ChromaClient} from "../src/ChromaClient"; -import chroma from "./initClientWithAuth"; +import {chromaBasic} from "./initClientWithAuth"; import chromaNoAuth from "./initClient"; test("it should get the version without auth needed", async () => { @@ -22,12 +22,12 @@ test("it should raise error when non authenticated", async () => { }); test('it should list collections', async () => { - await chroma.reset() - let collections = await chroma.listCollections() + await chromaBasic.reset() + let collections = await chromaBasic.listCollections() expect(collections).toBeDefined() expect(collections).toBeInstanceOf(Array) expect(collections.length).toBe(0) - const collection = await chroma.createCollection({name: "test"}); - collections = await chroma.listCollections() + await chromaBasic.createCollection({name: "test"}); + collections = await chromaBasic.listCollections() expect(collections.length).toBe(1) }) diff --git a/clients/js/test/auth.token.test.ts b/clients/js/test/auth.token.test.ts new file mode 100644 index 00000000000..96612480ac8 --- /dev/null +++ b/clients/js/test/auth.token.test.ts @@ -0,0 +1,59 @@ +import {expect, test} from "@jest/globals"; +import {ChromaClient} from "../src/ChromaClient"; +import {chromaTokenDefault, chromaTokenBearer, chromaTokenXToken} from "./initClientWithAuth"; +import chromaNoAuth from "./initClient"; + +test("it should get the version without auth needed", async () => { + const version = await chromaNoAuth.version(); + expect(version).toBeDefined(); + expect(version).toMatch(/^[0-9]+\.[0-9]+\.[0-9]+$/); +}); + +test("it should get the heartbeat without auth needed", async () => { + const heartbeat = await chromaNoAuth.heartbeat(); + expect(heartbeat).toBeDefined(); + expect(heartbeat).toBeGreaterThan(0); +}); + +test("it should raise error when non authenticated", async () => { + await expect(chromaNoAuth.listCollections()).rejects.toMatchObject({ + status: 401 + }); +}); + +if (!process.env.XTOKEN_TEST) { + test('it should list collections with default token config', async () => { + await chromaTokenDefault.reset() + let collections = await chromaTokenDefault.listCollections() + expect(collections).toBeDefined() + expect(collections).toBeInstanceOf(Array) + expect(collections.length).toBe(0) + const collection = await chromaTokenDefault.createCollection({name: "test"}); + collections = await chromaTokenDefault.listCollections() + expect(collections.length).toBe(1) + }) + + test('it should list collections with explicit bearer token config', async () => { + await chromaTokenBearer.reset() + let collections = await chromaTokenBearer.listCollections() + expect(collections).toBeDefined() + expect(collections).toBeInstanceOf(Array) + expect(collections.length).toBe(0) + const collection = await chromaTokenBearer.createCollection({name: "test"}); + collections = await chromaTokenBearer.listCollections() + expect(collections.length).toBe(1) + }) +} else { + + test('it should list collections with explicit x-token token config', async () => { + await chromaTokenXToken.reset() + let collections = await chromaTokenXToken.listCollections() + expect(collections).toBeDefined() + expect(collections).toBeInstanceOf(Array) + expect(collections.length).toBe(0) + const collection = await chromaTokenXToken.createCollection({name: "test"}); + collections = await chromaTokenXToken.listCollections() + expect(collections.length).toBe(1) + }) + +} diff --git a/clients/js/test/initClientWithAuth.ts b/clients/js/test/initClientWithAuth.ts index b24c9a48d1d..4c061d089d5 100644 --- a/clients/js/test/initClientWithAuth.ts +++ b/clients/js/test/initClientWithAuth.ts @@ -2,6 +2,13 @@ import {ChromaClient} from "../src/ChromaClient"; const PORT = process.env.PORT || "8000"; const URL = "http://localhost:" + PORT; -const chroma = new ChromaClient({path: URL, auth: {provider: "basic", credentials: "admin:admin"}}); - -export default chroma; +export const chromaBasic = new ChromaClient({path: URL, auth: {provider: "basic", credentials: "admin:admin"}}); +export const chromaTokenDefault = new ChromaClient({path: URL, auth: {provider: "token", credentials: "test-token"}}); +export const chromaTokenBearer = new ChromaClient({ + path: URL, + auth: {provider: "token", credentials: "test-token", providerOptions: {headerType: "AUTHORIZATION"}} +}); +export const chromaTokenXToken = new ChromaClient({ + path: URL, + auth: {provider: "token", credentials: "test-token", providerOptions: {headerType: "X_CHROMA_TOKEN"}} +}); diff --git a/docker-compose.test-auth.yml b/docker-compose.test-auth.yml index 921f0749f06..c66cfc8202b 100644 --- a/docker-compose.test-auth.yml +++ b/docker-compose.test-auth.yml @@ -17,9 +17,11 @@ services: - ANONYMIZED_TELEMETRY=False - ALLOW_RESET=True - IS_PERSISTENT=TRUE - - CHROMA_SERVER_AUTH_CREDENTIALS_FILE=/chroma/server.htpasswd - - CHROMA_SERVER_AUTH_CREDENTIALS_PROVIDER=chromadb.auth.providers.HtpasswdFileServerAuthCredentialsProvider - - CHROMA_SERVER_AUTH_PROVIDER=chromadb.auth.basic.BasicAuthServerProvider + - CHROMA_SERVER_AUTH_CREDENTIALS_FILE=${CHROMA_SERVER_AUTH_CREDENTIALS_FILE} + - CHROMA_SERVER_AUTH_CREDENTIALS=${CHROMA_SERVER_AUTH_CREDENTIALS} + - CHROMA_SERVER_AUTH_CREDENTIALS_PROVIDER=${CHROMA_SERVER_AUTH_CREDENTIALS_PROVIDER} + - CHROMA_SERVER_AUTH_PROVIDER=${CHROMA_SERVER_AUTH_PROVIDER} + - CHROMA_SERVER_AUTH_TOKEN_TRANSPORT_HEADER=${CHROMA_SERVER_AUTH_TOKEN_TRANSPORT_HEADER} ports: - ${CHROMA_PORT}:8000 networks: From 82b9c830f70e211247da03ec82bfcabaf36154a6 Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Mon, 18 Sep 2023 23:00:57 +0300 Subject: [PATCH 15/39] [ENH]: CIP-5: Large Batch Handling Improvements Proposal (#1077) - Including only CIP for review. Refs: #1049 ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - New proposal to handle large batches of embeddings gracefully ## Test plan *How are these changes tested?* - [ ] Tests pass locally with `pytest` for python, `yarn test` for js ## Documentation Changes TBD --------- Signed-off-by: sunilkumardash9 Co-authored-by: Sunil Kumar Dash <47926185+sunilkumardash9@users.noreply.github.com> --- chromadb/api/__init__.py | 7 ++ chromadb/api/fastapi.py | 81 +++++++++++-------- chromadb/api/segment.py | 46 +++++++++-- chromadb/api/types.py | 12 ++- chromadb/server/fastapi/__init__.py | 8 ++ chromadb/test/property/test_add.py | 81 ++++++++++++++++++- chromadb/test/test_api.py | 18 +++++ chromadb/utils/batch_utils.py | 34 ++++++++ ...CIP_5_Large_Batch_Handling_Improvements.md | 59 ++++++++++++++ 9 files changed, 302 insertions(+), 44 deletions(-) create mode 100644 chromadb/utils/batch_utils.py create mode 100644 docs/CIP_5_Large_Batch_Handling_Improvements.md diff --git a/chromadb/api/__init__.py b/chromadb/api/__init__.py index c1c83580e9e..50f2ff1ecef 100644 --- a/chromadb/api/__init__.py +++ b/chromadb/api/__init__.py @@ -378,3 +378,10 @@ def get_settings(self) -> Settings: """ pass + + @property + @abstractmethod + def max_batch_size(self) -> int: + """Return the maximum number of records that can be submitted in a single call + to submit_embeddings.""" + pass diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index c08458a2fcb..2ddd537ebff 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -1,6 +1,6 @@ import json import logging -from typing import Optional, cast +from typing import Optional, cast, Tuple from typing import Sequence from uuid import UUID @@ -23,6 +23,7 @@ GetResult, QueryResult, CollectionMetadata, + validate_batch, ) from chromadb.auth import ( ClientAuthProvider, @@ -38,6 +39,7 @@ class FastAPI(API): _settings: Settings + _max_batch_size: int = -1 @staticmethod def _validate_host(host: str) -> None: @@ -296,6 +298,29 @@ def _delete( raise_chroma_error(resp) return cast(IDs, resp.json()) + def _submit_batch( + self, + batch: Tuple[ + IDs, Optional[Embeddings], Optional[Metadatas], Optional[Documents] + ], + url: str, + ) -> requests.Response: + """ + Submits a batch of embeddings to the database + """ + resp = self._session.post( + self._api_url + url, + data=json.dumps( + { + "ids": batch[0], + "embeddings": batch[1], + "metadatas": batch[2], + "documents": batch[3], + } + ), + ) + return resp + @override def _add( self, @@ -309,18 +334,9 @@ def _add( Adds a batch of embeddings to the database - pass in column oriented data lists """ - resp = self._session.post( - self._api_url + "/collections/" + str(collection_id) + "/add", - data=json.dumps( - { - "ids": ids, - "embeddings": embeddings, - "metadatas": metadatas, - "documents": documents, - } - ), - ) - + batch = (ids, embeddings, metadatas, documents) + validate_batch(batch, {"max_batch_size": self.max_batch_size}) + resp = self._submit_batch(batch, "/collections/" + str(collection_id) + "/add") raise_chroma_error(resp) return True @@ -337,18 +353,11 @@ def _update( Updates a batch of embeddings in the database - pass in column oriented data lists """ - resp = self._session.post( - self._api_url + "/collections/" + str(collection_id) + "/update", - data=json.dumps( - { - "ids": ids, - "embeddings": embeddings, - "metadatas": metadatas, - "documents": documents, - } - ), + batch = (ids, embeddings, metadatas, documents) + validate_batch(batch, {"max_batch_size": self.max_batch_size}) + resp = self._submit_batch( + batch, "/collections/" + str(collection_id) + "/update" ) - resp.raise_for_status() return True @@ -365,18 +374,11 @@ def _upsert( Upserts a batch of embeddings in the database - pass in column oriented data lists """ - resp = self._session.post( - self._api_url + "/collections/" + str(collection_id) + "/upsert", - data=json.dumps( - { - "ids": ids, - "embeddings": embeddings, - "metadatas": metadatas, - "documents": documents, - } - ), + batch = (ids, embeddings, metadatas, documents) + validate_batch(batch, {"max_batch_size": self.max_batch_size}) + resp = self._submit_batch( + batch, "/collections/" + str(collection_id) + "/upsert" ) - resp.raise_for_status() return True @@ -434,6 +436,15 @@ def get_settings(self) -> Settings: """Returns the settings of the client""" return self._settings + @property + @override + def max_batch_size(self) -> int: + if self._max_batch_size == -1: + resp = self._session.get(self._api_url + "/pre-flight-checks") + raise_chroma_error(resp) + self._max_batch_size = cast(int, resp.json()["max_batch_size"]) + return self._max_batch_size + def raise_chroma_error(resp: requests.Response) -> None: """Raises an error if the response is not ok, using a ChromaError if possible""" diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index 7f7712922fa..dd846891b28 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -26,6 +26,7 @@ validate_update_metadata, validate_where, validate_where_document, + validate_batch, ) from chromadb.telemetry.events import CollectionAddEvent, CollectionDeleteEvent @@ -38,6 +39,7 @@ import logging import re + logger = logging.getLogger(__name__) @@ -241,9 +243,18 @@ def _add( ) -> bool: coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.ADD) - + validate_batch( + (ids, embeddings, metadatas, documents), + {"max_batch_size": self.max_batch_size}, + ) records_to_submit = [] - for r in _records(t.Operation.ADD, ids, embeddings, metadatas, documents): + for r in _records( + t.Operation.ADD, + ids=ids, + embeddings=embeddings, + metadatas=metadatas, + documents=documents, + ): self._validate_embedding_record(coll, r) records_to_submit.append(r) self._producer.submit_embeddings(coll["topic"], records_to_submit) @@ -262,9 +273,18 @@ def _update( ) -> bool: coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.UPDATE) - + validate_batch( + (ids, embeddings, metadatas, documents), + {"max_batch_size": self.max_batch_size}, + ) records_to_submit = [] - for r in _records(t.Operation.UPDATE, ids, embeddings, metadatas, documents): + for r in _records( + t.Operation.UPDATE, + ids=ids, + embeddings=embeddings, + metadatas=metadatas, + documents=documents, + ): self._validate_embedding_record(coll, r) records_to_submit.append(r) self._producer.submit_embeddings(coll["topic"], records_to_submit) @@ -282,9 +302,18 @@ def _upsert( ) -> bool: coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.UPSERT) - + validate_batch( + (ids, embeddings, metadatas, documents), + {"max_batch_size": self.max_batch_size}, + ) records_to_submit = [] - for r in _records(t.Operation.UPSERT, ids, embeddings, metadatas, documents): + for r in _records( + t.Operation.UPSERT, + ids=ids, + embeddings=embeddings, + metadatas=metadatas, + documents=documents, + ): self._validate_embedding_record(coll, r) records_to_submit.append(r) self._producer.submit_embeddings(coll["topic"], records_to_submit) @@ -524,6 +553,11 @@ def reset(self) -> bool: def get_settings(self) -> Settings: return self._settings + @property + @override + def max_batch_size(self) -> int: + return self._producer.max_batch_size + def _topic(self, collection_id: UUID) -> str: return f"persistent://{self._tenant_id}/{self._topic_ns}/{collection_id}" diff --git a/chromadb/api/types.py b/chromadb/api/types.py index 7979dba624e..017e356ffac 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -1,4 +1,4 @@ -from typing import Optional, Union, Sequence, TypeVar, List, Dict, Any +from typing import Optional, Union, Sequence, TypeVar, List, Dict, Any, Tuple from typing_extensions import Literal, TypedDict, Protocol import chromadb.errors as errors from chromadb.types import ( @@ -367,3 +367,13 @@ def validate_embeddings(embeddings: Embeddings) -> Embeddings: f"Expected each value in the embedding to be a int or float, got {embeddings}" ) return embeddings + + +def validate_batch( + batch: Tuple[IDs, Optional[Embeddings], Optional[Metadatas], Optional[Documents]], + limits: Dict[str, Any], +) -> None: + if len(batch[0]) > limits["max_batch_size"]: + raise ValueError( + f"Batch size {len(batch[0])} exceeds maximum batch size {limits['max_batch_size']}" + ) diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index d8e43c51081..e92d16d63ba 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -126,6 +126,9 @@ def __init__(self, settings: Settings): self.router.add_api_route("/api/v1/reset", self.reset, methods=["POST"]) self.router.add_api_route("/api/v1/version", self.version, methods=["GET"]) self.router.add_api_route("/api/v1/heartbeat", self.heartbeat, methods=["GET"]) + self.router.add_api_route( + "/api/v1/pre-flight-checks", self.pre_flight_checks, methods=["GET"] + ) self.router.add_api_route( "/api/v1/collections", @@ -312,3 +315,8 @@ def get_nearest_neighbors( include=query.include, ) return nnresult + + def pre_flight_checks(self) -> Dict[str, Any]: + return { + "max_batch_size": self._api.max_batch_size, + } diff --git a/chromadb/test/property/test_add.py b/chromadb/test/property/test_add.py index 602df2fa81b..1980ed2a9d9 100644 --- a/chromadb/test/property/test_add.py +++ b/chromadb/test/property/test_add.py @@ -1,11 +1,15 @@ -from typing import cast +import random +import uuid +from random import randint +from typing import cast, List, Any, Dict import pytest import hypothesis.strategies as st from hypothesis import given, settings from chromadb.api import API -from chromadb.api.types import Embeddings +from chromadb.api.types import Embeddings, Metadatas import chromadb.test.property.strategies as strategies import chromadb.test.property.invariants as invariants +from chromadb.utils.batch_utils import create_batches collection_st = st.shared(strategies.collections(with_hnsw_params=True), key="coll") @@ -44,6 +48,79 @@ def test_add( ) +def create_large_recordset( + min_size: int = 45000, + max_size: int = 50000, +) -> strategies.RecordSet: + size = randint(min_size, max_size) + + ids = [str(uuid.uuid4()) for _ in range(size)] + metadatas = [{"some_key": f"{i}"} for i in range(size)] + documents = [f"Document {i}" for i in range(size)] + embeddings = [[1, 2, 3] for _ in range(size)] + record_set: Dict[str, List[Any]] = { + "ids": ids, + "embeddings": cast(Embeddings, embeddings), + "metadatas": metadatas, + "documents": documents, + } + return record_set + + +@given(collection=collection_st) +@settings(deadline=None, max_examples=1) +def test_add_large(api: API, collection: strategies.Collection) -> None: + api.reset() + record_set = create_large_recordset( + min_size=api.max_batch_size, + max_size=api.max_batch_size + int(api.max_batch_size * random.random()), + ) + coll = api.create_collection( + name=collection.name, + metadata=collection.metadata, + embedding_function=collection.embedding_function, + ) + normalized_record_set = invariants.wrap_all(record_set) + + if not invariants.is_metadata_valid(normalized_record_set): + with pytest.raises(Exception): + coll.add(**normalized_record_set) + return + for batch in create_batches( + api=api, + ids=cast(List[str], record_set["ids"]), + embeddings=cast(Embeddings, record_set["embeddings"]), + metadatas=cast(Metadatas, record_set["metadatas"]), + documents=cast(List[str], record_set["documents"]), + ): + coll.add(*batch) + invariants.count(coll, cast(strategies.RecordSet, normalized_record_set)) + + +@given(collection=collection_st) +@settings(deadline=None, max_examples=1) +def test_add_large_exceeding(api: API, collection: strategies.Collection) -> None: + api.reset() + record_set = create_large_recordset( + min_size=api.max_batch_size, + max_size=api.max_batch_size + int(api.max_batch_size * random.random()), + ) + coll = api.create_collection( + name=collection.name, + metadata=collection.metadata, + embedding_function=collection.embedding_function, + ) + normalized_record_set = invariants.wrap_all(record_set) + + if not invariants.is_metadata_valid(normalized_record_set): + with pytest.raises(Exception): + coll.add(**normalized_record_set) + return + with pytest.raises(Exception) as e: + coll.add(**record_set) + assert "exceeds maximum batch size" in str(e.value) + + # TODO: This test fails right now because the ids are not sorted by the input order @pytest.mark.xfail( reason="This is expected to fail right now. We should change the API to sort the \ diff --git a/chromadb/test/test_api.py b/chromadb/test/test_api.py index 0583d6eede7..8a12a1d9735 100644 --- a/chromadb/test/test_api.py +++ b/chromadb/test/test_api.py @@ -1,6 +1,8 @@ # type: ignore +import requests import chromadb +from chromadb.api.fastapi import FastAPI from chromadb.api.types import QueryResult from chromadb.config import Settings import chromadb.server.fastapi @@ -164,6 +166,22 @@ def test_heartbeat(api): assert heartbeat > datetime.now() - timedelta(seconds=10) +def test_max_batch_size(api): + print(api) + batch_size = api.max_batch_size + assert batch_size > 0 + + +def test_pre_flight_checks(api): + if not isinstance(api, FastAPI): + pytest.skip("Not a FastAPI instance") + + resp = requests.get(f"{api._api_url}/pre-flight-checks") + assert resp.status_code == 200 + assert resp.json() is not None + assert "max_batch_size" in resp.json().keys() + + batch_records = { "embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]], "ids": ["https://example.com/1", "https://example.com/2"], diff --git a/chromadb/utils/batch_utils.py b/chromadb/utils/batch_utils.py new file mode 100644 index 00000000000..c8c1ac1e476 --- /dev/null +++ b/chromadb/utils/batch_utils.py @@ -0,0 +1,34 @@ +from typing import Optional, Tuple, List +from chromadb.api import API +from chromadb.api.types import ( + Documents, + Embeddings, + IDs, + Metadatas, +) + + +def create_batches( + api: API, + ids: IDs, + embeddings: Optional[Embeddings] = None, + metadatas: Optional[Metadatas] = None, + documents: Optional[Documents] = None, +) -> List[Tuple[IDs, Embeddings, Optional[Metadatas], Optional[Documents]]]: + _batches: List[ + Tuple[IDs, Embeddings, Optional[Metadatas], Optional[Documents]] + ] = [] + if len(ids) > api.max_batch_size: + # create split batches + for i in range(0, len(ids), api.max_batch_size): + _batches.append( + ( # type: ignore + ids[i : i + api.max_batch_size], + embeddings[i : i + api.max_batch_size] if embeddings else None, + metadatas[i : i + api.max_batch_size] if metadatas else None, + documents[i : i + api.max_batch_size] if documents else None, + ) + ) + else: + _batches.append((ids, embeddings, metadatas, documents)) # type: ignore + return _batches diff --git a/docs/CIP_5_Large_Batch_Handling_Improvements.md b/docs/CIP_5_Large_Batch_Handling_Improvements.md new file mode 100644 index 00000000000..9b03d080f0f --- /dev/null +++ b/docs/CIP_5_Large_Batch_Handling_Improvements.md @@ -0,0 +1,59 @@ +# CIP-5: Large Batch Handling Improvements Proposal + +## Status + +Current Status: `Under Discussion` + +## **Motivation** + +As users start putting Chroma in its paces and storing ever-increasing datasets, we must ensure that errors +related to significant and potentially expensive batches are handled gracefully. This CIP proposes to add a new +setting, `max_batch_size` API, on the local segment API and use it to split large batches into smaller ones. + +## **Public Interfaces** + +The following interfaces are impacted: + +- New Server API endpoint - `/pre-flight-checks` +- New `max_batch_size` property on the `API` interface +- Updated `_add`, `_update` and `_upsert` methods on `chromadb.api.segment.SegmentAPI` +- Updated `_add`, `_update` and `_upsert` methods on `chromadb.api.fastapi.FastAPI` +- New utility library `batch_utils.py` +- New exception raised when batch size exceeds `max_batch_size` + +## **Proposed Changes** + +We propose the following changes: + +- The new `max_batch_size` property is now available in the `API` interface. The property relies on the + underlying `Producer` class + to fetch the actual value. The property will be implemented by both `chromadb.api.segment.SegmentAPI` + and `chromadb.api.fastapi.FastAPI` +- `chromadb.api.segment.SegmentAPI` will implement the `max_batch_size` property by fetching the value from the + `Producer` class. +- `chromadb.api.fastapi.FastAPI` will implement the `max_batch_size` by fetching it from a new `/pre-flight-checks` + endpoint on the Server. +- New `/pre-flight-checks` endpoint on the Server will return a dictionary with pre-flight checks the client must + fulfil to integrate with the server side. For now, we propose using this only for `max_batch_size`, but we can + add more checks in the future. The pre-flight checks will be only fetched once per client and cached for the duration + of the client's lifetime. +- Updated `_add`, `_update` and `_upsert` method on `chromadb.api.segment.SegmentAPI` to validate batch size. +- Updated `_add`, `_update` and `_upsert` method on `chromadb.api.fastapi.FastAPI` to validate batch size (client-side + validation) +- New utility library `batch_utils.py` will contain the logic for splitting batches into smaller ones. + +## **Compatibility, Deprecation, and Migration Plan** + +The change will be fully compatible with existing implementations. The changes will be transparent to the user. + +## **Test Plan** + +New tests: + +- Batch splitting tests for `chromadb.api.segment.SegmentAPI` +- Batch splitting tests for `chromadb.api.fastapi.FastAPI` +- Tests for `/pre-flight-checks` endpoint + +## **Rejected Alternatives** + +N/A From 9e05cb372a05b87a6c890e43ee5ebe05c0f20fe6 Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Tue, 19 Sep 2023 01:56:53 +0300 Subject: [PATCH 16/39] [BUG]: Fixing broken peer deps (#1153) Refs: #1104 ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Removed transformers and web-ai peer dependencies. ## Test plan *How are these changes tested?* - [ ] Manual testing - `mkdir testproject && cd testproject && npm init -y && npm link chromadb && npm add langchain` ## Documentation Changes N/A --- clients/js/package.json | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/clients/js/package.json b/clients/js/package.json index 25d3d90b94c..d72c655e3f1 100644 --- a/clients/js/package.json +++ b/clients/js/package.json @@ -65,22 +65,10 @@ "isomorphic-fetch": "^3.0.0" }, "peerDependencies": { - "@visheratin/web-ai": "^1.0.0", - "@visheratin/web-ai-node": "^1.0.0", - "@xenova/transformers": "^2.0.0", - "cohere-ai": "^6.0.0", + "cohere-ai": "^5.1.0", "openai": "^3.0.0 || ^4.0.0" }, "peerDependenciesMeta": { - "@visheratin/web-ai": { - "optional": true - }, - "@visheratin/web-ai-node": { - "optional": true - }, - "@xenova/transformers": { - "optional": true - }, "cohere-ai": { "optional": true }, From 3aed7b78b2ccd54f0f9306f9ef1acad174d12b3c Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Tue, 19 Sep 2023 03:43:13 +0300 Subject: [PATCH 17/39] [BUG]: Fixed BF index overflow issue with subsequent delete (#1150) Refs: #989 ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - When the BF index overflows (batch_size upon insertion of large batch it is cleared, if a subsequent delete request comes to delete Ids which were in the cleared BF index a warning is raised for non-existent embedding. The issue was resolved by separately checking if BF the record exists in the BF index and conditionally execute the BF removal ## Test plan *How are these changes tested?* - [x] Tests pass locally with `pytest` for python ## Documentation Changes N/A --- chromadb/segment/impl/vector/local_persistent_hnsw.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/chromadb/segment/impl/vector/local_persistent_hnsw.py b/chromadb/segment/impl/vector/local_persistent_hnsw.py index a0b52acd07a..6e1df7b1f1f 100644 --- a/chromadb/segment/impl/vector/local_persistent_hnsw.py +++ b/chromadb/segment/impl/vector/local_persistent_hnsw.py @@ -225,11 +225,13 @@ def _write_records(self, records: Sequence[EmbeddingRecord]) -> None: exists_in_index = self._id_to_label.get( id, None ) is not None or self._brute_force_index.has_id(id) + exists_in_bf_index = self._brute_force_index.has_id(id) if op == Operation.DELETE: if exists_in_index: self._curr_batch.apply(record) - self._brute_force_index.delete([record]) + if exists_in_bf_index: + self._brute_force_index.delete([record]) else: logger.warning(f"Delete of nonexisting embedding ID: {id}") From b930a862bad2cdaa25a8c888d04e97f33351d412 Mon Sep 17 00:00:00 2001 From: Jeff Huber Date: Mon, 18 Sep 2023 21:18:25 -0700 Subject: [PATCH 18/39] js 1.5.10 (#1155) Release https://github.com/chroma-core/chroma/pull/1153 --- clients/js/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clients/js/package.json b/clients/js/package.json index d72c655e3f1..55ca7a6c7aa 100644 --- a/clients/js/package.json +++ b/clients/js/package.json @@ -1,6 +1,6 @@ { "name": "chromadb", - "version": "1.5.9", + "version": "1.5.10", "description": "A JavaScript interface for chroma", "keywords": [], "author": "", From aa0387a6d7ba8b42009173d9b19a270a889eae22 Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Tue, 19 Sep 2023 19:14:16 +0300 Subject: [PATCH 19/39] [BUG]: Added cohere version 6.x support in peer dependencies (#1156) Refs: #1104 ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Expanding Cohere version also to support 6.x This plays nice with the rest of the ecosystem ## Test plan *How are these changes tested?* - [x] `yarn test` for js ## Documentation Changes N/A --- clients/js/package.json | 2 +- clients/js/test/add.collections.test.ts | 22 ++++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/clients/js/package.json b/clients/js/package.json index 55ca7a6c7aa..678353984f1 100644 --- a/clients/js/package.json +++ b/clients/js/package.json @@ -65,7 +65,7 @@ "isomorphic-fetch": "^3.0.0" }, "peerDependencies": { - "cohere-ai": "^5.1.0", + "cohere-ai": "^5.0.0 || ^6.0.0", "openai": "^3.0.0 || ^4.0.0" }, "peerDependenciesMeta": { diff --git a/clients/js/test/add.collections.test.ts b/clients/js/test/add.collections.test.ts index 32f76c89a30..cb89fa8dbe0 100644 --- a/clients/js/test/add.collections.test.ts +++ b/clients/js/test/add.collections.test.ts @@ -4,6 +4,7 @@ import { DOCUMENTS, EMBEDDINGS, IDS } from './data'; import { METADATAS } from './data'; import { IncludeEnum } from "../src/types"; import {OpenAIEmbeddingFunction} from "../src/embeddings/OpenAIEmbeddingFunction"; +import {CohereEmbeddingFunction} from "../src/embeddings/CohereEmbeddingFunction"; test("it should add single embeddings to a collection", async () => { await chroma.reset(); const collection = await chroma.createCollection({ name: "test" }); @@ -57,6 +58,27 @@ if (!process.env.OPENAI_API_KEY) { }); } +if (!process.env.COHERE_API_KEY) { + test.skip("it should add Cohere embeddings", async () => { + }); +} else { + test("it should add Cohere embeddings", async () => { + await chroma.reset(); + const embedder = new CohereEmbeddingFunction({ cohere_api_key: process.env.COHERE_API_KEY || "" }) + const collection = await chroma.createCollection({ name: "test" ,embeddingFunction: embedder}); + const embeddings = await embedder.generate(DOCUMENTS); + await collection.add({ ids: IDS, embeddings: embeddings }); + const count = await collection.count(); + expect(count).toBe(3); + var res = await collection.get({ + ids: IDS, include: [ + IncludeEnum.Embeddings, + ] + }); + expect(res.embeddings).toEqual(embeddings); // reverse because of the order of the ids + }); +} + test("add documents", async () => { await chroma.reset(); const collection = await chroma.createCollection({ name: "test" }); From f3284f62b9c65334055149b46f8b7c545c6fd637 Mon Sep 17 00:00:00 2001 From: Hammad Bashir Date: Tue, 19 Sep 2023 09:34:03 -0700 Subject: [PATCH 20/39] [RELEASE] JS 1.5.11 (#1161) Releases Js 1.5.11 --- clients/js/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clients/js/package.json b/clients/js/package.json index 678353984f1..b48269e6f64 100644 --- a/clients/js/package.json +++ b/clients/js/package.json @@ -1,6 +1,6 @@ { "name": "chromadb", - "version": "1.5.10", + "version": "1.5.11", "description": "A JavaScript interface for chroma", "keywords": [], "author": "", From 7d2dd011cf96b7993ad5290a99a64dcd7e1797ef Mon Sep 17 00:00:00 2001 From: Hammad Bashir Date: Tue, 19 Sep 2023 09:44:15 -0700 Subject: [PATCH 21/39] Release 0.4.11 (#1162) Release 0.4.11 --- chromadb/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chromadb/__init__.py b/chromadb/__init__.py index 804ff93404c..12e189c4725 100644 --- a/chromadb/__init__.py +++ b/chromadb/__init__.py @@ -44,7 +44,7 @@ __settings = Settings() -__version__ = "0.4.10" +__version__ = "0.4.11" # Workaround to deal with Colab's old sqlite3 version try: From dffc8067dba7d9de5d89aaf282ff8c3810f1007d Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Wed, 20 Sep 2023 10:47:16 +0300 Subject: [PATCH 22/39] [BUG]: Docker entrypoint logging path (#1159) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Initial CLI PR (https://github.com/chroma-core/chroma/pull/1032) moved the logging config inside chromadb. If image is built with the current setup it will result in Error: Invalid value for '--log-config': Path 'log_config.yml' does not exist. ## Test plan *How are these changes tested?* Steps to reproduce (prior to this PR): - `docker build -t chroma:canary .` - `docker run --rm -it chroma:canary` ## Documentation Changes *Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?* --- bin/docker_entrypoint.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/docker_entrypoint.sh b/bin/docker_entrypoint.sh index 3b0d146c70b..ce500ee80b9 100755 --- a/bin/docker_entrypoint.sh +++ b/bin/docker_entrypoint.sh @@ -3,4 +3,4 @@ echo "Rebuilding hnsw to ensure architecture compatibility" pip install --force-reinstall --no-cache-dir chroma-hnswlib export IS_PERSISTENT=1 -uvicorn chromadb.app:app --workers 1 --host 0.0.0.0 --port 8000 --proxy-headers --log-config log_config.yml +uvicorn chromadb.app:app --workers 1 --host 0.0.0.0 --port 8000 --proxy-headers --log-config chromadb/log_config.yml From 020950470cb75a54e761a7f3ba6de2738a7ddc9c Mon Sep 17 00:00:00 2001 From: Hammad Bashir Date: Wed, 20 Sep 2023 00:50:31 -0700 Subject: [PATCH 23/39] [RELEASE] 0.4.12 to fix Dockerfile log issue (#1165) Releasing a hotfix for #1159 which addresses #1164 which breaks the docker image --- chromadb/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chromadb/__init__.py b/chromadb/__init__.py index 12e189c4725..0ff5244a80f 100644 --- a/chromadb/__init__.py +++ b/chromadb/__init__.py @@ -44,7 +44,7 @@ __settings = Settings() -__version__ = "0.4.11" +__version__ = "0.4.12" # Workaround to deal with Colab's old sqlite3 version try: From 896822231e9444ebe41f7bc04047a686279ffa03 Mon Sep 17 00:00:00 2001 From: Hammad Bashir Date: Wed, 20 Sep 2023 02:03:07 -0700 Subject: [PATCH 24/39] [ENH] Pulsar Producer & Consumer (#921) ## Description of changes *Summarize the changes made by this PR.* - New functionality - Adds a basic pulsar producer, consumer and associated tests. As well as a docker compose for the distributed version of chroma. ## Test plan We added bin/cluster-test.sh, which starts pulsar and allows test_producer_consumer to run the pulsar fixture. ## Documentation Changes None required. --- .github/workflows/chroma-cluster-test.yml | 31 ++ .pre-commit-config.yaml | 2 +- bin/cluster-test.sh | 16 + chromadb/api/segment.py | 6 +- chromadb/config.py | 4 + chromadb/ingest/impl/pulsar.py | 304 ++++++++++++++++++ chromadb/ingest/impl/pulsar_admin.py | 81 +++++ chromadb/ingest/impl/utils.py | 20 ++ chromadb/proto/chroma.proto | 40 +++ chromadb/proto/chroma_pb2.py | 42 +++ chromadb/proto/chroma_pb2.pyi | 247 ++++++++++++++ chromadb/proto/convert.py | 150 +++++++++ chromadb/segment/impl/metadata/sqlite.py | 4 +- .../impl/vector/local_persistent_hnsw.py | 1 - .../test/ingest/test_producer_consumer.py | 132 +++++--- docker-compose.cluster.yml | 66 ++++ requirements_dev.txt | 2 + 17 files changed, 1105 insertions(+), 43 deletions(-) create mode 100644 .github/workflows/chroma-cluster-test.yml create mode 100755 bin/cluster-test.sh create mode 100644 chromadb/ingest/impl/pulsar.py create mode 100644 chromadb/ingest/impl/pulsar_admin.py create mode 100644 chromadb/ingest/impl/utils.py create mode 100644 chromadb/proto/chroma.proto create mode 100644 chromadb/proto/chroma_pb2.py create mode 100644 chromadb/proto/chroma_pb2.pyi create mode 100644 chromadb/proto/convert.py create mode 100644 docker-compose.cluster.yml diff --git a/.github/workflows/chroma-cluster-test.yml b/.github/workflows/chroma-cluster-test.yml new file mode 100644 index 00000000000..5ae873aa198 --- /dev/null +++ b/.github/workflows/chroma-cluster-test.yml @@ -0,0 +1,31 @@ +name: Chroma Cluster Tests + +on: + push: + branches: + - main + pull_request: + branches: + - main + - '**' + workflow_dispatch: + +jobs: + test: + strategy: + matrix: + python: ['3.7'] + platform: [ubuntu-latest] + testfile: ["chromadb/test/ingest/test_producer_consumer.py"] # Just this one test for now + runs-on: ${{ matrix.platform }} + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python }} + - name: Install test dependencies + run: python -m pip install -r requirements.txt && python -m pip install -r requirements_dev.txt + - name: Integration Test + run: bin/cluster-test.sh ${{ matrix.testfile }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6b8fbca9079..5b2ed56635e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,4 +32,4 @@ repos: hooks: - id: mypy args: [--strict, --ignore-missing-imports, --follow-imports=silent, --disable-error-code=type-abstract] - additional_dependencies: ["types-requests", "pydantic", "overrides", "hypothesis", "pytest", "pypika", "numpy"] \ No newline at end of file + additional_dependencies: ["types-requests", "pydantic", "overrides", "hypothesis", "pytest", "pypika", "numpy", "types-protobuf"] diff --git a/bin/cluster-test.sh b/bin/cluster-test.sh new file mode 100755 index 00000000000..b7255eae60a --- /dev/null +++ b/bin/cluster-test.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash + +set -e + +function cleanup { + docker compose -f docker-compose.cluster.yml down --rmi local --volumes +} + +trap cleanup EXIT + +docker compose -f docker-compose.cluster.yml up -d --wait pulsar + +export CHROMA_CLUSTER_TEST_ONLY=1 + +echo testing: python -m pytest "$@" +python -m pytest "$@" diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index dd846891b28..00002f46d27 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -1,6 +1,7 @@ from chromadb.api import API from chromadb.config import Settings, System from chromadb.db.system import SysDB +from chromadb.ingest.impl.utils import create_topic_name from chromadb.segment import SegmentManager, MetadataReader, VectorReader from chromadb.telemetry import Telemetry from chromadb.ingest import Producer @@ -130,6 +131,9 @@ def create_collection( coll = t.Collection( id=id, name=name, metadata=metadata, topic=self._topic(id), dimension=None ) + # TODO: Topic creation right now lives in the producer but it should be moved to the coordinator, + # and the producer should just be responsible for publishing messages. Coordinator should + # be responsible for all management of topics. self._producer.create_topic(coll["topic"]) segments = self._manager.create_segments(coll) self._sysdb.create_collection(coll) @@ -559,7 +563,7 @@ def max_batch_size(self) -> int: return self._producer.max_batch_size def _topic(self, collection_id: UUID) -> str: - return f"persistent://{self._tenant_id}/{self._topic_ns}/{collection_id}" + return create_topic_name(self._tenant_id, self._topic_ns, str(collection_id)) # TODO: This could potentially cause race conditions in a distributed version of the # system, since the cache is only local. diff --git a/chromadb/config.py b/chromadb/config.py index 4cb7eac1aae..6167193acd2 100644 --- a/chromadb/config.py +++ b/chromadb/config.py @@ -92,6 +92,10 @@ class Settings(BaseSettings): # type: ignore chroma_server_grpc_port: Optional[str] = None chroma_server_cors_allow_origins: List[str] = [] # eg ["http://localhost:3000"] + pulsar_broker_url: Optional[str] = None + pulsar_admin_port: Optional[str] = None + pulsar_broker_port: Optional[str] = None + chroma_server_auth_provider: Optional[str] = None @validator("chroma_server_auth_provider", pre=True, always=True, allow_reuse=True) diff --git a/chromadb/ingest/impl/pulsar.py b/chromadb/ingest/impl/pulsar.py new file mode 100644 index 00000000000..3f293c90580 --- /dev/null +++ b/chromadb/ingest/impl/pulsar.py @@ -0,0 +1,304 @@ +from __future__ import annotations +from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple +import uuid +from chromadb.config import Settings, System +from chromadb.ingest import Consumer, ConsumerCallbackFn, Producer +from overrides import overrides, EnforceOverrides +from uuid import UUID +from chromadb.ingest.impl.pulsar_admin import PulsarAdmin +from chromadb.ingest.impl.utils import create_pulsar_connection_str +from chromadb.proto.convert import from_proto_submit, to_proto_submit +import chromadb.proto.chroma_pb2 as proto +from chromadb.types import SeqId, SubmitEmbeddingRecord +import pulsar +from concurrent.futures import wait, Future + +from chromadb.utils.messageid import int_to_pulsar, pulsar_to_int + + +class PulsarProducer(Producer, EnforceOverrides): + _connection_str: str + _topic_to_producer: Dict[str, pulsar.Producer] + _client: pulsar.Client + _admin: PulsarAdmin + _settings: Settings + + def __init__(self, system: System) -> None: + pulsar_host = system.settings.require("pulsar_broker_url") + pulsar_port = system.settings.require("pulsar_broker_port") + self._connection_str = create_pulsar_connection_str(pulsar_host, pulsar_port) + self._topic_to_producer = {} + self._settings = system.settings + self._admin = PulsarAdmin(system) + super().__init__(system) + + @overrides + def start(self) -> None: + self._client = pulsar.Client(self._connection_str) + super().start() + + @overrides + def stop(self) -> None: + self._client.close() + super().stop() + + @overrides + def create_topic(self, topic_name: str) -> None: + self._admin.create_topic(topic_name) + + @overrides + def delete_topic(self, topic_name: str) -> None: + self._admin.delete_topic(topic_name) + + @overrides + def submit_embedding( + self, topic_name: str, embedding: SubmitEmbeddingRecord + ) -> SeqId: + """Add an embedding record to the given topic. Returns the SeqID of the record.""" + producer = self._get_or_create_producer(topic_name) + proto_submit: proto.SubmitEmbeddingRecord = to_proto_submit(embedding) + # TODO: batch performance / async + msg_id: pulsar.MessageId = producer.send(proto_submit.SerializeToString()) + return pulsar_to_int(msg_id) + + @overrides + def submit_embeddings( + self, topic_name: str, embeddings: Sequence[SubmitEmbeddingRecord] + ) -> Sequence[SeqId]: + if not self._running: + raise RuntimeError("Component not running") + + if len(embeddings) == 0: + return [] + + if len(embeddings) > self.max_batch_size: + raise ValueError( + f""" + Cannot submit more than {self.max_batch_size:,} embeddings at once. + Please submit your embeddings in batches of size + {self.max_batch_size:,} or less. + """ + ) + + producer = self._get_or_create_producer(topic_name) + protos_to_submit = [to_proto_submit(embedding) for embedding in embeddings] + + def create_producer_callback( + future: Future[int], + ) -> Callable[[Any, pulsar.MessageId], None]: + def producer_callback(res: Any, msg_id: pulsar.MessageId) -> None: + if msg_id: + future.set_result(pulsar_to_int(msg_id)) + else: + future.set_exception( + Exception( + "Unknown error while submitting embedding in producer_callback" + ) + ) + + return producer_callback + + futures = [] + for proto_to_submit in protos_to_submit: + future: Future[int] = Future() + producer.send_async( + proto_to_submit.SerializeToString(), + callback=create_producer_callback(future), + ) + futures.append(future) + + wait(futures) + + results: List[SeqId] = [] + for future in futures: + exception = future.exception() + if exception is not None: + raise exception + results.append(future.result()) + + return results + + @property + @overrides + def max_batch_size(self) -> int: + # For now, we use 1,000 + # TODO: tune this to a reasonable value by default + return 1000 + + def _get_or_create_producer(self, topic_name: str) -> pulsar.Producer: + if topic_name not in self._topic_to_producer: + producer = self._client.create_producer(topic_name) + self._topic_to_producer[topic_name] = producer + return self._topic_to_producer[topic_name] + + @overrides + def reset_state(self) -> None: + if not self._settings.require("allow_reset"): + raise ValueError( + "Resetting the database is not allowed. Set `allow_reset` to true in the config in tests or other non-production environments where reset should be permitted." + ) + for topic_name in self._topic_to_producer: + self._admin.delete_topic(topic_name) + self._topic_to_producer = {} + super().reset_state() + + +class PulsarConsumer(Consumer, EnforceOverrides): + class PulsarSubscription: + id: UUID + topic_name: str + start: int + end: int + callback: ConsumerCallbackFn + consumer: pulsar.Consumer + + def __init__( + self, + id: UUID, + topic_name: str, + start: int, + end: int, + callback: ConsumerCallbackFn, + consumer: pulsar.Consumer, + ): + self.id = id + self.topic_name = topic_name + self.start = start + self.end = end + self.callback = callback + self.consumer = consumer + + _connection_str: str + _client: pulsar.Client + _subscriptions: Dict[str, Set[PulsarSubscription]] + _settings: Settings + + def __init__(self, system: System) -> None: + pulsar_host = system.settings.require("pulsar_broker_url") + pulsar_port = system.settings.require("pulsar_broker_port") + self._connection_str = create_pulsar_connection_str(pulsar_host, pulsar_port) + self._subscriptions = defaultdict(set) + self._settings = system.settings + super().__init__(system) + + @overrides + def start(self) -> None: + self._client = pulsar.Client(self._connection_str) + super().start() + + @overrides + def stop(self) -> None: + self._client.close() + super().stop() + + @overrides + def subscribe( + self, + topic_name: str, + consume_fn: ConsumerCallbackFn, + start: Optional[SeqId] = None, + end: Optional[SeqId] = None, + id: Optional[UUID] = None, + ) -> UUID: + """Register a function that will be called to recieve embeddings for a given + topic. The given function may be called any number of times, with any number of + records, and may be called concurrently. + + Only records between start (exclusive) and end (inclusive) SeqIDs will be + returned. If start is None, the first record returned will be the next record + generated, not including those generated before creating the subscription. If + end is None, the consumer will consume indefinitely, otherwise it will + automatically be unsubscribed when the end SeqID is reached. + + If the function throws an exception, the function may be called again with the + same or different records. + + Takes an optional UUID as a unique subscription ID. If no ID is provided, a new + ID will be generated and returned.""" + if not self._running: + raise RuntimeError("Consumer must be started before subscribing") + + subscription_id = ( + id or uuid.uuid4() + ) # TODO: this should really be created by the coordinator and stored in sysdb + + start, end = self._validate_range(start, end) + + def wrap_callback(consumer: pulsar.Consumer, message: pulsar.Message) -> None: + msg_data = message.data() + msg_id = pulsar_to_int(message.message_id()) + submit_embedding_record = proto.SubmitEmbeddingRecord() + proto.SubmitEmbeddingRecord.ParseFromString( + submit_embedding_record, msg_data + ) + embedding_record = from_proto_submit(submit_embedding_record, msg_id) + consume_fn([embedding_record]) + consumer.acknowledge(message) + if msg_id == end: + self.unsubscribe(subscription_id) + + consumer = self._client.subscribe( + topic_name, + subscription_id.hex, + message_listener=wrap_callback, + ) + + subscription = self.PulsarSubscription( + subscription_id, topic_name, start, end, consume_fn, consumer + ) + self._subscriptions[topic_name].add(subscription) + + # NOTE: For some reason the seek() method expects a shadowed MessageId type + # which resides in _msg_id. + consumer.seek(int_to_pulsar(start)._msg_id) + + return subscription_id + + def _validate_range( + self, start: Optional[SeqId], end: Optional[SeqId] + ) -> Tuple[int, int]: + """Validate and normalize the start and end SeqIDs for a subscription using this + impl.""" + start = start or pulsar_to_int(pulsar.MessageId.latest) + end = end or self.max_seqid() + if not isinstance(start, int) or not isinstance(end, int): + raise ValueError("SeqIDs must be integers") + if start >= end: + raise ValueError(f"Invalid SeqID range: {start} to {end}") + return start, end + + @overrides + def unsubscribe(self, subscription_id: UUID) -> None: + """Unregister a subscription. The consume function will no longer be invoked, + and resources associated with the subscription will be released.""" + for topic_name, subscriptions in self._subscriptions.items(): + for subscription in subscriptions: + if subscription.id == subscription_id: + subscription.consumer.close() + subscriptions.remove(subscription) + if len(subscriptions) == 0: + del self._subscriptions[topic_name] + return + + @overrides + def min_seqid(self) -> SeqId: + """Return the minimum possible SeqID in this implementation.""" + return pulsar_to_int(pulsar.MessageId.earliest) + + @overrides + def max_seqid(self) -> SeqId: + """Return the maximum possible SeqID in this implementation.""" + return 2**192 - 1 + + @overrides + def reset_state(self) -> None: + if not self._settings.require("allow_reset"): + raise ValueError( + "Resetting the database is not allowed. Set `allow_reset` to true in the config in tests or other non-production environments where reset should be permitted." + ) + for topic_name, subscriptions in self._subscriptions.items(): + for subscription in subscriptions: + subscription.consumer.close() + self._subscriptions = defaultdict(set) + super().reset_state() diff --git a/chromadb/ingest/impl/pulsar_admin.py b/chromadb/ingest/impl/pulsar_admin.py new file mode 100644 index 00000000000..e031e4a238b --- /dev/null +++ b/chromadb/ingest/impl/pulsar_admin.py @@ -0,0 +1,81 @@ +# A thin wrapper around the pulsar admin api +import requests +from chromadb.config import System +from chromadb.ingest.impl.utils import parse_topic_name + + +class PulsarAdmin: + """A thin wrapper around the pulsar admin api, only used for interim development towards distributed chroma. + This functionality will be moved to the chroma coordinator.""" + + _connection_str: str + + def __init__(self, system: System): + pulsar_host = system.settings.require("pulsar_broker_url") + pulsar_port = system.settings.require("pulsar_admin_port") + self._connection_str = f"http://{pulsar_host}:{pulsar_port}" + + # Create the default tenant and namespace + # This is a temporary workaround until we have a proper tenant/namespace management system + self.create_tenant("default") + self.create_namespace("default", "default") + + def create_tenant(self, tenant: str) -> None: + """Make a PUT request to the admin api to create the tenant""" + + path = f"/admin/v2/tenants/{tenant}" + url = self._connection_str + path + response = requests.put( + url, json={"allowedClusters": ["standalone"], "adminRoles": []} + ) # TODO: how to manage clusters? + + if response.status_code != 204 and response.status_code != 409: + raise RuntimeError(f"Failed to create tenant {tenant}") + + def create_namespace(self, tenant: str, namespace: str) -> None: + """Make a PUT request to the admin api to create the namespace""" + + path = f"/admin/v2/namespaces/{tenant}/{namespace}" + url = self._connection_str + path + response = requests.put(url) + + if response.status_code != 204 and response.status_code != 409: + raise RuntimeError(f"Failed to create namespace {namespace}") + + def create_topic(self, topic: str) -> None: + # TODO: support non-persistent topics? + tenant, namespace, topic_name = parse_topic_name(topic) + + if tenant != "default": + raise ValueError(f"Only the default tenant is supported, got {tenant}") + if namespace != "default": + raise ValueError( + f"Only the default namespace is supported, got {namespace}" + ) + + # Make a PUT request to the admin api to create the topic + path = f"/admin/v2/persistent/{tenant}/{namespace}/{topic_name}" + url = self._connection_str + path + response = requests.put(url) + + if response.status_code != 204 and response.status_code != 409: + raise RuntimeError(f"Failed to create topic {topic_name}") + + def delete_topic(self, topic: str) -> None: + tenant, namespace, topic_name = parse_topic_name(topic) + + if tenant != "default": + raise ValueError(f"Only the default tenant is supported, got {tenant}") + if namespace != "default": + raise ValueError( + f"Only the default namespace is supported, got {namespace}" + ) + + # Make a PUT request to the admin api to delete the topic + path = f"/admin/v2/persistent/{tenant}/{namespace}/{topic_name}" + # Force delete the topic + path += "?force=true" + url = self._connection_str + path + response = requests.delete(url) + if response.status_code != 204 and response.status_code != 409: + raise RuntimeError(f"Failed to delete topic {topic_name}") diff --git a/chromadb/ingest/impl/utils.py b/chromadb/ingest/impl/utils.py new file mode 100644 index 00000000000..144384d75db --- /dev/null +++ b/chromadb/ingest/impl/utils.py @@ -0,0 +1,20 @@ +import re +from typing import Tuple + +topic_regex = r"persistent:\/\/(?P.+)\/(?P.+)\/(?P.+)" + + +def parse_topic_name(topic_name: str) -> Tuple[str, str, str]: + """Parse the topic name into the tenant, namespace and topic name""" + match = re.match(topic_regex, topic_name) + if not match: + raise ValueError(f"Invalid topic name: {topic_name}") + return match.group("tenant"), match.group("namespace"), match.group("topic") + + +def create_pulsar_connection_str(host: str, port: str) -> str: + return f"pulsar://{host}:{port}" + + +def create_topic_name(tenant: str, namespace: str, topic: str) -> str: + return f"persistent://{tenant}/{namespace}/{topic}" diff --git a/chromadb/proto/chroma.proto b/chromadb/proto/chroma.proto new file mode 100644 index 00000000000..7eefed74e12 --- /dev/null +++ b/chromadb/proto/chroma.proto @@ -0,0 +1,40 @@ +syntax = "proto3"; + +package chroma; + +enum Operation { + ADD = 0; + UPDATE = 1; + UPSERT = 2; + DELETE = 3; +} + +enum ScalarEncoding { + FLOAT32 = 0; + INT32 = 1; +} + +message Vector { + int32 dimension = 1; + bytes vector = 2; + ScalarEncoding encoding = 3; +} + +message UpdateMetadataValue { + oneof value { + string string_value = 1; + int64 int_value = 2; + double float_value = 3; + } +} + +message UpdateMetadata { + map metadata = 1; +} + +message SubmitEmbeddingRecord { + string id = 1; + optional Vector vector = 2; + optional UpdateMetadata metadata = 3; + Operation operation = 4; +} diff --git a/chromadb/proto/chroma_pb2.py b/chromadb/proto/chroma_pb2.py new file mode 100644 index 00000000000..ca8952697af --- /dev/null +++ b/chromadb/proto/chroma_pb2.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: chromadb/proto/chroma.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x1b\x63hromadb/proto/chroma.proto\x12\x06\x63hroma"U\n\x06Vector\x12\x11\n\tdimension\x18\x01 \x01(\x05\x12\x0e\n\x06vector\x18\x02 \x01(\x0c\x12(\n\x08\x65ncoding\x18\x03 \x01(\x0e\x32\x16.chroma.ScalarEncoding"b\n\x13UpdateMetadataValue\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x13\n\tint_value\x18\x02 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x01H\x00\x42\x07\n\x05value"\x96\x01\n\x0eUpdateMetadata\x12\x36\n\x08metadata\x18\x01 \x03(\x0b\x32$.chroma.UpdateMetadata.MetadataEntry\x1aL\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12*\n\x05value\x18\x02 \x01(\x0b\x32\x1b.chroma.UpdateMetadataValue:\x02\x38\x01"\xb5\x01\n\x15SubmitEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12#\n\x06vector\x18\x02 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x12-\n\x08metadata\x18\x03 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x01\x88\x01\x01\x12$\n\toperation\x18\x04 \x01(\x0e\x32\x11.chroma.OperationB\t\n\x07_vectorB\x0b\n\t_metadata*8\n\tOperation\x12\x07\n\x03\x41\x44\x44\x10\x00\x12\n\n\x06UPDATE\x10\x01\x12\n\n\x06UPSERT\x10\x02\x12\n\n\x06\x44\x45LETE\x10\x03*(\n\x0eScalarEncoding\x12\x0b\n\x07\x46LOAT32\x10\x00\x12\t\n\x05INT32\x10\x01\x62\x06proto3' +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages( + DESCRIPTOR, "chromadb.proto.chroma_pb2", _globals +) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _UPDATEMETADATA_METADATAENTRY._options = None + _UPDATEMETADATA_METADATAENTRY._serialized_options = b"8\001" + _globals["_OPERATION"]._serialized_start = 563 + _globals["_OPERATION"]._serialized_end = 619 + _globals["_SCALARENCODING"]._serialized_start = 621 + _globals["_SCALARENCODING"]._serialized_end = 661 + _globals["_VECTOR"]._serialized_start = 39 + _globals["_VECTOR"]._serialized_end = 124 + _globals["_UPDATEMETADATAVALUE"]._serialized_start = 126 + _globals["_UPDATEMETADATAVALUE"]._serialized_end = 224 + _globals["_UPDATEMETADATA"]._serialized_start = 227 + _globals["_UPDATEMETADATA"]._serialized_end = 377 + _globals["_UPDATEMETADATA_METADATAENTRY"]._serialized_start = 301 + _globals["_UPDATEMETADATA_METADATAENTRY"]._serialized_end = 377 + _globals["_SUBMITEMBEDDINGRECORD"]._serialized_start = 380 + _globals["_SUBMITEMBEDDINGRECORD"]._serialized_end = 561 +# @@protoc_insertion_point(module_scope) diff --git a/chromadb/proto/chroma_pb2.pyi b/chromadb/proto/chroma_pb2.pyi new file mode 100644 index 00000000000..b13327e982f --- /dev/null +++ b/chromadb/proto/chroma_pb2.pyi @@ -0,0 +1,247 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" +import builtins +import collections.abc +import google.protobuf.descriptor +import google.protobuf.internal.containers +import google.protobuf.internal.enum_type_wrapper +import google.protobuf.message +import sys +import typing + +if sys.version_info >= (3, 10): + import typing as typing_extensions +else: + import typing_extensions + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +class _Operation: + ValueType = typing.NewType("ValueType", builtins.int) + V: typing_extensions.TypeAlias = ValueType + +class _OperationEnumTypeWrapper( + google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_Operation.ValueType], + builtins.type, +): + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + ADD: _Operation.ValueType # 0 + UPDATE: _Operation.ValueType # 1 + UPSERT: _Operation.ValueType # 2 + DELETE: _Operation.ValueType # 3 + +class Operation(_Operation, metaclass=_OperationEnumTypeWrapper): ... + +ADD: Operation.ValueType # 0 +UPDATE: Operation.ValueType # 1 +UPSERT: Operation.ValueType # 2 +DELETE: Operation.ValueType # 3 +global___Operation = Operation + +class _ScalarEncoding: + ValueType = typing.NewType("ValueType", builtins.int) + V: typing_extensions.TypeAlias = ValueType + +class _ScalarEncodingEnumTypeWrapper( + google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[ + _ScalarEncoding.ValueType + ], + builtins.type, +): + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + FLOAT32: _ScalarEncoding.ValueType # 0 + INT32: _ScalarEncoding.ValueType # 1 + +class ScalarEncoding(_ScalarEncoding, metaclass=_ScalarEncodingEnumTypeWrapper): ... + +FLOAT32: ScalarEncoding.ValueType # 0 +INT32: ScalarEncoding.ValueType # 1 +global___ScalarEncoding = ScalarEncoding + +@typing_extensions.final +class Vector(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + DIMENSION_FIELD_NUMBER: builtins.int + VECTOR_FIELD_NUMBER: builtins.int + ENCODING_FIELD_NUMBER: builtins.int + dimension: builtins.int + vector: builtins.bytes + encoding: global___ScalarEncoding.ValueType + def __init__( + self, + *, + dimension: builtins.int = ..., + vector: builtins.bytes = ..., + encoding: global___ScalarEncoding.ValueType = ..., + ) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "dimension", b"dimension", "encoding", b"encoding", "vector", b"vector" + ], + ) -> None: ... + +global___Vector = Vector + +@typing_extensions.final +class UpdateMetadataValue(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + STRING_VALUE_FIELD_NUMBER: builtins.int + INT_VALUE_FIELD_NUMBER: builtins.int + FLOAT_VALUE_FIELD_NUMBER: builtins.int + string_value: builtins.str + int_value: builtins.int + float_value: builtins.float + def __init__( + self, + *, + string_value: builtins.str = ..., + int_value: builtins.int = ..., + float_value: builtins.float = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "float_value", + b"float_value", + "int_value", + b"int_value", + "string_value", + b"string_value", + "value", + b"value", + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "float_value", + b"float_value", + "int_value", + b"int_value", + "string_value", + b"string_value", + "value", + b"value", + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["value", b"value"] + ) -> ( + typing_extensions.Literal["string_value", "int_value", "float_value"] | None + ): ... + +global___UpdateMetadataValue = UpdateMetadataValue + +@typing_extensions.final +class UpdateMetadata(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + @typing_extensions.final + class MetadataEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: builtins.str + @property + def value(self) -> global___UpdateMetadataValue: ... + def __init__( + self, + *, + key: builtins.str = ..., + value: global___UpdateMetadataValue | None = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["value", b"value"] + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal["key", b"key", "value", b"value"], + ) -> None: ... + + METADATA_FIELD_NUMBER: builtins.int + @property + def metadata( + self, + ) -> google.protobuf.internal.containers.MessageMap[ + builtins.str, global___UpdateMetadataValue + ]: ... + def __init__( + self, + *, + metadata: collections.abc.Mapping[builtins.str, global___UpdateMetadataValue] + | None = ..., + ) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["metadata", b"metadata"] + ) -> None: ... + +global___UpdateMetadata = UpdateMetadata + +@typing_extensions.final +class SubmitEmbeddingRecord(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + ID_FIELD_NUMBER: builtins.int + VECTOR_FIELD_NUMBER: builtins.int + METADATA_FIELD_NUMBER: builtins.int + OPERATION_FIELD_NUMBER: builtins.int + id: builtins.str + @property + def vector(self) -> global___Vector: ... + @property + def metadata(self) -> global___UpdateMetadata: ... + operation: global___Operation.ValueType + def __init__( + self, + *, + id: builtins.str = ..., + vector: global___Vector | None = ..., + metadata: global___UpdateMetadata | None = ..., + operation: global___Operation.ValueType = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "_metadata", + b"_metadata", + "_vector", + b"_vector", + "metadata", + b"metadata", + "vector", + b"vector", + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "_metadata", + b"_metadata", + "_vector", + b"_vector", + "id", + b"id", + "metadata", + b"metadata", + "operation", + b"operation", + "vector", + b"vector", + ], + ) -> None: ... + @typing.overload + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_metadata", b"_metadata"] + ) -> typing_extensions.Literal["metadata"] | None: ... + @typing.overload + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_vector", b"_vector"] + ) -> typing_extensions.Literal["vector"] | None: ... + +global___SubmitEmbeddingRecord = SubmitEmbeddingRecord diff --git a/chromadb/proto/convert.py b/chromadb/proto/convert.py new file mode 100644 index 00000000000..15d1363b05c --- /dev/null +++ b/chromadb/proto/convert.py @@ -0,0 +1,150 @@ +import array +from typing import Optional, Tuple, Union +from chromadb.api.types import Embedding +import chromadb.proto.chroma_pb2 as proto +from chromadb.types import ( + EmbeddingRecord, + Metadata, + Operation, + ScalarEncoding, + SeqId, + SubmitEmbeddingRecord, + Vector, +) + + +def to_proto_vector(vector: Vector, encoding: ScalarEncoding) -> proto.Vector: + if encoding == ScalarEncoding.FLOAT32: + as_bytes = array.array("f", vector).tobytes() + proto_encoding = proto.ScalarEncoding.FLOAT32 + elif encoding == ScalarEncoding.INT32: + as_bytes = array.array("i", vector).tobytes() + proto_encoding = proto.ScalarEncoding.INT32 + else: + raise ValueError( + f"Unknown encoding {encoding}, expected one of {ScalarEncoding.FLOAT32} \ + or {ScalarEncoding.INT32}" + ) + + return proto.Vector(dimension=len(vector), vector=as_bytes, encoding=proto_encoding) + + +def from_proto_vector(vector: proto.Vector) -> Tuple[Embedding, ScalarEncoding]: + encoding = vector.encoding + as_array: array.array[float] | array.array[int] + if encoding == proto.ScalarEncoding.FLOAT32: + as_array = array.array("f") + out_encoding = ScalarEncoding.FLOAT32 + elif encoding == proto.ScalarEncoding.INT32: + as_array = array.array("i") + out_encoding = ScalarEncoding.INT32 + else: + raise ValueError( + f"Unknown encoding {encoding}, expected one of \ + {proto.ScalarEncoding.FLOAT32} or {proto.ScalarEncoding.INT32}" + ) + + as_array.frombytes(vector.vector) + return (as_array.tolist(), out_encoding) + + +def from_proto_operation(operation: proto.Operation.ValueType) -> Operation: + if operation == proto.Operation.ADD: + return Operation.ADD + elif operation == proto.Operation.UPDATE: + return Operation.UPDATE + elif operation == proto.Operation.UPSERT: + return Operation.UPSERT + elif operation == proto.Operation.DELETE: + return Operation.DELETE + else: + raise RuntimeError(f"Unknown operation {operation}") # TODO: full error + + +def from_proto_metadata(metadata: proto.UpdateMetadata) -> Optional[Metadata]: + if not metadata.metadata: + return None + out_metadata = {} + for key, value in metadata.metadata.items(): + if value.HasField("string_value"): + out_metadata[key] = value.string_value + elif value.HasField("int_value"): + out_metadata[key] = value.int_value + elif value.HasField("float_value"): + out_metadata[key] = value.float_value + else: + raise RuntimeError(f"Unknown metadata value type {value}") + return out_metadata + + +def from_proto_submit( + submit_embedding_record: proto.SubmitEmbeddingRecord, seq_id: SeqId +) -> EmbeddingRecord: + embedding, encoding = from_proto_vector(submit_embedding_record.vector) + record = EmbeddingRecord( + id=submit_embedding_record.id, + seq_id=seq_id, + embedding=embedding, + encoding=encoding, + metadata=from_proto_metadata(submit_embedding_record.metadata), + operation=from_proto_operation(submit_embedding_record.operation), + ) + return record + + +def to_proto_metadata_update_value( + value: Union[str, int, float, None] +) -> proto.UpdateMetadataValue: + if isinstance(value, str): + return proto.UpdateMetadataValue(string_value=value) + elif isinstance(value, int): + return proto.UpdateMetadataValue(int_value=value) + elif isinstance(value, float): + return proto.UpdateMetadataValue(float_value=value) + elif value is None: + return proto.UpdateMetadataValue() + else: + raise ValueError( + f"Unknown metadata value type {type(value)}, expected one of str, int, \ + float, or None" + ) + + +def to_proto_operation(operation: Operation) -> proto.Operation.ValueType: + if operation == Operation.ADD: + return proto.Operation.ADD + elif operation == Operation.UPDATE: + return proto.Operation.UPDATE + elif operation == Operation.UPSERT: + return proto.Operation.UPSERT + elif operation == Operation.DELETE: + return proto.Operation.DELETE + else: + raise ValueError( + f"Unknown operation {operation}, expected one of {Operation.ADD}, \ + {Operation.UPDATE}, {Operation.UPDATE}, or {Operation.DELETE}" + ) + + +def to_proto_submit( + submit_record: SubmitEmbeddingRecord, +) -> proto.SubmitEmbeddingRecord: + vector = None + if submit_record["embedding"] is not None and submit_record["encoding"] is not None: + vector = to_proto_vector(submit_record["embedding"], submit_record["encoding"]) + + metadata = None + if submit_record["metadata"] is not None: + metadata = { + k: to_proto_metadata_update_value(v) + for k, v in submit_record["metadata"].items() + } + + return proto.SubmitEmbeddingRecord( + id=submit_record["id"], + vector=vector, + metadata=proto.UpdateMetadata(metadata=metadata) + if metadata is not None + else None, + operation=to_proto_operation(submit_record["operation"]), + ) diff --git a/chromadb/segment/impl/metadata/sqlite.py b/chromadb/segment/impl/metadata/sqlite.py index 8e9649a627d..a7098d7808b 100644 --- a/chromadb/segment/impl/metadata/sqlite.py +++ b/chromadb/segment/impl/metadata/sqlite.py @@ -469,9 +469,9 @@ def delete(self) -> None: def _encode_seq_id(seq_id: SeqId) -> bytes: """Encode a SeqID into a byte array""" - if seq_id.bit_length() < 64: + if seq_id.bit_length() <= 64: return int.to_bytes(seq_id, 8, "big") - elif seq_id.bit_length() < 192: + elif seq_id.bit_length() <= 192: return int.to_bytes(seq_id, 24, "big") else: raise ValueError(f"Unsupported SeqID: {seq_id}") diff --git a/chromadb/segment/impl/vector/local_persistent_hnsw.py b/chromadb/segment/impl/vector/local_persistent_hnsw.py index 6e1df7b1f1f..f8c74bd0fe7 100644 --- a/chromadb/segment/impl/vector/local_persistent_hnsw.py +++ b/chromadb/segment/impl/vector/local_persistent_hnsw.py @@ -207,7 +207,6 @@ def _write_records(self, records: Sequence[EmbeddingRecord]) -> None: """Add a batch of embeddings to the index""" if not self._running: raise RuntimeError("Cannot add embeddings to stopped component") - with WriteRWLock(self._lock): for record in records: if record["embedding"] is not None: diff --git a/chromadb/test/ingest/test_producer_consumer.py b/chromadb/test/ingest/test_producer_consumer.py index 84aa69ffd07..1163889a246 100644 --- a/chromadb/test/ingest/test_producer_consumer.py +++ b/chromadb/test/ingest/test_producer_consumer.py @@ -1,3 +1,4 @@ +import asyncio import os import shutil import tempfile @@ -16,6 +17,7 @@ ) from chromadb.ingest import Producer, Consumer from chromadb.db.impl.sqlite import SqliteDB +from chromadb.ingest.impl.utils import create_topic_name from chromadb.test.conftest import ProducerFn from chromadb.types import ( SubmitEmbeddingRecord, @@ -51,8 +53,33 @@ def sqlite_persistent() -> Generator[Tuple[Producer, Consumer], None, None]: shutil.rmtree(save_path) +def pulsar() -> Generator[Tuple[Producer, Consumer], None, None]: + """Fixture generator for pulsar Producer + Consumer. This fixture requires a running + pulsar cluster. You can use bin/cluster-test.sh to start a standalone pulsar and run this test + """ + system = System( + Settings( + allow_reset=True, + chroma_producer_impl="chromadb.ingest.impl.pulsar.PulsarProducer", + chroma_consumer_impl="chromadb.ingest.impl.pulsar.PulsarConsumer", + pulsar_broker_url="localhost", + pulsar_admin_port="8080", + pulsar_broker_port="6650", + ) + ) + producer = system.require(Producer) + consumer = system.require(Consumer) + system.start() + yield producer, consumer + system.stop() + + def fixtures() -> List[Callable[[], Generator[Tuple[Producer, Consumer], None, None]]]: - return [sqlite, sqlite_persistent] + fixtures = [sqlite, sqlite_persistent] + if "CHROMA_CLUSTER_TEST_ONLY" in os.environ: + fixtures = [pulsar] + + return fixtures @pytest.fixture(scope="module", params=fixtures()) @@ -89,14 +116,20 @@ class CapturingConsumeFn: waiters: List[Tuple[int, Event]] def __init__(self) -> None: + """A function that captures embeddings and allows you to wait for a certain + number of embeddings to be available. It must be constructed in the thread with + the main event loop + """ self.embeddings = [] self.waiters = [] + self._loop = asyncio.get_event_loop() def __call__(self, embeddings: Sequence[EmbeddingRecord]) -> None: self.embeddings.extend(embeddings) for n, event in self.waiters: if len(self.embeddings) >= n: - event.set() + # event.set() is not thread safe, so we need to call it in the main event loop + self._loop.call_soon_threadsafe(event.set) async def get(self, n: int, timeout_secs: int = 10) -> Sequence[EmbeddingRecord]: "Wait until at least N embeddings are available, then return all embeddings" @@ -132,6 +165,10 @@ def assert_records_match( assert_approx_equal(inserted["embedding"], consumed["embedding"]) +def full_topic_name(topic_name: str) -> str: + return create_topic_name("default", "default", topic_name) + + @pytest.mark.asyncio async def test_backfill( producer_consumer: Tuple[Producer, Consumer], @@ -140,12 +177,14 @@ async def test_backfill( ) -> None: producer, consumer = producer_consumer producer.reset_state() + consumer.reset_state() - producer.create_topic("test_topic") - embeddings = produce_fns(producer, "test_topic", sample_embeddings, 3)[0] + topic_name = full_topic_name("test_topic") + producer.create_topic(topic_name) + embeddings = produce_fns(producer, topic_name, sample_embeddings, 3)[0] consume_fn = CapturingConsumeFn() - consumer.subscribe("test_topic", consume_fn, start=consumer.min_seqid()) + consumer.subscribe(topic_name, consume_fn, start=consumer.min_seqid()) recieved = await consume_fn.get(3) assert_records_match(embeddings, recieved) @@ -158,18 +197,21 @@ async def test_notifications( ) -> None: producer, consumer = producer_consumer producer.reset_state() - producer.create_topic("test_topic") + consumer.reset_state() + topic_name = full_topic_name("test_topic") + + producer.create_topic(topic_name) embeddings: List[SubmitEmbeddingRecord] = [] consume_fn = CapturingConsumeFn() - consumer.subscribe("test_topic", consume_fn, start=consumer.min_seqid()) + consumer.subscribe(topic_name, consume_fn, start=consumer.min_seqid()) for i in range(10): e = next(sample_embeddings) embeddings.append(e) - producer.submit_embedding("test_topic", e) + producer.submit_embedding(topic_name, e) received = await consume_fn.get(i + 1) assert_records_match(embeddings, received) @@ -181,8 +223,11 @@ async def test_multiple_topics( ) -> None: producer, consumer = producer_consumer producer.reset_state() - producer.create_topic("test_topic_1") - producer.create_topic("test_topic_2") + consumer.reset_state() + topic_name_1 = full_topic_name("test_topic_1") + topic_name_2 = full_topic_name("test_topic_2") + producer.create_topic(topic_name_1) + producer.create_topic(topic_name_2) embeddings_1: List[SubmitEmbeddingRecord] = [] embeddings_2: List[SubmitEmbeddingRecord] = [] @@ -190,19 +235,19 @@ async def test_multiple_topics( consume_fn_1 = CapturingConsumeFn() consume_fn_2 = CapturingConsumeFn() - consumer.subscribe("test_topic_1", consume_fn_1, start=consumer.min_seqid()) - consumer.subscribe("test_topic_2", consume_fn_2, start=consumer.min_seqid()) + consumer.subscribe(topic_name_1, consume_fn_1, start=consumer.min_seqid()) + consumer.subscribe(topic_name_2, consume_fn_2, start=consumer.min_seqid()) for i in range(10): e_1 = next(sample_embeddings) embeddings_1.append(e_1) - producer.submit_embedding("test_topic_1", e_1) + producer.submit_embedding(topic_name_1, e_1) results_2 = await consume_fn_1.get(i + 1) assert_records_match(embeddings_1, results_2) e_2 = next(sample_embeddings) embeddings_2.append(e_2) - producer.submit_embedding("test_topic_2", e_2) + producer.submit_embedding(topic_name_2, e_2) results_2 = await consume_fn_2.get(i + 1) assert_records_match(embeddings_2, results_2) @@ -215,21 +260,23 @@ async def test_start_seq_id( ) -> None: producer, consumer = producer_consumer producer.reset_state() - producer.create_topic("test_topic") + consumer.reset_state() + topic_name = full_topic_name("test_topic") + producer.create_topic(topic_name) consume_fn_1 = CapturingConsumeFn() consume_fn_2 = CapturingConsumeFn() - consumer.subscribe("test_topic", consume_fn_1, start=consumer.min_seqid()) + consumer.subscribe(topic_name, consume_fn_1, start=consumer.min_seqid()) - embeddings = produce_fns(producer, "test_topic", sample_embeddings, 5)[0] + embeddings = produce_fns(producer, topic_name, sample_embeddings, 5)[0] results_1 = await consume_fn_1.get(5) assert_records_match(embeddings, results_1) start = consume_fn_1.embeddings[-1]["seq_id"] - consumer.subscribe("test_topic", consume_fn_2, start=start) - second_embeddings = produce_fns(producer, "test_topic", sample_embeddings, 5)[0] + consumer.subscribe(topic_name, consume_fn_2, start=start) + second_embeddings = produce_fns(producer, topic_name, sample_embeddings, 5)[0] assert isinstance(embeddings, list) embeddings.extend(second_embeddings) results_2 = await consume_fn_2.get(5) @@ -244,20 +291,22 @@ async def test_end_seq_id( ) -> None: producer, consumer = producer_consumer producer.reset_state() - producer.create_topic("test_topic") + consumer.reset_state() + topic_name = full_topic_name("test_topic") + producer.create_topic(topic_name) consume_fn_1 = CapturingConsumeFn() consume_fn_2 = CapturingConsumeFn() - consumer.subscribe("test_topic", consume_fn_1, start=consumer.min_seqid()) + consumer.subscribe(topic_name, consume_fn_1, start=consumer.min_seqid()) - embeddings = produce_fns(producer, "test_topic", sample_embeddings, 10)[0] + embeddings = produce_fns(producer, topic_name, sample_embeddings, 10)[0] results_1 = await consume_fn_1.get(10) assert_records_match(embeddings, results_1) end = consume_fn_1.embeddings[-5]["seq_id"] - consumer.subscribe("test_topic", consume_fn_2, start=consumer.min_seqid(), end=end) + consumer.subscribe(topic_name, consume_fn_2, start=consumer.min_seqid(), end=end) results_2 = await consume_fn_2.get(6) assert_records_match(embeddings[:6], results_2) @@ -274,14 +323,16 @@ async def test_submit_batch( ) -> None: producer, consumer = producer_consumer producer.reset_state() + consumer.reset_state() + topic_name = full_topic_name("test_topic") embeddings = [next(sample_embeddings) for _ in range(100)] - producer.create_topic("test_topic") - producer.submit_embeddings("test_topic", embeddings=embeddings) + producer.create_topic(topic_name) + producer.submit_embeddings(topic_name, embeddings=embeddings) consume_fn = CapturingConsumeFn() - consumer.subscribe("test_topic", consume_fn, start=consumer.min_seqid()) + consumer.subscribe(topic_name, consume_fn, start=consumer.min_seqid()) recieved = await consume_fn.get(100) assert_records_match(embeddings, recieved) @@ -295,13 +346,16 @@ async def test_multiple_topics_batch( ) -> None: producer, consumer = producer_consumer producer.reset_state() + consumer.reset_state() - N_TOPICS = 100 + N_TOPICS = 2 consume_fns = [CapturingConsumeFn() for _ in range(N_TOPICS)] for i in range(N_TOPICS): - producer.create_topic(f"test_topic_{i}") + producer.create_topic(full_topic_name(f"test_topic_{i}")) consumer.subscribe( - f"test_topic_{i}", consume_fns[i], start=consumer.min_seqid() + full_topic_name(f"test_topic_{i}"), + consume_fns[i], + start=consumer.min_seqid(), ) embeddings_n: List[List[SubmitEmbeddingRecord]] = [[] for _ in range(N_TOPICS)] @@ -310,17 +364,17 @@ async def test_multiple_topics_batch( N_TO_PRODUCE = 100 total_produced = 0 for i in range(N_TO_PRODUCE // PRODUCE_BATCH_SIZE): - for i in range(N_TOPICS): - embeddings_n[i].extend( + for n in range(N_TOPICS): + embeddings_n[n].extend( produce_fns( producer, - f"test_topic_{i}", + full_topic_name(f"test_topic_{n}"), sample_embeddings, PRODUCE_BATCH_SIZE, )[0] ) - recieved = await consume_fns[i].get(total_produced + PRODUCE_BATCH_SIZE) - assert_records_match(embeddings_n[i], recieved) + recieved = await consume_fns[n].get(total_produced + PRODUCE_BATCH_SIZE) + assert_records_match(embeddings_n[n], recieved) total_produced += PRODUCE_BATCH_SIZE @@ -331,19 +385,21 @@ async def test_max_batch_size( ) -> None: producer, consumer = producer_consumer producer.reset_state() - max_batch_size = producer_consumer[0].max_batch_size + consumer.reset_state() + topic_name = full_topic_name("test_topic") + max_batch_size = producer.max_batch_size assert max_batch_size > 0 # Make sure that we can produce a batch of size max_batch_size embeddings = [next(sample_embeddings) for _ in range(max_batch_size)] consume_fn = CapturingConsumeFn() - consumer.subscribe("test_topic", consume_fn, start=consumer.min_seqid()) - producer.submit_embeddings("test_topic", embeddings=embeddings) + consumer.subscribe(topic_name, consume_fn, start=consumer.min_seqid()) + producer.submit_embeddings(topic_name, embeddings=embeddings) received = await consume_fn.get(max_batch_size, timeout_secs=120) assert_records_match(embeddings, received) embeddings = [next(sample_embeddings) for _ in range(max_batch_size + 1)] # Make sure that we can't produce a batch of size > max_batch_size with pytest.raises(ValueError) as e: - producer.submit_embeddings("test_topic", embeddings=embeddings) + producer.submit_embeddings(topic_name, embeddings=embeddings) assert "Cannot submit more than" in str(e.value) diff --git a/docker-compose.cluster.yml b/docker-compose.cluster.yml new file mode 100644 index 00000000000..d36ed16906f --- /dev/null +++ b/docker-compose.cluster.yml @@ -0,0 +1,66 @@ +# This docker compose file is not meant to be used. It is a work in progress +# for the distributed version of Chroma. It is not yet functional. + +version: '3.9' + +networks: + net: + driver: bridge + +services: + server: + image: server + build: + context: . + dockerfile: Dockerfile + volumes: + - ./:/chroma + - index_data:/index_data + command: uvicorn chromadb.app:app --reload --workers 1 --host 0.0.0.0 --port 8000 --log-config log_config.yml + environment: + - IS_PERSISTENT=TRUE + - CHROMA_PRODUCER_IMPL=chromadb.ingest.impl.pulsar.PulsarProducer + - CHROMA_CONSUMER_IMPL=chromadb.ingest.impl.pulsar.PulsarConsumer + - PULSAR_BROKER_URL=pulsar + - PULSAR_BROKER_PORT=6650 + - PULSAR_ADMIN_PORT=8080 + ports: + - 8000:8000 + depends_on: + pulsar: + condition: service_healthy + networks: + - net + + pulsar: + image: apachepulsar/pulsar + volumes: + - pulsardata:/pulsar/data + - pulsarconf:/pulsar/conf + command: bin/pulsar standalone + ports: + - 6650:6650 + - 8080:8080 + networks: + - net + healthcheck: + test: + [ + "CMD", + "curl", + "-f", + "localhost:8080/admin/v2/brokers/health" + ] + interval: 3s + timeout: 1m + retries: 10 + +volumes: + index_data: + driver: local + backups: + driver: local + pulsardata: + driver: local + pulsarconf: + driver: local diff --git a/requirements_dev.txt b/requirements_dev.txt index 68546b27796..9354d39b725 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -3,8 +3,10 @@ build httpx hypothesis hypothesis[numpy] +mypy-protobuf pre-commit pytest pytest-asyncio setuptools_scm +types-protobuf types-requests==2.30.0.0 From 5436bd5de1930256c98cd5f2eb3cbcd0a6e73f2b Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Wed, 20 Sep 2023 20:28:38 +0300 Subject: [PATCH 25/39] [ENH]: Support for $in and $nin metadata filters (#1151) Refs: #1105 ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - JS Client support for $in and $nin ## Test plan *How are these changes tested?* - [x] Tests pass locally `yarn test` for js ## Documentation Changes TBD --- clients/js/src/types.ts | 8 +-- clients/js/test/query.collection.test.ts | 68 ++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 3 deletions(-) diff --git a/clients/js/src/types.ts b/clients/js/src/types.ts index 8787d5e5659..1f1dd04c4c8 100644 --- a/clients/js/src/types.ts +++ b/clients/js/src/types.ts @@ -20,13 +20,15 @@ export type IDs = ID[]; export type PositiveInteger = number; -type LiteralValue = string | number; +type LiteralValue = string | number | boolean; +type ListLiteralValue = LiteralValue[]; type LiteralNumber = number; type LogicalOperator = "$and" | "$or"; +type InclusionOperator = "$in" | "$nin"; type WhereOperator = "$gt" | "$gte" | "$lt" | "$lte" | "$ne" | "$eq"; type OperatorExpression = { - [key in WhereOperator | LogicalOperator]?: LiteralValue | LiteralNumber; + [key in WhereOperator | InclusionOperator | LogicalOperator ]?: LiteralValue | ListLiteralValue; }; type BaseWhere = { @@ -77,4 +79,4 @@ export type CollectionMetadata = Record; // see all options here: https://www.jsdocs.io/package/@types/node-fetch#RequestInit export type ConfigOptions = { options?: RequestInit; -}; \ No newline at end of file +}; diff --git a/clients/js/test/query.collection.test.ts b/clients/js/test/query.collection.test.ts index 05125a27ffa..878ed0a71df 100644 --- a/clients/js/test/query.collection.test.ts +++ b/clients/js/test/query.collection.test.ts @@ -86,3 +86,71 @@ test("it should query a collection with text", async () => { expect.arrayContaining(results.documents[0]) ); }) + + +test("it should query a collection with text and where", async () => { + await chroma.reset(); + let embeddingFunction = new TestEmbeddingFunction(); + const collection = await chroma.createCollection({ name: "test", embeddingFunction: embeddingFunction }); + await collection.add({ ids: IDS, embeddings: EMBEDDINGS, metadatas: METADATAS, documents: DOCUMENTS }); + + const results = await collection.query({ + queryTexts: ["test"], + nResults: 3, + where: { "float_value" : 2 } + }); + + expect(results).toBeDefined(); + expect(results).toBeInstanceOf(Object); + expect(results.ids.length).toBe(1); + expect(["test3"]).toEqual(expect.arrayContaining(results.ids[0])); + expect(["test2"]).not.toEqual(expect.arrayContaining(results.ids[0])); + expect(["This is a third test"]).toEqual( + expect.arrayContaining(results.documents[0]) + ); +}) + + +test("it should query a collection with text and where in", async () => { + await chroma.reset(); + let embeddingFunction = new TestEmbeddingFunction(); + const collection = await chroma.createCollection({ name: "test", embeddingFunction: embeddingFunction }); + await collection.add({ ids: IDS, embeddings: EMBEDDINGS, metadatas: METADATAS, documents: DOCUMENTS }); + + const results = await collection.query({ + queryTexts: ["test"], + nResults: 3, + where: { "float_value" : { '$in': [2,5,10] }} + }); + + expect(results).toBeDefined(); + expect(results).toBeInstanceOf(Object); + expect(results.ids.length).toBe(1); + expect(["test3"]).toEqual(expect.arrayContaining(results.ids[0])); + expect(["test2"]).not.toEqual(expect.arrayContaining(results.ids[0])); + expect(["This is a third test"]).toEqual( + expect.arrayContaining(results.documents[0]) + ); +}) + +test("it should query a collection with text and where nin", async () => { + await chroma.reset(); + let embeddingFunction = new TestEmbeddingFunction(); + const collection = await chroma.createCollection({ name: "test", embeddingFunction: embeddingFunction }); + await collection.add({ ids: IDS, embeddings: EMBEDDINGS, metadatas: METADATAS, documents: DOCUMENTS }); + + const results = await collection.query({ + queryTexts: ["test"], + nResults: 3, + where: { "float_value" : { '$nin': [-2,0] }} + }); + + expect(results).toBeDefined(); + expect(results).toBeInstanceOf(Object); + expect(results.ids.length).toBe(1); + expect(["test3"]).toEqual(expect.arrayContaining(results.ids[0])); + expect(["test2"]).not.toEqual(expect.arrayContaining(results.ids[0])); + expect(["This is a third test"]).toEqual( + expect.arrayContaining(results.documents[0]) + ); +}) From 35991bfe87b036a105e0609a9061fde58aee73d4 Mon Sep 17 00:00:00 2001 From: calvintwr Date: Fri, 22 Sep 2023 01:30:58 +0800 Subject: [PATCH 26/39] export IncludeEnum as it is required by #get and #query (#1167) ## Description of changes The `IncludeEnum` enum is not exported, cause lint errors when using `.get` or `.query`, as follows: ```js const result = await collection.query({ queryTexts: [query], // THIS LINE WILL PRODUCE LINT ERROR as it needs IncludeEnum.Distances etc. include: ['distances', 'documents', 'metadatas'], nResults: 2, }) ``` ## Test plan Nil ## Documentation Changes Nil --- clients/js/src/index.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/clients/js/src/index.ts b/clients/js/src/index.ts index 192f39c9477..97c119dfcb1 100644 --- a/clients/js/src/index.ts +++ b/clients/js/src/index.ts @@ -5,3 +5,4 @@ export { OpenAIEmbeddingFunction } from './embeddings/OpenAIEmbeddingFunction'; export { CohereEmbeddingFunction } from './embeddings/CohereEmbeddingFunction'; export { WebAIEmbeddingFunction } from './embeddings/WebAIEmbeddingFunction'; export { TransformersEmbeddingFunction } from './embeddings/TransformersEmbeddingFunction'; +export { IncludeEnum } from './types'; \ No newline at end of file From 317d547d7a5065272e8623ab727fb857ce750946 Mon Sep 17 00:00:00 2001 From: Ben Eggers <64657842+beggers@users.noreply.github.com> Date: Fri, 22 Sep 2023 15:49:29 -0700 Subject: [PATCH 27/39] Fix failing example terraform (#1175) ## Description of changes `lifecycle` blocks don't allow variables. Right now our example deployment for AWS doesn't work. @tazarov has a fix for this and a few other things in https://github.com/chroma-core/chroma/pull/1173 but I'd like to get the basic fix out before the weekend. ## Test plan local `terraform init` failed before, now works. ## Documentation Changes *Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?* --- examples/deployments/aws-terraform/README.md | 3 +++ examples/deployments/aws-terraform/chroma.tf | 2 +- examples/deployments/aws-terraform/variables.tf | 6 ------ 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/examples/deployments/aws-terraform/README.md b/examples/deployments/aws-terraform/README.md index dc47929821a..d48ef3227e9 100644 --- a/examples/deployments/aws-terraform/README.md +++ b/examples/deployments/aws-terraform/README.md @@ -140,6 +140,9 @@ ssh -i ./chroma-aws ubuntu@$instance_public_ip ``` ### 5. Destroy your Chroma instance + +You will need to change `prevent_destroy` to `false` in the `aws_ebs_volume` in `chroma.tf`. + ```bash terraform destroy -auto-approve ``` diff --git a/examples/deployments/aws-terraform/chroma.tf b/examples/deployments/aws-terraform/chroma.tf index 5f0db03ed51..fca553ae90d 100644 --- a/examples/deployments/aws-terraform/chroma.tf +++ b/examples/deployments/aws-terraform/chroma.tf @@ -111,7 +111,7 @@ resource "aws_ebs_volume" "chroma-volume" { } lifecycle { - prevent_destroy = var.prevent_chroma_data_volume_delete # size in GBs + prevent_destroy = true } } diff --git a/examples/deployments/aws-terraform/variables.tf b/examples/deployments/aws-terraform/variables.tf index 84e086116b5..15e0cc81e40 100644 --- a/examples/deployments/aws-terraform/variables.tf +++ b/examples/deployments/aws-terraform/variables.tf @@ -86,9 +86,3 @@ variable "chroma_data_volume_size" { type = number default = 20 } - -variable "prevent_chroma_data_volume_delete" { - description = "Prevent the chroma data volume from being deleted when the instance is terminated" - type = bool - default = false -} From c7a0414ea7d95a5217c82f828ac6c95e31561791 Mon Sep 17 00:00:00 2001 From: Ben Eggers <64657842+beggers@users.noreply.github.com> Date: Fri, 22 Sep 2023 15:49:58 -0700 Subject: [PATCH 28/39] [ENH] Metric batching and more metrics (#1163) ## Description of changes This PR accomplishes two things: - Adds batching to metrics to decrease load to Posthog - Adds more metric instrumentation Each `TelemetryEvent` type now has a `batch_size` member defining how many of that Event to include in a batch. `TelemetryEvent`s with `batch_size > 1` must also define `can_batch()` and `batch()` methods to do the actual batching -- our posthog client can't do this itself since different `TelemetryEvent`s use different count fields. The Posthog client combines events until they hit their `batch_size` then fires them off as one event. NB: this means we can drop up to `batch_size` events -- since we only batch `add()` calls right now this seems fine, though we may want to address it in the future. As for the additional telemetry, I pretty much copied Anton's draft https://github.com/chroma-core/chroma/pull/859 with some minor changes. Other considerations: Maybe we should implement `can_batch()` and `batch()` on all events, even those which don't currently use them? I'd prefer not to leave dead code hanging around but happy to go either way. I created a ticket for the type ignores: https://github.com/chroma-core/chroma/issues/1169 ## Test plan pytest passes modulo a couple unrelated failures With `print(event.properties)` in posthog client's `_direct_capture()`: ``` >>> import chromadb >>> client = chromadb.Client() {'batch_size': 1} >>> collection = client.create_collection("sample_collection") {'batch_size': 1, 'collection_uuid': 'bb19d790-4ec7-436c-b781-46dab047625d', 'embedding_function': 'ONNXMiniLM_L6_V2'} >>> collection.add( ... documents=["This is document1", "This is document2"], # we embed for you, or bring your own ... metadatas=[{"source": "notion"}, {"source": "google-docs"}], # filter on arbitrary metadata! ... ids=["doc1", "doc2"], # must be unique for each doc ... ) {'batch_size': 1, 'collection_uuid': 'bb19d790-4ec7-436c-b781-46dab047625d', 'add_amount': 2, 'with_documents': 2, 'with_metadata': 2} >>> for i in range(50): ... collection.add(documents=[str(i)], ids=[str(i)]) ... {'batch_size': 20, 'collection_uuid': 'bb19d790-4ec7-436c-b781-46dab047625d', 'add_amount': 20, 'with_documents': 20, 'with_metadata': 0} {'batch_size': 20, 'collection_uuid': 'bb19d790-4ec7-436c-b781-46dab047625d', 'add_amount': 20, 'with_documents': 20, 'with_metadata': 0} >>> for i in range(50): ... collection.add(documents=[str(i) + ' ' + str(n) for n in range(20)], ids=[str(i) + ' ' + str(n) for n in range(20)]) ... {'batch_size': 20, 'collection_uuid': 'bb19d790-4ec7-436c-b781-46dab047625d', 'add_amount': 210, 'with_documents': 210, 'with_metadata': 0} {'batch_size': 20, 'collection_uuid': 'bb19d790-4ec7-436c-b781-46dab047625d', 'add_amount': 400, 'with_documents': 400, 'with_metadata': 0} {'batch_size': 20, 'collection_uuid': 'bb19d790-4ec7-436c-b781-46dab047625d', 'add_amount': 400, 'with_documents': 400, 'with_metadata': 0} ``` ## Documentation Changes https://github.com/chroma-core/docs/pull/139 https://github.com/chroma-core/docs/commit/a4fd57d4d2cc3cae00cbb4a9245b938e2f0d1842 --- chromadb/__init__.py | 4 +- chromadb/api/segment.py | 62 ++++++++++- chromadb/telemetry/__init__.py | 24 ++++- chromadb/telemetry/events.py | 148 ++++++++++++++++++++++++-- chromadb/telemetry/posthog.py | 20 ++++ chromadb/utils/embedding_functions.py | 26 +++-- 6 files changed, 258 insertions(+), 26 deletions(-) diff --git a/chromadb/__init__.py b/chromadb/__init__.py index 0ff5244a80f..ad7d3d4f70b 100644 --- a/chromadb/__init__.py +++ b/chromadb/__init__.py @@ -2,8 +2,6 @@ import logging import sqlite3 import chromadb.config -from chromadb.telemetry.events import ClientStartEvent -from chromadb.telemetry import Telemetry from chromadb.config import Settings, System from chromadb.api import API from chromadb.api.models.Collection import Collection @@ -38,6 +36,8 @@ "QueryResult", "GetResult", ] +from chromadb.telemetry.events import ClientStartEvent +from chromadb.telemetry import Telemetry logger = logging.getLogger(__name__) diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index 00002f46d27..fd2f08ec63b 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -29,7 +29,14 @@ validate_where_document, validate_batch, ) -from chromadb.telemetry.events import CollectionAddEvent, CollectionDeleteEvent +from chromadb.telemetry.events import ( + CollectionAddEvent, + CollectionDeleteEvent, + CollectionGetEvent, + CollectionUpdateEvent, + CollectionQueryEvent, + ClientCreateCollectionEvent, +) import chromadb.types as t @@ -140,6 +147,13 @@ def create_collection( for segment in segments: self._sysdb.create_segment(segment) + self._telemetry_client.capture( + ClientCreateCollectionEvent( + collection_uuid=str(id), + embedding_function=embedding_function.__class__.__name__, + ) + ) + return Collection( client=self, id=id, @@ -263,7 +277,14 @@ def _add( records_to_submit.append(r) self._producer.submit_embeddings(coll["topic"], records_to_submit) - self._telemetry_client.capture(CollectionAddEvent(str(collection_id), len(ids))) + self._telemetry_client.capture( + CollectionAddEvent( + collection_uuid=str(collection_id), + add_amount=len(ids), + with_metadata=len(ids) if metadatas is not None else 0, + with_documents=len(ids) if documents is not None else 0, + ) + ) return True @override @@ -293,6 +314,16 @@ def _update( records_to_submit.append(r) self._producer.submit_embeddings(coll["topic"], records_to_submit) + self._telemetry_client.capture( + CollectionUpdateEvent( + collection_uuid=str(collection_id), + update_amount=len(ids), + with_embeddings=len(embeddings) if embeddings else 0, + with_metadata=len(metadatas) if metadatas else 0, + with_documents=len(documents) if documents else 0, + ) + ) + return True @override @@ -377,6 +408,16 @@ def _get( if "documents" in include: documents = [_doc(m) for m in metadatas] + self._telemetry_client.capture( + CollectionGetEvent( + collection_uuid=str(collection_id), + ids_count=len(ids) if ids else 0, + limit=limit if limit else 0, + include_metadata="metadatas" in include, + include_documents="documents" in include, + ) + ) + return GetResult( ids=[r["id"] for r in records], embeddings=[r["embedding"] for r in vectors] @@ -441,7 +482,9 @@ def _delete( self._producer.submit_embeddings(coll["topic"], records_to_submit) self._telemetry_client.capture( - CollectionDeleteEvent(str(collection_id), len(ids_to_delete)) + CollectionDeleteEvent( + collection_uuid=str(collection_id), delete_amount=len(ids_to_delete) + ) ) return ids_to_delete @@ -528,6 +571,19 @@ def _query( doc_list = [_doc(m) for m in metadata_list] documents.append(doc_list) # type: ignore + self._telemetry_client.capture( + CollectionQueryEvent( + collection_uuid=str(collection_id), + query_amount=len(query_embeddings), + n_results=n_results, + with_metadata_filter=where is not None, + with_document_filter=where_document is not None, + include_metadatas="metadatas" in include, + include_documents="documents" in include, + include_distances="distances" in include, + ) + ) + return QueryResult( ids=ids, distances=distances if distances else None, diff --git a/chromadb/telemetry/__init__.py b/chromadb/telemetry/__init__.py index db962549267..d20b8e5d71c 100644 --- a/chromadb/telemetry/__init__.py +++ b/chromadb/telemetry/__init__.py @@ -1,5 +1,4 @@ from abc import abstractmethod -from dataclasses import asdict, dataclass import os from typing import Callable, ClassVar, Dict, Any import uuid @@ -22,13 +21,30 @@ class ServerContext(Enum): FASTAPI = "FastAPI" -@dataclass class TelemetryEvent: - name: ClassVar[str] + max_batch_size: ClassVar[int] = 1 + batch_size: int + + def __init__(self, batch_size: int = 1): + self.batch_size = batch_size @property def properties(self) -> Dict[str, Any]: - return asdict(self) + return self.__dict__ + + @property + def name(self) -> str: + return self.__class__.__name__ + + # A batch key is used to determine whether two events can be batched together. + # If a TelemetryEvent's max_batch_size > 1, batch_key() and batch() MUST be implemented. + # Otherwise they are ignored. + @property + def batch_key(self) -> str: + return self.name + + def batch(self, other: "TelemetryEvent") -> "TelemetryEvent": + raise NotImplementedError class RepeatedTelemetry: diff --git a/chromadb/telemetry/events.py b/chromadb/telemetry/events.py index 64c77574f9f..34c6264fcc9 100644 --- a/chromadb/telemetry/events.py +++ b/chromadb/telemetry/events.py @@ -1,27 +1,153 @@ -from dataclasses import dataclass -from typing import ClassVar +from typing import cast, ClassVar from chromadb.telemetry import TelemetryEvent +from chromadb.utils.embedding_functions import get_builtins -@dataclass class ClientStartEvent(TelemetryEvent): - name: ClassVar[str] = "client_start" + def __init__(self) -> None: + super().__init__() -@dataclass -class ServerStartEvent(TelemetryEvent): - name: ClassVar[str] = "server_start" +class ClientCreateCollectionEvent(TelemetryEvent): + collection_uuid: str + embedding_function: str + + def __init__(self, collection_uuid: str, embedding_function: str): + super().__init__() + self.collection_uuid = collection_uuid + + embedding_function_names = get_builtins() + + self.embedding_function = ( + embedding_function + if embedding_function in embedding_function_names + else "custom" + ) -@dataclass class CollectionAddEvent(TelemetryEvent): - name: ClassVar[str] = "collection_add" + max_batch_size: ClassVar[int] = 20 collection_uuid: str add_amount: int + with_documents: int + with_metadata: int + + def __init__( + self, + collection_uuid: str, + add_amount: int, + with_documents: int, + with_metadata: int, + batch_size: int = 1, + ): + super().__init__() + self.collection_uuid = collection_uuid + self.add_amount = add_amount + self.with_documents = with_documents + self.with_metadata = with_metadata + self.batch_size = batch_size + + @property + def batch_key(self) -> str: + return self.collection_uuid + self.name + + def batch(self, other: "TelemetryEvent") -> "CollectionAddEvent": + if not self.batch_key == other.batch_key: + raise ValueError("Cannot batch events") + other = cast(CollectionAddEvent, other) + total_amount = self.add_amount + other.add_amount + return CollectionAddEvent( + collection_uuid=self.collection_uuid, + add_amount=total_amount, + with_documents=self.with_documents + other.with_documents, + with_metadata=self.with_metadata + other.with_metadata, + batch_size=self.batch_size + other.batch_size, + ) + + +class CollectionUpdateEvent(TelemetryEvent): + collection_uuid: str + update_amount: int + with_embeddings: int + with_metadata: int + with_documents: int + + def __init__( + self, + collection_uuid: str, + update_amount: int, + with_embeddings: int, + with_metadata: int, + with_documents: int, + ): + super().__init__() + self.collection_uuid = collection_uuid + self.update_amount = update_amount + self.with_embeddings = with_embeddings + self.with_metadata = with_metadata + self.with_documents = with_documents + + +class CollectionQueryEvent(TelemetryEvent): + collection_uuid: str + query_amount: int + with_metadata_filter: bool + with_document_filter: bool + n_results: int + include_metadatas: bool + include_documents: bool + include_distances: bool + + def __init__( + self, + collection_uuid: str, + query_amount: int, + with_metadata_filter: bool, + with_document_filter: bool, + n_results: int, + include_metadatas: bool, + include_documents: bool, + include_distances: bool, + ): + super().__init__() + self.collection_uuid = collection_uuid + self.query_amount = query_amount + self.with_metadata_filter = with_metadata_filter + self.with_document_filter = with_document_filter + self.n_results = n_results + self.include_metadatas = include_metadatas + self.include_documents = include_documents + self.include_distances = include_distances + + +class CollectionGetEvent(TelemetryEvent): + collection_uuid: str + ids_count: int + limit: int + include_metadata: bool + include_documents: bool + + def __init__( + self, + collection_uuid: str, + ids_count: int, + limit: int, + include_metadata: bool, + include_documents: bool, + ): + super().__init__() + self.collection_uuid = collection_uuid + self.ids_count = ids_count + self.limit = limit + self.include_metadata = include_metadata + self.include_documents = include_documents -@dataclass class CollectionDeleteEvent(TelemetryEvent): - name: ClassVar[str] = "collection_delete" collection_uuid: str delete_amount: int + + def __init__(self, collection_uuid: str, delete_amount: int): + super().__init__() + self.collection_uuid = collection_uuid + self.delete_amount = delete_amount diff --git a/chromadb/telemetry/posthog.py b/chromadb/telemetry/posthog.py index a20e20dd257..184904531ef 100644 --- a/chromadb/telemetry/posthog.py +++ b/chromadb/telemetry/posthog.py @@ -1,6 +1,7 @@ import posthog import logging import sys +from typing import Any, Dict, Set from chromadb.config import System from chromadb.telemetry import Telemetry, TelemetryEvent from overrides import override @@ -21,10 +22,29 @@ def __init__(self, system: System): posthog_logger = logging.getLogger("posthog") # Silence posthog's logging posthog_logger.disabled = True + + self.batched_events: Dict[str, TelemetryEvent] = {} + self.seen_event_types: Set[Any] = set() + super().__init__(system) @override def capture(self, event: TelemetryEvent) -> None: + if event.max_batch_size == 1 or event.batch_key not in self.seen_event_types: + self.seen_event_types.add(event.batch_key) + self._direct_capture(event) + return + batch_key = event.batch_key + if batch_key not in self.batched_events: + self.batched_events[batch_key] = event + return + batched_event = self.batched_events[batch_key].batch(event) + self.batched_events[batch_key] = batched_event + if batched_event.batch_size >= batched_event.max_batch_size: + self._direct_capture(batched_event) + del self.batched_events[batch_key] + + def _direct_capture(self, event: TelemetryEvent) -> None: try: posthog.capture( self.user_id, diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index 124213c365b..aaef53c01e2 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -9,6 +9,8 @@ import numpy as np import numpy.typing as npt import importlib +import inspect +import sys from typing import Optional try: @@ -43,7 +45,7 @@ def __init__( self._normalize_embeddings = normalize_embeddings def __call__(self, texts: Documents) -> Embeddings: - return self._model.encode( + return self._model.encode( # type: ignore list(texts), convert_to_numpy=True, normalize_embeddings=self._normalize_embeddings, @@ -224,10 +226,10 @@ def __init__( def __call__(self, texts: Documents) -> Embeddings: if self._instruction is None: - return self._model.encode(texts).tolist() + return self._model.encode(texts).tolist() # type: ignore texts_with_instructions = [[self._instruction, text] for text in texts] - return self._model.encode(texts_with_instructions).tolist() + return self._model.encode(texts_with_instructions).tolist() # type: ignore # In order to remove dependencies on sentence-transformers, which in turn depends on @@ -302,12 +304,12 @@ def _download(self, url: str, fname: str, chunk_size: int = 1024) -> None: # Use pytorches default epsilon for division by zero # https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html - def _normalize(self, v: npt.NDArray) -> npt.NDArray: + def _normalize(self, v: npt.NDArray) -> npt.NDArray: # type: ignore norm = np.linalg.norm(v, axis=1) norm[norm == 0] = 1e-12 - return v / norm[:, np.newaxis] + return v / norm[:, np.newaxis] # type: ignore - def _forward(self, documents: List[str], batch_size: int = 32) -> npt.NDArray: + def _forward(self, documents: List[str], batch_size: int = 32) -> npt.NDArray: # type: ignore # We need to cast to the correct type because the type checker doesn't know that init_model_and_tokenizer will set the values self.tokenizer = cast(self.Tokenizer, self.tokenizer) # type: ignore self.model = cast(self.ort.InferenceSession, self.model) # type: ignore @@ -475,3 +477,15 @@ def __call__(self, texts: Documents) -> Embeddings: embeddings.append(response["predictions"]["embeddings"]["values"]) return embeddings + + +# List of all classes in this module +_classes = [ + name + for name, obj in inspect.getmembers(sys.modules[__name__], inspect.isclass) + if obj.__module__ == __name__ +] + + +def get_builtins() -> List[str]: + return _classes From 8a6ad071277feca0c94368e1ac09a46faea65401 Mon Sep 17 00:00:00 2001 From: Hammad Bashir Date: Mon, 25 Sep 2023 09:25:39 -0700 Subject: [PATCH 29/39] [CHORE] Add support for pydantic v2 (#1174) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description of changes Closes #893 *Summarize the changes made by this PR.* - Improvements & Bug fixes - Adds support for pydantic v2.0 by changing how Collection model inits - this simple change fixes pydantic v2 - Fixes the cross version tests to handle pydantic specifically - Conditionally imports pydantic-settings based on what is available. In v2 BaseSettings was moved to a new package. - New functionality - N/A ## Test plan Existing tests were run with the following configs 1. Fastapi < 0.100, Pydantic >= 2.0 - Unsupported as the fastapi dependencies will not allow it. They likely should, as pydantic.v1 imports would support this, but this is a downstream issue. 2. Fastapi >= 0.100, Pydantic >= 2.0, Supported via normal imports ✅ (Tested with fastapi==0.103.1, pydantic==2.3.0) 3. Fastapi < 0.100 Pydantic < 2.0, Supported via normal imports ✅ (Tested with fastapi==0.95.2, pydantic==1.9.2) 4. Fastapi >= 0.100, Pydantic < 2.0, Supported via normal imports ✅ (Tested with latest fastapi, pydantic==1.9.2) - [x] Tests pass locally with `pytest` for python, `yarn test` for js ## Documentation Changes None required. --- chromadb/api/models/Collection.py | 3 ++- chromadb/auth/providers.py | 1 - chromadb/config.py | 13 ++++++++++++- .../test/property/test_cross_version_persist.py | 12 +++++++++--- pyproject.toml | 4 ++-- requirements.txt | 6 +++--- 6 files changed, 28 insertions(+), 11 deletions(-) diff --git a/chromadb/api/models/Collection.py b/chromadb/api/models/Collection.py index 6b4f7f18bd9..c11a04b1fa4 100644 --- a/chromadb/api/models/Collection.py +++ b/chromadb/api/models/Collection.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING, Optional, Tuple, cast, List from pydantic import BaseModel, PrivateAttr + from uuid import UUID import chromadb.utils.embedding_functions as ef @@ -50,9 +51,9 @@ def __init__( embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), metadata: Optional[CollectionMetadata] = None, ): + super().__init__(name=name, metadata=metadata, id=id) self._client = client self._embedding_function = embedding_function - super().__init__(name=name, metadata=metadata, id=id) def __repr__(self) -> str: return f"Collection(name={self.name})" diff --git a/chromadb/auth/providers.py b/chromadb/auth/providers.py index a3bb23616e2..eceee3bc2ab 100644 --- a/chromadb/auth/providers.py +++ b/chromadb/auth/providers.py @@ -5,7 +5,6 @@ import requests from overrides import override from pydantic import SecretStr - from chromadb.auth import ( ServerAuthCredentialsProvider, AbstractCredentials, diff --git a/chromadb/config.py b/chromadb/config.py index 6167193acd2..0e9c87e5572 100644 --- a/chromadb/config.py +++ b/chromadb/config.py @@ -9,9 +9,20 @@ from overrides import EnforceOverrides from overrides import override -from pydantic import BaseSettings, validator from typing_extensions import Literal + +in_pydantic_v2 = False +try: + from pydantic import BaseSettings +except ImportError: + in_pydantic_v2 = True + from pydantic.v1 import BaseSettings + from pydantic.v1 import validator + +if not in_pydantic_v2: + from pydantic import validator # type: ignore # noqa + # The thin client will have a flag to control which implementations to use is_thin_client = False try: diff --git a/chromadb/test/property/test_cross_version_persist.py b/chromadb/test/property/test_cross_version_persist.py index d1785b84140..529fe02dda7 100644 --- a/chromadb/test/property/test_cross_version_persist.py +++ b/chromadb/test/property/test_cross_version_persist.py @@ -24,6 +24,9 @@ MINIMUM_VERSION = "0.4.1" version_re = re.compile(r"^[0-9]+\.[0-9]+\.[0-9]+$") +# Some modules do not work across versions, since we upgrade our support for them, and should be explicitly reimported in the subprocess +VERSIONED_MODULES = ["pydantic"] + def versions() -> List[str]: """Returns the pinned minimum version and the latest version of chromadb.""" @@ -49,7 +52,7 @@ def _patch_boolean_metadata( # boolean value metadata to int collection_metadata = collection.metadata if collection_metadata is not None: - _bool_to_int(collection_metadata) + _bool_to_int(collection_metadata) # type: ignore if embeddings["metadatas"] is not None: if isinstance(embeddings["metadatas"], list): @@ -162,7 +165,10 @@ def switch_to_version(version: str) -> ModuleType: old_modules = { n: m for n, m in sys.modules.items() - if n == module_name or (n.startswith(module_name + ".")) + if n == module_name + or (n.startswith(module_name + ".")) + or n in VERSIONED_MODULES + or (any(n.startswith(m + ".") for m in VERSIONED_MODULES)) } for n in old_modules: del sys.modules[n] @@ -197,7 +203,7 @@ def persist_generated_data_with_old_version( api.reset() coll = api.create_collection( name=collection_strategy.name, - metadata=collection_strategy.metadata, + metadata=collection_strategy.metadata, # type: ignore # In order to test old versions, we can't rely on the not_implemented function embedding_function=not_implemented_ef(), ) diff --git a/pyproject.toml b/pyproject.toml index 8fc60673607..7db0fe821ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,9 +16,9 @@ classifiers = [ ] dependencies = [ 'requests >= 2.28', - 'pydantic>=1.9,<2.0', + 'pydantic >= 1.9', 'chroma-hnswlib==0.7.3', - 'fastapi>=0.95.2, <0.100.0', + 'fastapi >= 0.95.2', 'uvicorn[standard] >= 0.18.3', 'numpy == 1.21.6; python_version < "3.8"', 'numpy >= 1.22.5; python_version >= "3.8"', diff --git a/requirements.txt b/requirements.txt index 9a9fdcc295c..80f4d9be904 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ bcrypt==4.0.1 chroma-hnswlib==0.7.3 -fastapi>=0.95.2, <0.100.0 +fastapi>=0.95.2 graphlib_backport==1.0.3; python_version < '3.9' importlib-resources numpy==1.21.6; python_version < '3.8' @@ -9,11 +9,11 @@ onnxruntime==1.14.1 overrides==7.3.1 posthog==2.4.0 pulsar-client==3.1.0 -pydantic>=1.9,<2.0 +pydantic>=1.9 pypika==0.48.9 requests==2.28.1 tokenizers==0.13.2 tqdm==4.65.0 typer>=0.9.0 -typing_extensions==4.5.0 +typing_extensions>=4.5.0 uvicorn[standard]==0.18.3 From c41bfa4ffca5a4341c640b595ea85512617fb0d5 Mon Sep 17 00:00:00 2001 From: Hammad Bashir Date: Mon, 25 Sep 2023 09:30:43 -0700 Subject: [PATCH 30/39] [RELEASE] 0.4.13 (#1180) Release 0.4.13 --- chromadb/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chromadb/__init__.py b/chromadb/__init__.py index ad7d3d4f70b..aa5a3edd7ea 100644 --- a/chromadb/__init__.py +++ b/chromadb/__init__.py @@ -44,7 +44,7 @@ __settings = Settings() -__version__ = "0.4.12" +__version__ = "0.4.13" # Workaround to deal with Colab's old sqlite3 version try: From 5cfaa885bc33e5f668c67ba587218fc789395aaf Mon Sep 17 00:00:00 2001 From: Thomas Betous Date: Mon, 2 Oct 2023 22:39:30 +0200 Subject: [PATCH 31/39] [Bug]: Fix notifyAll deprecation (#1199) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Fixes #1198 about notifyAll deprecation --- chromadb/utils/read_write_lock.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chromadb/utils/read_write_lock.py b/chromadb/utils/read_write_lock.py index 16c60ca2a3e..c6863049bd6 100644 --- a/chromadb/utils/read_write_lock.py +++ b/chromadb/utils/read_write_lock.py @@ -26,7 +26,7 @@ def release_read(self) -> None: try: self._readers -= 1 if not self._readers: - self._read_ready.notifyAll() + self._read_ready.notify_all() finally: self._read_ready.release() From 44d04b63353ba8fc4661b108c79d340918bc2d73 Mon Sep 17 00:00:00 2001 From: Hammad Bashir Date: Mon, 2 Oct 2023 14:41:54 -0700 Subject: [PATCH 32/39] [BUG] Update requirements for 3.11 support (#1201) ## Description of changes Our pyproject.toml is less restrictive than requirements.txt, which makes the published package support 3.11, but not in local dev. *Summarize the changes made by this PR.* - Improvements & Bug fixes - Updates requirements.txt to resolve for 3.11 - New functionality - None ## Test plan *How are these changes tested?* Pyenv was updated to 3.11.4 and tested locally. - [x] Tests pass locally with `pytest` for python, `yarn test` for js - **Yes, tested on 3.11.4** ## Documentation Changes We can update docs to remove any verbiage restricting 3.11. --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 80f4d9be904..8d00aa134cb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,8 +4,8 @@ fastapi>=0.95.2 graphlib_backport==1.0.3; python_version < '3.9' importlib-resources numpy==1.21.6; python_version < '3.8' -numpy==1.22.4; python_version >= '3.8' -onnxruntime==1.14.1 +numpy>=1.22.4; python_version >= '3.8' +onnxruntime>=1.14.1 overrides==7.3.1 posthog==2.4.0 pulsar-client==3.1.0 From 5b0ff2e4fa9e75073f0984e87d26bb2c08ea09fa Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Tue, 3 Oct 2023 01:18:34 +0300 Subject: [PATCH 33/39] [ENH] Render.com Terraform Blueprint (#1176) ## Description of changes *Summarize the changes made by this PR.* - New functionality - New blueprint for render.com ## Test plan *How are these changes tested?* Manual testing ## Documentation Changes README part of the deployment --- .../deployments/render-terraform/README.md | 118 ++++++++++++++++++ .../deployments/render-terraform/chroma.tf | 88 +++++++++++++ .../render-terraform/sqlite_version.patch | 29 +++++ .../deployments/render-terraform/variables.tf | 70 +++++++++++ 4 files changed, 305 insertions(+) create mode 100644 examples/deployments/render-terraform/README.md create mode 100644 examples/deployments/render-terraform/chroma.tf create mode 100644 examples/deployments/render-terraform/sqlite_version.patch create mode 100644 examples/deployments/render-terraform/variables.tf diff --git a/examples/deployments/render-terraform/README.md b/examples/deployments/render-terraform/README.md new file mode 100644 index 00000000000..eab333cbeea --- /dev/null +++ b/examples/deployments/render-terraform/README.md @@ -0,0 +1,118 @@ +# Render.com Deployment + +This is an example deployment to Render.com using [terraform](https://www.terraform.io/) + +## Requirements + +- [Terraform CLI v1.3.4+](https://developer.hashicorp.com/terraform/tutorials/gcp-get-started/install-cli) +- [Terraform Render provider](https://registry.terraform.io/providers/jackall3n/render/latest/docs) + +## Deployment with terraform + +### 1. Init your terraform state + +```bash +terraform init +``` + +### 3. Deploy your application + +```bash +# Your Render.com API token. IMPORTANT: The API does not work with Free plan. +export TF_VAR_render_api_token= +# Your Render.com user email +export TF_VAR_render_user_email= +#set the chroma release to deploy +export TF_VAR_chroma_release="0.4.13" +# the region to deploy to. At the time of writing only oregon and frankfurt are available +export TF_VAR_region="oregon" +#enable basic auth for the chroma instance +export TF_VAR_enable_auth="true" +#The auth type to use for the chroma instance (token or basic) +export TF_VAR_auth_type="token" +terraform apply -auto-approve +``` + +### 4. Check your public IP and that Chroma is running + +> Note: It might take couple minutes for the instance to boot up + +Get the public IP of your instance (it should also be printed out after successful `terraform apply`): + +```bash +terraform output instance_url +``` + +Check that chroma is running: + +```bash +export instance_public_ip=$(terraform output instance_url | sed 's/"//g') +curl -v $instance_public_ip/api/v1/heartbeat +``` + +#### 4.1 Checking Auth + +##### Token + +When token auth is enabled (this is the default option) you can check the get the credentials from Terraform state by +running: + +```bash +terraform output chroma_auth_token +``` + +You should see something of the form: + +```bash +PVcQ4qUUnmahXwUgAf3UuYZoMlos6MnF +``` + +You can then export these credentials: + +```bash +export CHROMA_AUTH=$(terraform output chroma_auth_token | sed 's/"//g') +``` + +Using the credentials: + +```bash +curl -v $instance_public_ip/api/v1/collections -H "Authorization: Bearer ${CHROMA_AUTH}" +``` + +##### Basic + +When basic auth is enabled you can check the get the credentials from Terraform state by running: + +```bash +terraform output chroma_auth_basic +``` + +You should see something of the form: + +```bash +chroma:VuA8I}QyNrm0@QLq +``` + +You can then export these credentials: + +```bash +export CHROMA_AUTH=$(terraform output chroma_auth_basic | sed 's/"//g') +``` + +Using the credentials: + +```bash +curl -v https://$instance_public_ip:8000/api/v1/collections -u "${CHROMA_AUTH}" +``` + +> Note: Without `-u` you should be getting 401 Unauthorized response + +#### 4.2 SSH to your instance + +To connect to your instance via SSH you need to go to Render.com service dashboard. + +### 5. Destroy your application + +```bash +terraform destroy +``` diff --git a/examples/deployments/render-terraform/chroma.tf b/examples/deployments/render-terraform/chroma.tf new file mode 100644 index 00000000000..a6ef69113a5 --- /dev/null +++ b/examples/deployments/render-terraform/chroma.tf @@ -0,0 +1,88 @@ +terraform { + required_providers { + render = { + source = "jackall3n/render" + version = "~> 1.3.0" + } + } +} + +variable "render_api_token" { + sensitive = true +} + +variable "render_user_email" { + sensitive = true +} + +provider "render" { + api_key = var.render_api_token +} + +data "render_owner" "render_owner" { + email = var.render_user_email +} + +resource "render_service" "chroma" { + name = "chroma" + owner = data.render_owner.render_owner.id + type = "web_service" + auto_deploy = true + + env_vars = concat([{ + key = "IS_PERSISTENT" + value = "1" + }, + { + key = "PERSIST_DIRECTORY" + value = var.chroma_data_volume_mount_path + }, + ], + var.enable_auth? [{ + key = "CHROMA_SERVER_AUTH_CREDENTIALS_PROVIDER" + value = "chromadb.auth.token.TokenConfigServerAuthCredentialsProvider" + }, + { + key = "CHROMA_SERVER_AUTH_CREDENTIALS" + value = "${local.token_auth_credentials.token}" + }, + { + key = "CHROMA_SERVER_AUTH_PROVIDER" + value = var.auth_type + }] : [] + ) + + image = { + owner_id = data.render_owner.render_owner.id + image_path = "${var.chroma_image_reg_url}:${var.chroma_release}" + } + + web_service_details = { + env = "image" + plan = var.render_plan + region = var.region + health_check_path = "/api/v1/heartbeat" + disk = { + name = var.chroma_data_volume_device_name + mount_path = var.chroma_data_volume_mount_path + size_gb = var.chroma_data_volume_size + } + docker = { + command = "uvicorn chromadb.app:app --reload --workers 1 --host 0.0.0.0 --port 80 --log-config chromadb/log_config.yml" + path = "./Dockerfile" + } + } +} + +output "service_id" { + value = render_service.chroma.id +} + +output "instance_url" { + value = render_service.chroma.web_service_details.url +} + +output "chroma_auth_token" { + value = random_password.chroma_token.result + sensitive = true +} diff --git a/examples/deployments/render-terraform/sqlite_version.patch b/examples/deployments/render-terraform/sqlite_version.patch new file mode 100644 index 00000000000..aa19837a916 --- /dev/null +++ b/examples/deployments/render-terraform/sqlite_version.patch @@ -0,0 +1,29 @@ +diff --git a/chromadb/__init__.py b/chromadb/__init__.py +index 0ff5244a..450aaf0d 100644 +--- a/chromadb/__init__.py ++++ b/chromadb/__init__.py +@@ -55,21 +55,9 @@ except ImportError: + IN_COLAB = False + + if sqlite3.sqlite_version_info < (3, 35, 0): +- if IN_COLAB: +- # In Colab, hotswap to pysqlite-binary if it's too old +- import subprocess +- import sys +- +- subprocess.check_call( +- [sys.executable, "-m", "pip", "install", "pysqlite3-binary"] +- ) +- __import__("pysqlite3") +- sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") +- else: +- raise RuntimeError( +- "\033[91mYour system has an unsupported version of sqlite3. Chroma requires sqlite3 >= 3.35.0.\033[0m\n" +- "\033[94mPlease visit https://docs.trychroma.com/troubleshooting#sqlite to learn how to upgrade.\033[0m" +- ) ++ __import__('pysqlite3') ++ import sys ++ sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') + + + def configure(**kwargs) -> None: # type: ignore diff --git a/examples/deployments/render-terraform/variables.tf b/examples/deployments/render-terraform/variables.tf new file mode 100644 index 00000000000..2acde15274a --- /dev/null +++ b/examples/deployments/render-terraform/variables.tf @@ -0,0 +1,70 @@ +variable "chroma_image_reg_url" { + description = "The URL of the chroma-core image registry (e.g. docker.io/chromadb/chroma). The URL must also include the image itself without the tag." + type = string + default = "docker.io/chromadb/chroma" +} + +variable "chroma_release" { + description = "The chroma release to deploy" + type = string + default = "0.4.13" +} + +variable "region" { + type = string + default = "oregon" +} + +variable "render_plan" { + default = "starter" + description = "The Render plan to use. This determines the size of the machine. NOTE: Terraform Render provider uses Render's API which requires at least starter plan." + type = string +} + +variable "enable_auth" { + description = "Enable authentication" + type = bool + default = true // or false depending on your needs +} + +variable "auth_type" { + description = "Authentication type" + type = string + default = "token" // or token depending on your needs + validation { + condition = contains([ "token"], var.auth_type) + error_message = "Only token is supported as auth type" + } +} + +resource "random_password" "chroma_token" { + length = 32 + special = false + lower = true + upper = true +} + + +locals { + token_auth_credentials = { + token = random_password.chroma_token.result + } +} + +variable "chroma_data_volume_size" { + description = "The size of the attached data volume in GB." + type = number + default = 20 +} + +variable "chroma_data_volume_device_name" { + default = "chroma-disk-0" + description = "The device name of the chroma data volume" + type = string +} + +variable "chroma_data_volume_mount_path" { + default = "/chroma-data" + description = "The mount path of the chroma data volume" + type = string +} From 2dcffca30660d69a0427ba8e6503feb9099560ee Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Tue, 3 Oct 2023 01:18:41 +0300 Subject: [PATCH 34/39] [ENH]: AWS Terraform Blueprint Improvements (#1173) Refs: #1172 and #1135 ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Added support for restoring from snapshot - Added snapshotting volume before destroy - Bumped default Chroma version to 0.4.12 - Added support for persistent volume mounts (in fstab) - Improved tagging - Made source ranges for the Security Group configurable - Made external Chroma port configurable - Switched to use a GH startup script. ## Test plan *How are these changes tested?* Tests not yet executed. ## Documentation Changes Updated README.md of the blueprint --- examples/deployments/aws-terraform/README.md | 62 +++++++++++-------- examples/deployments/aws-terraform/chroma.tf | 39 ++++++------ .../deployments/aws-terraform/variables.tf | 53 +++++++++++++++- 3 files changed, 107 insertions(+), 47 deletions(-) diff --git a/examples/deployments/aws-terraform/README.md b/examples/deployments/aws-terraform/README.md index d48ef3227e9..332cfd7265c 100644 --- a/examples/deployments/aws-terraform/README.md +++ b/examples/deployments/aws-terraform/README.md @@ -47,15 +47,28 @@ ssh-keygen -t RSA -b 4096 -C "Chroma AWS Key" -N "" -f ./chroma-aws && chmod 400 Set up your Terraform variables and deploy your instance: ```bash -export TF_VAR_AWS_ACCESS_KEY= #take note of this as it must be present in all of the subsequent steps -export TF_VAR_AWS_SECRET_ACCESS_KEY= #take note of this as it must be present in all of the subsequent steps -export TF_ssh_public_key="./chroma-aws.pub" #path to the public key you generated above (or can be different if you want to use your own key) -export TF_ssh_private_key="./chroma-aws" #path to the private key you generated above (or can be different if you want to use your own key) - used for formatting the Chroma data volume -export TF_VAR_chroma_release=0.4.8 #set the chroma release to deploy -export TF_VAR_region="us-west-1" # AWS region to deploy the chroma instance to -export TF_VAR_public_access="true" #enable public access to the chroma instance on port 8000 -export TF_VAR_enable_auth="true" #enable basic auth for the chroma instance -export TF_VAR_auth_type="token" #The auth type to use for the chroma instance (token or basic) +#AWS access key +export TF_VAR_AWS_ACCESS_KEY= +#AWS secret access key +export TF_VAR_AWS_SECRET_ACCESS_KEY= +#path to the public key you generated above (or can be different if you want to use your own key) +export TF_ssh_public_key="./chroma-aws.pub" +#path to the private key you generated above (or can be different if you want to use your own key) - used for formatting the Chroma data volume +export TF_ssh_private_key="./chroma-aws" +#set the chroma release to deploy +export TF_VAR_chroma_release=0.4.12 +# AWS region to deploy the chroma instance to +export TF_VAR_region="us-west-1" +#enable public access to the chroma instance on port 8000 +export TF_VAR_public_access="true" +#enable basic auth for the chroma instance +export TF_VAR_enable_auth="true" +#The auth type to use for the chroma instance (token or basic) +export TF_VAR_auth_type="token" +#optional - if you want to restore from a snapshot +export TF_VAR_chroma_data_restore_from_snapshot_id="" +#optional - if you want to snapshot the data volume before destroying the instance +export TF_VAR_chroma_data_volume_snapshot_before_destroy="true" terraform apply -auto-approve ``` > Note: Basic Auth is supported by Chroma v0.4.7+ @@ -77,60 +90,59 @@ curl -v http://$instance_public_ip:8000/api/v1/heartbeat #### 4.1 Checking Auth -##### Basic -When basic auth is enabled you can check the get the credentials from Terraform state by running: +##### Token +When token auth is enabled you can check the get the credentials from Terraform state by running: ```bash -terraform output chroma_auth_basic +terraform output chroma_auth_token ``` You should see something of the form: ```bash -chroma:VuA8I}QyNrm0@QLq +PVcQ4qUUnmahXwUgAf3UuYZoMlos6MnF ``` You can then export these credentials: ```bash -export CHROMA_AUTH=$(terraform output chroma_auth_basic | sed 's/"//g') +export CHROMA_AUTH=$(terraform output chroma_auth_token | sed 's/"//g') ``` Using the credentials: ```bash -curl -v http://$instance_public_ip:8000/api/v1/collections -u "${CHROMA_AUTH}" +curl -v http://$instance_public_ip:8000/api/v1/collections -H "Authorization: Bearer ${CHROMA_AUTH}" ``` -> Note: Without `-u` you should be getting 401 Unauthorized response - - -##### Token -When token auth is enabled you can check the get the credentials from Terraform state by running: +##### Basic +When basic auth is enabled you can check the get the credentials from Terraform state by running: ```bash -terraform output chroma_auth_token +terraform output chroma_auth_basic ``` You should see something of the form: ```bash -PVcQ4qUUnmahXwUgAf3UuYZoMlos6MnF +chroma:VuA8I}QyNrm0@QLq ``` You can then export these credentials: ```bash -export CHROMA_AUTH=$(terraform output chroma_auth_token | sed 's/"//g') +export CHROMA_AUTH=$(terraform output chroma_auth_basic | sed 's/"//g') ``` Using the credentials: ```bash -curl -v http://$instance_public_ip:8000/api/v1/collections -H "Authorization: Bearer ${CHROMA_AUTH}" +curl -v http://$instance_public_ip:8000/api/v1/collections -u "${CHROMA_AUTH}" ``` -#### 4.2 SSH to your instance +> Note: Without `-u` you should be getting 401 Unauthorized response + +#### 4.2 Connect (ssh) to your instance To SSH to your instance: diff --git a/examples/deployments/aws-terraform/chroma.tf b/examples/deployments/aws-terraform/chroma.tf index fca553ae90d..bd44c62e319 100644 --- a/examples/deployments/aws-terraform/chroma.tf +++ b/examples/deployments/aws-terraform/chroma.tf @@ -26,16 +26,16 @@ resource "aws_security_group" "chroma_sg" { from_port = 22 to_port = 22 protocol = "tcp" - cidr_blocks = ["0.0.0.0/0"] + cidr_blocks = var.mgmt_source_ranges } dynamic "ingress" { for_each = var.public_access ? [1] : [] content { - from_port = 8000 + from_port = var.chroma_port to_port = 8000 protocol = "tcp" - cidr_blocks = ["0.0.0.0/0"] + cidr_blocks = var.source_ranges } } @@ -47,9 +47,7 @@ resource "aws_security_group" "chroma_sg" { ipv6_cidr_blocks = ["::/0"] } - tags = { - Name = "chroma" - } + tags = local.tags } resource "aws_key_pair" "chroma-keypair" { @@ -83,17 +81,9 @@ resource "aws_instance" "chroma_instance" { key_name = "chroma-keypair" security_groups = [aws_security_group.chroma_sg.name] - user_data = templatefile("${path.module}/startup.sh", { - chroma_release = var.chroma_release, - enable_auth = var.enable_auth, - auth_type = var.auth_type, - basic_auth_credentials = "${local.basic_auth_credentials.username}:${local.basic_auth_credentials.password}", - token_auth_credentials = random_password.chroma_token.result, - }) + user_data = data.template_file.user_data.rendered - tags = { - Name = "chroma" - } + tags = local.tags ebs_block_device { device_name = "/dev/sda1" @@ -105,10 +95,10 @@ resource "aws_instance" "chroma_instance" { resource "aws_ebs_volume" "chroma-volume" { availability_zone = aws_instance.chroma_instance.availability_zone size = var.chroma_data_volume_size + final_snapshot = var.chroma_data_volume_snapshot_before_destroy + snapshot_id = var.chroma_data_restore_from_snapshot_id - tags = { - Name = "chroma" - } + tags = local.tags lifecycle { prevent_destroy = true @@ -119,15 +109,22 @@ locals { cleaned_volume_id = replace(aws_ebs_volume.chroma-volume.id, "-", "") } +locals { + restore_from_snapshot = length(var.chroma_data_restore_from_snapshot_id) == 0 ? false : true +} + resource "aws_volume_attachment" "chroma_volume_attachment" { device_name = "/dev/sdh" volume_id = aws_ebs_volume.chroma-volume.id instance_id = aws_instance.chroma_instance.id provisioner "remote-exec" { inline = [ - "export VOLUME_ID=${local.cleaned_volume_id} && sudo mkfs -t ext4 /dev/$(lsblk -o +SERIAL | grep $VOLUME_ID | awk '{print $1}')", + "if [ -z \"${local.restore_from_snapshot}\" ]; then export VOLUME_ID=${local.cleaned_volume_id} && sudo mkfs -t ext4 /dev/$(lsblk -o +SERIAL | grep $VOLUME_ID | awk '{print $1}'); fi", "sudo mkdir /chroma-data", - "export VOLUME_ID=${local.cleaned_volume_id} && sudo mount /dev/$(lsblk -o +SERIAL | grep $VOLUME_ID | awk '{print $1}') /chroma-data" + "export VOLUME_ID=${local.cleaned_volume_id} && sudo mount /dev/$(lsblk -o +SERIAL | grep $VOLUME_ID | awk '{print $1}') /chroma-data", + "export VOLUME_ID=${local.cleaned_volume_id} && cat <> /dev/null", + "/dev/$(lsblk -o +SERIAL | grep $VOLUME_ID | awk '{print $1}') /chroma-data ext4 defaults,nofail,discard 0 0", + "EOF", ] connection { diff --git a/examples/deployments/aws-terraform/variables.tf b/examples/deployments/aws-terraform/variables.tf index 15e0cc81e40..e7b7cd9b670 100644 --- a/examples/deployments/aws-terraform/variables.tf +++ b/examples/deployments/aws-terraform/variables.tf @@ -1,7 +1,24 @@ variable "chroma_release" { description = "The chroma release to deploy" type = string - default = "0.4.8" + default = "0.4.12" +} + +#TODO this should be updated to point to https://raw.githubusercontent.com/chroma-core/chroma/main/examples/deployments/common/startup.sh in the repo +data "http" "startup_script_remote" { + url = "https://raw.githubusercontent.com/chroma-core/chroma/main/examples/deployments/aws-terraform/startup.sh" +} + +data "template_file" "user_data" { + template = data.http.startup_script_remote.response_body + + vars = { + chroma_release = var.chroma_release + enable_auth = var.enable_auth + auth_type = var.auth_type + basic_auth_credentials = "${local.basic_auth_credentials.username}:${local.basic_auth_credentials.password}" + token_auth_credentials = random_password.chroma_token.result + } } variable "region" { @@ -62,6 +79,10 @@ locals { token_auth_credentials = { token = random_password.chroma_token.result } + tags = [ + "chroma", + "release-${replace(var.chroma_release, ".", "")}", + ] } variable "ssh_public_key" { @@ -86,3 +107,33 @@ variable "chroma_data_volume_size" { type = number default = 20 } + +variable "chroma_data_volume_snapshot_before_destroy" { + description = "Take a snapshot of the chroma data volume before destroying it" + type = bool + default = false +} + +variable "chroma_data_restore_from_snapshot_id" { + description = "Restore the chroma data volume from a snapshot" + type = string + default = null +} + +variable "chroma_port" { + default = "8000" + description = "The port that chroma listens on" + type = string +} + +variable "source_ranges" { + default = ["0.0.0.0/0", "::/0"] + type = list(string) + description = "List of CIDR ranges to allow through the firewall" +} + +variable "mgmt_source_ranges" { + default = ["0.0.0.0/0", "::/0"] + type = list(string) + description = "List of CIDR ranges to allow for management of the Chroma instance. This is used for SSH incoming traffic filtering" +} From 0d675094035cf87904c77404f0a94a3137a6bc27 Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Tue, 3 Oct 2023 01:18:51 +0300 Subject: [PATCH 35/39] [ENH]: DO deployment blueprint (#1171) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Full deployment of a Chroma on a DO droplet with attached volume ## Test plan *How are these changes tested?* Manually tested. ## Documentation Changes Docs are in the deployment README.md --- examples/deployments/common/startup.sh | 53 ++++++ examples/deployments/do-terraform/README.md | 163 ++++++++++++++++++ examples/deployments/do-terraform/chroma.tf | 133 ++++++++++++++ .../deployments/do-terraform/variables.tf | 126 ++++++++++++++ 4 files changed, 475 insertions(+) create mode 100644 examples/deployments/common/startup.sh create mode 100644 examples/deployments/do-terraform/README.md create mode 100644 examples/deployments/do-terraform/chroma.tf create mode 100644 examples/deployments/do-terraform/variables.tf diff --git a/examples/deployments/common/startup.sh b/examples/deployments/common/startup.sh new file mode 100644 index 00000000000..a6e5b3134f3 --- /dev/null +++ b/examples/deployments/common/startup.sh @@ -0,0 +1,53 @@ +#! /bin/bash + +# Note: This is run as root + +cd ~ +export enable_auth="${enable_auth}" +export basic_auth_credentials="${basic_auth_credentials}" +export auth_type="${auth_type}" +export token_auth_credentials="${token_auth_credentials}" +apt-get update -y +apt-get install -y ca-certificates curl gnupg lsb-release +mkdir -m 0755 -p /etc/apt/keyrings +curl -fsSL https://download.docker.com/linux/ubuntu/gpg | gpg --dearmor -o /etc/apt/keyrings/docker.gpg +echo \ + "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu \ + $(lsb_release -cs) stable" | tee /etc/apt/sources.list.d/docker.list > /dev/null +apt-get update -y +chmod a+r /etc/apt/keyrings/docker.gpg +apt-get update -y +apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin git +usermod -aG docker ubuntu +git clone https://github.com/chroma-core/chroma.git && cd chroma +git fetch --tags +git checkout tags/${chroma_release} + +if [ "$${enable_auth}" = "true" ] && [ "$${auth_type}" = "basic" ] && [ ! -z "$${basic_auth_credentials}" ]; then + username=$(echo $basic_auth_credentials | cut -d: -f1) + password=$(echo $basic_auth_credentials | cut -d: -f2) + docker run --rm --entrypoint htpasswd httpd:2 -Bbn $username $password > server.htpasswd + cat < .env +CHROMA_SERVER_AUTH_CREDENTIALS_FILE="/chroma/server.htpasswd" +CHROMA_SERVER_AUTH_CREDENTIALS_PROVIDER='chromadb.auth.providers.HtpasswdFileServerAuthCredentialsProvider' +CHROMA_SERVER_AUTH_PROVIDER='chromadb.auth.basic.BasicAuthServerProvider' +EOF +fi + +if [ "$${enable_auth}" = "true" ] && [ "$${auth_type}" = "token" ] && [ ! -z "$${token_auth_credentials}" ]; then + cat < .env +CHROMA_SERVER_AUTH_CREDENTIALS="$${token_auth_credentials}" \ +CHROMA_SERVER_AUTH_CREDENTIALS_PROVIDER='chromadb.auth.token.TokenConfigServerAuthCredentialsProvider' +CHROMA_SERVER_AUTH_PROVIDER='chromadb.auth.token.TokenAuthServerProvider' +EOF +fi + +cat < docker-compose.override.yaml +version: '3.8' +services: + server: + volumes: + - /chroma-data:/chroma/chroma +EOF + +COMPOSE_PROJECT_NAME=chroma docker compose up -d --build diff --git a/examples/deployments/do-terraform/README.md b/examples/deployments/do-terraform/README.md new file mode 100644 index 00000000000..80957bdd810 --- /dev/null +++ b/examples/deployments/do-terraform/README.md @@ -0,0 +1,163 @@ +# Digital Ocean Droplet Deployment + +This is an example deployment using Digital Ocean Droplet using [terraform](https://www.terraform.io/). + +This deployment will do the following: + +- 🔥 Create a firewall with required ports open (22 and 8000) +- 🐳 Create Droplet with Ubuntu 22 and deploy Chroma using docker compose +- 💿 Create a data volume for Chroma data +- 🗻 Mount the data volume to the Droplet instance +- ✏️ Format the data volume with ext4 +- 🏃‍ Start Chroma + +## Requirements + +- [Terraform CLI v1.3.4+](https://developer.hashicorp.com/terraform/tutorials/gcp-get-started/install-cli) + +## Deployment with terraform + +This deployment uses Ubuntu 22 as foundation, but you'd like to use a different image for your Droplet ( +see https://slugs.do-api.dev/ for a list of available images) + +### Configuration Options + + +### 1. Init your terraform state + +```bash +terraform init +``` + +### 2. Deploy your application + +Generate SSH key to use with your chroma instance (so you can log in to the Droplet): + +> Note: This is optional. You can use your own existing SSH key if you prefer. + +```bash +ssh-keygen -t RSA -b 4096 -C "Chroma DO Key" -N "" -f ./chroma-do && chmod 400 ./chroma-do +``` + +Set up your Terraform variables and deploy your instance: + +```bash +#take note of this as it must be present in all of the subsequent steps +export TF_VAR_do_token= +#path to the public key you generated above (or can be different if you want to use your own key) +export TF_ssh_public_key="./chroma-do.pub" +#path to the private key you generated above (or can be different if you want to use your own key) - used for formatting the Chroma data volume +export TF_ssh_private_key="./chroma-do" +#set the chroma release to deploy +export TF_VAR_chroma_release="0.4.12" +# DO region to deploy the chroma instance to +export TF_VAR_region="ams2" +#enable public access to the chroma instance on port 8000 +export TF_VAR_public_access="true" +#enable basic auth for the chroma instance +export TF_VAR_enable_auth="true" +#The auth type to use for the chroma instance (token or basic) +export TF_VAR_auth_type="token" +terraform apply -auto-approve +``` + +> Note: Basic Auth is supported by Chroma v0.4.7+ + +### 4. Check your public IP and that Chroma is running + +Get the public IP of your instance + +```bash +terraform output instance_public_ip +``` + +Check that chroma is running (It should take up several minutes for the instance to be ready) + +```bash +export instance_public_ip=$(terraform output instance_public_ip | sed 's/"//g') +curl -v http://$instance_public_ip:8000/api/v1/heartbeat +``` + +#### 4.1 Checking Auth + +##### Token + +When token auth is enabled you can check the get the credentials from Terraform state by running: + +```bash +terraform output chroma_auth_token +``` + +You should see something of the form: + +```bash +PVcQ4qUUnmahXwUgAf3UuYZoMlos6MnF +``` + +You can then export these credentials: + +```bash +export CHROMA_AUTH=$(terraform output chroma_auth_token | sed 's/"//g') +``` + +Using the credentials: + +```bash +curl -v http://$instance_public_ip:8000/api/v1/collections -H "Authorization: Bearer ${CHROMA_AUTH}" +``` + +##### Basic + +When basic auth is enabled you can check the get the credentials from Terraform state by running: + +```bash +terraform output chroma_auth_basic +``` + +You should see something of the form: + +```bash +chroma:VuA8I}QyNrm0@QLq +``` + +You can then export these credentials: + +```bash +export CHROMA_AUTH=$(terraform output chroma_auth_basic | sed 's/"//g') +``` + +Using the credentials: + +```bash +curl -v http://$instance_public_ip:8000/api/v1/collections -u "${CHROMA_AUTH}" +``` + +> Note: Without `-u` you should be getting 401 Unauthorized response + +#### 4.2 SSH to your instance + +To SSH to your instance: + +```bash +ssh -i ./chroma-do root@$instance_public_ip +``` + +### 5. Destroy your Chroma instance + +```bash +terraform destroy -auto-approve +``` + +## Extras + +You can visualize your infrastructure with: + +```bash +terraform graph | dot -Tsvg > graph.svg +``` + +> Note: You will need graphviz installed for this to work + +### Digital Ocean Resource Types + +Refs: https://slugs.do-api.dev/ diff --git a/examples/deployments/do-terraform/chroma.tf b/examples/deployments/do-terraform/chroma.tf new file mode 100644 index 00000000000..79960c80fe9 --- /dev/null +++ b/examples/deployments/do-terraform/chroma.tf @@ -0,0 +1,133 @@ +terraform { + required_providers { + digitalocean = { + source = "digitalocean/digitalocean" + version = "~> 2.0" + } + } +} + +# Define provider +variable "do_token" {} + +# Configure the DigitalOcean Provider +provider "digitalocean" { + token = var.do_token +} + + +resource "digitalocean_firewall" "chroma_firewall" { + name = "chroma-firewall" + + droplet_ids = [digitalocean_droplet.chroma_instance.id] + + inbound_rule { + protocol = "tcp" + port_range = "22" + source_addresses = var.mgmt_source_ranges + } + + dynamic "inbound_rule" { + for_each = var.public_access ? [1] : [] + content { + protocol = "tcp" + port_range = var.chroma_port + source_addresses = var.source_ranges + } + } + + outbound_rule { + protocol = "tcp" + port_range = "1-65535" + destination_addresses = ["0.0.0.0/0", "::/0"] + } + + outbound_rule { + protocol = "icmp" + port_range = "1-65535" + destination_addresses = ["0.0.0.0/0", "::/0"] + } + + outbound_rule { + protocol = "udp" + port_range = "1-65535" + destination_addresses = ["0.0.0.0/0", "::/0"] + } + + tags = local.tags + +} + +resource "digitalocean_ssh_key" "chroma_keypair" { + name = "chroma_keypair" + public_key = file(var.ssh_public_key) +} + + +#Create Droplet +resource "digitalocean_droplet" "chroma_instance" { + image = var.instance_image + name = "chroma" + region = var.region + size = var.instance_type + ssh_keys = [digitalocean_ssh_key.chroma_keypair.fingerprint] + + user_data = data.template_file.user_data.rendered + + tags = local.tags +} + + +resource "digitalocean_volume" "chroma_volume" { + region = digitalocean_droplet.chroma_instance.region + name = "chroma-volume" + size = var.chroma_data_volume_size + description = "Chroma data volume" + tags = local.tags +} + +resource "digitalocean_volume_attachment" "chroma_data_volume_attachment" { + droplet_id = digitalocean_droplet.chroma_instance.id + volume_id = digitalocean_volume.chroma_volume.id + + provisioner "remote-exec" { + inline = [ + "export VOLUME_ID=${digitalocean_volume.chroma_volume.name} && sudo mkfs -t ext4 /dev/$(lsblk -o +SERIAL | grep $VOLUME_ID | awk '{print $1}')", + "sudo mkdir /chroma-data", + "export VOLUME_ID=${digitalocean_volume.chroma_volume.name} && sudo mount /dev/$(lsblk -o +SERIAL | grep $VOLUME_ID | awk '{print $1}') /chroma-data", + "cat <> /dev/null", + "/dev/disk/by-id/scsi-0DO_Volume_${digitalocean_volume.chroma_volume.name} /chroma-data ext4 defaults,nofail,discard 0 0", + "EOF", + ] + + connection { + host = digitalocean_droplet.chroma_instance.ipv4_address + type = "ssh" + user = "root" + private_key = file(var.ssh_private_key) + } + } +} + + +output "instance_public_ip" { + value = digitalocean_droplet.chroma_instance.ipv4_address + description = "The public IP address of the Chroma instance" +} + +output "instance_private_ip" { + value = digitalocean_droplet.chroma_instance.ipv4_address_private + description = "The private IP address of the Chroma instance" +} + +output "chroma_auth_token" { + description = "The Chroma static auth token" + value = random_password.chroma_token.result + sensitive = true +} + +output "chroma_auth_basic" { + description = "The Chroma basic auth credentials" + value = "${local.basic_auth_credentials.username}:${local.basic_auth_credentials.password}" + sensitive = true +} diff --git a/examples/deployments/do-terraform/variables.tf b/examples/deployments/do-terraform/variables.tf new file mode 100644 index 00000000000..75ce6dc9a37 --- /dev/null +++ b/examples/deployments/do-terraform/variables.tf @@ -0,0 +1,126 @@ +variable "instance_image" { + description = "The image to use for the instance" + type = string + default = "ubuntu-22-04-x64" +} +variable "chroma_release" { + description = "The chroma release to deploy" + type = string + default = "0.4.12" +} + +data "http" "startup_script_remote" { + url = "https://raw.githubusercontent.com/chroma-core/chroma/main/examples/deployments/common/startup.sh" +} + +data "template_file" "user_data" { + template = data.http.startup_script_remote.response_body + + vars = { + chroma_release = var.chroma_release + enable_auth = var.enable_auth + auth_type = var.auth_type + basic_auth_credentials = "${local.basic_auth_credentials.username}:${local.basic_auth_credentials.password}" + token_auth_credentials = random_password.chroma_token.result + } +} + +variable "region" { + description = "DO Region" + type = string + default = "nyc2" +} + +variable "instance_type" { + description = "Droplet size" + type = string + default = "s-2vcpu-4gb" +} + + +variable "public_access" { + description = "Enable public ingress on port 8000" + type = bool + default = true // or false depending on your needs +} + +variable "enable_auth" { + description = "Enable authentication" + type = bool + default = true // or false depending on your needs +} + +variable "auth_type" { + description = "Authentication type" + type = string + default = "token" // or basic depending on your needs + validation { + condition = contains(["basic", "token"], var.auth_type) + error_message = "The auth type must be either basic or token" + } +} + +resource "random_password" "chroma_password" { + length = 16 + special = true + lower = true + upper = true +} + +resource "random_password" "chroma_token" { + length = 32 + special = false + lower = true + upper = true +} + + +locals { + basic_auth_credentials = { + username = "chroma" + password = random_password.chroma_password.result + } + token_auth_credentials = { + token = random_password.chroma_token.result + } + tags = [ + "chroma", + "release-${replace(var.chroma_release, ".", "")}", + ] +} + +variable "ssh_public_key" { + description = "SSH Public Key" + type = string + default = "./chroma-do.pub" +} +variable "ssh_private_key" { + description = "SSH Private Key" + type = string + default = "./chroma-do" +} + +variable "chroma_data_volume_size" { + description = "EBS Volume Size of the attached data volume where your chroma data is stored" + type = number + default = 20 +} + + +variable "chroma_port" { + default = "8000" + description = "The port that chroma listens on" + type = string +} + +variable "source_ranges" { + default = ["0.0.0.0/0", "::/0"] + type = list(string) + description = "List of CIDR ranges to allow through the firewall" +} + +variable "mgmt_source_ranges" { + default = ["0.0.0.0/0", "::/0"] + type = list(string) + description = "List of CIDR ranges to allow for management of the Chroma instance. This is used for SSH incoming traffic filtering" +} From de2f05a9dd2643efe37fd462e7448e4169279747 Mon Sep 17 00:00:00 2001 From: Hammad Bashir Date: Tue, 3 Oct 2023 15:52:45 -0700 Subject: [PATCH 36/39] Grpc Segments + Distributed Segment Manager (#952) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - N/A - New functionality - Adds gRPC based segments backed by local segment impls - Add protobuf service defs for the distributed index type - Extend convert.py with the needed proto conversions - Adds crude distributed segment manager - For now, only get() is supported, query is not - For now only vector segments are supported ## Test plan This PR adds a test harness to run the tests against the distributed impl. However the tests assume synchronous behavior in many places, and need to be overhauled to support the new decoupled read/write path. Rather than bloat this PR, we add the tests as a new GH workflow (Chroma Cluster Tests), expect them to fail, and will add a subsequent PR to both add the NotImplemented() functionality as well as fix the tests gradually over time. (The tests are green but if you drill in they are failing, this is the closest github actions gets to xfail) As we develop this, the goal is to make the tests progressively pass, rather than bloat PRs until we are at feature parity. How to test locally: Run `bin/cluster-test.sh` ## Documentation Changes N/A for now --- .github/workflows/chroma-cluster-test.yml | 8 +- .pre-commit-config.yaml | 1 + bin/cluster-test.sh | 4 +- chromadb/api/segment.py | 2 + chromadb/config.py | 4 + chromadb/proto/__init__.py | 0 chromadb/proto/chroma_pb2.py | 70 +-- chromadb/proto/chroma_pb2.pyi | 408 +++++++----------- chromadb/proto/chroma_pb2_grpc.py | 208 +++++++++ chromadb/proto/chromadb/proto/chroma.proto | 115 +++++ chromadb/proto/convert.py | 103 ++++- chromadb/segment/__init__.py | 26 +- chromadb/segment/impl/distributed/server.py | 135 ++++++ chromadb/segment/impl/manager/distributed.py | 157 +++++++ chromadb/segment/impl/manager/local.py | 10 +- .../segment/impl/manager/segment_directory.py | 36 ++ chromadb/segment/impl/vector/grpc_segment.py | 96 +++++ chromadb/test/conftest.py | 19 +- .../test/ingest/test_producer_consumer.py | 1 - docker-compose.cluster.test.yml | 96 +++++ docker-compose.cluster.yml | 28 +- pyproject.toml | 1 + requirements.txt | 1 + requirements_dev.txt | 1 + 24 files changed, 1241 insertions(+), 289 deletions(-) create mode 100644 chromadb/proto/__init__.py create mode 100644 chromadb/proto/chroma_pb2_grpc.py create mode 100644 chromadb/proto/chromadb/proto/chroma.proto create mode 100644 chromadb/segment/impl/distributed/server.py create mode 100644 chromadb/segment/impl/manager/distributed.py create mode 100644 chromadb/segment/impl/manager/segment_directory.py create mode 100644 chromadb/segment/impl/vector/grpc_segment.py create mode 100644 docker-compose.cluster.test.yml diff --git a/.github/workflows/chroma-cluster-test.yml b/.github/workflows/chroma-cluster-test.yml index 5ae873aa198..25287dbf0cd 100644 --- a/.github/workflows/chroma-cluster-test.yml +++ b/.github/workflows/chroma-cluster-test.yml @@ -16,7 +16,12 @@ jobs: matrix: python: ['3.7'] platform: [ubuntu-latest] - testfile: ["chromadb/test/ingest/test_producer_consumer.py"] # Just this one test for now + testfile: ["--ignore-glob 'chromadb/test/property/*' --ignore='chromadb/test/test_cli.py'", + "chromadb/test/property/test_add.py", + "chromadb/test/property/test_collections.py", + "chromadb/test/property/test_embeddings.py", + "chromadb/test/property/test_filtering.py", + "chromadb/test/property/test_persist.py"] runs-on: ${{ matrix.platform }} steps: - name: Checkout @@ -29,3 +34,4 @@ jobs: run: python -m pip install -r requirements.txt && python -m pip install -r requirements_dev.txt - name: Integration Test run: bin/cluster-test.sh ${{ matrix.testfile }} + continue-on-error: true # Mark the job as successful even if the tests fail for now (Xfail) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5b2ed56635e..6bf67a64ead 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,3 +1,4 @@ +exclude: 'chromadb/proto/chroma_pb2\.(py|pyi|py_grpc\.py)' # Generated files repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 diff --git a/bin/cluster-test.sh b/bin/cluster-test.sh index b7255eae60a..6eccc5f3158 100755 --- a/bin/cluster-test.sh +++ b/bin/cluster-test.sh @@ -3,12 +3,12 @@ set -e function cleanup { - docker compose -f docker-compose.cluster.yml down --rmi local --volumes + docker compose -f docker-compose.cluster.test.yml down --rmi local --volumes } trap cleanup EXIT -docker compose -f docker-compose.cluster.yml up -d --wait pulsar +docker compose -f docker-compose.cluster.test.yml up -d --wait export CHROMA_CLUSTER_TEST_ONLY=1 diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index fd2f08ec63b..d23139759d9 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -623,6 +623,8 @@ def _topic(self, collection_id: UUID) -> str: # TODO: This could potentially cause race conditions in a distributed version of the # system, since the cache is only local. + # TODO: promote collection -> topic to a base class method so that it can be + # used for channel assignment in the distributed version of the system. def _validate_embedding_record( self, collection: t.Collection, record: t.SubmitEmbeddingRecord ) -> None: diff --git a/chromadb/config.py b/chromadb/config.py index 0e9c87e5572..1ecf7d04254 100644 --- a/chromadb/config.py +++ b/chromadb/config.py @@ -69,6 +69,7 @@ "chromadb.ingest.Consumer": "chroma_consumer_impl", "chromadb.db.system.SysDB": "chroma_sysdb_impl", "chromadb.segment.SegmentManager": "chroma_segment_manager_impl", + "chromadb.segment.SegmentDirectory": "chroma_segment_directory_impl", } @@ -88,6 +89,9 @@ class Settings(BaseSettings): # type: ignore chroma_segment_manager_impl: str = ( "chromadb.segment.impl.manager.local.LocalSegmentManager" ) + chroma_segment_directory_impl: str = ( + "chromadb.segment.impl.manager.segment_directory.DockerComposeSegmentDirectory" + ) tenant_id: str = "default" topic_namespace: str = "default" diff --git a/chromadb/proto/__init__.py b/chromadb/proto/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/chromadb/proto/chroma_pb2.py b/chromadb/proto/chroma_pb2.py index ca8952697af..2a302c67154 100644 --- a/chromadb/proto/chroma_pb2.py +++ b/chromadb/proto/chroma_pb2.py @@ -6,37 +6,59 @@ from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder - # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1b\x63hromadb/proto/chroma.proto\x12\x06\x63hroma"U\n\x06Vector\x12\x11\n\tdimension\x18\x01 \x01(\x05\x12\x0e\n\x06vector\x18\x02 \x01(\x0c\x12(\n\x08\x65ncoding\x18\x03 \x01(\x0e\x32\x16.chroma.ScalarEncoding"b\n\x13UpdateMetadataValue\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x13\n\tint_value\x18\x02 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x01H\x00\x42\x07\n\x05value"\x96\x01\n\x0eUpdateMetadata\x12\x36\n\x08metadata\x18\x01 \x03(\x0b\x32$.chroma.UpdateMetadata.MetadataEntry\x1aL\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12*\n\x05value\x18\x02 \x01(\x0b\x32\x1b.chroma.UpdateMetadataValue:\x02\x38\x01"\xb5\x01\n\x15SubmitEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12#\n\x06vector\x18\x02 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x12-\n\x08metadata\x18\x03 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x01\x88\x01\x01\x12$\n\toperation\x18\x04 \x01(\x0e\x32\x11.chroma.OperationB\t\n\x07_vectorB\x0b\n\t_metadata*8\n\tOperation\x12\x07\n\x03\x41\x44\x44\x10\x00\x12\n\n\x06UPDATE\x10\x01\x12\n\n\x06UPSERT\x10\x02\x12\n\n\x06\x44\x45LETE\x10\x03*(\n\x0eScalarEncoding\x12\x0b\n\x07\x46LOAT32\x10\x00\x12\t\n\x05INT32\x10\x01\x62\x06proto3' -) + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1b\x63hromadb/proto/chroma.proto\x12\x06\x63hroma\"U\n\x06Vector\x12\x11\n\tdimension\x18\x01 \x01(\x05\x12\x0e\n\x06vector\x18\x02 \x01(\x0c\x12(\n\x08\x65ncoding\x18\x03 \x01(\x0e\x32\x16.chroma.ScalarEncoding\"\xca\x01\n\x07Segment\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12#\n\x05scope\x18\x03 \x01(\x0e\x32\x14.chroma.SegmentScope\x12\x12\n\x05topic\x18\x04 \x01(\tH\x00\x88\x01\x01\x12\x17\n\ncollection\x18\x05 \x01(\tH\x01\x88\x01\x01\x12-\n\x08metadata\x18\x06 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x88\x01\x01\x42\x08\n\x06_topicB\r\n\x0b_collectionB\x0b\n\t_metadata\"b\n\x13UpdateMetadataValue\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x13\n\tint_value\x18\x02 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x01H\x00\x42\x07\n\x05value\"\x96\x01\n\x0eUpdateMetadata\x12\x36\n\x08metadata\x18\x01 \x03(\x0b\x32$.chroma.UpdateMetadata.MetadataEntry\x1aL\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12*\n\x05value\x18\x02 \x01(\x0b\x32\x1b.chroma.UpdateMetadataValue:\x02\x38\x01\"\xb5\x01\n\x15SubmitEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12#\n\x06vector\x18\x02 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x12-\n\x08metadata\x18\x03 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x01\x88\x01\x01\x12$\n\toperation\x18\x04 \x01(\x0e\x32\x11.chroma.OperationB\t\n\x07_vectorB\x0b\n\t_metadata\"S\n\x15VectorEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06seq_id\x18\x02 \x01(\x0c\x12\x1e\n\x06vector\x18\x03 \x01(\x0b\x32\x0e.chroma.Vector\"q\n\x11VectorQueryResult\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06seq_id\x18\x02 \x01(\x0c\x12\x10\n\x08\x64istance\x18\x03 \x01(\x01\x12#\n\x06vector\x18\x04 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x42\t\n\x07_vector\"@\n\x12VectorQueryResults\x12*\n\x07results\x18\x01 \x03(\x0b\x32\x19.chroma.VectorQueryResult\"(\n\x15SegmentServerResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"4\n\x11GetVectorsRequest\x12\x0b\n\x03ids\x18\x01 \x03(\t\x12\x12\n\nsegment_id\x18\x02 \x01(\t\"D\n\x12GetVectorsResponse\x12.\n\x07records\x18\x01 \x03(\x0b\x32\x1d.chroma.VectorEmbeddingRecord\"\x86\x01\n\x13QueryVectorsRequest\x12\x1f\n\x07vectors\x18\x01 \x03(\x0b\x32\x0e.chroma.Vector\x12\t\n\x01k\x18\x02 \x01(\x05\x12\x13\n\x0b\x61llowed_ids\x18\x03 \x03(\t\x12\x1a\n\x12include_embeddings\x18\x04 \x01(\x08\x12\x12\n\nsegment_id\x18\x05 \x01(\t\"C\n\x14QueryVectorsResponse\x12+\n\x07results\x18\x01 \x03(\x0b\x32\x1a.chroma.VectorQueryResults*8\n\tOperation\x12\x07\n\x03\x41\x44\x44\x10\x00\x12\n\n\x06UPDATE\x10\x01\x12\n\n\x06UPSERT\x10\x02\x12\n\n\x06\x44\x45LETE\x10\x03*(\n\x0eScalarEncoding\x12\x0b\n\x07\x46LOAT32\x10\x00\x12\t\n\x05INT32\x10\x01*(\n\x0cSegmentScope\x12\n\n\x06VECTOR\x10\x00\x12\x0c\n\x08METADATA\x10\x01\x32\x94\x01\n\rSegmentServer\x12?\n\x0bLoadSegment\x12\x0f.chroma.Segment\x1a\x1d.chroma.SegmentServerResponse\"\x00\x12\x42\n\x0eReleaseSegment\x12\x0f.chroma.Segment\x1a\x1d.chroma.SegmentServerResponse\"\x00\x32\xa2\x01\n\x0cVectorReader\x12\x45\n\nGetVectors\x12\x19.chroma.GetVectorsRequest\x1a\x1a.chroma.GetVectorsResponse\"\x00\x12K\n\x0cQueryVectors\x12\x1b.chroma.QueryVectorsRequest\x1a\x1c.chroma.QueryVectorsResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages( - DESCRIPTOR, "chromadb.proto.chroma_pb2", _globals -) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'chromadb.proto.chroma_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _UPDATEMETADATA_METADATAENTRY._options = None - _UPDATEMETADATA_METADATAENTRY._serialized_options = b"8\001" - _globals["_OPERATION"]._serialized_start = 563 - _globals["_OPERATION"]._serialized_end = 619 - _globals["_SCALARENCODING"]._serialized_start = 621 - _globals["_SCALARENCODING"]._serialized_end = 661 - _globals["_VECTOR"]._serialized_start = 39 - _globals["_VECTOR"]._serialized_end = 124 - _globals["_UPDATEMETADATAVALUE"]._serialized_start = 126 - _globals["_UPDATEMETADATAVALUE"]._serialized_end = 224 - _globals["_UPDATEMETADATA"]._serialized_start = 227 - _globals["_UPDATEMETADATA"]._serialized_end = 377 - _globals["_UPDATEMETADATA_METADATAENTRY"]._serialized_start = 301 - _globals["_UPDATEMETADATA_METADATAENTRY"]._serialized_end = 377 - _globals["_SUBMITEMBEDDINGRECORD"]._serialized_start = 380 - _globals["_SUBMITEMBEDDINGRECORD"]._serialized_end = 561 + + DESCRIPTOR._options = None + _UPDATEMETADATA_METADATAENTRY._options = None + _UPDATEMETADATA_METADATAENTRY._serialized_options = b'8\001' + _globals['_OPERATION']._serialized_start=1406 + _globals['_OPERATION']._serialized_end=1462 + _globals['_SCALARENCODING']._serialized_start=1464 + _globals['_SCALARENCODING']._serialized_end=1504 + _globals['_SEGMENTSCOPE']._serialized_start=1506 + _globals['_SEGMENTSCOPE']._serialized_end=1546 + _globals['_VECTOR']._serialized_start=39 + _globals['_VECTOR']._serialized_end=124 + _globals['_SEGMENT']._serialized_start=127 + _globals['_SEGMENT']._serialized_end=329 + _globals['_UPDATEMETADATAVALUE']._serialized_start=331 + _globals['_UPDATEMETADATAVALUE']._serialized_end=429 + _globals['_UPDATEMETADATA']._serialized_start=432 + _globals['_UPDATEMETADATA']._serialized_end=582 + _globals['_UPDATEMETADATA_METADATAENTRY']._serialized_start=506 + _globals['_UPDATEMETADATA_METADATAENTRY']._serialized_end=582 + _globals['_SUBMITEMBEDDINGRECORD']._serialized_start=585 + _globals['_SUBMITEMBEDDINGRECORD']._serialized_end=766 + _globals['_VECTOREMBEDDINGRECORD']._serialized_start=768 + _globals['_VECTOREMBEDDINGRECORD']._serialized_end=851 + _globals['_VECTORQUERYRESULT']._serialized_start=853 + _globals['_VECTORQUERYRESULT']._serialized_end=966 + _globals['_VECTORQUERYRESULTS']._serialized_start=968 + _globals['_VECTORQUERYRESULTS']._serialized_end=1032 + _globals['_SEGMENTSERVERRESPONSE']._serialized_start=1034 + _globals['_SEGMENTSERVERRESPONSE']._serialized_end=1074 + _globals['_GETVECTORSREQUEST']._serialized_start=1076 + _globals['_GETVECTORSREQUEST']._serialized_end=1128 + _globals['_GETVECTORSRESPONSE']._serialized_start=1130 + _globals['_GETVECTORSRESPONSE']._serialized_end=1198 + _globals['_QUERYVECTORSREQUEST']._serialized_start=1201 + _globals['_QUERYVECTORSREQUEST']._serialized_end=1335 + _globals['_QUERYVECTORSRESPONSE']._serialized_start=1337 + _globals['_QUERYVECTORSRESPONSE']._serialized_end=1404 + _globals['_SEGMENTSERVER']._serialized_start=1549 + _globals['_SEGMENTSERVER']._serialized_end=1697 + _globals['_VECTORREADER']._serialized_start=1700 + _globals['_VECTORREADER']._serialized_end=1862 # @@protoc_insertion_point(module_scope) diff --git a/chromadb/proto/chroma_pb2.pyi b/chromadb/proto/chroma_pb2.pyi index b13327e982f..6d06e074c06 100644 --- a/chromadb/proto/chroma_pb2.pyi +++ b/chromadb/proto/chroma_pb2.pyi @@ -1,247 +1,161 @@ -""" -@generated by mypy-protobuf. Do not edit manually! -isort:skip_file -""" -import builtins -import collections.abc -import google.protobuf.descriptor -import google.protobuf.internal.containers -import google.protobuf.internal.enum_type_wrapper -import google.protobuf.message -import sys -import typing - -if sys.version_info >= (3, 10): - import typing as typing_extensions -else: - import typing_extensions - -DESCRIPTOR: google.protobuf.descriptor.FileDescriptor - -class _Operation: - ValueType = typing.NewType("ValueType", builtins.int) - V: typing_extensions.TypeAlias = ValueType - -class _OperationEnumTypeWrapper( - google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_Operation.ValueType], - builtins.type, -): - DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor - ADD: _Operation.ValueType # 0 - UPDATE: _Operation.ValueType # 1 - UPSERT: _Operation.ValueType # 2 - DELETE: _Operation.ValueType # 3 - -class Operation(_Operation, metaclass=_OperationEnumTypeWrapper): ... - -ADD: Operation.ValueType # 0 -UPDATE: Operation.ValueType # 1 -UPSERT: Operation.ValueType # 2 -DELETE: Operation.ValueType # 3 -global___Operation = Operation - -class _ScalarEncoding: - ValueType = typing.NewType("ValueType", builtins.int) - V: typing_extensions.TypeAlias = ValueType - -class _ScalarEncodingEnumTypeWrapper( - google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[ - _ScalarEncoding.ValueType - ], - builtins.type, -): - DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor - FLOAT32: _ScalarEncoding.ValueType # 0 - INT32: _ScalarEncoding.ValueType # 1 - -class ScalarEncoding(_ScalarEncoding, metaclass=_ScalarEncodingEnumTypeWrapper): ... - -FLOAT32: ScalarEncoding.ValueType # 0 -INT32: ScalarEncoding.ValueType # 1 -global___ScalarEncoding = ScalarEncoding - -@typing_extensions.final -class Vector(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - DIMENSION_FIELD_NUMBER: builtins.int - VECTOR_FIELD_NUMBER: builtins.int - ENCODING_FIELD_NUMBER: builtins.int - dimension: builtins.int - vector: builtins.bytes - encoding: global___ScalarEncoding.ValueType - def __init__( - self, - *, - dimension: builtins.int = ..., - vector: builtins.bytes = ..., - encoding: global___ScalarEncoding.ValueType = ..., - ) -> None: ... - def ClearField( - self, - field_name: typing_extensions.Literal[ - "dimension", b"dimension", "encoding", b"encoding", "vector", b"vector" - ], - ) -> None: ... - -global___Vector = Vector - -@typing_extensions.final -class UpdateMetadataValue(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - STRING_VALUE_FIELD_NUMBER: builtins.int - INT_VALUE_FIELD_NUMBER: builtins.int - FLOAT_VALUE_FIELD_NUMBER: builtins.int - string_value: builtins.str - int_value: builtins.int - float_value: builtins.float - def __init__( - self, - *, - string_value: builtins.str = ..., - int_value: builtins.int = ..., - float_value: builtins.float = ..., - ) -> None: ... - def HasField( - self, - field_name: typing_extensions.Literal[ - "float_value", - b"float_value", - "int_value", - b"int_value", - "string_value", - b"string_value", - "value", - b"value", - ], - ) -> builtins.bool: ... - def ClearField( - self, - field_name: typing_extensions.Literal[ - "float_value", - b"float_value", - "int_value", - b"int_value", - "string_value", - b"string_value", - "value", - b"value", - ], - ) -> None: ... - def WhichOneof( - self, oneof_group: typing_extensions.Literal["value", b"value"] - ) -> ( - typing_extensions.Literal["string_value", "int_value", "float_value"] | None - ): ... - -global___UpdateMetadataValue = UpdateMetadataValue - -@typing_extensions.final -class UpdateMetadata(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - @typing_extensions.final - class MetadataEntry(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - KEY_FIELD_NUMBER: builtins.int - VALUE_FIELD_NUMBER: builtins.int - key: builtins.str - @property - def value(self) -> global___UpdateMetadataValue: ... - def __init__( - self, - *, - key: builtins.str = ..., - value: global___UpdateMetadataValue | None = ..., - ) -> None: ... - def HasField( - self, field_name: typing_extensions.Literal["value", b"value"] - ) -> builtins.bool: ... - def ClearField( - self, - field_name: typing_extensions.Literal["key", b"key", "value", b"value"], - ) -> None: ... - - METADATA_FIELD_NUMBER: builtins.int - @property - def metadata( - self, - ) -> google.protobuf.internal.containers.MessageMap[ - builtins.str, global___UpdateMetadataValue - ]: ... - def __init__( - self, - *, - metadata: collections.abc.Mapping[builtins.str, global___UpdateMetadataValue] - | None = ..., - ) -> None: ... - def ClearField( - self, field_name: typing_extensions.Literal["metadata", b"metadata"] - ) -> None: ... - -global___UpdateMetadata = UpdateMetadata - -@typing_extensions.final -class SubmitEmbeddingRecord(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - ID_FIELD_NUMBER: builtins.int - VECTOR_FIELD_NUMBER: builtins.int - METADATA_FIELD_NUMBER: builtins.int - OPERATION_FIELD_NUMBER: builtins.int - id: builtins.str - @property - def vector(self) -> global___Vector: ... - @property - def metadata(self) -> global___UpdateMetadata: ... - operation: global___Operation.ValueType - def __init__( - self, - *, - id: builtins.str = ..., - vector: global___Vector | None = ..., - metadata: global___UpdateMetadata | None = ..., - operation: global___Operation.ValueType = ..., - ) -> None: ... - def HasField( - self, - field_name: typing_extensions.Literal[ - "_metadata", - b"_metadata", - "_vector", - b"_vector", - "metadata", - b"metadata", - "vector", - b"vector", - ], - ) -> builtins.bool: ... - def ClearField( - self, - field_name: typing_extensions.Literal[ - "_metadata", - b"_metadata", - "_vector", - b"_vector", - "id", - b"id", - "metadata", - b"metadata", - "operation", - b"operation", - "vector", - b"vector", - ], - ) -> None: ... - @typing.overload - def WhichOneof( - self, oneof_group: typing_extensions.Literal["_metadata", b"_metadata"] - ) -> typing_extensions.Literal["metadata"] | None: ... - @typing.overload - def WhichOneof( - self, oneof_group: typing_extensions.Literal["_vector", b"_vector"] - ) -> typing_extensions.Literal["vector"] | None: ... - -global___SubmitEmbeddingRecord = SubmitEmbeddingRecord +from google.protobuf.internal import containers as _containers +from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union + +DESCRIPTOR: _descriptor.FileDescriptor + +class Operation(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = [] + ADD: _ClassVar[Operation] + UPDATE: _ClassVar[Operation] + UPSERT: _ClassVar[Operation] + DELETE: _ClassVar[Operation] + +class ScalarEncoding(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = [] + FLOAT32: _ClassVar[ScalarEncoding] + INT32: _ClassVar[ScalarEncoding] + +class SegmentScope(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = [] + VECTOR: _ClassVar[SegmentScope] + METADATA: _ClassVar[SegmentScope] +ADD: Operation +UPDATE: Operation +UPSERT: Operation +DELETE: Operation +FLOAT32: ScalarEncoding +INT32: ScalarEncoding +VECTOR: SegmentScope +METADATA: SegmentScope + +class Vector(_message.Message): + __slots__ = ["dimension", "vector", "encoding"] + DIMENSION_FIELD_NUMBER: _ClassVar[int] + VECTOR_FIELD_NUMBER: _ClassVar[int] + ENCODING_FIELD_NUMBER: _ClassVar[int] + dimension: int + vector: bytes + encoding: ScalarEncoding + def __init__(self, dimension: _Optional[int] = ..., vector: _Optional[bytes] = ..., encoding: _Optional[_Union[ScalarEncoding, str]] = ...) -> None: ... + +class Segment(_message.Message): + __slots__ = ["id", "type", "scope", "topic", "collection", "metadata"] + ID_FIELD_NUMBER: _ClassVar[int] + TYPE_FIELD_NUMBER: _ClassVar[int] + SCOPE_FIELD_NUMBER: _ClassVar[int] + TOPIC_FIELD_NUMBER: _ClassVar[int] + COLLECTION_FIELD_NUMBER: _ClassVar[int] + METADATA_FIELD_NUMBER: _ClassVar[int] + id: str + type: str + scope: SegmentScope + topic: str + collection: str + metadata: UpdateMetadata + def __init__(self, id: _Optional[str] = ..., type: _Optional[str] = ..., scope: _Optional[_Union[SegmentScope, str]] = ..., topic: _Optional[str] = ..., collection: _Optional[str] = ..., metadata: _Optional[_Union[UpdateMetadata, _Mapping]] = ...) -> None: ... + +class UpdateMetadataValue(_message.Message): + __slots__ = ["string_value", "int_value", "float_value"] + STRING_VALUE_FIELD_NUMBER: _ClassVar[int] + INT_VALUE_FIELD_NUMBER: _ClassVar[int] + FLOAT_VALUE_FIELD_NUMBER: _ClassVar[int] + string_value: str + int_value: int + float_value: float + def __init__(self, string_value: _Optional[str] = ..., int_value: _Optional[int] = ..., float_value: _Optional[float] = ...) -> None: ... + +class UpdateMetadata(_message.Message): + __slots__ = ["metadata"] + class MetadataEntry(_message.Message): + __slots__ = ["key", "value"] + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: UpdateMetadataValue + def __init__(self, key: _Optional[str] = ..., value: _Optional[_Union[UpdateMetadataValue, _Mapping]] = ...) -> None: ... + METADATA_FIELD_NUMBER: _ClassVar[int] + metadata: _containers.MessageMap[str, UpdateMetadataValue] + def __init__(self, metadata: _Optional[_Mapping[str, UpdateMetadataValue]] = ...) -> None: ... + +class SubmitEmbeddingRecord(_message.Message): + __slots__ = ["id", "vector", "metadata", "operation"] + ID_FIELD_NUMBER: _ClassVar[int] + VECTOR_FIELD_NUMBER: _ClassVar[int] + METADATA_FIELD_NUMBER: _ClassVar[int] + OPERATION_FIELD_NUMBER: _ClassVar[int] + id: str + vector: Vector + metadata: UpdateMetadata + operation: Operation + def __init__(self, id: _Optional[str] = ..., vector: _Optional[_Union[Vector, _Mapping]] = ..., metadata: _Optional[_Union[UpdateMetadata, _Mapping]] = ..., operation: _Optional[_Union[Operation, str]] = ...) -> None: ... + +class VectorEmbeddingRecord(_message.Message): + __slots__ = ["id", "seq_id", "vector"] + ID_FIELD_NUMBER: _ClassVar[int] + SEQ_ID_FIELD_NUMBER: _ClassVar[int] + VECTOR_FIELD_NUMBER: _ClassVar[int] + id: str + seq_id: bytes + vector: Vector + def __init__(self, id: _Optional[str] = ..., seq_id: _Optional[bytes] = ..., vector: _Optional[_Union[Vector, _Mapping]] = ...) -> None: ... + +class VectorQueryResult(_message.Message): + __slots__ = ["id", "seq_id", "distance", "vector"] + ID_FIELD_NUMBER: _ClassVar[int] + SEQ_ID_FIELD_NUMBER: _ClassVar[int] + DISTANCE_FIELD_NUMBER: _ClassVar[int] + VECTOR_FIELD_NUMBER: _ClassVar[int] + id: str + seq_id: bytes + distance: float + vector: Vector + def __init__(self, id: _Optional[str] = ..., seq_id: _Optional[bytes] = ..., distance: _Optional[float] = ..., vector: _Optional[_Union[Vector, _Mapping]] = ...) -> None: ... + +class VectorQueryResults(_message.Message): + __slots__ = ["results"] + RESULTS_FIELD_NUMBER: _ClassVar[int] + results: _containers.RepeatedCompositeFieldContainer[VectorQueryResult] + def __init__(self, results: _Optional[_Iterable[_Union[VectorQueryResult, _Mapping]]] = ...) -> None: ... + +class SegmentServerResponse(_message.Message): + __slots__ = ["success"] + SUCCESS_FIELD_NUMBER: _ClassVar[int] + success: bool + def __init__(self, success: bool = ...) -> None: ... + +class GetVectorsRequest(_message.Message): + __slots__ = ["ids", "segment_id"] + IDS_FIELD_NUMBER: _ClassVar[int] + SEGMENT_ID_FIELD_NUMBER: _ClassVar[int] + ids: _containers.RepeatedScalarFieldContainer[str] + segment_id: str + def __init__(self, ids: _Optional[_Iterable[str]] = ..., segment_id: _Optional[str] = ...) -> None: ... + +class GetVectorsResponse(_message.Message): + __slots__ = ["records"] + RECORDS_FIELD_NUMBER: _ClassVar[int] + records: _containers.RepeatedCompositeFieldContainer[VectorEmbeddingRecord] + def __init__(self, records: _Optional[_Iterable[_Union[VectorEmbeddingRecord, _Mapping]]] = ...) -> None: ... + +class QueryVectorsRequest(_message.Message): + __slots__ = ["vectors", "k", "allowed_ids", "include_embeddings", "segment_id"] + VECTORS_FIELD_NUMBER: _ClassVar[int] + K_FIELD_NUMBER: _ClassVar[int] + ALLOWED_IDS_FIELD_NUMBER: _ClassVar[int] + INCLUDE_EMBEDDINGS_FIELD_NUMBER: _ClassVar[int] + SEGMENT_ID_FIELD_NUMBER: _ClassVar[int] + vectors: _containers.RepeatedCompositeFieldContainer[Vector] + k: int + allowed_ids: _containers.RepeatedScalarFieldContainer[str] + include_embeddings: bool + segment_id: str + def __init__(self, vectors: _Optional[_Iterable[_Union[Vector, _Mapping]]] = ..., k: _Optional[int] = ..., allowed_ids: _Optional[_Iterable[str]] = ..., include_embeddings: bool = ..., segment_id: _Optional[str] = ...) -> None: ... + +class QueryVectorsResponse(_message.Message): + __slots__ = ["results"] + RESULTS_FIELD_NUMBER: _ClassVar[int] + results: _containers.RepeatedCompositeFieldContainer[VectorQueryResults] + def __init__(self, results: _Optional[_Iterable[_Union[VectorQueryResults, _Mapping]]] = ...) -> None: ... diff --git a/chromadb/proto/chroma_pb2_grpc.py b/chromadb/proto/chroma_pb2_grpc.py new file mode 100644 index 00000000000..f5cc85a36bd --- /dev/null +++ b/chromadb/proto/chroma_pb2_grpc.py @@ -0,0 +1,208 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +from chromadb.proto import chroma_pb2 as chromadb_dot_proto_dot_chroma__pb2 + + +class SegmentServerStub(object): + """Segment Server Interface + + TODO: figure out subpackaging, ideally this file is colocated with the segment server implementation + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.LoadSegment = channel.unary_unary( + '/chroma.SegmentServer/LoadSegment', + request_serializer=chromadb_dot_proto_dot_chroma__pb2.Segment.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_chroma__pb2.SegmentServerResponse.FromString, + ) + self.ReleaseSegment = channel.unary_unary( + '/chroma.SegmentServer/ReleaseSegment', + request_serializer=chromadb_dot_proto_dot_chroma__pb2.Segment.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_chroma__pb2.SegmentServerResponse.FromString, + ) + + +class SegmentServerServicer(object): + """Segment Server Interface + + TODO: figure out subpackaging, ideally this file is colocated with the segment server implementation + """ + + def LoadSegment(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def ReleaseSegment(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_SegmentServerServicer_to_server(servicer, server): + rpc_method_handlers = { + 'LoadSegment': grpc.unary_unary_rpc_method_handler( + servicer.LoadSegment, + request_deserializer=chromadb_dot_proto_dot_chroma__pb2.Segment.FromString, + response_serializer=chromadb_dot_proto_dot_chroma__pb2.SegmentServerResponse.SerializeToString, + ), + 'ReleaseSegment': grpc.unary_unary_rpc_method_handler( + servicer.ReleaseSegment, + request_deserializer=chromadb_dot_proto_dot_chroma__pb2.Segment.FromString, + response_serializer=chromadb_dot_proto_dot_chroma__pb2.SegmentServerResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'chroma.SegmentServer', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class SegmentServer(object): + """Segment Server Interface + + TODO: figure out subpackaging, ideally this file is colocated with the segment server implementation + """ + + @staticmethod + def LoadSegment(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/chroma.SegmentServer/LoadSegment', + chromadb_dot_proto_dot_chroma__pb2.Segment.SerializeToString, + chromadb_dot_proto_dot_chroma__pb2.SegmentServerResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def ReleaseSegment(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/chroma.SegmentServer/ReleaseSegment', + chromadb_dot_proto_dot_chroma__pb2.Segment.SerializeToString, + chromadb_dot_proto_dot_chroma__pb2.SegmentServerResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + +class VectorReaderStub(object): + """Vector Reader Interface + + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.GetVectors = channel.unary_unary( + '/chroma.VectorReader/GetVectors', + request_serializer=chromadb_dot_proto_dot_chroma__pb2.GetVectorsRequest.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_chroma__pb2.GetVectorsResponse.FromString, + ) + self.QueryVectors = channel.unary_unary( + '/chroma.VectorReader/QueryVectors', + request_serializer=chromadb_dot_proto_dot_chroma__pb2.QueryVectorsRequest.SerializeToString, + response_deserializer=chromadb_dot_proto_dot_chroma__pb2.QueryVectorsResponse.FromString, + ) + + +class VectorReaderServicer(object): + """Vector Reader Interface + + """ + + def GetVectors(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def QueryVectors(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_VectorReaderServicer_to_server(servicer, server): + rpc_method_handlers = { + 'GetVectors': grpc.unary_unary_rpc_method_handler( + servicer.GetVectors, + request_deserializer=chromadb_dot_proto_dot_chroma__pb2.GetVectorsRequest.FromString, + response_serializer=chromadb_dot_proto_dot_chroma__pb2.GetVectorsResponse.SerializeToString, + ), + 'QueryVectors': grpc.unary_unary_rpc_method_handler( + servicer.QueryVectors, + request_deserializer=chromadb_dot_proto_dot_chroma__pb2.QueryVectorsRequest.FromString, + response_serializer=chromadb_dot_proto_dot_chroma__pb2.QueryVectorsResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'chroma.VectorReader', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class VectorReader(object): + """Vector Reader Interface + + """ + + @staticmethod + def GetVectors(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/chroma.VectorReader/GetVectors', + chromadb_dot_proto_dot_chroma__pb2.GetVectorsRequest.SerializeToString, + chromadb_dot_proto_dot_chroma__pb2.GetVectorsResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def QueryVectors(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/chroma.VectorReader/QueryVectors', + chromadb_dot_proto_dot_chroma__pb2.QueryVectorsRequest.SerializeToString, + chromadb_dot_proto_dot_chroma__pb2.QueryVectorsResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/chromadb/proto/chromadb/proto/chroma.proto b/chromadb/proto/chromadb/proto/chroma.proto new file mode 100644 index 00000000000..ddc7f11bc26 --- /dev/null +++ b/chromadb/proto/chromadb/proto/chroma.proto @@ -0,0 +1,115 @@ +syntax = "proto3"; + +package chroma; + +enum Operation { + ADD = 0; + UPDATE = 1; + UPSERT = 2; + DELETE = 3; +} + +enum ScalarEncoding { + FLOAT32 = 0; + INT32 = 1; +} + +message Vector { + int32 dimension = 1; + bytes vector = 2; + ScalarEncoding encoding = 3; +} + +enum SegmentScope { + VECTOR = 0; + METADATA = 1; +} + +message Segment { + string id = 1; + string type = 2; + SegmentScope scope = 3; + optional string topic = 4; // TODO should channel <> segment binding exist here? + // If a segment has a collection, it implies that this segment implements the full + // collection and can be used to service queries (for it's given scope.) + optional string collection = 5; + optional UpdateMetadata metadata = 6; +} + +message UpdateMetadataValue { + oneof value { + string string_value = 1; + int64 int_value = 2; + double float_value = 3; + } +} + +message UpdateMetadata { + map metadata = 1; +} + +message SubmitEmbeddingRecord { + string id = 1; + optional Vector vector = 2; + optional UpdateMetadata metadata = 3; + Operation operation = 4; +} + +message VectorEmbeddingRecord { + string id = 1; + bytes seq_id = 2; + Vector vector = 3; // TODO: we need to rethink source of truth for vector dimensionality and encoding +} + +message VectorQueryResult { + string id = 1; + bytes seq_id = 2; + double distance = 3; + optional Vector vector = 4; +} + +message VectorQueryResults { + repeated VectorQueryResult results = 1; +} + +/* Segment Server Interface */ + +// TODO: figure out subpackaging, ideally this file is colocated with the segment server implementation +service SegmentServer { + rpc LoadSegment (Segment) returns (SegmentServerResponse) {} + rpc ReleaseSegment (Segment) returns (SegmentServerResponse) {} // TODO: this maybe should only take id/type/scope +} + +// TODO: enum of succcess/failure/or already loaded +message SegmentServerResponse { + bool success = 1; +} + +/* Vector Reader Interface */ + +service VectorReader { + rpc GetVectors(GetVectorsRequest) returns (GetVectorsResponse) {} + rpc QueryVectors(QueryVectorsRequest) returns (QueryVectorsResponse) {} +} + +message GetVectorsRequest { + repeated string ids = 1; + string segment_id = 2; +} + +message GetVectorsResponse { + repeated VectorEmbeddingRecord records = 1; +} + +message QueryVectorsRequest { + repeated Vector vectors = 1; + int32 k = 2; + repeated string allowed_ids = 3; + bool include_embeddings = 4; + string segment_id = 5; + // TODO: options as in types.py, its currently unused so can add later +} + +message QueryVectorsResponse { + repeated VectorQueryResults results = 1; +} diff --git a/chromadb/proto/convert.py b/chromadb/proto/convert.py index 15d1363b05c..5ff7bab085d 100644 --- a/chromadb/proto/convert.py +++ b/chromadb/proto/convert.py @@ -1,18 +1,27 @@ import array -from typing import Optional, Tuple, Union +from uuid import UUID +from typing import Dict, Optional, Tuple, Union from chromadb.api.types import Embedding import chromadb.proto.chroma_pb2 as proto +from chromadb.utils.messageid import bytes_to_int, int_to_bytes from chromadb.types import ( EmbeddingRecord, Metadata, Operation, ScalarEncoding, + Segment, + SegmentScope, SeqId, SubmitEmbeddingRecord, Vector, + VectorEmbeddingRecord, + VectorQueryResult, ) +# TODO: Unit tests for this file, handling optional states etc + + def to_proto_vector(vector: Vector, encoding: ScalarEncoding) -> proto.Vector: if encoding == ScalarEncoding.FLOAT32: as_bytes = array.array("f", vector).tobytes() @@ -48,7 +57,7 @@ def from_proto_vector(vector: proto.Vector) -> Tuple[Embedding, ScalarEncoding]: return (as_array.tolist(), out_encoding) -def from_proto_operation(operation: proto.Operation.ValueType) -> Operation: +def from_proto_operation(operation: proto.Operation) -> Operation: if operation == proto.Operation.ADD: return Operation.ADD elif operation == proto.Operation.UPDATE: @@ -64,7 +73,7 @@ def from_proto_operation(operation: proto.Operation.ValueType) -> Operation: def from_proto_metadata(metadata: proto.UpdateMetadata) -> Optional[Metadata]: if not metadata.metadata: return None - out_metadata = {} + out_metadata: Dict[str, Union[str, int, float]] = {} for key, value in metadata.metadata.items(): if value.HasField("string_value"): out_metadata[key] = value.string_value @@ -92,6 +101,52 @@ def from_proto_submit( return record +def from_proto_segment(segment: proto.Segment) -> Segment: + return Segment( + id=UUID(hex=segment.id), + type=segment.type, + scope=from_proto_segment_scope(segment.scope), + topic=segment.topic, + collection=None + if not segment.HasField("collection") + else UUID(hex=segment.collection), + metadata=from_proto_metadata(segment.metadata), + ) + + +def to_proto_segment(segment: Segment) -> proto.Segment: + return proto.Segment( + id=segment["id"].hex, + type=segment["type"], + scope=to_proto_segment_scope(segment["scope"]), + topic=segment["topic"], + collection=None if segment["collection"] is None else segment["collection"].hex, + metadata=None + if segment["metadata"] is None + else { + k: to_proto_metadata_update_value(v) for k, v in segment["metadata"].items() + }, # TODO: refactor out to_proto_metadata + ) + + +def from_proto_segment_scope(segment_scope: proto.SegmentScope) -> SegmentScope: + if segment_scope == proto.SegmentScope.VECTOR: + return SegmentScope.VECTOR + elif segment_scope == proto.SegmentScope.METADATA: + return SegmentScope.METADATA + else: + raise RuntimeError(f"Unknown segment scope {segment_scope}") + + +def to_proto_segment_scope(segment_scope: SegmentScope) -> proto.SegmentScope: + if segment_scope == SegmentScope.VECTOR: + return proto.SegmentScope.VECTOR + elif segment_scope == SegmentScope.METADATA: + return proto.SegmentScope.METADATA + else: + raise RuntimeError(f"Unknown segment scope {segment_scope}") + + def to_proto_metadata_update_value( value: Union[str, int, float, None] ) -> proto.UpdateMetadataValue: @@ -110,7 +165,7 @@ def to_proto_metadata_update_value( ) -def to_proto_operation(operation: Operation) -> proto.Operation.ValueType: +def to_proto_operation(operation: Operation) -> proto.Operation: if operation == Operation.ADD: return proto.Operation.ADD elif operation == Operation.UPDATE: @@ -148,3 +203,43 @@ def to_proto_submit( else None, operation=to_proto_operation(submit_record["operation"]), ) + + +def from_proto_vector_embedding_record( + embedding_record: proto.VectorEmbeddingRecord, +) -> VectorEmbeddingRecord: + return VectorEmbeddingRecord( + id=embedding_record.id, + seq_id=from_proto_seq_id(embedding_record.seq_id), + embedding=from_proto_vector(embedding_record.vector)[0], + ) + + +def to_proto_vector_embedding_record( + embedding_record: VectorEmbeddingRecord, + encoding: ScalarEncoding, +) -> proto.VectorEmbeddingRecord: + return proto.VectorEmbeddingRecord( + id=embedding_record["id"], + seq_id=to_proto_seq_id(embedding_record["seq_id"]), + vector=to_proto_vector(embedding_record["embedding"], encoding), + ) + + +def from_proto_vector_query_result( + vector_query_result: proto.VectorQueryResult, +) -> VectorQueryResult: + return VectorQueryResult( + id=vector_query_result.id, + seq_id=from_proto_seq_id(vector_query_result.seq_id), + distance=vector_query_result.distance, + embedding=from_proto_vector(vector_query_result.vector)[0], + ) + + +def to_proto_seq_id(seq_id: SeqId) -> bytes: + return int_to_bytes(seq_id) + + +def from_proto_seq_id(seq_id: bytes) -> SeqId: + return bytes_to_int(seq_id) diff --git a/chromadb/segment/__init__.py b/chromadb/segment/__init__.py index e92bccee6fa..2c2570796fc 100644 --- a/chromadb/segment/__init__.py +++ b/chromadb/segment/__init__.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence, TypeVar, Type +from typing import Callable, Optional, Sequence, TypeVar, Type from abc import abstractmethod from chromadb.types import ( Collection, @@ -15,6 +15,14 @@ ) from chromadb.config import Component, System from uuid import UUID +from enum import Enum + + +class SegmentType(Enum): + SQLITE = "urn:chroma:segment/metadata/sqlite" + HNSW_LOCAL_MEMORY = "urn:chroma:segment/vector/hnsw-local-memory" + HNSW_LOCAL_PERSISTED = "urn:chroma:segment/vector/hnsw-local-persisted" + HNSW_DISTRIBUTED = "urn:chroma:segment/vector/hnsw-distributed" class SegmentImplementation(Component): @@ -118,3 +126,19 @@ def hint_use_collection(self, collection_id: UUID, hint_type: Operation) -> None it can preload segments as needed. This is only a hint, and implementations are free to ignore it.""" pass + + +class SegmentDirectory(Component): + """A segment directory is a data interface that manages the location of segments. Concretely, this + means that for clustered chroma, it provides the grpc endpoint for a segment.""" + + @abstractmethod + def get_segment_endpoint(self, segment: Segment) -> str: + """Return the segment residence for a given segment ID""" + + @abstractmethod + def register_updated_segment_callback( + self, callback: Callable[[Segment], None] + ) -> None: + """Register a callback that will be called when a segment is updated""" + pass diff --git a/chromadb/segment/impl/distributed/server.py b/chromadb/segment/impl/distributed/server.py new file mode 100644 index 00000000000..f7ea2f2ecaf --- /dev/null +++ b/chromadb/segment/impl/distributed/server.py @@ -0,0 +1,135 @@ +from typing import Any, Dict, Type, cast +from uuid import UUID +from chromadb.config import Settings, System, get_class +from chromadb.proto.chroma_pb2_grpc import ( + SegmentServerServicer, + add_SegmentServerServicer_to_server, + VectorReaderServicer, + add_VectorReaderServicer_to_server, +) +import chromadb.proto.chroma_pb2 as proto +import grpc +from concurrent import futures +from chromadb.proto.convert import ( + from_proto_segment, + to_proto_seq_id, + to_proto_vector, + to_proto_vector_embedding_record, +) +from chromadb.segment import SegmentImplementation, SegmentType, VectorReader +from chromadb.config import System +from chromadb.types import ScalarEncoding, Segment, SegmentScope +import logging + + +# Run this with python -m chromadb.segment.impl.distributed.server + +# TODO: for now the distirbuted segment type is serviced by a persistent local segment, since +# the only real material difference is the way the segment is loaded and persisted. +# we should refactor our the index logic from the segment logic, and then we can have a +# distributed segment implementation that uses the same index impl but has a different segment wrapper +# that handles the distributed logic and storage + +SEGMENT_TYPE_IMPLS = { + SegmentType.HNSW_DISTRIBUTED: "chromadb.segment.impl.vector.local_persistent_hnsw.PersistentLocalHnswSegment", +} + + +class SegmentServer(SegmentServerServicer, VectorReaderServicer): + _segment_cache: Dict[UUID, SegmentImplementation] = {} + _system: System + + def __init__(self, system: System) -> None: + super().__init__() + self._system = system + + def LoadSegment( + self, request: proto.Segment, context: Any + ) -> proto.SegmentServerResponse: + logging.info(f"LoadSegment scope {request.type}") + id = UUID(hex=request.id) + if id in self._segment_cache: + return proto.SegmentServerResponse( + success=True, + ) + else: + if request.scope == proto.SegmentScope.METADATA: + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Metadata segments are not yet implemented") + return proto.SegmentServerResponse(success=False) + elif request.scope == proto.SegmentScope.VECTOR: + logging.info(f"Loading segment {request}") + if request.type == SegmentType.HNSW_DISTRIBUTED.value: + self._create_instance(from_proto_segment(request)) + return proto.SegmentServerResponse(success=True) + else: + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Segment type not implemented yet") + return proto.SegmentServerResponse(success=False) + else: + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Segment scope not implemented") + return proto.SegmentServerResponse(success=False) + + def ReleaseSegment( + self, request: proto.Segment, context: Any + ) -> proto.SegmentServerResponse: + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Release segment not implemented yet") + return proto.SegmentServerResponse(success=False) + + def QueryVectors( + self, request: proto.QueryVectorsRequest, context: Any + ) -> proto.QueryVectorsResponse: + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Query segment not implemented yet") + return proto.QueryVectorsResponse() + + def GetVectors( + self, request: proto.GetVectorsRequest, context: Any + ) -> proto.GetVectorsResponse: + segment_id = UUID(hex=request.segment_id) + if segment_id not in self._segment_cache: + context.set_code(grpc.StatusCode.NOT_FOUND) + context.set_details("Segment not found") + return proto.GetVectorsResponse() + else: + segment = self._segment_cache[segment_id] + segment = cast(VectorReader, segment) + segment_results = segment.get_vectors(request.ids) + return_records = [] + for record in segment_results: + # TODO: encoding should be based on stored encoding for segment + # For now we just assume float32 + return_record = to_proto_vector_embedding_record( + record, ScalarEncoding.FLOAT32 + ) + return_records.append(return_record) + return proto.GetVectorsResponse(records=return_records) + + def _cls(self, segment: Segment) -> Type[SegmentImplementation]: + classname = SEGMENT_TYPE_IMPLS[SegmentType(segment["type"])] + cls = get_class(classname, SegmentImplementation) + return cls + + def _create_instance(self, segment: Segment) -> None: + if segment["id"] not in self._segment_cache: + cls = self._cls(segment) + instance = cls(self._system, segment) + instance.start() + self._segment_cache[segment["id"]] = instance + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + system = System(Settings()) + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + segment_server = SegmentServer(system) + add_SegmentServerServicer_to_server(segment_server, server) # type: ignore + add_VectorReaderServicer_to_server(segment_server, server) # type: ignore + server.add_insecure_port( + f"[::]:{system.settings.require('chroma_server_grpc_port')}" + ) + system.start() + server.start() + server.wait_for_termination() diff --git a/chromadb/segment/impl/manager/distributed.py b/chromadb/segment/impl/manager/distributed.py new file mode 100644 index 00000000000..ea6c2f0267f --- /dev/null +++ b/chromadb/segment/impl/manager/distributed.py @@ -0,0 +1,157 @@ +from threading import Lock + +import grpc +from chromadb.proto.chroma_pb2_grpc import SegmentServerStub +from chromadb.proto.convert import to_proto_segment +from chromadb.segment import ( + SegmentImplementation, + SegmentManager, + MetadataReader, + SegmentType, + VectorReader, + S, +) +from chromadb.config import System, get_class +from chromadb.db.system import SysDB +from overrides import override +from enum import Enum +from chromadb.segment import SegmentDirectory +from chromadb.types import Collection, Operation, Segment, SegmentScope, Metadata +from typing import Dict, List, Type, Sequence, Optional, cast +from uuid import UUID, uuid4 +from collections import defaultdict + +# TODO: it is odd that the segment manager is different for distributed vs local +# implementations. This should be refactored to be more consistent and shared. +# needed in this is the ability to specify the desired segment types for a collection +# It is odd that segment manager is coupled to the segment implementation. We need to rethink +# this abstraction. + +SEGMENT_TYPE_IMPLS = { + SegmentType.SQLITE: "chromadb.segment.impl.metadata.sqlite.SqliteMetadataSegment", + SegmentType.HNSW_DISTRIBUTED: "chromadb.segment.impl.vector.grpc_segment.GrpcVectorSegment", +} + + +class DistributedSegmentManager(SegmentManager): + _sysdb: SysDB + _system: System + _instances: Dict[UUID, SegmentImplementation] + _segment_cache: Dict[ + UUID, Dict[SegmentScope, Segment] + ] # collection_id -> scope -> segment + _segment_directory: SegmentDirectory + _lock: Lock + _segment_server_stubs: Dict[str, SegmentServerStub] # grpc_url -> grpc stub + + def __init__(self, system: System): + super().__init__(system) + self._sysdb = self.require(SysDB) + self._segment_directory = self.require(SegmentDirectory) + self._system = system + self._instances = {} + self._segment_cache = defaultdict(dict) + self._segment_server_stubs = {} + self._lock = Lock() + + @override + def create_segments(self, collection: Collection) -> Sequence[Segment]: + vector_segment = _segment( + SegmentType.HNSW_DISTRIBUTED, SegmentScope.VECTOR, collection + ) + metadata_segment = _segment( + SegmentType.SQLITE, SegmentScope.METADATA, collection + ) + return [vector_segment, metadata_segment] + + @override + def delete_segments(self, collection_id: UUID) -> Sequence[UUID]: + raise NotImplementedError() + + @override + def get_segment(self, collection_id: UUID, type: type[S]) -> S: + if type == MetadataReader: + scope = SegmentScope.METADATA + elif type == VectorReader: + scope = SegmentScope.VECTOR + else: + raise ValueError(f"Invalid segment type: {type}") + + if scope not in self._segment_cache[collection_id]: + segments = self._sysdb.get_segments(collection=collection_id, scope=scope) + known_types = set([k.value for k in SEGMENT_TYPE_IMPLS.keys()]) + # Get the first segment of a known type + segment = next(filter(lambda s: s["type"] in known_types, segments)) + grpc_url = self._segment_directory.get_segment_endpoint(segment) + if segment["metadata"] is not None: + segment["metadata"]["grpc_url"] = grpc_url # type: ignore + else: + segment["metadata"] = {"grpc_url": grpc_url} + # TODO: Register a callback to update the segment when it gets moved + # self._segment_directory.register_updated_segment_callback() + self._segment_cache[collection_id][scope] = segment + + # Instances must be atomically created, so we use a lock to ensure that only one thread + # creates the instance. + with self._lock: + instance = self._instance(self._segment_cache[collection_id][scope]) + return cast(S, instance) + + @override + def hint_use_collection(self, collection_id: UUID, hint_type: Operation) -> None: + # TODO: this should call load/release on the target node, node should be stored in metadata + # for now this is fine, but cache invalidation is a problem btwn sysdb and segment manager + types = [MetadataReader, VectorReader] + for type in types: + self.get_segment( + collection_id, type + ) # TODO: this is a hack that mirrors local segment manager to force load the relevant instances + if type == VectorReader: + # Load the remote segment + segments = self._sysdb.get_segments( + collection=collection_id, scope=SegmentScope.VECTOR + ) + known_types = set([k.value for k in SEGMENT_TYPE_IMPLS.keys()]) + segment = next(filter(lambda s: s["type"] in known_types, segments)) + grpc_url = self._segment_directory.get_segment_endpoint(segment) + + if grpc_url not in self._segment_server_stubs: + channel = grpc.insecure_channel(grpc_url) + self._segment_server_stubs[grpc_url] = SegmentServerStub(channel) # type: ignore + + self._segment_server_stubs[grpc_url].LoadSegment( + to_proto_segment(segment) + ) + + # TODO: rethink duplication from local segment manager + def _cls(self, segment: Segment) -> Type[SegmentImplementation]: + classname = SEGMENT_TYPE_IMPLS[SegmentType(segment["type"])] + cls = get_class(classname, SegmentImplementation) + return cls + + def _instance(self, segment: Segment) -> SegmentImplementation: + if segment["id"] not in self._instances: + cls = self._cls(segment) + instance = cls(self._system, segment) + instance.start() + self._instances[segment["id"]] = instance + return self._instances[segment["id"]] + + +# TODO: rethink duplication from local segment manager +def _segment(type: SegmentType, scope: SegmentScope, collection: Collection) -> Segment: + """Create a metadata dict, propagating metadata correctly for the given segment type.""" + cls = get_class(SEGMENT_TYPE_IMPLS[type], SegmentImplementation) + collection_metadata = collection.get("metadata", None) + metadata: Optional[Metadata] = None + if collection_metadata: + metadata = cls.propagate_collection_metadata(collection_metadata) + + return Segment( + id=uuid4(), + type=type.value, + scope=scope, + topic=collection["topic"], + collection=collection["id"], + metadata=metadata, + ) diff --git a/chromadb/segment/impl/manager/local.py b/chromadb/segment/impl/manager/local.py index e13452d4113..a5b797e31c6 100644 --- a/chromadb/segment/impl/manager/local.py +++ b/chromadb/segment/impl/manager/local.py @@ -3,13 +3,13 @@ SegmentImplementation, SegmentManager, MetadataReader, + SegmentType, VectorReader, S, ) from chromadb.config import System, get_class from chromadb.db.system import SysDB from overrides import override -from enum import Enum from chromadb.segment.impl.vector.local_persistent_hnsw import ( PersistentLocalHnswSegment, ) @@ -27,12 +27,6 @@ import ctypes -class SegmentType(Enum): - SQLITE = "urn:chroma:segment/metadata/sqlite" - HNSW_LOCAL_MEMORY = "urn:chroma:segment/vector/hnsw-local-memory" - HNSW_LOCAL_PERSISTED = "urn:chroma:segment/vector/hnsw-local-persisted" - - SEGMENT_TYPE_IMPLS = { SegmentType.SQLITE: "chromadb.segment.impl.metadata.sqlite.SqliteMetadataSegment", SegmentType.HNSW_LOCAL_MEMORY: "chromadb.segment.impl.vector.local_hnsw.LocalHnswSegment", @@ -62,6 +56,8 @@ def __init__(self, system: System): self._segment_cache = defaultdict(dict) self._lock = Lock() + # TODO: prototyping with distributed segment for now, but this should be a configurable option + # we need to think about how to handle this configuration if self._system.settings.require("is_persistent"): self._vector_segment_type = SegmentType.HNSW_LOCAL_PERSISTED if platform.system() != "Windows": diff --git a/chromadb/segment/impl/manager/segment_directory.py b/chromadb/segment/impl/manager/segment_directory.py new file mode 100644 index 00000000000..7e086e26c5e --- /dev/null +++ b/chromadb/segment/impl/manager/segment_directory.py @@ -0,0 +1,36 @@ +from typing import Callable +from overrides import EnforceOverrides, override + +from chromadb.segment import SegmentDirectory +from chromadb.types import Segment + + +class DockerComposeSegmentDirectory(SegmentDirectory, EnforceOverrides): + """A segment directory that uses docker-compose to manage segment endpoints""" + + @override + def get_segment_endpoint(self, segment: Segment) -> str: + # This is just a stub for now, as we don't have a real coordinator to assign and manage this + return "segment-worker:50051" + + @override + def register_updated_segment_callback( + self, callback: Callable[[Segment], None] + ) -> None: + # Updates are not supported for docker-compose yet, as there is only a single, static + # indexing node + pass + + +class KubernetesSegmentDirectory(SegmentDirectory, EnforceOverrides): + @override + def get_segment_endpoint(self, segment: Segment) -> str: + return "segment-worker.chroma:50051" + + @override + def register_updated_segment_callback( + self, callback: Callable[[Segment], None] + ) -> None: + # Updates are not supported for docker-compose yet, as there is only a single, static + # indexing node + pass diff --git a/chromadb/segment/impl/vector/grpc_segment.py b/chromadb/segment/impl/vector/grpc_segment.py new file mode 100644 index 00000000000..0aac3baa253 --- /dev/null +++ b/chromadb/segment/impl/vector/grpc_segment.py @@ -0,0 +1,96 @@ +from overrides import EnforceOverrides, override +from typing import List, Optional, Sequence +from chromadb.config import System +from chromadb.proto.convert import ( + from_proto_vector, + from_proto_vector_embedding_record, + from_proto_vector_query_result, + to_proto_vector, +) +from chromadb.segment import MetadataReader, VectorReader +from chromadb.segment.impl.vector.hnsw_params import PersistentHnswParams +from chromadb.types import ( + Metadata, + ScalarEncoding, + Segment, + VectorEmbeddingRecord, + VectorQuery, + VectorQueryResult, +) +from chromadb.proto.chroma_pb2_grpc import VectorReaderStub +from chromadb.proto.chroma_pb2 import ( + GetVectorsRequest, + GetVectorsResponse, + QueryVectorsRequest, + QueryVectorsResponse, +) +import grpc + + +class GrpcVectorSegment(VectorReader, EnforceOverrides): + _vector_reader_stub: VectorReaderStub + _segment: Segment + + def __init__(self, system: System, segment: Segment): + # TODO: move to start() method + # TODO: close channel in stop() method + if segment["metadata"] is None or segment["metadata"]["grpc_url"] is None: + raise Exception("Missing grpc_url in segment metadata") + + channel = grpc.insecure_channel(segment["metadata"]["grpc_url"]) + self._vector_reader_stub = VectorReaderStub(channel) # type: ignore + self._segment = segment + + @override + def get_vectors( + self, ids: Optional[Sequence[str]] = None + ) -> Sequence[VectorEmbeddingRecord]: + request = GetVectorsRequest(ids=ids, segment_id=self._segment["id"].hex) + response: GetVectorsResponse = self._vector_reader_stub.GetVectors(request) + results: List[VectorEmbeddingRecord] = [] + for vector in response.records: + result = from_proto_vector_embedding_record(vector) + results.append(result) + return results + + @override + def query_vectors( + self, query: VectorQuery + ) -> Sequence[Sequence[VectorQueryResult]]: + request = QueryVectorsRequest( + vectors=[ + to_proto_vector(vector=v, encoding=ScalarEncoding.FLOAT32) + for v in query["vectors"] + ], + k=query["k"], + allowed_ids=query["allowed_ids"], + include_embeddings=query["include_embeddings"], + segment_id=self._segment["id"].hex, + ) + response: QueryVectorsResponse = self._vector_reader_stub.QueryVectors(request) + results: List[List[VectorQueryResult]] = [] + for result in response.results: + curr_result: List[VectorQueryResult] = [] + for r in result.results: + curr_result.append(from_proto_vector_query_result(r)) + results.append(curr_result) + return results + + @override + def count(self) -> int: + raise NotImplementedError() + + @override + def max_seqid(self) -> int: + return 0 + + @staticmethod + @override + def propagate_collection_metadata(metadata: Metadata) -> Optional[Metadata]: + # Great example of why language sharing is nice. + segment_metadata = PersistentHnswParams.extract(metadata) + return segment_metadata + + @override + def delete(self) -> None: + raise NotImplementedError() diff --git a/chromadb/test/conftest.py b/chromadb/test/conftest.py index fe91622e383..aa4e530e384 100644 --- a/chromadb/test/conftest.py +++ b/chromadb/test/conftest.py @@ -191,6 +191,21 @@ def fastapi_persistent() -> Generator[System, None, None]: return _fastapi_fixture(is_persistent=True) +def basic_http_client() -> Generator[System, None, None]: + settings = Settings( + chroma_api_impl="chromadb.api.fastapi.FastAPI", + chroma_server_host="localhost", + chroma_server_http_port="8000", + allow_reset=True, + ) + system = System(settings) + api = system.instance(API) + _await_server(api) + system.start() + yield system + system.stop() + + def fastapi_server_basic_auth() -> Generator[System, None, None]: server_auth_file = os.path.abspath(os.path.join(".", "server.htpasswd")) with open(server_auth_file, "w") as f: @@ -327,6 +342,8 @@ def system_fixtures() -> List[Callable[[], Generator[System, None, None]]]: fixtures.append(integration) if "CHROMA_INTEGRATION_TEST_ONLY" in os.environ: fixtures = [integration] + if "CHROMA_CLUSTER_TEST_ONLY" in os.environ: + fixtures = [basic_http_client] return fixtures @@ -439,5 +456,5 @@ def produce_fns( yield request.param -def pytest_configure(config): +def pytest_configure(config): # type: ignore embeddings_queue._called_from_test = True diff --git a/chromadb/test/ingest/test_producer_consumer.py b/chromadb/test/ingest/test_producer_consumer.py index 1163889a246..de2ed592d07 100644 --- a/chromadb/test/ingest/test_producer_consumer.py +++ b/chromadb/test/ingest/test_producer_consumer.py @@ -178,7 +178,6 @@ async def test_backfill( producer, consumer = producer_consumer producer.reset_state() consumer.reset_state() - topic_name = full_topic_name("test_topic") producer.create_topic(topic_name) embeddings = produce_fns(producer, topic_name, sample_embeddings, 3)[0] diff --git a/docker-compose.cluster.test.yml b/docker-compose.cluster.test.yml new file mode 100644 index 00000000000..8b3f83eda7f --- /dev/null +++ b/docker-compose.cluster.test.yml @@ -0,0 +1,96 @@ +# This docker compose file is not meant to be used. It is a work in progress +# for the distributed version of Chroma. It is not yet functional. + +version: '3.9' + +networks: + net: + driver: bridge + +services: + server: + image: server + build: + context: . + dockerfile: Dockerfile + volumes: + - ./:/chroma + - index_data:/index_data + command: uvicorn chromadb.app:app --reload --workers 1 --host 0.0.0.0 --port 8000 --log-config chromadb/log_config.yml + environment: + - IS_PERSISTENT=TRUE + - CHROMA_PRODUCER_IMPL=chromadb.ingest.impl.pulsar.PulsarProducer + - CHROMA_CONSUMER_IMPL=chromadb.ingest.impl.pulsar.PulsarConsumer + - CHROMA_SEGMENT_MANAGER_IMPL=chromadb.segment.impl.manager.distributed.DistributedSegmentManager + - PULSAR_BROKER_URL=pulsar + - PULSAR_BROKER_PORT=6650 + - PULSAR_ADMIN_PORT=8080 + - ANONYMIZED_TELEMETRY=False + - ALLOW_RESET=True + ports: + - 8000:8000 + depends_on: + pulsar: + condition: service_healthy + networks: + - net + + segment-worker: + image: segment-worker + build: + context: . + dockerfile: Dockerfile + volumes: + - ./:/chroma + - index_data:/index_data + command: python -m chromadb.segment.impl.distributed.server + environment: + - IS_PERSISTENT=TRUE + - CHROMA_PRODUCER_IMPL=chromadb.ingest.impl.pulsar.PulsarProducer + - CHROMA_CONSUMER_IMPL=chromadb.ingest.impl.pulsar.PulsarConsumer + - PULSAR_BROKER_URL=pulsar + - PULSAR_BROKER_PORT=6650 + - PULSAR_ADMIN_PORT=8080 + - CHROMA_SERVER_GRPC_PORT=50051 + - ANONYMIZED_TELEMETRY=False + - ALLOW_RESET=True + ports: + - 50051:50051 + depends_on: + pulsar: + condition: service_healthy + networks: + - net + + pulsar: + image: apachepulsar/pulsar + volumes: + - pulsardata:/pulsar/data + - pulsarconf:/pulsar/conf + command: bin/pulsar standalone + ports: + - 6650:6650 + - 8080:8080 + networks: + - net + healthcheck: + test: + [ + "CMD", + "curl", + "-f", + "localhost:8080/admin/v2/brokers/health" + ] + interval: 3s + timeout: 1m + retries: 10 + +volumes: + index_data: + driver: local + backups: + driver: local + pulsardata: + driver: local + pulsarconf: + driver: local diff --git a/docker-compose.cluster.yml b/docker-compose.cluster.yml index d36ed16906f..2aaee58c566 100644 --- a/docker-compose.cluster.yml +++ b/docker-compose.cluster.yml @@ -16,11 +16,12 @@ services: volumes: - ./:/chroma - index_data:/index_data - command: uvicorn chromadb.app:app --reload --workers 1 --host 0.0.0.0 --port 8000 --log-config log_config.yml + command: uvicorn chromadb.app:app --reload --workers 1 --host 0.0.0.0 --port 8000 --log-config chromadb/log_config.yml environment: - IS_PERSISTENT=TRUE - CHROMA_PRODUCER_IMPL=chromadb.ingest.impl.pulsar.PulsarProducer - CHROMA_CONSUMER_IMPL=chromadb.ingest.impl.pulsar.PulsarConsumer + - CHROMA_SEGMENT_MANAGER_IMPL=chromadb.segment.impl.manager.distributed.DistributedSegmentManager - PULSAR_BROKER_URL=pulsar - PULSAR_BROKER_PORT=6650 - PULSAR_ADMIN_PORT=8080 @@ -32,6 +33,31 @@ services: networks: - net + segment-worker: + image: segment-worker + build: + context: . + dockerfile: Dockerfile + volumes: + - ./:/chroma + - index_data:/index_data + command: python -m chromadb.segment.impl.distributed.server + environment: + - IS_PERSISTENT=TRUE + - CHROMA_PRODUCER_IMPL=chromadb.ingest.impl.pulsar.PulsarProducer + - CHROMA_CONSUMER_IMPL=chromadb.ingest.impl.pulsar.PulsarConsumer + - PULSAR_BROKER_URL=pulsar + - PULSAR_BROKER_PORT=6650 + - PULSAR_ADMIN_PORT=8080 + - CHROMA_SERVER_GRPC_PORT=50051 + ports: + - 50051:50051 + depends_on: + pulsar: + condition: service_healthy + networks: + - net + pulsar: image: apachepulsar/pulsar volumes: diff --git a/pyproject.toml b/pyproject.toml index 7db0fe821ef..f69ae22140c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ 'overrides >= 7.3.1', 'importlib-resources', 'graphlib_backport >= 1.0.3; python_version < "3.9"', + 'grpcio >= 1.58.0', 'bcrypt >= 4.0.1', 'typer >= 0.9.0', ] diff --git a/requirements.txt b/requirements.txt index 8d00aa134cb..6501c7d9d2c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ bcrypt==4.0.1 chroma-hnswlib==0.7.3 fastapi>=0.95.2 graphlib_backport==1.0.3; python_version < '3.9' +grpcio==1.58.0 importlib-resources numpy==1.21.6; python_version < '3.8' numpy>=1.22.4; python_version >= '3.8' diff --git a/requirements_dev.txt b/requirements_dev.txt index 9354d39b725..4dce86e2efe 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -1,5 +1,6 @@ black==23.3.0 # match what's in pyproject.toml build +grpcio-tools httpx hypothesis hypothesis[numpy] From fc4c8b547444efafa8ddf75fbd53a9b8e1a7eabe Mon Sep 17 00:00:00 2001 From: Anton Troynikov Date: Wed, 4 Oct 2023 09:01:33 -0700 Subject: [PATCH 37/39] [BUG]: Chat with your documents example exhibits flaky retrieval (#1203) ## Description of changes *Summarize the changes made by this PR.* In https://github.com/chroma-core/chroma/issues/1115 @BChip noticed flaky retrieval performance. The issue was difficult to replicate because of nondeterminism inherent in the HNSW graph construction on loading, but I was able to track it down through repeated testing. The issue is caused by ingesting all the empty lines in the document, which make up 50% of all the lines in each file, which outputs the same embedding for all of them, causing the HNSW graph to sometimes be degenerate. The fix is to skip the empty lines. We should consider how we can mitigate this in the future since this is not easy to detect after the fact, and is likely to be something users run into. ## Test plan Failures no longer occur after manual invocation. ## Documentation Changes N/A --- examples/chat_with_your_documents/load_data.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/chat_with_your_documents/load_data.py b/examples/chat_with_your_documents/load_data.py index 574a2127d10..b9ffdbb116a 100644 --- a/examples/chat_with_your_documents/load_data.py +++ b/examples/chat_with_your_documents/load_data.py @@ -22,6 +22,9 @@ def main( ): # Strip whitespace and append the line to the documents list line = line.strip() + # Skip empty lines + if len(line) == 0: + continue documents.append(line) metadatas.append({"filename": filename, "line_number": line_number}) From e3a60a9a1a4368f3dc9f1c281e49200d765f0e04 Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Fri, 6 Oct 2023 17:48:42 +0300 Subject: [PATCH 38/39] [BUG]: Emergency Fix Integration Tests (#1210) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - An issue was introduced in the integration tests where Python testing with `pytest` was commented out (bug introduced with #1114) ## Test plan *How are these changes tested?* - [x] Tests pass locally with `pytest` for python ## Documentation Changes N/A --- bin/integration-test | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bin/integration-test b/bin/integration-test index 54b4e387e08..38a763c8650 100755 --- a/bin/integration-test +++ b/bin/integration-test @@ -51,9 +51,9 @@ export CHROMA_INTEGRATION_TEST_ONLY=1 export CHROMA_API_IMPL=chromadb.api.fastapi.FastAPI export CHROMA_SERVER_HOST=localhost export CHROMA_SERVER_HTTP_PORT=8000 -# -#echo testing: python -m pytest "$@" -#python -m pytest "$@" + +echo testing: python -m pytest "$@" +python -m pytest "$@" cd clients/js yarn From e357ef397b3324edec4104756155a8836b03ed65 Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Fri, 6 Oct 2023 18:09:56 +0300 Subject: [PATCH 39/39] [BUG]: Python client sqlite issue (#1211) Refs: #1206 ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Conditional import of sqlite for python client ## Test plan *How are these changes tested?* - [x] Tests pass locally with `pytest` for python Additional test in older disto (debian buster) was run: ``` pip install chromadb-client root@c26a0fadfcdc:~# python Python 3.10.12 (main, Jun 13 2023, 12:02:28) [GCC 8.3.0] on linux Type "help", "copyright", "credits" or "license" for more information. >>> import sqlite3 >>> print(sqlite3.sqlite_version_info) (3, 27, 2) >>> import chromadb >>> ``` ## Documentation Changes N/A --- chromadb/__init__.py | 42 +++++++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/chromadb/__init__.py b/chromadb/__init__.py index aa5a3edd7ea..ffc32392e07 100644 --- a/chromadb/__init__.py +++ b/chromadb/__init__.py @@ -1,6 +1,5 @@ from typing import Dict import logging -import sqlite3 import chromadb.config from chromadb.config import Settings, System from chromadb.api import API @@ -54,22 +53,31 @@ except ImportError: IN_COLAB = False -if sqlite3.sqlite_version_info < (3, 35, 0): - if IN_COLAB: - # In Colab, hotswap to pysqlite-binary if it's too old - import subprocess - import sys - - subprocess.check_call( - [sys.executable, "-m", "pip", "install", "pysqlite3-binary"] - ) - __import__("pysqlite3") - sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") - else: - raise RuntimeError( - "\033[91mYour system has an unsupported version of sqlite3. Chroma requires sqlite3 >= 3.35.0.\033[0m\n" - "\033[94mPlease visit https://docs.trychroma.com/troubleshooting#sqlite to learn how to upgrade.\033[0m" - ) +is_client = False +try: + from chromadb.is_thin_client import is_thin_client # type: ignore + is_client = is_thin_client +except ImportError: + is_client = False + +if not is_client: + import sqlite3 + if sqlite3.sqlite_version_info < (3, 35, 0): + if IN_COLAB: + # In Colab, hotswap to pysqlite-binary if it's too old + import subprocess + import sys + + subprocess.check_call( + [sys.executable, "-m", "pip", "install", "pysqlite3-binary"] + ) + __import__("pysqlite3") + sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") + else: + raise RuntimeError( + "\033[91mYour system has an unsupported version of sqlite3. Chroma requires sqlite3 >= 3.35.0.\033[0m\n" + "\033[94mPlease visit https://docs.trychroma.com/troubleshooting#sqlite to learn how to upgrade.\033[0m" + ) def configure(**kwargs) -> None: # type: ignore