Skip to content

Commit

Permalink
Get rid of quantifiers in pointer cast proof helpers, improve naming
Browse files Browse the repository at this point in the history
  • Loading branch information
superaxander committed Sep 18, 2024
1 parent bede2fc commit df9ebc4
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 154 deletions.
9 changes: 9 additions & 0 deletions src/col/vct/col/ast/Node.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1409,6 +1409,10 @@ final case class DerefPointer[G](pointer: Expr[G])(
val blame: Blame[PointerDerefError]
)(implicit val o: Origin)
extends Expr[G] with DerefPointerImpl[G]
final case class DerefPointerTyped[G](pointer: Expr[G], elementType: Type[G])(
val blame: Blame[PointerDerefError]
)(implicit val o: Origin)
extends Expr[G] with DerefPointerTypedImpl[G]
final case class RawDerefPointer[G](pointer: Expr[G])(
val blame: Blame[PointerDerefError]
)(implicit val o: Origin)
Expand Down Expand Up @@ -1766,6 +1770,11 @@ final case class PointerLocation[G](pointer: Expr[G])(
val blame: Blame[PointerLocationError]
)(implicit val o: Origin)
extends Location[G] with PointerLocationImpl[G]
final case class TypedPointerLocation[G](
pointer: Expr[G],
pointerType: Type[G],
)(val blame: Blame[PointerLocationError])(implicit val o: Origin)
extends Location[G] with TypedPointerLocationImpl[G]
final case class ByValueClassLocation[G](expr: Expr[G])(
val blame: Blame[PointerLocationError]
)(implicit val o: Origin)
Expand Down
13 changes: 13 additions & 0 deletions src/col/vct/col/ast/expr/heap/read/DerefPointerTypedImpl.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package vct.col.ast.expr.heap.read

import vct.col.ast.ops.DerefPointerTypedOps
import vct.col.ast.{DerefPointerTyped, Type}
import vct.col.print._

trait DerefPointerTypedImpl[G] extends DerefPointerTypedOps[G] {
this: DerefPointerTyped[G] =>
override def t: Type[G] = elementType

override def precedence: Int = Precedence.PREFIX
override def layout(implicit ctx: Ctx): Doc = Text("*") <> assoc(pointer)
}
11 changes: 11 additions & 0 deletions src/col/vct/col/ast/family/location/TypedPointerLocationImpl.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package vct.col.ast.family.location

import vct.col.ast.TypedPointerLocation
import vct.col.ast.ops.TypedPointerLocationOps
import vct.col.print.{Ctx, Doc, Text}

