diff --git a/.changes/unreleased/Fixes-20240206-160231.yaml b/.changes/unreleased/Fixes-20240206-160231.yaml new file mode 100644 index 00000000..a00205b3 --- /dev/null +++ b/.changes/unreleased/Fixes-20240206-160231.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: Make invocation contexts more reliable in testing scenarios. +time: 2024-02-06T16:02:31.81842-05:00 +custom: + Author: peterallenwebb + Issue: "52" diff --git a/dbt_common/context.py b/dbt_common/context.py index 07434432..f7f4b7ec 100644 --- a/dbt_common/context.py +++ b/dbt_common/context.py @@ -1,4 +1,4 @@ -from contextvars import ContextVar +from contextvars import ContextVar, copy_context from typing import List, Mapping, Optional from dbt_common.constants import SECRET_ENV_PREFIX @@ -26,10 +26,23 @@ def env_secrets(self) -> List[str]: _INVOCATION_CONTEXT_VAR: ContextVar[InvocationContext] = ContextVar("DBT_INVOCATION_CONTEXT_VAR") +def _reliably_get_invocation_var() -> ContextVar: + invocation_var: Optional[ContextVar] = next( + (cv for cv in copy_context() if cv.name == _INVOCATION_CONTEXT_VAR.name), None + ) + + if invocation_var is None: + invocation_var = _INVOCATION_CONTEXT_VAR + + return invocation_var + + def set_invocation_context(env: Mapping[str, str]) -> None: - _INVOCATION_CONTEXT_VAR.set(InvocationContext(env)) + invocation_var = _reliably_get_invocation_var() + invocation_var.set(InvocationContext(env)) def get_invocation_context() -> InvocationContext: - ctx = _INVOCATION_CONTEXT_VAR.get() + invocation_var = _reliably_get_invocation_var() + ctx = invocation_var.get() return ctx