Skip to content

Commit

Permalink
Merge pull request #328 from Consensys/130-add-test-generation-models
Browse files Browse the repository at this point in the history
feat: add test generation models #180
  • Loading branch information
DavePearce authored Oct 9, 2024
2 parents adbebee + 0a40e87 commit b687f2e
Show file tree
Hide file tree
Showing 23 changed files with 7,441 additions and 103 deletions.
234 changes: 234 additions & 0 deletions cmd/testgen/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
package main

import (
"fmt"
"os"
"path"
"strings"

"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/hir"
sc "github.com/consensys/go-corset/pkg/schema"
tr "github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/trace/json"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
)

func main() {
err := rootCmd.Execute()
if err != nil {
os.Exit(1)
}
}

func init() {
rootCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle")
}

// rootCmd represents the base command when called without any subcommands
var rootCmd = &cobra.Command{
Use: "testgen",
Short: "Test generation utility for go-corset.",
Run: func(cmd *cobra.Command, args []string) {
if len(args) != 1 {
fmt.Println(cmd.UsageString())
os.Exit(1)
}
model := args[0]
// Lookup model
for _, m := range models {
if m.Name == model {
// Read schema
filename := fmt.Sprintf("%s.lisp", m.Name)
schema := readSchemaFile(path.Join("testdata", filename))
// Generate & split traces
valid, invalid := generateTestTraces(m, schema)
// Write out
writeTestTraces(m, "accepts", schema, valid)
writeTestTraces(m, "rejects", schema, invalid)
os.Exit(0)
}
}
//
fmt.Printf("unknown model \"%s\"\n", model)
os.Exit(1)
},
}

// Model represents a hard-coded oracle for a given test.
type Model struct {
// Name of the model in question
Name string
// Predicate for determining which trace to accept
Oracle func(sc.Schema, tr.Trace) bool
}

var models []Model = []Model{
{"memory", memoryModel},
}

// Generate test traces
func generateTestTraces(model Model, schema sc.Schema) ([]tr.Trace, []tr.Trace) {
// NOTE: This is really a temporary solution for now. It doesn't handle
// length multipliers. It doesn't allow for modules with different heights.
// It uses a fixed pool.
pool := []fr.Element{fr.NewElement(0), fr.NewElement(1), fr.NewElement(2)}
//
enumerator := sc.NewTraceEnumerator(2, schema, pool)
valid := make([]tr.Trace, 0)
invalid := make([]tr.Trace, 0)
// Generate and split the traces
for enumerator.HasNext() {
trace := enumerator.Next()
// Check whether trace is valid or not (according to the oracle)
if model.Oracle(schema, trace) {
valid = append(valid, trace)
} else {
invalid = append(invalid, trace)
}
}
// Done
return valid, invalid
}

func writeTestTraces(model Model, ext string, schema sc.Schema, traces []tr.Trace) {
var sb strings.Builder
// Construct filename
filename := fmt.Sprintf("testdata/%s.auto.%s", model.Name, ext)
// Generate lines
for _, trace := range traces {
raw := traceToColumns(schema, trace)
json := json.ToJsonString(raw)
sb.WriteString(json)
sb.WriteString("\n")
}
// Write the file
if err := os.WriteFile(filename, []byte(sb.String()), 0644); err != nil {
panic(err)
}
// Log what happened
log.Infof("Wrote %s\n", filename)
}

// Convert a trace into an array of raw columns.
func traceToColumns(schema sc.Schema, trace tr.Trace) []tr.RawColumn {
ncols := schema.InputColumns().Count()
cols := make([]tr.RawColumn, ncols)
i := 0
// Convert each column
for iter := schema.InputColumns(); iter.HasNext(); {
sc_col := iter.Next()
// Lookup the column data
tr_col := findColumn(sc_col.Context().Module(), sc_col.Name(), schema, trace)
// Determine module name
mod := schema.Modules().Nth(sc_col.Context().Module())
// Assignt the raw colmn
cols[i] = tr.RawColumn{Module: mod.Name(), Name: sc_col.Name(), Data: tr_col.Data()}
//
i++
}
//
return cols
}

func readSchemaFile(filename string) *hir.Schema {
// Read schema file
bytes, err := os.ReadFile(filename)
// Handle errors
if err != nil {
fmt.Println(err)
os.Exit(1)
}
// Attempt to parse schema
schema, err2 := hir.ParseSchemaString(string(bytes))
// Check whether parsed successfully or not
if err2 == nil {
// Ok
return schema
}
// Errors
fmt.Println(err2)
os.Exit(1)
// unreachable
return nil
}

