Skip to content

Commit

Permalink
Persisting edge properties
Browse files Browse the repository at this point in the history
  • Loading branch information
oxisto committed Dec 11, 2024
1 parent 0097423 commit 875164a
Showing 1 changed file with 92 additions and 85 deletions.
177 changes: 92 additions & 85 deletions cpg-neo4j/src/main/kotlin/de/fraunhofer/aisec/cpg/v2/Persistence.kt
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,25 @@ package de.fraunhofer.aisec.cpg.v2
import de.fraunhofer.aisec.cpg.TranslationResult
import de.fraunhofer.aisec.cpg.graph.Name
import de.fraunhofer.aisec.cpg.graph.Node
import de.fraunhofer.aisec.cpg.graph.Persistable
import de.fraunhofer.aisec.cpg.graph.edges.Edge
import de.fraunhofer.aisec.cpg.graph.edges.allEdges
import de.fraunhofer.aisec.cpg.graph.edges.flows.DependenceType
import de.fraunhofer.aisec.cpg.graph.edges.flows.Granularity
import de.fraunhofer.aisec.cpg.graph.nodes
import de.fraunhofer.aisec.cpg.helpers.Benchmark
import de.fraunhofer.aisec.cpg.helpers.neo4j.DataflowGranularityConverter
import de.fraunhofer.aisec.cpg.helpers.neo4j.NameConverter
import de.fraunhofer.aisec.cpg.helpers.neo4j.SimpleNameConverter
import kotlin.collections.joinToString
import kotlin.collections.plusAssign
import kotlin.reflect.KClass
import kotlin.reflect.KProperty1
import kotlin.reflect.full.createType
import kotlin.reflect.full.memberProperties
import kotlin.reflect.full.withNullability
import kotlin.uuid.Uuid
import org.neo4j.driver.GraphDatabase
import org.neo4j.driver.TransactionContext
import org.neo4j.ogm.typeconversion.CompositeAttributeConverter
import org.slf4j.LoggerFactory

