Skip to content

Commit

Permalink
feat: add a Fernet key to encrypt/decrypt the state dict
Browse files Browse the repository at this point in the history
  • Loading branch information
aldbr committed Feb 19, 2024
1 parent b622745 commit 665bd9b
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 14 deletions.
10 changes: 9 additions & 1 deletion diracx-core/src/diracx/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
if TYPE_CHECKING:
from pydantic.config import BaseConfig
from pydantic.fields import ModelField

from cryptography.fernet import Fernet

T = TypeVar("T")

Expand All @@ -42,6 +42,14 @@ def validate(cls, value: Any) -> SecretStr:
return super().validate(value)


class FernetKey(SecretStr):
fernet: Fernet

def __init__(self, data: str):
super().__init__(data)
self.fernet = Fernet(self.get_secret_value())


class LocalFileUrl(AnyUrl):
host_required = False
allowed_schemes = {"file"}
Expand Down
41 changes: 29 additions & 12 deletions diracx-routers/src/diracx/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from authlib.jose import JoseError, JsonWebKey, JsonWebToken
from authlib.oidc.core import IDToken
from cachetools import TTLCache
from cryptography.fernet import Fernet
from fastapi import (
Depends,
Form,
Expand All @@ -41,7 +42,7 @@
SecurityProperty,
UnevaluatedProperty,
)
from diracx.core.settings import ServiceSettingsBase, TokenSigningKey
from diracx.core.settings import FernetKey, ServiceSettingsBase, TokenSigningKey
from diracx.db.sql.auth.schema import FlowStatus, RefreshTokenStatus

