diff --git a/core/dbt/cli/main.py b/core/dbt/cli/main.py index a9de9441365..af2a6cd59f7 100644 --- a/core/dbt/cli/main.py +++ b/core/dbt/cli/main.py @@ -564,6 +564,7 @@ def parse(ctx, **kwargs): @requires.profile @requires.project @requires.runtime_config +@requires.catalogs @requires.manifest def run(ctx, **kwargs): """Compile SQL and execute against the current target database.""" diff --git a/core/dbt/cli/requires.py b/core/dbt/cli/requires.py index 058b16cdfaa..d7a93ffe8d9 100644 --- a/core/dbt/cli/requires.py +++ b/core/dbt/cli/requires.py @@ -12,6 +12,7 @@ from dbt.cli.exceptions import ExceptionExit, ResultExit from dbt.cli.flags import Flags from dbt.config import RuntimeConfig +from dbt.config.catalogs import Catalogs from dbt.config.runtime import UnsetProfile, load_profile, load_project from dbt.context.providers import generate_runtime_macro_context from dbt.context.query_header import generate_query_header_context @@ -313,6 +314,29 @@ def wrapper(*args, **kwargs): return update_wrapper(wrapper, func) +def catalogs(func): + """A decorator used by click command functions for loading catalogs""" + + def wrapper(*args, **kwargs): + ctx = args[0] + assert isinstance(ctx, Context) + + req_strs = ["flags", "profile"] + reqs = [ctx.obj.get(req_str) for req_str in req_strs] + if None in reqs: + raise DbtProjectError("profile and flags required to load catalogs") + + flags = ctx.obj["flags"] + profile = ctx.obj["profile"] + + catalogs = Catalogs.load(flags.PROJECT_DIR, profile.profile_name, flags.VARS) + ctx.obj["catalogs"] = catalogs + + return func(*args, **kwargs) + + return update_wrapper(wrapper, func) + + def manifest(*args0, write=True, write_perf_info=False): """A decorator used by click command functions for generating a manifest given a profile, project, and runtime config. This also registers the adapter @@ -346,11 +370,18 @@ def setup_manifest(ctx: Context, write: bool = True, write_perf_info: bool = Fal runtime_config = ctx.obj["runtime_config"] + catalogs = ctx.obj["catalogs"] if "catalogs" in ctx.obj else None + catalog_integrations = ( + catalogs.get_active_adapter_write_catalog_integrations() if catalogs else [] + ) + # if a manifest has already been set on the context, don't overwrite it if ctx.obj.get("manifest") is None: ctx.obj["manifest"] = parse_manifest( runtime_config, write_perf_info, write, ctx.obj["flags"].write_json ) + adapter = get_adapter(runtime_config) + adapter.add_catalog_integrations(catalog_integrations) else: register_adapter(runtime_config, get_mp_context()) adapter = get_adapter(runtime_config) @@ -358,3 +389,4 @@ def setup_manifest(ctx: Context, write: bool = True, write_perf_info: bool = Fal adapter.set_macro_resolver(ctx.obj["manifest"]) query_header_context = generate_query_header_context(adapter.config, ctx.obj["manifest"]) # type: ignore[attr-defined] adapter.connections.set_query_header(query_header_context) + adapter.add_catalog_integrations(catalog_integrations) diff --git a/core/dbt/config/catalogs.py b/core/dbt/config/catalogs.py new file mode 100644 index 00000000000..ee0480be7e7 --- /dev/null +++ b/core/dbt/config/catalogs.py @@ -0,0 +1,145 @@ +import os +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +from dbt.adapters.contracts.catalog import CatalogIntegrationType +from dbt.adapters.relation_configs.formats import TableFormat +from dbt.clients.yaml_helper import load_yaml_text +from dbt.config.renderer import SecretRenderer +from dbt_common.clients.system import load_file_contents +from dbt_common.dataclass_schema import dbtClassMixin +from dbt_common.exceptions import CompilationError, DbtValidationError + + +@dataclass +class CatalogIntegration(dbtClassMixin): + name: str + external_volume: str + table_format: TableFormat + catalog_type: CatalogIntegrationType + + +# satisfies dbt.adapters.protocol.CatalogIntegrationConfig +@dataclass +class AdapterCatalogIntegration: + catalog_name: str + integration_name: str + table_format: str + catalog_type: str + external_volume: Optional[str] + namespace: Optional[str] + adapter_configs: Optional[Dict] + + +@dataclass +class Catalog(dbtClassMixin): + name: str + # If not specified, active_write_integration defaults to the integration in integrations if there is only one. + active_write_integration: Optional[str] = None + write_integrations: List[CatalogIntegration] = field(default_factory=list) + + @classmethod + def render( + cls, raw_catalog: Dict[str, Any], renderer: SecretRenderer, default_profile_name: str + ) -> "Catalog": + try: + rendered_catalog = renderer.render_data(raw_catalog) + except CompilationError: + # TODO: better error + raise + + cls.validate(rendered_catalog) + + write_integrations = [] + for raw_write_integration in rendered_catalog.get("write_integrations", []): + CatalogIntegration.validate(raw_write_integration) + write_integrations.append(CatalogIntegration.from_dict(raw_write_integration)) + + # Validate + set default active_write_integration if unset + active_write_integration = rendered_catalog.get("active_write_integration") + valid_write_integration_names = [integration.name for integration in write_integrations] + if ( + active_write_integration + and active_write_integration not in valid_write_integration_names + ): + raise DbtValidationError( + f"Catalog '{rendered_catalog['name']}' must specify a 'active_write_integration' from its set of defined 'write_integrations': {valid_write_integration_names}. Got: '{active_write_integration}'." + ) + elif len(write_integrations) > 1 and not active_write_integration: + raise DbtValidationError( + f"Catalog '{rendered_catalog['name']}' must specify an 'active_write_integration' when multiple 'write_integrations' are provided." + ) + elif not active_write_integration and len(write_integrations) == 1: + active_write_integration = write_integrations[0].name + + return cls( + name=raw_catalog["name"], + active_write_integration=active_write_integration, + write_integrations=write_integrations, + ) + + +@dataclass +class Catalogs(dbtClassMixin): + catalogs: List[Catalog] + + @classmethod + def load(cls, catalog_dir: str, profile: str, cli_vars: Dict[str, Any]) -> "Catalogs": + catalogs = [] + + raw_catalogs = cls._read_catalogs(catalog_dir) + + catalogs_renderer = SecretRenderer(cli_vars) + for raw_catalog in raw_catalogs.get("catalogs", []): + catalog = Catalog.render(raw_catalog, catalogs_renderer, profile) + catalogs.append(catalog) + + return cls(catalogs=catalogs) + + def get_active_adapter_write_catalog_integrations(self): + adapter_catalog_integrations: List[AdapterCatalogIntegration] = [] + + for catalog in self.catalogs: + active_write_integration = list( + filter( + lambda c: c.name == catalog.active_write_integration, + catalog.write_integrations, + ) + )[0] + + adapter_catalog_integrations.append( + AdapterCatalogIntegration( + catalog_name=catalog.name, + integration_name=catalog.active_write_integration, + table_format=active_write_integration.table_format, + catalog_type=active_write_integration.catalog_type, + external_volume=active_write_integration.external_volume, + namespace=None, # namespaces on write_integrations are not yet supported + adapter_configs={}, # configs on write_integrations not yet supported + ) + ) + + return adapter_catalog_integrations + + @classmethod + def _read_catalogs(cls, catalog_dir: str) -> Dict[str, Any]: + path = os.path.join(catalog_dir, "catalogs.yml") + + contents = None + if os.path.isfile(path): + try: + contents = load_file_contents(path, strip=False) + yaml_content = load_yaml_text(contents) + if not yaml_content: + # msg = f"The catalogs.yml file at {path} is empty" + # TODO: better error + raise ValueError + # raise DbtProfileError(INVALID_PROFILE_MESSAGE.format(error_string=msg)) + return yaml_content + # TODO: better error + except DbtValidationError: + # msg = INVALID_PROFILE_MESSAGE.format(error_string=e) + # raise DbtValidationError(msg) from e + raise + + return {} diff --git a/dev-requirements.txt b/dev-requirements.txt index 5f393349744..91aba73ef33 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,4 +1,4 @@ -git+https://github.com/dbt-labs/dbt-adapters.git@main +git+https://github.com/dbt-labs/dbt-adapters.git@feature/externalCatalogConfig git+https://github.com/dbt-labs/dbt-adapters.git@main#subdirectory=dbt-tests-adapter git+https://github.com/dbt-labs/dbt-common.git@main git+https://github.com/dbt-labs/dbt-postgres.git@main diff --git a/tests/functional/catalogs/test_catalogs_parsing.py b/tests/functional/catalogs/test_catalogs_parsing.py new file mode 100644 index 00000000000..bbd53dd2c6b --- /dev/null +++ b/tests/functional/catalogs/test_catalogs_parsing.py @@ -0,0 +1,36 @@ +from unittest import mock + +import pytest + +from dbt.tests.util import run_dbt, write_config_file + + +class TestCatalogsParsing: + + @pytest.fixture + def catalogs(self): + return { + "catalogs": [ + { + "name": "test_catalog", + "write_integrations": [ + { + "name": "write_integration_name", + "external_volume": "write_integration_external_volume", + "table_format": "iceberg", + "catalog_type": "glue", + } + ], + } + ] + } + + def test_catalog_parsing_adapter_initialialization(self, catalogs, project): + write_config_file(catalogs, project.project_root, "catalogs.yml") + + mock_add_catalog_integration = mock.Mock() + with mock.patch.object( + type(project.adapter), "add_catalog_integrations", mock_add_catalog_integration + ): + run_dbt(["run"]) + mock_add_catalog_integration.assert_called_once()