Skip to content

Commit

Permalink
Added registerReplacement to ScopedWalker (#1403)
Browse files Browse the repository at this point in the history
In some frontends (e.g., the Go frontend) we are replacing certain nodes in an extra pass, while iterating with the `ScopedWalker`. One such usecase is to replace calls with casts. Previously, when doing so, chained replacement of calls were a problem, since the original replaced value was stil registered in the walker and was used to determine further AST children to "walk" on, instead of the new node.

This adds a function to register such replacements with the walker.

Co-authored-by: Konrad Weiss <[email protected]>
  • Loading branch information
oxisto and konradweiss authored Jan 10, 2024
1 parent 203cf40 commit 360186c
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,10 @@ object SubgraphWalker {
* place where usual graph manipulation will happen. The current node is the single argument
* passed to the function
*/
private val onNodeVisit: MutableList<Consumer<Node>> = ArrayList()
private val onNodeVisit2: MutableList<BiConsumer<Node, Node?>> = ArrayList()
private val onNodeVisit: MutableList<Consumer<Node>> = mutableListOf()
private val onNodeVisit2: MutableList<BiConsumer<Node, Node?>> = mutableListOf()

private val replacements = mutableMapOf<Node, Node>()

/**
* The callback that is designed to tell the user when we leave the current scope. The
Expand Down Expand Up @@ -305,24 +307,33 @@ object SubgraphWalker {
val seen: MutableSet<Node> = LinkedHashSet()
todo?.push(Pair<Node, Node?>(root, null))
while ((todo as ArrayDeque<Pair<Node, Node?>>).isNotEmpty()) {
val (current, parent) = (todo as ArrayDeque<Pair<Node, Node?>>).pop()
var (current, parent) = (todo as ArrayDeque<Pair<Node, Node?>>).pop()
if (
(backlog as ArrayDeque<Node>).isNotEmpty() &&
(backlog as ArrayDeque<Node>).peek() == current
) {
val exiting = (backlog as ArrayDeque<Node>).pop()
onNodeExit.forEach(Consumer { c: Consumer<Node> -> c.accept(exiting) })
} else {
// re-place the current node as a marker for the above check to find out when we
// need to exit a scope
(todo as ArrayDeque<Pair<Node, Node?>>).push(Pair(current, parent))
onNodeVisit.forEach(Consumer { c: Consumer<Node> -> c.accept(current) })
onNodeVisit2.forEach(
Consumer { c: BiConsumer<Node, Node?> -> c.accept(current, parent) }
)

// Check if we have a replacement node
val toReplace = replacements[current]
if (toReplace != null) {
current = toReplace
replacements.remove(toReplace)
}

val unseenChildren =
strategy(current).asSequence().filter { it !in seen }.toMutableList()

// re-place the current node as a marker for the above check to find out when we
// need to exit a scope
(todo as ArrayDeque<Pair<Node, Node?>>).push(Pair(current, parent))

seen.addAll(unseenChildren)
unseenChildren.asReversed().forEach { child: Node ->
(todo as ArrayDeque<Pair<Node, Node?>>).push(Pair(child, current))
Expand All @@ -332,6 +343,15 @@ object SubgraphWalker {
}
}

/**
* Sometimes during walking the graph, we are replacing the current node. This causes
* problems, that the walker still assumes the old node. Calling this function will ensure
* that the walker knows about the new node.
*/
fun registerReplacement(from: Node, to: Node) {
replacements[from] = to
}

fun registerOnNodeVisit(callback: Consumer<Node>) {
onNodeVisit.add(callback)
}
Expand Down Expand Up @@ -408,6 +428,10 @@ object SubgraphWalker {
)
}

fun registerReplacement(from: Node, to: Node) {
walker?.registerReplacement(from, to)
}

/**
* Wraps [IterativeGraphWalker] to handle declaration scopes.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,15 @@ import de.fraunhofer.aisec.cpg.passes.order.ExecuteBefore
@ExecuteBefore(DFGPass::class)
class GoExtraPass(ctx: TranslationContext) : ComponentPass(ctx) {

private lateinit var walker: SubgraphWalker.ScopedWalker

override fun accept(component: Component) {
// Add built-int functions, but only if one of the components contains a GoLanguage
if (component.translationUnits.any { it.language is GoLanguage }) {
component.translationUnits += addBuiltIn()
}

val walker = SubgraphWalker.ScopedWalker(scopeManager)
walker = SubgraphWalker.ScopedWalker(scopeManager)
walker.registerHandler { _, parent, node ->
when (node) {
is CallExpression -> handleCall(node, parent)
Expand Down Expand Up @@ -467,6 +469,9 @@ class GoExtraPass(ctx: TranslationContext) : ComponentPass(ctx) {
)
} else {
call.disconnectFromGraph()

// Make sure to inform the walker about our change
walker.registerReplacement(call, cast)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1089,6 +1089,28 @@ class GoLanguageFrontendTest : BaseTest() {
assertInvokes(call, elem)
}

@Test
fun testChainedCast() {
val topLevel = Path.of("src", "test", "resources", "golang")
val tu =
analyze(
listOf(
topLevel.resolve("cast.go").toFile(),
),
topLevel,
true
) {
it.registerLanguage<GoLanguage>()
}
assertNotNull(tu)

assertEquals(0, tu.calls.size)
assertEquals(
listOf("string", "error", "p.myError"),
tu.casts.map { it.castType.name.toString() }
)
}

@Test
fun testComplexResolution() {
val topLevel = Path.of("src", "test", "resources", "golang", "complex_resolution")
Expand Down
9 changes: 9 additions & 0 deletions cpg-language-go/src/test/resources/golang/cast.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package p

type myError string

func (err myError) Error() string {
return string(err)
}

var s = error(myError("abc"))

0 comments on commit 360186c

Please sign in to comment.