Skip to content

Commit

Permalink
Merge pull request #13 from AllenInstitute/feature/OCSDV-347-make-env…
Browse files Browse the repository at this point in the history
…-base-more-subclassable

Make EnvBase more easily subclassed
  • Loading branch information
rpmcginty authored Oct 23, 2024
2 parents 804808d + 4e85a9b commit 94701eb
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 25 deletions.
113 changes: 91 additions & 22 deletions src/aibs_informatics_core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import re
from enum import Enum
from typing import Literal, Optional, Tuple, Union
from typing import Generic, Literal, Optional, Tuple, Type, TypeVar, Union, overload

from aibs_informatics_core.collections import StrEnum, ValidatedStr
from aibs_informatics_core.exceptions import ApplicationException
Expand All @@ -44,6 +44,7 @@ class EnvType(StrEnum):


SupportedDelim = Literal["-", "_", ":", "/"]
E = TypeVar("E", bound="EnvBase")


class EnvBase(ValidatedStr):
Expand Down Expand Up @@ -136,12 +137,12 @@ def to_type_and_label(self) -> Tuple[EnvType, Optional[str]]:
return (EnvType(env_type), env_label)

@classmethod
def from_type_and_label(cls, env_type: EnvType, env_label: Optional[str] = None) -> "EnvBase":
def from_type_and_label(cls: Type[E], env_type: EnvType, env_label: Optional[str] = None) -> E:
"""Returns <env_name>[-<env_label>]"""
return EnvBase(f"{env_type.value}-{env_label}" if env_label else env_type.value)
return cls(f"{env_type.value}-{env_label}" if env_label else env_type.value)

@classmethod
def from_env(cls) -> "EnvBase":
def from_env(cls: Type[E]) -> E:
"""Get value from environment variables
1. Checks for env base variables.
Expand All @@ -159,7 +160,7 @@ def from_env(cls) -> "EnvBase":
"""
env_base = get_env_var(ENV_BASE_KEY, ENV_BASE_KEY_ALIAS)
if env_base:
return EnvBase(env_base)
return cls(env_base)
else:
try:
env_type = cls.load_env_type__from_env()
Expand All @@ -184,9 +185,9 @@ def load_env_label__from_env(cls) -> Optional[str]:
return get_env_var(ENV_LABEL_KEY, ENV_LABEL_KEY_ALIAS, LABEL_KEY, LABEL_KEY_ALIAS)


