From 5ac88cd0f163e5038c9879aa0916ec8fb6a43bde Mon Sep 17 00:00:00 2001 From: Aleksandr Movchan Date: Thu, 19 Dec 2024 12:55:38 +0000 Subject: [PATCH] Reload token from session file on connection. --- aana/configs/db.py | 44 ++++++++++++++++++++++++++++---------------- aana/storage/op.py | 34 +++++++++++++++++++++++++--------- 2 files changed, 53 insertions(+), 25 deletions(-) diff --git a/aana/configs/db.py b/aana/configs/db.py index 4f2df17d..6e8ad901 100644 --- a/aana/configs/db.py +++ b/aana/configs/db.py @@ -1,5 +1,6 @@ import os from os import PathLike +from typing import Literal from pydantic import model_validator from pydantic_settings import BaseSettings @@ -40,27 +41,37 @@ class PostgreSQLConfig(TypedDict): class SnowflakeConfig(TypedDict, total=False): """Config values for Snowflake. + For now two connection methods are supported: user/password and OAuth token from session file. + + For user/password connection method you need to provide: account, user, password. + For OAuth connection method you need to provide: account, host and set authenticator to "oauth". + + For OAuth connection method, we only support getting the token from the session file for now since + this is what we need to deploy the app in Snowflake cloud. + + Other attributes (database, schema, warehouse, role) are optional and can be set for both connection methods. + Attributes: account (str): The account name. - user (str): The user to connect to the Snowflake server. - host (str): The host of the Snowflake server. - token (str): The token to connect to the Snowflake server. - password (str): The password to connect to the Snowflake server. - database (str): The database name. - schema (str): The schema name. - warehouse (str): The warehouse name. - role (str): The role name. + user (str | None): The user to connect to the Snowflake server. + host (str | None): The host of the Snowflake server. + authenticator (str | None): The authenticator to use to connect to the Snowflake server (only "oauth" or None are supported). + password (str | None): The password to connect to the Snowflake server. + database (str | None): The database name. + schema (str | None): The schema name. + warehouse (str | None): The warehouse name. + role (str | None): The role name. """ account: str - user: str - host: str - token: str - password: str - database: str - schema: str - warehouse: str - role: str + user: str | None + host: str | None + authenticator: Literal["oauth"] | None + password: str | None + database: str | None + schema: str | None + warehouse: str | None + role: str | None class DbSettings(BaseSettings): @@ -121,6 +132,7 @@ def update_from_alias_env_vars(self): "SNOWFLAKE_WAREHOUSE": "warehouse", "SNOWFLAKE_ROLE": "role", "SNOWFLAKE_TOKEN": "token", + "SNOWFLAKE_AUTHENTICATOR": "authenticator", } for env_var, key in mapping.items(): if not self.datastore_config.get(key) and os.environ.get(env_var): diff --git a/aana/storage/op.py b/aana/storage/op.py index c134b945..01b0c36e 100644 --- a/aana/storage/op.py +++ b/aana/storage/op.py @@ -93,17 +93,24 @@ def create_snowflake_engine(db_config: "DbSettings"): # noqa: C901 Returns: sqlalchemy.engine.Engine: SQLAlchemy engine instance. """ - datastore_config = db_config.datastore_config - - # If token is not provided, check if token file exists SNOWFLAKE_TOKEN_PATH = Path("/snowflake/session/token") - if SNOWFLAKE_TOKEN_PATH.exists() and "token" not in datastore_config: - token = SNOWFLAKE_TOKEN_PATH.read_text() - datastore_config["token"] = token - # Set authenticator to oauth if token is provided - if "token" in datastore_config: - datastore_config["authenticator"] = "oauth" + def get_token(): + return SNOWFLAKE_TOKEN_PATH.read_text() + + datastore_config = db_config.datastore_config + + # If authenticator is oauth, we get the token from the session file. + # For now we only support oauth authenticator and token from the session file since + # this is what we need to deploy the app in snowflake cloud. + # Here we just check if the token file exists. + if "authenticator" in datastore_config: + if datastore_config["authenticator"] != "oauth": + raise ValueError( # noqa: TRY003 + f"Unsupported authenticator: {datastore_config['authenticator']}" + ) + if not SNOWFLAKE_TOKEN_PATH.exists(): + raise FileNotFoundError(f"Token file not found: {SNOWFLAKE_TOKEN_PATH}") # noqa: TRY003 connection_string = SNOWFLAKE_URL(**datastore_config) engine = create_engine( @@ -113,6 +120,15 @@ def create_snowflake_engine(db_config: "DbSettings"): # noqa: C901 pool_recycle=db_config.pool_recycle, ) + @event.listens_for(engine, "do_connect") + def receive_do_connect(dialect, conn_rec, cargs, cparams): + """Handle connection parameters before connecting to Snowflake.""" + # If we are using oauth authenticator, + # we need to refresh the token before each connection. + # Otherwise, the token will expire and the connection will fail. + if cparams.get("authenticator") == "oauth": + cparams["token"] = get_token() + @event.listens_for(engine, "before_cursor_execute") def preprocess_parameters( conn, cursor, statement, parameters, context, executemany