diff --git a/src/aibs_informatics_cdk_lib/common/aws/iam_utils.py b/src/aibs_informatics_cdk_lib/common/aws/iam_utils.py index 90b3bea..6c05b79 100644 --- a/src/aibs_informatics_cdk_lib/common/aws/iam_utils.py +++ b/src/aibs_informatics_cdk_lib/common/aws/iam_utils.py @@ -170,6 +170,10 @@ def grant_managed_policies( LAMBDA_FULL_ACCESS_ACTIONS = ["lambda:*"] +LAMBDA_READ_ONLY_ACTIONS = [ + "lambda:Get*", + "lambda:List*", +] S3_FULL_ACCESS_ACTIONS = ["s3:*"] diff --git a/src/aibs_informatics_cdk_lib/constructs_/assets/code_asset.py b/src/aibs_informatics_cdk_lib/constructs_/assets/code_asset.py index 3ffc7b6..0472f80 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/assets/code_asset.py +++ b/src/aibs_informatics_cdk_lib/constructs_/assets/code_asset.py @@ -8,7 +8,6 @@ import constructs from aibs_informatics_core.utils.decorators import cached_property from aibs_informatics_core.utils.hashing import generate_path_hash -from aws_cdk import aws_ecr_assets as ecr_assets from aws_cdk import aws_lambda as lambda_ from aws_cdk import aws_s3_assets, aws_s3_deployment diff --git a/src/aibs_informatics_cdk_lib/constructs_/assets/code_asset_definitions.py b/src/aibs_informatics_cdk_lib/constructs_/assets/code_asset_definitions.py index 09d4c3f..35d2637 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/assets/code_asset_definitions.py +++ b/src/aibs_informatics_cdk_lib/constructs_/assets/code_asset_definitions.py @@ -21,7 +21,7 @@ ) AIBS_INFORMATICS_AWS_LAMBDA_REPO_ENV_VAR = "AIBS_INFORMATICS_AWS_LAMBDA_REPO" -AIBS_INFORMATICS_AWS_LAMBDA_REPO = "git@github.com/AllenInstitute/aibs-informatics-aws-lambda.git" +AIBS_INFORMATICS_AWS_LAMBDA_REPO = "git@github.com:AllenInstitute/aibs-informatics-aws-lambda.git" logger = logging.getLogger(__name__) @@ -56,7 +56,7 @@ def resolve_repo_path(cls, repo_url: str, repo_path_env_var: Optional[str]) -> P return repo_path -class AIBSInformaticsCodeAssets(constructs.Construct): +class AIBSInformaticsCodeAssets(constructs.Construct, AssetsMixin): def __init__( self, scope: constructs.Construct, @@ -76,16 +76,9 @@ def AIBS_INFORMATICS_AWS_LAMBDA(self) -> CodeAsset: CodeAsset: The code asset """ - if AIBS_INFORMATICS_AWS_LAMBDA_REPO_ENV_VAR in os.environ: - logger.info(f"Using {AIBS_INFORMATICS_AWS_LAMBDA_REPO_ENV_VAR} from environment") - repo_path = os.getenv(AIBS_INFORMATICS_AWS_LAMBDA_REPO_ENV_VAR) - if not repo_path or not is_local_repo(repo_path): - raise ValueError( - f"Environment variable {AIBS_INFORMATICS_AWS_LAMBDA_REPO_ENV_VAR} is not a valid git repo" - ) - repo_path = Path(repo_path) - else: - repo_path = clone_repo(AIBS_INFORMATICS_AWS_LAMBDA_REPO, skip_if_exists=True) + repo_path = self.resolve_repo_path( + AIBS_INFORMATICS_AWS_LAMBDA_REPO, AIBS_INFORMATICS_AWS_LAMBDA_REPO_ENV_VAR + ) asset_hash = generate_path_hash( path=str(repo_path.resolve()), @@ -179,8 +172,27 @@ def AIBS_INFORMATICS_AWS_LAMBDA(self) -> aws_ecr_assets.DockerImageAsset: platform=aws_ecr_assets.Platform.LINUX_AMD64, asset_name="aibs-informatics-aws-lambda", file="docker/Dockerfile", + extra_hash=generate_path_hash( + path=str(repo_path.resolve()), + excludes=PYTHON_REGEX_EXCLUDES, + ), exclude=[ *PYTHON_GLOB_EXCLUDES, *GLOBAL_GLOB_EXCLUDES, ], ) + + +class AIBSInformaticsAssets(constructs.Construct): + def __init__( + self, + scope: constructs.Construct, + construct_id: str, + env_base: EnvBase, + runtime: Optional[lambda_.Runtime] = None, + ) -> None: + super().__init__(scope, construct_id) + self.env_base = env_base + + self.code_assets = AIBSInformaticsCodeAssets(self, "CodeAssets", env_base, runtime=runtime) + self.docker_assets = AIBSInformaticsDockerAssets(self, "DockerAssets", env_base) diff --git a/src/aibs_informatics_cdk_lib/constructs_/base.py b/src/aibs_informatics_cdk_lib/constructs_/base.py index 545c694..b9bbec7 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/base.py +++ b/src/aibs_informatics_cdk_lib/constructs_/base.py @@ -32,6 +32,14 @@ def is_test_or_prod(self) -> bool: def construct_tags(self) -> List[cdk.Tag]: return [] + @property + def aws_region(self) -> str: + return cdk.Stack.of(self.as_construct()).region + + @property + def aws_account(self) -> str: + return cdk.Stack.of(self.as_construct()).account + def add_tags(self, *tags: cdk.Tag): for tag in tags: cdk.Tags.of(self.as_construct()).add(key=tag.key, value=tag.value) @@ -107,11 +115,3 @@ def __init__(self, scope: Construct, id: Optional[str], env_base: EnvBase) -> No @property def construct_id(self) -> str: return self.node.id - - @property - def aws_region(self) -> str: - return cdk.Stack.of(self).region - - @property - def aws_account(self) -> str: - return cdk.Stack.of(self).account diff --git a/src/aibs_informatics_cdk_lib/constructs_/batch/defaults.py b/src/aibs_informatics_cdk_lib/constructs_/batch/defaults.py index 4ba368c..806afeb 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/batch/defaults.py +++ b/src/aibs_informatics_cdk_lib/constructs_/batch/defaults.py @@ -12,28 +12,28 @@ LOW_PRIORITY_BATCH_ENV_CONFIG = BatchEnvironmentConfig( allocation_strategy=batch.AllocationStrategy.SPOT_CAPACITY_OPTIMIZED, - instance_types=SPOT_INSTANCE_TYPES, + instance_types=[*SPOT_INSTANCE_TYPES], use_spot=True, use_fargate=False, use_public_subnets=False, ) NORMAL_PRIORITY_BATCH_ENV_CONFIG = BatchEnvironmentConfig( allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE, - instance_types=SPOT_INSTANCE_TYPES, + instance_types=[*SPOT_INSTANCE_TYPES], use_spot=True, use_fargate=False, use_public_subnets=False, ) HIGH_PRIORITY_BATCH_ENV_CONFIG = BatchEnvironmentConfig( allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE, - instance_types=ON_DEMAND_INSTANCE_TYPES, + instance_types=[*ON_DEMAND_INSTANCE_TYPES], use_spot=False, use_fargate=False, use_public_subnets=False, ) PUBLIC_SUBNET_BATCH_ENV_CONFIG = BatchEnvironmentConfig( allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE, - instance_types=TRANSFER_INSTANCE_TYPES, + instance_types=[*TRANSFER_INSTANCE_TYPES], use_spot=False, use_fargate=False, use_public_subnets=True, @@ -52,21 +52,21 @@ ) LAMBDA_SMALL_BATCH_ENV_CONFIG = BatchEnvironmentConfig( allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE, - instance_types=LAMBDA_SMALL_INSTANCE_TYPES, + instance_types=[*LAMBDA_SMALL_INSTANCE_TYPES], use_spot=False, use_fargate=False, use_public_subnets=False, ) LAMBDA_MEDIUM_BATCH_ENV_CONFIG = BatchEnvironmentConfig( allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE, - instance_types=LAMBDA_MEDIUM_INSTANCE_TYPES, + instance_types=[*LAMBDA_MEDIUM_INSTANCE_TYPES], use_spot=False, use_fargate=False, use_public_subnets=False, ) LAMBDA_LARGE_BATCH_ENV_CONFIG = BatchEnvironmentConfig( allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE, - instance_types=LAMBDA_LARGE_INSTANCE_TYPES, + instance_types=[*LAMBDA_LARGE_INSTANCE_TYPES], use_spot=False, use_fargate=False, use_public_subnets=False, diff --git a/src/aibs_informatics_cdk_lib/constructs_/batch/launch_template.py b/src/aibs_informatics_cdk_lib/constructs_/batch/launch_template.py index 146931f..1abcb4d 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/batch/launch_template.py +++ b/src/aibs_informatics_cdk_lib/constructs_/batch/launch_template.py @@ -58,7 +58,8 @@ def create_launch_template( launch_template = ec2.LaunchTemplate( self, launch_template_name, - launch_template_name=launch_template_name, + # NOTE: unsetting because of complications when multiple batch environments are created + # launch_template_name=launch_template_name, instance_initiated_shutdown_behavior=ec2.InstanceInitiatedShutdownBehavior.TERMINATE, security_group=security_group, user_data=user_data, @@ -121,7 +122,7 @@ def create_launch_template( @dataclass class BatchLaunchTemplateUserData: - env_base: str + env_base: EnvBase batch_env_name: str user_data_text: str = field(init=False) diff --git a/src/aibs_informatics_cdk_lib/constructs_/cw/dashboard.py b/src/aibs_informatics_cdk_lib/constructs_/cw/dashboard.py index 8ae266f..f0c1b67 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/cw/dashboard.py +++ b/src/aibs_informatics_cdk_lib/constructs_/cw/dashboard.py @@ -2,7 +2,7 @@ from collections import defaultdict from copy import deepcopy from math import ceil -from typing import Any, Dict, List, Literal, Optional, Tuple +from typing import Any, Dict, List, Literal, Optional, Tuple, cast import aws_cdk as cdk import constructs @@ -126,7 +126,11 @@ def create_widgets_and_alarms( else: metric_name = graph_metric.label or metric_config["statistic"] else: - metric_name = metric_config["metric"] + metric = metric_config["metric"] + if isinstance(metric, cw.Metric): + metric_name = metric.metric_name + else: + metric_name = str(metric) metric_label = metric_config.get( "label", re.sub( diff --git a/src/aibs_informatics_cdk_lib/constructs_/cw/types.py b/src/aibs_informatics_cdk_lib/constructs_/cw/types.py index a843250..a1efaf0 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/cw/types.py +++ b/src/aibs_informatics_cdk_lib/constructs_/cw/types.py @@ -1,17 +1,6 @@ -import re -from collections import defaultdict -from copy import deepcopy -from math import ceil -from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union - -import aws_cdk as cdk -import constructs -from aibs_informatics_core.env import EnvBase -from aws_cdk import aws_cloudwatch as cw -from aws_cdk import aws_cloudwatch_actions as cw_actions -from aws_cdk import aws_sns as sns +from typing import Dict, List, Literal, TypedDict, Union -from aibs_informatics_cdk_lib.constructs_.base import EnvBaseConstruct, EnvBaseConstructMixins +from aws_cdk import aws_cloudwatch as cw class _AlarmMetricConfigOptional(TypedDict, total=False): diff --git a/src/aibs_informatics_cdk_lib/constructs_/efs/file_system.py b/src/aibs_informatics_cdk_lib/constructs_/efs/file_system.py index 4aa6081..4d39dbf 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/efs/file_system.py +++ b/src/aibs_informatics_cdk_lib/constructs_/efs/file_system.py @@ -1,20 +1,14 @@ import logging from dataclasses import dataclass -from pathlib import Path -from typing import Any, Literal, Optional, Tuple, TypeVar, Union +from typing import Any, Literal, Optional, Tuple, TypeVar, Union, cast import aws_cdk as cdk import constructs from aibs_informatics_aws_utils.batch import to_mount_point, to_volume from aibs_informatics_aws_utils.constants.efs import ( - EFS_MOUNT_POINT_PATH_VAR, - EFS_ROOT_ACCESS_POINT_TAG, EFS_ROOT_PATH, - EFS_SCRATCH_ACCESS_POINT_TAG, EFS_SCRATCH_PATH, - EFS_SHARED_ACCESS_POINT_TAG, EFS_SHARED_PATH, - EFS_TMP_ACCESS_POINT_TAG, EFS_TMP_PATH, EFSTag, ) @@ -32,6 +26,7 @@ from aibs_informatics_cdk_lib.common.aws.iam_utils import grant_managed_policies from aibs_informatics_cdk_lib.constructs_.base import EnvBaseConstruct, EnvBaseConstructMixins +from aibs_informatics_cdk_lib.constructs_.sfn.utils import convert_to_sfn_api_action_case logger = logging.getLogger(__name__) @@ -75,7 +70,7 @@ def __init__( **kwargs, ) - self._root_access_point = self.create_access_point("root", EFS_ROOT_PATH) + self._root_access_point = self.create_access_point(name="root", path=EFS_ROOT_PATH) @property def root_access_point(self) -> efs.AccessPoint: @@ -98,12 +93,11 @@ def create_access_point( efs.AccessPoint: _description_ """ ap_tags = [tag if isinstance(tag, EFSTag) else EFSTag(*tag) for tag in tags] - if not any(tag.key == "Name" for tag in ap_tags): ap_tags.insert(0, EFSTag("Name", name)) cfn_access_point = efs.CfnAccessPoint( - self, + self.get_stack_of(self), self.get_construct_id(name, "cfn-ap"), file_system_id=self.file_system_id, access_point_tags=[ @@ -169,11 +163,14 @@ def __init__( vpc=vpc, ) - self.shared_access_point = self.file_system.create_access_point("shared", EFS_SHARED_PATH) + self.shared_access_point = self.file_system.create_access_point( + name="shared", path=EFS_SHARED_PATH + ) self.scratch_access_point = self.file_system.create_access_point( - "scratch", EFS_SCRATCH_PATH + name="scratch", path=EFS_SCRATCH_PATH ) - self.tmp_access_point = self.file_system.create_access_point("tmp", EFS_TMP_PATH) + self.tmp_access_point = self.file_system.create_access_point(name="tmp", path=EFS_TMP_PATH) + self.file_system.add_tags(cdk.Tag("blah", self.env_base)) @property def file_system(self) -> EnvBaseFileSystem: @@ -250,26 +247,41 @@ def file_system_id(self) -> str: else: raise ValueError("No file system or access point provided") - def to_batch_mount_point(self, name: str) -> dict[str, Any]: - return to_mount_point(self.mount_point, self.read_only, source_volume=name) # type: ignore - - def to_batch_volume(self, name: str) -> dict[str, Any]: + @property + def access_point_id(self) -> Optional[str]: + if self.access_point: + return self.access_point.access_point_id + return None + + def to_batch_mount_point(self, name: str, sfn_format: bool = False) -> dict[str, Any]: + mount_point: dict[str, Any] = to_mount_point( + self.mount_point, self.read_only, source_volume=name + ) # type: ignore[arg-type] # typed dict should be accepted + if sfn_format: + return convert_to_sfn_api_action_case(mount_point) + return mount_point + + def to_batch_volume(self, name: str, sfn_format: bool = False) -> dict[str, Any]: efs_volume_configuration: dict[str, Any] = { "fileSystemId": self.file_system_id, } if self.access_point: efs_volume_configuration["transitEncryption"] = "ENABLED" + # TODO: Consider adding IAM efs_volume_configuration["authorizationConfig"] = { "accessPointId": self.access_point.access_point_id, - "iam": "ENABLED", + "iam": "DISABLED", } else: efs_volume_configuration["rootDirectory"] = self.root_directory or "/" - return to_volume( + volume: dict[str, Any] = to_volume( None, name=name, efs_volume_configuration=efs_volume_configuration, # type: ignore - ) # type: ignore + ) + if sfn_format: + return convert_to_sfn_api_action_case(volume) + return volume def create_access_point( diff --git a/src/aibs_informatics_cdk_lib/constructs_/s3/bucket.py b/src/aibs_informatics_cdk_lib/constructs_/s3/bucket.py index a333d89..33394ed 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/s3/bucket.py +++ b/src/aibs_informatics_cdk_lib/constructs_/s3/bucket.py @@ -29,7 +29,7 @@ def __init__( ): self.env_base = env_base self._full_bucket_name = bucket_name - if self._full_bucket_name is not None: + if bucket_name is not None: self._full_bucket_name = env_base.get_bucket_name( base_name=bucket_name, account_id=account_id, region=region ) diff --git a/src/aibs_informatics_cdk_lib/stacks/compute.py b/src/aibs_informatics_cdk_lib/constructs_/service/compute.py similarity index 56% rename from src/aibs_informatics_cdk_lib/stacks/compute.py rename to src/aibs_informatics_cdk_lib/constructs_/service/compute.py index 59b9631..b0bf225 100644 --- a/src/aibs_informatics_cdk_lib/stacks/compute.py +++ b/src/aibs_informatics_cdk_lib/constructs_/service/compute.py @@ -1,20 +1,28 @@ from abc import abstractmethod -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Iterable, List, Optional, Union -from aibs_informatics_aws_utils import AWS_REGION_VAR +import aws_cdk as cdk from aibs_informatics_core.env import EnvBase from aws_cdk import aws_batch_alpha as batch from aws_cdk import aws_ec2 as ec2 from aws_cdk import aws_efs as efs from aws_cdk import aws_iam as iam +from aws_cdk import aws_lambda as lambda_ from aws_cdk import aws_s3 as s3 from aws_cdk import aws_stepfunctions as sfn +from aws_cdk import aws_stepfunctions_tasks as sfn_tasks from constructs import Construct from aibs_informatics_cdk_lib.common.aws.iam_utils import ( + LAMBDA_READ_ONLY_ACTIONS, batch_policy_statement, + lambda_policy_statement, s3_policy_statement, ) +from aibs_informatics_cdk_lib.constructs_.assets.code_asset_definitions import ( + AIBSInformaticsCodeAssets, +) +from aibs_informatics_cdk_lib.constructs_.base import EnvBaseConstruct from aibs_informatics_cdk_lib.constructs_.batch.infrastructure import ( Batch, BatchEnvironment, @@ -38,10 +46,9 @@ from aibs_informatics_cdk_lib.constructs_.sfn.fragments.informatics import ( BatchInvokedLambdaFunction, ) -from aibs_informatics_cdk_lib.stacks.base import EnvBaseStack -class BaseComputeStack(EnvBaseStack): +class BaseBatchComputeConstruct(EnvBaseConstruct): def __init__( self, scope: Construct, @@ -52,11 +59,18 @@ def __init__( buckets: Optional[Iterable[s3.Bucket]] = None, file_systems: Optional[Iterable[Union[efs.FileSystem, efs.IFileSystem]]] = None, mount_point_configs: Optional[Iterable[MountPointConfiguration]] = None, + instance_role_policy_statements: Optional[List[iam.PolicyStatement]] = None, **kwargs, ) -> None: super().__init__(scope, id, env_base, **kwargs) self.batch_name = batch_name - self.batch = Batch(self, batch_name, self.env_base, vpc=vpc) + self.batch = Batch( + self, + batch_name, + self.env_base, + vpc=vpc, + instance_role_policy_statements=instance_role_policy_statements, + ) self.create_batch_environments() @@ -139,7 +153,7 @@ def _update_file_systems_from_mount_point_configs( return list(file_system_map.values()) -class ComputeStack(BaseComputeStack): +class BatchCompute(BaseBatchComputeConstruct): @property def primary_batch_environment(self) -> BatchEnvironment: return self.on_demand_batch_environment @@ -149,7 +163,7 @@ def create_batch_environments(self): self, f"{self.name}-lt-builder", env_base=self.env_base ) self.on_demand_batch_environment = self.batch.setup_batch_environment( - descriptor=BatchEnvironmentDescriptor("on-demand"), + descriptor=BatchEnvironmentDescriptor(f"{self.name}-on-demand"), config=BatchEnvironmentConfig( allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE, instance_types=[ec2.InstanceType(_) for _ in ON_DEMAND_INSTANCE_TYPES], @@ -161,7 +175,7 @@ def create_batch_environments(self): ) self.spot_batch_environment = self.batch.setup_batch_environment( - descriptor=BatchEnvironmentDescriptor("spot"), + descriptor=BatchEnvironmentDescriptor(f"{self.name}-spot"), config=BatchEnvironmentConfig( allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE, instance_types=[ec2.InstanceType(_) for _ in SPOT_INSTANCE_TYPES], @@ -173,7 +187,7 @@ def create_batch_environments(self): ) self.fargate_batch_environment = self.batch.setup_batch_environment( - descriptor=BatchEnvironmentDescriptor("fargate"), + descriptor=BatchEnvironmentDescriptor(f"{self.name}-fargate"), config=BatchEnvironmentConfig( allocation_strategy=None, instance_types=None, @@ -185,17 +199,17 @@ def create_batch_environments(self): ) -class LambdaComputeStack(ComputeStack): +class LambdaCompute(BatchCompute): @property def primary_batch_environment(self) -> BatchEnvironment: return self.lambda_batch_environment def create_batch_environments(self): lt_builder = BatchLaunchTemplateBuilder( - self, f"${self.name}-lt-builder", env_base=self.env_base + self, f"{self.name}-lt-builder", env_base=self.env_base ) self.lambda_batch_environment = self.batch.setup_batch_environment( - descriptor=BatchEnvironmentDescriptor("lambda"), + descriptor=BatchEnvironmentDescriptor(f"{self.name}-lambda"), config=BatchEnvironmentConfig( allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE, instance_types=[ @@ -206,12 +220,13 @@ def create_batch_environments(self): use_spot=False, use_fargate=False, use_public_subnets=False, + minv_cpus=2, ), launch_template_builder=lt_builder, ) self.lambda_small_batch_environment = self.batch.setup_batch_environment( - descriptor=BatchEnvironmentDescriptor("lambda-small"), + descriptor=BatchEnvironmentDescriptor(f"{self.name}-lambda-small"), config=BatchEnvironmentConfig( allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE, instance_types=[*LAMBDA_SMALL_INSTANCE_TYPES], @@ -223,19 +238,20 @@ def create_batch_environments(self): ) self.lambda_medium_batch_environment = self.batch.setup_batch_environment( - descriptor=BatchEnvironmentDescriptor("lambda-medium"), + descriptor=BatchEnvironmentDescriptor(f"{self.name}-lambda-medium"), config=BatchEnvironmentConfig( allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE, instance_types=[*LAMBDA_MEDIUM_INSTANCE_TYPES], use_spot=False, use_fargate=False, use_public_subnets=False, + minv_cpus=2, ), launch_template_builder=lt_builder, ) self.lambda_large_batch_environment = self.batch.setup_batch_environment( - descriptor=BatchEnvironmentDescriptor("lambda-large"), + descriptor=BatchEnvironmentDescriptor(f"{self.name}-lambda-large"), config=BatchEnvironmentConfig( allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE, instance_types=[*LAMBDA_LARGE_INSTANCE_TYPES], @@ -245,161 +261,3 @@ def create_batch_environments(self): ), launch_template_builder=lt_builder, ) - - -class ComputeWorkflowStack(EnvBaseStack): - def __init__( - self, - scope: Construct, - id: Optional[str], - env_base: EnvBase, - batch_environment: BatchEnvironment, - primary_bucket: s3.Bucket, - buckets: Optional[Iterable[s3.Bucket]] = None, - mount_point_configs: Optional[Iterable[MountPointConfiguration]] = None, - **kwargs, - ) -> None: - super().__init__(scope, id, env_base, **kwargs) - self.primary_bucket = primary_bucket - self.buckets = list(buckets or []) - self.batch_environment = batch_environment - self.mount_point_configs = list(mount_point_configs) if mount_point_configs else None - - self.create_submit_job_step_function() - self.create_lambda_invoke_step_function() - - def create_submit_job_step_function(self): - fragment = SubmitJobWithDefaultsFragment( - self, - "submit-job-fragment", - self.env_base, - job_queue=self.batch_environment.job_queue.job_queue_arn, - mount_point_configs=self.mount_point_configs, - ) - state_machine_name = self.get_resource_name("submit-job") - - self.batch_submit_job_state_machine = fragment.to_state_machine( - state_machine_name=state_machine_name, - role=iam.Role( - self, - self.env_base.get_construct_id(state_machine_name, "role"), - assumed_by=iam.ServicePrincipal("states.amazonaws.com"), # type: ignore - managed_policies=[ - iam.ManagedPolicy.from_aws_managed_policy_name( - "AmazonAPIGatewayInvokeFullAccess" - ), - iam.ManagedPolicy.from_aws_managed_policy_name("AWSStepFunctionsFullAccess"), - iam.ManagedPolicy.from_aws_managed_policy_name("CloudWatchLogsFullAccess"), - iam.ManagedPolicy.from_aws_managed_policy_name("CloudWatchEventsFullAccess"), - ], - inline_policies={ - "default": iam.PolicyDocument( - statements=[ - batch_policy_statement(self.env_base), - ] - ), - }, - ), - ) - - def create_lambda_invoke_step_function( - self, - ): - defaults: dict[str, Any] = {} - - defaults["job_queue"] = self.batch_environment.job_queue_name - defaults["memory"] = "1024" - defaults["vcpus"] = "1" - defaults["gpu"] = "0" - defaults["platform_capabilities"] = ["EC2"] - - if self.mount_point_configs: - mount_points, volumes = AWSBatchMixins.convert_to_mount_point_and_volumes( - self.mount_point_configs - ) - defaults["mount_points"] = mount_points - defaults["volumes"] = volumes - - start = sfn.Pass( - self, - "Start", - parameters={ - "image": sfn.JsonPath.string_at("$.image"), - "handler": sfn.JsonPath.string_at("$.handler"), - "payload": sfn.JsonPath.object_at("$.request"), - # We will merge the rest with the defaults - "input": sfn.JsonPath.object_at("$"), - "default": defaults, - }, - ) - - merge = sfn.Pass( - self, - "Merge", - parameters={ - "image": sfn.JsonPath.string_at("$.image"), - "handler": sfn.JsonPath.string_at("$.handler"), - "payload": sfn.JsonPath.string_at("$.payload"), - "compute": sfn.JsonPath.json_merge( - sfn.JsonPath.object_at("$.default"), sfn.JsonPath.object_at("$.input") - ), - }, - ) - - batch_invoked_lambda = BatchInvokedLambdaFunction( - self, - "Data Chain", - env_base=self.env_base, - image="$.image", - name="run-lambda-function", - handler="$.handler", - payload_path="$.payload", - bucket_name=self.primary_bucket.bucket_name, - job_queue=self.batch_environment.job_queue_name, - environment={ - EnvBase.ENV_BASE_KEY: self.env_base, - "AWS_REGION": self.aws_region, - "AWS_ACCOUNT_ID": self.aws_account, - }, - memory=sfn.JsonPath.string_at("$.compute.memory"), - vcpus=sfn.JsonPath.string_at("$.compute.vcpus"), - mount_points=sfn.JsonPath.string_at("$.compute.mount_points"), - volumes=sfn.JsonPath.string_at("$.compute.volumes"), - platform_capabilities=sfn.JsonPath.string_at("$.compute.platform_capabilities"), - ) - - # fmt: off - definition = ( - start - .next(merge) - .next(batch_invoked_lambda.to_single_state()) - ) - # fmt: on - - self.batch_invoked_lambda_state_machine = create_state_machine( - self, - env_base=self.env_base, - name=self.env_base.get_state_machine_name("batch-invoked-lambda"), - definition=definition, - role=iam.Role( - self, - self.env_base.get_construct_id("batch-invoked-lambda", "role"), - assumed_by=iam.ServicePrincipal("states.amazonaws.com"), # type: ignore - managed_policies=[ - iam.ManagedPolicy.from_aws_managed_policy_name( - "AmazonAPIGatewayInvokeFullAccess" - ), - iam.ManagedPolicy.from_aws_managed_policy_name("AWSStepFunctionsFullAccess"), - iam.ManagedPolicy.from_aws_managed_policy_name("CloudWatchLogsFullAccess"), - iam.ManagedPolicy.from_aws_managed_policy_name("CloudWatchEventsFullAccess"), - ], - inline_policies={ - "default": iam.PolicyDocument( - statements=[ - batch_policy_statement(self.env_base), - s3_policy_statement(self.env_base), - ] - ), - }, - ), - ) diff --git a/src/aibs_informatics_cdk_lib/constructs_/service/storage.py b/src/aibs_informatics_cdk_lib/constructs_/service/storage.py new file mode 100644 index 0000000..2ae3b22 --- /dev/null +++ b/src/aibs_informatics_cdk_lib/constructs_/service/storage.py @@ -0,0 +1,54 @@ +from typing import Optional + +import aws_cdk as cdk +from aibs_informatics_core.env import EnvBase +from aws_cdk import aws_ec2 as ec2 +from constructs import Construct + +from aibs_informatics_cdk_lib.constructs_.base import EnvBaseConstruct +from aibs_informatics_cdk_lib.constructs_.efs.file_system import EFSEcosystem, EnvBaseFileSystem +from aibs_informatics_cdk_lib.constructs_.s3 import EnvBaseBucket, LifecycleRuleGenerator + + +class Storage(EnvBaseConstruct): + def __init__( + self, + scope: Construct, + id: Optional[str], + env_base: EnvBase, + name: str, + vpc: ec2.Vpc, + removal_policy: cdk.RemovalPolicy = cdk.RemovalPolicy.RETAIN, + **kwargs, + ) -> None: + super().__init__(scope, id, env_base, **kwargs) + + self._bucket = EnvBaseBucket( + self, + "Bucket", + self.env_base, + bucket_name=name, + removal_policy=removal_policy, + lifecycle_rules=[ + LifecycleRuleGenerator.expire_files_under_prefix(), + LifecycleRuleGenerator.expire_files_with_scratch_tags(), + LifecycleRuleGenerator.use_storage_class_as_default(), + ], + ) + + self._efs_ecosystem = EFSEcosystem( + self, id="EFS", env_base=self.env_base, file_system_name=name, vpc=vpc + ) + self._file_system = self._efs_ecosystem.file_system + + @property + def bucket(self) -> EnvBaseBucket: + return self._bucket + + @property + def efs_ecosystem(self) -> EFSEcosystem: + return self._efs_ecosystem + + @property + def file_system(self) -> EnvBaseFileSystem: + return self._file_system diff --git a/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/base.py b/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/base.py index 9a3ea21..1bb33ef 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/base.py +++ b/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/base.py @@ -1,6 +1,5 @@ -import builtins from abc import abstractmethod -from typing import Any, Dict, List, Mapping, Optional, Type, TypeVar, cast +from typing import Any, Dict, List, Mapping, Optional, Sequence, TypeVar, Union, cast import aws_cdk as cdk import constructs @@ -14,13 +13,84 @@ from aibs_informatics_cdk_lib.common.aws.core_utils import build_lambda_arn from aibs_informatics_cdk_lib.common.aws.sfn_utils import JsonReferencePath from aibs_informatics_cdk_lib.constructs_.base import EnvBaseConstructMixins -from aibs_informatics_cdk_lib.constructs_.sfn.utils import convert_reference_paths T = TypeVar("T", bound=ValidatedStr) F = TypeVar("F", bound="StateMachineFragment") +def create_log_options( + scope: constructs.Construct, + id: str, + env_base: EnvBase, + removal_policy: Optional[cdk.RemovalPolicy] = None, + retention: Optional[logs_.RetentionDays] = None, +) -> sfn.LogOptions: + return sfn.LogOptions( + destination=logs_.LogGroup( + scope, + env_base.get_construct_id(id, "state-loggroup"), + log_group_name=env_base.get_state_machine_log_group_name(id), + removal_policy=removal_policy or cdk.RemovalPolicy.DESTROY, + retention=retention or logs_.RetentionDays.ONE_MONTH, + ) + ) + + +def create_role( + scope: constructs.Construct, + id: str, + env_base: EnvBase, + assumed_by: iam.IPrincipal = iam.ServicePrincipal("states.amazonaws.com"), # type: ignore[assignment] + managed_policies: Optional[Sequence[Union[iam.IManagedPolicy, str]]] = None, + inline_policies: Optional[Mapping[str, iam.PolicyDocument]] = None, + inline_policies_from_statements: Optional[Mapping[str, Sequence[iam.PolicyStatement]]] = None, + include_default_managed_policies: bool = True, +) -> iam.Role: + construct_id = env_base.get_construct_id(id, "role") + + if managed_policies is not None: + managed_policies = [ + iam.ManagedPolicy.from_aws_managed_policy_name(policy) + if isinstance(policy, str) + else policy + for policy in managed_policies + ] + + if inline_policies is None: + inline_policies = {} + if inline_policies_from_statements: + inline_policies = { + **inline_policies, + **{ + name: iam.PolicyDocument(statements=statements) + for name, statements in inline_policies_from_statements.items() + }, + } + + return iam.Role( + scope, + construct_id, + assumed_by=assumed_by, # type: ignore + managed_policies=[ + *(managed_policies or []), + *[ + iam.ManagedPolicy.from_aws_managed_policy_name(policy) + for policy in ( + [ + "AWSStepFunctionsFullAccess", + "CloudWatchLogsFullAccess", + "CloudWatchEventsFullAccess", + ] + if include_default_managed_policies + else [] + ) + ], + ], + inline_policies=inline_policies, + ) + + class StateMachineMixins(EnvBaseConstructMixins): def get_fn(self, function_name: str) -> lambda_.IFunction: cache_attr = "_function_cache" @@ -45,7 +115,7 @@ def get_state_machine_from_name(self, state_machine_name: str) -> sfn.IStateMach resource_cache = cast(Dict[str, sfn.IStateMachine], getattr(self, cache_attr)) if state_machine_name not in resource_cache: resource_cache[state_machine_name] = sfn.StateMachine.from_state_machine_name( - scope=self, + scope=self.as_construct(), id=self.env_base.get_construct_id(state_machine_name, "from-name"), state_machine_name=self.env_base.get_state_machine_name(state_machine_name), ) @@ -55,7 +125,8 @@ def get_state_machine_from_name(self, state_machine_name: str) -> sfn.IStateMach def create_state_machine( scope: constructs.Construct, env_base: EnvBase, - name: str, + id: str, + name: Optional[str], definition: sfn.IChainable, role: Optional[iam.Role] = None, logs: Optional[sfn.LogOptions] = None, @@ -63,15 +134,15 @@ def create_state_machine( ) -> sfn.StateMachine: return sfn.StateMachine( scope, - env_base.get_construct_id(name), - state_machine_name=env_base.get_state_machine_name(name), + env_base.get_construct_id(id), + state_machine_name=env_base.get_state_machine_name(name) if name else None, logs=( logs or sfn.LogOptions( destination=logs_.LogGroup( scope, - env_base.get_construct_id(name, "state-loggroup"), - log_group_name=env_base.get_state_machine_log_group_name(name), + env_base.get_construct_id(id, "state-loggroup"), + log_group_name=env_base.get_state_machine_log_group_name(name or id), removal_policy=cdk.RemovalPolicy.DESTROY, retention=logs_.RetentionDays.ONE_MONTH, ) @@ -100,50 +171,60 @@ def start_state(self) -> sfn.State: def end_states(self) -> List[sfn.INextable]: return self.definition.end_states - @classmethod def enclose( - cls: Type[F], - scope: constructs.Construct, + self, id: str, - definition: sfn.IChainable, + input_path: Optional[str] = None, result_path: Optional[str] = None, - ) -> F: - chain = ( - sfn.Chain.start(definition) - if not isinstance(definition, (sfn.Chain, sfn.StateMachineFragment)) - else definition - ) + ) -> sfn.Chain: + """Enclose the current state machine fragment within a parallel state. + + Notes: + - If input_path is not provided, it will default to "$" + - If result_path is not provided, it will default to input_path + + Args: + id (str): an identifier for the parallel state + input_path (Optional[str], optional): input path for the enclosed state. + Defaults to "$". + result_path (Optional[str], optional): result path to put output of enclosed state. + Defaults to same as input_path. + + Returns: + sfn.Chain: the new state machine fragment + """ + if input_path is None: + input_path = "$" + if result_path is None: + result_path = input_path - pre = sfn.Pass( - scope, f"{id} Parallel Prep", parameters={"input": sfn.JsonPath.entire_payload} + chain = ( + sfn.Chain.start(self.definition) + if not isinstance(self.definition, (sfn.Chain, sfn.StateMachineFragment)) + else self.definition ) if isinstance(chain, sfn.Chain): parallel = chain.to_single_state( - id=f"{id} Parallel", input_path="$.input", result_path="$.result" + id=f"{id} Enclosure", input_path=input_path, result_path=result_path ) else: - parallel = chain.to_single_state(input_path="$.input", result_path="$.result") + parallel = chain.to_single_state(input_path=input_path, result_path=result_path) + definition = sfn.Chain.start(parallel) - mod_result_path = JsonReferencePath("$.input") if result_path and result_path != sfn.JsonPath.DISCARD: - mod_result_path = mod_result_path + result_path - - post = sfn.Pass( - scope, - f"{id} Parallel Post", - input_path="$.result[0]", - result_path=mod_result_path.as_reference, - output_path="$.input", - ) - - new_definition = sfn.Chain.start(pre).next(parallel).next(post) - (self := cls(scope, id)).definition = new_definition + restructure = sfn.Pass( + self, + f"{id} Enclosure Post", + input_path=f"{result_path}[0]", + result_path=result_path, + ) + definition = definition.next(restructure) - return self + return definition -class EnvBaseStateMachineFragment(sfn.StateMachineFragment, StateMachineMixins): +class EnvBaseStateMachineFragment(StateMachineFragment, StateMachineMixins): def __init__( self, scope: constructs.Construct, @@ -153,22 +234,6 @@ def __init__( super().__init__(scope, id) self.env_base = env_base - @property - def definition(self) -> sfn.IChainable: - return self._definition - - @definition.setter - def definition(self, value: sfn.IChainable): - self._definition = value - - @property - def start_state(self) -> sfn.State: - return self.definition.start_state - - @property - def end_states(self) -> List[sfn.INextable]: - return self.definition.end_states - def to_single_state( self, *, @@ -197,29 +262,45 @@ def to_state_machine( logs: Optional[sfn.LogOptions] = None, timeout: Optional[cdk.Duration] = None, ) -> sfn.StateMachine: + if role is None: + role = create_role( + self, + state_machine_name, + self.env_base, + managed_policies=self.required_managed_policies, + inline_policies_from_statements={ + "default": self.required_inline_policy_statements, + }, + ) + else: + for policy in self.required_managed_policies: + if isinstance(policy, str): + policy = iam.ManagedPolicy.from_aws_managed_policy_name(policy) + role.add_managed_policy(policy) + + for statement in self.required_inline_policy_statements: + role.add_to_policy(statement) + return sfn.StateMachine( self, self.get_construct_id(state_machine_name), state_machine_name=self.env_base.get_state_machine_name(state_machine_name), - logs=( - logs - or sfn.LogOptions( - destination=logs_.LogGroup( - self, - self.get_construct_id(state_machine_name, "state-loggroup"), - log_group_name=self.env_base.get_state_machine_log_group_name( - state_machine_name - ), - removal_policy=cdk.RemovalPolicy.DESTROY, - retention=logs_.RetentionDays.ONE_MONTH, - ) - ) - ), - role=cast(iam.IRole, role), + logs=logs or create_log_options(self, state_machine_name, self.env_base), + role=( + role if role is not None else create_role(self, state_machine_name, self.env_base) + ), # type: ignore[arg-type] definition_body=sfn.DefinitionBody.from_chainable(self.definition), timeout=timeout, ) + @property + def required_managed_policies(self) -> Sequence[Union[iam.ManagedPolicy, str]]: + return [] + + @property + def required_inline_policy_statements(self) -> Sequence[iam.PolicyStatement]: + return [] + class LazyLoadStateMachineFragment(EnvBaseStateMachineFragment): @property @@ -255,10 +336,11 @@ def __init__( @property def task(self) -> sfn.IChainable: + assert self._task, "Task must be set" return self._task @task.setter - def task(self, value: sfn.IChainable): + def task(self, value: Optional[sfn.IChainable]): self._task = value @property diff --git a/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/batch.py b/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/batch.py index 62b7233..e169e31 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/batch.py +++ b/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/batch.py @@ -4,16 +4,12 @@ from aibs_informatics_core.env import EnvBase from aibs_informatics_core.utils.tools.dicttools import convert_key_case from aibs_informatics_core.utils.tools.strtools import pascalcase -from aws_cdk import aws_batch_alpha as batch from aws_cdk import aws_stepfunctions as sfn from aibs_informatics_cdk_lib.constructs_.efs.file_system import MountPointConfiguration -from aibs_informatics_cdk_lib.constructs_.sfn.fragments.base import ( - EnvBaseStateMachineFragment, - StateMachineFragment, -) +from aibs_informatics_cdk_lib.constructs_.sfn.fragments.base import EnvBaseStateMachineFragment from aibs_informatics_cdk_lib.constructs_.sfn.states.batch import BatchOperation -from aibs_informatics_cdk_lib.constructs_.sfn.utils import enclosed_chain +from aibs_informatics_cdk_lib.constructs_.sfn.states.common import CommonOperation if TYPE_CHECKING: from mypy_boto3_batch.type_defs import MountPointTypeDef, VolumeTypeDef @@ -99,13 +95,14 @@ def __init__( job_definition=sfn.JsonPath.string_at("$.taskResult.register.JobDefinitionArn"), ) - register = StateMachineFragment.enclose( + register = CommonOperation.enclose_chainable( self, id + " Register", register_chain, result_path="$.taskResult.register" ) - submit = StateMachineFragment.enclose( + # submit = StateMachineFragment.enclose( + submit = CommonOperation.enclose_chainable( self, id + " Submit", submit_chain, result_path="$.taskResult.submit" - ).to_single_state(output_path="$[0]") - deregister = StateMachineFragment.enclose( + ).to_single_state(id=f"{id} Enclosure", output_path="$[0]") + deregister = CommonOperation.enclose_chainable( self, id + " Deregister", deregister_chain, result_path="$.taskResult.deregister" ) submit.add_catch( @@ -231,15 +228,8 @@ def __init__( "default": defaults, }, ) - - merge = sfn.Pass( - self, - "Merge", - parameters={ - "request": sfn.JsonPath.json_merge( - sfn.JsonPath.object_at("$.default"), sfn.JsonPath.object_at("$.input") - ), - }, + merge_chain = CommonOperation.merge_defaults( + self, f"{id}", defaults=defaults, input_path="$.input", result_path="$.request" ) submit_job = SubmitJobFragment( @@ -262,4 +252,4 @@ def __init__( platform_capabilities=sfn.JsonPath.string_at("$.request.platform_capabilities"), ).to_single_state() - self.definition = start.next(merge).next(submit_job) + self.definition = start.next(merge_chain).next(submit_job) diff --git a/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/informatics.py b/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/informatics.py index f011b68..19c4ac3 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/informatics.py +++ b/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/informatics.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, List, Literal, Mapping, Optional, Union +from typing import TYPE_CHECKING, Any, Iterable, List, Literal, Mapping, Optional, Sequence, Union import constructs from aibs_informatics_aws_utils.constants.lambda_ import ( @@ -9,14 +9,28 @@ ) from aibs_informatics_aws_utils.constants.s3 import S3_SCRATCH_KEY_PREFIX from aibs_informatics_core.env import EnvBase +from aws_cdk import aws_batch_alpha as batch +from aws_cdk import aws_ecr_assets as ecr_assets +from aws_cdk import aws_iam as iam +from aws_cdk import aws_s3 as s3 from aws_cdk import aws_stepfunctions as sfn - +from aws_cdk import aws_stepfunctions_tasks as sfn_tasks + +from aibs_informatics_cdk_lib.common.aws.iam_utils import ( + SFN_STATES_EXECUTION_ACTIONS, + SFN_STATES_READ_ACCESS_ACTIONS, + batch_policy_statement, + s3_policy_statement, + sfn_policy_statement, +) +from aibs_informatics_cdk_lib.constructs_.base import EnvBaseConstructMixins from aibs_informatics_cdk_lib.constructs_.efs.file_system import MountPointConfiguration from aibs_informatics_cdk_lib.constructs_.sfn.fragments.base import EnvBaseStateMachineFragment from aibs_informatics_cdk_lib.constructs_.sfn.fragments.batch import ( AWSBatchMixins, SubmitJobFragment, ) +from aibs_informatics_cdk_lib.constructs_.sfn.states.common import CommonOperation from aibs_informatics_cdk_lib.constructs_.sfn.states.s3 import S3Operation if TYPE_CHECKING: @@ -26,7 +40,21 @@ VolumeTypeDef = dict -class BatchInvokedLambdaFunction(EnvBaseStateMachineFragment, AWSBatchMixins): +class BatchInvokedBaseFragment(EnvBaseStateMachineFragment, EnvBaseConstructMixins): + @property + def required_managed_policies(self) -> Sequence[Union[iam.IManagedPolicy, str]]: + return super().required_managed_policies + + @property + def required_inline_policy_statements(self) -> List[iam.PolicyStatement]: + return [ + *super().required_inline_policy_statements, + batch_policy_statement(self.env_base), + s3_policy_statement(self.env_base), + ] + + +class BatchInvokedLambdaFunction(BatchInvokedBaseFragment, AWSBatchMixins): def __init__( self, scope: constructs.Construct, @@ -109,6 +137,22 @@ def __init__( result_path="$.taskResult.put", ) + default_environment = { + AWS_LAMBDA_FUNCTION_NAME_KEY: name, + AWS_LAMBDA_FUNCTION_HANDLER_KEY: handler, + AWS_LAMBDA_EVENT_PAYLOAD_KEY: sfn.JsonPath.format( + "s3://{}/{}", + sfn.JsonPath.string_at("$.taskResult.put.Bucket"), + sfn.JsonPath.string_at("$.taskResult.put.Key"), + ), + AWS_LAMBDA_EVENT_RESPONSE_LOCATION_KEY: sfn.JsonPath.format( + "s3://{}/{}", bucket_name, response_key + ), + EnvBase.ENV_BASE_KEY: self.env_base, + "AWS_REGION": self.aws_region, + "AWS_ACCOUNT_ID": self.aws_account, + } + submit_job = SubmitJobFragment( self, f"{id} Batch", @@ -119,16 +163,7 @@ def __init__( image=image, environment={ **(environment if environment else {}), - AWS_LAMBDA_FUNCTION_NAME_KEY: name, - AWS_LAMBDA_FUNCTION_HANDLER_KEY: handler, - AWS_LAMBDA_EVENT_PAYLOAD_KEY: sfn.JsonPath.format( - "s3://{}/{}", - sfn.JsonPath.string_at("$.taskResult.put.Bucket"), - sfn.JsonPath.string_at("$.taskResult.put.Key"), - ), - AWS_LAMBDA_EVENT_RESPONSE_LOCATION_KEY: sfn.JsonPath.format( - "s3://{}/{}", bucket_name, response_key - ), + **default_environment, }, memory=memory, vcpus=vcpus, @@ -143,7 +178,7 @@ def __init__( bucket_name=bucket_name, key=response_key, ).to_single_state( - "Get Response from S3", + f"{id} Get Response from S3", output_path="$[0]", ) @@ -157,8 +192,91 @@ def start_state(self) -> sfn.State: def end_states(self) -> List[sfn.INextable]: return self.definition.end_states + @classmethod + def with_defaults( + cls, + scope: constructs.Construct, + id: str, + env_base: EnvBase, + name: str, + job_queue: str, + bucket_name: str, + key_prefix: Optional[str] = None, + image_path: Optional[str] = None, + handler_path: Optional[str] = None, + payload_path: Optional[str] = None, + command: Optional[List[str]] = None, + memory: str = "1024", + vcpus: str = "1", + environment: Optional[Mapping[str, str]] = None, + mount_point_configs: Optional[List[MountPointConfiguration]] = None, + platform_capabilities: Optional[List[Literal["EC2", "FARGATE"]]] = None, + ) -> "BatchInvokedLambdaFunction": + defaults: dict[str, Any] = {} + + defaults["job_queue"] = job_queue + defaults["memory"] = memory + defaults["vcpus"] = vcpus + defaults["environment"] = environment or {} + defaults["platform_capabilities"] = platform_capabilities or ["EC2"] + defaults["bucket_name"] = bucket_name + + defaults["command"] = command if command else [] + + if mount_point_configs: + mount_points, volumes = cls.convert_to_mount_point_and_volumes(mount_point_configs) + defaults["mount_points"] = mount_points + defaults["volumes"] = volumes + + fragment = BatchInvokedLambdaFunction( + scope, + id, + env_base=env_base, + name=name, + image=sfn.JsonPath.string_at("$.image"), + handler=sfn.JsonPath.string_at("$.handler"), + job_queue=sfn.JsonPath.string_at("$.merged.job_queue"), + bucket_name=sfn.JsonPath.string_at("$.merged.bucket_name"), + key_prefix=key_prefix, + payload_path=sfn.JsonPath.string_at("$.payload"), + command=sfn.JsonPath.string_at("$.merged.command"), + environment=environment, + memory=sfn.JsonPath.string_at("$.merged.memory"), + vcpus=sfn.JsonPath.string_at("$.merged.vcpus"), + # TODO: Handle GPU parameter better - right now, we cannot handle cases where it is + # not specified. Setting to zero causes issues with the Batch API. + # If it is set to zero, then the json list of resources are not properly set. + # gpu=sfn.JsonPath.string_at("$.merged.gpu"), + mount_points=sfn.JsonPath.string_at("$.merged.mount_points"), + volumes=sfn.JsonPath.string_at("$.merged.volumes"), + platform_capabilities=sfn.JsonPath.string_at("$.merged.platform_capabilities"), + ) + + start = sfn.Pass( + fragment, + f"Start {id}", + parameters={ + "image": sfn.JsonPath.string_at(image_path or "$.image"), + "handler": sfn.JsonPath.string_at(handler_path or "$.handler"), + "payload": sfn.JsonPath.object_at(payload_path or "$.payload"), + # We will merge the rest with the defaults + "input": sfn.JsonPath.object_at("$"), + }, + ) + + merge_chain = CommonOperation.merge_defaults( + fragment, + f"Merge {id}", + input_path="$.input", + defaults=defaults, + result_path="$.merged", + ) + + fragment.definition = start.next(merge_chain).next(fragment.definition) + return fragment -class BatchInvokedExecutorFragment(EnvBaseStateMachineFragment, AWSBatchMixins): + +class BatchInvokedExecutorFragment(BatchInvokedBaseFragment, AWSBatchMixins): def __init__( self, scope: constructs.Construct, @@ -265,7 +383,7 @@ def __init__( bucket_name=bucket_name, key=response_key, ).to_single_state( - "Get Response from S3", + f"{id} Get Response from S3", output_path="$[0]", ) @@ -278,3 +396,331 @@ def start_state(self) -> sfn.State: @property def end_states(self) -> List[sfn.INextable]: return self.definition.end_states + + +class DataSyncFragment(BatchInvokedBaseFragment, EnvBaseConstructMixins): + def __init__( + self, + scope: constructs.Construct, + id: str, + env_base: EnvBase, + aibs_informatics_docker_asset: Union[ecr_assets.DockerImageAsset, str], + batch_job_queue: Union[batch.JobQueue, str], + scaffolding_bucket: s3.Bucket, + mount_point_configs: Optional[Iterable[MountPointConfiguration]] = None, + ) -> None: + """Sync data from one s3 bucket to another + + + Args: + scope (Construct): construct scope + id (str): id + env_base (EnvBase): env base + aibs_informatics_docker_asset (DockerImageAsset|str): Docker image asset or image uri + str for the aibs informatics aws lambda + batch_job_queue (JobQueue|str): Default batch job queue or job queue name str that + the batch job will be submitted to. This can be override by the payload. + primary_bucket (Bucket): Primary bucket used for request/response json blobs used in + the batch invoked lambda function. + mount_point_configs (Optional[Iterable[MountPointConfiguration]], optional): + List of mount point configurations to use. These can be overridden in the payload. + + """ + super().__init__(scope, id, env_base) + + aibs_informatics_image_uri = ( + aibs_informatics_docker_asset + if isinstance(aibs_informatics_docker_asset, str) + else aibs_informatics_docker_asset.image_uri + ) + + self.batch_job_queue_name = ( + batch_job_queue if isinstance(batch_job_queue, str) else batch_job_queue.job_queue_name + ) + + start = sfn.Pass( + self, + "Input Restructure", + parameters={ + "handler": "aibs_informatics_aws_lambda.handlers.data_sync.data_sync_handler", + "image": aibs_informatics_image_uri, + "payload": sfn.JsonPath.object_at("$"), + }, + ) + + self.fragment = BatchInvokedLambdaFunction.with_defaults( + self, + "Data Sync", + env_base=self.env_base, + name="data-sync", + job_queue=self.batch_job_queue_name, + bucket_name=scaffolding_bucket.bucket_name, + handler_path="$.handler", + image_path="$.image", + payload_path="$.payload", + memory="1024", + vcpus="1", + mount_point_configs=list(mount_point_configs) if mount_point_configs else None, + environment={ + EnvBase.ENV_BASE_KEY: self.env_base, + "AWS_REGION": self.aws_region, + "AWS_ACCOUNT_ID": self.aws_account, + }, + ) + + self.definition = start.next(self.fragment.to_single_state()) + + @property + def required_managed_policies(self) -> List[Union[iam.IManagedPolicy, str]]: + return [ + *super().required_managed_policies, + *[_ for _ in self.fragment.required_managed_policies], + ] + + @property + def required_inline_policy_statements(self) -> List[iam.PolicyStatement]: + return [ + *self.fragment.required_inline_policy_statements, + *super().required_inline_policy_statements, + sfn_policy_statement( + self.env_base, + actions=SFN_STATES_EXECUTION_ACTIONS + SFN_STATES_READ_ACCESS_ACTIONS, + ), + ] + + +class DemandExecutionFragment(EnvBaseStateMachineFragment, EnvBaseConstructMixins): + def __init__( + self, + scope: constructs.Construct, + id: str, + env_base: EnvBase, + aibs_informatics_docker_asset: Union[ecr_assets.DockerImageAsset, str], + scaffolding_bucket: s3.Bucket, + scaffolding_job_queue: Union[batch.JobQueue, str], + batch_invoked_lambda_state_machine: sfn.StateMachine, + data_sync_state_machine: sfn.StateMachine, + shared_mount_point_config: Optional[MountPointConfiguration], + scratch_mount_point_config: Optional[MountPointConfiguration], + ) -> None: + super().__init__(scope, id, env_base) + + # ----------------- Validation ----------------- + if not (shared_mount_point_config and scratch_mount_point_config) or not ( + shared_mount_point_config or scratch_mount_point_config + ): + raise ValueError( + "If shared or scratch mount point configurations are provided," + "Both shared and scratch mount point configurations must be provided." + ) + + # ------------------- Setup ------------------- + + config_scaffolding_path = "config.scaffolding" + config_setup_results_path = f"{config_scaffolding_path}.setup_results" + config_batch_args_path = f"{config_setup_results_path}.batch_args" + + config_cleanup_results_path = f"tasks.cleanup.cleanup_results" + + # Create common kwargs for the batch invoked lambda functions + # - specify the bucket name and job queue + # - specify the mount points and volumes if provided + batch_invoked_lambda_kwargs: dict[str, Any] = { + "bucket_name": scaffolding_bucket.bucket_name, + "image": aibs_informatics_docker_asset + if isinstance(aibs_informatics_docker_asset, str) + else aibs_informatics_docker_asset.image_uri, + "job_queue": scaffolding_job_queue + if isinstance(scaffolding_job_queue, str) + else scaffolding_job_queue.job_queue_name, + } + + # Create request input for the demand scaffolding + file_system_configurations = {} + + # Update arguments with mount points and volumes if provided + if shared_mount_point_config or scratch_mount_point_config: + mount_points = [] + volumes = [] + if shared_mount_point_config: + # update file system configurations for scaffolding function + file_system_configurations["shared"] = { + "file_system": shared_mount_point_config.file_system_id, + "access_point": shared_mount_point_config.access_point_id, + "container_path": shared_mount_point_config.mount_point, + } + # add to mount point and volumes list for batch invoked lambda functions + mount_points.append( + shared_mount_point_config.to_batch_mount_point("shared", sfn_format=True) + ) + volumes.append( + shared_mount_point_config.to_batch_volume("shared", sfn_format=True) + ) + + if scratch_mount_point_config: + # update file system configurations for scaffolding function + file_system_configurations["scratch"] = { + "file_system": scratch_mount_point_config.file_system_id, + "access_point": scratch_mount_point_config.access_point_id, + "container_path": scratch_mount_point_config.mount_point, + } + # add to mount point and volumes list for batch invoked lambda functions + mount_points.append( + scratch_mount_point_config.to_batch_mount_point("scratch", sfn_format=True) + ) + volumes.append( + scratch_mount_point_config.to_batch_volume("scratch", sfn_format=True) + ) + + batch_invoked_lambda_kwargs["mount_points"] = mount_points + batch_invoked_lambda_kwargs["volumes"] = volumes + + start_state = sfn.Pass( + self, + f"Start Demand Batch Task", + parameters={ + "request": { + "demand_execution": sfn.JsonPath.object_at("$"), + "file_system_configurations": file_system_configurations, + } + }, + ) + + prep_scaffolding_task = CommonOperation.enclose_chainable( + self, + "Prepare Demand Scaffolding", + sfn.Pass( + self, + "Pass: Prepare Demand Scaffolding", + parameters={ + "handler": "aibs_informatics_aws_lambda.handlers.demand.scaffolding.handler", + "payload": sfn.JsonPath.object_at("$"), + **batch_invoked_lambda_kwargs, + }, + ).next( + sfn_tasks.StepFunctionsStartExecution( + self, + "SM: Prepare Demand Scaffolding", + state_machine=batch_invoked_lambda_state_machine, + integration_pattern=sfn.IntegrationPattern.RUN_JOB, + associate_with_parent=False, + input_path="$", + output_path=f"$.Output", + ) + ), + input_path="$.request", + result_path=f"$.{config_scaffolding_path}", + ) + + create_def_and_prepare_job_args_task = CommonOperation.enclose_chainable( + self, + "Create Definition and Prep Job Args", + sfn.Pass( + self, + "Pass: Create Definition and Prep Job Args", + parameters={ + "handler": "aibs_informatics_aws_lambda.handlers.batch.create.handler", + "payload": sfn.JsonPath.object_at("$"), + **batch_invoked_lambda_kwargs, + }, + ).next( + sfn_tasks.StepFunctionsStartExecution( + self, + "SM: Create Definition and Prep Job Args", + state_machine=batch_invoked_lambda_state_machine, + integration_pattern=sfn.IntegrationPattern.RUN_JOB, + associate_with_parent=False, + input_path="$", + output_path=f"$.Output", + ) + ), + input_path="$.batch_create_request", + result_path=f"$", + ) + + setup_tasks = ( + sfn.Parallel( + self, + "Execution Setup Steps", + input_path=f"$.{config_scaffolding_path}.setup_configs", + result_path=f"$.{'.'.join(config_batch_args_path.split('.')[:-1])}", + result_selector={f'{config_batch_args_path.split(".")[-1]}.$': "$[0]"}, + ) + .branch(create_def_and_prepare_job_args_task) + .branch( + sfn.Map( + self, + "Transfer Inputs TO Batch Job", + items_path="$.data_sync_requests", + ).iterator( + sfn_tasks.StepFunctionsStartExecution( + self, + "Transfer Input", + state_machine=data_sync_state_machine, + integration_pattern=sfn.IntegrationPattern.RUN_JOB, + associate_with_parent=False, + result_path=sfn.JsonPath.DISCARD, + ) + ) + ) + ) + + execution_task = sfn.CustomState( + self, + f"Submit Batch Job", + state_json={ + "Type": "Task", + "Resource": "arn:aws:states:::batch:submitJob.sync", + # fmt: off + "Parameters": { + "JobName.$": sfn.JsonPath.string_at(f"$.{config_batch_args_path}.job_name"), + "JobDefinition.$": sfn.JsonPath.string_at(f"$.{config_batch_args_path}.job_definition_arn"), + "JobQueue.$": sfn.JsonPath.string_at(f"$.{config_batch_args_path}.job_queue_arn"), + "Parameters.$": sfn.JsonPath.object_at(f"$.{config_batch_args_path}.parameters"), + "ContainerOverrides.$": sfn.JsonPath.object_at(f"$.{config_batch_args_path}.container_overrides"), + }, + # fmt: on + "ResultPath": "$.tasks.batch_submit_task", + }, + ) + + cleanup_tasks = sfn.Chain.start( + sfn.Map( + self, + "Transfer Results FROM Batch Job", + input_path=f"$.{config_scaffolding_path}.cleanup_configs.data_sync_requests", + result_path=f"$.{config_cleanup_results_path}.transfer_results", + ).iterator( + sfn_tasks.StepFunctionsStartExecution( + self, + "Transfer Result", + state_machine=data_sync_state_machine, + integration_pattern=sfn.IntegrationPattern.RUN_JOB, + associate_with_parent=False, + result_path=sfn.JsonPath.DISCARD, + ) + ) + ).to_single_state("Execution Cleanup Steps", output_path="$[0]") + + # fmt: off + definition = ( + start_state + .next(prep_scaffolding_task) + .next(setup_tasks) + .next(execution_task) + .next(cleanup_tasks) + ) + # fmt: on + self.definition = definition + + @property + def required_inline_policy_statements(self) -> List[iam.PolicyStatement]: + return [ + *super().required_inline_policy_statements, + batch_policy_statement(self.env_base), + s3_policy_statement(self.env_base), + sfn_policy_statement( + self.env_base, + actions=SFN_STATES_EXECUTION_ACTIONS + SFN_STATES_READ_ACCESS_ACTIONS, + ), + ] diff --git a/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/lambda_.py b/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/lambda_.py index 08575f6..a949384 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/lambda_.py +++ b/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/lambda_.py @@ -1,21 +1,12 @@ -from typing import TYPE_CHECKING, List, Mapping, Optional, Union, cast +from typing import cast import constructs -from aibs_informatics_aws_utils.constants.lambda_ import ( - AWS_LAMBDA_EVENT_PAYLOAD_KEY, - AWS_LAMBDA_EVENT_RESPONSE_LOCATION_KEY, - AWS_LAMBDA_FUNCTION_HANDLER_KEY, - AWS_LAMBDA_FUNCTION_NAME_KEY, -) -from aibs_informatics_aws_utils.constants.s3 import S3_SCRATCH_KEY_PREFIX from aibs_informatics_core.env import EnvBase -from aws_cdk import aws_ecr_assets as ecr_assets from aws_cdk import aws_lambda as lambda_ from aws_cdk import aws_stepfunctions as sfn from aws_cdk import aws_stepfunctions_tasks as stepfn_tasks from aibs_informatics_cdk_lib.constructs_.sfn.fragments.base import EnvBaseStateMachineFragment -from aibs_informatics_cdk_lib.constructs_.sfn.states.s3 import S3Operation class LambdaFunctionFragment(EnvBaseStateMachineFragment): diff --git a/src/aibs_informatics_cdk_lib/constructs_/sfn/states/common.py b/src/aibs_informatics_cdk_lib/constructs_/sfn/states/common.py new file mode 100644 index 0000000..ed6779c --- /dev/null +++ b/src/aibs_informatics_cdk_lib/constructs_/sfn/states/common.py @@ -0,0 +1,116 @@ +from typing import Any, Optional + +import constructs +from aws_cdk import aws_stepfunctions as sfn + + +class CommonOperation: + @classmethod + def merge_defaults( + cls, + scope: constructs.Construct, + id: str, + defaults: dict[str, Any], + input_path: str = "$", + result_path: Optional[str] = None, + ) -> sfn.Chain: + """Wrapper chain that merges input with defaults. + + + + Args: + scope (constructs.Construct): construct scope + id (str): identifier for the states created + defaults (dict[str, Any]): default values to merge with input + input_path (str, optional): Input path of object to merge. Defaults to "$". + result_path (Optional[str], optional): result path to store merged results. + Defaults to whatever input_path is defined as. + + Returns: + sfn.Chain: the new chain that merges defaults with input + """ + new_input_path = result_path if result_path is not None else input_path + init_state = sfn.Pass( + scope, + "Merge Defaults", + parameters={ + "input": sfn.JsonPath.object_at("$"), + "default": defaults, + }, + ) + merge = sfn.Pass( + scope, + "Merge", + parameters={ + "merged": sfn.JsonPath.json_merge( + sfn.JsonPath.object_at("$.default"), sfn.JsonPath.object_at("$.input") + ), + }, + output_path="$.merged", + ) + + parallel = init_state.next(merge).to_single_state( + id=id, input_path=input_path, result_path=new_input_path + ) + restructure = sfn.Pass( + scope, + f"{id} Restructure", + input_path=f"{new_input_path}[0]", + result_path=new_input_path, + ) + return parallel.next(restructure) + + @classmethod + def enclose_chainable( + cls, + scope: constructs.Construct, + id: str, + definition: sfn.IChainable, + input_path: Optional[str] = None, + result_path: Optional[str] = None, + ) -> sfn.Chain: + """Enclose the current state machine fragment within a parallel state. + + Notes: + - If input_path is not provided, it will default to "$" + - If result_path is not provided, it will default to input_path + + Args: + id (str): an identifier for the parallel state + input_path (Optional[str], optional): input path for the enclosed state. + Defaults to "$". + result_path (Optional[str], optional): result path to put output of enclosed state. + Defaults to same as input_path. + + Returns: + sfn.Chain: the new state machine fragment + """ + if input_path is None: + input_path = "$" + if result_path is None: + result_path = input_path + + chain = ( + sfn.Chain.start(definition) + if not isinstance(definition, (sfn.Chain, sfn.StateMachineFragment)) + else definition + ) + + if isinstance(chain, sfn.Chain): + parallel = chain.to_single_state( + id=f"{id} Enclosure", input_path=input_path, result_path=result_path + ) + else: + parallel = chain.to_single_state(input_path=input_path, result_path=result_path) + definition = sfn.Chain.start(parallel) + + if result_path and result_path != sfn.JsonPath.DISCARD: + restructure = sfn.Pass( + scope, + f"{id} Enclosure Post", + input_path=f"{result_path}[0]", + result_path=result_path, + ) + definition = definition.next(restructure) + + return definition diff --git a/src/aibs_informatics_cdk_lib/constructs_/sfn/states/s3.py b/src/aibs_informatics_cdk_lib/constructs_/sfn/states/s3.py index 283dfa4..6cdaa48 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/sfn/states/s3.py +++ b/src/aibs_informatics_cdk_lib/constructs_/sfn/states/s3.py @@ -3,10 +3,7 @@ import constructs from aws_cdk import aws_stepfunctions as sfn -from aibs_informatics_cdk_lib.constructs_.sfn.utils import ( - convert_reference_paths, - convert_reference_paths_in_mapping, -) +from aibs_informatics_cdk_lib.constructs_.sfn.utils import convert_reference_paths_in_mapping class S3Operation: diff --git a/src/aibs_informatics_cdk_lib/constructs_/sfn/utils.py b/src/aibs_informatics_cdk_lib/constructs_/sfn/utils.py index 53d64b1..03c8d85 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/sfn/utils.py +++ b/src/aibs_informatics_cdk_lib/constructs_/sfn/utils.py @@ -1,16 +1,37 @@ import re from functools import reduce -from typing import Any, ClassVar, List, Mapping, Pattern, Union, cast +from typing import Any, ClassVar, Dict, List, Mapping, Optional, Pattern, TypeVar, Union, cast import aws_cdk as cdk +import constructs from aibs_informatics_core.utils.json import JSON +from aibs_informatics_core.utils.tools.dicttools import convert_key_case +from aibs_informatics_core.utils.tools.strtools import pascalcase from aws_cdk import aws_stepfunctions as sfn +T = TypeVar("T") + def convert_reference_paths_in_mapping(parameters: Mapping[str, Any]) -> Mapping[str, Any]: return {k: convert_reference_paths(v) for k, v in parameters.items()} +def convert_to_sfn_api_action_case(parameters: T) -> T: + """Converts a dictionary of parameters to the format expected by the Step Functions for service integration. + + Even if a native API specifies a parameter in camelCase, the Step Functions SDK expects it in pascal case. + + https://docs.aws.amazon.com/step-functions/latest/dg/supported-services-awssdk.html#use-awssdk-integ + + Args: + parameters (Dict[str, Any]): parameters for SDK action + + Returns: + Dict[str, Any]: parameters for SDK action in pascal case + """ + return convert_key_case(parameters, pascalcase) + + def convert_reference_paths(parameters: JSON) -> JSON: if isinstance(parameters, dict): return {k: convert_reference_paths(v) for k, v in parameters.items()} @@ -28,15 +49,57 @@ def convert_reference_paths(parameters: JSON) -> JSON: def enclosed_chain( - id: str, chain: sfn.Chain, result_path: str = "$", output_path: str = "$" + scope: constructs.Construct, + id: str, + definition: sfn.IChainable, + input_path: Optional[str] = None, + result_path: Optional[str] = None, ) -> sfn.Chain: - pre = sfn.Pass(id + " Parallel Prep", result_path=result_path, output_path=output_path) # type: ignore - parallel = chain.to_single_state( - id=f"{id} Parallel", input_path=result_path, output_path="$[0]" + """Enclose the current state machine fragment within a parallel state. + + Notes: + - If input_path is not provided, it will default to "$" + - If result_path is not provided, it will default to input_path + + Args: + id (str): an identifier for the parallel state + input_path (Optional[str], optional): input path for the enclosed state. + Defaults to "$". + result_path (Optional[str], optional): result path to put output of enclosed state. + Defaults to same as input_path. + + Returns: + sfn.Chain: the new state machine fragment + """ + if input_path is None: + input_path = "$" + if result_path is None: + result_path = input_path + + chain = ( + sfn.Chain.start(definition) + if not isinstance(definition, (sfn.Chain, sfn.StateMachineFragment)) + else definition ) - post = sfn.Pass(id + " Parallel Post", result_path=result_path, output_path=output_path) # type: ignore - return sfn.Chain.start(pre).next(parallel).next(post) + if isinstance(chain, sfn.Chain): + parallel = chain.to_single_state( + id=f"{id} Enclosure", input_path=input_path, result_path=result_path + ) + else: + parallel = chain.to_single_state(input_path=input_path, result_path=result_path) + definition = sfn.Chain.start(parallel) + + if result_path and result_path != sfn.JsonPath.DISCARD: + restructure = sfn.Pass( + scope, + f"{id} Enclosure Post", + input_path=f"{result_path}[0]", + result_path=result_path, + ) + definition = definition.next(restructure) + + return definition class JsonReferencePath(str): diff --git a/src/aibs_informatics_cdk_lib/stacks/data_sync.py b/src/aibs_informatics_cdk_lib/stacks/data_sync.py deleted file mode 100644 index c3d7e8c..0000000 --- a/src/aibs_informatics_cdk_lib/stacks/data_sync.py +++ /dev/null @@ -1,474 +0,0 @@ -from pathlib import Path -from typing import TYPE_CHECKING, List, Mapping, Optional, Union, cast - -import aws_cdk as cdk -import constructs -from aibs_informatics_aws_utils.batch import to_mount_point, to_volume -from aibs_informatics_aws_utils.constants.efs import EFS_MOUNT_POINT_PATH_VAR -from aibs_informatics_core.env import EnvBase, ResourceNameBaseEnum -from aws_cdk import aws_batch_alpha as batch -from aws_cdk import aws_ec2 as ec2 -from aws_cdk import aws_ecr_assets as ecr_assets -from aws_cdk import aws_efs as efs -from aws_cdk import aws_iam as iam -from aws_cdk import aws_lambda as lambda_ -from aws_cdk import aws_logs as logs -from aws_cdk import aws_s3 as s3 -from aws_cdk import aws_stepfunctions as sfn -from aws_cdk import aws_stepfunctions_tasks as stepfn_tasks - -from aibs_informatics_cdk_lib.common.aws.iam_utils import ( - batch_policy_statement, - dynamodb_policy_statement, - lambda_policy_statement, - s3_policy_statement, -) -from aibs_informatics_cdk_lib.constructs_.assets.code_asset import ( - GLOBAL_GLOB_EXCLUDES, - PYTHON_GLOB_EXCLUDES, - CodeAsset, -) -from aibs_informatics_cdk_lib.constructs_.batch.types import IBatchEnvironmentDescriptor -from aibs_informatics_cdk_lib.constructs_.ec2.network import EnvBaseVpc -from aibs_informatics_cdk_lib.constructs_.efs.file_system import ( - EnvBaseFileSystem, - create_access_point, - grant_file_system_access, -) -from aibs_informatics_cdk_lib.constructs_.s3 import EnvBaseBucket, grant_bucket_access -from aibs_informatics_cdk_lib.constructs_.sfn.fragments.batch import SubmitJobFragment -from aibs_informatics_cdk_lib.constructs_.sfn.fragments.informatics import ( - BatchInvokedLambdaFunction, -) -from aibs_informatics_cdk_lib.constructs_.sfn.states.s3 import S3Operation -from aibs_informatics_cdk_lib.stacks.base import EnvBaseStack - -if TYPE_CHECKING: - from mypy_boto3_batch.type_defs import ( - KeyValuePairTypeDef, - MountPointTypeDef, - RegisterJobDefinitionRequestRequestTypeDef, - ResourceRequirementTypeDef, - VolumeTypeDef, - ) -else: - ResourceRequirementTypeDef = dict - MountPointTypeDef = dict - VolumeTypeDef = dict - KeyValuePairTypeDef = dict - RegisterJobDefinitionRequestRequestTypeDef = dict - - -DATA_SYNC_ASSET_NAME = "aibs_informatics_aws_lambda" - -EFS_MOUNT_PATH = "/opt/efs" -EFS_VOLUME_NAME = "efs-file-system" - - -class DataSyncFunctions(ResourceNameBaseEnum): - PUT_JSON_TO_FILE = "put-json-to-file" - GET_JSON_FROM_FILE = "get-json-from-file" - DATA_SYNC = "data-sync" - BATCH_DATA_SYNC = "batch-data-sync" - PREPARE_BATCH_DATA_SYNC = "prep-batch-data-sync" - - -class DataSyncStack(EnvBaseStack): - def __init__( - self, - scope: constructs.Construct, - id: str, - env_base: EnvBase, - asset_directory: Union[Path, str], - vpc: ec2.Vpc, - primary_bucket: s3.Bucket, - file_system: efs.FileSystem, - batch_job_queue: batch.JobQueue, - s3_buckets: List[s3.Bucket], - **kwargs, - ): - super().__init__(scope, id, env_base, **kwargs) - - self._vpc = vpc - if self._vpc is None: - self._vpc = EnvBaseVpc(self, "vpc", env_base=self.env_base) - - self._primary_bucket = primary_bucket - self._buckets = [primary_bucket, *s3_buckets] - - self._file_system = file_system - - if isinstance(self._file_system, EnvBaseFileSystem): - self._lambda_file_system = self._file_system.as_lambda_file_system() - else: - root_ap = create_access_point(self, self._file_system, "data-sync", "/") - self._lambda_file_system = lambda_.FileSystem.from_efs_access_point( - root_ap, "/mnt/efs" - ) - - self._job_queue = batch_job_queue - - # Create code and docker asset for AWS lambda package - asset_directory = Path(asset_directory) - self._code_asset = CodeAsset.create_py_code_asset( - path=asset_directory, - context_path=None, - runtime=lambda_.Runtime.PYTHON_3_11, - environment={self.env_base.ENV_BASE_KEY: self.env_base}, - ) - # self._docker_asset = ecr_assets.DockerImageAsset( - # self, - # "docker", - # asset_name=DATA_SYNC_ASSET_NAME, - # directory=asset_directory.resolve().as_posix(), - # file="docker/Dockerfile", - # build_ssh="default", - # platform=ecr_assets.Platform.LINUX_AMD64, - # exclude=[*GLOBAL_GLOB_EXCLUDES, *PYTHON_GLOB_EXCLUDES] - # ) - - self.create_lambda_functions() - self.create_step_functions() - - @property - def buckets(self) -> List[s3.Bucket]: - return self._buckets - - @property - def file_system(self) -> efs.FileSystem: - return self._file_system - - @property - def lambda_file_system(self) -> lambda_.FileSystem: - return self._lambda_file_system - - @property - def vpc(self) -> ec2.Vpc: - return self._vpc - - # @property - # def docker_image_asset(self) -> ecr_assets.DockerImageAsset: - # return self._docker_asset - - @property - def code_asset(self) -> CodeAsset: - return self._code_asset - - def create_lambda_functions(self) -> None: - # ---------------------------------------------------------- - # Data Transfer functions - # ---------------------------------------------------------- - - data_sync_efs_lambda_arguments = dict( - filesystem=self.lambda_file_system, - vpc=self.vpc, - ) - - # self.put_json_to_file_fn: lambda_.Function = lambda_.Function( - # self, - # self.get_construct_id(DataSyncFunctions.PUT_JSON_TO_FILE), - # description=f"Lambda to put content to EFS/S3 [{self.env_base}]", - # function_name=self.get_resource_name(DataSyncFunctions.PUT_JSON_TO_FILE), - # handler="aibs_informatics_aws_lambda.handlers.data_sync.put_json_to_file_handler", - # runtime=self.code_asset.default_runtime, - # code=self.code_asset.as_code, - # memory_size=128, - # timeout=cdk.Duration.seconds(30), - # environment=self.code_asset.environment, - # **data_sync_efs_lambda_arguments, - # ) - # grant_bucket_access(self.buckets, self.put_json_to_file_fn.role, "rw") - # grant_file_system_access(self.file_system, self.put_json_to_file_fn) - - # self.get_json_from_file_fn: lambda_.Function = lambda_.Function( - # self, - # self.get_construct_id(DataSyncFunctions.GET_JSON_FROM_FILE), - # description=f"Lambda to get content from EFS/S3 [{self.env_base}]", - # function_name=self.get_resource_name(DataSyncFunctions.GET_JSON_FROM_FILE), - # handler="aibs_informatics_aws_lambda.handlers.data_sync.get_json_from_file_handler", - # runtime=self.code_asset.default_runtime, - # code=self.code_asset.as_code, - # memory_size=128, - # timeout=cdk.Duration.seconds(30), - # environment=self.code_asset.environment, - # **data_sync_efs_lambda_arguments, - # ) - # grant_bucket_access(self.buckets, self.get_json_from_file_fn.role, "rw") - # grant_file_system_access(self.file_system, self.get_json_from_file_fn) - - self.data_sync_fn = lambda_.Function( - self, - self.get_construct_id(DataSyncFunctions.DATA_SYNC), - description=f"Lambda to transfer data between Local (e.g. EFS) and remote (e.g. S3) [{self.env_base}]", - function_name=self.get_resource_name(DataSyncFunctions.DATA_SYNC), - handler="aibs_informatics_aws_lambda.handlers.data_sync.data_sync_handler", - runtime=self.code_asset.default_runtime, - code=self.code_asset.as_code, - memory_size=10240, - timeout=cdk.Duration.minutes(15), - environment=self.code_asset.environment, - **data_sync_efs_lambda_arguments, - ) - grant_bucket_access(self.buckets, self.data_sync_fn.role, "rw") - grant_file_system_access(self.file_system, self.data_sync_fn) - - self.batch_data_sync_fn_handler = ( - "aibs_informatics_aws_lambda.handlers.common.data_sync.batch_data_sync_handler" - ) - self.batch_data_sync_fn = lambda_.Function( - self, - self.get_construct_id(DataSyncFunctions.BATCH_DATA_SYNC), - description=f"Lambda to sync data between local (e.g. EFS) and remote (e.g. S3) in batch [{self.env_base}]", - function_name=self.get_resource_name(DataSyncFunctions.BATCH_DATA_SYNC), - handler=self.batch_data_sync_fn_handler, - runtime=self.code_asset.default_runtime, - code=self.code_asset.as_code, - memory_size=10240, - timeout=cdk.Duration.minutes(15), - environment=self.code_asset.environment, - **data_sync_efs_lambda_arguments, - ) - grant_bucket_access(self.buckets, self.batch_data_sync_fn.role, "rw") - grant_file_system_access(self.file_system, self.batch_data_sync_fn) - - self.prepare_batch_data_sync_fn = lambda_.Function( - self, - self.get_construct_id(DataSyncFunctions.PREPARE_BATCH_DATA_SYNC), - description=f"Lambda to prepare batch data transfer [{self.env_base}]", - function_name=self.get_resource_name(DataSyncFunctions.PREPARE_BATCH_DATA_SYNC), - handler="aibs_informatics_aws_lambda.handlers.data_sync.prep_batch_data_sync_handler", - runtime=self.code_asset.default_runtime, - code=self.code_asset.as_code, - memory_size=1024, - timeout=cdk.Duration.minutes(10), - environment=self.code_asset.environment, - **data_sync_efs_lambda_arguments, - ) - grant_bucket_access(self.buckets, self.prepare_batch_data_sync_fn.role, "rw") - grant_file_system_access(self.file_system, self.prepare_batch_data_sync_fn) - # allow read access to S3 - self.add_managed_policies(self.prepare_batch_data_sync_fn.role, "AmazonS3ReadOnlyAccess") - - def create_step_functions(self): - start_pass_state = sfn.Pass( - self, - "Data Sync: Start", - parameters={ - "request": sfn.JsonPath.string_at("$"), - }, - ) - prep_batch_sync_task_name = "prep-batch-data-sync-requests" - prep_batch_sync_lambda_task = stepfn_tasks.LambdaInvoke( - self, - "Prep Batch Data Sync", - lambda_function=self.prepare_batch_data_sync_fn, - payload=sfn.TaskInput.from_json_path_at("$.request"), - payload_response_only=False, - result_path=f"$.tasks.{prep_batch_sync_task_name}.response", - ) - - batch_sync_map_state = sfn.Map( - self, - "Batch Data Sync: Map Start", - comment="Runs requests for batch sync in parallel", - items_path=f"$.tasks.{prep_batch_sync_task_name}.response.Payload.requests", - result_path=sfn.JsonPath.DISCARD, - ) - batch_sync_map_state.iterator( - BatchInvokedLambdaFunction( - self, - "Batch Data Sync Chain", - env_base=self.env_base, - image=f"{self.account}.dkr.ecr.{self.region}.amazonaws.com/aibs-informatics-aws-lambda:latest", - name="batch-data-sync", - handler=self.batch_data_sync_fn_handler, - bucket_name=self._primary_bucket.bucket_name, - job_queue=self._job_queue.job_queue_name, - environment={EFS_MOUNT_POINT_PATH_VAR: EFS_MOUNT_PATH}, - memory=2048, - vcpus=1, - mount_points=[to_mount_point(EFS_MOUNT_PATH, False, EFS_VOLUME_NAME)], - volumes=[ - to_volume( - None, - EFS_VOLUME_NAME, - { - "fileSystemId": self.file_system.file_system_id, - "rootDirectory": "/", - }, - ) - ], - platform_capabilities=( - ["FARGATE"] - if any( - [ - isinstance(ce.compute_environment, batch.FargateComputeEnvironment) - for ce in self._job_queue.compute_environments - ] - ) - else None - ), - ) - ) - # fmt: off - definition = ( - start_pass_state - .next(prep_batch_sync_lambda_task) - .next(batch_sync_map_state) - ) - # fmt: on - - data_sync_state_machine_name = self.get_resource_name("data-sync") - self.data_sync_state_machine = sfn.StateMachine( - self, - self.env_base.get_construct_id("data-sync", "state-machine"), - state_machine_name=data_sync_state_machine_name, - logs=sfn.LogOptions( - destination=logs.LogGroup( - self, - self.get_construct_id(data_sync_state_machine_name, "state-loggroup"), - log_group_name=self.env_base.get_state_machine_log_group_name("data-sync"), - removal_policy=cdk.RemovalPolicy.DESTROY, - retention=logs.RetentionDays.ONE_MONTH, - ) - ), - role=iam.Role( - self, - self.env_base.get_construct_id(data_sync_state_machine_name, "role"), - assumed_by=iam.ServicePrincipal("states.amazonaws.com"), - managed_policies=[ - iam.ManagedPolicy.from_aws_managed_policy_name( - "AmazonAPIGatewayInvokeFullAccess" - ), - iam.ManagedPolicy.from_aws_managed_policy_name("AWSStepFunctionsFullAccess"), - iam.ManagedPolicy.from_aws_managed_policy_name("CloudWatchLogsFullAccess"), - iam.ManagedPolicy.from_aws_managed_policy_name("CloudWatchEventsFullAccess"), - ], - inline_policies={ - "default": iam.PolicyDocument( - statements=[ - batch_policy_statement(self.env_base), - dynamodb_policy_statement(self.env_base), - s3_policy_statement(self.env_base), - lambda_policy_statement(self.env_base), - ] - ), - }, - ), - definition_body=sfn.DefinitionBody.from_chainable(definition), - timeout=cdk.Duration.hours(12), - ) - - def create_rclone_step_function(self): - start_pass_state = sfn.Pass( - self, - "Rclone: Start", - parameters={ - "request": sfn.JsonPath.string_at("$"), - }, - ) - default_args = { - "command": "sync", - } - default_rclone_image = "rclone/rclone:latest" - prep_batch_sync_task_name = "prep-batch-data-sync-requests" - prep_batch_sync_lambda_task = stepfn_tasks.LambdaInvoke( - self, - "Prep Batch Data Sync", - lambda_function=self.prepare_batch_data_sync_fn, - payload=sfn.TaskInput.from_json_path_at("$.request"), - payload_response_only=False, - result_path=f"$.tasks.{prep_batch_sync_task_name}.response", - ) - - batch_sync_map_state = sfn.Map( - self, - "Batch Data Sync: Map Start", - comment="Runs requests for batch sync in parallel", - items_path=f"$.tasks.{prep_batch_sync_task_name}.response.Payload.requests", - result_path=sfn.JsonPath.DISCARD, - ) - batch_sync_map_state.iterator( - BatchInvokedLambdaFunction( - self, - "Batch Data Sync Chain", - env_base=self.env_base, - image=f"{self.account}.dkr.ecr.{self.region}.amazonaws.com/aibs-informatics-aws-lambda:latest", - name="batch-data-sync", - handler=self.batch_data_sync_fn_handler, - bucket_name=self._primary_bucket.bucket_name, - job_queue=self._job_queue.job_queue_name, - environment={EFS_MOUNT_POINT_PATH_VAR: EFS_MOUNT_PATH}, - memory=2048, - vcpus=1, - mount_points=[to_mount_point(EFS_MOUNT_PATH, False, EFS_VOLUME_NAME)], - volumes=[ - to_volume( - None, - EFS_VOLUME_NAME, - { - "fileSystemId": self.file_system.file_system_id, - "rootDirectory": "/", - }, - ) - ], - platform_capabilities=( - ["FARGATE"] - if any( - [ - isinstance(ce.compute_environment, batch.FargateComputeEnvironment) - for ce in self._job_queue.compute_environments - ] - ) - else None - ), - ) - ) - # fmt: off - definition = ( - start_pass_state - .next(prep_batch_sync_lambda_task) - .next(batch_sync_map_state) - ) - # fmt: on - - data_sync_state_machine_name = self.get_resource_name("data-sync") - self.data_sync_state_machine = sfn.StateMachine( - self, - self.env_base.get_construct_id("data-sync", "state-machine"), - state_machine_name=data_sync_state_machine_name, - logs=sfn.LogOptions( - destination=logs.LogGroup( - self, - self.get_construct_id(data_sync_state_machine_name, "state-loggroup"), - log_group_name=self.env_base.get_state_machine_log_group_name("data-sync"), - removal_policy=cdk.RemovalPolicy.DESTROY, - retention=logs.RetentionDays.ONE_MONTH, - ) - ), - role=iam.Role( - self, - self.env_base.get_construct_id(data_sync_state_machine_name, "role"), - assumed_by=iam.ServicePrincipal("states.amazonaws.com"), - managed_policies=[ - iam.ManagedPolicy.from_aws_managed_policy_name( - "AmazonAPIGatewayInvokeFullAccess" - ), - iam.ManagedPolicy.from_aws_managed_policy_name("AWSStepFunctionsFullAccess"), - iam.ManagedPolicy.from_aws_managed_policy_name("CloudWatchLogsFullAccess"), - iam.ManagedPolicy.from_aws_managed_policy_name("CloudWatchEventsFullAccess"), - ], - inline_policies={ - "default": iam.PolicyDocument( - statements=[ - batch_policy_statement(self.env_base), - dynamodb_policy_statement(self.env_base), - s3_policy_statement(self.env_base), - lambda_policy_statement(self.env_base), - ] - ), - }, - ), - definition_body=sfn.DefinitionBody.from_chainable(definition), - timeout=cdk.Duration.hours(12), - ) diff --git a/src/aibs_informatics_cdk_lib/stacks/network.py b/src/aibs_informatics_cdk_lib/stacks/network.py deleted file mode 100644 index ebb97b2..0000000 --- a/src/aibs_informatics_cdk_lib/stacks/network.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import Optional - -from aibs_informatics_core.env import EnvBase -from constructs import Construct - -from aibs_informatics_cdk_lib.constructs_.ec2 import EnvBaseVpc -from aibs_informatics_cdk_lib.stacks.base import EnvBaseStack - - -class NetworkStack(EnvBaseStack): - def __init__(self, scope: Construct, id: Optional[str], env_base: EnvBase, **kwargs) -> None: - super().__init__(scope, id, env_base, **kwargs) - self._vpc = EnvBaseVpc(self, "Vpc", self.env_base, max_azs=4) - - @property - def vpc(self) -> EnvBaseVpc: - return self._vpc diff --git a/src/aibs_informatics_cdk_lib/stacks/storage.py b/src/aibs_informatics_cdk_lib/stacks/storage.py deleted file mode 100644 index 3cdb36c..0000000 --- a/src/aibs_informatics_cdk_lib/stacks/storage.py +++ /dev/null @@ -1,96 +0,0 @@ -import json -from email.policy import default -from typing import Any, Iterable, List, Optional, TypeVar, Union -from urllib import request - -import aws_cdk as cdk -from aibs_informatics_aws_utils.batch import to_mount_point, to_volume -from aibs_informatics_aws_utils.constants.efs import ( - EFS_ROOT_ACCESS_POINT_TAG, - EFS_ROOT_PATH, - EFS_SCRATCH_ACCESS_POINT_TAG, - EFS_SCRATCH_PATH, - EFS_SHARED_ACCESS_POINT_TAG, - EFS_SHARED_PATH, - EFS_TMP_ACCESS_POINT_TAG, - EFS_TMP_PATH, -) -from aibs_informatics_core.env import EnvBase -from aibs_informatics_core.utils.tools.dicttools import convert_key_case -from aibs_informatics_core.utils.tools.strtools import pascalcase -from aws_cdk import aws_batch_alpha as batch -from aws_cdk import aws_ec2 as ec2 -from aws_cdk import aws_efs as efs -from aws_cdk import aws_iam as iam -from aws_cdk import aws_lambda as lambda_ -from aws_cdk import aws_logs as logs -from aws_cdk import aws_s3 as s3 -from aws_cdk import aws_stepfunctions as sfn -from constructs import Construct - -from aibs_informatics_cdk_lib.common.aws.iam_utils import batch_policy_statement -from aibs_informatics_cdk_lib.constructs_.assets.code_asset_definitions import ( - AIBSInformaticsCodeAssets, -) -from aibs_informatics_cdk_lib.constructs_.batch.infrastructure import Batch, BatchEnvironmentConfig -from aibs_informatics_cdk_lib.constructs_.batch.instance_types import ( - ON_DEMAND_INSTANCE_TYPES, - SPOT_INSTANCE_TYPES, -) -from aibs_informatics_cdk_lib.constructs_.batch.launch_template import BatchLaunchTemplateBuilder -from aibs_informatics_cdk_lib.constructs_.batch.types import BatchEnvironmentDescriptor -from aibs_informatics_cdk_lib.constructs_.ec2 import EnvBaseVpc -from aibs_informatics_cdk_lib.constructs_.efs.file_system import ( - EFSEcosystem, - EnvBaseFileSystem, - MountPointConfiguration, -) -from aibs_informatics_cdk_lib.constructs_.s3 import EnvBaseBucket, LifecycleRuleGenerator -from aibs_informatics_cdk_lib.constructs_.sfn.fragments.batch import ( - SubmitJobFragment, - SubmitJobWithDefaultsFragment, -) -from aibs_informatics_cdk_lib.stacks.base import EnvBaseStack - - -class StorageStack(EnvBaseStack): - def __init__( - self, - scope: Construct, - id: Optional[str], - env_base: EnvBase, - name: str, - vpc: ec2.Vpc, - **kwargs, - ) -> None: - super().__init__(scope, id, env_base, **kwargs) - - self._bucket = EnvBaseBucket( - self, - "Bucket", - self.env_base, - bucket_name=name, - removal_policy=self.removal_policy, - lifecycle_rules=[ - LifecycleRuleGenerator.expire_files_under_prefix(), - LifecycleRuleGenerator.expire_files_with_scratch_tags(), - LifecycleRuleGenerator.use_storage_class_as_default(), - ], - ) - - self._efs_ecosystem = EFSEcosystem( - self, id="EFS", env_base=self.env_base, file_system_name=name, vpc=vpc - ) - self._file_system = self._efs_ecosystem.file_system - - @property - def bucket(self) -> EnvBaseBucket: - return self._bucket - - @property - def efs_ecosystem(self) -> EFSEcosystem: - return self._efs_ecosystem - - @property - def file_system(self) -> EnvBaseFileSystem: - return self._file_system diff --git a/src/aibs_informatics_core_app/app.py b/src/aibs_informatics_core_app/app.py index c57f545..2eea2de 100644 --- a/src/aibs_informatics_core_app/app.py +++ b/src/aibs_informatics_core_app/app.py @@ -1,7 +1,4 @@ #!/usr/bin/env python3 -import logging -from pathlib import Path -from typing import List, Optional import aws_cdk as cdk from constructs import Construct @@ -9,15 +6,13 @@ from aibs_informatics_cdk_lib.constructs_.efs.file_system import MountPointConfiguration from aibs_informatics_cdk_lib.project.config import StageConfig from aibs_informatics_cdk_lib.project.utils import get_config -from aibs_informatics_cdk_lib.stacks.assets import AIBSInformaticsAssetsStack -from aibs_informatics_cdk_lib.stacks.compute import ( - ComputeStack, - ComputeWorkflowStack, - LambdaComputeStack, -) -from aibs_informatics_cdk_lib.stacks.network import NetworkStack -from aibs_informatics_cdk_lib.stacks.storage import StorageStack from aibs_informatics_cdk_lib.stages.base import ConfigBasedStage +from aibs_informatics_core_app.stacks.assets import AIBSInformaticsAssetsStack +from aibs_informatics_core_app.stacks.core import CoreStack +from aibs_informatics_core_app.stacks.demand_execution import ( + DemandExecutionInfrastructureStack, + DemandExecutionStack, +) class InfraStage(ConfigBasedStage): @@ -29,67 +24,36 @@ def __init__(self, scope: Construct, config: StageConfig, **kwargs) -> None: self.env_base, env=self.env, ) - network = NetworkStack( - self, - self.get_stack_name("Network"), - self.env_base, - env=self.env, - ) - storage = StorageStack( - self, - self.get_stack_name("Storage"), - self.env_base, - "core", - vpc=network.vpc, - env=self.env, - ) - - efs_ecosystem = storage.efs_ecosystem - - ap_mount_point_configs = [ - MountPointConfiguration.from_access_point( - efs_ecosystem.scratch_access_point, "/opt/scratch" - ), - MountPointConfiguration.from_access_point(efs_ecosystem.tmp_access_point, "/opt/tmp"), - MountPointConfiguration.from_access_point( - efs_ecosystem.shared_access_point, "/opt/shared", read_only=True - ), - ] - fs_mount_point_configs = [ - MountPointConfiguration.from_file_system(storage.file_system, None, "/opt/efs"), - ] - - compute = ComputeStack( + core = CoreStack( self, - self.get_stack_name("Compute"), + self.get_stack_name("Core"), self.env_base, - batch_name="batch", - vpc=network.vpc, - buckets=[storage.bucket], - file_systems=[storage.file_system], - mount_point_configs=fs_mount_point_configs, + name="core", env=self.env, ) - lambda_compute = LambdaComputeStack( + demand_execution_infra = DemandExecutionInfrastructureStack( self, - self.get_stack_name("LambdaCompute"), + self.get_stack_name("DemandExecutionInfra"), self.env_base, - batch_name="lambda-batch", - vpc=network.vpc, - buckets=[storage.bucket], - file_systems=[storage.file_system], - mount_point_configs=fs_mount_point_configs, + vpc=core.vpc, + buckets=[core.bucket], + mount_point_configs=[ + MountPointConfiguration.from_file_system(core.file_system, None, "/opt/efs"), + ], env=self.env, ) - compute_workflow = ComputeWorkflowStack( + demand_execution = DemandExecutionStack( self, - self.get_stack_name("ComputeWorkflow"), + self.get_stack_name("DemandExecution"), env_base=self.env_base, - batch_environment=lambda_compute.primary_batch_environment, - primary_bucket=storage.bucket, - mount_point_configs=fs_mount_point_configs, + assets=assets.assets, + scaffolding_bucket=core.bucket, + efs_ecosystem=core.efs_ecosystem, + data_sync_job_queue=demand_execution_infra.infra_compute.lambda_medium_batch_environment.job_queue_name, + scaffolding_job_queue=demand_execution_infra.infra_compute.primary_batch_environment.job_queue_name, + execution_job_queue=demand_execution_infra.execution_compute.primary_batch_environment.job_queue_name, env=self.env, ) diff --git a/src/aibs_informatics_core_app/stacks/__init__.py b/src/aibs_informatics_core_app/stacks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/aibs_informatics_cdk_lib/stacks/assets.py b/src/aibs_informatics_core_app/stacks/assets.py similarity index 73% rename from src/aibs_informatics_cdk_lib/stacks/assets.py rename to src/aibs_informatics_core_app/stacks/assets.py index a7a7eea..d9f6052 100644 --- a/src/aibs_informatics_cdk_lib/stacks/assets.py +++ b/src/aibs_informatics_core_app/stacks/assets.py @@ -4,6 +4,7 @@ from aibs_informatics_core.env import EnvBase from aibs_informatics_cdk_lib.constructs_.assets.code_asset_definitions import ( + AIBSInformaticsAssets, AIBSInformaticsCodeAssets, AIBSInformaticsDockerAssets, ) @@ -19,11 +20,11 @@ def __init__( **kwargs, ): super().__init__(scope, id, env_base, **kwargs) - self.code_assets = AIBSInformaticsCodeAssets( + + self.assets = AIBSInformaticsAssets( self, - "aibs-info-code-assets", + "aibs-info-assets", self.env_base, ) - self.docker_assets = AIBSInformaticsDockerAssets( - self, "aibs-info-docker-assets", self.env_base - ) + self.code_assets = self.assets.code_assets + self.docker_assets = self.assets.docker_assets diff --git a/src/aibs_informatics_core_app/stacks/core.py b/src/aibs_informatics_core_app/stacks/core.py new file mode 100644 index 0000000..d363f22 --- /dev/null +++ b/src/aibs_informatics_core_app/stacks/core.py @@ -0,0 +1,56 @@ +from typing import Optional + +from aibs_informatics_core.env import EnvBase +from constructs import Construct + +from aibs_informatics_cdk_lib.constructs_.ec2.network import EnvBaseVpc +from aibs_informatics_cdk_lib.constructs_.efs.file_system import EFSEcosystem, EnvBaseFileSystem +from aibs_informatics_cdk_lib.constructs_.s3 import EnvBaseBucket, LifecycleRuleGenerator +from aibs_informatics_cdk_lib.stacks.base import EnvBaseStack + + +class CoreStack(EnvBaseStack): + def __init__( + self, + scope: Construct, + id: Optional[str], + env_base: EnvBase, + name: str, + **kwargs, + ) -> None: + super().__init__(scope, id, env_base, **kwargs) + self._vpc = EnvBaseVpc(self, "Vpc", self.env_base, max_azs=4) + + self._bucket = EnvBaseBucket( + self, + "Bucket", + self.env_base, + bucket_name=name, + removal_policy=self.removal_policy, + lifecycle_rules=[ + LifecycleRuleGenerator.expire_files_under_prefix(), + LifecycleRuleGenerator.expire_files_with_scratch_tags(), + LifecycleRuleGenerator.use_storage_class_as_default(), + ], + ) + + self._efs_ecosystem = EFSEcosystem( + self, id="EFS", env_base=self.env_base, file_system_name=name, vpc=self.vpc + ) + self._file_system = self._efs_ecosystem.file_system + + @property + def vpc(self) -> EnvBaseVpc: + return self._vpc + + @property + def bucket(self) -> EnvBaseBucket: + return self._bucket + + @property + def efs_ecosystem(self) -> EFSEcosystem: + return self._efs_ecosystem + + @property + def file_system(self) -> EnvBaseFileSystem: + return self._file_system diff --git a/src/aibs_informatics_core_app/stacks/demand_execution.py b/src/aibs_informatics_core_app/stacks/demand_execution.py new file mode 100644 index 0000000..182bcd3 --- /dev/null +++ b/src/aibs_informatics_core_app/stacks/demand_execution.py @@ -0,0 +1,152 @@ +from typing import Iterable, Optional, Union + +import constructs +from aibs_informatics_core.env import EnvBase +from aws_cdk import aws_batch_alpha as batch +from aws_cdk import aws_ec2 as ec2 +from aws_cdk import aws_s3 as s3 + +from aibs_informatics_cdk_lib.common.aws.iam_utils import ( + BATCH_FULL_ACCESS_ACTIONS, + batch_policy_statement, +) +from aibs_informatics_cdk_lib.constructs_.assets.code_asset_definitions import ( + AIBSInformaticsAssets, +) +from aibs_informatics_cdk_lib.constructs_.efs.file_system import ( + EFSEcosystem, + MountPointConfiguration, +) +from aibs_informatics_cdk_lib.constructs_.service.compute import BatchCompute, LambdaCompute +from aibs_informatics_cdk_lib.constructs_.sfn.fragments.informatics import ( + BatchInvokedLambdaFunction, + DataSyncFragment, + DemandExecutionFragment, +) +from aibs_informatics_cdk_lib.stacks.base import EnvBaseStack + +DATA_SYNC_ASSET_NAME = "aibs_informatics_aws_lambda" + +EFS_MOUNT_PATH = "/opt/efs" +EFS_VOLUME_NAME = "efs-file-system" + + +class DemandExecutionInfrastructureStack(EnvBaseStack): + def __init__( + self, + scope: constructs.Construct, + id: Optional[str], + env_base: EnvBase, + vpc: ec2.Vpc, + buckets: Optional[Iterable[s3.Bucket]] = None, + mount_point_configs: Optional[Iterable[MountPointConfiguration]] = None, + **kwargs, + ) -> None: + super().__init__(scope, id, env_base, **kwargs) + self.execution_compute = BatchCompute( + self, + id="demand", + env_base=env_base, + batch_name="demand", + vpc=vpc, + buckets=buckets, + mount_point_configs=mount_point_configs, + ) + self.infra_compute = LambdaCompute( + self, + id="demand-infra", + env_base=env_base, + vpc=vpc, + batch_name="demand-infra", + buckets=buckets, + mount_point_configs=mount_point_configs, + instance_role_policy_statements=[ + batch_policy_statement(env_base=env_base, actions=BATCH_FULL_ACCESS_ACTIONS), + ], + ) + + +class DemandExecutionStack(EnvBaseStack): + def __init__( + self, + scope: constructs.Construct, + id: str, + env_base: EnvBase, + assets: AIBSInformaticsAssets, + scaffolding_bucket: s3.Bucket, + efs_ecosystem: EFSEcosystem, + scaffolding_job_queue: Union[batch.JobQueue, str], + data_sync_job_queue: Union[batch.JobQueue, str], + execution_job_queue: Union[batch.JobQueue, str], + **kwargs, + ): + super().__init__(scope, id, env_base, **kwargs) + + self.efs_ecosystem = efs_ecosystem + + self.execution_job_queue = ( + execution_job_queue + if isinstance(execution_job_queue, str) + else execution_job_queue.job_queue_arn + ) + self.data_sync_job_queue = ( + data_sync_job_queue + if isinstance(data_sync_job_queue, str) + else data_sync_job_queue.job_queue_arn + ) + self.scaffolding_job_queue = ( + scaffolding_job_queue + if isinstance(scaffolding_job_queue, str) + else scaffolding_job_queue.job_queue_arn + ) + + self._assets = assets + + root_mount_point_config = MountPointConfiguration.from_access_point( + self.efs_ecosystem.file_system.root_access_point, EFS_MOUNT_PATH + ) + shared_mount_point_config = MountPointConfiguration.from_access_point( + self.efs_ecosystem.shared_access_point, "/opt/shared", read_only=True + ) + scratch_mount_point_config = MountPointConfiguration.from_access_point( + self.efs_ecosystem.scratch_access_point, "/opt/scratch" + ) + + batch_invoked_lambda_fragment = BatchInvokedLambdaFunction.with_defaults( + self, + "batch-invoked-lambda", + env_base=self.env_base, + name="batch-invoked-lambda", + job_queue=self.scaffolding_job_queue, + bucket_name=scaffolding_bucket.bucket_name, + mount_point_configs=[root_mount_point_config], + ) + self.batch_invoked_lambda_state_machine = batch_invoked_lambda_fragment.to_state_machine( + "batch-invoked-lambda-state-machine" + ) + + data_sync = DataSyncFragment( + self, + "data-sync", + env_base=self.env_base, + aibs_informatics_docker_asset=self._assets.docker_assets.AIBS_INFORMATICS_AWS_LAMBDA, + batch_job_queue=self.data_sync_job_queue, + scaffolding_bucket=scaffolding_bucket, + mount_point_configs=[root_mount_point_config], + ) + + self.data_sync_state_machine = data_sync.to_state_machine("data-sync-v2") + + demand_execution = DemandExecutionFragment( + self, + "demand-execution", + env_base=self.env_base, + aibs_informatics_docker_asset=self._assets.docker_assets.AIBS_INFORMATICS_AWS_LAMBDA, + scaffolding_bucket=scaffolding_bucket, + scaffolding_job_queue=self.scaffolding_job_queue, + batch_invoked_lambda_state_machine=self.batch_invoked_lambda_state_machine, + data_sync_state_machine=self.data_sync_state_machine, + shared_mount_point_config=shared_mount_point_config, + scratch_mount_point_config=scratch_mount_point_config, + ) + self.demand_execution_state_machine = demand_execution.to_state_machine("demand-execution")