Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add browser IDC authentication method #950

Merged
merged 13 commits into from
Nov 25, 2024
Merged
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20241122-143326.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Add browser identity center authentication method.
time: 2024-11-22T14:33:26.549878-08:00
custom:
Author: versusfacit
Issue: "898"
266 changes: 165 additions & 101 deletions dbt/adapters/redshift/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,24 @@ def get_message(self) -> str:
logger = AdapterLogger("Redshift")


class IdentityCenterTokenType(StrEnum):
ACCESS_TOKEN = "ACCESS_TOKEN"
EXT_JWT = "EXT_JWT"


class RedshiftConnectionMethod(StrEnum):
DATABASE = "database"
IAM = "iam"
IAM_ROLE = "iam_role"
IAM_IDENTITY_CENTER_BROWSER = "browser_identity_center"

@staticmethod
def uses_identity_center(method: str) -> bool:
return method.endswith("identity_center")

@staticmethod
def is_iam(method: str) -> bool:
VersusFacit marked this conversation as resolved.
Show resolved Hide resolved
return not RedshiftConnectionMethod.uses_identity_center(method)


class UserSSLMode(StrEnum):
Expand Down Expand Up @@ -128,6 +142,17 @@ class RedshiftCredentials(Credentials):
access_key_id: Optional[str] = None
secret_access_key: Optional[str] = None

#
# IAM identity center methods
#

# browser
idc_region: Optional[str] = None
issuer_url: Optional[str] = None
idp_listen_port: Optional[int] = 7890
idc_client_display_name: Optional[str] = "Amazon Redshift driver"
idp_response_timeout: Optional[int] = None

_ALIASES = {"dbname": "database", "pass": "password"}

@property
Expand Down Expand Up @@ -163,131 +188,171 @@ def unique_field(self) -> str:
return self.host


class RedshiftConnectMethodFactory:
credentials: RedshiftCredentials

def __init__(self, credentials) -> None:
self.credentials = credentials

def get_connect_method(self) -> Callable[[], redshift_connector.Connection]:
def get_connection_method(
credentials: RedshiftCredentials,
) -> Callable[[], redshift_connector.Connection]:
#
# Helper Methods
#
def __validate_required_fields(method_name: str, required_fields: Tuple[str, ...]):
missing_fields: List[str] = [
field for field in required_fields if getattr(credentials, field, None) is None
]
if missing_fields:
fields_str: str = "', '".join(missing_fields)
raise FailedToConnectError(
f"'{fields_str}' field(s) are required for '{method_name}' credentials method"
)

# Support missing 'method' for backwards compatibility
method = self.credentials.method or RedshiftConnectionMethod.DATABASE
if method == RedshiftConnectionMethod.DATABASE:
kwargs = self._database_kwargs
elif method == RedshiftConnectionMethod.IAM:
kwargs = self._iam_user_kwargs
elif method == RedshiftConnectionMethod.IAM_ROLE:
kwargs = self._iam_role_kwargs
else:
raise FailedToConnectError(f"Invalid 'method' in profile: '{method}'")
def __base_kwargs(credentials) -> Dict[str, Any]:
redshift_ssl_config: Dict[str, Any] = RedshiftSSLConfig.parse(
credentials.sslmode
).to_dict()
return {
"host": credentials.host,
"port": int(credentials.port) if credentials.port else 5439,
"database": credentials.database,
"region": credentials.region,
"auto_create": credentials.autocreate,
"db_groups": credentials.db_groups,
"timeout": credentials.connect_timeout,
**redshift_ssl_config,
}

def connect() -> redshift_connector.Connection:
c = redshift_connector.connect(**kwargs)
if self.credentials.autocommit:
c.autocommit = True
if self.credentials.role:
c.cursor().execute(f"set role {self.credentials.role}")
return c
def __iam_kwargs(credentials) -> Dict[str, Any]:

return connect
# iam True except for identity center methods
iam: bool = RedshiftConnectionMethod.is_iam(credentials.method)

@property
def _database_kwargs(self) -> Dict[str, Any]:
logger.debug("Connecting to redshift with 'database' credentials method")
kwargs = self._base_kwargs

