Skip to content

Commit

Permalink
[ENTERPRISE-1418] Add support for plain JWT authentication
Browse files Browse the repository at this point in the history
  • Loading branch information
llam15 committed Jun 10, 2024
1 parent 4480734 commit 5f6a4db
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 6 deletions.
35 changes: 29 additions & 6 deletions dbt/adapters/snowflake/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from dbt.adapters.sql import SQLConnectionManager
from dbt.adapters.events.logging import AdapterLogger
from dbt_common.events.functions import warn_or_error
from dbt.adapters.events.types import AdapterEventWarning
from dbt.adapters.events.types import AdapterEventWarning, AdapterEventError
from dbt_common.ui import line_wrap_message, warning_tag


Expand All @@ -70,7 +70,7 @@ class SnowflakeAdapterResponse(AdapterResponse):
@dataclass
class SnowflakeCredentials(Credentials):
account: str
user: str
user: Optional[str] = None
warehouse: Optional[str] = None
role: Optional[str] = None
password: Optional[str] = None
Expand All @@ -96,15 +96,29 @@ class SnowflakeCredentials(Credentials):
reuse_connections: Optional[bool] = None

def __post_init__(self):
if self.authenticator != "oauth" and (
self.oauth_client_secret or self.oauth_client_id or self.token
):
if self.authenticator != "oauth" and (self.oauth_client_secret or self.oauth_client_id):
# the user probably forgot to set 'authenticator' like I keep doing
warn_or_error(
AdapterEventWarning(
base_msg="Authenticator is not set to oauth, but an oauth-only parameter is set! Did you mean to set authenticator: oauth?"
)
)

if self.authenticator not in ["oauth", "jwt"]:
if self.token:
warn_or_error(
AdapterEventWarning(
base_msg=(
"The token parameter was set, but the authenticator was "
"not set to 'oauth' or 'jwt'."
)
)
)

if not self.user:
# The user attribute is only optional if 'authenticator' is 'jwt' or 'oauth'
warn_or_error(AdapterEventError(base_msg="'user' is a required property."))

self.account = self.account.replace("_", "-")

@property
Expand Down Expand Up @@ -146,6 +160,8 @@ def auth_args(self):
# Pull all of the optional authentication args for the connector,
# let connector handle the actual arg validation
result = {}
if self.user:
result["user"] = self.user
if self.password:
result["password"] = self.password
if self.host:
Expand Down Expand Up @@ -180,6 +196,14 @@ def auth_args(self):
)

result["token"] = token

elif self.authenticator == "jwt":
# If authenticator is 'jwt', then the 'token' value should be used
# unmodified. We expose this as 'jwt' in the profile, but the value
# passed into the snowflake.connect method should still be 'oauth'
result["token"] = self.token
result["authenticator"] = "oauth"

# enable id token cache for linux
result["client_store_temporary_credential"] = True
# enable mfa token cache for linux
Expand Down Expand Up @@ -346,7 +370,6 @@ def connect():

handle = snowflake.connector.connect(
account=creds.account,
user=creds.user,
database=creds.database,
schema=creds.schema,
warehouse=creds.warehouse,
Expand Down
32 changes: 32 additions & 0 deletions tests/unit/test_snowflake_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,38 @@ def test_authenticator_private_key_authentication_no_passphrase(self, mock_get_p
]
)

def test_authenticator_jwt_authentication(self):
self.config.credentials = self.config.credentials.replace(
authenticator="jwt", token="my-jwt-token", user=None
)
self.adapter = SnowflakeAdapter(self.config, get_context("spawn"))
conn = self.adapter.connections.set_connection_name(name="new_connection_with_new_config")

self.snowflake.assert_not_called()
conn.handle
self.snowflake.assert_has_calls(
[
mock.call(
account="test-account",
autocommit=True,
client_session_keep_alive=False,
database="test_database",
role=None,
schema="public",
warehouse="test_warehouse",
authenticator="oauth",
token="my-jwt-token",
private_key=None,
application="dbt",
client_request_mfa_token=True,
client_store_temporary_credential=True,
insecure_mode=False,
session_parameters={},
reuse_connections=None,
)
]
)

def test_query_tag(self):
self.config.credentials = self.config.credentials.replace(
password="test_password", query_tag="test_query_tag"
Expand Down

0 comments on commit 5f6a4db

Please sign in to comment.