From bb2b6fa170b5fd4e5fbdbd31f190daf9138c948c Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Tue, 2 Jan 2024 12:41:44 +0100 Subject: [PATCH] Added `registerReplacement` to `ScopedWalker` In some frontends (e.g., the Go frontend) we are replacing certain nodes in an extra pass, while iterating with the `ScopedWalker`. One such usecase is to replace calls with casts. Previously, when doing so, chained replacement of calls were a problem, since the original replaced value was stil registered in the walker and was used to determine further AST children to "walk" on, instead of the new node. This adds a function to register such replacements with the walker. --- .../aisec/cpg/helpers/SubgraphWalker.kt | 36 +++++++++++++++---- .../aisec/cpg/passes/GoExtraPass.kt | 7 +++- .../golang/GoLanguageFrontendTest.kt | 22 ++++++++++++ .../src/test/resources/golang/cast.go | 9 +++++ 4 files changed, 67 insertions(+), 7 deletions(-) create mode 100644 cpg-language-go/src/test/resources/golang/cast.go diff --git a/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/helpers/SubgraphWalker.kt b/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/helpers/SubgraphWalker.kt index 58ceb9ab9d..c8ea8f02f8 100644 --- a/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/helpers/SubgraphWalker.kt +++ b/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/helpers/SubgraphWalker.kt @@ -268,8 +268,10 @@ object SubgraphWalker { * place where usual graph manipulation will happen. The current node is the single argument * passed to the function */ - private val onNodeVisit: MutableList> = ArrayList() - private val onNodeVisit2: MutableList> = ArrayList() + private val onNodeVisit: MutableList> = mutableListOf() + private val onNodeVisit2: MutableList> = mutableListOf() + + private val replacements = mutableMapOf() /** * The callback that is designed to tell the user when we leave the current scope. The @@ -305,7 +307,7 @@ object SubgraphWalker { val seen: MutableSet = LinkedHashSet() todo?.push(Pair(root, null)) while ((todo as ArrayDeque>).isNotEmpty()) { - val (current, parent) = (todo as ArrayDeque>).pop() + var (current, parent) = (todo as ArrayDeque>).pop() if ( (backlog as ArrayDeque).isNotEmpty() && (backlog as ArrayDeque).peek() == current @@ -313,16 +315,25 @@ object SubgraphWalker { val exiting = (backlog as ArrayDeque).pop() onNodeExit.forEach(Consumer { c: Consumer -> c.accept(exiting) }) } else { - // re-place the current node as a marker for the above check to find out when we - // need to exit a scope - (todo as ArrayDeque>).push(Pair(current, parent)) onNodeVisit.forEach(Consumer { c: Consumer -> c.accept(current) }) onNodeVisit2.forEach( Consumer { c: BiConsumer -> c.accept(current, parent) } ) + + // Check if we have a replacement node + val toReplace = replacements[current] + if (toReplace != null) { + current = toReplace + replacements.remove(toReplace) + } + val unseenChildren = strategy(current).asSequence().filter { it !in seen }.toMutableList() + // re-place the current node as a marker for the above check to find out when we + // need to exit a scope + (todo as ArrayDeque>).push(Pair(current, parent)) + seen.addAll(unseenChildren) unseenChildren.asReversed().forEach { child: Node -> (todo as ArrayDeque>).push(Pair(child, current)) @@ -332,6 +343,15 @@ object SubgraphWalker { } } + /** + * Sometimes during walking the graph, we are replacing the current node. This causes + * problems, that the walker still assumes the old node. Calling this function will ensure + * that the walker knows about the new node. + */ + fun registerReplacement(from: Node, to: Node) { + replacements[from] = to + } + fun registerOnNodeVisit(callback: Consumer) { onNodeVisit.add(callback) } @@ -408,6 +428,10 @@ object SubgraphWalker { ) } + fun registerReplacement(from: Node, to: Node) { + walker?.registerReplacement(from, to) + } + /** * Wraps [IterativeGraphWalker] to handle declaration scopes. * diff --git a/cpg-language-go/src/main/kotlin/de/fraunhofer/aisec/cpg/passes/GoExtraPass.kt b/cpg-language-go/src/main/kotlin/de/fraunhofer/aisec/cpg/passes/GoExtraPass.kt index a287d7dfac..1111bf3ca2 100644 --- a/cpg-language-go/src/main/kotlin/de/fraunhofer/aisec/cpg/passes/GoExtraPass.kt +++ b/cpg-language-go/src/main/kotlin/de/fraunhofer/aisec/cpg/passes/GoExtraPass.kt @@ -110,13 +110,15 @@ import de.fraunhofer.aisec.cpg.passes.order.ExecuteBefore @ExecuteBefore(DFGPass::class) class GoExtraPass(ctx: TranslationContext) : ComponentPass(ctx) { + private lateinit var walker: SubgraphWalker.ScopedWalker + override fun accept(component: Component) { // Add built-int functions, but only if one of the components contains a GoLanguage if (component.translationUnits.any { it.language is GoLanguage }) { component.translationUnits += addBuiltIn() } - val walker = SubgraphWalker.ScopedWalker(scopeManager) + walker = SubgraphWalker.ScopedWalker(scopeManager) walker.registerHandler { _, parent, node -> when (node) { is CallExpression -> handleCall(node, parent) @@ -467,6 +469,9 @@ class GoExtraPass(ctx: TranslationContext) : ComponentPass(ctx) { ) } else { call.disconnectFromGraph() + + // Make sure to inform the walker about our change + walker.registerReplacement(call, cast) } } diff --git a/cpg-language-go/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/golang/GoLanguageFrontendTest.kt b/cpg-language-go/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/golang/GoLanguageFrontendTest.kt index 5d56ceefdd..90b3dbb513 100644 --- a/cpg-language-go/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/golang/GoLanguageFrontendTest.kt +++ b/cpg-language-go/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/golang/GoLanguageFrontendTest.kt @@ -1089,6 +1089,28 @@ class GoLanguageFrontendTest : BaseTest() { assertInvokes(call, elem) } + @Test + fun testChainedCast() { + val topLevel = Path.of("src", "test", "resources", "golang") + val tu = + analyze( + listOf( + topLevel.resolve("cast.go").toFile(), + ), + topLevel, + true + ) { + it.registerLanguage() + } + assertNotNull(tu) + + assertEquals(0, tu.calls.size) + assertEquals( + listOf("string", "error", "p.myError"), + tu.casts.map { it.castType.name.toString() } + ) + } + @Test fun testComplexResolution() { val topLevel = Path.of("src", "test", "resources", "golang", "complex_resolution") diff --git a/cpg-language-go/src/test/resources/golang/cast.go b/cpg-language-go/src/test/resources/golang/cast.go new file mode 100644 index 0000000000..2ea3fad7d6 --- /dev/null +++ b/cpg-language-go/src/test/resources/golang/cast.go @@ -0,0 +1,9 @@ +package p + +type myError string + +func (err myError) Error() string { + return string(err) +} + +var s = error(myError("abc"))