diff --git a/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/statements/expressions/BinaryOperator.kt b/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/statements/expressions/BinaryOperator.kt index c090750d26..06cce9f422 100644 --- a/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/statements/expressions/BinaryOperator.kt +++ b/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/statements/expressions/BinaryOperator.kt @@ -26,7 +26,8 @@ package de.fraunhofer.aisec.cpg.graph.statements.expressions import de.fraunhofer.aisec.cpg.frontends.TranslationException -import de.fraunhofer.aisec.cpg.graph.* +import de.fraunhofer.aisec.cpg.graph.ArgumentHolder +import de.fraunhofer.aisec.cpg.graph.HasOverloadedOperation import de.fraunhofer.aisec.cpg.graph.edges.ast.astEdgeOf import de.fraunhofer.aisec.cpg.graph.edges.unwrapping import de.fraunhofer.aisec.cpg.graph.types.HasType @@ -81,6 +82,7 @@ open class BinaryOperator : .append("lhs", lhs.name) .append("rhs", rhs.name) .append("operatorCode", operatorCode) + .append("location", location) .toString() } diff --git a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/ExpressionHandler.kt b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/ExpressionHandler.kt index 3c7763836c..6071ee35b5 100644 --- a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/ExpressionHandler.kt +++ b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/ExpressionHandler.kt @@ -63,6 +63,7 @@ class ExpressionHandler(lang: LLVMIRLanguageFrontend) : LLVMPoisonValueValueKind -> { newReference("poison", frontend.typeOf(value), rawNode = value) } + LLVMConstantTargetNoneValueKind -> newLiteral(null, unknownType(), rawNode = value) LLVMConstantTokenNoneValueKind -> newLiteral(null, unknownType(), rawNode = value) LLVMUndefValueValueKind -> initializeAsUndef(frontend.typeOf(value), value) LLVMConstantAggregateZeroValueKind -> initializeAsZero(frontend.typeOf(value), value) @@ -88,45 +89,28 @@ class ExpressionHandler(lang: LLVMIRLanguageFrontend) : ) } else -> { - log.info( - "Not handling value kind {} in handleValue yet. Falling back to the legacy way. Please change", - kind - ) - val cpgType = frontend.typeOf(value) - - // old stuff from getOperandValue, needs to be refactored to the when above - // TODO also move the other stuff to the expression handler - return when { - LLVMIsConstant(value) != 1 -> { - val operandName: String = - if ( - LLVMIsAGlobalAlias(value) != null || - LLVMIsGlobalConstant(value) == 1 - ) { - val aliasee = LLVMAliasGetAliasee(value) - LLVMPrintValueToString(aliasee) - .string // Already resolve the aliasee of the constant - } else { - // TODO This does not return the actual constant but only a string - // representation - LLVMPrintValueToString(value).string - } - newLiteral(operandName, cpgType, rawNode = value) - } - LLVMIsUndef(value) == 1 -> { - newReference("undef", cpgType, rawNode = value) - } - LLVMIsPoison(value) == 1 -> { - newReference("poison", cpgType, rawNode = value) - } - else -> { - log.error("Unknown expression {}", kind) - newProblemExpression( - "Unknown expression $kind", - ProblemNode.ProblemType.TRANSLATION, - rawNode = value - ) - } + // old stuff from getOperandValue, needs to be refactored to the `when` above + return if (LLVMIsConstant(value) != 1) { + log.info("Update handling value kind {} to the new way", kind) + var printVal = + if (LLVMIsAGlobalAlias(value) != null || LLVMIsGlobalConstant(value) == 1) { + // Already resolve the aliasee of the constant + LLVMAliasGetAliasee(value) + } else { + value + } + newLiteral( + LLVMPrintValueToString(printVal).string, + frontend.typeOf(value), + rawNode = value + ) + } else { + log.error("Unknown expression {}", kind) + newProblemExpression( + "Unknown expression $kind", + ProblemNode.ProblemType.TRANSLATION, + rawNode = value + ) } } } @@ -210,65 +194,42 @@ class ExpressionHandler(lang: LLVMIRLanguageFrontend) : * regular expression. */ private fun handleConstantExprValueKind(value: LLVMValueRef): Expression { - val expr = - when (val kind = LLVMGetConstOpcode(value)) { - LLVMGetElementPtr -> handleGetElementPtr(value) - LLVMSelect -> handleSelect(value) - LLVMTrunc, - LLVMZExt, - LLVMSExt, - LLVMFPToUI, - LLVMFPToSI, - LLVMUIToFP, - LLVMSIToFP, - LLVMFPTrunc, - LLVMFPExt, - LLVMPtrToInt, - LLVMIntToPtr, - LLVMBitCast, - LLVMAddrSpaceCast -> handleCastInstruction(value) - LLVMAdd, - LLVMFAdd -> - frontend.statementHandler.handleBinaryOperator(value, "+", false) as? Expression - ?: newProblemExpression( - "Wrong type of constant binary operation +", - ProblemNode.ProblemType.TRANSLATION, - rawNode = value - ) - LLVMSub, - LLVMFSub -> - frontend.statementHandler.handleBinaryOperator(value, "-", false) as? Expression - ?: newProblemExpression( - "Wrong type of constant binary operation -", - ProblemNode.ProblemType.TRANSLATION, - rawNode = value - ) - LLVMAShr -> - frontend.statementHandler.handleBinaryOperator(value, ">>", false) - as? Expression - ?: newProblemExpression( - "Wrong type of constant binary operation >>", - ProblemNode.ProblemType.TRANSLATION, - rawNode = value - ) - LLVMICmp -> - frontend.statementHandler.handleIntegerComparison(value) as? Expression - ?: newProblemExpression( - "Wrong type of constant comparison", - ProblemNode.ProblemType.TRANSLATION, - rawNode = value - ) - else -> { - log.error("Not handling constant expression of opcode {} yet", kind) - newProblemExpression( - "Not handling constant expression of opcode $kind yet", - ProblemNode.ProblemType.TRANSLATION, - rawNode = value - ) - } + return when (val kind = LLVMGetConstOpcode(value)) { + LLVMGetElementPtr -> handleGetElementPtr(value) + LLVMSelect -> handleSelect(value) + LLVMTrunc, + LLVMZExt, + LLVMSExt, + LLVMFPToUI, + LLVMFPToSI, + LLVMUIToFP, + LLVMSIToFP, + LLVMFPTrunc, + LLVMFPExt, + LLVMPtrToInt, + LLVMIntToPtr, + LLVMBitCast, + LLVMAddrSpaceCast -> handleCastInstruction(value) + LLVMAdd, + LLVMFAdd -> frontend.statementHandler.handleBinaryOperator(value, "+", false) + LLVMSub, + LLVMFSub -> frontend.statementHandler.handleBinaryOperator(value, "-", false) + LLVMMul, + LLVMFMul -> frontend.statementHandler.handleBinaryOperator(value, "*", false) + LLVMShl -> frontend.statementHandler.handleBinaryOperator(value, "<<", false) + LLVMLShr, + LLVMAShr -> frontend.statementHandler.handleBinaryOperator(value, ">>", false) + LLVMXor -> frontend.statementHandler.handleBinaryOperator(value, "^", false) + LLVMICmp -> frontend.statementHandler.handleIntegerComparison(value) + else -> { + log.error("Not handling constant expression of opcode {} yet", kind) + newProblemExpression( + "Not handling constant expression of opcode $kind yet", + ProblemNode.ProblemType.TRANSLATION, + rawNode = value + ) } - - return expr + } } /** diff --git a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt index f9d8e8c7a1..520b58dd97 100644 --- a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt +++ b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandler.kt @@ -69,9 +69,7 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : return declarationOrNot(frontend.expressionHandler.handleCastInstruction(instr), instr) } - val opcode = instr.opCode - - when (opcode) { + return when (val opcode = instr.opCode) { LLVMRet -> { val ret = newReturnStatement(rawNode = instr) @@ -80,141 +78,151 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : ret.returnValue = frontend.getOperandValueAtIndex(instr, 0) } - return ret + ret } LLVMBr -> { - return handleBrStatement(instr) + handleBrStatement(instr) } LLVMSwitch -> { - return handleSwitchStatement(instr) + handleSwitchStatement(instr) } LLVMIndirectBr -> { - return handleIndirectbrStatement(instr) + handleIndirectbrStatement(instr) } LLVMCall, LLVMInvoke -> { - return handleFunctionCall(instr) + handleFunctionCall(instr) } LLVMUnreachable -> { // Does nothing - return newEmptyStatement(rawNode = instr) + newEmptyStatement(rawNode = instr) } LLVMCallBr -> { // Maps to a call but also to a goto statement? Barely used => not relevant - log.error("Cannot parse callbr instruction yet") + newProblemExpression( + "Cannot handle callbr instruction yet", + ProblemNode.ProblemType.TRANSLATION, + rawNode = instr + ) } LLVMFNeg -> { val fneg = newUnaryOperator("-", postfix = false, prefix = true, rawNode = instr) fneg.input = frontend.getOperandValueAtIndex(instr, 0) - return fneg + + val decl = declarationOrNot(fneg, instr) + (decl as? DeclarationStatement)?.let { + // cache binding + frontend.bindingsCache[instr.symbolName] = + decl.singleDeclaration as VariableDeclaration + } + + decl } LLVMAlloca -> { - return handleAlloca(instr) + handleAlloca(instr) } LLVMLoad -> { - return handleLoad(instr) + handleLoad(instr) } LLVMStore -> { - return handleStore(instr) + handleStore(instr) } LLVMExtractValue, LLVMGetElementPtr -> { - return declarationOrNot( - frontend.expressionHandler.handleGetElementPtr(instr), - instr - ) + declarationOrNot(frontend.expressionHandler.handleGetElementPtr(instr), instr) } LLVMICmp -> { - return handleIntegerComparison(instr) + declarationOrNot(handleIntegerComparison(instr), instr) } LLVMFCmp -> { - return handleFloatComparison(instr) + handleFloatComparison(instr) } LLVMPHI -> { frontend.phiList.add(instr) - return newEmptyStatement(rawNode = instr) + newEmptyStatement(rawNode = instr) } LLVMSelect -> { - return declarationOrNot(frontend.expressionHandler.handleSelect(instr), instr) + declarationOrNot(frontend.expressionHandler.handleSelect(instr), instr) } LLVMUserOp1, LLVMUserOp2 -> { log.info( "userop instruction is not a real instruction. Replacing it with empty statement" ) - return newEmptyStatement(rawNode = instr) + newEmptyStatement(rawNode = instr) } LLVMVAArg -> { - return handleVaArg(instr) + handleVaArg(instr) } LLVMExtractElement -> { - return handleExtractelement(instr) + handleExtractelement(instr) } LLVMInsertElement -> { - return handleInsertelement(instr) + handleInsertelement(instr) } LLVMShuffleVector -> { - return handleShufflevector(instr) + handleShufflevector(instr) } LLVMInsertValue -> { - return handleInsertValue(instr) + handleInsertValue(instr) } LLVMFreeze -> { - return handleFreeze(instr) + handleFreeze(instr) } LLVMFence -> { - return handleFence(instr) + handleFence(instr) } LLVMAtomicCmpXchg -> { - return handleAtomiccmpxchg(instr) + handleAtomiccmpxchg(instr) } LLVMAtomicRMW -> { - return handleAtomicrmw(instr) + handleAtomicrmw(instr) } LLVMResume -> { // Resumes propagation of an existing (in-flight) exception whose unwinding was // interrupted with a landingpad instruction. - return newThrowExpression(rawNode = instr).apply { + newThrowExpression(rawNode = instr).apply { exception = newProblemExpression("We don't know the exception while parsing this node.") } } LLVMLandingPad -> { - return handleLandingpad(instr) + handleLandingpad(instr) } LLVMCleanupRet -> { // End of the cleanup basic block(s) // Jump to a label where handling the exception will unwind to next (e.g. a // catchswitch statement) - return handleCatchret(instr) + handleCatchret(instr) } LLVMCatchRet -> { // Catch (caught by catchpad instruction) is over. // Jumps to a label where the "normal" function logic continues - return handleCatchret(instr) + handleCatchret(instr) } LLVMCatchPad -> { // Actually handles the exception. - return handleCatchpad(instr) + handleCatchpad(instr) } LLVMCleanupPad -> { // Beginning of the cleanup basic block(s). // We should model this as the beginning of a catch block - return handleCleanuppad(instr) + handleCleanuppad(instr) } LLVMCatchSwitch -> { // Marks the beginning of a "real" catch block // Jumps to one of the handlers specified or to the default handler (if specified) - return handleCatchswitch(instr) + handleCatchswitch(instr) + } + else -> { + log.error("Not handling instruction opcode {} yet", opcode) + newProblemExpression( + "Not handling instruction opcode $opcode yet", + ProblemNode.ProblemType.TRANSLATION, + rawNode = instr + ) } } - - log.error("Not handling instruction opcode {} yet", opcode) - return newProblemExpression( - "Not handling instruction opcode $opcode yet", - ProblemNode.ProblemType.TRANSLATION, - rawNode = instr - ) } /** @@ -228,10 +236,8 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : val unwindDest = if (instr.opCode == LLVMCatchRet) { LLVMGetOperand(instr, 1) - } else if (LLVMGetUnwindDest(instr) != null) { - LLVMBasicBlockAsValue(LLVMGetUnwindDest(instr)) } else { - null + LLVMGetUnwindDest(instr)?.let { LLVMBasicBlockAsValue(it) } } val name = Name( @@ -242,13 +248,9 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : } ) return if (unwindDest != null) { // For "unwind to caller", the destination is null - val gotoStatement = assembleGotoStatement(instr, unwindDest) - gotoStatement.name = name - gotoStatement + assembleGotoStatement(instr, unwindDest).apply { this.name = name } } else { - val emptyStatement = newEmptyStatement(rawNode = instr) - emptyStatement.name = name - emptyStatement + newEmptyStatement(rawNode = instr).apply { this.name = name } } } @@ -423,57 +425,60 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : /** Handles all kinds of instructions which are an arithmetic or logical binary instruction. */ private fun handleBinaryInstruction(instr: LLVMValueRef): Statement { - when (instr.opCode) { - LLVMAdd, - LLVMFAdd -> { - return handleBinaryOperator(instr, "+", false) - } - LLVMSub, - LLVMFSub -> { - return handleBinaryOperator(instr, "-", false) - } - LLVMMul, - LLVMFMul -> { - return handleBinaryOperator(instr, "*", false) - } - LLVMUDiv -> { - return handleBinaryOperator(instr, "/", true) - } - LLVMSDiv, - LLVMFDiv -> { - return handleBinaryOperator(instr, "/", false) - } - LLVMURem -> { - return handleBinaryOperator(instr, "%", true) - } - LLVMSRem, - LLVMFRem -> { - return handleBinaryOperator(instr, "%", false) - } - LLVMShl -> { - return handleBinaryOperator(instr, "<<", false) - } - LLVMLShr -> { - return handleBinaryOperator(instr, ">>", true) - } - LLVMAShr -> { - return handleBinaryOperator(instr, ">>", false) - } - LLVMAnd -> { - return handleBinaryOperator(instr, "&", false) - } - LLVMOr -> { - return handleBinaryOperator(instr, "|", false) - } - LLVMXor -> { - return handleBinaryOperator(instr, "^", false) + val binaryOperator = + when (instr.opCode) { + LLVMAdd, + LLVMFAdd -> { + handleBinaryOperator(instr, "+", false) + } + LLVMSub, + LLVMFSub -> { + handleBinaryOperator(instr, "-", false) + } + LLVMMul, + LLVMFMul -> { + handleBinaryOperator(instr, "*", false) + } + LLVMUDiv -> { + handleBinaryOperator(instr, "/", true) + } + LLVMSDiv, + LLVMFDiv -> { + handleBinaryOperator(instr, "/", false) + } + LLVMURem -> { + handleBinaryOperator(instr, "%", true) + } + LLVMSRem, + LLVMFRem -> { + handleBinaryOperator(instr, "%", false) + } + LLVMShl -> { + handleBinaryOperator(instr, "<<", false) + } + LLVMLShr -> { + handleBinaryOperator(instr, ">>", true) + } + LLVMAShr -> { + handleBinaryOperator(instr, ">>", false) + } + LLVMAnd -> { + handleBinaryOperator(instr, "&", false) + } + LLVMOr -> { + handleBinaryOperator(instr, "|", false) + } + LLVMXor -> { + handleBinaryOperator(instr, "^", false) + } + else -> + newProblemExpression( + "No opcode found for binary operator", + ProblemNode.ProblemType.TRANSLATION, + rawNode = instr + ) } - } - return newProblemExpression( - "Not opcode found for binary operator", - ProblemNode.ProblemType.TRANSLATION, - rawNode = instr - ) + return declarationOrNot(binaryOperator, instr) } /** @@ -527,7 +532,7 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : * Handles the [`icmp`](https://llvm.org/docs/LangRef.html#icmp-instruction) instruction for * comparing integer values. */ - fun handleIntegerComparison(instr: LLVMValueRef): Statement { + fun handleIntegerComparison(instr: LLVMValueRef): Expression { var unsigned = false val cmpPred = when (LLVMGetICmpPredicate(instr)) { @@ -608,7 +613,7 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : else -> "unknown" } - return handleBinaryOperator(instr, cmpPred, false, unordered) + return declarationOrNot(handleBinaryOperator(instr, cmpPred, false, unordered), instr) } /** @@ -736,12 +741,8 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : condition.lhs = undefCheck val poisonCheck = newBinaryOperator("!=", rawNode = instr) poisonCheck.lhs = operand - poisonCheck.rhs = - newLiteral( - "POISON", - operand.type, - rawNode = instr - ) // This could be e.g. NAN. Not sure for complex types + // This could be e.g. NAN. Not sure for complex types + poisonCheck.rhs = newReference("poison", operand.type, rawNode = instr) condition.rhs = poisonCheck // Call to a dummy function "llvm.freeze" which would fill the undef or poison values @@ -752,7 +753,7 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : newCallExpression(llvmInternalRef("llvm.freeze"), "llvm.freeze", false, rawNode = instr) callExpression.addArgument(operand) - // res = (arg != undef && arg != poison) ? arg : llvm.freeze(in) + // res = (arg != undef && arg != poison) ? arg : llvm.freeze(arg) val conditional = newConditionalExpression( condition, @@ -976,7 +977,11 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : exchOp.rhs = mutableListOf(conditional) } else -> { - throw TranslationException("LLVMAtomicRMWBinOp $operation not supported") + newProblemExpression( + "LLVMAtomicRMWBinOp $operation not supported", + ProblemNode.ProblemType.TRANSLATION, + rawNode = instr + ) } } @@ -1027,6 +1032,7 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : // Get the label of the goto statement. val gotoStatement = assembleGotoStatement(instr, LLVMGetOperand(instr, idx)) caseStatements.statements += gotoStatement + caseStatements.statements += newBreakStatement().implicit() idx++ } @@ -1525,7 +1531,7 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : op: String, unsigned: Boolean, unordered: Boolean = false - ): Statement { + ): Expression { val op1 = frontend.getOperandValueAtIndex(instr, 0) val op2 = frontend.getOperandValueAtIndex(instr, 1) @@ -1598,17 +1604,7 @@ class StatementHandler(lang: LLVMIRLanguageFrontend) : } } - val declOp = if (unordered) binOpUnordered else binaryOperator - val decl = - declOp?.let { declarationOrNot(it, instr) } - ?: newProblemExpression("Could not parse declaration") - - (decl as? DeclarationStatement)?.let { - // cache binding - frontend.bindingsCache[instr.symbolName] = decl.singleDeclaration as VariableDeclaration - } - - return decl + return binOpUnordered ?: binaryOperator } /** diff --git a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/passes/CompressLLVMPass.kt b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/passes/CompressLLVMPass.kt index 12247af000..84743e6299 100644 --- a/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/passes/CompressLLVMPass.kt +++ b/cpg-language-llvm/src/main/kotlin/de/fraunhofer/aisec/cpg/passes/CompressLLVMPass.kt @@ -56,7 +56,7 @@ class CompressLLVMPass(ctx: TranslationContext) : ComponentPass(ctx) { // Enforce the order: First IfStatements, then SwitchStatements, then the rest. This // prevents to treat the final goto in the case or default statement as a normal // compound - // statement which would lead to inlining the instructions BB but we want to keep the BB + // statement which would lead to inlining the instructions BB, but we want to keep the BB // inside a Block. for (node in flatAST.sortedBy { n -> @@ -72,76 +72,74 @@ class CompressLLVMPass(ctx: TranslationContext) : ComponentPass(ctx) { handleIfStatement(node, gotosToReplace) } is SwitchStatement -> { - // Iterate over all statements in a body of the switch/case and replace a goto - // statement if it is the only one jumping to the target - val caseBodyStatements = node.statement as Block - val newStatements = caseBodyStatements.statements.toMutableList() - for (i in 0 until newStatements.size) { - val subStatement = - (newStatements[i] as? GotoStatement)?.targetLabel?.subStatement - if ( - newStatements[i] in gotosToReplace && - newStatements[i] !in (subStatement?.astChildren ?: listOf()) - ) { - subStatement?.let { newStatements[i] = it } - } - } - (node.statement as Block).statements = newStatements + handleSwitchStatement(node, gotosToReplace) } is TryStatement -> { handleTryStatement(node) } is Block -> { - // Get the last statement in a Block and replace a goto statement - // iff it is the only one jumping to the target - val goto = node.statements.lastOrNull() - if ( - goto != null && - goto in gotosToReplace && - node !in - SubgraphWalker.flattenAST( - (goto as GotoStatement).targetLabel?.subStatement - ) - ) { - val subStatement = goto.targetLabel?.subStatement - val newStatements = node.statements.dropLast(1).toMutableList() - newStatements.addAll((subStatement as Block).statements) - node.statements = newStatements - } + handleBlock(node, gotosToReplace) } } } } + /** + * Get the last statement in a [Block] and replace a goto statement iff it is the only one + * jumping to the target + */ + private fun handleBlock(node: Block, gotosToReplace: List) { + val goto = node.statements.lastOrNull() as? GotoStatement ?: return + val gotoSubstatement = goto.targetLabel?.subStatement as? Block ?: return + if (goto in gotosToReplace && node !in gotoSubstatement.allChildren()) { + val newStatements = node.statements.dropLast(1).toMutableList() + newStatements.addAll(gotoSubstatement.statements) + node.statements = newStatements + } + } + + /** + * Iterates over all statements in a body of the switch/case and replace a goto statement if it + * is the only one jumping to the target + */ + private fun handleSwitchStatement(node: SwitchStatement, gotosToReplace: List) { + val caseBodyStatements = node.statement as? Block ?: return + val newStatements = caseBodyStatements.statements.toMutableList() + for (i in 0 until newStatements.size) { + val subStatement = (newStatements[i] as? GotoStatement)?.targetLabel?.subStatement + if ( + newStatements[i] in gotosToReplace && + newStatements[i] !in (subStatement?.astChildren ?: listOf()) + ) { + subStatement?.let { newStatements[i] = it } + } + } + caseBodyStatements.statements = newStatements + } + + /** + * Replace the then-statement and else-statement with the basic block it jumps to iff we found + * that its goto statement is the only one jumping to the target + */ private fun handleIfStatement(node: IfStatement, gotosToReplace: List) { - // Replace the then-statement with the basic block it jumps to iff we found that - // its goto statement is the only one jumping to the target - if ( - node.thenStatement in gotosToReplace && - node !in - SubgraphWalker.flattenAST( - (node.thenStatement as GotoStatement).targetLabel?.subStatement - ) - ) { - node.thenStatement = (node.thenStatement as GotoStatement).targetLabel?.subStatement + + // Replace the then-statement + val thenGoto = (node.thenStatement as? GotoStatement)?.targetLabel?.subStatement + if (node.thenStatement in gotosToReplace && node !in thenGoto.allChildren()) { + node.thenStatement = thenGoto } - // Replace the else-statement with the basic block it jumps to iff we found that - // its goto statement is the only one jumping to the target - if ( - node.elseStatement in gotosToReplace && - node !in - SubgraphWalker.flattenAST( - (node.elseStatement as GotoStatement).targetLabel?.subStatement - ) - ) { - node.elseStatement = (node.elseStatement as GotoStatement).targetLabel?.subStatement + // Replace the else-statement + val elseGoto = (node.elseStatement as? GotoStatement)?.targetLabel?.subStatement + if (node.elseStatement in gotosToReplace && node !in elseGoto.allChildren()) { + node.elseStatement = elseGoto } } private fun handleTryStatement(node: TryStatement) { - when { - node.catchClauses.size == 1 && - node.catchClauses[0].body?.statements?.get(0) is CatchClause -> { + val firstCatch = node.catchClauses.singleOrNull() + val firstStatement = firstCatch?.body?.statements?.get(0) + when (firstStatement) { + is CatchClause -> { /* Initially, we expect only a single catch clause which contains all the logic. * The first statement of the clause should have been a `landingpad` instruction * which has been translated to a CatchClause. We get this clause and set it as the @@ -149,36 +147,24 @@ class CompressLLVMPass(ctx: TranslationContext) : ComponentPass(ctx) { */ val catchClauses = mutableListOf() - val caseBody = node.catchClauses[0].body + val caseBody = firstCatch.body // This is the most generic one - val clauseToAdd = caseBody?.statements?.get(0) as CatchClause - catchClauses.add(clauseToAdd) - caseBody.statements = caseBody.statements.drop(1).toMutableList() + catchClauses.add(firstStatement) + caseBody?.statements = caseBody.statements.drop(1).toMutableList() catchClauses[0].body = caseBody - if (node.catchClauses[0].parameter != null) { - catchClauses[0].parameter = node.catchClauses[0].parameter + if (firstCatch.parameter != null) { + catchClauses[0].parameter = firstCatch.parameter } node.catchClauses = catchClauses - - fixThrowExpressionsForCatch(node.catchClauses[0]) } - node.catchClauses.size == 1 && - node.catchClauses[0].body?.statements?.get(0) is Block -> { + is Block -> { // A compound statement which is wrapped in the catchClause. We can simply move - // it - // one layer up and make - // the compound statement the body of the catch clause. - val innerCompound = node.catchClauses[0].body?.statements?.get(0) as? Block - innerCompound?.statements?.let { node.catchClauses[0].body?.statements = it } - fixThrowExpressionsForCatch(node.catchClauses[0]) - } - node.catchClauses.isNotEmpty() -> { - for (catch in node.catchClauses) { - fixThrowExpressionsForCatch(catch) - } + // it one layer up and make the compound statement the body of the catch clause. + firstCatch.body?.statements = firstStatement.statements } } + node.catchClauses.forEach(::fixThrowExpressionsForCatch) } /** @@ -192,30 +178,38 @@ class CompressLLVMPass(ctx: TranslationContext) : ComponentPass(ctx) { n.exception is ProblemExpression } if (reachableThrowNodes.isNotEmpty()) { - if (catch.parameter == null) { - val error = - newVariableDeclaration( - "e_${catch.name}", - UnknownType.getUnknownType(catch.language), - true, - ) - error.language = catch.language - catch.parameter = error + val catchParameter = + catch.parameter + ?: newVariableDeclaration( + "e_${catch.name}", + UnknownType.getUnknownType(catch.language), + implicitInitializerAllowed = true, + ) + .apply { + language = catch.language + catch.parameter = this + } + .implicit() + + reachableThrowNodes.forEach { + it.exception = + newReference(catchParameter.name, catchParameter.type) + .apply { + language = catch.language + refersTo = catch.parameter + } + .implicit() } - val exceptionReference = - newReference( - catch.parameter?.name, - catch.parameter?.type ?: UnknownType.getUnknownType(catch.language), - ) - exceptionReference.language = catch.language - exceptionReference.refersTo = catch.parameter - reachableThrowNodes.forEach { n -> n.exception = exceptionReference } } } - /** Iterates through all nodes which are reachable from the catch clause */ + /** + * Iterates through all nodes which are reachable from the catch clause. Note: When reaching a + * `TryStatement`, we do not follow the path further. This is why we can't use the `allChildren` + * extension. + */ private fun getAllChildrenRecursively(node: CatchClause?): Set { - if (node == null) return LinkedHashSet() + if (node == null) return setOf() val worklist: Queue = LinkedList() worklist.add(node.body) val alreadyChecked = LinkedHashSet() @@ -225,7 +219,7 @@ class CompressLLVMPass(ctx: TranslationContext) : ComponentPass(ctx) { // We exclude sub-try statements as they would mess up with the results val toAdd = currentNode.astChildren.filter { n -> - n !is TryStatement && !alreadyChecked.contains(n) && !worklist.contains(n) + n !is TryStatement && n !in alreadyChecked && n !in worklist } worklist.addAll(toAdd) } diff --git a/cpg-language-llvm/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/ExpressionHandlerTest.kt b/cpg-language-llvm/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/ExpressionHandlerTest.kt new file mode 100644 index 0000000000..20c895485f --- /dev/null +++ b/cpg-language-llvm/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/ExpressionHandlerTest.kt @@ -0,0 +1,159 @@ +/* + * Copyright (c) 2024, Fraunhofer AISEC. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * $$$$$$\ $$$$$$$\ $$$$$$\ + * $$ __$$\ $$ __$$\ $$ __$$\ + * $$ / \__|$$ | $$ |$$ / \__| + * $$ | $$$$$$$ |$$ |$$$$\ + * $$ | $$ ____/ $$ |\_$$ | + * $$ | $$\ $$ | $$ | $$ | + * \$$$$$ |$$ | \$$$$$ | + * \______/ \__| \______/ + * + */ +package de.fraunhofer.aisec.cpg.frontends.llvm + +import de.fraunhofer.aisec.cpg.graph.* +import de.fraunhofer.aisec.cpg.graph.statements.expressions.BinaryOperator +import de.fraunhofer.aisec.cpg.graph.statements.expressions.CallExpression +import de.fraunhofer.aisec.cpg.graph.statements.expressions.CastExpression +import de.fraunhofer.aisec.cpg.test.analyzeAndGetFirstTU +import de.fraunhofer.aisec.cpg.test.assertLiteralValue +import de.fraunhofer.aisec.cpg.test.assertLocalName +import de.fraunhofer.aisec.cpg.test.assertRefersTo +import java.nio.file.Path +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertIs +import kotlin.test.assertNotNull + +class ExpressionHandlerTest { + @Test + fun testConstantFloat() { + val topLevel = Path.of("src", "test", "resources", "llvm") + val tu = + analyzeAndGetFirstTU( + listOf(topLevel.resolve("floatingpoint_const.ll").toFile()), + topLevel, + true + ) { + it.registerLanguage() + } + + val globalX = tu.variables["x"] + assertNotNull(globalX) + assertLiteralValue(1.25, globalX.initializer) + + val a = tu.variables["a"] + assertNotNull(a) + val aInit = a.initializer + assertIs(aInit) + assertLiteralValue(1.25, aInit.lhs) + assertLiteralValue(1.0, aInit.rhs) + } + + @Test + fun testConstantExpr() { + val topLevel = Path.of("src", "test", "resources", "llvm") + val tu = + analyzeAndGetFirstTU( + listOf(topLevel.resolve("integer_const.ll").toFile()), + topLevel, + true + ) { + it.registerLanguage() + } + + val globalX = tu.variables["x"] + assertNotNull(globalX) + + val aInitCall = tu.variables["a"]?.initializer + assertIs(aInitCall) + assertLocalName("foo", aInitCall) + val argumentA = aInitCall.arguments.singleOrNull() + assertIs(argumentA) + assertEquals("+", argumentA.operatorCode) + val argumentAX = argumentA.lhs + assertIs(argumentAX) + assertRefersTo(argumentAX.expression, globalX) + assertLiteralValue(5L, argumentA.rhs) + + val bInitCall = tu.variables["b"]?.initializer + assertIs(bInitCall) + assertLocalName("foo", bInitCall) + val argumentB = bInitCall.arguments.singleOrNull() + assertIs(argumentB) + assertEquals("-", argumentB.operatorCode) + val argumenBtX = argumentB.lhs + assertIs(argumenBtX) + assertRefersTo(argumenBtX.expression, globalX) + assertLiteralValue(5L, argumentB.rhs) + + val cInitCall = tu.variables["c"]?.initializer + assertIs(cInitCall) + assertLocalName("foo", cInitCall) + val argumentC = cInitCall.arguments.singleOrNull() + assertIs(argumentC) + assertEquals("*", argumentC.operatorCode) + val argumentCX = argumentC.lhs + assertIs(argumentCX) + assertRefersTo(argumentCX.expression, globalX) + assertLiteralValue(5L, argumentC.rhs) + + val dInitCall = tu.variables["d"]?.initializer + assertIs(dInitCall) + assertLocalName("foo", dInitCall) + val argumentD = dInitCall.arguments.singleOrNull() + assertIs(argumentD) + assertEquals("<<", argumentD.operatorCode) + val argumentDX = argumentD.lhs + assertIs(argumentDX) + assertRefersTo(argumentDX.expression, globalX) + assertLiteralValue(5L, argumentD.rhs) + + val eInitCall = tu.variables["e"]?.initializer + assertIs(eInitCall) + assertLocalName("foo", eInitCall) + val argumentE = eInitCall.arguments.singleOrNull() + assertIs(argumentE) + assertEquals(">>", argumentE.operatorCode) + val argumentEX = argumentE.lhs + assertIs(argumentEX) + assertRefersTo(argumentEX.expression, globalX) + assertLiteralValue(5L, argumentE.rhs) + + val fInitCall = tu.variables["f"]?.initializer + assertIs(fInitCall) + assertLocalName("foo", fInitCall) + val argumentF = fInitCall.arguments.singleOrNull() + assertIs(argumentF) + assertEquals("^", argumentF.operatorCode) + val argumentFX = argumentF.lhs + assertIs(argumentFX) + assertRefersTo(argumentFX.expression, globalX) + assertLiteralValue(5L, argumentF.rhs) + + val gInitCall = tu.variables["g"]?.initializer + assertIs(gInitCall) + assertLocalName("foo1", gInitCall) + val argumentG = gInitCall.arguments.singleOrNull() + assertIs(argumentG) + assertEquals("==", argumentG.operatorCode) + val argumentGX = argumentG.lhs + assertIs(argumentGX) + assertRefersTo(argumentGX.expression, globalX) + assertLiteralValue(5L, argumentG.rhs) + } +} diff --git a/cpg-language-llvm/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/LLVMIRLanguageFrontendTest.kt b/cpg-language-llvm/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/LLVMIRLanguageFrontendTest.kt index dbeeb0ad41..89315e6599 100644 --- a/cpg-language-llvm/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/LLVMIRLanguageFrontendTest.kt +++ b/cpg-language-llvm/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/LLVMIRLanguageFrontendTest.kt @@ -26,6 +26,7 @@ package de.fraunhofer.aisec.cpg.frontends.llvm import de.fraunhofer.aisec.cpg.* +import de.fraunhofer.aisec.cpg.frontends.TranslationException import de.fraunhofer.aisec.cpg.graph.* import de.fraunhofer.aisec.cpg.graph.declarations.VariableDeclaration import de.fraunhofer.aisec.cpg.graph.statements.* @@ -35,8 +36,29 @@ import de.fraunhofer.aisec.cpg.test.* import java.nio.file.Path import kotlin.test.* import kotlin.test.Test +import org.junit.jupiter.api.assertThrows class LLVMIRLanguageFrontendTest { + @Test + fun testExceptionBrokenFile() { + val topLevel = Path.of("src", "test", "resources", "llvm") + + val frontend = + LLVMIRLanguageFrontend( + LLVMIRLanguage(), + TranslationContext( + TranslationConfiguration.builder().build(), + ScopeManager(), + TypeManager() + ) + ) + val exception = + assertThrows { + frontend.parse(topLevel.resolve("main-broken.ll").toFile()) + } + assertTrue(exception.message?.startsWith("Could not parse IR: ") == true) + } + @Test fun test1() { val topLevel = Path.of("src", "test", "resources", "llvm") @@ -88,50 +110,6 @@ class LLVMIRLanguageFrontendTest { assertLiteralValue(0L, xInit.initializers[3]) } - @Test - fun testIntegerOps() { - val topLevel = Path.of("src", "test", "resources", "llvm") - val tu = - analyzeAndGetFirstTU( - listOf(topLevel.resolve("integer_ops.ll").toFile()), - topLevel, - true - ) { - it.registerLanguage() - } - - assertEquals(2, tu.declarations.size) - - val main = tu.functions["main"] - assertNotNull(main) - assertLocalName("i32", main.type) - - val rand = tu.functions["rand"] - assertNotNull(rand) - assertNull(rand.body) - - val xDeclaration = tu.variables["x"] - assertNotNull(xDeclaration) - - val call = xDeclaration.initializer - assertIs(call) - assertLocalName("rand", call) - assertContains(call.invokes, rand) - assertEquals(0, call.arguments.size) - - val xorStatement = main.bodyOrNull(3) - assertNotNull(xorStatement) - - val xorDeclaration = xorStatement.singleDeclaration - assertIs(xorDeclaration) - assertLocalName("a", xorDeclaration) - assertEquals("i32", xorDeclaration.type.typeName) - - val xor = xorDeclaration.initializer - assertIs(xor) - assertEquals("^", xor.operatorCode) - } - @Test fun testIdentifiedStruct() { val topLevel = Path.of("src", "test", "resources", "llvm") @@ -369,55 +347,6 @@ class LLVMIRLanguageFrontendTest { assertEquals(" ret i32 1", ifRet.statements[0].code) } - @Test - fun testAtomicrmw() { - val topLevel = Path.of("src", "test", "resources", "llvm") - val tu = - analyzeAndGetFirstTU( - listOf(topLevel.resolve("atomicrmw.ll").toFile()), - topLevel, - true - ) { - it.registerLanguage() - } - - val foo = tu.functions["foo"] - assertNotNull(foo) - - val atomicrmwStatement = foo.bodyOrNull() - assertNotNull(atomicrmwStatement) - - // Check that the value is assigned to - val declaration = atomicrmwStatement.statements[0].declarations[0] - assertIs(declaration) - assertLocalName("old", declaration) - assertLocalName("i32", declaration.type) - val initializer = declaration.initializer - assertIs(initializer) - assertEquals("*", initializer.operatorCode) - assertLocalName("ptr", initializer.input) - - // Check that the replacement equals *ptr = *ptr + 1 - val replacement = atomicrmwStatement.statements[1] - assertIs(replacement) - assertEquals(1, replacement.lhs.size) - assertEquals(1, replacement.rhs.size) - assertEquals("=", replacement.operatorCode) - val replacementLhs = replacement.lhs.first() - assertIs(replacementLhs) - assertEquals("*", replacementLhs.operatorCode) - assertLocalName("ptr", replacementLhs.input) - // Check that the rhs is equal to *ptr + 1 - val add = replacement.rhs.first() - assertIs(add) - assertEquals("+", add.operatorCode) - val addLhs = add.lhs - assertIs(addLhs) - assertEquals("*", addLhs.operatorCode) - assertLocalName("ptr", addLhs.input) - assertLiteralValue(1L, add.rhs) - } - @Test fun testCmpxchg() { val topLevel = Path.of("src", "test", "resources", "llvm") @@ -433,7 +362,7 @@ class LLVMIRLanguageFrontendTest { val foo = tu.functions["foo"] assertNotNull(foo) - val cmpxchgStatement = foo.bodyOrNull(1) + val cmpxchgStatement = foo.bodyOrNull(10) assertNotNull(cmpxchgStatement) assertEquals(2, cmpxchgStatement.statements.size) @@ -484,8 +413,8 @@ class LLVMIRLanguageFrontendTest { assertEquals("*", thenExprLhs.operatorCode) assertLocalName("ptr", thenExprLhs.input) assertIs(thenExpr.rhs.first()) - assertLocalName("old", thenExpr.rhs.first()) - assertRefersTo(thenExpr.rhs.first(), tu.variables["old"]) + assertLocalName("old1", thenExpr.rhs.first()) + assertRefersTo(thenExpr.rhs.first(), tu.variables["old1"]) } @Test diff --git a/cpg-language-llvm/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandlerTest.kt b/cpg-language-llvm/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandlerTest.kt new file mode 100644 index 0000000000..ce0a7901f2 --- /dev/null +++ b/cpg-language-llvm/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/llvm/StatementHandlerTest.kt @@ -0,0 +1,991 @@ +/* + * Copyright (c) 2024, Fraunhofer AISEC. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * $$$$$$\ $$$$$$$\ $$$$$$\ + * $$ __$$\ $$ __$$\ $$ __$$\ + * $$ / \__|$$ | $$ |$$ / \__| + * $$ | $$$$$$$ |$$ |$$$$\ + * $$ | $$ ____/ $$ |\_$$ | + * $$ | $$\ $$ | $$ | $$ | + * \$$$$$ |$$ | \$$$$$ | + * \______/ \__| \______/ + * + */ +package de.fraunhofer.aisec.cpg.frontends.llvm + +import de.fraunhofer.aisec.cpg.graph.bodyOrNull +import de.fraunhofer.aisec.cpg.graph.declarations.VariableDeclaration +import de.fraunhofer.aisec.cpg.graph.functions +import de.fraunhofer.aisec.cpg.graph.get +import de.fraunhofer.aisec.cpg.graph.statements.BreakStatement +import de.fraunhofer.aisec.cpg.graph.statements.CaseStatement +import de.fraunhofer.aisec.cpg.graph.statements.DeclarationStatement +import de.fraunhofer.aisec.cpg.graph.statements.GotoStatement +import de.fraunhofer.aisec.cpg.graph.statements.SwitchStatement +import de.fraunhofer.aisec.cpg.graph.statements.expressions.AssignExpression +import de.fraunhofer.aisec.cpg.graph.statements.expressions.BinaryOperator +import de.fraunhofer.aisec.cpg.graph.statements.expressions.Block +import de.fraunhofer.aisec.cpg.graph.statements.expressions.CallExpression +import de.fraunhofer.aisec.cpg.graph.statements.expressions.CastExpression +import de.fraunhofer.aisec.cpg.graph.statements.expressions.ConditionalExpression +import de.fraunhofer.aisec.cpg.graph.statements.expressions.Literal +import de.fraunhofer.aisec.cpg.graph.statements.expressions.ProblemExpression +import de.fraunhofer.aisec.cpg.graph.statements.expressions.UnaryOperator +import de.fraunhofer.aisec.cpg.graph.variables +import de.fraunhofer.aisec.cpg.test.analyzeAndGetFirstTU +import de.fraunhofer.aisec.cpg.test.assertFullName +import de.fraunhofer.aisec.cpg.test.assertLiteralValue +import de.fraunhofer.aisec.cpg.test.assertLocalName +import de.fraunhofer.aisec.cpg.test.assertRefersTo +import java.nio.file.Path +import kotlin.test.Test +import kotlin.test.assertContains +import kotlin.test.assertEquals +import kotlin.test.assertIs +import kotlin.test.assertNotNull +import kotlin.test.assertNull + +class StatementHandlerTest { + + @Test + fun testIntegerOps() { + val topLevel = Path.of("src", "test", "resources", "llvm") + val tu = + analyzeAndGetFirstTU( + listOf(topLevel.resolve("integer_ops.ll").toFile()), + topLevel, + true + ) { + it.registerLanguage() + } + + assertEquals(2, tu.declarations.size) + + val main = tu.functions["main"] + assertNotNull(main) + assertLocalName("i32", main.type) + + val rand = tu.functions["rand"] + assertNotNull(rand) + assertNull(rand.body) + + val xDeclaration = tu.variables["x"] + assertNotNull(xDeclaration) + + val call = xDeclaration.initializer + assertIs(call) + assertLocalName("rand", call) + assertContains(call.invokes, rand) + assertEquals(0, call.arguments.size) + + val mulStatement = main.bodyOrNull(2) + assertNotNull(mulStatement) + + val mulDeclaration = mulStatement.singleDeclaration + assertIs(mulDeclaration) + assertLocalName("a", mulDeclaration) + assertEquals("i32", mulDeclaration.type.typeName) + + val mul = mulDeclaration.initializer + assertIs(mul) + assertEquals("*", mul.operatorCode) + + val addStatement = main.bodyOrNull(3) + assertNotNull(addStatement) + + val addDeclaration = addStatement.singleDeclaration + assertIs(addDeclaration) + assertLocalName("b", addDeclaration) + assertEquals("i32", addDeclaration.type.typeName) + + val add = addDeclaration.initializer + assertIs(add) + assertEquals("+", add.operatorCode) + + val subStatement = main.bodyOrNull(4) + assertNotNull(subStatement) + + val subDeclaration = subStatement.singleDeclaration + assertIs(subDeclaration) + assertLocalName("c", subDeclaration) + assertEquals("i32", subDeclaration.type.typeName) + + val sub = subDeclaration.initializer + assertIs(sub) + assertEquals("-", sub.operatorCode) + + val divStatement = main.bodyOrNull(5) + assertNotNull(divStatement) + + val divDeclaration = divStatement.singleDeclaration + assertIs(divDeclaration) + assertLocalName("d", divDeclaration) + assertEquals("i32", divDeclaration.type.typeName) + + val div = divDeclaration.initializer + assertIs(div) + assertEquals("/", div.operatorCode) + + val remStatement = main.bodyOrNull(6) + assertNotNull(remStatement) + + val remDeclaration = remStatement.singleDeclaration + assertIs(remDeclaration) + assertLocalName("e", remDeclaration) + assertEquals("i32", remDeclaration.type.typeName) + + val rem = remDeclaration.initializer + assertIs(rem) + assertEquals("%", rem.operatorCode) + + val xorStatement = main.bodyOrNull(7) + assertNotNull(xorStatement) + + val xorDeclaration = xorStatement.singleDeclaration + assertIs(xorDeclaration) + assertLocalName("f", xorDeclaration) + assertEquals("i32", xorDeclaration.type.typeName) + + val xor = xorDeclaration.initializer + assertIs(xor) + assertEquals("^", xor.operatorCode) + + val udivStatement = main.bodyOrNull(8) + assertNotNull(udivStatement) + + val udivDeclaration = udivStatement.singleDeclaration + assertIs(udivDeclaration) + assertLocalName("g", udivDeclaration) + assertEquals("i32", udivDeclaration.type.typeName) + + val udiv = udivDeclaration.initializer + assertIs(udiv) + assertEquals("/", udiv.operatorCode) + + val uremStatement = main.bodyOrNull(9) + assertNotNull(uremStatement) + + val uremDeclaration = uremStatement.singleDeclaration + assertIs(uremDeclaration) + assertLocalName("h", uremDeclaration) + assertEquals("i32", uremDeclaration.type.typeName) + + val urem = uremDeclaration.initializer + assertIs(urem) + assertEquals("%", urem.operatorCode) + + val shlStatement = main.bodyOrNull(10) + assertNotNull(shlStatement) + + val shlDeclaration = shlStatement.singleDeclaration + assertIs(shlDeclaration) + assertLocalName("i", shlDeclaration) + assertEquals("i32", shlDeclaration.type.typeName) + + val shl = shlDeclaration.initializer + assertIs(shl) + assertEquals("<<", shl.operatorCode) + + val lshrStatement = main.bodyOrNull(11) + assertNotNull(lshrStatement) + + val lshrDeclaration = lshrStatement.singleDeclaration + assertIs(lshrDeclaration) + assertLocalName("j", lshrDeclaration) + assertEquals("i32", lshrDeclaration.type.typeName) + + val lshr = lshrDeclaration.initializer + assertIs(lshr) + assertEquals(">>", lshr.operatorCode) + + val ashrStatement = main.bodyOrNull(12) + assertNotNull(ashrStatement) + + val ashrDeclaration = ashrStatement.singleDeclaration + assertIs(ashrDeclaration) + assertLocalName("k", ashrDeclaration) + assertEquals("i32", ashrDeclaration.type.typeName) + + val ashr = ashrDeclaration.initializer + assertIs(ashr) + assertEquals(">>", ashr.operatorCode) + } + + @Test + fun testFloatingpoingOps() { + val topLevel = Path.of("src", "test", "resources", "llvm") + val tu = + analyzeAndGetFirstTU( + listOf(topLevel.resolve("floatingpoint_ops.ll").toFile()), + topLevel, + true + ) { + it.registerLanguage() + } + + assertEquals(2, tu.declarations.size) + + val main = tu.functions["main"] + assertNotNull(main) + assertLocalName("half", main.type) + + val rand = tu.functions["rand"] + assertNotNull(rand) + assertNull(rand.body) + + val xDeclaration = tu.variables["x"] + assertNotNull(xDeclaration) + + val call = xDeclaration.initializer + assertIs(call) + assertLocalName("rand", call) + assertContains(call.invokes, rand) + assertEquals(0, call.arguments.size) + + val fmulStatement = main.bodyOrNull(2) + assertNotNull(fmulStatement) + + val fmulDeclaration = fmulStatement.singleDeclaration + assertIs(fmulDeclaration) + assertLocalName("a", fmulDeclaration) + assertEquals("half", fmulDeclaration.type.typeName) + + val fmul = fmulDeclaration.initializer + assertIs(fmul) + assertEquals("*", fmul.operatorCode) + + val faddStatement = main.bodyOrNull(3) + assertNotNull(faddStatement) + + val faddDeclaration = faddStatement.singleDeclaration + assertIs(faddDeclaration) + assertLocalName("b", faddDeclaration) + assertEquals("half", faddDeclaration.type.typeName) + + val fadd = faddDeclaration.initializer + assertIs(fadd) + assertEquals("+", fadd.operatorCode) + + val fsubStatement = main.bodyOrNull(4) + assertNotNull(fsubStatement) + + val fsubDeclaration = fsubStatement.singleDeclaration + assertIs(fsubDeclaration) + assertLocalName("c", fsubDeclaration) + assertEquals("half", fsubDeclaration.type.typeName) + + val fsub = fsubDeclaration.initializer + assertIs(fsub) + assertEquals("-", fsub.operatorCode) + + val fdivStatement = main.bodyOrNull(5) + assertNotNull(fdivStatement) + + val fdivDeclaration = fdivStatement.singleDeclaration + assertIs(fdivDeclaration) + assertLocalName("d", fdivDeclaration) + assertEquals("half", fdivDeclaration.type.typeName) + + val fdiv = fdivDeclaration.initializer + assertIs(fdiv) + assertEquals("/", fdiv.operatorCode) + + val fremStatement = main.bodyOrNull(6) + assertNotNull(fremStatement) + + val fremDeclaration = fremStatement.singleDeclaration + assertIs(fremDeclaration) + assertLocalName("e", fremDeclaration) + assertEquals("half", fremDeclaration.type.typeName) + + val frem = fremDeclaration.initializer + assertIs(frem) + assertEquals("%", frem.operatorCode) + + val fnegStatement = main.bodyOrNull(7) + assertNotNull(fnegStatement) + + val fnegDeclaration = fnegStatement.singleDeclaration + assertIs(fnegDeclaration) + assertLocalName("f", fnegDeclaration) + assertEquals("half", fnegDeclaration.type.typeName) + + val fneg = fnegDeclaration.initializer + assertIs(fneg) + assertEquals("-", fneg.operatorCode) + } + + @Test + fun testIntegerComparisons() { + val topLevel = Path.of("src", "test", "resources", "llvm") + val tu = + analyzeAndGetFirstTU( + listOf(topLevel.resolve("integer_comparisons.ll").toFile()), + topLevel, + true + ) { + it.registerLanguage() + } + + assertEquals(2, tu.declarations.size) + + val main = tu.functions["main"] + assertNotNull(main) + assertLocalName("i32", main.type) + + val rand = tu.functions["rand"] + assertNotNull(rand) + assertNull(rand.body) + + val xDeclaration = tu.variables["x"] + assertNotNull(xDeclaration) + + val call = xDeclaration.initializer + assertIs(call) + assertLocalName("rand", call) + assertContains(call.invokes, rand) + assertEquals(0, call.arguments.size) + + val cmpEqStatement = main.bodyOrNull(1) + assertNotNull(cmpEqStatement) + + val cmpEqDeclaration = cmpEqStatement.singleDeclaration + assertIs(cmpEqDeclaration) + assertLocalName("a", cmpEqDeclaration) + assertEquals("i1", cmpEqDeclaration.type.typeName) + + val cmpEq = cmpEqDeclaration.initializer + assertIs(cmpEq) + assertEquals("==", cmpEq.operatorCode) + + val cmpNeqStatement = main.bodyOrNull(2) + assertNotNull(cmpNeqStatement) + + val cmpNeqDeclaration = cmpNeqStatement.singleDeclaration + assertIs(cmpNeqDeclaration) + assertLocalName("b", cmpNeqDeclaration) + assertEquals("i1", cmpNeqDeclaration.type.typeName) + + val cmpNeq = cmpNeqDeclaration.initializer + assertIs(cmpNeq) + assertEquals("!=", cmpNeq.operatorCode) + + val cmpUgtStatement = main.bodyOrNull(3) + assertNotNull(cmpUgtStatement) + + val cmpUgtDeclaration = cmpUgtStatement.singleDeclaration + assertIs(cmpUgtDeclaration) + assertLocalName("c", cmpUgtDeclaration) + assertEquals("i1", cmpUgtDeclaration.type.typeName) + + val cmpUgt = cmpUgtDeclaration.initializer + assertIs(cmpUgt) + assertEquals(">", cmpUgt.operatorCode) + + val cmpUltStatement = main.bodyOrNull(4) + assertNotNull(cmpUltStatement) + + val cmpUltDeclaration = cmpUltStatement.singleDeclaration + assertIs(cmpUltDeclaration) + assertLocalName("d", cmpUltDeclaration) + assertEquals("i1", cmpUltDeclaration.type.typeName) + + val cmpUlt = cmpUltDeclaration.initializer + assertIs(cmpUlt) + assertEquals("<", cmpUlt.operatorCode) + + val cmpUgeStatement = main.bodyOrNull(5) + assertNotNull(cmpUgeStatement) + + val cmpUgeDeclaration = cmpUgeStatement.singleDeclaration + assertIs(cmpUgeDeclaration) + assertLocalName("e", cmpUgeDeclaration) + assertEquals("i1", cmpUgeDeclaration.type.typeName) + + val cmpUge = cmpUgeDeclaration.initializer + assertIs(cmpUge) + assertEquals(">=", cmpUge.operatorCode) + + val cmpUleStatement = main.bodyOrNull(6) + assertNotNull(cmpUleStatement) + + val cmpUleDeclaration = cmpUleStatement.singleDeclaration + assertIs(cmpUleDeclaration) + assertLocalName("f", cmpUleDeclaration) + assertEquals("i1", cmpUleDeclaration.type.typeName) + + val cmpUle = cmpUleDeclaration.initializer + assertIs(cmpUle) + assertEquals("<=", cmpUle.operatorCode) + + val cmpSgtStatement = main.bodyOrNull(7) + assertNotNull(cmpSgtStatement) + + val cmpSgtDeclaration = cmpSgtStatement.singleDeclaration + assertIs(cmpSgtDeclaration) + assertLocalName("g", cmpSgtDeclaration) + assertEquals("i1", cmpSgtDeclaration.type.typeName) + + val cmpSgt = cmpSgtDeclaration.initializer + assertIs(cmpSgt) + assertEquals(">", cmpSgt.operatorCode) + + val cmpSltStatement = main.bodyOrNull(8) + assertNotNull(cmpSltStatement) + + val cmpSltDeclaration = cmpSltStatement.singleDeclaration + assertIs(cmpSltDeclaration) + assertLocalName("h", cmpSltDeclaration) + assertEquals("i1", cmpSltDeclaration.type.typeName) + + val cmpSlt = cmpSltDeclaration.initializer + assertIs(cmpSlt) + assertEquals("<", cmpSlt.operatorCode) + + val cmpSgeStatement = main.bodyOrNull(9) + assertNotNull(cmpSgeStatement) + + val cmpSgeDeclaration = cmpSgeStatement.singleDeclaration + assertIs(cmpSgeDeclaration) + assertLocalName("i", cmpSgeDeclaration) + assertEquals("i1", cmpSgeDeclaration.type.typeName) + + val cmpSge = cmpSgeDeclaration.initializer + assertIs(cmpSge) + assertEquals(">=", cmpSge.operatorCode) + + val cmpSleStatement = main.bodyOrNull(10) + assertNotNull(cmpSleStatement) + + val cmpSleDeclaration = cmpSleStatement.singleDeclaration + assertIs(cmpSleDeclaration) + assertLocalName("j", cmpSleDeclaration) + assertEquals("i1", cmpSleDeclaration.type.typeName) + + val cmpSle = cmpSleDeclaration.initializer + assertIs(cmpSle) + assertEquals("<=", cmpSle.operatorCode) + } + + @Test + fun testFloatingpointComparisons() { + val topLevel = Path.of("src", "test", "resources", "llvm") + val tu = + analyzeAndGetFirstTU( + listOf(topLevel.resolve("floatingpoint_comparisons.ll").toFile()), + topLevel, + true + ) { + it.registerLanguage() + } + + // main, rand and inferred dummy function "isunordered" + assertEquals(3, tu.declarations.size) + + val main = tu.functions["main"] + assertNotNull(main) + assertLocalName("half", main.type) + + val rand = tu.functions["rand"] + assertNotNull(rand) + assertNull(rand.body) + + val xDeclaration = tu.variables["x"] + assertNotNull(xDeclaration) + val yDeclaration = tu.variables["y"] + assertNotNull(yDeclaration) + + val call = xDeclaration.initializer + assertIs(call) + assertLocalName("rand", call) + assertContains(call.invokes, rand) + assertEquals(0, call.arguments.size) + + val cmpOeqStatement = main.bodyOrNull(2) + assertNotNull(cmpOeqStatement) + + val cmpOeqDeclaration = cmpOeqStatement.singleDeclaration + assertIs(cmpOeqDeclaration) + assertLocalName("a", cmpOeqDeclaration) + assertEquals("i1", cmpOeqDeclaration.type.typeName) + + val cmpOeq = cmpOeqDeclaration.initializer + assertIs(cmpOeq) + assertEquals("==", cmpOeq.operatorCode) + + val cmpOneStatement = main.bodyOrNull(3) + assertNotNull(cmpOneStatement) + + val cmpOneDeclaration = cmpOneStatement.singleDeclaration + assertIs(cmpOneDeclaration) + assertLocalName("b", cmpOneDeclaration) + assertEquals("i1", cmpOneDeclaration.type.typeName) + + val cmpOne = cmpOneDeclaration.initializer + assertIs(cmpOne) + assertEquals("!=", cmpOne.operatorCode) + + val cmpOgtStatement = main.bodyOrNull(4) + assertNotNull(cmpOgtStatement) + + val cmpOgtDeclaration = cmpOgtStatement.singleDeclaration + assertIs(cmpOgtDeclaration) + assertLocalName("c", cmpOgtDeclaration) + assertEquals("i1", cmpOgtDeclaration.type.typeName) + + val cmpOgt = cmpOgtDeclaration.initializer + assertIs(cmpOgt) + assertEquals(">", cmpOgt.operatorCode) + + val cmpOltStatement = main.bodyOrNull(5) + assertNotNull(cmpOltStatement) + + val cmpOltDeclaration = cmpOltStatement.singleDeclaration + assertIs(cmpOltDeclaration) + assertLocalName("d", cmpOltDeclaration) + assertEquals("i1", cmpOltDeclaration.type.typeName) + + val cmpOlt = cmpOltDeclaration.initializer + assertIs(cmpOlt) + assertEquals("<", cmpOlt.operatorCode) + + val cmpOgeStatement = main.bodyOrNull(6) + assertNotNull(cmpOgeStatement) + + val cmpOgeDeclaration = cmpOgeStatement.singleDeclaration + assertIs(cmpOgeDeclaration) + assertLocalName("e", cmpOgeDeclaration) + assertEquals("i1", cmpOgeDeclaration.type.typeName) + + val cmpOge = cmpOgeDeclaration.initializer + assertIs(cmpOge) + assertEquals(">=", cmpOge.operatorCode) + + val cmpOleStatement = main.bodyOrNull(7) + assertNotNull(cmpOleStatement) + + val cmpOleDeclaration = cmpOleStatement.singleDeclaration + assertIs(cmpOleDeclaration) + assertLocalName("f", cmpOleDeclaration) + assertEquals("i1", cmpOleDeclaration.type.typeName) + + val cmpOle = cmpOleDeclaration.initializer + assertIs(cmpOle) + assertEquals("<=", cmpOle.operatorCode) + + val cmpUgtStatement = main.bodyOrNull(8) + assertNotNull(cmpUgtStatement) + + val cmpUgtDeclaration = cmpUgtStatement.singleDeclaration + assertIs(cmpUgtDeclaration) + assertLocalName("g", cmpUgtDeclaration) + assertEquals("i1", cmpUgtDeclaration.type.typeName) + + val cmpUgtOr = cmpUgtDeclaration.initializer + assertIs(cmpUgtOr) + assertEquals("||", cmpUgtOr.operatorCode) + val cmpUgt = cmpUgtOr.rhs + assertIs(cmpUgt) + assertEquals(">", cmpUgt.operatorCode) + + val cmpUltStatement = main.bodyOrNull(9) + assertNotNull(cmpUltStatement) + + val cmpUltDeclaration = cmpUltStatement.singleDeclaration + assertIs(cmpUltDeclaration) + assertLocalName("h", cmpUltDeclaration) + assertEquals("i1", cmpUltDeclaration.type.typeName) + + val cmpUltOr = cmpUltDeclaration.initializer + assertIs(cmpUltOr) + assertEquals("||", cmpUltOr.operatorCode) + val cmpUlt = cmpUltOr.rhs + assertIs(cmpUlt) + assertEquals("<", cmpUlt.operatorCode) + + val cmpUgeStatement = main.bodyOrNull(10) + assertNotNull(cmpUgeStatement) + + val cmpUgeDeclaration = cmpUgeStatement.singleDeclaration + assertIs(cmpUgeDeclaration) + assertLocalName("i", cmpUgeDeclaration) + assertEquals("i1", cmpUgeDeclaration.type.typeName) + + val cmpUgeOr = cmpUgeDeclaration.initializer + assertIs(cmpUgeOr) + assertEquals("||", cmpUgeOr.operatorCode) + val cmpUge = cmpUgeOr.rhs + assertIs(cmpUge) + assertEquals(">=", cmpUge.operatorCode) + + val cmpUleStatement = main.bodyOrNull(11) + assertNotNull(cmpUleStatement) + + val cmpUleDeclaration = cmpUleStatement.singleDeclaration + assertIs(cmpUleDeclaration) + assertLocalName("j", cmpUleDeclaration) + assertEquals("i1", cmpUleDeclaration.type.typeName) + + val cmpUleOr = cmpUleDeclaration.initializer + assertIs(cmpUleOr) + assertEquals("||", cmpUleOr.operatorCode) + val cmpUle = cmpUleOr.rhs + assertIs(cmpUle) + assertEquals("<=", cmpUle.operatorCode) + + val cmpUeqStatement = main.bodyOrNull(12) + assertNotNull(cmpUeqStatement) + + val cmpUeqDeclaration = cmpUeqStatement.singleDeclaration + assertIs(cmpUeqDeclaration) + assertLocalName("k", cmpUeqDeclaration) + assertEquals("i1", cmpUeqDeclaration.type.typeName) + + val cmpUeqOr = cmpUeqDeclaration.initializer + assertIs(cmpUeqOr) + assertEquals("||", cmpUeqOr.operatorCode) + val cmpUeq = cmpUeqOr.rhs + assertIs(cmpUeq) + assertEquals("==", cmpUeq.operatorCode) + + val cmpUneStatement = main.bodyOrNull(13) + assertNotNull(cmpUneStatement) + + val cmpUneDeclaration = cmpUneStatement.singleDeclaration + assertIs(cmpUneDeclaration) + assertLocalName("l", cmpUneDeclaration) + assertEquals("i1", cmpUneDeclaration.type.typeName) + + val cmpUneOr = cmpUneDeclaration.initializer + assertIs(cmpUneOr) + assertEquals("||", cmpUneOr.operatorCode) + val cmpUne = cmpUneOr.rhs + assertIs(cmpUne) + assertEquals("!=", cmpUne.operatorCode) + + val cmpOrdStatement = main.bodyOrNull(14) + assertNotNull(cmpOrdStatement) + + val cmpOrdDeclaration = cmpOrdStatement.singleDeclaration + assertIs(cmpOrdDeclaration) + assertLocalName("m", cmpOrdDeclaration) + assertEquals("i1", cmpOrdDeclaration.type.typeName) + + val cmpOrdNeg = cmpOrdDeclaration.initializer + assertIs(cmpOrdNeg) + assertEquals("!", cmpOrdNeg.operatorCode) + val cmpOrd = cmpOrdNeg.input + assertIs(cmpOrd) + assertLocalName("isunordered", cmpOrd) + assertRefersTo(cmpOrd.arguments[0], xDeclaration) + assertRefersTo(cmpOrd.arguments[1], yDeclaration) + + val cmpUnoStatement = main.bodyOrNull(15) + assertNotNull(cmpUnoStatement) + + val cmpUnoDeclaration = cmpUnoStatement.singleDeclaration + assertIs(cmpUnoDeclaration) + assertLocalName("n", cmpUnoDeclaration) + assertEquals("i1", cmpUnoDeclaration.type.typeName) + + val cmpUno = cmpUnoDeclaration.initializer + assertIs(cmpUno) + assertLocalName("isunordered", cmpUno) + assertRefersTo(cmpUno.arguments[0], xDeclaration) + assertRefersTo(cmpUno.arguments[1], yDeclaration) + } + + @Test + fun testFreeze() { + val topLevel = Path.of("src", "test", "resources", "llvm") + val tu = + analyzeAndGetFirstTU(listOf(topLevel.resolve("freeze.ll").toFile()), topLevel, true) { + it.registerLanguage() + } + + val main = tu.functions["main"] + assertNotNull(main) + + val mainBody = main.body + assertIs(mainBody) + val wDeclaration = main.variables["w"] + assertNotNull(wDeclaration) + + val freezeInstructionDeclaration = mainBody.statements[3] + // We expect something like this: x = (w != undef && w != poison) ? w : llvm.freeze(w) + assertIs(freezeInstructionDeclaration) + val xDeclaration = freezeInstructionDeclaration.singleDeclaration + assertIs(xDeclaration) + assertLocalName("x", xDeclaration) + assertEquals("i32", xDeclaration.type.typeName) + + val freezeInstruction = xDeclaration.initializer + assertIs(freezeInstruction) + val condition = freezeInstruction.condition + assertIs(condition) + assertEquals("&&", condition.operatorCode) + + val undefCheck = condition.lhs + assertIs(undefCheck) + assertEquals("!=", undefCheck.operatorCode) + assertRefersTo(undefCheck.lhs, wDeclaration) + // undef is modeled as null + assertLiteralValue(null, undefCheck.rhs) + + val poisonCheck = condition.rhs + assertIs(poisonCheck) + assertEquals("!=", poisonCheck.operatorCode) + assertRefersTo(poisonCheck.lhs, wDeclaration) + // poison is modeled as a reference "poison" + assertLocalName("poison", poisonCheck.rhs) + + assertRefersTo(freezeInstruction.thenExpression, wDeclaration) + + val elseExpression = freezeInstruction.elseExpression + assertIs(elseExpression) + assertFullName("llvm.freeze", elseExpression) + assertEquals(1, elseExpression.arguments.size) + assertRefersTo(elseExpression.arguments.firstOrNull(), wDeclaration) + } + + @Test + fun testAtomicrmw() { + val topLevel = Path.of("src", "test", "resources", "llvm") + val tu = + analyzeAndGetFirstTU( + listOf(topLevel.resolve("atomicrmw.ll").toFile()), + topLevel, + true + ) { + it.registerLanguage() + } + + val foo = tu.functions["foo"] + assertNotNull(foo) + + val fooBody = foo.body + assertIs(fooBody) + + val atomicrmwAddStatement = fooBody[0] + assertIs(atomicrmwAddStatement) + checkAtomicRmwBinaryOpReplacement(atomicrmwAddStatement, "+", "old1") + + val atomicrmwSubStatement = fooBody[1] + assertIs(atomicrmwSubStatement) + checkAtomicRmwBinaryOpReplacement(atomicrmwSubStatement, "-", "old2") + + val atomicrmwAndStatement = fooBody[2] + assertIs(atomicrmwAndStatement) + checkAtomicRmwBinaryOpReplacement(atomicrmwAndStatement, "&", "old3") + + val atomicrmwOrStatement = fooBody[3] + assertIs(atomicrmwOrStatement) + checkAtomicRmwBinaryOpReplacement(atomicrmwOrStatement, "|", "old4") + + val atomicrmwXorStatement = fooBody[4] + assertIs(atomicrmwXorStatement) + checkAtomicRmwBinaryOpReplacement(atomicrmwXorStatement, "^", "old5") + + // This one is not wrapped in a block and does not have the declaration statement! + // Check that the replacement equals *ptr = ~(*ptr | 1) + val replacementNand = fooBody[5] + assertIs(replacementNand) + assertEquals(1, replacementNand.lhs.size) + assertEquals(1, replacementNand.rhs.size) + assertEquals("=", replacementNand.operatorCode) + val replacementNandLhs = replacementNand.lhs.first() + assertIs(replacementNandLhs) + assertEquals("*", replacementNandLhs.operatorCode) + assertLocalName("ptr", replacementNandLhs.input) + // Check that the rhs is equal to ~(*ptr | 1) + val unaryOp = replacementNand.rhs.first() + assertIs(unaryOp) + assertEquals("~", unaryOp.operatorCode) + val binOp = unaryOp.input + assertIs(binOp) + assertEquals("|", binOp.operatorCode) + val binOpLhs = binOp.lhs + assertIs(binOpLhs) + assertEquals("*", binOpLhs.operatorCode) + assertLocalName("ptr", binOpLhs.input) + assertLiteralValue(1L, binOp.rhs) + + val atomicrmwMinStatement = fooBody[6] + assertIs(atomicrmwMinStatement) + checkAtomicRmwMinMax(atomicrmwMinStatement, "<", "old7", false) + + val atomicrmwMaxStatement = fooBody[7] + assertIs(atomicrmwMaxStatement) + checkAtomicRmwMinMax(atomicrmwMaxStatement, ">", "old8", false) + + val atomicrmwUminStatement = fooBody[8] + assertIs(atomicrmwUminStatement) + checkAtomicRmwMinMax(atomicrmwUminStatement, "<", "old9", true) + + val atomicrmwUmaxStatement = fooBody[9] + assertIs(atomicrmwUmaxStatement) + checkAtomicRmwMinMax(atomicrmwUmaxStatement, ">", "old10", true) + } + + // We expect *ptr = (*ptr 1) ? *ptr : 1 + private fun checkAtomicRmwMinMax( + atomicrmwStatement: Block, + cmp: String, + variableName: String, + requiresUintCast: Boolean + ) { + // Check that the value is assigned to + val declaration = atomicrmwStatement.statements[0].declarations[0] + assertIs(declaration) + assertLocalName(variableName, declaration) + assertLocalName("i32", declaration.type) + val initializer = declaration.initializer + assertIs(initializer) + assertEquals("*", initializer.operatorCode) + assertLocalName("ptr", initializer.input) + + // Check that the replacement equals *ptr = (*ptr 1) ? *ptr : 1 + val replacement = atomicrmwStatement.statements[1] + assertIs(replacement) + assertEquals(1, replacement.lhs.size) + assertEquals(1, replacement.rhs.size) + assertEquals("=", replacement.operatorCode) + val replacementLhs = replacement.lhs.first() + assertIs(replacementLhs) + assertEquals("*", replacementLhs.operatorCode) + assertLocalName("ptr", replacementLhs.input) + + // Check that the rhs is equal to (*ptr 1) ? *ptr : 1 + val conditionalExpression = replacement.rhs.first() + assertIs(conditionalExpression) + val condition = conditionalExpression.condition + assertIs(condition) + assertEquals(cmp, condition.operatorCode) + var cmpLhs = condition.lhs + if (requiresUintCast) { + assertIs(cmpLhs) + assertEquals("ui32", cmpLhs.castType.typeName) + cmpLhs = cmpLhs.expression + } + assertIs(cmpLhs) + assertEquals("*", cmpLhs.operatorCode) + assertLocalName("ptr", cmpLhs.input) + var cmpRhs = condition.rhs + if (requiresUintCast) { + assertIs(cmpRhs) + assertEquals("ui32", cmpRhs.castType.typeName) + cmpRhs = cmpRhs.expression + } + assertLiteralValue(1L, cmpRhs) + val thenExpression = conditionalExpression.thenExpression + assertIs(thenExpression) + assertEquals("*", thenExpression.operatorCode) + assertLocalName("ptr", thenExpression.input) + assertLiteralValue(1L, conditionalExpression.elseExpression) + } + + private fun checkAtomicRmwBinaryOpReplacement( + atomicrmwStatement: Block, + operator: String, + variableName: String + ) { + // Check that the value is assigned to + val declaration = atomicrmwStatement.statements[0].declarations[0] + assertIs(declaration) + assertLocalName(variableName, declaration) + assertLocalName("i32", declaration.type) + val initializer = declaration.initializer + assertIs(initializer) + assertEquals("*", initializer.operatorCode) + assertLocalName("ptr", initializer.input) + + // Check that the replacement equals *ptr = *ptr 1 + val replacement = atomicrmwStatement.statements[1] + + assertIs(replacement) + assertEquals(1, replacement.lhs.size) + assertEquals(1, replacement.rhs.size) + assertEquals("=", replacement.operatorCode) + val replacementLhs = replacement.lhs.first() + assertIs(replacementLhs) + assertEquals("*", replacementLhs.operatorCode) + assertLocalName("ptr", replacementLhs.input) + // Check that the rhs is equal to *ptr + 1 + val binOp = replacement.rhs.first() + assertIs(binOp) + assertEquals(operator, binOp.operatorCode) + val binOpLhs = binOp.lhs + assertIs(binOpLhs) + assertEquals("*", binOpLhs.operatorCode) + assertLocalName("ptr", binOpLhs.input) + assertLiteralValue(1L, binOp.rhs) + } + + @Test + fun testCallBr() { + val topLevel = Path.of("src", "test", "resources", "llvm") + val tu = + analyzeAndGetFirstTU(listOf(topLevel.resolve("callbr.ll").toFile()), topLevel, true) { + it.registerLanguage() + } + + val main = tu.functions["main"] + assertNotNull(main) + + val mainBody = main.body + assertIs(mainBody) + val callBrInstruction = mainBody.statements[3] + assertIs(callBrInstruction) + } + + @Test + fun testIndirectBr() { + val topLevel = Path.of("src", "test", "resources", "llvm") + val tu = + analyzeAndGetFirstTU( + listOf(topLevel.resolve("indirectbr.ll").toFile()), + topLevel, + false + ) { + it.registerLanguage() + } + + val foo = tu.functions["foo"] + assertNotNull(foo) + + val fooBody = foo.body + assertIs(fooBody) + val indirectbrInstruction = fooBody.statements[0] + assertIs(indirectbrInstruction) + assertRefersTo(indirectbrInstruction.selector, foo.parameters.single()) + val jumps = indirectbrInstruction.statement + assertIs(jumps) + val caseBB1 = jumps.statements[0] + assertIs(caseBB1) + assertIs>(caseBB1.caseExpression) + val jumpBB1 = jumps.statements[1] + assertIs(jumpBB1) + assertEquals("bb1", jumpBB1.targetLabel?.label) + assertIs(jumps.statements[2]) + + val caseBB2 = jumps.statements[3] + assertIs(caseBB2) + assertIs>(caseBB2.caseExpression) + val jumpBB2 = jumps.statements[4] + assertIs(jumpBB2) + assertEquals("bb2", jumpBB2.targetLabel?.label) + assertIs(jumps.statements[5]) + } +} diff --git a/cpg-language-llvm/src/test/resources/llvm/atomicrmw.ll b/cpg-language-llvm/src/test/resources/llvm/atomicrmw.ll index f73d2da3c8..22343cc67b 100644 --- a/cpg-language-llvm/src/test/resources/llvm/atomicrmw.ll +++ b/cpg-language-llvm/src/test/resources/llvm/atomicrmw.ll @@ -1,9 +1,18 @@ define i32 @foo(i32* %ptr) nounwind uwtable readnone optsize ssp { - %old = atomicrmw add i32* %ptr, i32 1 acquire + %old1 = atomicrmw add i32* %ptr, i32 1 acquire + %old2 = atomicrmw sub i32* %ptr, i32 1 acquire + %old3 = atomicrmw and i32* %ptr, i32 1 acquire + %old4 = atomicrmw or i32* %ptr, i32 1 acquire + %old5 = atomicrmw xor i32* %ptr, i32 1 acquire + atomicrmw nand i32* %ptr, i32 1 acquire + %old7 = atomicrmw min i32* %ptr, i32 1 acquire + %old8 = atomicrmw max i32* %ptr, i32 1 acquire + %old9 = atomicrmw umin i32* %ptr, i32 1 acquire + %old10 = atomicrmw umax i32* %ptr, i32 1 acquire - %val_success = cmpxchg i32* %ptr, i32 5, i32 %old acq_rel monotonic ; yields { i32, i1 } + %val_success = cmpxchg i32* %ptr, i32 5, i32 %old1 acq_rel monotonic ; yields { i32, i1 } %value_loaded = extractvalue { i32, i1 } %val_success, 1 - ret i32 %old + ret i32 %old1 } \ No newline at end of file diff --git a/cpg-language-llvm/src/test/resources/llvm/callbr.ll b/cpg-language-llvm/src/test/resources/llvm/callbr.ll new file mode 100644 index 0000000000..b1ddc8c630 --- /dev/null +++ b/cpg-language-llvm/src/test/resources/llvm/callbr.ll @@ -0,0 +1,16 @@ + +define i32 @main() { ; i32()* + %y = alloca i32 + store i32 3, i32* %y + %x = load i32, i32* %y + ; "asm goto" without output constraints. + callbr void asm "", "r,!i"(i32 %x) + to label %fallthrough [label %indirect] + +fallthrough: ; This is the fallthrough target + %b = add i32 %x, 5 + ret i32 %b + +indirect: + ret i32 %x +} \ No newline at end of file diff --git a/cpg-language-llvm/src/test/resources/llvm/floatingpoint_comparisons.ll b/cpg-language-llvm/src/test/resources/llvm/floatingpoint_comparisons.ll new file mode 100644 index 0000000000..597aa67276 --- /dev/null +++ b/cpg-language-llvm/src/test/resources/llvm/floatingpoint_comparisons.ll @@ -0,0 +1,26 @@ +; External declaration of the rand function +declare half @rand() nounwind + +; Definition of main function +define half @main() { ; i32()* + %x = call half @rand() + %y = call half @rand() + + %a = fcmp oeq half %x, %y + %b = fcmp one half %x, %y + %c = fcmp ogt half %x, %y + %d = fcmp olt half %x, %y + %e = fcmp oge half %x, %y + %f = fcmp ole half %x, %y + %g = fcmp ugt half %x, %y + %h = fcmp ult half %x, %y + %i = fcmp uge half %x, %y + %j = fcmp ule half %x, %y + %k = fcmp ueq half %x, %y + %l = fcmp une half %x, %y + %m = fcmp ord half %x, %y + %n = fcmp uno half %x, %y + + ret half %x +} + diff --git a/cpg-language-llvm/src/test/resources/llvm/floatingpoint_const.ll b/cpg-language-llvm/src/test/resources/llvm/floatingpoint_const.ll new file mode 100644 index 0000000000..17efc91625 --- /dev/null +++ b/cpg-language-llvm/src/test/resources/llvm/floatingpoint_const.ll @@ -0,0 +1,9 @@ +; declare i32 constant @x +@x = private constant half 1.25 + +; Definition of main function +define half @main() { ; half()* + %a = fadd half 1.25, 1.0 + + ret half %a +} diff --git a/cpg-language-llvm/src/test/resources/llvm/floatingpoint_ops.ll b/cpg-language-llvm/src/test/resources/llvm/floatingpoint_ops.ll new file mode 100644 index 0000000000..dcfe553f4b --- /dev/null +++ b/cpg-language-llvm/src/test/resources/llvm/floatingpoint_ops.ll @@ -0,0 +1,16 @@ +; External declaration of the rand function +declare half @rand() nounwind + +; Definition of main function +define half @main() { ; half()* + %x = call half @rand() + %y = call half @rand() + %a = fmul half %y, %x + %b = fadd half %a, %x + %c = fsub half %a, %b + %d = fdiv half %a, %x + %e = frem half %a, %x + %f = fneg half %e + + ret half %b +} \ No newline at end of file diff --git a/cpg-language-llvm/src/test/resources/llvm/freeze.ll b/cpg-language-llvm/src/test/resources/llvm/freeze.ll new file mode 100644 index 0000000000..39d90133ca --- /dev/null +++ b/cpg-language-llvm/src/test/resources/llvm/freeze.ll @@ -0,0 +1,7 @@ +define i32 @main() { ; i32()* + %ptr = alloca i32 + store i32 undef, i32* %ptr + %w = load i32, i32* %ptr + %x = freeze i32 %w + ret i32 %x +} \ No newline at end of file diff --git a/cpg-language-llvm/src/test/resources/llvm/indirectbr.ll b/cpg-language-llvm/src/test/resources/llvm/indirectbr.ll new file mode 100644 index 0000000000..07218a69fc --- /dev/null +++ b/cpg-language-llvm/src/test/resources/llvm/indirectbr.ll @@ -0,0 +1,9 @@ +define i32 @foo(i32* %addr) nounwind uwtable readnone optsize ssp { + indirectbr i32* %addr, [ label %bb1, label %bb2 ] + +bb1: + ret i32 1 + +bb2: + ret i32 2 +} \ No newline at end of file diff --git a/cpg-language-llvm/src/test/resources/llvm/integer_comparisons.ll b/cpg-language-llvm/src/test/resources/llvm/integer_comparisons.ll new file mode 100644 index 0000000000..2f0205b7d3 --- /dev/null +++ b/cpg-language-llvm/src/test/resources/llvm/integer_comparisons.ll @@ -0,0 +1,21 @@ +; External declaration of the rand function +declare i32 @rand() nounwind + +; Definition of main function +define i32 @main() { ; i32()* + %x = call i32 @rand() + + %a = icmp eq i32 %x, 10 + %b = icmp ne i32 %x, 10 + %c = icmp ugt i32 %x, 10 + %d = icmp ult i32 %x, 10 + %e = icmp uge i32 %x, 10 + %f = icmp ule i32 %x, 10 + %g = icmp sgt i32 %x, 10 + %h = icmp slt i32 %x, 10 + %i = icmp sge i32 %x, 10 + %j = icmp sle i32 %x, 10 + + ret i32 %x +} + diff --git a/cpg-language-llvm/src/test/resources/llvm/integer_const.ll b/cpg-language-llvm/src/test/resources/llvm/integer_const.ll new file mode 100644 index 0000000000..dbd8d2554e --- /dev/null +++ b/cpg-language-llvm/src/test/resources/llvm/integer_const.ll @@ -0,0 +1,25 @@ +; External declaration of the rand function +declare i32 @rand() nounwind + +; External declaration of the foo function +declare i32 @foo(i32) nounwind + +; External declaration of the foo function +declare i32 @foo1(i1) nounwind + +; declare i32 constant @x +@x = private constant i32 5 + +; Definition of main function +define i32 @main() { ; i32()* + %a = call i32 @foo(i32 add(i32 ptrtoint (i32* @x to i32), i32 5)) + %b = call i32 @foo(i32 sub(i32 ptrtoint (i32* @x to i32), i32 5)) + %c = call i32 @foo(i32 mul(i32 ptrtoint (i32* @x to i32), i32 5)) + %d = call i32 @foo(i32 shl(i32 ptrtoint (i32* @x to i32), i32 5)) + %e = call i32 @foo(i32 lshr(i32 ptrtoint (i32* @x to i32), i32 5)) + %f = call i32 @foo(i32 xor(i32 ptrtoint (i32* @x to i32), i32 5)) + %g = call i32 @foo1(i1 icmp eq (i32 ptrtoint (i32* @x to i32), i32 5)) + + ret i32 %a +} + diff --git a/cpg-language-llvm/src/test/resources/llvm/integer_ops.ll b/cpg-language-llvm/src/test/resources/llvm/integer_ops.ll index bd006a5327..5c0ba9a4bd 100644 --- a/cpg-language-llvm/src/test/resources/llvm/integer_ops.ll +++ b/cpg-language-llvm/src/test/resources/llvm/integer_ops.ll @@ -5,10 +5,17 @@ declare i32 @rand() nounwind define i32 @main() { ; i32()* %x = call i32 @rand() %y = call i32 @rand() - %z = mul i32 %y, 32768 - %a = xor i32 %z, %x + %a = mul i32 %y, 32768 %b = add i32 %a, 5 + %c = sub i32 %a, %b + %d = sdiv i32 %a, %x + %e = srem i32 %a, %x + %f = xor i32 %a, %x + %g = udiv i32 %a, %x + %h = urem i32 %a, %x + %i = shl i32 %a, %x + %j = lshr i32 %a, %x + %k = ashr i32 %a, %x ret i32 %b } - diff --git a/cpg-language-llvm/src/test/resources/llvm/main-broken.ll b/cpg-language-llvm/src/test/resources/llvm/main-broken.ll new file mode 100644 index 0000000000..57cb2124aa --- /dev/null +++ b/cpg-language-llvm/src/test/resources/llvm/main-broken.ll @@ -0,0 +1,19 @@ +; Declare the string constant as a global constant. +@.str = prvate unnamed_addr constant [13 x i8] c"hello world\0A\00" + +; External declaration of the puts function +declare i32 @puts(i8* nocapture) nounwind + +; Definition of main function +define i32 @main() { ; i32()* + ; Convert [13 x i8]* to i8*... + %cast210 = getelementptr [13 x i8], [13 x i8]* @.str, i64 0, i64 0 + + ; Call puts function to write out the string to stdout. + call i32 @puts(i8* %cast210) + ret i32 0 +} + +; Named metadata +!0 = !{i32 42, null, !"string"} +!foo = !{!0}