diff --git a/shared/src/main/scala/scala/xml/transform/NestingTransformer.scala b/shared/src/main/scala/scala/xml/transform/NestingTransformer.scala new file mode 100644 index 000000000..321e195b1 --- /dev/null +++ b/shared/src/main/scala/scala/xml/transform/NestingTransformer.scala @@ -0,0 +1,19 @@ +/* __ *\ +** ________ ___ / / ___ Scala API ** +** / __/ __// _ | / / / _ | (c) 2002-2020, LAMP/EPFL ** +** __\ \/ /__/ __ |/ /__/ __ | (c) 2011-2020, Lightbend, Inc. ** +** /____/\___/_/ |_/____/_/ | | http://scala-lang.org/ ** +** |/ ** +\* */ + +package scala +package xml +package transform + +import scala.collection.Seq + +class NestingTransformer(rule: RewriteRule) extends BasicTransformer { + override def transform(n: Node): Seq[Node] = { + rule.transform(super.transform(n)) + } +} diff --git a/shared/src/main/scala/scala/xml/transform/RuleTransformer.scala b/shared/src/main/scala/scala/xml/transform/RuleTransformer.scala index d8c1a56d0..b7e62644a 100644 --- a/shared/src/main/scala/scala/xml/transform/RuleTransformer.scala +++ b/shared/src/main/scala/scala/xml/transform/RuleTransformer.scala @@ -13,6 +13,9 @@ package transform import scala.collection.Seq class RuleTransformer(rules: RewriteRule*) extends BasicTransformer { - override def transform(n: Node): Seq[Node] = - rules.foldLeft(super.transform(n)) { (res, rule) => rule transform res } + private val transformers = rules.map(new NestingTransformer(_)) + override def transform(n: Node): Seq[Node] = { + if (transformers.isEmpty) n + else transformers.tail.foldLeft(transformers.head.transform(n)) { (res, transformer) => transformer.transform(res) } + } } diff --git a/shared/src/test/scala-2.x/scala/xml/TransformersTest.scala b/shared/src/test/scala-2.x/scala/xml/TransformersTest.scala index 59ca25240..f8c5b661a 100644 --- a/shared/src/test/scala-2.x/scala/xml/TransformersTest.scala +++ b/shared/src/test/scala-2.x/scala/xml/TransformersTest.scala @@ -60,7 +60,7 @@ class TransformersTest { @Test def preserveReferentialComplexityInLinearComplexity = { // SI-4528 var i = 0 - + val xmlNode =

Hello Example

new RuleTransformer(new RewriteRule { @@ -77,4 +77,19 @@ class TransformersTest { assertEquals(1, i) } + + @Test + def appliesRulesRecursivelyOnPreviousChanges = { // #257 + def add(outer: Elem, inner: Node) = new RewriteRule { + override def transform(n: Node): Seq[Node] = n match { + case e: Elem if e.label == outer.label => e.copy(child = e.child ++ inner) + case other => other + } + } + + def transformer = new RuleTransformer(add(, ), add(, )) + + assertEquals(, transformer()) + } } +