diff --git a/changelogs/fragments/refactor_rds_instance.yml b/changelogs/fragments/refactor_rds_instance.yml new file mode 100644 index 00000000000..434ee27f2bb --- /dev/null +++ b/changelogs/fragments/refactor_rds_instance.yml @@ -0,0 +1,5 @@ +--- +minor_changes: + - module_utils/rds.py - Refactor shared boto3 client functionality, add type hinting and function docstrings (https://github.com/ansible-collections/amazon.aws/pull/2119). + - rds_instance - Refactor shared boto3 client functionality, add type hinting and function docstrings (https://github.com/ansible-collections/amazon.aws/pull/2119). + - rds_instance_info - Refactor shared boto3 client functionality, add type hinting and function docstrings (https://github.com/ansible-collections/amazon.aws/pull/2119). diff --git a/plugins/module_utils/rds.py b/plugins/module_utils/rds.py index 20e0ae5e083..45b6507dde4 100644 --- a/plugins/module_utils/rds.py +++ b/plugins/module_utils/rds.py @@ -8,6 +8,8 @@ from typing import Any from typing import Dict from typing import List +from typing import Optional +from typing import Tuple try: from botocore.exceptions import BotoCoreError @@ -21,6 +23,8 @@ from .botocore import is_boto3_error_code from .core import AnsibleAWSModule +from .errors import AWSErrorHandler +from .exceptions import AnsibleAWSError from .retries import AWSRetry from .tagging import ansible_dict_to_boto3_tag_list from .tagging import boto3_tag_list_to_ansible_dict @@ -83,7 +87,39 @@ ] -def get_rds_method_attribute(method_name, module): +class AnsibleRDSError(AnsibleAWSError): + pass + + +class RDSErrorHandler(AWSErrorHandler): + _CUSTOM_EXCEPTION = AnsibleRDSError + + @classmethod + def _is_missing(cls): + return is_boto3_error_code(["DBInstanceNotFound", "DBSnapshotNotFound", "DBClusterNotFound"]) + + +@RDSErrorHandler.list_error_handler("describe db instances", []) +@AWSRetry.jittered_backoff() +def describe_db_instances(client, **params: Dict) -> List[Dict[str, Any]]: + paginator = client.get_paginator("describe_db_instances") + return paginator.paginate(**params).build_full_result()["DBInstances"] + + +@RDSErrorHandler.list_error_handler("describe db snapshots", []) +@AWSRetry.jittered_backoff() +def describe_db_snapshots(client, **params: Dict) -> List[Dict]: + paginator = client.get_paginator("describe_db_snapshots") + return paginator.paginate(**params).build_full_result()["DBSnapshots"] + + +@RDSErrorHandler.list_error_handler("list tags for resource", []) +@AWSRetry.jittered_backoff() +def list_tags_for_resource(client, resource_arn: str) -> List[Dict[str, str]]: + return client.list_tags_for_resource(ResourceName=resource_arn)["TagList"] + + +def get_rds_method_attribute(method_name: str, module: AnsibleAWSModule) -> Boto3ClientMethod: """ Returns rds attributes of the specified method. @@ -174,7 +210,21 @@ def get_rds_method_attribute(method_name, module): ) -def get_final_identifier(method_name, module): +def get_final_identifier(method_name: str, module: AnsibleAWSModule) -> str: + """ + Returns the final identifier for the resource to which the specified method applies. + + Parameters: + method_name (str): RDS method whose target resource final identifier is returned + module: AnsibleAWSModule + + Returns: + updated_identifier (str): The new resource identifier from module params if not in check mode, there is a new identifier in module params, and + apply_immediately is True; otherwise returns the original resource identifier from module params + + Raises: + NotImplementedError if the provided method is not supported + """ updated_identifier = None apply_immediately = module.params.get("apply_immediately") resource = get_rds_method_attribute(method_name, module).resource @@ -197,7 +247,20 @@ def get_final_identifier(method_name, module): return identifier -def handle_errors(module, exception, method_name, parameters): +def handle_errors(module: AnsibleAWSModule, exception: Any, method_name: str, parameters: Dict[str, Any]) -> bool: + """ + Fails the module with an appropriate error message given the provided exception. + + Parameters: + module: AnsibleAWSModule + exception: Botocore exception to be handled + method_name (str): Name of boto3 rds client method + parameters (dict): Parameters provided to boto3 client method + + Returns: + changed (bool): False if provided exception indicates that no modifications were requested or a read replica promotion was attempted on an + instance/cluseter that is not a read replica; should never return True (the module should always fail instead) + """ if not isinstance(exception, ClientError): module.fail_json_aws(exception, msg=f"Unexpected failure for method {method_name} with parameters {parameters}") @@ -252,7 +315,23 @@ def handle_errors(module, exception, method_name, parameters): return changed -def call_method(client, module, method_name, parameters): +def call_method(client, module: AnsibleAWSModule, method_name: str, parameters: Dict[str, Any]) -> Tuple[Any, bool]: + """Calls the provided boto3 rds client method with the provided parameters. + + Handles check mode determination, whether or not to wait for resource status, and method-specific retry codes. + + Parameters: + client: boto3 rds client + module: Ansible AWS module + method_name (str): Name of the boto3 rds client method to call + parameters (dict): Parameters to pass to the boto3 client method; these must already match expected parameters for the method and + be formatted correctly (CamelCase, Tags and other attributes converted to lists of dicts as needed) + + Returns: + tuple (any, bool): + result (any): Result value from method call + changed (bool): True if changes were made to the resource, False otherwise + """ result = {} changed = True if not module.check_mode: @@ -270,7 +349,19 @@ def call_method(client, module, method_name, parameters): return result, changed -def wait_for_instance_status(client, module, db_instance_id, waiter_name): +def wait_for_instance_status(client, module: AnsibleAWSModule, db_instance_id: str, waiter_name: str) -> None: + """ + Waits until provided instance has reached the expected status for provided waiter. + + Fails the module if an exception is raised while waiting. + + Parameters: + client: boto3 rds client + module: AnsibleAWSModule + db_instance_id (str): DB instance identifier + waiter_name (str): Name of either a boto3 rds client waiter or an RDS waiter defined in module_utils/waiters.py + """ + def wait(client, db_instance_id, waiter_name): try: waiter = client.get_waiter(waiter_name) @@ -300,7 +391,18 @@ def wait(client, db_instance_id, waiter_name): ) -def wait_for_cluster_status(client, module, db_cluster_id, waiter_name): +def wait_for_cluster_status(client, module: AnsibleAWSModule, db_cluster_id: str, waiter_name: str) -> None: + """ + Waits until provided cluster has reached the expected status for provided waiter. + + Fails the module if an exception is raised while waiting. + + Parameters: + client: boto3 rds client + module: AnsibleAWSModule + db_cluster_id (str): DB cluster identifier + waiter_name (str): Name of either a boto3 rds client waiter or an RDS waiter defined in module_utils/waiters.py + """ try: get_waiter(client, waiter_name).wait(DBClusterIdentifier=db_cluster_id) except WaiterError as e: @@ -313,7 +415,18 @@ def wait_for_cluster_status(client, module, db_cluster_id, waiter_name): module.fail_json_aws(e, msg=f"Failed with an unexpected error while waiting for the DB cluster {db_cluster_id}") -def wait_for_instance_snapshot_status(client, module, db_snapshot_id, waiter_name): +def wait_for_instance_snapshot_status(client, module: AnsibleAWSModule, db_snapshot_id: str, waiter_name: str) -> None: + """ + Waits until provided instance snapshot has reached the expected status for provided waiter. + + Fails the module if an exception is raised while waiting. + + Parameters: + client: boto3 rds client + module: AnsibleAWSModule + db_snapshot_id (str): DB instance snapshot identifier + waiter_name (str): Name of a boto3 rds client waiter + """ try: client.get_waiter(waiter_name).wait(DBSnapshotIdentifier=db_snapshot_id) except WaiterError as e: @@ -328,7 +441,18 @@ def wait_for_instance_snapshot_status(client, module, db_snapshot_id, waiter_nam ) -def wait_for_cluster_snapshot_status(client, module, db_snapshot_id, waiter_name): +def wait_for_cluster_snapshot_status(client, module: AnsibleAWSModule, db_snapshot_id: str, waiter_name: str) -> None: + """ + Waits until provided cluster snapshot has reached the expected status for provided waiter. + + Fails the module if an exception is raised while waiting. + + Parameters: + client: boto3 rds client + module: AnsibleAWSModule + db_snapshot_id (str): DB cluster snapshot identifier + waiter_name (str): Name of a boto3 rds client waiter + """ try: client.get_waiter(waiter_name).wait(DBClusterSnapshotIdentifier=db_snapshot_id) except WaiterError as e: @@ -344,7 +468,16 @@ def wait_for_cluster_snapshot_status(client, module, db_snapshot_id, waiter_name ) -def wait_for_status(client, module, identifier, method_name): +def wait_for_status(client, module: AnsibleAWSModule, identifier: str, method_name: str) -> None: + """ + Waits until provided resource has reached the expected final status for provided method. + + Parameters: + client: boto3 rds client + module: AnsibleAWSModule + identifier (str): resource identifier + method_name (str): Name of boto3 rds client method on whose final status to wait + """ rds_method_attributes = get_rds_method_attribute(method_name, module) waiter_name = rds_method_attributes.waiter resource = rds_method_attributes.resource @@ -359,14 +492,40 @@ def wait_for_status(client, module, identifier, method_name): wait_for_cluster_snapshot_status(client, module, identifier, waiter_name) -def get_tags(client, module, resource_arn): +def get_tags(client, module: AnsibleAWSModule, resource_arn: str) -> Dict[str, str]: + """ + Returns tags for provided RDS resource, formatted as an Ansible dict. + + Fails the module if an error is raised while retrieving resource tags. + + Parameters: + client: boto3 rds client + module: AnsibleAWSModule + resource_arn (str): AWS resource ARN + + Returns: + tags (dict): Tags for resource, formatted as an Ansible dict. An empty list is returned if the resource has no tags. + """ try: - return boto3_tag_list_to_ansible_dict(client.list_tags_for_resource(ResourceName=resource_arn)["TagList"]) - except (BotoCoreError, ClientError) as e: - module.fail_json_aws(e, msg="Unable to describe tags") + tags = list_tags_for_resource(client, resource_arn) + except AnsibleRDSError as e: + module.fail_json_aws(e, msg=f"Unable to list tags for resource {resource_arn}") + return boto3_tag_list_to_ansible_dict(tags) -def arg_spec_to_rds_params(options_dict): +def arg_spec_to_rds_params(options_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Converts snake_cased rds module options to CamelCased parameter formats expected by boto3 rds client. + + Does not alter case for keys or values in the following attributes: tags, processor_features. + Includes special handling of certain boto3 params that do not follow standard CamelCase. + + Parameters: + options_dict (dict): Snake-cased options for a boto3 rds client method + + Returns: + camel_options (dct): Options formatted for boto3 rds client + """ tags = options_dict.pop("tags") has_processor_features = False if "processor_features" in options_dict: @@ -383,7 +542,30 @@ def arg_spec_to_rds_params(options_dict): return camel_options -def ensure_tags(client, module, resource_arn, existing_tags, tags, purge_tags): +def ensure_tags( + client, + module: AnsibleAWSModule, + resource_arn: str, + existing_tags: Dict[str, str], + tags: Optional[Dict[str, str]], + purge_tags: bool, +) -> bool: + """ + Compares current resource tages to desired tags and adds/removes tags to ensure desired tags are present. + + A value of None for desired tags results in resource tags being left as is. + + Parameters: + client: boto3 rds client + module: AnsibleAWSModule + resource_arn (str): AWS resource ARN + existing_tags (dict): Current resource tags formatted as an Ansible dict + tags (dict): Desired resource tags formatted as an Ansible dict + purge_tags (bool): Whether to remove any existing resource tags not present in desired tags + + Returns: + True if resource tags are updated, False if not. + """ if tags is None: return False tags_to_add, tags_to_remove = compare_aws_tags(existing_tags, tags, purge_tags) @@ -405,13 +587,15 @@ def ensure_tags(client, module, resource_arn, existing_tags, tags, purge_tags): return changed -def compare_iam_roles(existing_roles, target_roles, purge_roles): +def compare_iam_roles( + existing_roles: List[Dict[str, str]], target_roles: List[Dict[str, str]], purge_roles: bool +) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]: """ - Returns differences between target and existing IAM roles + Returns differences between target and existing IAM roles. Parameters: - existing_roles (list): Existing IAM roles - target_roles (list): Target IAM roles + existing_roles (list): Existing IAM roles as a list of snake-cased dicts + target_roles (list): Target IAM roles as a list of snake-cased dicts purge_roles (bool): Remove roles not in target_roles if True Returns: @@ -424,7 +608,13 @@ def compare_iam_roles(existing_roles, target_roles, purge_roles): return roles_to_add, roles_to_remove -def update_iam_roles(client, module, instance_id, roles_to_add, roles_to_remove): +def update_iam_roles( + client, + module: AnsibleAWSModule, + instance_id: str, + roles_to_add: List[Dict[str, str]], + roles_to_remove: List[Dict[str, str]], +) -> bool: """ Update a DB instance's associated IAM roles @@ -432,8 +622,8 @@ def update_iam_roles(client, module, instance_id, roles_to_add, roles_to_remove) client: RDS client module: AnsibleAWSModule instance_id (str): DB's instance ID - roles_to_add (list): List of IAM roles to add - roles_to_delete (list): List of IAM roles to delete + roles_to_add (list): List of IAM roles to add in snake-cased dict format + roles_to_delete (list): List of IAM roles to delete in snake-cased dict format Returns: changed (bool): True if changes were successfully made to DB instance's IAM roles; False if not @@ -449,7 +639,7 @@ def update_iam_roles(client, module, instance_id, roles_to_add, roles_to_remove) @AWSRetry.jittered_backoff() def describe_db_cluster_parameter_groups( - module: AnsibleAWSModule, connection: Any, group_name: str + module: AnsibleAWSModule, connection: Any, group_name: Optional[str] ) -> List[Dict[str, Any]]: result = [] try: diff --git a/plugins/modules/rds_instance.py b/plugins/modules/rds_instance.py index 66cdbcc7160..ca43ba3a05e 100644 --- a/plugins/modules/rds_instance.py +++ b/plugins/modules/rds_instance.py @@ -862,30 +862,28 @@ """ from time import sleep - -try: - import botocore -except ImportError: - pass # caught by AnsibleAWSModule - +from typing import Any +from typing import Dict +from typing import List +from typing import Optional from ansible.module_utils._text import to_text from ansible.module_utils.common.dict_transformations import camel_dict_to_snake_dict from ansible.module_utils.six import string_types from ansible_collections.amazon.aws.plugins.module_utils.botocore import get_boto3_client_method_parameters -from ansible_collections.amazon.aws.plugins.module_utils.botocore import is_boto3_error_code from ansible_collections.amazon.aws.plugins.module_utils.botocore import is_boto3_error_message from ansible_collections.amazon.aws.plugins.module_utils.modules import AnsibleAWSModule +from ansible_collections.amazon.aws.plugins.module_utils.rds import AnsibleRDSError from ansible_collections.amazon.aws.plugins.module_utils.rds import arg_spec_to_rds_params from ansible_collections.amazon.aws.plugins.module_utils.rds import call_method from ansible_collections.amazon.aws.plugins.module_utils.rds import compare_iam_roles +from ansible_collections.amazon.aws.plugins.module_utils.rds import describe_db_instances +from ansible_collections.amazon.aws.plugins.module_utils.rds import describe_db_snapshots from ansible_collections.amazon.aws.plugins.module_utils.rds import ensure_tags from ansible_collections.amazon.aws.plugins.module_utils.rds import get_final_identifier from ansible_collections.amazon.aws.plugins.module_utils.rds import get_rds_method_attribute -from ansible_collections.amazon.aws.plugins.module_utils.rds import get_tags from ansible_collections.amazon.aws.plugins.module_utils.rds import update_iam_roles -from ansible_collections.amazon.aws.plugins.module_utils.retries import AWSRetry from ansible_collections.amazon.aws.plugins.module_utils.tagging import ansible_dict_to_boto3_tag_list from ansible_collections.amazon.aws.plugins.module_utils.tagging import boto3_tag_list_to_ansible_dict @@ -920,7 +918,22 @@ ] -def get_rds_method_attribute_name(instance, state, creation_source, read_replica): +def get_rds_method_attribute_name( + instance: Dict, state: str, creation_source: str, read_replica: Optional[bool] +) -> Optional[str]: + """ + Returns the target boto3 rds client method name given the provided module options and current instance state. + + Parameters: + instance (dict): Current instance attributes as returned by get_instance() + state (str): Desired instance state as provided to module options + creation_source (str): Creation source to use for restoring an instance as provided to module options + read_replica (bool): Whether to create (True) or promote (False) a read replica as provided to module options + + Returns: + method_name (str): Name of boto3 rds client method needed to achieve desired state. Returns None if desired state is "absent" or "terminated" and + current instance is None or the current instance status is "deleting" or "deleted" + """ method_name = None if state == "absent" or state == "terminated": if instance and instance["DBInstanceStatus"] not in ["deleting", "deleted"]: @@ -941,50 +954,99 @@ def get_rds_method_attribute_name(instance, state, creation_source, read_replica return method_name -def get_instance(client, module, db_instance_id): +def get_instance(client, module: AnsibleAWSModule, db_instance_id: str) -> Dict[str, Any]: + """ + Returns attributes for the provided db instance. + + Parameters: + client: boto3 rds client + module: AnsibleAWSModule + db_instance_id (str): DB instance identifier + + Returns: + instance (dict): DB instance attributes with the following boto3 attribute lists converted to dicts of key/value pairs: + - PendingModifiedValues["ProcessorFeatures"] + - ProcessorFeatures + - Tags (from boto3 TagList attribute) + If no matching instance is found, returns an empty dict. + + Raises: + Fails the module if an exception is raised while retrieving the db instance attributes. + """ + instances = None try: - for _i in range(3): - try: - instance = client.describe_db_instances(DBInstanceIdentifier=db_instance_id)["DBInstances"][0] - instance["Tags"] = get_tags(client, module, instance["DBInstanceArn"]) - if instance.get("ProcessorFeatures"): - instance["ProcessorFeatures"] = dict( - (feature["Name"], feature["Value"]) for feature in instance["ProcessorFeatures"] - ) - if instance.get("PendingModifiedValues", {}).get("ProcessorFeatures"): - instance["PendingModifiedValues"]["ProcessorFeatures"] = dict( - (feature["Name"], feature["Value"]) - for feature in instance["PendingModifiedValues"]["ProcessorFeatures"] - ) - break - except is_boto3_error_code("DBInstanceNotFound"): - sleep(3) - else: - instance = {} - except ( - botocore.exceptions.ClientError, - botocore.exceptions.BotoCoreError, - ) as e: # pylint: disable=duplicate-except - module.fail_json_aws(e, msg="Failed to describe DB instances") + instances = describe_db_instances(client, DBInstanceIdentifier=db_instance_id) + except AnsibleRDSError as e: + module.fail_json_aws(e, msg=f"Failed to get DB instance {db_instance_id}") + + if instances: + instance = instances[0] + else: + return {} + + instance["Tags"] = boto3_tag_list_to_ansible_dict(instance.pop("TagList")) + if instance.get("ProcessorFeatures"): + instance["ProcessorFeatures"] = dict( + (feature["Name"], feature["Value"]) for feature in instance["ProcessorFeatures"] + ) + if instance.get("PendingModifiedValues", {}).get("ProcessorFeatures"): + instance["PendingModifiedValues"]["ProcessorFeatures"] = dict( + (feature["Name"], feature["Value"]) for feature in instance["PendingModifiedValues"]["ProcessorFeatures"] + ) + return instance -def get_final_snapshot(client, module, snapshot_identifier): +def get_final_snapshot(client, module: AnsibleAWSModule, snapshot_identifier: str) -> Dict[str, Any]: + """ + Returns the final snapshot given the final snapshot identifer. + + Parameters: + client: boto3 rds client + module: AnsibleAWSModule + snapshot_identifier (str): Unique snapshot identifier + + Returns: + snapshot (dict): Snapshot attributes as returned by boto3 client + + Raises: + Failes the module if an exception is raised while retrieving the snapshot attributes. + """ + snapshot = {} try: - snapshots = AWSRetry.jittered_backoff()(client.describe_db_snapshots)(DBSnapshotIdentifier=snapshot_identifier) - if len(snapshots.get("DBSnapshots", [])) == 1: - return snapshots["DBSnapshots"][0] - return {} - except is_boto3_error_code("DBSnapshotNotFound"): # May not be using wait: True - return {} - except ( - botocore.exceptions.ClientError, - botocore.exceptions.BotoCoreError, - ) as e: # pylint: disable=duplicate-except - module.fail_json_aws(e, msg="Failed to retrieve information about the final snapshot") + snapshots = describe_db_snapshots(client, DBSnapshotIdentifier=snapshot_identifier) + if len(snapshots) == 1: + snapshot = snapshots[0] + except AnsibleRDSError as e: + module.fail_json_aws(e, msg=f"Failed to retrieve information about the final snapshot: {snapshot_identifier}") + return snapshot -def get_parameters(client, module, parameters, method_name): +def get_parameters(client, module: AnsibleAWSModule, parameters: Dict[str, Any], method_name: str) -> Dict[str, Any]: + """ + Returns a dict of parameters validated and formatted for the provided boto3 client method. + + Performs the following parameters checks and updates: + - Converts parameters supplied as snake_cased module options to CamelCase + - Ensures that all required parameters for the provided method are present + - Ensures that only parameters allowed for the provided method are present, removing any that are not relevant + - Removes parameters with None values + - Converts the following dict parameters to lists of dicts as expected by the boto3 rds client: ProcessorFeatures, Tags + - If method is "modify_db_instance", compares supplied parameters to current instance attributes, determines which parameters need to be modified, and + removes any parameters that do not need to be modified + + Parameters: + client: boto3 rds client + module: AnsibleAWSModule + parameters (dict): Parameter options from module argument_spec + method_name: boto3 client method for which to validate parameters + + Returns: + Dict of client parameters formatted for the provided method + + Raises: + Fails the module if any parameters required by the provided method are not provided in module options + """ if method_name == "restore_db_instance_to_point_in_time": parameters["TargetDBInstanceIdentifier"] = module.params["db_instance_identifier"] @@ -1008,13 +1070,29 @@ def get_parameters(client, module, parameters, method_name): if parameters.get("Tags"): parameters["Tags"] = ansible_dict_to_boto3_tag_list(parameters["Tags"]) + # For modify_db_instance method, update parameters to include all params that need to be modified by comparing them to current instance attributes if method_name == "modify_db_instance": parameters = get_options_with_changing_values(client, module, parameters) return parameters -def get_options_with_changing_values(client, module, parameters): +def get_options_with_changing_values(client, module: AnsibleAWSModule, parameters: Dict[str, Any]) -> Dict[str, Any]: + """ + Compares current instance attributes to the provided parameters and module options and returns parameters with values to be modified. + + Parameters: + client: boto3 rds client + module: AnsibleAWSModule + parameters (dict): Parameters for boto3 client modify_db_instance method + + Returns: + parameters (dict): Updated parameters including only parameters that need to be modified, renamed and formatted as expected by boto3 client + modify_db_instance method + + Raises: + Fails the module if invalid changes are provided for iops or storage_throughput values + """ instance_id = module.params["db_instance_identifier"] purge_cloudwatch_logs = module.params["purge_cloudwatch_logs_exports"] force_update_password = module.params["force_update_password"] @@ -1036,12 +1114,15 @@ def get_options_with_changing_values(client, module, parameters): parameters.pop("Iops", None) instance = get_instance(client, module, instance_id) + + # Determine which parameters need to be modified updated_parameters = get_changing_options_with_inconsistent_keys( parameters, instance, purge_cloudwatch_logs, purge_security_groups ) updated_parameters.update(get_changing_options_with_consistent_keys(parameters, instance)) parameters = updated_parameters + # Validate changes to storage type options if instance.get("StorageType") == "io1": # Bundle Iops and AllocatedStorage while updating io1 RDS Instance current_iops = instance.get("PendingModifiedValues", {}).get("Iops", instance["Iops"]) @@ -1113,8 +1194,26 @@ def get_options_with_changing_values(client, module, parameters): return parameters -def get_current_attributes_with_inconsistent_keys(instance): - options = {} +def get_current_attributes_with_inconsistent_keys(instance: Dict[str, Any]) -> Dict[str, Any]: + """ + Returns current instance attributes whose formats differ from those expected by boto3 client modify_db_instance method, updated to match method options. + + Option formats for the boto3 client modify_db_instance method do not always match their corresponding attributes returned by describe_db_instances. + To ensure that we are accurately comparing the two dicts for changes, this function: + - Checks for pending modified values in these instance attributes and updates the corresponding current attributes to match the pending values + - Converts these instance attribute names and value formats to those expected by the modify_db_instance method + + Parameters: + instance (dict): Current instance attributes as returned by get_instance() + + Returns: + options (dict): Current instance attributes updated to match the boto3 client modify_db_instance method option formatting. Only returns attributes + whose format varies between the returned attributes and the method options (i.e., excludes any attributes whose formats already match what the + boto3 client method expects) + """ + options: Dict[str, Any] = {} + + # Check for any pending cloudwatch logs exports configuration changes and add them to CloudwatchLogsExportConfiguration option if instance.get("PendingModifiedValues", {}).get("PendingCloudwatchLogsExports", {}).get("LogTypesToEnable", []): current_enabled = instance["PendingModifiedValues"]["PendingCloudwatchLogsExports"]["LogTypesToEnable"] current_disabled = instance["PendingModifiedValues"]["PendingCloudwatchLogsExports"]["LogTypesToDisable"] @@ -1122,11 +1221,15 @@ def get_current_attributes_with_inconsistent_keys(instance): "LogTypesToEnable": current_enabled, "LogTypesToDisable": current_disabled, } + # If there are no pending cloudwatch logs exports configuration changes, set CloudwatchLogsExportConfiguration option to match current enabled cloudwatch + # logs exports attribute else: options["CloudwatchLogsExportConfiguration"] = { "LogTypesToEnable": instance.get("EnabledCloudwatchLogsExports", []), "LogTypesToDisable": [], } + + # Check for pending changes on other attributes, if not then set options to current attribute values if instance.get("PendingModifiedValues", {}).get("Port"): options["DBPortNumber"] = instance["PendingModifiedValues"]["Port"] else: @@ -1139,6 +1242,8 @@ def get_current_attributes_with_inconsistent_keys(instance): options["ProcessorFeatures"] = instance["PendingModifiedValues"]["ProcessorFeatures"] else: options["ProcessorFeatures"] = instance.get("ProcessorFeatures", {}) + + # Convert current instance's attributes that are lists of dicts to lists of string values for comparison and set the options to updated lists options["OptionGroupName"] = [g["OptionGroupName"] for g in instance["OptionGroupMemberships"]] options["DBSecurityGroups"] = [ sg["DBSecurityGroupName"] for sg in instance["DBSecurityGroups"] if sg["Status"] in ["adding", "active"] @@ -1150,6 +1255,7 @@ def get_current_attributes_with_inconsistent_keys(instance): parameter_group["DBParameterGroupName"] for parameter_group in instance["DBParameterGroups"] ] options["EnableIAMDatabaseAuthentication"] = instance["IAMDatabaseAuthenticationEnabled"] + # PerformanceInsightsEnabled is not returned on older RDS instances it seems options["EnablePerformanceInsights"] = instance.get("PerformanceInsightsEnabled", False) options["NewDBInstanceIdentifier"] = instance["DBInstanceIdentifier"] @@ -1161,8 +1267,24 @@ def get_current_attributes_with_inconsistent_keys(instance): return options -def get_changing_options_with_inconsistent_keys(modify_params, instance, purge_cloudwatch_logs, purge_security_groups): - changing_params = {} +def get_changing_options_with_inconsistent_keys( + modify_params: Dict[str, Any], instance: Dict[str, Any], purge_cloudwatch_logs: bool, purge_security_groups: bool +) -> Dict[str, Any]: + """ + Compares current instance attributes with provided parameters whose formats are inconsistent between describe_db_instances and modify_db_instance methods. + + Parameters: + modify_params (dict): Parameters to be supplied to boto3 client modify_db_instance method; should already be validated and formatted + instance (dict): Current instance attributes as returned by get_instance() + purge_cloudwatch_logs (bool): True if currently enabled cloudwatch logs exports should be removed from configuration when not in provided + parameters, False if they should be retained + purge_security_groups (bool): True if currently associated security groups should be removed from instance if not in provided parameters, False if + they should be retained + + Returns: + changing_params (dict): Parameters to be modified + """ + changing_params: Dict[str, Any] = {} current_options = get_current_attributes_with_inconsistent_keys(instance) for option, current_option in current_options.items(): desired_option = modify_params.pop(option, None) @@ -1171,17 +1293,22 @@ def get_changing_options_with_inconsistent_keys(modify_params, instance, purge_c # TODO: allow other purge_option module parameters rather than just checking for things to add if isinstance(current_option, list): + # Compare lists if isinstance(desired_option, list): if ( set(desired_option) < set(current_option) and option in ["DBSecurityGroups", "VpcSecurityGroupIds"] and purge_security_groups ): + # There are associated security groups to be purged changing_params[option] = desired_option elif set(desired_option) <= set(current_option): + # Desired option set is entirely contained within current option set and purge is False, nothing to change continue elif isinstance(desired_option, string_types): + # Current option is a list and desired option is a string if desired_option in current_option: + # Desired option is in current options, nothing to change continue # Current option and desired option are the same - continue loop @@ -1191,19 +1318,25 @@ def get_changing_options_with_inconsistent_keys(modify_params, instance, purge_c if option == "ProcessorFeatures" and current_option == boto3_tag_list_to_ansible_dict( desired_option, "Name", "Value" ): + # Processor features are the same, continue loop continue # Current option and desired option are different - add to changing_params list if option == "ProcessorFeatures" and desired_option == []: + # Update to use default processor features changing_params["UseDefaultProcessorFeatures"] = True elif option == "CloudwatchLogsExportConfiguration": + # Update cloudwatch logs enabled/disabled current_option = set(current_option.get("LogTypesToEnable", [])) desired_option = set(desired_option) - format_option = {"EnableLogTypes": [], "DisableLogTypes": []} + format_option: Dict[str, List] = {"EnableLogTypes": [], "DisableLogTypes": []} + # Set enable list to any items from desired not in current format_option["EnableLogTypes"] = list(desired_option.difference(current_option)) if purge_cloudwatch_logs: + # If purge is true, set disable list to difference between current and desired format_option["DisableLogTypes"] = list(current_option.difference(desired_option)) if format_option["EnableLogTypes"] or format_option["DisableLogTypes"]: + # Update cloudwatch logs configuration option to reflect changes changing_params[option] = format_option elif option in ["DBSecurityGroups", "VpcSecurityGroupIds"]: if purge_security_groups: @@ -1216,8 +1349,20 @@ def get_changing_options_with_inconsistent_keys(modify_params, instance, purge_c return changing_params -def get_changing_options_with_consistent_keys(modify_params, instance): - changing_params = {} +def get_changing_options_with_consistent_keys( + modify_params: Dict[str, Any], instance: Dict[str, Any] +) -> Dict[str, Any]: + """ + Compares current instance attributes with provided parameters whose attribute and parameter formats match. + + Parameters: + modify_params (dict): Parameters to be supplied to boto3 client modify_db_instance method; should already be validated and formatted + instance (dict): Current instance attributes as returned by get_instance() + + Returns: + changing_params (dict): Parameters to be modified + """ + changing_params: Dict[str, Any] = {} for param in modify_params: current_option = instance.get("PendingModifiedValues", {}).get(param, None) @@ -1229,7 +1374,18 @@ def get_changing_options_with_consistent_keys(modify_params, instance): return changing_params -def validate_options(client, module, instance): +def validate_options(client, module: AnsibleAWSModule, instance: Dict[str, Any]) -> None: + """ + Validates complex module option logic and fails the module with an error message if options are invalid. + + Parameters: + client: boto3 rds client + module: AnsibleAWSModule + instance (dict): Current instance attributes as returned by get_instance() + + Raises: + Fails the module if provided module options are incompatible with each other or with current instance attributes + """ state = module.params["state"] skip_final_snapshot = module.params["skip_final_snapshot"] snapshot_id = module.params["final_db_snapshot_identifier"] @@ -1268,7 +1424,19 @@ def validate_options(client, module, instance): ) -def update_instance(client, module, instance, instance_id): +def update_instance(client, module: AnsibleAWSModule, instance: Dict[str, Any], instance_id: str) -> bool: + """ + Ensures that an existing instance's tags, read replica status, and state match what is supplied in module options. + + Parameters: + client: boto3 rds client + module: AnsibleAWSModule + instance (dict): Current instance attributes as returned by get_instance() + instance_id (str): Existing instance identifier, used to retrieve instance attributes if provided instance dict is empty + + Returns: + changed (bool): True if instance was successfully updated, False if not + """ changed = False # Get newly created DB instance @@ -1285,7 +1453,24 @@ def update_instance(client, module, instance, instance_id): return changed -def promote_replication_instance(client, module, instance, read_replica): +def promote_replication_instance( + client, module: AnsibleAWSModule, instance: Dict[str, Any], read_replica: bool +) -> bool: + """ + Promotes the provided DB instance from a read replica to a standalone instance. + + Only promotes the instance if read_replica is False, which is confusing but is how the module is documented. + Returns changed=False without any warning or error message if the provided instance is not a read replica. + + Parameters: + client: boto3 rds client + module: AnsibleAWSModule + instance (dict): Current instance attributes as returned by get_instance() + read_replica (bool): False if instance should be promoted + + Returns: + changed (bool): True if provided instance was successfully promoted, False if not + """ changed = False if read_replica is False: # 'StatusInfos' only exists when the instance is a read replica @@ -1303,14 +1488,14 @@ def promote_replication_instance(client, module, instance, read_replica): return changed -def ensure_iam_roles(client, module, instance_id): +def ensure_iam_roles(client, module: AnsibleAWSModule, instance_id: str) -> bool: """ - Ensure specified IAM roles are associated with DB instance + Ensure specified IAM roles are associated with DB instance. Parameters: - client: RDS client - module: AWSModule - instance_id: DB's instance ID + client: boto3 rds client + module: AnsibleAWSModule + instance_id (str): Existing DB instance identifier Returns: changed (bool): True if changes were successfully made to DB instance's IAM roles; False if not @@ -1341,7 +1526,19 @@ def ensure_iam_roles(client, module, instance_id): return changed -def update_instance_state(client, module, instance, state): +def update_instance_state(client, module: AnsibleAWSModule, instance: Dict[str, Any], state: str) -> bool: + """ + Starts, stops, or reboots an instance given the desired state and current instance attributes. + + Parameters: + client: boto3 rds client + module: AnsibleAWSModule + instance (dict): Current instance attributes as returned by get_instance() + state (str): Desired instance state as provided to module options + + Returns: + changed (bool): True if DB instance state was updated, False if not + """ changed = False if state in ["rebooted", "restarted"]: changed |= reboot_running_db_instance(client, module, instance) @@ -1350,7 +1547,20 @@ def update_instance_state(client, module, instance, state): return changed -def reboot_running_db_instance(client, module, instance): +def reboot_running_db_instance(client, module: AnsibleAWSModule, instance: Dict[str, Any]) -> bool: + """ + Reboots provided instance. + + If the instance is currently stopped or stopping, restarts it first. + + Parameters: + client: boto3 rds client + module: AnsibleAWSModule + instance (dict): Current instance attributes as returned by get_instance() + + Returns: + changed (bool): True if instance was successfully rebooted, False otherwise. + """ parameters = {"DBInstanceIdentifier": instance["DBInstanceIdentifier"]} if instance["DBInstanceStatus"] in ["stopped", "stopping"]: call_method(client, module, "start_db_instance", parameters) @@ -1360,7 +1570,21 @@ def reboot_running_db_instance(client, module, instance): return changed -def start_or_stop_instance(client, module, instance, state): +def start_or_stop_instance(client, module: AnsibleAWSModule, instance: Dict[str, Any], state: str) -> bool: + """ + Starts or stops provided instance given desired state. + + Checks whether the instance is already in or pending the desired state, if so does not alter it. + + Parameters: + client: boto3 rds client + module: AnsibleAWSModule + instance (dict): Current instance attributes as returned by get_instance() + state (str): Desired instance state as provided to module options + + Returns: + changed (bool): True if instance was started or stopped, False if not + """ changed = False parameters = {"DBInstanceIdentifier": instance["DBInstanceIdentifier"]} if state == "stopped" and instance["DBInstanceStatus"] not in ["stopping", "stopped"]: diff --git a/plugins/modules/rds_instance_info.py b/plugins/modules/rds_instance_info.py index ff3a6215684..7cf721374a4 100644 --- a/plugins/modules/rds_instance_info.py +++ b/plugins/modules/rds_instance_info.py @@ -351,64 +351,47 @@ sample: sg-abcd1234 """ +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Union + from ansible.module_utils.common.dict_transformations import camel_dict_to_snake_dict -from ansible_collections.amazon.aws.plugins.module_utils.botocore import is_boto3_error_code from ansible_collections.amazon.aws.plugins.module_utils.modules import AnsibleAWSModule -from ansible_collections.amazon.aws.plugins.module_utils.retries import AWSRetry +from ansible_collections.amazon.aws.plugins.module_utils.rds import AnsibleRDSError +from ansible_collections.amazon.aws.plugins.module_utils.rds import describe_db_instances from ansible_collections.amazon.aws.plugins.module_utils.tagging import boto3_tag_list_to_ansible_dict from ansible_collections.amazon.aws.plugins.module_utils.transformation import ansible_dict_to_boto3_filter_list -try: - import botocore -except ImportError: - pass # handled by AnsibleAWSModule - - -@AWSRetry.jittered_backoff() -def _describe_db_instances(conn, **params): - paginator = conn.get_paginator("describe_db_instances") - try: - results = paginator.paginate(**params).build_full_result()["DBInstances"] - except is_boto3_error_code("DBInstanceNotFound"): - results = [] - - return results - -class RdsInstanceInfoFailure(Exception): - def __init__(self, original_e, user_message): - self.original_e = original_e - self.user_message = user_message - super().__init__(self) +def instance_info( + client, module: AnsibleAWSModule, instance_name: Optional[str], filters: Optional[Dict[str, Union[str, List]]] +) -> List[Dict[str, Any]]: + """ + Returns attributes of db instance(s), with instances optionally filtered by provided name and additional filters. + Parameters: + client: boto3 rds client + module: AnsibleAWSModule + instance_name (str, optional): Unique identifier of db instance to describe + filters (dict, optional): Additional boto3-supported filters specifying db instance(s) to describe -def get_instance_tags(conn, arn): - try: - return boto3_tag_list_to_ansible_dict(conn.list_tags_for_resource(ResourceName=arn, aws_retry=True)["TagList"]) - except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: - raise RdsInstanceInfoFailure(e, f"Couldn't get tags for instance {arn}") - - -def instance_info(conn, instance_name, filters): + Returns: + instances (list): List of instance attribute dicts converted from CamelCase to snake_case format + """ params = {} if instance_name: params["DBInstanceIdentifier"] = instance_name if filters: params["Filters"] = ansible_dict_to_boto3_filter_list(filters) - try: - results = _describe_db_instances(conn, **params) - except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: - raise RdsInstanceInfoFailure(e, "Couldn't get instance information") - + results = describe_db_instances(client, **params) for instance in results: - instance["Tags"] = get_instance_tags(conn, arn=instance["DBInstanceArn"]) + instance["Tags"] = boto3_tag_list_to_ansible_dict(instance.pop("TagList")) - return { - "changed": False, - "instances": [camel_dict_to_snake_dict(instance, ignore_list=["Tags"]) for instance in results], - } + return [camel_dict_to_snake_dict(instance, ignore_list=["Tags"]) for instance in results] def main(): @@ -422,15 +405,15 @@ def main(): supports_check_mode=True, ) - conn = module.client("rds", retry_decorator=AWSRetry.jittered_backoff(retries=10)) + client = module.client("rds") instance_name = module.params.get("db_instance_identifier") filters = module.params.get("filters") try: - module.exit_json(**instance_info(conn, instance_name, filters)) - except RdsInstanceInfoFailure as e: - module.fail_json_aws(e.original_e, e.user_message) + module.exit_json(changed=False, instances=instance_info(client, module, instance_name, filters)) + except AnsibleRDSError as e: + module.fail_json_aws(e) if __name__ == "__main__": diff --git a/tests/unit/plugins/modules/test_rds_instance.py b/tests/unit/plugins/modules/test_rds_instance.py new file mode 100644 index 00000000000..0cc17f96d1f --- /dev/null +++ b/tests/unit/plugins/modules/test_rds_instance.py @@ -0,0 +1,104 @@ +# (c) 2024 Red Hat Inc. +# +# This file is part of Ansible +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest + +from ansible_collections.amazon.aws.plugins.module_utils.rds import AnsibleRDSError +from ansible_collections.amazon.aws.plugins.modules.rds_instance import get_final_snapshot +from ansible_collections.amazon.aws.plugins.modules.rds_instance import get_instance + +mod_name = "ansible_collections.amazon.aws.plugins.modules.rds_instance" + + +@pytest.mark.parametrize( + "instances, expected", + [ + ([], {}), + ( + [ + { + "DBInstanceIdentifier": "my-instance", + "DBInstanceArn": "arn:aws:rds:us-east-1:123456789012:og:my-instance", + "TagList": [], + } + ], + { + "DBInstanceIdentifier": "my-instance", + "DBInstanceArn": "arn:aws:rds:us-east-1:123456789012:og:my-instance", + "Tags": {}, + }, + ), + ( + [ + { + "DBInstanceIdentifier": "my-instance", + "DBInstanceArn": "arn:aws:rds:us-east-1:123456789012:og:my-instance", + "TagList": [{"Key": "My Tag", "Value": "My Value"}], + "ProcessorFeatures": [{"Name": "coreCount", "Value": "1"}], + "PendingModifiedValues": { + "ProcessorFeatures": [{"Name": "coreCount", "Value": "2"}], + }, + } + ], + { + "DBInstanceIdentifier": "my-instance", + "DBInstanceArn": "arn:aws:rds:us-east-1:123456789012:og:my-instance", + "Tags": {"My Tag": "My Value"}, + "ProcessorFeatures": {"coreCount": "1"}, + "PendingModifiedValues": {"ProcessorFeatures": {"coreCount": "2"}}, + }, + ), + ], +) +@patch(mod_name + ".describe_db_instances") +def test_get_instance_success(m_describe_db_instances, instances, expected): + client = MagicMock() + module = MagicMock() + m_describe_db_instances.return_value = instances + assert get_instance(client, module, "my-instance") == expected + + +@patch(mod_name + ".describe_db_instances") +def test_get_instance_failure(m_describe_db_instances): + client = MagicMock() + module = MagicMock() + e = AnsibleRDSError() + m_describe_db_instances.side_effect = e + get_instance(client, module, "my-instance") + module.fail_json_aws.assert_called_once_with(e, msg="Failed to get DB instance my-instance") + + +@pytest.mark.parametrize( + "snapshots, expected", + [ + ([], {}), + ( + [{"DBSnapshotIdentifier": "my-snapshot", "DBInstanceIdentifier": "my-instance"}], + {"DBSnapshotIdentifier": "my-snapshot", "DBInstanceIdentifier": "my-instance"}, + ), + ([{"DBSnapshotIdentifier": "snapshot-1"}, {"DBSnapshotIdentifier": "snapshot-2"}], {}), + ], +) +@patch(mod_name + ".describe_db_snapshots") +def test_get_final_snapshot_success(m_describe_db_snapshots, snapshots, expected): + client = MagicMock() + module = MagicMock() + m_describe_db_snapshots.return_value = snapshots + assert get_final_snapshot(client, module, "my-snapshot") == expected + + +@patch(mod_name + ".describe_db_snapshots") +def test_get_final_snapshot_failure(m_describe_db_snapshots): + client = MagicMock() + module = MagicMock() + e = AnsibleRDSError() + m_describe_db_snapshots.side_effect = e + get_final_snapshot(client, module, "my-snapshot") + module.fail_json_aws.assert_called_once_with( + e, msg="Failed to retrieve information about the final snapshot: my-snapshot" + ) diff --git a/tests/unit/plugins/modules/test_rds_instance_info.py b/tests/unit/plugins/modules/test_rds_instance_info.py index 8db20f1a077..2058356ab38 100644 --- a/tests/unit/plugins/modules/test_rds_instance_info.py +++ b/tests/unit/plugins/modules/test_rds_instance_info.py @@ -2,99 +2,70 @@ # # This file is part of Ansible # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from unittest.mock import ANY + from unittest.mock import MagicMock -from unittest.mock import call from unittest.mock import patch -import botocore.exceptions -import pytest - +from ansible_collections.amazon.aws.plugins.module_utils.rds import AnsibleRDSError from ansible_collections.amazon.aws.plugins.modules import rds_instance_info +from ansible_collections.amazon.aws.plugins.modules.rds_instance_info import instance_info mod_name = "ansible_collections.amazon.aws.plugins.modules.rds_instance_info" -def a_boto_exception(): - return botocore.exceptions.UnknownServiceError(service_name="Whoops", known_service_names="Oula") - - -@patch(mod_name + "._describe_db_instances") -@patch(mod_name + ".get_instance_tags") -def test_instance_info_one_instance(m_get_instance_tags, m_describe_db_instances): +@patch(mod_name + ".describe_db_instances") +def test_instance_info_one_instance(m_describe_db_instances): conn = MagicMock() + module = MagicMock() instance_name = "my-instance" - m_get_instance_tags.return_value = [] m_describe_db_instances.return_value = [ { "DBInstanceIdentifier": instance_name, "DBInstanceArn": "arn:aws:rds:us-east-2:123456789012:og:" + instance_name, + "TagList": [], } ] - rds_instance_info.instance_info(conn, instance_name, filters={}) + assert instance_info(conn, module, instance_name, filters={}) == [ + { + "db_instance_identifier": instance_name, + "db_instance_arn": "arn:aws:rds:us-east-2:123456789012:og:" + instance_name, + "tags": {}, + } + ] m_describe_db_instances.assert_called_with(conn, DBInstanceIdentifier=instance_name) - m_get_instance_tags.assert_called_with(conn, arn="arn:aws:rds:us-east-2:123456789012:og:" + instance_name) -@patch(mod_name + "._describe_db_instances") -@patch(mod_name + ".get_instance_tags") -def test_instance_info_all_instances(m_get_instance_tags, m_describe_db_instances): +@patch(mod_name + ".describe_db_instances") +def test_instance_info_all_instances(m_describe_db_instances): conn = MagicMock() - m_get_instance_tags.return_value = [] + module = MagicMock() m_describe_db_instances.return_value = [ { "DBInstanceIdentifier": "first-instance", "DBInstanceArn": "arn:aws:rds:us-east-2:123456789012:og:first-instance", + "TagList": [], }, { "DBInstanceIdentifier": "second-instance", "DBInstanceArn": "arn:aws:rds:us-east-2:123456789012:og:second-instance", + "TagList": [{"Key": "MyTag", "Value": "My tag value"}], }, ] - rds_instance_info.instance_info(conn, instance_name=None, filters={"engine": "postgres"}) + assert instance_info(conn, module, instance_name=None, filters={"engine": "postgres"}) == [ + { + "db_instance_identifier": "first-instance", + "db_instance_arn": "arn:aws:rds:us-east-2:123456789012:og:first-instance", + "tags": {}, + }, + { + "db_instance_identifier": "second-instance", + "db_instance_arn": "arn:aws:rds:us-east-2:123456789012:og:second-instance", + "tags": {"MyTag": "My tag value"}, + }, + ] m_describe_db_instances.assert_called_with(conn, Filters=[{"Name": "engine", "Values": ["postgres"]}]) - assert m_get_instance_tags.call_count == 2 - m_get_instance_tags.assert_has_calls( - [ - call(conn, arn="arn:aws:rds:us-east-2:123456789012:og:first-instance"), - call(conn, arn="arn:aws:rds:us-east-2:123456789012:og:second-instance"), - ] - ) - - -def test_get_instance_tags(): - conn = MagicMock() - conn.list_tags_for_resource.return_value = { - "TagList": [ - {"Key": "My-tag", "Value": "the-value$"}, - ], - "NextToken": "some-token", - } - - tags = rds_instance_info.get_instance_tags(conn, "arn:aws:rds:us-east-2:123456789012:og:second-instance") - conn.list_tags_for_resource.assert_called_with( - ResourceName="arn:aws:rds:us-east-2:123456789012:og:second-instance", - aws_retry=True, - ) - assert tags == {"My-tag": "the-value$"} - - -def test_api_failure_get_tag(): - conn = MagicMock() - conn.list_tags_for_resource.side_effect = a_boto_exception() - - with pytest.raises(rds_instance_info.RdsInstanceInfoFailure): - rds_instance_info.get_instance_tags(conn, "arn:blabla") - - -def test_api_failure_describe(): - conn = MagicMock() - conn.get_paginator.side_effect = a_boto_exception() - - with pytest.raises(rds_instance_info.RdsInstanceInfoFailure): - rds_instance_info.instance_info(conn, None, {}) @patch(mod_name + ".AnsibleAWSModule") @@ -104,18 +75,19 @@ def test_main_success(m_AnsibleAWSModule): rds_instance_info.main() - m_module.client.assert_called_with("rds", retry_decorator=ANY) + m_module.client.assert_called_with("rds") m_module.exit_json.assert_called_with(changed=False, instances=[]) -@patch(mod_name + "._describe_db_instances") +@patch(mod_name + ".describe_db_instances") @patch(mod_name + ".AnsibleAWSModule") def test_main_failure(m_AnsibleAWSModule, m_describe_db_instances): m_module = MagicMock() m_AnsibleAWSModule.return_value = m_module - m_describe_db_instances.side_effect = a_boto_exception() + e = AnsibleRDSError() + m_describe_db_instances.side_effect = e rds_instance_info.main() - m_module.client.assert_called_with("rds", retry_decorator=ANY) - m_module.fail_json_aws.assert_called_with(ANY, "Couldn't get instance information") + m_module.client.assert_called_with("rds") + m_module.fail_json_aws.assert_called_with(e)