diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index f3d9909718b..f079a30747e 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -19,8 +19,7 @@ from fastapi.responses import JSONResponse, ORJSONResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.routing import APIRoute -from fastapi import HTTPException, status - +from fastapi import status from chromadb.api.configuration import CollectionConfigurationInternal from pydantic import BaseModel from chromadb.api.types import ( @@ -41,7 +40,6 @@ from chromadb.api import ServerAPI from chromadb.errors import ( ChromaError, - InvalidDimensionException, InvalidHTTPVersion, RateLimitError, QuotaError, @@ -773,46 +771,42 @@ async def add( collection_id: str, body: AddEmbedding = Body(...), ) -> bool: - try: - - def process_add(request: Request, raw_body: bytes) -> bool: - add = validate_model(AddEmbedding, orjson.loads(raw_body)) - self.auth_request( - request.headers, - AuthzAction.ADD, - tenant, - database_name, - collection_id, - ) - self._set_request_context(request=request) - add_attributes_to_current_span({"tenant": tenant}) - return self._api._add( - collection_id=_uuid(collection_id), - ids=add.ids, - embeddings=cast( - Embeddings, - convert_list_embeddings_to_np(add.embeddings) - if add.embeddings - else None, - ), - metadatas=add.metadatas, # type: ignore - documents=add.documents, # type: ignore - uris=add.uris, # type: ignore - tenant=tenant, - database=database_name, - ) - - return cast( - bool, - await to_thread.run_sync( - process_add, - request, - await request.body(), - limiter=self._capacity_limiter, + def process_add(request: Request, raw_body: bytes) -> bool: + add = validate_model(AddEmbedding, orjson.loads(raw_body)) + self.auth_request( + request.headers, + AuthzAction.ADD, + tenant, + database_name, + collection_id, + ) + self._set_request_context(request=request) + add_attributes_to_current_span({"tenant": tenant}) + return self._api._add( + collection_id=_uuid(collection_id), + ids=add.ids, + embeddings=cast( + Embeddings, + convert_list_embeddings_to_np(add.embeddings) + if add.embeddings + else None, ), + metadatas=add.metadatas, # type: ignore + documents=add.documents, # type: ignore + uris=add.uris, # type: ignore + tenant=tenant, + database=database_name, ) - except InvalidDimensionException as e: - raise HTTPException(status_code=500, detail=str(e)) + + return cast( + bool, + await to_thread.run_sync( + process_add, + request, + await request.body(), + limiter=self._capacity_limiter, + ), + ) @trace_method("FastAPI.update", OpenTelemetryGranularity.OPERATION) async def update( @@ -1639,42 +1633,38 @@ async def delete_collection_v1( async def add_v1( self, request: Request, collection_id: str, body: AddEmbedding = Body(...) ) -> bool: - try: - - def process_add(request: Request, raw_body: bytes) -> bool: - add = validate_model(AddEmbedding, orjson.loads(raw_body)) - self.auth_and_get_tenant_and_database_for_request( - request.headers, - AuthzAction.ADD, - None, - None, - collection_id, - ) - return self._api._add( - collection_id=_uuid(collection_id), - ids=add.ids, - embeddings=cast( - Embeddings, - convert_list_embeddings_to_np(add.embeddings) - if add.embeddings - else None, - ), - metadatas=add.metadatas, # type: ignore - documents=add.documents, # type: ignore - uris=add.uris, # type: ignore - ) - - return cast( - bool, - await to_thread.run_sync( - process_add, - request, - await request.body(), - limiter=self._capacity_limiter, + def process_add(request: Request, raw_body: bytes) -> bool: + add = validate_model(AddEmbedding, orjson.loads(raw_body)) + self.auth_and_get_tenant_and_database_for_request( + request.headers, + AuthzAction.ADD, + None, + None, + collection_id, + ) + return self._api._add( + collection_id=_uuid(collection_id), + ids=add.ids, + embeddings=cast( + Embeddings, + convert_list_embeddings_to_np(add.embeddings) + if add.embeddings + else None, ), + metadatas=add.metadatas, # type: ignore + documents=add.documents, # type: ignore + uris=add.uris, # type: ignore ) - except InvalidDimensionException as e: - raise HTTPException(status_code=500, detail=str(e)) + + return cast( + bool, + await to_thread.run_sync( + process_add, + request, + await request.body(), + limiter=self._capacity_limiter, + ), + ) @trace_method("FastAPI.update_v1", OpenTelemetryGranularity.OPERATION) async def update_v1( diff --git a/chromadb/test/api/test_collection.py b/chromadb/test/api/test_collection.py index e367ebc21f8..287d3984165 100644 --- a/chromadb/test/api/test_collection.py +++ b/chromadb/test/api/test_collection.py @@ -1,10 +1,13 @@ from chromadb.api import ClientAPI +import numpy as np + +from chromadb.errors import InvalidDimensionException def test_duplicate_collection_create( client: ClientAPI, ) -> None: - collection = client.create_collection( + _ = client.create_collection( name="test", metadata={"hnsw:construction_ef": 128, "hnsw:search_ef": 128, "hnsw:M": 128}, ) @@ -28,10 +31,34 @@ def test_not_existing_collection_delete( client: ClientAPI, ) -> None: try: - collection = client.delete_collection( + _ = client.delete_collection( name="test101", ) assert False, "Expected exception" except Exception as e: print("Collection deletion failed as expected with error ", e) assert "does not exist" in e.args[0] + + +def test_collection_dimension_mismatch( + client: ClientAPI, +) -> None: + collection = client.create_collection( + name="test", + ) + D = 768 + N = 5 + embeddings = np.random.random(size=(N, D)) + ids = [str(i) for i in range(N)] + + collection.add(ids=ids, embeddings=embeddings) # type: ignore[arg-type] + + WRONG_D = 512 + wrong_embeddings = np.random.random(size=(N, WRONG_D)) + try: + collection.add(ids=ids, embeddings=wrong_embeddings) # type: ignore[arg-type] + assert False, "Expected exception" + except InvalidDimensionException: + print("Dimension mismatch failed as expected") + except Exception as e: + assert False, f"Unexpected exception {e}"