Skip to content

Commit

Permalink
Implement Record/Replay (#90)
Browse files Browse the repository at this point in the history
* Implement record/replay mechanism.

* Refinements/fixes

* Fix unit test failures

* Formatting and typing

* Misc fixes

* Add changelog entry.

* Add unit tests.

---------

Co-authored-by: Michelle Ark <[email protected]>
  • Loading branch information
peterallenwebb and MichelleArk authored Feb 28, 2024
1 parent f873d6b commit ae8ffe0
Show file tree
Hide file tree
Showing 5 changed files with 414 additions and 0 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240227-145400.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Implement record/replay mechanism
time: 2024-02-27T14:54:00.94815-05:00
custom:
Author: peterallenwebb
Issue: "9689"
118 changes: 118 additions & 0 deletions dbt_common/clients/system.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dbt_common.exceptions.base
import dataclasses
import errno
import fnmatch
import functools
Expand All @@ -25,7 +26,9 @@
SystemReportReturnCode,
)
from dbt_common.exceptions import DbtInternalError
from dbt_common.record import record_function, Recorder, Record
from dbt_common.utils.connection import connection_exception_retry

from pathspec import PathSpec # type: ignore

if sys.platform == "win32":
Expand All @@ -35,6 +38,49 @@
c_bool = None


@dataclasses.dataclass
class FindMatchingParams:
root_path: str
relative_paths_to_search: List[str]
file_pattern: str
# ignore_spec: Optional[PathSpec] = None

def __init__(
self,
root_path: str,
relative_paths_to_search: List[str],
file_pattern: str,
ignore_spec: Optional[Any] = None,
):
self.root_path = root_path
rps = list(relative_paths_to_search)
rps.sort()
self.relative_paths_to_search = rps
self.file_pattern = file_pattern

def _include(self) -> bool:
# Do not record or replay filesystem searches that were performed against
# files which are actually part of dbt's implementation.
return (
"dbt/include/global_project" not in self.root_path
and "/plugins/postgres/dbt/include/" not in self.root_path
)


@dataclasses.dataclass
class FindMatchingResult:
matches: List[Dict[str, Any]]


@Recorder.register_record_type
class FindMatchingRecord(Record):
"""Record of calls to the directory search function find_matching()"""

params_cls = FindMatchingParams
result_cls = FindMatchingResult


