Skip to content

Commit

Permalink
state:modified vars, behind "state_modified_compare_vars" behaviour f…
Browse files Browse the repository at this point in the history
…lag (#10793)
  • Loading branch information
MichelleArk authored Sep 30, 2024
1 parent 2ff3f20 commit d1857b3
Show file tree
Hide file tree
Showing 22 changed files with 614 additions and 21 deletions.
7 changes: 7 additions & 0 deletions .changes/unreleased/Features-20240729-173203.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Features
body: Include models that depend on changed vars in state:modified, add state:modified.vars
selection method
time: 2024-07-29T17:32:03.368508-04:00
custom:
Author: michelleark
Issue: "4304"
1 change: 1 addition & 0 deletions core/dbt/artifacts/resources/v1/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ class ParsedResource(ParsedResourceMandatory):
unrendered_config_call_dict: Dict[str, Any] = field(default_factory=dict)
relation_name: Optional[str] = None
raw_code: str = ""
vars: Dict[str, Any] = field(default_factory=dict)

def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
Expand Down
1 change: 1 addition & 0 deletions core/dbt/artifacts/resources/v1/exposure.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Exposure(GraphResource):
tags: List[str] = field(default_factory=list)
config: ExposureConfig = field(default_factory=ExposureConfig)
unrendered_config: Dict[str, Any] = field(default_factory=dict)
vars: Dict[str, Any] = field(default_factory=dict)
url: Optional[str] = None
depends_on: DependsOn = field(default_factory=DependsOn)
refs: List[RefArgs] = field(default_factory=list)
Expand Down
1 change: 1 addition & 0 deletions core/dbt/artifacts/resources/v1/source_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,6 @@ class SourceDefinition(ParsedSourceMandatory):
config: SourceConfig = field(default_factory=SourceConfig)
patch_path: Optional[str] = None
unrendered_config: Dict[str, Any] = field(default_factory=dict)
vars: Dict[str, Any] = field(default_factory=dict)
relation_name: Optional[str] = None
created_at: float = field(default_factory=lambda: time.time())
38 changes: 26 additions & 12 deletions core/dbt/context/configured.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,35 @@ def __init__(self, package_name: str):
self.resource_type = NodeType.Model


class SchemaYamlVars:
def __init__(self):
self.env_vars = {}
self.vars = {}


class ConfiguredVar(Var):
def __init__(
self,
context: Dict[str, Any],
config: AdapterRequiredConfig,
project_name: str,
schema_yaml_vars: Optional[SchemaYamlVars] = None,
):
super().__init__(context, config.cli_vars)
self._config = config
self._project_name = project_name
self.schema_yaml_vars = schema_yaml_vars

def __call__(self, var_name, default=Var._VAR_NOTSET):
my_config = self._config.load_dependencies()[self._project_name]

var_found = False
var_value = None

# cli vars > active project > local project
if var_name in self._config.cli_vars:
return self._config.cli_vars[var_name]
var_found = True
var_value = self._config.cli_vars[var_name]

adapter_type = self._config.credentials.type
lookup = FQNLookup(self._project_name)
Expand All @@ -58,19 +70,21 @@ def __call__(self, var_name, default=Var._VAR_NOTSET):
all_vars.add(my_config.vars.vars_for(lookup, adapter_type))
all_vars.add(active_vars)

if var_name in all_vars:
return all_vars[var_name]
if not var_found and var_name in all_vars:
var_found = True
var_value = all_vars[var_name]

if default is not Var._VAR_NOTSET:
return default

return self.get_missing_var(var_name)
if not var_found and default is not Var._VAR_NOTSET:
var_found = True
var_value = default

if not var_found:
return self.get_missing_var(var_name)
else:
if self.schema_yaml_vars:
self.schema_yaml_vars.vars[var_name] = var_value

class SchemaYamlVars:
def __init__(self):
self.env_vars = {}
self.vars = {}
return var_value


class SchemaYamlContext(ConfiguredContext):
Expand All @@ -82,7 +96,7 @@ def __init__(self, config, project_name: str, schema_yaml_vars: Optional[SchemaY

@contextproperty()
def var(self) -> ConfiguredVar:
return ConfiguredVar(self._ctx, self.config, self._project_name)
return ConfiguredVar(self._ctx, self.config, self._project_name, self.schema_yaml_vars)

@contextmember()
def env_var(self, var: str, default: Optional[str] = None) -> str:
Expand Down
8 changes: 8 additions & 0 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,14 @@ def get_missing_var(self, var_name):
# in the parser, just always return None.
return None

def __call__(self, var_name: str, default: Any = ModelConfiguredVar._VAR_NOTSET) -> Any:
var_value = super().__call__(var_name, default)

if self._node and hasattr(self._node, "vars"):
self._node.vars[var_name] = var_value

return var_value


class RuntimeVar(ModelConfiguredVar):
pass
Expand Down
17 changes: 17 additions & 0 deletions core/dbt/contracts/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ class SchemaSourceFile(BaseSourceFile):
sop: List[SourceKey] = field(default_factory=list)
env_vars: Dict[str, Any] = field(default_factory=dict)
unrendered_configs: Dict[str, Any] = field(default_factory=dict)
vars: Dict[str, Any] = field(default_factory=dict)
pp_dict: Optional[Dict[str, Any]] = None
pp_test_index: Optional[Dict[str, Any]] = None

Expand Down Expand Up @@ -353,6 +354,22 @@ def delete_from_unrendered_configs(self, yaml_key, name):
if not self.unrendered_configs[yaml_key]:
del self.unrendered_configs[yaml_key]

def add_vars(self, vars: Dict[str, Any], yaml_key: str, name: str) -> None:
if yaml_key not in self.vars:
self.vars[yaml_key] = {}

if name not in self.vars[yaml_key]:
self.vars[yaml_key][name] = vars

def get_vars(self, yaml_key: str, name: str) -> Dict[str, Any]:
if yaml_key not in self.vars:
return {}

if name not in self.vars[yaml_key]:
return {}

return self.vars[yaml_key][name]

def add_env_var(self, var, yaml_key, name):
if yaml_key not in self.env_vars:
self.env_vars[yaml_key] = {}
Expand Down
32 changes: 32 additions & 0 deletions core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,19 +369,30 @@ def same_contract(self, old, adapter_type=None) -> bool:
# This would only apply to seeds
return True

def same_vars(self, old) -> bool:
return self.vars == old.vars

def same_contents(self, old, adapter_type) -> bool:
if old is None:
return False

# Need to ensure that same_contract is called because it
# could throw an error
same_contract = self.same_contract(old, adapter_type)

# Legacy behaviour
if not get_flags().state_modified_compare_vars:
same_vars = True
else:
same_vars = self.same_vars(old)

return (
self.same_body(old)
and self.same_config(old)
and self.same_persisted_description(old)
and self.same_fqn(old)
and self.same_database_representation(old)
and same_vars
and same_contract
and True
)
Expand Down Expand Up @@ -1251,6 +1262,9 @@ def same_config(self, old: "SourceDefinition") -> bool:
old.unrendered_config,
)

def same_vars(self, other: "SourceDefinition") -> bool:
return self.vars == other.vars

def same_contents(self, old: Optional["SourceDefinition"]) -> bool:
# existing when it didn't before is a change!
if old is None:
Expand All @@ -1264,13 +1278,20 @@ def same_contents(self, old: Optional["SourceDefinition"]) -> bool:
# freshness changes are changes, I guess
# metadata/tags changes are not "changes"
# patching/description changes are not "changes"
# Legacy behaviour
if not get_flags().state_modified_compare_vars:
same_vars = True
else:
same_vars = self.same_vars(old)

return (
self.same_database_representation(old)
and self.same_fqn(old)
and self.same_config(old)
and self.same_quoting(old)
and self.same_freshness(old)
and self.same_external(old)
and same_vars
and True
)

Expand Down Expand Up @@ -1367,12 +1388,21 @@ def same_config(self, old: "Exposure") -> bool:
old.unrendered_config,
)

def same_vars(self, old: "Exposure") -> bool:
return self.vars == old.vars

def same_contents(self, old: Optional["Exposure"]) -> bool:
# existing when it didn't before is a change!
# metadata/tags changes are not "changes"
if old is None:
return True

# Legacy behaviour
if not get_flags().state_modified_compare_vars:
same_vars = True
else:
same_vars = self.same_vars(old)

return (
self.same_fqn(old)
and self.same_exposure_type(old)
Expand All @@ -1383,6 +1413,7 @@ def same_contents(self, old: Optional["Exposure"]) -> bool:
and self.same_label(old)
and self.same_depends_on(old)
and self.same_config(old)
and same_vars
and True
)

Expand Down Expand Up @@ -1634,6 +1665,7 @@ class ParsedNodePatch(ParsedPatch):
latest_version: Optional[NodeVersion]
constraints: List[Dict[str, Any]]
deprecation_date: Optional[datetime]
vars: Dict[str, Any]
time_spine: Optional[TimeSpine] = None


Expand Down
2 changes: 2 additions & 0 deletions core/dbt/contracts/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ class ProjectFlags(ExtensibleDbtClassMixin):
source_freshness_run_project_hooks: bool = False
skip_nodes_if_on_run_start_fails: bool = False
state_modified_compare_more_unrendered_values: bool = False
state_modified_compare_vars: bool = False

@property
def project_only_flags(self) -> Dict[str, Any]:
Expand All @@ -352,6 +353,7 @@ def project_only_flags(self) -> Dict[str, Any]:
"source_freshness_run_project_hooks": self.source_freshness_run_project_hooks,
"skip_nodes_if_on_run_start_fails": self.skip_nodes_if_on_run_start_fails,
"state_modified_compare_more_unrendered_values": self.state_modified_compare_more_unrendered_values,
"state_modified_compare_vars": self.state_modified_compare_vars,
}


