Skip to content

Commit

Permalink
Speed up VariableToPointer pass, enable comparing pointers with addre…
Browse files Browse the repository at this point in the history
…ss of, and change \pointer and \pointer_index to also give permission to fields if the pointed to type is a class
  • Loading branch information
superaxander committed Apr 2, 2024
1 parent 40f3dff commit 04d2357
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 35 deletions.
2 changes: 1 addition & 1 deletion src/main/vct/main/stages/Transformation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,9 @@ case class SilverTransformation
EncodeChar,

CollectLocalDeclarations, // all decls in Scope
VariableToPointer, // should happen before ParBlockEncoder so it can distinguish between variables which can and can't altered in a parallel block
DesugarPermissionOperators, // no PointsTo, \pointer, etc.
ReadToValue, // resolve wildcard into fractional permission
VariableToPointer,
TrivialAddrOf,
DesugarCoalescingOperators, // no ?.
PinCollectionTypes, // no anonymous sequences, sets, etc.
Expand Down
36 changes: 32 additions & 4 deletions src/rewrite/vct/rewrite/DesugarPermissionOperators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,21 @@ case class DesugarPermissionOperators[Pre <: Generation]() extends Rewriter[Pre]
}
}

def makeClassPerm(e: Expr[Post], cls: TClass[Pre], perm: Expr[Post]): Expr[Post] = {
implicit val o: Origin = e.o
cls.cls.decl.declarations.collect {
case f: InstanceField[Pre] =>
if (f.t.asClass.isDefined) {
Perm(FieldLocation[Post](e, succ(f)), perm) &*
makeClassPerm(Deref[Post](e, succ(f))(FramedPtrOffset), f.t.asClass.get, perm)
} else {
Perm[Post](FieldLocation(e, succ(f)), perm)
}
}.reduce {
(a,b) => a &* b
}
}

