Skip to content

Commit

Permalink
Add types to semver code
Browse files Browse the repository at this point in the history
  • Loading branch information
peterallenwebb committed Sep 8, 2024
1 parent f63804a commit 05d6ae7
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 45 deletions.
10 changes: 8 additions & 2 deletions dbt_common/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,14 @@ def __init__(self, params, result) -> None:

def to_dict(self) -> Dict[str, Any]:
return {
"params": self.params._to_dict() if hasattr(self.params, "_to_dict") else dataclasses.asdict(self.params), # type: ignore
"result": self.result._to_dict() if hasattr(self.result, "_to_dict") else dataclasses.asdict(self.result) if self.result is not None else None, # type: ignore
"params": self.params._to_dict()
if hasattr(self.params, "_to_dict")
else dataclasses.asdict(self.params),
"result": self.result._to_dict()
if hasattr(self.result, "_to_dict")
else dataclasses.asdict(self.result)
if self.result is not None
else None,
}

@classmethod
Expand Down
66 changes: 39 additions & 27 deletions dbt_common/semver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
import re
from typing import Iterable, List, Union
from typing import Any, Iterable, List, Union

import dbt_common.exceptions.base
from dbt_common.exceptions import VersionsNotCompatibleError
Expand Down Expand Up @@ -67,9 +67,9 @@ class VersionSpecification(dbtClassMixin):
_VERSION_REGEX = re.compile(_VERSION_REGEX_PAT_STR, re.VERBOSE)


def _cmp(a, b) -> int:
def _cmp(a: Any, b: Any) -> int:
"""Return negative if a<b, zero if a==b, positive if a>b."""
return (a > b) - (a < b)
return int((a > b) - (a < b))


@dataclass
Expand Down Expand Up @@ -102,7 +102,9 @@ def from_version_string(cls, version_string: str) -> "VersionSpecifier":

matched = {k: v for k, v in match.groupdict().items() if v is not None}

return cls.from_dict(matched)
spec = cls.from_dict(matched)
assert isinstance(spec, VersionSpecifier)
return spec

def __str__(self) -> str:
return self.to_version_string()
Expand Down Expand Up @@ -198,10 +200,11 @@ def __lt__(self, other: "VersionSpecifier") -> bool:
def __gt__(self, other: "VersionSpecifier") -> bool:
return self.compare(other) == 1

def __eq___(self, other: "VersionSpecifier") -> bool:
def __eq__(self, other: object) -> bool:
assert isinstance(other, VersionSpecifier)
return self.compare(other) == 0

def __cmp___(self, other: "VersionSpecifier") -> int:
def __cmp__(self, other: "VersionSpecifier") -> int:
return self.compare(other)

@property
Expand All @@ -221,8 +224,8 @@ def is_exact(self) -> bool:
return self.matcher == Matchers.EXACT

@classmethod
def _nat_cmp(cls, a, b) -> int:
def cmp_prerelease_tag(a, b):
def _nat_cmp(cls, a: str, b: str) -> int:
def cmp_prerelease_tag(a: Union[str, int], b: Union[str, int]) -> int:
if isinstance(a, int) and isinstance(b, int):
return _cmp(a, b)
elif isinstance(a, int):
Expand All @@ -234,10 +237,10 @@ def cmp_prerelease_tag(a, b):

a, b = a or "", b or ""
a_parts, b_parts = a.split("."), b.split(".")
a_parts = [int(x) if re.match(r"^\d+$", x) else x for x in a_parts]
b_parts = [int(x) if re.match(r"^\d+$", x) else x for x in b_parts]
for sub_a, sub_b in zip(a_parts, b_parts):
cmp_result = cmp_prerelease_tag(sub_a, sub_b)
a_parts_2 = [int(x) if re.match(r"^\d+$", x) else x for x in a_parts]
b_parts_2 = [int(x) if re.match(r"^\d+$", x) else x for x in b_parts]
for sub_a, sub_b in zip(a_parts_2, b_parts_2):
cmp_result = cmp_prerelease_tag(sub_a, sub_b) # type: ignore
if cmp_result != 0:
return cmp_result
else:
Expand All @@ -249,21 +252,25 @@ class VersionRange:
start: VersionSpecifier
end: VersionSpecifier

