Skip to content

Commit

Permalink
add credentials to resolver and use _resolve_account_id helper
Browse files Browse the repository at this point in the history
  • Loading branch information
davidlm committed Sep 19, 2023
1 parent 8874c64 commit 1f672cb
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 58 deletions.
10 changes: 4 additions & 6 deletions botocore/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,6 @@ def _build_endpoint_resolver(
endpoint_bridge=endpoint_bridge,
client_endpoint_url=endpoint_url,
legacy_endpoint_url=endpoint.host,
credentials=credentials,
)
# botocore does not support client context parameters generically
# for every service. Instead, the s3 config section entries are
Expand All @@ -669,6 +668,7 @@ def _build_endpoint_resolver(
event_emitter=event_emitter,
use_ssl=is_secure,
requested_auth_scheme=sig_version,
credentials=credentials,
)

def compute_endpoint_resolver_builtin_defaults(
Expand All @@ -679,7 +679,6 @@ def compute_endpoint_resolver_builtin_defaults(
endpoint_bridge,
client_endpoint_url,
legacy_endpoint_url,
credentials,
):
# EndpointRulesetResolver rulesets may accept an "SDK::Endpoint" as
# input. If the endpoint_url argument of create_client() is set, it
Expand Down Expand Up @@ -708,9 +707,6 @@ def compute_endpoint_resolver_builtin_defaults(
else:
force_path_style = s3_config.get('addressing_style') == 'path'

account_id_getter = None
if credentials is not None:
account_id_getter = credentials.account_id_getter
return {
EPRBuiltins.AWS_REGION: region_name,
EPRBuiltins.AWS_USE_FIPS: (
Expand Down Expand Up @@ -757,7 +753,9 @@ def compute_endpoint_resolver_builtin_defaults(
's3_disable_multiregion_access_points', False
),
EPRBuiltins.SDK_ENDPOINT: given_endpoint,
EPRBuiltins.AWS_ACCOUNT_ID: account_id_getter,
# account ID is calculated later if account based routing is
# enabled and configured for the service
EPRBuiltins.AWS_ACCOUNT_ID: None,
}

def _compute_user_agent_appid_config(self, config_kwargs):
Expand Down
45 changes: 15 additions & 30 deletions botocore/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,9 +340,6 @@ def get_frozen_credentials(self):
self.access_key, self.secret_key, self.token, self.account_id
)

def account_id_getter(self):
return self.account_id


class RefreshableCredentials(Credentials):
"""
Expand Down Expand Up @@ -679,7 +676,7 @@ def _get_credentials(self):
def fetch_credentials(self):
return self._get_cached_credentials()

def _get_cached_credentials_response(self):
def _get_cached_credentials(self):
"""Get up-to-date credentials.
This will check the cache for up-to-date credentials, calling assume
Expand All @@ -691,10 +688,16 @@ def _get_cached_credentials_response(self):
self._write_to_cache(response)
else:
logger.debug("Credentials for role retrieved from cache.")
return response

def _get_cached_credentials(self, response):
raise NotImplementedError('_get_cached_credentials()')
creds = response['Credentials']
expiration = _serialize_if_needed(creds['Expiration'], iso=True)
return {
'access_key': creds['AccessKeyId'],
'secret_key': creds['SecretAccessKey'],
'token': creds['SessionToken'],
'expiry_time': expiration,
'account_id': self._resolve_account_id(response),
}

def _load_from_cache(self):
if self._cache_key in self._cache:
Expand All @@ -716,6 +719,9 @@ def _is_expired(self, credentials):
seconds = total_seconds(end_time - _local_now())
return seconds < self._expiry_window_seconds

def _resolve_account_id(self, response=None):
raise NotImplementedError('_resolve_account_id()')


class BaseAssumeRoleCredentialFetcher(CachedCredentialFetcher):
def __init__(
Expand Down Expand Up @@ -769,18 +775,6 @@ def _create_cache_key(self):
argument_hash = sha1(args.encode('utf-8')).hexdigest()
return self._make_file_safe(argument_hash)

def _get_cached_credentials(self):
response = self._get_cached_credentials_response()
creds = response['Credentials']
expiration = _serialize_if_needed(creds['Expiration'], iso=True)
return {
'access_key': creds['AccessKeyId'],
'secret_key': creds['SecretAccessKey'],
'token': creds['SessionToken'],
'expiry_time': expiration,
'account_id': self._resolve_account_id(response),
}

def _resolve_account_id(self, response):
user_arn = response['AssumedRoleUser']['Arn']
return ArnParser().parse_arn(user_arn)['account']
Expand Down Expand Up @@ -2202,17 +2196,8 @@ def _get_credentials(self):
}
return credentials

def _get_cached_credentials(self):
response = self._get_cached_credentials_response()
creds = response['Credentials']
expiration = _serialize_if_needed(creds['Expiration'], iso=True)
return {
'access_key': creds['AccessKeyId'],
'secret_key': creds['SecretAccessKey'],
'token': creds['SessionToken'],
'expiry_time': expiration,
'account_id': self._account_id,
}
def _resolve_account_id(self, response=None):
return self._account_id


class SSOProvider(CredentialProvider):
Expand Down
30 changes: 15 additions & 15 deletions botocore/regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ def __init__(
event_emitter,
use_ssl=True,
requested_auth_scheme=None,
credentials=None,
):
self._provider = EndpointProvider(
ruleset_data=endpoint_ruleset_data,
Expand All @@ -488,6 +489,7 @@ def __init__(
self._use_ssl = use_ssl
self._requested_auth_scheme = requested_auth_scheme
self._instance_cache = {}
self._credentials = credentials

def construct_endpoint(
self,
Expand Down Expand Up @@ -556,11 +558,13 @@ def _get_provider_params(
customized_builtins = self._get_customized_builtins(
operation_model, call_args, request_context
)
if not request_context.get('is_presign_request'):
is_presign = request_context.get('is_presign_request')
should_sign = self._requested_auth_scheme != UNSIGNED
# account ID routing is disabled for presigned and unsigned requests.
if not is_presign and should_sign:
config = request_context['client_config']
account_id_endpoint_mode = config.account_id_endpoint_mode
else:
# Presigning requests do not support account ID routing..
account_id_endpoint_mode = 'disabled'
for param_name, param_def in self._param_definitions.items():
param_val = self._resolve_param_from_context(
Expand Down Expand Up @@ -618,32 +622,28 @@ def _resolve_param_as_builtin(
):
if builtin_name not in EndpointResolverBuiltins.__members__.values():
raise UnknownEndpointResolutionBuiltInName(name=builtin_name)
builtin_value = builtins.get(builtin_name)
if (
builtin_name == EndpointResolverBuiltins.AWS_ACCOUNT_ID
and builtin_value is not None
):
builtin_value = self._resolve_account_id_builtin(
builtin_value, account_id_endpoint_mode
)
return builtin_value

def _resolve_account_id_builtin(self, builtin, account_id_endpoint_mode):
if builtin_name == EndpointResolverBuiltins.AWS_ACCOUNT_ID:
return self._resolve_account_id_builtin(account_id_endpoint_mode)

return builtins.get(builtin_name)

def _resolve_account_id_builtin(self, account_id_endpoint_mode):
self._validate_account_id_endpoint_mode(account_id_endpoint_mode)

if account_id_endpoint_mode == 'disabled':
return None

builtin_value = builtin()
if builtin_value is None:
account_id = self._credentials.account_id
if account_id is None:
if account_id_endpoint_mode == 'preferred':
LOG.warning(
'`account_id_endpoint_mode` is enabled but no account ID was found!'
)
elif account_id_endpoint_mode == 'required':
raise AccountIDNotFound()

return builtin_value
return account_id

def _validate_account_id_endpoint_mode(account_id_endpoint_mode):
if account_id_endpoint_mode not in VALID_ACCOUNT_ID_ENDPOINT_MODES:
Expand Down
9 changes: 2 additions & 7 deletions tests/unit/test_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,6 @@ def call_compute_endpoint_resolver_builtin_defaults(self, **overrides):
'endpoint_bridge': self.bridge,
'client_endpoint_url': None,
'legacy_endpoint_url': 'https://my.legacy.endpoint.com',
'credentials': self.credentials,
}
kwargs = {**defaults, **overrides}
return self.args_create.compute_endpoint_resolver_builtin_defaults(
Expand All @@ -688,9 +687,7 @@ def test_builtins_defaults(self):
bins['AWS::S3::DisableMultiRegionAccessPoints'], False
)
self.assertEqual(bins['SDK::Endpoint'], None)
self.assertEqual(
bins['AWS::Auth::AccountId'], self.credentials.account_id_getter
)
self.assertEqual(bins['AWS::Auth::AccountId'], None)

def test_aws_region(self):
bins = self.call_compute_endpoint_resolver_builtin_defaults(
Expand Down Expand Up @@ -857,6 +854,4 @@ def test_sdk_endpoint_legacy_set_without_builtin_data(self):

def test_aws_account_id(self):
bins = self.call_compute_endpoint_resolver_builtin_defaults()
self.assertEqual(
bins['AWS::Auth::AccountId'](), self.credentials.account_id
)
self.assertIsNone(bins['AWS::Auth::AccountId'])

0 comments on commit 1f672cb

Please sign in to comment.