diff --git a/cosmos/profiles/base.py b/cosmos/profiles/base.py index 36f2ccbc5..b1cebb38b 100644 --- a/cosmos/profiles/base.py +++ b/cosmos/profiles/base.py @@ -18,6 +18,9 @@ if TYPE_CHECKING: from airflow.models import Connection +DBT_PROFILE_TYPE_FIELD = "type" +DBT_PROFILE_METHOD_FIELD = "method" + logger = get_logger(__name__) @@ -41,6 +44,26 @@ class BaseProfileMapping(ABC): def __init__(self, conn_id: str, profile_args: dict[str, Any] | None = None): self.conn_id = conn_id self.profile_args = profile_args or {} + self._validate_profile_args() + + def _validate_profile_args(self) -> None: + """ + Check if profile_args contains keys that should not be overridden from the + class variables when creating the profile. + """ + for profile_field in [DBT_PROFILE_TYPE_FIELD, DBT_PROFILE_METHOD_FIELD]: + if profile_field in self.profile_args and self.profile_args.get(profile_field) != getattr( + self, f"dbt_profile_{profile_field}" + ): + raise CosmosValueError( + "`profile_args` for {0} has {1}='{2}' that will override the dbt profile required value of '{3}'. " + "To fix this, remove {1} from `profile_args`.".format( + self.__class__.__name__, + profile_field, + self.profile_args.get(profile_field), + getattr(self, f"dbt_profile_{profile_field}"), + ) + ) @property def conn(self) -> Connection: @@ -100,11 +123,11 @@ def mock_profile(self) -> dict[str, Any]: where live connection values don't matter. """ mock_profile = { - "type": self.dbt_profile_type, + DBT_PROFILE_TYPE_FIELD: self.dbt_profile_type, } if self.dbt_profile_method: - mock_profile["method"] = self.dbt_profile_method + mock_profile[DBT_PROFILE_METHOD_FIELD] = self.dbt_profile_method for field in self.required_fields: # if someone has passed in a value for this field, use it @@ -199,11 +222,11 @@ def get_dbt_value(self, name: str) -> Any: def mapped_params(self) -> dict[str, Any]: "Turns the self.airflow_param_mapping into a dictionary of dbt fields and their values." mapped_params = { - "type": self.dbt_profile_type, + DBT_PROFILE_TYPE_FIELD: self.dbt_profile_type, } if self.dbt_profile_method: - mapped_params["method"] = self.dbt_profile_method + mapped_params[DBT_PROFILE_METHOD_FIELD] = self.dbt_profile_method for dbt_field in self.airflow_param_mapping: mapped_params[dbt_field] = self.get_dbt_value(dbt_field) diff --git a/tests/profiles/test_base_profile.py b/tests/profiles/test_base_profile.py new file mode 100644 index 000000000..1b1ba3e8a --- /dev/null +++ b/tests/profiles/test_base_profile.py @@ -0,0 +1,31 @@ +import pytest +from cosmos.profiles.base import BaseProfileMapping +from cosmos.exceptions import CosmosValueError + + +class TestProfileMapping(BaseProfileMapping): + dbt_profile_method: str = "fake-method" + dbt_profile_type: str = "fake-type" + + def profile(self): + raise NotImplementedError + + +@pytest.mark.parametrize("profile_arg", ["type", "method"]) +def test_validate_profile_args(profile_arg: str): + """ + An error should be raised if the profile_args contains a key that should not be overridden from the class variables. + """ + profile_args = {profile_arg: "fake-value"} + dbt_profile_value = getattr(TestProfileMapping, f"dbt_profile_{profile_arg}") + + expected_cosmos_error = ( + f"`profile_args` for TestProfileMapping has {profile_arg}='fake-value' that will override the dbt profile required value of " + f"'{dbt_profile_value}'. To fix this, remove {profile_arg} from `profile_args`." + ) + + with pytest.raises(CosmosValueError, match=expected_cosmos_error): + TestProfileMapping( + conn_id="fake_conn_id", + profile_args=profile_args, + )