diff --git a/dbt_common/context.py b/dbt_common/context.py index 35223c35..07434432 100644 --- a/dbt_common/context.py +++ b/dbt_common/context.py @@ -1,6 +1,5 @@ from contextvars import ContextVar -import os -from typing import List, Mapping +from typing import List, Mapping, Optional from dbt_common.constants import SECRET_ENV_PREFIX @@ -8,27 +7,27 @@ class InvocationContext: def __init__(self, env: Mapping[str, str]): self._env = env - self._env_secrets: List[str] = None + self._env_secrets: Optional[List[str]] = None # This class will also eventually manage the invocation_id, flags, event manager, etc. @property def env(self) -> Mapping[str, str]: - if self._env is None: - self._env = os.environ - return self._env @property def env_secrets(self) -> List[str]: - return [v for k, v in self.env.items() if k.startswith(SECRET_ENV_PREFIX) and v.strip()] - + if self._env_secrets is None: + self._env_secrets = [ + v for k, v in self.env.items() if k.startswith(SECRET_ENV_PREFIX) and v.strip() + ] + return self._env_secrets _INVOCATION_CONTEXT_VAR: ContextVar[InvocationContext] = ContextVar("DBT_INVOCATION_CONTEXT_VAR") -def set_invocation_context() -> None: - _INVOCATION_CONTEXT_VAR.set(InvocationContext()) +def set_invocation_context(env: Mapping[str, str]) -> None: + _INVOCATION_CONTEXT_VAR.set(InvocationContext(env)) def get_invocation_context() -> InvocationContext: diff --git a/tests/unit/test_invocation_context.py b/tests/unit/test_invocation_context.py index 9653996f..f1ae4418 100644 --- a/tests/unit/test_invocation_context.py +++ b/tests/unit/test_invocation_context.py @@ -1,11 +1,13 @@ from dbt_common.constants import SECRET_ENV_PREFIX from dbt_common.context import InvocationContext + def test_invocation_context_env(): test_env = {"VAR_1": "value1", "VAR_2": "value2"} ic = InvocationContext(env=test_env) assert ic.env == test_env + def test_invocation_context_secrets(): test_env = { f"{SECRET_ENV_PREFIX}_VAR_1": "secret1", @@ -14,4 +16,4 @@ def test_invocation_context_secrets(): f"foo{SECRET_ENV_PREFIX}": "nonsecret", } ic = InvocationContext(env=test_env) - assert set(ic.env_secrets) == set(["secret1", "secret2"]) \ No newline at end of file + assert set(ic.env_secrets) == set(["secret1", "secret2"])