diff --git a/diracx-core/src/diracx/core/settings.py b/diracx-core/src/diracx/core/settings.py index dcc36e61..d726b098 100644 --- a/diracx-core/src/diracx/core/settings.py +++ b/diracx-core/src/diracx/core/settings.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from pydantic.config import BaseConfig from pydantic.fields import ModelField - +from cryptography.fernet import Fernet T = TypeVar("T") @@ -42,6 +42,14 @@ def validate(cls, value: Any) -> SecretStr: return super().validate(value) +class FernetKey(SecretStr): + fernet: Fernet + + def __init__(self, data: str): + super().__init__(data) + self.fernet = Fernet(self.get_secret_value()) + + class LocalFileUrl(AnyUrl): host_required = False allowed_schemes = {"file"} diff --git a/diracx-db/src/diracx/db/sql/auth/db.py b/diracx-db/src/diracx/db/sql/auth/db.py index fcdb8212..1596e6a4 100644 --- a/diracx-db/src/diracx/db/sql/auth/db.py +++ b/diracx-db/src/diracx/db/sql/auth/db.py @@ -1,5 +1,6 @@ from __future__ import annotations +import hashlib import secrets from datetime import datetime from uuid import uuid4 @@ -63,7 +64,7 @@ async def get_device_flow(self, device_code: str, max_validity: int): ), ).with_for_update() stmt = stmt.where( - DeviceFlows.device_code == device_code, + DeviceFlows.device_code == hashlib.sha256(device_code.encode()).hexdigest(), ) res = dict((await self.conn.execute(stmt)).one()._mapping) @@ -74,7 +75,10 @@ async def get_device_flow(self, device_code: str, max_validity: int): # Update the status to Done before returning await self.conn.execute( update(DeviceFlows) - .where(DeviceFlows.device_code == device_code) + .where( + DeviceFlows.device_code + == hashlib.sha256(device_code.encode()).hexdigest() + ) .values(status=FlowStatus.DONE) ) return res @@ -119,14 +123,17 @@ async def insert_device_flow( secrets.choice(USER_CODE_ALPHABET) for _ in range(DeviceFlows.user_code.type.length) # type: ignore ) - # user_code = "2QRKPY" device_code = secrets.token_urlsafe() + + # Hash the the device_code to avoid leaking information + hashed_device_code = hashlib.sha256(device_code.encode()).hexdigest() + stmt = insert(DeviceFlows).values( client_id=client_id, scope=scope, audience=audience, user_code=user_code, - device_code=device_code, + device_code=hashed_device_code, ) try: await self.conn.execute(stmt) @@ -172,7 +179,10 @@ async def authorization_flow_insert_id_token( :raises: AuthorizationError if no such uuid or status not pending """ + # Hash the code to avoid leaking information code = secrets.token_urlsafe() + hashed_code = hashlib.sha256(code.encode()).hexdigest() + stmt = update(AuthorizationFlows) stmt = stmt.where( @@ -181,7 +191,7 @@ async def authorization_flow_insert_id_token( AuthorizationFlows.creation_time > substract_date(seconds=max_validity), ) - stmt = stmt.values(id_token=id_token, code=code, status=FlowStatus.READY) + stmt = stmt.values(id_token=id_token, code=hashed_code, status=FlowStatus.READY) res = await self.conn.execute(stmt) if res.rowcount != 1: @@ -190,15 +200,16 @@ async def authorization_flow_insert_id_token( stmt = select(AuthorizationFlows.code, AuthorizationFlows.redirect_uri) stmt = stmt.where(AuthorizationFlows.uuid == uuid) row = (await self.conn.execute(stmt)).one() - return row.code, row.redirect_uri + return code, row.redirect_uri async def get_authorization_flow(self, code: str, max_validity: int): + hashed_code = hashlib.sha256(code.encode()).hexdigest() # The with_for_update # prevents that the token is retrieved # multiple time concurrently stmt = select(AuthorizationFlows).with_for_update() stmt = stmt.where( - AuthorizationFlows.code == code, + AuthorizationFlows.code == hashed_code, AuthorizationFlows.creation_time > substract_date(seconds=max_validity), ) @@ -208,7 +219,7 @@ async def get_authorization_flow(self, code: str, max_validity: int): # Update the status to Done before returning await self.conn.execute( update(AuthorizationFlows) - .where(AuthorizationFlows.code == code) + .where(AuthorizationFlows.code == hashed_code) .values(status=FlowStatus.DONE) ) diff --git a/diracx-db/src/diracx/db/sql/auth/schema.py b/diracx-db/src/diracx/db/sql/auth/schema.py index c85822ed..8a2f5fd8 100644 --- a/diracx-db/src/diracx/db/sql/auth/schema.py +++ b/diracx-db/src/diracx/db/sql/auth/schema.py @@ -46,7 +46,7 @@ class DeviceFlows(Base): client_id = Column(String(255)) scope = Column(String(1024)) audience = Column(String(255)) - device_code = Column(String(128), unique=True) # hash it ? + device_code = Column(String(128), unique=True) # Should be a hash id_token = NullColumn(JSON()) @@ -61,7 +61,7 @@ class AuthorizationFlows(Base): code_challenge = Column(String(255)) code_challenge_method = Column(String(8)) redirect_uri = Column(String(255)) - code = NullColumn(String(255)) # hash it ? + code = NullColumn(String(255)) # Should be a hash id_token = NullColumn(JSON()) diff --git a/diracx-routers/src/diracx/routers/auth.py b/diracx-routers/src/diracx/routers/auth.py index 44163a19..cb9f0cb0 100644 --- a/diracx-routers/src/diracx/routers/auth.py +++ b/diracx-routers/src/diracx/routers/auth.py @@ -16,6 +16,7 @@ from authlib.jose import JoseError, JsonWebKey, JsonWebToken from authlib.oidc.core import IDToken from cachetools import TTLCache +from cryptography.fernet import Fernet from fastapi import ( Depends, Form, @@ -41,7 +42,7 @@ SecurityProperty, UnevaluatedProperty, ) -from diracx.core.settings import ServiceSettingsBase, TokenSigningKey +from diracx.core.settings import FernetKey, ServiceSettingsBase, TokenSigningKey from diracx.db.sql.auth.schema import FlowStatus, RefreshTokenStatus from .dependencies import ( @@ -64,6 +65,9 @@ class AuthSettings(ServiceSettingsBase, env_prefix="DIRACX_SERVICE_AUTH_"): device_flow_expiration_seconds: int = 600 authorization_flow_expiration_seconds: int = 300 + # State key is used to encrypt/decrypt the state dict passed to the IAM + state_key: FernetKey + token_issuer: str = "http://lhcbdirac.cern.ch/" token_audience: str = "dirac" token_key: TokenSigningKey @@ -379,13 +383,13 @@ async def initiate_device_flow( "user_code": user_code, "device_code": device_code, "verification_uri_complete": f"{verification_uri}?user_code={user_code}", - "verification_uri": str(request.url.replace(query={})), + "verification_uri": verification_uri, "expires_in": settings.device_flow_expiration_seconds, } async def initiate_authorization_flow_with_iam( - config, vo: str, redirect_uri: str, state: dict[str, str] + config, vo: str, redirect_uri: str, state: dict[str, str], cipher_suite: Fernet ): # code_verifier: https://www.rfc-editor.org/rfc/rfc7636#section-4.1 code_verifier = secrets.token_hex() @@ -404,10 +408,9 @@ async def initiate_authorization_flow_with_iam( # Take these two from CS/.well-known authorization_endpoint = server_metadata["authorization_endpoint"] - # TODO : encrypt it for good - encrypted_state = base64.urlsafe_b64encode( - json.dumps(state | {"vo": vo, "code_verifier": code_verifier}).encode() - ).decode() + encrypted_state = encrypt_state( + state | {"vo": vo, "code_verifier": code_verifier}, cipher_suite + ) urlParams = [ "response_type=code", @@ -501,15 +504,28 @@ async def do_device_flow( } authorization_flow_url = await initiate_authorization_flow_with_iam( - config, parsed_scope["vo"], redirect_uri, state_for_iam + config, + parsed_scope["vo"], + redirect_uri, + state_for_iam, + settings.state_key.fernet, ) return RedirectResponse(authorization_flow_url) -def decrypt_state(state): +def encrypt_state(state_dict: dict[str, str], cipher_suite: Fernet) -> str: + """Encrypt the state dict and return it as a string""" + return cipher_suite.encrypt( + base64.urlsafe_b64encode(json.dumps(state_dict).encode()) + ).decode() + + +def decrypt_state(state: str, cipher_suite: Fernet) -> dict[str, str]: + """Decrypt the state string and return it as a dict""" try: - # TODO: There have been better schemes like rot13 - return json.loads(base64.urlsafe_b64decode(state).decode()) + return json.loads( + base64.urlsafe_b64decode(cipher_suite.decrypt(state.encode())).decode() + ) except Exception as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid state" @@ -532,7 +548,7 @@ async def finish_device_flow( can map it to the corresponding device flow using the user_code in the cookie/session """ - decrypted_state = decrypt_state(state) + decrypted_state = decrypt_state(state, settings.state_key.fernet) assert decrypted_state["grant_type"] == GrantType.device_code id_token = await get_token_from_iam( @@ -981,6 +997,7 @@ async def authorization_flow( parsed_scope["vo"], f"{request.url.replace(query='')}/complete", state_for_iam, + settings.state_key.fernet, ) return responses.RedirectResponse(authorization_flow_url) @@ -995,7 +1012,7 @@ async def authorization_flow_complete( config: Config, settings: AuthSettings, ): - decrypted_state = decrypt_state(state) + decrypted_state = decrypt_state(state, settings.state_key.fernet) assert decrypted_state["grant_type"] == GrantType.authorization_code id_token = await get_token_from_iam( diff --git a/diracx-routers/tests/auth/test_standard.py b/diracx-routers/tests/auth/test_standard.py index 2f9679ee..d282c473 100644 --- a/diracx-routers/tests/auth/test_standard.py +++ b/diracx-routers/tests/auth/test_standard.py @@ -8,16 +8,21 @@ import httpx import jwt import pytest +from cryptography.fernet import Fernet from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa +from fastapi import HTTPException from pytest_httpx import HTTPXMock from diracx.core.config import Config from diracx.core.properties import NORMAL_USER, PROXY_MANAGEMENT, SecurityProperty from diracx.routers.auth import ( AuthSettings, + GrantType, _server_metadata_cache, create_token, + decrypt_state, + encrypt_state, get_server_metadata, parse_and_validate_scope, ) @@ -370,6 +375,7 @@ async def test_refresh_token_invalid(test_client, auth_httpx_mock: HTTPXMock): new_auth_settings = AuthSettings( token_key=pem, + state_key=Fernet.generate_key(), allowed_redirects=[ "http://diracx.test.invalid:8000/api/docs/oauth2-redirect", ], @@ -680,3 +686,38 @@ def test_parse_scopes_invalid(vos, groups, scope, expected_error): available_properties = SecurityProperty.available_properties() with pytest.raises(ValueError, match=expected_error): parse_and_validate_scope(scope, config, available_properties) + + +def test_encrypt_decrypt_state_valid_state(fernet_key): + """Test that decrypt_state returns the correct state""" + fernet = Fernet(fernet_key) + # Create a valid state + state_dict = { + "vo": "lhcb", + "code_verifier": secrets.token_hex(), + "user_code": "AE19U", + "grant_type": GrantType.device_code.value, + } + + state = encrypt_state(state_dict, fernet) + result = decrypt_state(state, fernet) + + assert result == state_dict + + # Create an empty state + state_dict = {} + + state = encrypt_state(state_dict, fernet) + result = decrypt_state(state, fernet) + + assert result == state_dict + + +def test_encrypt_decrypt_state_invalid_state(fernet_key): + """Test that decrypt_state raises an error when the state is invalid""" + state = "invalid_state" # Invalid state string + + with pytest.raises(HTTPException) as exc_info: + decrypt_state(state, fernet_key) + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "Invalid state" diff --git a/diracx-testing/src/diracx/testing/__init__.py b/diracx-testing/src/diracx/testing/__init__.py index 99a04468..591b8741 100644 --- a/diracx-testing/src/diracx/testing/__init__.py +++ b/diracx-testing/src/diracx/testing/__init__.py @@ -69,11 +69,19 @@ def rsa_private_key_pem() -> str: @pytest.fixture(scope="session") -def test_auth_settings(rsa_private_key_pem) -> AuthSettings: +def fernet_key() -> str: + from cryptography.fernet import Fernet + + return Fernet.generate_key().decode() + + +@pytest.fixture(scope="session") +def test_auth_settings(rsa_private_key_pem, fernet_key) -> AuthSettings: from diracx.routers.auth import AuthSettings yield AuthSettings( token_key=rsa_private_key_pem, + state_key=fernet_key, allowed_redirects=[ "http://diracx.test.invalid:8000/api/docs/oauth2-redirect", ], diff --git a/run_local.sh b/run_local.sh index 94fe4db8..30d81225 100755 --- a/run_local.sh +++ b/run_local.sh @@ -10,6 +10,8 @@ mkdir -p "${tmp_dir}/signing-key" "${tmp_dir}/cs_store/" signing_key="${tmp_dir}/signing-key/rsa256.key" ssh-keygen -P "" -t rsa -b 4096 -m PEM -f "${signing_key}" +state_key="$(head -c 32 /dev/urandom | base64)" + # Make a fake CS dirac internal generate-cs "${tmp_dir}/cs_store/initialRepo" @@ -31,6 +33,7 @@ export DIRACX_DB_URL_JOBLOGGINGDB="sqlite+aiosqlite:///:memory:" export DIRACX_DB_URL_SANDBOXMETADATADB="sqlite+aiosqlite:///:memory:" export DIRACX_DB_URL_TASKQUEUEDB="sqlite+aiosqlite:///:memory:" export DIRACX_SERVICE_AUTH_TOKEN_KEY="file://${signing_key}" +export DIRACX_SERVICE_AUTH_STATE_KEY="${state_key}" export DIRACX_SERVICE_AUTH_ALLOWED_REDIRECTS='["http://'$(hostname| tr -s '[:upper:]' '[:lower:]')':8000/docs/oauth2-redirect"]' export DIRACX_SANDBOX_STORE_BUCKET_NAME=sandboxes export DIRACX_SANDBOX_STORE_AUTO_CREATE_BUCKET=true