From d32ae9898b6e7cd97481274a0a13810dac2af51a Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Thu, 18 Jul 2024 21:27:01 +0200 Subject: [PATCH 1/9] Converting LLVM frontend to astParent style --- .../cpg/frontends/llvm/DeclarationHandler.kt | 173 ++++---- .../cpg/frontends/llvm/ExpressionHandler.kt | 138 ++++--- .../cpg/frontends/llvm/StatementHandler.kt | 377 ++++++++---------- 3 files changed, 323 insertions(+), 365 deletions(-) diff --git a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/DeclarationHandler.kt b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/DeclarationHandler.kt index 948340673c..41aa972732 100644 --- a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/DeclarationHandler.kt +++ b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/DeclarationHandler.kt @@ -74,19 +74,17 @@ class DeclarationHandler(lang: LLVMIRLanguageFrontend) : // the pointer type val type = frontend.typeOf(valueRef) - val variableDeclaration = newVariableDeclaration(name, type, false, rawNode = valueRef) - - // cache binding - frontend.bindingsCache[valueRef.symbolName] = variableDeclaration - - val size = LLVMGetNumOperands(valueRef) - // the first operand (if it exists) is an initializer - if (size > 0) { - val expr = frontend.expressionHandler.handle(LLVMGetOperand(valueRef, 0)) - variableDeclaration.initializer = expr + return newVariableDeclaration(name, type, false, rawNode = valueRef).withChildren { + // cache binding + frontend.bindingsCache[valueRef.symbolName] = it + + val size = LLVMGetNumOperands(valueRef) + // the first operand (if it exists) is an initializer + if (size > 0) { + val expr = frontend.expressionHandler.handle(LLVMGetOperand(valueRef, 0)) + it.initializer = expr + } } - - return variableDeclaration } /** @@ -97,77 +95,75 @@ class DeclarationHandler(lang: LLVMIRLanguageFrontend) : */ private fun handleFunction(func: LLVMValueRef): FunctionDeclaration { val name = LLVMGetValueName(func) - val functionDeclaration = newFunctionDeclaration(name.string, rawNode = func) - - // return types are a bit tricky, because the type of the function is a pointer to the - // function type, which then has the return type in it - val funcPtrType = LLVMTypeOf(func) - val funcType = LLVMGetElementType(funcPtrType) - val returnType = LLVMGetReturnType(funcType) + return newFunctionDeclaration(name.string, rawNode = func).withChildren(hasScope = true) { + // return types are a bit tricky, because the type of the function is a pointer to the + // function type, which then has the return type in it + val funcPtrType = LLVMTypeOf(func) + val funcType = LLVMGetElementType(funcPtrType) + val returnType = LLVMGetReturnType(funcType) - functionDeclaration.type = frontend.typeOf(returnType) + it.type = frontend.typeOf(returnType) - frontend.scopeManager.enterScope(functionDeclaration) + var param = LLVMGetFirstParam(func) + while (param != null) { + val namePair = frontend.getNameOf(param) + val paramName = namePair.first + val paramSymbolName = namePair.second - var param = LLVMGetFirstParam(func) - while (param != null) { - val namePair = frontend.getNameOf(param) - val paramName = namePair.first - val paramSymbolName = namePair.second + val type = frontend.typeOf(param) - val type = frontend.typeOf(param) + // TODO: support variardic + val decl = newParameterDeclaration(paramName, type, false, rawNode = param) - // TODO: support variardic - val decl = newParameterDeclaration(paramName, type, false, rawNode = param) + frontend.scopeManager.addDeclaration(decl) + frontend.bindingsCache[paramSymbolName] = decl - frontend.scopeManager.addDeclaration(decl) - frontend.bindingsCache[paramSymbolName] = decl - - param = LLVMGetNextParam(param) - } + param = LLVMGetNextParam(param) + } - var bb = LLVMGetFirstBasicBlock(func) - while (bb != null) { - val stmt = frontend.statementHandler.handle(bb) - - // Notice: we have one fundamental challenge here. Basic blocks in LLVM have a flat - // hierarchy, meaning that a function has a list of basic blocks, of which one can - // be unlabeled and is considered to be the entry. All other blocks need to have - // labels and can be reached by branching or jump instructions. If all blocks are - // labeled, then the first one is considered to be the entry. - // - // For our translation into the CPG we translate a basic block into a compound - // statement, i.e. a list of statements. However, in the CPG structure, a function - // definition does not have an entry, which specifies the first block, but it has a - // *body*, which comprises *all* statements within the abstract syntax tree of - // that function, hierarchically organized by compound statements. To emulate that, we - // take the first basic block as our body and add subsequent blocks as statements to - // the body. More specifically, we use the CPG node LabelStatement, which denotes the - // use of a label. Its property substatement contains the original basic block, parsed - // as a compound statement - - // Take the entry block as our body - if (LLVMGetEntryBasicBlock(func) == bb && stmt is Block) { - functionDeclaration.body = stmt - } else if (LLVMGetEntryBasicBlock(func) == bb) { - functionDeclaration.body = newBlock() - if (stmt != null) { - (functionDeclaration.body as Block).addStatement(stmt) - } - } else { - // add the label statement, containing this basic block as a compound statement to - // our body (if we have none, which we should) - if (stmt != null) { - (functionDeclaration.body as? Block)?.addStatement(stmt) + var bb = LLVMGetFirstBasicBlock(func) + while (bb != null) { + val stmt = frontend.statementHandler.handle(bb) + + // Notice: we have one fundamental challenge here. Basic blocks in LLVM have a flat + // hierarchy, meaning that a function has a list of basic blocks, of which one can + // be unlabeled and is considered to be the entry. All other blocks need to have + // labels and can be reached by branching or jump instructions. If all blocks are + // labeled, then the first one is considered to be the entry. + // + // For our translation into the CPG we translate a basic block into a compound + // statement, i.e. a list of statements. However, in the CPG structure, a function + // definition does not have an entry, which specifies the first block, but it has a + // *body*, which comprises *all* statements within the abstract syntax tree of + // that function, hierarchically organized by compound statements. To emulate that, + // we + // take the first basic block as our body and add subsequent blocks as statements to + // the body. More specifically, we use the CPG node LabelStatement, which denotes + // the + // use of a label. Its property substatement contains the original basic block, + // parsed + // as a compound statement + + // Take the entry block as our body + if (LLVMGetEntryBasicBlock(func) == bb && stmt is Block) { + it.body = stmt + } else if (LLVMGetEntryBasicBlock(func) == bb) { + it.body = newBlock() + if (stmt != null) { + (it.body as Block).addStatement(stmt) + } + } else { + // add the label statement, containing this basic block as a compound statement + // to + // our body (if we have none, which we should) + if (stmt != null) { + (it.body as? Block)?.addStatement(stmt) + } } - } - bb = LLVMGetNextBasicBlock(bb) + bb = LLVMGetNextBasicBlock(bb) + } } - - frontend.scopeManager.leaveScope(functionDeclaration) - - return functionDeclaration } /** @@ -206,25 +202,30 @@ class DeclarationHandler(lang: LLVMIRLanguageFrontend) : return record } - record = newRecordDeclaration(name, "struct") - - val size = LLVMCountStructElementTypes(typeRef) + // We need to create a new record declaration but DO NOT create a new scope (yet) + record = + newRecordDeclaration(name, "struct").withChildren { + val size = LLVMCountStructElementTypes(typeRef) - for (i in 0 until size) { - val a = LLVMStructGetTypeAtIndex(typeRef, i) - val fieldType = frontend.typeOf(a, alreadyVisited) + for (i in 0 until size) { + val a = LLVMStructGetTypeAtIndex(typeRef, i) + val fieldType = frontend.typeOf(a, alreadyVisited) - // there are no names, so we need to invent some dummy ones for easier reading - val fieldName = "field_$i" + // there are no names, so we need to invent some dummy ones for easier reading + val fieldName = "field_$i" - frontend.scopeManager.enterScope(record) + // we need to enter the record's scope for each field individually, otherwise + // the above call to typeOf will be inside the record scope and then things go + // wrong + frontend.scopeManager.enterScope(it) - val field = newFieldDeclaration(fieldName, fieldType, listOf(), null, false) + val field = newFieldDeclaration(fieldName, fieldType, listOf(), null, false) - frontend.scopeManager.addDeclaration(field) + frontend.scopeManager.addDeclaration(field) - frontend.scopeManager.leaveScope(record) - } + frontend.scopeManager.leaveScope(it) + } + } // add it to the global scope frontend.scopeManager.addDeclaration(record) diff --git a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/ExpressionHandler.kt b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/ExpressionHandler.kt index a0132951f7..ce2f509a7e 100644 --- a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/ExpressionHandler.kt +++ b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/ExpressionHandler.kt @@ -282,20 +282,19 @@ class ExpressionHandler(lang: LLVMIRLanguageFrontend) : // retrieve the type val type = frontend.typeOf(value) - val expr: ConstructExpression = newConstructExpression(frontend.codeOf(value)) - // map the construct expression to the record declaration of the type - expr.instantiates = (type as? ObjectType)?.recordDeclaration - - // loop through the operands - for (i in 0 until LLVMGetNumOperands(value)) { - // and handle them as expressions themselves - val arg = this.handle(LLVMGetOperand(value, i)) - if (arg != null) { - expr.addArgument(arg) + return newConstructExpression(frontend.codeOf(value)).withChildren { + // map the construct expression to the record declaration of the type + it.instantiates = (type as? ObjectType)?.recordDeclaration + + // loop through the operands + for (i in 0 until LLVMGetNumOperands(value)) { + // and handle them as expressions themselves + val arg = this.handle(LLVMGetOperand(value, i)) + if (arg != null) { + it.addArgument(arg) + } } } - - return expr } /** @@ -315,25 +314,25 @@ class ExpressionHandler(lang: LLVMIRLanguageFrontend) : } val arrayType = LLVMTypeOf(valueRef) - val list = newInitializerListExpression(frontend.typeOf(valueRef), rawNode = valueRef) - val length = - if (LLVMIsAConstantDataArray(valueRef) != null) { - LLVMGetArrayLength(arrayType) - } else { - LLVMGetVectorSize(arrayType) - } - - val initializers = mutableListOf() + return newInitializerListExpression(frontend.typeOf(valueRef), rawNode = valueRef) + .withChildren { + val length = + if (LLVMIsAConstantDataArray(valueRef) != null) { + LLVMGetArrayLength(arrayType) + } else { + LLVMGetVectorSize(arrayType) + } - for (i in 0 until length) { - val expr = handle(LLVMGetAggregateElement(valueRef, i)) as Expression + val initializers = mutableListOf() - initializers += expr - } + for (i in 0 until length) { + val expr = handle(LLVMGetAggregateElement(valueRef, i)) as Expression - list.initializers = initializers + initializers += expr + } - return list + it.initializers = initializers + } } /** @@ -348,20 +347,16 @@ class ExpressionHandler(lang: LLVMIRLanguageFrontend) : ) { newLiteral(null, type, rawNode = value) } else { - val expr: ConstructExpression = - newConstructExpression(frontend.codeOf(value), rawNode = value) - // map the construct expression to the record declaration of the type - expr.instantiates = (type as? ObjectType)?.recordDeclaration - if (expr.instantiates == null) return expr - - // loop through the operands - for (field in (expr.instantiates as RecordDeclaration).fields) { - // and handle them as expressions themselves - val arg = initializeAsUndef(field.type, value) - expr.addArgument(arg) + newConstructExpression(frontend.codeOf(value), rawNode = value).withChildren { + // map the construct expression to the record declaration of the type + it.instantiates = (type as? ObjectType)?.recordDeclaration + // loop through the operands + for (field in (it.instantiates as? RecordDeclaration)?.fields ?: listOf()) { + // and handle them as expressions themselves + val arg = initializeAsUndef(field.type, value) + it.addArgument(arg) + } } - - expr } } @@ -376,20 +371,17 @@ class ExpressionHandler(lang: LLVMIRLanguageFrontend) : ) { newLiteral(0, type, rawNode = value) } else { - val expr: ConstructExpression = - newConstructExpression(frontend.codeOf(value), rawNode = value) - // map the construct expression to the record declaration of the type - expr.instantiates = (type as? ObjectType)?.recordDeclaration - if (expr.instantiates == null) return expr - - // loop through the operands - for (field in (expr.instantiates as RecordDeclaration).fields) { - // and handle them as expressions themselves - val arg = initializeAsZero(field.type, value) - expr.addArgument(arg) + newConstructExpression(frontend.codeOf(value), rawNode = value).withChildren { + // map the construct expression to the record declaration of the type + it.instantiates = (type as? ObjectType)?.recordDeclaration + + // loop through the operands + for (field in (it.instantiates as? RecordDeclaration)?.fields ?: listOf()) { + // and handle them as expressions themselves + val arg = initializeAsZero(field.type, value) + it.addArgument(arg) + } } - - expr } } @@ -464,10 +456,12 @@ class ExpressionHandler(lang: LLVMIRLanguageFrontend) : // check, if the current base type is a pointer -> then we need to handle this as an // array access if (baseType is PointerType) { - val arrayExpr = newSubscriptExpression() - arrayExpr.arrayExpression = base - arrayExpr.name = Name(index.toString()) - arrayExpr.subscriptExpression = operand + val arrayExpr = + newSubscriptExpression().withChildren { + it.arrayExpression = base + it.name = Name(index.toString()) + it.subscriptExpression = operand + } expr = arrayExpr // deference the type to get the new base type @@ -527,6 +521,7 @@ class ExpressionHandler(lang: LLVMIRLanguageFrontend) : log.info("{}", expr) // the current expression is the new base + // TODO: we need to manually push this to the AST stack here base = expr } } @@ -534,8 +529,10 @@ class ExpressionHandler(lang: LLVMIRLanguageFrontend) : // since getelementpr returns the *address*, whereas extractvalue returns a *value*, we need // to do a final unary & operation if (isGetElementPtr) { - val ref = newUnaryOperator("&", postfix = false, prefix = true) - ref.input = expr + val ref = + newUnaryOperator("&", postfix = false, prefix = true).withChildren { + it.input = expr + } expr = ref } @@ -546,26 +543,25 @@ class ExpressionHandler(lang: LLVMIRLanguageFrontend) : * Handles the [`select`](https://llvm.org/docs/LangRef.html#i-select) instruction, which * behaves like a [ConditionalExpression]. */ - fun handleSelect(instr: LLVMValueRef): Expression { - val cond = frontend.getOperandValueAtIndex(instr, 0) - val value1 = frontend.getOperandValueAtIndex(instr, 1) - val value2 = frontend.getOperandValueAtIndex(instr, 2) + fun handleSelect(instr: LLVMValueRef) = + newConditionalExpression().withChildren { + val cond = frontend.getOperandValueAtIndex(instr, 0) + val value1 = frontend.getOperandValueAtIndex(instr, 1) + val value2 = frontend.getOperandValueAtIndex(instr, 2) - return newConditionalExpression(value1.type).withChildren { it.condition = cond it.thenExpression = value1 it.elseExpression = value2 + it.type = value1.type } - } /** * Handles all kinds of instructions which are a * [cast instruction](https://llvm.org/docs/LangRef.html#conversion-operations). */ - fun handleCastInstruction(instr: LLVMValueRef): Expression { - val castExpr = newCastExpression(rawNode = instr) - castExpr.castType = frontend.typeOf(instr) - castExpr.expression = frontend.getOperandValueAtIndex(instr, 0) - return castExpr - } + fun handleCastInstruction(instr: LLVMValueRef) = + newCastExpression(rawNode = instr).withChildren { + it.castType = frontend.typeOf(instr) + it.expression = frontend.getOperandValueAtIndex(instr, 0) + } } diff --git a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt index 3d0c8ea8f4..ad71b55421 100644 --- a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt +++ b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt @@ -71,148 +71,105 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : } val opcode = instr.opCode - - when (opcode) { + return when (opcode) { LLVMRet -> { - val ret = newReturnStatement(rawNode = instr) - - val numOps = LLVMGetNumOperands(instr) - if (numOps != 0) { - ret.returnValue = frontend.getOperandValueAtIndex(instr, 0) + newReturnStatement(rawNode = instr).withChildren { + val numOps = LLVMGetNumOperands(instr) + if (numOps != 0) { + it.returnValue = frontend.getOperandValueAtIndex(instr, 0) + } } - - return ret - } - LLVMBr -> { - return handleBrStatement(instr) - } - LLVMSwitch -> { - return handleSwitchStatement(instr) - } - LLVMIndirectBr -> { - return handleIndirectbrStatement(instr) } + LLVMBr -> handleBrStatement(instr) + LLVMSwitch -> handleSwitchStatement(instr) + LLVMIndirectBr -> handleIndirectbrStatement(instr) LLVMCall, - LLVMInvoke -> { - return handleFunctionCall(instr) - } + LLVMInvoke -> handleFunctionCall(instr) LLVMUnreachable -> { // Does nothing - return newEmptyStatement(rawNode = instr) + newEmptyStatement(rawNode = instr) } LLVMCallBr -> { // Maps to a call but also to a goto statement? Barely used => not relevant log.error("Cannot parse callbr instruction yet") + newProblemExpression("Cannot parse callbr instruction yet") } LLVMFNeg -> { - val fneg = newUnaryOperator("-", postfix = false, prefix = true, rawNode = instr) - fneg.input = frontend.getOperandValueAtIndex(instr, 0) - return fneg - } - LLVMAlloca -> { - return handleAlloca(instr) - } - LLVMLoad -> { - return handleLoad(instr) - } - LLVMStore -> { - return handleStore(instr) + newUnaryOperator("-", postfix = false, prefix = true, rawNode = instr) + .withChildren { it.input = frontend.getOperandValueAtIndex(instr, 0) } } + LLVMAlloca -> handleAlloca(instr) + LLVMLoad -> handleLoad(instr) + LLVMStore -> handleStore(instr) LLVMExtractValue, LLVMGetElementPtr -> { - return declarationOrNot( - frontend.expressionHandler.handleGetElementPtr(instr), - instr - ) - } - LLVMICmp -> { - return handleIntegerComparison(instr) - } - LLVMFCmp -> { - return handleFloatComparison(instr) + declarationOrNot(frontend.expressionHandler.handleGetElementPtr(instr), instr) } + LLVMICmp -> handleIntegerComparison(instr) + LLVMFCmp -> handleFloatComparison(instr) LLVMPHI -> { frontend.phiList.add(instr) - return newEmptyStatement(rawNode = instr) + newEmptyStatement(rawNode = instr) } LLVMSelect -> { - return declarationOrNot(frontend.expressionHandler.handleSelect(instr), instr) + declarationOrNot(frontend.expressionHandler.handleSelect(instr), instr) } LLVMUserOp1, LLVMUserOp2 -> { log.info( "userop instruction is not a real instruction. Replacing it with empty statement" ) - return newEmptyStatement(rawNode = instr) - } - LLVMVAArg -> { - return handleVaArg(instr) - } - LLVMExtractElement -> { - return handleExtractelement(instr) - } - LLVMInsertElement -> { - return handleInsertelement(instr) - } - LLVMShuffleVector -> { - return handleShufflevector(instr) - } - LLVMInsertValue -> { - return handleInsertValue(instr) - } - LLVMFreeze -> { - return handleFreeze(instr) - } - LLVMFence -> { - return handleFence(instr) - } - LLVMAtomicCmpXchg -> { - return handleAtomiccmpxchg(instr) - } - LLVMAtomicRMW -> { - return handleAtomicrmw(instr) - } + newEmptyStatement(rawNode = instr) + } + LLVMVAArg -> handleVaArg(instr) + LLVMExtractElement -> handleExtractelement(instr) + LLVMInsertElement -> handleInsertelement(instr) + LLVMShuffleVector -> handleShufflevector(instr) + LLVMInsertValue -> handleInsertValue(instr) + LLVMFreeze -> handleFreeze(instr) + LLVMFence -> handleFence(instr) + LLVMAtomicCmpXchg -> handleAtomiccmpxchg(instr) + LLVMAtomicRMW -> handleAtomicrmw(instr) LLVMResume -> { // Resumes propagation of an existing (in-flight) exception whose unwinding was // interrupted with a landingpad instruction. - return newUnaryOperator("throw", postfix = false, prefix = true, rawNode = instr) - } - LLVMLandingPad -> { - return handleLandingpad(instr) + newUnaryOperator("throw", postfix = false, prefix = true, rawNode = instr) } + LLVMLandingPad -> handleLandingpad(instr) LLVMCleanupRet -> { // End of the cleanup basic block(s) // Jump to a label where handling the exception will unwind to next (e.g. a // catchswitch statement) - return handleCatchret(instr) + handleCatchret(instr) } LLVMCatchRet -> { // Catch (caught by catchpad instruction) is over. // Jumps to a label where the "normal" function logic continues - return handleCatchret(instr) + handleCatchret(instr) } LLVMCatchPad -> { // Actually handles the exception. - return handleCatchpad(instr) + handleCatchpad(instr) } LLVMCleanupPad -> { // Beginning of the cleanup basic block(s). // We should model this as the beginning of a catch block - return handleCleanuppad(instr) + handleCleanuppad(instr) } LLVMCatchSwitch -> { // Marks the beginning of a "real" catch block // Jumps to one of the handlers specified or to the default handler (if specified) - return handleCatchswitch(instr) + handleCatchswitch(instr) + } + else -> { + log.error("Not handling instruction opcode {} yet", opcode) + newProblemExpression( + "Not handling instruction opcode $opcode yet", + ProblemNode.ProblemType.TRANSLATION, + rawNode = instr + ) } } - - log.error("Not handling instruction opcode {} yet", opcode) - return newProblemExpression( - "Not handling instruction opcode $opcode yet", - ProblemNode.ProblemType.TRANSLATION, - rawNode = instr - ) } /** @@ -240,13 +197,9 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : } ) return if (unwindDest != null) { // For "unwind to caller", the destination is null - val gotoStatement = assembleGotoStatement(instr, unwindDest) - gotoStatement.name = name - gotoStatement + assembleGotoStatement(instr, unwindDest).withChildren { it.name = name } } else { - val emptyStatement = newEmptyStatement(rawNode = instr) - emptyStatement.name = name - emptyStatement + newEmptyStatement(rawNode = instr).withChildren { it.name = name } } } @@ -263,87 +216,87 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : private fun handleCatchswitch(instr: LLVMValueRef): Statement { val numOps = LLVMGetNumOperands(instr) - val parent = frontend.getOperandValueAtIndex(instr, 0) - - val compoundStatement = newBlock(rawNode = instr) - - val dummyCall = - newCallExpression( - llvmInternalRef("llvm.catchswitch"), - "llvm.catchswitch", - false, - rawNode = instr - ) - dummyCall.addArgument(parent, "parent") + return newBlock(rawNode = instr).withChildren { block -> + val dummyCall = + newCallExpression( + llvmInternalRef("llvm.catchswitch"), + "llvm.catchswitch", + false, + rawNode = instr + ) + .withChildren { + val parent = frontend.getOperandValueAtIndex(instr, 0) + it.addArgument(parent, "parent") + } - val tokenGeneration = declarationOrNot(dummyCall, instr) as DeclarationStatement - compoundStatement.addStatement(tokenGeneration) + val tokenGeneration = declarationOrNot(dummyCall, instr) as DeclarationStatement + block.addStatement(tokenGeneration) - val ifStatement = newIfStatement(rawNode = instr) - var currentIfStatement: IfStatement? = null - var idx = 1 - while (idx < numOps) { - if (currentIfStatement == null) { - currentIfStatement = ifStatement - } else { - val newIf = newIfStatement(rawNode = instr) - currentIfStatement.elseStatement = newIf - currentIfStatement = newIf - } + val ifStatement = newIfStatement(rawNode = instr) + var currentIfStatement: IfStatement? = null + var idx = 1 + while (idx < numOps) { + if (currentIfStatement == null) { + currentIfStatement = ifStatement + } else { + val newIf = newIfStatement(rawNode = instr) + currentIfStatement.elseStatement = newIf + currentIfStatement = newIf + } - // For each of the handlers, we get the first instruction and insert a statement - // case llvm.matchesCatchpad(parent, args), where args are used to determine if - // this handler accepts the object thrown. - val bbTarget = LLVMGetOperand(instr, idx) + // For each of the handlers, we get the first instruction and insert a statement + // case llvm.matchesCatchpad(parent, args), where args are used to determine if + // this handler accepts the object thrown. + val bbTarget = LLVMGetOperand(instr, idx) - val catchpad = LLVMGetFirstInstruction(LLVMValueAsBasicBlock(bbTarget)) - val catchOps = LLVMGetNumArgOperands(catchpad) + val catchpad = LLVMGetFirstInstruction(LLVMValueAsBasicBlock(bbTarget)) + val catchOps = LLVMGetNumArgOperands(catchpad) - val matchesCatchpad = - newCallExpression( - llvmInternalRef("llvm.matchesCatchpad"), - "llvm.matchesCatchpad", - false, - rawNode = instr - ) + val matchesCatchpad = + newCallExpression( + llvmInternalRef("llvm.matchesCatchpad"), + "llvm.matchesCatchpad", + false, + rawNode = instr + ) - val parentCatchSwitch = LLVMGetParentCatchSwitch(catchpad) - val catchswitch = frontend.expressionHandler.handle(parentCatchSwitch) as Expression - matchesCatchpad.addArgument(catchswitch, "parentCatchswitch") + val parentCatchSwitch = LLVMGetParentCatchSwitch(catchpad) + val catchswitch = frontend.expressionHandler.handle(parentCatchSwitch) as Expression + matchesCatchpad.addArgument(catchswitch, "parentCatchswitch") - for (i in 0 until catchOps) { - val arg = frontend.getOperandValueAtIndex(catchpad, i) - matchesCatchpad.addArgument(arg, "args_$i") - } + for (i in 0 until catchOps) { + val arg = frontend.getOperandValueAtIndex(catchpad, i) + matchesCatchpad.addArgument(arg, "args_$i") + } - currentIfStatement.condition = matchesCatchpad + currentIfStatement.condition = matchesCatchpad - // Get the label of the goto statement. - val gotoStatement = assembleGotoStatement(instr, bbTarget) - currentIfStatement.thenStatement = gotoStatement + // Get the label of the goto statement. + val gotoStatement = assembleGotoStatement(instr, bbTarget) + currentIfStatement.thenStatement = gotoStatement - idx++ - } + idx++ + } - val unwindDest = LLVMGetUnwindDest(instr) - if (unwindDest != null) { // For "unwind to caller", the destination is null - val gotoStatement = assembleGotoStatement(instr, LLVMBasicBlockAsValue(unwindDest)) - if (currentIfStatement == null) { - currentIfStatement = ifStatement + val unwindDest = LLVMGetUnwindDest(instr) + if (unwindDest != null) { // For "unwind to caller", the destination is null + val gotoStatement = assembleGotoStatement(instr, LLVMBasicBlockAsValue(unwindDest)) + if (currentIfStatement == null) { + currentIfStatement = ifStatement + } + currentIfStatement.elseStatement = gotoStatement + } else { + // "unwind to caller". As we don't know where the control flow continues, + // the best model would be to throw the exception again. Here, we only know + // that we will throw something here but we don't know what. We have to fix + // that later once we know in which catch-block this statement is executed. + val throwOperation = + newUnaryOperator("throw", postfix = false, prefix = true, rawNode = instr) + currentIfStatement?.elseStatement = throwOperation } - currentIfStatement.elseStatement = gotoStatement - } else { - // "unwind to caller". As we don't know where the control flow continues, - // the best model would be to throw the exception again. Here, we only know - // that we will throw something here but we don't know what. We have to fix - // that later once we know in which catch-block this statement is executed. - val throwOperation = - newUnaryOperator("throw", postfix = false, prefix = true, rawNode = instr) - currentIfStatement?.elseStatement = throwOperation - } - compoundStatement.addStatement(ifStatement) - return compoundStatement + block.addStatement(ifStatement) + } } /** @@ -499,12 +452,10 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : val dereference = newUnaryOperator("*", postfix = false, prefix = true, rawNode = instr) dereference.input = frontend.getOperandValueAtIndex(instr, 1) - return newAssignExpression( - "=", - listOf(dereference), - listOf(frontend.getOperandValueAtIndex(instr, 0)), - rawNode = instr - ) + return newAssignExpression("=", rawNode = instr).withChildren { + it.lhs = listOf(dereference) + it.rhs = listOf(frontend.getOperandValueAtIndex(instr, 0)) + } } /** @@ -703,12 +654,15 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : } } - val compoundStatement = newBlock(rawNode = instr) - val assignment = newAssignExpression("=", listOf(base), listOf(valueToSet), rawNode = instr) - compoundStatement.addStatement(copy) - compoundStatement.addStatement(assignment) - - return compoundStatement + return newBlock(rawNode = instr).withChildren { block -> + val assignment = + newAssignExpression("=", rawNode = instr).withChildren { + it.lhs = listOf(base) + it.rhs = listOf(valueToSet) + } + block.addStatement(copy) + block.addStatement(assignment) + } } /** @@ -750,11 +704,13 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : // res = (arg != undef && arg != poison) ? arg : llvm.freeze(in) val conditional = newConditionalExpression( - condition, - operand, - callExpression, - operand.type, - ) + operand.type, + ) + .withChildren { + it.condition = condition + it.thenExpression = operand + it.elseExpression = callExpression + } return declarationOrNot(conditional, instr) } @@ -842,11 +798,16 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : ptrDerefAssign.input = frontend.getOperandValueAtIndex(instr, 0) val assignment = - newAssignExpression("=", listOf(ptrDerefAssign), listOf(value), rawNode = instr) + newAssignExpression("=", rawNode = instr).withChildren { + it.lhs = listOf(ptrDerefAssign) + it.rhs = listOf(value) + } - val ifStatement = newIfStatement(rawNode = instr) - ifStatement.condition = cmpExpr - ifStatement.thenStatement = assignment + val ifStatement = + newIfStatement(rawNode = instr).withChildren { + it.condition = cmpExpr + it.thenStatement = assignment + } compoundStatement.addStatement(ifStatement) @@ -933,11 +894,13 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : ptrDerefConditional.input = frontend.getOperandValueAtIndex(instr, 0) val conditional = newConditionalExpression( - condition, - ptrDerefConditional, - value, - ty, - ) + ty, + ) + .withChildren { + it.condition = condition + it.thenExpression = ptrDerefConditional + it.elseExpression = value + } exchOp.rhs = listOf(conditional) } LLVMAtomicRMWBinOpUMax, @@ -963,11 +926,13 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : ptrDerefConditional.input = frontend.getOperandValueAtIndex(instr, 0) val conditional = newConditionalExpression( - condition, - ptrDerefConditional, - value, - ty, - ) + ty, + ) + .withChildren { + it.condition = condition + it.thenExpression = ptrDerefConditional + it.elseExpression = value + } exchOp.rhs = listOf(conditional) } else -> { @@ -1243,12 +1208,10 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : arrayExpr.subscriptExpression = frontend.getOperandValueAtIndex(instr, 2) val assignExpr = - newAssignExpression( - "=", - listOf(arrayExpr), - listOf(frontend.getOperandValueAtIndex(instr, 1)), - rawNode = instr - ) + newAssignExpression("=", rawNode = instr).withChildren { + it.lhs = listOf(arrayExpr) + it.rhs = listOf(frontend.getOperandValueAtIndex(instr, 1)) + } compoundStatement.addStatement(assignExpr) return compoundStatement @@ -1424,12 +1387,10 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : for ((l, r) in labelMap) { // Now, we iterate over all the basic blocks and add an assign statement. val assignment = - newAssignExpression( - "=", - listOf(newReference(varName, type, rawNode = instr)), - listOf(r), - rawNode = instr - ) + newAssignExpression("=", rawNode = instr).withChildren { + it.lhs = listOf(newReference(varName, type, rawNode = instr)) + it.rhs = listOf(r) + } (assignment.lhs.first() as Reference).type = type (assignment.lhs.first() as Reference).refersTo = declaration flatAST.add(assignment) From d6deeb5c9074088a9afe5ea3bd1603994a4462ac Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Thu, 18 Jul 2024 22:03:28 +0200 Subject: [PATCH 2/9] More LLVM --- .../cpg/frontends/llvm/StatementHandler.kt | 85 ++++++++++--------- 1 file changed, 47 insertions(+), 38 deletions(-) diff --git a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt index ad71b55421..a6689adf4b 100644 --- a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt +++ b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt @@ -236,6 +236,7 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : var currentIfStatement: IfStatement? = null var idx = 1 while (idx < numOps) { + // TODO: somehow change the AST stack here if (currentIfStatement == null) { currentIfStatement = ifStatement } else { @@ -254,21 +255,22 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : val matchesCatchpad = newCallExpression( - llvmInternalRef("llvm.matchesCatchpad"), - "llvm.matchesCatchpad", - false, - rawNode = instr - ) - - val parentCatchSwitch = LLVMGetParentCatchSwitch(catchpad) - val catchswitch = frontend.expressionHandler.handle(parentCatchSwitch) as Expression - matchesCatchpad.addArgument(catchswitch, "parentCatchswitch") - - for (i in 0 until catchOps) { - val arg = frontend.getOperandValueAtIndex(catchpad, i) - matchesCatchpad.addArgument(arg, "args_$i") - } - + llvmInternalRef("llvm.matchesCatchpad"), + "llvm.matchesCatchpad", + false, + rawNode = instr + ) + .withChildren { + val parentCatchSwitch = LLVMGetParentCatchSwitch(catchpad) + val catchswitch = + frontend.expressionHandler.handle(parentCatchSwitch) as Expression + it.addArgument(catchswitch, "parentCatchswitch") + + for (i in 0 until catchOps) { + val arg = frontend.getOperandValueAtIndex(catchpad, i) + it.addArgument(arg, "args_$i") + } + } currentIfStatement.condition = matchesCatchpad // Get the label of the goto statement. @@ -311,17 +313,19 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : val dummyCall = newCallExpression( - llvmInternalRef("llvm.cleanuppad"), - "llvm.cleanuppad", - false, - rawNode = instr - ) - dummyCall.addArgument(catchswitch, "parentCatchswitch") + llvmInternalRef("llvm.cleanuppad"), + "llvm.cleanuppad", + false, + rawNode = instr + ) + .withChildren { + it.addArgument(catchswitch, "parentCatchswitch") - for (i in 1 until numOps) { - val arg = frontend.getOperandValueAtIndex(instr, i) - dummyCall.addArgument(arg, "args_${i-1}") - } + for (i in 1 until numOps) { + val arg = frontend.getOperandValueAtIndex(instr, i) + it.addArgument(arg, "args_${i - 1}") + } + } return declarationOrNot(dummyCall, instr) } @@ -1417,21 +1421,26 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : // if it is still empty, we probably do not have a left side return if (lhs != "") { - val decl = - newVariableDeclaration(lhs, frontend.typeOf(valueRef), false, rawNode = valueRef) - decl.initializer = rhs - - // add the declaration to the current scope - frontend.scopeManager.addDeclaration(decl) - - // add it to our bindings cache - frontend.bindingsCache[symbolName] = decl - // Since the declaration statement only contains the single declaration, we can use the // same raw node, so we end up with the same code and location - val declStatement = newDeclarationStatement(rawNode = valueRef) - declStatement.singleDeclaration = decl - declStatement + newDeclarationStatement(rawNode = valueRef).withChildren { + it.singleDeclaration = + newVariableDeclaration( + lhs, + frontend.typeOf(valueRef), + false, + rawNode = valueRef + ) + .withChildren { decl -> + decl.initializer = rhs + + // add the declaration to the current scope + frontend.scopeManager.addDeclaration(decl) + + // add it to our bindings cache + frontend.bindingsCache[symbolName] = decl + } + } } else { rhs } From c15a0a8bbd8a01e93a9e45c07f5254469870c875 Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Thu, 18 Jul 2024 22:52:20 +0200 Subject: [PATCH 3/9] more conversion --- .../cpg/frontends/llvm/StatementHandler.kt | 429 +++++++++--------- 1 file changed, 215 insertions(+), 214 deletions(-) diff --git a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt index a6689adf4b..0bcc47185d 100644 --- a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt +++ b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt @@ -307,11 +307,8 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : * parent and the args as arguments. */ @FunctionReplacement(["llvm.cleanuppad"], "cleanuppad") - private fun handleCleanuppad(instr: LLVMValueRef): Statement { - val numOps = LLVMGetNumArgOperands(instr) - val catchswitch = frontend.getOperandValueAtIndex(instr, 0) - - val dummyCall = + private fun handleCleanuppad(instr: LLVMValueRef) = + declarationOrNot( newCallExpression( llvmInternalRef("llvm.cleanuppad"), "llvm.cleanuppad", @@ -319,15 +316,18 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : rawNode = instr ) .withChildren { + val numOps = LLVMGetNumArgOperands(instr) + val catchswitch = frontend.getOperandValueAtIndex(instr, 0) + it.addArgument(catchswitch, "parentCatchswitch") for (i in 1 until numOps) { val arg = frontend.getOperandValueAtIndex(instr, i) it.addArgument(arg, "args_${i - 1}") } - } - return declarationOrNot(dummyCall, instr) - } + }, + instr + ) /** * We simulate a [`catchpad`](https://llvm.org/docs/LangRef.html#catchpad-instruction) @@ -335,26 +335,21 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : * catchswitch and the args as arguments. */ @FunctionReplacement(["llvm.catchpad"], "catchpad") - private fun handleCatchpad(instr: LLVMValueRef): Statement { - val numOps = LLVMGetNumArgOperands(instr) - val parentCatchSwitch = LLVMGetParentCatchSwitch(instr) - val catchswitch = frontend.expressionHandler.handle(parentCatchSwitch) as Expression - - val dummyCall = - newCallExpression( - llvmInternalRef("llvm.catchpad"), - "llvm.catchpad", - false, - rawNode = instr - ) - dummyCall.addArgument(catchswitch, "parentCatchswitch") - - for (i in 0 until numOps) { - val arg = frontend.getOperandValueAtIndex(instr, i) - dummyCall.addArgument(arg, "args_$i") - } - return declarationOrNot(dummyCall, instr) - } + private fun handleCatchpad(instr: LLVMValueRef) = + newCallExpression(llvmInternalRef("llvm.catchpad"), "llvm.catchpad", false, rawNode = instr) + .withChildren { + val numOps = LLVMGetNumArgOperands(instr) + val parentCatchSwitch = LLVMGetParentCatchSwitch(instr) + val catchswitch = frontend.expressionHandler.handle(parentCatchSwitch) as Expression + + it.addArgument(catchswitch, "parentCatchswitch") + + for (i in 0 until numOps) { + val arg = frontend.getOperandValueAtIndex(instr, i) + it.addArgument(arg, "args_$i") + } + } + .declareIfNecessary(instr) /** * Handles the [`va_arg`](https://llvm.org/docs/LangRef.html#va-arg-instruction) instruction. It @@ -362,70 +357,45 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : * function takes two arguments: the vararg-list and the type of the return value. */ @FunctionReplacement(["llvm.va_arg"], "va_arg") - private fun handleVaArg(instr: LLVMValueRef): Statement { - val callExpr = - newCallExpression(llvmInternalRef("llvm.va_arg"), "llvm.va_arg", false, rawNode = instr) - val operandName = frontend.getOperandValueAtIndex(instr, 0) - callExpr.addArgument(operandName) - val expectedType = frontend.typeOf(instr) - val typeLiteral = newLiteral(expectedType, expectedType, rawNode = instr) - callExpr.addArgument(typeLiteral) // TODO: Is this correct?? - return declarationOrNot(callExpr, instr) - } + private fun handleVaArg(instr: LLVMValueRef) = + newCallExpression(llvmInternalRef("llvm.va_arg"), "llvm.va_arg", false, rawNode = instr) + .withChildren { + val operandName = frontend.getOperandValueAtIndex(instr, 0) + it.addArgument(operandName) + val expectedType = frontend.typeOf(instr) + val typeLiteral = newLiteral(expectedType, expectedType, rawNode = instr) + it.addArgument(typeLiteral) // TODO: Is this correct?? + } + .declareIfNecessary(instr) /** Handles all kinds of instructions which are an arithmetic or logical binary instruction. */ private fun handleBinaryInstruction(instr: LLVMValueRef): Statement { - when (instr.opCode) { + return when (instr.opCode) { LLVMAdd, - LLVMFAdd -> { - return handleBinaryOperator(instr, "+", false) - } + LLVMFAdd -> handleBinaryOperator(instr, "+", false) LLVMSub, - LLVMFSub -> { - return handleBinaryOperator(instr, "-", false) - } + LLVMFSub -> handleBinaryOperator(instr, "-", false) LLVMMul, - LLVMFMul -> { - return handleBinaryOperator(instr, "*", false) - } - LLVMUDiv -> { - return handleBinaryOperator(instr, "/", true) - } + LLVMFMul -> handleBinaryOperator(instr, "*", false) + LLVMUDiv -> handleBinaryOperator(instr, "/", true) LLVMSDiv, - LLVMFDiv -> { - return handleBinaryOperator(instr, "/", false) - } - LLVMURem -> { - return handleBinaryOperator(instr, "%", true) - } + LLVMFDiv -> handleBinaryOperator(instr, "/", false) + LLVMURem -> handleBinaryOperator(instr, "%", true) LLVMSRem, - LLVMFRem -> { - return handleBinaryOperator(instr, "%", false) - } - LLVMShl -> { - return handleBinaryOperator(instr, "<<", false) - } - LLVMLShr -> { - return handleBinaryOperator(instr, ">>", true) - } - LLVMAShr -> { - return handleBinaryOperator(instr, ">>", false) - } - LLVMAnd -> { - return handleBinaryOperator(instr, "&", false) - } - LLVMOr -> { - return handleBinaryOperator(instr, "|", false) - } - LLVMXor -> { - return handleBinaryOperator(instr, "^", false) - } + LLVMFRem -> handleBinaryOperator(instr, "%", false) + LLVMShl -> handleBinaryOperator(instr, "<<", false) + LLVMLShr -> handleBinaryOperator(instr, ">>", true) + LLVMAShr -> handleBinaryOperator(instr, ">>", false) + LLVMAnd -> handleBinaryOperator(instr, "&", false) + LLVMOr -> handleBinaryOperator(instr, "|", false) + LLVMXor -> handleBinaryOperator(instr, "^", false) + else -> + newProblemExpression( + "Not opcode found for binary operator", + ProblemNode.ProblemType.TRANSLATION, + rawNode = instr + ) } - return newProblemExpression( - "Not opcode found for binary operator", - ProblemNode.ProblemType.TRANSLATION, - rawNode = instr - ) } /** @@ -433,45 +403,42 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : * which allocates a defined block of memory. The closest what we have in the graph is the * [NewArrayExpression], which creates a fixed sized array, i.e., a block of memory. */ - private fun handleAlloca(instr: LLVMValueRef): Statement { - val array = newNewArrayExpression(rawNode = instr) - - array.type = frontend.typeOf(instr) - - // LLVM is quite forthcoming here. in case the optional length parameter is omitted in the - // source code, it will automatically be set to 1 - val size = frontend.getOperandValueAtIndex(instr, 0) + private fun handleAlloca(instr: LLVMValueRef) = + newNewArrayExpression(rawNode = instr) + .withChildren { + it.type = frontend.typeOf(instr) - array.addDimension(size) + // LLVM is quite forthcoming here. in case the optional length parameter is omitted + // in the source code, it will automatically be set to 1 + val size = frontend.getOperandValueAtIndex(instr, 0) - return declarationOrNot(array, instr) - } + it.addDimension(size) + } + .declareIfNecessary(instr) /** * Handles the [`store`](https://llvm.org/docs/LangRef.html#store-instruction) instruction. It * stores a particular value at a pointer address. This is the rough equivalent to an assignment * of a de-referenced pointer in C like `*a = 1`. */ - private fun handleStore(instr: LLVMValueRef): Statement { - val dereference = newUnaryOperator("*", postfix = false, prefix = true, rawNode = instr) - dereference.input = frontend.getOperandValueAtIndex(instr, 1) - - return newAssignExpression("=", rawNode = instr).withChildren { - it.lhs = listOf(dereference) + private fun handleStore(instr: LLVMValueRef) = + newAssignExpression("=", rawNode = instr).withChildren { + it.lhs = + listOf( + newUnaryOperator("*", postfix = false, prefix = true, rawNode = instr) + .withChildren { it.input = frontend.getOperandValueAtIndex(instr, 1) } + ) it.rhs = listOf(frontend.getOperandValueAtIndex(instr, 0)) } - } /** * Handles the [`load`](https://llvm.org/docs/LangRef.html#load-instruction) instruction, which * is basically just a pointer de-reference. */ - private fun handleLoad(instr: LLVMValueRef): Statement { - val ref = newUnaryOperator("*", postfix = false, prefix = true, rawNode = instr) - ref.input = frontend.getOperandValueAtIndex(instr, 0) - - return declarationOrNot(ref, instr) - } + private fun handleLoad(instr: LLVMValueRef) = + newUnaryOperator("*", postfix = false, prefix = true, rawNode = instr) + .withChildren { it.input = frontend.getOperandValueAtIndex(instr, 0) } + .declareIfNecessary(instr) /** * Handles the [`icmp`](https://llvm.org/docs/LangRef.html#icmp-instruction) instruction for @@ -678,45 +645,49 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : * to each data type. */ @FunctionReplacement(["llvm.freeze"], "freeze") - private fun handleFreeze(instr: LLVMValueRef): Statement { - val operand = frontend.getOperandValueAtIndex(instr, 0) - - // condition: arg != undef && arg != poison - val condition = newBinaryOperator("&&", rawNode = instr) - val undefCheck = newBinaryOperator("!=", rawNode = instr) - undefCheck.lhs = operand - undefCheck.rhs = newLiteral(null, operand.type, rawNode = instr) - condition.lhs = undefCheck - val poisonCheck = newBinaryOperator("!=", rawNode = instr) - poisonCheck.lhs = operand - poisonCheck.rhs = - newLiteral( - "POISON", - operand.type, - rawNode = instr - ) // This could be e.g. NAN. Not sure for complex types - condition.rhs = poisonCheck + private fun handleFreeze(instr: LLVMValueRef) = + // res = (arg != undef && arg != poison) ? arg : llvm.freeze(in) + newConditionalExpression() + .withChildren { + val operand = frontend.getOperandValueAtIndex(instr, 0) + + // condition: arg != undef && arg != poison + it.condition = + newBinaryOperator("&&", rawNode = instr).withChildren { + it.lhs = + newBinaryOperator("!=", rawNode = instr).withChildren { + it.lhs = operand + it.rhs = newLiteral(null, operand.type, rawNode = instr) + } + it.rhs = + newBinaryOperator("!=", rawNode = instr).withChildren { + it.lhs = operand + it.rhs = + newLiteral( + "POISON", + operand.type, + rawNode = instr + ) // This could be e.g. NAN. Not sure for complex types + } + } - // Call to a dummy function "llvm.freeze" which would fill the undef or poison values - // randomly. - // The implementation of this function would depend on the data type (e.g. for integers, it - // could be rand()) - val callExpression = - newCallExpression(llvmInternalRef("llvm.freeze"), "llvm.freeze", false, rawNode = instr) - callExpression.addArgument(operand) + it.thenExpression = operand + it.type = operand.type - // res = (arg != undef && arg != poison) ? arg : llvm.freeze(in) - val conditional = - newConditionalExpression( - operand.type, - ) - .withChildren { - it.condition = condition - it.thenExpression = operand - it.elseExpression = callExpression - } - return declarationOrNot(conditional, instr) - } + // Call to a dummy function "llvm.freeze" which would fill the undef or poison + // values randomly. + // The implementation of this function would depend on the data type (e.g. for + // integers, it could be rand()) + it.elseExpression = + newCallExpression( + llvmInternalRef("llvm.freeze"), + "llvm.freeze", + false, + rawNode = instr + ) + .withChildren { it.addArgument(operand) } + } + .declareIfNecessary(instr) /** * Handles the [`freeze`](https://llvm.org/docs/LangRef.html#fence-instruction) instruction. @@ -729,19 +700,24 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : @FunctionReplacement(["llvm.fence"], "fence") private fun handleFence(instr: LLVMValueRef): Statement { val instrString = frontend.codeOf(instr) - val callExpression = - newCallExpression(llvmInternalRef("llvm.fence"), "llvm.fence", false, rawNode = instr) - val ordering = newLiteral(LLVMGetOrdering(instr), primitiveType("i32"), rawNode = instr) - callExpression.addArgument(ordering, "ordering") - if (instrString?.contains("syncscope") == true) { - val syncscope = instrString.split("\"")[1] - callExpression.addArgument( - newLiteral(syncscope, objectType("String"), rawNode = instr), - "syncscope" + return newCallExpression( + llvmInternalRef("llvm.fence"), + "llvm.fence", + false, + rawNode = instr ) - } - - return callExpression + .withChildren { + val ordering = + newLiteral(LLVMGetOrdering(instr), primitiveType("i32"), rawNode = instr) + it.addArgument(ordering, "ordering") + if (instrString?.contains("syncscope") == true) { + val syncscope = instrString.split("\"")[1] + it.addArgument( + newLiteral(syncscope, objectType("String"), rawNode = instr), + "syncscope" + ) + } + } } /** @@ -756,68 +732,89 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : * Returns a [Block] with those two instructions or, if `lhs` doesn't exist, only the if-then * statement. */ - private fun handleAtomiccmpxchg(instr: LLVMValueRef): Statement { - val compoundStatement = newBlock(rawNode = instr) - compoundStatement.name = Name("atomiccmpxchg") - val ptr = frontend.getOperandValueAtIndex(instr, 0) - val cmp = frontend.getOperandValueAtIndex(instr, 1) - val value = frontend.getOperandValueAtIndex(instr, 2) - - val ptrDerefCmp = newUnaryOperator("*", postfix = false, prefix = true, rawNode = instr) - ptrDerefCmp.input = ptr - - val cmpExpr = newBinaryOperator("==", rawNode = instr) - cmpExpr.lhs = ptrDerefCmp - cmpExpr.rhs = cmp - - val lhs = LLVMGetValueName(instr).string - if (lhs != "") { - // we need to create a crazy struct here. the target type can be found here - val targetType = frontend.typeOf(instr) + private fun handleAtomiccmpxchg(instr: LLVMValueRef) = + newBlock(rawNode = instr).withChildren { block -> + block.name = Name("atomiccmpxchg") - // construct it - val construct = newConstructExpression("") - construct.instantiates = (targetType as? ObjectType)?.recordDeclaration - - val ptrDerefConstruct = - newUnaryOperator("*", postfix = false, prefix = true, rawNode = instr) - ptrDerefConstruct.input = frontend.getOperandValueAtIndex(instr, 0) - - val ptrDerefCmpConstruct = - newUnaryOperator("*", postfix = false, prefix = true, rawNode = instr) - ptrDerefCmpConstruct.input = frontend.getOperandValueAtIndex(instr, 0) - - val cmpExprConstruct = newBinaryOperator("==", rawNode = instr) - cmpExprConstruct.lhs = ptrDerefCmpConstruct - cmpExprConstruct.rhs = frontend.getOperandValueAtIndex(instr, 1) - - construct.addArgument(ptrDerefConstruct) - construct.addArgument(cmpExprConstruct) + val lhs = LLVMGetValueName(instr).string + // lhs = {*pointer, *pointer == cmp} // A struct of {T, i1} + if (lhs != "") { + // we need to create a crazy struct here. the target type can be found here + val targetType = frontend.typeOf(instr) - val decl = declarationOrNot(construct, instr) - compoundStatement.addStatement(decl) + // construct it + block += + newConstructExpression("") + .withChildren { + it.instantiates = (targetType as? ObjectType)?.recordDeclaration + + // arguments + it += + newUnaryOperator( + "*", + postfix = false, + prefix = true, + rawNode = instr + ) + .withChildren { + // pointer + it.input = frontend.getOperandValueAtIndex(instr, 0) + } + it += + newBinaryOperator("==", rawNode = instr).withChildren { + it.lhs = + newUnaryOperator( + "*", + postfix = false, + prefix = true, + rawNode = instr + ) + .withChildren { + // pointer + it.input = frontend.getOperandValueAtIndex(instr, 0) + } + // cmp + it.rhs = frontend.getOperandValueAtIndex(instr, 1) + } + } + .declareIfNecessary(instr) + } + + // if(*pointer == cmp) { *pointer = new } + block += + newIfStatement(rawNode = instr).withChildren { + it.condition = + newBinaryOperator("==", rawNode = instr).withChildren { + it.lhs = + newUnaryOperator( + "*", + postfix = false, + prefix = true, + rawNode = instr + ) + .withChildren { + // pointer + it.input = frontend.getOperandValueAtIndex(instr, 0) + } + // cmp + it.rhs = frontend.getOperandValueAtIndex(instr, 1) + } + it.thenStatement = + newAssignExpression("=", rawNode = instr).withChildren { + it.lhs = + listOf( + newUnaryOperator("*", false, true, rawNode = instr) + .withChildren { + // pointer + it.input = frontend.getOperandValueAtIndex(instr, 0) + } + ) + // new + it.rhs = listOf(frontend.getOperandValueAtIndex(instr, 2)) + } + } } - val ptrDerefAssign = newUnaryOperator("*", false, true, rawNode = instr) - ptrDerefAssign.input = frontend.getOperandValueAtIndex(instr, 0) - - val assignment = - newAssignExpression("=", rawNode = instr).withChildren { - it.lhs = listOf(ptrDerefAssign) - it.rhs = listOf(value) - } - - val ifStatement = - newIfStatement(rawNode = instr).withChildren { - it.condition = cmpExpr - it.thenStatement = assignment - } - - compoundStatement.addStatement(ifStatement) - - return compoundStatement - } - /** * Parses the `atomicrmw` instruction. It returns either a single [Statement] or a [Block] if * the value is assigned to another variable. @@ -1636,4 +1633,8 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : private fun llvmInternalRef(name: String): Reference { return newReference(name) } + + fun Expression.declareIfNecessary(instr: LLVMValueRef): Statement { + return declarationOrNot(this, instr) + } } From a3b8df960368f23de8e4ac6b114240372c61d5a4 Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Thu, 18 Jul 2024 23:03:37 +0200 Subject: [PATCH 4/9] more --- .../cpg/frontends/llvm/StatementHandler.kt | 227 ++++++++++-------- 1 file changed, 121 insertions(+), 106 deletions(-) diff --git a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt index 0bcc47185d..9b12912d7f 100644 --- a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt +++ b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt @@ -822,125 +822,140 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : private fun handleAtomicrmw(instr: LLVMValueRef): Statement { val lhs = LLVMGetValueName(instr).string val operation = LLVMGetAtomicRMWBinOp(instr) - val ptr = frontend.getOperandValueAtIndex(instr, 0) val value = frontend.getOperandValueAtIndex(instr, 1) val ty = value.type - val exchOp = newAssignExpression("=", rawNode = instr) - exchOp.name = Name("atomicrmw") - val ptrDeref = newUnaryOperator("*", postfix = false, prefix = true, rawNode = instr) - ptrDeref.input = ptr + val exchOp = newAssignExpression("=", rawNode = instr).withChildren { + it.name = Name("atomicrmw") - val ptrDerefExch = newUnaryOperator("*", postfix = false, prefix = true, rawNode = instr) - ptrDerefExch.input = frontend.getOperandValueAtIndex(instr, 0) - exchOp.lhs = listOf(ptrDerefExch) + val ptrDeref = newUnaryOperator("*", postfix = false, prefix = true, rawNode = instr) + ptrDeref.input = frontend.getOperandValueAtIndex(instr, 0) - when (operation) { - LLVMAtomicRMWBinOpXchg -> { - exchOp.rhs = listOf(value) - } - LLVMAtomicRMWBinOpFAdd, - LLVMAtomicRMWBinOpAdd -> { - val binaryOperator = newBinaryOperator("+", rawNode = instr) - binaryOperator.lhs = ptrDeref - binaryOperator.rhs = value - exchOp.rhs = listOf(binaryOperator) - } - LLVMAtomicRMWBinOpFSub, - LLVMAtomicRMWBinOpSub -> { - val binaryOperator = newBinaryOperator("-", rawNode = instr) - binaryOperator.lhs = ptrDeref - binaryOperator.rhs = value - exchOp.rhs = listOf(binaryOperator) - } - LLVMAtomicRMWBinOpAnd -> { - val binaryOperator = newBinaryOperator("&", rawNode = instr) - binaryOperator.lhs = ptrDeref - binaryOperator.rhs = value - exchOp.rhs = listOf(binaryOperator) - } - LLVMAtomicRMWBinOpNand -> { - val binaryOperator = newBinaryOperator("|", rawNode = instr) - binaryOperator.lhs = ptrDeref - binaryOperator.rhs = value - val unaryOperator = newUnaryOperator("~", false, true, rawNode = instr) - unaryOperator.input = binaryOperator - exchOp.rhs = listOf(unaryOperator) - } - LLVMAtomicRMWBinOpOr -> { - val binaryOperator = newBinaryOperator("|", rawNode = instr) - binaryOperator.lhs = ptrDeref - binaryOperator.rhs = value - exchOp.rhs = listOf(binaryOperator) - } - LLVMAtomicRMWBinOpXor -> { - val binaryOperator = newBinaryOperator("^", rawNode = instr) - binaryOperator.lhs = ptrDeref - binaryOperator.rhs = value - exchOp.rhs = listOf(binaryOperator) - } - LLVMAtomicRMWBinOpMax, - LLVMAtomicRMWBinOpMin -> { - val operatorCode = - if (operation == LLVMAtomicRMWBinOpMin) { - "<" - } else { - ">" - } - val condition = newBinaryOperator(operatorCode, rawNode = instr) - condition.lhs = ptrDeref - condition.rhs = value - - val ptrDerefConditional = newUnaryOperator("*", false, true, rawNode = instr) - ptrDerefConditional.input = frontend.getOperandValueAtIndex(instr, 0) - val conditional = - newConditionalExpression( + val ptrDerefExch = newUnaryOperator("*", postfix = false, prefix = true, rawNode = instr) + ptrDerefExch.input = frontend.getOperandValueAtIndex(instr, 0) + it.lhs = listOf(ptrDerefExch) + + when (operation) { + LLVMAtomicRMWBinOpXchg -> { + it.rhs = listOf(value) + } + + LLVMAtomicRMWBinOpFAdd, + LLVMAtomicRMWBinOpAdd + -> { + val binaryOperator = newBinaryOperator("+", rawNode = instr) + binaryOperator.lhs = ptrDeref + binaryOperator.rhs = value + it.rhs = listOf(binaryOperator) + } + + LLVMAtomicRMWBinOpFSub, + LLVMAtomicRMWBinOpSub + -> { + val binaryOperator = newBinaryOperator("-", rawNode = instr) + binaryOperator.lhs = ptrDeref + binaryOperator.rhs = value + it.rhs = listOf(binaryOperator) + } + + LLVMAtomicRMWBinOpAnd -> { + val binaryOperator = newBinaryOperator("&", rawNode = instr) + binaryOperator.lhs = ptrDeref + binaryOperator.rhs = value + it.rhs = listOf(binaryOperator) + } + + LLVMAtomicRMWBinOpNand -> { + val binaryOperator = newBinaryOperator("|", rawNode = instr) + binaryOperator.lhs = ptrDeref + binaryOperator.rhs = value + val unaryOperator = newUnaryOperator("~", false, true, rawNode = instr) + unaryOperator.input = binaryOperator + it.rhs = listOf(unaryOperator) + } + + LLVMAtomicRMWBinOpOr -> { + val binaryOperator = newBinaryOperator("|", rawNode = instr) + binaryOperator.lhs = ptrDeref + binaryOperator.rhs = value + it.rhs = listOf(binaryOperator) + } + + LLVMAtomicRMWBinOpXor -> { + val binaryOperator = newBinaryOperator("^", rawNode = instr) + binaryOperator.lhs = ptrDeref + binaryOperator.rhs = value + it.rhs = listOf(binaryOperator) + } + + LLVMAtomicRMWBinOpMax, + LLVMAtomicRMWBinOpMin + -> { + val operatorCode = + if (operation == LLVMAtomicRMWBinOpMin) { + "<" + } else { + ">" + } + val condition = newBinaryOperator(operatorCode, rawNode = instr) + condition.lhs = ptrDeref + condition.rhs = value + + val ptrDerefConditional = newUnaryOperator("*", false, true, rawNode = instr) + ptrDerefConditional.input = frontend.getOperandValueAtIndex(instr, 0) + val conditional = + newConditionalExpression( ty, ) - .withChildren { - it.condition = condition - it.thenExpression = ptrDerefConditional - it.elseExpression = value - } - exchOp.rhs = listOf(conditional) - } - LLVMAtomicRMWBinOpUMax, - LLVMAtomicRMWBinOpUMin -> { - val operatorCode = - if (operation == LLVMAtomicRMWBinOpUMin) { - "<" - } else { - ">" - } - val condition = newBinaryOperator(operatorCode, rawNode = instr) - val castExprLhs = newCastExpression(rawNode = instr) - castExprLhs.castType = objectType("u${ty.name}") - castExprLhs.expression = ptrDeref - condition.lhs = castExprLhs + .withChildren { + it.condition = condition + it.thenExpression = ptrDerefConditional + it.elseExpression = value + } + it.rhs = listOf(conditional) + } - val castExprRhs = newCastExpression(rawNode = instr) - castExprRhs.castType = objectType("u${ty.name}") - castExprRhs.expression = value - condition.rhs = castExprRhs - - val ptrDerefConditional = newUnaryOperator("*", false, true, rawNode = instr) - ptrDerefConditional.input = frontend.getOperandValueAtIndex(instr, 0) - val conditional = - newConditionalExpression( + LLVMAtomicRMWBinOpUMax, + LLVMAtomicRMWBinOpUMin + -> { + val operatorCode = + if (operation == LLVMAtomicRMWBinOpUMin) { + "<" + } else { + ">" + } + val condition = newBinaryOperator(operatorCode, rawNode = instr) + val castExprLhs = newCastExpression(rawNode = instr) + castExprLhs.castType = objectType("u${ty.name}") + castExprLhs.expression = ptrDeref + condition.lhs = castExprLhs + + val castExprRhs = newCastExpression(rawNode = instr) + castExprRhs.castType = objectType("u${ty.name}") + castExprRhs.expression = value + condition.rhs = castExprRhs + + val ptrDerefConditional = newUnaryOperator("*", false, true, rawNode = instr) + ptrDerefConditional.input = frontend.getOperandValueAtIndex(instr, 0) + val conditional = + newConditionalExpression( ty, ) - .withChildren { - it.condition = condition - it.thenExpression = ptrDerefConditional - it.elseExpression = value - } - exchOp.rhs = listOf(conditional) - } - else -> { - throw TranslationException("LLVMAtomicRMWBinOp $operation not supported") + .withChildren { + it.condition = condition + it.thenExpression = ptrDerefConditional + it.elseExpression = value + } + it.rhs = listOf(conditional) + } + + else -> { + throw TranslationException("LLVMAtomicRMWBinOp $operation not supported") + } } } + // TODO: too complicated because actually exchOp must be inside the block then :( return if (lhs != "") { // set lhs = *ptr, then perform the replacement val compoundStatement = newBlock(rawNode = instr) From 181e14522c9b8fa46916d3c5995c3b205312b459 Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Tue, 23 Jul 2024 15:39:51 +0200 Subject: [PATCH 5/9] and more --- .../fraunhofer/aisec/cpg/graph/NodeBuilder.kt | 22 +- .../cpg/frontends/llvm/StatementHandler.kt | 277 +++++++++--------- 2 files changed, 164 insertions(+), 135 deletions(-) diff --git a/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/NodeBuilder.kt b/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/NodeBuilder.kt index 70ed30caf6..170afc3854 100644 --- a/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/NodeBuilder.kt +++ b/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/NodeBuilder.kt @@ -33,15 +33,12 @@ import de.fraunhofer.aisec.cpg.graph.NodeBuilder.log import de.fraunhofer.aisec.cpg.graph.declarations.Declaration import de.fraunhofer.aisec.cpg.graph.declarations.TranslationUnitDeclaration import de.fraunhofer.aisec.cpg.graph.scopes.Scope -import de.fraunhofer.aisec.cpg.graph.statements.expressions.* -import de.fraunhofer.aisec.cpg.graph.types.* import de.fraunhofer.aisec.cpg.helpers.getCodeOfSubregion import de.fraunhofer.aisec.cpg.passes.inference.IsImplicitProvider import de.fraunhofer.aisec.cpg.passes.inference.IsInferredProvider import de.fraunhofer.aisec.cpg.sarif.PhysicalLocation import de.fraunhofer.aisec.cpg.sarif.Region import java.net.URI -import java.util.* import kotlin.collections.ArrayDeque import org.slf4j.LoggerFactory @@ -420,6 +417,25 @@ fun T.withChildren( return this } +/** + * This function can be used to set the [Node.astParent] of this node to the current node on the + * [AstStackProvider]'s stack. This is particularly useful if the node was created outside of the + * [withChildren] lambda (for example, because it is used in multiple when-branches). + * + * Example: + * ```kotlin + * val binaryOperator = newBinaryOperator("|", rawNode = instr).withChildren { + * it.lhs = ptrDeref.withParent() + * it.rhs = value.withParent() + * } + * ``` + */ +context(AstStackProvider) +fun T.withParent(): T { + this.astParent = (this@AstStackProvider).astStack.lastOrNull() + return this +} + context(ContextProvider) fun T.declare(): T { val scopeManager = diff --git a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt index 9b12912d7f..ae577e9fb0 100644 --- a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt +++ b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt @@ -825,148 +825,161 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : val value = frontend.getOperandValueAtIndex(instr, 1) val ty = value.type - val exchOp = newAssignExpression("=", rawNode = instr).withChildren { - it.name = Name("atomicrmw") - - val ptrDeref = newUnaryOperator("*", postfix = false, prefix = true, rawNode = instr) - ptrDeref.input = frontend.getOperandValueAtIndex(instr, 0) - - val ptrDerefExch = newUnaryOperator("*", postfix = false, prefix = true, rawNode = instr) - ptrDerefExch.input = frontend.getOperandValueAtIndex(instr, 0) - it.lhs = listOf(ptrDerefExch) - - when (operation) { - LLVMAtomicRMWBinOpXchg -> { - it.rhs = listOf(value) - } - - LLVMAtomicRMWBinOpFAdd, - LLVMAtomicRMWBinOpAdd - -> { - val binaryOperator = newBinaryOperator("+", rawNode = instr) - binaryOperator.lhs = ptrDeref - binaryOperator.rhs = value - it.rhs = listOf(binaryOperator) - } - - LLVMAtomicRMWBinOpFSub, - LLVMAtomicRMWBinOpSub - -> { - val binaryOperator = newBinaryOperator("-", rawNode = instr) - binaryOperator.lhs = ptrDeref - binaryOperator.rhs = value - it.rhs = listOf(binaryOperator) - } - - LLVMAtomicRMWBinOpAnd -> { - val binaryOperator = newBinaryOperator("&", rawNode = instr) - binaryOperator.lhs = ptrDeref - binaryOperator.rhs = value - it.rhs = listOf(binaryOperator) - } - - LLVMAtomicRMWBinOpNand -> { - val binaryOperator = newBinaryOperator("|", rawNode = instr) - binaryOperator.lhs = ptrDeref - binaryOperator.rhs = value - val unaryOperator = newUnaryOperator("~", false, true, rawNode = instr) - unaryOperator.input = binaryOperator - it.rhs = listOf(unaryOperator) - } + val exchOp = + newAssignExpression("=", rawNode = instr).withChildren { + it.name = Name("atomicrmw") - LLVMAtomicRMWBinOpOr -> { - val binaryOperator = newBinaryOperator("|", rawNode = instr) - binaryOperator.lhs = ptrDeref - binaryOperator.rhs = value - it.rhs = listOf(binaryOperator) - } + val ptrDeref = + newUnaryOperator("*", postfix = false, prefix = true, rawNode = instr) + ptrDeref.input = frontend.getOperandValueAtIndex(instr, 0) - LLVMAtomicRMWBinOpXor -> { - val binaryOperator = newBinaryOperator("^", rawNode = instr) - binaryOperator.lhs = ptrDeref - binaryOperator.rhs = value - it.rhs = listOf(binaryOperator) - } + val ptrDerefExch = + newUnaryOperator("*", postfix = false, prefix = true, rawNode = instr) + ptrDerefExch.input = frontend.getOperandValueAtIndex(instr, 0) + it.lhs = listOf(ptrDerefExch) - LLVMAtomicRMWBinOpMax, - LLVMAtomicRMWBinOpMin - -> { - val operatorCode = - if (operation == LLVMAtomicRMWBinOpMin) { - "<" - } else { - ">" - } - val condition = newBinaryOperator(operatorCode, rawNode = instr) - condition.lhs = ptrDeref - condition.rhs = value - - val ptrDerefConditional = newUnaryOperator("*", false, true, rawNode = instr) - ptrDerefConditional.input = frontend.getOperandValueAtIndex(instr, 0) - val conditional = - newConditionalExpression( - ty, - ) - .withChildren { - it.condition = condition - it.thenExpression = ptrDerefConditional - it.elseExpression = value + when (operation) { + LLVMAtomicRMWBinOpXchg -> { + it.rhs = listOf(value.withParent()) + } + LLVMAtomicRMWBinOpFAdd, + LLVMAtomicRMWBinOpAdd -> { + val binaryOperator = + newBinaryOperator("+", rawNode = instr).withChildren { + it.lhs = ptrDeref.withParent() + it.rhs = value.withParent() } - it.rhs = listOf(conditional) - } - - LLVMAtomicRMWBinOpUMax, - LLVMAtomicRMWBinOpUMin - -> { - val operatorCode = - if (operation == LLVMAtomicRMWBinOpUMin) { - "<" - } else { - ">" - } - val condition = newBinaryOperator(operatorCode, rawNode = instr) - val castExprLhs = newCastExpression(rawNode = instr) - castExprLhs.castType = objectType("u${ty.name}") - castExprLhs.expression = ptrDeref - condition.lhs = castExprLhs - - val castExprRhs = newCastExpression(rawNode = instr) - castExprRhs.castType = objectType("u${ty.name}") - castExprRhs.expression = value - condition.rhs = castExprRhs - - val ptrDerefConditional = newUnaryOperator("*", false, true, rawNode = instr) - ptrDerefConditional.input = frontend.getOperandValueAtIndex(instr, 0) - val conditional = - newConditionalExpression( - ty, - ) - .withChildren { - it.condition = condition - it.thenExpression = ptrDerefConditional - it.elseExpression = value + it.rhs = listOf(binaryOperator) + } + LLVMAtomicRMWBinOpFSub, + LLVMAtomicRMWBinOpSub -> { + val binaryOperator = + newBinaryOperator("-", rawNode = instr).withChildren { + it.lhs = ptrDeref.withParent() + it.rhs = value.withParent() } - it.rhs = listOf(conditional) - } - - else -> { - throw TranslationException("LLVMAtomicRMWBinOp $operation not supported") + it.rhs = listOf(binaryOperator) + } + LLVMAtomicRMWBinOpAnd -> { + val binaryOperator = + newBinaryOperator("&", rawNode = instr).withChildren { + it.lhs = ptrDeref.withParent() + it.rhs = value.withParent() + } + it.rhs = listOf(binaryOperator) + } + LLVMAtomicRMWBinOpNand -> { + val unaryOperator = + newUnaryOperator("~", false, true, rawNode = instr).withChildren { + it.input = + newBinaryOperator("|", rawNode = instr).withChildren { + it.lhs = ptrDeref.withParent() + it.rhs = value.withParent() + } + } + it.rhs = listOf(unaryOperator) + } + LLVMAtomicRMWBinOpOr -> { + val binaryOperator = + newBinaryOperator("|", rawNode = instr).withChildren { + it.lhs = ptrDeref.withParent() + it.rhs = value.withParent() + } + it.rhs = listOf(binaryOperator) + } + LLVMAtomicRMWBinOpXor -> { + val binaryOperator = + newBinaryOperator("^", rawNode = instr).withChildren { + it.lhs = ptrDeref.withParent() + it.rhs = value.withParent() + } + it.rhs = listOf(binaryOperator) + } + LLVMAtomicRMWBinOpMax, + LLVMAtomicRMWBinOpMin -> { + val operatorCode = + if (operation == LLVMAtomicRMWBinOpMin) { + "<" + } else { + ">" + } + val conditional = + newConditionalExpression( + ty, + ) + .withChildren { + it.condition = + newBinaryOperator(operatorCode, rawNode = instr) + .withChildren { + it.lhs = ptrDeref.withParent() + it.rhs = value.withParent() + } + it.thenExpression = + newUnaryOperator("*", false, true, rawNode = instr) + .withChildren { + it.input = frontend.getOperandValueAtIndex(instr, 0) + } + it.elseExpression = + frontend.getOperandValueAtIndex(instr, 1) // value + } + it.rhs = listOf(conditional) + } + LLVMAtomicRMWBinOpUMax, + LLVMAtomicRMWBinOpUMin -> { + val operatorCode = + if (operation == LLVMAtomicRMWBinOpUMin) { + "<" + } else { + ">" + } + val conditional = + newConditionalExpression( + ty, + ) + .withChildren { + it.condition = + newBinaryOperator(operatorCode, rawNode = instr) + .withChildren { + it.lhs = + newCastExpression(rawNode = instr) + .withChildren { + it.castType = objectType("u${ty.name}") + it.expression = ptrDeref.withParent() + } + it.rhs = + newCastExpression(rawNode = instr) + .withChildren { + it.castType = objectType("u${ty.name}") + it.expression = value.withParent() + } + } + it.thenExpression = + newUnaryOperator("*", false, true, rawNode = instr) + .withChildren { + it.input = frontend.getOperandValueAtIndex(instr, 0) + } + it.elseExpression = + frontend.getOperandValueAtIndex(instr, 1) // value + } + it.rhs = listOf(conditional) + } + else -> { + throw TranslationException("LLVMAtomicRMWBinOp $operation not supported") + } } } - } - // TODO: too complicated because actually exchOp must be inside the block then :( return if (lhs != "") { // set lhs = *ptr, then perform the replacement - val compoundStatement = newBlock(rawNode = instr) - - val ptrDerefAssignment = - newUnaryOperator("*", postfix = false, prefix = true, rawNode = instr) - ptrDerefAssignment.input = frontend.getOperandValueAtIndex(instr, 0) - - compoundStatement.statements = - listOf(declarationOrNot(ptrDerefAssignment, instr), exchOp) - compoundStatement + newBlock(rawNode = instr).withChildren { + it.statements = + listOf( + newUnaryOperator("*", postfix = false, prefix = true, rawNode = instr) + .withChildren { it.input = frontend.getOperandValueAtIndex(instr, 0) } + .declareIfNecessary(instr), + exchOp.withParent() + ) + } } else { // only perform the replacement exchOp From 774a4510d200e5946b3b7bfd7f7247c8a32a78b0 Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Tue, 23 Jul 2024 16:02:49 +0200 Subject: [PATCH 6/9] Playing with register operands --- .../cpg/frontends/llvm/StatementHandler.kt | 142 +++++++++++------- 1 file changed, 85 insertions(+), 57 deletions(-) diff --git a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt index ae577e9fb0..6bedec3927 100644 --- a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt +++ b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt @@ -50,6 +50,30 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : map.put(LLVMBasicBlockRef::class.java) { handleBasicBlock(it as LLVMBasicBlockRef) } } + var instructionOperandMap = mutableMapOf>() + + fun LLVMValueRef.registerOperands(vararg pairs: Pair) { + instructionOperandMap[this.opCode] = mapOf(*pairs) + } + + fun LLVMValueRef.operand(operand: String): LLVMValueRef? { + var idx = instructionOperandMap[this.opCode]?.get(operand) + if (idx == null) { + throw TranslationException("unknown operand $operand for ${this.opCode}") + } + + return LLVMGetOperand(this, idx) + } + + fun LLVMValueRef.operandValue(operand: String): Expression { + var idx = instructionOperandMap[this.opCode]?.get(operand) + if (idx == null) { + throw TranslationException("unknown operand $operand for ${this.opCode}") + } + + return frontend.getOperandValueAtIndex(this, idx) + } + /** * Handles the parsing of * [instructions](https://llvm.org/docs/LangRef.html#instruction-reference). Instructions are @@ -997,48 +1021,52 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : "Indirectbr statement without address and at least one target" ) - val address = frontend.getOperandValueAtIndex(instr, 0) - - val switchStatement = newSwitchStatement(rawNode = instr) - switchStatement.selector = address - - val caseStatements = newBlock(rawNode = instr) - - var idx = 1 - while (idx < numOps) { - // The case statement is derived from the address of the label which we can jump to - val caseBBAddress = LLVMValueAsBasicBlock(LLVMGetOperand(instr, idx)).address() - val caseStatement = newCaseStatement(rawNode = instr) - caseStatement.caseExpression = - newLiteral(caseBBAddress, primitiveType("i64"), rawNode = instr) - caseStatements.addStatement(caseStatement) + return newSwitchStatement(rawNode = instr).withChildren { + val address = frontend.getOperandValueAtIndex(instr, 0) + it.selector = address + it.statement = + newBlock(rawNode = instr).withChildren { + var idx = 1 + while (idx < numOps) { + // The case statement is derived from the address of the label which we can + // jump to + val caseBBAddress = + LLVMValueAsBasicBlock(LLVMGetOperand(instr, idx)).address() + val caseStatement = + newCaseStatement(rawNode = instr).withChildren { + it.caseExpression = + newLiteral(caseBBAddress, primitiveType("i64"), rawNode = instr) + } + it.addStatement(caseStatement) - // Get the label of the goto statement. - val gotoStatement = assembleGotoStatement(instr, LLVMGetOperand(instr, idx)) - caseStatements.addStatement(gotoStatement) - idx++ + // Get the label of the goto statement. + val gotoStatement = assembleGotoStatement(instr, LLVMGetOperand(instr, idx)) + it.addStatement(gotoStatement) + idx++ + } + } } - - switchStatement.statement = caseStatements - - return switchStatement } /** Handles a [`br`](https://llvm.org/docs/LangRef.html#br-instruction) instruction. */ private fun handleBrStatement(instr: LLVMValueRef): Statement { - if (LLVMGetNumOperands(instr) == 3) { - // if(op) then {goto label1} else {goto label2} - val ifStatement = newIfStatement(rawNode = instr) - val condition = frontend.getOperandValueAtIndex(instr, 0) - ifStatement.condition = condition + instr.registerOperands ( + "cond" to 0, + "iftrue" to 1, + "iffalse" to 2, + ) - // Get the label of the "else" branch - ifStatement.elseStatement = assembleGotoStatement(instr, LLVMGetOperand(instr, 1)) + if (LLVMGetNumOperands(instr) == 3) { + // if(cond) then {goto iftrue} else {goto iffalse} + return newIfStatement(rawNode = instr).withChildren { + it.condition = instr.operandValue("cond") - // Get the label of the "if" branch - ifStatement.thenStatement = assembleGotoStatement(instr, LLVMGetOperand(instr, 2)) + // Get the label of the "else" branch + it.elseStatement = assembleGotoStatement(instr, instr.operand("iftrue")) - return ifStatement + // Get the label of the "if" branch + it.thenStatement = assembleGotoStatement(instr, instr.operand("iffalse")) + } } else if (LLVMGetNumOperands(instr) == 1) { // goto defaultLocation return assembleGotoStatement(instr, LLVMGetOperand(instr, 0)) @@ -1610,32 +1638,32 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : * already been processed or uses the listeners to generate the relation once the label * statement has been processed. */ - private fun assembleGotoStatement(instr: LLVMValueRef, bbTarget: LLVMValueRef): GotoStatement { - val goto = newGotoStatement(rawNode = instr) - val assigneeTargetLabel = BiConsumer { _: Any, to: Node -> - if (to is LabelStatement) { - goto.targetLabel = to - } else if (goto.targetLabel != to) { - log.error("$to is not a LabelStatement") + private fun assembleGotoStatement(instr: LLVMValueRef, bbTarget: LLVMValueRef?): GotoStatement { + return newGotoStatement(rawNode = instr).withChildren { goto -> + val assigneeTargetLabel = BiConsumer { _: Any, to: Node -> + if (to is LabelStatement) { + goto.targetLabel = to + } else if (goto.targetLabel != to) { + log.error("$to is not a LabelStatement") + } + } + val bb: LLVMBasicBlockRef = LLVMValueAsBasicBlock(bbTarget) + val labelName = LLVMGetBasicBlockName(bb).string + goto.labelName = labelName + + val label = newLabelStatement().withChildren { it -> it.name = Name(labelName) } + // If the bound AST node is/or was transformed into a CPG node the cpg node is bound + // to the CPG goto statement + frontend.registerObjectListener(label, assigneeTargetLabel) + if (goto.targetLabel == null) { + // If the Label AST node could not be resolved, the matching is done based on label + // names of CPG nodes using the predicate listeners + frontend.registerPredicateListener( + { _: Any?, to: Any? -> (to is LabelStatement && to.label == goto.labelName) }, + assigneeTargetLabel + ) } } - val bb: LLVMBasicBlockRef = LLVMValueAsBasicBlock(bbTarget) - val labelName = LLVMGetBasicBlockName(bb).string - goto.labelName = labelName - val label = newLabelStatement() - label.name = Name(labelName) - // If the bound AST node is/or was transformed into a CPG node the cpg node is bound - // to the CPG goto statement - frontend.registerObjectListener(label, assigneeTargetLabel) - if (goto.targetLabel == null) { - // If the Label AST node could not be resolved, the matching is done based on label - // names of CPG nodes using the predicate listeners - frontend.registerPredicateListener( - { _: Any?, to: Any? -> (to is LabelStatement && to.label == goto.labelName) }, - assigneeTargetLabel - ) - } - return goto } /** Returns the name of the given basic block. */ From f843c5d9d686bf76d213f5e1ed25431bc463f96e Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Tue, 23 Jul 2024 22:31:39 +0200 Subject: [PATCH 7/9] loosing my sanity --- .../cpg/frontends/llvm/StatementHandler.kt | 153 +++++++++--------- 1 file changed, 74 insertions(+), 79 deletions(-) diff --git a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt index 6bedec3927..71016eb861 100644 --- a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt +++ b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt @@ -52,24 +52,22 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : var instructionOperandMap = mutableMapOf>() - fun LLVMValueRef.registerOperands(vararg pairs: Pair) { - instructionOperandMap[this.opCode] = mapOf(*pairs) + private infix fun Int.usesOperands(operands: Map) { + instructionOperandMap[this] = operands } - fun LLVMValueRef.operand(operand: String): LLVMValueRef? { - var idx = instructionOperandMap[this.opCode]?.get(operand) - if (idx == null) { - throw TranslationException("unknown operand $operand for ${this.opCode}") - } + private fun LLVMValueRef.operand(operand: String): LLVMValueRef? { + val idx = + instructionOperandMap[this.opCode]?.get(operand) + ?: throw TranslationException("unknown operand $operand for ${this.opCode}") return LLVMGetOperand(this, idx) } - fun LLVMValueRef.operandValue(operand: String): Expression { - var idx = instructionOperandMap[this.opCode]?.get(operand) - if (idx == null) { - throw TranslationException("unknown operand $operand for ${this.opCode}") - } + private fun LLVMValueRef.operandValue(operand: String): Expression { + val idx = + instructionOperandMap[this.opCode]?.get(operand) + ?: throw TranslationException("unknown operand $operand for ${this.opCode}") return frontend.getOperandValueAtIndex(this, idx) } @@ -1050,11 +1048,12 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : /** Handles a [`br`](https://llvm.org/docs/LangRef.html#br-instruction) instruction. */ private fun handleBrStatement(instr: LLVMValueRef): Statement { - instr.registerOperands ( - "cond" to 0, - "iftrue" to 1, - "iffalse" to 2, - ) + LLVMBr usesOperands + mapOf( + "cond" to 0, + "iftrue" to 1, + "iffalse" to 2, + ) if (LLVMGetNumOperands(instr) == 3) { // if(cond) then {goto iftrue} else {goto iffalse} @@ -1084,38 +1083,35 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : * Returns a [SwitchStatement]. */ private fun handleSwitchStatement(instr: LLVMValueRef): Statement { + LLVMSwitch usesOperands mapOf("value" to 0, "defaultdest" to 1) + val numOps = LLVMGetNumOperands(instr) if (numOps < 2 || numOps % 2 != 0) throw TranslationException("Switch statement without operand and default branch") - val operand = frontend.getOperandValueAtIndex(instr, 0) + return newSwitchStatement(rawNode = instr).withChildren { + it.selector = instr.operandValue("value") + it.statement = + newBlock(rawNode = instr).withChildren { block -> + var idx = 2 + while (idx < numOps) { + // Get the comparison value and add it to the CaseStatement + block += + newCaseStatement(rawNode = instr).withChildren { + it.caseExpression = frontend.getOperandValueAtIndex(instr, idx) + } + idx++ + // Get the "case" statements and add it to the CaseStatement + block += assembleGotoStatement(instr, LLVMGetOperand(instr, idx)) + idx++ + } - val switchStatement = newSwitchStatement(rawNode = instr) - switchStatement.selector = operand - - val caseStatements = newBlock(rawNode = instr) - - var idx = 2 - while (idx < numOps) { - // Get the comparison value and add it to the CaseStatement - val caseStatement = newCaseStatement(rawNode = instr) - caseStatement.caseExpression = frontend.getOperandValueAtIndex(instr, idx) - caseStatements.addStatement(caseStatement) - idx++ - // Get the "case" statements and add it to the CaseStatement - val gotoStatement = assembleGotoStatement(instr, LLVMGetOperand(instr, idx)) - caseStatements.addStatement(gotoStatement) - idx++ + // Get the label of the "default" branch + block += newDefaultStatement(rawNode = instr) + val defaultGoto = assembleGotoStatement(instr, instr.operand("defaultdest")) + block += defaultGoto + } } - - // Get the label of the "default" branch - caseStatements.addStatement(newDefaultStatement(rawNode = instr)) - val defaultGoto = assembleGotoStatement(instr, LLVMGetOperand(instr, 1)) - caseStatements.addStatement(defaultGoto) - - switchStatement.statement = caseStatements - - return switchStatement } /** @@ -1154,46 +1150,45 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : ) } - val callee = newReference(calledFuncName, frontend.typeOf(calledFunc), rawNode = calledFunc) - - val callExpr = newCallExpression(callee, calledFuncName, false, rawNode = instr) - - while (idx < max) { - val operandName = frontend.getOperandValueAtIndex(instr, idx) - callExpr.addArgument(operandName) - idx++ - } + val call = + newCallExpression(null, calledFuncName, false, rawNode = instr).withChildren { call -> + call.callee = + newReference(calledFuncName, frontend.typeOf(calledFunc), rawNode = calledFunc) + while (idx < max) { + val arg = frontend.getOperandValueAtIndex(instr, idx) + call.arguments += arg + idx++ + } + } - if (instr.opCode == LLVMInvoke) { + return if (instr.opCode == LLVMInvoke) { // For the "invoke" instruction, the call is surrounded by a try statement which also // contains a goto statement after the call. - val tryStatement = newTryStatement(rawNode = instr) - frontend.scopeManager.enterScope(tryStatement) - val tryBlock = newBlock(rawNode = instr) - tryBlock.addStatement(declarationOrNot(callExpr, instr)) - tryBlock.addStatement(tryContinue) - tryStatement.tryBlock = tryBlock - frontend.scopeManager.leaveScope(tryStatement) - - val catchClause = newCatchClause(rawNode = instr) - catchClause.name = Name(gotoCatch.labelName) - catchClause.parameter = - newVariableDeclaration( - "e_${gotoCatch.labelName}", - unknownType(), - true, - rawNode = instr - ) + newTryStatement(rawNode = instr).withChildren(hasScope = true) { tryStatement -> + tryStatement.tryBlock = newBlock(rawNode = instr).withChildren { + it += declarationOrNot(call.withParent(), instr) + it += tryContinue.withParent() + } - val catchBlockStatement = newBlock(rawNode = instr) - catchBlockStatement.addStatement(gotoCatch) - catchClause.body = catchBlockStatement - tryStatement.catchClauses = mutableListOf(catchClause) + val catchClause = newCatchClause(rawNode = instr).withChildren { + it.name = Name(gotoCatch.labelName) + it.parameter = + newVariableDeclaration( + "e_${gotoCatch.labelName}", + unknownType(), + true, + rawNode = instr + ) - return tryStatement + it.body = newBlock(rawNode = instr).withChildren { + it.addStatement(gotoCatch) + } + } + tryStatement.catchClauses = mutableListOf(catchClause) + } + } else { + call.declareIfNecessary(instr) } - - return declarationOrNot(callExpr, instr) } /** @@ -1464,8 +1459,8 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : /** * Most instructions in LLVM have a variable assignment as part of their instruction. Since LLVM * IR is SSA, we need to declare a new variable in this case, which is named according to - * [valueRef]. In case the variable assignment is optional, and we directly return the - * [Expression] associated with the instruction. + * [valueRef]. In case the variable assignment is optional, we directly return the [Expression] + * associated with the instruction. */ private fun declarationOrNot(rhs: Expression, valueRef: LLVMValueRef): Statement { val namePair = frontend.getNameOf(valueRef) From 99c610962f4f52234941ad75a08a1ded48e06623 Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Tue, 23 Jul 2024 23:01:03 +0200 Subject: [PATCH 8/9] some LLVM test cases actually work --- .../fraunhofer/aisec/cpg/graph/NodeBuilder.kt | 14 +- .../cpg/frontends/llvm/StatementHandler.kt | 234 +++++++++--------- 2 files changed, 129 insertions(+), 119 deletions(-) diff --git a/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/NodeBuilder.kt b/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/NodeBuilder.kt index 170afc3854..9e37987b26 100644 --- a/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/NodeBuilder.kt +++ b/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/NodeBuilder.kt @@ -420,13 +420,19 @@ fun T.withChildren( /** * This function can be used to set the [Node.astParent] of this node to the current node on the * [AstStackProvider]'s stack. This is particularly useful if the node was created outside of the - * [withChildren] lambda (for example, because it is used in multiple when-branches). + * [withChildren] lambda. This is a usual pattern if the node is optionally wrapped in something + * else. * * Example: * ```kotlin - * val binaryOperator = newBinaryOperator("|", rawNode = instr).withChildren { - * it.lhs = ptrDeref.withParent() - * it.rhs = value.withParent() + * val expr = newReference("p") + * + * return if (isDeref) { + * newUnaryOperator("*").withChildren { + * it.input = expr.withParent() + * } + * } else { + * expr * } * ``` */ diff --git a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt index 71016eb861..9dc8eca1d0 100644 --- a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt +++ b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt @@ -1165,25 +1165,26 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : // For the "invoke" instruction, the call is surrounded by a try statement which also // contains a goto statement after the call. newTryStatement(rawNode = instr).withChildren(hasScope = true) { tryStatement -> - tryStatement.tryBlock = newBlock(rawNode = instr).withChildren { - it += declarationOrNot(call.withParent(), instr) - it += tryContinue.withParent() - } - - val catchClause = newCatchClause(rawNode = instr).withChildren { - it.name = Name(gotoCatch.labelName) - it.parameter = - newVariableDeclaration( - "e_${gotoCatch.labelName}", - unknownType(), - true, - rawNode = instr - ) + tryStatement.tryBlock = + newBlock(rawNode = instr).withChildren { + it += declarationOrNot(call.withParent(), instr) + it += tryContinue.withParent() + } - it.body = newBlock(rawNode = instr).withChildren { - it.addStatement(gotoCatch) + val catchClause = + newCatchClause(rawNode = instr).withChildren { + it.name = Name(gotoCatch.labelName) + it.parameter = + newVariableDeclaration( + "e_${gotoCatch.labelName}", + unknownType(), + true, + rawNode = instr + ) + + it.body = + newBlock(rawNode = instr).withChildren { it.addStatement(gotoCatch) } } - } tryStatement.catchClauses = mutableListOf(catchClause) } } else { @@ -1480,7 +1481,7 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : rawNode = valueRef ) .withChildren { decl -> - decl.initializer = rhs + decl.initializer = rhs.withParent() // add the declaration to the current scope frontend.scopeManager.addDeclaration(decl) @@ -1499,31 +1500,34 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : * [LabelStatement] if the basic block has a label. */ private fun handleBasicBlock(bb: LLVMBasicBlockRef): Statement { - val compound = newBlock(rawNode = bb) - - var instr = LLVMGetFirstInstruction(bb) - while (instr != null) { - log.debug("Parsing {}", frontend.codeOf(instr)) + val block = + newBlock(rawNode = bb).withChildren { block -> + var instr = LLVMGetFirstInstruction(bb) + while (instr != null) { + log.debug("Parsing {}", frontend.codeOf(instr)) + + val stmt = frontend.statementHandler.handle(instr) + if (stmt != null) { + block += stmt + } - val stmt = frontend.statementHandler.handle(instr) - if (stmt != null) { - compound.addStatement(stmt) + instr = LLVMGetNextInstruction(instr) + } } - instr = LLVMGetNextInstruction(instr) - } - val labelName = getBasicBlockName(bb) + return if (labelName != "") { + val labelStatement = + newLabelStatement().withChildren { + it.name = Name(labelName) + it.label = labelName + it.subStatement = block.withParent() + } - if (labelName != "") { - val labelStatement = newLabelStatement() - labelStatement.name = Name(labelName) - labelStatement.label = labelName - labelStatement.subStatement = compound - - return labelStatement + labelStatement + } else { + block } - return compound } /** @@ -1543,89 +1547,89 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : unsigned: Boolean, unordered: Boolean = false ): Statement { - val op1 = frontend.getOperandValueAtIndex(instr, 0) - val op2 = frontend.getOperandValueAtIndex(instr, 1) + instr.opCode usesOperands mapOf("op1" to 0, "op2" to 1) - val binaryOperator: Expression - var binOpUnordered: BinaryOperator? = null - - if (op == "uno") { - // Unordered comparison operand => Replace with a call to isunordered(x, y) - // Resulting statement: i1 lhs = isordered(op1, op2) - binaryOperator = - newCallExpression( - llvmInternalRef("isunordered"), - "isunordered", - false, - rawNode = instr - ) - binaryOperator.addArgument(op1) - binaryOperator.addArgument(op2) - } else if (op == "ord") { - // Ordered comparison operand => Replace with !isunordered(x, y) - // Resulting statement: i1 lhs = !isordered(op1, op2) - val unorderedCall = - newCallExpression( - llvmInternalRef("isunordered"), - "isunordered", - false, - rawNode = instr - ) - unorderedCall.addArgument(op1) - unorderedCall.addArgument(op2) - binaryOperator = newUnaryOperator("!", false, true, rawNode = instr) - binaryOperator.input = unorderedCall - } else { - // Resulting statement: lhs = op1 op2. - binaryOperator = newBinaryOperator(op, rawNode = instr) - - if (unsigned) { - val op1Type = "u${op1.type.name}" - val castExprLhs = newCastExpression(rawNode = instr) - castExprLhs.castType = objectType(op1Type) - castExprLhs.expression = op1 - binaryOperator.lhs = castExprLhs - - val op2Type = "u${op2.type.name}" - val castExprRhs = newCastExpression(rawNode = instr) - castExprRhs.castType = objectType(op2Type) - castExprRhs.expression = op2 - binaryOperator.rhs = castExprRhs - } else { - binaryOperator.lhs = op1 - binaryOperator.rhs = op2 + var expr = + when (op) { + "uno" -> { + // Unordered comparison operand => Replace with a call to isunordered(x, y) + // Resulting statement: i1 lhs = isordered(op1, op2) + newCallExpression( + llvmInternalRef("isunordered"), + "isunordered", + false, + rawNode = instr + ) + .withChildren { + it.arguments += instr.operandValue("op1") + it.arguments += instr.operandValue("op2") + } + } + "ord" -> { + // Ordered comparison operand => Replace with !isunordered(x, y) + // Resulting statement: i1 lhs = !isordered(op1, op2) + newUnaryOperator("!", false, true, rawNode = instr).withChildren { + it.input = + newCallExpression( + llvmInternalRef("isunordered"), + "isunordered", + false, + rawNode = instr + ) + .withChildren { + it.arguments += instr.operandValue("op1") + it.arguments += instr.operandValue("op2") + } + } + } + else -> { + // Resulting statement: lhs = op1 op2. + newBinaryOperator(op, rawNode = instr).withChildren { + if (unsigned) { + it.lhs = + newCastExpression(rawNode = instr).withChildren { + val op1 = instr.operandValue("op1") + val op1Type = "u${op1.type.name}" + it.castType = objectType(op1Type) + it.expression = op1 + } + it.rhs = + newCastExpression(rawNode = instr).withChildren { + val op2 = instr.operandValue("op2") + val op2Type = "u${op2.type.name}" + it.castType = objectType(op2Type) + it.expression = op2 + } + } else { + it.lhs = instr.operandValue("op1") + it.rhs = instr.operandValue("op2") + } + } + } } - if (unordered) { + return if (unordered) { // Special case for floating point comparisons which check if a value is "unordered // or ". // Statement is then lhs = isunordered(op1, op2) || (op1 op2) - binOpUnordered = newBinaryOperator("||", rawNode = instr) - binOpUnordered.rhs = binaryOperator - val unorderedCall = - newCallExpression( - llvmInternalRef("isunordered"), - "isunordered", - false, - rawNode = instr - ) - unorderedCall.addArgument(op1) - unorderedCall.addArgument(op2) - binOpUnordered.lhs = unorderedCall + newBinaryOperator("||", rawNode = instr).withChildren { + it.rhs = expr.withParent() + it.lhs = + newCallExpression( + llvmInternalRef("isunordered"), + "isunordered", + false, + rawNode = instr + ) + .withChildren { + it.arguments += instr.operandValue("op1") + it.arguments += instr.operandValue("op2") + } + } + } else { + expr } - } - - val declOp = if (unordered) binOpUnordered else binaryOperator - val decl = - declOp?.let { declarationOrNot(it, instr) } - ?: newProblemExpression("Could not parse declaration") - - (decl as? DeclarationStatement)?.let { - // cache binding - frontend.bindingsCache[instr.symbolName] = decl.singleDeclaration as VariableDeclaration - } - - return decl + .declareIfNecessary(instr) } /** @@ -1646,7 +1650,7 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : val labelName = LLVMGetBasicBlockName(bb).string goto.labelName = labelName - val label = newLabelStatement().withChildren { it -> it.name = Name(labelName) } + val label = newLabelStatement().withChildren { it.name = Name(labelName) } // If the bound AST node is/or was transformed into a CPG node the cpg node is bound // to the CPG goto statement frontend.registerObjectListener(label, assigneeTargetLabel) @@ -1685,7 +1689,7 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : return newReference(name) } - fun Expression.declareIfNecessary(instr: LLVMValueRef): Statement { + private fun Expression.declareIfNecessary(instr: LLVMValueRef): Statement { return declarationOrNot(this, instr) } } From 211b3f66b04cca368817d2e884ce90ecf418140f Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Tue, 23 Jul 2024 23:59:59 +0200 Subject: [PATCH 9/9] More LLVM stuff --- .../cpg/frontends/llvm/DeclarationHandler.kt | 30 +- .../cpg/frontends/llvm/StatementHandler.kt | 289 +++++++++--------- 2 files changed, 164 insertions(+), 155 deletions(-) diff --git a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/DeclarationHandler.kt b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/DeclarationHandler.kt index 41aa972732..8ac43573bf 100644 --- a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/DeclarationHandler.kt +++ b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/DeclarationHandler.kt @@ -136,28 +136,28 @@ class DeclarationHandler(lang: LLVMIRLanguageFrontend) : // definition does not have an entry, which specifies the first block, but it has a // *body*, which comprises *all* statements within the abstract syntax tree of // that function, hierarchically organized by compound statements. To emulate that, - // we - // take the first basic block as our body and add subsequent blocks as statements to - // the body. More specifically, we use the CPG node LabelStatement, which denotes - // the - // use of a label. Its property substatement contains the original basic block, - // parsed - // as a compound statement + // we take the first basic block as our body and add subsequent blocks as statements + // to the body. More specifically, we use the CPG node LabelStatement, which denotes + // the use of a label. Its property sub-statement contains the original basic block, + // parsed as a compound statement // Take the entry block as our body if (LLVMGetEntryBasicBlock(func) == bb && stmt is Block) { it.body = stmt } else if (LLVMGetEntryBasicBlock(func) == bb) { - it.body = newBlock() - if (stmt != null) { - (it.body as Block).addStatement(stmt) - } + it.body = + newBlock().withChildren { block -> + if (stmt != null) { + block += stmt.withParent() + } + } } else { // add the label statement, containing this basic block as a compound statement - // to - // our body (if we have none, which we should) - if (stmt != null) { - (it.body as? Block)?.addStatement(stmt) + // to our body (if we have one, which we should) + (it.body as? Block)?.withChildren { block -> + if (stmt != null) { + block += stmt.withParent() + } } } diff --git a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt index 9dc8eca1d0..c93792e792 100644 --- a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt +++ b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt @@ -89,7 +89,7 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : if (LLVMIsABinaryOperator(instr) != null) { return handleBinaryInstruction(instr) } else if (LLVMIsACastInst(instr) != null) { - return declarationOrNot(frontend.expressionHandler.handleCastInstruction(instr), instr) + return frontend.expressionHandler.handleCastInstruction(instr).declareIfNecessary(instr) } val opcode = instr.opCode @@ -125,7 +125,7 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : LLVMStore -> handleStore(instr) LLVMExtractValue, LLVMGetElementPtr -> { - declarationOrNot(frontend.expressionHandler.handleGetElementPtr(instr), instr) + frontend.expressionHandler.handleGetElementPtr(instr).declareIfNecessary(instr) } LLVMICmp -> handleIntegerComparison(instr) LLVMFCmp -> handleFloatComparison(instr) @@ -134,7 +134,7 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : newEmptyStatement(rawNode = instr) } LLVMSelect -> { - declarationOrNot(frontend.expressionHandler.handleSelect(instr), instr) + frontend.expressionHandler.handleSelect(instr).declareIfNecessary(instr) } LLVMUserOp1, LLVMUserOp2 -> { @@ -251,7 +251,7 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : it.addArgument(parent, "parent") } - val tokenGeneration = declarationOrNot(dummyCall, instr) as DeclarationStatement + val tokenGeneration = dummyCall.declareIfNecessary(instr) as DeclarationStatement block.addStatement(tokenGeneration) val ifStatement = newIfStatement(rawNode = instr) @@ -330,26 +330,24 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : */ @FunctionReplacement(["llvm.cleanuppad"], "cleanuppad") private fun handleCleanuppad(instr: LLVMValueRef) = - declarationOrNot( - newCallExpression( - llvmInternalRef("llvm.cleanuppad"), - "llvm.cleanuppad", - false, - rawNode = instr - ) - .withChildren { - val numOps = LLVMGetNumArgOperands(instr) - val catchswitch = frontend.getOperandValueAtIndex(instr, 0) + newCallExpression( + llvmInternalRef("llvm.cleanuppad"), + "llvm.cleanuppad", + false, + rawNode = instr + ) + .withChildren { + val numOps = LLVMGetNumArgOperands(instr) + val catchswitch = frontend.getOperandValueAtIndex(instr, 0) - it.addArgument(catchswitch, "parentCatchswitch") + it.addArgument(catchswitch, "parentCatchswitch") - for (i in 1 until numOps) { - val arg = frontend.getOperandValueAtIndex(instr, i) - it.addArgument(arg, "args_${i - 1}") - } - }, - instr - ) + for (i in 1 until numOps) { + val arg = frontend.getOperandValueAtIndex(instr, i) + it.addArgument(arg, "args_${i - 1}") + } + } + .declareIfNecessary(instr) /** * We simulate a [`catchpad`](https://llvm.org/docs/LangRef.html#catchpad-instruction) @@ -574,7 +572,7 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : rawNode = instr ) if (operand !is ConstructExpression) { - copy = declarationOrNot(operand, instr) + copy = operand.declareIfNecessary(instr) if (copy is DeclarationStatement) { base = newReference( @@ -592,15 +590,16 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : if (base is ConstructExpression) { if (idx == numOps - 1) { base.setArgument(index, valueToSet) - return declarationOrNot(operand, instr) + return operand.declareIfNecessary(instr) } base = base.arguments[index] } else if (baseType is PointerType) { - val arrayExpr = newSubscriptExpression() - arrayExpr.arrayExpression = base - arrayExpr.name = Name(index.toString()) - arrayExpr.subscriptExpression = operand - expr = arrayExpr + expr = + newSubscriptExpression().withChildren { + it.arrayExpression = base + it.name = Name(index.toString()) + it.subscriptExpression = operand + } // deference the type to get the new base type baseType = baseType.dereference() @@ -1167,7 +1166,7 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : newTryStatement(rawNode = instr).withChildren(hasScope = true) { tryStatement -> tryStatement.tryBlock = newBlock(rawNode = instr).withChildren { - it += declarationOrNot(call.withParent(), instr) + it += call.withParent().declareIfNecessary(instr) it += tryContinue.withParent() } @@ -1247,7 +1246,7 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : val compoundStatement = newBlock(rawNode = instr) // TODO: Probably we should make a proper copy of the array - val newArrayDecl = declarationOrNot(frontend.getOperandValueAtIndex(instr, 0), instr) + val newArrayDecl = frontend.getOperandValueAtIndex(instr, 0).declareIfNecessary(instr) compoundStatement.addStatement(newArrayDecl) val decl = newArrayDecl.declarations[0] as? VariableDeclaration @@ -1275,11 +1274,12 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : * instruction which is modeled as access to an array at a given index. */ private fun handleExtractelement(instr: LLVMValueRef): Statement { - val arrayExpr = newSubscriptExpression(rawNode = instr) - arrayExpr.arrayExpression = frontend.getOperandValueAtIndex(instr, 0) - arrayExpr.subscriptExpression = frontend.getOperandValueAtIndex(instr, 1) + LLVMExtractElement usesOperands mapOf("val" to 0, "idx" to 1) - return declarationOrNot(arrayExpr, instr) + return newSubscriptExpression(rawNode = instr).withChildren { + it.arrayExpression = instr.operandValue("val") + it.subscriptExpression = instr.operandValue("idx") + } } /** @@ -1291,71 +1291,80 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : * barely used and also the features of LLVM are very limited in that scenario. */ private fun handleShufflevector(instr: LLVMValueRef): Statement { - val list = newInitializerListExpression(frontend.typeOf(instr), rawNode = instr) - val elementType = frontend.typeOf(instr).dereference() + LLVMShuffleVector usesOperands mapOf("v1" to 0, "v2" to 1) - val initializers = mutableListOf() + return newInitializerListExpression(frontend.typeOf(instr), rawNode = instr) + .withChildren { + val elementType = frontend.typeOf(instr).dereference() - // Get the first vector and its length. The length is 0 if it's an undef value. - val array1 = frontend.getOperandValueAtIndex(instr, 0) - val array1Length = - if (array1 is Literal<*> && array1.value == null) { - 0 - } else { - LLVMGetVectorSize(LLVMTypeOf(LLVMGetOperand(instr, 0))) - } + val initializers = mutableListOf() - // Get the second vector and its length. The length is 0 if it's an undef value. - val array2 = frontend.getOperandValueAtIndex(instr, 1) - val array2Length = - if (array2 is Literal<*> && array2.value == null) { - 0 - } else { - LLVMGetVectorSize(LLVMTypeOf(LLVMGetOperand(instr, 1))) - } + // Get the first vector and its length. The length is 0 if it's an undef value. + val array1 = instr.operandValue("op1") + val array1Length = + if (array1 is Literal<*> && array1.value == null) { + 0 + } else { + LLVMGetVectorSize(LLVMTypeOf(instr.operand("op1"))) + } - // Get the number of mask elements. They determine the ordering of the elements. - val indices = LLVMGetNumMaskElements(instr) - - // Get the respective elements depending on the mask and put them into an initializer for - // the resulting vector. - // If a vector is an initializer itself (i.e., a constant array), we directly put the values - // in the new initializer. - // Otherwise, we use the array as a variable. - for (idx in 0 until indices) { - val idxInt = LLVMGetMaskValue(instr, idx) - if (idxInt < array1Length) { - if (array1 is InitializerListExpression) { - initializers += array1.initializers[idxInt] - } else if (array1 is Literal<*> && array1.value == null) { - initializers += newLiteral(null, elementType, rawNode = instr) - } else { - val arrayExpr = newSubscriptExpression(rawNode = instr) - arrayExpr.arrayExpression = frontend.getOperandValueAtIndex(instr, 0) - arrayExpr.subscriptExpression = - newLiteral(idxInt, primitiveType("i32"), rawNode = instr) - initializers += arrayExpr - } - } else if (idxInt < array1Length + array2Length) { - if (array2 is InitializerListExpression) { - initializers += array2.initializers[idxInt - array1Length] - } else if (array2 is Literal<*> && array2.value == null) { - initializers += newLiteral(null, elementType, rawNode = instr) - } else { - val arrayExpr = newSubscriptExpression(rawNode = instr) - arrayExpr.arrayExpression = frontend.getOperandValueAtIndex(instr, 1) - arrayExpr.subscriptExpression = - newLiteral(idxInt - array1Length, primitiveType("i32"), rawNode = instr) - initializers += arrayExpr + // Get the second vector and its length. The length is 0 if it's an undef value. + val array2 = instr.operandValue("op2") + val array2Length = + if (array2 is Literal<*> && array2.value == null) { + 0 + } else { + LLVMGetVectorSize(LLVMTypeOf(instr.operand("op2"))) + } + + // Get the number of mask elements. They determine the ordering of the elements. + val indices = LLVMGetNumMaskElements(instr) + + // Get the respective elements depending on the mask and put them into an + // initializer for + // the resulting vector. + // If a vector is an initializer itself (i.e., a constant array), we directly put + // the values + // in the new initializer. + // Otherwise, we use the array as a variable. + for (idx in 0 until indices) { + val idxInt = LLVMGetMaskValue(instr, idx) + if (idxInt < array1Length) { + if (array1 is InitializerListExpression) { + initializers += array1.initializers[idxInt].withParent() + } else if (array1 is Literal<*> && array1.value == null) { + initializers += newLiteral(null, elementType, rawNode = instr) + } else { + initializers += + newSubscriptExpression(rawNode = instr).withChildren { + it.arrayExpression = instr.operandValue("op1") + it.subscriptExpression = + newLiteral(idxInt, primitiveType("i32"), rawNode = instr) + } + } + } else if (idxInt < array1Length + array2Length) { + if (array2 is InitializerListExpression) { + initializers += array2.initializers[idxInt - array1Length].withParent() + } else if (array2 is Literal<*> && array2.value == null) { + initializers += newLiteral(null, elementType, rawNode = instr) + } else { + initializers += + newSubscriptExpression(rawNode = instr).withChildren { + it.arrayExpression = instr.operandValue("op2") + it.subscriptExpression = + newLiteral( + idxInt - array1Length, + primitiveType("i32"), + rawNode = instr + ) + } + } + } else { + initializers += newLiteral(null, elementType, rawNode = instr) + } } - } else { - initializers += newLiteral(null, elementType, rawNode = instr) } - } - - list.initializers = initializers - - return declarationOrNot(list, instr) + .declareIfNecessary(instr) } /** @@ -1394,7 +1403,7 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : // We only have a single pair, so we insert a declaration in that one BB. val (key, value) = labelMap.entries.elementAt(0) val basicBlock = key.subStatement as? Block - val decl = declarationOrNot(value, instr) + val decl = value.declareIfNecessary(instr) flatAST.addAll(SubgraphWalker.flattenAST(decl)) val mutableStatements = basicBlock?.statements?.toMutableList() mutableStatements?.add(basicBlock.statements.size - 1, decl) @@ -1422,17 +1431,19 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : val varName = instr.name val type = frontend.typeOf(instr) val declaration = newVariableDeclaration(varName, type, false, rawNode = instr) - declaration.type = type - flatAST.add(declaration) // add the declaration to the current scope frontend.scopeManager.addDeclaration(declaration) + // add it to our bindings cache frontend.bindingsCache[instr.symbolName] = declaration - val declStatement = newDeclarationStatement(rawNode = instr) - declStatement.singleDeclaration = declaration + val declStatement = + newDeclarationStatement(rawNode = instr).withChildren { + it.singleDeclaration = declaration.withParent() + } + val mutableFunctionStatements = firstBB.statements.toMutableList() mutableFunctionStatements.add(0, declStatement) firstBB.statements = mutableFunctionStatements @@ -1457,44 +1468,6 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : } } - /** - * Most instructions in LLVM have a variable assignment as part of their instruction. Since LLVM - * IR is SSA, we need to declare a new variable in this case, which is named according to - * [valueRef]. In case the variable assignment is optional, we directly return the [Expression] - * associated with the instruction. - */ - private fun declarationOrNot(rhs: Expression, valueRef: LLVMValueRef): Statement { - val namePair = frontend.getNameOf(valueRef) - val lhs = namePair.first - val symbolName = namePair.second - - // if it is still empty, we probably do not have a left side - return if (lhs != "") { - // Since the declaration statement only contains the single declaration, we can use the - // same raw node, so we end up with the same code and location - newDeclarationStatement(rawNode = valueRef).withChildren { - it.singleDeclaration = - newVariableDeclaration( - lhs, - frontend.typeOf(valueRef), - false, - rawNode = valueRef - ) - .withChildren { decl -> - decl.initializer = rhs.withParent() - - // add the declaration to the current scope - frontend.scopeManager.addDeclaration(decl) - - // add it to our bindings cache - frontend.bindingsCache[symbolName] = decl - } - } - } else { - rhs - } - } - /** * Handles a basic block and returns a [Block] comprised of the statements of this block or a * [LabelStatement] if the basic block has a label. @@ -1549,7 +1522,7 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : ): Statement { instr.opCode usesOperands mapOf("op1" to 0, "op2" to 1) - var expr = + val expr = when (op) { "uno" -> { // Unordered comparison operand => Replace with a call to isunordered(x, y) @@ -1689,7 +1662,43 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : return newReference(name) } - private fun Expression.declareIfNecessary(instr: LLVMValueRef): Statement { - return declarationOrNot(this, instr) + /** + * Most instructions in LLVM have a variable assignment as part of their instruction. Since LLVM + * IR is SSA, we need to declare a new variable in this case, which is named according to + * [valueRef]. In case the variable assignment is optional, we directly return the [Expression] + * associated with the instruction. + */ + private fun Expression.declareIfNecessary(valueRef: LLVMValueRef): Statement { + val namePair = frontend.getNameOf(valueRef) + val lhs = namePair.first + val symbolName = namePair.second + + return with(this@StatementHandler) { + // if it is still empty, we probably do not have a left side + return@with if (lhs != "") { + // Since the declaration statement only contains the single declaration, we can use + // the same raw node, so we end up with the same code and location + newDeclarationStatement(rawNode = valueRef).withChildren { + it.singleDeclaration = + newVariableDeclaration( + lhs, + frontend.typeOf(valueRef), + false, + rawNode = valueRef + ) + .withChildren { decl -> + decl.initializer = this@declareIfNecessary.withParent() + + // add the declaration to the current scope + frontend.scopeManager.addDeclaration(decl) + + // add it to our bindings cache + frontend.bindingsCache[symbolName] = decl + } + } + } else { + this@declareIfNecessary + } + } } }