func findColumn(mod uint, col string, schema sc.Schema, trace tr.Trace) tr.Column {
cid, ok := sc.ColumnIndexOf(schema, mod, col)
if !ok {
panic(fmt.Sprintf("unknown column \"%s\"", col))
}
// Done
return trace.Column(cid)
}

// ============================================================================
// Models
// ============================================================================

func memoryModel(schema sc.Schema, trace tr.Trace) bool {
TWO_1 := fr.NewElement(2)
TWO_8 := fr.NewElement(256)
TWO_16 := fr.NewElement(65536)
TWO_32 := fr.NewElement(4294967296)
//
PC := findColumn(0, "PC", schema, trace).Data()
RW := findColumn(0, "RW", schema, trace).Data()
ADDR := findColumn(0, "ADDR", schema, trace).Data()
VAL := findColumn(0, "VAL", schema, trace).Data()
// Configure memory model
memory := make(map[fr.Element]fr.Element, 0)
//
for i := uint(0); i < PC.Len(); i++ {
pc_i := PC.Get(i)
rw_i := RW.Get(i)
addr_i := ADDR.Get(i)
val_i := VAL.Get(i)
// Type constraints
t_pc := pc_i.Cmp(&TWO_16) < 0
t_rw := rw_i.Cmp(&TWO_1) < 0
t_addr := addr_i.Cmp(&TWO_32) < 0
t_val := val_i.Cmp(&TWO_8) < 0
// Check type constraints
if !(t_pc && t_rw && t_addr && t_val) {
return false
}
// Heartbeat 1
h1 := i != 0 || pc_i.IsZero()
// Heartbeat 2
h2 := i == 0 || pc_i.IsZero() || isIncremented(PC.Get(i-1), pc_i)
// Heartbeat 3
h3 := i == 0 || !pc_i.IsZero() || PC.Get(i-1) == pc_i
// Heartbeat 4
h4 := !pc_i.IsZero() || (rw_i.IsZero() && addr_i.IsZero() && val_i.IsZero())
// Check heartbeat constraints
if !(h1 && h2 && h3 && h4) {
return false
}
// Check reading / writing
if rw_i.IsOne() {
// Write
memory[addr_i] = val_i
} else {
v := memory[addr_i]
// Check read matches
if v.Cmp(&val_i) != 0 {
return false
}
}
}
// Success
return true
}

// ============================================================================
// Helpers
// ============================================================================

// Check a given element is the previous element plus one.
func isIncremented(before fr.Element, after fr.Element) bool {
after.Sub(&after, &before)
//
return after.IsOne()
}
15 changes: 14 additions & 1 deletion pkg/air/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ type PermutationConstraint = *constraint.PermutationConstraint
// PropertyAssertion captures the notion of an arbitrary property which should
// hold for all acceptable traces. However, such a property is not enforced by
// the prover.
type PropertyAssertion = *schema.PropertyAssertion[constraint.ZeroTest[schema.Evaluable]]
type PropertyAssertion = *schema.PropertyAssertion[schema.Testable]

// Schema for AIR traces which is parameterised on a notion of computation as
// permissible in computed columns.
Expand Down Expand Up @@ -134,6 +134,11 @@ func (p *Schema) AddPermutationConstraint(targets []uint, sources []uint) {
p.constraints = append(p.constraints, constraint.NewPermutationConstraint(targets, sources))
}

// AddPropertyAssertion appends a new property assertion.
func (p *Schema) AddPropertyAssertion(handle string, context trace.Context, assertion schema.Testable) {
p.assertions = append(p.assertions, schema.NewPropertyAssertion(handle, context, assertion))
}

