From ca65d1067a2c508e9d8090d270e580b4dfea4777 Mon Sep 17 00:00:00 2001 From: DanMawdsleyBA Date: Sat, 4 Nov 2023 10:56:05 +0000 Subject: [PATCH 01/16] Intial change for Snowflake encrypted private key --- cosmos/profiles/__init__.py | 5 +- cosmos/profiles/snowflake/__init__.py | 4 +- .../user_encrypted_privatekey_env_variable.py | 86 +++++++ ...y.py => user_encrypted_privatekey_file.py} | 6 +- ..._user_encrypted_privatekey_env_variable.py | 216 ++++++++++++++++++ ...owflake_user_encrypted_privatekey_file.py} | 2 +- 6 files changed, 313 insertions(+), 6 deletions(-) create mode 100644 cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py rename cosmos/profiles/snowflake/{user_encrypted_privatekey.py => user_encrypted_privatekey_file.py} (95%) create mode 100644 tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py rename tests/profiles/snowflake/{test_snowflake_user_encrypted_privatekey.py => test_snowflake_user_encrypted_privatekey_file.py} (99%) diff --git a/cosmos/profiles/__init__.py b/cosmos/profiles/__init__.py index 47c7309ab..8280cd950 100644 --- a/cosmos/profiles/__init__.py +++ b/cosmos/profiles/__init__.py @@ -16,7 +16,8 @@ from .redshift.user_pass import RedshiftUserPasswordProfileMapping from .snowflake.user_pass import SnowflakeUserPasswordProfileMapping from .snowflake.user_privatekey import SnowflakePrivateKeyPemProfileMapping -from .snowflake.user_encrypted_privatekey import SnowflakeEncryptedPrivateKeyPemProfileMapping +from .snowflake.user_encrypted_privatekey_file import SnowflakeEncryptedPrivateKeyFilePemProfileMapping +from .snowflake.user_encrypted_privatekey_env_variable import SnowflakeEncryptedPrivateKeyPemProfileMapping from .spark.thrift import SparkThriftProfileMapping from .trino.certificate import TrinoCertificateProfileMapping from .trino.jwt import TrinoJWTProfileMapping @@ -32,6 +33,7 @@ PostgresUserPasswordProfileMapping, RedshiftUserPasswordProfileMapping, SnowflakeUserPasswordProfileMapping, + SnowflakeEncryptedPrivateKeyFilePemProfileMapping, SnowflakeEncryptedPrivateKeyPemProfileMapping, SnowflakePrivateKeyPemProfileMapping, SparkThriftProfileMapping, @@ -71,6 +73,7 @@ def get_automatic_profile_mapping( "RedshiftUserPasswordProfileMapping", "SnowflakeUserPasswordProfileMapping", "SnowflakePrivateKeyPemProfileMapping", + "SnowflakeEncryptedPrivateKeyFilePemProfileMapping", "SparkThriftProfileMapping", "ExasolUserPasswordProfileMapping", "TrinoLDAPProfileMapping", diff --git a/cosmos/profiles/snowflake/__init__.py b/cosmos/profiles/snowflake/__init__.py index 26c3fb595..fdf323a76 100644 --- a/cosmos/profiles/snowflake/__init__.py +++ b/cosmos/profiles/snowflake/__init__.py @@ -2,10 +2,12 @@ from .user_pass import SnowflakeUserPasswordProfileMapping from .user_privatekey import SnowflakePrivateKeyPemProfileMapping -from .user_encrypted_privatekey import SnowflakeEncryptedPrivateKeyPemProfileMapping +from .user_encrypted_privatekey_file import SnowflakeEncryptedPrivateKeyFilePemProfileMapping +from .user_encrypted_privatekey_env_variable import SnowflakeEncryptedPrivateKeyPemProfileMapping __all__ = [ "SnowflakeUserPasswordProfileMapping", "SnowflakePrivateKeyPemProfileMapping", + "SnowflakeEncryptedPrivateKeyFilePemProfileMapping", "SnowflakeEncryptedPrivateKeyPemProfileMapping", ] diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py new file mode 100644 index 000000000..3fa7aaaf4 --- /dev/null +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py @@ -0,0 +1,86 @@ +"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key." +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any + +from ..base import BaseProfileMapping + +if TYPE_CHECKING: + from airflow.models import Connection + + +class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping): + """ + Maps Airflow Snowflake connections to dbt profiles if they use a user/private key. + https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication + https://airflow.apache.org/docs/apache-airflow-providers-snowflake/stable/connections/snowflake.html + """ + + airflow_connection_type: str = "snowflake" + dbt_profile_type: str = "snowflake" + is_community: bool = True + + required_fields = [ + "account", + "user", + "database", + "warehouse", + "schema", + "private_key", + "private_key_passphrase", + ] + secret_fields = [ + "private_key", + "private_key_passphrase", + ] + airflow_param_mapping = { + "account": "extra.account", + "user": "login", + "database": "extra.database", + "warehouse": "extra.warehouse", + "schema": "schema", + "role": "extra.role", + "private_key": "extra.private_key_content", + "private_key_passphrase": "password", + } + + @property + def conn(self) -> Connection: + """ + Snowflake can be odd because the fields used to be stored with keys in the format + 'extra__snowflake__account', but now are stored as 'account'. + + This standardizes the keys to be 'account', 'database', etc. + """ + conn = super().conn + + conn_dejson = conn.extra_dejson + + if conn_dejson.get("extra__snowflake__account"): + conn_dejson = {key.replace("extra__snowflake__", ""): value for key, value in conn_dejson.items()} + + conn.extra = json.dumps(conn_dejson) + + return conn + + @property + def profile(self) -> dict[str, Any | None]: + "Gets profile." + profile_vars = { + **self.mapped_params, + **self.profile_args, + "private_key": self.get_env_var_format("private_key"), + "private_key_passphrase": self.get_env_var_format("private_key_passphrase"), + } + + # remove any null values + return self.filter_null(profile_vars) + + def transform_account(self, account: str) -> str: + "Transform the account to the format . if it's not already." + region = self.conn.extra_dejson.get("region") + if region and region not in account: + account = f"{account}.{region}" + + return str(account) diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py similarity index 95% rename from cosmos/profiles/snowflake/user_encrypted_privatekey.py rename to cosmos/profiles/snowflake/user_encrypted_privatekey_file.py index 0623598be..96e45080a 100644 --- a/cosmos/profiles/snowflake/user_encrypted_privatekey.py +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py @@ -1,4 +1,4 @@ -"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key." +"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key path." from __future__ import annotations import json @@ -10,9 +10,9 @@ from airflow.models import Connection -class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping): +class SnowflakeEncryptedPrivateKeyFilePemProfileMapping(BaseProfileMapping): """ - Maps Airflow Snowflake connections to dbt profiles if they use a user/private key. + Maps Airflow Snowflake connections to dbt profiles if they use a user/private key path. https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication https://airflow.apache.org/docs/apache-airflow-providers-snowflake/stable/connections/snowflake.html """ diff --git a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py new file mode 100644 index 000000000..64c33d337 --- /dev/null +++ b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py @@ -0,0 +1,216 @@ +"Tests for the Snowflake user/private key file profile." + +import json +from unittest.mock import patch + +import pytest +from airflow.models.connection import Connection + +from cosmos.profiles import get_automatic_profile_mapping +from cosmos.profiles.snowflake import ( + SnowflakeEncryptedPrivateKeyFilePemProfileMapping, +) + + +@pytest.fixture() +def mock_snowflake_conn(): # type: ignore + """ + Sets the connection as an environment variable. + """ + conn = Connection( + conn_id="my_snowflake_pk_connection", + conn_type="snowflake", + login="my_user", + schema="my_schema", + password="secret", + extra=json.dumps( + { + "account": "my_account", + "region": "my_region", + "database": "my_database", + "warehouse": "my_warehouse", + "private_key_content": "my_private_key", + } + ), + ) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + yield conn + + +def test_connection_claiming() -> None: + """ + Tests that the Snowflake profile mapping claims the correct connection type. + """ + potential_values = { + "conn_type": "snowflake", + "login": "my_user", + "schema": "my_database", + "password": "secret", + "extra": json.dumps( + { + "account": "my_account", + "database": "my_database", + "warehouse": "my_warehouse", + "private_key_content": "my_private_key", + } + ), + } + + # if we're missing any of the values, it shouldn't claim + for key in potential_values: + values = potential_values.copy() + del values[key] + conn = Connection(**values) # type: ignore + + print("testing with", values) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping( + conn, + ) + assert not profile_mapping.can_claim_connection() + + # test when we're missing the account + conn = Connection(**potential_values) # type: ignore + conn.extra = '{"database": "my_database", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}' + print("testing with", conn.extra) + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + assert not profile_mapping.can_claim_connection() + + # test when we're missing the database + conn = Connection(**potential_values) # type: ignore + conn.extra = '{"account": "my_account", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}' + print("testing with", conn.extra) + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + assert not profile_mapping.can_claim_connection() + + # test when we're missing the warehouse + conn = Connection(**potential_values) # type: ignore + conn.extra = '{"account": "my_account", "database": "my_database", "private_key_content": "my_private_key"}' + print("testing with", conn.extra) + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + assert not profile_mapping.can_claim_connection() + + # if we have them all, it should claim + conn = Connection(**potential_values) # type: ignore + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + assert profile_mapping.can_claim_connection() + + +def test_profile_mapping_selected( + mock_snowflake_conn: Connection, +) -> None: + """ + Tests that the correct profile mapping is selected. + """ + profile_mapping = get_automatic_profile_mapping( + mock_snowflake_conn.conn_id, + ) + assert isinstance(profile_mapping, SnowflakeEncryptedPrivateKeyFilePemProfileMapping) + + +def test_profile_args( + mock_snowflake_conn: Connection, +) -> None: + """ + Tests that the profile values get set correctly. + """ + profile_mapping = get_automatic_profile_mapping( + mock_snowflake_conn.conn_id, + ) + + mock_account = mock_snowflake_conn.extra_dejson.get("account") + mock_region = mock_snowflake_conn.extra_dejson.get("region") + + assert profile_mapping.profile == { + "type": mock_snowflake_conn.conn_type, + "user": mock_snowflake_conn.login, + "private_key": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY') }}", + "private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}", + "schema": mock_snowflake_conn.schema, + "account": f"{mock_account}.{mock_region}", + "database": mock_snowflake_conn.extra_dejson.get("database"), + "warehouse": mock_snowflake_conn.extra_dejson.get("warehouse"), + } + + +def test_profile_args_overrides( + mock_snowflake_conn: Connection, +) -> None: + """ + Tests that you can override the profile values. + """ + profile_mapping = get_automatic_profile_mapping( + mock_snowflake_conn.conn_id, + profile_args={"database": "my_db_override"}, + ) + assert profile_mapping.profile_args == { + "database": "my_db_override", + } + + mock_account = mock_snowflake_conn.extra_dejson.get("account") + mock_region = mock_snowflake_conn.extra_dejson.get("region") + + assert profile_mapping.profile == { + "type": mock_snowflake_conn.conn_type, + "user": mock_snowflake_conn.login, + "private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}", + "private_key": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY') }}", + "schema": mock_snowflake_conn.schema, + "account": f"{mock_account}.{mock_region}", + "database": "my_db_override", + "warehouse": mock_snowflake_conn.extra_dejson.get("warehouse"), + } + + +def test_profile_env_vars( + mock_snowflake_conn: Connection, +) -> None: + """ + Tests that the environment variables get set correctly. + """ + profile_mapping = get_automatic_profile_mapping( + mock_snowflake_conn.conn_id, + ) + assert profile_mapping.env_vars == { + "COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE": mock_snowflake_conn.password, + } + + +def test_old_snowflake_format() -> None: + """ + Tests that the old format still works. + """ + conn = Connection( + conn_id="my_snowflake_connection", + conn_type="snowflake", + login="my_user", + schema="my_schema", + password="secret", + extra=json.dumps( + { + "extra__snowflake__account": "my_account", + "extra__snowflake__database": "my_database", + "extra__snowflake__warehouse": "my_warehouse", + "extra__snowflake__private_key_content": "my_private_key", + } + ), + ) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + assert profile_mapping.profile == { + "type": conn.conn_type, + "user": conn.login, + "private_key": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY') }}", + "private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}", + "schema": conn.schema, + "account": conn.extra_dejson.get("account"), + "database": conn.extra_dejson.get("database"), + "warehouse": conn.extra_dejson.get("warehouse"), + } diff --git a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey.py b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_file.py similarity index 99% rename from tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey.py rename to tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_file.py index b61b85094..b3ccde977 100644 --- a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey.py +++ b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_file.py @@ -1,4 +1,4 @@ -"Tests for the Snowflake user/private key profile." +"Tests for the Snowflake user/private key file profile." import json from unittest.mock import patch From 32c7f3d1a896f8f98a05ef161c714a64b7d27783 Mon Sep 17 00:00:00 2001 From: DanMawdsleyBA Date: Sat, 4 Nov 2023 11:02:21 +0000 Subject: [PATCH 02/16] Updating test description --- .../test_snowflake_user_encrypted_privatekey_env_variable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py index 64c33d337..314bb489e 100644 --- a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py +++ b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py @@ -1,4 +1,4 @@ -"Tests for the Snowflake user/private key file profile." +"Tests for the Snowflake user/private key enviroment variable profile." import json from unittest.mock import patch From ba8af7dfd2cb764ff37e5e393c7c7cb00424166b Mon Sep 17 00:00:00 2001 From: DanMawdsleyBA Date: Sat, 4 Nov 2023 11:13:32 +0000 Subject: [PATCH 03/16] Work around for user/password mapping --- cosmos/profiles/snowflake/user_pass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cosmos/profiles/snowflake/user_pass.py b/cosmos/profiles/snowflake/user_pass.py index 2e1025a2c..ce3ea6472 100644 --- a/cosmos/profiles/snowflake/user_pass.py +++ b/cosmos/profiles/snowflake/user_pass.py @@ -44,7 +44,7 @@ class SnowflakeUserPasswordProfileMapping(BaseProfileMapping): def can_claim_connection(self) -> bool: # Make sure this isn't a private key path credential result = super().can_claim_connection() - if result and self.conn.extra_dejson.get("private_key_file") is not None: + if result and self.conn.extra_dejson.get("private_key_file") is not None and self.conn.extra_dejson.get("private_key_content") is not None: return False return result From c5668a79c80682b3799a6dd137f61ceb91329357 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 4 Nov 2023 11:14:01 +0000 Subject: [PATCH 04/16] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cosmos/profiles/snowflake/user_pass.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/cosmos/profiles/snowflake/user_pass.py b/cosmos/profiles/snowflake/user_pass.py index ce3ea6472..9ced6b691 100644 --- a/cosmos/profiles/snowflake/user_pass.py +++ b/cosmos/profiles/snowflake/user_pass.py @@ -44,7 +44,11 @@ class SnowflakeUserPasswordProfileMapping(BaseProfileMapping): def can_claim_connection(self) -> bool: # Make sure this isn't a private key path credential result = super().can_claim_connection() - if result and self.conn.extra_dejson.get("private_key_file") is not None and self.conn.extra_dejson.get("private_key_content") is not None: + if ( + result + and self.conn.extra_dejson.get("private_key_file") is not None + and self.conn.extra_dejson.get("private_key_content") is not None + ): return False return result From feaae3e3f4384a47157bbf1b2de586a0da0ec9f1 Mon Sep 17 00:00:00 2001 From: DanMawdsleyBA Date: Sat, 4 Nov 2023 11:26:10 +0000 Subject: [PATCH 05/16] Adding conditions on claiming connections --- .../user_encrypted_privatekey_env_variable.py | 10 ++++++++++ .../snowflake/user_encrypted_privatekey_file.py | 10 ++++++++++ 2 files changed, 20 insertions(+) diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py index 3fa7aaaf4..f4c57154b 100644 --- a/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py @@ -45,6 +45,16 @@ class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping): "private_key_passphrase": "password", } + def can_claim_connection(self) -> bool: + # Make sure this isn't a private key path credential + result = super().can_claim_connection() + if ( + result + and self.conn.extra_dejson.get("private_key_file") is not None + ): + return False + return result + @property def conn(self) -> Connection: """ diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py index 96e45080a..8c57931bb 100644 --- a/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py @@ -44,6 +44,16 @@ class SnowflakeEncryptedPrivateKeyFilePemProfileMapping(BaseProfileMapping): "private_key_path": "extra.private_key_file", } + def can_claim_connection(self) -> bool: + # Make sure this isn't a private key enviroment variable + result = super().can_claim_connection() + if ( + result + and self.conn.extra_dejson.get("private_key_content") is not None + ): + return False + return result + @property def conn(self) -> Connection: """ From 6c8ec36b9defa4bffc695475e93ea2b74fc8142b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 4 Nov 2023 11:26:35 +0000 Subject: [PATCH 06/16] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../snowflake/user_encrypted_privatekey_env_variable.py | 5 +---- cosmos/profiles/snowflake/user_encrypted_privatekey_file.py | 5 +---- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py index f4c57154b..fecfa97fe 100644 --- a/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py @@ -48,10 +48,7 @@ class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping): def can_claim_connection(self) -> bool: # Make sure this isn't a private key path credential result = super().can_claim_connection() - if ( - result - and self.conn.extra_dejson.get("private_key_file") is not None - ): + if result and self.conn.extra_dejson.get("private_key_file") is not None: return False return result diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py index 8c57931bb..c806658e1 100644 --- a/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py @@ -47,10 +47,7 @@ class SnowflakeEncryptedPrivateKeyFilePemProfileMapping(BaseProfileMapping): def can_claim_connection(self) -> bool: # Make sure this isn't a private key enviroment variable result = super().can_claim_connection() - if ( - result - and self.conn.extra_dejson.get("private_key_content") is not None - ): + if result and self.conn.extra_dejson.get("private_key_content") is not None: return False return result From 8d95a49a971b6642d0f159b4969e609b7e05b922 Mon Sep 17 00:00:00 2001 From: DanMawdsleyBA Date: Sat, 4 Nov 2023 11:40:57 +0000 Subject: [PATCH 07/16] Updating import on tests --- ...ake_user_encrypted_privatekey_env_variable.py | 16 ++++++++-------- ...t_snowflake_user_encrypted_privatekey_file.py | 16 ++++++++-------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py index 314bb489e..4ed746fb7 100644 --- a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py +++ b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py @@ -8,7 +8,7 @@ from cosmos.profiles import get_automatic_profile_mapping from cosmos.profiles.snowflake import ( - SnowflakeEncryptedPrivateKeyFilePemProfileMapping, + SnowflakeEncryptedPrivateKeyPemProfileMapping, ) @@ -66,7 +66,7 @@ def test_connection_claiming() -> None: print("testing with", values) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping( + profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping( conn, ) assert not profile_mapping.can_claim_connection() @@ -76,7 +76,7 @@ def test_connection_claiming() -> None: conn.extra = '{"database": "my_database", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}' print("testing with", conn.extra) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) assert not profile_mapping.can_claim_connection() # test when we're missing the database @@ -84,7 +84,7 @@ def test_connection_claiming() -> None: conn.extra = '{"account": "my_account", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}' print("testing with", conn.extra) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) assert not profile_mapping.can_claim_connection() # test when we're missing the warehouse @@ -92,13 +92,13 @@ def test_connection_claiming() -> None: conn.extra = '{"account": "my_account", "database": "my_database", "private_key_content": "my_private_key"}' print("testing with", conn.extra) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) assert not profile_mapping.can_claim_connection() # if we have them all, it should claim conn = Connection(**potential_values) # type: ignore with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) assert profile_mapping.can_claim_connection() @@ -111,7 +111,7 @@ def test_profile_mapping_selected( profile_mapping = get_automatic_profile_mapping( mock_snowflake_conn.conn_id, ) - assert isinstance(profile_mapping, SnowflakeEncryptedPrivateKeyFilePemProfileMapping) + assert isinstance(profile_mapping, SnowflakeEncryptedPrivateKeyPemProfileMapping) def test_profile_args( @@ -203,7 +203,7 @@ def test_old_snowflake_format() -> None: ) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) assert profile_mapping.profile == { "type": conn.conn_type, "user": conn.login, diff --git a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_file.py b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_file.py index b3ccde977..d8c3aedcf 100644 --- a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_file.py +++ b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_file.py @@ -8,7 +8,7 @@ from cosmos.profiles import get_automatic_profile_mapping from cosmos.profiles.snowflake import ( - SnowflakeEncryptedPrivateKeyPemProfileMapping, + SnowflakeEncryptedPrivateKeyFilePemProfileMapping, ) @@ -66,7 +66,7 @@ def test_connection_claiming() -> None: print("testing with", values) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping( + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping( conn, ) assert not profile_mapping.can_claim_connection() @@ -76,7 +76,7 @@ def test_connection_claiming() -> None: conn.extra = '{"database": "my_database", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}' print("testing with", conn.extra) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) assert not profile_mapping.can_claim_connection() # test when we're missing the database @@ -84,7 +84,7 @@ def test_connection_claiming() -> None: conn.extra = '{"account": "my_account", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}' print("testing with", conn.extra) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) assert not profile_mapping.can_claim_connection() # test when we're missing the warehouse @@ -92,13 +92,13 @@ def test_connection_claiming() -> None: conn.extra = '{"account": "my_account", "database": "my_database", "private_key_content": "my_private_key"}' print("testing with", conn.extra) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) assert not profile_mapping.can_claim_connection() # if we have them all, it should claim conn = Connection(**potential_values) # type: ignore with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) assert profile_mapping.can_claim_connection() @@ -111,7 +111,7 @@ def test_profile_mapping_selected( profile_mapping = get_automatic_profile_mapping( mock_snowflake_conn.conn_id, ) - assert isinstance(profile_mapping, SnowflakeEncryptedPrivateKeyPemProfileMapping) + assert isinstance(profile_mapping, SnowflakeEncryptedPrivateKeyFilePemProfileMapping) def test_profile_args( @@ -203,7 +203,7 @@ def test_old_snowflake_format() -> None: ) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) assert profile_mapping.profile == { "type": conn.conn_type, "user": conn.login, From 643f03f6a19ba2814d6435e4acab637202da73bb Mon Sep 17 00:00:00 2001 From: DanMawdsleyBA Date: Sat, 4 Nov 2023 11:47:46 +0000 Subject: [PATCH 08/16] Or condition for user pass --- cosmos/profiles/snowflake/user_pass.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cosmos/profiles/snowflake/user_pass.py b/cosmos/profiles/snowflake/user_pass.py index 9ced6b691..c24571a39 100644 --- a/cosmos/profiles/snowflake/user_pass.py +++ b/cosmos/profiles/snowflake/user_pass.py @@ -46,8 +46,8 @@ def can_claim_connection(self) -> bool: result = super().can_claim_connection() if ( result - and self.conn.extra_dejson.get("private_key_file") is not None - and self.conn.extra_dejson.get("private_key_content") is not None + and (self.conn.extra_dejson.get("private_key_file") is not None + or self.conn.extra_dejson.get("private_key_content") is not None) ): return False return result From 07341668e08b1d9c5ed5fde0129be4f1fde25a8a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 4 Nov 2023 11:48:08 +0000 Subject: [PATCH 09/16] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cosmos/profiles/snowflake/user_pass.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/cosmos/profiles/snowflake/user_pass.py b/cosmos/profiles/snowflake/user_pass.py index c24571a39..fa634d1a2 100644 --- a/cosmos/profiles/snowflake/user_pass.py +++ b/cosmos/profiles/snowflake/user_pass.py @@ -44,10 +44,9 @@ class SnowflakeUserPasswordProfileMapping(BaseProfileMapping): def can_claim_connection(self) -> bool: # Make sure this isn't a private key path credential result = super().can_claim_connection() - if ( - result - and (self.conn.extra_dejson.get("private_key_file") is not None - or self.conn.extra_dejson.get("private_key_content") is not None) + if result and ( + self.conn.extra_dejson.get("private_key_file") is not None + or self.conn.extra_dejson.get("private_key_content") is not None ): return False return result From 8f46b5542890e03a60c66c63eec6e35f75185868 Mon Sep 17 00:00:00 2001 From: DanMawdsleyBA Date: Sat, 4 Nov 2023 11:52:24 +0000 Subject: [PATCH 10/16] Adding private key to env test --- .../test_snowflake_user_encrypted_privatekey_env_variable.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py index 4ed746fb7..7ed7f7d3d 100644 --- a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py +++ b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py @@ -178,6 +178,7 @@ def test_profile_env_vars( mock_snowflake_conn.conn_id, ) assert profile_mapping.env_vars == { + "COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY": mock_snowflake_conn.extra_dejson.get("private_key_content"), "COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE": mock_snowflake_conn.password, } From 188fe56994e29037eda1777974faa45a183ebc1f Mon Sep 17 00:00:00 2001 From: DanMawdsleyBA Date: Sat, 4 Nov 2023 11:57:42 +0000 Subject: [PATCH 11/16] Fixing typo --- cosmos/profiles/snowflake/user_encrypted_privatekey_file.py | 2 +- .../test_snowflake_user_encrypted_privatekey_env_variable.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py index c806658e1..6831cbd28 100644 --- a/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py @@ -45,7 +45,7 @@ class SnowflakeEncryptedPrivateKeyFilePemProfileMapping(BaseProfileMapping): } def can_claim_connection(self) -> bool: - # Make sure this isn't a private key enviroment variable + # Make sure this isn't a private key environmentvariable result = super().can_claim_connection() if result and self.conn.extra_dejson.get("private_key_content") is not None: return False diff --git a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py index 7ed7f7d3d..2c7515f72 100644 --- a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py +++ b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py @@ -1,4 +1,4 @@ -"Tests for the Snowflake user/private key enviroment variable profile." +"Tests for the Snowflake user/private key environmentvariable profile." import json from unittest.mock import patch From 6c88bb4af4468ab45016f31f46a6db6dc2084d1e Mon Sep 17 00:00:00 2001 From: Joppe Vos Date: Wed, 8 Nov 2023 15:02:46 +0100 Subject: [PATCH 12/16] extend deps command with profile flags --- cosmos/operators/local.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 1c00f476c..2bae5ab44 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -206,17 +206,11 @@ def run_command( tmp_project_dir, ) - # if we need to install deps, do so - if self.install_deps: - self.run_subprocess( - command=[self.dbt_executable_path, "deps"], - env=env, - output_encoding=self.output_encoding, - cwd=tmp_project_dir, - ) - with self.profile_config.ensure_profile() as (profile_path, env_vars): + with self.profile_config.ensure_profile() as profile_values: + (profile_path, env_vars) = profile_values env.update(env_vars) - full_cmd = cmd + [ + + flags = [ "--profiles-dir", str(profile_path.parent), "--profile", @@ -225,6 +219,18 @@ def run_command( self.profile_config.target_name, ] + if self.install_deps: + deps_command = [self.dbt_executable_path, "deps"] + deps_command.extend(flags) + self.run_subprocess( + command=[self.dbt_executable_path, "deps"], + env=env, + output_encoding=self.output_encoding, + cwd=tmp_project_dir, + ) + + full_cmd = cmd + flags + logger.info("Trying to run the command:\n %s\nFrom %s", full_cmd, tmp_project_dir) logger.info("Using environment variables keys: %s", env.keys()) result = self.run_subprocess( From 85ea5c10bed7e7158b9426a571e40c7d328da561 Mon Sep 17 00:00:00 2001 From: Joppe Vos Date: Thu, 9 Nov 2023 14:46:13 +0100 Subject: [PATCH 13/16] create tests to check for arguments --- cosmos/operators/local.py | 2 +- tests/operators/test_local.py | 23 ++++++++++++++++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 2bae5ab44..2f0d45a97 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -223,7 +223,7 @@ def run_command( deps_command = [self.dbt_executable_path, "deps"] deps_command.extend(flags) self.run_subprocess( - command=[self.dbt_executable_path, "deps"], + command=deps_command, env=env, output_encoding=self.output_encoding, cwd=tmp_project_dir, diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index 580d49e6c..114d10ef8 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -4,7 +4,7 @@ import shutil import tempfile from pathlib import Path -from unittest.mock import MagicMock, patch, call +from unittest.mock import MagicMock, patch, call, ANY import pytest from airflow import DAG @@ -445,3 +445,24 @@ def test_dbt_docs_gcs_local_operator(): call(filename="fake-dir/target/file2", bucket_name="fake-bucket", object_name="fake-folder/file2"), ] mock_hook.upload.assert_has_calls(expected_upload_calls) + + +@patch("cosmos.operators.local.DbtLocalBaseOperator.store_compiled_sql") +@patch("cosmos.operators.local.DbtLocalBaseOperator.exception_handling") +@patch("cosmos.config.ProfileConfig.ensure_profile") +@patch("cosmos.operators.local.DbtLocalBaseOperator.run_dbt_deps") +def test_operator_execute_deps_parameters( + mock_build_and_run_cmd, mock_ensure_profile, mock_exception_handling, mock_store_compiled_sql +): + expected_call_kwargs = ["--profiles-dir", "/path/to", "--profile", "default", "--target", "dev"] + task = DbtLocalBaseOperator( + profile_config=real_profile_config, + task_id="my-task", + project_dir=DBT_PROJ_DIR, + install_deps=True, + emit_datasets=False, + dbt_executable_path="/usr/local/bin/dbt", + ) + mock_ensure_profile.return_value.__enter__.return_value = (Path("/path/to/profile"), {"ENV_VAR": "value"}) + task.execute(context={"task_instance": MagicMock()}) + mock_build_and_run_cmd.assert_called_once_with(expected_call_kwargs, env=ANY, tmp_project_dir=ANY) From fc6a5fad4419cb9d666f23fc8061c2ae352cb87b Mon Sep 17 00:00:00 2001 From: Joppe Vos Date: Thu, 9 Nov 2023 15:29:31 +0100 Subject: [PATCH 14/16] adjust failing test --- tests/operators/test_virtualenv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/operators/test_virtualenv.py b/tests/operators/test_virtualenv.py index 142a251a7..13dba8f94 100644 --- a/tests/operators/test_virtualenv.py +++ b/tests/operators/test_virtualenv.py @@ -60,7 +60,7 @@ def test_run_command( dbt_cmd = run_command_args[2] assert python_cmd[0][0][0].endswith("/bin/python") assert python_cmd[0][-1][-1] == "from importlib.metadata import version; print(version('dbt-core'))" - assert dbt_deps[0][0][-1] == "deps" + assert dbt_deps[0][0][1] == "deps" assert dbt_deps[0][0][0].endswith("/bin/dbt") assert dbt_deps[0][0][0] == dbt_cmd[0][0][0] assert dbt_cmd[0][0][1] == "do-something" From d3a5d0a92e541d4d1da0798f371c037e6c1a9188 Mon Sep 17 00:00:00 2001 From: Joppe Vos Date: Thu, 9 Nov 2023 15:30:05 +0100 Subject: [PATCH 15/16] adjust test to check for call_args --- tests/operators/test_local.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index 114d10ef8..b40b74654 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -4,7 +4,7 @@ import shutil import tempfile from pathlib import Path -from unittest.mock import MagicMock, patch, call, ANY +from unittest.mock import MagicMock, patch, call import pytest from airflow import DAG @@ -450,12 +450,21 @@ def test_dbt_docs_gcs_local_operator(): @patch("cosmos.operators.local.DbtLocalBaseOperator.store_compiled_sql") @patch("cosmos.operators.local.DbtLocalBaseOperator.exception_handling") @patch("cosmos.config.ProfileConfig.ensure_profile") -@patch("cosmos.operators.local.DbtLocalBaseOperator.run_dbt_deps") +@patch("cosmos.operators.local.DbtLocalBaseOperator.run_subprocess") def test_operator_execute_deps_parameters( mock_build_and_run_cmd, mock_ensure_profile, mock_exception_handling, mock_store_compiled_sql ): - expected_call_kwargs = ["--profiles-dir", "/path/to", "--profile", "default", "--target", "dev"] - task = DbtLocalBaseOperator( + expected_call_kwargs = [ + "/usr/local/bin/dbt", + "deps", + "--profiles-dir", + "/path/to", + "--profile", + "default", + "--target", + "dev", + ] + task = DbtRunLocalOperator( profile_config=real_profile_config, task_id="my-task", project_dir=DBT_PROJ_DIR, @@ -465,4 +474,4 @@ def test_operator_execute_deps_parameters( ) mock_ensure_profile.return_value.__enter__.return_value = (Path("/path/to/profile"), {"ENV_VAR": "value"}) task.execute(context={"task_instance": MagicMock()}) - mock_build_and_run_cmd.assert_called_once_with(expected_call_kwargs, env=ANY, tmp_project_dir=ANY) + assert mock_build_and_run_cmd.call_args.kwargs["command"] == expected_call_kwargs From c63515ab4898600c8e4cc0a6c125a675bb7b710a Mon Sep 17 00:00:00 2001 From: Joppe Vos Date: Thu, 9 Nov 2023 17:18:25 +0100 Subject: [PATCH 16/16] Index on the first call in the list --- tests/operators/test_local.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index b40b74654..2ccdfe1cf 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -474,4 +474,4 @@ def test_operator_execute_deps_parameters( ) mock_ensure_profile.return_value.__enter__.return_value = (Path("/path/to/profile"), {"ENV_VAR": "value"}) task.execute(context={"task_instance": MagicMock()}) - assert mock_build_and_run_cmd.call_args.kwargs["command"] == expected_call_kwargs + assert mock_build_and_run_cmd.call_args_list[0].kwargs["command"] == expected_call_kwargs