Skip to content

Commit

Permalink
Support "just in time" loading of records, and add ID fields (#154)
Browse files Browse the repository at this point in the history
* Deserialize records "just in time" in order to avoid import order issues.

* Add changelog entry

* Typing and formatting fixes

* Typing

* Add id field to record.

* Add more informative error message.

* Tweak the way results are stored.

* Typing and formatting fixes.

---------

Co-authored-by: Emily Rockman <[email protected]>
  • Loading branch information
peterallenwebb and emmyoop authored Jun 25, 2024
1 parent 00fbf91 commit a2ccd41
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 27 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240618-155025.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Deserialize Record objects on a just-in-time basis.
time: 2024-06-18T15:50:25.985387-04:00
custom:
Author: peterallenwebb
Issue: "151"
4 changes: 3 additions & 1 deletion dbt_common/clients/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def _include(self) -> bool:
# Do not record or replay filesystem searches that were performed against
# files which are actually part of dbt's implementation.
return (
"dbt/include/global_project" not in self.root_path
"dbt/include"
not in self.root_path # TODO: This actually obviates the next two checks but is probably too coarse?
and "dbt/include/global_project" not in self.root_path
and "/plugins/postgres/dbt/include/" not in self.root_path
)

Expand Down
71 changes: 45 additions & 26 deletions dbt_common/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
external systems during a command invocation, so that the command can be re-run
later with the recording 'replayed' to dbt.
The rationale for and architecture of this module is described in detail in the
The rationale for and architecture of this module are described in detail in the
docs/guides/record_replay.md document in this repository.
"""
import functools
Expand All @@ -12,7 +12,7 @@

from deepdiff import DeepDiff # type: ignore
from enum import Enum
from typing import Any, Dict, List, Mapping, Optional, Type
from typing import Any, Callable, Dict, List, Mapping, Optional, Type

from dbt_common.context import get_invocation_context

Expand Down Expand Up @@ -129,10 +129,11 @@ def __init__(
previous_recording_path: Optional[str] = None,
) -> None:
self.mode = mode
self.types = types
self.recorded_types = types
self._records_by_type: Dict[str, List[Record]] = {}
self._unprocessed_records_by_type: Dict[str, List[Dict[str, Any]]] = {}
self._replay_diffs: List["Diff"] = []
self.diff: Diff
self.diff: Optional[Diff] = None
self.previous_recording_path = previous_recording_path
self.current_recording_path = current_recording_path

Expand All @@ -146,7 +147,7 @@ def __init__(
)

if self.mode == RecorderMode.REPLAY:
self._records_by_type = self.load(self.previous_recording_path)
self._unprocessed_records_by_type = self.load(self.previous_recording_path)

@classmethod
def register_record_type(cls, rec_type) -> Any:
Expand All @@ -161,7 +162,14 @@ def add_record(self, record: Record) -> None:
self._records_by_type[rec_cls_name].append(record)

def pop_matching_record(self, params: Any) -> Optional[Record]:
rec_type_name = self._record_name_by_params_name[type(params).__name__]
rec_type_name = self._record_name_by_params_name.get(type(params).__name__)

if rec_type_name is None:
raise Exception(
f"A record of type {type(params).__name__} was requested, but no such type has been registered."
)

self._ensure_records_processed(rec_type_name)
records = self._records_by_type[rec_type_name]
match: Optional[Record] = None
for rec in records:
Expand All @@ -186,39 +194,40 @@ def _to_dict(self) -> Dict:
return dct

@classmethod
def load(cls, file_name: str) -> Dict[str, List[Record]]:
def load(cls, file_name: str) -> Dict[str, List[Dict[str, Any]]]:
with open(file_name) as file:
loaded_dct = json.load(file)
return json.load(file)

records_by_type: Dict[str, List[Record]] = {}
def _ensure_records_processed(self, record_type_name: str) -> None:
if record_type_name in self._records_by_type:
return

for record_type_name in loaded_dct:
# TODO: this breaks with QueryRecord on replay since it's
# not in common so isn't part of cls._record_cls_by_name yet
record_cls = cls._record_cls_by_name[record_type_name]
rec_list = []
for record_dct in loaded_dct[record_type_name]:
rec = record_cls.from_dict(record_dct)
rec_list.append(rec) # type: ignore
records_by_type[record_type_name] = rec_list
return records_by_type
rec_list = []
record_cls = self._record_cls_by_name[record_type_name]
for record_dct in self._unprocessed_records_by_type[record_type_name]:
rec = record_cls.from_dict(record_dct)
rec_list.append(rec) # type: ignore
self._records_by_type[record_type_name] = rec_list

def expect_record(self, params: Any) -> Any:
record = self.pop_matching_record(params)

if record is None:
raise Exception()

if record.result is None:
return None

result_tuple = dataclasses.astuple(record.result)
return result_tuple[0] if len(result_tuple) == 1 else result_tuple

def write_diffs(self, diff_file_name) -> None:
json.dump(
self.diff.calculate_diff(),
open(diff_file_name, "w"),
)
assert self.diff is not None
with open(diff_file_name, "w") as f:
json.dump(self.diff.calculate_diff(), f)

def print_diffs(self) -> None:
assert self.diff is not None
print(repr(self.diff.calculate_diff()))


Expand Down Expand Up @@ -273,7 +282,12 @@ def get_record_types_from_dict(fp: str) -> List:
return list(loaded_dct.keys())


def record_function(record_type, method=False, tuple_result=False):
def record_function(
record_type,
method: bool = False,
tuple_result: bool = False,
id_field_name: Optional[str] = None,
) -> Callable:
def record_function_inner(func_to_record):
# To avoid runtime overhead and other unpleasantness, we only apply the
# record/replay decorator if a relevant env var is set.
Expand All @@ -291,12 +305,17 @@ def record_replay_wrapper(*args, **kwargs):
if recorder is None:
return func_to_record(*args, **kwargs)

if recorder.types is not None and record_type.__name__ not in recorder.types:
if (
recorder.recorded_types is not None
and record_type.__name__ not in recorder.recorded_types
):
return func_to_record(*args, **kwargs)

# For methods, peel off the 'self' argument before calling the
# params constructor.
param_args = args[1:] if method else args
if method and id_field_name is not None:
param_args = (getattr(args[0], id_field_name),) + param_args

params = record_type.params_cls(*param_args, **kwargs)

Expand All @@ -313,7 +332,7 @@ def record_replay_wrapper(*args, **kwargs):
r = func_to_record(*args, **kwargs)
result = (
None
if r is None or record_type.result_cls is None
if record_type.result_cls is None
else record_type.result_cls(*r)
if tuple_result
else record_type.result_cls(r)
Expand Down

0 comments on commit a2ccd41

Please sign in to comment.