Skip to content

Commit

Permalink
code cleanup and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
davidlm committed Sep 27, 2023
1 parent 79aa6ab commit 31e83d9
Show file tree
Hide file tree
Showing 7 changed files with 331 additions and 130 deletions.
2 changes: 1 addition & 1 deletion botocore/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ class Config:
:type account_id_endpoint_mode: str
:param account_id_endpoint_mode: Enables or disables account ID based
endpoint routing for supported operations.
endpoint routing for supported operations.
Defaults to None.
"""
Expand Down
95 changes: 61 additions & 34 deletions botocore/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
ContainerMetadataFetcher,
FileWebIdentityTokenLoader,
InstanceMetadataFetcher,
InvalidArnException,
JSONFileCache,
SSOTokenLoader,
parse_key_val_file,
Expand Down Expand Up @@ -306,22 +307,27 @@ class Credentials:
:param str access_key: The access key part of the credentials.
:param str secret_key: The secret key part of the credentials.
:param str token: The security token, valid only for session credentials.
:param str account_id: The account ID associated with the credentials.
:param str method: A string which identifies where the credentials
were found.
:param str account_id: The account ID associated with the credentials.
"""

def __init__(
self, access_key, secret_key, token=None, account_id=None, method=None
self,
access_key,
secret_key,
token=None,
method=None,
account_id=None,
):
self.access_key = access_key
self.secret_key = secret_key
self.token = token
self.account_id = account_id

if method is None:
method = 'explicit'
self.method = method
self.account_id = account_id

self._normalize()

