Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CT-2860] Add IAM Authentication to dbt-postgres #8187

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20230722-201041.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: add IAM Authentication to dbt-postgres
time: 2023-07-22T20:10:41.9286694+02:00
custom:
Author: christopherscholz
Issue: "8186"
193 changes: 149 additions & 44 deletions plugins/postgres/dbt/adapters/postgres/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,39 @@
import psycopg2
from psycopg2.extensions import string_types

import boto3

import dbt.exceptions
from dbt.adapters.base import Credentials
from dbt.adapters.sql import SQLConnectionManager
from dbt.contracts.connection import AdapterResponse
from dbt.dataclass_schema import StrEnum
from hologram.helpers import StrLiteral
from dbt.events import AdapterLogger

from dbt.helper_types import Port
from dataclasses import dataclass
from typing import Optional
from typing import Any, Optional


logger = AdapterLogger("Postgres")


class PostgresConnectionMethod(StrEnum):
DATABASE = "database"
IAM = "iam"


@dataclass
class PostgresCredentials(Credentials):
host: str
user: str
port: Port
password: str # on postgres the password is mandatory
password: Optional[str] = None
connect_timeout: int = 10
method: Optional[PostgresConnectionMethod] = PostgresConnectionMethod.DATABASE
iam_profile: Optional[str] = None
region: Optional[str] = None
role: Optional[str] = None
search_path: Optional[str] = None
keepalives_idle: int = 0 # 0 means to use the default value
Expand All @@ -44,6 +56,17 @@ def type(self):
def unique_field(self):
return self.host

@classmethod
def validate(cls, data: Any):
super(Credentials, cls).validate(data)

method_credentials = {
PostgresConnectionMethod.DATABASE: PostgresCredentialsDatabase,
PostgresConnectionMethod.IAM: PostgresCredentialsIAM,
}

method_credentials[data.get("method", PostgresConnectionMethod.DATABASE)].validate(data)

def _connection_keys(self):
return (
"host",
Expand All @@ -52,6 +75,9 @@ def _connection_keys(self):
"database",
"schema",
"connect_timeout",
"method",
"iam_profile",
"region",
"role",
"search_path",
"keepalives_idle",
Expand All @@ -64,6 +90,125 @@ def _connection_keys(self):
)


@dataclass
class PostgresCredentialsDatabase(PostgresCredentials):
password: str
method: Optional[
StrLiteral(PostgresConnectionMethod.DATABASE)
] = PostgresConnectionMethod.DATABASE

@classmethod
def validate(cls, data: Any):
super(Credentials, cls).validate(data)


@dataclass
class PostgresCredentialsIAM(PostgresCredentials):
password: None
method: StrLiteral(PostgresConnectionMethod.IAM)
iam_profile: Optional[str] = None
region: Optional[str] = None

@classmethod
def validate(cls, data: Any):
super(Credentials, cls).validate(data)


class PostgresConnectMethodFactory:
credentials: PostgresCredentials

def __init__(self, credentials):
self.credentials = credentials

def get_connect_method(self):
method = self.credentials.method
kwargs = {
"host": self.credentials.host,
"dbname": self.credentials.database,
"port": int(self.credentials.port) if self.credentials.port else int(5432),
"user": self.credentials.user,
"connect_timeout": self.credentials.connect_timeout,
}

# we don't want to pass 0 along to connect() as postgres will try to
# call an invalid setsockopt() call (contrary to the docs).
if self.credentials.keepalives_idle:
kwargs["keepalives_idle"] = self.credentials.keepalives_idle

# psycopg2 doesn't support search_path officially,
# see https://github.com/psycopg/psycopg2/issues/465
search_path = self.credentials.search_path
if search_path is not None and search_path != "":
# see https://postgresql.org/docs/9.5/libpq-connect.html
kwargs["options"] = "-c search_path={}".format(search_path.replace(" ", "\\ "))

if self.credentials.sslmode:
kwargs["sslmode"] = self.credentials.sslmode

if self.credentials.sslcert is not None:
kwargs["sslcert"] = self.credentials.sslcert

if self.credentials.sslkey is not None:
kwargs["sslkey"] = self.credentials.sslkey

if self.credentials.sslrootcert is not None:
kwargs["sslrootcert"] = self.credentials.sslrootcert

if self.credentials.application_name:
kwargs["application_name"] = self.credentials.application_name

# Support missing 'method' for backwards compatibility
if method == PostgresConnectionMethod.DATABASE or method is None:

def connect():
logger.debug("Connecting to postgres with username/password based auth...")
c = psycopg2.connect(
password=self.credentials.password,
**kwargs,
)
if self.credentials.role:
c.cursor().execute("set role {}".format(self.credentials.role))
return c

elif method == PostgresConnectionMethod.IAM:

def connect():
logger.debug("Connecting to postgres with IAM based auth...")

session_kwargs = {}
if self.credentials.iam_profile:
session_kwargs["profile_name"] = self.credentials.iam_profile
if self.credentials.region:
session_kwargs["region_name"] = self.credentials.region
session = boto3.Session(**session_kwargs)

client = session.client("rds")
generate_db_auth_token_kwargs = {
"DBHostname": self.credentials.host,
"Port": self.credentials.port,
"DBUsername": self.credentials.user,
}
if self.credentials.region:
generate_db_auth_token_kwargs["Region"] = self.credentials.region
token = client.generate_db_auth_token(**generate_db_auth_token_kwargs)

