diff --git a/dbt_common/record.py b/dbt_common/record.py index 930dc2e0..5e225c0d 100644 --- a/dbt_common/record.py +++ b/dbt_common/record.py @@ -66,8 +66,11 @@ class Recorder: _record_cls_by_name: Dict[str, Type] = {} _record_name_by_params_name: Dict[str, str] = {} - def __init__(self, mode: RecorderMode, recording_path: Optional[str] = None) -> None: + def __init__( + self, mode: RecorderMode, types: Optional[List], recording_path: Optional[str] = None + ) -> None: self.mode = mode + self.types = types self._records_by_type: Dict[str, List[Record]] = {} self._replay_diffs: List["Diff"] = [] @@ -148,21 +151,58 @@ def print_diffs(self) -> None: def get_record_mode_from_env() -> Optional[RecorderMode]: - replay_val = os.environ.get("DBT_REPLAY") - if replay_val is not None and replay_val != "0" and replay_val.lower() != "false": - return RecorderMode.REPLAY + """ + Get the record mode from the environment variables. - record_val = os.environ.get("DBT_RECORD") - if record_val is not None and record_val != "0" and record_val.lower() != "false": - return RecorderMode.RECORD + If the mode is not set to 'RECORD' or 'REPLAY', return None. + Expected format: 'DBT_RECORDER_MODE=RECORD' + """ + record_mode = os.environ.get("DBT_RECORDER_MODE") - record_val = os.environ.get("DBT_RECORD_QUERIES") - if record_val is not None and record_val != "0" and record_val.lower() != "false": - return RecorderMode.RECORD_QUERIES + if record_mode is None: + return None + if record_mode.lower() == "record": + return RecorderMode.RECORD + elif record_mode.lower() == "replay": + return RecorderMode.REPLAY + + # if you don't specify record/replay it's a noop return None +def get_record_types_from_env() -> Optional[List]: + """ + Get the record subset from the environment variables. + + If no types are provided, there will be no filtering. + Invalid types will be ignored. + Expected format: 'DBT_RECORDER_TYPES=QueryRecord,FileLoadRecord,OtherRecord' + """ + record_types_str = os.environ.get("DBT_RECORDER_TYPES") + + # if all is specified we don't want any type filtering + if record_types_str is None or record_types_str.lower == "all": + return None + + record_types = record_types_str.split(",") + + for type in record_types: + # Types not defined in common are not in the record_types list yet + # TODO: is there a better way to do this without hardcoding? We can't just + # wait for later because if it's QueryRecord (not defined in common) we don't + # want to remove it to ensure everything else is filtered out.... + if type not in Recorder._record_cls_by_name and type != "QueryRecord": + print(f"Invalid record type: {type}") # TODO: remove after testing + record_types.remove(type) + + # if everything is invalid we don't want any type filtering + if len(record_types) == 0: + return None + + return record_types + + def record_function(record_type, method=False, tuple_result=False): def record_function_inner(func_to_record): # To avoid runtime overhead and other unpleasantness, we only apply the @@ -181,10 +221,7 @@ def record_replay_wrapper(*args, **kwargs): if recorder is None: return func_to_record(*args, **kwargs) - if ( - recorder.mode == RecorderMode.RECORD_QUERIES - and record_type.__name__ != "QueryRecord" - ): + if recorder.types is not None and record_type.__name__ not in recorder.types: return func_to_record(*args, **kwargs) # For methods, peel off the 'self' argument before calling the diff --git a/docs/guides/record_replay.md b/docs/guides/record_replay.md index 9d9d87f2..c182da0a 100644 --- a/docs/guides/record_replay.md +++ b/docs/guides/record_replay.md @@ -28,7 +28,18 @@ Note also the `LoadFileRecord` class passed as a parameter to this decorator. Th The final detail needed is to define the classes specified by `params_cls` and `result_cls`, which must be dataclasses with properties whose order and names correspond to the parameters passed to the recorded function. In this case those are the `LoadFileParams` and `LoadFileResult` classes, respectively. -With these decorators applied and classes defined, dbt is able to record all file access during a run, and mock out the accesses during replay, isolating dbt from actually loading files. At least it would if dbt only used this function for all file access, which is only mostly true. We hope to continue improving the usefulness of this mechanism by adding more recorded functions and routing more operations through them. +With these decorators applied and classes defined, dbt is able to record all file access during a run, and mock out the accesses during replay, isolating dbt from actually loading files. At least it would if dbt only used this function for all file access, which is only mostly true. We hope to continue improving the usefulness of this mechanism by adding more recorded functions and routing more operations through them. + +## How to record/replay +If `DBT_RECORDER_MODE` is not `replay` or `record`, case insensitive, this is a no-op. Invalid values are ignored and do not throw exceptions. + +`DBT_RECODER_TYPES` is optional. It indicates which types to filter the results by and expects a list of strings values for the `Record` subclasses. Any invalid types will be ignored. `all` is a valid type and behaves the same as not populating the env var. + +example + +```bash +DBT_RECORDER_MODE=record DBT_RECODER_TYPES=QueryRecord,GetEnvRecord dbt run +``` ## Final Thoughts diff --git a/tests/unit/test_record.py b/tests/unit/test_record.py index aa7af69b..7829762c 100644 --- a/tests/unit/test_record.py +++ b/tests/unit/test_record.py @@ -25,9 +25,9 @@ class TestRecord(Record): def test_decorator_records(): - prev = os.environ.get("DBT_RECORD", None) + prev = os.environ.get("DBT_RECORDER_MODE", None) try: - os.environ["DBT_RECORD"] = "True" + os.environ["DBT_RECORDER_MODE"] = "Record" recorder = Recorder(RecorderMode.RECORD) set_invocation_context({}) get_invocation_context().recorder = recorder @@ -47,15 +47,15 @@ def test_func(a: int, b: str, c: Optional[str] = None) -> str: finally: if prev is None: - os.environ.pop("DBT_RECORD", None) + os.environ.pop("DBT_RECORDER_MODE", None) else: - os.environ["DBT_RECORD"] = prev + os.environ["DBT_RECORDER_MODE"] = prev def test_decorator_replays(): - prev = os.environ.get("DBT_RECORD", None) + prev = os.environ.get("DBT_RECORDER_MODE", None) try: - os.environ["DBT_RECORD"] = "True" + os.environ["DBT_RECORDER_MODE"] = "Replay" recorder = Recorder(RecorderMode.REPLAY) set_invocation_context({}) get_invocation_context().recorder = recorder @@ -76,6 +76,6 @@ def test_func(a: int, b: str, c: Optional[str] = None) -> str: finally: if prev is None: - os.environ.pop("DBT_RECORD", None) + os.environ.pop("DBT_RECORDER_MODE", None) else: - os.environ["DBT_RECORD"] = prev + os.environ["DBT_RECORDER_MODE"] = prev