@record_function(FindMatchingRecord)
def find_matching(
root_path: str,
relative_paths_to_search: List[str],
Expand Down Expand Up @@ -94,6 +140,34 @@ def find_matching(
return matching


@dataclasses.dataclass
class LoadFileParams:
path: str
strip: bool = True

def _include(self) -> bool:
# Do not record or replay file reads that were performed against files
# which are actually part of dbt's implementation.
return (
"dbt/include/global_project" not in self.path
and "/plugins/postgres/dbt/include/" not in self.path
)


@dataclasses.dataclass
class LoadFileResult:
contents: str


@Recorder.register_record_type
class LoadFileRecord(Record):
"""Record of file load operation"""

params_cls = LoadFileParams
result_cls = LoadFileResult


@record_function(LoadFileRecord)
def load_file_contents(path: str, strip: bool = True) -> str:
path = convert_path(path)
with open(path, "rb") as handle:
Expand Down Expand Up @@ -164,6 +238,29 @@ def supports_symlinks() -> bool:
return getattr(os, "symlink", None) is not None


@dataclasses.dataclass
class WriteFileParams:
path: str
contents: str

def _include(self) -> bool:
# Do not record or replay file reads that were performed against files
# which are actually part of dbt's implementation.
return (
"dbt/include/global_project" not in self.path
and "/plugins/postgres/dbt/include/" not in self.path
)


@Recorder.register_record_type
class WriteFileRecord(Record):
"""Record of a file write operation."""

params_cls = WriteFileParams
result_cls = None


@record_function(WriteFileRecord)
def write_file(path: str, contents: str = "") -> bool:
path = convert_path(path)
try:
Expand Down Expand Up @@ -573,3 +670,24 @@ def rmtree(path):
"""
path = convert_path(path)
return shutil.rmtree(path, onerror=chmod_and_retry)


@dataclasses.dataclass
class GetEnvParams:
pass


@dataclasses.dataclass
class GetEnvResult:
env: Dict[str, str]


@Recorder.register_record_type
class GetEnvRecord(Record):
params_cls = GetEnvParams
result_cls = GetEnvResult


@record_function(GetEnvRecord)
def get_env() -> Dict[str, str]:
return dict(os.environ)
1 change: 1 addition & 0 deletions dbt_common/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class InvocationContext:
def __init__(self, env: Mapping[str, str]):
self._env = env
self._env_secrets: Optional[List[str]] = None
self.recorder = None
# This class will also eventually manage the invocation_id, flags, event manager, etc.

@property
Expand Down
208 changes: 208 additions & 0 deletions dbt_common/record.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
"""The record module provides a mechanism for recording dbt's interaction with
external systems during a command invocation, so that the command can be re-run
later with the recording 'replayed' to dbt.
If dbt behaves sufficiently deterministically, we will be able to use the
record/replay mechanism in several interesting test and debugging scenarios.
"""
import functools
import dataclasses
import json
import os
from enum import Enum
from typing import Any, Dict, List, Mapping, Optional, Type

from dbt_common.context import get_invocation_context


class Record:
"""An instance of this abstract Record class represents a request made by dbt
to an external process or the operating system. The 'params' are the arguments
to the request, and the 'result' is what is returned."""

params_cls: type
result_cls: Optional[type]

def __init__(self, params, result) -> None:
self.params = params
self.result = result

def to_dict(self) -> Dict[str, Any]:
return {
"params": self.params._to_dict() if hasattr(self.params, "_to_dict") else dataclasses.asdict(self.params), # type: ignore
"result": self.result._to_dict() if hasattr(self.result, "_to_dict") else dataclasses.asdict(self.result) if self.result is not None else None, # type: ignore
}

@classmethod
def from_dict(cls, dct: Mapping) -> "Record":
p = (
cls.params_cls._from_dict(dct["params"])
if hasattr(cls.params_cls, "_from_dict")
else cls.params_cls(**dct["params"])
)
r = (
cls.result_cls._from_dict(dct["result"]) # type: ignore
if hasattr(cls.result_cls, "_from_dict")
else cls.result_cls(**dct["result"])
if cls.result_cls is not None
else None
)
return cls(params=p, result=r)


class Diff:
"""Marker class for diffs?"""

pass


class RecorderMode(Enum):
RECORD = 1
REPLAY = 2


class Recorder:
_record_cls_by_name: Dict[str, Type] = {}
_record_name_by_params_name: Dict[str, str] = {}

def __init__(self, mode: RecorderMode, recording_path: Optional[str] = None) -> None:
self.mode = mode
self._records_by_type: Dict[str, List[Record]] = {}
self._replay_diffs: List["Diff"] = []

if recording_path is not None:
self._records_by_type = self.load(recording_path)

@classmethod
def register_record_type(cls, rec_type) -> Any:
cls._record_cls_by_name[rec_type.__name__] = rec_type
cls._record_name_by_params_name[rec_type.params_cls.__name__] = rec_type.__name__
return rec_type

def add_record(self, record: Record) -> None:
rec_cls_name = record.__class__.__name__ # type: ignore
if rec_cls_name not in self._records_by_type:
self._records_by_type[rec_cls_name] = []
self._records_by_type[rec_cls_name].append(record)

def pop_matching_record(self, params: Any) -> Optional[Record]:
rec_type_name = self._record_name_by_params_name[type(params).__name__]
records = self._records_by_type[rec_type_name]
match: Optional[Record] = None
for rec in records:
if rec.params == params:
match = rec
records.remove(match)
break

return match

def write(self, file_name) -> None:
with open(file_name, "w") as file:
json.dump(self._to_dict(), file)

def _to_dict(self) -> Dict:
dct: Dict[str, Any] = {}

for record_type in self._records_by_type:
record_list = [r.to_dict() for r in self._records_by_type[record_type]]
dct[record_type] = record_list

return dct

@classmethod
def load(cls, file_name: str) -> Dict[str, List[Record]]:
with open(file_name) as file:
loaded_dct = json.load(file)

records_by_type: Dict[str, List[Record]] = {}

for record_type_name in loaded_dct:
record_cls = cls._record_cls_by_name[record_type_name]
rec_list = []
for record_dct in loaded_dct[record_type_name]:
rec = record_cls.from_dict(record_dct)
rec_list.append(rec) # type: ignore
records_by_type[record_type_name] = rec_list

return records_by_type

def expect_record(self, params: Any) -> Any:
record = self.pop_matching_record(params)

if record is None:
raise Exception()

result_tuple = dataclasses.astuple(record.result)
return result_tuple[0] if len(result_tuple) == 1 else result_tuple

def write_diffs(self, diff_file_name) -> None:
json.dump(
self._replay_diffs,
open(diff_file_name, "w"),
)

def print_diffs(self) -> None:
print(repr(self._replay_diffs))


def get_record_mode_from_env() -> Optional[RecorderMode]:
replay_val = os.environ.get("DBT_REPLAY")
if replay_val is not None and replay_val != "0" and replay_val.lower() != "false":
return RecorderMode.REPLAY

record_val = os.environ.get("DBT_RECORD")
if record_val is not None and record_val != "0" and record_val.lower() != "false":
return RecorderMode.RECORD

return None


def record_function(record_type, method=False, tuple_result=False):
def record_function_inner(func_to_record):
# To avoid runtime overhead and other unpleasantness, we only apply the
# record/replay decorator if a relevant env var is set.
if get_record_mode_from_env() is None:
return func_to_record

@functools.wraps(func_to_record)
def record_replay_wrapper(*args, **kwargs):
recorder: Recorder = None
try:
recorder = get_invocation_context().recorder
except LookupError:
pass

if recorder is None:
return func_to_record(*args, **kwargs)

# For methods, peel off the 'self' argument before calling the
# params constructor.
param_args = args[1:] if method else args

params = record_type.params_cls(*param_args, **kwargs)

include = True
if hasattr(params, "_include"):
include = params._include()

if not include:
return func_to_record(*args, **kwargs)

if recorder.mode == RecorderMode.REPLAY:
return recorder.expect_record(params)

r = func_to_record(*args, **kwargs)
result = (
None
if r is None or record_type.result_cls is None
else record_type.result_cls(*r)
if tuple_result
else record_type.result_cls(r)
)
recorder.add_record(record_type(params=params, result=result))
return r

return record_replay_wrapper

return record_function_inner
Loading

0 comments on commit ae8ffe0

Please sign in to comment.