override def dispatch(e: Expr[Pre]): Expr[Post] = {
implicit val o: Origin = e.o
e match {
Expand All @@ -79,19 +94,32 @@ case class DesugarPermissionOperators[Pre <: Generation]() extends Rewriter[Pre]
(const(0) <= row0 && row0 < dim0 && const(0) <= row1 && row1 < dim0) ==>
((ArraySubscript(mat, row0)(FramedArrIndex) === ArraySubscript(mat, row1)(FramedArrIndex)) ==> (row0 === row1))
))
case PermPointer(p, len, perm) =>
case PermPointer(p, len, perm) if p.t.asPointer.get.element.asClass.isDefined =>
(dispatch(p) !== Null()) &*
const(0) <= PointerBlockOffset(dispatch(p))(FramedPtrBlockOffset) + dispatch(len) &*
PointerBlockOffset(dispatch(p))(FramedPtrBlockOffset) + dispatch(len) <= PointerBlockLength(dispatch(p))(FramedPtrBlockLength) &*
starall(IteratedPtrInjective, TInt(), i =>
(const(0) <= i && i < dispatch(len)) ==>
Perm(PointerLocation(PointerAdd(dispatch(p), i)(FramedPtrOffset))(FramedPtrOffset), dispatch(perm)))
(const(0) <= i && i < dispatch(len)) ==> (
Perm(PointerLocation(PointerAdd(dispatch(p), i)(FramedPtrOffset))(FramedPtrOffset), dispatch(perm)) &* makeClassPerm(DerefPointer(PointerAdd(dispatch(p), i)(FramedPtrOffset))(FramedPtrOffset), p.t.asPointer.get.element.asClass.get, dispatch(perm)) ))
case PermPointer(p, len, perm) =>
(dispatch(p) !== Null()) &*
const(0) <= PointerBlockOffset(dispatch(p))(FramedPtrBlockOffset) + dispatch(len) &*
PointerBlockOffset(dispatch(p))(FramedPtrBlockOffset) + dispatch(len) <= PointerBlockLength(dispatch(p))(FramedPtrBlockLength) &*
starall(IteratedPtrInjective, TInt(), i =>
(const(0) <= i && i < dispatch(len)) ==>
Perm(PointerLocation(PointerAdd(dispatch(p), i)(FramedPtrOffset))(FramedPtrOffset), dispatch(perm)))
case PermPointerIndex(p, idx, perm) if p.t.asPointer.get.element.asClass.isDefined =>
(dispatch(p) !== Null()) &*
const(0) <= PointerBlockOffset(dispatch(p))(FramedPtrBlockOffset) + dispatch(idx) &*
PointerBlockOffset(dispatch(p))(FramedPtrBlockOffset) + dispatch(idx) < PointerBlockLength(dispatch(p))(FramedPtrBlockLength) &*
Perm(PointerLocation(PointerAdd(dispatch(p), dispatch(idx))(FramedPtrOffset))(FramedPtrOffset), dispatch(perm)) &*
makeClassPerm(DerefPointer(PointerAdd(dispatch(p), dispatch(idx))(FramedPtrOffset))(FramedPtrOffset), p.t.asPointer.get.element.asClass.get, dispatch(perm))
case PermPointerIndex(p, idx, perm) =>
(dispatch(p) !== Null()) &*
const(0) <= PointerBlockOffset(dispatch(p))(FramedPtrBlockOffset) + dispatch(idx) &*
PointerBlockOffset(dispatch(p))(FramedPtrBlockOffset) + dispatch(idx) < PointerBlockLength(dispatch(p))(FramedPtrBlockLength) &*
Perm(PointerLocation(PointerAdd(dispatch(p), dispatch(idx))(FramedPtrOffset))(FramedPtrOffset), dispatch(perm))
case other => rewriteDefault(other)
case other => other.rewriteDefault()
}
}
}
10 changes: 8 additions & 2 deletions src/rewrite/vct/rewrite/TrivialAddrOf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,20 @@ case class TrivialAddrOf[Pre <: Generation]() extends Rewriter[Pre] {
case AddrOf(sub @ PointerSubscript(p, i)) =>
PointerAdd(dispatch(p), dispatch(i))(SubscriptErrorAddError(sub))(e.o)

case Eq(left, AddrOf(right)) if left.t.isInstanceOf[TPointer[Pre]] =>
Eq(PointerSubscript(dispatch(left), const[Post](0)(left.o))(PanicBlame("Size is > 0"))(left.o), dispatch(right))(e.o)

case Neq(left, AddrOf(right)) if left.t.isInstanceOf[TPointer[Pre]] =>
Neq(PointerSubscript(dispatch(left), const[Post](0)(left.o))(PanicBlame("Size is > 0"))(left.o), dispatch(right))(e.o)

case AddrOf(other) =>
throw UnsupportedLocation(other)
case assign@PreAssignExpression(target, AddrOf(value)) if value.t.isInstanceOf[TClass[Pre]] =>
implicit val o: Origin = assign.o
val (newPointer, newTarget, newValue) = rewriteAssign(target, value, assign.blame, assign.o)
val newAssign = PreAssignExpression(PointerSubscript(newTarget, const[Post](0))(PanicBlame("Should always be accessible")), newValue)(assign.blame)
With(newPointer, newAssign)
case other => rewriteDefault(other)
case other => other.rewriteDefault()
}

