From 306576ec94af4bb7afdfeb8d4ca628091fe05a67 Mon Sep 17 00:00:00 2001 From: Emily Rockman Date: Tue, 28 May 2024 09:30:14 -0500 Subject: [PATCH] clean up and fix test --- dbt_common/record.py | 13 ++++++------- tests/unit/test_record.py | 6 ++++++ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/dbt_common/record.py b/dbt_common/record.py index 0f4750d8..ab5b16da 100644 --- a/dbt_common/record.py +++ b/dbt_common/record.py @@ -124,14 +124,12 @@ def load(cls, file_name: str) -> Dict[str, List[Record]]: for record_type_name in loaded_dct: # TODO: this breaks with QueryRecord on replay since it's # not in common so isn't part of cls._record_cls_by_name yet - record_cls = cls._record_cls_by_name[record_type_name] rec_list = [] for record_dct in loaded_dct[record_type_name]: rec = record_cls.from_dict(record_dct) rec_list.append(rec) # type: ignore records_by_type[record_type_name] = rec_list - return records_by_type def expect_record(self, params: Any) -> Any: @@ -168,7 +166,9 @@ def get_record_mode_from_env() -> Optional[RecorderMode]: if record_mode.lower() == "record": return RecorderMode.RECORD # replaying requires a file path, otherwise treat as noop - elif record_mode.lower() == "replay" and os.environ["DBT_RECORDER_REPLAY_PATH"] is not None: + elif ( + record_mode.lower() == "replay" and os.environ.get("DBT_RECORDER_REPLAY_PATH") is not None + ): return RecorderMode.REPLAY # if you don't specify record/replay it's a noop @@ -193,10 +193,9 @@ def get_record_types_from_env() -> Optional[List]: for type in record_types: # Types not defined in common are not in the record_types list yet - # TODO: is there a better way to do this without hardcoding? We can't just - # wait for later because if it's QueryRecord (not defined in common) we don't - # want to remove it to ensure everything else is filtered out.... This is also - # a problem with replaying QueryRecords generally + # TODO: This is related to a problem with replay noted above. Will solve + # at a future date. Leaving it hardcoded for now to unblock. Will remove + # after resolving MNTL-308. if type not in Recorder._record_cls_by_name and type != "QueryRecord": print(f"Invalid record type: {type}") # TODO: remove after testing record_types.remove(type) diff --git a/tests/unit/test_record.py b/tests/unit/test_record.py index bce824bc..af9fc0fe 100644 --- a/tests/unit/test_record.py +++ b/tests/unit/test_record.py @@ -54,8 +54,10 @@ def test_func(a: int, b: str, c: Optional[str] = None) -> str: def test_decorator_replays(): prev = os.environ.get("DBT_RECORDER_MODE", None) + prev_path = os.environ.get("DBT_RECORDER_REPLAY_PATH", None) try: os.environ["DBT_RECORDER_MODE"] = "Replay" + os.environ["DBT_RECORDER_REPLAY_PATH"] = "record.json" recorder = Recorder(RecorderMode.REPLAY, None) set_invocation_context({}) get_invocation_context().recorder = recorder @@ -79,3 +81,7 @@ def test_func(a: int, b: str, c: Optional[str] = None) -> str: os.environ.pop("DBT_RECORDER_MODE", None) else: os.environ["DBT_RECORDER_MODE"] = prev + if prev_path is None: + os.environ.pop("DBT_RECORDER_REPLAY_PATH", None) + else: + os.environ["DBT_RECORDER_REPLAY_PATH"] = prev_path