Skip to content

Commit

Permalink
Deserialize records "just in time" in order to avoid import order iss…
Browse files Browse the repository at this point in the history
…sue (#151)

* Deserialize records "just in time" in order to avoid import order issues.

* Add changelog entry

* Typing and formatting fixes

* Typing

---------

Co-authored-by: Emily Rockman <[email protected]>
  • Loading branch information
peterallenwebb and emmyoop authored Jun 20, 2024
1 parent 00fbf91 commit 64374bc
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 23 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240618-155025.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Deserialize Record objects on a just-in-time basis.
time: 2024-06-18T15:50:25.985387-04:00
custom:
Author: peterallenwebb
Issue: "151"
4 changes: 3 additions & 1 deletion dbt_common/clients/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def _include(self) -> bool:
# Do not record or replay filesystem searches that were performed against
# files which are actually part of dbt's implementation.
return (
"dbt/include/global_project" not in self.root_path
"dbt/include"
not in self.root_path # TODO: This actually obviates the next two checks but is probably too coarse?
and "dbt/include/global_project" not in self.root_path
and "/plugins/postgres/dbt/include/" not in self.root_path
)

Expand Down
50 changes: 28 additions & 22 deletions dbt_common/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
external systems during a command invocation, so that the command can be re-run
later with the recording 'replayed' to dbt.
The rationale for and architecture of this module is described in detail in the
The rationale for and architecture of this module are described in detail in the
docs/guides/record_replay.md document in this repository.
"""
import functools
Expand Down Expand Up @@ -129,10 +129,11 @@ def __init__(
previous_recording_path: Optional[str] = None,
) -> None:
self.mode = mode
self.types = types
self.recorded_types = types
self._records_by_type: Dict[str, List[Record]] = {}
self._unprocessed_records_by_type: Dict[str, List[Dict[str, Any]]] = {}
self._replay_diffs: List["Diff"] = []
self.diff: Diff
self.diff: Optional[Diff] = None
self.previous_recording_path = previous_recording_path
self.current_recording_path = current_recording_path

Expand All @@ -146,7 +147,7 @@ def __init__(
)

if self.mode == RecorderMode.REPLAY:
self._records_by_type = self.load(self.previous_recording_path)
self._unprocessed_records_by_type = self.load(self.previous_recording_path)

@classmethod
def register_record_type(cls, rec_type) -> Any:
Expand All @@ -162,6 +163,7 @@ def add_record(self, record: Record) -> None:

def pop_matching_record(self, params: Any) -> Optional[Record]:
rec_type_name = self._record_name_by_params_name[type(params).__name__]
self._ensure_records_processed(rec_type_name)
records = self._records_by_type[rec_type_name]
match: Optional[Record] = None
for rec in records:
Expand All @@ -186,39 +188,40 @@ def _to_dict(self) -> Dict:
return dct

@classmethod
def load(cls, file_name: str) -> Dict[str, List[Record]]:
def load(cls, file_name: str) -> Dict[str, List[Dict[str, Any]]]:
with open(file_name) as file:
loaded_dct = json.load(file)
return json.load(file)

records_by_type: Dict[str, List[Record]] = {}
def _ensure_records_processed(self, record_type_name: str) -> None:
if record_type_name in self._records_by_type:
return

for record_type_name in loaded_dct:
# TODO: this breaks with QueryRecord on replay since it's
# not in common so isn't part of cls._record_cls_by_name yet
record_cls = cls._record_cls_by_name[record_type_name]
rec_list = []
for record_dct in loaded_dct[record_type_name]:
rec = record_cls.from_dict(record_dct)
rec_list.append(rec) # type: ignore
records_by_type[record_type_name] = rec_list
return records_by_type
rec_list = []
record_cls = self._record_cls_by_name[record_type_name]
for record_dct in self._unprocessed_records_by_type[record_type_name]:
rec = record_cls.from_dict(record_dct)
rec_list.append(rec) # type: ignore
self._records_by_type[record_type_name] = rec_list

def expect_record(self, params: Any) -> Any:
record = self.pop_matching_record(params)

if record is None:
raise Exception()

if record.result is None:
return None

result_tuple = dataclasses.astuple(record.result)
return result_tuple[0] if len(result_tuple) == 1 else result_tuple

def write_diffs(self, diff_file_name) -> None:
json.dump(
self.diff.calculate_diff(),
open(diff_file_name, "w"),
)
assert self.diff is not None
with open(diff_file_name, "w") as f:
json.dump(self.diff.calculate_diff(), f)

def print_diffs(self) -> None:
assert self.diff is not None
print(repr(self.diff.calculate_diff()))


Expand Down Expand Up @@ -291,7 +294,10 @@ def record_replay_wrapper(*args, **kwargs):
if recorder is None:
return func_to_record(*args, **kwargs)

if recorder.types is not None and record_type.__name__ not in recorder.types:
if (
recorder.recorded_types is not None
and record_type.__name__ not in recorder.recorded_types
):
return func_to_record(*args, **kwargs)

# For methods, peel off the 'self' argument before calling the
Expand Down

0 comments on commit 64374bc

Please sign in to comment.