Skip to content

Commit

Permalink
Resolving Go's embedded struct methods using the regular SymbolResolver
Browse files Browse the repository at this point in the history
  • Loading branch information
oxisto committed Oct 1, 2023
1 parent 10cf78c commit f10a3f4
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import de.fraunhofer.aisec.cpg.frontends.*
import de.fraunhofer.aisec.cpg.graph.*
import de.fraunhofer.aisec.cpg.graph.declarations.*
import de.fraunhofer.aisec.cpg.graph.scopes.NameScope
import de.fraunhofer.aisec.cpg.graph.scopes.StructureDeclarationScope
import de.fraunhofer.aisec.cpg.graph.statements.expressions.*
import de.fraunhofer.aisec.cpg.graph.types.*
import de.fraunhofer.aisec.cpg.helpers.SubgraphWalker.ScopedWalker
Expand Down Expand Up @@ -790,9 +791,18 @@ open class SymbolResolver(ctx: TranslationContext) : ComponentPass(ctx) {
ctx
)
} else {
recordDeclaration.methods.filter {
it.name.lastPartsMatch(name) && it.hasSignature(call.signature)
// We should not directly access the "methods" property of the record declaration,
// because depending on the programming language, this only may hold methods that are
// declared directly within the original type declaration, but not ones that are
// declared "outside" (e.g, like it is possible in Go and C++). Instead, we should
// retrieve the scope of the record and look for appropriate declarations.
val scope = scopeManager.lookupScope(recordDeclaration) as? StructureDeclarationScope

// Filter the value declarations for an appropriate method
scope?.valueDeclarations?.filterIsInstance<MethodDeclaration>()?.filter {
it.name.lastPartsMatch(name) && it.hasSignature(call)
}
?: listOf()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,11 @@
*/
package de.fraunhofer.aisec.cpg.frontends.golang

import de.fraunhofer.aisec.cpg.TranslationContext
import de.fraunhofer.aisec.cpg.frontends.*
import de.fraunhofer.aisec.cpg.graph.declarations.FunctionDeclaration
import de.fraunhofer.aisec.cpg.graph.declarations.RecordDeclaration
import de.fraunhofer.aisec.cpg.graph.declarations.TranslationUnitDeclaration
import de.fraunhofer.aisec.cpg.graph.primitiveType
import de.fraunhofer.aisec.cpg.graph.statements.expressions.CallExpression
import de.fraunhofer.aisec.cpg.graph.statements.expressions.Literal
import de.fraunhofer.aisec.cpg.graph.types.*
import de.fraunhofer.aisec.cpg.passes.SymbolResolver
import org.neo4j.ogm.annotation.Transient

/** The Go language. */
Expand All @@ -44,8 +39,7 @@ class GoLanguage :
HasGenerics,
HasStructs,
HasFirstClassFunctions,
HasAnonymousIdentifier,
HasComplexCallResolution {
HasAnonymousIdentifier {
override val fileExtensions = listOf("go")
override val namespaceDelimiter = "."
@Transient override val frontend = GoLanguageFrontend::class
Expand Down Expand Up @@ -200,45 +194,4 @@ class GoLanguage :

return false
}

override fun refineNormalCallResolution(
call: CallExpression,
ctx: TranslationContext,
currentTU: TranslationUnitDeclaration
): List<FunctionDeclaration> {
return ctx.scopeManager.resolveFunction(call)
}

override fun refineMethodCallResolution(
curClass: RecordDeclaration?,
possibleContainingTypes: Set<Type>,
call: CallExpression,
ctx: TranslationContext,
currentTU: TranslationUnitDeclaration,
callResolver: SymbolResolver
): List<FunctionDeclaration> {
return ctx.scopeManager.resolveFunction(call)
}

override fun refineInvocationCandidatesFromRecord(
recordDeclaration: RecordDeclaration,
call: CallExpression,
name: String,
ctx: TranslationContext
): List<FunctionDeclaration> {
var list =
recordDeclaration.methods.filter {
it.name.lastPartsMatch(name) && it.hasSignature(call.signature)
}

// If the list is empty, we also need to consider the embedded records
if (list.isEmpty()) {
list =
recordDeclaration.embeddedStructs
.flatMap { it.methods }
.filter { it.name.lastPartsMatch(name) && it.hasSignature(call.signature) }
}

return list
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,6 @@ fun funcTypeName(paramTypes: List<Type>, returnTypes: List<Type>): String {
val RecordDeclaration.embeddedStructs: List<RecordDeclaration>
get() {
return this.fields
.filter { it.name.localName == it.type.root.name.localName }
.filter { "embedded" in it.modifiers }
.mapNotNull { it.type.root.recordDeclaration }
}
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,17 @@ class SpecificationHandler(frontend: GoLanguageFrontend) :

// A field can also have no name, which means that it is embedded. In this case, it
// can be accessed by the local name of its type and therefore we name the field
// accordingly
val fieldName =
// accordingly. We use the modifiers property to denote that this is an embedded
// field, so we can easily retrieve them later
val (fieldName, modifiers) =
if (field.names.isEmpty()) {
// Retrieve the root type local name
type.root.name.localName
Pair(type.root.name.localName, listOf("embedded"))
} else {
field.names[0].name
Pair(field.names[0].name, listOf())
}

val decl = newFieldDeclaration(fieldName, type, rawNode = field)
val decl = newFieldDeclaration(fieldName, type, modifiers, rawNode = field)
frontend.scopeManager.addDeclaration(decl)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@
package de.fraunhofer.aisec.cpg.passes

import de.fraunhofer.aisec.cpg.TranslationContext
import de.fraunhofer.aisec.cpg.frontends.golang.GoLanguage
import de.fraunhofer.aisec.cpg.frontends.golang.funcTypeName
import de.fraunhofer.aisec.cpg.frontends.golang.isOverlay
import de.fraunhofer.aisec.cpg.frontends.golang.underlyingType
import de.fraunhofer.aisec.cpg.frontends.golang.*
import de.fraunhofer.aisec.cpg.graph.*
import de.fraunhofer.aisec.cpg.graph.declarations.*
import de.fraunhofer.aisec.cpg.graph.scopes.Scope
Expand Down Expand Up @@ -95,14 +92,19 @@ import de.fraunhofer.aisec.cpg.passes.order.ExecuteBefore
* This is also possible with more complex types, such as interfaces or aliased types, as long as
* they are compatible. Because types in the same package can be defined in multiple files, we
* cannot decide during the frontend run. Therefore, we need to execute this pass before the
* [CallResolver] and convert certain [CallExpression] nodes into a [CastExpression].
* [SymbolResolver] and convert certain [CallExpression] nodes into a [CastExpression].
*
* ## Adjust Names of Keys in Key Value Expressions to FQN
*
* This pass also adjusts the names of keys in a [KeyValueExpression], which is part of an
* [InitializerListExpression] to a fully-qualified name that contains the name of the [ObjectType]
* that the expression is creating. This way we can resolve the static references to the field to
* the actual field.
*
* ## Add Methods of Embedded Structs to the Record's Scope
*
* This pass also adds methods of [RecordDeclaration.embeddedStructs] into the scope of the
* [RecordDeclaration] itself, so that it can be resolved using the regular [SymbolResolver].
*/
@ExecuteBefore(SymbolResolver::class)
@ExecuteBefore(EvaluationOrderGraphPass::class)
Expand All @@ -123,6 +125,7 @@ class GoExtraPass(ctx: TranslationContext) : ComponentPass(ctx), ScopeProvider {
when (node) {
is CallExpression -> handleCall(node, parent)
is IncludeDeclaration -> handleInclude(node)
is RecordDeclaration -> handleRecordDeclaration(node)
is AssignExpression -> handleAssign(node)
is ForEachStatement -> handleForEachStatement(node)
is InitializerListExpression -> handleInitializerListExpression(node)
Expand All @@ -134,6 +137,42 @@ class GoExtraPass(ctx: TranslationContext) : ComponentPass(ctx), ScopeProvider {
}
}

/**
* This function adds methods of [RecordDeclaration.embeddedStructs] into the scope of the
* struct itself, so we can resolve method calls of embedded structs.
*
* For example, if a struct embeds another struct (see https://go.dev/ref/spec#Struct_types), we
* can call any methods of the embedded struct on the one that embeds it:
* ```go
* type MyTime struct {
* time.Time
* }
*
* func main() {
* var t = MyTime{Time: time.Now()}
* t.Add(-5*time.Second)
* }
* ```
*/
private fun handleRecordDeclaration(record: RecordDeclaration) {
// We are only interest in structs, not interfaces
if (record.kind != "struct") {
return
}

// Enter our record's scope
scopeManager.enterScope(record)

// Loop through the embedded struct and add their methods to the record's scope.
for (method in record.embeddedStructs.flatMap { it.methods }) {
// Add it to the scope, but do NOT add it to the underlying AST field (methods),
// otherwise we would duplicate the method in the AST
scopeManager.addDeclaration(method, addToAST = false)
}

scopeManager.leaveScope(record)
}

private fun addBuiltIn(): TranslationUnitDeclaration {
val builtin = newTranslationUnitDeclaration("builtin.go")
builtin.language = GoLanguage()
Expand Down Expand Up @@ -192,7 +231,11 @@ class GoExtraPass(ctx: TranslationContext) : ComponentPass(ctx), ScopeProvider {
private fun addBuiltInFunction(func: FunctionDeclaration) {
func.type =
typeManager.registerType(
FunctionType(funcTypeName(func.signatureTypes, func.returnTypes))
FunctionType(
funcTypeName(func.signatureTypes, func.returnTypes),
func.signatureTypes,
func.returnTypes
)
)
scopeManager.addDeclaration(func)
}
Expand Down

0 comments on commit f10a3f4

Please sign in to comment.