Skip to content

Commit

Permalink
Merge pull request #9 from AllenInstitute/feature/MSPE-391
Browse files Browse the repository at this point in the history
MSPE-391: Adding Demand Execution Constructs to CDK Lib
  • Loading branch information
rpmcginty authored Jun 5, 2024
2 parents 9166d25 + 1decff6 commit 5563f55
Show file tree
Hide file tree
Showing 27 changed files with 1,222 additions and 1,018 deletions.
4 changes: 4 additions & 0 deletions src/aibs_informatics_cdk_lib/common/aws/iam_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ def grant_managed_policies(


LAMBDA_FULL_ACCESS_ACTIONS = ["lambda:*"]
LAMBDA_READ_ONLY_ACTIONS = [
"lambda:Get*",
"lambda:List*",
]

S3_FULL_ACCESS_ACTIONS = ["s3:*"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import constructs
from aibs_informatics_core.utils.decorators import cached_property
from aibs_informatics_core.utils.hashing import generate_path_hash
from aws_cdk import aws_ecr_assets as ecr_assets
from aws_cdk import aws_lambda as lambda_
from aws_cdk import aws_s3_assets, aws_s3_deployment

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)

AIBS_INFORMATICS_AWS_LAMBDA_REPO_ENV_VAR = "AIBS_INFORMATICS_AWS_LAMBDA_REPO"
AIBS_INFORMATICS_AWS_LAMBDA_REPO = "[email protected]/AllenInstitute/aibs-informatics-aws-lambda.git"
AIBS_INFORMATICS_AWS_LAMBDA_REPO = "[email protected]:AllenInstitute/aibs-informatics-aws-lambda.git"

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -56,7 +56,7 @@ def resolve_repo_path(cls, repo_url: str, repo_path_env_var: Optional[str]) -> P
return repo_path


class AIBSInformaticsCodeAssets(constructs.Construct):
class AIBSInformaticsCodeAssets(constructs.Construct, AssetsMixin):
def __init__(
self,
scope: constructs.Construct,
Expand All @@ -76,16 +76,9 @@ def AIBS_INFORMATICS_AWS_LAMBDA(self) -> CodeAsset:
CodeAsset: The code asset
"""

if AIBS_INFORMATICS_AWS_LAMBDA_REPO_ENV_VAR in os.environ:
logger.info(f"Using {AIBS_INFORMATICS_AWS_LAMBDA_REPO_ENV_VAR} from environment")
repo_path = os.getenv(AIBS_INFORMATICS_AWS_LAMBDA_REPO_ENV_VAR)
if not repo_path or not is_local_repo(repo_path):
raise ValueError(
f"Environment variable {AIBS_INFORMATICS_AWS_LAMBDA_REPO_ENV_VAR} is not a valid git repo"
)
repo_path = Path(repo_path)
else:
repo_path = clone_repo(AIBS_INFORMATICS_AWS_LAMBDA_REPO, skip_if_exists=True)
repo_path = self.resolve_repo_path(
AIBS_INFORMATICS_AWS_LAMBDA_REPO, AIBS_INFORMATICS_AWS_LAMBDA_REPO_ENV_VAR
)

asset_hash = generate_path_hash(
path=str(repo_path.resolve()),
Expand Down Expand Up @@ -179,8 +172,27 @@ def AIBS_INFORMATICS_AWS_LAMBDA(self) -> aws_ecr_assets.DockerImageAsset:
platform=aws_ecr_assets.Platform.LINUX_AMD64,
asset_name="aibs-informatics-aws-lambda",
file="docker/Dockerfile",
extra_hash=generate_path_hash(
path=str(repo_path.resolve()),
excludes=PYTHON_REGEX_EXCLUDES,
),
exclude=[
*PYTHON_GLOB_EXCLUDES,
*GLOBAL_GLOB_EXCLUDES,
],
)


class AIBSInformaticsAssets(constructs.Construct):
def __init__(
self,
scope: constructs.Construct,
construct_id: str,
env_base: EnvBase,
runtime: Optional[lambda_.Runtime] = None,
) -> None:
super().__init__(scope, construct_id)
self.env_base = env_base

self.code_assets = AIBSInformaticsCodeAssets(self, "CodeAssets", env_base, runtime=runtime)
self.docker_assets = AIBSInformaticsDockerAssets(self, "DockerAssets", env_base)
16 changes: 8 additions & 8 deletions src/aibs_informatics_cdk_lib/constructs_/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ def is_test_or_prod(self) -> bool:
def construct_tags(self) -> List[cdk.Tag]:
return []

@property
def aws_region(self) -> str:
return cdk.Stack.of(self.as_construct()).region

@property
def aws_account(self) -> str:
return cdk.Stack.of(self.as_construct()).account

def add_tags(self, *tags: cdk.Tag):
for tag in tags:
cdk.Tags.of(self.as_construct()).add(key=tag.key, value=tag.value)
Expand Down Expand Up @@ -107,11 +115,3 @@ def __init__(self, scope: Construct, id: Optional[str], env_base: EnvBase) -> No
@property
def construct_id(self) -> str:
return self.node.id

@property
def aws_region(self) -> str:
return cdk.Stack.of(self).region

@property
def aws_account(self) -> str:
return cdk.Stack.of(self).account
14 changes: 7 additions & 7 deletions src/aibs_informatics_cdk_lib/constructs_/batch/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,28 @@

LOW_PRIORITY_BATCH_ENV_CONFIG = BatchEnvironmentConfig(
allocation_strategy=batch.AllocationStrategy.SPOT_CAPACITY_OPTIMIZED,
instance_types=SPOT_INSTANCE_TYPES,
instance_types=[*SPOT_INSTANCE_TYPES],
use_spot=True,
use_fargate=False,
use_public_subnets=False,
)
NORMAL_PRIORITY_BATCH_ENV_CONFIG = BatchEnvironmentConfig(
allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE,
instance_types=SPOT_INSTANCE_TYPES,
instance_types=[*SPOT_INSTANCE_TYPES],
use_spot=True,
use_fargate=False,
use_public_subnets=False,
)
HIGH_PRIORITY_BATCH_ENV_CONFIG = BatchEnvironmentConfig(
allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE,
instance_types=ON_DEMAND_INSTANCE_TYPES,
instance_types=[*ON_DEMAND_INSTANCE_TYPES],
use_spot=False,
use_fargate=False,
use_public_subnets=False,
)
PUBLIC_SUBNET_BATCH_ENV_CONFIG = BatchEnvironmentConfig(
allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE,
instance_types=TRANSFER_INSTANCE_TYPES,
instance_types=[*TRANSFER_INSTANCE_TYPES],
use_spot=False,
use_fargate=False,
use_public_subnets=True,
Expand All @@ -52,21 +52,21 @@
)
LAMBDA_SMALL_BATCH_ENV_CONFIG = BatchEnvironmentConfig(
allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE,
instance_types=LAMBDA_SMALL_INSTANCE_TYPES,
instance_types=[*LAMBDA_SMALL_INSTANCE_TYPES],
use_spot=False,
use_fargate=False,
use_public_subnets=False,
)
LAMBDA_MEDIUM_BATCH_ENV_CONFIG = BatchEnvironmentConfig(
allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE,
instance_types=LAMBDA_MEDIUM_INSTANCE_TYPES,
instance_types=[*LAMBDA_MEDIUM_INSTANCE_TYPES],
use_spot=False,
use_fargate=False,
use_public_subnets=False,
)
LAMBDA_LARGE_BATCH_ENV_CONFIG = BatchEnvironmentConfig(
allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE,
instance_types=LAMBDA_LARGE_INSTANCE_TYPES,
instance_types=[*LAMBDA_LARGE_INSTANCE_TYPES],
use_spot=False,
use_fargate=False,
use_public_subnets=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def create_launch_template(
launch_template = ec2.LaunchTemplate(
self,
launch_template_name,
launch_template_name=launch_template_name,
# NOTE: unsetting because of complications when multiple batch environments are created
# launch_template_name=launch_template_name,
instance_initiated_shutdown_behavior=ec2.InstanceInitiatedShutdownBehavior.TERMINATE,
security_group=security_group,
user_data=user_data,
Expand Down Expand Up @@ -121,7 +122,7 @@ def create_launch_template(

@dataclass
class BatchLaunchTemplateUserData:
env_base: str
env_base: EnvBase
batch_env_name: str
user_data_text: str = field(init=False)

Expand Down
8 changes: 6 additions & 2 deletions src/aibs_informatics_cdk_lib/constructs_/cw/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import defaultdict
from copy import deepcopy
from math import ceil
from typing import Any, Dict, List, Literal, Optional, Tuple
from typing import Any, Dict, List, Literal, Optional, Tuple, cast

import aws_cdk as cdk
import constructs
Expand Down Expand Up @@ -126,7 +126,11 @@ def create_widgets_and_alarms(
else:
metric_name = graph_metric.label or metric_config["statistic"]
else:
metric_name = metric_config["metric"]
metric = metric_config["metric"]
if isinstance(metric, cw.Metric):
metric_name = metric.metric_name
else:
metric_name = str(metric)
metric_label = metric_config.get(
"label",
re.sub(
Expand Down
15 changes: 2 additions & 13 deletions src/aibs_informatics_cdk_lib/constructs_/cw/types.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,6 @@
import re
from collections import defaultdict
from copy import deepcopy
from math import ceil
from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union

import aws_cdk as cdk
import constructs
from aibs_informatics_core.env import EnvBase
from aws_cdk import aws_cloudwatch as cw
from aws_cdk import aws_cloudwatch_actions as cw_actions
from aws_cdk import aws_sns as sns
from typing import Dict, List, Literal, TypedDict, Union

from aibs_informatics_cdk_lib.constructs_.base import EnvBaseConstruct, EnvBaseConstructMixins
from aws_cdk import aws_cloudwatch as cw


class _AlarmMetricConfigOptional(TypedDict, total=False):
Expand Down
52 changes: 32 additions & 20 deletions src/aibs_informatics_cdk_lib/constructs_/efs/file_system.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal, Optional, Tuple, TypeVar, Union
from typing import Any, Literal, Optional, Tuple, TypeVar, Union, cast

import aws_cdk as cdk
import constructs
from aibs_informatics_aws_utils.batch import to_mount_point, to_volume
from aibs_informatics_aws_utils.constants.efs import (
EFS_MOUNT_POINT_PATH_VAR,
EFS_ROOT_ACCESS_POINT_TAG,
EFS_ROOT_PATH,
EFS_SCRATCH_ACCESS_POINT_TAG,
EFS_SCRATCH_PATH,
EFS_SHARED_ACCESS_POINT_TAG,
EFS_SHARED_PATH,
EFS_TMP_ACCESS_POINT_TAG,
EFS_TMP_PATH,
EFSTag,
)
Expand All @@ -32,6 +26,7 @@

from aibs_informatics_cdk_lib.common.aws.iam_utils import grant_managed_policies
from aibs_informatics_cdk_lib.constructs_.base import EnvBaseConstruct, EnvBaseConstructMixins
from aibs_informatics_cdk_lib.constructs_.sfn.utils import convert_to_sfn_api_action_case

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -75,7 +70,7 @@ def __init__(
**kwargs,
)

self._root_access_point = self.create_access_point("root", EFS_ROOT_PATH)
self._root_access_point = self.create_access_point(name="root", path=EFS_ROOT_PATH)

@property
def root_access_point(self) -> efs.AccessPoint:
Expand All @@ -98,12 +93,11 @@ def create_access_point(
efs.AccessPoint: _description_
"""
ap_tags = [tag if isinstance(tag, EFSTag) else EFSTag(*tag) for tag in tags]

if not any(tag.key == "Name" for tag in ap_tags):
ap_tags.insert(0, EFSTag("Name", name))

cfn_access_point = efs.CfnAccessPoint(
self,
self.get_stack_of(self),
self.get_construct_id(name, "cfn-ap"),
file_system_id=self.file_system_id,
access_point_tags=[
Expand Down Expand Up @@ -169,11 +163,14 @@ def __init__(
vpc=vpc,
)

self.shared_access_point = self.file_system.create_access_point("shared", EFS_SHARED_PATH)
self.shared_access_point = self.file_system.create_access_point(
name="shared", path=EFS_SHARED_PATH
)
self.scratch_access_point = self.file_system.create_access_point(
"scratch", EFS_SCRATCH_PATH
name="scratch", path=EFS_SCRATCH_PATH
)
self.tmp_access_point = self.file_system.create_access_point("tmp", EFS_TMP_PATH)
self.tmp_access_point = self.file_system.create_access_point(name="tmp", path=EFS_TMP_PATH)
self.file_system.add_tags(cdk.Tag("blah", self.env_base))

@property
def file_system(self) -> EnvBaseFileSystem:
Expand Down Expand Up @@ -250,26 +247,41 @@ def file_system_id(self) -> str:
else:
raise ValueError("No file system or access point provided")

def to_batch_mount_point(self, name: str) -> dict[str, Any]:
return to_mount_point(self.mount_point, self.read_only, source_volume=name) # type: ignore

def to_batch_volume(self, name: str) -> dict[str, Any]:
@property
def access_point_id(self) -> Optional[str]:
if self.access_point:
return self.access_point.access_point_id
return None

def to_batch_mount_point(self, name: str, sfn_format: bool = False) -> dict[str, Any]:
mount_point: dict[str, Any] = to_mount_point(
self.mount_point, self.read_only, source_volume=name
) # type: ignore[arg-type] # typed dict should be accepted
if sfn_format:
return convert_to_sfn_api_action_case(mount_point)
return mount_point

def to_batch_volume(self, name: str, sfn_format: bool = False) -> dict[str, Any]:
efs_volume_configuration: dict[str, Any] = {
"fileSystemId": self.file_system_id,
}
if self.access_point:
efs_volume_configuration["transitEncryption"] = "ENABLED"
# TODO: Consider adding IAM
efs_volume_configuration["authorizationConfig"] = {
"accessPointId": self.access_point.access_point_id,
"iam": "ENABLED",
"iam": "DISABLED",
}
else:
efs_volume_configuration["rootDirectory"] = self.root_directory or "/"
return to_volume(
volume: dict[str, Any] = to_volume(
None,
name=name,
efs_volume_configuration=efs_volume_configuration, # type: ignore
) # type: ignore
)
if sfn_format:
return convert_to_sfn_api_action_case(volume)
return volume


def create_access_point(
Expand Down
2 changes: 1 addition & 1 deletion src/aibs_informatics_cdk_lib/constructs_/s3/bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
):
self.env_base = env_base
self._full_bucket_name = bucket_name
if self._full_bucket_name is not None:
if bucket_name is not None:
self._full_bucket_name = env_base.get_bucket_name(
base_name=bucket_name, account_id=account_id, region=region
)
Expand Down
Loading

0 comments on commit 5563f55

Please sign in to comment.