Skip to content

Commit

Permalink
refactor: assertion parsing (kyverno#514)
Browse files Browse the repository at this point in the history
Signed-off-by: Charles-Edouard Brétéché <[email protected]>
  • Loading branch information
eddycharly authored Sep 23, 2024
1 parent 574b26e commit ae40afb
Showing 1 changed file with 117 additions and 142 deletions.
259 changes: 117 additions & 142 deletions pkg/core/assertion/assertion.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
package assertion

import (
"errors"
"fmt"
"reflect"
"sync"

"github.com/jmespath-community/go-jmespath/pkg/binding"
"github.com/kyverno/kyverno-json/pkg/core/compilers"
"github.com/kyverno/kyverno-json/pkg/core/expression"
"github.com/kyverno/kyverno-json/pkg/core/matching"
"github.com/kyverno/kyverno-json/pkg/core/projection"
reflectutils "github.com/kyverno/kyverno-json/pkg/utils/reflect"
Expand All @@ -19,7 +16,7 @@ type Assertion interface {
Assert(*field.Path, any, binding.Bindings) (field.ErrorList, error)
}

func Parse(assertion any, compiler compilers.Compilers) (node, error) {
func Parse(assertion any, compiler compilers.Compilers) (Assertion, error) {
switch reflectutils.GetKind(assertion) {
case reflect.Slice:
return parseSlice(assertion, compiler)
Expand All @@ -30,18 +27,36 @@ func Parse(assertion any, compiler compilers.Compilers) (node, error) {
}
}

// node implements the Assertion interface using a delegate func
type node func(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error)
// sliceNode is the assertion represented by a slice.
// it first compares the length of the analysed resource with the length of the descendants.
// if lengths match all descendants are evaluated with their corresponding items.
type sliceNode []Assertion

func (n node) Assert(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error) {
return n(path, value, bindings)
func (node sliceNode) Assert(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error) {
var errs field.ErrorList
if value == nil {
errs = append(errs, field.Invalid(path, value, "value is null"))
} else if reflectutils.GetKind(value) != reflect.Slice {
return nil, field.TypeInvalid(path, value, "expected a slice")
} else {
valueOf := reflect.ValueOf(value)
if valueOf.Len() != len(node) {
errs = append(errs, field.Invalid(path, value, "lengths of slices don't match"))
} else {
for i := range node {
if _errs, err := node[i].Assert(path.Index(i), valueOf.Index(i).Interface(), bindings); err != nil {
return nil, err
} else {
errs = append(errs, _errs...)
}
}
}
}
return errs, nil
}

// parseSlice is the assertion represented by a slice.
// it first compares the length of the analysed resource with the length of the descendants.
// if lengths match all descendants are evaluated with their corresponding items.
func parseSlice(assertion any, compiler compilers.Compilers) (node, error) {
var assertions []node
func parseSlice(assertion any, compiler compilers.Compilers) (sliceNode, error) {
var assertions sliceNode
valueOf := reflect.ValueOf(assertion)
for i := 0; i < valueOf.Len(); i++ {
sub, err := Parse(valueOf.Index(i).Interface(), compiler)
Expand All @@ -50,37 +65,81 @@ func parseSlice(assertion any, compiler compilers.Compilers) (node, error) {
}
assertions = append(assertions, sub)
}
return func(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error) {
var errs field.ErrorList
return assertions, nil
}

// mapNode is the assertion represented by a map.
// it is responsible for projecting the analysed resource and passing the result to the descendant
type mapNode map[any]struct {
projection.Projection
Assertion
}

func (node mapNode) Assert(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error) {
var errs field.ErrorList
// if we assert against an empty object, value is expected to be not nil
if len(node) == 0 {
if value == nil {
errs = append(errs, field.Invalid(path, value, "value is null"))
} else if reflectutils.GetKind(value) != reflect.Slice {
return nil, field.TypeInvalid(path, value, "expected a slice")
errs = append(errs, field.Invalid(path, value, "invalid value, must not be null"))
}
return errs, nil
}
for k, v := range node {
projected, found, err := v.Projection.Handler(value, bindings)
if err != nil {
return nil, field.InternalError(path.Child(fmt.Sprint(k)), err)
} else if !found {
errs = append(errs, field.Required(path.Child(fmt.Sprint(k)), "field not found in the input object"))
} else {
valueOf := reflect.ValueOf(value)
if valueOf.Len() != len(assertions) {
errs = append(errs, field.Invalid(path, value, "lengths of slices don't match"))
} else {
for i := range assertions {
if _errs, err := assertions[i].Assert(path.Index(i), valueOf.Index(i).Interface(), bindings); err != nil {
return nil, err
} else {
errs = append(errs, _errs...)
if v.Projection.Binding != "" {
bindings = bindings.Register("$"+v.Projection.Binding, binding.NewBinding(projected))
}
if v.Projection.Foreach {
projectedKind := reflectutils.GetKind(projected)
if projectedKind == reflect.Slice {
valueOf := reflect.ValueOf(projected)
for i := 0; i < valueOf.Len(); i++ {
bindings := bindings
if v.Projection.ForeachName != "" {
bindings = bindings.Register("$"+v.Projection.ForeachName, binding.NewBinding(i))
}
if _errs, err := v.Assert(path.Child(fmt.Sprint(k)).Index(i), valueOf.Index(i).Interface(), bindings); err != nil {
return nil, err
} else {
errs = append(errs, _errs...)
}
}
} else if projectedKind == reflect.Map {
iter := reflect.ValueOf(projected).MapRange()
for iter.Next() {
key := iter.Key().Interface()
bindings := bindings
if v.Projection.ForeachName != "" {
bindings = bindings.Register("$"+v.Projection.ForeachName, binding.NewBinding(key))
}
if _errs, err := v.Assert(path.Child(fmt.Sprint(k)).Key(fmt.Sprint(key)), iter.Value().Interface(), bindings); err != nil {
return nil, err
} else {
errs = append(errs, _errs...)
}
}
} else {
return nil, field.TypeInvalid(path.Child(fmt.Sprint(k)), projected, "expected a slice or a map")
}
} else {
if _errs, err := v.Assert(path.Child(fmt.Sprint(k)), projected, bindings); err != nil {
return nil, err
} else {
errs = append(errs, _errs...)
}
}
}
return errs, nil
}, nil
}
return errs, nil
}

// parseMap is the assertion represented by a map.
// it is responsible for projecting the analysed resource and passing the result to the descendant
func parseMap(assertion any, compiler compilers.Compilers) (node, error) {
assertions := map[any]struct {
projection.Projection
node
}{}
func parseMap(assertion any, compiler compilers.Compilers) (mapNode, error) {
assertions := mapNode{}
iter := reflect.ValueOf(assertion).MapRange()
for iter.Next() {
key := iter.Key().Interface()
Expand All @@ -90,120 +149,36 @@ func parseMap(assertion any, compiler compilers.Compilers) (node, error) {
return nil, err
}
entry := assertions[key]
entry.node = assertion
entry.Assertion = assertion
entry.Projection = projection.ParseMapKey(key, compiler)
assertions[key] = entry
}
return func(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error) {
var errs field.ErrorList
// if we assert against an empty object, value is expected to be not nil
if len(assertions) == 0 {
if value == nil {
errs = append(errs, field.Invalid(path, value, "invalid value, must not be null"))
}
return errs, nil
}
for k, v := range assertions {
projected, found, err := v.Projection.Handler(value, bindings)
if err != nil {
return nil, field.InternalError(path.Child(fmt.Sprint(k)), err)
} else if !found {
errs = append(errs, field.Required(path.Child(fmt.Sprint(k)), "field not found in the input object"))
} else {
if v.Projection.Binding != "" {
bindings = bindings.Register("$"+v.Projection.Binding, binding.NewBinding(projected))
}
if v.Projection.Foreach {
projectedKind := reflectutils.GetKind(projected)
if projectedKind == reflect.Slice {
valueOf := reflect.ValueOf(projected)
for i := 0; i < valueOf.Len(); i++ {
bindings := bindings
if v.Projection.ForeachName != "" {
bindings = bindings.Register("$"+v.Projection.ForeachName, binding.NewBinding(i))
}
if _errs, err := v.Assert(path.Child(fmt.Sprint(k)).Index(i), valueOf.Index(i).Interface(), bindings); err != nil {
return nil, err
} else {
errs = append(errs, _errs...)
}
}
} else if projectedKind == reflect.Map {
iter := reflect.ValueOf(projected).MapRange()
for iter.Next() {
key := iter.Key().Interface()
bindings := bindings
if v.Projection.ForeachName != "" {
bindings = bindings.Register("$"+v.Projection.ForeachName, binding.NewBinding(key))
}
if _errs, err := v.Assert(path.Child(fmt.Sprint(k)).Key(fmt.Sprint(key)), iter.Value().Interface(), bindings); err != nil {
return nil, err
} else {
errs = append(errs, _errs...)
}
}
} else {
return nil, field.TypeInvalid(path.Child(fmt.Sprint(k)), projected, "expected a slice or a map")
}
} else {
if _errs, err := v.Assert(path.Child(fmt.Sprint(k)), projected, bindings); err != nil {
return nil, err
} else {
errs = append(errs, _errs...)
}
}
}
}
return errs, nil
}, nil
return assertions, nil
}

// parseScalar is the assertion represented by a leaf.
// scalarNode is the assertion represented by a leaf.
// it receives a value and compares it with an expected value.
// the expected value can be the result of an expression.
func parseScalar(assertion any, compiler compilers.Compilers) (node, error) {
var project func(value any, bindings binding.Bindings) (any, error)
switch typed := assertion.(type) {
case string:
expr := expression.Parse(typed)
if expr.Foreach {
return nil, errors.New("foreach is not supported on the RHS")
}
if expr.Binding != "" {
return nil, errors.New("binding is not supported on the RHS")
}
if compiler := compiler.Compiler(expr.Compiler); compiler != nil {
parse := sync.OnceValues(func() (compilers.Program, error) {
return compiler.Compile(expr.Statement)
})
project = func(value any, bindings binding.Bindings) (any, error) {
program, err := parse()
if err != nil {
return nil, err
}
return program(value, bindings)
}
} else {
assertion = expr.Statement
}
type scalarNode projection.ScalarHandler

func (node scalarNode) Assert(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error) {
var errs field.ErrorList
if projected, err := node(value, bindings); err != nil {
return nil, field.InternalError(path, err)
} else if match, err := matching.Match(projected, value); err != nil {
return nil, field.InternalError(path, err)
} else if !match {
errs = append(errs, field.Invalid(path, value, expectValueMessage(projected)))
}
return func(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error) {
expected := assertion
if project != nil {
projected, err := project(value, bindings)
if err != nil {
return nil, field.InternalError(path, err)
}
expected = projected
}
var errs field.ErrorList
if match, err := matching.Match(expected, value); err != nil {
return nil, field.InternalError(path, err)
} else if !match {
errs = append(errs, field.Invalid(path, value, expectValueMessage(expected)))
}
return errs, nil
}, nil
return errs, nil
}

func parseScalar(in any, compiler compilers.Compilers) (scalarNode, error) {
proj, err := projection.ParseScalar(in, compiler)
if err != nil {
return nil, err
}
return proj, err
}

func expectValueMessage(value any) string {
Expand Down

0 comments on commit ae40afb

Please sign in to comment.