Skip to content

Commit

Permalink
Even More Type Annotations (#180)
Browse files Browse the repository at this point in the history
* Even more type annotations.

* Fix some issues for Python 3.8

* Type annotations

* Generic type to work around lack of Self in Python before 3.11
  • Loading branch information
peterallenwebb authored Aug 19, 2024
1 parent 3521bd8 commit 4b8a41e
Show file tree
Hide file tree
Showing 11 changed files with 55 additions and 46 deletions.
4 changes: 2 additions & 2 deletions dbt_common/clients/agate_helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from codecs import BOM_UTF8

import agate # type: ignore
import agate
import datetime
import isodate
import json
Expand Down Expand Up @@ -149,7 +149,7 @@ def as_matrix(table):
return [r.values() for r in table.rows.values()]


def from_csv(abspath, text_columns, delimiter=","):
def from_csv(abspath, text_columns, delimiter=",") -> agate.Table:
type_tester = build_type_tester(text_columns=text_columns)
with open(abspath, encoding="utf-8") as fp:
if fp.read(1) != BOM:
Expand Down
14 changes: 8 additions & 6 deletions dbt_common/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Union, Set, Type
from typing_extensions import Protocol

import jinja2 # type: ignore
import jinja2.ext # type: ignore
import jinja2.nativetypes # type: ignore
import jinja2.nodes # type: ignore
import jinja2.parser # type: ignore
import jinja2.sandbox # type: ignore
import jinja2
import jinja2.ext
import jinja2.nativetypes
import jinja2.nodes
import jinja2.parser
import jinja2.sandbox

from dbt_common.tests import test_caching_enabled
from dbt_common.utils.jinja import (
Expand Down Expand Up @@ -124,6 +124,7 @@ def new_context(
"shared or locals parameters."
)

vars = {} if vars is None else vars
parent = ChainMap(vars, self.globals) if self.globals else vars

return self.environment.context_class(self.environment, parent, self.name, self.blocks)
Expand Down Expand Up @@ -544,6 +545,7 @@ def render_template(template, ctx: Dict[str, Any], node=None) -> str:

def _get_blocks_hash(text: str, allowed_blocks: Optional[Set[str]], collect_raw_data: bool) -> int:
"""Provides a hash function over the arguments to extract_toplevel_blocks, in order to support caching."""
allowed_blocks = allowed_blocks or set()
allowed_tuple = tuple(sorted(allowed_blocks) or [])
return text.__hash__() + allowed_tuple.__hash__() + collect_raw_data.__hash__()

Expand Down
11 changes: 8 additions & 3 deletions dbt_common/contracts/util.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import dataclasses
from typing import Any
from typing import Any, TypeVar

_R = TypeVar("_R", bound="Replaceable")


# TODO: remove from dbt_common.contracts.util:: Replaceable + references
class Replaceable:
def replace(self, **kwargs: Any):
def replace(self: _R, **kwargs: Any) -> _R:
return dataclasses.replace(self, **kwargs) # type: ignore


_M = TypeVar("_M", bound="Mergeable")


class Mergeable(Replaceable):
def merged(self, *args):
def merged(self: _M, *args: Any) -> _M:
"""Perform a shallow merge, where the last non-None write wins. This is
intended to merge dataclasses that are a collection of optional values.
"""
Expand Down
7 changes: 3 additions & 4 deletions dbt_common/dataclass_schema.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from typing import Any, cast, ClassVar, Dict, get_type_hints, List, Optional, Tuple
from typing import Any, ClassVar, Dict, get_type_hints, List, Optional, Tuple, Union
import re
import jsonschema
from dataclasses import fields, Field
from enum import Enum
from datetime import datetime
from dateutil.parser import parse

# type: ignore
from mashumaro.config import (
TO_DICT_ADD_OMIT_NONE_FLAG,
ADD_SERIALIZATION_CONTEXT,
Expand All @@ -33,8 +32,8 @@ def serialize(self, value: datetime) -> str:
out += "Z"
return out

def deserialize(self, value) -> datetime:
return value if isinstance(value, datetime) else parse(cast(str, value))
def deserialize(self, value: Union[datetime, str]) -> datetime:
return value if isinstance(value, datetime) else parse(value)


class dbtMashConfig(MashBaseConfig):
Expand Down
4 changes: 2 additions & 2 deletions dbt_common/exceptions/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class WorkingDirectoryError(CommandError):
def __init__(self, cwd: str, cmd: List[str], msg: str) -> None:
super().__init__(cwd, cmd, msg)

def __str__(self):
def __str__(self, prefix: str = "! ") -> str:
return f'{self.msg}: "{self.cwd}"'


Expand All @@ -46,5 +46,5 @@ def __init__(
self.stderr = scrub_secrets(stderr.decode("utf-8"), env_secrets())
self.args = (cwd, self.cmd, returncode, self.stdout, self.stderr, msg)

def __str__(self):
def __str__(self, prefix: str = "! ") -> str:
return f"{self.msg} running: {self.cmd}"
4 changes: 2 additions & 2 deletions dbt_common/semver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
import re
from typing import List, Iterable
from typing import Iterable, List, Union

import dbt_common.exceptions.base
from dbt_common.exceptions import VersionsNotCompatibleError
Expand Down Expand Up @@ -378,7 +378,7 @@ def is_exact(self) -> bool:
return False


def reduce_versions(*args):
def reduce_versions(*args: Union[VersionSpecifier, VersionRange, str]) -> VersionRange:
version_specifiers = []

for version in args:
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_connection_retries.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from dbt_common.utils.connection import connection_exception_retry


def no_retry_fn():
def no_retry_fn() -> str:
return "success"


class TestNoRetries:
def test_no_retry(self):
def test_no_retry(self) -> None:
fn_to_retry = functools.partial(no_retry_fn)
result = connection_exception_retry(fn_to_retry, 3)

Expand Down
36 changes: 19 additions & 17 deletions tests/unit/test_diff.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import json
from typing import Any, Dict
from typing import Any, Dict, List

import pytest
from dbt_common.record import Diff

Case = List[Dict[str, Any]]


@pytest.fixture
def current_query():
def current_query() -> Case:
return [
{
"params": {
Expand All @@ -21,7 +23,7 @@ def current_query():


@pytest.fixture
def query_modified_order():
def query_modified_order() -> Case:
return [
{
"params": {
Expand All @@ -36,7 +38,7 @@ def query_modified_order():


@pytest.fixture
def query_modified_value():
def query_modified_value() -> Case:
return [
{
"params": {
Expand All @@ -51,7 +53,7 @@ def query_modified_value():


@pytest.fixture
def current_simple():
def current_simple() -> Case:
return [
{
"params": {
Expand All @@ -65,7 +67,7 @@ def current_simple():


@pytest.fixture
def current_simple_modified():
def current_simple_modified() -> Case:
return [
{
"params": {
Expand All @@ -79,7 +81,7 @@ def current_simple_modified():


@pytest.fixture
def env_record():
def env_record() -> Case:
return [
{
"params": {},
Expand All @@ -94,7 +96,7 @@ def env_record():


@pytest.fixture
def modified_env_record():
def modified_env_record() -> Case:
return [
{
"params": {},
Expand All @@ -108,30 +110,30 @@ def modified_env_record():
]


def test_diff_query_records_no_diff(current_query, query_modified_order):
def test_diff_query_records_no_diff(current_query: Case, query_modified_order: Case) -> None:
# Setup: Create an instance of Diff
diff_instance = Diff(
current_recording_path="path/to/current", previous_recording_path="path/to/previous"
)
result = diff_instance.diff_query_records(current_query, query_modified_order)
# the order changed but the diff should be empty
expected_result = {}
expected_result: Dict[str, Any] = {}
assert result == expected_result # Replace expected_result with what you actually expect


def test_diff_query_records_with_diff(current_query, query_modified_value):
def test_diff_query_records_with_diff(current_query: Case, query_modified_value: Case) -> None:
diff_instance = Diff(
current_recording_path="path/to/current", previous_recording_path="path/to/previous"
)
result = diff_instance.diff_query_records(current_query, query_modified_value)
# the values changed this time
expected_result = {
expected_result: Dict[str, Any] = {
"values_changed": {"root[0]['result']['table'][1]['b']": {"new_value": 7, "old_value": 10}}
}
assert result == expected_result


def test_diff_env_records(env_record, modified_env_record):
def test_diff_env_records(env_record: Case, modified_env_record: Case) -> None:
diff_instance = Diff(
current_recording_path="path/to/current", previous_recording_path="path/to/previous"
)
Expand All @@ -147,17 +149,17 @@ def test_diff_env_records(env_record, modified_env_record):
assert result == expected_result


def test_diff_default_no_diff(current_simple):
def test_diff_default_no_diff(current_simple: Case) -> None:
diff_instance = Diff(
current_recording_path="path/to/current", previous_recording_path="path/to/previous"
)
# use the same list to ensure no diff
result = diff_instance.diff_default(current_simple, current_simple)
expected_result = {}
expected_result: Dict[str, Any] = {}
assert result == expected_result


def test_diff_default_with_diff(current_simple, current_simple_modified):
def test_diff_default_with_diff(current_simple: Case, current_simple_modified: Case) -> None:
diff_instance = Diff(
current_recording_path="path/to/current", previous_recording_path="path/to/previous"
)
Expand All @@ -170,7 +172,7 @@ def test_diff_default_with_diff(current_simple, current_simple_modified):

# Mock out reading the files so we don't have to
class MockFile:
def __init__(self, json_data):
def __init__(self, json_data) -> None:
self.json_data = json_data

def __enter__(self):
Expand Down
10 changes: 5 additions & 5 deletions tests/unit/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def valid_error_names() -> Set[str]:


class TestWarnOrError:
def test_fires_error(self, valid_error_names: Set[str]):
def test_fires_error(self, valid_error_names: Set[str]) -> None:
functions.WARN_ERROR_OPTIONS = WarnErrorOptions(
include="*", valid_error_names=valid_error_names
)
Expand All @@ -49,8 +49,8 @@ def test_fires_warning(
self,
valid_error_names: Set[str],
event_catcher: EventCatcher,
set_event_manager_with_catcher,
):
set_event_manager_with_catcher: None,
) -> None:
functions.WARN_ERROR_OPTIONS = WarnErrorOptions(
include="*", exclude=list(valid_error_names), valid_error_names=valid_error_names
)
Expand All @@ -62,8 +62,8 @@ def test_silenced(
self,
valid_error_names: Set[str],
event_catcher: EventCatcher,
set_event_manager_with_catcher,
):
set_event_manager_with_catcher: None,
) -> None:
functions.WARN_ERROR_OPTIONS = WarnErrorOptions(
include="*", silence=list(valid_error_names), valid_error_names=valid_error_names
)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def test_incomplete_block_failure(self) -> None:
with self.assertRaises(CompilationError):
extract_toplevel_blocks(body, allowed_blocks={"myblock"})

def test_wrong_end_failure(self):
def test_wrong_end_failure(self) -> None:
body = "{% myblock foo %} {% endotherblock %}"
with self.assertRaises(CompilationError):
extract_toplevel_blocks(body, allowed_blocks={"myblock", "otherblock"})
Expand Down
5 changes: 3 additions & 2 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from typing import Any, Tuple, Union

import dbt_common.exceptions
import dbt_common.utils.dict
Expand Down Expand Up @@ -68,7 +69,7 @@ def setUp(self) -> None:
}

@staticmethod
def intify_all(value, _):
def intify_all(value, _) -> int:
try:
return int(value)
except (TypeError, ValueError):
Expand Down Expand Up @@ -98,7 +99,7 @@ def test__simple_cases(self) -> None:
self.assertEqual(actual, expected)

@staticmethod
def special_keypath(value, keypath):
def special_keypath(value: Any, keypath: Tuple[Union[str, int], ...]) -> Any:
if tuple(keypath) == ("foo", "baz", 1):
return "hello"
else:
Expand Down

0 comments on commit 4b8a41e

Please sign in to comment.