Skip to content

Commit

Permalink
feat: Working implementation and updated CIP
Browse files Browse the repository at this point in the history
  • Loading branch information
tazarov committed Oct 19, 2023
1 parent 9b3d8a4 commit e63f57a
Show file tree
Hide file tree
Showing 14 changed files with 1,170 additions and 99 deletions.
203 changes: 195 additions & 8 deletions chromadb/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,21 @@
Contains only Auth abstractions, no implementations.
"""
import base64
from functools import partial
import logging
from abc import ABC, abstractmethod
from enum import Enum
from typing import (
Any,
Callable,
Optional,
Dict,
TypeVar,
Tuple,
Generic,
Union,
)
from attr import dataclass

from overrides import EnforceOverrides, override
from pydantic import SecretStr
Expand All @@ -35,6 +40,21 @@ class AuthInfoType(Enum):
METADATA = "metadata" # gRPC


class UserIdentity(EnforceOverrides, ABC):
@abstractmethod
def get_user_id(self) -> str:
...


class SimpleUserIdentity(UserIdentity):
def __init__(self, user_id: str) -> None:
self._user_id = user_id

@override
def get_user_id(self) -> str:
return self._user_id


class ClientAuthResponse(EnforceOverrides, ABC):
@abstractmethod
def get_auth_info_type(self) -> AuthInfoType:
Expand Down Expand Up @@ -87,11 +107,12 @@ def inject_credentials(self, injection_context: T) -> None:
class ServerAuthenticationRequest(EnforceOverrides, ABC, Generic[T]):
@abstractmethod
def get_auth_info(
self, auth_info_type: AuthInfoType, auth_info_id: Optional[str] = None
self, auth_info_type: AuthInfoType, auth_info_id: str
) -> T:
"""
This method should return the necessary auth info based on the type of authentication (e.g. header, cookie, url)
and a given id for the respective auth type (e.g. name of the header, cookie, url param).
This method should return the necessary auth info based on the type of
authentication (e.g. header, cookie, url) and a given id for the respective
auth type (e.g. name of the header, cookie, url param).
:param auth_info_type: The type of auth info to return
:param auth_info_id: The id of the auth info to return
Expand All @@ -101,16 +122,41 @@ def get_auth_info(


class ServerAuthenticationResponse(EnforceOverrides, ABC):
@abstractmethod
def success(self) -> bool:
...

@abstractmethod
def get_user_identity(self) -> Optional[UserIdentity]:
...


class SimpleServerAuthenticationResponse(ServerAuthenticationResponse):
""" Simple implementation of ServerAuthenticationResponse"""
_auth_success: bool
_user_identity: Optional[UserIdentity]

def __init__(self, auth_success: bool, user_identity: Optional[UserIdentity]) \
-> None:
self._auth_success = auth_success
self._user_identity = user_identity

@override
def success(self) -> bool:
raise NotImplementedError()
return self._auth_success

@override
def get_user_identity(self) -> Optional[UserIdentity]:
return self._user_identity


class ServerAuthProvider(Component):
def __init__(self, system: System) -> None:
super().__init__(system)

@abstractmethod
def authenticate(self, request: ServerAuthenticationRequest[T]) -> bool:
def authenticate(self, request: ServerAuthenticationRequest[T]) \
-> ServerAuthenticationResponse:
pass


Expand All @@ -121,7 +167,7 @@ def __init__(self, system: System) -> None:
@abstractmethod
def authenticate(
self, request: ServerAuthenticationRequest[T]
) -> Optional[ServerAuthenticationResponse]:
) -> ServerAuthenticationResponse:
...

@abstractmethod
Expand Down Expand Up @@ -155,8 +201,8 @@ def name(cls) -> str:

class AbstractCredentials(EnforceOverrides, ABC, Generic[T]):
"""
The class is used by Auth Providers to encapsulate credentials received from the server
and pass them to a ServerAuthCredentialsProvider.
The class is used by Auth Providers to encapsulate credentials received
from the server and pass them to a ServerAuthCredentialsProvider.
"""

