Skip to content

Commit

Permalink
fix sfn batch and s3 states to handle map type
Browse files Browse the repository at this point in the history
  • Loading branch information
rpmcginty committed Jun 26, 2024
1 parent 2b099e6 commit f41cc0d
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def from_defaults(
defaults["job_queue"] = job_queue
defaults["environment"] = environment or {}
defaults["memory"] = memory
defaults["image"] = image
defaults["vcpus"] = vcpus
defaults["gpu"] = "0"
defaults["platform_capabilities"] = ["EC2"]
Expand All @@ -153,7 +154,7 @@ def from_defaults(
scope,
id,
env_base=env_base,
name="SubmitJobCore",
name=name,
image=sfn.JsonPath.string_at("$.request.image"),
command=sfn.JsonPath.string_at("$.request.command"),
job_queue=sfn.JsonPath.string_at("$.request.job_queue"),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import TYPE_CHECKING, Any, Iterable, List, Literal, Mapping, Optional, Sequence, Union
from unittest import result

import constructs
from aibs_informatics_aws_utils.constants.lambda_ import (
Expand Down Expand Up @@ -116,11 +117,35 @@ def __init__(
"""
super().__init__(scope, id, env_base)
key_prefix = key_prefix or S3_SCRATCH_KEY_PREFIX

request_key = sfn.JsonPath.format(
f"{key_prefix}{{}}/request.json", sfn.JsonPath.execution_name
f"{key_prefix}{{}}/{{}}/request.json",
sfn.JsonPath.execution_name,
sfn.JsonPath.string_at("$.taskResult.prep.task_id"),
)
response_key = sfn.JsonPath.format(
f"{key_prefix}{{}}/response.json", sfn.JsonPath.execution_name
f"{key_prefix}{{}}/{{}}/response.json",
sfn.JsonPath.execution_name,
sfn.JsonPath.string_at("$.taskResult.prep.task_id"),
)

start = sfn.Pass(
self,
f"{id} Prep S3 Keys",
parameters={
"task_id": sfn.JsonPath.uuid(),
# "requestKey": sfn.JsonPath.format(
# f"{key_prefix}{{}}/{{}}/request.json",
# sfn.JsonPath.execution_name,
# sfn.JsonPath.uuid(),
# ),
# "responseKey": sfn.JsonPath.format(
# f"{key_prefix}{{}}/{{}}/response.json",
# sfn.JsonPath.execution_name,
# sfn.JsonPath.uuid(),
# ),
},
result_path="$.taskResult.prep",
)

if mount_point_configs:
Expand All @@ -146,7 +171,9 @@ def __init__(
sfn.JsonPath.string_at("$.taskResult.put.Key"),
),
AWS_LAMBDA_EVENT_RESPONSE_LOCATION_KEY: sfn.JsonPath.format(
"s3://{}/{}", bucket_name, response_key
"s3://{}/{}",
bucket_name,
response_key,
),
EnvBase.ENV_BASE_KEY: self.env_base,
"AWS_REGION": self.aws_region,
Expand Down Expand Up @@ -182,7 +209,7 @@ def __init__(
output_path="$[0]",
)

self.definition = put_payload.next(submit_job).next(get_response)
self.definition = start.next(put_payload).next(submit_job).next(get_response)

@property
def start_state(self) -> sfn.State:
Expand Down Expand Up @@ -334,11 +361,25 @@ def __init__(
"""
super().__init__(scope, id, env_base)
key_prefix = key_prefix or S3_SCRATCH_KEY_PREFIX

request_key = sfn.JsonPath.format(
f"{key_prefix}{{}}/request.json", sfn.JsonPath.execution_name
f"{key_prefix}{{}}/{{}}/request.json",
sfn.JsonPath.execution_name,
sfn.JsonPath.string_at("$.taskResult.prep.task_id"),
)
response_key = sfn.JsonPath.format(
f"{key_prefix}{{}}/response.json", sfn.JsonPath.execution_name
f"{key_prefix}{{}}/{{}}/response.json",
sfn.JsonPath.execution_name,
sfn.JsonPath.string_at("$.taskResult.prep.task_id"),
)

start = sfn.Pass(
self,
f"{id} Prep S3 Keys",
parameters={
"task_id": sfn.JsonPath.uuid(),
},
result_path="$.taskResult.prep",
)

if mount_point_configs:
Expand Down Expand Up @@ -368,7 +409,7 @@ def __init__(
"--input",
sfn.JsonPath.format("s3://{}/{}", "$.Bucket", "$.Key"),
"--output-location",
sfn.JsonPath.format("s3://{}/{}", bucket_name, response_key),
sfn.JsonPath.format("s3://{}/{}", bucket_name, "$.t"),
],
image=image,
environment=environment,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Iterable, List, Optional, Union
from typing import Iterable, List, Optional, Union

import constructs
from aibs_informatics_core.env import EnvBase
Expand All @@ -7,7 +7,6 @@
from aws_cdk import aws_iam as iam
from aws_cdk import aws_s3 as s3
from aws_cdk import aws_stepfunctions as sfn
from aws_cdk import aws_stepfunctions_tasks as sfn_tasks

from aibs_informatics_cdk_lib.common.aws.iam_utils import (
SFN_STATES_EXECUTION_ACTIONS,
Expand Down Expand Up @@ -172,7 +171,7 @@ def __init__(
id=f"{id}: Batch Data Sync",
env_base=env_base,
name="batch-data-sync",
payload_path="$.requests",
payload_path="$",
image=(
aibs_informatics_docker_asset
if isinstance(aibs_informatics_docker_asset, str)
Expand Down
9 changes: 7 additions & 2 deletions src/aibs_informatics_cdk_lib/constructs_/sfn/states/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ def register_job_definition(
"""

job_definition_name = sfn.JsonPath.format(
f"{{}}-{{}}", job_definition_name, sfn.JsonPath.execution_name
f"{{}}-{{}}-{{}}",
job_definition_name,
sfn.JsonPath.execution_name,
sfn.JsonPath.uuid(),
)
if not isinstance(environment, str):
environment_pairs = to_key_value_pairs(dict(environment or {}))
Expand Down Expand Up @@ -151,7 +154,9 @@ def submit_job(
result_path: Optional[str] = "$",
output_path: Optional[str] = "$",
) -> sfn.Chain:
job_name = sfn.JsonPath.format(f"{job_name}-{{}}", sfn.JsonPath.execution_name)
job_name = sfn.JsonPath.format(
f"{job_name}-{{}}-{{}}", sfn.JsonPath.execution_name, sfn.JsonPath.uuid()
)
if not isinstance(environment, str):
environment_pairs = to_key_value_pairs(dict(environment or {}))
else:
Expand Down

0 comments on commit f41cc0d

Please sign in to comment.