Skip to content

Commit

Permalink
Merge pull request #215 from aldbr/main_FEAT_hash-encrypt-auth
Browse files Browse the repository at this point in the history
Hash & Encrypt: auth module
  • Loading branch information
chrisburr authored Mar 3, 2024
2 parents 791f37b + dba04d7 commit c14c07d
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 25 deletions.
10 changes: 9 additions & 1 deletion diracx-core/src/diracx/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
if TYPE_CHECKING:
from pydantic.config import BaseConfig
from pydantic.fields import ModelField

from cryptography.fernet import Fernet

T = TypeVar("T")

Expand All @@ -42,6 +42,14 @@ def validate(cls, value: Any) -> SecretStr:
return super().validate(value)


class FernetKey(SecretStr):
fernet: Fernet

def __init__(self, data: str):
super().__init__(data)
self.fernet = Fernet(self.get_secret_value())


class LocalFileUrl(AnyUrl):
host_required = False
allowed_schemes = {"file"}
Expand Down
27 changes: 19 additions & 8 deletions diracx-db/src/diracx/db/sql/auth/db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import hashlib
import secrets
from datetime import datetime
from uuid import uuid4
Expand Down Expand Up @@ -63,7 +64,7 @@ async def get_device_flow(self, device_code: str, max_validity: int):
),
).with_for_update()
stmt = stmt.where(
DeviceFlows.device_code == device_code,
DeviceFlows.device_code == hashlib.sha256(device_code.encode()).hexdigest(),
)
res = dict((await self.conn.execute(stmt)).one()._mapping)

Expand All @@ -74,7 +75,10 @@ async def get_device_flow(self, device_code: str, max_validity: int):
# Update the status to Done before returning
await self.conn.execute(
update(DeviceFlows)
.where(DeviceFlows.device_code == device_code)
.where(
DeviceFlows.device_code
== hashlib.sha256(device_code.encode()).hexdigest()
)
.values(status=FlowStatus.DONE)
)
return res
Expand Down Expand Up @@ -119,14 +123,17 @@ async def insert_device_flow(
secrets.choice(USER_CODE_ALPHABET)
for _ in range(DeviceFlows.user_code.type.length) # type: ignore
)
# user_code = "2QRKPY"
device_code = secrets.token_urlsafe()

# Hash the the device_code to avoid leaking information
hashed_device_code = hashlib.sha256(device_code.encode()).hexdigest()

