diff --git a/refurb/checks/readability/fluid_interface.py b/refurb/checks/readability/fluid_interface.py new file mode 100644 index 0000000..338a0b5 --- /dev/null +++ b/refurb/checks/readability/fluid_interface.py @@ -0,0 +1,107 @@ +from dataclasses import dataclass + +from mypy.nodes import ( + Block, + Statement, + AssignmentStmt, + MypyFile, + CallExpr, + MemberExpr, + NameExpr, +) + +from refurb.checks.common import check_block_like +from refurb.error import Error + + +@dataclass +class ErrorInfo(Error): + r"""When an API has a Fluent Interface (the ability to chain multiple calls together), you should chain those calls + instead of repeatedly assigning and using the value. + Sometimes a return statement can be written more succinctly: + + Bad: + + ```python + def get_tensors(device: str) -> torch.Tensor: + a = torch.ones(2, 1) + a = a.long() + a = a.to(device) + return a + + + def process(file_name: str): + common_columns = ["col1_renamed", "col2_renamed", "custom_col"] + df = spark.read.parquet(file_name) + df = df \ + .withColumnRenamed('col1', 'col1_renamed') \ + .withColumnRenamed('col2', 'col2_renamed') + df = df \ + .select(common_columns) \ + .withColumn('service_type', F.lit('green')) + return df + ``` + + Good: + + ```python + def get_tensors(device: str) -> torch.Tensor: + a = ( + torch.ones(2, 1) + .long() + .to(device) + ) + return a + + def process(file_name: str): + common_columns = ["col1_renamed", "col2_renamed", "custom_col"] + df = ( + spark.read.parquet(file_name) + .withColumnRenamed('col1', 'col1_renamed') + .withColumnRenamed('col2', 'col2_renamed') + .select(common_columns) + .withColumn('service_type', F.lit('green')) + ) + return df + ``` + """ + + name = "use-fluid-interface" + code = 184 + categories = ("readability",) + + +def check(node: Block | MypyFile, errors: list[Error]) -> None: + check_block_like(check_stmts, node, errors) + + +def check_call(node) -> bool: + match node: + # Single chain + case CallExpr(callee=MemberExpr(expr=NameExpr(name=x), name=y)): + return True + # Nested + case CallExpr(callee=MemberExpr(expr=call_node, name=y)): + return check_call(call_node) + + return False + + +def check_stmts(stmts: list[Statement], errors: list[Error]) -> None: + last = "" + + for stmt in stmts: + match stmt: + case AssignmentStmt(lvalues=[NameExpr(name=name)], rvalue=rvalue): + if last and f"{last}'" == name and check_call(rvalue): + errors.append( + ErrorInfo.from_node( + stmt, + f"Assignment statements should be chained", + ) + ) + + last = name + + case _: + last = "" diff --git a/test/data/err_184.py b/test/data/err_184.py new file mode 100644 index 0000000..825bc17 --- /dev/null +++ b/test/data/err_184.py @@ -0,0 +1,82 @@ +class torch: + @staticmethod + def ones(*args): + return torch + + @staticmethod + def long(): + return torch + + @staticmethod + def to(device: str): + return torch.Tensor() + + class Tensor: + pass + + +def transform(x): + return x + + +class spark: + class read: + @staticmethod + def parquet(file_name: str): + return spark.DataFrame() + + class functions: + @staticmethod + def lit(constant): + return constant + + @staticmethod + def col(col_name): + return col_name + + class DataFrame: + @staticmethod + def withColumnRenamed(col_in, col_out): + return spark.DataFrame() + + @staticmethod + def withColumn(col_in, col_out): + return spark.DataFrame() + + @staticmethod + def select(*args): + return spark.DataFrame() + +# these will match +def get_tensors(device: str) -> torch.Tensor: + a = torch.ones(2, 1) + a = a.long() + a = a.to(device) + return a + + +def process(file_name: str): + common_columns = ["col1_renamed", "col2_renamed", "custom_col"] + df = spark.read.parquet(file_name) + df = df \ + .withColumnRenamed('col1', 'col1_renamed') \ + .withColumnRenamed('col2', 'col2_renamed') + df = df \ + .select(common_columns) \ + .withColumn('service_type', spark.functions.lit('green')) + return df + + +def projection(df_in: spark.DataFrame) -> spark.DataFrame: + df = ( + df_in.select(["col1", "col2"]) + .withColumnRenamed("col1", "col1a") + ) + return df.withColumn("col2a", spark.functions.col("col2").cast("date")) + + +# these will not +def no_match(): + y = 10 + y = transform(y) + return y diff --git a/test/data/err_184.txt b/test/data/err_184.txt new file mode 100644 index 0000000..51e0489 --- /dev/null +++ b/test/data/err_184.txt @@ -0,0 +1,4 @@ +test/data/err_184.py:53:5 [FURB184]: Assignment statements should be chained +test/data/err_184.py:54:5 [FURB184]: Assignment statements should be chained +test/data/err_184.py:61:5 [FURB184]: Assignment statements should be chained +test/data/err_184.py:64:5 [FURB184]: Assignment statements should be chained