Skip to content

Commit

Permalink
Prevent overriding dbt profile fields with profile args of "type" or …
Browse files Browse the repository at this point in the history
…"method" (astronomer#702)

The issue is described in astronomer#696 which was discovered when a user was
creating a ProfileConfig with
`GoogleCloudServiceAccountDictProfileMapping(profile_args={"method":
"service-account"})` which was overriding the dbt profile method:

https://github.com/astronomer/astronomer-cosmos/blob/24aa38e528e299ef51ca6baf32f5a6185887d432/cosmos/profiles/bigquery/service_account_keyfile_dict.py#L21

when the profile args are mapped to the created profile below:


https://github.com/astronomer/astronomer-cosmos/blob/24aa38e528e299ef51ca6baf32f5a6185887d432/cosmos/profiles/bigquery/service_account_keyfile_dict.py#L42-L52

This is not an issue with the profile mapping example above and could
happen with any profile mapping by changing the "type" from
`dbt_profile_type` or "method" (if used) from `dbt_profile_method` in
the class.

The fix in this PR is to not allow args with "type" or "method" that are
different from the class variables in `profile_args`.

I think this is better than logging a warning because if either of those
fields are different the dbt run with the created profile will fail
anyways. This also allows backwards compatibility in the case users have
these already set in their profile args and it matches the class
variables.

Closes astronomer#696
  • Loading branch information
jbandoro authored Nov 23, 2023
1 parent bc8a309 commit 8f7a04b
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 4 deletions.
31 changes: 27 additions & 4 deletions cosmos/profiles/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
if TYPE_CHECKING:
from airflow.models import Connection

DBT_PROFILE_TYPE_FIELD = "type"
DBT_PROFILE_METHOD_FIELD = "method"

logger = get_logger(__name__)


Expand All @@ -41,6 +44,26 @@ class BaseProfileMapping(ABC):
def __init__(self, conn_id: str, profile_args: dict[str, Any] | None = None):
self.conn_id = conn_id
self.profile_args = profile_args or {}
self._validate_profile_args()

def _validate_profile_args(self) -> None:
"""
Check if profile_args contains keys that should not be overridden from the
class variables when creating the profile.
"""
for profile_field in [DBT_PROFILE_TYPE_FIELD, DBT_PROFILE_METHOD_FIELD]:
if profile_field in self.profile_args and self.profile_args.get(profile_field) != getattr(
self, f"dbt_profile_{profile_field}"
):
raise CosmosValueError(
"`profile_args` for {0} has {1}='{2}' that will override the dbt profile required value of '{3}'. "
"To fix this, remove {1} from `profile_args`.".format(
self.__class__.__name__,
profile_field,
self.profile_args.get(profile_field),
getattr(self, f"dbt_profile_{profile_field}"),
)
)

@property
def conn(self) -> Connection:
Expand Down Expand Up @@ -100,11 +123,11 @@ def mock_profile(self) -> dict[str, Any]:
where live connection values don't matter.
"""
mock_profile = {
"type": self.dbt_profile_type,
DBT_PROFILE_TYPE_FIELD: self.dbt_profile_type,
}

if self.dbt_profile_method:
mock_profile["method"] = self.dbt_profile_method
mock_profile[DBT_PROFILE_METHOD_FIELD] = self.dbt_profile_method

for field in self.required_fields:
# if someone has passed in a value for this field, use it
Expand Down Expand Up @@ -199,11 +222,11 @@ def get_dbt_value(self, name: str) -> Any:
def mapped_params(self) -> dict[str, Any]:
"Turns the self.airflow_param_mapping into a dictionary of dbt fields and their values."
mapped_params = {
"type": self.dbt_profile_type,
DBT_PROFILE_TYPE_FIELD: self.dbt_profile_type,
}

if self.dbt_profile_method:
mapped_params["method"] = self.dbt_profile_method
mapped_params[DBT_PROFILE_METHOD_FIELD] = self.dbt_profile_method

for dbt_field in self.airflow_param_mapping:
mapped_params[dbt_field] = self.get_dbt_value(dbt_field)
Expand Down
31 changes: 31 additions & 0 deletions tests/profiles/test_base_profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest
from cosmos.profiles.base import BaseProfileMapping
from cosmos.exceptions import CosmosValueError


class TestProfileMapping(BaseProfileMapping):
dbt_profile_method: str = "fake-method"
dbt_profile_type: str = "fake-type"

def profile(self):
raise NotImplementedError


@pytest.mark.parametrize("profile_arg", ["type", "method"])
def test_validate_profile_args(profile_arg: str):
"""
An error should be raised if the profile_args contains a key that should not be overridden from the class variables.
"""
profile_args = {profile_arg: "fake-value"}
dbt_profile_value = getattr(TestProfileMapping, f"dbt_profile_{profile_arg}")

expected_cosmos_error = (
f"`profile_args` for TestProfileMapping has {profile_arg}='fake-value' that will override the dbt profile required value of "
f"'{dbt_profile_value}'. To fix this, remove {profile_arg} from `profile_args`."
)

with pytest.raises(CosmosValueError, match=expected_cosmos_error):
TestProfileMapping(
conn_id="fake_conn_id",
profile_args=profile_args,
)

0 comments on commit 8f7a04b

Please sign in to comment.