Skip to content

Commit

Permalink
Support Snowflake encrypted private key path
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanstillfront committed Oct 17, 2023
1 parent 28268dc commit 40e25b8
Show file tree
Hide file tree
Showing 5 changed files with 309 additions and 1 deletion.
2 changes: 2 additions & 0 deletions cosmos/profiles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .redshift.user_pass import RedshiftUserPasswordProfileMapping
from .snowflake.user_pass import SnowflakeUserPasswordProfileMapping
from .snowflake.user_privatekey import SnowflakePrivateKeyPemProfileMapping
from .snowflake.user_encrypted_privatekey import SnowflakeEncryptedPrivateKeyPemProfileMapping
from .spark.thrift import SparkThriftProfileMapping
from .trino.certificate import TrinoCertificateProfileMapping
from .trino.jwt import TrinoJWTProfileMapping
Expand All @@ -29,6 +30,7 @@
DatabricksTokenProfileMapping,
PostgresUserPasswordProfileMapping,
RedshiftUserPasswordProfileMapping,
SnowflakeEncryptedPrivateKeyPemProfileMapping,
SnowflakeUserPasswordProfileMapping,
SnowflakePrivateKeyPemProfileMapping,
SparkThriftProfileMapping,
Expand Down
7 changes: 6 additions & 1 deletion cosmos/profiles/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,10 @@

from .user_pass import SnowflakeUserPasswordProfileMapping
from .user_privatekey import SnowflakePrivateKeyPemProfileMapping
from .user_encrypted_privatekey import SnowflakeEncryptedPrivateKeyPemProfileMapping

__all__ = ["SnowflakeUserPasswordProfileMapping", "SnowflakePrivateKeyPemProfileMapping"]
__all__ = [
"SnowflakeUserPasswordProfileMapping",
"SnowflakePrivateKeyPemProfileMapping",
"SnowflakeEncryptedPrivateKeyPemProfileMapping",
]
85 changes: 85 additions & 0 deletions cosmos/profiles/snowflake/user_encrypted_privatekey.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key."
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any

from ..base import BaseProfileMapping

if TYPE_CHECKING:
from airflow.models import Connection


class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping):
"""
Maps Airflow Snowflake connections to dbt profiles if they use a user/private key.
https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication
https://airflow.apache.org/docs/apache-airflow-providers-snowflake/stable/connections/snowflake.html
"""

airflow_connection_type: str = "snowflake"
dbt_profile_type: str = "snowflake"
is_community: bool = True

required_fields = [
"account",
"user",
"database",
"warehouse",
"schema",
"private_key_passphrase",
"private_key_path",
]
secret_fields = [
"private_key_passphrase",
]
airflow_param_mapping = {
"account": "extra.account",
"user": "login",
"database": "extra.database",
"warehouse": "extra.warehouse",
"schema": "schema",
"role": "extra.role",
"private_key_passphrase": "password",
"private_key_path": "extra.private_key_file",
}

@property
def conn(self) -> Connection:
"""
Snowflake can be odd because the fields used to be stored with keys in the format
'extra__snowflake__account', but now are stored as 'account'.
This standardizes the keys to be 'account', 'database', etc.
"""
conn = super().conn

conn_dejson = conn.extra_dejson

if conn_dejson.get("extra__snowflake__account"):
conn_dejson = {key.replace("extra__snowflake__", ""): value for key, value in conn_dejson.items()}

conn.extra = json.dumps(conn_dejson)

return conn

@property
def profile(self) -> dict[str, Any | None]:
"Gets profile."
profile_vars = {
**self.mapped_params,
**self.profile_args,
# private_key_passphrase should always get set as env var
"private_key_passphrase": self.get_env_var_format("private_key_passphrase"),
}

# remove any null values
return self.filter_null(profile_vars)

def transform_account(self, account: str) -> str:
"Transform the account to the format <account>.<region> if it's not already."
region = self.conn.extra_dejson.get("region")
if region and region not in account:
account = f"{account}.{region}"

return str(account)
7 changes: 7 additions & 0 deletions cosmos/profiles/snowflake/user_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ class SnowflakeUserPasswordProfileMapping(BaseProfileMapping):
"role": "extra.role",
}

def can_claim_connection(self) -> bool:
# Make sure this isn't a private key path credential
result = super().can_claim_connection()
if result and self.conn.extra_dejson.get("private_key_file") is not None:
return False
return result

