Skip to content

Commit

Permalink
Use new context invocation class to cache environment variables. (#9489)
Browse files Browse the repository at this point in the history
* Use new context invocation class.

* Adjust new constructor param on InvocationContext, make tests robust

* Add changelog entry.

* Clarify parameter name
  • Loading branch information
peterallenwebb authored Feb 2, 2024
1 parent db65e62 commit ef03ea2
Show file tree
Hide file tree
Showing 14 changed files with 95 additions and 28 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20240202-112644.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Cache environment variables
time: 2024-02-02T11:26:44.614393-05:00
custom:
Author: peterallenwebb
Issue: "9489"
6 changes: 6 additions & 0 deletions core/dbt/cli/requires.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import os

import dbt.tracking
from dbt_common.context import set_invocation_context
from dbt_common.invocation import reset_invocation_id

from dbt.version import installed as installed_version
from dbt.adapters.factory import adapter_management
from dbt.flags import set_flags, get_flag_dict
Expand Down Expand Up @@ -45,6 +49,8 @@ def wrapper(*args, **kwargs):
assert isinstance(ctx, Context)
ctx.obj = ctx.obj or {}

set_invocation_context(os.environ)

# Flags
flags = Flags(ctx)
ctx.obj["flags"] = flags
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/config/renderer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Dict, Any, Tuple, Optional, Union, Callable
import re
import os
from datetime import date

from dbt.clients.jinja import get_rendered
Expand All @@ -11,6 +10,7 @@
from dbt.context.base import BaseContext
from dbt.adapters.contracts.connection import HasCredentials
from dbt.exceptions import DbtProjectError
from dbt_common.context import get_invocation_context
from dbt_common.exceptions import CompilationError, RecursionError
from dbt_common.utils import deep_map_render

Expand Down Expand Up @@ -212,7 +212,7 @@ def render_value(self, value: Any, keypath: Optional[Keypath] = None) -> Any:
)
if m:
found = m.group(1)
value = os.environ[found]
value = get_invocation_context().env[found]
replace_this = SECRET_PLACEHOLDER.format(found)
return rendered.replace(replace_this, value)
else:
Expand Down
8 changes: 5 additions & 3 deletions core/dbt/context/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
SetStrictWrongTypeError,
ZipStrictWrongTypeError,
)
from dbt_common.context import get_invocation_context
from dbt_common.exceptions.macros import MacroReturn
from dbt_common.events.functions import fire_event, get_invocation_id
from dbt.events.types import JinjaLogInfo, JinjaLogDebug
Expand Down Expand Up @@ -303,8 +304,9 @@ def env_var(self, var: str, default: Optional[str] = None) -> str:
return_value = None
if var.startswith(SECRET_ENV_PREFIX):
raise SecretEnvVarLocationError(var)
if var in os.environ:
return_value = os.environ[var]
env = get_invocation_context().env
if var in env:
return_value = env[var]
elif default is not None:
return_value = default

Expand All @@ -313,7 +315,7 @@ def env_var(self, var: str, default: Optional[str] = None) -> str:
# that so we can skip partial parsing. Otherwise the file will be scheduled for
# reparsing. If the default changes, the file will have been updated and therefore
# will be scheduled for reparsing anyways.
self.env_vars[var] = return_value if var in os.environ else DEFAULT_ENV_PLACEHOLDER
self.env_vars[var] = return_value if var in env else DEFAULT_ENV_PLACEHOLDER

return return_value
else:
Expand Down
10 changes: 6 additions & 4 deletions core/dbt/context/configured.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from typing import Any, Dict, Optional

from dbt_common.context import get_invocation_context

from dbt.constants import SECRET_ENV_PREFIX, DEFAULT_ENV_PLACEHOLDER
from dbt.adapters.contracts.connection import AdapterRequiredConfig
from dbt.node_types import NodeType
Expand Down Expand Up @@ -89,8 +90,9 @@ def env_var(self, var: str, default: Optional[str] = None) -> str:
return_value = None
if var.startswith(SECRET_ENV_PREFIX):
raise SecretEnvVarLocationError(var)
if var in os.environ:
return_value = os.environ[var]
env = get_invocation_context().env
if var in env:
return_value = env[var]
elif default is not None:
return_value = default

