Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Snowflake encrypted private key environment variable #649

Merged
5 changes: 4 additions & 1 deletion cosmos/profiles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from .redshift.user_pass import RedshiftUserPasswordProfileMapping
from .snowflake.user_pass import SnowflakeUserPasswordProfileMapping
from .snowflake.user_privatekey import SnowflakePrivateKeyPemProfileMapping
from .snowflake.user_encrypted_privatekey import SnowflakeEncryptedPrivateKeyPemProfileMapping
from .snowflake.user_encrypted_privatekey_file import SnowflakeEncryptedPrivateKeyFilePemProfileMapping
from .snowflake.user_encrypted_privatekey_env_variable import SnowflakeEncryptedPrivateKeyPemProfileMapping
from .spark.thrift import SparkThriftProfileMapping
from .trino.certificate import TrinoCertificateProfileMapping
from .trino.jwt import TrinoJWTProfileMapping
Expand All @@ -32,6 +33,7 @@
PostgresUserPasswordProfileMapping,
RedshiftUserPasswordProfileMapping,
SnowflakeUserPasswordProfileMapping,
SnowflakeEncryptedPrivateKeyFilePemProfileMapping,
SnowflakeEncryptedPrivateKeyPemProfileMapping,
SnowflakePrivateKeyPemProfileMapping,
SparkThriftProfileMapping,
Expand Down Expand Up @@ -71,6 +73,7 @@ def get_automatic_profile_mapping(
"RedshiftUserPasswordProfileMapping",
"SnowflakeUserPasswordProfileMapping",
"SnowflakePrivateKeyPemProfileMapping",
"SnowflakeEncryptedPrivateKeyFilePemProfileMapping",
"SparkThriftProfileMapping",
"ExasolUserPasswordProfileMapping",
"TrinoLDAPProfileMapping",
Expand Down
4 changes: 3 additions & 1 deletion cosmos/profiles/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from .user_pass import SnowflakeUserPasswordProfileMapping
from .user_privatekey import SnowflakePrivateKeyPemProfileMapping
from .user_encrypted_privatekey import SnowflakeEncryptedPrivateKeyPemProfileMapping
from .user_encrypted_privatekey_file import SnowflakeEncryptedPrivateKeyFilePemProfileMapping
from .user_encrypted_privatekey_env_variable import SnowflakeEncryptedPrivateKeyPemProfileMapping

__all__ = [
"SnowflakeUserPasswordProfileMapping",
"SnowflakePrivateKeyPemProfileMapping",
"SnowflakeEncryptedPrivateKeyFilePemProfileMapping",
"SnowflakeEncryptedPrivateKeyPemProfileMapping",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key."
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any

from ..base import BaseProfileMapping

if TYPE_CHECKING:
from airflow.models import Connection

Check warning on line 10 in cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py

View check run for this annotation

Codecov / codecov/patch

cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py#L10

Added line #L10 was not covered by tests


class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping):
"""
Maps Airflow Snowflake connections to dbt profiles if they use a user/private key.
https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication
https://airflow.apache.org/docs/apache-airflow-providers-snowflake/stable/connections/snowflake.html
"""

airflow_connection_type: str = "snowflake"
dbt_profile_type: str = "snowflake"
is_community: bool = True

required_fields = [
"account",
"user",
"database",
"warehouse",
"schema",
"private_key",
"private_key_passphrase",
]
secret_fields = [
"private_key",
"private_key_passphrase",
]
airflow_param_mapping = {
"account": "extra.account",
"user": "login",
"database": "extra.database",
"warehouse": "extra.warehouse",
"schema": "schema",
"role": "extra.role",
"private_key": "extra.private_key_content",
"private_key_passphrase": "password",
}

def can_claim_connection(self) -> bool:
# Make sure this isn't a private key path credential
result = super().can_claim_connection()
if result and self.conn.extra_dejson.get("private_key_file") is not None:
return False

Check warning on line 52 in cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py

View check run for this annotation

Codecov / codecov/patch

cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py#L52

Added line #L52 was not covered by tests
return result

@property
def conn(self) -> Connection:
"""
Snowflake can be odd because the fields used to be stored with keys in the format
'extra__snowflake__account', but now are stored as 'account'.

This standardizes the keys to be 'account', 'database', etc.
"""
conn = super().conn

conn_dejson = conn.extra_dejson

if conn_dejson.get("extra__snowflake__account"):
conn_dejson = {key.replace("extra__snowflake__", ""): value for key, value in conn_dejson.items()}

conn.extra = json.dumps(conn_dejson)

return conn

@property
def profile(self) -> dict[str, Any | None]:
"Gets profile."
profile_vars = {
**self.mapped_params,
**self.profile_args,
"private_key": self.get_env_var_format("private_key"),
"private_key_passphrase": self.get_env_var_format("private_key_passphrase"),
}

# remove any null values
return self.filter_null(profile_vars)

def transform_account(self, account: str) -> str:
"Transform the account to the format <account>.<region> if it's not already."
region = self.conn.extra_dejson.get("region")
if region and region not in account:
account = f"{account}.{region}"

return str(account)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key."
"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key path."
from __future__ import annotations

import json
Expand All @@ -10,9 +10,9 @@
from airflow.models import Connection


class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping):
class SnowflakeEncryptedPrivateKeyFilePemProfileMapping(BaseProfileMapping):
"""
Maps Airflow Snowflake connections to dbt profiles if they use a user/private key.
Maps Airflow Snowflake connections to dbt profiles if they use a user/private key path.
https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication
https://airflow.apache.org/docs/apache-airflow-providers-snowflake/stable/connections/snowflake.html
"""
Expand Down Expand Up @@ -44,6 +44,13 @@
"private_key_path": "extra.private_key_file",
}

