Skip to content

Commit

Permalink
Caching simplification results
Browse files Browse the repository at this point in the history
  • Loading branch information
marcoeilers committed Aug 28, 2024
1 parent 6326764 commit e0db674
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 62 deletions.
3 changes: 2 additions & 1 deletion src/main/scala/viper/silver/ast/Expression.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import viper.silver.verifier.{ConsistencyError, VerificationResult}

/** Expressions. */
sealed trait Exp extends Hashable with Typed with Positioned with Infoed with TransformableErrors with PrettyExpression {
var simplified: Option[Exp] = None
lazy val isPure = Expressions.isPure(this)
def isHeapDependent(p: Program) = Expressions.isHeapDependent(this, p)
def isTopLevelHeapDependent(p: Program) = Expressions.isTopLevelHeapDependent(this, p)
Expand Down Expand Up @@ -47,7 +48,7 @@ sealed trait Exp extends Hashable with Typed with Positioned with Infoed with Tr
*/
// lazy val proofObligations = Expressions.proofObligations(this)

override def toString() = {
override lazy val toString = {
// Carbon relies on expression pretty-printing resulting in a string without line breaks,
// so for the special case of directly converting an expression to a string, we remove all line breaks
// the pretty printer might have inserted.
Expand Down
140 changes: 81 additions & 59 deletions src/main/scala/viper/silver/ast/utility/Simplifier.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,142 +23,144 @@ object Simplifier {
* might be transformed to terminating expression.
*/
def simplify[N <: Node](n: N, assumeWelldefinedness: Boolean = false): N = {
/* Always simplify children first, then treat parent. */
StrategyBuilder.Slim[Node]({

val simplifySingle: PartialFunction[Node, Node] = {
// expression simplifications
case root @ Not(BoolLit(literal)) =>
case root: Exp if root.simplified.isDefined =>
root.simplified.get
case root@Not(BoolLit(literal)) =>
BoolLit(!literal)(root.pos, root.info)
case Not(Not(single)) => single
case root @ Not(EqCmp(a, b)) => NeCmp(a, b)(root.pos, root.info)
case root @ Not(NeCmp(a, b)) => EqCmp(a, b)(root.pos, root.info)
case root @ Not(GtCmp(a, b)) => LeCmp(a, b)(root.pos, root.info)
case root @ Not(GeCmp(a, b)) => LtCmp(a, b)(root.pos, root.info)
case root @ Not(LtCmp(a, b)) => GeCmp(a, b)(root.pos, root.info)
case root @ Not(LeCmp(a, b)) => GtCmp(a, b)(root.pos, root.info)
case root@Not(EqCmp(a, b)) => NeCmp(a, b)(root.pos, root.info)
case root@Not(NeCmp(a, b)) => EqCmp(a, b)(root.pos, root.info)
case root@Not(GtCmp(a, b)) => LeCmp(a, b)(root.pos, root.info)
case root@Not(GeCmp(a, b)) => LtCmp(a, b)(root.pos, root.info)
case root@Not(LtCmp(a, b)) => GeCmp(a, b)(root.pos, root.info)
case root@Not(LeCmp(a, b)) => GtCmp(a, b)(root.pos, root.info)

case And(TrueLit(), right) => right
case And(left, TrueLit()) => left
case root @ And(FalseLit(), _) => FalseLit()(root.pos, root.info)
case root @ And(_, FalseLit()) => FalseLit()(root.pos, root.info)
case root@And(FalseLit(), _) => FalseLit()(root.pos, root.info)
case root@And(_, FalseLit()) => FalseLit()(root.pos, root.info)

case Or(FalseLit(), right) => right
case Or(left, FalseLit()) => left
case root @ Or(TrueLit(), _) => TrueLit()(root.pos, root.info)
case root @ Or(_, TrueLit()) => TrueLit()(root.pos, root.info)
case root@Or(TrueLit(), _) => TrueLit()(root.pos, root.info)
case root@Or(_, TrueLit()) => TrueLit()(root.pos, root.info)

case root @ Implies(FalseLit(), _) => TrueLit()(root.pos, root.info)
case Implies(_, tl @ TrueLit()) if assumeWelldefinedness => tl
case root@Implies(FalseLit(), _) => TrueLit()(root.pos, root.info)
case Implies(_, tl@TrueLit()) if assumeWelldefinedness => tl
case Implies(TrueLit(), consequent) => consequent
case root @ Implies(FalseLit(), _) => TrueLit()(root.pos, root.info)
case root @ Implies(l1, Implies(l2, r)) => Implies(And(l1, l2)(), r)(root.pos, root.info)
case root@Implies(FalseLit(), _) => TrueLit()(root.pos, root.info)
case root@Implies(l1, Implies(l2, r)) => Implies(And(l1, l2)(), r)(root.pos, root.info)

// TODO: Consider checking if Expressions.proofObligations(left) is empty (requires adding the program as parameter).
case root @ EqCmp(left, right) if assumeWelldefinedness && left == right => TrueLit()(root.pos, root.info)
case root @ EqCmp(BoolLit(left), BoolLit(right)) =>
case root@EqCmp(left, right) if assumeWelldefinedness && left == right => TrueLit()(root.pos, root.info)
case root@EqCmp(BoolLit(left), BoolLit(right)) =>
BoolLit(left == right)(root.pos, root.info)
case root @ EqCmp(FalseLit(), right) => Not(right)(root.pos, root.info)
case root @ EqCmp(left, FalseLit()) => Not(left)(root.pos, root.info)
case root@EqCmp(FalseLit(), right) => Not(right)(root.pos, root.info)
case root@EqCmp(left, FalseLit()) => Not(left)(root.pos, root.info)
case EqCmp(TrueLit(), right) => right
case EqCmp(left, TrueLit()) => left
case root @ EqCmp(IntLit(left), IntLit(right)) =>
case root@EqCmp(IntLit(left), IntLit(right)) =>
BoolLit(left == right)(root.pos, root.info)
case root @ EqCmp(NullLit(), NullLit()) => TrueLit()(root.pos, root.info)
case root@EqCmp(NullLit(), NullLit()) => TrueLit()(root.pos, root.info)

case root @ NeCmp(BoolLit(left), BoolLit(right)) =>
case root@NeCmp(BoolLit(left), BoolLit(right)) =>
BoolLit(left != right)(root.pos, root.info)
case NeCmp(FalseLit(), right) => right
case NeCmp(left, FalseLit()) => left
case root @ NeCmp(TrueLit(), right) => Not(right)(root.pos, root.info)
case root @ NeCmp(left, TrueLit()) => Not(left)(root.pos, root.info)
case root @ NeCmp(IntLit(left), IntLit(right)) =>
case root@NeCmp(TrueLit(), right) => Not(right)(root.pos, root.info)
case root@NeCmp(left, TrueLit()) => Not(left)(root.pos, root.info)
case root@NeCmp(IntLit(left), IntLit(right)) =>
BoolLit(left != right)(root.pos, root.info)
case root @ NeCmp(NullLit(), NullLit()) =>
case root@NeCmp(NullLit(), NullLit()) =>
FalseLit()(root.pos, root.info)
case root @ NeCmp(left, right) if assumeWelldefinedness && left == right => FalseLit()(root.pos, root.info)
case root@NeCmp(left, right) if assumeWelldefinedness && left == right => FalseLit()(root.pos, root.info)

case CondExp(TrueLit(), ifTrue, _) => ifTrue
case CondExp(FalseLit(), _, ifFalse) => ifFalse
case CondExp(_, ifTrue, ifFalse) if assumeWelldefinedness && ifTrue == ifFalse =>
ifTrue
case root @ CondExp(condition, FalseLit(), TrueLit()) =>
case root@CondExp(condition, FalseLit(), TrueLit()) =>
Not(condition)(root.pos, root.info)
case CondExp(condition, TrueLit(), FalseLit()) => condition
case root @ CondExp(condition, FalseLit(), ifFalse) =>
case root@CondExp(condition, FalseLit(), ifFalse) =>
And(Not(condition)(), ifFalse)(root.pos, root.info)
case root @ CondExp(condition, TrueLit(), ifFalse) =>
case root@CondExp(condition, TrueLit(), ifFalse) =>
if (ifFalse.isPure) {
Or(condition, ifFalse)(root.pos, root.info)
} else {
Implies(Not(condition)(), ifFalse)(root.pos, root.info)
}
case root @ CondExp(condition, ifTrue, FalseLit()) =>
case root@CondExp(condition, ifTrue, FalseLit()) =>
And(condition, ifTrue)(root.pos, root.info)
case root @ CondExp(condition, ifTrue, TrueLit()) =>
case root@CondExp(condition, ifTrue, TrueLit()) =>
Implies(condition, ifTrue)(root.pos, root.info)

case root @ Forall(_, _, BoolLit(literal)) =>
case root@Forall(_, _, BoolLit(literal)) =>
BoolLit(literal)(root.pos, root.info)
case root @ Exists(_, _, BoolLit(literal)) =>
case root@Exists(_, _, BoolLit(literal)) =>
BoolLit(literal)(root.pos, root.info)

case root @ Minus(IntLit(literal)) => IntLit(-literal)(root.pos, root.info)
case root@Minus(IntLit(literal)) => IntLit(-literal)(root.pos, root.info)
case Minus(Minus(single)) => single

case PermMinus(PermMinus(single)) => single
case PermMul(fst, FullPerm()) => fst
case PermMul(FullPerm(), snd) => snd
case PermMul(_, np @ NoPerm()) if assumeWelldefinedness => np
case PermMul(np @ NoPerm(), _) if assumeWelldefinedness => np
case PermMul(_, np@NoPerm()) if assumeWelldefinedness => np
case PermMul(np@NoPerm(), _) if assumeWelldefinedness => np

case root @ PermGeCmp(a, b) if assumeWelldefinedness && a == b => TrueLit()(root.pos, root.info)
case root @ PermLeCmp(a, b) if assumeWelldefinedness && a == b => TrueLit()(root.pos, root.info)
case root @ PermGtCmp(a, b) if assumeWelldefinedness && a == b => FalseLit()(root.pos, root.info)
case root @ PermLtCmp(a, b) if assumeWelldefinedness && a == b => FalseLit()(root.pos, root.info)
case root@PermGeCmp(a, b) if assumeWelldefinedness && a == b => TrueLit()(root.pos, root.info)
case root@PermLeCmp(a, b) if assumeWelldefinedness && a == b => TrueLit()(root.pos, root.info)
case root@PermGtCmp(a, b) if assumeWelldefinedness && a == b => FalseLit()(root.pos, root.info)
case root@PermLtCmp(a, b) if assumeWelldefinedness && a == b => FalseLit()(root.pos, root.info)

case root @ PermGtCmp(AnyPermLiteral(a, b), AnyPermLiteral(c, d)) =>
case root@PermGtCmp(AnyPermLiteral(a, b), AnyPermLiteral(c, d)) =>
BoolLit(Rational(a, b) > Rational(c, d))(root.pos, root.info)
case root @ PermGeCmp(AnyPermLiteral(a, b), AnyPermLiteral(c, d)) =>
case root@PermGeCmp(AnyPermLiteral(a, b), AnyPermLiteral(c, d)) =>
BoolLit(Rational(a, b) >= Rational(c, d))(root.pos, root.info)
case root @ PermLtCmp(AnyPermLiteral(a, b), AnyPermLiteral(c, d)) =>
case root@PermLtCmp(AnyPermLiteral(a, b), AnyPermLiteral(c, d)) =>
BoolLit(Rational(a, b) < Rational(c, d))(root.pos, root.info)
case root @ PermLeCmp(AnyPermLiteral(a, b), AnyPermLiteral(c, d)) =>
case root@PermLeCmp(AnyPermLiteral(a, b), AnyPermLiteral(c, d)) =>
BoolLit(Rational(a, b) <= Rational(c, d))(root.pos, root.info)
case DebugPermMin(e0@AnyPermLiteral(a, b), e1@AnyPermLiteral(c, d)) =>
if (Rational(a, b) < Rational(c, d)) {
e0
} else {
e1
}
case root @ PermSub(AnyPermLiteral(a, b), AnyPermLiteral(c, d)) =>
case root@PermSub(AnyPermLiteral(a, b), AnyPermLiteral(c, d)) =>
val diff = Rational(a, b) - Rational(c, d)
FractionalPerm(IntLit(diff.numerator)(root.pos, root.info), IntLit(diff.denominator)(root.pos, root.info))(root.pos, root.info)
case root @ PermAdd(AnyPermLiteral(a, b), AnyPermLiteral(c, d)) =>
case root@PermAdd(AnyPermLiteral(a, b), AnyPermLiteral(c, d)) =>
val sum = Rational(a, b) + Rational(c, d)
FractionalPerm(IntLit(sum.numerator)(root.pos, root.info), IntLit(sum.denominator)(root.pos, root.info))(root.pos, root.info)

case root @ GeCmp(IntLit(left), IntLit(right)) =>
case root@GeCmp(IntLit(left), IntLit(right)) =>
BoolLit(left >= right)(root.pos, root.info)
case root @ GtCmp(IntLit(left), IntLit(right)) =>
case root@GtCmp(IntLit(left), IntLit(right)) =>
BoolLit(left > right)(root.pos, root.info)
case root @ LeCmp(IntLit(left), IntLit(right)) =>
case root@LeCmp(IntLit(left), IntLit(right)) =>
BoolLit(left <= right)(root.pos, root.info)
case root @ LtCmp(IntLit(left), IntLit(right)) =>
case root@LtCmp(IntLit(left), IntLit(right)) =>
BoolLit(left < right)(root.pos, root.info)

case root @ Add(IntLit(left), IntLit(right)) =>
case root@Add(IntLit(left), IntLit(right)) =>
IntLit(left + right)(root.pos, root.info)
case root @ Sub(IntLit(left), IntLit(right)) =>
case root@Sub(IntLit(left), IntLit(right)) =>
IntLit(left - right)(root.pos, root.info)
case root @ Mul(IntLit(left), IntLit(right)) =>
case root@Mul(IntLit(left), IntLit(right)) =>
IntLit(left * right)(root.pos, root.info)
/* In the general case, Viper uses the SMT division and modulo. Scala's division is not in-sync with SMT division.
For nonnegative dividends and divisors, all used division and modulo definitions coincide. So, in order to not
not make any assumptions on the SMT division, division and modulo are simplified only if the dividend and divisor
are nonnegative. Also see Carbon PR #448.
*/
case root @ Div(IntLit(left), IntLit(right)) if left >= bigIntZero && right > bigIntZero =>
case root@Div(IntLit(left), IntLit(right)) if left >= bigIntZero && right > bigIntZero =>
IntLit(left / right)(root.pos, root.info)
case root @ Mod(IntLit(left), IntLit(right)) if left >= bigIntZero && right > bigIntZero =>
case root@Mod(IntLit(left), IntLit(right)) if left >= bigIntZero && right > bigIntZero =>
IntLit(left % right)(root.pos, root.info)

// statement simplifications
Expand All @@ -169,7 +171,27 @@ object Simplifier {
case If(_, EmptyStmt, EmptyStmt) => EmptyStmt // remove empty If clause
case If(TrueLit(), thn, _) => thn // remove trivial If conditions
case If(FalseLit(), _, els) => els // remove trivial If conditions
}, Traverse.BottomUp) execute n
}

val simplifyAndCache = new PartialFunction[Node, Node] {
def apply(n: Node): Node = {
val simplified = simplifySingle.applyOrElse(n, (nn: Node) => nn)
n match {
case e: Exp =>
e.simplified = Some(simplified.asInstanceOf[Exp])
simplified.asInstanceOf[Exp].simplified = Some(simplified.asInstanceOf[Exp])
case _ =>
}
simplified
}

def isDefinedAt(n: Node): Boolean = n.isInstanceOf[Exp] || simplifySingle.isDefinedAt(n)
}

/* Always simplify children first, then treat parent. */
StrategyBuilder.Slim[Node](simplifyAndCache, Traverse.BottomUp).recurseFunc({
case e: Exp if e.simplified.isDefined => Nil
}) execute n
}

private val bigIntZero = BigInt(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ trait HasExtraVars {
*/
trait Rewritable extends Product {

def children: Seq[Any] = productIterator.toList
lazy val children: Seq[Any] = productIterator.toSeq

def withChildren(children: Seq[Any], pos: Option[(Position, Position)] = None, forceRewrite: Boolean = false): this.type = {
if (!forceRewrite && this.children == children && !pos.isDefined)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ case class DecreasesSpecification(tuple: Option[DecreasesTuple],
star: Option[DecreasesStar]) extends Info {

// The comment of this metadata are the provided decreases clauses
override lazy val comment: Seq[String] = (tuple ++ wildcard ++ star).map(_.toString()).toSeq
override lazy val comment: Seq[String] = (tuple ++ wildcard ++ star).map(_.toString).toSeq
override val isCached: Boolean = false

/**
Expand Down

0 comments on commit e0db674

Please sign in to comment.