Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Check for None values in branch nodes #592

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ require (
github.com/DiSiqueira/GoTree v1.0.1-0.20180907134536-53a8e837f295
github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1
github.com/fatih/color v1.13.0
github.com/flyteorg/flyteidl v1.5.13
github.com/flyteorg/flyteidl v1.5.16
github.com/flyteorg/flyteplugins v1.1.30
github.com/flyteorg/flytestdlib v1.0.24
github.com/ghodss/yaml v1.0.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,8 @@ github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5Kwzbycv
github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w=
github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk=
github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/flyteorg/flyteidl v1.5.13 h1:IQ2Cw+u36ew3BPyRDAcHdzc/GyNEOXOxhKy9jbS4hbo=
github.com/flyteorg/flyteidl v1.5.13/go.mod h1:EtE/muM2lHHgBabjYcxqe9TWeJSL0kXwbI0RgVwI4Og=
github.com/flyteorg/flyteidl v1.5.16 h1:S70wD7K99nKHZxmo8U16Jjhy1kZwoBh5ZQhZf3/6MPU=
github.com/flyteorg/flyteidl v1.5.16/go.mod h1:EtE/muM2lHHgBabjYcxqe9TWeJSL0kXwbI0RgVwI4Og=
github.com/flyteorg/flyteplugins v1.1.30 h1:AVqS6Eb9Nr9Z3Mb3CtP04ffAVS9LMx5Q1Z7AyFFk/e0=
github.com/flyteorg/flyteplugins v1.1.30/go.mod h1:FujFQdL/f9r1HvFR81JCiNYusDy9F0lExhyoyMHXXbg=
github.com/flyteorg/flytestdlib v1.0.24 h1:jDvymcjlsTRCwOtxPapro0WZBe3isTz+T3Tiq+mZUuk=
Expand Down
9 changes: 8 additions & 1 deletion pkg/compiler/validators/condition.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ func validateOperand(node c.NodeBuilder, paramName string, operand *flyte.Operan
} else if operand.GetPrimitive() != nil {
// no validation
literalType = literalTypeForPrimitive(operand.GetPrimitive())
} else if operand.GetScalar().GetPrimitive() != nil {
literalType = literalTypeForPrimitive(operand.GetPrimitive())
} else if operand.GetScalar().GetNoneType() != nil {
literalType = &flyte.LiteralType{Type: &flyte.LiteralType_Simple{Simple: flyte.SimpleType_NONE}}
} else if len(operand.GetVar()) > 0 {
if node.GetInterface() != nil {
if param, paramOk := validateInputVar(node, operand.GetVar(), requireParamType, errs.NewScope()); paramOk {
Expand All @@ -41,7 +45,10 @@ func ValidateBooleanExpression(w c.WorkflowBuilder, node c.NodeBuilder, expr *fl
expr.GetComparison().GetRightValue(), requireParamType, errs.NewScope())
op2Type, op2Valid := validateOperand(node, "LeftValue",
expr.GetComparison().GetLeftValue(), requireParamType, errs.NewScope())
if op1Valid && op2Valid && op1Type != nil && op2Type != nil {
// Valid expression
// 1. Both operands are primitive types and have the same types.
// 2. One of the operands is the None type.
if op1Valid && op2Valid && op1Type != nil && op2Type != nil && op1Type.GetSimple() != flyte.SimpleType_NONE && op2Type.GetSimple() != flyte.SimpleType_NONE {
if op1Type.String() != op2Type.String() {
errs.Collect(errors.NewMismatchingTypesErr(node.GetId(), "RightValue",
op1Type.String(), op2Type.String()))
Expand Down
56 changes: 34 additions & 22 deletions pkg/controller/nodes/branch/comparator.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,21 @@ var perTypeComparators = map[string]comparators{
},
}

func Evaluate(lValue *core.Primitive, rValue *core.Primitive, op core.ComparisonExpression_Operator) (bool, error) {
lValueType := reflect.TypeOf(lValue.Value)
rValueType := reflect.TypeOf(rValue.Value)
func Evaluate(lValue *core.Scalar, rValue *core.Scalar, op core.ComparisonExpression_Operator) (bool, error) {
if lValue.GetNoneType() != nil || rValue.GetNoneType() != nil {
lIsNone := lValue.GetNoneType() != nil
rIsNone := rValue.GetNoneType() != nil
switch op {
case core.ComparisonExpression_EQ:
return lIsNone == rIsNone, nil
case core.ComparisonExpression_NEQ:
return lIsNone != rIsNone, nil
default:
return false, errors.Errorf(ErrorCodeMalformedBranch, "Comparison between nil and non-nil values with operator [%v] is not supported. lVal[%v]:rVal[%v]", op, lValue, rValue)
}
}
lValueType := reflect.TypeOf(lValue.GetPrimitive().Value)
rValueType := reflect.TypeOf(rValue.GetPrimitive().Value)
if lValueType != rValueType {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Comparison between different primitives types. lVal[%v]:rVal[%v]", lValueType, rValueType)
}
Expand All @@ -90,50 +102,50 @@ func Evaluate(lValue *core.Primitive, rValue *core.Primitive, op core.Comparison
if isBoolean {
return false, errors.Errorf(ErrorCodeMalformedBranch, "[GT] not defined for boolean operands.")
}
return comps.gt(lValue, rValue), nil
return comps.gt(lValue.GetPrimitive(), rValue.GetPrimitive()), nil
case core.ComparisonExpression_GTE:
if isBoolean {
return false, errors.Errorf(ErrorCodeMalformedBranch, "[GTE] not defined for boolean operands.")
}
return comps.eq(lValue, rValue) || comps.gt(lValue, rValue), nil
return comps.eq(lValue.GetPrimitive(), rValue.GetPrimitive()) || comps.gt(lValue.GetPrimitive(), rValue.GetPrimitive()), nil
case core.ComparisonExpression_LT:
if isBoolean {
return false, errors.Errorf(ErrorCodeMalformedBranch, "[LT] not defined for boolean operands.")
}
return !(comps.gt(lValue, rValue) || comps.eq(lValue, rValue)), nil
return !(comps.gt(lValue.GetPrimitive(), rValue.GetPrimitive()) || comps.eq(lValue.GetPrimitive(), rValue.GetPrimitive())), nil
case core.ComparisonExpression_LTE:
if isBoolean {
return false, errors.Errorf(ErrorCodeMalformedBranch, "[LTE] not defined for boolean operands.")
}
return !comps.gt(lValue, rValue), nil
return !comps.gt(lValue.GetPrimitive(), rValue.GetPrimitive()), nil
case core.ComparisonExpression_EQ:
return comps.eq(lValue, rValue), nil
return comps.eq(lValue.GetPrimitive(), rValue.GetPrimitive()), nil
case core.ComparisonExpression_NEQ:
return !comps.eq(lValue, rValue), nil
return !comps.eq(lValue.GetPrimitive(), rValue.GetPrimitive()), nil
}
return false, errors.Errorf(ErrorCodeMalformedBranch, "Unsupported operator type in Propeller. System error.")
}

func Evaluate1(lValue *core.Primitive, rValue *core.Literal, op core.ComparisonExpression_Operator) (bool, error) {
if rValue.GetScalar() == nil || rValue.GetScalar().GetPrimitive() == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. RHS Variable is non primitive.")
func Evaluate1(lValue *core.Scalar, rValue *core.Literal, op core.ComparisonExpression_Operator) (bool, error) {
if rValue.GetScalar() == nil || (rValue.GetScalar().GetPrimitive() == nil && rValue.GetScalar().GetNoneType() == nil) {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. RHS Variable [%v] is non primitive", rValue)
}
return Evaluate(lValue, rValue.GetScalar().GetPrimitive(), op)
return Evaluate(lValue, rValue.GetScalar(), op)
}

func Evaluate2(lValue *core.Literal, rValue *core.Primitive, op core.ComparisonExpression_Operator) (bool, error) {
if lValue.GetScalar() == nil || lValue.GetScalar().GetPrimitive() == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. LHS Variable is non primitive.")
func Evaluate2(lValue *core.Literal, rValue *core.Scalar, op core.ComparisonExpression_Operator) (bool, error) {
if lValue.GetScalar() == nil || (lValue.GetScalar().GetPrimitive() == nil && lValue.GetScalar().GetNoneType() == nil) {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. LHS Variable [%v] is non primitive.", lValue)
}
return Evaluate(lValue.GetScalar().GetPrimitive(), rValue, op)
return Evaluate(lValue.GetScalar(), rValue, op)
}

func EvaluateLiterals(lValue *core.Literal, rValue *core.Literal, op core.ComparisonExpression_Operator) (bool, error) {
if lValue.GetScalar() == nil || lValue.GetScalar().GetPrimitive() == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. LHS Variable is non primitive.")
if lValue.GetScalar() == nil || (lValue.GetScalar().GetPrimitive() == nil && lValue.GetScalar().GetNoneType() == nil) {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. LHS Variable [%v] is non primitive.", lValue)
}
if rValue.GetScalar() == nil || rValue.GetScalar().GetPrimitive() == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. RHS Variable is non primitive")
if rValue.GetScalar() == nil || (rValue.GetScalar().GetPrimitive() == nil && rValue.GetScalar().GetNoneType() == nil) {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. RHS Variable [%v] is non primitive", rValue)
}
return Evaluate(lValue.GetScalar().GetPrimitive(), rValue.GetScalar().GetPrimitive(), op)
return Evaluate(lValue.GetScalar(), rValue.GetScalar(), op)
}
24 changes: 12 additions & 12 deletions pkg/controller/nodes/branch/comparator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import (
)

func TestEvaluate_int(t *testing.T) {
p1 := coreutils.MustMakePrimitive(1)
p2 := coreutils.MustMakePrimitive(2)
p1 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(1)}}
p2 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(2)}}
{
// p1 > p2 = false
b, err := Evaluate(p1, p2, core.ComparisonExpression_GT)
Expand Down Expand Up @@ -82,8 +82,8 @@ func TestEvaluate_int(t *testing.T) {
}

func TestEvaluate_float(t *testing.T) {
p1 := coreutils.MustMakePrimitive(1.0)
p2 := coreutils.MustMakePrimitive(2.0)
p1 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(1)}}
p2 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(2)}}
{
// p1 > p2 = false
b, err := Evaluate(p1, p2, core.ComparisonExpression_GT)
Expand Down Expand Up @@ -153,8 +153,8 @@ func TestEvaluate_float(t *testing.T) {
}

func TestEvaluate_string(t *testing.T) {
p1 := coreutils.MustMakePrimitive("a")
p2 := coreutils.MustMakePrimitive("b")
p1 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive("a")}}
p2 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive("b")}}
{
// p1 > p2 = false
b, err := Evaluate(p1, p2, core.ComparisonExpression_GT)
Expand Down Expand Up @@ -224,8 +224,8 @@ func TestEvaluate_string(t *testing.T) {
}

func TestEvaluate_datetime(t *testing.T) {
p1 := coreutils.MustMakePrimitive(time.Date(2018, 7, 4, 12, 00, 00, 00, time.UTC))
p2 := coreutils.MustMakePrimitive(time.Date(2018, 7, 4, 12, 00, 01, 00, time.UTC))
p1 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(time.Date(2018, 7, 4, 12, 00, 00, 00, time.UTC))}}
p2 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(time.Date(2018, 7, 4, 12, 00, 01, 00, time.UTC))}}
{
// p1 > p2 = false
b, err := Evaluate(p1, p2, core.ComparisonExpression_GT)
Expand Down Expand Up @@ -295,8 +295,8 @@ func TestEvaluate_datetime(t *testing.T) {
}

func TestEvaluate_duration(t *testing.T) {
p1 := coreutils.MustMakePrimitive(10 * time.Second)
p2 := coreutils.MustMakePrimitive(11 * time.Second)
p1 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(10 * time.Second)}}
p2 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(11 * time.Second)}}
{
// p1 > p2 = false
b, err := Evaluate(p1, p2, core.ComparisonExpression_GT)
Expand Down Expand Up @@ -366,8 +366,8 @@ func TestEvaluate_duration(t *testing.T) {
}

func TestEvaluate_boolean(t *testing.T) {
p1 := coreutils.MustMakePrimitive(true)
p2 := coreutils.MustMakePrimitive(false)
p1 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(true)}}
p2 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(false)}}
f := func(op core.ComparisonExpression_Operator) {
// GT/LT = false
msg := fmt.Sprintf("Evaluating: [%s]", op.String())
Expand Down
42 changes: 32 additions & 10 deletions pkg/controller/nodes/branch/evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,53 @@ const ErrorCodeFailedFetchOutputs = "FailedFetchOutputs"
func EvaluateComparison(expr *core.ComparisonExpression, nodeInputs *core.LiteralMap) (bool, error) {
var lValue *core.Literal
var rValue *core.Literal
var lPrim *core.Primitive
var rPrim *core.Primitive
var lPrim *core.Scalar
var rPrim *core.Scalar

if expr.GetLeftValue().GetPrimitive() == nil {
if nodeInputs == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar())
if expr.GetLeftValue().GetScalar().GetNoneType() != nil {
lValue = &core.Literal{Value: &core.Literal_Scalar{Scalar: expr.GetLeftValue().GetScalar()}}
} else if expr.GetLeftValue().GetScalar().GetUnion() != nil {
lValue = expr.GetLeftValue().GetScalar().GetUnion().GetValue()
} else {
if nodeInputs == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar())
}
input := nodeInputs.Literals[expr.GetLeftValue().GetVar()]
if input.GetScalar().GetUnion().GetValue() != nil {
lValue = input.GetScalar().GetUnion().GetValue()
} else {
lValue = input
}
}
lValue = nodeInputs.Literals[expr.GetLeftValue().GetVar()]
if lValue == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar())
}
} else {
lPrim = expr.GetLeftValue().GetPrimitive()
lPrim = &core.Scalar{Value: &core.Scalar_Primitive{Primitive: expr.GetLeftValue().GetPrimitive()}}
}

