Skip to content

Commit

Permalink
fix PartialFunctionCondOpt
Browse files Browse the repository at this point in the history
fix #113
  • Loading branch information
xuwei-k committed Jan 13, 2024
1 parent 63dadcb commit 05419b0
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 12 deletions.
14 changes: 14 additions & 0 deletions input/src/main/scala/fix/PartialFunctionCondOptTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,18 @@ object PartialFunctionCondOptTest {
case "3" => Some(12340)
case _ => None
}

def x2(a: String, b: Int): Option[Int] = a match {
case "1" => Some(3)
case "2" =>
Some(
(b match {
case 1 => Some(100)
case 2 => Some(200)
case _ => None
}).getOrElse(-1)
)
case _ =>
None
}
}
13 changes: 13 additions & 0 deletions output/src/main/scala/fix/PartialFunctionCondOptTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,17 @@ object PartialFunctionCondOptTest {
case "2" => 10
case "3" => 12340
}

def x2(a: String, b: Int): Option[Int] = a match {
case "1" => Some(3)
case "2" =>
Some(
(PartialFunction.condOpt(b) {
case 1 => 100
case 2 => 200
}).getOrElse(-1)
)
case _ =>
None
}
}
34 changes: 22 additions & 12 deletions rules/src/main/scala/fix/PartialFunctionCondOpt.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,39 @@ import scalafix.v1.SyntacticRule
import scala.meta.Case
import scala.meta.Pat
import scala.meta.Term
import scala.meta.Tree

class PartialFunctionCondOpt extends SyntacticRule("PartialFunctionCondOpt") {
override def fix(implicit doc: SyntacticDocument): Patch = {
doc.tree.collect {
doc.tree.collect { case x => fix0(x).asPatch }.asPatch
}

private def fix0(tree: Tree)(implicit doc: SyntacticDocument): Option[Patch] = PartialFunction
.condOpt(tree) {
case t @ Term.Match.After_4_4_5(expr, init :+ last, _) if init.nonEmpty =>
last match {
case Case(Pat.Wildcard(), None, Term.Name("None")) =>
val values = init.collect { case a @ Case(_, _, Term.Apply.Initial(Term.Name("Some"), x :: Nil)) =>
a.copy(body = x)
val values = init.collect {
case a @ Case(_, _, Term.Apply.Initial(Term.Name("Some"), x :: Nil)) if x.collect { case a =>
fix0(a)
}.flatten.isEmpty =>
a.copy(body = x)
}
if (values.lengthCompare(init.size) == 0) {
Patch.replaceTree(
t,
s"""PartialFunction.condOpt($expr) {
| ${values.mkString("\n ")}
|}""".stripMargin
Some(
Patch.replaceTree(
t,
s"""PartialFunction.condOpt($expr) {
| ${values.mkString("\n ")}
|}""".stripMargin
)
)
} else {
Patch.empty
None
}
case _ =>
Patch.empty
None
}
}.asPatch
}
}
.flatten
}

0 comments on commit 05419b0

Please sign in to comment.