Skip to content

Commit

Permalink
Modify list_collections client methods to return a list of collection…
Browse files Browse the repository at this point in the history
… names
  • Loading branch information
itaismith committed Dec 9, 2024
1 parent 337fe73 commit ce7320e
Show file tree
Hide file tree
Showing 11 changed files with 59 additions and 36 deletions.
11 changes: 5 additions & 6 deletions chromadb/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Sequence, Optional
from typing import Sequence, Optional, List
from uuid import UUID

from overrides import override
Expand Down Expand Up @@ -31,8 +31,7 @@
from chromadb.config import Component, Settings
from chromadb.types import Database, Tenant, Collection as CollectionModel
import chromadb.utils.embedding_functions as ef
from chromadb.api.models.Collection import Collection

from chromadb.api.models.Collection import Collection, CollectionName

# Re-export the async version
from chromadb.api.async_api import ( # noqa: F401
Expand Down Expand Up @@ -347,19 +346,19 @@ def list_collections(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> Sequence[Collection]:
) -> List[CollectionName]:
"""List all collections.
Args:
limit: The maximum number of entries to return. Defaults to None.
offset: The number of entries to skip before returning. Defaults to None.
Returns:
Sequence[Collection]: A list of collections
List[CollectionName]: A list of collection names
Examples:
```python
client.list_collections()
# [collection(name="my_collection", metadata={})]
# ["my_collection"]
```
"""
pass
Expand Down
8 changes: 4 additions & 4 deletions chromadb/api/async_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Sequence, Optional
from typing import Sequence, Optional, List
from uuid import UUID

from overrides import override
Expand Down Expand Up @@ -338,19 +338,19 @@ async def list_collections(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> Sequence[AsyncCollection]:
) -> List[str]:
"""List all collections.
Args:
limit: The maximum number of entries to return. Defaults to None.
offset: The number of entries to skip before returning. Defaults to None.
Returns:
Sequence[Collection]: A list of collections
List[str]: A list of collection names.
Examples:
```python
await client.list_collections()
# [collection(name="my_collection", metadata={})]
# ["my_collection"]
```
"""
pass
Expand Down
14 changes: 5 additions & 9 deletions chromadb/api/async_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import httpx
from typing import Optional, Sequence
from typing import Optional, Sequence, List
from uuid import UUID
from overrides import override

from chromadb.api.models.Collection import CollectionName
from chromadb.auth import UserIdentity
from chromadb.auth.utils import maybe_set_tenant_and_database
from chromadb.api import AsyncAdminAPI, AsyncClientAPI, AsyncServerAPI
Expand Down Expand Up @@ -152,17 +154,11 @@ async def heartbeat(self) -> int:
@override
async def list_collections(
self, limit: Optional[int] = None, offset: Optional[int] = None
) -> Sequence[AsyncCollection]:
) -> List[CollectionName]:
models = await self._server.list_collections(
limit, offset, tenant=self.tenant, database=self.database
)
return [
AsyncCollection(
client=self._server,
model=model,
)
for model in models
]
return [CollectionName(model.name) for model in models]

@override
async def count_collections(self) -> int:
Expand Down
8 changes: 4 additions & 4 deletions chromadb/api/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Sequence
from typing import Optional, Sequence, List
from uuid import UUID