class EnvBaseMixins:
class EnvBaseMixinsBase(Generic[E]):
@property
def env_base(self) -> EnvBase:
def env_base(self) -> E:
"""returns env base
If env base has not been set, it sets the value using environment variables
Expand All @@ -201,24 +202,39 @@ def env_base(self) -> EnvBase:
return self._env_base

@env_base.setter
def env_base(self, env_base: Optional[EnvBase] = None):
def env_base(self, env_base: Optional[E] = None):
"""Sets env base
If None is provided, env variables are read to infer env base
Args:
env_base (Optional[EnvBase], optional): environment base to use. Defaults to None.
"""
self._env_base = get_env_base(env_base)
self._env_base = get_env_base(env_base, env_base_class=self._env_base_class())

@classmethod
def _env_base_class(cls) -> Type[E]:
return cls.__orig_bases__[0].__args__[0] # type: ignore


# ----------------------------------
# Mixins & Enums
# ----------------------------------


class EnvBaseMixins(EnvBaseMixinsBase[EnvBase]):
pass


class EnvBaseEnumMixins:
def prefix_with(self, env_base: Optional[EnvBase] = None) -> str:
env_base: EnvBase = get_env_base(env_base)
def prefix_with(
self, env_base: Optional[E] = None, env_base_class: Optional[Type[E]] = None
) -> str:
env_base_ = get_env_base(env_base, env_base_class)
if isinstance(self, Enum):
return env_base.prefixed(self.value)
return env_base_.prefixed(self.value)
else:
return env_base.prefixed(str(self))
return env_base_.prefixed(str(self))


class ResourceNameBaseEnum(str, EnvBaseEnumMixins, Enum):
Expand All @@ -229,15 +245,64 @@ def get_name(self, env_base: EnvBase) -> str:
return self.prefix_with(env_base)


def get_env_base(env_base: Optional[Union[str, EnvBase]] = None) -> EnvBase:
# ----------------------------------
# getter functions
# ----------------------------------


@overload
def get_env_base() -> EnvBase:
...


@overload
def get_env_base(env_base: Union[str, EnvBase]) -> EnvBase:
...


@overload
def get_env_base(env_base: Literal[None]) -> EnvBase:
...


@overload
def get_env_base(env_base: Literal[None], env_base_class: Literal[None]) -> EnvBase:
...


@overload
def get_env_base(env_base: Literal[None], env_base_class: Type[E]) -> E:
...


@overload
def get_env_base(env_base: Union[str, E], env_base_class: Type[E]) -> E:
...


@overload
def get_env_base(env_base: Union[str, E], env_base_class: Literal[None]) -> EnvBase:
...


def get_env_base(
env_base: Optional[Union[str, E]] = None,
env_base_class: Optional[Type[Union[E, EnvBase]]] = None,
) -> Union[E, EnvBase]:
"""Will look for the env_base as an environment variable."""
env_base_cls: Type[E] = env_base_class or EnvBase # type: ignore[assignment]
if env_base:
return env_base if isinstance(env_base, EnvBase) else EnvBase(env_base)
return EnvBase.from_env()
if isinstance(env_base, env_base_cls):
return env_base
else:
return env_base_cls(env_base)
return env_base_cls.from_env()


def get_env_type(
env_type: Optional[Union[str, EnvType]] = None, default_env_type: Optional[EnvType] = None
env_type: Optional[Union[str, EnvType]] = None,
default_env_type: Optional[EnvType] = None,
env_base_class: Optional[E] = None,
) -> EnvType:
"""Loads EnvType from environment or normalizes input
Expand All @@ -247,7 +312,7 @@ def get_env_type(
if env_type:
return env_type if isinstance(env_type, EnvType) else EnvType(env_type)
try:
return EnvBase.from_env().env_type
return (env_base_class or EnvBase).from_env().env_type
except Exception as e:
if default_env_type:
return default_env_type
Expand All @@ -261,7 +326,10 @@ class _Missing:
MISSING = _Missing()


def get_env_label(env_label: Union[Optional[str], _Missing] = MISSING) -> Optional[str]:
def get_env_label(
env_label: Union[Optional[str], _Missing] = MISSING,
env_base_class: Optional[E] = None,
) -> Optional[str]:
"""Get Environment label
If not specified, it will be loaded from envirionment (First checking for env base, then label)
Expand All @@ -272,13 +340,14 @@ def get_env_label(env_label: Union[Optional[str], _Missing] = MISSING) -> Option
Returns:
Optional[str]: env label string
"""
env_base_cls = env_base_class or EnvBase
if isinstance(env_label, _Missing):
try:
# First check if EnvBase exists and
return EnvBase.from_env().env_label
return env_base_cls.from_env().env_label
except:
# next check for env label.
return EnvBase.load_env_label__from_env()
return env_base_cls.load_env_label__from_env()
# Right now env label regex is only baked into EnvBase, so let's
# create an EnvBase to validate
return EnvBase.from_type_and_label(EnvType.DEV, env_label=env_label).env_label
return env_base_cls.from_type_and_label(EnvType.DEV, env_label=env_label).env_label
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def from_any(
return value
elif isinstance(value, dict):
obj = cls.from_dict(value, partial=True)
if obj.local is None:
if obj.local is None or cls.is_missing(obj.local):
if default_local is None:
raise ValueError(f"Local is None for {value}. No default provided")
obj.local = default_local
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from test.base import does_not_raise
from typing import Any, Optional, Tuple, Union
from typing import Any, Optional, Tuple, Type, Union

import marshmallow as mm
from pytest import mark, param, raises

from aibs_informatics_core.exceptions import ValidationError
from aibs_informatics_core.models.aws.s3 import S3URI
from aibs_informatics_core.models.demand_execution.resolvables import (
R,
Resolvable,
S3Resolvable,
StringifiedDownloadable,
Expand Down Expand Up @@ -190,6 +191,63 @@ def test__get_resolvable_from_value__parses_stuff(
assert actual == expected


@mark.parametrize(
"resolvable_class, value, default_local, default_remote, expected, raise_expectation",
[
param(
S3Resolvable,
S3Resolvable("/tmp/somefile", S3URI("s3://bucket/key")),
None,
None,
S3Resolvable("/tmp/somefile", S3URI("s3://bucket/key")),
does_not_raise(),
id="object is returned as is",
),
param(
Uploadable,
{"remote": "s3://bucket/key"},
"/tmp/somefile",
None,
Uploadable("/tmp/somefile", "s3://bucket/key"),
does_not_raise(),
id="default local fills missing local in input",
),
param(
Uploadable,
{"remote": "s3://bucket/key"},
None,
None,
None,
raises(ValueError, match=r"Local is None for .*\. No default provided"),
id="ERROR: no local in input or default",
),
param(
S3Resolvable,
42,
None,
None,
None,
raises(ValueError),
id="ERROR: invalid type",
),
],
)
def test__from_any__works_as_intended(
resolvable_class: Type[R],
value: Any,
default_local: Optional[str],
default_remote: Optional[str],
expected: Optional[Resolvable],
raise_expectation,
):
with raise_expectation:
actual = resolvable_class.from_any(
value=value, default_local=default_local, default_remote=default_remote
)
if expected is not None:
assert actual == expected


@mark.parametrize(
"value, expected, raise_expectation",
[
Expand All @@ -213,7 +271,7 @@ def test__get_resolvable_from_value__parses_stuff(
),
],
)
def test__from_str__works(value: str, expected: S3Resolvable, raise_expectation):
def test__from_str__works(value: str, expected: Resolvable, raise_expectation):
with raise_expectation:
actual = expected.from_str(value)
assert actual == expected
Expand Down

0 comments on commit 94701eb

Please sign in to comment.