diff --git a/refurb/checks/common.py b/refurb/checks/common.py index c835d49..74fbc23 100644 --- a/refurb/checks/common.py +++ b/refurb/checks/common.py @@ -532,6 +532,9 @@ def get_mypy_type(node: Node) -> Type | None: case FloatExpr(): return _get_builtin_mypy_type("float") + case NameExpr(fullname="builtins.True" | "builtins.False"): + return _get_builtin_mypy_type("bool") + case DictExpr(): return _get_builtin_mypy_type("dict") diff --git a/refurb/checks/readability/no_is_bool_compare.py b/refurb/checks/readability/no_is_bool_compare.py index 9b98ec5..e3849a6 100644 --- a/refurb/checks/readability/no_is_bool_compare.py +++ b/refurb/checks/readability/no_is_bool_compare.py @@ -1,8 +1,9 @@ from dataclasses import dataclass +from typing import TypeGuard -from mypy.nodes import ComparisonExpr, Expression, NameExpr, Var +from mypy.nodes import ComparisonExpr, Expression, NameExpr -from refurb.checks.common import is_same_type, stringify +from refurb.checks.common import get_mypy_type, is_same_type, stringify from refurb.error import Error @@ -36,7 +37,7 @@ class ErrorInfo(Error): categories = ("logical", "readability", "truthy") -def is_bool_literal(expr: Expression) -> bool: +def is_bool_literal(expr: Expression) -> TypeGuard[NameExpr]: match expr: case NameExpr(fullname="builtins.True" | "builtins.False"): return True @@ -45,11 +46,7 @@ def is_bool_literal(expr: Expression) -> bool: def is_bool_variable(expr: Expression) -> bool: - match expr: - case NameExpr(node=Var(type=ty)) if is_same_type(ty, bool): - return True - - return False + return is_same_type(get_mypy_type(expr), bool) def is_truthy(oper: str, name: str) -> bool: @@ -62,7 +59,7 @@ def check(node: ComparisonExpr, errors: list[Error]) -> None: match node: case ComparisonExpr( operators=["is" | "is not" | "==" | "!=" as oper], - operands=[NameExpr() as lhs, NameExpr() as rhs], + operands=[lhs, rhs], ): if is_bool_literal(lhs) and is_bool_variable(rhs): expr = stringify(rhs) diff --git a/refurb/checks/readability/use_str_method.py b/refurb/checks/readability/use_str_method.py index b5935fb..47ec4e5 100644 --- a/refurb/checks/readability/use_str_method.py +++ b/refurb/checks/readability/use_str_method.py @@ -1,9 +1,9 @@ from dataclasses import dataclass from typing import Any -from mypy.nodes import ArgKind, Block, CallExpr, LambdaExpr, MemberExpr, NameExpr, ReturnStmt, Var +from mypy.nodes import ArgKind, Block, CallExpr, LambdaExpr, MemberExpr, NameExpr, ReturnStmt -from refurb.checks.common import is_same_type, stringify +from refurb.checks.common import get_mypy_type, is_same_type, stringify from refurb.error import Error @@ -78,7 +78,7 @@ def check(node: LambdaExpr, errors: list[Error]) -> None: ReturnStmt( expr=CallExpr( callee=MemberExpr( - expr=NameExpr(name=member_base_name, node=Var(type=ty)), + expr=NameExpr(name=member_base_name) as member_base, name=str_func_name, ), args=[], @@ -89,7 +89,7 @@ def check(node: LambdaExpr, errors: list[Error]) -> None: ) if ( arg_name == member_base_name and str_func_name in STR_FUNCS - and is_same_type(ty, str, None, Any) + and is_same_type(get_mypy_type(member_base), str, None, Any) ): msg = f"Replace `{stringify(node)}` with `str.{str_func_name}`" diff --git a/test/data/err_149.py b/test/data/err_149.py index 08109db..691dc46 100644 --- a/test/data/err_149.py +++ b/test/data/err_149.py @@ -16,6 +16,11 @@ _ = True == b _ = False == b +class Wrapper: + b: bool + +_ = Wrapper().b == True + # these should not diff --git a/test/data/err_149.txt b/test/data/err_149.txt index 4a692f1..166f3af 100644 --- a/test/data/err_149.txt +++ b/test/data/err_149.txt @@ -10,3 +10,4 @@ test/data/err_149.py:14:5 [FURB149]: Replace `b != True` with `not b` test/data/err_149.py:15:5 [FURB149]: Replace `b != False` with `b` test/data/err_149.py:16:5 [FURB149]: Replace `True == b` with `b` test/data/err_149.py:17:5 [FURB149]: Replace `False == b` with `not b` +test/data/err_149.py:22:5 [FURB149]: Replace `Wrapper().b == True` with `Wrapper().b`