From bc9866dbe3c9244657d2100368006af29d31a291 Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Fri, 11 Aug 2023 15:00:51 +0200 Subject: [PATCH] Trying to implement type switches --- .../src/main/golang/lib/cpg/main.go | 18 +++++ .../frontends/golang/DeclarationHandler.kt | 3 +- .../cpg/frontends/golang/ExpressionHandler.kt | 40 ++++++++---- .../cpg/frontends/golang/GoStandardLibrary.kt | 24 +++++++ .../cpg/frontends/golang/StatementHandler.kt | 65 +++++++++++++++---- .../cpg/frontends/golang/StatementTest.kt | 17 +++++ .../src/test/resources/golang/type_assert.go | 27 +++++--- 7 files changed, 157 insertions(+), 37 deletions(-) diff --git a/cpg-language-go/src/main/golang/lib/cpg/main.go b/cpg-language-go/src/main/golang/lib/cpg/main.go index 0198373e0d7..9ab12a03618 100644 --- a/cpg-language-go/src/main/golang/lib/cpg/main.go +++ b/cpg-language-go/src/main/golang/lib/cpg/main.go @@ -909,6 +909,24 @@ func GetSwitchStmtBody(ptr unsafe.Pointer) unsafe.Pointer { return save(stmt.Body) } +//export GetTypeSwitchStmtInit +func GetTypeSwitchStmtInit(ptr unsafe.Pointer) unsafe.Pointer { + stmt := restore[*ast.TypeSwitchStmt](ptr) + return save(stmt.Init) +} + +//export GetTypeSwitchStmtAssign +func GetTypeSwitchStmtAssign(ptr unsafe.Pointer) unsafe.Pointer { + stmt := restore[*ast.TypeSwitchStmt](ptr) + return save(stmt.Assign) +} + +//export GetTypeSwitchStmtBody +func GetTypeSwitchStmtBody(ptr unsafe.Pointer) unsafe.Pointer { + stmt := restore[*ast.TypeSwitchStmt](ptr) + return save(stmt.Body) +} + func restore[T any](ptr unsafe.Pointer) T { return pointer.Restore(ptr).(T) } diff --git a/cpg-language-go/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/golang/DeclarationHandler.kt b/cpg-language-go/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/golang/DeclarationHandler.kt index 5123ec3d583..fe50c2a253e 100644 --- a/cpg-language-go/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/golang/DeclarationHandler.kt +++ b/cpg-language-go/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/golang/DeclarationHandler.kt @@ -98,8 +98,7 @@ class DeclarationHandler(frontend: GoLanguageFrontend) : // marked as AST and in Go a method is not part of the struct's AST but is // declared outside. In the future, we need to differentiate between just the // associated members of the class and the pure AST nodes declared in the - // struct - // itself + // struct itself if (record != null) { method.recordDeclaration = record record.addMethod(method) diff --git a/cpg-language-go/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/golang/ExpressionHandler.kt b/cpg-language-go/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/golang/ExpressionHandler.kt index 1c87e5f3f62..80319b59c55 100644 --- a/cpg-language-go/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/golang/ExpressionHandler.kt +++ b/cpg-language-go/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/golang/ExpressionHandler.kt @@ -345,18 +345,34 @@ class ExpressionHandler(frontend: GoLanguageFrontend) : private fun handleTypeAssertExpr( typeAssertExpr: GoStandardLibrary.Ast.TypeAssertExpr - ): CastExpression { - val cast = newCastExpression(rawNode = typeAssertExpr) - - // Parse the inner expression - cast.expression = - handle(typeAssertExpr.x) ?: newProblemExpression("missing inner expression") - - // The type can be null, but only in certain circumstances, i.e, a type switch (which we do - // not support yet) - typeAssertExpr.type?.let { cast.castType = frontend.typeOf(it) } - - return cast + ): Expression { + // This can either be a regular type assertion, which we handle as a cast expression or the + // "special" type assertion `.(type)`, which is used in a type switch to retrieve the type + // of the variable. In this case we treat it as a special unary operator. + if (typeAssertExpr.type == null) { + val op = + newUnaryOperator( + ".(type)", + postfix = true, + prefix = false, + rawNode = typeAssertExpr + ) + op.input = handle(typeAssertExpr.x) + + return op + } else { + val cast = newCastExpression(rawNode = typeAssertExpr) + + // Parse the inner expression + cast.expression = handle(typeAssertExpr.x) + + // The type can be null, but only in certain circumstances, i.e, a type switch (which we + // do + // not support yet) + typeAssertExpr.type?.let { cast.castType = frontend.typeOf(it) } + + return cast + } } private fun handleUnaryExpr(unaryExpr: GoStandardLibrary.Ast.UnaryExpr): UnaryOperator { diff --git a/cpg-language-go/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/golang/GoStandardLibrary.kt b/cpg-language-go/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/golang/GoStandardLibrary.kt index 0592820e984..c9acec5d89d 100644 --- a/cpg-language-go/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/golang/GoStandardLibrary.kt +++ b/cpg-language-go/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/golang/GoStandardLibrary.kt @@ -543,6 +543,7 @@ interface GoStandardLibrary : Library { "*ast.ReturnStmt" -> ReturnStmt(nativeValue) "*ast.SendStmt" -> SendStmt(nativeValue) "*ast.SwitchStmt" -> SwitchStmt(nativeValue) + "*ast.TypeSwitchStmt" -> TypeSwitchStmt(nativeValue) else -> super.fromNative(nativeValue, context) } } @@ -754,6 +755,23 @@ interface GoStandardLibrary : Library { } } + class TypeSwitchStmt(p: Pointer? = Pointer.NULL) : Stmt(p) { + val init: Stmt? + get() { + return INSTANCE.GetTypeSwitchStmtInit(this) + } + + val assign: Stmt + get() { + return INSTANCE.GetTypeSwitchStmtAssign(this) + } + + val body: BlockStmt + get() { + return INSTANCE.GetTypeSwitchStmtBody(this) + } + } + class Position(p: Pointer? = Pointer.NULL) : GoObject(p) { val line: Int get() { @@ -1035,6 +1053,12 @@ interface GoStandardLibrary : Library { fun GetSwitchStmtBody(stmt: Ast.SwitchStmt): Ast.BlockStmt + fun GetTypeSwitchStmtInit(typeSwitchStmt: Ast.TypeSwitchStmt): Ast.Stmt? + + fun GetTypeSwitchStmtAssign(typeSwitchStmt: Ast.TypeSwitchStmt): Ast.Stmt + + fun GetTypeSwitchStmtBody(typeSwitchStmt: Ast.TypeSwitchStmt): Ast.BlockStmt + fun GetNumGenDeclSpecs(genDecl: Ast.GenDecl): Int fun GetGenDeclSpec(genDecl: Ast.GenDecl, i: Int): Ast.Spec diff --git a/cpg-language-go/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/golang/StatementHandler.kt b/cpg-language-go/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/golang/StatementHandler.kt index e99a824fba3..c6dbad6bf02 100644 --- a/cpg-language-go/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/golang/StatementHandler.kt +++ b/cpg-language-go/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/golang/StatementHandler.kt @@ -53,6 +53,7 @@ class StatementHandler(frontend: GoLanguageFrontend) : is GoStandardLibrary.Ast.ReturnStmt -> handleReturnStmt(stmt) is GoStandardLibrary.Ast.SendStmt -> handleSendStmt(stmt) is GoStandardLibrary.Ast.SwitchStmt -> handleSwitchStmt(stmt) + is GoStandardLibrary.Ast.TypeSwitchStmt -> handleTypeSwitchStmt(stmt) else -> handleNotSupported(stmt, stmt.goType) } } @@ -115,7 +116,10 @@ class StatementHandler(frontend: GoLanguageFrontend) : return compound } - private fun handleCaseClause(caseClause: GoStandardLibrary.Ast.CaseClause): Statement { + private fun handleCaseClause( + caseClause: GoStandardLibrary.Ast.CaseClause, + typeSwitchName: Name? = null + ): Statement { val case = if (caseClause.list.isEmpty()) { newDefaultStatement(rawNode = caseClause) @@ -136,6 +140,17 @@ class StatementHandler(frontend: GoLanguageFrontend) : // Add the case statement block += case + // TODO(oxisto): We somehow need to create a shadowed variable here and assign it to the + // scope + if (typeSwitchName != null) { + val typeSwitchType = frontend.typeOf(caseClause.list[0]) + log.warn( + "Variable {} should have the type {} now; we cannot do that yet", + typeSwitchName, + typeSwitchType + ) + } + for (s in caseClause.body) { block += handle(s) } @@ -286,9 +301,7 @@ class StatementHandler(frontend: GoLanguageFrontend) : val expr = frontend.expressionHandler.handle(results[0]) // TODO: parse more than one result expression - if (expr != null) { - `return`.returnValue = expr - } + `return`.returnValue = expr } else { // TODO: connect result statement to result variables } @@ -316,19 +329,45 @@ class StatementHandler(frontend: GoLanguageFrontend) : handle(switchStmt.body) as? CompoundStatement ?: return newProblemExpression("missing switch body") - // Because of the way we parse the statements, the case statement turns out to be the last - // statement. However, we need it to be the first statement, so we need to switch first and - // last items - /*val statements = block.statements.toMutableList() - val tmp = statements.first() - statements[0] = block.statements.last() - statements[(statements.size - 1).coerceAtLeast(0)] = tmp - block.statements = statements*/ - switch.statement = block frontend.scopeManager.leaveScope(switch) return switch } + + private fun handleTypeSwitchStmt( + typeSwitchStmt: GoStandardLibrary.Ast.TypeSwitchStmt + ): SwitchStatement { + val switch = newSwitchStatement(rawNode = typeSwitchStmt) + + frontend.scopeManager.enterScope(switch) + + typeSwitchStmt.init?.let { switch.initializerStatement = handle(it) } + + val assign = frontend.statementHandler.handle(typeSwitchStmt.assign) + val variableName = + if (assign is AssignExpression) { + switch.selector = assign.rhs.singleOrNull() + assign.lhs.singleOrNull()?.name + } else { + null + } + + val body = newCompoundStatement(rawNode = typeSwitchStmt.body) + + frontend.scopeManager.enterScope(body) + + for (c in typeSwitchStmt.body.list.filterIsInstance()) { + handleCaseClause(c, variableName) + } + + frontend.scopeManager.leaveScope(body) + + switch.statement = body + + frontend.scopeManager.leaveScope(switch) + + return switch + } } diff --git a/cpg-language-go/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/golang/StatementTest.kt b/cpg-language-go/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/golang/StatementTest.kt index 3e698bc9054..a5b8232faf1 100644 --- a/cpg-language-go/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/golang/StatementTest.kt +++ b/cpg-language-go/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/golang/StatementTest.kt @@ -110,4 +110,21 @@ class StatementTest { // Its call expression should connect to the return statement op.input.prevEOG.all { it is ReturnStatement } } + + @Test + fun testTypeSwitch() { + val topLevel = Path.of("src", "test", "resources", "golang") + val tu = + analyzeAndGetFirstTU( + listOf(topLevel.resolve("type_assert.go").toFile()), + topLevel, + true + ) { + it.registerLanguage() + } + assertNotNull(tu) + + val main = tu.functions["main"] + assertNotNull(main) + } } diff --git a/cpg-language-go/src/test/resources/golang/type_assert.go b/cpg-language-go/src/test/resources/golang/type_assert.go index 4f42ac27989..3b9a809f990 100644 --- a/cpg-language-go/src/test/resources/golang/type_assert.go +++ b/cpg-language-go/src/test/resources/golang/type_assert.go @@ -2,19 +2,26 @@ package main import "fmt" -type MyStruct struct {} +type MyStruct struct{} type MyInterface interface { - MyFunc() + MyFunc() } + func (MyStruct) MyFunc() {} -func main () { - var f MyInterface = MyStruct{} - var s = f.(MyStruct) +func main() { + var f MyInterface = MyStruct{} + var s = f.(MyStruct) + + fmt.Printf("%+v", s) - fmt.Printf("%+v", s) + var _ = MyInterface(s) + var _ = interface{}(s) + var _ = any(s) - var _ = MyInterface(s) - var _ = interface{}(s) - var _ = any(s) -} \ No newline at end of file + switch v := f.(type) { + case MyStruct: + var s2 = v + fmt.Printf("%+v", s2) + } +}