diff --git a/dbt_common/contracts/config/base.py b/dbt_common/contracts/config/base.py index a16b4d9b..42acb1bf 100644 --- a/dbt_common/contracts/config/base.py +++ b/dbt_common/contracts/config/base.py @@ -17,11 +17,11 @@ @dataclass class BaseConfig(AdditionalPropertiesAllowed, Replaceable): # enable syntax like: config['key'] - def __getitem__(self, key): + def __getitem__(self, key: str) -> Any: return self.get(key) # like doing 'get' on a dictionary - def get(self, key, default=None): + def get(self, key: str, default: Any = None) -> Any: if hasattr(self, key): return getattr(self, key) elif key in self._extra: @@ -30,13 +30,13 @@ def get(self, key, default=None): return default # enable syntax like: config['key'] = value - def __setitem__(self, key, value): + def __setitem__(self, key: str, value) -> None: if hasattr(self, key): setattr(self, key, value) else: self._extra[key] = value - def __delitem__(self, key): + def __delitem__(self, key: str) -> None: if hasattr(self, key): msg = ( 'Error, tried to delete config key "{}": Cannot delete ' "built-in keys" @@ -60,7 +60,7 @@ def _content_iterator(self, include_condition: Callable[[Field], bool]): def __iter__(self): yield from self._content_iterator(include_condition=lambda f: True) - def __len__(self): + def __len__(self) -> int: return len(self._get_fields()) + len(self._extra) @staticmethod @@ -221,7 +221,7 @@ def _merge_field_value( merge_behavior: MergeBehavior, self_value: Any, other_value: Any, -): +) -> Any: if merge_behavior == MergeBehavior.Clobber: return other_value elif merge_behavior == MergeBehavior.Append: diff --git a/dbt_common/invocation.py b/dbt_common/invocation.py index 0e5d3206..adbd7d18 100644 --- a/dbt_common/invocation.py +++ b/dbt_common/invocation.py @@ -7,6 +7,6 @@ def get_invocation_id() -> str: return _INVOCATION_ID -def reset_invocation_id(): +def reset_invocation_id() -> None: global _INVOCATION_ID _INVOCATION_ID = str(uuid.uuid4()) diff --git a/dbt_common/semver.py b/dbt_common/semver.py index fbdcefa5..4c411911 100644 --- a/dbt_common/semver.py +++ b/dbt_common/semver.py @@ -67,7 +67,7 @@ class VersionSpecification(dbtClassMixin): _VERSION_REGEX = re.compile(_VERSION_REGEX_PAT_STR, re.VERBOSE) -def _cmp(a, b): +def _cmp(a, b) -> int: """Return negative if ab.""" return (a > b) - (a < b) @@ -123,7 +123,7 @@ def to_range(self) -> "VersionRange": return VersionRange(start=range_start, end=range_end) - def compare(self, other): + def compare(self, other: "VersionSpecifier") -> int: if self.is_unbounded or other.is_unbounded: return 0 @@ -192,16 +192,16 @@ def compare(self, other): return 0 - def __lt__(self, other) -> bool: + def __lt__(self, other: "VersionSpecifier") -> bool: return self.compare(other) == -1 - def __gt__(self, other) -> bool: + def __gt__(self, other: "VersionSpecifier") -> bool: return self.compare(other) == 1 - def __eq___(self, other) -> bool: + def __eq___(self, other: "VersionSpecifier") -> bool: return self.compare(other) == 0 - def __cmp___(self, other): + def __cmp___(self, other: "VersionSpecifier") -> int: return self.compare(other) @property @@ -221,7 +221,7 @@ def is_exact(self) -> bool: return self.matcher == Matchers.EXACT @classmethod - def _nat_cmp(cls, a, b): + def _nat_cmp(cls, a, b) -> int: def cmp_prerelease_tag(a, b): if isinstance(a, int) and isinstance(b, int): return _cmp(a, b) @@ -358,23 +358,23 @@ def __init__(self, *args, **kwargs) -> None: matcher=Matchers.EXACT, major=None, minor=None, patch=None, prerelease=None, build=None ) - def __str__(self): + def __str__(self) -> str: return "*" @property - def is_unbounded(self): + def is_unbounded(self) -> bool: return True @property - def is_lower_bound(self): + def is_lower_bound(self) -> bool: return False @property - def is_upper_bound(self): + def is_upper_bound(self) -> bool: return False @property - def is_exact(self): + def is_exact(self) -> bool: return False diff --git a/dbt_common/utils/casting.py b/dbt_common/utils/casting.py index f366db7f..076d4c12 100644 --- a/dbt_common/utils/casting.py +++ b/dbt_common/utils/casting.py @@ -1,7 +1,7 @@ # This is useful for proto generated classes in particular, since # the default for protobuf for strings is the empty string, so # Optional[str] types don't work for generated Python classes. -from typing import Any, Dict, Optional +from typing import Any, Dict, Mapping, Optional def cast_to_str(string: Optional[str]) -> str: @@ -18,8 +18,9 @@ def cast_to_int(integer: Optional[int]) -> int: return integer -def cast_dict_to_dict_of_strings(dct: Dict[Any, Any]) -> Dict[str, str]: +def cast_dict_to_dict_of_strings(dct: Mapping[Any, Any]) -> Dict[str, str]: new_dct: Dict[str, str] = {} + for k, v in dct.items(): new_dct[str(k)] = str(v) return new_dct diff --git a/dbt_common/utils/connection.py b/dbt_common/utils/connection.py index 5c76fe7f..cc4e1ce4 100644 --- a/dbt_common/utils/connection.py +++ b/dbt_common/utils/connection.py @@ -1,4 +1,5 @@ import time +from typing import Callable from dbt_common.events.types import RecordRetryException, RetryExternalCall from dbt_common.exceptions import ConnectionError @@ -7,7 +8,7 @@ import requests -def connection_exception_retry(fn, max_attempts: int, attempt: int = 0): +def connection_exception_retry(fn: Callable, max_attempts: int, attempt: int = 0): """Handle connection retries gracefully. Attempts to run a function that makes an external call, if the call fails diff --git a/dbt_common/utils/jinja.py b/dbt_common/utils/jinja.py index 260ccb6a..c9d9fa8e 100644 --- a/dbt_common/utils/jinja.py +++ b/dbt_common/utils/jinja.py @@ -1,3 +1,5 @@ +from typing import Optional + from dbt_common.exceptions import DbtInternalError @@ -5,20 +7,20 @@ DOCS_PREFIX = "dbt_docs__" -def get_dbt_macro_name(name) -> str: +def get_dbt_macro_name(name: str) -> str: if name is None: raise DbtInternalError("Got None for a macro name!") return f"{MACRO_PREFIX}{name}" -def get_dbt_docs_name(name) -> str: +def get_dbt_docs_name(name: str) -> str: if name is None: raise DbtInternalError("Got None for a doc name!") return f"{DOCS_PREFIX}{name}" def get_materialization_macro_name( - materialization_name, adapter_type=None, with_prefix=True + materialization_name: str, adapter_type: Optional[str] = None, with_prefix: bool = True ) -> str: if adapter_type is None: adapter_type = "default" @@ -26,10 +28,10 @@ def get_materialization_macro_name( return get_dbt_macro_name(name) if with_prefix else name -def get_docs_macro_name(docs_name, with_prefix=True): +def get_docs_macro_name(docs_name: str, with_prefix: bool = True) -> str: return get_dbt_docs_name(docs_name) if with_prefix else docs_name -def get_test_macro_name(test_name, with_prefix=True): +def get_test_macro_name(test_name: str, with_prefix: bool = True) -> str: name = f"test_{test_name}" return get_dbt_macro_name(name) if with_prefix else name