Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add retry to nocredentials error #24

Merged
merged 5 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions src/aibs_informatics_aws_utils/efs/mount_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
from typing import TYPE_CHECKING, Dict, List, Optional, Union

from aibs_informatics_core.models.aws.efs import AccessPointId, EFSPath, FileSystemId
from aibs_informatics_core.utils.decorators import retry
from aibs_informatics_core.utils.hashing import sha256_hexdigest
from aibs_informatics_core.utils.os_operations import get_env_var
from botocore.exceptions import NoCredentialsError

from aibs_informatics_aws_utils.constants.efs import (
EFS_MOUNT_POINT_ID_VAR,
Expand Down Expand Up @@ -377,22 +379,23 @@ def __repr__(self) -> str:


@cache
@retry(retryable_exceptions=(NoCredentialsError), tries=5, backoff=2.0)
def detect_mount_points() -> List[MountPointConfiguration]:
mount_points: List[MountPointConfiguration] = []

if batch_job_id := get_env_var("AWS_BATCH_JOB_ID"):
logger.info(f"Detected Batch job {batch_job_id}")
batch_mp_configs = _detect_moint_points_from_batch_job(batch_job_id)
batch_mp_configs = _detect_mount_points_from_batch_job(batch_job_id)
logger.info(f"Detected {len(batch_mp_configs)} EFS mount points from Batch")
mount_points.extend(batch_mp_configs)
elif lambda_function_name := get_env_var("AWS_LAMBDA_FUNCTION_NAME"):
logger.info(f"Detected Lambda function {lambda_function_name}")
lambda_mp_configs = _detect_moint_points_from_lambda(lambda_function_name)
lambda_mp_configs = _detect_mount_points_from_lambda(lambda_function_name)
logger.info(f"Detected {len(lambda_mp_configs)} EFS mount points from Lambda")
mount_points.extend(lambda_mp_configs)
else:
logger.info("No Lambda or Batch environment detected. Using environment variables.")
env_mount_points = _detect_moint_points_from_env()
env_mount_points = _detect_mount_points_from_env()
logger.info(
f"Detected {len(env_mount_points)} EFS mount points from environment variables"
)
Expand Down Expand Up @@ -441,7 +444,7 @@ def deduplicate_mount_points(
# ------------------------------------


def _detect_moint_points_from_lambda(lambda_function_name: str) -> List[MountPointConfiguration]:
def _detect_mount_points_from_lambda(lambda_function_name: str) -> List[MountPointConfiguration]:
mount_points: List[MountPointConfiguration] = []
lambda_ = get_lambda_client()
response = lambda_.get_function_configuration(FunctionName=lambda_function_name)
Expand All @@ -458,9 +461,9 @@ def _detect_moint_points_from_lambda(lambda_function_name: str) -> List[MountPoi
return _remove_invalid_mount_points(mount_points)


def _detect_moint_points_from_batch_job(batch_job_id: str) -> List[MountPointConfiguration]:
def _detect_mount_points_from_batch_job(batch_job_id: str) -> List[MountPointConfiguration]:
mount_points: List[MountPointConfiguration] = []
batch = AWSService.BATCH.get_client()
batch = get_batch_client()
response = batch.describe_jobs(jobs=[batch_job_id])
job_container = response.get("jobs", [{}])[0].get("container", {})
batch_mount_points = job_container.get("mountPoints")
Expand Down Expand Up @@ -496,7 +499,7 @@ def _detect_moint_points_from_batch_job(batch_job_id: str) -> List[MountPointCon
return _remove_invalid_mount_points(mount_points)


def _detect_moint_points_from_env() -> List[MountPointConfiguration]:
def _detect_mount_points_from_env() -> List[MountPointConfiguration]:
mount_points: List[MountPointConfiguration] = []

for k, v in os.environ.items():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,6 @@ def assertLocalFileSystem_partition(
self.assertSetEqual(expected_node_paths, local_node_paths)


@moto.mock_aws
class EFSFileSystemTests(EFSTestsBase):
def setUp(self) -> None:
super().setUp()
Expand Down
19 changes: 17 additions & 2 deletions test/aibs_informatics_aws_utils/efs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,36 @@

import boto3
import moto
import moto.core
import moto.core.config
import moto.core.decorator


class EFSTestsBase(AwsBaseTest):
def setUp(self) -> None:
super().setUp()

self.mock_efs = moto.mock_aws()
# HACK: We must define moto.mock_aws here instead of as a decorator because moto gives us
# issues when we try to use the moto.mock_aws decorator in the child classes
# (as https://github.com/getmoto/moto/issues/7063 suggests).
# We previously double decorated - decorating the child and parent class - but this
# now fails in python 3.12 (perhaps cleanup of mock collisions leniency).
#
# This is a workaround until we can find a better solution. If configs need to be
# overridden, they can be passed in via the mock_aws_config property on the child class.
self.mock_efs = moto.mock_aws(config=self.mock_aws_config)
self.mock_efs.start()

self.set_aws_credentials()
self._file_store_name_id_map: Dict[str, str] = {}

def tearDown(self) -> None:
super().tearDown()
self.mock_efs.stop()
return super().tearDown()

@property
def mock_aws_config(self) -> Optional[moto.core.config.DefaultConfig]:
return None

@property
def efs_client(self):
Expand Down
1 change: 0 additions & 1 deletion test/aibs_informatics_aws_utils/efs/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
)


@moto.mock_aws
class EFSTests(EFSTestsBase):
def test__list_efs_file_systems__filters_based_on_tag(self):
file_system_id1 = self.create_file_system("fs1", env="dev")
Expand Down
22 changes: 15 additions & 7 deletions test/aibs_informatics_aws_utils/efs/test_mount_point.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import json
from pathlib import Path
from test.aibs_informatics_aws_utils.efs.base import EFSTestsBase
from typing import Optional, Tuple, Union
from unittest import mock, skip
from typing import TYPE_CHECKING, Optional, Tuple, Union

import boto3
import moto
from moto.core.config import DefaultConfig

from aibs_informatics_aws_utils.constants.efs import (
EFS_MOUNT_POINT_ID_VAR,
Expand All @@ -18,13 +18,21 @@
detect_mount_points,
)

if TYPE_CHECKING:
from mypy_boto3_lambda.type_defs import FileSystemConfigTypeDef
else: # pragma: no cover
FileSystemConfigTypeDef = dict


@moto.mock_aws(config={"batch": {"use_docker": False}})
class MountPointConfigurationTests(EFSTestsBase):
def setUp(self) -> None:
super().setUp()
detect_mount_points.cache_clear()

@property
def mock_aws_config(self) -> Optional[DefaultConfig]:
return {"batch": {"use_docker": False}, "core": {"reset_boto3_session": True}}

def setUpEFS(self, *access_points: Tuple[str, Path], file_system_name: Optional[str] = None):
self.create_file_system(file_system_name)
for access_point_name, access_point_path in access_points:
Expand Down Expand Up @@ -136,13 +144,13 @@ def test__detect_mount_points__lambda_config_overrides(self):

# Set up lambda
lambda_client = boto3.client("lambda")
file_system_configs = [
file_system_configs: list[FileSystemConfigTypeDef] = [
{
"Arn": (c1.access_point or {}).get("AccessPointArn"), # type: ignore,
"Arn": (c1.access_point or {}).get("AccessPointArn"), # type: ignore
"LocalMountPath": c1.mount_point.as_posix(),
},
{
"Arn": (c2.access_point or {}).get("AccessPointArn"), # type: ignore,
"Arn": (c2.access_point or {}).get("AccessPointArn"), # type: ignore
"LocalMountPath": c2.mount_point.as_posix(),
},
]
Expand Down Expand Up @@ -331,7 +339,7 @@ def test__detect_mount_points__batch_job_config_overrides(self):

describe_job_response = batch_client.describe_jobs(jobs=[job_id])
with self.stub(batch_client) as batch_stubber:
describe_job_response["jobs"][0]["container"][
describe_job_response["jobs"][0]["container"][ # type: ignore
"mountPoints"
] = batch_mount_point_configs
batch_stubber.add_response(
Expand Down
Loading