From 2b0474a0fcf70b3c0c04f1bcd8aa02556d90f898 Mon Sep 17 00:00:00 2001 From: KuechA <31155350+KuechA@users.noreply.github.com> Date: Thu, 21 Nov 2024 15:13:04 +0100 Subject: [PATCH] Start with python `match` statement (#1801) * Start with python match statement * fix bug, add test * More testing * Add implicit break * Review feedback * nullable MatchSingleton: comment and handling --- .../cpg/graph/statements/SwitchStatement.kt | 4 +- .../cpg/frontends/python/ExpressionHandler.kt | 41 ++- .../aisec/cpg/frontends/python/Python.kt | 7 +- .../cpg/frontends/python/StatementHandler.kt | 123 +++++++- .../python/statementHandler/MatchTest.kt | 295 ++++++++++++++++++ .../src/test/resources/python/match.py | 54 ++++ 6 files changed, 498 insertions(+), 26 deletions(-) create mode 100644 cpg-language-python/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/python/statementHandler/MatchTest.kt create mode 100644 cpg-language-python/src/test/resources/python/match.py diff --git a/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/statements/SwitchStatement.kt b/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/statements/SwitchStatement.kt index e677cc8911..9d955a5057 100644 --- a/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/statements/SwitchStatement.kt +++ b/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/statements/SwitchStatement.kt @@ -36,7 +36,7 @@ import org.neo4j.ogm.annotation.Relationship /** * Represents a Java or C++ switch statement of the `switch (selector) {...}` that can include case - * and default statements. Break statements break out of the switch and labeled breaks in JAva are + * and default statements. Break statements break out of the switch and labeled breaks in Java are * handled properly. */ class SwitchStatement : Statement(), BranchingNode { @@ -51,7 +51,7 @@ class SwitchStatement : Statement(), BranchingNode { @Relationship(value = "SELECTOR_DECLARATION") var selectorDeclarationEdge = astOptionalEdgeOf() - /** C++ allows to use a declaration instead of a expression as selector */ + /** C++ allows to use a declaration instead of an expression as selector */ var selectorDeclaration by unwrapping(SwitchStatement::selectorDeclarationEdge) @Relationship(value = "STATEMENT") var statementEdge = astOptionalEdgeOf() diff --git a/cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/ExpressionHandler.kt b/cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/ExpressionHandler.kt index ac844775a2..62778bc208 100644 --- a/cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/ExpressionHandler.kt +++ b/cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/ExpressionHandler.kt @@ -237,19 +237,24 @@ class ExpressionHandler(frontend: PythonLanguageFrontend) : * where the first element in [nodes] is the lhs of the root of the tree of binary operators. * The last operands are further down the tree. */ - private fun joinListWithBinOp( + internal fun joinListWithBinOp( operatorCode: String, nodes: List, - rawNode: Python.AST.AST? = null + rawNode: Python.AST.AST? = null, + isImplicit: Boolean = true ): BinaryOperator { - val lastTwo = newBinaryOperator(operatorCode, rawNode = rawNode) - lastTwo.rhs = nodes.last() - lastTwo.lhs = nodes[nodes.size - 2] + val lastTwo = + newBinaryOperator(operatorCode = operatorCode, rawNode = rawNode).apply { + rhs = nodes.last() + lhs = nodes[nodes.size - 2] + this.isImplicit = isImplicit + } return nodes.subList(0, nodes.size - 2).foldRight(lastTwo) { newVal, start -> - val nextValue = newBinaryOperator(operatorCode) - nextValue.rhs = start - nextValue.lhs = newVal - nextValue + newBinaryOperator(operatorCode = operatorCode, rawNode = rawNode).apply { + rhs = start + lhs = newVal + this.isImplicit = isImplicit + } } } @@ -297,18 +302,12 @@ class ExpressionHandler(frontend: PythonLanguageFrontend) : rawNode = node ) } else { - // Start with the last two operands, then keep prepending the previous ones until the - // list is finished. - val lastTwo = newBinaryOperator(op, rawNode = node) - lastTwo.rhs = handle(node.values.last()) - lastTwo.lhs = handle(node.values[node.values.size - 2]) - return node.values.subList(0, node.values.size - 2).foldRight(lastTwo) { newVal, start - -> - val nextValue = newBinaryOperator(op, rawNode = node) - nextValue.rhs = start - nextValue.lhs = handle(newVal) - nextValue - } + joinListWithBinOp( + operatorCode = op, + nodes = node.values.map(::handle), + rawNode = node, + isImplicit = true + ) } } diff --git a/cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/Python.kt b/cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/Python.kt index 774bd91c0b..93eafe9a1e 100644 --- a/cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/Python.kt +++ b/cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/Python.kt @@ -1153,7 +1153,12 @@ interface Python { * ``` */ class MatchSingleton(pyObject: PyObject) : BasePattern(pyObject) { - val value: Any by lazy { "value" of pyObject } + /** + * [value] is not optional. We have to make it nullable though because the value will be + * set to `null` if the case matches on `None`. This is known behavior of jep (similar + * to literals/constants). + */ + val value: Any? by lazy { "value" of pyObject } } /** diff --git a/cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/StatementHandler.kt b/cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/StatementHandler.kt index cdb66e7103..dfb16d943b 100644 --- a/cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/StatementHandler.kt +++ b/cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/StatementHandler.kt @@ -77,15 +77,134 @@ class StatementHandler(frontend: PythonLanguageFrontend) : is Python.AST.Global -> handleGlobal(node) is Python.AST.Nonlocal -> handleNonLocal(node) is Python.AST.Raise -> handleRaise(node) - is Python.AST.Match, + is Python.AST.Match -> handleMatch(node) is Python.AST.TryStar -> newProblemExpression( - "The statement of class ${node.javaClass} is not supported yet", + problem = "The statement of class ${node.javaClass} is not supported yet", rawNode = node ) } } + /** + * Translates a pattern which can be used by a `match_case`. There are various options available + * and all of them are translated to traditional comparisons and logical expressions which could + * also be seen in the condition of an if-statement. + */ + private fun handlePattern(node: Python.AST.BasePattern, subject: String): Expression { + return when (node) { + is Python.AST.MatchValue -> + newBinaryOperator(operatorCode = "==", rawNode = node).implicit().apply { + this.lhs = newReference(name = subject) + this.rhs = frontend.expressionHandler.handle(ctx = node.value) + } + is Python.AST.MatchSingleton -> + newBinaryOperator(operatorCode = "===", rawNode = node).implicit().apply { + this.lhs = newReference(name = subject) + this.rhs = + when (val value = node.value) { + is Python.AST.BaseExpr -> frontend.expressionHandler.handle(ctx = value) + null -> newLiteral(value = null, rawNode = node) + else -> + newProblemExpression( + problem = + "Can't handle ${value::class} in value of Python.AST.MatchSingleton yet" + ) + } + } + is Python.AST.MatchOr -> + frontend.expressionHandler.joinListWithBinOp( + operatorCode = "or", + nodes = node.patterns.map { handlePattern(node = it, subject = subject) }, + rawNode = node, + isImplicit = false + ) + is Python.AST.MatchSequence, + is Python.AST.MatchMapping, + is Python.AST.MatchClass, + is Python.AST.MatchStar, + is Python.AST.MatchAs -> + newProblemExpression( + problem = "Cannot handle of type ${node::class} yet", + rawNode = node + ) + else -> + newProblemExpression( + problem = "Cannot handle of type ${node::class} yet", + rawNode = node + ) + } + } + + /** + * Translates a [`match_case`](https://docs.python.org/3/library/ast.html#ast.match_case) to a + * [Block] which holds the [CaseStatement] and then all other statements of the + * [Python.AST.match_case.body]. + * + * The [CaseStatement] is generated by the [Python.AST.match_case.pattern] and, if available, + * [Python.AST.match_case.guard]. A `guard` is modeled with an `AND` BinaryOperator in the + * [CaseStatement.caseExpression]. Its `lhs` is the normal pattern and the `rhs` is the guard. + * This is in line with [PEP 634](https://peps.python.org/pep-0634/). + */ + private fun handleMatchCase(node: Python.AST.match_case, subject: String): List { + val statements = mutableListOf() + // First, we add the CaseStatement. A `MatchAs` without a `pattern` implies + // it's a default statement. + // We have to handle this here since we do not want to generate the CaseStatement in this + // case. + val pattern = node.pattern + val guard = node.guard + statements += + if (pattern is Python.AST.MatchAs && pattern.pattern == null) { + newDefaultStatement(rawNode = pattern) + } else if (guard != null) { + newCaseStatement(rawNode = node).apply { + this.caseExpression = + newBinaryOperator(operatorCode = "and") + .implicit( + code = frontend.codeOf(astNode = node), + location = frontend.locationOf(astNode = node) + ) + .apply { + this.lhs = handlePattern(node = node.pattern, subject = subject) + this.rhs = frontend.expressionHandler.handle(ctx = guard) + } + } + } else { + newCaseStatement(rawNode = node).apply { + this.caseExpression = handlePattern(node = node.pattern, subject = subject) + } + } + // Now, we add the remaining body. + statements += node.body.map(::handle) + // Currently, the EOG pass requires a break statement to work as expected. For this reason, + // we insert an implicit break statement at the end of the block. + statements += + newBreakStatement() + .implicit( + code = frontend.codeOf(astNode = node), + location = frontend.locationOf(astNode = node) + ) + return statements + } + + /** + * Translates a Python [`Match`](https://docs.python.org/3/library/ast.html#ast.Match) into a + * [SwitchStatement]. + */ + private fun handleMatch(node: Python.AST.Match): SwitchStatement = + newSwitchStatement(rawNode = node).apply { + val subject = frontend.expressionHandler.handle(ctx = node.subject) + this.selector = subject + + this.statement = + node.cases.fold(initial = newBlock().implicit()) { block, case -> + block.statements += + handleMatchCase(node = case, subject = subject.name.localName) + block + } + } + /** * Translates a Python [`Raise`](https://docs.python.org/3/library/ast.html#ast.Raise) into a * [ThrowExpression]. diff --git a/cpg-language-python/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/python/statementHandler/MatchTest.kt b/cpg-language-python/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/python/statementHandler/MatchTest.kt new file mode 100644 index 0000000000..aca1ef4b65 --- /dev/null +++ b/cpg-language-python/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/python/statementHandler/MatchTest.kt @@ -0,0 +1,295 @@ +/* + * Copyright (c) 2024, Fraunhofer AISEC. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * $$$$$$\ $$$$$$$\ $$$$$$\ + * $$ __$$\ $$ __$$\ $$ __$$\ + * $$ / \__|$$ | $$ |$$ / \__| + * $$ | $$$$$$$ |$$ |$$$$\ + * $$ | $$ ____/ $$ |\_$$ | + * $$ | $$\ $$ | $$ | $$ | + * \$$$$$ |$$ | \$$$$$ | + * \______/ \__| \______/ + * + */ +package de.fraunhofer.aisec.cpg.frontends.python.statementHandler + +import de.fraunhofer.aisec.cpg.TranslationResult +import de.fraunhofer.aisec.cpg.frontends.python.PythonLanguage +import de.fraunhofer.aisec.cpg.graph.functions +import de.fraunhofer.aisec.cpg.graph.get +import de.fraunhofer.aisec.cpg.graph.statements.BreakStatement +import de.fraunhofer.aisec.cpg.graph.statements.CaseStatement +import de.fraunhofer.aisec.cpg.graph.statements.DefaultStatement +import de.fraunhofer.aisec.cpg.graph.statements.expressions.BinaryOperator +import de.fraunhofer.aisec.cpg.graph.statements.expressions.Block +import de.fraunhofer.aisec.cpg.graph.statements.expressions.CallExpression +import de.fraunhofer.aisec.cpg.graph.statements.expressions.Literal +import de.fraunhofer.aisec.cpg.graph.statements.expressions.ProblemExpression +import de.fraunhofer.aisec.cpg.graph.statements.expressions.Reference +import de.fraunhofer.aisec.cpg.graph.switches +import de.fraunhofer.aisec.cpg.test.analyze +import de.fraunhofer.aisec.cpg.test.assertLiteralValue +import de.fraunhofer.aisec.cpg.test.assertLocalName +import de.fraunhofer.aisec.cpg.test.assertRefersTo +import java.nio.file.Path +import kotlin.test.assertEquals +import kotlin.test.assertIs +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import org.junit.jupiter.api.BeforeAll +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.TestInstance + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +class MatchTest { + private lateinit var topLevel: Path + private lateinit var result: TranslationResult + + @BeforeAll + fun setup() { + topLevel = Path.of("src", "test", "resources", "python") + result = + analyze(listOf(topLevel.resolve("match.py").toFile()), topLevel, true) { + it.registerLanguage() + } + assertNotNull(result) + } + + @Test + fun testMatchSingleton() { + val func = result.functions["matchSingleton"] + assertNotNull(func) + + val switchStatement = func.switches.singleOrNull() + assertNotNull(switchStatement) + + assertLocalName("x", switchStatement.selector) + assertIs(switchStatement.selector) + val paramX = func.parameters.singleOrNull() + assertNotNull(paramX) + assertRefersTo(switchStatement.selector, paramX) + + val statementBlock = switchStatement.statement + assertIs(statementBlock) + assertEquals(3, statementBlock.statements.size) + val caseSingleton = statementBlock[0] + assertIs(caseSingleton) + val singletonCheck = caseSingleton.caseExpression + assertIs(singletonCheck) + assertEquals("===", singletonCheck.operatorCode) + assertRefersTo(singletonCheck.lhs, paramX) + val singletonRhs = singletonCheck.rhs + assertIs>(singletonRhs) + assertNull(singletonRhs.value) + assertIs(statementBlock[2]) + } + + @Test + fun testMatchValue() { + val func = result.functions["matchValue"] + assertNotNull(func) + + val switchStatement = func.switches.singleOrNull() + assertNotNull(switchStatement) + + assertLocalName("x", switchStatement.selector) + assertIs(switchStatement.selector) + val paramX = func.parameters.singleOrNull() + assertNotNull(paramX) + assertRefersTo(switchStatement.selector, paramX) + + val statementBlock = switchStatement.statement + assertIs(statementBlock) + assertEquals(3, statementBlock.statements.size) + val caseValue = statementBlock[0] + assertIs(caseValue) + val valueCheck = caseValue.caseExpression + assertIs(valueCheck) + assertEquals("==", valueCheck.operatorCode) + assertRefersTo(valueCheck.lhs, paramX) + assertLiteralValue("value", valueCheck.rhs) + assertIs(statementBlock[2]) + } + + @Test + fun testMatchOr() { + val func = result.functions["matchOr"] + assertNotNull(func) + + val switchStatement = func.switches.singleOrNull() + assertNotNull(switchStatement) + + assertLocalName("x", switchStatement.selector) + assertIs(switchStatement.selector) + val paramX = func.parameters.singleOrNull() + assertNotNull(paramX) + assertRefersTo(switchStatement.selector, paramX) + + val statementBlock = switchStatement.statement + assertIs(statementBlock) + assertEquals(3, statementBlock.statements.size) + val caseOr = statementBlock[0] + assertIs(caseOr) + val orExpr = caseOr.caseExpression + assertIs(orExpr) + assertEquals("or", orExpr.operatorCode) + assertIs(orExpr.lhs) + assertIs(orExpr.rhs) + assertIs(statementBlock[2]) + } + + @Test + fun testMatchDefault() { + val func = result.functions["matchDefault"] + assertNotNull(func) + + val switchStatement = func.switches.singleOrNull() + assertNotNull(switchStatement) + + assertLocalName("x", switchStatement.selector) + assertIs(switchStatement.selector) + val paramX = func.parameters.singleOrNull() + assertNotNull(paramX) + assertRefersTo(switchStatement.selector, paramX) + + val statementBlock = switchStatement.statement + assertIs(statementBlock) + assertEquals(3, statementBlock.statements.size) + val caseDefault = statementBlock[0] + assertIs(caseDefault) + assertIs(statementBlock[2]) + } + + @Test + fun testMatchGuard() { + val func = result.functions["matchAnd"] + assertNotNull(func) + + val switchStatement = func.switches.singleOrNull() + assertNotNull(switchStatement) + + assertLocalName("x", switchStatement.selector) + assertIs(switchStatement.selector) + val paramX = func.parameters.singleOrNull() + assertNotNull(paramX) + assertRefersTo(switchStatement.selector, paramX) + + val statementBlock = switchStatement.statement + assertIs(statementBlock) + val caseAnd = statementBlock[0] + assertIs(caseAnd) + val andExpr = caseAnd.caseExpression + assertIs(andExpr) + assertEquals("and", andExpr.operatorCode) + val andRhs = andExpr.rhs + assertIs(andRhs) + assertEquals(">", andRhs.operatorCode) + assertRefersTo(andRhs.lhs, paramX) + assertLiteralValue(0L, andRhs.rhs) + assertIs(statementBlock[2]) + } + + @Test + fun testMatchCombined() { + val func = result.functions["matcher"] + assertNotNull(func) + + val switchStatement = func.switches.singleOrNull() + assertNotNull(switchStatement) + + assertLocalName("x", switchStatement.selector) + assertIs(switchStatement.selector) + val paramX = func.parameters.singleOrNull() + assertNotNull(paramX) + assertRefersTo(switchStatement.selector, paramX) + + val statementBlock = switchStatement.statement + assertIs(statementBlock) + val caseSingleton = statementBlock[0] + assertIs(caseSingleton) + val singletonCheck = caseSingleton.caseExpression + assertIs(singletonCheck) + assertEquals("===", singletonCheck.operatorCode) + assertRefersTo(singletonCheck.lhs, paramX) + val singletonRhs = singletonCheck.rhs + assertIs>(singletonRhs) + assertNull(singletonRhs.value) + assertIs(statementBlock[2]) + + val caseValue = statementBlock[3] + assertIs(caseValue) + val valueCheck = caseValue.caseExpression + assertIs(valueCheck) + assertEquals("==", valueCheck.operatorCode) + assertRefersTo(valueCheck.lhs, paramX) + assertLiteralValue("value", valueCheck.rhs) + assertIs(statementBlock[5]) + + val caseAnd = statementBlock[6] + assertIs(caseAnd) + val andExpr = caseAnd.caseExpression + assertIs(andExpr) + assertEquals("and", andExpr.operatorCode) + val andRhs = andExpr.rhs + assertIs(andRhs) + assertEquals(">", andRhs.operatorCode) + assertRefersTo(andRhs.lhs, paramX) + assertLiteralValue(0L, andRhs.rhs) + assertIs(statementBlock[8]) + + assertIs(statementBlock[9]) + assertIs(statementBlock[11]) + assertIs(statementBlock[12]) + assertIs(statementBlock[14]) + assertIs(statementBlock[15]) + assertIs(statementBlock[17]) + assertIs(statementBlock[18]) + assertIs(statementBlock[20]) + assertIs(statementBlock[21]) + assertIs(statementBlock[23]) + assertIs(statementBlock[24]) + assertIs(statementBlock[26]) + + val caseOr = statementBlock[27] + assertIs(caseOr) + val orExpr = caseOr.caseExpression + assertIs(orExpr) + assertEquals("or", orExpr.operatorCode) + assertIs(orExpr.lhs) + assertIs(orExpr.rhs) + assertIs(statementBlock[29]) + + val caseDefault = statementBlock[30] + assertIs(caseDefault) + assertIs(statementBlock[32]) + } + + @Test + fun testMatch2() { + val func = result.functions["match_weird"] + assertNotNull(func) + + val switchStatement = func.switches.singleOrNull() + assertNotNull(switchStatement) + + assertIs(switchStatement.selector) + + val statementBlock = switchStatement.statement + assertIs(statementBlock) + val case = statementBlock[0] + assertIs(case) + assertIs(case.caseExpression) + } +} diff --git a/cpg-language-python/src/test/resources/python/match.py b/cpg-language-python/src/test/resources/python/match.py new file mode 100644 index 0000000000..21e7884ba1 --- /dev/null +++ b/cpg-language-python/src/test/resources/python/match.py @@ -0,0 +1,54 @@ +def matcher(x): + match x: + case None: + print("singleton" + x) + case "value": + print("value" + x) + case [x] if x>0: + print(x) + case [1, 2]: + print("sequence" + x) + case [1, 2, *rest]: + print("star" + x) + case [*_]: + print("star2" + x) + case {1: _, 2: _}: + print("mapping" + x) + case Point2D(0, 0): + print("class" + x) + case [x] as y: + print("as" + y) + case "xyz" | "abc": + print("or" + x) + case _: + print("Default match") + +def matchSingleton(x): + match x: + case None: + print("singleton" + x) + +def matchValue(x): + match x: + case "value": + print("value" + x) + +def matchOr(x): + match x: + case "xyz" | "abc": + print("or" + x) + +def matchAnd(x): + match x: + case [x] if x>0: + print(x) + +def matchDefault(x): + match x: + case _: + print("Default match") + +def match_weird(): + match command.split(): + case ["go", ("north" | "south" | "east" | "west") as direction]: + current_room = current_room.neighbor(direction) \ No newline at end of file