Skip to content

Commit

Permalink
feat: add fallback folder when looking up public keys (#119)
Browse files Browse the repository at this point in the history
* feat: add fallback folder when looking up public keys

* move exception message to the function that raises it

---------

Co-authored-by: Avram Tudor <[email protected]>
  • Loading branch information
quitrk and Avram Tudor authored Nov 15, 2024
1 parent f06a2f2 commit 90c926d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 5 deletions.
32 changes: 27 additions & 5 deletions skynet/auth/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,44 @@
from fastapi import HTTPException

from skynet import http_client
from skynet.env import asap_pub_keys_auds, asap_pub_keys_folder, asap_pub_keys_max_cache_size, asap_pub_keys_url
from skynet.env import (
asap_pub_keys_auds,
asap_pub_keys_fallback_folder,
asap_pub_keys_folder,
asap_pub_keys_max_cache_size,
asap_pub_keys_url,
)
from skynet.logs import get_logger

log = get_logger(__name__)


def is_valid_key(key: str) -> bool:
return key.startswith('-----BEGIN PUBLIC KEY-----')


@alru_cache(maxsize=asap_pub_keys_max_cache_size)
async def get_public_key(kid: str) -> str:
encoded_pub_key_name = sha256(kid.encode('UTF-8')).hexdigest()
pub_key_remote_filename = f'{encoded_pub_key_name}.pem'

url = f'{asap_pub_keys_url}/{asap_pub_keys_folder}/{pub_key_remote_filename}'

log.info(f'Fetching public key {kid} from {url}')
key = await http_client.get(url, 'text')

if is_valid_key(key):
return key

return await http_client.get(url, 'text')
if asap_pub_keys_fallback_folder:
url = f'{asap_pub_keys_url}/{asap_pub_keys_fallback_folder}/{pub_key_remote_filename}'

log.info(f'Fetching public key {kid} from {url}')
key = await http_client.get(url, 'text')

if is_valid_key(key):
return key

raise Exception(f'Failed to retrieve public key {kid}')


async def authorize(jwt_incoming: str) -> dict:
Expand All @@ -36,8 +58,8 @@ async def authorize(jwt_incoming: str) -> dict:

try:
public_key = await get_public_key(kid)
except Exception:
raise HTTPException(status_code=401, detail=f'Failed to retrieve public key. {kid}')
except Exception as ex:
raise HTTPException(status_code=401, detail=str(ex))

try:
decoded = jwt.decode(jwt_incoming, public_key, algorithms=['RS256', 'HS512'], audience=asap_pub_keys_auds)
Expand Down
1 change: 1 addition & 0 deletions skynet/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def tobool(val: str | None):
bypass_auth = tobool(os.environ.get('BYPASS_AUTHORIZATION'))
asap_pub_keys_url = os.environ.get('ASAP_PUB_KEYS_REPO_URL')
asap_pub_keys_folder = os.environ.get('ASAP_PUB_KEYS_FOLDER')
asap_pub_keys_fallback_folder = os.environ.get('ASAP_PUB_KEYS_FALLBACK_FOLDER')
asap_pub_keys_auds = os.environ.get('ASAP_PUB_KEYS_AUDS', '').strip().split(',')
asap_pub_keys_max_cache_size = int(os.environ.get('ASAP_PUB_KEYS_MAX_CACHE_SIZE', 512))

Expand Down

0 comments on commit 90c926d

Please sign in to comment.