From be027df5f9c7b20642b6499a03ceac0f08620436 Mon Sep 17 00:00:00 2001 From: Ryan McGinty Date: Tue, 22 Oct 2024 17:55:39 -0700 Subject: [PATCH 1/3] update EnvBAse to be more easily subclassed --- src/aibs_informatics_core/env.py | 118 +++++++++++++++++++++++++------ 1 file changed, 96 insertions(+), 22 deletions(-) diff --git a/src/aibs_informatics_core/env.py b/src/aibs_informatics_core/env.py index 3f6cee1..207b783 100644 --- a/src/aibs_informatics_core/env.py +++ b/src/aibs_informatics_core/env.py @@ -20,7 +20,7 @@ import re from enum import Enum -from typing import Literal, Optional, Tuple, Union +from typing import Generic, Literal, Optional, Tuple, Type, TypeVar, Union, overload from aibs_informatics_core.collections import StrEnum, ValidatedStr from aibs_informatics_core.exceptions import ApplicationException @@ -44,6 +44,7 @@ class EnvType(StrEnum): SupportedDelim = Literal["-", "_", ":", "/"] +E = TypeVar("E", bound="EnvBase") class EnvBase(ValidatedStr): @@ -136,12 +137,12 @@ def to_type_and_label(self) -> Tuple[EnvType, Optional[str]]: return (EnvType(env_type), env_label) @classmethod - def from_type_and_label(cls, env_type: EnvType, env_label: Optional[str] = None) -> "EnvBase": + def from_type_and_label(cls: Type[E], env_type: EnvType, env_label: Optional[str] = None) -> E: """Returns [-]""" - return EnvBase(f"{env_type.value}-{env_label}" if env_label else env_type.value) + return cls(f"{env_type.value}-{env_label}" if env_label else env_type.value) @classmethod - def from_env(cls) -> "EnvBase": + def from_env(cls: Type[E]) -> E: """Get value from environment variables 1. Checks for env base variables. @@ -159,7 +160,7 @@ def from_env(cls) -> "EnvBase": """ env_base = get_env_var(ENV_BASE_KEY, ENV_BASE_KEY_ALIAS) if env_base: - return EnvBase(env_base) + return cls(env_base) else: try: env_type = cls.load_env_type__from_env() @@ -184,9 +185,9 @@ def load_env_label__from_env(cls) -> Optional[str]: return get_env_var(ENV_LABEL_KEY, ENV_LABEL_KEY_ALIAS, LABEL_KEY, LABEL_KEY_ALIAS) -class EnvBaseMixins: +class EnvBaseMixinsBase(Generic[E]): @property - def env_base(self) -> EnvBase: + def env_base(self) -> E: """returns env base If env base has not been set, it sets the value using environment variables @@ -201,7 +202,7 @@ def env_base(self) -> EnvBase: return self._env_base @env_base.setter - def env_base(self, env_base: Optional[EnvBase] = None): + def env_base(self, env_base: Optional[E] = None): """Sets env base If None is provided, env variables are read to infer env base @@ -209,16 +210,31 @@ def env_base(self, env_base: Optional[EnvBase] = None): Args: env_base (Optional[EnvBase], optional): environment base to use. Defaults to None. """ - self._env_base = get_env_base(env_base) + self._env_base = get_env_base(env_base, env_base_class=self._env_base_class()) + + @classmethod + def _env_base_class(cls) -> Type[E]: + return cls.__orig_bases__[0].__args__[0] # type: ignore + + +# ---------------------------------- +# Mixins & Enums +# ---------------------------------- + + +class EnvBaseMixins(EnvBaseMixinsBase[EnvBase]): + pass class EnvBaseEnumMixins: - def prefix_with(self, env_base: Optional[EnvBase] = None) -> str: - env_base: EnvBase = get_env_base(env_base) + def prefix_with( + self, env_base: Optional[E] = None, env_base_class: Optional[Type[E]] = None + ) -> str: + env_base_ = get_env_base(env_base, env_base_class) if isinstance(self, Enum): - return env_base.prefixed(self.value) + return env_base_.prefixed(self.value) else: - return env_base.prefixed(str(self)) + return env_base_.prefixed(str(self)) class ResourceNameBaseEnum(str, EnvBaseEnumMixins, Enum): @@ -229,15 +245,69 @@ def get_name(self, env_base: EnvBase) -> str: return self.prefix_with(env_base) -def get_env_base(env_base: Optional[Union[str, EnvBase]] = None) -> EnvBase: +# ---------------------------------- +# getter functions +# ---------------------------------- + + +@overload +def get_env_base() -> EnvBase: + ... + + +@overload +def get_env_base(env_base: Union[str, EnvBase]) -> EnvBase: + ... + + +@overload +def get_env_base(env_base: Literal[None]) -> EnvBase: + ... + + +@overload +def get_env_base(env_base: Literal[None], env_base_class: Literal[None]) -> EnvBase: + ... + + +@overload +def get_env_base(env_base: Literal[None], env_base_class: Type[E]) -> E: + ... + + +@overload +def get_env_base(env_base: Union[str, E], env_base_class: Type[E]) -> E: + ... + + +@overload +def get_env_base(env_base: Union[str, E], env_base_class: Literal[None]) -> EnvBase: + ... + + +def get_env_base( + env_base: Optional[Union[str, E]] = None, + env_base_class: Optional[Type[Union[E, EnvBase]]] = None, +) -> Union[E, EnvBase]: """Will look for the env_base as an environment variable.""" + env_base_cls: Type[E] = env_base_class or EnvBase # type: ignore[assignment] if env_base: - return env_base if isinstance(env_base, EnvBase) else EnvBase(env_base) - return EnvBase.from_env() + if isinstance(env_base, env_base_cls): + return env_base + else: + return env_base_cls(env_base) + return env_base_cls.from_env() + + +# def get_env_base(env_base: Optional[Union[str, EnvBase]] = None) -> EnvBase: +# """Will look for the env_base as an environment variable.""" +# return get_any_env_base(env_base) def get_env_type( - env_type: Optional[Union[str, EnvType]] = None, default_env_type: Optional[EnvType] = None + env_type: Optional[Union[str, EnvType]] = None, + default_env_type: Optional[EnvType] = None, + env_base_class: Optional[E] = None, ) -> EnvType: """Loads EnvType from environment or normalizes input @@ -247,7 +317,7 @@ def get_env_type( if env_type: return env_type if isinstance(env_type, EnvType) else EnvType(env_type) try: - return EnvBase.from_env().env_type + return (env_base_class or EnvBase).from_env().env_type except Exception as e: if default_env_type: return default_env_type @@ -261,7 +331,10 @@ class _Missing: MISSING = _Missing() -def get_env_label(env_label: Union[Optional[str], _Missing] = MISSING) -> Optional[str]: +def get_env_label( + env_label: Union[Optional[str], _Missing] = MISSING, + env_base_class: Optional[E] = None, +) -> Optional[str]: """Get Environment label If not specified, it will be loaded from envirionment (First checking for env base, then label) @@ -272,13 +345,14 @@ def get_env_label(env_label: Union[Optional[str], _Missing] = MISSING) -> Option Returns: Optional[str]: env label string """ + env_base_cls = env_base_class or EnvBase if isinstance(env_label, _Missing): try: # First check if EnvBase exists and - return EnvBase.from_env().env_label + return env_base_cls.from_env().env_label except: # next check for env label. - return EnvBase.load_env_label__from_env() + return env_base_cls.load_env_label__from_env() # Right now env label regex is only baked into EnvBase, so let's # create an EnvBase to validate - return EnvBase.from_type_and_label(EnvType.DEV, env_label=env_label).env_label + return env_base_cls.from_type_and_label(EnvType.DEV, env_label=env_label).env_label From f950ee8e2850de71de85177dc46310f104f04b84 Mon Sep 17 00:00:00 2001 From: Ryan McGinty Date: Tue, 22 Oct 2024 17:55:54 -0700 Subject: [PATCH 2/3] add test coverage for Resolvable --- .../models/demand_execution/resolvables.py | 2 +- .../demand_execution/test_resolvables.py | 62 ++++++++++++++++++- 2 files changed, 61 insertions(+), 3 deletions(-) diff --git a/src/aibs_informatics_core/models/demand_execution/resolvables.py b/src/aibs_informatics_core/models/demand_execution/resolvables.py index 24fdd54..132557d 100644 --- a/src/aibs_informatics_core/models/demand_execution/resolvables.py +++ b/src/aibs_informatics_core/models/demand_execution/resolvables.py @@ -136,7 +136,7 @@ def from_any( return value elif isinstance(value, dict): obj = cls.from_dict(value, partial=True) - if obj.local is None: + if obj.local is None or cls.is_missing(obj.local): if default_local is None: raise ValueError(f"Local is None for {value}. No default provided") obj.local = default_local diff --git a/test/aibs_informatics_core/models/demand_execution/test_resolvables.py b/test/aibs_informatics_core/models/demand_execution/test_resolvables.py index b9b4b8e..f724c98 100644 --- a/test/aibs_informatics_core/models/demand_execution/test_resolvables.py +++ b/test/aibs_informatics_core/models/demand_execution/test_resolvables.py @@ -1,5 +1,5 @@ from test.base import does_not_raise -from typing import Any, Optional, Tuple, Union +from typing import Any, Optional, Tuple, Type, Union import marshmallow as mm from pytest import mark, param, raises @@ -7,6 +7,7 @@ from aibs_informatics_core.exceptions import ValidationError from aibs_informatics_core.models.aws.s3 import S3URI from aibs_informatics_core.models.demand_execution.resolvables import ( + R, Resolvable, S3Resolvable, StringifiedDownloadable, @@ -190,6 +191,63 @@ def test__get_resolvable_from_value__parses_stuff( assert actual == expected +@mark.parametrize( + "resolvable_class, value, default_local, default_remote, expected, raise_expectation", + [ + param( + S3Resolvable, + S3Resolvable("/tmp/somefile", S3URI("s3://bucket/key")), + None, + None, + S3Resolvable("/tmp/somefile", S3URI("s3://bucket/key")), + does_not_raise(), + id="object is returned as is", + ), + param( + Uploadable, + {"remote": "s3://bucket/key"}, + "/tmp/somefile", + None, + Uploadable("/tmp/somefile", "s3://bucket/key"), + does_not_raise(), + id="default local fills missing local in input", + ), + param( + Uploadable, + {"remote": "s3://bucket/key"}, + None, + None, + None, + raises(ValueError, match=r"Local is None for .*\. No default provided"), + id="ERROR: no local in input or default", + ), + param( + S3Resolvable, + 42, + None, + None, + None, + raises(ValueError), + id="ERROR: invalid type", + ), + ], +) +def test__from_any__works_as_intended( + resolvable_class: Type[R], + value: Any, + default_local: Optional[str], + default_remote: Optional[str], + expected: Optional[Resolvable], + raise_expectation, +): + with raise_expectation: + actual = resolvable_class.from_any( + value=value, default_local=default_local, default_remote=default_remote + ) + if expected is not None: + assert actual == expected + + @mark.parametrize( "value, expected, raise_expectation", [ @@ -213,7 +271,7 @@ def test__get_resolvable_from_value__parses_stuff( ), ], ) -def test__from_str__works(value: str, expected: S3Resolvable, raise_expectation): +def test__from_str__works(value: str, expected: Resolvable, raise_expectation): with raise_expectation: actual = expected.from_str(value) assert actual == expected From 4e85a9b6cd9fb3620381f8d2d69263dfb6a6eed5 Mon Sep 17 00:00:00 2001 From: Ryan McGinty Date: Wed, 23 Oct 2024 11:11:54 -0700 Subject: [PATCH 3/3] remove commented out code --- src/aibs_informatics_core/env.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/aibs_informatics_core/env.py b/src/aibs_informatics_core/env.py index 207b783..83e1f53 100644 --- a/src/aibs_informatics_core/env.py +++ b/src/aibs_informatics_core/env.py @@ -299,11 +299,6 @@ def get_env_base( return env_base_cls.from_env() -# def get_env_base(env_base: Optional[Union[str, EnvBase]] = None) -> EnvBase: -# """Will look for the env_base as an environment variable.""" -# return get_any_env_base(env_base) - - def get_env_type( env_type: Optional[Union[str, EnvType]] = None, default_env_type: Optional[EnvType] = None,