Skip to content

Commit

Permalink
Add some type annotations (#156)
Browse files Browse the repository at this point in the history
Add misc. type annotations
  • Loading branch information
peterallenwebb authored Jul 30, 2024
1 parent db99ddd commit bef3b7d
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 27 deletions.
12 changes: 6 additions & 6 deletions dbt_common/contracts/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
@dataclass
class BaseConfig(AdditionalPropertiesAllowed, Replaceable):
# enable syntax like: config['key']
def __getitem__(self, key):
def __getitem__(self, key: str) -> Any:
return self.get(key)

# like doing 'get' on a dictionary
def get(self, key, default=None):
def get(self, key: str, default: Any = None) -> Any:
if hasattr(self, key):
return getattr(self, key)
elif key in self._extra:
Expand All @@ -30,13 +30,13 @@ def get(self, key, default=None):
return default

# enable syntax like: config['key'] = value
def __setitem__(self, key, value):
def __setitem__(self, key: str, value) -> None:
if hasattr(self, key):
setattr(self, key, value)
else:
self._extra[key] = value

def __delitem__(self, key):
def __delitem__(self, key: str) -> None:
if hasattr(self, key):
msg = (
'Error, tried to delete config key "{}": Cannot delete ' "built-in keys"
Expand All @@ -60,7 +60,7 @@ def _content_iterator(self, include_condition: Callable[[Field], bool]):
def __iter__(self):
yield from self._content_iterator(include_condition=lambda f: True)

def __len__(self):
def __len__(self) -> int:
return len(self._get_fields()) + len(self._extra)

@staticmethod
Expand Down Expand Up @@ -221,7 +221,7 @@ def _merge_field_value(
merge_behavior: MergeBehavior,
self_value: Any,
other_value: Any,
):
) -> Any:
if merge_behavior == MergeBehavior.Clobber:
return other_value
elif merge_behavior == MergeBehavior.Append:
Expand Down
2 changes: 1 addition & 1 deletion dbt_common/invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ def get_invocation_id() -> str:
return _INVOCATION_ID


def reset_invocation_id():
def reset_invocation_id() -> None:
global _INVOCATION_ID
_INVOCATION_ID = str(uuid.uuid4())
24 changes: 12 additions & 12 deletions dbt_common/semver.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class VersionSpecification(dbtClassMixin):
_VERSION_REGEX = re.compile(_VERSION_REGEX_PAT_STR, re.VERBOSE)


def _cmp(a, b):
def _cmp(a, b) -> int:
"""Return negative if a<b, zero if a==b, positive if a>b."""
return (a > b) - (a < b)

Expand Down Expand Up @@ -123,7 +123,7 @@ def to_range(self) -> "VersionRange":

return VersionRange(start=range_start, end=range_end)

def compare(self, other):
def compare(self, other: "VersionSpecifier") -> int:
if self.is_unbounded or other.is_unbounded:
return 0

Expand Down Expand Up @@ -192,16 +192,16 @@ def compare(self, other):

return 0

def __lt__(self, other) -> bool:
def __lt__(self, other: "VersionSpecifier") -> bool:
return self.compare(other) == -1

def __gt__(self, other) -> bool:
def __gt__(self, other: "VersionSpecifier") -> bool:
return self.compare(other) == 1

def __eq___(self, other) -> bool:
def __eq___(self, other: "VersionSpecifier") -> bool:
return self.compare(other) == 0

def __cmp___(self, other):
def __cmp___(self, other: "VersionSpecifier") -> int:
return self.compare(other)

@property
Expand All @@ -221,7 +221,7 @@ def is_exact(self) -> bool:
return self.matcher == Matchers.EXACT

@classmethod
def _nat_cmp(cls, a, b):
def _nat_cmp(cls, a, b) -> int:
def cmp_prerelease_tag(a, b):
if isinstance(a, int) and isinstance(b, int):
return _cmp(a, b)
Expand Down Expand Up @@ -358,23 +358,23 @@ def __init__(self, *args, **kwargs) -> None:
matcher=Matchers.EXACT, major=None, minor=None, patch=None, prerelease=None, build=None
)

def __str__(self):
def __str__(self) -> str:
return "*"

@property
def is_unbounded(self):
def is_unbounded(self) -> bool:
return True

@property
def is_lower_bound(self):
def is_lower_bound(self) -> bool:
return False

@property
def is_upper_bound(self):
def is_upper_bound(self) -> bool:
return False

@property
def is_exact(self):
def is_exact(self) -> bool:
return False


Expand Down
5 changes: 3 additions & 2 deletions dbt_common/utils/casting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This is useful for proto generated classes in particular, since
# the default for protobuf for strings is the empty string, so
# Optional[str] types don't work for generated Python classes.
from typing import Any, Dict, Optional
from typing import Any, Dict, Mapping, Optional


def cast_to_str(string: Optional[str]) -> str:
Expand All @@ -18,8 +18,9 @@ def cast_to_int(integer: Optional[int]) -> int:
return integer


def cast_dict_to_dict_of_strings(dct: Dict[Any, Any]) -> Dict[str, str]:
def cast_dict_to_dict_of_strings(dct: Mapping[Any, Any]) -> Dict[str, str]:
new_dct: Dict[str, str] = {}

for k, v in dct.items():
new_dct[str(k)] = str(v)
return new_dct
3 changes: 2 additions & 1 deletion dbt_common/utils/connection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
from typing import Callable

from dbt_common.events.types import RecordRetryException, RetryExternalCall
from dbt_common.exceptions import ConnectionError
Expand All @@ -7,7 +8,7 @@
import requests


def connection_exception_retry(fn, max_attempts: int, attempt: int = 0):
def connection_exception_retry(fn: Callable, max_attempts: int, attempt: int = 0):
"""Handle connection retries gracefully.
Attempts to run a function that makes an external call, if the call fails
Expand Down
12 changes: 7 additions & 5 deletions dbt_common/utils/jinja.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,37 @@
from typing import Optional

from dbt_common.exceptions import DbtInternalError


MACRO_PREFIX = "dbt_macro__"
DOCS_PREFIX = "dbt_docs__"


def get_dbt_macro_name(name) -> str:
def get_dbt_macro_name(name: str) -> str:
if name is None:
raise DbtInternalError("Got None for a macro name!")
return f"{MACRO_PREFIX}{name}"


def get_dbt_docs_name(name) -> str:
def get_dbt_docs_name(name: str) -> str:
if name is None:
raise DbtInternalError("Got None for a doc name!")
return f"{DOCS_PREFIX}{name}"


def get_materialization_macro_name(
materialization_name, adapter_type=None, with_prefix=True
materialization_name: str, adapter_type: Optional[str] = None, with_prefix: bool = True
) -> str:
if adapter_type is None:
adapter_type = "default"
name = f"materialization_{materialization_name}_{adapter_type}"
return get_dbt_macro_name(name) if with_prefix else name


def get_docs_macro_name(docs_name, with_prefix=True):
def get_docs_macro_name(docs_name: str, with_prefix: bool = True) -> str:
return get_dbt_docs_name(docs_name) if with_prefix else docs_name


def get_test_macro_name(test_name, with_prefix=True):
def get_test_macro_name(test_name: str, with_prefix: bool = True) -> str:
name = f"test_{test_name}"
return get_dbt_macro_name(name) if with_prefix else name

0 comments on commit bef3b7d

Please sign in to comment.