diff --git a/shared/src/main/scala/scala/xml/transform/NestingTransformer.scala b/shared/src/main/scala/scala/xml/transform/NestingTransformer.scala
new file mode 100644
index 000000000..321e195b1
--- /dev/null
+++ b/shared/src/main/scala/scala/xml/transform/NestingTransformer.scala
@@ -0,0 +1,19 @@
+/* __ *\
+** ________ ___ / / ___ Scala API **
+** / __/ __// _ | / / / _ | (c) 2002-2020, LAMP/EPFL **
+** __\ \/ /__/ __ |/ /__/ __ | (c) 2011-2020, Lightbend, Inc. **
+** /____/\___/_/ |_/____/_/ | | http://scala-lang.org/ **
+** |/ **
+\* */
+
+package scala
+package xml
+package transform
+
+import scala.collection.Seq
+
+class NestingTransformer(rule: RewriteRule) extends BasicTransformer {
+ override def transform(n: Node): Seq[Node] = {
+ rule.transform(super.transform(n))
+ }
+}
diff --git a/shared/src/main/scala/scala/xml/transform/RuleTransformer.scala b/shared/src/main/scala/scala/xml/transform/RuleTransformer.scala
index d8c1a56d0..b7e62644a 100644
--- a/shared/src/main/scala/scala/xml/transform/RuleTransformer.scala
+++ b/shared/src/main/scala/scala/xml/transform/RuleTransformer.scala
@@ -13,6 +13,9 @@ package transform
import scala.collection.Seq
class RuleTransformer(rules: RewriteRule*) extends BasicTransformer {
- override def transform(n: Node): Seq[Node] =
- rules.foldLeft(super.transform(n)) { (res, rule) => rule transform res }
+ private val transformers = rules.map(new NestingTransformer(_))
+ override def transform(n: Node): Seq[Node] = {
+ if (transformers.isEmpty) n
+ else transformers.tail.foldLeft(transformers.head.transform(n)) { (res, transformer) => transformer.transform(res) }
+ }
}
diff --git a/shared/src/test/scala-2.x/scala/xml/TransformersTest.scala b/shared/src/test/scala-2.x/scala/xml/TransformersTest.scala
index 59ca25240..f8c5b661a 100644
--- a/shared/src/test/scala-2.x/scala/xml/TransformersTest.scala
+++ b/shared/src/test/scala-2.x/scala/xml/TransformersTest.scala
@@ -60,7 +60,7 @@ class TransformersTest {
@Test
def preserveReferentialComplexityInLinearComplexity = { // SI-4528
var i = 0
-
+
val xmlNode = Hello Example
new RuleTransformer(new RewriteRule {
@@ -77,4 +77,19 @@ class TransformersTest {
assertEquals(1, i)
}
+
+ @Test
+ def appliesRulesRecursivelyOnPreviousChanges = { // #257
+ def add(outer: Elem, inner: Node) = new RewriteRule {
+ override def transform(n: Node): Seq[Node] = n match {
+ case e: Elem if e.label == outer.label => e.copy(child = e.child ++ inner)
+ case other => other
+ }
+ }
+
+ def transformer = new RuleTransformer(add(, ), add(, ))
+
+ assertEquals(, transformer())
+ }
}
+