Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

state:modified.vars #10502

Closed
wants to merge 11 commits into from
7 changes: 7 additions & 0 deletions .changes/unreleased/Features-20240729-173203.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Features
body: Include models that depend on changed vars in state:modified, add state:modified.vars
selection method
time: 2024-07-29T17:32:03.368508-04:00
custom:
Author: michelleark
Issue: "4304"
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20240923-190758.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Allow singular tests to be documented in properties.yml
time: 2024-09-23T19:07:58.151069+01:00
custom:
Author: aranke
Issue: "9005"
1 change: 1 addition & 0 deletions core/dbt/artifacts/resources/v1/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ class ParsedResource(ParsedResourceMandatory):
unrendered_config_call_dict: Dict[str, Any] = field(default_factory=dict)
relation_name: Optional[str] = None
raw_code: str = ""
vars: Dict[str, Any] = field(default_factory=dict)

def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
Expand Down
1 change: 1 addition & 0 deletions core/dbt/artifacts/resources/v1/exposure.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Exposure(GraphResource):
tags: List[str] = field(default_factory=list)
config: ExposureConfig = field(default_factory=ExposureConfig)
unrendered_config: Dict[str, Any] = field(default_factory=dict)
vars: Dict[str, Any] = field(default_factory=dict)
url: Optional[str] = None
depends_on: DependsOn = field(default_factory=DependsOn)
refs: List[RefArgs] = field(default_factory=list)
Expand Down
1 change: 1 addition & 0 deletions core/dbt/artifacts/resources/v1/source_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,6 @@ class SourceDefinition(ParsedSourceMandatory):
config: SourceConfig = field(default_factory=SourceConfig)
patch_path: Optional[str] = None
unrendered_config: Dict[str, Any] = field(default_factory=dict)
vars: Dict[str, Any] = field(default_factory=dict)
relation_name: Optional[str] = None
created_at: float = field(default_factory=lambda: time.time())
13 changes: 4 additions & 9 deletions core/dbt/config/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,8 @@ def _parse_versions(versions: Union[List[str], str]) -> List[VersionSpecifier]:
return [VersionSpecifier.from_version_string(v) for v in versions]


def _all_source_paths(
model_paths: List[str],
seed_paths: List[str],
snapshot_paths: List[str],
analysis_paths: List[str],
macro_paths: List[str],
) -> List[str]:
paths = chain(model_paths, seed_paths, snapshot_paths, analysis_paths, macro_paths)
def _all_source_paths(*args: List[str]) -> List[str]:
paths = chain(*args)
# Strip trailing slashes since the path is the same even though the name is not
stripped_paths = map(lambda s: s.rstrip("/"), paths)
return list(set(stripped_paths))
Expand Down Expand Up @@ -409,7 +403,7 @@ def create_project(self, rendered: RenderComponents) -> "Project":
snapshot_paths: List[str] = value_or(cfg.snapshot_paths, ["snapshots"])

all_source_paths: List[str] = _all_source_paths(
model_paths, seed_paths, snapshot_paths, analysis_paths, macro_paths
model_paths, seed_paths, snapshot_paths, analysis_paths, macro_paths, test_paths
)

docs_paths: List[str] = value_or(cfg.docs_paths, all_source_paths)
Expand Down Expand Up @@ -652,6 +646,7 @@ def all_source_paths(self) -> List[str]:
self.snapshot_paths,
self.analysis_paths,
self.macro_paths,
self.test_paths,
)

@property
Expand Down
38 changes: 26 additions & 12 deletions core/dbt/context/configured.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,35 @@
self.resource_type = NodeType.Model


class SchemaYamlVars:
def __init__(self):
self.env_vars = {}
self.vars = {}


class ConfiguredVar(Var):
def __init__(
self,
context: Dict[str, Any],
config: AdapterRequiredConfig,
project_name: str,
schema_yaml_vars: Optional[SchemaYamlVars] = None,
):
super().__init__(context, config.cli_vars)
self._config = config
self._project_name = project_name
self.schema_yaml_vars = schema_yaml_vars

def __call__(self, var_name, default=Var._VAR_NOTSET):
my_config = self._config.load_dependencies()[self._project_name]

var_found = False
var_value = None

# cli vars > active project > local project
if var_name in self._config.cli_vars:
return self._config.cli_vars[var_name]
var_found = True
var_value = self._config.cli_vars[var_name]

adapter_type = self._config.credentials.type
lookup = FQNLookup(self._project_name)
Expand All @@ -58,19 +70,21 @@
all_vars.add(my_config.vars.vars_for(lookup, adapter_type))
all_vars.add(active_vars)

if var_name in all_vars:
return all_vars[var_name]
if not var_found and var_name in all_vars:
var_found = True
var_value = all_vars[var_name]

if default is not Var._VAR_NOTSET:
return default

return self.get_missing_var(var_name)
if not var_found and default is not Var._VAR_NOTSET:
var_found = True
var_value = default

if not var_found:
return self.get_missing_var(var_name)

Check warning on line 82 in core/dbt/context/configured.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/context/configured.py#L82

