Skip to content

Commit

Permalink
fix: Apply code changes to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dominik003 committed Jul 29, 2024
1 parent ee0535b commit 4ca31f3
Show file tree
Hide file tree
Showing 24 changed files with 355 additions and 223 deletions.
2 changes: 1 addition & 1 deletion backend/capellacollab/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ class AuthenticationConfig(BaseConfig):
mapping: ClaimMappingConfig = ClaimMappingConfig()
scopes: list[str] = pydantic.Field(
default=["openid", "profile", "offline_access"],
description="List of scopes that application neeeds to access the required attributes.",
description="List of scopes that the application needs to access the required attributes.",
)
client: AuthOauthClientConfig = AuthOauthClientConfig()
redirect_uri: str = pydantic.Field(
Expand Down
42 changes: 17 additions & 25 deletions backend/capellacollab/core/authentication/api_key_cookie.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,24 @@
auth_config = config.authentication


class JWTConfigBorg:
_shared_state: dict[str, str] = {}

def __init__(
self, provider_config: oidc_provider.AbstractOIDCProviderConfig
):
self.__dict__ = self._shared_state
self.provider_config = provider_config

if not hasattr(self, "jwks_client"):
self.jwks_client = jwt.PyJWKClient(
uri=self.provider_config.get_jwks_uri()
class JWTConfig:
_jwks_client = None

def __init__(self, oidc_config: oidc_provider.AbstractOIDCProviderConfig):
self.oidc_config = oidc_config

if JWTConfig._jwks_client is None:
JWTConfig._jwks_client = jwt.PyJWKClient(
uri=self.oidc_config.get_jwks_uri()
)
self.jwks_client = JWTConfig._jwks_client


class JWTAPIKeyCookie(security.APIKeyCookie):
def __init__(
self, provider_config: oidc_provider.AbstractOIDCProviderConfig
):
def __init__(self, oidc_config: oidc_provider.AbstractOIDCProviderConfig):
super().__init__(name="id_token", auto_error=True)
self.provider_config = provider_config
self.jwt_config = JWTConfigBorg(provider_config)
self.oidc_config = oidc_config
self.jwt_config = JWTConfig(oidc_config)

async def __call__(self, request: fastapi.Request) -> str:
token: str | None = await super().__call__(request)
Expand All @@ -59,14 +55,10 @@ def validate_token(self, token: str) -> dict[str, t.Any]:
return jwt.decode(
jwt=token,
key=signing_key.key,
algorithms=self.provider_config.get_supported_signing_algorithms(),
audience=self.provider_config.get_client_id(),
issuer=self.provider_config.get_issuer(),
options={
"verify_exp": True,
"verify_iat": True,
"verify_nbf": True,
},
algorithms=self.oidc_config.get_supported_signing_algorithms(),
audience=self.oidc_config.get_client_id(),
issuer=self.oidc_config.get_issuer(),
options={"require": ["exp", "iat"]},
)
except jwt_exceptions.ExpiredSignatureError:
raise exceptions.TokenSignatureExpired()
Expand Down
24 changes: 9 additions & 15 deletions backend/capellacollab/core/authentication/injectables.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,20 @@


@functools.lru_cache
def get_cached_oidc_provider_config() -> (
oidc_provider.AbstractOIDCProviderConfig
):
def get_cached_oidc_config() -> oidc_provider.AbstractOIDCProviderConfig:
return oidc_provider.WellKnownOIDCProviderConfig()


async def get_oidc_provider_config() -> (
oidc_provider.AbstractOIDCProviderConfig
):
return get_cached_oidc_provider_config()
async def get_oidc_config() -> oidc_provider.AbstractOIDCProviderConfig:
return get_cached_oidc_config()


async def get_oidc_provider(
oidc_provider_config: oidc_provider.AbstractOIDCProviderConfig = fastapi.Depends(
get_oidc_provider_config
oidc_config: oidc_provider.AbstractOIDCProviderConfig = fastapi.Depends(
get_oidc_config
),
) -> oidc_provider.AbstractOIDCProvider:
return oidc_provider.OIDCProvider(oidc_provider_config)
return oidc_provider.OIDCProvider(oidc_config)


class OpenAPIFakeBase(security_base.SecurityBase):
Expand Down Expand Up @@ -83,15 +79,13 @@ class OpenAPIPersonalAccessToken(OpenAPIFakeBase):

async def get_username(
request: fastapi.Request,
provider_config: oidc_provider.AbstractOIDCProviderConfig = fastapi.Depends(
get_oidc_provider_config
oidc_config: oidc_provider.AbstractOIDCProviderConfig = fastapi.Depends(
get_oidc_config
),
_unused1=fastapi.Depends(OpenAPIPersonalAccessToken()),
) -> str:
if request.cookies.get("id_token"):
username = await api_key_cookie.JWTAPIKeyCookie(provider_config)(
request
)
username = await api_key_cookie.JWTAPIKeyCookie(oidc_config)(request)
return username

authorization = request.headers.get("Authorization")
Expand Down
24 changes: 12 additions & 12 deletions backend/capellacollab/core/authentication/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,16 @@ async def api_get_token(
provider: oidc_provider.AbstractOIDCProvider = fastapi.Depends(
injectables.get_oidc_provider
),
provider_config: oidc_provider.AbstractOIDCProviderConfig = fastapi.Depends(
injectables.get_oidc_provider_config
oidc_config: oidc_provider.AbstractOIDCProviderConfig = fastapi.Depends(
injectables.get_oidc_config
),
):
tokens = provider.exchange_code_for_tokens(
token_request.code, token_request.code_verifier
)

validated_id_token = validate_id_token(
tokens["id_token"], provider_config, None
tokens["id_token"], oidc_config, None
)
user = create_or_update_user(db, validated_id_token)

Expand All @@ -72,8 +72,8 @@ async def api_refresh_token(
provider: oidc_provider.AbstractOIDCProvider = fastapi.Depends(
injectables.get_oidc_provider
),
provider_config: oidc_provider.AbstractOIDCProviderConfig = fastapi.Depends(
injectables.get_oidc_provider_config
oidc_config: oidc_provider.AbstractOIDCProviderConfig = fastapi.Depends(
injectables.get_oidc_config
),
):
if refresh_token is None or refresh_token == "":
Expand All @@ -82,7 +82,7 @@ async def api_refresh_token(
tokens = provider.refresh_token(refresh_token)

validated_id_token = validate_id_token(
tokens["id_token"], provider_config, None
tokens["id_token"], oidc_config, None
)
user = create_or_update_user(db, validated_id_token)

Expand All @@ -102,11 +102,11 @@ async def validate_token(
request: fastapi.Request,
scope: users_models.Role | None = None,
db: orm.Session = fastapi.Depends(database.get_db),
provider_config: oidc_provider.AbstractOIDCProviderConfig = fastapi.Depends(
injectables.get_oidc_provider_config
oidc_config: oidc_provider.AbstractOIDCProviderConfig = fastapi.Depends(
injectables.get_oidc_config
),
):
username = await api_key_cookie.JWTAPIKeyCookie(provider_config)(request)
username = await api_key_cookie.JWTAPIKeyCookie(oidc_config)(request)
if scope and scope.ADMIN:
auth_injectables.RoleVerification(
required_role=users_models.Role.ADMIN
Expand All @@ -116,17 +116,17 @@ async def validate_token(

def validate_id_token(
id_token: str,
provider_config: oidc_provider.AbstractOIDCProviderConfig,
oidc_config: oidc_provider.AbstractOIDCProviderConfig,
nonce: str | None,
) -> dict[str, str]:
validated_id_token = api_key_cookie.JWTAPIKeyCookie(
provider_config
oidc_config
).validate_token(id_token)

if nonce and not hmac.compare_digest(validated_id_token["nonce"], nonce):
raise exceptions.NonceMismatchError()

if provider_config.get_client_id() not in validated_id_token["aud"]:
if oidc_config.get_client_id() not in validated_id_token["aud"]:
raise exceptions.UnauthenticatedError()

return validated_id_token
Expand Down
6 changes: 2 additions & 4 deletions backend/capellacollab/core/logging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,9 @@ async def dispatch(
call_next: base.RequestResponseEndpoint,
):
try:
oidc_provider_config = (
await auth_injectables.get_oidc_provider_config()
)
oidc_config = await auth_injectables.get_oidc_config()
username = await auth_injectables.get_username(
request, oidc_provider_config
request, oidc_config
)
except fastapi.HTTPException:
username = "anonymous"
Expand Down
17 changes: 15 additions & 2 deletions backend/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from testcontainers import postgres

from capellacollab.__main__ import app
from capellacollab.core.authentication import oidc_provider
from capellacollab.core.database import migration

os.environ["DEVELOPMENT_MODE"] = "1"
Expand Down Expand Up @@ -93,8 +94,8 @@ def commit(*args, **kwargs):
@pytest.fixture()
def client(monkeypatch: pytest.MonkeyPatch) -> testclient.TestClient:
monkeypatch.setattr(
"capellacollab.core.authentication.api_key_cookie.JWTConfigBorg",
core_conftest.MockJWTConfigBorg,
"capellacollab.core.authentication.api_key_cookie.JWTConfig",
core_conftest.MockJWTConfig,
)

return testclient.TestClient(app, cookies={"id_token": "any"})
Expand All @@ -103,3 +104,15 @@ def client(monkeypatch: pytest.MonkeyPatch) -> testclient.TestClient:
@pytest.fixture(name="logger")
def fixture_logger() -> logging.LoggerAdapter:
return logging.LoggerAdapter(logging.getLogger())


@pytest.fixture(name="mock_oidc_config")
def fixture_mock_oidc_config():
return core_conftest.MockOIDCProviderConfig()


@pytest.fixture(name="mock_oidc_provider")
def fixture_mock_oidc_provider(
mock_oidc_config: oidc_provider.AbstractOIDCProviderConfig,
):
return core_conftest.MockOIDCProvider(mock_oidc_config)
105 changes: 94 additions & 11 deletions backend/tests/core/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
# SPDX-License-Identifier: Apache-2.0


import typing as t

import pytest
from fastapi import testclient

from capellacollab.__main__ import app
from capellacollab.core.authentication import injectables as auth_injectables
from capellacollab.core.authentication import oidc_provider


class MockPyJWK:
Expand All @@ -18,24 +22,103 @@ def get_signing_key_from_jwt(self, token: str):
return MockPyJWK()


class MockJWTConfigBorg:
_shared_state: dict[str, str] = {}
class MockJWTConfig:
def __init__(
self, oidc_config: oidc_provider.AbstractOIDCProviderConfig
) -> None:
self.oidc_config = oidc_config
self.jwks_client = MockJWKSClient()


class MockOIDCProviderConfig(oidc_provider.AbstractOIDCProviderConfig):
def get_authorization_endpoint(self) -> str:
return "mock-authorization-endpoint"

def get_token_endpoint(self) -> str:
return "mock-token-endpoint"

def get_jwks_uri(self) -> str:
return "mock-jwks-uri"

def get_supported_signing_algorithms(self) -> list[str]:
return ["RS256"]

def get_issuer(self) -> str:
return "mock-issuer"

def get_scopes(self) -> list[str]:
return ["openid", "offline_access", "email"]

def get_client_secret(self) -> str:
return "mock-secret"

def get_client_id(self) -> str:
return "mock-client-id"

def __init__(self) -> None:
self.__dict__ = self._shared_state

if not hasattr(self, "_jwks_client"):
self.jwks_client = MockJWKSClient()
class MockOIDCProvider(oidc_provider.AbstractOIDCProvider):
def __init__(self, oidc_config: oidc_provider.AbstractOIDCProviderConfig):
super().__init__(oidc_config)
self.oidc_config = oidc_config

if not hasattr(self, "_supported_signing_algorithms"):
self.supported_signing_algorithms = ["RS256"]
def get_authorization_url_with_parameters(
self,
) -> t.Tuple[str, str, str, str]:
return (
"mock-auth-url",
"mock-state",
"mock-nonce",
"mock-code-verifier",
)

def exchange_code_for_tokens(
self, authorization_code: str, code_verifier: str
) -> dict[str, t.Any]:
return {
"id_token": "mock-id-token",
"access-token": "mock-access-token",
"refresh-token": "mock-refresh-token",
}

def refresh_token(self, _refresh_token: str) -> dict[str, t.Any]:
return {
"id_token": "mock-id-token",
"access-token": "mock-access-token",
"refresh-token": "mock-refresh-token",
}


@pytest.fixture(name="mock_oidc_provider_and_config")
def fixture_mock_oidc_provider_and_config(
mock_oidc_config: oidc_provider.AbstractOIDCProviderConfig,
mock_oidc_provider: oidc_provider.AbstractOIDCProvider,
):
async def get_mock_oidc_config() -> (
oidc_provider.AbstractOIDCProviderConfig
):
return mock_oidc_config

async def get_mock_oidc_provider() -> oidc_provider.AbstractOIDCProvider:
return mock_oidc_provider

app.dependency_overrides[auth_injectables.get_oidc_config] = (
get_mock_oidc_config
)
app.dependency_overrides[auth_injectables.get_oidc_provider] = (
get_mock_oidc_provider
)

yield

del app.dependency_overrides[auth_injectables.get_oidc_config]
del app.dependency_overrides[auth_injectables.get_oidc_provider]


@pytest.mark.usefixtures("mock_oidc_provider_and_config")
@pytest.fixture(name="unauthorized_client")
def fixture_unauthorized_client(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(
"capellacollab.core.authentication.api_key_cookie.JWTConfigBorg",
MockJWTConfigBorg,
"capellacollab.core.authentication.api_key_cookie.JWTConfig",
MockJWTConfig,
)

return testclient.TestClient(app)
Loading

0 comments on commit 4ca31f3

Please sign in to comment.