diff --git a/.changes/unreleased/Features-20240716-102457.yaml b/.changes/unreleased/Features-20240716-102457.yaml new file mode 100644 index 00000000..096b4259 --- /dev/null +++ b/.changes/unreleased/Features-20240716-102457.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Add to and to_columns to ColumnLevelConstraint and ModelLevelConstraint contracts +time: 2024-07-16T10:24:57.11251-04:00 +custom: + Author: michelleark + Issue: "168" diff --git a/.changes/unreleased/Fixes-20240715-205355.yaml b/.changes/unreleased/Fixes-20240715-205355.yaml new file mode 100644 index 00000000..780a6ad9 --- /dev/null +++ b/.changes/unreleased/Fixes-20240715-205355.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: Fix case-insensitive env vars for Windows +time: 2024-07-15T20:53:55.946355+01:00 +custom: + Author: peterallenwebb aranke + Issue: "166" diff --git a/.changes/unreleased/Under the Hood-20240618-155025.yaml b/.changes/unreleased/Under the Hood-20240618-155025.yaml new file mode 100644 index 00000000..b540d3d7 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240618-155025.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Deserialize Record objects on a just-in-time basis. +time: 2024-06-18T15:50:25.985387-04:00 +custom: + Author: peterallenwebb + Issue: "151" diff --git a/.changes/unreleased/Under the Hood-20240716-125753.yaml b/.changes/unreleased/Under the Hood-20240716-125753.yaml new file mode 100644 index 00000000..55b36cb3 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240716-125753.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Add record grouping mechanism to record/replay. +time: 2024-07-16T12:57:53.434099-04:00 +custom: + Author: peterallenwebb + Issue: "169" diff --git a/dbt_common/__about__.py b/dbt_common/__about__.py index d619c757..a55413d1 100644 --- a/dbt_common/__about__.py +++ b/dbt_common/__about__.py @@ -1 +1 @@ -version = "1.4.0" +version = "1.7.0" diff --git a/dbt_common/clients/system.py b/dbt_common/clients/system.py index bcf798d2..afa7f744 100644 --- a/dbt_common/clients/system.py +++ b/dbt_common/clients/system.py @@ -38,6 +38,15 @@ c_bool = None +def _record_path(path: str) -> bool: + return ( + # TODO: The first check here obviates the next two checks but is probably too coarse? + "dbt/include" not in path + and "dbt/include/global_project" not in path + and "/plugins/postgres/dbt/include/" not in path + ) + + @dataclasses.dataclass class FindMatchingParams: root_path: str @@ -61,10 +70,7 @@ def __init__( def _include(self) -> bool: # Do not record or replay filesystem searches that were performed against # files which are actually part of dbt's implementation. - return ( - "dbt/include/global_project" not in self.root_path - and "/plugins/postgres/dbt/include/" not in self.root_path - ) + return _record_path(self.root_path) @dataclasses.dataclass @@ -148,10 +154,7 @@ class LoadFileParams: def _include(self) -> bool: # Do not record or replay file reads that were performed against files # which are actually part of dbt's implementation. - return ( - "dbt/include/global_project" not in self.path - and "/plugins/postgres/dbt/include/" not in self.path - ) + return _record_path(self.path) @dataclasses.dataclass @@ -246,10 +249,7 @@ class WriteFileParams: def _include(self) -> bool: # Do not record or replay file reads that were performed against files # which are actually part of dbt's implementation. - return ( - "dbt/include/global_project" not in self.path - and "/plugins/postgres/dbt/include/" not in self.path - ) + return _record_path(self.path) @Recorder.register_record_type diff --git a/dbt_common/context.py b/dbt_common/context.py index a46b1dd2..947d409a 100644 --- a/dbt_common/context.py +++ b/dbt_common/context.py @@ -1,15 +1,47 @@ +import os from contextvars import ContextVar, copy_context -from typing import List, Mapping, Optional +from typing import List, Mapping, Optional, Iterator from dbt_common.constants import PRIVATE_ENV_PREFIX, SECRET_ENV_PREFIX +from dbt_common.record import Recorder + + +class CaseInsensitiveMapping(Mapping): + def __init__(self, env: Mapping[str, str]): + self._env = {k.casefold(): (k, v) for k, v in env.items()} + + def __getitem__(self, key: str) -> str: + return self._env[key.casefold()][1] + + def __len__(self) -> int: + return len(self._env) + + def __iter__(self) -> Iterator[str]: + for item in self._env.items(): + yield item[0] class InvocationContext: def __init__(self, env: Mapping[str, str]): - self._env = {k: v for k, v in env.items() if not k.startswith(PRIVATE_ENV_PREFIX)} + self._env: Mapping[str, str] + + env_public = {} + env_private = {} + + for k, v in env.items(): + if k.startswith(PRIVATE_ENV_PREFIX): + env_private[k] = v + else: + env_public[k] = v + + if os.name == "nt": + self._env = CaseInsensitiveMapping(env_public) + else: + self._env = env_public + self._env_secrets: Optional[List[str]] = None - self._env_private = {k: v for k, v in env.items() if k.startswith(PRIVATE_ENV_PREFIX)} - self.recorder = None + self._env_private = env_private + self.recorder: Optional[Recorder] = None # This class will also eventually manage the invocation_id, flags, event manager, etc. @property @@ -32,7 +64,7 @@ def env_secrets(self) -> List[str]: _INVOCATION_CONTEXT_VAR: ContextVar[InvocationContext] = ContextVar("DBT_INVOCATION_CONTEXT_VAR") -def reliably_get_invocation_var() -> ContextVar: +def reliably_get_invocation_var() -> ContextVar[InvocationContext]: invocation_var: Optional[ContextVar] = next( (cv for cv in copy_context() if cv.name == _INVOCATION_CONTEXT_VAR.name), None ) diff --git a/dbt_common/contracts/constraints.py b/dbt_common/contracts/constraints.py index c01ee6f8..4e2d9c7a 100644 --- a/dbt_common/contracts/constraints.py +++ b/dbt_common/contracts/constraints.py @@ -36,6 +36,8 @@ class ColumnLevelConstraint(dbtClassMixin): warn_unsupported: bool = ( True # Warn if constraint is not supported by the platform and won't be in DDL ) + to: Optional[str] = None + to_columns: List[str] = field(default_factory=list) @dataclass diff --git a/dbt_common/dataclass_schema.py b/dbt_common/dataclass_schema.py index 0bad081f..867d5a4c 100644 --- a/dbt_common/dataclass_schema.py +++ b/dbt_common/dataclass_schema.py @@ -1,4 +1,4 @@ -from typing import ClassVar, cast, get_type_hints, List, Tuple, Dict, Any, Optional +from typing import Any, cast, ClassVar, Dict, get_type_hints, List, Optional, Tuple import re import jsonschema from dataclasses import fields, Field @@ -26,7 +26,7 @@ class ValidationError(jsonschema.ValidationError): class DateTimeSerialization(SerializationStrategy): - def serialize(self, value) -> str: + def serialize(self, value: datetime) -> str: out = value.isoformat() # Assume UTC if timezone is missing if value.tzinfo is None: @@ -127,7 +127,7 @@ def _get_fields(cls) -> List[Tuple[Field, str]]: # copied from hologram. Used in tests @classmethod - def _get_field_names(cls): + def _get_field_names(cls) -> List[str]: return [element[1] for element in cls._get_fields()] @@ -152,7 +152,7 @@ def validate(cls, value): # These classes must be in this order or it doesn't work class StrEnum(str, SerializableType, Enum): - def __str__(self): + def __str__(self) -> str: return self.value # https://docs.python.org/3.6/library/enum.html#using-automatic-values diff --git a/dbt_common/exceptions/base.py b/dbt_common/exceptions/base.py index db619326..d966a28d 100644 --- a/dbt_common/exceptions/base.py +++ b/dbt_common/exceptions/base.py @@ -1,5 +1,5 @@ import builtins -from typing import List, Any, Optional +from typing import Any, List, Optional import os from dbt_common.constants import SECRET_ENV_PREFIX @@ -37,7 +37,7 @@ def __init__(self, msg: str): self.msg = scrub_secrets(msg, env_secrets()) @property - def type(self): + def type(self) -> str: return "Internal" def process_stack(self): @@ -59,7 +59,7 @@ def process_stack(self): return lines - def __str__(self): + def __str__(self) -> str: if hasattr(self.msg, "split"): split_msg = self.msg.split("\n") else: diff --git a/dbt_common/helper_types.py b/dbt_common/helper_types.py index 0ca435b7..8611f39f 100644 --- a/dbt_common/helper_types.py +++ b/dbt_common/helper_types.py @@ -19,7 +19,7 @@ class NVEnum(StrEnum): novalue = "novalue" - def __eq__(self, other): + def __eq__(self, other) -> bool: return isinstance(other, NVEnum) @@ -59,7 +59,7 @@ def includes(self, item_name: str) -> bool: item_name in self.include or self.include in self.INCLUDE_ALL ) and item_name not in self.exclude - def _validate_items(self, items: List[str]): + def _validate_items(self, items: List[str]) -> None: pass diff --git a/dbt_common/record.py b/dbt_common/record.py index 9fed4c3d..612ddf75 100644 --- a/dbt_common/record.py +++ b/dbt_common/record.py @@ -2,7 +2,7 @@ external systems during a command invocation, so that the command can be re-run later with the recording 'replayed' to dbt. -The rationale for and architecture of this module is described in detail in the +The rationale for and architecture of this module are described in detail in the docs/guides/record_replay.md document in this repository. """ import functools @@ -10,11 +10,8 @@ import json import os -from deepdiff import DeepDiff # type: ignore from enum import Enum -from typing import Any, Dict, List, Mapping, Optional, Type - -from dbt_common.context import get_invocation_context +from typing import Any, Callable, Dict, List, Mapping, Optional, Type class Record: @@ -23,7 +20,8 @@ class Record: to the request, and the 'result' is what is returned.""" params_cls: type - result_cls: Optional[type] + result_cls: Optional[type] = None + group: Optional[str] = None def __init__(self, params, result) -> None: self.params = params @@ -54,6 +52,11 @@ def from_dict(cls, dct: Mapping) -> "Record": class Diff: def __init__(self, current_recording_path: str, previous_recording_path: str) -> None: + # deepdiff is expensive to import, so we only do it here when we need it + from deepdiff import DeepDiff # type: ignore + + self.diff = DeepDiff + self.current_recording_path = current_recording_path self.previous_recording_path = previous_recording_path @@ -69,7 +72,7 @@ def diff_query_records(self, current: List, previous: List) -> Dict[str, Any]: if previous[i].get("result").get("table") is not None: previous[i]["result"]["table"] = json.loads(previous[i]["result"]["table"]) - return DeepDiff(previous, current, ignore_order=True, verbose_level=2) + return self.diff(previous, current, ignore_order=True, verbose_level=2) def diff_env_records(self, current: List, previous: List) -> Dict[str, Any]: # The mode and filepath may change. Ignore them. @@ -79,12 +82,12 @@ def diff_env_records(self, current: List, previous: List) -> Dict[str, Any]: "root[0]['result']['env']['DBT_RECORDER_MODE']", ] - return DeepDiff( + return self.diff( previous, current, ignore_order=True, verbose_level=2, exclude_paths=exclude_paths ) def diff_default(self, current: List, previous: List) -> Dict[str, Any]: - return DeepDiff(previous, current, ignore_order=True, verbose_level=2) + return self.diff(previous, current, ignore_order=True, verbose_level=2) def calculate_diff(self) -> Dict[str, Any]: with open(self.current_recording_path) as current_recording: @@ -129,10 +132,11 @@ def __init__( previous_recording_path: Optional[str] = None, ) -> None: self.mode = mode - self.types = types + self.recorded_types = types self._records_by_type: Dict[str, List[Record]] = {} + self._unprocessed_records_by_type: Dict[str, List[Dict[str, Any]]] = {} self._replay_diffs: List["Diff"] = [] - self.diff: Diff + self.diff: Optional[Diff] = None self.previous_recording_path = previous_recording_path self.current_recording_path = current_recording_path @@ -146,7 +150,7 @@ def __init__( ) if self.mode == RecorderMode.REPLAY: - self._records_by_type = self.load(self.previous_recording_path) + self._unprocessed_records_by_type = self.load(self.previous_recording_path) @classmethod def register_record_type(cls, rec_type) -> Any: @@ -161,7 +165,14 @@ def add_record(self, record: Record) -> None: self._records_by_type[rec_cls_name].append(record) def pop_matching_record(self, params: Any) -> Optional[Record]: - rec_type_name = self._record_name_by_params_name[type(params).__name__] + rec_type_name = self._record_name_by_params_name.get(type(params).__name__) + + if rec_type_name is None: + raise Exception( + f"A record of type {type(params).__name__} was requested, but no such type has been registered." + ) + + self._ensure_records_processed(rec_type_name) records = self._records_by_type[rec_type_name] match: Optional[Record] = None for rec in records: @@ -186,22 +197,20 @@ def _to_dict(self) -> Dict: return dct @classmethod - def load(cls, file_name: str) -> Dict[str, List[Record]]: + def load(cls, file_name: str) -> Dict[str, List[Dict[str, Any]]]: with open(file_name) as file: - loaded_dct = json.load(file) + return json.load(file) - records_by_type: Dict[str, List[Record]] = {} + def _ensure_records_processed(self, record_type_name: str) -> None: + if record_type_name in self._records_by_type: + return - for record_type_name in loaded_dct: - # TODO: this breaks with QueryRecord on replay since it's - # not in common so isn't part of cls._record_cls_by_name yet - record_cls = cls._record_cls_by_name[record_type_name] - rec_list = [] - for record_dct in loaded_dct[record_type_name]: - rec = record_cls.from_dict(record_dct) - rec_list.append(rec) # type: ignore - records_by_type[record_type_name] = rec_list - return records_by_type + rec_list = [] + record_cls = self._record_cls_by_name[record_type_name] + for record_dct in self._unprocessed_records_by_type[record_type_name]: + rec = record_cls.from_dict(record_dct) + rec_list.append(rec) # type: ignore + self._records_by_type[record_type_name] = rec_list def expect_record(self, params: Any) -> Any: record = self.pop_matching_record(params) @@ -209,16 +218,19 @@ def expect_record(self, params: Any) -> Any: if record is None: raise Exception() + if record.result is None: + return None + result_tuple = dataclasses.astuple(record.result) return result_tuple[0] if len(result_tuple) == 1 else result_tuple def write_diffs(self, diff_file_name) -> None: - json.dump( - self.diff.calculate_diff(), - open(diff_file_name, "w"), - ) + assert self.diff is not None + with open(diff_file_name, "w") as f: + json.dump(self.diff.calculate_diff(), f) def print_diffs(self) -> None: + assert self.diff is not None print(repr(self.diff.calculate_diff())) @@ -273,7 +285,12 @@ def get_record_types_from_dict(fp: str) -> List: return list(loaded_dct.keys()) -def record_function(record_type, method=False, tuple_result=False): +def record_function( + record_type, + method: bool = False, + tuple_result: bool = False, + id_field_name: Optional[str] = None, +) -> Callable: def record_function_inner(func_to_record): # To avoid runtime overhead and other unpleasantness, we only apply the # record/replay decorator if a relevant env var is set. @@ -281,9 +298,11 @@ def record_function_inner(func_to_record): return func_to_record @functools.wraps(func_to_record) - def record_replay_wrapper(*args, **kwargs): - recorder: Recorder = None + def record_replay_wrapper(*args, **kwargs) -> Any: + recorder: Optional[Recorder] = None try: + from dbt_common.context import get_invocation_context + recorder = get_invocation_context().recorder except LookupError: pass @@ -291,12 +310,17 @@ def record_replay_wrapper(*args, **kwargs): if recorder is None: return func_to_record(*args, **kwargs) - if recorder.types is not None and record_type.__name__ not in recorder.types: + if recorder.recorded_types is not None and not ( + record_type.__name__ in recorder.recorded_types + or record_type.group in recorder.recorded_types + ): return func_to_record(*args, **kwargs) # For methods, peel off the 'self' argument before calling the # params constructor. param_args = args[1:] if method else args + if method and id_field_name is not None: + param_args = (getattr(args[0], id_field_name),) + param_args params = record_type.params_cls(*param_args, **kwargs) @@ -313,7 +337,7 @@ def record_replay_wrapper(*args, **kwargs): r = func_to_record(*args, **kwargs) result = ( None - if r is None or record_type.result_cls is None + if record_type.result_cls is None else record_type.result_cls(*r) if tuple_result else record_type.result_cls(r) diff --git a/dbt_common/semver.py b/dbt_common/semver.py index ef0182ba..4c411911 100644 --- a/dbt_common/semver.py +++ b/dbt_common/semver.py @@ -1,6 +1,6 @@ from dataclasses import dataclass import re -from typing import List +from typing import List, Iterable import dbt_common.exceptions.base from dbt_common.exceptions import VersionsNotCompatibleError @@ -429,7 +429,7 @@ def versions_compatible(*args) -> bool: return False -def find_possible_versions(requested_range, available_versions): +def find_possible_versions(requested_range, available_versions: Iterable[str]): possible_versions = [] for version_string in available_versions: @@ -442,7 +442,9 @@ def find_possible_versions(requested_range, available_versions): return [v.to_version_string(skip_matcher=True) for v in sorted_versions] -def resolve_to_specific_version(requested_range, available_versions): +def resolve_to_specific_version( + requested_range, available_versions: Iterable[str] +) -> Optional[str]: max_version = None max_version_string = None diff --git a/dbt_common/utils/casting.py b/dbt_common/utils/casting.py index f86ca191..076d4c12 100644 --- a/dbt_common/utils/casting.py +++ b/dbt_common/utils/casting.py @@ -19,7 +19,8 @@ def cast_to_int(integer: Optional[int]) -> int: def cast_dict_to_dict_of_strings(dct: Mapping[Any, Any]) -> Dict[str, str]: - new_dct = {} + new_dct: Dict[str, str] = {} + for k, v in dct.items(): new_dct[str(k)] = str(v) return new_dct diff --git a/dbt_common/utils/executor.py b/dbt_common/utils/executor.py index 0dd8490c..529b02be 100644 --- a/dbt_common/utils/executor.py +++ b/dbt_common/utils/executor.py @@ -1,9 +1,12 @@ import concurrent.futures from contextlib import contextmanager -from contextvars import ContextVar from typing import Protocol, Optional -from dbt_common.context import get_invocation_context, reliably_get_invocation_var +from dbt_common.context import ( + get_invocation_context, + reliably_get_invocation_var, + InvocationContext, +) class ConnectingExecutor(concurrent.futures.Executor): @@ -63,7 +66,7 @@ class HasThreadingConfig(Protocol): threads: Optional[int] -def _thread_initializer(invocation_context: ContextVar) -> None: +def _thread_initializer(invocation_context: InvocationContext) -> None: invocation_var = reliably_get_invocation_var() invocation_var.set(invocation_context) diff --git a/docs/guides/record_replay.md b/docs/guides/record_replay.md index aff4c77c..b6dfc7b8 100644 --- a/docs/guides/record_replay.md +++ b/docs/guides/record_replay.md @@ -31,13 +31,16 @@ The final detail needed is to define the classes specified by `params_cls` and ` With these decorators applied and classes defined, dbt is able to record all file access during a run, and mock out the accesses during replay, isolating dbt from actually loading files. At least it would if dbt only used this function for all file access, which is only mostly true. We hope to continue improving the usefulness of this mechanism by adding more recorded functions and routing more operations through them. ## How to record/replay -If `DBT_RECORDER_MODE` is not `replay` or `record`, case insensitive, this is a no-op. Invalid values are ignored and do not throw exceptions. -`DBT_RECODER_TYPES` is optional. It indicates which types to filter the results by and expects a list of strings values for the `Record` subclasses. Any invalid types will be ignored. `all` is a valid type and behaves the same as not populating the env var. +Record/replay behavior is activated and configured via environment variables. When DBT_RECORDER_MODE is unset, the entire subsystem is disabled, and the decorators described above have no effect at all. This helps isolate the subsystem from core's application code, reducing the risk of performance impact or regressions. + +The record/replay subsystem is activated by setting the `DBT_RECORDER_MODE` variable to `replay`, `record`, or `diff`, case insensitive. Invalid values are ignored and do not throw exceptions. + +`DBT_RECODER_TYPES` is optional. It indicates which types to filter the results by and expects a list of strings values for the `Record` subclasses or groups of such classes. For example, all records of database/DWH interaction performed by adapters belong to the `Database` group. Any invalid type or group name will be ignored. `all` is a valid value for this variable and has the same effect as not populating the variable. ```bash -DBT_RECORDER_MODE=record DBT_RECODER_TYPES=QueryRecord,GetEnvRecord dbt run +DBT_RECORDER_MODE=record DBT_RECODER_TYPES=Database dbt run ``` replay need the file to replay diff --git a/pyproject.toml b/pyproject.toml index 64fc04fc..ba306437 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -126,6 +126,7 @@ ignore = ["E203", "E501", "E741", "W503", "W504"] exclude = [ "dbt_common/events/types_pb2.py", "venv", + ".venv", "env*" ] per-file-ignores = ["*/__init__.py: F401"] diff --git a/tests/unit/test_agate_helper.py b/tests/unit/test_agate_helper.py index 4c12bcd8..fff0d4c6 100644 --- a/tests/unit/test_agate_helper.py +++ b/tests/unit/test_agate_helper.py @@ -46,13 +46,13 @@ class TestAgateHelper(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.tempdir = mkdtemp() - def tearDown(self): + def tearDown(self) -> None: rmtree(self.tempdir) - def test_from_csv(self): + def test_from_csv(self) -> None: path = os.path.join(self.tempdir, "input.csv") with open(path, "wb") as fp: fp.write(SAMPLE_CSV_DATA.encode("utf-8")) @@ -61,7 +61,7 @@ def test_from_csv(self): for idx, row in enumerate(tbl): self.assertEqual(list(row), EXPECTED[idx]) - def test_bom_from_csv(self): + def test_bom_from_csv(self) -> None: path = os.path.join(self.tempdir, "input.csv") with open(path, "wb") as fp: fp.write(SAMPLE_CSV_BOM_DATA.encode("utf-8")) @@ -70,7 +70,7 @@ def test_bom_from_csv(self): for idx, row in enumerate(tbl): self.assertEqual(list(row), EXPECTED[idx]) - def test_from_csv_all_reserved(self): + def test_from_csv_all_reserved(self) -> None: path = os.path.join(self.tempdir, "input.csv") with open(path, "wb") as fp: fp.write(SAMPLE_CSV_DATA.encode("utf-8")) @@ -79,7 +79,7 @@ def test_from_csv_all_reserved(self): for expected, row in zip(EXPECTED_STRINGS, tbl): self.assertEqual(list(row), expected) - def test_from_data(self): + def test_from_data(self) -> None: column_names = ["a", "b", "c", "d", "e", "f", "g"] data = [ { @@ -106,7 +106,7 @@ def test_from_data(self): for idx, row in enumerate(tbl): self.assertEqual(list(row), EXPECTED[idx]) - def test_datetime_formats(self): + def test_datetime_formats(self) -> None: path = os.path.join(self.tempdir, "input.csv") datetimes = [ "20180806T11:33:29.000Z", @@ -120,7 +120,7 @@ def test_datetime_formats(self): tbl = agate_helper.from_csv(path, ()) self.assertEqual(tbl[0][0], expected) - def test_merge_allnull(self): + def test_merge_allnull(self) -> None: t1 = agate_helper.table_from_rows([(1, "a", None), (2, "b", None)], ("a", "b", "c")) t2 = agate_helper.table_from_rows([(3, "c", None), (4, "d", None)], ("a", "b", "c")) result = agate_helper.merge_tables([t1, t2]) @@ -130,7 +130,7 @@ def test_merge_allnull(self): assert isinstance(result.column_types[2], agate_helper.Integer) self.assertEqual(len(result), 4) - def test_merge_mixed(self): + def test_merge_mixed(self) -> None: t1 = agate_helper.table_from_rows( [(1, "a", None, None), (2, "b", None, None)], ("a", "b", "c", "d") ) @@ -181,7 +181,7 @@ def test_merge_mixed(self): assert isinstance(result.column_types[3], agate.data_types.Number) self.assertEqual(len(result), 6) - def test_nocast_string_types(self): + def test_nocast_string_types(self) -> None: # String fields should not be coerced into a representative type # See: https://github.com/dbt-labs/dbt-core/issues/2984 @@ -202,7 +202,7 @@ def test_nocast_string_types(self): for i, row in enumerate(tbl): self.assertEqual(list(row), expected[i]) - def test_nocast_bool_01(self): + def test_nocast_bool_01(self) -> None: # True and False values should not be cast to 1 and 0, and vice versa # See: https://github.com/dbt-labs/dbt-core/issues/4511 diff --git a/tests/unit/test_connection_retries.py b/tests/unit/test_connection_retries.py index 817af7a2..44fc72f5 100644 --- a/tests/unit/test_connection_retries.py +++ b/tests/unit/test_connection_retries.py @@ -19,20 +19,23 @@ def test_no_retry(self): assert result == expected -def no_success_fn(): +def no_success_fn() -> str: raise RequestException("You'll never pass") return "failure" class TestMaxRetries: - def test_no_retry(self): + def test_no_retry(self) -> None: fn_to_retry = functools.partial(no_success_fn) with pytest.raises(ConnectionError): connection_exception_retry(fn_to_retry, 3) -def single_retry_fn(): +counter = 0 + + +def single_retry_fn() -> str: global counter if counter == 0: counter += 1 @@ -45,7 +48,7 @@ def single_retry_fn(): class TestSingleRetry: - def test_no_retry(self): + def test_no_retry(self) -> None: global counter counter = 0 diff --git a/tests/unit/test_contextvars.py b/tests/unit/test_contextvars.py index 4eb58e6c..1aa9425f 100644 --- a/tests/unit/test_contextvars.py +++ b/tests/unit/test_contextvars.py @@ -1,7 +1,7 @@ from dbt_common.events.contextvars import log_contextvars, get_node_info, set_log_contextvars -def test_contextvars(): +def test_contextvars() -> None: node_info = { "unique_id": "model.test.my_model", "started_at": None, diff --git a/tests/unit/test_contracts_util.py b/tests/unit/test_contracts_util.py index 2a620370..d2fc4493 100644 --- a/tests/unit/test_contracts_util.py +++ b/tests/unit/test_contracts_util.py @@ -13,7 +13,7 @@ class ExampleMergableClass(Mergeable): class TestMergableClass(unittest.TestCase): - def test_mergeability(self): + def test_mergeability(self) -> None: mergeable1 = ExampleMergableClass( attr_a="loses", attr_b=None, attr_c=["I'll", "still", "exist"] ) diff --git a/tests/unit/test_core_dbt_utils.py b/tests/unit/test_core_dbt_utils.py index 8a0e836e..7419cd8d 100644 --- a/tests/unit/test_core_dbt_utils.py +++ b/tests/unit/test_core_dbt_utils.py @@ -7,30 +7,30 @@ class TestCommonDbtUtils(unittest.TestCase): - def test_connection_exception_retry_none(self): - Counter._reset(self) - connection_exception_retry(lambda: Counter._add(self), 5) + def test_connection_exception_retry_none(self) -> None: + Counter._reset() + connection_exception_retry(lambda: Counter._add(), 5) self.assertEqual(1, counter) - def test_connection_exception_retry_success_requests_exception(self): - Counter._reset(self) - connection_exception_retry(lambda: Counter._add_with_requests_exception(self), 5) + def test_connection_exception_retry_success_requests_exception(self) -> None: + Counter._reset() + connection_exception_retry(lambda: Counter._add_with_requests_exception(), 5) self.assertEqual(2, counter) # 2 = original attempt returned None, plus 1 retry - def test_connection_exception_retry_max(self): - Counter._reset(self) + def test_connection_exception_retry_max(self) -> None: + Counter._reset() with self.assertRaises(ConnectionError): - connection_exception_retry(lambda: Counter._add_with_exception(self), 5) + connection_exception_retry(lambda: Counter._add_with_exception(), 5) self.assertEqual(6, counter) # 6 = original attempt plus 5 retries - def test_connection_exception_retry_success_failed_untar(self): - Counter._reset(self) - connection_exception_retry(lambda: Counter._add_with_untar_exception(self), 5) + def test_connection_exception_retry_success_failed_untar(self) -> None: + Counter._reset() + connection_exception_retry(lambda: Counter._add_with_untar_exception(), 5) self.assertEqual(2, counter) # 2 = original attempt returned ReadError, plus 1 retry - def test_connection_exception_retry_success_failed_eofexception(self): - Counter._reset(self) - connection_exception_retry(lambda: Counter._add_with_eof_exception(self), 5) + def test_connection_exception_retry_success_failed_eofexception(self) -> None: + Counter._reset() + connection_exception_retry(lambda: Counter._add_with_eof_exception(), 5) self.assertEqual(2, counter) # 2 = original attempt returned EOFError, plus 1 retry @@ -38,36 +38,42 @@ def test_connection_exception_retry_success_failed_eofexception(self): class Counter: - def _add(self): + @classmethod + def _add(cls) -> None: global counter counter += 1 # All exceptions that Requests explicitly raises inherit from # requests.exceptions.RequestException so we want to make sure that raises plus one exception # that inherit from it for sanity - def _add_with_requests_exception(self): + @classmethod + def _add_with_requests_exception(cls) -> None: global counter counter += 1 if counter < 2: raise requests.exceptions.RequestException - def _add_with_exception(self): + @classmethod + def _add_with_exception(cls) -> None: global counter counter += 1 raise requests.exceptions.ConnectionError - def _add_with_untar_exception(self): + @classmethod + def _add_with_untar_exception(cls) -> None: global counter counter += 1 if counter < 2: raise tarfile.ReadError - def _add_with_eof_exception(self): + @classmethod + def _add_with_eof_exception(cls) -> None: global counter counter += 1 if counter < 2: raise EOFError - def _reset(self): + @classmethod + def _reset(cls) -> None: global counter counter = 0 diff --git a/tests/unit/test_diff.py b/tests/unit/test_diff.py index 791263f3..54f735e3 100644 --- a/tests/unit/test_diff.py +++ b/tests/unit/test_diff.py @@ -1,4 +1,6 @@ import json +from typing import Any, Dict + import pytest from dbt_common.record import Diff @@ -191,7 +193,7 @@ def open_mock(file, *args, **kwargs): return open_mock -def test_calculate_diff_no_diff(monkeypatch): +def test_calculate_diff_no_diff(monkeypatch) -> None: # Mock data for the files current_recording_data = { "GetEnvRecord": [ @@ -251,11 +253,11 @@ def test_calculate_diff_no_diff(monkeypatch): previous_recording_path=previous_recording_path, ) result = diff_instance.calculate_diff() - expected_result = {"GetEnvRecord": {}, "DefaultKey": {}} + expected_result: Dict[str, Any] = {"GetEnvRecord": {}, "DefaultKey": {}} assert result == expected_result -def test_calculate_diff_with_diff(monkeypatch): +def test_calculate_diff_with_diff(monkeypatch) -> None: # Mock data for the files current_recording_data = { "GetEnvRecord": [ diff --git a/tests/unit/test_event_handler.py b/tests/unit/test_event_handler.py index 80d5ae2b..f38938b6 100644 --- a/tests/unit/test_event_handler.py +++ b/tests/unit/test_event_handler.py @@ -5,7 +5,7 @@ from dbt_common.events.event_manager import TestEventManager -def test_event_logging_handler_emits_records_correctly(): +def test_event_logging_handler_emits_records_correctly() -> None: event_manager = TestEventManager() handler = DbtEventLoggingHandler(event_manager=event_manager, level=logging.DEBUG) log = logging.getLogger("test") @@ -27,7 +27,7 @@ def test_event_logging_handler_emits_records_correctly(): assert event_manager.event_history[5][1] == EventLevel.ERROR -def test_set_package_logging_sets_level_correctly(): +def test_set_package_logging_sets_level_correctly() -> None: event_manager = TestEventManager() log = logging.getLogger("test") set_package_logging("test", logging.DEBUG, event_manager) diff --git a/tests/unit/test_helper_types.py b/tests/unit/test_helper_types.py index 1a9519de..ba98803c 100644 --- a/tests/unit/test_helper_types.py +++ b/tests/unit/test_helper_types.py @@ -1,11 +1,12 @@ import pytest +from typing import List, Union from dbt_common.helper_types import IncludeExclude, WarnErrorOptions from dbt_common.dataclass_schema import ValidationError class TestIncludeExclude: - def test_init_invalid(self): + def test_init_invalid(self) -> None: with pytest.raises(ValidationError): IncludeExclude(include="invalid") @@ -22,14 +23,16 @@ def test_init_invalid(self): (["ItemA", "ItemB"], [], True), ], ) - def test_includes(self, include, exclude, expected_includes): + def test_includes( + self, include: Union[str, List[str]], exclude: List[str], expected_includes: bool + ) -> None: include_exclude = IncludeExclude(include=include, exclude=exclude) assert include_exclude.includes("ItemA") == expected_includes class TestWarnErrorOptions: - def test_init_invalid_error(self): + def test_init_invalid_error(self) -> None: with pytest.raises(ValidationError): WarnErrorOptions(include=["InvalidError"], valid_error_names=set(["ValidError"])) @@ -38,14 +41,14 @@ def test_init_invalid_error(self): include="*", exclude=["InvalidError"], valid_error_names=set(["ValidError"]) ) - def test_init_invalid_error_default_valid_error_names(self): + def test_init_invalid_error_default_valid_error_names(self) -> None: with pytest.raises(ValidationError): WarnErrorOptions(include=["InvalidError"]) with pytest.raises(ValidationError): WarnErrorOptions(include="*", exclude=["InvalidError"]) - def test_init_valid_error(self): + def test_init_valid_error(self) -> None: warn_error_options = WarnErrorOptions( include=["ValidError"], valid_error_names=set(["ValidError"]) ) @@ -58,18 +61,18 @@ def test_init_valid_error(self): assert warn_error_options.include == "*" assert warn_error_options.exclude == ["ValidError"] - def test_init_default_silence(self): + def test_init_default_silence(self) -> None: my_options = WarnErrorOptions(include="*") assert my_options.silence == [] - def test_init_invalid_silence_event(self): + def test_init_invalid_silence_event(self) -> None: with pytest.raises(ValidationError): WarnErrorOptions(include="*", silence=["InvalidError"]) - def test_init_valid_silence_event(self): + def test_init_valid_silence_event(self) -> None: all_events = ["MySilencedEvent"] my_options = WarnErrorOptions( - include="*", silence=all_events, valid_error_names=all_events + include="*", silence=all_events, valid_error_names=set(all_events) ) assert my_options.silence == all_events @@ -81,14 +84,16 @@ def test_init_valid_silence_event(self): ("*", ["ItemB"], True), ], ) - def test_includes(self, include, silence, expected_includes): + def test_includes( + self, include: Union[str, List[str]], silence: List[str], expected_includes: bool + ) -> None: include_exclude = WarnErrorOptions( include=include, silence=silence, valid_error_names={"ItemA", "ItemB"} ) assert include_exclude.includes("ItemA") == expected_includes - def test_silenced(self): + def test_silenced(self) -> None: my_options = WarnErrorOptions(include="*", silence=["ItemA"], valid_error_names={"ItemA"}) assert my_options.silenced("ItemA") assert not my_options.silenced("ItemB") diff --git a/tests/unit/test_invocation_context.py b/tests/unit/test_invocation_context.py index b6697f8e..fbf060ba 100644 --- a/tests/unit/test_invocation_context.py +++ b/tests/unit/test_invocation_context.py @@ -1,14 +1,29 @@ +import os + +import pytest + from dbt_common.constants import PRIVATE_ENV_PREFIX, SECRET_ENV_PREFIX -from dbt_common.context import InvocationContext +from dbt_common.context import InvocationContext, CaseInsensitiveMapping -def test_invocation_context_env(): +def test_invocation_context_env() -> None: test_env = {"VAR_1": "value1", "VAR_2": "value2"} ic = InvocationContext(env=test_env) assert ic.env == test_env -def test_invocation_context_secrets(): +@pytest.mark.skipif( + os.name != "nt", reason="Test for case-insensitive env vars, only run on Windows" +) +def test_invocation_context_windows() -> None: + test_env = {"var_1": "lowercase", "vAr_2": "mixedcase", "VAR_3": "uppercase"} + ic = InvocationContext(env=test_env) + assert ic.env == CaseInsensitiveMapping( + {"var_1": "lowercase", "var_2": "mixedcase", "var_3": "uppercase"} + ) + + +def test_invocation_context_secrets() -> None: test_env = { f"{SECRET_ENV_PREFIX}_VAR_1": "secret1", f"{SECRET_ENV_PREFIX}VAR_2": "secret2", @@ -16,10 +31,10 @@ def test_invocation_context_secrets(): f"foo{SECRET_ENV_PREFIX}": "non-secret", } ic = InvocationContext(env=test_env) - assert set(ic.env_secrets) == set(["secret1", "secret2"]) + assert set(ic.env_secrets) == {"secret1", "secret2"} -def test_invocation_context_private(): +def test_invocation_context_private() -> None: test_env = { f"{PRIVATE_ENV_PREFIX}_VAR_1": "private1", f"{PRIVATE_ENV_PREFIX}VAR_2": "private2", diff --git a/tests/unit/test_jinja.py b/tests/unit/test_jinja.py index f038a1ec..e906a0ac 100644 --- a/tests/unit/test_jinja.py +++ b/tests/unit/test_jinja.py @@ -1,23 +1,26 @@ import unittest +from dbt_common.clients._jinja_blocks import BlockTag from dbt_common.clients.jinja import extract_toplevel_blocks from dbt_common.exceptions import CompilationError class TestBlockLexer(unittest.TestCase): - def test_basic(self): + def test_basic(self) -> None: body = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n' block_data = " \n\r\t{%- mytype foo %}" + body + "{%endmytype -%}" blocks = extract_toplevel_blocks( block_data, allowed_blocks={"mytype"}, collect_raw_data=False ) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "mytype") - self.assertEqual(blocks[0].block_name, "foo") - self.assertEqual(blocks[0].contents, body) - self.assertEqual(blocks[0].full_block, block_data) - - def test_multiple(self): + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "mytype") + self.assertEqual(b0.block_name, "foo") + self.assertEqual(b0.contents, body) + self.assertEqual(b0.full_block, block_data) + + def test_multiple(self) -> None: body_one = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n' body_two = ( "{{ config(bar=1)}}\r\nselect * from {% if foo %} thing " @@ -37,7 +40,7 @@ def test_multiple(self): ) self.assertEqual(len(blocks), 2) - def test_comments(self): + def test_comments(self) -> None: body = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n' comment = "{# my comment #}" block_data = " \n\r\t{%- mytype foo %}" + body + "{%endmytype -%}" @@ -45,12 +48,14 @@ def test_comments(self): comment + block_data, allowed_blocks={"mytype"}, collect_raw_data=False ) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "mytype") - self.assertEqual(blocks[0].block_name, "foo") - self.assertEqual(blocks[0].contents, body) - self.assertEqual(blocks[0].full_block, block_data) - - def test_evil_comments(self): + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "mytype") + self.assertEqual(b0.block_name, "foo") + self.assertEqual(b0.contents, body) + self.assertEqual(b0.full_block, block_data) + + def test_evil_comments(self) -> None: body = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n' comment = ( "{# external comment {% othertype bar %} select * from " @@ -61,12 +66,14 @@ def test_evil_comments(self): comment + block_data, allowed_blocks={"mytype"}, collect_raw_data=False ) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "mytype") - self.assertEqual(blocks[0].block_name, "foo") - self.assertEqual(blocks[0].contents, body) - self.assertEqual(blocks[0].full_block, block_data) - - def test_nested_comments(self): + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "mytype") + self.assertEqual(b0.block_name, "foo") + self.assertEqual(b0.contents, body) + self.assertEqual(b0.full_block, block_data) + + def test_nested_comments(self) -> None: body = ( '{# my comment #} {{ config(foo="bar") }}' "\r\nselect * from {# my other comment embedding {% endmytype %} #} this.that\r\n" @@ -80,33 +87,43 @@ def test_nested_comments(self): comment + block_data, allowed_blocks={"mytype"}, collect_raw_data=False ) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "mytype") - self.assertEqual(blocks[0].block_name, "foo") - self.assertEqual(blocks[0].contents, body) - self.assertEqual(blocks[0].full_block, block_data) - - def test_complex_file(self): + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "mytype") + self.assertEqual(b0.block_name, "foo") + self.assertEqual(b0.contents, body) + self.assertEqual(b0.full_block, block_data) + + def test_complex_file(self) -> None: blocks = extract_toplevel_blocks( complex_snapshot_file, allowed_blocks={"mytype", "myothertype"}, collect_raw_data=False ) self.assertEqual(len(blocks), 3) - self.assertEqual(blocks[0].block_type_name, "mytype") - self.assertEqual(blocks[0].block_name, "foo") - self.assertEqual(blocks[0].full_block, "{% mytype foo %} some stuff {% endmytype %}") - self.assertEqual(blocks[0].contents, " some stuff ") - self.assertEqual(blocks[1].block_type_name, "mytype") - self.assertEqual(blocks[1].block_name, "bar") - self.assertEqual(blocks[1].full_block, bar_block) - self.assertEqual(blocks[1].contents, bar_block[16:-15].rstrip()) - self.assertEqual(blocks[2].block_type_name, "myothertype") - self.assertEqual(blocks[2].block_name, "x") - self.assertEqual(blocks[2].full_block, x_block.strip()) + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "mytype") + self.assertEqual(b0.block_name, "foo") + self.assertEqual(b0.full_block, "{% mytype foo %} some stuff {% endmytype %}") + self.assertEqual(b0.contents, " some stuff ") + + b1 = blocks[1] + assert isinstance(b1, BlockTag) + self.assertEqual(b1.block_type_name, "mytype") + self.assertEqual(b1.block_name, "bar") + self.assertEqual(b1.full_block, bar_block) + self.assertEqual(b1.contents, bar_block[16:-15].rstrip()) + + b2 = blocks[2] + assert isinstance(b2, BlockTag) + self.assertEqual(b2.block_type_name, "myothertype") + self.assertEqual(b2.block_name, "x") + self.assertEqual(b2.full_block, x_block.strip()) self.assertEqual( - blocks[2].contents, + b2.contents, x_block[len("\n{% myothertype x %}") : -len("{% endmyothertype %}\n")], ) - def test_peaceful_macro_coexistence(self): + def test_peaceful_macro_coexistence(self) -> None: body = ( "{# my macro #} {% macro foo(a, b) %} do a thing " "{%- endmacro %} {# my model #} {% a b %} test {% enda %}" @@ -116,15 +133,22 @@ def test_peaceful_macro_coexistence(self): ) self.assertEqual(len(blocks), 4) self.assertEqual(blocks[0].full_block, "{# my macro #} ") - self.assertEqual(blocks[1].block_type_name, "macro") - self.assertEqual(blocks[1].block_name, "foo") - self.assertEqual(blocks[1].contents, " do a thing") + + b1 = blocks[1] + assert isinstance(b1, BlockTag) + self.assertEqual(b1.block_type_name, "macro") + self.assertEqual(b1.block_name, "foo") + self.assertEqual(b1.contents, " do a thing") + self.assertEqual(blocks[2].full_block, " {# my model #} ") - self.assertEqual(blocks[3].block_type_name, "a") - self.assertEqual(blocks[3].block_name, "b") - self.assertEqual(blocks[3].contents, " test ") - def test_macro_with_trailing_data(self): + b3 = blocks[3] + assert isinstance(b3, BlockTag) + self.assertEqual(b3.block_type_name, "a") + self.assertEqual(b3.block_name, "b") + self.assertEqual(b3.contents, " test ") + + def test_macro_with_trailing_data(self) -> None: body = ( "{# my macro #} {% macro foo(a, b) %} do a thing {%- endmacro %} " "{# my model #} {% a b %} test {% enda %} raw data so cool" @@ -134,16 +158,24 @@ def test_macro_with_trailing_data(self): ) self.assertEqual(len(blocks), 5) self.assertEqual(blocks[0].full_block, "{# my macro #} ") - self.assertEqual(blocks[1].block_type_name, "macro") - self.assertEqual(blocks[1].block_name, "foo") - self.assertEqual(blocks[1].contents, " do a thing") + + b1 = blocks[1] + assert isinstance(b1, BlockTag) + self.assertEqual(b1.block_type_name, "macro") + self.assertEqual(b1.block_name, "foo") + self.assertEqual(b1.contents, " do a thing") + self.assertEqual(blocks[2].full_block, " {# my model #} ") - self.assertEqual(blocks[3].block_type_name, "a") - self.assertEqual(blocks[3].block_name, "b") - self.assertEqual(blocks[3].contents, " test ") + + b3 = blocks[3] + assert isinstance(b3, BlockTag) + self.assertEqual(b3.block_type_name, "a") + self.assertEqual(b3.block_name, "b") + self.assertEqual(b3.contents, " test ") + self.assertEqual(blocks[4].full_block, " raw data so cool") - def test_macro_with_crazy_args(self): + def test_macro_with_crazy_args(self) -> None: body = ( """{% macro foo(a, b=asdf("cool this is 'embedded'" * 3) + external_var, c)%}""" "cool{# block comment with {% endmacro %} in it #} stuff here " @@ -151,38 +183,44 @@ def test_macro_with_crazy_args(self): ) blocks = extract_toplevel_blocks(body, allowed_blocks={"macro"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "macro") - self.assertEqual(blocks[0].block_name, "foo") + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "macro") + self.assertEqual(b0.block_name, "foo") self.assertEqual( blocks[0].contents, "cool{# block comment with {% endmacro %} in it #} stuff here " ) - def test_materialization_parse(self): + def test_materialization_parse(self) -> None: body = "{% materialization xxx, default %} ... {% endmaterialization %}" blocks = extract_toplevel_blocks( body, allowed_blocks={"materialization"}, collect_raw_data=False ) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "materialization") - self.assertEqual(blocks[0].block_name, "xxx") - self.assertEqual(blocks[0].full_block, body) + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "materialization") + self.assertEqual(b0.block_name, "xxx") + self.assertEqual(b0.full_block, body) body = '{% materialization xxx, adapter="other" %} ... {% endmaterialization %}' blocks = extract_toplevel_blocks( body, allowed_blocks={"materialization"}, collect_raw_data=False ) + b0 = blocks[0] + assert isinstance(b0, BlockTag) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "materialization") - self.assertEqual(blocks[0].block_name, "xxx") - self.assertEqual(blocks[0].full_block, body) + self.assertEqual(b0.block_type_name, "materialization") + self.assertEqual(b0.block_name, "xxx") + self.assertEqual(b0.full_block, body) - def test_nested_not_ok(self): + def test_nested_not_ok(self) -> None: # we don't allow nesting same blocks body = "{% myblock a %} {% myblock b %} {% endmyblock %} {% endmyblock %}" with self.assertRaises(CompilationError): extract_toplevel_blocks(body, allowed_blocks={"myblock"}) - def test_incomplete_block_failure(self): + def test_incomplete_block_failure(self) -> None: fullbody = "{% myblock foo %} {% endmyblock %}" for length in range(len("{% myblock foo %}"), len(fullbody) - 1): body = fullbody[:length] @@ -194,45 +232,45 @@ def test_wrong_end_failure(self): with self.assertRaises(CompilationError): extract_toplevel_blocks(body, allowed_blocks={"myblock", "otherblock"}) - def test_comment_no_end_failure(self): + def test_comment_no_end_failure(self) -> None: body = "{# " with self.assertRaises(CompilationError): extract_toplevel_blocks(body) - def test_comment_only(self): + def test_comment_only(self) -> None: body = "{# myblock #}" blocks = extract_toplevel_blocks(body) self.assertEqual(len(blocks), 1) blocks = extract_toplevel_blocks(body, collect_raw_data=False) self.assertEqual(len(blocks), 0) - def test_comment_block_self_closing(self): + def test_comment_block_self_closing(self) -> None: # test the case where a comment start looks a lot like it closes itself # (but it doesn't in jinja!) body = "{#} {% myblock foo %} {#}" blocks = extract_toplevel_blocks(body, collect_raw_data=False) self.assertEqual(len(blocks), 0) - def test_embedded_self_closing_comment_block(self): + def test_embedded_self_closing_comment_block(self) -> None: body = "{% myblock foo %} {#}{% endmyblock %} {#}{% endmyblock %}" blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, body) self.assertEqual(blocks[0].contents, " {#}{% endmyblock %} {#}") - def test_set_statement(self): + def test_set_statement(self) -> None: body = "{% set x = 1 %}{% myblock foo %}hi{% endmyblock %}" blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}") - def test_set_block(self): + def test_set_block(self) -> None: body = "{% set x %}1{% endset %}{% myblock foo %}hi{% endmyblock %}" blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}") - def test_crazy_set_statement(self): + def test_crazy_set_statement(self) -> None: body = ( '{% set x = (thing("{% myblock foo %}")) %}{% otherblock bar %}x{% endotherblock %}' '{% set y = otherthing("{% myblock foo %}") %}' @@ -244,19 +282,19 @@ def test_crazy_set_statement(self): self.assertEqual(blocks[0].full_block, "{% otherblock bar %}x{% endotherblock %}") self.assertEqual(blocks[0].block_type_name, "otherblock") - def test_do_statement(self): + def test_do_statement(self) -> None: body = "{% do thing.update() %}{% myblock foo %}hi{% endmyblock %}" blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}") - def test_deceptive_do_statement(self): + def test_deceptive_do_statement(self) -> None: body = "{% do thing %}{% myblock foo %}hi{% endmyblock %}" blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}") - def test_do_block(self): + def test_do_block(self) -> None: body = "{% do %}thing.update(){% enddo %}{% myblock foo %}hi{% endmyblock %}" blocks = extract_toplevel_blocks( body, allowed_blocks={"do", "myblock"}, collect_raw_data=False @@ -266,7 +304,7 @@ def test_do_block(self): self.assertEqual(blocks[0].block_type_name, "do") self.assertEqual(blocks[1].full_block, "{% myblock foo %}hi{% endmyblock %}") - def test_crazy_do_statement(self): + def test_crazy_do_statement(self) -> None: body = ( '{% do (thing("{% myblock foo %}")) %}{% otherblock bar %}x{% endotherblock %}' '{% do otherthing("{% myblock foo %}") %}{% myblock x %}hi{% endmyblock %}' @@ -280,7 +318,7 @@ def test_crazy_do_statement(self): self.assertEqual(blocks[1].full_block, "{% myblock x %}hi{% endmyblock %}") self.assertEqual(blocks[1].block_type_name, "myblock") - def test_awful_jinja(self): + def test_awful_jinja(self) -> None: blocks = extract_toplevel_blocks( if_you_do_this_you_are_awful, allowed_blocks={"snapshot", "materialization"}, @@ -304,63 +342,71 @@ def test_awful_jinja(self): self.assertEqual(blocks[1].block_type_name, "materialization") self.assertEqual(blocks[1].contents, "\nhi\n") - def test_quoted_endblock_within_block(self): + def test_quoted_endblock_within_block(self) -> None: body = '{% myblock something -%} {% set x = ("{% endmyblock %}") %} {% endmyblock %}' blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].block_type_name, "myblock") self.assertEqual(blocks[0].contents, '{% set x = ("{% endmyblock %}") %} ') - def test_docs_block(self): + def test_docs_block(self) -> None: body = ( "{% docs __my_doc__ %} asdf {# nope {% enddocs %}} #} {% enddocs %}" '{% docs __my_other_doc__ %} asdf "{% enddocs %}' ) blocks = extract_toplevel_blocks(body, allowed_blocks={"docs"}, collect_raw_data=False) self.assertEqual(len(blocks), 2) - self.assertEqual(blocks[0].block_type_name, "docs") - self.assertEqual(blocks[0].contents, " asdf {# nope {% enddocs %}} #} ") - self.assertEqual(blocks[0].block_name, "__my_doc__") - self.assertEqual(blocks[1].block_type_name, "docs") - self.assertEqual(blocks[1].contents, ' asdf "') - self.assertEqual(blocks[1].block_name, "__my_other_doc__") - - def test_docs_block_expr(self): + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "docs") + self.assertEqual(b0.contents, " asdf {# nope {% enddocs %}} #} ") + self.assertEqual(b0.block_name, "__my_doc__") + b1 = blocks[1] + assert isinstance(b1, BlockTag) + self.assertEqual(b1.block_type_name, "docs") + self.assertEqual(b1.contents, ' asdf "') + self.assertEqual(b1.block_name, "__my_other_doc__") + + def test_docs_block_expr(self) -> None: body = '{% docs more_doc %} asdf {{ "{% enddocs %}" ~ "}}" }}{% enddocs %}' blocks = extract_toplevel_blocks(body, allowed_blocks={"docs"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "docs") - self.assertEqual(blocks[0].contents, ' asdf {{ "{% enddocs %}" ~ "}}" }}') - self.assertEqual(blocks[0].block_name, "more_doc") + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "docs") + self.assertEqual(b0.contents, ' asdf {{ "{% enddocs %}" ~ "}}" }}') + self.assertEqual(b0.block_name, "more_doc") - def test_unclosed_model_quotes(self): + def test_unclosed_model_quotes(self) -> None: # test case for https://github.com/dbt-labs/dbt-core/issues/1533 body = '{% model my_model -%} select * from "something"."something_else{% endmodel %}' blocks = extract_toplevel_blocks(body, allowed_blocks={"model"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "model") - self.assertEqual(blocks[0].contents, 'select * from "something"."something_else') - self.assertEqual(blocks[0].block_name, "my_model") + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "model") + self.assertEqual(b0.contents, 'select * from "something"."something_else') + self.assertEqual(b0.block_name, "my_model") - def test_if(self): + def test_if(self) -> None: # if you conditionally define your macros/models, don't body = "{% if true %}{% macro my_macro() %} adsf {% endmacro %}{% endif %}" with self.assertRaises(CompilationError): extract_toplevel_blocks(body) - def test_if_innocuous(self): + def test_if_innocuous(self) -> None: body = "{% if true %}{% something %}asdfasd{% endsomething %}{% endif %}" blocks = extract_toplevel_blocks(body) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, body) - def test_for(self): + def test_for(self) -> None: # no for-loops over macros. body = "{% for x in range(10) %}{% macro my_macro() %} adsf {% endmacro %}{% endfor %}" with self.assertRaises(CompilationError): extract_toplevel_blocks(body) - def test_for_innocuous(self): + def test_for_innocuous(self) -> None: # no for-loops over macros. body = ( "{% for x in range(10) %}{% something my_something %} adsf " @@ -370,7 +416,7 @@ def test_for_innocuous(self): self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, body) - def test_endif(self): + def test_endif(self) -> None: body = "{% snapshot foo %}select * from thing{% endsnapshot%}{% endif %}" with self.assertRaises(CompilationError) as err: extract_toplevel_blocks(body) @@ -382,7 +428,7 @@ def test_endif(self): str(err.exception), ) - def test_if_endfor(self): + def test_if_endfor(self) -> None: body = "{% if x %}...{% endfor %}{% endif %}" with self.assertRaises(CompilationError) as err: extract_toplevel_blocks(body) @@ -391,7 +437,7 @@ def test_if_endfor(self): str(err.exception), ) - def test_if_endfor_newlines(self): + def test_if_endfor_newlines(self) -> None: body = "{% if x %}\n ...\n {% endfor %}\n{% endif %}" with self.assertRaises(CompilationError) as err: extract_toplevel_blocks(body) diff --git a/tests/unit/test_model_config.py b/tests/unit/test_model_config.py index 0cc1e711..57a14438 100644 --- a/tests/unit/test_model_config.py +++ b/tests/unit/test_model_config.py @@ -14,7 +14,7 @@ class ThingWithMergeBehavior(dbtClassMixin): keysappended: Dict[str, int] = field(metadata={"merge": MergeBehavior.DictKeyAppend}) -def test_merge_behavior_meta(): +def test_merge_behavior_meta() -> None: existing = {"foo": "bar"} initial_existing = existing.copy() assert set(MergeBehavior) == { @@ -29,15 +29,14 @@ def test_merge_behavior_meta(): assert existing == initial_existing -def test_merge_behavior_from_field(): - fields = [f[0] for f in ThingWithMergeBehavior._get_fields()] - fields = {name: f for f, name in ThingWithMergeBehavior._get_fields()} - assert set(fields) == {"default_behavior", "appended", "updated", "clobbered", "keysappended"} - assert MergeBehavior.from_field(fields["default_behavior"]) == MergeBehavior.Clobber - assert MergeBehavior.from_field(fields["appended"]) == MergeBehavior.Append - assert MergeBehavior.from_field(fields["updated"]) == MergeBehavior.Update - assert MergeBehavior.from_field(fields["clobbered"]) == MergeBehavior.Clobber - assert MergeBehavior.from_field(fields["keysappended"]) == MergeBehavior.DictKeyAppend +def test_merge_behavior_from_field() -> None: + fields2 = {name: f for f, name in ThingWithMergeBehavior._get_fields()} + assert set(fields2) == {"default_behavior", "appended", "updated", "clobbered", "keysappended"} + assert MergeBehavior.from_field(fields2["default_behavior"]) == MergeBehavior.Clobber + assert MergeBehavior.from_field(fields2["appended"]) == MergeBehavior.Append + assert MergeBehavior.from_field(fields2["updated"]) == MergeBehavior.Update + assert MergeBehavior.from_field(fields2["clobbered"]) == MergeBehavior.Clobber + assert MergeBehavior.from_field(fields2["keysappended"]) == MergeBehavior.DictKeyAppend @dataclass @@ -47,7 +46,7 @@ class ThingWithShowBehavior(dbtClassMixin): shown: float = field(metadata={"show_hide": ShowBehavior.Show}) -def test_show_behavior_meta(): +def test_show_behavior_meta() -> None: existing = {"foo": "bar"} initial_existing = existing.copy() assert set(ShowBehavior) == {ShowBehavior.Hide, ShowBehavior.Show} @@ -57,13 +56,12 @@ def test_show_behavior_meta(): assert existing == initial_existing -def test_show_behavior_from_field(): - fields = [f[0] for f in ThingWithShowBehavior._get_fields()] - fields = {name: f for f, name in ThingWithShowBehavior._get_fields()} - assert set(fields) == {"default_behavior", "hidden", "shown"} - assert ShowBehavior.from_field(fields["default_behavior"]) == ShowBehavior.Show - assert ShowBehavior.from_field(fields["hidden"]) == ShowBehavior.Hide - assert ShowBehavior.from_field(fields["shown"]) == ShowBehavior.Show +def test_show_behavior_from_field() -> None: + fields2 = {name: f for f, name in ThingWithShowBehavior._get_fields()} + assert set(fields2) == {"default_behavior", "hidden", "shown"} + assert ShowBehavior.from_field(fields2["default_behavior"]) == ShowBehavior.Show + assert ShowBehavior.from_field(fields2["hidden"]) == ShowBehavior.Hide + assert ShowBehavior.from_field(fields2["shown"]) == ShowBehavior.Show @dataclass @@ -73,7 +71,7 @@ class ThingWithCompareBehavior(dbtClassMixin): excluded: str = field(metadata={"compare": CompareBehavior.Exclude}) -def test_compare_behavior_meta(): +def test_compare_behavior_meta() -> None: existing = {"foo": "bar"} initial_existing = existing.copy() assert set(CompareBehavior) == {CompareBehavior.Include, CompareBehavior.Exclude} @@ -83,10 +81,9 @@ def test_compare_behavior_meta(): assert existing == initial_existing -def test_compare_behavior_from_field(): - fields = [f[0] for f in ThingWithCompareBehavior._get_fields()] - fields = {name: f for f, name in ThingWithCompareBehavior._get_fields()} - assert set(fields) == {"default_behavior", "included", "excluded"} - assert CompareBehavior.from_field(fields["default_behavior"]) == CompareBehavior.Include - assert CompareBehavior.from_field(fields["included"]) == CompareBehavior.Include - assert CompareBehavior.from_field(fields["excluded"]) == CompareBehavior.Exclude +def test_compare_behavior_from_field() -> None: + fields2 = {name: f for f, name in ThingWithCompareBehavior._get_fields()} + assert set(fields2) == {"default_behavior", "included", "excluded"} + assert CompareBehavior.from_field(fields2["default_behavior"]) == CompareBehavior.Include + assert CompareBehavior.from_field(fields2["included"]) == CompareBehavior.Include + assert CompareBehavior.from_field(fields2["excluded"]) == CompareBehavior.Exclude diff --git a/tests/unit/test_proto_events.py b/tests/unit/test_proto_events.py index 32eb08ae..d21b5062 100644 --- a/tests/unit/test_proto_events.py +++ b/tests/unit/test_proto_events.py @@ -18,7 +18,7 @@ } -def test_events(): +def test_events() -> None: # M020 event event_code = "M020" event = RetryExternalCall(attempt=3, max=5) @@ -45,7 +45,7 @@ def test_events(): assert new_msg.data.attempt == msg.data.attempt -def test_extra_dict_on_event(monkeypatch): +def test_extra_dict_on_event(monkeypatch) -> None: monkeypatch.setenv("DBT_ENV_CUSTOM_ENV_env_key", "env_value") reset_metadata_vars() diff --git a/tests/unit/test_record.py b/tests/unit/test_record.py index b0371498..6e02d710 100644 --- a/tests/unit/test_record.py +++ b/tests/unit/test_record.py @@ -69,7 +69,7 @@ def setup(): os.environ["DBT_RECORDER_FILE_PATH"] = prev_fp -def test_decorator_records(setup): +def test_decorator_records(setup) -> None: os.environ["DBT_RECORDER_MODE"] = "Record" recorder = Recorder(RecorderMode.RECORD, None) set_invocation_context({}) @@ -116,7 +116,7 @@ def not_test_func(a: int, b: str, c: Optional[str] = None) -> str: assert NotTestRecord not in recorder._records_by_type -def test_decorator_replays(setup): +def test_decorator_replays(setup) -> None: os.environ["DBT_RECORDER_MODE"] = "Replay" os.environ["DBT_RECORDER_FILE_PATH"] = "record.json" recorder = Recorder(RecorderMode.REPLAY, None) diff --git a/tests/unit/test_semver.py b/tests/unit/test_semver.py index ae48e592..383d3479 100644 --- a/tests/unit/test_semver.py +++ b/tests/unit/test_semver.py @@ -1,6 +1,6 @@ import itertools import unittest -from typing import List +from typing import List, Optional from dbt_common.exceptions import VersionsNotCompatibleError from dbt_common.semver import ( @@ -23,9 +23,11 @@ def semver_regex_versioning(versions: List[str]) -> bool: return True -def create_range(start_version_string, end_version_string): - start = UnboundedVersionSpecifier() - end = UnboundedVersionSpecifier() +def create_range( + start_version_string: Optional[str], end_version_string: Optional[str] +) -> VersionRange: + start: VersionSpecifier = UnboundedVersionSpecifier() + end: VersionSpecifier = UnboundedVersionSpecifier() if start_version_string is not None: start = VersionSpecifier.from_version_string(start_version_string) @@ -37,24 +39,24 @@ def create_range(start_version_string, end_version_string): class TestSemver(unittest.TestCase): - def assertVersionSetResult(self, inputs, output_range): + def assertVersionSetResult(self, inputs, output_range) -> None: expected = create_range(*output_range) for permutation in itertools.permutations(inputs): self.assertEqual(reduce_versions(*permutation), expected) - def assertInvalidVersionSet(self, inputs): + def assertInvalidVersionSet(self, inputs) -> None: for permutation in itertools.permutations(inputs): with self.assertRaises(VersionsNotCompatibleError): reduce_versions(*permutation) - def test__versions_compatible(self): + def test__versions_compatible(self) -> None: self.assertTrue(versions_compatible("0.0.1", "0.0.1")) self.assertFalse(versions_compatible("0.0.1", "0.0.2")) self.assertTrue(versions_compatible(">0.0.1", "0.0.2")) self.assertFalse(versions_compatible("0.4.5a1", "0.4.5a2")) - def test__semver_regex_versions(self): + def test__semver_regex_versions(self) -> None: self.assertTrue( semver_regex_versioning( [ @@ -140,7 +142,7 @@ def test__semver_regex_versions(self): ) ) - def test__reduce_versions(self): + def test__reduce_versions(self) -> None: self.assertVersionSetResult(["0.0.1", "0.0.1"], ["=0.0.1", "=0.0.1"]) self.assertVersionSetResult(["0.0.1"], ["=0.0.1", "=0.0.1"]) @@ -175,7 +177,7 @@ def test__reduce_versions(self): self.assertInvalidVersionSet(["<0.0.3", ">=0.0.3"]) self.assertInvalidVersionSet(["<0.0.3", ">0.0.3"]) - def test__resolve_to_specific_version(self): + def test__resolve_to_specific_version(self) -> None: self.assertEqual( resolve_to_specific_version(create_range(">0.0.1", None), ["0.0.1", "0.0.2"]), "0.0.2" ) @@ -253,7 +255,7 @@ def test__resolve_to_specific_version(self): "0.9.1", ) - def test__filter_installable(self): + def test__filter_installable(self) -> None: installable = filter_installable( [ "1.1.0", diff --git a/tests/unit/test_system_client.py b/tests/unit/test_system_client.py index a4dcc323..d2cf27ed 100644 --- a/tests/unit/test_system_client.py +++ b/tests/unit/test_system_client.py @@ -12,39 +12,39 @@ class SystemClient(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: super().setUp() self.tmp_dir = mkdtemp() self.profiles_path = "{}/profiles.yml".format(self.tmp_dir) - def set_up_profile(self): + def set_up_profile(self) -> None: with open(self.profiles_path, "w") as f: f.write("ORIGINAL_TEXT") - def get_profile_text(self): + def get_profile_text(self) -> str: with open(self.profiles_path, "r") as f: return f.read() - def tearDown(self): + def tearDown(self) -> None: try: shutil.rmtree(self.tmp_dir) except Exception as e: # noqa: F841 pass - def test__make_file_when_exists(self): + def test__make_file_when_exists(self) -> None: self.set_up_profile() written = dbt_common.clients.system.make_file(self.profiles_path, contents="NEW_TEXT") self.assertFalse(written) self.assertEqual(self.get_profile_text(), "ORIGINAL_TEXT") - def test__make_file_when_not_exists(self): + def test__make_file_when_not_exists(self) -> None: written = dbt_common.clients.system.make_file(self.profiles_path, contents="NEW_TEXT") self.assertTrue(written) self.assertEqual(self.get_profile_text(), "NEW_TEXT") - def test__make_file_with_overwrite(self): + def test__make_file_with_overwrite(self) -> None: self.set_up_profile() written = dbt_common.clients.system.make_file( self.profiles_path, contents="NEW_TEXT", overwrite=True @@ -53,12 +53,12 @@ def test__make_file_with_overwrite(self): self.assertTrue(written) self.assertEqual(self.get_profile_text(), "NEW_TEXT") - def test__make_dir_from_str(self): + def test__make_dir_from_str(self) -> None: test_dir_str = self.tmp_dir + "/test_make_from_str/sub_dir" dbt_common.clients.system.make_directory(test_dir_str) self.assertTrue(Path(test_dir_str).is_dir()) - def test__make_dir_from_pathobj(self): + def test__make_dir_from_pathobj(self) -> None: test_dir_pathobj = Path(self.tmp_dir + "/test_make_from_pathobj/sub_dir") dbt_common.clients.system.make_directory(test_dir_pathobj) self.assertTrue(test_dir_pathobj.is_dir()) @@ -72,7 +72,7 @@ class TestRunCmd(unittest.TestCase): not_a_file = "zzzbbfasdfasdfsdaq" - def setUp(self): + def setUp(self) -> None: self.tempdir = mkdtemp() self.run_dir = os.path.join(self.tempdir, "run_dir") self.does_not_exist = os.path.join(self.tempdir, "does_not_exist") @@ -86,10 +86,10 @@ def setUp(self): with open(self.empty_file, "w") as fp: # noqa: F841 pass # "touch" - def tearDown(self): + def tearDown(self) -> None: shutil.rmtree(self.tempdir) - def test__executable_does_not_exist(self): + def test__executable_does_not_exist(self) -> None: with self.assertRaises(ExecutableError) as exc: dbt_common.clients.system.run_cmd(self.run_dir, [self.does_not_exist]) @@ -99,7 +99,7 @@ def test__executable_does_not_exist(self): self.assertIn("could not find", msg) self.assertIn(self.does_not_exist.lower(), msg) - def test__not_exe(self): + def test__not_exe(self) -> None: with self.assertRaises(ExecutableError) as exc: dbt_common.clients.system.run_cmd(self.run_dir, [self.empty_file]) @@ -112,14 +112,14 @@ def test__not_exe(self): self.assertIn("permissions", msg) self.assertIn(self.empty_file.lower(), msg) - def test__cwd_does_not_exist(self): + def test__cwd_does_not_exist(self) -> None: with self.assertRaises(WorkingDirectoryError) as exc: dbt_common.clients.system.run_cmd(self.does_not_exist, self.exists_cmd) msg = str(exc.exception).lower() self.assertIn("does not exist", msg) self.assertIn(self.does_not_exist.lower(), msg) - def test__cwd_not_directory(self): + def test__cwd_not_directory(self) -> None: with self.assertRaises(WorkingDirectoryError) as exc: dbt_common.clients.system.run_cmd(self.empty_file, self.exists_cmd) @@ -127,7 +127,7 @@ def test__cwd_not_directory(self): self.assertIn("not a directory", msg) self.assertIn(self.empty_file.lower(), msg) - def test__cwd_no_permissions(self): + def test__cwd_no_permissions(self) -> None: # it would be nice to add a windows test. Possible path to that is via # `psexec` (to get SYSTEM privs), use `icacls` to set permissions on # the directory for the test user. I'm pretty sure windows users can't @@ -145,18 +145,18 @@ def test__cwd_no_permissions(self): self.assertIn("permissions", msg) self.assertIn(self.run_dir.lower(), msg) - def test__ok(self): + def test__ok(self) -> None: out, err = dbt_common.clients.system.run_cmd(self.run_dir, self.exists_cmd) self.assertEqual(out.strip(), b"hello") self.assertEqual(err.strip(), b"") class TestFindMatching(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.base_dir = mkdtemp() self.tempdir = mkdtemp(dir=self.base_dir) - def test_find_matching_lowercase_file_pattern(self): + def test_find_matching_lowercase_file_pattern(self) -> None: with NamedTemporaryFile(prefix="sql-files", suffix=".sql", dir=self.tempdir) as named_file: file_path = os.path.dirname(named_file.name) relative_path = os.path.basename(file_path) @@ -175,7 +175,7 @@ def test_find_matching_lowercase_file_pattern(self): ] self.assertEqual(out, expected_output) - def test_find_matching_uppercase_file_pattern(self): + def test_find_matching_uppercase_file_pattern(self) -> None: with NamedTemporaryFile(prefix="sql-files", suffix=".SQL", dir=self.tempdir) as named_file: file_path = os.path.dirname(named_file.name) relative_path = os.path.basename(file_path) @@ -190,12 +190,12 @@ def test_find_matching_uppercase_file_pattern(self): ] self.assertEqual(out, expected_output) - def test_find_matching_file_pattern_not_found(self): + def test_find_matching_file_pattern_not_found(self) -> None: with NamedTemporaryFile(prefix="sql-files", suffix=".SQLT", dir=self.tempdir): out = dbt_common.clients.system.find_matching(self.tempdir, [""], "*.sql") self.assertEqual(out, []) - def test_ignore_spec(self): + def test_ignore_spec(self) -> None: with NamedTemporaryFile(prefix="sql-files", suffix=".sql", dir=self.tempdir): out = dbt_common.clients.system.find_matching( self.tempdir, @@ -207,7 +207,7 @@ def test_ignore_spec(self): ) self.assertEqual(out, []) - def tearDown(self): + def tearDown(self) -> None: try: shutil.rmtree(self.base_dir) except Exception as e: # noqa: F841 @@ -215,18 +215,18 @@ def tearDown(self): class TestUntarPackage(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.base_dir = mkdtemp() self.tempdir = mkdtemp(dir=self.base_dir) self.tempdest = mkdtemp(dir=self.base_dir) - def tearDown(self): + def tearDown(self) -> None: try: shutil.rmtree(self.base_dir) except Exception as e: # noqa: F841 pass - def test_untar_package_success(self): + def test_untar_package_success(self) -> None: # set up a valid tarball to test against with NamedTemporaryFile( prefix="my-package.2", suffix=".tar.gz", dir=self.tempdir, delete=False @@ -244,7 +244,7 @@ def test_untar_package_success(self): path = Path(os.path.join(self.tempdest, relative_file_a)) assert path.is_file() - def test_untar_package_failure(self): + def test_untar_package_failure(self) -> None: # create a text file then rename it as a tar (so it's invalid) with NamedTemporaryFile( prefix="a", suffix=".txt", dir=self.tempdir, delete=False @@ -259,7 +259,7 @@ def test_untar_package_failure(self): with self.assertRaises(tarfile.ReadError) as exc: # noqa: F841 dbt_common.clients.system.untar_package(tar_file_path, self.tempdest) - def test_untar_package_empty(self): + def test_untar_package_empty(self) -> None: # create a tarball with nothing in it with NamedTemporaryFile( prefix="my-empty-package.2", suffix=".tar.gz", dir=self.tempdir diff --git a/tests/unit/test_ui.py b/tests/unit/test_ui.py index 22e431d5..5b70b1d1 100644 --- a/tests/unit/test_ui.py +++ b/tests/unit/test_ui.py @@ -1,11 +1,11 @@ from dbt_common.ui import warning_tag, error_tag -def test_warning_tag(): +def test_warning_tag() -> None: tagged = warning_tag("hi") assert "WARNING" in tagged -def test_error_tag(): +def test_error_tag() -> None: tagged = error_tag("hi") assert "ERROR" in tagged diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 250c20cc..93c57046 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -5,7 +5,7 @@ class TestDeepMerge(unittest.TestCase): - def test__simple_cases(self): + def test__simple_cases(self) -> None: cases = [ {"args": [{}, {"a": 1}], "expected": {"a": 1}, "description": "one key into empty"}, { @@ -27,7 +27,7 @@ def test__simple_cases(self): class TestMerge(unittest.TestCase): - def test__simple_cases(self): + def test__simple_cases(self) -> None: cases = [ {"args": [{}, {"a": 1}], "expected": {"a": 1}, "description": "one key into empty"}, { @@ -49,7 +49,7 @@ def test__simple_cases(self): class TestDeepMap(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.input_value = { "foo": { "bar": "hello", @@ -74,7 +74,7 @@ def intify_all(value, _): except (TypeError, ValueError): return -1 - def test__simple_cases(self): + def test__simple_cases(self) -> None: expected = { "foo": { "bar": -1, @@ -104,7 +104,7 @@ def special_keypath(value, keypath): else: return value - def test__keypath(self): + def test__keypath(self) -> None: expected = { "foo": { "bar": "hello", @@ -128,11 +128,11 @@ def test__keypath(self): actual = dbt_common.utils.dict.deep_map_render(self.special_keypath, expected) self.assertEqual(actual, expected) - def test__noop(self): + def test__noop(self) -> None: actual = dbt_common.utils.dict.deep_map_render(lambda x, _: x, self.input_value) self.assertEqual(actual, self.input_value) - def test_trivial(self): + def test_trivial(self) -> None: cases = [[], {}, 1, "abc", None, True] for case in cases: result = dbt_common.utils.dict.deep_map_render(lambda x, _: x, case)