diff --git a/.changes/unreleased/Under the Hood-20240618-155025.yaml b/.changes/unreleased/Under the Hood-20240618-155025.yaml new file mode 100644 index 00000000..b540d3d7 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240618-155025.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Deserialize Record objects on a just-in-time basis. +time: 2024-06-18T15:50:25.985387-04:00 +custom: + Author: peterallenwebb + Issue: "151" diff --git a/dbt_common/clients/system.py b/dbt_common/clients/system.py index bcf798d2..00a1ac69 100644 --- a/dbt_common/clients/system.py +++ b/dbt_common/clients/system.py @@ -62,7 +62,9 @@ def _include(self) -> bool: # Do not record or replay filesystem searches that were performed against # files which are actually part of dbt's implementation. return ( - "dbt/include/global_project" not in self.root_path + "dbt/include" + not in self.root_path # TODO: This actually obviates the next two checks but is probably too coarse? + and "dbt/include/global_project" not in self.root_path and "/plugins/postgres/dbt/include/" not in self.root_path ) diff --git a/dbt_common/record.py b/dbt_common/record.py index 9fed4c3d..c204faa6 100644 --- a/dbt_common/record.py +++ b/dbt_common/record.py @@ -2,7 +2,7 @@ external systems during a command invocation, so that the command can be re-run later with the recording 'replayed' to dbt. -The rationale for and architecture of this module is described in detail in the +The rationale for and architecture of this module are described in detail in the docs/guides/record_replay.md document in this repository. """ import functools @@ -12,7 +12,7 @@ from deepdiff import DeepDiff # type: ignore from enum import Enum -from typing import Any, Dict, List, Mapping, Optional, Type +from typing import Any, Callable, Dict, List, Mapping, Optional, Type from dbt_common.context import get_invocation_context @@ -129,10 +129,11 @@ def __init__( previous_recording_path: Optional[str] = None, ) -> None: self.mode = mode - self.types = types + self.recorded_types = types self._records_by_type: Dict[str, List[Record]] = {} + self._unprocessed_records_by_type: Dict[str, List[Dict[str, Any]]] = {} self._replay_diffs: List["Diff"] = [] - self.diff: Diff + self.diff: Optional[Diff] = None self.previous_recording_path = previous_recording_path self.current_recording_path = current_recording_path @@ -146,7 +147,7 @@ def __init__( ) if self.mode == RecorderMode.REPLAY: - self._records_by_type = self.load(self.previous_recording_path) + self._unprocessed_records_by_type = self.load(self.previous_recording_path) @classmethod def register_record_type(cls, rec_type) -> Any: @@ -161,7 +162,14 @@ def add_record(self, record: Record) -> None: self._records_by_type[rec_cls_name].append(record) def pop_matching_record(self, params: Any) -> Optional[Record]: - rec_type_name = self._record_name_by_params_name[type(params).__name__] + rec_type_name = self._record_name_by_params_name.get(type(params).__name__) + + if rec_type_name is None: + raise Exception( + f"A record of type {type(params).__name__} was requested, but no such type has been registered." + ) + + self._ensure_records_processed(rec_type_name) records = self._records_by_type[rec_type_name] match: Optional[Record] = None for rec in records: @@ -186,22 +194,20 @@ def _to_dict(self) -> Dict: return dct @classmethod - def load(cls, file_name: str) -> Dict[str, List[Record]]: + def load(cls, file_name: str) -> Dict[str, List[Dict[str, Any]]]: with open(file_name) as file: - loaded_dct = json.load(file) + return json.load(file) - records_by_type: Dict[str, List[Record]] = {} + def _ensure_records_processed(self, record_type_name: str) -> None: + if record_type_name in self._records_by_type: + return - for record_type_name in loaded_dct: - # TODO: this breaks with QueryRecord on replay since it's - # not in common so isn't part of cls._record_cls_by_name yet - record_cls = cls._record_cls_by_name[record_type_name] - rec_list = [] - for record_dct in loaded_dct[record_type_name]: - rec = record_cls.from_dict(record_dct) - rec_list.append(rec) # type: ignore - records_by_type[record_type_name] = rec_list - return records_by_type + rec_list = [] + record_cls = self._record_cls_by_name[record_type_name] + for record_dct in self._unprocessed_records_by_type[record_type_name]: + rec = record_cls.from_dict(record_dct) + rec_list.append(rec) # type: ignore + self._records_by_type[record_type_name] = rec_list def expect_record(self, params: Any) -> Any: record = self.pop_matching_record(params) @@ -209,16 +215,19 @@ def expect_record(self, params: Any) -> Any: if record is None: raise Exception() + if record.result is None: + return None + result_tuple = dataclasses.astuple(record.result) return result_tuple[0] if len(result_tuple) == 1 else result_tuple def write_diffs(self, diff_file_name) -> None: - json.dump( - self.diff.calculate_diff(), - open(diff_file_name, "w"), - ) + assert self.diff is not None + with open(diff_file_name, "w") as f: + json.dump(self.diff.calculate_diff(), f) def print_diffs(self) -> None: + assert self.diff is not None print(repr(self.diff.calculate_diff())) @@ -273,7 +282,12 @@ def get_record_types_from_dict(fp: str) -> List: return list(loaded_dct.keys()) -def record_function(record_type, method=False, tuple_result=False): +def record_function( + record_type, + method: bool = False, + tuple_result: bool = False, + id_field_name: Optional[str] = None, +) -> Callable: def record_function_inner(func_to_record): # To avoid runtime overhead and other unpleasantness, we only apply the # record/replay decorator if a relevant env var is set. @@ -291,12 +305,17 @@ def record_replay_wrapper(*args, **kwargs): if recorder is None: return func_to_record(*args, **kwargs) - if recorder.types is not None and record_type.__name__ not in recorder.types: + if ( + recorder.recorded_types is not None + and record_type.__name__ not in recorder.recorded_types + ): return func_to_record(*args, **kwargs) # For methods, peel off the 'self' argument before calling the # params constructor. param_args = args[1:] if method else args + if method and id_field_name is not None: + param_args = (getattr(args[0], id_field_name),) + param_args params = record_type.params_cls(*param_args, **kwargs) @@ -313,7 +332,7 @@ def record_replay_wrapper(*args, **kwargs): r = func_to_record(*args, **kwargs) result = ( None - if r is None or record_type.result_cls is None + if record_type.result_cls is None else record_type.result_cls(*r) if tuple_result else record_type.result_cls(r)