Skip to content

Commit

Permalink
In build command run unit tests before models (#9273)
Browse files Browse the repository at this point in the history
  • Loading branch information
gshank authored Dec 20, 2023
1 parent 4e87f46 commit a0177e3
Show file tree
Hide file tree
Showing 27 changed files with 5,254 additions and 2,799 deletions.
7 changes: 7 additions & 0 deletions .changes/unreleased/Breaking Changes-20231129-091921.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Breaking Changes
body: Rm --dry-run flag from 'dbt deps --add-package', in favor of just 'dbt deps
--lock'
time: 2023-11-29T09:19:21.071212+01:00
custom:
Author: jtcohen6
Issue: "9100"
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20231212-150556.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: In build command run unit tests before models
time: 2023-12-12T15:05:56.778829-05:00
custom:
Author: gshank
Issue: "9128"
7 changes: 0 additions & 7 deletions core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,6 @@ def debug(ctx, **kwargs):
@p.target
@p.vars
@p.source
@p.dry_run
@p.lock
@p.upgrade
@p.add_package
Expand All @@ -483,12 +482,6 @@ def deps(ctx, **kwargs):
message=f"Version is required in --add-package when a package when source is {flags.SOURCE}",
option_name="--add-package",
)
else:
if flags.DRY_RUN:
raise BadOptionUsage(
message="Invalid flag `--dry-run` when not using `--add-package`.",
option_name="--dry-run",
)
task = DepsTask(flags, ctx.obj["project"])
results = task.run()
success = task.interpret_results(results)
Expand Down
7 changes: 0 additions & 7 deletions core/dbt/cli/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,6 @@
hidden=True,
)

dry_run = click.option(
"--dry-run",
envvar=None,
help="Option to run `dbt deps --add-package` without updating package-lock.yml file.",
is_flag=True,
)

