From 8d7b6ad21de703f660bd5b8ca8b60ab76bce0ee6 Mon Sep 17 00:00:00 2001 From: Diego Casella Date: Tue, 23 Apr 2024 00:19:36 +0200 Subject: [PATCH 1/9] feat: add support for Scala 2 #14 --- build.sbt | 12 +++++++++--- plugin/src/main/resources/scalac-plugin.xml | 4 ++++ .../polentino/redacted/RedactedPlugin.scala | 11 +++++++++++ .../redacted/RedactedPluginComponent.scala | 15 +++++++++++++++ .../polentino/redacted/RedactedPlugin.scala | 2 +- .../polentino/redacted/helpers/AstOps.scala | 0 .../polentino/redacted/helpers/PluginOps.scala | 0 .../polentino/redacted/phases/PatchToString.scala | 0 .../github/polentino/redacted/RedactedSpec.scala | 2 +- 9 files changed, 41 insertions(+), 5 deletions(-) create mode 100644 plugin/src/main/resources/scalac-plugin.xml create mode 100644 plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPlugin.scala create mode 100644 plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPluginComponent.scala rename plugin/src/main/{scala => scala-3}/io/github/polentino/redacted/RedactedPlugin.scala (88%) rename plugin/src/main/{scala => scala-3}/io/github/polentino/redacted/helpers/AstOps.scala (100%) rename plugin/src/main/{scala => scala-3}/io/github/polentino/redacted/helpers/PluginOps.scala (100%) rename plugin/src/main/{scala => scala-3}/io/github/polentino/redacted/phases/PatchToString.scala (100%) diff --git a/build.sbt b/build.sbt index 2a0f1f8..d30f5c7 100644 --- a/build.sbt +++ b/build.sbt @@ -6,6 +6,7 @@ val scalaCheckVersion = "3.2.17.0" // all LTS versions & latest minor ones val supportedScalaVersions = List( + "2.13.13", "3.1.3", "3.2.2", "3.3.0", @@ -64,7 +65,11 @@ lazy val redactedCompilerPlugin = (project in file("plugin")) .settings(name := "redacted-plugin") .settings( crossCompileSettings, - libraryDependencies += "org.scala-lang" %% "scala3-compiler" % scalaVersion.value + libraryDependencies += (CrossVersion.partialVersion(scalaVersion.value) match { + case Some((3, _)) => "org.scala-lang" %% "scala3-compiler" % scalaVersion.value + case Some((2, _)) => "org.scala-lang" % "scala-compiler" % scalaVersion.value + case v => throw new Exception(s"Scala version $v not recognised") + }) ) lazy val redactedTests = (project in file("tests")) @@ -80,9 +85,10 @@ lazy val redactedTests = (project in file("tests")) ), Test / scalacOptions ++= { val jar = (redactedCompilerPlugin / Compile / packageBin).value - val addPlugin = "-Xplugin:" + jar.getAbsolutePath + val addScala2Plugin = "-Xplugin-require:redacted-plugin" + val addScala3Plugin = "-Xplugin:" + jar.getAbsolutePath val dummy = "-Jdummy=" + jar.lastModified - Seq(addPlugin, dummy) + Seq(addScala2Plugin, addScala3Plugin, dummy) } ) diff --git a/plugin/src/main/resources/scalac-plugin.xml b/plugin/src/main/resources/scalac-plugin.xml new file mode 100644 index 0000000..4af80e0 --- /dev/null +++ b/plugin/src/main/resources/scalac-plugin.xml @@ -0,0 +1,4 @@ + + redacted-plugin + io.github.polentino.redacted.RedactedPlugin + \ No newline at end of file diff --git a/plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPlugin.scala b/plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPlugin.scala new file mode 100644 index 0000000..94e4d5a --- /dev/null +++ b/plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPlugin.scala @@ -0,0 +1,11 @@ +package io.github.polentino.redacted + +import scala.tools.nsc.Global +import scala.tools.nsc.plugins.Plugin +import scala.tools.nsc.plugins.PluginComponent + +final class RedactedPlugin(override val global: Global) extends Plugin { + override val name: String = "redacted-plugin" + override val description: String = "Plugin to prevent leaking sensitive data when logging case classes" + override val components: List[PluginComponent] = List(new RedactedPluginComponent(global)) +} diff --git a/plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPluginComponent.scala b/plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPluginComponent.scala new file mode 100644 index 0000000..26fd8f3 --- /dev/null +++ b/plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPluginComponent.scala @@ -0,0 +1,15 @@ +package io.github.polentino.redacted + +import scala.tools.nsc._ +import scala.tools.nsc.plugins.PluginComponent +import scala.tools.nsc.transform.TypingTransformers + +final class RedactedPluginComponent(val global: Global) extends PluginComponent with TypingTransformers { + import global._ + override val phaseName: String = "redacted-plugin-component" + override val runsAfter: List[String] = List("parser") + + override def newPhase(prev: Phase): Phase = new StdPhase(prev) { + override def apply(unit: CompilationUnit): Unit = () + } +} diff --git a/plugin/src/main/scala/io/github/polentino/redacted/RedactedPlugin.scala b/plugin/src/main/scala-3/io/github/polentino/redacted/RedactedPlugin.scala similarity index 88% rename from plugin/src/main/scala/io/github/polentino/redacted/RedactedPlugin.scala rename to plugin/src/main/scala-3/io/github/polentino/redacted/RedactedPlugin.scala index b30240e..4bd9970 100644 --- a/plugin/src/main/scala/io/github/polentino/redacted/RedactedPlugin.scala +++ b/plugin/src/main/scala-3/io/github/polentino/redacted/RedactedPlugin.scala @@ -7,7 +7,7 @@ import io.github.polentino.redacted.phases._ class RedactedPlugin extends StandardPlugin { override def init(options: List[String]): List[PluginPhase] = List(PatchToString()) - override def name: String = "Redacted" + override def name: String = "redacted-plugin" override def description: String = "Plugin to prevent leaking sensitive data when logging case classes" } diff --git a/plugin/src/main/scala/io/github/polentino/redacted/helpers/AstOps.scala b/plugin/src/main/scala-3/io/github/polentino/redacted/helpers/AstOps.scala similarity index 100% rename from plugin/src/main/scala/io/github/polentino/redacted/helpers/AstOps.scala rename to plugin/src/main/scala-3/io/github/polentino/redacted/helpers/AstOps.scala diff --git a/plugin/src/main/scala/io/github/polentino/redacted/helpers/PluginOps.scala b/plugin/src/main/scala-3/io/github/polentino/redacted/helpers/PluginOps.scala similarity index 100% rename from plugin/src/main/scala/io/github/polentino/redacted/helpers/PluginOps.scala rename to plugin/src/main/scala-3/io/github/polentino/redacted/helpers/PluginOps.scala diff --git a/plugin/src/main/scala/io/github/polentino/redacted/phases/PatchToString.scala b/plugin/src/main/scala-3/io/github/polentino/redacted/phases/PatchToString.scala similarity index 100% rename from plugin/src/main/scala/io/github/polentino/redacted/phases/PatchToString.scala rename to plugin/src/main/scala-3/io/github/polentino/redacted/phases/PatchToString.scala diff --git a/tests/src/test/scala/io/github/polentino/redacted/RedactedSpec.scala b/tests/src/test/scala/io/github/polentino/redacted/RedactedSpec.scala index 19bf6b5..944e732 100644 --- a/tests/src/test/scala/io/github/polentino/redacted/RedactedSpec.scala +++ b/tests/src/test/scala/io/github/polentino/redacted/RedactedSpec.scala @@ -1,6 +1,6 @@ package io.github.polentino.redacted -import org.scalatest.Checkpoints.* +import org.scalatest.Checkpoints._ import org.scalatest.flatspec.AnyFlatSpec import io.github.polentino.redacted.RedactionWithNestedCaseClass.Inner import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks From acf5d879eb5b1e092ef0878256f1af0edc130302 Mon Sep 17 00:00:00 2001 From: Diego Casella Date: Tue, 23 Apr 2024 00:21:09 +0200 Subject: [PATCH 2/9] fix: nopenopenope, no braceless syntax :/ --- .../redacted/phases/PatchToString.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/plugin/src/main/scala-3/io/github/polentino/redacted/phases/PatchToString.scala b/plugin/src/main/scala-3/io/github/polentino/redacted/phases/PatchToString.scala index ff11f25..0725f98 100644 --- a/plugin/src/main/scala-3/io/github/polentino/redacted/phases/PatchToString.scala +++ b/plugin/src/main/scala-3/io/github/polentino/redacted/phases/PatchToString.scala @@ -22,7 +22,6 @@ final case class PatchToString() extends PluginPhase { case None => tree case Some(validatedTree) => val maybeNewTypeDef = for { - template <- getTreeTemplate(validatedTree) .withLog(s"can't extract proper `tpd.Template` from ${tree.name}") @@ -36,20 +35,21 @@ final case class PatchToString() extends PluginPhase { .withLog(s"couldn't patch ${tree.name} template into ${tree.name} typedef") } yield result - maybeNewTypeDef match + maybeNewTypeDef match { case Some(newTypeDef) => newTypeDef case None => report.warning( s""" - |Dang, couldn't patch properly ${tree.name} :( - |If you believe this is an error: please report the issue, along with a minimum reproducible example, - |at the following link: https://github.com/polentino/redacted/issues/new . - | - |Thank you 🙏 - |""".stripMargin, + |Dang, couldn't patch properly ${tree.name} :( + |If you believe this is an error: please report the issue, along with a minimum reproducible example, + |at the following link: https://github.com/polentino/redacted/issues/new . + | + |Thank you 🙏 + |""".stripMargin, tree.srcPos ) tree + } } } From ded70c650fb4a7e8b742e78a88f698a9c37d2ff2 Mon Sep 17 00:00:00 2001 From: Diego Casella Date: Fri, 26 Apr 2024 22:27:19 +0200 Subject: [PATCH 3/9] fix: wip, traverse & log redacted fields --- .../redacted/RedactedPluginComponent.scala | 29 ++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPluginComponent.scala b/plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPluginComponent.scala index 26fd8f3..97d7472 100644 --- a/plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPluginComponent.scala +++ b/plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPluginComponent.scala @@ -8,8 +8,35 @@ final class RedactedPluginComponent(val global: Global) extends PluginComponent import global._ override val phaseName: String = "redacted-plugin-component" override val runsAfter: List[String] = List("parser") + override val runsRightAfter: Option[String] = Some("parser") override def newPhase(prev: Phase): Phase = new StdPhase(prev) { - override def apply(unit: CompilationUnit): Unit = () + + override def apply(unit: CompilationUnit): Unit = { + global.reporter.echo(s"[!] Inspecting Compilation Unit '${unit.toString()}'") + new Traverser { + override def traverse(tree: global.Tree): Unit = { + val t = tree match { + case cd: ClassDef if cd.mods.hasFlag(Flag.CASE) => + global.reporter.echo(s"\t -> inspecting case class '${cd.name}'") + val annotations = cd.impl.body.flatMap { + case dd: DefDef if dd.name == termNames.CONSTRUCTOR => + dd.vparamss.flatMap { p => + p.filter { pp => + pp.mods.hasAnnotationNamed(TypeName("redacted")) + } + }.map(_.name) + case _ => Nil + } + if (annotations.nonEmpty) { + global.reporter.echo(s"\t\t -> annotations: $annotations") + } + cd + case c => c + } + super.traverse(t) + } + }.traverse(unit.body) + } } } From 9755dedf6802acb8d209fb0044d76c6c5e6b7e0e Mon Sep 17 00:00:00 2001 From: Diego Casella Date: Tue, 30 Apr 2024 18:06:05 +0200 Subject: [PATCH 4/9] fix: y u no patch :/ --- .../polentino/redacted/RedactedPlugin.scala | 2 +- .../redacted/RedactedPluginComponent.scala | 142 +++++++++++++++--- 2 files changed, 122 insertions(+), 22 deletions(-) diff --git a/plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPlugin.scala b/plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPlugin.scala index 94e4d5a..bc5fdb0 100644 --- a/plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPlugin.scala +++ b/plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPlugin.scala @@ -1,6 +1,6 @@ package io.github.polentino.redacted -import scala.tools.nsc.Global +import scala.tools.nsc._ import scala.tools.nsc.plugins.Plugin import scala.tools.nsc.plugins.PluginComponent diff --git a/plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPluginComponent.scala b/plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPluginComponent.scala index 97d7472..951a11f 100644 --- a/plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPluginComponent.scala +++ b/plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPluginComponent.scala @@ -4,11 +4,11 @@ import scala.tools.nsc._ import scala.tools.nsc.plugins.PluginComponent import scala.tools.nsc.transform.TypingTransformers -final class RedactedPluginComponent(val global: Global) extends PluginComponent with TypingTransformers { +final class RedactedPluginComponent(val global: Global) extends PluginComponent { import global._ override val phaseName: String = "redacted-plugin-component" - override val runsAfter: List[String] = List("parser") - override val runsRightAfter: Option[String] = Some("parser") + override val runsAfter: List[String] = List("pickler") +// override val runsRightAfter: Option[String] = Some("typer") override def newPhase(prev: Phase): Phase = new StdPhase(prev) { @@ -16,27 +16,127 @@ final class RedactedPluginComponent(val global: Global) extends PluginComponent global.reporter.echo(s"[!] Inspecting Compilation Unit '${unit.toString()}'") new Traverser { override def traverse(tree: global.Tree): Unit = { - val t = tree match { - case cd: ClassDef if cd.mods.hasFlag(Flag.CASE) => - global.reporter.echo(s"\t -> inspecting case class '${cd.name}'") - val annotations = cd.impl.body.flatMap { - case dd: DefDef if dd.name == termNames.CONSTRUCTOR => - dd.vparamss.flatMap { p => - p.filter { pp => - pp.mods.hasAnnotationNamed(TypeName("redacted")) - } - }.map(_.name) - case _ => Nil - } - if (annotations.nonEmpty) { - global.reporter.echo(s"\t\t -> annotations: $annotations") - } - cd - case c => c + // Logic that decides whether to transform the tree or not + if (shouldTransform(tree)) { + global.reporter.echo(s"\t PATCHING ${tree.symbol.name}") + val p = tree match { + case cd: ClassDef => + val newImpl = transformTemplate(cd.impl) + cd.copy(impl = newImpl) + case c => c + } + super.traverse(p) + } else { + super.traverse(tree) } - super.traverse(t) + + // val t = tree match { +// case cd: ClassDef if cd.mods.hasFlag(Flag.CASE) => +// global.reporter.echo(s"\t -> inspecting case class '${cd.name}'") +// val annotations = cd.impl.body.flatMap { +// case dd: DefDef if dd.name == termNames.CONSTRUCTOR => +// dd.vparamss.flatMap { p => +// p.filter { pp => +// pp.mods.hasAnnotationNamed(TypeName("redacted")) +// } +// }.map(_.name) +// case _ => Nil +// } +// if (annotations.nonEmpty) { +// global.reporter.echo(s"\t\t -> annotations: $annotations") +// val modifiedBody = cd.impl.body.map { +// case dd: DefDef if dd.name == TermName("toString") => dd.copy(rhs = Literal(Constant("TEST"))) +// case other => other +// } +// val t = cd.copy(impl = cd.impl.copy(body = modifiedBody)) +// new Transformer {}.transform(t) +// } else { +// cd +// } +// case c => c +// } +// super.traverse(t) } }.traverse(unit.body) } } + + private def shouldTransform(tree: global.Tree) = tree match { + case cd: ClassDef if cd.mods.hasFlag(Flag.CASE) => + cd.impl.body.exists { + case dd: DefDef => hasAnnotation(dd, "redacted") + case _ => false + } + case _ => false + } + + private def getAnnotations(tree: global.Tree) = { + tree match { + case cd: ClassDef if cd.mods.hasFlag(Flag.CASE) => +// val annotations = cd.impl.body.flatMap { +// case dd: DefDef if dd.name == termNames.CONSTRUCTOR => +// dd.vparamss.flatMap { p => +// p.filter { pp => +// pp.mods.hasAnnotationNamed(TypeName("redacted")) +// } +// }.map(_.name) +// case _ => Nil +// } +// annotations.nonEmpty + + case _ => Nil + } + } + + private def transformTemplate(template: Template): Template = { + val newBody = template.body.map { + case dd: DefDef if dd.name == TermName("toString") => + global.reporter.echo("\t\t WE HAVE A TOSTRING!") +// if (hasAnnotation(dd, "redacted")) { + val newRhs = Literal(Constant("TEST")) + dd.copy(rhs = newRhs) +// } else { +// dd +// } + case other => other + } + template.copy(body = newBody) + } + + private def hasAnnotation(dd: DefDef, annotationName: String): Boolean = { + dd.vparamss.flatten.exists(_.symbol.annotations.exists(_.tree.tpe.typeSymbol.name.toString == annotationName)) + } + + private class MyTransformer(val global: Global) extends Transformer { + + override def transform(tree: Tree): Tree = { + tree match { + case cd: ClassDef if cd.mods.hasFlag(Flag.CASE) => + global.reporter.echo(s"\t -> inspecting case class '${cd.name}'") + val annotations = cd.impl.body.flatMap { + case dd: DefDef if dd.name == termNames.CONSTRUCTOR => + dd.vparamss.flatMap { p => + p.filter { pp => + pp.mods.hasAnnotationNamed(TypeName("redacted")) + } + }.map(_.name) + case _ => Nil + } + + global.reporter.echo(s"\t\t -> annotations: $annotations") + global.reporter.echo(s"\t\t -> body: $cd.impl.body") + val modifiedBody = cd.impl.body.map { + case dd: DefDef if dd.name == TermName("toString") => + global.reporter.echo(s"\t -> PATCHED!") + dd.copy(rhs = Literal(Constant("TEST"))) + case other => + global.reporter.echo(s"\t -> SKIPPING ${other.symbol}") + other + } + cd.copy(impl = cd.impl.copy(body = modifiedBody)) + + case _ => tree + } + } + } } From 0c57ba74f892cb25a795f5010bfcb4207b549f27 Mon Sep 17 00:00:00 2001 From: Diego Casella Date: Sun, 12 May 2024 18:15:44 +0200 Subject: [PATCH 5/9] fix: it works! --- .../polentino/redacted/RedactedPlugin.scala | 3 + .../redacted/RedactedPluginComponent.scala | 281 ++++++++++-------- .../redacted/phases/PatchToString.scala | 2 +- .../polentino/redacted/RedactedSpec.scala | 23 ++ 4 files changed, 189 insertions(+), 120 deletions(-) diff --git a/plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPlugin.scala b/plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPlugin.scala index bc5fdb0..0328b73 100644 --- a/plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPlugin.scala +++ b/plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPlugin.scala @@ -5,7 +5,10 @@ import scala.tools.nsc.plugins.Plugin import scala.tools.nsc.plugins.PluginComponent final class RedactedPlugin(override val global: Global) extends Plugin { + override val name: String = "redacted-plugin" + override val description: String = "Plugin to prevent leaking sensitive data when logging case classes" + override val components: List[PluginComponent] = List(new RedactedPluginComponent(global)) } diff --git a/plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPluginComponent.scala b/plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPluginComponent.scala index 951a11f..bcaa433 100644 --- a/plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPluginComponent.scala +++ b/plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPluginComponent.scala @@ -1,141 +1,184 @@ package io.github.polentino.redacted -import scala.tools.nsc._ +import scala.tools.nsc.backend.jvm.GenBCode import scala.tools.nsc.plugins.PluginComponent -import scala.tools.nsc.transform.TypingTransformers +import scala.tools.nsc.transform.Transform +import scala.util.Success +import scala.tools.nsc.Global + +class RedactedPluginComponent(val global: Global) extends PluginComponent with Transform { + + override val phaseName: String = "patch-tostring-component" + + override val runsAfter: List[String] = List("parser") + + override val runsRightAfter: Option[String] = Some("parser") -final class RedactedPluginComponent(val global: Global) extends PluginComponent { import global._ - override val phaseName: String = "redacted-plugin-component" - override val runsAfter: List[String] = List("pickler") -// override val runsRightAfter: Option[String] = Some("typer") - - override def newPhase(prev: Phase): Phase = new StdPhase(prev) { - - override def apply(unit: CompilationUnit): Unit = { - global.reporter.echo(s"[!] Inspecting Compilation Unit '${unit.toString()}'") - new Traverser { - override def traverse(tree: global.Tree): Unit = { - // Logic that decides whether to transform the tree or not - if (shouldTransform(tree)) { - global.reporter.echo(s"\t PATCHING ${tree.symbol.name}") - val p = tree match { - case cd: ClassDef => - val newImpl = transformTemplate(cd.impl) - cd.copy(impl = newImpl) - case c => c - } - super.traverse(p) - } else { - super.traverse(tree) + + override protected def newTransformer(unit: CompilationUnit): Transformer = ToStringMaskerTransformer + + private object ToStringMaskerTransformer extends Transformer { + + private val TO_STRING_NAME = "toString" + private val redactedTypeName = TypeName("redacted") + + override def transform(tree: Tree): Tree = { + val transformedTree = super.transform(tree) + validate(transformedTree) match { + case None => transformedTree + case Some(validatedClassDef) => + val maybePatchedClassDef = for { + newToStringBody <- createToStringBody(validatedClassDef) + .withLog(s"couldn't create a valid toString body for ${validatedClassDef.name.decode}") + + newToStringMethod <- buildToStringMethod(newToStringBody) + .withLog(s"couldn't create a valid toString body for ${validatedClassDef.name.decode}") + + patchedClassDef <- patchCaseClass(validatedClassDef, newToStringMethod) + .withLog(s"couldn't create a valid toString body for ${validatedClassDef.name.decode}") + + } yield patchedClassDef + + maybePatchedClassDef match { + case Some(patchedClassDef) => patchedClassDef + case None => + reporter.warning( + tree.pos, + s""" + |Dang, couldn't patch properly ${tree.symbol.nameString} :( + |If you believe this is an error: please report the issue, along with a minimum reproducible example, + |at the following link: https://github.com/polentino/redacted/issues/new . + | + |Thank you 🙏 + |""".stripMargin + ) + tree } + } + } - // val t = tree match { -// case cd: ClassDef if cd.mods.hasFlag(Flag.CASE) => -// global.reporter.echo(s"\t -> inspecting case class '${cd.name}'") -// val annotations = cd.impl.body.flatMap { -// case dd: DefDef if dd.name == termNames.CONSTRUCTOR => -// dd.vparamss.flatMap { p => -// p.filter { pp => -// pp.mods.hasAnnotationNamed(TypeName("redacted")) -// } -// }.map(_.name) -// case _ => Nil -// } -// if (annotations.nonEmpty) { -// global.reporter.echo(s"\t\t -> annotations: $annotations") -// val modifiedBody = cd.impl.body.map { -// case dd: DefDef if dd.name == TermName("toString") => dd.copy(rhs = Literal(Constant("TEST"))) -// case other => other -// } -// val t = cd.copy(impl = cd.impl.copy(body = modifiedBody)) -// new Transformer {}.transform(t) -// } else { -// cd -// } -// case c => c -// } -// super.traverse(t) - } - }.traverse(unit.body) + /** Utility method that ensures the current tree being inspected is a case class with at least one parameter + * annotated with `@redacted`. + * @param tree + * the tree to be checked + * @return + * an option containing the validated `ClassDef`, or `None` + */ + private def validate(tree: Tree): Option[global.ClassDef] = for { + caseClassType <- validateTypeDef(tree) + _ <- getRedactedFields(caseClassType) + } yield caseClassType + + /** Utility method that checks whether the current tree being inspected corresponds to a case class. + * @param tree + * the tree to be checked + * @return + * an option containing the validated `ClassDef`, or `None` + */ + private def validateTypeDef(tree: Tree): Option[ClassDef] = tree match { + case classDef: ClassDef if classDef.mods.isCase => Some(classDef) + case _ => None } - } - private def shouldTransform(tree: global.Tree) = tree match { - case cd: ClassDef if cd.mods.hasFlag(Flag.CASE) => - cd.impl.body.exists { - case dd: DefDef => hasAnnotation(dd, "redacted") - case _ => false + /** Utility method that returns all ctor fields annotated with `@redacted` + * @param classDef + * the ClassDef to be checked + * @return + * an Option with the list of all params marked with `@redacted`, or `None` otherwise + */ + private def getRedactedFields(classDef: ClassDef): Option[List[ValDef]] = + classDef.impl.body.collectFirst { + case d: DefDef if d.name.decode == GenBCode.INSTANCE_CONSTRUCTOR_NAME => + d.vparamss.flatMap(_.filter(_.mods.hasAnnotationNamed(redactedTypeName))) + } match { + case some @ Some(values) if values.nonEmpty => some + case _ => None + } + + /** Utility method to generate a new `toString` definition based on the parameters marked with `@redacted`. + * @param classDef + * the ClassDef for which we need a dedicated `toString` method + * @return + * the body of the new `toString` method + */ + private def createToStringBody(classDef: ClassDef): scala.util.Try[Tree] = scala.util.Try { + val className = classDef.name.decode + val memberNames = getAllFields(classDef) + val classPrefix = (className + "(").toConstantLiteral + val classSuffix = ")".toConstantLiteral + val commaSymbol = ",".toConstantLiteral + val asterisksSymbol = "***".toConstantLiteral + val concatOperator = TermName("$plus") + + val fragments: List[Tree] = memberNames.map(m => + if (m.mods.hasAnnotationNamed(redactedTypeName)) asterisksSymbol + else Apply(Select(Ident(m.name), TO_STRING_NAME), Nil)) + + def buildToStringTree(fragments: List[Tree]): Tree = { + + def concatAll(l: List[Tree]): List[Tree] = l match { + case Nil => Nil + case head :: Nil => List(head) + case head :: tail => List(head, commaSymbol) ++ concatAll(tail) + } + + val res = concatAll(fragments).fold(classPrefix) { case (accumulator, fragment) => + Apply(Select(accumulator, concatOperator), List(fragment)) + } + Apply(Select(res, concatOperator), List(classSuffix)) } - case _ => false - } - private def getAnnotations(tree: global.Tree) = { - tree match { - case cd: ClassDef if cd.mods.hasFlag(Flag.CASE) => -// val annotations = cd.impl.body.flatMap { -// case dd: DefDef if dd.name == termNames.CONSTRUCTOR => -// dd.vparamss.flatMap { p => -// p.filter { pp => -// pp.mods.hasAnnotationNamed(TypeName("redacted")) -// } -// }.map(_.name) -// case _ => Nil -// } -// annotations.nonEmpty - - case _ => Nil + buildToStringTree(fragments) } - } - private def transformTemplate(template: Template): Template = { - val newBody = template.body.map { - case dd: DefDef if dd.name == TermName("toString") => - global.reporter.echo("\t\t WE HAVE A TOSTRING!") -// if (hasAnnotation(dd, "redacted")) { - val newRhs = Literal(Constant("TEST")) - dd.copy(rhs = newRhs) -// } else { -// dd -// } - case other => other + /** Returns all the fields in a case class ctor. + * @param classDef + * the `ClassDef` for which we want to get all if ctor field + * @return + * a list of all the `ValDef` + */ + private def getAllFields(classDef: ClassDef): List[ValDef] = + classDef.impl.body.collectFirst { + case d: DefDef if d.name.decode == GenBCode.INSTANCE_CONSTRUCTOR_NAME => d.vparamss.flatten + }.getOrElse(Nil) + + /** Build a new `toString` method definition containing the body passed as parameter. + * @param body + * the body of the newly created `toString` method + * @return + * the whole `toString` method definition + */ + private def buildToStringMethod(body: Tree): scala.util.Try[DefDef] = scala.util.Try { + DefDef(Modifiers(Flag.OVERRIDE), TermName(TO_STRING_NAME), Nil, Nil, TypeTree(), body) } - template.copy(body = newBody) - } - private def hasAnnotation(dd: DefDef, annotationName: String): Boolean = { - dd.vparamss.flatten.exists(_.symbol.annotations.exists(_.tree.tpe.typeSymbol.name.toString == annotationName)) - } + /** Utility method that adds a new method definition to an existing `ClassDef` body. + * @param classDef + * the class that needs to be patched + * @param newToStringMethod + * the new method that will be included in the `ClassDef` passed as first parameter + * @return + * the patched `ClassDef` + */ + private def patchCaseClass(classDef: ClassDef, newToStringMethod: Tree): scala.util.Try[ClassDef] = + scala.util.Try { + val newBody = classDef.impl.body :+ newToStringMethod + val newImpl = classDef.impl.copy(body = newBody) + classDef.copy(impl = newImpl) + } - private class MyTransformer(val global: Global) extends Transformer { + // utility extension classes - override def transform(tree: Tree): Tree = { - tree match { - case cd: ClassDef if cd.mods.hasFlag(Flag.CASE) => - global.reporter.echo(s"\t -> inspecting case class '${cd.name}'") - val annotations = cd.impl.body.flatMap { - case dd: DefDef if dd.name == termNames.CONSTRUCTOR => - dd.vparamss.flatMap { p => - p.filter { pp => - pp.mods.hasAnnotationNamed(TypeName("redacted")) - } - }.map(_.name) - case _ => Nil - } + private implicit class AstOps(s: String) { + def toConstantLiteral: Literal = Literal(Constant(s)) + } - global.reporter.echo(s"\t\t -> annotations: $annotations") - global.reporter.echo(s"\t\t -> body: $cd.impl.body") - val modifiedBody = cd.impl.body.map { - case dd: DefDef if dd.name == TermName("toString") => - global.reporter.echo(s"\t -> PATCHED!") - dd.copy(rhs = Literal(Constant("TEST"))) - case other => - global.reporter.echo(s"\t -> SKIPPING ${other.symbol}") - other - } - cd.copy(impl = cd.impl.copy(body = modifiedBody)) + private implicit class TryOps[Out](opt: scala.util.Try[Out]) { - case _ => tree + def withLog(message: String): Option[Out] = opt match { + case Success(value) => Some(value) + case _ => reporter.echo(message); None } } } diff --git a/plugin/src/main/scala-3/io/github/polentino/redacted/phases/PatchToString.scala b/plugin/src/main/scala-3/io/github/polentino/redacted/phases/PatchToString.scala index 0725f98..26e0868 100644 --- a/plugin/src/main/scala-3/io/github/polentino/redacted/phases/PatchToString.scala +++ b/plugin/src/main/scala-3/io/github/polentino/redacted/phases/PatchToString.scala @@ -54,5 +54,5 @@ final case class PatchToString() extends PluginPhase { } object PatchToString { - final val name: String = "PatchToString" + final val name: String = "patch-tostring-phase" } diff --git a/tests/src/test/scala/io/github/polentino/redacted/RedactedSpec.scala b/tests/src/test/scala/io/github/polentino/redacted/RedactedSpec.scala index 944e732..fe8714e 100644 --- a/tests/src/test/scala/io/github/polentino/redacted/RedactedSpec.scala +++ b/tests/src/test/scala/io/github/polentino/redacted/RedactedSpec.scala @@ -71,6 +71,29 @@ class RedactedSpec extends AnyFlatSpec with ScalaCheckPropertyChecks { } } + it should "not confuse the parameter of a method with the parameter of the main ctor" in { + case class TestWrongAnnotationPlacement(name: String, age: Int) { + + /** WRONG! */ + def toUpper(@redacted name: String): String = name.toUpperCase() + } + + forAll { (name: String, age: Int) => + val expected = s"TestWrongAnnotationPlacement($name,$age)" + val testing = TestWrongAnnotationPlacement(name, age) + val implicitToString = s"$testing" + val explicitToString = testing.toString + + val cp = new Checkpoint + cp { assert(implicitToString == expected) } + cp { assert(explicitToString == expected) } + cp { + assert(testing.name == name && testing.age == age) + } + cp.reportAll() + } + } + it should "work with nested case classes in case class" in { case class Inner(userId: String, @redacted balance: Int) case class Outer(inner: Inner) From 7f79f9b93a6b5625d3e00d33216fcd103d940e63 Mon Sep 17 00:00:00 2001 From: Diego Casella Date: Sun, 12 May 2024 22:29:07 +0200 Subject: [PATCH 6/9] fix: added test for curried case class ctor --- .../redacted/RedactedPluginComponent.scala | 7 ++----- .../polentino/redacted/helpers/AstOps.scala | 6 ++++-- .../redacted/helpers/PluginOps.scala | 2 +- .../polentino/redacted/RedactedSpec.scala | 19 +++++++++++++++++++ 4 files changed, 26 insertions(+), 8 deletions(-) diff --git a/plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPluginComponent.scala b/plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPluginComponent.scala index bcaa433..09bd079 100644 --- a/plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPluginComponent.scala +++ b/plugin/src/main/scala-2/io/github/polentino/redacted/RedactedPluginComponent.scala @@ -90,10 +90,7 @@ class RedactedPluginComponent(val global: Global) extends PluginComponent with T private def getRedactedFields(classDef: ClassDef): Option[List[ValDef]] = classDef.impl.body.collectFirst { case d: DefDef if d.name.decode == GenBCode.INSTANCE_CONSTRUCTOR_NAME => - d.vparamss.flatMap(_.filter(_.mods.hasAnnotationNamed(redactedTypeName))) - } match { - case some @ Some(values) if values.nonEmpty => some - case _ => None + d.vparamss.headOption.fold(List.empty[ValDef])(v => v.filter(_.mods.hasAnnotationNamed(redactedTypeName))) } /** Utility method to generate a new `toString` definition based on the parameters marked with `@redacted`. @@ -140,7 +137,7 @@ class RedactedPluginComponent(val global: Global) extends PluginComponent with T */ private def getAllFields(classDef: ClassDef): List[ValDef] = classDef.impl.body.collectFirst { - case d: DefDef if d.name.decode == GenBCode.INSTANCE_CONSTRUCTOR_NAME => d.vparamss.flatten + case d: DefDef if d.name.decode == GenBCode.INSTANCE_CONSTRUCTOR_NAME => d.vparamss.headOption.getOrElse(Nil) }.getOrElse(Nil) /** Build a new `toString` method definition containing the body passed as parameter. diff --git a/plugin/src/main/scala-3/io/github/polentino/redacted/helpers/AstOps.scala b/plugin/src/main/scala-3/io/github/polentino/redacted/helpers/AstOps.scala index 3e50491..32bcb55 100644 --- a/plugin/src/main/scala-3/io/github/polentino/redacted/helpers/AstOps.scala +++ b/plugin/src/main/scala-3/io/github/polentino/redacted/helpers/AstOps.scala @@ -20,8 +20,10 @@ object AstOps { def redactedFields: List[String] = { val redactedType = redactedSymbol - symbol.primaryConstructor.paramSymss.flatten.collect { - case s if s.annotations.exists(_.matches(redactedType)) => s.name.toString + symbol.primaryConstructor.paramSymss.headOption.fold(List.empty[String]) { params => + params + .filter(_.annotations.exists(_.matches(redactedType))) + .map(_.name.toString) } } } diff --git a/plugin/src/main/scala-3/io/github/polentino/redacted/helpers/PluginOps.scala b/plugin/src/main/scala-3/io/github/polentino/redacted/helpers/PluginOps.scala index 9fb0e6a..3c1c42e 100644 --- a/plugin/src/main/scala-3/io/github/polentino/redacted/helpers/PluginOps.scala +++ b/plugin/src/main/scala-3/io/github/polentino/redacted/helpers/PluginOps.scala @@ -117,7 +117,7 @@ object PluginOps { */ def createToStringBody(tree: tpd.TypeDef)(using Context): Try[tpd.Tree] = Try { val className = tree.name.toString - val memberNames = tree.symbol.primaryConstructor.paramSymss.flatten + val memberNames = tree.symbol.primaryConstructor.paramSymss.headOption.getOrElse(Nil) val annotationSymbol = redactedSymbol val classPrefix = (className + "(").toConstantLiteral val classSuffix = ")".toConstantLiteral diff --git a/tests/src/test/scala/io/github/polentino/redacted/RedactedSpec.scala b/tests/src/test/scala/io/github/polentino/redacted/RedactedSpec.scala index fe8714e..0c582c8 100644 --- a/tests/src/test/scala/io/github/polentino/redacted/RedactedSpec.scala +++ b/tests/src/test/scala/io/github/polentino/redacted/RedactedSpec.scala @@ -94,6 +94,25 @@ class RedactedSpec extends AnyFlatSpec with ScalaCheckPropertyChecks { } } + it should "ignore `@redacted` annotation on curried parameters" in { + case class Curried(age: Int, @redacted name: String)(@redacted email: String) + + forAll { (age: Int, name: String, email: String) => + val expected = s"Curried($age,***)" + val testing = Curried(age, name)(email) + val implicitToString = s"$testing" + val explicitToString = testing.toString + + val cp = new Checkpoint + cp { assert(implicitToString == expected) } + cp { assert(explicitToString == expected) } + cp { + assert(testing.age == age && testing.name == name) + } + cp.reportAll() + } + } + it should "work with nested case classes in case class" in { case class Inner(userId: String, @redacted balance: Int) case class Outer(inner: Inner) From e645fe7cf9ca8a8a415f4448250f868b4b1d595a Mon Sep 17 00:00:00 2001 From: Diego Casella Date: Sun, 12 May 2024 22:37:06 +0200 Subject: [PATCH 7/9] fix: crosscompile for scala 2.12.x and updated 3.4.x --- build.sbt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/build.sbt b/build.sbt index d30f5c7..9092d4f 100644 --- a/build.sbt +++ b/build.sbt @@ -6,13 +6,14 @@ val scalaCheckVersion = "3.2.17.0" // all LTS versions & latest minor ones val supportedScalaVersions = List( + "2.12.19", "2.13.13", "3.1.3", "3.2.2", "3.3.0", "3.3.1", "3.3.3", - "3.4.0" + "3.4.1" ) inThisBuild( From 7d80e6f7331f631890cdb91cbd031a957fff8c61 Mon Sep 17 00:00:00 2001 From: Diego Casella Date: Sun, 12 May 2024 22:45:58 +0200 Subject: [PATCH 8/9] chore: updated README.md --- README.md | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 2a38bb4..1b903aa 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,8 @@ in your `build.sbt` file, add the following lines ```scala 3 val redactedVersion = // use latest version of the library -resolvers += DefaultMavenRepository, + resolvers += DefaultMavenRepository +, libraryDependencies ++= Seq( "io.github.polentino" %% "redacted" % redactedVersion cross CrossVersion.full, compilerPlugin("io.github.polentino" %% "redacted-plugin" % redactedVersion cross CrossVersion.full) @@ -131,6 +132,30 @@ println(wrapper) will print > Wrapper(id-1,***) +### Note on curried case classes + +While it is possible to write something like + +```scala 3 +case class Curried(id: String, @redacted name: String)(@redacted email: String) +``` + +the `toString` method that Scala compiler generates by default will print only the parameters in the primary +constructor, meaning that + +```scala 3 +val c = Curried(0, "Berfu")("berfu@gmail.com") +println(c) +``` + +will display + +```scala 3 +Curried(0,Berfu) +``` + +Therefore, the same behavior is being kept in the customized `toString` implementation. + ## How it works Given a case class with at least one field annotated with `@redacted`, i.e. @@ -157,7 +182,8 @@ implementation by selectively returning either the `***` string, or the value of ```scala 3 def toString(): String = - "(" + this. + "," + "***" + ... + ")" + "(" + this.< field not redacted > + "," + "***" + +...+")" ``` ## Improvements From a8063e854efd3189af62a388ffcc61644eb00011 Mon Sep 17 00:00:00 2001 From: Diego Casella Date: Sun, 12 May 2024 22:48:06 +0200 Subject: [PATCH 9/9] chore: bump version --- version.sbt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.sbt b/version.sbt index eae5ef9..9283500 100644 --- a/version.sbt +++ b/version.sbt @@ -1 +1 @@ -ThisBuild / version := "0.4.0" +ThisBuild / version := "0.5.0"