Skip to content

Commit

Permalink
Implement Exp expression at HIR level
Browse files Browse the repository at this point in the history
This puts through an exponentiation operator at the HIR and MIR levels.
This updates the parser accordingly, and also the `bin` file reader.  At
the AIR level, however, there is no change.  Instead, `Exp` nodes are
reduced to fixed-width multiplications.  Finally, an efficient algorithm
for computing exponents is included, along with tests against gnark.
  • Loading branch information
DavePearce committed Jul 5, 2024
1 parent 32ece2d commit f034b5d
Show file tree
Hide file tree
Showing 16 changed files with 279 additions and 1 deletion.
18 changes: 18 additions & 0 deletions pkg/binfile/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,24 @@ func (e *jsonExprFuncall) ToHir(schema *hir.Schema) hir.Expr {
return &hir.Mul{Args: args}
case "VectorSub", "Sub":
return &hir.Sub{Args: args}
case "Exp":
if len(args) != 2 {
panic(fmt.Sprintf("incorrect number of arguments for Exp (%d)", len(args)))
}

c, ok := args[1].(*hir.Constant)

if !ok {
panic(fmt.Sprintf("constant power expected for Exp, got %s", args[1].String()))
} else if !c.Val.IsUint64() {
panic("constant power too large for Exp")
}

var k big.Int
// Convert power to uint64
c.Val.BigInt(&k)
// Done
return &hir.Exp{Arg: args[0], Pow: k.Uint64()}
case "IfZero":
if len(args) == 2 {
return &hir.IfZero{Condition: args[0], TrueBranch: args[1], FalseBranch: nil}
Expand Down
12 changes: 12 additions & 0 deletions pkg/hir/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package hir
import (
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/util"
)

// EvalAllAt evaluates a column access at a given row in a trace, which returns the
Expand Down Expand Up @@ -38,6 +39,17 @@ func (e *Mul) EvalAllAt(k int, tr trace.Trace) []*fr.Element {
return evalExprsAt(k, tr, e.Args, fn)
}

// EvalAllAt evaluates a product at a given row in a trace by first evaluating all of
// its arguments at that row.
func (e *Exp) EvalAllAt(k int, tr trace.Trace) []*fr.Element {
vals := e.Arg.EvalAllAt(k, tr)
for _, v := range vals {
util.Pow(v, e.Pow)
}

return vals
}

// EvalAllAt evaluates a conditional at a given row in a trace by first evaluating
// its condition at that row. If that condition is zero then the true branch
// (if applicable) is evaluated; otherwise if the condition is non-zero then
Expand Down
20 changes: 20 additions & 0 deletions pkg/hir/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,26 @@ func (p *Mul) Context(schema sc.Schema) trace.Context {
return sc.JoinContexts[Expr](p.Args, schema)
}

// ============================================================================
// Exponentiation
// ============================================================================

// Exp represents the a given value taken to a power.
type Exp struct {
Arg Expr
Pow uint64
}

// Bounds returns max shift in either the negative (left) or positive
// direction (right).
func (p *Exp) Bounds() util.Bounds { return p.Arg.Bounds() }

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (p *Exp) Context(schema sc.Schema) trace.Context {
return p.Arg.Context(schema)
}

// ============================================================================
// List
// ============================================================================
Expand Down
18 changes: 18 additions & 0 deletions pkg/hir/lower.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,13 @@ func (e *ColumnAccess) LowerTo(schema *mir.Schema) []mir.Expr {
return lowerTo(e, schema)
}

// LowerTo lowers an exponent expression to the MIR level. This requires expanding
// the argument andn lowering it. Furthermore, conditionals are "lifted" to
// the top.
func (e *Exp) LowerTo(schema *mir.Schema) []mir.Expr {
return lowerTo(e, schema)
}

// LowerTo lowers a product expression to the MIR level. This requires expanding
// the arguments, then lowering them. Furthermore, conditionals are "lifted" to
// the top.
Expand Down Expand Up @@ -182,6 +189,8 @@ func lowerCondition(e Expr, schema *mir.Schema) mir.Expr {
return lowerConditions(p.Args, schema)
} else if p, ok := e.(*Normalise); ok {
return lowerCondition(p.Arg, schema)
} else if p, ok := e.(*Exp); ok {
return lowerCondition(p.Arg, schema)
} else if p, ok := e.(*IfZero); ok {
return lowerIfZeroCondition(p, schema)
} else if p, ok := e.(*Sub); ok {
Expand Down Expand Up @@ -248,6 +257,8 @@ func lowerBody(e Expr, schema *mir.Schema) mir.Expr {
return &mir.ColumnAccess{Column: p.Column, Shift: p.Shift}
} else if p, ok := e.(*Mul); ok {
return &mir.Mul{Args: lowerBodies(p.Args, schema)}
} else if p, ok := e.(*Exp); ok {
return &mir.Exp{Arg: lowerBody(p.Arg, schema), Pow: p.Pow}
} else if p, ok := e.(*Normalise); ok {
return &mir.Normalise{Arg: lowerBody(p.Arg, schema)}
} else if p, ok := e.(*IfZero); ok {
Expand Down Expand Up @@ -306,6 +317,13 @@ func expand(e Expr) []Expr {
ees = append(ees, expand(arg)...)
}

return ees
} else if p, ok := e.(*Exp); ok {
ees := expand(p.Arg)
for i, ee := range ees {
ees[i] = &Exp{ee, p.Pow}
}

return ees
} else if p, ok := e.(*Normalise); ok {
ees := expand(p.Arg)
Expand Down
21 changes: 21 additions & 0 deletions pkg/hir/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package hir
import (
"errors"
"fmt"
"math/big"
"strconv"
"strings"
"unicode"
Expand Down Expand Up @@ -77,6 +78,7 @@ func newHirParser(srcmap *sexp.SourceMap[sexp.SExp]) *hirParser {
p.AddRecursiveRule("-", subParserRule)
p.AddRecursiveRule("*", mulParserRule)
p.AddRecursiveRule("~", normParserRule)
p.AddRecursiveRule("^", powParserRule)
p.AddRecursiveRule("if", ifParserRule)
p.AddRecursiveRule("ifnot", ifNotParserRule)
p.AddRecursiveRule("begin", beginParserRule)
Expand Down Expand Up @@ -543,6 +545,25 @@ func shiftParserRule(parser *hirParser) func(string, string) (Expr, error) {
}
}

func powParserRule(args []Expr) (Expr, error) {
var k big.Int

if len(args) != 2 {
return nil, errors.New("incorrect number of arguments")
}

c, ok := args[1].(*Constant)
if !ok {
return nil, errors.New("expected constant power")
} else if !c.Val.IsUint64() {
return nil, errors.New("constant power too large")
}
// Convert power to uint64
c.Val.BigInt(&k)
// Done
return &Exp{Arg: args[0], Pow: k.Uint64()}, nil
}

func normParserRule(args []Expr) (Expr, error) {
if len(args) != 1 {
return nil, errors.New("incorrect number of arguments")
Expand Down
4 changes: 4 additions & 0 deletions pkg/hir/string.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ func (e *Mul) String() string {
return naryString("*", e.Args)
}

func (e *Exp) String() string {
return fmt.Sprintf("(^ %s %d)", e.Arg, e.Pow)
}

func (e *Normalise) String() string {
return fmt.Sprintf("(~ %s)", e.Arg)
}
Expand Down
12 changes: 12 additions & 0 deletions pkg/mir/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package mir
import (
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/util"
)

// EvalAt evaluates a column access at a given row in a trace, which returns the
Expand Down Expand Up @@ -38,6 +39,17 @@ func (e *Mul) EvalAt(k int, tr trace.Trace) *fr.Element {
return evalExprsAt(k, tr, e.Args, fn)
}

// EvalAt evaluates a product at a given row in a trace by first evaluating all of
// its arguments at that row.
func (e *Exp) EvalAt(k int, tr trace.Trace) *fr.Element {
// Check whether argument evaluates to zero or not.
val := e.Arg.EvalAt(k, tr)
// Compute exponent
util.Pow(val, e.Pow)
// Done
return val
}

// EvalAt evaluates the normalisation of some expression by first evaluating
// that expression. Then, zero is returned if the result is zero; otherwise one
// is returned.
Expand Down
20 changes: 20 additions & 0 deletions pkg/mir/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,26 @@ func (p *Mul) Context(schema sc.Schema) trace.Context {
return sc.JoinContexts[Expr](p.Args, schema)
}

// ============================================================================
// Exponentiation
// ============================================================================

// Exp represents the a given value taken to a power.
type Exp struct {
Arg Expr
Pow uint64
}

// Bounds returns max shift in either the negative (left) or positive
// direction (right).
func (p *Exp) Bounds() util.Bounds { return p.Arg.Bounds() }

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (p *Exp) Context(schema sc.Schema) trace.Context {
return p.Arg.Context(schema)
}

// ============================================================================
// Constant
// ============================================================================
Expand Down
16 changes: 16 additions & 0 deletions pkg/mir/lower.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,22 @@ func (e *Mul) LowerTo(schema *air.Schema) air.Expr {
return &air.Mul{Args: lowerExprs(e.Args, schema)}
}

// LowerTo lowers an exponent expression to the AIR level by lowering the
// argument, and then constructing a multiplication. This is because the AIR
// level does not support an explicit exponent operator.
func (e *Exp) LowerTo(schema *air.Schema) air.Expr {
// Lower the expression being raised
le := e.Arg.LowerTo(schema)
// Multiply it out k times
es := make([]air.Expr, e.Pow)
//
for i := uint64(0); i < e.Pow; i++ {
es[i] = le
}
// Done
return &air.Mul{Args: es}
}

// LowerTo lowers a normalise expression to the AIR level by "compiling it out"
// using a computed column.
func (p *Normalise) LowerTo(schema *air.Schema) air.Expr {
Expand Down
4 changes: 4 additions & 0 deletions pkg/mir/string.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ func (e *Normalise) String() string {
return fmt.Sprintf("(~ %s)", e.Arg)
}

func (e *Exp) String() string {
return fmt.Sprintf("(^ %s %d)", e.Arg, e.Pow)
}

func naryString(operator string, exprs []Expr) string {
// This should be generalised and moved into common?
var rs string
Expand Down
8 changes: 7 additions & 1 deletion pkg/test/ir_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ func Test_Basic_09(t *testing.T) {
Check(t, "basic_09")
}

func Test_Basic_10(t *testing.T) {
Check(t, "basic_10")
}

// ===================================================================
// Domain Tests
// ===================================================================
Expand Down Expand Up @@ -420,12 +424,14 @@ func TestSlow_Mxp(t *testing.T) {

// Determines the maximum amount of padding to use when testing. Specifically,
// every trace is tested with varying amounts of padding upto this value.
const MAX_PADDING uint = 1
const MAX_PADDING uint = 5

// For a given set of constraints, check that all traces which we
// expect to be accepted are accepted, and all traces that we expect
// to be rejected are rejected.
func Check(t *testing.T, test string) {
// Enable testing each trace in parallel
t.Parallel()
// Read constraints file
bytes, err := os.ReadFile(fmt.Sprintf("%s/%s.lisp", TestDir, test))
// Check test file read ok
Expand Down
71 changes: 71 additions & 0 deletions pkg/test/util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package test

import (
"math/big"
"testing"

"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/util"
)

const POW_BASE_MAX uint = 65536
const POW_BASE_INC uint = 8

func Test_Pow_01(t *testing.T) {
PowCheckLoop(t, 0)
}

func Test_Pow_02(t *testing.T) {
PowCheckLoop(t, 1)
}

func Test_Pow_03(t *testing.T) {
PowCheckLoop(t, 2)
}

func Test_Pow_04(t *testing.T) {
PowCheckLoop(t, 3)
}

func Test_Pow_05(t *testing.T) {
PowCheckLoop(t, 4)
}

func Test_Pow_06(t *testing.T) {
PowCheckLoop(t, 5)
}

func Test_Pow_07(t *testing.T) {
PowCheckLoop(t, 6)
}

func Test_Pow_08(t *testing.T) {
PowCheckLoop(t, 7)
}

func PowCheckLoop(t *testing.T, first uint) {
// Enable parallel testing
t.Parallel()
// Run through the loop
for i := first; i < POW_BASE_MAX; i += POW_BASE_INC {
for j := uint64(0); j < 256; j++ {
PowCheck(t, i, j)
}
}
}

// Check pow computed correctly. This is done by comparing against the existing
// gnark function.
func PowCheck(t *testing.T, base uint, pow uint64) {
k := big.NewInt(int64(pow))
v1 := fr.NewElement(uint64(base))
v2 := fr.NewElement(uint64(base))
// V1 computed using our optimised method
util.Pow(&v1, pow)
// V2 computed using existing gnark function
v2.Exp(v2, k)
// Final sanity check
if v1.Cmp(&v2) != 0 {
t.Errorf("Pow(%d,%d)=%s (not %s)", base, pow, v1.String(), v2.String())
}
}
22 changes: 22 additions & 0 deletions pkg/util/fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,25 @@ func ToFieldElements(ints []*big.Int) []*fr.Element {
// Done.
return elements
}

// Pow takes a given value to the power n.
func Pow(val *fr.Element, n uint64) {
if n == 0 {
val.SetOne()
} else if n > 1 {
m := n / 2
// Check for odd case
if n%2 == 1 {
var tmp fr.Element
// Clone value
tmp.Set(val)
Pow(val, m)
val.Square(val)
val.Mul(val, &tmp)
} else {
// Even case is easy
Pow(val, m)
val.Square(val)
}
}
}
Loading

0 comments on commit f034b5d

Please sign in to comment.