Skip to content

Commit

Permalink
updates based on pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
rpmcginty committed Jul 11, 2024
1 parent 9f11bd5 commit 47ac51f
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 42 deletions.
63 changes: 32 additions & 31 deletions src/aibs_informatics_cdk_lib/cicd/pipeline/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
import logging
from abc import abstractmethod
from importlib.resources import files
from pathlib import Path
from typing import (
Callable,
Expand All @@ -15,7 +16,6 @@
Union,
cast,
)
from importlib.resources import files

import aws_cdk as cdk
import constructs
Expand Down Expand Up @@ -111,10 +111,26 @@ def deploy_stage(self) -> Tuple[List[pipelines.Step], cdk.Stage, List[pipelines.
Defaults to None.
"""

def decorator_pipeline_stage(func):
def decorator_pipeline_stage(
func: Callable[[], Union[cdk.Stage, Tuple[cdk.Stage]]]
) -> Callable[
[],
Tuple[Optional[Sequence[pipelines.Step]], cdk.Stage, Optional[Sequence[pipelines.Step]]],
]:
@functools.wraps(func)
def wrapper_pipeline_stage(*args, **kwargs):
return func(*args, **kwargs)
def wrapper_pipeline_stage(
*args, **kwargs
) -> Tuple[
Optional[Sequence[pipelines.Step]], cdk.Stage, Optional[Sequence[pipelines.Step]]
]:
results = func(*args, **kwargs)
if isinstance(results, cdk.Stage):
return None, results, None
assert isinstance(results, tuple) and len(results) == 3
assert isinstance(results[0], list) or results[0] is None
assert isinstance(results[1], cdk.Stage)
assert isinstance(results[2], list) or results[2] is None
return results

wrapper_pipeline_stage._pipeline_stage_info = PipelineStageInfo( # type: ignore[attr-defined]
order=order, name=name, pre_steps=pre_steps, post_steps=post_steps
Expand Down Expand Up @@ -216,25 +232,13 @@ def build_pipeline(self):
# Add Stages
for stage_method in self.get_stage_methods():
stage_info: PipelineStageInfo = stage_method._pipeline_stage_info # type: ignore[union-attr]
stage = stage_method()
pre_steps = stage_info.pre_steps
post_steps = stage_info.post_steps
if isinstance(stage, cdk.Stage):
stage = stage
elif (
isinstance(stage, tuple)
and len(stage) == 3
and isinstance(stage[0], list)
and isinstance(stage[1], cdk.Stage)
and isinstance(stage[2], list)
):
pre_steps = [*(pre_steps or []), *(cast(List[pipelines.Step], stage[0]))]
post_steps = [*(post_steps or []), *(cast(List[pipelines.Step], stage[2]))]
stage = stage[1]
else:
raise ValueError(
"Stage must be a cdk.Stage or a tuple of pre_steps, stage, post_steps"
)
pre_steps, stage, post_steps = stage_method()

if stage_info.pre_steps is not None:
pre_steps = [*stage_info.pre_steps, *(pre_steps or [])]
if stage_info.post_steps is not None:
post_steps = [*stage_info.post_steps, *(post_steps or [])]

self.pipeline.add_stage(stage, pre=pre_steps, post=post_steps)

# Add Promotion Stage
Expand Down Expand Up @@ -361,7 +365,9 @@ def add_promotion_stage(self, pipeline: pipelines.CodePipeline):
# 1. Read the release script file
# 2. Base64 encode the file
# 3. Decode the base64 encoded file and write it to the release script path
# TODO: i think importlib
# TODO: i think `importlib.resources.files` is preferred way to go here, but
# it requires specifying the package path. This is a bit more
# difficult to do in this context. So we are using the Path approach.
f"echo {base64.b64encode((Path(__file__).parent / 'scripts' / 'cicd-release.sh').read_text().encode()).decode()} | base64 --decode > $RELEASE_SCRIPT_PATH"
),
# Run the release script
Expand Down Expand Up @@ -512,12 +518,7 @@ def get_pipeline_source(

def get_stage_methods(
self,
) -> List[
Union[
Callable[[], cdk.Stage],
Callable[[], Tuple[Sequence[pipelines.Step], cdk.Stage, Sequence[pipelines.Step]]],
]
]:
) -> List[Callable[[], Tuple[Sequence[pipelines.Step], cdk.Stage, Sequence[pipelines.Step]]]]:
# Get all methods of the instance
methods = [
getattr(self, method_name)
Expand All @@ -530,8 +531,8 @@ def get_stage_methods(

# Sort methods by their order attribute
stage_methods.sort(key=lambda method: method._pipeline_stage_info.order) # type: ignore[attr-defined]
# Return the sorted methods

# Return the sorted methods
return stage_methods

@staticmethod
Expand Down
12 changes: 12 additions & 0 deletions src/aibs_informatics_cdk_lib/project/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,18 @@ def get_project_config_and_env_base(
def get_config(
node: constructs.Node, project_config_cls: Type[BaseProjectConfig[G, S]] = ProjectConfig
) -> S:
"""
Retrieves the stage configuration for a given node.
Args:
node (constructs.Node): The node for which to retrieve the configuration.
project_config_cls (Type[BaseProjectConfig[G, S]], optional): The project configuration class to use.
Defaults to ProjectConfig.
Returns:
S: The stage configuration object.
"""
project_config, env_base = get_project_config_and_env_base( # type: ignore
node, project_config_cls=project_config_cls
)
Expand Down
23 changes: 12 additions & 11 deletions test/aibs_informatics_cdk_lib/project/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from unittest import mock

import aws_cdk as cdk
import constructs
Expand All @@ -13,7 +14,6 @@
LABEL_KEY,
LABEL_KEY_ALIAS,
EnvBase,
EnvType,
)

from aibs_informatics_cdk_lib.project.utils import (
Expand Down Expand Up @@ -111,13 +111,14 @@ def test__get_env_base__context_and_env_vars(env_vars, dummy_node):

def test__set_env_base__env_vars_only(env_vars):
env_base = EnvBase("dev")
set_env_base(env_base)
assert os.environ.get(ENV_BASE_KEY) == "dev"
assert os.environ.get(ENV_TYPE_KEY) == "dev"
assert os.environ.get(ENV_LABEL_KEY) is None

env_base = EnvBase("prod-marmot")
set_env_base(env_base)
assert os.environ.get(ENV_BASE_KEY) == "prod-marmot"
assert os.environ.get(ENV_TYPE_KEY) == "prod"
assert os.environ.get(ENV_LABEL_KEY) == "marmot"
with mock.patch.dict(os.environ, clear=True):
set_env_base(env_base)
assert os.environ.get(ENV_BASE_KEY) == "dev"
assert os.environ.get(ENV_TYPE_KEY) == "dev"
assert os.environ.get(ENV_LABEL_KEY) is None

env_base = EnvBase("prod-marmot")
set_env_base(env_base)
assert os.environ.get(ENV_BASE_KEY) == "prod-marmot"
assert os.environ.get(ENV_TYPE_KEY) == "prod"
assert os.environ.get(ENV_LABEL_KEY) == "marmot"

0 comments on commit 47ac51f

Please sign in to comment.