diff --git a/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/batch.py b/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/batch.py index e169e31..dd2e379 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/batch.py +++ b/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/batch.py @@ -140,6 +140,7 @@ def from_defaults( defaults["job_queue"] = job_queue defaults["environment"] = environment or {} defaults["memory"] = memory + defaults["image"] = image defaults["vcpus"] = vcpus defaults["gpu"] = "0" defaults["platform_capabilities"] = ["EC2"] @@ -153,7 +154,7 @@ def from_defaults( scope, id, env_base=env_base, - name="SubmitJobCore", + name=name, image=sfn.JsonPath.string_at("$.request.image"), command=sfn.JsonPath.string_at("$.request.command"), job_queue=sfn.JsonPath.string_at("$.request.job_queue"), diff --git a/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/informatics/batch.py b/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/informatics/batch.py index 23d5da5..f3a5829 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/informatics/batch.py +++ b/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/informatics/batch.py @@ -1,4 +1,5 @@ from typing import TYPE_CHECKING, Any, Iterable, List, Literal, Mapping, Optional, Sequence, Union +from unittest import result import constructs from aibs_informatics_aws_utils.constants.lambda_ import ( @@ -116,11 +117,35 @@ def __init__( """ super().__init__(scope, id, env_base) key_prefix = key_prefix or S3_SCRATCH_KEY_PREFIX + request_key = sfn.JsonPath.format( - f"{key_prefix}{{}}/request.json", sfn.JsonPath.execution_name + f"{key_prefix}{{}}/{{}}/request.json", + sfn.JsonPath.execution_name, + sfn.JsonPath.string_at("$.taskResult.prep.task_id"), ) response_key = sfn.JsonPath.format( - f"{key_prefix}{{}}/response.json", sfn.JsonPath.execution_name + f"{key_prefix}{{}}/{{}}/response.json", + sfn.JsonPath.execution_name, + sfn.JsonPath.string_at("$.taskResult.prep.task_id"), + ) + + start = sfn.Pass( + self, + f"{id} Prep S3 Keys", + parameters={ + "task_id": sfn.JsonPath.uuid(), + # "requestKey": sfn.JsonPath.format( + # f"{key_prefix}{{}}/{{}}/request.json", + # sfn.JsonPath.execution_name, + # sfn.JsonPath.uuid(), + # ), + # "responseKey": sfn.JsonPath.format( + # f"{key_prefix}{{}}/{{}}/response.json", + # sfn.JsonPath.execution_name, + # sfn.JsonPath.uuid(), + # ), + }, + result_path="$.taskResult.prep", ) if mount_point_configs: @@ -146,7 +171,9 @@ def __init__( sfn.JsonPath.string_at("$.taskResult.put.Key"), ), AWS_LAMBDA_EVENT_RESPONSE_LOCATION_KEY: sfn.JsonPath.format( - "s3://{}/{}", bucket_name, response_key + "s3://{}/{}", + bucket_name, + response_key, ), EnvBase.ENV_BASE_KEY: self.env_base, "AWS_REGION": self.aws_region, @@ -182,7 +209,7 @@ def __init__( output_path="$[0]", ) - self.definition = put_payload.next(submit_job).next(get_response) + self.definition = start.next(put_payload).next(submit_job).next(get_response) @property def start_state(self) -> sfn.State: @@ -334,11 +361,25 @@ def __init__( """ super().__init__(scope, id, env_base) key_prefix = key_prefix or S3_SCRATCH_KEY_PREFIX + request_key = sfn.JsonPath.format( - f"{key_prefix}{{}}/request.json", sfn.JsonPath.execution_name + f"{key_prefix}{{}}/{{}}/request.json", + sfn.JsonPath.execution_name, + sfn.JsonPath.string_at("$.taskResult.prep.task_id"), ) response_key = sfn.JsonPath.format( - f"{key_prefix}{{}}/response.json", sfn.JsonPath.execution_name + f"{key_prefix}{{}}/{{}}/response.json", + sfn.JsonPath.execution_name, + sfn.JsonPath.string_at("$.taskResult.prep.task_id"), + ) + + start = sfn.Pass( + self, + f"{id} Prep S3 Keys", + parameters={ + "task_id": sfn.JsonPath.uuid(), + }, + result_path="$.taskResult.prep", ) if mount_point_configs: @@ -368,7 +409,7 @@ def __init__( "--input", sfn.JsonPath.format("s3://{}/{}", "$.Bucket", "$.Key"), "--output-location", - sfn.JsonPath.format("s3://{}/{}", bucket_name, response_key), + sfn.JsonPath.format("s3://{}/{}", bucket_name, "$.t"), ], image=image, environment=environment, diff --git a/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/informatics/data_sync.py b/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/informatics/data_sync.py index e4c5528..57cfcfd 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/informatics/data_sync.py +++ b/src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/informatics/data_sync.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Iterable, List, Optional, Union +from typing import Iterable, List, Optional, Union import constructs from aibs_informatics_core.env import EnvBase @@ -7,7 +7,6 @@ from aws_cdk import aws_iam as iam from aws_cdk import aws_s3 as s3 from aws_cdk import aws_stepfunctions as sfn -from aws_cdk import aws_stepfunctions_tasks as sfn_tasks from aibs_informatics_cdk_lib.common.aws.iam_utils import ( SFN_STATES_EXECUTION_ACTIONS, @@ -172,7 +171,7 @@ def __init__( id=f"{id}: Batch Data Sync", env_base=env_base, name="batch-data-sync", - payload_path="$.requests", + payload_path="$", image=( aibs_informatics_docker_asset if isinstance(aibs_informatics_docker_asset, str) diff --git a/src/aibs_informatics_cdk_lib/constructs_/sfn/states/batch.py b/src/aibs_informatics_cdk_lib/constructs_/sfn/states/batch.py index 0a938f7..dc76325 100644 --- a/src/aibs_informatics_cdk_lib/constructs_/sfn/states/batch.py +++ b/src/aibs_informatics_cdk_lib/constructs_/sfn/states/batch.py @@ -82,7 +82,10 @@ def register_job_definition( """ job_definition_name = sfn.JsonPath.format( - f"{{}}-{{}}", job_definition_name, sfn.JsonPath.execution_name + f"{{}}-{{}}-{{}}", + job_definition_name, + sfn.JsonPath.execution_name, + sfn.JsonPath.uuid(), ) if not isinstance(environment, str): environment_pairs = to_key_value_pairs(dict(environment or {})) @@ -151,7 +154,9 @@ def submit_job( result_path: Optional[str] = "$", output_path: Optional[str] = "$", ) -> sfn.Chain: - job_name = sfn.JsonPath.format(f"{job_name}-{{}}", sfn.JsonPath.execution_name) + job_name = sfn.JsonPath.format( + f"{job_name}-{{}}-{{}}", sfn.JsonPath.execution_name, sfn.JsonPath.uuid() + ) if not isinstance(environment, str): environment_pairs = to_key_value_pairs(dict(environment or {})) else: