Skip to content

Commit

Permalink
Implement basic pointer casts
Browse files Browse the repository at this point in the history
  • Loading branch information
superaxander committed Aug 21, 2024
1 parent 98dde38 commit 42aca99
Show file tree
Hide file tree
Showing 4 changed files with 275 additions and 178 deletions.
301 changes: 245 additions & 56 deletions src/rewrite/vct/rewrite/ClassToRef.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ case object ClassToRef extends RewriterBuilder {
private def InstanceOfOrigin: Origin =
Origin(Seq(PreferredName(Seq("subtype")), LabelContext("classToRef")))

private def ValueAdtOrigin: Origin =
Origin(Seq(PreferredName(Seq("Value")), LabelContext("classToRef")))

private def CastHelperOrigin: Origin = Origin(Seq(LabelContext("classToRef")))

case class InstanceNullPreconditionFailed(
inner: Blame[InstanceNull],
inv: InvokingNode[_],
Expand Down Expand Up @@ -73,6 +78,15 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] {
val typeOf: SuccessionMap[Unit, Function[Post]] = SuccessionMap()
val instanceOf: SuccessionMap[Unit, Function[Post]] = SuccessionMap()

val valueAdt: SuccessionMap[Unit, AxiomaticDataType[Post]] = SuccessionMap()
val valueAdtTypeArgument: Variable[Post] =
new Variable(TType(TAnyValue()))(ValueAdtOrigin.where(name = "V"))
val valueAsFunctions: mutable.Map[Type[Pre], ADTFunction[Post]] = mutable
.Map()

val castHelpers: SuccessionMap[Type[Pre], Procedure[Post]] = SuccessionMap()
val castHelperCalls: ScopedStack[mutable.Set[Statement[Post]]] = ScopedStack()

def typeNumber(cls: Class[Pre]): Int =
typeNumberStore.getOrElseUpdate(cls, typeNumberStore.size + 1)

Expand Down Expand Up @@ -141,15 +155,68 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] {
)
}

private def makeValueAdt: AxiomaticDataType[Post] = {
new AxiomaticDataType[Post](
valueAsFunctions.values.toSeq,
Seq(valueAdtTypeArgument),
)(ValueAdtOrigin)
}

// TODO: Also generate value as axioms for arrays once those are properly supported for C/CPP/LLVM
private def makeValueAsFunction(
typeName: String,
t: Type[Post],
): ADTFunction[Post] = {
new ADTFunction[Post](
Seq(new Variable(TVar[Post](valueAdtTypeArgument.ref))(
ValueAdtOrigin.where(name = "v")
)),
TNonNullPointer(t),
)(ValueAdtOrigin.where(name = "value_as_" + typeName))
}

private def unwrapValueAs(
axiomType: TAxiomatic[Post],
oldT: Type[Pre],
newT: Type[Post],
fieldRef: Ref[Post, ADTFunction[Post]],
)(implicit o: Origin): Seq[ADTAxiom[Post]] = {
(oldT match {
case t: TByValueClass[Pre] => {
// TODO: If there are no fields we should ignore the first field and add the axioms for the second field
t.cls.decl.decls.collectFirst({ case field: InstanceField[Pre] =>
unwrapValueAs(axiomType, field.t, dispatch(field.t), fieldRef)
}).getOrElse(Nil)
}
case _ => Nil
}) :+ new ADTAxiom[Post](forall(
axiomType,
body = { a =>
InlinePattern(adtFunctionInvocation[Post](
valueAsFunctions
.getOrElseUpdate(oldT, makeValueAsFunction(oldT.toString, newT))
.ref,
typeArgs = Some((valueAdt.ref(()), Seq(axiomType))),
args = Seq(a),
)) === Cast(
adtFunctionInvocation(fieldRef, args = Seq(a)),
TypeValue(TNonNullPointer(newT)),
)
},
))
}

override def dispatch(program: Program[Pre]): Program[Rewritten[Pre]] =
program.rewrite(declarations =
globalDeclarations.collect {
program.declarations.foreach(dispatch)
implicit val o: Origin = TypeOfOrigin
typeOf(()) = makeTypeOf
globalDeclarations.declare(typeOf(()))
instanceOf(()) = makeInstanceOf
globalDeclarations.declare(instanceOf(()))
if (valueAsFunctions.nonEmpty) {
globalDeclarations.declare(valueAdt.getOrElseUpdate((), makeValueAdt))
}
}._1
)

