Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix installing deps when using profile_mapping & ExecutionMode.LOCAL #659

Merged
merged 17 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,17 +206,11 @@ def run_command(
tmp_project_dir,
)

# if we need to install deps, do so
if self.install_deps:
self.run_subprocess(
command=[self.dbt_executable_path, "deps"],
env=env,
output_encoding=self.output_encoding,
cwd=tmp_project_dir,
)
with self.profile_config.ensure_profile() as (profile_path, env_vars):
with self.profile_config.ensure_profile() as profile_values:
(profile_path, env_vars) = profile_values
env.update(env_vars)
full_cmd = cmd + [

flags = [
"--profiles-dir",
str(profile_path.parent),
"--profile",
Expand All @@ -225,6 +219,18 @@ def run_command(
self.profile_config.target_name,
]

if self.install_deps:
tatiana marked this conversation as resolved.
Show resolved Hide resolved
deps_command = [self.dbt_executable_path, "deps"]
deps_command.extend(flags)
self.run_subprocess(
command=[self.dbt_executable_path, "deps"],
env=env,
output_encoding=self.output_encoding,
cwd=tmp_project_dir,
)

full_cmd = cmd + flags

logger.info("Trying to run the command:\n %s\nFrom %s", full_cmd, tmp_project_dir)
logger.info("Using environment variables keys: %s", env.keys())
result = self.run_subprocess(
Expand Down
5 changes: 4 additions & 1 deletion cosmos/profiles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from .redshift.user_pass import RedshiftUserPasswordProfileMapping
from .snowflake.user_pass import SnowflakeUserPasswordProfileMapping
from .snowflake.user_privatekey import SnowflakePrivateKeyPemProfileMapping
from .snowflake.user_encrypted_privatekey import SnowflakeEncryptedPrivateKeyPemProfileMapping
from .snowflake.user_encrypted_privatekey_file import SnowflakeEncryptedPrivateKeyFilePemProfileMapping
from .snowflake.user_encrypted_privatekey_env_variable import SnowflakeEncryptedPrivateKeyPemProfileMapping
from .spark.thrift import SparkThriftProfileMapping
from .trino.certificate import TrinoCertificateProfileMapping
from .trino.jwt import TrinoJWTProfileMapping
Expand All @@ -32,6 +33,7 @@
PostgresUserPasswordProfileMapping,
RedshiftUserPasswordProfileMapping,
SnowflakeUserPasswordProfileMapping,
SnowflakeEncryptedPrivateKeyFilePemProfileMapping,
SnowflakeEncryptedPrivateKeyPemProfileMapping,
SnowflakePrivateKeyPemProfileMapping,
SparkThriftProfileMapping,
Expand Down Expand Up @@ -71,6 +73,7 @@ def get_automatic_profile_mapping(
"RedshiftUserPasswordProfileMapping",
"SnowflakeUserPasswordProfileMapping",
"SnowflakePrivateKeyPemProfileMapping",
"SnowflakeEncryptedPrivateKeyFilePemProfileMapping",
"SparkThriftProfileMapping",
"ExasolUserPasswordProfileMapping",
"TrinoLDAPProfileMapping",
Expand Down
4 changes: 3 additions & 1 deletion cosmos/profiles/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from .user_pass import SnowflakeUserPasswordProfileMapping
from .user_privatekey import SnowflakePrivateKeyPemProfileMapping
from .user_encrypted_privatekey import SnowflakeEncryptedPrivateKeyPemProfileMapping
from .user_encrypted_privatekey_file import SnowflakeEncryptedPrivateKeyFilePemProfileMapping
from .user_encrypted_privatekey_env_variable import SnowflakeEncryptedPrivateKeyPemProfileMapping

__all__ = [
"SnowflakeUserPasswordProfileMapping",
"SnowflakePrivateKeyPemProfileMapping",
"SnowflakeEncryptedPrivateKeyFilePemProfileMapping",
"SnowflakeEncryptedPrivateKeyPemProfileMapping",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key."
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any

from ..base import BaseProfileMapping

if TYPE_CHECKING:
from airflow.models import Connection


class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping):
"""
Maps Airflow Snowflake connections to dbt profiles if they use a user/private key.
https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication
https://airflow.apache.org/docs/apache-airflow-providers-snowflake/stable/connections/snowflake.html
"""

airflow_connection_type: str = "snowflake"
dbt_profile_type: str = "snowflake"
is_community: bool = True

required_fields = [
"account",
"user",
"database",
"warehouse",
"schema",
"private_key",
"private_key_passphrase",
]
secret_fields = [
"private_key",
"private_key_passphrase",
]
airflow_param_mapping = {
"account": "extra.account",
"user": "login",
"database": "extra.database",
"warehouse": "extra.warehouse",
"schema": "schema",
"role": "extra.role",
"private_key": "extra.private_key_content",
"private_key_passphrase": "password",
}

def can_claim_connection(self) -> bool:
# Make sure this isn't a private key path credential
result = super().can_claim_connection()
if result and self.conn.extra_dejson.get("private_key_file") is not None:
return False
return result

@property
def conn(self) -> Connection:
"""
Snowflake can be odd because the fields used to be stored with keys in the format
'extra__snowflake__account', but now are stored as 'account'.

This standardizes the keys to be 'account', 'database', etc.
"""
conn = super().conn

conn_dejson = conn.extra_dejson

if conn_dejson.get("extra__snowflake__account"):
conn_dejson = {key.replace("extra__snowflake__", ""): value for key, value in conn_dejson.items()}

conn.extra = json.dumps(conn_dejson)

return conn

@property
def profile(self) -> dict[str, Any | None]:
"Gets profile."
profile_vars = {
**self.mapped_params,
**self.profile_args,
"private_key": self.get_env_var_format("private_key"),
"private_key_passphrase": self.get_env_var_format("private_key_passphrase"),
}

# remove any null values
return self.filter_null(profile_vars)

def transform_account(self, account: str) -> str:
"Transform the account to the format <account>.<region> if it's not already."
region = self.conn.extra_dejson.get("region")
if region and region not in account:
account = f"{account}.{region}"

return str(account)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key."
"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key path."
from __future__ import annotations

import json
Expand All @@ -10,9 +10,9 @@
from airflow.models import Connection


class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping):
class SnowflakeEncryptedPrivateKeyFilePemProfileMapping(BaseProfileMapping):
"""
Maps Airflow Snowflake connections to dbt profiles if they use a user/private key.
Maps Airflow Snowflake connections to dbt profiles if they use a user/private key path.
https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication
https://airflow.apache.org/docs/apache-airflow-providers-snowflake/stable/connections/snowflake.html
"""
Expand Down Expand Up @@ -44,6 +44,13 @@ class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping):
"private_key_path": "extra.private_key_file",
}

def can_claim_connection(self) -> bool:
# Make sure this isn't a private key environmentvariable
result = super().can_claim_connection()
if result and self.conn.extra_dejson.get("private_key_content") is not None:
return False
return result

@property
def conn(self) -> Connection:
"""
Expand Down
5 changes: 4 additions & 1 deletion cosmos/profiles/snowflake/user_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ class SnowflakeUserPasswordProfileMapping(BaseProfileMapping):
def can_claim_connection(self) -> bool:
# Make sure this isn't a private key path credential
result = super().can_claim_connection()
if result and self.conn.extra_dejson.get("private_key_file") is not None:
if result and (
self.conn.extra_dejson.get("private_key_file") is not None
or self.conn.extra_dejson.get("private_key_content") is not None
):
return False
return result

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"Tests for the Snowflake user/private key profile."
"Tests for the Snowflake user/private key environmentvariable profile."

import json
from unittest.mock import patch
Expand Down Expand Up @@ -29,7 +29,7 @@ def mock_snowflake_conn(): # type: ignore
"region": "my_region",
"database": "my_database",
"warehouse": "my_warehouse",
"private_key_file": "path/to/private_key.p8",
"private_key_content": "my_private_key",
}
),
)
Expand All @@ -52,7 +52,7 @@ def test_connection_claiming() -> None:
"account": "my_account",
"database": "my_database",
"warehouse": "my_warehouse",
"private_key_file": "path/to/private_key.p8",
"private_key_content": "my_private_key",
}
),
}
Expand Down Expand Up @@ -130,8 +130,8 @@ def test_profile_args(
assert profile_mapping.profile == {
"type": mock_snowflake_conn.conn_type,
"user": mock_snowflake_conn.login,
"private_key": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY') }}",
"private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}",
"private_key_path": mock_snowflake_conn.extra_dejson.get("private_key_file"),
"schema": mock_snowflake_conn.schema,
"account": f"{mock_account}.{mock_region}",
"database": mock_snowflake_conn.extra_dejson.get("database"),
Expand Down Expand Up @@ -160,7 +160,7 @@ def test_profile_args_overrides(
"type": mock_snowflake_conn.conn_type,
"user": mock_snowflake_conn.login,
"private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}",
"private_key_path": mock_snowflake_conn.extra_dejson.get("private_key_file"),
"private_key": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY') }}",
"schema": mock_snowflake_conn.schema,
"account": f"{mock_account}.{mock_region}",
"database": "my_db_override",
Expand All @@ -178,6 +178,7 @@ def test_profile_env_vars(
mock_snowflake_conn.conn_id,
)
assert profile_mapping.env_vars == {
"COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY": mock_snowflake_conn.extra_dejson.get("private_key_content"),
"COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE": mock_snowflake_conn.password,
}

Expand All @@ -197,7 +198,7 @@ def test_old_snowflake_format() -> None:
"extra__snowflake__account": "my_account",
"extra__snowflake__database": "my_database",
"extra__snowflake__warehouse": "my_warehouse",
"extra__snowflake__private_key_file": "path/to/private_key.p8",
"extra__snowflake__private_key_content": "my_private_key",
}
),
)
Expand All @@ -207,8 +208,8 @@ def test_old_snowflake_format() -> None:
assert profile_mapping.profile == {
"type": conn.conn_type,
"user": conn.login,
"private_key": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY') }}",
"private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}",
"private_key_path": conn.extra_dejson.get("private_key_file"),
"schema": conn.schema,
"account": conn.extra_dejson.get("account"),
"database": conn.extra_dejson.get("database"),
Expand Down
Loading
Loading