Expand Down Expand Up @@ -349,11 +355,12 @@ class RefreshableCredentials(Credentials):
:param str access_key: The access key part of the credentials.
:param str secret_key: The secret key part of the credentials.
:param str token: The security token, valid only for session credentials.
:param str account_id: The account ID associated with the credentials.
:param datetime expiry_time: The time when the credentials will expire.
:param function refresh_using: Callback function to refresh the credentials.
:param str method: A string which identifies where the credentials
were found.
:param function time_fetcher: Callback function to retrieve current time.
:param str account_id: The account ID associated with the credentials.
"""

# The time at which we'll attempt to refresh, but not
Expand All @@ -368,17 +375,16 @@ def __init__(
access_key,
secret_key,
token,
account_id,
expiry_time,
refresh_using,
method,
time_fetcher=_local_now,
account_id=None,
):
self._refresh_using = refresh_using
self._access_key = access_key
self._secret_key = secret_key
self._token = token
self._account_id = account_id
self._expiry_time = expiry_time
self._time_fetcher = time_fetcher
self._refresh_lock = threading.Lock()
Expand All @@ -387,6 +393,7 @@ def __init__(
access_key, secret_key, token, account_id
)
self._normalize()
self._account_id = account_id

def _normalize(self):
self._access_key = botocore.compat.ensure_unicode(self._access_key)
Expand All @@ -398,10 +405,10 @@ def create_from_metadata(cls, metadata, refresh_using, method):
access_key=metadata['access_key'],
secret_key=metadata['secret_key'],
token=metadata['token'],
account_id=metadata.get('account_id'),
expiry_time=cls._expiry_datetime(metadata['expiry_time']),
method=method,
refresh_using=refresh_using,
account_id=metadata.get('account_id'),
)
return instance

Expand Down Expand Up @@ -446,8 +453,16 @@ def token(self, value):

@property
def account_id(self):
frozen_credentials = self.get_frozen_credentials()
return frozen_credentials.account_id
"""Warning: Using this property can lead to race conditions if you
access another property subsequently along the refresh boundary.
Please use get_frozen_credentials instead.
"""
self._refresh()
return self._account_id

@account_id.setter
def account_id(self, value):
self._account_id = value

def _seconds_remaining(self):
delta = self._expiry_time - self._time_fetcher()
Expand Down Expand Up @@ -585,7 +600,7 @@ def _set_from_data(self, data):
logger.debug(
"Retrieved credentials will expire at: %s", self._expiry_time
)
self._account_id = data.get('account_id')
self.account_id = data.get('account_id')
self._normalize()

def get_frozen_credentials(self):
Expand Down Expand Up @@ -696,7 +711,7 @@ def _get_cached_credentials(self):
'secret_key': creds['SecretAccessKey'],
'token': creds['SessionToken'],
'expiry_time': expiration,
'account_id': self._resolve_account_id(response),
'account_id': creds.get('AccountId'),
}

def _load_from_cache(self):
Expand All @@ -719,9 +734,6 @@ def _is_expired(self, credentials):
seconds = total_seconds(end_time - _local_now())
return seconds < self._expiry_window_seconds

def _resolve_account_id(self, response=None):
raise NotImplementedError('_resolve_account_id()')


class BaseAssumeRoleCredentialFetcher(CachedCredentialFetcher):
def __init__(
Expand All @@ -747,6 +759,7 @@ def __init__(
self._generate_assume_role_name()

super().__init__(cache, expiry_window_seconds)
self._arn_parser = ArnParser()

def _generate_assume_role_name(self):
self._role_session_name = 'botocore-session-%s' % (int(time.time()))
Expand Down Expand Up @@ -776,8 +789,16 @@ def _create_cache_key(self):
return self._make_file_safe(argument_hash)

def _resolve_account_id(self, response):
user_arn = response['AssumedRoleUser']['Arn']
return ArnParser().parse_arn(user_arn)['account']
account_id = None
user_arn = response.get('AssumedRoleUser', {}).get('Arn')
if user_arn is not None:
try:
account_id = self._arn_parser.parse_arn(user_arn)['account']
except InvalidArnException:
logger.debug(
'Unable to parse account ID from ARN: %s', user_arn
)
response['Credentials']['AccountId'] = account_id


class AssumeRoleCredentialFetcher(BaseAssumeRoleCredentialFetcher):
Expand Down Expand Up @@ -839,7 +860,9 @@ def _get_credentials(self):
"""Get credentials by calling assume role."""
kwargs = self._assume_role_kwargs()
client = self._create_client()
return client.assume_role(**kwargs)
response = client.assume_role(**kwargs)
self._resolve_account_id(response)
return response

def _assume_role_kwargs(self):
"""Get the arguments for assume role based on current configuration."""
Expand Down Expand Up @@ -926,7 +949,9 @@ def _get_credentials(self):
# the token, explicitly configure the client to not sign requests.
config = Config(signature_version=UNSIGNED)
client = self._client_creator('sts', config=config)
return client.assume_role_with_web_identity(**kwargs)
response = client.assume_role_with_web_identity(**kwargs)
self._resolve_account_id(response)
return response

def _assume_role_kwargs(self):
"""Get the arguments for assume role based on current configuration."""
Expand Down Expand Up @@ -1009,8 +1034,8 @@ def load(self):
access_key=creds_dict['access_key'],
secret_key=creds_dict['secret_key'],
token=creds_dict.get('token'),
account_id=creds_dict['account_id'],
method=self.METHOD,
account_id=creds_dict['account_id'],
)

def _retrieve_credentials_using(self, credential_process):
Expand Down Expand Up @@ -1116,8 +1141,8 @@ def __init__(self, environ=None, mapping=None):
:param mapping: An optional mapping of variable names to
environment variable names. Use this if you want to
change the mapping of access_key->AWS_ACCESS_KEY_ID, etc.
The dict can have up to 3 keys: ``access_key``, ``secret_key``,
``session_token``.
The dict can have up to 5 keys: ``access_key``, ``secret_key``,
``token``, ``expiry_time``, and ``account_id``.
"""
if environ is None:
environ = os.environ
Expand Down Expand Up @@ -1163,25 +1188,26 @@ def load(self):
logger.info('Found credentials in environment variables.')
fetcher = self._create_credentials_fetcher()
credentials = fetcher(require_expiry=False)

