Skip to content

Commit

Permalink
Merge pull request #7 from AllenInstitute/feature/asset-and-sfn-changes
Browse files Browse the repository at this point in the history
Base changes, updates to Core stacks, Sfn SM utils
  • Loading branch information
rpmcginty authored May 22, 2024
2 parents 732c1a9 + 4125d4a commit bb40e93
Show file tree
Hide file tree
Showing 19 changed files with 1,021 additions and 455 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ no_site_packages = true
# Untyped definitions and calls
# https://mypy.readthedocs.io/en/stable/config_file.html#untyped-definitions-and-calls
# TODO: enable and fix errors
check_untyped_defs = false
check_untyped_defs = true


# Miscellaneous strictness flags
Expand Down
8 changes: 6 additions & 2 deletions src/aibs_informatics_cdk_lib/common/aws/iam_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@
build_sfn_arn,
)

BATCH_READ_ONLY_ACTIONS = [
"batch:Describe*",
"batch:List*",
]

BATCH_FULL_ACCESS_ACTIONS = [
"batch:RegisterJobDefinition",
"batch:DeregisterJobDefinition",
"batch:DescribeJobDefinitions",
"batch:List*",
"batch:Describe*",
*BATCH_READ_ONLY_ACTIONS,
"batch:*",
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def create_py_code_asset(
cls,
path: Path,
context_path: Optional[Path],
requirements_file: Optional[Path] = None,
includes: Optional[Sequence[str]] = None,
excludes: Optional[Sequence[str]] = None,
runtime: lambda_.Runtime = lambda_.Runtime.PYTHON_3_11,
Expand Down Expand Up @@ -219,7 +220,7 @@ def create_py_code_asset(
"ssh -vT [email protected] || true",
# Must make sure that the package is not installing using --editable mode
"python3 -m pip install --upgrade pip --no-cache",
"pip3 install . --no-cache -t /asset-output",
f"pip3 install {'-r ' + requirements_file.as_posix() if requirements_file else '.'} --no-cache -t /asset-output",
# TODO: remove botocore and boto3 from asset output
# Must make asset output permissions accessible to lambda
"find /asset-output -type d -print0 | xargs -0 chmod 755",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from aibs_informatics_core.env import EnvBase
from aibs_informatics_core.utils.decorators import cached_property
from aibs_informatics_core.utils.hashing import generate_path_hash
from aws_cdk import aws_ecr_assets
from aws_cdk import aws_lambda as lambda_
from aws_cdk import aws_s3_assets

from aibs_informatics_cdk_lib.common.git import clone_repo, is_local_repo
from aibs_informatics_cdk_lib.constructs_.assets.code_asset import (
GLOBAL_GLOB_EXCLUDES,
PYTHON_GLOB_EXCLUDES,
PYTHON_REGEX_EXCLUDES,
CodeAsset,
Expand All @@ -24,6 +26,36 @@
logger = logging.getLogger(__name__)


class AssetsMixin:
@classmethod
def resolve_repo_path(cls, repo_url: str, repo_path_env_var: Optional[str]) -> Path:
"""Resolves the repo path from the environment or clones the repo from the url
This method is useful to quickly swapping between locally modified changes and the remote repo.
This should typically be used in the context of defining a code asset for a static name
(e.g. AIBS_INFORMATICS_AWS_LAMBDA). You can then use the env var option to point to a local
repo path for development.
Args:
repo_url (str): The git repo url. This is required.
If the repo path is not in the environment, the repo will be cloned from this url.
repo_path_env_var (Optional[str]): The environment variable that contains the repo path.
This is optional. This is useful for local development.
Returns:
Path: The path to the repo
"""
if repo_path_env_var and (repo_path := os.getenv(repo_path_env_var)) is not None:
logger.info(f"Using {repo_path_env_var} from environment")
if not is_local_repo(repo_path):
raise ValueError(f"Env variable {repo_path_env_var} is not a valid git repo")
repo_path = Path(repo_path)
else:
repo_path = clone_repo(repo_url, skip_if_exists=True)
return repo_path


class AIBSInformaticsCodeAssets(constructs.Construct):
def __init__(
self,
Expand Down Expand Up @@ -92,7 +124,7 @@ def AIBS_INFORMATICS_AWS_LAMBDA(self) -> CodeAsset:
"ssh -vT [email protected] || true",
# Must make sure that the package is not installing using --editable mode
"python3 -m pip install --upgrade pip --no-cache",
"pip3 install --no-cache -r requirements-lambda.txt -t /asset-output",
"pip3 install . --no-cache -t /asset-output",
# TODO: remove botocore and boto3 from asset output
# Must make asset output permissions accessible to lambda
"find /asset-output -type d -print0 | xargs -0 chmod 755",
Expand All @@ -117,3 +149,38 @@ def AIBS_INFORMATICS_AWS_LAMBDA(self) -> CodeAsset:
self.env_base.ENV_BASE_KEY: self.env_base,
},
)


class AIBSInformaticsDockerAssets(constructs.Construct, AssetsMixin):
def __init__(
self,
scope: constructs.Construct,
construct_id: str,
env_base: EnvBase,
) -> None:
super().__init__(scope, construct_id)
self.env_base = env_base

@cached_property
def AIBS_INFORMATICS_AWS_LAMBDA(self) -> aws_ecr_assets.DockerImageAsset:
"""Returns a NEW docker asset for aibs-informatics-aws-lambda
Returns:
aws_ecr_assets.DockerImageAsset: The docker asset
"""
repo_path = self.resolve_repo_path(
AIBS_INFORMATICS_AWS_LAMBDA_REPO, AIBS_INFORMATICS_AWS_LAMBDA_REPO_ENV_VAR
)
return aws_ecr_assets.DockerImageAsset(
self,
"aibs-informatics-aws-lambda",
directory=repo_path.as_posix(),
build_ssh="default",
platform=aws_ecr_assets.Platform.LINUX_AMD64,
asset_name="aibs-informatics-aws-lambda",
file="docker/Dockerfile",
exclude=[
*PYTHON_GLOB_EXCLUDES,
*GLOBAL_GLOB_EXCLUDES,
],
)
23 changes: 17 additions & 6 deletions src/aibs_informatics_cdk_lib/constructs_/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def construct_tags(self) -> List[cdk.Tag]:

def add_tags(self, *tags: cdk.Tag):
for tag in tags:
cdk.Tags.of(self).add(key=tag.key, value=tag.value)
cdk.Tags.of(self.as_construct()).add(key=tag.key, value=tag.value)

def normalize_construct_id(
self, construct_id: str, max_size: int = 64, hash_size: int = 8
Expand Down Expand Up @@ -71,11 +71,14 @@ def get_resource_name(self, name: Union[ResourceNameBaseEnum, str]) -> str:
return name.get_name(self.env_base)
return self.env_base.get_resource_name(name)

def get_stack_of(self, construct: Construct) -> Optional[Stack]:
try:
return cdk.Stack.of(construct)
except:
return None
def get_stack_of(self, construct: Optional[Construct] = None) -> Stack:
if construct is None:
construct = self.as_construct()
return cdk.Stack.of(construct)

def as_construct(self) -> Construct:
assert isinstance(self, Construct)
return self

@classmethod
def build_construct_id(cls, env_base: EnvBase, *names: str) -> str:
Expand Down Expand Up @@ -104,3 +107,11 @@ def __init__(self, scope: Construct, id: Optional[str], env_base: EnvBase) -> No
@property
def construct_id(self) -> str:
return self.node.id

@property
def aws_region(self) -> str:
return cdk.Stack.of(self).region

@property
def aws_account(self) -> str:
return cdk.Stack.of(self).account
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from aws_cdk import aws_s3 as s3

from aibs_informatics_cdk_lib.common.aws.iam_utils import (
BATCH_READ_ONLY_ACTIONS,
S3_READ_ONLY_ACCESS_ACTIONS,
batch_policy_statement,
dynamodb_policy_statement,
lambda_policy_statement,
)
Expand Down Expand Up @@ -149,6 +151,7 @@ def create_instance_role(
effect=iam.Effect.ALLOW,
resources=["*"],
),
batch_policy_statement(actions=BATCH_READ_ONLY_ACTIONS, env_base=self.env_base),
lambda_policy_statement(actions=["lambda:InvokeFunction"], env_base=self.env_base),
dynamodb_policy_statement(
env_base=self.env_base,
Expand Down
2 changes: 1 addition & 1 deletion src/aibs_informatics_cdk_lib/constructs_/cw/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def create_widgets_and_alarms(
Returns:
Tuple[List[cw.IWidget], List[cw.IAlarm]]: List of widgets and list of alarms
"""
self_stack = self.get_stack_of(self)
self_stack = cdk.Stack.of(self.as_construct())

graph_widgets: List[cw.IWidget] = []
metric_alarms: List[cw.IAlarm] = []
Expand Down
43 changes: 39 additions & 4 deletions src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def get_fn(self, function_name: str) -> lambda_.IFunction:
resource_cache = cast(Dict[str, lambda_.IFunction], getattr(self, cache_attr))
if function_name not in resource_cache:
resource_cache[function_name] = lambda_.Function.from_function_arn(
scope=self,
scope=self.as_construct(),
id=self.env_base.get_construct_id(function_name, "from-arn"),
function_arn=build_lambda_arn(
resource_type="function",
Expand All @@ -52,6 +52,37 @@ def get_state_machine_from_name(self, state_machine_name: str) -> sfn.IStateMach
return resource_cache[state_machine_name]


def create_state_machine(
scope: constructs.Construct,
env_base: EnvBase,
name: str,
definition: sfn.IChainable,
role: Optional[iam.Role] = None,
logs: Optional[sfn.LogOptions] = None,
timeout: Optional[cdk.Duration] = None,
) -> sfn.StateMachine:
return sfn.StateMachine(
scope,
env_base.get_construct_id(name),
state_machine_name=env_base.get_state_machine_name(name),
logs=(
logs
or sfn.LogOptions(
destination=logs_.LogGroup(
scope,
env_base.get_construct_id(name, "state-loggroup"),
log_group_name=env_base.get_state_machine_log_group_name(name),
removal_policy=cdk.RemovalPolicy.DESTROY,
retention=logs_.RetentionDays.ONE_MONTH,
)
)
),
role=cast(iam.IRole, role),
definition_body=sfn.DefinitionBody.from_chainable(definition),
timeout=timeout,
)


class StateMachineFragment(sfn.StateMachineFragment):
@property
def definition(self) -> sfn.IChainable:
Expand Down Expand Up @@ -87,9 +118,13 @@ def enclose(
scope, f"{id} Parallel Prep", parameters={"input": sfn.JsonPath.entire_payload}
)

parallel = chain.to_single_state(
id=f"{id} Parallel", input_path="$.input", result_path="$.result"
)
if isinstance(chain, sfn.Chain):
parallel = chain.to_single_state(
id=f"{id} Parallel", input_path="$.input", result_path="$.result"
)
else:
parallel = chain.to_single_state(input_path="$.input", result_path="$.result")

mod_result_path = JsonReferencePath("$.input")
if result_path and result_path != sfn.JsonPath.DISCARD:
mod_result_path = mod_result_path + result_path
Expand Down
Loading

0 comments on commit bb40e93

Please sign in to comment.