Skip to content

Commit

Permalink
Load oauth token before connection. Fix for regex to replace VALUES. …
Browse files Browse the repository at this point in the history
…Add debug prints for oauth issues.
  • Loading branch information
Aleksandr Movchan committed Dec 20, 2024
1 parent 5ac88cd commit 922c6a2
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
2 changes: 2 additions & 0 deletions aana/configs/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class SnowflakeConfig(TypedDict, total=False):
user (str | None): The user to connect to the Snowflake server.
host (str | None): The host of the Snowflake server.
authenticator (str | None): The authenticator to use to connect to the Snowflake server (only "oauth" or None are supported).
token (str | None): The OAuth token to connect to the Snowflake (don't set it manually, it's read from the session file).
password (str | None): The password to connect to the Snowflake server.
database (str | None): The database name.
schema (str | None): The schema name.
Expand All @@ -67,6 +68,7 @@ class SnowflakeConfig(TypedDict, total=False):
user: str | None
host: str | None
authenticator: Literal["oauth"] | None
token: str | None
password: str | None
database: str | None
schema: str | None
Expand Down
7 changes: 6 additions & 1 deletion aana/storage/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,11 @@ def get_token():
)
if not SNOWFLAKE_TOKEN_PATH.exists():
raise FileNotFoundError(f"Token file not found: {SNOWFLAKE_TOKEN_PATH}") # noqa: TRY003
datastore_config["token"] = get_token()

print("datastore_config", datastore_config)
connection_string = SNOWFLAKE_URL(**datastore_config)
print("connection_string", connection_string)
engine = create_engine(
connection_string,
pool_size=db_config.pool_size,
Expand All @@ -126,8 +129,10 @@ def receive_do_connect(dialect, conn_rec, cargs, cparams):
# If we are using oauth authenticator,
# we need to refresh the token before each connection.
# Otherwise, the token will expire and the connection will fail.
print("cparams before", cparams)
if cparams.get("authenticator") == "oauth":
cparams["token"] = get_token()
print("cparams after", cparams)

@event.listens_for(engine, "before_cursor_execute")
def preprocess_parameters(
Expand Down Expand Up @@ -157,7 +162,7 @@ def compile_insert(insert_stmt, compiler, **kwargs):
# Locate the VALUES clause and replace it
def replace_values_with_select(sql):
# Regex to find `VALUES (...)` ensuring balanced parentheses
pattern = r"VALUES\s*(\((?:[^)(]+|\((?:[^)(]+|\([^)(]*\))*\))*\))"
pattern = r"VALUES\s*\(((?:[^()]+|\([^()]*\))*)\)"

Check failure

Code scanning / CodeQL

Inefficient regular expression High

This part of the regular expression may cause exponential backtracking on strings starting with 'VALUES(' and containing many repetitions of '''.
match = re.search(pattern, sql)
if match:
values_clause = match.group(1) # Captures the `(...)` after VALUES
Expand Down

0 comments on commit 922c6a2

Please sign in to comment.