/**
Expand All @@ -53,161 +58,163 @@ val dbUser = "neo4j"
val dbPassword = "password"

val neo4jSession by lazy {
GraphDatabase.driver(dbUri, org.neo4j.driver.AuthTokens.basic(dbUser, dbPassword)).session()
val driver = GraphDatabase.driver(dbUri, org.neo4j.driver.AuthTokens.basic(dbUser, dbPassword))
driver.session()
}

val labelCache: MutableMap<KClass<out Node>, Set<String>> = mutableMapOf()

val schemaPropertiesCache: MutableMap<KClass<out Node>, Map<String, KProperty1<out Node, *>>> =
val schemaPropertiesCache:
MutableMap<KClass<out Persistable>, Map<String, KProperty1<out Persistable, *>>> =
mutableMapOf()

val log = LoggerFactory.getLogger("Persistence")

val edgeChunkSize = 10000
val nodeChunkSize = 10000

fun TranslationResult.persist() {
val nodes = this@persist.nodes
val edges = this@persist.allEdges<Edge<*>>()

val b = Benchmark(Persistable::class.java, "Persisting translation result")

log.info("Persisting {} nodes", nodes.size)
nodes.persist()

log.info("Persisting {} edges", edges.size)
edges.persist()

b.stop()
}

private fun List<Node>.persist() {
val groups = groupBy { it::class }
groups.forEach {
it.value.chunked(10000).forEach { chunk ->
log.info("Processing ${chunk.size} nodes of type ${it.key}")

val params = mapOf("props" to chunk.map { it.properties() })
val start = System.currentTimeMillis()
neo4jSession.executeWrite { tx ->
tx.run(
"UNWIND \$props AS map CREATE (n:${it.key.labels.joinToString("&")}) SET n=map",
params
)
.consume()
}
log.info(
"Time Taken to process and save ${chunk.size} records to Neo4j Batch Insert took ${System.currentTimeMillis() - start} ms"
)
this.chunked(nodeChunkSize).map { chunk ->
val b = Benchmark(Persistable::class.java, "Persisting chunk of ${chunk.size} nodes")
val params =
mapOf("props" to chunk.map { mapOf("labels" to it::class.labels) + it.properties() })
neo4jSession.executeWrite { tx ->
tx.run(
"""
UNWIND ${"$"}props AS map
WITH map, apoc.map.removeKeys(map, ['labels']) AS properties
CALL apoc.create.node(map.labels, properties) YIELD node
RETURN node
""",
params
)
.consume()
}
b.stop()
}
}

private fun Collection<Edge<*>>.persist() {
// Create an index for the "id" field of node, because we are "MATCH"ing on it in the edge
// creation. We need to wait for this to be finished
neo4jSession.executeWrite { tx ->
tx.run("CREATE INDEX IF NOT EXISTS FOR (n:Node) ON (n.id)").consume()
}

val groups = groupBy { it.label }
groups.forEach {
it.value.chunked(10000).forEach { chunk ->
log.info("Processing ${chunk.size} edges of type ${it.key}")

val params =
mapOf(
"props" to
chunk.map {
mapOf(
"startId" to it.start.id.toString(),
"endId" to it.end.id.toString(),
)
}
)
val start = System.currentTimeMillis()
neo4jSession.executeWrite { tx ->
tx.run(
"""
this.chunked(edgeChunkSize).map { chunk ->
val b = Benchmark(Persistable::class.java, "Persisting chunk of ${chunk.size} edges")
val params =
mapOf(
"props" to
chunk.map {
mapOf(
"startId" to it.start.id.toString(),
"endId" to it.end.id.toString(),
"type" to it.label
) + it.properties()
}
)
neo4jSession.executeWrite { tx ->
tx.run(
"""
UNWIND ${'$'}props AS map
MATCH (s:Node {id: map.startId})
MATCH (e:Node {id: map.endId})
CREATE (s)-[r:${it.key} {}]->(e)
WITH s, e, map, apoc.map.removeKeys(map, ['startId', 'endId', 'type']) AS properties
CALL apoc.create.relationship(s, map.type, properties, e) YIELD rel
RETURN rel
"""
.trimIndent(),
params
)
.consume()
}
log.info(
"Time Taken to process and save ${chunk.size} records to Neo4j Batch Insert took ${System.currentTimeMillis() - start} ms"
)
.trimIndent(),
params
)
.consume()
}
b.stop()
}
}

/**
* Returns the node's properties. This DOES NOT include relationships, but only properties directly
* attached to the node.
* Returns the [Persistable]'s properties. This DOES NOT include relationships, but only properties
* directly attached to the node/edge.
*/
fun Node.properties(): Map<String, Any?> {
fun Persistable.properties(): Map<String, Any?> {
val properties = mutableMapOf<String, Any?>()
for (entry in schemaProperties) {
for (entry in this::class.schemaProperties) {
val value = entry.value.call(this)

if (value == null) {
continue
}

// TODO: generalize conversions
if (value is Name && entry.key == "name") {
properties += NameConverter().toGraphProperties(value)
} else if (value is Name) {
properties.put(entry.key, SimpleNameConverter().toGraphProperty(value))
} else if (value is Uuid) {
properties.put(entry.key, value.toString())
} else {
properties.put(entry.key, value)
}
value.convert(entry.key, properties)
}

return properties
}

context(TransactionContext)
fun Edge<*>.persist() {
this@TransactionContext.run(
"MATCH (start { id: \$startId }), (end { id: \$endId } ) MERGE (start)-[r:${label} { }]->(end)",
mapOf("startId" to this.start.id.toString(), "endId" to this.end.id.toString())
)
.consume()
}

val Node.labels: Set<String>
get() {
val klazz = this::class

// Check, if we already computed the labels for this node's class
return labelCache.computeIfAbsent(klazz) { setOf<String>("Node", klazz.simpleName!!) }
/**
* Runs any conversions that are necessary by [CompositeAttributeConverter] and
* [org.neo4j.ogm.typeconversion.AttributeConverter]. Since both of these classes are Neo4J OGM
* classes, we need to find new base types at some point.
*/
fun Any.convert(originalKey: String, properties: MutableMap<String, Any?>) {
// TODO: generalize conversions
if (this is Name && originalKey == "name") {
properties += NameConverter().toGraphProperties(this)
} else if (this is Name) {
properties.put(originalKey, SimpleNameConverter().toGraphProperty(this))
} else if (this is Granularity) {
properties += DataflowGranularityConverter().toGraphProperties(this)
} else if (this is Enum<*>) {
properties.put(originalKey, this.name)
} else if (this is Uuid) {
properties.put(originalKey, this.toString())
} else {
properties.put(originalKey, this)
}
}

val KClass<out Node>.labels: Set<String>
get() {
// Check, if we already computed the labels for this node's class
return labelCache.computeIfAbsent(this) { setOf<String>("Node", this.simpleName!!) }
}

val primitiveTypes =
val propertyTypes =
setOf(
String::class.createType(),
Int::class.createType(),
Long::class.createType(),
Boolean::class.createType(),
Name::class.createType(),
Uuid::class.createType(),
Granularity::class.createType(),
DependenceType::class.createType(),
)

val Node.schemaProperties: Map<String, KProperty1<out Node, *>>
val KClass<out Persistable>.schemaProperties: Map<String, KProperty1<out Persistable, *>>
get() {
val klazz = this::class

// Check, if we already computed the labels for this node's class
return schemaPropertiesCache.computeIfAbsent(klazz) {
val schema = mutableMapOf<String, KProperty1<out Node, *>>()
return schemaPropertiesCache.computeIfAbsent(this) {
val schema = mutableMapOf<String, KProperty1<out Persistable, *>>()
val properties = it.memberProperties
for (property in properties) {
if (property.returnType.withNullability(false) in primitiveTypes) {
if (property.returnType.withNullability(false) in propertyTypes) {
schema.put(property.name, property)
}
}
Expand Down

0 comments on commit 875164a

Please sign in to comment.