Skip to content

Commit

Permalink
Preparing dynamic invoke resolving for Go
Browse files Browse the repository at this point in the history
  • Loading branch information
oxisto committed Sep 30, 2023
1 parent a7fb181 commit 4a7dd5b
Show file tree
Hide file tree
Showing 13 changed files with 87 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -659,13 +659,11 @@ class ScopeManager : ScopeProvider {
call: CallExpression,
startScope: Scope? = currentScope
): List<FunctionDeclaration> {
val language = call.language

val (scope, name) = extractScope(call, startScope)

val func =
resolve<FunctionDeclaration>(scope) {
it.name.lastPartsMatch(name) && it.hasSignature(call.signature, call.arguments)
it.name.lastPartsMatch(name) && it.hasSignature(call)
}

return func
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import de.fraunhofer.aisec.cpg.graph.edge.PropertyEdge.Companion.propertyEqualsL
import de.fraunhofer.aisec.cpg.graph.edge.PropertyEdgeDelegate
import de.fraunhofer.aisec.cpg.graph.statements.Statement
import de.fraunhofer.aisec.cpg.graph.statements.expressions.Block
import de.fraunhofer.aisec.cpg.graph.statements.expressions.CallExpression
import de.fraunhofer.aisec.cpg.graph.statements.expressions.Expression
import de.fraunhofer.aisec.cpg.graph.types.Type
import de.fraunhofer.aisec.cpg.isDerivedFrom
Expand Down Expand Up @@ -111,6 +112,10 @@ open class FunctionDeclaration : ValueDeclaration(), DeclarationHolder, Resoluti
targetFunctionDeclaration.signatureTypes == signatureTypes
}

fun hasSignature(call: CallExpression): Boolean {
return hasSignature(call.signature, call.arguments)
}

