From 875164a4e6def2abea3735d91c5340c8cb2b4101 Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Wed, 11 Dec 2024 09:42:53 +0100 Subject: [PATCH] Persisting edge properties --- .../de/fraunhofer/aisec/cpg/v2/Persistence.kt | 177 +++++++++--------- 1 file changed, 92 insertions(+), 85 deletions(-) diff --git a/cpg-neo4j/src/main/kotlin/de/fraunhofer/aisec/cpg/v2/Persistence.kt b/cpg-neo4j/src/main/kotlin/de/fraunhofer/aisec/cpg/v2/Persistence.kt index 46929daa33..26664a7533 100644 --- a/cpg-neo4j/src/main/kotlin/de/fraunhofer/aisec/cpg/v2/Persistence.kt +++ b/cpg-neo4j/src/main/kotlin/de/fraunhofer/aisec/cpg/v2/Persistence.kt @@ -28,12 +28,17 @@ package de.fraunhofer.aisec.cpg.v2 import de.fraunhofer.aisec.cpg.TranslationResult import de.fraunhofer.aisec.cpg.graph.Name import de.fraunhofer.aisec.cpg.graph.Node +import de.fraunhofer.aisec.cpg.graph.Persistable import de.fraunhofer.aisec.cpg.graph.edges.Edge import de.fraunhofer.aisec.cpg.graph.edges.allEdges +import de.fraunhofer.aisec.cpg.graph.edges.flows.DependenceType +import de.fraunhofer.aisec.cpg.graph.edges.flows.Granularity import de.fraunhofer.aisec.cpg.graph.nodes +import de.fraunhofer.aisec.cpg.helpers.Benchmark +import de.fraunhofer.aisec.cpg.helpers.neo4j.DataflowGranularityConverter import de.fraunhofer.aisec.cpg.helpers.neo4j.NameConverter import de.fraunhofer.aisec.cpg.helpers.neo4j.SimpleNameConverter -import kotlin.collections.joinToString +import kotlin.collections.plusAssign import kotlin.reflect.KClass import kotlin.reflect.KProperty1 import kotlin.reflect.full.createType @@ -41,7 +46,7 @@ import kotlin.reflect.full.memberProperties import kotlin.reflect.full.withNullability import kotlin.uuid.Uuid import org.neo4j.driver.GraphDatabase -import org.neo4j.driver.TransactionContext +import org.neo4j.ogm.typeconversion.CompositeAttributeConverter import org.slf4j.LoggerFactory /** @@ -53,134 +58,136 @@ val dbUser = "neo4j" val dbPassword = "password" val neo4jSession by lazy { - GraphDatabase.driver(dbUri, org.neo4j.driver.AuthTokens.basic(dbUser, dbPassword)).session() + val driver = GraphDatabase.driver(dbUri, org.neo4j.driver.AuthTokens.basic(dbUser, dbPassword)) + driver.session() } val labelCache: MutableMap, Set> = mutableMapOf() -val schemaPropertiesCache: MutableMap, Map>> = +val schemaPropertiesCache: + MutableMap, Map>> = mutableMapOf() val log = LoggerFactory.getLogger("Persistence") +val edgeChunkSize = 10000 +val nodeChunkSize = 10000 + fun TranslationResult.persist() { val nodes = this@persist.nodes val edges = this@persist.allEdges>() + val b = Benchmark(Persistable::class.java, "Persisting translation result") + log.info("Persisting {} nodes", nodes.size) nodes.persist() log.info("Persisting {} edges", edges.size) edges.persist() + + b.stop() } private fun List.persist() { - val groups = groupBy { it::class } - groups.forEach { - it.value.chunked(10000).forEach { chunk -> - log.info("Processing ${chunk.size} nodes of type ${it.key}") - - val params = mapOf("props" to chunk.map { it.properties() }) - val start = System.currentTimeMillis() - neo4jSession.executeWrite { tx -> - tx.run( - "UNWIND \$props AS map CREATE (n:${it.key.labels.joinToString("&")}) SET n=map", - params - ) - .consume() - } - log.info( - "Time Taken to process and save ${chunk.size} records to Neo4j Batch Insert took ${System.currentTimeMillis() - start} ms" - ) + this.chunked(nodeChunkSize).map { chunk -> + val b = Benchmark(Persistable::class.java, "Persisting chunk of ${chunk.size} nodes") + val params = + mapOf("props" to chunk.map { mapOf("labels" to it::class.labels) + it.properties() }) + neo4jSession.executeWrite { tx -> + tx.run( + """ + UNWIND ${"$"}props AS map + WITH map, apoc.map.removeKeys(map, ['labels']) AS properties + CALL apoc.create.node(map.labels, properties) YIELD node + RETURN node + """, + params + ) + .consume() } + b.stop() } } private fun Collection>.persist() { + // Create an index for the "id" field of node, because we are "MATCH"ing on it in the edge + // creation. We need to wait for this to be finished neo4jSession.executeWrite { tx -> tx.run("CREATE INDEX IF NOT EXISTS FOR (n:Node) ON (n.id)").consume() } - val groups = groupBy { it.label } - groups.forEach { - it.value.chunked(10000).forEach { chunk -> - log.info("Processing ${chunk.size} edges of type ${it.key}") - - val params = - mapOf( - "props" to - chunk.map { - mapOf( - "startId" to it.start.id.toString(), - "endId" to it.end.id.toString(), - ) - } - ) - val start = System.currentTimeMillis() - neo4jSession.executeWrite { tx -> - tx.run( - """ + this.chunked(edgeChunkSize).map { chunk -> + val b = Benchmark(Persistable::class.java, "Persisting chunk of ${chunk.size} edges") + val params = + mapOf( + "props" to + chunk.map { + mapOf( + "startId" to it.start.id.toString(), + "endId" to it.end.id.toString(), + "type" to it.label + ) + it.properties() + } + ) + neo4jSession.executeWrite { tx -> + tx.run( + """ UNWIND ${'$'}props AS map MATCH (s:Node {id: map.startId}) MATCH (e:Node {id: map.endId}) - CREATE (s)-[r:${it.key} {}]->(e) + WITH s, e, map, apoc.map.removeKeys(map, ['startId', 'endId', 'type']) AS properties + CALL apoc.create.relationship(s, map.type, properties, e) YIELD rel + RETURN rel """ - .trimIndent(), - params - ) - .consume() - } - log.info( - "Time Taken to process and save ${chunk.size} records to Neo4j Batch Insert took ${System.currentTimeMillis() - start} ms" - ) + .trimIndent(), + params + ) + .consume() } + b.stop() } } /** - * Returns the node's properties. This DOES NOT include relationships, but only properties directly - * attached to the node. + * Returns the [Persistable]'s properties. This DOES NOT include relationships, but only properties + * directly attached to the node/edge. */ -fun Node.properties(): Map { +fun Persistable.properties(): Map { val properties = mutableMapOf() - for (entry in schemaProperties) { + for (entry in this::class.schemaProperties) { val value = entry.value.call(this) if (value == null) { continue } - // TODO: generalize conversions - if (value is Name && entry.key == "name") { - properties += NameConverter().toGraphProperties(value) - } else if (value is Name) { - properties.put(entry.key, SimpleNameConverter().toGraphProperty(value)) - } else if (value is Uuid) { - properties.put(entry.key, value.toString()) - } else { - properties.put(entry.key, value) - } + value.convert(entry.key, properties) } return properties } -context(TransactionContext) -fun Edge<*>.persist() { - this@TransactionContext.run( - "MATCH (start { id: \$startId }), (end { id: \$endId } ) MERGE (start)-[r:${label} { }]->(end)", - mapOf("startId" to this.start.id.toString(), "endId" to this.end.id.toString()) - ) - .consume() -} - -val Node.labels: Set - get() { - val klazz = this::class - - // Check, if we already computed the labels for this node's class - return labelCache.computeIfAbsent(klazz) { setOf("Node", klazz.simpleName!!) } +/** + * Runs any conversions that are necessary by [CompositeAttributeConverter] and + * [org.neo4j.ogm.typeconversion.AttributeConverter]. Since both of these classes are Neo4J OGM + * classes, we need to find new base types at some point. + */ +fun Any.convert(originalKey: String, properties: MutableMap) { + // TODO: generalize conversions + if (this is Name && originalKey == "name") { + properties += NameConverter().toGraphProperties(this) + } else if (this is Name) { + properties.put(originalKey, SimpleNameConverter().toGraphProperty(this)) + } else if (this is Granularity) { + properties += DataflowGranularityConverter().toGraphProperties(this) + } else if (this is Enum<*>) { + properties.put(originalKey, this.name) + } else if (this is Uuid) { + properties.put(originalKey, this.toString()) + } else { + properties.put(originalKey, this) } +} val KClass.labels: Set get() { @@ -188,7 +195,7 @@ val KClass.labels: Set return labelCache.computeIfAbsent(this) { setOf("Node", this.simpleName!!) } } -val primitiveTypes = +val propertyTypes = setOf( String::class.createType(), Int::class.createType(), @@ -196,18 +203,18 @@ val primitiveTypes = Boolean::class.createType(), Name::class.createType(), Uuid::class.createType(), + Granularity::class.createType(), + DependenceType::class.createType(), ) -val Node.schemaProperties: Map> +val KClass.schemaProperties: Map> get() { - val klazz = this::class - // Check, if we already computed the labels for this node's class - return schemaPropertiesCache.computeIfAbsent(klazz) { - val schema = mutableMapOf>() + return schemaPropertiesCache.computeIfAbsent(this) { + val schema = mutableMapOf>() val properties = it.memberProperties for (property in properties) { - if (property.returnType.withNullability(false) in primitiveTypes) { + if (property.returnType.withNullability(false) in propertyTypes) { schema.put(property.name, property) } }