From 2ec33b43f139fe07167d4cf595dfee13f3548b12 Mon Sep 17 00:00:00 2001 From: DanMawdsleyBA Date: Sat, 4 Nov 2023 10:56:05 +0000 Subject: [PATCH] Support for Snowflake encrypted private key environment variable (#649) Adds a snowflake mapping for encrypted private key using an environment variable Closes: #632 Breaking Change? This does rename the previous SnowflakeEncryptedPrivateKeyFilePemProfileMapping to SnowflakeEncryptedPrivateKeyFilePemProfileMapping but this makes it clearer as a new SnowflakeEncryptedPrivateKeyPemProfileMapping is added which supports the env variable. Also was only released as a pre-release change --- cosmos/profiles/__init__.py | 5 +- cosmos/profiles/snowflake/__init__.py | 4 +- .../user_encrypted_privatekey_env_variable.py | 93 ++++++++ ...y.py => user_encrypted_privatekey_file.py} | 13 +- cosmos/profiles/snowflake/user_pass.py | 5 +- ...user_encrypted_privatekey_env_variable.py} | 15 +- ...nowflake_user_encrypted_privatekey_file.py | 216 ++++++++++++++++++ 7 files changed, 338 insertions(+), 13 deletions(-) create mode 100644 cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py rename cosmos/profiles/snowflake/{user_encrypted_privatekey.py => user_encrypted_privatekey_file.py} (85%) rename tests/profiles/snowflake/{test_snowflake_user_encrypted_privatekey.py => test_snowflake_user_encrypted_privatekey_env_variable.py} (92%) create mode 100644 tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_file.py diff --git a/cosmos/profiles/__init__.py b/cosmos/profiles/__init__.py index 47c7309ab..8280cd950 100644 --- a/cosmos/profiles/__init__.py +++ b/cosmos/profiles/__init__.py @@ -16,7 +16,8 @@ from .redshift.user_pass import RedshiftUserPasswordProfileMapping from .snowflake.user_pass import SnowflakeUserPasswordProfileMapping from .snowflake.user_privatekey import SnowflakePrivateKeyPemProfileMapping -from .snowflake.user_encrypted_privatekey import SnowflakeEncryptedPrivateKeyPemProfileMapping +from .snowflake.user_encrypted_privatekey_file import SnowflakeEncryptedPrivateKeyFilePemProfileMapping +from .snowflake.user_encrypted_privatekey_env_variable import SnowflakeEncryptedPrivateKeyPemProfileMapping from .spark.thrift import SparkThriftProfileMapping from .trino.certificate import TrinoCertificateProfileMapping from .trino.jwt import TrinoJWTProfileMapping @@ -32,6 +33,7 @@ PostgresUserPasswordProfileMapping, RedshiftUserPasswordProfileMapping, SnowflakeUserPasswordProfileMapping, + SnowflakeEncryptedPrivateKeyFilePemProfileMapping, SnowflakeEncryptedPrivateKeyPemProfileMapping, SnowflakePrivateKeyPemProfileMapping, SparkThriftProfileMapping, @@ -71,6 +73,7 @@ def get_automatic_profile_mapping( "RedshiftUserPasswordProfileMapping", "SnowflakeUserPasswordProfileMapping", "SnowflakePrivateKeyPemProfileMapping", + "SnowflakeEncryptedPrivateKeyFilePemProfileMapping", "SparkThriftProfileMapping", "ExasolUserPasswordProfileMapping", "TrinoLDAPProfileMapping", diff --git a/cosmos/profiles/snowflake/__init__.py b/cosmos/profiles/snowflake/__init__.py index 26c3fb595..fdf323a76 100644 --- a/cosmos/profiles/snowflake/__init__.py +++ b/cosmos/profiles/snowflake/__init__.py @@ -2,10 +2,12 @@ from .user_pass import SnowflakeUserPasswordProfileMapping from .user_privatekey import SnowflakePrivateKeyPemProfileMapping -from .user_encrypted_privatekey import SnowflakeEncryptedPrivateKeyPemProfileMapping +from .user_encrypted_privatekey_file import SnowflakeEncryptedPrivateKeyFilePemProfileMapping +from .user_encrypted_privatekey_env_variable import SnowflakeEncryptedPrivateKeyPemProfileMapping __all__ = [ "SnowflakeUserPasswordProfileMapping", "SnowflakePrivateKeyPemProfileMapping", + "SnowflakeEncryptedPrivateKeyFilePemProfileMapping", "SnowflakeEncryptedPrivateKeyPemProfileMapping", ] diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py new file mode 100644 index 000000000..fecfa97fe --- /dev/null +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py @@ -0,0 +1,93 @@ +"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key." +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any + +from ..base import BaseProfileMapping + +if TYPE_CHECKING: + from airflow.models import Connection + + +class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping): + """ + Maps Airflow Snowflake connections to dbt profiles if they use a user/private key. + https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication + https://airflow.apache.org/docs/apache-airflow-providers-snowflake/stable/connections/snowflake.html + """ + + airflow_connection_type: str = "snowflake" + dbt_profile_type: str = "snowflake" + is_community: bool = True + + required_fields = [ + "account", + "user", + "database", + "warehouse", + "schema", + "private_key", + "private_key_passphrase", + ] + secret_fields = [ + "private_key", + "private_key_passphrase", + ] + airflow_param_mapping = { + "account": "extra.account", + "user": "login", + "database": "extra.database", + "warehouse": "extra.warehouse", + "schema": "schema", + "role": "extra.role", + "private_key": "extra.private_key_content", + "private_key_passphrase": "password", + } + + def can_claim_connection(self) -> bool: + # Make sure this isn't a private key path credential + result = super().can_claim_connection() + if result and self.conn.extra_dejson.get("private_key_file") is not None: + return False + return result + + @property + def conn(self) -> Connection: + """ + Snowflake can be odd because the fields used to be stored with keys in the format + 'extra__snowflake__account', but now are stored as 'account'. + + This standardizes the keys to be 'account', 'database', etc. + """ + conn = super().conn + + conn_dejson = conn.extra_dejson + + if conn_dejson.get("extra__snowflake__account"): + conn_dejson = {key.replace("extra__snowflake__", ""): value for key, value in conn_dejson.items()} + + conn.extra = json.dumps(conn_dejson) + + return conn + + @property + def profile(self) -> dict[str, Any | None]: + "Gets profile." + profile_vars = { + **self.mapped_params, + **self.profile_args, + "private_key": self.get_env_var_format("private_key"), + "private_key_passphrase": self.get_env_var_format("private_key_passphrase"), + } + + # remove any null values + return self.filter_null(profile_vars) + + def transform_account(self, account: str) -> str: + "Transform the account to the format . if it's not already." + region = self.conn.extra_dejson.get("region") + if region and region not in account: + account = f"{account}.{region}" + + return str(account) diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py similarity index 85% rename from cosmos/profiles/snowflake/user_encrypted_privatekey.py rename to cosmos/profiles/snowflake/user_encrypted_privatekey_file.py index 0623598be..6831cbd28 100644 --- a/cosmos/profiles/snowflake/user_encrypted_privatekey.py +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py @@ -1,4 +1,4 @@ -"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key." +"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key path." from __future__ import annotations import json @@ -10,9 +10,9 @@ from airflow.models import Connection -class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping): +class SnowflakeEncryptedPrivateKeyFilePemProfileMapping(BaseProfileMapping): """ - Maps Airflow Snowflake connections to dbt profiles if they use a user/private key. + Maps Airflow Snowflake connections to dbt profiles if they use a user/private key path. https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication https://airflow.apache.org/docs/apache-airflow-providers-snowflake/stable/connections/snowflake.html """ @@ -44,6 +44,13 @@ class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping): "private_key_path": "extra.private_key_file", } + def can_claim_connection(self) -> bool: + # Make sure this isn't a private key environmentvariable + result = super().can_claim_connection() + if result and self.conn.extra_dejson.get("private_key_content") is not None: + return False + return result + @property def conn(self) -> Connection: """ diff --git a/cosmos/profiles/snowflake/user_pass.py b/cosmos/profiles/snowflake/user_pass.py index 2e1025a2c..fa634d1a2 100644 --- a/cosmos/profiles/snowflake/user_pass.py +++ b/cosmos/profiles/snowflake/user_pass.py @@ -44,7 +44,10 @@ class SnowflakeUserPasswordProfileMapping(BaseProfileMapping): def can_claim_connection(self) -> bool: # Make sure this isn't a private key path credential result = super().can_claim_connection() - if result and self.conn.extra_dejson.get("private_key_file") is not None: + if result and ( + self.conn.extra_dejson.get("private_key_file") is not None + or self.conn.extra_dejson.get("private_key_content") is not None + ): return False return result diff --git a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey.py b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py similarity index 92% rename from tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey.py rename to tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py index b61b85094..2c7515f72 100644 --- a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey.py +++ b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py @@ -1,4 +1,4 @@ -"Tests for the Snowflake user/private key profile." +"Tests for the Snowflake user/private key environmentvariable profile." import json from unittest.mock import patch @@ -29,7 +29,7 @@ def mock_snowflake_conn(): # type: ignore "region": "my_region", "database": "my_database", "warehouse": "my_warehouse", - "private_key_file": "path/to/private_key.p8", + "private_key_content": "my_private_key", } ), ) @@ -52,7 +52,7 @@ def test_connection_claiming() -> None: "account": "my_account", "database": "my_database", "warehouse": "my_warehouse", - "private_key_file": "path/to/private_key.p8", + "private_key_content": "my_private_key", } ), } @@ -130,8 +130,8 @@ def test_profile_args( assert profile_mapping.profile == { "type": mock_snowflake_conn.conn_type, "user": mock_snowflake_conn.login, + "private_key": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY') }}", "private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}", - "private_key_path": mock_snowflake_conn.extra_dejson.get("private_key_file"), "schema": mock_snowflake_conn.schema, "account": f"{mock_account}.{mock_region}", "database": mock_snowflake_conn.extra_dejson.get("database"), @@ -160,7 +160,7 @@ def test_profile_args_overrides( "type": mock_snowflake_conn.conn_type, "user": mock_snowflake_conn.login, "private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}", - "private_key_path": mock_snowflake_conn.extra_dejson.get("private_key_file"), + "private_key": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY') }}", "schema": mock_snowflake_conn.schema, "account": f"{mock_account}.{mock_region}", "database": "my_db_override", @@ -178,6 +178,7 @@ def test_profile_env_vars( mock_snowflake_conn.conn_id, ) assert profile_mapping.env_vars == { + "COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY": mock_snowflake_conn.extra_dejson.get("private_key_content"), "COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE": mock_snowflake_conn.password, } @@ -197,7 +198,7 @@ def test_old_snowflake_format() -> None: "extra__snowflake__account": "my_account", "extra__snowflake__database": "my_database", "extra__snowflake__warehouse": "my_warehouse", - "extra__snowflake__private_key_file": "path/to/private_key.p8", + "extra__snowflake__private_key_content": "my_private_key", } ), ) @@ -207,8 +208,8 @@ def test_old_snowflake_format() -> None: assert profile_mapping.profile == { "type": conn.conn_type, "user": conn.login, + "private_key": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY') }}", "private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}", - "private_key_path": conn.extra_dejson.get("private_key_file"), "schema": conn.schema, "account": conn.extra_dejson.get("account"), "database": conn.extra_dejson.get("database"), diff --git a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_file.py b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_file.py new file mode 100644 index 000000000..d8c3aedcf --- /dev/null +++ b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_file.py @@ -0,0 +1,216 @@ +"Tests for the Snowflake user/private key file profile." + +import json +from unittest.mock import patch + +import pytest +from airflow.models.connection import Connection + +from cosmos.profiles import get_automatic_profile_mapping +from cosmos.profiles.snowflake import ( + SnowflakeEncryptedPrivateKeyFilePemProfileMapping, +) + + +@pytest.fixture() +def mock_snowflake_conn(): # type: ignore + """ + Sets the connection as an environment variable. + """ + conn = Connection( + conn_id="my_snowflake_pk_connection", + conn_type="snowflake", + login="my_user", + schema="my_schema", + password="secret", + extra=json.dumps( + { + "account": "my_account", + "region": "my_region", + "database": "my_database", + "warehouse": "my_warehouse", + "private_key_file": "path/to/private_key.p8", + } + ), + ) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + yield conn + + +def test_connection_claiming() -> None: + """ + Tests that the Snowflake profile mapping claims the correct connection type. + """ + potential_values = { + "conn_type": "snowflake", + "login": "my_user", + "schema": "my_database", + "password": "secret", + "extra": json.dumps( + { + "account": "my_account", + "database": "my_database", + "warehouse": "my_warehouse", + "private_key_file": "path/to/private_key.p8", + } + ), + } + + # if we're missing any of the values, it shouldn't claim + for key in potential_values: + values = potential_values.copy() + del values[key] + conn = Connection(**values) # type: ignore + + print("testing with", values) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping( + conn, + ) + assert not profile_mapping.can_claim_connection() + + # test when we're missing the account + conn = Connection(**potential_values) # type: ignore + conn.extra = '{"database": "my_database", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}' + print("testing with", conn.extra) + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + assert not profile_mapping.can_claim_connection() + + # test when we're missing the database + conn = Connection(**potential_values) # type: ignore + conn.extra = '{"account": "my_account", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}' + print("testing with", conn.extra) + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + assert not profile_mapping.can_claim_connection() + + # test when we're missing the warehouse + conn = Connection(**potential_values) # type: ignore + conn.extra = '{"account": "my_account", "database": "my_database", "private_key_content": "my_private_key"}' + print("testing with", conn.extra) + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + assert not profile_mapping.can_claim_connection() + + # if we have them all, it should claim + conn = Connection(**potential_values) # type: ignore + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + assert profile_mapping.can_claim_connection() + + +def test_profile_mapping_selected( + mock_snowflake_conn: Connection, +) -> None: + """ + Tests that the correct profile mapping is selected. + """ + profile_mapping = get_automatic_profile_mapping( + mock_snowflake_conn.conn_id, + ) + assert isinstance(profile_mapping, SnowflakeEncryptedPrivateKeyFilePemProfileMapping) + + +def test_profile_args( + mock_snowflake_conn: Connection, +) -> None: + """ + Tests that the profile values get set correctly. + """ + profile_mapping = get_automatic_profile_mapping( + mock_snowflake_conn.conn_id, + ) + + mock_account = mock_snowflake_conn.extra_dejson.get("account") + mock_region = mock_snowflake_conn.extra_dejson.get("region") + + assert profile_mapping.profile == { + "type": mock_snowflake_conn.conn_type, + "user": mock_snowflake_conn.login, + "private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}", + "private_key_path": mock_snowflake_conn.extra_dejson.get("private_key_file"), + "schema": mock_snowflake_conn.schema, + "account": f"{mock_account}.{mock_region}", + "database": mock_snowflake_conn.extra_dejson.get("database"), + "warehouse": mock_snowflake_conn.extra_dejson.get("warehouse"), + } + + +def test_profile_args_overrides( + mock_snowflake_conn: Connection, +) -> None: + """ + Tests that you can override the profile values. + """ + profile_mapping = get_automatic_profile_mapping( + mock_snowflake_conn.conn_id, + profile_args={"database": "my_db_override"}, + ) + assert profile_mapping.profile_args == { + "database": "my_db_override", + } + + mock_account = mock_snowflake_conn.extra_dejson.get("account") + mock_region = mock_snowflake_conn.extra_dejson.get("region") + + assert profile_mapping.profile == { + "type": mock_snowflake_conn.conn_type, + "user": mock_snowflake_conn.login, + "private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}", + "private_key_path": mock_snowflake_conn.extra_dejson.get("private_key_file"), + "schema": mock_snowflake_conn.schema, + "account": f"{mock_account}.{mock_region}", + "database": "my_db_override", + "warehouse": mock_snowflake_conn.extra_dejson.get("warehouse"), + } + + +def test_profile_env_vars( + mock_snowflake_conn: Connection, +) -> None: + """ + Tests that the environment variables get set correctly. + """ + profile_mapping = get_automatic_profile_mapping( + mock_snowflake_conn.conn_id, + ) + assert profile_mapping.env_vars == { + "COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE": mock_snowflake_conn.password, + } + + +def test_old_snowflake_format() -> None: + """ + Tests that the old format still works. + """ + conn = Connection( + conn_id="my_snowflake_connection", + conn_type="snowflake", + login="my_user", + schema="my_schema", + password="secret", + extra=json.dumps( + { + "extra__snowflake__account": "my_account", + "extra__snowflake__database": "my_database", + "extra__snowflake__warehouse": "my_warehouse", + "extra__snowflake__private_key_file": "path/to/private_key.p8", + } + ), + ) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + assert profile_mapping.profile == { + "type": conn.conn_type, + "user": conn.login, + "private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}", + "private_key_path": conn.extra_dejson.get("private_key_file"), + "schema": conn.schema, + "account": conn.extra_dejson.get("account"), + "database": conn.extra_dejson.get("database"), + "warehouse": conn.extra_dejson.get("warehouse"), + }