Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More Type Annotations #177

Merged
merged 3 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dbt_common/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dbt_common.record import Recorder


class CaseInsensitiveMapping(Mapping):
class CaseInsensitiveMapping(Mapping[str, str]):
def __init__(self, env: Mapping[str, str]):
self._env = {k.casefold(): (k, v) for k, v in env.items()}

Expand Down Expand Up @@ -65,7 +65,7 @@ def env_secrets(self) -> List[str]:


def reliably_get_invocation_var() -> ContextVar[InvocationContext]:
invocation_var: Optional[ContextVar] = next(
invocation_var: Optional[ContextVar[InvocationContext]] = next(
(cv for cv in copy_context() if cv.name == _INVOCATION_CONTEXT_VAR.name), None
)

Expand Down
12 changes: 6 additions & 6 deletions dbt_common/contracts/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass, Field

from itertools import chain
from typing import Callable, Dict, Any, List, TypeVar, Type
from typing import Any, Callable, Dict, Iterator, List, Type, TypeVar

from dbt_common.contracts.config.metadata import Metadata
from dbt_common.exceptions import CompilationError, DbtInternalError
Expand Down Expand Up @@ -45,7 +45,7 @@ def __delitem__(self, key: str) -> None:
else:
del self._extra[key]

def _content_iterator(self, include_condition: Callable[[Field], bool]):
def _content_iterator(self, include_condition: Callable[[Field[Any]], bool]) -> Iterator[str]:
seen = set()
for fld, _ in self._get_fields():
seen.add(fld.name)
Expand All @@ -57,7 +57,7 @@ def _content_iterator(self, include_condition: Callable[[Field], bool]):
seen.add(key)
yield key

def __iter__(self):
def __iter__(self) -> Iterator[str]:
yield from self._content_iterator(include_condition=lambda f: True)

def __len__(self) -> int:
Expand All @@ -76,7 +76,7 @@ def compare_key(
elif key in unrendered and key not in other:
return False
else:
return unrendered[key] == other[key]
return bool(unrendered[key] == other[key])

@classmethod
def same_contents(cls, unrendered: Dict[str, Any], other: Dict[str, Any]) -> bool:
Expand Down Expand Up @@ -203,11 +203,11 @@ def metadata_key(cls) -> str:
return "compare"

@classmethod
def should_include(cls, fld: Field) -> bool:
def should_include(cls, fld: Field[Any]) -> bool:
return cls.from_field(fld) == cls.Include


def _listify(value: Any) -> List:
def _listify(value: Any) -> List[Any]:
if isinstance(value, list):
return value[:]
else:
Expand Down
7 changes: 4 additions & 3 deletions dbt_common/contracts/util.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import dataclasses
from typing import Any


# TODO: remove from dbt_common.contracts.util:: Replaceable + references
class Replaceable:
def replace(self, **kwargs):
return dataclasses.replace(self, **kwargs)
def replace(self, **kwargs: Any):
return dataclasses.replace(self, **kwargs) # type: ignore


class Mergeable(Replaceable):
Expand All @@ -15,7 +16,7 @@ def merged(self, *args):
replacements = {}
cls = type(self)
for arg in args:
for field in dataclasses.fields(cls):
for field in dataclasses.fields(cls): # type: ignore
value = getattr(arg, field.name)
if value is not None:
replacements[field.name] = value
Expand Down
2 changes: 1 addition & 1 deletion dbt_common/dataclass_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def json_schema(cls):
return json_schema

@classmethod
def validate(cls, data):
def validate(cls, data: Any) -> None:
json_schema = cls.json_schema()
validator = jsonschema.Draft7Validator(json_schema)
error = next(iter(validator.iter_errors(data)), None)
Expand Down
2 changes: 1 addition & 1 deletion dbt_common/events/contextvars.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def get_contextvars(prefix: str) -> Dict[str, Any]:
return rv


def get_node_info():
def get_node_info() -> Dict[str, Any]:
cvars = get_contextvars(LOG_PREFIX)
if "node_info" in cvars:
return cvars["node_info"]
Expand Down
26 changes: 13 additions & 13 deletions dbt_common/exceptions/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import builtins
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional
import os

from dbt_common.constants import SECRET_ENV_PREFIX
Expand All @@ -23,7 +23,7 @@ class DbtBaseException(Exception):
CODE = -32000
MESSAGE = "Server Error"

def data(self):
def data(self) -> Dict[str, Any]:
# if overriding, make sure the result is json-serializable.
return {
"type": self.__class__.__name__,
Expand All @@ -32,15 +32,15 @@ def data(self):


class DbtInternalError(DbtBaseException):
def __init__(self, msg: str):
def __init__(self, msg: str) -> None:
self.stack: List = []
self.msg = scrub_secrets(msg, env_secrets())

@property
def type(self) -> str:
return "Internal"

def process_stack(self):
def process_stack(self) -> List[str]:
lines = []
stack = self.stack
first = True
Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__(self, msg: str, node=None) -> None:
self.node = node
self.msg = scrub_secrets(msg, env_secrets())

def add_node(self, node=None):
def add_node(self, node=None) -> None:
if node is not None and node is not self.node:
if self.node is not None:
self.stack.append(self.node)
Expand All @@ -91,7 +91,7 @@ def add_node(self, node=None):
def type(self):
return "Runtime"

def node_to_string(self, node: Any):
def node_to_string(self, node: Any) -> str:
"""Given a node-like object we attempt to create the best identifier we can."""
result = ""
if hasattr(node, "resource_type"):
Expand All @@ -103,7 +103,7 @@ def node_to_string(self, node: Any):

return result.strip() if result != "" else "<Unknown>"

def process_stack(self):
def process_stack(self) -> List[str]:
lines = []
stack = self.stack + [self.node]
first = True
Expand All @@ -122,7 +122,7 @@ def process_stack(self):

return lines

def validator_error_message(self, exc: builtins.Exception):
def validator_error_message(self, exc: builtins.Exception) -> str:
"""Given a dbt.dataclass_schema.ValidationError return the relevant parts as a string.

dbt.dataclass_schema.ValidationError is basically a jsonschema.ValidationError)
Expand All @@ -132,7 +132,7 @@ def validator_error_message(self, exc: builtins.Exception):
path = "[%s]" % "][".join(map(repr, exc.relative_path))
return f"at path {path}: {exc.message}"

def __str__(self, prefix: str = "! "):
def __str__(self, prefix: str = "! ") -> str:
node_string = ""

if self.node is not None:
Expand All @@ -149,7 +149,7 @@ def __str__(self, prefix: str = "! "):

return lines[0] + "\n" + "\n".join([" " + line for line in lines[1:]])

def data(self):
def data(self) -> Dict[str, Any]:
result = DbtBaseException.data(self)
if self.node is None:
return result
Expand Down Expand Up @@ -236,7 +236,7 @@ class DbtDatabaseError(DbtRuntimeError):
CODE = 10003
MESSAGE = "Database Error"

def process_stack(self):
def process_stack(self) -> List[str]:
lines = []

if hasattr(self.node, "build_path") and self.node.build_path:
Expand All @@ -250,7 +250,7 @@ def type(self):


class UnexpectedNullError(DbtDatabaseError):
def __init__(self, field_name: str, source):
def __init__(self, field_name: str, source) -> None:
self.field_name = field_name
self.source = source
msg = (
Expand All @@ -268,7 +268,7 @@ def __init__(self, cwd: str, cmd: List[str], msg: str = "Error running command")
self.cmd = cmd_scrubbed
self.args = (cwd, cmd_scrubbed, msg)

def __str__(self):
def __str__(self, prefix: str = "! ") -> str:
if len(self.cmd) == 0:
return f"{self.msg}: No arguments given"
return f'{self.msg}: "{self.cmd[0]}"'
Loading