diff --git a/.changes/unreleased/Features-20240425-011440.yaml b/.changes/unreleased/Features-20240425-011440.yaml new file mode 100644 index 000000000..a8197dd6f --- /dev/null +++ b/.changes/unreleased/Features-20240425-011440.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Add support for IAM Role auth +time: 2024-04-25T01:14:40.601575-04:00 +custom: + Author: mikealfare,abbywh + Issue: "623" diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index 3b54717f3..752c81e32 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -37,6 +37,7 @@ def get_message(self) -> str: class RedshiftConnectionMethod(StrEnum): DATABASE = "database" IAM = "iam" + IAM_ROLE = "iam_role" class UserSSLMode(StrEnum): @@ -102,9 +103,9 @@ def parse(cls, user_sslmode: UserSSLMode) -> "RedshiftSSLConfig": @dataclass class RedshiftCredentials(Credentials): host: str - user: str port: Port method: str = RedshiftConnectionMethod.DATABASE # type: ignore + user: Optional[str] = None password: Optional[str] = None # type: ignore cluster_id: Optional[str] = field( default=None, @@ -173,6 +174,8 @@ def get_connect_method(self) -> Callable[[], redshift_connector.Connection]: kwargs = self._database_kwargs elif method == RedshiftConnectionMethod.IAM: kwargs = self._iam_user_kwargs + elif method == RedshiftConnectionMethod.IAM_ROLE: + kwargs = self._iam_role_kwargs else: raise FailedToConnectError(f"Invalid 'method' in profile: '{method}'") @@ -227,6 +230,20 @@ def _iam_user_kwargs(self) -> Dict[str, Any]: return kwargs + @property + def _iam_role_kwargs(self) -> Dict[str, Optional[Any]]: + logger.debug("Connecting to redshift with 'iam_role' credentials method") + kwargs = self._iam_kwargs + kwargs.update( + group_federation=True, + db_user=None, + ) + + if iam_profile := self.credentials.iam_profile: + kwargs.update(profile=iam_profile) + + return kwargs + @property def _iam_kwargs(self) -> Dict[str, Any]: kwargs = self._base_kwargs diff --git a/dbt/include/redshift/profile_template.yml b/dbt/include/redshift/profile_template.yml index 41f33e87e..d78356923 100644 --- a/dbt/include/redshift/profile_template.yml +++ b/dbt/include/redshift/profile_template.yml @@ -15,6 +15,8 @@ prompts: hide_input: true iam: _fixed_method: iam + iam_role: + _fixed_method: iam_role dbname: hint: 'default database that dbt will build objects in' schema: diff --git a/test.env.example b/test.env.example index 83c682036..6816b4ec2 100644 --- a/test.env.example +++ b/test.env.example @@ -9,12 +9,17 @@ REDSHIFT_TEST_USER= REDSHIFT_TEST_PASS= REDSHIFT_TEST_REGION= -# IAM User Authentication Method +# IAM Methods REDSHIFT_TEST_CLUSTER_ID= + +# IAM User Authentication Method REDSHIFT_TEST_IAM_USER_PROFILE= REDSHIFT_TEST_IAM_USER_ACCESS_KEY_ID= REDSHIFT_TEST_IAM_USER_SECRET_ACCESS_KEY= +# IAM Role Authentication Method +REDSHIFT_TEST_IAM_ROLE_PROFILE= + # Database users for testing DBT_TEST_USER_1=dbt_test_user_1 DBT_TEST_USER_2=dbt_test_user_2 diff --git a/tests/functional/test_auth_method.py b/tests/functional/test_auth_method.py index 0eb33c0fa..b2273e02c 100644 --- a/tests/functional/test_auth_method.py +++ b/tests/functional/test_auth_method.py @@ -85,3 +85,19 @@ def dbt_profile_target(self): "host": "", # host is a required field in dbt-core "port": 0, # port is a required field in dbt-core } + + +class TestIAMRoleAuthProfile(AuthMethod): + @pytest.fixture(scope="class") + def dbt_profile_target(self): + return { + "type": "redshift", + "method": RedshiftConnectionMethod.IAM_ROLE.value, + "cluster_id": os.getenv("REDSHIFT_TEST_CLUSTER_ID"), + "dbname": os.getenv("REDSHIFT_TEST_DBNAME"), + "iam_profile": os.getenv("REDSHIFT_TEST_IAM_ROLE_PROFILE"), + "threads": 1, + "retries": 6, + "host": "", # host is a required field in dbt-core + "port": 0, # port is a required field in dbt-core + } diff --git a/tests/unit/test_auth_method.py b/tests/unit/test_auth_method.py index 5b39db354..bd9912d0c 100644 --- a/tests/unit/test_auth_method.py +++ b/tests/unit/test_auth_method.py @@ -393,3 +393,66 @@ def test_profile_invalid_serverless(self): **DEFAULT_SSL_CONFIG, ) self.assertTrue("'host' must be provided" in context.exception.msg) + + +class TestIAMRoleMethod(AuthMethod): + + def test_no_cluster_id(self): + self.config.credentials = self.config.credentials.replace(method="iam_role") + with self.assertRaises(FailedToConnectError) as context: + connect_method_factory = RedshiftConnectMethodFactory(self.config.credentials) + connect_method_factory.get_connect_method() + + self.assertTrue("'cluster_id' must be provided" in context.exception.msg) + + @mock.patch("redshift_connector.connect", MagicMock()) + def test_default(self): + self.config.credentials = self.config.credentials.replace( + method="iam_role", + cluster_id="my_redshift", + ) + connection = self.adapter.acquire_connection("dummy") + connection.handle + redshift_connector.connect.assert_called_once_with( + iam=True, + host="thishostshouldnotexist.test.us-east-1", + database="redshift", + cluster_identifier="my_redshift", + db_user=None, + password="", + user="", + region=None, + timeout=None, + auto_create=False, + db_groups=[], + port=5439, + group_federation=True, + **DEFAULT_SSL_CONFIG, + ) + + @mock.patch("redshift_connector.connect", MagicMock()) + def test_profile(self): + self.config.credentials = self.config.credentials.replace( + method="iam_role", + cluster_id="my_redshift", + iam_profile="test", + ) + connection = self.adapter.acquire_connection("dummy") + connection.handle + redshift_connector.connect.assert_called_once_with( + iam=True, + host="thishostshouldnotexist.test.us-east-1", + database="redshift", + cluster_identifier="my_redshift", + db_user=None, + password="", + user="", + region=None, + timeout=None, + auto_create=False, + db_groups=[], + profile="test", + port=5439, + group_federation=True, + **DEFAULT_SSL_CONFIG, + )