Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start with python match statement #1801

Merged
merged 23 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.neo4j.ogm.annotation.Relationship

/**
* Represents a Java or C++ switch statement of the `switch (selector) {...}` that can include case
* and default statements. Break statements break out of the switch and labeled breaks in JAva are
* and default statements. Break statements break out of the switch and labeled breaks in Java are
* handled properly.
*/
class SwitchStatement : Statement(), BranchingNode {
Expand All @@ -51,7 +51,7 @@ class SwitchStatement : Statement(), BranchingNode {

@Relationship(value = "SELECTOR_DECLARATION")
var selectorDeclarationEdge = astOptionalEdgeOf<Declaration>()
/** C++ allows to use a declaration instead of a expression as selector */
/** C++ allows to use a declaration instead of an expression as selector */
var selectorDeclaration by unwrapping(SwitchStatement::selectorDeclarationEdge)

@Relationship(value = "STATEMENT") var statementEdge = astOptionalEdgeOf<Statement>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,34 @@
}
}

/**
* Joins the [nodes] with a [BinaryOperator] with the [operatorCode]. Nests the whole thing,
* where the first element in [nodes] is the lhs of the root of the tree of binary operators.
* The last operands are further down the tree.
*/
fun joinListWithBinOp(
KuechA marked this conversation as resolved.
Show resolved Hide resolved
operatorCode: String,
nodes: List<Expression>,
rawNode: Python.AST.AST? = null,

Check warning on line 264 in cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/ExpressionHandler.kt

View check run for this annotation

Codecov / codecov/patch

cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/ExpressionHandler.kt#L264

Added line #L264 was not covered by tests
isImplicit: Boolean = true
maximiliankaul marked this conversation as resolved.
Show resolved Hide resolved
): BinaryOperator {
val lastTwo = newBinaryOperator(operatorCode, rawNode = rawNode)
maximiliankaul marked this conversation as resolved.
Show resolved Hide resolved
lastTwo.rhs = nodes.last()
lastTwo.lhs = nodes[nodes.size - 2]
return nodes.subList(0, nodes.size - 2).foldRight(lastTwo) { newVal, start ->
val nextValue = newBinaryOperator(operatorCode)
if (isImplicit && rawNode != null)
nextValue.implicit(
code = frontend.codeOf(rawNode),
location = frontend.locationOf(rawNode)
)
else if (isImplicit) nextValue.implicit()
maximiliankaul marked this conversation as resolved.
Show resolved Hide resolved
nextValue.rhs = start
nextValue.lhs = newVal
nextValue
}
}

