Skip to content

Commit

Permalink
Refactoring
Browse files Browse the repository at this point in the history
This puts through a number of minor refactorings which are now possible,
and removes some code which is no longer used.
  • Loading branch information
DavePearce committed Jun 26, 2024
1 parent 9f13cdc commit 5588872
Show file tree
Hide file tree
Showing 22 changed files with 81 additions and 123 deletions.
2 changes: 1 addition & 1 deletion pkg/air/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
// value at that row of the column in question or nil is that row is
// out-of-bounds.
func (e *ColumnAccess) EvalAt(k int, tr trace.Trace) *fr.Element {
val := tr.ColumnByIndex(e.Column).Get(k + e.Shift)
val := tr.Column(e.Column).Get(k + e.Shift)

var clone fr.Element
// Clone original value
Expand Down
6 changes: 3 additions & 3 deletions pkg/cmd/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func filterColumns(tr trace.Trace, prefix string) trace.Trace {
ntr := trace.EmptyArrayTrace()
//
for i := uint(0); i < tr.Width(); i++ {
ith := tr.ColumnByIndex(i)
ith := tr.Column(i)
if strings.HasPrefix(ith.Name(), prefix) {
ntr.Add(ith)
}
Expand All @@ -73,7 +73,7 @@ func listColumns(tr trace.Trace) {
tbl := util.NewTablePrinter(3, tr.Width())

for i := uint(0); i < tr.Width(); i++ {
ith := tr.ColumnByIndex(i)
ith := tr.Column(i)
elems := fmt.Sprintf("%d rows", ith.Height())
bytes := fmt.Sprintf("%d bytes", ith.Width()*ith.Height())
tbl.SetRow(i, ith.Name(), elems, bytes)
Expand All @@ -93,7 +93,7 @@ func printTrace(start uint, end uint, max_width uint, tr trace.Trace) {
}

for i := uint(0); i < tr.Width(); i++ {
ith := tr.ColumnByIndex(i)
ith := tr.Column(i)
tbl.Set(0, i+1, ith.Name())

if start < ith.Height() {
Expand Down
2 changes: 1 addition & 1 deletion pkg/cmd/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func readTraceFile(filename string) trace.Trace {
//
switch ext {
case ".json":
tr, err = trace.ParseJsonTrace(bytes)
tr, err = json.FromBytes(bytes)
if err == nil {
return tr
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/hir/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
// value at that row of the column in question or nil is that row is
// out-of-bounds.
func (e *ColumnAccess) EvalAllAt(k int, tr trace.Trace) []*fr.Element {
val := tr.ColumnByIndex(e.Column).Get(k + e.Shift)
val := tr.Column(e.Column).Get(k + e.Shift)

var clone fr.Element
// Clone original value
Expand Down
2 changes: 1 addition & 1 deletion pkg/mir/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
// value at that row of the column in question or nil is that row is
// out-of-bounds.
func (e *ColumnAccess) EvalAt(k int, tr trace.Trace) *fr.Element {
val := tr.ColumnByIndex(e.Column).Get(k + e.Shift)
val := tr.Column(e.Column).Get(k + e.Shift)

var clone fr.Element
// Clone original value
Expand Down
4 changes: 2 additions & 2 deletions pkg/schema/alignment.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func alignWith(expand bool, p tr.Trace, schema Schema) error {
return fmt.Errorf("trace missing column %s", schemaName)
}

traceName := p.ColumnByIndex(index).Name()
traceName := p.Column(index).Name()
// Check alignment
if traceName != schemaName {
// Not aligned --- so fix
Expand All @@ -74,7 +74,7 @@ func alignWith(expand bool, p tr.Trace, schema Schema) error {
unknowns := make([]string, n)
// Determine names of unknown columns.
for i := index; i < ncols; i++ {
unknowns[i-index] = p.ColumnByIndex(i).Name()
unknowns[i-index] = p.Column(i).Name()
}
//
return fmt.Errorf("trace contains unknown columns: %v", unknowns)
Expand Down
10 changes: 5 additions & 5 deletions pkg/schema/assignment/byte_decomposition.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (

"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/schema"
tr "github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/util"
)

Expand Down Expand Up @@ -62,13 +62,13 @@ func (p *ByteDecomposition) IsComputed() bool {
// ExpandTrace expands a given trace to include the columns specified by a given
// ByteDecomposition. This requires computing the value of each byte column in
// the decomposition.
func (p *ByteDecomposition) ExpandTrace(tr tr.Trace) error {
func (p *ByteDecomposition) ExpandTrace(tr trace.Trace) error {
// Calculate how many bytes required.
n := len(p.targets)
// Identify target column
target := tr.ColumnByIndex(p.source)
target := tr.Column(p.source)
// Extract column data to decompose
data := tr.ColumnByIndex(p.source).Data()
data := tr.Column(p.source).Data()
// Construct byte column data
cols := make([][]*fr.Element, n)
// Initialise columns
Expand All @@ -86,7 +86,7 @@ func (p *ByteDecomposition) ExpandTrace(tr tr.Trace) error {
padding := decomposeIntoBytes(target.Padding(), n)
// Finally, add byte columns to trace
for i := 0; i < n; i++ {
tr.AddColumn(p.targets[i].Name(), cols[i], padding[i])
tr.Add(trace.NewFieldColumn(p.targets[i].Name(), cols[i], padding[i]))
}
// Done
return nil
Expand Down
6 changes: 3 additions & 3 deletions pkg/schema/assignment/computed.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/schema"
sc "github.com/consensys/go-corset/pkg/schema"
tr "github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/util"
)

Expand Down Expand Up @@ -75,7 +75,7 @@ func (p *ComputedColumn[E]) RequiredSpillage() uint {
// ExpandTrace attempts to a new column to the trace which contains the result
// of evaluating a given expression on each row. If the column already exists,
// then an error is flagged.
func (p *ComputedColumn[E]) ExpandTrace(tr tr.Trace) error {
func (p *ComputedColumn[E]) ExpandTrace(tr trace.Trace) error {
if tr.HasColumn(p.name) {
return fmt.Errorf("Computed column already exists ({%s})", p.name)
}
Expand All @@ -96,7 +96,7 @@ func (p *ComputedColumn[E]) ExpandTrace(tr tr.Trace) error {
// the padding value for *this* column.
padding := p.expr.EvalAt(-1, tr)
// Colunm needs to be expanded.
tr.AddColumn(p.name, data, padding)
tr.Add(trace.NewFieldColumn(p.name, data, padding))
// Done
return nil
}
8 changes: 4 additions & 4 deletions pkg/schema/assignment/lexicographic_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ func (p *LexicographicSort) ExpandTrace(tr trace.Trace) error {
delta[i] = &zero
// Decide which row is the winner (if any)
for j := 0; j < ncols; j++ {
prev := tr.ColumnByIndex(p.sources[j]).Get(i - 1)
curr := tr.ColumnByIndex(p.sources[j]).Get(i)
prev := tr.Column(p.sources[j]).Get(i - 1)
curr := tr.Column(p.sources[j]).Get(i)

if !set && prev != nil && prev.Cmp(curr) != 0 {
var diff fr.Element
Expand All @@ -108,11 +108,11 @@ func (p *LexicographicSort) ExpandTrace(tr trace.Trace) error {
}
}
// Add delta column data
tr.AddColumn(p.targets[0].Name(), delta, &zero)
tr.Add(trace.NewFieldColumn(p.targets[0].Name(), delta, &zero))
// Add bit column data
for i := 0; i < ncols; i++ {
bitName := p.targets[1+i].Name()
tr.AddColumn(bitName, bit[i], &zero)
tr.Add(trace.NewFieldColumn(bitName, bit[i], &zero))
}
// Done.
return nil
Expand Down
7 changes: 4 additions & 3 deletions pkg/schema/assignment/sorted_permutation.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/schema"
"github.com/consensys/go-corset/pkg/trace"
tr "github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/util"
)
Expand Down Expand Up @@ -119,7 +120,7 @@ func (p *SortedPermutation) ExpandTrace(tr tr.Trace) error {
for i := 0; i < len(p.sources); i++ {
src := p.sources[i]
// Read column data to initialise permutation.
data := tr.ColumnByIndex(src).Data()
data := tr.Column(src).Data()
// Copy column data to initialise permutation.
cols[i] = make([]*fr.Element, len(data))
copy(cols[i], data)
Expand All @@ -131,8 +132,8 @@ func (p *SortedPermutation) ExpandTrace(tr tr.Trace) error {

for i := p.Columns(); i.HasNext(); {
dstColName := i.Next().Name()
srcCol := tr.ColumnByIndex(p.sources[index])
tr.AddColumn(dstColName, cols[index], srcCol.Padding())
srcCol := tr.Column(p.sources[index])
tr.Add(trace.NewFieldColumn(dstColName, cols[index], srcCol.Padding()))

index++
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/schema/constraint/permutation.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func sliceColumns(columns []uint, tr trace.Trace) [][]*fr.Element {
cols := make([][]*fr.Element, len(columns))
// Slice out the data
for i, n := range columns {
nth := tr.ColumnByIndex(n)
nth := tr.Column(n)
cols[i] = nth.Data()
}
// Done
Expand Down
2 changes: 1 addition & 1 deletion pkg/schema/constraint/range.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func NewRangeConstraint(column uint, bound *fr.Element) *RangeConstraint {
// Accepts checks whether a range constraint evaluates to zero on
// every row of a table. If so, return nil otherwise return an error.
func (p *RangeConstraint) Accepts(tr trace.Trace) error {
column := tr.ColumnByIndex(p.column)
column := tr.Column(p.column)
for k := 0; k < int(tr.Height()); k++ {
// Get the value on the kth row
kth := column.Get(k)
Expand Down
2 changes: 1 addition & 1 deletion pkg/schema/constraint/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (p *TypeConstraint) Type() schema.Type {
// Accepts checks whether a range constraint evaluates to zero on
// every row of a table. If so, return nil otherwise return an error.
func (p *TypeConstraint) Accepts(tr trace.Trace) error {
column := tr.ColumnByIndex(p.column)
column := tr.Column(p.column)
for k := 0; k < int(tr.Height()); k++ {
// Get the value on the kth row
kth := column.Get(k)
Expand Down
25 changes: 0 additions & 25 deletions pkg/schema/schemas.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package schema

import (
"fmt"

tr "github.com/consensys/go-corset/pkg/trace"
)

Expand Down Expand Up @@ -69,26 +67,3 @@ func ColumnIndexOf(schema Schema, name string) (uint, bool) {
return c.Name() == name
})
}

// ColumnByName returns the column with the matching name, or panics if no such
// column exists.
func ColumnByName(schema Schema, name string) Column {
var col Column
// Attempt to determine the index of this column
_, ok := schema.Columns().Find(func(c Column) bool {
col = c
return c.Name() == name
})
// If we found it, then done.
if ok {
return col
}
// Otherwise panic.
panic(fmt.Sprintf("unknown column %s", name))
}

// HasColumn checks whether a column of the given name is declared within the schema.
func HasColumn(schema Schema, name string) bool {
_, ok := ColumnIndexOf(schema, name)
return ok
}
3 changes: 2 additions & 1 deletion pkg/test/ir_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/consensys/go-corset/pkg/schema"
sc "github.com/consensys/go-corset/pkg/schema"
"github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/trace/json"
)

// Determines the (relative) location of the test directory. That is
Expand Down Expand Up @@ -469,7 +470,7 @@ func ReadTracesFile(name string, ext string) []*trace.ArrayTrace {
for i, line := range lines {
// Parse input line as JSON
if line != "" && !strings.HasPrefix(line, ";;") {
tr, err := trace.ParseJsonTrace([]byte(line))
tr, err := json.FromBytes([]byte(line))
if err != nil {
msg := fmt.Sprintf("%s.%s:%d: %s", name, ext, i+1, err)
panic(msg)
Expand Down
30 changes: 3 additions & 27 deletions pkg/trace/array_trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,24 +67,11 @@ func (p *ArrayTrace) Columns() []Column {
return p.columns
}

// ColumnByIndex looks up a column based on its index.
func (p *ArrayTrace) ColumnByIndex(index uint) Column {
// Column looks up a column based on its index.
func (p *ArrayTrace) Column(index uint) Column {
return p.columns[index]
}

// ColumnByName looks up a column based on its name. If the column doesn't
// exist, then nil is returned.
func (p *ArrayTrace) ColumnByName(name string) Column {
for _, c := range p.columns {
if name == c.Name() {
// Matched column
return c
}
}

return nil
}

// HasColumn checks whether the trace has a given named column (or not).
func (p *ArrayTrace) HasColumn(name string) bool {
_, ok := p.ColumnIndex(name)
Expand Down Expand Up @@ -120,18 +107,7 @@ func (p *ArrayTrace) Add(column Column) {

// AddColumn adds a new column of data to this trace.
func (p *ArrayTrace) AddColumn(name string, data []*fr.Element, padding *fr.Element) {
// Sanity check the column does not already exist.
if p.HasColumn(name) {
panic("column already exists")
}
// Construct new column
column := FieldColumn{name, data, padding}
// Append it
p.columns = append(p.columns, &column)
// Update maximum height
if uint(len(data)) > p.height {
p.height = uint(len(data))
}
p.Add(&FieldColumn{name, data, padding})
}

// Height determines the maximum height of any column within this trace.
Expand Down
5 changes: 5 additions & 0 deletions pkg/trace/field_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ type FieldColumn struct {
padding *fr.Element
}

// NewFieldColumn constructs a FieldColumn with the give name, data and padding.
func NewFieldColumn(name string, data []*fr.Element, padding *fr.Element) *FieldColumn {
return &FieldColumn{name, data, padding}
}

// Name returns the name of the given column.
func (p *FieldColumn) Name() string {
return p.name
Expand Down
35 changes: 35 additions & 0 deletions pkg/trace/json/reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package json

import (
"encoding/json"
"math/big"

"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/util"
)

// FromBytes parses a trace expressed in JSON notation. For example, {"X":
// [0], "Y": [1]} is a trace containing one row of data each for two columns "X"
// and "Y".
func FromBytes(bytes []byte) (*trace.ArrayTrace, error) {
var zero fr.Element = fr.NewElement((0))

var rawData map[string][]*big.Int
// Unmarshall
jsonErr := json.Unmarshal(bytes, &rawData)
if jsonErr != nil {
return nil, jsonErr
}

trace := trace.EmptyArrayTrace()

for name, rawInts := range rawData {
// Translate raw bigints into raw field elements
rawElements := util.ToFieldElements(rawInts)
// Add new column to the trace
trace.AddColumn(name, rawElements, &zero)
}
// Done.
return trace, nil
}
2 changes: 1 addition & 1 deletion pkg/trace/json/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func ToJsonString(tr trace.Trace) string {
builder.WriteString(", ")
}

ith := tr.ColumnByIndex(i)
ith := tr.Column(i)

builder.WriteString("\"")
builder.WriteString(ith.Name())
Expand Down
4 changes: 2 additions & 2 deletions pkg/trace/lt/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func WriteBytes(tr trace.Trace, buf io.Writer) error {
}
// Write header information
for i := uint(0); i < ncols; i++ {
col := tr.ColumnByIndex(i)
col := tr.Column(i)
// Write name length
nameBytes := []byte(col.Name())
nameLen := uint16(len(nameBytes))
Expand All @@ -62,7 +62,7 @@ func WriteBytes(tr trace.Trace, buf io.Writer) error {
}
// Write column data information
for i := uint(0); i < ncols; i++ {
col := tr.ColumnByIndex(i)
col := tr.Column(i)
if err := col.Write(buf); err != nil {
return err
}
Expand Down
Loading

0 comments on commit 5588872

Please sign in to comment.