Expand Down Expand Up @@ -309,18 +376,49 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] {
case cls: ByValueClass[Pre] =>
implicit val o: Origin = cls.o
val axiomType = TAxiomatic[Post](byValClassSucc.ref(cls), Nil)
val classType = cls.classType(Nil)
var valueAsAxioms: Seq[ADTAxiom[Post]] = Seq()
val (fieldFunctions, fieldInverses, fieldTypes) =
cls.decls.collect { case field: Field[Pre] =>
val newT = TNonNullPointer(dispatch(field.t))
val newT = dispatch(field.t)
val nonnullT = TNonNullPointer(newT)
byValFieldSucc(field) =
new ADTFunction[Post](
Seq(new Variable(axiomType)(field.o)),
newT,
nonnullT,
)(field.o)
if (valueAsAxioms.isEmpty) {
// This is the first field
valueAsAxioms =
valueAsAxioms :+ new ADTAxiom[Post](forall(
axiomType,
body = { a =>
InlinePattern(adtFunctionInvocation[Post](
valueAsFunctions.getOrElseUpdate(
field.t,
makeValueAsFunction(field.t.toString, newT),
).ref,
typeArgs = Some((valueAdt.ref(()), Seq(axiomType))),
args = Seq(a),
)) === adtFunctionInvocation(
byValFieldSucc.ref(field),
args = Seq(a),
)
},
))

valueAsAxioms =
valueAsAxioms ++ unwrapValueAs(
axiomType,
field.t,
newT,
byValFieldSucc.ref(field),
)
}
(
byValFieldSucc(field),
new ADTFunction[Post](
Seq(new Variable(newT)(field.o)),
Seq(new Variable(nonnullT)(field.o)),
axiomType,
)(
field.o.copy(
Expand All @@ -331,7 +429,7 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] {
.getOrElse("unknown")
)
),
newT,
nonnullT,
)
}.unzip3
val constructor =
Expand Down Expand Up @@ -432,7 +530,8 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] {
byValClassSucc(cls) =
new AxiomaticDataType[Post](
Seq(indexFunction, injectivityAxiom) ++ destructorAxioms ++
indexAxioms ++ fieldFunctions ++ fieldInverses,
indexAxioms ++ fieldFunctions ++ fieldInverses ++
valueAsAxioms,
Nil,
)
globalDeclarations.succeed(cls, byValClassSucc(cls))
Expand Down Expand Up @@ -467,55 +566,63 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] {
}
}

override def dispatch(stat: Statement[Pre]): Statement[Post] =
stat match {
case Instantiate(Ref(cls), Local(Ref(v))) =>
instantiate(cls, succ(v))(stat.o)
case inv @ InvokeMethod(
obj,
Ref(method),
args,
outArgs,
typeArgs,
givenMap,
yields,
) =>
InvokeProcedure[Post](
ref = methodSucc.ref(method),
args = dispatch(obj) +: args.map(dispatch),
outArgs = outArgs.map(dispatch),
typeArgs = typeArgs.map(dispatch),
givenMap = givenMap.map { case (Ref(v), e) =>
(succ(v), dispatch(e))
},
yields = yields.map { case (e, Ref(v)) => (dispatch(e), succ(v)) },
)(PreBlameSplit.left(
InstanceNullPreconditionFailed(inv.blame, inv),
PreBlameSplit
.left(PanicBlame("incorrect instance method type?"), inv.blame),
))(inv.o)
case inv @ InvokeConstructor(
Ref(cons),
_,
out,
args,
outArgs,
typeArgs,
givenMap,
yields,
) =>
InvokeProcedure[Post](
ref = consSucc.ref(cons),
args = args.map(dispatch),
outArgs = dispatch(out) +: outArgs.map(dispatch),
typeArgs = typeArgs.map(dispatch),
givenMap = givenMap.map { case (Ref(v), e) =>
(succ(v), dispatch(e))
},
yields = yields.map { case (e, Ref(v)) => (dispatch(e), succ(v)) },
)(inv.blame)(inv.o)
case other => super.dispatch(other)
}
override def dispatch(stat: Statement[Pre]): Statement[Post] = {
val helpers: mutable.Set[Statement[Post]] = mutable.Set()
val result =
castHelperCalls.having(helpers) {
stat match {
case Instantiate(Ref(cls), Local(Ref(v))) =>
instantiate(cls, succ(v))(stat.o)
case inv @ InvokeMethod(
obj,
Ref(method),
args,
outArgs,
typeArgs,
givenMap,
yields,
) =>
InvokeProcedure[Post](
ref = methodSucc.ref(method),
args = dispatch(obj) +: args.map(dispatch),
outArgs = outArgs.map(dispatch),
typeArgs = typeArgs.map(dispatch),
givenMap = givenMap.map { case (Ref(v), e) =>
(succ(v), dispatch(e))
},
yields = yields.map { case (e, Ref(v)) => (dispatch(e), succ(v)) },
)(PreBlameSplit.left(
InstanceNullPreconditionFailed(inv.blame, inv),
PreBlameSplit
.left(PanicBlame("incorrect instance method type?"), inv.blame),
))(inv.o)
case inv @ InvokeConstructor(
Ref(cons),
_,
out,
args,
outArgs,
typeArgs,
givenMap,
yields,
) =>
InvokeProcedure[Post](
ref = consSucc.ref(cons),
args = args.map(dispatch),
outArgs = dispatch(out) +: outArgs.map(dispatch),
typeArgs = typeArgs.map(dispatch),
givenMap = givenMap.map { case (Ref(v), e) =>
(succ(v), dispatch(e))
},
yields = yields.map { case (e, Ref(v)) => (dispatch(e), succ(v)) },
)(inv.blame)(inv.o)
case other => super.dispatch(other)
}
}

if (helpers.nonEmpty) { Block(helpers.toSeq :+ result)(stat.o) }
else { result }
}