def _try_combine_exact(self, a, b):
def _try_combine_exact(self, a: VersionSpecifier, b: VersionSpecifier) -> VersionSpecifier:
if a.compare(b) == 0:
return a
else:
raise VersionsNotCompatibleError()

def _try_combine_lower_bound_with_exact(self, lower, exact):
def _try_combine_lower_bound_with_exact(
self, lower: VersionSpecifier, exact: VersionSpecifier
) -> VersionSpecifier:
comparison = lower.compare(exact)

if comparison < 0 or (comparison == 0 and lower.matcher == Matchers.GREATER_THAN_OR_EQUAL):
return exact

raise VersionsNotCompatibleError()

def _try_combine_lower_bound(self, a, b):
def _try_combine_lower_bound(
self, a: VersionSpecifier, b: VersionSpecifier
) -> VersionSpecifier:
if b.is_unbounded:
return a
elif a.is_unbounded:
Expand All @@ -280,18 +287,22 @@ def _try_combine_lower_bound(self, a, b):
elif a.is_exact:
return self._try_combine_lower_bound_with_exact(b, a)

elif b.is_exact:
else:
return self._try_combine_lower_bound_with_exact(a, b)

def _try_combine_upper_bound_with_exact(self, upper, exact):
def _try_combine_upper_bound_with_exact(
self, upper: VersionSpecifier, exact: VersionSpecifier
) -> VersionSpecifier:
comparison = upper.compare(exact)

if comparison > 0 or (comparison == 0 and upper.matcher == Matchers.LESS_THAN_OR_EQUAL):
return exact

raise VersionsNotCompatibleError()

def _try_combine_upper_bound(self, a, b):
def _try_combine_upper_bound(
self, a: VersionSpecifier, b: VersionSpecifier
) -> VersionSpecifier:
if b.is_unbounded:
return a
elif a.is_unbounded:
Expand All @@ -308,15 +319,14 @@ def _try_combine_upper_bound(self, a, b):
elif a.is_exact:
return self._try_combine_upper_bound_with_exact(b, a)

elif b.is_exact:
else:
return self._try_combine_upper_bound_with_exact(a, b)

def reduce(self, other):
def reduce(self, other: "VersionRange") -> "VersionRange":
start = None

if self.start.is_exact and other.start.is_exact:
start = end = self._try_combine_exact(self.start, other.start)

else:
start = self._try_combine_lower_bound(self.start, other.start)
end = self._try_combine_upper_bound(self.end, other.end)
Expand All @@ -326,7 +336,7 @@ def reduce(self, other):

return VersionRange(start=start, end=end)

def __str__(self):
def __str__(self) -> str:
result = []

if self.start.is_unbounded and self.end.is_unbounded:
Expand All @@ -340,7 +350,7 @@ def __str__(self):

return ", ".join(result)

def to_version_string_pair(self):
def to_version_string_pair(self) -> List[str]:
to_return = []

if not self.start.is_unbounded:
Expand All @@ -353,7 +363,7 @@ def to_version_string_pair(self):


class UnboundedVersionSpecifier(VersionSpecifier):
def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(
matcher=Matchers.EXACT, major=None, minor=None, patch=None, prerelease=None, build=None
)
Expand Down Expand Up @@ -418,7 +428,7 @@ def reduce_versions(*args: Union[VersionSpecifier, VersionRange, str]) -> Versio
return to_return


def versions_compatible(*args) -> bool:
def versions_compatible(*args: Union[VersionSpecifier, VersionRange, str]) -> bool:
if len(args) == 1:
return True

Expand All @@ -429,7 +439,9 @@ def versions_compatible(*args) -> bool:
return False


def find_possible_versions(requested_range, available_versions: Iterable[str]):
def find_possible_versions(
requested_range: VersionRange, available_versions: Iterable[str]
) -> List[str]:
possible_versions = []

for version_string in available_versions:
Expand All @@ -443,7 +455,7 @@ def find_possible_versions(requested_range, available_versions: Iterable[str]):


def resolve_to_specific_version(
requested_range, available_versions: Iterable[str]
requested_range: VersionRange, available_versions: Iterable[str]
) -> Optional[str]:
max_version = None
max_version_string = None
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/test_behavior_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dbt_common.exceptions.base import CompilationError


