Skip to content

Commit

Permalink
clean up and fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
emmyoop committed May 28, 2024
1 parent 84cca75 commit 306576e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
13 changes: 6 additions & 7 deletions dbt_common/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,12 @@ def load(cls, file_name: str) -> Dict[str, List[Record]]:
for record_type_name in loaded_dct:
# TODO: this breaks with QueryRecord on replay since it's
# not in common so isn't part of cls._record_cls_by_name yet

record_cls = cls._record_cls_by_name[record_type_name]
rec_list = []
for record_dct in loaded_dct[record_type_name]:
rec = record_cls.from_dict(record_dct)
rec_list.append(rec) # type: ignore
records_by_type[record_type_name] = rec_list

return records_by_type

def expect_record(self, params: Any) -> Any:
Expand Down Expand Up @@ -168,7 +166,9 @@ def get_record_mode_from_env() -> Optional[RecorderMode]:
if record_mode.lower() == "record":
return RecorderMode.RECORD
# replaying requires a file path, otherwise treat as noop
elif record_mode.lower() == "replay" and os.environ["DBT_RECORDER_REPLAY_PATH"] is not None:
elif (
record_mode.lower() == "replay" and os.environ.get("DBT_RECORDER_REPLAY_PATH") is not None
):
return RecorderMode.REPLAY

# if you don't specify record/replay it's a noop
Expand All @@ -193,10 +193,9 @@ def get_record_types_from_env() -> Optional[List]:

for type in record_types:
# Types not defined in common are not in the record_types list yet
# TODO: is there a better way to do this without hardcoding? We can't just
# wait for later because if it's QueryRecord (not defined in common) we don't
# want to remove it to ensure everything else is filtered out.... This is also
# a problem with replaying QueryRecords generally
# TODO: This is related to a problem with replay noted above. Will solve
# at a future date. Leaving it hardcoded for now to unblock. Will remove
# after resolving MNTL-308.
if type not in Recorder._record_cls_by_name and type != "QueryRecord":
print(f"Invalid record type: {type}") # TODO: remove after testing
record_types.remove(type)
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/test_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ def test_func(a: int, b: str, c: Optional[str] = None) -> str:

def test_decorator_replays():
prev = os.environ.get("DBT_RECORDER_MODE", None)
prev_path = os.environ.get("DBT_RECORDER_REPLAY_PATH", None)
try:
os.environ["DBT_RECORDER_MODE"] = "Replay"
os.environ["DBT_RECORDER_REPLAY_PATH"] = "record.json"
recorder = Recorder(RecorderMode.REPLAY, None)
set_invocation_context({})
get_invocation_context().recorder = recorder
Expand All @@ -79,3 +81,7 @@ def test_func(a: int, b: str, c: Optional[str] = None) -> str:
os.environ.pop("DBT_RECORDER_MODE", None)
else:
os.environ["DBT_RECORDER_MODE"] = prev
if prev_path is None:
os.environ.pop("DBT_RECORDER_REPLAY_PATH", None)
else:
os.environ["DBT_RECORDER_REPLAY_PATH"] = prev_path

0 comments on commit 306576e

Please sign in to comment.