Skip to content

Commit

Permalink
Intial change for Snowflake encrypted private key
Browse files Browse the repository at this point in the history
  • Loading branch information
DanMawdsleyBA committed Nov 4, 2023
1 parent b64eb9a commit 529b0c5
Show file tree
Hide file tree
Showing 6 changed files with 313 additions and 6 deletions.
5 changes: 4 additions & 1 deletion cosmos/profiles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from .redshift.user_pass import RedshiftUserPasswordProfileMapping
from .snowflake.user_pass import SnowflakeUserPasswordProfileMapping
from .snowflake.user_privatekey import SnowflakePrivateKeyPemProfileMapping
from .snowflake.user_encrypted_privatekey import SnowflakeEncryptedPrivateKeyPemProfileMapping
from .snowflake.user_encrypted_privatekey_file import SnowflakeEncryptedPrivateKeyFilePemProfileMapping
from .snowflake.user_encrypted_privatekey_env_variable import SnowflakeEncryptedPrivateKeyPemProfileMapping
from .spark.thrift import SparkThriftProfileMapping
from .trino.certificate import TrinoCertificateProfileMapping
from .trino.jwt import TrinoJWTProfileMapping
Expand All @@ -32,6 +33,7 @@
PostgresUserPasswordProfileMapping,
RedshiftUserPasswordProfileMapping,
SnowflakeUserPasswordProfileMapping,
SnowflakeEncryptedPrivateKeyFilePemProfileMapping,
SnowflakeEncryptedPrivateKeyPemProfileMapping,
SnowflakePrivateKeyPemProfileMapping,
SparkThriftProfileMapping,
Expand Down Expand Up @@ -71,6 +73,7 @@ def get_automatic_profile_mapping(
"RedshiftUserPasswordProfileMapping",
"SnowflakeUserPasswordProfileMapping",
"SnowflakePrivateKeyPemProfileMapping",
"SnowflakeEncryptedPrivateKeyFilePemProfileMapping",
"SparkThriftProfileMapping",
"ExasolUserPasswordProfileMapping",
"TrinoLDAPProfileMapping",
Expand Down
4 changes: 3 additions & 1 deletion cosmos/profiles/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from .user_pass import SnowflakeUserPasswordProfileMapping
from .user_privatekey import SnowflakePrivateKeyPemProfileMapping
from .user_encrypted_privatekey import SnowflakeEncryptedPrivateKeyPemProfileMapping
from .user_encrypted_privatekey_file import SnowflakeEncryptedPrivateKeyFilePemProfileMapping
from .user_encrypted_privatekey_env_variable import SnowflakeEncryptedPrivateKeyPemProfileMapping

__all__ = [
"SnowflakeUserPasswordProfileMapping",
"SnowflakePrivateKeyPemProfileMapping",
"SnowflakeEncryptedPrivateKeyFilePemProfileMapping",
"SnowflakeEncryptedPrivateKeyPemProfileMapping",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key."
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any

from ..base import BaseProfileMapping

if TYPE_CHECKING:
from airflow.models import Connection


class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping):
"""
Maps Airflow Snowflake connections to dbt profiles if they use a user/private key.
https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication
https://airflow.apache.org/docs/apache-airflow-providers-snowflake/stable/connections/snowflake.html
"""

airflow_connection_type: str = "snowflake"
dbt_profile_type: str = "snowflake"
is_community: bool = True

required_fields = [
"account",
"user",
"database",
"warehouse",
"schema",
"private_key",
"private_key_passphrase",
]
secret_fields = [
"private_key",
"private_key_passphrase",
]
airflow_param_mapping = {
"account": "extra.account",
"user": "login",
"database": "extra.database",
"warehouse": "extra.warehouse",
"schema": "schema",
"role": "extra.role",
"private_key": "extra.private_key_content",
"private_key_passphrase": "password",
}

@property
def conn(self) -> Connection:
"""
Snowflake can be odd because the fields used to be stored with keys in the format
'extra__snowflake__account', but now are stored as 'account'.
This standardizes the keys to be 'account', 'database', etc.
"""
conn = super().conn

conn_dejson = conn.extra_dejson

if conn_dejson.get("extra__snowflake__account"):
conn_dejson = {key.replace("extra__snowflake__", ""): value for key, value in conn_dejson.items()}

conn.extra = json.dumps(conn_dejson)

return conn

@property
def profile(self) -> dict[str, Any | None]:
"Gets profile."
profile_vars = {
**self.mapped_params,
**self.profile_args,
"private_key": self.get_env_var_format("private_key"),
"private_key_passphrase": self.get_env_var_format("private_key_passphrase"),
}

# remove any null values
return self.filter_null(profile_vars)

def transform_account(self, account: str) -> str:
"Transform the account to the format <account>.<region> if it's not already."
region = self.conn.extra_dejson.get("region")
if region and region not in account:
account = f"{account}.{region}"

return str(account)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key."
"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key path."
from __future__ import annotations

import json
Expand All @@ -10,9 +10,9 @@
from airflow.models import Connection


class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping):
class SnowflakeEncryptedPrivateKeyFilePemProfileMapping(BaseProfileMapping):
"""
Maps Airflow Snowflake connections to dbt profiles if they use a user/private key.
Maps Airflow Snowflake connections to dbt profiles if they use a user/private key path.
https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication
https://airflow.apache.org/docs/apache-airflow-providers-snowflake/stable/connections/snowflake.html
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
"Tests for the Snowflake user/private key file profile."

import json
from unittest.mock import patch

import pytest
from airflow.models.connection import Connection

from cosmos.profiles import get_automatic_profile_mapping
from cosmos.profiles.snowflake import (
SnowflakeEncryptedPrivateKeyFilePemProfileMapping,
)


@pytest.fixture()
def mock_snowflake_conn(): # type: ignore
"""
Sets the connection as an environment variable.
"""
conn = Connection(
conn_id="my_snowflake_pk_connection",
conn_type="snowflake",
login="my_user",
schema="my_schema",
password="secret",
extra=json.dumps(
{
"account": "my_account",
"region": "my_region",
"database": "my_database",
"warehouse": "my_warehouse",
"private_key_content": "my_private_key",
}
),
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn


def test_connection_claiming() -> None:
"""
Tests that the Snowflake profile mapping claims the correct connection type.
"""
potential_values = {
"conn_type": "snowflake",
"login": "my_user",
"schema": "my_database",
"password": "secret",
"extra": json.dumps(
{
"account": "my_account",
"database": "my_database",
"warehouse": "my_warehouse",
"private_key_content": "my_private_key",
}
),
}

# if we're missing any of the values, it shouldn't claim
for key in potential_values:
values = potential_values.copy()
del values[key]
conn = Connection(**values) # type: ignore

print("testing with", values)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(
conn,
)
assert not profile_mapping.can_claim_connection()

# test when we're missing the account
conn = Connection(**potential_values) # type: ignore
conn.extra = '{"database": "my_database", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}'
print("testing with", conn.extra)
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn)
assert not profile_mapping.can_claim_connection()

# test when we're missing the database
conn = Connection(**potential_values) # type: ignore
conn.extra = '{"account": "my_account", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}'
print("testing with", conn.extra)
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn)
assert not profile_mapping.can_claim_connection()

# test when we're missing the warehouse
conn = Connection(**potential_values) # type: ignore
conn.extra = '{"account": "my_account", "database": "my_database", "private_key_content": "my_private_key"}'
print("testing with", conn.extra)
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn)
assert not profile_mapping.can_claim_connection()

# if we have them all, it should claim
conn = Connection(**potential_values) # type: ignore
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn)
assert profile_mapping.can_claim_connection()


def test_profile_mapping_selected(
mock_snowflake_conn: Connection,
) -> None:
"""
Tests that the correct profile mapping is selected.
"""
profile_mapping = get_automatic_profile_mapping(
mock_snowflake_conn.conn_id,
)
assert isinstance(profile_mapping, SnowflakeEncryptedPrivateKeyFilePemProfileMapping)


def test_profile_args(
mock_snowflake_conn: Connection,
) -> None:
"""
Tests that the profile values get set correctly.
"""
profile_mapping = get_automatic_profile_mapping(
mock_snowflake_conn.conn_id,
)

mock_account = mock_snowflake_conn.extra_dejson.get("account")
mock_region = mock_snowflake_conn.extra_dejson.get("region")

assert profile_mapping.profile == {
"type": mock_snowflake_conn.conn_type,
"user": mock_snowflake_conn.login,
"private_key": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY') }}",
"private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}",
"schema": mock_snowflake_conn.schema,
"account": f"{mock_account}.{mock_region}",
"database": mock_snowflake_conn.extra_dejson.get("database"),
"warehouse": mock_snowflake_conn.extra_dejson.get("warehouse"),
}


def test_profile_args_overrides(
mock_snowflake_conn: Connection,
) -> None:
"""
Tests that you can override the profile values.
"""
profile_mapping = get_automatic_profile_mapping(
mock_snowflake_conn.conn_id,
profile_args={"database": "my_db_override"},
)
assert profile_mapping.profile_args == {
"database": "my_db_override",
}

mock_account = mock_snowflake_conn.extra_dejson.get("account")
mock_region = mock_snowflake_conn.extra_dejson.get("region")

assert profile_mapping.profile == {
"type": mock_snowflake_conn.conn_type,
"user": mock_snowflake_conn.login,
"private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}",
"private_key": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY') }}",
"schema": mock_snowflake_conn.schema,
"account": f"{mock_account}.{mock_region}",
"database": "my_db_override",
"warehouse": mock_snowflake_conn.extra_dejson.get("warehouse"),
}


def test_profile_env_vars(
mock_snowflake_conn: Connection,
) -> None:
"""
Tests that the environment variables get set correctly.
"""
profile_mapping = get_automatic_profile_mapping(
mock_snowflake_conn.conn_id,
)
assert profile_mapping.env_vars == {
"COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE": mock_snowflake_conn.password,
}


def test_old_snowflake_format() -> None:
"""
Tests that the old format still works.
"""
conn = Connection(
conn_id="my_snowflake_connection",
conn_type="snowflake",
login="my_user",
schema="my_schema",
password="secret",
extra=json.dumps(
{
"extra__snowflake__account": "my_account",
"extra__snowflake__database": "my_database",
"extra__snowflake__warehouse": "my_warehouse",
"extra__snowflake__private_key_content": "my_private_key",
}
),
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn)
assert profile_mapping.profile == {
"type": conn.conn_type,
"user": conn.login,
"private_key": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY') }}",
"private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}",
"schema": conn.schema,
"account": conn.extra_dejson.get("account"),
"database": conn.extra_dejson.get("database"),
"warehouse": conn.extra_dejson.get("warehouse"),
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"Tests for the Snowflake user/private key profile."
"Tests for the Snowflake user/private key file profile."

import json
from unittest.mock import patch
Expand Down

0 comments on commit 529b0c5

Please sign in to comment.