From 94a40b2cd5c2814da28b1ca440aa081d020f4cfa Mon Sep 17 00:00:00 2001 From: David Baker Effendi Date: Fri, 4 Mar 2022 13:01:49 +0200 Subject: [PATCH] Fork Join Pass Update + New BulkTx(overflowdb.BatchUpdate) API (#240) * Moved diff pass into parent since it runs in serial * Basic infrastructure down, time to handle new applied diff graphs * Made sure nDiffT mutates the correct variable * OverflowDB new bulkTx handles create edges well * OverflowDB new bulkTx handles the rest * Handling method resolve failure on invoke expr * Working new diff graph tests for overflowdb * Working new diff for gremlin * Styling * Neo4j new bulkTx ready but untested * TigerGraph new bulkTx ready but untested * Updated changelog * Added to changelog * Fixed TG:bulkTx * See if this fixes things * See if this fixes things * Moved batched update helpers to BatchedUpdateUtil * Upgraded to latest CPG and Joern and accounted for those changes too * Quick blacklist bugfix * MethodStub blacklist fix --- CHANGELOG.md | 13 + build.sbt | 6 +- .../com/github/plume/oss/Jimple2Cpg.scala | 6 +- .../scala/com/github/plume/oss/Plume.scala | 2 +- .../plume/oss/drivers/GremlinDriver.scala | 78 ++++ .../github/plume/oss/drivers/IDriver.scala | 39 +- .../plume/oss/drivers/Neo4jDriver.scala | 126 +++++- .../plume/oss/drivers/OverflowDbDriver.scala | 33 +- .../plume/oss/drivers/TigerGraphDriver.scala | 82 +++- .../plume/oss/passes/PlumeCpgPass.scala | 224 ++++++++-- .../{concurrent => }/PlumeDiffPass.scala | 2 +- .../concurrent/PlumeCfgCreationPass.scala | 31 +- .../concurrent/PlumeConcurrentCpgPass.scala | 82 ++-- .../concurrent/PlumeConcurrentWriter.scala | 27 +- .../concurrent/PlumeContainsEdgePass.scala | 31 +- .../oss/passes/concurrent/PlumeHashPass.scala | 11 +- .../oss/passes/forkjoin/PlumeCdgPass.scala | 41 ++ .../forkjoin/PlumeCfgDominatorPass.scala | 41 ++ .../PlumeForkJoinParallelCpgPass.scala | 118 +++++ .../oss/passes/parallel/PlumeAstCreator.scala | 21 +- .../oss/passes/parallel/PlumeCdgPass.scala | 40 -- .../parallel/PlumeCfgDominatorPass.scala | 40 -- .../parallel/PlumeMethodStubCreator.scala | 156 +++++-- .../parallel/PlumeReachingDefPass.scala | 42 +- .../plume/oss/util/BatchedUpdateUtil.scala | 69 +++ .../plume/oss/util/ProgramHandlingUtil.scala | 2 +- .../oss/querying/MethodParameterTests.scala | 4 +- .../plume/oss/querying/MethodTests.scala | 2 +- .../oss/testfixtures/PlumeDriverFixture.scala | 411 +++++++++++++----- 29 files changed, 1339 insertions(+), 441 deletions(-) rename src/main/scala/com/github/plume/oss/passes/{concurrent => }/PlumeDiffPass.scala (97%) create mode 100644 src/main/scala/com/github/plume/oss/passes/forkjoin/PlumeCdgPass.scala create mode 100644 src/main/scala/com/github/plume/oss/passes/forkjoin/PlumeCfgDominatorPass.scala create mode 100644 src/main/scala/com/github/plume/oss/passes/forkjoin/PlumeForkJoinParallelCpgPass.scala delete mode 100644 src/main/scala/com/github/plume/oss/passes/parallel/PlumeCdgPass.scala delete mode 100644 src/main/scala/com/github/plume/oss/passes/parallel/PlumeCfgDominatorPass.scala create mode 100644 src/main/scala/com/github/plume/oss/util/BatchedUpdateUtil.scala diff --git a/CHANGELOG.md b/CHANGELOG.md index 910bf080..c8d1508d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,19 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ## [Unreleased] +### Added + +- Overloaded `bulkTx` to handle new `overflowdb.BatchedUpdate` objects. + +### Fixed + +- Instance where dynamic `InvokeExpr::getMethod` would fail by using `getMethodRef` instead. + +### Changed + +- Updated passes to handle new `ForkJoinParallel` passes. +- Upgraded CPG and Joern versions to latest. + ## [1.0.16] - 2022-03-01 ### Added diff --git a/build.sbt b/build.sbt index fb667eeb..3f9c112a 100644 --- a/build.sbt +++ b/build.sbt @@ -14,8 +14,8 @@ inThisBuild( ) ) -val cpgVersion = "1.3.493" -val joernVersion = "1.1.502" +val cpgVersion = "1.3.509" +val joernVersion = "1.1.590" val sootVersion = "4.2.1" val tinkerGraphVersion = "3.4.8" val neo4jVersion = "4.4.3" @@ -36,11 +36,13 @@ lazy val NeptuneIntTest = config("nepTest") extend Test trapExit := false Test / fork := true +Test / parallelExecution := false libraryDependencies ++= Seq( "io.shiftleft" %% "codepropertygraph" % cpgVersion, "io.shiftleft" %% "semanticcpg" % cpgVersion, "io.joern" %% "dataflowengineoss" % joernVersion, + "io.joern" %% "x2cpg" % joernVersion, "io.shiftleft" %% "semanticcpg" % cpgVersion % Test classifier "tests", "org.soot-oss" % "soot" % sootVersion, "org.apache.tinkerpop" % "tinkergraph-gremlin" % tinkerGraphVersion, diff --git a/src/main/scala/com/github/plume/oss/Jimple2Cpg.scala b/src/main/scala/com/github/plume/oss/Jimple2Cpg.scala index 9ef7cc56..eb0bf0b3 100644 --- a/src/main/scala/com/github/plume/oss/Jimple2Cpg.scala +++ b/src/main/scala/com/github/plume/oss/Jimple2Cpg.scala @@ -5,17 +5,17 @@ import com.github.plume.oss.passes._ import com.github.plume.oss.passes.concurrent.{ PlumeCfgCreationPass, PlumeContainsEdgePass, - PlumeDiffPass, PlumeHashPass } +import com.github.plume.oss.passes.forkjoin.{PlumeCdgPass, PlumeCfgDominatorPass} import com.github.plume.oss.passes.parallel._ import com.github.plume.oss.util.ProgramHandlingUtil import com.github.plume.oss.util.ProgramHandlingUtil.{extractSourceFilesFromArchive, moveClassFiles} +import io.joern.x2cpg.SourceFiles +import io.joern.x2cpg.X2Cpg.newEmptyCpg import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.{NodeTypes, PropertyNames} import io.shiftleft.passes.CpgPassBase -import io.shiftleft.x2cpg.SourceFiles -import io.shiftleft.x2cpg.X2Cpg.newEmptyCpg import org.slf4j.LoggerFactory import soot.options.Options import soot.{G, PhaseOptions, Scene, SootClass} diff --git a/src/main/scala/com/github/plume/oss/Plume.scala b/src/main/scala/com/github/plume/oss/Plume.scala index 9fc84de6..6c48f1ea 100644 --- a/src/main/scala/com/github/plume/oss/Plume.scala +++ b/src/main/scala/com/github/plume/oss/Plume.scala @@ -10,7 +10,7 @@ import com.github.plume.oss.drivers.{ TinkerGraphDriver } import io.circe.Json -import io.shiftleft.x2cpg.{X2Cpg, X2CpgConfig} +import io.joern.x2cpg.{X2Cpg, X2CpgConfig} import scopt.OParser import java.io.InputStreamReader diff --git a/src/main/scala/com/github/plume/oss/drivers/GremlinDriver.scala b/src/main/scala/com/github/plume/oss/drivers/GremlinDriver.scala index a1b3f82b..38eec386 100644 --- a/src/main/scala/com/github/plume/oss/drivers/GremlinDriver.scala +++ b/src/main/scala/com/github/plume/oss/drivers/GremlinDriver.scala @@ -1,6 +1,7 @@ package com.github.plume.oss.drivers import com.github.plume.oss.PlumeStatistics +import com.github.plume.oss.util.BatchedUpdateUtil._ import io.shiftleft.codepropertygraph.generated.nodes.{AbstractNode, NewNode, StoredNode} import io.shiftleft.codepropertygraph.generated.{EdgeTypes, NodeTypes, PropertyNames} import io.shiftleft.passes.AppliedDiffGraph @@ -18,6 +19,8 @@ import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.{ import org.apache.tinkerpop.gremlin.structure.{Edge, Graph, T, Vertex} import org.apache.tinkerpop.gremlin.tinkergraph.structure.TinkerGraph import org.slf4j.{Logger, LoggerFactory} +import overflowdb.BatchedUpdate.AppliedDiff +import overflowdb.{BatchedUpdate, DetachedNodeData} import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable @@ -90,6 +93,81 @@ abstract class GremlinDriver(txMax: Int = 50) extends IDriver { .foreach { ops: Seq[Change] => bulkEdgeTx(g(), ops, dg) } } + override def bulkTx(dg: AppliedDiff): Unit = { + dg.getDiffGraph.iterator.asScala + .collect { + case c: BatchedUpdate.RemoveNode => c + case c: BatchedUpdate.SetNodeProperty => c + case c: DetachedNodeData => c + } + .grouped(txMax) + .foreach { changes => + var ptr: Option[GraphTraversal[Vertex, Vertex]] = None + changes.foreach { + case node: DetachedNodeData => + val nodeId = typedNodeId(idFromNodeData(node)) + val propMap = propertiesFromNodeData(node) + ptr match { + case Some(p) => + ptr = Some(p.addV(node.label).property(T.id, nodeId)) + serializeLists(propMap).foreach { case (k, v) => p.property(k, v) } + case None => + ptr = Some(g().addV(node.label).property(T.id, nodeId)) + serializeLists(propMap).foreach { case (k, v) => ptr.get.property(k, v) } + } + case c: BatchedUpdate.RemoveNode => + val nodeId = typedNodeId(c.node.id()) + ptr match { + case Some(p) => ptr = Some(p.V(nodeId).drop()) + case None => ptr = Some(g().V(nodeId).drop()) + } + case c: BatchedUpdate.SetNodeProperty => + val v = + if ( + c.label == PropertyNames.INHERITS_FROM_TYPE_FULL_NAME || c.label == PropertyNames.OVERLAYS + ) + c.value.toString.split(",") + else c.value + val nodeId = typedNodeId(c.node.id()) + ptr match { + case Some(p) => ptr = Some(p.V(nodeId).property(c.label, v)) + case None => ptr = Some(g().V(nodeId).property(c.label, v)) + } + } + // Commit transaction + ptr match { + case Some(p) => p.iterate() + case None => + } + } + dg.getDiffGraph.iterator.asScala + .collect { case c: BatchedUpdate.CreateEdge => c } + .grouped(txMax) + .foreach { changes => + var ptr: Option[GraphTraversal[Vertex, Edge]] = None + changes.foreach { c: BatchedUpdate.CreateEdge => + val srcId = typedNodeId(idFromNodeData(c.src)) + val dstId = typedNodeId(idFromNodeData(c.dst)) + ptr match { + case Some(p) => ptr = Some(p.V(srcId).addE(c.label).to(__.V(dstId))) + case None => ptr = Some(g().V(srcId).addE(c.label).to(__.V(dstId))) + } + Option(c.propertiesAndKeys) match { + case Some(edgeKeyValues) => + propertiesFromObjectArray(edgeKeyValues).foreach { case (k, v) => + ptr.get.property(k, v) + } + case None => + } + } + // Commit transaction + ptr match { + case Some(p) => p.iterate() + case None => + } + } + } + private def bulkNodeTx( g: GraphTraversalSource, ops: Seq[Change], diff --git a/src/main/scala/com/github/plume/oss/drivers/IDriver.scala b/src/main/scala/com/github/plume/oss/drivers/IDriver.scala index fc388b22..115e418e 100644 --- a/src/main/scala/com/github/plume/oss/drivers/IDriver.scala +++ b/src/main/scala/com/github/plume/oss/drivers/IDriver.scala @@ -1,39 +1,11 @@ package com.github.plume.oss.drivers -import io.shiftleft.codepropertygraph.generated.nodes.{ - AbstractNode, - Block, - Call, - ControlStructure, - FieldIdentifier, - File, - Identifier, - JumpTarget, - Literal, - Local, - Member, - MetaData, - Method, - MethodParameterIn, - MethodParameterOut, - MethodRef, - MethodReturn, - Modifier, - Namespace, - NamespaceBlock, - NewNode, - Return, - StoredNode, - Type, - TypeArgument, - TypeDecl, - TypeParameter, - TypeRef, - Unknown -} +import io.shiftleft.codepropertygraph.generated.nodes._ import io.shiftleft.codepropertygraph.generated.{EdgeTypes, NodeTypes, PropertyNames} import io.shiftleft.passes.AppliedDiffGraph import org.slf4j.LoggerFactory +import overflowdb.BatchedUpdate.AppliedDiff +import overflowdb.{DetachedNodeData, DetachedNodeGeneric, Node} import scala.collection.mutable @@ -64,6 +36,11 @@ trait IDriver extends AutoCloseable { */ def bulkTx(dg: AppliedDiffGraph): Unit + /** Executes all changes contained within the given overflowdb.BatchedUpdate.AppliedDiff as a (or set of) + * bulk transaction(s). + */ + def bulkTx(dg: AppliedDiff): Unit + /** Given filenames, will remove related TYPE, TYPE_DECL, METHOD (with AST children), and NAMESPACE_BLOCK. */ def removeSourceFiles(filenames: String*): Unit diff --git a/src/main/scala/com/github/plume/oss/drivers/Neo4jDriver.scala b/src/main/scala/com/github/plume/oss/drivers/Neo4jDriver.scala index 27e6a5ed..8a534955 100644 --- a/src/main/scala/com/github/plume/oss/drivers/Neo4jDriver.scala +++ b/src/main/scala/com/github/plume/oss/drivers/Neo4jDriver.scala @@ -2,18 +2,21 @@ package com.github.plume.oss.drivers import com.github.plume.oss.PlumeStatistics import com.github.plume.oss.drivers.Neo4jDriver._ -import io.shiftleft.codepropertygraph.generated.nodes.NewNode +import com.github.plume.oss.util.BatchedUpdateUtil._ +import io.shiftleft.codepropertygraph.generated.nodes.{NewNode, StoredNode} import io.shiftleft.codepropertygraph.generated.{EdgeTypes, NodeTypes, PropertyNames} import io.shiftleft.passes.AppliedDiffGraph import io.shiftleft.passes.DiffGraph.Change import org.neo4j.driver.{AuthTokens, GraphDatabase, Transaction, Value} import org.slf4j.LoggerFactory +import overflowdb.BatchedUpdate.AppliedDiff +import overflowdb.{BatchedUpdate, DetachedNodeData} import java.util import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable import scala.jdk.CollectionConverters -import scala.jdk.CollectionConverters.CollectionHasAsScala +import scala.jdk.CollectionConverters.{CollectionHasAsScala, IteratorHasAsScala} import scala.util.{Failure, Success, Try, Using} /** The driver used to connect to a remote Neo4j instance. Once can optionally call buildSchema to add indexes for @@ -108,22 +111,46 @@ final class Neo4jDriver( }) } ++ Map("id" -> id) + nodePropertiesToCypherQuery(pMap) + } + + private def nodePayload(n: DetachedNodeData): (util.Map[String, Object], String) = { + val pMap = propertiesFromNodeData(n).map { case (k, v) => + k -> (v match { + case x: String => x + case xs: Seq[_] => CollectionConverters.IterableHasAsJava(xs.toList).asJava + case x => x + }) + } ++ Map("id" -> idFromNodeData(n)) + + nodePropertiesToCypherQuery(pMap) + } + + private def nodePropertiesToCypherQuery(pMap: Map[String, Any]) = { val pString = pMap.map { case (k, _) => s"$k:$$$k" }.mkString(",") val jpMap = new util.HashMap[String, Object](pMap.size) pMap.foreach { case (x, y) => jpMap.put(x, y.asInstanceOf[Object]) } (jpMap, pString) } - private def bulkDeleteNode(ops: Seq[Change.RemoveNode]): Unit = + private def bulkDeleteNode(ops: Seq[Any]): Unit = Using.resource(driver.session()) { session => Using.resource(session.beginTransaction()) { tx => ops - .map { case Change.RemoveNode(nodeId) => + .collect { + case c: Change.RemoveNode => (c.nodeId, None) + case c: BatchedUpdate.RemoveNode => (c.node.id(), Some(c.node.label())) + } + .map { case (nodeId: Long, maybeLabel: Option[String]) => + val label = maybeLabel match { + case Some(value) => s":$value" + case None => "" + } ( - """ - |MATCH (n {id:$nodeId}) - |DETACH DELETE (n) - |""".stripMargin, + s""" + |MATCH (n$label {id:$$nodeId}) + |DETACH DELETE (n) + |""".stripMargin, nodeId ) } @@ -164,12 +191,37 @@ final class Neo4jDriver( } } - private def bulkNodeSetProperty(ops: Seq[Change.SetNodeProperty]): Unit = + private def bulkCreateNode(ops: Seq[DetachedNodeData]): Unit = + Using.resource(driver.session()) { session => + Using.resource(session.beginTransaction()) { tx => + ops + .map { c: DetachedNodeData => + val (params, pString) = nodePayload(c) + params -> s"MERGE (n:${c.label} {$pString})" + } + .foreach { case (params: util.Map[String, Object], query: String) => + Try(tx.run(query, params)) match { + case Failure(e) => + logger.error(s"Unable to write bulk create node transaction $query", e) + case Success(_) => + } + } + tx.commit() + } + } + + private def bulkNodeSetProperty(ops: Seq[Any]): Unit = Using.resource(driver.session()) { session => Using.resource(session.beginTransaction()) { tx => ops - .map { case Change.SetNodeProperty(node, k, v) => - val newV = v match { + .collect { + case c: BatchedUpdate.SetNodeProperty => + (c.label, c.value, c.node) + case c: Change.SetNodeProperty => + (c.key, c.value, c.node) + } + .map { case (key: String, value: Any, node: StoredNode) => + val newV = value match { case x: String => "\"" + x + "\"" case Seq() => IDriver.STRING_DEFAULT case xs: Seq[_] => @@ -184,7 +236,7 @@ final class Neo4jDriver( }, s""" |MATCH (n:${node.label} {id: $$nodeId}) - |SET n.$k = $$newV + |SET n.$key = $$newV |""".stripMargin ) } @@ -230,7 +282,7 @@ final class Neo4jDriver( /** This does not add edge properties as they are often not used in the CPG. */ - private def bulkCreateEdge(ops: Seq[Change.CreateEdge], dg: AppliedDiffGraph): Unit = { + private def bulkCreateEdge(ops: Seq[Change.CreateEdge], dg: AppliedDiffGraph): Unit = Using.resource(driver.session()) { session => Using.resource(session.beginTransaction()) { tx => ops.foreach { case Change.CreateEdge(src, dst, label, _) => @@ -255,7 +307,34 @@ final class Neo4jDriver( tx.commit() } } - } + + private def bulkCreateEdge(ops: Seq[BatchedUpdate.CreateEdge]): Unit = + Using.resource(driver.session()) { session => + Using.resource(session.beginTransaction()) { tx => + ops.foreach { c: BatchedUpdate.CreateEdge => + val srcLabel = labelFromNodeData(c.src) + val dstLabel = labelFromNodeData(c.dst) + val query = s""" + |MATCH (src:$srcLabel {id: $$srcId}), (dst:$dstLabel {id: $$dstId}) + |CREATE (src)-[:${c.label}]->(dst) + |""".stripMargin + Try( + tx.run( + query, + new util.HashMap[String, Object](2) { + put("srcId", idFromNodeData(c.src).asInstanceOf[Object]) + put("dstId", idFromNodeData(c.dst).asInstanceOf[Object]) + } + ) + ) match { + case Failure(e) => + logger.error(s"Unable to write bulk create edge transaction $query", e) + case Success(_) => + } + } + tx.commit() + } + } override def bulkTx(dg: AppliedDiffGraph): Unit = { // Node operations @@ -282,6 +361,25 @@ final class Neo4jDriver( .foreach(bulkCreateEdge(_, dg)) } + override def bulkTx(dg: AppliedDiff): Unit = { + dg.diffGraph.iterator.asScala + .collect { case x: DetachedNodeData => x } + .grouped(txMax) + .foreach(bulkCreateNode) + dg.diffGraph.iterator.asScala + .collect { case x: BatchedUpdate.SetNodeProperty => x } + .grouped(txMax) + .foreach(bulkNodeSetProperty) + dg.diffGraph.iterator.asScala + .collect { case x: BatchedUpdate.RemoveNode => x } + .grouped(txMax) + .foreach(bulkDeleteNode) + dg.diffGraph.iterator.asScala + .collect { case x: BatchedUpdate.CreateEdge => x } + .grouped(txMax) + .foreach(bulkCreateEdge) + } + /** Removes the namespace block with all of its AST children specified by the given FILENAME property. */ private def deleteNamespaceBlockWithAstChildrenByFilename(filename: String): Unit = diff --git a/src/main/scala/com/github/plume/oss/drivers/OverflowDbDriver.scala b/src/main/scala/com/github/plume/oss/drivers/OverflowDbDriver.scala index 89a96951..1ef77263 100644 --- a/src/main/scala/com/github/plume/oss/drivers/OverflowDbDriver.scala +++ b/src/main/scala/com/github/plume/oss/drivers/OverflowDbDriver.scala @@ -9,6 +9,7 @@ import com.github.plume.oss.domain.{ } import com.github.plume.oss.drivers.OverflowDbDriver.newOverflowGraph import com.github.plume.oss.passes.PlumeDynamicCallLinker +import com.github.plume.oss.util.BatchedUpdateUtil._ import io.joern.dataflowengineoss.language.toExtendedCfgNode import io.joern.dataflowengineoss.queryengine._ import io.joern.dataflowengineoss.semanticsloader.{Parser, Semantics} @@ -18,8 +19,9 @@ import io.shiftleft.codepropertygraph.{Cpg => CPG} import io.shiftleft.passes.AppliedDiffGraph import io.shiftleft.passes.DiffGraph.{Change, PackedProperties} import org.slf4j.LoggerFactory +import overflowdb.BatchedUpdate.AppliedDiff import overflowdb.traversal.{Traversal, jIteratortoTraversal} -import overflowdb.{Config, Node} +import overflowdb.{BatchedUpdate, Config, DetachedNodeData, Node} import java.io.{File => JFile} import java.nio.file.{Files, Path, Paths} @@ -160,12 +162,12 @@ final case class OverflowDbDriver( case Change.RemoveNode(nodeId) => cpg.graph.node(nodeId).remove() case Change.RemoveNodeProperty(nodeId, propertyKey) => - cpg.graph.nodes(nodeId).next().removeProperty(propertyKey) + cpg.graph.node(nodeId).removeProperty(propertyKey) case Change.CreateNode(node) => val newNode = cpg.graph.addNode(dg.nodeToGraphId(node), node.label) node.properties.foreach { case (k, v) => newNode.setProperty(k, v) } case Change.SetNodeProperty(node, key, value) => - cpg.graph.nodes(node.id()).next().setProperty(key, value) + cpg.graph.node(node.id()).setProperty(key, value) case _ => // do nothing } // Now that all nodes are in, connect/remove edges @@ -180,7 +182,7 @@ final case class OverflowDbDriver( val srcId: Long = id(src, dg).asInstanceOf[Long] val dstId: Long = id(dst, dg).asInstanceOf[Long] val e: overflowdb.Edge = - cpg.graph.nodes(srcId).next().addEdge(label, cpg.graph.nodes(dstId).next()) + cpg.graph.node(srcId).addEdge(label, cpg.graph.node(dstId)) PackedProperties.unpack(packedProperties).foreach { case (k: String, v: Any) => e.setProperty(k, v) } @@ -188,6 +190,27 @@ final case class OverflowDbDriver( } } + override def bulkTx(dg: AppliedDiff): Unit = { + dg.getDiffGraph.iterator.forEachRemaining { + case node: DetachedNodeData => + val id = idFromNodeData(node) + val newNode = cpg.graph.addNode(id, node.label) + propertiesFromNodeData(node).foreach { case (k, v) => newNode.setProperty(k, v) } + case c: BatchedUpdate.CreateEdge => + val srcId = idFromNodeData(c.src) + val dstId = idFromNodeData(c.dst) + val e = cpg.graph.node(srcId).addEdge(c.label, cpg.graph.node(dstId)) + Option(c.propertiesAndKeys) match { + case Some(edgeKeyValues) => + propertiesFromObjectArray(edgeKeyValues).foreach { case (k, v) => e.setProperty(k, v) } + case None => + } + case c: BatchedUpdate.RemoveNode => cpg.graph.node(c.node.id()).remove() + case c: BatchedUpdate.SetNodeProperty => + cpg.graph.node(c.node.id()).setProperty(c.label, c.value) + } + } + private def dfsDelete( n: Node, visitedNodes: mutable.Set[Node], @@ -285,7 +308,7 @@ final case class OverflowDbDriver( .foreach { c: Call => methodFullNameToNode.get(c.methodFullName) match { case Some(dstId) if cpg.graph.nodes(dstId.asInstanceOf[Long]).hasNext => - c.addEdge(EdgeTypes.CALL, cpg.graph.nodes(dstId.asInstanceOf[Long]).next()) + c.addEdge(EdgeTypes.CALL, cpg.graph.node(dstId.asInstanceOf[Long])) case _ => } } diff --git a/src/main/scala/com/github/plume/oss/drivers/TigerGraphDriver.scala b/src/main/scala/com/github/plume/oss/drivers/TigerGraphDriver.scala index 51e54509..69a5103f 100644 --- a/src/main/scala/com/github/plume/oss/drivers/TigerGraphDriver.scala +++ b/src/main/scala/com/github/plume/oss/drivers/TigerGraphDriver.scala @@ -2,6 +2,7 @@ package com.github.plume.oss.drivers import com.github.plume.oss.domain.TigerGraphResponse import com.github.plume.oss.drivers.TigerGraphDriver._ +import com.github.plume.oss.util.BatchedUpdateUtil._ import io.circe import io.circe.generic.auto._ import io.circe.syntax._ @@ -11,6 +12,8 @@ import io.shiftleft.codepropertygraph.generated.{EdgeTypes, NodeTypes, PropertyN import io.shiftleft.passes.AppliedDiffGraph import io.shiftleft.passes.DiffGraph.Change import org.slf4j.LoggerFactory +import overflowdb.BatchedUpdate.AppliedDiff +import overflowdb.{BatchedUpdate, DetachedNodeData} import sttp.client3._ import sttp.client3.circe._ import sttp.model.{MediaType, Uri} @@ -19,7 +22,7 @@ import java.io.{ByteArrayOutputStream, IOException, PrintStream} import java.security.Permission import scala.collection.mutable import scala.concurrent.duration.{Duration, DurationInt} -import scala.jdk.CollectionConverters.CollectionHasAsScala +import scala.jdk.CollectionConverters.{CollectionHasAsScala, IteratorHasAsScala} import scala.util.{Failure, Success, Try} /** The driver used to communicate to a remote TigerGraph instance. One must build a schema on the first use of the database. @@ -120,8 +123,8 @@ final class TigerGraphDriver( } } - private def nodePayload(id: Long, n: NewNode): JsonObject = { - val attributes = n.properties.flatMap { case (k, v) => + private def nodePayload(id: Long, label: String, properties: Map[String, Any]): JsonObject = { + val attributes = properties.flatMap { case (k, v) => val vStr = v match { case xs: Seq[_] => xs.mkString(",") case x if x == null => IDriver.getPropertyDefault(k) @@ -135,7 +138,7 @@ final class TigerGraphDriver( JsonObject.fromMap( Map( - s"${n.label}_" -> JsonObject + s"${label}_" -> JsonObject .fromMap(Map(id.toString -> JsonObject.fromMap(attributes).asJson)) .asJson ) @@ -182,10 +185,32 @@ final class TigerGraphDriver( .foreach(bulkCreateEdge(_, dg)) } - private def bulkDeleteNode(ops: Seq[Change.RemoveNode]): Unit = { + override def bulkTx(dg: AppliedDiff): Unit = { + // Node operations + dg.getDiffGraph.iterator.asScala + .collect { case x: BatchedUpdate.RemoveNode => x } + .grouped(txMax) + .foreach(bulkDeleteNode) + dg.diffGraph.iterator.asScala + .collect { case x: DetachedNodeData => x } + .grouped(txMax) + .foreach(bulkCreateNode) + dg.diffGraph.iterator.asScala + .collect { case x: BatchedUpdate.SetNodeProperty => x } + .grouped(txMax) + .foreach(bulkNodeSetProperty) + // Edge operations + dg.diffGraph.iterator.asScala + .collect { case x: BatchedUpdate.CreateEdge => x } + .grouped(txMax) + .foreach(bulkCreateEdge) + } + + private def bulkDeleteNode(ops: Seq[Any]): Unit = { val ids = ops.flatMap { - case Change.RemoveNode(nodeId) => Some(nodeId) - case _ => None + case c: Change.RemoveNode => Some(c.nodeId) + case c: BatchedUpdate.RemoveNode => Some(c.node.id()) + case _ => None } get("query/cpg/v_delete", ids.map { i => "ids" -> i.toString }) } @@ -193,17 +218,33 @@ final class TigerGraphDriver( private def bulkCreateNode(ops: Seq[Change.CreateNode], dg: AppliedDiffGraph): Unit = { val payload = ops .flatMap { - case Change.CreateNode(node) => Some(nodePayload(id(node, dg).asInstanceOf[Long], node)) - case _ => None + case Change.CreateNode(node) => + Some(nodePayload(id(node, dg).asInstanceOf[Long], node.label(), node.properties)) + case _ => None + } + .reduce { (a: JsonObject, b: JsonObject) => a.deepMerge(b) } + post("graph/cpg", PayloadBody(vertices = payload)) + } + + private def bulkCreateNode(ops: Seq[DetachedNodeData]): Unit = { + val payload = ops + .flatMap { + case c: DetachedNodeData => + Some(nodePayload(idFromNodeData(c), c.label(), propertiesFromNodeData(c))) + case _ => None } .reduce { (a: JsonObject, b: JsonObject) => a.deepMerge(b) } post("graph/cpg", PayloadBody(vertices = payload)) } - private def bulkNodeSetProperty(ops: Seq[Change.SetNodeProperty]): Unit = { + private def bulkNodeSetProperty(ops: Seq[Any]): Unit = { val payload = ops + .collect { + case c: BatchedUpdate.SetNodeProperty => (c.label, c.value, c.node) + case c: Change.SetNodeProperty => (c.key, c.value, c.node) + } .flatMap { - case Change.SetNodeProperty(n: StoredNode, key, value) => + case (key, value, n: StoredNode) => jsonValue(value) match { case Some(v) => val kv = @@ -249,6 +290,25 @@ final class TigerGraphDriver( post("graph/cpg", PayloadBody(edges = payload)) } + private def bulkCreateEdge(ops: Seq[BatchedUpdate.CreateEdge]): Unit = { + val payload = ops + .flatMap { + case c: BatchedUpdate.CreateEdge => + Some( + edgePayload( + idFromNodeData(c.src), + labelFromNodeData(c.src), + idFromNodeData(c.dst), + labelFromNodeData(c.dst), + c.label + ) + ) + case _ => None + } + .reduce { (a: JsonObject, b: JsonObject) => a.deepMerge(b) } + post("graph/cpg", PayloadBody(edges = payload)) + } + override def removeSourceFiles(filenames: String*): Unit = get("query/cpg/delete_source_file", filenames.map(("filenames", _))) diff --git a/src/main/scala/com/github/plume/oss/passes/PlumeCpgPass.scala b/src/main/scala/com/github/plume/oss/passes/PlumeCpgPass.scala index d001a284..b2e52955 100644 --- a/src/main/scala/com/github/plume/oss/passes/PlumeCpgPass.scala +++ b/src/main/scala/com/github/plume/oss/passes/PlumeCpgPass.scala @@ -1,11 +1,16 @@ package com.github.plume.oss.passes import com.github.plume.oss.drivers.IDriver +import com.github.plume.oss.passes.forkjoin.PlumeForkJoinParallelCpgPass.{ + DiffGraphBuilder, + forkJoinSerializeAndStore +} +import com.github.plume.oss.util.BatchedUpdateUtil import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.PropertyNames -import io.shiftleft.codepropertygraph.generated.nodes.{AbstractNode, NewNode, StoredNode} +import io.shiftleft.codepropertygraph.generated.nodes.{AbstractNode, Method, NewNode, StoredNode} import io.shiftleft.passes.DiffGraph.Change -import io.shiftleft.passes.{DiffGraph, KeyPool} +import io.shiftleft.passes.{DiffGraph, ForkJoinParallelCpgPass, KeyPool} import io.shiftleft.semanticcpg.passes.base.{ FileCreationPass, MethodDecoratorPass, @@ -13,6 +18,22 @@ import io.shiftleft.semanticcpg.passes.base.{ TypeDeclStubCreator } import io.shiftleft.semanticcpg.passes.frontend.{MetaDataPass, TypeNodePass} +import overflowdb.traversal.jIteratortoTraversal +import overflowdb.{BatchedUpdate, DetachedNodeData, DetachedNodeGeneric, Node, NodeOrDetachedNode} + +abstract class PlumeSimpleCpgPass(cpg: Cpg, outName: String = "", keyPool: Option[KeyPool] = None) + extends ForkJoinParallelCpgPass[AnyRef](cpg, outName, keyPool) { + + def run(builder: overflowdb.BatchedUpdate.DiffGraphBuilder): Unit + + final override def generateParts(): Array[_ <: AnyRef] = Array[AnyRef](null) + + final override def runOnPart( + builder: overflowdb.BatchedUpdate.DiffGraphBuilder, + part: AnyRef + ): Unit = + run(builder) +} class PlumeMetaDataPass( cpg: Cpg, @@ -22,13 +43,25 @@ class PlumeMetaDataPass( ) extends MetaDataPass(cpg, language, keyPool) with PlumeCpgPassBase { - override def createAndApply(driver: IDriver): Unit = { - if (blacklist.isEmpty) { - withStartEndTimesLogged { - run() - .map(diffGraph => DiffGraph.Applier.applyDiff(diffGraph, cpg, undoable = false, keyPool)) - .foreach(driver.bulkTx) - } + def createAndApply(driver: IDriver): Unit = { + if (blacklist.isEmpty) // If not empty then do not generate duplicate meta data nodes + createApplySerializeAndStore(driver) // Apply to driver + } + + def createApplySerializeAndStore(driver: IDriver): Unit = { + try { + init() + forkJoinSerializeAndStore( + driver, + name, + cpg, + baseLogger, + generateParts(), + (builder: DiffGraphBuilder, part: Method) => runOnPart(builder, part), + keyPool + ) + } finally { + finish() } } @@ -38,12 +71,26 @@ class PlumeNamespaceCreator(cpg: Cpg, keyPool: Option[KeyPool], blacklist: Set[S extends NamespaceCreator(cpg) with PlumeCpgPassBase { - override def createAndApply(driver: IDriver): Unit = { - withStartEndTimesLogged { - run() - .map(dg => PlumeCpgPass.filterDiffGraph(dg, PropertyNames.NAME, blacklist)) - .map(diffGraph => DiffGraph.Applier.applyDiff(diffGraph, cpg, undoable = false, keyPool)) - .foreach(driver.bulkTx) + def createAndApply(driver: IDriver): Unit = { + createApplySerializeAndStore(driver) // Apply to driver + } + + def createApplySerializeAndStore(driver: IDriver): Unit = { + try { + init() + forkJoinSerializeAndStore( + driver, + name, + cpg, + baseLogger, + generateParts(), + (builder: DiffGraphBuilder, part: Method) => runOnPart(builder, part), + keyPool, + blacklist, + PropertyNames.NAME + ) + } finally { + finish() } } @@ -53,11 +100,24 @@ class PlumeFileCreationPass(cpg: Cpg, keyPool: Option[KeyPool]) extends FileCreationPass(cpg) with PlumeCpgPassBase { - override def createAndApply(driver: IDriver): Unit = { - withStartEndTimesLogged { - run() - .map(diffGraph => DiffGraph.Applier.applyDiff(diffGraph, cpg, undoable = false, keyPool)) - .foreach(driver.bulkTx) + def createAndApply(driver: IDriver): Unit = { + createApplySerializeAndStore(driver) // Apply to driver + } + + def createApplySerializeAndStore(driver: IDriver): Unit = { + try { + init() + forkJoinSerializeAndStore( + driver, + name, + cpg, + baseLogger, + generateParts(), + (builder: DiffGraphBuilder, part: Method) => runOnPart(builder, part), + keyPool + ) + } finally { + finish() } } @@ -71,49 +131,122 @@ class PlumeTypeNodePass( ) extends TypeNodePass(usedTypes, cpg, keyPool) with PlumeCpgPassBase { - override def createAndApply(driver: IDriver): Unit = { - withStartEndTimesLogged { - run() - .map(dg => PlumeCpgPass.filterDiffGraph(dg, PropertyNames.FULL_NAME, blacklist)) - .map(diffGraph => DiffGraph.Applier.applyDiff(diffGraph, cpg, undoable = false, keyPool)) - .foreach(driver.bulkTx) + def createAndApply(driver: IDriver): Unit = { + createApplySerializeAndStore(driver) // Apply to driver + } + + def createApplySerializeAndStore(driver: IDriver): Unit = { + try { + init() + forkJoinSerializeAndStore( + driver, + name, + cpg, + baseLogger, + generateParts(), + (builder: DiffGraphBuilder, part: Method) => runOnPart(builder, part), + keyPool, + blacklist, + PropertyNames.FULL_NAME + ) + } finally { + finish() } } + } class PlumeTypeDeclStubCreator(cpg: Cpg, keyPool: Option[KeyPool], blacklist: Set[String] = Set()) extends TypeDeclStubCreator(cpg) with PlumeCpgPassBase { - override def createAndApply(driver: IDriver): Unit = { - withStartEndTimesLogged { - run() - .map(dg => PlumeCpgPass.filterDiffGraph(dg, PropertyNames.FULL_NAME, blacklist)) - .map(diffGraph => DiffGraph.Applier.applyDiff(diffGraph, cpg, undoable = false, keyPool)) - .foreach(driver.bulkTx) + def createAndApply(driver: IDriver): Unit = { + createApplySerializeAndStore(driver) // Apply to driver + } + + def createApplySerializeAndStore(driver: IDriver): Unit = { + try { + init() + forkJoinSerializeAndStore( + driver, + name, + cpg, + baseLogger, + generateParts(), + (builder: DiffGraphBuilder, part: Method) => runOnPart(builder, part), + keyPool, + blacklist, + PropertyNames.FULL_NAME + ) + } finally { + finish() } } + } class PlumeMethodDecoratorPass(cpg: Cpg, keyPool: Option[KeyPool], blacklist: Set[String] = Set()) extends MethodDecoratorPass(cpg) with PlumeCpgPassBase { - override def createAndApply(driver: IDriver): Unit = { - withStartEndTimesLogged { - run() - .map(dg => - PlumeCpgPass - .filterDiffGraph(dg, PropertyNames.TYPE_FULL_NAME, blacklist, rejectAllOnFail = true) - ) - .map(diffGraph => DiffGraph.Applier.applyDiff(diffGraph, cpg, undoable = false, keyPool)) - .foreach(driver.bulkTx) + def createAndApply(driver: IDriver): Unit = { + createApplySerializeAndStore(driver) // Apply to driver + } + + def createApplySerializeAndStore(driver: IDriver): Unit = { + try { + init() + forkJoinSerializeAndStore( + driver, + name, + cpg, + baseLogger, + generateParts(), + (builder: DiffGraphBuilder, part: Method) => runOnPart(builder, part), + keyPool, + blacklist, + PropertyNames.TYPE_FULL_NAME, + blacklistRejectOnFail = true + ) + } finally { + finish() } } + } object PlumeCpgPass { + def filterBatchedDiffGraph( + dg: BatchedUpdate.DiffGraph, + key: String, + blacklist: Set[String], + rejectAllOnFail: Boolean = false + ): BatchedUpdate.DiffGraph = { + val newDg = new DiffGraphBuilder + dg.iterator.foreach { + case c: DetachedNodeData => + val properties = c match { + case generic: DetachedNodeGeneric => + BatchedUpdateUtil.propertiesFromObjectArray(generic.keyvalues) + case node: NewNode => node.properties + case _ => Map.empty[String, Object] + } + if (!blacklist.contains(properties.getOrElse(key, "").toString)) + newDg.addNode(c) + case c: BatchedUpdate.CreateEdge => + val srcProperty = getPropertyFromAbstractNode[String](c.src, key) + val dstProperty = getPropertyFromAbstractNode[String](c.dst, key) + if (!blacklist.contains(srcProperty) && !blacklist.contains(dstProperty)) { + newDg.addEdge(c.src, c.dst, c.label) + } else if (rejectAllOnFail) { + return (new DiffGraphBuilder).build() + } + case _ => + } + newDg.build() + } + def filterDiffGraph( dg: DiffGraph, key: String, @@ -146,4 +279,13 @@ object PlumeCpgPass { } } + private def getPropertyFromAbstractNode[T](node: NodeOrDetachedNode, key: String): T = { + node match { + case generic: DetachedNodeGeneric => + BatchedUpdateUtil.propertiesFromObjectArray(generic.keyvalues)(key).asInstanceOf[T] + case node: NewNode => node.properties(key).asInstanceOf[T] + case node: Node => node.property(key).asInstanceOf[T] + } + } + } diff --git a/src/main/scala/com/github/plume/oss/passes/concurrent/PlumeDiffPass.scala b/src/main/scala/com/github/plume/oss/passes/PlumeDiffPass.scala similarity index 97% rename from src/main/scala/com/github/plume/oss/passes/concurrent/PlumeDiffPass.scala rename to src/main/scala/com/github/plume/oss/passes/PlumeDiffPass.scala index fe467c7b..37ebb0bf 100644 --- a/src/main/scala/com/github/plume/oss/passes/concurrent/PlumeDiffPass.scala +++ b/src/main/scala/com/github/plume/oss/passes/PlumeDiffPass.scala @@ -1,4 +1,4 @@ -package com.github.plume.oss.passes.concurrent +package com.github.plume.oss.passes import com.github.plume.oss.drivers.IDriver import com.github.plume.oss.util.HashUtil diff --git a/src/main/scala/com/github/plume/oss/passes/concurrent/PlumeCfgCreationPass.scala b/src/main/scala/com/github/plume/oss/passes/concurrent/PlumeCfgCreationPass.scala index 6f7b5b1e..e02f3750 100644 --- a/src/main/scala/com/github/plume/oss/passes/concurrent/PlumeCfgCreationPass.scala +++ b/src/main/scala/com/github/plume/oss/passes/concurrent/PlumeCfgCreationPass.scala @@ -5,7 +5,6 @@ import com.github.plume.oss.passes.PlumeCpgPassBase import com.github.plume.oss.passes.concurrent.PlumeConcurrentCpgPass.concurrentCreateApply import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.nodes.Method -import io.shiftleft.passes.DiffGraph import io.shiftleft.semanticcpg.passes.controlflow.CfgCreationPass object PlumeCfgCreationPass { @@ -20,17 +19,25 @@ class PlumeCfgCreationPass(cpg: Cpg) extends CfgCreationPass(cpg) with PlumeCpgP def createApplySerializeAndStore(driver: IDriver): Unit = { import PlumeCfgCreationPass.producerQueueCapacity - concurrentCreateApply[Method]( - producerQueueCapacity, - driver, - name, - baseLogger, - _ => init(), - _ => generateParts(), - cpg, - (x: DiffGraph.Builder, y: Method) => runOnPart(x, y), - _ => finish() - ) + try { + init() + concurrentCreateApply[Method]( + producerQueueCapacity, + driver, + name, + baseLogger, + generateParts(), + cpg, + (x: DiffGraphBuilder, y: Method) => runOnPart(x, y), + None, + (newDiff: Int) => { + nDiffT = newDiff + nDiffT + } + ) + } finally { + finish() + } } } diff --git a/src/main/scala/com/github/plume/oss/passes/concurrent/PlumeConcurrentCpgPass.scala b/src/main/scala/com/github/plume/oss/passes/concurrent/PlumeConcurrentCpgPass.scala index 69d01180..5743cde9 100644 --- a/src/main/scala/com/github/plume/oss/passes/concurrent/PlumeConcurrentCpgPass.scala +++ b/src/main/scala/com/github/plume/oss/passes/concurrent/PlumeConcurrentCpgPass.scala @@ -4,15 +4,16 @@ import com.github.plume.oss.drivers.IDriver import com.github.plume.oss.passes.concurrent.PlumeConcurrentCpgPass.concurrentCreateApply import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.nodes.AstNode -import io.shiftleft.passes.{ConcurrentWriterCpgPass, DiffGraph} -import org.slf4j.Logger +import io.shiftleft.passes.{ConcurrentWriterCpgPass, KeyPool} +import io.shiftleft.utils.ExecutionContextProvider +import org.slf4j.{Logger, MDC} +import overflowdb.BatchedUpdate.{DiffGraph, DiffGraphBuilder} import scala.collection.mutable -import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration.Duration -import scala.concurrent.{Await, Future} +import scala.concurrent.{Await, ExecutionContext, Future} -abstract class PlumeConcurrentCpgPass[T <: AstNode](cpg: Cpg) +abstract class PlumeConcurrentCpgPass[T <: AstNode](cpg: Cpg, keyPool: Option[KeyPool]) extends ConcurrentWriterCpgPass[T](cpg) { override def generateParts(): Array[_ <: AstNode] = Array[AstNode](null) @@ -23,20 +24,28 @@ abstract class PlumeConcurrentCpgPass[T <: AstNode](cpg: Cpg) def createApplySerializeAndStore(driver: IDriver): Unit = { import PlumeConcurrentCpgPass.producerQueueCapacity - concurrentCreateApply[T]( - producerQueueCapacity, - driver, - name, - baseLogger, - _ => init(), - _ => generateParts(), - cpg, - (x: DiffGraph.Builder, y: T) => runOnPart(x, y), - _ => finish() - ) + try { + init() + concurrentCreateApply[T]( + producerQueueCapacity, + driver, + name, + baseLogger, + generateParts(), + cpg, + (x: DiffGraphBuilder, y: T) => runOnPart(x, y), + keyPool, + (newDiff: Int) => { + nDiffT = newDiff + nDiffT + } + ) + } finally { + finish() + } } - override def runOnPart(builder: DiffGraph.Builder, part: T): Unit + override def runOnPart(builder: DiffGraphBuilder, part: T): Unit } object PlumeConcurrentCpgPass { @@ -47,53 +56,54 @@ object PlumeConcurrentCpgPass { driver: IDriver, name: String, baseLogger: Logger, - init: Unit => Unit, - generateParts: Unit => Array[_ <: AstNode], + parts: Array[_ <: AstNode], cpg: Cpg, - runOnPart: (DiffGraph.Builder, T) => Unit, - finish: Unit => Unit + runOnPart: (DiffGraphBuilder, T) => Unit, + keyPool: Option[KeyPool], + setDiff: Int => Int ): Unit = { baseLogger.info(s"Start of enhancement: $name") val nanosStart = System.nanoTime() var nParts = 0 - var nDiff = 0 - - init() - val parts = generateParts() + var nDiff = setDiff(0) + // init is called outside of this method nParts = parts.length val partIter = parts.iterator val completionQueue = mutable.ArrayDeque[Future[DiffGraph]]() - val writer = new PlumeConcurrentWriter(driver, cpg) - val writerThread = new Thread(writer) + val writer = + new PlumeConcurrentWriter(driver, cpg, baseLogger, keyPool, MDC.getCopyOfContextMap, setDiff) + val writerThread = new Thread(writer) writerThread.setName("Writer") writerThread.start() + implicit val ec: ExecutionContext = ExecutionContextProvider.getExecutionContext try { try { var done = false - while (!done) { + while (!done && writer.raisedException == null) { + if (writer.raisedException != null) + throw writer.raisedException + if (completionQueue.size < producerQueueCapacity && partIter.hasNext) { val next = partIter.next() completionQueue.append(Future.apply { - val builder = DiffGraph.newBuilder + val builder = new DiffGraphBuilder runOnPart(builder, next.asInstanceOf[T]) builder.build() }) } else if (completionQueue.nonEmpty) { val future = completionQueue.removeHead() val res = Await.result(future, Duration.Inf) - nDiff += res.size + nDiff = setDiff(nDiff + res.size) writer.queue.put(Some(res)) } else { done = true } } } finally { - try { - writer.queue.put(None) - writerThread.join() - } finally { - finish() - } + if (writer.raisedException == null) writer.queue.put(None) + writerThread.join() + if (writer.raisedException != null) + throw new RuntimeException("Failure in diffgraph application", writer.raisedException) } } finally { val nanosStop = System.nanoTime() diff --git a/src/main/scala/com/github/plume/oss/passes/concurrent/PlumeConcurrentWriter.scala b/src/main/scala/com/github/plume/oss/passes/concurrent/PlumeConcurrentWriter.scala index 62d09c2f..afbc5a5e 100644 --- a/src/main/scala/com/github/plume/oss/passes/concurrent/PlumeConcurrentWriter.scala +++ b/src/main/scala/com/github/plume/oss/passes/concurrent/PlumeConcurrentWriter.scala @@ -2,8 +2,9 @@ package com.github.plume.oss.passes.concurrent import com.github.plume.oss.drivers.IDriver import io.shiftleft.codepropertygraph.Cpg -import io.shiftleft.passes.{CpgPass, DiffGraph} -import org.slf4j.{Logger, LoggerFactory} +import io.shiftleft.passes.{CpgPass, KeyPool} +import org.slf4j.{Logger, LoggerFactory, MDC} +import overflowdb.BatchedUpdate.DiffGraph import java.util.concurrent.LinkedBlockingQueue @@ -13,14 +14,21 @@ object PlumeConcurrentWriter { class PlumeConcurrentWriter( driver: IDriver, cpg: Cpg, - baseLogger: Logger = LoggerFactory.getLogger(classOf[CpgPass]) + baseLogger: Logger = LoggerFactory.getLogger(classOf[CpgPass]), + keyPool: Option[KeyPool] = None, + mdc: java.util.Map[String, String], + setDiffT: Int => Int ) extends Runnable { val queue: LinkedBlockingQueue[Option[DiffGraph]] = new LinkedBlockingQueue[Option[DiffGraph]](PlumeConcurrentWriter.writerQueueCapacity) + @volatile var raisedException: Exception = null + override def run(): Unit = { try { + var nDiffT = setDiffT(0) + MDC.setContextMap(mdc) var terminate = false while (!terminate) { queue.take() match { @@ -28,13 +36,22 @@ class PlumeConcurrentWriter( baseLogger.debug("Shutting down WriterThread") terminate = true case Some(diffGraph) => - val appliedDiffGraph = - DiffGraph.Applier.applyDiff(diffGraph, cpg, undoable = false, None) + val appliedDiffGraph = overflowdb.BatchedUpdate + .applyDiff(cpg.graph, diffGraph, keyPool.orNull, null) + + nDiffT = setDiffT( + nDiffT + appliedDiffGraph + .transitiveModifications() + ) driver.bulkTx(appliedDiffGraph) } } } catch { case exception: InterruptedException => baseLogger.warn("Interrupted WriterThread", exception) + case exc: Exception => + raisedException = exc + queue.clear() + throw exc } } } diff --git a/src/main/scala/com/github/plume/oss/passes/concurrent/PlumeContainsEdgePass.scala b/src/main/scala/com/github/plume/oss/passes/concurrent/PlumeContainsEdgePass.scala index ce63c122..15217296 100644 --- a/src/main/scala/com/github/plume/oss/passes/concurrent/PlumeContainsEdgePass.scala +++ b/src/main/scala/com/github/plume/oss/passes/concurrent/PlumeContainsEdgePass.scala @@ -5,7 +5,6 @@ import com.github.plume.oss.passes.PlumeCpgPassBase import com.github.plume.oss.passes.concurrent.PlumeConcurrentCpgPass.concurrentCreateApply import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.nodes.AstNode -import io.shiftleft.passes.DiffGraph import io.shiftleft.semanticcpg.passes.base.ContainsEdgePass object PlumeContainsEdgePass { @@ -20,16 +19,24 @@ class PlumeContainsEdgePass(cpg: Cpg) extends ContainsEdgePass(cpg) with PlumeCp def createApplySerializeAndStore(driver: IDriver): Unit = { import PlumeContainsEdgePass.producerQueueCapacity - concurrentCreateApply[AstNode]( - producerQueueCapacity, - driver, - name, - baseLogger, - _ => init(), - _ => generateParts(), - cpg, - (x: DiffGraph.Builder, y: AstNode) => runOnPart(x, y), - _ => finish() - ) + try { + init() + concurrentCreateApply[AstNode]( + producerQueueCapacity, + driver, + name, + baseLogger, + generateParts(), + cpg, + (x: DiffGraphBuilder, y: AstNode) => runOnPart(x, y), + None, + (newDiff: Int) => { + nDiffT = newDiff + nDiffT + } + ) + } finally { + finish() + } } } diff --git a/src/main/scala/com/github/plume/oss/passes/concurrent/PlumeHashPass.scala b/src/main/scala/com/github/plume/oss/passes/concurrent/PlumeHashPass.scala index c3157962..76209026 100644 --- a/src/main/scala/com/github/plume/oss/passes/concurrent/PlumeHashPass.scala +++ b/src/main/scala/com/github/plume/oss/passes/concurrent/PlumeHashPass.scala @@ -7,13 +7,14 @@ import io.shiftleft.codepropertygraph.generated.nodes.File import io.shiftleft.passes.DiffGraph import io.shiftleft.semanticcpg.language._ import org.slf4j.{Logger, LoggerFactory} +import overflowdb.BatchedUpdate.DiffGraphBuilder import java.io.{File => JFile} import scala.util.{Failure, Success, Try} /** Performs hash calculations on the files represented by the FILE nodes. */ -class PlumeHashPass(cpg: Cpg) extends PlumeConcurrentCpgPass[File](cpg) { +class PlumeHashPass(cpg: Cpg) extends PlumeConcurrentCpgPass[File](cpg, None) { import PlumeHashPass._ @@ -25,15 +26,15 @@ class PlumeHashPass(cpg: Cpg) extends PlumeConcurrentCpgPass[File](cpg) { /** Use the information in the given file node to find the local file and store its hash locally. */ - override def runOnPart(diffGraph: DiffGraph.Builder, part: File): Unit = { - val localDiff = DiffGraph.newBuilder + override def runOnPart(diffGraph: DiffGraphBuilder, part: File): Unit = { + val localDiff = new DiffGraphBuilder Try(HashUtil.getFileHash(new JFile(part.name))) match { case Failure(exception) => logger.warn(s"Unable to generate hash for file at ${part.name}", exception) case Success(fileHash) => - localDiff.addNodeProperty(part, PropertyNames.HASH, fileHash) + localDiff.setNodeProperty(part, PropertyNames.HASH, fileHash) } - diffGraph.moveFrom(localDiff) + diffGraph.absorb(localDiff) } } diff --git a/src/main/scala/com/github/plume/oss/passes/forkjoin/PlumeCdgPass.scala b/src/main/scala/com/github/plume/oss/passes/forkjoin/PlumeCdgPass.scala new file mode 100644 index 00000000..5eb6cd8d --- /dev/null +++ b/src/main/scala/com/github/plume/oss/passes/forkjoin/PlumeCdgPass.scala @@ -0,0 +1,41 @@ +package com.github.plume.oss.passes.forkjoin + +import com.github.plume.oss.drivers.IDriver +import com.github.plume.oss.passes.PlumeCpgPassBase +import com.github.plume.oss.passes.forkjoin.PlumeForkJoinParallelCpgPass.forkJoinSerializeAndStore +import com.github.plume.oss.passes.parallel.PlumeParallelCpgPass.{ + parallelEnqueue, + parallelWithWriter +} +import com.github.plume.oss.passes.parallel.PlumeParallelWriter +import io.shiftleft.codepropertygraph.Cpg +import io.shiftleft.codepropertygraph.generated.nodes.Method +import io.shiftleft.passes.KeyPool +import io.shiftleft.semanticcpg.passes.controlflow.codepencegraph.CdgPass + +class PlumeCdgPass(cpg: Cpg, keyPool: Option[KeyPool] = None) + extends CdgPass(cpg) + with PlumeCpgPassBase { + + def createAndApply(driver: IDriver): Unit = { + createApplySerializeAndStore(driver) // Apply to driver + } + + def createApplySerializeAndStore(driver: IDriver): Unit = { + try { + init() + forkJoinSerializeAndStore( + driver, + name, + cpg, + baseLogger, + generateParts(), + (builder: DiffGraphBuilder, part: Method) => runOnPart(builder, part), + keyPool + ) + } finally { + finish() + } + } + +} diff --git a/src/main/scala/com/github/plume/oss/passes/forkjoin/PlumeCfgDominatorPass.scala b/src/main/scala/com/github/plume/oss/passes/forkjoin/PlumeCfgDominatorPass.scala new file mode 100644 index 00000000..4dc52652 --- /dev/null +++ b/src/main/scala/com/github/plume/oss/passes/forkjoin/PlumeCfgDominatorPass.scala @@ -0,0 +1,41 @@ +package com.github.plume.oss.passes.forkjoin + +import com.github.plume.oss.drivers.IDriver +import com.github.plume.oss.passes.PlumeCpgPassBase +import com.github.plume.oss.passes.forkjoin.PlumeForkJoinParallelCpgPass.forkJoinSerializeAndStore +import com.github.plume.oss.passes.parallel.PlumeParallelCpgPass.{ + parallelEnqueue, + parallelWithWriter +} +import com.github.plume.oss.passes.parallel.PlumeParallelWriter +import io.shiftleft.codepropertygraph.Cpg +import io.shiftleft.codepropertygraph.generated.nodes.Method +import io.shiftleft.passes.KeyPool +import io.shiftleft.semanticcpg.passes.controlflow.cfgdominator.CfgDominatorPass + +class PlumeCfgDominatorPass(cpg: Cpg, keyPool: Option[KeyPool] = None) + extends CfgDominatorPass(cpg) + with PlumeCpgPassBase { + + def createAndApply(driver: IDriver): Unit = { + createApplySerializeAndStore(driver) // Apply to driver + } + + def createApplySerializeAndStore(driver: IDriver): Unit = { + try { + init() + forkJoinSerializeAndStore( + driver, + name, + cpg, + baseLogger, + generateParts(), + (builder: DiffGraphBuilder, part: Method) => runOnPart(builder, part), + keyPool + ) + } finally { + finish() + } + } + +} diff --git a/src/main/scala/com/github/plume/oss/passes/forkjoin/PlumeForkJoinParallelCpgPass.scala b/src/main/scala/com/github/plume/oss/passes/forkjoin/PlumeForkJoinParallelCpgPass.scala new file mode 100644 index 00000000..bdefc8ed --- /dev/null +++ b/src/main/scala/com/github/plume/oss/passes/forkjoin/PlumeForkJoinParallelCpgPass.scala @@ -0,0 +1,118 @@ +package com.github.plume.oss.passes.forkjoin + +import com.github.plume.oss.drivers.IDriver +import com.github.plume.oss.passes.PlumeCpgPass +import com.github.plume.oss.passes.forkjoin.PlumeForkJoinParallelCpgPass.forkJoinSerializeAndStore +import io.shiftleft.codepropertygraph.Cpg +import io.shiftleft.passes.{ForkJoinParallelCpgPass, KeyPool} +import org.slf4j.Logger + +import java.util.function.{BiConsumer, Supplier} + +object PlumeForkJoinParallelCpgPass { + + type DiffGraphBuilder = overflowdb.BatchedUpdate.DiffGraphBuilder + + def forkJoinSerializeAndStore[T]( + driver: IDriver, + name: String, + cpg: Cpg, + baseLogger: Logger, + parts: Array[_ <: AnyRef], + runOnPart: (DiffGraphBuilder, T) => Unit, + keyPool: Option[KeyPool], + blacklist: Set[String] = Set.empty, + blacklistProperty: String = "", + blacklistRejectOnFail: Boolean = false + ): Unit = { + baseLogger.info(s"Start of pass: $name") + val nanosStart = System.nanoTime() + var nParts = 0 + var nanosBuilt = -1L + var nDiff = -1 + var nDiffT = -1 + try { + nParts = parts.length + val diffGraph = nParts match { + case 0 => (new DiffGraphBuilder).build() + case 1 => + val builder = new DiffGraphBuilder + runOnPart(builder, parts(0).asInstanceOf[T]) + builder.build() + case _ => + java.util.Arrays + .stream(parts) + .parallel() + .collect( + new Supplier[DiffGraphBuilder] { + override def get(): DiffGraphBuilder = + new DiffGraphBuilder + }, + new BiConsumer[DiffGraphBuilder, AnyRef] { + override def accept(builder: DiffGraphBuilder, part: AnyRef): Unit = + runOnPart(builder, part.asInstanceOf[T]) + }, + new BiConsumer[DiffGraphBuilder, DiffGraphBuilder] { + override def accept(leftBuilder: DiffGraphBuilder, rightBuilder: DiffGraphBuilder) + : Unit = + leftBuilder.absorb(rightBuilder) + } + ) + .build() + } + nanosBuilt = System.nanoTime() + nDiff = diffGraph.size() + val diffToCommit = + if (blacklist.nonEmpty) + PlumeCpgPass + .filterBatchedDiffGraph(diffGraph, blacklistProperty, blacklist, blacklistRejectOnFail) + else + diffGraph + + val appliedDiffGraph = overflowdb.BatchedUpdate + .applyDiff(cpg.graph, diffToCommit, keyPool.orNull, null) + driver.bulkTx(appliedDiffGraph) + nDiffT = appliedDiffGraph.transitiveModifications() + + } catch { + case exc: Exception => + baseLogger.error(s"Pass ${name} failed", exc) + throw exc + } finally { + val nanosStop = System.nanoTime() + val fracRun = + if (nanosBuilt == -1) 100.0 + else (nanosBuilt - nanosStart) * 100.0 / (nanosStop - nanosStart + 1) + baseLogger.info( + f"Pass $name completed in ${(nanosStop - nanosStart) * 1e-6}%.0f ms (${fracRun}%.0f%% on mutations). ${nDiff}%d + ${nDiffT - nDiff}%d changes commited from ${nParts}%d parts." + ) + } + } + +} + +abstract class PlumeForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, keyPool: Option[KeyPool] = None) + extends ForkJoinParallelCpgPass[T](cpg, keyPool = keyPool) { + + def createAndApply(driver: IDriver): Unit = { + createApplySerializeAndStore(driver) // Apply to driver + } + + def createApplySerializeAndStore(driver: IDriver): Unit = { + try { + init() + forkJoinSerializeAndStore( + driver, + name, + cpg, + baseLogger, + generateParts(), + (builder: DiffGraphBuilder, part: T) => runOnPart(builder, part), + keyPool + ) + } finally { + finish() + } + } + +} diff --git a/src/main/scala/com/github/plume/oss/passes/parallel/PlumeAstCreator.scala b/src/main/scala/com/github/plume/oss/passes/parallel/PlumeAstCreator.scala index e090b5dd..a5782760 100644 --- a/src/main/scala/com/github/plume/oss/passes/parallel/PlumeAstCreator.scala +++ b/src/main/scala/com/github/plume/oss/passes/parallel/PlumeAstCreator.scala @@ -3,7 +3,7 @@ package com.github.plume.oss.passes.parallel import io.shiftleft.codepropertygraph.generated.nodes._ import io.shiftleft.codepropertygraph.generated._ import io.shiftleft.passes.DiffGraph -import io.shiftleft.x2cpg.Ast +import io.joern.x2cpg.Ast import org.slf4j.LoggerFactory import soot.jimple._ import soot.tagkit.Host @@ -378,7 +378,7 @@ class PlumeAstCreator(filename: String, global: Global) { } private def astForInvokeExpr(invokeExpr: InvokeExpr, order: Int, parentUnit: soot.Unit): Ast = { - val method = invokeExpr.getMethod + val method = invokeExpr.getMethodRef val dispatchType = invokeExpr match { case _ if method.isConstructor => DispatchTypes.STATIC_DISPATCH case _: DynamicInvokeExpr => DispatchTypes.DYNAMIC_DISPATCH @@ -386,7 +386,7 @@ class PlumeAstCreator(filename: String, global: Global) { case _ => DispatchTypes.STATIC_DISPATCH } val signature = - s"${method.getReturnType.toQuotedString}(${(for (i <- 0 until method.getParameterCount) + s"${method.getReturnType.toQuotedString}(${(for (i <- 0 until method.getParameterTypes().size()) yield method.getParameterType(i).toQuotedString).mkString(",")})" val thisAsts = Seq(createThisNode(method, NewIdentifier())) @@ -506,7 +506,10 @@ class PlumeAstCreator(filename: String, global: Global) { ) } - private def createThisNode(method: SootMethod, builder: NewNode): Ast = { + private def createThisNode(method: SootMethod, builder: NewNode): Ast = + createThisNode(method.makeRef(), builder) + + private def createThisNode(method: SootMethodRef, builder: NewNode): Ast = { if (!method.isStatic || method.isConstructor) { val parentType = registerType(method.getDeclaringClass.getType.toQuotedString) Ast( @@ -521,7 +524,7 @@ class PlumeAstCreator(filename: String, global: Global) { case x: NewMethodParameterIn => x.name("this") .code("this") - .lineNumber(line(method)) + .lineNumber(line(method.tryResolve())) .typeFullName(parentType) .order(0) .evaluationStrategy(EvaluationStrategies.BY_SHARING) @@ -911,11 +914,15 @@ class PlumeAstCreator(filename: String, global: Global) { object PlumeAstCreator { def line(node: Host): Option[Integer] = { - Option(node.getJavaSourceStartLineNumber) + if (node == null) None + else if (node.getJavaSourceStartLineNumber == -1) None + else Option(node.getJavaSourceStartLineNumber) } def column(node: Host): Option[Integer] = { - Option(node.getJavaSourceStartColumnNumber) + if (node == null) None + else if (node.getJavaSourceStartColumnNumber == -1) None + else Option(node.getJavaSourceStartColumnNumber) } def withOrder[T, X](nodeList: java.util.List[T])(f: (T, Int) => X): Seq[X] = { diff --git a/src/main/scala/com/github/plume/oss/passes/parallel/PlumeCdgPass.scala b/src/main/scala/com/github/plume/oss/passes/parallel/PlumeCdgPass.scala deleted file mode 100644 index 6df0f348..00000000 --- a/src/main/scala/com/github/plume/oss/passes/parallel/PlumeCdgPass.scala +++ /dev/null @@ -1,40 +0,0 @@ -package com.github.plume.oss.passes.parallel - -import com.github.plume.oss.drivers.IDriver -import com.github.plume.oss.passes.PlumeCpgPassBase -import com.github.plume.oss.passes.parallel.PlumeParallelCpgPass.{ - parallelEnqueue, - parallelWithWriter -} -import io.shiftleft.codepropertygraph.Cpg -import io.shiftleft.codepropertygraph.generated.nodes.Method -import io.shiftleft.passes.KeyPool -import io.shiftleft.semanticcpg.passes.controlflow.codepencegraph.CdgPass - -class PlumeCdgPass(cpg: Cpg, keyPools: Option[Iterator[KeyPool]] = None) - extends CdgPass(cpg) - with PlumeCpgPassBase { - - override def createAndApply(driver: IDriver): Unit = { - withWriter(driver) { writer => - enqueueInParallel(writer) - } - } - - private def withWriter[X](driver: IDriver)(f: PlumeParallelWriter => Unit): Unit = - parallelWithWriter[X](driver, f, cpg, baseLogger) - - private def enqueueInParallel(writer: PlumeParallelWriter): Unit = - withStartEndTimesLogged { - init() - parallelEnqueue[Method]( - baseLogger, - name, - writer, - (part: Method) => runOnPart(part), - keyPools, - partIterator - ) - } - -} diff --git a/src/main/scala/com/github/plume/oss/passes/parallel/PlumeCfgDominatorPass.scala b/src/main/scala/com/github/plume/oss/passes/parallel/PlumeCfgDominatorPass.scala deleted file mode 100644 index c3a11528..00000000 --- a/src/main/scala/com/github/plume/oss/passes/parallel/PlumeCfgDominatorPass.scala +++ /dev/null @@ -1,40 +0,0 @@ -package com.github.plume.oss.passes.parallel - -import com.github.plume.oss.drivers.IDriver -import com.github.plume.oss.passes.PlumeCpgPassBase -import com.github.plume.oss.passes.parallel.PlumeParallelCpgPass.{ - parallelEnqueue, - parallelWithWriter -} -import io.shiftleft.codepropertygraph.Cpg -import io.shiftleft.codepropertygraph.generated.nodes.Method -import io.shiftleft.passes.KeyPool -import io.shiftleft.semanticcpg.passes.controlflow.cfgdominator.CfgDominatorPass - -class PlumeCfgDominatorPass(cpg: Cpg, keyPools: Option[Iterator[KeyPool]] = None) - extends CfgDominatorPass(cpg) - with PlumeCpgPassBase { - - override def createAndApply(driver: IDriver): Unit = { - withWriter(driver) { writer => - enqueueInParallel(writer) - } - } - - private def withWriter[X](driver: IDriver)(f: PlumeParallelWriter => Unit): Unit = - parallelWithWriter(driver, f, cpg, baseLogger) - - private def enqueueInParallel(writer: PlumeParallelWriter): Unit = - withStartEndTimesLogged { - init() - parallelEnqueue[Method]( - baseLogger, - name, - writer, - (part: Method) => runOnPart(part), - keyPools, - partIterator - ) - } - -} diff --git a/src/main/scala/com/github/plume/oss/passes/parallel/PlumeMethodStubCreator.scala b/src/main/scala/com/github/plume/oss/passes/parallel/PlumeMethodStubCreator.scala index 1ef3c1f3..d6ea4593 100644 --- a/src/main/scala/com/github/plume/oss/passes/parallel/PlumeMethodStubCreator.scala +++ b/src/main/scala/com/github/plume/oss/passes/parallel/PlumeMethodStubCreator.scala @@ -1,63 +1,149 @@ package com.github.plume.oss.passes.parallel import com.github.plume.oss.drivers.IDriver -import com.github.plume.oss.passes.parallel.PlumeParallelCpgPass.{ - parallelEnqueue, - parallelItWithKeyPools, - parallelWithWriter -} -import com.github.plume.oss.passes.{IncrementalKeyPool, PlumeCpgPassBase} +import com.github.plume.oss.passes.forkjoin.PlumeForkJoinParallelCpgPass.forkJoinSerializeAndStore +import com.github.plume.oss.passes.{IncrementalKeyPool, PlumeCpgPassBase, PlumeSimpleCpgPass} import io.shiftleft.codepropertygraph.Cpg -import io.shiftleft.passes.{DiffGraph, KeyPool, ParallelIteratorExecutor} -import io.shiftleft.semanticcpg.passes.base.{MethodStubCreator, NameAndSignature} +import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.{EdgeTypes, EvaluationStrategies, NodeTypes} +import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.passes.base.NameAndSignature +import overflowdb.BatchedUpdate + +import scala.collection.mutable +import scala.util.Try class PlumeMethodStubCreator( cpg: Cpg, keyPool: Option[IncrementalKeyPool], blacklist: Set[String] = Set() -) extends MethodStubCreator(cpg) +) extends PlumeSimpleCpgPass(cpg, keyPool = keyPool) with PlumeCpgPassBase { - var keyPools: Option[Iterator[KeyPool]] = None + // Since the method fullNames for fuzzyc are not unique, we do not have + // a 1to1 relation and may overwrite some values. This is ok for now. + private val methodFullNameToNode = mutable.LinkedHashMap[String, MethodBase]() + private val methodToParameterCount = mutable.LinkedHashMap[NameAndSignature, Int]() + + private def filter(name: NameAndSignature): Boolean = { + val methodTypeName = name.fullName.replace(s".${name.name}:${name.signature}", "") + !(blacklist.contains(methodTypeName) || (blacklist.nonEmpty && methodTypeName == "")) + } - override def init(): Unit = { - super.init() - keyPool match { - case Some(value) => keyPools = Option(value.split(partIterator.size)) - case None => + override def run(dstGraph: BatchedUpdate.DiffGraphBuilder): Unit = { + for (method <- cpg.method) { + methodFullNameToNode.put(method.fullName, method) + } + + for (call <- cpg.call) { + methodToParameterCount.put( + NameAndSignature(call.name, call.signature, call.methodFullName), + call.argument.size + ) } + + for ( + (NameAndSignature(name, signature, fullName), parameterCount) <- methodToParameterCount + if !methodFullNameToNode.contains(fullName) + ) { + if (filter(NameAndSignature(name, signature, fullName))) + createMethodStub(name, fullName, signature, parameterCount, dstGraph) + } + } + + override def finish(): Unit = { + methodFullNameToNode.clear() + methodToParameterCount.clear() + super.finish() } - // Do not create stubs for methods that exist - override def runOnPart(part: (NameAndSignature, Int)): Iterator[DiffGraph] = { - val methodTypeName = part._1.fullName.replace(s".${part._1.name}:${part._1.signature}", "") - if (blacklist.contains(methodTypeName) || (blacklist.nonEmpty && methodTypeName == "")) { - Iterator() + private def addLineNumberInfo(methodNode: NewMethod, fullName: String): NewMethod = { + val s = fullName.split(":") + if (s.size == 5 && Try { s(1).toInt }.isSuccess && Try { s(2).toInt }.isSuccess) { + val filename = s(0) + val lineNumber = s(1).toInt + val lineNumberEnd = s(2).toInt + methodNode + .filename(filename) + .lineNumber(lineNumber) + .lineNumberEnd(lineNumberEnd) } else { - super.runOnPart(part) + methodNode } } - override def createAndApply(driver: IDriver): Unit = { - withWriter(driver) { writer => - enqueueInParallel(writer) + private def createMethodStub( + name: String, + fullName: String, + signature: String, + parameterCount: Int, + dstGraph: DiffGraphBuilder + ): MethodBase = { + + val methodNode = addLineNumberInfo( + NewMethod() + .name(name) + .fullName(fullName) + .isExternal(true) + .signature(signature) + .astParentType(NodeTypes.NAMESPACE_BLOCK) + .astParentFullName("") + .order(0), + fullName + ) + + dstGraph.addNode(methodNode) + + (1 to parameterCount).foreach { parameterOrder => + val nameAndCode = s"p$parameterOrder" + val param = NewMethodParameterIn() + .code(nameAndCode) + .order(parameterOrder) + .name(nameAndCode) + .evaluationStrategy(EvaluationStrategies.BY_VALUE) + .typeFullName("ANY") + + dstGraph.addNode(param) + dstGraph.addEdge(methodNode, param, EdgeTypes.AST) } + + val methodReturn = NewMethodReturn() + .code("RET") + .evaluationStrategy(EvaluationStrategies.BY_VALUE) + .typeFullName("ANY") + + dstGraph.addNode(methodReturn) + dstGraph.addEdge(methodNode, methodReturn, EdgeTypes.AST) + + val blockNode = NewBlock() + .order(1) + .argumentIndex(1) + .typeFullName("ANY") + + dstGraph.addNode(blockNode) + dstGraph.addEdge(methodNode, blockNode, EdgeTypes.AST) + + methodNode } - private def withWriter[X](driver: IDriver)(f: PlumeParallelWriter => Unit): Unit = - parallelWithWriter[X](driver, f, cpg, baseLogger) + def createAndApply(driver: IDriver): Unit = { + createApplySerializeAndStore(driver) // Apply to driver + } - private def enqueueInParallel(writer: PlumeParallelWriter): Unit = - withStartEndTimesLogged { + def createApplySerializeAndStore(driver: IDriver): Unit = { + try { init() - parallelEnqueue( - baseLogger, + forkJoinSerializeAndStore( + driver, name, - writer, - (x: (NameAndSignature, Int)) => runOnPart(x), - keyPools, - partIterator + cpg, + baseLogger, + generateParts(), + (builder: DiffGraphBuilder, part: Method) => runOnPart(builder, part), + keyPool ) + } finally { + finish() } - + } } diff --git a/src/main/scala/com/github/plume/oss/passes/parallel/PlumeReachingDefPass.scala b/src/main/scala/com/github/plume/oss/passes/parallel/PlumeReachingDefPass.scala index cae93649..8766297c 100644 --- a/src/main/scala/com/github/plume/oss/passes/parallel/PlumeReachingDefPass.scala +++ b/src/main/scala/com/github/plume/oss/passes/parallel/PlumeReachingDefPass.scala @@ -2,10 +2,7 @@ package com.github.plume.oss.passes.parallel import com.github.plume.oss.drivers.IDriver import com.github.plume.oss.passes.PlumeCpgPassBase -import com.github.plume.oss.passes.parallel.PlumeParallelCpgPass.{ - parallelEnqueue, - parallelWithWriter -} +import com.github.plume.oss.passes.forkjoin.PlumeForkJoinParallelCpgPass.forkJoinSerializeAndStore import io.joern.dataflowengineoss.passes.reachingdef.ReachingDefPass import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.nodes.Method @@ -14,36 +11,35 @@ import io.shiftleft.semanticcpg.language._ class PlumeReachingDefPass( cpg: Cpg, - keyPools: Option[Iterator[KeyPool]] = None, + keyPool: Option[KeyPool] = None, unchangedTypes: Set[String] = Set.empty[String] ) extends ReachingDefPass(cpg) with PlumeCpgPassBase { - override def createAndApply(driver: IDriver): Unit = { - withWriter(driver) { writer => - enqueueInParallel(writer) - } - } - - override def partIterator: Iterator[Method] = cpg.method.internal.iterator.filterNot { m => + override def generateParts(): Array[Method] = cpg.method.internal.iterator.filterNot { m => val typeFullName = m.fullName.substring(0, m.fullName.lastIndexOf('.')) unchangedTypes.contains(typeFullName) - } + }.toArray - private def withWriter[X](driver: IDriver)(f: PlumeParallelWriter => Unit): Unit = - parallelWithWriter[X](driver, f, cpg, baseLogger) + def createAndApply(driver: IDriver): Unit = { + createApplySerializeAndStore(driver) // Apply to driver + } - private def enqueueInParallel(writer: PlumeParallelWriter): Unit = - withStartEndTimesLogged { + def createApplySerializeAndStore(driver: IDriver): Unit = { + try { init() - parallelEnqueue[Method]( - baseLogger, + forkJoinSerializeAndStore( + driver, name, - writer, - (part: Method) => runOnPart(part), - keyPools, - partIterator + cpg, + baseLogger, + generateParts(), + (builder: DiffGraphBuilder, part: Method) => runOnPart(builder, part), + keyPool ) + } finally { + finish() } + } } diff --git a/src/main/scala/com/github/plume/oss/util/BatchedUpdateUtil.scala b/src/main/scala/com/github/plume/oss/util/BatchedUpdateUtil.scala new file mode 100644 index 00000000..00ed1adf --- /dev/null +++ b/src/main/scala/com/github/plume/oss/util/BatchedUpdateUtil.scala @@ -0,0 +1,69 @@ +package com.github.plume.oss.util + +import io.shiftleft.codepropertygraph.generated.nodes.NewNode +import overflowdb.{DetachedNodeData, DetachedNodeGeneric, Node} + +import scala.collection.mutable + +/** Tools to extract information from new BatchedUpdate API. + */ +object BatchedUpdateUtil { + + private def idFromRefOrId(refOrId: Object): Long = { + refOrId match { + case n: Node => n.id() + case i: java.lang.Long => i.longValue() + } + } + + /** By determines what kind of node object is given, will extract its label. + * @param data either detached node data or node object. + * @return the node ID. + */ + def labelFromNodeData(data: Any): String = + data match { + case generic: DetachedNodeGeneric => generic.label() + case node: NewNode => node.label() + case node: Node => node.label() + } + + /** By determines what kind of node object is given, will extract its ID. + * @param data either detached node data or node object. + * @return the node ID. + */ + def idFromNodeData(data: Any): Long = + data match { + case generic: DetachedNodeGeneric => idFromRefOrId(generic.getRefOrId) + case node: NewNode => idFromRefOrId(node.getRefOrId) + case node: Node => node.id() + } + + /** Extracts properties from detached node data. + * @param data node data from which to determine properties from. + * @return a map of key-value pairs. + */ + def propertiesFromNodeData(data: DetachedNodeData): Map[String, Any] = { + data match { + case generic: DetachedNodeGeneric => propertiesFromObjectArray(generic.keyvalues) + case node: NewNode => node.properties + case _ => Map.empty[String, Any] + } + } + + /** Extracts a property key-value pairs as a map from an object array. + * @param arr the object array where key-values are stored as pairs. + * @return a map of key-value pairs. + */ + def propertiesFromObjectArray(arr: Array[Object]): Map[String, Any] = { + val props = mutable.HashMap.empty[String, Object] + for { + i <- arr.indices by 2 + } { + val key = arr(i).asInstanceOf[String] + val value = arr(i + 1) + props.put(key, value) + } + props.toMap + } + +} diff --git a/src/main/scala/com/github/plume/oss/util/ProgramHandlingUtil.scala b/src/main/scala/com/github/plume/oss/util/ProgramHandlingUtil.scala index 3d6bd0da..97556099 100644 --- a/src/main/scala/com/github/plume/oss/util/ProgramHandlingUtil.scala +++ b/src/main/scala/com/github/plume/oss/util/ProgramHandlingUtil.scala @@ -1,6 +1,6 @@ package com.github.plume.oss.util -import io.shiftleft.x2cpg.SourceFiles +import io.joern.x2cpg.SourceFiles import org.objectweb.asm.ClassReader.SKIP_CODE import org.objectweb.asm.{ClassReader, ClassVisitor, Opcodes} import org.slf4j.LoggerFactory diff --git a/src/test/scala/com/github/plume/oss/querying/MethodParameterTests.scala b/src/test/scala/com/github/plume/oss/querying/MethodParameterTests.scala index cdc24f81..74d2b62c 100644 --- a/src/test/scala/com/github/plume/oss/querying/MethodParameterTests.scala +++ b/src/test/scala/com/github/plume/oss/querying/MethodParameterTests.scala @@ -30,7 +30,7 @@ class MethodParameterTests extends Jimple2CpgFixture { x.code shouldBe "int param1" x.typeFullName shouldBe "int" x.lineNumber shouldBe Some(3) - x.columnNumber shouldBe Some(-1) + x.columnNumber shouldBe None x.order shouldBe 1 x.evaluationStrategy shouldBe EvaluationStrategies.BY_VALUE @@ -38,7 +38,7 @@ class MethodParameterTests extends Jimple2CpgFixture { y.code shouldBe "java.lang.Object param2" y.typeFullName shouldBe "java.lang.Object" y.lineNumber shouldBe Some(3) - y.columnNumber shouldBe Some(-1) + y.columnNumber shouldBe None y.order shouldBe 2 y.evaluationStrategy shouldBe EvaluationStrategies.BY_REFERENCE } diff --git a/src/test/scala/com/github/plume/oss/querying/MethodTests.scala b/src/test/scala/com/github/plume/oss/querying/MethodTests.scala index 613b4d6e..b61fd344 100644 --- a/src/test/scala/com/github/plume/oss/querying/MethodTests.scala +++ b/src/test/scala/com/github/plume/oss/querying/MethodTests.scala @@ -26,7 +26,7 @@ class MethodTests extends Jimple2CpgFixture { x.filename.startsWith(File.separator) shouldBe true x.filename.endsWith(".class") shouldBe true x.lineNumber shouldBe Some(2) - x.columnNumber shouldBe Some(-1) + x.columnNumber shouldBe None } // "should return correct number of lines" in { diff --git a/src/test/scala/com/github/plume/oss/testfixtures/PlumeDriverFixture.scala b/src/test/scala/com/github/plume/oss/testfixtures/PlumeDriverFixture.scala index ec90e943..032cd597 100644 --- a/src/test/scala/com/github/plume/oss/testfixtures/PlumeDriverFixture.scala +++ b/src/test/scala/com/github/plume/oss/testfixtures/PlumeDriverFixture.scala @@ -1,6 +1,7 @@ package com.github.plume.oss.testfixtures import com.github.plume.oss.drivers.IDriver +import com.github.plume.oss.passes.forkjoin.PlumeForkJoinParallelCpgPass.DiffGraphBuilder import io.shiftleft.codepropertygraph.generated.NodeTypes._ import io.shiftleft.codepropertygraph.generated.nodes._ import io.shiftleft.codepropertygraph.generated.{Cpg, DispatchTypes, EdgeTypes} @@ -9,7 +10,9 @@ import io.shiftleft.codepropertygraph.generated.PropertyNames._ import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import overflowdb.{BatchedUpdate, DetachedNodeData, DetachedNodeGeneric, Node} +import scala.jdk.CollectionConverters.IteratorHasAsScala import scala.language.postfixOps class PlumeDriverFixture(val driver: IDriver) @@ -28,123 +31,263 @@ class PlumeDriverFixture(val driver: IDriver) driver.clear() } - "should reflect node additions in bulk transactions" in { - val cpg = Cpg.empty - val keyPool = new IntervalKeyPool(1, 1000) - val diffGraph = DiffGraph.newBuilder - // Create some nodes - diffGraph.addNode(m1).addNode(b1) - val adg = - DiffGraph.Applier.applyDiff(diffGraph.build(), cpg.graph, undoable = false, Option(keyPool)) - driver.bulkTx(adg) - val List(m: Map[String, Any]) = - driver.propertyFromNodes(METHOD, NAME, ORDER, DYNAMIC_TYPE_HINT_FULL_NAME) - m.get(NAME) shouldBe Some("foo") - m.get(ORDER) shouldBe Some(1) - val List(b: Map[String, Any]) = driver.propertyFromNodes(BLOCK, ORDER) - b.get(ORDER) shouldBe Some(1) + "io.shiftleft.passes.DiffGraph based changes" should { + + "should reflect node additions in bulk transactions" in { + val cpg = Cpg.empty + val keyPool = new IntervalKeyPool(1, 1000) + val diffGraph = DiffGraph.newBuilder + // Create some nodes + diffGraph.addNode(m1).addNode(b1) + val adg = + DiffGraph.Applier.applyDiff(diffGraph.build(), cpg.graph, undoable = false, Option(keyPool)) + driver.bulkTx(adg) + val List(m: Map[String, Any]) = + driver.propertyFromNodes(METHOD, NAME, ORDER, DYNAMIC_TYPE_HINT_FULL_NAME) + m.get(NAME) shouldBe Some("foo") + m.get(ORDER) shouldBe Some(1) + val List(b: Map[String, Any]) = driver.propertyFromNodes(BLOCK, ORDER) + b.get(ORDER) shouldBe Some(1) + } + + "should reflect node subtractions in bulk transactions" in { + val cpg = Cpg.empty + val keyPool = new IntervalKeyPool(1, 1000) + val diffGraph1 = DiffGraph.newBuilder + val diffGraph2 = DiffGraph.newBuilder + // Create some nodes + diffGraph1.addNode(m1).addNode(b1) + val adg1 = + DiffGraph.Applier.applyDiff( + diffGraph1.build(), + cpg.graph, + undoable = false, + Option(keyPool) + ) + driver.bulkTx(adg1) + + val List(m: Map[String, Any]) = driver.propertyFromNodes(METHOD, NAME, ORDER) + m.get(NAME) shouldBe Some("foo") + m.get(ORDER) shouldBe Some(1) + val List(b: Map[String, Any]) = driver.propertyFromNodes(BLOCK, ORDER) + b.get(ORDER) shouldBe Some(1) + + // Remove one node + diffGraph2.removeNode(m.getOrElse("id", -1L).toString.toLong) + val adg2 = + DiffGraph.Applier.applyDiff( + diffGraph2.build(), + cpg.graph, + undoable = false, + Option(keyPool) + ) + driver.bulkTx(adg2) + + driver.propertyFromNodes(METHOD) shouldBe List() + } + + "should reflect edge additions in bulk transactions" in { + val cpg = Cpg.empty + val keyPool = new IntervalKeyPool(1, 1000) + val diffGraph1 = DiffGraph.newBuilder + val diffGraph2 = DiffGraph.newBuilder + // Create some nodes + diffGraph1.addNode(m1).addNode(b1) + val adg1 = + DiffGraph.Applier.applyDiff( + diffGraph1.build(), + cpg.graph, + undoable = false, + Option(keyPool) + ) + driver.bulkTx(adg1) + + val List(m: Map[String, Any]) = driver.propertyFromNodes(METHOD, NAME, ORDER) + m.get(NAME) shouldBe Some("foo") + m.get(ORDER) shouldBe Some(1) + val List(b: Map[String, Any]) = driver.propertyFromNodes(BLOCK, ORDER) + b.get(ORDER) shouldBe Some(1) + + // Add an edge + diffGraph2.addEdge( + cpg.graph.nodes(m.getOrElse("id", -1L).toString.toLong).next().asInstanceOf[AbstractNode], + cpg.graph.nodes(b.getOrElse("id", -1L).toString.toLong).next().asInstanceOf[AbstractNode], + EdgeTypes.AST + ) + val adg2 = + DiffGraph.Applier.applyDiff( + diffGraph2.build(), + cpg.graph, + undoable = false, + Option(keyPool) + ) + driver.bulkTx(adg2) + + driver.exists( + m.getOrElse("id", -1L).toString.toLong, + b.getOrElse("id", -1L).toString.toLong, + EdgeTypes.AST + ) shouldBe true + driver.exists( + b.getOrElse("id", -1L).toString.toLong, + m.getOrElse("id", -1L).toString.toLong, + EdgeTypes.AST + ) shouldBe false + } + + "should reflect edge removal in bulk transactions" in { + val cpg = Cpg.empty + val keyPool = new IntervalKeyPool(1, 1000) + val diffGraph1 = DiffGraph.newBuilder + val diffGraph2 = DiffGraph.newBuilder + // Create some nodes and an edge + diffGraph1.addNode(m1).addNode(b1).addEdge(m1, b1, EdgeTypes.AST) + val adg1 = + DiffGraph.Applier.applyDiff( + diffGraph1.build(), + cpg.graph, + undoable = false, + Option(keyPool) + ) + driver.bulkTx(adg1) + + val List(m: Map[String, Any]) = driver.propertyFromNodes(METHOD, NAME, ORDER) + m.get(NAME) shouldBe Some("foo") + m.get(ORDER) shouldBe Some(1) + val List(b: Map[String, Any]) = driver.propertyFromNodes(BLOCK, ORDER) + b.get(ORDER) shouldBe Some(1) + + driver.exists( + m.getOrElse("id", -1L).toString.toLong, + b.getOrElse("id", -1L).toString.toLong, + EdgeTypes.AST + ) shouldBe true + + diffGraph2.removeEdge( + cpg.graph.node(m.getOrElse("id", -1L).toString.toLong).outE(EdgeTypes.AST).next() + ) + val adg2 = + DiffGraph.Applier.applyDiff( + diffGraph2.build(), + cpg.graph, + undoable = false, + Option(keyPool) + ) + driver.bulkTx(adg2) + + driver.exists( + m.getOrElse("id", -1L).toString.toLong, + b.getOrElse("id", -1L).toString.toLong, + EdgeTypes.AST + ) shouldBe false + } } - "should reflect node subtractions in bulk transactions" in { - val cpg = Cpg.empty - val keyPool = new IntervalKeyPool(1, 1000) - val diffGraph1 = DiffGraph.newBuilder - val diffGraph2 = DiffGraph.newBuilder - // Create some nodes - diffGraph1.addNode(m1).addNode(b1) - val adg1 = - DiffGraph.Applier.applyDiff(diffGraph1.build(), cpg.graph, undoable = false, Option(keyPool)) - driver.bulkTx(adg1) - - val List(m: Map[String, Any]) = driver.propertyFromNodes(METHOD, NAME, ORDER) - m.get(NAME) shouldBe Some("foo") - m.get(ORDER) shouldBe Some(1) - val List(b: Map[String, Any]) = driver.propertyFromNodes(BLOCK, ORDER) - b.get(ORDER) shouldBe Some(1) - - // Remove one node - diffGraph2.removeNode(m.getOrElse("id", -1L).toString.toLong) - val adg2 = - DiffGraph.Applier.applyDiff(diffGraph2.build(), cpg.graph, undoable = false, Option(keyPool)) - driver.bulkTx(adg2) - - driver.propertyFromNodes(METHOD) shouldBe List() + def nodeToNodeCreate(n: NewNode): DetachedNodeGeneric = { + val props: Array[Object] = n.properties.flatMap { case (k, v) => + Iterable(k.asInstanceOf[Object], v.asInstanceOf[Object]) + }.toArray + new DetachedNodeGeneric(n.label(), props: _*) } - "should reflect edge additions in bulk transactions" in { - val cpg = Cpg.empty - val keyPool = new IntervalKeyPool(1, 1000) - val diffGraph1 = DiffGraph.newBuilder - val diffGraph2 = DiffGraph.newBuilder - // Create some nodes - diffGraph1.addNode(m1).addNode(b1) - val adg1 = - DiffGraph.Applier.applyDiff(diffGraph1.build(), cpg.graph, undoable = false, Option(keyPool)) - driver.bulkTx(adg1) - - val List(m: Map[String, Any]) = driver.propertyFromNodes(METHOD, NAME, ORDER) - m.get(NAME) shouldBe Some("foo") - m.get(ORDER) shouldBe Some(1) - val List(b: Map[String, Any]) = driver.propertyFromNodes(BLOCK, ORDER) - b.get(ORDER) shouldBe Some(1) - - // Add an edge - diffGraph2.addEdge( - cpg.graph.nodes(m.getOrElse("id", -1L).toString.toLong).next().asInstanceOf[AbstractNode], - cpg.graph.nodes(b.getOrElse("id", -1L).toString.toLong).next().asInstanceOf[AbstractNode], - EdgeTypes.AST - ) - val adg2 = - DiffGraph.Applier.applyDiff(diffGraph2.build(), cpg.graph, undoable = false, Option(keyPool)) - driver.bulkTx(adg2) - - driver.exists( - m.getOrElse("id", -1L).toString.toLong, - b.getOrElse("id", -1L).toString.toLong, - EdgeTypes.AST - ) shouldBe true - driver.exists( - b.getOrElse("id", -1L).toString.toLong, - m.getOrElse("id", -1L).toString.toLong, - EdgeTypes.AST - ) shouldBe false - } - - "should reflect edge removal in bulk transactions" in { - val cpg = Cpg.empty - val keyPool = new IntervalKeyPool(1, 1000) - val diffGraph1 = DiffGraph.newBuilder - val diffGraph2 = DiffGraph.newBuilder - // Create some nodes and an edge - diffGraph1.addNode(m1).addNode(b1).addEdge(m1, b1, EdgeTypes.AST) - val adg1 = - DiffGraph.Applier.applyDiff(diffGraph1.build(), cpg.graph, undoable = false, Option(keyPool)) - driver.bulkTx(adg1) - - val List(m: Map[String, Any]) = driver.propertyFromNodes(METHOD, NAME, ORDER) - m.get(NAME) shouldBe Some("foo") - m.get(ORDER) shouldBe Some(1) - val List(b: Map[String, Any]) = driver.propertyFromNodes(BLOCK, ORDER) - b.get(ORDER) shouldBe Some(1) - - driver.exists( - m.getOrElse("id", -1L).toString.toLong, - b.getOrElse("id", -1L).toString.toLong, - EdgeTypes.AST - ) shouldBe true - - diffGraph2.removeEdge( - cpg.graph.node(m.getOrElse("id", -1L).toString.toLong).outE(EdgeTypes.AST).next() - ) - val adg2 = - DiffGraph.Applier.applyDiff(diffGraph2.build(), cpg.graph, undoable = false, Option(keyPool)) - driver.bulkTx(adg2) - - driver.exists( - m.getOrElse("id", -1L).toString.toLong, - b.getOrElse("id", -1L).toString.toLong, - EdgeTypes.AST - ) shouldBe false + "overflowdb.BatchedUpdate.DiffGraph based changes" should { + + "should reflect node additions in bulk transactions" in { + val cpg = Cpg.empty + val keyPool = new IntervalKeyPool(1, 1000) + val diffGraph = new DiffGraphBuilder + // Create some nodes + diffGraph.addNode(nodeToNodeCreate(m1)).addNode(nodeToNodeCreate(b1)) + val adg = BatchedUpdate.applyDiff(cpg.graph, diffGraph.build(), keyPool, null) + driver.bulkTx(adg) + val List(m: Map[String, Any]) = + driver.propertyFromNodes(METHOD, NAME, ORDER, DYNAMIC_TYPE_HINT_FULL_NAME) + m.get(NAME) shouldBe Some("foo") + m.get(ORDER) shouldBe Some(1) + val List(b: Map[String, Any]) = driver.propertyFromNodes(BLOCK, ORDER) + b.get(ORDER) shouldBe Some(1) + } + + "should reflect node subtractions in bulk transactions" in { + val cpg = Cpg.empty + val keyPool = new IntervalKeyPool(1, 1000) + val diffGraph1 = new DiffGraphBuilder + val diffGraph2 = new DiffGraphBuilder + // Create some nodes + diffGraph1.addNode(nodeToNodeCreate(m1)).addNode(b1) + val adg1 = BatchedUpdate.applyDiff(cpg.graph, diffGraph1.build(), keyPool, null) + driver.bulkTx(adg1) + + val List(m: Map[String, Any]) = driver.propertyFromNodes(METHOD, NAME, ORDER) + m.get(NAME) shouldBe Some("foo") + m.get(ORDER) shouldBe Some(1) + val List(b: Map[String, Any]) = driver.propertyFromNodes(BLOCK, ORDER) + b.get(ORDER) shouldBe Some(1) + adg1.diffGraph.iterator.asScala + .collectFirst { + case c: DetachedNodeGeneric + if c.getRefOrId.asInstanceOf[Node].id() == m.getOrElse("id", -1L).toString.toLong => + c + } match { + case Some(mToCheck) => + // Remove one node + diffGraph2.removeNode(mToCheck.getRefOrId.asInstanceOf[Node]) + val adg2 = BatchedUpdate.applyDiff(cpg.graph, diffGraph2.build(), keyPool, null) + driver.bulkTx(adg2) + case None => fail("Unable to extract removed method node") + } + driver.propertyFromNodes(METHOD) shouldBe List() + } + + "should reflect edge additions in bulk transactions" in { + val cpg = Cpg.empty + val keyPool = new IntervalKeyPool(1, 1000) + val diffGraph1 = new DiffGraphBuilder + val diffGraph2 = new DiffGraphBuilder + // Create some nodes + diffGraph1.addNode(nodeToNodeCreate(m1)).addNode(b1) + val adg1 = BatchedUpdate.applyDiff(cpg.graph, diffGraph1.build(), keyPool, null) + driver.bulkTx(adg1) + val List(m: Map[String, Any]) = driver.propertyFromNodes(METHOD, NAME, ORDER) + m.get(NAME) shouldBe Some("foo") + m.get(ORDER) shouldBe Some(1) + val List(b: Map[String, Any]) = driver.propertyFromNodes(BLOCK, ORDER) + b.get(ORDER) shouldBe Some(1) + + // Add an edge + val srcNode = adg1.diffGraph.iterator.asScala + .collectFirst { + case c: DetachedNodeGeneric + if c.getRefOrId.asInstanceOf[Node].id() == m.getOrElse("id", -1L).toString.toLong => + c + } match { + case Some(src) => src + case None => fail("Unable to extract method node") + } + val dstNode = adg1.diffGraph.iterator.asScala + .collectFirst { + case c: NewBlock + if c.getRefOrId().asInstanceOf[Node].id() == b.getOrElse("id", -1L).toString.toLong => + c + } match { + case Some(dst) => dst + case None => fail("Unable to extract block node") + } + diffGraph2.addEdge(srcNode, dstNode, EdgeTypes.AST) + val adg2 = BatchedUpdate.applyDiff(cpg.graph, diffGraph2.build(), keyPool, null) + driver.bulkTx(adg2) + + driver.exists( + m.getOrElse("id", -1L).toString.toLong, + b.getOrElse("id", -1L).toString.toLong, + EdgeTypes.AST + ) shouldBe true + driver.exists( + b.getOrElse("id", -1L).toString.toLong, + m.getOrElse("id", -1L).toString.toLong, + EdgeTypes.AST + ) shouldBe false + } } "should accurately report which IDs have been taken" in { @@ -260,6 +403,48 @@ class PlumeDriverFixture(val driver: IDriver) if (driver.isConnected) driver.close() } + private def createSimpleGraph(dg: DiffGraphBuilder): Unit = { + dg.addNode(meta) + .addNode(f1) + .addNode(f2) + .addNode(td1) + .addNode(td2) + .addNode(t1) + .addNode(t2) + .addNode(n1) + .addNode(n2) + .addNode(m1) + .addNode(m2) + .addNode(m3) + .addNode(b1) + .addNode(c1) + .addNode(c2) + .addNode(li1) + .addNode(l1) + .addNode(i1) + .addEdge(m1, f1, EdgeTypes.SOURCE_FILE) + .addEdge(m2, f2, EdgeTypes.SOURCE_FILE) + .addEdge(m3, f1, EdgeTypes.SOURCE_FILE) + .addEdge(td1, f1, EdgeTypes.SOURCE_FILE) + .addEdge(td2, f2, EdgeTypes.SOURCE_FILE) + .addEdge(n1, f1, EdgeTypes.SOURCE_FILE) + .addEdge(n2, f2, EdgeTypes.SOURCE_FILE) + .addEdge(t1, td1, EdgeTypes.REF) + .addEdge(t2, td2, EdgeTypes.REF) + .addEdge(n1, td1, EdgeTypes.AST) + .addEdge(n2, td2, EdgeTypes.AST) + .addEdge(td1, m1, EdgeTypes.AST) + .addEdge(td1, m3, EdgeTypes.AST) + .addEdge(td2, m2, EdgeTypes.AST) + .addEdge(m1, b1, EdgeTypes.AST) + .addEdge(b1, c1, EdgeTypes.AST) + .addEdge(b1, c2, EdgeTypes.AST) + .addEdge(b1, l1, EdgeTypes.AST) + .addEdge(c1, li1, EdgeTypes.AST) + .addEdge(c1, i1, EdgeTypes.AST) + .addEdge(m1, c1, EdgeTypes.CFG) + } + private def createSimpleGraph(dg: DiffGraph.Builder): Unit = { dg.addNode(meta) .addNode(f1)