diff --git a/CHANGELOG.md b/CHANGELOG.md index ab1c71bd..fddcfb4e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,15 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/) and this project adheres to [Semantic Versioning](http://semver.org/). +## [1.2.0] - 2022-03-31 + +### Changed + +- `OverflowDbDriver` now takes a `DataFlowCacheConfig` argument that specifies +data flow engine specific configurations. +- `OverflowDbDriver::nodesReachableBy` renamed to `flowsBetween` and now takes functions +as `sources` and `sinks` parameters. + ## [1.1.19] - 2022-03-30 ### Changed diff --git a/build.sbt b/build.sbt index 048f328e..c447d2d7 100644 --- a/build.sbt +++ b/build.sbt @@ -3,22 +3,23 @@ name := "Plume" inThisBuild( List( organization := "com.github.plume-oss", - version := "1.1.19", + version := "1.2.0", scalaVersion := "2.13.8", crossScalaVersions := Seq("2.13.8", "3.1.1"), resolvers ++= Seq( Resolver.mavenLocal, Resolver.mavenCentral, - Resolver.JCenterRepository + Resolver.JCenterRepository, + "Gradle Tooling" at "https://repo.gradle.org/gradle/libs-releases-local/" ) ) ) -val cpgVersion = "1.3.521" -val joernVersion = "1.1.661" +val cpgVersion = "1.3.523" +val joernVersion = "1.1.678" val sootVersion = "4.3.0" val tinkerGraphVersion = "3.4.11" -val neo4jVersion = "4.4.3" +val neo4jVersion = "4.4.5" val apacheCodecVersion = "1.15" val apacheIoVersion = "2.11.0" val apacheLangVersion = "3.12.0" @@ -43,7 +44,7 @@ Test / parallelExecution := false libraryDependencies ++= Seq( "io.shiftleft" %% "codepropertygraph" % cpgVersion, - "io.shiftleft" %% "semanticcpg" % cpgVersion, + "io.joern" %% "semanticcpg" % joernVersion, "io.joern" %% "dataflowengineoss" % joernVersion, "io.joern" %% "x2cpg" % joernVersion, "io.joern" %% "jimple2cpg" % joernVersion, diff --git a/src/main/scala/com/github/plume/oss/domain/package.scala b/src/main/scala/com/github/plume/oss/domain/package.scala index f69e2675..e5c3706a 100644 --- a/src/main/scala/com/github/plume/oss/domain/package.scala +++ b/src/main/scala/com/github/plume/oss/domain/package.scala @@ -12,7 +12,7 @@ import org.slf4j.LoggerFactory import overflowdb.traversal.jIteratortoTraversal import java.io.{File, FileInputStream, FileOutputStream} -import java.nio.file.Path +import java.nio.file.{Path, Paths} import java.util.concurrent.ConcurrentHashMap import scala.jdk.CollectionConverters import scala.jdk.CollectionConverters.ConcurrentMapHasAsScala @@ -28,6 +28,13 @@ package object domain { .addModule(DefaultScalaModule) .build() + case class DataFlowCacheConfig( + dataFlowCacheFile: Option[Path] = Some(Paths.get("dataFlowCache.cbor")), + compressDataFlowCache: Boolean = true, + maxCallDepth: Int = 2, + maxCachedPaths: Int = 1_000 + ) + /** Given an object and a path, will serialize the object to the given path. * @param o object to serialize. * @param p path to write serialized data to. diff --git a/src/main/scala/com/github/plume/oss/drivers/OverflowDbDriver.scala b/src/main/scala/com/github/plume/oss/drivers/OverflowDbDriver.scala index 11cce2a6..74a6141a 100644 --- a/src/main/scala/com/github/plume/oss/drivers/OverflowDbDriver.scala +++ b/src/main/scala/com/github/plume/oss/drivers/OverflowDbDriver.scala @@ -1,12 +1,7 @@ package com.github.plume.oss.drivers import com.github.plume.oss.PlumeStatistics -import com.github.plume.oss.domain.{ - SerialReachableByResult, - deserializeCache, - deserializeResultTable, - serializeCache -} +import com.github.plume.oss.domain._ import com.github.plume.oss.drivers.OverflowDbDriver.newOverflowGraph import com.github.plume.oss.passes.callgraph.PlumeDynamicCallLinker import com.github.plume.oss.util.BatchedUpdateUtil._ @@ -25,7 +20,7 @@ import overflowdb.traversal.{Traversal, jIteratortoTraversal} import overflowdb.{BatchedUpdate, Config, DetachedNodeData, Edge, Node} import java.io.{FileOutputStream, OutputStreamWriter, File => JFile} -import java.nio.file.{Files, Path, Paths} +import java.nio.file.Files import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable import scala.io.{BufferedSource, Source} @@ -36,9 +31,7 @@ import scala.util.{Failure, Success, Try, Using} * @param storageLocation where the database will serialize to and deserialize from. * @param heapPercentageThreshold the percentage of the JVM heap from when the database will begin swapping to disk. * @param serializationStatsEnabled enables saving of serialization statistics. - * @param dataFlowCacheFile the path to the cache file where data-flow paths are saved to. If None then data flow - * results will not be saved. - * @param compressDataFlowCache whether to compress the serialized data-flow cache specified at dataFlowCacheFile. + * @param cacheConfig contains various configurations for using the data flow tracking capabilities of OverflowDB. */ final case class OverflowDbDriver( storageLocation: Option[String] = Option( @@ -46,8 +39,7 @@ final case class OverflowDbDriver( ), heapPercentageThreshold: Int = 80, serializationStatsEnabled: Boolean = false, - dataFlowCacheFile: Option[Path] = Some(Paths.get("dataFlowCache.cbor")), - compressDataFlowCache: Boolean = true + cacheConfig: DataFlowCacheConfig = DataFlowCacheConfig() ) extends IDriver { private val logger = LoggerFactory.getLogger(classOf[OverflowDbDriver]) @@ -74,13 +66,13 @@ final case class OverflowDbDriver( /** Reads the saved cache on the disk and retrieves it as a serializable object */ private def fetchCacheFromDisk: Option[ConcurrentHashMap[Long, Vector[SerialReachableByResult]]] = - dataFlowCacheFile match { + cacheConfig.dataFlowCacheFile match { case Some(filePath) => if (Files.isRegularFile(filePath)) Some( PlumeStatistics.time( PlumeStatistics.TIME_RETRIEVING_CACHE, - { deserializeCache(filePath, compressDataFlowCache) } + { deserializeCache(filePath, cacheConfig.compressDataFlowCache) } ) ) else @@ -96,10 +88,13 @@ final case class OverflowDbDriver( private implicit var context: EngineContext = EngineContext( Semantics.fromList(List()), - EngineConfig(initialTable = resultTable) + EngineConfig( + maxCallDepth = cacheConfig.maxCallDepth, + initialTable = resultTable + ) ) - private def saveDataflowCache(): Unit = dataFlowCacheFile match { + private def saveDataflowCache(): Unit = cacheConfig.dataFlowCacheFile match { case Some(filePath) if resultTable.isDefined && resultTable.get.table.nonEmpty => PlumeStatistics.time( PlumeStatistics.TIME_STORING_CACHE, { @@ -110,13 +105,13 @@ final case class OverflowDbDriver( t.put(n.id(), v.map(SerialReachableByResult.apply)) } // Write to disk - serializeCache(t, filePath, compressDataFlowCache) + serializeCache(t, filePath, cacheConfig.compressDataFlowCache) } ) case _ => // Do nothing } - /** Sets the context for the data-flow engine when performing [[nodesReachableBy]] queries. + /** Sets the context for the data-flow engine when performing [[flowsBetween]] queries. * * @param maxCallDepth the new method call depth. * @param methodSemantics the file containing method semantics for external methods. @@ -179,7 +174,11 @@ final case class OverflowDbDriver( override def clear(): Unit = { cpg.graph.nodes.asScala.foreach(safeRemove) - dataFlowCacheFile match { + resultTable match { + case Some(table) => table.table.clear() + case None => + } + cacheConfig.dataFlowCacheFile match { case Some(filePath) => filePath.toFile.delete() case None => // Do nothing } @@ -378,17 +377,42 @@ final case class OverflowDbDriver( * @param sanitizers a set of full method names to filter paths out with. * @return the source nodes whose data flows to the given sinks uninterrupted. */ - def nodesReachableBy( - source: Traversal[CfgNode], - sink: Traversal[CfgNode], + def flowsBetween( + source: () => Traversal[CfgNode], + sink: () => Traversal[CfgNode], sanitizers: Set[String] = Set.empty[String] ): List[ReachableByResult] = PlumeStatistics.time( PlumeStatistics.TIME_REACHABLE_BY_QUERYING, { import io.shiftleft.semanticcpg.language._ + // Strip the cache of only nodes that will be used the most in this query to get fast starts/finishes + cacheConfig.dataFlowCacheFile match { + case Some(_) => + val newCache = new ResultTable + val oldCache = resultTable.getOrElse(new ResultTable) + var currPathsInCache = 0 + scala.util.Random + .shuffle(source().l ++ sink().l) + .flatMap { x => + oldCache.get(x) match { + case Some(paths) => Some((x, paths)) + case None => None + } + } + .foreach { case (startOrEndNode, paths) => + if (currPathsInCache + paths.size <= cacheConfig.maxCachedPaths) { + currPathsInCache += paths.size + newCache.add(startOrEndNode, paths) + } + } + oldCache.table.clear() + resultTable = Some(newCache) + setDataflowContext(context.config.maxCallDepth, context.semantics, resultTable) + case _ => + } - val results: List[ReachableByResult] = sink - .reachableByDetailed(source)(context) + val results: List[ReachableByResult] = sink() + .reachableByDetailed(source())(context) captureDataflowCache(results) results // Remove a source/sink arguments referring to itself @@ -408,16 +432,15 @@ final case class OverflowDbDriver( ) private def captureDataflowCache(results: List[ReachableByResult]): Unit = { - dataFlowCacheFile match { + cacheConfig.dataFlowCacheFile match { case Some(_) => - // Reload latest results to the query engine context + // Capture latest results resultTable = (results .map(_.table) ++ List(resultTable).flatten).distinct .reduceOption((a: ResultTable, b: ResultTable) => { b.table.foreach { case (k, v) => a.add(k, v) } a }) - setDataflowContext(context.config.maxCallDepth, context.semantics, resultTable) case None => // Do nothing since no table means we aren't saving data and instead keeping memory low } } diff --git a/src/test/scala/com/github/plume/oss/DiffTests.scala b/src/test/scala/com/github/plume/oss/DiffTests.scala index e3613343..e322cbc3 100644 --- a/src/test/scala/com/github/plume/oss/DiffTests.scala +++ b/src/test/scala/com/github/plume/oss/DiffTests.scala @@ -64,7 +64,7 @@ class DiffTests extends AnyWordSpec with Matchers with BeforeAndAfterAll { driver.clear() driver.close() Paths.get(storage.get).toFile.delete() - driver.dataFlowCacheFile match { + driver.cacheConfig.dataFlowCacheFile match { case Some(jsonFile) => new File(jsonFile.toFile.getAbsolutePath + ".lz4").delete() case None => } @@ -102,12 +102,12 @@ class DiffTests extends AnyWordSpec with Matchers with BeforeAndAfterAll { val sinkNodesId1 = driver.cpg.call(Operators.addition).id.l val r1 = driver - .nodesReachableBy(driver.cpg.parameter("a"), driver.cpg.call(Operators.addition)) + .flowsBetween( () => driver.cpg.parameter("a"), () => driver.cpg.call(Operators.addition)) .map(_.path.map(_.node.id())) val cH1 = QueryEngineStatistics.results()(QueryEngineStatistics.PATH_CACHE_HITS) val cM1 = QueryEngineStatistics.results()(QueryEngineStatistics.PATH_CACHE_MISSES) val hitRatio1 = cH1.toDouble / (cH1 + cM1) * 100 - logger.info(s"Cache hit ratio $hitRatio1%") + logger.info(s"Cache hit ratio $hitRatio1% ($cH1 vs $cM1)") cH1 should be <= cM1 QueryEngineStatistics.reset() @@ -119,7 +119,7 @@ class DiffTests extends AnyWordSpec with Matchers with BeforeAndAfterAll { val sinkNodesId2 = driver.cpg.call(Operators.addition).id.l val r2 = driver - .nodesReachableBy(driver.cpg.parameter("a"), driver.cpg.call(Operators.addition)) + .flowsBetween(() => driver.cpg.parameter("a"), () => driver.cpg.call(Operators.addition)) .map(_.path.map(_.node.id())) val cH2 = QueryEngineStatistics.results()(QueryEngineStatistics.PATH_CACHE_HITS) val cM2 = QueryEngineStatistics.results()(QueryEngineStatistics.PATH_CACHE_MISSES) @@ -131,7 +131,7 @@ class DiffTests extends AnyWordSpec with Matchers with BeforeAndAfterAll { r1 shouldBe r2 // The cache should have a higher number of hits now from re-using the first query val hitRatio2 = cH2.toDouble / (cH2 + cM2) * 100 - logger.info(s"Cache hit ratio $hitRatio2%") + logger.info(s"Cache hit ratio $hitRatio2% ($cH2 vs $cM2)") hitRatio2 should be >= hitRatio1 } diff --git a/src/test/scala/com/github/plume/oss/querying/DataFlowTests.scala b/src/test/scala/com/github/plume/oss/querying/DataFlowTests.scala index eb1cce77..92e9e79c 100644 --- a/src/test/scala/com/github/plume/oss/querying/DataFlowTests.scala +++ b/src/test/scala/com/github/plume/oss/querying/DataFlowTests.scala @@ -47,7 +47,7 @@ class DataFlowTests extends Jimple2CpgFixture(Some(new OverflowDbDriver())) { val cpg = CPG(driver.cpg.graph) val r = driver - .nodesReachableBy(cpg.parameter("a"), cpg.call(".*")) + .flowsBetween(() => cpg.parameter("a"), () => cpg.call(".*")) val List(v1) = r.map(r => r.path.map(x => (x.node.method.name, x.node.code))) v1.head shouldBe ("foo", "int a") @@ -58,7 +58,7 @@ class DataFlowTests extends Jimple2CpgFixture(Some(new OverflowDbDriver())) { val cpg = CPG(driver.cpg.graph) val r = driver - .nodesReachableBy(cpg.parameter("a"), cpg.call("bar")) + .flowsBetween(() => cpg.parameter("a"), () => cpg.call("bar")) val List(v1) = r.map(r => r.path.map(x => (x.node.method.name, x.node.code))) v1.head shouldBe ("foo", "int a") @@ -69,7 +69,9 @@ class DataFlowTests extends Jimple2CpgFixture(Some(new OverflowDbDriver())) { val cpg = CPG(driver.cpg.graph) val r = driver - .nodesReachableBy(cpg.parameter("a"), cpg.call("println")) + .flowsBetween(() => cpg.parameter("a"), () => cpg.call("println")) + + r.map(r => r.path.map(x => (x.node.method.name, x.node.code))).foreach(println) r.size shouldBe 2 @@ -87,10 +89,11 @@ class DataFlowTests extends Jimple2CpgFixture(Some(new OverflowDbDriver())) { def source = cpg.call("taint").argument def sink = cpg.call("baz") - val r1 = driver.nodesReachableBy(source, sink) + val r1 = driver.flowsBetween(() => source, () => sink) + r1.map(r => r.path.map(x => (x.node.method.name, x.node.code))).foreach(println) r1.size shouldBe 1 - val r2 = driver.nodesReachableBy(source, sink, Set("Foo.falseClean:int(int)")) + val r2 = driver.flowsBetween(() => source, () => sink, Set("Foo.falseClean:int(int)")) r2.size shouldBe 0 } diff --git a/src/test/scala/com/github/plume/oss/testfixtures/Jimple2CpgFixture.scala b/src/test/scala/com/github/plume/oss/testfixtures/Jimple2CpgFixture.scala index 30bf2044..aaa2d1b8 100644 --- a/src/test/scala/com/github/plume/oss/testfixtures/Jimple2CpgFixture.scala +++ b/src/test/scala/com/github/plume/oss/testfixtures/Jimple2CpgFixture.scala @@ -3,6 +3,7 @@ package com.github.plume.oss.testfixtures import com.github.plume.oss.{Jimple2Cpg, PlumeStatistics} import com.github.plume.oss.drivers.OverflowDbDriver import com.github.plume.oss.JavaCompiler.compileJava +import com.github.plume.oss.domain.DataFlowCacheConfig import io.joern.x2cpg.testfixtures.{CodeToCpgFixture, LanguageFrontend} import io.shiftleft.codepropertygraph.Cpg import org.slf4j.LoggerFactory @@ -16,7 +17,7 @@ class PlumeFrontend(val _driver: Option[OverflowDbDriver]) extends LanguageFront private val logger = LoggerFactory.getLogger(classOf[PlumeFrontend]) val driver: OverflowDbDriver = _driver match { case Some(d) => d - case None => new OverflowDbDriver(dataFlowCacheFile = None) + case None => new OverflowDbDriver(cacheConfig = DataFlowCacheConfig(dataFlowCacheFile = None)) } override val fileSuffix: String = ".java"