Skip to content

Commit

Permalink
Merge pull request #24 from AllenInstitute/add-retry-to-nocredentials…
Browse files Browse the repository at this point in the history
…-error

Add retry to nocredentials error
  • Loading branch information
rpmcginty authored Dec 19, 2024
2 parents 5dcff74 + b4414dd commit 3e75ba3
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 18 deletions.
17 changes: 10 additions & 7 deletions src/aibs_informatics_aws_utils/efs/mount_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
from typing import TYPE_CHECKING, Dict, List, Optional, Union

from aibs_informatics_core.models.aws.efs import AccessPointId, EFSPath, FileSystemId
from aibs_informatics_core.utils.decorators import retry
from aibs_informatics_core.utils.hashing import sha256_hexdigest
from aibs_informatics_core.utils.os_operations import get_env_var
from botocore.exceptions import NoCredentialsError

from aibs_informatics_aws_utils.constants.efs import (
EFS_MOUNT_POINT_ID_VAR,
Expand Down Expand Up @@ -377,22 +379,23 @@ def __repr__(self) -> str:


@cache
@retry(retryable_exceptions=(NoCredentialsError), tries=5, backoff=2.0)
def detect_mount_points() -> List[MountPointConfiguration]:
mount_points: List[MountPointConfiguration] = []

if batch_job_id := get_env_var("AWS_BATCH_JOB_ID"):
logger.info(f"Detected Batch job {batch_job_id}")
batch_mp_configs = _detect_moint_points_from_batch_job(batch_job_id)
batch_mp_configs = _detect_mount_points_from_batch_job(batch_job_id)
logger.info(f"Detected {len(batch_mp_configs)} EFS mount points from Batch")
mount_points.extend(batch_mp_configs)
elif lambda_function_name := get_env_var("AWS_LAMBDA_FUNCTION_NAME"):
logger.info(f"Detected Lambda function {lambda_function_name}")
lambda_mp_configs = _detect_moint_points_from_lambda(lambda_function_name)
lambda_mp_configs = _detect_mount_points_from_lambda(lambda_function_name)
logger.info(f"Detected {len(lambda_mp_configs)} EFS mount points from Lambda")
mount_points.extend(lambda_mp_configs)
else:
logger.info("No Lambda or Batch environment detected. Using environment variables.")
env_mount_points = _detect_moint_points_from_env()
env_mount_points = _detect_mount_points_from_env()
logger.info(
f"Detected {len(env_mount_points)} EFS mount points from environment variables"
)
Expand Down Expand Up @@ -441,7 +444,7 @@ def deduplicate_mount_points(
# ------------------------------------


def _detect_moint_points_from_lambda(lambda_function_name: str) -> List[MountPointConfiguration]:
def _detect_mount_points_from_lambda(lambda_function_name: str) -> List[MountPointConfiguration]:
mount_points: List[MountPointConfiguration] = []
lambda_ = get_lambda_client()
response = lambda_.get_function_configuration(FunctionName=lambda_function_name)
Expand All @@ -458,9 +461,9 @@ def _detect_moint_points_from_lambda(lambda_function_name: str) -> List[MountPoi
return _remove_invalid_mount_points(mount_points)


def _detect_moint_points_from_batch_job(batch_job_id: str) -> List[MountPointConfiguration]:
def _detect_mount_points_from_batch_job(batch_job_id: str) -> List[MountPointConfiguration]:
mount_points: List[MountPointConfiguration] = []
batch = AWSService.BATCH.get_client()
batch = get_batch_client()
response = batch.describe_jobs(jobs=[batch_job_id])
job_container = response.get("jobs", [{}])[0].get("container", {})
batch_mount_points = job_container.get("mountPoints")
Expand Down Expand Up @@ -496,7 +499,7 @@ def _detect_moint_points_from_batch_job(batch_job_id: str) -> List[MountPointCon
return _remove_invalid_mount_points(mount_points)


def _detect_moint_points_from_env() -> List[MountPointConfiguration]:
def _detect_mount_points_from_env() -> List[MountPointConfiguration]:
mount_points: List[MountPointConfiguration] = []

for k, v in os.environ.items():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,6 @@ def assertLocalFileSystem_partition(
self.assertSetEqual(expected_node_paths, local_node_paths)


@moto.mock_aws
class EFSFileSystemTests(EFSTestsBase):
def setUp(self) -> None:
super().setUp()
Expand Down
6 changes: 6 additions & 0 deletions test/aibs_informatics_aws_utils/data_sync/test_operations.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from pathlib import Path
from test.aibs_informatics_aws_utils.base import AwsBaseTest
from typing import Optional, Union
Expand All @@ -6,6 +7,7 @@
from aibs_informatics_core.models.aws.s3 import S3URI
from aibs_informatics_core.models.data_sync import RemoteToLocalConfig
from aibs_informatics_core.utils.os_operations import find_all_paths
from pytest import mark

from aibs_informatics_aws_utils.data_sync.operations import sync_data
from aibs_informatics_aws_utils.s3 import get_s3_client, get_s3_resource, is_object, list_s3_paths
Expand Down Expand Up @@ -360,6 +362,10 @@ def test__sync_data__s3_to_local__file__does_not_exist(self):
)
assert not destination_path.exists()

@mark.xfail(
sys.platform == "darwin",
reason="Test does not run on macOS (tmp dir is /private which is not accessible)",
)
def test__sync_data__s3_to_local__file__auto_custom_tmp_dir__succeeds(self):
fs = self.setUpLocalFS()
self.setUpBucket()
Expand Down
19 changes: 17 additions & 2 deletions test/aibs_informatics_aws_utils/efs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,36 @@

import boto3
import moto
import moto.core
import moto.core.config
import moto.core.decorator


class EFSTestsBase(AwsBaseTest):
def setUp(self) -> None:
super().setUp()

self.mock_efs = moto.mock_aws()
# HACK: We must define moto.mock_aws here instead of as a decorator because moto gives us
# issues when we try to use the moto.mock_aws decorator in the child classes
# (as https://github.com/getmoto/moto/issues/7063 suggests).
# We previously double decorated - decorating the child and parent class - but this
# now fails in python 3.12 (perhaps cleanup of mock collisions leniency).
#
# This is a workaround until we can find a better solution. If configs need to be
# overridden, they can be passed in via the mock_aws_config property on the child class.
self.mock_efs = moto.mock_aws(config=self.mock_aws_config)
self.mock_efs.start()

self.set_aws_credentials()
self._file_store_name_id_map: Dict[str, str] = {}

def tearDown(self) -> None:
super().tearDown()
self.mock_efs.stop()
return super().tearDown()

@property
def mock_aws_config(self) -> Optional[moto.core.config.DefaultConfig]:
return None

@property
def efs_client(self):
Expand Down
1 change: 0 additions & 1 deletion test/aibs_informatics_aws_utils/efs/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
)


@moto.mock_aws
class EFSTests(EFSTestsBase):
def test__list_efs_file_systems__filters_based_on_tag(self):
file_system_id1 = self.create_file_system("fs1", env="dev")
Expand Down
22 changes: 15 additions & 7 deletions test/aibs_informatics_aws_utils/efs/test_mount_point.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import json
from pathlib import Path
from test.aibs_informatics_aws_utils.efs.base import EFSTestsBase
from typing import Optional, Tuple, Union
from unittest import mock, skip
from typing import TYPE_CHECKING, Optional, Tuple, Union

import boto3
import moto
from moto.core.config import DefaultConfig

from aibs_informatics_aws_utils.constants.efs import (
EFS_MOUNT_POINT_ID_VAR,
Expand All @@ -18,13 +18,21 @@
detect_mount_points,
)

if TYPE_CHECKING:
from mypy_boto3_lambda.type_defs import FileSystemConfigTypeDef
else: # pragma: no cover
FileSystemConfigTypeDef = dict


@moto.mock_aws(config={"batch": {"use_docker": False}})
class MountPointConfigurationTests(EFSTestsBase):
def setUp(self) -> None:
super().setUp()
detect_mount_points.cache_clear()

@property
def mock_aws_config(self) -> Optional[DefaultConfig]:
return {"batch": {"use_docker": False}, "core": {"reset_boto3_session": True}}

def setUpEFS(self, *access_points: Tuple[str, Path], file_system_name: Optional[str] = None):
self.create_file_system(file_system_name)
for access_point_name, access_point_path in access_points:
Expand Down Expand Up @@ -136,13 +144,13 @@ def test__detect_mount_points__lambda_config_overrides(self):

# Set up lambda
lambda_client = boto3.client("lambda")
file_system_configs = [
file_system_configs: list[FileSystemConfigTypeDef] = [
{
"Arn": (c1.access_point or {}).get("AccessPointArn"), # type: ignore,
"Arn": (c1.access_point or {}).get("AccessPointArn"), # type: ignore
"LocalMountPath": c1.mount_point.as_posix(),
},
{
"Arn": (c2.access_point or {}).get("AccessPointArn"), # type: ignore,
"Arn": (c2.access_point or {}).get("AccessPointArn"), # type: ignore
"LocalMountPath": c2.mount_point.as_posix(),
},
]
Expand Down Expand Up @@ -331,7 +339,7 @@ def test__detect_mount_points__batch_job_config_overrides(self):

describe_job_response = batch_client.describe_jobs(jobs=[job_id])
with self.stub(batch_client) as batch_stubber:
describe_job_response["jobs"][0]["container"][
describe_job_response["jobs"][0]["container"][ # type: ignore
"mountPoints"
] = batch_mount_point_configs
batch_stubber.add_response(
Expand Down

0 comments on commit 3e75ba3

Please sign in to comment.