trait TypedPointerLocationImpl[G] extends TypedPointerLocationOps[G] {
this: TypedPointerLocation[G] =>
override def layout(implicit ctx: Ctx): Doc =
Text("(") <> pointerType <> ")" <> pointer
}
4 changes: 4 additions & 0 deletions src/col/vct/col/typerules/CoercingRewriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1261,6 +1261,8 @@ abstract class CoercingRewriter[Pre <: Generation]()
case deref @ Deref(obj, ref) => Deref(cls(obj), ref)(deref.blame)
case deref @ DerefHeapVariable(ref) => DerefHeapVariable(ref)(deref.blame)
case deref @ DerefPointer(p) => DerefPointer(pointer(p)._1)(deref.blame)
case deref @ DerefPointerTyped(p, t) =>
DerefPointerTyped(pointer(p)._1, t)(deref.blame)
case deref @ RawDerefPointer(p) =>
RawDerefPointer(pointer(p)._1)(deref.blame)
case Drop(xs, count) => Drop(seq(xs)._1, int(count))
Expand Down Expand Up @@ -2713,6 +2715,8 @@ abstract class CoercingRewriter[Pre <: Generation]()
ArrayLocation(array(arrayObj)._1, int(subscript))(a.blame)
case p @ PointerLocation(pointerExp) =>
PointerLocation(pointer(pointerExp)._1)(p.blame)
case p @ TypedPointerLocation(pointerExp, pointerType) =>
TypedPointerLocation(pointer(pointerExp)._1, pointerType)(p.blame)
case ByValueClassLocation(expr) => node
case PredicateLocation(inv) => PredicateLocation(inv)
case al @ AmbiguousLocation(expr) => AmbiguousLocation(expr)(al.blame)
Expand Down
203 changes: 127 additions & 76 deletions src/rewrite/vct/rewrite/ClassToRef.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import vct.col.origin._
import vct.result.VerificationError
import vct.col.util.AstBuildHelpers._
import hre.util.ScopedStack
import vct.col.print.Ctx
import vct.col.rewrite.error.{ExcludedByPassOrder, ExtraNode}
import vct.col.ref.Ref
import vct.col.resolve.ctx.Referrable
Expand Down Expand Up @@ -55,6 +56,8 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] {
private def This: Origin =
Origin(Seq(PreferredName(Seq("this")), LabelContext("classToRef")))

private var namingContext: Ctx = null

val byRefFieldSucc: SuccessionMap[Field[Pre], SilverField[Post]] =
SuccessionMap()
val byValFieldSucc: SuccessionMap[Field[Pre], ADTFunction[Post]] =
Expand Down Expand Up @@ -85,8 +88,11 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] {
val valueAsFunctions: mutable.Map[Type[Pre], ADTFunction[Post]] = mutable
.Map()

val castHelpers: SuccessionMap[Type[Pre], Procedure[Post]] = SuccessionMap()
val requiredCastHelpers: ScopedStack[mutable.Set[Type[Pre]]] = ScopedStack()
val castHelpers: SuccessionMap[(Type[Post], Type[Pre]), Procedure[Post]] =
SuccessionMap()
val requiredCastHelpers
: ScopedStack[mutable.Set[(Expr[Post], Type[Pre], Type[Pre])]] =
ScopedStack()

def typeNumber(cls: Class[Pre]): Int =
typeNumberStore.getOrElseUpdate(cls, typeNumberStore.size + 1)
Expand Down Expand Up @@ -194,9 +200,10 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] {
axiomType,
body = { a =>
InlinePattern(adtFunctionInvocation[Post](
valueAsFunctions
.getOrElseUpdate(oldT, makeValueAsFunction(oldT.toString, newT))
.ref,
valueAsFunctions.getOrElseUpdate(
oldT,
makeValueAsFunction(oldT.toStringWithContext(namingContext), newT),
).ref,
typeArgs = Some((valueAdt.ref(()), Seq(axiomType))),
args = Seq(a),
)) === Cast(
Expand All @@ -207,7 +214,8 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] {
))
}

override def dispatch(program: Program[Pre]): Program[Rewritten[Pre]] =
override def dispatch(program: Program[Pre]): Program[Rewritten[Pre]] = {
namingContext = Ctx().namesIn(program)
program.rewrite(declarations =
globalDeclarations.collect {
program.declarations.foreach(dispatch)
Expand All @@ -220,6 +228,7 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] {
}
}._1
)
}