from .dependencies import (
Expand All @@ -64,6 +65,9 @@ class AuthSettings(ServiceSettingsBase, env_prefix="DIRACX_SERVICE_AUTH_"):
device_flow_expiration_seconds: int = 600
authorization_flow_expiration_seconds: int = 300

# State key is used to encrypt/decrypt the state dict passed to the IAM
state_key: FernetKey

token_issuer: str = "http://lhcbdirac.cern.ch/"
token_audience: str = "dirac"
token_key: TokenSigningKey
Expand Down Expand Up @@ -385,7 +389,7 @@ async def initiate_device_flow(


async def initiate_authorization_flow_with_iam(
config, vo: str, redirect_uri: str, state: dict[str, str]
config, vo: str, redirect_uri: str, state: dict[str, str], cipher_suite: Fernet
):
# code_verifier: https://www.rfc-editor.org/rfc/rfc7636#section-4.1
code_verifier = secrets.token_hex()
Expand All @@ -404,10 +408,9 @@ async def initiate_authorization_flow_with_iam(
# Take these two from CS/.well-known
authorization_endpoint = server_metadata["authorization_endpoint"]

# TODO : encrypt it for good
encrypted_state = base64.urlsafe_b64encode(
json.dumps(state | {"vo": vo, "code_verifier": code_verifier}).encode()
).decode()
encrypted_state = encrypt_state(
state | {"vo": vo, "code_verifier": code_verifier}, cipher_suite
)

urlParams = [
"response_type=code",
Expand Down Expand Up @@ -501,15 +504,28 @@ async def do_device_flow(
}

authorization_flow_url = await initiate_authorization_flow_with_iam(
config, parsed_scope["vo"], redirect_uri, state_for_iam
config,
parsed_scope["vo"],
redirect_uri,
state_for_iam,
settings.state_key.fernet,
)
return RedirectResponse(authorization_flow_url)


def decrypt_state(state):
def encrypt_state(state_dict: dict[str, str], cipher_suite: Fernet) -> str:
"""Encrypt the state dict and return it as a string"""
return cipher_suite.encrypt(
base64.urlsafe_b64encode(json.dumps(state_dict).encode())
).decode()


def decrypt_state(state: str, cipher_suite: Fernet) -> dict[str, str]:
"""Decrypt the state string and return it as a dict"""
try:
# TODO: There have been better schemes like rot13
return json.loads(base64.urlsafe_b64decode(state).decode())
return json.loads(
base64.urlsafe_b64decode(cipher_suite.decrypt(state.encode())).decode()
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid state"
Expand All @@ -532,7 +548,7 @@ async def finish_device_flow(
can map it to the corresponding device flow using the user_code
in the cookie/session
"""
decrypted_state = decrypt_state(state)
decrypted_state = decrypt_state(state, settings.state_key.fernet)
assert decrypted_state["grant_type"] == GrantType.device_code

id_token = await get_token_from_iam(
Expand Down Expand Up @@ -981,6 +997,7 @@ async def authorization_flow(
parsed_scope["vo"],
f"{request.url.replace(query='')}/complete",
state_for_iam,
settings.state_key.fernet,
)

return responses.RedirectResponse(authorization_flow_url)
Expand All @@ -995,7 +1012,7 @@ async def authorization_flow_complete(
config: Config,
settings: AuthSettings,
):
decrypted_state = decrypt_state(state)
decrypted_state = decrypt_state(state, settings.state_key.fernet)
assert decrypted_state["grant_type"] == GrantType.authorization_code

id_token = await get_token_from_iam(
Expand Down
41 changes: 41 additions & 0 deletions diracx-routers/tests/auth/test_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,21 @@
import httpx
import jwt
import pytest
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from fastapi import HTTPException
from pytest_httpx import HTTPXMock

from diracx.core.config import Config
from diracx.core.properties import NORMAL_USER, PROXY_MANAGEMENT, SecurityProperty
from diracx.routers.auth import (
AuthSettings,
GrantType,
_server_metadata_cache,
create_token,
decrypt_state,
encrypt_state,
get_server_metadata,
parse_and_validate_scope,
)
Expand Down Expand Up @@ -370,6 +375,7 @@ async def test_refresh_token_invalid(test_client, auth_httpx_mock: HTTPXMock):

new_auth_settings = AuthSettings(
token_key=pem,
state_key=Fernet.generate_key(),
allowed_redirects=[
"http://diracx.test.invalid:8000/api/docs/oauth2-redirect",
],
Expand Down Expand Up @@ -680,3 +686,38 @@ def test_parse_scopes_invalid(vos, groups, scope, expected_error):
available_properties = SecurityProperty.available_properties()
with pytest.raises(ValueError, match=expected_error):
parse_and_validate_scope(scope, config, available_properties)


def test_encrypt_decrypt_state_valid_state(fernet_key):
"""Test that decrypt_state returns the correct state"""
fernet = Fernet(fernet_key)
# Create a valid state
state_dict = {
"vo": "lhcb",
"code_verifier": secrets.token_hex(),
"user_code": "AE19U",
"grant_type": GrantType.device_code.value,
}

state = encrypt_state(state_dict, fernet)
result = decrypt_state(state, fernet)

assert result == state_dict

# Create an empty state
state_dict = {}

state = encrypt_state(state_dict, fernet)
result = decrypt_state(state, fernet)

assert result == state_dict


def test_encrypt_decrypt_state_invalid_state(fernet_key):
"""Test that decrypt_state raises an error when the state is invalid"""
state = "invalid_state" # Invalid state string

with pytest.raises(HTTPException) as exc_info:
decrypt_state(state, fernet_key)
assert exc_info.value.status_code == 400
assert exc_info.value.detail == "Invalid state"
10 changes: 9 additions & 1 deletion diracx-testing/src/diracx/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,19 @@ def rsa_private_key_pem() -> str:


@pytest.fixture(scope="session")
def test_auth_settings(rsa_private_key_pem) -> AuthSettings:
def fernet_key() -> str:
from cryptography.fernet import Fernet

return Fernet.generate_key().decode()


@pytest.fixture(scope="session")
def test_auth_settings(rsa_private_key_pem, fernet_key) -> AuthSettings:
from diracx.routers.auth import AuthSettings

yield AuthSettings(
token_key=rsa_private_key_pem,
state_key=fernet_key,
allowed_redirects=[
"http://diracx.test.invalid:8000/api/docs/oauth2-redirect",
],
Expand Down

0 comments on commit 665bd9b

Please sign in to comment.