@abstractmethod
Expand Down Expand Up @@ -204,4 +250,145 @@ def __init__(self, system: System) -> None:

@abstractmethod
def validate_credentials(self, credentials: AbstractCredentials[T]) -> bool:
...

@abstractmethod
def get_user_identity(self, credentials: AbstractCredentials[T]) \
-> Optional[UserIdentity]:
...


# --- AuthZ ---#

# TODO move this to basic impl

@dataclass
class AuthzUser:
id: Optional[str]
attributes: Optional[Dict[str, Any]] = None
claims: Optional[Dict[str, Any]] = None


@dataclass
class AuthzResource:
id: Optional[str]
type: Optional[str]
namespace: Optional[str]
attributes: Optional[Dict[str, Any]] = None


class DynamicAuthzResource:
id: Optional[Union[str, Callable[..., str]]]
namespace: Optional[Union[str, Callable[..., str]]]
type: Optional[Union[str, Callable[..., str]]]
attributes: Optional[Union[Dict[str, Any], Callable[..., Dict[str, Any]]]]

def __init__(self, id: Optional[Union[str, Callable[..., str]]] = None,
namespace: Optional[Union[str, Callable[..., str]
]] = "default_database",
attributes: Optional[Union[Dict[str, Any],
Callable[..., Dict[str, Any]]]]
= lambda **kwargs: {},
type: Optional[Union[str, Callable[..., str]]
] = "default_database",
) -> None:
self.id = id
self.namespace = namespace
self.attributes = attributes
self.type = type

def to_authz_resource(self, **kwargs):
return AuthzResource(
id=self.id(**kwargs) if callable(self.id) else self.id,
namespace=self.namespace(**kwargs) if callable(
self.namespace) else self.namespace,
type=self.type(**kwargs) if callable(
self.type) else self.type,
attributes=self.attributes(**kwargs) if callable(
self.attributes) else self.attributes,
)


class AuthzDynamicParams:
@staticmethod
def from_function_name(**kwargs):
return partial(lambda **kwargs: kwargs['function'].__name__,
**kwargs)

@staticmethod
def from_function_args(**kwargs):
return partial(lambda **kwargs: kwargs['function_args'][kwargs['arg_num']],
**kwargs)

@staticmethod
def from_function_kwargs(**kwargs):
return partial(lambda **kwargs: kwargs['function_kwargs'][kwargs['arg_name']],
**kwargs)


@dataclass
class AuthzAction:
id: str
attributes: Optional[Dict[str, Any]] = None


@dataclass
class AuthorizationContext:
user: AuthzUser
resource: AuthzResource
action: AuthzAction


class ServerAuthorizationProvider(Component):
def __init__(self, system: System) -> None:
super().__init__(system)

@abstractmethod
def authorize(self, context: AuthorizationContext) \
-> bool:
pass


class AuthorizationRequestContext(EnforceOverrides, ABC, Generic[T]):
@abstractmethod
def get_request(self) -> T:
...


class ChromaAuthzMiddleware(Component, Generic[T]):
def __init__(self, system: System) -> None:
super().__init__(system)

@abstractmethod
def pre_process(
self, request: AuthorizationRequestContext
) -> None:
...

@abstractmethod
def ignore_operation(self, verb: str, path: str) -> bool:
...

@abstractmethod
def instrument_server(self, app: T) -> None:
...


class ServerAuthorizationConfigurationProvider(Component, Generic[T]):
def __init__(self, system: System) -> None:
super().__init__(system)

@abstractmethod
def get_configuration(self) -> T:
pass


class AuthorizationError(ChromaError):
@override
def code(self) -> int:
return 403

@classmethod
@override
def name(cls) -> str:
return "AuthorizationError"
78 changes: 78 additions & 0 deletions chromadb/auth/authz/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import logging
from typing import Any, Dict, cast
from overrides import override
import yaml
from chromadb.auth import (
AuthorizationContext,
ServerAuthorizationConfigurationProvider,
ServerAuthorizationProvider,
)
from chromadb.auth.registry import register_provider, resolve_provider
from chromadb.config import System

logger = logging.getLogger(__name__)


