From 90c926d79123b27691016c3430250e8d19406ae5 Mon Sep 17 00:00:00 2001 From: Avram Tudor Date: Fri, 15 Nov 2024 12:55:28 +0200 Subject: [PATCH] feat: add fallback folder when looking up public keys (#119) * feat: add fallback folder when looking up public keys * move exception message to the function that raises it --------- Co-authored-by: Avram Tudor --- skynet/auth/jwt.py | 32 +++++++++++++++++++++++++++----- skynet/env.py | 1 + 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/skynet/auth/jwt.py b/skynet/auth/jwt.py index 35ad8ae..ee149e0 100644 --- a/skynet/auth/jwt.py +++ b/skynet/auth/jwt.py @@ -5,22 +5,44 @@ from fastapi import HTTPException from skynet import http_client -from skynet.env import asap_pub_keys_auds, asap_pub_keys_folder, asap_pub_keys_max_cache_size, asap_pub_keys_url +from skynet.env import ( + asap_pub_keys_auds, + asap_pub_keys_fallback_folder, + asap_pub_keys_folder, + asap_pub_keys_max_cache_size, + asap_pub_keys_url, +) from skynet.logs import get_logger log = get_logger(__name__) +def is_valid_key(key: str) -> bool: + return key.startswith('-----BEGIN PUBLIC KEY-----') + + @alru_cache(maxsize=asap_pub_keys_max_cache_size) async def get_public_key(kid: str) -> str: encoded_pub_key_name = sha256(kid.encode('UTF-8')).hexdigest() pub_key_remote_filename = f'{encoded_pub_key_name}.pem' - url = f'{asap_pub_keys_url}/{asap_pub_keys_folder}/{pub_key_remote_filename}' log.info(f'Fetching public key {kid} from {url}') + key = await http_client.get(url, 'text') + + if is_valid_key(key): + return key - return await http_client.get(url, 'text') + if asap_pub_keys_fallback_folder: + url = f'{asap_pub_keys_url}/{asap_pub_keys_fallback_folder}/{pub_key_remote_filename}' + + log.info(f'Fetching public key {kid} from {url}') + key = await http_client.get(url, 'text') + + if is_valid_key(key): + return key + + raise Exception(f'Failed to retrieve public key {kid}') async def authorize(jwt_incoming: str) -> dict: @@ -36,8 +58,8 @@ async def authorize(jwt_incoming: str) -> dict: try: public_key = await get_public_key(kid) - except Exception: - raise HTTPException(status_code=401, detail=f'Failed to retrieve public key. {kid}') + except Exception as ex: + raise HTTPException(status_code=401, detail=str(ex)) try: decoded = jwt.decode(jwt_incoming, public_key, algorithms=['RS256', 'HS512'], audience=asap_pub_keys_auds) diff --git a/skynet/env.py b/skynet/env.py index e3cda0f..f7ce585 100644 --- a/skynet/env.py +++ b/skynet/env.py @@ -51,6 +51,7 @@ def tobool(val: str | None): bypass_auth = tobool(os.environ.get('BYPASS_AUTHORIZATION')) asap_pub_keys_url = os.environ.get('ASAP_PUB_KEYS_REPO_URL') asap_pub_keys_folder = os.environ.get('ASAP_PUB_KEYS_FOLDER') +asap_pub_keys_fallback_folder = os.environ.get('ASAP_PUB_KEYS_FALLBACK_FOLDER') asap_pub_keys_auds = os.environ.get('ASAP_PUB_KEYS_AUDS', '').strip().split(',') asap_pub_keys_max_cache_size = int(os.environ.get('ASAP_PUB_KEYS_MAX_CACHE_SIZE', 512))