Skip to content

Commit

Permalink
Merge pull request #210 from Consensys/204-solidify-notion-of-evaluat…
Browse files Browse the repository at this point in the history
…ion-context2

Solidify Notion of Evaluation Context
  • Loading branch information
DavePearce authored Jul 5, 2024
2 parents b2d1668 + ef5b612 commit 32ece2d
Show file tree
Hide file tree
Showing 40 changed files with 491 additions and 475 deletions.
20 changes: 9 additions & 11 deletions pkg/air/expr.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
package air

import (
"math"

"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/schema"
sc "github.com/consensys/go-corset/pkg/schema"
"github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/util"
)

Expand All @@ -17,7 +15,7 @@ import (
// trace expansion).
type Expr interface {
util.Boundable
schema.Evaluable
sc.Evaluable

// String produces a string representing this as an S-Expression.
String() string
Expand All @@ -44,7 +42,7 @@ type Add struct{ Args []Expr }

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (p *Add) Context(schema sc.Schema) (uint, uint, bool) {
func (p *Add) Context(schema sc.Schema) trace.Context {
return sc.JoinContexts[Expr](p.Args, schema)
}

Expand Down Expand Up @@ -73,7 +71,7 @@ type Sub struct{ Args []Expr }

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (p *Sub) Context(schema sc.Schema) (uint, uint, bool) {
func (p *Sub) Context(schema sc.Schema) trace.Context {
return sc.JoinContexts[Expr](p.Args, schema)
}

Expand Down Expand Up @@ -102,7 +100,7 @@ type Mul struct{ Args []Expr }

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (p *Mul) Context(schema sc.Schema) (uint, uint, bool) {
func (p *Mul) Context(schema sc.Schema) trace.Context {
return sc.JoinContexts[Expr](p.Args, schema)
}

Expand Down Expand Up @@ -154,8 +152,8 @@ func NewConstCopy(val *fr.Element) Expr {

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (p *Constant) Context(schema sc.Schema) (uint, uint, bool) {
return math.MaxUint, math.MaxUint, true
func (p *Constant) Context(schema sc.Schema) trace.Context {
return trace.VoidContext()
}

// Add two expressions together, producing a third.
Expand Down Expand Up @@ -193,9 +191,9 @@ func NewColumnAccess(column uint, shift int) Expr {

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (p *ColumnAccess) Context(schema sc.Schema) (uint, uint, bool) {
func (p *ColumnAccess) Context(schema sc.Schema) trace.Context {
col := schema.Columns().Nth(p.Column)
return col.Module(), col.LengthMultiplier(), true
return col.Context()
}

// Add two expressions together, producing a third.
Expand Down
6 changes: 3 additions & 3 deletions pkg/air/gadgets/bits.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func ApplyBinaryGadget(col uint, schema *air.Schema) {
// Construct X * (X-1)
X_X_m1 := X.Mul(X_m1)
// Done!
schema.AddVanishingConstraint(fmt.Sprintf("%s:u1", name), column.Module(), column.LengthMultiplier(), nil, X_X_m1)
schema.AddVanishingConstraint(fmt.Sprintf("%s:u1", name), column.Context(), nil, X_X_m1)
}

// ApplyBitwidthGadget ensures all values in a given column fit within a given
Expand All @@ -45,7 +45,7 @@ func ApplyBitwidthGadget(col uint, nbits uint, schema *air.Schema) {
coefficient := fr.NewElement(1)
// Add decomposition assignment
index := schema.AddAssignment(
assignment.NewByteDecomposition(name, column.Module(), column.LengthMultiplier(), col, n))
assignment.NewByteDecomposition(name, column.Context(), col, n))
// Construct Columns
for i := uint(0); i < n; i++ {
// Create Column + Constraint
Expand All @@ -61,5 +61,5 @@ func ApplyBitwidthGadget(col uint, nbits uint, schema *air.Schema) {
X := air.NewColumnAccess(col, 0)
eq := X.Equate(sum)
// Construct column name
schema.AddVanishingConstraint(fmt.Sprintf("%s:u%d", name, nbits), column.Module(), column.LengthMultiplier(), nil, eq)
schema.AddVanishingConstraint(fmt.Sprintf("%s:u%d", name, nbits), column.Context(), nil, eq)
}
4 changes: 2 additions & 2 deletions pkg/air/gadgets/column_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ func ApplyColumnSortGadget(col uint, sign bool, bitwidth uint, schema *air.Schem
}
// Add delta assignment
deltaIndex := schema.AddAssignment(
assignment.NewComputedColumn(column.Module(), deltaName, column.LengthMultiplier(), Xdiff))
assignment.NewComputedColumn(column.Context(), deltaName, Xdiff))
// Add necessary bitwidth constraints
ApplyBitwidthGadget(deltaIndex, bitwidth, schema)
// Configure constraint: Delta[k] = X[k] - X[k-1]
Dk := air.NewColumnAccess(deltaIndex, 0)
schema.AddVanishingConstraint(deltaName, column.Module(), column.LengthMultiplier(), nil, Dk.Equate(Xdiff))
schema.AddVanishingConstraint(deltaName, column.Context(), nil, Dk.Equate(Xdiff))
}
8 changes: 4 additions & 4 deletions pkg/air/gadgets/expand.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,23 @@ func Expand(e air.Expr, schema *air.Schema) uint {
return ca.Column
}
// No optimisation, therefore expand using a computedcolumn
module, multiplier := sc.DetermineEnclosingModuleOfExpression(e, schema)
ctx := e.Context(schema)
// Determine computed column name
name := e.String()
// Look up column
index, ok := sc.ColumnIndexOf(schema, module, name)
index, ok := sc.ColumnIndexOf(schema, ctx.Module(), name)
// Add new column (if it does not already exist)
if !ok {
// Add computed column
index = schema.AddAssignment(assignment.NewComputedColumn(module, name, multiplier, e))
index = schema.AddAssignment(assignment.NewComputedColumn(ctx, name, e))
}
// Construct v == [e]
v := air.NewColumnAccess(index, 0)
// Construct 1 == e/e
eq_e_v := v.Equate(e)
// Ensure (e - v) == 0, where v is value of computed column.
c_name := fmt.Sprintf("[%s]", e.String())
schema.AddVanishingConstraint(c_name, module, multiplier, nil, eq_e_v)
schema.AddVanishingConstraint(c_name, ctx, nil, eq_e_v)
//
return index
}
17 changes: 9 additions & 8 deletions pkg/air/gadgets/lexicographic_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/consensys/go-corset/pkg/air"
sc "github.com/consensys/go-corset/pkg/schema"
"github.com/consensys/go-corset/pkg/schema/assignment"
"github.com/consensys/go-corset/pkg/trace"
)

// ApplyLexicographicSortingGadget Add sorting constraints for a sequence of one
Expand All @@ -33,19 +34,19 @@ func ApplyLexicographicSortingGadget(columns []uint, signs []bool, bitwidth uint
panic("Inconsistent number of columns and signs for lexicographic sort.")
}
// Determine enclosing module for this gadget.
module, multiplier := sc.DetermineEnclosingModuleOfColumns(columns, schema)
ctx := sc.ContextOfColumns(columns, schema)
// Construct a unique prefix for this sort.
prefix := constructLexicographicSortingPrefix(columns, signs, schema)
// Add trace computation
deltaIndex := schema.AddAssignment(
assignment.NewLexicographicSort(prefix, module, multiplier, columns, signs, bitwidth))
assignment.NewLexicographicSort(prefix, ctx, columns, signs, bitwidth))
// Construct selecto bits.
addLexicographicSelectorBits(prefix, module, multiplier, deltaIndex, columns, schema)
addLexicographicSelectorBits(prefix, ctx, deltaIndex, columns, schema)
// Construct delta terms
constraint := constructLexicographicDeltaConstraint(deltaIndex, columns, signs)
// Add delta constraint
deltaName := fmt.Sprintf("%s:delta", prefix)
schema.AddVanishingConstraint(deltaName, module, multiplier, nil, constraint)
schema.AddVanishingConstraint(deltaName, ctx, nil, constraint)
// Add necessary bitwidth constraints
ApplyBitwidthGadget(deltaIndex, bitwidth, schema)
}
Expand Down Expand Up @@ -77,7 +78,7 @@ func constructLexicographicSortingPrefix(columns []uint, signs []bool, schema *a
//
// NOTE: this implementation differs from the original corset which used an
// additional "Eq" bit to help ensure at most one selector bit was enabled.
func addLexicographicSelectorBits(prefix string, module uint, multiplier uint,
func addLexicographicSelectorBits(prefix string, context trace.Context,
deltaIndex uint, columns []uint, schema *air.Schema) {
ncols := uint(len(columns))
// Calculate column index of first selector bit
Expand All @@ -102,7 +103,7 @@ func addLexicographicSelectorBits(prefix string, module uint, multiplier uint,
pterms[i] = air.NewColumnAccess(bitIndex+i, 0)
pDiff := air.NewColumnAccess(columns[i], 0).Sub(air.NewColumnAccess(columns[i], -1))
pName := fmt.Sprintf("%s:%d:a", prefix, i)
schema.AddVanishingConstraint(pName, module, multiplier,
schema.AddVanishingConstraint(pName, context,
nil, air.NewConst64(1).Sub(&air.Add{Args: pterms}).Mul(pDiff))
// (∀j<i.Bj=0) ∧ Bi=1 ==> C[k]≠C[k-1]
qDiff := Normalise(air.NewColumnAccess(columns[i], 0).Sub(air.NewColumnAccess(columns[i], -1)), schema)
Expand All @@ -115,14 +116,14 @@ func addLexicographicSelectorBits(prefix string, module uint, multiplier uint,
constraint = air.NewConst64(1).Sub(&air.Add{Args: qterms}).Mul(constraint)
}

schema.AddVanishingConstraint(qName, module, multiplier, nil, constraint)
schema.AddVanishingConstraint(qName, context, nil, constraint)
}

sum := &air.Add{Args: terms}
// (sum = 0) ∨ (sum = 1)
constraint := sum.Mul(sum.Equate(air.NewConst64(1)))
name := fmt.Sprintf("%s:xor", prefix)
schema.AddVanishingConstraint(name, module, multiplier, nil, constraint)
schema.AddVanishingConstraint(name, context, nil, constraint)
}

// Construct the lexicographic delta constraint. This states that the delta
Expand Down
12 changes: 6 additions & 6 deletions pkg/air/gadgets/normalisation.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,17 @@ func Normalise(e air.Expr, schema *air.Schema) air.Expr {
// ensure it really holds the inverted value.
func ApplyPseudoInverseGadget(e air.Expr, schema *air.Schema) air.Expr {
// Determine enclosing module.
module, multiplier := sc.DetermineEnclosingModuleOfExpression(e, schema)
ctx := e.Context(schema)
// Construct inverse computation
ie := &Inverse{Expr: e}
// Determine computed column name
name := ie.String()
// Look up column
index, ok := sc.ColumnIndexOf(schema, module, name)
index, ok := sc.ColumnIndexOf(schema, ctx.Module(), name)
// Add new column (if it does not already exist)
if !ok {
// Add computed column
index = schema.AddAssignment(assignment.NewComputedColumn(module, name, multiplier, ie))
index = schema.AddAssignment(assignment.NewComputedColumn(ctx, name, ie))
}

// Construct 1/e
Expand All @@ -54,10 +54,10 @@ func ApplyPseudoInverseGadget(e air.Expr, schema *air.Schema) air.Expr {
inv_e_implies_one_e_e := inv_e.Mul(one_e_e)
// Ensure (e != 0) ==> (1 == e/e)
l_name := fmt.Sprintf("[%s <=]", ie.String())
schema.AddVanishingConstraint(l_name, module, multiplier, nil, e_implies_one_e_e)
schema.AddVanishingConstraint(l_name, ctx, nil, e_implies_one_e_e)
// Ensure (e/e != 0) ==> (1 == e/e)
r_name := fmt.Sprintf("[%s =>]", ie.String())
schema.AddVanishingConstraint(r_name, module, multiplier, nil, inv_e_implies_one_e_e)
schema.AddVanishingConstraint(r_name, ctx, nil, inv_e_implies_one_e_e)
// Done
return air.NewColumnAccess(index, 0)
}
Expand All @@ -81,7 +81,7 @@ func (e *Inverse) Bounds() util.Bounds { return e.Expr.Bounds() }

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (e *Inverse) Context(schema sc.Schema) (uint, uint, bool) {
func (e *Inverse) Context(schema sc.Schema) tr.Context {
return e.Expr.Context(schema)
}

Expand Down
23 changes: 12 additions & 11 deletions pkg/air/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/consensys/go-corset/pkg/schema"
"github.com/consensys/go-corset/pkg/schema/assignment"
"github.com/consensys/go-corset/pkg/schema/constraint"
"github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/util"
)

Expand Down Expand Up @@ -64,14 +65,14 @@ func (p *Schema) AddModule(name string) uint {

// AddColumn appends a new data column whose values must be provided by the
// user.
func (p *Schema) AddColumn(module uint, name string, datatype schema.Type) uint {
if module >= uint(len(p.modules)) {
panic(fmt.Sprintf("invalid module index (%d)", module))
func (p *Schema) AddColumn(context trace.Context, name string, datatype schema.Type) uint {
if context.Module() >= uint(len(p.modules)) {
panic(fmt.Sprintf("invalid module index (%d)", context.Module()))
}

// NOTE: the air level has no ability to enforce the type specified for a
// given column.
p.inputs = append(p.inputs, assignment.NewDataColumn(module, name, datatype))
p.inputs = append(p.inputs, assignment.NewDataColumn(context, name, datatype))
// Calculate column index
return uint(len(p.inputs) - 1)
}
Expand All @@ -87,8 +88,8 @@ func (p *Schema) AddAssignment(c schema.Assignment) uint {
}

// AddLookupConstraint appends a new lookup constraint.
func (p *Schema) AddLookupConstraint(handle string, source uint, source_multiplier uint,
target uint, target_multiplier uint, sources []uint, targets []uint) {
func (p *Schema) AddLookupConstraint(handle string, source trace.Context,
target trace.Context, sources []uint, targets []uint) {
if len(targets) != len(sources) {
panic("differeng number of target / source lookup columns")
}
Expand All @@ -103,7 +104,7 @@ func (p *Schema) AddLookupConstraint(handle string, source uint, source_multipli
}
//
p.constraints = append(p.constraints,
constraint.NewLookupConstraint(handle, source, source_multiplier, target, target_multiplier, from, into))
constraint.NewLookupConstraint(handle, source, target, from, into))
}

// AddPermutationConstraint appends a new permutation constraint which
Expand All @@ -114,13 +115,13 @@ func (p *Schema) AddPermutationConstraint(targets []uint, sources []uint) {
}

// AddVanishingConstraint appends a new vanishing constraint.
func (p *Schema) AddVanishingConstraint(handle string, module uint, multiplier uint, domain *int, expr Expr) {
if module >= uint(len(p.modules)) {
panic(fmt.Sprintf("invalid module index (%d)", module))
func (p *Schema) AddVanishingConstraint(handle string, context trace.Context, domain *int, expr Expr) {
if context.Module() >= uint(len(p.modules)) {
panic(fmt.Sprintf("invalid module index (%d)", context.Module()))
}
// TODO: sanity check expression enclosed by module
p.constraints = append(p.constraints,
constraint.NewVanishingConstraint(handle, module, multiplier, domain, constraint.ZeroTest[Expr]{Expr: expr}))
constraint.NewVanishingConstraint(handle, context, domain, constraint.ZeroTest[Expr]{Expr: expr}))
}

// AddRangeConstraint appends a new range constraint.
Expand Down
17 changes: 10 additions & 7 deletions pkg/binfile/computation.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/consensys/go-corset/pkg/hir"
sc "github.com/consensys/go-corset/pkg/schema"
"github.com/consensys/go-corset/pkg/schema/assignment"
"github.com/consensys/go-corset/pkg/trace"
)

type jsonComputationSet struct {
Expand All @@ -27,7 +28,6 @@ type jsonSortedComputation struct {
// =============================================================================

func (e jsonComputationSet) addToSchema(schema *hir.Schema) {
var multiplier uint
//
for _, c := range e.Computations {
if c.Sorted != nil {
Expand All @@ -44,6 +44,8 @@ func (e jsonComputationSet) addToSchema(schema *hir.Schema) {
// Convert target refs into columns
targets := make([]sc.Column, len(targetRefs))
//
ctx := trace.VoidContext()
//
for i, targetRef := range targetRefs {
src_cid, src_mid := sourceRefs[i].resolve(schema)
_, dst_mid := targetRef.resolve(schema)
Expand All @@ -53,20 +55,21 @@ func (e jsonComputationSet) addToSchema(schema *hir.Schema) {
}
// Determine type of source column
ith := schema.Columns().Nth(src_cid)
ctx = ctx.Join(ith.Context())
// Sanity check we have a sensible type here.
if ith.Type().AsUint() == nil {
panic(fmt.Sprintf("source column %s has field type", sourceRefs[i]))
} else if i == 0 {
multiplier = ith.LengthMultiplier()
} else if multiplier != ith.LengthMultiplier() {
panic(fmt.Sprintf("source column %s has inconsistent length multiplier", sourceRefs[i]))
} else if ctx.IsConflicted() {
panic(fmt.Sprintf("source column %s has conflicted evaluation context", sourceRefs[i]))
} else if ctx.IsVoid() {
panic(fmt.Sprintf("source column %s has void evaluation context", sourceRefs[i]))
}

sources[i] = src_cid
targets[i] = sc.NewColumn(ith.Module(), targetRef.column, multiplier, ith.Type())
targets[i] = sc.NewColumn(ctx, targetRef.column, ith.Type())
}
// Finally, add the sorted permutation assignment
schema.AddAssignment(assignment.NewSortedPermutation(module, multiplier, targets, c.Sorted.Signs, sources))
schema.AddAssignment(assignment.NewSortedPermutation(ctx, targets, c.Sorted.Signs, sources))
}
}
}
Loading

0 comments on commit 32ece2d

Please sign in to comment.