diff --git a/dev-requirements.txt b/dev-requirements.txt index f6d97a5..9e20cfa 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,4 +1,4 @@ -black==24.1.1 +black==24.2.0 click==8.1.7 colorama==0.4.6 coverage==7.4.1 @@ -12,6 +12,6 @@ pathspec==0.12.1 platformdirs==4.2.0 pluggy==1.4.0 pytest-cov==4.1.0 -pytest==8.0.0 -ruff==0.2.0 -typos==1.18.0 +pytest==8.0.1 +ruff==0.2.2 +typos==1.18.2 diff --git a/refurb/checks/builtin/writelines.py b/refurb/checks/builtin/writelines.py index b0ea27a..6dd881b 100644 --- a/refurb/checks/builtin/writelines.py +++ b/refurb/checks/builtin/writelines.py @@ -1,6 +1,15 @@ from dataclasses import dataclass -from mypy.nodes import Block, CallExpr, ExpressionStmt, ForStmt, MemberExpr, NameExpr, WithStmt +from mypy.nodes import ( + Block, + CallExpr, + Expression, + ExpressionStmt, + ForStmt, + MemberExpr, + NameExpr, + WithStmt, +) from refurb.checks.common import get_mypy_type, is_equivalent, is_same_type, stringify from refurb.error import Error @@ -44,6 +53,11 @@ class ErrorInfo(Error): categories = ("builtin", "readability") +def is_file_object(f: Expression) -> bool: + # TODO: support more file-like types + return is_same_type(get_mypy_type(f), "io.TextIOWrapper", "io.BufferedWriter") + + def check(node: WithStmt, errors: list[Error]) -> None: match node: case WithStmt( @@ -70,10 +84,7 @@ def check(node: WithStmt, errors: list[Error]) -> None: ) as for_stmt ] ), - ) if ( - is_same_type(get_mypy_type(f), "io.TextIOWrapper", "io.BufferedWriter") - and is_equivalent(f, write_base) - ): + ) if is_file_object(f) and is_equivalent(f, write_base): old = stringify(for_stmt) new = f"{stringify(f)}.writelines({stringify(source)})" diff --git a/refurb/checks/common.py b/refurb/checks/common.py index fbb196d..53c66a8 100644 --- a/refurb/checks/common.py +++ b/refurb/checks/common.py @@ -225,8 +225,8 @@ def get_common_expr_in_comparison_chain( case ( ComparisonExpr(operators=[lhs_oper], operands=[a, b]), ComparisonExpr(operators=[rhs_oper], operands=[c, d]), - ) if ( - lhs_oper == rhs_oper == cmp_oper and (indices := get_common_expr_positions(a, b, c, d)) + ) if lhs_oper == rhs_oper == cmp_oper and ( + indices := get_common_expr_positions(a, b, c, d) ): return a, indices @@ -402,7 +402,7 @@ def _stringify(node: Node) -> str: arg_names=arg_names, arg_kinds=arg_kinds, body=Block(body=[ReturnStmt(expr=Expression() as expr)]), - ) if (all(kind == ArgKind.ARG_POS for kind in arg_kinds) and all(arg_names)): + ) if all(kind == ArgKind.ARG_POS for kind in arg_kinds) and all(arg_names): if arg_names: args = " " # type: ignore args += ", ".join(arg_names) # type: ignore diff --git a/refurb/checks/pathlib/simplify_ctor.py b/refurb/checks/pathlib/simplify_ctor.py index 1b2eb68..342c8c4 100644 --- a/refurb/checks/pathlib/simplify_ctor.py +++ b/refurb/checks/pathlib/simplify_ctor.py @@ -49,7 +49,7 @@ def check(node: CallExpr, errors: list[Error]) -> None: case CallExpr( args=[RefExpr(fullname=arg) as arg_ref], callee=RefExpr(fullname="pathlib.Path") as func_ref, - ) if ((arg := normalize_os_path(arg)) in {"os.curdir", "os.path.curdir"}): + ) if (arg := normalize_os_path(arg)) in {"os.curdir", "os.path.curdir"}: func_name = stringify(func_ref) arg_name = stringify(arg_ref) diff --git a/refurb/checks/readability/no_or_default.py b/refurb/checks/readability/no_or_default.py index 990f768..085e0a3 100644 --- a/refurb/checks/readability/no_or_default.py +++ b/refurb/checks/readability/no_or_default.py @@ -68,9 +68,8 @@ def check(node: OpExpr, errors: list[Error]) -> None: | FloatExpr(value=0.0) | NameExpr(fullname="builtins.False") ) as rhs, - ) if ( - (expected_type := mypy_type_to_python_type(get_mypy_type(rhs))) - and is_same_type(get_mypy_type(lhs), expected_type) + ) if (expected_type := mypy_type_to_python_type(get_mypy_type(rhs))) and is_same_type( + get_mypy_type(lhs), expected_type ): lhs_expr = stringify(lhs) diff --git a/refurb/checks/readability/use_comprehension.py b/refurb/checks/readability/use_comprehension.py index 95d5182..67a26ad 100644 --- a/refurb/checks/readability/use_comprehension.py +++ b/refurb/checks/readability/use_comprehension.py @@ -97,9 +97,8 @@ def check_stmts(stmts: list[Statement], errors: list[Error]) -> None: errors.append(ErrorInfo.from_node(assign)) case ForStmt(body=Block(body=[stmt])) if ( - (name := get_append_func_callee_name(stmt)) - and name.fullname == assign.fullname - ): + name := get_append_func_callee_name(stmt) + ) and name.fullname == assign.fullname: name_visitor = ReadCountVisitor(name) name_visitor.accept(stmt) diff --git a/refurb/checks/readability/use_func_name.py b/refurb/checks/readability/use_func_name.py index 3a4075f..89e61cc 100644 --- a/refurb/checks/readability/use_func_name.py +++ b/refurb/checks/readability/use_func_name.py @@ -83,9 +83,8 @@ def check(node: LambdaExpr, errors: list[Error]) -> None: body=Block( body=[ReturnStmt(expr=CallExpr(callee=RefExpr() as ref) as func)], ), - ) if ( - get_lambda_arg_names(lambda_args) == get_func_arg_names(func.args) - and all(kind == ArgKind.ARG_POS for kind in func.arg_kinds) + ) if get_lambda_arg_names(lambda_args) == get_func_arg_names(func.args) and all( + kind == ArgKind.ARG_POS for kind in func.arg_kinds ): func_name = stringify(ref)