Skip to content

Commit

Permalink
Support for Snowflake encrypted private key environment variable (ast…
Browse files Browse the repository at this point in the history
…ronomer#649)

Adds a snowflake mapping for encrypted private key using an environment variable

Closes: astronomer#632

Breaking Change?
This does rename the previous SnowflakeEncryptedPrivateKeyFilePemProfileMapping to SnowflakeEncryptedPrivateKeyFilePemProfileMapping but this makes it clearer as a new SnowflakeEncryptedPrivateKeyPemProfileMapping is added which supports the env variable. Also was only released as a pre-release change
  • Loading branch information
DanMawdsleyBA authored and arojasb3 committed Jul 14, 2024
1 parent a655e67 commit 8897f79
Show file tree
Hide file tree
Showing 7 changed files with 338 additions and 13 deletions.
5 changes: 4 additions & 1 deletion cosmos/profiles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from .redshift.user_pass import RedshiftUserPasswordProfileMapping
from .snowflake.user_pass import SnowflakeUserPasswordProfileMapping
from .snowflake.user_privatekey import SnowflakePrivateKeyPemProfileMapping
from .snowflake.user_encrypted_privatekey import SnowflakeEncryptedPrivateKeyPemProfileMapping
from .snowflake.user_encrypted_privatekey_file import SnowflakeEncryptedPrivateKeyFilePemProfileMapping
from .snowflake.user_encrypted_privatekey_env_variable import SnowflakeEncryptedPrivateKeyPemProfileMapping
from .spark.thrift import SparkThriftProfileMapping
from .trino.certificate import TrinoCertificateProfileMapping
from .trino.jwt import TrinoJWTProfileMapping
Expand All @@ -32,6 +33,7 @@
PostgresUserPasswordProfileMapping,
RedshiftUserPasswordProfileMapping,
SnowflakeUserPasswordProfileMapping,
SnowflakeEncryptedPrivateKeyFilePemProfileMapping,
SnowflakeEncryptedPrivateKeyPemProfileMapping,
SnowflakePrivateKeyPemProfileMapping,
SparkThriftProfileMapping,
Expand Down Expand Up @@ -71,6 +73,7 @@ def get_automatic_profile_mapping(
"RedshiftUserPasswordProfileMapping",
"SnowflakeUserPasswordProfileMapping",
"SnowflakePrivateKeyPemProfileMapping",
"SnowflakeEncryptedPrivateKeyFilePemProfileMapping",
"SparkThriftProfileMapping",
"ExasolUserPasswordProfileMapping",
"TrinoLDAPProfileMapping",
Expand Down
4 changes: 3 additions & 1 deletion cosmos/profiles/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from .user_pass import SnowflakeUserPasswordProfileMapping
from .user_privatekey import SnowflakePrivateKeyPemProfileMapping
from .user_encrypted_privatekey import SnowflakeEncryptedPrivateKeyPemProfileMapping
from .user_encrypted_privatekey_file import SnowflakeEncryptedPrivateKeyFilePemProfileMapping
from .user_encrypted_privatekey_env_variable import SnowflakeEncryptedPrivateKeyPemProfileMapping

__all__ = [
"SnowflakeUserPasswordProfileMapping",
"SnowflakePrivateKeyPemProfileMapping",
"SnowflakeEncryptedPrivateKeyFilePemProfileMapping",
"SnowflakeEncryptedPrivateKeyPemProfileMapping",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key."
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any

from ..base import BaseProfileMapping

if TYPE_CHECKING:
from airflow.models import Connection


class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping):
"""
Maps Airflow Snowflake connections to dbt profiles if they use a user/private key.
https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication
https://airflow.apache.org/docs/apache-airflow-providers-snowflake/stable/connections/snowflake.html
"""

airflow_connection_type: str = "snowflake"
dbt_profile_type: str = "snowflake"
is_community: bool = True

required_fields = [
"account",
"user",
"database",
"warehouse",
"schema",
"private_key",
"private_key_passphrase",
]
secret_fields = [
"private_key",
"private_key_passphrase",
]
airflow_param_mapping = {
"account": "extra.account",
"user": "login",
"database": "extra.database",
"warehouse": "extra.warehouse",
"schema": "schema",
"role": "extra.role",
"private_key": "extra.private_key_content",
"private_key_passphrase": "password",
}

def can_claim_connection(self) -> bool:
# Make sure this isn't a private key path credential
result = super().can_claim_connection()
if result and self.conn.extra_dejson.get("private_key_file") is not None:
return False
return result

@property
def conn(self) -> Connection:
"""
Snowflake can be odd because the fields used to be stored with keys in the format
'extra__snowflake__account', but now are stored as 'account'.
This standardizes the keys to be 'account', 'database', etc.
"""
conn = super().conn

conn_dejson = conn.extra_dejson

if conn_dejson.get("extra__snowflake__account"):
conn_dejson = {key.replace("extra__snowflake__", ""): value for key, value in conn_dejson.items()}

conn.extra = json.dumps(conn_dejson)

return conn

@property
def profile(self) -> dict[str, Any | None]:
"Gets profile."
profile_vars = {
**self.mapped_params,
**self.profile_args,
"private_key": self.get_env_var_format("private_key"),
"private_key_passphrase": self.get_env_var_format("private_key_passphrase"),
}

# remove any null values
return self.filter_null(profile_vars)

def transform_account(self, account: str) -> str:
"Transform the account to the format <account>.<region> if it's not already."
region = self.conn.extra_dejson.get("region")
if region and region not in account:
account = f"{account}.{region}"

return str(account)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key."
"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key path."
from __future__ import annotations

import json
Expand All @@ -10,9 +10,9 @@
from airflow.models import Connection


class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping):
class SnowflakeEncryptedPrivateKeyFilePemProfileMapping(BaseProfileMapping):
"""
Maps Airflow Snowflake connections to dbt profiles if they use a user/private key.
Maps Airflow Snowflake connections to dbt profiles if they use a user/private key path.
https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication
https://airflow.apache.org/docs/apache-airflow-providers-snowflake/stable/connections/snowflake.html
"""
Expand Down Expand Up @@ -44,6 +44,13 @@ class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping):
"private_key_path": "extra.private_key_file",
}

