Skip to content

Commit

Permalink
Merge pull request #199 from Consensys/196-determine-evaluation-conte…
Browse files Browse the repository at this point in the history
…xt-for-expression

Add `Contextual` interface
  • Loading branch information
DavePearce authored Jul 2, 2024
2 parents 59c74ca + 1b2933c commit d3e03e8
Show file tree
Hide file tree
Showing 17 changed files with 395 additions and 85 deletions.
58 changes: 51 additions & 7 deletions pkg/air/expr.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package air

import (
"math"

"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/schema"
sc "github.com/consensys/go-corset/pkg/schema"
"github.com/consensys/go-corset/pkg/util"
)

Expand All @@ -14,12 +17,7 @@ import (
// trace expansion).
type Expr interface {
util.Boundable
// EvalAt evaluates this expression in a given tabular context. Observe that
// if this expression is *undefined* within this context then it returns
// "nil". An expression can be undefined for several reasons: firstly, if
// it accesses a row which does not exist (e.g. at index -1); secondly, if
// it accesses a column which does not exist.
EvalAt(int, trace.Trace) *fr.Element
schema.Evaluable

// String produces a string representing this as an S-Expression.
String() string
Expand All @@ -37,9 +35,19 @@ type Expr interface {
Equate(Expr) Expr
}

// ============================================================================
// Addition
// ============================================================================

// Add represents the sum over zero or more expressions.
type Add struct{ Args []Expr }

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (p *Add) Context(schema sc.Schema) (uint, bool) {
return sc.JoinContexts[Expr](p.Args, schema)
}

// Add two expressions together, producing a third.
func (p *Add) Add(other Expr) Expr { return &Add{Args: []Expr{p, other}} }

Expand All @@ -56,9 +64,19 @@ func (p *Add) Equate(other Expr) Expr { return &Sub{Args: []Expr{p, other}} }
// direction (right).
func (p *Add) Bounds() util.Bounds { return util.BoundsForArray(p.Args) }

// ============================================================================
// Subtraction
// ============================================================================

// Sub represents the subtraction over zero or more expressions.
type Sub struct{ Args []Expr }

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (p *Sub) Context(schema sc.Schema) (uint, bool) {
return sc.JoinContexts[Expr](p.Args, schema)
}

// Add two expressions together, producing a third.
func (p *Sub) Add(other Expr) Expr { return &Add{Args: []Expr{p, other}} }

Expand All @@ -75,9 +93,19 @@ func (p *Sub) Equate(other Expr) Expr { return &Sub{Args: []Expr{p, other}} }
// direction (right).
func (p *Sub) Bounds() util.Bounds { return util.BoundsForArray(p.Args) }

// ============================================================================
// Multiplication
// ============================================================================

// Mul represents the product over zero or more expressions.
type Mul struct{ Args []Expr }

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (p *Mul) Context(schema sc.Schema) (uint, bool) {
return sc.JoinContexts[Expr](p.Args, schema)
}

// Add two expressions together, producing a third.
func (p *Mul) Add(other Expr) Expr { return &Add{Args: []Expr{p, other}} }

Expand All @@ -94,6 +122,10 @@ func (p *Mul) Equate(other Expr) Expr { return &Sub{Args: []Expr{p, other}} }
// direction (right).
func (p *Mul) Bounds() util.Bounds { return util.BoundsForArray(p.Args) }

// ============================================================================
// Constant
// ============================================================================

// Constant represents a constant value within an expression.
type Constant struct{ Value *fr.Element }

Expand All @@ -120,6 +152,12 @@ func NewConstCopy(val *fr.Element) Expr {
return &Constant{&clone}
}

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (p *Constant) Context(schema sc.Schema) (uint, bool) {
return math.MaxUint, true
}

// Add two expressions together, producing a third.
func (p *Constant) Add(other Expr) Expr { return &Add{Args: []Expr{p, other}} }

Expand Down Expand Up @@ -153,6 +191,12 @@ func NewColumnAccess(column uint, shift int) Expr {
return &ColumnAccess{column, shift}
}

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (p *ColumnAccess) Context(schema sc.Schema) (uint, bool) {
return schema.Columns().Nth(p.Column).Module(), true
}

// Add two expressions together, producing a third.
func (p *ColumnAccess) Add(other Expr) Expr { return &Add{Args: []Expr{p, other}} }

Expand Down
16 changes: 10 additions & 6 deletions pkg/air/gadgets/bits.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,19 @@ import (
// ApplyBinaryGadget adds a binarity constraint for a given column in the schema
// which enforces that all values in the given column are either 0 or 1. For a
// column X, this corresponds to the vanishing constraint X * (X-1) == 0.
func ApplyBinaryGadget(column uint, schema *air.Schema) {
func ApplyBinaryGadget(col uint, schema *air.Schema) {
// Identify target column
column := schema.Columns().Nth(col)
// Determine column name
name := schema.Columns().Nth(column).Name()
name := column.Name()
// Construct X
X := air.NewColumnAccess(column, 0)
X := air.NewColumnAccess(col, 0)
// Construct X-1
X_m1 := X.Sub(air.NewConst64(1))
// Construct X * (X-1)
X_X_m1 := X.Mul(X_m1)
// Done!
schema.AddVanishingConstraint(fmt.Sprintf("%s:u1", name), nil, X_X_m1)
schema.AddVanishingConstraint(fmt.Sprintf("%s:u1", name), column.Module(), nil, X_X_m1)
}

// ApplyBitwidthGadget ensures all values in a given column fit within a given
Expand All @@ -33,11 +35,13 @@ func ApplyBitwidthGadget(col uint, nbits uint, schema *air.Schema) {
} else if nbits == 0 {
panic("zero bitwidth constraint encountered")
}
// Identify target column
column := schema.Columns().Nth(col)
// Calculate how many bytes required.
n := nbits / 8
es := make([]air.Expr, n)
fr256 := fr.NewElement(256)
name := schema.Columns().Nth(col).Name()
name := column.Name()
coefficient := fr.NewElement(1)
// Add decomposition assignment
index := schema.AddAssignment(assignment.NewByteDecomposition(name, col, n))
Expand All @@ -56,5 +60,5 @@ func ApplyBitwidthGadget(col uint, nbits uint, schema *air.Schema) {
X := air.NewColumnAccess(col, 0)
eq := X.Equate(sum)
// Construct column name
schema.AddVanishingConstraint(fmt.Sprintf("%s:u%d", name, nbits), nil, eq)
schema.AddVanishingConstraint(fmt.Sprintf("%s:u%d", name, nbits), column.Module(), nil, eq)
}
6 changes: 4 additions & 2 deletions pkg/air/gadgets/column_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ import (
// computation.
func ApplyColumnSortGadget(col uint, sign bool, bitwidth uint, schema *air.Schema) {
var deltaName string
// Identify target column
column := schema.Columns().Nth(col)
// Determine column name
name := schema.Columns().Nth(col).Name()
name := column.Name()
// Configure computation
Xk := air.NewColumnAccess(col, 0)
Xkm1 := air.NewColumnAccess(col, -1)
Expand All @@ -40,5 +42,5 @@ func ApplyColumnSortGadget(col uint, sign bool, bitwidth uint, schema *air.Schem
ApplyBitwidthGadget(deltaIndex, bitwidth, schema)
// Configure constraint: Delta[k] = X[k] - X[k-1]
Dk := air.NewColumnAccess(deltaIndex, 0)
schema.AddVanishingConstraint(deltaName, nil, Dk.Equate(Xdiff))
schema.AddVanishingConstraint(deltaName, column.Module(), nil, Dk.Equate(Xdiff))
}
17 changes: 10 additions & 7 deletions pkg/air/gadgets/lexicographic_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"strings"

"github.com/consensys/go-corset/pkg/air"
sc "github.com/consensys/go-corset/pkg/schema"
"github.com/consensys/go-corset/pkg/schema/assignment"
)

Expand All @@ -26,22 +27,24 @@ import (
// ensure it is positive. The delta column is constrained to a given bitwidth,
// with constraints added as necessary to ensure this.
func ApplyLexicographicSortingGadget(columns []uint, signs []bool, bitwidth uint, schema *air.Schema) {
// Check preconditions
ncols := len(columns)
// Check preconditions
if ncols != len(signs) {
panic("Inconsistent number of columns and signs for lexicographic sort.")
}
// Determine enclosing module for this gadget.
module := sc.DetermineEnclosingModuleOfColumns(columns, schema)
// Construct a unique prefix for this sort.
prefix := constructLexicographicSortingPrefix(columns, signs, schema)
// Add trace computation
deltaIndex := schema.AddAssignment(assignment.NewLexicographicSort(prefix, columns, signs, bitwidth))
// Construct selecto bits.
addLexicographicSelectorBits(prefix, deltaIndex, columns, schema)
addLexicographicSelectorBits(prefix, module, deltaIndex, columns, schema)
// Construct delta terms
constraint := constructLexicographicDeltaConstraint(deltaIndex, columns, signs)
// Add delta constraint
deltaName := fmt.Sprintf("%s:delta", prefix)
schema.AddVanishingConstraint(deltaName, nil, constraint)
schema.AddVanishingConstraint(deltaName, module, nil, constraint)
// Add necessary bitwidth constraints
ApplyBitwidthGadget(deltaIndex, bitwidth, schema)
}
Expand Down Expand Up @@ -73,7 +76,7 @@ func constructLexicographicSortingPrefix(columns []uint, signs []bool, schema *a
//
// NOTE: this implementation differs from the original corset which used an
// additional "Eq" bit to help ensure at most one selector bit was enabled.
func addLexicographicSelectorBits(prefix string, deltaIndex uint, columns []uint, schema *air.Schema) {
func addLexicographicSelectorBits(prefix string, module uint, deltaIndex uint, columns []uint, schema *air.Schema) {
ncols := uint(len(columns))
// Calculate column index of first selector bit
bitIndex := deltaIndex + 1
Expand All @@ -97,7 +100,7 @@ func addLexicographicSelectorBits(prefix string, deltaIndex uint, columns []uint
pterms[i] = air.NewColumnAccess(bitIndex+i, 0)
pDiff := air.NewColumnAccess(columns[i], 0).Sub(air.NewColumnAccess(columns[i], -1))
pName := fmt.Sprintf("%s:%d:a", prefix, i)
schema.AddVanishingConstraint(pName, nil, air.NewConst64(1).Sub(&air.Add{Args: pterms}).Mul(pDiff))
schema.AddVanishingConstraint(pName, module, nil, air.NewConst64(1).Sub(&air.Add{Args: pterms}).Mul(pDiff))
// (∀j<i.Bj=0) ∧ Bi=1 ==> C[k]≠C[k-1]
qDiff := Normalise(air.NewColumnAccess(columns[i], 0).Sub(air.NewColumnAccess(columns[i], -1)), schema)
qName := fmt.Sprintf("%s:%d:b", prefix, i)
Expand All @@ -109,14 +112,14 @@ func addLexicographicSelectorBits(prefix string, deltaIndex uint, columns []uint
constraint = air.NewConst64(1).Sub(&air.Add{Args: qterms}).Mul(constraint)
}

schema.AddVanishingConstraint(qName, nil, constraint)
schema.AddVanishingConstraint(qName, module, nil, constraint)
}

sum := &air.Add{Args: terms}
// (sum = 0) ∨ (sum = 1)
constraint := sum.Mul(sum.Equate(air.NewConst64(1)))
name := fmt.Sprintf("%s:xor", prefix)
schema.AddVanishingConstraint(name, nil, constraint)
schema.AddVanishingConstraint(name, module, nil, constraint)
}

// Construct the lexicographic delta constraint. This states that the delta
Expand Down
12 changes: 10 additions & 2 deletions pkg/air/gadgets/normalisation.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ func Normalise(e air.Expr, schema *air.Schema) air.Expr {
// column which holds the multiplicative inverse. Constraints are also added to
// ensure it really holds the inverted value.
func ApplyPseudoInverseGadget(e air.Expr, schema *air.Schema) air.Expr {
// Determine enclosing module.
module := sc.DetermineEnclosingModuleOfExpression(e, schema)
// Construct inverse computation
ie := &Inverse{Expr: e}
// Determine computed column name
Expand All @@ -52,10 +54,10 @@ func ApplyPseudoInverseGadget(e air.Expr, schema *air.Schema) air.Expr {
inv_e_implies_one_e_e := inv_e.Mul(one_e_e)
// Ensure (e != 0) ==> (1 == e/e)
l_name := fmt.Sprintf("[%s <=]", ie.String())
schema.AddVanishingConstraint(l_name, nil, e_implies_one_e_e)
schema.AddVanishingConstraint(l_name, module, nil, e_implies_one_e_e)
// Ensure (e/e != 0) ==> (1 == e/e)
r_name := fmt.Sprintf("[%s =>]", ie.String())
schema.AddVanishingConstraint(r_name, nil, inv_e_implies_one_e_e)
schema.AddVanishingConstraint(r_name, module, nil, inv_e_implies_one_e_e)
// Done
return air.NewColumnAccess(index, 0)
}
Expand All @@ -77,6 +79,12 @@ func (e *Inverse) EvalAt(k int, tbl tr.Trace) *fr.Element {
// direction (right).
func (e *Inverse) Bounds() util.Bounds { return e.Expr.Bounds() }

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (e *Inverse) Context(schema sc.Schema) (uint, bool) {
return e.Expr.Context(schema)
}

func (e *Inverse) String() string {
return fmt.Sprintf("(inv %s)", e.Expr)
}
4 changes: 2 additions & 2 deletions pkg/air/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ func (p *Schema) AddPermutationConstraint(targets []uint, sources []uint) {
}

// AddVanishingConstraint appends a new vanishing constraint.
func (p *Schema) AddVanishingConstraint(handle string, domain *int, expr Expr) {
func (p *Schema) AddVanishingConstraint(handle string, module uint, domain *int, expr Expr) {
p.constraints = append(p.constraints,
constraint.NewVanishingConstraint(handle, domain, constraint.ZeroTest[Expr]{Expr: expr}))
constraint.NewVanishingConstraint(handle, module, domain, constraint.ZeroTest[Expr]{Expr: expr}))
}

// AddRangeConstraint appends a new range constraint.
Expand Down
5 changes: 4 additions & 1 deletion pkg/binfile/constraint.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package binfile

import (
"github.com/consensys/go-corset/pkg/hir"
sc "github.com/consensys/go-corset/pkg/schema"
)

// JsonConstraint аn enumeration of constraint forms. Exactly one of these fields
Expand Down Expand Up @@ -43,8 +44,10 @@ func (e jsonConstraint) addToSchema(schema *hir.Schema) {
expr := e.Vanishes.Expr.ToHir(schema)
// Translate Domain
domain := e.Vanishes.Domain.toHir()
// Determine enclosing module
module := sc.DetermineEnclosingModuleOfExpression(expr, schema)
// Construct the vanishing constraint
schema.AddVanishingConstraint(e.Vanishes.Handle, domain, expr)
schema.AddVanishingConstraint(e.Vanishes.Handle, module, domain, expr)
} else if e.Permutation == nil {
// Catch all
panic("Unknown JSON constraint encountered")
Expand Down
Loading

0 comments on commit d3e03e8

Please sign in to comment.