From 05acc41613c0c2eadcf4237a13545da66ade69da Mon Sep 17 00:00:00 2001 From: Peter Allen Webb Date: Tue, 25 Jun 2024 18:11:14 -0400 Subject: [PATCH] Add misc. type annotations --- dbt_common/contracts/config/base.py | 12 ++++----- dbt_common/invocation.py | 2 +- dbt_common/semver.py | 40 ++++++++++++++--------------- dbt_common/utils/casting.py | 4 +-- dbt_common/utils/connection.py | 3 ++- dbt_common/utils/jinja.py | 14 ++++++---- 6 files changed, 40 insertions(+), 35 deletions(-) diff --git a/dbt_common/contracts/config/base.py b/dbt_common/contracts/config/base.py index a16b4d9b..42acb1bf 100644 --- a/dbt_common/contracts/config/base.py +++ b/dbt_common/contracts/config/base.py @@ -17,11 +17,11 @@ @dataclass class BaseConfig(AdditionalPropertiesAllowed, Replaceable): # enable syntax like: config['key'] - def __getitem__(self, key): + def __getitem__(self, key: str) -> Any: return self.get(key) # like doing 'get' on a dictionary - def get(self, key, default=None): + def get(self, key: str, default: Any = None) -> Any: if hasattr(self, key): return getattr(self, key) elif key in self._extra: @@ -30,13 +30,13 @@ def get(self, key, default=None): return default # enable syntax like: config['key'] = value - def __setitem__(self, key, value): + def __setitem__(self, key: str, value) -> None: if hasattr(self, key): setattr(self, key, value) else: self._extra[key] = value - def __delitem__(self, key): + def __delitem__(self, key: str) -> None: if hasattr(self, key): msg = ( 'Error, tried to delete config key "{}": Cannot delete ' "built-in keys" @@ -60,7 +60,7 @@ def _content_iterator(self, include_condition: Callable[[Field], bool]): def __iter__(self): yield from self._content_iterator(include_condition=lambda f: True) - def __len__(self): + def __len__(self) -> int: return len(self._get_fields()) + len(self._extra) @staticmethod @@ -221,7 +221,7 @@ def _merge_field_value( merge_behavior: MergeBehavior, self_value: Any, other_value: Any, -): +) -> Any: if merge_behavior == MergeBehavior.Clobber: return other_value elif merge_behavior == MergeBehavior.Append: diff --git a/dbt_common/invocation.py b/dbt_common/invocation.py index 0e5d3206..adbd7d18 100644 --- a/dbt_common/invocation.py +++ b/dbt_common/invocation.py @@ -7,6 +7,6 @@ def get_invocation_id() -> str: return _INVOCATION_ID -def reset_invocation_id(): +def reset_invocation_id() -> None: global _INVOCATION_ID _INVOCATION_ID = str(uuid.uuid4()) diff --git a/dbt_common/semver.py b/dbt_common/semver.py index 951f4e8e..ef0182ba 100644 --- a/dbt_common/semver.py +++ b/dbt_common/semver.py @@ -67,14 +67,14 @@ class VersionSpecification(dbtClassMixin): _VERSION_REGEX = re.compile(_VERSION_REGEX_PAT_STR, re.VERBOSE) -def _cmp(a, b): +def _cmp(a, b) -> int: """Return negative if ab.""" return (a > b) - (a < b) @dataclass class VersionSpecifier(VersionSpecification): - def to_version_string(self, skip_matcher=False): + def to_version_string(self, skip_matcher: bool = False) -> str: prerelease = "" build = "" matcher = "" @@ -92,7 +92,7 @@ def to_version_string(self, skip_matcher=False): ) @classmethod - def from_version_string(cls, version_string): + def from_version_string(cls, version_string: str) -> "VersionSpecifier": match = _VERSION_REGEX.match(version_string) if not match: @@ -104,7 +104,7 @@ def from_version_string(cls, version_string): return cls.from_dict(matched) - def __str__(self): + def __str__(self) -> str: return self.to_version_string() def to_range(self) -> "VersionRange": @@ -123,7 +123,7 @@ def to_range(self) -> "VersionRange": return VersionRange(start=range_start, end=range_end) - def compare(self, other): + def compare(self, other: "VersionSpecifier") -> int: if self.is_unbounded or other.is_unbounded: return 0 @@ -192,36 +192,36 @@ def compare(self, other): return 0 - def __lt__(self, other): + def __lt__(self, other: "VersionSpecifier") -> bool: return self.compare(other) == -1 - def __gt__(self, other): + def __gt__(self, other: "VersionSpecifier") -> bool: return self.compare(other) == 1 - def __eq___(self, other): + def __eq___(self, other: "VersionSpecifier") -> bool: return self.compare(other) == 0 - def __cmp___(self, other): + def __cmp___(self, other: "VersionSpecifier") -> int: return self.compare(other) @property - def is_unbounded(self): + def is_unbounded(self) -> bool: return False @property - def is_lower_bound(self): + def is_lower_bound(self) -> bool: return self.matcher in [Matchers.GREATER_THAN, Matchers.GREATER_THAN_OR_EQUAL] @property - def is_upper_bound(self): + def is_upper_bound(self) -> bool: return self.matcher in [Matchers.LESS_THAN, Matchers.LESS_THAN_OR_EQUAL] @property - def is_exact(self): + def is_exact(self) -> bool: return self.matcher == Matchers.EXACT @classmethod - def _nat_cmp(cls, a, b): + def _nat_cmp(cls, a, b) -> int: def cmp_prerelease_tag(a, b): if isinstance(a, int) and isinstance(b, int): return _cmp(a, b) @@ -358,23 +358,23 @@ def __init__(self, *args, **kwargs) -> None: matcher=Matchers.EXACT, major=None, minor=None, patch=None, prerelease=None, build=None ) - def __str__(self): + def __str__(self) -> str: return "*" @property - def is_unbounded(self): + def is_unbounded(self) -> bool: return True @property - def is_lower_bound(self): + def is_lower_bound(self) -> bool: return False @property - def is_upper_bound(self): + def is_upper_bound(self) -> bool: return False @property - def is_exact(self): + def is_exact(self) -> bool: return False @@ -418,7 +418,7 @@ def reduce_versions(*args): return to_return -def versions_compatible(*args): +def versions_compatible(*args) -> bool: if len(args) == 1: return True diff --git a/dbt_common/utils/casting.py b/dbt_common/utils/casting.py index 811ea376..f86ca191 100644 --- a/dbt_common/utils/casting.py +++ b/dbt_common/utils/casting.py @@ -1,7 +1,7 @@ # This is useful for proto generated classes in particular, since # the default for protobuf for strings is the empty string, so # Optional[str] types don't work for generated Python classes. -from typing import Optional +from typing import Any, Dict, Mapping, Optional def cast_to_str(string: Optional[str]) -> str: @@ -18,7 +18,7 @@ def cast_to_int(integer: Optional[int]) -> int: return integer -def cast_dict_to_dict_of_strings(dct): +def cast_dict_to_dict_of_strings(dct: Mapping[Any, Any]) -> Dict[str, str]: new_dct = {} for k, v in dct.items(): new_dct[str(k)] = str(v) diff --git a/dbt_common/utils/connection.py b/dbt_common/utils/connection.py index 5c76fe7f..cc4e1ce4 100644 --- a/dbt_common/utils/connection.py +++ b/dbt_common/utils/connection.py @@ -1,4 +1,5 @@ import time +from typing import Callable from dbt_common.events.types import RecordRetryException, RetryExternalCall from dbt_common.exceptions import ConnectionError @@ -7,7 +8,7 @@ import requests -def connection_exception_retry(fn, max_attempts: int, attempt: int = 0): +def connection_exception_retry(fn: Callable, max_attempts: int, attempt: int = 0): """Handle connection retries gracefully. Attempts to run a function that makes an external call, if the call fails diff --git a/dbt_common/utils/jinja.py b/dbt_common/utils/jinja.py index 36464cbe..c9d9fa8e 100644 --- a/dbt_common/utils/jinja.py +++ b/dbt_common/utils/jinja.py @@ -1,3 +1,5 @@ +from typing import Optional + from dbt_common.exceptions import DbtInternalError @@ -5,29 +7,31 @@ DOCS_PREFIX = "dbt_docs__" -def get_dbt_macro_name(name): +def get_dbt_macro_name(name: str) -> str: if name is None: raise DbtInternalError("Got None for a macro name!") return f"{MACRO_PREFIX}{name}" -def get_dbt_docs_name(name): +def get_dbt_docs_name(name: str) -> str: if name is None: raise DbtInternalError("Got None for a doc name!") return f"{DOCS_PREFIX}{name}" -def get_materialization_macro_name(materialization_name, adapter_type=None, with_prefix=True): +def get_materialization_macro_name( + materialization_name: str, adapter_type: Optional[str] = None, with_prefix: bool = True +) -> str: if adapter_type is None: adapter_type = "default" name = f"materialization_{materialization_name}_{adapter_type}" return get_dbt_macro_name(name) if with_prefix else name -def get_docs_macro_name(docs_name, with_prefix=True): +def get_docs_macro_name(docs_name: str, with_prefix: bool = True) -> str: return get_dbt_docs_name(docs_name) if with_prefix else docs_name -def get_test_macro_name(test_name, with_prefix=True): +def get_test_macro_name(test_name: str, with_prefix: bool = True) -> str: name = f"test_{test_name}" return get_dbt_macro_name(name) if with_prefix else name