stmt = insert(DeviceFlows).values(
client_id=client_id,
scope=scope,
audience=audience,
user_code=user_code,
device_code=device_code,
device_code=hashed_device_code,
)
try:
await self.conn.execute(stmt)
Expand Down Expand Up @@ -172,7 +179,10 @@ async def authorization_flow_insert_id_token(
:raises: AuthorizationError if no such uuid or status not pending
"""

# Hash the code to avoid leaking information
code = secrets.token_urlsafe()
hashed_code = hashlib.sha256(code.encode()).hexdigest()

stmt = update(AuthorizationFlows)

stmt = stmt.where(
Expand All @@ -181,7 +191,7 @@ async def authorization_flow_insert_id_token(
AuthorizationFlows.creation_time > substract_date(seconds=max_validity),
)

stmt = stmt.values(id_token=id_token, code=code, status=FlowStatus.READY)
stmt = stmt.values(id_token=id_token, code=hashed_code, status=FlowStatus.READY)
res = await self.conn.execute(stmt)

if res.rowcount != 1:
Expand All @@ -190,15 +200,16 @@ async def authorization_flow_insert_id_token(
stmt = select(AuthorizationFlows.code, AuthorizationFlows.redirect_uri)
stmt = stmt.where(AuthorizationFlows.uuid == uuid)
row = (await self.conn.execute(stmt)).one()
return row.code, row.redirect_uri
return code, row.redirect_uri

async def get_authorization_flow(self, code: str, max_validity: int):
hashed_code = hashlib.sha256(code.encode()).hexdigest()
# The with_for_update
# prevents that the token is retrieved
# multiple time concurrently
stmt = select(AuthorizationFlows).with_for_update()
stmt = stmt.where(
AuthorizationFlows.code == code,
AuthorizationFlows.code == hashed_code,
AuthorizationFlows.creation_time > substract_date(seconds=max_validity),
)

Expand All @@ -208,7 +219,7 @@ async def get_authorization_flow(self, code: str, max_validity: int):
# Update the status to Done before returning
await self.conn.execute(
update(AuthorizationFlows)
.where(AuthorizationFlows.code == code)
.where(AuthorizationFlows.code == hashed_code)
.values(status=FlowStatus.DONE)
)

Expand Down
4 changes: 2 additions & 2 deletions diracx-db/src/diracx/db/sql/auth/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class DeviceFlows(Base):
client_id = Column(String(255))
scope = Column(String(1024))
audience = Column(String(255))
device_code = Column(String(128), unique=True) # hash it ?
device_code = Column(String(128), unique=True) # Should be a hash
id_token = NullColumn(JSON())


Expand All @@ -61,7 +61,7 @@ class AuthorizationFlows(Base):
code_challenge = Column(String(255))
code_challenge_method = Column(String(8))
redirect_uri = Column(String(255))
code = NullColumn(String(255)) # hash it ?
code = NullColumn(String(255)) # Should be a hash
id_token = NullColumn(JSON())


Expand Down
43 changes: 30 additions & 13 deletions diracx-routers/src/diracx/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from authlib.jose import JoseError, JsonWebKey, JsonWebToken
from authlib.oidc.core import IDToken
from cachetools import TTLCache
from cryptography.fernet import Fernet
from fastapi import (
Depends,
Form,
Expand All @@ -41,7 +42,7 @@
SecurityProperty,
UnevaluatedProperty,
)
from diracx.core.settings import ServiceSettingsBase, TokenSigningKey
from diracx.core.settings import FernetKey, ServiceSettingsBase, TokenSigningKey
from diracx.db.sql.auth.schema import FlowStatus, RefreshTokenStatus

from .dependencies import (
Expand All @@ -64,6 +65,9 @@ class AuthSettings(ServiceSettingsBase, env_prefix="DIRACX_SERVICE_AUTH_"):
device_flow_expiration_seconds: int = 600
authorization_flow_expiration_seconds: int = 300

# State key is used to encrypt/decrypt the state dict passed to the IAM
state_key: FernetKey

token_issuer: str = "http://lhcbdirac.cern.ch/"
token_audience: str = "dirac"
token_key: TokenSigningKey
Expand Down Expand Up @@ -379,13 +383,13 @@ async def initiate_device_flow(
"user_code": user_code,
"device_code": device_code,
"verification_uri_complete": f"{verification_uri}?user_code={user_code}",
"verification_uri": str(request.url.replace(query={})),
"verification_uri": verification_uri,
"expires_in": settings.device_flow_expiration_seconds,
}


async def initiate_authorization_flow_with_iam(
config, vo: str, redirect_uri: str, state: dict[str, str]
config, vo: str, redirect_uri: str, state: dict[str, str], cipher_suite: Fernet
):
# code_verifier: https://www.rfc-editor.org/rfc/rfc7636#section-4.1
code_verifier = secrets.token_hex()
Expand All @@ -404,10 +408,9 @@ async def initiate_authorization_flow_with_iam(
# Take these two from CS/.well-known
authorization_endpoint = server_metadata["authorization_endpoint"]

# TODO : encrypt it for good
encrypted_state = base64.urlsafe_b64encode(
json.dumps(state | {"vo": vo, "code_verifier": code_verifier}).encode()
).decode()
encrypted_state = encrypt_state(
state | {"vo": vo, "code_verifier": code_verifier}, cipher_suite
)

urlParams = [
"response_type=code",
Expand Down Expand Up @@ -501,15 +504,28 @@ async def do_device_flow(
}

authorization_flow_url = await initiate_authorization_flow_with_iam(
config, parsed_scope["vo"], redirect_uri, state_for_iam
config,
parsed_scope["vo"],
redirect_uri,
state_for_iam,
settings.state_key.fernet,
)
return RedirectResponse(authorization_flow_url)


def decrypt_state(state):
def encrypt_state(state_dict: dict[str, str], cipher_suite: Fernet) -> str:
"""Encrypt the state dict and return it as a string"""
return cipher_suite.encrypt(
base64.urlsafe_b64encode(json.dumps(state_dict).encode())
).decode()


def decrypt_state(state: str, cipher_suite: Fernet) -> dict[str, str]:
"""Decrypt the state string and return it as a dict"""
try:
# TODO: There have been better schemes like rot13
return json.loads(base64.urlsafe_b64decode(state).decode())
return json.loads(
base64.urlsafe_b64decode(cipher_suite.decrypt(state.encode())).decode()
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid state"
Expand All @@ -532,7 +548,7 @@ async def finish_device_flow(
can map it to the corresponding device flow using the user_code
in the cookie/session
"""
decrypted_state = decrypt_state(state)
decrypted_state = decrypt_state(state, settings.state_key.fernet)
assert decrypted_state["grant_type"] == GrantType.device_code

id_token = await get_token_from_iam(
Expand Down Expand Up @@ -981,6 +997,7 @@ async def authorization_flow(
parsed_scope["vo"],
f"{request.url.replace(query='')}/complete",
state_for_iam,
settings.state_key.fernet,
)

return responses.RedirectResponse(authorization_flow_url)
Expand All @@ -995,7 +1012,7 @@ async def authorization_flow_complete(
config: Config,
settings: AuthSettings,
):
decrypted_state = decrypt_state(state)
decrypted_state = decrypt_state(state, settings.state_key.fernet)
assert decrypted_state["grant_type"] == GrantType.authorization_code

id_token = await get_token_from_iam(
Expand Down
41 changes: 41 additions & 0 deletions diracx-routers/tests/auth/test_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,21 @@
import httpx
import jwt
import pytest
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from fastapi import HTTPException
from pytest_httpx import HTTPXMock

from diracx.core.config import Config
from diracx.core.properties import NORMAL_USER, PROXY_MANAGEMENT, SecurityProperty
from diracx.routers.auth import (
AuthSettings,
GrantType,
_server_metadata_cache,
create_token,
decrypt_state,
encrypt_state,
get_server_metadata,
parse_and_validate_scope,
)
Expand Down Expand Up @@ -370,6 +375,7 @@ async def test_refresh_token_invalid(test_client, auth_httpx_mock: HTTPXMock):

new_auth_settings = AuthSettings(
token_key=pem,
state_key=Fernet.generate_key(),
allowed_redirects=[
"http://diracx.test.invalid:8000/api/docs/oauth2-redirect",
],
Expand Down Expand Up @@ -680,3 +686,38 @@ def test_parse_scopes_invalid(vos, groups, scope, expected_error):
available_properties = SecurityProperty.available_properties()
with pytest.raises(ValueError, match=expected_error):
parse_and_validate_scope(scope, config, available_properties)


def test_encrypt_decrypt_state_valid_state(fernet_key):
"""Test that decrypt_state returns the correct state"""
fernet = Fernet(fernet_key)
# Create a valid state
state_dict = {
"vo": "lhcb",
"code_verifier": secrets.token_hex(),
"user_code": "AE19U",
"grant_type": GrantType.device_code.value,
}

state = encrypt_state(state_dict, fernet)
result = decrypt_state(state, fernet)

assert result == state_dict

# Create an empty state
state_dict = {}

state = encrypt_state(state_dict, fernet)
result = decrypt_state(state, fernet)

assert result == state_dict


def test_encrypt_decrypt_state_invalid_state(fernet_key):
"""Test that decrypt_state raises an error when the state is invalid"""
state = "invalid_state" # Invalid state string

with pytest.raises(HTTPException) as exc_info:
decrypt_state(state, fernet_key)
assert exc_info.value.status_code == 400
assert exc_info.value.detail == "Invalid state"
10 changes: 9 additions & 1 deletion diracx-testing/src/diracx/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,19 @@ def rsa_private_key_pem() -> str:


@pytest.fixture(scope="session")
def test_auth_settings(rsa_private_key_pem) -> AuthSettings:
def fernet_key() -> str:
from cryptography.fernet import Fernet

return Fernet.generate_key().decode()


@pytest.fixture(scope="session")
def test_auth_settings(rsa_private_key_pem, fernet_key) -> AuthSettings:
from diracx.routers.auth import AuthSettings

yield AuthSettings(
token_key=rsa_private_key_pem,
state_key=fernet_key,
allowed_redirects=[
"http://diracx.test.invalid:8000/api/docs/oauth2-redirect",
],
Expand Down
3 changes: 3 additions & 0 deletions run_local.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ mkdir -p "${tmp_dir}/signing-key" "${tmp_dir}/cs_store/"
signing_key="${tmp_dir}/signing-key/rsa256.key"
ssh-keygen -P "" -t rsa -b 4096 -m PEM -f "${signing_key}"

state_key="$(head -c 32 /dev/urandom | base64)"

# Make a fake CS
dirac internal generate-cs "${tmp_dir}/cs_store/initialRepo"

Expand All @@ -31,6 +33,7 @@ export DIRACX_DB_URL_JOBLOGGINGDB="sqlite+aiosqlite:///:memory:"
export DIRACX_DB_URL_SANDBOXMETADATADB="sqlite+aiosqlite:///:memory:"
export DIRACX_DB_URL_TASKQUEUEDB="sqlite+aiosqlite:///:memory:"
export DIRACX_SERVICE_AUTH_TOKEN_KEY="file://${signing_key}"
export DIRACX_SERVICE_AUTH_STATE_KEY="${state_key}"
export DIRACX_SERVICE_AUTH_ALLOWED_REDIRECTS='["http://'$(hostname| tr -s '[:upper:]' '[:lower:]')':8000/docs/oauth2-redirect"]'
export DIRACX_SANDBOX_STORE_BUCKET_NAME=sandboxes
export DIRACX_SANDBOX_STORE_AUTO_CREATE_BUCKET=true
Expand Down

0 comments on commit c14c07d

Please sign in to comment.