Skip to content

Commit

Permalink
fix: Small refactoring to address comments for review
Browse files Browse the repository at this point in the history
  • Loading branch information
tazarov committed Oct 26, 2023
1 parent 17fd267 commit 38fd73a
Show file tree
Hide file tree
Showing 8 changed files with 204 additions and 126 deletions.
38 changes: 4 additions & 34 deletions chromadb/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from pydantic import SecretStr

from chromadb.config import (
DEFAULT_DATABASE,
DEFAULT_TENANT,
Component,
System,
)
Expand Down Expand Up @@ -272,9 +274,6 @@ def get_user_identity(
...


# --- AuthZ ---#


class AuthzResourceTypes(str, Enum):
DB = "db"
COLLECTION = "collection"
Expand All @@ -296,7 +295,6 @@ class AuthzResourceActions(str, Enum):
DELETE = "delete"
GET = "get"
QUERY = "query"
PEEK = "peek"
COUNT = "count"
UPDATE = "update"
UPSERT = "upsert"
Expand All @@ -306,7 +304,7 @@ class AuthzResourceActions(str, Enum):
@dataclass
class AuthzUser:
id: Optional[str]
tenant: Optional[str] = "*"
tenant: Optional[str] = DEFAULT_TENANT
attributes: Optional[Dict[str, Any]] = None
claims: Optional[Dict[str, Any]] = None

Expand All @@ -329,7 +327,7 @@ def __init__(
attributes: Optional[
Union[Dict[str, Any], Callable[..., Dict[str, Any]]]
] = lambda **kwargs: {},
type: Optional[Union[str, Callable[..., str]]] = "default_database",
type: Optional[Union[str, Callable[..., str]]] = DEFAULT_DATABASE,
) -> None:
self.id = id
self.attributes = attributes
Expand All @@ -345,27 +343,6 @@ def to_authz_resource(self, **kwargs: Any) -> AuthzResource:
)


def find_key_with_value_of_type(
type: AuthzResourceTypes, **kwargs: Any
) -> Dict[str, Any]:
from chromadb.server.fastapi.types import (
CreateCollection,
CreateDatabase,
CreateTenant,
)

for key, value in kwargs.items():
if type == AuthzResourceTypes.DB and isinstance(value, CreateDatabase):
return dict(value)
elif type == AuthzResourceTypes.COLLECTION and isinstance(
value, CreateCollection
):
return dict(value)
elif type == AuthzResourceTypes.TENANT and isinstance(value, CreateTenant):
return dict(value)
return {}


class AuthzDynamicParams:
@staticmethod
def from_function_name(**kwargs: Any) -> Callable[..., str]:
Expand All @@ -383,13 +360,6 @@ def from_function_kwargs(**kwargs: Any) -> Callable[..., str]:
lambda **kwargs: kwargs["function_kwargs"][kwargs["arg_name"]], **kwargs
)

@staticmethod
def attr_from_resource_object(
type: AuthzResourceTypes, **kwargs: Any
) -> Callable[..., Dict[str, Any]]:
obj = find_key_with_value_of_type(type, **kwargs)
return partial(lambda **kwargs: obj, **kwargs)


@dataclass
class AuthzAction:
Expand Down
21 changes: 14 additions & 7 deletions chromadb/auth/authz/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any, Dict, cast
from typing import Any, Dict, Set, cast
from overrides import override
import yaml
from chromadb.auth import (
Expand All @@ -8,7 +8,7 @@
ServerAuthorizationProvider,
)
from chromadb.auth.registry import register_provider, resolve_provider
from chromadb.config import System
from chromadb.config import DEFAULT_TENANT, System