override def dispatch(s: Statement[Pre]): Statement[Post] = s match {
Expand All @@ -46,7 +52,7 @@ case class TrivialAddrOf[Pre <: Generation]() extends Rewriter[Pre] {
val (newPointer, newTarget, newValue) = rewriteAssign(target, value, assign.blame, assign.o)
val newAssign = Assign(PointerSubscript(newTarget, const[Post](0))(PanicBlame("Should always be accessible")), newValue)(assign.blame)
Block(Seq(newPointer, newAssign))
case other => rewriteDefault(other)
case other => other.rewriteDefault()
}

// TODO: AddressOff needs a more structured approach. Now you could assign a local structure to a pointer, and that pointer
Expand Down
48 changes: 20 additions & 28 deletions src/rewrite/vct/rewrite/VariableToPointer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,35 +26,38 @@ case class VariableToPointer[Pre <: Generation]() extends Rewriter[Pre] {

import VariableToPointer._

var stageTwo = false
val addressedSet: mutable.Set[Node[Pre]] = new mutable.HashSet[Node[Pre]]()
val heapVariableMap: SuccessionMap[HeapVariable[Pre], HeapVariable[Post]] = SuccessionMap()
val variableMap: SuccessionMap[Variable[Pre], Variable[Post]] = SuccessionMap()
val fieldMap: SuccessionMap[InstanceField[Pre], InstanceField[Post]] = SuccessionMap()

override def dispatch(program: Program[Pre]): Program[Rewritten[Pre]] = {
super.dispatch(program)
stageTwo = true
addressedSet.addAll(program.collect {
case AddrOf(Local(Ref(v))) if !v.t.isInstanceOf[TClass[Pre]] => v
case AddrOf(DerefHeapVariable(Ref(v))) if !v.t.isInstanceOf[TClass[Pre]] => v
case AddrOf(Deref(_, Ref(f))) if !f.t.isInstanceOf[TClass[Pre]] => f
})
super.dispatch(program)
}

override def dispatch(decl: Declaration[Pre]): Unit = decl match {
case v: HeapVariable[Pre] if stageTwo && addressedSet.contains(v) => heapVariableMap(v) = globalDeclarations.declare(new HeapVariable(TPointer(dispatch(v.t)))(v.o))
case v: Variable[Pre] if stageTwo && addressedSet.contains(v) => variableMap(v) = variables.declare(new Variable(TPointer(dispatch(v.t)))(v.o))
case f: InstanceField[Pre] if stageTwo && addressedSet.contains(f) => fieldMap(f) = classDeclarations.declare(new InstanceField(TPointer(dispatch(f.t)), f.flags.map { it => dispatch(it) })(f.o))
case v: HeapVariable[Pre] if addressedSet.contains(v) => heapVariableMap(v) = globalDeclarations.declare(new HeapVariable(TPointer(dispatch(v.t)))(v.o))
case v: Variable[Pre] if addressedSet.contains(v) => variableMap(v) = variables.declare(new Variable(TPointer(dispatch(v.t)))(v.o))
case f: InstanceField[Pre] if addressedSet.contains(f) => fieldMap(f) = classDeclarations.declare(new InstanceField(TPointer(dispatch(f.t)), f.flags.map { it => dispatch(it) })(f.o))
case other => allScopes.anySucceed(other, other.rewriteDefault())
}

override def dispatch(stat: Statement[Pre]): Statement[Post] = {
implicit val o: Origin = stat.o
stat match {
case s: Scope[Pre] if stageTwo => s.rewrite(locals = variables.dispatch(s.locals), body = Block(s.locals.filter { local => addressedSet.contains(local) }.map { local =>
case s: Scope[Pre] => s.rewrite(locals = variables.dispatch(s.locals), body = Block(s.locals.filter { local => addressedSet.contains(local) }.map { local =>
implicit val o: Origin = local.o
Assign(Local[Post](variableMap.ref(local)), NewPointerArray(variableMap(local).t.asPointer.get.element, const(1))(PanicBlame("Size is > 0")))(PanicBlame("Initialisation should always succeed"))
} ++ Seq(dispatch(s.body))))
case i@Instantiate(cls, out) if stageTwo =>
case i@Instantiate(cls, out) =>
Block(Seq(i.rewriteDefault()) ++ cls.decl.declarations.flatMap {
case f: InstanceField[Pre] if addressedSet.contains(f) => Seq(Assign(Deref[Post](dispatch(out), fieldMap.ref(f))(PanicBlame("Initialisation should always succeed")), NewPointerArray(fieldMap(f).t.asPointer.get.element, const(1))(PanicBlame("Size is > 0")))(PanicBlame("Initialisation should always succeed")))
case f: InstanceField[Pre] if addressedSet.contains(f) =>
Seq(Assign(Deref[Post](dispatch(out), fieldMap.ref(f))(PanicBlame("Initialisation should always succeed")), NewPointerArray(fieldMap(f).t.asPointer.get.element, const(1))(PanicBlame("Size is > 0")))(PanicBlame("Initialisation should always succeed")))
case _ => Seq()
})
case other => other.rewriteDefault()
Expand All @@ -64,28 +67,18 @@ case class VariableToPointer[Pre <: Generation]() extends Rewriter[Pre] {
override def dispatch(expr: Expr[Pre]): Expr[Post] = {
implicit val o: Origin = expr.o
expr match {
case AddrOf(Local(Ref(v))) if !stageTwo && !v.t.isInstanceOf[TClass[Pre]] =>
addressedSet.add(v)
expr.rewriteDefault()
case AddrOf(DerefHeapVariable(Ref(v))) if !stageTwo && !v.t.isInstanceOf[TClass[Pre]] =>
addressedSet.add(v)
expr.rewriteDefault()
case AddrOf(Deref(_, Ref(f))) if !stageTwo && !f.t.isInstanceOf[TClass[Pre]] =>
addressedSet.add(f)
expr.rewriteDefault()
// case AddrOf(inner) if !stageTwo =>
// throw UnsupportedAddrOf(inner)
case deref@DerefHeapVariable(Ref(v)) if stageTwo && addressedSet.contains(v) =>
case deref@DerefHeapVariable(Ref(v)) if addressedSet.contains(v) =>
DerefPointer(DerefHeapVariable[Post](heapVariableMap.ref(v))(deref.blame))(PanicBlame("Should always be accessible"))
case Local(Ref(v)) if stageTwo && addressedSet.contains(v) =>
case Local(Ref(v)) if addressedSet.contains(v) =>
DerefPointer(Local[Post](variableMap.ref(v)))(PanicBlame("Should always be accessible"))
case deref@Deref(obj, Ref(f)) if stageTwo && addressedSet.contains(f) =>
case deref@Deref(obj, Ref(f)) if addressedSet.contains(f) =>
DerefPointer(Deref[Post](dispatch(obj), fieldMap.ref(f))(deref.blame))(PanicBlame("Should always be accessible"))
case newObject@NewObject(Ref(cls)) if stageTwo =>
case newObject@NewObject(Ref(cls)) =>
val obj = new Variable[Post](TClass(succ(cls)))
ScopedExpr(Seq(obj), With(Block(
Seq(assignLocal(obj.get, newObject.rewriteDefault())) ++ cls.declarations.flatMap {
case f: InstanceField[Pre] if addressedSet.contains(f) => Seq(Assign(Deref[Post](obj.get, fieldMap.ref(f))(PanicBlame("Initialisation should always succeed")), NewPointerArray(fieldMap(f).t.asPointer.get.element, const(1))(PanicBlame("Size is > 0")))(PanicBlame("Initialisation should always succeed")))
case f: InstanceField[Pre] if addressedSet.contains(f) =>
Seq(Assign(Deref[Post](obj.get, fieldMap.ref(f))(PanicBlame("Initialisation should always succeed")), NewPointerArray(fieldMap(f).t.asPointer.get.element, const(1))(PanicBlame("Size is > 0")))(PanicBlame("Initialisation should always succeed")))
case _ => Seq()
}), obj.get))
case other => other.rewriteDefault()
Expand All @@ -95,9 +88,8 @@ case class VariableToPointer[Pre <: Generation]() extends Rewriter[Pre] {
override def dispatch(loc: Location[Pre]): Location[Post] = {
implicit val o: Origin = loc.o
loc match {
case HeapVariableLocation(Ref(v)) if stageTwo && addressedSet.contains(v) => PointerLocation(DerefHeapVariable[Post](heapVariableMap.ref(v))(PanicBlame("Should always be accessible")))(PanicBlame("Should always be accessible"))
case FieldLocation(obj, Ref(f)) if stageTwo && addressedSet.contains(f) => PointerLocation(Deref[Post](dispatch(obj), fieldMap.ref(f))(PanicBlame("Should always be accessible")))(PanicBlame("Should always be accessible"))
case loc@PointerLocation(inner) if stageTwo && addressedSet.contains(inner) => PointerLocation(dispatch(inner))(loc.blame)
case HeapVariableLocation(Ref(v)) if addressedSet.contains(v) => PointerLocation(DerefHeapVariable[Post](heapVariableMap.ref(v))(PanicBlame("Should always be accessible")))(PanicBlame("Should always be accessible"))
case FieldLocation(obj, Ref(f)) if addressedSet.contains(f) => PointerLocation(Deref[Post](dispatch(obj), fieldMap.ref(f))(PanicBlame("Should always be accessible")))(PanicBlame("Should always be accessible"))
case other => other.rewriteDefault()
}
}
Expand Down

0 comments on commit 04d2357

Please sign in to comment.