Skip to content

Commit

Permalink
- Fix for protect_token_cache.
Browse files Browse the repository at this point in the history
- Refactor.
- Added _resource_id to AuthProviderAccessToken so that it can be
  used to create shared access keys (required for PR tests on github).
  • Loading branch information
rwiker committed Dec 19, 2024
1 parent 9397740 commit 6e4aec7
Showing 1 changed file with 20 additions and 17 deletions.
37 changes: 20 additions & 17 deletions src/sumo/wrapper/_auth_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def get_authorization(self):

def store_shared_access_key_for_case(self, case_uuid, token):
with open(
get_token_path(self._resource_id + "+" + case_uuid, ".sharedkey"),
get_token_path(self._resource_id, ".sharedkey", case_uuid),
"w",
) as f:
f.write(token)
Expand All @@ -93,6 +93,7 @@ def __init__(self, access_token):
self._access_token = access_token
payload = jwt.decode(access_token, options={"verify_signature": False})
self._expires = payload["exp"]
self._resource_id=payload["aud"]
return

def get_token(self):
Expand All @@ -117,10 +118,17 @@ def __init__(self, refresh_token, client_id, authority, resource_id):
pass


def get_token_path(resource_id, suffix):
return os.path.join(
os.path.expanduser("~"), ".sumo", str(resource_id) + suffix
)
def get_token_path(resource_id, suffix, case_uuid=None):
if case_uuid is not None:
return os.path.join(
os.path.expanduser("~"),
".sumo",
str(resource_id) + "+" + str(case_uuid) + suffix,
)
else:
return os.path.join(
os.path.expanduser("~"), ".sumo", str(resource_id) + suffix
)


@tn.retry(
Expand Down Expand Up @@ -176,8 +184,8 @@ def get_token_cache(resource_id, suffix):
retry_error_callback=_return_last_value,
before_sleep=_log_retry_info,
)
def protect_token_cache(resource_id, suffix):
token_path = get_token_path(resource_id, suffix)
def protect_token_cache(resource_id, suffix, case_uuid=None):
token_path = get_token_path(resource_id, suffix, case_uuid)

if sys.platform.startswith("linux") or sys.platform == "darwin":
filemode = stat.filemode(os.stat(token_path).st_mode)
Expand Down Expand Up @@ -369,9 +377,9 @@ class AuthProviderSumoToken(AuthProvider):
retry_error_callback=_return_last_value,
before_sleep=_log_retry_info,
)
def __init__(self, resource_id):
protect_token_cache(resource_id, ".sharedkey")
token_path = get_token_path(resource_id, ".sharedkey")
def __init__(self, resource_id, case_uuid=None):
protect_token_cache(resource_id, ".sharedkey", case_uuid)
token_path = get_token_path(resource_id, ".sharedkey", case_uuid)
with open(token_path, "r") as f:
self._token = f.readline().strip()
return
Expand Down Expand Up @@ -411,13 +419,8 @@ def get_auth_provider(
if access_token:
return AuthProviderAccessToken(access_token)
# ELSE
if case_uuid is not None and os.path.exists(
get_token_path(resource_id + "+" + case_uuid, ".sharedkey")
):
return AuthProviderSumoToken(resource_id + "+" + case_uuid)
# ELSE
if os.path.exists(get_token_path(resource_id, ".sharedkey")):
return AuthProviderSumoToken(resource_id)
if os.path.exists(get_token_path(resource_id, ".sharedkey", case_uuid)):
return AuthProviderSumoToken(resource_id, case_uuid)
# ELSE
if os.path.exists(get_token_path(resource_id, ".token")):
auth_silent = AuthProviderSilent(client_id, authority, resource_id)
Expand Down

0 comments on commit 6e4aec7

Please sign in to comment.