empty = click.option(
"--empty/--no-empty",
envvar="DBT_EMPTY",
Expand Down
4 changes: 3 additions & 1 deletion core/dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class Linker:
def __init__(self, data=None) -> None:
if data is None:
data = {}
self.graph = nx.DiGraph(**data)
self.graph: nx.DiGraph = nx.DiGraph(**data)

def edges(self):
return self.graph.edges()
Expand Down Expand Up @@ -243,6 +243,7 @@ def add_test_edges(self, manifest: Manifest) -> None:
# Get all tests that depend on any upstream nodes.
upstream_tests = []
for upstream_node in upstream_nodes:
# This gets tests with unique_ids starting with "test."
upstream_tests += _get_tests_for_node(manifest, upstream_node)

for upstream_test in upstream_tests:
Expand Down Expand Up @@ -471,6 +472,7 @@ def compile(self, manifest: Manifest, write=True, add_test_edges=False) -> Graph
summaries["_invocation_id"] = get_invocation_id()
summaries["linked"] = linker.get_graph_summary(manifest)

# This is only called for the "build" command
if add_test_edges:
manifest.build_parent_and_child_maps()
linker.add_test_edges(manifest)
Expand Down
10 changes: 6 additions & 4 deletions core/dbt/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Graph:
"""

def __init__(self, graph) -> None:
self.graph = graph
self.graph: nx.DiGraph = graph

def nodes(self) -> Set[UniqueId]:
return set(self.graph.nodes())
Expand Down Expand Up @@ -83,10 +83,10 @@ def get_subset_graph(self, selected: Iterable[UniqueId]) -> "Graph":
removed nodes are preserved as explicit new edges.
"""

new_graph = self.graph.copy()
include_nodes = set(selected)
new_graph: nx.DiGraph = self.graph.copy()
include_nodes: Set[UniqueId] = set(selected)

still_removing = True
still_removing: bool = True
while still_removing:
nodes_to_remove = list(
node
Expand Down Expand Up @@ -129,6 +129,8 @@ def get_subset_graph(self, selected: Iterable[UniqueId]) -> "Graph":
return Graph(new_graph)

def subgraph(self, nodes: Iterable[UniqueId]) -> "Graph":
# Take the original networkx graph and return a subgraph containing only
# the selected unique_id nodes.
return Graph(self.graph.subgraph(nodes))

def get_dependent_nodes(self, node: UniqueId):
Expand Down
9 changes: 7 additions & 2 deletions core/dbt/graph/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def __init__(
include_empty_nodes: bool = False,
) -> None:
super().__init__(manifest, previous_state)
self.full_graph = graph
self.include_empty_nodes = include_empty_nodes
self.full_graph: Graph = graph
self.include_empty_nodes: bool = include_empty_nodes

# build a subgraph containing only non-empty, enabled nodes and enabled
# sources.
Expand Down Expand Up @@ -258,6 +258,8 @@ def expand_selection(
node = self.manifest.nodes[unique_id]
elif unique_id in self.manifest.unit_tests:
node = self.manifest.unit_tests[unique_id] # type: ignore
# Test nodes that are not selected themselves, but whose parents are selected.
# (Does not include unit tests because they can only have one parent.)
if can_select_indirectly(node):
# should we add it in directly?
if indirect_selection == IndirectSelection.Eager or set(
Expand Down Expand Up @@ -325,8 +327,11 @@ def get_graph_queue(self, spec: SelectionSpec) -> GraphQueue:
"""Returns a queue over nodes in the graph that tracks progress of
dependecies.
"""
# Filtering hapens in get_selected
selected_nodes = self.get_selected(spec)
# Save to global variable
selected_resources.set_selected_resources(selected_nodes)
# Construct a new graph using the selected_nodes
new_graph = self.full_graph.get_subset_graph(selected_nodes)
# should we give a way here for consumers to mutate the graph?
return GraphQueue(new_graph.graph, self.manifest, selected_nodes)
Expand Down
139 changes: 123 additions & 16 deletions core/dbt/task/build.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import threading
from typing import Dict, List, Set

from .run import RunTask, ModelRunner as run_model_runner
from .snapshot import SnapshotRunner as snapshot_model_runner
Expand All @@ -8,7 +9,7 @@
from dbt.adapters.factory import get_adapter
from dbt.contracts.results import NodeStatus
from dbt.exceptions import DbtInternalError
from dbt.graph import ResourceTypeSelector
from dbt.graph import ResourceTypeSelector, GraphQueue, Graph
from dbt.node_types import NodeType
from dbt.task.test import TestSelector
from dbt.task.base import BaseRunner
Expand Down Expand Up @@ -77,7 +78,7 @@ class BuildTask(RunTask):
I.E. a resource of type Model is handled by the ModelRunner which is
imported as run_model_runner."""

MARK_DEPENDENT_ERRORS_STATUSES = [NodeStatus.Error, NodeStatus.Fail]
MARK_DEPENDENT_ERRORS_STATUSES = [NodeStatus.Error, NodeStatus.Fail, NodeStatus.Skipped]

RUNNER_MAP = {
NodeType.Model: run_model_runner,
Expand All @@ -88,28 +89,133 @@ class BuildTask(RunTask):
}
ALL_RESOURCE_VALUES = frozenset({x for x in RUNNER_MAP.keys()})

@property
def resource_types(self):
def __init__(self, args, config, manifest) -> None:
super().__init__(args, config, manifest)
self.selected_unit_tests: Set = set()
self.model_to_unit_test_map: Dict[str, List] = {}

def resource_types(self, no_unit_tests=False):
if self.args.include_saved_query:
self.RUNNER_MAP[NodeType.SavedQuery] = SavedQueryRunner
self.ALL_RESOURCE_VALUES = self.ALL_RESOURCE_VALUES.union({NodeType.SavedQuery})

if not self.args.resource_types:
return list(self.ALL_RESOURCE_VALUES)

values = set(self.args.resource_types)

if "all" in values:
values.remove("all")
values.update(self.ALL_RESOURCE_VALUES)
resource_types = list(self.ALL_RESOURCE_VALUES)
else:
resource_types = set(self.args.resource_types)

if "all" in resource_types:
resource_types.remove("all")
resource_types.update(self.ALL_RESOURCE_VALUES)

# First we get selected_nodes including unit tests, then without,
# and do a set difference.
if no_unit_tests is True and NodeType.Unit in resource_types:
resource_types.remove(NodeType.Unit)
return list(resource_types)

# overrides get_graph_queue in runnable.py
def get_graph_queue(self) -> GraphQueue:
# Following uses self.selection_arg and self.exclusion_arg
spec = self.get_selection_spec()

# selector including unit tests
full_selector = self.get_node_selector(no_unit_tests=False)
# selected node unique_ids with unit_tests
full_selected_nodes = full_selector.get_selected(spec)

# This selector removes the unit_tests from the selector
selector_wo_unit_tests = self.get_node_selector(no_unit_tests=True)
# selected node unique_ids without unit_tests
selected_nodes_wo_unit_tests = selector_wo_unit_tests.get_selected(spec)

# Get the difference in the sets of nodes with and without unit tests and
# save it
selected_unit_tests = full_selected_nodes - selected_nodes_wo_unit_tests
self.selected_unit_tests = selected_unit_tests
self.build_model_to_unit_test_map(selected_unit_tests)

# get_graph_queue in the selector will remove NodeTypes not specified
# in the node_selector (filter_selection).
return selector_wo_unit_tests.get_graph_queue(spec)

# overrides handle_job_queue in runnable.py
def handle_job_queue(self, pool, callback):
if self.run_count == 0:
self.num_nodes = self.num_nodes + len(self.selected_unit_tests)
node = self.job_queue.get()
if (
node.resource_type == NodeType.Model
and self.model_to_unit_test_map
and node.unique_id in self.model_to_unit_test_map
):
self.handle_model_with_unit_tests_node(node, pool, callback)

return list(values)
else:
self.handle_job_queue_node(node, pool, callback)

def get_node_selector(self) -> ResourceTypeSelector:
def handle_model_with_unit_tests_node(self, node, pool, callback):
self._raise_set_error()
args = [node]
if self.config.args.single_threaded:
callback(self.call_model_and_unit_tests_runner(*args))
else:
pool.apply_async(self.call_model_and_unit_tests_runner, args=args, callback=callback)

def call_model_and_unit_tests_runner(self, node) -> RunResult:
assert self.manifest
for unit_test_unique_id in self.model_to_unit_test_map[node.unique_id]:
unit_test_node = self.manifest.unit_tests[unit_test_unique_id]
unit_test_runner = self.get_runner(unit_test_node)
# If the model is marked skip, also skip the unit tests
if node.unique_id in self._skipped_children:
# cause is only for ephemeral nodes
unit_test_runner.do_skip(cause=None)
result = self.call_runner(unit_test_runner)
self._handle_result(result)
if result.status in self.MARK_DEPENDENT_ERRORS_STATUSES:
# The _skipped_children dictionary can contain a run_result for ephemeral nodes,
# but that should never be the case here.
self._skipped_children[node.unique_id] = None
runner = self.get_runner(node)
if runner.node.unique_id in self._skipped_children:
cause = self._skipped_children.pop(runner.node.unique_id)
runner.do_skip(cause=cause)
return self.call_runner(runner)

# handle non-model-plus-unit-tests nodes
def handle_job_queue_node(self, node, pool, callback):
self._raise_set_error()
runner = self.get_runner(node)
# we finally know what we're running! Make sure we haven't decided
# to skip it due to upstream failures
if runner.node.unique_id in self._skipped_children:
cause = self._skipped_children.pop(runner.node.unique_id)
runner.do_skip(cause=cause)
args = [runner]
if self.config.args.single_threaded:
callback(self.call_runner(*args))
else:
pool.apply_async(self.call_runner, args=args, callback=callback)

# Make a map of model unique_ids to selected unit test unique_ids,
# for processing before the model.
def build_model_to_unit_test_map(self, selected_unit_tests):
dct = {}
for unit_test_unique_id in selected_unit_tests:
unit_test = self.manifest.unit_tests[unit_test_unique_id]
model_unique_id = unit_test.depends_on.nodes[0]
if model_unique_id not in dct:
dct[model_unique_id] = []
dct[model_unique_id].append(unit_test.unique_id)
self.model_to_unit_test_map = dct

# We return two different kinds of selectors, one with unit tests and one without
def get_node_selector(self, no_unit_tests=False) -> ResourceTypeSelector:
if self.manifest is None or self.graph is None:
raise DbtInternalError("manifest and graph must be set to get node selection")

resource_types = self.resource_types
resource_types = self.resource_types(no_unit_tests)

if resource_types == [NodeType.Test]:
return TestSelector(
Expand All @@ -127,9 +233,10 @@ def get_node_selector(self) -> ResourceTypeSelector:
def get_runner_type(self, node):
return self.RUNNER_MAP.get(node.resource_type)

def compile_manifest(self):
# Special build compile_manifest method to pass add_test_edges to the compiler
def compile_manifest(self) -> None:
if self.manifest is None:
raise DbtInternalError("compile_manifest called before manifest was loaded")
adapter = get_adapter(self.config)
compiler = adapter.get_compiler()
self.graph = compiler.compile(self.manifest, add_test_edges=True)
self.graph: Graph = compiler.compile(self.manifest, add_test_edges=True)
5 changes: 3 additions & 2 deletions core/dbt/task/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,9 @@ def run(self) -> None:
if previous_hash != current_hash:
self.lock()

# Early return when dry run or lock only.
if self.args.dry_run or self.args.lock:
# Early return when 'dbt deps --lock'
# Just resolve packages and write lock file, don't actually install packages
if self.args.lock:
return

if system.path_exists(self.project.packages_install_path):
Expand Down
31 changes: 19 additions & 12 deletions core/dbt/task/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
ModelMetadata,
NodeCount,
)
from dbt.node_types import NodeType
from dbt.parser.manifest import write_manifest
from dbt.task.base import ConfiguredTask, BaseRunner
from .printer import (
Expand Down Expand Up @@ -222,8 +223,9 @@ def call_runner(self, runner: BaseRunner) -> RunResult:
)
)
# `_event_status` dict is only used for logging. Make sure
# it gets deleted when we're done with it
runner.node.clear_event_status()
# it gets deleted when we're done with it, except for unit tests
if not runner.node.resource_type == NodeType.Unit:
runner.node.clear_event_status()

fail_fast = get_flags().FAIL_FAST

Expand Down Expand Up @@ -275,16 +277,7 @@ def callback(result):
self.job_queue.mark_done(result.node.unique_id)

while not self.job_queue.empty():
node = self.job_queue.get()
self._raise_set_error()
runner = self.get_runner(node)
# we finally know what we're running! Make sure we haven't decided
# to skip it due to upstream failures
if runner.node.unique_id in self._skipped_children:
cause = self._skipped_children.pop(runner.node.unique_id)
runner.do_skip(cause=cause)
args = (runner,)
self._submit(pool, args, callback)
self.handle_job_queue(pool, callback)

# block on completion
if get_flags().FAIL_FAST:
Expand All @@ -301,6 +294,19 @@ def callback(result):

return

# The build command overrides this
def handle_job_queue(self, pool, callback):
node = self.job_queue.get()
self._raise_set_error()
runner = self.get_runner(node)
# we finally know what we're running! Make sure we haven't decided
# to skip it due to upstream failures
if runner.node.unique_id in self._skipped_children:
cause = self._skipped_children.pop(runner.node.unique_id)
runner.do_skip(cause=cause)
args = [runner]
self._submit(pool, args, callback)

def _handle_result(self, result: RunResult):
"""Mark the result as completed, insert the `CompileResultNode` into
the manifest, and mark any descendants (potentially with a 'cause' if
Expand All @@ -315,6 +321,7 @@ def _handle_result(self, result: RunResult):
if self.manifest is None:
raise DbtInternalError("manifest was None in _handle_result")

# If result.status == NodeStatus.Error, plus Fail for build command
if result.status in self.MARK_DEPENDENT_ERRORS_STATUSES:
if is_ephemeral:
cause = result
Expand Down
Loading

0 comments on commit a0177e3

Please sign in to comment.