diff --git a/bento_lib/auth/middleware/base.py b/bento_lib/auth/middleware/base.py index a239c2c..019db43 100644 --- a/bento_lib/auth/middleware/base.py +++ b/bento_lib/auth/middleware/base.py @@ -3,7 +3,7 @@ import requests from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Callable from ..exceptions import BentoAuthException @@ -48,7 +48,15 @@ def check_require_token(require_token: bool, token: str | None) -> None: def mk_authz_url(self, path: str) -> str: return f"{self._bento_authz_service_url.rstrip('/')}{path}" - def _extract_token_and_build_headers(self, request: Any, require_token: bool) -> dict: + def _extract_token_and_build_headers( + self, + request: Any, + require_token: bool, + headers_getter: Callable[[Any], dict[str, str]] | None = None, + ) -> dict[str, str]: + if headers_getter: + return headers_getter(request) + tkn_header = self.get_authz_header_value(request) self.check_require_token(require_token, tkn_header) return {"Authorization": tkn_header} if tkn_header else {} @@ -58,11 +66,18 @@ def _gen_exc_non_200_error_from_authz(self, code: int, content: bytes): # Generic error - don't leak errors from authz service! raise BentoAuthException("Error from authz service", status_code=500) - def authz_post(self, request: Any, path: str, body: dict, require_token: bool = False) -> dict: + def authz_post( + self, + request: Any, + path: str, + body: dict, + require_token: bool = False, + headers_getter: Callable[[Any], dict[str, str]] | None = None, + ) -> dict: res = requests.post( self.mk_authz_url(path), json=body, - headers=self._extract_token_and_build_headers(request, require_token), + headers=self._extract_token_and_build_headers(request, require_token, headers_getter), verify=self._verify_ssl) if res.status_code != 200: # Invalid authorization service response @@ -70,12 +85,19 @@ def authz_post(self, request: Any, path: str, body: dict, require_token: bool = return res.json() - async def async_authz_post(self, request: Any, path: str, body: dict, require_token: bool = False) -> dict: + async def async_authz_post( + self, + request: Any, + path: str, + body: dict, + require_token: bool = False, + headers_getter: Callable[[Any], dict[str, str]] | None = None, + ) -> dict: async with aiohttp.ClientSession() as session: async with session.post( self.mk_authz_url(path), json=body, - headers=self._extract_token_and_build_headers(request, require_token), + headers=self._extract_token_and_build_headers(request, require_token, headers_getter), ssl=(None if self._verify_ssl else False)) as res: if res.status != 200: # Invalid authorization service response @@ -95,6 +117,7 @@ def check_authz_evaluate( resource: dict, require_token: bool = True, set_authz_flag: bool = False, + headers_getter: Callable[[Any], dict[str, str]] | None = None, ): if not self.enabled: return @@ -104,6 +127,7 @@ def check_authz_evaluate( "/policy/evaluate", body={"requested_resource": resource, "required_permissions": list(permissions)}, require_token=require_token, + headers_getter=headers_getter, ) if not res.get("result"): @@ -121,6 +145,7 @@ async def async_check_authz_evaluate( resource: dict, require_token: bool = True, set_authz_flag: bool = False, + headers_getter: Callable[[Any], dict[str, str]] | None = None, ): if not self.enabled: return @@ -130,6 +155,7 @@ async def async_check_authz_evaluate( "/policy/evaluate", body={"requested_resource": resource, "required_permissions": list(permissions)}, require_token=require_token, + headers_getter=headers_getter, ) if not res.get("result"): diff --git a/bento_lib/package.cfg b/bento_lib/package.cfg index 5bdd696..2080101 100644 --- a/bento_lib/package.cfg +++ b/bento_lib/package.cfg @@ -1,5 +1,5 @@ [package] name = bento_lib -version = 7.0.0a3 +version = 7.0.0a4 authors = David Lougheed, Paul Pillot author_emails = david.lougheed@mail.mcgill.ca, paul.pillot@computationalgenomics.ca diff --git a/tests/test_platform_fastapi.py b/tests/test_platform_fastapi.py index 9942024..94c6649 100644 --- a/tests/test_platform_fastapi.py +++ b/tests/test_platform_fastapi.py @@ -12,6 +12,7 @@ from pydantic import BaseModel from bento_lib.auth.exceptions import BentoAuthException +from bento_lib.auth.middleware.constants import RESOURCE_EVERYTHING from bento_lib.auth.middleware.fastapi import FastApiAuthMiddleware from bento_lib.responses.fastapi_errors import ( http_exception_handler_factory, @@ -36,6 +37,11 @@ class TestBody(BaseModel): test2: str +class TestTokenBody(BaseModel): + token: str + payload: str + + # Standard test app ----------------------------------------------------------- test_app = FastAPI() @@ -132,6 +138,20 @@ def auth_get_500(): raise HTTPException(500, "Internal Server Error") +@test_app_auth.post("/post-with-token-in-body") +async def auth_post_with_token_in_body(request: Request, body: TestTokenBody): + token = body.token + await auth_middleware.async_check_authz_evaluate( + request, + frozenset({PERMISSION_INGEST_DATA}), + RESOURCE_EVERYTHING, + require_token=True, + set_authz_flag=True, + headers_getter=(lambda _r: {"Authorization": f"Bearer {token}"}), + ) + return JSONResponse({"payload": body.payload}) + + # Auth test app (disabled auth middleware) ------------------------------------ test_app_auth_disabled = FastAPI() @@ -253,6 +273,13 @@ def test_fastapi_auth_options_call(aioresponse: aioresponses, fastapi_client_aut assert r.status_code == 200 +def test_fastapi_auth_post_with_token_in_body(aioresponse: aioresponses, fastapi_client_auth: TestClient): + aioresponse.post("https://bento-auth.local/policy/evaluate", status=200, payload={"result": True}) + r = fastapi_client_auth.post("/post-with-token-in-body", json={"token": "test", "payload": "hello world"}) + assert r.status_code == 200 + assert r.text == '{"payload":"hello world"}' + + @pytest.mark.asyncio async def test_fastapi_auth_disabled(aioresponse: aioresponses, fastapi_client_auth_disabled: TestClient): # middleware is disabled, should work anyway diff --git a/tests/test_platform_flask.py b/tests/test_platform_flask.py index b441684..86c82f1 100644 --- a/tests/test_platform_flask.py +++ b/tests/test_platform_flask.py @@ -8,6 +8,7 @@ from flask.testing import FlaskClient from werkzeug.exceptions import BadRequest, NotFound, InternalServerError +from bento_lib.auth.middleware.constants import RESOURCE_EVERYTHING from bento_lib.auth.middleware.flask import FlaskAuthMiddleware from .common import ( @@ -99,6 +100,20 @@ def auth_500(): def auth_404(): raise NotFound() + @test_app_auth.route("/post-with-token-in-body", methods=["POST"]) + def auth_post_with_token_in_body(): + token = request.json["token"] + payload = request.json["payload"] + auth_middleware.check_authz_evaluate( + request, + frozenset({PERMISSION_INGEST_DATA}), + RESOURCE_EVERYTHING, + require_token=True, + set_authz_flag=True, + headers_getter=(lambda _r: {"Authorization": f"Bearer {token}"}), + ) + return jsonify({"payload": payload}) + with test_app_auth.test_client() as client: yield client @@ -208,7 +223,15 @@ def test_flask_auth_404(flask_client_auth: FlaskClient): @responses.activate -def test_fastapi_auth_disabled(flask_client_auth_disabled_with_middleware: tuple[FlaskClient, FlaskAuthMiddleware]): +def test_flask_auth_post_with_token_in_body(flask_client_auth: FlaskClient): + responses.add(responses.POST, "https://bento-auth.local/policy/evaluate", json={"result": True}, status=200) + r = flask_client_auth.post("/post-with-token-in-body", json={"token": "test", "payload": "hello world"}) + assert r.status_code == 200 + assert r.text == '{"payload":"hello world"}\n' + + +@responses.activate +def test_flask_auth_disabled(flask_client_auth_disabled_with_middleware: tuple[FlaskClient, FlaskAuthMiddleware]): flask_client_auth_disabled, auth_middleware_disabled = flask_client_auth_disabled_with_middleware # middleware is disabled, should work anyway