diff --git a/refurb/checks/readability/fluid_interface.py b/refurb/checks/readability/fluid_interface.py new file mode 100644 index 0000000..780c01c --- /dev/null +++ b/refurb/checks/readability/fluid_interface.py @@ -0,0 +1,107 @@ +from dataclasses import dataclass + +from mypy.nodes import ( + Block, + Statement, + AssignmentStmt, + MypyFile, + CallExpr, + MemberExpr, + NameExpr, +) + +from refurb.checks.common import check_block_like +from refurb.error import Error + + +@dataclass +class ErrorInfo(Error): + """ + https://towardsdatascience.com/the-unreasonable-effectiveness-of-method-chaining-in-pandas-15c2109e3c69 + Sometimes a return statement can be written more succinctly: + + Bad: + + ``` + def get_tensors(device: str) -> torch.Tensor: + a = torch.ones(2, 1) + a = a.long() + a = a.to(device) + return a + + + def process(file_name: str): + common_columns = ["col1_renamed", "col2_renamed", "custom_col"] + df = spark.read.parquet(file_name) + df = df \ + .withColumnRenamed('col1', 'col1_renamed') \ + .withColumnRenamed('col2', 'col2_renamed') + df = df \ + .select(common_columns) \ + .withColumn('service_type', F.lit('green')) + return df + ``` + + Good: + + ``` + def get_tensors(device: str) -> torch.Tensor: + a = ( + torch.ones(2, 1) + .long() + .to(device) + ) + return a + + def process(file_name: str): + common_columns = ["col1_renamed", "col2_renamed", "custom_col"] + df = ( + spark.read.parquet(file_name) + .withColumnRenamed('col1', 'col1_renamed') + .withColumnRenamed('col2', 'col2_renamed') + .select(common_columns) + .withColumn('service_type', F.lit('green')) + ) + return df + ``` + """ + + name = "fluid-interface" + code = 999 + categories = ("readability",) + + +def check(node: Block | MypyFile, errors: list[Error]) -> None: + check_block_like(check_stmts, node, errors) + + +def check_call(node) -> bool: + match node: + # Single chain + case CallExpr(callee=MemberExpr(expr=NameExpr(name=x), name=y)): + return True + # Nested + case CallExpr(callee=MemberExpr(expr=call_node, name=y)): + return check_call(call_node) + + return False + + +def check_stmts(stmts: list[Statement], errors: list[Error]) -> None: + last = "" + + for stmt in stmts: + match stmt: + case AssignmentStmt(lvalues=[NameExpr(name=name)], rvalue=rvalue): + if last != "" and f"{last}'" == name and check_call(rvalue): + errors.append( + ErrorInfo.from_node( + stmt, + f"Assignment statements should be chained", + ) + ) + + last = name + + case _: + last = "" diff --git a/test/data/err_999.py b/test/data/err_999.py new file mode 100644 index 0000000..b231351 --- /dev/null +++ b/test/data/err_999.py @@ -0,0 +1,70 @@ +class torch: + @staticmethod + def ones(*args): + return torch + + @staticmethod + def long(): + return torch + + @staticmethod + def to(device: str): + return torch.Tensor() + + class Tensor: + pass + + +def transform(x): + return x + + +class spark: + class read: + @staticmethod + def parquet(file_name: str): + return spark.DataFrame() + + class functions: + @staticmethod + def lit(constant): + return constant + + class DataFrame: + @staticmethod + def withColumnRenamed(col_in, col_out): + return spark.DataFrame() + + @staticmethod + def withColumn(col_in, col_out): + return spark.DataFrame() + + @staticmethod + def select(*args): + return spark.DataFrame() + +# these will match +def get_tensors(device: str) -> torch.Tensor: + a = torch.ones(2, 1) + a = a.long() + a = a.to(device) + return a + + +def process(file_name: str): + common_columns = ["col1_renamed", "col2_renamed", "custom_col"] + df = spark.read.parquet(file_name) + df = df \ + .withColumnRenamed('col1', 'col1_renamed') \ + .withColumnRenamed('col2', 'col2_renamed') + df = df \ + .select(common_columns) \ + .withColumn('service_type', spark.functions.lit('green')) + return df + + +# these will not +def no_match(): + y = 10 + y = transform(y) + return y