Skip to content

Commit

Permalink
More type annotations.
Browse files Browse the repository at this point in the history
  • Loading branch information
peterallenwebb committed Aug 1, 2024
1 parent bef3b7d commit 2018bc1
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 19 deletions.
4 changes: 2 additions & 2 deletions dbt_common/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dbt_common.record import Recorder


class CaseInsensitiveMapping(Mapping):
class CaseInsensitiveMapping(Mapping[str, str]):
def __init__(self, env: Mapping[str, str]):
self._env = {k.casefold(): (k, v) for k, v in env.items()}

Expand Down Expand Up @@ -65,7 +65,7 @@ def env_secrets(self) -> List[str]:


def reliably_get_invocation_var() -> ContextVar[InvocationContext]:
invocation_var: Optional[ContextVar] = next(
invocation_var: Optional[ContextVar[InvocationContext]] = next(
(cv for cv in copy_context() if cv.name == _INVOCATION_CONTEXT_VAR.name), None
)

Expand Down
20 changes: 9 additions & 11 deletions dbt_common/contracts/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@
from dataclasses import dataclass, Field

from itertools import chain
from typing import Callable, Dict, Any, List, TypeVar, Type
from typing import Callable, Dict, Any, List, Type, Self, Iterator

from dbt_common.contracts.config.metadata import Metadata
from dbt_common.exceptions import CompilationError, DbtInternalError
from dbt_common.contracts.config.properties import AdditionalPropertiesAllowed
from dbt_common.contracts.util import Replaceable

T = TypeVar("T", bound="BaseConfig")


@dataclass
class BaseConfig(AdditionalPropertiesAllowed, Replaceable):
Expand Down Expand Up @@ -45,7 +43,7 @@ def __delitem__(self, key: str) -> None:
else:
del self._extra[key]

def _content_iterator(self, include_condition: Callable[[Field], bool]):
def _content_iterator(self, include_condition: Callable[[Field[Any]], bool]) -> Iterator[str]:
seen = set()
for fld, _ in self._get_fields():
seen.add(fld.name)
Expand All @@ -57,7 +55,7 @@ def _content_iterator(self, include_condition: Callable[[Field], bool]):
seen.add(key)
yield key

def __iter__(self):
def __iter__(self) -> Iterator[str]:
yield from self._content_iterator(include_condition=lambda f: True)

def __len__(self) -> int:
Expand All @@ -76,7 +74,7 @@ def compare_key(
elif key in unrendered and key not in other:
return False
else:
return unrendered[key] == other[key]
return bool(unrendered[key] == other[key])

@classmethod
def same_contents(cls, unrendered: Dict[str, Any], other: Dict[str, Any]) -> bool:
Expand Down Expand Up @@ -146,8 +144,8 @@ def _merge_dicts(cls, src: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, An
return result

def update_from(
self: T, data: Dict[str, Any], config_cls: Type[BaseConfig], validate: bool = True
) -> T:
self, data: Dict[str, Any], config_cls: Type[BaseConfig], validate: bool = True
) -> Self:
"""Update and validate config given a dict.
Given a dict of keys, update the current config from them, validate
Expand All @@ -169,7 +167,7 @@ def update_from(
self.validate(dct)
return self.from_dict(dct)

def finalize_and_validate(self: T) -> T:
def finalize_and_validate(self) -> Self:
dct = self.to_dict(omit_none=False)
self.validate(dct)
return self.from_dict(dct)
Expand Down Expand Up @@ -203,11 +201,11 @@ def metadata_key(cls) -> str:
return "compare"

@classmethod
def should_include(cls, fld: Field) -> bool:
def should_include(cls, fld: Field[Any]) -> bool:
return cls.from_field(fld) == cls.Include


def _listify(value: Any) -> List:
def _listify(value: Any) -> List[Any]:
if isinstance(value, list):
return value[:]
else:
Expand Down
9 changes: 5 additions & 4 deletions dbt_common/contracts/util.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import dataclasses
from typing import Any, Self


# TODO: remove from dbt_common.contracts.util:: Replaceable + references
class Replaceable:
def replace(self, **kwargs):
return dataclasses.replace(self, **kwargs)
def replace(self, **kwargs: Any) -> Self:
return dataclasses.replace(self, **kwargs) # type: ignore


class Mergeable(Replaceable):
def merged(self, *args):
def merged(self, *args: Self) -> Self:
"""Perform a shallow merge, where the last non-None write wins. This is
intended to merge dataclasses that are a collection of optional values.
"""
replacements = {}
cls = type(self)
for arg in args:
for field in dataclasses.fields(cls):
for field in dataclasses.fields(cls): # type: ignore
value = getattr(arg, field.name)
if value is not None:
replacements[field.name] = value
Expand Down
2 changes: 1 addition & 1 deletion dbt_common/dataclass_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def json_schema(cls):
return json_schema

@classmethod
def validate(cls, data):
def validate(cls, data: Any) -> None:
json_schema = cls.json_schema()
validator = jsonschema.Draft7Validator(json_schema)
error = next(iter(validator.iter_errors(data)), None)
Expand Down
2 changes: 1 addition & 1 deletion dbt_common/events/contextvars.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def get_contextvars(prefix: str) -> Dict[str, Any]:
return rv


def get_node_info():
def get_node_info() -> Dict[str, Any]:
cvars = get_contextvars(LOG_PREFIX)
if "node_info" in cvars:
return cvars["node_info"]
Expand Down

0 comments on commit 2018bc1

Please sign in to comment.