diff --git a/go.mod b/go.mod index 2b5a774f9..83d2b4e5e 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/blang/semver/v4 v4.0.0 github.com/gin-contrib/cors v1.7.2 github.com/gin-gonic/gin v1.10.0 + github.com/google/cel-go v0.17.8 github.com/google/go-cmp v0.6.0 github.com/jmespath-community/go-jmespath v1.1.2-0.20240117150817-e430401a2172 github.com/kyverno/pkg/ext v0.0.0-20240418121121-df8add26c55c @@ -62,7 +63,6 @@ require ( github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/btree v1.1.2 // indirect - github.com/google/cel-go v0.17.8 // indirect github.com/google/gnostic-models v0.6.9-0.20230804172637-c7be7c783f49 // indirect github.com/google/gofuzz v1.2.0 // indirect github.com/google/gxui v0.0.0-20151028112939-f85e0a97b3a4 // indirect diff --git a/pkg/apis/policy/v1alpha1/assertion_tree.go b/pkg/apis/policy/v1alpha1/assertion_tree.go index 8859541d3..6bb156998 100644 --- a/pkg/apis/policy/v1alpha1/assertion_tree.go +++ b/pkg/apis/policy/v1alpha1/assertion_tree.go @@ -1,10 +1,8 @@ package v1alpha1 import ( - "context" - "sync" - "github.com/kyverno/kyverno-json/pkg/core/assertion" + "github.com/kyverno/kyverno-json/pkg/core/templating" "k8s.io/apimachinery/pkg/util/json" ) @@ -13,24 +11,20 @@ import ( // +kubebuilder:validation:Type:="" // AssertionTree represents an assertion tree. type AssertionTree struct { - _tree any - _assertion func() (assertion.Assertion, error) + _tree any } func NewAssertionTree(value any) AssertionTree { return AssertionTree{ _tree: value, - _assertion: sync.OnceValues(func() (assertion.Assertion, error) { - return assertion.Parse(context.Background(), value) - }), } } -func (t *AssertionTree) Assertion() (assertion.Assertion, error) { +func (t *AssertionTree) Assertion(compiler templating.Compiler) (assertion.Assertion, error) { if t._tree == nil { return nil, nil } - return t._assertion() + return assertion.Parse(t._tree, compiler) } func (a *AssertionTree) MarshalJSON() ([]byte, error) { @@ -44,13 +38,9 @@ func (a *AssertionTree) UnmarshalJSON(data []byte) error { return err } a._tree = v - a._assertion = sync.OnceValues(func() (assertion.Assertion, error) { - return assertion.Parse(context.Background(), v) - }) return nil } func (in *AssertionTree) DeepCopyInto(out *AssertionTree) { out._tree = deepCopy(in._tree) - out._assertion = in._assertion } diff --git a/pkg/commands/jp/query/command.go b/pkg/commands/jp/query/command.go index a955c21b7..0720c1444 100644 --- a/pkg/commands/jp/query/command.go +++ b/pkg/commands/jp/query/command.go @@ -1,7 +1,6 @@ package query import ( - "context" "encoding/json" "errors" "fmt" @@ -11,7 +10,7 @@ import ( "github.com/jmespath-community/go-jmespath/pkg/parsing" "github.com/kyverno/kyverno-json/pkg/command" - "github.com/kyverno/kyverno-json/pkg/engine/template" + "github.com/kyverno/kyverno-json/pkg/core/templating" "github.com/spf13/cobra" "sigs.k8s.io/yaml" ) @@ -156,7 +155,8 @@ func loadInput(cmd *cobra.Command, file string) (any, error) { } func evaluate(input any, query string) (any, error) { - result, err := template.ExecuteJP(context.Background(), query, input, nil) + compiler := templating.NewCompiler(templating.CompilerOptions{}) + result, err := templating.ExecuteJP(query, input, nil, compiler) if err != nil { if syntaxError, ok := err.(parsing.SyntaxError); ok { return nil, fmt.Errorf("%s\n%s", syntaxError, syntaxError.HighlightLocation()) diff --git a/pkg/commands/scan/command_test.go b/pkg/commands/scan/command_test.go index 162e1d2c9..efd738a1c 100644 --- a/pkg/commands/scan/command_test.go +++ b/pkg/commands/scan/command_test.go @@ -24,6 +24,12 @@ func Test_Execute(t *testing.T) { policies: []string{"../../../test/commands/scan/foo-bar/policy.yaml"}, out: "../../../test/commands/scan/foo-bar/out.txt", wantErr: false, + }, { + name: "cel", + payload: "../../../test/commands/scan/cel/payload.yaml", + policies: []string{"../../../test/commands/scan/cel/policy.yaml"}, + out: "../../../test/commands/scan/cel/out.txt", + wantErr: false, }, { name: "wildcard", payload: "../../../test/commands/scan/wildcard/payload.json", diff --git a/pkg/commands/scan/options.go b/pkg/commands/scan/options.go index 0f091504e..5b0a3d81e 100644 --- a/pkg/commands/scan/options.go +++ b/pkg/commands/scan/options.go @@ -7,7 +7,7 @@ import ( "strings" "github.com/kyverno/kyverno-json/pkg/apis/policy/v1alpha1" - "github.com/kyverno/kyverno-json/pkg/engine/template" + "github.com/kyverno/kyverno-json/pkg/core/templating" jsonengine "github.com/kyverno/kyverno-json/pkg/json-engine" "github.com/kyverno/kyverno-json/pkg/payload" "github.com/kyverno/kyverno-json/pkg/policy" @@ -76,8 +76,9 @@ func (c *options) run(cmd *cobra.Command, _ []string) error { return errors.New("payload is `null`") } out.println("Pre processing ...") + compiler := templating.NewCompiler(templating.CompilerOptions{}) for _, preprocessor := range c.preprocessors { - result, err := template.ExecuteJP(context.Background(), preprocessor, payload, nil) + result, err := templating.ExecuteJP(preprocessor, payload, nil, compiler) if err != nil { return err } diff --git a/pkg/core/assertion/assertion.go b/pkg/core/assertion/assertion.go index 5721567a1..b0322b149 100644 --- a/pkg/core/assertion/assertion.go +++ b/pkg/core/assertion/assertion.go @@ -1,58 +1,56 @@ package assertion import ( - "context" "errors" "fmt" "reflect" "sync" "github.com/jmespath-community/go-jmespath/pkg/binding" - "github.com/jmespath-community/go-jmespath/pkg/parsing" "github.com/kyverno/kyverno-json/pkg/core/expression" "github.com/kyverno/kyverno-json/pkg/core/projection" + "github.com/kyverno/kyverno-json/pkg/core/templating" "github.com/kyverno/kyverno-json/pkg/engine/match" - "github.com/kyverno/kyverno-json/pkg/engine/template" reflectutils "github.com/kyverno/kyverno-json/pkg/utils/reflect" "k8s.io/apimachinery/pkg/util/validation/field" ) type Assertion interface { - Assert(context.Context, *field.Path, any, binding.Bindings, ...template.Option) (field.ErrorList, error) + Assert(*field.Path, any, binding.Bindings) (field.ErrorList, error) } -func Parse(ctx context.Context, assertion any) (node, error) { +func Parse(assertion any, compiler templating.Compiler) (node, error) { switch reflectutils.GetKind(assertion) { case reflect.Slice: - return parseSlice(ctx, assertion) + return parseSlice(assertion, compiler) case reflect.Map: - return parseMap(ctx, assertion) + return parseMap(assertion, compiler) default: - return parseScalar(ctx, assertion) + return parseScalar(assertion, compiler) } } // node implements the Assertion interface using a delegate func -type node func(ctx context.Context, path *field.Path, value any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error) +type node func(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error) -func (n node) Assert(ctx context.Context, path *field.Path, value any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error) { - return n(ctx, path, value, bindings, opts...) +func (n node) Assert(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error) { + return n(path, value, bindings) } // parseSlice is the assertion represented by a slice. // it first compares the length of the analysed resource with the length of the descendants. // if lengths match all descendants are evaluated with their corresponding items. -func parseSlice(ctx context.Context, assertion any) (node, error) { +func parseSlice(assertion any, compiler templating.Compiler) (node, error) { var assertions []node valueOf := reflect.ValueOf(assertion) for i := 0; i < valueOf.Len(); i++ { - sub, err := Parse(ctx, valueOf.Index(i).Interface()) + sub, err := Parse(valueOf.Index(i).Interface(), compiler) if err != nil { return nil, err } assertions = append(assertions, sub) } - return func(ctx context.Context, path *field.Path, value any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error) { + return func(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error) { var errs field.ErrorList if value == nil { errs = append(errs, field.Invalid(path, value, "value is null")) @@ -64,7 +62,7 @@ func parseSlice(ctx context.Context, assertion any) (node, error) { errs = append(errs, field.Invalid(path, value, "lengths of slices don't match")) } else { for i := range assertions { - if _errs, err := assertions[i].Assert(ctx, path.Index(i), valueOf.Index(i).Interface(), bindings, opts...); err != nil { + if _errs, err := assertions[i].Assert(path.Index(i), valueOf.Index(i).Interface(), bindings); err != nil { return nil, err } else { errs = append(errs, _errs...) @@ -78,7 +76,7 @@ func parseSlice(ctx context.Context, assertion any) (node, error) { // parseMap is the assertion represented by a map. // it is responsible for projecting the analysed resource and passing the result to the descendant -func parseMap(ctx context.Context, assertion any) (node, error) { +func parseMap(assertion any, compiler templating.Compiler) (node, error) { assertions := map[any]struct { projection.Projection node @@ -87,16 +85,16 @@ func parseMap(ctx context.Context, assertion any) (node, error) { for iter.Next() { key := iter.Key().Interface() value := iter.Value().Interface() - assertion, err := Parse(ctx, value) + assertion, err := Parse(value, compiler) if err != nil { return nil, err } entry := assertions[key] entry.node = assertion - entry.Projection = projection.Parse(key) + entry.Projection = projection.Parse(key, compiler) assertions[key] = entry } - return func(ctx context.Context, path *field.Path, value any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error) { + return func(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error) { var errs field.ErrorList // if we assert against an empty object, value is expected to be not nil if len(assertions) == 0 { @@ -106,7 +104,7 @@ func parseMap(ctx context.Context, assertion any) (node, error) { return errs, nil } for k, v := range assertions { - projected, found, err := v.Projection.Handler(ctx, value, bindings, opts...) + projected, found, err := v.Projection.Handler(value, bindings) if err != nil { return nil, field.InternalError(path.Child(fmt.Sprint(k)), err) } else if !found { @@ -124,7 +122,7 @@ func parseMap(ctx context.Context, assertion any) (node, error) { if v.Projection.ForeachName != "" { bindings = bindings.Register("$"+v.Projection.ForeachName, binding.NewBinding(i)) } - if _errs, err := v.Assert(ctx, path.Child(fmt.Sprint(k)).Index(i), valueOf.Index(i).Interface(), bindings, opts...); err != nil { + if _errs, err := v.Assert(path.Child(fmt.Sprint(k)).Index(i), valueOf.Index(i).Interface(), bindings); err != nil { return nil, err } else { errs = append(errs, _errs...) @@ -138,7 +136,7 @@ func parseMap(ctx context.Context, assertion any) (node, error) { if v.Projection.ForeachName != "" { bindings = bindings.Register("$"+v.Projection.ForeachName, binding.NewBinding(key)) } - if _errs, err := v.Assert(ctx, path.Child(fmt.Sprint(k)).Key(fmt.Sprint(key)), iter.Value().Interface(), bindings, opts...); err != nil { + if _errs, err := v.Assert(path.Child(fmt.Sprint(k)).Key(fmt.Sprint(key)), iter.Value().Interface(), bindings); err != nil { return nil, err } else { errs = append(errs, _errs...) @@ -148,7 +146,7 @@ func parseMap(ctx context.Context, assertion any) (node, error) { return nil, field.TypeInvalid(path.Child(fmt.Sprint(k)), projected, "expected a slice or a map") } } else { - if _errs, err := v.Assert(ctx, path.Child(fmt.Sprint(k)), projected, bindings, opts...); err != nil { + if _errs, err := v.Assert(path.Child(fmt.Sprint(k)), projected, bindings); err != nil { return nil, err } else { errs = append(errs, _errs...) @@ -163,8 +161,8 @@ func parseMap(ctx context.Context, assertion any) (node, error) { // parseScalar is the assertion represented by a leaf. // it receives a value and compares it with an expected value. // the expected value can be the result of an expression. -func parseScalar(_ context.Context, assertion any) (node, error) { - var project func(ctx context.Context, value any, bindings binding.Bindings, opts ...template.Option) (any, error) +func parseScalar(assertion any, compiler templating.Compiler) (node, error) { + var project func(value any, bindings binding.Bindings) (any, error) switch typed := assertion.(type) { case string: expr := expression.Parse(typed) @@ -176,34 +174,39 @@ func parseScalar(_ context.Context, assertion any) (node, error) { } switch expr.Engine { case expression.EngineJP: - parse := sync.OnceValues(func() (parsing.ASTNode, error) { - parser := parsing.NewParser() - return parser.Parse(expr.Statement) + parse := sync.OnceValues(func() (templating.Program, error) { + return compiler.CompileJP(expr.Statement) }) - project = func(ctx context.Context, value any, bindings binding.Bindings, opts ...template.Option) (any, error) { - ast, err := parse() + project = func(value any, bindings binding.Bindings) (any, error) { + program, err := parse() if err != nil { return nil, err } - return template.ExecuteAST(ctx, ast, value, bindings, opts...) + return program(value, bindings) } case expression.EngineCEL: - return nil, errors.New("engine not supported") + project = func(value any, bindings binding.Bindings) (any, error) { + program, err := compiler.CompileCEL(expr.Statement) + if err != nil { + return nil, err + } + return program(value, bindings) + } default: assertion = expr.Statement } } - return func(ctx context.Context, path *field.Path, value any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error) { + return func(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error) { expected := assertion if project != nil { - projected, err := project(ctx, value, bindings, opts...) + projected, err := project(value, bindings) if err != nil { return nil, field.InternalError(path, err) } expected = projected } var errs field.ErrorList - if match, err := match.Match(ctx, expected, value); err != nil { + if match, err := match.Match(expected, value); err != nil { return nil, field.InternalError(path, err) } else if !match { errs = append(errs, field.Invalid(path, value, expectValueMessage(expected))) diff --git a/pkg/core/assertion/assertion_test.go b/pkg/core/assertion/assertion_test.go index 0f31f7b41..2a2ca11ab 100644 --- a/pkg/core/assertion/assertion_test.go +++ b/pkg/core/assertion/assertion_test.go @@ -1,10 +1,10 @@ package assertion import ( - "context" "testing" "github.com/jmespath-community/go-jmespath/pkg/binding" + "github.com/kyverno/kyverno-json/pkg/core/templating" tassert "github.com/stretchr/testify/assert" "k8s.io/apimachinery/pkg/util/validation/field" ) @@ -48,9 +48,10 @@ func TestAssert(t *testing.T) { }} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - parsed, err := Parse(context.TODO(), tt.assertion) + compiler := templating.NewCompiler(templating.CompilerOptions{}) + parsed, err := Parse(tt.assertion, compiler) tassert.NoError(t, err) - got, err := parsed.Assert(context.TODO(), nil, tt.value, tt.bindings) + got, err := parsed.Assert(nil, tt.value, tt.bindings) if tt.wantErr { tassert.Error(t, err) } else { diff --git a/pkg/core/message/message.go b/pkg/core/message/message.go index bfeec4a3e..268838e0a 100644 --- a/pkg/core/message/message.go +++ b/pkg/core/message/message.go @@ -1,7 +1,6 @@ package message import ( - "context" "fmt" "regexp" "strings" @@ -9,17 +8,17 @@ import ( "github.com/jmespath-community/go-jmespath/pkg/binding" "github.com/jmespath-community/go-jmespath/pkg/parsing" - "github.com/kyverno/kyverno-json/pkg/engine/template" + "github.com/kyverno/kyverno-json/pkg/core/templating/jp" ) var variable = regexp.MustCompile(`{{(.*?)}}`) type Message interface { Original() string - Format(any, binding.Bindings, ...template.Option) string + Format(any, binding.Bindings, ...jp.Option) string } -type substitution = func(string, any, binding.Bindings, ...template.Option) string +type substitution = func(string, any, binding.Bindings, ...jp.Option) string type message struct { original string @@ -30,7 +29,7 @@ func (m *message) Original() string { return m.original } -func (m *message) Format(value any, bindings binding.Bindings, opts ...template.Option) string { +func (m *message) Format(value any, bindings binding.Bindings, opts ...jp.Option) string { out := m.original for _, substitution := range m.substitutions { out = substitution(out, value, bindings, opts...) @@ -40,22 +39,22 @@ func (m *message) Format(value any, bindings binding.Bindings, opts ...template. func Parse(in string) *message { groups := variable.FindAllStringSubmatch(in, -1) - var substitutions []func(string, any, binding.Bindings, ...template.Option) string + var substitutions []func(string, any, binding.Bindings, ...jp.Option) string for _, group := range groups { statement := strings.TrimSpace(group[1]) parse := sync.OnceValues(func() (parsing.ASTNode, error) { parser := parsing.NewParser() return parser.Parse(statement) }) - evaluate := func(value any, bindings binding.Bindings, opts ...template.Option) (any, error) { + evaluate := func(value any, bindings binding.Bindings, opts ...jp.Option) (any, error) { ast, err := parse() if err != nil { return nil, err } - return template.ExecuteAST(context.TODO(), ast, value, bindings, opts...) + return jp.Execute(ast, value, bindings, opts...) } placeholder := group[0] - substitutions = append(substitutions, func(out string, value any, bindings binding.Bindings, opts ...template.Option) string { + substitutions = append(substitutions, func(out string, value any, bindings binding.Bindings, opts ...jp.Option) string { result, err := evaluate(value, bindings, opts...) if err != nil { out = strings.ReplaceAll(out, placeholder, fmt.Sprintf("ERR (%s - %s)", statement, err)) diff --git a/pkg/core/projection/projection.go b/pkg/core/projection/projection.go index b39b1ba1d..db6fee2c4 100644 --- a/pkg/core/projection/projection.go +++ b/pkg/core/projection/projection.go @@ -1,19 +1,17 @@ package projection import ( - "context" "errors" "reflect" "sync" "github.com/jmespath-community/go-jmespath/pkg/binding" - "github.com/jmespath-community/go-jmespath/pkg/parsing" "github.com/kyverno/kyverno-json/pkg/core/expression" - "github.com/kyverno/kyverno-json/pkg/engine/template" + "github.com/kyverno/kyverno-json/pkg/core/templating" reflectutils "github.com/kyverno/kyverno-json/pkg/utils/reflect" ) -type Handler = func(ctx context.Context, value any, bindings binding.Bindings, opts ...template.Option) (any, bool, error) +type Handler = func(value any, bindings binding.Bindings) (any, bool, error) type Info struct { Foreach bool @@ -26,7 +24,7 @@ type Projection struct { Handler } -func Parse(in any) (projection Projection) { +func Parse(in any, compiler templating.Compiler) (projection Projection) { switch typed := in.(type) { case string: // 1. if we have a string, parse the expression @@ -38,22 +36,34 @@ func Parse(in any) (projection Projection) { // 3. compute the projection func switch expr.Engine { case expression.EngineJP: - parse := sync.OnceValues(func() (parsing.ASTNode, error) { - parser := parsing.NewParser() - return parser.Parse(expr.Statement) + parse := sync.OnceValues(func() (templating.Program, error) { + return compiler.CompileJP(expr.Statement) }) - projection.Handler = func(ctx context.Context, value any, bindings binding.Bindings, opts ...template.Option) (any, bool, error) { - ast, err := parse() + projection.Handler = func(value any, bindings binding.Bindings) (any, bool, error) { + program, err := parse() + if err != nil { + return nil, false, err + } + projected, err := program(value, bindings) if err != nil { return nil, false, err } - projected, err := template.ExecuteAST(ctx, ast, value, bindings, opts...) return projected, true, err } case expression.EngineCEL: - panic("engine not supported") + projection.Handler = func(value any, bindings binding.Bindings) (any, bool, error) { + program, err := compiler.CompileCEL(expr.Statement) + if err != nil { + return nil, false, err + } + projected, err := program(value, bindings) + if err != nil { + return nil, false, err + } + return projected, true, nil + } default: - projection.Handler = func(ctx context.Context, value any, bindings binding.Bindings, opts ...template.Option) (any, bool, error) { + projection.Handler = func(value any, bindings binding.Bindings) (any, bool, error) { if value == nil { return nil, false, nil } @@ -69,7 +79,7 @@ func Parse(in any) (projection Projection) { } default: // 1. compute the projection func - projection.Handler = func(ctx context.Context, value any, bindings binding.Bindings, opts ...template.Option) (any, bool, error) { + projection.Handler = func(value any, bindings binding.Bindings) (any, bool, error) { if value == nil { return nil, false, nil } diff --git a/pkg/core/projection/projection_test.go b/pkg/core/projection/projection_test.go index f2d73c628..62d4fa9a3 100644 --- a/pkg/core/projection/projection_test.go +++ b/pkg/core/projection/projection_test.go @@ -1,10 +1,10 @@ package projection import ( - "context" "testing" "github.com/jmespath-community/go-jmespath/pkg/binding" + "github.com/kyverno/kyverno-json/pkg/core/templating" tassert "github.com/stretchr/testify/assert" ) @@ -88,8 +88,9 @@ func TestProjection(t *testing.T) { }} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - proj := Parse(tt.key) - got, found, err := proj.Handler(context.TODO(), tt.value, tt.bindings) + compiler := templating.NewCompiler(templating.CompilerOptions{}) + proj := Parse(tt.key, compiler) + got, found, err := proj.Handler(tt.value, tt.bindings) if tt.wantErr { tassert.Error(t, err) } else { diff --git a/pkg/core/templating/cel/cel.go b/pkg/core/templating/cel/cel.go new file mode 100644 index 000000000..fb2e31f2e --- /dev/null +++ b/pkg/core/templating/cel/cel.go @@ -0,0 +1,18 @@ +package cel + +import ( + "github.com/google/cel-go/cel" + "github.com/jmespath-community/go-jmespath/pkg/binding" +) + +func Execute(program cel.Program, value any, bindings binding.Bindings) (any, error) { + data := map[string]interface{}{ + "object": value, + "bindings": NewVal(bindings, BindingsType), + } + out, _, err := program.Eval(data) + if err != nil { + return nil, err + } + return out.Value(), nil +} diff --git a/pkg/core/templating/cel/env.go b/pkg/core/templating/cel/env.go new file mode 100644 index 000000000..e3e0be817 --- /dev/null +++ b/pkg/core/templating/cel/env.go @@ -0,0 +1,41 @@ +package cel + +import ( + "sync" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/jmespath-community/go-jmespath/pkg/binding" +) + +var ( + BindingsType = cel.OpaqueType("bindings") + DefaultEnv = sync.OnceValues(func() (*cel.Env, error) { + return cel.NewEnv( + cel.Variable("object", cel.DynType), + cel.Variable("bindings", BindingsType), + cel.Function("resolve", + cel.MemberOverload("bindings_resolve_string", + []*cel.Type{BindingsType, cel.StringType}, + cel.AnyType, + cel.BinaryBinding(func(lhs, rhs ref.Val) ref.Val { + bindings, ok := lhs.(Val[binding.Bindings]) + if !ok { + return types.ValOrErr(bindings, "invalid bindings type") + } + name, ok := rhs.(types.String) + if !ok { + return types.ValOrErr(name, "invalid name type") + } + value, err := binding.Resolve("$"+string(name), bindings.Unwrap()) + if err != nil { + return types.WrapErr(err) + } + return types.DefaultTypeAdapter.NativeToValue(value) + }), + ), + ), + ) + }) +) diff --git a/pkg/core/templating/cel/val.go b/pkg/core/templating/cel/val.go new file mode 100644 index 000000000..a21458440 --- /dev/null +++ b/pkg/core/templating/cel/val.go @@ -0,0 +1,48 @@ +package cel + +import ( + "reflect" + + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" +) + +type Val[T comparable] struct { + inner T + celType ref.Type +} + +func NewVal[T comparable](value T, celType ref.Type) Val[T] { + return Val[T]{ + inner: value, + celType: celType, + } +} + +func (w Val[T]) Unwrap() T { + return w.inner +} + +func (w Val[T]) Value() interface{} { + return w.Unwrap() +} + +func (w Val[T]) ConvertToNative(typeDesc reflect.Type) (interface{}, error) { + panic("not required") +} + +func (w Val[T]) ConvertToType(typeVal ref.Type) ref.Val { + panic("not required") +} + +func (w Val[T]) Equal(other ref.Val) ref.Val { + o, ok := other.Value().(Val[T]) + if !ok { + return types.ValOrErr(other, "no such overload") + } + return types.Bool(o == w) +} + +func (w Val[T]) Type() ref.Type { + return w.celType +} diff --git a/pkg/core/templating/compiler.go b/pkg/core/templating/compiler.go new file mode 100644 index 000000000..7d9e1c042 --- /dev/null +++ b/pkg/core/templating/compiler.go @@ -0,0 +1,100 @@ +package templating + +import ( + "github.com/jmespath-community/go-jmespath/pkg/binding" + "github.com/jmespath-community/go-jmespath/pkg/interpreter" + "github.com/jmespath-community/go-jmespath/pkg/parsing" + "github.com/kyverno/kyverno-json/pkg/core/expression" + "github.com/kyverno/kyverno-json/pkg/core/templating/cel" + "github.com/kyverno/kyverno-json/pkg/core/templating/jp" + "k8s.io/apimachinery/pkg/util/validation/field" +) + +type CelOptions struct { + FunctionCaller interpreter.FunctionCaller +} + +type CompilerOptions struct { + Cel CelOptions + Jp []jp.Option +} + +type Compiler struct { + options CompilerOptions +} + +func NewCompiler(options CompilerOptions) Compiler { + return Compiler{ + options: options, + } +} + +type Program func(any, binding.Bindings) (any, error) + +func (c Compiler) Options() CompilerOptions { + return c.options +} + +func (c Compiler) CompileCEL(statement string) (Program, error) { + env, err := cel.DefaultEnv() + if err != nil { + return nil, err + } + ast, iss := env.Compile(statement) + if iss.Err() != nil { + return nil, iss.Err() + } + program, err := env.Program(ast) + if err != nil { + return nil, err + } + return func(value any, bindings binding.Bindings) (any, error) { + return cel.Execute(program, value, bindings) + }, nil +} + +func (c Compiler) CompileJP(statement string) (Program, error) { + parser := parsing.NewParser() + compiled, err := parser.Parse(statement) + if err != nil { + return nil, err + } + return func(value any, bindings binding.Bindings) (any, error) { + return jp.Execute(compiled, value, bindings, c.options.Jp...) + }, nil +} + +func (c Compiler) NewBinding(path *field.Path, value any, bindings binding.Bindings, template any) binding.Binding { + return jp.NewLazyBinding( + func() (any, error) { + switch typed := template.(type) { + case string: + expr := expression.Parse(typed) + if expr.Foreach { + return nil, field.Invalid(path.Child("variable"), typed, "foreach is not supported in context") + } + if expr.Binding != "" { + return nil, field.Invalid(path.Child("variable"), typed, "binding is not supported in context") + } + switch expr.Engine { + case expression.EngineJP: + projected, err := ExecuteJP(expr.Statement, value, bindings, c) + if err != nil { + return nil, field.InternalError(path.Child("variable"), err) + } + return projected, nil + case expression.EngineCEL: + projected, err := ExecuteCEL(expr.Statement, value, bindings, c) + if err != nil { + return nil, field.InternalError(path.Child("variable"), err) + } + return projected, nil + default: + return expr.Statement, nil + } + default: + return typed, nil + } + }, + ) +} diff --git a/pkg/jp/binding.go b/pkg/core/templating/jp/binding.go similarity index 100% rename from pkg/jp/binding.go rename to pkg/core/templating/jp/binding.go diff --git a/pkg/core/templating/jp/jp.go b/pkg/core/templating/jp/jp.go new file mode 100644 index 000000000..37a17ab74 --- /dev/null +++ b/pkg/core/templating/jp/jp.go @@ -0,0 +1,13 @@ +package jp + +import ( + "github.com/jmespath-community/go-jmespath/pkg/binding" + "github.com/jmespath-community/go-jmespath/pkg/interpreter" + "github.com/jmespath-community/go-jmespath/pkg/parsing" +) + +func Execute(ast parsing.ASTNode, value any, bindings binding.Bindings, opts ...Option) (any, error) { + o := buildOptions(opts...) + vm := interpreter.NewInterpreter(nil, bindings) + return vm.Execute(ast, value, interpreter.WithFunctionCaller(o.functionCaller)) +} diff --git a/pkg/engine/template/options.go b/pkg/core/templating/jp/options.go similarity index 97% rename from pkg/engine/template/options.go rename to pkg/core/templating/jp/options.go index 52cf408af..50fe93403 100644 --- a/pkg/engine/template/options.go +++ b/pkg/core/templating/jp/options.go @@ -1,4 +1,4 @@ -package template +package jp import ( "context" diff --git a/pkg/core/templating/templating.go b/pkg/core/templating/templating.go new file mode 100644 index 000000000..cbaea13b6 --- /dev/null +++ b/pkg/core/templating/templating.go @@ -0,0 +1,21 @@ +package templating + +import ( + "github.com/jmespath-community/go-jmespath/pkg/binding" +) + +func ExecuteJP(statement string, value any, bindings binding.Bindings, compiler Compiler) (any, error) { + program, err := compiler.CompileJP(statement) + if err != nil { + return nil, err + } + return program(value, bindings) +} + +func ExecuteCEL(statement string, value any, bindings binding.Bindings, compiler Compiler) (any, error) { + program, err := compiler.CompileCEL(statement) + if err != nil { + return nil, err + } + return program(value, bindings) +} diff --git a/pkg/engine/match/match.go b/pkg/engine/match/match.go index afb21f164..c433318bc 100644 --- a/pkg/engine/match/match.go +++ b/pkg/engine/match/match.go @@ -1,14 +1,13 @@ package match import ( - "context" "fmt" "reflect" reflectutils "github.com/kyverno/kyverno-json/pkg/utils/reflect" ) -func Match(ctx context.Context, expected, actual any) (bool, error) { +func Match(expected, actual any) (bool, error) { if expected != nil { switch reflectutils.GetKind(expected) { case reflect.Slice: @@ -19,7 +18,7 @@ func Match(ctx context.Context, expected, actual any) (bool, error) { return false, nil } for i := 0; i < reflect.ValueOf(expected).Len(); i++ { - if inner, err := Match(ctx, reflect.ValueOf(expected).Index(i).Interface(), reflect.ValueOf(actual).Index(i).Interface()); err != nil { + if inner, err := Match(reflect.ValueOf(expected).Index(i).Interface(), reflect.ValueOf(actual).Index(i).Interface()); err != nil { return false, err } else if !inner { return false, nil @@ -36,7 +35,7 @@ func Match(ctx context.Context, expected, actual any) (bool, error) { if !actualValue.IsValid() { return false, nil } - if inner, err := Match(ctx, iter.Value().Interface(), actualValue.Interface()); err != nil { + if inner, err := Match(iter.Value().Interface(), actualValue.Interface()); err != nil { return false, err } else if !inner { return false, nil diff --git a/pkg/engine/template/binding.go b/pkg/engine/template/binding.go deleted file mode 100644 index 511da21ee..000000000 --- a/pkg/engine/template/binding.go +++ /dev/null @@ -1,41 +0,0 @@ -package template - -import ( - "context" - - "github.com/jmespath-community/go-jmespath/pkg/binding" - "github.com/kyverno/kyverno-json/pkg/core/expression" - "github.com/kyverno/kyverno-json/pkg/jp" - "k8s.io/apimachinery/pkg/util/validation/field" -) - -func NewContextBinding(path *field.Path, bindings binding.Bindings, value any, template any, opts ...Option) binding.Binding { - return jp.NewLazyBinding( - func() (any, error) { - switch typed := template.(type) { - case string: - expr := expression.Parse(typed) - if expr.Foreach { - return nil, field.Invalid(path.Child("variable"), typed, "foreach is not supported in context") - } - if expr.Binding != "" { - return nil, field.Invalid(path.Child("variable"), typed, "binding is not supported in context") - } - switch expr.Engine { - case expression.EngineJP: - projected, err := ExecuteJP(context.TODO(), expr.Statement, value, bindings, opts...) - if err != nil { - return nil, field.InternalError(path.Child("variable"), err) - } - return projected, nil - case expression.EngineCEL: - return nil, field.Invalid(path.Child("variable"), expr.Engine, "engine not supported") - default: - return expr.Statement, nil - } - default: - return typed, nil - } - }, - ) -} diff --git a/pkg/engine/template/template.go b/pkg/engine/template/template.go deleted file mode 100644 index 19a0aa6df..000000000 --- a/pkg/engine/template/template.go +++ /dev/null @@ -1,24 +0,0 @@ -package template - -import ( - "context" - - "github.com/jmespath-community/go-jmespath/pkg/binding" - "github.com/jmespath-community/go-jmespath/pkg/interpreter" - "github.com/jmespath-community/go-jmespath/pkg/parsing" -) - -func ExecuteJP(ctx context.Context, statement string, value any, bindings binding.Bindings, opts ...Option) (any, error) { - parser := parsing.NewParser() - compiled, err := parser.Parse(statement) - if err != nil { - return nil, err - } - return ExecuteAST(ctx, compiled, value, bindings, opts...) -} - -func ExecuteAST(ctx context.Context, ast parsing.ASTNode, value any, bindings binding.Bindings, opts ...Option) (any, error) { - o := buildOptions(opts...) - vm := interpreter.NewInterpreter(nil, bindings) - return vm.Execute(ast, value, interpreter.WithFunctionCaller(o.functionCaller)) -} diff --git a/pkg/json-engine/engine.go b/pkg/json-engine/engine.go index 862fae4ae..6d60bef98 100644 --- a/pkg/json-engine/engine.go +++ b/pkg/json-engine/engine.go @@ -7,9 +7,9 @@ import ( jpbinding "github.com/jmespath-community/go-jmespath/pkg/binding" "github.com/kyverno/kyverno-json/pkg/apis/policy/v1alpha1" + "github.com/kyverno/kyverno-json/pkg/core/templating" "github.com/kyverno/kyverno-json/pkg/engine" "github.com/kyverno/kyverno-json/pkg/engine/builder" - "github.com/kyverno/kyverno-json/pkg/engine/template" "github.com/kyverno/kyverno-json/pkg/matching" "k8s.io/apimachinery/pkg/util/validation/field" ) @@ -66,6 +66,7 @@ func New() engine.Engine[Request, Response] { resource any bindings jpbinding.Bindings } + compiler := templating.NewCompiler(templating.CompilerOptions{}) ruleEngine := builder. Function(func(ctx context.Context, r ruleRequest) []RuleResponse { bindings := r.bindings.Register("$rule", jpbinding.NewBinding(r.rule)) @@ -73,11 +74,11 @@ func New() engine.Engine[Request, Response] { var path *field.Path path = path.Child("context") for i, entry := range r.rule.Context { - bindings = bindings.Register("$"+entry.Name, template.NewContextBinding(path.Index(i), bindings, r.resource, entry.Variable.Value())) + bindings = bindings.Register("$"+entry.Name, compiler.NewBinding(path.Index(i), r.resource, bindings, entry.Variable.Value())) } identifier := "" if r.rule.Identifier != "" { - result, err := template.ExecuteJP(context.Background(), r.rule.Identifier, r.resource, bindings) + result, err := templating.ExecuteJP(r.rule.Identifier, r.resource, bindings, compiler) if err != nil { identifier = fmt.Sprintf("(error: %s)", err) } else { @@ -85,7 +86,7 @@ func New() engine.Engine[Request, Response] { } } if r.rule.Match != nil { - errs, err := matching.Match(ctx, nil, r.rule.Match, r.resource, bindings) + errs, err := matching.Match(nil, r.rule.Match, r.resource, bindings, compiler) if err != nil { return []RuleResponse{{ Rule: r.rule, @@ -100,7 +101,7 @@ func New() engine.Engine[Request, Response] { } } if r.rule.Exclude != nil { - errs, err := matching.Match(ctx, nil, r.rule.Exclude, r.resource, bindings) + errs, err := matching.Match(nil, r.rule.Exclude, r.resource, bindings, compiler) if err != nil { return []RuleResponse{{ Rule: r.rule, @@ -116,7 +117,7 @@ func New() engine.Engine[Request, Response] { } var feedback map[string]Feedback for _, f := range r.rule.Feedback { - result, err := template.ExecuteJP(context.Background(), f.Value, r.resource, bindings) + result, err := templating.ExecuteJP(f.Value, r.resource, bindings, compiler) if feedback == nil { feedback = map[string]Feedback{} } @@ -130,7 +131,7 @@ func New() engine.Engine[Request, Response] { } } } - violations, err := matching.MatchAssert(ctx, nil, r.rule.Assert, r.resource, bindings) + violations, err := matching.MatchAssert(nil, r.rule.Assert, r.resource, bindings, compiler) if err != nil { return []RuleResponse{{ Rule: r.rule, diff --git a/pkg/matching/match.go b/pkg/matching/match.go index 616141a63..01743a17a 100644 --- a/pkg/matching/match.go +++ b/pkg/matching/match.go @@ -1,12 +1,11 @@ package matching import ( - "context" "strings" "github.com/jmespath-community/go-jmespath/pkg/binding" "github.com/kyverno/kyverno-json/pkg/apis/policy/v1alpha1" - "github.com/kyverno/kyverno-json/pkg/engine/template" + "github.com/kyverno/kyverno-json/pkg/core/templating" "k8s.io/apimachinery/pkg/util/validation/field" ) @@ -38,8 +37,7 @@ func (r Results) Error() string { return strings.Join(lines, "\n") } -// func MatchAssert(ctx context.Context, path *field.Path, match *v1alpha1.Assert, actual any, bindings binding.Bindings, opts ...template.Option) ([]error, error) { -func MatchAssert(ctx context.Context, path *field.Path, match v1alpha1.Assert, actual any, bindings binding.Bindings, opts ...template.Option) ([]Result, error) { +func MatchAssert(path *field.Path, match v1alpha1.Assert, actual any, bindings binding.Bindings, compiler templating.Compiler) ([]Result, error) { if len(match.Any) == 0 && len(match.All) == 0 { return nil, field.Invalid(path, match, "an empty assert is not valid") } else { @@ -48,11 +46,11 @@ func MatchAssert(ctx context.Context, path *field.Path, match v1alpha1.Assert, a path := path.Child("any") for i, assertion := range match.Any { path := path.Index(i).Child("check") - parsed, err := assertion.Check.Assertion() + parsed, err := assertion.Check.Assertion(compiler) if err != nil { return fails, err } - checkFails, err := parsed.Assert(ctx, path, actual, bindings, opts...) + checkFails, err := parsed.Assert(path, actual, bindings) if err != nil { return fails, err } @@ -64,7 +62,7 @@ func MatchAssert(ctx context.Context, path *field.Path, match v1alpha1.Assert, a ErrorList: checkFails, } if assertion.Message != nil { - fail.Message = assertion.Message.Format(actual, bindings, opts...) + fail.Message = assertion.Message.Format(actual, bindings, compiler.Options().Jp...) } fails = append(fails, fail) } @@ -77,11 +75,11 @@ func MatchAssert(ctx context.Context, path *field.Path, match v1alpha1.Assert, a path := path.Child("all") for i, assertion := range match.All { path := path.Index(i).Child("check") - parsed, err := assertion.Check.Assertion() + parsed, err := assertion.Check.Assertion(compiler) if err != nil { return fails, err } - checkFails, err := parsed.Assert(ctx, path, actual, bindings, opts...) + checkFails, err := parsed.Assert(path, actual, bindings) if err != nil { return fails, err } @@ -90,7 +88,7 @@ func MatchAssert(ctx context.Context, path *field.Path, match v1alpha1.Assert, a ErrorList: checkFails, } if assertion.Message != nil { - fail.Message = assertion.Message.Format(actual, bindings, opts...) + fail.Message = assertion.Message.Format(actual, bindings, compiler.Options().Jp...) } fails = append(fails, fail) } @@ -101,20 +99,20 @@ func MatchAssert(ctx context.Context, path *field.Path, match v1alpha1.Assert, a } } -func Match(ctx context.Context, path *field.Path, match *v1alpha1.Match, actual any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error) { +func Match(path *field.Path, match *v1alpha1.Match, actual any, bindings binding.Bindings, compiler templating.Compiler) (field.ErrorList, error) { if match == nil || (len(match.Any) == 0 && len(match.All) == 0) { return nil, field.Invalid(path, match, "an empty match is not valid") } else { var errs field.ErrorList if len(match.Any) != 0 { - _errs, err := MatchAny(ctx, path.Child("any"), match.Any, actual, bindings, opts...) + _errs, err := MatchAny(path.Child("any"), match.Any, actual, bindings, compiler) if err != nil { return errs, err } errs = append(errs, _errs...) } if len(match.All) != 0 { - _errs, err := MatchAll(ctx, path.Child("all"), match.All, actual, bindings, opts...) + _errs, err := MatchAll(path.Child("all"), match.All, actual, bindings, compiler) if err != nil { return errs, err } @@ -124,15 +122,15 @@ func Match(ctx context.Context, path *field.Path, match *v1alpha1.Match, actual } } -func MatchAny(ctx context.Context, path *field.Path, assertions []v1alpha1.AssertionTree, actual any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error) { +func MatchAny(path *field.Path, assertions []v1alpha1.AssertionTree, actual any, bindings binding.Bindings, compiler templating.Compiler) (field.ErrorList, error) { var errs field.ErrorList for i, assertion := range assertions { path := path.Index(i) - assertion, err := assertion.Assertion() + assertion, err := assertion.Assertion(compiler) if err != nil { return errs, err } - _errs, err := assertion.Assert(ctx, path, actual, bindings, opts...) + _errs, err := assertion.Assert(path, actual, bindings) if err != nil { return errs, err } @@ -144,15 +142,15 @@ func MatchAny(ctx context.Context, path *field.Path, assertions []v1alpha1.Asser return errs, nil } -func MatchAll(ctx context.Context, path *field.Path, assertions []v1alpha1.AssertionTree, actual any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error) { +func MatchAll(path *field.Path, assertions []v1alpha1.AssertionTree, actual any, bindings binding.Bindings, compiler templating.Compiler) (field.ErrorList, error) { var errs field.ErrorList for i, assertion := range assertions { path := path.Index(i) - assertion, err := assertion.Assertion() + assertion, err := assertion.Assertion(compiler) if err != nil { return errs, err } - _errs, err := assertion.Assert(ctx, path, actual, bindings, opts...) + _errs, err := assertion.Assert(path, actual, bindings) if err != nil { return errs, err } diff --git a/pkg/policy/load_test.go b/pkg/policy/load_test.go index 80826316d..dd06dbef4 100644 --- a/pkg/policy/load_test.go +++ b/pkg/policy/load_test.go @@ -5,7 +5,6 @@ import ( "testing" "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" "github.com/kyverno/kyverno-json/pkg/apis/policy/v1alpha1" "github.com/stretchr/testify/assert" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -124,7 +123,7 @@ func TestLoad(t *testing.T) { } else { assert.NoError(t, err) } - assert.True(t, cmp.Equal(tt.want, got, cmp.AllowUnexported(v1alpha1.AssertionTree{}), cmpopts.IgnoreFields(v1alpha1.AssertionTree{}, "_assertion"))) + assert.True(t, cmp.Equal(tt.want, got, cmp.AllowUnexported(v1alpha1.AssertionTree{}))) }) } } diff --git a/pkg/server/playground/handler.go b/pkg/server/playground/handler.go index 86f0b828b..62b6ee0b4 100644 --- a/pkg/server/playground/handler.go +++ b/pkg/server/playground/handler.go @@ -8,7 +8,7 @@ import ( "github.com/gin-gonic/gin" "github.com/kyverno/kyverno-json/pkg/apis/policy/v1alpha1" - "github.com/kyverno/kyverno-json/pkg/engine/template" + "github.com/kyverno/kyverno-json/pkg/core/templating" jsonengine "github.com/kyverno/kyverno-json/pkg/json-engine" "github.com/kyverno/kyverno-json/pkg/server/model" "github.com/loopfz/gadgeto/tonic" @@ -34,7 +34,7 @@ func newHandler() (gin.HandlerFunc, error) { } // apply pre processors for _, preprocessor := range in.Preprocessors { - result, err := template.ExecuteJP(context.Background(), preprocessor, payload, nil) + result, err := templating.ExecuteJP(preprocessor, payload, nil, templating.NewCompiler(templating.CompilerOptions{})) if err != nil { return nil, fmt.Errorf("failed to execute prepocessor (%s) - %w", preprocessor, err) } diff --git a/pkg/server/scan/handler.go b/pkg/server/scan/handler.go index 0af2af14c..faee8e76a 100644 --- a/pkg/server/scan/handler.go +++ b/pkg/server/scan/handler.go @@ -8,7 +8,7 @@ import ( "github.com/gin-gonic/gin" "github.com/kyverno/kyverno-json/pkg/apis/policy/v1alpha1" - "github.com/kyverno/kyverno-json/pkg/engine/template" + "github.com/kyverno/kyverno-json/pkg/core/templating" jsonengine "github.com/kyverno/kyverno-json/pkg/json-engine" "github.com/kyverno/kyverno-json/pkg/server/model" "github.com/loopfz/gadgeto/tonic" @@ -26,7 +26,7 @@ func newHandler(policyProvider PolicyProvider) (gin.HandlerFunc, error) { payload := in.Payload // apply pre processors for _, preprocessor := range in.Preprocessors { - result, err := template.ExecuteJP(context.Background(), preprocessor, payload, nil) + result, err := templating.ExecuteJP(preprocessor, payload, nil, templating.NewCompiler(templating.CompilerOptions{})) if err != nil { return nil, fmt.Errorf("failed to execute prepocessor (%s) - %w", preprocessor, err) } diff --git a/test/commands/scan/cel/out.txt b/test/commands/scan/cel/out.txt new file mode 100644 index 000000000..7592126be --- /dev/null +++ b/test/commands/scan/cel/out.txt @@ -0,0 +1,6 @@ +Loading policies ... +Loading payload ... +Pre processing ... +Running ( evaluating 1 resource against 1 policy ) ... +- test / foo-bar-4 / PASSED +Done diff --git a/test/commands/scan/cel/payload.yaml b/test/commands/scan/cel/payload.yaml new file mode 100644 index 000000000..a48ec4748 --- /dev/null +++ b/test/commands/scan/cel/payload.yaml @@ -0,0 +1,2 @@ +foo: + bar: 4 diff --git a/test/commands/scan/cel/policy.yaml b/test/commands/scan/cel/policy.yaml new file mode 100644 index 000000000..066382249 --- /dev/null +++ b/test/commands/scan/cel/policy.yaml @@ -0,0 +1,18 @@ +apiVersion: json.kyverno.io/v1alpha1 +kind: ValidatingPolicy +metadata: + name: test +spec: + rules: + - name: foo-bar-4 + context: + - name: celFoo + variable: (cel; 4) + - name: jpFoo + variable: (jp; $celFoo) + - name: celFoo + variable: (cel; bindings.resolve('jpFoo')) + assert: + all: + - check: + (cel; object.foo.bar): (cel; bindings.resolve('celFoo'))