def test_behavior_default():
def test_behavior_default() -> None:
behavior = Behavior(
[
{"name": "default_false_flag", "default": False},
Expand All @@ -17,7 +17,7 @@ def test_behavior_default():
assert behavior.default_true_flag.setting is True


def test_behavior_user_override():
def test_behavior_user_override() -> None:
behavior = Behavior(
[
{"name": "flag_default_false", "default": False},
Expand All @@ -43,7 +43,7 @@ def test_behavior_user_override():
assert behavior.flag_default_true_override_true.setting is True


def test_behavior_unregistered_flag_raises_correct_exception():
def test_behavior_unregistered_flag_raises_correct_exception() -> None:
behavior = Behavior(
[
{"name": "behavior_flag_exists", "default": False},
Expand All @@ -56,7 +56,7 @@ def test_behavior_unregistered_flag_raises_correct_exception():
assert behavior.behavior_flag_does_not_exist


def test_behavior_flag_can_be_used_as_conditional():
def test_behavior_flag_can_be_used_as_conditional() -> None:
behavior = Behavior(
[
{"name": "flag_false", "default": False},
Expand Down
25 changes: 16 additions & 9 deletions tests/unit/test_diff.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import json
from typing import Any, Dict, List
from inspect import Traceback
from typing import Any, Callable, Dict, List, Optional, Type

import pytest
from dbt_common.record import Diff
from tox.pytest import MonkeyPatch

Case = List[Dict[str, Any]]

Expand Down Expand Up @@ -172,30 +174,35 @@ def test_diff_default_with_diff(current_simple: Case, current_simple_modified: C

# Mock out reading the files so we don't have to
class MockFile:
def __init__(self, json_data) -> None:
def __init__(self, json_data: Any) -> None:
self.json_data = json_data

def __enter__(self):
def __enter__(self) -> "MockFile":
return self

def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(
self,
exc_type: Optional[Type[Exception]],
exc_val: Optional[Exception],
exc_tb: Optional[Traceback],
) -> None:
pass

def read(self):
def read(self) -> str:
return json.dumps(self.json_data)


# Create a Mock Open Function
def mock_open(mock_files):
def open_mock(file, *args, **kwargs):
def mock_open(mock_files: Dict[str, Any]) -> Callable[..., MockFile]:
def open_mock(file: str, *args: Any, **kwargs: Any) -> MockFile:
if file in mock_files:
return MockFile(mock_files[file])
raise FileNotFoundError(f"No mock file found for {file}")

return open_mock


def test_calculate_diff_no_diff(monkeypatch) -> None:
def test_calculate_diff_no_diff(monkeypatch: MonkeyPatch) -> None:
# Mock data for the files
current_recording_data = {
"GetEnvRecord": [
Expand Down Expand Up @@ -259,7 +266,7 @@ def test_calculate_diff_no_diff(monkeypatch) -> None:
assert result == expected_result


def test_calculate_diff_with_diff(monkeypatch) -> None:
def test_calculate_diff_with_diff(monkeypatch: MonkeyPatch) -> None:
# Mock data for the files
current_recording_data = {
"GetEnvRecord": [
Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def code(self) -> str:
return "Z050"

def message(self) -> str:
assert isinstance(self.msg, str)
return self.msg


Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_semver.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ def create_range(


class TestSemver(unittest.TestCase):
def assertVersionSetResult(self, inputs, output_range) -> None:
def assertVersionSetResult(self, inputs: List[str], output_range: List[Optional[str]]) -> None:
expected = create_range(*output_range)

for permutation in itertools.permutations(inputs):
self.assertEqual(reduce_versions(*permutation), expected)

def assertInvalidVersionSet(self, inputs) -> None:
def assertInvalidVersionSet(self, inputs: List[str]) -> None:
for permutation in itertools.permutations(inputs):
with self.assertRaises(VersionsNotCompatibleError):
reduce_versions(*permutation)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def setUp(self) -> None:
}

@staticmethod
def intify_all(value, _) -> int:
def intify_all(value: Any, _: Any) -> int:
try:
return int(value)
except (TypeError, ValueError):
Expand Down

0 comments on commit 05d6ae7

Please sign in to comment.