diff --git a/.github/workflows/chroma-integration-test.yml b/.github/workflows/chroma-integration-test.yml index 4d9c0085d18..2f731028801 100644 --- a/.github/workflows/chroma-integration-test.yml +++ b/.github/workflows/chroma-integration-test.yml @@ -17,9 +17,10 @@ jobs: matrix: python: ['3.8'] platform: [ubuntu-latest, windows-latest] - testfile: ["--ignore-glob 'chromadb/test/property/*' --ignore='chromadb/test/test_cli.py'", + testfile: ["--ignore-glob 'chromadb/test/property/*' --ignore='chromadb/test/test_cli.py' --ignore='chromadb/test/auth/test_simple_rbac_authz.py'", "chromadb/test/property/test_add.py", "chromadb/test/test_cli.py", + "chromadb/test/auth/test_simple_rbac_authz.py", "chromadb/test/property/test_collections.py", "chromadb/test/property/test_cross_version_persist.py", "chromadb/test/property/test_embeddings.py", diff --git a/.github/workflows/chroma-test.yml b/.github/workflows/chroma-test.yml index 4ef9c64ed7b..12a5de4b6ed 100644 --- a/.github/workflows/chroma-test.yml +++ b/.github/workflows/chroma-test.yml @@ -18,7 +18,8 @@ jobs: matrix: python: ['3.8', '3.9', '3.10', '3.11'] platform: [ubuntu-latest, windows-latest] - testfile: ["--ignore-glob 'chromadb/test/property/*' --ignore-glob 'chromadb/test/stress/*'", + testfile: ["--ignore-glob 'chromadb/test/property/*' --ignore-glob 'chromadb/test/stress/*' --ignore='chromadb/test/auth/test_simple_rbac_authz.py'", + "chromadb/test/auth/test_simple_rbac_authz.py", "chromadb/test/property/test_add.py", "chromadb/test/property/test_collections.py", "chromadb/test/property/test_cross_version_persist.py", diff --git a/chromadb/auth/__init__.py b/chromadb/auth/__init__.py index 67f296bb35c..6ae0936d167 100644 --- a/chromadb/auth/__init__.py +++ b/chromadb/auth/__init__.py @@ -2,21 +2,29 @@ Contains only Auth abstractions, no implementations. """ import base64 +from functools import partial import logging from abc import ABC, abstractmethod from enum import Enum from typing import ( + Any, + Callable, + List, Optional, Dict, TypeVar, Tuple, Generic, + Union, ) +from dataclasses import dataclass from overrides import EnforceOverrides, override from pydantic import SecretStr from chromadb.config import ( + DEFAULT_DATABASE, + DEFAULT_TENANT, Component, System, ) @@ -35,13 +43,53 @@ class AuthInfoType(Enum): METADATA = "metadata" # gRPC +class UserIdentity(EnforceOverrides, ABC): + @abstractmethod + def get_user_id(self) -> str: + ... + + @abstractmethod + def get_user_tenant(self) -> Optional[str]: + ... + + @abstractmethod + def get_user_attributes(self) -> Optional[Dict[str, Any]]: + ... + + +class SimpleUserIdentity(UserIdentity): + def __init__( + self, + user_id: str, + tenant: Optional[str] = None, + attributes: Optional[Dict[str, Any]] = None, + ) -> None: + self._user_id = user_id + self._tenant = tenant + self._attributes = attributes + + @override + def get_user_id(self) -> str: + return self._user_id + + @override + def get_user_tenant(self) -> Optional[str]: + return self._tenant if self._tenant else DEFAULT_TENANT + + @override + def get_user_attributes(self) -> Optional[Dict[str, Any]]: + return self._attributes + + class ClientAuthResponse(EnforceOverrides, ABC): @abstractmethod def get_auth_info_type(self) -> AuthInfoType: ... @abstractmethod - def get_auth_info(self) -> Tuple[str, SecretStr]: + def get_auth_info( + self, + ) -> Union[Tuple[str, SecretStr], List[Tuple[str, SecretStr]]]: ... @@ -86,12 +134,11 @@ def inject_credentials(self, injection_context: T) -> None: class ServerAuthenticationRequest(EnforceOverrides, ABC, Generic[T]): @abstractmethod - def get_auth_info( - self, auth_info_type: AuthInfoType, auth_info_id: Optional[str] = None - ) -> T: + def get_auth_info(self, auth_info_type: AuthInfoType, auth_info_id: str) -> T: """ - This method should return the necessary auth info based on the type of authentication (e.g. header, cookie, url) - and a given id for the respective auth type (e.g. name of the header, cookie, url param). + This method should return the necessary auth info based on the type of + authentication (e.g. header, cookie, url) and a given id for the respective + auth type (e.g. name of the header, cookie, url param). :param auth_info_type: The type of auth info to return :param auth_info_id: The id of the auth info to return @@ -101,8 +148,34 @@ def get_auth_info( class ServerAuthenticationResponse(EnforceOverrides, ABC): + @abstractmethod + def success(self) -> bool: + ... + + @abstractmethod + def get_user_identity(self) -> Optional[UserIdentity]: + ... + + +class SimpleServerAuthenticationResponse(ServerAuthenticationResponse): + """Simple implementation of ServerAuthenticationResponse""" + + _auth_success: bool + _user_identity: Optional[UserIdentity] + + def __init__( + self, auth_success: bool, user_identity: Optional[UserIdentity] + ) -> None: + self._auth_success = auth_success + self._user_identity = user_identity + + @override def success(self) -> bool: - raise NotImplementedError() + return self._auth_success + + @override + def get_user_identity(self) -> Optional[UserIdentity]: + return self._user_identity class ServerAuthProvider(Component): @@ -110,7 +183,9 @@ def __init__(self, system: System) -> None: super().__init__(system) @abstractmethod - def authenticate(self, request: ServerAuthenticationRequest[T]) -> bool: + def authenticate( + self, request: ServerAuthenticationRequest[T] + ) -> ServerAuthenticationResponse: pass @@ -121,7 +196,7 @@ def __init__(self, system: System) -> None: @abstractmethod def authenticate( self, request: ServerAuthenticationRequest[T] - ) -> Optional[ServerAuthenticationResponse]: + ) -> ServerAuthenticationResponse: ... @abstractmethod @@ -155,8 +230,8 @@ def name(cls) -> str: class AbstractCredentials(EnforceOverrides, ABC, Generic[T]): """ - The class is used by Auth Providers to encapsulate credentials received from the server - and pass them to a ServerAuthCredentialsProvider. + The class is used by Auth Providers to encapsulate credentials received + from the server and pass them to a ServerAuthCredentialsProvider. """ @abstractmethod @@ -204,4 +279,162 @@ def __init__(self, system: System) -> None: @abstractmethod def validate_credentials(self, credentials: AbstractCredentials[T]) -> bool: + ... + + @abstractmethod + def get_user_identity( + self, credentials: AbstractCredentials[T] + ) -> Optional[UserIdentity]: + ... + + +class AuthzResourceTypes(str, Enum): + DB = "db" + COLLECTION = "collection" + TENANT = "tenant" + + +class AuthzResourceActions(str, Enum): + CREATE_DATABASE = "create_database" + GET_DATABASE = "get_database" + CREATE_TENANT = "create_tenant" + GET_TENANT = "get_tenant" + LIST_COLLECTIONS = "list_collections" + GET_COLLECTION = "get_collection" + CREATE_COLLECTION = "create_collection" + GET_OR_CREATE_COLLECTION = "get_or_create_collection" + DELETE_COLLECTION = "delete_collection" + UPDATE_COLLECTION = "update_collection" + ADD = "add" + DELETE = "delete" + GET = "get" + QUERY = "query" + COUNT = "count" + UPDATE = "update" + UPSERT = "upsert" + RESET = "reset" + + +@dataclass +class AuthzUser: + id: Optional[str] + tenant: Optional[str] = DEFAULT_TENANT + attributes: Optional[Dict[str, Any]] = None + claims: Optional[Dict[str, Any]] = None + + +@dataclass +class AuthzResource: + id: Optional[str] + type: Optional[str] + attributes: Optional[Dict[str, Any]] = None + + +class DynamicAuthzResource: + id: Optional[Union[str, Callable[..., str]]] + type: Optional[Union[str, Callable[..., str]]] + attributes: Optional[Union[Dict[str, Any], Callable[..., Dict[str, Any]]]] + + def __init__( + self, + id: Optional[Union[str, Callable[..., str]]] = None, + attributes: Optional[ + Union[Dict[str, Any], Callable[..., Dict[str, Any]]] + ] = lambda **kwargs: {}, + type: Optional[Union[str, Callable[..., str]]] = DEFAULT_DATABASE, + ) -> None: + self.id = id + self.attributes = attributes + self.type = type + + def to_authz_resource(self, **kwargs: Any) -> AuthzResource: + return AuthzResource( + id=self.id(**kwargs) if callable(self.id) else self.id, + type=self.type(**kwargs) if callable(self.type) else self.type, + attributes=self.attributes(**kwargs) + if callable(self.attributes) + else self.attributes, + ) + + +class AuthzDynamicParams: + @staticmethod + def from_function_name(**kwargs: Any) -> Callable[..., str]: + return partial(lambda **kwargs: kwargs["function"].__name__, **kwargs) + + @staticmethod + def from_function_args(**kwargs: Any) -> Callable[..., str]: + return partial( + lambda **kwargs: kwargs["function_args"][kwargs["arg_num"]], **kwargs + ) + + @staticmethod + def from_function_kwargs(**kwargs: Any) -> Callable[..., str]: + return partial( + lambda **kwargs: kwargs["function_kwargs"][kwargs["arg_name"]], **kwargs + ) + + +@dataclass +class AuthzAction: + id: str + attributes: Optional[Dict[str, Any]] = None + + +@dataclass +class AuthorizationContext: + user: AuthzUser + resource: AuthzResource + action: AuthzAction + + +class ServerAuthorizationProvider(Component): + def __init__(self, system: System) -> None: + super().__init__(system) + + @abstractmethod + def authorize(self, context: AuthorizationContext) -> bool: pass + + +class AuthorizationRequestContext(EnforceOverrides, ABC, Generic[T]): + @abstractmethod + def get_request(self) -> T: + ... + + +class ChromaAuthzMiddleware(Component, Generic[T, S]): + def __init__(self, system: System) -> None: + super().__init__(system) + + @abstractmethod + def pre_process(self, request: AuthorizationRequestContext[S]) -> None: + ... + + @abstractmethod + def ignore_operation(self, verb: str, path: str) -> bool: + ... + + @abstractmethod + def instrument_server(self, app: T) -> None: + ... + + +class ServerAuthorizationConfigurationProvider(Component, Generic[T]): + def __init__(self, system: System) -> None: + super().__init__(system) + + @abstractmethod + def get_configuration(self) -> T: + pass + + +class AuthorizationError(ChromaError): + @override + def code(self) -> int: + return 403 + + @classmethod + @override + def name(cls) -> str: + return "AuthorizationError" diff --git a/chromadb/auth/authz/__init__.py b/chromadb/auth/authz/__init__.py new file mode 100644 index 00000000000..0cb350e8ca4 --- /dev/null +++ b/chromadb/auth/authz/__init__.py @@ -0,0 +1,110 @@ +import logging +from typing import Any, Dict, Set, cast +from overrides import override +import yaml +from chromadb.auth import ( + AuthorizationContext, + ServerAuthorizationConfigurationProvider, + ServerAuthorizationProvider, +) +from chromadb.auth.registry import register_provider, resolve_provider +from chromadb.config import DEFAULT_TENANT, System + +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryGranularity, + trace_method, +) + +logger = logging.getLogger(__name__) + + +@register_provider("local_authz_config") +class LocalUserConfigAuthorizationConfigurationProvider( + ServerAuthorizationConfigurationProvider[Dict[str, Any]] +): + _config_file: str + _config: Dict[str, Any] + + def __init__(self, system: System) -> None: + super().__init__(system) + self._settings = system.settings + if self._settings.chroma_server_authz_config_file: + self._config_file = str(system.settings.chroma_server_authz_config_file) + with open(self._config_file, "r") as f: + self._config = yaml.safe_load(f) + elif self._settings.chroma_server_authz_config: + self._config = self._settings.chroma_server_authz_config + else: + raise ValueError( + "No configuration (CHROMA_SERVER_AUTHZ_CONFIG_FILE) file or " + "configuration (CHROMA_SERVER_AUTHZ_CONFIG) provided for " + "LocalUserConfigAuthorizationConfigurationProvider" + ) + + @override + def get_configuration(self) -> Dict[str, Any]: + return self._config + + +@register_provider("simple_rbac") +class SimpleRBACAuthorizationProvider(ServerAuthorizationProvider): + _authz_config_provider: ServerAuthorizationConfigurationProvider[Dict[str, Any]] + + def __init__(self, system: System) -> None: + super().__init__(system) + self._settings = system.settings + system.settings.require("chroma_server_authz_config_provider") + if self._settings.chroma_server_authz_config_provider: + _cls = resolve_provider( + self._settings.chroma_server_authz_config_provider, + ServerAuthorizationConfigurationProvider, + ) + self._authz_config_provider = cast( + ServerAuthorizationConfigurationProvider[Dict[str, Any]], + self.require(_cls), + ) + _config = self._authz_config_provider.get_configuration() + self._authz_tuples_map: Dict[str, Set[Any]] = {} + for u in _config["users"]: + _actions = _config["roles_mapping"][u["role"]]["actions"] + for a in _actions: + tenant = u["tenant"] if "tenant" in u else DEFAULT_TENANT + if u["id"] not in self._authz_tuples_map.keys(): + self._authz_tuples_map[u["id"]] = set() + self._authz_tuples_map[u["id"]].add( + (u["id"], tenant, *a.split(":")) + ) + logger.debug( + f"Loaded {len(self._authz_tuples_map)} permissions for " + f"({len(_config['users'])}) users" + ) + logger.info( + "Authorization Provider SimpleRBACAuthorizationProvider initialized" + ) + + @trace_method( + "SimpleRBACAuthorizationProvider.authorize", + OpenTelemetryGranularity.ALL, + ) + @override + def authorize(self, context: AuthorizationContext) -> bool: + _authz_tuple = ( + context.user.id, + context.user.tenant, + context.resource.type, + context.action.id, + ) + + policy_decision = False + if ( + context.user.id in self._authz_tuples_map.keys() + and _authz_tuple in self._authz_tuples_map[context.user.id] + ): + policy_decision = True + logger.debug( + f"Authorization decision: Access " + f"{'granted' if policy_decision else 'denied'} for " + f"user [{context.user.id}] attempting to [{context.action.id}]" + f" on [{context.resource}]" + ) + return policy_decision diff --git a/chromadb/auth/basic/__init__.py b/chromadb/auth/basic/__init__.py index a9888598a22..d561895faf5 100644 --- a/chromadb/auth/basic/__init__.py +++ b/chromadb/auth/basic/__init__.py @@ -14,11 +14,11 @@ BasicAuthCredentials, ClientAuthCredentialsProvider, ClientAuthResponse, + SimpleServerAuthenticationResponse, ) from chromadb.auth.registry import register_provider, resolve_provider from chromadb.config import System from chromadb.telemetry.opentelemetry import ( - OpenTelemetryClient, OpenTelemetryGranularity, trace_method, ) @@ -91,12 +91,20 @@ def __init__(self, system: System) -> None: @trace_method("BasicAuthServerProvider.authenticate", OpenTelemetryGranularity.ALL) @override - def authenticate(self, request: ServerAuthenticationRequest[Any]) -> bool: + def authenticate( + self, request: ServerAuthenticationRequest[Any] + ) -> SimpleServerAuthenticationResponse: try: _auth_header = request.get_auth_info(AuthInfoType.HEADER, "Authorization") - return self._credentials_provider.validate_credentials( + _validation = self._credentials_provider.validate_credentials( BasicAuthCredentials.from_header(_auth_header) ) + return SimpleServerAuthenticationResponse( + _validation, + self._credentials_provider.get_user_identity( + BasicAuthCredentials.from_header(_auth_header) + ), + ) except Exception as e: logger.error(f"BasicAuthServerProvider.authenticate failed: {repr(e)}") - return False + return SimpleServerAuthenticationResponse(False, None) diff --git a/chromadb/auth/fastapi.py b/chromadb/auth/fastapi.py index 14b531e48e8..17cf4c3d0e5 100644 --- a/chromadb/auth/fastapi.py +++ b/chromadb/auth/fastapi.py @@ -1,24 +1,34 @@ # FAST API code +from contextvars import ContextVar +from functools import wraps import logging -from typing import Optional, Dict, List, cast, Any - +from typing import Callable, Optional, Dict, List, Union, cast, Any from overrides import override from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request from starlette.responses import Response, JSONResponse from starlette.types import ASGIApp -from chromadb.config import System +from chromadb.config import DEFAULT_TENANT, System from chromadb.auth import ( + AuthorizationContext, + AuthorizationError, + AuthorizationRequestContext, + AuthzAction, + AuthzResource, + AuthzResourceActions, + AuthzUser, + DynamicAuthzResource, ServerAuthenticationRequest, AuthInfoType, ServerAuthenticationResponse, ServerAuthProvider, ChromaAuthMiddleware, + ChromaAuthzMiddleware, + ServerAuthorizationProvider, ) from chromadb.auth.registry import resolve_provider from chromadb.telemetry.opentelemetry import ( - OpenTelemetryClient, OpenTelemetryGranularity, trace_method, ) @@ -32,7 +42,7 @@ def __init__(self, request: Request) -> None: @override def get_auth_info( - self, auth_info_type: AuthInfoType, auth_info_id: Optional[str] = None + self, auth_info_type: AuthInfoType, auth_info_id: str ) -> Optional[str]: if auth_info_type == AuthInfoType.HEADER: return str(self._request.headers[auth_info_id]) @@ -83,11 +93,12 @@ def __init__(self, system: System) -> None: @override def authenticate( self, request: ServerAuthenticationRequest[Any] - ) -> Optional[ServerAuthenticationResponse]: - return FastAPIServerAuthenticationResponse( - self._auth_provider.authenticate(request) - ) + ) -> ServerAuthenticationResponse: + return self._auth_provider.authenticate(request) + @trace_method( + "FastAPIChromaAuthMiddleware.ignore_operation", OpenTelemetryGranularity.ALL + ) @override def ignore_operation(self, verb: str, path: str) -> bool: if ( @@ -100,8 +111,9 @@ def ignore_operation(self, verb: str, path: str) -> bool: @override def instrument_server(self, app: ASGIApp) -> None: - # We can potentially add an `/auth` endpoint to the server to allow for more complex auth flows - return + # We can potentially add an `/auth` endpoint to the server to allow for more + # complex auth flows + raise NotImplementedError("Not implemented yet") class FastAPIChromaAuthMiddlewareWrapper(BaseHTTPMiddleware): # type: ignore @@ -110,8 +122,14 @@ def __init__( ) -> None: super().__init__(app) self._middleware = auth_middleware - self._middleware.instrument_server(app) + try: + self._middleware.instrument_server(app) + except NotImplementedError: + pass + @trace_method( + "FastAPIChromaAuthMiddlewareWrapper.dispatch", OpenTelemetryGranularity.ALL + ) @override async def dispatch( self, request: Request, call_next: RequestResponseEndpoint @@ -126,4 +144,147 @@ async def dispatch( ) if not response or not response.success(): return JSONResponse({"error": "Unauthorized"}, status_code=401) + request.state.user_identity = response.get_user_identity() + return await call_next(request) + + +request_var: ContextVar[Optional[Request]] = ContextVar("request_var", default=None) +authz_provider: ContextVar[Optional[ServerAuthorizationProvider]] = ContextVar( + "authz_provider", default=None +) + + +def authz_context( + action: Union[str, AuthzResourceActions, List[str], List[AuthzResourceActions]], + resource: Union[AuthzResource, DynamicAuthzResource], +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + def decorator(f: Callable[..., Any]) -> Callable[..., Any]: + @wraps(f) + def wrapped(*args: Any, **kwargs: Dict[Any, Any]) -> Any: + _dynamic_kwargs = { + "function": f, + "function_args": args, + "function_kwargs": kwargs, + } + request = request_var.get() + if request: + _provider = authz_provider.get() + a_list: List[Union[str, AuthzAction]] = [] + if not isinstance(action, List): + a_list = [action] + else: + a_list = cast(List[Union[str, AuthzAction]], action) + a_authz_responses = [] + for a in a_list: + _action = a if isinstance(a, AuthzAction) else AuthzAction(id=a) + _resource = ( + resource + if isinstance(resource, AuthzResource) + else resource.to_authz_resource(**_dynamic_kwargs) + ) + _context = AuthorizationContext( + user=AuthzUser( + id=request.state.user_identity.get_user_id() + if hasattr(request.state, "user_identity") + else "Anonymous", + tenant=request.state.user_identity.get_user_tenant() + if hasattr(request.state, "user_identity") + else DEFAULT_TENANT, + ), + resource=_resource, + action=_action, + ) + + if _provider: + a_authz_responses.append(_provider.authorize(_context)) + if not any(a_authz_responses): + raise AuthorizationError("Unauthorized") + return f(*args, **kwargs) + + return wrapped + + return decorator + + +class FastAPIAuthorizationRequestContext(AuthorizationRequestContext[Request]): + _request: Request + + def __init__(self, request: Request) -> None: + self._request = request + pass + + @override + def get_request(self) -> Request: + return self._request + + +class FastAPIChromaAuthzMiddleware(ChromaAuthzMiddleware[ASGIApp, Request]): + _authz_provider: ServerAuthorizationProvider + + def __init__(self, system: System) -> None: + super().__init__(system) + self._system = system + self._settings = system.settings + self._settings.require("chroma_server_authz_provider") + self._ignore_auth_paths: Dict[ + str, List[str] + ] = self._settings.chroma_server_authz_ignore_paths + if self._settings.chroma_server_authz_provider: + logger.debug( + "Server Authorization Provider: " + f"{self._settings.chroma_server_authz_provider}" + ) + _cls = resolve_provider( + self._settings.chroma_server_authz_provider, ServerAuthorizationProvider + ) + self._authz_provider = cast(ServerAuthorizationProvider, self.require(_cls)) + + @override + def pre_process(self, request: AuthorizationRequestContext[Request]) -> None: + rest_request = request.get_request() + request_var.set(rest_request) + authz_provider.set(self._authz_provider) + + @override + def ignore_operation(self, verb: str, path: str) -> bool: + if ( + path in self._ignore_auth_paths.keys() + and verb.upper() in self._ignore_auth_paths[path] + ): + logger.debug(f"Skipping authz for path {path} and method {verb}") + return True + return False + + @override + def instrument_server(self, app: ASGIApp) -> None: + # We can potentially add an `/auth` endpoint to the server to allow + # for more complex auth flows + raise NotImplementedError("Not implemented yet") + + +class FastAPIChromaAuthzMiddlewareWrapper(BaseHTTPMiddleware): # type: ignore + def __init__( + self, app: ASGIApp, authz_middleware: FastAPIChromaAuthzMiddleware + ) -> None: + super().__init__(app) + self._middleware = authz_middleware + try: + self._middleware.instrument_server(app) + except NotImplementedError: + pass + + @trace_method( + "FastAPIChromaAuthzMiddlewareWrapper.dispatch", OpenTelemetryGranularity.ALL + ) + @override + async def dispatch( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: + if self._middleware.ignore_operation(request.method, request.url.path): + logger.debug( + f"Skipping authz for path {request.url.path} " + "and method {request.method}" + ) + return await call_next(request) + self._middleware.pre_process(FastAPIAuthorizationRequestContext(request)) return await call_next(request) diff --git a/chromadb/auth/fastapi_utils.py b/chromadb/auth/fastapi_utils.py new file mode 100644 index 00000000000..e80084702e9 --- /dev/null +++ b/chromadb/auth/fastapi_utils.py @@ -0,0 +1,31 @@ +from functools import partial +from typing import Any, Callable, Dict +from chromadb.auth import AuthzResourceTypes + + +def find_key_with_value_of_type( + type: AuthzResourceTypes, **kwargs: Any +) -> Dict[str, Any]: + from chromadb.server.fastapi.types import ( + CreateCollection, + CreateDatabase, + CreateTenant, + ) + + for key, value in kwargs.items(): + if type == AuthzResourceTypes.DB and isinstance(value, CreateDatabase): + return dict(value) + elif type == AuthzResourceTypes.COLLECTION and isinstance( + value, CreateCollection + ): + return dict(value) + elif type == AuthzResourceTypes.TENANT and isinstance(value, CreateTenant): + return dict(value) + return {} + + +def attr_from_resource_object( + type: AuthzResourceTypes, **kwargs: Any +) -> Callable[..., Dict[str, Any]]: + obj = find_key_with_value_of_type(type, **kwargs) + return partial(lambda **kwargs: obj, **kwargs) diff --git a/chromadb/auth/providers.py b/chromadb/auth/providers.py index 2982b9e15a6..8eb3f4697cb 100644 --- a/chromadb/auth/providers.py +++ b/chromadb/auth/providers.py @@ -1,6 +1,6 @@ import importlib import logging -from typing import cast, Dict, TypeVar, Any +from typing import Optional, cast, Dict, TypeVar, Any import requests from overrides import override @@ -12,11 +12,11 @@ AuthInfoType, ClientAuthProvider, ClientAuthProtocolAdapter, + SimpleUserIdentity, ) from chromadb.auth.registry import register_provider, resolve_provider from chromadb.config import System from chromadb.telemetry.opentelemetry import ( - OpenTelemetryClient, OpenTelemetryGranularity, trace_method, ) @@ -36,7 +36,8 @@ def __init__(self, system: System) -> None: self.bc = importlib.import_module("bcrypt") except ImportError: raise ValueError( - "The bcrypt python package is not installed. Please install it with `pip install bcrypt`" + "The bcrypt python package is not installed. " + "Please install it with `pip install bcrypt`" ) @trace_method( @@ -48,11 +49,13 @@ def validate_credentials(self, credentials: AbstractCredentials[T]) -> bool: _creds = cast(Dict[str, SecretStr], credentials.get_credentials()) if len(_creds) != 2: logger.error( - "Returned credentials did match expected format: dict[username:SecretStr, password: SecretStr]" + "Returned credentials did match expected format: " + "dict[username:SecretStr, password: SecretStr]" ) return False if "username" not in _creds or "password" not in _creds: - logger.error("Returned credentials do not contain username or password") + logger.error( + "Returned credentials do not contain username or password") return False _usr_check = bool( _creds["username"].get_secret_value() @@ -63,6 +66,13 @@ def validate_credentials(self, credentials: AbstractCredentials[T]) -> bool: self._creds["password"].get_secret_value().encode("utf-8"), ) + @override + def get_user_identity( + self, credentials: AbstractCredentials[T] + ) -> Optional[SimpleUserIdentity]: + _creds = cast(Dict[str, SecretStr], credentials.get_credentials()) + return SimpleUserIdentity(_creds["username"].get_secret_value()) + @register_provider("htpasswd_file") class HtpasswdFileServerAuthCredentialsProvider(HtpasswdServerAuthCredentialsProvider): @@ -70,7 +80,7 @@ def __init__(self, system: System) -> None: super().__init__(system) system.settings.require("chroma_server_auth_credentials_file") _file = str(system.settings.chroma_server_auth_credentials_file) - with open(_file) as f: + with open(_file, "r") as f: _raw_creds = [v for v in f.readline().strip().split(":")] self._creds = { "username": SecretStr(_raw_creds[0]), @@ -82,7 +92,8 @@ def __init__(self, system: System) -> None: or "password" not in self._creds ): raise ValueError( - "Invalid Htpasswd credentials found in [chroma_server_auth_credentials]. " + "Invalid Htpasswd credentials found in " + "[chroma_server_auth_credentials]. " "Must be :." ) @@ -106,7 +117,8 @@ def __init__(self, system: System) -> None: or "password" not in self._creds ): raise ValueError( - "Invalid Htpasswd credentials found in [chroma_server_auth_credentials]. " + "Invalid Htpasswd credentials found in " + "[chroma_server_auth_credentials]. " "Must be :." ) @@ -155,9 +167,14 @@ def session(self) -> requests.Session: def inject_credentials(self, injection_context: requests.PreparedRequest) -> None: if self._auth_header.get_auth_info_type() == AuthInfoType.HEADER: _header_info = self._auth_header.get_auth_info() - injection_context.headers[_header_info[0]] = _header_info[ - 1 - ].get_secret_value() + if isinstance(_header_info, tuple): + injection_context.headers[_header_info[0]] = _header_info[ + 1 + ].get_secret_value() + else: + for header in _header_info: + injection_context.headers[header[0] + ] = header[1].get_secret_value() else: raise ValueError( f"Unsupported auth type: {self._auth_header.get_auth_info_type()}" @@ -172,7 +189,8 @@ class ConfigurationClientAuthCredentialsProvider( def __init__(self, system: System) -> None: super().__init__(system) system.settings.require("chroma_client_auth_credentials") - self._creds = SecretStr(str(system.settings.chroma_client_auth_credentials)) + self._creds = SecretStr( + str(system.settings.chroma_client_auth_credentials)) @override def get_credentials(self) -> SecretStr: diff --git a/chromadb/auth/registry.py b/chromadb/auth/registry.py index 7e844753199..af0f0f903e6 100644 --- a/chromadb/auth/registry.py +++ b/chromadb/auth/registry.py @@ -11,6 +11,8 @@ ServerAuthConfigurationProvider, ServerAuthCredentialsProvider, ClientAuthProvider, + ServerAuthorizationConfigurationProvider, + ServerAuthorizationProvider, ) from chromadb.utils import get_class @@ -23,6 +25,8 @@ "ServerAuthConfigurationProvider", "ServerAuthCredentialsProvider", "ClientAuthProtocolAdapter", + "ServerAuthorizationProvider", + "ServerAuthorizationConfigurationProvider", ] _provider_registry = { @@ -33,6 +37,8 @@ "server_auth_providers": {}, "server_auth_config_providers": {}, "server_auth_credentials_providers": {}, + "server_authz_providers": {}, + "server_authz_config_providers": {}, } # type: Dict[str, Dict[str, Type[ProviderTypes]]] @@ -63,12 +69,17 @@ def decorator(cls: Type[ProviderTypes]) -> Type[ProviderTypes]: _provider_registry["server_auth_config_providers"][short_hand] = cls elif issubclass(cls, ServerAuthCredentialsProvider): _provider_registry["server_auth_credentials_providers"][short_hand] = cls + elif issubclass(cls, ServerAuthorizationProvider): + _provider_registry["server_authz_providers"][short_hand] = cls + elif issubclass(cls, ServerAuthorizationConfigurationProvider): + _provider_registry["server_authz_config_providers"][short_hand] = cls else: raise ValueError( "Only ClientAuthProvider, ClientAuthConfigurationProvider, " "ClientAuthCredentialsProvider, ServerAuthProvider, " - "ServerAuthConfigurationProvider, and ServerAuthCredentialsProvider, ClientAuthProtocolAdapter " - "can be registered." + "ServerAuthConfigurationProvider, and ServerAuthCredentialsProvider, " + "ClientAuthProtocolAdapter, ServerAuthorizationProvider, " + "ServerAuthorizationConfigurationProvider can be registered." ) return cls @@ -94,12 +105,17 @@ def resolve_provider( _key = "server_auth_config_providers" elif issubclass(cls, ServerAuthCredentialsProvider): _key = "server_auth_credentials_providers" + elif issubclass(cls, ServerAuthorizationProvider): + _key = "server_authz_providers" + elif issubclass(cls, ServerAuthorizationConfigurationProvider): + _key = "server_authz_config_providers" else: raise ValueError( "Only ClientAuthProvider, ClientAuthConfigurationProvider, " "ClientAuthCredentialsProvider, ServerAuthProvider, " - "ServerAuthConfigurationProvider, and ServerAuthCredentialsProvider,ClientAuthProtocolAdapter " - "can be registered." + "ServerAuthConfigurationProvider, and ServerAuthCredentialsProvider, " + "ClientAuthProtocolAdapter, ServerAuthorizationProvider," + "ServerAuthorizationConfigurationProvider, can be registered." ) if class_or_name in _provider_registry[_key]: return _provider_registry[_key][class_or_name] diff --git a/chromadb/auth/token/__init__.py b/chromadb/auth/token/__init__.py index 6dfa8635942..3280cc2a4e8 100644 --- a/chromadb/auth/token/__init__.py +++ b/chromadb/auth/token/__init__.py @@ -1,10 +1,12 @@ +import json import logging import string from enum import Enum -from typing import Tuple, Any, cast, Dict, TypeVar +from typing import List, Optional, Tuple, Any, TypedDict, cast, Dict, TypeVar from overrides import override from pydantic import SecretStr +import yaml from chromadb.auth import ( ServerAuthProvider, @@ -16,11 +18,12 @@ ClientAuthResponse, SecretStrAbstractCredentials, AbstractCredentials, + SimpleServerAuthenticationResponse, + SimpleUserIdentity, ) from chromadb.auth.registry import register_provider, resolve_provider from chromadb.config import System from chromadb.telemetry.opentelemetry import ( - OpenTelemetryClient, OpenTelemetryGranularity, trace_method, ) @@ -103,6 +106,81 @@ def validate_credentials(self, credentials: AbstractCredentials[T]) -> bool: return False return _creds["token"].get_secret_value() == self._token.get_secret_value() + @override + def get_user_identity( + self, credentials: AbstractCredentials[T] + ) -> Optional[SimpleUserIdentity]: + return None + + +class Token(TypedDict): + token: str + secret: str + + +class User(TypedDict): + id: str + role: str + tenant: Optional[str] + tokens: List[Token] + + +@register_provider("user_token_config") +class UserTokenConfigServerAuthCredentialsProvider(ServerAuthCredentialsProvider): + _users: List[User] + _token_user_mapping: Dict[str, str] # reverse mapping of token to user + + def __init__(self, system: System) -> None: + super().__init__(system) + if system.settings.chroma_server_auth_credentials_file: + system.settings.require("chroma_server_auth_credentials_file") + user_file = str(system.settings.chroma_server_auth_credentials_file) + with open(user_file) as f: + self._users = cast(List[User], yaml.safe_load(f)["users"]) + elif system.settings.chroma_server_auth_credentials: + self._users = cast( + List[User], json.loads(system.settings.chroma_server_auth_credentials) + ) + self._token_user_mapping = {} + for user in self._users: + for t in user["tokens"]: + token_str = t["token"] + check_token(token_str) + if token_str in self._token_user_mapping: + raise ValueError("Token already exists for another user") + self._token_user_mapping[token_str] = user["id"] + + def find_user_by_id(self, _user_id: str) -> Optional[User]: + for user in self._users: + if user["id"] == _user_id: + return user + return None + + @override + def validate_credentials(self, credentials: AbstractCredentials[T]) -> bool: + _creds = cast(Dict[str, SecretStr], credentials.get_credentials()) + if "token" not in _creds: + logger.error("Returned credentials do not contain token") + return False + return _creds["token"].get_secret_value() in self._token_user_mapping.keys() + + @override + def get_user_identity( + self, credentials: AbstractCredentials[T] + ) -> Optional[SimpleUserIdentity]: + _creds = cast(Dict[str, SecretStr], credentials.get_credentials()) + if "token" not in _creds: + logger.error("Returned credentials do not contain token") + return None + # below is just simple identity mapping and may need future work for more + # complex use cases + _user_id = self._token_user_mapping[_creds["token"].get_secret_value()] + _user = self.find_user_by_id(_user_id) + return SimpleUserIdentity( + user_id=_user_id, + tenant=_user["tenant"] if _user and "tenant" in _user else "*", + ) + class TokenAuthCredentials(SecretStrAbstractCredentials): _token: SecretStr @@ -161,19 +239,23 @@ def __init__(self, system: System) -> None: @trace_method("TokenAuthServerProvider.authenticate", OpenTelemetryGranularity.ALL) @override - def authenticate(self, request: ServerAuthenticationRequest[Any]) -> bool: + def authenticate( + self, request: ServerAuthenticationRequest[Any] + ) -> SimpleServerAuthenticationResponse: try: _auth_header = request.get_auth_info( AuthInfoType.HEADER, self._token_transport_header.value ) - return self._credentials_provider.validate_credentials( - TokenAuthCredentials.from_header( - _auth_header, self._token_transport_header - ) + _token_creds = TokenAuthCredentials.from_header( + _auth_header, self._token_transport_header + ) + return SimpleServerAuthenticationResponse( + self._credentials_provider.validate_credentials(_token_creds), + self._credentials_provider.get_user_identity(_token_creds), ) except Exception as e: logger.error(f"TokenAuthServerProvider.authenticate failed: {repr(e)}") - return False + return SimpleServerAuthenticationResponse(False, None) @register_provider("token") diff --git a/chromadb/config.py b/chromadb/config.py index 993b2a33d03..767cac4a3e4 100644 --- a/chromadb/config.py +++ b/chromadb/config.py @@ -63,7 +63,8 @@ # TODO: Don't use concrete types here to avoid circular deps. Strings are fine for right here! _abstract_type_keys: Dict[str, str] = { - "chromadb.api.API": "chroma_api_impl", # NOTE: this is to support legacy api construction. Use ServerAPI instead + # NOTE: this is to support legacy api construction. Use ServerAPI instead + "chromadb.api.API": "chroma_api_impl", "chromadb.api.ServerAPI": "chroma_api_impl", "chromadb.telemetry.product.ProductTelemetryClient": "chroma_product_telemetry_impl", "chromadb.ingest.Producer": "chroma_producer_impl", @@ -85,8 +86,8 @@ class Settings(BaseSettings): # type: ignore # Legacy config has to be kept around because pydantic will error # on nonexisting keys chroma_db_impl: Optional[str] = None - - chroma_api_impl: str = "chromadb.api.segment.SegmentAPI" # Can be "chromadb.api.segment.SegmentAPI" or "chromadb.api.fastapi.FastAPI" + # Can be "chromadb.api.segment.SegmentAPI" or "chromadb.api.fastapi.FastAPI" + chroma_api_impl: str = "chromadb.api.segment.SegmentAPI" chroma_product_telemetry_impl: str = "chromadb.telemetry.product.posthog.Posthog" # Required for backwards compatibility chroma_telemetry_impl: str = chroma_product_telemetry_impl @@ -120,7 +121,8 @@ class Settings(BaseSettings): # type: ignore chroma_server_ssl_enabled: Optional[bool] = False chroma_server_api_default_path: Optional[str] = "/api/v1" chroma_server_grpc_port: Optional[str] = None - chroma_server_cors_allow_origins: List[str] = [] # eg ["http://localhost:3000"] + # eg ["http://localhost:3000"] + chroma_server_cors_allow_origins: List[str] = [] pulsar_broker_url: Optional[str] = None pulsar_admin_port: Optional[str] = "8080" @@ -178,6 +180,35 @@ def chroma_server_auth_credentials_file_non_empty_file_exists( chroma_client_auth_token_transport_header: Optional[str] = None chroma_server_auth_token_transport_header: Optional[str] = None + chroma_server_authz_provider: Optional[str] = None + + chroma_server_authz_ignore_paths: Dict[str, List[str]] = { + "/api/v1": ["GET"], + "/api/v1/heartbeat": ["GET"], + "/api/v1/version": ["GET"], + } + chroma_server_authz_config_file: Optional[str] = None + + chroma_server_authz_config: Optional[Dict[str, Any]] = None + + @validator( + "chroma_server_authz_config_file", pre=True, always=True, allow_reuse=True + ) + def chroma_server_authz_config_file_non_empty_file_exists( + cls: Type["Settings"], v: str # type: ignore + ) -> Optional[str]: + if v and not v.strip(): + raise ValueError( + "chroma_server_authz_config_file cannot be empty or just whitespace" + ) + if v and not os.path.isfile(os.path.join(v)): + raise ValueError(f"chroma_server_authz_config_file [{v}] does not exist") + return v + + chroma_server_authz_config_provider: Optional[ + str + ] = "chromadb.auth.authz.LocalUserConfigAuthorizationConfigurationProvider" + anonymized_telemetry: bool = True chroma_otel_collection_endpoint: Optional[str] = "" diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index b66bf33bda6..e76d43023a4 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -7,14 +7,23 @@ from fastapi.routing import APIRoute from fastapi import HTTPException, status from uuid import UUID - import chromadb from chromadb.api.models.Collection import Collection from chromadb.api.types import GetResult, QueryResult +from chromadb.auth import ( + AuthzDynamicParams, + AuthzResourceActions, + AuthzResourceTypes, + DynamicAuthzResource, +) from chromadb.auth.fastapi import ( FastAPIChromaAuthMiddleware, FastAPIChromaAuthMiddlewareWrapper, + FastAPIChromaAuthzMiddleware, + FastAPIChromaAuthzMiddlewareWrapper, + authz_context, ) +from chromadb.auth.fastapi_utils import attr_from_resource_object from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System import chromadb.server import chromadb.api @@ -125,11 +134,17 @@ def __init__(self, settings: Settings): allow_origins=settings.chroma_server_cors_allow_origins, allow_methods=["*"], ) + + if settings.chroma_server_authz_provider: + self._app.add_middleware( + FastAPIChromaAuthzMiddlewareWrapper, + authz_middleware=self._api.require(FastAPIChromaAuthzMiddleware), + ) + if settings.chroma_server_auth_provider: - self._auth_middleware = self._api.require(FastAPIChromaAuthMiddleware) self._app.add_middleware( FastAPIChromaAuthMiddlewareWrapper, - auth_middleware=self._auth_middleware, + auth_middleware=self._api.require(FastAPIChromaAuthMiddleware), ) self.router = ChromaAPIRouter() @@ -262,30 +277,74 @@ def version(self) -> str: return self._api.get_version() @trace_method("FastAPI.create_database", OpenTelemetryGranularity.OPERATION) + @authz_context( + action=AuthzResourceActions.CREATE_DATABASE, + resource=DynamicAuthzResource( + type=AuthzResourceTypes.DB, + attributes=attr_from_resource_object(type=AuthzResourceTypes.DB), + ), + ) def create_database( self, database: CreateDatabase, tenant: str = DEFAULT_TENANT ) -> None: return self._api.create_database(database.name, tenant) @trace_method("FastAPI.get_database", OpenTelemetryGranularity.OPERATION) + @authz_context( + action=AuthzResourceActions.GET_DATABASE, + resource=DynamicAuthzResource( + id="*", + type=AuthzResourceTypes.DB, + ), + ) def get_database(self, database: str, tenant: str = DEFAULT_TENANT) -> Database: return self._api.get_database(database, tenant) @trace_method("FastAPI.create_tenant", OpenTelemetryGranularity.OPERATION) + @authz_context( + action=AuthzResourceActions.CREATE_TENANT, + resource=DynamicAuthzResource( + type=AuthzResourceTypes.TENANT, + ), + ) def create_tenant(self, tenant: CreateTenant) -> None: return self._api.create_tenant(tenant.name) @trace_method("FastAPI.get_tenant", OpenTelemetryGranularity.OPERATION) + @authz_context( + action=AuthzResourceActions.GET_TENANT, + resource=DynamicAuthzResource( + id="*", + type=AuthzResourceTypes.TENANT, + ), + ) def get_tenant(self, tenant: str) -> Tenant: return self._api.get_tenant(tenant) @trace_method("FastAPI.list_collections", OpenTelemetryGranularity.OPERATION) + @authz_context( + action=AuthzResourceActions.LIST_COLLECTIONS, + resource=DynamicAuthzResource( + id="*", + type=AuthzResourceTypes.DB, + ), + ) def list_collections( self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE ) -> Sequence[Collection]: return self._api.list_collections(tenant=tenant, database=database) @trace_method("FastAPI.create_collection", OpenTelemetryGranularity.OPERATION) + @authz_context( + action=[ + AuthzResourceActions.CREATE_COLLECTION, + AuthzResourceActions.GET_OR_CREATE_COLLECTION, + ], + resource=DynamicAuthzResource( + id="*", + type=AuthzResourceTypes.DB, + ), + ) def create_collection( self, collection: CreateCollection, @@ -301,6 +360,13 @@ def create_collection( ) @trace_method("FastAPI.get_collection", OpenTelemetryGranularity.OPERATION) + @authz_context( + action=AuthzResourceActions.GET_COLLECTION, + resource=DynamicAuthzResource( + id=AuthzDynamicParams.from_function_kwargs(arg_name="collection_name"), + type=AuthzResourceTypes.COLLECTION, + ), + ) def get_collection( self, collection_name: str, @@ -312,6 +378,13 @@ def get_collection( ) @trace_method("FastAPI.update_collection", OpenTelemetryGranularity.OPERATION) + @authz_context( + action=AuthzResourceActions.UPDATE_COLLECTION, + resource=DynamicAuthzResource( + id=AuthzDynamicParams.from_function_kwargs(arg_name="collection_id"), + type=AuthzResourceTypes.COLLECTION, + ), + ) def update_collection( self, collection_id: str, collection: UpdateCollection ) -> None: @@ -322,6 +395,13 @@ def update_collection( ) @trace_method("FastAPI.delete_collection", OpenTelemetryGranularity.OPERATION) + @authz_context( + action=AuthzResourceActions.DELETE_COLLECTION, + resource=DynamicAuthzResource( + id=AuthzDynamicParams.from_function_kwargs(arg_name="collection_name"), + type=AuthzResourceTypes.COLLECTION, + ), + ) def delete_collection( self, collection_name: str, @@ -333,40 +413,68 @@ def delete_collection( ) @trace_method("FastAPI.add", OpenTelemetryGranularity.OPERATION) + @authz_context( + action=AuthzResourceActions.ADD, + resource=DynamicAuthzResource( + id=AuthzDynamicParams.from_function_kwargs(arg_name="collection_id"), + type=AuthzResourceTypes.COLLECTION, + ), + ) def add(self, collection_id: str, add: AddEmbedding) -> None: try: result = self._api._add( collection_id=_uuid(collection_id), - embeddings=add.embeddings, - metadatas=add.metadatas, - documents=add.documents, + embeddings=add.embeddings, # type: ignore + metadatas=add.metadatas, # type: ignore + documents=add.documents, # type: ignore ids=add.ids, ) except InvalidDimensionException as e: raise HTTPException(status_code=500, detail=str(e)) - return result + return result # type: ignore @trace_method("FastAPI.update", OpenTelemetryGranularity.OPERATION) + @authz_context( + action=AuthzResourceActions.UPDATE, + resource=DynamicAuthzResource( + id=AuthzDynamicParams.from_function_kwargs(arg_name="collection_id"), + type=AuthzResourceTypes.COLLECTION, + ), + ) def update(self, collection_id: str, add: UpdateEmbedding) -> None: return self._api._update( ids=add.ids, collection_id=_uuid(collection_id), embeddings=add.embeddings, - documents=add.documents, - metadatas=add.metadatas, + documents=add.documents, # type: ignore + metadatas=add.metadatas, # type: ignore ) @trace_method("FastAPI.upsert", OpenTelemetryGranularity.OPERATION) + @authz_context( + action=AuthzResourceActions.UPSERT, + resource=DynamicAuthzResource( + id=AuthzDynamicParams.from_function_kwargs(arg_name="collection_id"), + type=AuthzResourceTypes.COLLECTION, + ), + ) def upsert(self, collection_id: str, upsert: AddEmbedding) -> None: return self._api._upsert( collection_id=_uuid(collection_id), ids=upsert.ids, - embeddings=upsert.embeddings, - documents=upsert.documents, - metadatas=upsert.metadatas, + embeddings=upsert.embeddings, # type: ignore + documents=upsert.documents, # type: ignore + metadatas=upsert.metadatas, # type: ignore ) @trace_method("FastAPI.get", OpenTelemetryGranularity.OPERATION) + @authz_context( + action=AuthzResourceActions.GET, + resource=DynamicAuthzResource( + id=AuthzDynamicParams.from_function_kwargs(arg_name="collection_id"), + type=AuthzResourceTypes.COLLECTION, + ), + ) def get(self, collection_id: str, get: GetEmbedding) -> GetResult: return self._api._get( collection_id=_uuid(collection_id), @@ -380,22 +488,51 @@ def get(self, collection_id: str, get: GetEmbedding) -> GetResult: ) @trace_method("FastAPI.delete", OpenTelemetryGranularity.OPERATION) + @authz_context( + action=AuthzResourceActions.DELETE, + resource=DynamicAuthzResource( + id=AuthzDynamicParams.from_function_kwargs(arg_name="collection_id"), + type=AuthzResourceTypes.COLLECTION, + ), + ) def delete(self, collection_id: str, delete: DeleteEmbedding) -> List[UUID]: return self._api._delete( - where=delete.where, + where=delete.where, # type: ignore ids=delete.ids, collection_id=_uuid(collection_id), where_document=delete.where_document, ) @trace_method("FastAPI.count", OpenTelemetryGranularity.OPERATION) + @authz_context( + action=AuthzResourceActions.COUNT, + resource=DynamicAuthzResource( + id=AuthzDynamicParams.from_function_kwargs(arg_name="collection_id"), + type=AuthzResourceTypes.COLLECTION, + ), + ) def count(self, collection_id: str) -> int: return self._api._count(_uuid(collection_id)) + @trace_method("FastAPI.reset", OpenTelemetryGranularity.OPERATION) + @authz_context( + action=AuthzResourceActions.RESET, + resource=DynamicAuthzResource( + id="*", + type=AuthzResourceTypes.DB, + ), + ) def reset(self) -> bool: return self._api.reset() @trace_method("FastAPI.get_nearest_neighbors", OpenTelemetryGranularity.OPERATION) + @authz_context( + action=AuthzResourceActions.QUERY, + resource=DynamicAuthzResource( + id=AuthzDynamicParams.from_function_kwargs(arg_name="collection_id"), + type=AuthzResourceTypes.COLLECTION, + ), + ) def get_nearest_neighbors( self, collection_id: str, query: QueryEmbedding ) -> QueryResult: diff --git a/chromadb/test/auth/test_simple_rbac_authz.py b/chromadb/test/auth/test_simple_rbac_authz.py new file mode 100644 index 00000000000..06c8bde569d --- /dev/null +++ b/chromadb/test/auth/test_simple_rbac_authz.py @@ -0,0 +1,332 @@ +import json +import random +import string +from typing import Dict, Any, Tuple +import uuid +import hypothesis.strategies as st +import pytest +from hypothesis import given, settings +from chromadb import AdminClient + +from chromadb.api import AdminAPI, ServerAPI +from chromadb.api.models.Collection import Collection +from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System +from chromadb.test.conftest import _fastapi_fixture + + +valid_action_space = [ + "tenant:create_tenant", + "tenant:get_tenant", + "db:create_database", + "db:get_database", + "db:reset", + "db:list_collections", + "collection:get_collection", + "db:create_collection", + "db:get_or_create_collection", + "collection:delete_collection", + "collection:update_collection", + "collection:add", + "collection:delete", + "collection:get", + "collection:query", + "collection:peek", + "collection:update", + "collection:upsert", + "collection:count", +] + +role_name = st.text(alphabet=string.ascii_letters, min_size=1, max_size=20) + +user_name = st.text(alphabet=string.ascii_letters, min_size=1, max_size=20) + +actions = st.lists( + st.sampled_from(valid_action_space), min_size=1, max_size=len(valid_action_space) +) + + +@st.composite +def master_user(draw: st.DrawFn) -> Tuple[Dict[str, Any], Dict[str, Any]]: + return { + "role": "__master_role__", + "id": "__master__", + "tenant": DEFAULT_TENANT, + "tokens": [ + { + "token": f"{random.randint(1,1000000)}_" + + draw( + st.text( + alphabet=string.ascii_letters + string.digits, + min_size=1, + max_size=25, + ) + ) + } + for _ in range(2) + ], + }, { + "__master_role__": { + "actions": valid_action_space, + "unauthorized_actions": [], + } + } + + +@st.composite +def user_role_config(draw: st.DrawFn) -> Tuple[Dict[str, Any], Dict[str, Any]]: + role = draw(role_name) + user = draw(user_name) + actions_list = draw(actions) + if any( + action in actions_list + for action in [ + "collection:add", + "collection:delete", + "collection:get", + "collection:query", + "collection:peek", + "collection:update", + "collection:upsert", + "collection:count", + ] + ): + actions_list.append("collection:get_collection") + if any( + action in actions_list + for action in [ + "collection:peek", + ] + ): + actions_list.append("collection:get") + actions_list.extend( + [ + "tenant:get_tenant", + "db:get_database", + ] + ) + unauthorized_actions = set(valid_action_space) - set(actions_list) + _role_config = { + f"{role}": { + "actions": actions_list, + "unauthorized_actions": list(unauthorized_actions), + } + } + + return { + "role": role, + "id": user, + "tenant": DEFAULT_TENANT, + "tokens": [ + { + "token": f"{random.randint(1,1000000)}_" + + draw( + st.text( + alphabet=string.ascii_letters + string.digits, + min_size=1, + max_size=25, + ) + ) + } + for _ in range(2) + ], + }, _role_config + + +@st.composite +def rbac_config(draw: st.DrawFn) -> Dict[str, Any]: + user_roles = draw( + st.lists(user_role_config().filter( + lambda t: t[0]), min_size=1, max_size=10) + ) + muser_role = draw(st.lists(master_user(), min_size=1, max_size=1)) + users = [] + roles = [] + for user, role in user_roles: + users.append(user) + roles.append(role) + + for muser, mrole in muser_role: + users.append(muser) + roles.append(mrole) + roles_mapping = {} + for role in roles: + roles_mapping.update(role) + _rbac_config = { + "roles_mapping": roles_mapping, + "users": users, + } + return _rbac_config + + +@st.composite +def token_config(draw: st.DrawFn) -> Dict[str, Any]: + token_header = draw(st.sampled_from( + ["AUTHORIZATION", "X_CHROMA_TOKEN", None])) + server_provider = draw( + st.sampled_from( + ["token", "chromadb.auth.token.TokenAuthServerProvider"]) + ) + client_provider = draw( + st.sampled_from( + ["token", "chromadb.auth.token.TokenAuthClientProvider"]) + ) + server_authz_provider = draw( + st.sampled_from( + ["chromadb.auth.authz.SimpleRBACAuthorizationProvider"]) + ) + server_credentials_provider = draw(st.sampled_from(["user_token_config"])) + # _rbac_config = draw(rbac_config()) + persistence = draw(st.booleans()) + return { + "token_transport_header": token_header, + "chroma_server_auth_credentials_file": None, + "chroma_server_auth_provider": server_provider, + "chroma_client_auth_provider": client_provider, + "chroma_server_authz_config_file": None, + "chroma_server_auth_credentials_provider": server_credentials_provider, + "chroma_server_authz_provider": server_authz_provider, + "is_persistent": persistence, + } + + +api_executors = { + "db:create_database": lambda api, mapi, aapi: ( + aapi.create_database(f"test-{uuid.uuid4()}") + ), + "db:get_database": lambda api, mapi, aapi: (aapi.get_database(DEFAULT_DATABASE),), + "tenant:create_tenant": lambda api, mapi, aapi: ( + aapi.create_tenant(f"test-{uuid.uuid4()}") + ), + "tenant:get_tenant": lambda api, mapi, aapi: (aapi.get_tenant(DEFAULT_TENANT),), + "db:reset": lambda api, mapi, _: api.reset(), + "db:list_collections": lambda api, mapi, _: api.list_collections(), + "collection:get_collection": lambda api, mapi, _: ( + # pre-condition + mcol := mapi.create_collection(f"test-get-{uuid.uuid4()}"), + api.get_collection(f"{mcol.name}"), + ), + "db:create_collection": lambda api, mapi, _: ( + api.create_collection(f"test-create-{uuid.uuid4()}"), + ), + "db:get_or_create_collection": lambda api, mapi, _: ( + api.get_or_create_collection(f"test-get-or-create-{uuid.uuid4()}") + ), + "collection:delete_collection": lambda api, mapi, _: ( + # pre-condition + mcol := mapi.create_collection(f"test-delete-col-{uuid.uuid4()}"), + api.delete_collection(f"{mcol.name}"), + ), + "collection:update_collection": lambda api, mapi, _: ( + # pre-condition + mcol := mapi.create_collection(f"test-modify-col-{uuid.uuid4()}"), + col := Collection(api, f"{mcol.name}", mcol.id), + col.modify(metadata={"test": "test"}), + ), + "collection:add": lambda api, mapi, _: ( + mcol := mapi.create_collection(f"test-add-doc-{uuid.uuid4()}"), + col := Collection(api, f"{mcol.name}", mcol.id), + col.add(documents=["test"], ids=["1"]), + ), + "collection:delete": lambda api, mapi, _: ( + mcol := mapi.create_collection(f"test-delete-doc-{uuid.uuid4()}"), + mcol.add(documents=["test"], ids=["1"]), + col := Collection(client=api, name=f"{mcol.name}", id=mcol.id), + col.delete(ids=["1"]), + ), + "collection:get": lambda api, mapi, _: ( + mcol := mapi.create_collection(f"test-get-doc-{uuid.uuid4()}"), + mcol.add(documents=["test"], ids=["1"]), + col := Collection(api, f"{mcol.name}", mcol.id), + col.get(ids=["1"]), + ), + "collection:query": lambda api, mapi, _: ( + mcol := mapi.create_collection(f"test-query-doc-{uuid.uuid4()}"), + mcol.add(documents=["test"], ids=["1"]), + col := Collection(api, f"{mcol.name}", mcol.id), + col.query(query_texts=["test"]), + ), + "collection:peek": lambda api, mapi, _: ( + mcol := mapi.create_collection(f"test-peek-{uuid.uuid4()}"), + mcol.add(documents=["test"], ids=["1"]), + col := Collection(api, f"{mcol.name}", mcol.id), + col.peek(), + ), + "collection:update": lambda api, mapi, _: ( + mcol := mapi.create_collection(f"test-update-{uuid.uuid4()}"), + mcol.add(documents=["test"], ids=["1"]), + col := Collection(api, f"{mcol.name}", mcol.id), + col.update(ids=["1"], documents=["test1"]), + ), + "collection:upsert": lambda api, mapi, _: ( + mcol := mapi.create_collection(f"test-upsert-{uuid.uuid4()}"), + mcol.add(documents=["test"], ids=["1"]), + col := Collection(api, f"{mcol.name}", mcol.id), + col.upsert(ids=["1"], documents=["test1"]), + ), + "collection:count": lambda api, mapi, _: ( + mcol := mapi.create_collection(f"test-count-{uuid.uuid4()}"), + mcol.add(documents=["test"], ids=["1"]), + col := Collection(api, f"{mcol.name}", mcol.id), + col.count(), + ), +} + + +def master_api(_settings: Settings) -> Tuple[ServerAPI, AdminAPI]: + system = System(_settings) + api = system.instance(ServerAPI) + admin_api = AdminClient(api.get_settings()) + system.start() + return api, admin_api + + +@settings(max_examples=10) +@given(token_config=token_config(), rbac_config=rbac_config()) +def test_authz(token_config: Dict[str, Any], rbac_config: Dict[str, Any]) -> None: + authz_config = rbac_config + token_config["chroma_server_authz_config"] = rbac_config + token_config["chroma_server_auth_credentials"] = json.dumps( + rbac_config["users"]) + random_user = random.choice( + [user for user in authz_config["users"] if user["id"] != "__master__"] + ) + _master_user = [ + user for user in authz_config["users"] if user["id"] == "__master__" + ][0] + random_token = random.choice(random_user["tokens"])["token"] + api = _fastapi_fixture( + is_persistent=token_config["is_persistent"], + chroma_server_auth_provider=token_config["chroma_server_auth_provider"], + chroma_server_auth_credentials_provider=token_config[ + "chroma_server_auth_credentials_provider" + ], + chroma_server_auth_credentials=token_config["chroma_server_auth_credentials"], + chroma_client_auth_provider=token_config["chroma_client_auth_provider"], + chroma_client_auth_token_transport_header=token_config[ + "token_transport_header" + ], + chroma_server_auth_token_transport_header=token_config[ + "token_transport_header" + ], + chroma_server_authz_provider=token_config["chroma_server_authz_provider"], + chroma_server_authz_config=token_config["chroma_server_authz_config"], + chroma_client_auth_credentials=random_token, + ) + _sys: System = next(api) + _sys.reset_state() + _master_settings = Settings(**dict(_sys.settings)) + _master_settings.chroma_client_auth_credentials = _master_user["tokens"][0]["token"] + _master_api, admin_api = master_api(_master_settings) + _api = _sys.instance(ServerAPI) + _api.heartbeat() + for action in authz_config["roles_mapping"][random_user["role"]]["actions"]: + print(action) + api_executors[action](_api, _master_api, admin_api) # type: ignore + for unauthorized_action in authz_config["roles_mapping"][random_user["role"]][ + "unauthorized_actions" + ]: + with pytest.raises(Exception) as ex: + api_executors[unauthorized_action]( + _api, _master_api, admin_api + ) # type: ignore + assert "Unauthorized" in str(ex) or "Forbidden" in str(ex) diff --git a/chromadb/test/conftest.py b/chromadb/test/conftest.py index 9e716cb0058..55a389980c9 100644 --- a/chromadb/test/conftest.py +++ b/chromadb/test/conftest.py @@ -6,6 +6,8 @@ import tempfile import time from typing import ( + Any, + Dict, Generator, Iterator, List, @@ -48,12 +50,14 @@ NOT_CLUSTER_ONLY = os.getenv("CHROMA_CLUSTER_TEST_ONLY") != "1" + def skip_if_not_cluster() -> pytest.MarkDecorator: return pytest.mark.skipif( NOT_CLUSTER_ONLY, reason="Requires Kubernetes to be running with a valid config", ) + def find_free_port() -> int: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("", 0)) @@ -70,6 +74,9 @@ def _run_server( chroma_server_auth_credentials_file: Optional[str] = None, chroma_server_auth_credentials: Optional[str] = None, chroma_server_auth_token_transport_header: Optional[str] = None, + chroma_server_authz_provider: Optional[str] = None, + chroma_server_authz_config_file: Optional[str] = None, + chroma_server_authz_config: Optional[Dict[str, Any]] = None, ) -> None: """Run a Chroma server locally""" if is_persistent and persist_directory: @@ -87,6 +94,9 @@ def _run_server( chroma_server_auth_credentials_file=chroma_server_auth_credentials_file, chroma_server_auth_credentials=chroma_server_auth_credentials, chroma_server_auth_token_transport_header=chroma_server_auth_token_transport_header, + chroma_server_authz_provider=chroma_server_authz_provider, + chroma_server_authz_config_file=chroma_server_authz_config_file, + chroma_server_authz_config=chroma_server_authz_config, ) else: settings = Settings( @@ -102,6 +112,9 @@ def _run_server( chroma_server_auth_credentials_file=chroma_server_auth_credentials_file, chroma_server_auth_credentials=chroma_server_auth_credentials, chroma_server_auth_token_transport_header=chroma_server_auth_token_transport_header, + chroma_server_authz_provider=chroma_server_authz_provider, + chroma_server_authz_config_file=chroma_server_authz_config_file, + chroma_server_authz_config=chroma_server_authz_config, ) server = chromadb.server.fastapi.FastAPI(settings) uvicorn.run(server.app(), host="0.0.0.0", port=port, log_level="error") @@ -130,6 +143,9 @@ def _fastapi_fixture( chroma_server_auth_credentials: Optional[str] = None, chroma_client_auth_token_transport_header: Optional[str] = None, chroma_server_auth_token_transport_header: Optional[str] = None, + chroma_server_authz_provider: Optional[str] = None, + chroma_server_authz_config_file: Optional[str] = None, + chroma_server_authz_config: Optional[Dict[str, Any]] = None, ) -> Generator[System, None, None]: """Fixture generator that launches a server in a separate process, and yields a fastapi client connect to it""" @@ -146,6 +162,9 @@ def _fastapi_fixture( Optional[str], Optional[str], Optional[str], + Optional[str], + Optional[str], + Optional[Dict[str, Any]], ] = ( port, False, @@ -155,6 +174,9 @@ def _fastapi_fixture( chroma_server_auth_credentials_file, chroma_server_auth_credentials, chroma_server_auth_token_transport_header, + chroma_server_authz_provider, + chroma_server_authz_config_file, + chroma_server_authz_config, ) persist_directory = None if is_persistent: @@ -168,6 +190,9 @@ def _fastapi_fixture( chroma_server_auth_credentials_file, chroma_server_auth_credentials, chroma_server_auth_token_transport_header, + chroma_server_authz_provider, + chroma_server_authz_config_file, + chroma_server_authz_config, ) proc = ctx.Process(target=_run_server, args=args, daemon=True) proc.start() diff --git a/docs/cip/CIP-10112023_Authorization.md b/docs/cip/CIP-10112023_Authorization.md new file mode 100644 index 00000000000..1be1e4c51b6 --- /dev/null +++ b/docs/cip/CIP-10112023_Authorization.md @@ -0,0 +1,299 @@ +# CIP-10112023: Authorization + +## Status + +Current Status: `Under Discussion` + +## **Motivation** + +The motivation for introducing an authorization feature in Chroma is to address the lack of a proper authorization model that many users are struggling with, especially those who deploy production apps. Additionally, as Chroma is gearing up for production-grade deployments out of the box, it is essential to have a proper authorization model in place for distributed and hosted Chroma instances. + +## **Public Interfaces** + +No changes to public interfaces are proposed in this CIP. + +## **Proposed Changes** + +In this CIP we propose the introduction of abstractions necessary for implementing a multi-user authorization scheme in a pluggable way. We also propose a baseline implementation of such a scheme which will be shipped with Chroma as a default authorization provider. + +It is important to keep in mind that the client/server interaction in Chroma is meant to be stateless, as such the Authorization approach must also follow the same principle. This means that the server must not store any state about the user's authorization. The authorization decision must be made on a per-request basis. + +The diagram below illustrates the levels of abstractions we introduce: + +![Server-Side Authorization Workflow](assets/CIP-10112023_Authorization_Workflow.png) + +- (1) Client Sends Request to Chroma Server +- (2) Authentication Middleware intercepts the request +- (3a) Authentication Provider attempts to authenticate the user +- (3b) Authentication Provider returns success (with user identity) or failure +- (3c) Authentication Middleware returns success or failure to server +- (4) Server passes the request with user identity to the Authorization Middleware +- (5) Authorization Middleware creates Authorization Request Context +- (6) Authorization Context Decorator (at API endpoint) intercepts the call and using Authorization Request Context creates an Authorization Context that is then passed to the Authorization Provider +- (7)Authorization Context Decorator raises and error if the Authorization Provider returns a failure or passes the request to the API endpoint if the Authorization Provider returns success +- (8a) Request is passed to the API endpoint for execution +- (8b) Response is returned to the client + +In the above diagram we highlight the new abstractions we introduce in this CIP and we also demonstrate the interop with the existing Authentication + +### Concepts + +#### Basic Authorization Terms + +##### User + +A user is an entity that can perform actions on resources. A user can be a human or a machine. + +##### Resource + +A resource is an entity that can be acted upon. A resource can be a database, a collection, a document. + +> Note: In this release we do not support document as a resource. + +##### Action + +An action is an operation that can be performed on a resource. An action can be `read`, `write`, `delete`, `update`, `create`, `list`, `count`, `query`, `peek`, `get`, `add`, `upsert`, `get_or_create`. Actions are resource specific. + +##### Role + +A role is a collection of actions that a user can perform on a resource. This pertains to RBAC or Role Based Access Control. + +#### Chroma Authorization Terms + +##### ServerAuthorizationProvider + +The `ServerAuthorizationProvider` is a class that abstracts a provider that will authorize requests to the Chroma server (FastAPI). In practical terms the provider will integrate with an external authorization service (e.g. Auth0, Okta, Permit.io etc.) and will be responsible for allowing or denying the user request. + +In our baseline implementation we will provide a simple file-based RBAC authorization provider that will read authorization configuration from a YAML file. + +##### ServerAuthzConfigurationProvider + +The `ServerAuthzConfigurationProvider` is a class that abstracts a the configuration needed for authorization provider to work. In practice that implies, reading secrets from environment variables, reading configuration from a file, or reading configuration from a database or secrets file, or even KMS. + +In our baseline implementation the AuthzConfigurationProvider will read configuration from a YAML file that contains the authorization configuration. + +##### ServerAuthorizationRequest + +The `ServerAuthorizationRequest` encapsulates the authorization context. + +##### ServerAuthorizationResponse + +Authorization response provides authorization provider evaluation response. It returns a boolean response indicating whether the request is allowed or denied. + +##### ChromaAuthzMiddleware + +The `ChromaAuthzMiddleware` is an abstraction for the server-side middleware. At the time of writing we only support FastAPI. The middleware interface supports several methods: + +- `authorize` - authorizes the request against the authorization provider. +- `ignore_operation` - determines whether or not the operation should be ignored by the middleware +- `instrument_server` - an optional method for additional server instrumentation. For example, header injection. + +##### AuthorizationError + +Error thrown when an authorization request is disallowed/denied by the authorization provider. Depending on authorization provider's implementation such error may also be thrown when the authorization provider is not available or an internal error ocurred. + +Client semantics of this error is a 403 Unauthorized error being returned over HTTP interface. + +##### AuthorizationContext + +The AuthorizationContext is composed of three components as defined in #Basic Authorization Terms: + +- User +- Resource +- Action + + +```json +{ +"user": {"id": "API Token or User Id"}, +"resource": {"namespace": "*", "id": "collection_id","type": "database"}, +"action": {"id":"get_or_create"}, +} +``` + +We intentionally want to keep this as minimal as possible to avoid any unnecessary complexity and to allow users to easily understand the authorization model. However the context is just an abstraction of the above representation and each authorization provider will need to implement the above and if necessary extend it to support additional information. + +We propose the following classes to represent the above: + +```python +@dataclass +class AuthzUser: + id: Optional[str] + attributes: Optional[Dict[str, Any]] = None + claims: Optional[Dict[str, Any]] = None + + +@dataclass +class AuthzResource: + id: Optional[str] + type: Optional[str] + namespace: Optional[str] + attributes: Optional[Dict[str, Any]] = None + + +@dataclass +class AuthzAction: + id: str + attributes: Optional[Dict[str, Any]] = None + + +@dataclass +class AuthorizationContext: + user: AuthzUser + resource: AuthzResource + action: AuthzAction + +``` + +##### User Identity + +In this CIP we also introduce a handover or bridge mechanism from authentication to authorization which we term `User Identity`. The object is meant to encapsulate the user identity and possibly also claims, roles and attributes in the future. + +```python +class UserIdentity(EnforceOverrides, ABC): + @abstractmethod + def get_user_id(self) -> str: + ... +``` + +### Baseline Implementation + +In this section we propose a minimal implementation example of the authorization framework which will also ship in Chroma as a default authorization provider and a reference implementation. Our reference implementation relies on static configuration files in YAML format. + +We introduce the following implementations: + +- `LocalUserConfigAuthorizationConfigurationProvider` - a simple authz configuration to read the yaml configuration file. +- `SimpleRBACAuthorizationProvider` - a simple RBAC authorization provider that reads the configuration from the configuration provider, creates a list of tuples for every user and his role action mappings (e.g. `('user@example.com','tenant_x', 'db', 'list_collections')`) and evaluates the authorization request against the list of tuples. + +#### Authentication and Authorization Config Scheme + +In our baseline implementation we propose the following configuration scheme: + +```yaml +resource_type_action: # This is here just for reference + - tenant:create_tenant + - tenant:get_tenant + - db:create_database + - db:get_database + - db:reset + - db:list_collections + - collection:get_collection + - db:create_collection + - db:get_or_create_collection + - collection:delete_collection + - collection:update_collection + - collection:add + - collection:delete + - collection:get + - collection:query + - collection:peek #from API perspective this is the same as collection:get + - collection:count + - collection:update + - collection:upsert + +roles_mapping: + admin: + actions: + [ + "tenant:create_tenant", + "tenant:get_tenant", + "db:create_database", + "db:get_database", + "db:reset", + "db:list_collections", + "collection:get_collection", + "db:create_collection", + "db:get_or_create_collection", + "collection:delete_collection", + "collection:update_collection", + "collection:add", + "collection:delete", + "collection:get", + "collection:query", + "collection:peek", + "collection:update", + "collection:upsert", + "collection:count", + ] + write: + actions: + [ + "tenant:get_tenant", + "db:get_database", + "db:list_collections", + "collection:get_collection", + "db:create_collection", + "db:get_or_create_collection", + "collection:delete_collection", + "collection:update_collection", + "collection:add", + "collection:delete", + "collection:get", + "collection:query", + "collection:peek", + "collection:update", + "collection:upsert", + "collection:count", + ] + db_read: + actions: + [ + "tenant:get_tenant", + "db:get_database", + "db:list_collections", + "collection:get_collection", + "db:create_collection", + "db:get_or_create_collection", + "collection:delete_collection", + "collection:update_collection", + ] + collection_read: + actions: + [ + "tenant:get_tenant", + "db:get_database", + "db:list_collections", + "collection:get_collection", + "collection:get", + "collection:query", + "collection:peek", + "collection:count", + ] + collection_x_read: + actions: + [ + "tenant:get_tenant", + "db:get_database", + "collection:get_collection", + "collection:get", + "collection:query", + "collection:peek", + "collection:count", + ] + resources: [""] #not yet supported +users: + - id: user@example.com + role: admin + tenant: my_tenant + tokens: + - token: test-token-admin + secret: my_api_secret # not yet supported + - id: Anonymous + role: db_read + tokens: + - token: my_api_token + secret: my_api_secret + +``` + +## **Compatibility, Deprecation, and Migration Plan** + +This CIP is backwards compatible with older versions of Chroma clients. + +## **Test Plan** + +Property and Integration tests. + +## **Rejected Alternatives** + +We considered several alternatives that are more vendor specific (such az Auth0, Okta, Permit.io etc.), but we decided to go with a more generic approach that will allow users to be able to extend the authorization framework to support additional features and providers. diff --git a/docs/cip/assets/CIP-10112023_Authorization_Workflow.png b/docs/cip/assets/CIP-10112023_Authorization_Workflow.png new file mode 100644 index 00000000000..d526f3fd27b Binary files /dev/null and b/docs/cip/assets/CIP-10112023_Authorization_Workflow.png differ diff --git a/examples/basic_functionality/authz/README.md b/examples/basic_functionality/authz/README.md new file mode 100644 index 00000000000..1201ac7374a --- /dev/null +++ b/examples/basic_functionality/authz/README.md @@ -0,0 +1,155 @@ +# Authorization + +## Configuration + +### Resource Actions + +```yaml +resource_type_action: # This is here just for reference + - tenant:create_tenant + - tenant:get_tenant + - db:create_database + - db:get_database + - db:reset + - db:list_collections + - collection:get_collection + - db:create_collection + - db:get_or_create_collection + - collection:delete_collection + - collection:update_collection + - collection:add + - collection:delete + - collection:get + - collection:query + - collection:peek #from API perspective this is the same as collection:get + - collection:count + - collection:update + - collection:upsert +``` + +### Role Mapping + +Following are the role mappings where we define roles and the actions they can perform. The actions spaces is taken from the resource actions defined above. + +> **Note**: We also plan to support resource level authorization soon but for now only RBAC is available. + +```yaml +roles_mapping: + admin: + actions: + [ + db:list_collections, + collection:get_collection, + db:create_collection, + db:get_or_create_collection, + collection:delete_collection, + collection:update_collection, + collection:add, + collection:delete, + collection:get, + collection:query, + collection:peek, + collection:update, + collection:upsert, + collection:count, + ] + write: + actions: + [ + db:list_collections, + collection:get_collection, + db:create_collection, + db:get_or_create_collection, + collection:delete_collection, + collection:update_collection, + collection:add, + collection:delete, + collection:get, + collection:query, + collection:peek, + collection:update, + collection:upsert, + collection:count, + ] + db_read: + actions: + [ + db:list_collections, + collection:get_collection, + db:create_collection, + db:get_or_create_collection, + collection:delete_collection, + collection:update_collection, + ] + collection_read: + actions: + [ + db:list_collections, + collection:get_collection, + collection:get, + collection:query, + collection:peek, + collection:count, + ] + collection_x_read: + actions: + [ + collection:get_collection, + collection:get, + collection:query, + collection:peek, + collection:count, + ] + resources: [""] #not yet supported +``` + +You can update the roll mapping as per your requirements. + +### Users + +Last piece of the puzzle is the user configuration. Here we define the user id, role and the tokens they can use to authenticate. + +> **Note**: In our example we use both AuthN and AuthZ where AuthN verifies whether a token is valid e.g. user has that token and AuthZ verifies whether the user has the right role to perform the action. + +```yaml +users: + - id: user@example.com + role: admin + tokens: + - token: test-token-admin + secret: my_api_secret # not yet supported + - id: Anonymous + role: admin + tokens: + - token: my_api_token + secret: my_api_secret +``` + +## Starting the Server + +```bash +IS_PERSISTENT=1 \ +CHROMA_SERVER_AUTHZ_PROVIDER="chromadb.auth.authz.SimpleRBACAuthorizationProvider" \ +CHROMA_SERVER_AUTH_CREDENTIALS_FILE=examples/basic_functionality/authz/authz.yaml \ +CHROMA_SERVER_AUTH_CREDENTIALS_PROVIDER="user_token_config" \ +CHROMA_SERVER_AUTH_PROVIDER="chromadb.auth.token.TokenAuthServerProvider" \ +CHROMA_SERVER_AUTHZ_CONFIG_FILE=examples/basic_functionality/authz/authz.yaml \ +uvicorn chromadb.app:app --workers 1 --host 0.0.0.0 --port 8000 --proxy-headers --log-config chromadb/log_config.yml --reload +``` + +## Testing the authorization + +```python +import chromadb +from chromadb.config import Settings + +client = chromadb.HttpClient("http://localhost:8000/", + settings=Settings(chroma_client_auth_provider="chromadb.auth.token.TokenAuthClientProvider", + chroma_client_auth_credentials="test-token-admin")) + +client.list_collections() +collection = client.get_or_create_collection("test_collection") + +collection.add(documents=["test"],ids=["1"]) +collection.get() +``` diff --git a/examples/basic_functionality/authz/authz.ipynb b/examples/basic_functionality/authz/authz.ipynb new file mode 100644 index 00000000000..c70df77c702 --- /dev/null +++ b/examples/basic_functionality/authz/authz.ipynb @@ -0,0 +1,95 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/Users/tazarov\n" + ] + }, + { + "data": { + "text/plain": [ + "{'ids': ['1'],\n", + " 'embeddings': None,\n", + " 'metadatas': [None],\n", + " 'documents': ['test21']}" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%cd ../../../\n", + "import chromadb\n", + "from chromadb.config import Settings\n", + "\n", + "client = chromadb.HttpClient(\"http://localhost:8000/\",\n", + " settings=Settings(chroma_client_auth_provider=\"chromadb.auth.token.TokenAuthClientProvider\",\n", + " chroma_client_auth_credentials=\"test-token-admin\"))\n", + "\n", + "client.list_collections()\n", + "collection = client.get_or_create_collection(\"test_collection\")\n", + "\n", + "collection.add(documents=[\"test21\"],ids=[\"1\"])\n", + "collection.get(ids=[\"1\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "ename": "HTTPError", + "evalue": "400 Client Error: Bad Request for url: http://localhost:8000/api/v1/collections/4487accd-6160-454c-a5f2-26d6e87ce5ef/upsert", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mHTTPError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/Users/tazarov/experiments/chroma-experiments/chroma-authz/examples/basic_functionality/authz/authz.ipynb Cell 2\u001b[0m line \u001b[0;36m6\n\u001b[1;32m 2\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mchromadb\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mapi\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mmodels\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mCollection\u001b[39;00m \u001b[39mimport\u001b[39;00m Collection\n\u001b[1;32m 5\u001b[0m col \u001b[39m=\u001b[39m Collection(client, \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mtest-upsert-\u001b[39m\u001b[39m{\u001b[39;00muuid\u001b[39m.\u001b[39muuid4()\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m, uuid\u001b[39m.\u001b[39muuid4())\n\u001b[0;32m----> 6\u001b[0m col\u001b[39m.\u001b[39;49mupsert(documents\u001b[39m=\u001b[39;49m[\u001b[39m\"\u001b[39;49m\u001b[39mtest\u001b[39;49m\u001b[39m\"\u001b[39;49m],ids\u001b[39m=\u001b[39;49m[\u001b[39m\"\u001b[39;49m\u001b[39m1\u001b[39;49m\u001b[39m\"\u001b[39;49m])\n", + "File \u001b[0;32m~/experiments/chroma-experiments/chroma-authz/chromadb/api/models/Collection.py:299\u001b[0m, in \u001b[0;36mCollection.upsert\u001b[0;34m(self, ids, embeddings, metadatas, documents)\u001b[0m\n\u001b[1;32m 283\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"Update the embeddings, metadatas or documents for provided ids, or create them if they don't exist.\u001b[39;00m\n\u001b[1;32m 284\u001b[0m \n\u001b[1;32m 285\u001b[0m \u001b[39mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 292\u001b[0m \u001b[39m None\u001b[39;00m\n\u001b[1;32m 293\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 295\u001b[0m ids, embeddings, metadatas, documents \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_validate_embedding_set(\n\u001b[1;32m 296\u001b[0m ids, embeddings, metadatas, documents\n\u001b[1;32m 297\u001b[0m )\n\u001b[0;32m--> 299\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_client\u001b[39m.\u001b[39;49m_upsert(\n\u001b[1;32m 300\u001b[0m collection_id\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mid,\n\u001b[1;32m 301\u001b[0m ids\u001b[39m=\u001b[39;49mids,\n\u001b[1;32m 302\u001b[0m embeddings\u001b[39m=\u001b[39;49membeddings,\n\u001b[1;32m 303\u001b[0m metadatas\u001b[39m=\u001b[39;49mmetadatas,\n\u001b[1;32m 304\u001b[0m documents\u001b[39m=\u001b[39;49mdocuments,\n\u001b[1;32m 305\u001b[0m )\n", + "File \u001b[0;32m~/experiments/chroma-experiments/chroma-authz/chromadb/api/fastapi.py:382\u001b[0m, in \u001b[0;36m_upsert\u001b[0;34m(self, collection_id, ids, embeddings, metadatas, documents)\u001b[0m\n\u001b[1;32m 379\u001b[0m batch \u001b[39m=\u001b[39m (ids, embeddings, metadatas, documents)\n\u001b[1;32m 380\u001b[0m validate_batch(batch, {\u001b[39m\"\u001b[39m\u001b[39mmax_batch_size\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmax_batch_size})\n\u001b[1;32m 381\u001b[0m resp \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_submit_batch(\n\u001b[0;32m--> 382\u001b[0m batch, \u001b[39m\"\u001b[39m\u001b[39m/collections/\u001b[39m\u001b[39m\"\u001b[39m \u001b[39m+\u001b[39m \u001b[39mstr\u001b[39m(collection_id) \u001b[39m+\u001b[39m \u001b[39m\"\u001b[39m\u001b[39m/update\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 383\u001b[0m )\n\u001b[1;32m 384\u001b[0m resp\u001b[39m.\u001b[39mraise_for_status()\n\u001b[1;32m 385\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mTrue\u001b[39;00m\n", + "File \u001b[0;32m~/experiments/chroma-experiments/chroma-authz/venv/lib/python3.11/site-packages/requests/models.py:1021\u001b[0m, in \u001b[0;36mResponse.raise_for_status\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1016\u001b[0m http_error_msg \u001b[39m=\u001b[39m (\n\u001b[1;32m 1017\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstatus_code\u001b[39m}\u001b[39;00m\u001b[39m Server Error: \u001b[39m\u001b[39m{\u001b[39;00mreason\u001b[39m}\u001b[39;00m\u001b[39m for url: \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39murl\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m\n\u001b[1;32m 1018\u001b[0m )\n\u001b[1;32m 1020\u001b[0m \u001b[39mif\u001b[39;00m http_error_msg:\n\u001b[0;32m-> 1021\u001b[0m \u001b[39mraise\u001b[39;00m HTTPError(http_error_msg, response\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m)\n", + "\u001b[0;31mHTTPError\u001b[0m: 400 Client Error: Bad Request for url: http://localhost:8000/api/v1/collections/4487accd-6160-454c-a5f2-26d6e87ce5ef/upsert" + ] + } + ], + "source": [ + "import uuid\n", + "from chromadb.api.models.Collection import Collection\n", + "\n", + "col = Collection(client, f\"test-upsert-{uuid.uuid4()}\", uuid.uuid4())\n", + "col.upsert(documents=[\"test\"],ids=[\"1\"])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/basic_functionality/authz/authz.yaml b/examples/basic_functionality/authz/authz.yaml new file mode 100644 index 00000000000..1e07a0ec9e4 --- /dev/null +++ b/examples/basic_functionality/authz/authz.yaml @@ -0,0 +1,113 @@ +resource_type_action: # This is here just for reference + - tenant:create_tenant + - tenant:get_tenant + - db:create_database + - db:get_database + - db:reset + - db:list_collections + - collection:get_collection + - db:create_collection + - db:get_or_create_collection + - collection:delete_collection + - collection:update_collection + - collection:add + - collection:delete + - collection:get + - collection:query + - collection:peek #from API perspective this is the same as collection:get + - collection:count + - collection:update + - collection:upsert + +roles_mapping: + admin: + actions: + [ + "tenant:create_tenant", + "tenant:get_tenant", + "db:create_database", + "db:get_database", + "db:reset", + "db:list_collections", + "collection:get_collection", + "db:create_collection", + "db:get_or_create_collection", + "collection:delete_collection", + "collection:update_collection", + "collection:add", + "collection:delete", + "collection:get", + "collection:query", + "collection:peek", + "collection:update", + "collection:upsert", + "collection:count", + ] + write: + actions: + [ + "tenant:get_tenant", + "db:get_database", + "db:list_collections", + "collection:get_collection", + "db:create_collection", + "db:get_or_create_collection", + "collection:delete_collection", + "collection:update_collection", + "collection:add", + "collection:delete", + "collection:get", + "collection:query", + "collection:peek", + "collection:update", + "collection:upsert", + "collection:count", + ] + db_read: + actions: + [ + "tenant:get_tenant", + "db:get_database", + "db:list_collections", + "collection:get_collection", + "db:create_collection", + "db:get_or_create_collection", + "collection:delete_collection", + "collection:update_collection", + ] + collection_read: + actions: + [ + "tenant:get_tenant", + "db:get_database", + "db:list_collections", + "collection:get_collection", + "collection:get", + "collection:query", + "collection:peek", + "collection:count", + ] + collection_x_read: + actions: + [ + "tenant:get_tenant", + "db:get_database", + "collection:get_collection", + "collection:get", + "collection:query", + "collection:peek", + "collection:count", + ] + resources: [""] #not yet supported +users: + - id: user@example.com + role: admin + tenant: my_tenant + tokens: + - token: test-token-admin + secret: my_api_secret # not yet supported + - id: Anonymous + role: db_read + tokens: + - token: my_api_token + secret: my_api_secret diff --git a/pyproject.toml b/pyproject.toml index 04d7e79c27a..d223ea338e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ 'typer >= 0.9.0', 'kubernetes>=28.1.0', 'tenacity>=8.2.3', + 'PyYAML>=6.0.0', ] [tool.black] diff --git a/requirements.txt b/requirements.txt index 4ad9f29aeb2..338083ef3d0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,3 +23,4 @@ tqdm>=4.65.0 typer>=0.9.0 typing_extensions>=4.5.0 uvicorn[standard]==0.18.3 +PyYAML>=6.0.0