diff --git a/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/informatics/__init__.py b/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/informatics/__init__.py index ecd71ff..003aa5e 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/informatics/__init__.py +++ b/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/informatics/__init__.py @@ -2,6 +2,7 @@ "BatchInvokedExecutorFragment", "BatchInvokedLambdaFunction", "DataSyncFragment", + "DistributedDataSyncFragment", "DemandExecutionFragment", "CleanFileSystemFragment", "CleanFileSystemTriggerConfig", @@ -13,6 +14,7 @@ ) from aibs_informatics_cdk_lib.constructs_.sfn.fragments.informatics.data_sync import ( DataSyncFragment, + DistributedDataSyncFragment, ) from aibs_informatics_cdk_lib.constructs_.sfn.fragments.informatics.demand_execution import ( DemandExecutionFragment, diff --git a/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/informatics/data_sync.py b/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/informatics/data_sync.py index fd4f0ac..e4c5528 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/informatics/data_sync.py +++ b/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/informatics/data_sync.py @@ -7,6 +7,7 @@ from aws_cdk import aws_iam as iam from aws_cdk import aws_s3 as s3 from aws_cdk import aws_stepfunctions as sfn +from aws_cdk import aws_stepfunctions_tasks as sfn_tasks from aibs_informatics_cdk_lib.common.aws.iam_utils import ( SFN_STATES_EXECUTION_ACTIONS, @@ -15,6 +16,7 @@ ) from aibs_informatics_cdk_lib.constructs_.base import EnvBaseConstructMixins from aibs_informatics_cdk_lib.constructs_.efs.file_system import MountPointConfiguration +from aibs_informatics_cdk_lib.constructs_.sfn.fragments.base import EnvBaseStateMachineFragment from aibs_informatics_cdk_lib.constructs_.sfn.fragments.informatics.batch import ( BatchInvokedBaseFragment, BatchInvokedLambdaFunction, @@ -110,3 +112,88 @@ def required_inline_policy_statements(self) -> List[iam.PolicyStatement]: actions=SFN_STATES_EXECUTION_ACTIONS + SFN_STATES_READ_ACCESS_ACTIONS, ), ] + + +class DistributedDataSyncFragment(BatchInvokedBaseFragment): + def __init__( + self, + scope: constructs.Construct, + id: str, + env_base: EnvBase, + aibs_informatics_docker_asset: Union[ecr_assets.DockerImageAsset, str], + batch_job_queue: Union[batch.JobQueue, str], + scaffolding_bucket: s3.Bucket, + mount_point_configs: Optional[Iterable[MountPointConfiguration]] = None, + ) -> None: + super().__init__(scope, id, env_base) + start_pass_state = sfn.Pass( + self, + f"{id}: Start", + parameters={ + "request": sfn.JsonPath.object_at("$"), + }, + ) + prep_batch_sync_task_name = "prep-batch-data-sync-requests" + + prep_batch_sync = BatchInvokedLambdaFunction( + scope=scope, + id=f"{id}: Prep Batch Data Sync", + env_base=env_base, + name=prep_batch_sync_task_name, + payload_path="$.request", + image=( + aibs_informatics_docker_asset + if isinstance(aibs_informatics_docker_asset, str) + else aibs_informatics_docker_asset.image_uri + ), + handler="aibs_informatics_aws_lambda.handlers.data_sync.prepare_batch_data_sync_handler", + job_queue=( + batch_job_queue + if isinstance(batch_job_queue, str) + else batch_job_queue.job_queue_name + ), + bucket_name=scaffolding_bucket.bucket_name, + memory=1024, + vcpus=1, + mount_point_configs=list(mount_point_configs) if mount_point_configs else None, + ).enclose(result_path=f"$.tasks.{prep_batch_sync_task_name}.response") + + batch_sync_map_state = sfn.Map( + self, + f"{id}: Batch Data Sync: Map Start", + comment="Runs requests for batch sync in parallel", + items_path=f"$.tasks.{prep_batch_sync_task_name}.response.requests", + result_path=sfn.JsonPath.DISCARD, + ) + + batch_sync_map_state.iterator( + BatchInvokedLambdaFunction( + scope=scope, + id=f"{id}: Batch Data Sync", + env_base=env_base, + name="batch-data-sync", + payload_path="$.requests", + image=( + aibs_informatics_docker_asset + if isinstance(aibs_informatics_docker_asset, str) + else aibs_informatics_docker_asset.image_uri + ), + handler="aibs_informatics_aws_lambda.handlers.data_sync.batch_data_sync_handler", + job_queue=( + batch_job_queue + if isinstance(batch_job_queue, str) + else batch_job_queue.job_queue_name + ), + bucket_name=scaffolding_bucket.bucket_name, + memory=2048, + vcpus=1, + mount_point_configs=list(mount_point_configs) if mount_point_configs else None, + ) + ) + # fmt: off + self.definition = ( + start_pass_state + .next(prep_batch_sync) + .next(batch_sync_map_state) + ) + # fmt: on diff --git a/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/informatics/demand_execution.py b/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/informatics/demand_execution.py index 4e57e8b..1e39773 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/informatics/demand_execution.py +++ b/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/informatics/demand_execution.py @@ -1,7 +1,8 @@ -from typing import Any, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import constructs from aibs_informatics_core.env import EnvBase +from aibs_informatics_core.utils.tools.dicttools import remove_null_values from aws_cdk import aws_batch_alpha as batch from aws_cdk import aws_ecr_assets as ecr_assets from aws_cdk import aws_iam as iam @@ -35,6 +36,8 @@ def __init__( data_sync_state_machine: sfn.StateMachine, shared_mount_point_config: Optional[MountPointConfiguration], scratch_mount_point_config: Optional[MountPointConfiguration], + tmp_mount_point_config: Optional[MountPointConfiguration] = None, + context_manager_configuration: Optional[Dict[str, Any]] = None, ) -> None: super().__init__(scope, id, env_base) @@ -72,7 +75,7 @@ def __init__( file_system_configurations = {} # Update arguments with mount points and volumes if provided - if shared_mount_point_config or scratch_mount_point_config: + if shared_mount_point_config or scratch_mount_point_config or tmp_mount_point_config: mount_points = [] volumes = [] if shared_mount_point_config: @@ -104,18 +107,34 @@ def __init__( volumes.append( scratch_mount_point_config.to_batch_volume("scratch", sfn_format=True) ) + if tmp_mount_point_config: + # update file system configurations for scaffolding function + file_system_configurations["tmp"] = { + "file_system": tmp_mount_point_config.file_system_id, + "access_point": tmp_mount_point_config.access_point_id, + "container_path": tmp_mount_point_config.mount_point, + } + # add to mount point and volumes list for batch invoked lambda functions + mount_points.append( + tmp_mount_point_config.to_batch_mount_point("tmp", sfn_format=True) + ) + volumes.append(tmp_mount_point_config.to_batch_volume("tmp", sfn_format=True)) batch_invoked_lambda_kwargs["mount_points"] = mount_points batch_invoked_lambda_kwargs["volumes"] = volumes + request = { + "demand_execution": sfn.JsonPath.object_at("$"), + "file_system_configurations": file_system_configurations, + } + if context_manager_configuration: + request["context_manager_configuration"] = context_manager_configuration + start_state = sfn.Pass( self, f"Start Demand Batch Task", parameters={ - "request": { - "demand_execution": sfn.JsonPath.object_at("$"), - "file_system_configurations": file_system_configurations, - } + "request": request, }, )