@register_provider("local_authz_config")
class LocalUserConfigAuthorizationConfigurationProvider(
ServerAuthorizationConfigurationProvider[Dict[str, Any]]
):
_config_file: str
_config: Dict[str, Any]

def __init__(self, system: System) -> None:
super().__init__(system)
self._settings = system.settings
system.settings.require("chroma_server_authz_config_file")
self._config_file = str(
system.settings.chroma_server_authz_config_file)
with open(self._config_file, "r") as f:
self._config = yaml.safe_load(f)

@override
def get_configuration(self) -> Dict[str, Any]:
return self._config


@register_provider("simple_rbac")
class SimpleRBACAuthorizationProvider(ServerAuthorizationProvider):
_authz_config_provider: ServerAuthorizationConfigurationProvider[Dict[str, Any]]

def __init__(self, system: System) -> None:
super().__init__(system)
self._settings = system.settings
system.settings.require("chroma_server_authz_config_provider")
if self._settings.chroma_server_authz_config_provider:
_cls = resolve_provider(
self._settings.chroma_server_authz_config_provider,
ServerAuthorizationConfigurationProvider
)
self._authz_config_provider = cast(
ServerAuthorizationConfigurationProvider, self.require(_cls))
_config = self._authz_config_provider.get_configuration()
self._authz_tuples = []
for u in _config["users"]:
_actions = _config["roles_mapping"][u["role"]]["actions"]
for a in _actions:
self._authz_tuples.append((u["id"], *a.split(":")))
logger.debug(f"Loaded {len(self._authz_tuples)} permissions for "
f"({len(_config['users'])}) users")
logger.info(
"Authorization Provider SimpleRBACAuthorizationProvider initialized")

@override
def authorize(self, context: AuthorizationContext) \
-> bool:
_authz_tuple = (context.user.id,
context.resource.type,
context.action.id)

print(_authz_tuple)
policy_decision = False
if _authz_tuple in self._authz_tuples:
policy_decision = True
logger.debug(f"Authorization decision: Access "
f"{'granted' if policy_decision else 'denied'} for "
f"user [{context.user.id}] attempting to [{context.action.id}]"
f" on [{context.resource}]")
return policy_decision
32 changes: 27 additions & 5 deletions chromadb/auth/basic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
BasicAuthCredentials,
ClientAuthCredentialsProvider,
ClientAuthResponse,
SimpleServerAuthenticationResponse,
)
from chromadb.auth.registry import register_provider, resolve_provider
from chromadb.config import System
Expand Down Expand Up @@ -85,12 +86,33 @@ def __init__(self, system: System) -> None:
)

@override
def authenticate(self, request: ServerAuthenticationRequest[Any]) -> bool:
def authenticate(self, request: ServerAuthenticationRequest[Any]) \
-> SimpleServerAuthenticationResponse:
try:
_auth_header = request.get_auth_info(AuthInfoType.HEADER, "Authorization")
return self._credentials_provider.validate_credentials(
_auth_header = request.get_auth_info(
AuthInfoType.HEADER, "Authorization")
_validation = self._credentials_provider.validate_credentials(
BasicAuthCredentials.from_header(_auth_header)
)
return SimpleServerAuthenticationResponse(
_validation,
self._credentials_provider.get_user_identity(
BasicAuthCredentials.from_header(_auth_header)
),
)
except Exception as e:
logger.error(f"BasicAuthServerProvider.authenticate failed: {repr(e)}")
return False
logger.error(
f"BasicAuthServerProvider.authenticate failed: {repr(e)}")
return SimpleServerAuthenticationResponse(
False, None
)

# @override
# def get_auth_info_type(self, request: ServerAuthenticationRequest[Any]) \
# -> UserIdentity:
# _auth_header = request.get_auth_info(
# AuthInfoType.HEADER, "Authorization")
# _creds = BasicAuthCredentials.from_header(_auth_header)
# return SimpleUserIdentity(
# _creds.get_credentials()["username"].get_secret_value()
# )
Loading

0 comments on commit e63f57a

Please sign in to comment.