From 31e83d95069124e98c5d11c9d1d5a6a23dc8e8a6 Mon Sep 17 00:00:00 2001 From: davidlm Date: Tue, 26 Sep 2023 20:26:10 -0400 Subject: [PATCH] code cleanup and tests --- botocore/config.py | 2 +- botocore/credentials.py | 95 ++++++++----- botocore/session.py | 6 +- tests/__init__.py | 8 +- tests/functional/test_credentials.py | 134 ++++++++++++++++-- tests/unit/test_args.py | 12 -- tests/unit/test_credentials.py | 204 +++++++++++++++++---------- 7 files changed, 331 insertions(+), 130 deletions(-) diff --git a/botocore/config.py b/botocore/config.py index bbc1d62d8b..a86e347d77 100644 --- a/botocore/config.py +++ b/botocore/config.py @@ -223,7 +223,7 @@ class Config: :type account_id_endpoint_mode: str :param account_id_endpoint_mode: Enables or disables account ID based - endpoint routing for supported operations. + endpoint routing for supported operations. Defaults to None. """ diff --git a/botocore/credentials.py b/botocore/credentials.py index 1c2b64de85..63d61ffb36 100644 --- a/botocore/credentials.py +++ b/botocore/credentials.py @@ -48,6 +48,7 @@ ContainerMetadataFetcher, FileWebIdentityTokenLoader, InstanceMetadataFetcher, + InvalidArnException, JSONFileCache, SSOTokenLoader, parse_key_val_file, @@ -306,22 +307,27 @@ class Credentials: :param str access_key: The access key part of the credentials. :param str secret_key: The secret key part of the credentials. :param str token: The security token, valid only for session credentials. - :param str account_id: The account ID associated with the credentials. :param str method: A string which identifies where the credentials were found. + :param str account_id: The account ID associated with the credentials. """ def __init__( - self, access_key, secret_key, token=None, account_id=None, method=None + self, + access_key, + secret_key, + token=None, + method=None, + account_id=None, ): self.access_key = access_key self.secret_key = secret_key self.token = token - self.account_id = account_id if method is None: method = 'explicit' self.method = method + self.account_id = account_id self._normalize() @@ -349,11 +355,12 @@ class RefreshableCredentials(Credentials): :param str access_key: The access key part of the credentials. :param str secret_key: The secret key part of the credentials. :param str token: The security token, valid only for session credentials. - :param str account_id: The account ID associated with the credentials. + :param datetime expiry_time: The time when the credentials will expire. :param function refresh_using: Callback function to refresh the credentials. :param str method: A string which identifies where the credentials were found. :param function time_fetcher: Callback function to retrieve current time. + :param str account_id: The account ID associated with the credentials. """ # The time at which we'll attempt to refresh, but not @@ -368,17 +375,16 @@ def __init__( access_key, secret_key, token, - account_id, expiry_time, refresh_using, method, time_fetcher=_local_now, + account_id=None, ): self._refresh_using = refresh_using self._access_key = access_key self._secret_key = secret_key self._token = token - self._account_id = account_id self._expiry_time = expiry_time self._time_fetcher = time_fetcher self._refresh_lock = threading.Lock() @@ -387,6 +393,7 @@ def __init__( access_key, secret_key, token, account_id ) self._normalize() + self._account_id = account_id def _normalize(self): self._access_key = botocore.compat.ensure_unicode(self._access_key) @@ -398,10 +405,10 @@ def create_from_metadata(cls, metadata, refresh_using, method): access_key=metadata['access_key'], secret_key=metadata['secret_key'], token=metadata['token'], - account_id=metadata.get('account_id'), expiry_time=cls._expiry_datetime(metadata['expiry_time']), method=method, refresh_using=refresh_using, + account_id=metadata.get('account_id'), ) return instance @@ -446,8 +453,16 @@ def token(self, value): @property def account_id(self): - frozen_credentials = self.get_frozen_credentials() - return frozen_credentials.account_id + """Warning: Using this property can lead to race conditions if you + access another property subsequently along the refresh boundary. + Please use get_frozen_credentials instead. + """ + self._refresh() + return self._account_id + + @account_id.setter + def account_id(self, value): + self._account_id = value def _seconds_remaining(self): delta = self._expiry_time - self._time_fetcher() @@ -585,7 +600,7 @@ def _set_from_data(self, data): logger.debug( "Retrieved credentials will expire at: %s", self._expiry_time ) - self._account_id = data.get('account_id') + self.account_id = data.get('account_id') self._normalize() def get_frozen_credentials(self): @@ -696,7 +711,7 @@ def _get_cached_credentials(self): 'secret_key': creds['SecretAccessKey'], 'token': creds['SessionToken'], 'expiry_time': expiration, - 'account_id': self._resolve_account_id(response), + 'account_id': creds.get('AccountId'), } def _load_from_cache(self): @@ -719,9 +734,6 @@ def _is_expired(self, credentials): seconds = total_seconds(end_time - _local_now()) return seconds < self._expiry_window_seconds - def _resolve_account_id(self, response=None): - raise NotImplementedError('_resolve_account_id()') - class BaseAssumeRoleCredentialFetcher(CachedCredentialFetcher): def __init__( @@ -747,6 +759,7 @@ def __init__( self._generate_assume_role_name() super().__init__(cache, expiry_window_seconds) + self._arn_parser = ArnParser() def _generate_assume_role_name(self): self._role_session_name = 'botocore-session-%s' % (int(time.time())) @@ -776,8 +789,16 @@ def _create_cache_key(self): return self._make_file_safe(argument_hash) def _resolve_account_id(self, response): - user_arn = response['AssumedRoleUser']['Arn'] - return ArnParser().parse_arn(user_arn)['account'] + account_id = None + user_arn = response.get('AssumedRoleUser', {}).get('Arn') + if user_arn is not None: + try: + account_id = self._arn_parser.parse_arn(user_arn)['account'] + except InvalidArnException: + logger.debug( + 'Unable to parse account ID from ARN: %s', user_arn + ) + response['Credentials']['AccountId'] = account_id class AssumeRoleCredentialFetcher(BaseAssumeRoleCredentialFetcher): @@ -839,7 +860,9 @@ def _get_credentials(self): """Get credentials by calling assume role.""" kwargs = self._assume_role_kwargs() client = self._create_client() - return client.assume_role(**kwargs) + response = client.assume_role(**kwargs) + self._resolve_account_id(response) + return response def _assume_role_kwargs(self): """Get the arguments for assume role based on current configuration.""" @@ -926,7 +949,9 @@ def _get_credentials(self): # the token, explicitly configure the client to not sign requests. config = Config(signature_version=UNSIGNED) client = self._client_creator('sts', config=config) - return client.assume_role_with_web_identity(**kwargs) + response = client.assume_role_with_web_identity(**kwargs) + self._resolve_account_id(response) + return response def _assume_role_kwargs(self): """Get the arguments for assume role based on current configuration.""" @@ -1009,8 +1034,8 @@ def load(self): access_key=creds_dict['access_key'], secret_key=creds_dict['secret_key'], token=creds_dict.get('token'), - account_id=creds_dict['account_id'], method=self.METHOD, + account_id=creds_dict['account_id'], ) def _retrieve_credentials_using(self, credential_process): @@ -1116,8 +1141,8 @@ def __init__(self, environ=None, mapping=None): :param mapping: An optional mapping of variable names to environment variable names. Use this if you want to change the mapping of access_key->AWS_ACCESS_KEY_ID, etc. - The dict can have up to 3 keys: ``access_key``, ``secret_key``, - ``session_token``. + The dict can have up to 5 keys: ``access_key``, ``secret_key``, + ``token``, ``expiry_time``, and ``account_id``. """ if environ is None: environ = os.environ @@ -1163,6 +1188,7 @@ def load(self): logger.info('Found credentials in environment variables.') fetcher = self._create_credentials_fetcher() credentials = fetcher(require_expiry=False) + expiry_time = credentials['expiry_time'] if expiry_time is not None: expiry_time = parse(expiry_time) @@ -1170,18 +1196,18 @@ def load(self): credentials['access_key'], credentials['secret_key'], credentials['token'], - credentials['account_id'], expiry_time, refresh_using=fetcher, method=self.METHOD, + account_id=credentials['account_id'], ) return Credentials( credentials['access_key'], credentials['secret_key'], credentials['token'], - credentials['account_id'], method=self.METHOD, + account_id=credentials['account_id'], ) else: return None @@ -1223,10 +1249,7 @@ def fetch_credentials(require_expiry=True): raise PartialCredentialsError( provider=method, cred_var=mapping['expiry_time'] ) - credentials['account_id'] = None - account_id = environ.get(mapping['account_id'], '') - if account_id: - credentials['account_id'] = account_id + credentials['account_id'] = environ.get(mapping['account_id']) return credentials @@ -1305,13 +1328,13 @@ def load(self): config, self.ACCESS_KEY, self.SECRET_KEY ) token = self._get_session_token(config) - account_id = config.get(self.ACCOUNT_ID) + account_id = self._get_account_id(config) return Credentials( access_key, secret_key, token, - account_id, method=self.METHOD, + account_id=account_id, ) def _get_session_token(self, config): @@ -1319,6 +1342,9 @@ def _get_session_token(self, config): if token_envvar in config: return config[token_envvar] + def _get_account_id(self, config): + return config.get(self.ACCOUNT_ID) + class ConfigProvider(CredentialProvider): """INI based config provider with profile sections.""" @@ -1368,14 +1394,14 @@ def load(self): access_key, secret_key = self._extract_creds_from_mapping( profile_config, self.ACCESS_KEY, self.SECRET_KEY ) - account_id = profile_config.get(self.ACCOUNT_ID) token = self._get_session_token(profile_config) + account_id = self._get_account_id(profile_config) return Credentials( access_key, secret_key, token, - account_id, method=self.METHOD, + account_id=account_id, ) else: return None @@ -1385,6 +1411,9 @@ def _get_session_token(self, profile_config): if token_name in profile_config: return profile_config[token_name] + def _get_account_id(self, profile_config): + return profile_config.get(self.ACCOUNT_ID) + class BotoProvider(CredentialProvider): METHOD = 'boto-config' @@ -2192,13 +2221,11 @@ def _get_credentials(self): 'SecretAccessKey': credentials['secretAccessKey'], 'SessionToken': credentials['sessionToken'], 'Expiration': self._parse_timestamp(credentials['expiration']), + 'AccountId': self._account_id, }, } return credentials - def _resolve_account_id(self, response=None): - return self._account_id - class SSOProvider(CredentialProvider): METHOD = 'sso' diff --git a/botocore/session.py b/botocore/session.py index 0bfb82080d..cdb5a3feaf 100644 --- a/botocore/session.py +++ b/botocore/session.py @@ -840,8 +840,8 @@ def create_client( aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, - aws_account_id=None, config=None, + aws_account_id=None, ): """Create a botocore client. @@ -911,6 +911,10 @@ def create_client( :rtype: botocore.client.BaseClient :return: A botocore client instance + :type aws_account_id: string + :param aws_account_id: The AWS account ID to use when creating the client. + Same semantics as aws_access_key_id above. + """ default_client_config = self.get_default_client_config() # If a config is provided and a default config is set, then diff --git a/tests/__init__.py b/tests/__init__.py index b098f705f3..fb9bc4e68a 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -320,7 +320,13 @@ def __init__( if refresh_function is None: refresh_function = self._do_refresh super().__init__( - '0', '0', '0', '0', expires_in, refresh_function, 'INTREFRESH' + '0', + '0', + '0', + expires_in, + refresh_function, + 'INTREFRESH', + account_id='0', ) self.creds_last_for = creds_last_for self.refresh_counter = 0 diff --git a/tests/functional/test_credentials.py b/tests/functional/test_credentials.py index 3446836619..cbc36ca1cb 100644 --- a/tests/functional/test_credentials.py +++ b/tests/functional/test_credentials.py @@ -209,7 +209,7 @@ def create_random_credentials(self): 'fake-%s' % random_chars(15), 'fake-%s' % random_chars(35), 'fake-%s' % random_chars(45), - 'fake-%s' % random_chars(12), + account_id='fake-%s' % random_chars(12), ) def assert_creds_equal(self, c1, c2): @@ -773,6 +773,67 @@ def test_assume_role_uses_correct_region(self): self.assert_creds_equal(creds, expected_creds) self.assertEqual(self.actual_client_region, 'cn-north-1') + def test_assume_role_account_id_config(self): + config = ( + "[profile assume-role]\n" + "role_arn = arn:aws:iam::123456789002:role/MyRole\n" + "source_profile = assume-creds\n\n" + "[profile assume-creds]\n" + "aws_access_key_id = abc123\n" + "aws_secret_access_key = def456\n" + "aws_account_id = 123456789001" + ) + self.write_config(config) + session = Session(profile='assume-role') + create_client, expected_creds = self.create_stubbed_sts_client(session) + session.create_client = create_client + + resolver = create_credential_resolver(session) + provider = resolver.get_provider('assume-role') + creds = provider.load() + self.assert_creds_equal(creds, expected_creds) + + def test_chained_assume_role_account_id(self): + config = ( + "[profile final-role]\n" + "role_arn = arn:aws:iam::123456789003:role/MyRole\n" + "source_profile = chained-role\n\n" + "[profile chained-role]\n" + "role_arn = arn:aws:iam::123456789002:role/MyRole\n" + "source_profile = assume-creds\n\n" + "[profile assume-creds]\n" + "aws_access_key_id = abc123\n" + "aws_secret_access_key = def456\n" + "aws_account_id = 123456789001" + ) + self.write_config(config) + session = Session(profile='final-role') + create_client, expected_creds = self.create_stubbed_sts_client(session) + session.create_client = create_client + + resolver = create_credential_resolver(session) + provider = resolver.get_provider('assume-role') + creds = provider.load() + self.assert_creds_equal(creds, expected_creds) + + def test_environment_credential_source_account_id(self): + config = ( + "[profile assume-role]\n" + "role_arn = arn:aws:iam::123456789002:role/MyRole\n" + "credential_source = Environment\n" + "aws_account_id = 123456789001" + ) + self.write_config(config) + self.env_provider.load.return_value = self.create_random_credentials() + session, stubber = self.create_session(profile='assume-role') + expected_creds = self.create_random_credentials() + response = self.create_assume_role_response(expected_creds) + stubber.add_response('assume_role', response) + creds = session.get_credentials() + self.assert_creds_equal(creds, expected_creds) + stubber.assert_no_pending_responses() + self.assertEqual(self.env_provider.load.call_count, 1) + class TestAssumeRoleWithWebIdentity(BaseAssumeRoleTest): def setUp(self): @@ -791,9 +852,7 @@ def assert_session_credentials(self, expected_params, **kwargs): response = self.create_assume_role_response(expected_creds) session = StubbedSession(**kwargs) stubber = session.stub( - 'sts', - aws_access_key_id='spam', - aws_secret_access_key='eggs', + 'sts', config=Config(signature_version=UNSIGNED) ) stubber.add_response( 'assume_role_with_web_identity', response, expected_params @@ -1150,6 +1209,15 @@ def add_credential_response(self, stubber): stubber.add_response(body=json.dumps(response).encode('utf-8')) +def assert_credentials(access_key, secret_key, token, account_id): + session = Session() + credentials = session.get_credentials() + assert credentials.access_key == access_key + assert credentials.secret_key == secret_key + assert credentials.token == token + assert credentials.account_id == account_id + + @pytest.mark.parametrize( 'config, access_key, secret_key, token, account_id', [ @@ -1208,9 +1276,55 @@ def test_config_provider(config, access_key, secret_key, token, account_id): f.write(config) f.flush() with mock.patch('os.environ', {'AWS_CONFIG_FILE': f.name}): - session = Session() - credentials = session.get_credentials() - assert credentials.access_key == access_key - assert credentials.secret_key == secret_key - assert credentials.token == token - assert credentials.account_id == account_id + assert_credentials(access_key, secret_key, token, account_id) + + +@pytest.mark.parametrize( + "env_vars, access_key, secret_key, token, account_id", + [ + ( + { + "AWS_ACCOUNT_ID": "123456789001", + "AWS_ACCESS_KEY_ID": "foo", + "AWS_SECRET_ACCESS_KEY": "bar", + }, + "foo", + "bar", + None, + "123456789001", + ), + ( + {"AWS_ACCESS_KEY_ID": "foo", "AWS_SECRET_ACCESS_KEY": "bar"}, + "foo", + "bar", + None, + None, + ), + ( + { + "AWS_ACCOUNT_ID": "123456789001", + "AWS_ACCESS_KEY_ID": "foo", + "AWS_SECRET_ACCESS_KEY": "bar", + "AWS_SESSION_TOKEN": "baz", + }, + "foo", + "bar", + "baz", + "123456789001", + ), + ( + { + "AWS_ACCESS_KEY_ID": "foo", + "AWS_SECRET_ACCESS_KEY": "bar", + "AWS_SESSION_TOKEN": "baz", + }, + "foo", + "bar", + "baz", + None, + ), + ], +) +def test_env_provider(env_vars, access_key, secret_key, token, account_id): + with mock.patch('os.environ', env_vars): + assert_credentials(access_key, secret_key, token, account_id) diff --git a/tests/unit/test_args.py b/tests/unit/test_args.py index 01266f6497..c84aa6ee2d 100644 --- a/tests/unit/test_args.py +++ b/tests/unit/test_args.py @@ -17,7 +17,6 @@ from botocore.client import ClientEndpointBridge from botocore.config import Config from botocore.configprovider import ConfigValueStore -from botocore.credentials import Credentials from botocore.hooks import HierarchicalEmitter from botocore.model import ServiceModel from botocore.useragent import UserAgentString @@ -649,13 +648,6 @@ def setUp(self): # assume a legacy endpoint resolver that uses the builtin # endpoints.json file self.bridge.endpoint_resolver.uses_builtin_data = True - self.credentials = Credentials( - access_key='foo', - secret_key='bar', - token='baz', - account_id='fiz', - method='test', - ) def call_compute_endpoint_resolver_builtin_defaults(self, **overrides): defaults = { @@ -851,7 +843,3 @@ def test_sdk_endpoint_legacy_set_without_builtin_data(self): legacy_endpoint_url='https://my.legacy.endpoint.com', ) self.assertEqual(bins['SDK::Endpoint'], None) - - def test_aws_account_id(self): - bins = self.call_compute_endpoint_resolver_builtin_defaults() - self.assertIsNone(bins['AWS::Auth::AccountId']) diff --git a/tests/unit/test_credentials.py b/tests/unit/test_credentials.py index 1aab8e40b8..70ca361445 100644 --- a/tests/unit/test_credentials.py +++ b/tests/unit/test_credentials.py @@ -748,7 +748,7 @@ def test_mfa_refresh_enabled(self): ] self.assertEqual(calls, expected_calls) - def test_no_cache_account_id(self): + def test_account_id(self): response = { 'Credentials': { 'AccessKeyId': 'foo', @@ -771,76 +771,6 @@ def test_no_cache_account_id(self): self.assertEqual(response, expected_response) - def test_retrieves_from_cache_account_id(self): - date_in_future = datetime.utcnow() + timedelta(seconds=1000) - utc_timestamp = date_in_future.isoformat() + 'Z' - cache_key = '793d6e2f27667ab2da104824407e486bfec24a47' - cache = { - cache_key: { - 'Credentials': { - 'AccessKeyId': 'foo-cached', - 'SecretAccessKey': 'bar-cached', - 'SessionToken': 'baz-cached', - 'Expiration': utc_timestamp, - 'AccountId': '123456789012-cached', - }, - 'AssumedRoleUser': { - 'AssumedRoleId': 'ARO123EXAMPLE123:myrole', - 'Arn': 'arn:aws:sts::123456789012-cached:assumed-role/myrole', - }, - } - } - client_creator = mock.Mock() - refresher = credentials.AssumeRoleCredentialFetcher( - client_creator, self.source_creds, self.role_arn, cache=cache - ) - - expected_response = self.get_expected_creds_from_response( - cache[cache_key] - ) - response = refresher.fetch_credentials() - - self.assertEqual(response, expected_response) - client_creator.assert_not_called() - - def test_expired_cache_account_id(self): - response = { - 'Credentials': { - 'AccessKeyId': 'foo', - 'SecretAccessKey': 'bar', - 'SessionToken': 'baz', - 'Expiration': self.some_future_time().isoformat(), - }, - 'AssumedRoleUser': { - 'AssumedRoleId': 'ARO123EXAMPLE123:myrole', - 'Arn': 'arn:aws:sts::123456789012:assumed-role/myrole', - }, - } - client_creator = self.create_client_creator(with_response=response) - cache = { - 'development--myrole': { - 'Credentials': { - 'AccessKeyId': 'foo-cached', - 'SecretAccessKey': 'bar-cached', - 'SessionToken': 'baz-cached', - 'Expiration': datetime.now(tzlocal()), - 'AccountId': '123456789012-cached', - }, - 'AssumedRoleUser': { - 'AssumedRoleId': 'ARO123EXAMPLE123:myrole', - 'Arn': 'arn:aws:sts::123456789012:assumed-role/myrole', - }, - } - } - - refresher = credentials.AssumeRoleCredentialFetcher( - client_creator, self.source_creds, self.role_arn, cache=cache - ) - expected = self.get_expected_creds_from_response(response) - response = refresher.fetch_credentials() - - self.assertEqual(response, expected) - class TestAssumeRoleWithWebIdentityCredentialFetcher(BaseEnvVar): def setUp(self): @@ -954,6 +884,29 @@ def test_assume_role_in_cache_but_expired(self): self.assertEqual(response, expected) + def test_account_id(self): + response = { + 'Credentials': { + 'AccessKeyId': 'foo', + 'SecretAccessKey': 'bar', + 'SessionToken': 'baz', + 'Expiration': self.some_future_time().isoformat(), + }, + 'AssumedRoleUser': { + 'AssumedRoleId': 'ARO123EXAMPLE123:myrole', + 'Arn': 'arn:aws:sts::123456789012:assumed-role/myrole', + }, + } + client_creator = self.create_client_creator(with_response=response) + refresher = credentials.AssumeRoleWithWebIdentityCredentialFetcher( + client_creator, self.load_token, self.role_arn + ) + + expected_response = self.get_expected_creds_from_response(response) + response = refresher.fetch_credentials() + + self.assertEqual(response, expected_response) + class TestAssumeRoleWithWebIdentityCredentialProvider(unittest.TestCase): def setUp(self): @@ -3518,6 +3471,7 @@ def test_missing_session_token(self): 'SecretAccessKey': 'bar', # Missing session token. 'Expiration': '2999-01-01T00:00:00Z', + 'AccountId': '1234567890', } ) @@ -3528,6 +3482,7 @@ def test_missing_session_token(self): self.assertEqual(creds.secret_key, 'bar') self.assertIsNone(creds.token) self.assertEqual(creds.method, 'custom-process') + self.assertEqual(creds.account_id, '1234567890') def test_missing_expiration(self): self.loaded_config['profiles'] = { @@ -3540,6 +3495,7 @@ def test_missing_expiration(self): 'SecretAccessKey': 'bar', 'SessionToken': 'baz', # Missing expiration. + 'AccountId': '1234567890', } ) @@ -3550,6 +3506,31 @@ def test_missing_expiration(self): self.assertEqual(creds.secret_key, 'bar') self.assertEqual(creds.token, 'baz') self.assertEqual(creds.method, 'custom-process') + self.assertEqual(creds.account_id, '1234567890') + + def test_missing_account_id(self): + self.loaded_config['profiles'] = { + 'default': {'credential_process': 'my-process'} + } + self._set_process_return_value( + { + 'Version': 1, + 'AccessKeyId': 'foo', + 'SecretAccessKey': 'bar', + 'SessionToken': 'baz', + 'Expiration': '2999-01-01T00:00:00Z', + # Missing AccountId. + } + ) + + provider = self.create_process_provider() + creds = provider.load() + self.assertIsNotNone(creds) + self.assertEqual(creds.access_key, 'foo') + self.assertEqual(creds.secret_key, 'bar') + self.assertEqual(creds.token, 'baz') + self.assertEqual(creds.method, 'custom-process') + self.assertIsNone(creds.account_id) def test_missing_expiration_and_session_token(self): self.loaded_config['profiles'] = { @@ -3561,6 +3542,7 @@ def test_missing_expiration_and_session_token(self): 'AccessKeyId': 'foo', 'SecretAccessKey': 'bar', # Missing session token and expiration + 'AccountId': '1234567890', } ) @@ -3571,6 +3553,81 @@ def test_missing_expiration_and_session_token(self): self.assertEqual(creds.secret_key, 'bar') self.assertIsNone(creds.token) self.assertEqual(creds.method, 'custom-process') + self.assertEqual(creds.account_id, '1234567890') + + def test_missing_expiration_session_token_account_id(self): + self.loaded_config['profiles'] = { + 'default': {'credential_process': 'my-process'} + } + self._set_process_return_value( + { + 'Version': 1, + 'AccessKeyId': 'foo', + 'SecretAccessKey': 'bar', + # Missing session token, expiration and account ID + } + ) + + provider = self.create_process_provider() + creds = provider.load() + self.assertIsNotNone(creds) + self.assertEqual(creds.access_key, 'foo') + self.assertEqual(creds.secret_key, 'bar') + self.assertIsNone(creds.token) + self.assertEqual(creds.method, 'custom-process') + self.assertIsNone(creds.account_id) + + def test_account_id_from_profile(self): + self.loaded_config['profiles'] = { + 'default': { + 'credential_process': 'my-process', + 'aws_account_id': '1234567890', + } + } + self._set_process_return_value( + { + 'Version': 1, + 'AccessKeyId': 'foo', + 'SecretAccessKey': 'bar', + 'SessionToken': 'baz', + 'Expiration': '2999-01-01T00:00:00Z', + # Missing AccountId. + } + ) + provider = self.create_process_provider() + creds = provider.load() + self.assertIsNotNone(creds) + self.assertEqual(creds.access_key, 'foo') + self.assertEqual(creds.secret_key, 'bar') + self.assertEqual(creds.token, 'baz') + self.assertEqual(creds.method, 'custom-process') + self.assertEqual(creds.account_id, '1234567890') + + def test_account_id_from_process_takes_precedence(self): + self.loaded_config['profiles'] = { + 'default': { + 'credential_process': 'my-process', + 'aws_account_id': '1234567890', + } + } + self._set_process_return_value( + { + 'Version': 1, + 'AccessKeyId': 'foo', + 'SecretAccessKey': 'bar', + 'SessionToken': 'baz', + 'Expiration': '2999-01-01T00:00:00Z', + 'AccountId': '0987654321', + } + ) + provider = self.create_process_provider() + creds = provider.load() + self.assertIsNotNone(creds) + self.assertEqual(creds.access_key, 'foo') + self.assertEqual(creds.secret_key, 'bar') + self.assertEqual(creds.token, 'baz') + self.assertEqual(creds.method, 'custom-process') + self.assertEqual(creds.account_id, '0987654321') class TestProfileProviderBuilder(unittest.TestCase): @@ -3753,6 +3810,7 @@ def test_load_sso_credentials_without_cache(self): self.assertEqual(credentials.access_key, 'foo') self.assertEqual(credentials.secret_key, 'bar') self.assertEqual(credentials.token, 'baz') + self.assertEqual(credentials.account_id, '1234567890') def test_load_sso_credentials_with_cache(self): cached_creds = { @@ -3761,6 +3819,7 @@ def test_load_sso_credentials_with_cache(self): 'SecretAccessKey': 'cached-sak', 'SessionToken': 'cached-st', 'Expiration': self.expires_at.strftime('%Y-%m-%dT%H:%M:%S%Z'), + 'AccountId': '1234567890-cached', } } self.cache[self.cached_creds_key] = cached_creds @@ -3768,6 +3827,7 @@ def test_load_sso_credentials_with_cache(self): self.assertEqual(credentials.access_key, 'cached-akid') self.assertEqual(credentials.secret_key, 'cached-sak') self.assertEqual(credentials.token, 'cached-st') + self.assertEqual(credentials.account_id, '1234567890-cached') def test_load_sso_credentials_with_cache_expired(self): cached_creds = { @@ -3776,6 +3836,7 @@ def test_load_sso_credentials_with_cache_expired(self): 'SecretAccessKey': 'expired-sak', 'SessionToken': 'expired-st', 'Expiration': '2002-10-22T20:52:11UTC', + 'AccountId': '1234567890-expired', } } self.cache[self.cached_creds_key] = cached_creds @@ -3786,6 +3847,7 @@ def test_load_sso_credentials_with_cache_expired(self): self.assertEqual(credentials.access_key, 'foo') self.assertEqual(credentials.secret_key, 'bar') self.assertEqual(credentials.token, 'baz') + self.assertEqual(credentials.account_id, '1234567890') def test_required_config_not_set(self): del self.config['sso_start_url']