Skip to content

Commit

Permalink
Reload token from session file on connection.
Browse files Browse the repository at this point in the history
  • Loading branch information
Aleksandr Movchan committed Dec 19, 2024
1 parent f41e8fa commit 5ac88cd
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 25 deletions.
44 changes: 28 additions & 16 deletions aana/configs/db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from os import PathLike
from typing import Literal

from pydantic import model_validator
from pydantic_settings import BaseSettings
Expand Down Expand Up @@ -40,27 +41,37 @@ class PostgreSQLConfig(TypedDict):
class SnowflakeConfig(TypedDict, total=False):
"""Config values for Snowflake.
For now two connection methods are supported: user/password and OAuth token from session file.
For user/password connection method you need to provide: account, user, password.
For OAuth connection method you need to provide: account, host and set authenticator to "oauth".
For OAuth connection method, we only support getting the token from the session file for now since
this is what we need to deploy the app in Snowflake cloud.
Other attributes (database, schema, warehouse, role) are optional and can be set for both connection methods.
Attributes:
account (str): The account name.
user (str): The user to connect to the Snowflake server.
host (str): The host of the Snowflake server.
token (str): The token to connect to the Snowflake server.
password (str): The password to connect to the Snowflake server.
database (str): The database name.
schema (str): The schema name.
warehouse (str): The warehouse name.
role (str): The role name.
user (str | None): The user to connect to the Snowflake server.
host (str | None): The host of the Snowflake server.
authenticator (str | None): The authenticator to use to connect to the Snowflake server (only "oauth" or None are supported).
password (str | None): The password to connect to the Snowflake server.
database (str | None): The database name.
schema (str | None): The schema name.
warehouse (str | None): The warehouse name.
role (str | None): The role name.
"""

account: str
user: str
host: str
token: str
password: str
database: str
schema: str
warehouse: str
role: str
user: str | None
host: str | None
authenticator: Literal["oauth"] | None
password: str | None
database: str | None
schema: str | None
warehouse: str | None
role: str | None


class DbSettings(BaseSettings):
Expand Down Expand Up @@ -121,6 +132,7 @@ def update_from_alias_env_vars(self):
"SNOWFLAKE_WAREHOUSE": "warehouse",
"SNOWFLAKE_ROLE": "role",
"SNOWFLAKE_TOKEN": "token",
"SNOWFLAKE_AUTHENTICATOR": "authenticator",
}
for env_var, key in mapping.items():
if not self.datastore_config.get(key) and os.environ.get(env_var):
Expand Down
34 changes: 25 additions & 9 deletions aana/storage/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,24 @@ def create_snowflake_engine(db_config: "DbSettings"): # noqa: C901
Returns:
sqlalchemy.engine.Engine: SQLAlchemy engine instance.
"""
datastore_config = db_config.datastore_config

# If token is not provided, check if token file exists
SNOWFLAKE_TOKEN_PATH = Path("/snowflake/session/token")
if SNOWFLAKE_TOKEN_PATH.exists() and "token" not in datastore_config:
token = SNOWFLAKE_TOKEN_PATH.read_text()
datastore_config["token"] = token

# Set authenticator to oauth if token is provided
if "token" in datastore_config:
datastore_config["authenticator"] = "oauth"
def get_token():
return SNOWFLAKE_TOKEN_PATH.read_text()

datastore_config = db_config.datastore_config

# If authenticator is oauth, we get the token from the session file.
# For now we only support oauth authenticator and token from the session file since
# this is what we need to deploy the app in snowflake cloud.
# Here we just check if the token file exists.
if "authenticator" in datastore_config:
if datastore_config["authenticator"] != "oauth":
raise ValueError( # noqa: TRY003
f"Unsupported authenticator: {datastore_config['authenticator']}"
)
if not SNOWFLAKE_TOKEN_PATH.exists():
raise FileNotFoundError(f"Token file not found: {SNOWFLAKE_TOKEN_PATH}") # noqa: TRY003

connection_string = SNOWFLAKE_URL(**datastore_config)
engine = create_engine(
Expand All @@ -113,6 +120,15 @@ def create_snowflake_engine(db_config: "DbSettings"): # noqa: C901
pool_recycle=db_config.pool_recycle,
)

@event.listens_for(engine, "do_connect")
def receive_do_connect(dialect, conn_rec, cargs, cparams):
"""Handle connection parameters before connecting to Snowflake."""
# If we are using oauth authenticator,
# we need to refresh the token before each connection.
# Otherwise, the token will expire and the connection will fail.
if cparams.get("authenticator") == "oauth":
cparams["token"] = get_token()

@event.listens_for(engine, "before_cursor_execute")
def preprocess_parameters(
conn, cursor, statement, parameters, context, executemany
Expand Down

0 comments on commit 5ac88cd

Please sign in to comment.