diff --git a/.changes/unreleased/Features-20230401-193614.yaml b/.changes/unreleased/Features-20230401-193614.yaml new file mode 100644 index 00000000000..0c37af174a2 --- /dev/null +++ b/.changes/unreleased/Features-20230401-193614.yaml @@ -0,0 +1,6 @@ +kind: Features +body: 'New command: ''dbt clone''' +time: 2023-04-01T19:36:14.622217+02:00 +custom: + Author: jtcohen6 + Issue: "7256" diff --git a/core/dbt/cli/main.py b/core/dbt/cli/main.py index 30051063095..3cf2d3d9a27 100644 --- a/core/dbt/cli/main.py +++ b/core/dbt/cli/main.py @@ -25,6 +25,7 @@ from dbt.task.build import BuildTask from dbt.task.clean import CleanTask from dbt.task.compile import CompileTask +from dbt.task.clone import CloneTask from dbt.task.debug import DebugTask from dbt.task.deps import DepsTask from dbt.task.freshness import FreshnessTask @@ -393,6 +394,42 @@ def show(ctx, **kwargs): return results, success +# dbt clone +@cli.command("clone") +@click.pass_context +@p.exclude +@p.full_refresh +@p.profile +@p.profiles_dir +@p.project_dir +@p.resource_type +@p.select +@p.selector +@p.state # required +@p.target +@p.target_path +@p.threads +@p.vars +@p.version_check +@requires.preflight +@requires.profile +@requires.project +@requires.runtime_config +@requires.manifest +@requires.postflight +def clone(ctx, **kwargs): + """Create clones of selected nodes based on their location in the manifest provided to --state.""" + task = CloneTask( + ctx.obj["flags"], + ctx.obj["runtime_config"], + ctx.obj["manifest"], + ) + + results = task.run() + success = task.interpret_results(results) + return results, success + + # dbt debug @cli.command("debug") @click.pass_context diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index 8a74cc08d1e..754522605d3 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -1427,6 +1427,20 @@ def this(self) -> Optional[RelationProxy]: return None return self.db_wrapper.Relation.create_from(self.config, self.model) + @contextproperty + def state_relation(self) -> Optional[RelationProxy]: + """ + For commands which add information about this node's corresponding + production version (via a --state artifact), access the Relation + object for that stateful other + """ + if getattr(self.model, "state_relation", None): + return self.db_wrapper.Relation.create_from_node( + self.config, self.model.state_relation # type: ignore + ) + else: + return None + # This is called by '_context_for', used in 'render_with_context' def generate_parser_model_context( diff --git a/core/dbt/contracts/graph/manifest.py b/core/dbt/contracts/graph/manifest.py index f4caefbf6e4..d72b1032049 100644 --- a/core/dbt/contracts/graph/manifest.py +++ b/core/dbt/contracts/graph/manifest.py @@ -37,6 +37,7 @@ GraphMemberNode, ResultNode, BaseNode, + StateRelation, ManifestOrPublicNode, ) from dbt.contracts.graph.unparsed import SourcePatch, NodeVersion, UnparsedVersion @@ -1103,6 +1104,30 @@ def merge_from_artifact( sample = list(islice(merged, 5)) fire_event(MergedFromState(num_merged=len(merged), sample=sample)) + # Called by CloneTask.defer_to_manifest + def add_from_artifact( + self, + other: "WritableManifest", + ) -> None: + """Update this manifest by *adding* information about each node's location + in the other manifest. + + Only non-ephemeral refable nodes are examined. + """ + refables = set(NodeType.refable()) + for unique_id, node in other.nodes.items(): + current = self.nodes.get(unique_id) + if current and (node.resource_type in refables and not node.is_ephemeral): + other_node = other.nodes[unique_id] + state_relation = StateRelation( + other_node.database, other_node.schema, other_node.alias + ) + self.nodes[unique_id] = current.replace(state_relation=state_relation) + + # Rebuild the flat_graph, which powers the 'graph' context variable, + # now that we've deferred some nodes + self.build_flat_graph() + # Methods that were formerly in ParseResult def add_macro(self, source_file: SourceFile, macro: Macro): @@ -1316,6 +1341,8 @@ def __post_serialize__(self, dct): for unique_id, node in dct["nodes"].items(): if "config_call_dict" in node: del node["config_call_dict"] + if "state_relation" in node: + del node["state_relation"] return dct diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index 26938b37554..4e9ccfa59e2 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -270,6 +270,17 @@ def add_public_node(self, value: str): self.public_nodes.append(value) +@dataclass +class StateRelation(dbtClassMixin): + database: Optional[str] + schema: str + alias: str + + @property + def identifier(self): + return self.alias + + @dataclass class ParsedNodeMandatory(GraphNode, HasRelationMetadata, Replaceable): alias: str @@ -358,7 +369,7 @@ def __post_serialize__(self, dct): @classmethod def _deserialize(cls, dct: Dict[str, int]): # The serialized ParsedNodes do not differ from each other - # in fields that would allow 'from_dict' to distinguis + # in fields that would allow 'from_dict' to distinguish # between them. resource_type = dct["resource_type"] if resource_type == "model": @@ -615,6 +626,7 @@ class ModelNode(CompiledNode): constraints: List[ModelLevelConstraint] = field(default_factory=list) version: Optional[NodeVersion] = None latest_version: Optional[NodeVersion] = None + state_relation: Optional[StateRelation] = None @property def is_latest_version(self) -> bool: @@ -797,6 +809,7 @@ class SeedNode(ParsedNode): # No SQLDefaults! # and we need the root_path to load the seed later root_path: Optional[str] = None depends_on: MacroDependsOn = field(default_factory=MacroDependsOn) + state_relation: Optional[StateRelation] = None def same_seeds(self, other: "SeedNode") -> bool: # for seeds, we check the hashes. If the hashes are different types, @@ -995,6 +1008,7 @@ class IntermediateSnapshotNode(CompiledNode): class SnapshotNode(CompiledNode): resource_type: NodeType = field(metadata={"restrict": [NodeType.Snapshot]}) config: SnapshotConfig + state_relation: Optional[StateRelation] = None # ==================================== diff --git a/core/dbt/include/global_project/macros/materializations/models/clone.sql b/core/dbt/include/global_project/macros/materializations/models/clone.sql new file mode 100644 index 00000000000..8525983619e --- /dev/null +++ b/core/dbt/include/global_project/macros/materializations/models/clone.sql @@ -0,0 +1,115 @@ +{% macro can_clone_tables() %} + {{ return(adapter.dispatch('can_clone_tables', 'dbt')()) }} +{% endmacro %} + + +{% macro default__can_clone_tables() %} + {{ return(False) }} +{% endmacro %} + + +{% macro snowflake__can_clone_tables() %} + {{ return(True) }} +{% endmacro %} + + +{% macro get_pointer_sql(to_relation) %} + {{ return(adapter.dispatch('get_pointer_sql', 'dbt')(to_relation)) }} +{% endmacro %} + + +{% macro default__get_pointer_sql(to_relation) %} + {% set pointer_sql %} + select * from {{ to_relation }} + {% endset %} + {{ return(pointer_sql) }} +{% endmacro %} + + +{% macro get_clone_table_sql(this_relation, state_relation) %} + {{ return(adapter.dispatch('get_clone_table_sql', 'dbt')(this_relation, state_relation)) }} +{% endmacro %} + + +{% macro default__get_clone_table_sql(this_relation, state_relation) %} + create or replace table {{ this_relation }} clone {{ state_relation }} +{% endmacro %} + + +{% macro snowflake__get_clone_table_sql(this_relation, state_relation) %} + create or replace + {{ "transient" if config.get("transient", true) }} + table {{ this_relation }} + clone {{ state_relation }} + {{ "copy grants" if config.get("copy_grants", false) }} +{% endmacro %} + + +{%- materialization clone, default -%} + + {%- set relations = {'relations': []} -%} + + {%- if not state_relation -%} + -- nothing to do + {{ log("No relation found in state manifest for " ~ model.unique_id) }} + {{ return(relations) }} + {%- endif -%} + + {%- set existing_relation = load_cached_relation(this) -%} + + {%- if existing_relation and not flags.FULL_REFRESH -%} + -- noop! + {{ log("Relation " ~ existing_relation ~ " already exists") }} + {{ return(relations) }} + {%- endif -%} + + {%- set other_existing_relation = load_cached_relation(state_relation) -%} + + -- If this is a database that can do zero-copy cloning of tables, and the other relation is a table, then this will be a table + -- Otherwise, this will be a view + + {% set can_clone_tables = can_clone_tables() %} + + {%- if other_existing_relation and other_existing_relation.type == 'table' and can_clone_tables -%} + + {%- set target_relation = this.incorporate(type='table') -%} + {% if existing_relation is not none and not existing_relation.is_table %} + {{ log("Dropping relation " ~ existing_relation ~ " because it is of type " ~ existing_relation.type) }} + {{ drop_relation_if_exists(existing_relation) }} + {% endif %} + + -- as a general rule, data platforms that can clone tables can also do atomic 'create or replace' + {% call statement('main') %} + {{ get_clone_table_sql(target_relation, state_relation) }} + {% endcall %} + + {% set should_revoke = should_revoke(existing_relation, full_refresh_mode=True) %} + {% do apply_grants(target_relation, grant_config, should_revoke=should_revoke) %} + {% do persist_docs(target_relation, model) %} + + {{ return({'relations': [target_relation]}) }} + + {%- else -%} + + {%- set target_relation = this.incorporate(type='view') -%} + + -- TODO: this should probably be illegal + -- I'm just doing it out of convenience to reuse the 'view' materialization logic + {%- do context.update({ + 'sql': get_pointer_sql(state_relation), + 'compiled_code': get_pointer_sql(state_relation) + }) -%} + + -- reuse the view materialization + -- TODO: support actual dispatch for materialization macros + {% set search_name = "materialization_view_" ~ adapter.type() %} + {% if not search_name in context %} + {% set search_name = "materialization_view_default" %} + {% endif %} + {% set materialization_macro = context[search_name] %} + {% set relations = materialization_macro() %} + {{ return(relations) }} + + {%- endif -%} + +{%- endmaterialization -%} diff --git a/core/dbt/task/clone.py b/core/dbt/task/clone.py new file mode 100644 index 00000000000..cab0d6a2de4 --- /dev/null +++ b/core/dbt/task/clone.py @@ -0,0 +1,185 @@ +import threading +from typing import AbstractSet, Optional, Any, List, Iterable, Set + +from dbt.dataclass_schema import dbtClassMixin + +from dbt.contracts.graph.manifest import WritableManifest +from dbt.contracts.results import RunStatus, RunResult +from dbt.exceptions import DbtInternalError, DbtRuntimeError, CompilationError +from dbt.graph import ResourceTypeSelector +from dbt.node_types import NodeType +from dbt.parser.manifest import write_manifest +from dbt.task.base import BaseRunner +from dbt.task.runnable import GraphRunnableTask +from dbt.task.run import _validate_materialization_relations_dict +from dbt.adapters.base import BaseRelation +from dbt.clients.jinja import MacroGenerator +from dbt.context.providers import generate_runtime_model_context + + +class CloneRunner(BaseRunner): + def before_execute(self): + pass + + def after_execute(self, result): + pass + + def _build_run_model_result(self, model, context): + result = context["load_result"]("main") + if result: + status = RunStatus.Success + message = str(result.response) + else: + status = RunStatus.Success + message = "No-op" + adapter_response = {} + if result and isinstance(result.response, dbtClassMixin): + adapter_response = result.response.to_dict(omit_none=True) + return RunResult( + node=model, + status=status, + timing=[], + thread_id=threading.current_thread().name, + execution_time=0, + message=message, + adapter_response=adapter_response, + failures=None, + ) + + def compile(self, manifest): + # no-op + return self.node + + def _materialization_relations(self, result: Any, model) -> List[BaseRelation]: + if isinstance(result, str): + msg = ( + 'The materialization ("{}") did not explicitly return a ' + "list of relations to add to the cache.".format(str(model.get_materialization())) + ) + raise CompilationError(msg, node=model) + + if isinstance(result, dict): + return _validate_materialization_relations_dict(result, model) + + msg = ( + "Invalid return value from materialization, expected a dict " + 'with key "relations", got: {}'.format(str(result)) + ) + raise CompilationError(msg, node=model) + + def execute(self, model, manifest): + context = generate_runtime_model_context(model, self.config, manifest) + materialization_macro = manifest.find_materialization_macro_by_name( + self.config.project_name, "clone", self.adapter.type() + ) + + if "config" not in context: + raise DbtInternalError( + "Invalid materialization context generated, missing config: {}".format(context) + ) + + context_config = context["config"] + + hook_ctx = self.adapter.pre_model_hook(context_config) + try: + result = MacroGenerator( + materialization_macro, context, stack=context["context_macro_stack"] + )() + finally: + self.adapter.post_model_hook(context_config, hook_ctx) + + for relation in self._materialization_relations(result, model): + self.adapter.cache_added(relation.incorporate(dbt_created=True)) + + return self._build_run_model_result(model, context) + + +class CloneTask(GraphRunnableTask): + def raise_on_first_error(self): + return False + + def get_model_schemas(self, adapter, selected_uids: Iterable[str]) -> Set[BaseRelation]: + if self.manifest is None: + raise DbtInternalError("manifest was None in get_model_schemas") + result: Set[BaseRelation] = set() + + for node in self.manifest.nodes.values(): + if node.unique_id not in selected_uids: + continue + if node.is_relational and not node.is_ephemeral: + relation = adapter.Relation.create_from(self.config, node) + result.add(relation.without_identifier()) + + # cache the 'other' schemas too! + if node.state_relation: # type: ignore + other_relation = adapter.Relation.create_from_node( + self.config, node.state_relation # type: ignore + ) + result.add(other_relation.without_identifier()) + + return result + + def before_run(self, adapter, selected_uids: AbstractSet[str]): + with adapter.connection_named("master"): + # unlike in other tasks, we want to add information from the --state manifest *before* caching! + self.defer_to_manifest(adapter, selected_uids) + # only create *our* schemas, but cache *other* schemas in addition + schemas_to_create = super().get_model_schemas(adapter, selected_uids) + self.create_schemas(adapter, schemas_to_create) + schemas_to_cache = self.get_model_schemas(adapter, selected_uids) + self.populate_adapter_cache(adapter, schemas_to_cache) + + @property + def resource_types(self): + if not self.args.resource_types: + return NodeType.refable() + + values = set(self.args.resource_types) + + if "all" in values: + values.remove("all") + values.update(NodeType.refable()) + + values = [NodeType(val) for val in values if val in NodeType.refable()] + + return list(values) + + def get_node_selector(self) -> ResourceTypeSelector: + resource_types = self.resource_types + + if self.manifest is None or self.graph is None: + raise DbtInternalError("manifest and graph must be set to get perform node selection") + return ResourceTypeSelector( + graph=self.graph, + manifest=self.manifest, + previous_state=self.previous_state, + resource_types=resource_types, + ) + + def get_runner_type(self, _): + return CloneRunner + + def _get_deferred_manifest(self) -> Optional[WritableManifest]: + state = self.previous_state + if state is None: + raise DbtRuntimeError( + "--state is required for cloning relations from another environment" + ) + + if state.manifest is None: + raise DbtRuntimeError(f'Could not find manifest in --state path: "{self.args.state}"') + return state.manifest + + # Note that this is different behavior from --defer with other commands, which *merge* + # selected nodes from this manifest + unselected nodes from the other manifest + def defer_to_manifest(self, adapter, selected_uids: AbstractSet[str]): + deferred_manifest = self._get_deferred_manifest() + if deferred_manifest is None: + return + if self.manifest is None: + raise DbtInternalError( + "Expected to defer to manifest, but there is no runtime manifest to defer from!" + ) + self.manifest.add_from_artifact(other=deferred_manifest) + # TODO: is it wrong to write the manifest here? I think it's right... + write_manifest(self.manifest, self.config.target_path) diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index 4b1cea04727..d80ed9cce62 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -14,6 +14,7 @@ from datetime import datetime from dbt import tracking from dbt import utils +from dbt.flags import get_flags from dbt.adapters.base import BaseRelation from dbt.clients.jinja import MacroGenerator from dbt.context.providers import generate_runtime_model_context @@ -444,7 +445,10 @@ def before_run(self, adapter, selected_uids: AbstractSet[str]): with adapter.connection_named("master"): required_schemas = self.get_model_schemas(adapter, selected_uids) self.create_schemas(adapter, required_schemas) - self.populate_adapter_cache(adapter, required_schemas) + if get_flags().CACHE_SELECTED_ONLY is True: + self.populate_adapter_cache(adapter, required_schemas) + else: + self.populate_adapter_cache(adapter) self.defer_to_manifest(adapter, selected_uids) self.safe_run_hooks(adapter, RunHookType.Start, {}) diff --git a/core/dbt/task/runnable.py b/core/dbt/task/runnable.py index 494acf98904..de998ee31c9 100644 --- a/core/dbt/task/runnable.py +++ b/core/dbt/task/runnable.py @@ -380,10 +380,7 @@ def populate_adapter_cache(self, adapter, required_schemas: Set[BaseRelation] = return start_populate_cache = time.perf_counter() - if get_flags().CACHE_SELECTED_ONLY is True: - adapter.set_relations_cache(self.manifest, required_schemas=required_schemas) - else: - adapter.set_relations_cache(self.manifest) + adapter.set_relations_cache(self.manifest, required_schemas=required_schemas) cache_populate_time = time.perf_counter() - start_populate_cache if dbt.tracking.active_user is not None: dbt.tracking.track_runnable_timing( @@ -392,6 +389,11 @@ def populate_adapter_cache(self, adapter, required_schemas: Set[BaseRelation] = def before_run(self, adapter, selected_uids: AbstractSet[str]): with adapter.connection_named("master"): + if get_flags().CACHE_SELECTED_ONLY is True: + required_schemas = self.get_model_schemas(adapter, selected_uids) + self.populate_adapter_cache(adapter, required_schemas) + else: + self.populate_adapter_cache(adapter) self.populate_adapter_cache(adapter) self.defer_to_manifest(adapter, selected_uids) diff --git a/test/unit/test_context.py b/test/unit/test_context.py index 1c02a650b9a..5bb895b844e 100644 --- a/test/unit/test_context.py +++ b/test/unit/test_context.py @@ -230,7 +230,7 @@ def assert_has_keys(required_keys: Set[str], maybe_keys: Set[str], ctx: Dict[str "submit_python_job", "dbt_metadata_envs", } -REQUIRED_MODEL_KEYS = REQUIRED_MACRO_KEYS | {"this", "compiled_code"} +REQUIRED_MODEL_KEYS = REQUIRED_MACRO_KEYS | {"this", "compiled_code", "state_relation"} MAYBE_KEYS = frozenset({"debug"}) diff --git a/test/unit/test_manifest.py b/test/unit/test_manifest.py index ccca812810e..95be155f462 100644 --- a/test/unit/test_manifest.py +++ b/test/unit/test_manifest.py @@ -90,6 +90,7 @@ "relation_name", "contract", "access", + "state_relation", "version", "latest_version", "constraints", diff --git a/tests/functional/defer_state/test_defer_state.py b/tests/functional/defer_state/test_defer_state.py index a50f09af0d1..9384694e23f 100644 --- a/tests/functional/defer_state/test_defer_state.py +++ b/tests/functional/defer_state/test_defer_state.py @@ -272,3 +272,68 @@ def test_run_defer_deleted_upstream(self, project, unique_schema, other_schema): ) results = run_dbt(["test", "--state", "state", "--defer", "--favor-state"]) assert other_schema not in results[0].node.compiled_code + + +get_schema_name_sql = """ +{% macro generate_schema_name(custom_schema_name, node) -%} + {%- set default_schema = target.schema -%} + + {%- if custom_schema_name is not none -%} + {{ return(default_schema ~ '_' ~ custom_schema_name|trim) }} + + -- put seeds into a separate schema in "prod", to verify that cloning in "dev" still works + {%- elif target.name == 'default' and node.resource_type == 'seed' -%} + {{ return(default_schema ~ '_' ~ 'seeds') }} + + {%- else -%} + {{ return(default_schema) }} + {%- endif -%} + +{%- endmacro %} +""" + + +class TestCloneToOther(BaseDeferState): + def build_and_save_state(self): + results = run_dbt(["build"]) + assert len(results) == 6 + + # copy files + self.copy_state() + + @pytest.fixture(scope="class") + def macros(self): + return { + "macros.sql": macros_sql, + "infinite_macros.sql": infinite_macros_sql, + "get_schema_name.sql": get_schema_name_sql, + } + + def test_clone(self, project, unique_schema, other_schema): + project.create_test_schema(other_schema) + self.build_and_save_state() + + clone_args = ["clone", "--state", "state", "--target", "otherschema"] + + results = run_dbt(clone_args) + # TODO: need an "adapter zone" version of this test that checks to see + # how many of the cloned objects are "pointers" (views) versus "true clones" (tables) + # e.g. on Postgres we expect to see 4 views + # whereas on Snowflake we'd expect to see 3 cloned tables + 1 view + assert [r.message for r in results] == ["CREATE VIEW"] * 4 + schema_relations = project.adapter.list_relations( + database=project.database, schema=other_schema + ) + assert [r.type for r in schema_relations] == ["view"] * 4 + + # objects already exist, so this is a no-op + results = run_dbt(clone_args) + assert [r.message for r in results] == ["No-op"] * 4 + + # recreate all objects + results = run_dbt(clone_args + ["--full-refresh"]) + assert [r.message for r in results] == ["CREATE VIEW"] * 4 + + # select only models this time + results = run_dbt(clone_args + ["--resource-type", "model"]) + assert len(results) == 2