Skip to content

Commit

Permalink
Fixes to type propagation of arithmetic expressions (#1449)
Browse files Browse the repository at this point in the history
* Enable type propagation of modulo BinaryOperators

* Enable type propagation of short circuit BinaryOperators, parsing of float literals, add test

* Enable type propagation of logical shift BinaryOperators, adapt test

* Use Kotlin extension API

* Replace FQN by import

* Make use of LanguageProvider

* Use with block
  • Loading branch information
robinmaisch authored Mar 8, 2024
1 parent 18060e4 commit 2da3c9e
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ abstract class Language<T : LanguageFrontend<*, *>> : Node() {
} else {
rhs
}
lhs is BooleanType && rhs is BooleanType -> lhs
else -> unknownType()
}
}
Expand All @@ -147,14 +148,13 @@ abstract class Language<T : LanguageFrontend<*, *>> : Node() {
* programming languages.
*/
open fun propagateTypeOfBinaryOperation(operation: BinaryOperator): Type {
if (operation.operatorCode == "==" || operation.operatorCode == "===") {
// A comparison, so we return the type "boolean"
return this.builtInTypes.values.firstOrNull { it is BooleanType }
?: this.builtInTypes.values.firstOrNull { it.name.localName.startsWith("bool") }
?: unknownType()
}

return when (operation.operatorCode) {
"==",
"===" ->
// A comparison, so we return the type "boolean"
this.builtInTypes.values.firstOrNull { it is BooleanType }
?: this.builtInTypes.values.firstOrNull { it.name.localName.startsWith("bool") }
?: unknownType()
"+" ->
if (operation.lhs.type is StringType) {
// string + anything => string
Expand All @@ -167,12 +167,16 @@ abstract class Language<T : LanguageFrontend<*, *>> : Node() {
}
"-",
"*",
"/" -> arithmeticOpTypePropagation(operation.lhs.type, operation.rhs.type)
"/",
"%",
"&",
"&&",
"|",
"^",
"||",
"^" -> arithmeticOpTypePropagation(operation.lhs.type, operation.rhs.type)
"<<",
">>" ->
">>",
">>>" ->
if (operation.lhs.type.isPrimitive && operation.rhs.type.isPrimitive) {
// primitive type 1 OP primitive type 2 => primitive type 1
operation.lhs.type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ import com.github.javaparser.resolution.declarations.ResolvedMethodDeclaration
import de.fraunhofer.aisec.cpg.frontends.Handler
import de.fraunhofer.aisec.cpg.frontends.HandlerInterface
import de.fraunhofer.aisec.cpg.graph.*
import de.fraunhofer.aisec.cpg.graph.declarations.MethodDeclaration
import de.fraunhofer.aisec.cpg.graph.declarations.RecordDeclaration
import de.fraunhofer.aisec.cpg.graph.declarations.*
import de.fraunhofer.aisec.cpg.graph.statements.*
import de.fraunhofer.aisec.cpg.graph.statements.expressions.*
import de.fraunhofer.aisec.cpg.graph.statements.expressions.AssignExpression
Expand Down Expand Up @@ -462,7 +461,8 @@ class ExpressionHandler(lang: JavaLanguageFrontend) :
is DoubleLiteralExpr ->
newLiteral(
literalExpr.asDoubleLiteralExpr().asDouble(),
this.primitiveType("double"),
if (literalExpr.value.endsWith("f", true)) this.primitiveType("float")
else this.primitiveType("double"),
rawNode = expr
)
is LongLiteralExpr ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -619,10 +619,7 @@ internal class JavaLanguageFrontendTest : BaseTest() {
it.registerLanguage(JavaLanguage())
}
val tu =
findByUniqueName(
result.components.flatMap { it.translationUnits },
"src/test/resources/fix-328/Cat.java"
)
findByUniqueName(result.components.flatMap { it.translationUnits }, file1.toString())
val namespace = tu.getDeclarationAs(0, NamespaceDeclaration::class.java)
assertNotNull(namespace)

Expand All @@ -649,7 +646,7 @@ internal class JavaLanguageFrontendTest : BaseTest() {
this.declarationHandler =
object : DeclarationHandler(this@MyJavaLanguageFrontend) {
override fun handleClassOrInterfaceDeclaration(
classInterDecl: ClassOrInterfaceDeclaration
classInterDecl: ClassOrInterfaceDeclaration,
): RecordDeclaration {
// take the original class and replace the name
val declaration =
Expand Down Expand Up @@ -800,4 +797,52 @@ internal class JavaLanguageFrontendTest : BaseTest() {
assertNotNull(jArg)
assertContains(jArg.prevDFG, loopVariable)
}

@Test
fun testArithmeticOperators() {
val file = File("src/test/resources/Issue1444.java")

val result =
TestUtils.analyze(listOf(file), file.parentFile.toPath(), true) {
it.registerLanguage(JavaLanguage())
}
val record = result.records["Operators"]
assertNotNull(record)
assertFalse { record.methods.isEmpty() }

val mainMethod = record.methods["main"]

val expressionLists = mainMethod.mcalls
assertEquals(6, expressionLists.size)

assertNotNull(mainMethod)

with(mainMethod) {
val intOperationsList = expressionLists[0]
assertEquals(14, intOperationsList.arguments.size)
assertTrue { intOperationsList.arguments.all { it.type == primitiveType("int") } }

val longOperationsList = expressionLists[1]
assertEquals(14, longOperationsList.arguments.size)
assertTrue { longOperationsList.arguments.all { it.type == primitiveType("long") } }

val floatOperationsList = expressionLists[2]
assertEquals(7, floatOperationsList.arguments.size)
assertTrue { floatOperationsList.arguments.all { it.type == primitiveType("float") } }

val doubleOperationsList = expressionLists[3]
assertEquals(7, doubleOperationsList.arguments.size)
assertTrue { doubleOperationsList.arguments.all { it.type == primitiveType("double") } }

val booleanOperationsList = expressionLists[4]
assertEquals(6, booleanOperationsList.arguments.size)
assertTrue {
booleanOperationsList.arguments.all { it.type == primitiveType("boolean") }
}

val stringOperationsList = expressionLists[5]
assertEquals(6, stringOperationsList.arguments.size)
assertTrue { stringOperationsList.arguments.all { it.type == primitiveType("String") } }
}
}
}
83 changes: 83 additions & 0 deletions cpg-language-java/src/test/resources/Issue1444.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
public class Operators {

public static void main(String[] args) {
// results should be type IntegerType("int")
List.of(
1 + 2,
3 - 4,
5 * 6,
7 / 8,
9 % 10,
11 << 12,
13 >> 14,
14 >>> 14,
15 ^ 16,
17 & 18,
19 | 20,
+21,
-22,
~23
);

// results should be type IntegerType("long")
List.of(
1L + 2,
3 - 4L,
5L * 6,
7 / 8L,
9L % 10,
11L << 12,
13L >> 14,
14L >>> 14,
15 ^ 16L,
17L & 18,
19 | 20L,
+21L,
-22L,
~23L
);

// results should be type FloatingPointType("float")
List.of(
1.f + 2,
3 - 4.f,
5.f * 6,
7 / 8.f,
9.f % 10,
+21.f,
-22.f
);

// results should be type FloatingPointType("long")
List.of(
1.f + 2.d,
3 - 4.d,
5.d * 6.f,
7.d / 8.f,
9.f % 10.d,
+21.d,
-22.d
);

// results should be type BooleanType
List.of(
true && false,
true & true,
false || true,
true | true,
false ^ true,
!false
);

// result should be type StringType
List.of(
"1" + 2,
3 + "4" ,
"5" + true,
'7' + "8",
"9" + null,
new ArrayList<Object>() + "12"
);
}

}

0 comments on commit 2da3c9e

Please sign in to comment.