Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More Type Annotations #177

Merged
merged 3 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dbt_common/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dbt_common.record import Recorder


class CaseInsensitiveMapping(Mapping):
class CaseInsensitiveMapping(Mapping[str, str]):
def __init__(self, env: Mapping[str, str]):
self._env = {k.casefold(): (k, v) for k, v in env.items()}

Expand Down Expand Up @@ -65,7 +65,7 @@ def env_secrets(self) -> List[str]:


def reliably_get_invocation_var() -> ContextVar[InvocationContext]:
invocation_var: Optional[ContextVar] = next(
invocation_var: Optional[ContextVar[InvocationContext]] = next(
(cv for cv in copy_context() if cv.name == _INVOCATION_CONTEXT_VAR.name), None
)

Expand Down
20 changes: 9 additions & 11 deletions dbt_common/contracts/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@
from dataclasses import dataclass, Field

from itertools import chain
from typing import Callable, Dict, Any, List, TypeVar, Type
from typing import Callable, Dict, Any, List, Type, Self, Iterator

from dbt_common.contracts.config.metadata import Metadata
from dbt_common.exceptions import CompilationError, DbtInternalError
from dbt_common.contracts.config.properties import AdditionalPropertiesAllowed
from dbt_common.contracts.util import Replaceable

T = TypeVar("T", bound="BaseConfig")


@dataclass
class BaseConfig(AdditionalPropertiesAllowed, Replaceable):
Expand Down Expand Up @@ -45,7 +43,7 @@ def __delitem__(self, key: str) -> None:
else:
del self._extra[key]

def _content_iterator(self, include_condition: Callable[[Field], bool]):
def _content_iterator(self, include_condition: Callable[[Field[Any]], bool]) -> Iterator[str]:
seen = set()
for fld, _ in self._get_fields():
seen.add(fld.name)
Expand All @@ -57,7 +55,7 @@ def _content_iterator(self, include_condition: Callable[[Field], bool]):
seen.add(key)
yield key

def __iter__(self):
def __iter__(self) -> Iterator[str]:
yield from self._content_iterator(include_condition=lambda f: True)

def __len__(self) -> int:
Expand All @@ -76,7 +74,7 @@ def compare_key(
elif key in unrendered and key not in other:
return False
else:
return unrendered[key] == other[key]
return bool(unrendered[key] == other[key])

@classmethod
def same_contents(cls, unrendered: Dict[str, Any], other: Dict[str, Any]) -> bool:
Expand Down Expand Up @@ -146,8 +144,8 @@ def _merge_dicts(cls, src: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, An
return result

def update_from(
self: T, data: Dict[str, Any], config_cls: Type[BaseConfig], validate: bool = True
) -> T:
self, data: Dict[str, Any], config_cls: Type[BaseConfig], validate: bool = True
) -> Self:
"""Update and validate config given a dict.

Given a dict of keys, update the current config from them, validate
Expand All @@ -169,7 +167,7 @@ def update_from(
self.validate(dct)
return self.from_dict(dct)

def finalize_and_validate(self: T) -> T:
def finalize_and_validate(self) -> Self:
dct = self.to_dict(omit_none=False)
self.validate(dct)
return self.from_dict(dct)
Expand Down Expand Up @@ -203,11 +201,11 @@ def metadata_key(cls) -> str:
return "compare"

@classmethod
def should_include(cls, fld: Field) -> bool:
def should_include(cls, fld: Field[Any]) -> bool:
return cls.from_field(fld) == cls.Include


def _listify(value: Any) -> List:
def _listify(value: Any) -> List[Any]:
if isinstance(value, list):
return value[:]
else:
Expand Down
9 changes: 5 additions & 4 deletions dbt_common/contracts/util.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import dataclasses
from typing import Any, Self


# TODO: remove from dbt_common.contracts.util:: Replaceable + references
class Replaceable:
def replace(self, **kwargs):
return dataclasses.replace(self, **kwargs)
def replace(self, **kwargs: Any) -> Self:
return dataclasses.replace(self, **kwargs) # type: ignore


class Mergeable(Replaceable):
def merged(self, *args):
def merged(self, *args: Self) -> Self:
"""Perform a shallow merge, where the last non-None write wins. This is
intended to merge dataclasses that are a collection of optional values.
"""
replacements = {}
cls = type(self)
for arg in args:
for field in dataclasses.fields(cls):
for field in dataclasses.fields(cls): # type: ignore
value = getattr(arg, field.name)
if value is not None:
replacements[field.name] = value
Expand Down
2 changes: 1 addition & 1 deletion dbt_common/dataclass_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def json_schema(cls):
return json_schema

@classmethod
def validate(cls, data):
def validate(cls, data: Any) -> None:
json_schema = cls.json_schema()
validator = jsonschema.Draft7Validator(json_schema)
error = next(iter(validator.iter_errors(data)), None)
Expand Down
2 changes: 1 addition & 1 deletion dbt_common/events/contextvars.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def get_contextvars(prefix: str) -> Dict[str, Any]:
return rv


def get_node_info():
def get_node_info() -> Dict[str, Any]:
cvars = get_contextvars(LOG_PREFIX)
if "node_info" in cvars:
return cvars["node_info"]
Expand Down
Loading