Skip to content

Commit

Permalink
[ENH]: Authz tenant and DB resource attribute hook (#1317)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- Added attribute injection hook to add database and tenant to authz
resource attributes

## Test plan
*How are these changes tested?*

- [x] Tests pass locally with `pytest` for python

## Documentation Changes
N/A
  • Loading branch information
tazarov authored Nov 2, 2023
1 parent 4ec8d7a commit cdcafc8
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 10 deletions.
9 changes: 9 additions & 0 deletions chromadb/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,15 @@ def from_function_kwargs(**kwargs: Any) -> Callable[..., str]:
lambda **kwargs: kwargs["function_kwargs"][kwargs["arg_name"]], **kwargs
)

@staticmethod
def dict_from_function_kwargs(**kwargs: Any) -> Callable[..., Dict[str, Any]]:
return partial(
lambda **kwargs: {
k: kwargs["function_kwargs"][k] for k in kwargs["arg_names"]
},
**kwargs,
)


@dataclass
class AuthzAction:
Expand Down
1 change: 1 addition & 0 deletions chromadb/auth/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
@wraps(f)
def wrapped(*args: Any, **kwargs: Dict[Any, Any]) -> Any:
_dynamic_kwargs = {
"api": args[0]._api,
"function": f,
"function_args": args,
"function_kwargs": kwargs,
Expand Down
29 changes: 25 additions & 4 deletions chromadb/auth/fastapi_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from functools import partial
from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, Optional, Sequence, cast
from chromadb.api import ServerAPI
from chromadb.auth import AuthzResourceTypes


Expand All @@ -25,7 +26,27 @@ def find_key_with_value_of_type(


def attr_from_resource_object(
type: AuthzResourceTypes, **kwargs: Any
type: AuthzResourceTypes,
additional_attrs: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> Callable[..., Dict[str, Any]]:
def _wrap(**wkwargs: Any) -> Dict[str, Any]:
obj = find_key_with_value_of_type(type, **wkwargs)
if additional_attrs:
obj.update({k: wkwargs["function_kwargs"][k]
for k in additional_attrs})
return obj

return partial(_wrap, **kwargs)


def attr_from_collection_lookup(
collection_id_arg: str, **kwargs: Any
) -> Callable[..., Dict[str, Any]]:
obj = find_key_with_value_of_type(type, **kwargs)
return partial(lambda **kwargs: obj, **kwargs)
def _wrap(**kwargs: Any) -> Dict[str, Any]:
_api = cast(ServerAPI, kwargs["api"])
col = _api.get_collection(
id=kwargs["function_kwargs"][collection_id_arg])
return {"tenant": col.tenant, "database": col.database}

return partial(_wrap, **kwargs)
29 changes: 27 additions & 2 deletions chromadb/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
FastAPIChromaAuthzMiddlewareWrapper,
authz_context,
)
from chromadb.auth.fastapi_utils import attr_from_resource_object
from chromadb.auth.fastapi_utils import (
attr_from_collection_lookup,
attr_from_resource_object,
)
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
import chromadb.server
import chromadb.api
Expand Down Expand Up @@ -281,7 +284,9 @@ def version(self) -> str:
action=AuthzResourceActions.CREATE_DATABASE,
resource=DynamicAuthzResource(
type=AuthzResourceTypes.DB,
attributes=attr_from_resource_object(type=AuthzResourceTypes.DB),
attributes=attr_from_resource_object(
type=AuthzResourceTypes.DB, additional_attrs=["tenant"]
),
),
)
def create_database(
Expand All @@ -295,6 +300,9 @@ def create_database(
resource=DynamicAuthzResource(
id="*",
type=AuthzResourceTypes.DB,
attributes=AuthzDynamicParams.dict_from_function_kwargs(
arg_names=["tenant", "database"]
),
),
)
def get_database(self, database: str, tenant: str = DEFAULT_TENANT) -> Database:
Expand Down Expand Up @@ -343,6 +351,9 @@ def list_collections(
resource=DynamicAuthzResource(
id="*",
type=AuthzResourceTypes.DB,
attributes=AuthzDynamicParams.dict_from_function_kwargs(
arg_names=["tenant", "database"]
),
),
)
def create_collection(
Expand All @@ -365,6 +376,9 @@ def create_collection(
resource=DynamicAuthzResource(
id=AuthzDynamicParams.from_function_kwargs(arg_name="collection_name"),
type=AuthzResourceTypes.COLLECTION,
attributes=AuthzDynamicParams.dict_from_function_kwargs(
arg_names=["tenant", "database"]
),
),
)
def get_collection(
Expand All @@ -383,6 +397,7 @@ def get_collection(
resource=DynamicAuthzResource(
id=AuthzDynamicParams.from_function_kwargs(arg_name="collection_id"),
type=AuthzResourceTypes.COLLECTION,
attributes=attr_from_collection_lookup(collection_id_arg="collection_id"),
),
)
def update_collection(
Expand All @@ -400,6 +415,9 @@ def update_collection(
resource=DynamicAuthzResource(
id=AuthzDynamicParams.from_function_kwargs(arg_name="collection_name"),
type=AuthzResourceTypes.COLLECTION,
attributes=AuthzDynamicParams.dict_from_function_kwargs(
arg_names=["tenant", "database"]
),
),
)
def delete_collection(
Expand All @@ -418,6 +436,7 @@ def delete_collection(
resource=DynamicAuthzResource(
id=AuthzDynamicParams.from_function_kwargs(arg_name="collection_id"),
type=AuthzResourceTypes.COLLECTION,
attributes=attr_from_collection_lookup(collection_id_arg="collection_id"),
),
)
def add(self, collection_id: str, add: AddEmbedding) -> None:
Expand All @@ -439,6 +458,7 @@ def add(self, collection_id: str, add: AddEmbedding) -> None:
resource=DynamicAuthzResource(
id=AuthzDynamicParams.from_function_kwargs(arg_name="collection_id"),
type=AuthzResourceTypes.COLLECTION,
attributes=attr_from_collection_lookup(collection_id_arg="collection_id"),
),
)
def update(self, collection_id: str, add: UpdateEmbedding) -> None:
Expand All @@ -456,6 +476,7 @@ def update(self, collection_id: str, add: UpdateEmbedding) -> None:
resource=DynamicAuthzResource(
id=AuthzDynamicParams.from_function_kwargs(arg_name="collection_id"),
type=AuthzResourceTypes.COLLECTION,
attributes=attr_from_collection_lookup(collection_id_arg="collection_id"),
),
)
def upsert(self, collection_id: str, upsert: AddEmbedding) -> None:
Expand All @@ -473,6 +494,7 @@ def upsert(self, collection_id: str, upsert: AddEmbedding) -> None:
resource=DynamicAuthzResource(
id=AuthzDynamicParams.from_function_kwargs(arg_name="collection_id"),
type=AuthzResourceTypes.COLLECTION,
attributes=attr_from_collection_lookup(collection_id_arg="collection_id"),
),
)
def get(self, collection_id: str, get: GetEmbedding) -> GetResult:
Expand All @@ -493,6 +515,7 @@ def get(self, collection_id: str, get: GetEmbedding) -> GetResult:
resource=DynamicAuthzResource(
id=AuthzDynamicParams.from_function_kwargs(arg_name="collection_id"),
type=AuthzResourceTypes.COLLECTION,
attributes=attr_from_collection_lookup(collection_id_arg="collection_id"),
),
)
def delete(self, collection_id: str, delete: DeleteEmbedding) -> List[UUID]:
Expand All @@ -509,6 +532,7 @@ def delete(self, collection_id: str, delete: DeleteEmbedding) -> List[UUID]:
resource=DynamicAuthzResource(
id=AuthzDynamicParams.from_function_kwargs(arg_name="collection_id"),
type=AuthzResourceTypes.COLLECTION,
attributes=attr_from_collection_lookup(collection_id_arg="collection_id"),
),
)
def count(self, collection_id: str) -> int:
Expand All @@ -531,6 +555,7 @@ def reset(self) -> bool:
resource=DynamicAuthzResource(
id=AuthzDynamicParams.from_function_kwargs(arg_name="collection_id"),
type=AuthzResourceTypes.COLLECTION,
attributes=attr_from_collection_lookup(collection_id_arg="collection_id"),
),
)
def get_nearest_neighbors(
Expand Down
8 changes: 4 additions & 4 deletions examples/basic_functionality/authz/authz.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/Users/tazarov\n"
"/Users/tazarov/experiments/chroma-experiments/authz-tenant-db-hook\n"
]
},
{
Expand All @@ -21,7 +21,7 @@
" 'documents': ['test21']}"
]
},
"execution_count": 2,
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -87,7 +87,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.11.2"
}
},
"nbformat": 4,
Expand Down

0 comments on commit cdcafc8

Please sign in to comment.