diff --git a/chromadb/auth/__init__.py b/chromadb/auth/__init__.py index 6ae0936d167..b34308c8d9e 100644 --- a/chromadb/auth/__init__.py +++ b/chromadb/auth/__init__.py @@ -374,6 +374,15 @@ def from_function_kwargs(**kwargs: Any) -> Callable[..., str]: lambda **kwargs: kwargs["function_kwargs"][kwargs["arg_name"]], **kwargs ) + @staticmethod + def dict_from_function_kwargs(**kwargs: Any) -> Callable[..., Dict[str, Any]]: + return partial( + lambda **kwargs: { + k: kwargs["function_kwargs"][k] for k in kwargs["arg_names"] + }, + **kwargs, + ) + @dataclass class AuthzAction: diff --git a/chromadb/auth/fastapi.py b/chromadb/auth/fastapi.py index 692d0f25a28..f93611f4765 100644 --- a/chromadb/auth/fastapi.py +++ b/chromadb/auth/fastapi.py @@ -162,6 +162,7 @@ def decorator(f: Callable[..., Any]) -> Callable[..., Any]: @wraps(f) def wrapped(*args: Any, **kwargs: Dict[Any, Any]) -> Any: _dynamic_kwargs = { + "api": args[0]._api, "function": f, "function_args": args, "function_kwargs": kwargs, diff --git a/chromadb/auth/fastapi_utils.py b/chromadb/auth/fastapi_utils.py index e80084702e9..881eee0d3ef 100644 --- a/chromadb/auth/fastapi_utils.py +++ b/chromadb/auth/fastapi_utils.py @@ -1,5 +1,6 @@ from functools import partial -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, Optional, Sequence, cast +from chromadb.api import ServerAPI from chromadb.auth import AuthzResourceTypes @@ -25,7 +26,27 @@ def find_key_with_value_of_type( def attr_from_resource_object( - type: AuthzResourceTypes, **kwargs: Any + type: AuthzResourceTypes, + additional_attrs: Optional[Sequence[str]] = None, + **kwargs: Any, +) -> Callable[..., Dict[str, Any]]: + def _wrap(**wkwargs: Any) -> Dict[str, Any]: + obj = find_key_with_value_of_type(type, **wkwargs) + if additional_attrs: + obj.update({k: wkwargs["function_kwargs"][k] + for k in additional_attrs}) + return obj + + return partial(_wrap, **kwargs) + + +def attr_from_collection_lookup( + collection_id_arg: str, **kwargs: Any ) -> Callable[..., Dict[str, Any]]: - obj = find_key_with_value_of_type(type, **kwargs) - return partial(lambda **kwargs: obj, **kwargs) + def _wrap(**kwargs: Any) -> Dict[str, Any]: + _api = cast(ServerAPI, kwargs["api"]) + col = _api.get_collection( + id=kwargs["function_kwargs"][collection_id_arg]) + return {"tenant": col.tenant, "database": col.database} + + return partial(_wrap, **kwargs) diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index e76d43023a4..28b84e4c380 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -23,7 +23,10 @@ FastAPIChromaAuthzMiddlewareWrapper, authz_context, ) -from chromadb.auth.fastapi_utils import attr_from_resource_object +from chromadb.auth.fastapi_utils import ( + attr_from_collection_lookup, + attr_from_resource_object, +) from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System import chromadb.server import chromadb.api @@ -281,7 +284,9 @@ def version(self) -> str: action=AuthzResourceActions.CREATE_DATABASE, resource=DynamicAuthzResource( type=AuthzResourceTypes.DB, - attributes=attr_from_resource_object(type=AuthzResourceTypes.DB), + attributes=attr_from_resource_object( + type=AuthzResourceTypes.DB, additional_attrs=["tenant"] + ), ), ) def create_database( @@ -295,6 +300,9 @@ def create_database( resource=DynamicAuthzResource( id="*", type=AuthzResourceTypes.DB, + attributes=AuthzDynamicParams.dict_from_function_kwargs( + arg_names=["tenant", "database"] + ), ), ) def get_database(self, database: str, tenant: str = DEFAULT_TENANT) -> Database: @@ -343,6 +351,9 @@ def list_collections( resource=DynamicAuthzResource( id="*", type=AuthzResourceTypes.DB, + attributes=AuthzDynamicParams.dict_from_function_kwargs( + arg_names=["tenant", "database"] + ), ), ) def create_collection( @@ -365,6 +376,9 @@ def create_collection( resource=DynamicAuthzResource( id=AuthzDynamicParams.from_function_kwargs(arg_name="collection_name"), type=AuthzResourceTypes.COLLECTION, + attributes=AuthzDynamicParams.dict_from_function_kwargs( + arg_names=["tenant", "database"] + ), ), ) def get_collection( @@ -383,6 +397,7 @@ def get_collection( resource=DynamicAuthzResource( id=AuthzDynamicParams.from_function_kwargs(arg_name="collection_id"), type=AuthzResourceTypes.COLLECTION, + attributes=attr_from_collection_lookup(collection_id_arg="collection_id"), ), ) def update_collection( @@ -400,6 +415,9 @@ def update_collection( resource=DynamicAuthzResource( id=AuthzDynamicParams.from_function_kwargs(arg_name="collection_name"), type=AuthzResourceTypes.COLLECTION, + attributes=AuthzDynamicParams.dict_from_function_kwargs( + arg_names=["tenant", "database"] + ), ), ) def delete_collection( @@ -418,6 +436,7 @@ def delete_collection( resource=DynamicAuthzResource( id=AuthzDynamicParams.from_function_kwargs(arg_name="collection_id"), type=AuthzResourceTypes.COLLECTION, + attributes=attr_from_collection_lookup(collection_id_arg="collection_id"), ), ) def add(self, collection_id: str, add: AddEmbedding) -> None: @@ -439,6 +458,7 @@ def add(self, collection_id: str, add: AddEmbedding) -> None: resource=DynamicAuthzResource( id=AuthzDynamicParams.from_function_kwargs(arg_name="collection_id"), type=AuthzResourceTypes.COLLECTION, + attributes=attr_from_collection_lookup(collection_id_arg="collection_id"), ), ) def update(self, collection_id: str, add: UpdateEmbedding) -> None: @@ -456,6 +476,7 @@ def update(self, collection_id: str, add: UpdateEmbedding) -> None: resource=DynamicAuthzResource( id=AuthzDynamicParams.from_function_kwargs(arg_name="collection_id"), type=AuthzResourceTypes.COLLECTION, + attributes=attr_from_collection_lookup(collection_id_arg="collection_id"), ), ) def upsert(self, collection_id: str, upsert: AddEmbedding) -> None: @@ -473,6 +494,7 @@ def upsert(self, collection_id: str, upsert: AddEmbedding) -> None: resource=DynamicAuthzResource( id=AuthzDynamicParams.from_function_kwargs(arg_name="collection_id"), type=AuthzResourceTypes.COLLECTION, + attributes=attr_from_collection_lookup(collection_id_arg="collection_id"), ), ) def get(self, collection_id: str, get: GetEmbedding) -> GetResult: @@ -493,6 +515,7 @@ def get(self, collection_id: str, get: GetEmbedding) -> GetResult: resource=DynamicAuthzResource( id=AuthzDynamicParams.from_function_kwargs(arg_name="collection_id"), type=AuthzResourceTypes.COLLECTION, + attributes=attr_from_collection_lookup(collection_id_arg="collection_id"), ), ) def delete(self, collection_id: str, delete: DeleteEmbedding) -> List[UUID]: @@ -509,6 +532,7 @@ def delete(self, collection_id: str, delete: DeleteEmbedding) -> List[UUID]: resource=DynamicAuthzResource( id=AuthzDynamicParams.from_function_kwargs(arg_name="collection_id"), type=AuthzResourceTypes.COLLECTION, + attributes=attr_from_collection_lookup(collection_id_arg="collection_id"), ), ) def count(self, collection_id: str) -> int: @@ -531,6 +555,7 @@ def reset(self) -> bool: resource=DynamicAuthzResource( id=AuthzDynamicParams.from_function_kwargs(arg_name="collection_id"), type=AuthzResourceTypes.COLLECTION, + attributes=attr_from_collection_lookup(collection_id_arg="collection_id"), ), ) def get_nearest_neighbors( diff --git a/examples/basic_functionality/authz/authz.ipynb b/examples/basic_functionality/authz/authz.ipynb index c70df77c702..97abebd5785 100644 --- a/examples/basic_functionality/authz/authz.ipynb +++ b/examples/basic_functionality/authz/authz.ipynb @@ -2,14 +2,14 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "/Users/tazarov\n" + "/Users/tazarov/experiments/chroma-experiments/authz-tenant-db-hook\n" ] }, { @@ -21,7 +21,7 @@ " 'documents': ['test21']}" ] }, - "execution_count": 2, + "execution_count": 1, "metadata": {}, "output_type": "execute_result" } @@ -87,7 +87,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.11.2" } }, "nbformat": 4,