diff --git a/.changes/unreleased/Features-20240202-112644.yaml b/.changes/unreleased/Features-20240202-112644.yaml new file mode 100644 index 00000000000..95fc50e8163 --- /dev/null +++ b/.changes/unreleased/Features-20240202-112644.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Cache environment variables +time: 2024-02-02T11:26:44.614393-05:00 +custom: + Author: peterallenwebb + Issue: "9489" diff --git a/core/dbt/cli/requires.py b/core/dbt/cli/requires.py index 4667e583fa7..cb98d1764bb 100644 --- a/core/dbt/cli/requires.py +++ b/core/dbt/cli/requires.py @@ -1,5 +1,9 @@ +import os + import dbt.tracking +from dbt_common.context import set_invocation_context from dbt_common.invocation import reset_invocation_id + from dbt.version import installed as installed_version from dbt.adapters.factory import adapter_management from dbt.flags import set_flags, get_flag_dict @@ -45,6 +49,8 @@ def wrapper(*args, **kwargs): assert isinstance(ctx, Context) ctx.obj = ctx.obj or {} + set_invocation_context(os.environ) + # Flags flags = Flags(ctx) ctx.obj["flags"] = flags diff --git a/core/dbt/config/renderer.py b/core/dbt/config/renderer.py index 42fd15d55a7..eee740893b8 100644 --- a/core/dbt/config/renderer.py +++ b/core/dbt/config/renderer.py @@ -1,6 +1,5 @@ from typing import Dict, Any, Tuple, Optional, Union, Callable import re -import os from datetime import date from dbt.clients.jinja import get_rendered @@ -11,6 +10,7 @@ from dbt.context.base import BaseContext from dbt.adapters.contracts.connection import HasCredentials from dbt.exceptions import DbtProjectError +from dbt_common.context import get_invocation_context from dbt_common.exceptions import CompilationError, RecursionError from dbt_common.utils import deep_map_render @@ -212,7 +212,7 @@ def render_value(self, value: Any, keypath: Optional[Keypath] = None) -> Any: ) if m: found = m.group(1) - value = os.environ[found] + value = get_invocation_context().env[found] replace_this = SECRET_PLACEHOLDER.format(found) return rendered.replace(replace_this, value) else: diff --git a/core/dbt/context/base.py b/core/dbt/context/base.py index e992c885d73..e969506f625 100644 --- a/core/dbt/context/base.py +++ b/core/dbt/context/base.py @@ -20,6 +20,7 @@ SetStrictWrongTypeError, ZipStrictWrongTypeError, ) +from dbt_common.context import get_invocation_context from dbt_common.exceptions.macros import MacroReturn from dbt_common.events.functions import fire_event, get_invocation_id from dbt.events.types import JinjaLogInfo, JinjaLogDebug @@ -303,8 +304,9 @@ def env_var(self, var: str, default: Optional[str] = None) -> str: return_value = None if var.startswith(SECRET_ENV_PREFIX): raise SecretEnvVarLocationError(var) - if var in os.environ: - return_value = os.environ[var] + env = get_invocation_context().env + if var in env: + return_value = env[var] elif default is not None: return_value = default @@ -313,7 +315,7 @@ def env_var(self, var: str, default: Optional[str] = None) -> str: # that so we can skip partial parsing. Otherwise the file will be scheduled for # reparsing. If the default changes, the file will have been updated and therefore # will be scheduled for reparsing anyways. - self.env_vars[var] = return_value if var in os.environ else DEFAULT_ENV_PLACEHOLDER + self.env_vars[var] = return_value if var in env else DEFAULT_ENV_PLACEHOLDER return return_value else: diff --git a/core/dbt/context/configured.py b/core/dbt/context/configured.py index 67a10142ac8..c5c95bbfbcb 100644 --- a/core/dbt/context/configured.py +++ b/core/dbt/context/configured.py @@ -1,6 +1,7 @@ -import os from typing import Any, Dict, Optional +from dbt_common.context import get_invocation_context + from dbt.constants import SECRET_ENV_PREFIX, DEFAULT_ENV_PLACEHOLDER from dbt.adapters.contracts.connection import AdapterRequiredConfig from dbt.node_types import NodeType @@ -89,8 +90,9 @@ def env_var(self, var: str, default: Optional[str] = None) -> str: return_value = None if var.startswith(SECRET_ENV_PREFIX): raise SecretEnvVarLocationError(var) - if var in os.environ: - return_value = os.environ[var] + env = get_invocation_context().env + if var in env: + return_value = env[var] elif default is not None: return_value = default @@ -101,7 +103,7 @@ def env_var(self, var: str, default: Optional[str] = None) -> str: # reparsing. If the default changes, the file will have been updated and therefore # will be scheduled for reparsing anyways. self.schema_yaml_vars.env_vars[var] = ( - return_value if var in os.environ else DEFAULT_ENV_PLACEHOLDER + return_value if var in env else DEFAULT_ENV_PLACEHOLDER ) return return_value diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index 51b7010e109..a8acc505644 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -13,11 +13,13 @@ Iterable, Mapping, ) + from typing_extensions import Protocol from dbt.adapters.base.column import Column from dbt.artifacts.resources import NodeVersion, RefArgs from dbt_common.clients.jinja import MacroProtocol +from dbt_common.context import get_invocation_context from dbt.adapters.factory import get_adapter, get_adapter_package_names, get_adapter_type_names from dbt_common.clients import agate_helper from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack, UnitTestMacroGenerator @@ -1353,8 +1355,11 @@ def env_var(self, var: str, default: Optional[str] = None) -> str: return_value = None if var.startswith(SECRET_ENV_PREFIX): raise SecretEnvVarLocationError(var) - if var in os.environ: - return_value = os.environ[var] + + env = get_invocation_context().env + + if var in env: + return_value = env[var] elif default is not None: return_value = default @@ -1373,7 +1378,7 @@ def env_var(self, var: str, default: Optional[str] = None) -> str: # reparsing. If the default changes, the file will have been updated and therefore # will be scheduled for reparsing anyways. self.manifest.env_vars[var] = ( - return_value if var in os.environ else DEFAULT_ENV_PLACEHOLDER + return_value if var in env else DEFAULT_ENV_PLACEHOLDER ) # hooks come from dbt_project.yml which doesn't have a real file_id @@ -1784,8 +1789,10 @@ def env_var(self, var: str, default: Optional[str] = None) -> str: return_value = None if var.startswith(SECRET_ENV_PREFIX): raise SecretEnvVarLocationError(var) - if var in os.environ: - return_value = os.environ[var] + + env = get_invocation_context().env + if var in env: + return_value = env[var] elif default is not None: return_value = default @@ -1797,7 +1804,7 @@ def env_var(self, var: str, default: Optional[str] = None) -> str: # reparsing. If the default changes, the file will have been updated and therefore # will be scheduled for reparsing anyways. self.manifest.env_vars[var] = ( - return_value if var in os.environ else DEFAULT_ENV_PLACEHOLDER + return_value if var in env else DEFAULT_ENV_PLACEHOLDER ) # the "model" should only be test nodes, but just in case, check # TODO CT-211 diff --git a/core/dbt/context/secret.py b/core/dbt/context/secret.py index 2c75546c42a..6de99fd5e5b 100644 --- a/core/dbt/context/secret.py +++ b/core/dbt/context/secret.py @@ -1,6 +1,7 @@ -import os from typing import Any, Dict, Optional +from dbt_common.context import get_invocation_context + from .base import BaseContext, contextmember from dbt.constants import SECRET_ENV_PREFIX, DEFAULT_ENV_PLACEHOLDER @@ -30,24 +31,25 @@ def env_var(self, var: str, default: Optional[str] = None) -> str: # if this is a 'secret' env var, just return the name of the env var # instead of rendering the actual value here, to avoid any risk of # Jinja manipulation. it will be subbed out later, in SecretRenderer.render_value - if var in os.environ and var.startswith(SECRET_ENV_PREFIX): + env = get_invocation_context().env + if var in env and var.startswith(SECRET_ENV_PREFIX): return SECRET_PLACEHOLDER.format(var) - elif var in os.environ: - return_value = os.environ[var] + if var in env: + return_value = env[var] elif default is not None: return_value = default if return_value is not None: # store env vars in the internal manifest to power partial parsing # if it's a 'secret' env var, we shouldn't even get here - # but just to be safe — don't save secrets + # but just to be safe, don't save secrets if not var.startswith(SECRET_ENV_PREFIX): # If the environment variable is set from a default, store a string indicating # that so we can skip partial parsing. Otherwise the file will be scheduled for # reparsing. If the default changes, the file will have been updated and therefore # will be scheduled for reparsing anyways. - self.env_vars[var] = return_value if var in os.environ else DEFAULT_ENV_PLACEHOLDER + self.env_vars[var] = return_value if var in env else DEFAULT_ENV_PLACEHOLDER return return_value else: raise EnvVarMissingError(var) diff --git a/core/dbt/logger.py b/core/dbt/logger.py index e3ec87a1914..33332417f2b 100644 --- a/core/dbt/logger.py +++ b/core/dbt/logger.py @@ -3,7 +3,6 @@ import json import logging -import os import sys import time import warnings @@ -12,7 +11,8 @@ from typing import Optional, List, ContextManager, Callable, Dict, Any, Set import logbook -from dbt.constants import SECRET_ENV_PREFIX + +from dbt_common.context import get_invocation_context from dbt_common.dataclass_schema import dbtClassMixin STDOUT_LOG_FORMAT = "{record.message}" @@ -20,7 +20,7 @@ def get_secret_env() -> List[str]: - return [v for k, v in os.environ.items() if k.startswith(SECRET_ENV_PREFIX)] + return get_invocation_context().env_secrets ExceptionInformation = str diff --git a/core/dbt/parser/partial.py b/core/dbt/parser/partial.py index 2e79ce25262..32b4760f5a8 100644 --- a/core/dbt/parser/partial.py +++ b/core/dbt/parser/partial.py @@ -1,6 +1,7 @@ import os from copy import deepcopy from typing import MutableMapping, Dict, List, Callable + from dbt.contracts.graph.manifest import Manifest from dbt.contracts.files import ( AnySourceFile, @@ -8,6 +9,7 @@ parse_file_type_to_parser, SchemaSourceFile, ) +from dbt_common.context import get_invocation_context from dbt_common.events.functions import fire_event from dbt_common.events.base_types import EventLevel from dbt.events.types import ( @@ -159,7 +161,8 @@ def build_file_diff(self): deleted = len(deleted) + len(deleted_schema_files) changed = len(changed) + len(changed_schema_files) event = PartialParsingEnabled(deleted=deleted, added=len(added), changed=changed) - if os.environ.get("DBT_PP_TEST"): + + if get_invocation_context().env.get("DBT_PP_TEST"): fire_event(event, level=EventLevel.INFO) else: fire_event(event) diff --git a/core/dbt/task/runnable.py b/core/dbt/task/runnable.py index 145d5778784..8b361988517 100644 --- a/core/dbt/task/runnable.py +++ b/core/dbt/task/runnable.py @@ -7,7 +7,9 @@ from pathlib import Path from typing import AbstractSet, Optional, Dict, List, Set, Tuple, Iterable +from dbt_common.context import get_invocation_context, _INVOCATION_CONTEXT_VAR import dbt_common.utils.formatting + import dbt.exceptions import dbt.tracking import dbt.utils @@ -377,7 +379,7 @@ def execute_nodes(self): with TextOnly(): fire_event(Formatting("")) - pool = ThreadPool(num_threads) + pool = ThreadPool(num_threads, self._pool_thread_initializer, [get_invocation_context()]) try: self.run_queue(pool) except FailFastError as failure: @@ -414,6 +416,10 @@ def execute_nodes(self): return self.node_results + @staticmethod + def _pool_thread_initializer(invocation_context): + _INVOCATION_CONTEXT_VAR.set(invocation_context) + def _mark_dependent_errors( self, node_id: str, result: RunResult, cause: Optional[RunResult] ) -> None: diff --git a/core/dbt/tests/util.py b/core/dbt/tests/util.py index ec4c2204f49..437f25aa2e2 100644 --- a/core/dbt/tests/util.py +++ b/core/dbt/tests/util.py @@ -1,3 +1,4 @@ +from contextvars import ContextVar, copy_context from io import StringIO import os import shutil @@ -12,6 +13,7 @@ from dbt.cli.main import dbtRunner from dbt.logger import log_manager from dbt.contracts.graph.manifest import Manifest +from dbt_common.context import _INVOCATION_CONTEXT_VAR, InvocationContext from dbt_common.events.functions import ( fire_event, capture_stdout_logs, @@ -631,3 +633,16 @@ def get_model_file(project, relation: BaseRelation) -> str: def set_model_file(project, relation: BaseRelation, model_sql: str): write_file(model_sql, project.project_root, "models", f"{relation.name}.sql") + + +def safe_set_invocation_context(): + """In order to deal with a problem with the way the pytest runner interacts + with ContextVars, this function provides a mechanism for setting the + invocation context reliably, using its name rather than the reference + variable, which may have been loaded in a separate context.""" + invocation_var: Optional[ContextVar] = next( + iter([cv for cv in copy_context() if cv.name == _INVOCATION_CONTEXT_VAR.name]), None + ) + if invocation_var is None: + invocation_var = _INVOCATION_CONTEXT_VAR + invocation_var.set(InvocationContext(os.environ)) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index e74c277d18c..b3e5d166e83 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -26,6 +26,7 @@ from dbt.task.base import ConfiguredTask from dbt.flags import set_from_args +from dbt.tests.util import safe_set_invocation_context from .utils import normalize @@ -362,9 +363,10 @@ def test_eq(self): def test_invalid_env_vars(self): self.env_override["env_value_port"] = "hello" - renderer = empty_profile_renderer() with mock.patch.dict(os.environ, self.env_override): with self.assertRaises(dbt.exceptions.DbtProfileError) as exc: + safe_set_invocation_context() + renderer = empty_profile_renderer() dbt.config.Profile.from_raw_profile_info( self.default_profile_data["default"], "default", @@ -442,6 +444,7 @@ def test_profile_override(self): def test_env_vars(self): self.args.target = "with-vars" with mock.patch.dict(os.environ, self.env_override): + safe_set_invocation_context() # reset invocation context with new env profile = self.from_args() from_raw = self.from_raw_profile_info(target_override="with-vars") @@ -460,6 +463,7 @@ def test_env_vars_env_target(self): self.write_profile(self.default_profile_data) self.env_override["env_value_target"] = "with-vars" with mock.patch.dict(os.environ, self.env_override): + safe_set_invocation_context() # reset invocation context with new env profile = self.from_args() from_raw = self.from_raw_profile_info(target_override="with-vars") @@ -478,6 +482,7 @@ def test_invalid_env_vars(self): self.args.target = "with-vars" with mock.patch.dict(os.environ, self.env_override): with self.assertRaises(dbt.exceptions.DbtProfileError) as exc: + safe_set_invocation_context() # reset invocation context with new env self.from_args() self.assertIn("Could not convert value 'hello' into type 'number'", str(exc.exception)) @@ -487,6 +492,7 @@ def test_cli_and_env_vars(self): self.args.vars = {"cli_value_host": "cli-postgres-host"} renderer = dbt.config.renderer.ProfileRenderer({"cli_value_host": "cli-postgres-host"}) with mock.patch.dict(os.environ, self.env_override): + safe_set_invocation_context() # reset invocation context with new env profile = self.from_args(renderer=renderer) from_raw = self.from_raw_profile_info( target_override="cli-and-env-vars", @@ -971,6 +977,7 @@ def setUp(self): def test_cli_and_env_vars(self): renderer = dbt.config.renderer.DbtProjectYamlRenderer(None, {"cli_version": "0.1.2"}) with mock.patch.dict(os.environ, self.env_override): + safe_set_invocation_context() # reset invocation context with new env project = dbt.config.Project.from_project_root( self.project_dir, renderer, @@ -1262,6 +1269,7 @@ def test_cli_and_env_vars(self): self.args.project_dir = self.project_dir set_from_args(self.args, None) with mock.patch.dict(os.environ, self.env_override): + safe_set_invocation_context() # reset invocation context with new env config = dbt.config.RuntimeConfig.from_args(self.args) self.assertEqual(config.version, "0.1.2") diff --git a/tests/unit/test_graph_runnable_task.py b/tests/unit/test_graph_runnable_task.py index 5a678c0d490..50fda4f2c82 100644 --- a/tests/unit/test_graph_runnable_task.py +++ b/tests/unit/test_graph_runnable_task.py @@ -4,6 +4,8 @@ from dbt.task.runnable import GraphRunnableTask from typing import AbstractSet, Any, Dict, Optional +from dbt.tests.util import safe_set_invocation_context + @dataclass class MockArgs: @@ -46,6 +48,9 @@ def defer_to_manifest(self, adapter, selected_uids: AbstractSet[str]): def test_graph_runnable_task_cancels_connection_on_system_exit(): + + safe_set_invocation_context() + task = MockRunnableTask(exception_class=SystemExit) with pytest.raises(SystemExit): @@ -56,6 +61,9 @@ def test_graph_runnable_task_cancels_connection_on_system_exit(): def test_graph_runnable_task_cancels_connection_on_keyboard_interrupt(): + + safe_set_invocation_context() + task = MockRunnableTask(exception_class=KeyboardInterrupt) with pytest.raises(KeyboardInterrupt): diff --git a/tests/unit/test_partial_parsing.py b/tests/unit/test_partial_parsing.py index 12caac35013..beac86abe38 100644 --- a/tests/unit/test_partial_parsing.py +++ b/tests/unit/test_partial_parsing.py @@ -6,12 +6,15 @@ from dbt.contracts.graph.nodes import ModelNode from dbt.contracts.files import ParseFileType, SourceFile, SchemaSourceFile, FilePath, FileHash from dbt.node_types import NodeType +from dbt.tests.util import safe_set_invocation_context from .utils import normalize class TestPartialParsing(unittest.TestCase): def setUp(self): + safe_set_invocation_context() + project_name = "my_test" project_root = "/users/root" sql_model_file = SourceFile( @@ -156,7 +159,6 @@ def get_python_model(self, name): ) def test_simple(self): - # Nothing has changed self.assertIsNotNone(self.partial_parsing) self.assertTrue(self.partial_parsing.skip_parsing())