expiry_time = credentials['expiry_time']
if expiry_time is not None:
expiry_time = parse(expiry_time)
return RefreshableCredentials(
credentials['access_key'],
credentials['secret_key'],
credentials['token'],
credentials['account_id'],
expiry_time,
refresh_using=fetcher,
method=self.METHOD,
account_id=credentials['account_id'],
)

return Credentials(
credentials['access_key'],
credentials['secret_key'],
credentials['token'],
credentials['account_id'],
method=self.METHOD,
account_id=credentials['account_id'],
)
else:
return None
Expand Down Expand Up @@ -1223,10 +1249,7 @@ def fetch_credentials(require_expiry=True):
raise PartialCredentialsError(
provider=method, cred_var=mapping['expiry_time']
)
credentials['account_id'] = None
account_id = environ.get(mapping['account_id'], '')
if account_id:
credentials['account_id'] = account_id
credentials['account_id'] = environ.get(mapping['account_id'])

return credentials

Expand Down Expand Up @@ -1305,20 +1328,23 @@ def load(self):
config, self.ACCESS_KEY, self.SECRET_KEY
)
token = self._get_session_token(config)
account_id = config.get(self.ACCOUNT_ID)
account_id = self._get_account_id(config)
return Credentials(
access_key,
secret_key,
token,
account_id,
method=self.METHOD,
account_id=account_id,
)

def _get_session_token(self, config):
for token_envvar in self.TOKENS:
if token_envvar in config:
return config[token_envvar]

def _get_account_id(self, config):
return config.get(self.ACCOUNT_ID)


class ConfigProvider(CredentialProvider):
"""INI based config provider with profile sections."""
Expand Down Expand Up @@ -1368,14 +1394,14 @@ def load(self):
access_key, secret_key = self._extract_creds_from_mapping(
profile_config, self.ACCESS_KEY, self.SECRET_KEY
)
account_id = profile_config.get(self.ACCOUNT_ID)
token = self._get_session_token(profile_config)
account_id = self._get_account_id(profile_config)
return Credentials(
access_key,
secret_key,
token,
account_id,
method=self.METHOD,
account_id=account_id,
)
else:
return None
Expand All @@ -1385,6 +1411,9 @@ def _get_session_token(self, profile_config):
if token_name in profile_config:
return profile_config[token_name]

def _get_account_id(self, profile_config):
return profile_config.get(self.ACCOUNT_ID)


class BotoProvider(CredentialProvider):
METHOD = 'boto-config'
Expand Down Expand Up @@ -2192,13 +2221,11 @@ def _get_credentials(self):
'SecretAccessKey': credentials['secretAccessKey'],
'SessionToken': credentials['sessionToken'],
'Expiration': self._parse_timestamp(credentials['expiration']),
'AccountId': self._account_id,
},
}
return credentials

def _resolve_account_id(self, response=None):
return self._account_id


class SSOProvider(CredentialProvider):
METHOD = 'sso'
Expand Down
6 changes: 5 additions & 1 deletion botocore/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,8 +840,8 @@ def create_client(
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
aws_account_id=None,
config=None,
aws_account_id=None,
):
"""Create a botocore client.
Expand Down Expand Up @@ -911,6 +911,10 @@ def create_client(
:rtype: botocore.client.BaseClient
:return: A botocore client instance
:type aws_account_id: string
:param aws_account_id: The AWS account ID to use when creating the client.
Same semantics as aws_access_key_id above.
"""
default_client_config = self.get_default_client_config()
# If a config is provided and a default config is set, then
Expand Down
8 changes: 7 additions & 1 deletion tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,13 @@ def __init__(
if refresh_function is None:
refresh_function = self._do_refresh
super().__init__(
'0', '0', '0', '0', expires_in, refresh_function, 'INTREFRESH'
'0',
'0',
'0',
expires_in,
refresh_function,
'INTREFRESH',
account_id='0',
)
self.creds_last_for = creds_last_for
self.refresh_counter = 0
Expand Down
Loading

0 comments on commit 31e83d9

Please sign in to comment.