diff --git a/src/main/vct/main/stages/Transformation.scala b/src/main/vct/main/stages/Transformation.scala index d7fa210d3..9b0535e01 100644 --- a/src/main/vct/main/stages/Transformation.scala +++ b/src/main/vct/main/stages/Transformation.scala @@ -208,9 +208,9 @@ case class SilverTransformation EncodeChar, CollectLocalDeclarations, // all decls in Scope + VariableToPointer, // should happen before ParBlockEncoder so it can distinguish between variables which can and can't altered in a parallel block DesugarPermissionOperators, // no PointsTo, \pointer, etc. ReadToValue, // resolve wildcard into fractional permission - VariableToPointer, TrivialAddrOf, DesugarCoalescingOperators, // no ?. PinCollectionTypes, // no anonymous sequences, sets, etc. diff --git a/src/rewrite/vct/rewrite/DesugarPermissionOperators.scala b/src/rewrite/vct/rewrite/DesugarPermissionOperators.scala index 0d30ddc8d..f2e21d5f9 100644 --- a/src/rewrite/vct/rewrite/DesugarPermissionOperators.scala +++ b/src/rewrite/vct/rewrite/DesugarPermissionOperators.scala @@ -54,6 +54,21 @@ case class DesugarPermissionOperators[Pre <: Generation]() extends Rewriter[Pre] } } + def makeClassPerm(e: Expr[Post], cls: TClass[Pre], perm: Expr[Post]): Expr[Post] = { + implicit val o: Origin = e.o + cls.cls.decl.declarations.collect { + case f: InstanceField[Pre] => + if (f.t.asClass.isDefined) { + Perm(FieldLocation[Post](e, succ(f)), perm) &* + makeClassPerm(Deref[Post](e, succ(f))(FramedPtrOffset), f.t.asClass.get, perm) + } else { + Perm[Post](FieldLocation(e, succ(f)), perm) + } + }.reduce { + (a,b) => a &* b + } + } + override def dispatch(e: Expr[Pre]): Expr[Post] = { implicit val o: Origin = e.o e match { @@ -79,19 +94,32 @@ case class DesugarPermissionOperators[Pre <: Generation]() extends Rewriter[Pre] (const(0) <= row0 && row0 < dim0 && const(0) <= row1 && row1 < dim0) ==> ((ArraySubscript(mat, row0)(FramedArrIndex) === ArraySubscript(mat, row1)(FramedArrIndex)) ==> (row0 === row1)) )) - case PermPointer(p, len, perm) => + case PermPointer(p, len, perm) if p.t.asPointer.get.element.asClass.isDefined => (dispatch(p) !== Null()) &* const(0) <= PointerBlockOffset(dispatch(p))(FramedPtrBlockOffset) + dispatch(len) &* PointerBlockOffset(dispatch(p))(FramedPtrBlockOffset) + dispatch(len) <= PointerBlockLength(dispatch(p))(FramedPtrBlockLength) &* starall(IteratedPtrInjective, TInt(), i => - (const(0) <= i && i < dispatch(len)) ==> - Perm(PointerLocation(PointerAdd(dispatch(p), i)(FramedPtrOffset))(FramedPtrOffset), dispatch(perm))) + (const(0) <= i && i < dispatch(len)) ==> ( + Perm(PointerLocation(PointerAdd(dispatch(p), i)(FramedPtrOffset))(FramedPtrOffset), dispatch(perm)) &* makeClassPerm(DerefPointer(PointerAdd(dispatch(p), i)(FramedPtrOffset))(FramedPtrOffset), p.t.asPointer.get.element.asClass.get, dispatch(perm)) )) + case PermPointer(p, len, perm) => + (dispatch(p) !== Null()) &* + const(0) <= PointerBlockOffset(dispatch(p))(FramedPtrBlockOffset) + dispatch(len) &* + PointerBlockOffset(dispatch(p))(FramedPtrBlockOffset) + dispatch(len) <= PointerBlockLength(dispatch(p))(FramedPtrBlockLength) &* + starall(IteratedPtrInjective, TInt(), i => + (const(0) <= i && i < dispatch(len)) ==> + Perm(PointerLocation(PointerAdd(dispatch(p), i)(FramedPtrOffset))(FramedPtrOffset), dispatch(perm))) + case PermPointerIndex(p, idx, perm) if p.t.asPointer.get.element.asClass.isDefined => + (dispatch(p) !== Null()) &* + const(0) <= PointerBlockOffset(dispatch(p))(FramedPtrBlockOffset) + dispatch(idx) &* + PointerBlockOffset(dispatch(p))(FramedPtrBlockOffset) + dispatch(idx) < PointerBlockLength(dispatch(p))(FramedPtrBlockLength) &* + Perm(PointerLocation(PointerAdd(dispatch(p), dispatch(idx))(FramedPtrOffset))(FramedPtrOffset), dispatch(perm)) &* + makeClassPerm(DerefPointer(PointerAdd(dispatch(p), dispatch(idx))(FramedPtrOffset))(FramedPtrOffset), p.t.asPointer.get.element.asClass.get, dispatch(perm)) case PermPointerIndex(p, idx, perm) => (dispatch(p) !== Null()) &* const(0) <= PointerBlockOffset(dispatch(p))(FramedPtrBlockOffset) + dispatch(idx) &* PointerBlockOffset(dispatch(p))(FramedPtrBlockOffset) + dispatch(idx) < PointerBlockLength(dispatch(p))(FramedPtrBlockLength) &* Perm(PointerLocation(PointerAdd(dispatch(p), dispatch(idx))(FramedPtrOffset))(FramedPtrOffset), dispatch(perm)) - case other => rewriteDefault(other) + case other => other.rewriteDefault() } } } diff --git a/src/rewrite/vct/rewrite/TrivialAddrOf.scala b/src/rewrite/vct/rewrite/TrivialAddrOf.scala index f1a05454f..a94bffd3c 100644 --- a/src/rewrite/vct/rewrite/TrivialAddrOf.scala +++ b/src/rewrite/vct/rewrite/TrivialAddrOf.scala @@ -30,6 +30,12 @@ case class TrivialAddrOf[Pre <: Generation]() extends Rewriter[Pre] { case AddrOf(sub @ PointerSubscript(p, i)) => PointerAdd(dispatch(p), dispatch(i))(SubscriptErrorAddError(sub))(e.o) + case Eq(left, AddrOf(right)) if left.t.isInstanceOf[TPointer[Pre]] => + Eq(PointerSubscript(dispatch(left), const[Post](0)(left.o))(PanicBlame("Size is > 0"))(left.o), dispatch(right))(e.o) + + case Neq(left, AddrOf(right)) if left.t.isInstanceOf[TPointer[Pre]] => + Neq(PointerSubscript(dispatch(left), const[Post](0)(left.o))(PanicBlame("Size is > 0"))(left.o), dispatch(right))(e.o) + case AddrOf(other) => throw UnsupportedLocation(other) case assign@PreAssignExpression(target, AddrOf(value)) if value.t.isInstanceOf[TClass[Pre]] => @@ -37,7 +43,7 @@ case class TrivialAddrOf[Pre <: Generation]() extends Rewriter[Pre] { val (newPointer, newTarget, newValue) = rewriteAssign(target, value, assign.blame, assign.o) val newAssign = PreAssignExpression(PointerSubscript(newTarget, const[Post](0))(PanicBlame("Should always be accessible")), newValue)(assign.blame) With(newPointer, newAssign) - case other => rewriteDefault(other) + case other => other.rewriteDefault() } override def dispatch(s: Statement[Pre]): Statement[Post] = s match { @@ -46,7 +52,7 @@ case class TrivialAddrOf[Pre <: Generation]() extends Rewriter[Pre] { val (newPointer, newTarget, newValue) = rewriteAssign(target, value, assign.blame, assign.o) val newAssign = Assign(PointerSubscript(newTarget, const[Post](0))(PanicBlame("Should always be accessible")), newValue)(assign.blame) Block(Seq(newPointer, newAssign)) - case other => rewriteDefault(other) + case other => other.rewriteDefault() } // TODO: AddressOff needs a more structured approach. Now you could assign a local structure to a pointer, and that pointer diff --git a/src/rewrite/vct/rewrite/VariableToPointer.scala b/src/rewrite/vct/rewrite/VariableToPointer.scala index 751d201a2..3cd675d0c 100644 --- a/src/rewrite/vct/rewrite/VariableToPointer.scala +++ b/src/rewrite/vct/rewrite/VariableToPointer.scala @@ -26,35 +26,38 @@ case class VariableToPointer[Pre <: Generation]() extends Rewriter[Pre] { import VariableToPointer._ - var stageTwo = false val addressedSet: mutable.Set[Node[Pre]] = new mutable.HashSet[Node[Pre]]() val heapVariableMap: SuccessionMap[HeapVariable[Pre], HeapVariable[Post]] = SuccessionMap() val variableMap: SuccessionMap[Variable[Pre], Variable[Post]] = SuccessionMap() val fieldMap: SuccessionMap[InstanceField[Pre], InstanceField[Post]] = SuccessionMap() override def dispatch(program: Program[Pre]): Program[Rewritten[Pre]] = { - super.dispatch(program) - stageTwo = true + addressedSet.addAll(program.collect { + case AddrOf(Local(Ref(v))) if !v.t.isInstanceOf[TClass[Pre]] => v + case AddrOf(DerefHeapVariable(Ref(v))) if !v.t.isInstanceOf[TClass[Pre]] => v + case AddrOf(Deref(_, Ref(f))) if !f.t.isInstanceOf[TClass[Pre]] => f + }) super.dispatch(program) } override def dispatch(decl: Declaration[Pre]): Unit = decl match { - case v: HeapVariable[Pre] if stageTwo && addressedSet.contains(v) => heapVariableMap(v) = globalDeclarations.declare(new HeapVariable(TPointer(dispatch(v.t)))(v.o)) - case v: Variable[Pre] if stageTwo && addressedSet.contains(v) => variableMap(v) = variables.declare(new Variable(TPointer(dispatch(v.t)))(v.o)) - case f: InstanceField[Pre] if stageTwo && addressedSet.contains(f) => fieldMap(f) = classDeclarations.declare(new InstanceField(TPointer(dispatch(f.t)), f.flags.map { it => dispatch(it) })(f.o)) + case v: HeapVariable[Pre] if addressedSet.contains(v) => heapVariableMap(v) = globalDeclarations.declare(new HeapVariable(TPointer(dispatch(v.t)))(v.o)) + case v: Variable[Pre] if addressedSet.contains(v) => variableMap(v) = variables.declare(new Variable(TPointer(dispatch(v.t)))(v.o)) + case f: InstanceField[Pre] if addressedSet.contains(f) => fieldMap(f) = classDeclarations.declare(new InstanceField(TPointer(dispatch(f.t)), f.flags.map { it => dispatch(it) })(f.o)) case other => allScopes.anySucceed(other, other.rewriteDefault()) } override def dispatch(stat: Statement[Pre]): Statement[Post] = { implicit val o: Origin = stat.o stat match { - case s: Scope[Pre] if stageTwo => s.rewrite(locals = variables.dispatch(s.locals), body = Block(s.locals.filter { local => addressedSet.contains(local) }.map { local => + case s: Scope[Pre] => s.rewrite(locals = variables.dispatch(s.locals), body = Block(s.locals.filter { local => addressedSet.contains(local) }.map { local => implicit val o: Origin = local.o Assign(Local[Post](variableMap.ref(local)), NewPointerArray(variableMap(local).t.asPointer.get.element, const(1))(PanicBlame("Size is > 0")))(PanicBlame("Initialisation should always succeed")) } ++ Seq(dispatch(s.body)))) - case i@Instantiate(cls, out) if stageTwo => + case i@Instantiate(cls, out) => Block(Seq(i.rewriteDefault()) ++ cls.decl.declarations.flatMap { - case f: InstanceField[Pre] if addressedSet.contains(f) => Seq(Assign(Deref[Post](dispatch(out), fieldMap.ref(f))(PanicBlame("Initialisation should always succeed")), NewPointerArray(fieldMap(f).t.asPointer.get.element, const(1))(PanicBlame("Size is > 0")))(PanicBlame("Initialisation should always succeed"))) + case f: InstanceField[Pre] if addressedSet.contains(f) => + Seq(Assign(Deref[Post](dispatch(out), fieldMap.ref(f))(PanicBlame("Initialisation should always succeed")), NewPointerArray(fieldMap(f).t.asPointer.get.element, const(1))(PanicBlame("Size is > 0")))(PanicBlame("Initialisation should always succeed"))) case _ => Seq() }) case other => other.rewriteDefault() @@ -64,28 +67,18 @@ case class VariableToPointer[Pre <: Generation]() extends Rewriter[Pre] { override def dispatch(expr: Expr[Pre]): Expr[Post] = { implicit val o: Origin = expr.o expr match { - case AddrOf(Local(Ref(v))) if !stageTwo && !v.t.isInstanceOf[TClass[Pre]] => - addressedSet.add(v) - expr.rewriteDefault() - case AddrOf(DerefHeapVariable(Ref(v))) if !stageTwo && !v.t.isInstanceOf[TClass[Pre]] => - addressedSet.add(v) - expr.rewriteDefault() - case AddrOf(Deref(_, Ref(f))) if !stageTwo && !f.t.isInstanceOf[TClass[Pre]] => - addressedSet.add(f) - expr.rewriteDefault() - // case AddrOf(inner) if !stageTwo => - // throw UnsupportedAddrOf(inner) - case deref@DerefHeapVariable(Ref(v)) if stageTwo && addressedSet.contains(v) => + case deref@DerefHeapVariable(Ref(v)) if addressedSet.contains(v) => DerefPointer(DerefHeapVariable[Post](heapVariableMap.ref(v))(deref.blame))(PanicBlame("Should always be accessible")) - case Local(Ref(v)) if stageTwo && addressedSet.contains(v) => + case Local(Ref(v)) if addressedSet.contains(v) => DerefPointer(Local[Post](variableMap.ref(v)))(PanicBlame("Should always be accessible")) - case deref@Deref(obj, Ref(f)) if stageTwo && addressedSet.contains(f) => + case deref@Deref(obj, Ref(f)) if addressedSet.contains(f) => DerefPointer(Deref[Post](dispatch(obj), fieldMap.ref(f))(deref.blame))(PanicBlame("Should always be accessible")) - case newObject@NewObject(Ref(cls)) if stageTwo => + case newObject@NewObject(Ref(cls)) => val obj = new Variable[Post](TClass(succ(cls))) ScopedExpr(Seq(obj), With(Block( Seq(assignLocal(obj.get, newObject.rewriteDefault())) ++ cls.declarations.flatMap { - case f: InstanceField[Pre] if addressedSet.contains(f) => Seq(Assign(Deref[Post](obj.get, fieldMap.ref(f))(PanicBlame("Initialisation should always succeed")), NewPointerArray(fieldMap(f).t.asPointer.get.element, const(1))(PanicBlame("Size is > 0")))(PanicBlame("Initialisation should always succeed"))) + case f: InstanceField[Pre] if addressedSet.contains(f) => + Seq(Assign(Deref[Post](obj.get, fieldMap.ref(f))(PanicBlame("Initialisation should always succeed")), NewPointerArray(fieldMap(f).t.asPointer.get.element, const(1))(PanicBlame("Size is > 0")))(PanicBlame("Initialisation should always succeed"))) case _ => Seq() }), obj.get)) case other => other.rewriteDefault() @@ -95,9 +88,8 @@ case class VariableToPointer[Pre <: Generation]() extends Rewriter[Pre] { override def dispatch(loc: Location[Pre]): Location[Post] = { implicit val o: Origin = loc.o loc match { - case HeapVariableLocation(Ref(v)) if stageTwo && addressedSet.contains(v) => PointerLocation(DerefHeapVariable[Post](heapVariableMap.ref(v))(PanicBlame("Should always be accessible")))(PanicBlame("Should always be accessible")) - case FieldLocation(obj, Ref(f)) if stageTwo && addressedSet.contains(f) => PointerLocation(Deref[Post](dispatch(obj), fieldMap.ref(f))(PanicBlame("Should always be accessible")))(PanicBlame("Should always be accessible")) - case loc@PointerLocation(inner) if stageTwo && addressedSet.contains(inner) => PointerLocation(dispatch(inner))(loc.blame) + case HeapVariableLocation(Ref(v)) if addressedSet.contains(v) => PointerLocation(DerefHeapVariable[Post](heapVariableMap.ref(v))(PanicBlame("Should always be accessible")))(PanicBlame("Should always be accessible")) + case FieldLocation(obj, Ref(f)) if addressedSet.contains(f) => PointerLocation(Deref[Post](dispatch(obj), fieldMap.ref(f))(PanicBlame("Should always be accessible")))(PanicBlame("Should always be accessible")) case other => other.rewriteDefault() } }