private fun handleStarred(node: Python.AST.Starred): Expression {
val unaryOp = newUnaryOperator("*", postfix = false, prefix = false, rawNode = node)
unaryOp.input = handle(node.value)
Expand Down Expand Up @@ -297,18 +325,12 @@
rawNode = node
)
} else {
// Start with the last two operands, then keep prepending the previous ones until the
// list is finished.
val lastTwo = newBinaryOperator(op, rawNode = node)
lastTwo.rhs = handle(node.values.last())
lastTwo.lhs = handle(node.values[node.values.size - 2])
return node.values.subList(0, node.values.size - 2).foldRight(lastTwo) { newVal, start
->
val nextValue = newBinaryOperator(op, rawNode = node)
nextValue.rhs = start
nextValue.lhs = handle(newVal)
nextValue
}
joinListWithBinOp(
operatorCode = op,
nodes = node.values.map(::handle),
rawNode = node,
isImplicit = true
)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1153,7 +1153,12 @@ interface Python {
* ```
*/
class MatchSingleton(pyObject: PyObject) : BasePattern(pyObject) {
val value: Any by lazy { "value" of pyObject }
/**
* [value] is not optional. We have to make it nullable though because the value will be
* set to `null` if the case matches on `None`. This is known behavior of jep (similar
* to literals/constants).
*/
val value: Any? by lazy { "value" of pyObject }
KuechA marked this conversation as resolved.
Show resolved Hide resolved
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
is Python.AST.Global -> handleGlobal(node)
is Python.AST.Nonlocal -> handleNonLocal(node)
is Python.AST.Raise -> handleRaise(node)
is Python.AST.Match,
is Python.AST.Match -> handleMatch(node)
is Python.AST.TryStar ->
newProblemExpression(
"The statement of class ${node.javaClass} is not supported yet",
Expand All @@ -86,6 +86,103 @@
}
}

/**
* Translates a pattern which can be used by a `match_case`. There are various options available
* and all of them are translated to traditional comparisons and logical expressions which could
* also be seen in the condition of an if-statement.
*/
private fun handlePattern(node: Python.AST.BasePattern, selector: String): Expression {
return when (node) {
is Python.AST.MatchValue ->
newBinaryOperator("==", node).implicit().apply {
this.lhs = newReference(selector)
this.rhs = frontend.expressionHandler.handle(node.value)
}
is Python.AST.MatchSingleton ->
newBinaryOperator("===", node).implicit().apply {
this.lhs = newReference(selector)
this.rhs =
when (val value = node.value) {
KuechA marked this conversation as resolved.
Show resolved Hide resolved
is Python.AST.BaseExpr -> frontend.expressionHandler.handle(value)
null -> newLiteral(value = null, rawNode = node)
else ->
newProblemExpression(
"Can't handle ${value::class} in value of Python.AST.MatchSingleton yet"

Check warning on line 110 in cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/StatementHandler.kt

View check run for this annotation

Codecov / codecov/patch

cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/StatementHandler.kt#L109-L110

Added lines #L109 - L110 were not covered by tests
)
}
}
is Python.AST.MatchOr ->
frontend.expressionHandler.joinListWithBinOp(
"or",
node.patterns.map { handlePattern(it, selector) },
node
)
is Python.AST.MatchSequence,
is Python.AST.MatchMapping,
is Python.AST.MatchClass,
is Python.AST.MatchStar,
is Python.AST.MatchAs ->
newProblemExpression("Cannot handle of type ${node::class} yet")
else -> newProblemExpression("Cannot handle of type ${node::class} yet")

Check warning on line 126 in cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/StatementHandler.kt

View check run for this annotation

Codecov / codecov/patch

cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/StatementHandler.kt#L126

Added line #L126 was not covered by tests
}
}

/**
* Translates a [`match_case`](https://docs.python.org/3/library/ast.html#ast.match_case) to a
* [Block] which holds the [CaseStatement] and then all other statements of the
* [Python.AST.match_case.body].
*
* The [CaseStatement] is generated by the [Python.AST.match_case.pattern] and, if available,
* [Python.AST.match_case.guard]. If there's a `guard` present, we model the
KuechA marked this conversation as resolved.
Show resolved Hide resolved
* [CaseStatement.caseExpression] as an `AND` BinaryOperator, where the `lhs` is the normal
* pattern and the `rhs` is the guard. This is in line with
* [PEP 634](https://peps.python.org/pep-0634/).
*/
private fun handleCase(node: Python.AST.match_case, selector: String): List<Statement> {
maximiliankaul marked this conversation as resolved.
Show resolved Hide resolved
maximiliankaul marked this conversation as resolved.
Show resolved Hide resolved
val statements = mutableListOf<Statement>()
// First, we add the caseStatement
statements +=
newCaseStatement(node).apply {
this.caseExpression =
node.guard?.let {
newBinaryOperator("and")
.implicit(
code = frontend.codeOf(node),
location = frontend.locationOf(node)
)
.apply {
this.lhs = handlePattern(node.pattern, selector)
this.rhs = frontend.expressionHandler.handle(it)
}
} ?: handlePattern(node.pattern, selector)
}
// Now, we add the remaining body.
statements += node.body.map(::handle)
// Currently, the EOG pass requires a break statement to work as expected. For this reason,
// we insert an implicit break statement at the end of the block.
statements +=
newBreakStatement()
.implicit(code = frontend.codeOf(node), location = frontend.locationOf(node))
return statements
}

/**
* Translates a Python [`Match`](https://docs.python.org/3/library/ast.html#ast.Match) into a
* [SwitchStatement].
*/
private fun handleMatch(node: Python.AST.Match): Statement {
maximiliankaul marked this conversation as resolved.
Show resolved Hide resolved
return newSwitchStatement(node).apply {
maximiliankaul marked this conversation as resolved.
Show resolved Hide resolved
val selector = frontend.expressionHandler.handle(node.subject)
this.selector = selector

this.statement =
node.cases.fold(newBlock().implicit()) { block, case ->
block.statements += handleCase(case, selector.name.localName)
block
}
}
}

/**
* Translates a Python [`Raise`](https://docs.python.org/3/library/ast.html#ast.Raise) into a
* [ThrowStatement].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,80 @@ class StatementHandlerTest : BaseTest() {
assertNotNull(result)
}

@Test
fun testMatch() {
maximiliankaul marked this conversation as resolved.
Show resolved Hide resolved
analyzeFile("match.py")

val func = result.functions["matcher"]
assertNotNull(func)

val switchStatement = func.switches.singleOrNull()
assertNotNull(switchStatement)

assertLocalName("x", switchStatement.selector)
assertIs<Reference>(switchStatement.selector)
val paramX = func.parameters.singleOrNull()
assertNotNull(paramX)
assertRefersTo(switchStatement.selector, paramX)

val statementBlock = switchStatement.statement
assertIs<Block>(statementBlock)
val caseSingleton = statementBlock[0]
assertIs<CaseStatement>(caseSingleton)
val singletonCheck = caseSingleton.caseExpression
assertIs<BinaryOperator>(singletonCheck)
assertNotNull(singletonCheck)
assertEquals("===", singletonCheck.operatorCode)
assertRefersTo(singletonCheck.lhs, paramX)
val singletonRhs = singletonCheck.rhs
assertIs<Literal<*>>(singletonRhs)
assertNull(singletonRhs.value)
assertIs<BreakStatement>(statementBlock[2])

val caseValue = statementBlock[3]
assertIs<CaseStatement>(caseValue)
val valueCheck = caseValue.caseExpression
assertIs<BinaryOperator>(valueCheck)
assertNotNull(valueCheck)
assertEquals("==", valueCheck.operatorCode)
assertRefersTo(valueCheck.lhs, paramX)
assertLiteralValue("value", valueCheck.rhs)
assertIs<BreakStatement>(statementBlock[5])

val caseAnd = statementBlock[6]
assertIs<CaseStatement>(caseAnd)
val andExpr = caseAnd.caseExpression
assertIs<BinaryOperator>(andExpr)
assertEquals("and", andExpr.operatorCode)
val andRhs = andExpr.rhs
assertIs<BinaryOperator>(andRhs)
assertEquals(">", andRhs.operatorCode)
assertRefersTo(andRhs.lhs, paramX)
assertLiteralValue(0L, andRhs.rhs)
assertIs<BreakStatement>(statementBlock[8])

assertIs<CaseStatement>(statementBlock[9])
assertIs<BreakStatement>(statementBlock[11])
assertIs<CaseStatement>(statementBlock[12])
assertIs<BreakStatement>(statementBlock[14])
assertIs<CaseStatement>(statementBlock[15])
assertIs<BreakStatement>(statementBlock[17])
assertIs<CaseStatement>(statementBlock[18])
assertIs<BreakStatement>(statementBlock[20])
assertIs<CaseStatement>(statementBlock[21])
assertIs<BreakStatement>(statementBlock[23])
assertIs<CaseStatement>(statementBlock[24])
assertIs<BreakStatement>(statementBlock[26])

val caseOr = statementBlock[27]
assertIs<CaseStatement>(caseOr)
val orExpr = caseOr.caseExpression
assertIs<BinaryOperator>(orExpr)
assertNotNull(orExpr)
assertEquals("or", orExpr.operatorCode)
assertIs<BreakStatement>(statementBlock[29])
}

@Test
fun testTry() {
val tu =
Expand Down
22 changes: 22 additions & 0 deletions cpg-language-python/src/test/resources/python/match.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
def matcher(x):
match x:
case None:
print("singleton" + x)
case "value":
print("value" + x)
case [x] if x>0:
print(x)
case [1, 2]:
print("sequence" + x)
case [1, 2, *rest]:
print("star" + x)
case [*_]:
print("star2" + x)
case {1: _, 2: _}:
print("mapping" + x)
case Point2D(0, 0):
print("class" + x)
case [x] as y:
print("as" + y)
case [x] | [y]:
maximiliankaul marked this conversation as resolved.
Show resolved Hide resolved
print("or" + x)
maximiliankaul marked this conversation as resolved.
Show resolved Hide resolved
Loading