@property
def conn(self) -> Connection:
"""
Expand Down
209 changes: 209 additions & 0 deletions tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
"Tests for the Snowflake user/private key profile."

import json
from unittest.mock import patch

import pytest
from airflow.models.connection import Connection

from cosmos.profiles import get_automatic_profile_mapping
from cosmos.profiles.snowflake import (
SnowflakeEncryptedPrivateKeyPemProfileMapping,
)


@pytest.fixture()
def mock_snowflake_conn(): # type: ignore
"""
Sets the connection as an environment variable.
"""
conn = Connection(
conn_id="my_snowflake_pk_connection",
conn_type="snowflake",
login="my_user",
schema="my_schema",
password="secret",
extra=json.dumps(
{
"account": "my_account",
"database": "my_database",
"warehouse": "my_warehouse",
"private_key_file": "path/to/private_key.p8",
}
),
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn


def test_connection_claiming() -> None:
"""
Tests that the Snowflake profile mapping claims the correct connection type.
"""
potential_values = {
"conn_type": "snowflake",
"login": "my_user",
"schema": "my_database",
"password": "secret",
"extra": json.dumps(
{
"account": "my_account",
"database": "my_database",
"warehouse": "my_warehouse",
"private_key_file": "path/to/private_key.p8",
}
),
}

# if we're missing any of the values, it shouldn't claim
for key in potential_values:
values = potential_values.copy()
del values[key]
conn = Connection(**values) # type: ignore

print("testing with", values)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(
conn,
)
assert not profile_mapping.can_claim_connection()

# test when we're missing the account
conn = Connection(**potential_values) # type: ignore
conn.extra = '{"database": "my_database", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}'
print("testing with", conn.extra)
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn)
assert not profile_mapping.can_claim_connection()

# test when we're missing the database
conn = Connection(**potential_values) # type: ignore
conn.extra = '{"account": "my_account", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}'
print("testing with", conn.extra)
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn)
assert not profile_mapping.can_claim_connection()

# test when we're missing the warehouse
conn = Connection(**potential_values) # type: ignore
conn.extra = '{"account": "my_account", "database": "my_database", "private_key_content": "my_private_key"}'
print("testing with", conn.extra)
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn)
assert not profile_mapping.can_claim_connection()

# if we have them all, it should claim
conn = Connection(**potential_values) # type: ignore
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn)
assert profile_mapping.can_claim_connection()


def test_profile_mapping_selected(
mock_snowflake_conn: Connection,
) -> None:
"""
Tests that the correct profile mapping is selected.
"""
profile_mapping = get_automatic_profile_mapping(
mock_snowflake_conn.conn_id,
)
assert isinstance(profile_mapping, SnowflakeEncryptedPrivateKeyPemProfileMapping)


def test_profile_args(
mock_snowflake_conn: Connection,
) -> None:
"""
Tests that the profile values get set correctly.
"""
profile_mapping = get_automatic_profile_mapping(
mock_snowflake_conn.conn_id,
)

assert profile_mapping.profile == {
"type": mock_snowflake_conn.conn_type,
"user": mock_snowflake_conn.login,
"private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}",
"private_key_path": mock_snowflake_conn.extra_dejson.get("private_key_file"),
"schema": mock_snowflake_conn.schema,
"account": mock_snowflake_conn.extra_dejson.get("account"),
"database": mock_snowflake_conn.extra_dejson.get("database"),
"warehouse": mock_snowflake_conn.extra_dejson.get("warehouse"),
}


def test_profile_args_overrides(
mock_snowflake_conn: Connection,
) -> None:
"""
Tests that you can override the profile values.
"""
profile_mapping = get_automatic_profile_mapping(
mock_snowflake_conn.conn_id,
profile_args={"database": "my_db_override"},
)
assert profile_mapping.profile_args == {
"database": "my_db_override",
}

assert profile_mapping.profile == {
"type": mock_snowflake_conn.conn_type,
"user": mock_snowflake_conn.login,
"private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}",
"private_key_path": mock_snowflake_conn.extra_dejson.get("private_key_file"),
"schema": mock_snowflake_conn.schema,
"account": mock_snowflake_conn.extra_dejson.get("account"),
"database": "my_db_override",
"warehouse": mock_snowflake_conn.extra_dejson.get("warehouse"),
}


def test_profile_env_vars(
mock_snowflake_conn: Connection,
) -> None:
"""
Tests that the environment variables get set correctly.
"""
profile_mapping = get_automatic_profile_mapping(
mock_snowflake_conn.conn_id,
)
assert profile_mapping.env_vars == {
"COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE": mock_snowflake_conn.password,
}


def test_old_snowflake_format() -> None:
"""
Tests that the old format still works.
"""
conn = Connection(
conn_id="my_snowflake_connection",
conn_type="snowflake",
login="my_user",
schema="my_schema",
password="secret",
extra=json.dumps(
{
"extra__snowflake__account": "my_account",
"extra__snowflake__database": "my_database",
"extra__snowflake__warehouse": "my_warehouse",
"extra__snowflake__private_key_file": "path/to/private_key.p8",
}
),
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn)
assert profile_mapping.profile == {
"type": conn.conn_type,
"user": conn.login,
"private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}",
"private_key_path": conn.extra_dejson.get("private_key_file"),
"schema": conn.schema,
"account": conn.extra_dejson.get("account"),
"database": conn.extra_dejson.get("database"),
"warehouse": conn.extra_dejson.get("warehouse"),
}

0 comments on commit 40e25b8

Please sign in to comment.