Skip to content

Commit

Permalink
Merge pull request #713 from viperproject/meilers_chained_comp
Browse files Browse the repository at this point in the history
Support for chained comparisons
  • Loading branch information
marcoeilers authored Jul 13, 2023
2 parents 5babb18 + eb6bb5e commit d769c1e
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 47 deletions.
75 changes: 38 additions & 37 deletions src/main/scala/viper/silver/parser/FastParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,16 @@ package viper.silver.parser

import java.net.URL
import java.nio.file.{Files, Path, Paths}

import viper.silver.ast.{FilePosition, LabelledOld, LineCol, NoPosition, Position, SourcePosition}
import viper.silver.parser.FastParserCompanion.{LW, LeadingWhitespace}
import viper.silver.plugin.{ParserPluginTemplate, SilverPluginManager}
import viper.silver.verifier.ParseError
import viper.silver.verifier.{ParseError, ParseWarning}

import scala.collection.{immutable, mutable}

case class ParseException(msg: String, pos: (Position, Position)) extends Exception

case class SuffixedExpressionGenerator[E <: PExp](func: PExp => E) extends (PExp => PExp) {
case class SuffixedExpressionGenerator[+E <: PExp](func: PExp => E) extends (PExp => PExp) {
override def apply(v1: PExp): E = func(v1)
}

Expand Down Expand Up @@ -125,6 +124,7 @@ class FastParser {
var _line_offset: Array[Int] = null
/** The file we are currently parsing (for creating positions later). */
var _file: Path = null
var _warnings: Seq[ParseWarning] = Seq()

def parse(s: String, f: Path, plugins: Option[SilverPluginManager] = None) = {
// Strategy to handle imports
Expand Down Expand Up @@ -390,9 +390,8 @@ class FastParser {

def quoted[$: P, T](p: => P[T]) = "\"" ~ p ~ "\""

def foldPExp[E <: PExp](e: PExp, es: Seq[SuffixedExpressionGenerator[E]]): E =
es.foldLeft(e) { (t, a) => a(t)
}.asInstanceOf[E]
def foldPExp[E <: PExp](e: E, es: Seq[SuffixedExpressionGenerator[E]]): E =
es.foldLeft(e) { (t, a) => a(t) }

def isFieldAccess(obj: Any) = {
obj.isInstanceOf[PFieldAccess]
Expand Down Expand Up @@ -481,49 +480,47 @@ class FastParser {
def exp[$: P]: P[PExp] = P(iteExpr)

def suffix[$: P]: P[SuffixedExpressionGenerator[PExp]] =
P(FP("." ~ idnuse).map { case (pos, id) => SuffixedExpressionGenerator[PExp]((e: PExp) => PFieldAccess(e, id)(pos)) } |
FP("[" ~ Pass ~ ".." ~/ exp ~ "]").map { case (pos, n) => SuffixedExpressionGenerator[PExp]((e: PExp) => PSeqTake(e, n)(pos)) } |
FP("[" ~ exp ~ ".." ~ Pass ~ "]").map { case (pos, n) => SuffixedExpressionGenerator[PExp]((e: PExp) => PSeqDrop(e, n)(pos)) } |
FP("[" ~ exp ~ ".." ~ exp ~ "]").map { case (pos, (n, m)) => SuffixedExpressionGenerator[PExp]((e: PExp) => PSeqDrop(PSeqTake(e, m)(pos), n)(pos)) } |
FP("[" ~ exp ~ "]").map { case (pos, e1) => SuffixedExpressionGenerator[PExp]((e0: PExp) => PLookup(e0, e1)(pos)) } |
FP("[" ~ exp ~ ":=" ~ exp ~ "]").map { case (pos, (i, v)) => SuffixedExpressionGenerator[PExp]((e: PExp) => PUpdate(e, i, v)(pos)) })

/*
Maps:
def suffix[$: P]: P[SuffixedExpressionGenerator[PExp]] =
P(FP("." ~ idnuse).map { case (pos, id) => SuffixedExpressionGenerator[PExp]((e: PExp) => {
PFieldAccess(e, id)(pos)
}) } |
FP("[" ~ Pass ~ ".." ~/ exp ~ "]").map { case (pos, n) => SuffixedExpressionGenerator[PExp]((e: PExp) => PSeqTake(e, n)(pos)) } |
FP("[" ~ exp ~ ".." ~ Pass ~ "]").map { case (pos, n) => SuffixedExpressionGenerator[PExp]((e: PExp) => PSeqDrop(e, n)(pos)) } |
FP("[" ~ exp ~ ".." ~ exp ~ "]").map { case (pos, (n, m)) => SuffixedExpressionGenerator[PExp]((e: PExp) => PSeqDrop(PSeqTake(e, m)(), n)(pos)) } |
FP("[" ~ exp ~ "]").map { case (pos, e1) => SuffixedExpressionGenerator[PExp]((e0: PExp) => PSeqIndex(e0, e1)(pos)) } |
FP("[" ~ exp ~ ":=" ~ exp ~ "]").map { case (pos, (i, v)) => SuffixedExpressionGenerator[PExp]((e: PExp) => PSeqUpdate(e, i, v)(pos)) })
*/
P(FP("." ~ idnuse).map { case (pos, id) => SuffixedExpressionGenerator[PExp](e => PFieldAccess(e, id)(e.pos._1, pos._2)) } |
FP("[" ~ Pass ~ ".." ~/ exp ~ "]").map { case (pos, n) => SuffixedExpressionGenerator[PExp](e => PSeqTake(e, n)(e.pos._1, pos._2)) } |
FP("[" ~ exp ~ ".." ~ Pass ~ "]").map { case (pos, n) => SuffixedExpressionGenerator[PExp](e => PSeqDrop(e, n)(e.pos._1, pos._2)) } |
FP("[" ~ exp ~ ".." ~ exp ~ "]").map { case (pos, (n, m)) => SuffixedExpressionGenerator[PExp](e => PSeqDrop(PSeqTake(e, m)(e.pos._1, pos._2), n)(e.pos._1, pos._2)) } |
FP("[" ~ exp ~ "]").map { case (pos, e1) => SuffixedExpressionGenerator[PExp](e0 => PLookup(e0, e1)(e0.pos._1, pos._2)) } |
FP("[" ~ exp ~ ":=" ~ exp ~ "]").map { case (pos, (i, v)) => SuffixedExpressionGenerator[PExp](e => PUpdate(e, i, v)(e.pos._1, pos._2)) })

def suffixExpr[$: P]: P[PExp] = P((atom ~~~ suffix.lw.rep).map { case (fac, ss) => foldPExp[PExp](fac, ss) })
def suffixExpr[$: P]: P[PExp] = P((atom ~~~ suffix.lw.rep).map { case (fac, ss) => foldPExp(fac, ss) })

def termOp[$: P]: P[String] = P(StringIn("*", "/", "\\", "%").!)

def term[$: P]: P[PExp] = P((suffixExpr ~~~ termd.lw.rep).map { case (a, ss) => foldPExp[PExp](a, ss) })
def term[$: P]: P[PExp] = P((suffixExpr ~~~ termd.lw.rep).map { case (a, ss) => foldPExp(a, ss) })

def termd[$: P]: P[SuffixedExpressionGenerator[PExp]] = FP(termOp ~ suffixExpr).map { case (pos, (op, id)) => SuffixedExpressionGenerator[PExp]((e: PExp) => PBinExp(e, op, id)(pos)) }
def termd[$: P]: P[SuffixedExpressionGenerator[PBinExp]] = FP(termOp ~ suffixExpr).map { case (pos, (op, id)) => SuffixedExpressionGenerator(e => PBinExp(e, op, id)(e.pos._1, pos._2)) }

def sumOp[$: P]: P[String] = P(StringIn("++", "+", "-").! | keyword("union").! | keyword("intersection").! | keyword("setminus").! | keyword("subset").!)

def sum[$: P]: P[PExp] = P((term ~~~ sumd.lw.rep).map { case (a, ss) => foldPExp[PBinExp](a, ss) })
def sum[$: P]: P[PExp] = P((term ~~~ sumd.lw.rep).map { case (a, ss) => foldPExp(a, ss) })

def sumd[$: P]: P[SuffixedExpressionGenerator[PBinExp]] = FP(sumOp ~ term).map { case (pos, (op, id)) => SuffixedExpressionGenerator[PBinExp]((e: PExp) => PBinExp(e, op, id)(pos)) }
def sumd[$: P]: P[SuffixedExpressionGenerator[PBinExp]] = FP(sumOp ~ term).map { case (pos, (op, id)) => SuffixedExpressionGenerator(e => PBinExp(e, op, id)(e.pos._1, pos._2)) }

def cmpOp[$: P] = P(StringIn("<=", ">=", "<", ">").! | keyword("in").!)

def cmpExp[$: P]: P[PExp] = FP(sum ~~~ (cmpOp ~ cmpExp).lw.?).map {
case (pos, (a, b)) => b match {
case Some(c) => PBinExp(a, c._1, c._2)(pos)
case None => a
}
val cmpOps = Set("<=", ">=", "<", ">", "in")

def cmpd[$: P]: P[PExp => SuffixedExpressionGenerator[PBinExp]] = FP(cmpOp ~ sum).map {
case (pos, (op, right)) => chainComp(op, right, pos)
}

def chainComp(op: String, right: PExp, pos: (FilePosition, FilePosition))(from: PExp) = SuffixedExpressionGenerator(_ match {
case left@PBinExp(_, op0, middle) if cmpOps.contains(op0) && left != from =>
PBinExp(left, "&&", PBinExp(middle, op, right)(middle.pos._1, pos._2))(left.pos._1, pos._2)
case left@PBinExp(_, "&&", PBinExp(_, op0, middle)) if cmpOps.contains(op0) && left != from =>
PBinExp(left, "&&", PBinExp(middle, op, right)(middle.pos._1, pos._2))(left.pos._1, pos._2)
case left => PBinExp(left, op, right)(left.pos._1, pos._2)
})

def cmpExp[$: P]: P[PExp] = P((sum ~~~ cmpd.lw.rep).map {
case (from, others) => foldPExp(from, others.map(_(from)))
})

def eqOp[$: P] = P(StringIn("==", "!=").!)

def eqExp[$: P]: P[PExp] = FP(cmpExp ~~~ (eqOp ~ eqExp).lw.?).map {
Expand Down Expand Up @@ -631,11 +628,15 @@ class FastParser {
case (pos, (keyType, valueType)) => PMapType(keyType, valueType)(pos)
}


/** Only for call-like macros, `idnuse`-like ones are parsed by `domainTyp`. */
def macroType[$: P] : P[PMacroType] = funcApp.map(PMacroType(_))

def primitiveTyp[$: P]: P[PPrimitiv] = P(FP(keyword("Rational")).map{ case (pos, _) => PPrimitiv("Perm")(pos)}
| FP((StringIn("Int", "Bool", "Perm", "Ref") ~~ !identContinues).!).map{ case (pos, name) => PPrimitiv(name)(pos)})
def primitiveTyp[$: P]: P[PPrimitiv] = P(FP(keyword("Rational")).map {
case (pos, _) =>
_warnings = _warnings :+ ParseWarning("Rational is deprecated, use Perm instead", SourcePosition(_file, pos._1.line, pos._1.column))
PPrimitiv("Perm")(pos)
} | FP((StringIn("Int", "Bool", "Perm", "Ref") ~~ !identContinues).!).map { case (pos, name) => PPrimitiv(name)(pos) })
/* Maps:
lazy val primitiveTyp: P[PType] = P(keyword("Rational").map(_ => PPrimitiv("Perm"))
| (StringIn("Int", "Bool", "Perm", "Ref") ~~ !identContinues).!.map(PPrimitiv))
Expand Down
55 changes: 55 additions & 0 deletions src/test/resources/all/basic/comparisons.vpr
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Any copyright is dedicated to the Public Domain.
// http://creativecommons.org/publicdomain/zero/1.0/

method chain(i1: Int, i2: Int, i3: Int, i4: Int)
requires i1 < i2 <= i3 > i4
{
assert i1 < i2
assert i2 <= i3
assert i3 > i4
//:: ExpectedOutput(assert.failed:assertion.false)
assert i2 < i3
}

method chain1In(s1: Seq[Int], s2: Seq[Seq[Int]], s3: Seq[Seq[Seq[Int]]], s4: Set[Seq[Seq[Seq[Int]]]])
requires s1 in s2 in s3 in s4
{
assert s1 in s2
assert s3 in s4
assert s2 in s3
//:: ExpectedOutput(assert.failed:assertion.false)
assert 3 in s1
}


method chainEq(i1: Int, i2: Int, i3: Int, i4: Int)
requires i1 < i2 == i3 > i4
{
//:: ExpectedOutput(assert.failed:assertion.false)
assert i1 < i2
}

method nonChain(i1: Int, i2: Int, i3: Int, i4: Int)
requires i1 < i2 && i3 > i4
{
assert i1 < i2
assert i3 > i4
//:: ExpectedOutput(assert.failed:assertion.false)
assert i2 <= i3
}

method chainParen(i1: Int, i2: Multiset[Int], i3: Int, i4: Int)
requires (i1 in i2) <= i3 > i4
{
assert i3 > i4
assume (i1 in i2) == 3
assert 3 <= i3
}

method chainParen2(i1: Int, i2: Int, i3: Int, i4: Multiset[Int])
requires i1 < i2 <= (i3 in i4)
{
assert i1 < i2
assume (i3 in i4) == 3
assert i2 <= 3
}
26 changes: 26 additions & 0 deletions src/test/resources/all/basic/comparisons_fail.vpr
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Any copyright is dedicated to the Public Domain.
// http://creativecommons.org/publicdomain/zero/1.0/

method chain(i1: Int, i2: Ref, i3: Int, i4: Int)
//:: ExpectedOutput(typechecker.error)
requires i1 < i2 <= i3 > i4
{
}

method chain1In(s1: Seq[Int], s2: Seq[Seq[Int]], s3: Seq[Seq[Seq[Ref]]], s4: Set[Seq[Seq[Seq[Int]]]])
//:: ExpectedOutput(typechecker.error)
requires s1 in s2 in s3 in s4
{
}

method chainParen(i1: Int, i2: Int, i3: Int, i4: Int)
//:: ExpectedOutput(typechecker.error)
requires (i1 < i2) <= i3 > i4
{
}

method chainParen2(i1: Int, i2: Int, i3: Int, i4: Int)
//:: ExpectedOutput(typechecker.error)
requires i1 < i2 <= (i3 > i4)
{
}
46 changes: 36 additions & 10 deletions src/test/scala/AstPositionsTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class AstPositionsTests extends AnyFunSuite {
|method sum(x: Ref, g: Set[Ref])
| returns (res: Bool)
| ensures false && 4/x.foo > 0
| ensures 3 < 4 < 5
|{
| inhale forall r: Ref :: r in g ==> acc(r.foo)
| assert acc(x.foo)
Expand Down Expand Up @@ -82,7 +83,7 @@ class AstPositionsTests extends AnyFunSuite {
val m: Method = res.methods(0)
m.pos match {
case spos: AbstractSourcePosition => {
assert(spos.start.line === 2 && spos.end.get.line === 10)
assert(spos.start.line === 2 && spos.end.get.line === 11)
assert(spos.start.column === 1 && spos.end.get.column === 2)
}
case _ =>
Expand All @@ -99,22 +100,47 @@ class AstPositionsTests extends AnyFunSuite {
case _ =>
fail("method args must have start and end positions set")
}
// Check position of method post
assert(m.posts.length === 1)
val post: Exp = m.posts(0).asInstanceOf[BinExp].args(1);
post.pos match {
// Check positions of method posts
assert(m.posts.length === 2)
val post1: Exp = m.posts(0).asInstanceOf[BinExp].args(1);
post1.pos match {
case spos: AbstractSourcePosition => {
assert(spos.start.line === 4 && spos.end.get.line === 4)
assert(spos.start.column === 20 && spos.end.get.column === 31)
}
case _ =>
fail("method posts must have start and end positions set")
}
val post2: BinExp = m.posts(1).asInstanceOf[BinExp]
post2.pos match {
case spos: AbstractSourcePosition => {
assert(spos.start.line === 5 && spos.end.get.line === 5)
assert(spos.start.column === 11 && spos.end.get.column === 20)
}
case _ =>
fail("method posts must have start and end positions set")
}
post2.left.pos match {
case spos: AbstractSourcePosition => {
assert(spos.start.line === 5 && spos.end.get.line === 5)
assert(spos.start.column === 11 && spos.end.get.column === 16)
}
case _ =>
fail("method posts must have start and end positions set")
}
post2.right.pos match {
case spos: AbstractSourcePosition => {
assert(spos.start.line === 5 && spos.end.get.line === 5)
assert(spos.start.column === 15 && spos.end.get.column === 20)
}
case _ =>
fail("method posts must have start and end positions set")
}
// Check position of body
val block = m.body.get
block.pos match {
case spos: AbstractSourcePosition => {
assert(spos.start.line === 5 && spos.end.get.line === 10)
assert(spos.start.line === 6 && spos.end.get.line === 11)
assert(spos.start.column === 1 && spos.end.get.column === 2)
}
case _ =>
Expand All @@ -125,7 +151,7 @@ class AstPositionsTests extends AnyFunSuite {
val forall: Stmt = block.ss(0)
forall.pos match {
case spos: AbstractSourcePosition => {
assert(spos.start.line === 6 && spos.end.get.line === 6)
assert(spos.start.line === 7 && spos.end.get.line === 7)
assert(spos.start.column === 3 && spos.end.get.column === 48)
}
case _ =>
Expand All @@ -135,7 +161,7 @@ class AstPositionsTests extends AnyFunSuite {
val assert_exp: Exp = block.ss(1).asInstanceOf[Assert].exp
assert_exp.pos match {
case spos: AbstractSourcePosition => {
assert(spos.start.line === 7 && spos.end.get.line === 7)
assert(spos.start.line === 8 && spos.end.get.line === 8)
assert(spos.start.column === 10 && spos.end.get.column === 20)
}
case _ =>
Expand All @@ -146,7 +172,7 @@ class AstPositionsTests extends AnyFunSuite {
val m2: Method = res.methods(1);
m2.pos match {
case spos: AbstractSourcePosition => {
assert(spos.start.line === 11 && spos.end.get.line === 12)
assert(spos.start.line === 12 && spos.end.get.line === 13)
assert(spos.start.column === 1 && spos.end.get.column === 26)
}
case _ =>
Expand All @@ -156,7 +182,7 @@ class AstPositionsTests extends AnyFunSuite {
val pre: Exp = m2.pres(0);
pre.pos match {
case spos: AbstractSourcePosition => {
assert(spos.start.line === 12 && spos.end.get.line === 12)
assert(spos.start.line === 13 && spos.end.get.line === 13)
assert(spos.start.column === 13 && spos.end.get.column === 26)
}
case _ =>
Expand Down

0 comments on commit d769c1e

Please sign in to comment.