Skip to content

Commit

Permalink
updates to compute stack
Browse files Browse the repository at this point in the history
rpmcginty committed Apr 19, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent d82f195 commit 392da29
Showing 1 changed file with 78 additions and 21 deletions.
99 changes: 78 additions & 21 deletions src/aibs_informatics_cdk_lib/stacks/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from email.policy import default
from typing import Any, Iterable, List, Optional, TypeVar, Union
from urllib import request

@@ -36,7 +37,11 @@
from aibs_informatics_cdk_lib.constructs_.batch.launch_template import BatchLaunchTemplateBuilder
from aibs_informatics_cdk_lib.constructs_.batch.types import BatchEnvironmentDescriptor
from aibs_informatics_cdk_lib.constructs_.ec2 import EnvBaseVpc
from aibs_informatics_cdk_lib.constructs_.efs.file_system import EFSEcosystem, EnvBaseFileSystem
from aibs_informatics_cdk_lib.constructs_.efs.file_system import (
EFSEcosystem,
EnvBaseFileSystem,
MountPointConfiguration,
)
from aibs_informatics_cdk_lib.constructs_.s3 import EnvBaseBucket, LifecycleRuleGenerator
from aibs_informatics_cdk_lib.constructs_.sfn.fragments.batch import SubmitJobFragment
from aibs_informatics_cdk_lib.stacks.base import EnvBaseStack
@@ -97,7 +102,10 @@ def __init__(
env_base: EnvBase,
vpc: ec2.Vpc,
buckets: Optional[Iterable[s3.Bucket]] = None,
file_systems: Optional[Iterable[efs.FileSystem]] = None,
file_systems: Optional[Iterable[Union[efs.FileSystem, efs.IFileSystem]]] = None,
mount_point_configs: Optional[Iterable[MountPointConfiguration]] = None,
create_state_machine: bool = True,
state_machine_name: Optional[str] = "submit-job",
**kwargs,
) -> None:
super().__init__(scope, id, env_base, **kwargs)
@@ -110,13 +118,26 @@ def __init__(

file_system_list = list(file_systems or [])

if mount_point_configs:
mount_point_config_list = list(mount_point_configs)
file_system_list = self._update_file_systems_from_mount_point_configs(
file_system_list, mount_point_config_list
)
else:
mount_point_config_list = self._get_mount_point_configs(file_system_list)

# Validation to ensure that the file systems are not duplicated
self._validate_mount_point_configs(mount_point_config_list)

self.grant_storage_access(*bucket_list, *file_system_list)

self.create_step_functions(file_system=file_system_list[0] if file_system_list else None)
self.create_step_functions(
name=state_machine_name, mount_point_configs=mount_point_config_list
)

self.export_values()

def grant_storage_access(self, *resources: Union[s3.Bucket, efs.FileSystem]):
def grant_storage_access(self, *resources: Union[s3.Bucket, efs.FileSystem, efs.IFileSystem]):
self.batch.grant_instance_role_permissions(read_write_resources=list(resources))

for batch_environment in self.batch.environments:
@@ -162,9 +183,12 @@ def create_batch_environments(self):
launch_template_builder=lt_builder,
)

def create_step_functions(self, file_system: Optional[efs.FileSystem] = None):

state_machine_core_name = "submit-job"
def create_step_functions(
self,
name: Optional[str] = None,
mount_point_configs: Optional[list[MountPointConfiguration]] = None,
):
state_machine_core_name = name or "submit-job"
defaults: dict[str, Any] = {}
defaults["command"] = []
defaults["job_queue"] = self.on_demand_batch_environment.job_queue.job_queue_arn
@@ -174,23 +198,14 @@ def create_step_functions(self, file_system: Optional[efs.FileSystem] = None):
defaults["gpu"] = "0"
defaults["platform_capabilities"] = ["EC2"]

if file_system:
file_system.file_system_id
if mount_point_configs:
defaults["mount_points"] = [
convert_key_case(to_mount_point("/opt/efs", False, "efs-root-volume"), pascalcase)
convert_key_case(mpc.to_batch_mount_point(f"efs-vol{i}"), pascalcase)
for i, mpc in enumerate(mount_point_configs)
]
defaults["volumes"] = [
convert_key_case(
to_volume(
None,
"efs-root-volume",
{
"fileSystemId": file_system.file_system_id,
"rootDirectory": "/",
},
),
pascalcase,
)
convert_key_case(mpc.to_batch_volume(f"efs-vol{i}"), pascalcase)
for i, mpc in enumerate(mount_point_configs)
]

start = sfn.Pass(
@@ -273,3 +288,45 @@ def export_values(self) -> None:
self.export_value(self.on_demand_batch_environment.job_queue.job_queue_arn)
self.export_value(self.spot_batch_environment.job_queue.job_queue_arn)
self.export_value(self.fargate_batch_environment.job_queue.job_queue_arn)

## Private methods

def _validate_mount_point_configs(self, mount_point_configs: List[MountPointConfiguration]):
_ = {}
for mpc in mount_point_configs:
if mpc.mount_point in _ and _[mpc.mount_point] != mpc:
raise ValueError(
f"Mount point {mpc.mount_point} is duplicated. "
"Cannot have multiple mount points configurations with the same name."
)
_[mpc.mount_point] = mpc

def _get_mount_point_configs(
self, file_systems: Optional[List[Union[efs.FileSystem, efs.IFileSystem]]]
) -> List[MountPointConfiguration]:
mount_point_configs = []
if file_systems:
for fs in file_systems:
mount_point_configs.append(MountPointConfiguration.from_file_system(fs))
return mount_point_configs

def _update_file_systems_from_mount_point_configs(
self,
file_systems: List[Union[efs.FileSystem, efs.IFileSystem]],
mount_point_configs: List[MountPointConfiguration],
) -> List[Union[efs.FileSystem, efs.IFileSystem]]:
file_system_map: dict[str, Union[efs.FileSystem, efs.IFileSystem]] = {
fs.file_system_id: fs for fs in file_systems
}
for mpc in mount_point_configs:
if mpc.file_system_id not in file_system_map:
if not mpc.file_system and mpc.access_point:
file_system_map[mpc.file_system_id] = mpc.access_point.file_system
elif mpc.file_system:
file_system_map[mpc.file_system_id] = mpc.file_system
else:
raise ValueError(
"Mount point configuration must have a file system or access point."
)

return list(file_system_map.values())

0 comments on commit 392da29

Please sign in to comment.