Skip to content

Commit

Permalink
Add management schema feature
Browse files Browse the repository at this point in the history
  • Loading branch information
bneijt committed Jun 17, 2022
1 parent f8d347e commit 5b4be03
Show file tree
Hide file tree
Showing 10 changed files with 20,161 additions and 1 deletion.
6 changes: 6 additions & 0 deletions core/dbt/config/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class Profile(HasCredentials):
threads: int
credentials: Credentials
profile_env_vars: Dict[str, Any]
manage_schemas: bool

def __init__(
self,
Expand All @@ -99,6 +100,7 @@ def __init__(
user_config: UserConfig,
threads: int,
credentials: Credentials,
manage_schemas: bool = False,
):
"""Explicitly defining `__init__` to work around bug in Python 3.9.7
https://bugs.python.org/issue45081
Expand All @@ -109,6 +111,7 @@ def __init__(
self.threads = threads
self.credentials = credentials
self.profile_env_vars = {} # never available on init
self.manage_schemas = manage_schemas

def to_profile_info(self, serialize_credentials: bool = False) -> Dict[str, Any]:
"""Unlike to_project_config, this dict is not a mirror of any existing
Expand Down Expand Up @@ -240,6 +243,7 @@ def from_credentials(
profile_name: str,
target_name: str,
user_config: Optional[Dict[str, Any]] = None,
manage_schemas: bool = False,
) -> "Profile":
"""Create a profile from an existing set of Credentials and the
remaining information.
Expand All @@ -264,6 +268,7 @@ def from_credentials(
user_config=user_config_obj,
threads=threads,
credentials=credentials,
manage_schemas=manage_schemas,
)
profile.validate()
return profile
Expand Down Expand Up @@ -355,6 +360,7 @@ def from_raw_profile_info(
target_name=target_name,
threads=threads,
user_config=user_config,
manage_schemas=profile_data.get("manage_schemas", False),
)

@classmethod
Expand Down
7 changes: 6 additions & 1 deletion core/dbt/config/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from dbt.contracts.project import (
Project as ProjectContract,
SemverString,
SchemaManagementConfiguration,
)
from dbt.contracts.project import PackageConfig
from dbt.dataclass_schema import ValidationError
Expand Down Expand Up @@ -348,12 +349,13 @@ def create_project(self, rendered: RenderComponents) -> "Project":
)
test_paths: List[str] = value_or(cfg.test_paths, ["tests"])
analysis_paths: List[str] = value_or(cfg.analysis_paths, ["analyses"])
snapshot_paths: List[str] = value_or(cfg.snapshot_paths, ["snapshots"])
snapshot_paths: List[SchemaManagementConfiguration] = value_or(cfg.snapshot_paths, ["snapshots"])

all_source_paths: List[str] = _all_source_paths(
model_paths, seed_paths, snapshot_paths, analysis_paths, macro_paths
)

managed_schemas: List[str] = value_or(cfg.managed_schemas, [])
docs_paths: List[str] = value_or(cfg.docs_paths, all_source_paths)
asset_paths: List[str] = value_or(cfg.asset_paths, [])
target_path: str = value_or(cfg.target_path, "target")
Expand Down Expand Up @@ -417,6 +419,7 @@ def create_project(self, rendered: RenderComponents) -> "Project":
asset_paths=asset_paths,
target_path=target_path,
snapshot_paths=snapshot_paths,
managed_schemas=managed_schemas,
clean_targets=clean_targets,
log_path=log_path,
packages_install_path=packages_install_path,
Expand Down Expand Up @@ -524,6 +527,7 @@ class Project:
asset_paths: List[str]
target_path: str
snapshot_paths: List[str]
managed_schemas: List[SchemaManagementConfiguration]
clean_targets: List[str]
log_path: str
packages_install_path: str
Expand Down Expand Up @@ -597,6 +601,7 @@ def to_project_config(self, with_packages=False):
"asset-paths": self.asset_paths,
"target-path": self.target_path,
"snapshot-paths": self.snapshot_paths,
"managed-schemas": [schema.to_dict() for schema in self.managed_schemas],
"clean-targets": self.clean_targets,
"log-path": self.log_path,
"quoting": self.quoting,
Expand Down
3 changes: 3 additions & 0 deletions core/dbt/config/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def from_parts(
asset_paths=project.asset_paths,
target_path=project.target_path,
snapshot_paths=project.snapshot_paths,
managed_schemas=project.managed_schemas,
clean_targets=project.clean_targets,
log_path=project.log_path,
packages_install_path=project.packages_install_path,
Expand Down Expand Up @@ -118,6 +119,7 @@ def from_parts(
args=args,
cli_vars=cli_vars,
dependencies=dependencies,
manage_schemas=profile.manage_schemas,
)

# Called by 'load_projects' in this class
Expand Down Expand Up @@ -520,6 +522,7 @@ def from_parts(
asset_paths=project.asset_paths,
target_path=project.target_path,
snapshot_paths=project.snapshot_paths,
managed_schemas=project.managed_schemas,
clean_targets=project.clean_targets,
log_path=project.log_path,
packages_install_path=project.packages_install_path,
Expand Down
1 change: 1 addition & 0 deletions core/dbt/contracts/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def resolve(self, connection: Connection) -> Connection:
class Credentials(ExtensibleDbtClassMixin, Replaceable, metaclass=abc.ABCMeta):
database: str
schema: str
manage_schemas: bool
_ALIASES: ClassVar[Dict[str, str]] = field(default={}, init=False)

@abc.abstractproperty
Expand Down
9 changes: 9 additions & 0 deletions core/dbt/contracts/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,13 @@ class RegistryPackageMetadata(
}


@dataclass
class SchemaManagementConfiguration(HyphenatedDbtClassMixin, Replaceable):
database: Optional[str] = None
schema: Optional[str] = None
action: Optional[str] = None


@dataclass
class Project(HyphenatedDbtClassMixin, Replaceable):
name: Name
Expand All @@ -197,6 +204,7 @@ class Project(HyphenatedDbtClassMixin, Replaceable):
asset_paths: Optional[List[str]] = None
target_path: Optional[str] = None
snapshot_paths: Optional[List[str]] = None
managed_schemas: Optional[List[SchemaManagementConfiguration]] = None
clean_targets: Optional[List[str]] = None
profile: Optional[str] = None
log_path: Optional[str] = None
Expand Down Expand Up @@ -264,6 +272,7 @@ class ProfileConfig(HyphenatedDbtClassMixin, Replaceable):
threads: int
# TODO: make this a dynamic union of some kind?
credentials: Optional[Dict[str, Any]]
manage_schemas: Optional[bool] = field(metadata={"preserve_underscore": True})


@dataclass
Expand Down
52 changes: 52 additions & 0 deletions core/dbt/task/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from dbt.contracts.graph.parsed import ParsedHookNode
from dbt.contracts.results import NodeStatus, RunResult, RunStatus, RunningStatus
from dbt.exceptions import (
warn_or_error,
warn,
CompilationException,
InternalException,
RuntimeException,
Expand Down Expand Up @@ -466,6 +468,7 @@ def after_run(self, adapter, results):
}
with adapter.connection_named("master"):
self.safe_run_hooks(adapter, RunHookType.End, extras)
self.manage_schema(adapter, results)

def after_hooks(self, adapter, results, elapsed):
self.print_results_line(results, elapsed)
Expand All @@ -486,3 +489,52 @@ def get_runner_type(self, _):
def task_end_messages(self, results):
if results:
print_run_end_messages(results)

def manage_schema(self, adapter, results: List[RunResult]):
# Read config
manage_schemas_config = self.config.manage_schemas # type: bool
managed_schemas_actions_config: Dict[Tuple[str, str], str] = {
(ms.database or "", ms.schema or ""): ms.action or "warn"
for ms in self.config.managed_schemas
}

if not manage_schemas_config:
# TODO debug not doing anything
warn("Schema's configured to be managed, but manage_schemas is false in the profile")
return

if len(managed_schemas_actions_config) == 0:
warn_or_error("Schema management enabled for connection but no schema's configured to manage")
return

# Never manage schema if we have a failed node
was_successfull_complete_run = not any(
r.status in (NodeStatus.Error, NodeStatus.Fail, NodeStatus.Skipped) for r in results
)
if not was_successfull_complete_run and manage_schemas_config:
warn("One or more models failed, skipping schema management")
return

models_in_results: Set[Tuple[str, str, str]] = set(
(r.node.database, r.node.schema, r.node.identifier)
for r in results
if (r.node.is_relational and not r.node.is_ephemeral_model)
)

for database, schema in managed_schemas_actions_config.keys():
available_models: Dict[Tuple[str, str, str], str] = {
(database, schema, relation.identifier): relation
for relation in adapter.list_relations(database, schema)
}
if len(available_models) == 0:
warn_or_error(f"No modules in managed schema '{schema}' for database '{database}'")
should_act_upon = available_models.keys() - models_in_results
for (target_database, target_schema, target_identifier) in should_act_upon:
target_action = managed_schemas_actions_config[(target_database, target_schema)]
if target_action == "warn":
print("WARN ABOUT ", target_database, target_schema, target_identifier)
elif target_action == "drop":
adapter.drop_relation(
available_models[(target_database, target_schema, target_identifier)]
)

5 changes: 5 additions & 0 deletions core/dbt/tests/fixtures/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,11 @@ def get_tables_in_schema(self):
result = self.run_sql(sql, fetch="all")
return {model_name: materialization for (model_name, materialization) in result}

def update_models(self, models: dict):
"""Update the modules in the test project"""
self.project_root.join("models").remove()
write_project_files(self.project_root, "models", models)


# This is the main fixture that is used in all functional tests. It pulls in the other
# fixtures that are necessary to set up a dbt project, and saves some of the information
Expand Down
1 change: 1 addition & 0 deletions tests/functional/schema_management/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Test schema management as introduced by https://github.com/dbt-labs/dbt-core/issues/4957
77 changes: 77 additions & 0 deletions tests/functional/schema_management/test_drop_dangling_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import pytest
import os

from dbt.tests.util import (
run_dbt,
check_table_does_exist,
check_table_does_not_exist,
)

model = """
{{
config(
materialized = "table"
)
}}
SELECT * FROM (
VALUES (1, 'one'),
(2, 'two'),
(3, 'three')
) AS t (num,letter)
"""



class TestDanglingModels:

@pytest.fixture(scope="class")
def models(self):
return {
"model_a.sql": model,
"model_b.sql": model,
}


@pytest.fixture(scope="class")
def dbt_profile_target(self):
return {
"type": "postgres",
"threads": 4,
"host": "localhost",
"port": int(os.getenv("POSTGRES_TEST_PORT", 5432)),
"user": os.getenv("POSTGRES_TEST_USER", "root"),
"pass": os.getenv("POSTGRES_TEST_PASS", "password"),
"dbname": os.getenv("POSTGRES_TEST_DATABASE", "dbt"),
"manage_schemas": True,
}

@pytest.fixture(scope="class")
def project_config_update(self, unique_schema):
return {
"managed-schemas": [
{
"database": os.getenv("POSTGRES_TEST_DATABASE", "dbt"),
"schema": unique_schema,
"action": "drop",
}
]
}


def test_drop(
self,
project,
):
# create numbers model
run_dbt(["run"])
check_table_does_exist(project.adapter, "model_a")
check_table_does_exist(project.adapter, "model_b")
check_table_does_not_exist(project.adapter, "baz")

# remove numbers model
project.update_models({
"model_b.sql": model,
})
run_dbt(["run"])
check_table_does_not_exist(project.adapter, "model_a")
check_table_does_exist(project.adapter, "model_b")
Loading

0 comments on commit 5b4be03

Please sign in to comment.