From 4b2599ce41edfb571dfb5e39c883f653cb061dc3 Mon Sep 17 00:00:00 2001 From: Ivaylo Bratoev Date: Tue, 30 Apr 2024 21:27:52 +0300 Subject: [PATCH] [ENH] Use InvalidCollectionException/Error consistently (#2048) Use the InvalidCollectionException consistently in the Python client and APIs. Introduce InvalidCollectionError to the JS client and handle it accordingly. This is the first step to improving error handling in both client and completing #565. Add tests for all related scenarios. --- chromadb/api/segment.py | 9 ++ chromadb/test/test_api.py | 100 +++++++++++++++++++++ clients/js/src/ChromaFetch.ts | 5 ++ clients/js/src/Errors.ts | 16 ++++ clients/js/test/add.collections.test.ts | 11 +++ clients/js/test/delete.collection.test.ts | 12 ++- clients/js/test/get.collection.test.ts | 15 +++- clients/js/test/peek.collection.test.ts | 10 +++ clients/js/test/query.collection.test.ts | 20 +++-- clients/js/test/update.collection.test.ts | 15 ++++ clients/js/test/upsert.collections.test.ts | 15 ++++ 11 files changed, 219 insertions(+), 9 deletions(-) diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index 0c29f38f03c..aa15d8c71e6 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -323,6 +323,8 @@ def _modify( if new_metadata: validate_update_metadata(new_metadata) + self._validate_collection(id) + # TODO eventually we'll want to use OptionalArgument and Unspecified in the # signature of `_modify` but not changing the API right now. if new_name and new_metadata: @@ -498,6 +500,8 @@ def _get( } ) + self._validate_collection(collection_id) + where = validate_where(where) if where is not None and len(where) > 0 else None where_document = ( validate_where_document(where_document) @@ -651,6 +655,8 @@ def _delete( @override def _count(self, collection_id: UUID) -> int: add_attributes_to_current_span({"collection_id": str(collection_id)}) + self._validate_collection(collection_id) + metadata_segment = self._manager.get_segment(collection_id, MetadataReader) return metadata_segment.count() @@ -843,6 +849,9 @@ def _get_collection(self, collection_id: UUID) -> t.Collection: self._collection_cache[collection_id] = collections[0] return self._collection_cache[collection_id] + def _validate_collection(self, collection_id: UUID) -> None: + self._get_collection(collection_id) + def _records( operation: t.Operation, diff --git a/chromadb/test/test_api.py b/chromadb/test/test_api.py index 3cef03da19c..e2830b394e1 100644 --- a/chromadb/test/test_api.py +++ b/chromadb/test/test_api.py @@ -7,6 +7,7 @@ from chromadb.api.fastapi import FastAPI from chromadb.api.types import QueryResult, EmbeddingFunction, Document from chromadb.config import Settings +from chromadb.errors import InvalidCollectionException import chromadb.server.fastapi import pytest import tempfile @@ -224,6 +225,17 @@ def test_add(api): assert collection.count() == 2 +def test_collection_add_with_invalid_collection_throws(api): + api.reset() + collection = api.create_collection("test") + api.delete_collection("test") + + with pytest.raises( + InvalidCollectionException, match=r"Collection .* does not exist." + ): + collection.add(**batch_records) + + def test_get_or_create(api): api.reset() @@ -272,6 +284,17 @@ def test_get_from_db(api): assert records[key] is None +def test_collection_get_with_invalid_collection_throws(api): + api.reset() + collection = api.create_collection("test") + api.delete_collection("test") + + with pytest.raises( + InvalidCollectionException, match=r"Collection .* does not exist." + ): + collection.get() + + def test_reset_db(api): api.reset() @@ -350,6 +373,17 @@ def test_delete_with_index(api): collection.query(query_embeddings=[[1.1, 2.3, 3.2]], n_results=1) +def test_collection_delete_with_invalid_collection_throws(api): + api.reset() + collection = api.create_collection("test") + api.delete_collection("test") + + with pytest.raises( + InvalidCollectionException, match=r"Collection .* does not exist." + ): + collection.delete(ids=["id1"]) + + def test_count(api): api.reset() collection = api.create_collection("testspace") @@ -358,6 +392,17 @@ def test_count(api): assert collection.count() == 2 +def test_collection_count_with_invalid_collection_throws(api): + api.reset() + collection = api.create_collection("test") + api.delete_collection("test") + + with pytest.raises( + InvalidCollectionException, match=r"Collection .* does not exist." + ): + collection.count() + + def test_modify(api): api.reset() collection = api.create_collection("testspace") @@ -367,6 +412,17 @@ def test_modify(api): assert collection.name == "testspace2" +def test_collection_modify_with_invalid_collection_throws(api): + api.reset() + collection = api.create_collection("test") + api.delete_collection("test") + + with pytest.raises( + InvalidCollectionException, match=r"Collection .* does not exist." + ): + collection.modify(name="test2") + + def test_modify_error_on_existing_name(api): api.reset() @@ -511,6 +567,39 @@ def test_peek(api): assert peek[key] is None +def test_collection_peek_with_invalid_collection_throws(api): + api.reset() + collection = api.create_collection("test") + api.delete_collection("test") + + with pytest.raises( + InvalidCollectionException, match=r"Collection .* does not exist." + ): + collection.peek() + + +def test_collection_query_with_invalid_collection_throws(api): + api.reset() + collection = api.create_collection("test") + api.delete_collection("test") + + with pytest.raises( + InvalidCollectionException, match=r"Collection .* does not exist." + ): + collection.query(query_texts=["test"]) + + +def test_collection_update_with_invalid_collection_throws(api): + api.reset() + collection = api.create_collection("test") + api.delete_collection("test") + + with pytest.raises( + InvalidCollectionException, match=r"Collection .* does not exist." + ): + collection.update(ids=["id1"], documents=["test"]) + + # TEST METADATA AND METADATA FILTERING # region @@ -1427,6 +1516,17 @@ def test_upsert(api): assert get_result["documents"][0] is None +def test_collection_upsert_with_invalid_collection_throws(api): + api.reset() + collection = api.create_collection("test") + api.delete_collection("test") + + with pytest.raises( + InvalidCollectionException, match=r"Collection .* does not exist." + ): + collection.upsert(**initial_records) + + # test to make sure add, query, update, upsert error on invalid embeddings input diff --git a/clients/js/src/ChromaFetch.ts b/clients/js/src/ChromaFetch.ts index 9cdd86b5068..7a7596af534 100644 --- a/clients/js/src/ChromaFetch.ts +++ b/clients/js/src/ChromaFetch.ts @@ -7,6 +7,7 @@ import { ChromaServerError, ChromaValueError, ChromaError, + createErrorByType, } from "./Errors"; import { FetchAPI } from "./generated"; @@ -50,6 +51,10 @@ export const chromaFetch: FetchAPI = async ( const clonedResp = resp.clone(); const respBody = await clonedResp.json(); if (!clonedResp.ok) { + const error = createErrorByType(respBody?.error, respBody?.message); + if (error) { + throw error; + } switch (resp.status) { case 400: throw new ChromaClientError( diff --git a/clients/js/src/Errors.ts b/clients/js/src/Errors.ts index fd5c146c5cc..678e6742d1c 100644 --- a/clients/js/src/Errors.ts +++ b/clients/js/src/Errors.ts @@ -88,3 +88,19 @@ export class ChromaValueError extends Error { super(message); } } + +export class InvalidCollectionError extends Error { + name = "InvalidCollectionError"; + constructor(message: string, public readonly cause?: unknown) { + super(message); + } +} + +export function createErrorByType(type: string, message: string) { + switch (type) { + case "InvalidCollection": + return new InvalidCollectionError(message); + default: + return undefined; + } +} diff --git a/clients/js/test/add.collections.test.ts b/clients/js/test/add.collections.test.ts index 41b3de3fef5..9ecf2371cc3 100644 --- a/clients/js/test/add.collections.test.ts +++ b/clients/js/test/add.collections.test.ts @@ -6,6 +6,8 @@ import { IncludeEnum } from "../src/types"; import { OpenAIEmbeddingFunction } from "../src/embeddings/OpenAIEmbeddingFunction"; import { CohereEmbeddingFunction } from "../src/embeddings/CohereEmbeddingFunction"; import { OllamaEmbeddingFunction } from "../src/embeddings/OllamaEmbeddingFunction"; +import { InvalidCollectionError } from "../src/Errors"; + test("it should add single embeddings to a collection", async () => { await chroma.reset(); const collection = await chroma.createCollection({ name: "test" }); @@ -96,6 +98,15 @@ test("add documents", async () => { expect(results.documents[0]).toBe("This is a test"); }); +test("should error on non existing collection", async () => { + await chroma.reset(); + const collection = await chroma.createCollection({ name: "test" }); + await chroma.deleteCollection({ name: "test" }); + expect(async () => { + await collection.add({ ids: IDS, embeddings: EMBEDDINGS }); + }).rejects.toThrow(InvalidCollectionError); +}); + test("It should return an error when inserting duplicate IDs in the same batch", async () => { await chroma.reset(); const collection = await chroma.createCollection({ name: "test" }); diff --git a/clients/js/test/delete.collection.test.ts b/clients/js/test/delete.collection.test.ts index c4a3f8310ee..5f5e40ddbd1 100644 --- a/clients/js/test/delete.collection.test.ts +++ b/clients/js/test/delete.collection.test.ts @@ -1,6 +1,7 @@ import { expect, test } from "@jest/globals"; import chroma from "./initClient"; import { EMBEDDINGS, IDS, METADATAS } from "./data"; +import { InvalidCollectionError } from "../src/Errors"; test("it should delete a collection", async () => { await chroma.reset(); @@ -18,6 +19,15 @@ test("it should delete a collection", async () => { var remainingEmbeddings = await collection.get(); expect(["test2", "test3"]).toEqual( - expect.arrayContaining(remainingEmbeddings.ids), + expect.arrayContaining(remainingEmbeddings.ids) ); }); + +test("should error on non existing collection", async () => { + await chroma.reset(); + const collection = await chroma.createCollection({ name: "test" }); + await chroma.deleteCollection({ name: "test" }); + expect(async () => { + await collection.delete({ where: { test: "test1" } }); + }).rejects.toThrow(InvalidCollectionError); +}); diff --git a/clients/js/test/get.collection.test.ts b/clients/js/test/get.collection.test.ts index 2bfd4fae789..23f1d42e45b 100644 --- a/clients/js/test/get.collection.test.ts +++ b/clients/js/test/get.collection.test.ts @@ -1,7 +1,7 @@ import { expect, test } from "@jest/globals"; import chroma from "./initClient"; import { DOCUMENTS, EMBEDDINGS, IDS, METADATAS } from "./data"; -import { ChromaValueError } from "../src/Errors"; +import { ChromaValueError, InvalidCollectionError } from "../src/Errors"; test("it should get a collection", async () => { await chroma.reset(); @@ -47,7 +47,7 @@ test("wrong code returns an error", async () => { expect(error).toBeDefined(); expect(error).toBeInstanceOf(ChromaValueError); expect(error.message).toMatchInlineSnapshot( - `"Expected where operator to be one of $gt, $gte, $lt, $lte, $ne, $eq, $in, $nin, got $contains"`, + `"Expected where operator to be one of $gt, $gte, $lt, $lte, $ne, $eq, $in, $nin, got $contains"` ); } }); @@ -101,10 +101,19 @@ test("test gt, lt, in a simple small way", async () => { expect(["test2", "test3"]).toEqual(expect.arrayContaining(items.ids)); }); +test("should error on non existing collection", async () => { + await chroma.reset(); + const collection = await chroma.createCollection({ name: "test" }); + await chroma.deleteCollection({ name: "test" }); + expect(async () => { + await collection.get({ ids: IDS }); + }).rejects.toThrow(InvalidCollectionError); +}); + test("it should throw an error if the collection does not exist", async () => { await chroma.reset(); await expect( - async () => await chroma.getCollection({ name: "test" }), + async () => await chroma.getCollection({ name: "test" }) ).rejects.toThrow(Error); }); diff --git a/clients/js/test/peek.collection.test.ts b/clients/js/test/peek.collection.test.ts index 70d7b5674e8..0636a6ae121 100644 --- a/clients/js/test/peek.collection.test.ts +++ b/clients/js/test/peek.collection.test.ts @@ -1,6 +1,7 @@ import { expect, test } from "@jest/globals"; import chroma from "./initClient"; import { IDS, EMBEDDINGS } from "./data"; +import { InvalidCollectionError } from "../src/Errors"; test("it should peek a collection", async () => { await chroma.reset(); @@ -12,3 +13,12 @@ test("it should peek a collection", async () => { expect(results.ids.length).toBe(2); expect(["test1", "test2"]).toEqual(expect.arrayContaining(results.ids)); }); + +test("should error on non existing collection", async () => { + await chroma.reset(); + const collection = await chroma.createCollection({ name: "test" }); + await chroma.deleteCollection({ name: "test" }); + expect(async () => { + await collection.peek(); + }).rejects.toThrow(InvalidCollectionError); +}); diff --git a/clients/js/test/query.collection.test.ts b/clients/js/test/query.collection.test.ts index 41adff61ba6..39b3f2c2a10 100644 --- a/clients/js/test/query.collection.test.ts +++ b/clients/js/test/query.collection.test.ts @@ -4,6 +4,7 @@ import { IncludeEnum } from "../src/types"; import { EMBEDDINGS, IDS, METADATAS, DOCUMENTS } from "./data"; import { IEmbeddingFunction } from "../src/embeddings/IEmbeddingFunction"; +import { InvalidCollectionError } from "../src/Errors"; export class TestEmbeddingFunction implements IEmbeddingFunction { constructor() {} @@ -58,7 +59,7 @@ test("it should get embedding with matching documents", async () => { expect(["test1"]).toEqual(expect.arrayContaining(results.ids[0])); expect(["test2"]).not.toEqual(expect.arrayContaining(results.ids[0])); expect(["This is a test"]).toEqual( - expect.arrayContaining(results.documents[0]), + expect.arrayContaining(results.documents[0]) ); const results2 = await collection.query({ @@ -124,7 +125,7 @@ test("it should query a collection with text", async () => { expect(["test1"]).toEqual(expect.arrayContaining(results.ids[0])); expect(["test2"]).not.toEqual(expect.arrayContaining(results.ids[0])); expect(["This is a test"]).toEqual( - expect.arrayContaining(results.documents[0]), + expect.arrayContaining(results.documents[0]) ); }); @@ -154,7 +155,7 @@ test("it should query a collection with text and where", async () => { expect(["test3"]).toEqual(expect.arrayContaining(results.ids[0])); expect(["test2"]).not.toEqual(expect.arrayContaining(results.ids[0])); expect(["This is a third test"]).toEqual( - expect.arrayContaining(results.documents[0]), + expect.arrayContaining(results.documents[0]) ); }); @@ -184,7 +185,7 @@ test("it should query a collection with text and where in", async () => { expect(["test3"]).toEqual(expect.arrayContaining(results.ids[0])); expect(["test2"]).not.toEqual(expect.arrayContaining(results.ids[0])); expect(["This is a third test"]).toEqual( - expect.arrayContaining(results.documents[0]), + expect.arrayContaining(results.documents[0]) ); }); @@ -214,6 +215,15 @@ test("it should query a collection with text and where nin", async () => { expect(["test3"]).toEqual(expect.arrayContaining(results.ids[0])); expect(["test2"]).not.toEqual(expect.arrayContaining(results.ids[0])); expect(["This is a third test"]).toEqual( - expect.arrayContaining(results.documents[0]), + expect.arrayContaining(results.documents[0]) ); }); + +test("should error on non existing collection", async () => { + await chroma.reset(); + const collection = await chroma.createCollection({ name: "test" }); + await chroma.deleteCollection({ name: "test" }); + expect(async () => { + await collection.query({ queryEmbeddings: [1, 2, 3] }); + }).rejects.toThrow(InvalidCollectionError); +}); diff --git a/clients/js/test/update.collection.test.ts b/clients/js/test/update.collection.test.ts index 77537ac6bfb..c89320b8a9e 100644 --- a/clients/js/test/update.collection.test.ts +++ b/clients/js/test/update.collection.test.ts @@ -2,6 +2,7 @@ import { expect, test } from "@jest/globals"; import chroma from "./initClient"; import { IncludeEnum } from "../src/types"; import { IDS, DOCUMENTS, EMBEDDINGS, METADATAS } from "./data"; +import { InvalidCollectionError } from "../src/Errors"; test("it should get embedding with matching documents", async () => { await chroma.reset(); @@ -47,6 +48,20 @@ test("it should get embedding with matching documents", async () => { expect(results2.documents[0]).toEqual("doc1new"); }); +test("should error on non existing collection", async () => { + await chroma.reset(); + const collection = await chroma.createCollection({ name: "test" }); + await chroma.deleteCollection({ name: "test" }); + expect(async () => { + await collection.update({ + ids: ["test1"], + embeddings: [[1, 2, 3, 4, 5, 6, 7, 8, 9, 11]], + metadatas: [{ test: "meta1" }], + documents: ["doc1"], + }); + }).rejects.toThrow(InvalidCollectionError); +}); + // this currently fails // test("it should update metadata or documents to array of Nones", async () => { // await chroma.reset(); diff --git a/clients/js/test/upsert.collections.test.ts b/clients/js/test/upsert.collections.test.ts index 0fc1cdf1ef4..52871230a24 100644 --- a/clients/js/test/upsert.collections.test.ts +++ b/clients/js/test/upsert.collections.test.ts @@ -1,5 +1,6 @@ import { expect, test } from "@jest/globals"; import chroma from "./initClient"; +import { InvalidCollectionError } from "../src/Errors"; test("it should upsert embeddings to a collection", async () => { await chroma.reset(); @@ -24,3 +25,17 @@ test("it should upsert embeddings to a collection", async () => { const count2 = await collection.count(); expect(count2).toBe(3); }); + +test("should error on non existing collection", async () => { + await chroma.reset(); + const collection = await chroma.createCollection({ name: "test" }); + await chroma.deleteCollection({ name: "test" }); + expect(async () => { + await collection.upsert({ + ids: ["test1"], + embeddings: [[1, 2, 3, 4, 5, 6, 7, 8, 9, 11]], + metadatas: [{ test: "meta1" }], + documents: ["doc1"], + }); + }).rejects.toThrow(InvalidCollectionError); +});