Added line #L82 was not covered by tests
else:
if self.schema_yaml_vars:
self.schema_yaml_vars.vars[var_name] = var_value

class SchemaYamlVars:
def __init__(self):
self.env_vars = {}
self.vars = {}
return var_value


class SchemaYamlContext(ConfiguredContext):
Expand All @@ -82,7 +96,7 @@

@contextproperty()
def var(self) -> ConfiguredVar:
return ConfiguredVar(self._ctx, self.config, self._project_name)
return ConfiguredVar(self._ctx, self.config, self._project_name, self.schema_yaml_vars)

@contextmember()
def env_var(self, var: str, default: Optional[str] = None) -> str:
Expand Down
8 changes: 8 additions & 0 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,14 @@ def get_missing_var(self, var_name):
# in the parser, just always return None.
return None

def __call__(self, var_name: str, default: Any = ModelConfiguredVar._VAR_NOTSET) -> Any:
var_value = super().__call__(var_name, default)

if self._node and hasattr(self._node, "vars"):
self._node.vars[var_name] = var_value

return var_value


class RuntimeVar(ModelConfiguredVar):
pass
Expand Down
17 changes: 17 additions & 0 deletions core/dbt/contracts/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ class SchemaSourceFile(BaseSourceFile):
# created too, but those are in 'sources'
sop: List[SourceKey] = field(default_factory=list)
env_vars: Dict[str, Any] = field(default_factory=dict)
vars: Dict[str, Any] = field(default_factory=dict)
unrendered_configs: Dict[str, Any] = field(default_factory=dict)
pp_dict: Optional[Dict[str, Any]] = None
pp_test_index: Optional[Dict[str, Any]] = None
Expand Down Expand Up @@ -318,6 +319,22 @@ def get_all_test_ids(self):
test_ids.extend(self.data_tests[key][name])
return test_ids

def add_vars(self, vars: Dict[str, Any], yaml_key: str, name: str) -> None:
if yaml_key not in self.vars:
self.vars[yaml_key] = {}

if name not in self.vars[yaml_key]:
self.vars[yaml_key][name] = vars

def get_vars(self, yaml_key: str, name: str) -> Dict[str, Any]:
if yaml_key not in self.vars:
return {}

if name not in self.vars[yaml_key]:
return {}

return self.vars[yaml_key][name]

def add_unrendered_config(self, unrendered_config, yaml_key, name, version=None):
versioned_name = f"{name}_v{version}" if version is not None else name

Expand Down
50 changes: 49 additions & 1 deletion core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
SavedQuery,
SeedNode,
SemanticModel,
SingularTestNode,
SourceDefinition,
UnitTestDefinition,
UnitTestFileFixture,
Expand Down Expand Up @@ -89,7 +90,7 @@
RefName = str


def find_unique_id_for_package(storage, key, package: Optional[PackageName]):
def find_unique_id_for_package(storage, key, package: Optional[PackageName]) -> Optional[UniqueID]:
if key not in storage:
return None

Expand Down Expand Up @@ -470,6 +471,43 @@
_versioned_types: ClassVar[set] = set()


class SingularTestLookup(dbtClassMixin):
def __init__(self, manifest: "Manifest") -> None:
self.storage: Dict[str, Dict[PackageName, UniqueID]] = {}
self.populate(manifest)

def get_unique_id(self, search_name, package: Optional[PackageName]) -> Optional[UniqueID]:
return find_unique_id_for_package(self.storage, search_name, package)

def find(
self, search_name, package: Optional[PackageName], manifest: "Manifest"
) -> Optional[SingularTestNode]:
unique_id = self.get_unique_id(search_name, package)
if unique_id is not None:
return self.perform_lookup(unique_id, manifest)
return None

Check warning on line 488 in core/dbt/contracts/graph/manifest.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/contracts/graph/manifest.py#L485-L488

Added lines #L485 - L488 were not covered by tests

def add_singular_test(self, source: SingularTestNode) -> None:
if source.search_name not in self.storage:
self.storage[source.search_name] = {}

self.storage[source.search_name][source.package_name] = source.unique_id

def populate(self, manifest: "Manifest") -> None:
for node in manifest.nodes.values():
if isinstance(node, SingularTestNode):
self.add_singular_test(node)

def perform_lookup(self, unique_id: UniqueID, manifest: "Manifest") -> SingularTestNode:
if unique_id not in manifest.nodes:
raise dbt_common.exceptions.DbtInternalError(

Check warning on line 503 in core/dbt/contracts/graph/manifest.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/contracts/graph/manifest.py#L502-L503

Added lines #L502 - L503 were not covered by tests
f"Singular test {unique_id} found in cache but not found in manifest"
)
node = manifest.nodes[unique_id]
assert isinstance(node, SingularTestNode)
return node

Check warning on line 508 in core/dbt/contracts/graph/manifest.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/contracts/graph/manifest.py#L506-L508

Added lines #L506 - L508 were not covered by tests


