From dbb2308e126d7306594d99a0b0521ba73822b724 Mon Sep 17 00:00:00 2001 From: Peter Webb Date: Wed, 26 Jun 2024 13:30:23 -0400 Subject: [PATCH 01/10] Support "just in time" loading of records, and add ID fields (#154) * Deserialize records "just in time" in order to avoid import order issues. * Add changelog entry * Typing and formatting fixes * Typing * Add id field to record. * Add more informative error message. * Tweak the way results are stored. * Typing and formatting fixes. --------- Co-authored-by: Emily Rockman --- .../Under the Hood-20240618-155025.yaml | 6 ++ dbt_common/clients/system.py | 4 +- dbt_common/record.py | 71 ++++++++++++------- 3 files changed, 54 insertions(+), 27 deletions(-) create mode 100644 .changes/unreleased/Under the Hood-20240618-155025.yaml diff --git a/.changes/unreleased/Under the Hood-20240618-155025.yaml b/.changes/unreleased/Under the Hood-20240618-155025.yaml new file mode 100644 index 00000000..b540d3d7 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240618-155025.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Deserialize Record objects on a just-in-time basis. +time: 2024-06-18T15:50:25.985387-04:00 +custom: + Author: peterallenwebb + Issue: "151" diff --git a/dbt_common/clients/system.py b/dbt_common/clients/system.py index bcf798d2..00a1ac69 100644 --- a/dbt_common/clients/system.py +++ b/dbt_common/clients/system.py @@ -62,7 +62,9 @@ def _include(self) -> bool: # Do not record or replay filesystem searches that were performed against # files which are actually part of dbt's implementation. return ( - "dbt/include/global_project" not in self.root_path + "dbt/include" + not in self.root_path # TODO: This actually obviates the next two checks but is probably too coarse? + and "dbt/include/global_project" not in self.root_path and "/plugins/postgres/dbt/include/" not in self.root_path ) diff --git a/dbt_common/record.py b/dbt_common/record.py index 9fed4c3d..c204faa6 100644 --- a/dbt_common/record.py +++ b/dbt_common/record.py @@ -2,7 +2,7 @@ external systems during a command invocation, so that the command can be re-run later with the recording 'replayed' to dbt. -The rationale for and architecture of this module is described in detail in the +The rationale for and architecture of this module are described in detail in the docs/guides/record_replay.md document in this repository. """ import functools @@ -12,7 +12,7 @@ from deepdiff import DeepDiff # type: ignore from enum import Enum -from typing import Any, Dict, List, Mapping, Optional, Type +from typing import Any, Callable, Dict, List, Mapping, Optional, Type from dbt_common.context import get_invocation_context @@ -129,10 +129,11 @@ def __init__( previous_recording_path: Optional[str] = None, ) -> None: self.mode = mode - self.types = types + self.recorded_types = types self._records_by_type: Dict[str, List[Record]] = {} + self._unprocessed_records_by_type: Dict[str, List[Dict[str, Any]]] = {} self._replay_diffs: List["Diff"] = [] - self.diff: Diff + self.diff: Optional[Diff] = None self.previous_recording_path = previous_recording_path self.current_recording_path = current_recording_path @@ -146,7 +147,7 @@ def __init__( ) if self.mode == RecorderMode.REPLAY: - self._records_by_type = self.load(self.previous_recording_path) + self._unprocessed_records_by_type = self.load(self.previous_recording_path) @classmethod def register_record_type(cls, rec_type) -> Any: @@ -161,7 +162,14 @@ def add_record(self, record: Record) -> None: self._records_by_type[rec_cls_name].append(record) def pop_matching_record(self, params: Any) -> Optional[Record]: - rec_type_name = self._record_name_by_params_name[type(params).__name__] + rec_type_name = self._record_name_by_params_name.get(type(params).__name__) + + if rec_type_name is None: + raise Exception( + f"A record of type {type(params).__name__} was requested, but no such type has been registered." + ) + + self._ensure_records_processed(rec_type_name) records = self._records_by_type[rec_type_name] match: Optional[Record] = None for rec in records: @@ -186,22 +194,20 @@ def _to_dict(self) -> Dict: return dct @classmethod - def load(cls, file_name: str) -> Dict[str, List[Record]]: + def load(cls, file_name: str) -> Dict[str, List[Dict[str, Any]]]: with open(file_name) as file: - loaded_dct = json.load(file) + return json.load(file) - records_by_type: Dict[str, List[Record]] = {} + def _ensure_records_processed(self, record_type_name: str) -> None: + if record_type_name in self._records_by_type: + return - for record_type_name in loaded_dct: - # TODO: this breaks with QueryRecord on replay since it's - # not in common so isn't part of cls._record_cls_by_name yet - record_cls = cls._record_cls_by_name[record_type_name] - rec_list = [] - for record_dct in loaded_dct[record_type_name]: - rec = record_cls.from_dict(record_dct) - rec_list.append(rec) # type: ignore - records_by_type[record_type_name] = rec_list - return records_by_type + rec_list = [] + record_cls = self._record_cls_by_name[record_type_name] + for record_dct in self._unprocessed_records_by_type[record_type_name]: + rec = record_cls.from_dict(record_dct) + rec_list.append(rec) # type: ignore + self._records_by_type[record_type_name] = rec_list def expect_record(self, params: Any) -> Any: record = self.pop_matching_record(params) @@ -209,16 +215,19 @@ def expect_record(self, params: Any) -> Any: if record is None: raise Exception() + if record.result is None: + return None + result_tuple = dataclasses.astuple(record.result) return result_tuple[0] if len(result_tuple) == 1 else result_tuple def write_diffs(self, diff_file_name) -> None: - json.dump( - self.diff.calculate_diff(), - open(diff_file_name, "w"), - ) + assert self.diff is not None + with open(diff_file_name, "w") as f: + json.dump(self.diff.calculate_diff(), f) def print_diffs(self) -> None: + assert self.diff is not None print(repr(self.diff.calculate_diff())) @@ -273,7 +282,12 @@ def get_record_types_from_dict(fp: str) -> List: return list(loaded_dct.keys()) -def record_function(record_type, method=False, tuple_result=False): +def record_function( + record_type, + method: bool = False, + tuple_result: bool = False, + id_field_name: Optional[str] = None, +) -> Callable: def record_function_inner(func_to_record): # To avoid runtime overhead and other unpleasantness, we only apply the # record/replay decorator if a relevant env var is set. @@ -291,12 +305,17 @@ def record_replay_wrapper(*args, **kwargs): if recorder is None: return func_to_record(*args, **kwargs) - if recorder.types is not None and record_type.__name__ not in recorder.types: + if ( + recorder.recorded_types is not None + and record_type.__name__ not in recorder.recorded_types + ): return func_to_record(*args, **kwargs) # For methods, peel off the 'self' argument before calling the # params constructor. param_args = args[1:] if method else args + if method and id_field_name is not None: + param_args = (getattr(args[0], id_field_name),) + param_args params = record_type.params_cls(*param_args, **kwargs) @@ -313,7 +332,7 @@ def record_replay_wrapper(*args, **kwargs): r = func_to_record(*args, **kwargs) result = ( None - if r is None or record_type.result_cls is None + if record_type.result_cls is None else record_type.result_cls(*r) if tuple_result else record_type.result_cls(r) From 002377d299e24912d8f43132c3ff71b9ffe335f9 Mon Sep 17 00:00:00 2001 From: Peter Webb Date: Wed, 26 Jun 2024 16:10:15 -0400 Subject: [PATCH 02/10] Deserialize records "just in time" in order to avoid import order isssue (#151) * Deserialize records "just in time" in order to avoid import order issues. * Add changelog entry * Typing and formatting fixes * Typing --------- Co-authored-by: Emily Rockman From d25c29a73925a39fdf6fb0737868683dc61ae84b Mon Sep 17 00:00:00 2001 From: Emily Rockman Date: Tue, 2 Jul 2024 12:00:38 -0500 Subject: [PATCH 03/10] bump version to 1.5.0 (#160) --- dbt_common/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbt_common/__about__.py b/dbt_common/__about__.py index d619c757..e3a0f015 100644 --- a/dbt_common/__about__.py +++ b/dbt_common/__about__.py @@ -1 +1 @@ -version = "1.4.0" +version = "1.5.0" From feda22dc08950efb21e4ae4feb1b4fd20a9b090b Mon Sep 17 00:00:00 2001 From: Peter Webb Date: Thu, 4 Jul 2024 17:26:03 -0400 Subject: [PATCH 04/10] Improve Type Annotations Coverage (#162) * Add some type annotations * More type annotation. * Freedom from typlessness. * More annotations. * Clean up linter issues --- dbt_common/context.py | 5 +- dbt_common/dataclass_schema.py | 8 +- dbt_common/exceptions/base.py | 6 +- dbt_common/helper_types.py | 4 +- dbt_common/record.py | 8 +- dbt_common/semver.py | 30 ++-- dbt_common/utils/casting.py | 6 +- dbt_common/utils/executor.py | 9 +- dbt_common/utils/jinja.py | 8 +- tests/unit/test_agate_helper.py | 22 +-- tests/unit/test_connection_retries.py | 11 +- tests/unit/test_contextvars.py | 2 +- tests/unit/test_contracts_util.py | 2 +- tests/unit/test_core_dbt_utils.py | 48 ++--- tests/unit/test_diff.py | 8 +- tests/unit/test_event_handler.py | 4 +- tests/unit/test_helper_types.py | 27 +-- tests/unit/test_invocation_context.py | 8 +- tests/unit/test_jinja.py | 246 +++++++++++++++----------- tests/unit/test_model_config.py | 49 +++-- tests/unit/test_proto_events.py | 4 +- tests/unit/test_record.py | 4 +- tests/unit/test_semver.py | 24 +-- tests/unit/test_system_client.py | 56 +++--- tests/unit/test_ui.py | 4 +- tests/unit/test_utils.py | 14 +- 26 files changed, 343 insertions(+), 274 deletions(-) diff --git a/dbt_common/context.py b/dbt_common/context.py index a46b1dd2..d1775c55 100644 --- a/dbt_common/context.py +++ b/dbt_common/context.py @@ -2,6 +2,7 @@ from typing import List, Mapping, Optional from dbt_common.constants import PRIVATE_ENV_PREFIX, SECRET_ENV_PREFIX +from dbt_common.record import Recorder class InvocationContext: @@ -9,7 +10,7 @@ def __init__(self, env: Mapping[str, str]): self._env = {k: v for k, v in env.items() if not k.startswith(PRIVATE_ENV_PREFIX)} self._env_secrets: Optional[List[str]] = None self._env_private = {k: v for k, v in env.items() if k.startswith(PRIVATE_ENV_PREFIX)} - self.recorder = None + self.recorder: Optional[Recorder] = None # This class will also eventually manage the invocation_id, flags, event manager, etc. @property @@ -32,7 +33,7 @@ def env_secrets(self) -> List[str]: _INVOCATION_CONTEXT_VAR: ContextVar[InvocationContext] = ContextVar("DBT_INVOCATION_CONTEXT_VAR") -def reliably_get_invocation_var() -> ContextVar: +def reliably_get_invocation_var() -> ContextVar[InvocationContext]: invocation_var: Optional[ContextVar] = next( (cv for cv in copy_context() if cv.name == _INVOCATION_CONTEXT_VAR.name), None ) diff --git a/dbt_common/dataclass_schema.py b/dbt_common/dataclass_schema.py index 0bad081f..867d5a4c 100644 --- a/dbt_common/dataclass_schema.py +++ b/dbt_common/dataclass_schema.py @@ -1,4 +1,4 @@ -from typing import ClassVar, cast, get_type_hints, List, Tuple, Dict, Any, Optional +from typing import Any, cast, ClassVar, Dict, get_type_hints, List, Optional, Tuple import re import jsonschema from dataclasses import fields, Field @@ -26,7 +26,7 @@ class ValidationError(jsonschema.ValidationError): class DateTimeSerialization(SerializationStrategy): - def serialize(self, value) -> str: + def serialize(self, value: datetime) -> str: out = value.isoformat() # Assume UTC if timezone is missing if value.tzinfo is None: @@ -127,7 +127,7 @@ def _get_fields(cls) -> List[Tuple[Field, str]]: # copied from hologram. Used in tests @classmethod - def _get_field_names(cls): + def _get_field_names(cls) -> List[str]: return [element[1] for element in cls._get_fields()] @@ -152,7 +152,7 @@ def validate(cls, value): # These classes must be in this order or it doesn't work class StrEnum(str, SerializableType, Enum): - def __str__(self): + def __str__(self) -> str: return self.value # https://docs.python.org/3.6/library/enum.html#using-automatic-values diff --git a/dbt_common/exceptions/base.py b/dbt_common/exceptions/base.py index db619326..d966a28d 100644 --- a/dbt_common/exceptions/base.py +++ b/dbt_common/exceptions/base.py @@ -1,5 +1,5 @@ import builtins -from typing import List, Any, Optional +from typing import Any, List, Optional import os from dbt_common.constants import SECRET_ENV_PREFIX @@ -37,7 +37,7 @@ def __init__(self, msg: str): self.msg = scrub_secrets(msg, env_secrets()) @property - def type(self): + def type(self) -> str: return "Internal" def process_stack(self): @@ -59,7 +59,7 @@ def process_stack(self): return lines - def __str__(self): + def __str__(self) -> str: if hasattr(self.msg, "split"): split_msg = self.msg.split("\n") else: diff --git a/dbt_common/helper_types.py b/dbt_common/helper_types.py index 0ca435b7..8611f39f 100644 --- a/dbt_common/helper_types.py +++ b/dbt_common/helper_types.py @@ -19,7 +19,7 @@ class NVEnum(StrEnum): novalue = "novalue" - def __eq__(self, other): + def __eq__(self, other) -> bool: return isinstance(other, NVEnum) @@ -59,7 +59,7 @@ def includes(self, item_name: str) -> bool: item_name in self.include or self.include in self.INCLUDE_ALL ) and item_name not in self.exclude - def _validate_items(self, items: List[str]): + def _validate_items(self, items: List[str]) -> None: pass diff --git a/dbt_common/record.py b/dbt_common/record.py index c204faa6..8fe068bb 100644 --- a/dbt_common/record.py +++ b/dbt_common/record.py @@ -14,8 +14,6 @@ from enum import Enum from typing import Any, Callable, Dict, List, Mapping, Optional, Type -from dbt_common.context import get_invocation_context - class Record: """An instance of this abstract Record class represents a request made by dbt @@ -295,9 +293,11 @@ def record_function_inner(func_to_record): return func_to_record @functools.wraps(func_to_record) - def record_replay_wrapper(*args, **kwargs): - recorder: Recorder = None + def record_replay_wrapper(*args, **kwargs) -> Any: + recorder: Optional[Recorder] = None try: + from dbt_common.context import get_invocation_context + recorder = get_invocation_context().recorder except LookupError: pass diff --git a/dbt_common/semver.py b/dbt_common/semver.py index 951f4e8e..fbdcefa5 100644 --- a/dbt_common/semver.py +++ b/dbt_common/semver.py @@ -1,6 +1,6 @@ from dataclasses import dataclass import re -from typing import List +from typing import List, Iterable import dbt_common.exceptions.base from dbt_common.exceptions import VersionsNotCompatibleError @@ -74,7 +74,7 @@ def _cmp(a, b): @dataclass class VersionSpecifier(VersionSpecification): - def to_version_string(self, skip_matcher=False): + def to_version_string(self, skip_matcher: bool = False) -> str: prerelease = "" build = "" matcher = "" @@ -92,7 +92,7 @@ def to_version_string(self, skip_matcher=False): ) @classmethod - def from_version_string(cls, version_string): + def from_version_string(cls, version_string: str) -> "VersionSpecifier": match = _VERSION_REGEX.match(version_string) if not match: @@ -104,7 +104,7 @@ def from_version_string(cls, version_string): return cls.from_dict(matched) - def __str__(self): + def __str__(self) -> str: return self.to_version_string() def to_range(self) -> "VersionRange": @@ -192,32 +192,32 @@ def compare(self, other): return 0 - def __lt__(self, other): + def __lt__(self, other) -> bool: return self.compare(other) == -1 - def __gt__(self, other): + def __gt__(self, other) -> bool: return self.compare(other) == 1 - def __eq___(self, other): + def __eq___(self, other) -> bool: return self.compare(other) == 0 def __cmp___(self, other): return self.compare(other) @property - def is_unbounded(self): + def is_unbounded(self) -> bool: return False @property - def is_lower_bound(self): + def is_lower_bound(self) -> bool: return self.matcher in [Matchers.GREATER_THAN, Matchers.GREATER_THAN_OR_EQUAL] @property - def is_upper_bound(self): + def is_upper_bound(self) -> bool: return self.matcher in [Matchers.LESS_THAN, Matchers.LESS_THAN_OR_EQUAL] @property - def is_exact(self): + def is_exact(self) -> bool: return self.matcher == Matchers.EXACT @classmethod @@ -418,7 +418,7 @@ def reduce_versions(*args): return to_return -def versions_compatible(*args): +def versions_compatible(*args) -> bool: if len(args) == 1: return True @@ -429,7 +429,7 @@ def versions_compatible(*args): return False -def find_possible_versions(requested_range, available_versions): +def find_possible_versions(requested_range, available_versions: Iterable[str]): possible_versions = [] for version_string in available_versions: @@ -442,7 +442,9 @@ def find_possible_versions(requested_range, available_versions): return [v.to_version_string(skip_matcher=True) for v in sorted_versions] -def resolve_to_specific_version(requested_range, available_versions): +def resolve_to_specific_version( + requested_range, available_versions: Iterable[str] +) -> Optional[str]: max_version = None max_version_string = None diff --git a/dbt_common/utils/casting.py b/dbt_common/utils/casting.py index 811ea376..f366db7f 100644 --- a/dbt_common/utils/casting.py +++ b/dbt_common/utils/casting.py @@ -1,7 +1,7 @@ # This is useful for proto generated classes in particular, since # the default for protobuf for strings is the empty string, so # Optional[str] types don't work for generated Python classes. -from typing import Optional +from typing import Any, Dict, Optional def cast_to_str(string: Optional[str]) -> str: @@ -18,8 +18,8 @@ def cast_to_int(integer: Optional[int]) -> int: return integer -def cast_dict_to_dict_of_strings(dct): - new_dct = {} +def cast_dict_to_dict_of_strings(dct: Dict[Any, Any]) -> Dict[str, str]: + new_dct: Dict[str, str] = {} for k, v in dct.items(): new_dct[str(k)] = str(v) return new_dct diff --git a/dbt_common/utils/executor.py b/dbt_common/utils/executor.py index 0dd8490c..529b02be 100644 --- a/dbt_common/utils/executor.py +++ b/dbt_common/utils/executor.py @@ -1,9 +1,12 @@ import concurrent.futures from contextlib import contextmanager -from contextvars import ContextVar from typing import Protocol, Optional -from dbt_common.context import get_invocation_context, reliably_get_invocation_var +from dbt_common.context import ( + get_invocation_context, + reliably_get_invocation_var, + InvocationContext, +) class ConnectingExecutor(concurrent.futures.Executor): @@ -63,7 +66,7 @@ class HasThreadingConfig(Protocol): threads: Optional[int] -def _thread_initializer(invocation_context: ContextVar) -> None: +def _thread_initializer(invocation_context: InvocationContext) -> None: invocation_var = reliably_get_invocation_var() invocation_var.set(invocation_context) diff --git a/dbt_common/utils/jinja.py b/dbt_common/utils/jinja.py index 36464cbe..260ccb6a 100644 --- a/dbt_common/utils/jinja.py +++ b/dbt_common/utils/jinja.py @@ -5,19 +5,21 @@ DOCS_PREFIX = "dbt_docs__" -def get_dbt_macro_name(name): +def get_dbt_macro_name(name) -> str: if name is None: raise DbtInternalError("Got None for a macro name!") return f"{MACRO_PREFIX}{name}" -def get_dbt_docs_name(name): +def get_dbt_docs_name(name) -> str: if name is None: raise DbtInternalError("Got None for a doc name!") return f"{DOCS_PREFIX}{name}" -def get_materialization_macro_name(materialization_name, adapter_type=None, with_prefix=True): +def get_materialization_macro_name( + materialization_name, adapter_type=None, with_prefix=True +) -> str: if adapter_type is None: adapter_type = "default" name = f"materialization_{materialization_name}_{adapter_type}" diff --git a/tests/unit/test_agate_helper.py b/tests/unit/test_agate_helper.py index 4c12bcd8..fff0d4c6 100644 --- a/tests/unit/test_agate_helper.py +++ b/tests/unit/test_agate_helper.py @@ -46,13 +46,13 @@ class TestAgateHelper(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.tempdir = mkdtemp() - def tearDown(self): + def tearDown(self) -> None: rmtree(self.tempdir) - def test_from_csv(self): + def test_from_csv(self) -> None: path = os.path.join(self.tempdir, "input.csv") with open(path, "wb") as fp: fp.write(SAMPLE_CSV_DATA.encode("utf-8")) @@ -61,7 +61,7 @@ def test_from_csv(self): for idx, row in enumerate(tbl): self.assertEqual(list(row), EXPECTED[idx]) - def test_bom_from_csv(self): + def test_bom_from_csv(self) -> None: path = os.path.join(self.tempdir, "input.csv") with open(path, "wb") as fp: fp.write(SAMPLE_CSV_BOM_DATA.encode("utf-8")) @@ -70,7 +70,7 @@ def test_bom_from_csv(self): for idx, row in enumerate(tbl): self.assertEqual(list(row), EXPECTED[idx]) - def test_from_csv_all_reserved(self): + def test_from_csv_all_reserved(self) -> None: path = os.path.join(self.tempdir, "input.csv") with open(path, "wb") as fp: fp.write(SAMPLE_CSV_DATA.encode("utf-8")) @@ -79,7 +79,7 @@ def test_from_csv_all_reserved(self): for expected, row in zip(EXPECTED_STRINGS, tbl): self.assertEqual(list(row), expected) - def test_from_data(self): + def test_from_data(self) -> None: column_names = ["a", "b", "c", "d", "e", "f", "g"] data = [ { @@ -106,7 +106,7 @@ def test_from_data(self): for idx, row in enumerate(tbl): self.assertEqual(list(row), EXPECTED[idx]) - def test_datetime_formats(self): + def test_datetime_formats(self) -> None: path = os.path.join(self.tempdir, "input.csv") datetimes = [ "20180806T11:33:29.000Z", @@ -120,7 +120,7 @@ def test_datetime_formats(self): tbl = agate_helper.from_csv(path, ()) self.assertEqual(tbl[0][0], expected) - def test_merge_allnull(self): + def test_merge_allnull(self) -> None: t1 = agate_helper.table_from_rows([(1, "a", None), (2, "b", None)], ("a", "b", "c")) t2 = agate_helper.table_from_rows([(3, "c", None), (4, "d", None)], ("a", "b", "c")) result = agate_helper.merge_tables([t1, t2]) @@ -130,7 +130,7 @@ def test_merge_allnull(self): assert isinstance(result.column_types[2], agate_helper.Integer) self.assertEqual(len(result), 4) - def test_merge_mixed(self): + def test_merge_mixed(self) -> None: t1 = agate_helper.table_from_rows( [(1, "a", None, None), (2, "b", None, None)], ("a", "b", "c", "d") ) @@ -181,7 +181,7 @@ def test_merge_mixed(self): assert isinstance(result.column_types[3], agate.data_types.Number) self.assertEqual(len(result), 6) - def test_nocast_string_types(self): + def test_nocast_string_types(self) -> None: # String fields should not be coerced into a representative type # See: https://github.com/dbt-labs/dbt-core/issues/2984 @@ -202,7 +202,7 @@ def test_nocast_string_types(self): for i, row in enumerate(tbl): self.assertEqual(list(row), expected[i]) - def test_nocast_bool_01(self): + def test_nocast_bool_01(self) -> None: # True and False values should not be cast to 1 and 0, and vice versa # See: https://github.com/dbt-labs/dbt-core/issues/4511 diff --git a/tests/unit/test_connection_retries.py b/tests/unit/test_connection_retries.py index 817af7a2..44fc72f5 100644 --- a/tests/unit/test_connection_retries.py +++ b/tests/unit/test_connection_retries.py @@ -19,20 +19,23 @@ def test_no_retry(self): assert result == expected -def no_success_fn(): +def no_success_fn() -> str: raise RequestException("You'll never pass") return "failure" class TestMaxRetries: - def test_no_retry(self): + def test_no_retry(self) -> None: fn_to_retry = functools.partial(no_success_fn) with pytest.raises(ConnectionError): connection_exception_retry(fn_to_retry, 3) -def single_retry_fn(): +counter = 0 + + +def single_retry_fn() -> str: global counter if counter == 0: counter += 1 @@ -45,7 +48,7 @@ def single_retry_fn(): class TestSingleRetry: - def test_no_retry(self): + def test_no_retry(self) -> None: global counter counter = 0 diff --git a/tests/unit/test_contextvars.py b/tests/unit/test_contextvars.py index 4eb58e6c..1aa9425f 100644 --- a/tests/unit/test_contextvars.py +++ b/tests/unit/test_contextvars.py @@ -1,7 +1,7 @@ from dbt_common.events.contextvars import log_contextvars, get_node_info, set_log_contextvars -def test_contextvars(): +def test_contextvars() -> None: node_info = { "unique_id": "model.test.my_model", "started_at": None, diff --git a/tests/unit/test_contracts_util.py b/tests/unit/test_contracts_util.py index 2a620370..d2fc4493 100644 --- a/tests/unit/test_contracts_util.py +++ b/tests/unit/test_contracts_util.py @@ -13,7 +13,7 @@ class ExampleMergableClass(Mergeable): class TestMergableClass(unittest.TestCase): - def test_mergeability(self): + def test_mergeability(self) -> None: mergeable1 = ExampleMergableClass( attr_a="loses", attr_b=None, attr_c=["I'll", "still", "exist"] ) diff --git a/tests/unit/test_core_dbt_utils.py b/tests/unit/test_core_dbt_utils.py index 8a0e836e..7419cd8d 100644 --- a/tests/unit/test_core_dbt_utils.py +++ b/tests/unit/test_core_dbt_utils.py @@ -7,30 +7,30 @@ class TestCommonDbtUtils(unittest.TestCase): - def test_connection_exception_retry_none(self): - Counter._reset(self) - connection_exception_retry(lambda: Counter._add(self), 5) + def test_connection_exception_retry_none(self) -> None: + Counter._reset() + connection_exception_retry(lambda: Counter._add(), 5) self.assertEqual(1, counter) - def test_connection_exception_retry_success_requests_exception(self): - Counter._reset(self) - connection_exception_retry(lambda: Counter._add_with_requests_exception(self), 5) + def test_connection_exception_retry_success_requests_exception(self) -> None: + Counter._reset() + connection_exception_retry(lambda: Counter._add_with_requests_exception(), 5) self.assertEqual(2, counter) # 2 = original attempt returned None, plus 1 retry - def test_connection_exception_retry_max(self): - Counter._reset(self) + def test_connection_exception_retry_max(self) -> None: + Counter._reset() with self.assertRaises(ConnectionError): - connection_exception_retry(lambda: Counter._add_with_exception(self), 5) + connection_exception_retry(lambda: Counter._add_with_exception(), 5) self.assertEqual(6, counter) # 6 = original attempt plus 5 retries - def test_connection_exception_retry_success_failed_untar(self): - Counter._reset(self) - connection_exception_retry(lambda: Counter._add_with_untar_exception(self), 5) + def test_connection_exception_retry_success_failed_untar(self) -> None: + Counter._reset() + connection_exception_retry(lambda: Counter._add_with_untar_exception(), 5) self.assertEqual(2, counter) # 2 = original attempt returned ReadError, plus 1 retry - def test_connection_exception_retry_success_failed_eofexception(self): - Counter._reset(self) - connection_exception_retry(lambda: Counter._add_with_eof_exception(self), 5) + def test_connection_exception_retry_success_failed_eofexception(self) -> None: + Counter._reset() + connection_exception_retry(lambda: Counter._add_with_eof_exception(), 5) self.assertEqual(2, counter) # 2 = original attempt returned EOFError, plus 1 retry @@ -38,36 +38,42 @@ def test_connection_exception_retry_success_failed_eofexception(self): class Counter: - def _add(self): + @classmethod + def _add(cls) -> None: global counter counter += 1 # All exceptions that Requests explicitly raises inherit from # requests.exceptions.RequestException so we want to make sure that raises plus one exception # that inherit from it for sanity - def _add_with_requests_exception(self): + @classmethod + def _add_with_requests_exception(cls) -> None: global counter counter += 1 if counter < 2: raise requests.exceptions.RequestException - def _add_with_exception(self): + @classmethod + def _add_with_exception(cls) -> None: global counter counter += 1 raise requests.exceptions.ConnectionError - def _add_with_untar_exception(self): + @classmethod + def _add_with_untar_exception(cls) -> None: global counter counter += 1 if counter < 2: raise tarfile.ReadError - def _add_with_eof_exception(self): + @classmethod + def _add_with_eof_exception(cls) -> None: global counter counter += 1 if counter < 2: raise EOFError - def _reset(self): + @classmethod + def _reset(cls) -> None: global counter counter = 0 diff --git a/tests/unit/test_diff.py b/tests/unit/test_diff.py index 791263f3..54f735e3 100644 --- a/tests/unit/test_diff.py +++ b/tests/unit/test_diff.py @@ -1,4 +1,6 @@ import json +from typing import Any, Dict + import pytest from dbt_common.record import Diff @@ -191,7 +193,7 @@ def open_mock(file, *args, **kwargs): return open_mock -def test_calculate_diff_no_diff(monkeypatch): +def test_calculate_diff_no_diff(monkeypatch) -> None: # Mock data for the files current_recording_data = { "GetEnvRecord": [ @@ -251,11 +253,11 @@ def test_calculate_diff_no_diff(monkeypatch): previous_recording_path=previous_recording_path, ) result = diff_instance.calculate_diff() - expected_result = {"GetEnvRecord": {}, "DefaultKey": {}} + expected_result: Dict[str, Any] = {"GetEnvRecord": {}, "DefaultKey": {}} assert result == expected_result -def test_calculate_diff_with_diff(monkeypatch): +def test_calculate_diff_with_diff(monkeypatch) -> None: # Mock data for the files current_recording_data = { "GetEnvRecord": [ diff --git a/tests/unit/test_event_handler.py b/tests/unit/test_event_handler.py index 80d5ae2b..f38938b6 100644 --- a/tests/unit/test_event_handler.py +++ b/tests/unit/test_event_handler.py @@ -5,7 +5,7 @@ from dbt_common.events.event_manager import TestEventManager -def test_event_logging_handler_emits_records_correctly(): +def test_event_logging_handler_emits_records_correctly() -> None: event_manager = TestEventManager() handler = DbtEventLoggingHandler(event_manager=event_manager, level=logging.DEBUG) log = logging.getLogger("test") @@ -27,7 +27,7 @@ def test_event_logging_handler_emits_records_correctly(): assert event_manager.event_history[5][1] == EventLevel.ERROR -def test_set_package_logging_sets_level_correctly(): +def test_set_package_logging_sets_level_correctly() -> None: event_manager = TestEventManager() log = logging.getLogger("test") set_package_logging("test", logging.DEBUG, event_manager) diff --git a/tests/unit/test_helper_types.py b/tests/unit/test_helper_types.py index 1a9519de..ba98803c 100644 --- a/tests/unit/test_helper_types.py +++ b/tests/unit/test_helper_types.py @@ -1,11 +1,12 @@ import pytest +from typing import List, Union from dbt_common.helper_types import IncludeExclude, WarnErrorOptions from dbt_common.dataclass_schema import ValidationError class TestIncludeExclude: - def test_init_invalid(self): + def test_init_invalid(self) -> None: with pytest.raises(ValidationError): IncludeExclude(include="invalid") @@ -22,14 +23,16 @@ def test_init_invalid(self): (["ItemA", "ItemB"], [], True), ], ) - def test_includes(self, include, exclude, expected_includes): + def test_includes( + self, include: Union[str, List[str]], exclude: List[str], expected_includes: bool + ) -> None: include_exclude = IncludeExclude(include=include, exclude=exclude) assert include_exclude.includes("ItemA") == expected_includes class TestWarnErrorOptions: - def test_init_invalid_error(self): + def test_init_invalid_error(self) -> None: with pytest.raises(ValidationError): WarnErrorOptions(include=["InvalidError"], valid_error_names=set(["ValidError"])) @@ -38,14 +41,14 @@ def test_init_invalid_error(self): include="*", exclude=["InvalidError"], valid_error_names=set(["ValidError"]) ) - def test_init_invalid_error_default_valid_error_names(self): + def test_init_invalid_error_default_valid_error_names(self) -> None: with pytest.raises(ValidationError): WarnErrorOptions(include=["InvalidError"]) with pytest.raises(ValidationError): WarnErrorOptions(include="*", exclude=["InvalidError"]) - def test_init_valid_error(self): + def test_init_valid_error(self) -> None: warn_error_options = WarnErrorOptions( include=["ValidError"], valid_error_names=set(["ValidError"]) ) @@ -58,18 +61,18 @@ def test_init_valid_error(self): assert warn_error_options.include == "*" assert warn_error_options.exclude == ["ValidError"] - def test_init_default_silence(self): + def test_init_default_silence(self) -> None: my_options = WarnErrorOptions(include="*") assert my_options.silence == [] - def test_init_invalid_silence_event(self): + def test_init_invalid_silence_event(self) -> None: with pytest.raises(ValidationError): WarnErrorOptions(include="*", silence=["InvalidError"]) - def test_init_valid_silence_event(self): + def test_init_valid_silence_event(self) -> None: all_events = ["MySilencedEvent"] my_options = WarnErrorOptions( - include="*", silence=all_events, valid_error_names=all_events + include="*", silence=all_events, valid_error_names=set(all_events) ) assert my_options.silence == all_events @@ -81,14 +84,16 @@ def test_init_valid_silence_event(self): ("*", ["ItemB"], True), ], ) - def test_includes(self, include, silence, expected_includes): + def test_includes( + self, include: Union[str, List[str]], silence: List[str], expected_includes: bool + ) -> None: include_exclude = WarnErrorOptions( include=include, silence=silence, valid_error_names={"ItemA", "ItemB"} ) assert include_exclude.includes("ItemA") == expected_includes - def test_silenced(self): + def test_silenced(self) -> None: my_options = WarnErrorOptions(include="*", silence=["ItemA"], valid_error_names={"ItemA"}) assert my_options.silenced("ItemA") assert not my_options.silenced("ItemB") diff --git a/tests/unit/test_invocation_context.py b/tests/unit/test_invocation_context.py index b6697f8e..3dc832d3 100644 --- a/tests/unit/test_invocation_context.py +++ b/tests/unit/test_invocation_context.py @@ -2,13 +2,13 @@ from dbt_common.context import InvocationContext -def test_invocation_context_env(): +def test_invocation_context_env() -> None: test_env = {"VAR_1": "value1", "VAR_2": "value2"} ic = InvocationContext(env=test_env) assert ic.env == test_env -def test_invocation_context_secrets(): +def test_invocation_context_secrets() -> None: test_env = { f"{SECRET_ENV_PREFIX}_VAR_1": "secret1", f"{SECRET_ENV_PREFIX}VAR_2": "secret2", @@ -16,10 +16,10 @@ def test_invocation_context_secrets(): f"foo{SECRET_ENV_PREFIX}": "non-secret", } ic = InvocationContext(env=test_env) - assert set(ic.env_secrets) == set(["secret1", "secret2"]) + assert set(ic.env_secrets) == {"secret1", "secret2"} -def test_invocation_context_private(): +def test_invocation_context_private() -> None: test_env = { f"{PRIVATE_ENV_PREFIX}_VAR_1": "private1", f"{PRIVATE_ENV_PREFIX}VAR_2": "private2", diff --git a/tests/unit/test_jinja.py b/tests/unit/test_jinja.py index f038a1ec..e906a0ac 100644 --- a/tests/unit/test_jinja.py +++ b/tests/unit/test_jinja.py @@ -1,23 +1,26 @@ import unittest +from dbt_common.clients._jinja_blocks import BlockTag from dbt_common.clients.jinja import extract_toplevel_blocks from dbt_common.exceptions import CompilationError class TestBlockLexer(unittest.TestCase): - def test_basic(self): + def test_basic(self) -> None: body = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n' block_data = " \n\r\t{%- mytype foo %}" + body + "{%endmytype -%}" blocks = extract_toplevel_blocks( block_data, allowed_blocks={"mytype"}, collect_raw_data=False ) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "mytype") - self.assertEqual(blocks[0].block_name, "foo") - self.assertEqual(blocks[0].contents, body) - self.assertEqual(blocks[0].full_block, block_data) - - def test_multiple(self): + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "mytype") + self.assertEqual(b0.block_name, "foo") + self.assertEqual(b0.contents, body) + self.assertEqual(b0.full_block, block_data) + + def test_multiple(self) -> None: body_one = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n' body_two = ( "{{ config(bar=1)}}\r\nselect * from {% if foo %} thing " @@ -37,7 +40,7 @@ def test_multiple(self): ) self.assertEqual(len(blocks), 2) - def test_comments(self): + def test_comments(self) -> None: body = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n' comment = "{# my comment #}" block_data = " \n\r\t{%- mytype foo %}" + body + "{%endmytype -%}" @@ -45,12 +48,14 @@ def test_comments(self): comment + block_data, allowed_blocks={"mytype"}, collect_raw_data=False ) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "mytype") - self.assertEqual(blocks[0].block_name, "foo") - self.assertEqual(blocks[0].contents, body) - self.assertEqual(blocks[0].full_block, block_data) - - def test_evil_comments(self): + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "mytype") + self.assertEqual(b0.block_name, "foo") + self.assertEqual(b0.contents, body) + self.assertEqual(b0.full_block, block_data) + + def test_evil_comments(self) -> None: body = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n' comment = ( "{# external comment {% othertype bar %} select * from " @@ -61,12 +66,14 @@ def test_evil_comments(self): comment + block_data, allowed_blocks={"mytype"}, collect_raw_data=False ) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "mytype") - self.assertEqual(blocks[0].block_name, "foo") - self.assertEqual(blocks[0].contents, body) - self.assertEqual(blocks[0].full_block, block_data) - - def test_nested_comments(self): + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "mytype") + self.assertEqual(b0.block_name, "foo") + self.assertEqual(b0.contents, body) + self.assertEqual(b0.full_block, block_data) + + def test_nested_comments(self) -> None: body = ( '{# my comment #} {{ config(foo="bar") }}' "\r\nselect * from {# my other comment embedding {% endmytype %} #} this.that\r\n" @@ -80,33 +87,43 @@ def test_nested_comments(self): comment + block_data, allowed_blocks={"mytype"}, collect_raw_data=False ) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "mytype") - self.assertEqual(blocks[0].block_name, "foo") - self.assertEqual(blocks[0].contents, body) - self.assertEqual(blocks[0].full_block, block_data) - - def test_complex_file(self): + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "mytype") + self.assertEqual(b0.block_name, "foo") + self.assertEqual(b0.contents, body) + self.assertEqual(b0.full_block, block_data) + + def test_complex_file(self) -> None: blocks = extract_toplevel_blocks( complex_snapshot_file, allowed_blocks={"mytype", "myothertype"}, collect_raw_data=False ) self.assertEqual(len(blocks), 3) - self.assertEqual(blocks[0].block_type_name, "mytype") - self.assertEqual(blocks[0].block_name, "foo") - self.assertEqual(blocks[0].full_block, "{% mytype foo %} some stuff {% endmytype %}") - self.assertEqual(blocks[0].contents, " some stuff ") - self.assertEqual(blocks[1].block_type_name, "mytype") - self.assertEqual(blocks[1].block_name, "bar") - self.assertEqual(blocks[1].full_block, bar_block) - self.assertEqual(blocks[1].contents, bar_block[16:-15].rstrip()) - self.assertEqual(blocks[2].block_type_name, "myothertype") - self.assertEqual(blocks[2].block_name, "x") - self.assertEqual(blocks[2].full_block, x_block.strip()) + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "mytype") + self.assertEqual(b0.block_name, "foo") + self.assertEqual(b0.full_block, "{% mytype foo %} some stuff {% endmytype %}") + self.assertEqual(b0.contents, " some stuff ") + + b1 = blocks[1] + assert isinstance(b1, BlockTag) + self.assertEqual(b1.block_type_name, "mytype") + self.assertEqual(b1.block_name, "bar") + self.assertEqual(b1.full_block, bar_block) + self.assertEqual(b1.contents, bar_block[16:-15].rstrip()) + + b2 = blocks[2] + assert isinstance(b2, BlockTag) + self.assertEqual(b2.block_type_name, "myothertype") + self.assertEqual(b2.block_name, "x") + self.assertEqual(b2.full_block, x_block.strip()) self.assertEqual( - blocks[2].contents, + b2.contents, x_block[len("\n{% myothertype x %}") : -len("{% endmyothertype %}\n")], ) - def test_peaceful_macro_coexistence(self): + def test_peaceful_macro_coexistence(self) -> None: body = ( "{# my macro #} {% macro foo(a, b) %} do a thing " "{%- endmacro %} {# my model #} {% a b %} test {% enda %}" @@ -116,15 +133,22 @@ def test_peaceful_macro_coexistence(self): ) self.assertEqual(len(blocks), 4) self.assertEqual(blocks[0].full_block, "{# my macro #} ") - self.assertEqual(blocks[1].block_type_name, "macro") - self.assertEqual(blocks[1].block_name, "foo") - self.assertEqual(blocks[1].contents, " do a thing") + + b1 = blocks[1] + assert isinstance(b1, BlockTag) + self.assertEqual(b1.block_type_name, "macro") + self.assertEqual(b1.block_name, "foo") + self.assertEqual(b1.contents, " do a thing") + self.assertEqual(blocks[2].full_block, " {# my model #} ") - self.assertEqual(blocks[3].block_type_name, "a") - self.assertEqual(blocks[3].block_name, "b") - self.assertEqual(blocks[3].contents, " test ") - def test_macro_with_trailing_data(self): + b3 = blocks[3] + assert isinstance(b3, BlockTag) + self.assertEqual(b3.block_type_name, "a") + self.assertEqual(b3.block_name, "b") + self.assertEqual(b3.contents, " test ") + + def test_macro_with_trailing_data(self) -> None: body = ( "{# my macro #} {% macro foo(a, b) %} do a thing {%- endmacro %} " "{# my model #} {% a b %} test {% enda %} raw data so cool" @@ -134,16 +158,24 @@ def test_macro_with_trailing_data(self): ) self.assertEqual(len(blocks), 5) self.assertEqual(blocks[0].full_block, "{# my macro #} ") - self.assertEqual(blocks[1].block_type_name, "macro") - self.assertEqual(blocks[1].block_name, "foo") - self.assertEqual(blocks[1].contents, " do a thing") + + b1 = blocks[1] + assert isinstance(b1, BlockTag) + self.assertEqual(b1.block_type_name, "macro") + self.assertEqual(b1.block_name, "foo") + self.assertEqual(b1.contents, " do a thing") + self.assertEqual(blocks[2].full_block, " {# my model #} ") - self.assertEqual(blocks[3].block_type_name, "a") - self.assertEqual(blocks[3].block_name, "b") - self.assertEqual(blocks[3].contents, " test ") + + b3 = blocks[3] + assert isinstance(b3, BlockTag) + self.assertEqual(b3.block_type_name, "a") + self.assertEqual(b3.block_name, "b") + self.assertEqual(b3.contents, " test ") + self.assertEqual(blocks[4].full_block, " raw data so cool") - def test_macro_with_crazy_args(self): + def test_macro_with_crazy_args(self) -> None: body = ( """{% macro foo(a, b=asdf("cool this is 'embedded'" * 3) + external_var, c)%}""" "cool{# block comment with {% endmacro %} in it #} stuff here " @@ -151,38 +183,44 @@ def test_macro_with_crazy_args(self): ) blocks = extract_toplevel_blocks(body, allowed_blocks={"macro"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "macro") - self.assertEqual(blocks[0].block_name, "foo") + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "macro") + self.assertEqual(b0.block_name, "foo") self.assertEqual( blocks[0].contents, "cool{# block comment with {% endmacro %} in it #} stuff here " ) - def test_materialization_parse(self): + def test_materialization_parse(self) -> None: body = "{% materialization xxx, default %} ... {% endmaterialization %}" blocks = extract_toplevel_blocks( body, allowed_blocks={"materialization"}, collect_raw_data=False ) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "materialization") - self.assertEqual(blocks[0].block_name, "xxx") - self.assertEqual(blocks[0].full_block, body) + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "materialization") + self.assertEqual(b0.block_name, "xxx") + self.assertEqual(b0.full_block, body) body = '{% materialization xxx, adapter="other" %} ... {% endmaterialization %}' blocks = extract_toplevel_blocks( body, allowed_blocks={"materialization"}, collect_raw_data=False ) + b0 = blocks[0] + assert isinstance(b0, BlockTag) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "materialization") - self.assertEqual(blocks[0].block_name, "xxx") - self.assertEqual(blocks[0].full_block, body) + self.assertEqual(b0.block_type_name, "materialization") + self.assertEqual(b0.block_name, "xxx") + self.assertEqual(b0.full_block, body) - def test_nested_not_ok(self): + def test_nested_not_ok(self) -> None: # we don't allow nesting same blocks body = "{% myblock a %} {% myblock b %} {% endmyblock %} {% endmyblock %}" with self.assertRaises(CompilationError): extract_toplevel_blocks(body, allowed_blocks={"myblock"}) - def test_incomplete_block_failure(self): + def test_incomplete_block_failure(self) -> None: fullbody = "{% myblock foo %} {% endmyblock %}" for length in range(len("{% myblock foo %}"), len(fullbody) - 1): body = fullbody[:length] @@ -194,45 +232,45 @@ def test_wrong_end_failure(self): with self.assertRaises(CompilationError): extract_toplevel_blocks(body, allowed_blocks={"myblock", "otherblock"}) - def test_comment_no_end_failure(self): + def test_comment_no_end_failure(self) -> None: body = "{# " with self.assertRaises(CompilationError): extract_toplevel_blocks(body) - def test_comment_only(self): + def test_comment_only(self) -> None: body = "{# myblock #}" blocks = extract_toplevel_blocks(body) self.assertEqual(len(blocks), 1) blocks = extract_toplevel_blocks(body, collect_raw_data=False) self.assertEqual(len(blocks), 0) - def test_comment_block_self_closing(self): + def test_comment_block_self_closing(self) -> None: # test the case where a comment start looks a lot like it closes itself # (but it doesn't in jinja!) body = "{#} {% myblock foo %} {#}" blocks = extract_toplevel_blocks(body, collect_raw_data=False) self.assertEqual(len(blocks), 0) - def test_embedded_self_closing_comment_block(self): + def test_embedded_self_closing_comment_block(self) -> None: body = "{% myblock foo %} {#}{% endmyblock %} {#}{% endmyblock %}" blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, body) self.assertEqual(blocks[0].contents, " {#}{% endmyblock %} {#}") - def test_set_statement(self): + def test_set_statement(self) -> None: body = "{% set x = 1 %}{% myblock foo %}hi{% endmyblock %}" blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}") - def test_set_block(self): + def test_set_block(self) -> None: body = "{% set x %}1{% endset %}{% myblock foo %}hi{% endmyblock %}" blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}") - def test_crazy_set_statement(self): + def test_crazy_set_statement(self) -> None: body = ( '{% set x = (thing("{% myblock foo %}")) %}{% otherblock bar %}x{% endotherblock %}' '{% set y = otherthing("{% myblock foo %}") %}' @@ -244,19 +282,19 @@ def test_crazy_set_statement(self): self.assertEqual(blocks[0].full_block, "{% otherblock bar %}x{% endotherblock %}") self.assertEqual(blocks[0].block_type_name, "otherblock") - def test_do_statement(self): + def test_do_statement(self) -> None: body = "{% do thing.update() %}{% myblock foo %}hi{% endmyblock %}" blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}") - def test_deceptive_do_statement(self): + def test_deceptive_do_statement(self) -> None: body = "{% do thing %}{% myblock foo %}hi{% endmyblock %}" blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}") - def test_do_block(self): + def test_do_block(self) -> None: body = "{% do %}thing.update(){% enddo %}{% myblock foo %}hi{% endmyblock %}" blocks = extract_toplevel_blocks( body, allowed_blocks={"do", "myblock"}, collect_raw_data=False @@ -266,7 +304,7 @@ def test_do_block(self): self.assertEqual(blocks[0].block_type_name, "do") self.assertEqual(blocks[1].full_block, "{% myblock foo %}hi{% endmyblock %}") - def test_crazy_do_statement(self): + def test_crazy_do_statement(self) -> None: body = ( '{% do (thing("{% myblock foo %}")) %}{% otherblock bar %}x{% endotherblock %}' '{% do otherthing("{% myblock foo %}") %}{% myblock x %}hi{% endmyblock %}' @@ -280,7 +318,7 @@ def test_crazy_do_statement(self): self.assertEqual(blocks[1].full_block, "{% myblock x %}hi{% endmyblock %}") self.assertEqual(blocks[1].block_type_name, "myblock") - def test_awful_jinja(self): + def test_awful_jinja(self) -> None: blocks = extract_toplevel_blocks( if_you_do_this_you_are_awful, allowed_blocks={"snapshot", "materialization"}, @@ -304,63 +342,71 @@ def test_awful_jinja(self): self.assertEqual(blocks[1].block_type_name, "materialization") self.assertEqual(blocks[1].contents, "\nhi\n") - def test_quoted_endblock_within_block(self): + def test_quoted_endblock_within_block(self) -> None: body = '{% myblock something -%} {% set x = ("{% endmyblock %}") %} {% endmyblock %}' blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].block_type_name, "myblock") self.assertEqual(blocks[0].contents, '{% set x = ("{% endmyblock %}") %} ') - def test_docs_block(self): + def test_docs_block(self) -> None: body = ( "{% docs __my_doc__ %} asdf {# nope {% enddocs %}} #} {% enddocs %}" '{% docs __my_other_doc__ %} asdf "{% enddocs %}' ) blocks = extract_toplevel_blocks(body, allowed_blocks={"docs"}, collect_raw_data=False) self.assertEqual(len(blocks), 2) - self.assertEqual(blocks[0].block_type_name, "docs") - self.assertEqual(blocks[0].contents, " asdf {# nope {% enddocs %}} #} ") - self.assertEqual(blocks[0].block_name, "__my_doc__") - self.assertEqual(blocks[1].block_type_name, "docs") - self.assertEqual(blocks[1].contents, ' asdf "') - self.assertEqual(blocks[1].block_name, "__my_other_doc__") - - def test_docs_block_expr(self): + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "docs") + self.assertEqual(b0.contents, " asdf {# nope {% enddocs %}} #} ") + self.assertEqual(b0.block_name, "__my_doc__") + b1 = blocks[1] + assert isinstance(b1, BlockTag) + self.assertEqual(b1.block_type_name, "docs") + self.assertEqual(b1.contents, ' asdf "') + self.assertEqual(b1.block_name, "__my_other_doc__") + + def test_docs_block_expr(self) -> None: body = '{% docs more_doc %} asdf {{ "{% enddocs %}" ~ "}}" }}{% enddocs %}' blocks = extract_toplevel_blocks(body, allowed_blocks={"docs"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "docs") - self.assertEqual(blocks[0].contents, ' asdf {{ "{% enddocs %}" ~ "}}" }}') - self.assertEqual(blocks[0].block_name, "more_doc") + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "docs") + self.assertEqual(b0.contents, ' asdf {{ "{% enddocs %}" ~ "}}" }}') + self.assertEqual(b0.block_name, "more_doc") - def test_unclosed_model_quotes(self): + def test_unclosed_model_quotes(self) -> None: # test case for https://github.com/dbt-labs/dbt-core/issues/1533 body = '{% model my_model -%} select * from "something"."something_else{% endmodel %}' blocks = extract_toplevel_blocks(body, allowed_blocks={"model"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "model") - self.assertEqual(blocks[0].contents, 'select * from "something"."something_else') - self.assertEqual(blocks[0].block_name, "my_model") + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "model") + self.assertEqual(b0.contents, 'select * from "something"."something_else') + self.assertEqual(b0.block_name, "my_model") - def test_if(self): + def test_if(self) -> None: # if you conditionally define your macros/models, don't body = "{% if true %}{% macro my_macro() %} adsf {% endmacro %}{% endif %}" with self.assertRaises(CompilationError): extract_toplevel_blocks(body) - def test_if_innocuous(self): + def test_if_innocuous(self) -> None: body = "{% if true %}{% something %}asdfasd{% endsomething %}{% endif %}" blocks = extract_toplevel_blocks(body) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, body) - def test_for(self): + def test_for(self) -> None: # no for-loops over macros. body = "{% for x in range(10) %}{% macro my_macro() %} adsf {% endmacro %}{% endfor %}" with self.assertRaises(CompilationError): extract_toplevel_blocks(body) - def test_for_innocuous(self): + def test_for_innocuous(self) -> None: # no for-loops over macros. body = ( "{% for x in range(10) %}{% something my_something %} adsf " @@ -370,7 +416,7 @@ def test_for_innocuous(self): self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, body) - def test_endif(self): + def test_endif(self) -> None: body = "{% snapshot foo %}select * from thing{% endsnapshot%}{% endif %}" with self.assertRaises(CompilationError) as err: extract_toplevel_blocks(body) @@ -382,7 +428,7 @@ def test_endif(self): str(err.exception), ) - def test_if_endfor(self): + def test_if_endfor(self) -> None: body = "{% if x %}...{% endfor %}{% endif %}" with self.assertRaises(CompilationError) as err: extract_toplevel_blocks(body) @@ -391,7 +437,7 @@ def test_if_endfor(self): str(err.exception), ) - def test_if_endfor_newlines(self): + def test_if_endfor_newlines(self) -> None: body = "{% if x %}\n ...\n {% endfor %}\n{% endif %}" with self.assertRaises(CompilationError) as err: extract_toplevel_blocks(body) diff --git a/tests/unit/test_model_config.py b/tests/unit/test_model_config.py index 0cc1e711..57a14438 100644 --- a/tests/unit/test_model_config.py +++ b/tests/unit/test_model_config.py @@ -14,7 +14,7 @@ class ThingWithMergeBehavior(dbtClassMixin): keysappended: Dict[str, int] = field(metadata={"merge": MergeBehavior.DictKeyAppend}) -def test_merge_behavior_meta(): +def test_merge_behavior_meta() -> None: existing = {"foo": "bar"} initial_existing = existing.copy() assert set(MergeBehavior) == { @@ -29,15 +29,14 @@ def test_merge_behavior_meta(): assert existing == initial_existing -def test_merge_behavior_from_field(): - fields = [f[0] for f in ThingWithMergeBehavior._get_fields()] - fields = {name: f for f, name in ThingWithMergeBehavior._get_fields()} - assert set(fields) == {"default_behavior", "appended", "updated", "clobbered", "keysappended"} - assert MergeBehavior.from_field(fields["default_behavior"]) == MergeBehavior.Clobber - assert MergeBehavior.from_field(fields["appended"]) == MergeBehavior.Append - assert MergeBehavior.from_field(fields["updated"]) == MergeBehavior.Update - assert MergeBehavior.from_field(fields["clobbered"]) == MergeBehavior.Clobber - assert MergeBehavior.from_field(fields["keysappended"]) == MergeBehavior.DictKeyAppend +def test_merge_behavior_from_field() -> None: + fields2 = {name: f for f, name in ThingWithMergeBehavior._get_fields()} + assert set(fields2) == {"default_behavior", "appended", "updated", "clobbered", "keysappended"} + assert MergeBehavior.from_field(fields2["default_behavior"]) == MergeBehavior.Clobber + assert MergeBehavior.from_field(fields2["appended"]) == MergeBehavior.Append + assert MergeBehavior.from_field(fields2["updated"]) == MergeBehavior.Update + assert MergeBehavior.from_field(fields2["clobbered"]) == MergeBehavior.Clobber + assert MergeBehavior.from_field(fields2["keysappended"]) == MergeBehavior.DictKeyAppend @dataclass @@ -47,7 +46,7 @@ class ThingWithShowBehavior(dbtClassMixin): shown: float = field(metadata={"show_hide": ShowBehavior.Show}) -def test_show_behavior_meta(): +def test_show_behavior_meta() -> None: existing = {"foo": "bar"} initial_existing = existing.copy() assert set(ShowBehavior) == {ShowBehavior.Hide, ShowBehavior.Show} @@ -57,13 +56,12 @@ def test_show_behavior_meta(): assert existing == initial_existing -def test_show_behavior_from_field(): - fields = [f[0] for f in ThingWithShowBehavior._get_fields()] - fields = {name: f for f, name in ThingWithShowBehavior._get_fields()} - assert set(fields) == {"default_behavior", "hidden", "shown"} - assert ShowBehavior.from_field(fields["default_behavior"]) == ShowBehavior.Show - assert ShowBehavior.from_field(fields["hidden"]) == ShowBehavior.Hide - assert ShowBehavior.from_field(fields["shown"]) == ShowBehavior.Show +def test_show_behavior_from_field() -> None: + fields2 = {name: f for f, name in ThingWithShowBehavior._get_fields()} + assert set(fields2) == {"default_behavior", "hidden", "shown"} + assert ShowBehavior.from_field(fields2["default_behavior"]) == ShowBehavior.Show + assert ShowBehavior.from_field(fields2["hidden"]) == ShowBehavior.Hide + assert ShowBehavior.from_field(fields2["shown"]) == ShowBehavior.Show @dataclass @@ -73,7 +71,7 @@ class ThingWithCompareBehavior(dbtClassMixin): excluded: str = field(metadata={"compare": CompareBehavior.Exclude}) -def test_compare_behavior_meta(): +def test_compare_behavior_meta() -> None: existing = {"foo": "bar"} initial_existing = existing.copy() assert set(CompareBehavior) == {CompareBehavior.Include, CompareBehavior.Exclude} @@ -83,10 +81,9 @@ def test_compare_behavior_meta(): assert existing == initial_existing -def test_compare_behavior_from_field(): - fields = [f[0] for f in ThingWithCompareBehavior._get_fields()] - fields = {name: f for f, name in ThingWithCompareBehavior._get_fields()} - assert set(fields) == {"default_behavior", "included", "excluded"} - assert CompareBehavior.from_field(fields["default_behavior"]) == CompareBehavior.Include - assert CompareBehavior.from_field(fields["included"]) == CompareBehavior.Include - assert CompareBehavior.from_field(fields["excluded"]) == CompareBehavior.Exclude +def test_compare_behavior_from_field() -> None: + fields2 = {name: f for f, name in ThingWithCompareBehavior._get_fields()} + assert set(fields2) == {"default_behavior", "included", "excluded"} + assert CompareBehavior.from_field(fields2["default_behavior"]) == CompareBehavior.Include + assert CompareBehavior.from_field(fields2["included"]) == CompareBehavior.Include + assert CompareBehavior.from_field(fields2["excluded"]) == CompareBehavior.Exclude diff --git a/tests/unit/test_proto_events.py b/tests/unit/test_proto_events.py index 32eb08ae..d21b5062 100644 --- a/tests/unit/test_proto_events.py +++ b/tests/unit/test_proto_events.py @@ -18,7 +18,7 @@ } -def test_events(): +def test_events() -> None: # M020 event event_code = "M020" event = RetryExternalCall(attempt=3, max=5) @@ -45,7 +45,7 @@ def test_events(): assert new_msg.data.attempt == msg.data.attempt -def test_extra_dict_on_event(monkeypatch): +def test_extra_dict_on_event(monkeypatch) -> None: monkeypatch.setenv("DBT_ENV_CUSTOM_ENV_env_key", "env_value") reset_metadata_vars() diff --git a/tests/unit/test_record.py b/tests/unit/test_record.py index b0371498..6e02d710 100644 --- a/tests/unit/test_record.py +++ b/tests/unit/test_record.py @@ -69,7 +69,7 @@ def setup(): os.environ["DBT_RECORDER_FILE_PATH"] = prev_fp -def test_decorator_records(setup): +def test_decorator_records(setup) -> None: os.environ["DBT_RECORDER_MODE"] = "Record" recorder = Recorder(RecorderMode.RECORD, None) set_invocation_context({}) @@ -116,7 +116,7 @@ def not_test_func(a: int, b: str, c: Optional[str] = None) -> str: assert NotTestRecord not in recorder._records_by_type -def test_decorator_replays(setup): +def test_decorator_replays(setup) -> None: os.environ["DBT_RECORDER_MODE"] = "Replay" os.environ["DBT_RECORDER_FILE_PATH"] = "record.json" recorder = Recorder(RecorderMode.REPLAY, None) diff --git a/tests/unit/test_semver.py b/tests/unit/test_semver.py index ae48e592..383d3479 100644 --- a/tests/unit/test_semver.py +++ b/tests/unit/test_semver.py @@ -1,6 +1,6 @@ import itertools import unittest -from typing import List +from typing import List, Optional from dbt_common.exceptions import VersionsNotCompatibleError from dbt_common.semver import ( @@ -23,9 +23,11 @@ def semver_regex_versioning(versions: List[str]) -> bool: return True -def create_range(start_version_string, end_version_string): - start = UnboundedVersionSpecifier() - end = UnboundedVersionSpecifier() +def create_range( + start_version_string: Optional[str], end_version_string: Optional[str] +) -> VersionRange: + start: VersionSpecifier = UnboundedVersionSpecifier() + end: VersionSpecifier = UnboundedVersionSpecifier() if start_version_string is not None: start = VersionSpecifier.from_version_string(start_version_string) @@ -37,24 +39,24 @@ def create_range(start_version_string, end_version_string): class TestSemver(unittest.TestCase): - def assertVersionSetResult(self, inputs, output_range): + def assertVersionSetResult(self, inputs, output_range) -> None: expected = create_range(*output_range) for permutation in itertools.permutations(inputs): self.assertEqual(reduce_versions(*permutation), expected) - def assertInvalidVersionSet(self, inputs): + def assertInvalidVersionSet(self, inputs) -> None: for permutation in itertools.permutations(inputs): with self.assertRaises(VersionsNotCompatibleError): reduce_versions(*permutation) - def test__versions_compatible(self): + def test__versions_compatible(self) -> None: self.assertTrue(versions_compatible("0.0.1", "0.0.1")) self.assertFalse(versions_compatible("0.0.1", "0.0.2")) self.assertTrue(versions_compatible(">0.0.1", "0.0.2")) self.assertFalse(versions_compatible("0.4.5a1", "0.4.5a2")) - def test__semver_regex_versions(self): + def test__semver_regex_versions(self) -> None: self.assertTrue( semver_regex_versioning( [ @@ -140,7 +142,7 @@ def test__semver_regex_versions(self): ) ) - def test__reduce_versions(self): + def test__reduce_versions(self) -> None: self.assertVersionSetResult(["0.0.1", "0.0.1"], ["=0.0.1", "=0.0.1"]) self.assertVersionSetResult(["0.0.1"], ["=0.0.1", "=0.0.1"]) @@ -175,7 +177,7 @@ def test__reduce_versions(self): self.assertInvalidVersionSet(["<0.0.3", ">=0.0.3"]) self.assertInvalidVersionSet(["<0.0.3", ">0.0.3"]) - def test__resolve_to_specific_version(self): + def test__resolve_to_specific_version(self) -> None: self.assertEqual( resolve_to_specific_version(create_range(">0.0.1", None), ["0.0.1", "0.0.2"]), "0.0.2" ) @@ -253,7 +255,7 @@ def test__resolve_to_specific_version(self): "0.9.1", ) - def test__filter_installable(self): + def test__filter_installable(self) -> None: installable = filter_installable( [ "1.1.0", diff --git a/tests/unit/test_system_client.py b/tests/unit/test_system_client.py index a4dcc323..d2cf27ed 100644 --- a/tests/unit/test_system_client.py +++ b/tests/unit/test_system_client.py @@ -12,39 +12,39 @@ class SystemClient(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: super().setUp() self.tmp_dir = mkdtemp() self.profiles_path = "{}/profiles.yml".format(self.tmp_dir) - def set_up_profile(self): + def set_up_profile(self) -> None: with open(self.profiles_path, "w") as f: f.write("ORIGINAL_TEXT") - def get_profile_text(self): + def get_profile_text(self) -> str: with open(self.profiles_path, "r") as f: return f.read() - def tearDown(self): + def tearDown(self) -> None: try: shutil.rmtree(self.tmp_dir) except Exception as e: # noqa: F841 pass - def test__make_file_when_exists(self): + def test__make_file_when_exists(self) -> None: self.set_up_profile() written = dbt_common.clients.system.make_file(self.profiles_path, contents="NEW_TEXT") self.assertFalse(written) self.assertEqual(self.get_profile_text(), "ORIGINAL_TEXT") - def test__make_file_when_not_exists(self): + def test__make_file_when_not_exists(self) -> None: written = dbt_common.clients.system.make_file(self.profiles_path, contents="NEW_TEXT") self.assertTrue(written) self.assertEqual(self.get_profile_text(), "NEW_TEXT") - def test__make_file_with_overwrite(self): + def test__make_file_with_overwrite(self) -> None: self.set_up_profile() written = dbt_common.clients.system.make_file( self.profiles_path, contents="NEW_TEXT", overwrite=True @@ -53,12 +53,12 @@ def test__make_file_with_overwrite(self): self.assertTrue(written) self.assertEqual(self.get_profile_text(), "NEW_TEXT") - def test__make_dir_from_str(self): + def test__make_dir_from_str(self) -> None: test_dir_str = self.tmp_dir + "/test_make_from_str/sub_dir" dbt_common.clients.system.make_directory(test_dir_str) self.assertTrue(Path(test_dir_str).is_dir()) - def test__make_dir_from_pathobj(self): + def test__make_dir_from_pathobj(self) -> None: test_dir_pathobj = Path(self.tmp_dir + "/test_make_from_pathobj/sub_dir") dbt_common.clients.system.make_directory(test_dir_pathobj) self.assertTrue(test_dir_pathobj.is_dir()) @@ -72,7 +72,7 @@ class TestRunCmd(unittest.TestCase): not_a_file = "zzzbbfasdfasdfsdaq" - def setUp(self): + def setUp(self) -> None: self.tempdir = mkdtemp() self.run_dir = os.path.join(self.tempdir, "run_dir") self.does_not_exist = os.path.join(self.tempdir, "does_not_exist") @@ -86,10 +86,10 @@ def setUp(self): with open(self.empty_file, "w") as fp: # noqa: F841 pass # "touch" - def tearDown(self): + def tearDown(self) -> None: shutil.rmtree(self.tempdir) - def test__executable_does_not_exist(self): + def test__executable_does_not_exist(self) -> None: with self.assertRaises(ExecutableError) as exc: dbt_common.clients.system.run_cmd(self.run_dir, [self.does_not_exist]) @@ -99,7 +99,7 @@ def test__executable_does_not_exist(self): self.assertIn("could not find", msg) self.assertIn(self.does_not_exist.lower(), msg) - def test__not_exe(self): + def test__not_exe(self) -> None: with self.assertRaises(ExecutableError) as exc: dbt_common.clients.system.run_cmd(self.run_dir, [self.empty_file]) @@ -112,14 +112,14 @@ def test__not_exe(self): self.assertIn("permissions", msg) self.assertIn(self.empty_file.lower(), msg) - def test__cwd_does_not_exist(self): + def test__cwd_does_not_exist(self) -> None: with self.assertRaises(WorkingDirectoryError) as exc: dbt_common.clients.system.run_cmd(self.does_not_exist, self.exists_cmd) msg = str(exc.exception).lower() self.assertIn("does not exist", msg) self.assertIn(self.does_not_exist.lower(), msg) - def test__cwd_not_directory(self): + def test__cwd_not_directory(self) -> None: with self.assertRaises(WorkingDirectoryError) as exc: dbt_common.clients.system.run_cmd(self.empty_file, self.exists_cmd) @@ -127,7 +127,7 @@ def test__cwd_not_directory(self): self.assertIn("not a directory", msg) self.assertIn(self.empty_file.lower(), msg) - def test__cwd_no_permissions(self): + def test__cwd_no_permissions(self) -> None: # it would be nice to add a windows test. Possible path to that is via # `psexec` (to get SYSTEM privs), use `icacls` to set permissions on # the directory for the test user. I'm pretty sure windows users can't @@ -145,18 +145,18 @@ def test__cwd_no_permissions(self): self.assertIn("permissions", msg) self.assertIn(self.run_dir.lower(), msg) - def test__ok(self): + def test__ok(self) -> None: out, err = dbt_common.clients.system.run_cmd(self.run_dir, self.exists_cmd) self.assertEqual(out.strip(), b"hello") self.assertEqual(err.strip(), b"") class TestFindMatching(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.base_dir = mkdtemp() self.tempdir = mkdtemp(dir=self.base_dir) - def test_find_matching_lowercase_file_pattern(self): + def test_find_matching_lowercase_file_pattern(self) -> None: with NamedTemporaryFile(prefix="sql-files", suffix=".sql", dir=self.tempdir) as named_file: file_path = os.path.dirname(named_file.name) relative_path = os.path.basename(file_path) @@ -175,7 +175,7 @@ def test_find_matching_lowercase_file_pattern(self): ] self.assertEqual(out, expected_output) - def test_find_matching_uppercase_file_pattern(self): + def test_find_matching_uppercase_file_pattern(self) -> None: with NamedTemporaryFile(prefix="sql-files", suffix=".SQL", dir=self.tempdir) as named_file: file_path = os.path.dirname(named_file.name) relative_path = os.path.basename(file_path) @@ -190,12 +190,12 @@ def test_find_matching_uppercase_file_pattern(self): ] self.assertEqual(out, expected_output) - def test_find_matching_file_pattern_not_found(self): + def test_find_matching_file_pattern_not_found(self) -> None: with NamedTemporaryFile(prefix="sql-files", suffix=".SQLT", dir=self.tempdir): out = dbt_common.clients.system.find_matching(self.tempdir, [""], "*.sql") self.assertEqual(out, []) - def test_ignore_spec(self): + def test_ignore_spec(self) -> None: with NamedTemporaryFile(prefix="sql-files", suffix=".sql", dir=self.tempdir): out = dbt_common.clients.system.find_matching( self.tempdir, @@ -207,7 +207,7 @@ def test_ignore_spec(self): ) self.assertEqual(out, []) - def tearDown(self): + def tearDown(self) -> None: try: shutil.rmtree(self.base_dir) except Exception as e: # noqa: F841 @@ -215,18 +215,18 @@ def tearDown(self): class TestUntarPackage(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.base_dir = mkdtemp() self.tempdir = mkdtemp(dir=self.base_dir) self.tempdest = mkdtemp(dir=self.base_dir) - def tearDown(self): + def tearDown(self) -> None: try: shutil.rmtree(self.base_dir) except Exception as e: # noqa: F841 pass - def test_untar_package_success(self): + def test_untar_package_success(self) -> None: # set up a valid tarball to test against with NamedTemporaryFile( prefix="my-package.2", suffix=".tar.gz", dir=self.tempdir, delete=False @@ -244,7 +244,7 @@ def test_untar_package_success(self): path = Path(os.path.join(self.tempdest, relative_file_a)) assert path.is_file() - def test_untar_package_failure(self): + def test_untar_package_failure(self) -> None: # create a text file then rename it as a tar (so it's invalid) with NamedTemporaryFile( prefix="a", suffix=".txt", dir=self.tempdir, delete=False @@ -259,7 +259,7 @@ def test_untar_package_failure(self): with self.assertRaises(tarfile.ReadError) as exc: # noqa: F841 dbt_common.clients.system.untar_package(tar_file_path, self.tempdest) - def test_untar_package_empty(self): + def test_untar_package_empty(self) -> None: # create a tarball with nothing in it with NamedTemporaryFile( prefix="my-empty-package.2", suffix=".tar.gz", dir=self.tempdir diff --git a/tests/unit/test_ui.py b/tests/unit/test_ui.py index 22e431d5..5b70b1d1 100644 --- a/tests/unit/test_ui.py +++ b/tests/unit/test_ui.py @@ -1,11 +1,11 @@ from dbt_common.ui import warning_tag, error_tag -def test_warning_tag(): +def test_warning_tag() -> None: tagged = warning_tag("hi") assert "WARNING" in tagged -def test_error_tag(): +def test_error_tag() -> None: tagged = error_tag("hi") assert "ERROR" in tagged diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 250c20cc..93c57046 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -5,7 +5,7 @@ class TestDeepMerge(unittest.TestCase): - def test__simple_cases(self): + def test__simple_cases(self) -> None: cases = [ {"args": [{}, {"a": 1}], "expected": {"a": 1}, "description": "one key into empty"}, { @@ -27,7 +27,7 @@ def test__simple_cases(self): class TestMerge(unittest.TestCase): - def test__simple_cases(self): + def test__simple_cases(self) -> None: cases = [ {"args": [{}, {"a": 1}], "expected": {"a": 1}, "description": "one key into empty"}, { @@ -49,7 +49,7 @@ def test__simple_cases(self): class TestDeepMap(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.input_value = { "foo": { "bar": "hello", @@ -74,7 +74,7 @@ def intify_all(value, _): except (TypeError, ValueError): return -1 - def test__simple_cases(self): + def test__simple_cases(self) -> None: expected = { "foo": { "bar": -1, @@ -104,7 +104,7 @@ def special_keypath(value, keypath): else: return value - def test__keypath(self): + def test__keypath(self) -> None: expected = { "foo": { "bar": "hello", @@ -128,11 +128,11 @@ def test__keypath(self): actual = dbt_common.utils.dict.deep_map_render(self.special_keypath, expected) self.assertEqual(actual, expected) - def test__noop(self): + def test__noop(self) -> None: actual = dbt_common.utils.dict.deep_map_render(lambda x, _: x, self.input_value) self.assertEqual(actual, self.input_value) - def test_trivial(self): + def test_trivial(self) -> None: cases = [[], {}, 1, "abc", None, True] for case in cases: result = dbt_common.utils.dict.deep_map_render(lambda x, _: x, case) From 6e3828e0b0eced0ae4a42b2534748cf1b005a4aa Mon Sep 17 00:00:00 2001 From: Emily Rockman Date: Mon, 15 Jul 2024 09:32:37 -0500 Subject: [PATCH 05/10] conditionally import deepdiff (#164) only import deepdiff when needed --- dbt_common/record.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/dbt_common/record.py b/dbt_common/record.py index 8fe068bb..b33d4b5a 100644 --- a/dbt_common/record.py +++ b/dbt_common/record.py @@ -10,7 +10,6 @@ import json import os -from deepdiff import DeepDiff # type: ignore from enum import Enum from typing import Any, Callable, Dict, List, Mapping, Optional, Type @@ -52,6 +51,11 @@ def from_dict(cls, dct: Mapping) -> "Record": class Diff: def __init__(self, current_recording_path: str, previous_recording_path: str) -> None: + # deepdiff is expensive to import, so we only do it here when we need it + from deepdiff import DeepDiff # type: ignore + + self.diff = DeepDiff + self.current_recording_path = current_recording_path self.previous_recording_path = previous_recording_path @@ -67,7 +71,7 @@ def diff_query_records(self, current: List, previous: List) -> Dict[str, Any]: if previous[i].get("result").get("table") is not None: previous[i]["result"]["table"] = json.loads(previous[i]["result"]["table"]) - return DeepDiff(previous, current, ignore_order=True, verbose_level=2) + return self.diff(previous, current, ignore_order=True, verbose_level=2) def diff_env_records(self, current: List, previous: List) -> Dict[str, Any]: # The mode and filepath may change. Ignore them. @@ -77,12 +81,12 @@ def diff_env_records(self, current: List, previous: List) -> Dict[str, Any]: "root[0]['result']['env']['DBT_RECORDER_MODE']", ] - return DeepDiff( + return self.diff( previous, current, ignore_order=True, verbose_level=2, exclude_paths=exclude_paths ) def diff_default(self, current: List, previous: List) -> Dict[str, Any]: - return DeepDiff(previous, current, ignore_order=True, verbose_level=2) + return self.diff(previous, current, ignore_order=True, verbose_level=2) def calculate_diff(self) -> Dict[str, Any]: with open(self.current_recording_path) as current_recording: From 71f4d53c084e5960b4804c46eff435abc8718e57 Mon Sep 17 00:00:00 2001 From: Kshitij Aranke Date: Mon, 15 Jul 2024 21:42:40 +0100 Subject: [PATCH 06/10] Fix case-insensitive env vars for Windows (#166) * Draft fix of case-insensitive env vars for Windows. * add test for windows * add test skip reason --------- Co-authored-by: Peter Allen Webb --- .../unreleased/Fixes-20240715-205355.yaml | 6 +++ dbt_common/context.py | 37 +++++++++++++++++-- pyproject.toml | 1 + tests/unit/test_invocation_context.py | 17 ++++++++- 4 files changed, 57 insertions(+), 4 deletions(-) create mode 100644 .changes/unreleased/Fixes-20240715-205355.yaml diff --git a/.changes/unreleased/Fixes-20240715-205355.yaml b/.changes/unreleased/Fixes-20240715-205355.yaml new file mode 100644 index 00000000..780a6ad9 --- /dev/null +++ b/.changes/unreleased/Fixes-20240715-205355.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: Fix case-insensitive env vars for Windows +time: 2024-07-15T20:53:55.946355+01:00 +custom: + Author: peterallenwebb aranke + Issue: "166" diff --git a/dbt_common/context.py b/dbt_common/context.py index d1775c55..947d409a 100644 --- a/dbt_common/context.py +++ b/dbt_common/context.py @@ -1,15 +1,46 @@ +import os from contextvars import ContextVar, copy_context -from typing import List, Mapping, Optional +from typing import List, Mapping, Optional, Iterator from dbt_common.constants import PRIVATE_ENV_PREFIX, SECRET_ENV_PREFIX from dbt_common.record import Recorder +class CaseInsensitiveMapping(Mapping): + def __init__(self, env: Mapping[str, str]): + self._env = {k.casefold(): (k, v) for k, v in env.items()} + + def __getitem__(self, key: str) -> str: + return self._env[key.casefold()][1] + + def __len__(self) -> int: + return len(self._env) + + def __iter__(self) -> Iterator[str]: + for item in self._env.items(): + yield item[0] + + class InvocationContext: def __init__(self, env: Mapping[str, str]): - self._env = {k: v for k, v in env.items() if not k.startswith(PRIVATE_ENV_PREFIX)} + self._env: Mapping[str, str] + + env_public = {} + env_private = {} + + for k, v in env.items(): + if k.startswith(PRIVATE_ENV_PREFIX): + env_private[k] = v + else: + env_public[k] = v + + if os.name == "nt": + self._env = CaseInsensitiveMapping(env_public) + else: + self._env = env_public + self._env_secrets: Optional[List[str]] = None - self._env_private = {k: v for k, v in env.items() if k.startswith(PRIVATE_ENV_PREFIX)} + self._env_private = env_private self.recorder: Optional[Recorder] = None # This class will also eventually manage the invocation_id, flags, event manager, etc. diff --git a/pyproject.toml b/pyproject.toml index 64fc04fc..ba306437 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -126,6 +126,7 @@ ignore = ["E203", "E501", "E741", "W503", "W504"] exclude = [ "dbt_common/events/types_pb2.py", "venv", + ".venv", "env*" ] per-file-ignores = ["*/__init__.py: F401"] diff --git a/tests/unit/test_invocation_context.py b/tests/unit/test_invocation_context.py index 3dc832d3..fbf060ba 100644 --- a/tests/unit/test_invocation_context.py +++ b/tests/unit/test_invocation_context.py @@ -1,5 +1,9 @@ +import os + +import pytest + from dbt_common.constants import PRIVATE_ENV_PREFIX, SECRET_ENV_PREFIX -from dbt_common.context import InvocationContext +from dbt_common.context import InvocationContext, CaseInsensitiveMapping def test_invocation_context_env() -> None: @@ -8,6 +12,17 @@ def test_invocation_context_env() -> None: assert ic.env == test_env +@pytest.mark.skipif( + os.name != "nt", reason="Test for case-insensitive env vars, only run on Windows" +) +def test_invocation_context_windows() -> None: + test_env = {"var_1": "lowercase", "vAr_2": "mixedcase", "VAR_3": "uppercase"} + ic = InvocationContext(env=test_env) + assert ic.env == CaseInsensitiveMapping( + {"var_1": "lowercase", "var_2": "mixedcase", "var_3": "uppercase"} + ) + + def test_invocation_context_secrets() -> None: test_env = { f"{SECRET_ENV_PREFIX}_VAR_1": "secret1", From 8edae92ac4e1c5814be028281215fd90ab29d9c6 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Tue, 16 Jul 2024 10:28:23 -0400 Subject: [PATCH 07/10] add to and to_column(s) to ColumnLevelConstraint and ModelLevelConstraint contracts (#163) * add to and to_column(s) to ColumnLevelConstraint and ModelLevelConstraint contracts * fix schema * linting * changelog entry --- .changes/unreleased/Features-20240716-102457.yaml | 6 ++++++ dbt_common/contracts/constraints.py | 2 ++ 2 files changed, 8 insertions(+) create mode 100644 .changes/unreleased/Features-20240716-102457.yaml diff --git a/.changes/unreleased/Features-20240716-102457.yaml b/.changes/unreleased/Features-20240716-102457.yaml new file mode 100644 index 00000000..096b4259 --- /dev/null +++ b/.changes/unreleased/Features-20240716-102457.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Add to and to_columns to ColumnLevelConstraint and ModelLevelConstraint contracts +time: 2024-07-16T10:24:57.11251-04:00 +custom: + Author: michelleark + Issue: "168" diff --git a/dbt_common/contracts/constraints.py b/dbt_common/contracts/constraints.py index c01ee6f8..4e2d9c7a 100644 --- a/dbt_common/contracts/constraints.py +++ b/dbt_common/contracts/constraints.py @@ -36,6 +36,8 @@ class ColumnLevelConstraint(dbtClassMixin): warn_unsupported: bool = ( True # Warn if constraint is not supported by the platform and won't be in DDL ) + to: Optional[str] = None + to_columns: List[str] = field(default_factory=list) @dataclass From de7d932ce857f080a57a54ff341e38e5e52008f0 Mon Sep 17 00:00:00 2001 From: Kshitij Aranke Date: Tue, 16 Jul 2024 20:18:34 +0100 Subject: [PATCH 08/10] minor version bump to 1.6.0 (#167) --- dbt_common/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbt_common/__about__.py b/dbt_common/__about__.py index e3a0f015..38ec8ede 100644 --- a/dbt_common/__about__.py +++ b/dbt_common/__about__.py @@ -1 +1 @@ -version = "1.5.0" +version = "1.6.0" From 35bdf68db9f0e94ecda1cf2aec419c0fbc95742f Mon Sep 17 00:00:00 2001 From: Peter Webb Date: Tue, 16 Jul 2024 15:37:43 -0400 Subject: [PATCH 09/10] Add group property to records. Correct filter on file records. (#169) * Add group property to records. Correct filter on file records. * Update documentation. * Add changelog entry. --- .../Under the Hood-20240716-125753.yaml | 6 +++++ dbt_common/clients/system.py | 26 +++++++++---------- dbt_common/record.py | 9 ++++--- docs/guides/record_replay.md | 9 ++++--- 4 files changed, 29 insertions(+), 21 deletions(-) create mode 100644 .changes/unreleased/Under the Hood-20240716-125753.yaml diff --git a/.changes/unreleased/Under the Hood-20240716-125753.yaml b/.changes/unreleased/Under the Hood-20240716-125753.yaml new file mode 100644 index 00000000..55b36cb3 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240716-125753.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Add record grouping mechanism to record/replay. +time: 2024-07-16T12:57:53.434099-04:00 +custom: + Author: peterallenwebb + Issue: "169" diff --git a/dbt_common/clients/system.py b/dbt_common/clients/system.py index 00a1ac69..afa7f744 100644 --- a/dbt_common/clients/system.py +++ b/dbt_common/clients/system.py @@ -38,6 +38,15 @@ c_bool = None +def _record_path(path: str) -> bool: + return ( + # TODO: The first check here obviates the next two checks but is probably too coarse? + "dbt/include" not in path + and "dbt/include/global_project" not in path + and "/plugins/postgres/dbt/include/" not in path + ) + + @dataclasses.dataclass class FindMatchingParams: root_path: str @@ -61,12 +70,7 @@ def __init__( def _include(self) -> bool: # Do not record or replay filesystem searches that were performed against # files which are actually part of dbt's implementation. - return ( - "dbt/include" - not in self.root_path # TODO: This actually obviates the next two checks but is probably too coarse? - and "dbt/include/global_project" not in self.root_path - and "/plugins/postgres/dbt/include/" not in self.root_path - ) + return _record_path(self.root_path) @dataclasses.dataclass @@ -150,10 +154,7 @@ class LoadFileParams: def _include(self) -> bool: # Do not record or replay file reads that were performed against files # which are actually part of dbt's implementation. - return ( - "dbt/include/global_project" not in self.path - and "/plugins/postgres/dbt/include/" not in self.path - ) + return _record_path(self.path) @dataclasses.dataclass @@ -248,10 +249,7 @@ class WriteFileParams: def _include(self) -> bool: # Do not record or replay file reads that were performed against files # which are actually part of dbt's implementation. - return ( - "dbt/include/global_project" not in self.path - and "/plugins/postgres/dbt/include/" not in self.path - ) + return _record_path(self.path) @Recorder.register_record_type diff --git a/dbt_common/record.py b/dbt_common/record.py index b33d4b5a..612ddf75 100644 --- a/dbt_common/record.py +++ b/dbt_common/record.py @@ -20,7 +20,8 @@ class Record: to the request, and the 'result' is what is returned.""" params_cls: type - result_cls: Optional[type] + result_cls: Optional[type] = None + group: Optional[str] = None def __init__(self, params, result) -> None: self.params = params @@ -309,9 +310,9 @@ def record_replay_wrapper(*args, **kwargs) -> Any: if recorder is None: return func_to_record(*args, **kwargs) - if ( - recorder.recorded_types is not None - and record_type.__name__ not in recorder.recorded_types + if recorder.recorded_types is not None and not ( + record_type.__name__ in recorder.recorded_types + or record_type.group in recorder.recorded_types ): return func_to_record(*args, **kwargs) diff --git a/docs/guides/record_replay.md b/docs/guides/record_replay.md index aff4c77c..b6dfc7b8 100644 --- a/docs/guides/record_replay.md +++ b/docs/guides/record_replay.md @@ -31,13 +31,16 @@ The final detail needed is to define the classes specified by `params_cls` and ` With these decorators applied and classes defined, dbt is able to record all file access during a run, and mock out the accesses during replay, isolating dbt from actually loading files. At least it would if dbt only used this function for all file access, which is only mostly true. We hope to continue improving the usefulness of this mechanism by adding more recorded functions and routing more operations through them. ## How to record/replay -If `DBT_RECORDER_MODE` is not `replay` or `record`, case insensitive, this is a no-op. Invalid values are ignored and do not throw exceptions. -`DBT_RECODER_TYPES` is optional. It indicates which types to filter the results by and expects a list of strings values for the `Record` subclasses. Any invalid types will be ignored. `all` is a valid type and behaves the same as not populating the env var. +Record/replay behavior is activated and configured via environment variables. When DBT_RECORDER_MODE is unset, the entire subsystem is disabled, and the decorators described above have no effect at all. This helps isolate the subsystem from core's application code, reducing the risk of performance impact or regressions. + +The record/replay subsystem is activated by setting the `DBT_RECORDER_MODE` variable to `replay`, `record`, or `diff`, case insensitive. Invalid values are ignored and do not throw exceptions. + +`DBT_RECODER_TYPES` is optional. It indicates which types to filter the results by and expects a list of strings values for the `Record` subclasses or groups of such classes. For example, all records of database/DWH interaction performed by adapters belong to the `Database` group. Any invalid type or group name will be ignored. `all` is a valid value for this variable and has the same effect as not populating the variable. ```bash -DBT_RECORDER_MODE=record DBT_RECODER_TYPES=QueryRecord,GetEnvRecord dbt run +DBT_RECORDER_MODE=record DBT_RECODER_TYPES=Database dbt run ``` replay need the file to replay From db99ddda5f0d90c0499a000038a9cbb13d8d0b20 Mon Sep 17 00:00:00 2001 From: Emily Rockman Date: Tue, 30 Jul 2024 10:18:34 -0500 Subject: [PATCH 10/10] Update to minor version 1.7 (#172) --- dbt_common/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbt_common/__about__.py b/dbt_common/__about__.py index 38ec8ede..a55413d1 100644 --- a/dbt_common/__about__.py +++ b/dbt_common/__about__.py @@ -1 +1 @@ -version = "1.6.0" +version = "1.7.0"