Skip to content

Commit

Permalink
Modify list_collections client methods to return a list of collection…
Browse files Browse the repository at this point in the history
… names
  • Loading branch information
itaismith committed Dec 11, 2024
1 parent 337fe73 commit e5ec6e9
Show file tree
Hide file tree
Showing 12 changed files with 85 additions and 33 deletions.
9 changes: 4 additions & 5 deletions chromadb/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@
from chromadb.config import Component, Settings
from chromadb.types import Database, Tenant, Collection as CollectionModel
import chromadb.utils.embedding_functions as ef
from chromadb.api.models.Collection import Collection

from chromadb.api.models.Collection import Collection, CollectionName

# Re-export the async version
from chromadb.api.async_api import ( # noqa: F401
Expand Down Expand Up @@ -347,19 +346,19 @@ def list_collections(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> Sequence[Collection]:
) -> Sequence[CollectionName]:
"""List all collections.
Args:
limit: The maximum number of entries to return. Defaults to None.
offset: The number of entries to skip before returning. Defaults to None.
Returns:
Sequence[Collection]: A list of collections
Sequence[CollectionName]: A list of collection names
Examples:
```python
client.list_collections()
# [collection(name="my_collection", metadata={})]
# ["my_collection"]
```
"""
pass
Expand Down
6 changes: 3 additions & 3 deletions chromadb/api/async_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,19 +338,19 @@ async def list_collections(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> Sequence[AsyncCollection]:
) -> Sequence[str]:
"""List all collections.
Args:
limit: The maximum number of entries to return. Defaults to None.
offset: The number of entries to skip before returning. Defaults to None.
Returns:
Sequence[Collection]: A list of collections
Sequence[str]: A list of collection names.
Examples:
```python
await client.list_collections()
# [collection(name="my_collection", metadata={})]
# ["my_collection"]
```
"""
pass
Expand Down
12 changes: 4 additions & 8 deletions chromadb/api/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from typing import Optional, Sequence
from uuid import UUID
from overrides import override

from chromadb.api.models.Collection import CollectionName
from chromadb.auth import UserIdentity
from chromadb.auth.utils import maybe_set_tenant_and_database
from chromadb.api import AsyncAdminAPI, AsyncClientAPI, AsyncServerAPI
Expand Down Expand Up @@ -152,17 +154,11 @@ async def heartbeat(self) -> int:
@override
async def list_collections(
self, limit: Optional[int] = None, offset: Optional[int] = None
) -> Sequence[AsyncCollection]:
) -> Sequence[CollectionName]:
models = await self._server.list_collections(
limit, offset, tenant=self.tenant, database=self.database
)
return [
AsyncCollection(
client=self._server,
model=model,
)
for model in models
]
return [CollectionName(model.name) for model in models]