def _packages_to_search(
current_project: str,
node_package: str,
Expand Down Expand Up @@ -869,6 +907,9 @@
_analysis_lookup: Optional[AnalysisLookup] = field(
default=None, metadata={"serialize": lambda x: None, "deserialize": lambda x: None}
)
_singular_test_lookup: Optional[SingularTestLookup] = field(
default=None, metadata={"serialize": lambda x: None, "deserialize": lambda x: None}
)
_parsing_info: ParsingInfo = field(
default_factory=ParsingInfo,
metadata={"serialize": lambda x: None, "deserialize": lambda x: None},
Expand Down Expand Up @@ -1264,6 +1305,12 @@
self._analysis_lookup = AnalysisLookup(self)
return self._analysis_lookup

@property
def singular_test_lookup(self) -> SingularTestLookup:
if self._singular_test_lookup is None:
self._singular_test_lookup = SingularTestLookup(self)
return self._singular_test_lookup

@property
def external_node_unique_ids(self):
return [node.unique_id for node in self.nodes.values() if node.is_external_node]
Expand Down Expand Up @@ -1708,6 +1755,7 @@
self._semantic_model_by_measure_lookup,
self._disabled_lookup,
self._analysis_lookup,
self._singular_test_lookup,
)
return self.__class__, args

Expand Down
27 changes: 27 additions & 0 deletions core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,12 @@ def same_contract(self, old, adapter_type=None) -> bool:
# This would only apply to seeds
return True

def same_vars(self, old) -> bool:
if get_flags().state_modified_compare_vars:
return self.vars == old.vars
else:
return True

def same_contents(self, old, adapter_type) -> bool:
if old is None:
return False
Expand All @@ -382,6 +388,7 @@ def same_contents(self, old, adapter_type) -> bool:
and self.same_persisted_description(old)
and self.same_fqn(old)
and self.same_database_representation(old)
and self.same_vars(old)
and same_contract
and True
)
Expand Down Expand Up @@ -1251,6 +1258,12 @@ def same_config(self, old: "SourceDefinition") -> bool:
old.unrendered_config,
)

def same_vars(self, other: "SourceDefinition") -> bool:
if get_flags().state_modified_compare_vars:
return self.vars == other.vars
else:
return True

def same_contents(self, old: Optional["SourceDefinition"]) -> bool:
# existing when it didn't before is a change!
if old is None:
Expand All @@ -1271,6 +1284,7 @@ def same_contents(self, old: Optional["SourceDefinition"]) -> bool:
and self.same_quoting(old)
and self.same_freshness(old)
and self.same_external(old)
and self.same_vars(old)
and True
)

Expand Down Expand Up @@ -1367,6 +1381,12 @@ def same_config(self, old: "Exposure") -> bool:
old.unrendered_config,
)

def same_vars(self, old: "Exposure") -> bool:
if get_flags().state_modified_compare_vars:
return self.vars == old.vars
else:
return True

def same_contents(self, old: Optional["Exposure"]) -> bool:
# existing when it didn't before is a change!
# metadata/tags changes are not "changes"
Expand All @@ -1383,6 +1403,7 @@ def same_contents(self, old: Optional["Exposure"]) -> bool:
and self.same_label(old)
and self.same_depends_on(old)
and self.same_config(old)
and self.same_vars(old)
and True
)

Expand Down Expand Up @@ -1634,6 +1655,7 @@ class ParsedNodePatch(ParsedPatch):
latest_version: Optional[NodeVersion]
constraints: List[Dict[str, Any]]
deprecation_date: Optional[datetime]
vars: Dict[str, Any]
time_spine: Optional[TimeSpine] = None


Expand All @@ -1642,6 +1664,11 @@ class ParsedMacroPatch(ParsedPatch):
arguments: List[MacroArgument] = field(default_factory=list)


@dataclass
class ParsedSingularTestPatch(ParsedPatch):
pass


# ====================================
# Node unions/categories
# ====================================
Expand Down
5 changes: 5 additions & 0 deletions core/dbt/contracts/graph/unparsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,11 @@ class UnparsedAnalysisUpdate(HasConfig, HasColumnDocs, HasColumnProps, HasYamlMe
access: Optional[str] = None


@dataclass
class UnparsedSingularTestUpdate(HasConfig, HasColumnProps, HasYamlMetadata):
pass


@dataclass
class UnparsedNodeUpdate(HasConfig, HasColumnTests, HasColumnAndTestProps, HasYamlMetadata):
quote_columns: Optional[bool] = None
Expand Down
2 changes: 2 additions & 0 deletions core/dbt/contracts/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ class ProjectFlags(ExtensibleDbtClassMixin):
require_resource_names_without_spaces: bool = False
source_freshness_run_project_hooks: bool = False
state_modified_compare_more_unrendered_values: bool = False
state_modified_compare_vars: bool = False

@property
def project_only_flags(self) -> Dict[str, Any]:
Expand All @@ -350,6 +351,7 @@ def project_only_flags(self) -> Dict[str, Any]:
"require_resource_names_without_spaces": self.require_resource_names_without_spaces,
"source_freshness_run_project_hooks": self.source_freshness_run_project_hooks,
"state_modified_compare_more_unrendered_values": self.state_modified_compare_more_unrendered_values,
"state_modified_compare_vars": self.state_modified_compare_vars,
}


Expand Down
Loading
Loading