From 4a8e4999ce59a2f079dd0e3c2fafaa615a88156d Mon Sep 17 00:00:00 2001 From: Marco Eilers Date: Mon, 28 Oct 2024 00:59:29 +0100 Subject: [PATCH] More term simplification especially for MCE terms --- src/main/resources/preamble.smt2 | 4 + .../decider/TermToSMTLib2Converter.scala | 1 + .../scala/decider/TermToZ3APIConverter.scala | 9 +- src/main/scala/interfaces/state/Chunks.scala | 5 +- src/main/scala/reporting/Converter.scala | 1 + src/main/scala/rules/Brancher.scala | 18 ++- src/main/scala/rules/Executor.scala | 2 +- .../rules/MoreCompleteExhaleSupporter.scala | 34 ++++-- src/main/scala/state/Chunks.scala | 4 + src/main/scala/state/Heap.scala | 7 ++ src/main/scala/state/State.scala | 9 ++ src/main/scala/state/Store.scala | 6 + src/main/scala/state/Terms.scala | 115 ++++++++++++++++-- src/main/scala/state/Utils.scala | 1 + 14 files changed, 182 insertions(+), 34 deletions(-) diff --git a/src/main/resources/preamble.smt2 b/src/main/resources/preamble.smt2 index c08c806a4..46dbaabbb 100644 --- a/src/main/resources/preamble.smt2 +++ b/src/main/resources/preamble.smt2 @@ -39,6 +39,10 @@ (define-fun $Perm.min ((p1 $Perm) (p2 $Perm)) Real (ite (<= p1 p2) p1 p2)) +; max function for permissions +(define-fun $Perm.max ((p1 $Perm) (p2 $Perm)) Real + (ite (>= p1 p2) p1 p2)) + ; --- Sort wrappers --- ; Sort wrappers are no longer part of the static preamble. Instead, they are diff --git a/src/main/scala/decider/TermToSMTLib2Converter.scala b/src/main/scala/decider/TermToSMTLib2Converter.scala index d876d1db7..966315d9c 100644 --- a/src/main/scala/decider/TermToSMTLib2Converter.scala +++ b/src/main/scala/decider/TermToSMTLib2Converter.scala @@ -212,6 +212,7 @@ class TermToSMTLib2Converter case PermIntDiv(t0, t1) => renderBinaryOp("/", renderAsReal(t0), renderAsReal(t1)) case PermPermDiv(t0, t1) => renderBinaryOp("/", renderAsReal(t0), renderAsReal(t1)) case PermMin(t0, t1) => renderBinaryOp("$Perm.min", render(t0), render(t1)) + case PermMax(t0, t1) => renderBinaryOp("$Perm.max", render(t0), render(t1)) case IsValidPermVar(v) => parens(text("$Perm.isValidVar") <+> render(v)) case IsReadPermVar(v) => parens(text("$Perm.isReadVar") <+> render(v)) diff --git a/src/main/scala/decider/TermToZ3APIConverter.scala b/src/main/scala/decider/TermToZ3APIConverter.scala index d39dad3c1..9a064b73e 100644 --- a/src/main/scala/decider/TermToZ3APIConverter.scala +++ b/src/main/scala/decider/TermToZ3APIConverter.scala @@ -332,14 +332,15 @@ class TermToZ3APIConverter case PermIntDiv(t0, t1) => ctx.mkDiv(convertToReal(t0), convertToReal(t1)) case PermPermDiv(t0, t1) => ctx.mkDiv(convertToReal(t0), convertToReal(t1)) case PermMin(t0, t1) => { - /* - (define-fun $Perm.min ((p1 $Perm) (p2 $Perm)) Real - (ite (<= p1 p2) p1 p2)) - */ val e0 = convert(t0).asInstanceOf[ArithExpr] val e1 = convert(t1).asInstanceOf[ArithExpr] ctx.mkITE(ctx.mkLe(e0, e1), e0, e1) } + case PermMax(t0, t1) => { + val e0 = convert(t0).asInstanceOf[ArithExpr] + val e1 = convert(t1).asInstanceOf[ArithExpr] + ctx.mkITE(ctx.mkGe(e0, e1), e0, e1) + } case IsValidPermVar(v) => { /* (define-fun $Perm.isValidVar ((p $Perm)) Bool diff --git a/src/main/scala/interfaces/state/Chunks.scala b/src/main/scala/interfaces/state/Chunks.scala index 329750248..fd5e1b35a 100644 --- a/src/main/scala/interfaces/state/Chunks.scala +++ b/src/main/scala/interfaces/state/Chunks.scala @@ -10,8 +10,9 @@ import viper.silicon.resources.ResourceID import viper.silicon.state.terms.{Term, Var} import viper.silver.ast -trait Chunk - +trait Chunk { + def addEquality(t1: Term, t2: Term): Chunk = this +} trait ChunkIdentifer trait GeneralChunk extends Chunk { diff --git a/src/main/scala/reporting/Converter.scala b/src/main/scala/reporting/Converter.scala index 3c3a34c43..53f7eae3f 100644 --- a/src/main/scala/reporting/Converter.scala +++ b/src/main/scala/reporting/Converter.scala @@ -437,6 +437,7 @@ object Converter { case PermLess(_, _) => None case PermAtMost(_, _) => None case PermMin(_, _) => None + case PermMax(_, _) => None case _ => None } } diff --git a/src/main/scala/rules/Brancher.scala b/src/main/scala/rules/Brancher.scala index 444c61181..eb73cb548 100644 --- a/src/main/scala/rules/Brancher.scala +++ b/src/main/scala/rules/Brancher.scala @@ -14,7 +14,7 @@ import viper.silicon.decider.PathConditionStack import viper.silicon.interfaces.{Unreachable, VerificationResult} import viper.silicon.reporting.condenseToViperResult import viper.silicon.state.State -import viper.silicon.state.terms.{FunctionDecl, MacroDecl, Not, Term} +import viper.silicon.state.terms.{BuiltinEquals, FunctionDecl, MacroDecl, Not, Term} import viper.silicon.verifier.Verifier import viper.silver.ast import viper.silver.reporter.BranchFailureMessage @@ -158,7 +158,12 @@ object brancher extends BranchingRules { } } - val result = fElse(v1.stateConsolidator(s1).consolidateOptionally(s1, v1), v1) + val s1p = condition match { + case Not(BuiltinEquals(p0, p1)) => + s1.addEquality(p0, p1) + case _ => s1 + } + val result = fElse(v1.stateConsolidator(s1p).consolidateOptionally(s1p, v1), v1) if (wasElseExecutedOnDifferentVerifier) { v1.decider.resetProverOptions() v1.decider.setProverOptions(proverArgsOfElseBranchDecider) @@ -197,7 +202,14 @@ object brancher extends BranchingRules { v1.decider.prover.comment(s"[then-branch: $cnt | $condition]") v1.decider.setCurrentBranchCondition(condition, conditionExp) - fThen(v1.stateConsolidator(s1).consolidateOptionally(s1, v1), v1) + val s1p = condition match { + case BuiltinEquals(p0, p1) => + s1.addEquality(p0, p1) + case _ => s1 + } + val s1pp = v1.stateConsolidator(s1p).consolidateOptionally(s1p, v1) + + fThen(s1pp, v1) }) } else { Unreachable() diff --git a/src/main/scala/rules/Executor.scala b/src/main/scala/rules/Executor.scala index a19e85d65..63497f37c 100644 --- a/src/main/scala/rules/Executor.scala +++ b/src/main/scala/rules/Executor.scala @@ -720,7 +720,7 @@ object executor extends ExecutionRules { private def ssaifyRhs(rhs: Term, rhsExp: ast.Exp, rhsExpNew: Option[ast.Exp], name: String, typ: ast.Type, v: Verifier, s : State): (Term, Option[ast.Exp]) = { rhs match { - case _: Var | _: Literal => + case t if t.isDefinitelyNonTriggering => (rhs, rhsExpNew) case _ => diff --git a/src/main/scala/rules/MoreCompleteExhaleSupporter.scala b/src/main/scala/rules/MoreCompleteExhaleSupporter.scala index 1739866e1..0a497f7cb 100644 --- a/src/main/scala/rules/MoreCompleteExhaleSupporter.scala +++ b/src/main/scala/rules/MoreCompleteExhaleSupporter.scala @@ -274,20 +274,24 @@ object moreCompleteExhaleSupporter extends SymbolicExecutionRules { val newChunks = ListBuffer[NonQuantifiedChunk]() var moreNeeded = true - val definiteAlias = chunkSupporter.findChunk[NonQuantifiedChunk](relevantChunks, id, args, v).filter(c => - v.decider.check(IsPositive(c.perm), Verifier.config.checkTimeout()) - ) - - val sortFunction: (NonQuantifiedChunk, NonQuantifiedChunk) => Boolean = (ch1, ch2) => { - // The definitive alias and syntactic aliases should get priority, since it is always - // possible to consume from them - definiteAlias.contains(ch1) || !definiteAlias.contains(ch2) && ch1.args == args + val (sortedChunks, checkedDefiniteAlias) = if (relevantChunks.size < 2) { + (relevantChunks, None) + } else { + val definiteAlias = chunkSupporter.findChunk[NonQuantifiedChunk](relevantChunks, id, args, v).filter(c => + v.decider.check(IsPositive(c.perm), Verifier.config.checkTimeout()) + ) + val sortFunction: (NonQuantifiedChunk, NonQuantifiedChunk) => Boolean = (ch1, ch2) => { + // The definitive alias and syntactic aliases should get priority, since it is always + // possible to consume from them + definiteAlias.contains(ch1) || !definiteAlias.contains(ch2) && ch1.args == args + } + (relevantChunks.sortWith(sortFunction), Some(definiteAlias)) } val additionalArgs = s.relevantQuantifiedVariables.map(_._1) var currentFunctionRecorder = s.functionRecorder - relevantChunks.sortWith(sortFunction) foreach { ch => + sortedChunks foreach { ch => if (moreNeeded) { val eqHelper = ch.args.zip(args).map { case (t1, t2) => t1 === t2 } val eq = And(eqHelper) @@ -295,13 +299,14 @@ object moreCompleteExhaleSupporter extends SymbolicExecutionRules { val takenTerm = Ite(eq, PermMin(ch.perm, pNeeded), NoPerm) val pTakenExp = permsExp.map(pe => ast.CondExp(eqExp.get, buildMinExp(Seq(ch.permExp.get, pNeededExp.get), ast.Perm), ast.NoPerm()(pe.pos, pe.info, pe.errT))(eqExp.get.pos, eqExp.get.info, eqExp.get.errT)) - val pTaken = if (takenTerm.isInstanceOf[PermLiteral] || s.functionRecorder != NoopFunctionRecorder || Verifier.config.useFlyweight) { + val pTaken = if (true) { //(takenTerm.isInstanceOf[PermLiteral] || s.functionRecorder != NoopFunctionRecorder || Verifier.config.useFlyweight) { // ME: When using Z3 via API, it is beneficial to not use macros, since macro-terms will *always* be different // (leading to new terms that have to be translated), whereas without macros, we can usually use a term // that already exists. // During function verification, we should not define macros, since they could contain result, which is not // defined elsewhere. // Also, we don't introduce a macro if the term is a straightforward literal. + // ME: Trying to never use macros to get more simplification. takenTerm } else { val pTakenArgs = additionalArgs @@ -320,11 +325,13 @@ object moreCompleteExhaleSupporter extends SymbolicExecutionRules { pNeeded = PermMinus(pNeeded, pTaken) pNeededExp = permsExp.map(pe => ast.PermSub(pNeededExp.get, pTakenExp.get)(pe.pos, pe.info, pe.errT)) - if (!v.decider.check(IsNonPositive(newChunk.perm), Verifier.config.splitTimeout())) { + val newChunkHasNoPerm = IsNonPositive(newChunk.perm) + if (newChunkHasNoPerm == False || !v.decider.check(newChunkHasNoPerm, Verifier.config.splitTimeout())) { newChunks.append(newChunk) } - moreNeeded = !v.decider.check(pNeeded === NoPerm, Verifier.config.splitTimeout()) + val noMoreNeeded = pNeeded === NoPerm + moreNeeded = noMoreNeeded == False || !v.decider.check(noMoreNeeded, Verifier.config.splitTimeout()) } else { newChunks.append(ch) } @@ -342,8 +349,9 @@ object moreCompleteExhaleSupporter extends SymbolicExecutionRules { val newHeap = Heap(allChunks) val s0 = s.copy(functionRecorder = currentFunctionRecorder) + val checkedDefiniteValue = checkedDefiniteAlias.map(_.map(_.snap)) - summarise(s0, relevantChunks.toSeq, resource, args, argsExp, Some(definiteAlias.map(_.snap)), v)((s1, snap, _, _, _, v1) => { + summarise(s0, relevantChunks.toSeq, resource, args, argsExp, checkedDefiniteValue, v)((s1, snap, _, _, _, v1) => { val condSnap = if (v1.decider.check(IsPositive(perms), Verifier.config.checkTimeout())) { snap } else { diff --git a/src/main/scala/state/Chunks.scala b/src/main/scala/state/Chunks.scala index e986faded..329c0e7fc 100644 --- a/src/main/scala/state/Chunks.scala +++ b/src/main/scala/state/Chunks.scala @@ -56,6 +56,10 @@ case class BasicChunk(resourceID: BaseID, case FieldID => s"${args.head}.$id -> $snap # $perm" case PredicateID => s"$id($snap; ${args.mkString(",")}) # $perm" } + + override def addEquality(t1: Term, t2: Term) = { + BasicChunk(resourceID, id, args.map(_.replace(t1, t2)), argsExp, snap.replace(t1, t2), perm.replace(t1, t2), permExp) + } } sealed trait QuantifiedBasicChunk extends QuantifiedChunk { diff --git a/src/main/scala/state/Heap.scala b/src/main/scala/state/Heap.scala index 302de3dc3..8fc79e9ce 100644 --- a/src/main/scala/state/Heap.scala +++ b/src/main/scala/state/Heap.scala @@ -7,12 +7,14 @@ package viper.silicon.state import viper.silicon.interfaces.state.Chunk +import viper.silicon.state.terms.Term trait Heap { def values: Iterable[Chunk] def +(chunk: Chunk): Heap def +(other: Heap): Heap def -(chunk: Chunk): Heap + def addEquality(t1: Term, t2: Term): Heap } trait HeapFactory[H <: Heap] { @@ -38,4 +40,9 @@ final class ListBackedHeap private[state] (chunks: Vector[Chunk]) new ListBackedHeap(prefix ++ suffix.tail) } + + def addEquality(t1: Term, t2: Term) = { + val newChunks = chunks.map(_.addEquality(t1, t2)) + new ListBackedHeap(newChunks) + } } diff --git a/src/main/scala/state/State.scala b/src/main/scala/state/State.scala index 25f9e5fdd..4237c647d 100644 --- a/src/main/scala/state/State.scala +++ b/src/main/scala/state/State.scala @@ -106,6 +106,15 @@ final case class State(g: Store = Store(), def cycles(m: ast.Member) = visited.count(_ == m) + def addEquality(t1: Term, t2: Term): State = { + if (t1 == t2) { + this + } else { + val newState = copy(g = g.addEquality(t1, t2), h = h.addEquality(t1, t2)) + newState + } + } + def setConstrainable(arps: Iterable[Var], constrainable: Boolean) = { val newConstrainableARPs = if (constrainable) constrainableARPs ++ arps diff --git a/src/main/scala/state/Store.scala b/src/main/scala/state/Store.scala index 2777026d1..dd665f2f0 100644 --- a/src/main/scala/state/Store.scala +++ b/src/main/scala/state/Store.scala @@ -20,6 +20,7 @@ trait Store { def getExp(key: ast.AbstractLocalVar): Option[ast.Exp] def +(kv: (ast.AbstractLocalVar, (Term, Option[ast.Exp]))): Store def +(other: Store): Store + def addEquality(t1: Term, t2: Term): Store } trait StoreFactory[ST <: Store] { @@ -51,4 +52,9 @@ final class MapBackedStore private[state] (map: Map[ast.AbstractLocalVar, (Term, } def +(entry: (ast.AbstractLocalVar, (Term, Option[ast.Exp]))) = new MapBackedStore(map + entry) def +(other: Store) = new MapBackedStore(map ++ other.values) + + def addEquality(t1: Term, t2: Term) = { + val newMap = map.map { case (k, (v, ve)) => (k, (v.replace(t1, t2), ve)) } + new MapBackedStore(newMap) + } } diff --git a/src/main/scala/state/Terms.scala b/src/main/scala/state/Terms.scala index 0e60b98b9..513cbbfcd 100644 --- a/src/main/scala/state/Terms.scala +++ b/src/main/scala/state/Terms.scala @@ -264,6 +264,8 @@ class Var private[terms] (val id: Identifier, val sort: Sort, val isWildcard: Bo override lazy val toString = id.toString def copy(id: Identifier = id, sort: Sort = sort, isWildcard: Boolean = isWildcard) = Var(id, sort, isWildcard) + + override val isDefinitelyNonTriggering: Boolean = true } object Var extends CondFlyweightFactory[(Identifier, Sort, Boolean), Var, Var] { @@ -412,6 +414,8 @@ sealed trait Term extends Node { case other => Vector(other) } } + + val isDefinitelyNonTriggering: Boolean = false } trait UnaryOp[E] { @@ -526,10 +530,12 @@ trait ConditionalFlyweight[T, V] { self: AnyRef => trait ConditionalFlyweightBinaryOp[T] extends ConditionalFlyweight[(Term, Term), T] with BinaryOp[Term] with Term { override val equalityDefiningMembers = (p0, p1) + override val isDefinitelyNonTriggering: Boolean = p0.isDefinitelyNonTriggering && p1.isDefinitelyNonTriggering } trait ConditionalFlyweightUnaryOp[T] extends ConditionalFlyweight[Term, T] with UnaryOp[Term] with Term { override val equalityDefiningMembers = p + override val isDefinitelyNonTriggering: Boolean = p.isDefinitelyNonTriggering } /** @@ -590,7 +596,9 @@ trait GeneralCondFlyweightFactory[IF, T <: IF, U, V <: U with ConditionalFlyweig /* Literals */ -sealed trait Literal extends Term +sealed trait Literal extends Term { + override val isDefinitelyNonTriggering: Boolean = true +} case object Unit extends SnapshotTerm with Literal { override lazy val toString = "_" @@ -1068,9 +1076,33 @@ object Ite extends CondFlyweightTermFactory[(Term, Term, Term), Ite] { case (False, _, e2) => e2 case (e0, True, False) => e0 case (e0, False, True) => Not(e0) + case (c, e1, e2) => + val eqs = getEqualities(c) + if (eqs.nonEmpty) { + val eqMap: scala.collection.immutable.Map[Term, Term] = eqs.map(eq => eq.p0 -> eq.p1).toMap + val eqFalseMap: scala.collection.immutable.Map[Term, Term] = eqs.flatMap(eq => { + Seq(eq -> False, eq.flip() -> False) + }).toMap + createIfNonExistent(c, replace(e1, eqMap), replace(e2, eqFalseMap)) + } else { + createIfNonExistent(v0) + } case _ => createIfNonExistent(v0) } + def replace(t: Term, rps: scala.collection.immutable.Map[Term, Term]): Term = { + assert(rps.nonEmpty) + t.transform { + case trm if rps.contains(trm) => rps(trm) + }() + } + + def getEqualities(t: Term): Seq[Equals] = t match { + case eq@Equals(_, _) => Seq(eq) + case And(ts) => ts.flatMap(getEqualities) + case _ => Seq() + } + override def actualCreate(args: (Term, Term, Term)): Ite = new Ite(args._1, args._2, args._3) } @@ -1078,9 +1110,12 @@ object Ite extends CondFlyweightTermFactory[(Term, Term, Term), Ite] { sealed trait ComparisonTerm extends BooleanTerm -sealed trait Equals extends ComparisonTerm with BinaryOp[Term] { override val op = "==" } +sealed trait Equals extends ComparisonTerm with BinaryOp[Term] { + override val op = "==" + def flip(): Equals +} -object Equals extends ((Term, Term) => BooleanTerm) { +object Equals extends ((Term, Term) => Term) { def apply(e0: Term, e1: Term) = { assert(e0.sort == e1.sort, s"Expected both operands to be of the same sort, but found ${e0.sort} ($e0) and ${e1.sort} ($e1).") @@ -1122,11 +1157,15 @@ object Equals extends ((Term, Term) => BooleanTerm) { } /* Represents built-in equality, e.g., '=' in SMT-LIB */ -class BuiltinEquals private[terms] (val p0: Term, val p1: Term) extends ConditionalFlyweightBinaryOp[BuiltinEquals] with Equals +class BuiltinEquals private[terms] (val p0: Term, val p1: Term) extends ConditionalFlyweightBinaryOp[BuiltinEquals] with Equals { + override def flip() = BuiltinEquals.createIfNonExistent(p1, p0) +} -object BuiltinEquals extends CondFlyweightFactory[(Term, Term), BooleanTerm, BuiltinEquals] { +object BuiltinEquals extends CondFlyweightTermFactory[(Term, Term), BuiltinEquals] { override def apply(v0: (Term, Term)) = v0 match { - case (v0: Var, v1: Var) if v0 == v1 => True + case (p0, p1) if p0 == p1 && p0.isDefinitelyNonTriggering && p1.isDefinitelyNonTriggering => True + case (p0, Ite(c, t1, t2)) => Ite(c, BuiltinEquals(p0, t1), BuiltinEquals(p0, t2)) + case (Ite(c, t1, t2), p0) => Ite(c, BuiltinEquals(t1, p0), BuiltinEquals(t2, p0)) case (p0: PermLiteral, p1: PermLiteral) => // NOTE: The else-case (False) is only justified because permission literals are stored in a normal form // such that two literals are semantically equivalent iff they are syntactically equivalent. @@ -1139,7 +1178,7 @@ object BuiltinEquals extends CondFlyweightFactory[(Term, Term), BooleanTerm, Bui /* Custom equality that (potentially) needs to be axiomatised. */ class CustomEquals private[terms] (val p0: Term, val p1: Term) extends ConditionalFlyweightBinaryOp[CustomEquals] with Equals { - + override def flip() = CustomEquals.createIfNonExistent(p1, p0) override val op = "===" } @@ -1403,6 +1442,11 @@ object PermPlus extends CondFlyweightTermFactory[(Term, Term), PermPlus] { case (FractionPerm(n1, d1), FractionPerm(n2, d2)) if d1 == d2 => FractionPerm(Plus(n1, n2), d1) case (PermMinus(t00, t01), t1) if t01 == t1 => t00 case (t0, PermMinus(t10, t11)) if t11 == t0 => t10 + case (Ite(c, t1, t2), t3) => Ite(c, PermPlus(t1, t3), PermPlus(t2, t3)) + case (t1, Ite(c, t2, t3)) => Ite(c, PermPlus(t1, t2), PermPlus(t1, t3)) + case (PermMin(t0, t1), t2) => PermMin(PermPlus(t0, t2), PermPlus(t1, t2)) + case (PermMax(t0, t1), t2) => PermMax(PermPlus(t0, t2), PermPlus(t1, t2)) + case (t0, PermMax(t1, t2)) => PermMax(PermPlus(t0, t1), PermPlus(t0, t2)) case (_, _) => createIfNonExistent(v0) } @@ -1428,9 +1472,21 @@ object PermMinus extends CondFlyweightTermFactory[(Term, Term), PermMinus] { case (t0, NoPerm) => t0 case (p0, p1) if p0 == p1 => NoPerm case (p0: PermLiteral, p1: PermLiteral) => FractionPermLiteral(p0.literal - p1.literal) - case (p0, PermMinus(p1, p2)) if p0 == p1 => p2 + case (p0, PermMinus(p1, p2)) => + if (p0 == p1) { + p2 + } else { + PermPlus(PermMinus(p0, p1), p2) + } + case (PermMinus(t0, t1), t2) => PermMinus(t0, PermPlus(t1, t2)) case (PermPlus(p0, p1), p2) if p0 == p2 => p1 case (PermPlus(p0, p1), p2) if p1 == p2 => p0 + case (Ite(c, t1, t2), t3) => Ite(c, PermMinus(t1, t3), PermMinus(t2, t3)) + case (t1, Ite(c, t2, t3)) => Ite(c, PermMinus(t1, t2), PermMinus(t1, t3)) + case (PermMin(p0, p1), p2) => PermMin(PermMinus(p0, p2), PermMinus(p1, p2)) + case (t0, PermMin(t1, t2)) => PermMax(PermMinus(t0, t1), PermMinus(t0, t2)) + case (PermMax(p0, p1), p2) => PermMax(PermMinus(p0, p2), PermMinus(p1, p2)) + case (t0, PermMax(t1, t2)) => PermMin(PermMinus(t0, t1), PermMinus(t0, t2)) case (_, _) => createIfNonExistent(v0) } @@ -1454,10 +1510,9 @@ object PermLess extends CondFlyweightTermFactory[(Term, Term), PermLess] { case (p0: PermLiteral, p1: PermLiteral) => if (p0.literal < p1.literal) True else False case (t0, Ite(tCond, tIf, tElse)) => - /* The pattern p0 < b ? p1 : p2 arises very often in the context of quantified permissions. - * Pushing the comparisons into the ite allows further simplifications. - */ Ite(tCond, PermLess(t0, tIf), PermLess(t0, tElse)) + case (Ite(tCond, tIf, tElse), t0) => + Ite(tCond, PermLess(tIf, t0), PermLess(tElse, t0)) case _ => createIfNonExistent(v0) } @@ -1476,6 +1531,10 @@ object PermAtMost extends CondFlyweightTermFactory[(Term, Term), PermAtMost] { override def apply(v0: (Term, Term)) = v0 match { case (p0: PermLiteral, p1: PermLiteral) => if (p0.literal <= p1.literal) True else False case (t0, t1) if t0 == t1 => True + case (t0, Ite(tCond, tIf, tElse)) => + Ite(tCond, PermAtMost(t0, tIf), PermAtMost(t0, tElse)) + case (Ite(tCond, tIf, tElse), t0) => + Ite(tCond, PermAtMost(tIf, t0), PermAtMost(tElse, t0)) case _ => createIfNonExistent(v0) } @@ -1496,12 +1555,40 @@ object PermMin extends CondFlyweightTermFactory[(Term, Term), PermMin] { override def apply(v0: (Term, Term)) = v0 match { case (t0, t1) if t0 == t1 => t0 case (p0: PermLiteral, p1: PermLiteral) => if (p0.literal > p1.literal) p1 else p0 + case (t0, Ite(tCond, tIf, tElse)) => + Ite(tCond, PermMin(t0, tIf), PermMin(t0, tElse)) + case (Ite(tCond, tIf, tElse), t0) => + Ite(tCond, PermMin(tIf, t0), PermMin(tElse, t0)) case _ => createIfNonExistent(v0) } override def actualCreate(args: (Term, Term)): PermMin = new PermMin(args._1, args._2) } +class PermMax private[terms] (val p0: Term, val p1: Term) extends Permissions + with BinaryOp[Term] + with ConditionalFlyweightBinaryOp[PermMax] { + + utils.assertSort(p0, "Permission 1st", sorts.Perm) + utils.assertSort(p1, "Permission 2nd", sorts.Perm) + + override lazy val toString = s"max ($p0, $p1)" +} + +object PermMax extends CondFlyweightTermFactory[(Term, Term), PermMax] { + override def apply(v0: (Term, Term)) = v0 match { + case (t0, t1) if t0 == t1 => t0 + case (p0: PermLiteral, p1: PermLiteral) => if (p0.literal < p1.literal) p1 else p0 + case (t0, Ite(tCond, tIf, tElse)) => + Ite(tCond, PermMax(t0, tIf), PermMax(t0, tElse)) + case (Ite(tCond, tIf, tElse), t0) => + Ite(tCond, PermMax(tIf, t0), PermMax(tElse, t0)) + case _ => createIfNonExistent(v0) + } + + override def actualCreate(args: (Term, Term)): PermMax = new PermMax(args._1, args._2) +} + /* Sequences */ sealed trait SeqTerm extends Term { @@ -2120,6 +2207,8 @@ class Combine(val p0: Term, val p1: Term) extends SnapshotTerm utils.assertSort(p1, "second operand", sorts.Snap) override lazy val toString = s"($p0, $p1)" + + override val isDefinitelyNonTriggering: Boolean = p0.isDefinitelyNonTriggering && p1.isDefinitelyNonTriggering } object Combine extends CondFlyweightTermFactory[(Term, Term), Combine] { @@ -2133,6 +2222,8 @@ class First(val p: Term) extends SnapshotTerm /*with PossibleTrigger*/ { utils.assertSort(p, "term", sorts.Snap) + + override val isDefinitelyNonTriggering: Boolean = p.isDefinitelyNonTriggering } object First extends CondFlyweightTermFactory[Term, First] { @@ -2149,6 +2240,8 @@ class Second(val p: Term) extends SnapshotTerm /*with PossibleTrigger*/ { utils.assertSort(p, "term", sorts.Snap) + + override val isDefinitelyNonTriggering: Boolean = p.isDefinitelyNonTriggering } object Second extends CondFlyweightTermFactory[Term, Second] { diff --git a/src/main/scala/state/Utils.scala b/src/main/scala/state/Utils.scala index 0c9368f2b..ca0cb5713 100644 --- a/src/main/scala/state/Utils.scala +++ b/src/main/scala/state/Utils.scala @@ -203,6 +203,7 @@ package object utils { case PermLess(p0, p1) => PermLess(go(p0), go(p1)) case PermAtMost(p0, p1) => PermAtMost(go(p0), go(p1)) case PermMin(p0, p1) => PermMin(go(p0), go(p1)) + case PermMax(p0, p1) => PermMax(go(p0), go(p1)) case App(f, ts) => App(f, ts map go) case SeqRanged(t0, t1) => SeqRanged(go(t0), go(t1)) case SeqSingleton(t) => SeqSingleton(go(t))