Expand All @@ -101,7 +103,7 @@ def env_var(self, var: str, default: Optional[str] = None) -> str:
# reparsing. If the default changes, the file will have been updated and therefore
# will be scheduled for reparsing anyways.
self.schema_yaml_vars.env_vars[var] = (
return_value if var in os.environ else DEFAULT_ENV_PLACEHOLDER
return_value if var in env else DEFAULT_ENV_PLACEHOLDER
)

return return_value
Expand Down
19 changes: 13 additions & 6 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
Iterable,
Mapping,
)

from typing_extensions import Protocol

from dbt.adapters.base.column import Column
from dbt.artifacts.resources import NodeVersion, RefArgs
from dbt_common.clients.jinja import MacroProtocol
from dbt_common.context import get_invocation_context
from dbt.adapters.factory import get_adapter, get_adapter_package_names, get_adapter_type_names
from dbt_common.clients import agate_helper
from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack, UnitTestMacroGenerator
Expand Down Expand Up @@ -1353,8 +1355,11 @@ def env_var(self, var: str, default: Optional[str] = None) -> str:
return_value = None
if var.startswith(SECRET_ENV_PREFIX):
raise SecretEnvVarLocationError(var)
if var in os.environ:
return_value = os.environ[var]

env = get_invocation_context().env

if var in env:
return_value = env[var]
elif default is not None:
return_value = default

Expand All @@ -1373,7 +1378,7 @@ def env_var(self, var: str, default: Optional[str] = None) -> str:
# reparsing. If the default changes, the file will have been updated and therefore
# will be scheduled for reparsing anyways.
self.manifest.env_vars[var] = (
return_value if var in os.environ else DEFAULT_ENV_PLACEHOLDER
return_value if var in env else DEFAULT_ENV_PLACEHOLDER
)

# hooks come from dbt_project.yml which doesn't have a real file_id
Expand Down Expand Up @@ -1784,8 +1789,10 @@ def env_var(self, var: str, default: Optional[str] = None) -> str:
return_value = None
if var.startswith(SECRET_ENV_PREFIX):
raise SecretEnvVarLocationError(var)
if var in os.environ:
return_value = os.environ[var]

env = get_invocation_context().env
if var in env:
return_value = env[var]
elif default is not None:
return_value = default

Expand All @@ -1797,7 +1804,7 @@ def env_var(self, var: str, default: Optional[str] = None) -> str:
# reparsing. If the default changes, the file will have been updated and therefore
# will be scheduled for reparsing anyways.
self.manifest.env_vars[var] = (
return_value if var in os.environ else DEFAULT_ENV_PLACEHOLDER
return_value if var in env else DEFAULT_ENV_PLACEHOLDER
)
# the "model" should only be test nodes, but just in case, check
# TODO CT-211
Expand Down
14 changes: 8 additions & 6 deletions core/dbt/context/secret.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from typing import Any, Dict, Optional

from dbt_common.context import get_invocation_context

from .base import BaseContext, contextmember

from dbt.constants import SECRET_ENV_PREFIX, DEFAULT_ENV_PLACEHOLDER
Expand Down Expand Up @@ -30,24 +31,25 @@ def env_var(self, var: str, default: Optional[str] = None) -> str:
# if this is a 'secret' env var, just return the name of the env var
# instead of rendering the actual value here, to avoid any risk of
# Jinja manipulation. it will be subbed out later, in SecretRenderer.render_value
if var in os.environ and var.startswith(SECRET_ENV_PREFIX):
env = get_invocation_context().env
if var in env and var.startswith(SECRET_ENV_PREFIX):
return SECRET_PLACEHOLDER.format(var)

elif var in os.environ:
return_value = os.environ[var]
if var in env:
return_value = env[var]
elif default is not None:
return_value = default

