Skip to content

Commit

Permalink
Add integration tests for IAM User auth (#774)
Browse files Browse the repository at this point in the history
* move connection fixtures into the functional scope
* add iam user creds to the test.env template
* add test for database connection method
* add iam user auth test
* maintain existing behavior when not providing profile
* add AWS IAM profile in CI
* pull in new env vars in CI
* updates to make space for iam role

---------

Co-authored-by: Colin Rogers <[email protected]>
  • Loading branch information
mikealfare and colin-rogers-dbt authored May 7, 2024
1 parent 12b5cd7 commit 2d653c6
Show file tree
Hide file tree
Showing 16 changed files with 1,114 additions and 794 deletions.
7 changes: 7 additions & 0 deletions .changes/unreleased/Features-20240419-145208.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Features
body: Support IAM user auth via direct parameters, in addition to the existing profile
method
time: 2024-04-19T14:52:08.086607-04:00
custom:
Author: mikealfare
Issue: "760"
155 changes: 95 additions & 60 deletions dbt/adapters/redshift/connections.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from multiprocessing import Lock
from contextlib import contextmanager
from typing import Tuple, Union, Optional, List
from typing import Any, Callable, Dict, Tuple, Union, Optional, List
from dataclasses import dataclass, field

import agate
Expand Down Expand Up @@ -116,11 +116,13 @@ class RedshiftCredentials(Credentials):
ra3_node: Optional[bool] = False
connect_timeout: Optional[int] = None
role: Optional[str] = None
sslmode: Optional[UserSSLMode] = field(default_factory=UserSSLMode.default)
sslmode: UserSSLMode = field(default_factory=UserSSLMode.default)
retries: int = 1
region: Optional[str] = None
# opt-in by default per team deliberation on https://peps.python.org/pep-0249/#autocommit
autocommit: Optional[bool] = True
access_key_id: Optional[str] = None
secret_access_key: Optional[str] = None

_ALIASES = {"dbname": "database", "pass": "password"}

Expand All @@ -142,14 +144,14 @@ def _connection_keys(self):
"region",
"sslmode",
"region",
"iam_profile",
"autocreate",
"db_groups",
"ra3_node",
"connect_timeout",
"role",
"retries",
"autocommit",
"access_key_id",
)

@property
Expand All @@ -160,74 +162,107 @@ def unique_field(self) -> str:
class RedshiftConnectMethodFactory:
credentials: RedshiftCredentials

def __init__(self, credentials):
def __init__(self, credentials) -> None:
self.credentials = credentials

def get_connect_method(self):
method = self.credentials.method
def get_connect_method(self) -> Callable[[], redshift_connector.Connection]:

# Support missing 'method' for backwards compatibility
method = self.credentials.method or RedshiftConnectionMethod.DATABASE
if method == RedshiftConnectionMethod.DATABASE:
kwargs = self._database_kwargs
elif method == RedshiftConnectionMethod.IAM:
kwargs = self._iam_user_kwargs
else:
raise FailedToConnectError(f"Invalid 'method' in profile: '{method}'")

def connect() -> redshift_connector.Connection:
c = redshift_connector.connect(**kwargs)
if self.credentials.autocommit:
c.autocommit = True
if self.credentials.role:
c.cursor().execute(f"set role {self.credentials.role}")
return c

return connect

@property
def _database_kwargs(self) -> Dict[str, Any]:
logger.debug("Connecting to redshift with 'database' credentials method")
kwargs = self._base_kwargs

if self.credentials.user and self.credentials.password:
kwargs.update(
user=self.credentials.user,
password=self.credentials.password,
)
else:
raise FailedToConnectError(
"'user' and 'password' fields are required for 'database' credentials method"
)

return kwargs

@property
def _iam_user_kwargs(self) -> Dict[str, Any]:
logger.debug("Connecting to redshift with 'iam' credentials method")
kwargs = self._iam_kwargs

if self.credentials.access_key_id and self.credentials.secret_access_key:
kwargs.update(
access_key_id=self.credentials.access_key_id,
secret_access_key=self.credentials.secret_access_key,
)
elif self.credentials.access_key_id or self.credentials.secret_access_key:
raise FailedToConnectError(
"'access_key_id' and 'secret_access_key' are both needed if providing explicit credentials"
)
else:
kwargs.update(profile=self.credentials.iam_profile)