def can_claim_connection(self) -> bool:
# Make sure this isn't a private key environmentvariable
result = super().can_claim_connection()
if result and self.conn.extra_dejson.get("private_key_content") is not None:
return False

Check warning on line 51 in cosmos/profiles/snowflake/user_encrypted_privatekey_file.py

View check run for this annotation

Codecov / codecov/patch

cosmos/profiles/snowflake/user_encrypted_privatekey_file.py#L51

Added line #L51 was not covered by tests
return result

@property
def conn(self) -> Connection:
"""
Expand Down
5 changes: 4 additions & 1 deletion cosmos/profiles/snowflake/user_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ class SnowflakeUserPasswordProfileMapping(BaseProfileMapping):
def can_claim_connection(self) -> bool:
# Make sure this isn't a private key path credential
result = super().can_claim_connection()
if result and self.conn.extra_dejson.get("private_key_file") is not None:
if result and (
self.conn.extra_dejson.get("private_key_file") is not None
or self.conn.extra_dejson.get("private_key_content") is not None
):
return False
return result

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"Tests for the Snowflake user/private key profile."
"Tests for the Snowflake user/private key environmentvariable profile."

import json
from unittest.mock import patch
Expand Down Expand Up @@ -29,7 +29,7 @@ def mock_snowflake_conn(): # type: ignore
"region": "my_region",
"database": "my_database",
"warehouse": "my_warehouse",
"private_key_file": "path/to/private_key.p8",
"private_key_content": "my_private_key",
}
),
)
Expand All @@ -52,7 +52,7 @@ def test_connection_claiming() -> None:
"account": "my_account",
"database": "my_database",
"warehouse": "my_warehouse",
"private_key_file": "path/to/private_key.p8",
"private_key_content": "my_private_key",
}
),
}
Expand Down Expand Up @@ -130,8 +130,8 @@ def test_profile_args(
assert profile_mapping.profile == {
"type": mock_snowflake_conn.conn_type,
"user": mock_snowflake_conn.login,
"private_key": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY') }}",
"private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}",
"private_key_path": mock_snowflake_conn.extra_dejson.get("private_key_file"),
"schema": mock_snowflake_conn.schema,
"account": f"{mock_account}.{mock_region}",
"database": mock_snowflake_conn.extra_dejson.get("database"),
Expand Down Expand Up @@ -160,7 +160,7 @@ def test_profile_args_overrides(
"type": mock_snowflake_conn.conn_type,
"user": mock_snowflake_conn.login,
"private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}",
"private_key_path": mock_snowflake_conn.extra_dejson.get("private_key_file"),
"private_key": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY') }}",
"schema": mock_snowflake_conn.schema,
"account": f"{mock_account}.{mock_region}",
"database": "my_db_override",
Expand All @@ -178,6 +178,7 @@ def test_profile_env_vars(
mock_snowflake_conn.conn_id,
)
assert profile_mapping.env_vars == {
"COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY": mock_snowflake_conn.extra_dejson.get("private_key_content"),
"COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE": mock_snowflake_conn.password,
}

Expand All @@ -197,7 +198,7 @@ def test_old_snowflake_format() -> None:
"extra__snowflake__account": "my_account",
"extra__snowflake__database": "my_database",
"extra__snowflake__warehouse": "my_warehouse",
"extra__snowflake__private_key_file": "path/to/private_key.p8",
"extra__snowflake__private_key_content": "my_private_key",
}
),
)
Expand All @@ -207,8 +208,8 @@ def test_old_snowflake_format() -> None:
assert profile_mapping.profile == {
"type": conn.conn_type,
"user": conn.login,
"private_key": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY') }}",
"private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}",
"private_key_path": conn.extra_dejson.get("private_key_file"),
"schema": conn.schema,
"account": conn.extra_dejson.get("account"),
"database": conn.extra_dejson.get("database"),
Expand Down
Loading
Loading