// TODO: Documentation required. It's not completely clear what this method is supposed to do.
fun hasSignature(
targetSignature: List<Type>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class LambdaExpression : Expression(), HasType.TypeObserver {

// We should only propagate a function type, coming from our declared function
if (newType is FunctionType) {
// TODO(oxisto): We should discuss at some point, whether we should actually return
// a FunctionType instead of a FunctionPointerType
// Propagate a pointer reference to the function
this.type = newType.pointer()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ open class SymbolResolver(ctx: TranslationContext) : ComponentPass(ctx) {
// are running into the special case explained above. Otherwise, we abort here (for
// now).
wouldResolveTo = scopeManager.resolveReference(current, current.scope)
if (wouldResolveTo !is VariableDeclaration) {
if (wouldResolveTo !is VariableDeclaration && wouldResolveTo !is ParameterDeclaration) {
return
}
}
Expand Down Expand Up @@ -441,8 +441,13 @@ open class SymbolResolver(ctx: TranslationContext) : ComponentPass(ctx) {
}

protected fun handleCallExpression(curClass: RecordDeclaration?, call: CallExpression) {
// Function pointers are handled by extra pass, so we are not resolving them here
if (call.callee?.type is FunctionPointerType) {
// Dynamic function invokes (such as function pointers) are handled by extra pass, so we are
// not resolving them here. In this case, our callee refers to a variable rather than a
// function.
if (
(call.callee as? Reference)?.refersTo is VariableDeclaration ||
(call.callee as? Reference)?.refersTo is ParameterDeclaration
) {
return
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import de.fraunhofer.aisec.cpg.graph.types.*
import de.fraunhofer.aisec.cpg.helpers.Benchmark
import de.fraunhofer.aisec.cpg.helpers.Util
import de.fraunhofer.aisec.cpg.passes.CXXExtraPass
import de.fraunhofer.aisec.cpg.passes.FunctionPointerCallResolver
import de.fraunhofer.aisec.cpg.passes.DynamicInvokeResolver
import de.fraunhofer.aisec.cpg.passes.order.RegisterExtraPass
import de.fraunhofer.aisec.cpg.sarif.PhysicalLocation
import de.fraunhofer.aisec.cpg.sarif.Region
Expand Down Expand Up @@ -78,7 +78,7 @@ import org.slf4j.LoggerFactory
* ad [GPPLanguage]). This enables us (to some degree) to deal with the finer difference between C
* and C++ code.
*/
@RegisterExtraPass(FunctionPointerCallResolver::class)
@RegisterExtraPass(DynamicInvokeResolver::class)
@RegisterExtraPass(CXXExtraPass::class)
class CXXLanguageFrontend(language: Language<CXXLanguageFrontend>, ctx: TranslationContext) :
LanguageFrontend<IASTNode, IASTTypeId>(language, ctx) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,14 @@ import de.fraunhofer.aisec.cpg.graph.types.FunctionPointerType
import de.fraunhofer.aisec.cpg.helpers.IdentitySet
import de.fraunhofer.aisec.cpg.helpers.SubgraphWalker.ScopedWalker
import de.fraunhofer.aisec.cpg.passes.order.DependsOn
import de.fraunhofer.aisec.cpg.passes.order.RequiredFrontend
import java.util.*
import java.util.function.Consumer

/**
* This [Pass] is responsible for resolving function pointer calls, i.e., [CallExpression] nodes
* that contain a reference/pointer to a function and are being "called". This pass is intentionally
* split from the [CallResolver] because it depends on DFG edges. This split allows the
* [CallResolver] to be run before any DFG passes, which in turn allow us to also populate DFG
* split from the [SymbolResolver] because it depends on DFG edges. This split allows the
* [SymbolResolver] to be run before any DFG passes, which in turn allow us to also populate DFG
* passes for inferred functions.
*
* This pass is currently only run for the [CXXLanguageFrontend], however, in the future we might
Expand All @@ -54,8 +53,7 @@ import java.util.function.Consumer
*/
@DependsOn(SymbolResolver::class)
@DependsOn(DFGPass::class)
@RequiredFrontend(CXXLanguageFrontend::class)
class FunctionPointerCallResolver(ctx: TranslationContext) : ComponentPass(ctx) {
class DynamicInvokeResolver(ctx: TranslationContext) : ComponentPass(ctx) {
private lateinit var walker: ScopedWalker
private var inferDfgForUnresolvedCalls = false

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1537,7 +1537,7 @@ internal class CXXLanguageFrontendTest : BaseTest() {
it.registerPass<EvaluationOrderGraphPass>() // creates EOG
it.registerPass<TypeResolver>()
it.registerPass<ControlFlowSensitiveDFGPass>()
it.registerPass<FunctionPointerCallResolver>()
it.registerPass<DynamicInvokeResolver>()
it.registerPass<FilenameMapper>()
}

Expand Down Expand Up @@ -1582,7 +1582,7 @@ internal class CXXLanguageFrontendTest : BaseTest() {
it.registerPass<DFGPass>()
it.registerPass<EvaluationOrderGraphPass>() // creates EOG
it.registerPass<TypeResolver>()
it.registerPass<FunctionPointerCallResolver>()
it.registerPass<DynamicInvokeResolver>()
it.registerPass<ControlFlowSensitiveDFGPass>()
it.registerPass<FilenameMapper>()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ class GoLanguage :
return true
}

// This makes lambda expression works, as long as we have the dedicated a
// FunctionPointerType
if (type is FunctionPointerType && superType is FunctionType) {
return type == superType.reference(PointerType.PointerOrigin.POINTER)
}

// the unsafe.IntegerType is a fake type in the unsafe package, that accepts any integer
// type
if (type is IntegerType && superType == primitiveType("unsafe.IntegerType")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@ import de.fraunhofer.aisec.cpg.frontends.golang.funcTypeName
import de.fraunhofer.aisec.cpg.frontends.golang.isOverlay
import de.fraunhofer.aisec.cpg.frontends.golang.underlyingType
import de.fraunhofer.aisec.cpg.graph.*
import de.fraunhofer.aisec.cpg.graph.declarations.IncludeDeclaration
import de.fraunhofer.aisec.cpg.graph.declarations.NamespaceDeclaration
import de.fraunhofer.aisec.cpg.graph.declarations.TranslationUnitDeclaration
import de.fraunhofer.aisec.cpg.graph.declarations.VariableDeclaration
import de.fraunhofer.aisec.cpg.graph.declarations.*
import de.fraunhofer.aisec.cpg.graph.scopes.Scope
import de.fraunhofer.aisec.cpg.graph.statements.DeclarationStatement
import de.fraunhofer.aisec.cpg.graph.statements.ForEachStatement
Expand Down Expand Up @@ -146,11 +143,7 @@ class GoExtraPass(ctx: TranslationContext) : ComponentPass(ctx), ScopeProvider {
val len = newFunctionDeclaration("len", localNameOnly = true)
len.parameters = listOf(newParameterDeclaration("v", autoType()))
len.returnTypes = listOf(primitiveType("int"))
len.type =
typeManager.registerType(
FunctionType(funcTypeName(len.signatureTypes, len.returnTypes))
)
scopeManager.addDeclaration(len)
addBuiltInFunction(len)

/**
* ```go
Expand All @@ -164,28 +157,46 @@ class GoExtraPass(ctx: TranslationContext) : ComponentPass(ctx), ScopeProvider {
newParameterDeclaration("elems", autoType(), variadic = true),
)
append.returnTypes = listOf(autoType().array())
append.type =
typeManager.registerType(
FunctionType(funcTypeName(append.signatureTypes, append.returnTypes))
)
scopeManager.addDeclaration(append)
addBuiltInFunction(append)

/**
* ```go
* func panic(v any)
* ```
*/
val panic = newFunctionDeclaration("panic", localNameOnly = true)
panic.parameters = listOf(newParameterDeclaration("v", primitiveType("any")))
addBuiltInFunction(panic)

/**
* ```go
* func recover() any
* ```
*/
val recover = newFunctionDeclaration("panic", localNameOnly = true)
panic.returnTypes = listOf(primitiveType("any"))
addBuiltInFunction(recover)

val error = newRecordDeclaration("error", "interface")
scopeManager.enterScope(error)

val errorFunc = newMethodDeclaration("Error", recordDeclaration = error)
errorFunc.returnTypes = listOf(primitiveType("string"))
errorFunc.type =
typeManager.registerType(
FunctionType(funcTypeName(errorFunc.signatureTypes, errorFunc.returnTypes))
)
scopeManager.addDeclaration(errorFunc)
addBuiltInFunction(errorFunc)

scopeManager.leaveScope(error)
builtin
}
}

private fun addBuiltInFunction(func: FunctionDeclaration) {
func.type =
typeManager.registerType(
FunctionType(funcTypeName(func.signatureTypes, func.returnTypes))
)
scopeManager.addDeclaration(func)
}

/**
* handleInitializerListExpression changes the references of keys in a [KeyValueExpression] to
* include the object it is creating as a parent name.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1104,13 +1104,26 @@ class GoLanguageFrontendTest : BaseTest() {

assertNotNull(result)

val calls = result.calls
// All calls except the one to "funcy" (which is a dynamic invoke) should be resolved to
// non-inferred functions
val calls = result.calls.filter { it.name.localName != "funcy" }
calls.forEach {
assertTrue(it.invokes.isNotEmpty())
it.invokes.forEach { func -> assertFalse(func.isInferred) }
}

val refs = result.refs
val func = result.functions[""]
assertNotNull(func)

// For funcy, we should be able to find the DFG back to the lamba function
val funcy = result.calls["funcy"]
assertNotNull(funcy)

val path = funcy.callee?.followPrevDFGEdgesUntilHit { it == func }
assertNotNull(path)
assertTrue(path.fulfilled.isNotEmpty())

val refs = result.refs.filter { it.name.localName != "_" }
refs.forEach { assertNotNull(it.refersTo) }
}
}
1 change: 1 addition & 0 deletions cpg-language-go/src/test/resources/golang/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ func main() {

go c.MyFunc()

// In Go, numeric literals can be used as any numeric type
sixtyfour(1)
}

Expand Down
11 changes: 11 additions & 0 deletions cpg-language-go/src/test/resources/golang/complex_resolution.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@ func main() {
var some = NewSomething()

some.self.self.self.Do()

_ = doFuncy(func() error {
return nil
})
}

func Func(args ...int) {}

type Funcy func() error

func doFuncy(funcy Funcy) (err error) {
err = funcy()
return
}
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ class TypeScriptLanguageFrontendTest {
assertNotNull(usersComponent)
assertLocalName("Users", usersComponent)
assertEquals(1, usersComponent.constructors.size)
assertEquals(/*2*/ 3 /* because of a dummy node */, usersComponent.methods.size)
assertEquals(2, usersComponent.methods.size)
assertEquals(/*0*/ 2 /* because of dummy nodes */, usersComponent.fields.size)

val render = usersComponent.methods["render"]
Expand Down

0 comments on commit 4a7dd5b

Please sign in to comment.