if return_value is not None:
# store env vars in the internal manifest to power partial parsing
# if it's a 'secret' env var, we shouldn't even get here
# but just to be safe — don't save secrets
# but just to be safe, don't save secrets
if not var.startswith(SECRET_ENV_PREFIX):
# If the environment variable is set from a default, store a string indicating
# that so we can skip partial parsing. Otherwise the file will be scheduled for
# reparsing. If the default changes, the file will have been updated and therefore
# will be scheduled for reparsing anyways.
self.env_vars[var] = return_value if var in os.environ else DEFAULT_ENV_PLACEHOLDER
self.env_vars[var] = return_value if var in env else DEFAULT_ENV_PLACEHOLDER
return return_value
else:
raise EnvVarMissingError(var)
Expand Down
6 changes: 3 additions & 3 deletions core/dbt/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import json
import logging
import os
import sys
import time
import warnings
Expand All @@ -12,15 +11,16 @@
from typing import Optional, List, ContextManager, Callable, Dict, Any, Set

import logbook
from dbt.constants import SECRET_ENV_PREFIX

from dbt_common.context import get_invocation_context
from dbt_common.dataclass_schema import dbtClassMixin

STDOUT_LOG_FORMAT = "{record.message}"
DEBUG_LOG_FORMAT = "{record.time:%Y-%m-%d %H:%M:%S.%f%z} ({record.thread_name}): {record.message}"


def get_secret_env() -> List[str]:
return [v for k, v in os.environ.items() if k.startswith(SECRET_ENV_PREFIX)]
return get_invocation_context().env_secrets


ExceptionInformation = str
Expand Down
5 changes: 4 additions & 1 deletion core/dbt/parser/partial.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import os
from copy import deepcopy
from typing import MutableMapping, Dict, List, Callable

