diff --git a/pyproject.toml b/pyproject.toml index 1d828ce..971b44f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -141,7 +141,7 @@ no_site_packages = true # Untyped definitions and calls # https://mypy.readthedocs.io/en/stable/config_file.html#untyped-definitions-and-calls # TODO: enable and fix errors -check_untyped_defs = false +check_untyped_defs = true # Miscellaneous strictness flags diff --git a/src/aibs_informatics_cdk_lib/common/aws/iam_utils.py b/src/aibs_informatics_cdk_lib/common/aws/iam_utils.py index 5c2643f..51452f9 100644 --- a/src/aibs_informatics_cdk_lib/common/aws/iam_utils.py +++ b/src/aibs_informatics_cdk_lib/common/aws/iam_utils.py @@ -12,12 +12,16 @@ build_sfn_arn, ) +BATCH_READ_ONLY_ACTIONS = [ + "batch:Describe*", + "batch:List*", +] + BATCH_FULL_ACCESS_ACTIONS = [ "batch:RegisterJobDefinition", "batch:DeregisterJobDefinition", "batch:DescribeJobDefinitions", - "batch:List*", - "batch:Describe*", + *BATCH_READ_ONLY_ACTIONS, "batch:*", ] diff --git a/src/aibs_informatics_cdk_lib/constructs_/assets/code_asset.py b/src/aibs_informatics_cdk_lib/constructs_/assets/code_asset.py index f642d49..3ffc7b6 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/assets/code_asset.py +++ b/src/aibs_informatics_cdk_lib/constructs_/assets/code_asset.py @@ -165,6 +165,7 @@ def create_py_code_asset( cls, path: Path, context_path: Optional[Path], + requirements_file: Optional[Path] = None, includes: Optional[Sequence[str]] = None, excludes: Optional[Sequence[str]] = None, runtime: lambda_.Runtime = lambda_.Runtime.PYTHON_3_11, @@ -219,7 +220,7 @@ def create_py_code_asset( "ssh -vT git@github.com || true", # Must make sure that the package is not installing using --editable mode "python3 -m pip install --upgrade pip --no-cache", - "pip3 install . --no-cache -t /asset-output", + f"pip3 install {'-r ' + requirements_file.as_posix() if requirements_file else '.'} --no-cache -t /asset-output", # TODO: remove botocore and boto3 from asset output # Must make asset output permissions accessible to lambda "find /asset-output -type d -print0 | xargs -0 chmod 755", diff --git a/src/aibs_informatics_cdk_lib/constructs_/assets/code_asset_definitions.py b/src/aibs_informatics_cdk_lib/constructs_/assets/code_asset_definitions.py index 319e202..09d4c3f 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/assets/code_asset_definitions.py +++ b/src/aibs_informatics_cdk_lib/constructs_/assets/code_asset_definitions.py @@ -8,11 +8,13 @@ from aibs_informatics_core.env import EnvBase from aibs_informatics_core.utils.decorators import cached_property from aibs_informatics_core.utils.hashing import generate_path_hash +from aws_cdk import aws_ecr_assets from aws_cdk import aws_lambda as lambda_ from aws_cdk import aws_s3_assets from aibs_informatics_cdk_lib.common.git import clone_repo, is_local_repo from aibs_informatics_cdk_lib.constructs_.assets.code_asset import ( + GLOBAL_GLOB_EXCLUDES, PYTHON_GLOB_EXCLUDES, PYTHON_REGEX_EXCLUDES, CodeAsset, @@ -24,6 +26,36 @@ logger = logging.getLogger(__name__) +class AssetsMixin: + @classmethod + def resolve_repo_path(cls, repo_url: str, repo_path_env_var: Optional[str]) -> Path: + """Resolves the repo path from the environment or clones the repo from the url + + This method is useful to quickly swapping between locally modified changes and the remote repo. + + This should typically be used in the context of defining a code asset for a static name + (e.g. AIBS_INFORMATICS_AWS_LAMBDA). You can then use the env var option to point to a local + repo path for development. + + Args: + repo_url (str): The git repo url. This is required. + If the repo path is not in the environment, the repo will be cloned from this url. + repo_path_env_var (Optional[str]): The environment variable that contains the repo path. + This is optional. This is useful for local development. + + Returns: + Path: The path to the repo + """ + if repo_path_env_var and (repo_path := os.getenv(repo_path_env_var)) is not None: + logger.info(f"Using {repo_path_env_var} from environment") + if not is_local_repo(repo_path): + raise ValueError(f"Env variable {repo_path_env_var} is not a valid git repo") + repo_path = Path(repo_path) + else: + repo_path = clone_repo(repo_url, skip_if_exists=True) + return repo_path + + class AIBSInformaticsCodeAssets(constructs.Construct): def __init__( self, @@ -92,7 +124,7 @@ def AIBS_INFORMATICS_AWS_LAMBDA(self) -> CodeAsset: "ssh -vT git@github.com || true", # Must make sure that the package is not installing using --editable mode "python3 -m pip install --upgrade pip --no-cache", - "pip3 install --no-cache -r requirements-lambda.txt -t /asset-output", + "pip3 install . --no-cache -t /asset-output", # TODO: remove botocore and boto3 from asset output # Must make asset output permissions accessible to lambda "find /asset-output -type d -print0 | xargs -0 chmod 755", @@ -117,3 +149,38 @@ def AIBS_INFORMATICS_AWS_LAMBDA(self) -> CodeAsset: self.env_base.ENV_BASE_KEY: self.env_base, }, ) + + +class AIBSInformaticsDockerAssets(constructs.Construct, AssetsMixin): + def __init__( + self, + scope: constructs.Construct, + construct_id: str, + env_base: EnvBase, + ) -> None: + super().__init__(scope, construct_id) + self.env_base = env_base + + @cached_property + def AIBS_INFORMATICS_AWS_LAMBDA(self) -> aws_ecr_assets.DockerImageAsset: + """Returns a NEW docker asset for aibs-informatics-aws-lambda + + Returns: + aws_ecr_assets.DockerImageAsset: The docker asset + """ + repo_path = self.resolve_repo_path( + AIBS_INFORMATICS_AWS_LAMBDA_REPO, AIBS_INFORMATICS_AWS_LAMBDA_REPO_ENV_VAR + ) + return aws_ecr_assets.DockerImageAsset( + self, + "aibs-informatics-aws-lambda", + directory=repo_path.as_posix(), + build_ssh="default", + platform=aws_ecr_assets.Platform.LINUX_AMD64, + asset_name="aibs-informatics-aws-lambda", + file="docker/Dockerfile", + exclude=[ + *PYTHON_GLOB_EXCLUDES, + *GLOBAL_GLOB_EXCLUDES, + ], + ) diff --git a/src/aibs_informatics_cdk_lib/constructs_/base.py b/src/aibs_informatics_cdk_lib/constructs_/base.py index 5195e2c..545c694 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/base.py +++ b/src/aibs_informatics_cdk_lib/constructs_/base.py @@ -34,7 +34,7 @@ def construct_tags(self) -> List[cdk.Tag]: def add_tags(self, *tags: cdk.Tag): for tag in tags: - cdk.Tags.of(self).add(key=tag.key, value=tag.value) + cdk.Tags.of(self.as_construct()).add(key=tag.key, value=tag.value) def normalize_construct_id( self, construct_id: str, max_size: int = 64, hash_size: int = 8 @@ -71,11 +71,14 @@ def get_resource_name(self, name: Union[ResourceNameBaseEnum, str]) -> str: return name.get_name(self.env_base) return self.env_base.get_resource_name(name) - def get_stack_of(self, construct: Construct) -> Optional[Stack]: - try: - return cdk.Stack.of(construct) - except: - return None + def get_stack_of(self, construct: Optional[Construct] = None) -> Stack: + if construct is None: + construct = self.as_construct() + return cdk.Stack.of(construct) + + def as_construct(self) -> Construct: + assert isinstance(self, Construct) + return self @classmethod def build_construct_id(cls, env_base: EnvBase, *names: str) -> str: @@ -104,3 +107,11 @@ def __init__(self, scope: Construct, id: Optional[str], env_base: EnvBase) -> No @property def construct_id(self) -> str: return self.node.id + + @property + def aws_region(self) -> str: + return cdk.Stack.of(self).region + + @property + def aws_account(self) -> str: + return cdk.Stack.of(self).account diff --git a/src/aibs_informatics_cdk_lib/constructs_/batch/infrastructure.py b/src/aibs_informatics_cdk_lib/constructs_/batch/infrastructure.py index b3448a8..9c56be9 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/batch/infrastructure.py +++ b/src/aibs_informatics_cdk_lib/constructs_/batch/infrastructure.py @@ -23,7 +23,9 @@ from aws_cdk import aws_s3 as s3 from aibs_informatics_cdk_lib.common.aws.iam_utils import ( + BATCH_READ_ONLY_ACTIONS, S3_READ_ONLY_ACCESS_ACTIONS, + batch_policy_statement, dynamodb_policy_statement, lambda_policy_statement, ) @@ -149,6 +151,7 @@ def create_instance_role( effect=iam.Effect.ALLOW, resources=["*"], ), + batch_policy_statement(actions=BATCH_READ_ONLY_ACTIONS, env_base=self.env_base), lambda_policy_statement(actions=["lambda:InvokeFunction"], env_base=self.env_base), dynamodb_policy_statement( env_base=self.env_base, diff --git a/src/aibs_informatics_cdk_lib/constructs_/cw/dashboard.py b/src/aibs_informatics_cdk_lib/constructs_/cw/dashboard.py index 03c8331..8ae266f 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/cw/dashboard.py +++ b/src/aibs_informatics_cdk_lib/constructs_/cw/dashboard.py @@ -105,7 +105,7 @@ def create_widgets_and_alarms( Returns: Tuple[List[cw.IWidget], List[cw.IAlarm]]: List of widgets and list of alarms """ - self_stack = self.get_stack_of(self) + self_stack = cdk.Stack.of(self.as_construct()) graph_widgets: List[cw.IWidget] = [] metric_alarms: List[cw.IAlarm] = [] diff --git a/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/base.py b/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/base.py index 7b75bc8..9a3ea21 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/base.py +++ b/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/base.py @@ -29,7 +29,7 @@ def get_fn(self, function_name: str) -> lambda_.IFunction: resource_cache = cast(Dict[str, lambda_.IFunction], getattr(self, cache_attr)) if function_name not in resource_cache: resource_cache[function_name] = lambda_.Function.from_function_arn( - scope=self, + scope=self.as_construct(), id=self.env_base.get_construct_id(function_name, "from-arn"), function_arn=build_lambda_arn( resource_type="function", @@ -52,6 +52,37 @@ def get_state_machine_from_name(self, state_machine_name: str) -> sfn.IStateMach return resource_cache[state_machine_name] +def create_state_machine( + scope: constructs.Construct, + env_base: EnvBase, + name: str, + definition: sfn.IChainable, + role: Optional[iam.Role] = None, + logs: Optional[sfn.LogOptions] = None, + timeout: Optional[cdk.Duration] = None, +) -> sfn.StateMachine: + return sfn.StateMachine( + scope, + env_base.get_construct_id(name), + state_machine_name=env_base.get_state_machine_name(name), + logs=( + logs + or sfn.LogOptions( + destination=logs_.LogGroup( + scope, + env_base.get_construct_id(name, "state-loggroup"), + log_group_name=env_base.get_state_machine_log_group_name(name), + removal_policy=cdk.RemovalPolicy.DESTROY, + retention=logs_.RetentionDays.ONE_MONTH, + ) + ) + ), + role=cast(iam.IRole, role), + definition_body=sfn.DefinitionBody.from_chainable(definition), + timeout=timeout, + ) + + class StateMachineFragment(sfn.StateMachineFragment): @property def definition(self) -> sfn.IChainable: @@ -87,9 +118,13 @@ def enclose( scope, f"{id} Parallel Prep", parameters={"input": sfn.JsonPath.entire_payload} ) - parallel = chain.to_single_state( - id=f"{id} Parallel", input_path="$.input", result_path="$.result" - ) + if isinstance(chain, sfn.Chain): + parallel = chain.to_single_state( + id=f"{id} Parallel", input_path="$.input", result_path="$.result" + ) + else: + parallel = chain.to_single_state(input_path="$.input", result_path="$.result") + mod_result_path = JsonReferencePath("$.input") if result_path and result_path != sfn.JsonPath.DISCARD: mod_result_path = mod_result_path + result_path diff --git a/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/batch.py b/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/batch.py index 01ed963..62b7233 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/batch.py +++ b/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/batch.py @@ -1,10 +1,13 @@ -from typing import TYPE_CHECKING, List, Literal, Mapping, Optional, Union +from typing import TYPE_CHECKING, Any, List, Literal, Mapping, Optional, Union import constructs from aibs_informatics_core.env import EnvBase +from aibs_informatics_core.utils.tools.dicttools import convert_key_case +from aibs_informatics_core.utils.tools.strtools import pascalcase from aws_cdk import aws_batch_alpha as batch from aws_cdk import aws_stepfunctions as sfn +from aibs_informatics_cdk_lib.constructs_.efs.file_system import MountPointConfiguration from aibs_informatics_cdk_lib.constructs_.sfn.fragments.base import ( EnvBaseStateMachineFragment, StateMachineFragment, @@ -14,12 +17,28 @@ if TYPE_CHECKING: from mypy_boto3_batch.type_defs import MountPointTypeDef, VolumeTypeDef -else: +else: # pragma: no cover MountPointTypeDef = dict VolumeTypeDef = dict -class SubmitJobFragment(EnvBaseStateMachineFragment): +class AWSBatchMixins: + @classmethod + def convert_to_mount_point_and_volumes( + cls, + mount_point_configs: List[MountPointConfiguration], + ) -> tuple[List[MountPointTypeDef], List[VolumeTypeDef]]: + mount_points = [] + volumes = [] + for i, mpc in enumerate(mount_point_configs): + mount_points.append( + convert_key_case(mpc.to_batch_mount_point(f"efs-vol{i}"), pascalcase) + ) + volumes.append(convert_key_case(mpc.to_batch_volume(f"efs-vol{i}"), pascalcase)) + return mount_points, volumes + + +class SubmitJobFragment(EnvBaseStateMachineFragment, AWSBatchMixins): def __init__( self, scope: constructs.Construct, @@ -103,3 +122,144 @@ def __init__( ) self.definition = register.next(submit).next(deregister) + + @classmethod + def from_defaults( + cls, + scope: constructs.Construct, + id: str, + env_base: EnvBase, + name: str, + job_queue: str, + image: str, + command: str = "", + memory: str = "1024", + vcpus: str = "1", + environment: Optional[Mapping[str, str]] = None, + mount_point_configs: Optional[List[MountPointConfiguration]] = None, + ) -> "SubmitJobFragment": + defaults: dict[str, Any] = {} + defaults["command"] = command + defaults["job_queue"] = job_queue + defaults["environment"] = environment or {} + defaults["memory"] = memory + defaults["vcpus"] = vcpus + defaults["gpu"] = "0" + defaults["platform_capabilities"] = ["EC2"] + + if mount_point_configs: + mount_points, volumes = cls.convert_to_mount_point_and_volumes(mount_point_configs) + defaults["mount_points"] = mount_points + defaults["volumes"] = volumes + + submit_job = SubmitJobFragment( + scope, + id, + env_base=env_base, + name="SubmitJobCore", + image=sfn.JsonPath.string_at("$.request.image"), + command=sfn.JsonPath.string_at("$.request.command"), + job_queue=sfn.JsonPath.string_at("$.request.job_queue"), + environment=sfn.JsonPath.string_at("$.request.environment"), + memory=sfn.JsonPath.string_at("$.request.memory"), + vcpus=sfn.JsonPath.string_at("$.request.vcpus"), + # TODO: Handle GPU parameter better - right now, we cannot handle cases where it is + # not specified. Setting to zero causes issues with the Batch API. + # If it is set to zero, then the json list of resources are not properly set. + # gpu=sfn.JsonPath.string_at("$.request.gpu"), + mount_points=sfn.JsonPath.string_at("$.request.mount_points"), + volumes=sfn.JsonPath.string_at("$.request.volumes"), + platform_capabilities=sfn.JsonPath.string_at("$.request.platform_capabilities"), + ) + + # Now we need to add the start and merge states and add to the definition + start = sfn.Pass( + submit_job, + "Start", + parameters={ + "input": sfn.JsonPath.string_at("$"), + "default": defaults, + }, + ) + merge = sfn.Pass( + submit_job, + "Merge", + parameters={ + "request": sfn.JsonPath.json_merge( + sfn.JsonPath.object_at("$.default"), sfn.JsonPath.object_at("$.input") + ), + }, + ) + + submit_job.definition = start.next(merge).next(submit_job.definition) + return submit_job + + +class SubmitJobWithDefaultsFragment(EnvBaseStateMachineFragment, AWSBatchMixins): + def __init__( + self, + scope: constructs.Construct, + id: str, + env_base: EnvBase, + job_queue: str, + command: str = "", + memory: str = "1024", + vcpus: str = "1", + environment: Optional[Mapping[str, str]] = None, + mount_point_configs: Optional[List[MountPointConfiguration]] = None, + ): + super().__init__(scope, id, env_base) + defaults: dict[str, Any] = {} + defaults["command"] = command + defaults["job_queue"] = job_queue + defaults["environment"] = environment or {} + defaults["memory"] = memory + defaults["vcpus"] = vcpus + defaults["gpu"] = "0" + defaults["platform_capabilities"] = ["EC2"] + + if mount_point_configs: + mount_points, volumes = self.convert_to_mount_point_and_volumes(mount_point_configs) + defaults["mount_points"] = mount_points + defaults["volumes"] = volumes + + start = sfn.Pass( + self, + "Start", + parameters={ + "input": sfn.JsonPath.object_at("$"), + "default": defaults, + }, + ) + + merge = sfn.Pass( + self, + "Merge", + parameters={ + "request": sfn.JsonPath.json_merge( + sfn.JsonPath.object_at("$.default"), sfn.JsonPath.object_at("$.input") + ), + }, + ) + + submit_job = SubmitJobFragment( + self, + "SubmitJobCore", + env_base=self.env_base, + name="SubmitJobCore", + image=sfn.JsonPath.string_at("$.request.image"), + command=sfn.JsonPath.string_at("$.request.command"), + job_queue=sfn.JsonPath.string_at("$.request.job_queue"), + environment=sfn.JsonPath.string_at("$.request.environment"), + memory=sfn.JsonPath.string_at("$.request.memory"), + vcpus=sfn.JsonPath.string_at("$.request.vcpus"), + # TODO: Handle GPU parameter better - right now, we cannot handle cases where it is + # not specified. Setting to zero causes issues with the Batch API. + # If it is set to zero, then the json list of resources are not properly set. + # gpu=sfn.JsonPath.string_at("$.request.gpu"), + mount_points=sfn.JsonPath.string_at("$.request.mount_points"), + volumes=sfn.JsonPath.string_at("$.request.volumes"), + platform_capabilities=sfn.JsonPath.string_at("$.request.platform_capabilities"), + ).to_single_state() + + self.definition = start.next(merge).next(submit_job) diff --git a/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/informatics.py b/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/informatics.py index 90ebc15..f011b68 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/informatics.py +++ b/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/informatics.py @@ -11,9 +11,12 @@ from aibs_informatics_core.env import EnvBase from aws_cdk import aws_stepfunctions as sfn +from aibs_informatics_cdk_lib.constructs_.efs.file_system import MountPointConfiguration from aibs_informatics_cdk_lib.constructs_.sfn.fragments.base import EnvBaseStateMachineFragment -from aibs_informatics_cdk_lib.constructs_.sfn.fragments.batch import SubmitJobFragment -from aibs_informatics_cdk_lib.constructs_.sfn.states.batch import BatchOperation +from aibs_informatics_cdk_lib.constructs_.sfn.fragments.batch import ( + AWSBatchMixins, + SubmitJobFragment, +) from aibs_informatics_cdk_lib.constructs_.sfn.states.s3 import S3Operation if TYPE_CHECKING: @@ -23,56 +26,7 @@ VolumeTypeDef = dict -class RcloneFragment(EnvBaseStateMachineFragment): - def __init__( - self, - scope: constructs.Construct, - id: str, - env_base: EnvBase, - name: str, - image: str, - job_queue: str, - bucket_name: str, - payload_path: Optional[str] = None, - command: Optional[Union[List[str], str]] = None, - environment: Optional[Mapping[str, str]] = None, - memory: Optional[Union[int, str]] = None, - vcpus: Optional[Union[int, str]] = None, - mount_points: Optional[List[MountPointTypeDef]] = None, - volumes: Optional[List[VolumeTypeDef]] = None, - platform_capabilities: Optional[List[Literal["EC2", "FARGATE"]]] = None, - ) -> None: - pass - # submit_job = SubmitJobFragment( - # self, - # f"{id} Batch", - # env_base=env_base, - # name=name, - # job_queue=job_queue, - # command=command or [], - # image=image, - # environment={ - # **(environment if environment else {}), - # AWS_LAMBDA_FUNCTION_NAME_KEY: name, - # # AWS_LAMBDA_FUNCTION_HANDLER_KEY: handler, - # AWS_LAMBDA_EVENT_PAYLOAD_KEY: sfn.JsonPath.format( - # "s3://{}/{}", - # sfn.JsonPath.string_at("$.taskResult.put.Bucket"), - # sfn.JsonPath.string_at("$.taskResult.put.Key"), - # ), - # AWS_LAMBDA_EVENT_RESPONSE_LOCATION_KEY: sfn.JsonPath.format( - # "s3://{}/{}", bucket_name, response_key - # ), - # }, - # memory=memory, - # vcpus=vcpus, - # mount_points=mount_points or [], - # volumes=volumes or [], - # platform_capabilities=platform_capabilities, - # ) - - -class BatchInvokedLambdaFunction(EnvBaseStateMachineFragment): +class BatchInvokedLambdaFunction(EnvBaseStateMachineFragment, AWSBatchMixins): def __init__( self, scope: constructs.Construct, @@ -89,9 +43,10 @@ def __init__( environment: Optional[Mapping[str, str]] = None, memory: Optional[Union[int, str]] = None, vcpus: Optional[Union[int, str]] = None, - mount_points: Optional[List[MountPointTypeDef]] = None, - volumes: Optional[List[VolumeTypeDef]] = None, - platform_capabilities: Optional[List[Literal["EC2", "FARGATE"]]] = None, + mount_points: Optional[Union[List[MountPointTypeDef], str]] = None, + volumes: Optional[Union[List[VolumeTypeDef], str]] = None, + mount_point_configs: Optional[List[MountPointConfiguration]] = None, + platform_capabilities: Optional[Union[List[Literal["EC2", "FARGATE"]], str]] = None, ) -> None: """Invoke a command on image via batch with a payload from s3 @@ -140,6 +95,11 @@ def __init__( f"{key_prefix}{{}}/response.json", sfn.JsonPath.execution_name ) + if mount_point_configs: + if mount_points or volumes: + raise ValueError("Cannot specify both mount_point_configs and mount_points") + mount_points, volumes = self.convert_to_mount_point_and_volumes(mount_point_configs) + put_payload = S3Operation.put_payload( self, f"{id} Put Request to S3", @@ -179,11 +139,12 @@ def __init__( get_response = S3Operation.get_payload( self, - f"{id} Get Response from S3", + f"{id}", bucket_name=bucket_name, key=response_key, - result_path="$.taskResult.get", - output_path="$.taskResult.get", + ).to_single_state( + "Get Response from S3", + output_path="$[0]", ) self.definition = put_payload.next(submit_job).next(get_response) @@ -197,7 +158,7 @@ def end_states(self) -> List[sfn.INextable]: return self.definition.end_states -class BatchInvokedExecutorFragment(EnvBaseStateMachineFragment): +class BatchInvokedExecutorFragment(EnvBaseStateMachineFragment, AWSBatchMixins): def __init__( self, scope: constructs.Construct, @@ -213,6 +174,7 @@ def __init__( environment: Optional[Union[Mapping[str, str], str]] = None, memory: Optional[Union[int, str]] = None, vcpus: Optional[Union[int, str]] = None, + mount_point_configs: Optional[List[MountPointConfiguration]] = None, mount_points: Optional[List[MountPointTypeDef]] = None, volumes: Optional[List[VolumeTypeDef]] = None, ) -> None: @@ -260,12 +222,18 @@ def __init__( f"{key_prefix}{{}}/response.json", sfn.JsonPath.execution_name ) + if mount_point_configs: + if mount_points or volumes: + raise ValueError("Cannot specify both mount_point_configs and mount_points") + mount_points, volumes = self.convert_to_mount_point_and_volumes(mount_point_configs) + put_payload = S3Operation.put_payload( self, - f"Put Request to S3", + f"{id} Put Request to S3", payload=payload_path or sfn.JsonPath.entire_payload, bucket_name=bucket_name, key=request_key, + result_path="$.taskResult.put", ) submit_job = SubmitJobFragment( @@ -293,9 +261,12 @@ def __init__( get_response = S3Operation.get_payload( self, - f"Get Response from S3", + f"{id}", bucket_name=bucket_name, key=response_key, + ).to_single_state( + "Get Response from S3", + output_path="$[0]", ) self.definition = put_payload.next(submit_job).next(get_response) diff --git a/src/aibs_informatics_cdk_lib/constructs_/sfn/states/s3.py b/src/aibs_informatics_cdk_lib/constructs_/sfn/states/s3.py index c816eca..283dfa4 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/sfn/states/s3.py +++ b/src/aibs_informatics_cdk_lib/constructs_/sfn/states/s3.py @@ -3,7 +3,10 @@ import constructs from aws_cdk import aws_stepfunctions as sfn -from aibs_informatics_cdk_lib.constructs_.sfn.utils import convert_reference_paths +from aibs_informatics_cdk_lib.constructs_.sfn.utils import ( + convert_reference_paths, + convert_reference_paths_in_mapping, +) class S3Operation: @@ -61,7 +64,7 @@ def put_object( init = sfn.Pass( scope, id + " PutObject Prep", - parameters=convert_reference_paths( + parameters=convert_reference_paths_in_mapping( { "Bucket": bucket_name, "Key": key, @@ -90,7 +93,7 @@ def put_object( end = sfn.Pass( scope, id + " PutObject Post", - parameters=convert_reference_paths( + parameters=convert_reference_paths_in_mapping( { "Bucket": f"{result_path or '$'}.Bucket", "Key": f"{result_path or '$'}.Key", @@ -155,7 +158,7 @@ def get_object( init = sfn.Pass( scope, id + " GetObject Prep", - parameters=convert_reference_paths( + parameters=convert_reference_paths_in_mapping( { "Bucket": bucket_name, "Key": key, @@ -164,12 +167,19 @@ def get_object( result_path=result_path or "$", ) + if result_path: + bucket_path = f"{result_path}.Bucket" + key_path = f"{result_path}.Key" + else: + bucket_path = "$.Bucket" + key_path = "$.Key" + state_json = { "Type": "Task", "Resource": "arn:aws:states:::aws-sdk:s3:getObject", "Parameters": { - "Bucket.$": "$.Bucket", - "Key.$": "$.Key", + "Bucket.$": bucket_path, + "Key.$": key_path, }, "ResultSelector": { "Body.$": "$.Body", @@ -245,37 +255,65 @@ def get_payload( bucket_name: str, key: str, result_path: Optional[str] = "$", - output_path: Optional[str] = "$", ) -> sfn.Chain: """Gets a payload from s3 This chain fetches object and then passes the body through a json parser - Example: - Context: - {"bucket": "woah", "key": "wait/what"} - Definition: - S3Operation.get_payload(..., bucket_name="$.bucket", key="$.key") - Result: - # text '{"a": "b"}' is fetched from s3://woah/wait/what - {"a": "b"} + The resulting payload will be stored to path specified by result_path + + Examples: + Use Case #1 - No result path + Context: + {"bucket": "woah", "key": "wait/what"} + Definition: + S3Operation.get_payload(..., bucket_name="$.bucket", key="$.key") + Result: + # text '{"a": "b"}' is fetched from s3://woah/wait/what + {"a": "b"} + + Use Case #2 - result path specified + Context: + {"bucket": "woah", "key": "wait/what"} + Definition: + S3Operation.get_payload( + ..., bucket_name="$.bucket", key="$.key", result_path="$.payload" + ) + Result: + # text '{"a": "b"}' is fetched from s3://woah/wait/what + { + "bucket": "woah", + "key": "wait/what", + "payload": {"a": "b"} + } Args: scope (constructs.Construct): cdk construct id (str): id - bucket_name (str): bucket name - key (str): key + bucket_name (str): bucket name. Can be a reference path + key (str): key name. Can be a reference path + result_path (Optional[str], optional): path to store the payload. Defaults to "$". Returns: sfn.Chain: chain """ - get_chain = S3Operation.get_object(scope, id, bucket_name, key, result_path, output_path) + get_chain = S3Operation.get_object(scope, id, bucket_name, key, result_path) post = sfn.Pass( scope, id + " Post", - parameters={"Payload": sfn.JsonPath.string_to_json("$.Body")}, + parameters={ + "Payload": sfn.JsonPath.string_to_json( + sfn.JsonPath.string_at(f"{result_path}.Body") + ) + }, result_path=result_path, - output_path="$.Payload", ) - return get_chain.next(post) + restructure = sfn.Pass( + scope, + id + " Restructure", + input_path=f"{result_path}.Payload", + result_path=result_path, + ) + + return get_chain.next(post).next(restructure) diff --git a/src/aibs_informatics_cdk_lib/constructs_/sfn/utils.py b/src/aibs_informatics_cdk_lib/constructs_/sfn/utils.py index 601aaa7..53d64b1 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/sfn/utils.py +++ b/src/aibs_informatics_cdk_lib/constructs_/sfn/utils.py @@ -1,12 +1,16 @@ import re from functools import reduce -from typing import Any, ClassVar, List, Pattern, Union, cast +from typing import Any, ClassVar, List, Mapping, Pattern, Union, cast import aws_cdk as cdk from aibs_informatics_core.utils.json import JSON from aws_cdk import aws_stepfunctions as sfn +def convert_reference_paths_in_mapping(parameters: Mapping[str, Any]) -> Mapping[str, Any]: + return {k: convert_reference_paths(v) for k, v in parameters.items()} + + def convert_reference_paths(parameters: JSON) -> JSON: if isinstance(parameters, dict): return {k: convert_reference_paths(v) for k, v in parameters.items()} diff --git a/src/aibs_informatics_cdk_lib/stacks/assets.py b/src/aibs_informatics_cdk_lib/stacks/assets.py new file mode 100644 index 0000000..a7a7eea --- /dev/null +++ b/src/aibs_informatics_cdk_lib/stacks/assets.py @@ -0,0 +1,29 @@ +from typing import Optional + +import constructs +from aibs_informatics_core.env import EnvBase + +from aibs_informatics_cdk_lib.constructs_.assets.code_asset_definitions import ( + AIBSInformaticsCodeAssets, + AIBSInformaticsDockerAssets, +) +from aibs_informatics_cdk_lib.stacks.base import EnvBaseStack + + +class AIBSInformaticsAssetsStack(EnvBaseStack): + def __init__( + self, + scope: constructs.Construct, + id: Optional[str], + env_base: EnvBase, + **kwargs, + ): + super().__init__(scope, id, env_base, **kwargs) + self.code_assets = AIBSInformaticsCodeAssets( + self, + "aibs-info-code-assets", + self.env_base, + ) + self.docker_assets = AIBSInformaticsDockerAssets( + self, "aibs-info-docker-assets", self.env_base + ) diff --git a/src/aibs_informatics_cdk_lib/stacks/base.py b/src/aibs_informatics_cdk_lib/stacks/base.py index 691ec7b..1626164 100644 --- a/src/aibs_informatics_cdk_lib/stacks/base.py +++ b/src/aibs_informatics_cdk_lib/stacks/base.py @@ -45,13 +45,21 @@ def __init__( ) -> None: super().__init__( scope, - id or env_base.get_construct_id(self.__class__), + id or env_base.get_construct_id(str(self.__class__)), env=env, **kwargs, ) self.env_base = env_base self.add_tags(*self.stack_tags) + @property + def aws_region(self) -> str: + return cdk.Stack.of(self).region + + @property + def aws_account(self) -> str: + return cdk.Stack.of(self).account + @property def stack_tags(self) -> List[cdk.Tag]: return [ diff --git a/src/aibs_informatics_cdk_lib/stacks/compute.py b/src/aibs_informatics_cdk_lib/stacks/compute.py new file mode 100644 index 0000000..59b9631 --- /dev/null +++ b/src/aibs_informatics_cdk_lib/stacks/compute.py @@ -0,0 +1,405 @@ +from abc import abstractmethod +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +from aibs_informatics_aws_utils import AWS_REGION_VAR +from aibs_informatics_core.env import EnvBase +from aws_cdk import aws_batch_alpha as batch +from aws_cdk import aws_ec2 as ec2 +from aws_cdk import aws_efs as efs +from aws_cdk import aws_iam as iam +from aws_cdk import aws_s3 as s3 +from aws_cdk import aws_stepfunctions as sfn +from constructs import Construct + +from aibs_informatics_cdk_lib.common.aws.iam_utils import ( + batch_policy_statement, + s3_policy_statement, +) +from aibs_informatics_cdk_lib.constructs_.batch.infrastructure import ( + Batch, + BatchEnvironment, + BatchEnvironmentConfig, +) +from aibs_informatics_cdk_lib.constructs_.batch.instance_types import ( + LAMBDA_LARGE_INSTANCE_TYPES, + LAMBDA_MEDIUM_INSTANCE_TYPES, + LAMBDA_SMALL_INSTANCE_TYPES, + ON_DEMAND_INSTANCE_TYPES, + SPOT_INSTANCE_TYPES, +) +from aibs_informatics_cdk_lib.constructs_.batch.launch_template import BatchLaunchTemplateBuilder +from aibs_informatics_cdk_lib.constructs_.batch.types import BatchEnvironmentDescriptor +from aibs_informatics_cdk_lib.constructs_.efs.file_system import MountPointConfiguration +from aibs_informatics_cdk_lib.constructs_.sfn.fragments.base import create_state_machine +from aibs_informatics_cdk_lib.constructs_.sfn.fragments.batch import ( + AWSBatchMixins, + SubmitJobWithDefaultsFragment, +) +from aibs_informatics_cdk_lib.constructs_.sfn.fragments.informatics import ( + BatchInvokedLambdaFunction, +) +from aibs_informatics_cdk_lib.stacks.base import EnvBaseStack + + +class BaseComputeStack(EnvBaseStack): + def __init__( + self, + scope: Construct, + id: Optional[str], + env_base: EnvBase, + vpc: ec2.Vpc, + batch_name: str, + buckets: Optional[Iterable[s3.Bucket]] = None, + file_systems: Optional[Iterable[Union[efs.FileSystem, efs.IFileSystem]]] = None, + mount_point_configs: Optional[Iterable[MountPointConfiguration]] = None, + **kwargs, + ) -> None: + super().__init__(scope, id, env_base, **kwargs) + self.batch_name = batch_name + self.batch = Batch(self, batch_name, self.env_base, vpc=vpc) + + self.create_batch_environments() + + bucket_list = list(buckets or []) + + file_system_list = list(file_systems or []) + + if mount_point_configs: + mount_point_config_list = list(mount_point_configs) + file_system_list = self._update_file_systems_from_mount_point_configs( + file_system_list, mount_point_config_list + ) + else: + mount_point_config_list = self._get_mount_point_configs(file_system_list) + + # Validation to ensure that the file systems are not duplicated + self._validate_mount_point_configs(mount_point_config_list) + + self.grant_storage_access(*bucket_list, *file_system_list) + + @property + @abstractmethod + def primary_batch_environment(self) -> BatchEnvironment: + raise NotImplementedError() + + @abstractmethod + def create_batch_environments(self): + raise NotImplementedError() + + @property + def name(self) -> str: + return self.batch_name + + def grant_storage_access(self, *resources: Union[s3.Bucket, efs.FileSystem, efs.IFileSystem]): + self.batch.grant_instance_role_permissions(read_write_resources=list(resources)) + + for batch_environment in self.batch.environments: + for resource in resources: + if isinstance(resource, efs.FileSystem): + batch_environment.grant_file_system_access(resource) + + def _validate_mount_point_configs(self, mount_point_configs: List[MountPointConfiguration]): + _ = {} + for mpc in mount_point_configs: + if mpc.mount_point in _ and _[mpc.mount_point] != mpc: + raise ValueError( + f"Mount point {mpc.mount_point} is duplicated. " + "Cannot have multiple mount points configurations with the same name." + ) + _[mpc.mount_point] = mpc + + def _get_mount_point_configs( + self, file_systems: Optional[List[Union[efs.FileSystem, efs.IFileSystem]]] + ) -> List[MountPointConfiguration]: + mount_point_configs = [] + if file_systems: + for fs in file_systems: + mount_point_configs.append(MountPointConfiguration.from_file_system(fs)) + return mount_point_configs + + def _update_file_systems_from_mount_point_configs( + self, + file_systems: List[Union[efs.FileSystem, efs.IFileSystem]], + mount_point_configs: List[MountPointConfiguration], + ) -> List[Union[efs.FileSystem, efs.IFileSystem]]: + file_system_map: dict[str, Union[efs.FileSystem, efs.IFileSystem]] = { + fs.file_system_id: fs for fs in file_systems + } + for mpc in mount_point_configs: + if mpc.file_system_id not in file_system_map: + if not mpc.file_system and mpc.access_point: + file_system_map[mpc.file_system_id] = mpc.access_point.file_system + elif mpc.file_system: + file_system_map[mpc.file_system_id] = mpc.file_system + else: + raise ValueError( + "Mount point configuration must have a file system or access point." + ) + + return list(file_system_map.values()) + + +class ComputeStack(BaseComputeStack): + @property + def primary_batch_environment(self) -> BatchEnvironment: + return self.on_demand_batch_environment + + def create_batch_environments(self): + lt_builder = BatchLaunchTemplateBuilder( + self, f"{self.name}-lt-builder", env_base=self.env_base + ) + self.on_demand_batch_environment = self.batch.setup_batch_environment( + descriptor=BatchEnvironmentDescriptor("on-demand"), + config=BatchEnvironmentConfig( + allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE, + instance_types=[ec2.InstanceType(_) for _ in ON_DEMAND_INSTANCE_TYPES], + use_spot=False, + use_fargate=False, + use_public_subnets=False, + ), + launch_template_builder=lt_builder, + ) + + self.spot_batch_environment = self.batch.setup_batch_environment( + descriptor=BatchEnvironmentDescriptor("spot"), + config=BatchEnvironmentConfig( + allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE, + instance_types=[ec2.InstanceType(_) for _ in SPOT_INSTANCE_TYPES], + use_spot=True, + use_fargate=False, + use_public_subnets=False, + ), + launch_template_builder=lt_builder, + ) + + self.fargate_batch_environment = self.batch.setup_batch_environment( + descriptor=BatchEnvironmentDescriptor("fargate"), + config=BatchEnvironmentConfig( + allocation_strategy=None, + instance_types=None, + use_spot=False, + use_fargate=True, + use_public_subnets=False, + ), + launch_template_builder=lt_builder, + ) + + +class LambdaComputeStack(ComputeStack): + @property + def primary_batch_environment(self) -> BatchEnvironment: + return self.lambda_batch_environment + + def create_batch_environments(self): + lt_builder = BatchLaunchTemplateBuilder( + self, f"${self.name}-lt-builder", env_base=self.env_base + ) + self.lambda_batch_environment = self.batch.setup_batch_environment( + descriptor=BatchEnvironmentDescriptor("lambda"), + config=BatchEnvironmentConfig( + allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE, + instance_types=[ + *LAMBDA_SMALL_INSTANCE_TYPES, + *LAMBDA_MEDIUM_INSTANCE_TYPES, + *LAMBDA_LARGE_INSTANCE_TYPES, + ], + use_spot=False, + use_fargate=False, + use_public_subnets=False, + ), + launch_template_builder=lt_builder, + ) + + self.lambda_small_batch_environment = self.batch.setup_batch_environment( + descriptor=BatchEnvironmentDescriptor("lambda-small"), + config=BatchEnvironmentConfig( + allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE, + instance_types=[*LAMBDA_SMALL_INSTANCE_TYPES], + use_spot=False, + use_fargate=False, + use_public_subnets=False, + ), + launch_template_builder=lt_builder, + ) + + self.lambda_medium_batch_environment = self.batch.setup_batch_environment( + descriptor=BatchEnvironmentDescriptor("lambda-medium"), + config=BatchEnvironmentConfig( + allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE, + instance_types=[*LAMBDA_MEDIUM_INSTANCE_TYPES], + use_spot=False, + use_fargate=False, + use_public_subnets=False, + ), + launch_template_builder=lt_builder, + ) + + self.lambda_large_batch_environment = self.batch.setup_batch_environment( + descriptor=BatchEnvironmentDescriptor("lambda-large"), + config=BatchEnvironmentConfig( + allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE, + instance_types=[*LAMBDA_LARGE_INSTANCE_TYPES], + use_spot=False, + use_fargate=False, + use_public_subnets=False, + ), + launch_template_builder=lt_builder, + ) + + +class ComputeWorkflowStack(EnvBaseStack): + def __init__( + self, + scope: Construct, + id: Optional[str], + env_base: EnvBase, + batch_environment: BatchEnvironment, + primary_bucket: s3.Bucket, + buckets: Optional[Iterable[s3.Bucket]] = None, + mount_point_configs: Optional[Iterable[MountPointConfiguration]] = None, + **kwargs, + ) -> None: + super().__init__(scope, id, env_base, **kwargs) + self.primary_bucket = primary_bucket + self.buckets = list(buckets or []) + self.batch_environment = batch_environment + self.mount_point_configs = list(mount_point_configs) if mount_point_configs else None + + self.create_submit_job_step_function() + self.create_lambda_invoke_step_function() + + def create_submit_job_step_function(self): + fragment = SubmitJobWithDefaultsFragment( + self, + "submit-job-fragment", + self.env_base, + job_queue=self.batch_environment.job_queue.job_queue_arn, + mount_point_configs=self.mount_point_configs, + ) + state_machine_name = self.get_resource_name("submit-job") + + self.batch_submit_job_state_machine = fragment.to_state_machine( + state_machine_name=state_machine_name, + role=iam.Role( + self, + self.env_base.get_construct_id(state_machine_name, "role"), + assumed_by=iam.ServicePrincipal("states.amazonaws.com"), # type: ignore + managed_policies=[ + iam.ManagedPolicy.from_aws_managed_policy_name( + "AmazonAPIGatewayInvokeFullAccess" + ), + iam.ManagedPolicy.from_aws_managed_policy_name("AWSStepFunctionsFullAccess"), + iam.ManagedPolicy.from_aws_managed_policy_name("CloudWatchLogsFullAccess"), + iam.ManagedPolicy.from_aws_managed_policy_name("CloudWatchEventsFullAccess"), + ], + inline_policies={ + "default": iam.PolicyDocument( + statements=[ + batch_policy_statement(self.env_base), + ] + ), + }, + ), + ) + + def create_lambda_invoke_step_function( + self, + ): + defaults: dict[str, Any] = {} + + defaults["job_queue"] = self.batch_environment.job_queue_name + defaults["memory"] = "1024" + defaults["vcpus"] = "1" + defaults["gpu"] = "0" + defaults["platform_capabilities"] = ["EC2"] + + if self.mount_point_configs: + mount_points, volumes = AWSBatchMixins.convert_to_mount_point_and_volumes( + self.mount_point_configs + ) + defaults["mount_points"] = mount_points + defaults["volumes"] = volumes + + start = sfn.Pass( + self, + "Start", + parameters={ + "image": sfn.JsonPath.string_at("$.image"), + "handler": sfn.JsonPath.string_at("$.handler"), + "payload": sfn.JsonPath.object_at("$.request"), + # We will merge the rest with the defaults + "input": sfn.JsonPath.object_at("$"), + "default": defaults, + }, + ) + + merge = sfn.Pass( + self, + "Merge", + parameters={ + "image": sfn.JsonPath.string_at("$.image"), + "handler": sfn.JsonPath.string_at("$.handler"), + "payload": sfn.JsonPath.string_at("$.payload"), + "compute": sfn.JsonPath.json_merge( + sfn.JsonPath.object_at("$.default"), sfn.JsonPath.object_at("$.input") + ), + }, + ) + + batch_invoked_lambda = BatchInvokedLambdaFunction( + self, + "Data Chain", + env_base=self.env_base, + image="$.image", + name="run-lambda-function", + handler="$.handler", + payload_path="$.payload", + bucket_name=self.primary_bucket.bucket_name, + job_queue=self.batch_environment.job_queue_name, + environment={ + EnvBase.ENV_BASE_KEY: self.env_base, + "AWS_REGION": self.aws_region, + "AWS_ACCOUNT_ID": self.aws_account, + }, + memory=sfn.JsonPath.string_at("$.compute.memory"), + vcpus=sfn.JsonPath.string_at("$.compute.vcpus"), + mount_points=sfn.JsonPath.string_at("$.compute.mount_points"), + volumes=sfn.JsonPath.string_at("$.compute.volumes"), + platform_capabilities=sfn.JsonPath.string_at("$.compute.platform_capabilities"), + ) + + # fmt: off + definition = ( + start + .next(merge) + .next(batch_invoked_lambda.to_single_state()) + ) + # fmt: on + + self.batch_invoked_lambda_state_machine = create_state_machine( + self, + env_base=self.env_base, + name=self.env_base.get_state_machine_name("batch-invoked-lambda"), + definition=definition, + role=iam.Role( + self, + self.env_base.get_construct_id("batch-invoked-lambda", "role"), + assumed_by=iam.ServicePrincipal("states.amazonaws.com"), # type: ignore + managed_policies=[ + iam.ManagedPolicy.from_aws_managed_policy_name( + "AmazonAPIGatewayInvokeFullAccess" + ), + iam.ManagedPolicy.from_aws_managed_policy_name("AWSStepFunctionsFullAccess"), + iam.ManagedPolicy.from_aws_managed_policy_name("CloudWatchLogsFullAccess"), + iam.ManagedPolicy.from_aws_managed_policy_name("CloudWatchEventsFullAccess"), + ], + inline_policies={ + "default": iam.PolicyDocument( + statements=[ + batch_policy_statement(self.env_base), + s3_policy_statement(self.env_base), + ] + ), + }, + ), + ) diff --git a/src/aibs_informatics_cdk_lib/stacks/core.py b/src/aibs_informatics_cdk_lib/stacks/core.py deleted file mode 100644 index a3d3114..0000000 --- a/src/aibs_informatics_cdk_lib/stacks/core.py +++ /dev/null @@ -1,332 +0,0 @@ -import json -from email.policy import default -from typing import Any, Iterable, List, Optional, TypeVar, Union -from urllib import request - -import aws_cdk as cdk -from aibs_informatics_aws_utils.batch import to_mount_point, to_volume -from aibs_informatics_aws_utils.constants.efs import ( - EFS_ROOT_ACCESS_POINT_TAG, - EFS_ROOT_PATH, - EFS_SCRATCH_ACCESS_POINT_TAG, - EFS_SCRATCH_PATH, - EFS_SHARED_ACCESS_POINT_TAG, - EFS_SHARED_PATH, - EFS_TMP_ACCESS_POINT_TAG, - EFS_TMP_PATH, -) -from aibs_informatics_core.env import EnvBase -from aibs_informatics_core.utils.tools.dicttools import convert_key_case -from aibs_informatics_core.utils.tools.strtools import pascalcase -from aws_cdk import aws_batch_alpha as batch -from aws_cdk import aws_ec2 as ec2 -from aws_cdk import aws_efs as efs -from aws_cdk import aws_iam as iam -from aws_cdk import aws_lambda as lambda_ -from aws_cdk import aws_logs as logs -from aws_cdk import aws_s3 as s3 -from aws_cdk import aws_stepfunctions as sfn -from constructs import Construct - -from aibs_informatics_cdk_lib.common.aws.iam_utils import batch_policy_statement -from aibs_informatics_cdk_lib.constructs_.batch.infrastructure import Batch, BatchEnvironmentConfig -from aibs_informatics_cdk_lib.constructs_.batch.instance_types import ( - ON_DEMAND_INSTANCE_TYPES, - SPOT_INSTANCE_TYPES, -) -from aibs_informatics_cdk_lib.constructs_.batch.launch_template import BatchLaunchTemplateBuilder -from aibs_informatics_cdk_lib.constructs_.batch.types import BatchEnvironmentDescriptor -from aibs_informatics_cdk_lib.constructs_.ec2 import EnvBaseVpc -from aibs_informatics_cdk_lib.constructs_.efs.file_system import ( - EFSEcosystem, - EnvBaseFileSystem, - MountPointConfiguration, -) -from aibs_informatics_cdk_lib.constructs_.s3 import EnvBaseBucket, LifecycleRuleGenerator -from aibs_informatics_cdk_lib.constructs_.sfn.fragments.batch import SubmitJobFragment -from aibs_informatics_cdk_lib.stacks.base import EnvBaseStack - - -class NetworkStack(EnvBaseStack): - def __init__(self, scope: Construct, id: Optional[str], env_base: EnvBase, **kwargs) -> None: - super().__init__(scope, id, env_base, **kwargs) - self._vpc = EnvBaseVpc(self, "Vpc", self.env_base, max_azs=4) - - @property - def vpc(self) -> EnvBaseVpc: - return self._vpc - - -class StorageStack(EnvBaseStack): - def __init__( - self, - scope: Construct, - id: Optional[str], - env_base: EnvBase, - name: str, - vpc: ec2.Vpc, - **kwargs, - ) -> None: - super().__init__(scope, id, env_base, **kwargs) - - self._bucket = EnvBaseBucket( - self, - "Bucket", - self.env_base, - bucket_name=name, - removal_policy=self.removal_policy, - lifecycle_rules=[ - LifecycleRuleGenerator.expire_files_under_prefix(), - LifecycleRuleGenerator.expire_files_with_scratch_tags(), - LifecycleRuleGenerator.use_storage_class_as_default(), - ], - ) - - self._efs_ecosystem = EFSEcosystem(self, "EFS", self.env_base, name, vpc=vpc) - self._file_system = self._efs_ecosystem.file_system - - @property - def bucket(self) -> EnvBaseBucket: - return self._bucket - - @property - def file_system(self) -> EnvBaseFileSystem: - return self._file_system - - -class ComputeStack(EnvBaseStack): - def __init__( - self, - scope: Construct, - id: Optional[str], - env_base: EnvBase, - vpc: ec2.Vpc, - buckets: Optional[Iterable[s3.Bucket]] = None, - file_systems: Optional[Iterable[Union[efs.FileSystem, efs.IFileSystem]]] = None, - mount_point_configs: Optional[Iterable[MountPointConfiguration]] = None, - create_state_machine: bool = True, - state_machine_name: Optional[str] = "submit-job", - **kwargs, - ) -> None: - super().__init__(scope, id, env_base, **kwargs) - - self.batch = Batch(self, "Batch", self.env_base, vpc=vpc) - - self.create_batch_environments() - - bucket_list = list(buckets or []) - - file_system_list = list(file_systems or []) - - if mount_point_configs: - mount_point_config_list = list(mount_point_configs) - file_system_list = self._update_file_systems_from_mount_point_configs( - file_system_list, mount_point_config_list - ) - else: - mount_point_config_list = self._get_mount_point_configs(file_system_list) - - # Validation to ensure that the file systems are not duplicated - self._validate_mount_point_configs(mount_point_config_list) - - self.grant_storage_access(*bucket_list, *file_system_list) - - self.create_step_functions( - name=state_machine_name, mount_point_configs=mount_point_config_list - ) - - self.export_values() - - def grant_storage_access(self, *resources: Union[s3.Bucket, efs.FileSystem, efs.IFileSystem]): - self.batch.grant_instance_role_permissions(read_write_resources=list(resources)) - - for batch_environment in self.batch.environments: - for resource in resources: - if isinstance(resource, efs.FileSystem): - batch_environment.grant_file_system_access(resource) - - def create_batch_environments(self): - lt_builder = BatchLaunchTemplateBuilder(self, "lt-builder", env_base=self.env_base) - self.on_demand_batch_environment = self.batch.setup_batch_environment( - descriptor=BatchEnvironmentDescriptor("on-demand"), - config=BatchEnvironmentConfig( - allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE, - instance_types=list(map(ec2.InstanceType, ON_DEMAND_INSTANCE_TYPES)), - use_spot=False, - use_fargate=False, - use_public_subnets=False, - ), - launch_template_builder=lt_builder, - ) - - self.spot_batch_environment = self.batch.setup_batch_environment( - descriptor=BatchEnvironmentDescriptor("spot"), - config=BatchEnvironmentConfig( - allocation_strategy=batch.AllocationStrategy.BEST_FIT_PROGRESSIVE, - instance_types=list(map(ec2.InstanceType, SPOT_INSTANCE_TYPES)), - use_spot=True, - use_fargate=False, - use_public_subnets=False, - ), - launch_template_builder=lt_builder, - ) - - self.fargate_batch_environment = self.batch.setup_batch_environment( - descriptor=BatchEnvironmentDescriptor("fargate"), - config=BatchEnvironmentConfig( - allocation_strategy=None, - instance_types=None, - use_spot=False, - use_fargate=True, - use_public_subnets=False, - ), - launch_template_builder=lt_builder, - ) - - def create_step_functions( - self, - name: Optional[str] = None, - mount_point_configs: Optional[list[MountPointConfiguration]] = None, - ): - state_machine_core_name = name or "submit-job" - defaults: dict[str, Any] = {} - defaults["command"] = [] - defaults["job_queue"] = self.on_demand_batch_environment.job_queue.job_queue_arn - defaults["environment"] = [] - defaults["memory"] = "1024" - defaults["vcpus"] = "1" - defaults["gpu"] = "0" - defaults["platform_capabilities"] = ["EC2"] - - if mount_point_configs: - defaults["mount_points"] = [ - convert_key_case(mpc.to_batch_mount_point(f"efs-vol{i}"), pascalcase) - for i, mpc in enumerate(mount_point_configs) - ] - defaults["volumes"] = [ - convert_key_case(mpc.to_batch_volume(f"efs-vol{i}"), pascalcase) - for i, mpc in enumerate(mount_point_configs) - ] - - start = sfn.Pass( - self, - "Start", - parameters={ - "input": sfn.JsonPath.string_at("$"), - "default": defaults, - }, - ) - - merge = sfn.Pass( - self, - "Merge", - parameters={ - "request": sfn.JsonPath.json_merge( - sfn.JsonPath.object_at("$.default"), sfn.JsonPath.object_at("$.input") - ), - }, - ) - - submit_job = SubmitJobFragment( - self, - "SubmitJob", - env_base=self.env_base, - name="SubmitJobCore", - image=sfn.JsonPath.string_at("$.request.image"), - command=sfn.JsonPath.string_at("$.request.command"), - job_queue=sfn.JsonPath.string_at("$.request.job_queue"), - environment=sfn.JsonPath.string_at("$.request.environment"), - memory=sfn.JsonPath.string_at("$.request.memory"), - vcpus=sfn.JsonPath.string_at("$.request.vcpus"), - # TODO: Handle GPU parameter better - right now, we cannot handle cases where it is - # not specified. Setting to zero causes issues with the Batch API. - # If it is set to zero, then the json list of resources are not properly set. - # gpu=sfn.JsonPath.string_at("$.request.gpu"), - mount_points=sfn.JsonPath.string_at("$.request.mount_points"), - volumes=sfn.JsonPath.string_at("$.request.volumes"), - platform_capabilities=sfn.JsonPath.string_at("$.request.platform_capabilities"), - ).to_single_state() - - definition = start.next(merge).next(submit_job) - - state_machine_name = self.get_resource_name(state_machine_core_name) - self.batch_submit_job_state_machine = sfn.StateMachine( - self, - self.env_base.get_construct_id(state_machine_name, "state-machine"), - state_machine_name=state_machine_name, - logs=sfn.LogOptions( - destination=logs.LogGroup( - self, - self.get_construct_id(state_machine_name, "state-loggroup"), - log_group_name=self.env_base.get_state_machine_log_group_name("submit-job"), - removal_policy=cdk.RemovalPolicy.DESTROY, - retention=logs.RetentionDays.ONE_MONTH, - ) - ), - role=iam.Role( - self, - self.env_base.get_construct_id(state_machine_name, "role"), - assumed_by=iam.ServicePrincipal("states.amazonaws.com"), # type: ignore - managed_policies=[ - iam.ManagedPolicy.from_aws_managed_policy_name( - "AmazonAPIGatewayInvokeFullAccess" - ), - iam.ManagedPolicy.from_aws_managed_policy_name("AWSStepFunctionsFullAccess"), - iam.ManagedPolicy.from_aws_managed_policy_name("CloudWatchLogsFullAccess"), - iam.ManagedPolicy.from_aws_managed_policy_name("CloudWatchEventsFullAccess"), - ], - inline_policies={ - "default": iam.PolicyDocument( - statements=[batch_policy_statement(self.env_base)] - ), - }, - ), - definition_body=sfn.DefinitionBody.from_chainable(definition), - ) - - def export_values(self) -> None: - self.export_value(self.on_demand_batch_environment.job_queue.job_queue_arn) - self.export_value(self.spot_batch_environment.job_queue.job_queue_arn) - self.export_value(self.fargate_batch_environment.job_queue.job_queue_arn) - - ## Private methods - - def _validate_mount_point_configs(self, mount_point_configs: List[MountPointConfiguration]): - _ = {} - for mpc in mount_point_configs: - if mpc.mount_point in _ and _[mpc.mount_point] != mpc: - raise ValueError( - f"Mount point {mpc.mount_point} is duplicated. " - "Cannot have multiple mount points configurations with the same name." - ) - _[mpc.mount_point] = mpc - - def _get_mount_point_configs( - self, file_systems: Optional[List[Union[efs.FileSystem, efs.IFileSystem]]] - ) -> List[MountPointConfiguration]: - mount_point_configs = [] - if file_systems: - for fs in file_systems: - mount_point_configs.append(MountPointConfiguration.from_file_system(fs)) - return mount_point_configs - - def _update_file_systems_from_mount_point_configs( - self, - file_systems: List[Union[efs.FileSystem, efs.IFileSystem]], - mount_point_configs: List[MountPointConfiguration], - ) -> List[Union[efs.FileSystem, efs.IFileSystem]]: - file_system_map: dict[str, Union[efs.FileSystem, efs.IFileSystem]] = { - fs.file_system_id: fs for fs in file_systems - } - for mpc in mount_point_configs: - if mpc.file_system_id not in file_system_map: - if not mpc.file_system and mpc.access_point: - file_system_map[mpc.file_system_id] = mpc.access_point.file_system - elif mpc.file_system: - file_system_map[mpc.file_system_id] = mpc.file_system - else: - raise ValueError( - "Mount point configuration must have a file system or access point." - ) - - return list(file_system_map.values()) diff --git a/src/aibs_informatics_cdk_lib/stacks/network.py b/src/aibs_informatics_cdk_lib/stacks/network.py new file mode 100644 index 0000000..ebb97b2 --- /dev/null +++ b/src/aibs_informatics_cdk_lib/stacks/network.py @@ -0,0 +1,17 @@ +from typing import Optional + +from aibs_informatics_core.env import EnvBase +from constructs import Construct + +from aibs_informatics_cdk_lib.constructs_.ec2 import EnvBaseVpc +from aibs_informatics_cdk_lib.stacks.base import EnvBaseStack + + +class NetworkStack(EnvBaseStack): + def __init__(self, scope: Construct, id: Optional[str], env_base: EnvBase, **kwargs) -> None: + super().__init__(scope, id, env_base, **kwargs) + self._vpc = EnvBaseVpc(self, "Vpc", self.env_base, max_azs=4) + + @property + def vpc(self) -> EnvBaseVpc: + return self._vpc diff --git a/src/aibs_informatics_cdk_lib/stacks/storage.py b/src/aibs_informatics_cdk_lib/stacks/storage.py new file mode 100644 index 0000000..3cdb36c --- /dev/null +++ b/src/aibs_informatics_cdk_lib/stacks/storage.py @@ -0,0 +1,96 @@ +import json +from email.policy import default +from typing import Any, Iterable, List, Optional, TypeVar, Union +from urllib import request + +import aws_cdk as cdk +from aibs_informatics_aws_utils.batch import to_mount_point, to_volume +from aibs_informatics_aws_utils.constants.efs import ( + EFS_ROOT_ACCESS_POINT_TAG, + EFS_ROOT_PATH, + EFS_SCRATCH_ACCESS_POINT_TAG, + EFS_SCRATCH_PATH, + EFS_SHARED_ACCESS_POINT_TAG, + EFS_SHARED_PATH, + EFS_TMP_ACCESS_POINT_TAG, + EFS_TMP_PATH, +) +from aibs_informatics_core.env import EnvBase +from aibs_informatics_core.utils.tools.dicttools import convert_key_case +from aibs_informatics_core.utils.tools.strtools import pascalcase +from aws_cdk import aws_batch_alpha as batch +from aws_cdk import aws_ec2 as ec2 +from aws_cdk import aws_efs as efs +from aws_cdk import aws_iam as iam +from aws_cdk import aws_lambda as lambda_ +from aws_cdk import aws_logs as logs +from aws_cdk import aws_s3 as s3 +from aws_cdk import aws_stepfunctions as sfn +from constructs import Construct + +from aibs_informatics_cdk_lib.common.aws.iam_utils import batch_policy_statement +from aibs_informatics_cdk_lib.constructs_.assets.code_asset_definitions import ( + AIBSInformaticsCodeAssets, +) +from aibs_informatics_cdk_lib.constructs_.batch.infrastructure import Batch, BatchEnvironmentConfig +from aibs_informatics_cdk_lib.constructs_.batch.instance_types import ( + ON_DEMAND_INSTANCE_TYPES, + SPOT_INSTANCE_TYPES, +) +from aibs_informatics_cdk_lib.constructs_.batch.launch_template import BatchLaunchTemplateBuilder +from aibs_informatics_cdk_lib.constructs_.batch.types import BatchEnvironmentDescriptor +from aibs_informatics_cdk_lib.constructs_.ec2 import EnvBaseVpc +from aibs_informatics_cdk_lib.constructs_.efs.file_system import ( + EFSEcosystem, + EnvBaseFileSystem, + MountPointConfiguration, +) +from aibs_informatics_cdk_lib.constructs_.s3 import EnvBaseBucket, LifecycleRuleGenerator +from aibs_informatics_cdk_lib.constructs_.sfn.fragments.batch import ( + SubmitJobFragment, + SubmitJobWithDefaultsFragment, +) +from aibs_informatics_cdk_lib.stacks.base import EnvBaseStack + + +class StorageStack(EnvBaseStack): + def __init__( + self, + scope: Construct, + id: Optional[str], + env_base: EnvBase, + name: str, + vpc: ec2.Vpc, + **kwargs, + ) -> None: + super().__init__(scope, id, env_base, **kwargs) + + self._bucket = EnvBaseBucket( + self, + "Bucket", + self.env_base, + bucket_name=name, + removal_policy=self.removal_policy, + lifecycle_rules=[ + LifecycleRuleGenerator.expire_files_under_prefix(), + LifecycleRuleGenerator.expire_files_with_scratch_tags(), + LifecycleRuleGenerator.use_storage_class_as_default(), + ], + ) + + self._efs_ecosystem = EFSEcosystem( + self, id="EFS", env_base=self.env_base, file_system_name=name, vpc=vpc + ) + self._file_system = self._efs_ecosystem.file_system + + @property + def bucket(self) -> EnvBaseBucket: + return self._bucket + + @property + def efs_ecosystem(self) -> EFSEcosystem: + return self._efs_ecosystem + + @property + def file_system(self) -> EnvBaseFileSystem: + return self._file_system diff --git a/src/aibs_informatics_core_app/app.py b/src/aibs_informatics_core_app/app.py index 4eb3f55..c57f545 100644 --- a/src/aibs_informatics_core_app/app.py +++ b/src/aibs_informatics_core_app/app.py @@ -6,49 +6,98 @@ import aws_cdk as cdk from constructs import Construct +from aibs_informatics_cdk_lib.constructs_.efs.file_system import MountPointConfiguration from aibs_informatics_cdk_lib.project.config import StageConfig -from aibs_informatics_cdk_lib.project.utils import get_config, resolve_repo_root -from aibs_informatics_cdk_lib.stacks import data_sync -from aibs_informatics_cdk_lib.stacks.core import ComputeStack, NetworkStack, StorageStack -from aibs_informatics_cdk_lib.stacks.data_sync import DataSyncStack +from aibs_informatics_cdk_lib.project.utils import get_config +from aibs_informatics_cdk_lib.stacks.assets import AIBSInformaticsAssetsStack +from aibs_informatics_cdk_lib.stacks.compute import ( + ComputeStack, + ComputeWorkflowStack, + LambdaComputeStack, +) +from aibs_informatics_cdk_lib.stacks.network import NetworkStack +from aibs_informatics_cdk_lib.stacks.storage import StorageStack from aibs_informatics_cdk_lib.stages.base import ConfigBasedStage class InfraStage(ConfigBasedStage): def __init__(self, scope: Construct, config: StageConfig, **kwargs) -> None: super().__init__(scope, "Infra", config, **kwargs) - network = NetworkStack(self, "Network", self.env_base, env=self.env) + assets = AIBSInformaticsAssetsStack( + self, + self.get_stack_name("Assets"), + self.env_base, + env=self.env, + ) + network = NetworkStack( + self, + self.get_stack_name("Network"), + self.env_base, + env=self.env, + ) storage = StorageStack( - self, "Storage", self.env_base, "core", vpc=network.vpc, env=self.env + self, + self.get_stack_name("Storage"), + self.env_base, + "core", + vpc=network.vpc, + env=self.env, ) + + efs_ecosystem = storage.efs_ecosystem + + ap_mount_point_configs = [ + MountPointConfiguration.from_access_point( + efs_ecosystem.scratch_access_point, "/opt/scratch" + ), + MountPointConfiguration.from_access_point(efs_ecosystem.tmp_access_point, "/opt/tmp"), + MountPointConfiguration.from_access_point( + efs_ecosystem.shared_access_point, "/opt/shared", read_only=True + ), + ] + fs_mount_point_configs = [ + MountPointConfiguration.from_file_system(storage.file_system, None, "/opt/efs"), + ] + compute = ComputeStack( self, - "Compute", + self.get_stack_name("Compute"), self.env_base, + batch_name="batch", vpc=network.vpc, buckets=[storage.bucket], file_systems=[storage.file_system], + mount_point_configs=fs_mount_point_configs, env=self.env, ) - # data_sync = DataSyncStack( - # self, - # "DataSync", - # self.env_base, - # asset_directory=Path(resolve_repo_root()) / ".." / "aibs-informatics-aws-lambda", - # vpc=network.vpc, - # primary_bucket=storage.bucket, - # s3_buckets=[], - # file_system=storage.file_system, - # batch_job_queue=compute.on_demand_batch_environment.job_queue, - # env=self.env, - # ) + lambda_compute = LambdaComputeStack( + self, + self.get_stack_name("LambdaCompute"), + self.env_base, + batch_name="lambda-batch", + vpc=network.vpc, + buckets=[storage.bucket], + file_systems=[storage.file_system], + mount_point_configs=fs_mount_point_configs, + env=self.env, + ) + + compute_workflow = ComputeWorkflowStack( + self, + self.get_stack_name("ComputeWorkflow"), + env_base=self.env_base, + batch_environment=lambda_compute.primary_batch_environment, + primary_bucket=storage.bucket, + mount_point_configs=fs_mount_point_configs, + env=self.env, + ) def main(): app = cdk.App() - config = get_config(app.node) + config: StageConfig = get_config(app.node) InfraStage(app, config)