override def dispatch(node: ApplyAnyPredicate[Pre]): ApplyAnyPredicate[Post] =
node match {
Expand All @@ -527,6 +634,77 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] {
case other => other.rewriteDefault()
}

private def unwrapCastConstraints(outerType: Type[Post], t: Type[Pre])(
implicit o: Origin
): Expr[Post] = {
val newT = dispatch(t)
val constraint = forall[Post](
TNonNullPointer(outerType),
body = { p =>
PolarityDependent(
Greater(
CurPerm(PointerLocation(p)(PanicBlame(
"Referring to a non-null pointer should not cause any verification failures"
))),
NoPerm(),
) ==>
(InlinePattern(Cast(p, TypeValue(TNonNullPointer(newT)))) ===
adtFunctionInvocation(
valueAsFunctions
.getOrElseUpdate(t, makeValueAsFunction(t.toString, newT))
.ref,
typeArgs = Some((valueAdt.ref(()), Seq(outerType))),
args = Seq(DerefPointer(p)(PanicBlame(
"Pointer deref is safe since the permission is framed"
))),
)),
tt,
)
},
)

if (t.isInstanceOf[TByValueClass[Pre]]) {
constraint &*
t.asInstanceOf[TByValueClass[Pre]].cls.decl.decls.collectFirst {
case field: InstanceField[Pre] =>
unwrapCastConstraints(outerType, field.t)
}.getOrElse(tt)
} else { constraint }
}

private def makeCastHelper(t: Type[Pre]): Procedure[Post] = {
implicit val o: Origin = CastHelperOrigin
.where(name = "constraints_" + t.toString)
globalDeclarations.declare(procedure(
AbstractApplicable,
TrueSatisfiable,
ensures = UnitAccountedPredicate(unwrapCastConstraints(dispatch(t), t)),
))
}

private def addCastHelpers(t: Type[Pre], calls: mutable.Set[Statement[Post]])(
implicit o: Origin
): Unit = {
t match {
case cls: TByValueClass[Pre] => {
calls.add(
InvokeProcedure[Post](
castHelpers.getOrElseUpdate(t, makeCastHelper(t)).ref,
Nil,
Nil,
Nil,
Nil,
Nil,
)(TrueSatisfiable)(o)
)
cls.cls.decl.decls.collectFirst { case field: InstanceField[Pre] =>
addCastHelpers(field.t, calls)
}
}
case _ =>
}
}

override def dispatch(e: Expr[Pre]): Expr[Post] =
e match {
case inv @ MethodInvocation(
Expand Down Expand Up @@ -640,7 +818,18 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] {
Nil,
Nil,
)(PanicBlame("instanceOf requires nothing"))(e.o)
case Cast(value, typeValue) if value.t.asPointer.isEmpty =>
case Cast(value, typeValue) if value.t.asPointer.isDefined => {
// Keep pointer casts and add extra annotations
// TODO: Check if we need to get rid of the pointer add's here since in my testing that broke some of the reasoning
if (castHelperCalls.nonEmpty) {
addCastHelpers(value.t.asPointer.get.element, castHelperCalls.top)(
e.o
)
}

e.rewriteDefault()
}
case Cast(value, typeValue) =>
dispatch(
value
) // Discard for now, should assert instanceOf(value, typeValue)
Expand Down
Loading

0 comments on commit 42aca99

Please sign in to comment.