kwargs["password"] = token

c = psycopg2.connect(
**kwargs,
)
if self.credentials.role:
c.cursor().execute("set role {}".format(self.credentials.role))
return c

else:
raise dbt.exceptions.FailedToConnectError(
"Invalid 'method' in profile: '{}'".format(method)
)

return connect


class PostgresConnectionManager(SQLConnectionManager):
TYPE = "postgres"

Expand Down Expand Up @@ -102,47 +247,7 @@ def open(cls, connection):
return connection

credentials = cls.get_credentials(connection.credentials)
kwargs = {}
# we don't want to pass 0 along to connect() as postgres will try to
# call an invalid setsockopt() call (contrary to the docs).
if credentials.keepalives_idle:
kwargs["keepalives_idle"] = credentials.keepalives_idle

# psycopg2 doesn't support search_path officially,
# see https://github.com/psycopg/psycopg2/issues/465
search_path = credentials.search_path
if search_path is not None and search_path != "":
# see https://postgresql.org/docs/9.5/libpq-connect.html
kwargs["options"] = "-c search_path={}".format(search_path.replace(" ", "\\ "))

if credentials.sslmode:
kwargs["sslmode"] = credentials.sslmode

if credentials.sslcert is not None:
kwargs["sslcert"] = credentials.sslcert

if credentials.sslkey is not None:
kwargs["sslkey"] = credentials.sslkey

if credentials.sslrootcert is not None:
kwargs["sslrootcert"] = credentials.sslrootcert

if credentials.application_name:
kwargs["application_name"] = credentials.application_name

def connect():
handle = psycopg2.connect(
dbname=credentials.database,
user=credentials.user,
host=credentials.host,
password=credentials.password,
port=credentials.port,
connect_timeout=credentials.connect_timeout,
**kwargs,
)
if credentials.role:
handle.cursor().execute("set role {}".format(credentials.role))
return handle
connect_method_factory = PostgresConnectMethodFactory(credentials)

retryable_exceptions = [
# OperationalError is subclassed by all psycopg2 Connection Exceptions and it's raised
Expand All @@ -158,7 +263,7 @@ def exponential_backoff(attempt: int):

return cls.retry_connection(
connection,
connect=connect,
connect=connect_method_factory.get_connect_method(),
logger=logger,
retry_limit=credentials.retries,
retry_timeout=exponential_backoff,
Expand Down
2 changes: 2 additions & 0 deletions plugins/postgres/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def _dbt_psycopg2_name():
"{}~=2.8".format(DBT_PSYCOPG2_NAME),
# installed via dbt-core, but referenced directly, don't pin to avoid version conflicts with dbt-core
"agate",
# pinned version, which works with dbt-redshift
"boto3~=1.26.157",
],
zip_safe=False,
classifiers=[
Expand Down
35 changes: 33 additions & 2 deletions tests/unit/test_adapter_connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import dbt.exceptions

import boto3

import psycopg2

from dbt.contracts.connection import Connection
Expand Down Expand Up @@ -460,7 +462,6 @@ def setUp(self):
schema="test-schema",
retries=2,
)
self.connection = Connection("postgres", None, self.credentials)

def test_open(self):
"""Test opening a Postgres Connection with failures in the first 3 attempts.
Expand All @@ -472,7 +473,7 @@ def test_open(self):
returns in the 4th attempt.
* The resulting attempt count should be 4.
"""
conn = self.connection
conn = Connection("postgres", None, self.credentials)
attempt = 0

def connect(*args, **kwargs):
Expand All @@ -492,3 +493,33 @@ def connect(*args, **kwargs):
assert attempt == 3
assert conn.state == "open"
assert conn.handle is True

@mock.patch("psycopg2.connect", mock.Mock())
def test_method_iam(self):
self.credentials = self.credentials.replace(
method="iam", iam_profile="test", region="us-east-1"
)

conn = Connection("postgres", None, self.credentials)

with mock.patch("boto3.Session") as mock_session:
mock_client = mock.Mock()
mock_client.generate_db_auth_token.return_value = "secret-token"
mock_session.return_value.client.return_value = mock_client

PostgresConnectionManager.open(conn)

boto3.Session.assert_called_once_with(profile_name="test", region_name="us-east-1")
mock_session.return_value.client.assert_called_once_with("rds")
mock_client.generate_db_auth_token.assert_called_once_with(
DBHostname="localhost", Port=1111, DBUsername="test-user", Region="us-east-1"
)
psycopg2.connect.assert_called_once_with(
host="localhost",
dbname="test-db",
port=1111,
user="test-user",
connect_timeout=10,
application_name="dbt",
password="secret-token",
)
11 changes: 11 additions & 0 deletions tests/unit/test_postgres_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,17 @@ def test_debug_connection_fail_nopass(self):
with self.assertRaises(DbtConfigError):
DebugTask.validate_connection(self.target_dict)

def test_debug_connection_iam_ok(self):
del self.target_dict["pass"]
self.target_dict["method"] = "iam"
with self.assertRaises(DbtConfigError):
DebugTask.validate_connection(self.target_dict)

def test_debug_connection_iam_fail_nopass(self):
self.target_dict["method"] = "iam"
with self.assertRaises(DbtConfigError):
DebugTask.validate_connection(self.target_dict)

def test_connection_fail_select(self):
self.mock_execute.side_effect = DatabaseError()
with self.assertRaises(DbtConfigError):
Expand Down