From 1f672cb96244ef498d9e3a6faf25ffdb5223c974 Mon Sep 17 00:00:00 2001 From: davidlm Date: Tue, 19 Sep 2023 18:26:07 -0400 Subject: [PATCH] add credentials to resolver and use _resolve_account_id helper --- botocore/args.py | 10 ++++----- botocore/credentials.py | 45 ++++++++++++++--------------------------- botocore/regions.py | 30 +++++++++++++-------------- tests/unit/test_args.py | 9 ++------- 4 files changed, 36 insertions(+), 58 deletions(-) diff --git a/botocore/args.py b/botocore/args.py index 3d714b0b8a..2359169d19 100644 --- a/botocore/args.py +++ b/botocore/args.py @@ -645,7 +645,6 @@ def _build_endpoint_resolver( endpoint_bridge=endpoint_bridge, client_endpoint_url=endpoint_url, legacy_endpoint_url=endpoint.host, - credentials=credentials, ) # botocore does not support client context parameters generically # for every service. Instead, the s3 config section entries are @@ -669,6 +668,7 @@ def _build_endpoint_resolver( event_emitter=event_emitter, use_ssl=is_secure, requested_auth_scheme=sig_version, + credentials=credentials, ) def compute_endpoint_resolver_builtin_defaults( @@ -679,7 +679,6 @@ def compute_endpoint_resolver_builtin_defaults( endpoint_bridge, client_endpoint_url, legacy_endpoint_url, - credentials, ): # EndpointRulesetResolver rulesets may accept an "SDK::Endpoint" as # input. If the endpoint_url argument of create_client() is set, it @@ -708,9 +707,6 @@ def compute_endpoint_resolver_builtin_defaults( else: force_path_style = s3_config.get('addressing_style') == 'path' - account_id_getter = None - if credentials is not None: - account_id_getter = credentials.account_id_getter return { EPRBuiltins.AWS_REGION: region_name, EPRBuiltins.AWS_USE_FIPS: ( @@ -757,7 +753,9 @@ def compute_endpoint_resolver_builtin_defaults( 's3_disable_multiregion_access_points', False ), EPRBuiltins.SDK_ENDPOINT: given_endpoint, - EPRBuiltins.AWS_ACCOUNT_ID: account_id_getter, + # account ID is calculated later if account based routing is + # enabled and configured for the service + EPRBuiltins.AWS_ACCOUNT_ID: None, } def _compute_user_agent_appid_config(self, config_kwargs): diff --git a/botocore/credentials.py b/botocore/credentials.py index 3d0ae6f466..1c2b64de85 100644 --- a/botocore/credentials.py +++ b/botocore/credentials.py @@ -340,9 +340,6 @@ def get_frozen_credentials(self): self.access_key, self.secret_key, self.token, self.account_id ) - def account_id_getter(self): - return self.account_id - class RefreshableCredentials(Credentials): """ @@ -679,7 +676,7 @@ def _get_credentials(self): def fetch_credentials(self): return self._get_cached_credentials() - def _get_cached_credentials_response(self): + def _get_cached_credentials(self): """Get up-to-date credentials. This will check the cache for up-to-date credentials, calling assume @@ -691,10 +688,16 @@ def _get_cached_credentials_response(self): self._write_to_cache(response) else: logger.debug("Credentials for role retrieved from cache.") - return response - def _get_cached_credentials(self, response): - raise NotImplementedError('_get_cached_credentials()') + creds = response['Credentials'] + expiration = _serialize_if_needed(creds['Expiration'], iso=True) + return { + 'access_key': creds['AccessKeyId'], + 'secret_key': creds['SecretAccessKey'], + 'token': creds['SessionToken'], + 'expiry_time': expiration, + 'account_id': self._resolve_account_id(response), + } def _load_from_cache(self): if self._cache_key in self._cache: @@ -716,6 +719,9 @@ def _is_expired(self, credentials): seconds = total_seconds(end_time - _local_now()) return seconds < self._expiry_window_seconds + def _resolve_account_id(self, response=None): + raise NotImplementedError('_resolve_account_id()') + class BaseAssumeRoleCredentialFetcher(CachedCredentialFetcher): def __init__( @@ -769,18 +775,6 @@ def _create_cache_key(self): argument_hash = sha1(args.encode('utf-8')).hexdigest() return self._make_file_safe(argument_hash) - def _get_cached_credentials(self): - response = self._get_cached_credentials_response() - creds = response['Credentials'] - expiration = _serialize_if_needed(creds['Expiration'], iso=True) - return { - 'access_key': creds['AccessKeyId'], - 'secret_key': creds['SecretAccessKey'], - 'token': creds['SessionToken'], - 'expiry_time': expiration, - 'account_id': self._resolve_account_id(response), - } - def _resolve_account_id(self, response): user_arn = response['AssumedRoleUser']['Arn'] return ArnParser().parse_arn(user_arn)['account'] @@ -2202,17 +2196,8 @@ def _get_credentials(self): } return credentials - def _get_cached_credentials(self): - response = self._get_cached_credentials_response() - creds = response['Credentials'] - expiration = _serialize_if_needed(creds['Expiration'], iso=True) - return { - 'access_key': creds['AccessKeyId'], - 'secret_key': creds['SecretAccessKey'], - 'token': creds['SessionToken'], - 'expiry_time': expiration, - 'account_id': self._account_id, - } + def _resolve_account_id(self, response=None): + return self._account_id class SSOProvider(CredentialProvider): diff --git a/botocore/regions.py b/botocore/regions.py index c9fd8c3ada..7130a7e7c1 100644 --- a/botocore/regions.py +++ b/botocore/regions.py @@ -475,6 +475,7 @@ def __init__( event_emitter, use_ssl=True, requested_auth_scheme=None, + credentials=None, ): self._provider = EndpointProvider( ruleset_data=endpoint_ruleset_data, @@ -488,6 +489,7 @@ def __init__( self._use_ssl = use_ssl self._requested_auth_scheme = requested_auth_scheme self._instance_cache = {} + self._credentials = credentials def construct_endpoint( self, @@ -556,11 +558,13 @@ def _get_provider_params( customized_builtins = self._get_customized_builtins( operation_model, call_args, request_context ) - if not request_context.get('is_presign_request'): + is_presign = request_context.get('is_presign_request') + should_sign = self._requested_auth_scheme != UNSIGNED + # account ID routing is disabled for presigned and unsigned requests. + if not is_presign and should_sign: config = request_context['client_config'] account_id_endpoint_mode = config.account_id_endpoint_mode else: - # Presigning requests do not support account ID routing.. account_id_endpoint_mode = 'disabled' for param_name, param_def in self._param_definitions.items(): param_val = self._resolve_param_from_context( @@ -618,24 +622,20 @@ def _resolve_param_as_builtin( ): if builtin_name not in EndpointResolverBuiltins.__members__.values(): raise UnknownEndpointResolutionBuiltInName(name=builtin_name) - builtin_value = builtins.get(builtin_name) - if ( - builtin_name == EndpointResolverBuiltins.AWS_ACCOUNT_ID - and builtin_value is not None - ): - builtin_value = self._resolve_account_id_builtin( - builtin_value, account_id_endpoint_mode - ) - return builtin_value - def _resolve_account_id_builtin(self, builtin, account_id_endpoint_mode): + if builtin_name == EndpointResolverBuiltins.AWS_ACCOUNT_ID: + return self._resolve_account_id_builtin(account_id_endpoint_mode) + + return builtins.get(builtin_name) + + def _resolve_account_id_builtin(self, account_id_endpoint_mode): self._validate_account_id_endpoint_mode(account_id_endpoint_mode) if account_id_endpoint_mode == 'disabled': return None - builtin_value = builtin() - if builtin_value is None: + account_id = self._credentials.account_id + if account_id is None: if account_id_endpoint_mode == 'preferred': LOG.warning( '`account_id_endpoint_mode` is enabled but no account ID was found!' @@ -643,7 +643,7 @@ def _resolve_account_id_builtin(self, builtin, account_id_endpoint_mode): elif account_id_endpoint_mode == 'required': raise AccountIDNotFound() - return builtin_value + return account_id def _validate_account_id_endpoint_mode(account_id_endpoint_mode): if account_id_endpoint_mode not in VALID_ACCOUNT_ID_ENDPOINT_MODES: diff --git a/tests/unit/test_args.py b/tests/unit/test_args.py index 12f3572fe8..01266f6497 100644 --- a/tests/unit/test_args.py +++ b/tests/unit/test_args.py @@ -665,7 +665,6 @@ def call_compute_endpoint_resolver_builtin_defaults(self, **overrides): 'endpoint_bridge': self.bridge, 'client_endpoint_url': None, 'legacy_endpoint_url': 'https://my.legacy.endpoint.com', - 'credentials': self.credentials, } kwargs = {**defaults, **overrides} return self.args_create.compute_endpoint_resolver_builtin_defaults( @@ -688,9 +687,7 @@ def test_builtins_defaults(self): bins['AWS::S3::DisableMultiRegionAccessPoints'], False ) self.assertEqual(bins['SDK::Endpoint'], None) - self.assertEqual( - bins['AWS::Auth::AccountId'], self.credentials.account_id_getter - ) + self.assertEqual(bins['AWS::Auth::AccountId'], None) def test_aws_region(self): bins = self.call_compute_endpoint_resolver_builtin_defaults( @@ -857,6 +854,4 @@ def test_sdk_endpoint_legacy_set_without_builtin_data(self): def test_aws_account_id(self): bins = self.call_compute_endpoint_resolver_builtin_defaults() - self.assertEqual( - bins['AWS::Auth::AccountId'](), self.credentials.account_id - ) + self.assertIsNone(bins['AWS::Auth::AccountId'])