Skip to content

Commit

Permalink
Merge pull request #131 from bento-platform/chore/mark-authz-done-mixin
Browse files Browse the repository at this point in the history
chore: move mark authz done fn into mixin for use in authz service
  • Loading branch information
davidlougheed authored Oct 20, 2023
2 parents b816a5e + 4bd7147 commit 32c177e
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 19 deletions.
14 changes: 4 additions & 10 deletions bento_lib/auth/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
from typing import Any, Callable, Iterable

from ..exceptions import BentoAuthException
from ..types import EvaluationResultMatrix
from .mark_authz_done_mixin import MarkAuthzDoneMixin

__all__ = ["EvaluationResultMatrix", "BaseAuthMiddleware"]
__all__ = ["BaseAuthMiddleware"]


EvaluationResultMatrix = tuple[tuple[bool, ...], ...]


class BaseAuthMiddleware(ABC):
class BaseAuthMiddleware(ABC, MarkAuthzDoneMixin):
def __init__(
self,
bento_authz_service_url: str,
Expand Down Expand Up @@ -192,11 +191,6 @@ async def async_evaluate_one(
)
)[0][0]

@staticmethod
@abstractmethod
def mark_authz_done(request: Any): # pragma: no cover
pass

def check_authz_evaluate(
self,
request: Any,
Expand Down
12 changes: 12 additions & 0 deletions bento_lib/auth/middleware/mark_authz_done_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from abc import abstractmethod
from typing import Any


__all__ = ["MarkAuthzDoneMixin"]


class MarkAuthzDoneMixin:
@staticmethod
@abstractmethod
def mark_authz_done(request: Any): # pragma: no cover
pass
13 changes: 13 additions & 0 deletions bento_lib/auth/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import Type

from ..auth.middleware.mark_authz_done_mixin import MarkAuthzDoneMixin

__all__ = [
"EvaluationResultMatrix",
"MarkAuthzDoneType",
]

EvaluationResultMatrix = tuple[tuple[bool, ...], ...]

# Allow subclass OR instance, since mark_authz_done is a static method:
MarkAuthzDoneType = MarkAuthzDoneMixin | Type[MarkAuthzDoneMixin]
2 changes: 1 addition & 1 deletion bento_lib/package.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[package]
name = bento_lib
version = 9.0.0a2
version = 9.0.0a3
authors = David Lougheed, Paul Pillot
author_emails = [email protected], [email protected]
8 changes: 4 additions & 4 deletions bento_lib/responses/fastapi_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from starlette.responses import Response
from typing import Callable

from ..auth.middleware.fastapi import FastApiAuthMiddleware
from ..auth.exceptions import BentoAuthException
from ..auth.types import MarkAuthzDoneType
from .errors import http_error

__all__ = [
Expand All @@ -26,7 +26,7 @@ def _log_if_500(logger: logging.Logger, code: int, exc: Exception) -> None:

def http_exception_handler_factory(
logger: logging.Logger,
authz: FastApiAuthMiddleware | None = None,
authz: MarkAuthzDoneType | None = None,
**kwargs,
) -> Callable[[Request, HTTPException], Response]:
def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
Expand All @@ -41,7 +41,7 @@ def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse

def bento_auth_exception_handler_factory(
logger: logging.Logger,
authz: FastApiAuthMiddleware | None = None,
authz: MarkAuthzDoneType | None = None,
**kwargs,
) -> Callable[[Request, BentoAuthException], Response]:
def bento_auth_exception_handler(request: Request, exc: BentoAuthException) -> JSONResponse:
Expand All @@ -55,7 +55,7 @@ def bento_auth_exception_handler(request: Request, exc: BentoAuthException) -> J


def validation_exception_handler_factory(
authz: FastApiAuthMiddleware | None = None,
authz: MarkAuthzDoneType | None = None,
**kwargs,
) -> Callable[[Request, RequestValidationError], Response]:
def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
Expand Down
8 changes: 4 additions & 4 deletions bento_lib/responses/flask_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from functools import partial
from typing import Callable

from bento_lib.auth.middleware.flask import FlaskAuthMiddleware
from bento_lib.responses import errors
from ..auth.types import MarkAuthzDoneType
from ..responses import errors


__all__ = [
Expand Down Expand Up @@ -38,7 +38,7 @@ def flask_error_wrap_with_traceback(fn: Callable, *args, **kwargs) -> Callable:
service_name = kwargs.pop("service_name", "Bento Service")

logger = kwargs.pop("logger", None)
authz: FlaskAuthMiddleware | None = kwargs.pop("authz", None)
authz: MarkAuthzDoneType | None = kwargs.pop("authz", None)

def handle_error(e):
if logger:
Expand All @@ -62,7 +62,7 @@ def flask_error_wrap(fn: Callable, *args, **kwargs) -> Callable:
:return: The wrapped function
"""

authz: FlaskAuthMiddleware | None = kwargs.pop("authz", None)
authz: MarkAuthzDoneType | None = kwargs.pop("authz", None)

def handle_error(e):
if authz:
Expand Down

0 comments on commit 32c177e

Please sign in to comment.