@override
async def count_collections(self) -> int:
Expand Down
6 changes: 3 additions & 3 deletions chromadb/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from chromadb.auth.utils import maybe_set_tenant_and_database
from chromadb.config import Settings, System
from chromadb.config import DEFAULT_TENANT, DEFAULT_DATABASE
from chromadb.api.models.Collection import Collection
from chromadb.api.models.Collection import Collection, CollectionName
from chromadb.errors import ChromaError
from chromadb.types import Database, Tenant, Where, WhereDocument
import chromadb.utils.embedding_functions as ef
Expand Down Expand Up @@ -118,9 +118,9 @@ def heartbeat(self) -> int:
@override
def list_collections(
self, limit: Optional[int] = None, offset: Optional[int] = None
) -> Sequence[Collection]:
) -> Sequence[CollectionName]:
return [
Collection(client=self._server, model=model)
CollectionName(model.name)
for model in self._server.list_collections(
limit, offset, tenant=self.tenant, database=self.database
)
Expand Down
40 changes: 40 additions & 0 deletions chromadb/api/models/Collection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from typing import TYPE_CHECKING, Optional, Union

from chromadb.api.models.CollectionCommon import CollectionCommon
Expand Down Expand Up @@ -380,3 +381,42 @@ def delete(
tenant=self.tenant,
database=self.database,
)



class CollectionName(str):
"""
A string wrapper to supply users with indicative message about list_collections only
returning collection names, in lieu of Collection object.
When a user will try to access an attribute on a CollectionName string, the __getattribute__ method
of str is invoked first. If a valid str method or property is found, it will be used. Otherwise, the fallback
__getattr__ defined here is invoked next. It will error if the requested attribute is a Collection
method or property.
For example:
collection_name = client.list_collections()[0] # collection_name = "test"
collection_name.startsWith("t") # Evaluates to True.
# __getattribute__ is invoked first, selecting startsWith from str.
collection_name.add(ids=[...], documents=[...]) # Raises the error defined below
# __getattribute__ is invoked first, not finding a match in str.
# __getattr__ from this class is invoked and raises an error
"""

def __getattr__(self, item):
collection_attributes_and_methods = [
member for member, _ in inspect.getmembers(Collection)
if not member.startswith("_")
]

if item in collection_attributes_and_methods:
raise NotImplementedError(
f"In Chroma v0.6.0, list_collections only returns collection names. "
f"Use Client.get_collection(collection_name) to access {item}. "
f"See https://docs.trychroma.com/deployment/migration for more information."
)

raise AttributeError(f"'CollectionName' object has no attribute '{item}'")
3 changes: 2 additions & 1 deletion chromadb/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ def vacuum(
sqlite, system.instance(SegmentManager)
)

for collection in collections:
for collection_name in collections:
collection = client.get_collection(collection_name)
sqlite.purge_log(collection_id=collection.id)
progress.update(task, advance=1)
except Exception as e:
Expand Down
20 changes: 13 additions & 7 deletions chromadb/test/client/test_database_tenant.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,32 +24,38 @@ def test_database_tenant_collections(client_factories: ClientFactories) -> None:
# List collections in the default database
collections = client.list_collections()
assert len(collections) == 1
assert collections[0].name == "collection"
assert collections[0].metadata == {"database": DEFAULT_DATABASE}
assert collections[0] == "collection"
collection = client.get_collection(collections[0])
assert collection.metadata == {"database": DEFAULT_DATABASE}

# List collections in the new database
client.set_tenant(tenant=DEFAULT_TENANT, database="test_db")
collections = client.list_collections()
assert len(collections) == 1
assert collections[0].metadata == {"database": "test_db"}
collection = client.get_collection(collections[0])
assert collection.metadata == {"database": "test_db"}

# Update the metadata in both databases to different values
client.set_tenant(tenant=DEFAULT_TENANT, database=DEFAULT_DATABASE)
client.list_collections()[0].modify(metadata={"database": "default2"})
collection = client.get_collection(client.list_collections()[0])
collection.modify(metadata={"database": "default2"})

client.set_tenant(tenant=DEFAULT_TENANT, database="test_db")
client.list_collections()[0].modify(metadata={"database": "test_db2"})
collection = client.get_collection(client.list_collections()[0])
collection.modify(metadata={"database": "test_db2"})

# Validate that the metadata was updated
client.set_tenant(tenant=DEFAULT_TENANT, database=DEFAULT_DATABASE)
collections = client.list_collections()
assert len(collections) == 1
assert collections[0].metadata == {"database": "default2"}
collection = client.get_collection(collections[0])
assert collection.metadata == {"database": "default2"}

client.set_tenant(tenant=DEFAULT_TENANT, database="test_db")
collections = client.list_collections()
assert len(collections) == 1
assert collections[0].metadata == {"database": "test_db2"}
collection = client.get_collection(collections[0])
assert collection.metadata == {"database": "test_db2"}

# Delete the collections and make sure databases are isolated
client.set_tenant(tenant=DEFAULT_TENANT, database=DEFAULT_DATABASE)
Expand Down
3 changes: 2 additions & 1 deletion chromadb/test/client/test_multiple_clients_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def run_target(n: int) -> None:
client.set_database(database)
seen_collections = client.list_collections()
assert len(seen_collections) == COLLECTION_COUNT
for collection in seen_collections:
for collection_name in seen_collections:
collection = client.get_collection(collection_name)
assert collection.name in collections
assert collection.metadata == {"database": database}
4 changes: 2 additions & 2 deletions chromadb/test/property/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def delete_coll(self, coll: strategies.ExternalCollection) -> None:
def list_collections(self) -> None:
colls = self.client.list_collections()
assert len(colls) == len(self.model)
for c in colls:
assert c.name in self.model
for collection_name in colls:
assert collection_name in self.model

# @rule for list_collections with limit and offset
@rule(
Expand Down
3 changes: 2 additions & 1 deletion chromadb/test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,8 @@ def test_metadata_cru(client):

# Test list collections
collections = client.list_collections()
for collection in collections:
for collection_name in collections:
collection = client.get_collection(collection_name)
if collection.name == "testspace":
assert collection.metadata is not None
assert collection.metadata["a"] == 2
Expand Down
8 changes: 8 additions & 0 deletions docs/docs.trychroma.com/pages/deployment/migration.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ We will aim to provide:

## Migration Log

### v0.6.0

`list_collections` now returns a list of collection *names*, instead of collection objects.

Previously, `list_collections` returned a list of collection objects, configured with the default embedding function. If one of your collections was created with a different embedding function, using the version returned by `list_collections` would result in various errors.

We are working on embedding function persistence to allow you to configure a collection with an embedding function once, and not have to specify it again (in `get_collection` for example).

### v0.5.17

We no longer support sending empty lists or dictionaries for metadata filtering, ID filtering, etc. For example,
Expand Down
4 changes: 2 additions & 2 deletions docs/docs.trychroma.com/pages/reference/py-client.md
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,10 @@ class ClientAPI(BaseAPI, ABC)

```python
def list_collections(limit: Optional[int] = None,
offset: Optional[int] = None) -> Sequence[Collection]
offset: Optional[int] = None) -> List[str]
```

List all collections.
List all collection names.

**Arguments**:

Expand Down

0 comments on commit e5ec6e9

Please sign in to comment.