Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Interleaving Constraints #205

Merged
merged 3 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions pkg/air/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ type Add struct{ Args []Expr }

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (p *Add) Context(schema sc.Schema) (uint, bool) {
func (p *Add) Context(schema sc.Schema) (uint, uint, bool) {
return sc.JoinContexts[Expr](p.Args, schema)
}

Expand Down Expand Up @@ -73,7 +73,7 @@ type Sub struct{ Args []Expr }

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (p *Sub) Context(schema sc.Schema) (uint, bool) {
func (p *Sub) Context(schema sc.Schema) (uint, uint, bool) {
return sc.JoinContexts[Expr](p.Args, schema)
}

Expand Down Expand Up @@ -102,7 +102,7 @@ type Mul struct{ Args []Expr }

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (p *Mul) Context(schema sc.Schema) (uint, bool) {
func (p *Mul) Context(schema sc.Schema) (uint, uint, bool) {
return sc.JoinContexts[Expr](p.Args, schema)
}

Expand Down Expand Up @@ -154,8 +154,8 @@ func NewConstCopy(val *fr.Element) Expr {

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (p *Constant) Context(schema sc.Schema) (uint, bool) {
return math.MaxUint, true
func (p *Constant) Context(schema sc.Schema) (uint, uint, bool) {
return math.MaxUint, math.MaxUint, true
}

// Add two expressions together, producing a third.
Expand Down Expand Up @@ -193,8 +193,9 @@ func NewColumnAccess(column uint, shift int) Expr {

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (p *ColumnAccess) Context(schema sc.Schema) (uint, bool) {
return schema.Columns().Nth(p.Column).Module(), true
func (p *ColumnAccess) Context(schema sc.Schema) (uint, uint, bool) {
col := schema.Columns().Nth(p.Column)
return col.Module(), col.LengthMultiplier(), true
}

// Add two expressions together, producing a third.
Expand Down
7 changes: 4 additions & 3 deletions pkg/air/gadgets/bits.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func ApplyBinaryGadget(col uint, schema *air.Schema) {
// Construct X * (X-1)
X_X_m1 := X.Mul(X_m1)
// Done!
schema.AddVanishingConstraint(fmt.Sprintf("%s:u1", name), column.Module(), nil, X_X_m1)
schema.AddVanishingConstraint(fmt.Sprintf("%s:u1", name), column.Module(), column.LengthMultiplier(), nil, X_X_m1)
}

// ApplyBitwidthGadget ensures all values in a given column fit within a given
Expand All @@ -44,7 +44,8 @@ func ApplyBitwidthGadget(col uint, nbits uint, schema *air.Schema) {
name := column.Name()
coefficient := fr.NewElement(1)
// Add decomposition assignment
index := schema.AddAssignment(assignment.NewByteDecomposition(name, column.Module(), col, n))
index := schema.AddAssignment(
assignment.NewByteDecomposition(name, column.Module(), column.LengthMultiplier(), col, n))
// Construct Columns
for i := uint(0); i < n; i++ {
// Create Column + Constraint
Expand All @@ -60,5 +61,5 @@ func ApplyBitwidthGadget(col uint, nbits uint, schema *air.Schema) {
X := air.NewColumnAccess(col, 0)
eq := X.Equate(sum)
// Construct column name
schema.AddVanishingConstraint(fmt.Sprintf("%s:u%d", name, nbits), column.Module(), nil, eq)
schema.AddVanishingConstraint(fmt.Sprintf("%s:u%d", name, nbits), column.Module(), column.LengthMultiplier(), nil, eq)
}
5 changes: 3 additions & 2 deletions pkg/air/gadgets/column_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@ func ApplyColumnSortGadget(col uint, sign bool, bitwidth uint, schema *air.Schem
deltaName = fmt.Sprintf("-%s", name)
}
// Add delta assignment
deltaIndex := schema.AddAssignment(assignment.NewComputedColumn(column.Module(), deltaName, Xdiff))
deltaIndex := schema.AddAssignment(
assignment.NewComputedColumn(column.Module(), deltaName, column.LengthMultiplier(), Xdiff))
// Add necessary bitwidth constraints
ApplyBitwidthGadget(deltaIndex, bitwidth, schema)
// Configure constraint: Delta[k] = X[k] - X[k-1]
Dk := air.NewColumnAccess(deltaIndex, 0)
schema.AddVanishingConstraint(deltaName, column.Module(), nil, Dk.Equate(Xdiff))
schema.AddVanishingConstraint(deltaName, column.Module(), column.LengthMultiplier(), nil, Dk.Equate(Xdiff))
}
6 changes: 3 additions & 3 deletions pkg/air/gadgets/expand.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,23 @@ func Expand(e air.Expr, schema *air.Schema) uint {
return ca.Column
}
// No optimisation, therefore expand using a computedcolumn
module := sc.DetermineEnclosingModuleOfExpression(e, schema)
module, multiplier := sc.DetermineEnclosingModuleOfExpression(e, schema)
// Determine computed column name
name := e.String()
// Look up column
index, ok := sc.ColumnIndexOf(schema, module, name)
// Add new column (if it does not already exist)
if !ok {
// Add computed column
index = schema.AddAssignment(assignment.NewComputedColumn(module, name, e))
index = schema.AddAssignment(assignment.NewComputedColumn(module, name, multiplier, e))
}
// Construct v == [e]
v := air.NewColumnAccess(index, 0)
// Construct 1 == e/e
eq_e_v := v.Equate(e)
// Ensure (e - v) == 0, where v is value of computed column.
c_name := fmt.Sprintf("[%s]", e.String())
schema.AddVanishingConstraint(c_name, module, nil, eq_e_v)
schema.AddVanishingConstraint(c_name, module, multiplier, nil, eq_e_v)
//
return index
}
19 changes: 11 additions & 8 deletions pkg/air/gadgets/lexicographic_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,19 @@ func ApplyLexicographicSortingGadget(columns []uint, signs []bool, bitwidth uint
panic("Inconsistent number of columns and signs for lexicographic sort.")
}
// Determine enclosing module for this gadget.
module := sc.DetermineEnclosingModuleOfColumns(columns, schema)
module, multiplier := sc.DetermineEnclosingModuleOfColumns(columns, schema)
// Construct a unique prefix for this sort.
prefix := constructLexicographicSortingPrefix(columns, signs, schema)
// Add trace computation
deltaIndex := schema.AddAssignment(assignment.NewLexicographicSort(prefix, module, columns, signs, bitwidth))
deltaIndex := schema.AddAssignment(
assignment.NewLexicographicSort(prefix, module, multiplier, columns, signs, bitwidth))
// Construct selecto bits.
addLexicographicSelectorBits(prefix, module, deltaIndex, columns, schema)
addLexicographicSelectorBits(prefix, module, multiplier, deltaIndex, columns, schema)
// Construct delta terms
constraint := constructLexicographicDeltaConstraint(deltaIndex, columns, signs)
// Add delta constraint
deltaName := fmt.Sprintf("%s:delta", prefix)
schema.AddVanishingConstraint(deltaName, module, nil, constraint)
schema.AddVanishingConstraint(deltaName, module, multiplier, nil, constraint)
// Add necessary bitwidth constraints
ApplyBitwidthGadget(deltaIndex, bitwidth, schema)
}
Expand Down Expand Up @@ -76,7 +77,8 @@ func constructLexicographicSortingPrefix(columns []uint, signs []bool, schema *a
//
// NOTE: this implementation differs from the original corset which used an
// additional "Eq" bit to help ensure at most one selector bit was enabled.
func addLexicographicSelectorBits(prefix string, module uint, deltaIndex uint, columns []uint, schema *air.Schema) {
func addLexicographicSelectorBits(prefix string, module uint, multiplier uint,
deltaIndex uint, columns []uint, schema *air.Schema) {
ncols := uint(len(columns))
// Calculate column index of first selector bit
bitIndex := deltaIndex + 1
Expand All @@ -100,7 +102,8 @@ func addLexicographicSelectorBits(prefix string, module uint, deltaIndex uint, c
pterms[i] = air.NewColumnAccess(bitIndex+i, 0)
pDiff := air.NewColumnAccess(columns[i], 0).Sub(air.NewColumnAccess(columns[i], -1))
pName := fmt.Sprintf("%s:%d:a", prefix, i)
schema.AddVanishingConstraint(pName, module, nil, air.NewConst64(1).Sub(&air.Add{Args: pterms}).Mul(pDiff))
schema.AddVanishingConstraint(pName, module, multiplier,
nil, air.NewConst64(1).Sub(&air.Add{Args: pterms}).Mul(pDiff))
// (∀j<i.Bj=0) ∧ Bi=1 ==> C[k]≠C[k-1]
qDiff := Normalise(air.NewColumnAccess(columns[i], 0).Sub(air.NewColumnAccess(columns[i], -1)), schema)
qName := fmt.Sprintf("%s:%d:b", prefix, i)
Expand All @@ -112,14 +115,14 @@ func addLexicographicSelectorBits(prefix string, module uint, deltaIndex uint, c
constraint = air.NewConst64(1).Sub(&air.Add{Args: qterms}).Mul(constraint)
}

schema.AddVanishingConstraint(qName, module, nil, constraint)
schema.AddVanishingConstraint(qName, module, multiplier, nil, constraint)
}

sum := &air.Add{Args: terms}
// (sum = 0) ∨ (sum = 1)
constraint := sum.Mul(sum.Equate(air.NewConst64(1)))
name := fmt.Sprintf("%s:xor", prefix)
schema.AddVanishingConstraint(name, module, nil, constraint)
schema.AddVanishingConstraint(name, module, multiplier, nil, constraint)
}

// Construct the lexicographic delta constraint. This states that the delta
Expand Down
10 changes: 5 additions & 5 deletions pkg/air/gadgets/normalisation.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func Normalise(e air.Expr, schema *air.Schema) air.Expr {
// ensure it really holds the inverted value.
func ApplyPseudoInverseGadget(e air.Expr, schema *air.Schema) air.Expr {
// Determine enclosing module.
module := sc.DetermineEnclosingModuleOfExpression(e, schema)
module, multiplier := sc.DetermineEnclosingModuleOfExpression(e, schema)
// Construct inverse computation
ie := &Inverse{Expr: e}
// Determine computed column name
Expand All @@ -39,7 +39,7 @@ func ApplyPseudoInverseGadget(e air.Expr, schema *air.Schema) air.Expr {
// Add new column (if it does not already exist)
if !ok {
// Add computed column
index = schema.AddAssignment(assignment.NewComputedColumn(module, name, ie))
index = schema.AddAssignment(assignment.NewComputedColumn(module, name, multiplier, ie))
}

// Construct 1/e
Expand All @@ -54,10 +54,10 @@ func ApplyPseudoInverseGadget(e air.Expr, schema *air.Schema) air.Expr {
inv_e_implies_one_e_e := inv_e.Mul(one_e_e)
// Ensure (e != 0) ==> (1 == e/e)
l_name := fmt.Sprintf("[%s <=]", ie.String())
schema.AddVanishingConstraint(l_name, module, nil, e_implies_one_e_e)
schema.AddVanishingConstraint(l_name, module, multiplier, nil, e_implies_one_e_e)
// Ensure (e/e != 0) ==> (1 == e/e)
r_name := fmt.Sprintf("[%s =>]", ie.String())
schema.AddVanishingConstraint(r_name, module, nil, inv_e_implies_one_e_e)
schema.AddVanishingConstraint(r_name, module, multiplier, nil, inv_e_implies_one_e_e)
// Done
return air.NewColumnAccess(index, 0)
}
Expand All @@ -81,7 +81,7 @@ func (e *Inverse) Bounds() util.Bounds { return e.Expr.Bounds() }

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (e *Inverse) Context(schema sc.Schema) (uint, bool) {
func (e *Inverse) Context(schema sc.Schema) (uint, uint, bool) {
return e.Expr.Context(schema)
}

Expand Down
9 changes: 5 additions & 4 deletions pkg/air/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ func (p *Schema) AddAssignment(c schema.Assignment) uint {
}

// AddLookupConstraint appends a new lookup constraint.
func (p *Schema) AddLookupConstraint(handle string, source uint, target uint, sources []uint, targets []uint) {
func (p *Schema) AddLookupConstraint(handle string, source uint, source_multiplier uint,
target uint, target_multiplier uint, sources []uint, targets []uint) {
if len(targets) != len(sources) {
panic("differeng number of target / source lookup columns")
}
Expand All @@ -102,7 +103,7 @@ func (p *Schema) AddLookupConstraint(handle string, source uint, target uint, so
}
//
p.constraints = append(p.constraints,
constraint.NewLookupConstraint(handle, source, target, from, into))
constraint.NewLookupConstraint(handle, source, source_multiplier, target, target_multiplier, from, into))
}

// AddPermutationConstraint appends a new permutation constraint which
Expand All @@ -113,13 +114,13 @@ func (p *Schema) AddPermutationConstraint(targets []uint, sources []uint) {
}

// AddVanishingConstraint appends a new vanishing constraint.
func (p *Schema) AddVanishingConstraint(handle string, module uint, domain *int, expr Expr) {
func (p *Schema) AddVanishingConstraint(handle string, module uint, multiplier uint, domain *int, expr Expr) {
if module >= uint(len(p.modules)) {
panic(fmt.Sprintf("invalid module index (%d)", module))
}
// TODO: sanity check expression enclosed by module
p.constraints = append(p.constraints,
constraint.NewVanishingConstraint(handle, module, domain, constraint.ZeroTest[Expr]{Expr: expr}))
constraint.NewVanishingConstraint(handle, module, multiplier, domain, constraint.ZeroTest[Expr]{Expr: expr}))
}

// AddRangeConstraint appends a new range constraint.
Expand Down
13 changes: 10 additions & 3 deletions pkg/binfile/computation.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"github.com/consensys/go-corset/pkg/hir"
sc "github.com/consensys/go-corset/pkg/schema"
"github.com/consensys/go-corset/pkg/schema/assignment"
)

type jsonComputationSet struct {
Expand All @@ -26,6 +27,8 @@ type jsonSortedComputation struct {
// =============================================================================

func (e jsonComputationSet) addToSchema(schema *hir.Schema) {
var multiplier uint
//
for _, c := range e.Computations {
if c.Sorted != nil {
targetRefs := asColumnRefs(c.Sorted.Tos)
Expand Down Expand Up @@ -53,13 +56,17 @@ func (e jsonComputationSet) addToSchema(schema *hir.Schema) {
// Sanity check we have a sensible type here.
if ith.Type().AsUint() == nil {
panic(fmt.Sprintf("source column %s has field type", sourceRefs[i]))
} else if i == 0 {
multiplier = ith.LengthMultiplier()
} else if multiplier != ith.LengthMultiplier() {
panic(fmt.Sprintf("source column %s has inconsistent length multiplier", sourceRefs[i]))
}

sources[i] = src_cid
targets[i] = sc.NewColumn(ith.Module(), targetRef.column, ith.Type())
targets[i] = sc.NewColumn(ith.Module(), targetRef.column, multiplier, ith.Type())
}
// Finally, add the permutation column
schema.AddPermutationColumns(module, targets, c.Sorted.Signs, sources)
// Finally, add the sorted permutation assignment
schema.AddAssignment(assignment.NewSortedPermutation(module, multiplier, targets, c.Sorted.Signs, sources))
}
}
}
4 changes: 2 additions & 2 deletions pkg/binfile/constraint.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ func (e jsonConstraint) addToSchema(schema *hir.Schema) {
// Translate Domain
domain := e.Vanishes.Domain.toHir()
// Determine enclosing module
module := sc.DetermineEnclosingModuleOfExpression(expr, schema)
module, multiplier := sc.DetermineEnclosingModuleOfExpression(expr, schema)
// Construct the vanishing constraint
schema.AddVanishingConstraint(e.Vanishes.Handle, module, domain, expr)
schema.AddVanishingConstraint(e.Vanishes.Handle, module, multiplier, domain, expr)
} else if e.Permutation == nil {
// Catch all
panic("Unknown JSON constraint encountered")
Expand Down
20 changes: 11 additions & 9 deletions pkg/hir/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package hir
import (
"fmt"

"github.com/consensys/go-corset/pkg/schema"
sc "github.com/consensys/go-corset/pkg/schema"
)

Expand Down Expand Up @@ -68,17 +69,18 @@ func (p *Environment) AddDataColumn(module uint, column string, datatype sc.Type
return cid
}

// AddPermutationColumns registers a new permutation within a given module. Observe that
// this will panic if any of the target columns already exists, or the source
// columns don't exist.
func (p *Environment) AddPermutationColumns(module uint, targets []sc.Column, signs []bool, sources []uint) {
// AddAssignment appends a new assignment (i.e. set of computed columns) to be
// used during trace expansion for this schema. Computed columns are introduced
// by the process of lowering from HIR / MIR to AIR.
func (p *Environment) AddAssignment(decl schema.Assignment) {
// Update schema
p.schema.AddPermutationColumns(module, targets, signs, sources)
index := p.schema.AddAssignment(decl)
// Update cache
for _, col := range targets {
cid := uint(len(p.columns))
cref := columnRef{module, col.Name()}
p.columns[cref] = cid
for i := decl.Columns(); i.HasNext(); {
ith := i.Next()
cref := columnRef{ith.Module(), ith.Name()}
p.columns[cref] = index
index++
}
}

Expand Down
Loading