from chromadb.telemetry.opentelemetry import (
OpenTelemetryGranularity,
Expand Down Expand Up @@ -64,14 +64,18 @@ def __init__(self, system: System) -> None:
self.require(_cls),
)
_config = self._authz_config_provider.get_configuration()
self._authz_tuples = []
self._authz_tuples_map: Dict[str, Set[Any]] = {}
for u in _config["users"]:
_actions = _config["roles_mapping"][u["role"]]["actions"]
for a in _actions:
tenant = u["tenant"] if "tenant" in u else "*"
self._authz_tuples.append((u["id"], tenant, *a.split(":")))
tenant = u["tenant"] if "tenant" in u else DEFAULT_TENANT
if u["id"] not in self._authz_tuples_map.keys():
self._authz_tuples_map[u["id"]] = set()
self._authz_tuples_map[u["id"]].add(
(u["id"], tenant, *a.split(":"))
)
logger.debug(
f"Loaded {len(self._authz_tuples)} permissions for "
f"Loaded {len(self._authz_tuples_map)} permissions for "
f"({len(_config['users'])}) users"
)
logger.info(
Expand All @@ -92,7 +96,10 @@ def authorize(self, context: AuthorizationContext) -> bool:
)

policy_decision = False
if _authz_tuple in self._authz_tuples:
if (
context.user.id in self._authz_tuples_map.keys()
and _authz_tuple in self._authz_tuples_map[context.user.id]
):
policy_decision = True
logger.debug(
f"Authorization decision: Access "
Expand Down
10 changes: 0 additions & 10 deletions chromadb/auth/basic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,3 @@ def authenticate(
except Exception as e:
logger.error(f"BasicAuthServerProvider.authenticate failed: {repr(e)}")
return SimpleServerAuthenticationResponse(False, None)

# @override
# def get_auth_info_type(self, request: ServerAuthenticationRequest[Any]) \
# -> UserIdentity:
# _auth_header = request.get_auth_info(
# AuthInfoType.HEADER, "Authorization")
# _creds = BasicAuthCredentials.from_header(_auth_header)
# return SimpleUserIdentity(
# _creds.get_credentials()["username"].get_secret_value()
# )
21 changes: 12 additions & 9 deletions chromadb/auth/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from starlette.responses import Response, JSONResponse
from starlette.types import ASGIApp

from chromadb.config import System
from chromadb.config import DEFAULT_TENANT, System
from chromadb.auth import (
AuthorizationContext,
AuthorizationError,
Expand Down Expand Up @@ -113,7 +113,7 @@ def ignore_operation(self, verb: str, path: str) -> bool:
def instrument_server(self, app: ASGIApp) -> None:
# We can potentially add an `/auth` endpoint to the server to allow for more
# complex auth flows
return
raise NotImplementedError("Not implemented yet")


class FastAPIChromaAuthMiddlewareWrapper(BaseHTTPMiddleware): # type: ignore
Expand All @@ -122,7 +122,10 @@ def __init__(
) -> None:
super().__init__(app)
self._middleware = auth_middleware
self._middleware.instrument_server(app)
try:
self._middleware.instrument_server(app)
except NotImplementedError:
pass

@trace_method(
"FastAPIChromaAuthMiddlewareWrapper.dispatch", OpenTelemetryGranularity.ALL
Expand All @@ -145,9 +148,6 @@ async def dispatch(
return await call_next(request)


# AuthZ


request_var: ContextVar[Optional[Request]] = ContextVar("request_var", default=None)
authz_provider: ContextVar[Optional[ServerAuthorizationProvider]] = ContextVar(
"authz_provider", default=None
Expand Down Expand Up @@ -189,7 +189,7 @@ def wrapped(*args: Any, **kwargs: Dict[Any, Any]) -> Any:
else "Anonymous",
tenant=request.state.user_identity.get_user_tenant()
if hasattr(request.state, "user_identity")
else "*",
else DEFAULT_TENANT,
),
resource=_resource,
action=_action,
Expand Down Expand Up @@ -259,7 +259,7 @@ def ignore_operation(self, verb: str, path: str) -> bool:
def instrument_server(self, app: ASGIApp) -> None:
# We can potentially add an `/auth` endpoint to the server to allow
# for more complex auth flows
return
raise NotImplementedError("Not implemented yet")


class FastAPIChromaAuthzMiddlewareWrapper(BaseHTTPMiddleware): # type: ignore
Expand All @@ -268,7 +268,10 @@ def __init__(
) -> None:
super().__init__(app)
self._middleware = authz_middleware
self._middleware.instrument_server(app)
try:
self._middleware.instrument_server(app)
except NotImplementedError:
pass

@trace_method(
"FastAPIChromaAuthzMiddlewareWrapper.dispatch", OpenTelemetryGranularity.ALL
Expand Down
31 changes: 31 additions & 0 deletions chromadb/auth/fastapi_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from functools import partial
from typing import Any, Callable, Dict
from chromadb.auth import AuthzResourceTypes


def find_key_with_value_of_type(
type: AuthzResourceTypes, **kwargs: Any
) -> Dict[str, Any]:
from chromadb.server.fastapi.types import (
CreateCollection,
CreateDatabase,
CreateTenant,
)

for key, value in kwargs.items():
if type == AuthzResourceTypes.DB and isinstance(value, CreateDatabase):
return dict(value)
elif type == AuthzResourceTypes.COLLECTION and isinstance(
value, CreateCollection
):
return dict(value)
elif type == AuthzResourceTypes.TENANT and isinstance(value, CreateTenant):
return dict(value)
return {}


def attr_from_resource_object(
type: AuthzResourceTypes, **kwargs: Any
) -> Callable[..., Dict[str, Any]]:
obj = find_key_with_value_of_type(type, **kwargs)
return partial(lambda **kwargs: obj, **kwargs)
11 changes: 5 additions & 6 deletions chromadb/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
FastAPIChromaAuthzMiddlewareWrapper,
authz_context,
)
from chromadb.auth.fastapi_utils import attr_from_resource_object
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
import chromadb.server
import chromadb.api
Expand Down Expand Up @@ -280,9 +281,7 @@ def version(self) -> str:
action=AuthzResourceActions.CREATE_DATABASE,
resource=DynamicAuthzResource(
type=AuthzResourceTypes.DB,
attributes=AuthzDynamicParams.attr_from_resource_object(
type=AuthzResourceTypes.DB
),
attributes=attr_from_resource_object(type=AuthzResourceTypes.DB),
),
)
def create_database(
Expand Down Expand Up @@ -444,7 +443,7 @@ def add(self, collection_id: str, add: AddEmbedding) -> None:
)
def update(self, collection_id: str, add: UpdateEmbedding) -> None:
return self._api._update(
ids=add.ids, # type: ignore
ids=add.ids,
collection_id=_uuid(collection_id),
embeddings=add.embeddings,
documents=add.documents, # type: ignore
Expand All @@ -461,7 +460,7 @@ def update(self, collection_id: str, add: UpdateEmbedding) -> None:
)
def upsert(self, collection_id: str, upsert: AddEmbedding) -> None:
return self._api._upsert(
collection_id=_uuid(collection_id), # type: ignore
collection_id=_uuid(collection_id),
ids=upsert.ids,
embeddings=upsert.embeddings, # type: ignore
documents=upsert.documents, # type: ignore
Expand All @@ -470,7 +469,7 @@ def upsert(self, collection_id: str, upsert: AddEmbedding) -> None:

@trace_method("FastAPI.get", OpenTelemetryGranularity.OPERATION)
@authz_context(
action=[AuthzResourceActions.GET, AuthzResourceActions.PEEK],
action=AuthzResourceActions.GET,
resource=DynamicAuthzResource(
id=AuthzDynamicParams.from_function_kwargs(arg_name="collection_id"),
type=AuthzResourceTypes.COLLECTION,
Expand Down
Loading

0 comments on commit 38fd73a

Please sign in to comment.