Skip to content

Commit

Permalink
Categorize let bindings so no longer missing alias patterns
Browse files Browse the repository at this point in the history
  • Loading branch information
chengluyu committed Mar 15, 2023
1 parent fef2c0c commit 8ac73ee
Show file tree
Hide file tree
Showing 10 changed files with 227 additions and 114 deletions.
13 changes: 7 additions & 6 deletions shared/src/main/scala/mlscript/ucs/Clause.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ abstract class Clause {
/**
* Local interleaved let bindings declared before this condition.
*/
var bindings: Ls[(Bool, Var, Term)] = Nil
var bindings: Ls[LetBinding] = Nil

/**
* Locations of terms that build this `Clause`.
Expand Down Expand Up @@ -56,18 +56,19 @@ object Clause {
override def toString(): String = s"«$test»" + bindingsToString
}

final case class Binding(name: Var, term: Term)(
/**
* @param isField whether this binding is extracting a class field
*/
final case class Binding(name: Var, term: Term, isField: Bool)(
override val locations: Ls[Loc]
) extends Clause {
override def toString(): String = s"«$name = $term»" + bindingsToString
}

def showBindings(bindings: Ls[(Bool, Var, Term)]): Str =
def showBindings(bindings: Ls[LetBinding]): Str =
bindings match {
case Nil => ""
case bindings => bindings.map {
case (_, Var(name), _) => name
}.mkString("(", ", ", ")")
case bindings => bindings.map(_.name.name).mkString("(", ", ", ")")
}

def showClauses(clauses: Iterable[Clause]): Str = clauses.mkString("", " and ", "")
Expand Down
6 changes: 3 additions & 3 deletions shared/src/main/scala/mlscript/ucs/Conjunction.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import scala.annotation.tailrec
/**
* A `Conjunction` represents a list of `Clause`s.
*/
final case class Conjunction(clauses: Ls[Clause], trailingBindings: Ls[(Bool, Var, Term)]) {
final case class Conjunction(clauses: Ls[Clause], trailingBindings: Ls[LetBinding]) {
/**
* Concatenate two `Conjunction` together.
*
Expand Down Expand Up @@ -51,7 +51,7 @@ final case class Conjunction(clauses: Ls[Clause], trailingBindings: Ls[(Bool, Va
* @param suffix the list of clauses to append to this conjunction
* @return a new conjunction with clauses from `this` and `suffix`
*/
def +(lastBinding: (Bool, Var, Term)): Conjunction =
def +(lastBinding: LetBinding): Conjunction =
Conjunction(clauses, trailingBindings :+ lastBinding)

def separate(expectedScrutinee: Scrutinee): Opt[(MatchClass \/ MatchLiteral, Conjunction)] = {
Expand Down Expand Up @@ -87,7 +87,7 @@ final case class Conjunction(clauses: Ls[Clause], trailingBindings: Ls[(Bool, Va
* @param interleavedLets the buffer of let bindings in the current context
* @return idential to `conditions`
*/
def withBindings(implicit interleavedLets: Buffer[(Bool, Var, Term)]): Conjunction = {
def withBindings(implicit interleavedLets: Buffer[LetBinding]): Conjunction = {
clauses match {
case Nil => Conjunction(Nil, interleavedLets.toList ::: trailingBindings)
case head :: _ =>
Expand Down
152 changes: 96 additions & 56 deletions shared/src/main/scala/mlscript/ucs/Desugarer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class Desugarer extends TypeDefs { self: Typer =>
private def destructSubPatterns(scrutinee: Scrutinee, subPatterns: Iterable[Var -> Term])
(implicit ctx: Ctx, raise: Raise, aliasMap: FieldAliasMap): Ls[Clause] = {
subPatterns.iterator.flatMap[Clause] { case (subScrutinee, subPattern) =>
destructPattern(makeScrutinee(subScrutinee, scrutinee.matchRootLoc), subPattern)
destructPattern(makeScrutinee(subScrutinee, scrutinee.matchRootLoc), subPattern, false)
}.toList
}

Expand Down Expand Up @@ -130,6 +130,7 @@ class Desugarer extends TypeDefs { self: Typer =>
*
* @param scrutinee the scrutinee of the pattern matching
* @param pattern the pattern we will destruct
* @param isTopLevel whether this pattern just follows the `is` operator
* @param raise the `Raise` function
* @param aliasMap the field alias map
* @param matchRootLoc the location of the root of the pattern matching
Expand All @@ -140,7 +141,7 @@ class Desugarer extends TypeDefs { self: Typer =>
* do not contain interleaved let bindings.
*/
private def destructPattern
(scrutinee: Scrutinee, pattern: Term)
(scrutinee: Scrutinee, pattern: Term, isTopLevel: Bool)
(implicit ctx: Ctx,
raise: Raise,
aliasMap: FieldAliasMap,
Expand Down Expand Up @@ -178,7 +179,7 @@ class Desugarer extends TypeDefs { self: Typer =>
// This case handles name binding.
// x is a
case bindingVar @ Var(bindingName) if bindingName.headOption.exists(_.isLower) =>
Clause.Binding(bindingVar, scrutinee.term)(scrutinee.term.toLoc.toList ::: bindingVar.toLoc.toList) :: Nil
Clause.Binding(bindingVar, scrutinee.term, !isTopLevel)(scrutinee.term.toLoc.toList ::: bindingVar.toLoc.toList) :: Nil
// This case handles simple class tests.
// x is A
case classNameVar @ Var(className) =>
Expand Down Expand Up @@ -246,12 +247,12 @@ class Desugarer extends TypeDefs { self: Typer =>
msg"Cannot find operator `$op` in the context"
}, opVar.toLoc)
case S(td) if td.positionals.length === 2 =>
val (subPatterns, bindings) = desugarPositionals(
val (subPatterns, fields) = desugarPositionals(
scrutinee,
lhs :: rhs :: Nil,
td.positionals
)
val clause = Clause.MatchClass(scrutinee, opVar, bindings)(collectLocations(scrutinee.term))
val clause = Clause.MatchClass(scrutinee, opVar, fields)(collectLocations(scrutinee.term))
printlnUCS(s"Build a Clause.MatchClass from $scrutinee where operator is $opVar")
clause :: destructSubPatterns(scrutinee, subPatterns)
case S(td) =>
Expand Down Expand Up @@ -293,7 +294,7 @@ class Desugarer extends TypeDefs { self: Typer =>
def desugarIf
(body: IfBody, fallback: Opt[Term])
(implicit ctx: Ctx, raise: Raise)
: Ls[Conjunction -> Term] = {
: Ls[Conjunction -> Term] = traceUCS(s"[desugarIf] with fallback $fallback") {
// We allocate temporary variable names for nested patterns.
// This prevents aliasing problems.
implicit val scrutineeFieldAliasMap: FieldAliasMap = MutMap.empty
Expand All @@ -317,7 +318,7 @@ class Desugarer extends TypeDefs { self: Typer =>
// This is an inline `x is Class` match test.
val inlineMatchLoc = isApp.toLoc
val inlineScrutinee = makeScrutinee(scrutinee, inlineMatchLoc)
destructPattern(inlineScrutinee, pattern)(ctx, raise, scrutineeFieldAliasMap)
destructPattern(inlineScrutinee, pattern, true)(ctx, raise, scrutineeFieldAliasMap)
case test =>
val clause = Clause.BooleanTest(test)(collectLocations(test))
Iterable.single(clause)
Expand All @@ -339,7 +340,7 @@ class Desugarer extends TypeDefs { self: Typer =>
body: IfBody \/ Statement,
partialPattern: PartialTerm,
collectedConditions: Conjunction,
)(implicit interleavedLets: Buffer[(Bool, Var, Term)]): Unit =
)(implicit interleavedLets: Buffer[LetBinding]): Unit = traceUCS("[desugarMatchBranch]") {
body match {
// This case handles default branches. For example,
// if x is
Expand All @@ -360,7 +361,7 @@ class Desugarer extends TypeDefs { self: Typer =>
// B(...) and ... then ... // Case 2: more conjunctions
case L(IfThen(patTest, consequent)) =>
val (patternPart, extraTestOpt) = separatePattern(patTest)
val clauses = destructPattern(scrutinee, partialPattern.addTerm(patternPart).term)
val clauses = destructPattern(scrutinee, partialPattern.addTerm(patternPart).term, true)
val conditions = collectedConditions + Conjunction(clauses, Nil).withBindings
printlnUCS(s"result conditions: " + Clause.showClauses(conditions.clauses))
extraTestOpt match {
Expand All @@ -380,7 +381,7 @@ class Desugarer extends TypeDefs { self: Typer =>
// B(...) then ...
case L(IfOpApp(patLhs, Var("and"), consequent)) =>
val (pattern, optTests) = separatePattern(patLhs)
val patternConditions = destructPattern(scrutinee, pattern)
val patternConditions = destructPattern(scrutinee, pattern, true)
val tailTestConditions = optTests.fold(Nil: Ls[Clause])(x => desugarConditions(splitAnd(x)))
val conditions =
collectedConditions + Conjunction(patternConditions ::: tailTestConditions, Nil).withBindings
Expand All @@ -391,7 +392,7 @@ class Desugarer extends TypeDefs { self: Typer =>
// The pattern is completed. There is also a conjunction.
// So, we need to separate the pattern from remaining parts.
case (pattern, S(extraTests)) =>
val patternConditions = destructPattern(scrutinee, pattern)
val patternConditions = destructPattern(scrutinee, pattern, true)
val extraConditions = desugarConditions(splitAnd(extraTests))
val conditions =
collectedConditions + Conjunction(patternConditions ::: extraConditions, Nil).withBindings
Expand All @@ -413,7 +414,7 @@ class Desugarer extends TypeDefs { self: Typer =>
desugarMatchBranch(scrutinee, L(consequent), partialPattern2.addOp(op), collectedConditions)
}
case (patternPart, S(extraTests)) =>
val patternConditions = destructPattern(scrutinee, partialPattern.addTerm(patternPart).term)
val patternConditions = destructPattern(scrutinee, partialPattern.addTerm(patternPart).term, true)
val testTerms = splitAnd(extraTests)
val middleConditions = desugarConditions(testTerms.init)
val conditions =
Expand All @@ -433,16 +434,18 @@ class Desugarer extends TypeDefs { self: Typer =>
TODO("please add this rare case to test files")
// This case handles interleaved lets.
case R(NuFunDef(S(isRec), nameVar, _, L(term))) =>
interleavedLets += ((isRec, nameVar, term))
interleavedLets += (LetBinding(LetBinding.Kind.InterleavedLet, isRec, nameVar, term))
// Other statements are considered to be ill-formed.
case R(statement) => throw new DesugaringException({
msg"Illegal interleaved statement ${statement.toString}"
}, statement.toLoc)
}
}(_ => "[desugarMatchBranch]")

def desugarIfBody
(body: IfBody, expr: PartialTerm, acc: Conjunction)
(implicit interleavedLets: Buffer[(Bool, Var, Term)])
: Unit = {
(implicit interleavedLets: Buffer[LetBinding])
: Unit = traceUCS("[desugarIfBody]") {
body match {
case IfOpsApp(exprPart, opsRhss) =>
val exprStart = expr.addTerm(exprPart)
Expand Down Expand Up @@ -502,20 +505,21 @@ class Desugarer extends TypeDefs { self: Typer =>
case L(subBody) => desugarIfBody(subBody, expr, acc)
case R(NuFunDef(S(isRec), nameVar, _, L(term))) =>
printlnUCS(s"Found interleaved binding ${nameVar.name}")
interleavedLets += ((isRec, nameVar, term))
interleavedLets += LetBinding(LetBinding.Kind.InterleavedLet, isRec, nameVar, term)
case R(_) =>
throw new Error("unexpected statements at desugarIfBody")
}
}
}
}(_ => "[desugarIfBody]")

// Top-level interleaved let bindings.
val interleavedLets = Buffer.empty[(Bool, Var, Term)]
val interleavedLets = Buffer.empty[LetBinding]
desugarIfBody(body, PartialTerm.Empty, Conjunction.empty)(interleavedLets)
// Add the fallback case to conjunctions if there is any.
fallback.foreach { branches += Conjunction.empty -> _ }
Clause.print(printlnUCS, branches)
branches.toList
}
}(r => s"[desugarIf] produces ${r.size} branch(es)")

import MutCaseOf.{MutCase, IfThenElse, Match, MissingCase, Consequent}

Expand Down Expand Up @@ -628,7 +632,7 @@ class Desugarer extends TypeDefs { self: Typer =>
checkExhaustive(branch.consequent, S(t))
}
}
}()
}(_ => s"[checkExhaustive] ${t.describe}")

def summarizePatterns(t: MutCaseOf)(implicit ctx: Ctx, raise: Raise): ExhaustivenessMap = traceUCS("[summarizePatterns]") {
val m = MutMap.empty[Str \/ Int, MutMap[Var, MutCase]]
Expand Down Expand Up @@ -670,45 +674,81 @@ class Desugarer extends TypeDefs { self: Typer =>
printlnUCS(s"- $scrutinee => " + patterns.keysIterator.mkString(", "))
}
Map.from(m.iterator.map { case (key, patternMap) => key -> Map.from(patternMap) })
}()
}(_ => "[summarizePatterns]")

protected def constructTerm(m: MutCaseOf)(implicit ctx: Ctx): Term = {
def rec(m: MutCaseOf)(implicit defs: Set[Var]): Term = m match {
case Consequent(term) => term
case Match(scrutinee, branches, wildcard) =>
def rec2(xs: Ls[MutCase]): CaseBranches =
xs match {
case MutCase.Constructor(className -> fields, cases) :: next =>
// TODO: expand bindings here
val consequent = rec(cases)(defs ++ fields.iterator.map(_._2))
Case(className, mkLetFromFields(scrutinee, fields.toList, consequent), rec2(next))
case MutCase.Literal(literal, cases) :: next =>
val consequent = rec(cases)
Case(literal, consequent, rec2(next))
case Nil =>
wildcard.fold[CaseBranches](NoCases) { rec(_) |> Wildcard }
/**
* Make a term from a mutable case tree.
* This should be called after exhaustiveness checking.
*
* @param m the mutable case tree
* @param ctx the context
* @return the case expression
*/
protected def constructTerm(m: MutCaseOf)(implicit ctx: Ctx): Term = traceUCS("[constructTerm]") {
/**
* Reconstruct case branches.
*/
def rec2(xs: Ls[MutCase])(
implicit defs: Set[Var], scrutinee: Scrutinee, wildcard: Option[MutCaseOf]
): CaseBranches = {
xs match {
case MutCase.Constructor(className -> fields, cases) :: next =>
printlnUCS(s"• Constructor pattern: $className(${fields.iterator.map(x => s"${x._1} -> ${x._2}").mkString(", ")})")
// TODO: expand bindings here
val consequent = rec(cases)(defs ++ fields.iterator.map(_._2))
Case(className, mkLetFromFields(scrutinee, fields.toList, consequent), rec2(next))
case MutCase.Literal(literal, cases) :: next =>
printlnUCS(s"• Literal pattern: $literal")
Case(literal, rec(cases), rec2(next))
case Nil =>
wildcard match {
case None =>
printlnUCS("• No wildcard branch")
NoCases
case Some(value) =>
printlnUCS("• Wildcard branch")
Wildcard(rec(value))
}
val cases = rec2(branches.toList)
val resultTerm = scrutinee.local match {
case N => CaseOf(scrutinee.term, cases)
case S(aliasVar) => Let(false, aliasVar, scrutinee.term, CaseOf(aliasVar, cases))
}
// Collect let bindings from case branches.
val bindings = branches.iterator.flatMap(_.consequent.getBindings).toList
mkBindings(bindings, resultTerm, defs)
case MissingCase =>
import Message.MessageContext
throw new DesugaringException(msg"missing a default branch", N)
case IfThenElse(condition, whenTrue, whenFalse) =>
val falseBody = mkBindings(whenFalse.getBindings.toList, rec(whenFalse)(defs ++ whenFalse.getBindings.iterator.map(_._2)), defs)
val trueBody = mkBindings(whenTrue.getBindings.toList, rec(whenTrue)(defs ++ whenTrue.getBindings.iterator.map(_._2)), defs)
val falseBranch = Wildcard(falseBody)
val trueBranch = Case(Var("true"), trueBody, falseBranch)
CaseOf(condition, trueBranch)
}
}
val term = rec(m)(Set.from(m.getBindings.iterator.map(_._2)))
/**
* Reconstruct the entire match.
*/
def rec(m: MutCaseOf)(implicit defs: Set[Var]): Term = traceUCS(s"[rec] ${m.describe} -| {${defs.mkString(", ")}}") {
m match {
case Consequent(term) => term
case Match(scrutinee, branches, wildcard) =>
val cases = traceUCS("• For each case branch"){
rec2(branches.toList)(defs, scrutinee, wildcard)
}(_ => "• End for each")
val resultTerm = scrutinee.local match {
case N => CaseOf(scrutinee.term, cases)
case S(aliasVar) => Let(false, aliasVar, scrutinee.term, CaseOf(aliasVar, cases))
}
// Collect interleaved let bindings from case branches.
val bindings = branches.iterator.flatMap(_.consequent.getBindings).filter {
_.kind === LetBinding.Kind.InterleavedLet
}.toList
printlnUCS("• Collect interleaved let bindings from case branches")
bindings.foreach { case LetBinding(_, _, name, value) =>
printlnUCS(s" - $name = $value")
}
mkBindings(bindings, resultTerm, defs)
case MissingCase =>
import Message.MessageContext
throw new DesugaringException(msg"missing a default branch", N)
case IfThenElse(condition, whenTrue, whenFalse) =>
val falseBody = mkBindings(whenFalse.getBindings.toList, rec(whenFalse)(defs ++ whenFalse.getBindings.iterator.map(_.name)), defs)
val trueBody = mkBindings(whenTrue.getBindings.toList, rec(whenTrue)(defs ++ whenTrue.getBindings.iterator.map(_.name)), defs)
val falseBranch = Wildcard(falseBody)
val trueBranch = Case(Var("true"), trueBody, falseBranch)
CaseOf(condition, trueBranch)
}
}()
val term = rec(m)(Set.from(m.getBindings.iterator.map(_.name)))
// Create immutable map from the mutable map.
mkBindings(m.getBindings.toList, term, Set.empty)
}
}(_ => "[constructTerm]")

/**
* Generate a chain of field selection to the given scrutinee.
Expand All @@ -726,7 +766,7 @@ class Desugarer extends TypeDefs { self: Typer =>
// Check if the scrutinee is a `Var` and its name conflicts with
// one of the positionals. If so, we create an alias and extract
// fields by selecting the alias.
case Var(scrutineeName) if alias == scrutineeName =>
case Var(scrutineeName) if alias === scrutineeName =>
val scrutineeAlias = Var(freshName)
Let(
false,
Expand Down
40 changes: 40 additions & 0 deletions shared/src/main/scala/mlscript/ucs/LetBinding.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package mlscript.ucs

import mlscript._
import mlscript.utils._
import mlscript.utils.shorthands._
import scala.collection.immutable.Set
import scala.collection.mutable.{Set => MutSet, Buffer}

case class LetBinding(val kind: LetBinding.Kind, val recursive: Bool, val name: Var, val term: Term)

object LetBinding {
sealed abstract class Kind

object Kind {
case object ScrutineeAlias extends Kind
case object FieldExtraction extends Kind
case object InterleavedLet extends Kind
}
}

trait WithBindings { this: MutCaseOf =>
private val bindingsSet: MutSet[LetBinding] = MutSet.empty
private val bindings: Buffer[LetBinding] = Buffer.empty

def addBindings(newBindings: IterableOnce[LetBinding]): Unit = {
newBindings.iterator.foreach {
case binding if bindingsSet.contains(binding) => ()
case binding =>
bindingsSet += binding
bindings += binding
}
}

def getBindings: Iterable[LetBinding] = bindings

def withBindings(newBindings: IterableOnce[LetBinding]): MutCaseOf = {
addBindings(newBindings)
this
}
}
Loading

0 comments on commit 8ac73ee

Please sign in to comment.