diff --git a/src/aibs_informatics_aws_utils/core.py b/src/aibs_informatics_aws_utils/core.py index 1728f27..bde86c9 100644 --- a/src/aibs_informatics_aws_utils/core.py +++ b/src/aibs_informatics_aws_utils/core.py @@ -11,19 +11,8 @@ ] import logging import os -import re from dataclasses import dataclass -from typing import ( - TYPE_CHECKING, - ClassVar, - Generic, - Literal, - Optional, - Pattern, - TypeVar, - Union, - cast, -) +from typing import TYPE_CHECKING, Generic, Literal, Optional, TypeVar, Union, cast import boto3 from aibs_informatics_core.models.aws.core import AWSRegion @@ -32,6 +21,7 @@ from boto3 import Session from boto3.resources.base import ServiceResource from botocore.client import BaseClient, ClientError +from botocore.config import Config from botocore.session import Session as BotocoreSession if TYPE_CHECKING: # pragma: no cover @@ -249,8 +239,22 @@ def get_client( region_name = get_region(region=region or kwargs.get("region_name")) if region_name: kwargs["region_name"] = region_name + + # If config for our client is not set, we want to set it to use "standard" mode + # (default is "legacy") and increase the number of retries to 5 (default is 3) + # See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html#available-retry-modes + config: Optional[Config] = kwargs.pop("config", None) + default_config = Config( + connect_timeout=120, read_timeout=120, retries={"max_attempts": 6, "mode": "standard"} + ) + if config is None: + config = default_config + else: + # Have values in pre-existing config (if it exists) take precedence over default_config + config = default_config.merge(other_config=config) + session = session or boto3.Session() - return session.client(service, **kwargs) + return session.client(service, config=config, **kwargs) @cache @@ -280,8 +284,22 @@ def get_resource( region_name = get_region(region=region or kwargs.get("region_name")) if region_name: kwargs["region_name"] = region_name + + # If config for our client is not set, we want to set it to use "standard" mode + # (default is "legacy") and increase the number of retries to 5 (default is 3) + # See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html#available-retry-modes + config: Optional[Config] = kwargs.pop("config", None) + default_config = Config( + connect_timeout=120, read_timeout=120, retries={"max_attempts": 6, "mode": "standard"} + ) + if config is None: + config = default_config + else: + # Have values in pre-existing config (if it exists) take precedence over default_config + config = default_config.merge(other_config=config) + session = session or boto3.Session() - return session.resource(service, **kwargs) + return session.resource(service, config=config, **kwargs) @dataclass diff --git a/test/aibs_informatics_aws_utils/test_core.py b/test/aibs_informatics_aws_utils/test_core.py index 83ab3a9..75cf48e 100644 --- a/test/aibs_informatics_aws_utils/test_core.py +++ b/test/aibs_informatics_aws_utils/test_core.py @@ -3,6 +3,7 @@ import moto import pytest +from botocore.config import Config from aibs_informatics_aws_utils.core import ( AWSService, @@ -116,3 +117,71 @@ def test__get_resource__gets_service_resources(self): mock.call("sqs", region=None), ] ) + + +@pytest.mark.parametrize( + "service, preexisting_config, expected_retries_config", + [ + pytest.param( + # service + "s3", + # preexisting_config + None, + # expected_retries_config + {"total_max_attempts": 7, "mode": "standard"}, + id="Basic test case (no preexisting_config provided)", + ), + pytest.param( + # service + "dynamodb", + # preexisting_config + Config(retries={"max_attempts": 8, "mode": "adaptive"}), + # expected_retries_config + {"total_max_attempts": 9, "mode": "adaptive"}, + id="Test preexisting_config doesn't get overridden by default", + ), + ], +) +def test___core__get_client__config_setup_properly( + aws_credentials_fixture, service, preexisting_config, expected_retries_config +): + if preexisting_config: + client = get_client(service=service, config=preexisting_config) + else: + client = get_client(service=service) + + assert expected_retries_config == client._client_config.retries + + +@pytest.mark.parametrize( + "service, preexisting_config, expected_retries_config", + [ + pytest.param( + # service + "s3", + # preexisting_config + None, + # expected_retries_config + {"total_max_attempts": 7, "mode": "standard"}, + id="Basic test case (no preexisting_config provided)", + ), + pytest.param( + # service + "dynamodb", + # preexisting_config + Config(retries={"max_attempts": 8, "mode": "adaptive"}), + # expected_retries_config + {"total_max_attempts": 9, "mode": "adaptive"}, + id="Test preexisting_config doesn't get overridden by default", + ), + ], +) +def test___core__get_resource__config_setup_properly( + aws_credentials_fixture, service, preexisting_config, expected_retries_config +): + if preexisting_config: + resource = get_resource(service=service, config=preexisting_config) + else: + resource = get_resource(service=service) + + assert expected_retries_config == resource.meta.client._client_config.retries