diff --git a/refurb/checks/readability/fluid_interface.py b/refurb/checks/readability/fluid_interface.py index 338a0b5..ac101b3 100644 --- a/refurb/checks/readability/fluid_interface.py +++ b/refurb/checks/readability/fluid_interface.py @@ -8,10 +8,12 @@ CallExpr, MemberExpr, NameExpr, + ReturnStmt, ) from refurb.checks.common import check_block_like from refurb.error import Error +from refurb.visitor import TraverserVisitor @dataclass @@ -24,11 +26,10 @@ class ErrorInfo(Error): ```python def get_tensors(device: str) -> torch.Tensor: - a = torch.ones(2, 1) - a = a.long() - a = a.to(device) - return a - + t1 = torch.ones(2, 1) + t2 = t1.long() + t3 = t2.to(device) + return t3 def process(file_name: str): common_columns = ["col1_renamed", "col2_renamed", "custom_col"] @@ -46,12 +47,12 @@ def process(file_name: str): ```python def get_tensors(device: str) -> torch.Tensor: - a = ( + t3 = ( torch.ones(2, 1) .long() .to(device) ) - return a + return t3 def process(file_name: str): common_columns = ["col1_renamed", "col2_renamed", "custom_col"] @@ -75,33 +76,79 @@ def check(node: Block | MypyFile, errors: list[Error]) -> None: check_block_like(check_stmts, node, errors) -def check_call(node) -> bool: +def check_call(node, name: str | None = None) -> bool: match node: # Single chain - case CallExpr(callee=MemberExpr(expr=NameExpr(name=x), name=y)): - return True + case CallExpr(callee=MemberExpr(expr=NameExpr(name=x), name=_)): + return name is None or name == x # Nested - case CallExpr(callee=MemberExpr(expr=call_node, name=y)): + case CallExpr(callee=MemberExpr(expr=call_node, name=_)): return check_call(call_node) return False +class NameReferenceVisitor(TraverserVisitor): + name: NameExpr + referenced: bool + + def __init__(self, name: NameExpr, stmt: Statement) -> None: + super().__init__() + self.name = name + self.stmt = stmt + self.referenced = False + + def visit_name_expr(self, node: NameExpr) -> None: + if not self.referenced and node.fullname == self.name.fullname: + self.referenced = True + + @property + def was_referenced(self) -> bool: + return self.referenced + + def check_stmts(stmts: list[Statement], errors: list[Error]) -> None: last = "" + visitors = [] for stmt in stmts: + for visitor in visitors: + visitor.accept(stmt) + match stmt: case AssignmentStmt(lvalues=[NameExpr(name=name)], rvalue=rvalue): - if last and f"{last}'" == name and check_call(rvalue): + if last and check_call(rvalue, name=last): + if f"{last}'" == name: + errors.append( + ErrorInfo.from_node( + stmt, + f"Assignment statement should be chained", + ) + ) + else: + # We need to ensure that the variable is not referenced somewhere else + name_expr = NameExpr(name=last) + name_expr.fullname = last + visitors.append(NameReferenceVisitor(name_expr, stmt)) + + last = name + case ReturnStmt(expr=rvalue): + if last and check_call(rvalue, name=last): errors.append( ErrorInfo.from_node( stmt, - f"Assignment statements should be chained", + f"Return statement should be chained", ) ) - - last = name - case _: last = "" + + # Ensure that variables are not referenced + for visitor in visitors: + if not visitor.referenced: + errors.append( + ErrorInfo.from_node( + visitor.stmt, + f"Assignment statement should be chained", + ) + ) diff --git a/test/data/err_184.py b/test/data/err_184.py index 825bc17..43034cd 100644 --- a/test/data/err_184.py +++ b/test/data/err_184.py @@ -47,6 +47,12 @@ def withColumn(col_in, col_out): def select(*args): return spark.DataFrame() +class F: + @staticmethod + def lit(value): + return value + + # these will match def get_tensors(device: str) -> torch.Tensor: a = torch.ones(2, 1) @@ -75,7 +81,36 @@ def projection(df_in: spark.DataFrame) -> spark.DataFrame: return df.withColumn("col2a", spark.functions.col("col2").cast("date")) +def assign_multiple(df): + df = df.select("column") + result_df = df.select("another_column") + final_df = result_df.withColumn("column2", F.lit("abc")) + return final_df + + +# not yet supported +def assign_alternating(df, df2): + df = df.select("column") + df2 = df2.select("another_column") + df = df.withColumn("column2", F.lit("abc")) + return df, df2 + + # these will not +def assign_multiple_referenced(df, df2): + df = df.select("column") + result_df = df.select("another_column") + return df, result_df + + +def invalid(df_in: spark.DataFrame, alternative_df: spark.DataFrame) -> spark.DataFrame: + df = ( + df_in.select(["col1", "col2"]) + .withColumnRenamed("col1", "col1a") + ) + return alternative_df.withColumn("col2a", spark.functions.col("col2").cast("date")) + + def no_match(): y = 10 y = transform(y) diff --git a/test/data/err_184.txt b/test/data/err_184.txt deleted file mode 100644 index 51e0489..0000000 --- a/test/data/err_184.txt +++ /dev/null @@ -1,4 +0,0 @@ -test/data/err_184.py:53:5 [FURB184]: Assignment statements should be chained -test/data/err_184.py:54:5 [FURB184]: Assignment statements should be chained -test/data/err_184.py:61:5 [FURB184]: Assignment statements should be chained -test/data/err_184.py:64:5 [FURB184]: Assignment statements should be chained