Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH]: CIP - Authorization #1250

Merged
merged 23 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
9b3d8a4
feat: CIP - Authorization
tazarov Oct 16, 2023
e63f57a
feat: Working implementation and updated CIP
tazarov Oct 19, 2023
a4ee5e5
Merge branch 'main' into feature/authz
tazarov Oct 19, 2023
19daf90
feat: Authorization
tazarov Oct 19, 2023
d86ac2e
chore: Updated README.md for authz example
tazarov Oct 19, 2023
02b4ac0
chore: Added property tests
tazarov Oct 23, 2023
3e2803f
fix: Issue with subscripting a primitive tuple type
tazarov Oct 23, 2023
57ec406
fix: Made rmtree to ignore errors
tazarov Oct 24, 2023
0f86c81
fix: moving authz in a separate test
tazarov Oct 24, 2023
1f6a8b8
fix: Temp fix for telemetry
tazarov Oct 24, 2023
393b42e
feat: expanding client auth provided capabilities to allow for multip…
tazarov Oct 25, 2023
d0c2d02
feat: Addendum to last commit
tazarov Oct 25, 2023
7ee8263
Merge remote-tracking branch 'origin/main' into feature/authz
tazarov Oct 25, 2023
c9cd702
feat: Refactored to support tenant and db
tazarov Oct 25, 2023
69937f5
chore: Updated docs with new actions for db and tenant
tazarov Oct 25, 2023
ed333aa
fix: Moving Authz in a separate test
tazarov Oct 26, 2023
17fd267
fix: Reduce the number of examples to prevent `Exception of type 'Sys…
tazarov Oct 26, 2023
38fd73a
fix: Small refactoring to address comments for review
tazarov Oct 26, 2023
bc8804b
fix: Fixed an issue with test
tazarov Oct 26, 2023
5c61770
fix: Fixed wrong tenant assignment in SimpleUserIdentity
tazarov Oct 26, 2023
37b8c5d
feat: Enhanced UserIdentity with additional attributes
tazarov Oct 26, 2023
33aed93
fix: Fixed delete collection failing test
tazarov Oct 26, 2023
8abbf83
fix: Fix for peek, needs to include get as action in the role
tazarov Oct 26, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/chroma-integration-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ jobs:
matrix:
python: ['3.8']
platform: [ubuntu-latest, windows-latest]
testfile: ["--ignore-glob 'chromadb/test/property/*' --ignore='chromadb/test/test_cli.py'",
testfile: ["--ignore-glob 'chromadb/test/property/*' --ignore='chromadb/test/test_cli.py' --ignore='chromadb/test/auth/test_simple_rbac_authz.py'",
"chromadb/test/property/test_add.py",
"chromadb/test/test_cli.py",
"chromadb/test/auth/test_simple_rbac_authz.py",
"chromadb/test/property/test_collections.py",
"chromadb/test/property/test_cross_version_persist.py",
"chromadb/test/property/test_embeddings.py",
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/chroma-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ jobs:
matrix:
python: ['3.8', '3.9', '3.10', '3.11']
platform: [ubuntu-latest, windows-latest]
testfile: ["--ignore-glob 'chromadb/test/property/*' --ignore-glob 'chromadb/test/stress/*'",
testfile: ["--ignore-glob 'chromadb/test/property/*' --ignore-glob 'chromadb/test/stress/*' --ignore='chromadb/test/auth/test_simple_rbac_authz.py'",
"chromadb/test/auth/test_simple_rbac_authz.py",
"chromadb/test/property/test_add.py",
"chromadb/test/property/test_collections.py",
"chromadb/test/property/test_cross_version_persist.py",
Expand Down
255 changes: 244 additions & 11 deletions chromadb/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,29 @@
Contains only Auth abstractions, no implementations.
"""
import base64
from functools import partial
import logging
from abc import ABC, abstractmethod
from enum import Enum
from typing import (
Any,
Callable,
List,
Optional,
Dict,
TypeVar,
Tuple,
Generic,
Union,
)
from dataclasses import dataclass

from overrides import EnforceOverrides, override
from pydantic import SecretStr

from chromadb.config import (
DEFAULT_DATABASE,
DEFAULT_TENANT,
Component,
System,
)
Expand All @@ -35,13 +43,53 @@ class AuthInfoType(Enum):
METADATA = "metadata" # gRPC


class UserIdentity(EnforceOverrides, ABC):
HammadB marked this conversation as resolved.
Show resolved Hide resolved
@abstractmethod
def get_user_id(self) -> str:
...

@abstractmethod
def get_user_tenant(self) -> Optional[str]:
...

@abstractmethod
def get_user_attributes(self) -> Optional[Dict[str, Any]]:
...


class SimpleUserIdentity(UserIdentity):
def __init__(
self,
user_id: str,
tenant: Optional[str] = None,
attributes: Optional[Dict[str, Any]] = None,
) -> None:
self._user_id = user_id
self._tenant = tenant
self._attributes = attributes

@override
def get_user_id(self) -> str:
return self._user_id

@override
def get_user_tenant(self) -> Optional[str]:
return self._tenant if self._tenant else DEFAULT_TENANT

@override
def get_user_attributes(self) -> Optional[Dict[str, Any]]:
return self._attributes


class ClientAuthResponse(EnforceOverrides, ABC):
@abstractmethod
def get_auth_info_type(self) -> AuthInfoType:
...

@abstractmethod
def get_auth_info(self) -> Tuple[str, SecretStr]:
def get_auth_info(
self,
) -> Union[Tuple[str, SecretStr], List[Tuple[str, SecretStr]]]:
...


Expand Down Expand Up @@ -86,12 +134,11 @@ def inject_credentials(self, injection_context: T) -> None:

class ServerAuthenticationRequest(EnforceOverrides, ABC, Generic[T]):
@abstractmethod
def get_auth_info(
self, auth_info_type: AuthInfoType, auth_info_id: Optional[str] = None
) -> T:
def get_auth_info(self, auth_info_type: AuthInfoType, auth_info_id: str) -> T:
"""
This method should return the necessary auth info based on the type of authentication (e.g. header, cookie, url)
and a given id for the respective auth type (e.g. name of the header, cookie, url param).
This method should return the necessary auth info based on the type of
authentication (e.g. header, cookie, url) and a given id for the respective
auth type (e.g. name of the header, cookie, url param).

:param auth_info_type: The type of auth info to return
:param auth_info_id: The id of the auth info to return
Expand All @@ -101,16 +148,44 @@ def get_auth_info(


class ServerAuthenticationResponse(EnforceOverrides, ABC):
@abstractmethod
def success(self) -> bool:
...

@abstractmethod
def get_user_identity(self) -> Optional[UserIdentity]:
...


class SimpleServerAuthenticationResponse(ServerAuthenticationResponse):
"""Simple implementation of ServerAuthenticationResponse"""

_auth_success: bool
_user_identity: Optional[UserIdentity]

def __init__(
self, auth_success: bool, user_identity: Optional[UserIdentity]
) -> None:
self._auth_success = auth_success
self._user_identity = user_identity

@override
def success(self) -> bool:
raise NotImplementedError()
return self._auth_success

@override
def get_user_identity(self) -> Optional[UserIdentity]:
return self._user_identity


class ServerAuthProvider(Component):
def __init__(self, system: System) -> None:
super().__init__(system)

@abstractmethod
def authenticate(self, request: ServerAuthenticationRequest[T]) -> bool:
def authenticate(
self, request: ServerAuthenticationRequest[T]
) -> ServerAuthenticationResponse:
pass


Expand All @@ -121,7 +196,7 @@ def __init__(self, system: System) -> None:
@abstractmethod
def authenticate(
self, request: ServerAuthenticationRequest[T]
) -> Optional[ServerAuthenticationResponse]:
) -> ServerAuthenticationResponse:
...

@abstractmethod
Expand Down Expand Up @@ -155,8 +230,8 @@ def name(cls) -> str:

class AbstractCredentials(EnforceOverrides, ABC, Generic[T]):
"""
The class is used by Auth Providers to encapsulate credentials received from the server
and pass them to a ServerAuthCredentialsProvider.
The class is used by Auth Providers to encapsulate credentials received
from the server and pass them to a ServerAuthCredentialsProvider.
"""

@abstractmethod
Expand Down Expand Up @@ -204,4 +279,162 @@ def __init__(self, system: System) -> None:

@abstractmethod
def validate_credentials(self, credentials: AbstractCredentials[T]) -> bool:
...

@abstractmethod
def get_user_identity(
self, credentials: AbstractCredentials[T]
) -> Optional[UserIdentity]:
...


class AuthzResourceTypes(str, Enum):
DB = "db"
COLLECTION = "collection"
TENANT = "tenant"


class AuthzResourceActions(str, Enum):
CREATE_DATABASE = "create_database"
GET_DATABASE = "get_database"
CREATE_TENANT = "create_tenant"
GET_TENANT = "get_tenant"
LIST_COLLECTIONS = "list_collections"
GET_COLLECTION = "get_collection"
CREATE_COLLECTION = "create_collection"
GET_OR_CREATE_COLLECTION = "get_or_create_collection"
DELETE_COLLECTION = "delete_collection"
UPDATE_COLLECTION = "update_collection"
ADD = "add"
DELETE = "delete"
GET = "get"
QUERY = "query"
COUNT = "count"
UPDATE = "update"
UPSERT = "upsert"
RESET = "reset"


@dataclass
class AuthzUser:
id: Optional[str]
tenant: Optional[str] = DEFAULT_TENANT
attributes: Optional[Dict[str, Any]] = None
claims: Optional[Dict[str, Any]] = None


@dataclass
class AuthzResource:
id: Optional[str]
type: Optional[str]
attributes: Optional[Dict[str, Any]] = None


class DynamicAuthzResource:
id: Optional[Union[str, Callable[..., str]]]
type: Optional[Union[str, Callable[..., str]]]
attributes: Optional[Union[Dict[str, Any], Callable[..., Dict[str, Any]]]]

def __init__(
self,
id: Optional[Union[str, Callable[..., str]]] = None,
attributes: Optional[
Union[Dict[str, Any], Callable[..., Dict[str, Any]]]
] = lambda **kwargs: {},
type: Optional[Union[str, Callable[..., str]]] = DEFAULT_DATABASE,
) -> None:
self.id = id
self.attributes = attributes
self.type = type

def to_authz_resource(self, **kwargs: Any) -> AuthzResource:
return AuthzResource(
id=self.id(**kwargs) if callable(self.id) else self.id,
type=self.type(**kwargs) if callable(self.type) else self.type,
attributes=self.attributes(**kwargs)
if callable(self.attributes)
else self.attributes,
)


class AuthzDynamicParams:
@staticmethod
def from_function_name(**kwargs: Any) -> Callable[..., str]:
return partial(lambda **kwargs: kwargs["function"].__name__, **kwargs)

@staticmethod
def from_function_args(**kwargs: Any) -> Callable[..., str]:
return partial(
lambda **kwargs: kwargs["function_args"][kwargs["arg_num"]], **kwargs
)

@staticmethod
def from_function_kwargs(**kwargs: Any) -> Callable[..., str]:
return partial(
lambda **kwargs: kwargs["function_kwargs"][kwargs["arg_name"]], **kwargs
)


@dataclass
class AuthzAction:
id: str
attributes: Optional[Dict[str, Any]] = None


@dataclass
class AuthorizationContext:
user: AuthzUser
resource: AuthzResource
action: AuthzAction


class ServerAuthorizationProvider(Component):
def __init__(self, system: System) -> None:
super().__init__(system)

@abstractmethod
def authorize(self, context: AuthorizationContext) -> bool:
pass


class AuthorizationRequestContext(EnforceOverrides, ABC, Generic[T]):
@abstractmethod
def get_request(self) -> T:
...


class ChromaAuthzMiddleware(Component, Generic[T, S]):
def __init__(self, system: System) -> None:
super().__init__(system)

@abstractmethod
def pre_process(self, request: AuthorizationRequestContext[S]) -> None:
...

@abstractmethod
def ignore_operation(self, verb: str, path: str) -> bool:
...

@abstractmethod
def instrument_server(self, app: T) -> None:
...


class ServerAuthorizationConfigurationProvider(Component, Generic[T]):
def __init__(self, system: System) -> None:
super().__init__(system)

@abstractmethod
def get_configuration(self) -> T:
pass


class AuthorizationError(ChromaError):
@override
def code(self) -> int:
return 403

@classmethod
@override
def name(cls) -> str:
return "AuthorizationError"
Loading