if self.credentials.user and self.credentials.password:
kwargs.update(
user=self.credentials.user,
password=self.credentials.password,
)
cluster_identifier: Optional[str]
if "serverless" in credentials.host or RedshiftConnectionMethod.uses_identity_center(
credentials.method
):
cluster_identifier = None
elif credentials.cluster_id:
cluster_identifier = credentials.cluster_id
else:
raise FailedToConnectError(
"'user' and 'password' fields are required for 'database' credentials method"
"Failed to use IAM method:"
" 'cluster_id' must be provided for provisioned cluster"
" 'host' must be provided for serverless endpoint"
)

return kwargs
iam_specific_kwargs: Dict[str, Any] = {
"iam": iam,
"user": "",
"password": "",
"cluster_identifier": cluster_identifier,
}

return __base_kwargs(credentials) | iam_specific_kwargs

@property
def _iam_user_kwargs(self) -> Dict[str, Any]:
logger.debug("Connecting to redshift with 'iam' credentials method")
kwargs = self._iam_kwargs

if self.credentials.access_key_id and self.credentials.secret_access_key:
kwargs.update(
access_key_id=self.credentials.access_key_id,
secret_access_key=self.credentials.secret_access_key,
)
elif self.credentials.access_key_id or self.credentials.secret_access_key:
def __database_kwargs(credentials) -> Dict[str, Any]:
logger.debug("Connecting to Redshift with 'database' credentials method")

__validate_required_fields("database", ("user", "password"))

db_credentials: Dict[str, Any] = {
"user": credentials.user,
"password": credentials.password,
}

return __base_kwargs(credentials) | db_credentials

def __iam_user_kwargs(credentials) -> Dict[str, Any]:
logger.debug("Connecting to Redshift with 'iam' credentials method")

iam_credentials: Dict[str, Any]
if credentials.access_key_id and credentials.secret_access_key:
iam_credentials = {
"access_key_id": credentials.access_key_id,
"secret_access_key": credentials.secret_access_key,
}
elif credentials.access_key_id or credentials.secret_access_key:
raise FailedToConnectError(
"'access_key_id' and 'secret_access_key' are both needed if providing explicit credentials"
)
else:
kwargs.update(profile=self.credentials.iam_profile)
iam_credentials = {"profile": credentials.iam_profile}

if user := self.credentials.user:
kwargs.update(db_user=user)
else:
raise FailedToConnectError("'user' field is required for 'iam' credentials method")
__validate_required_fields("iam", ("user",))
iam_credentials["db_user"] = credentials.user

return kwargs
return __iam_kwargs(credentials) | iam_credentials

@property
def _iam_role_kwargs(self) -> Dict[str, Optional[Any]]:
logger.debug("Connecting to redshift with 'iam_role' credentials method")
kwargs = self._iam_kwargs
def __iam_role_kwargs(credentials) -> Dict[str, Any]:
logger.debug("Connecting to Redshift with 'iam_role' credentials method")
role_kwargs = {
"db_user": None,
"group_federation": "serverless" not in credentials.host,
}

# It's a role, we're ignoring the user
kwargs.update(db_user=None)
if credentials.iam_profile:
role_kwargs["profile"] = credentials.iam_profile

# Serverless shouldn't get group_federation, Provisoned clusters should
if "serverless" in self.credentials.host:
kwargs.update(group_federation=False)
else:
kwargs.update(group_federation=True)
return __iam_kwargs(credentials) | role_kwargs

if iam_profile := self.credentials.iam_profile:
kwargs.update(profile=iam_profile)
def __iam_idc_browser_kwargs(credentials) -> Dict[str, Any]:
logger.debug("Connecting to Redshift with '{credentials.method}' credentials method")

return kwargs
__IDP_TIMEOUT: int = 60
__LISTEN_PORT_DEFAULT: int = 7890