Expand Down
1 change: 1 addition & 0 deletions core/dbt/graph/selector_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,7 @@ def search(self, included_nodes: Set[UniqueId], selector: str) -> Iterator[Uniqu
"modified.relation": self.check_modified_factory("same_database_representation"),
"modified.macros": self.check_modified_macros,
"modified.contract": self.check_modified_contract("same_contract", adapter_type),
"modified.vars": self.check_modified_factory("same_vars"),
}
if selector in state_checks:
checker = state_checks[selector]
Expand Down
5 changes: 4 additions & 1 deletion core/dbt/parser/schema_yaml_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ def parse_exposure(self, unparsed: UnparsedExposure) -> None:
unique_id = f"{NodeType.Exposure}.{package_name}.{unparsed.name}"
path = self.yaml.path.relative_path

assert isinstance(self.yaml.file, SchemaSourceFile)
exposure_vars = self.yaml.file.get_vars(self.key, unparsed.name)

fqn = self.schema_parser.get_fqn_prefix(path)
fqn.append(unparsed.name)

Expand Down Expand Up @@ -133,6 +136,7 @@ def parse_exposure(self, unparsed: UnparsedExposure) -> None:
maturity=unparsed.maturity,
config=config,
unrendered_config=unrendered_config,
vars=exposure_vars,
)
ctx = generate_parse_exposure(
parsed,
Expand All @@ -144,7 +148,6 @@ def parse_exposure(self, unparsed: UnparsedExposure) -> None:
get_rendered(depends_on_jinja, ctx, parsed, capture_macros=True)
# parsed now has a populated refs/sources/metrics

assert isinstance(self.yaml.file, SchemaSourceFile)
if parsed.config.enabled:
self.manifest.add_exposure(self.yaml.file, parsed)
else:
Expand Down
18 changes: 14 additions & 4 deletions core/dbt/parser/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,10 +403,14 @@ def get_key_dicts(self) -> Iterable[Dict[str, Any]]:

if self.schema_yaml_vars.env_vars:
self.schema_parser.manifest.env_vars.update(self.schema_yaml_vars.env_vars)
for var in self.schema_yaml_vars.env_vars.keys():
schema_file.add_env_var(var, self.key, entry["name"])
for env_var in self.schema_yaml_vars.env_vars.keys():
schema_file.add_env_var(env_var, self.key, entry["name"])
self.schema_yaml_vars.env_vars = {}

if self.schema_yaml_vars.vars:
schema_file.add_vars(self.schema_yaml_vars.vars, self.key, entry["name"])
self.schema_yaml_vars.vars = {}

yield entry

def render_entry(self, dct):
Expand Down Expand Up @@ -677,6 +681,9 @@ def parse_patch(self, block: TargetBlock[NodeTarget], refs: ParserRef) -> None:
# code consistency.
deprecation_date: Optional[datetime.datetime] = None
time_spine: Optional[TimeSpine] = None
assert isinstance(self.yaml.file, SchemaSourceFile)
source_file: SchemaSourceFile = self.yaml.file

if isinstance(block.target, UnparsedModelUpdate):
deprecation_date = block.target.deprecation_date
time_spine = (
Expand Down Expand Up @@ -709,9 +716,9 @@ def parse_patch(self, block: TargetBlock[NodeTarget], refs: ParserRef) -> None:
constraints=block.target.constraints,
deprecation_date=deprecation_date,
time_spine=time_spine,
vars=source_file.get_vars(block.target.yaml_key, block.target.name),
)
assert isinstance(self.yaml.file, SchemaSourceFile)
source_file: SchemaSourceFile = self.yaml.file

if patch.yaml_key in ["models", "seeds", "snapshots"]:
unique_id = self.manifest.ref_lookup.get_unique_id(
patch.name, self.project.project_name, None
Expand Down Expand Up @@ -805,6 +812,8 @@ def patch_node_properties(self, node, patch: "ParsedNodePatch") -> None:
node.description = patch.description
node.columns = patch.columns
node.name = patch.name
# Prefer node-level vars to vars from patch
node.vars = {**patch.vars, **node.vars}

if not isinstance(node, ModelNode):
for attr in ["latest_version", "access", "version", "constraints"]:
Expand Down Expand Up @@ -954,6 +963,7 @@ def parse_patch(self, block: TargetBlock[UnparsedModelUpdate], refs: ParserRef)
latest_version=latest_version,
constraints=unparsed_version.constraints or target.constraints,
deprecation_date=unparsed_version.deprecation_date,
vars=source_file.get_vars(block.target.yaml_key, block.target.name),
)
# Node patched before config because config patching depends on model name,
# which may have been updated in the version patch
Expand Down
6 changes: 6 additions & 0 deletions core/dbt/parser/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ContextConfigGenerator,
UnrenderedConfigGenerator,
)
from dbt.contracts.files import SchemaSourceFile
from dbt.contracts.graph.manifest import Manifest, SourceKey
from dbt.contracts.graph.nodes import (
GenericTestNode,
Expand Down Expand Up @@ -158,6 +159,10 @@ def parse_source(self, target: UnpatchedSourceDefinition) -> SourceDefinition:
rendered=False,
)

schema_file = self.manifest.files[target.file_id]
assert isinstance(schema_file, SchemaSourceFile)
source_vars = schema_file.get_vars("sources", source.name)

if not isinstance(config, SourceConfig):
raise DbtInternalError(
f"Calculated a {type(config)} for a source, but expected a SourceConfig"
Expand Down Expand Up @@ -190,6 +195,7 @@ def parse_source(self, target: UnpatchedSourceDefinition) -> SourceDefinition:
tags=tags,
config=config,
unrendered_config=unrendered_config,
vars=source_vars,
)

if (
Expand Down
Loading

0 comments on commit d1857b3

Please sign in to comment.