diff --git a/src/aibs_informatics_cdk_lib/cicd/pipeline/base.py b/src/aibs_informatics_cdk_lib/cicd/pipeline/base.py index 508450a..ea1b4d7 100644 --- a/src/aibs_informatics_cdk_lib/cicd/pipeline/base.py +++ b/src/aibs_informatics_cdk_lib/cicd/pipeline/base.py @@ -24,31 +24,17 @@ from aws_cdk import aws_codepipeline_actions from aws_cdk import aws_codestarnotifications as codestarnotifications from aws_cdk import aws_iam as iam -from aws_cdk import aws_s3 as s3 -from aws_cdk import aws_secretsmanager as secretsmanager from aws_cdk import aws_sns as sns from aws_cdk import pipelines -from aws_cdk.aws_codebuild import ( - BuildEnvironment, - BuildEnvironmentVariable, - BuildSpec, - LinuxBuildImage, -) -from dataclasses_json import global_config +from aws_cdk.aws_codebuild import BuildEnvironment, BuildEnvironmentVariable, BuildSpec from aibs_informatics_cdk_lib.common.aws.core_utils import build_arn -from aibs_informatics_cdk_lib.common.aws.iam_utils import ( - CODE_BUILD_IAM_POLICY, - DYNAMODB_READ_ACTIONS, - S3_FULL_ACCESS_ACTIONS, -) +from aibs_informatics_cdk_lib.common.aws.iam_utils import CODE_BUILD_IAM_POLICY from aibs_informatics_cdk_lib.project.config import ( BaseProjectConfig, CodePipelineSourceConfig, - Env, GlobalConfig, PipelineConfig, - ProjectConfig, StageConfig, ) from aibs_informatics_cdk_lib.stacks.base import EnvBaseStack @@ -200,10 +186,11 @@ def __init__( ) -> None: self.project_config = config self.stage_config = config.get_stage_config(env_base.env_type) + self.stage_config.env.label = env_base.env_label env = cdk.Environment( account=self.stage_config.env.account, region=self.stage_config.env.region ) - super().__init__(scope, id, config=config, env=env, **kwargs) + super().__init__(scope, id, env_base=env_base, env=env, **kwargs) self.build_pipeline() @abstractmethod @@ -228,7 +215,7 @@ def build_pipeline(self): # Add Stages for stage_method in self.get_stage_methods(): - stage_info: PipelineStageInfo = stage_method._pipeline_stage_info # type: ignore[attr-defined] + stage_info: PipelineStageInfo = stage_method._pipeline_stage_info # type: ignore[union-attr] stage = stage_method() pre_steps = stage_info.pre_steps post_steps = stage_info.post_steps diff --git a/src/aibs_informatics_cdk_lib/cicd/target.py b/src/aibs_informatics_cdk_lib/cicd/target.py index b7526b3..4241c46 100644 --- a/src/aibs_informatics_cdk_lib/cicd/target.py +++ b/src/aibs_informatics_cdk_lib/cicd/target.py @@ -1,6 +1,79 @@ from enum import Enum +from typing import Optional, Type, TypeVar, Union +import constructs +from aibs_informatics_core.utils.os_operations import get_env_var -class CDKStackTarget(str, Enum): - PIPELINE = "pipeline" - INFRA = "infra" +from aibs_informatics_cdk_lib.project.utils import _get_from_context + +CDK_STACK_TARGET_ENV_VAR = "CDK_STACK_TARGET" + +T = TypeVar("T", bound="CDKStackTargetBaseEnum") + + +class CDKStackTargetBaseEnum(Enum): + """Base class for CDK stack target types + + Usage: + class MyCDKStackTarget(str, CDKStackTargetBaseEnum): + INFRA = "pipeline" + + """ + + @classmethod + def from_env(cls: Type[T], default: Union[str, T]) -> T: + target = get_env_var(CDK_STACK_TARGET_ENV_VAR) + target = target or default + return cls(target) + + @classmethod + def from_context( + cls: Type[T], + node: constructs.Node, + default: Union[str, T], + context_keys: Optional[list[str]] = None, + ) -> T: + """Resolves the CDK stack target type from context + + Args: + cls (Type[T]): subclassed CDKStackTargetBase + node (constructs.Node): cdk construct node + default (str): default to use. + context_keys (Optional[list[str]], optional): overrides for context names. + Defaults to None. + + Returns: + T: CDKStackTargetBase instance + """ + context_keys = context_keys or ["target", "stack_target"] + + target = _get_from_context(node, context_keys) or default + + return cls(target) + + @classmethod + def from_context_or_env( + cls: Type[T], + node: constructs.Node, + default: Union[str, T], + context_keys: Optional[list[str]] = None, + ) -> T: + """Resolves the CDK stack target type from context or environment + + Order of resolution: + 1. CDK context value (specifying -c K=V) + 2. env variable + 3. default value ("dev") + + Args: + cls (Type[T]): subclassed CDKStackTargetBase + node (constructs.Node): cdk construct node + default (str): default to use. + context_keys (Optional[list[str]], optional): overrides for context names. + Defaults to None. + + """ + + return cls.from_context( + node=node, default=cls.from_env(default), context_keys=context_keys + ) diff --git a/src/aibs_informatics_cdk_lib/project/config.py b/src/aibs_informatics_cdk_lib/project/config.py index 7539362..81d0ba5 100644 --- a/src/aibs_informatics_cdk_lib/project/config.py +++ b/src/aibs_informatics_cdk_lib/project/config.py @@ -85,8 +85,8 @@ class CodePipelineBuildConfig(BaseModel): class CodePipelineSourceConfig(BaseModel): repository: str branch: Annotated[str, PlainValidator(EnvVarStr.validate)] - codestar_connection: Optional[UniqueIDType] - oauth_secret_name: Optional[str] + codestar_connection: Optional[UniqueIDType] = None + oauth_secret_name: Optional[str] = None @model_validator(mode="after") @classmethod diff --git a/src/aibs_informatics_cdk_lib/project/utils.py b/src/aibs_informatics_cdk_lib/project/utils.py index e13cae6..ce270f0 100644 --- a/src/aibs_informatics_cdk_lib/project/utils.py +++ b/src/aibs_informatics_cdk_lib/project/utils.py @@ -6,8 +6,9 @@ ] import logging +import os import pathlib -from typing import List, Optional, Type, Union +from typing import List, Optional, Tuple, Type, Union import constructs from aibs_informatics_core.env import ( @@ -21,7 +22,7 @@ ) from aibs_informatics_core.utils.os_operations import get_env_var, set_env_var -from aibs_informatics_cdk_lib.project.config import BaseProjectConfig, G, ProjectConfig, S +from aibs_informatics_cdk_lib.project.config import BaseProjectConfig, G, P, ProjectConfig, S logger = logging.getLogger(__name__) @@ -120,19 +121,41 @@ def get_env_base(node: constructs.Node) -> EnvBase: return EnvBase.from_type_and_label(env_type=env_type, env_label=env_label) -def get_config( - node: constructs.Node, project_config_cls: Type[BaseProjectConfig[G, S]] = ProjectConfig -) -> S: - env_base = get_env_base(node) +def set_env_base(env_base: EnvBase) -> None: + """Set the environment base + Args: + env_base (EnvBase): environment base + """ set_env_var(EnvBase.ENV_BASE_KEY, env_base) set_env_var(EnvBase.ENV_TYPE_KEY, env_base.env_type) if env_base.env_label: set_env_var(EnvBase.ENV_LABEL_KEY, env_base.env_label) + else: + os.environ.pop(EnvBase.ENV_LABEL_KEY, None) + + +def get_project_config_and_env_base( + node: constructs.Node, project_config_cls: Type[P] = ProjectConfig +) -> Tuple[P, EnvBase]: + env_base = get_env_base(node) + + config = project_config_cls.load_config() + return config, env_base + + +def get_config( + node: constructs.Node, project_config_cls: Type[BaseProjectConfig[G, S]] = ProjectConfig +) -> S: + project_config, env_base = get_project_config_and_env_base( # type: ignore + node, project_config_cls=project_config_cls + ) + set_env_base(env_base) + + stage_config: S = project_config.get_stage_config(env_type=env_base.env_type) + stage_config.env.label = env_base.env_label - config: S = project_config_cls.load_stage_config(env_type=env_base.env_type) - config.env.label = env_base.env_label - return config + return stage_config def _get_from_context( diff --git a/test/aibs_informatics_cdk_lib/project/test_utils.py b/test/aibs_informatics_cdk_lib/project/test_utils.py index ac76403..5da56cd 100644 --- a/test/aibs_informatics_cdk_lib/project/test_utils.py +++ b/test/aibs_informatics_cdk_lib/project/test_utils.py @@ -1,3 +1,5 @@ +import os + import aws_cdk as cdk import constructs import pytest @@ -19,6 +21,7 @@ ENV_LABEL_KEYS, ENV_TYPE_KEYS, get_env_base, + set_env_base, ) USER = "marmotdev" @@ -104,3 +107,17 @@ def test__get_env_base__context_and_env_vars(env_vars, dummy_node): # Base from context supercedes type/label dummy_node.set_context(ENV_BASE_KEY, "dev") assert get_env_base(dummy_node) == EnvBase("dev") + + +def test__set_env_base__env_vars_only(env_vars): + env_base = EnvBase("dev") + set_env_base(env_base) + assert os.environ.get(ENV_BASE_KEY) == "dev" + assert os.environ.get(ENV_TYPE_KEY) == "dev" + assert os.environ.get(ENV_LABEL_KEY) is None + + env_base = EnvBase("prod-marmot") + set_env_base(env_base) + assert os.environ.get(ENV_BASE_KEY) == "prod-marmot" + assert os.environ.get(ENV_TYPE_KEY) == "prod" + assert os.environ.get(ENV_LABEL_KEY) == "marmot"