diff --git a/src/sumo/wrapper/_auth_provider.py b/src/sumo/wrapper/_auth_provider.py index 43d086b..619f13c 100644 --- a/src/sumo/wrapper/_auth_provider.py +++ b/src/sumo/wrapper/_auth_provider.py @@ -68,7 +68,7 @@ def get_authorization(self): def store_shared_access_key_for_case(self, case_uuid, token): with open( - get_token_path(self._resource_id + "+" + case_uuid, ".sharedkey"), + get_token_path(self._resource_id, ".sharedkey", case_uuid), "w", ) as f: f.write(token) @@ -93,6 +93,7 @@ def __init__(self, access_token): self._access_token = access_token payload = jwt.decode(access_token, options={"verify_signature": False}) self._expires = payload["exp"] + self._resource_id=payload["aud"] return def get_token(self): @@ -117,10 +118,17 @@ def __init__(self, refresh_token, client_id, authority, resource_id): pass -def get_token_path(resource_id, suffix): - return os.path.join( - os.path.expanduser("~"), ".sumo", str(resource_id) + suffix - ) +def get_token_path(resource_id, suffix, case_uuid=None): + if case_uuid is not None: + return os.path.join( + os.path.expanduser("~"), + ".sumo", + str(resource_id) + "+" + str(case_uuid) + suffix, + ) + else: + return os.path.join( + os.path.expanduser("~"), ".sumo", str(resource_id) + suffix + ) @tn.retry( @@ -176,8 +184,8 @@ def get_token_cache(resource_id, suffix): retry_error_callback=_return_last_value, before_sleep=_log_retry_info, ) -def protect_token_cache(resource_id, suffix): - token_path = get_token_path(resource_id, suffix) +def protect_token_cache(resource_id, suffix, case_uuid=None): + token_path = get_token_path(resource_id, suffix, case_uuid) if sys.platform.startswith("linux") or sys.platform == "darwin": filemode = stat.filemode(os.stat(token_path).st_mode) @@ -369,9 +377,9 @@ class AuthProviderSumoToken(AuthProvider): retry_error_callback=_return_last_value, before_sleep=_log_retry_info, ) - def __init__(self, resource_id): - protect_token_cache(resource_id, ".sharedkey") - token_path = get_token_path(resource_id, ".sharedkey") + def __init__(self, resource_id, case_uuid=None): + protect_token_cache(resource_id, ".sharedkey", case_uuid) + token_path = get_token_path(resource_id, ".sharedkey", case_uuid) with open(token_path, "r") as f: self._token = f.readline().strip() return @@ -411,13 +419,8 @@ def get_auth_provider( if access_token: return AuthProviderAccessToken(access_token) # ELSE - if case_uuid is not None and os.path.exists( - get_token_path(resource_id + "+" + case_uuid, ".sharedkey") - ): - return AuthProviderSumoToken(resource_id + "+" + case_uuid) - # ELSE - if os.path.exists(get_token_path(resource_id, ".sharedkey")): - return AuthProviderSumoToken(resource_id) + if os.path.exists(get_token_path(resource_id, ".sharedkey", case_uuid)): + return AuthProviderSumoToken(resource_id, case_uuid) # ELSE if os.path.exists(get_token_path(resource_id, ".token")): auth_silent = AuthProviderSilent(client_id, authority, resource_id)