@property
def _iam_kwargs(self) -> Dict[str, Any]:
kwargs = self._base_kwargs
kwargs.update(
iam=True,
user="",
password="",
__validate_required_fields(
"browser_identity_center", ("method", "idc_region", "issuer_url")
)

if "serverless" in self.credentials.host:
kwargs.update(cluster_identifier=None)
elif cluster_id := self.credentials.cluster_id:
kwargs.update(cluster_identifier=cluster_id)
else:
raise FailedToConnectError(
"Failed to use IAM method:"
" 'cluster_id' must be provided for provisioned cluster"
" 'host' must be provided for serverless endpoint"
)
idp_timeout: int = (
timeout
if (timeout := credentials.idp_response_timeout) or timeout == 0
else __IDP_TIMEOUT
)

return kwargs
idp_listen_port: int = (
port if (port := credentials.idp_listen_port) else __LISTEN_PORT_DEFAULT
)

@property
def _base_kwargs(self) -> Dict[str, Any]:
kwargs = {
"host": self.credentials.host,
"port": int(self.credentials.port) if self.credentials.port else int(5439),
"database": self.credentials.database,
"region": self.credentials.region,
"auto_create": self.credentials.autocreate,
"db_groups": self.credentials.db_groups,
"timeout": self.credentials.connect_timeout,
idc_kwargs: Dict[str, Any] = {
"credentials_provider": "BrowserIdcAuthPlugin",
"issuer_url": credentials.issuer_url,
"listen_port": idp_listen_port,
"idc_region": credentials.idc_region,
"idc_client_display_name": credentials.idc_client_display_name,
"idp_response_timeout": idp_timeout,
}
redshift_ssl_config = RedshiftSSLConfig.parse(self.credentials.sslmode)
kwargs.update(redshift_ssl_config.to_dict())
return kwargs

return __iam_kwargs(credentials) | idc_kwargs

#
# Head of function execution
#

method_to_kwargs_function = {
None: __database_kwargs,
RedshiftConnectionMethod.DATABASE: __database_kwargs,
RedshiftConnectionMethod.IAM: __iam_user_kwargs,
RedshiftConnectionMethod.IAM_ROLE: __iam_role_kwargs,
RedshiftConnectionMethod.IAM_IDENTITY_CENTER_BROWSER: __iam_idc_browser_kwargs,
}

try:
kwargs_function: Callable[[RedshiftCredentials], Dict[str, Any]] = (
VersusFacit marked this conversation as resolved.
Show resolved Hide resolved
method_to_kwargs_function[credentials.method]
)
except KeyError:
raise FailedToConnectError(f"Invalid 'method' in profile: '{credentials.method}'")

kwargs: Dict[str, Any] = kwargs_function(credentials)

def connect() -> redshift_connector.Connection:
c = redshift_connector.connect(**kwargs)
if credentials.autocommit:
c.autocommit = True
if credentials.role:
c.cursor().execute(f"set role {credentials.role}")
return c

return connect


class RedshiftConnectionManager(SQLConnectionManager):
Expand Down Expand Up @@ -373,7 +438,6 @@ def open(cls, connection):
return connection

credentials = connection.credentials
connect_method_factory = RedshiftConnectMethodFactory(credentials)

def exponential_backoff(attempt: int):
return attempt * attempt
Expand All @@ -387,7 +451,7 @@ def exponential_backoff(attempt: int):

open_connection = cls.retry_connection(
connection,
connect=connect_method_factory.get_connect_method(),
connect=get_connection_method(credentials),
logger=logger,
retry_limit=credentials.retries,
retry_timeout=exponential_backoff,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _plugin_version() -> str:
"dbt-postgres>=1.8,<1.10",
# dbt-redshift depends deeply on this package. it does not follow SemVer, therefore there have been breaking changes in previous patch releases
# Pin to the patch or minor version, and bump in each new minor version of dbt-redshift.
"redshift-connector<2.1.1,>=2.0.913,!=2.0.914",
"redshift-connector==2.1.3",
VersusFacit marked this conversation as resolved.
Show resolved Hide resolved
# add dbt-core to ensure backwards compatibility of installation, this is not a functional dependency
"dbt-core>=1.8.0b3",
# installed via dbt-core but referenced directly; don't pin to avoid version conflicts with dbt-core
Expand Down
Loading
Loading