if expr.GetRightValue().GetPrimitive() == nil {
if nodeInputs == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar())
if expr.GetRightValue().GetScalar().GetNoneType() != nil {
rValue = &core.Literal{Value: &core.Literal_Scalar{Scalar: expr.GetRightValue().GetScalar()}}
} else if expr.GetRightValue().GetScalar().GetUnion() != nil {
rValue = expr.GetRightValue().GetScalar().GetUnion().GetValue()
} else {
if nodeInputs == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar())
}
input := nodeInputs.Literals[expr.GetRightValue().GetVar()]
if input.GetScalar().GetUnion().GetValue() != nil {
rValue = input.GetScalar().GetUnion().GetValue()
} else {
rValue = input
}
}
rValue = nodeInputs.Literals[expr.GetRightValue().GetVar()]
if rValue == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetRightValue().GetVar())
}
} else {
rPrim = expr.GetRightValue().GetPrimitive()
rPrim = &core.Scalar{Value: &core.Scalar_Primitive{Primitive: expr.GetRightValue().GetPrimitive()}}
}

if lValue != nil && rValue != nil {
Expand Down
84 changes: 84 additions & 0 deletions pkg/controller/nodes/branch/evaluator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,16 @@ func createUnaryConjunction(l *core.ComparisonExpression, op core.ConjunctionExp
}
}

