From c46ea2a0f7fe9a5b953e8a5d97d7d318942cdbaa Mon Sep 17 00:00:00 2001 From: Max Isom Date: Mon, 29 Apr 2024 17:14:08 -0700 Subject: [PATCH] [ENH] add `included` to `.get()` & `.query()` response (#2044) --- chromadb/api/fastapi.py | 2 ++ chromadb/api/segment.py | 3 ++ chromadb/api/types.py | 2 ++ chromadb/test/property/test_filtering.py | 2 ++ chromadb/test/test_api.py | 39 +++++++++++++++++++++--- clients/js/src/types.ts | 12 +++++--- clients/js/test/get.collection.test.ts | 1 + clients/js/test/query.collection.test.ts | 4 ++- 8 files changed, 55 insertions(+), 10 deletions(-) diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index 22dffe04ddf..f375d269579 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -427,6 +427,7 @@ def _get( documents=body.get("documents", None), data=None, uris=body.get("uris", None), + included=body["included"], ) @trace_method("FastAPI._delete", OpenTelemetryGranularity.OPERATION) @@ -581,6 +582,7 @@ def _query( documents=body.get("documents", None), uris=body.get("uris", None), data=None, + included=body["included"], ) @trace_method("FastAPI.reset", OpenTelemetryGranularity.ALL) diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index 1440a843fc9..0c29f38f03c 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -531,6 +531,7 @@ def _get( documents=[] if "documents" in include else None, uris=[] if "uris" in include else None, data=[] if "data" in include else None, + included=include, ) vectors: Sequence[t.VectorEmbeddingRecord] = [] @@ -574,6 +575,7 @@ def _get( documents=documents if "documents" in include else None, # type: ignore uris=uris if "uris" in include else None, # type: ignore data=None, + included=include, ) @trace_method("SegmentAPI._delete", OpenTelemetryGranularity.OPERATION) @@ -766,6 +768,7 @@ def _query( documents=documents if documents else None, uris=uris if uris else None, data=None, + included=include, ) @trace_method("SegmentAPI._peek", OpenTelemetryGranularity.OPERATION) diff --git a/chromadb/api/types.py b/chromadb/api/types.py index cafa3d594f3..706dd870eae 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -157,6 +157,7 @@ class GetResult(TypedDict): uris: Optional[URIs] data: Optional[Loadable] metadatas: Optional[List[Metadata]] + included: Include class QueryResult(TypedDict): @@ -167,6 +168,7 @@ class QueryResult(TypedDict): data: Optional[List[Loadable]] metadatas: Optional[List[List[Metadata]]] distances: Optional[List[List[float]]] + included: Include class IndexMetadata(TypedDict): diff --git a/chromadb/test/property/test_filtering.py b/chromadb/test/property/test_filtering.py index 9129c023df7..2826caf12eb 100644 --- a/chromadb/test/property/test_filtering.py +++ b/chromadb/test/property/test_filtering.py @@ -338,6 +338,7 @@ def test_empty_filter(api: ServerAPI) -> None: assert res["embeddings"] == [[]] assert res["distances"] == [[]] assert res["metadatas"] == [[]] + assert set(res["included"]) == set(["embeddings", "distances", "metadatas"]) res = coll.query( query_embeddings=test_query_embeddings, @@ -348,6 +349,7 @@ def test_empty_filter(api: ServerAPI) -> None: assert res["embeddings"] is None assert res["distances"] == [[], []] assert res["metadatas"] == [[], []] + assert set(res["included"]) == set(["metadatas", "documents", "distances"]) def test_boolean_metadata(api: ServerAPI) -> None: diff --git a/chromadb/test/test_api.py b/chromadb/test/test_api.py index f91ea3ca927..3cef03da19c 100644 --- a/chromadb/test/test_api.py +++ b/chromadb/test/test_api.py @@ -91,6 +91,8 @@ def test_persist_index_loading(api_fixture, request): for key in nn.keys(): if (key in includes) or (key == "ids"): assert len(nn[key]) == 1 + elif key == "included": + assert set(nn[key]) == set(includes) else: assert nn[key] is None @@ -118,6 +120,8 @@ def __call__(self, input): for key in nn.keys(): if (key in includes) or (key == "ids"): assert len(nn[key]) == 1 + elif key == "included": + assert set(nn[key]) == set(includes) else: assert nn[key] is None @@ -146,6 +150,8 @@ def __call__(self, input): for key in nn.keys(): if (key in includes) or (key == "ids"): assert len(nn[key]) == 1 + elif key == "included": + assert set(nn[key]) == set(includes) else: assert nn[key] is None @@ -260,6 +266,8 @@ def test_get_from_db(api): for key in records.keys(): if (key in includes) or (key == "ids"): assert len(records[key]) == 2 + elif key == "included": + assert set(records[key]) == set(includes) else: assert records[key] is None @@ -290,6 +298,8 @@ def test_get_nearest_neighbors(api): for key in nn.keys(): if (key in includes) or (key == "ids"): assert len(nn[key]) == 1 + elif key == "included": + assert set(nn[key]) == set(includes) else: assert nn[key] is None @@ -302,6 +312,8 @@ def test_get_nearest_neighbors(api): for key in nn.keys(): if (key in includes) or (key == "ids"): assert len(nn[key]) == 1 + elif key == "included": + assert set(nn[key]) == set(includes) else: assert nn[key] is None @@ -314,6 +326,8 @@ def test_get_nearest_neighbors(api): for key in nn.keys(): if (key in includes) or (key == "ids"): assert len(nn[key]) == 2 + elif key == "included": + assert set(nn[key]) == set(includes) else: assert nn[key] is None @@ -437,6 +451,8 @@ def test_increment_index_on(api): for key in nn.keys(): if (key in includes) or (key == "ids"): assert len(nn[key]) == 1 + elif key == "included": + assert set(nn[key]) == set(includes) else: assert nn[key] is None @@ -489,6 +505,8 @@ def test_peek(api): for key in peek.keys(): if key in ["embeddings", "documents", "metadatas"] or key == "ids": assert len(peek[key]) == 2 + elif key == "included": + assert set(peek[key]) == set(["embeddings", "metadatas", "documents"]) else: assert peek[key] is None @@ -994,22 +1012,26 @@ def test_query_include(api): collection = api.create_collection("test_query_include") collection.add(**records) + include = ["metadatas", "documents", "distances"] items = collection.query( query_embeddings=[0, 0, 0], - include=["metadatas", "documents", "distances"], + include=include, n_results=1, ) assert items["embeddings"] is None assert items["ids"][0][0] == "id1" assert items["metadatas"][0][0]["int_value"] == 1 + assert set(items["included"]) == set(include) + include = ["embeddings", "documents", "distances"] items = collection.query( query_embeddings=[0, 0, 0], - include=["embeddings", "documents", "distances"], + include=include, n_results=1, ) assert items["metadatas"] is None assert items["ids"][0][0] == "id1" + assert set(items["included"]) == set(include) items = collection.query( query_embeddings=[[0, 0, 0], [1, 2, 1.2]], @@ -1029,22 +1051,27 @@ def test_get_include(api): collection = api.create_collection("test_get_include") collection.add(**records) - items = collection.get(include=["metadatas", "documents"], where={"int_value": 1}) + include = ["metadatas", "documents"] + items = collection.get(include=include, where={"int_value": 1}) assert items["embeddings"] is None assert items["ids"][0] == "id1" assert items["metadatas"][0]["int_value"] == 1 assert items["documents"][0] == "this document is first" + assert set(items["included"]) == set(include) - items = collection.get(include=["embeddings", "documents"]) + include = ["embeddings", "documents"] + items = collection.get(include=include) assert items["metadatas"] is None assert items["ids"][0] == "id1" assert approx_equal(items["embeddings"][1][0], 1.2) + assert set(items["included"]) == set(include) items = collection.get(include=[]) assert items["documents"] is None assert items["metadatas"] is None assert items["embeddings"] is None assert items["ids"][0] == "id1" + assert items["included"] == [] with pytest.raises(ValueError, match="include"): items = collection.get(include=["metadatas", "undefined"]) @@ -1172,6 +1199,8 @@ def test_persist_index_loading_params(api, request): for key in nn.keys(): if (key in includes) or (key == "ids"): assert len(nn[key]) == 1 + elif key == "included": + assert set(nn[key]) == set(includes) else: assert nn[key] is None @@ -1290,6 +1319,8 @@ def test_get_nearest_neighbors_where_n_results_more_than_element(api): for key in results.keys(): if key in includes or key == "ids": assert len(results[key][0]) == 2 + elif key == "included": + assert set(results[key]) == set(includes) else: assert results[key] is None diff --git a/clients/js/src/types.ts b/clients/js/src/types.ts index 92e66f516c7..5e1aa8885a5 100644 --- a/clients/js/src/types.ts +++ b/clients/js/src/types.ts @@ -32,8 +32,8 @@ type WhereOperator = "$gt" | "$gte" | "$lt" | "$lte" | "$ne" | "$eq"; type OperatorExpression = { [key in WhereOperator | InclusionOperator | LogicalOperator]?: - | LiteralValue - | ListLiteralValue; + | LiteralValue + | ListLiteralValue; }; type BaseWhere = { @@ -50,9 +50,9 @@ type WhereDocumentOperator = "$contains" | "$not_contains" | LogicalOperator; export type WhereDocument = { [key in WhereDocumentOperator]?: - | LiteralValue - | LiteralNumber - | WhereDocument[]; + | LiteralValue + | LiteralNumber + | WhereDocument[]; }; export type CollectionType = { @@ -67,6 +67,7 @@ export type GetResponse = { documents: (null | Document)[]; metadatas: (null | Metadata)[]; error: null | string; + included: IncludeEnum[] }; export type QueryResponse = { @@ -75,6 +76,7 @@ export type QueryResponse = { documents: (null | Document)[][]; metadatas: (null | Metadata)[][]; distances: null | number[][]; + included: IncludeEnum[] }; export type AddResponse = { diff --git a/clients/js/test/get.collection.test.ts b/clients/js/test/get.collection.test.ts index 5dc972454b8..805b3819264 100644 --- a/clients/js/test/get.collection.test.ts +++ b/clients/js/test/get.collection.test.ts @@ -17,6 +17,7 @@ test("it should get a collection", async () => { expect(results.ids.length).toBe(1); expect(["test1"]).toEqual(expect.arrayContaining(results.ids)); expect(["test2"]).not.toEqual(expect.arrayContaining(results.ids)); + expect(results.included).toEqual(expect.arrayContaining(["metadatas", "documents"])) const results2 = await collection.get({ where: { test: "test1" } }); expect(results2).toBeDefined(); diff --git a/clients/js/test/query.collection.test.ts b/clients/js/test/query.collection.test.ts index 2809716535c..4afb0120ca3 100644 --- a/clients/js/test/query.collection.test.ts +++ b/clients/js/test/query.collection.test.ts @@ -6,7 +6,7 @@ import { EMBEDDINGS, IDS, METADATAS, DOCUMENTS } from "./data"; import { IEmbeddingFunction } from "../src/embeddings/IEmbeddingFunction"; export class TestEmbeddingFunction implements IEmbeddingFunction { - constructor() {} + constructor() { } public async generate(texts: string[]): Promise { let embeddings: number[][] = []; @@ -29,6 +29,7 @@ test("it should query a collection", async () => { expect(results).toBeInstanceOf(Object); expect(["test1", "test2"]).toEqual(expect.arrayContaining(results.ids[0])); expect(["test3"]).not.toEqual(expect.arrayContaining(results.ids[0])); + expect(results.included).toEqual(expect.arrayContaining(["metadatas", "documents"])) }); // test where_document @@ -68,6 +69,7 @@ test("it should get embedding with matching documents", async () => { // expect(results2.embeddings[0][0]).toBeInstanceOf(Array); expect(results2.embeddings![0].length).toBe(1); expect(results2.embeddings![0][0]).toEqual([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + expect(results2.included).toEqual(expect.arrayContaining(["embeddings"])) }); test("it should exclude documents matching - not_contains", async () => {