From ca65d1067a2c508e9d8090d270e580b4dfea4777 Mon Sep 17 00:00:00 2001 From: DanMawdsleyBA Date: Sat, 4 Nov 2023 10:56:05 +0000 Subject: [PATCH 01/13] Intial change for Snowflake encrypted private key --- cosmos/profiles/__init__.py | 5 +- cosmos/profiles/snowflake/__init__.py | 4 +- .../user_encrypted_privatekey_env_variable.py | 86 +++++++ ...y.py => user_encrypted_privatekey_file.py} | 6 +- ..._user_encrypted_privatekey_env_variable.py | 216 ++++++++++++++++++ ...owflake_user_encrypted_privatekey_file.py} | 2 +- 6 files changed, 313 insertions(+), 6 deletions(-) create mode 100644 cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py rename cosmos/profiles/snowflake/{user_encrypted_privatekey.py => user_encrypted_privatekey_file.py} (95%) create mode 100644 tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py rename tests/profiles/snowflake/{test_snowflake_user_encrypted_privatekey.py => test_snowflake_user_encrypted_privatekey_file.py} (99%) diff --git a/cosmos/profiles/__init__.py b/cosmos/profiles/__init__.py index 47c7309ab..8280cd950 100644 --- a/cosmos/profiles/__init__.py +++ b/cosmos/profiles/__init__.py @@ -16,7 +16,8 @@ from .redshift.user_pass import RedshiftUserPasswordProfileMapping from .snowflake.user_pass import SnowflakeUserPasswordProfileMapping from .snowflake.user_privatekey import SnowflakePrivateKeyPemProfileMapping -from .snowflake.user_encrypted_privatekey import SnowflakeEncryptedPrivateKeyPemProfileMapping +from .snowflake.user_encrypted_privatekey_file import SnowflakeEncryptedPrivateKeyFilePemProfileMapping +from .snowflake.user_encrypted_privatekey_env_variable import SnowflakeEncryptedPrivateKeyPemProfileMapping from .spark.thrift import SparkThriftProfileMapping from .trino.certificate import TrinoCertificateProfileMapping from .trino.jwt import TrinoJWTProfileMapping @@ -32,6 +33,7 @@ PostgresUserPasswordProfileMapping, RedshiftUserPasswordProfileMapping, SnowflakeUserPasswordProfileMapping, + SnowflakeEncryptedPrivateKeyFilePemProfileMapping, SnowflakeEncryptedPrivateKeyPemProfileMapping, SnowflakePrivateKeyPemProfileMapping, SparkThriftProfileMapping, @@ -71,6 +73,7 @@ def get_automatic_profile_mapping( "RedshiftUserPasswordProfileMapping", "SnowflakeUserPasswordProfileMapping", "SnowflakePrivateKeyPemProfileMapping", + "SnowflakeEncryptedPrivateKeyFilePemProfileMapping", "SparkThriftProfileMapping", "ExasolUserPasswordProfileMapping", "TrinoLDAPProfileMapping", diff --git a/cosmos/profiles/snowflake/__init__.py b/cosmos/profiles/snowflake/__init__.py index 26c3fb595..fdf323a76 100644 --- a/cosmos/profiles/snowflake/__init__.py +++ b/cosmos/profiles/snowflake/__init__.py @@ -2,10 +2,12 @@ from .user_pass import SnowflakeUserPasswordProfileMapping from .user_privatekey import SnowflakePrivateKeyPemProfileMapping -from .user_encrypted_privatekey import SnowflakeEncryptedPrivateKeyPemProfileMapping +from .user_encrypted_privatekey_file import SnowflakeEncryptedPrivateKeyFilePemProfileMapping +from .user_encrypted_privatekey_env_variable import SnowflakeEncryptedPrivateKeyPemProfileMapping __all__ = [ "SnowflakeUserPasswordProfileMapping", "SnowflakePrivateKeyPemProfileMapping", + "SnowflakeEncryptedPrivateKeyFilePemProfileMapping", "SnowflakeEncryptedPrivateKeyPemProfileMapping", ] diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py new file mode 100644 index 000000000..3fa7aaaf4 --- /dev/null +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py @@ -0,0 +1,86 @@ +"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key." +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any + +from ..base import BaseProfileMapping + +if TYPE_CHECKING: + from airflow.models import Connection + + +class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping): + """ + Maps Airflow Snowflake connections to dbt profiles if they use a user/private key. + https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication + https://airflow.apache.org/docs/apache-airflow-providers-snowflake/stable/connections/snowflake.html + """ + + airflow_connection_type: str = "snowflake" + dbt_profile_type: str = "snowflake" + is_community: bool = True + + required_fields = [ + "account", + "user", + "database", + "warehouse", + "schema", + "private_key", + "private_key_passphrase", + ] + secret_fields = [ + "private_key", + "private_key_passphrase", + ] + airflow_param_mapping = { + "account": "extra.account", + "user": "login", + "database": "extra.database", + "warehouse": "extra.warehouse", + "schema": "schema", + "role": "extra.role", + "private_key": "extra.private_key_content", + "private_key_passphrase": "password", + } + + @property + def conn(self) -> Connection: + """ + Snowflake can be odd because the fields used to be stored with keys in the format + 'extra__snowflake__account', but now are stored as 'account'. + + This standardizes the keys to be 'account', 'database', etc. + """ + conn = super().conn + + conn_dejson = conn.extra_dejson + + if conn_dejson.get("extra__snowflake__account"): + conn_dejson = {key.replace("extra__snowflake__", ""): value for key, value in conn_dejson.items()} + + conn.extra = json.dumps(conn_dejson) + + return conn + + @property + def profile(self) -> dict[str, Any | None]: + "Gets profile." + profile_vars = { + **self.mapped_params, + **self.profile_args, + "private_key": self.get_env_var_format("private_key"), + "private_key_passphrase": self.get_env_var_format("private_key_passphrase"), + } + + # remove any null values + return self.filter_null(profile_vars) + + def transform_account(self, account: str) -> str: + "Transform the account to the format . if it's not already." + region = self.conn.extra_dejson.get("region") + if region and region not in account: + account = f"{account}.{region}" + + return str(account) diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py similarity index 95% rename from cosmos/profiles/snowflake/user_encrypted_privatekey.py rename to cosmos/profiles/snowflake/user_encrypted_privatekey_file.py index 0623598be..96e45080a 100644 --- a/cosmos/profiles/snowflake/user_encrypted_privatekey.py +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py @@ -1,4 +1,4 @@ -"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key." +"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key path." from __future__ import annotations import json @@ -10,9 +10,9 @@ from airflow.models import Connection -class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping): +class SnowflakeEncryptedPrivateKeyFilePemProfileMapping(BaseProfileMapping): """ - Maps Airflow Snowflake connections to dbt profiles if they use a user/private key. + Maps Airflow Snowflake connections to dbt profiles if they use a user/private key path. https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication https://airflow.apache.org/docs/apache-airflow-providers-snowflake/stable/connections/snowflake.html """ diff --git a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py new file mode 100644 index 000000000..64c33d337 --- /dev/null +++ b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py @@ -0,0 +1,216 @@ +"Tests for the Snowflake user/private key file profile." + +import json +from unittest.mock import patch + +import pytest +from airflow.models.connection import Connection + +from cosmos.profiles import get_automatic_profile_mapping +from cosmos.profiles.snowflake import ( + SnowflakeEncryptedPrivateKeyFilePemProfileMapping, +) + + +@pytest.fixture() +def mock_snowflake_conn(): # type: ignore + """ + Sets the connection as an environment variable. + """ + conn = Connection( + conn_id="my_snowflake_pk_connection", + conn_type="snowflake", + login="my_user", + schema="my_schema", + password="secret", + extra=json.dumps( + { + "account": "my_account", + "region": "my_region", + "database": "my_database", + "warehouse": "my_warehouse", + "private_key_content": "my_private_key", + } + ), + ) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + yield conn + + +def test_connection_claiming() -> None: + """ + Tests that the Snowflake profile mapping claims the correct connection type. + """ + potential_values = { + "conn_type": "snowflake", + "login": "my_user", + "schema": "my_database", + "password": "secret", + "extra": json.dumps( + { + "account": "my_account", + "database": "my_database", + "warehouse": "my_warehouse", + "private_key_content": "my_private_key", + } + ), + } + + # if we're missing any of the values, it shouldn't claim + for key in potential_values: + values = potential_values.copy() + del values[key] + conn = Connection(**values) # type: ignore + + print("testing with", values) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping( + conn, + ) + assert not profile_mapping.can_claim_connection() + + # test when we're missing the account + conn = Connection(**potential_values) # type: ignore + conn.extra = '{"database": "my_database", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}' + print("testing with", conn.extra) + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + assert not profile_mapping.can_claim_connection() + + # test when we're missing the database + conn = Connection(**potential_values) # type: ignore + conn.extra = '{"account": "my_account", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}' + print("testing with", conn.extra) + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + assert not profile_mapping.can_claim_connection() + + # test when we're missing the warehouse + conn = Connection(**potential_values) # type: ignore + conn.extra = '{"account": "my_account", "database": "my_database", "private_key_content": "my_private_key"}' + print("testing with", conn.extra) + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + assert not profile_mapping.can_claim_connection() + + # if we have them all, it should claim + conn = Connection(**potential_values) # type: ignore + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + assert profile_mapping.can_claim_connection() + + +def test_profile_mapping_selected( + mock_snowflake_conn: Connection, +) -> None: + """ + Tests that the correct profile mapping is selected. + """ + profile_mapping = get_automatic_profile_mapping( + mock_snowflake_conn.conn_id, + ) + assert isinstance(profile_mapping, SnowflakeEncryptedPrivateKeyFilePemProfileMapping) + + +def test_profile_args( + mock_snowflake_conn: Connection, +) -> None: + """ + Tests that the profile values get set correctly. + """ + profile_mapping = get_automatic_profile_mapping( + mock_snowflake_conn.conn_id, + ) + + mock_account = mock_snowflake_conn.extra_dejson.get("account") + mock_region = mock_snowflake_conn.extra_dejson.get("region") + + assert profile_mapping.profile == { + "type": mock_snowflake_conn.conn_type, + "user": mock_snowflake_conn.login, + "private_key": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY') }}", + "private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}", + "schema": mock_snowflake_conn.schema, + "account": f"{mock_account}.{mock_region}", + "database": mock_snowflake_conn.extra_dejson.get("database"), + "warehouse": mock_snowflake_conn.extra_dejson.get("warehouse"), + } + + +def test_profile_args_overrides( + mock_snowflake_conn: Connection, +) -> None: + """ + Tests that you can override the profile values. + """ + profile_mapping = get_automatic_profile_mapping( + mock_snowflake_conn.conn_id, + profile_args={"database": "my_db_override"}, + ) + assert profile_mapping.profile_args == { + "database": "my_db_override", + } + + mock_account = mock_snowflake_conn.extra_dejson.get("account") + mock_region = mock_snowflake_conn.extra_dejson.get("region") + + assert profile_mapping.profile == { + "type": mock_snowflake_conn.conn_type, + "user": mock_snowflake_conn.login, + "private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}", + "private_key": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY') }}", + "schema": mock_snowflake_conn.schema, + "account": f"{mock_account}.{mock_region}", + "database": "my_db_override", + "warehouse": mock_snowflake_conn.extra_dejson.get("warehouse"), + } + + +def test_profile_env_vars( + mock_snowflake_conn: Connection, +) -> None: + """ + Tests that the environment variables get set correctly. + """ + profile_mapping = get_automatic_profile_mapping( + mock_snowflake_conn.conn_id, + ) + assert profile_mapping.env_vars == { + "COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE": mock_snowflake_conn.password, + } + + +def test_old_snowflake_format() -> None: + """ + Tests that the old format still works. + """ + conn = Connection( + conn_id="my_snowflake_connection", + conn_type="snowflake", + login="my_user", + schema="my_schema", + password="secret", + extra=json.dumps( + { + "extra__snowflake__account": "my_account", + "extra__snowflake__database": "my_database", + "extra__snowflake__warehouse": "my_warehouse", + "extra__snowflake__private_key_content": "my_private_key", + } + ), + ) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + assert profile_mapping.profile == { + "type": conn.conn_type, + "user": conn.login, + "private_key": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY') }}", + "private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}", + "schema": conn.schema, + "account": conn.extra_dejson.get("account"), + "database": conn.extra_dejson.get("database"), + "warehouse": conn.extra_dejson.get("warehouse"), + } diff --git a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey.py b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_file.py similarity index 99% rename from tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey.py rename to tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_file.py index b61b85094..b3ccde977 100644 --- a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey.py +++ b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_file.py @@ -1,4 +1,4 @@ -"Tests for the Snowflake user/private key profile." +"Tests for the Snowflake user/private key file profile." import json from unittest.mock import patch From 32c7f3d1a896f8f98a05ef161c714a64b7d27783 Mon Sep 17 00:00:00 2001 From: DanMawdsleyBA Date: Sat, 4 Nov 2023 11:02:21 +0000 Subject: [PATCH 02/13] Updating test description --- .../test_snowflake_user_encrypted_privatekey_env_variable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py index 64c33d337..314bb489e 100644 --- a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py +++ b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py @@ -1,4 +1,4 @@ -"Tests for the Snowflake user/private key file profile." +"Tests for the Snowflake user/private key enviroment variable profile." import json from unittest.mock import patch From ba8af7dfd2cb764ff37e5e393c7c7cb00424166b Mon Sep 17 00:00:00 2001 From: DanMawdsleyBA Date: Sat, 4 Nov 2023 11:13:32 +0000 Subject: [PATCH 03/13] Work around for user/password mapping --- cosmos/profiles/snowflake/user_pass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cosmos/profiles/snowflake/user_pass.py b/cosmos/profiles/snowflake/user_pass.py index 2e1025a2c..ce3ea6472 100644 --- a/cosmos/profiles/snowflake/user_pass.py +++ b/cosmos/profiles/snowflake/user_pass.py @@ -44,7 +44,7 @@ class SnowflakeUserPasswordProfileMapping(BaseProfileMapping): def can_claim_connection(self) -> bool: # Make sure this isn't a private key path credential result = super().can_claim_connection() - if result and self.conn.extra_dejson.get("private_key_file") is not None: + if result and self.conn.extra_dejson.get("private_key_file") is not None and self.conn.extra_dejson.get("private_key_content") is not None: return False return result From c5668a79c80682b3799a6dd137f61ceb91329357 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 4 Nov 2023 11:14:01 +0000 Subject: [PATCH 04/13] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cosmos/profiles/snowflake/user_pass.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/cosmos/profiles/snowflake/user_pass.py b/cosmos/profiles/snowflake/user_pass.py index ce3ea6472..9ced6b691 100644 --- a/cosmos/profiles/snowflake/user_pass.py +++ b/cosmos/profiles/snowflake/user_pass.py @@ -44,7 +44,11 @@ class SnowflakeUserPasswordProfileMapping(BaseProfileMapping): def can_claim_connection(self) -> bool: # Make sure this isn't a private key path credential result = super().can_claim_connection() - if result and self.conn.extra_dejson.get("private_key_file") is not None and self.conn.extra_dejson.get("private_key_content") is not None: + if ( + result + and self.conn.extra_dejson.get("private_key_file") is not None + and self.conn.extra_dejson.get("private_key_content") is not None + ): return False return result From feaae3e3f4384a47157bbf1b2de586a0da0ec9f1 Mon Sep 17 00:00:00 2001 From: DanMawdsleyBA Date: Sat, 4 Nov 2023 11:26:10 +0000 Subject: [PATCH 05/13] Adding conditions on claiming connections --- .../user_encrypted_privatekey_env_variable.py | 10 ++++++++++ .../snowflake/user_encrypted_privatekey_file.py | 10 ++++++++++ 2 files changed, 20 insertions(+) diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py index 3fa7aaaf4..f4c57154b 100644 --- a/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py @@ -45,6 +45,16 @@ class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping): "private_key_passphrase": "password", } + def can_claim_connection(self) -> bool: + # Make sure this isn't a private key path credential + result = super().can_claim_connection() + if ( + result + and self.conn.extra_dejson.get("private_key_file") is not None + ): + return False + return result + @property def conn(self) -> Connection: """ diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py index 96e45080a..8c57931bb 100644 --- a/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py @@ -44,6 +44,16 @@ class SnowflakeEncryptedPrivateKeyFilePemProfileMapping(BaseProfileMapping): "private_key_path": "extra.private_key_file", } + def can_claim_connection(self) -> bool: + # Make sure this isn't a private key enviroment variable + result = super().can_claim_connection() + if ( + result + and self.conn.extra_dejson.get("private_key_content") is not None + ): + return False + return result + @property def conn(self) -> Connection: """ From 6c8ec36b9defa4bffc695475e93ea2b74fc8142b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 4 Nov 2023 11:26:35 +0000 Subject: [PATCH 06/13] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../snowflake/user_encrypted_privatekey_env_variable.py | 5 +---- cosmos/profiles/snowflake/user_encrypted_privatekey_file.py | 5 +---- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py index f4c57154b..fecfa97fe 100644 --- a/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py @@ -48,10 +48,7 @@ class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping): def can_claim_connection(self) -> bool: # Make sure this isn't a private key path credential result = super().can_claim_connection() - if ( - result - and self.conn.extra_dejson.get("private_key_file") is not None - ): + if result and self.conn.extra_dejson.get("private_key_file") is not None: return False return result diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py index 8c57931bb..c806658e1 100644 --- a/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py @@ -47,10 +47,7 @@ class SnowflakeEncryptedPrivateKeyFilePemProfileMapping(BaseProfileMapping): def can_claim_connection(self) -> bool: # Make sure this isn't a private key enviroment variable result = super().can_claim_connection() - if ( - result - and self.conn.extra_dejson.get("private_key_content") is not None - ): + if result and self.conn.extra_dejson.get("private_key_content") is not None: return False return result From 8d95a49a971b6642d0f159b4969e609b7e05b922 Mon Sep 17 00:00:00 2001 From: DanMawdsleyBA Date: Sat, 4 Nov 2023 11:40:57 +0000 Subject: [PATCH 07/13] Updating import on tests --- ...ake_user_encrypted_privatekey_env_variable.py | 16 ++++++++-------- ...t_snowflake_user_encrypted_privatekey_file.py | 16 ++++++++-------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py index 314bb489e..4ed746fb7 100644 --- a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py +++ b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py @@ -8,7 +8,7 @@ from cosmos.profiles import get_automatic_profile_mapping from cosmos.profiles.snowflake import ( - SnowflakeEncryptedPrivateKeyFilePemProfileMapping, + SnowflakeEncryptedPrivateKeyPemProfileMapping, ) @@ -66,7 +66,7 @@ def test_connection_claiming() -> None: print("testing with", values) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping( + profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping( conn, ) assert not profile_mapping.can_claim_connection() @@ -76,7 +76,7 @@ def test_connection_claiming() -> None: conn.extra = '{"database": "my_database", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}' print("testing with", conn.extra) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) assert not profile_mapping.can_claim_connection() # test when we're missing the database @@ -84,7 +84,7 @@ def test_connection_claiming() -> None: conn.extra = '{"account": "my_account", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}' print("testing with", conn.extra) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) assert not profile_mapping.can_claim_connection() # test when we're missing the warehouse @@ -92,13 +92,13 @@ def test_connection_claiming() -> None: conn.extra = '{"account": "my_account", "database": "my_database", "private_key_content": "my_private_key"}' print("testing with", conn.extra) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) assert not profile_mapping.can_claim_connection() # if we have them all, it should claim conn = Connection(**potential_values) # type: ignore with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) assert profile_mapping.can_claim_connection() @@ -111,7 +111,7 @@ def test_profile_mapping_selected( profile_mapping = get_automatic_profile_mapping( mock_snowflake_conn.conn_id, ) - assert isinstance(profile_mapping, SnowflakeEncryptedPrivateKeyFilePemProfileMapping) + assert isinstance(profile_mapping, SnowflakeEncryptedPrivateKeyPemProfileMapping) def test_profile_args( @@ -203,7 +203,7 @@ def test_old_snowflake_format() -> None: ) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) assert profile_mapping.profile == { "type": conn.conn_type, "user": conn.login, diff --git a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_file.py b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_file.py index b3ccde977..d8c3aedcf 100644 --- a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_file.py +++ b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_file.py @@ -8,7 +8,7 @@ from cosmos.profiles import get_automatic_profile_mapping from cosmos.profiles.snowflake import ( - SnowflakeEncryptedPrivateKeyPemProfileMapping, + SnowflakeEncryptedPrivateKeyFilePemProfileMapping, ) @@ -66,7 +66,7 @@ def test_connection_claiming() -> None: print("testing with", values) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping( + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping( conn, ) assert not profile_mapping.can_claim_connection() @@ -76,7 +76,7 @@ def test_connection_claiming() -> None: conn.extra = '{"database": "my_database", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}' print("testing with", conn.extra) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) assert not profile_mapping.can_claim_connection() # test when we're missing the database @@ -84,7 +84,7 @@ def test_connection_claiming() -> None: conn.extra = '{"account": "my_account", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}' print("testing with", conn.extra) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) assert not profile_mapping.can_claim_connection() # test when we're missing the warehouse @@ -92,13 +92,13 @@ def test_connection_claiming() -> None: conn.extra = '{"account": "my_account", "database": "my_database", "private_key_content": "my_private_key"}' print("testing with", conn.extra) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) assert not profile_mapping.can_claim_connection() # if we have them all, it should claim conn = Connection(**potential_values) # type: ignore with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) assert profile_mapping.can_claim_connection() @@ -111,7 +111,7 @@ def test_profile_mapping_selected( profile_mapping = get_automatic_profile_mapping( mock_snowflake_conn.conn_id, ) - assert isinstance(profile_mapping, SnowflakeEncryptedPrivateKeyPemProfileMapping) + assert isinstance(profile_mapping, SnowflakeEncryptedPrivateKeyFilePemProfileMapping) def test_profile_args( @@ -203,7 +203,7 @@ def test_old_snowflake_format() -> None: ) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) assert profile_mapping.profile == { "type": conn.conn_type, "user": conn.login, From 643f03f6a19ba2814d6435e4acab637202da73bb Mon Sep 17 00:00:00 2001 From: DanMawdsleyBA Date: Sat, 4 Nov 2023 11:47:46 +0000 Subject: [PATCH 08/13] Or condition for user pass --- cosmos/profiles/snowflake/user_pass.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cosmos/profiles/snowflake/user_pass.py b/cosmos/profiles/snowflake/user_pass.py index 9ced6b691..c24571a39 100644 --- a/cosmos/profiles/snowflake/user_pass.py +++ b/cosmos/profiles/snowflake/user_pass.py @@ -46,8 +46,8 @@ def can_claim_connection(self) -> bool: result = super().can_claim_connection() if ( result - and self.conn.extra_dejson.get("private_key_file") is not None - and self.conn.extra_dejson.get("private_key_content") is not None + and (self.conn.extra_dejson.get("private_key_file") is not None + or self.conn.extra_dejson.get("private_key_content") is not None) ): return False return result From 07341668e08b1d9c5ed5fde0129be4f1fde25a8a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 4 Nov 2023 11:48:08 +0000 Subject: [PATCH 09/13] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cosmos/profiles/snowflake/user_pass.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/cosmos/profiles/snowflake/user_pass.py b/cosmos/profiles/snowflake/user_pass.py index c24571a39..fa634d1a2 100644 --- a/cosmos/profiles/snowflake/user_pass.py +++ b/cosmos/profiles/snowflake/user_pass.py @@ -44,10 +44,9 @@ class SnowflakeUserPasswordProfileMapping(BaseProfileMapping): def can_claim_connection(self) -> bool: # Make sure this isn't a private key path credential result = super().can_claim_connection() - if ( - result - and (self.conn.extra_dejson.get("private_key_file") is not None - or self.conn.extra_dejson.get("private_key_content") is not None) + if result and ( + self.conn.extra_dejson.get("private_key_file") is not None + or self.conn.extra_dejson.get("private_key_content") is not None ): return False return result From 8f46b5542890e03a60c66c63eec6e35f75185868 Mon Sep 17 00:00:00 2001 From: DanMawdsleyBA Date: Sat, 4 Nov 2023 11:52:24 +0000 Subject: [PATCH 10/13] Adding private key to env test --- .../test_snowflake_user_encrypted_privatekey_env_variable.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py index 4ed746fb7..7ed7f7d3d 100644 --- a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py +++ b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py @@ -178,6 +178,7 @@ def test_profile_env_vars( mock_snowflake_conn.conn_id, ) assert profile_mapping.env_vars == { + "COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY": mock_snowflake_conn.extra_dejson.get("private_key_content"), "COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE": mock_snowflake_conn.password, } From 188fe56994e29037eda1777974faa45a183ebc1f Mon Sep 17 00:00:00 2001 From: DanMawdsleyBA Date: Sat, 4 Nov 2023 11:57:42 +0000 Subject: [PATCH 11/13] Fixing typo --- cosmos/profiles/snowflake/user_encrypted_privatekey_file.py | 2 +- .../test_snowflake_user_encrypted_privatekey_env_variable.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py index c806658e1..6831cbd28 100644 --- a/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py @@ -45,7 +45,7 @@ class SnowflakeEncryptedPrivateKeyFilePemProfileMapping(BaseProfileMapping): } def can_claim_connection(self) -> bool: - # Make sure this isn't a private key enviroment variable + # Make sure this isn't a private key environmentvariable result = super().can_claim_connection() if result and self.conn.extra_dejson.get("private_key_content") is not None: return False diff --git a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py index 7ed7f7d3d..2c7515f72 100644 --- a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py +++ b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py @@ -1,4 +1,4 @@ -"Tests for the Snowflake user/private key enviroment variable profile." +"Tests for the Snowflake user/private key environmentvariable profile." import json from unittest.mock import patch From 9cd46d2353dff874cc94c66baa0b146255ec42f5 Mon Sep 17 00:00:00 2001 From: Joppe Vos <44348300+joppevos@users.noreply.github.com> Date: Wed, 8 Nov 2023 16:29:49 +0100 Subject: [PATCH 12/13] Add `operator_args` `full_refresh` as a templated field (#623) This allows you to fully refresh a model from the console. Full-refresh/backfill is a common task. Using Airflow parameters makes this easy. Without this, you'd have to trigger an entire deployment. In our setup, company analysts manage their models without modifying the DAG code. This empowers such users. Example of usage: ```python with DAG( dag_id="jaffle", params={"full_refresh": Param(default=False, type="boolean")}, render_template_as_native_obj=True ): task = DbtTaskGroup( operator_args={"full_refresh": "{{ params.get('full_refresh') }}", "install_deps": True}, ) ``` Closes: #151 --- cosmos/operators/local.py | 3 +++ tests/operators/test_local.py | 13 +++++++++++++ 2 files changed, 16 insertions(+) diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 1c00f476c..145741096 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -397,6 +397,8 @@ class DbtSeedLocalOperator(DbtLocalBaseOperator): ui_color = "#F58D7E" + template_fields: Sequence[str] = DbtBaseOperator.template_fields + ("full_refresh",) # type: ignore[operator] + def __init__(self, full_refresh: bool = False, **kwargs: Any) -> None: self.full_refresh = full_refresh super().__init__(**kwargs) @@ -434,6 +436,7 @@ class DbtRunLocalOperator(DbtLocalBaseOperator): ui_color = "#7352BA" ui_fgcolor = "#F4F2FC" + template_fields: Sequence[str] = DbtBaseOperator.template_fields + ("full_refresh",) # type: ignore[operator] def __init__(self, full_refresh: bool = False, **kwargs: Any) -> None: self.full_refresh = full_refresh diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index 580d49e6c..14213b335 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -420,6 +420,19 @@ def test_calculate_openlineage_events_completes_openlineage_errors(mock_processo assert err_msg in caplog.text +@pytest.mark.parametrize( + "operator_class,expected_template", + [ + (DbtSeedLocalOperator, ("env", "vars", "full_refresh")), + (DbtRunLocalOperator, ("env", "vars", "full_refresh")), + ], +) +def test_dbt_base_operator_template_fields(operator_class, expected_template): + # Check if value of template fields is what we expect for the operators we're validating + dbt_base_operator = operator_class(profile_config=profile_config, task_id="my-task", project_dir="my/dir") + assert dbt_base_operator.template_fields == expected_template + + @patch.object(DbtDocsGCSLocalOperator, "required_files", ["file1", "file2"]) def test_dbt_docs_gcs_local_operator(): mock_gcs = MagicMock() From 3c85872d945dffd692fc2223ff7e8f3d5f0f3a81 Mon Sep 17 00:00:00 2001 From: Benjamin Dornel Date: Thu, 9 Nov 2023 17:40:37 +0800 Subject: [PATCH 13/13] Add aws_session_token for Athena mapping --- cosmos/profiles/__init__.py | 1 + cosmos/profiles/athena/access_key.py | 8 ++++---- tests/profiles/athena/test_athena_access_key.py | 8 +++++++- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/cosmos/profiles/__init__.py b/cosmos/profiles/__init__.py index 8280cd950..1f39a91a0 100644 --- a/cosmos/profiles/__init__.py +++ b/cosmos/profiles/__init__.py @@ -64,6 +64,7 @@ def get_automatic_profile_mapping( __all__ = [ + "AthenaAccessKeyProfileMapping", "BaseProfileMapping", "GoogleCloudServiceAccountFileProfileMapping", "GoogleCloudServiceAccountDictProfileMapping", diff --git a/cosmos/profiles/athena/access_key.py b/cosmos/profiles/athena/access_key.py index b79bb793a..a8f71c2b7 100644 --- a/cosmos/profiles/athena/access_key.py +++ b/cosmos/profiles/athena/access_key.py @@ -26,12 +26,11 @@ class AthenaAccessKeyProfileMapping(BaseProfileMapping): "s3_staging_dir", "schema", ] - secret_fields = [ - "aws_secret_access_key", - ] + secret_fields = ["aws_secret_access_key", "aws_session_token"] airflow_param_mapping = { "aws_access_key_id": "login", "aws_secret_access_key": "password", + "aws_session_token": "extra.aws_session_token", "aws_profile_name": "extra.aws_profile_name", "database": "extra.database", "debug_query_state": "extra.debug_query_state", @@ -53,7 +52,8 @@ def profile(self) -> dict[str, Any | None]: profile = { **self.mapped_params, **self.profile_args, - # aws_secret_access_key should always get set as env var + # aws_secret_access_key and aws_session_token should always get set as env var "aws_secret_access_key": self.get_env_var_format("aws_secret_access_key"), + "aws_session_token": self.get_env_var_format("aws_session_token"), } return self.filter_null(profile) diff --git a/tests/profiles/athena/test_athena_access_key.py b/tests/profiles/athena/test_athena_access_key.py index 2063ef6ed..22c8efa2c 100644 --- a/tests/profiles/athena/test_athena_access_key.py +++ b/tests/profiles/athena/test_athena_access_key.py @@ -22,6 +22,7 @@ def mock_athena_conn(): # type: ignore password="my_aws_secret_key", extra=json.dumps( { + "aws_session_token": "token123", "database": "my_database", "region_name": "my_region", "s3_staging_dir": "s3://my_bucket/dbt/", @@ -107,6 +108,7 @@ def test_athena_profile_args( "type": "athena", "aws_access_key_id": mock_athena_conn.login, "aws_secret_access_key": "{{ env_var('COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY') }}", + "aws_session_token": "{{ env_var('COSMOS_CONN_AWS_AWS_SESSION_TOKEN') }}", "database": mock_athena_conn.extra_dejson.get("database"), "region_name": mock_athena_conn.extra_dejson.get("region_name"), "s3_staging_dir": mock_athena_conn.extra_dejson.get("s3_staging_dir"), @@ -122,17 +124,20 @@ def test_athena_profile_args_overrides( """ profile_mapping = get_automatic_profile_mapping( mock_athena_conn.conn_id, - profile_args={"schema": "my_custom_schema", "database": "my_custom_db"}, + profile_args={"schema": "my_custom_schema", "database": "my_custom_db", "aws_session_token": "override_token"}, ) + assert profile_mapping.profile_args == { "schema": "my_custom_schema", "database": "my_custom_db", + "aws_session_token": "override_token", } assert profile_mapping.profile == { "type": "athena", "aws_access_key_id": mock_athena_conn.login, "aws_secret_access_key": "{{ env_var('COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY') }}", + "aws_session_token": "{{ env_var('COSMOS_CONN_AWS_AWS_SESSION_TOKEN') }}", "database": "my_custom_db", "region_name": mock_athena_conn.extra_dejson.get("region_name"), "s3_staging_dir": mock_athena_conn.extra_dejson.get("s3_staging_dir"), @@ -151,4 +156,5 @@ def test_athena_profile_env_vars( ) assert profile_mapping.env_vars == { "COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY": mock_athena_conn.password, + "COSMOS_CONN_AWS_AWS_SESSION_TOKEN": mock_athena_conn.extra_dejson.get("aws_session_token"), }