if user := self.credentials.user:
kwargs.update(db_user=user)
else:
raise FailedToConnectError("'user' field is required for 'iam' credentials method")

return kwargs

@property
def _iam_kwargs(self) -> Dict[str, Any]:
kwargs = self._base_kwargs
kwargs.update(
iam=True,
user="",
password="",
)

if cluster_id := self.credentials.cluster_id:
kwargs.update(cluster_identifier=cluster_id)
elif "serverless" in self.credentials.host:
kwargs.update(cluster_identifier=None)
else:
raise FailedToConnectError(
"Failed to use IAM method:"
" 'cluster_id' must be provided for provisioned cluster"
" 'host' must be provided for serverless endpoint"
)

return kwargs

@property
def _base_kwargs(self) -> Dict[str, Any]:
kwargs = {
"host": self.credentials.host,
"database": self.credentials.database,
"port": int(self.credentials.port) if self.credentials.port else int(5439),
"database": self.credentials.database,
"region": self.credentials.region,
"auto_create": self.credentials.autocreate,
"db_groups": self.credentials.db_groups,
"region": self.credentials.region,
"timeout": self.credentials.connect_timeout,
}

redshift_ssl_config = RedshiftSSLConfig.parse(self.credentials.sslmode)
kwargs.update(redshift_ssl_config.to_dict())

# Support missing 'method' for backwards compatibility
if method == RedshiftConnectionMethod.DATABASE or method is None:
# this requirement is really annoying to encode into json schema,
# so validate it here
if self.credentials.password is None:
raise FailedToConnectError(
"'password' field is required for 'database' credentials"
)

def connect():
logger.debug("Connecting to redshift with username/password based auth...")
c = redshift_connector.connect(
user=self.credentials.user,
password=self.credentials.password,
**kwargs,
)
if self.credentials.autocommit:
c.autocommit = True
if self.credentials.role:
c.cursor().execute("set role {}".format(self.credentials.role))
return c

elif method == RedshiftConnectionMethod.IAM:
if not self.credentials.cluster_id and "serverless" not in self.credentials.host:
raise FailedToConnectError(
"Failed to use IAM method. 'cluster_id' must be provided for provisioned cluster. "
"'host' must be provided for serverless endpoint."
)

def connect():
logger.debug("Connecting to redshift with IAM based auth...")
c = redshift_connector.connect(
iam=True,
db_user=self.credentials.user,
password="",
user="",
cluster_identifier=self.credentials.cluster_id,
profile=self.credentials.iam_profile,
**kwargs,
)
if self.credentials.autocommit:
c.autocommit = True
if self.credentials.role:
c.cursor().execute("set role {}".format(self.credentials.role))
return c

else:
raise FailedToConnectError("Invalid 'method' in profile: '{}'".format(method))

return connect
return kwargs


class RedshiftConnectionManager(SQLConnectionManager):
Expand Down
23 changes: 13 additions & 10 deletions test.env.example
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
# Note: Make sure you have a Redshift account that is set up so these fields are easy to complete.

### Test Environment field definitions
# These will all be gathered from account information or created by you.
# Endpoint for Redshift connection

# Database Authentication Method
REDSHIFT_TEST_HOST=
# Username on your account
REDSHIFT_TEST_USER=
# Password for Redshift account
REDSHIFT_TEST_PASS=
# Local port to connect on
REDSHIFT_TEST_PORT=
# Name of Redshift database in your account to test against
REDSHIFT_TEST_DBNAME=
# Users for testing
REDSHIFT_TEST_USER=
REDSHIFT_TEST_PASS=
REDSHIFT_TEST_REGION=

# IAM User Authentication Method
REDSHIFT_TEST_CLUSTER_ID=
REDSHIFT_TEST_IAM_USER_PROFILE=
REDSHIFT_TEST_IAM_USER_ACCESS_KEY_ID=
REDSHIFT_TEST_IAM_USER_SECRET_ACCESS_KEY=

# Database users for testing
DBT_TEST_USER_1=dbt_test_user_1
DBT_TEST_USER_2=dbt_test_user_2
DBT_TEST_USER_3=dbt_test_user_3
21 changes: 0 additions & 21 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,22 +1 @@
import pytest
import os

# Import the functional fixtures as a plugin
# Note: fixtures with session scope need to be local

pytest_plugins = ["dbt.tests.fixtures.project"]