// AddVanishingConstraint appends a new vanishing constraint.
func (p *Schema) AddVanishingConstraint(handle string, context trace.Context, domain *int, expr Expr) {
if context.Module() >= uint(len(p.modules)) {
Expand Down Expand Up @@ -162,6 +167,14 @@ func (p *Schema) InputColumns() util.Iterator[schema.Column] {
func(d schema.Declaration) util.Iterator[schema.Column] { return d.Columns() })
}

// Assertions returns an iterator over the property assertions of this
// schema. These are properties which should hold true for any valid trace
// (though, of course, may not hold true for an invalid trace).
func (p *Schema) Assertions() util.Iterator[schema.Constraint] {
properties := util.NewArrayIterator(p.assertions)
return util.NewCastIterator[PropertyAssertion, schema.Constraint](properties)
}

// Assignments returns an array over the assignments of this schema. That
// is, the subset of declarations whose trace values can be computed from
// the inputs.
Expand Down
48 changes: 26 additions & 22 deletions pkg/cmd/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,22 @@ var checkCmd = &cobra.Command{
cfg.parallelExpansion = !getFlag(cmd, "sequential")
cfg.batchSize = getUint(cmd, "batch")
cfg.ansiEscapes = getFlag(cmd, "ansi-escapes")
//
stats := util.NewPerfStats()
// TODO: support true ranges
cfg.padding.Left = cfg.padding.Right
if !cfg.hir && !cfg.mir && !cfg.air {
// If IR not specified default to running all.
cfg.hir, cfg.mir, cfg.air = true, true, true
}
//
stats := util.NewPerfStats()
// Parse constraints
hirSchema = readSchemaFile(args[1])
//
stats.Log("Reading constraints file")
// Parse trace file
columns := readTraceFile(args[0])
//
stats.Log("Reading trace files")
stats.Log("Reading trace file")
// Go!
if !checkTraceWithLowering(columns, hirSchema, cfg) {
os.Exit(1)
Expand Down Expand Up @@ -107,28 +113,17 @@ type checkConfig struct {
// Check a given trace is consistently accepted (or rejected) at the different
// IR levels.
func checkTraceWithLowering(cols []tr.RawColumn, schema *hir.Schema, cfg checkConfig) bool {
hir := cfg.hir
mir := cfg.mir
air := cfg.air

if !hir && !mir && !air {
// If IR not specified default to running all.
hir = true
mir = true
air = true
}
//
res := true
// Process individually
if hir {
if cfg.hir {
res = checkTrace("HIR", cols, schema, cfg)
}

if mir {
if cfg.mir {
res = checkTrace("MIR", cols, schema.LowerToMir(), cfg) && res
}

if air {
if cfg.air {
res = checkTrace("AIR", cols, schema.LowerToMir().LowerToAir(), cfg) && res
}

Expand Down Expand Up @@ -159,11 +154,16 @@ func checkTrace(ir string, cols []tr.RawColumn, schema sc.Schema, cfg checkConfi
// Check trace
stats.Log("Validating trace")
stats = util.NewPerfStats()
//
// Check constraints
if errs := sc.Accepts(cfg.batchSize, schema, trace); len(errs) > 0 {
reportFailures(ir, errs, trace, cfg)
return false
}
// Check assertions
if errs := sc.Asserts(cfg.batchSize, schema, trace); len(errs) > 0 {
reportFailures(ir, errs, trace, cfg)
return false
}

stats.Log("Checking constraints")
}
Expand Down Expand Up @@ -238,15 +238,19 @@ func reportFailures(ir string, failures []sc.Failure, trace tr.Trace, cfg checkC
// Print a human-readable report detailing the given failure
func reportFailure(failure sc.Failure, trace tr.Trace, cfg checkConfig) {
if f, ok := failure.(*constraint.VanishingFailure); ok {
reportVanishingFailure(f, trace, cfg)
cells := f.RequiredCells(trace)
reportConstraintFailure("constraint", f.Handle(), cells, trace, cfg)
} else if f, ok := failure.(*sc.AssertionFailure); ok {
cells := f.RequiredCells(trace)
reportConstraintFailure("assertion", f.Handle(), cells, trace, cfg)
}
}

// Print a human-readable report detailing the given failure with a vanishing constraint.
func reportVanishingFailure(failure *constraint.VanishingFailure, trace tr.Trace, cfg checkConfig) {
func reportConstraintFailure(kind string, handle string, cells *util.AnySortedSet[tr.CellRef],
trace tr.Trace, cfg checkConfig) {
var start uint = math.MaxUint
// Determine all (input) cells involved in evaluating the given constraint
cells := failure.RequiredCells(trace)
end := uint(0)
// Determine row bounds
for _, c := range cells.ToArray() {
Expand All @@ -271,7 +275,7 @@ func reportVanishingFailure(failure *constraint.VanishingFailure, trace tr.Trace
return cells.Contains(cell)
})
// Print out report
fmt.Printf("failing constraint %s:\n", failure.Handle())
fmt.Printf("failing %s %s:\n", kind, handle)
tp.Print(trace)
fmt.Println()
}
Expand Down
19 changes: 0 additions & 19 deletions pkg/cmd/compute.go

This file was deleted.

Loading

0 comments on commit b687f2e

Please sign in to comment.