diff --git a/pkg/air/expr.go b/pkg/air/expr.go index de2dc07..ae45acf 100644 --- a/pkg/air/expr.go +++ b/pkg/air/expr.go @@ -44,7 +44,7 @@ type Add struct{ Args []Expr } // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *Add) Context(schema sc.Schema) (uint, bool) { +func (p *Add) Context(schema sc.Schema) (uint, uint, bool) { return sc.JoinContexts[Expr](p.Args, schema) } @@ -73,7 +73,7 @@ type Sub struct{ Args []Expr } // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *Sub) Context(schema sc.Schema) (uint, bool) { +func (p *Sub) Context(schema sc.Schema) (uint, uint, bool) { return sc.JoinContexts[Expr](p.Args, schema) } @@ -102,7 +102,7 @@ type Mul struct{ Args []Expr } // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *Mul) Context(schema sc.Schema) (uint, bool) { +func (p *Mul) Context(schema sc.Schema) (uint, uint, bool) { return sc.JoinContexts[Expr](p.Args, schema) } @@ -154,8 +154,8 @@ func NewConstCopy(val *fr.Element) Expr { // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *Constant) Context(schema sc.Schema) (uint, bool) { - return math.MaxUint, true +func (p *Constant) Context(schema sc.Schema) (uint, uint, bool) { + return math.MaxUint, math.MaxUint, true } // Add two expressions together, producing a third. @@ -193,8 +193,9 @@ func NewColumnAccess(column uint, shift int) Expr { // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *ColumnAccess) Context(schema sc.Schema) (uint, bool) { - return schema.Columns().Nth(p.Column).Module(), true +func (p *ColumnAccess) Context(schema sc.Schema) (uint, uint, bool) { + col := schema.Columns().Nth(p.Column) + return col.Module(), col.LengthMultiplier(), true } // Add two expressions together, producing a third. diff --git a/pkg/air/gadgets/bits.go b/pkg/air/gadgets/bits.go index 2902b78..8d8845d 100644 --- a/pkg/air/gadgets/bits.go +++ b/pkg/air/gadgets/bits.go @@ -23,7 +23,7 @@ func ApplyBinaryGadget(col uint, schema *air.Schema) { // Construct X * (X-1) X_X_m1 := X.Mul(X_m1) // Done! - schema.AddVanishingConstraint(fmt.Sprintf("%s:u1", name), column.Module(), nil, X_X_m1) + schema.AddVanishingConstraint(fmt.Sprintf("%s:u1", name), column.Module(), column.LengthMultiplier(), nil, X_X_m1) } // ApplyBitwidthGadget ensures all values in a given column fit within a given @@ -44,7 +44,8 @@ func ApplyBitwidthGadget(col uint, nbits uint, schema *air.Schema) { name := column.Name() coefficient := fr.NewElement(1) // Add decomposition assignment - index := schema.AddAssignment(assignment.NewByteDecomposition(name, column.Module(), col, n)) + index := schema.AddAssignment( + assignment.NewByteDecomposition(name, column.Module(), column.LengthMultiplier(), col, n)) // Construct Columns for i := uint(0); i < n; i++ { // Create Column + Constraint @@ -60,5 +61,5 @@ func ApplyBitwidthGadget(col uint, nbits uint, schema *air.Schema) { X := air.NewColumnAccess(col, 0) eq := X.Equate(sum) // Construct column name - schema.AddVanishingConstraint(fmt.Sprintf("%s:u%d", name, nbits), column.Module(), nil, eq) + schema.AddVanishingConstraint(fmt.Sprintf("%s:u%d", name, nbits), column.Module(), column.LengthMultiplier(), nil, eq) } diff --git a/pkg/air/gadgets/column_sort.go b/pkg/air/gadgets/column_sort.go index 1eab748..de08f8b 100644 --- a/pkg/air/gadgets/column_sort.go +++ b/pkg/air/gadgets/column_sort.go @@ -37,10 +37,11 @@ func ApplyColumnSortGadget(col uint, sign bool, bitwidth uint, schema *air.Schem deltaName = fmt.Sprintf("-%s", name) } // Add delta assignment - deltaIndex := schema.AddAssignment(assignment.NewComputedColumn(column.Module(), deltaName, Xdiff)) + deltaIndex := schema.AddAssignment( + assignment.NewComputedColumn(column.Module(), deltaName, column.LengthMultiplier(), Xdiff)) // Add necessary bitwidth constraints ApplyBitwidthGadget(deltaIndex, bitwidth, schema) // Configure constraint: Delta[k] = X[k] - X[k-1] Dk := air.NewColumnAccess(deltaIndex, 0) - schema.AddVanishingConstraint(deltaName, column.Module(), nil, Dk.Equate(Xdiff)) + schema.AddVanishingConstraint(deltaName, column.Module(), column.LengthMultiplier(), nil, Dk.Equate(Xdiff)) } diff --git a/pkg/air/gadgets/expand.go b/pkg/air/gadgets/expand.go index 6028a7c..be975b5 100644 --- a/pkg/air/gadgets/expand.go +++ b/pkg/air/gadgets/expand.go @@ -20,7 +20,7 @@ func Expand(e air.Expr, schema *air.Schema) uint { return ca.Column } // No optimisation, therefore expand using a computedcolumn - module := sc.DetermineEnclosingModuleOfExpression(e, schema) + module, multiplier := sc.DetermineEnclosingModuleOfExpression(e, schema) // Determine computed column name name := e.String() // Look up column @@ -28,7 +28,7 @@ func Expand(e air.Expr, schema *air.Schema) uint { // Add new column (if it does not already exist) if !ok { // Add computed column - index = schema.AddAssignment(assignment.NewComputedColumn(module, name, e)) + index = schema.AddAssignment(assignment.NewComputedColumn(module, name, multiplier, e)) } // Construct v == [e] v := air.NewColumnAccess(index, 0) @@ -36,7 +36,7 @@ func Expand(e air.Expr, schema *air.Schema) uint { eq_e_v := v.Equate(e) // Ensure (e - v) == 0, where v is value of computed column. c_name := fmt.Sprintf("[%s]", e.String()) - schema.AddVanishingConstraint(c_name, module, nil, eq_e_v) + schema.AddVanishingConstraint(c_name, module, multiplier, nil, eq_e_v) // return index } diff --git a/pkg/air/gadgets/lexicographic_sort.go b/pkg/air/gadgets/lexicographic_sort.go index 3315e12..f12610b 100644 --- a/pkg/air/gadgets/lexicographic_sort.go +++ b/pkg/air/gadgets/lexicographic_sort.go @@ -33,18 +33,19 @@ func ApplyLexicographicSortingGadget(columns []uint, signs []bool, bitwidth uint panic("Inconsistent number of columns and signs for lexicographic sort.") } // Determine enclosing module for this gadget. - module := sc.DetermineEnclosingModuleOfColumns(columns, schema) + module, multiplier := sc.DetermineEnclosingModuleOfColumns(columns, schema) // Construct a unique prefix for this sort. prefix := constructLexicographicSortingPrefix(columns, signs, schema) // Add trace computation - deltaIndex := schema.AddAssignment(assignment.NewLexicographicSort(prefix, module, columns, signs, bitwidth)) + deltaIndex := schema.AddAssignment( + assignment.NewLexicographicSort(prefix, module, multiplier, columns, signs, bitwidth)) // Construct selecto bits. - addLexicographicSelectorBits(prefix, module, deltaIndex, columns, schema) + addLexicographicSelectorBits(prefix, module, multiplier, deltaIndex, columns, schema) // Construct delta terms constraint := constructLexicographicDeltaConstraint(deltaIndex, columns, signs) // Add delta constraint deltaName := fmt.Sprintf("%s:delta", prefix) - schema.AddVanishingConstraint(deltaName, module, nil, constraint) + schema.AddVanishingConstraint(deltaName, module, multiplier, nil, constraint) // Add necessary bitwidth constraints ApplyBitwidthGadget(deltaIndex, bitwidth, schema) } @@ -76,7 +77,8 @@ func constructLexicographicSortingPrefix(columns []uint, signs []bool, schema *a // // NOTE: this implementation differs from the original corset which used an // additional "Eq" bit to help ensure at most one selector bit was enabled. -func addLexicographicSelectorBits(prefix string, module uint, deltaIndex uint, columns []uint, schema *air.Schema) { +func addLexicographicSelectorBits(prefix string, module uint, multiplier uint, + deltaIndex uint, columns []uint, schema *air.Schema) { ncols := uint(len(columns)) // Calculate column index of first selector bit bitIndex := deltaIndex + 1 @@ -100,7 +102,8 @@ func addLexicographicSelectorBits(prefix string, module uint, deltaIndex uint, c pterms[i] = air.NewColumnAccess(bitIndex+i, 0) pDiff := air.NewColumnAccess(columns[i], 0).Sub(air.NewColumnAccess(columns[i], -1)) pName := fmt.Sprintf("%s:%d:a", prefix, i) - schema.AddVanishingConstraint(pName, module, nil, air.NewConst64(1).Sub(&air.Add{Args: pterms}).Mul(pDiff)) + schema.AddVanishingConstraint(pName, module, multiplier, + nil, air.NewConst64(1).Sub(&air.Add{Args: pterms}).Mul(pDiff)) // (∀j C[k]≠C[k-1] qDiff := Normalise(air.NewColumnAccess(columns[i], 0).Sub(air.NewColumnAccess(columns[i], -1)), schema) qName := fmt.Sprintf("%s:%d:b", prefix, i) @@ -112,14 +115,14 @@ func addLexicographicSelectorBits(prefix string, module uint, deltaIndex uint, c constraint = air.NewConst64(1).Sub(&air.Add{Args: qterms}).Mul(constraint) } - schema.AddVanishingConstraint(qName, module, nil, constraint) + schema.AddVanishingConstraint(qName, module, multiplier, nil, constraint) } sum := &air.Add{Args: terms} // (sum = 0) ∨ (sum = 1) constraint := sum.Mul(sum.Equate(air.NewConst64(1))) name := fmt.Sprintf("%s:xor", prefix) - schema.AddVanishingConstraint(name, module, nil, constraint) + schema.AddVanishingConstraint(name, module, multiplier, nil, constraint) } // Construct the lexicographic delta constraint. This states that the delta diff --git a/pkg/air/gadgets/normalisation.go b/pkg/air/gadgets/normalisation.go index 72d7646..6f6d489 100644 --- a/pkg/air/gadgets/normalisation.go +++ b/pkg/air/gadgets/normalisation.go @@ -29,7 +29,7 @@ func Normalise(e air.Expr, schema *air.Schema) air.Expr { // ensure it really holds the inverted value. func ApplyPseudoInverseGadget(e air.Expr, schema *air.Schema) air.Expr { // Determine enclosing module. - module := sc.DetermineEnclosingModuleOfExpression(e, schema) + module, multiplier := sc.DetermineEnclosingModuleOfExpression(e, schema) // Construct inverse computation ie := &Inverse{Expr: e} // Determine computed column name @@ -39,7 +39,7 @@ func ApplyPseudoInverseGadget(e air.Expr, schema *air.Schema) air.Expr { // Add new column (if it does not already exist) if !ok { // Add computed column - index = schema.AddAssignment(assignment.NewComputedColumn(module, name, ie)) + index = schema.AddAssignment(assignment.NewComputedColumn(module, name, multiplier, ie)) } // Construct 1/e @@ -54,10 +54,10 @@ func ApplyPseudoInverseGadget(e air.Expr, schema *air.Schema) air.Expr { inv_e_implies_one_e_e := inv_e.Mul(one_e_e) // Ensure (e != 0) ==> (1 == e/e) l_name := fmt.Sprintf("[%s <=]", ie.String()) - schema.AddVanishingConstraint(l_name, module, nil, e_implies_one_e_e) + schema.AddVanishingConstraint(l_name, module, multiplier, nil, e_implies_one_e_e) // Ensure (e/e != 0) ==> (1 == e/e) r_name := fmt.Sprintf("[%s =>]", ie.String()) - schema.AddVanishingConstraint(r_name, module, nil, inv_e_implies_one_e_e) + schema.AddVanishingConstraint(r_name, module, multiplier, nil, inv_e_implies_one_e_e) // Done return air.NewColumnAccess(index, 0) } @@ -81,7 +81,7 @@ func (e *Inverse) Bounds() util.Bounds { return e.Expr.Bounds() } // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (e *Inverse) Context(schema sc.Schema) (uint, bool) { +func (e *Inverse) Context(schema sc.Schema) (uint, uint, bool) { return e.Expr.Context(schema) } diff --git a/pkg/air/schema.go b/pkg/air/schema.go index 2212247..c14c958 100644 --- a/pkg/air/schema.go +++ b/pkg/air/schema.go @@ -87,7 +87,8 @@ func (p *Schema) AddAssignment(c schema.Assignment) uint { } // AddLookupConstraint appends a new lookup constraint. -func (p *Schema) AddLookupConstraint(handle string, source uint, target uint, sources []uint, targets []uint) { +func (p *Schema) AddLookupConstraint(handle string, source uint, source_multiplier uint, + target uint, target_multiplier uint, sources []uint, targets []uint) { if len(targets) != len(sources) { panic("differeng number of target / source lookup columns") } @@ -102,7 +103,7 @@ func (p *Schema) AddLookupConstraint(handle string, source uint, target uint, so } // p.constraints = append(p.constraints, - constraint.NewLookupConstraint(handle, source, target, from, into)) + constraint.NewLookupConstraint(handle, source, source_multiplier, target, target_multiplier, from, into)) } // AddPermutationConstraint appends a new permutation constraint which @@ -113,13 +114,13 @@ func (p *Schema) AddPermutationConstraint(targets []uint, sources []uint) { } // AddVanishingConstraint appends a new vanishing constraint. -func (p *Schema) AddVanishingConstraint(handle string, module uint, domain *int, expr Expr) { +func (p *Schema) AddVanishingConstraint(handle string, module uint, multiplier uint, domain *int, expr Expr) { if module >= uint(len(p.modules)) { panic(fmt.Sprintf("invalid module index (%d)", module)) } // TODO: sanity check expression enclosed by module p.constraints = append(p.constraints, - constraint.NewVanishingConstraint(handle, module, domain, constraint.ZeroTest[Expr]{Expr: expr})) + constraint.NewVanishingConstraint(handle, module, multiplier, domain, constraint.ZeroTest[Expr]{Expr: expr})) } // AddRangeConstraint appends a new range constraint. diff --git a/pkg/binfile/computation.go b/pkg/binfile/computation.go index e1c3d11..2a95b73 100644 --- a/pkg/binfile/computation.go +++ b/pkg/binfile/computation.go @@ -5,6 +5,7 @@ import ( "github.com/consensys/go-corset/pkg/hir" sc "github.com/consensys/go-corset/pkg/schema" + "github.com/consensys/go-corset/pkg/schema/assignment" ) type jsonComputationSet struct { @@ -26,6 +27,8 @@ type jsonSortedComputation struct { // ============================================================================= func (e jsonComputationSet) addToSchema(schema *hir.Schema) { + var multiplier uint + // for _, c := range e.Computations { if c.Sorted != nil { targetRefs := asColumnRefs(c.Sorted.Tos) @@ -53,13 +56,17 @@ func (e jsonComputationSet) addToSchema(schema *hir.Schema) { // Sanity check we have a sensible type here. if ith.Type().AsUint() == nil { panic(fmt.Sprintf("source column %s has field type", sourceRefs[i])) + } else if i == 0 { + multiplier = ith.LengthMultiplier() + } else if multiplier != ith.LengthMultiplier() { + panic(fmt.Sprintf("source column %s has inconsistent length multiplier", sourceRefs[i])) } sources[i] = src_cid - targets[i] = sc.NewColumn(ith.Module(), targetRef.column, ith.Type()) + targets[i] = sc.NewColumn(ith.Module(), targetRef.column, multiplier, ith.Type()) } - // Finally, add the permutation column - schema.AddPermutationColumns(module, targets, c.Sorted.Signs, sources) + // Finally, add the sorted permutation assignment + schema.AddAssignment(assignment.NewSortedPermutation(module, multiplier, targets, c.Sorted.Signs, sources)) } } } diff --git a/pkg/binfile/constraint.go b/pkg/binfile/constraint.go index eb91ef7..b1bd1c6 100644 --- a/pkg/binfile/constraint.go +++ b/pkg/binfile/constraint.go @@ -45,9 +45,9 @@ func (e jsonConstraint) addToSchema(schema *hir.Schema) { // Translate Domain domain := e.Vanishes.Domain.toHir() // Determine enclosing module - module := sc.DetermineEnclosingModuleOfExpression(expr, schema) + module, multiplier := sc.DetermineEnclosingModuleOfExpression(expr, schema) // Construct the vanishing constraint - schema.AddVanishingConstraint(e.Vanishes.Handle, module, domain, expr) + schema.AddVanishingConstraint(e.Vanishes.Handle, module, multiplier, domain, expr) } else if e.Permutation == nil { // Catch all panic("Unknown JSON constraint encountered") diff --git a/pkg/hir/environment.go b/pkg/hir/environment.go index 2a8935f..c3644a4 100644 --- a/pkg/hir/environment.go +++ b/pkg/hir/environment.go @@ -3,6 +3,7 @@ package hir import ( "fmt" + "github.com/consensys/go-corset/pkg/schema" sc "github.com/consensys/go-corset/pkg/schema" ) @@ -68,17 +69,18 @@ func (p *Environment) AddDataColumn(module uint, column string, datatype sc.Type return cid } -// AddPermutationColumns registers a new permutation within a given module. Observe that -// this will panic if any of the target columns already exists, or the source -// columns don't exist. -func (p *Environment) AddPermutationColumns(module uint, targets []sc.Column, signs []bool, sources []uint) { +// AddAssignment appends a new assignment (i.e. set of computed columns) to be +// used during trace expansion for this schema. Computed columns are introduced +// by the process of lowering from HIR / MIR to AIR. +func (p *Environment) AddAssignment(decl schema.Assignment) { // Update schema - p.schema.AddPermutationColumns(module, targets, signs, sources) + index := p.schema.AddAssignment(decl) // Update cache - for _, col := range targets { - cid := uint(len(p.columns)) - cref := columnRef{module, col.Name()} - p.columns[cref] = cid + for i := decl.Columns(); i.HasNext(); { + ith := i.Next() + cref := columnRef{ith.Module(), ith.Name()} + p.columns[cref] = index + index++ } } diff --git a/pkg/hir/expr.go b/pkg/hir/expr.go index 460b300..531aef4 100644 --- a/pkg/hir/expr.go +++ b/pkg/hir/expr.go @@ -50,7 +50,7 @@ func (p *Add) Bounds() util.Bounds { return util.BoundsForArray(p.Args) } // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *Add) Context(schema sc.Schema) (uint, bool) { +func (p *Add) Context(schema sc.Schema) (uint, uint, bool) { return sc.JoinContexts[Expr](p.Args, schema) } @@ -67,7 +67,7 @@ func (p *Sub) Bounds() util.Bounds { return util.BoundsForArray(p.Args) } // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *Sub) Context(schema sc.Schema) (uint, bool) { +func (p *Sub) Context(schema sc.Schema) (uint, uint, bool) { return sc.JoinContexts[Expr](p.Args, schema) } @@ -84,7 +84,7 @@ func (p *Mul) Bounds() util.Bounds { return util.BoundsForArray(p.Args) } // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *Mul) Context(schema sc.Schema) (uint, bool) { +func (p *Mul) Context(schema sc.Schema) (uint, uint, bool) { return sc.JoinContexts[Expr](p.Args, schema) } @@ -101,7 +101,7 @@ func (p *List) Bounds() util.Bounds { return util.BoundsForArray(p.Args) } // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *List) Context(schema sc.Schema) (uint, bool) { +func (p *List) Context(schema sc.Schema) (uint, uint, bool) { return sc.JoinContexts[Expr](p.Args, schema) } @@ -118,8 +118,8 @@ func (p *Constant) Bounds() util.Bounds { return util.EMPTY_BOUND } // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *Constant) Context(schema sc.Schema) (uint, bool) { - return math.MaxUint, true +func (p *Constant) Context(schema sc.Schema) (uint, uint, bool) { + return math.MaxUint, math.MaxUint, true } // ============================================================================ @@ -157,8 +157,18 @@ func (p *IfZero) Bounds() util.Bounds { // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *IfZero) Context(schema sc.Schema) (uint, bool) { - args := []Expr{p.Condition, p.TrueBranch, p.FalseBranch} +func (p *IfZero) Context(schema sc.Schema) (uint, uint, bool) { + if p.TrueBranch != nil && p.FalseBranch != nil { + args := []Expr{p.Condition, p.TrueBranch, p.FalseBranch} + return sc.JoinContexts[Expr](args, schema) + } else if p.TrueBranch != nil { + // FalseBranch == nil + args := []Expr{p.Condition, p.TrueBranch} + return sc.JoinContexts[Expr](args, schema) + } + // TrueBranch == nil + args := []Expr{p.Condition, p.FalseBranch} + return sc.JoinContexts[Expr](args, schema) } @@ -178,7 +188,7 @@ func (p *Normalise) Bounds() util.Bounds { // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *Normalise) Context(schema sc.Schema) (uint, bool) { +func (p *Normalise) Context(schema sc.Schema) (uint, uint, bool) { return p.Arg.Context(schema) } @@ -210,6 +220,7 @@ func (p *ColumnAccess) Bounds() util.Bounds { // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *ColumnAccess) Context(schema sc.Schema) (uint, bool) { - return schema.Columns().Nth(p.Column).Module(), true +func (p *ColumnAccess) Context(schema sc.Schema) (uint, uint, bool) { + col := schema.Columns().Nth(p.Column) + return col.Module(), col.LengthMultiplier(), true } diff --git a/pkg/hir/lower.go b/pkg/hir/lower.go index 150b434..09b3091 100644 --- a/pkg/hir/lower.go +++ b/pkg/hir/lower.go @@ -23,10 +23,9 @@ func (p *Schema) LowerToMir() *mir.Schema { col := input.(DataColumn) mirSchema.AddDataColumn(col.Module(), col.Name(), col.Type()) } - // Lower permutations + // Lower assignments (nothing to do here) for _, asn := range p.assignments { - col := asn.(Permutation) - mirSchema.AddPermutationColumns(col.Module(), col.Targets(), col.Signs(), col.Sources()) + mirSchema.AddAssignment(asn) } // Lower constraints for _, c := range p.constraints { @@ -52,7 +51,7 @@ func lowerConstraintToMir(c sc.Constraint, schema *mir.Schema) { mir_exprs := v.Constraint().Expr.LowerTo(schema) // Add individual constraints arising for _, mir_expr := range mir_exprs { - schema.AddVanishingConstraint(v.Handle(), v.Module(), v.Domain(), mir_expr) + schema.AddVanishingConstraint(v.Handle(), v.Module(), v.LengthMultiplier(), v.Domain(), mir_expr) } } else if v, ok := c.(*constraint.TypeConstraint); ok { schema.AddTypeConstraint(v.Target(), v.Type()) @@ -74,7 +73,9 @@ func lowerLookupConstraint(c LookupConstraint, schema *mir.Schema) { into[i] = lowerUnitTo(targets[i], schema) } // - schema.AddLookupConstraint(c.Handle(), c.SourceModule(), c.TargetModule(), from, into) + src_mod, src_mul := c.SourceContext() + dst_mod, dst_mul := c.TargetContext() + schema.AddLookupConstraint(c.Handle(), src_mod, src_mul, dst_mod, dst_mul, from, into) } // Lower an expression which is expected to lower into a single expression. diff --git a/pkg/hir/parser.go b/pkg/hir/parser.go index 32c28ff..9ed0204 100644 --- a/pkg/hir/parser.go +++ b/pkg/hir/parser.go @@ -10,6 +10,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/go-corset/pkg/schema" sc "github.com/consensys/go-corset/pkg/schema" + "github.com/consensys/go-corset/pkg/schema/assignment" "github.com/consensys/go-corset/pkg/sexp" ) @@ -103,6 +104,8 @@ func (p *hirParser) parseDeclaration(s sexp.SExp) error { return p.parseSortedPermutationDeclaration(e) } else if e.Len() == 4 && e.MatchSymbols(1, "lookup") { return p.parseLookupDeclaration(e) + } else if e.Len() == 3 && e.MatchSymbols(1, "interleave") { + return p.parseInterleavingDeclaration(e) } } // Error @@ -161,6 +164,7 @@ func (p *hirParser) parseColumnDeclaration(l *sexp.List) error { // Parse a sorted permutation declaration func (p *hirParser) parseSortedPermutationDeclaration(l *sexp.List) error { + var multiplier uint // Target columns are (sorted) permutations of source columns. sexpTargets := l.Elements[1].AsList() // Source columns. @@ -210,13 +214,23 @@ func (p *hirParser) parseSortedPermutationDeclaration(l *sexp.List) error { // No, it doesn't. return p.translator.SyntaxError(sexpTargets.Get(i), fmt.Sprintf("duplicate column %s", targetName)) } + // Check multiplier calculation + sourceCol := p.env.schema.Columns().Nth(sourceIndex) + if i == 0 { + // First time around, multiplier is determine by the first source column. + multiplier = sourceCol.LengthMultiplier() + } else if sourceCol.LengthMultiplier() != multiplier { + // In all other cases, multiplier must match that of first source column. + return p.translator.SyntaxError(sexpSources.Get(i), "inconsistent length multiplier") + } // Copy over column name sources[i] = sourceIndex // FIXME: determine source column type - targets[i] = schema.NewColumn(p.module, targetName, &schema.FieldType{}) + targets[i] = schema.NewColumn(p.module, targetName, multiplier, &schema.FieldType{}) } // - p.env.AddPermutationColumns(p.module, targets, signs, sources) + //p.env.AddPermutationColumns(p.module, targets, signs, sources) + p.env.AddAssignment(assignment.NewSortedPermutation(p.module, multiplier, targets, signs, sources)) // return nil } @@ -265,8 +279,8 @@ func (p *hirParser) parseLookupDeclaration(l *sexp.List) error { sources[i] = UnitExpr{source} } // Sanity check enclosing source and target modules - source, err1 := schema.DetermineEnclosingModuleOfExpressions(sources, p.env.schema) - target, err2 := schema.DetermineEnclosingModuleOfExpressions(targets, p.env.schema) + source, src_multiplier, err1 := schema.DetermineEnclosingModuleOfExpressions(sources, p.env.schema) + target, target_multiplier, err2 := schema.DetermineEnclosingModuleOfExpressions(targets, p.env.schema) // Propagate errors if err1 != nil { return p.translator.SyntaxError(sexpSources.Get(int(source)), err1.Error()) @@ -274,8 +288,55 @@ func (p *hirParser) parseLookupDeclaration(l *sexp.List) error { return p.translator.SyntaxError(sexpTargets.Get(int(target)), err2.Error()) } // Finally add constraint - p.env.schema.AddLookupConstraint(handle, source, target, sources, targets) - // DOne + p.env.schema.AddLookupConstraint(handle, source, src_multiplier, target, target_multiplier, sources, targets) + // Done + return nil +} + +// Parse am interleaving declaration +func (p *hirParser) parseInterleavingDeclaration(l *sexp.List) error { + var multiplier uint + // Target columns are (sorted) permutations of source columns. + sexpTarget := l.Elements[1].AsSymbol() + // Source columns. + sexpSources := l.Elements[2].AsList() + // Sanity checks. + if sexpTarget == nil { + return p.translator.SyntaxError(l, "column name expected") + } else if sexpSources == nil { + return p.translator.SyntaxError(l, "source column list expected") + } + // Construct and check source columns + sources := make([]uint, sexpSources.Len()) + + for i := 0; i < sexpSources.Len(); i++ { + ith := sexpSources.Get(i) + col := ith.AsSymbol() + // Sanity check a symbol was found + if col == nil { + return p.translator.SyntaxError(ith, "column name expected") + } + // Attempt to lookup the column + cid, ok := p.env.LookupColumn(p.module, col.Value) + // Check it exists + if !ok { + return p.translator.SyntaxError(ith, "unknown column") + } + // Check multiplier calculation + sourceCol := p.env.schema.Columns().Nth(cid) + if i == 0 { + // First time around, multiplier is determine by the first source column. + multiplier = sourceCol.LengthMultiplier() + } else if sourceCol.LengthMultiplier() != multiplier { + // In all other cases, multiplier must match that of first source column. + return p.translator.SyntaxError(sexpSources.Get(i), "inconsistent length multiplier") + } + // Assign + sources[i] = cid + } + // Add assignment + p.env.AddAssignment(assignment.NewInterleaving(p.module, sexpTarget.Value, multiplier, sources)) + // Done return nil } @@ -307,8 +368,12 @@ func (p *hirParser) parseVanishingDeclaration(elements []sexp.SExp, domain *int) if err != nil { return err } - - p.env.schema.AddVanishingConstraint(handle, p.module, domain, expr) + // TODO: improve error reporting here, since the following will just panic if + // the evaluation context is inconsistent (and, since we know the enclosing + // module is consistent, then this should only happen if the length + // multipliers are inconsistent). + _, multiplier := schema.DetermineEnclosingModuleOfExpression(expr, p.env.schema) + p.env.schema.AddVanishingConstraint(handle, p.module, multiplier, domain, expr) return nil } diff --git a/pkg/hir/schema.go b/pkg/hir/schema.go index d8b9d87..4dc7b19 100644 --- a/pkg/hir/schema.go +++ b/pkg/hir/schema.go @@ -78,7 +78,8 @@ func (p *Schema) AddDataColumn(module uint, name string, base sc.Type) { } // AddLookupConstraint appends a new lookup constraint. -func (p *Schema) AddLookupConstraint(handle string, source uint, target uint, sources []UnitExpr, targets []UnitExpr) { +func (p *Schema) AddLookupConstraint(handle string, source uint, source_multiplier uint, target uint, + target_multiplier uint, sources []UnitExpr, targets []UnitExpr) { if len(targets) != len(sources) { panic("differeng number of target / source lookup columns") } @@ -87,30 +88,27 @@ func (p *Schema) AddLookupConstraint(handle string, source uint, target uint, so // Finally add constraint p.constraints = append(p.constraints, - constraint.NewLookupConstraint(handle, source, target, sources, targets)) + constraint.NewLookupConstraint(handle, source, source_multiplier, target, target_multiplier, sources, targets)) } -// AddPermutationColumns introduces a permutation of one or more -// existing columns. Specifically, this introduces one or more -// computed columns which represent a (sorted) permutation of the -// source columns. Each source column is associated with a "sign" -// which indicates the direction of sorting (i.e. ascending versus -// descending). -func (p *Schema) AddPermutationColumns(module uint, targets []sc.Column, signs []bool, sources []uint) { - if module >= uint(len(p.modules)) { - panic(fmt.Sprintf("invalid module index (%d)", module)) - } +// AddAssignment appends a new assignment (i.e. set of computed columns) to be +// used during trace expansion for this schema. Computed columns are introduced +// by the process of lowering from HIR / MIR to AIR. +func (p *Schema) AddAssignment(c schema.Assignment) uint { + index := p.Columns().Count() + p.assignments = append(p.assignments, c) - p.assignments = append(p.assignments, assignment.NewSortedPermutation(module, targets, signs, sources)) + return index } // AddVanishingConstraint appends a new vanishing constraint. -func (p *Schema) AddVanishingConstraint(handle string, module uint, domain *int, expr Expr) { +func (p *Schema) AddVanishingConstraint(handle string, module uint, multiplier uint, domain *int, expr Expr) { if module >= uint(len(p.modules)) { panic(fmt.Sprintf("invalid module index (%d)", module)) } - p.constraints = append(p.constraints, constraint.NewVanishingConstraint(handle, module, domain, ZeroArrayTest{expr})) + p.constraints = append(p.constraints, + constraint.NewVanishingConstraint(handle, module, multiplier, domain, ZeroArrayTest{expr})) } // AddTypeConstraint appends a new range constraint. diff --git a/pkg/hir/util.go b/pkg/hir/util.go index bdabbfb..86c8463 100644 --- a/pkg/hir/util.go +++ b/pkg/hir/util.go @@ -46,7 +46,7 @@ func (p ZeroArrayTest) Bounds() util.Bounds { // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p ZeroArrayTest) Context(schema sc.Schema) (uint, bool) { +func (p ZeroArrayTest) Context(schema sc.Schema) (uint, uint, bool) { return p.Expr.Context(schema) } @@ -86,6 +86,6 @@ func (e UnitExpr) Bounds() util.Bounds { // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (e UnitExpr) Context(schema sc.Schema) (uint, bool) { +func (e UnitExpr) Context(schema sc.Schema) (uint, uint, bool) { return e.expr.Context(schema) } diff --git a/pkg/mir/expr.go b/pkg/mir/expr.go index 336a3f7..c27647c 100644 --- a/pkg/mir/expr.go +++ b/pkg/mir/expr.go @@ -40,7 +40,7 @@ func (p *Add) Bounds() util.Bounds { return util.BoundsForArray(p.Args) } // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *Add) Context(schema sc.Schema) (uint, bool) { +func (p *Add) Context(schema sc.Schema) (uint, uint, bool) { return sc.JoinContexts[Expr](p.Args, schema) } @@ -57,7 +57,7 @@ func (p *Sub) Bounds() util.Bounds { return util.BoundsForArray(p.Args) } // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *Sub) Context(schema sc.Schema) (uint, bool) { +func (p *Sub) Context(schema sc.Schema) (uint, uint, bool) { return sc.JoinContexts[Expr](p.Args, schema) } @@ -74,7 +74,7 @@ func (p *Mul) Bounds() util.Bounds { return util.BoundsForArray(p.Args) } // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *Mul) Context(schema sc.Schema) (uint, bool) { +func (p *Mul) Context(schema sc.Schema) (uint, uint, bool) { return sc.JoinContexts[Expr](p.Args, schema) } @@ -91,8 +91,8 @@ func (p *Constant) Bounds() util.Bounds { return util.EMPTY_BOUND } // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *Constant) Context(schema sc.Schema) (uint, bool) { - return math.MaxUint, true +func (p *Constant) Context(schema sc.Schema) (uint, uint, bool) { + return math.MaxUint, math.MaxUint, true } // ============================================================================ @@ -109,7 +109,7 @@ func (p *Normalise) Bounds() util.Bounds { return p.Arg.Bounds() } // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *Normalise) Context(schema sc.Schema) (uint, bool) { +func (p *Normalise) Context(schema sc.Schema) (uint, uint, bool) { return p.Arg.Context(schema) } @@ -141,6 +141,7 @@ func (p *ColumnAccess) Bounds() util.Bounds { // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *ColumnAccess) Context(schema sc.Schema) (uint, bool) { - return schema.Columns().Nth(p.Column).Module(), true +func (p *ColumnAccess) Context(schema sc.Schema) (uint, uint, bool) { + col := schema.Columns().Nth(p.Column) + return col.Module(), col.LengthMultiplier(), true } diff --git a/pkg/mir/lower.go b/pkg/mir/lower.go index 6f6e304..fc4e860 100644 --- a/pkg/mir/lower.go +++ b/pkg/mir/lower.go @@ -25,12 +25,12 @@ func (p *Schema) LowerToAir() *air.Schema { // Essentially to reflect the fact that these columns have been added above // before others. Realistically, the overall design of this process is a // bit broken right now. - for _, perm := range p.assignments { - airSchema.AddAssignment(perm.(Permutation)) + for _, assign := range p.assignments { + airSchema.AddAssignment(assign) } - // Lower permutations columns - for _, perm := range p.assignments { - lowerPermutationToAir(perm.(Permutation), p, airSchema) + // Now, lower assignments. + for _, assign := range p.assignments { + lowerAssignmentToAir(assign, p, airSchema) } // Lower vanishing constraints for _, c := range p.constraints { @@ -40,6 +40,19 @@ func (p *Schema) LowerToAir() *air.Schema { return airSchema } +// Lower an assignment to the AIR level. +func lowerAssignmentToAir(c sc.Assignment, mirSchema *Schema, airSchema *air.Schema) { + if v, ok := c.(Permutation); ok { + lowerPermutationToAir(v, mirSchema, airSchema) + } else if _, ok := c.(Interleaving); ok { + // Nothing to do for interleaving constraints, as they can be passed + // directly down to the AIR level + return + } else { + panic("unknown assignment") + } +} + // Lower a constraint to the AIR level. func lowerConstraintToAir(c sc.Constraint, schema *air.Schema) { // Check what kind of constraint we have @@ -47,7 +60,7 @@ func lowerConstraintToAir(c sc.Constraint, schema *air.Schema) { lowerLookupConstraintToAir(v, schema) } else if v, ok := c.(VanishingConstraint); ok { air_expr := v.Constraint().Expr.LowerTo(schema) - schema.AddVanishingConstraint(v.Handle(), v.Module(), v.Domain(), air_expr) + schema.AddVanishingConstraint(v.Handle(), v.Module(), v.LengthMultiplier(), v.Domain(), air_expr) } else if v, ok := c.(*constraint.TypeConstraint); ok { if t := v.Type().AsUint(); t != nil { // Yes, a constraint is implied. Now, decide whether to use a range @@ -89,7 +102,9 @@ func lowerLookupConstraintToAir(c LookupConstraint, schema *air.Schema) { sources[i] = air_gadgets.Expand(source, schema) } // finally add the constraint - schema.AddLookupConstraint(c.Handle(), c.SourceModule(), c.TargetModule(), sources, targets) + src_mod, src_mul := c.SourceContext() + dst_mod, dst_mul := c.TargetContext() + schema.AddLookupConstraint(c.Handle(), src_mod, src_mul, dst_mod, dst_mul, sources, targets) } // Lower a permutation to the AIR level. This has quite a few diff --git a/pkg/mir/schema.go b/pkg/mir/schema.go index 7e63ce5..385af0b 100644 --- a/pkg/mir/schema.go +++ b/pkg/mir/schema.go @@ -29,6 +29,9 @@ type PropertyAssertion = *schema.PropertyAssertion[constraint.ZeroTest[Expr]] // Permutation captures the notion of a (sorted) permutation at the MIR level. type Permutation = *assignment.SortedPermutation +// Interleaving captures the notion of an interleaving at the MIR level. +type Interleaving = *assignment.Interleaving + // Schema for MIR traces type Schema struct { // The modules of the schema @@ -74,39 +77,36 @@ func (p *Schema) AddDataColumn(module uint, name string, base schema.Type) { p.inputs = append(p.inputs, assignment.NewDataColumn(module, name, base)) } +// AddAssignment appends a new assignment (i.e. set of computed columns) to be +// used during trace expansion for this schema. Computed columns are introduced +// by the process of lowering from HIR / MIR to AIR. +func (p *Schema) AddAssignment(c schema.Assignment) uint { + index := p.Columns().Count() + p.assignments = append(p.assignments, c) + + return index +} + // AddLookupConstraint appends a new lookup constraint. -func (p *Schema) AddLookupConstraint(handle string, source uint, target uint, sources []Expr, targets []Expr) { +func (p *Schema) AddLookupConstraint(handle string, source uint, source_context uint, target uint, + target_context uint, sources []Expr, targets []Expr) { if len(targets) != len(sources) { panic("differeng number of target / source lookup columns") } // TODO: sanity source columns are in the same module, and likewise target // columns (though they don't have to be in the same column together). p.constraints = append(p.constraints, - constraint.NewLookupConstraint(handle, source, target, sources, targets)) -} - -// AddPermutationColumns introduces a permutation of one or more -// existing columns. Specifically, this introduces one or more -// computed columns which represent a (sorted) permutation of the -// source columns. Each source column is associated with a "sign" -// which indicates the direction of sorting (i.e. ascending versus -// descending). -func (p *Schema) AddPermutationColumns(module uint, targets []schema.Column, signs []bool, sources []uint) { - if module >= uint(len(p.modules)) { - panic(fmt.Sprintf("invalid module index (%d)", module)) - } - - p.assignments = append(p.assignments, assignment.NewSortedPermutation(module, targets, signs, sources)) + constraint.NewLookupConstraint(handle, source, source_context, target, target_context, sources, targets)) } // AddVanishingConstraint appends a new vanishing constraint. -func (p *Schema) AddVanishingConstraint(handle string, module uint, domain *int, expr Expr) { +func (p *Schema) AddVanishingConstraint(handle string, module uint, multiplier uint, domain *int, expr Expr) { if module >= uint(len(p.modules)) { panic(fmt.Sprintf("invalid module index (%d)", module)) } p.constraints = append(p.constraints, - constraint.NewVanishingConstraint(handle, module, domain, constraint.ZeroTest[Expr]{Expr: expr})) + constraint.NewVanishingConstraint(handle, module, multiplier, domain, constraint.ZeroTest[Expr]{Expr: expr})) } // AddTypeConstraint appends a new range constraint. diff --git a/pkg/schema/assignment/byte_decomposition.go b/pkg/schema/assignment/byte_decomposition.go index 2bb4c85..811e45d 100644 --- a/pkg/schema/assignment/byte_decomposition.go +++ b/pkg/schema/assignment/byte_decomposition.go @@ -19,7 +19,7 @@ type ByteDecomposition struct { } // NewByteDecomposition creates a new sorted permutation -func NewByteDecomposition(prefix string, module uint, source uint, width uint) *ByteDecomposition { +func NewByteDecomposition(prefix string, module uint, multiplier uint, source uint, width uint) *ByteDecomposition { if width == 0 { panic("zero byte decomposition encountered") } @@ -30,7 +30,7 @@ func NewByteDecomposition(prefix string, module uint, source uint, width uint) * for i := uint(0); i < width; i++ { name := fmt.Sprintf("%s:%d", prefix, i) - targets[i] = schema.NewColumn(module, name, U8) + targets[i] = schema.NewColumn(module, name, multiplier, U8) } // Done return &ByteDecomposition{source, targets} @@ -66,10 +66,10 @@ func (p *ByteDecomposition) ExpandTrace(tr trace.Trace) error { columns := tr.Columns() // Calculate how many bytes required. n := len(p.targets) - // Identify target column - target := columns.Get(p.source) + // Identify source column + source := columns.Get(p.source) // Extract column data to decompose - data := columns.Get(p.source).Data() + data := source.Data() // Construct byte column data cols := make([][]*fr.Element, n) // Initialise columns @@ -84,11 +84,11 @@ func (p *ByteDecomposition) ExpandTrace(tr trace.Trace) error { } } // Determine padding values - padding := decomposeIntoBytes(target.Padding(), n) + padding := decomposeIntoBytes(source.Padding(), n) // Finally, add byte columns to trace for i := 0; i < n; i++ { ith := p.targets[i] - columns.Add(trace.NewFieldColumn(ith.Module(), ith.Name(), cols[i], padding[i])) + columns.Add(trace.NewFieldColumn(ith.Module(), ith.Name(), ith.LengthMultiplier(), cols[i], padding[i])) } // Done return nil diff --git a/pkg/schema/assignment/computed_column.go b/pkg/schema/assignment/computed_column.go index 99072f3..f173491 100644 --- a/pkg/schema/assignment/computed_column.go +++ b/pkg/schema/assignment/computed_column.go @@ -17,10 +17,7 @@ import ( // give rise to "trace expansion". That is where the initial trace provided by // the user is expanded by determining the value of all computed columns. type ComputedColumn[E sc.Evaluable] struct { - // Module in which to locate new column - module uint - // Name of the new column - name string + target schema.Column // The computation which accepts a given trace and computes // the value of this column at a given row. expr E @@ -29,18 +26,20 @@ type ComputedColumn[E sc.Evaluable] struct { // NewComputedColumn constructs a new computed column with a given name and // determining expression. More specifically, that expression is used to // compute the values for this column during trace expansion. -func NewComputedColumn[E sc.Evaluable](module uint, name string, expr E) *ComputedColumn[E] { - return &ComputedColumn[E]{module, name, expr} +func NewComputedColumn[E sc.Evaluable](module uint, name string, multiplier uint, expr E) *ComputedColumn[E] { + // FIXME: Determine computed columns type? + column := schema.NewColumn(module, name, multiplier, &schema.FieldType{}) + return &ComputedColumn[E]{column, expr} } // nolint:revive func (p *ComputedColumn[E]) String() string { - return fmt.Sprintf("(compute %s %s)", p.name, any(p.expr)) + return fmt.Sprintf("(compute %s %s)", p.Name(), any(p.expr)) } // Name returns the name of this computed column. func (p *ComputedColumn[E]) Name() string { - return p.name + return p.target.Name() } // ============================================================================ @@ -50,8 +49,7 @@ func (p *ComputedColumn[E]) Name() string { // Columns returns the columns declared by this computed column. func (p *ComputedColumn[E]) Columns() util.Iterator[schema.Column] { // TODO: figure out appropriate type for computed column - column := schema.NewColumn(p.module, p.name, &schema.FieldType{}) - return util.NewUnitIterator[schema.Column](column) + return util.NewUnitIterator[schema.Column](p.target) } // IsComputed Determines whether or not this declaration is computed (which it @@ -81,11 +79,15 @@ func (p *ComputedColumn[E]) RequiredSpillage() uint { func (p *ComputedColumn[E]) ExpandTrace(tr trace.Trace) error { columns := tr.Columns() // Check whether a column already exists with the given name. - if tr.Columns().HasColumn(p.name) { - return fmt.Errorf("Computed column already exists ({%s})", p.name) + if columns.HasColumn(p.Name()) { + return fmt.Errorf("column already exists ({%s})", p.Name()) } - - data := make([]*fr.Element, tr.Modules().Get(p.module).Height()) + // Extract length multipiler + multiplier := p.target.LengthMultiplier() + // Determine multiplied height + height := tr.Modules().Get(p.target.Module()).Height() * multiplier + // Make space for computed data + data := make([]*fr.Element, height) // Expand the trace for i := 0; i < len(data); i++ { val := p.expr.EvalAt(i, tr) @@ -101,7 +103,7 @@ func (p *ComputedColumn[E]) ExpandTrace(tr trace.Trace) error { // the padding value for *this* column. padding := p.expr.EvalAt(-1, tr) // Colunm needs to be expanded. - columns.Add(trace.NewFieldColumn(p.module, p.name, data, padding)) + columns.Add(trace.NewFieldColumn(p.target.Module(), p.Name(), multiplier, data, padding)) // Done return nil } diff --git a/pkg/schema/assignment/data_column.go b/pkg/schema/assignment/data_column.go index f1615eb..36799b6 100644 --- a/pkg/schema/assignment/data_column.go +++ b/pkg/schema/assignment/data_column.go @@ -54,7 +54,8 @@ func (c *DataColumn) String() string { // Columns returns the columns declared by this computed column. func (p *DataColumn) Columns() util.Iterator[schema.Column] { - column := schema.NewColumn(p.module, p.name, p.datatype) + // Datacolumns always have a multiplier of 1. + column := schema.NewColumn(p.module, p.name, 1, p.datatype) return util.NewUnitIterator[schema.Column](column) } diff --git a/pkg/schema/assignment/interleave.go b/pkg/schema/assignment/interleave.go new file mode 100644 index 0000000..0c20a2a --- /dev/null +++ b/pkg/schema/assignment/interleave.go @@ -0,0 +1,115 @@ +package assignment + +import ( + "fmt" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/go-corset/pkg/schema" + "github.com/consensys/go-corset/pkg/trace" + tr "github.com/consensys/go-corset/pkg/trace" + "github.com/consensys/go-corset/pkg/util" +) + +// Interleaving generates a new column by interleaving two or more existing +// colummns. For example, say Z interleaves X and Y (in that order) and we have +// a trace X=[1,2], Y=[3,4]. Then, the interleaved column Z has the values +// Z=[1,3,2,4]. +type Interleaving struct { + // Module where this interleaving is located. + module uint + // The new (interleaved) column + target schema.Column + // The source columns + sources []uint +} + +// NewInterleaving constructs a new interleaving assignment. +func NewInterleaving(module uint, name string, multiplier uint, sources []uint) *Interleaving { + // Update multiplier + multiplier = multiplier * uint(len(sources)) + // Fixme: determine interleaving type + target := schema.NewColumn(module, name, multiplier, &schema.FieldType{}) + + return &Interleaving{module, target, sources} +} + +// Module returns the module which encloses this interleaving. +func (p *Interleaving) Module() uint { + return p.module +} + +// Sources returns the columns used by this interleaving to define the new +// (interleaved) column. +func (p *Interleaving) Sources() []uint { + return p.sources +} + +// ============================================================================ +// Declaration Interface +// ============================================================================ + +// Columns returns the column declared by this interleaving. +func (p *Interleaving) Columns() util.Iterator[schema.Column] { + return util.NewUnitIterator(p.target) +} + +// IsComputed Determines whether or not this declaration is computed (which an +// interleaving column is by definition). +func (p *Interleaving) IsComputed() bool { + return true +} + +// ============================================================================ +// Assignment Interface +// ============================================================================ + +// RequiredSpillage returns the minimum amount of spillage required to ensure +// valid traces are accepted in the presence of arbitrary padding. +func (p *Interleaving) RequiredSpillage() uint { + return uint(0) +} + +// ExpandTrace expands a given trace to include the columns specified by a given +// Interleaving. This requires copying the data in the source columns to create +// the interleaved column. +func (p *Interleaving) ExpandTrace(tr tr.Trace) error { + columns := tr.Columns() + // Ensure target column doesn't exist + for i := p.Columns(); i.HasNext(); { + name := i.Next().Name() + // Sanity check no column already exists with this name. + if columns.HasColumn(name) { + return fmt.Errorf("column already exists ({%s})", name) + } + } + // Determine interleaving width + width := uint(len(p.sources)) + // Following division should always produce whole value because the length + // multiplier already includes the width as a factor. + multiplier := p.target.LengthMultiplier() / width + // Determine module height (as this can be used to determine the height of + // the interleaved column) + height := tr.Modules().Get(p.module).Height() * multiplier + // Construct empty array + data := make([]*fr.Element, height*width) + // Offset just gives the column index + offset := uint(0) + // Copy interleaved data + for i := uint(0); i < width; i++ { + // Lookup source column + col := tr.Columns().Get(p.sources[i]) + // Copy over + for j := uint(0); j < height; j++ { + data[offset+(j*width)] = col.Get(int(j)) + } + + offset++ + } + // Padding for the entire column is determined by the padding for the first + // column in the interleaving. + padding := columns.Get(0).Padding() + // Colunm needs to be expanded. + columns.Add(trace.NewFieldColumn(p.module, p.target.Name(), multiplier*width, data, padding)) + // + return nil +} diff --git a/pkg/schema/assignment/lexicographic_sort.go b/pkg/schema/assignment/lexicographic_sort.go index 117f37b..7dc7ec4 100644 --- a/pkg/schema/assignment/lexicographic_sort.go +++ b/pkg/schema/assignment/lexicographic_sort.go @@ -17,6 +17,8 @@ type LexicographicSort struct { // Module in which source and target columns to be located. All target and // source columns should be contained within this module. module uint + // Length multiplier for all columns in this gadget + multiplier uint // The target columns to be filled. The first entry is for the delta // column, and the remaining n entries are for the selector columns. targets []schema.Column @@ -27,17 +29,19 @@ type LexicographicSort struct { } // NewLexicographicSort constructs a new LexicographicSorting assignment. -func NewLexicographicSort(prefix string, module uint, sources []uint, signs []bool, bitwidth uint) *LexicographicSort { +func NewLexicographicSort(prefix string, module uint, multiplier uint, + sources []uint, signs []bool, bitwidth uint) *LexicographicSort { + // targets := make([]schema.Column, len(sources)+1) // Create delta column - targets[0] = schema.NewColumn(module, fmt.Sprintf("%s:delta", prefix), schema.NewUintType(bitwidth)) + targets[0] = schema.NewColumn(module, fmt.Sprintf("%s:delta", prefix), multiplier, schema.NewUintType(bitwidth)) // Create selector columns for i := range sources { ithName := fmt.Sprintf("%s:%d", prefix, i) - targets[1+i] = schema.NewColumn(module, ithName, schema.NewUintType(1)) + targets[1+i] = schema.NewColumn(module, ithName, multiplier, schema.NewUintType(1)) } - return &LexicographicSort{module, targets, sources, signs, bitwidth} + return &LexicographicSort{module, multiplier, targets, sources, signs, bitwidth} } // ============================================================================ @@ -73,8 +77,10 @@ func (p *LexicographicSort) ExpandTrace(tr trace.Trace) error { one := fr.NewElement(1) // Exact number of columns involved in the sort ncols := len(p.sources) + // + multiplier := p.multiplier // Determine how many rows to be constrained. - nrows := tr.Modules().Get(p.module).Height() + nrows := tr.Modules().Get(p.module).Height() * multiplier // Initialise new data columns delta := make([]*fr.Element, nrows) bit := make([][]*fr.Element, ncols) @@ -113,11 +119,11 @@ func (p *LexicographicSort) ExpandTrace(tr trace.Trace) error { } // Add delta column data first := p.targets[0] - columns.Add(trace.NewFieldColumn(first.Module(), first.Name(), delta, &zero)) + columns.Add(trace.NewFieldColumn(first.Module(), first.Name(), multiplier, delta, &zero)) // Add bit column data for i := 0; i < ncols; i++ { ith := p.targets[1+i] - columns.Add(trace.NewFieldColumn(ith.Module(), ith.Name(), bit[i], &zero)) + columns.Add(trace.NewFieldColumn(ith.Module(), ith.Name(), multiplier, bit[i], &zero)) } // Done. return nil diff --git a/pkg/schema/assignment/sorted_permutation.go b/pkg/schema/assignment/sorted_permutation.go index 7e95516..a9d18ae 100644 --- a/pkg/schema/assignment/sorted_permutation.go +++ b/pkg/schema/assignment/sorted_permutation.go @@ -14,6 +14,8 @@ import ( // existing columns. type SortedPermutation struct { module uint + // Length multiplier + multiplier uint // The new (sorted) columns targets []schema.Column // The sorting criteria @@ -23,7 +25,8 @@ type SortedPermutation struct { } // NewSortedPermutation creates a new sorted permutation -func NewSortedPermutation(module uint, targets []schema.Column, signs []bool, sources []uint) *SortedPermutation { +func NewSortedPermutation(module uint, multiplier uint, targets []schema.Column, + signs []bool, sources []uint) *SortedPermutation { if len(targets) != len(signs) || len(signs) != len(sources) { panic("target and source column widths must match") } @@ -31,10 +34,12 @@ func NewSortedPermutation(module uint, targets []schema.Column, signs []bool, so for _, c := range targets { if c.Module() != module { panic("inconsistent target modules") + } else if c.LengthMultiplier() != multiplier { + panic("inconsistent length multipliers for target columns") } } - return &SortedPermutation{module, targets, signs, sources} + return &SortedPermutation{module, multiplier, targets, signs, sources} } // Module returns the module which encloses this sorted permutation. @@ -125,8 +130,8 @@ func (p *SortedPermutation) ExpandTrace(tr tr.Trace) error { for i := p.Columns(); i.HasNext(); { name := i.Next().Name() // Sanity check no column already exists with this name. - if tr.Columns().HasColumn(name) { - panic("target column already exists") + if columns.HasColumn(name) { + return fmt.Errorf("column already exists ({%s})", name) } } @@ -145,13 +150,11 @@ func (p *SortedPermutation) ExpandTrace(tr tr.Trace) error { // Physically add the columns index := 0 - for i := p.Columns(); i.HasNext(); { + for i := p.Columns(); i.HasNext(); index++ { ith := i.Next() dstColName := ith.Name() srcCol := tr.Columns().Get(p.sources[index]) - columns.Add(trace.NewFieldColumn(ith.Module(), dstColName, cols[index], srcCol.Padding())) - - index++ + columns.Add(trace.NewFieldColumn(ith.Module(), dstColName, p.multiplier, cols[index], srcCol.Padding())) } // return nil diff --git a/pkg/schema/constraint/lookup.go b/pkg/schema/constraint/lookup.go index ff3a8a4..c3b3499 100644 --- a/pkg/schema/constraint/lookup.go +++ b/pkg/schema/constraint/lookup.go @@ -28,8 +28,14 @@ type LookupConstraint[E schema.Evaluable] struct { handle string // Enclosing module for source columns. source uint + // Length multiplier partly determines the evaluation context for source + // expressions. + source_multiplier uint // Enclosing module for target columns. target uint + // Length multiplier partly determines the evaluation context for target + // expressions. + target_multiplier uint // Source rows represent the subset of rows. sources []E // Target rows represent the set of rows. @@ -37,13 +43,13 @@ type LookupConstraint[E schema.Evaluable] struct { } // NewLookupConstraint creates a new lookup constraint with a given handle. -func NewLookupConstraint[E schema.Evaluable](handle string, source uint, target uint, - sources []E, targets []E) *LookupConstraint[E] { +func NewLookupConstraint[E schema.Evaluable](handle string, source uint, source_multiplier uint, + target uint, target_multiplier uint, sources []E, targets []E) *LookupConstraint[E] { if len(targets) != len(sources) { panic("differeng number of target / source lookup columns") } - return &LookupConstraint[E]{handle, source, target, sources, targets} + return &LookupConstraint[E]{handle, source, source_multiplier, target, target_multiplier, sources, targets} } // Handle returns the handle for this lookup constraint which is simply an @@ -54,14 +60,14 @@ func (p *LookupConstraint[E]) Handle() string { return p.handle } -// SourceModule returns the module in which all source columns are located. -func (p *LookupConstraint[E]) SourceModule() uint { - return p.source +// SourceContext returns the module in which all source columns are located. +func (p *LookupConstraint[E]) SourceContext() (uint, uint) { + return p.source, p.source_multiplier } -// TargetModule returns the module in which all target columns are located. -func (p *LookupConstraint[E]) TargetModule() uint { - return p.target +// TargetContext returns the module in which all target columns are located. +func (p *LookupConstraint[E]) TargetContext() (uint, uint) { + return p.target, p.target_multiplier } // Sources returns the source expressions which are used to lookup into the @@ -82,8 +88,8 @@ func (p *LookupConstraint[E]) Targets() []E { //nolint:revive func (p *LookupConstraint[E]) Accepts(tr trace.Trace) error { // Determine height of enclosing module for source columns - src_height := tr.Modules().Get(p.source).Height() - tgt_height := tr.Modules().Get(p.target).Height() + src_height := tr.Modules().Get(p.source).Height() * p.source_multiplier + tgt_height := tr.Modules().Get(p.target).Height() * p.target_multiplier // Go through every row of the source columns checking they are present in // the target columns. // diff --git a/pkg/schema/constraint/range.go b/pkg/schema/constraint/range.go index 5ccf65d..c0c7036 100644 --- a/pkg/schema/constraint/range.go +++ b/pkg/schema/constraint/range.go @@ -38,7 +38,7 @@ func NewRangeConstraint(column uint, bound *fr.Element) *RangeConstraint { // every row of a table. If so, return nil otherwise return an error. func (p *RangeConstraint) Accepts(tr trace.Trace) error { column := tr.Columns().Get(p.column) - height := tr.Modules().Get(column.Module()).Height() + height := column.Height() // Iterate all rows of the module for k := 0; k < int(height); k++ { // Get the value on the kth row diff --git a/pkg/schema/constraint/type.go b/pkg/schema/constraint/type.go index 93f9eda..45f14de 100644 --- a/pkg/schema/constraint/type.go +++ b/pkg/schema/constraint/type.go @@ -40,8 +40,8 @@ func (p *TypeConstraint) Type() schema.Type { // every row of a table. If so, return nil otherwise return an error. func (p *TypeConstraint) Accepts(tr trace.Trace) error { column := tr.Columns().Get(p.column) - // Determine height of enclosing module - height := tr.Modules().Get(column.Module()).Height() + // Determine height + height := column.Height() // Iterate every row for k := 0; k < int(height); k++ { // Get the value on the kth row diff --git a/pkg/schema/constraint/vanishing.go b/pkg/schema/constraint/vanishing.go index 701f568..ef841df 100644 --- a/pkg/schema/constraint/vanishing.go +++ b/pkg/schema/constraint/vanishing.go @@ -31,7 +31,7 @@ func (p ZeroTest[E]) Bounds() util.Bounds { // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p ZeroTest[E]) Context(schema sc.Schema) (uint, bool) { +func (p ZeroTest[E]) Context(schema sc.Schema) (uint, uint, bool) { return p.Expr.Context(schema) } @@ -55,6 +55,9 @@ type VanishingConstraint[T schema.Testable] struct { // Enclosing module for this assertion. This restricts the constraint to // access only columns from within this module. module uint + // Length multiplier. This is used to the column's actual height as a + // multipler of the enclosing module's height. + multiplier uint // Indicates (when nil) a global constraint that applies to all rows. // Otherwise, indicates a local constraint which applies to the specific row // given here. @@ -65,9 +68,9 @@ type VanishingConstraint[T schema.Testable] struct { } // NewVanishingConstraint constructs a new vanishing constraint! -func NewVanishingConstraint[T schema.Testable](handle string, module uint, +func NewVanishingConstraint[T schema.Testable](handle string, module uint, multiplier uint, domain *int, constraint T) *VanishingConstraint[T] { - return &VanishingConstraint[T]{handle, module, domain, constraint} + return &VanishingConstraint[T]{handle, module, multiplier, domain, constraint} } // Handle returns the handle associated with this constraint. @@ -96,6 +99,13 @@ func (p *VanishingConstraint[T]) Module() uint { return p.module } +// LengthMultiplier returns the length multiplier used by this vanishing +// constraint. This should match the evaluation context of the vanishing +// expression. +func (p *VanishingConstraint[T]) LengthMultiplier() uint { + return p.multiplier +} + // Accepts checks whether a vanishing constraint evaluates to zero on every row // of a table. If so, return nil otherwise return an error. // @@ -103,7 +113,7 @@ func (p *VanishingConstraint[T]) Module() uint { func (p *VanishingConstraint[T]) Accepts(tr trace.Trace) error { if p.domain == nil { // Global Constraint - return HoldsGlobally(p.handle, p.module, p.constraint, tr) + return HoldsGlobally(p.handle, p.module, p.multiplier, p.constraint, tr) } // Local constraint var start uint @@ -122,9 +132,9 @@ func (p *VanishingConstraint[T]) Accepts(tr trace.Trace) error { // HoldsGlobally checks whether a given expression vanishes (i.e. evaluates to // zero) for all rows of a trace. If not, report an appropriate error. -func HoldsGlobally[T schema.Testable](handle string, module uint, constraint T, tr trace.Trace) error { +func HoldsGlobally[T schema.Testable](handle string, module uint, multiplier uint, constraint T, tr trace.Trace) error { // Determine height of enclosing module - height := tr.Modules().Get(module).Height() + height := tr.Modules().Get(module).Height() * multiplier // Determine well-definedness bounds for this constraint bounds := constraint.Bounds() // Sanity check enough rows diff --git a/pkg/schema/schema.go b/pkg/schema/schema.go index 5bf1fc1..d1cc96f 100644 --- a/pkg/schema/schema.go +++ b/pkg/schema/schema.go @@ -73,16 +73,18 @@ type Constraint interface { // require a single context. This interface is separated from Evaluable (and // Testable) because HIR expressions do not implement Evaluable. type Contextual interface { - // Context returns the evaluation context (i.e. enclosing module) for this - // constraint. Every testable constraint must have a single evaluation - // context. This function therefore attempts to determine what that is, or - // return false to signal an error. There are several failure modes which - // need to be considered. Firstly, if the expression has no enclosing - // module (e.g. because it is a constant expression) then it will return - // 'math.MaxUint` to signal this. Secondly, if the expression has multiple - // (i.e. conflicting) enclosing modules then it will return false to signal - // this. - Context(Schema) (uint, bool) + // Context returns the evaluation context (i.e. enclosing module + length + // multiplier) for this constraint. Every expression must have a single + // evaluation context. This function therefore attempts to determine what + // that is, or return false to signal an error. There are several failure + // modes which need to be considered. Firstly, if the expression has no + // enclosing module (e.g. because it is a constant expression) then it will + // return 'math.MaxUint` to signal this. Secondly, if the expression has + // multiple (i.e. conflicting) enclosing modules then it will return false + // to signal this. Likewise, the expression could have a single enclosing + // module but multiple conflicting length multipliers, in which case it also + // returns false. + Context(Schema) (uint, uint, bool) } // Evaluable captures something which can be evaluated on a given table row to @@ -130,13 +132,16 @@ type Column struct { module uint // Returns the name of this column name string + // Length multiplier. This is used to the column'ss actual height as a + // multipler of the enclosing module's height. + multiplier uint // Returns the expected type of data in this column datatype Type } // NewColumn constructs a new column -func NewColumn(module uint, name string, datatype Type) Column { - return Column{module, name, datatype} +func NewColumn(module uint, name string, multiplier uint, datatype Type) Column { + return Column{module, name, multiplier, datatype} } // Module returns the index of the module which contains this column @@ -149,6 +154,12 @@ func (p Column) Name() string { return p.name } +// LengthMultiplier is needed to the column's actual height as a +// multipler of the enclosing module's height. +func (p Column) LengthMultiplier() uint { + return p.multiplier +} + // Type returns the expected type of data in this column func (p Column) Type() Type { return p.datatype diff --git a/pkg/schema/schemas.go b/pkg/schema/schemas.go index 6db181d..00e21ae 100644 --- a/pkg/schema/schemas.go +++ b/pkg/schema/schemas.go @@ -7,41 +7,44 @@ import ( tr "github.com/consensys/go-corset/pkg/trace" ) -// JoinContexts combines one or more contexts together. There are a number of -// scenarios. The simple path is when each expression has the same evaluation -// context (in which case this is returned). Its also possible one or more -// expressions have no evaluation context (signaled by math.MaxUint) and this -// can be ignored. Finally, we might have two expressions with conflicting +// JoinContexts combines one or more evaluation contexts together. There are a +// number of scenarios. The simple path is when each expression has the same +// evaluation context (in which case this is returned). Its also possible one +// or more expressions have no evaluation context (signaled by math.MaxUint) and +// this can be ignored. Finally, we might have two expressions with conflicting // evaluation contexts, and this clearly signals an error. -func JoinContexts[E Contextual](args []E, schema Schema) (uint, bool) { - var ctx uint = math.MaxUint +func JoinContexts[E Contextual](args []E, schema Schema) (uint, uint, bool) { + var mid uint = math.MaxUint + + var multiplier = uint(1) // for _, e := range args { - c, b := e.Context(schema) + c, m, b := e.Context(schema) if !b { // Indicates conflict detected upstream, therefore propagate this // down. - return 0, false - } else if ctx == math.MaxUint { + return 0, 0, false + } else if mid == math.MaxUint { // No evaluation context determined yet, therefore can overwrite // with whatever we got. Observe that this might still actually - ctx = c - } else if c != ctx && c != math.MaxUint { + mid = c + multiplier = m + } else if c != math.MaxUint && (c != mid || m != multiplier) { // This indicates a conflict is detected, therefore we must // propagate this down. - return 0, false + return 0, 0, false } } // If we get here, then no conflicts were detected. - return ctx, true + return mid, multiplier, true } // DetermineEnclosingModuleOfExpression determines (and checks) the enclosing // module for a given expression. The expectation is that there is a single // enclosing module, and this function will panic if that does not hold. -func DetermineEnclosingModuleOfExpression[E Contextual](expr E, schema Schema) uint { - if mid, ok := expr.Context(schema); ok && mid != math.MaxUint { - return mid +func DetermineEnclosingModuleOfExpression[E Contextual](expr E, schema Schema) (uint, uint) { + if mid, multiplier, ok := expr.Context(schema); ok && mid != math.MaxUint { + return mid, multiplier } // panic("expression has no evaluation context") @@ -50,52 +53,58 @@ func DetermineEnclosingModuleOfExpression[E Contextual](expr E, schema Schema) u // DetermineEnclosingModuleOfExpressions determines (and checks) the enclosing // module for a given set of expressions. The expectation is that there is a single // enclosing module, and this function will panic if that does not hold. -func DetermineEnclosingModuleOfExpressions[E Contextual](exprs []E, schema Schema) (uint, error) { +func DetermineEnclosingModuleOfExpressions[E Contextual](exprs []E, schema Schema) (uint, uint, error) { // Sanity check input if len(exprs) == 0 { panic("cannot determine enclosing module for empty expression array") } // Determine first - mid, ok := exprs[0].Context(schema) + mid, multiplier, ok := exprs[0].Context(schema) // Sanity check this made sense if !ok { - return 0, errors.New("conflicting enclosing modules") + return 0, 0, errors.New("conflicting enclosing modules") } // Check rest against this for i := 1; i < len(exprs); i++ { - m, ok := exprs[i].Context(schema) + m, f, ok := exprs[i].Context(schema) if !ok { - return uint(i), errors.New("conflicting enclosing modules") + return uint(i), 0, errors.New("conflicting enclosing modules") } else if mid == math.MaxUint { mid = m } else if m != math.MaxUint && m != mid { - return uint(i), errors.New("conflicting enclosing modules") + return uint(i), 0, errors.New("conflicting enclosing modules") + } else if m != math.MaxUint && f != multiplier { + return uint(i), 0, errors.New("conflicting length multipliers") } } // success - return mid, nil + return mid, multiplier, nil } // DetermineEnclosingModuleOfColumns determines (and checks) the enclosing module for a // given set of columns. The expectation is that there is a single enclosing // module, and this function will panic if that does not hold. -func DetermineEnclosingModuleOfColumns(cols []uint, schema Schema) uint { +func DetermineEnclosingModuleOfColumns(cols []uint, schema Schema) (uint, uint) { + head := schema.Columns().Nth(cols[0]) // First, determine module of first column. - mid := schema.Columns().Nth(cols[0]).Module() + mid := head.Module() + multiplier := head.LengthMultiplier() // Second, check other columns in the same module. // // NOTE: this could potentially be made more efficient by checking the // columns of the module for the first column. for i := 1; i < len(cols); i++ { - col := cols[i] - if mid != schema.Columns().Nth(col).Module() { + col := schema.Columns().Nth(cols[i]) + if mid != col.Module() { // This is an internal failure which should be prevented by upstream // checking (e.g. in the parser). panic("columns have different enclosing module") + } else if multiplier != col.LengthMultiplier() { + panic("columns have different length multipliers") } } // Done - return mid + return mid, multiplier } // RequiredSpillage returns the minimum amount of spillage required to ensure diff --git a/pkg/test/ir_test.go b/pkg/test/ir_test.go index 0a1ec2b..c1d2804 100644 --- a/pkg/test/ir_test.go +++ b/pkg/test/ir_test.go @@ -346,6 +346,26 @@ func Test_Lookup_06(t *testing.T) { Check(t, "lookup_06") } +// =================================================================== +// Interleaving +// =================================================================== + +func Test_Interleave_01(t *testing.T) { + Check(t, "interleave_01") +} + +func Test_Interleave_02(t *testing.T) { + Check(t, "interleave_02") +} + +func Test_Interleave_03(t *testing.T) { + Check(t, "interleave_03") +} + +func Test_Interleave_04(t *testing.T) { + Check(t, "interleave_04") +} + // =================================================================== // Complex Tests // =================================================================== diff --git a/pkg/trace/array_trace.go b/pkg/trace/array_trace.go index 40d83e7..b655a4d 100644 --- a/pkg/trace/array_trace.go +++ b/pkg/trace/array_trace.go @@ -1,6 +1,7 @@ package trace import ( + "fmt" "strings" "github.com/consensys/go-corset/pkg/util" @@ -103,9 +104,10 @@ type arrayTraceColumnSet struct { // Add a new column to this column set. func (p arrayTraceColumnSet) Add(column Column) uint { m := &p.trace.modules[column.Module()] - // Sanity check height - if column.Height() != m.Height() { - panic("invalid column height") + // Sanity check effective height + if column.Height() != (column.LengthMultiplier() * m.Height()) { + panic(fmt.Sprintf("invalid column height for %s: %d vs %d*%d", column.Name(), + column.Height(), m.Height(), column.LengthMultiplier())) } // Proceed index := uint(len(p.trace.columns)) diff --git a/pkg/trace/builder.go b/pkg/trace/builder.go index e288fc7..9e3b053 100644 --- a/pkg/trace/builder.go +++ b/pkg/trace/builder.go @@ -52,8 +52,9 @@ func (p *Builder) Add(name string, padding *fr.Element, data []*fr.Element) erro return err } } - // register new column - return p.registerColumn(NewFieldColumn(mid, colname, data, padding)) + // Register new column. Observe that user-provided columns always have a + // factor of 1. + return p.registerColumn(NewFieldColumn(mid, colname, 1, data, padding)) } // HasModule checks whether a given module has already been registered with this diff --git a/pkg/trace/bytes_column.go b/pkg/trace/bytes_column.go index bdbc4de..d0526f2 100644 --- a/pkg/trace/bytes_column.go +++ b/pkg/trace/bytes_column.go @@ -19,6 +19,8 @@ type BytesColumn struct { width uint8 // The number of data elements in this column. length uint + // Length multiplier (i.e. of length) + multiplier uint // The data stored in this column (as bytes). bytes []byte // Value to be used when padding this column @@ -26,9 +28,14 @@ type BytesColumn struct { } // NewBytesColumn constructs a new BytesColumn from its constituent parts. -func NewBytesColumn(module uint, name string, width uint8, length uint, +func NewBytesColumn(module uint, name string, width uint8, length uint, multiplier uint, bytes []byte, padding *fr.Element) *BytesColumn { - return &BytesColumn{module, name, width, length, bytes, padding} + // Sanity check data length + if length%multiplier != 0 { + panic("data length has incorrect multiplier") + } + + return &BytesColumn{module, name, width, length, multiplier, bytes, padding} } // Module returns the enclosing module of this column @@ -51,6 +58,14 @@ func (p *BytesColumn) Height() uint { return p.length } +// LengthMultiplier is a multiplier of the enclosing module's height used to +// determine this column's height. For example, if the multiplier is 2 then the +// height of this column must always be a multiple of 2, etc. This affects +// padding also, as we must pad to this multiplier. +func (p *BytesColumn) LengthMultiplier() uint { + return p.multiplier +} + // Padding returns the value which will be used for padding this column. func (p *BytesColumn) Padding() *fr.Element { return p.padding @@ -74,6 +89,7 @@ func (p *BytesColumn) Clone() Column { clone.name = p.name clone.length = p.length clone.width = p.width + clone.multiplier = p.multiplier clone.padding = p.padding // NOTE: the following is as we never actually mutate the underlying bytes // array. @@ -109,6 +125,8 @@ func (p *BytesColumn) Data() []*fr.Element { // Pad this column with n copies of the column's padding value. func (p *BytesColumn) Pad(n uint) { + // Apply the length multiplier + n = n * p.multiplier // Computing padding length (in bytes) padding_len := n * uint(p.width) // Access bytes to use for padding diff --git a/pkg/trace/field_column.go b/pkg/trace/field_column.go index 5bcedd2..199ed5c 100644 --- a/pkg/trace/field_column.go +++ b/pkg/trace/field_column.go @@ -16,6 +16,8 @@ type FieldColumn struct { module uint // Holds the name of this column name string + // Length multiplier (i.e. of the data array) + multiplier uint // Holds the raw data making up this column data []*fr.Element // Value to be used when padding this column @@ -23,8 +25,13 @@ type FieldColumn struct { } // NewFieldColumn constructs a FieldColumn with the give name, data and padding. -func NewFieldColumn(module uint, name string, data []*fr.Element, padding *fr.Element) *FieldColumn { - return &FieldColumn{module, name, data, padding} +func NewFieldColumn(module uint, name string, multiplier uint, data []*fr.Element, padding *fr.Element) *FieldColumn { + // Sanity check data length + if uint(len(data))%multiplier != 0 { + panic("data length has incorrect multiplier") + } + // Done + return &FieldColumn{module, name, multiplier, data, padding} } // Module returns the enclosing module of this column @@ -48,6 +55,13 @@ func (p *FieldColumn) Height() uint { return uint(len(p.data)) } +// LengthMultiplier is a multiplier which must be a factor of the height. For +// example, if the factor is 2 then the height must always be a multiple of 2, +// etc. This affects padding also, as we must pad to this factor. +func (p *FieldColumn) LengthMultiplier() uint { + return p.multiplier +} + // Padding returns the value which will be used for padding this column. func (p *FieldColumn) Padding() *fr.Element { return p.padding @@ -75,6 +89,7 @@ func (p *FieldColumn) Clone() Column { clone := new(FieldColumn) clone.module = p.module clone.name = p.name + clone.multiplier = p.multiplier clone.padding = p.padding // NOTE: the following is as we never actually mutate the underlying bytes // array. @@ -85,6 +100,8 @@ func (p *FieldColumn) Clone() Column { // Pad this column with n copies of the column's padding value. func (p *FieldColumn) Pad(n uint) { + // Apply the length multiplier + n = n * p.multiplier // Allocate sufficient memory ndata := make([]*fr.Element, uint(len(p.data))+n) // Copy over the data diff --git a/pkg/trace/trace.go b/pkg/trace/trace.go index ecf1186..180ea5f 100644 --- a/pkg/trace/trace.go +++ b/pkg/trace/trace.go @@ -45,6 +45,11 @@ type Column interface { Get(row int) *fr.Element // Return the height (i.e. number of rows) of this column. Height() uint + // Returns the length multiplier (which must be a factor of the height). For + // example, if the multiplier is 2 then the height must always be a multiple + // of 2, etc. This affects padding also, as we must pad to this multiplier, + // etc. + LengthMultiplier() uint // Get the module index of the enclosing module. Module() uint // Get the name of this column diff --git a/testdata/interleave_01.accepts b/testdata/interleave_01.accepts new file mode 100644 index 0000000..d4efd81 --- /dev/null +++ b/testdata/interleave_01.accepts @@ -0,0 +1,5 @@ +{ "X": [], "Y": [] } +{ "X": [0], "Y": [0] } +{ "X": [0,0], "Y": [0,0] } +{ "X": [0,0,0], "Y": [0,0,0] } +{ "X": [0,0,0,0], "Y": [0,0,0,0] } diff --git a/testdata/interleave_01.lisp b/testdata/interleave_01.lisp new file mode 100644 index 0000000..6bfd104 --- /dev/null +++ b/testdata/interleave_01.lisp @@ -0,0 +1,4 @@ +(column X) +(column Y) +(interleave Z (X Y)) +(vanish c1 Z) diff --git a/testdata/interleave_01.rejects b/testdata/interleave_01.rejects new file mode 100644 index 0000000..5b49bd9 --- /dev/null +++ b/testdata/interleave_01.rejects @@ -0,0 +1,13 @@ +{ "X": [1], "Y": [0] } +{ "X": [0], "Y": [1] } +{ "X": [0,0], "Y": [0,1] } +{ "X": [0,0], "Y": [1,0] } +{ "X": [0,1], "Y": [0,0] } +{ "X": [1,0], "Y": [0,0] } +;; +{ "X": [1,0], "Y": [0,1] } +{ "X": [0,1], "Y": [0,1] } +{ "X": [1,0], "Y": [1,0] } +{ "X": [0,1], "Y": [1,0] } +{ "X": [1,0], "Y": [0,1] } +{ "X": [0,1], "Y": [0,1] } diff --git a/testdata/interleave_02.accepts b/testdata/interleave_02.accepts new file mode 100644 index 0000000..6c9dedc --- /dev/null +++ b/testdata/interleave_02.accepts @@ -0,0 +1,16 @@ +{ "X": [], "Y": [] } +{ "X": [0], "Y": [0] } +{ "X": [1], "Y": [1] } +{ "X": [1], "Y": [2] } +;; +{ "X": [0,0], "Y": [0,0] } +{ "X": [0,0], "Y": [0,1] } +{ "X": [0,1], "Y": [0,1] } +{ "X": [0,1], "Y": [1,1] } +{ "X": [1,1], "Y": [1,1] } +{ "X": [1,1], "Y": [1,2] } +{ "X": [1,2], "Y": [1,2] } +{ "X": [1,2], "Y": [2,2] } +{ "X": [1,2], "Y": [2,3] } +{ "X": [1,3], "Y": [2,3] } +{ "X": [1,3], "Y": [2,4] } diff --git a/testdata/interleave_02.lisp b/testdata/interleave_02.lisp new file mode 100644 index 0000000..56f9492 --- /dev/null +++ b/testdata/interleave_02.lisp @@ -0,0 +1,5 @@ +(column X) +(column Y) +(interleave Z (X Y)) +;; Z[k]+1 == Z[k+1] || Z[k] == Z[k+1] +(vanish c1 (* (- (+ 1 Z) (shift Z 1)) (- Z (shift Z 1)))) diff --git a/testdata/interleave_02.rejects b/testdata/interleave_02.rejects new file mode 100644 index 0000000..1a9da8a --- /dev/null +++ b/testdata/interleave_02.rejects @@ -0,0 +1,16 @@ +{ "X": [1], "Y": [0] } +{ "X": [0], "Y": [2] } +{ "X": [2], "Y": [2] } +;; +{ "X": [0,0], "Y": [1,0] } +{ "X": [0,1], "Y": [0,0] } +{ "X": [1,0], "Y": [0,0] } +{ "X": [1,0], "Y": [1,0] } +{ "X": [1,1], "Y": [1,0] } +{ "X": [1,2], "Y": [1,1] } +{ "X": [1,1], "Y": [2,1] } +{ "X": [2,1], "Y": [1,1] } +{ "X": [1,2], "Y": [2,1] } +{ "X": [2,1], "Y": [2,1] } +{ "X": [2,2], "Y": [2,1] } +{ "X": [1,1], "Y": [1,3] } diff --git a/testdata/interleave_03.accepts b/testdata/interleave_03.accepts new file mode 100644 index 0000000..d4efd81 --- /dev/null +++ b/testdata/interleave_03.accepts @@ -0,0 +1,5 @@ +{ "X": [], "Y": [] } +{ "X": [0], "Y": [0] } +{ "X": [0,0], "Y": [0,0] } +{ "X": [0,0,0], "Y": [0,0,0] } +{ "X": [0,0,0,0], "Y": [0,0,0,0] } diff --git a/testdata/interleave_03.lisp b/testdata/interleave_03.lisp new file mode 100644 index 0000000..e6d2ac3 --- /dev/null +++ b/testdata/interleave_03.lisp @@ -0,0 +1,6 @@ +(column X) +(column Y) +(interleave A (X Y)) +(interleave B (X Y)) +(interleave Z (A B)) +(vanish c1 Z) diff --git a/testdata/interleave_03.rejects b/testdata/interleave_03.rejects new file mode 100644 index 0000000..5b49bd9 --- /dev/null +++ b/testdata/interleave_03.rejects @@ -0,0 +1,13 @@ +{ "X": [1], "Y": [0] } +{ "X": [0], "Y": [1] } +{ "X": [0,0], "Y": [0,1] } +{ "X": [0,0], "Y": [1,0] } +{ "X": [0,1], "Y": [0,0] } +{ "X": [1,0], "Y": [0,0] } +;; +{ "X": [1,0], "Y": [0,1] } +{ "X": [0,1], "Y": [0,1] } +{ "X": [1,0], "Y": [1,0] } +{ "X": [0,1], "Y": [1,0] } +{ "X": [1,0], "Y": [0,1] } +{ "X": [0,1], "Y": [0,1] } diff --git a/testdata/interleave_04.accepts b/testdata/interleave_04.accepts new file mode 100644 index 0000000..08cfc8c --- /dev/null +++ b/testdata/interleave_04.accepts @@ -0,0 +1,20 @@ +{ "X": [], "Y": [], "Z": [] } +{ "X": [0], "Y": [0], "Z": [0] } +{ "X": [1], "Y": [2], "Z": [1] } +;; +{ "X": [1,2], "Y": [3,4], "Z": [0,0] } +{ "X": [1,2], "Y": [3,4], "Z": [0,1] } +{ "X": [1,2], "Y": [3,4], "Z": [1,0] } +{ "X": [1,2], "Y": [3,4], "Z": [1,1] } +{ "X": [2,1], "Y": [3,4], "Z": [0,0] } +{ "X": [2,1], "Y": [3,4], "Z": [0,1] } +{ "X": [2,1], "Y": [3,4], "Z": [1,0] } +{ "X": [2,1], "Y": [3,4], "Z": [1,1] } +{ "X": [1,2], "Y": [4,3], "Z": [0,0] } +{ "X": [1,2], "Y": [4,3], "Z": [0,1] } +{ "X": [1,2], "Y": [4,3], "Z": [1,0] } +{ "X": [1,2], "Y": [4,3], "Z": [1,1] } +{ "X": [2,1], "Y": [4,3], "Z": [0,0] } +{ "X": [2,1], "Y": [4,3], "Z": [0,1] } +{ "X": [2,1], "Y": [4,3], "Z": [1,0] } +{ "X": [2,1], "Y": [4,3], "Z": [1,1] } diff --git a/testdata/interleave_04.lisp b/testdata/interleave_04.lisp new file mode 100644 index 0000000..464fbac --- /dev/null +++ b/testdata/interleave_04.lisp @@ -0,0 +1,5 @@ +(column X) +(column Y) +(column Z) +(interleave A (X Y)) +(lookup l1 (A) (Z)) diff --git a/testdata/interleave_04.rejects b/testdata/interleave_04.rejects new file mode 100644 index 0000000..7512767 --- /dev/null +++ b/testdata/interleave_04.rejects @@ -0,0 +1,17 @@ +{ "X": [1], "Y": [2], "Z": [3] } +;; +{ "X": [1,2], "Y": [2,1], "Z": [0,3] } +{ "X": [1,2], "Y": [2,1], "Z": [3,0] } +;; +{ "X": [1,1,2], "Y": [4,4,5], "Z": [0,1,3] } +{ "X": [1,1,2], "Y": [4,4,5], "Z": [1,0,3] } +{ "X": [1,1,2], "Y": [4,4,5], "Z": [0,3,1] } +{ "X": [1,1,2], "Y": [4,4,5], "Z": [1,3,0] } +{ "X": [1,2,1], "Y": [4,5,4], "Z": [0,1,3] } +{ "X": [1,2,1], "Y": [4,5,4], "Z": [1,0,3] } +{ "X": [1,2,1], "Y": [4,5,4], "Z": [0,3,1] } +{ "X": [1,2,1], "Y": [4,5,4], "Z": [1,3,0] } +{ "X": [2,1,1], "Y": [5,4,4], "Z": [0,1,3] } +{ "X": [2,1,1], "Y": [5,4,4], "Z": [1,0,3] } +{ "X": [2,1,1], "Y": [5,4,4], "Z": [0,3,1] } +{ "X": [2,1,1], "Y": [5,4,4], "Z": [1,3,0] }