Skip to content

Commit

Permalink
Add tenant/db name verification. Fix fastapi createtenant request. Ad…
Browse files Browse the repository at this point in the history
…d property test
  • Loading branch information
HammadB committed Oct 20, 2023
1 parent c66327b commit 3553ad9
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 12 deletions.
6 changes: 6 additions & 0 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ def heartbeat(self) -> int:

@override
def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
if len(name) < 3:
raise ValueError("Database name must be at least 3 characters long")

self._sysdb.create_database(
id=uuid4(),
name=name,
Expand All @@ -104,6 +107,9 @@ def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:

@override
def create_tenant(self, name: str) -> None:
if len(name) < 3:
raise ValueError("Tenant name must be at least 3 characters long")

self._sysdb.create_tenant(
name=name,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ CREATE TABLE collections_tmp (
);

-- Create default tenant and database
INSERT INTO tenants (id) VALUES ('default'); -- should ids be uuids?
INSERT INTO tenants (id) VALUES ('default'); -- The default tenant id is 'default' others are UUIDs
INSERT INTO databases (id, name, tenant_id) VALUES ('default', 'default', 'default');

INSERT INTO collections_tmp (id, name, topic, dimension, database_id)
Expand Down
5 changes: 3 additions & 2 deletions chromadb/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from chromadb.server.fastapi.types import (
AddEmbedding,
CreateDatabase,
CreateTenant,
DeleteEmbedding,
GetEmbedding,
QueryEmbedding,
Expand Down Expand Up @@ -244,8 +245,8 @@ def create_database(
) -> None:
return self._api.create_database(database.name, tenant)

def create_tenant(self, name: str) -> None:
return self._api.create_tenant(name)
def create_tenant(self, tenant: CreateTenant) -> None:
return self._api.create_tenant(tenant.name)

def list_collections(
self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE
Expand Down
4 changes: 4 additions & 0 deletions chromadb/server/fastapi/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,7 @@ class UpdateCollection(BaseModel): # type: ignore

class CreateDatabase(BaseModel):
name: str


class CreateTenant(BaseModel):
name: str
15 changes: 15 additions & 0 deletions chromadb/test/client/test_database_tenant.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
from chromadb.api.client import AdminClient, Client


Expand Down Expand Up @@ -59,3 +60,17 @@ def test_database_tenant_collections(client: Client) -> None:
client.delete_collection("collection")
collections = client.list_collections()
assert len(collections) == 0


def test_min_len_name(client: Client) -> None:
client.reset()

# Create a new database in the default tenant with a name of length 1
# and expect an error
admin_client = AdminClient.from_system(client._system)
with pytest.raises(Exception):
admin_client.create_database("a")

# Create a tenant with a name of length 1 and expect an error
with pytest.raises(Exception):
admin_client.create_tenant("a")
2 changes: 2 additions & 0 deletions chromadb/test/property/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,12 @@ class Record(TypedDict):
# TODO: support empty strings everywhere
sql_alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_"
safe_text = st.text(alphabet=sql_alphabet, min_size=1)
tenant_database_name = st.text(alphabet=sql_alphabet, min_size=3)

# Workaround for FastAPI json encoding peculiarities
# https://github.com/tiangolo/fastapi/blob/8ac8d70d52bb0dd9eb55ba4e22d3e383943da05c/fastapi/encoders.py#L104
safe_text = safe_text.filter(lambda s: not s.startswith("_sa"))
tenant_database_name = tenant_database_name.filter(lambda s: not s.startswith("_sa"))

safe_integers = st.integers(
min_value=-(2**31), max_value=2**31 - 1
Expand Down
32 changes: 23 additions & 9 deletions chromadb/test/property/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@

class CollectionStateMachine(RuleBasedStateMachine):
collections: Bundle[strategies.Collection]
model: Dict[str, Optional[types.CollectionMetadata]]
_model: Dict[str, Optional[types.CollectionMetadata]]

collections = Bundle("collections")

def __init__(self, api: ClientAPI):
super().__init__()
self.model = {}
self._model = {}
self.api = api

@initialize()
def initialize(self) -> None:
self.api.reset()
self.model = {}
self._model = {}

@rule(target=collections, coll=strategies.collections())
def create_coll(
Expand All @@ -54,7 +54,7 @@ def create_coll(
metadata=coll.metadata,
embedding_function=coll.embedding_function,
)
self.model[coll.name] = coll.metadata
self.set_model(coll.name, coll.metadata)

assert c.name == coll.name
assert c.metadata == coll.metadata
Expand All @@ -74,7 +74,7 @@ def get_coll(self, coll: strategies.Collection) -> None:
def delete_coll(self, coll: strategies.Collection) -> None:
if coll.name in self.model:
self.api.delete_collection(name=coll.name)
del self.model[coll.name]
self.delete_from_model(coll.name)
else:
with pytest.raises(Exception):
self.api.delete_collection(name=coll.name)
Expand Down Expand Up @@ -140,7 +140,7 @@ def get_or_create_coll(
coll.metadata = (
self.model[coll.name] if new_metadata is None else new_metadata
)
self.model[coll.name] = coll.metadata
self.set_model(coll.name, coll.metadata)

# Update API
c = self.api.get_or_create_collection(
Expand Down Expand Up @@ -183,16 +183,16 @@ def modify_coll(
)
return multiple()
coll.metadata = new_metadata
self.model[coll.name] = coll.metadata
self.set_model(coll.name, coll.metadata)

if new_name is not None:
if new_name in self.model and new_name != coll.name:
with pytest.raises(Exception):
c.modify(metadata=new_metadata, name=new_name)
return multiple()

del self.model[coll.name]
self.model[new_name] = coll.metadata
self.delete_from_model(coll.name)
self.set_model(new_name, coll.metadata)
coll.name = new_name

c.modify(metadata=new_metadata, name=new_name)
Expand All @@ -202,6 +202,20 @@ def modify_coll(
assert c.metadata == coll.metadata
return multiple(coll)

def set_model(
self, name: str, metadata: Optional[types.CollectionMetadata]
) -> None:
model = self.model
model[name] = metadata

def delete_from_model(self, name: str) -> None:
model = self.model
del model[name]

@property
def model(self) -> Dict[str, Optional[types.CollectionMetadata]]:
return self._model


def test_collections(caplog: pytest.LogCaptureFixture, api: ClientAPI) -> None:
caplog.set_level(logging.ERROR)
Expand Down
105 changes: 105 additions & 0 deletions chromadb/test/property/test_collections_with_database_tenant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import logging
from typing import Dict, Optional, Tuple
import pytest
from chromadb.api import AdminAPI
import chromadb.api.types as types
from chromadb.api.client import AdminClient, Client
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT
from chromadb.test.property.test_collections import CollectionStateMachine
from hypothesis.stateful import (
Bundle,
rule,
initialize,
multiple,
run_state_machine_as_test,
MultipleResults,
)
import chromadb.test.property.strategies as strategies


class TenantDatabaseCollectionStateMachine(CollectionStateMachine):
"""A collection state machine test that includes tenant and database information,
and switches between them."""

tenants: Bundle[str]
databases: Bundle[Tuple[str, str]] # database to tenant it belongs to
tenant_to_database_to_model: Dict[
str, Dict[str, Dict[str, Optional[types.CollectionMetadata]]]
]
admin_client: AdminAPI
curr_tenant: str
curr_database: str

tenants = Bundle("tenants")
databases = Bundle("databases")

def __init__(self, client: Client):
super().__init__(client)
self.api = client
self.admin_client = AdminClient.from_system(client._system)

@initialize()
def initialize(self) -> None:
self.api.reset()
self.tenant_to_database_to_model = {}
self.curr_tenant = DEFAULT_TENANT
self.curr_database = DEFAULT_DATABASE
self.api.set_tenant(DEFAULT_TENANT)
self.api.set_database(DEFAULT_DATABASE)
self.tenant_to_database_to_model[self.curr_tenant] = {}
self.tenant_to_database_to_model[self.curr_tenant][self.curr_database] = {}

@rule(target=tenants, name=strategies.tenant_database_name)
def create_tenant(self, name: str) -> MultipleResults[str]:
# Check if tenant already exists
if name in self.tenant_to_database_to_model:
with pytest.raises(Exception):
self.admin_client.create_tenant(name)
return multiple()

self.admin_client.create_tenant(name)
# When we create a tenant, create a default database for it just for testing
# since the state machine could call collection operations before creating a
# database
self.admin_client.create_database(DEFAULT_DATABASE, tenant=name)
self.tenant_to_database_to_model[name] = {}
self.tenant_to_database_to_model[name][DEFAULT_DATABASE] = {}
return multiple(name)

@rule(target=databases, name=strategies.tenant_database_name)
def create_database(self, name: str) -> MultipleResults[Tuple[str, str]]:
# If database already exists in current tenant, raise an error
if name in self.tenant_to_database_to_model[self.curr_tenant]:
with pytest.raises(Exception):
self.admin_client.create_database(name, tenant=self.curr_tenant)
return multiple()

self.admin_client.create_database(name, tenant=self.curr_tenant)
self.tenant_to_database_to_model[self.curr_tenant][name] = {}
return multiple((name, self.curr_tenant))

@rule(database=databases)
def set_database_and_tenant(self, database: Dict[str, str]) -> None:
# Get a database and switch to the database and the tenant it belongs to
database_name = database[0]
tenant_name = database[1]
self.api.set_tenant(tenant_name)
self.api.set_database(database_name)
self.curr_database = database_name
self.curr_tenant = tenant_name

@rule(tenant=tenants)
def set_tenant(self, tenant: str) -> None:
self.api.set_tenant(tenant)
self.api.set_database(DEFAULT_DATABASE)
self.curr_tenant = tenant
self.curr_database = DEFAULT_DATABASE

@property
def model(self) -> Dict[str, Optional[types.CollectionMetadata]]:
return self.tenant_to_database_to_model[self.curr_tenant][self.curr_database]


def test_collections(caplog: pytest.LogCaptureFixture, client: Client) -> None:
caplog.set_level(logging.ERROR)
run_state_machine_as_test(lambda: TenantDatabaseCollectionStateMachine(client)) # type: ignore

0 comments on commit 3553ad9

Please sign in to comment.