diff --git a/src/aibs_informatics_cdk_lib/project/config.py b/src/aibs_informatics_cdk_lib/project/config.py index 81d0ba5..3287f6b 100644 --- a/src/aibs_informatics_cdk_lib/project/config.py +++ b/src/aibs_informatics_cdk_lib/project/config.py @@ -146,11 +146,25 @@ def get_global_config_cls(cls) -> Type[G]: def get_stage_config_cls(cls) -> Type[S]: return cls.model_fields["default_config"].annotation # type: ignore - def get_stage_config(self, env_type: Union[str, EnvType]) -> S: - """Get default config with `EnvType` overrides""" + def get_stage_config( + self, env_type: Union[str, EnvType], env_label: Optional[str] = None + ) -> S: + """Get default config with `EnvType` overrides (and optional env_label override) + + Args: + env_type (Union[str, EnvType]): The EnvType override for the stage config to + be retrieved. + env_label (Optional[str], optional): An optional env_label override for + the stage config to be retrieved. Defaults to None (no override). + Raises: + e: If the stage config model validation fails (or any other error) + + Returns: + S: A stage config object + """ try: - return self.get_stage_config_cls().model_validate( + stage_config = self.get_stage_config_cls().model_validate( { **DeepChainMap( self.default_config_overrides[EnvType(env_type)], @@ -161,6 +175,12 @@ def get_stage_config(self, env_type: Union[str, EnvType]) -> S: except Exception as e: raise e + if env_label is None: + return stage_config + else: + stage_config.env.label = env_label + return stage_config + @classmethod def parse_file(cls: Type[P], path: Union[str, Path]) -> P: path = Path(path) diff --git a/test/aibs_informatics_cdk_lib/project/test_config.py b/test/aibs_informatics_cdk_lib/project/test_config.py index 9e48b9f..c82e5e5 100644 --- a/test/aibs_informatics_cdk_lib/project/test_config.py +++ b/test/aibs_informatics_cdk_lib/project/test_config.py @@ -159,6 +159,36 @@ def test__parse_file__test_loads_json_and_yml(self): another_proj_config = ProjectConfig.load_config(proj_config_json_path) self.assertEqual(proj_config, another_proj_config) + def test__get_stage_config__providing_an_optional_env_label_override(self): + global_config = create_global_config() + default_config = create_stage_config() + + default_config_overrides = { + EnvType.DEV: { + "env": { + "env_type": "dev", + "label": "marmot", + "account": "111222333", + } + } + } + proj_config = ProjectConfig( + global_config=global_config, + default_config=default_config, + default_config_overrides=default_config_overrides, + ) + expected_config = StageConfig.model_validate( + { + **default_config.model_copy().model_dump(exclude_unset=True), + } + ) + expected_config.env.env_type = EnvType.DEV + expected_config.env.label = "overridelabel" + expected_config.env.account = "111222333" + + resolved_config = proj_config.get_stage_config("dev", "overridelabel") + self.assertEqual(resolved_config, expected_config) + class ConfigProviderTests(BaseTest): def test__get_stage_config__fails_with_invalid_env_type(self):