Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MSPE-391: Adding Demand Execution Constructs to CDK Lib #9

Merged
merged 6 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/aibs_informatics_cdk_lib/common/aws/iam_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ def grant_managed_policies(


LAMBDA_FULL_ACCESS_ACTIONS = ["lambda:*"]
LAMBDA_READ_ONLY_ACTIONS = [
"lambda:Get*",
"lambda:List*",
]

S3_FULL_ACCESS_ACTIONS = ["s3:*"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import constructs
from aibs_informatics_core.utils.decorators import cached_property
from aibs_informatics_core.utils.hashing import generate_path_hash
from aws_cdk import aws_ecr_assets as ecr_assets
from aws_cdk import aws_lambda as lambda_
from aws_cdk import aws_s3_assets, aws_s3_deployment

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)

AIBS_INFORMATICS_AWS_LAMBDA_REPO_ENV_VAR = "AIBS_INFORMATICS_AWS_LAMBDA_REPO"
AIBS_INFORMATICS_AWS_LAMBDA_REPO = "[email protected]/AllenInstitute/aibs-informatics-aws-lambda.git"
AIBS_INFORMATICS_AWS_LAMBDA_REPO = "[email protected]:AllenInstitute/aibs-informatics-aws-lambda.git"

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -56,7 +56,7 @@ def resolve_repo_path(cls, repo_url: str, repo_path_env_var: Optional[str]) -> P
return repo_path


class AIBSInformaticsCodeAssets(constructs.Construct):
class AIBSInformaticsCodeAssets(constructs.Construct, AssetsMixin):
def __init__(
self,
scope: constructs.Construct,
Expand All @@ -76,16 +76,9 @@ def AIBS_INFORMATICS_AWS_LAMBDA(self) -> CodeAsset:
CodeAsset: The code asset
"""

if AIBS_INFORMATICS_AWS_LAMBDA_REPO_ENV_VAR in os.environ:
logger.info(f"Using {AIBS_INFORMATICS_AWS_LAMBDA_REPO_ENV_VAR} from environment")
repo_path = os.getenv(AIBS_INFORMATICS_AWS_LAMBDA_REPO_ENV_VAR)
if not repo_path or not is_local_repo(repo_path):
raise ValueError(
f"Environment variable {AIBS_INFORMATICS_AWS_LAMBDA_REPO_ENV_VAR} is not a valid git repo"
)
repo_path = Path(repo_path)
else:
repo_path = clone_repo(AIBS_INFORMATICS_AWS_LAMBDA_REPO, skip_if_exists=True)
repo_path = self.resolve_repo_path(
AIBS_INFORMATICS_AWS_LAMBDA_REPO, AIBS_INFORMATICS_AWS_LAMBDA_REPO_ENV_VAR
)

asset_hash = generate_path_hash(
path=str(repo_path.resolve()),
Expand Down Expand Up @@ -179,8 +172,27 @@ def AIBS_INFORMATICS_AWS_LAMBDA(self) -> aws_ecr_assets.DockerImageAsset:
platform=aws_ecr_assets.Platform.LINUX_AMD64,
asset_name="aibs-informatics-aws-lambda",
file="docker/Dockerfile",
extra_hash=generate_path_hash(
path=str(repo_path.resolve()),
excludes=PYTHON_REGEX_EXCLUDES,
),
njmei marked this conversation as resolved.
Show resolved Hide resolved
exclude=[
*PYTHON_GLOB_EXCLUDES,
*GLOBAL_GLOB_EXCLUDES,
],
)


class AIBSInformaticsAssets(constructs.Construct):
def __init__(
self,
scope: constructs.Construct,
construct_id: str,
env_base: EnvBase,
runtime: Optional[lambda_.Runtime] = None,
) -> None:
super().__init__(scope, construct_id)
self.env_base = env_base

self.code_assets = AIBSInformaticsCodeAssets(self, "CodeAssets", env_base, runtime=runtime)
self.docker_assets = AIBSInformaticsDockerAssets(self, "DockerAssets", env_base)
16 changes: 8 additions & 8 deletions src/aibs_informatics_cdk_lib/constructs_/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ def is_test_or_prod(self) -> bool:
def construct_tags(self) -> List[cdk.Tag]:
return []

@property
def aws_region(self) -> str:
return cdk.Stack.of(self.as_construct()).region

@property
def aws_account(self) -> str:
return cdk.Stack.of(self.as_construct()).account

def add_tags(self, *tags: cdk.Tag):
for tag in tags:
cdk.Tags.of(self.as_construct()).add(key=tag.key, value=tag.value)
Expand Down Expand Up @@ -107,11 +115,3 @@ def __init__(self, scope: Construct, id: Optional[str], env_base: EnvBase) -> No
@property
def construct_id(self) -> str:
return self.node.id

@property
def aws_region(self) -> str:
return cdk.Stack.of(self).region

@property
def aws_account(self) -> str:
return cdk.Stack.of(self).account
14 changes: 7 additions & 7 deletions src/aibs_informatics_cdk_lib/constructs_/batch/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,28 @@

LOW_PRIORITY_BATCH_ENV_CONFIG = BatchEnvironmentConfig(
allocation_strategy=batch.AllocationStrategy.SPOT_CAPACITY_OPTIMIZED,
instance_types=SPOT_INSTANCE_TYPES,
instance_types=[*SPOT_INSTANCE_TYPES],
use_spot=True,
use_fargate=False,
use_public_subnets=False,
)
NORMAL_PRIORITY_BATCH_ENV_CONFIG = BatchEnvironmentConfig(
allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE,
instance_types=SPOT_INSTANCE_TYPES,
instance_types=[*SPOT_INSTANCE_TYPES],
use_spot=True,
use_fargate=False,
use_public_subnets=False,
)
HIGH_PRIORITY_BATCH_ENV_CONFIG = BatchEnvironmentConfig(
allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE,
instance_types=ON_DEMAND_INSTANCE_TYPES,
instance_types=[*ON_DEMAND_INSTANCE_TYPES],
use_spot=False,
use_fargate=False,
use_public_subnets=False,
)
PUBLIC_SUBNET_BATCH_ENV_CONFIG = BatchEnvironmentConfig(
allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE,
instance_types=TRANSFER_INSTANCE_TYPES,
instance_types=[*TRANSFER_INSTANCE_TYPES],
use_spot=False,
use_fargate=False,
use_public_subnets=True,
Expand All @@ -52,21 +52,21 @@
)
LAMBDA_SMALL_BATCH_ENV_CONFIG = BatchEnvironmentConfig(
allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE,
instance_types=LAMBDA_SMALL_INSTANCE_TYPES,
instance_types=[*LAMBDA_SMALL_INSTANCE_TYPES],
use_spot=False,
use_fargate=False,
use_public_subnets=False,
)
LAMBDA_MEDIUM_BATCH_ENV_CONFIG = BatchEnvironmentConfig(
allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE,
instance_types=LAMBDA_MEDIUM_INSTANCE_TYPES,
instance_types=[*LAMBDA_MEDIUM_INSTANCE_TYPES],
use_spot=False,
use_fargate=False,
use_public_subnets=False,
)
LAMBDA_LARGE_BATCH_ENV_CONFIG = BatchEnvironmentConfig(
allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE,
instance_types=LAMBDA_LARGE_INSTANCE_TYPES,
instance_types=[*LAMBDA_LARGE_INSTANCE_TYPES],
use_spot=False,
use_fargate=False,
use_public_subnets=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def create_launch_template(
launch_template = ec2.LaunchTemplate(
self,
launch_template_name,
launch_template_name=launch_template_name,
# NOTE: unsetting because of complications when multiple batch environments are created
# launch_template_name=launch_template_name,
njmei marked this conversation as resolved.
Show resolved Hide resolved
instance_initiated_shutdown_behavior=ec2.InstanceInitiatedShutdownBehavior.TERMINATE,
security_group=security_group,
user_data=user_data,
Expand Down Expand Up @@ -121,7 +122,7 @@ def create_launch_template(

@dataclass
class BatchLaunchTemplateUserData:
env_base: str
env_base: EnvBase
batch_env_name: str
user_data_text: str = field(init=False)

Expand Down
8 changes: 6 additions & 2 deletions src/aibs_informatics_cdk_lib/constructs_/cw/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import defaultdict
from copy import deepcopy
from math import ceil
from typing import Any, Dict, List, Literal, Optional, Tuple
from typing import Any, Dict, List, Literal, Optional, Tuple, cast

import aws_cdk as cdk
import constructs
Expand Down Expand Up @@ -126,7 +126,11 @@ def create_widgets_and_alarms(
else:
metric_name = graph_metric.label or metric_config["statistic"]
else:
metric_name = metric_config["metric"]
metric = metric_config["metric"]
if isinstance(metric, cw.Metric):
metric_name = metric.metric_name
else:
metric_name = str(metric)
metric_label = metric_config.get(
"label",
re.sub(
Expand Down
15 changes: 2 additions & 13 deletions src/aibs_informatics_cdk_lib/constructs_/cw/types.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,6 @@
import re
from collections import defaultdict
from copy import deepcopy
from math import ceil
from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union

import aws_cdk as cdk
import constructs
from aibs_informatics_core.env import EnvBase
from aws_cdk import aws_cloudwatch as cw
from aws_cdk import aws_cloudwatch_actions as cw_actions
from aws_cdk import aws_sns as sns
from typing import Dict, List, Literal, TypedDict, Union

from aibs_informatics_cdk_lib.constructs_.base import EnvBaseConstruct, EnvBaseConstructMixins
from aws_cdk import aws_cloudwatch as cw


class _AlarmMetricConfigOptional(TypedDict, total=False):
Expand Down
52 changes: 32 additions & 20 deletions src/aibs_informatics_cdk_lib/constructs_/efs/file_system.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal, Optional, Tuple, TypeVar, Union
from typing import Any, Literal, Optional, Tuple, TypeVar, Union, cast

import aws_cdk as cdk
import constructs
from aibs_informatics_aws_utils.batch import to_mount_point, to_volume
from aibs_informatics_aws_utils.constants.efs import (
EFS_MOUNT_POINT_PATH_VAR,
EFS_ROOT_ACCESS_POINT_TAG,
EFS_ROOT_PATH,
EFS_SCRATCH_ACCESS_POINT_TAG,
EFS_SCRATCH_PATH,
EFS_SHARED_ACCESS_POINT_TAG,
EFS_SHARED_PATH,
EFS_TMP_ACCESS_POINT_TAG,
EFS_TMP_PATH,
EFSTag,
)
Expand All @@ -32,6 +26,7 @@

from aibs_informatics_cdk_lib.common.aws.iam_utils import grant_managed_policies
from aibs_informatics_cdk_lib.constructs_.base import EnvBaseConstruct, EnvBaseConstructMixins
from aibs_informatics_cdk_lib.constructs_.sfn.utils import convert_to_sfn_api_action_case

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -75,7 +70,7 @@ def __init__(
**kwargs,
)

self._root_access_point = self.create_access_point("root", EFS_ROOT_PATH)
self._root_access_point = self.create_access_point(name="root", path=EFS_ROOT_PATH)

@property
def root_access_point(self) -> efs.AccessPoint:
Expand All @@ -98,12 +93,11 @@ def create_access_point(
efs.AccessPoint: _description_
"""
ap_tags = [tag if isinstance(tag, EFSTag) else EFSTag(*tag) for tag in tags]

if not any(tag.key == "Name" for tag in ap_tags):
ap_tags.insert(0, EFSTag("Name", name))

cfn_access_point = efs.CfnAccessPoint(
self,
self.get_stack_of(self),
self.get_construct_id(name, "cfn-ap"),
file_system_id=self.file_system_id,
access_point_tags=[
Expand Down Expand Up @@ -169,11 +163,14 @@ def __init__(
vpc=vpc,
)

self.shared_access_point = self.file_system.create_access_point("shared", EFS_SHARED_PATH)
self.shared_access_point = self.file_system.create_access_point(
name="shared", path=EFS_SHARED_PATH
)
self.scratch_access_point = self.file_system.create_access_point(
"scratch", EFS_SCRATCH_PATH
name="scratch", path=EFS_SCRATCH_PATH
)
self.tmp_access_point = self.file_system.create_access_point("tmp", EFS_TMP_PATH)
self.tmp_access_point = self.file_system.create_access_point(name="tmp", path=EFS_TMP_PATH)
self.file_system.add_tags(cdk.Tag("blah", self.env_base))

@property
def file_system(self) -> EnvBaseFileSystem:
Expand Down Expand Up @@ -250,26 +247,41 @@ def file_system_id(self) -> str:
else:
raise ValueError("No file system or access point provided")

def to_batch_mount_point(self, name: str) -> dict[str, Any]:
return to_mount_point(self.mount_point, self.read_only, source_volume=name) # type: ignore

def to_batch_volume(self, name: str) -> dict[str, Any]:
@property
def access_point_id(self) -> Optional[str]:
if self.access_point:
return self.access_point.access_point_id
return None

def to_batch_mount_point(self, name: str, sfn_format: bool = False) -> dict[str, Any]:
mount_point: dict[str, Any] = to_mount_point(
self.mount_point, self.read_only, source_volume=name
) # type: ignore[arg-type] # typed dict should be accepted
if sfn_format:
return convert_to_sfn_api_action_case(mount_point)
return mount_point

def to_batch_volume(self, name: str, sfn_format: bool = False) -> dict[str, Any]:
efs_volume_configuration: dict[str, Any] = {
"fileSystemId": self.file_system_id,
}
if self.access_point:
efs_volume_configuration["transitEncryption"] = "ENABLED"
# TODO: Consider adding IAM
efs_volume_configuration["authorizationConfig"] = {
"accessPointId": self.access_point.access_point_id,
"iam": "ENABLED",
"iam": "DISABLED",
njmei marked this conversation as resolved.
Show resolved Hide resolved
}
else:
efs_volume_configuration["rootDirectory"] = self.root_directory or "/"
return to_volume(
volume: dict[str, Any] = to_volume(
None,
name=name,
efs_volume_configuration=efs_volume_configuration, # type: ignore
) # type: ignore
)
if sfn_format:
return convert_to_sfn_api_action_case(volume)
return volume


def create_access_point(
Expand Down
2 changes: 1 addition & 1 deletion src/aibs_informatics_cdk_lib/constructs_/s3/bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
):
self.env_base = env_base
self._full_bucket_name = bucket_name
if self._full_bucket_name is not None:
if bucket_name is not None:
self._full_bucket_name = env_base.get_bucket_name(
base_name=bucket_name, account_id=account_id, region=region
)
Expand Down
Loading
Loading