From 8f7a04b75fac734a8fc8138aba71a43585cfa5c5 Mon Sep 17 00:00:00 2001 From: Justin Bandoro <79104794+jbandoro@users.noreply.github.com> Date: Thu, 23 Nov 2023 06:07:34 -0800 Subject: [PATCH] Prevent overriding dbt profile fields with profile args of "type" or "method" (#702) The issue is described in #696 which was discovered when a user was creating a ProfileConfig with `GoogleCloudServiceAccountDictProfileMapping(profile_args={"method": "service-account"})` which was overriding the dbt profile method: https://github.com/astronomer/astronomer-cosmos/blob/24aa38e528e299ef51ca6baf32f5a6185887d432/cosmos/profiles/bigquery/service_account_keyfile_dict.py#L21 when the profile args are mapped to the created profile below: https://github.com/astronomer/astronomer-cosmos/blob/24aa38e528e299ef51ca6baf32f5a6185887d432/cosmos/profiles/bigquery/service_account_keyfile_dict.py#L42-L52 This is not an issue with the profile mapping example above and could happen with any profile mapping by changing the "type" from `dbt_profile_type` or "method" (if used) from `dbt_profile_method` in the class. The fix in this PR is to not allow args with "type" or "method" that are different from the class variables in `profile_args`. I think this is better than logging a warning because if either of those fields are different the dbt run with the created profile will fail anyways. This also allows backwards compatibility in the case users have these already set in their profile args and it matches the class variables. Closes #696 --- cosmos/profiles/base.py | 31 +++++++++++++++++++++++++---- tests/profiles/test_base_profile.py | 31 +++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 4 deletions(-) create mode 100644 tests/profiles/test_base_profile.py diff --git a/cosmos/profiles/base.py b/cosmos/profiles/base.py index 36f2ccbc5..b1cebb38b 100644 --- a/cosmos/profiles/base.py +++ b/cosmos/profiles/base.py @@ -18,6 +18,9 @@ if TYPE_CHECKING: from airflow.models import Connection +DBT_PROFILE_TYPE_FIELD = "type" +DBT_PROFILE_METHOD_FIELD = "method" + logger = get_logger(__name__) @@ -41,6 +44,26 @@ class BaseProfileMapping(ABC): def __init__(self, conn_id: str, profile_args: dict[str, Any] | None = None): self.conn_id = conn_id self.profile_args = profile_args or {} + self._validate_profile_args() + + def _validate_profile_args(self) -> None: + """ + Check if profile_args contains keys that should not be overridden from the + class variables when creating the profile. + """ + for profile_field in [DBT_PROFILE_TYPE_FIELD, DBT_PROFILE_METHOD_FIELD]: + if profile_field in self.profile_args and self.profile_args.get(profile_field) != getattr( + self, f"dbt_profile_{profile_field}" + ): + raise CosmosValueError( + "`profile_args` for {0} has {1}='{2}' that will override the dbt profile required value of '{3}'. " + "To fix this, remove {1} from `profile_args`.".format( + self.__class__.__name__, + profile_field, + self.profile_args.get(profile_field), + getattr(self, f"dbt_profile_{profile_field}"), + ) + ) @property def conn(self) -> Connection: @@ -100,11 +123,11 @@ def mock_profile(self) -> dict[str, Any]: where live connection values don't matter. """ mock_profile = { - "type": self.dbt_profile_type, + DBT_PROFILE_TYPE_FIELD: self.dbt_profile_type, } if self.dbt_profile_method: - mock_profile["method"] = self.dbt_profile_method + mock_profile[DBT_PROFILE_METHOD_FIELD] = self.dbt_profile_method for field in self.required_fields: # if someone has passed in a value for this field, use it @@ -199,11 +222,11 @@ def get_dbt_value(self, name: str) -> Any: def mapped_params(self) -> dict[str, Any]: "Turns the self.airflow_param_mapping into a dictionary of dbt fields and their values." mapped_params = { - "type": self.dbt_profile_type, + DBT_PROFILE_TYPE_FIELD: self.dbt_profile_type, } if self.dbt_profile_method: - mapped_params["method"] = self.dbt_profile_method + mapped_params[DBT_PROFILE_METHOD_FIELD] = self.dbt_profile_method for dbt_field in self.airflow_param_mapping: mapped_params[dbt_field] = self.get_dbt_value(dbt_field) diff --git a/tests/profiles/test_base_profile.py b/tests/profiles/test_base_profile.py new file mode 100644 index 000000000..1b1ba3e8a --- /dev/null +++ b/tests/profiles/test_base_profile.py @@ -0,0 +1,31 @@ +import pytest +from cosmos.profiles.base import BaseProfileMapping +from cosmos.exceptions import CosmosValueError + + +class TestProfileMapping(BaseProfileMapping): + dbt_profile_method: str = "fake-method" + dbt_profile_type: str = "fake-type" + + def profile(self): + raise NotImplementedError + + +@pytest.mark.parametrize("profile_arg", ["type", "method"]) +def test_validate_profile_args(profile_arg: str): + """ + An error should be raised if the profile_args contains a key that should not be overridden from the class variables. + """ + profile_args = {profile_arg: "fake-value"} + dbt_profile_value = getattr(TestProfileMapping, f"dbt_profile_{profile_arg}") + + expected_cosmos_error = ( + f"`profile_args` for TestProfileMapping has {profile_arg}='fake-value' that will override the dbt profile required value of " + f"'{dbt_profile_value}'. To fix this, remove {profile_arg} from `profile_args`." + ) + + with pytest.raises(CosmosValueError, match=expected_cosmos_error): + TestProfileMapping( + conn_id="fake_conn_id", + profile_args=profile_args, + )