Skip to content

Commit

Permalink
Add record/replay support (#1106)
Browse files Browse the repository at this point in the history
* Add record/replay support.

* Add group to record types.

* Re-organize record/replay code to match dbt-adapters

* Add changelog entry.

* Update .changes/unreleased/Under the Hood-20240716-174655.yaml

Co-authored-by: Colin Rogers <[email protected]>

---------

Co-authored-by: Colin Rogers <[email protected]>
  • Loading branch information
peterallenwebb and colin-rogers-dbt authored Jul 16, 2024
1 parent 6857e6b commit d51584d
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 13 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240716-174655.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Add support for experimental record/replay testing.
time: 2024-07-16T17:46:55.11204-04:00
custom:
Author: peterallenwebb
Issue: "1106"
40 changes: 27 additions & 13 deletions dbt/adapters/snowflake/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,15 @@
DbtConfigError,
)
from dbt_common.exceptions import DbtDatabaseError
from dbt_common.record import get_record_mode_from_env, RecorderMode
from dbt.adapters.exceptions.connection import FailedToConnectError
from dbt.adapters.contracts.connection import AdapterResponse, Connection, Credentials
from dbt.adapters.sql import SQLConnectionManager
from dbt.adapters.events.logging import AdapterLogger
from dbt_common.events.functions import warn_or_error
from dbt.adapters.events.types import AdapterEventWarning, AdapterEventError
from dbt_common.ui import line_wrap_message, warning_tag
from dbt.adapters.snowflake.record import SnowflakeRecordReplayHandle

from dbt.adapters.snowflake.auth import private_key_from_file, private_key_from_string

Expand Down Expand Up @@ -372,20 +374,32 @@ def connect():

if creds.query_tag:
session_parameters.update({"QUERY_TAG": creds.query_tag})
handle = None

# In replay mode, we won't connect to a real database at all, while
# in record and diff modes we do, but insert an intermediate handle
# object which monitors native connection activity.
rec_mode = get_record_mode_from_env()
handle = None
if rec_mode != RecorderMode.REPLAY:
handle = snowflake.connector.connect(
account=creds.account,
database=creds.database,
schema=creds.schema,
warehouse=creds.warehouse,
role=creds.role,
autocommit=True,
client_session_keep_alive=creds.client_session_keep_alive,
application="dbt",
insecure_mode=creds.insecure_mode,
session_parameters=session_parameters,
**creds.auth_args(),
)

handle = snowflake.connector.connect(
account=creds.account,
database=creds.database,
schema=creds.schema,
warehouse=creds.warehouse,
role=creds.role,
autocommit=True,
client_session_keep_alive=creds.client_session_keep_alive,
application="dbt",
insecure_mode=creds.insecure_mode,
session_parameters=session_parameters,
**creds.auth_args(),
)
if rec_mode is not None:
# If using the record/replay mechanism, regardless of mode, we
# use a wrapper.
handle = SnowflakeRecordReplayHandle(handle, connection)

return handle

Expand Down
2 changes: 2 additions & 0 deletions dbt/adapters/snowflake/record/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from dbt.adapters.snowflake.record.cursor.cursor import SnowflakeRecordReplayCursor
from dbt.adapters.snowflake.record.handle import SnowflakeRecordReplayHandle
21 changes: 21 additions & 0 deletions dbt/adapters/snowflake/record/cursor/cursor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from dbt_common.record import record_function

from dbt.adapters.record import RecordReplayCursor
from dbt.adapters.snowflake.record.cursor.sfqid import CursorGetSfqidRecord
from dbt.adapters.snowflake.record.cursor.sqlstate import CursorGetSqlStateRecord


class SnowflakeRecordReplayCursor(RecordReplayCursor):
"""A custom extension of RecordReplayCursor that adds the sqlstate
and sfqid properties which are specific to snowflake-connector."""

@property
@property
@record_function(CursorGetSqlStateRecord, method=True, id_field_name="connection_name")
def sqlstate(self):
return self.native_cursor.sqlstate

@property
@record_function(CursorGetSfqidRecord, method=True, id_field_name="connection_name")
def sfqid(self):
return self.native_cursor.sfqid
21 changes: 21 additions & 0 deletions dbt/adapters/snowflake/record/cursor/sfqid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import dataclasses
from typing import Optional

from dbt_common.record import Record, Recorder


@dataclasses.dataclass
class CursorGetSfqidParams:
connection_name: str


@dataclasses.dataclass
class CursorGetSfqidResult:
msg: Optional[str]


@Recorder.register_record_type
class CursorGetSfqidRecord(Record):
params_cls = CursorGetSfqidParams
result_cls = CursorGetSfqidResult
group = "Database"
21 changes: 21 additions & 0 deletions dbt/adapters/snowflake/record/cursor/sqlstate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import dataclasses
from typing import Optional

from dbt_common.record import Record, Recorder


@dataclasses.dataclass
class CursorGetSqlStateParams:
connection_name: str


@dataclasses.dataclass
class CursorGetSqlStateResult:
msg: Optional[str]


@Recorder.register_record_type
class CursorGetSqlStateRecord(Record):
params_cls = CursorGetSqlStateParams
result_cls = CursorGetSqlStateResult
group = "Database"
12 changes: 12 additions & 0 deletions dbt/adapters/snowflake/record/handle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from dbt.adapters.record import RecordReplayHandle

from dbt.adapters.snowflake.record.cursor.cursor import SnowflakeRecordReplayCursor


class SnowflakeRecordReplayHandle(RecordReplayHandle):
"""A custom extension of RecordReplayHandle that returns a
snowflake-connector-specific SnowflakeRecordReplayCursor object."""

def cursor(self):
cursor = None if self.native_handle is None else self.native_handle.cursor()
return SnowflakeRecordReplayCursor(cursor, self.connection)

0 comments on commit d51584d

Please sign in to comment.