override def dispatch(decl: Declaration[Pre]): Unit =
decl match {
Expand Down Expand Up @@ -396,7 +405,10 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] {
InlinePattern(adtFunctionInvocation[Post](
valueAsFunctions.getOrElseUpdate(
field.t,
makeValueAsFunction(field.t.toString, newT),
makeValueAsFunction(
field.t.toStringWithContext(namingContext),
newT,
),
).ref,
typeArgs = Some((valueAdt.ref(()), Seq(axiomType))),
args = Seq(a),
Expand Down Expand Up @@ -577,15 +589,15 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] {

private def addCastConstraints(
expr: Expr[Pre],
totalHelpers: mutable.Set[Type[Pre]],
totalHelpers: mutable.Set[(Expr[Post], Type[Pre], Type[Pre])],
): Expr[Post] = {
val helpers: mutable.Set[Type[Pre]] = mutable.Set()
val helpers: mutable.Set[(Expr[Post], Type[Pre], Type[Pre])] = mutable.Set()
var result: Seq[Expr[Post]] = Nil
for (clause <- expr.unfoldStar) {
val newClause = requiredCastHelpers.having(helpers) { dispatch(clause) }
if (helpers.nonEmpty) {
result ++= helpers.map { t =>
unwrapCastConstraints(dispatch(t), t)(CastHelperOrigin)
result ++= helpers.map { case (e, t, _) =>
unwrapCastConstraints(dispatch(t), t, e)(CastHelperOrigin)
}.toSeq
totalHelpers.addAll(helpers)
helpers.clear()
Expand All @@ -597,7 +609,7 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] {

// For loops add cast helpers before and as an invariant (since otherwise the contract might not be well-formed)
override def dispatch(node: LoopContract[Pre]): LoopContract[Post] = {
val helpers: mutable.Set[Type[Pre]] = mutable.Set()
val helpers: mutable.Set[(Expr[Post], Type[Pre], Type[Pre])] = mutable.Set()
node match {
case inv @ LoopInvariant(invariant, decreases) => {
val result =
Expand Down Expand Up @@ -631,7 +643,7 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] {
}

override def dispatch(stat: Statement[Pre]): Statement[Post] = {
val helpers: mutable.Set[Type[Pre]] = mutable.Set()
val helpers: mutable.Set[(Expr[Post], Type[Pre], Type[Pre])] = mutable.Set()
val result =
requiredCastHelpers.having(helpers) {
stat match {
Expand Down Expand Up @@ -685,15 +697,36 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] {
}

if (helpers.nonEmpty) {
Block(helpers.map { t =>
InvokeProcedure[Post](
castHelpers.getOrElseUpdate(t, makeCastHelper(t)).ref,
Nil,
Nil,
Nil,
Nil,
Nil,
)(TrueSatisfiable)(CastHelperOrigin)
// TODO: Branch/match here!
Block(helpers.map { case (ptr, t, oldPtrT) =>
ptr.t match {
case pt: TPointer[Post] =>
Branch(Seq((
Neq(ptr, Null()(stat.o))(stat.o),
InvokeProcedure[Post](
castHelpers
.getOrElseUpdate((ptr.t, t), makeCastHelper(t, oldPtrT)).ref,
Seq(Cast(ptr, TypeValue(TNonNullPointer(pt.element))(stat.o))(
stat.o
)),
Nil,
Nil,
Nil,
Nil,
)(TrueSatisfiable)(CastHelperOrigin),
)))(stat.o)
case _: TNonNullPointer[Post] =>
InvokeProcedure[Post](
castHelpers
.getOrElseUpdate((ptr.t, t), makeCastHelper(t, oldPtrT)).ref,
Seq(ptr),
Nil,
Nil,
Nil,
Nil,
)(TrueSatisfiable)(CastHelperOrigin)
case _ => ???
}
}.toSeq :+ result)(stat.o)
} else { result }
}
Expand All @@ -708,81 +741,96 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] {
case other => other.rewriteDefault()
}

private def unwrapCastConstraints(outerType: Type[Post], t: Type[Pre])(
implicit o: Origin
): Expr[Post] = {
private def unwrapCastConstraints(
outerType: Type[Post],
t: Type[Pre],
ptr: Expr[Post],
)(implicit o: Origin): Expr[Post] = {
val newT = dispatch(t)
val constraint =
forall[Post](
TNonNullPointer(outerType),
body = { p =>
PolarityDependent(
Greater(
CurPerm(PointerLocation(p)(PanicBlame(
"Referring to a non-null pointer should not cause any verification failures"
))),
NoPerm(),
) ==>
(InlinePattern(Cast(p, TypeValue(TNonNullPointer(newT)))) ===
adtFunctionInvocation(
valueAsFunctions
.getOrElseUpdate(t, makeValueAsFunction(t.toString, newT))
.ref,
typeArgs = Some((valueAdt.ref(()), Seq(outerType))),
args = Seq(DerefPointer(p)(PanicBlame(
"Pointer deref is safe since the permission is framed"
))),
)),
tt,
)
},
) &* forall[Post](
TNonNullPointer(outerType),
body = { p =>
PolarityDependent(
Greater(
CurPerm(PointerLocation(p)(PanicBlame(
var constraint: Expr[Post] = PolarityDependent(
Greater(
CurPerm(
TypedPointerLocation(ptr, TNonNullPointer(outerType))(PanicBlame(
"Referring to a non-null pointer should not cause any verification failures"
))
),
NoPerm(),
) ==>
(InlinePattern(Cast(ptr, TypeValue(TNonNullPointer(newT)))) ===
adtFunctionInvocation(
valueAsFunctions.getOrElseUpdate(
t,
makeValueAsFunction(t.toStringWithContext(namingContext), newT),
).ref,
typeArgs = Some((valueAdt.ref(()), Seq(outerType))),
args = Seq(DerefPointerTyped(ptr, outerType)(PanicBlame(
"Pointer deref is safe since the permission is framed"
))),
)),
tt,
)

if (newT != outerType) {
constraint =
constraint &* PolarityDependent(
Greater(
CurPerm(
TypedPointerLocation(ptr, TNonNullPointer(outerType))(PanicBlame(
"Referring to a non-null pointer should not cause any verification failures"
))),
NoPerm(),
) ==>
(InlinePattern(Cast(
Cast(p, TypeValue(TNonNullPointer(newT))),
TypeValue(TNonNullPointer(outerType)),
)) === p),
tt,
)
},
)
))
),
NoPerm(),
) ==>
(InlinePattern(Cast(
Cast(ptr, TypeValue(TNonNullPointer(newT))),
TypeValue(TNonNullPointer(outerType)),
)) === ptr),
tt,
)
}

if (t.isInstanceOf[TByValueClass[Pre]]) {
constraint &*
t.asInstanceOf[TByValueClass[Pre]].cls.decl.decls.collectFirst {
case field: InstanceField[Pre] =>
unwrapCastConstraints(outerType, field.t)
unwrapCastConstraints(outerType, field.t, ptr)
}.getOrElse(tt)
} else { constraint }
}

private def makeCastHelper(t: Type[Pre]): Procedure[Post] = {
implicit val o: Origin = CastHelperOrigin
.where(name = "constraints_" + t.toString)
private def makeCastHelper(
t: Type[Pre],
inType: Type[Pre],
): Procedure[Post] = {
implicit val o: Origin = CastHelperOrigin.where(name =
"constraints_" + t.toStringWithContext(namingContext) + "_from_" +
inType.toStringWithContext(namingContext)
)
val ptr =
new Variable[Post](TNonNullPointer(dispatch(inType)))(
o.where(name = "ptr")
)
globalDeclarations.declare(procedure(
AbstractApplicable,
TrueSatisfiable,
ensures = UnitAccountedPredicate(unwrapCastConstraints(dispatch(t), t)),
args = Seq(ptr),
ensures = UnitAccountedPredicate(
unwrapCastConstraints(dispatch(t), t, ptr.get)
),
))
}

private def addCastHelpers(
t: Type[Pre],
helpers: mutable.Set[Type[Pre]],
helpers: mutable.Set[(Expr[Post], Type[Pre], Type[Pre])],
ptr: Expr[Post],
oldPtrT: Type[Pre],
): Unit = {
t match {
case cls: TByValueClass[Pre] => {
helpers.add(t)
helpers.add((ptr, t, oldPtrT))
cls.cls.decl.decls.collectFirst { case field: InstanceField[Pre] =>
addCastHelpers(field.t, helpers)
addCastHelpers(field.t, helpers, ptr, oldPtrT)
}
}
case _ =>
Expand Down Expand Up @@ -902,13 +950,16 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] {
Nil,
Nil,
)(PanicBlame("instanceOf requires nothing"))(e.o)
// Is this a pointer cast between two different types:
case Cast(value, typeValue) if value.t.asPointer.isDefined => {
val newValue = dispatch(value)
// Keep pointer casts and add extra annotations
if (requiredCastHelpers.nonEmpty) {
addCastHelpers(value.t.asPointer.get.element, requiredCastHelpers.top)
val element = value.t.asPointer.get.element
addCastHelpers(element, requiredCastHelpers.top, newValue, element)
}

e.rewriteDefault()
Cast(newValue, dispatch(typeValue))(e.o)
}
case Cast(value, typeValue) =>
dispatch(
Expand Down
Loading

0 comments on commit df9ebc4

Please sign in to comment.