From e09ac6d640609c43602f7e39c74a58255bfa525b Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Wed, 18 Oct 2023 11:32:24 +0100 Subject: [PATCH 1/2] Add tests to sources, snapshots and seeds when using TestBehavior.AFTER_EACH (#599) Previously Cosmos would only create a task group when using `TestBehavior.AFTER_EACH` for nodes of the type `DbtResourceType.MODEL`. This change adds the same behavior to snapshots and seeds. For this to work as expected with sources, we would need to create a default operator to handle `DbtResourceType.SOURCE`, which is outside the scope of the current ticket. Once this operator exists, sources will also lead to creating a task group. All the test selectors were tested successfully with dbt 1.6. This screenshot illustrates the validation of this feature, with an adapted version of Jaffle Shop: Screenshot 2023-10-13 at 19 43 52 The modifications that were done to jaffle_shop were: Appended the following lines to `dev/dags/dbt/jaffle_shop/models/schema.yml`: ``` seeds: - name: raw_customers description: Raw data from customers columns: - name: id tests: - unique - not_null snapshots: - name: orders_snapshot description: Snapshot of orders columns: - name: orders_snapshot.order_id tests: - unique - not_null ``` And created the file `dev/dags/dbt/jaffle_shop/snapshots/orders_snapshot.sql` with: ``` {% snapshot orders_snapshot %} {{ config( target_database='postgres', target_schema='public', unique_key='order_id', strategy='timestamp', updated_at='order_date', ) }} select * from {{ ref('jaffle_shop', 'orders') }} {% endsnapshot %} ``` Closes: #474 --- cosmos/airflow/graph.py | 17 +++++--- cosmos/constants.py | 8 ++-- tests/airflow/test_graph.py | 77 +++++++++++++++++++++++++++++++++++-- 3 files changed, 89 insertions(+), 13 deletions(-) diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index c4a0aa61a..3e7961ed1 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -56,7 +56,7 @@ def create_test_task_metadata( execution_mode: ExecutionMode, task_args: dict[str, Any], on_warning_callback: Callable[..., Any] | None = None, - model_name: str | None = None, + node: DbtNode | None = None, ) -> TaskMetadata: """ Create the metadata that will be used to instantiate the Airflow Task that will be used to run the Dbt test node. @@ -66,13 +66,18 @@ def create_test_task_metadata( :param task_args: Arguments to be used to instantiate an Airflow Task :param on_warning_callback: A callback function called on warnings with additional Context variables “test_names” and “test_results” of type List. - :param model_name: If the test relates to a specific model, the name of the model it relates to + :param node: If the test relates to a specific node, the node reference :returns: The metadata necessary to instantiate the source dbt node as an Airflow task. """ task_args = dict(task_args) task_args["on_warning_callback"] = on_warning_callback - if model_name is not None: - task_args["models"] = model_name + if node is not None: + if node.resource_type == DbtResourceType.MODEL: + task_args["models"] = node.name + elif node.resource_type == DbtResourceType.SOURCE: + task_args["select"] = f"source:{node.unique_id[len('source.'):]}" + else: # tested with node.resource_type == DbtResourceType.SEED or DbtResourceType.SNAPSHOT + task_args["select"] = node.name return TaskMetadata( id=test_task_name, operator_class=calculate_operator_class( @@ -112,6 +117,8 @@ def create_task_metadata( task_id = "run" else: task_id = f"{node.name}_{node.resource_type.value}" + if use_task_group is True: + task_id = node.resource_type.value task_metadata = TaskMetadata( id=task_id, @@ -163,7 +170,7 @@ def generate_task_or_group( "test", execution_mode, task_args=task_args, - model_name=node.name, + node=node, on_warning_callback=on_warning_callback, ) test_task = create_airflow_task(test_meta, dag, task_group=model_task_group) diff --git a/cosmos/constants.py b/cosmos/constants.py index 2ce9a09fe..cd59c8173 100644 --- a/cosmos/constants.py +++ b/cosmos/constants.py @@ -70,7 +70,7 @@ def _missing_value_(cls, value): # type: ignore DEFAULT_DBT_RESOURCES = DbtResourceType.__members__.values() - -TESTABLE_DBT_RESOURCES = { - DbtResourceType.MODEL -} # TODO: extend with DbtResourceType.SOURCE, DbtResourceType.SNAPSHOT, DbtResourceType.SEED) +# dbt test runs tests defined on models, sources, snapshots, and seeds. +# It expects that you have already created those resources through the appropriate commands. +# https://docs.getdbt.com/reference/commands/test +TESTABLE_DBT_RESOURCES = {DbtResourceType.MODEL, DbtResourceType.SOURCE, DbtResourceType.SNAPSHOT, DbtResourceType.SEED} diff --git a/tests/airflow/test_graph.py b/tests/airflow/test_graph.py index 57a3462f7..2eb93c613 100644 --- a/tests/airflow/test_graph.py +++ b/tests/airflow/test_graph.py @@ -5,6 +5,7 @@ import pytest from airflow import __version__ as airflow_version from airflow.models import DAG +from airflow.utils.task_group import TaskGroup from packaging import version from cosmos.airflow.graph import ( @@ -13,6 +14,7 @@ calculate_operator_class, create_task_metadata, create_test_task_metadata, + generate_task_or_group, ) from cosmos.config import ProfileConfig from cosmos.constants import DbtResourceType, ExecutionMode, TestBehavior @@ -101,6 +103,50 @@ def test_build_airflow_graph_with_after_each(): assert dag.leaves[0].task_id == "child_run" +@pytest.mark.parametrize( + "node_type,task_suffix", + [(DbtResourceType.MODEL, "run"), (DbtResourceType.SEED, "seed"), (DbtResourceType.SNAPSHOT, "snapshot")], +) +def test_create_task_group_for_after_each_supported_nodes(node_type, task_suffix): + """ + dbt test runs tests defined on models, sources, snapshots, and seeds. + It expects that you have already created those resources through the appropriate commands. + https://docs.getdbt.com/reference/commands/test + """ + with DAG("test-task-group-after-each", start_date=datetime(2022, 1, 1)) as dag: + node = DbtNode( + name="dbt_node", + unique_id="dbt_node", + resource_type=node_type, + file_path=SAMPLE_PROJ_PATH / "gen2/models/parent.sql", + tags=["has_child"], + config={"materialized": "view"}, + depends_on=[], + has_test=True, + ) + output = generate_task_or_group( + dag=dag, + task_group=None, + node=node, + execution_mode=ExecutionMode.LOCAL, + task_args={ + "project_dir": SAMPLE_PROJ_PATH, + "profile_config": ProfileConfig( + profile_name="default", + target_name="default", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="fake_conn", + profile_args={"schema": "public"}, + ), + ), + }, + test_behavior=TestBehavior.AFTER_EACH, + on_warning_callback=None, + ) + assert isinstance(output, TaskGroup) + assert list(output.children.keys()) == [f"dbt_node.{task_suffix}", "dbt_node.test"] + + @pytest.mark.skipif( version.parse(airflow_version) < version.parse("2.4"), reason="Airflow DAG did not have task_group_dict until the 2.4 release", @@ -259,7 +305,12 @@ def test_create_task_metadata_seed(caplog, use_task_group): args={}, use_task_group=use_task_group, ) - assert metadata.id == "my_seed_seed" + + if not use_task_group: + assert metadata.id == "my_seed_seed" + else: + assert metadata.id == "seed" + assert metadata.operator_class == "cosmos.operators.docker.DbtSeedDockerOperator" assert metadata.arguments == {"models": "my_seed"} @@ -280,14 +331,32 @@ def test_create_task_metadata_snapshot(caplog): assert metadata.arguments == {"models": "my_snapshot"} -def test_create_test_task_metadata(): +@pytest.mark.parametrize( + "node_type,node_unique_id,selector_key,selector_value", + [ + (DbtResourceType.MODEL, "node_name", "models", "node_name"), + (DbtResourceType.SEED, "node_name", "select", "node_name"), + (DbtResourceType.SOURCE, "source.node_name", "select", "source:node_name"), + (DbtResourceType.SNAPSHOT, "node_name", "select", "node_name"), + ], +) +def test_create_test_task_metadata(node_type, node_unique_id, selector_key, selector_value): + sample_node = DbtNode( + name="node_name", + unique_id=node_unique_id, + resource_type=node_type, + depends_on=[], + file_path="", + tags=[], + config={}, + ) metadata = create_test_task_metadata( test_task_name="test_no_nulls", execution_mode=ExecutionMode.LOCAL, task_args={"task_arg": "value"}, on_warning_callback=True, - model_name="my_model", + node=sample_node, ) assert metadata.id == "test_no_nulls" assert metadata.operator_class == "cosmos.operators.local.DbtTestLocalOperator" - assert metadata.arguments == {"task_arg": "value", "on_warning_callback": True, "models": "my_model"} + assert metadata.arguments == {"task_arg": "value", "on_warning_callback": True, selector_key: selector_value} From c82e6bc66dd31d6e1404fc24548d978489f4f0ab Mon Sep 17 00:00:00 2001 From: Perttu Salonen Date: Wed, 18 Oct 2023 18:09:17 +0300 Subject: [PATCH 2/2] Add Vertica ProfileMapping Add `ProfileMapping` for Vertica, based on the adapter: https://github.com/vertica/dbt-vertica Closes: #538 Signed-off-by: Perttu Salonen --- cosmos/profiles/__init__.py | 3 + cosmos/profiles/vertica/__init__.py | 5 + cosmos/profiles/vertica/user_pass.py | 76 +++++++ pyproject.toml | 14 +- .../vertica/test_vertica_user_pass.py | 191 ++++++++++++++++++ 5 files changed, 284 insertions(+), 5 deletions(-) create mode 100644 cosmos/profiles/vertica/__init__.py create mode 100644 cosmos/profiles/vertica/user_pass.py create mode 100644 tests/profiles/vertica/test_vertica_user_pass.py diff --git a/cosmos/profiles/__init__.py b/cosmos/profiles/__init__.py index dae6e2c04..e75b6c25e 100644 --- a/cosmos/profiles/__init__.py +++ b/cosmos/profiles/__init__.py @@ -20,6 +20,7 @@ from .trino.certificate import TrinoCertificateProfileMapping from .trino.jwt import TrinoJWTProfileMapping from .trino.ldap import TrinoLDAPProfileMapping +from .vertica.user_pass import VerticaUserPasswordProfileMapping profile_mappings: list[Type[BaseProfileMapping]] = [ AthenaAccessKeyProfileMapping, @@ -36,6 +37,7 @@ TrinoLDAPProfileMapping, TrinoCertificateProfileMapping, TrinoJWTProfileMapping, + VerticaUserPasswordProfileMapping, ] @@ -72,4 +74,5 @@ def get_automatic_profile_mapping( "TrinoLDAPProfileMapping", "TrinoCertificateProfileMapping", "TrinoJWTProfileMapping", + "VerticaUserPasswordProfileMapping", ] diff --git a/cosmos/profiles/vertica/__init__.py b/cosmos/profiles/vertica/__init__.py new file mode 100644 index 000000000..4a88f2edd --- /dev/null +++ b/cosmos/profiles/vertica/__init__.py @@ -0,0 +1,5 @@ +"Vertica Airflow connection -> dbt profile mappings" + +from .user_pass import VerticaUserPasswordProfileMapping + +__all__ = ["VerticaUserPasswordProfileMapping"] diff --git a/cosmos/profiles/vertica/user_pass.py b/cosmos/profiles/vertica/user_pass.py new file mode 100644 index 000000000..494185e05 --- /dev/null +++ b/cosmos/profiles/vertica/user_pass.py @@ -0,0 +1,76 @@ +"Maps Airflow Vertica connections using user + password authentication to dbt profiles." +from __future__ import annotations + +from typing import Any + +from ..base import BaseProfileMapping + + +class VerticaUserPasswordProfileMapping(BaseProfileMapping): + """ + Maps Airflow Vertica connections using user + password authentication to dbt profiles. + https://docs.getdbt.com/reference/warehouse-setups/vertica-setup + https://airflow.apache.org/docs/apache-airflow-providers-vertica/stable/connections/vertica.html + """ + + airflow_connection_type: str = "vertica" + dbt_profile_type: str = "vertica" + + required_fields = [ + "host", + "user", + "password", + "database", + "schema", + ] + secret_fields = [ + "password", + ] + airflow_param_mapping = { + "host": "host", + "user": "login", + "password": "password", + "port": "port", + "schema": "schema", + "database": "extra.database", + "autocommit": "extra.autocommit", + "backup_server_node": "extra.backup_server_node", + "binary_transfer": "extra.binary_transfer", + "connection_load_balance": "extra.connection_load_balance", + "connection_timeout": "extra.connection_timeout", + "disable_copy_local": "extra.disable_copy_local", + "kerberos_host_name": "extra.kerberos_host_name", + "kerberos_service_name": "extra.kerberos_service_name", + "log_level": "extra.log_level", + "log_path": "extra.log_path", + "oauth_access_token": "extra.oauth_access_token", + "request_complex_types": "extra.request_complex_types", + "session_label": "extra.session_label", + "ssl": "extra.ssl", + "unicode_error": "extra.unicode_error", + "use_prepared_statements": "extra.use_prepared_statements", + "workload": "extra.workload", + } + + @property + def profile(self) -> dict[str, Any | None]: + "Gets profile. The password is stored in an environment variable." + profile = { + "port": 5433, + **self.mapped_params, + **self.profile_args, + # password should always get set as env var + "password": self.get_env_var_format("password"), + } + + return self.filter_null(profile) + + @property + def mock_profile(self) -> dict[str, Any | None]: + "Gets mock profile. Defaults port to 5433." + parent_mock = super().mock_profile + + return { + "port": 5433, + **parent_mock, + } diff --git a/pyproject.toml b/pyproject.toml index 512af2bbe..ef1fba2f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ dbt-all = [ "dbt-redshift", "dbt-snowflake", "dbt-spark", + "dbt-vertica", ] dbt-athena = [ "dbt-athena-community", @@ -80,6 +81,9 @@ dbt-snowflake = [ dbt-spark = [ "dbt-spark", ] +dbt-vertica = [ + "dbt-vertica<=1.5.4", +] openlineage = [ "openlineage-integration-common", "openlineage-airflow", @@ -165,18 +169,18 @@ test = 'pytest -vv --durations=0 . -m "not integration" --ignore=tests/test_exam test-cov = """pytest -vv --cov=cosmos --cov-report=term-missing --cov-report=xml --durations=0 -m "not integration" --ignore=tests/test_example_dags.py --ignore=tests/test_example_dags_no_connections.py""" # we install using the following workaround to overcome installation conflicts, such as: # apache-airflow 2.3.0 and dbt-core [0.13.0 - 1.5.2] and jinja2>=3.0.0 because these package versions have conflicting dependencies -test-integration-setup = """pip uninstall -y dbt-core dbt-databricks dbt-sqlite dbt-postgres dbt-sqlite; \ +test-integration-setup = """pip uninstall dbt-postgres dbt-databricks dbt-vertica; \ rm -rf airflow.*; \ airflow db init; \ -pip install 'dbt-core' 'dbt-databricks' 'dbt-postgres' 'openlineage-airflow'""" -test-integration = """pytest -vv \ +pip install 'dbt-core' 'dbt-databricks' 'dbt-postgres' 'dbt-vertica' 'openlineage-airflow'""" +test-integration = """rm -rf dbt/jaffle_shop/dbt_packages; +pytest -vv \ --cov=cosmos \ --cov-report=term-missing \ --cov-report=xml \ --durations=0 \ -m integration \ --k 'not (sqlite or example_cosmos_sources or example_cosmos_python_models or example_virtualenv or cosmos_manifest_example)' -""" +-k 'not (sqlite or example_cosmos_sources or example_cosmos_python_models or example_virtualenv or cosmos_manifest_example)'""" test-integration-expensive = """pytest -vv \ --cov=cosmos \ --cov-report=term-missing \ diff --git a/tests/profiles/vertica/test_vertica_user_pass.py b/tests/profiles/vertica/test_vertica_user_pass.py new file mode 100644 index 000000000..953a3c553 --- /dev/null +++ b/tests/profiles/vertica/test_vertica_user_pass.py @@ -0,0 +1,191 @@ +"Tests for the vertica profile." + +from unittest.mock import patch + +import pytest +from airflow.models.connection import Connection + +from cosmos.profiles import get_automatic_profile_mapping +from cosmos.profiles.vertica.user_pass import ( + VerticaUserPasswordProfileMapping, +) + + +@pytest.fixture() +def mock_vertica_conn(): # type: ignore + """ + Sets the connection as an environment variable. + """ + conn = Connection( + conn_id="my_vertica_connection", + conn_type="vertica", + host="my_host", + login="my_user", + password="my_password", + port=5433, + schema="my_schema", + extra='{"database": "my_database"}', + ) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + yield conn + + +@pytest.fixture() +def mock_vertica_conn_custom_port(): # type: ignore + """ + Sets the connection as an environment variable. + """ + conn = Connection( + conn_id="my_vertica_connection", + conn_type="vertica", + host="my_host", + login="my_user", + password="my_password", + port=7472, + schema="my_schema", + extra='{"database": "my_database"}', + ) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + yield conn + + +def test_connection_claiming() -> None: + """ + Tests that the vertica profile mapping claims the correct connection type. + """ + # should only claim when: + # - conn_type == vertica + # and the following exist: + # - host + # - user + # - password + # - port + # - database or database + # - schema + potential_values = { + "conn_type": "vertica", + "host": "my_host", + "login": "my_user", + "password": "my_password", + "schema": "my_schema", + "extra": '{"database": "my_database"}', + } + + # if we're missing any of the values, it shouldn't claim + for key in potential_values: + values = potential_values.copy() + del values[key] + conn = Connection(**values) # type: ignore + + print("testing with", values) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = VerticaUserPasswordProfileMapping(conn) + assert not profile_mapping.can_claim_connection() + + # also test when there's no database + conn = Connection(**potential_values) # type: ignore + conn.extra = "" + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = VerticaUserPasswordProfileMapping(conn) + assert not profile_mapping.can_claim_connection() + + # if we have them all, it should claim + conn = Connection(**potential_values) # type: ignore + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = VerticaUserPasswordProfileMapping(conn) + assert profile_mapping.can_claim_connection() + + +def test_profile_mapping_selected( + mock_vertica_conn: Connection, +) -> None: + """ + Tests that the correct profile mapping is selected. + """ + profile_mapping = get_automatic_profile_mapping( + mock_vertica_conn.conn_id, + {"schema": "my_schema"}, + ) + assert isinstance(profile_mapping, VerticaUserPasswordProfileMapping) + + +def test_mock_profile() -> None: + """ + Tests that the mock profile port value get set correctly. + """ + profile = VerticaUserPasswordProfileMapping("mock_conn_id") + assert profile.mock_profile.get("port") == 5433 + + +def test_profile_mapping_keeps_custom_port(mock_vertica_conn_custom_port: Connection) -> None: + profile = VerticaUserPasswordProfileMapping(mock_vertica_conn_custom_port.conn_id, {"schema": "my_schema"}) + assert profile.profile["port"] == 7472 + + +def test_profile_args( + mock_vertica_conn: Connection, +) -> None: + """ + Tests that the profile values get set correctly. + """ + profile_mapping = get_automatic_profile_mapping( + mock_vertica_conn.conn_id, + profile_args={"schema": "my_schema"}, + ) + assert profile_mapping.profile_args == { + "schema": "my_schema", + } + + assert profile_mapping.profile == { + "type": mock_vertica_conn.conn_type, + "host": mock_vertica_conn.host, + "user": mock_vertica_conn.login, + "password": "{{ env_var('COSMOS_CONN_VERTICA_PASSWORD') }}", + "port": mock_vertica_conn.port, + "schema": "my_schema", + "database": mock_vertica_conn.extra_dejson.get("database"), + } + + +def test_profile_args_overrides( + mock_vertica_conn: Connection, +) -> None: + """ + Tests that you can override the profile values. + """ + profile_mapping = get_automatic_profile_mapping( + mock_vertica_conn.conn_id, + profile_args={"schema": "my_schema", "database": "my_db_override"}, + ) + assert profile_mapping.profile_args == { + "schema": "my_schema", + "database": "my_db_override", + } + + assert profile_mapping.profile == { + "type": mock_vertica_conn.conn_type, + "host": mock_vertica_conn.host, + "user": mock_vertica_conn.login, + "password": "{{ env_var('COSMOS_CONN_VERTICA_PASSWORD') }}", + "port": mock_vertica_conn.port, + "database": "my_db_override", + "schema": "my_schema", + } + + +def test_profile_env_vars( + mock_vertica_conn: Connection, +) -> None: + """ + Tests that the environment variables get set correctly. + """ + profile_mapping = get_automatic_profile_mapping( + mock_vertica_conn.conn_id, + profile_args={"schema": "my_schema"}, + ) + assert profile_mapping.env_vars == { + "COSMOS_CONN_VERTICA_PASSWORD": mock_vertica_conn.password, + }