from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.files import (
AnySourceFile,
ParseFileType,
parse_file_type_to_parser,
SchemaSourceFile,
)
from dbt_common.context import get_invocation_context
from dbt_common.events.functions import fire_event
from dbt_common.events.base_types import EventLevel
from dbt.events.types import (
Expand Down Expand Up @@ -159,7 +161,8 @@ def build_file_diff(self):
deleted = len(deleted) + len(deleted_schema_files)
changed = len(changed) + len(changed_schema_files)
event = PartialParsingEnabled(deleted=deleted, added=len(added), changed=changed)
if os.environ.get("DBT_PP_TEST"):

if get_invocation_context().env.get("DBT_PP_TEST"):
fire_event(event, level=EventLevel.INFO)
else:
fire_event(event)
Expand Down
8 changes: 7 additions & 1 deletion core/dbt/task/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from pathlib import Path
from typing import AbstractSet, Optional, Dict, List, Set, Tuple, Iterable

from dbt_common.context import get_invocation_context, _INVOCATION_CONTEXT_VAR
import dbt_common.utils.formatting

import dbt.exceptions
import dbt.tracking
import dbt.utils
Expand Down Expand Up @@ -377,7 +379,7 @@ def execute_nodes(self):
with TextOnly():
fire_event(Formatting(""))

pool = ThreadPool(num_threads)
pool = ThreadPool(num_threads, self._pool_thread_initializer, [get_invocation_context()])
try:
self.run_queue(pool)
except FailFastError as failure:
Expand Down Expand Up @@ -414,6 +416,10 @@ def execute_nodes(self):

return self.node_results

@staticmethod
def _pool_thread_initializer(invocation_context):
_INVOCATION_CONTEXT_VAR.set(invocation_context)

def _mark_dependent_errors(
self, node_id: str, result: RunResult, cause: Optional[RunResult]
) -> None:
Expand Down
15 changes: 15 additions & 0 deletions core/dbt/tests/util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextvars import ContextVar, copy_context
from io import StringIO
import os
import shutil
Expand All @@ -12,6 +13,7 @@
from dbt.cli.main import dbtRunner
from dbt.logger import log_manager
from dbt.contracts.graph.manifest import Manifest
from dbt_common.context import _INVOCATION_CONTEXT_VAR, InvocationContext
from dbt_common.events.functions import (
fire_event,
capture_stdout_logs,
Expand Down Expand Up @@ -631,3 +633,16 @@ def get_model_file(project, relation: BaseRelation) -> str:

def set_model_file(project, relation: BaseRelation, model_sql: str):
write_file(model_sql, project.project_root, "models", f"{relation.name}.sql")


def safe_set_invocation_context():
"""In order to deal with a problem with the way the pytest runner interacts
with ContextVars, this function provides a mechanism for setting the
invocation context reliably, using its name rather than the reference
variable, which may have been loaded in a separate context."""
invocation_var: Optional[ContextVar] = next(
iter([cv for cv in copy_context() if cv.name == _INVOCATION_CONTEXT_VAR.name]), None
)
if invocation_var is None:
invocation_var = _INVOCATION_CONTEXT_VAR
invocation_var.set(InvocationContext(os.environ))
10 changes: 9 additions & 1 deletion tests/unit/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from dbt.task.base import ConfiguredTask

from dbt.flags import set_from_args
from dbt.tests.util import safe_set_invocation_context

from .utils import normalize

Expand Down Expand Up @@ -362,9 +363,10 @@ def test_eq(self):

def test_invalid_env_vars(self):
self.env_override["env_value_port"] = "hello"
renderer = empty_profile_renderer()
with mock.patch.dict(os.environ, self.env_override):
with self.assertRaises(dbt.exceptions.DbtProfileError) as exc:
safe_set_invocation_context()
renderer = empty_profile_renderer()
dbt.config.Profile.from_raw_profile_info(
self.default_profile_data["default"],
"default",
Expand Down Expand Up @@ -442,6 +444,7 @@ def test_profile_override(self):
def test_env_vars(self):
self.args.target = "with-vars"
with mock.patch.dict(os.environ, self.env_override):
safe_set_invocation_context() # reset invocation context with new env
profile = self.from_args()
from_raw = self.from_raw_profile_info(target_override="with-vars")

Expand All @@ -460,6 +463,7 @@ def test_env_vars_env_target(self):
self.write_profile(self.default_profile_data)
self.env_override["env_value_target"] = "with-vars"
with mock.patch.dict(os.environ, self.env_override):
safe_set_invocation_context() # reset invocation context with new env
profile = self.from_args()
from_raw = self.from_raw_profile_info(target_override="with-vars")

Expand All @@ -478,6 +482,7 @@ def test_invalid_env_vars(self):
self.args.target = "with-vars"
with mock.patch.dict(os.environ, self.env_override):
with self.assertRaises(dbt.exceptions.DbtProfileError) as exc:
safe_set_invocation_context() # reset invocation context with new env
self.from_args()

self.assertIn("Could not convert value 'hello' into type 'number'", str(exc.exception))
Expand All @@ -487,6 +492,7 @@ def test_cli_and_env_vars(self):
self.args.vars = {"cli_value_host": "cli-postgres-host"}
renderer = dbt.config.renderer.ProfileRenderer({"cli_value_host": "cli-postgres-host"})
with mock.patch.dict(os.environ, self.env_override):
safe_set_invocation_context() # reset invocation context with new env
profile = self.from_args(renderer=renderer)
from_raw = self.from_raw_profile_info(
target_override="cli-and-env-vars",
Expand Down Expand Up @@ -971,6 +977,7 @@ def setUp(self):
def test_cli_and_env_vars(self):
renderer = dbt.config.renderer.DbtProjectYamlRenderer(None, {"cli_version": "0.1.2"})
with mock.patch.dict(os.environ, self.env_override):
safe_set_invocation_context() # reset invocation context with new env
project = dbt.config.Project.from_project_root(
self.project_dir,
renderer,
Expand Down Expand Up @@ -1262,6 +1269,7 @@ def test_cli_and_env_vars(self):
self.args.project_dir = self.project_dir
set_from_args(self.args, None)
with mock.patch.dict(os.environ, self.env_override):
safe_set_invocation_context() # reset invocation context with new env
config = dbt.config.RuntimeConfig.from_args(self.args)

self.assertEqual(config.version, "0.1.2")
Expand Down
Loading

0 comments on commit ef03ea2

Please sign in to comment.