Skip to content

Commit

Permalink
[ENH] Use InvalidCollectionException/Error consistently (#2048)
Browse files Browse the repository at this point in the history
Use the InvalidCollectionException consistently in the Python client and
APIs.

Introduce InvalidCollectionError to the JS client and handle it
accordingly.

This is the first step to improving error handling in both client and
completing #565.

Add tests for all related scenarios.
  • Loading branch information
ibratoev authored Apr 30, 2024
1 parent f070188 commit 4b2599c
Show file tree
Hide file tree
Showing 11 changed files with 219 additions and 9 deletions.
9 changes: 9 additions & 0 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,8 @@ def _modify(
if new_metadata:
validate_update_metadata(new_metadata)

self._validate_collection(id)

# TODO eventually we'll want to use OptionalArgument and Unspecified in the
# signature of `_modify` but not changing the API right now.
if new_name and new_metadata:
Expand Down Expand Up @@ -498,6 +500,8 @@ def _get(
}
)

self._validate_collection(collection_id)

where = validate_where(where) if where is not None and len(where) > 0 else None
where_document = (
validate_where_document(where_document)
Expand Down Expand Up @@ -651,6 +655,8 @@ def _delete(
@override
def _count(self, collection_id: UUID) -> int:
add_attributes_to_current_span({"collection_id": str(collection_id)})
self._validate_collection(collection_id)

metadata_segment = self._manager.get_segment(collection_id, MetadataReader)
return metadata_segment.count()

Expand Down Expand Up @@ -843,6 +849,9 @@ def _get_collection(self, collection_id: UUID) -> t.Collection:
self._collection_cache[collection_id] = collections[0]
return self._collection_cache[collection_id]

def _validate_collection(self, collection_id: UUID) -> None:
self._get_collection(collection_id)


def _records(
operation: t.Operation,
Expand Down
100 changes: 100 additions & 0 deletions chromadb/test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from chromadb.api.fastapi import FastAPI
from chromadb.api.types import QueryResult, EmbeddingFunction, Document
from chromadb.config import Settings
from chromadb.errors import InvalidCollectionException
import chromadb.server.fastapi
import pytest
import tempfile
Expand Down Expand Up @@ -224,6 +225,17 @@ def test_add(api):
assert collection.count() == 2


def test_collection_add_with_invalid_collection_throws(api):
api.reset()
collection = api.create_collection("test")
api.delete_collection("test")

with pytest.raises(
InvalidCollectionException, match=r"Collection .* does not exist."
):
collection.add(**batch_records)


def test_get_or_create(api):
api.reset()

Expand Down Expand Up @@ -272,6 +284,17 @@ def test_get_from_db(api):
assert records[key] is None


def test_collection_get_with_invalid_collection_throws(api):
api.reset()
collection = api.create_collection("test")
api.delete_collection("test")

with pytest.raises(
InvalidCollectionException, match=r"Collection .* does not exist."
):
collection.get()


def test_reset_db(api):
api.reset()

Expand Down Expand Up @@ -350,6 +373,17 @@ def test_delete_with_index(api):
collection.query(query_embeddings=[[1.1, 2.3, 3.2]], n_results=1)


def test_collection_delete_with_invalid_collection_throws(api):
api.reset()
collection = api.create_collection("test")
api.delete_collection("test")

with pytest.raises(
InvalidCollectionException, match=r"Collection .* does not exist."
):
collection.delete(ids=["id1"])


def test_count(api):
api.reset()
collection = api.create_collection("testspace")
Expand All @@ -358,6 +392,17 @@ def test_count(api):
assert collection.count() == 2


def test_collection_count_with_invalid_collection_throws(api):
api.reset()
collection = api.create_collection("test")
api.delete_collection("test")

with pytest.raises(
InvalidCollectionException, match=r"Collection .* does not exist."
):
collection.count()


def test_modify(api):
api.reset()
collection = api.create_collection("testspace")
Expand All @@ -367,6 +412,17 @@ def test_modify(api):
assert collection.name == "testspace2"


def test_collection_modify_with_invalid_collection_throws(api):
api.reset()
collection = api.create_collection("test")
api.delete_collection("test")

with pytest.raises(
InvalidCollectionException, match=r"Collection .* does not exist."
):
collection.modify(name="test2")


def test_modify_error_on_existing_name(api):
api.reset()

Expand Down Expand Up @@ -511,6 +567,39 @@ def test_peek(api):
assert peek[key] is None


def test_collection_peek_with_invalid_collection_throws(api):
api.reset()
collection = api.create_collection("test")
api.delete_collection("test")

with pytest.raises(
InvalidCollectionException, match=r"Collection .* does not exist."
):
collection.peek()


def test_collection_query_with_invalid_collection_throws(api):
api.reset()
collection = api.create_collection("test")
api.delete_collection("test")

with pytest.raises(
InvalidCollectionException, match=r"Collection .* does not exist."
):
collection.query(query_texts=["test"])


def test_collection_update_with_invalid_collection_throws(api):
api.reset()
collection = api.create_collection("test")
api.delete_collection("test")

with pytest.raises(
InvalidCollectionException, match=r"Collection .* does not exist."
):
collection.update(ids=["id1"], documents=["test"])


# TEST METADATA AND METADATA FILTERING
# region

Expand Down Expand Up @@ -1427,6 +1516,17 @@ def test_upsert(api):
assert get_result["documents"][0] is None


def test_collection_upsert_with_invalid_collection_throws(api):
api.reset()
collection = api.create_collection("test")
api.delete_collection("test")

with pytest.raises(
InvalidCollectionException, match=r"Collection .* does not exist."
):
collection.upsert(**initial_records)


# test to make sure add, query, update, upsert error on invalid embeddings input


Expand Down
5 changes: 5 additions & 0 deletions clients/js/src/ChromaFetch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
ChromaServerError,
ChromaValueError,
ChromaError,
createErrorByType,
} from "./Errors";
import { FetchAPI } from "./generated";

Expand Down Expand Up @@ -50,6 +51,10 @@ export const chromaFetch: FetchAPI = async (
const clonedResp = resp.clone();
const respBody = await clonedResp.json();
if (!clonedResp.ok) {
const error = createErrorByType(respBody?.error, respBody?.message);
if (error) {
throw error;
}
switch (resp.status) {
case 400:
throw new ChromaClientError(
Expand Down
16 changes: 16 additions & 0 deletions clients/js/src/Errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,19 @@ export class ChromaValueError extends Error {
super(message);
}
}

export class InvalidCollectionError extends Error {
name = "InvalidCollectionError";
constructor(message: string, public readonly cause?: unknown) {
super(message);
}
}

export function createErrorByType(type: string, message: string) {
switch (type) {
case "InvalidCollection":
return new InvalidCollectionError(message);
default:
return undefined;
}
}
11 changes: 11 additions & 0 deletions clients/js/test/add.collections.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import { IncludeEnum } from "../src/types";
import { OpenAIEmbeddingFunction } from "../src/embeddings/OpenAIEmbeddingFunction";
import { CohereEmbeddingFunction } from "../src/embeddings/CohereEmbeddingFunction";
import { OllamaEmbeddingFunction } from "../src/embeddings/OllamaEmbeddingFunction";
import { InvalidCollectionError } from "../src/Errors";

test("it should add single embeddings to a collection", async () => {
await chroma.reset();
const collection = await chroma.createCollection({ name: "test" });
Expand Down Expand Up @@ -96,6 +98,15 @@ test("add documents", async () => {
expect(results.documents[0]).toBe("This is a test");
});

test("should error on non existing collection", async () => {
await chroma.reset();
const collection = await chroma.createCollection({ name: "test" });
await chroma.deleteCollection({ name: "test" });
expect(async () => {
await collection.add({ ids: IDS, embeddings: EMBEDDINGS });
}).rejects.toThrow(InvalidCollectionError);
});

test("It should return an error when inserting duplicate IDs in the same batch", async () => {
await chroma.reset();
const collection = await chroma.createCollection({ name: "test" });
Expand Down
12 changes: 11 additions & 1 deletion clients/js/test/delete.collection.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { expect, test } from "@jest/globals";
import chroma from "./initClient";
import { EMBEDDINGS, IDS, METADATAS } from "./data";
import { InvalidCollectionError } from "../src/Errors";

test("it should delete a collection", async () => {
await chroma.reset();
Expand All @@ -18,6 +19,15 @@ test("it should delete a collection", async () => {

var remainingEmbeddings = await collection.get();
expect(["test2", "test3"]).toEqual(
expect.arrayContaining(remainingEmbeddings.ids),
expect.arrayContaining(remainingEmbeddings.ids)
);
});

test("should error on non existing collection", async () => {
await chroma.reset();
const collection = await chroma.createCollection({ name: "test" });
await chroma.deleteCollection({ name: "test" });
expect(async () => {
await collection.delete({ where: { test: "test1" } });
}).rejects.toThrow(InvalidCollectionError);
});
15 changes: 12 additions & 3 deletions clients/js/test/get.collection.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { expect, test } from "@jest/globals";
import chroma from "./initClient";
import { DOCUMENTS, EMBEDDINGS, IDS, METADATAS } from "./data";
import { ChromaValueError } from "../src/Errors";
import { ChromaValueError, InvalidCollectionError } from "../src/Errors";

test("it should get a collection", async () => {
await chroma.reset();
Expand Down Expand Up @@ -47,7 +47,7 @@ test("wrong code returns an error", async () => {
expect(error).toBeDefined();
expect(error).toBeInstanceOf(ChromaValueError);
expect(error.message).toMatchInlineSnapshot(
`"Expected where operator to be one of $gt, $gte, $lt, $lte, $ne, $eq, $in, $nin, got $contains"`,
`"Expected where operator to be one of $gt, $gte, $lt, $lte, $ne, $eq, $in, $nin, got $contains"`
);
}
});
Expand Down Expand Up @@ -101,10 +101,19 @@ test("test gt, lt, in a simple small way", async () => {
expect(["test2", "test3"]).toEqual(expect.arrayContaining(items.ids));
});

test("should error on non existing collection", async () => {
await chroma.reset();
const collection = await chroma.createCollection({ name: "test" });
await chroma.deleteCollection({ name: "test" });
expect(async () => {
await collection.get({ ids: IDS });
}).rejects.toThrow(InvalidCollectionError);
});

test("it should throw an error if the collection does not exist", async () => {
await chroma.reset();

await expect(
async () => await chroma.getCollection({ name: "test" }),
async () => await chroma.getCollection({ name: "test" })
).rejects.toThrow(Error);
});
10 changes: 10 additions & 0 deletions clients/js/test/peek.collection.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { expect, test } from "@jest/globals";
import chroma from "./initClient";
import { IDS, EMBEDDINGS } from "./data";
import { InvalidCollectionError } from "../src/Errors";

test("it should peek a collection", async () => {
await chroma.reset();
Expand All @@ -12,3 +13,12 @@ test("it should peek a collection", async () => {
expect(results.ids.length).toBe(2);
expect(["test1", "test2"]).toEqual(expect.arrayContaining(results.ids));
});

test("should error on non existing collection", async () => {
await chroma.reset();
const collection = await chroma.createCollection({ name: "test" });
await chroma.deleteCollection({ name: "test" });
expect(async () => {
await collection.peek();
}).rejects.toThrow(InvalidCollectionError);
});
Loading

0 comments on commit 4b2599c

Please sign in to comment.