# The profile dictionary, used to write out profiles.yml
@pytest.fixture(scope="class")
def dbt_profile_target():
return {
"type": "redshift",
"threads": 1,
"retries": 6,
"host": os.getenv("REDSHIFT_TEST_HOST"),
"port": int(os.getenv("REDSHIFT_TEST_PORT")),
"user": os.getenv("REDSHIFT_TEST_USER"),
"pass": os.getenv("REDSHIFT_TEST_PASS"),
"dbname": os.getenv("REDSHIFT_TEST_DBNAME"),
}
19 changes: 19 additions & 0 deletions tests/functional/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import os

import pytest


# The profile dictionary, used to write out profiles.yml
@pytest.fixture(scope="class")
def dbt_profile_target():
return {
"type": "redshift",
"host": os.getenv("REDSHIFT_TEST_HOST"),
"port": int(os.getenv("REDSHIFT_TEST_PORT")),
"dbname": os.getenv("REDSHIFT_TEST_DBNAME"),
"user": os.getenv("REDSHIFT_TEST_USER"),
"pass": os.getenv("REDSHIFT_TEST_PASS"),
"region": os.getenv("REDSHIFT_TEST_REGION"),
"threads": 1,
"retries": 6,
}
87 changes: 87 additions & 0 deletions tests/functional/test_auth_method.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import os

import pytest

from dbt.adapters.redshift.connections import RedshiftConnectionMethod
from dbt.tests.util import run_dbt


MY_SEED = """
id,name
1,apple
2,banana
3,cherry
""".strip()


MY_VIEW = """
select * from {{ ref("my_seed") }}
"""


class AuthMethod:

@pytest.fixture(scope="class")
def seeds(self):
yield {"my_seed.csv": MY_SEED}

@pytest.fixture(scope="class")
def models(self):
yield {"my_view.sql": MY_VIEW}

def test_connection(self, project):
run_dbt(["seed"])
results = run_dbt(["run"])
assert len(results) == 1


class TestDatabaseMethod(AuthMethod):
@pytest.fixture(scope="class")
def dbt_profile_target(self):
return {
"type": "redshift",
"method": RedshiftConnectionMethod.DATABASE.value,
"host": os.getenv("REDSHIFT_TEST_HOST"),
"port": int(os.getenv("REDSHIFT_TEST_PORT")),
"dbname": os.getenv("REDSHIFT_TEST_DBNAME"),
"user": os.getenv("REDSHIFT_TEST_USER"),
"pass": os.getenv("REDSHIFT_TEST_PASS"),
"threads": 1,
"retries": 6,
}


class TestIAMUserMethodProfile(AuthMethod):
@pytest.fixture(scope="class")
def dbt_profile_target(self):
return {
"type": "redshift",
"method": RedshiftConnectionMethod.IAM.value,
"cluster_id": os.getenv("REDSHIFT_TEST_CLUSTER_ID"),
"dbname": os.getenv("REDSHIFT_TEST_DBNAME"),
"iam_profile": os.getenv("REDSHIFT_TEST_IAM_USER_PROFILE"),
"user": os.getenv("REDSHIFT_TEST_USER"),
"threads": 1,
"retries": 6,
"host": "", # host is a required field in dbt-core
"port": 0, # port is a required field in dbt-core
}


class TestIAMUserMethodExplicit(AuthMethod):
@pytest.fixture(scope="class")
def dbt_profile_target(self):
return {
"type": "redshift",
"method": RedshiftConnectionMethod.IAM.value,
"cluster_id": os.getenv("REDSHIFT_TEST_CLUSTER_ID"),
"dbname": os.getenv("REDSHIFT_TEST_DBNAME"),
"access_key_id": os.getenv("REDSHIFT_TEST_IAM_USER_ACCESS_KEY_ID"),
"secret_access_key": os.getenv("REDSHIFT_TEST_IAM_USER_SECRET_ACCESS_KEY"),
"region": os.getenv("REDSHIFT_TEST_REGION"),
"user": os.getenv("REDSHIFT_TEST_USER"),
"threads": 1,
"retries": 6,
"host": "", # host is a required field in dbt-core
"port": 0, # port is a required field in dbt-core
}
2 changes: 1 addition & 1 deletion tests/unit/mock_adapter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from unittest import mock
from contextlib import contextmanager
from unittest import mock

from dbt.adapters.base import BaseAdapter

Expand Down
Loading

0 comments on commit 2d653c6

Please sign in to comment.