Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Lookup Constraints #201

Merged
merged 5 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions pkg/air/gadgets/expand.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package gadgets

import (
"fmt"

"github.com/consensys/go-corset/pkg/air"
sc "github.com/consensys/go-corset/pkg/schema"
"github.com/consensys/go-corset/pkg/schema/assignment"
)

// Expand converts an arbitrary expression into a specific column index. In
// general, this means adding a computed column to hold the value of the
// arbitrary expression and returning its index. However, this can be optimised
// in the case the given expression is a direct column access by simply
// returning the accessed column index.
func Expand(e air.Expr, schema *air.Schema) uint {
//
if ca, ok := e.(*air.ColumnAccess); ok && ca.Shift == 0 {
// Optimisation possible
return ca.Column
}
// No optimisation, therefore expand using a computedcolumn
module := sc.DetermineEnclosingModuleOfExpression(e, schema)
// Determine computed column name
name := e.String()
// Look up column
index, ok := sc.ColumnIndexOf(schema, module, name)
// Add new column (if it does not already exist)
if !ok {
// Add computed column
index = schema.AddAssignment(assignment.NewComputedColumn(module, name, e))
}
// Construct v == [e]
v := air.NewColumnAccess(index, 0)
// Construct 1 == e/e
eq_e_v := v.Equate(e)
// Ensure (e - v) == 0, where v is value of computed column.
c_name := fmt.Sprintf("[%s]", e.String())
schema.AddVanishingConstraint(c_name, module, nil, eq_e_v)
//
return index
}
24 changes: 24 additions & 0 deletions pkg/air/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ import (
// DataColumn captures the essence of a data column at AIR level.
type DataColumn = *assignment.DataColumn

// LookupConstraint captures the essence of a lookup constraint at the HIR
// level. At the AIR level, lookup constraints are only permitted between
// columns (i.e. not arbitrary expressions).
type LookupConstraint = *constraint.LookupConstraint[*ColumnAccess]

// PropertyAssertion captures the notion of an arbitrary property which should
// hold for all acceptable traces. However, such a property is not enforced by
// the prover.
Expand Down Expand Up @@ -81,6 +86,25 @@ func (p *Schema) AddAssignment(c schema.Assignment) uint {
return index
}

// AddLookupConstraint appends a new lookup constraint.
func (p *Schema) AddLookupConstraint(handle string, source uint, target uint, sources []uint, targets []uint) {
if len(targets) != len(sources) {
panic("differeng number of target / source lookup columns")
}
// TODO: sanity source columns are in the same module, and likewise target
// columns (though they don't have to be in the same column together).
from := make([]Expr, len(sources))
into := make([]Expr, len(targets))
// Construct column accesses from column indices.
for i := 0; i < len(from); i++ {
from[i] = NewColumnAccess(sources[i], 0)
into[i] = NewColumnAccess(targets[i], 0)
}
//
p.constraints = append(p.constraints,
constraint.NewLookupConstraint(handle, source, target, from, into))
}

// AddPermutationConstraint appends a new permutation constraint which
// ensures that one column is a permutation of another.
func (p *Schema) AddPermutationConstraint(targets []uint, sources []uint) {
Expand Down
8 changes: 7 additions & 1 deletion pkg/cmd/debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,13 @@ var debugCmd = &cobra.Command{

// Print out all declarations included in a given
func printSchema(schema schema.Schema) {
panic("todo")
for i := schema.Declarations(); i.HasNext(); {
fmt.Println(i.Next())
}

for i := schema.Constraints(); i.HasNext(); {
fmt.Println(i.Next())
}
}

func init() {
Expand Down
33 changes: 31 additions & 2 deletions pkg/hir/lower.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ func (p *Schema) LowerToMir() *mir.Schema {

func lowerConstraintToMir(c sc.Constraint, schema *mir.Schema) {
// Check what kind of constraint we have
if v, ok := c.(VanishingConstraint); ok {
if v, ok := c.(LookupConstraint); ok {
lowerLookupConstraint(v, schema)
} else if v, ok := c.(VanishingConstraint); ok {
mir_exprs := v.Constraint().Expr.LowerTo(schema)
// Add individual constraints arising
for _, mir_expr := range mir_exprs {
Expand All @@ -61,6 +63,33 @@ func lowerConstraintToMir(c sc.Constraint, schema *mir.Schema) {
}
}

func lowerLookupConstraint(c LookupConstraint, schema *mir.Schema) {
sources := c.Sources()
targets := c.Targets()
from := make([]mir.Expr, len(sources))
into := make([]mir.Expr, len(targets))
// Convert general expressions into unit expressions.
for i := 0; i < len(from); i++ {
from[i] = lowerUnitTo(sources[i], schema)
into[i] = lowerUnitTo(targets[i], schema)
}
//
schema.AddLookupConstraint(c.Handle(), c.SourceModule(), c.TargetModule(), from, into)
}

// Lower an expression which is expected to lower into a single expression.
// This will panic if the unit expression is malformed (i.e. does not lower
// into a single expression).
func lowerUnitTo(e UnitExpr, schema *mir.Schema) mir.Expr {
exprs := lowerTo(e.expr, schema)

if len(exprs) != 1 {
panic("invalid unitary expression")
}

return exprs[0]
}

// LowerTo lowers a sum expression to the MIR level. This requires expanding
// the arguments, then lowering them. Furthermore, conditionals are "lifted" to
// the top.
Expand Down Expand Up @@ -119,7 +148,7 @@ func (e *Sub) LowerTo(schema *mir.Schema) []mir.Expr {
}

// ============================================================================
// expandedLowerTo
// lowerTo
// ============================================================================

// Lowers a given expression to the MIR level. The expression is first expanded
Expand Down
132 changes: 125 additions & 7 deletions pkg/hir/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ type hirParser struct {
// Environment used during parsing to resolve column names into column
// indices.
env *Environment
// Global is used exclusively when parsing expressions to signal whether or
// not qualified column accesses are permitted (i.e. which include a
// module).
global bool
}

func newHirParser(srcmap *sexp.SourceMap[sexp.SExp]) *hirParser {
Expand All @@ -63,7 +67,7 @@ func newHirParser(srcmap *sexp.SourceMap[sexp.SExp]) *hirParser {
// Register top-level module (aka the prelude)
prelude := env.RegisterModule("")
// Construct parser
parser := &hirParser{p, prelude, env}
parser := &hirParser{p, prelude, env, false}
// Configure translator
p.AddSymbolRule(constantParserRule)
p.AddSymbolRule(columnAccessParserRule(parser))
Expand Down Expand Up @@ -97,6 +101,8 @@ func (p *hirParser) parseDeclaration(s sexp.SExp) error {
return p.parseAssertionDeclaration(e.Elements)
} else if e.Len() == 3 && e.MatchSymbols(1, "permute") {
return p.parseSortedPermutationDeclaration(e)
} else if e.Len() == 4 && e.MatchSymbols(1, "lookup") {
return p.parseLookupDeclaration(e)
}
}
// Error
Expand Down Expand Up @@ -215,10 +221,71 @@ func (p *hirParser) parseSortedPermutationDeclaration(l *sexp.List) error {
return nil
}

// Parse a lookup declaration
func (p *hirParser) parseLookupDeclaration(l *sexp.List) error {
handle := l.Elements[1].String()
// Target columns are (sorted) permutations of source columns.
sexpTargets := l.Elements[2].AsList()
// Source columns.
sexpSources := l.Elements[3].AsList()
// Sanity check number of target colunms matches number of source columns.
if sexpTargets.Len() != sexpSources.Len() {
return p.translator.SyntaxError(l, "lookup constraint requires matching number of source and target columns")
}
// Sanity check expressions have unitary form.
for i := 0; i < sexpTargets.Len(); i++ {
// Sanity check source and target expressions do not contain expression
// forms which are not permitted within a unitary expression.
if err := p.checkUnitExpr(sexpTargets.Get(i)); err != nil {
return err
}

if err := p.checkUnitExpr(sexpSources.Get(i)); err != nil {
return err
}
}
// Proceed with translation
targets := make([]UnitExpr, sexpTargets.Len())
sources := make([]UnitExpr, sexpSources.Len())
// Lookup expressions are permitted to make fully qualified accesses. This
// is because inter-module lookups are supported.
p.global = true
// Parse source / target expressions
for i := 0; i < len(targets); i++ {
target, err1 := p.translator.Translate(sexpTargets.Get(i))
source, err2 := p.translator.Translate(sexpSources.Get(i))

if err1 != nil {
return err1
} else if err2 != nil {
return err2
}
// Done
targets[i] = UnitExpr{target}
sources[i] = UnitExpr{source}
}
// Sanity check enclosing source and target modules
source, err1 := schema.DetermineEnclosingModuleOfExpressions(sources, p.env.schema)
target, err2 := schema.DetermineEnclosingModuleOfExpressions(targets, p.env.schema)
// Propagate errors
if err1 != nil {
return p.translator.SyntaxError(sexpSources.Get(int(source)), err1.Error())
} else if err2 != nil {
return p.translator.SyntaxError(sexpTargets.Get(int(target)), err2.Error())
}
// Finally add constraint
p.env.schema.AddLookupConstraint(handle, source, target, sources, targets)
// DOne
return nil
}

// Parse a property assertion
func (p *hirParser) parseAssertionDeclaration(elements []sexp.SExp) error {
handle := elements[1].String()

// Property assertions do not have global scope, hence qualified column
// accesses are not permitted.
p.global = false
// Translate
expr, err := p.translator.Translate(elements[2])
if err != nil {
return err
Expand All @@ -232,7 +299,10 @@ func (p *hirParser) parseAssertionDeclaration(elements []sexp.SExp) error {
// Parse a vanishing declaration
func (p *hirParser) parseVanishingDeclaration(elements []sexp.SExp, domain *int) error {
handle := elements[1].String()

// Vanishing constraints do not have global scope, hence qualified column
// accesses are not permitted.
p.global = false
// Translate
expr, err := p.translator.Translate(elements[2])
if err != nil {
return err
Expand Down Expand Up @@ -262,6 +332,34 @@ func (p *hirParser) parseType(term sexp.SExp) (sc.Type, error) {
return nil, p.translator.SyntaxError(symbol, "unknown type")
}

// Check that a given expression conforms to the requirements of a unitary
// expression. That is, it cannot contain an "if", "ifnot" or "begin"
// expression form.
func (p *hirParser) checkUnitExpr(term sexp.SExp) error {
l := term.AsList()

if l != nil && l.Len() > 0 {
if head := l.Get(0).AsSymbol(); head != nil {
switch head.Value {
case "if":
fallthrough
case "ifnot":
fallthrough
case "begin":
return p.translator.SyntaxError(term, "not permitted in lookup")
}
}
// Check arguments
for i := 0; i < l.Len(); i++ {
if err := p.checkUnitExpr(l.Get(i)); err != nil {
return err
}
}
}

return nil
}

func beginParserRule(args []Expr) (Expr, error) {
return &List{args}, nil
}
Expand All @@ -285,18 +383,38 @@ func constantParserRule(symbol string) (Expr, bool, error) {
func columnAccessParserRule(parser *hirParser) func(col string) (Expr, bool, error) {
// Returns a closure over the parser.
return func(col string) (Expr, bool, error) {
var ok bool
// Sanity check what we have
if !unicode.IsLetter(rune(col[0])) {
return nil, false, nil
}
// Look up column in the environment
i, ok := parser.env.LookupColumn(parser.module, col)
// Handle qualified accesses (where permitted)
module := parser.module
colname := col
// Attempt to split column name into module / column pair.
split := strings.Split(col, ".")
if parser.global && len(split) == 2 {
// Lookup module
if module, ok = parser.env.LookupModule(split[0]); !ok {
return nil, true, errors.New("unknown module")
}

colname = split[1]
} else if len(split) > 2 {
return nil, true, errors.New("malformed column access")
} else if len(split) == 2 {
return nil, true, errors.New("qualified column access not permitted here")
}
// Now lookup column in the appropriate module.
var cid uint
// Look up column in the environment using local scope.
cid, ok = parser.env.LookupColumn(module, colname)
// Check column was found
if !ok {
return nil, true, fmt.Errorf("unknown column %s", col)
return nil, true, errors.New("unknown column")
}
// Done
return &ColumnAccess{i, 0}, true, nil
return &ColumnAccess{cid, 0}, true, nil
}
}

Expand Down
Loading