func getNoneOperand() *core.Operand {
return &core.Operand{
Val: &core.Operand_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_NoneType{NoneType: &core.Void{}},
},
},
}
}

func TestEvaluateComparison(t *testing.T) {
t.Run("ComparePrimitives", func(t *testing.T) {
// Compare primitives
Expand Down Expand Up @@ -100,6 +110,80 @@ func TestEvaluateComparison(t *testing.T) {
assert.NoError(t, err)
assert.False(t, v)
})
t.Run("CompareNoneAndLiteral", func(t *testing.T) {
// Compare lVal -> None and rVal -> literal
exp := &core.ComparisonExpression{
LeftValue: getNoneOperand(),
Operator: core.ComparisonExpression_EQ,
RightValue: &core.Operand{
Val: &core.Operand_Primitive{
Primitive: coreutils.MustMakePrimitive(1),
},
},
}
v, err := EvaluateComparison(exp, nil)
assert.NoError(t, err)
assert.False(t, v)
})
t.Run("CompareLiteralAndNone", func(t *testing.T) {
// Compare lVal -> literal and rVal -> None
exp := &core.ComparisonExpression{
LeftValue: &core.Operand{
Val: &core.Operand_Primitive{
Primitive: coreutils.MustMakePrimitive(1),
},
},
Operator: core.ComparisonExpression_NEQ,
RightValue: getNoneOperand(),
}
v, err := EvaluateComparison(exp, nil)
assert.NoError(t, err)
assert.True(t, v)
})
t.Run("CompareUnionLiteralAndNone", func(t *testing.T) {
// Compare lVal -> literal and rVal -> None
exp := &core.ComparisonExpression{
LeftValue: &core.Operand{
Val: &core.Operand_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Union{
Union: &core.Union{
Value: &core.Literal{
Value: &core.Literal_Scalar{Scalar: &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(1)}}},
},
},
},
},
},
},
Operator: core.ComparisonExpression_NEQ,
RightValue: getNoneOperand(),
}
v, err := EvaluateComparison(exp, nil)
assert.NoError(t, err)
assert.True(t, v)
})
t.Run("CompareNoneAndNone", func(t *testing.T) {
// Compare lVal -> None and rVal -> None
exp := &core.ComparisonExpression{
LeftValue: getNoneOperand(),
Operator: core.ComparisonExpression_EQ,
RightValue: getNoneOperand(),
}
v, err := EvaluateComparison(exp, nil)
assert.NoError(t, err)
assert.True(t, v)
})
t.Run("CompareNoneAndNoneWithError", func(t *testing.T) {
// Compare lVal -> None and rVal -> None
exp := &core.ComparisonExpression{
LeftValue: getNoneOperand(),
Operator: core.ComparisonExpression_GTE,
RightValue: getNoneOperand(),
}
_, err := EvaluateComparison(exp, nil)
assert.Error(t, err)
})
t.Run("CompareLiteralAndPrimitive", func(t *testing.T) {

// Compare lVal -> literal and rVal -> primitive
Expand Down
Loading