Skip to content

Commit

Permalink
Trying to implement type switches
Browse files Browse the repository at this point in the history
  • Loading branch information
oxisto committed Aug 11, 2023
1 parent 32fd851 commit bc9866d
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 37 deletions.
18 changes: 18 additions & 0 deletions cpg-language-go/src/main/golang/lib/cpg/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,24 @@ func GetSwitchStmtBody(ptr unsafe.Pointer) unsafe.Pointer {
return save(stmt.Body)
}

//export GetTypeSwitchStmtInit
func GetTypeSwitchStmtInit(ptr unsafe.Pointer) unsafe.Pointer {
stmt := restore[*ast.TypeSwitchStmt](ptr)
return save(stmt.Init)
}

//export GetTypeSwitchStmtAssign
func GetTypeSwitchStmtAssign(ptr unsafe.Pointer) unsafe.Pointer {
stmt := restore[*ast.TypeSwitchStmt](ptr)
return save(stmt.Assign)
}

//export GetTypeSwitchStmtBody
func GetTypeSwitchStmtBody(ptr unsafe.Pointer) unsafe.Pointer {
stmt := restore[*ast.TypeSwitchStmt](ptr)
return save(stmt.Body)
}

func restore[T any](ptr unsafe.Pointer) T {
return pointer.Restore(ptr).(T)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,7 @@ class DeclarationHandler(frontend: GoLanguageFrontend) :
// marked as AST and in Go a method is not part of the struct's AST but is
// declared outside. In the future, we need to differentiate between just the
// associated members of the class and the pure AST nodes declared in the
// struct
// itself
// struct itself
if (record != null) {
method.recordDeclaration = record
record.addMethod(method)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,18 +345,34 @@ class ExpressionHandler(frontend: GoLanguageFrontend) :

private fun handleTypeAssertExpr(
typeAssertExpr: GoStandardLibrary.Ast.TypeAssertExpr
): CastExpression {
val cast = newCastExpression(rawNode = typeAssertExpr)

// Parse the inner expression
cast.expression =
handle(typeAssertExpr.x) ?: newProblemExpression("missing inner expression")

// The type can be null, but only in certain circumstances, i.e, a type switch (which we do
// not support yet)
typeAssertExpr.type?.let { cast.castType = frontend.typeOf(it) }

return cast
): Expression {
// This can either be a regular type assertion, which we handle as a cast expression or the
// "special" type assertion `.(type)`, which is used in a type switch to retrieve the type
// of the variable. In this case we treat it as a special unary operator.
if (typeAssertExpr.type == null) {
val op =
newUnaryOperator(
".(type)",
postfix = true,
prefix = false,
rawNode = typeAssertExpr
)
op.input = handle(typeAssertExpr.x)

return op
} else {
val cast = newCastExpression(rawNode = typeAssertExpr)

// Parse the inner expression
cast.expression = handle(typeAssertExpr.x)

// The type can be null, but only in certain circumstances, i.e, a type switch (which we
// do
// not support yet)
typeAssertExpr.type?.let { cast.castType = frontend.typeOf(it) }

return cast
}
}

private fun handleUnaryExpr(unaryExpr: GoStandardLibrary.Ast.UnaryExpr): UnaryOperator {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,7 @@ interface GoStandardLibrary : Library {
"*ast.ReturnStmt" -> ReturnStmt(nativeValue)
"*ast.SendStmt" -> SendStmt(nativeValue)
"*ast.SwitchStmt" -> SwitchStmt(nativeValue)
"*ast.TypeSwitchStmt" -> TypeSwitchStmt(nativeValue)
else -> super.fromNative(nativeValue, context)
}
}
Expand Down Expand Up @@ -754,6 +755,23 @@ interface GoStandardLibrary : Library {
}
}

class TypeSwitchStmt(p: Pointer? = Pointer.NULL) : Stmt(p) {
val init: Stmt?
get() {
return INSTANCE.GetTypeSwitchStmtInit(this)
}

val assign: Stmt
get() {
return INSTANCE.GetTypeSwitchStmtAssign(this)
}

val body: BlockStmt
get() {
return INSTANCE.GetTypeSwitchStmtBody(this)
}
}

class Position(p: Pointer? = Pointer.NULL) : GoObject(p) {
val line: Int
get() {
Expand Down Expand Up @@ -1035,6 +1053,12 @@ interface GoStandardLibrary : Library {

fun GetSwitchStmtBody(stmt: Ast.SwitchStmt): Ast.BlockStmt

fun GetTypeSwitchStmtInit(typeSwitchStmt: Ast.TypeSwitchStmt): Ast.Stmt?

fun GetTypeSwitchStmtAssign(typeSwitchStmt: Ast.TypeSwitchStmt): Ast.Stmt

fun GetTypeSwitchStmtBody(typeSwitchStmt: Ast.TypeSwitchStmt): Ast.BlockStmt

fun GetNumGenDeclSpecs(genDecl: Ast.GenDecl): Int

fun GetGenDeclSpec(genDecl: Ast.GenDecl, i: Int): Ast.Spec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class StatementHandler(frontend: GoLanguageFrontend) :
is GoStandardLibrary.Ast.ReturnStmt -> handleReturnStmt(stmt)
is GoStandardLibrary.Ast.SendStmt -> handleSendStmt(stmt)
is GoStandardLibrary.Ast.SwitchStmt -> handleSwitchStmt(stmt)
is GoStandardLibrary.Ast.TypeSwitchStmt -> handleTypeSwitchStmt(stmt)
else -> handleNotSupported(stmt, stmt.goType)
}
}
Expand Down Expand Up @@ -115,7 +116,10 @@ class StatementHandler(frontend: GoLanguageFrontend) :
return compound
}

private fun handleCaseClause(caseClause: GoStandardLibrary.Ast.CaseClause): Statement {
private fun handleCaseClause(
caseClause: GoStandardLibrary.Ast.CaseClause,
typeSwitchName: Name? = null
): Statement {
val case =
if (caseClause.list.isEmpty()) {
newDefaultStatement(rawNode = caseClause)
Expand All @@ -136,6 +140,17 @@ class StatementHandler(frontend: GoLanguageFrontend) :
// Add the case statement
block += case

// TODO(oxisto): We somehow need to create a shadowed variable here and assign it to the
// scope
if (typeSwitchName != null) {
val typeSwitchType = frontend.typeOf(caseClause.list[0])
log.warn(
"Variable {} should have the type {} now; we cannot do that yet",
typeSwitchName,
typeSwitchType
)
}

for (s in caseClause.body) {
block += handle(s)
}
Expand Down Expand Up @@ -286,9 +301,7 @@ class StatementHandler(frontend: GoLanguageFrontend) :
val expr = frontend.expressionHandler.handle(results[0])

// TODO: parse more than one result expression
if (expr != null) {
`return`.returnValue = expr
}
`return`.returnValue = expr
} else {
// TODO: connect result statement to result variables
}
Expand Down Expand Up @@ -316,19 +329,45 @@ class StatementHandler(frontend: GoLanguageFrontend) :
handle(switchStmt.body) as? CompoundStatement
?: return newProblemExpression("missing switch body")

// Because of the way we parse the statements, the case statement turns out to be the last
// statement. However, we need it to be the first statement, so we need to switch first and
// last items
/*val statements = block.statements.toMutableList()
val tmp = statements.first()
statements[0] = block.statements.last()
statements[(statements.size - 1).coerceAtLeast(0)] = tmp
block.statements = statements*/

switch.statement = block

frontend.scopeManager.leaveScope(switch)

return switch
}

private fun handleTypeSwitchStmt(
typeSwitchStmt: GoStandardLibrary.Ast.TypeSwitchStmt
): SwitchStatement {
val switch = newSwitchStatement(rawNode = typeSwitchStmt)

frontend.scopeManager.enterScope(switch)

typeSwitchStmt.init?.let { switch.initializerStatement = handle(it) }

val assign = frontend.statementHandler.handle(typeSwitchStmt.assign)
val variableName =
if (assign is AssignExpression) {
switch.selector = assign.rhs.singleOrNull()
assign.lhs.singleOrNull()?.name
} else {
null
}

val body = newCompoundStatement(rawNode = typeSwitchStmt.body)

frontend.scopeManager.enterScope(body)

for (c in typeSwitchStmt.body.list.filterIsInstance<GoStandardLibrary.Ast.CaseClause>()) {
handleCaseClause(c, variableName)
}

frontend.scopeManager.leaveScope(body)

switch.statement = body

frontend.scopeManager.leaveScope(switch)

return switch
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,21 @@ class StatementTest {
// Its call expression should connect to the return statement
op.input.prevEOG.all { it is ReturnStatement }
}

@Test
fun testTypeSwitch() {
val topLevel = Path.of("src", "test", "resources", "golang")
val tu =
analyzeAndGetFirstTU(
listOf(topLevel.resolve("type_assert.go").toFile()),
topLevel,
true
) {
it.registerLanguage<GoLanguage>()
}
assertNotNull(tu)

val main = tu.functions["main"]
assertNotNull(main)
}
}
27 changes: 17 additions & 10 deletions cpg-language-go/src/test/resources/golang/type_assert.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,26 @@ package main

import "fmt"

type MyStruct struct {}
type MyStruct struct{}
type MyInterface interface {
MyFunc()
MyFunc()
}

func (MyStruct) MyFunc() {}

func main () {
var f MyInterface = MyStruct{}
var s = f.(MyStruct)
func main() {
var f MyInterface = MyStruct{}
var s = f.(MyStruct)

fmt.Printf("%+v", s)

fmt.Printf("%+v", s)
var _ = MyInterface(s)
var _ = interface{}(s)
var _ = any(s)

var _ = MyInterface(s)
var _ = interface{}(s)
var _ = any(s)
}
switch v := f.(type) {
case MyStruct:
var s2 = v
fmt.Printf("%+v", s2)
}
}

0 comments on commit bc9866d

Please sign in to comment.