Skip to content

Commit

Permalink
testing updates
Browse files Browse the repository at this point in the history
  • Loading branch information
davidlm committed Sep 25, 2023
1 parent 46ee5ae commit bb64447
Show file tree
Hide file tree
Showing 6 changed files with 402 additions and 56 deletions.
9 changes: 9 additions & 0 deletions botocore/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def compute_client_args(
self._compute_connect_timeout(config_kwargs)
self._compute_user_agent_appid_config(config_kwargs)
self._compute_request_compression_config(config_kwargs)
self._compute_account_id_endpoint_mode(config_kwargs)
s3_config = self.compute_s3_config(client_config)

is_s3_service = self._is_s3_service(service_name)
Expand Down Expand Up @@ -601,6 +602,14 @@ def _validate_min_compression_size(self, min_size):

return min_size

def _compute_account_id_endpoint_mode(self, config_kwargs):
ep_mode = config_kwargs.get('account_id_endpoint_mode')
if ep_mode is None:
ep_mode = self._config_store.get_config_variable(
'account_id_endpoint_mode'
)
config_kwargs['account_id_endpoint_mode'] = ep_mode

def _ensure_boolean(self, val):
if isinstance(val, bool):
return val
Expand Down
95 changes: 52 additions & 43 deletions botocore/regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,14 +558,7 @@ def _get_provider_params(
customized_builtins = self._get_customized_builtins(
operation_model, call_args, request_context
)
is_presign = request_context.get('is_presign_request')
should_sign = self._requested_auth_scheme != UNSIGNED
# account ID routing is disabled for presigned and unsigned requests.
if not is_presign and should_sign:
config = request_context['client_config']
account_id_endpoint_mode = config.account_id_endpoint_mode
else:
account_id_endpoint_mode = 'disabled'
self._resolve_account_id_builtin(request_context, customized_builtins)
for param_name, param_def in self._param_definitions.items():
param_val = self._resolve_param_from_context(
param_name=param_name,
Expand All @@ -576,12 +569,61 @@ def _get_provider_params(
param_val = self._resolve_param_as_builtin(
builtin_name=param_def.builtin,
builtins=customized_builtins,
account_id_endpoint_mode=account_id_endpoint_mode,
)
if param_val is not None:
provider_params[param_name] = param_val
return provider_params

def _resolve_account_id_builtin(self, request_context, builtins):
"""Resolve the ``AWS::Auth::AccountId`` builtin if account ID based
routing is enabled. It initially defaults to None to avoid resolving
credentials during client creation.
"""
account_id_endpoint_mode = self._resolve_account_id_endpoint_mode(
request_context
)
if account_id_endpoint_mode == 'disabled':
return
# This will make a call to resolve credentials if they are not already
# or need to be refreshed.
frozen_creds = self._credentials.get_frozen_credentials()
account_id = frozen_creds.account_id
if account_id is None:
if account_id_endpoint_mode == 'preferred':
LOG.debug(
'`account_id_endpoint_mode` is set to `preferred`, but no '
'account ID was found.'
)
elif account_id_endpoint_mode == 'required':
raise AccountIDNotFound()
else:
builtins[EndpointResolverBuiltins.AWS_ACCOUNT_ID] = account_id

def _resolve_account_id_endpoint_mode(self, request_context):
"""Resolve the account ID endpoint mode for the request. Account ID
based routing is always disabled for presigned and unsigned requests.
Otherwise, the mode is determined by the ``account_id_endpoint_mode``
config setting.
"""
is_presign = request_context.get('is_presign_request')
should_sign = self._requested_auth_scheme != UNSIGNED
if not is_presign and should_sign:
config = request_context['client_config']
act_id_ep_mode = config.account_id_endpoint_mode
return self._validate_account_id_endpoint_mode(act_id_ep_mode)
return 'disabled'

def _validate_account_id_endpoint_mode(self, account_id_endpoint_mode):
if account_id_endpoint_mode not in VALID_ACCOUNT_ID_ENDPOINT_MODES:
valid_modes_str = ', '.join(VALID_ACCOUNT_ID_ENDPOINT_MODES)
error_msg = (
f"Invalid value '{account_id_endpoint_mode}' for "
"account_id_endpoint_mode. Valid values are: "
f"{valid_modes_str}"
)
raise InvalidConfigError(error_msg=error_msg)
return account_id_endpoint_mode

def _resolve_param_from_context(
self, param_name, operation_model, call_args
):
Expand Down Expand Up @@ -617,44 +659,11 @@ def _resolve_param_as_client_context_param(self, param_name):
client_ctx_varname = client_ctx_params[param_name]
return self._client_context.get(client_ctx_varname)

def _resolve_param_as_builtin(
self, builtin_name, builtins, account_id_endpoint_mode
):
def _resolve_param_as_builtin(self, builtin_name, builtins):
if builtin_name not in EndpointResolverBuiltins.__members__.values():
raise UnknownEndpointResolutionBuiltInName(name=builtin_name)

if builtin_name == EndpointResolverBuiltins.AWS_ACCOUNT_ID:
return self._resolve_account_id_builtin(account_id_endpoint_mode)

return builtins.get(builtin_name)

def _resolve_account_id_builtin(self, account_id_endpoint_mode):
self._validate_account_id_endpoint_mode(account_id_endpoint_mode)

if account_id_endpoint_mode == 'disabled':
return None

account_id = self._credentials.account_id
if account_id is None:
if account_id_endpoint_mode == 'preferred':
LOG.warning(
'`account_id_endpoint_mode` is enabled but no account ID was found!'
)
elif account_id_endpoint_mode == 'required':
raise AccountIDNotFound()

return account_id

def _validate_account_id_endpoint_mode(account_id_endpoint_mode):
if account_id_endpoint_mode not in VALID_ACCOUNT_ID_ENDPOINT_MODES:
valid_modes_str = ', '.join(VALID_ACCOUNT_ID_ENDPOINT_MODES)
error_msg = (
f"Invalid value '{account_id_endpoint_mode}' for "
"account_id_endpoint_mode. Valid values are: "
f"{valid_modes_str}"
)
raise InvalidConfigError(error_msg=error_msg)

@instance_cache
def _get_static_context_params(self, operation_model):
"""Mapping of param names to static param value for an operation"""
Expand Down
72 changes: 64 additions & 8 deletions tests/functional/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,11 +1150,67 @@ def add_credential_response(self, stubber):
stubber.add_response(body=json.dumps(response).encode('utf-8'))


def _load_account_id_test_cases():
path = os.path.join(
os.path.dirname(__file__),
'credentials',
'accountid-source-testcases.json',
)
with open(os.path.join(path, 'account_id_test_cases.json')) as f:
return json.load(f)
@pytest.mark.parametrize(
'config, access_key, secret_key, token, account_id',
[
(
(
'[default]\n'
'aws_access_key_id = foo\n'
'aws_secret_access_key = bar\n'
'aws_account_id = 123456789001'
),
'foo',
'bar',
None,
'123456789001',
),
(
(
'[default]\n'
'aws_access_key_id = foo\n'
'aws_secret_access_key = bar\n'
),
'foo',
'bar',
None,
None,
),
(
(
'[default]\n'
'aws_access_key_id = foo\n'
'aws_secret_access_key = bar\n'
'aws_session_token = baz\n'
'aws_account_id = 123456789001\n'
),
'foo',
'bar',
'baz',
'123456789001',
),
(
(
'[default]\n'
'aws_access_key_id = foo\n'
'aws_secret_access_key = bar\n'
'aws_session_token = baz\n'
),
'foo',
'bar',
'baz',
None,
),
],
)
def test_config_provider(config, access_key, secret_key, token, account_id):
with temporary_file('w') as f:
f.write(config)
f.flush()
with mock.patch('os.environ', {'AWS_CONFIG_FILE': f.name}):
session = Session()
credentials = session.get_credentials()
assert credentials.access_key == access_key
assert credentials.secret_key == secret_key
assert credentials.token == token
assert credentials.account_id == account_id
83 changes: 83 additions & 0 deletions tests/unit/data/endpoints/valid-rules/aws-account-id.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
{
"parameters": {
"Region": {
"type": "string",
"builtIn": "AWS::Region",
"documentation": "The region to dispatch this request, eg. `us-east-1`."
},
"AccountId": {
"type": "string",
"builtIn": "AWS::Auth::AccountId",
"documentation": "The account ID to dispatch this request, eg. `us-east-1`."
}
},
"rules": [
{
"documentation": "Template the account ID into the URI when account ID is set",
"conditions": [
{
"fn": "isSet",
"argv": [
{
"ref": "AccountId"
}
]
},
{
"fn": "isSet",
"argv": [
{
"ref": "Region"
}
]
}
],
"endpoint": {
"url": "https://{AccountId}.amazonaws.com",
"properties": {
"authSchemes": [
{
"name": "sigv4",
"signingName": "serviceName",
"signingRegion": "{Region}"
}
]
}
},
"type": "endpoint"
},
{
"documentation": "Fallback when account ID isn't set",
"conditions": [
{
"fn": "isSet",
"argv": [
{
"ref": "Region"
}
]
}
],
"endpoint": {
"url": "https://amazonaws.com",
"properties": {
"authSchemes": [
{
"name": "sigv4",
"signingName": "serviceName",
"signingRegion": "{Region}"
}
]
}
},
"type": "endpoint"
},
{
"documentation": "fallback when region is unset",
"conditions": [],
"error": "Region must be set to resolve a valid endpoint",
"type": "error"
}
],
"version": "1.3"
}
Loading

0 comments on commit bb64447

Please sign in to comment.