from overrides import override
Expand All @@ -25,7 +25,7 @@
from chromadb.auth.utils import maybe_set_tenant_and_database
from chromadb.config import Settings, System
from chromadb.config import DEFAULT_TENANT, DEFAULT_DATABASE
from chromadb.api.models.Collection import Collection
from chromadb.api.models.Collection import Collection, CollectionName
from chromadb.errors import ChromaError
from chromadb.types import Database, Tenant, Where, WhereDocument
import chromadb.utils.embedding_functions as ef
Expand Down Expand Up @@ -118,9 +118,9 @@ def heartbeat(self) -> int:
@override
def list_collections(
self, limit: Optional[int] = None, offset: Optional[int] = None
) -> Sequence[Collection]:
) -> List[CollectionName]:
return [
Collection(client=self._server, model=model)
CollectionName(model.name)
for model in self._server.list_collections(
limit, offset, tenant=self.tenant, database=self.database
)
Expand Down
19 changes: 19 additions & 0 deletions chromadb/api/models/Collection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from typing import TYPE_CHECKING, Optional, Union

from chromadb.api.models.CollectionCommon import CollectionCommon
Expand Down Expand Up @@ -380,3 +381,21 @@ def delete(
tenant=self.tenant,
database=self.database,
)


def disable_method(method_name):
def method(self, *args, **kwargs):
raise NotImplementedError(f"In Chroma v0.6.0, list_collections only returns collection names. "
f"Use get_collection to access {method_name}. "
f"See https://docs.trychroma.com/deployment/migration for more information.")
return method


class CollectionName(str):
"""
A string wrapper to supply users with indicative message about list_collections only
returning collection names, in lieu of Collection object.
"""
for member, _ in inspect.getmembers(Collection, inspect.isfunction):
if not member.startswith("_"):
locals()[member] = disable_method(member)
3 changes: 2 additions & 1 deletion chromadb/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ def vacuum(
sqlite, system.instance(SegmentManager)
)

for collection in collections:
for collection_name in collections:
collection = client.get_collection(collection_name)
sqlite.purge_log(collection_id=collection.id)
progress.update(task, advance=1)
except Exception as e:
Expand Down
20 changes: 13 additions & 7 deletions chromadb/test/client/test_database_tenant.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,32 +24,38 @@ def test_database_tenant_collections(client_factories: ClientFactories) -> None:
# List collections in the default database
collections = client.list_collections()
assert len(collections) == 1
assert collections[0].name == "collection"
assert collections[0].metadata == {"database": DEFAULT_DATABASE}
assert collections[0] == "collection"
collection = client.get_collection(collections[0])
assert collection.metadata == {"database": DEFAULT_DATABASE}

# List collections in the new database
client.set_tenant(tenant=DEFAULT_TENANT, database="test_db")
collections = client.list_collections()
assert len(collections) == 1
assert collections[0].metadata == {"database": "test_db"}
collection = client.get_collection(collections[0])
assert collection.metadata == {"database": "test_db"}

# Update the metadata in both databases to different values
client.set_tenant(tenant=DEFAULT_TENANT, database=DEFAULT_DATABASE)
client.list_collections()[0].modify(metadata={"database": "default2"})
collection = client.get_collection(client.list_collections()[0])
collection.modify(metadata={"database": "default2"})

client.set_tenant(tenant=DEFAULT_TENANT, database="test_db")
client.list_collections()[0].modify(metadata={"database": "test_db2"})
collection = client.get_collection(client.list_collections()[0])
collection.modify(metadata={"database": "test_db2"})

# Validate that the metadata was updated
client.set_tenant(tenant=DEFAULT_TENANT, database=DEFAULT_DATABASE)
collections = client.list_collections()
assert len(collections) == 1
assert collections[0].metadata == {"database": "default2"}
collection = client.get_collection(collections[0])
assert collection.metadata == {"database": "default2"}

client.set_tenant(tenant=DEFAULT_TENANT, database="test_db")
collections = client.list_collections()
assert len(collections) == 1
assert collections[0].metadata == {"database": "test_db2"}
collection = client.get_collection(collections[0])
assert collection.metadata == {"database": "test_db2"}

# Delete the collections and make sure databases are isolated
client.set_tenant(tenant=DEFAULT_TENANT, database=DEFAULT_DATABASE)
Expand Down
3 changes: 2 additions & 1 deletion chromadb/test/client/test_multiple_clients_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def run_target(n: int) -> None:
client.set_database(database)
seen_collections = client.list_collections()
assert len(seen_collections) == COLLECTION_COUNT
for collection in seen_collections:
for collection_name in seen_collections:
collection = client.get_collection(collection_name)
assert collection.name in collections
assert collection.metadata == {"database": database}
2 changes: 1 addition & 1 deletion chromadb/test/property/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def list_collections(self) -> None:
colls = self.client.list_collections()
assert len(colls) == len(self.model)
for c in colls:
assert c.name in self.model
assert c in self.model

# @rule for list_collections with limit and offset
@rule(
Expand Down
3 changes: 2 additions & 1 deletion chromadb/test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,8 @@ def test_metadata_cru(client):

# Test list collections
collections = client.list_collections()
for collection in collections:
for collection_name in collections:
collection = client.get_collection(collection_name)
if collection.name == "testspace":
assert collection.metadata is not None
assert collection.metadata["a"] == 2
Expand Down
4 changes: 2 additions & 2 deletions docs/docs.trychroma.com/pages/reference/py-client.md
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,10 @@ class ClientAPI(BaseAPI, ABC)

```python
def list_collections(limit: Optional[int] = None,
offset: Optional[int] = None) -> Sequence[Collection]
offset: Optional[int] = None) -> List[str]
```

List all collections.
List all collection names.

**Arguments**:

Expand Down

0 comments on commit ce7320e

Please sign in to comment.