def can_claim_connection(self) -> bool:
# Make sure this isn't a private key environmentvariable
result = super().can_claim_connection()
if result and self.conn.extra_dejson.get("private_key_content") is not None:
return False
return result

@property
def conn(self) -> Connection:
"""
Expand Down
5 changes: 4 additions & 1 deletion cosmos/profiles/snowflake/user_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ class SnowflakeUserPasswordProfileMapping(BaseProfileMapping):
def can_claim_connection(self) -> bool:
# Make sure this isn't a private key path credential
result = super().can_claim_connection()
if result and self.conn.extra_dejson.get("private_key_file") is not None:
if result and (
self.conn.extra_dejson.get("private_key_file") is not None
or self.conn.extra_dejson.get("private_key_content") is not None
):
return False
return result

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"Tests for the Snowflake user/private key profile."
"Tests for the Snowflake user/private key environmentvariable profile."

import json
from unittest.mock import patch
Expand Down Expand Up @@ -29,7 +29,7 @@ def mock_snowflake_conn(): # type: ignore
"region": "my_region",
"database": "my_database",
"warehouse": "my_warehouse",
"private_key_file": "path/to/private_key.p8",
"private_key_content": "my_private_key",
}
),
)
Expand All @@ -52,7 +52,7 @@ def test_connection_claiming() -> None:
"account": "my_account",
"database": "my_database",
"warehouse": "my_warehouse",
"private_key_file": "path/to/private_key.p8",
"private_key_content": "my_private_key",
}
),
}
Expand Down Expand Up @@ -130,8 +130,8 @@ def test_profile_args(
assert profile_mapping.profile == {
"type": mock_snowflake_conn.conn_type,
"user": mock_snowflake_conn.login,
"private_key": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY') }}",
"private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}",
"private_key_path": mock_snowflake_conn.extra_dejson.get("private_key_file"),
"schema": mock_snowflake_conn.schema,
"account": f"{mock_account}.{mock_region}",
"database": mock_snowflake_conn.extra_dejson.get("database"),
Expand Down Expand Up @@ -160,7 +160,7 @@ def test_profile_args_overrides(
"type": mock_snowflake_conn.conn_type,
"user": mock_snowflake_conn.login,
"private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}",
"private_key_path": mock_snowflake_conn.extra_dejson.get("private_key_file"),
"private_key": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY') }}",
"schema": mock_snowflake_conn.schema,
"account": f"{mock_account}.{mock_region}",
"database": "my_db_override",
Expand All @@ -178,6 +178,7 @@ def test_profile_env_vars(
mock_snowflake_conn.conn_id,
)
assert profile_mapping.env_vars == {
"COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY": mock_snowflake_conn.extra_dejson.get("private_key_content"),
"COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE": mock_snowflake_conn.password,
}

Expand All @@ -197,7 +198,7 @@ def test_old_snowflake_format() -> None:
"extra__snowflake__account": "my_account",
"extra__snowflake__database": "my_database",
"extra__snowflake__warehouse": "my_warehouse",
"extra__snowflake__private_key_file": "path/to/private_key.p8",
"extra__snowflake__private_key_content": "my_private_key",
}
),
)
Expand All @@ -207,8 +208,8 @@ def test_old_snowflake_format() -> None:
assert profile_mapping.profile == {
"type": conn.conn_type,
"user": conn.login,
"private_key": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY') }}",
"private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}",
"private_key_path": conn.extra_dejson.get("private_key_file"),
"schema": conn.schema,
"account": conn.extra_dejson.get("account"),
"database": conn.extra_dejson.get("database"),
Expand Down
Loading

0 comments on commit 8897f79

Please sign in to comment.