diff --git a/chromadb/api/models/AsyncCollection.py b/chromadb/api/models/AsyncCollection.py index 02491ba1035..efa7dda4b23 100644 --- a/chromadb/api/models/AsyncCollection.py +++ b/chromadb/api/models/AsyncCollection.py @@ -21,6 +21,7 @@ ID, OneOrMany, WhereDocument, + IncludeEnum, ) from chromadb.api.models.CollectionCommon import CollectionCommon @@ -74,7 +75,16 @@ async def add( ids, embeddings, metadatas, documents, images, uris ) - await self._client._add(ids, self.id, embeddings, metadatas, documents, uris) + await self._client._add( + ids=ids, + collection_id=self.id, + embeddings=embeddings, + metadatas=metadatas, + documents=documents, + uris=uris, + tenant=self.tenant, + database=self.database, + ) async def count(self) -> int: """The total number of embeddings added to the database @@ -83,7 +93,11 @@ async def count(self) -> int: int: The total number of embeddings added to the database """ - return await self._client._count(collection_id=self.id) + return await self._client._count( + collection_id=self.id, + tenant=self.tenant, + database=self.database, + ) async def get( self, @@ -92,7 +106,7 @@ async def get( limit: Optional[int] = None, offset: Optional[int] = None, where_document: Optional[WhereDocument] = None, - include: Include = ["metadatas", "documents"], + include: Include = [IncludeEnum.metadatas, IncludeEnum.documents], ) -> GetResult: """Get embeddings and their associate data from the data store. If no ids or where filter is provided returns all embeddings up to limit starting at offset. @@ -117,14 +131,16 @@ async def get( ) = self._validate_and_prepare_get_request(ids, where, where_document, include) get_results = await self._client._get( - self.id, - valid_ids, - valid_where, - None, - limit, - offset, + collection_id=self.id, + ids=valid_ids, + where=valid_where, + sort=None, + limit=limit, + offset=offset, where_document=valid_where_document, include=valid_include, + tenant=self.tenant, + database=self.database, ) return self._transform_get_response(get_results, valid_include) @@ -138,11 +154,18 @@ async def peek(self, limit: int = 10) -> GetResult: Returns: GetResult: A GetResult object containing the results. """ - return self._transform_peek_response(await self._client._peek(self.id, limit)) + return self._transform_peek_response( + await self._client._peek( + collection_id=self.id, + n=limit, + tenant=self.tenant, + database=self.database, + ) + ) async def query( self, - query_embeddings: Optional[ + query_embeddings: Optional[ # type: ignore[type-arg] Union[ OneOrMany[Embedding], OneOrMany[np.ndarray], @@ -154,7 +177,11 @@ async def query( n_results: int = 10, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, - include: Include = ["metadatas", "documents", "distances"], + include: Include = [ + IncludeEnum.metadatas, + IncludeEnum.documents, + IncludeEnum.distances, + ], ) -> QueryResult: """Get the n_results nearest neighbor embeddings for provided query_embeddings or query_texts. @@ -201,6 +228,8 @@ async def query( where=valid_where, where_document=valid_where_document, include=include, + tenant=self.tenant, + database=self.database, ) return self._transform_query_response(query_results, include) @@ -230,7 +259,7 @@ async def modify( async def update( self, ids: OneOrMany[ID], - embeddings: Optional[ + embeddings: Optional[ # type: ignore[type-arg] Union[ OneOrMany[Embedding], OneOrMany[np.ndarray], @@ -262,12 +291,21 @@ async def update( ids, embeddings, metadatas, documents, images, uris ) - await self._client._update(self.id, ids, embeddings, metadatas, documents, uris) + await self._client._update( + collection_id=self.id, + ids=ids, + embeddings=embeddings, + metadatas=metadatas, + documents=documents, + uris=uris, + tenant=self.tenant, + database=self.database, + ) async def upsert( self, ids: OneOrMany[ID], - embeddings: Optional[ + embeddings: Optional[ # type: ignore[type-arg] Union[ OneOrMany[Embedding], OneOrMany[np.ndarray], @@ -306,6 +344,8 @@ async def upsert( metadatas=metadatas, documents=documents, uris=uris, + tenant=self.tenant, + database=self.database, ) async def delete( @@ -331,4 +371,6 @@ async def delete( ids, where, where_document ) - await self._client._delete(self.id, ids, where, where_document) + await self._client._delete( + collection_id=self.id, ids=ids, where=where, where_document=where_document + ) diff --git a/chromadb/api/models/Collection.py b/chromadb/api/models/Collection.py index f3aaf2f57dd..6aaf2c2b864 100644 --- a/chromadb/api/models/Collection.py +++ b/chromadb/api/models/Collection.py @@ -18,6 +18,7 @@ ID, OneOrMany, WhereDocument, + IncludeEnum, ) import logging @@ -36,12 +37,16 @@ def count(self) -> int: int: The total number of embeddings added to the database """ - return self._client._count(collection_id=self.id) + return self._client._count( + collection_id=self.id, + tenant=self.tenant, + database=self.database, + ) def add( self, ids: OneOrMany[ID], - embeddings: Optional[ # type: ignore[type-arg] + embeddings: Optional[ Union[ OneOrMany[Embedding], OneOrMany[PyEmbedding], @@ -82,7 +87,16 @@ def add( ids, embeddings, metadatas, documents, images, uris ) - self._client._add(ids, self.id, embeddings, metadatas, documents, uris) + self._client._add( + ids=ids, + collection_id=self.id, + embeddings=embeddings, + metadatas=metadatas, + documents=documents, + uris=uris, + tenant=self.tenant, + database=self.database, + ) def get( self, @@ -91,7 +105,7 @@ def get( limit: Optional[int] = None, offset: Optional[int] = None, where_document: Optional[WhereDocument] = None, - include: Include = ["metadatas", "documents"], + include: Include = [IncludeEnum.metadatas, IncludeEnum.documents], ) -> GetResult: """Get embeddings and their associate data from the data store. If no ids or where filter is provided returns all embeddings up to limit starting at offset. @@ -116,14 +130,16 @@ def get( ) = self._validate_and_prepare_get_request(ids, where, where_document, include) get_results = self._client._get( - self.id, - valid_ids, - valid_where, - None, - limit, - offset, + collection_id=self.id, + ids=valid_ids, + where=valid_where, + sort=None, + limit=limit, + offset=offset, where_document=valid_where_document, include=valid_include, + tenant=self.tenant, + database=self.database, ) return self._transform_get_response(get_results, include) @@ -137,11 +153,18 @@ def peek(self, limit: int = 10) -> GetResult: Returns: GetResult: A GetResult object containing the results. """ - return self._transform_peek_response(self._client._peek(self.id, limit)) + return self._transform_peek_response( + self._client._peek( + collection_id=self.id, + n=limit, + tenant=self.tenant, + database=self.database, + ) + ) def query( self, - query_embeddings: Optional[ # type: ignore[type-arg] + query_embeddings: Optional[ Union[ OneOrMany[Embedding], OneOrMany[PyEmbedding], @@ -153,7 +176,11 @@ def query( n_results: int = 10, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, - include: Include = ["metadatas", "documents", "distances"], + include: Include = [ + IncludeEnum.metadatas, + IncludeEnum.documents, + IncludeEnum.distances, + ], ) -> QueryResult: """Get the n_results nearest neighbor embeddings for provided query_embeddings or query_texts. @@ -201,6 +228,8 @@ def query( where=valid_where, where_document=valid_where_document, include=include, + tenant=self.tenant, + database=self.database, ) return self._transform_query_response(query_results, include) @@ -223,7 +252,13 @@ def modify( # Note there is a race condition here where the metadata can be updated # but another thread sees the cached local metadata. # TODO: fixme - self._client._modify(id=self.id, new_name=name, new_metadata=metadata) + self._client._modify( + id=self.id, + new_name=name, + new_metadata=metadata, + tenant=self.tenant, + database=self.database, + ) self._update_model_after_modify_success(name, metadata) @@ -262,12 +297,21 @@ def update( ids, embeddings, metadatas, documents, images, uris ) - self._client._update(self.id, ids, embeddings, metadatas, documents, uris) + self._client._update( + collection_id=self.id, + ids=ids, + embeddings=embeddings, + metadatas=metadatas, + documents=documents, + uris=uris, + tenant=self.tenant, + database=self.database, + ) def upsert( self, ids: OneOrMany[ID], - embeddings: Optional[ # type: ignore[type-arg] + embeddings: Optional[ Union[ OneOrMany[Embedding], OneOrMany[PyEmbedding], @@ -306,6 +350,8 @@ def upsert( metadatas=metadatas, documents=documents, uris=uris, + tenant=self.tenant, + database=self.database, ) def delete( @@ -331,4 +377,11 @@ def delete( ids, where, where_document ) - self._client._delete(self.id, ids, where, where_document) + self._client._delete( + collection_id=self.id, + ids=ids, + where=where, + where_document=where_document, + tenant=self.tenant, + database=self.database, + )