From 4bb31380423a803889b4c4732a6114a1c8309f4c Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Mon, 3 Apr 2023 15:46:18 -0500 Subject: [PATCH 01/62] updated flyteidl to local to get ArrayNode Signed-off-by: Daniel Rammer --- go.mod | 2 ++ go.sum | 2 -- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index ca9d98ee2..9e6362dab 100644 --- a/go.mod +++ b/go.mod @@ -147,3 +147,5 @@ require ( ) replace github.com/aws/amazon-sagemaker-operator-for-k8s => github.com/aws/amazon-sagemaker-operator-for-k8s v1.0.1-0.20210303003444-0fb33b1fd49d + +replace github.com/flyteorg/flyteidl => ../flyteidl diff --git a/go.sum b/go.sum index f84741c20..55a440130 100644 --- a/go.sum +++ b/go.sum @@ -260,8 +260,6 @@ github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5Kwzbycv github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/flyteorg/flyteidl v1.3.14 h1:o5M0g/r6pXTPu5PEurbYxbQmuOu3hqqsaI2M6uvK0N8= -github.com/flyteorg/flyteidl v1.3.14/go.mod h1:Pkt2skI1LiHs/2ZoekBnyPhuGOFMiuul6HHcKGZBsbM= github.com/flyteorg/flyteplugins v1.0.45 h1:I/N4ehOxX6ln8DivyZ9gayp/UYiBcqoizBbG1hfwIXM= github.com/flyteorg/flyteplugins v1.0.45/go.mod h1:ztsonku5fKwyxcIg1k69PTiBVjRI6d3nK5DnC+iwx08= github.com/flyteorg/flytestdlib v1.0.15 h1:kv9jDQmytbE84caY+pkZN8trJU2ouSAmESzpTEhfTt0= From 33deb420d2e06ef9a8aa752af6332d5499a5afa1 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Mon, 3 Apr 2023 16:25:17 -0500 Subject: [PATCH 02/62] added boilerplate to support ArrayNode Signed-off-by: Daniel Rammer --- pkg/apis/flyteworkflow/v1alpha1/array.go | 8 ++ pkg/apis/flyteworkflow/v1alpha1/iface.go | 20 +++++ .../flyteworkflow/v1alpha1/node_status.go | 48 ++++++++++++ pkg/apis/flyteworkflow/v1alpha1/nodes.go | 8 ++ pkg/compiler/common/reader.go | 1 + pkg/compiler/transformers/k8s/node.go | 2 + pkg/compiler/validators/interface.go | 2 + pkg/controller/nodes/array/handler.go | 76 +++++++++++++++++++ pkg/controller/nodes/gate/handler.go | 2 +- pkg/controller/nodes/handler/state.go | 6 ++ .../nodes/handler/transition_info.go | 4 + pkg/controller/nodes/handler_factory.go | 19 +++-- pkg/controller/nodes/node_state_manager.go | 16 ++++ pkg/controller/nodes/transformers.go | 6 ++ 14 files changed, 207 insertions(+), 11 deletions(-) create mode 100644 pkg/apis/flyteworkflow/v1alpha1/array.go create mode 100644 pkg/controller/nodes/array/handler.go diff --git a/pkg/apis/flyteworkflow/v1alpha1/array.go b/pkg/apis/flyteworkflow/v1alpha1/array.go new file mode 100644 index 000000000..91e86bd21 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/array.go @@ -0,0 +1,8 @@ +package v1alpha1 + +import ( +) + +type ArrayNodeSpec struct { + // TODO @hamersaw - fill out evaluation +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/iface.go b/pkg/apis/flyteworkflow/v1alpha1/iface.go index c52361b23..f57c39593 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/iface.go +++ b/pkg/apis/flyteworkflow/v1alpha1/iface.go @@ -45,6 +45,7 @@ const ( NodeKindBranch NodeKind = "branch" // A Branch node with conditions NodeKindWorkflow NodeKind = "workflow" // Either an inline workflow or a remote workflow definition NodeKindGate NodeKind = "gate" // A Gate node with a condition + NodeKindArray NodeKind = "array" // An array node with a subtask Node NodeKindStart NodeKind = "start" // Start node is a special node NodeKindEnd NodeKind = "end" ) @@ -253,6 +254,10 @@ type ExecutableGateNode interface { GetSleep() *core.SleepCondition } +type ExecutableArrayNode interface { + // TODO @hamersaw - complete ExecutableArrayNode +} + type ExecutableWorkflowNodeStatus interface { GetWorkflowNodePhase() WorkflowNodePhase GetExecutionError() *core.ExecutionError @@ -275,6 +280,16 @@ type MutableGateNodeStatus interface { SetGateNodePhase(phase GateNodePhase) } +type ExecutableArrayNodeStatus interface { + GetArrayNodePhase() ArrayNodePhase +} + +type MutableArrayNodeStatus interface { + Mutable + ExecutableArrayNodeStatus + SetArrayNodePhase(phase ArrayNodePhase) +} + type Mutable interface { IsDirty() bool } @@ -310,6 +325,10 @@ type MutableNodeStatus interface { GetGateNodeStatus() MutableGateNodeStatus GetOrCreateGateNodeStatus() MutableGateNodeStatus ClearGateNodeStatus() + + GetArrayNodeStatus() MutableArrayNodeStatus + GetOrCreateArrayNodeStatus() MutableArrayNodeStatus + ClearArrayNodeStatus() } type ExecutionTimeInfo interface { @@ -393,6 +412,7 @@ type ExecutableNode interface { GetBranchNode() ExecutableBranchNode GetWorkflowNode() ExecutableWorkflowNode GetGateNode() ExecutableGateNode + GetArrayNode() ExecutableArrayNode GetOutputAlias() []Alias GetInputBindings() []*Binding GetResources() *v1.ResourceRequirements diff --git a/pkg/apis/flyteworkflow/v1alpha1/node_status.go b/pkg/apis/flyteworkflow/v1alpha1/node_status.go index 7aea3f2b8..758a72206 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/node_status.go +++ b/pkg/apis/flyteworkflow/v1alpha1/node_status.go @@ -207,6 +207,30 @@ func (in *GateNodeStatus) SetGateNodePhase(phase GateNodePhase) { } } +type ArrayNodePhase int + +const ( + ArrayNodePhaseUndefined ArrayNodePhase = iota + ArrayNodePhaseExecuting + // TODO @hamersaw - need more phases +) + +type ArrayNodeStatus struct { + MutableStruct + Phase ArrayNodePhase `json:"phase,omitempty"` +} + +func (in *ArrayNodeStatus) GetArrayNodePhase() ArrayNodePhase { + return in.Phase +} + +func (in *ArrayNodeStatus) SetArrayNodePhase(phase ArrayNodePhase) { + if in.Phase != phase { + in.SetDirty() + in.Phase = phase + } +} + type NodeStatus struct { MutableStruct Phase NodePhase `json:"phase,omitempty"` @@ -235,6 +259,7 @@ type NodeStatus struct { TaskNodeStatus *TaskNodeStatus `json:",omitempty"` DynamicNodeStatus *DynamicNodeStatus `json:"dynamicNodeStatus,omitempty"` GateNodeStatus *GateNodeStatus `json:"gateNodeStatus,omitempty"` + ArrayNodeStatus *ArrayNodeStatus `json:"arrayNodeStatus,omitempty"` // In case of Failing/Failed Phase, an execution error can be optionally associated with the Node Error *ExecutionError `json:"error,omitempty"` @@ -315,6 +340,13 @@ func (in *NodeStatus) GetGateNodeStatus() MutableGateNodeStatus { return in.GateNodeStatus } +func (in *NodeStatus) GetArrayNodeStatus() MutableArrayNodeStatus { + if in.ArrayNodeStatus == nil { + return nil + } + return in.ArrayNodeStatus +} + func (in NodeStatus) VisitNodeStatuses(visitor NodeStatusVisitFn) { for n, s := range in.SubNodeStatus { visitor(n, s) @@ -353,6 +385,11 @@ func (in *NodeStatus) ClearGateNodeStatus() { in.SetDirty() } +func (in *NodeStatus) ClearArrayNodeStatus() { + in.ArrayNodeStatus = nil + in.SetDirty() +} + func (in *NodeStatus) GetLastUpdatedAt() *metav1.Time { return in.LastUpdatedAt } @@ -459,6 +496,17 @@ func (in *NodeStatus) GetOrCreateGateNodeStatus() MutableGateNodeStatus { return in.GateNodeStatus } +func (in *NodeStatus) GetOrCreateArrayNodeStatus() MutableArrayNodeStatus { + if in.ArrayNodeStatus == nil { + in.SetDirty() + in.ArrayNodeStatus = &ArrayNodeStatus{ + MutableStruct: MutableStruct{}, + } + } + + return in.ArrayNodeStatus +} + func (in *NodeStatus) UpdatePhase(p NodePhase, occurredAt metav1.Time, reason string, err *core.ExecutionError) { if in.Phase == p { // We will not update the phase multiple times. This prevents the comparison from returning false positive diff --git a/pkg/apis/flyteworkflow/v1alpha1/nodes.go b/pkg/apis/flyteworkflow/v1alpha1/nodes.go index 682af365d..21c8b0261 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/nodes.go +++ b/pkg/apis/flyteworkflow/v1alpha1/nodes.go @@ -101,6 +101,7 @@ type NodeSpec struct { TaskRef *TaskID `json:"task,omitempty"` WorkflowNode *WorkflowNodeSpec `json:"workflow,omitempty"` GateNode *GateNodeSpec `json:"gate,omitempty"` + ArrayNode *ArrayNodeSpec `json:"array,omitempty"` InputBindings []*Binding `json:"inputBindings,omitempty"` Config *typesv1.ConfigMap `json:"config,omitempty"` RetryStrategy *RetryStrategy `json:"retry,omitempty"` @@ -206,6 +207,13 @@ func (in *NodeSpec) GetGateNode() ExecutableGateNode { return in.GateNode } +func (in *NodeSpec) GetArrayNode() ExecutableArrayNode { + if in.ArrayNode == nil { + return nil + } + return in.ArrayNode +} + func (in *NodeSpec) GetTaskID() *TaskID { return in.TaskRef } diff --git a/pkg/compiler/common/reader.go b/pkg/compiler/common/reader.go index d0ea36172..74ee17d40 100644 --- a/pkg/compiler/common/reader.go +++ b/pkg/compiler/common/reader.go @@ -41,6 +41,7 @@ type Node interface { GetTask() Task GetSubWorkflow() Workflow GetGateNode() *core.GateNode + GetArrayNode() *core.ArrayNode } // An immutable task that represents the final output of the compiler. diff --git a/pkg/compiler/transformers/k8s/node.go b/pkg/compiler/transformers/k8s/node.go index f8d5947c9..1461be881 100644 --- a/pkg/compiler/transformers/k8s/node.go +++ b/pkg/compiler/transformers/k8s/node.go @@ -154,6 +154,8 @@ func buildNodeSpec(n *core.Node, tasks []*core.CompiledTask, errs errors.Compile }, } } + case *core.Node_ArrayNode: + // TODO @hamersaw - complete default: if n.GetId() == v1alpha1.StartNodeID { nodeSpec.Kind = v1alpha1.NodeKindStart diff --git a/pkg/compiler/validators/interface.go b/pkg/compiler/validators/interface.go index 1ae7ecd5b..f5a11345d 100644 --- a/pkg/compiler/validators/interface.go +++ b/pkg/compiler/validators/interface.go @@ -153,6 +153,8 @@ func ValidateUnderlyingInterface(w c.WorkflowBuilder, node c.NodeBuilder, errs e } else { errs.Collect(errors.NewNoConditionFound(node.GetId())) } + case *core.Node_ArrayNode: + // TODO @hamersaw complete default: errs.Collect(errors.NewValueRequiredErr(node.GetId(), "Target")) } diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go new file mode 100644 index 000000000..1ec0a9dfb --- /dev/null +++ b/pkg/controller/nodes/array/handler.go @@ -0,0 +1,76 @@ +package array + +import ( + "context" + + "github.com/flyteorg/flytepropeller/pkg/controller/config" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/promutils" +) + +//go:generate mockery -all -case=underscore + +// arrayNodeHandler is a handle implementation for processing array nodes +type arrayNodeHandler struct { + metrics metrics +} + +// metrics encapsulates the prometheus metrics for this handler +type metrics struct { + scope promutils.Scope +} + +// newMetrics initializes a new metrics struct +func newMetrics(scope promutils.Scope) metrics { + return metrics{ + scope: scope, + } +} + +// Abort stops the array node defined in the NodeExecutionContext +func (a *arrayNodeHandler) Abort(ctx context.Context, nCtx handler.NodeExecutionContext, reason string) error { + return nil // TODO @hamersaw - implement abort +} + +// Finalize completes the array node defined in the NodeExecutionContext +func (a *arrayNodeHandler) Finalize(ctx context.Context, _ handler.NodeExecutionContext) error { + return nil // TODO @hamersaw - implement finalize +} + +// FinalizeRequired defines whether or not this handler requires finalize to be called on +// node completion +func (a *arrayNodeHandler) FinalizeRequired() bool { + return false // TODO @hamersaw - implement finalize required +} + +// Handle is responsible for transitioning and reporting node state to complete the node defined +// by the NodeExecutionContext +func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.Transition, error) { + arrayNode := nCtx.Node().GetArrayNode() + arrayNodeState := nCtx.NodeStateReader().GetArrayNodeState() + + // TODO @hamersaw - handle array node + + // update array node status + if err := nCtx.NodeStateWriter().PutArrayNodeState(arrayNodeState); err != nil { + logger.Errorf(ctx, "failed to store ArrayNode state with err [%s]", err.Error()) + return handler.UnknownTransition, err + } + + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), nil +} + +// Setup handles any initialization requirements for this handler +func (a *arrayNodeHandler) Setup(_ context.Context, _ handler.SetupContext) error { + return nil // TODO @hamersaw - implement setup +} + +// New initializes a new arrayNodeHandler +func New(eventConfig *config.EventConfig, scope promutils.Scope) handler.Node { + arrayScope := scope.NewSubScope("array") + return &arrayNodeHandler{ + metrics: newMetrics(arrayScope), + } +} diff --git a/pkg/controller/nodes/gate/handler.go b/pkg/controller/nodes/gate/handler.go index 31e21c4dc..f51c92a54 100644 --- a/pkg/controller/nodes/gate/handler.go +++ b/pkg/controller/nodes/gate/handler.go @@ -197,7 +197,7 @@ func (g *gateNodeHandler) Handle(ctx context.Context, nCtx handler.NodeExecution // update gate node status if err := nCtx.NodeStateWriter().PutGateNodeState(gateNodeState); err != nil { - logger.Errorf(ctx, "failed to store TaskNode state with err [%s]", err.Error()) + logger.Errorf(ctx, "failed to store GateNode state with err [%s]", err.Error()) return handler.UnknownTransition, err } diff --git a/pkg/controller/nodes/handler/state.go b/pkg/controller/nodes/handler/state.go index d40459099..91622f054 100644 --- a/pkg/controller/nodes/handler/state.go +++ b/pkg/controller/nodes/handler/state.go @@ -46,12 +46,17 @@ type GateNodeState struct { StartedAt time.Time } +type ArrayNodeState struct { + Phase v1alpha1.ArrayNodePhase +} + type NodeStateWriter interface { PutTaskNodeState(s TaskNodeState) error PutBranchNode(s BranchNodeState) error PutDynamicNodeState(s DynamicNodeState) error PutWorkflowNodeState(s WorkflowNodeState) error PutGateNodeState(s GateNodeState) error + PutArrayNodeState(s ArrayNodeState) error } type NodeStateReader interface { @@ -60,4 +65,5 @@ type NodeStateReader interface { GetDynamicNodeState() DynamicNodeState GetWorkflowNodeState() WorkflowNodeState GetGateNodeState() GateNodeState + GetArrayNodeState() ArrayNodeState } diff --git a/pkg/controller/nodes/handler/transition_info.go b/pkg/controller/nodes/handler/transition_info.go index 5d302f4fa..0cce41ef4 100644 --- a/pkg/controller/nodes/handler/transition_info.go +++ b/pkg/controller/nodes/handler/transition_info.go @@ -52,6 +52,9 @@ type TaskNodeInfo struct { type GateNodeInfo struct { } +type ArrayNodeInfo struct { +} + type OutputInfo struct { OutputURI storage.DataReference DeckURI *storage.DataReference @@ -65,6 +68,7 @@ type ExecutionInfo struct { OutputInfo *OutputInfo TaskNodeInfo *TaskNodeInfo GateNodeInfo *GateNodeInfo + ArrayNodeInfo *ArrayNodeInfo } type PhaseInfo struct { diff --git a/pkg/controller/nodes/handler_factory.go b/pkg/controller/nodes/handler_factory.go index e13143e6b..0bfa71d4a 100644 --- a/pkg/controller/nodes/handler_factory.go +++ b/pkg/controller/nodes/handler_factory.go @@ -5,28 +5,26 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" - "github.com/flyteorg/flytepropeller/pkg/controller/config" - - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/dynamic" - - "github.com/flyteorg/flytestdlib/promutils" - - "github.com/pkg/errors" - "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/config" "github.com/flyteorg/flytepropeller/pkg/controller/executors" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/array" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/branch" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/dynamic" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/end" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/gate" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/start" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task" + + "github.com/flyteorg/flytestdlib/promutils" + + "github.com/pkg/errors" ) //go:generate mockery -name HandlerFactory -case=underscore @@ -72,6 +70,7 @@ func NewHandlerFactory(ctx context.Context, executor executors.Node, workflowLau v1alpha1.NodeKindTask: dynamic.New(t, executor, launchPlanReader, eventConfig, scope), v1alpha1.NodeKindWorkflow: subworkflow.New(executor, workflowLauncher, recoveryClient, eventConfig, scope), v1alpha1.NodeKindGate: gate.New(eventConfig, signalClient, scope), + v1alpha1.NodeKindArray: array.New(scope), v1alpha1.NodeKindStart: start.New(), v1alpha1.NodeKindEnd: end.New(), }, diff --git a/pkg/controller/nodes/node_state_manager.go b/pkg/controller/nodes/node_state_manager.go index 73baf4dda..cd050cc5b 100644 --- a/pkg/controller/nodes/node_state_manager.go +++ b/pkg/controller/nodes/node_state_manager.go @@ -16,6 +16,7 @@ type nodeStateManager struct { d *handler.DynamicNodeState w *handler.WorkflowNodeState g *handler.GateNodeState + a *handler.ArrayNodeState } func (n *nodeStateManager) PutTaskNodeState(s handler.TaskNodeState) error { @@ -43,6 +44,11 @@ func (n *nodeStateManager) PutGateNodeState(s handler.GateNodeState) error { return nil } +func (n *nodeStateManager) PutArrayNodeState(s handler.ArrayNodeState) error { + n.a = &s + return nil +} + func (n nodeStateManager) GetTaskNodeState() handler.TaskNodeState { tn := n.nodeStatus.GetTaskNodeStatus() if tn != nil { @@ -100,12 +106,22 @@ func (n nodeStateManager) GetGateNodeState() handler.GateNodeState { return gs } +func (n nodeStateManager) GetArrayNodeState() handler.ArrayNodeState { + an := n.nodeStatus.GetArrayNodeStatus() + as := handler.ArrayNodeState{} + if an != nil { + as.Phase = an.GetArrayNodePhase() + } + return as +} + func (n *nodeStateManager) clearNodeStatus() { n.t = nil n.b = nil n.d = nil n.w = nil n.g = nil + n.a = nil n.nodeStatus.ClearLastAttemptStartedAt() } diff --git a/pkg/controller/nodes/transformers.go b/pkg/controller/nodes/transformers.go index d34fc7a7a..c1b3952d4 100644 --- a/pkg/controller/nodes/transformers.go +++ b/pkg/controller/nodes/transformers.go @@ -275,4 +275,10 @@ func UpdateNodeStatus(np v1alpha1.NodePhase, p handler.PhaseInfo, n *nodeStateMa t := s.GetOrCreateGateNodeStatus() t.SetGateNodePhase(n.g.Phase) } + + // Update array node status + if n.a != nil { + t := s.GetOrCreateArrayNodeStatus() + t.SetArrayNodePhase(n.a.Phase) + } } From bf0a1ffd83e891010a6e1abeb165b6ef191f7a21 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Thu, 6 Apr 2023 10:14:58 -0500 Subject: [PATCH 03/62] pushing forward Signed-off-by: Daniel Rammer --- pkg/apis/flyteworkflow/v1alpha1/iface.go | 3 + .../flyteworkflow/v1alpha1/node_status.go | 22 ++- pkg/compiler/transformers/k8s/node.go | 5 + pkg/controller/executors/node.go | 4 + pkg/controller/executors/node_lookup.go | 1 + .../nodes/array/execution_context.go | 38 +++++ pkg/controller/nodes/array/handler.go | 133 +++++++++++++++++- pkg/controller/nodes/array/node_lookup.go | 40 ++++++ pkg/controller/nodes/executor.go | 8 +- pkg/controller/nodes/handler/state.go | 7 +- pkg/controller/nodes/handler_factory.go | 2 +- pkg/controller/nodes/node_exec_context.go | 4 +- pkg/controller/nodes/node_state_manager.go | 1 + pkg/controller/nodes/transformers.go | 1 + 14 files changed, 252 insertions(+), 17 deletions(-) create mode 100644 pkg/controller/nodes/array/execution_context.go create mode 100644 pkg/controller/nodes/array/node_lookup.go diff --git a/pkg/apis/flyteworkflow/v1alpha1/iface.go b/pkg/apis/flyteworkflow/v1alpha1/iface.go index f57c39593..35feb2644 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/iface.go +++ b/pkg/apis/flyteworkflow/v1alpha1/iface.go @@ -13,6 +13,7 @@ import ( "k8s.io/apimachinery/pkg/types" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flytestdlib/bitarray" "github.com/flyteorg/flytestdlib/storage" ) @@ -282,12 +283,14 @@ type MutableGateNodeStatus interface { type ExecutableArrayNodeStatus interface { GetArrayNodePhase() ArrayNodePhase + GetSubNodePhases() bitarray.CompactArray } type MutableArrayNodeStatus interface { Mutable ExecutableArrayNodeStatus SetArrayNodePhase(phase ArrayNodePhase) + SetSubNodePhases(subNodePhases bitarray.CompactArray) } type Mutable interface { diff --git a/pkg/apis/flyteworkflow/v1alpha1/node_status.go b/pkg/apis/flyteworkflow/v1alpha1/node_status.go index 758a72206..c70821b97 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/node_status.go +++ b/pkg/apis/flyteworkflow/v1alpha1/node_status.go @@ -9,11 +9,11 @@ import ( "strconv" "time" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flytestdlib/bitarray" "github.com/flyteorg/flytestdlib/storage" - "github.com/flyteorg/flytestdlib/logger" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) @@ -210,14 +210,17 @@ func (in *GateNodeStatus) SetGateNodePhase(phase GateNodePhase) { type ArrayNodePhase int const ( - ArrayNodePhaseUndefined ArrayNodePhase = iota + ArrayNodePhaseNone ArrayNodePhase = iota ArrayNodePhaseExecuting + ArrayNodePhaseFailing + ArrayNodePhaseSucceeding // TODO @hamersaw - need more phases ) type ArrayNodeStatus struct { MutableStruct - Phase ArrayNodePhase `json:"phase,omitempty"` + Phase ArrayNodePhase `json:"phase,omitempty"` + SubNodePhases bitarray.CompactArray `json:"subphase,omitempty"` } func (in *ArrayNodeStatus) GetArrayNodePhase() ArrayNodePhase { @@ -231,6 +234,17 @@ func (in *ArrayNodeStatus) SetArrayNodePhase(phase ArrayNodePhase) { } } +func (in *ArrayNodeStatus) GetSubNodePhases() bitarray.CompactArray { + return in.SubNodePhases +} + +func (in *ArrayNodeStatus) SetSubNodePhases(subNodePhases bitarray.CompactArray) { + if in.SubNodePhases != subNodePhases { + in.SetDirty() + in.SubNodePhases = subNodePhases + } +} + type NodeStatus struct { MutableStruct Phase NodePhase `json:"phase,omitempty"` diff --git a/pkg/compiler/transformers/k8s/node.go b/pkg/compiler/transformers/k8s/node.go index 1461be881..fbf7cc321 100644 --- a/pkg/compiler/transformers/k8s/node.go +++ b/pkg/compiler/transformers/k8s/node.go @@ -156,6 +156,11 @@ func buildNodeSpec(n *core.Node, tasks []*core.CompiledTask, errs errors.Compile } case *core.Node_ArrayNode: // TODO @hamersaw - complete + nodeSpec.Kind = v1alpha1.NodeKindArray + nodeSpec.ArrayNode = &v1alpha1.ArrayNodeSpec{ + } + //arrayNode := n.GetArrayNode() + default: if n.GetId() == v1alpha1.StartNodeID { nodeSpec.Kind = v1alpha1.NodeKindStart diff --git a/pkg/controller/executors/node.go b/pkg/controller/executors/node.go index a8f738c3e..3fbe9015c 100644 --- a/pkg/controller/executors/node.go +++ b/pkg/controller/executors/node.go @@ -7,6 +7,7 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + //"github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" ) //go:generate mockery -all -case=underscore @@ -81,6 +82,9 @@ type Node interface { FinalizeHandler(ctx context.Context, execContext ExecutionContext, dag DAGStructure, nl NodeLookup, currentNode v1alpha1.ExecutableNode) error + // TODO @docs + //NewNodeExecContext(ctx context.Context, executionContext ExecutionContext, nl NodeLookup, currentNodeID v1alpha1.NodeID) (handler.NodeExecutionContext, error) + // This method should be used to initialize Node executor Initialize(ctx context.Context) error } diff --git a/pkg/controller/executors/node_lookup.go b/pkg/controller/executors/node_lookup.go index 381b832c0..66fc9bddf 100644 --- a/pkg/controller/executors/node_lookup.go +++ b/pkg/controller/executors/node_lookup.go @@ -12,6 +12,7 @@ import ( type NodeLookup interface { GetNode(nodeID v1alpha1.NodeID) (v1alpha1.ExecutableNode, bool) GetNodeExecutionStatus(ctx context.Context, id v1alpha1.NodeID) v1alpha1.ExecutableNodeStatus + // Lookup for upstream edges, find all node ids from which this node can be reached. ToNode(id v1alpha1.NodeID) ([]v1alpha1.NodeID, error) // Lookup for downstream edges, find all node ids that can be reached from the given node id. diff --git a/pkg/controller/nodes/array/execution_context.go b/pkg/controller/nodes/array/execution_context.go new file mode 100644 index 000000000..5893442e8 --- /dev/null +++ b/pkg/controller/nodes/array/execution_context.go @@ -0,0 +1,38 @@ +package array + +import ( + //"github.com/flyteorg/flytepropeller/pkg/controller/executors" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" +) + +type arrayNodeExecutionContext struct { + handler.NodeExecutionContext +} + +// TODO @hamersaw - overwrite everything +/* +inputReader +taskRecorder +nodeRecorder - need to add to nodeExecutionContext so we can override?!?! +maxParallelism - looks like we need: + ExecutionConfig.GetMaxParallelism + ExecutionContext.IncrementMaxParallelism +storage locations + dataPrefix + +add environment variables for maptask execution either: + (1) in arrayExecutionContext if we use separate for each + (2) in arrayNodeExectionContext if we choose to use single DAG +*/ + +/*func newArrayExecutionContext(executionContext executors.ExecutionContext) executors.ExecutionContext { + return arrayExecutionContext{ + ExecutionContext: executionContext, + } +}*/ + +func newArrayNodeExecutionContext(nodeExecutionContext handler.NodeExecutionContext) arrayNodeExecutionContext { + return arrayNodeExecutionContext{ + NodeExecutionContext: nodeExecutionContext, + } +} diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 1ec0a9dfb..979b2b808 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -2,10 +2,19 @@ package array import ( "context" + "fmt" - "github.com/flyteorg/flytepropeller/pkg/controller/config" + idlcore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" + + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/compiler/validators" + "github.com/flyteorg/flytepropeller/pkg/controller/executors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytestdlib/bitarray" "github.com/flyteorg/flytestdlib/logger" "github.com/flyteorg/flytestdlib/promutils" ) @@ -15,6 +24,7 @@ import ( // arrayNodeHandler is a handle implementation for processing array nodes type arrayNodeHandler struct { metrics metrics + nodeExecutor executors.Node } // metrics encapsulates the prometheus metrics for this handler @@ -42,16 +52,130 @@ func (a *arrayNodeHandler) Finalize(ctx context.Context, _ handler.NodeExecution // FinalizeRequired defines whether or not this handler requires finalize to be called on // node completion func (a *arrayNodeHandler) FinalizeRequired() bool { - return false // TODO @hamersaw - implement finalize required + return true // TODO @hamersaw - implement finalize required } // Handle is responsible for transitioning and reporting node state to complete the node defined // by the NodeExecutionContext func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.Transition, error) { - arrayNode := nCtx.Node().GetArrayNode() + //arrayNode := nCtx.Node().GetArrayNode() arrayNodeState := nCtx.NodeStateReader().GetArrayNodeState() // TODO @hamersaw - handle array node + // the big question right now is if we make a DAG with everything or call a separate DAG for each individual task + // need to do much more thinking on this - cleaner = a single DAG / maybe easier = DAG for each + // single: + // + can still add envVars - override in ArrayNodeExectionContext + // each: + // + add envVars on ExecutionContext + // - need to manage + + switch arrayNodeState.Phase { + case v1alpha1.ArrayNodePhaseNone: + // identify and validate array node input value lengths + literalMap, err := nCtx.InputReader().Get(ctx) + if err != nil { + return handler.UnknownTransition, err // TODO @hamersaw fail + } + + size := -1 + for _, variable := range literalMap.Literals { + literalType := validators.LiteralTypeForLiteral(variable) + switch literalType.Type.(type) { + case *idlcore.LiteralType_CollectionType: + collection := variable.GetCollection() + collectionLength := len(collection.Literals) + + if size == -1 { + size = collectionLength + } else if size != collectionLength { + // TODO @hamersaw - return error + } + } + } + + if size == -1 { + // TODO @hamersaw return + } + + // initialize ArrayNode state + arrayNodeState.Phase = v1alpha1.ArrayNodePhaseExecuting + arrayNodeState.SubNodePhases, err = bitarray.NewCompactArray(uint(size), bitarray.Item(len(core.Phases)-1)) + if err != nil { + // TODO @hamersaw fail + } + // TODO @hamersaw - init SystemFailures and RetryAttempts as well + // do we want to abstract this? ie. arrayNodeState.GetStats(subNodeIndex) (phase, systemFailures, ...) + + fmt.Printf("HAMERSAW - created SubNodePhases with length '%d:%d'\n", size, len(arrayNodeState.SubNodePhases.GetItems())) + case v1alpha1.ArrayNodePhaseExecuting: + // process array node subnodes + for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { + nodePhase := v1alpha1.NodePhase(nodePhaseUint64) + fmt.Printf("HAMERSAW - TODO evaluating node '%d' in phase '%d'\n", i, nodePhase) + + // TODO @hamersaw - fix + /*if nodes.IsTerminalNodePhase(nodePhase) { + continue + }*/ + + var inputReader io.InputReader + if nodePhase == v1alpha1.NodePhaseNotYetStarted { // TODO @hamersaw - need to do this for PhaseSucceeded as well?!?! to write cache outputs once fastcache is in + // create input readers and set nodePhase to Queued to skip resolving inputs but still allow cache lookups + // TODO @hamersaw - create input readers + nodePhase = v1alpha1.NodePhaseQueued + } + + // wrap node lookup + subNodeID := fmt.Sprintf("%s-n%d", nCtx.NodeID(), i) + subNodeSpec := &v1alpha1.NodeSpec{ + ID: subNodeID, + Name: subNodeID, + } // TODO @hamersaw - compile this in ArrayNodeSpec? + subNodeStatus := &v1alpha1.NodeStatus{ + Phase: nodePhase, + /*TaskNodeStatus: &v1alpha1.TaskNodeStatus{ + Phase: nodePhase, // used for cache lookups - once fastcache is done we dont care about the TaskNodeStatus + },*/ + // TODO @hamersaw - fill out systemFailures, retryAttempt etc + } + + // TODO @hamersaw - can probably create a single arrayNodeLookup with all the subNodeIDs + arrayNodeLookup := newArrayNodeLookup(nCtx.ContextualNodeLookup(), subNodeID, subNodeSpec, subNodeStatus) + + // create base NodeExecutionContext + nodeExecutionContext, err := a.nodeExecutor.NewNodeExecutionContext(ctx, nCtx.ExecutionContext(), arrayNodeLookup, subNodeID) + if err != nil { + // TODO @hamersaw fail + } + + // create new arrayNodeExecutionContext to override for array task execution + arrayNodeExecutionContext := newArrayNodeExecutionContext(nodeExecutionContext, inputReader) + + // execute subNode through RecursiveNodeHandler + // TODO @hamersaw - either + // (1) add func to create nodeExecutionContext to RecursiveNodeHandler + // (2) create nodeExecutionContext before call to RecursiveNodeHandler + // can do with small wrapper function call + nodeStatus, err := a.nodeExecutor.RecursiveNodeHandler(ctx, arrayNodeExecutionContext, &arrayNodeLookup, &arrayNodeLookup, subNodeSpec) + if err != nil { + // TODO @hamersaw fail + } + + // handleNode / abort / finalize task nodeExecutionContext and Handler as parameters - THIS IS THE ENTRYPOINT WE'RE LOOKING FOR + } + + // TODO @hamersaw - determine summary phases + + arrayNodeState.Phase = v1alpha1.ArrayNodePhaseSucceeding + case v1alpha1.ArrayNodePhaseFailing: + // TODO @hamersaw - abort everything! + case v1alpha1.ArrayNodePhaseSucceeding: + // TODO @hamersaw - collect outputs + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil + default: + // TODO @hamersaw - fail + } // update array node status if err := nCtx.NodeStateWriter().PutArrayNodeState(arrayNodeState); err != nil { @@ -68,9 +192,10 @@ func (a *arrayNodeHandler) Setup(_ context.Context, _ handler.SetupContext) erro } // New initializes a new arrayNodeHandler -func New(eventConfig *config.EventConfig, scope promutils.Scope) handler.Node { +func New(nodeExecutor executors.Node, scope promutils.Scope) handler.Node { arrayScope := scope.NewSubScope("array") return &arrayNodeHandler{ metrics: newMetrics(arrayScope), + nodeExecutor: nodeExecutor, } } diff --git a/pkg/controller/nodes/array/node_lookup.go b/pkg/controller/nodes/array/node_lookup.go new file mode 100644 index 000000000..d1ef8fe55 --- /dev/null +++ b/pkg/controller/nodes/array/node_lookup.go @@ -0,0 +1,40 @@ +package array + +import ( + "context" + + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/executors" +) + +type arrayNodeLookup struct { + executors.NodeLookup + subNodeID v1alpha1.NodeID + subNodeSpec *v1alpha1.NodeSpec + subNodeStatus *v1alpha1.NodeStatus +} + +func (a *arrayNodeLookup) GetNode(nodeID v1alpha1.NodeID) (v1alpha1.ExecutableNode, bool) { + if nodeID == a.subNodeID { + return a.subNodeSpec, true + } + + return a.NodeLookup.GetNode(nodeID) +} + +func (a *arrayNodeLookup) GetNodeExecutionStatus(ctx context.Context, id v1alpha1.NodeID) v1alpha1.ExecutableNodeStatus { + if id == a.subNodeID { + return a.subNodeStatus + } + + return a.NodeLookup.GetNodeExecutionStatus(ctx, id) +} + +func newArrayNodeLookup(nodeLookup executors.NodeLookup, subNodeID v1alpha1.NodeID, subNodeSpec *v1alpha1.NodeSpec, subNodeStatus *v1alpha1.NodeStatus) arrayNodeLookup { + return arrayNodeLookup{ + NodeLookup: nodeLookup, + subNodeID: subNodeID, + subNodeSpec: subNodeSpec, + subNodeStatus: subNodeStatus, + } +} diff --git a/pkg/controller/nodes/executor.go b/pkg/controller/nodes/executor.go index 3d0707469..cb3a1ceee 100644 --- a/pkg/controller/nodes/executor.go +++ b/pkg/controller/nodes/executor.go @@ -733,6 +733,8 @@ func (c *nodeExecutor) handleRetryableFailure(ctx context.Context, nCtx *nodeExe nodeStatus.ClearTaskStatus() nodeStatus.ClearWorkflowStatus() nodeStatus.ClearDynamicNodeStatus() + nodeStatus.ClearGateNodeStatus() + nodeStatus.ClearArrayNodeStatus() return executors.NodeStatusPending, nil } @@ -998,7 +1000,7 @@ func (c *nodeExecutor) RecursiveNodeHandler(ctx context.Context, execContext exe return executors.NodeStatusRunning, nil } - nCtx, err := c.newNodeExecContextDefault(ctx, currentNode.GetID(), execContext, nl) + nCtx, err := c.NewNodeExecutionContext(ctx, execContext, nl, currentNode.GetID()) if err != nil { // NodeExecution creation failure is a permanent fail / system error. // Should a system failure always return an err? @@ -1064,7 +1066,7 @@ func (c *nodeExecutor) FinalizeHandler(ctx context.Context, execContext executor return err } - nCtx, err := c.newNodeExecContextDefault(ctx, currentNode.GetID(), execContext, nl) + nCtx, err := c.NewNodeExecutionContext(ctx, execContext, nl, currentNode.GetID()) if err != nil { return err } @@ -1123,7 +1125,7 @@ func (c *nodeExecutor) AbortHandler(ctx context.Context, execContext executors.E return err } - nCtx, err := c.newNodeExecContextDefault(ctx, currentNode.GetID(), execContext, nl) + nCtx, err := c.NewNodeExecutionContext(ctx, execContext, nl, currentNode.GetID()) if err != nil { return err } diff --git a/pkg/controller/nodes/handler/state.go b/pkg/controller/nodes/handler/state.go index 91622f054..e4cc851a2 100644 --- a/pkg/controller/nodes/handler/state.go +++ b/pkg/controller/nodes/handler/state.go @@ -5,9 +5,9 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" pluginCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flytestdlib/storage" - "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytestdlib/bitarray" + "github.com/flyteorg/flytestdlib/storage" ) // This is the legacy state structure that gets translated to node status @@ -47,7 +47,8 @@ type GateNodeState struct { } type ArrayNodeState struct { - Phase v1alpha1.ArrayNodePhase + Phase v1alpha1.ArrayNodePhase + SubNodePhases bitarray.CompactArray } type NodeStateWriter interface { diff --git a/pkg/controller/nodes/handler_factory.go b/pkg/controller/nodes/handler_factory.go index 0bfa71d4a..d1df39cab 100644 --- a/pkg/controller/nodes/handler_factory.go +++ b/pkg/controller/nodes/handler_factory.go @@ -70,7 +70,7 @@ func NewHandlerFactory(ctx context.Context, executor executors.Node, workflowLau v1alpha1.NodeKindTask: dynamic.New(t, executor, launchPlanReader, eventConfig, scope), v1alpha1.NodeKindWorkflow: subworkflow.New(executor, workflowLauncher, recoveryClient, eventConfig, scope), v1alpha1.NodeKindGate: gate.New(eventConfig, signalClient, scope), - v1alpha1.NodeKindArray: array.New(scope), + v1alpha1.NodeKindArray: array.New(executor, scope), v1alpha1.NodeKindStart: start.New(), v1alpha1.NodeKindEnd: end.New(), }, diff --git a/pkg/controller/nodes/node_exec_context.go b/pkg/controller/nodes/node_exec_context.go index 94a83040a..8914407c6 100644 --- a/pkg/controller/nodes/node_exec_context.go +++ b/pkg/controller/nodes/node_exec_context.go @@ -185,8 +185,8 @@ func newNodeExecContext(_ context.Context, store *storage.DataStore, execContext } } -func (c *nodeExecutor) newNodeExecContextDefault(ctx context.Context, currentNodeID v1alpha1.NodeID, - executionContext executors.ExecutionContext, nl executors.NodeLookup) (*nodeExecContext, error) { +func (c *nodeExecutor) NewNodeExecutionContext(ctx context.Context, executionContext executors.ExecutionContext, + nl executors.NodeLookup, currentNodeID v1alpha1.NodeID) (*nodeExecContext, error) { n, ok := nl.GetNode(currentNodeID) if !ok { return nil, fmt.Errorf("failed to find node with ID [%s] in execution [%s]", currentNodeID, executionContext.GetID()) diff --git a/pkg/controller/nodes/node_state_manager.go b/pkg/controller/nodes/node_state_manager.go index cd050cc5b..f5600b98c 100644 --- a/pkg/controller/nodes/node_state_manager.go +++ b/pkg/controller/nodes/node_state_manager.go @@ -111,6 +111,7 @@ func (n nodeStateManager) GetArrayNodeState() handler.ArrayNodeState { as := handler.ArrayNodeState{} if an != nil { as.Phase = an.GetArrayNodePhase() + as.SubNodePhases = an.GetSubNodePhases() } return as } diff --git a/pkg/controller/nodes/transformers.go b/pkg/controller/nodes/transformers.go index c1b3952d4..8576cac6d 100644 --- a/pkg/controller/nodes/transformers.go +++ b/pkg/controller/nodes/transformers.go @@ -280,5 +280,6 @@ func UpdateNodeStatus(np v1alpha1.NodePhase, p handler.PhaseInfo, n *nodeStateMa if n.a != nil { t := s.GetOrCreateArrayNodeStatus() t.SetArrayNodePhase(n.a.Phase) + t.SetSubNodePhases(n.a.SubNodePhases) } } From 6dc7fb1d416d229d0fa91c49e0e85ce2e35b6927 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Thu, 6 Apr 2023 11:22:31 -0500 Subject: [PATCH 04/62] refactored node executor interfaces to fix dependency cycle Signed-off-by: Daniel Rammer --- .../nodes/array/execution_context.go | 6 +- pkg/controller/nodes/array/handler.go | 31 ++++++---- pkg/controller/nodes/array/node_executor.go | 19 ++++++ pkg/controller/nodes/branch/handler.go | 30 +++++----- .../nodes/dynamic/dynamic_workflow.go | 26 ++++---- pkg/controller/nodes/dynamic/handler.go | 60 +++++++++---------- pkg/controller/nodes/dynamic/utils.go | 4 +- pkg/controller/nodes/end/handler.go | 7 ++- pkg/controller/nodes/gate/handler.go | 7 ++- pkg/controller/nodes/handler/iface.go | 15 ++++- .../{executors => nodes/interfaces}/node.go | 21 ++++--- .../node_exec_context.go | 19 +++--- .../nodes/{handler => interfaces}/state.go | 5 +- .../{handler => interfaces}/state_test.go | 2 +- pkg/controller/nodes/start/handler.go | 7 ++- pkg/controller/nodes/subworkflow/handler.go | 12 ++-- .../nodes/subworkflow/launchplan.go | 30 +++++----- .../nodes/subworkflow/subworkflow.go | 38 ++++++------ pkg/controller/nodes/task/handler.go | 44 +++++++------- pkg/controller/nodes/task/taskexec_context.go | 12 ++-- pkg/controller/nodes/task/transformer.go | 5 +- pkg/controller/workflow/executor.go | 22 +++---- 22 files changed, 234 insertions(+), 188 deletions(-) create mode 100644 pkg/controller/nodes/array/node_executor.go rename pkg/controller/{executors => nodes/interfaces}/node.go (79%) rename pkg/controller/nodes/{handler => interfaces}/node_exec_context.go (93%) rename pkg/controller/nodes/{handler => interfaces}/state.go (98%) rename pkg/controller/nodes/{handler => interfaces}/state_test.go (97%) diff --git a/pkg/controller/nodes/array/execution_context.go b/pkg/controller/nodes/array/execution_context.go index 5893442e8..56c2c2b5d 100644 --- a/pkg/controller/nodes/array/execution_context.go +++ b/pkg/controller/nodes/array/execution_context.go @@ -2,11 +2,11 @@ package array import ( //"github.com/flyteorg/flytepropeller/pkg/controller/executors" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" ) type arrayNodeExecutionContext struct { - handler.NodeExecutionContext + interfaces.NodeExecutionContext } // TODO @hamersaw - overwrite everything @@ -31,7 +31,7 @@ add environment variables for maptask execution either: } }*/ -func newArrayNodeExecutionContext(nodeExecutionContext handler.NodeExecutionContext) arrayNodeExecutionContext { +func newArrayNodeExecutionContext(nodeExecutionContext interfaces.NodeExecutionContext) arrayNodeExecutionContext { return arrayNodeExecutionContext{ NodeExecutionContext: nodeExecutionContext, } diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 979b2b808..a2b9d470a 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -11,8 +11,8 @@ import ( "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" "github.com/flyteorg/flytepropeller/pkg/compiler/validators" - "github.com/flyteorg/flytepropeller/pkg/controller/executors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flytestdlib/bitarray" "github.com/flyteorg/flytestdlib/logger" @@ -24,7 +24,7 @@ import ( // arrayNodeHandler is a handle implementation for processing array nodes type arrayNodeHandler struct { metrics metrics - nodeExecutor executors.Node + nodeExecutor interfaces.Node } // metrics encapsulates the prometheus metrics for this handler @@ -40,12 +40,12 @@ func newMetrics(scope promutils.Scope) metrics { } // Abort stops the array node defined in the NodeExecutionContext -func (a *arrayNodeHandler) Abort(ctx context.Context, nCtx handler.NodeExecutionContext, reason string) error { +func (a *arrayNodeHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecutionContext, reason string) error { return nil // TODO @hamersaw - implement abort } // Finalize completes the array node defined in the NodeExecutionContext -func (a *arrayNodeHandler) Finalize(ctx context.Context, _ handler.NodeExecutionContext) error { +func (a *arrayNodeHandler) Finalize(ctx context.Context, _ interfaces.NodeExecutionContext) error { return nil // TODO @hamersaw - implement finalize } @@ -57,7 +57,7 @@ func (a *arrayNodeHandler) FinalizeRequired() bool { // Handle is responsible for transitioning and reporting node state to complete the node defined // by the NodeExecutionContext -func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.Transition, error) { +func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.Transition, error) { //arrayNode := nCtx.Node().GetArrayNode() arrayNodeState := nCtx.NodeStateReader().GetArrayNodeState() @@ -143,24 +143,33 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx handler.NodeExecutio // TODO @hamersaw - can probably create a single arrayNodeLookup with all the subNodeIDs arrayNodeLookup := newArrayNodeLookup(nCtx.ContextualNodeLookup(), subNodeID, subNodeSpec, subNodeStatus) - // create base NodeExecutionContext - nodeExecutionContext, err := a.nodeExecutor.NewNodeExecutionContext(ctx, nCtx.ExecutionContext(), arrayNodeLookup, subNodeID) + // create arrayNodeExecutor + /*nodeExecutionContext, err := a.nodeExecutor.NewNodeExecutionContext(ctx, nCtx.ExecutionContext(), &arrayNodeLookup, subNodeID) if err != nil { // TODO @hamersaw fail } // create new arrayNodeExecutionContext to override for array task execution - arrayNodeExecutionContext := newArrayNodeExecutionContext(nodeExecutionContext, inputReader) + arrayNodeExecutionContext := newArrayNodeExecutionContext(nodeExecutionContext, inputReader)*/ + arrayNodeExecutor := newArrayNodeExecutor(a.nodeExecutor) + + // execute subNode through RecursiveNodeHandler + nodeStatus, err := arrayNodeExecutor.RecursiveNodeHandler(ctx, nCtx.ExecutionContext(), &arrayNodeLookup, &arrayNodeLookup, subNodeSpec) + if err != nil { + // TODO @hamersaw fail + } + + fmt.Printf("HAMERSAW - node phase transition %d -> %d", nodePhase, nodeStatus.NodePhase) // execute subNode through RecursiveNodeHandler // TODO @hamersaw - either // (1) add func to create nodeExecutionContext to RecursiveNodeHandler // (2) create nodeExecutionContext before call to RecursiveNodeHandler // can do with small wrapper function call - nodeStatus, err := a.nodeExecutor.RecursiveNodeHandler(ctx, arrayNodeExecutionContext, &arrayNodeLookup, &arrayNodeLookup, subNodeSpec) + /*nodeStatus, err := a.nodeExecutor.RecursiveNodeHandler(ctx, arrayNodeExecutionContext, &arrayNodeLookup, &arrayNodeLookup, subNodeSpec) if err != nil { // TODO @hamersaw fail - } + }*/ // handleNode / abort / finalize task nodeExecutionContext and Handler as parameters - THIS IS THE ENTRYPOINT WE'RE LOOKING FOR } @@ -192,7 +201,7 @@ func (a *arrayNodeHandler) Setup(_ context.Context, _ handler.SetupContext) erro } // New initializes a new arrayNodeHandler -func New(nodeExecutor executors.Node, scope promutils.Scope) handler.Node { +func New(nodeExecutor interfaces.Node, scope promutils.Scope) handler.Node { arrayScope := scope.NewSubScope("array") return &arrayNodeHandler{ metrics: newMetrics(arrayScope), diff --git a/pkg/controller/nodes/array/node_executor.go b/pkg/controller/nodes/array/node_executor.go new file mode 100644 index 000000000..81aa59754 --- /dev/null +++ b/pkg/controller/nodes/array/node_executor.go @@ -0,0 +1,19 @@ +package array + +import ( + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" +) + +type arrayNodeExecutor struct { + interfaces.Node +} + +/* +TODO @hamersaw - override NewNodeExecutionContext function +*/ + +func newArrayNodeExecutor(nodeExecutor interfaces.Node) arrayNodeExecutor { + return arrayNodeExecutor{ + Node: nodeExecutor, + } +} diff --git a/pkg/controller/nodes/branch/handler.go b/pkg/controller/nodes/branch/handler.go index 109290b90..91f822bda 100644 --- a/pkg/controller/nodes/branch/handler.go +++ b/pkg/controller/nodes/branch/handler.go @@ -5,16 +5,18 @@ import ( "fmt" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flytepropeller/pkg/controller/config" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/common" - stdErrors "github.com/flyteorg/flytestdlib/errors" - "github.com/flyteorg/flytestdlib/logger" - "github.com/flyteorg/flytestdlib/promutils" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/config" "github.com/flyteorg/flytepropeller/pkg/controller/executors" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/common" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + + stdErrors "github.com/flyteorg/flytestdlib/errors" + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/promutils" ) type metrics struct { @@ -22,7 +24,7 @@ type metrics struct { } type branchHandler struct { - nodeExecutor executors.Node + nodeExecutor interfaces.Node m metrics eventConfig *config.EventConfig } @@ -36,7 +38,7 @@ func (b *branchHandler) Setup(ctx context.Context, _ handler.SetupContext) error return nil } -func (b *branchHandler) HandleBranchNode(ctx context.Context, branchNode v1alpha1.ExecutableBranchNode, nCtx handler.NodeExecutionContext, nl executors.NodeLookup) (handler.Transition, error) { +func (b *branchHandler) HandleBranchNode(ctx context.Context, branchNode v1alpha1.ExecutableBranchNode, nCtx interfaces.NodeExecutionContext, nl executors.NodeLookup) (handler.Transition, error) { if nCtx.NodeStateReader().GetBranchNode().FinalizedNodeID == nil { nodeInputs, err := nCtx.InputReader().Get(ctx) if err != nil { @@ -55,7 +57,7 @@ func (b *branchHandler) HandleBranchNode(ctx context.Context, branchNode v1alpha return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, errors.IllegalStateError, errMsg, nil)), nil } - branchNodeState := handler.BranchNodeState{FinalizedNodeID: finalNodeID, Phase: v1alpha1.BranchNodeSuccess} + branchNodeState := interfaces.BranchNodeState{FinalizedNodeID: finalNodeID, Phase: v1alpha1.BranchNodeSuccess} err = nCtx.NodeStateWriter().PutBranchNode(branchNodeState) if err != nil { logger.Errorf(ctx, "Failed to store BranchNode state, err :%s", err.Error()) @@ -103,7 +105,7 @@ func (b *branchHandler) HandleBranchNode(ctx context.Context, branchNode v1alpha return b.recurseDownstream(ctx, nCtx, nodeStatus, branchTakenNode) } -func (b *branchHandler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.Transition, error) { +func (b *branchHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.Transition, error) { logger.Debug(ctx, "Starting Branch Node") branchNode := nCtx.Node().GetBranchNode() if branchNode == nil { @@ -115,7 +117,7 @@ func (b *branchHandler) Handle(ctx context.Context, nCtx handler.NodeExecutionCo return b.HandleBranchNode(ctx, branchNode, nCtx, nl) } -func (b *branchHandler) getExecutionContextForDownstream(nCtx handler.NodeExecutionContext) (executors.ExecutionContext, error) { +func (b *branchHandler) getExecutionContextForDownstream(nCtx interfaces.NodeExecutionContext) (executors.ExecutionContext, error) { newParentInfo, err := common.CreateParentInfo(nCtx.ExecutionContext().GetParentInfo(), nCtx.NodeID(), nCtx.CurrentAttempt()) if err != nil { return nil, err @@ -123,7 +125,7 @@ func (b *branchHandler) getExecutionContextForDownstream(nCtx handler.NodeExecut return executors.NewExecutionContextWithParentInfo(nCtx.ExecutionContext(), newParentInfo), nil } -func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx handler.NodeExecutionContext, nodeStatus v1alpha1.ExecutableNodeStatus, branchTakenNode v1alpha1.ExecutableNode) (handler.Transition, error) { +func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx interfaces.NodeExecutionContext, nodeStatus v1alpha1.ExecutableNodeStatus, branchTakenNode v1alpha1.ExecutableNode) (handler.Transition, error) { // TODO we should replace the call to RecursiveNodeHandler with a call to SingleNode Handler. The inputs are also already known ahead of time // There is no DAGStructure for the branch nodes, the branch taken node is the leaf node. The node itself may be arbitrarily complex, but in that case the node should reference a subworkflow etc // The parent of the BranchTaken Node is the actual Branch Node and all the data is just forwarded from the Branch to the executed node. @@ -167,7 +169,7 @@ func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx handler.Node return handler.DoTransition(handler.TransitionTypeEphemeral, phase), nil } -func (b *branchHandler) Abort(ctx context.Context, nCtx handler.NodeExecutionContext, reason string) error { +func (b *branchHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecutionContext, reason string) error { branch := nCtx.Node().GetBranchNode() if branch == nil { @@ -212,7 +214,7 @@ func (b *branchHandler) Abort(ctx context.Context, nCtx handler.NodeExecutionCon return b.nodeExecutor.AbortHandler(ctx, execContext, dag, nCtx.ContextualNodeLookup(), branchTakenNode, reason) } -func (b *branchHandler) Finalize(ctx context.Context, nCtx handler.NodeExecutionContext) error { +func (b *branchHandler) Finalize(ctx context.Context, nCtx interfaces.NodeExecutionContext) error { branch := nCtx.Node().GetBranchNode() if branch == nil { return errors.Errorf(errors.IllegalStateError, nCtx.NodeID(), "Invoked branch handler, for a non branch node.") @@ -256,7 +258,7 @@ func (b *branchHandler) Finalize(ctx context.Context, nCtx handler.NodeExecution return b.nodeExecutor.FinalizeHandler(ctx, execContext, dag, nCtx.ContextualNodeLookup(), branchTakenNode) } -func New(executor executors.Node, eventConfig *config.EventConfig, scope promutils.Scope) handler.Node { +func New(executor interfaces.Node, eventConfig *config.EventConfig, scope promutils.Scope) handler.Node { return &branchHandler{ nodeExecutor: executor, m: metrics{scope: scope}, diff --git a/pkg/controller/nodes/dynamic/dynamic_workflow.go b/pkg/controller/nodes/dynamic/dynamic_workflow.go index eb891aa27..bf9cf0287 100644 --- a/pkg/controller/nodes/dynamic/dynamic_workflow.go +++ b/pkg/controller/nodes/dynamic/dynamic_workflow.go @@ -5,23 +5,23 @@ import ( "fmt" "strconv" - "k8s.io/apimachinery/pkg/util/rand" - - node_common "github.com/flyteorg/flytepropeller/pkg/controller/nodes/common" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" "github.com/flyteorg/flytepropeller/pkg/compiler" "github.com/flyteorg/flytepropeller/pkg/compiler/common" "github.com/flyteorg/flytepropeller/pkg/compiler/transformers/k8s" "github.com/flyteorg/flytepropeller/pkg/controller/executors" + node_common "github.com/flyteorg/flytepropeller/pkg/controller/nodes/common" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task" "github.com/flyteorg/flytepropeller/pkg/utils" "github.com/flyteorg/flytestdlib/errors" "github.com/flyteorg/flytestdlib/logger" "github.com/flyteorg/flytestdlib/storage" + + "k8s.io/apimachinery/pkg/util/rand" ) type dynamicWorkflowContext struct { @@ -36,7 +36,7 @@ type dynamicWorkflowContext struct { const dynamicWfNameTemplate = "dynamic_%s" func setEphemeralNodeExecutionStatusAttributes(ctx context.Context, djSpec *core.DynamicJobSpec, - nCtx handler.NodeExecutionContext, parentNodeStatus v1alpha1.ExecutableNodeStatus) error { + nCtx interfaces.NodeExecutionContext, parentNodeStatus v1alpha1.ExecutableNodeStatus) error { if nCtx.ExecutionContext().GetEventVersion() != v1alpha1.EventVersion0 { return nil } @@ -77,7 +77,7 @@ func setEphemeralNodeExecutionStatusAttributes(ctx context.Context, djSpec *core } func (d dynamicNodeTaskNodeHandler) buildDynamicWorkflowTemplate(ctx context.Context, djSpec *core.DynamicJobSpec, - nCtx handler.NodeExecutionContext, parentNodeStatus v1alpha1.ExecutableNodeStatus) (*core.WorkflowTemplate, error) { + nCtx interfaces.NodeExecutionContext, parentNodeStatus v1alpha1.ExecutableNodeStatus) (*core.WorkflowTemplate, error) { iface, err := underlyingInterface(ctx, nCtx.TaskReader()) if err != nil { @@ -127,7 +127,7 @@ func (d dynamicNodeTaskNodeHandler) buildDynamicWorkflowTemplate(ctx context.Con }, nil } -func (d dynamicNodeTaskNodeHandler) buildContextualDynamicWorkflow(ctx context.Context, nCtx handler.NodeExecutionContext) (dynamicWorkflowContext, error) { +func (d dynamicNodeTaskNodeHandler) buildContextualDynamicWorkflow(ctx context.Context, nCtx interfaces.NodeExecutionContext) (dynamicWorkflowContext, error) { t := d.metrics.buildDynamicWorkflow.Start(ctx) defer t.Stop() @@ -221,7 +221,7 @@ func (d dynamicNodeTaskNodeHandler) buildContextualDynamicWorkflow(ctx context.C }, nil } -func (d dynamicNodeTaskNodeHandler) buildDynamicWorkflow(ctx context.Context, nCtx handler.NodeExecutionContext, +func (d dynamicNodeTaskNodeHandler) buildDynamicWorkflow(ctx context.Context, nCtx interfaces.NodeExecutionContext, djSpec *core.DynamicJobSpec, dynamicNodeStatus v1alpha1.ExecutableNodeStatus) (*core.CompiledWorkflowClosure, *v1alpha1.FlyteWorkflow, dynamicWorkflowContext, error) { wf, err := d.buildDynamicWorkflowTemplate(ctx, djSpec, nCtx, dynamicNodeStatus) if err != nil { @@ -265,7 +265,7 @@ func (d dynamicNodeTaskNodeHandler) buildDynamicWorkflow(ctx context.Context, nC } func (d dynamicNodeTaskNodeHandler) progressDynamicWorkflow(ctx context.Context, execContext executors.ExecutionContext, dynamicWorkflow v1alpha1.ExecutableWorkflow, nl executors.NodeLookup, - nCtx handler.NodeExecutionContext, prevState handler.DynamicNodeState) (handler.Transition, handler.DynamicNodeState, error) { + nCtx interfaces.NodeExecutionContext, prevState interfaces.DynamicNodeState) (handler.Transition, interfaces.DynamicNodeState, error) { state, err := d.nodeExecutor.RecursiveNodeHandler(ctx, execContext, dynamicWorkflow, nl, dynamicWorkflow.StartNode()) if err != nil { @@ -281,7 +281,7 @@ func (d dynamicNodeTaskNodeHandler) progressDynamicWorkflow(ctx context.Context, // As we do not support Failure Node, we can just return failure in this case return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoDynamicRunning(nil)), - handler.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: "Dynamic workflow failed", Error: state.Err}, + interfaces.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: "Dynamic workflow failed", Error: state.Err}, nil } @@ -293,7 +293,7 @@ func (d dynamicNodeTaskNodeHandler) progressDynamicWorkflow(ctx context.Context, endNodeStatus := dynamicNodeStatus.GetNodeExecutionStatus(ctx, v1alpha1.EndNodeID) if endNodeStatus == nil { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, "MalformedDynamicWorkflow", "no end-node found in dynamic workflow", nil)), - handler.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: "no end-node found in dynamic workflow"}, + interfaces.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: "no end-node found in dynamic workflow"}, nil } @@ -301,7 +301,7 @@ func (d dynamicNodeTaskNodeHandler) progressDynamicWorkflow(ctx context.Context, if metadata, err := nCtx.DataStore().Head(ctx, sourcePath); err == nil { if !metadata.Exists() { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRetryableFailure(core.ExecutionError_SYSTEM, "DynamicWorkflowOutputsNotFound", fmt.Sprintf(" is expected to produce outputs but no outputs file was written to %v.", sourcePath), nil)), - handler.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: "DynamicWorkflow is expected to produce outputs but no outputs file was written"}, + interfaces.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: "DynamicWorkflow is expected to produce outputs but no outputs file was written"}, nil } } else { @@ -313,7 +313,7 @@ func (d dynamicNodeTaskNodeHandler) progressDynamicWorkflow(ctx context.Context, return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, "OutputsNotFound", fmt.Sprintf("Failed to copy subworkflow outputs from [%v] to [%v]. Error: %s", sourcePath, destinationPath, err.Error()), nil), - ), handler.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: "Failed to copy subworkflow outputs"}, + ), interfaces.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: "Failed to copy subworkflow outputs"}, nil } diff --git a/pkg/controller/nodes/dynamic/handler.go b/pkg/controller/nodes/dynamic/handler.go index 1d35fb502..2cfa745e1 100644 --- a/pkg/controller/nodes/dynamic/handler.go +++ b/pkg/controller/nodes/dynamic/handler.go @@ -4,28 +4,26 @@ import ( "context" "time" - "github.com/flyteorg/flytepropeller/pkg/controller/config" - + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" - "github.com/flyteorg/flytestdlib/logger" - "github.com/flyteorg/flytestdlib/promutils/labeled" + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/config" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" - "github.com/flyteorg/flytepropeller/pkg/utils" - - "github.com/flyteorg/flytepropeller/pkg/controller/executors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task" + "github.com/flyteorg/flytepropeller/pkg/utils" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" stdErrors "github.com/flyteorg/flytestdlib/errors" + "github.com/flyteorg/flytestdlib/logger" "github.com/flyteorg/flytestdlib/promutils" - - "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytestdlib/promutils/labeled" ) //go:generate mockery -all -case=underscore @@ -60,12 +58,12 @@ func newMetrics(scope promutils.Scope) metrics { type dynamicNodeTaskNodeHandler struct { TaskNodeHandler metrics metrics - nodeExecutor executors.Node + nodeExecutor interfaces.Node lpReader launchplan.Reader eventConfig *config.EventConfig } -func (d dynamicNodeTaskNodeHandler) handleParentNode(ctx context.Context, prevState handler.DynamicNodeState, nCtx handler.NodeExecutionContext) (handler.Transition, handler.DynamicNodeState, error) { +func (d dynamicNodeTaskNodeHandler) handleParentNode(ctx context.Context, prevState interfaces.DynamicNodeState, nCtx interfaces.NodeExecutionContext) (handler.Transition, interfaces.DynamicNodeState, error) { // It seems parent node is still running, lets call handle for parent node trns, err := d.TaskNodeHandler.Handle(ctx, nCtx) if err != nil { @@ -87,7 +85,7 @@ func (d dynamicNodeTaskNodeHandler) handleParentNode(ctx context.Context, prevSt // directly to record, and then progress the dynamically generated workflow. logger.Infof(ctx, "future file detected, assuming dynamic node") // There is a futures file, so we need to continue running the node with the modified state - return trns.WithInfo(handler.PhaseInfoRunning(trns.Info().GetInfo())), handler.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseParentFinalizing}, nil + return trns.WithInfo(handler.PhaseInfoRunning(trns.Info().GetInfo())), interfaces.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseParentFinalizing}, nil } } @@ -95,8 +93,8 @@ func (d dynamicNodeTaskNodeHandler) handleParentNode(ctx context.Context, prevSt return trns, prevState, nil } -func (d dynamicNodeTaskNodeHandler) produceDynamicWorkflow(ctx context.Context, nCtx handler.NodeExecutionContext) ( - handler.Transition, handler.DynamicNodeState, error) { +func (d dynamicNodeTaskNodeHandler) produceDynamicWorkflow(ctx context.Context, nCtx interfaces.NodeExecutionContext) ( + handler.Transition, interfaces.DynamicNodeState, error) { // The first time this is called we go ahead and evaluate the dynamic node to build the workflow. We then cache // this workflow definition and send it to be persisted by flyteadmin so that users can observe the structure. dCtx, err := d.buildContextualDynamicWorkflow(ctx, nCtx) @@ -104,9 +102,9 @@ func (d dynamicNodeTaskNodeHandler) produceDynamicWorkflow(ctx context.Context, if stdErrors.IsCausedBy(err, utils.ErrorCodeUser) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_USER, "DynamicWorkflowBuildFailed", err.Error(), nil), - ), handler.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: err.Error()}, nil + ), interfaces.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: err.Error()}, nil } - return handler.Transition{}, handler.DynamicNodeState{}, err + return handler.Transition{}, interfaces.DynamicNodeState{}, err } taskNodeInfoMetadata := &event.TaskNodeMetadata{} if dCtx.subWorkflowClosure != nil && dCtx.subWorkflowClosure.Primary != nil && dCtx.subWorkflowClosure.Primary.Template != nil { @@ -117,7 +115,7 @@ func (d dynamicNodeTaskNodeHandler) produceDynamicWorkflow(ctx context.Context, } } - nextState := handler.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseExecuting} + nextState := interfaces.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseExecuting} return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoDynamicRunning(&handler.ExecutionInfo{ TaskNodeInfo: &handler.TaskNodeInfo{ TaskNodeMetadata: taskNodeInfoMetadata, @@ -125,16 +123,16 @@ func (d dynamicNodeTaskNodeHandler) produceDynamicWorkflow(ctx context.Context, })), nextState, nil } -func (d dynamicNodeTaskNodeHandler) handleDynamicSubNodes(ctx context.Context, nCtx handler.NodeExecutionContext, prevState handler.DynamicNodeState) (handler.Transition, handler.DynamicNodeState, error) { +func (d dynamicNodeTaskNodeHandler) handleDynamicSubNodes(ctx context.Context, nCtx interfaces.NodeExecutionContext, prevState interfaces.DynamicNodeState) (handler.Transition, interfaces.DynamicNodeState, error) { dCtx, err := d.buildContextualDynamicWorkflow(ctx, nCtx) if err != nil { if stdErrors.IsCausedBy(err, utils.ErrorCodeUser) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_USER, "DynamicWorkflowBuildFailed", err.Error(), nil), - ), handler.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: err.Error()}, nil + ), interfaces.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: err.Error()}, nil } // Mostly a system error or unknown - return handler.Transition{}, handler.DynamicNodeState{}, err + return handler.Transition{}, interfaces.DynamicNodeState{}, err } trns, newState, err := d.progressDynamicWorkflow(ctx, dCtx.execContext, dCtx.subWorkflow, dCtx.nodeLookup, nCtx, prevState) @@ -160,10 +158,10 @@ func (d dynamicNodeTaskNodeHandler) handleDynamicSubNodes(ctx context.Context, n if ee != nil { if ee.IsRecoverable { - return trns.WithInfo(handler.PhaseInfoRetryableFailureErr(ee.ExecutionError, trns.Info().GetInfo())), handler.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: ee.ExecutionError.String()}, nil + return trns.WithInfo(handler.PhaseInfoRetryableFailureErr(ee.ExecutionError, trns.Info().GetInfo())), interfaces.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: ee.ExecutionError.String()}, nil } - return trns.WithInfo(handler.PhaseInfoFailureErr(ee.ExecutionError, trns.Info().GetInfo())), handler.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: ee.ExecutionError.String()}, nil + return trns.WithInfo(handler.PhaseInfoFailureErr(ee.ExecutionError, trns.Info().GetInfo())), interfaces.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: ee.ExecutionError.String()}, nil } taskNodeInfoMetadata := &event.TaskNodeMetadata{CacheStatus: status.GetCacheStatus(), CatalogKey: status.GetMetadata()} trns = trns.WithInfo(trns.Info().WithInfo(&handler.ExecutionInfo{ @@ -182,7 +180,7 @@ func (d dynamicNodeTaskNodeHandler) handleDynamicSubNodes(ctx context.Context, n // DynamicNodePhaseParentFinalized: The parent has node completed successfully and the generated dynamic sub workflow has been serialized and sent as an event. // DynamicNodePhaseExecuting: The parent node has completed and finalized successfully, the sub-nodes are being handled // DynamicNodePhaseFailing: one or more of sub-nodes have failed and the failure is being handled -func (d dynamicNodeTaskNodeHandler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.Transition, error) { +func (d dynamicNodeTaskNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.Transition, error) { ds := nCtx.NodeStateReader().GetDynamicNodeState() var err error var trns handler.Transition @@ -212,7 +210,7 @@ func (d dynamicNodeTaskNodeHandler) Handle(ctx context.Context, nCtx handler.Nod if err := d.finalizeParentNode(ctx, nCtx); err != nil { return handler.UnknownTransition, err } - newState = handler.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseParentFinalized} + newState = interfaces.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseParentFinalized} trns = handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(trns.Info().GetInfo())) case v1alpha1.DynamicNodePhaseParentFinalized: trns, newState, err = d.produceDynamicWorkflow(ctx, nCtx) @@ -235,7 +233,7 @@ func (d dynamicNodeTaskNodeHandler) Handle(ctx context.Context, nCtx handler.Nod return trns, nil } -func (d dynamicNodeTaskNodeHandler) Abort(ctx context.Context, nCtx handler.NodeExecutionContext, reason string) error { +func (d dynamicNodeTaskNodeHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecutionContext, reason string) error { ds := nCtx.NodeStateReader().GetDynamicNodeState() switch ds.Phase { case v1alpha1.DynamicNodePhaseFailing: @@ -262,7 +260,7 @@ func (d dynamicNodeTaskNodeHandler) Abort(ctx context.Context, nCtx handler.Node } } -func (d dynamicNodeTaskNodeHandler) finalizeParentNode(ctx context.Context, nCtx handler.NodeExecutionContext) error { +func (d dynamicNodeTaskNodeHandler) finalizeParentNode(ctx context.Context, nCtx interfaces.NodeExecutionContext) error { logger.Infof(ctx, "Finalizing Parent node RetryAttempt [%d]", nCtx.CurrentAttempt()) if err := d.TaskNodeHandler.Finalize(ctx, nCtx); err != nil { logger.Errorf(ctx, "Failed to finalize Dynamic Nodes Parent.") @@ -272,7 +270,7 @@ func (d dynamicNodeTaskNodeHandler) finalizeParentNode(ctx context.Context, nCtx } // This is a weird method. We should always finalize before we set the dynamic parent node phase as complete? -func (d dynamicNodeTaskNodeHandler) Finalize(ctx context.Context, nCtx handler.NodeExecutionContext) error { +func (d dynamicNodeTaskNodeHandler) Finalize(ctx context.Context, nCtx interfaces.NodeExecutionContext) error { errs := make([]error, 0, 2) ds := nCtx.NodeStateReader().GetDynamicNodeState() @@ -305,7 +303,7 @@ func (d dynamicNodeTaskNodeHandler) Finalize(ctx context.Context, nCtx handler.N return nil } -func New(underlying TaskNodeHandler, nodeExecutor executors.Node, launchPlanReader launchplan.Reader, eventConfig *config.EventConfig, scope promutils.Scope) handler.Node { +func New(underlying TaskNodeHandler, nodeExecutor interfaces.Node, launchPlanReader launchplan.Reader, eventConfig *config.EventConfig, scope promutils.Scope) handler.Node { return &dynamicNodeTaskNodeHandler{ TaskNodeHandler: underlying, diff --git a/pkg/controller/nodes/dynamic/utils.go b/pkg/controller/nodes/dynamic/utils.go index d08845856..6a38bafe4 100644 --- a/pkg/controller/nodes/dynamic/utils.go +++ b/pkg/controller/nodes/dynamic/utils.go @@ -9,11 +9,11 @@ import ( "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/encoding" "github.com/flyteorg/flytepropeller/pkg/compiler" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" ) // Constructs the expected interface of a given node. -func underlyingInterface(ctx context.Context, taskReader handler.TaskReader) (*core.TypedInterface, error) { +func underlyingInterface(ctx context.Context, taskReader interfaces.TaskReader) (*core.TypedInterface, error) { t, err := taskReader.Read(ctx) iface := &core.TypedInterface{} if err != nil { diff --git a/pkg/controller/nodes/end/handler.go b/pkg/controller/nodes/end/handler.go index 4f56ee840..7bd1286ed 100644 --- a/pkg/controller/nodes/end/handler.go +++ b/pkg/controller/nodes/end/handler.go @@ -9,6 +9,7 @@ import ( "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" ) type endHandler struct { @@ -22,7 +23,7 @@ func (e endHandler) Setup(ctx context.Context, setupContext handler.SetupContext return nil } -func (e endHandler) Handle(ctx context.Context, executionContext handler.NodeExecutionContext) (handler.Transition, error) { +func (e endHandler) Handle(ctx context.Context, executionContext interfaces.NodeExecutionContext) (handler.Transition, error) { inputs, err := executionContext.InputReader().Get(ctx) if err != nil { return handler.UnknownTransition, err @@ -41,11 +42,11 @@ func (e endHandler) Handle(ctx context.Context, executionContext handler.NodeExe return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil } -func (e endHandler) Abort(_ context.Context, _ handler.NodeExecutionContext, _ string) error { +func (e endHandler) Abort(_ context.Context, _ interfaces.NodeExecutionContext, _ string) error { return nil } -func (e endHandler) Finalize(_ context.Context, _ handler.NodeExecutionContext) error { +func (e endHandler) Finalize(_ context.Context, _ interfaces.NodeExecutionContext) error { return nil } diff --git a/pkg/controller/nodes/gate/handler.go b/pkg/controller/nodes/gate/handler.go index f51c92a54..340cbab9f 100644 --- a/pkg/controller/nodes/gate/handler.go +++ b/pkg/controller/nodes/gate/handler.go @@ -13,6 +13,7 @@ import ( "github.com/flyteorg/flytepropeller/pkg/controller/config" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flytestdlib/logger" "github.com/flyteorg/flytestdlib/promutils" @@ -46,12 +47,12 @@ func newMetrics(scope promutils.Scope) metrics { } // Abort stops the gate node defined in the NodeExecutionContext -func (g *gateNodeHandler) Abort(ctx context.Context, nCtx handler.NodeExecutionContext, reason string) error { +func (g *gateNodeHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecutionContext, reason string) error { return nil } // Finalize completes the gate node defined in the NodeExecutionContext -func (g *gateNodeHandler) Finalize(ctx context.Context, _ handler.NodeExecutionContext) error { +func (g *gateNodeHandler) Finalize(ctx context.Context, _ interfaces.NodeExecutionContext) error { return nil } @@ -63,7 +64,7 @@ func (g *gateNodeHandler) FinalizeRequired() bool { // Handle is responsible for transitioning and reporting node state to complete the node defined // by the NodeExecutionContext -func (g *gateNodeHandler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.Transition, error) { +func (g *gateNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.Transition, error) { gateNode := nCtx.Node().GetGateNode() gateNodeState := nCtx.NodeStateReader().GetGateNodeState() diff --git a/pkg/controller/nodes/handler/iface.go b/pkg/controller/nodes/handler/iface.go index d0b359171..c2de8af08 100644 --- a/pkg/controller/nodes/handler/iface.go +++ b/pkg/controller/nodes/handler/iface.go @@ -2,6 +2,9 @@ package handler import ( "context" + + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + "github.com/flyteorg/flytestdlib/promutils" ) //go:generate mockery -all -case=underscore @@ -15,12 +18,18 @@ type Node interface { Setup(ctx context.Context, setupContext SetupContext) error // Core method that should handle this node - Handle(ctx context.Context, executionContext NodeExecutionContext) (Transition, error) + Handle(ctx context.Context, executionContext interfaces.NodeExecutionContext) (Transition, error) // This method should be invoked to indicate the node needs to be aborted. - Abort(ctx context.Context, executionContext NodeExecutionContext, reason string) error + Abort(ctx context.Context, executionContext interfaces.NodeExecutionContext, reason string) error // This method is always called before completing the node, if FinalizeRequired returns true. // It is guaranteed that Handle -> (happens before) -> Finalize. Abort -> finalize may be repeated multiple times - Finalize(ctx context.Context, executionContext NodeExecutionContext) error + Finalize(ctx context.Context, executionContext interfaces.NodeExecutionContext) error +} + +type SetupContext interface { + EnqueueOwner() func(string) + OwnerKind() string + MetricsScope() promutils.Scope } diff --git a/pkg/controller/executors/node.go b/pkg/controller/nodes/interfaces/node.go similarity index 79% rename from pkg/controller/executors/node.go rename to pkg/controller/nodes/interfaces/node.go index 3fbe9015c..f713f6348 100644 --- a/pkg/controller/executors/node.go +++ b/pkg/controller/nodes/interfaces/node.go @@ -1,4 +1,4 @@ -package executors +package interfaces import ( "context" @@ -7,7 +7,7 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" - //"github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/executors" ) //go:generate mockery -all -case=underscore @@ -68,22 +68,27 @@ func (p NodePhase) String() string { type Node interface { // This method is used specifically to set inputs for start node. This is because start node does not retrieve inputs // from predecessors, but the inputs are inputs to the workflow or inputs to the parent container (workflow) node. - SetInputsForStartNode(ctx context.Context, execContext ExecutionContext, dag DAGStructureWithStartNode, nl NodeLookup, inputs *core.LiteralMap) (NodeStatus, error) + SetInputsForStartNode(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructureWithStartNode, + nl executors.NodeLookup, inputs *core.LiteralMap) (NodeStatus, error) // This is the main entrypoint to execute a node. It recursively depth-first goes through all ready nodes and starts their execution // This returns either // - 1. It finds a blocking node (not ready, or running) // - 2. A node fails and hence the workflow will fail // - 3. The final/end node has completed and the workflow should be stopped - RecursiveNodeHandler(ctx context.Context, execContext ExecutionContext, dag DAGStructure, nl NodeLookup, currentNode v1alpha1.ExecutableNode) (NodeStatus, error) + RecursiveNodeHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, + nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) (NodeStatus, error) // This aborts the given node. If the given node is complete then it recursively finds the running nodes and aborts them - AbortHandler(ctx context.Context, execContext ExecutionContext, dag DAGStructure, nl NodeLookup, currentNode v1alpha1.ExecutableNode, reason string) error + AbortHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, + nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode, reason string) error - FinalizeHandler(ctx context.Context, execContext ExecutionContext, dag DAGStructure, nl NodeLookup, currentNode v1alpha1.ExecutableNode) error + FinalizeHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, + nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) error - // TODO @docs - //NewNodeExecContext(ctx context.Context, executionContext ExecutionContext, nl NodeLookup, currentNodeID v1alpha1.NodeID) (handler.NodeExecutionContext, error) + // TODO @hamersaw - docs + NewNodeExecutionContext(ctx context.Context, executionContext executors.ExecutionContext, + nl executors.NodeLookup, currentNodeID v1alpha1.NodeID) (NodeExecutionContext, error) // This method should be used to initialize Node executor Initialize(ctx context.Context) error diff --git a/pkg/controller/nodes/handler/node_exec_context.go b/pkg/controller/nodes/interfaces/node_exec_context.go similarity index 93% rename from pkg/controller/nodes/handler/node_exec_context.go rename to pkg/controller/nodes/interfaces/node_exec_context.go index 117358dab..db33c303c 100644 --- a/pkg/controller/nodes/handler/node_exec_context.go +++ b/pkg/controller/nodes/interfaces/node_exec_context.go @@ -1,20 +1,21 @@ -package handler +package interfaces import ( "context" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" - "github.com/flyteorg/flytestdlib/promutils" - "github.com/flyteorg/flytestdlib/storage" - v1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/types" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" "github.com/flyteorg/flytepropeller/events" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" "github.com/flyteorg/flytepropeller/pkg/controller/executors" + + "github.com/flyteorg/flytestdlib/storage" + + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" ) type TaskReader interface { @@ -23,12 +24,6 @@ type TaskReader interface { GetTaskID() *core.Identifier } -type SetupContext interface { - EnqueueOwner() func(string) - OwnerKind() string - MetricsScope() promutils.Scope -} - type NodeExecutionMetadata interface { GetOwnerID() types.NamespacedName GetNodeExecutionID() *core.NodeExecutionIdentifier diff --git a/pkg/controller/nodes/handler/state.go b/pkg/controller/nodes/interfaces/state.go similarity index 98% rename from pkg/controller/nodes/handler/state.go rename to pkg/controller/nodes/interfaces/state.go index e4cc851a2..b41cb556d 100644 --- a/pkg/controller/nodes/handler/state.go +++ b/pkg/controller/nodes/interfaces/state.go @@ -1,11 +1,14 @@ -package handler +package interfaces import ( "time" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + pluginCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytestdlib/bitarray" "github.com/flyteorg/flytestdlib/storage" ) diff --git a/pkg/controller/nodes/handler/state_test.go b/pkg/controller/nodes/interfaces/state_test.go similarity index 97% rename from pkg/controller/nodes/handler/state_test.go rename to pkg/controller/nodes/interfaces/state_test.go index 7e914422e..d7d9d62fd 100644 --- a/pkg/controller/nodes/handler/state_test.go +++ b/pkg/controller/nodes/interfaces/state_test.go @@ -1,4 +1,4 @@ -package handler +package interfaces import ( "bytes" diff --git a/pkg/controller/nodes/start/handler.go b/pkg/controller/nodes/start/handler.go index f1dda96f4..1fecdba25 100644 --- a/pkg/controller/nodes/start/handler.go +++ b/pkg/controller/nodes/start/handler.go @@ -4,6 +4,7 @@ import ( "context" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" ) type startHandler struct { @@ -17,15 +18,15 @@ func (s startHandler) Setup(ctx context.Context, setupContext handler.SetupConte return nil } -func (s startHandler) Handle(ctx context.Context, executionContext handler.NodeExecutionContext) (handler.Transition, error) { +func (s startHandler) Handle(ctx context.Context, executionContext interfaces.NodeExecutionContext) (handler.Transition, error) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(&handler.ExecutionInfo{})), nil } -func (s startHandler) Abort(ctx context.Context, executionContext handler.NodeExecutionContext, reason string) error { +func (s startHandler) Abort(ctx context.Context, executionContext interfaces.NodeExecutionContext, reason string) error { return nil } -func (s startHandler) Finalize(ctx context.Context, executionContext handler.NodeExecutionContext) error { +func (s startHandler) Finalize(ctx context.Context, executionContext interfaces.NodeExecutionContext) error { return nil } diff --git a/pkg/controller/nodes/subworkflow/handler.go b/pkg/controller/nodes/subworkflow/handler.go index bf7b5c393..7ad2b3dc7 100644 --- a/pkg/controller/nodes/subworkflow/handler.go +++ b/pkg/controller/nodes/subworkflow/handler.go @@ -14,9 +14,9 @@ import ( "github.com/flyteorg/flytestdlib/promutils/labeled" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" - "github.com/flyteorg/flytepropeller/pkg/controller/executors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" ) @@ -44,7 +44,7 @@ func (w *workflowNodeHandler) Setup(_ context.Context, _ handler.SetupContext) e return nil } -func (w *workflowNodeHandler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.Transition, error) { +func (w *workflowNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.Transition, error) { logger.Debug(ctx, "Starting workflow Node") invalidWFNodeError := func() (handler.Transition, error) { @@ -58,7 +58,7 @@ func (w *workflowNodeHandler) Handle(ctx context.Context, nCtx handler.NodeExecu return transition, err } - workflowNodeState := handler.WorkflowNodeState{Phase: newPhase} + workflowNodeState := interfaces.WorkflowNodeState{Phase: newPhase} err = nCtx.NodeStateWriter().PutWorkflowNodeState(workflowNodeState) if err != nil { logger.Errorf(ctx, "Failed to store WorkflowNodeState, err :%s", err.Error()) @@ -112,7 +112,7 @@ func (w *workflowNodeHandler) Handle(ctx context.Context, nCtx handler.NodeExecu return invalidWFNodeError() } -func (w *workflowNodeHandler) Abort(ctx context.Context, nCtx handler.NodeExecutionContext, reason string) error { +func (w *workflowNodeHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecutionContext, reason string) error { wfNode := nCtx.Node().GetWorkflowNode() if wfNode.GetSubWorkflowRef() != nil { return w.subWfHandler.HandleAbort(ctx, nCtx, reason) @@ -124,12 +124,12 @@ func (w *workflowNodeHandler) Abort(ctx context.Context, nCtx handler.NodeExecut return nil } -func (w *workflowNodeHandler) Finalize(ctx context.Context, _ handler.NodeExecutionContext) error { +func (w *workflowNodeHandler) Finalize(ctx context.Context, _ interfaces.NodeExecutionContext) error { logger.Warnf(ctx, "Subworkflow finalize invoked. Nothing to be done") return nil } -func New(executor executors.Node, workflowLauncher launchplan.Executor, recoveryClient recovery.Client, eventConfig *config.EventConfig, scope promutils.Scope) handler.Node { +func New(executor interfaces.Node, workflowLauncher launchplan.Executor, recoveryClient recovery.Client, eventConfig *config.EventConfig, scope promutils.Scope) handler.Node { workflowScope := scope.NewSubScope("workflow") return &workflowNodeHandler{ subWfHandler: newSubworkflowHandler(executor, eventConfig), diff --git a/pkg/controller/nodes/subworkflow/launchplan.go b/pkg/controller/nodes/subworkflow/launchplan.go index ec972d9fd..a3d3f39fe 100644 --- a/pkg/controller/nodes/subworkflow/launchplan.go +++ b/pkg/controller/nodes/subworkflow/launchplan.go @@ -4,22 +4,22 @@ import ( "context" "fmt" - "github.com/flyteorg/flytepropeller/pkg/controller/config" - - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/common" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flytestdlib/logger" - "github.com/flyteorg/flytestdlib/storage" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/config" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/common" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" + + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/storage" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) type launchPlanHandler struct { @@ -28,7 +28,7 @@ type launchPlanHandler struct { eventConfig *config.EventConfig } -func getParentNodeExecutionID(nCtx handler.NodeExecutionContext) (*core.NodeExecutionIdentifier, error) { +func getParentNodeExecutionID(nCtx interfaces.NodeExecutionContext) (*core.NodeExecutionIdentifier, error) { nodeExecID := &core.NodeExecutionIdentifier{ ExecutionId: nCtx.NodeExecutionMetadata().GetNodeExecutionID().ExecutionId, } @@ -45,7 +45,7 @@ func getParentNodeExecutionID(nCtx handler.NodeExecutionContext) (*core.NodeExec return nodeExecID, nil } -func (l *launchPlanHandler) StartLaunchPlan(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.Transition, error) { +func (l *launchPlanHandler) StartLaunchPlan(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.Transition, error) { nodeInputs, err := nCtx.InputReader().Get(ctx) if err != nil { errMsg := fmt.Sprintf("Failed to read input. Error [%s]", err) @@ -122,7 +122,7 @@ func (l *launchPlanHandler) StartLaunchPlan(ctx context.Context, nCtx handler.No })), nil } -func GetChildWorkflowExecutionIDForExecution(parentNodeExecID *core.NodeExecutionIdentifier, nCtx handler.NodeExecutionContext) (*core.WorkflowExecutionIdentifier, error) { +func GetChildWorkflowExecutionIDForExecution(parentNodeExecID *core.NodeExecutionIdentifier, nCtx interfaces.NodeExecutionContext) (*core.WorkflowExecutionIdentifier, error) { // Handle launch plan if nCtx.ExecutionContext().GetDefinitionVersion() == v1alpha1.WorkflowDefinitionVersion0 { return GetChildWorkflowExecutionID( @@ -137,7 +137,7 @@ func GetChildWorkflowExecutionIDForExecution(parentNodeExecID *core.NodeExecutio ) } -func (l *launchPlanHandler) CheckLaunchPlanStatus(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.Transition, error) { +func (l *launchPlanHandler) CheckLaunchPlanStatus(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.Transition, error) { parentNodeExecutionID, err := getParentNodeExecutionID(nCtx) if err != nil { return handler.UnknownTransition, err @@ -226,7 +226,7 @@ func (l *launchPlanHandler) CheckLaunchPlanStatus(ctx context.Context, nCtx hand return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(nil)), nil } -func (l *launchPlanHandler) HandleAbort(ctx context.Context, nCtx handler.NodeExecutionContext, reason string) error { +func (l *launchPlanHandler) HandleAbort(ctx context.Context, nCtx interfaces.NodeExecutionContext, reason string) error { parentNodeExecutionID, err := getParentNodeExecutionID(nCtx) if err != nil { return err diff --git a/pkg/controller/nodes/subworkflow/subworkflow.go b/pkg/controller/nodes/subworkflow/subworkflow.go index 74beeaf79..52626730c 100644 --- a/pkg/controller/nodes/subworkflow/subworkflow.go +++ b/pkg/controller/nodes/subworkflow/subworkflow.go @@ -4,28 +4,28 @@ import ( "context" "fmt" - "github.com/flyteorg/flytepropeller/pkg/controller/config" - - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/common" - "github.com/flyteorg/flytestdlib/logger" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flytestdlib/storage" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/config" "github.com/flyteorg/flytepropeller/pkg/controller/executors" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/common" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/storage" ) // Subworkflow handler handles inline subWorkflows type subworkflowHandler struct { - nodeExecutor executors.Node + nodeExecutor interfaces.Node eventConfig *config.EventConfig } // Helper method that extracts the SubWorkflow from the ExecutionContext -func GetSubWorkflow(ctx context.Context, nCtx handler.NodeExecutionContext) (v1alpha1.ExecutableSubWorkflow, error) { +func GetSubWorkflow(ctx context.Context, nCtx interfaces.NodeExecutionContext) (v1alpha1.ExecutableSubWorkflow, error) { node := nCtx.Node() subID := *node.GetWorkflowNode().GetSubWorkflowRef() subWorkflow := nCtx.ExecutionContext().FindSubWorkflow(subID) @@ -36,7 +36,7 @@ func GetSubWorkflow(ctx context.Context, nCtx handler.NodeExecutionContext) (v1a } // Performs an additional step of passing in and setting the inputs, before handling the execution of a SubWorkflow. -func (s *subworkflowHandler) startAndHandleSubWorkflow(ctx context.Context, nCtx handler.NodeExecutionContext, subWorkflow v1alpha1.ExecutableSubWorkflow, nl executors.NodeLookup) (handler.Transition, error) { +func (s *subworkflowHandler) startAndHandleSubWorkflow(ctx context.Context, nCtx interfaces.NodeExecutionContext, subWorkflow v1alpha1.ExecutableSubWorkflow, nl executors.NodeLookup) (handler.Transition, error) { // Before starting the subworkflow, lets set the inputs for the Workflow. The inputs for a SubWorkflow are essentially // Copy of the inputs to the Node nodeInputs, err := nCtx.InputReader().Get(ctx) @@ -63,7 +63,7 @@ func (s *subworkflowHandler) startAndHandleSubWorkflow(ctx context.Context, nCtx } // Calls the recursive node executor to handle the SubWorkflow and translates the results after the success -func (s *subworkflowHandler) handleSubWorkflow(ctx context.Context, nCtx handler.NodeExecutionContext, subworkflow v1alpha1.ExecutableSubWorkflow, nl executors.NodeLookup) (handler.Transition, error) { +func (s *subworkflowHandler) handleSubWorkflow(ctx context.Context, nCtx interfaces.NodeExecutionContext, subworkflow v1alpha1.ExecutableSubWorkflow, nl executors.NodeLookup) (handler.Transition, error) { // The current node would end up becoming the parent for the sub workflow nodes. // This is done to track the lineage. For level zero, the CreateParentInfo will return nil newParentInfo, err := common.CreateParentInfo(nCtx.ExecutionContext().GetParentInfo(), nCtx.NodeID(), nCtx.CurrentAttempt()) @@ -77,7 +77,7 @@ func (s *subworkflowHandler) handleSubWorkflow(ctx context.Context, nCtx handler } if state.HasFailed() { - workflowNodeState := handler.WorkflowNodeState{ + workflowNodeState := interfaces.WorkflowNodeState{ Phase: v1alpha1.WorkflowNodePhaseFailing, Error: state.Err, } @@ -135,7 +135,7 @@ func (s *subworkflowHandler) handleSubWorkflow(ctx context.Context, nCtx handler return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(nil)), nil } -func (s *subworkflowHandler) getExecutionContextForDownstream(nCtx handler.NodeExecutionContext) (executors.ExecutionContext, error) { +func (s *subworkflowHandler) getExecutionContextForDownstream(nCtx interfaces.NodeExecutionContext) (executors.ExecutionContext, error) { newParentInfo, err := common.CreateParentInfo(nCtx.ExecutionContext().GetParentInfo(), nCtx.NodeID(), nCtx.CurrentAttempt()) if err != nil { return nil, err @@ -143,7 +143,7 @@ func (s *subworkflowHandler) getExecutionContextForDownstream(nCtx handler.NodeE return executors.NewExecutionContextWithParentInfo(nCtx.ExecutionContext(), newParentInfo), nil } -func (s *subworkflowHandler) HandleFailureNodeOfSubWorkflow(ctx context.Context, nCtx handler.NodeExecutionContext, subworkflow v1alpha1.ExecutableSubWorkflow, nl executors.NodeLookup) (handler.Transition, error) { +func (s *subworkflowHandler) HandleFailureNodeOfSubWorkflow(ctx context.Context, nCtx interfaces.NodeExecutionContext, subworkflow v1alpha1.ExecutableSubWorkflow, nl executors.NodeLookup) (handler.Transition, error) { originalError := nCtx.NodeStateReader().GetWorkflowNodeState().Error if subworkflow.GetOnFailureNode() != nil { execContext, err := s.getExecutionContextForDownstream(nCtx) @@ -155,7 +155,7 @@ func (s *subworkflowHandler) HandleFailureNodeOfSubWorkflow(ctx context.Context, return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoUndefined), err } - if state.NodePhase == executors.NodePhaseRunning { + if state.NodePhase == interfaces.NodePhaseRunning { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(nil)), nil } @@ -185,7 +185,7 @@ func (s *subworkflowHandler) HandleFailureNodeOfSubWorkflow(ctx context.Context, originalError, nil)), nil } -func (s *subworkflowHandler) HandleFailingSubWorkflow(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.Transition, error) { +func (s *subworkflowHandler) HandleFailingSubWorkflow(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.Transition, error) { subWorkflow, err := GetSubWorkflow(ctx, nCtx) if err != nil { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, errors.SubWorkflowExecutionFailed, err.Error(), nil)), nil @@ -213,7 +213,7 @@ func (s *subworkflowHandler) HandleFailingSubWorkflow(ctx context.Context, nCtx return s.HandleFailureNodeOfSubWorkflow(ctx, nCtx, subWorkflow, nodeLookup) } -func (s *subworkflowHandler) StartSubWorkflow(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.Transition, error) { +func (s *subworkflowHandler) StartSubWorkflow(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.Transition, error) { subWorkflow, err := GetSubWorkflow(ctx, nCtx) if err != nil { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, errors.SubWorkflowExecutionFailed, err.Error(), nil)), nil @@ -226,7 +226,7 @@ func (s *subworkflowHandler) StartSubWorkflow(ctx context.Context, nCtx handler. return s.startAndHandleSubWorkflow(ctx, nCtx, subWorkflow, nodeLookup) } -func (s *subworkflowHandler) CheckSubWorkflowStatus(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.Transition, error) { +func (s *subworkflowHandler) CheckSubWorkflowStatus(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.Transition, error) { subWorkflow, err := GetSubWorkflow(ctx, nCtx) if err != nil { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, errors.SubWorkflowExecutionFailed, err.Error(), nil)), nil @@ -237,7 +237,7 @@ func (s *subworkflowHandler) CheckSubWorkflowStatus(ctx context.Context, nCtx ha return s.handleSubWorkflow(ctx, nCtx, subWorkflow, nodeLookup) } -func (s *subworkflowHandler) HandleAbort(ctx context.Context, nCtx handler.NodeExecutionContext, reason string) error { +func (s *subworkflowHandler) HandleAbort(ctx context.Context, nCtx interfaces.NodeExecutionContext, reason string) error { subWorkflow, err := GetSubWorkflow(ctx, nCtx) if err != nil { return err @@ -251,7 +251,7 @@ func (s *subworkflowHandler) HandleAbort(ctx context.Context, nCtx handler.NodeE return s.nodeExecutor.AbortHandler(ctx, execContext, subWorkflow, nodeLookup, subWorkflow.StartNode(), reason) } -func newSubworkflowHandler(nodeExecutor executors.Node, eventConfig *config.EventConfig) subworkflowHandler { +func newSubworkflowHandler(nodeExecutor interfaces.Node, eventConfig *config.EventConfig) subworkflowHandler { return subworkflowHandler{ nodeExecutor: nodeExecutor, eventConfig: eventConfig, diff --git a/pkg/controller/nodes/task/handler.go b/pkg/controller/nodes/task/handler.go index 8c1ce6575..572261d62 100644 --- a/pkg/controller/nodes/task/handler.go +++ b/pkg/controller/nodes/task/handler.go @@ -6,39 +6,39 @@ import ( "runtime/debug" "time" - eventsErr "github.com/flyteorg/flytepropeller/events/errors" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" - "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" - - "github.com/flyteorg/flytepropeller/pkg/utils" - - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + pluginMachinery "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog" pluginCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" pluginK8s "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s" + + eventsErr "github.com/flyteorg/flytepropeller/events/errors" + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" controllerConfig "github.com/flyteorg/flytepropeller/pkg/controller/config" + "github.com/flyteorg/flytepropeller/pkg/controller/executors" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/config" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/resourcemanager" + rmConfig "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/resourcemanager/config" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/secretmanager" + "github.com/flyteorg/flytepropeller/pkg/utils" + "github.com/flyteorg/flytestdlib/contextutils" "github.com/flyteorg/flytestdlib/logger" "github.com/flyteorg/flytestdlib/promutils" "github.com/flyteorg/flytestdlib/promutils/labeled" "github.com/flyteorg/flytestdlib/storage" - "github.com/golang/protobuf/ptypes" - regErrors "github.com/pkg/errors" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/resourcemanager" - rmConfig "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/resourcemanager/config" + "github.com/golang/protobuf/ptypes" - "github.com/flyteorg/flytepropeller/pkg/controller/executors" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/config" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/secretmanager" + regErrors "github.com/pkg/errors" ) const pluginContextKey = contextutils.Key("plugin") @@ -380,7 +380,7 @@ func (t Handler) fetchPluginTaskMetrics(pluginID, taskType string) (*taskMetrics return t.taskMetricsMap[metricNameKey], nil } -func (t Handler) invokePlugin(ctx context.Context, p pluginCore.Plugin, tCtx *taskExecutionContext, ts handler.TaskNodeState) (*pluginRequestedTransition, error) { +func (t Handler) invokePlugin(ctx context.Context, p pluginCore.Plugin, tCtx *taskExecutionContext, ts interfaces.TaskNodeState) (*pluginRequestedTransition, error) { pluginTrns := &pluginRequestedTransition{} trns, err := func() (trns pluginCore.Transition, err error) { @@ -533,7 +533,7 @@ func (t Handler) invokePlugin(ctx context.Context, p pluginCore.Plugin, tCtx *ta return pluginTrns, nil } -func (t Handler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.Transition, error) { +func (t Handler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.Transition, error) { ttype := nCtx.TaskReader().GetTaskType() ctx = contextutils.WithTaskType(ctx, ttype) p, err := t.ResolvePlugin(ctx, ttype, nCtx.ExecutionContext().GetExecutionConfig()) @@ -770,7 +770,7 @@ func (t Handler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext) } // STEP 6: Persist the plugin state - err = nCtx.NodeStateWriter().PutTaskNodeState(handler.TaskNodeState{ + err = nCtx.NodeStateWriter().PutTaskNodeState(interfaces.TaskNodeState{ PluginState: pluginTrns.pluginState, PluginStateVersion: pluginTrns.pluginStateVersion, PluginPhase: pluginTrns.pInfo.Phase(), @@ -791,7 +791,7 @@ func (t Handler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext) return pluginTrns.FinalTransition(ctx) } -func (t Handler) Abort(ctx context.Context, nCtx handler.NodeExecutionContext, reason string) error { +func (t Handler) Abort(ctx context.Context, nCtx interfaces.NodeExecutionContext, reason string) error { currentPhase := nCtx.NodeStateReader().GetTaskNodeState().PluginPhase logger.Debugf(ctx, "Abort invoked with phase [%v]", currentPhase) @@ -856,7 +856,7 @@ func (t Handler) Abort(ctx context.Context, nCtx handler.NodeExecutionContext, r return nil } -func (t Handler) Finalize(ctx context.Context, nCtx handler.NodeExecutionContext) error { +func (t Handler) Finalize(ctx context.Context, nCtx interfaces.NodeExecutionContext) error { logger.Debugf(ctx, "Finalize invoked.") ttype := nCtx.TaskReader().GetTaskType() p, err := t.ResolvePlugin(ctx, ttype, nCtx.ExecutionContext().GetExecutionConfig()) diff --git a/pkg/controller/nodes/task/taskexec_context.go b/pkg/controller/nodes/task/taskexec_context.go index 9e8296b3c..2a6c59c0c 100644 --- a/pkg/controller/nodes/task/taskexec_context.go +++ b/pkg/controller/nodes/task/taskexec_context.go @@ -15,7 +15,7 @@ import ( "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/common" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/resourcemanager" "github.com/flyteorg/flytepropeller/pkg/utils" "github.com/flyteorg/flytestdlib/logger" @@ -58,7 +58,7 @@ func (te taskExecutionID) GetGeneratedNameWith(minLength, maxLength int) (string } type taskExecutionMetadata struct { - handler.NodeExecutionMetadata + interfaces.NodeExecutionMetadata taskExecID taskExecutionID o pluginCore.TaskOverrides maxAttempts uint32 @@ -82,7 +82,7 @@ func (t taskExecutionMetadata) GetPlatformResources() *v1.ResourceRequirements { } type taskExecutionContext struct { - handler.NodeExecutionContext + interfaces.NodeExecutionContext tm taskExecutionMetadata rm resourcemanager.TaskResourceManager psm *pluginStateManager @@ -203,7 +203,7 @@ func convertTaskResourcesToRequirements(taskResources v1alpha1.TaskResources) *v // ComputeRawOutputPrefix constructs the output directory, where raw outputs of a task can be stored by the task. FlytePropeller may not have // access to this location and can be passed in per execution. // the function also returns the uniqueID generated -func ComputeRawOutputPrefix(ctx context.Context, length int, nCtx handler.NodeExecutionContext, currentNodeUniqueID v1alpha1.NodeID, currentAttempt uint32) (io.RawOutputPaths, string, error) { +func ComputeRawOutputPrefix(ctx context.Context, length int, nCtx interfaces.NodeExecutionContext, currentNodeUniqueID v1alpha1.NodeID, currentAttempt uint32) (io.RawOutputPaths, string, error) { uniqueID, err := encoding.FixedLengthUniqueIDForParts(length, []string{nCtx.NodeExecutionMetadata().GetOwnerID().Name, currentNodeUniqueID, strconv.Itoa(int(currentAttempt))}) if err != nil { // SHOULD never really happen @@ -218,7 +218,7 @@ func ComputeRawOutputPrefix(ctx context.Context, length int, nCtx handler.NodeEx } // ComputePreviousCheckpointPath returns the checkpoint path for the previous attempt, if this is the first attempt then returns an empty path -func ComputePreviousCheckpointPath(ctx context.Context, length int, nCtx handler.NodeExecutionContext, currentNodeUniqueID v1alpha1.NodeID, currentAttempt uint32) (storage.DataReference, error) { +func ComputePreviousCheckpointPath(ctx context.Context, length int, nCtx interfaces.NodeExecutionContext, currentNodeUniqueID v1alpha1.NodeID, currentAttempt uint32) (storage.DataReference, error) { // If first attempt for this node execution, look for a checkpoint path in a prior execution if currentAttempt == 0 { return nCtx.NodeStateReader().GetTaskNodeState().PreviousNodeExecutionCheckpointURI, nil @@ -232,7 +232,7 @@ func ComputePreviousCheckpointPath(ctx context.Context, length int, nCtx handler return ioutils.ConstructCheckpointPath(nCtx.DataStore(), prevRawOutputPrefix.GetRawOutputPrefix()), nil } -func (t *Handler) newTaskExecutionContext(ctx context.Context, nCtx handler.NodeExecutionContext, plugin pluginCore.Plugin) (*taskExecutionContext, error) { +func (t *Handler) newTaskExecutionContext(ctx context.Context, nCtx interfaces.NodeExecutionContext, plugin pluginCore.Plugin) (*taskExecutionContext, error) { id := GetTaskExecutionIdentifier(nCtx) currentNodeUniqueID := nCtx.NodeID() diff --git a/pkg/controller/nodes/task/transformer.go b/pkg/controller/nodes/task/transformer.go index 6faa93f70..21edca723 100644 --- a/pkg/controller/nodes/task/transformer.go +++ b/pkg/controller/nodes/task/transformer.go @@ -12,6 +12,7 @@ import ( "github.com/flyteorg/flytepropeller/pkg/controller/executors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/common" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" "github.com/golang/protobuf/ptypes" timestamppb "github.com/golang/protobuf/ptypes/timestamp" @@ -75,7 +76,7 @@ type ToTaskExecutionEventInputs struct { EventConfig *config.EventConfig OutputWriter io.OutputFilePaths Info pluginCore.PhaseInfo - NodeExecutionMetadata handler.NodeExecutionMetadata + NodeExecutionMetadata interfaces.NodeExecutionMetadata ExecContext executors.ExecutionContext TaskType string PluginID string @@ -185,7 +186,7 @@ func ToTaskExecutionEvent(input ToTaskExecutionEventInputs) (*event.TaskExecutio return tev, nil } -func GetTaskExecutionIdentifier(nCtx handler.NodeExecutionContext) *core.TaskExecutionIdentifier { +func GetTaskExecutionIdentifier(nCtx interfaces.NodeExecutionContext) *core.TaskExecutionIdentifier { return &core.TaskExecutionIdentifier{ TaskId: nCtx.TaskReader().GetTaskID(), RetryAttempt: nCtx.CurrentAttempt(), diff --git a/pkg/controller/workflow/executor.go b/pkg/controller/workflow/executor.go index e3eac1e37..11a766213 100644 --- a/pkg/controller/workflow/executor.go +++ b/pkg/controller/workflow/executor.go @@ -5,23 +5,25 @@ import ( "fmt" "time" - "github.com/flyteorg/flytepropeller/pkg/controller/config" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/flyteorg/flytestdlib/logger" - "github.com/flyteorg/flytestdlib/promutils" - "github.com/flyteorg/flytestdlib/promutils/labeled" - "github.com/flyteorg/flytestdlib/storage" - corev1 "k8s.io/api/core/v1" - "k8s.io/client-go/tools/record" "github.com/flyteorg/flytepropeller/events" eventsErr "github.com/flyteorg/flytepropeller/events/errors" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/config" "github.com/flyteorg/flytepropeller/pkg/controller/executors" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flytepropeller/pkg/controller/workflow/errors" "github.com/flyteorg/flytepropeller/pkg/utils" + + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/promutils" + "github.com/flyteorg/flytestdlib/promutils/labeled" + "github.com/flyteorg/flytestdlib/storage" + + corev1 "k8s.io/api/core/v1" + "k8s.io/client-go/tools/record" ) type workflowMetrics struct { @@ -64,7 +66,7 @@ type workflowExecutor struct { wfRecorder events.WorkflowEventRecorder k8sRecorder record.EventRecorder metadataPrefix storage.DataReference - nodeExecutor executors.Node + nodeExecutor interfaces.Node metrics *workflowMetrics eventConfig *config.EventConfig clusterID string @@ -495,7 +497,7 @@ func (c *workflowExecutor) cleanupRunningNodes(ctx context.Context, w v1alpha1.E } func NewExecutor(ctx context.Context, store *storage.DataStore, enQWorkflow v1alpha1.EnqueueWorkflow, eventSink events.EventSink, - k8sEventRecorder record.EventRecorder, metadataPrefix string, nodeExecutor executors.Node, eventConfig *config.EventConfig, + k8sEventRecorder record.EventRecorder, metadataPrefix string, nodeExecutor interfaces.Node, eventConfig *config.EventConfig, clusterID string, scope promutils.Scope) (executors.Workflow, error) { basePrefix := store.GetBaseContainerFQN(ctx) if metadataPrefix != "" { From 188a675941cf84d2a5677c4678929d2ec83df2ee Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Thu, 6 Apr 2023 11:49:33 -0500 Subject: [PATCH 05/62] refactoring almost complete Signed-off-by: Daniel Rammer --- pkg/controller/nodes/array/handler.go | 14 +----- pkg/controller/nodes/array/node_executor.go | 33 +++++++++++-- pkg/controller/nodes/executor.go | 13 +++--- pkg/controller/nodes/handler_factory.go | 3 +- pkg/controller/nodes/node_exec_context.go | 15 +++--- pkg/controller/nodes/node_state_manager.go | 52 ++++++++++----------- 6 files changed, 72 insertions(+), 58 deletions(-) diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index a2b9d470a..5966ebec5 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -151,7 +151,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // create new arrayNodeExecutionContext to override for array task execution arrayNodeExecutionContext := newArrayNodeExecutionContext(nodeExecutionContext, inputReader)*/ - arrayNodeExecutor := newArrayNodeExecutor(a.nodeExecutor) + arrayNodeExecutor := newArrayNodeExecutor(a.nodeExecutor, subNodeID, inputReader) // execute subNode through RecursiveNodeHandler nodeStatus, err := arrayNodeExecutor.RecursiveNodeHandler(ctx, nCtx.ExecutionContext(), &arrayNodeLookup, &arrayNodeLookup, subNodeSpec) @@ -160,18 +160,6 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu } fmt.Printf("HAMERSAW - node phase transition %d -> %d", nodePhase, nodeStatus.NodePhase) - - // execute subNode through RecursiveNodeHandler - // TODO @hamersaw - either - // (1) add func to create nodeExecutionContext to RecursiveNodeHandler - // (2) create nodeExecutionContext before call to RecursiveNodeHandler - // can do with small wrapper function call - /*nodeStatus, err := a.nodeExecutor.RecursiveNodeHandler(ctx, arrayNodeExecutionContext, &arrayNodeLookup, &arrayNodeLookup, subNodeSpec) - if err != nil { - // TODO @hamersaw fail - }*/ - - // handleNode / abort / finalize task nodeExecutionContext and Handler as parameters - THIS IS THE ENTRYPOINT WE'RE LOOKING FOR } // TODO @hamersaw - determine summary phases diff --git a/pkg/controller/nodes/array/node_executor.go b/pkg/controller/nodes/array/node_executor.go index 81aa59754..fdee0e92d 100644 --- a/pkg/controller/nodes/array/node_executor.go +++ b/pkg/controller/nodes/array/node_executor.go @@ -1,19 +1,42 @@ package array import ( + "context" + + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" + + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/executors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" ) type arrayNodeExecutor struct { interfaces.Node + subNodeID v1alpha1.NodeID + inputReader io.InputReader } -/* -TODO @hamersaw - override NewNodeExecutionContext function -*/ +// TODO @hamersaw - docs +func (a *arrayNodeExecutor) NewNodeExecutionContext(ctx context.Context, executionContext executors.ExecutionContext, + nl executors.NodeLookup, currentNodeID v1alpha1.NodeID) (interfaces.NodeExecutionContext, error) { + + // create base NodeExecutionContext + nCtx, err := a.Node.NewNodeExecutionContext(ctx, executionContext, nl, currentNodeID) + if err != nil { + return nil, err + } + + if currentNodeID == a.subNodeID { + // TODO @hamersaw - overwrite NodeExecutionContext for ArrayNode execution + } + + return nCtx, nil +} -func newArrayNodeExecutor(nodeExecutor interfaces.Node) arrayNodeExecutor { +func newArrayNodeExecutor(nodeExecutor interfaces.Node, subNodeID v1alpha1.NodeID, inputReader io.InputReader) arrayNodeExecutor { return arrayNodeExecutor{ - Node: nodeExecutor, + Node: nodeExecutor, + subNodeID: subNodeID, + inputReader: inputReader, } } diff --git a/pkg/controller/nodes/executor.go b/pkg/controller/nodes/executor.go index cb3a1ceee..120848bda 100644 --- a/pkg/controller/nodes/executor.go +++ b/pkg/controller/nodes/executor.go @@ -34,6 +34,7 @@ import ( "github.com/flyteorg/flytepropeller/pkg/controller/nodes/common" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task" @@ -155,7 +156,7 @@ func (c *nodeExecutor) IdempotentRecordEvent(ctx context.Context, nodeEvent *eve return err } -func (c *nodeExecutor) recoverInputs(ctx context.Context, nCtx handler.NodeExecutionContext, +func (c *nodeExecutor) recoverInputs(ctx context.Context, nCtx interfaces.NodeExecutionContext, recovered *admin.NodeExecution, recoveredData *admin.NodeExecutionGetDataResponse) (*core.LiteralMap, error) { nodeInputs := recoveredData.FullInputs @@ -184,7 +185,7 @@ func (c *nodeExecutor) recoverInputs(ctx context.Context, nCtx handler.NodeExecu return nodeInputs, nil } -func (c *nodeExecutor) attemptRecovery(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.PhaseInfo, error) { +func (c *nodeExecutor) attemptRecovery(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.PhaseInfo, error) { fullyQualifiedNodeID := nCtx.NodeExecutionMetadata().GetNodeExecutionID().NodeId if nCtx.ExecutionContext().GetEventVersion() != v1alpha1.EventVersion0 { // compute fully qualified node id (prefixed with parent id and retry attempt) to ensure uniqueness @@ -364,7 +365,7 @@ func (c *nodeExecutor) attemptRecovery(ctx context.Context, nCtx handler.NodeExe // In this method we check if the queue is ready to be processed and if so, we prime it in Admin as queued // Before we start the node execution, we need to transition this Node status to Queued. // This is because a node execution has to exist before task/wf executions can start. -func (c *nodeExecutor) preExecute(ctx context.Context, dag executors.DAGStructure, nCtx handler.NodeExecutionContext) ( +func (c *nodeExecutor) preExecute(ctx context.Context, dag executors.DAGStructure, nCtx interfaces.NodeExecutionContext) ( handler.PhaseInfo, error) { logger.Debugf(ctx, "Node not yet started") // Query the nodes information to figure out if it can be executed. @@ -504,7 +505,7 @@ func (c *nodeExecutor) execute(ctx context.Context, h handler.Node, nCtx *nodeEx return phase, nil } -func (c *nodeExecutor) abort(ctx context.Context, h handler.Node, nCtx handler.NodeExecutionContext, reason string) error { +func (c *nodeExecutor) abort(ctx context.Context, h handler.Node, nCtx interfaces.NodeExecutionContext, reason string) error { logger.Debugf(ctx, "Calling aborting & finalize") if err := h.Abort(ctx, nCtx, reason); err != nil { finalizeErr := h.Finalize(ctx, nCtx) @@ -517,7 +518,7 @@ func (c *nodeExecutor) abort(ctx context.Context, h handler.Node, nCtx handler.N return h.Finalize(ctx, nCtx) } -func (c *nodeExecutor) finalize(ctx context.Context, h handler.Node, nCtx handler.NodeExecutionContext) error { +func (c *nodeExecutor) finalize(ctx context.Context, h handler.Node, nCtx interfaces.NodeExecutionContext) error { return h.Finalize(ctx, nCtx) } @@ -1210,7 +1211,7 @@ func (c *nodeExecutor) Initialize(ctx context.Context) error { func NewExecutor(ctx context.Context, nodeConfig config.NodeConfig, store *storage.DataStore, enQWorkflow v1alpha1.EnqueueWorkflow, eventSink events.EventSink, workflowLauncher launchplan.Executor, launchPlanReader launchplan.Reader, maxDatasetSize int64, defaultRawOutputPrefix storage.DataReference, kubeClient executors.Client, - catalogClient catalog.Client, recoveryClient recovery.Client, eventConfig *config.EventConfig, clusterID string, signalClient service.SignalServiceClient, scope promutils.Scope) (executors.Node, error) { + catalogClient catalog.Client, recoveryClient recovery.Client, eventConfig *config.EventConfig, clusterID string, signalClient service.SignalServiceClient, scope promutils.Scope) (interfaces.Node, error) { // TODO we may want to make this configurable. shardSelector, err := ioutils.NewBase36PrefixShardSelector(ctx) diff --git a/pkg/controller/nodes/handler_factory.go b/pkg/controller/nodes/handler_factory.go index d1df39cab..a275b59a7 100644 --- a/pkg/controller/nodes/handler_factory.go +++ b/pkg/controller/nodes/handler_factory.go @@ -16,6 +16,7 @@ import ( "github.com/flyteorg/flytepropeller/pkg/controller/nodes/end" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/gate" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/start" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow" @@ -55,7 +56,7 @@ func (f handlerFactory) Setup(ctx context.Context, setup handler.SetupContext) e return nil } -func NewHandlerFactory(ctx context.Context, executor executors.Node, workflowLauncher launchplan.Executor, +func NewHandlerFactory(ctx context.Context, executor interfaces.Node, workflowLauncher launchplan.Executor, launchPlanReader launchplan.Reader, kubeClient executors.Client, client catalog.Client, recoveryClient recovery.Client, eventConfig *config.EventConfig, clusterID string, signalClient service.SignalServiceClient, scope promutils.Scope) (HandlerFactory, error) { diff --git a/pkg/controller/nodes/node_exec_context.go b/pkg/controller/nodes/node_exec_context.go index 8914407c6..ba0a72b47 100644 --- a/pkg/controller/nodes/node_exec_context.go +++ b/pkg/controller/nodes/node_exec_context.go @@ -16,6 +16,7 @@ import ( "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" "github.com/flyteorg/flytepropeller/pkg/controller/executors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flytepropeller/pkg/utils" ) @@ -57,8 +58,8 @@ func (e nodeExecMetadata) GetLabels() map[string]string { type nodeExecContext struct { store *storage.DataStore - tr handler.TaskReader - md handler.NodeExecutionMetadata + tr interfaces.TaskReader + md interfaces.NodeExecutionMetadata er events.TaskEventRecorder inputs io.InputReader node v1alpha1.ExecutableNode @@ -92,15 +93,15 @@ func (e nodeExecContext) EnqueueOwnerFunc() func() error { return e.enqueueOwner } -func (e nodeExecContext) TaskReader() handler.TaskReader { +func (e nodeExecContext) TaskReader() interfaces.TaskReader { return e.tr } -func (e nodeExecContext) NodeStateReader() handler.NodeStateReader { +func (e nodeExecContext) NodeStateReader() interfaces.NodeStateReader { return e.nsm } -func (e nodeExecContext) NodeStateWriter() handler.NodeStateWriter { +func (e nodeExecContext) NodeStateWriter() interfaces.NodeStateWriter { return e.nsm } @@ -132,7 +133,7 @@ func (e nodeExecContext) NodeStatus() v1alpha1.ExecutableNodeStatus { return e.nodeStatus } -func (e nodeExecContext) NodeExecutionMetadata() handler.NodeExecutionMetadata { +func (e nodeExecContext) NodeExecutionMetadata() interfaces.NodeExecutionMetadata { return e.md } @@ -142,7 +143,7 @@ func (e nodeExecContext) MaxDatasetSizeBytes() int64 { func newNodeExecContext(_ context.Context, store *storage.DataStore, execContext executors.ExecutionContext, nl executors.NodeLookup, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus, inputs io.InputReader, interruptible bool, interruptibleFailureThreshold uint32, - maxDatasetSize int64, er events.TaskEventRecorder, tr handler.TaskReader, nsm *nodeStateManager, + maxDatasetSize int64, er events.TaskEventRecorder, tr interfaces.TaskReader, nsm *nodeStateManager, enqueueOwner func() error, rawOutputPrefix storage.DataReference, outputShardSelector ioutils.ShardSelector) *nodeExecContext { md := nodeExecMetadata{ diff --git a/pkg/controller/nodes/node_state_manager.go b/pkg/controller/nodes/node_state_manager.go index f5600b98c..13a27e4da 100644 --- a/pkg/controller/nodes/node_state_manager.go +++ b/pkg/controller/nodes/node_state_manager.go @@ -6,53 +6,53 @@ import ( pluginCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" ) type nodeStateManager struct { nodeStatus v1alpha1.ExecutableNodeStatus - t *handler.TaskNodeState - b *handler.BranchNodeState - d *handler.DynamicNodeState - w *handler.WorkflowNodeState - g *handler.GateNodeState - a *handler.ArrayNodeState + t *interfaces.TaskNodeState + b *interfaces.BranchNodeState + d *interfaces.DynamicNodeState + w *interfaces.WorkflowNodeState + g *interfaces.GateNodeState + a *interfaces.ArrayNodeState } -func (n *nodeStateManager) PutTaskNodeState(s handler.TaskNodeState) error { +func (n *nodeStateManager) PutTaskNodeState(s interfaces.TaskNodeState) error { n.t = &s return nil } -func (n *nodeStateManager) PutBranchNode(s handler.BranchNodeState) error { +func (n *nodeStateManager) PutBranchNode(s interfaces.BranchNodeState) error { n.b = &s return nil } -func (n *nodeStateManager) PutDynamicNodeState(s handler.DynamicNodeState) error { +func (n *nodeStateManager) PutDynamicNodeState(s interfaces.DynamicNodeState) error { n.d = &s return nil } -func (n *nodeStateManager) PutWorkflowNodeState(s handler.WorkflowNodeState) error { +func (n *nodeStateManager) PutWorkflowNodeState(s interfaces.WorkflowNodeState) error { n.w = &s return nil } -func (n *nodeStateManager) PutGateNodeState(s handler.GateNodeState) error { +func (n *nodeStateManager) PutGateNodeState(s interfaces.GateNodeState) error { n.g = &s return nil } -func (n *nodeStateManager) PutArrayNodeState(s handler.ArrayNodeState) error { +func (n *nodeStateManager) PutArrayNodeState(s interfaces.ArrayNodeState) error { n.a = &s return nil } -func (n nodeStateManager) GetTaskNodeState() handler.TaskNodeState { +func (n nodeStateManager) GetTaskNodeState() interfaces.TaskNodeState { tn := n.nodeStatus.GetTaskNodeStatus() if tn != nil { - return handler.TaskNodeState{ + return interfaces.TaskNodeState{ PluginPhase: pluginCore.Phase(tn.GetPhase()), PluginPhaseVersion: tn.GetPhaseVersion(), PluginStateVersion: tn.GetPluginStateVersion(), @@ -62,12 +62,12 @@ func (n nodeStateManager) GetTaskNodeState() handler.TaskNodeState { PreviousNodeExecutionCheckpointURI: tn.GetPreviousNodeExecutionCheckpointPath(), } } - return handler.TaskNodeState{} + return interfaces.TaskNodeState{} } -func (n nodeStateManager) GetBranchNode() handler.BranchNodeState { +func (n nodeStateManager) GetBranchNode() interfaces.BranchNodeState { bn := n.nodeStatus.GetBranchStatus() - bs := handler.BranchNodeState{} + bs := interfaces.BranchNodeState{} if bn != nil { bs.Phase = bn.GetPhase() bs.FinalizedNodeID = bn.GetFinalizedNode() @@ -75,9 +75,9 @@ func (n nodeStateManager) GetBranchNode() handler.BranchNodeState { return bs } -func (n nodeStateManager) GetDynamicNodeState() handler.DynamicNodeState { +func (n nodeStateManager) GetDynamicNodeState() interfaces.DynamicNodeState { dn := n.nodeStatus.GetDynamicNodeStatus() - ds := handler.DynamicNodeState{} + ds := interfaces.DynamicNodeState{} if dn != nil { ds.Phase = dn.GetDynamicNodePhase() ds.Reason = dn.GetDynamicNodeReason() @@ -87,9 +87,9 @@ func (n nodeStateManager) GetDynamicNodeState() handler.DynamicNodeState { return ds } -func (n nodeStateManager) GetWorkflowNodeState() handler.WorkflowNodeState { +func (n nodeStateManager) GetWorkflowNodeState() interfaces.WorkflowNodeState { wn := n.nodeStatus.GetWorkflowNodeStatus() - ws := handler.WorkflowNodeState{} + ws := interfaces.WorkflowNodeState{} if wn != nil { ws.Phase = wn.GetWorkflowNodePhase() ws.Error = wn.GetExecutionError() @@ -97,18 +97,18 @@ func (n nodeStateManager) GetWorkflowNodeState() handler.WorkflowNodeState { return ws } -func (n nodeStateManager) GetGateNodeState() handler.GateNodeState { +func (n nodeStateManager) GetGateNodeState() interfaces.GateNodeState { gn := n.nodeStatus.GetGateNodeStatus() - gs := handler.GateNodeState{} + gs := interfaces.GateNodeState{} if gn != nil { gs.Phase = gn.GetGateNodePhase() } return gs } -func (n nodeStateManager) GetArrayNodeState() handler.ArrayNodeState { +func (n nodeStateManager) GetArrayNodeState() interfaces.ArrayNodeState { an := n.nodeStatus.GetArrayNodeStatus() - as := handler.ArrayNodeState{} + as := interfaces.ArrayNodeState{} if an != nil { as.Phase = an.GetArrayNodePhase() as.SubNodePhases = an.GetSubNodePhases() From 6f379e75fa8f78eab6388fc12b428cbc01e2b9fe Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Thu, 6 Apr 2023 14:05:12 -0500 Subject: [PATCH 06/62] refactor complete Signed-off-by: Daniel Rammer --- pkg/controller/nodes/branch/handler.go | 8 +- pkg/controller/nodes/executor.go | 136 ++++++++++----------- pkg/controller/nodes/interfaces/state.go | 9 +- pkg/controller/nodes/node_exec_context.go | 5 +- pkg/controller/nodes/node_state_manager.go | 28 ++++- pkg/controller/nodes/transformers.go | 70 ++++++----- 6 files changed, 147 insertions(+), 109 deletions(-) diff --git a/pkg/controller/nodes/branch/handler.go b/pkg/controller/nodes/branch/handler.go index 91f822bda..50e47dd68 100644 --- a/pkg/controller/nodes/branch/handler.go +++ b/pkg/controller/nodes/branch/handler.go @@ -39,7 +39,7 @@ func (b *branchHandler) Setup(ctx context.Context, _ handler.SetupContext) error } func (b *branchHandler) HandleBranchNode(ctx context.Context, branchNode v1alpha1.ExecutableBranchNode, nCtx interfaces.NodeExecutionContext, nl executors.NodeLookup) (handler.Transition, error) { - if nCtx.NodeStateReader().GetBranchNode().FinalizedNodeID == nil { + if nCtx.NodeStateReader().GetBranchNodeState().FinalizedNodeID == nil { nodeInputs, err := nCtx.InputReader().Get(ctx) if err != nil { errMsg := fmt.Sprintf("Failed to read input. Error [%s]", err) @@ -81,7 +81,7 @@ func (b *branchHandler) HandleBranchNode(ctx context.Context, branchNode v1alpha } // If the branchNodestatus was already evaluated i.e, Node is in Running status - branchStatus := nCtx.NodeStateReader().GetBranchNode() + branchStatus := nCtx.NodeStateReader().GetBranchNodeState() userError := branchNode.GetElseFail() finalNodeID := branchStatus.FinalizedNodeID if finalNodeID == nil { @@ -177,7 +177,7 @@ func (b *branchHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecution } // If the branch was already evaluated i.e, Node is in Running status - branchNodeState := nCtx.NodeStateReader().GetBranchNode() + branchNodeState := nCtx.NodeStateReader().GetBranchNodeState() if branchNodeState.Phase == v1alpha1.BranchNodeNotYetEvaluated { logger.Errorf(ctx, "No node finalized through previous branch evaluation.") return nil @@ -221,7 +221,7 @@ func (b *branchHandler) Finalize(ctx context.Context, nCtx interfaces.NodeExecut } // If the branch was already evaluated i.e, Node is in Running status - branchNodeState := nCtx.NodeStateReader().GetBranchNode() + branchNodeState := nCtx.NodeStateReader().GetBranchNodeState() if branchNodeState.Phase == v1alpha1.BranchNodeNotYetEvaluated { logger.Errorf(ctx, "No node finalized through previous branch evaluation.") return nil diff --git a/pkg/controller/nodes/executor.go b/pkg/controller/nodes/executor.go index 120848bda..997c07075 100644 --- a/pkg/controller/nodes/executor.go +++ b/pkg/controller/nodes/executor.go @@ -438,7 +438,7 @@ func isTimeoutExpired(queuedAt *metav1.Time, timeout time.Duration) bool { return false } -func (c *nodeExecutor) isEligibleForRetry(nCtx *nodeExecContext, nodeStatus v1alpha1.ExecutableNodeStatus, err *core.ExecutionError) (currentAttempt, maxAttempts uint32, isEligible bool) { +func (c *nodeExecutor) isEligibleForRetry(nCtx interfaces.NodeExecutionContext, nodeStatus v1alpha1.ExecutableNodeStatus, err *core.ExecutionError) (currentAttempt, maxAttempts uint32, isEligible bool) { if err.Kind == core.ExecutionError_SYSTEM { currentAttempt = nodeStatus.GetSystemFailures() maxAttempts = c.maxNodeRetriesForSystemFailures @@ -454,7 +454,7 @@ func (c *nodeExecutor) isEligibleForRetry(nCtx *nodeExecContext, nodeStatus v1al return } -func (c *nodeExecutor) execute(ctx context.Context, h handler.Node, nCtx *nodeExecContext, nodeStatus v1alpha1.ExecutableNodeStatus) (handler.PhaseInfo, error) { +func (c *nodeExecutor) execute(ctx context.Context, h handler.Node, nCtx interfaces.NodeExecutionContext, nodeStatus v1alpha1.ExecutableNodeStatus) (handler.PhaseInfo, error) { logger.Debugf(ctx, "Executing node") defer logger.Debugf(ctx, "Node execution round complete") @@ -499,7 +499,7 @@ func (c *nodeExecutor) execute(ctx context.Context, h handler.Node, nCtx *nodeEx } // Retrying to clearing all status - nCtx.nsm.clearNodeStatus() + nCtx.NodeStateWriter().ClearNodeStatus() } return phase, nil @@ -522,27 +522,27 @@ func (c *nodeExecutor) finalize(ctx context.Context, h handler.Node, nCtx interf return h.Finalize(ctx, nCtx) } -func (c *nodeExecutor) handleNotYetStartedNode(ctx context.Context, dag executors.DAGStructure, nCtx *nodeExecContext, _ handler.Node) (executors.NodeStatus, error) { +func (c *nodeExecutor) handleNotYetStartedNode(ctx context.Context, dag executors.DAGStructure, nCtx interfaces.NodeExecutionContext, _ handler.Node) (interfaces.NodeStatus, error) { logger.Debugf(ctx, "Node not yet started, running pre-execute") defer logger.Debugf(ctx, "Node pre-execute completed") occurredAt := time.Now() p, err := c.preExecute(ctx, dag, nCtx) if err != nil { logger.Errorf(ctx, "failed preExecute for node. Error: %s", err.Error()) - return executors.NodeStatusUndefined, err + return interfaces.NodeStatusUndefined, err } if p.GetPhase() == handler.EPhaseUndefined { - return executors.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, nCtx.NodeID(), "received undefined phase.") + return interfaces.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, nCtx.NodeID(), "received undefined phase.") } if p.GetPhase() == handler.EPhaseNotReady { - return executors.NodeStatusPending, nil + return interfaces.NodeStatusPending, nil } np, err := ToNodePhase(p.GetPhase()) if err != nil { - return executors.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "failed to move from queued") + return interfaces.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "failed to move from queued") } nodeStatus := nCtx.NodeStatus() @@ -553,33 +553,33 @@ func (c *nodeExecutor) handleNotYetStartedNode(ctx context.Context, dag executor nev, err := ToNodeExecutionEvent(nCtx.NodeExecutionMetadata().GetNodeExecutionID(), p, nCtx.InputReader().GetInputPath().String(), nodeStatus, nCtx.ExecutionContext().GetEventVersion(), - nCtx.ExecutionContext().GetParentInfo(), nCtx.node, c.clusterID, nCtx.NodeStateReader().GetDynamicNodeState().Phase, + nCtx.ExecutionContext().GetParentInfo(), nCtx.Node(), c.clusterID, nCtx.NodeStateReader().GetDynamicNodeState().Phase, c.eventConfig) if err != nil { - return executors.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "could not convert phase info to event") + return interfaces.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "could not convert phase info to event") } err = c.IdempotentRecordEvent(ctx, nev) if err != nil { logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) - return executors.NodeStatusUndefined, errors.Wrapf(errors.EventRecordingFailed, nCtx.NodeID(), err, "failed to record node event") + return interfaces.NodeStatusUndefined, errors.Wrapf(errors.EventRecordingFailed, nCtx.NodeID(), err, "failed to record node event") } - UpdateNodeStatus(np, p, nCtx.nsm, nodeStatus) + UpdateNodeStatus(np, p, nCtx.NodeStateReader(), nodeStatus) c.RecordTransitionLatency(ctx, dag, nCtx.ContextualNodeLookup(), nCtx.Node(), nodeStatus) } if np == v1alpha1.NodePhaseQueued { - if nCtx.md.IsInterruptible() { + if nCtx.NodeExecutionMetadata().IsInterruptible() { c.metrics.InterruptibleNodesRunning.Inc(ctx) } - return executors.NodeStatusQueued, nil + return interfaces.NodeStatusQueued, nil } else if np == v1alpha1.NodePhaseSkipped { - return executors.NodeStatusSuccess, nil + return interfaces.NodeStatusSuccess, nil } - return executors.NodeStatusPending, nil + return interfaces.NodeStatusPending, nil } -func (c *nodeExecutor) handleQueuedOrRunningNode(ctx context.Context, nCtx *nodeExecContext, h handler.Node) (executors.NodeStatus, error) { +func (c *nodeExecutor) handleQueuedOrRunningNode(ctx context.Context, nCtx interfaces.NodeExecutionContext, h handler.Node) (interfaces.NodeStatus, error) { nodeStatus := nCtx.NodeStatus() currentPhase := nodeStatus.GetPhase() @@ -594,16 +594,16 @@ func (c *nodeExecutor) handleQueuedOrRunningNode(ctx context.Context, nCtx *node p, err := c.execute(ctx, h, nCtx, nodeStatus) if err != nil { logger.Errorf(ctx, "failed Execute for node. Error: %s", err.Error()) - return executors.NodeStatusUndefined, err + return interfaces.NodeStatusUndefined, err } if p.GetPhase() == handler.EPhaseUndefined { - return executors.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, nCtx.NodeID(), "received undefined phase.") + return interfaces.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, nCtx.NodeID(), "received undefined phase.") } np, err := ToNodePhase(p.GetPhase()) if err != nil { - return executors.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "failed to move from queued") + return interfaces.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "failed to move from queued") } // execErr in phase-info 'p' is only available if node has failed to execute, and the current phase at that time @@ -638,27 +638,27 @@ func (c *nodeExecutor) handleQueuedOrRunningNode(ctx context.Context, nCtx *node } } } - finalStatus := executors.NodeStatusRunning + finalStatus := interfaces.NodeStatusRunning if np == v1alpha1.NodePhaseFailing && !h.FinalizeRequired() { logger.Infof(ctx, "Finalize not required, moving node to Failed") np = v1alpha1.NodePhaseFailed - finalStatus = executors.NodeStatusFailed(p.GetErr()) + finalStatus = interfaces.NodeStatusFailed(p.GetErr()) } if np == v1alpha1.NodePhaseTimingOut && !h.FinalizeRequired() { logger.Infof(ctx, "Finalize not required, moving node to TimedOut") np = v1alpha1.NodePhaseTimedOut - finalStatus = executors.NodeStatusTimedOut + finalStatus = interfaces.NodeStatusTimedOut } if np == v1alpha1.NodePhaseSucceeding && !h.FinalizeRequired() { logger.Infof(ctx, "Finalize not required, moving node to Succeeded") np = v1alpha1.NodePhaseSucceeded - finalStatus = executors.NodeStatusSuccess + finalStatus = interfaces.NodeStatusSuccess } if np == v1alpha1.NodePhaseRecovered { logger.Infof(ctx, "Finalize not required, moving node to Recovered") - finalStatus = executors.NodeStatusRecovered + finalStatus = interfaces.NodeStatusRecovered } // If it is retryable failure, we do no want to send any events, as the node is essentially still running @@ -669,10 +669,10 @@ func (c *nodeExecutor) handleQueuedOrRunningNode(ctx context.Context, nCtx *node nev, err := ToNodeExecutionEvent(nCtx.NodeExecutionMetadata().GetNodeExecutionID(), p, nCtx.InputReader().GetInputPath().String(), nCtx.NodeStatus(), nCtx.ExecutionContext().GetEventVersion(), - nCtx.ExecutionContext().GetParentInfo(), nCtx.node, c.clusterID, nCtx.NodeStateReader().GetDynamicNodeState().Phase, + nCtx.ExecutionContext().GetParentInfo(), nCtx.Node(), c.clusterID, nCtx.NodeStateReader().GetDynamicNodeState().Phase, c.eventConfig) if err != nil { - return executors.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "could not convert phase info to event") + return interfaces.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "could not convert phase info to event") } err = c.IdempotentRecordEvent(ctx, nev) @@ -698,11 +698,11 @@ func (c *nodeExecutor) handleQueuedOrRunningNode(ctx context.Context, nCtx *node }) if err != nil { - return executors.NodeStatusUndefined, errors.Wrapf(errors.EventRecordingFailed, nCtx.NodeID(), err, "failed to record node event") + return interfaces.NodeStatusUndefined, errors.Wrapf(errors.EventRecordingFailed, nCtx.NodeID(), err, "failed to record node event") } } else { logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) - return executors.NodeStatusUndefined, errors.Wrapf(errors.EventRecordingFailed, nCtx.NodeID(), err, "failed to record node event") + return interfaces.NodeStatusUndefined, errors.Wrapf(errors.EventRecordingFailed, nCtx.NodeID(), err, "failed to record node event") } } @@ -714,15 +714,15 @@ func (c *nodeExecutor) handleQueuedOrRunningNode(ctx context.Context, nCtx *node } } - UpdateNodeStatus(np, p, nCtx.nsm, nodeStatus) + UpdateNodeStatus(np, p, nCtx.NodeStateReader(), nodeStatus) return finalStatus, nil } -func (c *nodeExecutor) handleRetryableFailure(ctx context.Context, nCtx *nodeExecContext, h handler.Node) (executors.NodeStatus, error) { +func (c *nodeExecutor) handleRetryableFailure(ctx context.Context, nCtx interfaces.NodeExecutionContext, h handler.Node) (interfaces.NodeStatus, error) { nodeStatus := nCtx.NodeStatus() logger.Debugf(ctx, "node failed with retryable failure, aborting and finalizing, message: %s", nodeStatus.GetMessage()) if err := c.abort(ctx, h, nCtx, nodeStatus.GetMessage()); err != nil { - return executors.NodeStatusUndefined, err + return interfaces.NodeStatusUndefined, err } // NOTE: It is important to increment attempts only after abort has been called. Increment attempt mutates the state @@ -736,10 +736,10 @@ func (c *nodeExecutor) handleRetryableFailure(ctx context.Context, nCtx *nodeExe nodeStatus.ClearDynamicNodeStatus() nodeStatus.ClearGateNodeStatus() nodeStatus.ClearArrayNodeStatus() - return executors.NodeStatusPending, nil + return interfaces.NodeStatusPending, nil } -func (c *nodeExecutor) handleNode(ctx context.Context, dag executors.DAGStructure, nCtx *nodeExecContext, h handler.Node) (executors.NodeStatus, error) { +func (c *nodeExecutor) handleNode(ctx context.Context, dag executors.DAGStructure, nCtx interfaces.NodeExecutionContext, h handler.Node) (interfaces.NodeStatus, error) { logger.Debugf(ctx, "Handling Node [%s]", nCtx.NodeID()) defer logger.Debugf(ctx, "Completed node [%s]", nCtx.NodeID()) @@ -753,7 +753,7 @@ func (c *nodeExecutor) handleNode(ctx context.Context, dag executors.DAGStructur if err != nil { return p, err } - if p.NodePhase == executors.NodePhaseQueued { + if p.NodePhase == interfaces.NodePhaseQueued { logger.Infof(ctx, "Node was queued, parallelism is now [%d]", nCtx.ExecutionContext().IncrementParallelism()) } return p, err @@ -762,35 +762,35 @@ func (c *nodeExecutor) handleNode(ctx context.Context, dag executors.DAGStructur if currentPhase == v1alpha1.NodePhaseFailing { logger.Debugf(ctx, "node failing") if err := c.abort(ctx, h, nCtx, "node failing"); err != nil { - return executors.NodeStatusUndefined, err + return interfaces.NodeStatusUndefined, err } nodeStatus.UpdatePhase(v1alpha1.NodePhaseFailed, metav1.Now(), nodeStatus.GetMessage(), nodeStatus.GetExecutionError()) c.metrics.FailureDuration.Observe(ctx, nodeStatus.GetStartedAt().Time, nodeStatus.GetStoppedAt().Time) - if nCtx.md.IsInterruptible() { + if nCtx.NodeExecutionMetadata().IsInterruptible() { c.metrics.InterruptibleNodesTerminated.Inc(ctx) } - return executors.NodeStatusFailed(nodeStatus.GetExecutionError()), nil + return interfaces.NodeStatusFailed(nodeStatus.GetExecutionError()), nil } if currentPhase == v1alpha1.NodePhaseTimingOut { logger.Debugf(ctx, "node timing out") if err := c.abort(ctx, h, nCtx, "node timed out"); err != nil { - return executors.NodeStatusUndefined, err + return interfaces.NodeStatusUndefined, err } nodeStatus.ClearSubNodeStatus() nodeStatus.UpdatePhase(v1alpha1.NodePhaseTimedOut, metav1.Now(), nodeStatus.GetMessage(), nodeStatus.GetExecutionError()) c.metrics.TimedOutFailure.Inc(ctx) - if nCtx.md.IsInterruptible() { + if nCtx.NodeExecutionMetadata().IsInterruptible() { c.metrics.InterruptibleNodesTerminated.Inc(ctx) } - return executors.NodeStatusTimedOut, nil + return interfaces.NodeStatusTimedOut, nil } if currentPhase == v1alpha1.NodePhaseSucceeding { logger.Debugf(ctx, "node succeeding") if err := c.finalize(ctx, h, nCtx); err != nil { - return executors.NodeStatusUndefined, err + return interfaces.NodeStatusUndefined, err } t := metav1.Now() @@ -805,10 +805,10 @@ func (c *nodeExecutor) handleNode(ctx context.Context, dag executors.DAGStructur c.metrics.SuccessDuration.Observe(ctx, started.Time, stopped.Time) nodeStatus.ClearSubNodeStatus() nodeStatus.UpdatePhase(v1alpha1.NodePhaseSucceeded, t, "completed successfully", nil) - if nCtx.md.IsInterruptible() { + if nCtx.NodeExecutionMetadata().IsInterruptible() { c.metrics.InterruptibleNodesTerminated.Inc(ctx) } - return executors.NodeStatusSuccess, nil + return interfaces.NodeStatusSuccess, nil } if currentPhase == v1alpha1.NodePhaseRetryableFailure { @@ -817,7 +817,7 @@ func (c *nodeExecutor) handleNode(ctx context.Context, dag executors.DAGStructur if currentPhase == v1alpha1.NodePhaseFailed { // This should never happen - return executors.NodeStatusFailed(nodeStatus.GetExecutionError()), nil + return interfaces.NodeStatusFailed(nodeStatus.GetExecutionError()), nil } return c.handleQueuedOrRunningNode(ctx, nCtx, h) @@ -825,13 +825,13 @@ func (c *nodeExecutor) handleNode(ctx context.Context, dag executors.DAGStructur // The space search for the next node to execute is implemented like a DFS algorithm. handleDownstream visits all the nodes downstream from // the currentNode. Visit a node is the RecursiveNodeHandler. A visit may be partial, complete or may result in a failure. -func (c *nodeExecutor) handleDownstream(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) (executors.NodeStatus, error) { +func (c *nodeExecutor) handleDownstream(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) (interfaces.NodeStatus, error) { logger.Debugf(ctx, "Handling downstream Nodes") // This node is success. Handle all downstream nodes downstreamNodes, err := dag.FromNode(currentNode.GetID()) if err != nil { logger.Debugf(ctx, "Error when retrieving downstream nodes, [%s]", err) - return executors.NodeStatusFailed(&core.ExecutionError{ + return interfaces.NodeStatusFailed(&core.ExecutionError{ Code: errors.BadSpecificationError, Message: fmt.Sprintf("failed to retrieve downstream nodes for [%s]", currentNode.GetID()), Kind: core.ExecutionError_SYSTEM, @@ -839,7 +839,7 @@ func (c *nodeExecutor) handleDownstream(ctx context.Context, execContext executo } if len(downstreamNodes) == 0 { logger.Debugf(ctx, "No downstream nodes found. Complete.") - return executors.NodeStatusComplete, nil + return interfaces.NodeStatusComplete, nil } // If any downstream node is failed, fail, all // Else if all are success then success @@ -847,11 +847,11 @@ func (c *nodeExecutor) handleDownstream(ctx context.Context, execContext executo allCompleted := true partialNodeCompletion := false onFailurePolicy := execContext.GetOnFailurePolicy() - stateOnComplete := executors.NodeStatusComplete + stateOnComplete := interfaces.NodeStatusComplete for _, downstreamNodeName := range downstreamNodes { downstreamNode, ok := nl.GetNode(downstreamNodeName) if !ok { - return executors.NodeStatusFailed(&core.ExecutionError{ + return interfaces.NodeStatusFailed(&core.ExecutionError{ Code: errors.BadSpecificationError, Message: fmt.Sprintf("failed to retrieve downstream node [%s] for [%s]", downstreamNodeName, currentNode.GetID()), Kind: core.ExecutionError_SYSTEM, @@ -860,7 +860,7 @@ func (c *nodeExecutor) handleDownstream(ctx context.Context, execContext executo state, err := c.RecursiveNodeHandler(ctx, execContext, dag, nl, downstreamNode) if err != nil { - return executors.NodeStatusUndefined, err + return interfaces.NodeStatusUndefined, err } if state.HasFailed() || state.HasTimedOut() { @@ -895,35 +895,35 @@ func (c *nodeExecutor) handleDownstream(ctx context.Context, execContext executo } if partialNodeCompletion { - return executors.NodeStatusSuccess, nil + return interfaces.NodeStatusSuccess, nil } - return executors.NodeStatusPending, nil + return interfaces.NodeStatusPending, nil } -func (c *nodeExecutor) SetInputsForStartNode(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructureWithStartNode, nl executors.NodeLookup, inputs *core.LiteralMap) (executors.NodeStatus, error) { +func (c *nodeExecutor) SetInputsForStartNode(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructureWithStartNode, nl executors.NodeLookup, inputs *core.LiteralMap) (interfaces.NodeStatus, error) { startNode := dag.StartNode() ctx = contextutils.WithNodeID(ctx, startNode.GetID()) if inputs == nil { logger.Infof(ctx, "No inputs for the workflow. Skipping storing inputs") - return executors.NodeStatusComplete, nil + return interfaces.NodeStatusComplete, nil } // StartNode is special. It does not have any processing step. It just takes the workflow (or subworkflow) inputs and converts to its own outputs nodeStatus := nl.GetNodeExecutionStatus(ctx, startNode.GetID()) if len(nodeStatus.GetDataDir()) == 0 { - return executors.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, startNode.GetID(), "no data-dir set, cannot store inputs") + return interfaces.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, startNode.GetID(), "no data-dir set, cannot store inputs") } outputFile := v1alpha1.GetOutputsFile(nodeStatus.GetOutputDir()) so := storage.Options{} if err := c.store.WriteProtobuf(ctx, outputFile, so, inputs); err != nil { logger.Errorf(ctx, "Failed to write protobuf (metadata). Error [%v]", err) - return executors.NodeStatusUndefined, errors.Wrapf(errors.CausedByError, startNode.GetID(), err, "Failed to store workflow inputs (as start node)") + return interfaces.NodeStatusUndefined, errors.Wrapf(errors.CausedByError, startNode.GetID(), err, "Failed to store workflow inputs (as start node)") } - return executors.NodeStatusComplete, nil + return interfaces.NodeStatusComplete, nil } func canHandleNode(phase v1alpha1.NodePhase) bool { @@ -972,7 +972,7 @@ func IsMaxParallelismAchieved(ctx context.Context, currentNode v1alpha1.Executab // The recursive node-handler uses a modified depth-first type of algorithm to execute non-blocked nodes. func (c *nodeExecutor) RecursiveNodeHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) ( - executors.NodeStatus, error) { + interfaces.NodeStatus, error) { currentNodeCtx := contextutils.WithNodeID(ctx, currentNode.GetID()) nodeStatus := nl.GetNodeExecutionStatus(ctx, currentNode.GetID()) @@ -994,18 +994,18 @@ func (c *nodeExecutor) RecursiveNodeHandler(ctx context.Context, execContext exe // This is an optimization to avoid creating the nodeContext object in case the node has already been looked at. // If the overhead was zero, we would just do the isDirtyCheck after the nodeContext is created if nodeStatus.IsDirty() { - return executors.NodeStatusRunning, nil + return interfaces.NodeStatusRunning, nil } if IsMaxParallelismAchieved(ctx, currentNode, nodePhase, execContext) { - return executors.NodeStatusRunning, nil + return interfaces.NodeStatusRunning, nil } nCtx, err := c.NewNodeExecutionContext(ctx, execContext, nl, currentNode.GetID()) if err != nil { // NodeExecution creation failure is a permanent fail / system error. // Should a system failure always return an err? - return executors.NodeStatusFailed(&core.ExecutionError{ + return interfaces.NodeStatusFailed(&core.ExecutionError{ Code: "InternalError", Message: err.Error(), Kind: core.ExecutionError_SYSTEM, @@ -1015,7 +1015,7 @@ func (c *nodeExecutor) RecursiveNodeHandler(ctx context.Context, execContext exe // Now depending on the node type decide h, err := c.nodeHandlerFactory.GetHandler(nCtx.Node().GetKind()) if err != nil { - return executors.NodeStatusUndefined, err + return interfaces.NodeStatusUndefined, err } return c.handleNode(currentNodeCtx, dag, nCtx, h) @@ -1030,21 +1030,21 @@ func (c *nodeExecutor) RecursiveNodeHandler(ctx context.Context, execContext exe logger.Debugf(currentNodeCtx, "Node has failed, traversing downstream.") _, err := c.handleDownstream(ctx, execContext, dag, nl, currentNode) if err != nil { - return executors.NodeStatusUndefined, err + return interfaces.NodeStatusUndefined, err } - return executors.NodeStatusFailed(nodeStatus.GetExecutionError()), nil + return interfaces.NodeStatusFailed(nodeStatus.GetExecutionError()), nil } else if nodePhase == v1alpha1.NodePhaseTimedOut { logger.Debugf(currentNodeCtx, "Node has timed out, traversing downstream.") _, err := c.handleDownstream(ctx, execContext, dag, nl, currentNode) if err != nil { - return executors.NodeStatusUndefined, err + return interfaces.NodeStatusUndefined, err } - return executors.NodeStatusTimedOut, nil + return interfaces.NodeStatusTimedOut, nil } - return executors.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, currentNode.GetID(), + return interfaces.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, currentNode.GetID(), "Should never reach here. Current Phase: %v", nodePhase) } diff --git a/pkg/controller/nodes/interfaces/state.go b/pkg/controller/nodes/interfaces/state.go index b41cb556d..e83bb8a65 100644 --- a/pkg/controller/nodes/interfaces/state.go +++ b/pkg/controller/nodes/interfaces/state.go @@ -61,13 +61,20 @@ type NodeStateWriter interface { PutWorkflowNodeState(s WorkflowNodeState) error PutGateNodeState(s GateNodeState) error PutArrayNodeState(s ArrayNodeState) error + ClearNodeStatus() } type NodeStateReader interface { + HasTaskNodeState() bool GetTaskNodeState() TaskNodeState - GetBranchNode() BranchNodeState + HasBranchNodeState() bool + GetBranchNodeState() BranchNodeState + HasDynamicNodeState() bool GetDynamicNodeState() DynamicNodeState + HasWorkflowNodeState() bool GetWorkflowNodeState() WorkflowNodeState + HasGateNodeState() bool GetGateNodeState() GateNodeState + HasArrayNodeState() bool GetArrayNodeState() ArrayNodeState } diff --git a/pkg/controller/nodes/node_exec_context.go b/pkg/controller/nodes/node_exec_context.go index ba0a72b47..27c3e82c7 100644 --- a/pkg/controller/nodes/node_exec_context.go +++ b/pkg/controller/nodes/node_exec_context.go @@ -15,7 +15,6 @@ import ( "github.com/flyteorg/flytepropeller/events" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" "github.com/flyteorg/flytepropeller/pkg/controller/executors" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flytepropeller/pkg/utils" ) @@ -187,13 +186,13 @@ func newNodeExecContext(_ context.Context, store *storage.DataStore, execContext } func (c *nodeExecutor) NewNodeExecutionContext(ctx context.Context, executionContext executors.ExecutionContext, - nl executors.NodeLookup, currentNodeID v1alpha1.NodeID) (*nodeExecContext, error) { + nl executors.NodeLookup, currentNodeID v1alpha1.NodeID) (interfaces.NodeExecutionContext, error) { n, ok := nl.GetNode(currentNodeID) if !ok { return nil, fmt.Errorf("failed to find node with ID [%s] in execution [%s]", currentNodeID, executionContext.GetID()) } - var tr handler.TaskReader + var tr interfaces.TaskReader if n.GetKind() == v1alpha1.NodeKindTask { if n.GetTaskID() == nil { return nil, fmt.Errorf("bad state, no task-id defined for node [%s]", n.GetID()) diff --git a/pkg/controller/nodes/node_state_manager.go b/pkg/controller/nodes/node_state_manager.go index 13a27e4da..dd0e2c18c 100644 --- a/pkg/controller/nodes/node_state_manager.go +++ b/pkg/controller/nodes/node_state_manager.go @@ -49,6 +49,30 @@ func (n *nodeStateManager) PutArrayNodeState(s interfaces.ArrayNodeState) error return nil } +func (n *nodeStateManager) HasTaskNodeState() bool { + return n.t != nil +} + +func (n *nodeStateManager) HasBranchNodeState() bool { + return n.b != nil +} + +func (n *nodeStateManager) HasDynamicNodeState() bool { + return n.d != nil +} + +func (n *nodeStateManager) HasWorkflowNodeState() bool { + return n.w != nil +} + +func (n *nodeStateManager) HasGateNodeState() bool { + return n.g != nil +} + +func (n *nodeStateManager) HasArrayNodeState() bool { + return n.a != nil +} + func (n nodeStateManager) GetTaskNodeState() interfaces.TaskNodeState { tn := n.nodeStatus.GetTaskNodeStatus() if tn != nil { @@ -65,7 +89,7 @@ func (n nodeStateManager) GetTaskNodeState() interfaces.TaskNodeState { return interfaces.TaskNodeState{} } -func (n nodeStateManager) GetBranchNode() interfaces.BranchNodeState { +func (n nodeStateManager) GetBranchNodeState() interfaces.BranchNodeState { bn := n.nodeStatus.GetBranchStatus() bs := interfaces.BranchNodeState{} if bn != nil { @@ -116,7 +140,7 @@ func (n nodeStateManager) GetArrayNodeState() interfaces.ArrayNodeState { return as } -func (n *nodeStateManager) clearNodeStatus() { +func (n *nodeStateManager) ClearNodeStatus() { n.t = nil n.b = nil n.d = nil diff --git a/pkg/controller/nodes/transformers.go b/pkg/controller/nodes/transformers.go index 8576cac6d..44e2dffbb 100644 --- a/pkg/controller/nodes/transformers.go +++ b/pkg/controller/nodes/transformers.go @@ -6,19 +6,21 @@ import ( "strconv" "time" - "github.com/flyteorg/flytepropeller/pkg/controller/config" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/config" "github.com/flyteorg/flytepropeller/pkg/controller/executors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/common" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" "github.com/flyteorg/flytestdlib/logger" + "github.com/golang/protobuf/ptypes" - v1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) // This is used by flyteadmin to indicate that the events will now contain populated IsParent and IsDynamic bits. @@ -226,60 +228,66 @@ func ToK8sTime(t time.Time) v1.Time { return v1.Time{Time: t} } -func UpdateNodeStatus(np v1alpha1.NodePhase, p handler.PhaseInfo, n *nodeStateManager, s v1alpha1.ExecutableNodeStatus) { +func UpdateNodeStatus(np v1alpha1.NodePhase, p handler.PhaseInfo, n interfaces.NodeStateReader, s v1alpha1.ExecutableNodeStatus) { // We update the phase only if it is not already updated if np != s.GetPhase() { s.UpdatePhase(np, ToK8sTime(p.GetOccurredAt()), p.GetReason(), p.GetErr()) } // Update TaskStatus - if n.t != nil { + if n.HasTaskNodeState() { + nt := n.GetTaskNodeState() t := s.GetOrCreateTaskStatus() - t.SetPhaseVersion(n.t.PluginPhaseVersion) - t.SetPhase(int(n.t.PluginPhase)) - t.SetLastPhaseUpdatedAt(n.t.LastPhaseUpdatedAt) - t.SetPluginState(n.t.PluginState) - t.SetPluginStateVersion(n.t.PluginStateVersion) - t.SetBarrierClockTick(n.t.BarrierClockTick) - t.SetPreviousNodeExecutionCheckpointPath(n.t.PreviousNodeExecutionCheckpointURI) + t.SetPhaseVersion(nt.PluginPhaseVersion) + t.SetPhase(int(nt.PluginPhase)) + t.SetLastPhaseUpdatedAt(nt.LastPhaseUpdatedAt) + t.SetPluginState(nt.PluginState) + t.SetPluginStateVersion(nt.PluginStateVersion) + t.SetBarrierClockTick(nt.BarrierClockTick) + t.SetPreviousNodeExecutionCheckpointPath(nt.PreviousNodeExecutionCheckpointURI) } // Update dynamic node status - if n.d != nil { + if n.HasDynamicNodeState() { + nd := n.GetDynamicNodeState() t := s.GetOrCreateDynamicNodeStatus() - t.SetDynamicNodePhase(n.d.Phase) - t.SetDynamicNodeReason(n.d.Reason) - t.SetExecutionError(n.d.Error) + t.SetDynamicNodePhase(nd.Phase) + t.SetDynamicNodeReason(nd.Reason) + t.SetExecutionError(nd.Error) } // Update branch node status - if n.b != nil { + if n.HasBranchNodeState() { + nb := n.GetBranchNodeState() t := s.GetOrCreateBranchStatus() - if n.b.Phase == v1alpha1.BranchNodeError { + if nb.Phase == v1alpha1.BranchNodeError { t.SetBranchNodeError() - } else if n.b.FinalizedNodeID != nil { - t.SetBranchNodeSuccess(*n.b.FinalizedNodeID) + } else if nb.FinalizedNodeID != nil { + t.SetBranchNodeSuccess(*nb.FinalizedNodeID) } else { logger.Warnf(context.TODO(), "branch node status neither success nor error set") } } // Update workflow node status - if n.w != nil { + if n.HasWorkflowNodeState() { + nw := n.GetWorkflowNodeState() t := s.GetOrCreateWorkflowStatus() - t.SetWorkflowNodePhase(n.w.Phase) - t.SetExecutionError(n.w.Error) + t.SetWorkflowNodePhase(nw.Phase) + t.SetExecutionError(nw.Error) } // Update gate node status - if n.g != nil { + if n.HasGateNodeState() { + ng := n.GetGateNodeState() t := s.GetOrCreateGateNodeStatus() - t.SetGateNodePhase(n.g.Phase) + t.SetGateNodePhase(ng.Phase) } // Update array node status - if n.a != nil { + if n.HasArrayNodeState() { + na := n.GetArrayNodeState() t := s.GetOrCreateArrayNodeStatus() - t.SetArrayNodePhase(n.a.Phase) - t.SetSubNodePhases(n.a.SubNodePhases) + t.SetArrayNodePhase(na.Phase) + t.SetSubNodePhases(na.SubNodePhases) } } From caa9b07c831acb1c3a8408425ffbed0d5a6015ff Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Fri, 7 Apr 2023 12:01:14 -0500 Subject: [PATCH 07/62] supporting environment variables Signed-off-by: Daniel Rammer --- pkg/apis/flyteworkflow/v1alpha1/array.go | 7 ++- .../v1alpha1/execution_config.go | 2 + pkg/apis/flyteworkflow/v1alpha1/iface.go | 1 + pkg/compiler/transformers/k8s/node.go | 11 +++- .../nodes/array/execution_context.go | 56 +++++++++++++++---- pkg/controller/nodes/array/handler.go | 39 ++++++------- pkg/controller/nodes/array/input_reader.go | 37 ++++++++++++ pkg/controller/nodes/array/node_executor.go | 19 ++++--- pkg/controller/nodes/node_exec_context.go | 1 + pkg/controller/nodes/node_state_manager.go | 24 ++++++++ pkg/controller/nodes/task/taskexec_context.go | 14 +++-- 11 files changed, 168 insertions(+), 43 deletions(-) create mode 100644 pkg/controller/nodes/array/input_reader.go diff --git a/pkg/apis/flyteworkflow/v1alpha1/array.go b/pkg/apis/flyteworkflow/v1alpha1/array.go index 91e86bd21..0809c1a64 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/array.go +++ b/pkg/apis/flyteworkflow/v1alpha1/array.go @@ -4,5 +4,10 @@ import ( ) type ArrayNodeSpec struct { - // TODO @hamersaw - fill out evaluation + SubNodeSpec *NodeSpec + // TODO @hamersaw - fill out ArrayNodeSpec +} + +func (a *ArrayNodeSpec) GetSubNodeSpec() *NodeSpec { + return a.SubNodeSpec } diff --git a/pkg/apis/flyteworkflow/v1alpha1/execution_config.go b/pkg/apis/flyteworkflow/v1alpha1/execution_config.go index 878df4768..e56315c1c 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/execution_config.go +++ b/pkg/apis/flyteworkflow/v1alpha1/execution_config.go @@ -32,6 +32,8 @@ type ExecutionConfig struct { Interruptible *bool // Defines whether a workflow should skip all its cached results and re-compute its output, overwriting any already stored data. OverwriteCache bool + // TODO @hamersaw - docs + EnvironmentVariables map[string]string } type TaskPluginOverride struct { diff --git a/pkg/apis/flyteworkflow/v1alpha1/iface.go b/pkg/apis/flyteworkflow/v1alpha1/iface.go index 35feb2644..14fd1ba52 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/iface.go +++ b/pkg/apis/flyteworkflow/v1alpha1/iface.go @@ -256,6 +256,7 @@ type ExecutableGateNode interface { } type ExecutableArrayNode interface { + GetSubNodeSpec() *NodeSpec // TODO @hamersaw - complete ExecutableArrayNode } diff --git a/pkg/compiler/transformers/k8s/node.go b/pkg/compiler/transformers/k8s/node.go index fbf7cc321..d2dde8ef1 100644 --- a/pkg/compiler/transformers/k8s/node.go +++ b/pkg/compiler/transformers/k8s/node.go @@ -155,12 +155,19 @@ func buildNodeSpec(n *core.Node, tasks []*core.CompiledTask, errs errors.Compile } } case *core.Node_ArrayNode: + arrayNode := n.GetArrayNode() + + // build subNodeSpecs + subNodeSpecs, ok := buildNodeSpec(arrayNode.Node, tasks, errs) + if !ok { + return nil, ok + } + // TODO @hamersaw - complete nodeSpec.Kind = v1alpha1.NodeKindArray nodeSpec.ArrayNode = &v1alpha1.ArrayNodeSpec{ + SubNodeSpec: subNodeSpecs[0], } - //arrayNode := n.GetArrayNode() - default: if n.GetId() == v1alpha1.StartNodeID { nodeSpec.Kind = v1alpha1.NodeKindStart diff --git a/pkg/controller/nodes/array/execution_context.go b/pkg/controller/nodes/array/execution_context.go index 56c2c2b5d..188518207 100644 --- a/pkg/controller/nodes/array/execution_context.go +++ b/pkg/controller/nodes/array/execution_context.go @@ -1,12 +1,52 @@ package array import ( - //"github.com/flyteorg/flytepropeller/pkg/controller/executors" + "strconv" + + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" + + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/executors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" ) +const ( + FlyteK8sArrayIndexVarName string = "FLYTE_K8S_ARRAY_INDEX" + JobIndexVarName string = "BATCH_JOB_ARRAY_INDEX_VAR_NAME" +) + +type arrayExecutionContext struct { + executors.ExecutionContext + executionConfig v1alpha1.ExecutionConfig +} + +func (a arrayExecutionContext) GetExecutionConfig() v1alpha1.ExecutionConfig { + return a.executionConfig +} + +func newArrayExecutionContext(executionContext executors.ExecutionContext, subNodeIndex int) arrayExecutionContext { + executionConfig := executionContext.GetExecutionConfig() + executionConfig.EnvironmentVariables[JobIndexVarName] = FlyteK8sArrayIndexVarName + executionConfig.EnvironmentVariables[FlyteK8sArrayIndexVarName] = strconv.Itoa(subNodeIndex) + + return arrayExecutionContext{ + ExecutionContext: executionContext, + executionConfig: executionConfig, + } +} + type arrayNodeExecutionContext struct { interfaces.NodeExecutionContext + inputReader io.InputReader + executionContext arrayExecutionContext +} + +func (a arrayNodeExecutionContext) InputReader() io.InputReader { + return a.inputReader +} + +func (a arrayNodeExecutionContext) ExecutionContext() executors.ExecutionContext { + return a.executionContext } // TODO @hamersaw - overwrite everything @@ -17,22 +57,18 @@ nodeRecorder - need to add to nodeExecutionContext so we can override?!?! maxParallelism - looks like we need: ExecutionConfig.GetMaxParallelism ExecutionContext.IncrementMaxParallelism -storage locations - dataPrefix +storage locations - dataPrefix? add environment variables for maptask execution either: (1) in arrayExecutionContext if we use separate for each (2) in arrayNodeExectionContext if we choose to use single DAG */ -/*func newArrayExecutionContext(executionContext executors.ExecutionContext) executors.ExecutionContext { - return arrayExecutionContext{ - ExecutionContext: executionContext, - } -}*/ - -func newArrayNodeExecutionContext(nodeExecutionContext interfaces.NodeExecutionContext) arrayNodeExecutionContext { +func newArrayNodeExecutionContext(nodeExecutionContext interfaces.NodeExecutionContext, inputReader io.InputReader, subNodeIndex int) arrayNodeExecutionContext { + arrayExecutionContext := newArrayExecutionContext(nodeExecutionContext.ExecutionContext(), subNodeIndex) return arrayNodeExecutionContext{ NodeExecutionContext: nodeExecutionContext, + inputReader: inputReader, + executionContext: arrayExecutionContext, } } diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 5966ebec5..7889da60f 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -58,7 +58,7 @@ func (a *arrayNodeHandler) FinalizeRequired() bool { // Handle is responsible for transitioning and reporting node state to complete the node defined // by the NodeExecutionContext func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.Transition, error) { - //arrayNode := nCtx.Node().GetArrayNode() + arrayNode := nCtx.Node().GetArrayNode() arrayNodeState := nCtx.NodeStateReader().GetArrayNodeState() // TODO @hamersaw - handle array node @@ -122,44 +122,45 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu var inputReader io.InputReader if nodePhase == v1alpha1.NodePhaseNotYetStarted { // TODO @hamersaw - need to do this for PhaseSucceeded as well?!?! to write cache outputs once fastcache is in // create input readers and set nodePhase to Queued to skip resolving inputs but still allow cache lookups - // TODO @hamersaw - create input readers + // TODO @hamersaw - actually create input readers + inputReader = newStaticInputReader(nCtx.InputReader(), nil) nodePhase = v1alpha1.NodePhaseQueued } // wrap node lookup + subNodeSpec := *arrayNode.GetSubNodeSpec() + subNodeID := fmt.Sprintf("%s-n%d", nCtx.NodeID(), i) - subNodeSpec := &v1alpha1.NodeSpec{ + subNodeSpec.ID = subNodeID + subNodeSpec.Name = subNodeID + /*subNodeSpec := &v1alpha1.NodeSpec{ ID: subNodeID, Name: subNodeID, - } // TODO @hamersaw - compile this in ArrayNodeSpec? + } // TODO @hamersaw - compile this in ArrayNodeSpec?*/ subNodeStatus := &v1alpha1.NodeStatus{ Phase: nodePhase, - /*TaskNodeStatus: &v1alpha1.TaskNodeStatus{ - Phase: nodePhase, // used for cache lookups - once fastcache is done we dont care about the TaskNodeStatus - },*/ + TaskNodeStatus: &v1alpha1.TaskNodeStatus{ + // TODO @hamersaw - to get caching working we need to set to Queued to force cache lookup + // once fastcache is done we dont care about the TaskNodeStatus + Phase: int(core.Phases[core.PhaseRunning]), + }, // TODO @hamersaw - fill out systemFailures, retryAttempt etc } // TODO @hamersaw - can probably create a single arrayNodeLookup with all the subNodeIDs - arrayNodeLookup := newArrayNodeLookup(nCtx.ContextualNodeLookup(), subNodeID, subNodeSpec, subNodeStatus) + arrayNodeLookup := newArrayNodeLookup(nCtx.ContextualNodeLookup(), subNodeID, &subNodeSpec, subNodeStatus) // create arrayNodeExecutor - /*nodeExecutionContext, err := a.nodeExecutor.NewNodeExecutionContext(ctx, nCtx.ExecutionContext(), &arrayNodeLookup, subNodeID) - if err != nil { - // TODO @hamersaw fail - } - - // create new arrayNodeExecutionContext to override for array task execution - arrayNodeExecutionContext := newArrayNodeExecutionContext(nodeExecutionContext, inputReader)*/ - arrayNodeExecutor := newArrayNodeExecutor(a.nodeExecutor, subNodeID, inputReader) + arrayNodeExecutor := newArrayNodeExecutor(a.nodeExecutor, subNodeID, i, inputReader) // execute subNode through RecursiveNodeHandler - nodeStatus, err := arrayNodeExecutor.RecursiveNodeHandler(ctx, nCtx.ExecutionContext(), &arrayNodeLookup, &arrayNodeLookup, subNodeSpec) + nodeStatus, err := arrayNodeExecutor.RecursiveNodeHandler(ctx, nCtx.ExecutionContext(), &arrayNodeLookup, &arrayNodeLookup, &subNodeSpec) if err != nil { // TODO @hamersaw fail } - fmt.Printf("HAMERSAW - node phase transition %d -> %d", nodePhase, nodeStatus.NodePhase) + //fmt.Printf("HAMERSAW - node phase transition %d -> %d", nodePhase, nodeStatus.NodePhase) + arrayNodeState.SubNodePhases.SetItem(i, uint64(nodeStatus.NodePhase)) } // TODO @hamersaw - determine summary phases @@ -168,7 +169,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu case v1alpha1.ArrayNodePhaseFailing: // TODO @hamersaw - abort everything! case v1alpha1.ArrayNodePhaseSucceeding: - // TODO @hamersaw - collect outputs + // TODO @hamersaw - collect outputs and write as List[] return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil default: // TODO @hamersaw - fail diff --git a/pkg/controller/nodes/array/input_reader.go b/pkg/controller/nodes/array/input_reader.go new file mode 100644 index 000000000..540527b36 --- /dev/null +++ b/pkg/controller/nodes/array/input_reader.go @@ -0,0 +1,37 @@ +package array + +import ( + "context" + + idlcore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" +) + +type staticInputReader struct { + io.InputFilePaths + input *idlcore.LiteralMap +} + +func (i staticInputReader) Get(_ context.Context) (*idlcore.LiteralMap, error) { + return i.input, nil +} + +func newStaticInputReader(inputPaths io.InputFilePaths, input *idlcore.LiteralMap) staticInputReader { + return staticInputReader{ + InputFilePaths: inputPaths, + input: input, + } +} + +/*func ConstructStaticInputReaders(inputPaths io.InputFilePaths, inputs []*idlcore.Literal, inputName string) []io.InputReader { + inputReaders := make([]io.InputReader, 0, len(inputs)) + for i := 0; i < len(inputs); i++ { + inputReaders = append(inputReaders, NewStaticInputReader(inputPaths, &idlcore.LiteralMap{ + Literals: map[string]*idlcore.Literal{ + inputName: inputs[i], + }, + })) + } + + return inputReaders +}*/ diff --git a/pkg/controller/nodes/array/node_executor.go b/pkg/controller/nodes/array/node_executor.go index fdee0e92d..d9e5cf18b 100644 --- a/pkg/controller/nodes/array/node_executor.go +++ b/pkg/controller/nodes/array/node_executor.go @@ -2,6 +2,7 @@ package array import ( "context" + "fmt" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" @@ -12,12 +13,13 @@ import ( type arrayNodeExecutor struct { interfaces.Node - subNodeID v1alpha1.NodeID - inputReader io.InputReader + subNodeID v1alpha1.NodeID + subNodeIndex int + inputReader io.InputReader } // TODO @hamersaw - docs -func (a *arrayNodeExecutor) NewNodeExecutionContext(ctx context.Context, executionContext executors.ExecutionContext, +func (a arrayNodeExecutor) NewNodeExecutionContext(ctx context.Context, executionContext executors.ExecutionContext, nl executors.NodeLookup, currentNodeID v1alpha1.NodeID) (interfaces.NodeExecutionContext, error) { // create base NodeExecutionContext @@ -26,17 +28,20 @@ func (a *arrayNodeExecutor) NewNodeExecutionContext(ctx context.Context, executi return nil, err } + fmt.Println("HAMERSAW - currentNodeID %s subNodeID %s!\n", currentNodeID, a.subNodeID) if currentNodeID == a.subNodeID { // TODO @hamersaw - overwrite NodeExecutionContext for ArrayNode execution + nCtx = newArrayNodeExecutionContext(nCtx, a.inputReader, a.subNodeIndex) } return nCtx, nil } -func newArrayNodeExecutor(nodeExecutor interfaces.Node, subNodeID v1alpha1.NodeID, inputReader io.InputReader) arrayNodeExecutor { +func newArrayNodeExecutor(nodeExecutor interfaces.Node, subNodeID v1alpha1.NodeID, subNodeIndex int, inputReader io.InputReader) arrayNodeExecutor { return arrayNodeExecutor{ - Node: nodeExecutor, - subNodeID: subNodeID, - inputReader: inputReader, + Node: nodeExecutor, + subNodeID: subNodeID, + subNodeIndex: subNodeIndex, + inputReader: inputReader, } } diff --git a/pkg/controller/nodes/node_exec_context.go b/pkg/controller/nodes/node_exec_context.go index 27c3e82c7..f5faff70b 100644 --- a/pkg/controller/nodes/node_exec_context.go +++ b/pkg/controller/nodes/node_exec_context.go @@ -227,6 +227,7 @@ func (c *nodeExecutor) NewNodeExecutionContext(ctx context.Context, executionCon rawOutputPrefix = storage.DataReference(executionContext.GetRawOutputDataConfig().OutputLocationPrefix) } + fmt.Printf("HAMERSAW - creating base NodeExecutionContext for %s\n", currentNodeID) return newNodeExecContext(ctx, c.store, executionContext, nl, n, s, ioutils.NewCachedInputReader( ctx, diff --git a/pkg/controller/nodes/node_state_manager.go b/pkg/controller/nodes/node_state_manager.go index dd0e2c18c..70ac7ac5f 100644 --- a/pkg/controller/nodes/node_state_manager.go +++ b/pkg/controller/nodes/node_state_manager.go @@ -74,6 +74,10 @@ func (n *nodeStateManager) HasArrayNodeState() bool { } func (n nodeStateManager) GetTaskNodeState() interfaces.TaskNodeState { + if n.t != nil { + return *n.t + } + tn := n.nodeStatus.GetTaskNodeStatus() if tn != nil { return interfaces.TaskNodeState{ @@ -90,6 +94,10 @@ func (n nodeStateManager) GetTaskNodeState() interfaces.TaskNodeState { } func (n nodeStateManager) GetBranchNodeState() interfaces.BranchNodeState { + if n.b != nil { + return *n.b + } + bn := n.nodeStatus.GetBranchStatus() bs := interfaces.BranchNodeState{} if bn != nil { @@ -100,6 +108,10 @@ func (n nodeStateManager) GetBranchNodeState() interfaces.BranchNodeState { } func (n nodeStateManager) GetDynamicNodeState() interfaces.DynamicNodeState { + if n.d != nil { + return *n.d + } + dn := n.nodeStatus.GetDynamicNodeStatus() ds := interfaces.DynamicNodeState{} if dn != nil { @@ -112,6 +124,10 @@ func (n nodeStateManager) GetDynamicNodeState() interfaces.DynamicNodeState { } func (n nodeStateManager) GetWorkflowNodeState() interfaces.WorkflowNodeState { + if n.w != nil { + return *n.w + } + wn := n.nodeStatus.GetWorkflowNodeStatus() ws := interfaces.WorkflowNodeState{} if wn != nil { @@ -122,6 +138,10 @@ func (n nodeStateManager) GetWorkflowNodeState() interfaces.WorkflowNodeState { } func (n nodeStateManager) GetGateNodeState() interfaces.GateNodeState { + if n.g != nil { + return *n.g + } + gn := n.nodeStatus.GetGateNodeStatus() gs := interfaces.GateNodeState{} if gn != nil { @@ -131,6 +151,10 @@ func (n nodeStateManager) GetGateNodeState() interfaces.GateNodeState { } func (n nodeStateManager) GetArrayNodeState() interfaces.ArrayNodeState { + if n.a != nil { + return *n.a + } + an := n.nodeStatus.GetArrayNodeStatus() as := interfaces.ArrayNodeState{} if an != nil { diff --git a/pkg/controller/nodes/task/taskexec_context.go b/pkg/controller/nodes/task/taskexec_context.go index 2a6c59c0c..a308647ac 100644 --- a/pkg/controller/nodes/task/taskexec_context.go +++ b/pkg/controller/nodes/task/taskexec_context.go @@ -59,10 +59,11 @@ func (te taskExecutionID) GetGeneratedNameWith(minLength, maxLength int) (string type taskExecutionMetadata struct { interfaces.NodeExecutionMetadata - taskExecID taskExecutionID - o pluginCore.TaskOverrides - maxAttempts uint32 - platformResources *v1.ResourceRequirements + taskExecID taskExecutionID + o pluginCore.TaskOverrides + maxAttempts uint32 + platformResources *v1.ResourceRequirements + environmentVariables map[string]string } func (t taskExecutionMetadata) GetTaskExecutionID() pluginCore.TaskExecutionID { @@ -81,6 +82,10 @@ func (t taskExecutionMetadata) GetPlatformResources() *v1.ResourceRequirements { return t.platformResources } +func (t taskExecutionMetadata) GetEnvironmentVariables() map[string]string { + return t.environmentVariables +} + type taskExecutionContext struct { interfaces.NodeExecutionContext tm taskExecutionMetadata @@ -289,6 +294,7 @@ func (t *Handler) newTaskExecutionContext(ctx context.Context, nCtx interfaces.N o: nCtx.Node(), maxAttempts: maxAttempts, platformResources: convertTaskResourcesToRequirements(nCtx.ExecutionContext().GetExecutionConfig().TaskResources), + environmentVariables: nCtx.ExecutionContext().GetExecutionConfig().EnvironmentVariables, }, rm: resourcemanager.GetTaskResourceManager( t.resourceManager, resourceNamespacePrefix, id), From 03dcc56b789537702ed0c8cfb9844494670a0e76 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Thu, 13 Apr 2023 09:31:47 -0500 Subject: [PATCH 08/62] minimum viable product Signed-off-by: Daniel Rammer --- .../testdata/array-node-inputs.yaml.golden | 13 ++ .../cmd/testdata/array-node.yaml.golden | 82 +++++++++ .../nodes/array/execution_context.go | 3 + pkg/controller/nodes/array/handler.go | 173 ++++++++++++++++-- pkg/controller/nodes/array/input_reader.go | 22 +++ pkg/controller/nodes/executor.go | 14 ++ pkg/controller/nodes/interfaces/node.go | 3 + pkg/controller/nodes/node_exec_context.go | 1 - .../nodes/task/k8s/plugin_manager.go | 2 - 9 files changed, 292 insertions(+), 21 deletions(-) create mode 100755 cmd/kubectl-flyte/cmd/testdata/array-node-inputs.yaml.golden create mode 100644 cmd/kubectl-flyte/cmd/testdata/array-node.yaml.golden diff --git a/cmd/kubectl-flyte/cmd/testdata/array-node-inputs.yaml.golden b/cmd/kubectl-flyte/cmd/testdata/array-node-inputs.yaml.golden new file mode 100755 index 000000000..42e176686 --- /dev/null +++ b/cmd/kubectl-flyte/cmd/testdata/array-node-inputs.yaml.golden @@ -0,0 +1,13 @@ +literals: + "x": + collection: + literals: + - scalar: + primitive: + integer: "1" + - scalar: + primitive: + integer: "2" + - scalar: + primitive: + integer: "3" diff --git a/cmd/kubectl-flyte/cmd/testdata/array-node.yaml.golden b/cmd/kubectl-flyte/cmd/testdata/array-node.yaml.golden new file mode 100644 index 000000000..73eb67e81 --- /dev/null +++ b/cmd/kubectl-flyte/cmd/testdata/array-node.yaml.golden @@ -0,0 +1,82 @@ +tasks: +- container: + args: + - "pyflyte-fast-execute" + - "--additional-distribution" + - "s3://my-s3-bucket/flytesnacks/development/TWUFWD6G3NTY7K2B6YJOCHXIKQ======/script_mode.tar.gz" + - "--dest-dir" + - "/root" + - "--" + - "pyflyte-map-execute" + - "--inputs" + - "{{.input}}" + - "--output-prefix" + - "{{.outputPrefix}}" + - "--raw-output-data-prefix" + - "{{.rawOutputDataPrefix}}" + - "--checkpoint-path" + - "{{.checkpointOutputPrefix}}" + - "--prev-checkpoint" + - "{{.prevCheckpointPrefix}}" + - "--resolver" + - "MapTaskResolver" + - "--" + - "vars" + - "" + - "resolver" + - "flytekit.core.python_auto_container.default_task_resolver" + - "task-module" + - "map-task" + - "task-name" + - "a_mappable_task" + image: cr.flyte.org/flyteorg/flytekit:py3.10-latest + resources: + limits: + - name: 1 + value: "1" + - name: 3 + value: "500Mi" + requests: + - name: 1 + value: "1" + - name: 3 + value: "300Mi" + id: + name: task-1 + metadata: + discoverable: false + cache_serializable: false + interface: + inputs: + variables: + a: + type: + simple: INTEGER + outputs: + variables: + x: + type: + simple: STRING +workflow: + id: + name: workflow-with-array-node + interface: + inputs: + variables: + x: + type: + collectionType: + simple: INTEGER + nodes: + - id: node-1 + inputs: + - binding: + promise: + node_id: start-node + var: x + var: a + arrayNode: + node: + taskNode: + referenceId: + name: task-1 diff --git a/pkg/controller/nodes/array/execution_context.go b/pkg/controller/nodes/array/execution_context.go index 188518207..803288567 100644 --- a/pkg/controller/nodes/array/execution_context.go +++ b/pkg/controller/nodes/array/execution_context.go @@ -26,6 +26,9 @@ func (a arrayExecutionContext) GetExecutionConfig() v1alpha1.ExecutionConfig { func newArrayExecutionContext(executionContext executors.ExecutionContext, subNodeIndex int) arrayExecutionContext { executionConfig := executionContext.GetExecutionConfig() + if executionConfig.EnvironmentVariables == nil { + executionConfig.EnvironmentVariables = make(map[string]string) + } executionConfig.EnvironmentVariables[JobIndexVarName] = FlyteK8sArrayIndexVarName executionConfig.EnvironmentVariables[FlyteK8sArrayIndexVarName] = strconv.Itoa(subNodeIndex) diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 7889da60f..8c936fa62 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -1,22 +1,28 @@ package array import ( + "bytes" "context" "fmt" + "strconv" idlcore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" + //"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" "github.com/flyteorg/flytepropeller/pkg/compiler/validators" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/k8s" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/codex" "github.com/flyteorg/flytestdlib/bitarray" "github.com/flyteorg/flytestdlib/logger" "github.com/flyteorg/flytestdlib/promutils" + "github.com/flyteorg/flytestdlib/storage" ) //go:generate mockery -all -case=underscore @@ -70,6 +76,8 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // + add envVars on ExecutionContext // - need to manage + var inputs *idlcore.LiteralMap + switch arrayNodeState.Phase { case v1alpha1.ArrayNodePhaseNone: // identify and validate array node input value lengths @@ -107,7 +115,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // TODO @hamersaw - init SystemFailures and RetryAttempts as well // do we want to abstract this? ie. arrayNodeState.GetStats(subNodeIndex) (phase, systemFailures, ...) - fmt.Printf("HAMERSAW - created SubNodePhases with length '%d:%d'\n", size, len(arrayNodeState.SubNodePhases.GetItems())) + //fmt.Printf("HAMERSAW - created SubNodePhases with length '%d:%d'\n", size, len(arrayNodeState.SubNodePhases.GetItems())) case v1alpha1.ArrayNodePhaseExecuting: // process array node subnodes for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { @@ -115,15 +123,25 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu fmt.Printf("HAMERSAW - TODO evaluating node '%d' in phase '%d'\n", i, nodePhase) // TODO @hamersaw - fix + if nodePhase == v1alpha1.NodePhaseSucceeded || nodePhase == v1alpha1.NodePhaseFailed || nodePhase == v1alpha1.NodePhaseTimedOut || nodePhase == v1alpha1.NodePhaseSkipped { + continue + } + /*if nodes.IsTerminalNodePhase(nodePhase) { continue }*/ - var inputReader io.InputReader - if nodePhase == v1alpha1.NodePhaseNotYetStarted { // TODO @hamersaw - need to do this for PhaseSucceeded as well?!?! to write cache outputs once fastcache is in - // create input readers and set nodePhase to Queued to skip resolving inputs but still allow cache lookups - // TODO @hamersaw - actually create input readers - inputReader = newStaticInputReader(nCtx.InputReader(), nil) + // TODO @hamersaw - do we need to init input readers every time? + literalMap, err := constructLiteralMap(ctx, nCtx.InputReader(), i, inputs) + if err != nil { + logger.Errorf(ctx, "HAMERSAW - %+v", err) + // TODO @hamersaw - return err + } + + inputReader := newStaticInputReader(nCtx.InputReader(), &literalMap) + + if nodePhase == v1alpha1.NodePhaseNotYetStarted { + // set nodePhase to Queued to skip resolving inputs but still allow cache lookups nodePhase = v1alpha1.NodePhaseQueued } @@ -133,43 +151,162 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu subNodeID := fmt.Sprintf("%s-n%d", nCtx.NodeID(), i) subNodeSpec.ID = subNodeID subNodeSpec.Name = subNodeID - /*subNodeSpec := &v1alpha1.NodeSpec{ - ID: subNodeID, - Name: subNodeID, - } // TODO @hamersaw - compile this in ArrayNodeSpec?*/ + + // TODO @hamersaw - is this right?!?! it's certainly HACKY AF - maybe we persist pluginState.Phase and PluginPhase + pluginState := k8s.PluginState{ + } + if nodePhase == v1alpha1.NodePhaseQueued { + pluginState.Phase = k8s.PluginPhaseNotStarted + } else { + pluginState.Phase = k8s.PluginPhaseStarted + } + + buffer := make([]byte, 0, 256) + bufferWriter := bytes.NewBuffer(buffer) + + codec := codex.GobStateCodec{} + if err := codec.Encode(pluginState, bufferWriter); err != nil { + logger.Errorf(ctx, "HAMERSAW - %+v", err) + } + + // we set subDataDir and subOutputDir to the node dirs because flytekit automatically appends subtask + // index. however when we check completion status we need to manually append index - so in all cases + // where the node phase is not Queued (ie. task handler will launch task and init flytekit params) we + // append the subtask index. + var subDataDir, subOutputDir storage.DataReference + if nodePhase == v1alpha1.NodePhaseQueued { + subDataDir = nCtx.NodeStatus().GetDataDir() + subOutputDir = nCtx.NodeStatus().GetOutputDir() + } else { + subDataDir, err = nCtx.DataStore().ConstructReference(ctx, nCtx.NodeStatus().GetDataDir(), strconv.Itoa(i)) // TODO @hamersaw - constructOutputReference? + if err != nil { + logger.Errorf(ctx, "HAMERSAW - %+v", err) + // TODO @hamersaw - return err + } + + subOutputDir, err = nCtx.DataStore().ConstructReference(ctx, nCtx.NodeStatus().GetOutputDir(), strconv.Itoa(i)) // TODO @hamersaw - constructOutputReference? + if err != nil { + logger.Errorf(ctx, "HAMERSAW - %+v", err) + // TODO @hamersaw - return err + } + } + subNodeStatus := &v1alpha1.NodeStatus{ Phase: nodePhase, TaskNodeStatus: &v1alpha1.TaskNodeStatus{ // TODO @hamersaw - to get caching working we need to set to Queued to force cache lookup // once fastcache is done we dont care about the TaskNodeStatus Phase: int(core.Phases[core.PhaseRunning]), + PluginState: bufferWriter.Bytes(), }, + DataDir: subDataDir, + OutputDir: subOutputDir, // TODO @hamersaw - fill out systemFailures, retryAttempt etc } // TODO @hamersaw - can probably create a single arrayNodeLookup with all the subNodeIDs arrayNodeLookup := newArrayNodeLookup(nCtx.ContextualNodeLookup(), subNodeID, &subNodeSpec, subNodeStatus) - // create arrayNodeExecutor - arrayNodeExecutor := newArrayNodeExecutor(a.nodeExecutor, subNodeID, i, inputReader) - // execute subNode through RecursiveNodeHandler - nodeStatus, err := arrayNodeExecutor.RecursiveNodeHandler(ctx, nCtx.ExecutionContext(), &arrayNodeLookup, &arrayNodeLookup, &subNodeSpec) + _, err = a.nodeExecutor.RecursiveNodeHandlerWithNodeContextModifier(ctx, nCtx.ExecutionContext(), &arrayNodeLookup, &arrayNodeLookup, &subNodeSpec, + func (nCtx interfaces.NodeExecutionContext) interfaces.NodeExecutionContext { + if nCtx.NodeID() == subNodeID { + return newArrayNodeExecutionContext(nCtx, inputReader, i) + } + + return nCtx + }) + if err != nil { + logger.Errorf(ctx, "HAMERSAW - %+v", err) // TODO @hamersaw fail } - //fmt.Printf("HAMERSAW - node phase transition %d -> %d", nodePhase, nodeStatus.NodePhase) - arrayNodeState.SubNodePhases.SetItem(i, uint64(nodeStatus.NodePhase)) + fmt.Printf("HAMERSAW - node phase transition %d -> %d\n", nodePhase, subNodeStatus.GetPhase()) + arrayNodeState.SubNodePhases.SetItem(i, uint64(subNodeStatus.GetPhase())) } // TODO @hamersaw - determine summary phases + succeeded := true + for _, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { + nodePhase := v1alpha1.NodePhase(nodePhaseUint64) + if nodePhase != v1alpha1.NodePhaseSucceeded { + succeeded = false + break + } + } - arrayNodeState.Phase = v1alpha1.ArrayNodePhaseSucceeding + if succeeded { + arrayNodeState.Phase = v1alpha1.ArrayNodePhaseSucceeding + } case v1alpha1.ArrayNodePhaseFailing: // TODO @hamersaw - abort everything! case v1alpha1.ArrayNodePhaseSucceeding: + outputLiterals := make(map[string]*idlcore.Literal) + + for i, _ := range arrayNodeState.SubNodePhases.GetItems() { + // initialize subNode reader + subDataDir, err := nCtx.DataStore().ConstructReference(ctx, nCtx.NodeStatus().GetDataDir(), strconv.Itoa(i)) // TODO @hamersaw - constructOutputReference? + if err != nil { + logger.Errorf(ctx, "HAMERSAW - %+v", err) + // TODO @hamersaw - return err + } + + subOutputDir, err := nCtx.DataStore().ConstructReference(ctx, nCtx.NodeStatus().GetOutputDir(), strconv.Itoa(i)) // TODO @hamersaw - constructOutputReference? + if err != nil { + logger.Errorf(ctx, "HAMERSAW - %+v", err) + // TODO @hamersaw - return err + } + + // checkpoint paths are not computed here because this function is only called when writing + // existing cached outputs. if this functionality changes this will need to be revisited. + outputPaths := ioutils.NewCheckpointRemoteFilePaths(ctx, nCtx.DataStore(), subOutputDir, ioutils.NewRawOutputPaths(ctx, subDataDir), "") + reader := ioutils.NewRemoteFileOutputReader(ctx, nCtx.DataStore(), outputPaths, int64(999999999)) + + // read outputs + outputs, executionError, err := reader.Read(ctx) + if err != nil { + logger.Warnf(ctx, "Failed to read output for subtask [%v]. Error: %v", i, err) + //return workqueue.WorkStatusFailed, err // TODO @hamersaw -return error + } + + if executionError == nil && outputs != nil { + for name, literal := range outputs.GetLiterals() { + existingVal, found := outputLiterals[name] + var list *idlcore.LiteralCollection + if found { + list = existingVal.GetCollection() + } else { + list = &idlcore.LiteralCollection{ + Literals: make([]*idlcore.Literal, 0, len(arrayNodeState.SubNodePhases.GetItems())), + } + + existingVal = &idlcore.Literal{ + Value: &idlcore.Literal_Collection{ + Collection: list, + }, + } + } + + list.Literals = append(list.Literals, literal) + outputLiterals[name] = existingVal + } + } + } + // TODO @hamersaw - collect outputs and write as List[] + //fmt.Printf("HAMERSAW - final outputs %+v\n", idlcore.LiteralMap{Literals: outputLiterals}) + outputLiteralMap := &idlcore.LiteralMap{ + Literals: outputLiterals, + } + + outputFile := v1alpha1.GetOutputsFile(nCtx.NodeStatus().GetOutputDir()) + if err := nCtx.DataStore().WriteProtobuf(ctx, outputFile, storage.Options{}, outputLiteralMap); err != nil { + // TODO @hamersaw return error + //return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, "WriteOutputsFailed", + // fmt.Sprintf("failed to write signal value to [%v] with error [%s]", outputFile, err.Error()), nil)), nil + } + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil default: // TODO @hamersaw - fail diff --git a/pkg/controller/nodes/array/input_reader.go b/pkg/controller/nodes/array/input_reader.go index 540527b36..8e5d99598 100644 --- a/pkg/controller/nodes/array/input_reader.go +++ b/pkg/controller/nodes/array/input_reader.go @@ -23,6 +23,28 @@ func newStaticInputReader(inputPaths io.InputFilePaths, input *idlcore.LiteralMa } } +func constructLiteralMap(ctx context.Context, inputReader io.InputReader, index int, inputs *idlcore.LiteralMap) (idlcore.LiteralMap, error) { + // TODO @hamersaw - read inputs + var err error + if inputs == nil { + inputs, err = inputReader.Get(ctx) + if err != nil { + return idlcore.LiteralMap{}, err + } + } + + literals := make(map[string]*idlcore.Literal) + for name, literal := range inputs.Literals { + if literalCollection := literal.GetCollection(); literalCollection != nil { + literals[name] = literalCollection.Literals[index] + } + } + + return idlcore.LiteralMap{ + Literals: literals, + }, nil +} + /*func ConstructStaticInputReaders(inputPaths io.InputFilePaths, inputs []*idlcore.Literal, inputName string) []io.InputReader { inputReaders := make([]io.InputReader, 0, len(inputs)) for i := 0; i < len(inputs); i++ { diff --git a/pkg/controller/nodes/executor.go b/pkg/controller/nodes/executor.go index 997c07075..ae081a087 100644 --- a/pkg/controller/nodes/executor.go +++ b/pkg/controller/nodes/executor.go @@ -967,12 +967,24 @@ func IsMaxParallelismAchieved(ctx context.Context, currentNode v1alpha1.Executab return false } + // RecursiveNodeHandler This is the entrypoint of executing a node in a workflow. A workflow consists of nodes, that are // nested within other nodes. The system follows an actor model, where the parent nodes control the execution of nested nodes // The recursive node-handler uses a modified depth-first type of algorithm to execute non-blocked nodes. func (c *nodeExecutor) RecursiveNodeHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) ( interfaces.NodeStatus, error) { + + return c.RecursiveNodeHandlerWithNodeContextModifier(ctx, execContext, dag, nl, currentNode, func (nCtx interfaces.NodeExecutionContext) interfaces.NodeExecutionContext { + return nCtx + }) +} + +// TODO @hamersaw +func (c *nodeExecutor) RecursiveNodeHandlerWithNodeContextModifier(ctx context.Context, execContext executors.ExecutionContext, + dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode, + nCtxModifier func (interfaces.NodeExecutionContext) interfaces.NodeExecutionContext) ( + interfaces.NodeStatus, error) { currentNodeCtx := contextutils.WithNodeID(ctx, currentNode.GetID()) nodeStatus := nl.GetNodeExecutionStatus(ctx, currentNode.GetID()) @@ -1012,6 +1024,8 @@ func (c *nodeExecutor) RecursiveNodeHandler(ctx context.Context, execContext exe }), nil } + nCtx = nCtxModifier(nCtx) + // Now depending on the node type decide h, err := c.nodeHandlerFactory.GetHandler(nCtx.Node().GetKind()) if err != nil { diff --git a/pkg/controller/nodes/interfaces/node.go b/pkg/controller/nodes/interfaces/node.go index f713f6348..2974090cd 100644 --- a/pkg/controller/nodes/interfaces/node.go +++ b/pkg/controller/nodes/interfaces/node.go @@ -79,6 +79,9 @@ type Node interface { RecursiveNodeHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) (NodeStatus, error) + RecursiveNodeHandlerWithNodeContextModifier(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, + nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode, nCtxModifier func (NodeExecutionContext) NodeExecutionContext) (NodeStatus, error) + // This aborts the given node. If the given node is complete then it recursively finds the running nodes and aborts them AbortHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode, reason string) error diff --git a/pkg/controller/nodes/node_exec_context.go b/pkg/controller/nodes/node_exec_context.go index f5faff70b..27c3e82c7 100644 --- a/pkg/controller/nodes/node_exec_context.go +++ b/pkg/controller/nodes/node_exec_context.go @@ -227,7 +227,6 @@ func (c *nodeExecutor) NewNodeExecutionContext(ctx context.Context, executionCon rawOutputPrefix = storage.DataReference(executionContext.GetRawOutputDataConfig().OutputLocationPrefix) } - fmt.Printf("HAMERSAW - creating base NodeExecutionContext for %s\n", currentNodeID) return newNodeExecContext(ctx, c.store, executionContext, nl, n, s, ioutils.NewCachedInputReader( ctx, diff --git a/pkg/controller/nodes/task/k8s/plugin_manager.go b/pkg/controller/nodes/task/k8s/plugin_manager.go index 67b0356a3..42215254b 100644 --- a/pkg/controller/nodes/task/k8s/plugin_manager.go +++ b/pkg/controller/nodes/task/k8s/plugin_manager.go @@ -186,7 +186,6 @@ func (e *PluginManager) getPodEffectiveResourceLimits(ctx context.Context, pod * } func (e *PluginManager) LaunchResource(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (pluginsCore.Transition, error) { - tmpl, err := tCtx.TaskReader().Read(ctx) if err != nil { return pluginsCore.Transition{}, err @@ -249,7 +248,6 @@ func (e *PluginManager) LaunchResource(ctx context.Context, tCtx pluginsCore.Tas } func (e *PluginManager) CheckResourcePhase(ctx context.Context, tCtx pluginsCore.TaskExecutionContext, k8sPluginState *k8s.PluginState) (pluginsCore.Transition, error) { - o, err := e.plugin.BuildIdentityResource(ctx, tCtx.TaskExecutionMetadata()) if err != nil { logger.Errorf(ctx, "Failed to build the Resource with name: %v. Error: %v", tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), err) From 658abc784e666f89bb491afda6153f30811e36aa Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Thu, 13 Apr 2023 12:41:59 -0500 Subject: [PATCH 09/62] update print statements for debugging Signed-off-by: Daniel Rammer --- cmd/kubectl-flyte/cmd/create.go | 6 ++++++ .../cmd/testdata/array-node-inputs.yaml.golden | 3 --- pkg/controller/nodes/array/handler.go | 2 +- pkg/controller/nodes/array/input_reader.go | 1 - 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/cmd/kubectl-flyte/cmd/create.go b/cmd/kubectl-flyte/cmd/create.go index 91ea38255..2114aa902 100644 --- a/cmd/kubectl-flyte/cmd/create.go +++ b/cmd/kubectl-flyte/cmd/create.go @@ -212,6 +212,12 @@ func (c *CreateOpts) createWorkflowFromProto() error { } } + // TODO @hamersaw temp + flyteWf.ExecutionID.Project = "flytesnacks" + flyteWf.ExecutionID.Domain = "development" + flyteWf.Labels["project"] = "flytesnacks" + flyteWf.Labels["domain"] = "development" + if c.dryRun { fmt.Printf("Dry Run mode enabled. Printing the compiled workflow.\n") j, err := json.Marshal(flyteWf) diff --git a/cmd/kubectl-flyte/cmd/testdata/array-node-inputs.yaml.golden b/cmd/kubectl-flyte/cmd/testdata/array-node-inputs.yaml.golden index 42e176686..3f9d69172 100755 --- a/cmd/kubectl-flyte/cmd/testdata/array-node-inputs.yaml.golden +++ b/cmd/kubectl-flyte/cmd/testdata/array-node-inputs.yaml.golden @@ -8,6 +8,3 @@ literals: - scalar: primitive: integer: "2" - - scalar: - primitive: - integer: "3" diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 8c936fa62..125dcf80d 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -295,7 +295,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu } // TODO @hamersaw - collect outputs and write as List[] - //fmt.Printf("HAMERSAW - final outputs %+v\n", idlcore.LiteralMap{Literals: outputLiterals}) + fmt.Printf("HAMERSAW - final outputs %+v\n", idlcore.LiteralMap{Literals: outputLiterals}) outputLiteralMap := &idlcore.LiteralMap{ Literals: outputLiterals, } diff --git a/pkg/controller/nodes/array/input_reader.go b/pkg/controller/nodes/array/input_reader.go index 8e5d99598..de6e6ef81 100644 --- a/pkg/controller/nodes/array/input_reader.go +++ b/pkg/controller/nodes/array/input_reader.go @@ -24,7 +24,6 @@ func newStaticInputReader(inputPaths io.InputFilePaths, input *idlcore.LiteralMa } func constructLiteralMap(ctx context.Context, inputReader io.InputReader, index int, inputs *idlcore.LiteralMap) (idlcore.LiteralMap, error) { - // TODO @hamersaw - read inputs var err error if inputs == nil { inputs, err = inputReader.Get(ctx) From e33af409e65d83334c886ad8c3509d254dd23123 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Tue, 25 Apr 2023 05:40:08 -0500 Subject: [PATCH 10/62] massive refactor fixing NodeExecutionContext override for ArrayNode Signed-off-by: Daniel Rammer --- .../nodes/array/execution_context.go | 40 + pkg/controller/nodes/array/handler.go | 10 +- ...{node_executor.go => node_executor.go.bak} | 0 pkg/controller/nodes/executor.go | 1911 +++++++++-------- pkg/controller/nodes/executor_test.go | 56 +- pkg/controller/nodes/handler/iface.go | 9 + pkg/controller/nodes/interfaces/node.go | 18 +- pkg/controller/nodes/node_exec_context.go | 2 +- .../nodes/node_exec_context_test.go | 4 +- pkg/controller/nodes/setup_context.go | 2 +- 10 files changed, 1070 insertions(+), 982 deletions(-) rename pkg/controller/nodes/array/{node_executor.go => node_executor.go.bak} (100%) diff --git a/pkg/controller/nodes/array/execution_context.go b/pkg/controller/nodes/array/execution_context.go index 803288567..11aaf3693 100644 --- a/pkg/controller/nodes/array/execution_context.go +++ b/pkg/controller/nodes/array/execution_context.go @@ -1,6 +1,8 @@ package array import ( + "context" + "fmt" "strconv" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" @@ -75,3 +77,41 @@ func newArrayNodeExecutionContext(nodeExecutionContext interfaces.NodeExecutionC executionContext: arrayExecutionContext, } } + + +type arrayNodeExecutionContextBuilder struct { + nCtxBuilder interfaces.NodeExecutionContextBuilder + subNodeID v1alpha1.NodeID + subNodeIndex int + inputReader io.InputReader +} + +func (a *arrayNodeExecutionContextBuilder) BuildNodeExecutionContext(ctx context.Context, executionContext executors.ExecutionContext, + nl executors.NodeLookup, currentNodeID v1alpha1.NodeID) (interfaces.NodeExecutionContext, error) { + + // create base NodeExecutionContext + nCtx, err := a.nCtxBuilder.BuildNodeExecutionContext(ctx, executionContext, nl, currentNodeID) + if err != nil { + return nil, err + } + + fmt.Println("HAMERSAW - currentNodeID %s subNodeID %s!\n", currentNodeID, a.subNodeID) + if currentNodeID == a.subNodeID { + // overwrite NodeExecutionContext for ArrayNode execution + nCtx = newArrayNodeExecutionContext(nCtx, a.inputReader, a.subNodeIndex) + } + + return nCtx, nil + +} + +func newArrayNodeExecutionContextBuilder(nCtxBuilder interfaces.NodeExecutionContextBuilder, subNodeID v1alpha1.NodeID, + subNodeIndex int, inputReader io.InputReader) interfaces.NodeExecutionContextBuilder { + + return &arrayNodeExecutionContextBuilder{ + nCtxBuilder: nCtxBuilder, + subNodeID: subNodeID, + subNodeIndex: subNodeIndex, + inputReader: inputReader, + } +} diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 125dcf80d..d26aa510e 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -208,14 +208,20 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu arrayNodeLookup := newArrayNodeLookup(nCtx.ContextualNodeLookup(), subNodeID, &subNodeSpec, subNodeStatus) // execute subNode through RecursiveNodeHandler - _, err = a.nodeExecutor.RecursiveNodeHandlerWithNodeContextModifier(ctx, nCtx.ExecutionContext(), &arrayNodeLookup, &arrayNodeLookup, &subNodeSpec, + /*_, err = a.nodeExecutor.RecursiveNodeHandlerWithNodeContextModifier(ctx, nCtx.ExecutionContext(), &arrayNodeLookup, &arrayNodeLookup, &subNodeSpec, func (nCtx interfaces.NodeExecutionContext) interfaces.NodeExecutionContext { if nCtx.NodeID() == subNodeID { return newArrayNodeExecutionContext(nCtx, inputReader, i) } return nCtx - }) + })*/ + + // TODO @hamersaw - move all construction of nCtx internal -> can build a single arrayNodeExecutor and use for everyone -> build differently based on index + // execute subNode through RecursiveNodeHandler + arrayNodeExecutionContextBuilder := newArrayNodeExecutionContextBuilder(a.nodeExecutor.GetNodeExecutionContextBuilder(), subNodeID, i, inputReader) + arrayNodeExecutor := a.nodeExecutor.WithNodeExecutionContextBuilder(arrayNodeExecutionContextBuilder) + _, err = arrayNodeExecutor.RecursiveNodeHandler(ctx, nCtx.ExecutionContext(), &arrayNodeLookup, &arrayNodeLookup, &subNodeSpec) if err != nil { logger.Errorf(ctx, "HAMERSAW - %+v", err) diff --git a/pkg/controller/nodes/array/node_executor.go b/pkg/controller/nodes/array/node_executor.go.bak similarity index 100% rename from pkg/controller/nodes/array/node_executor.go rename to pkg/controller/nodes/array/node_executor.go.bak diff --git a/pkg/controller/nodes/executor.go b/pkg/controller/nodes/executor.go index ae081a087..b39787fc0 100644 --- a/pkg/controller/nodes/executor.go +++ b/pkg/controller/nodes/executor.go @@ -86,1140 +86,1158 @@ type nodeMetrics struct { } // Implements the executors.Node interface -type nodeExecutor struct { - nodeHandlerFactory HandlerFactory +type recursiveNodeExecutor struct { + nodeExecutor handler.NodeExecutor + nCtxBuilder interfaces.NodeExecutionContextBuilder + enqueueWorkflow v1alpha1.EnqueueWorkflow + nodeHandlerFactory HandlerFactory store *storage.DataStore - nodeRecorder events.NodeEventRecorder - taskRecorder events.TaskEventRecorder metrics *nodeMetrics - maxDatasetSizeBytes int64 - outputResolver OutputResolver - defaultExecutionDeadline time.Duration - defaultActiveDeadline time.Duration - maxNodeRetriesForSystemFailures uint32 - interruptibleFailureThreshold uint32 - defaultDataSandbox storage.DataReference - shardSelector ioutils.ShardSelector - recoveryClient recovery.Client - eventConfig *config.EventConfig - clusterID string } -func (c *nodeExecutor) RecordTransitionLatency(ctx context.Context, dag executors.DAGStructure, nl executors.NodeLookup, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus) { - if nodeStatus.GetPhase() == v1alpha1.NodePhaseNotYetStarted || nodeStatus.GetPhase() == v1alpha1.NodePhaseQueued { - // Log transition latency (The most recently finished parent node endAt time to this node's queuedAt time -now-) - t, err := GetParentNodeMaxEndTime(ctx, dag, nl, node) - if err != nil { - logger.Warnf(ctx, "Failed to record transition latency for node. Error: %s", err.Error()) - return - } - if !t.IsZero() { - c.metrics.TransitionLatency.Observe(ctx, t.Time, time.Now()) - } - } else if nodeStatus.GetPhase() == v1alpha1.NodePhaseRetryableFailure && nodeStatus.GetLastUpdatedAt() != nil { - c.metrics.TransitionLatency.Observe(ctx, nodeStatus.GetLastUpdatedAt().Time, time.Now()) +func (c *recursiveNodeExecutor) SetInputsForStartNode(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructureWithStartNode, nl executors.NodeLookup, inputs *core.LiteralMap) (interfaces.NodeStatus, error) { + startNode := dag.StartNode() + ctx = contextutils.WithNodeID(ctx, startNode.GetID()) + if inputs == nil { + logger.Infof(ctx, "No inputs for the workflow. Skipping storing inputs") + return interfaces.NodeStatusComplete, nil } -} -func (c *nodeExecutor) IdempotentRecordEvent(ctx context.Context, nodeEvent *event.NodeExecutionEvent) error { - if nodeEvent == nil { - return fmt.Errorf("event recording attempt of Nil Node execution event") + // StartNode is special. It does not have any processing step. It just takes the workflow (or subworkflow) inputs and converts to its own outputs + nodeStatus := nl.GetNodeExecutionStatus(ctx, startNode.GetID()) + + if len(nodeStatus.GetDataDir()) == 0 { + return interfaces.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, startNode.GetID(), "no data-dir set, cannot store inputs") } + outputFile := v1alpha1.GetOutputsFile(nodeStatus.GetOutputDir()) - if nodeEvent.Id == nil { - return fmt.Errorf("event recording attempt of with nil node Event ID") + so := storage.Options{} + if err := c.store.WriteProtobuf(ctx, outputFile, so, inputs); err != nil { + logger.Errorf(ctx, "Failed to write protobuf (metadata). Error [%v]", err) + return interfaces.NodeStatusUndefined, errors.Wrapf(errors.CausedByError, startNode.GetID(), err, "Failed to store workflow inputs (as start node)") } - logger.Infof(ctx, "Recording NodeEvent [%s] phase[%s]", nodeEvent.GetId().String(), nodeEvent.Phase.String()) - err := c.nodeRecorder.RecordNodeEvent(ctx, nodeEvent, c.eventConfig) - if err != nil { - if nodeEvent.GetId().NodeId == v1alpha1.EndNodeID { - return nil - } + return interfaces.NodeStatusComplete, nil +} - if eventsErr.IsAlreadyExists(err) { - logger.Infof(ctx, "Node event phase: %s, nodeId %s already exist", - nodeEvent.Phase.String(), nodeEvent.GetId().NodeId) - return nil - } else if eventsErr.IsEventAlreadyInTerminalStateError(err) { - if IsTerminalNodePhase(nodeEvent.Phase) { - // Event was trying to record a different terminal phase for an already terminal event. ignoring. - logger.Infof(ctx, "Node event phase: %s, nodeId %s already in terminal phase. err: %s", - nodeEvent.Phase.String(), nodeEvent.GetId().NodeId, err.Error()) - return nil - } - logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) - return errors.Wrapf(errors.IllegalStateError, nodeEvent.Id.NodeId, err, "phase mis-match mismatch between propeller and control plane; Trying to record Node p: %s", nodeEvent.Phase) +func canHandleNode(phase v1alpha1.NodePhase) bool { + return phase == v1alpha1.NodePhaseNotYetStarted || + phase == v1alpha1.NodePhaseQueued || + phase == v1alpha1.NodePhaseRunning || + phase == v1alpha1.NodePhaseFailing || + phase == v1alpha1.NodePhaseTimingOut || + phase == v1alpha1.NodePhaseRetryableFailure || + phase == v1alpha1.NodePhaseSucceeding || + phase == v1alpha1.NodePhaseDynamicRunning +} + +// IsMaxParallelismAchieved checks if we have already achieved max parallelism. It returns true, if the desired max parallelism +// value is achieved, false otherwise +// MaxParallelism is defined as the maximum number of TaskNodes and LaunchPlans (together) that can be executed concurrently +// by one workflow execution. A setting of `0` indicates that it is disabled. +func IsMaxParallelismAchieved(ctx context.Context, currentNode v1alpha1.ExecutableNode, currentPhase v1alpha1.NodePhase, + execContext executors.ExecutionContext) bool { + maxParallelism := execContext.GetExecutionConfig().MaxParallelism + if maxParallelism == 0 { + logger.Debugf(ctx, "Parallelism control disabled") + return false + } + + if currentNode.GetKind() == v1alpha1.NodeKindTask || + (currentNode.GetKind() == v1alpha1.NodeKindWorkflow && currentNode.GetWorkflowNode() != nil && currentNode.GetWorkflowNode().GetLaunchPlanRefID() != nil) { + // If we are queued, let us see if we can proceed within the node parallelism bounds + if execContext.CurrentParallelism() >= maxParallelism { + logger.Infof(ctx, "Maximum Parallelism for task/launch-plan nodes achieved [%d] >= Max [%d], Round will be short-circuited.", execContext.CurrentParallelism(), maxParallelism) + return true } + // We know that Propeller goes through each workflow in a single thread, thus every node is really processed + // sequentially. So, we can continue - now that we know we are under the parallelism limits and increment the + // parallelism if the node, enters a running state + logger.Debugf(ctx, "Parallelism criteria not met, Current [%d], Max [%d]", execContext.CurrentParallelism(), maxParallelism) + } else { + logger.Debugf(ctx, "NodeKind: %s in status [%s]. Parallelism control is not applicable. Current Parallelism [%d]", + currentNode.GetKind().String(), currentPhase.String(), execContext.CurrentParallelism()) } - return err + return false } -func (c *nodeExecutor) recoverInputs(ctx context.Context, nCtx interfaces.NodeExecutionContext, - recovered *admin.NodeExecution, recoveredData *admin.NodeExecutionGetDataResponse) (*core.LiteralMap, error) { - nodeInputs := recoveredData.FullInputs - if nodeInputs != nil { - if err := c.store.WriteProtobuf(ctx, nCtx.InputReader().GetInputPath(), storage.Options{}, nodeInputs); err != nil { - c.metrics.InputsWriteFailure.Inc(ctx) - logger.Errorf(ctx, "Failed to move recovered inputs for Node. Error [%v]. InputsFile [%s]", err, nCtx.InputReader().GetInputPath()) - return nil, errors.Wrapf(errors.StorageError, nCtx.NodeID(), err, "Failed to store inputs for Node. InputsFile [%s]", nCtx.InputReader().GetInputPath()) +// RecursiveNodeHandler This is the entrypoint of executing a node in a workflow. A workflow consists of nodes, that are +// nested within other nodes. The system follows an actor model, where the parent nodes control the execution of nested nodes +// The recursive node-handler uses a modified depth-first type of algorithm to execute non-blocked nodes. +func (c *recursiveNodeExecutor) RecursiveNodeHandler(ctx context.Context, execContext executors.ExecutionContext, + dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) ( + interfaces.NodeStatus, error) { + + currentNodeCtx := contextutils.WithNodeID(ctx, currentNode.GetID()) + nodeStatus := nl.GetNodeExecutionStatus(ctx, currentNode.GetID()) + nodePhase := nodeStatus.GetPhase() + + if canHandleNode(nodePhase) { + // TODO Follow up Pull Request, + // 1. Rename this method to DAGTraversalHandleNode (accepts a DAGStructure along-with) the remaining arguments + // 2. Create a new method called HandleNode (part of the interface) (remaining all args as the previous method, but no DAGStructure + // 3. Additional both methods will receive inputs reader + // 4. The Downstream nodes handler will Resolve the Inputs + // 5. the method will delegate all other node handling to HandleNode. + // 6. Thus we can get rid of SetInputs for StartNode as well + logger.Debugf(currentNodeCtx, "Handling node Status [%v]", nodeStatus.GetPhase().String()) + + t := c.metrics.NodeExecutionTime.Start(ctx) + defer t.Stop() + + // This is an optimization to avoid creating the nodeContext object in case the node has already been looked at. + // If the overhead was zero, we would just do the isDirtyCheck after the nodeContext is created + if nodeStatus.IsDirty() { + return interfaces.NodeStatusRunning, nil } - } else if len(recovered.InputUri) > 0 { - // If the inputs are too large they won't be returned inline in the RecoverData call. We must fetch them before copying them. - nodeInputs = &core.LiteralMap{} - if recoveredData.FullInputs == nil { - if err := c.store.ReadProtobuf(ctx, storage.DataReference(recovered.InputUri), nodeInputs); err != nil { - return nil, errors.Wrapf(errors.InputsNotFoundError, nCtx.NodeID(), err, "failed to read data from dataDir [%v].", recovered.InputUri) - } + + if IsMaxParallelismAchieved(ctx, currentNode, nodePhase, execContext) { + return interfaces.NodeStatusRunning, nil } - if err := c.store.WriteProtobuf(ctx, nCtx.InputReader().GetInputPath(), storage.Options{}, nodeInputs); err != nil { - c.metrics.InputsWriteFailure.Inc(ctx) - logger.Errorf(ctx, "Failed to move recovered inputs for Node. Error [%v]. InputsFile [%s]", err, nCtx.InputReader().GetInputPath()) - return nil, errors.Wrapf(errors.StorageError, nCtx.NodeID(), err, "Failed to store inputs for Node. InputsFile [%s]", nCtx.InputReader().GetInputPath()) + nCtx, err := c.nCtxBuilder.BuildNodeExecutionContext(ctx, execContext, nl, currentNode.GetID()) + if err != nil { + // NodeExecution creation failure is a permanent fail / system error. + // Should a system failure always return an err? + return interfaces.NodeStatusFailed(&core.ExecutionError{ + Code: "InternalError", + Message: err.Error(), + Kind: core.ExecutionError_SYSTEM, + }), nil } - } - return nodeInputs, nil -} + // Now depending on the node type decide + h, err := c.nodeHandlerFactory.GetHandler(nCtx.Node().GetKind()) + if err != nil { + return interfaces.NodeStatusUndefined, err + } -func (c *nodeExecutor) attemptRecovery(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.PhaseInfo, error) { - fullyQualifiedNodeID := nCtx.NodeExecutionMetadata().GetNodeExecutionID().NodeId - if nCtx.ExecutionContext().GetEventVersion() != v1alpha1.EventVersion0 { - // compute fully qualified node id (prefixed with parent id and retry attempt) to ensure uniqueness - var err error - fullyQualifiedNodeID, err = common.GenerateUniqueID(nCtx.ExecutionContext().GetParentInfo(), nCtx.NodeExecutionMetadata().GetNodeExecutionID().NodeId) + return c.nodeExecutor.HandleNode(currentNodeCtx, dag, nCtx, h) + + // TODO we can optimize skip state handling by iterating down the graph and marking all as skipped + // Currently we treat either Skip or Success the same way. In this approach only one node will be skipped + // at a time. As we iterate down, further nodes will be skipped + } else if nodePhase == v1alpha1.NodePhaseSucceeded || nodePhase == v1alpha1.NodePhaseSkipped || nodePhase == v1alpha1.NodePhaseRecovered { + logger.Debugf(currentNodeCtx, "Node has [%v], traversing downstream.", nodePhase) + return c.handleDownstream(ctx, execContext, dag, nl, currentNode) + } else if nodePhase == v1alpha1.NodePhaseFailed { + logger.Debugf(currentNodeCtx, "Node has failed, traversing downstream.") + _, err := c.handleDownstream(ctx, execContext, dag, nl, currentNode) if err != nil { - return handler.PhaseInfoUndefined, err + return interfaces.NodeStatusUndefined, err } - } - recovered, err := c.recoveryClient.RecoverNodeExecution(ctx, nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier, fullyQualifiedNodeID) - if err != nil { - st, ok := status.FromError(err) - if !ok || st.Code() != codes.NotFound { - logger.Warnf(ctx, "Failed to recover node [%+v] with err [%+v]", nCtx.NodeExecutionMetadata().GetNodeExecutionID(), err) + return interfaces.NodeStatusFailed(nodeStatus.GetExecutionError()), nil + } else if nodePhase == v1alpha1.NodePhaseTimedOut { + logger.Debugf(currentNodeCtx, "Node has timed out, traversing downstream.") + _, err := c.handleDownstream(ctx, execContext, dag, nl, currentNode) + if err != nil { + return interfaces.NodeStatusUndefined, err } - // The node is not recoverable when it's not found in the parent execution - return handler.PhaseInfoUndefined, nil + + return interfaces.NodeStatusTimedOut, nil } - if recovered == nil { - logger.Warnf(ctx, "call to recover node [%+v] returned no error but also no node", nCtx.NodeExecutionMetadata().GetNodeExecutionID()) - return handler.PhaseInfoUndefined, nil + + return interfaces.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, currentNode.GetID(), + "Should never reach here. Current Phase: %v", nodePhase) +} + +// The space search for the next node to execute is implemented like a DFS algorithm. handleDownstream visits all the nodes downstream from +// the currentNode. Visit a node is the RecursiveNodeHandler. A visit may be partial, complete or may result in a failure. +func (c *recursiveNodeExecutor) handleDownstream(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) (interfaces.NodeStatus, error) { + logger.Debugf(ctx, "Handling downstream Nodes") + // This node is success. Handle all downstream nodes + downstreamNodes, err := dag.FromNode(currentNode.GetID()) + if err != nil { + logger.Debugf(ctx, "Error when retrieving downstream nodes, [%s]", err) + return interfaces.NodeStatusFailed(&core.ExecutionError{ + Code: errors.BadSpecificationError, + Message: fmt.Sprintf("failed to retrieve downstream nodes for [%s]", currentNode.GetID()), + Kind: core.ExecutionError_SYSTEM, + }), nil } - if recovered.Closure == nil { - logger.Warnf(ctx, "Fetched node execution [%+v] data but was missing closure. Will not attempt to recover", - nCtx.NodeExecutionMetadata().GetNodeExecutionID()) - return handler.PhaseInfoUndefined, nil + if len(downstreamNodes) == 0 { + logger.Debugf(ctx, "No downstream nodes found. Complete.") + return interfaces.NodeStatusComplete, nil } - // A recoverable node execution should always be in a terminal phase - switch recovered.Closure.Phase { - case core.NodeExecution_SKIPPED: - return handler.PhaseInfoSkip(nil, "node execution recovery indicated original node was skipped"), nil - case core.NodeExecution_SUCCEEDED: - fallthrough - case core.NodeExecution_RECOVERED: - logger.Debugf(ctx, "Node [%+v] can be recovered. Proceeding to copy inputs and outputs", nCtx.NodeExecutionMetadata().GetNodeExecutionID()) - default: - // The node execution may be partially recoverable through intra task checkpointing. Save the checkpoint - // uri in the task node state to pass to the task handler later on. - if metadata, ok := recovered.Closure.TargetMetadata.(*admin.NodeExecutionClosure_TaskNodeMetadata); ok { - state := nCtx.NodeStateReader().GetTaskNodeState() - state.PreviousNodeExecutionCheckpointURI = storage.DataReference(metadata.TaskNodeMetadata.CheckpointUri) - err = nCtx.NodeStateWriter().PutTaskNodeState(state) - if err != nil { - logger.Warn(ctx, "failed to save recovered checkpoint uri for [%+v]: [%+v]", - nCtx.NodeExecutionMetadata().GetNodeExecutionID(), err) - } + // If any downstream node is failed, fail, all + // Else if all are success then success + // Else if any one is running then Downstream is still running + allCompleted := true + partialNodeCompletion := false + onFailurePolicy := execContext.GetOnFailurePolicy() + stateOnComplete := interfaces.NodeStatusComplete + for _, downstreamNodeName := range downstreamNodes { + downstreamNode, ok := nl.GetNode(downstreamNodeName) + if !ok { + return interfaces.NodeStatusFailed(&core.ExecutionError{ + Code: errors.BadSpecificationError, + Message: fmt.Sprintf("failed to retrieve downstream node [%s] for [%s]", downstreamNodeName, currentNode.GetID()), + Kind: core.ExecutionError_SYSTEM, + }), nil } - // if this node is a dynamic task we attempt to recover the compiled workflow from instances where the parent - // task succeeded but the dynamic task did not complete. this is important to ensure correctness since node ids - // within the compiled closure may not be generated deterministically. - if recovered.Metadata != nil && recovered.Metadata.IsDynamic && len(recovered.Closure.DynamicJobSpecUri) > 0 { - // recover node inputs - recoveredData, err := c.recoveryClient.RecoverNodeExecutionData(ctx, - nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier, fullyQualifiedNodeID) - if err != nil || recoveredData == nil { - return handler.PhaseInfoUndefined, nil - } - - if _, err := c.recoverInputs(ctx, nCtx, recovered, recoveredData); err != nil { - return handler.PhaseInfoUndefined, err - } - - // copy previous DynamicJobSpec file - f, err := task.NewRemoteFutureFileReader(ctx, nCtx.NodeStatus().GetOutputDir(), nCtx.DataStore()) - if err != nil { - return handler.PhaseInfoUndefined, err - } - - dynamicJobSpecReference := storage.DataReference(recovered.Closure.DynamicJobSpecUri) - if err := nCtx.DataStore().CopyRaw(ctx, dynamicJobSpecReference, f.GetLoc(), storage.Options{}); err != nil { - return handler.PhaseInfoUndefined, errors.Wrapf(errors.StorageError, nCtx.NodeID(), err, - "failed to store dynamic job spec for node. source file [%s] destination file [%s]", dynamicJobSpecReference, f.GetLoc()) - } + state, err := c.RecursiveNodeHandler(ctx, execContext, dag, nl, downstreamNode) + if err != nil { + return interfaces.NodeStatusUndefined, err + } - // transition node phase to 'Running' and dynamic task phase to 'DynamicNodePhaseParentFinalized' - state := nCtx.NodeStateReader().GetDynamicNodeState() - state.Phase = v1alpha1.DynamicNodePhaseParentFinalized - if err := nCtx.NodeStateWriter().PutDynamicNodeState(state); err != nil { - return handler.PhaseInfoUndefined, errors.Wrapf(errors.UnknownError, nCtx.NodeID(), err, "failed to store dynamic node state") + if state.HasFailed() || state.HasTimedOut() { + logger.Debugf(ctx, "Some downstream node has failed. Failed: [%v]. TimedOut: [%v]. Error: [%s]", state.HasFailed(), state.HasTimedOut(), state.Err) + if onFailurePolicy == v1alpha1.WorkflowOnFailurePolicy(core.WorkflowMetadata_FAIL_AFTER_EXECUTABLE_NODES_COMPLETE) { + // If the failure policy allows other nodes to continue running, do not exit the loop, + // Keep track of the last failed state in the loop since it'll be the one to return. + // TODO: If multiple nodes fail (which this mode allows), consolidate/summarize failure states in one. + stateOnComplete = state + } else { + return state, nil } - - return handler.PhaseInfoRunning(&handler.ExecutionInfo{}), nil + } else if !state.IsComplete() { + // A Failed/Timedout node is implicitly considered "complete" this means none of the downstream nodes from + // that node will ever be allowed to run. + // This else block, therefore, deals with all other states. IsComplete will return true if and only if this + // node as well as all of its downstream nodes have finished executing with success statuses. Otherwise we + // mark this node's state as not completed to ensure we will visit it again later. + allCompleted = false } - logger.Debugf(ctx, "Node [%+v] phase [%v] is not recoverable", nCtx.NodeExecutionMetadata().GetNodeExecutionID(), recovered.Closure.Phase) - return handler.PhaseInfoUndefined, nil + if state.PartiallyComplete() { + // This implies that one of the downstream nodes has just succeeded and workflow is ready for propagation + // We do not propagate in current cycle to make it possible to store the state between transitions + partialNodeCompletion = true + } } - recoveredData, err := c.recoveryClient.RecoverNodeExecutionData(ctx, nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier, fullyQualifiedNodeID) - if err != nil { - st, ok := status.FromError(err) - if !ok || st.Code() != codes.NotFound { - logger.Warnf(ctx, "Failed to attemptRecovery node execution data for [%+v] although back-end indicated node was recoverable with err [%+v]", - nCtx.NodeExecutionMetadata().GetNodeExecutionID(), err) - } - return handler.PhaseInfoUndefined, nil + if allCompleted { + logger.Debugf(ctx, "All downstream nodes completed") + return stateOnComplete, nil } - if recoveredData == nil { - logger.Warnf(ctx, "call to attemptRecovery node [%+v] data returned no error but also no data", nCtx.NodeExecutionMetadata().GetNodeExecutionID()) - return handler.PhaseInfoUndefined, nil + + if partialNodeCompletion { + return interfaces.NodeStatusSuccess, nil } - // Copy inputs to this node's expected location - nodeInputs, err := c.recoverInputs(ctx, nCtx, recovered, recoveredData) - if err != nil { - return handler.PhaseInfoUndefined, err + return interfaces.NodeStatusPending, nil +} + +func (c *recursiveNodeExecutor) FinalizeHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) error { + nodeStatus := nl.GetNodeExecutionStatus(ctx, currentNode.GetID()) + nodePhase := nodeStatus.GetPhase() + + if nodePhase == v1alpha1.NodePhaseNotYetStarted { + logger.Infof(ctx, "Node not yet started, will not finalize") + // Nothing to be aborted + return nil } - // Similarly, copy outputs' reference - so := storage.Options{} - var outputs = &core.LiteralMap{} - if recoveredData.FullOutputs != nil { - outputs = recoveredData.FullOutputs - } else if recovered.Closure.GetOutputData() != nil { - outputs = recovered.Closure.GetOutputData() - } else if len(recovered.Closure.GetOutputUri()) > 0 { - if err := c.store.ReadProtobuf(ctx, storage.DataReference(recovered.Closure.GetOutputUri()), outputs); err != nil { - return handler.PhaseInfoUndefined, errors.Wrapf(errors.InputsNotFoundError, nCtx.NodeID(), err, "failed to read output data [%v].", recovered.Closure.GetOutputUri()) + if canHandleNode(nodePhase) { + ctx = contextutils.WithNodeID(ctx, currentNode.GetID()) + + // Now depending on the node type decide + h, err := c.nodeHandlerFactory.GetHandler(currentNode.GetKind()) + if err != nil { + return err + } + + nCtx, err := c.nCtxBuilder.BuildNodeExecutionContext(ctx, execContext, nl, currentNode.GetID()) + if err != nil { + return err + } + // Abort this node + err = c.nodeExecutor.Finalize(ctx, h, nCtx) + if err != nil { + return err } } else { - logger.Debugf(ctx, "No outputs found for recovered node [%+v]", nCtx.NodeExecutionMetadata().GetNodeExecutionID()) - } + // Abort downstream nodes + downstreamNodes, err := dag.FromNode(currentNode.GetID()) + if err != nil { + logger.Debugf(ctx, "Error when retrieving downstream nodes. Error [%v]", err) + return nil + } - outputFile := v1alpha1.GetOutputsFile(nCtx.NodeStatus().GetOutputDir()) - oi := &handler.OutputInfo{ - OutputURI: outputFile, - } + errs := make([]error, 0, len(downstreamNodes)) + for _, d := range downstreamNodes { + downstreamNode, ok := nl.GetNode(d) + if !ok { + return errors.Errorf(errors.BadSpecificationError, currentNode.GetID(), "Unable to find Downstream Node [%v]", d) + } - deckFile := storage.DataReference(recovered.Closure.GetDeckUri()) - if len(deckFile) > 0 { - metadata, err := nCtx.DataStore().Head(ctx, deckFile) - if err != nil { - logger.Errorf(ctx, "Failed to check the existence of deck file. Error: %v", err) - return handler.PhaseInfoUndefined, errors.Wrapf(errors.CausedByError, nCtx.NodeID(), err, "Failed to check the existence of deck file.") + if err := c.FinalizeHandler(ctx, execContext, dag, nl, downstreamNode); err != nil { + logger.Infof(ctx, "Failed to abort node [%v]. Error: %v", d, err) + errs = append(errs, err) + } } - if metadata.Exists() { - oi.DeckURI = &deckFile + if len(errs) > 0 { + return errors.ErrorCollection{Errors: errs} } - } - if err := c.store.WriteProtobuf(ctx, outputFile, so, outputs); err != nil { - logger.Errorf(ctx, "Failed to write protobuf (metadata). Error [%v]", err) - return handler.PhaseInfoUndefined, errors.Wrapf(errors.CausedByError, nCtx.NodeID(), err, "Failed to store recovered node execution outputs") + return nil } - info := &handler.ExecutionInfo{ - Inputs: nodeInputs, - OutputInfo: oi, + return nil +} + +func (c *recursiveNodeExecutor) AbortHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode, reason string) error { + nodeStatus := nl.GetNodeExecutionStatus(ctx, currentNode.GetID()) + nodePhase := nodeStatus.GetPhase() + + if nodePhase == v1alpha1.NodePhaseNotYetStarted { + logger.Infof(ctx, "Node not yet started, will not finalize") + // Nothing to be aborted + return nil } - if recovered.Closure.GetTaskNodeMetadata() != nil { - taskNodeInfo := &handler.TaskNodeInfo{ - TaskNodeMetadata: &event.TaskNodeMetadata{ - CatalogKey: recovered.Closure.GetTaskNodeMetadata().CatalogKey, - CacheStatus: recovered.Closure.GetTaskNodeMetadata().CacheStatus, - }, + if canHandleNode(nodePhase) { + ctx = contextutils.WithNodeID(ctx, currentNode.GetID()) + + // Now depending on the node type decide + h, err := c.nodeHandlerFactory.GetHandler(currentNode.GetKind()) + if err != nil { + return err } - if recoveredData.DynamicWorkflow != nil { - taskNodeInfo.TaskNodeMetadata.DynamicWorkflow = &event.DynamicWorkflowNodeMetadata{ - Id: recoveredData.DynamicWorkflow.Id, - CompiledWorkflow: recoveredData.DynamicWorkflow.CompiledWorkflow, - } + + nCtx, err := c.nCtxBuilder.BuildNodeExecutionContext(ctx, execContext, nl, currentNode.GetID()) + if err != nil { + return err } - info.TaskNodeInfo = taskNodeInfo - } else if recovered.Closure.GetWorkflowNodeMetadata() != nil { - logger.Warnf(ctx, "Attempted to recover node") - info.WorkflowNodeInfo = &handler.WorkflowNodeInfo{ - LaunchedWorkflowID: recovered.Closure.GetWorkflowNodeMetadata().ExecutionId, + // Abort this node + err = c.nodeExecutor.Abort(ctx, h, nCtx, reason) + if err != nil { + return err + } + } else if nodePhase == v1alpha1.NodePhaseSucceeded || nodePhase == v1alpha1.NodePhaseSkipped || nodePhase == v1alpha1.NodePhaseRecovered { + // Abort downstream nodes + downstreamNodes, err := dag.FromNode(currentNode.GetID()) + if err != nil { + logger.Debugf(ctx, "Error when retrieving downstream nodes. Error [%v]", err) + return nil } - } - return handler.PhaseInfoRecovered(info), nil -} - -// In this method we check if the queue is ready to be processed and if so, we prime it in Admin as queued -// Before we start the node execution, we need to transition this Node status to Queued. -// This is because a node execution has to exist before task/wf executions can start. -func (c *nodeExecutor) preExecute(ctx context.Context, dag executors.DAGStructure, nCtx interfaces.NodeExecutionContext) ( - handler.PhaseInfo, error) { - logger.Debugf(ctx, "Node not yet started") - // Query the nodes information to figure out if it can be executed. - predicatePhase, err := CanExecute(ctx, dag, nCtx.ContextualNodeLookup(), nCtx.Node()) - if err != nil { - logger.Debugf(ctx, "Node failed in CanExecute. Error [%s]", err) - return handler.PhaseInfoUndefined, err - } - if predicatePhase == PredicatePhaseReady { - // TODO: Performance problem, we maybe in a retry loop and do not need to resolve the inputs again. - // For now we will do this - node := nCtx.Node() - var nodeInputs *core.LiteralMap - if !node.IsStartNode() { - if nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier != nil { - phaseInfo, err := c.attemptRecovery(ctx, nCtx) - if err != nil || phaseInfo.GetPhase() != handler.EPhaseUndefined { - return phaseInfo, err - } - } - nodeStatus := nCtx.NodeStatus() - dataDir := nodeStatus.GetDataDir() - t := c.metrics.NodeInputGatherLatency.Start(ctx) - defer t.Stop() - // Can execute - var err error - nodeInputs, err = Resolve(ctx, c.outputResolver, nCtx.ContextualNodeLookup(), node.GetID(), node.GetInputBindings()) - // TODO we need to handle retryable, network errors here!! - if err != nil { - c.metrics.ResolutionFailure.Inc(ctx) - logger.Warningf(ctx, "Failed to resolve inputs for Node. Error [%v]", err) - return handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, "BindingResolutionFailure", err.Error(), nil), nil + errs := make([]error, 0, len(downstreamNodes)) + for _, d := range downstreamNodes { + downstreamNode, ok := nl.GetNode(d) + if !ok { + return errors.Errorf(errors.BadSpecificationError, currentNode.GetID(), "Unable to find Downstream Node [%v]", d) } - if nodeInputs != nil { - inputsFile := v1alpha1.GetInputsFile(dataDir) - if err := c.store.WriteProtobuf(ctx, inputsFile, storage.Options{}, nodeInputs); err != nil { - c.metrics.InputsWriteFailure.Inc(ctx) - logger.Errorf(ctx, "Failed to store inputs for Node. Error [%v]. InputsFile [%s]", err, inputsFile) - return handler.PhaseInfoUndefined, errors.Wrapf( - errors.StorageError, node.GetID(), err, "Failed to store inputs for Node. InputsFile [%s]", inputsFile) - } + if err := c.AbortHandler(ctx, execContext, dag, nl, downstreamNode, reason); err != nil { + logger.Infof(ctx, "Failed to abort node [%v]. Error: %v", d, err) + errs = append(errs, err) } + } - logger.Debugf(ctx, "Node Data Directory [%s].", nodeStatus.GetDataDir()) + if len(errs) > 0 { + return errors.ErrorCollection{Errors: errs} } - return handler.PhaseInfoQueued("node queued", nodeInputs), nil + return nil + } else { + ctx = contextutils.WithNodeID(ctx, currentNode.GetID()) + logger.Warnf(ctx, "Trying to abort a node in state [%s]", nodeStatus.GetPhase().String()) } - // Now that we have resolved the inputs, we can record as a transition latency. This is because we have completed - // all the overhead that we have to compute. Any failures after this will incur this penalty, but it could be due - // to various external reasons - like queuing, overuse of quota, plugin overhead etc. - logger.Debugf(ctx, "preExecute completed in phase [%s]", predicatePhase.String()) - if predicatePhase == PredicatePhaseSkip { - return handler.PhaseInfoSkip(nil, "Node Skipped as parent node was skipped"), nil + return nil +} + +func (c *recursiveNodeExecutor) Initialize(ctx context.Context) error { + logger.Infof(ctx, "Initializing Core Node Executor") + s := c.newSetupContext(ctx) + return c.nodeHandlerFactory.Setup(ctx, s) +} + +// TODO @hamersaw docs +func (c *recursiveNodeExecutor) GetNodeExecutionContextBuilder() interfaces.NodeExecutionContextBuilder { + return c.nCtxBuilder +} + +func (c *recursiveNodeExecutor) WithNodeExecutionContextBuilder(nCtxBuilder interfaces.NodeExecutionContextBuilder) interfaces.Node { + return &recursiveNodeExecutor{ + nodeExecutor: c.nodeExecutor, + nCtxBuilder: nCtxBuilder, + // TODO @hamersaw fill out + enqueueWorkflow: c.enqueueWorkflow, + nodeHandlerFactory: c.nodeHandlerFactory, + store: c.store, + metrics: c.metrics, } +} - return handler.PhaseInfoNotReady("predecessor node not yet complete"), nil +// TODO @hamersaw nodeExecutor goes here and write docs +type nodeExecutor struct { + clusterID string + defaultActiveDeadline time.Duration + defaultDataSandbox storage.DataReference + defaultExecutionDeadline time.Duration + enqueueWorkflow v1alpha1.EnqueueWorkflow + eventConfig *config.EventConfig + interruptibleFailureThreshold uint32 + maxDatasetSizeBytes int64 + maxNodeRetriesForSystemFailures uint32 + metrics *nodeMetrics + nodeRecorder events.NodeEventRecorder + outputResolver OutputResolver + recoveryClient recovery.Client + shardSelector ioutils.ShardSelector + store *storage.DataStore + taskRecorder events.TaskEventRecorder } -func isTimeoutExpired(queuedAt *metav1.Time, timeout time.Duration) bool { - if !queuedAt.IsZero() && timeout != 0 { - deadline := queuedAt.Add(timeout) - if deadline.Before(time.Now()) { - return true +func (c *nodeExecutor) RecordTransitionLatency(ctx context.Context, dag executors.DAGStructure, nl executors.NodeLookup, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus) { + if nodeStatus.GetPhase() == v1alpha1.NodePhaseNotYetStarted || nodeStatus.GetPhase() == v1alpha1.NodePhaseQueued { + // Log transition latency (The most recently finished parent node endAt time to this node's queuedAt time -now-) + t, err := GetParentNodeMaxEndTime(ctx, dag, nl, node) + if err != nil { + logger.Warnf(ctx, "Failed to record transition latency for node. Error: %s", err.Error()) + return } + if !t.IsZero() { + c.metrics.TransitionLatency.Observe(ctx, t.Time, time.Now()) + } + } else if nodeStatus.GetPhase() == v1alpha1.NodePhaseRetryableFailure && nodeStatus.GetLastUpdatedAt() != nil { + c.metrics.TransitionLatency.Observe(ctx, nodeStatus.GetLastUpdatedAt().Time, time.Now()) } - return false } -func (c *nodeExecutor) isEligibleForRetry(nCtx interfaces.NodeExecutionContext, nodeStatus v1alpha1.ExecutableNodeStatus, err *core.ExecutionError) (currentAttempt, maxAttempts uint32, isEligible bool) { - if err.Kind == core.ExecutionError_SYSTEM { - currentAttempt = nodeStatus.GetSystemFailures() - maxAttempts = c.maxNodeRetriesForSystemFailures - isEligible = currentAttempt < c.maxNodeRetriesForSystemFailures - return +func (c *nodeExecutor) IdempotentRecordEvent(ctx context.Context, nodeEvent *event.NodeExecutionEvent) error { + if nodeEvent == nil { + return fmt.Errorf("event recording attempt of Nil Node execution event") } - currentAttempt = (nodeStatus.GetAttempts() + 1) - nodeStatus.GetSystemFailures() - if nCtx.Node().GetRetryStrategy() != nil && nCtx.Node().GetRetryStrategy().MinAttempts != nil { - maxAttempts = uint32(*nCtx.Node().GetRetryStrategy().MinAttempts) + if nodeEvent.Id == nil { + return fmt.Errorf("event recording attempt of with nil node Event ID") } - isEligible = currentAttempt < maxAttempts - return -} - -func (c *nodeExecutor) execute(ctx context.Context, h handler.Node, nCtx interfaces.NodeExecutionContext, nodeStatus v1alpha1.ExecutableNodeStatus) (handler.PhaseInfo, error) { - logger.Debugf(ctx, "Executing node") - defer logger.Debugf(ctx, "Node execution round complete") - t, err := h.Handle(ctx, nCtx) + logger.Infof(ctx, "Recording NodeEvent [%s] phase[%s]", nodeEvent.GetId().String(), nodeEvent.Phase.String()) + err := c.nodeRecorder.RecordNodeEvent(ctx, nodeEvent, c.eventConfig) if err != nil { - return handler.PhaseInfoUndefined, err + if nodeEvent.GetId().NodeId == v1alpha1.EndNodeID { + return nil + } + + if eventsErr.IsAlreadyExists(err) { + logger.Infof(ctx, "Node event phase: %s, nodeId %s already exist", + nodeEvent.Phase.String(), nodeEvent.GetId().NodeId) + return nil + } else if eventsErr.IsEventAlreadyInTerminalStateError(err) { + if IsTerminalNodePhase(nodeEvent.Phase) { + // Event was trying to record a different terminal phase for an already terminal event. ignoring. + logger.Infof(ctx, "Node event phase: %s, nodeId %s already in terminal phase. err: %s", + nodeEvent.Phase.String(), nodeEvent.GetId().NodeId, err.Error()) + return nil + } + logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) + return errors.Wrapf(errors.IllegalStateError, nodeEvent.Id.NodeId, err, "phase mis-match mismatch between propeller and control plane; Trying to record Node p: %s", nodeEvent.Phase) + } } + return err +} - phase := t.Info() - // check for timeout for non-terminal phases - if !phase.GetPhase().IsTerminal() { - activeDeadline := c.defaultActiveDeadline - if nCtx.Node().GetActiveDeadline() != nil && *nCtx.Node().GetActiveDeadline() > 0 { - activeDeadline = *nCtx.Node().GetActiveDeadline() +func (c *nodeExecutor) recoverInputs(ctx context.Context, nCtx interfaces.NodeExecutionContext, + recovered *admin.NodeExecution, recoveredData *admin.NodeExecutionGetDataResponse) (*core.LiteralMap, error) { + + nodeInputs := recoveredData.FullInputs + if nodeInputs != nil { + if err := c.store.WriteProtobuf(ctx, nCtx.InputReader().GetInputPath(), storage.Options{}, nodeInputs); err != nil { + c.metrics.InputsWriteFailure.Inc(ctx) + logger.Errorf(ctx, "Failed to move recovered inputs for Node. Error [%v]. InputsFile [%s]", err, nCtx.InputReader().GetInputPath()) + return nil, errors.Wrapf(errors.StorageError, nCtx.NodeID(), err, "Failed to store inputs for Node. InputsFile [%s]", nCtx.InputReader().GetInputPath()) } - if isTimeoutExpired(nodeStatus.GetQueuedAt(), activeDeadline) { - logger.Infof(ctx, "Node has timed out; timeout configured: %v", activeDeadline) - return handler.PhaseInfoTimedOut(nil, fmt.Sprintf("task active timeout [%s] expired", activeDeadline.String())), nil + } else if len(recovered.InputUri) > 0 { + // If the inputs are too large they won't be returned inline in the RecoverData call. We must fetch them before copying them. + nodeInputs = &core.LiteralMap{} + if recoveredData.FullInputs == nil { + if err := c.store.ReadProtobuf(ctx, storage.DataReference(recovered.InputUri), nodeInputs); err != nil { + return nil, errors.Wrapf(errors.InputsNotFoundError, nCtx.NodeID(), err, "failed to read data from dataDir [%v].", recovered.InputUri) + } } - // Execution timeout is a retry-able error - executionDeadline := c.defaultExecutionDeadline - if nCtx.Node().GetExecutionDeadline() != nil && *nCtx.Node().GetExecutionDeadline() > 0 { - executionDeadline = *nCtx.Node().GetExecutionDeadline() + if err := c.store.WriteProtobuf(ctx, nCtx.InputReader().GetInputPath(), storage.Options{}, nodeInputs); err != nil { + c.metrics.InputsWriteFailure.Inc(ctx) + logger.Errorf(ctx, "Failed to move recovered inputs for Node. Error [%v]. InputsFile [%s]", err, nCtx.InputReader().GetInputPath()) + return nil, errors.Wrapf(errors.StorageError, nCtx.NodeID(), err, "Failed to store inputs for Node. InputsFile [%s]", nCtx.InputReader().GetInputPath()) } - if isTimeoutExpired(nodeStatus.GetLastAttemptStartedAt(), executionDeadline) { - logger.Infof(ctx, "Current execution for the node timed out; timeout configured: %v", executionDeadline) - executionErr := &core.ExecutionError{Code: "TimeoutExpired", Message: fmt.Sprintf("task execution timeout [%s] expired", executionDeadline.String()), Kind: core.ExecutionError_USER} - phase = handler.PhaseInfoRetryableFailureErr(executionErr, nil) + } + + return nodeInputs, nil +} + +func (c *nodeExecutor) attemptRecovery(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.PhaseInfo, error) { + fullyQualifiedNodeID := nCtx.NodeExecutionMetadata().GetNodeExecutionID().NodeId + if nCtx.ExecutionContext().GetEventVersion() != v1alpha1.EventVersion0 { + // compute fully qualified node id (prefixed with parent id and retry attempt) to ensure uniqueness + var err error + fullyQualifiedNodeID, err = common.GenerateUniqueID(nCtx.ExecutionContext().GetParentInfo(), nCtx.NodeExecutionMetadata().GetNodeExecutionID().NodeId) + if err != nil { + return handler.PhaseInfoUndefined, err } } - if phase.GetPhase() == handler.EPhaseRetryableFailure { - currentAttempt, maxAttempts, isEligible := c.isEligibleForRetry(nCtx, nodeStatus, phase.GetErr()) - if !isEligible { - return handler.PhaseInfoFailure( - core.ExecutionError_USER, - fmt.Sprintf("RetriesExhausted|%s", phase.GetErr().Code), - fmt.Sprintf("[%d/%d] currentAttempt done. Last Error: %s::%s", currentAttempt, maxAttempts, phase.GetErr().Kind.String(), phase.GetErr().Message), - phase.GetInfo(), - ), nil + recovered, err := c.recoveryClient.RecoverNodeExecution(ctx, nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier, fullyQualifiedNodeID) + if err != nil { + st, ok := status.FromError(err) + if !ok || st.Code() != codes.NotFound { + logger.Warnf(ctx, "Failed to recover node [%+v] with err [%+v]", nCtx.NodeExecutionMetadata().GetNodeExecutionID(), err) + } + // The node is not recoverable when it's not found in the parent execution + return handler.PhaseInfoUndefined, nil + } + if recovered == nil { + logger.Warnf(ctx, "call to recover node [%+v] returned no error but also no node", nCtx.NodeExecutionMetadata().GetNodeExecutionID()) + return handler.PhaseInfoUndefined, nil + } + if recovered.Closure == nil { + logger.Warnf(ctx, "Fetched node execution [%+v] data but was missing closure. Will not attempt to recover", + nCtx.NodeExecutionMetadata().GetNodeExecutionID()) + return handler.PhaseInfoUndefined, nil + } + // A recoverable node execution should always be in a terminal phase + switch recovered.Closure.Phase { + case core.NodeExecution_SKIPPED: + return handler.PhaseInfoSkip(nil, "node execution recovery indicated original node was skipped"), nil + case core.NodeExecution_SUCCEEDED: + fallthrough + case core.NodeExecution_RECOVERED: + logger.Debugf(ctx, "Node [%+v] can be recovered. Proceeding to copy inputs and outputs", nCtx.NodeExecutionMetadata().GetNodeExecutionID()) + default: + // The node execution may be partially recoverable through intra task checkpointing. Save the checkpoint + // uri in the task node state to pass to the task handler later on. + if metadata, ok := recovered.Closure.TargetMetadata.(*admin.NodeExecutionClosure_TaskNodeMetadata); ok { + state := nCtx.NodeStateReader().GetTaskNodeState() + state.PreviousNodeExecutionCheckpointURI = storage.DataReference(metadata.TaskNodeMetadata.CheckpointUri) + err = nCtx.NodeStateWriter().PutTaskNodeState(state) + if err != nil { + logger.Warn(ctx, "failed to save recovered checkpoint uri for [%+v]: [%+v]", + nCtx.NodeExecutionMetadata().GetNodeExecutionID(), err) + } } - // Retrying to clearing all status - nCtx.NodeStateWriter().ClearNodeStatus() - } - - return phase, nil -} + // if this node is a dynamic task we attempt to recover the compiled workflow from instances where the parent + // task succeeded but the dynamic task did not complete. this is important to ensure correctness since node ids + // within the compiled closure may not be generated deterministically. + if recovered.Metadata != nil && recovered.Metadata.IsDynamic && len(recovered.Closure.DynamicJobSpecUri) > 0 { + // recover node inputs + recoveredData, err := c.recoveryClient.RecoverNodeExecutionData(ctx, + nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier, fullyQualifiedNodeID) + if err != nil || recoveredData == nil { + return handler.PhaseInfoUndefined, nil + } -func (c *nodeExecutor) abort(ctx context.Context, h handler.Node, nCtx interfaces.NodeExecutionContext, reason string) error { - logger.Debugf(ctx, "Calling aborting & finalize") - if err := h.Abort(ctx, nCtx, reason); err != nil { - finalizeErr := h.Finalize(ctx, nCtx) - if finalizeErr != nil { - return errors.ErrorCollection{Errors: []error{err, finalizeErr}} - } - return err - } + if _, err := c.recoverInputs(ctx, nCtx, recovered, recoveredData); err != nil { + return handler.PhaseInfoUndefined, err + } - return h.Finalize(ctx, nCtx) -} + // copy previous DynamicJobSpec file + f, err := task.NewRemoteFutureFileReader(ctx, nCtx.NodeStatus().GetOutputDir(), nCtx.DataStore()) + if err != nil { + return handler.PhaseInfoUndefined, err + } -func (c *nodeExecutor) finalize(ctx context.Context, h handler.Node, nCtx interfaces.NodeExecutionContext) error { - return h.Finalize(ctx, nCtx) -} + dynamicJobSpecReference := storage.DataReference(recovered.Closure.DynamicJobSpecUri) + if err := nCtx.DataStore().CopyRaw(ctx, dynamicJobSpecReference, f.GetLoc(), storage.Options{}); err != nil { + return handler.PhaseInfoUndefined, errors.Wrapf(errors.StorageError, nCtx.NodeID(), err, + "failed to store dynamic job spec for node. source file [%s] destination file [%s]", dynamicJobSpecReference, f.GetLoc()) + } -func (c *nodeExecutor) handleNotYetStartedNode(ctx context.Context, dag executors.DAGStructure, nCtx interfaces.NodeExecutionContext, _ handler.Node) (interfaces.NodeStatus, error) { - logger.Debugf(ctx, "Node not yet started, running pre-execute") - defer logger.Debugf(ctx, "Node pre-execute completed") - occurredAt := time.Now() - p, err := c.preExecute(ctx, dag, nCtx) - if err != nil { - logger.Errorf(ctx, "failed preExecute for node. Error: %s", err.Error()) - return interfaces.NodeStatusUndefined, err - } + // transition node phase to 'Running' and dynamic task phase to 'DynamicNodePhaseParentFinalized' + state := nCtx.NodeStateReader().GetDynamicNodeState() + state.Phase = v1alpha1.DynamicNodePhaseParentFinalized + if err := nCtx.NodeStateWriter().PutDynamicNodeState(state); err != nil { + return handler.PhaseInfoUndefined, errors.Wrapf(errors.UnknownError, nCtx.NodeID(), err, "failed to store dynamic node state") + } - if p.GetPhase() == handler.EPhaseUndefined { - return interfaces.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, nCtx.NodeID(), "received undefined phase.") - } + return handler.PhaseInfoRunning(&handler.ExecutionInfo{}), nil + } - if p.GetPhase() == handler.EPhaseNotReady { - return interfaces.NodeStatusPending, nil + logger.Debugf(ctx, "Node [%+v] phase [%v] is not recoverable", nCtx.NodeExecutionMetadata().GetNodeExecutionID(), recovered.Closure.Phase) + return handler.PhaseInfoUndefined, nil } - np, err := ToNodePhase(p.GetPhase()) + recoveredData, err := c.recoveryClient.RecoverNodeExecutionData(ctx, nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier, fullyQualifiedNodeID) if err != nil { - return interfaces.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "failed to move from queued") - } - - nodeStatus := nCtx.NodeStatus() - if np != nodeStatus.GetPhase() { - // assert np == Queued! - logger.Infof(ctx, "Change in node state detected from [%s] -> [%s]", nodeStatus.GetPhase().String(), np.String()) - p = p.WithOccuredAt(occurredAt) - - nev, err := ToNodeExecutionEvent(nCtx.NodeExecutionMetadata().GetNodeExecutionID(), - p, nCtx.InputReader().GetInputPath().String(), nodeStatus, nCtx.ExecutionContext().GetEventVersion(), - nCtx.ExecutionContext().GetParentInfo(), nCtx.Node(), c.clusterID, nCtx.NodeStateReader().GetDynamicNodeState().Phase, - c.eventConfig) - if err != nil { - return interfaces.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "could not convert phase info to event") - } - err = c.IdempotentRecordEvent(ctx, nev) - if err != nil { - logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) - return interfaces.NodeStatusUndefined, errors.Wrapf(errors.EventRecordingFailed, nCtx.NodeID(), err, "failed to record node event") + st, ok := status.FromError(err) + if !ok || st.Code() != codes.NotFound { + logger.Warnf(ctx, "Failed to attemptRecovery node execution data for [%+v] although back-end indicated node was recoverable with err [%+v]", + nCtx.NodeExecutionMetadata().GetNodeExecutionID(), err) } - UpdateNodeStatus(np, p, nCtx.NodeStateReader(), nodeStatus) - c.RecordTransitionLatency(ctx, dag, nCtx.ContextualNodeLookup(), nCtx.Node(), nodeStatus) + return handler.PhaseInfoUndefined, nil } - - if np == v1alpha1.NodePhaseQueued { - if nCtx.NodeExecutionMetadata().IsInterruptible() { - c.metrics.InterruptibleNodesRunning.Inc(ctx) - } - return interfaces.NodeStatusQueued, nil - } else if np == v1alpha1.NodePhaseSkipped { - return interfaces.NodeStatusSuccess, nil + if recoveredData == nil { + logger.Warnf(ctx, "call to attemptRecovery node [%+v] data returned no error but also no data", nCtx.NodeExecutionMetadata().GetNodeExecutionID()) + return handler.PhaseInfoUndefined, nil } - return interfaces.NodeStatusPending, nil -} - -func (c *nodeExecutor) handleQueuedOrRunningNode(ctx context.Context, nCtx interfaces.NodeExecutionContext, h handler.Node) (interfaces.NodeStatus, error) { - nodeStatus := nCtx.NodeStatus() - currentPhase := nodeStatus.GetPhase() - - // case v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseRunning: - logger.Debugf(ctx, "node executing, current phase [%s]", currentPhase) - defer logger.Debugf(ctx, "node execution completed") - - // Since we reset node status inside execute for retryable failure, we use lastAttemptStartTime to carry that information - // across execute which is used to emit metrics - lastAttemptStartTime := nodeStatus.GetLastAttemptStartedAt() - - p, err := c.execute(ctx, h, nCtx, nodeStatus) + // Copy inputs to this node's expected location + nodeInputs, err := c.recoverInputs(ctx, nCtx, recovered, recoveredData) if err != nil { - logger.Errorf(ctx, "failed Execute for node. Error: %s", err.Error()) - return interfaces.NodeStatusUndefined, err + return handler.PhaseInfoUndefined, err } - if p.GetPhase() == handler.EPhaseUndefined { - return interfaces.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, nCtx.NodeID(), "received undefined phase.") + // Similarly, copy outputs' reference + so := storage.Options{} + var outputs = &core.LiteralMap{} + if recoveredData.FullOutputs != nil { + outputs = recoveredData.FullOutputs + } else if recovered.Closure.GetOutputData() != nil { + outputs = recovered.Closure.GetOutputData() + } else if len(recovered.Closure.GetOutputUri()) > 0 { + if err := c.store.ReadProtobuf(ctx, storage.DataReference(recovered.Closure.GetOutputUri()), outputs); err != nil { + return handler.PhaseInfoUndefined, errors.Wrapf(errors.InputsNotFoundError, nCtx.NodeID(), err, "failed to read output data [%v].", recovered.Closure.GetOutputUri()) + } + } else { + logger.Debugf(ctx, "No outputs found for recovered node [%+v]", nCtx.NodeExecutionMetadata().GetNodeExecutionID()) } - np, err := ToNodePhase(p.GetPhase()) - if err != nil { - return interfaces.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "failed to move from queued") + outputFile := v1alpha1.GetOutputsFile(nCtx.NodeStatus().GetOutputDir()) + oi := &handler.OutputInfo{ + OutputURI: outputFile, } - // execErr in phase-info 'p' is only available if node has failed to execute, and the current phase at that time - // will be v1alpha1.NodePhaseRunning - execErr := p.GetErr() - if execErr != nil && (currentPhase == v1alpha1.NodePhaseRunning || currentPhase == v1alpha1.NodePhaseQueued || - currentPhase == v1alpha1.NodePhaseDynamicRunning) { - endTime := time.Now() - startTime := endTime - if lastAttemptStartTime != nil { - startTime = lastAttemptStartTime.Time + deckFile := storage.DataReference(recovered.Closure.GetDeckUri()) + if len(deckFile) > 0 { + metadata, err := nCtx.DataStore().Head(ctx, deckFile) + if err != nil { + logger.Errorf(ctx, "Failed to check the existence of deck file. Error: %v", err) + return handler.PhaseInfoUndefined, errors.Wrapf(errors.CausedByError, nCtx.NodeID(), err, "Failed to check the existence of deck file.") } - if execErr.GetKind() == core.ExecutionError_SYSTEM { - nodeStatus.IncrementSystemFailures() - c.metrics.SystemErrorDuration.Observe(ctx, startTime, endTime) - } else if execErr.GetKind() == core.ExecutionError_USER { - c.metrics.UserErrorDuration.Observe(ctx, startTime, endTime) - } else { - c.metrics.UnknownErrorDuration.Observe(ctx, startTime, endTime) - } - // When a node fails, we fail the workflow. Independent of number of nodes succeeding/failing, whenever a first node fails, - // the entire workflow is failed. - if np == v1alpha1.NodePhaseFailing { - if execErr.GetKind() == core.ExecutionError_SYSTEM { - nodeStatus.IncrementSystemFailures() - c.metrics.PermanentSystemErrorDuration.Observe(ctx, startTime, endTime) - } else if execErr.GetKind() == core.ExecutionError_USER { - c.metrics.PermanentUserErrorDuration.Observe(ctx, startTime, endTime) - } else { - c.metrics.PermanentUnknownErrorDuration.Observe(ctx, startTime, endTime) - } + if metadata.Exists() { + oi.DeckURI = &deckFile } } - finalStatus := interfaces.NodeStatusRunning - if np == v1alpha1.NodePhaseFailing && !h.FinalizeRequired() { - logger.Infof(ctx, "Finalize not required, moving node to Failed") - np = v1alpha1.NodePhaseFailed - finalStatus = interfaces.NodeStatusFailed(p.GetErr()) - } - if np == v1alpha1.NodePhaseTimingOut && !h.FinalizeRequired() { - logger.Infof(ctx, "Finalize not required, moving node to TimedOut") - np = v1alpha1.NodePhaseTimedOut - finalStatus = interfaces.NodeStatusTimedOut + if err := c.store.WriteProtobuf(ctx, outputFile, so, outputs); err != nil { + logger.Errorf(ctx, "Failed to write protobuf (metadata). Error [%v]", err) + return handler.PhaseInfoUndefined, errors.Wrapf(errors.CausedByError, nCtx.NodeID(), err, "Failed to store recovered node execution outputs") } - if np == v1alpha1.NodePhaseSucceeding && !h.FinalizeRequired() { - logger.Infof(ctx, "Finalize not required, moving node to Succeeded") - np = v1alpha1.NodePhaseSucceeded - finalStatus = interfaces.NodeStatusSuccess - } - if np == v1alpha1.NodePhaseRecovered { - logger.Infof(ctx, "Finalize not required, moving node to Recovered") - finalStatus = interfaces.NodeStatusRecovered + info := &handler.ExecutionInfo{ + Inputs: nodeInputs, + OutputInfo: oi, } - // If it is retryable failure, we do no want to send any events, as the node is essentially still running - // Similarly if the phase has not changed from the last time, events do not need to be sent - if np != nodeStatus.GetPhase() && np != v1alpha1.NodePhaseRetryableFailure { - // assert np == skipped, succeeding, failing or recovered - logger.Infof(ctx, "Change in node state detected from [%s] -> [%s], (handler phase [%s])", nodeStatus.GetPhase().String(), np.String(), p.GetPhase().String()) - - nev, err := ToNodeExecutionEvent(nCtx.NodeExecutionMetadata().GetNodeExecutionID(), - p, nCtx.InputReader().GetInputPath().String(), nCtx.NodeStatus(), nCtx.ExecutionContext().GetEventVersion(), - nCtx.ExecutionContext().GetParentInfo(), nCtx.Node(), c.clusterID, nCtx.NodeStateReader().GetDynamicNodeState().Phase, - c.eventConfig) - if err != nil { - return interfaces.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "could not convert phase info to event") - } - - err = c.IdempotentRecordEvent(ctx, nev) - if err != nil { - if eventsErr.IsTooLarge(err) { - // With large enough dynamic task fanouts the reported node event, which contains the compiled - // workflow closure, can exceed the gRPC message size limit. In this case we immediately - // transition the node to failing to abort the workflow. - np = v1alpha1.NodePhaseFailing - p = handler.PhaseInfoFailure(core.ExecutionError_USER, "NodeFailed", err.Error(), p.GetInfo()) - - err = c.IdempotentRecordEvent(ctx, &event.NodeExecutionEvent{ - Id: nCtx.NodeExecutionMetadata().GetNodeExecutionID(), - Phase: core.NodeExecution_FAILED, - OccurredAt: ptypes.TimestampNow(), - OutputResult: &event.NodeExecutionEvent_Error{ - Error: &core.ExecutionError{ - Code: "NodeFailed", - Message: err.Error(), - }, - }, - ReportedAt: ptypes.TimestampNow(), - }) - - if err != nil { - return interfaces.NodeStatusUndefined, errors.Wrapf(errors.EventRecordingFailed, nCtx.NodeID(), err, "failed to record node event") - } - } else { - logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) - return interfaces.NodeStatusUndefined, errors.Wrapf(errors.EventRecordingFailed, nCtx.NodeID(), err, "failed to record node event") - } + if recovered.Closure.GetTaskNodeMetadata() != nil { + taskNodeInfo := &handler.TaskNodeInfo{ + TaskNodeMetadata: &event.TaskNodeMetadata{ + CatalogKey: recovered.Closure.GetTaskNodeMetadata().CatalogKey, + CacheStatus: recovered.Closure.GetTaskNodeMetadata().CacheStatus, + }, } - - // We reach here only when transitioning from Queued to Running. In this case, the startedAt is not set. - if np == v1alpha1.NodePhaseRunning { - if nodeStatus.GetQueuedAt() != nil { - c.metrics.QueuingLatency.Observe(ctx, nodeStatus.GetQueuedAt().Time, time.Now()) + if recoveredData.DynamicWorkflow != nil { + taskNodeInfo.TaskNodeMetadata.DynamicWorkflow = &event.DynamicWorkflowNodeMetadata{ + Id: recoveredData.DynamicWorkflow.Id, + CompiledWorkflow: recoveredData.DynamicWorkflow.CompiledWorkflow, } } + info.TaskNodeInfo = taskNodeInfo + } else if recovered.Closure.GetWorkflowNodeMetadata() != nil { + logger.Warnf(ctx, "Attempted to recover node") + info.WorkflowNodeInfo = &handler.WorkflowNodeInfo{ + LaunchedWorkflowID: recovered.Closure.GetWorkflowNodeMetadata().ExecutionId, + } } - - UpdateNodeStatus(np, p, nCtx.NodeStateReader(), nodeStatus) - return finalStatus, nil + return handler.PhaseInfoRecovered(info), nil } -func (c *nodeExecutor) handleRetryableFailure(ctx context.Context, nCtx interfaces.NodeExecutionContext, h handler.Node) (interfaces.NodeStatus, error) { - nodeStatus := nCtx.NodeStatus() - logger.Debugf(ctx, "node failed with retryable failure, aborting and finalizing, message: %s", nodeStatus.GetMessage()) - if err := c.abort(ctx, h, nCtx, nodeStatus.GetMessage()); err != nil { - return interfaces.NodeStatusUndefined, err +// In this method we check if the queue is ready to be processed and if so, we prime it in Admin as queued +// Before we start the node execution, we need to transition this Node status to Queued. +// This is because a node execution has to exist before task/wf executions can start. +func (c *nodeExecutor) preExecute(ctx context.Context, dag executors.DAGStructure, nCtx interfaces.NodeExecutionContext) ( + handler.PhaseInfo, error) { + logger.Debugf(ctx, "Node not yet started") + // Query the nodes information to figure out if it can be executed. + predicatePhase, err := CanExecute(ctx, dag, nCtx.ContextualNodeLookup(), nCtx.Node()) + if err != nil { + logger.Debugf(ctx, "Node failed in CanExecute. Error [%s]", err) + return handler.PhaseInfoUndefined, err } - // NOTE: It is important to increment attempts only after abort has been called. Increment attempt mutates the state - // Attempt is used throughout the system to determine the idempotent resource version. - nodeStatus.IncrementAttempts() - nodeStatus.UpdatePhase(v1alpha1.NodePhaseRunning, metav1.Now(), "retrying", nil) - // We are going to retry in the next round, so we should clear all current state - nodeStatus.ClearSubNodeStatus() - nodeStatus.ClearTaskStatus() - nodeStatus.ClearWorkflowStatus() - nodeStatus.ClearDynamicNodeStatus() - nodeStatus.ClearGateNodeStatus() - nodeStatus.ClearArrayNodeStatus() - return interfaces.NodeStatusPending, nil -} - -func (c *nodeExecutor) handleNode(ctx context.Context, dag executors.DAGStructure, nCtx interfaces.NodeExecutionContext, h handler.Node) (interfaces.NodeStatus, error) { - logger.Debugf(ctx, "Handling Node [%s]", nCtx.NodeID()) - defer logger.Debugf(ctx, "Completed node [%s]", nCtx.NodeID()) + if predicatePhase == PredicatePhaseReady { + // TODO: Performance problem, we maybe in a retry loop and do not need to resolve the inputs again. + // For now we will do this + node := nCtx.Node() + var nodeInputs *core.LiteralMap + if !node.IsStartNode() { + if nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier != nil { + phaseInfo, err := c.attemptRecovery(ctx, nCtx) + if err != nil || phaseInfo.GetPhase() != handler.EPhaseUndefined { + return phaseInfo, err + } + } + nodeStatus := nCtx.NodeStatus() + dataDir := nodeStatus.GetDataDir() + t := c.metrics.NodeInputGatherLatency.Start(ctx) + defer t.Stop() + // Can execute + var err error + nodeInputs, err = Resolve(ctx, c.outputResolver, nCtx.ContextualNodeLookup(), node.GetID(), node.GetInputBindings()) + // TODO we need to handle retryable, network errors here!! + if err != nil { + c.metrics.ResolutionFailure.Inc(ctx) + logger.Warningf(ctx, "Failed to resolve inputs for Node. Error [%v]", err) + return handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, "BindingResolutionFailure", err.Error(), nil), nil + } - nodeStatus := nCtx.NodeStatus() - currentPhase := nodeStatus.GetPhase() + if nodeInputs != nil { + inputsFile := v1alpha1.GetInputsFile(dataDir) + if err := c.store.WriteProtobuf(ctx, inputsFile, storage.Options{}, nodeInputs); err != nil { + c.metrics.InputsWriteFailure.Inc(ctx) + logger.Errorf(ctx, "Failed to store inputs for Node. Error [%v]. InputsFile [%s]", err, inputsFile) + return handler.PhaseInfoUndefined, errors.Wrapf( + errors.StorageError, node.GetID(), err, "Failed to store inputs for Node. InputsFile [%s]", inputsFile) + } + } - // Optimization! - // If it is start node we directly move it to Queued without needing to run preExecute - if currentPhase == v1alpha1.NodePhaseNotYetStarted && !nCtx.Node().IsStartNode() { - p, err := c.handleNotYetStartedNode(ctx, dag, nCtx, h) - if err != nil { - return p, err - } - if p.NodePhase == interfaces.NodePhaseQueued { - logger.Infof(ctx, "Node was queued, parallelism is now [%d]", nCtx.ExecutionContext().IncrementParallelism()) + logger.Debugf(ctx, "Node Data Directory [%s].", nodeStatus.GetDataDir()) } - return p, err - } - if currentPhase == v1alpha1.NodePhaseFailing { - logger.Debugf(ctx, "node failing") - if err := c.abort(ctx, h, nCtx, "node failing"); err != nil { - return interfaces.NodeStatusUndefined, err - } - nodeStatus.UpdatePhase(v1alpha1.NodePhaseFailed, metav1.Now(), nodeStatus.GetMessage(), nodeStatus.GetExecutionError()) - c.metrics.FailureDuration.Observe(ctx, nodeStatus.GetStartedAt().Time, nodeStatus.GetStoppedAt().Time) - if nCtx.NodeExecutionMetadata().IsInterruptible() { - c.metrics.InterruptibleNodesTerminated.Inc(ctx) - } - return interfaces.NodeStatusFailed(nodeStatus.GetExecutionError()), nil + return handler.PhaseInfoQueued("node queued", nodeInputs), nil } - if currentPhase == v1alpha1.NodePhaseTimingOut { - logger.Debugf(ctx, "node timing out") - if err := c.abort(ctx, h, nCtx, "node timed out"); err != nil { - return interfaces.NodeStatusUndefined, err - } - - nodeStatus.ClearSubNodeStatus() - nodeStatus.UpdatePhase(v1alpha1.NodePhaseTimedOut, metav1.Now(), nodeStatus.GetMessage(), nodeStatus.GetExecutionError()) - c.metrics.TimedOutFailure.Inc(ctx) - if nCtx.NodeExecutionMetadata().IsInterruptible() { - c.metrics.InterruptibleNodesTerminated.Inc(ctx) - } - return interfaces.NodeStatusTimedOut, nil + // Now that we have resolved the inputs, we can record as a transition latency. This is because we have completed + // all the overhead that we have to compute. Any failures after this will incur this penalty, but it could be due + // to various external reasons - like queuing, overuse of quota, plugin overhead etc. + logger.Debugf(ctx, "preExecute completed in phase [%s]", predicatePhase.String()) + if predicatePhase == PredicatePhaseSkip { + return handler.PhaseInfoSkip(nil, "Node Skipped as parent node was skipped"), nil } - if currentPhase == v1alpha1.NodePhaseSucceeding { - logger.Debugf(ctx, "node succeeding") - if err := c.finalize(ctx, h, nCtx); err != nil { - return interfaces.NodeStatusUndefined, err - } - t := metav1.Now() + return handler.PhaseInfoNotReady("predecessor node not yet complete"), nil +} - started := nodeStatus.GetStartedAt() - if started == nil { - started = &t - } - stopped := nodeStatus.GetStoppedAt() - if stopped == nil { - stopped = &t - } - c.metrics.SuccessDuration.Observe(ctx, started.Time, stopped.Time) - nodeStatus.ClearSubNodeStatus() - nodeStatus.UpdatePhase(v1alpha1.NodePhaseSucceeded, t, "completed successfully", nil) - if nCtx.NodeExecutionMetadata().IsInterruptible() { - c.metrics.InterruptibleNodesTerminated.Inc(ctx) +func isTimeoutExpired(queuedAt *metav1.Time, timeout time.Duration) bool { + if !queuedAt.IsZero() && timeout != 0 { + deadline := queuedAt.Add(timeout) + if deadline.Before(time.Now()) { + return true } - return interfaces.NodeStatusSuccess, nil } + return false +} - if currentPhase == v1alpha1.NodePhaseRetryableFailure { - return c.handleRetryableFailure(ctx, nCtx, h) +func (c *nodeExecutor) isEligibleForRetry(nCtx interfaces.NodeExecutionContext, nodeStatus v1alpha1.ExecutableNodeStatus, err *core.ExecutionError) (currentAttempt, maxAttempts uint32, isEligible bool) { + if err.Kind == core.ExecutionError_SYSTEM { + currentAttempt = nodeStatus.GetSystemFailures() + maxAttempts = c.maxNodeRetriesForSystemFailures + isEligible = currentAttempt < c.maxNodeRetriesForSystemFailures + return } - if currentPhase == v1alpha1.NodePhaseFailed { - // This should never happen - return interfaces.NodeStatusFailed(nodeStatus.GetExecutionError()), nil + currentAttempt = (nodeStatus.GetAttempts() + 1) - nodeStatus.GetSystemFailures() + if nCtx.Node().GetRetryStrategy() != nil && nCtx.Node().GetRetryStrategy().MinAttempts != nil { + maxAttempts = uint32(*nCtx.Node().GetRetryStrategy().MinAttempts) } - - return c.handleQueuedOrRunningNode(ctx, nCtx, h) + isEligible = currentAttempt < maxAttempts + return } -// The space search for the next node to execute is implemented like a DFS algorithm. handleDownstream visits all the nodes downstream from -// the currentNode. Visit a node is the RecursiveNodeHandler. A visit may be partial, complete or may result in a failure. -func (c *nodeExecutor) handleDownstream(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) (interfaces.NodeStatus, error) { - logger.Debugf(ctx, "Handling downstream Nodes") - // This node is success. Handle all downstream nodes - downstreamNodes, err := dag.FromNode(currentNode.GetID()) +func (c *nodeExecutor) execute(ctx context.Context, h handler.Node, nCtx interfaces.NodeExecutionContext, nodeStatus v1alpha1.ExecutableNodeStatus) (handler.PhaseInfo, error) { + logger.Debugf(ctx, "Executing node") + defer logger.Debugf(ctx, "Node execution round complete") + + t, err := h.Handle(ctx, nCtx) if err != nil { - logger.Debugf(ctx, "Error when retrieving downstream nodes, [%s]", err) - return interfaces.NodeStatusFailed(&core.ExecutionError{ - Code: errors.BadSpecificationError, - Message: fmt.Sprintf("failed to retrieve downstream nodes for [%s]", currentNode.GetID()), - Kind: core.ExecutionError_SYSTEM, - }), nil - } - if len(downstreamNodes) == 0 { - logger.Debugf(ctx, "No downstream nodes found. Complete.") - return interfaces.NodeStatusComplete, nil + return handler.PhaseInfoUndefined, err } - // If any downstream node is failed, fail, all - // Else if all are success then success - // Else if any one is running then Downstream is still running - allCompleted := true - partialNodeCompletion := false - onFailurePolicy := execContext.GetOnFailurePolicy() - stateOnComplete := interfaces.NodeStatusComplete - for _, downstreamNodeName := range downstreamNodes { - downstreamNode, ok := nl.GetNode(downstreamNodeName) - if !ok { - return interfaces.NodeStatusFailed(&core.ExecutionError{ - Code: errors.BadSpecificationError, - Message: fmt.Sprintf("failed to retrieve downstream node [%s] for [%s]", downstreamNodeName, currentNode.GetID()), - Kind: core.ExecutionError_SYSTEM, - }), nil - } - state, err := c.RecursiveNodeHandler(ctx, execContext, dag, nl, downstreamNode) - if err != nil { - return interfaces.NodeStatusUndefined, err + phase := t.Info() + // check for timeout for non-terminal phases + if !phase.GetPhase().IsTerminal() { + activeDeadline := c.defaultActiveDeadline + if nCtx.Node().GetActiveDeadline() != nil && *nCtx.Node().GetActiveDeadline() > 0 { + activeDeadline = *nCtx.Node().GetActiveDeadline() } - - if state.HasFailed() || state.HasTimedOut() { - logger.Debugf(ctx, "Some downstream node has failed. Failed: [%v]. TimedOut: [%v]. Error: [%s]", state.HasFailed(), state.HasTimedOut(), state.Err) - if onFailurePolicy == v1alpha1.WorkflowOnFailurePolicy(core.WorkflowMetadata_FAIL_AFTER_EXECUTABLE_NODES_COMPLETE) { - // If the failure policy allows other nodes to continue running, do not exit the loop, - // Keep track of the last failed state in the loop since it'll be the one to return. - // TODO: If multiple nodes fail (which this mode allows), consolidate/summarize failure states in one. - stateOnComplete = state - } else { - return state, nil - } - } else if !state.IsComplete() { - // A Failed/Timedout node is implicitly considered "complete" this means none of the downstream nodes from - // that node will ever be allowed to run. - // This else block, therefore, deals with all other states. IsComplete will return true if and only if this - // node as well as all of its downstream nodes have finished executing with success statuses. Otherwise we - // mark this node's state as not completed to ensure we will visit it again later. - allCompleted = false + if isTimeoutExpired(nodeStatus.GetQueuedAt(), activeDeadline) { + logger.Infof(ctx, "Node has timed out; timeout configured: %v", activeDeadline) + return handler.PhaseInfoTimedOut(nil, fmt.Sprintf("task active timeout [%s] expired", activeDeadline.String())), nil } - if state.PartiallyComplete() { - // This implies that one of the downstream nodes has just succeeded and workflow is ready for propagation - // We do not propagate in current cycle to make it possible to store the state between transitions - partialNodeCompletion = true + // Execution timeout is a retry-able error + executionDeadline := c.defaultExecutionDeadline + if nCtx.Node().GetExecutionDeadline() != nil && *nCtx.Node().GetExecutionDeadline() > 0 { + executionDeadline = *nCtx.Node().GetExecutionDeadline() + } + if isTimeoutExpired(nodeStatus.GetLastAttemptStartedAt(), executionDeadline) { + logger.Infof(ctx, "Current execution for the node timed out; timeout configured: %v", executionDeadline) + executionErr := &core.ExecutionError{Code: "TimeoutExpired", Message: fmt.Sprintf("task execution timeout [%s] expired", executionDeadline.String()), Kind: core.ExecutionError_USER} + phase = handler.PhaseInfoRetryableFailureErr(executionErr, nil) } } - if allCompleted { - logger.Debugf(ctx, "All downstream nodes completed") - return stateOnComplete, nil - } + if phase.GetPhase() == handler.EPhaseRetryableFailure { + currentAttempt, maxAttempts, isEligible := c.isEligibleForRetry(nCtx, nodeStatus, phase.GetErr()) + if !isEligible { + return handler.PhaseInfoFailure( + core.ExecutionError_USER, + fmt.Sprintf("RetriesExhausted|%s", phase.GetErr().Code), + fmt.Sprintf("[%d/%d] currentAttempt done. Last Error: %s::%s", currentAttempt, maxAttempts, phase.GetErr().Kind.String(), phase.GetErr().Message), + phase.GetInfo(), + ), nil + } - if partialNodeCompletion { - return interfaces.NodeStatusSuccess, nil + // Retrying to clearing all status + nCtx.NodeStateWriter().ClearNodeStatus() } - return interfaces.NodeStatusPending, nil + return phase, nil } -func (c *nodeExecutor) SetInputsForStartNode(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructureWithStartNode, nl executors.NodeLookup, inputs *core.LiteralMap) (interfaces.NodeStatus, error) { - startNode := dag.StartNode() - ctx = contextutils.WithNodeID(ctx, startNode.GetID()) - if inputs == nil { - logger.Infof(ctx, "No inputs for the workflow. Skipping storing inputs") - return interfaces.NodeStatusComplete, nil +func (c *nodeExecutor) Abort(ctx context.Context, h handler.Node, nCtx interfaces.NodeExecutionContext, reason string) error { + logger.Debugf(ctx, "Calling aborting & finalize") + if err := h.Abort(ctx, nCtx, reason); err != nil { + finalizeErr := h.Finalize(ctx, nCtx) + if finalizeErr != nil { + return errors.ErrorCollection{Errors: []error{err, finalizeErr}} + } + return err } - // StartNode is special. It does not have any processing step. It just takes the workflow (or subworkflow) inputs and converts to its own outputs - nodeStatus := nl.GetNodeExecutionStatus(ctx, startNode.GetID()) - - if len(nodeStatus.GetDataDir()) == 0 { - return interfaces.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, startNode.GetID(), "no data-dir set, cannot store inputs") + if err := h.Finalize(ctx, nCtx); err != nil { + return err } - outputFile := v1alpha1.GetOutputsFile(nodeStatus.GetOutputDir()) - so := storage.Options{} - if err := c.store.WriteProtobuf(ctx, outputFile, so, inputs); err != nil { - logger.Errorf(ctx, "Failed to write protobuf (metadata). Error [%v]", err) - return interfaces.NodeStatusUndefined, errors.Wrapf(errors.CausedByError, startNode.GetID(), err, "Failed to store workflow inputs (as start node)") + nodeExecutionID := &core.NodeExecutionIdentifier{ + ExecutionId: nCtx.NodeExecutionMetadata().GetNodeExecutionID().ExecutionId, + NodeId: nCtx.NodeExecutionMetadata().GetNodeExecutionID().NodeId, + } + if nCtx.ExecutionContext().GetEventVersion() != v1alpha1.EventVersion0 { + currentNodeUniqueID, err := common.GenerateUniqueID(nCtx.ExecutionContext().GetParentInfo(), nodeExecutionID.NodeId) + if err != nil { + return err + } + nodeExecutionID.NodeId = currentNodeUniqueID } - return interfaces.NodeStatusComplete, nil + err := c.IdempotentRecordEvent(ctx, &event.NodeExecutionEvent{ + Id: nodeExecutionID, + Phase: core.NodeExecution_ABORTED, + OccurredAt: ptypes.TimestampNow(), + OutputResult: &event.NodeExecutionEvent_Error{ + Error: &core.ExecutionError{ + Code: "NodeAborted", + Message: reason, + }, + }, + ProducerId: c.clusterID, + ReportedAt: ptypes.TimestampNow(), + }) + if err != nil && !eventsErr.IsNotFound(err) && !eventsErr.IsEventIncompatibleClusterError(err) { + if errors2.IsCausedBy(err, errors.IllegalStateError) { + logger.Debugf(ctx, "Failed to record abort event due to illegal state transition. Ignoring the error. Error: %v", err) + } else { + logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) + return errors.Wrapf(errors.EventRecordingFailed, nCtx.NodeID(), err, "failed to record node event") + } + } + return nil } -func canHandleNode(phase v1alpha1.NodePhase) bool { - return phase == v1alpha1.NodePhaseNotYetStarted || - phase == v1alpha1.NodePhaseQueued || - phase == v1alpha1.NodePhaseRunning || - phase == v1alpha1.NodePhaseFailing || - phase == v1alpha1.NodePhaseTimingOut || - phase == v1alpha1.NodePhaseRetryableFailure || - phase == v1alpha1.NodePhaseSucceeding || - phase == v1alpha1.NodePhaseDynamicRunning +func (c *nodeExecutor) Finalize(ctx context.Context, h handler.Node, nCtx interfaces.NodeExecutionContext) error { + return h.Finalize(ctx, nCtx) } -// IsMaxParallelismAchieved checks if we have already achieved max parallelism. It returns true, if the desired max parallelism -// value is achieved, false otherwise -// MaxParallelism is defined as the maximum number of TaskNodes and LaunchPlans (together) that can be executed concurrently -// by one workflow execution. A setting of `0` indicates that it is disabled. -func IsMaxParallelismAchieved(ctx context.Context, currentNode v1alpha1.ExecutableNode, currentPhase v1alpha1.NodePhase, - execContext executors.ExecutionContext) bool { - maxParallelism := execContext.GetExecutionConfig().MaxParallelism - if maxParallelism == 0 { - logger.Debugf(ctx, "Parallelism control disabled") - return false +func (c *nodeExecutor) handleNotYetStartedNode(ctx context.Context, dag executors.DAGStructure, nCtx interfaces.NodeExecutionContext, _ handler.Node) (interfaces.NodeStatus, error) { + logger.Debugf(ctx, "Node not yet started, running pre-execute") + defer logger.Debugf(ctx, "Node pre-execute completed") + occurredAt := time.Now() + p, err := c.preExecute(ctx, dag, nCtx) + if err != nil { + logger.Errorf(ctx, "failed preExecute for node. Error: %s", err.Error()) + return interfaces.NodeStatusUndefined, err } - if currentNode.GetKind() == v1alpha1.NodeKindTask || - (currentNode.GetKind() == v1alpha1.NodeKindWorkflow && currentNode.GetWorkflowNode() != nil && currentNode.GetWorkflowNode().GetLaunchPlanRefID() != nil) { - // If we are queued, let us see if we can proceed within the node parallelism bounds - if execContext.CurrentParallelism() >= maxParallelism { - logger.Infof(ctx, "Maximum Parallelism for task/launch-plan nodes achieved [%d] >= Max [%d], Round will be short-circuited.", execContext.CurrentParallelism(), maxParallelism) - return true - } - // We know that Propeller goes through each workflow in a single thread, thus every node is really processed - // sequentially. So, we can continue - now that we know we are under the parallelism limits and increment the - // parallelism if the node, enters a running state - logger.Debugf(ctx, "Parallelism criteria not met, Current [%d], Max [%d]", execContext.CurrentParallelism(), maxParallelism) - } else { - logger.Debugf(ctx, "NodeKind: %s in status [%s]. Parallelism control is not applicable. Current Parallelism [%d]", - currentNode.GetKind().String(), currentPhase.String(), execContext.CurrentParallelism()) + if p.GetPhase() == handler.EPhaseUndefined { + return interfaces.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, nCtx.NodeID(), "received undefined phase.") } - return false -} + if p.GetPhase() == handler.EPhaseNotReady { + return interfaces.NodeStatusPending, nil + } -// RecursiveNodeHandler This is the entrypoint of executing a node in a workflow. A workflow consists of nodes, that are -// nested within other nodes. The system follows an actor model, where the parent nodes control the execution of nested nodes -// The recursive node-handler uses a modified depth-first type of algorithm to execute non-blocked nodes. -func (c *nodeExecutor) RecursiveNodeHandler(ctx context.Context, execContext executors.ExecutionContext, - dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) ( - interfaces.NodeStatus, error) { - - return c.RecursiveNodeHandlerWithNodeContextModifier(ctx, execContext, dag, nl, currentNode, func (nCtx interfaces.NodeExecutionContext) interfaces.NodeExecutionContext { - return nCtx - }) -} + np, err := ToNodePhase(p.GetPhase()) + if err != nil { + return interfaces.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "failed to move from queued") + } -// TODO @hamersaw -func (c *nodeExecutor) RecursiveNodeHandlerWithNodeContextModifier(ctx context.Context, execContext executors.ExecutionContext, - dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode, - nCtxModifier func (interfaces.NodeExecutionContext) interfaces.NodeExecutionContext) ( - interfaces.NodeStatus, error) { + nodeStatus := nCtx.NodeStatus() + if np != nodeStatus.GetPhase() { + // assert np == Queued! + logger.Infof(ctx, "Change in node state detected from [%s] -> [%s]", nodeStatus.GetPhase().String(), np.String()) + p = p.WithOccuredAt(occurredAt) - currentNodeCtx := contextutils.WithNodeID(ctx, currentNode.GetID()) - nodeStatus := nl.GetNodeExecutionStatus(ctx, currentNode.GetID()) - nodePhase := nodeStatus.GetPhase() + nev, err := ToNodeExecutionEvent(nCtx.NodeExecutionMetadata().GetNodeExecutionID(), + p, nCtx.InputReader().GetInputPath().String(), nodeStatus, nCtx.ExecutionContext().GetEventVersion(), + nCtx.ExecutionContext().GetParentInfo(), nCtx.Node(), c.clusterID, nCtx.NodeStateReader().GetDynamicNodeState().Phase, + c.eventConfig) + if err != nil { + return interfaces.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "could not convert phase info to event") + } + err = c.IdempotentRecordEvent(ctx, nev) + if err != nil { + logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) + return interfaces.NodeStatusUndefined, errors.Wrapf(errors.EventRecordingFailed, nCtx.NodeID(), err, "failed to record node event") + } + UpdateNodeStatus(np, p, nCtx.NodeStateReader(), nodeStatus) + c.RecordTransitionLatency(ctx, dag, nCtx.ContextualNodeLookup(), nCtx.Node(), nodeStatus) + } - if canHandleNode(nodePhase) { - // TODO Follow up Pull Request, - // 1. Rename this method to DAGTraversalHandleNode (accepts a DAGStructure along-with) the remaining arguments - // 2. Create a new method called HandleNode (part of the interface) (remaining all args as the previous method, but no DAGStructure - // 3. Additional both methods will receive inputs reader - // 4. The Downstream nodes handler will Resolve the Inputs - // 5. the method will delegate all other node handling to HandleNode. - // 6. Thus we can get rid of SetInputs for StartNode as well - logger.Debugf(currentNodeCtx, "Handling node Status [%v]", nodeStatus.GetPhase().String()) + if np == v1alpha1.NodePhaseQueued { + if nCtx.NodeExecutionMetadata().IsInterruptible() { + c.metrics.InterruptibleNodesRunning.Inc(ctx) + } + return interfaces.NodeStatusQueued, nil + } else if np == v1alpha1.NodePhaseSkipped { + return interfaces.NodeStatusSuccess, nil + } - t := c.metrics.NodeExecutionTime.Start(ctx) - defer t.Stop() + return interfaces.NodeStatusPending, nil +} - // This is an optimization to avoid creating the nodeContext object in case the node has already been looked at. - // If the overhead was zero, we would just do the isDirtyCheck after the nodeContext is created - if nodeStatus.IsDirty() { - return interfaces.NodeStatusRunning, nil - } +func (c *nodeExecutor) handleQueuedOrRunningNode(ctx context.Context, nCtx interfaces.NodeExecutionContext, h handler.Node) (interfaces.NodeStatus, error) { + nodeStatus := nCtx.NodeStatus() + currentPhase := nodeStatus.GetPhase() - if IsMaxParallelismAchieved(ctx, currentNode, nodePhase, execContext) { - return interfaces.NodeStatusRunning, nil - } + // case v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseRunning: + logger.Debugf(ctx, "node executing, current phase [%s]", currentPhase) + defer logger.Debugf(ctx, "node execution completed") - nCtx, err := c.NewNodeExecutionContext(ctx, execContext, nl, currentNode.GetID()) - if err != nil { - // NodeExecution creation failure is a permanent fail / system error. - // Should a system failure always return an err? - return interfaces.NodeStatusFailed(&core.ExecutionError{ - Code: "InternalError", - Message: err.Error(), - Kind: core.ExecutionError_SYSTEM, - }), nil - } + // Since we reset node status inside execute for retryable failure, we use lastAttemptStartTime to carry that information + // across execute which is used to emit metrics + lastAttemptStartTime := nodeStatus.GetLastAttemptStartedAt() - nCtx = nCtxModifier(nCtx) + p, err := c.execute(ctx, h, nCtx, nodeStatus) + if err != nil { + logger.Errorf(ctx, "failed Execute for node. Error: %s", err.Error()) + return interfaces.NodeStatusUndefined, err + } - // Now depending on the node type decide - h, err := c.nodeHandlerFactory.GetHandler(nCtx.Node().GetKind()) - if err != nil { - return interfaces.NodeStatusUndefined, err - } + if p.GetPhase() == handler.EPhaseUndefined { + return interfaces.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, nCtx.NodeID(), "received undefined phase.") + } - return c.handleNode(currentNodeCtx, dag, nCtx, h) + np, err := ToNodePhase(p.GetPhase()) + if err != nil { + return interfaces.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "failed to move from queued") + } - // TODO we can optimize skip state handling by iterating down the graph and marking all as skipped - // Currently we treat either Skip or Success the same way. In this approach only one node will be skipped - // at a time. As we iterate down, further nodes will be skipped - } else if nodePhase == v1alpha1.NodePhaseSucceeded || nodePhase == v1alpha1.NodePhaseSkipped || nodePhase == v1alpha1.NodePhaseRecovered { - logger.Debugf(currentNodeCtx, "Node has [%v], traversing downstream.", nodePhase) - return c.handleDownstream(ctx, execContext, dag, nl, currentNode) - } else if nodePhase == v1alpha1.NodePhaseFailed { - logger.Debugf(currentNodeCtx, "Node has failed, traversing downstream.") - _, err := c.handleDownstream(ctx, execContext, dag, nl, currentNode) - if err != nil { - return interfaces.NodeStatusUndefined, err + // execErr in phase-info 'p' is only available if node has failed to execute, and the current phase at that time + // will be v1alpha1.NodePhaseRunning + execErr := p.GetErr() + if execErr != nil && (currentPhase == v1alpha1.NodePhaseRunning || currentPhase == v1alpha1.NodePhaseQueued || + currentPhase == v1alpha1.NodePhaseDynamicRunning) { + endTime := time.Now() + startTime := endTime + if lastAttemptStartTime != nil { + startTime = lastAttemptStartTime.Time } - return interfaces.NodeStatusFailed(nodeStatus.GetExecutionError()), nil - } else if nodePhase == v1alpha1.NodePhaseTimedOut { - logger.Debugf(currentNodeCtx, "Node has timed out, traversing downstream.") - _, err := c.handleDownstream(ctx, execContext, dag, nl, currentNode) - if err != nil { - return interfaces.NodeStatusUndefined, err + if execErr.GetKind() == core.ExecutionError_SYSTEM { + nodeStatus.IncrementSystemFailures() + c.metrics.SystemErrorDuration.Observe(ctx, startTime, endTime) + } else if execErr.GetKind() == core.ExecutionError_USER { + c.metrics.UserErrorDuration.Observe(ctx, startTime, endTime) + } else { + c.metrics.UnknownErrorDuration.Observe(ctx, startTime, endTime) } - - return interfaces.NodeStatusTimedOut, nil + // When a node fails, we fail the workflow. Independent of number of nodes succeeding/failing, whenever a first node fails, + // the entire workflow is failed. + if np == v1alpha1.NodePhaseFailing { + if execErr.GetKind() == core.ExecutionError_SYSTEM { + nodeStatus.IncrementSystemFailures() + c.metrics.PermanentSystemErrorDuration.Observe(ctx, startTime, endTime) + } else if execErr.GetKind() == core.ExecutionError_USER { + c.metrics.PermanentUserErrorDuration.Observe(ctx, startTime, endTime) + } else { + c.metrics.PermanentUnknownErrorDuration.Observe(ctx, startTime, endTime) + } + } + } + finalStatus := interfaces.NodeStatusRunning + if np == v1alpha1.NodePhaseFailing && !h.FinalizeRequired() { + logger.Infof(ctx, "Finalize not required, moving node to Failed") + np = v1alpha1.NodePhaseFailed + finalStatus = interfaces.NodeStatusFailed(p.GetErr()) } - return interfaces.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, currentNode.GetID(), - "Should never reach here. Current Phase: %v", nodePhase) -} - -func (c *nodeExecutor) FinalizeHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) error { - nodeStatus := nl.GetNodeExecutionStatus(ctx, currentNode.GetID()) - nodePhase := nodeStatus.GetPhase() + if np == v1alpha1.NodePhaseTimingOut && !h.FinalizeRequired() { + logger.Infof(ctx, "Finalize not required, moving node to TimedOut") + np = v1alpha1.NodePhaseTimedOut + finalStatus = interfaces.NodeStatusTimedOut + } - if nodePhase == v1alpha1.NodePhaseNotYetStarted { - logger.Infof(ctx, "Node not yet started, will not finalize") - // Nothing to be aborted - return nil + if np == v1alpha1.NodePhaseSucceeding && !h.FinalizeRequired() { + logger.Infof(ctx, "Finalize not required, moving node to Succeeded") + np = v1alpha1.NodePhaseSucceeded + finalStatus = interfaces.NodeStatusSuccess + } + if np == v1alpha1.NodePhaseRecovered { + logger.Infof(ctx, "Finalize not required, moving node to Recovered") + finalStatus = interfaces.NodeStatusRecovered } - if canHandleNode(nodePhase) { - ctx = contextutils.WithNodeID(ctx, currentNode.GetID()) + // If it is retryable failure, we do no want to send any events, as the node is essentially still running + // Similarly if the phase has not changed from the last time, events do not need to be sent + if np != nodeStatus.GetPhase() && np != v1alpha1.NodePhaseRetryableFailure { + // assert np == skipped, succeeding, failing or recovered + logger.Infof(ctx, "Change in node state detected from [%s] -> [%s], (handler phase [%s])", nodeStatus.GetPhase().String(), np.String(), p.GetPhase().String()) - // Now depending on the node type decide - h, err := c.nodeHandlerFactory.GetHandler(currentNode.GetKind()) + nev, err := ToNodeExecutionEvent(nCtx.NodeExecutionMetadata().GetNodeExecutionID(), + p, nCtx.InputReader().GetInputPath().String(), nCtx.NodeStatus(), nCtx.ExecutionContext().GetEventVersion(), + nCtx.ExecutionContext().GetParentInfo(), nCtx.Node(), c.clusterID, nCtx.NodeStateReader().GetDynamicNodeState().Phase, + c.eventConfig) if err != nil { - return err + return interfaces.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "could not convert phase info to event") } - nCtx, err := c.NewNodeExecutionContext(ctx, execContext, nl, currentNode.GetID()) - if err != nil { - return err - } - // Abort this node - err = c.finalize(ctx, h, nCtx) - if err != nil { - return err - } - } else { - // Abort downstream nodes - downstreamNodes, err := dag.FromNode(currentNode.GetID()) + err = c.IdempotentRecordEvent(ctx, nev) if err != nil { - logger.Debugf(ctx, "Error when retrieving downstream nodes. Error [%v]", err) - return nil - } + if eventsErr.IsTooLarge(err) { + // With large enough dynamic task fanouts the reported node event, which contains the compiled + // workflow closure, can exceed the gRPC message size limit. In this case we immediately + // transition the node to failing to abort the workflow. + np = v1alpha1.NodePhaseFailing + p = handler.PhaseInfoFailure(core.ExecutionError_USER, "NodeFailed", err.Error(), p.GetInfo()) - errs := make([]error, 0, len(downstreamNodes)) - for _, d := range downstreamNodes { - downstreamNode, ok := nl.GetNode(d) - if !ok { - return errors.Errorf(errors.BadSpecificationError, currentNode.GetID(), "Unable to find Downstream Node [%v]", d) - } + err = c.IdempotentRecordEvent(ctx, &event.NodeExecutionEvent{ + Id: nCtx.NodeExecutionMetadata().GetNodeExecutionID(), + Phase: core.NodeExecution_FAILED, + OccurredAt: ptypes.TimestampNow(), + OutputResult: &event.NodeExecutionEvent_Error{ + Error: &core.ExecutionError{ + Code: "NodeFailed", + Message: err.Error(), + }, + }, + ReportedAt: ptypes.TimestampNow(), + }) - if err := c.FinalizeHandler(ctx, execContext, dag, nl, downstreamNode); err != nil { - logger.Infof(ctx, "Failed to abort node [%v]. Error: %v", d, err) - errs = append(errs, err) + if err != nil { + return interfaces.NodeStatusUndefined, errors.Wrapf(errors.EventRecordingFailed, nCtx.NodeID(), err, "failed to record node event") + } + } else { + logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) + return interfaces.NodeStatusUndefined, errors.Wrapf(errors.EventRecordingFailed, nCtx.NodeID(), err, "failed to record node event") } } - if len(errs) > 0 { - return errors.ErrorCollection{Errors: errs} + // We reach here only when transitioning from Queued to Running. In this case, the startedAt is not set. + if np == v1alpha1.NodePhaseRunning { + if nodeStatus.GetQueuedAt() != nil { + c.metrics.QueuingLatency.Observe(ctx, nodeStatus.GetQueuedAt().Time, time.Now()) + } } - - return nil } - return nil + UpdateNodeStatus(np, p, nCtx.NodeStateReader(), nodeStatus) + return finalStatus, nil } -func (c *nodeExecutor) AbortHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode, reason string) error { - nodeStatus := nl.GetNodeExecutionStatus(ctx, currentNode.GetID()) - nodePhase := nodeStatus.GetPhase() - - if nodePhase == v1alpha1.NodePhaseNotYetStarted { - logger.Infof(ctx, "Node not yet started, will not finalize") - // Nothing to be aborted - return nil +func (c *nodeExecutor) handleRetryableFailure(ctx context.Context, nCtx interfaces.NodeExecutionContext, h handler.Node) (interfaces.NodeStatus, error) { + nodeStatus := nCtx.NodeStatus() + logger.Debugf(ctx, "node failed with retryable failure, aborting and finalizing, message: %s", nodeStatus.GetMessage()) + if err := c.Abort(ctx, h, nCtx, nodeStatus.GetMessage()); err != nil { + return interfaces.NodeStatusUndefined, err } - if canHandleNode(nodePhase) { - ctx = contextutils.WithNodeID(ctx, currentNode.GetID()) + // NOTE: It is important to increment attempts only after abort has been called. Increment attempt mutates the state + // Attempt is used throughout the system to determine the idempotent resource version. + nodeStatus.IncrementAttempts() + nodeStatus.UpdatePhase(v1alpha1.NodePhaseRunning, metav1.Now(), "retrying", nil) + // We are going to retry in the next round, so we should clear all current state + nodeStatus.ClearSubNodeStatus() + nodeStatus.ClearTaskStatus() + nodeStatus.ClearWorkflowStatus() + nodeStatus.ClearDynamicNodeStatus() + nodeStatus.ClearGateNodeStatus() + nodeStatus.ClearArrayNodeStatus() + return interfaces.NodeStatusPending, nil +} - // Now depending on the node type decide - h, err := c.nodeHandlerFactory.GetHandler(currentNode.GetKind()) - if err != nil { - return err - } +func (c *nodeExecutor) HandleNode(ctx context.Context, dag executors.DAGStructure, nCtx interfaces.NodeExecutionContext, h handler.Node) (interfaces.NodeStatus, error) { + logger.Debugf(ctx, "Handling Node [%s]", nCtx.NodeID()) + defer logger.Debugf(ctx, "Completed node [%s]", nCtx.NodeID()) - nCtx, err := c.NewNodeExecutionContext(ctx, execContext, nl, currentNode.GetID()) + nodeStatus := nCtx.NodeStatus() + currentPhase := nodeStatus.GetPhase() + + // Optimization! + // If it is start node we directly move it to Queued without needing to run preExecute + if currentPhase == v1alpha1.NodePhaseNotYetStarted && !nCtx.Node().IsStartNode() { + p, err := c.handleNotYetStartedNode(ctx, dag, nCtx, h) if err != nil { - return err + return p, err } - // Abort this node - err = c.abort(ctx, h, nCtx, reason) - if err != nil { - return err + if p.NodePhase == interfaces.NodePhaseQueued { + logger.Infof(ctx, "Node was queued, parallelism is now [%d]", nCtx.ExecutionContext().IncrementParallelism()) } - nodeExecutionID := &core.NodeExecutionIdentifier{ - ExecutionId: nCtx.NodeExecutionMetadata().GetNodeExecutionID().ExecutionId, - NodeId: nCtx.NodeExecutionMetadata().GetNodeExecutionID().NodeId, + return p, err + } + + if currentPhase == v1alpha1.NodePhaseFailing { + logger.Debugf(ctx, "node failing") + if err := c.Abort(ctx, h, nCtx, "node failing"); err != nil { + return interfaces.NodeStatusUndefined, err } - if nCtx.ExecutionContext().GetEventVersion() != v1alpha1.EventVersion0 { - currentNodeUniqueID, err := common.GenerateUniqueID(nCtx.ExecutionContext().GetParentInfo(), nodeExecutionID.NodeId) - if err != nil { - return err - } - nodeExecutionID.NodeId = currentNodeUniqueID - } - - err = c.IdempotentRecordEvent(ctx, &event.NodeExecutionEvent{ - Id: nodeExecutionID, - Phase: core.NodeExecution_ABORTED, - OccurredAt: ptypes.TimestampNow(), - OutputResult: &event.NodeExecutionEvent_Error{ - Error: &core.ExecutionError{ - Code: "NodeAborted", - Message: reason, - }, - }, - ProducerId: c.clusterID, - ReportedAt: ptypes.TimestampNow(), - }) - if err != nil && !eventsErr.IsNotFound(err) && !eventsErr.IsEventIncompatibleClusterError(err) { - if errors2.IsCausedBy(err, errors.IllegalStateError) { - logger.Debugf(ctx, "Failed to record abort event due to illegal state transition. Ignoring the error. Error: %v", err) - } else { - logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) - return errors.Wrapf(errors.EventRecordingFailed, nCtx.NodeID(), err, "failed to record node event") - } + nodeStatus.UpdatePhase(v1alpha1.NodePhaseFailed, metav1.Now(), nodeStatus.GetMessage(), nodeStatus.GetExecutionError()) + c.metrics.FailureDuration.Observe(ctx, nodeStatus.GetStartedAt().Time, nodeStatus.GetStoppedAt().Time) + if nCtx.NodeExecutionMetadata().IsInterruptible() { + c.metrics.InterruptibleNodesTerminated.Inc(ctx) } - } else if nodePhase == v1alpha1.NodePhaseSucceeded || nodePhase == v1alpha1.NodePhaseSkipped || nodePhase == v1alpha1.NodePhaseRecovered { - // Abort downstream nodes - downstreamNodes, err := dag.FromNode(currentNode.GetID()) - if err != nil { - logger.Debugf(ctx, "Error when retrieving downstream nodes. Error [%v]", err) - return nil + return interfaces.NodeStatusFailed(nodeStatus.GetExecutionError()), nil + } + + if currentPhase == v1alpha1.NodePhaseTimingOut { + logger.Debugf(ctx, "node timing out") + if err := c.Abort(ctx, h, nCtx, "node timed out"); err != nil { + return interfaces.NodeStatusUndefined, err } - errs := make([]error, 0, len(downstreamNodes)) - for _, d := range downstreamNodes { - downstreamNode, ok := nl.GetNode(d) - if !ok { - return errors.Errorf(errors.BadSpecificationError, currentNode.GetID(), "Unable to find Downstream Node [%v]", d) - } + nodeStatus.ClearSubNodeStatus() + nodeStatus.UpdatePhase(v1alpha1.NodePhaseTimedOut, metav1.Now(), nodeStatus.GetMessage(), nodeStatus.GetExecutionError()) + c.metrics.TimedOutFailure.Inc(ctx) + if nCtx.NodeExecutionMetadata().IsInterruptible() { + c.metrics.InterruptibleNodesTerminated.Inc(ctx) + } + return interfaces.NodeStatusTimedOut, nil + } - if err := c.AbortHandler(ctx, execContext, dag, nl, downstreamNode, reason); err != nil { - logger.Infof(ctx, "Failed to abort node [%v]. Error: %v", d, err) - errs = append(errs, err) - } + if currentPhase == v1alpha1.NodePhaseSucceeding { + logger.Debugf(ctx, "node succeeding") + if err := c.Finalize(ctx, h, nCtx); err != nil { + return interfaces.NodeStatusUndefined, err } + t := metav1.Now() - if len(errs) > 0 { - return errors.ErrorCollection{Errors: errs} + started := nodeStatus.GetStartedAt() + if started == nil { + started = &t + } + stopped := nodeStatus.GetStoppedAt() + if stopped == nil { + stopped = &t + } + c.metrics.SuccessDuration.Observe(ctx, started.Time, stopped.Time) + nodeStatus.ClearSubNodeStatus() + nodeStatus.UpdatePhase(v1alpha1.NodePhaseSucceeded, t, "completed successfully", nil) + if nCtx.NodeExecutionMetadata().IsInterruptible() { + c.metrics.InterruptibleNodesTerminated.Inc(ctx) } + return interfaces.NodeStatusSuccess, nil + } - return nil - } else { - ctx = contextutils.WithNodeID(ctx, currentNode.GetID()) - logger.Warnf(ctx, "Trying to abort a node in state [%s]", nodeStatus.GetPhase().String()) + if currentPhase == v1alpha1.NodePhaseRetryableFailure { + return c.handleRetryableFailure(ctx, nCtx, h) } - return nil -} + if currentPhase == v1alpha1.NodePhaseFailed { + // This should never happen + return interfaces.NodeStatusFailed(nodeStatus.GetExecutionError()), nil + } -func (c *nodeExecutor) Initialize(ctx context.Context) error { - logger.Infof(ctx, "Initializing Core Node Executor") - s := c.newSetupContext(ctx) - return c.nodeHandlerFactory.Setup(ctx, s) + return c.handleQueuedOrRunningNode(ctx, nCtx, h) } func NewExecutor(ctx context.Context, nodeConfig config.NodeConfig, store *storage.DataStore, enQWorkflow v1alpha1.EnqueueWorkflow, eventSink events.EventSink, @@ -1234,44 +1252,55 @@ func NewExecutor(ctx context.Context, nodeConfig config.NodeConfig, store *stora } nodeScope := scope.NewSubScope("node") - exec := &nodeExecutor{ - store: store, + metrics := &nodeMetrics{ + Scope: nodeScope, + FailureDuration: labeled.NewStopWatch("failure_duration", "Indicates the total execution time of a failed workflow.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + SuccessDuration: labeled.NewStopWatch("success_duration", "Indicates the total execution time of a successful workflow.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + RecoveryDuration: labeled.NewStopWatch("recovery_duration", "Indicates the total execution time of a recovered workflow.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + UserErrorDuration: labeled.NewStopWatch("user_error_duration", "Indicates the total execution time before user error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + SystemErrorDuration: labeled.NewStopWatch("system_error_duration", "Indicates the total execution time before system error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + UnknownErrorDuration: labeled.NewStopWatch("unknown_error_duration", "Indicates the total execution time before unknown error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + PermanentUserErrorDuration: labeled.NewStopWatch("perma_user_error_duration", "Indicates the total execution time before non recoverable user error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + PermanentSystemErrorDuration: labeled.NewStopWatch("perma_system_error_duration", "Indicates the total execution time before non recoverable system error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + PermanentUnknownErrorDuration: labeled.NewStopWatch("perma_unknown_error_duration", "Indicates the total execution time before non recoverable unknown error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + InputsWriteFailure: labeled.NewCounter("inputs_write_fail", "Indicates failure in writing node inputs to metastore", nodeScope), + TimedOutFailure: labeled.NewCounter("timeout_fail", "Indicates failure due to timeout", nodeScope), + InterruptedThresholdHit: labeled.NewCounter("interrupted_threshold", "Indicates the node interruptible disabled because it hit max failure count", nodeScope), + InterruptibleNodesRunning: labeled.NewCounter("interruptible_nodes_running", "number of interruptible nodes running", nodeScope), + InterruptibleNodesTerminated: labeled.NewCounter("interruptible_nodes_terminated", "number of interruptible nodes finished running", nodeScope), + ResolutionFailure: labeled.NewCounter("input_resolve_fail", "Indicates failure in resolving node inputs", nodeScope), + TransitionLatency: labeled.NewStopWatch("transition_latency", "Measures the latency between the last parent node stoppedAt time and current node's queued time.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + QueuingLatency: labeled.NewStopWatch("queueing_latency", "Measures the latency between the time a node's been queued to the time the handler reported the executable moved to running state", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + NodeExecutionTime: labeled.NewStopWatch("node_exec_latency", "Measures the time taken to execute one node, a node can be complex so it may encompass sub-node latency.", time.Microsecond, nodeScope, labeled.EmitUnlabeledMetric), + NodeInputGatherLatency: labeled.NewStopWatch("node_input_latency", "Measures the latency to aggregate inputs and check readiness of a node", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + } + + nodeExecutor := &nodeExecutor{ + clusterID: clusterID, + defaultActiveDeadline: nodeConfig.DefaultDeadlines.DefaultNodeActiveDeadline.Duration, + defaultDataSandbox: defaultRawOutputPrefix, + defaultExecutionDeadline: nodeConfig.DefaultDeadlines.DefaultNodeExecutionDeadline.Duration, enqueueWorkflow: enQWorkflow, - nodeRecorder: events.NewNodeEventRecorder(eventSink, nodeScope, store), - taskRecorder: events.NewTaskEventRecorder(eventSink, scope.NewSubScope("task"), store), + eventConfig: eventConfig, + interruptibleFailureThreshold: uint32(nodeConfig.InterruptibleFailureThreshold), maxDatasetSizeBytes: maxDatasetSize, - metrics: &nodeMetrics{ - Scope: nodeScope, - FailureDuration: labeled.NewStopWatch("failure_duration", "Indicates the total execution time of a failed workflow.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - SuccessDuration: labeled.NewStopWatch("success_duration", "Indicates the total execution time of a successful workflow.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - RecoveryDuration: labeled.NewStopWatch("recovery_duration", "Indicates the total execution time of a recovered workflow.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - UserErrorDuration: labeled.NewStopWatch("user_error_duration", "Indicates the total execution time before user error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - SystemErrorDuration: labeled.NewStopWatch("system_error_duration", "Indicates the total execution time before system error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - UnknownErrorDuration: labeled.NewStopWatch("unknown_error_duration", "Indicates the total execution time before unknown error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - PermanentUserErrorDuration: labeled.NewStopWatch("perma_user_error_duration", "Indicates the total execution time before non recoverable user error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - PermanentSystemErrorDuration: labeled.NewStopWatch("perma_system_error_duration", "Indicates the total execution time before non recoverable system error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - PermanentUnknownErrorDuration: labeled.NewStopWatch("perma_unknown_error_duration", "Indicates the total execution time before non recoverable unknown error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - InputsWriteFailure: labeled.NewCounter("inputs_write_fail", "Indicates failure in writing node inputs to metastore", nodeScope), - TimedOutFailure: labeled.NewCounter("timeout_fail", "Indicates failure due to timeout", nodeScope), - InterruptedThresholdHit: labeled.NewCounter("interrupted_threshold", "Indicates the node interruptible disabled because it hit max failure count", nodeScope), - InterruptibleNodesRunning: labeled.NewCounter("interruptible_nodes_running", "number of interruptible nodes running", nodeScope), - InterruptibleNodesTerminated: labeled.NewCounter("interruptible_nodes_terminated", "number of interruptible nodes finished running", nodeScope), - ResolutionFailure: labeled.NewCounter("input_resolve_fail", "Indicates failure in resolving node inputs", nodeScope), - TransitionLatency: labeled.NewStopWatch("transition_latency", "Measures the latency between the last parent node stoppedAt time and current node's queued time.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - QueuingLatency: labeled.NewStopWatch("queueing_latency", "Measures the latency between the time a node's been queued to the time the handler reported the executable moved to running state", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - NodeExecutionTime: labeled.NewStopWatch("node_exec_latency", "Measures the time taken to execute one node, a node can be complex so it may encompass sub-node latency.", time.Microsecond, nodeScope, labeled.EmitUnlabeledMetric), - NodeInputGatherLatency: labeled.NewStopWatch("node_input_latency", "Measures the latency to aggregate inputs and check readiness of a node", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - }, - outputResolver: NewRemoteFileOutputResolver(store), - defaultExecutionDeadline: nodeConfig.DefaultDeadlines.DefaultNodeExecutionDeadline.Duration, - defaultActiveDeadline: nodeConfig.DefaultDeadlines.DefaultNodeActiveDeadline.Duration, maxNodeRetriesForSystemFailures: uint32(nodeConfig.MaxNodeRetriesOnSystemFailures), - interruptibleFailureThreshold: uint32(nodeConfig.InterruptibleFailureThreshold), - defaultDataSandbox: defaultRawOutputPrefix, - shardSelector: shardSelector, + metrics: metrics, + nodeRecorder: events.NewNodeEventRecorder(eventSink, nodeScope, store), + outputResolver: NewRemoteFileOutputResolver(store), recoveryClient: recoveryClient, - eventConfig: eventConfig, - clusterID: clusterID, + shardSelector: shardSelector, + store: store, + taskRecorder: events.NewTaskEventRecorder(eventSink, scope.NewSubScope("task"), store), + } + + exec := &recursiveNodeExecutor{ + nodeExecutor: nodeExecutor, + nCtxBuilder: nodeExecutor, + + enqueueWorkflow: enQWorkflow, + store: store, + metrics: metrics, } nodeHandlerFactory, err := NewHandlerFactory(ctx, exec, workflowLauncher, launchPlanReader, kubeClient, catalogClient, recoveryClient, eventConfig, clusterID, signalClient, nodeScope) exec.nodeHandlerFactory = nodeHandlerFactory diff --git a/pkg/controller/nodes/executor_test.go b/pkg/controller/nodes/executor_test.go index a707b4dfc..cccec87ef 100644 --- a/pkg/controller/nodes/executor_test.go +++ b/pkg/controller/nodes/executor_test.go @@ -146,7 +146,7 @@ func TestNodeExecutor_Initialize(t *testing.T) { execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, memStore, enQWf, mockEventSink, adminClient, adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) - exec := execIface.(*nodeExecutor) + exec := execIface.(*recursiveNodeExecutor) hf := &mocks2.HandlerFactory{} exec.nodeHandlerFactory = hf @@ -160,7 +160,7 @@ func TestNodeExecutor_Initialize(t *testing.T) { execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, memStore, enQWf, mockEventSink, adminClient, adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) - exec := execIface.(*nodeExecutor) + exec := execIface.(*recursiveNodeExecutor) hf := &mocks2.HandlerFactory{} exec.nodeHandlerFactory = hf @@ -183,7 +183,7 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseStartNodes(t *testing.T) { execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) - exec := execIface.(*nodeExecutor) + exec := execIface.(*recursiveNodeExecutor) defaultNodeID := "n1" @@ -287,7 +287,7 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseEndNode(t *testing.T) { execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) - exec := execIface.(*nodeExecutor) + exec := execIface.(*recursiveNodeExecutor) // Node not yet started { @@ -693,7 +693,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) - exec := execIface.(*nodeExecutor) + exec := execIface.(*recursiveNodeExecutor) exec.nodeHandlerFactory = hf execContext := executors.NewExecutionContext(mockWf, mockWf, mockWf, nil, executors.InitializeControlFlow()) @@ -768,7 +768,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) - exec := execIface.(*nodeExecutor) + exec := execIface.(*recursiveNodeExecutor) exec.nodeHandlerFactory = hf called := false @@ -880,7 +880,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) - exec := execIface.(*nodeExecutor) + exec := execIface.(*recursiveNodeExecutor) exec.nodeHandlerFactory = hf called := false @@ -944,7 +944,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) - exec := execIface.(*nodeExecutor) + exec := execIface.(*recursiveNodeExecutor) exec.nodeHandlerFactory = hf h := &nodeHandlerMocks.Node{} @@ -975,7 +975,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) - exec := execIface.(*nodeExecutor) + exec := execIface.(*recursiveNodeExecutor) exec.nodeHandlerFactory = hf h := &nodeHandlerMocks.Node{} @@ -1009,7 +1009,7 @@ func TestNodeExecutor_RecursiveNodeHandler_NoDownstream(t *testing.T) { execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) - exec := execIface.(*nodeExecutor) + exec := execIface.(*recursiveNodeExecutor) defaultNodeID := "n1" taskID := "tID" @@ -1120,7 +1120,7 @@ func TestNodeExecutor_RecursiveNodeHandler_UpstreamNotReady(t *testing.T) { execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) - exec := execIface.(*nodeExecutor) + exec := execIface.(*recursiveNodeExecutor) defaultNodeID := "n1" taskID := taskID @@ -1236,7 +1236,7 @@ func TestNodeExecutor_RecursiveNodeHandler_BranchNode(t *testing.T) { execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) - exec := execIface.(*nodeExecutor) + exec := execIface.(*recursiveNodeExecutor) // Node not yet started { tests := []struct { @@ -1389,7 +1389,7 @@ func Test_nodeExecutor_RecordTransitionLatency(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := &nodeExecutor{ + c := &recursiveNodeExecutor{ nodeHandlerFactory: tt.fields.nodeHandlerFactory, enqueueWorkflow: tt.fields.enqueueWorkflow, store: tt.fields.store, @@ -1489,7 +1489,7 @@ func Test_nodeExecutor_timeout(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := &nodeExecutor{defaultActiveDeadline: time.Second, defaultExecutionDeadline: time.Second} + c := &recursiveNodeExecutor{defaultActiveDeadline: time.Second, defaultExecutionDeadline: time.Second} handlerReturn := func() (handler.Transition, error) { return handler.DoTransition(handler.TransitionTypeEphemeral, tt.phaseInfo), tt.err } @@ -1544,7 +1544,7 @@ func Test_nodeExecutor_system_error(t *testing.T) { ns.On("ClearLastAttemptStartedAt").Return() - c := &nodeExecutor{} + c := &recursiveNodeExecutor{} h := &nodeHandlerMocks.Node{} h.On("Handle", mock.MatchedBy(func(ctx context.Context) bool { return true }), @@ -1575,7 +1575,7 @@ func Test_nodeExecutor_system_error(t *testing.T) { func Test_nodeExecutor_abort(t *testing.T) { ctx := context.Background() - exec := nodeExecutor{} + exec := recursiveNodeExecutor{} nCtx := &nodeExecContext{} t.Run("abort error calls finalize", func(t *testing.T) { @@ -1623,7 +1623,7 @@ func Test_nodeExecutor_abort(t *testing.T) { func TestNodeExecutor_AbortHandler(t *testing.T) { ctx := context.Background() - exec := nodeExecutor{} + exec := recursiveNodeExecutor{} t.Run("not-yet-started", func(t *testing.T) { id := "id" @@ -1658,7 +1658,7 @@ func TestNodeExecutor_AbortHandler(t *testing.T) { h.OnFinalizeMatch(mock.Anything, mock.Anything).Return(nil) hf.OnGetHandlerMatch(v1alpha1.NodeKindStart).Return(h, nil) - nExec := nodeExecutor{ + nExec := recursiveNodeExecutor{ nodeRecorder: incompatibleClusterErr, nodeHandlerFactory: hf, } @@ -1680,7 +1680,7 @@ func TestNodeExecutor_AbortHandler(t *testing.T) { func TestNodeExecutor_FinalizeHandler(t *testing.T) { ctx := context.Background() - exec := nodeExecutor{} + exec := recursiveNodeExecutor{} t.Run("not-yet-started", func(t *testing.T) { id := "id" @@ -1846,7 +1846,7 @@ func TestNodeExecutor_RecursiveNodeHandler_ParallelismLimit(t *testing.T) { execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) - exec := execIface.(*nodeExecutor) + exec := execIface.(*recursiveNodeExecutor) defaultNodeID := "n1" taskID := taskID @@ -2013,7 +2013,7 @@ func Test_nodeExecutor_IdempotentRecordEvent(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := &nodeExecutor{ + c := &recursiveNodeExecutor{ nodeRecorder: tt.rec, eventConfig: &config.EventConfig{ RawOutputPolicy: config.RawOutputPolicyReference, @@ -2139,7 +2139,7 @@ func TestRecover(t *testing.T) { } nCtx.OnDataStore().Return(storageClient) - executor := nodeExecutor{ + executor := recursiveNodeExecutor{ recoveryClient: recoveryClient, store: storageClient, eventConfig: &config.EventConfig{ @@ -2232,7 +2232,7 @@ func TestRecover(t *testing.T) { DynamicWorkflow: dynamicWorkflow, }, nil) - executor := nodeExecutor{ + executor := recursiveNodeExecutor{ recoveryClient: recoveryClient, store: storageClient, eventConfig: eventConfig, @@ -2305,7 +2305,7 @@ func TestRecover(t *testing.T) { nCtx.OnDataStore().Return(storageClient) - executor := nodeExecutor{ + executor := recursiveNodeExecutor{ recoveryClient: recoveryClient, store: storageClient, eventConfig: eventConfig, @@ -2369,7 +2369,7 @@ func TestRecover(t *testing.T) { } nCtx.OnDataStore().Return(storageClient) - executor := nodeExecutor{ + executor := recursiveNodeExecutor{ recoveryClient: recoveryClient, store: storageClient, eventConfig: eventConfig, @@ -2399,7 +2399,7 @@ func TestRecover(t *testing.T) { }, }, nil) - executor := nodeExecutor{ + executor := recursiveNodeExecutor{ recoveryClient: recoveryClient, } @@ -2454,7 +2454,7 @@ func TestRecover(t *testing.T) { } nCtx.OnDataStore().Return(storageClient) - executor := nodeExecutor{ + executor := recursiveNodeExecutor{ recoveryClient: recoveryClient, store: storageClient, eventConfig: eventConfig, @@ -2499,7 +2499,7 @@ func TestRecover(t *testing.T) { } nCtx.OnDataStore().Return(storageClient) - executor := nodeExecutor{ + executor := recursiveNodeExecutor{ recoveryClient: recoveryClient, store: storageClient, eventConfig: eventConfig, diff --git a/pkg/controller/nodes/handler/iface.go b/pkg/controller/nodes/handler/iface.go index c2de8af08..d85caef19 100644 --- a/pkg/controller/nodes/handler/iface.go +++ b/pkg/controller/nodes/handler/iface.go @@ -3,12 +3,21 @@ package handler import ( "context" + "github.com/flyteorg/flytepropeller/pkg/controller/executors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flytestdlib/promutils" ) //go:generate mockery -all -case=underscore +// TODO @hamersaw - docs?!?1 +type NodeExecutor interface { + // TODO @hamersaw - BuildNodeExecutionContext should be here - removes need for another interface + HandleNode(ctx context.Context, dag executors.DAGStructure, nCtx interfaces.NodeExecutionContext, h Node) (interfaces.NodeStatus, error) + Abort(ctx context.Context, h Node, nCtx interfaces.NodeExecutionContext, reason string) error + Finalize(ctx context.Context, h Node, nCtx interfaces.NodeExecutionContext) error +} + // Interface that should be implemented for a node type. type Node interface { // Method to indicate that finalize is required for this handler diff --git a/pkg/controller/nodes/interfaces/node.go b/pkg/controller/nodes/interfaces/node.go index 2974090cd..e279de1c7 100644 --- a/pkg/controller/nodes/interfaces/node.go +++ b/pkg/controller/nodes/interfaces/node.go @@ -79,9 +79,6 @@ type Node interface { RecursiveNodeHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) (NodeStatus, error) - RecursiveNodeHandlerWithNodeContextModifier(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, - nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode, nCtxModifier func (NodeExecutionContext) NodeExecutionContext) (NodeStatus, error) - // This aborts the given node. If the given node is complete then it recursively finds the running nodes and aborts them AbortHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode, reason string) error @@ -89,12 +86,19 @@ type Node interface { FinalizeHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) error - // TODO @hamersaw - docs - NewNodeExecutionContext(ctx context.Context, executionContext executors.ExecutionContext, - nl executors.NodeLookup, currentNodeID v1alpha1.NodeID) (NodeExecutionContext, error) - // This method should be used to initialize Node executor Initialize(ctx context.Context) error + + // TODO @hamersaw - docs + GetNodeExecutionContextBuilder() NodeExecutionContextBuilder + WithNodeExecutionContextBuilder(NodeExecutionContextBuilder) Node +} + +// TODO @hamersaw - docs +type NodeExecutionContextBuilder interface { + //BuildNodeExecutionContext(execContext executors.ExecutionContext) NodeExecutionContext + BuildNodeExecutionContext(ctx context.Context, executionContext executors.ExecutionContext, + nl executors.NodeLookup, currentNodeID v1alpha1.NodeID) (NodeExecutionContext, error) } // Helper struct to allow passing of status between functions diff --git a/pkg/controller/nodes/node_exec_context.go b/pkg/controller/nodes/node_exec_context.go index 27c3e82c7..681c32f48 100644 --- a/pkg/controller/nodes/node_exec_context.go +++ b/pkg/controller/nodes/node_exec_context.go @@ -185,7 +185,7 @@ func newNodeExecContext(_ context.Context, store *storage.DataStore, execContext } } -func (c *nodeExecutor) NewNodeExecutionContext(ctx context.Context, executionContext executors.ExecutionContext, +func (c *nodeExecutor) BuildNodeExecutionContext(ctx context.Context, executionContext executors.ExecutionContext, nl executors.NodeLookup, currentNodeID v1alpha1.NodeID) (interfaces.NodeExecutionContext, error) { n, ok := nl.GetNode(currentNodeID) if !ok { diff --git a/pkg/controller/nodes/node_exec_context_test.go b/pkg/controller/nodes/node_exec_context_test.go index 48fa664a3..8ad294287 100644 --- a/pkg/controller/nodes/node_exec_context_test.go +++ b/pkg/controller/nodes/node_exec_context_test.go @@ -108,7 +108,7 @@ func Test_NodeContextDefault(t *testing.T) { SystemFailures: 0, }) - nodeExecutor := nodeExecutor{ + nodeExecutor := recursiveNodeExecutor{ interruptibleFailureThreshold: 0, maxDatasetSizeBytes: 0, defaultDataSandbox: "s3://bucket-a", @@ -133,7 +133,7 @@ func Test_NodeContextDefaultInterruptible(t *testing.T) { scope := promutils.NewTestScope() dataStore, _ := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, scope.NewSubScope("dataStore")) - nodeExecutor := nodeExecutor{ + nodeExecutor := recursiveNodeExecutor{ interruptibleFailureThreshold: 10, maxDatasetSizeBytes: 0, defaultDataSandbox: "s3://bucket-a", diff --git a/pkg/controller/nodes/setup_context.go b/pkg/controller/nodes/setup_context.go index ef192f453..c940447a7 100644 --- a/pkg/controller/nodes/setup_context.go +++ b/pkg/controller/nodes/setup_context.go @@ -26,7 +26,7 @@ func (s *setupContext) MetricsScope() promutils.Scope { return s.scope } -func (c *nodeExecutor) newSetupContext(_ context.Context) handler.SetupContext { +func (c *recursiveNodeExecutor) newSetupContext(_ context.Context) handler.SetupContext { return &setupContext{ enq: c.enqueueWorkflow, scope: c.metrics.Scope, From 110e1eacdc514c404d47d46a7a6a0927f4cd7b18 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Wed, 26 Apr 2023 08:28:35 -0500 Subject: [PATCH 11/62] refactoring TODOs Signed-off-by: Daniel Rammer --- .../nodes/array/execution_context.go | 3 - pkg/controller/nodes/array/handler.go | 230 +++++++++--------- pkg/controller/nodes/array/input_reader.go | 26 +- .../nodes/array/node_executor.go.bak | 47 ---- pkg/controller/nodes/errors/codes.go | 1 + pkg/controller/nodes/handler_factory.go | 7 +- .../nodes/task/plugin_state_manager.go | 4 +- 7 files changed, 126 insertions(+), 192 deletions(-) delete mode 100644 pkg/controller/nodes/array/node_executor.go.bak diff --git a/pkg/controller/nodes/array/execution_context.go b/pkg/controller/nodes/array/execution_context.go index 11aaf3693..a701d5216 100644 --- a/pkg/controller/nodes/array/execution_context.go +++ b/pkg/controller/nodes/array/execution_context.go @@ -2,7 +2,6 @@ package array import ( "context" - "fmt" "strconv" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" @@ -95,14 +94,12 @@ func (a *arrayNodeExecutionContextBuilder) BuildNodeExecutionContext(ctx context return nil, err } - fmt.Println("HAMERSAW - currentNodeID %s subNodeID %s!\n", currentNodeID, a.subNodeID) if currentNodeID == a.subNodeID { // overwrite NodeExecutionContext for ArrayNode execution nCtx = newArrayNodeExecutionContext(nCtx, a.inputReader, a.subNodeIndex) } return nCtx, nil - } func newArrayNodeExecutionContextBuilder(nCtxBuilder interfaces.NodeExecutionContextBuilder, subNodeID v1alpha1.NodeID, diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index d26aa510e..60718c0da 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -9,13 +9,14 @@ import ( idlcore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - //"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" "github.com/flyteorg/flytepropeller/pkg/compiler/validators" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/k8s" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/codex" @@ -29,8 +30,10 @@ import ( // arrayNodeHandler is a handle implementation for processing array nodes type arrayNodeHandler struct { - metrics metrics - nodeExecutor interfaces.Node + metrics metrics + nodeExecutor interfaces.Node + pluginStateBytesNotStarted []byte + pluginStateBytesStarted []byte } // metrics encapsulates the prometheus metrics for this handler @@ -52,13 +55,13 @@ func (a *arrayNodeHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecut // Finalize completes the array node defined in the NodeExecutionContext func (a *arrayNodeHandler) Finalize(ctx context.Context, _ interfaces.NodeExecutionContext) error { - return nil // TODO @hamersaw - implement finalize + return nil // TODO @hamersaw - implement finalize - clear node data?!?! } // FinalizeRequired defines whether or not this handler requires finalize to be called on // node completion func (a *arrayNodeHandler) FinalizeRequired() bool { - return true // TODO @hamersaw - implement finalize required + return false } // Handle is responsible for transitioning and reporting node state to complete the node defined @@ -67,23 +70,12 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu arrayNode := nCtx.Node().GetArrayNode() arrayNodeState := nCtx.NodeStateReader().GetArrayNodeState() - // TODO @hamersaw - handle array node - // the big question right now is if we make a DAG with everything or call a separate DAG for each individual task - // need to do much more thinking on this - cleaner = a single DAG / maybe easier = DAG for each - // single: - // + can still add envVars - override in ArrayNodeExectionContext - // each: - // + add envVars on ExecutionContext - // - need to manage - - var inputs *idlcore.LiteralMap - switch arrayNodeState.Phase { case v1alpha1.ArrayNodePhaseNone: // identify and validate array node input value lengths literalMap, err := nCtx.InputReader().Get(ctx) if err != nil { - return handler.UnknownTransition, err // TODO @hamersaw fail + return handler.UnknownTransition, err } size := -1 @@ -91,57 +83,63 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu literalType := validators.LiteralTypeForLiteral(variable) switch literalType.Type.(type) { case *idlcore.LiteralType_CollectionType: - collection := variable.GetCollection() - collectionLength := len(collection.Literals) + collectionLength := len(variable.GetCollection().Literals) if size == -1 { size = collectionLength } else if size != collectionLength { - // TODO @hamersaw - return error + return handler.DoTransition(handler.TransitionTypeEphemeral, + handler.PhaseInfoFailure(idlcore.ExecutionError_USER, errors.InvalidArrayLength, + fmt.Sprintf("input arrays have different lengths: expecting '%d' found '%d'", size, collectionLength), nil), + ), nil } } } if size == -1 { - // TODO @hamersaw return + return handler.DoTransition(handler.TransitionTypeEphemeral, + handler.PhaseInfoFailure(idlcore.ExecutionError_USER, errors.InvalidArrayLength, "no input array provided", nil), + ), nil } // initialize ArrayNode state - arrayNodeState.Phase = v1alpha1.ArrayNodePhaseExecuting arrayNodeState.SubNodePhases, err = bitarray.NewCompactArray(uint(size), bitarray.Item(len(core.Phases)-1)) if err != nil { - // TODO @hamersaw fail + return handler.UnknownTransition, err } + // TODO @hamersaw - init SystemFailures and RetryAttempts as well // do we want to abstract this? ie. arrayNodeState.GetStats(subNodeIndex) (phase, systemFailures, ...) //fmt.Printf("HAMERSAW - created SubNodePhases with length '%d:%d'\n", size, len(arrayNodeState.SubNodePhases.GetItems())) + arrayNodeState.Phase = v1alpha1.ArrayNodePhaseExecuting case v1alpha1.ArrayNodePhaseExecuting: // process array node subnodes for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { nodePhase := v1alpha1.NodePhase(nodePhaseUint64) - fmt.Printf("HAMERSAW - TODO evaluating node '%d' in phase '%d'\n", i, nodePhase) + //fmt.Printf("HAMERSAW - evaluating node '%d' in phase '%d'\n", i, nodePhase) - // TODO @hamersaw - fix + // TODO @hamersaw fix - do not process nodes in terminal state + //if nodes.IsTerminalNodePhase(nodePhase) { if nodePhase == v1alpha1.NodePhaseSucceeded || nodePhase == v1alpha1.NodePhaseFailed || nodePhase == v1alpha1.NodePhaseTimedOut || nodePhase == v1alpha1.NodePhaseSkipped { continue } - /*if nodes.IsTerminalNodePhase(nodePhase) { - continue - }*/ - - // TODO @hamersaw - do we need to init input readers every time? - literalMap, err := constructLiteralMap(ctx, nCtx.InputReader(), i, inputs) - if err != nil { - logger.Errorf(ctx, "HAMERSAW - %+v", err) - // TODO @hamersaw - return err + // initialize input reader if NodePhaseNotyetStarted or NodePhaseSucceeding for cache lookup and population + var inputLiteralMap *idlcore.LiteralMap + var err error + if nodePhase == v1alpha1.NodePhaseNotYetStarted || nodePhase == v1alpha1.NodePhaseSucceeding { + inputLiteralMap, err = constructLiteralMap(ctx, nCtx.InputReader(), i) + if err != nil { + return handler.UnknownTransition, err + } } - inputReader := newStaticInputReader(nCtx.InputReader(), &literalMap) + inputReader := newStaticInputReader(nCtx.InputReader(), inputLiteralMap) + // if node has not yet started we automatically set to NodePhaseQueued to skip input resolution if nodePhase == v1alpha1.NodePhaseNotYetStarted { - // set nodePhase to Queued to skip resolving inputs but still allow cache lookups + // TODO @hamersaw how does this work with fastcache? nodePhase = v1alpha1.NodePhaseQueued } @@ -152,21 +150,13 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu subNodeSpec.ID = subNodeID subNodeSpec.Name = subNodeID - // TODO @hamersaw - is this right?!?! it's certainly HACKY AF - maybe we persist pluginState.Phase and PluginPhase - pluginState := k8s.PluginState{ - } + // TODO @hamersaw - store task phase and use to mock plugin state + // TODO - if we want to support more plugin types we need to figure out the best way to store plugin state + // currently just mocking based on node phase -> which works for all k8s plugins + // we can not pre-allocated a bit array because max size is 256B and with 5k fanout node state = 1.28MB + pluginStateBytes := a.pluginStateBytesStarted if nodePhase == v1alpha1.NodePhaseQueued { - pluginState.Phase = k8s.PluginPhaseNotStarted - } else { - pluginState.Phase = k8s.PluginPhaseStarted - } - - buffer := make([]byte, 0, 256) - bufferWriter := bytes.NewBuffer(buffer) - - codec := codex.GobStateCodec{} - if err := codec.Encode(pluginState, bufferWriter); err != nil { - logger.Errorf(ctx, "HAMERSAW - %+v", err) + pluginStateBytes = a.pluginStateBytesNotStarted } // we set subDataDir and subOutputDir to the node dirs because flytekit automatically appends subtask @@ -175,20 +165,13 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // append the subtask index. var subDataDir, subOutputDir storage.DataReference if nodePhase == v1alpha1.NodePhaseQueued { - subDataDir = nCtx.NodeStatus().GetDataDir() - subOutputDir = nCtx.NodeStatus().GetOutputDir() + subDataDir, subOutputDir, err = constructOutputReferences(ctx, nCtx) } else { - subDataDir, err = nCtx.DataStore().ConstructReference(ctx, nCtx.NodeStatus().GetDataDir(), strconv.Itoa(i)) // TODO @hamersaw - constructOutputReference? - if err != nil { - logger.Errorf(ctx, "HAMERSAW - %+v", err) - // TODO @hamersaw - return err - } + subDataDir, subOutputDir, err = constructOutputReferences(ctx, nCtx, strconv.Itoa(i)) + } - subOutputDir, err = nCtx.DataStore().ConstructReference(ctx, nCtx.NodeStatus().GetOutputDir(), strconv.Itoa(i)) // TODO @hamersaw - constructOutputReference? - if err != nil { - logger.Errorf(ctx, "HAMERSAW - %+v", err) - // TODO @hamersaw - return err - } + if err != nil { + return handler.UnknownTransition, err } subNodeStatus := &v1alpha1.NodeStatus{ @@ -197,38 +180,24 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // TODO @hamersaw - to get caching working we need to set to Queued to force cache lookup // once fastcache is done we dont care about the TaskNodeStatus Phase: int(core.Phases[core.PhaseRunning]), - PluginState: bufferWriter.Bytes(), + PluginState: pluginStateBytes, }, DataDir: subDataDir, OutputDir: subOutputDir, // TODO @hamersaw - fill out systemFailures, retryAttempt etc } - // TODO @hamersaw - can probably create a single arrayNodeLookup with all the subNodeIDs arrayNodeLookup := newArrayNodeLookup(nCtx.ContextualNodeLookup(), subNodeID, &subNodeSpec, subNodeStatus) - // execute subNode through RecursiveNodeHandler - /*_, err = a.nodeExecutor.RecursiveNodeHandlerWithNodeContextModifier(ctx, nCtx.ExecutionContext(), &arrayNodeLookup, &arrayNodeLookup, &subNodeSpec, - func (nCtx interfaces.NodeExecutionContext) interfaces.NodeExecutionContext { - if nCtx.NodeID() == subNodeID { - return newArrayNodeExecutionContext(nCtx, inputReader, i) - } - - return nCtx - })*/ - - // TODO @hamersaw - move all construction of nCtx internal -> can build a single arrayNodeExecutor and use for everyone -> build differently based on index // execute subNode through RecursiveNodeHandler arrayNodeExecutionContextBuilder := newArrayNodeExecutionContextBuilder(a.nodeExecutor.GetNodeExecutionContextBuilder(), subNodeID, i, inputReader) arrayNodeExecutor := a.nodeExecutor.WithNodeExecutionContextBuilder(arrayNodeExecutionContextBuilder) _, err = arrayNodeExecutor.RecursiveNodeHandler(ctx, nCtx.ExecutionContext(), &arrayNodeLookup, &arrayNodeLookup, &subNodeSpec) - if err != nil { - logger.Errorf(ctx, "HAMERSAW - %+v", err) - // TODO @hamersaw fail + return handler.UnknownTransition, err } - fmt.Printf("HAMERSAW - node phase transition %d -> %d\n", nodePhase, subNodeStatus.GetPhase()) + //fmt.Printf("HAMERSAW - node phase transition %d -> %d\n", nodePhase, subNodeStatus.GetPhase()) arrayNodeState.SubNodePhases.SetItem(i, uint64(subNodeStatus.GetPhase())) } @@ -249,19 +218,11 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // TODO @hamersaw - abort everything! case v1alpha1.ArrayNodePhaseSucceeding: outputLiterals := make(map[string]*idlcore.Literal) - for i, _ := range arrayNodeState.SubNodePhases.GetItems() { // initialize subNode reader - subDataDir, err := nCtx.DataStore().ConstructReference(ctx, nCtx.NodeStatus().GetDataDir(), strconv.Itoa(i)) // TODO @hamersaw - constructOutputReference? + subDataDir, subOutputDir, err := constructOutputReferences(ctx, nCtx, strconv.Itoa(i)) if err != nil { - logger.Errorf(ctx, "HAMERSAW - %+v", err) - // TODO @hamersaw - return err - } - - subOutputDir, err := nCtx.DataStore().ConstructReference(ctx, nCtx.NodeStatus().GetOutputDir(), strconv.Itoa(i)) // TODO @hamersaw - constructOutputReference? - if err != nil { - logger.Errorf(ctx, "HAMERSAW - %+v", err) - // TODO @hamersaw - return err + return handler.UnknownTransition, err } // checkpoint paths are not computed here because this function is only called when writing @@ -270,47 +231,41 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu reader := ioutils.NewRemoteFileOutputReader(ctx, nCtx.DataStore(), outputPaths, int64(999999999)) // read outputs - outputs, executionError, err := reader.Read(ctx) + outputs, executionErr, err := reader.Read(ctx) if err != nil { - logger.Warnf(ctx, "Failed to read output for subtask [%v]. Error: %v", i, err) - //return workqueue.WorkStatusFailed, err // TODO @hamersaw -return error + return handler.UnknownTransition, err + } else if executionErr != nil { + return handler.UnknownTransition, executionErr } - if executionError == nil && outputs != nil { - for name, literal := range outputs.GetLiterals() { - existingVal, found := outputLiterals[name] - var list *idlcore.LiteralCollection - if found { - list = existingVal.GetCollection() - } else { - list = &idlcore.LiteralCollection{ - Literals: make([]*idlcore.Literal, 0, len(arrayNodeState.SubNodePhases.GetItems())), - } - - existingVal = &idlcore.Literal{ - Value: &idlcore.Literal_Collection{ - Collection: list, + // copy individual subNode output literals into a collection of output literals + for name, literal := range outputs.GetLiterals() { + outputLiteral, exists := outputLiterals[name] + if !exists { + outputLiteral = &idlcore.Literal{ + Value: &idlcore.Literal_Collection{ + Collection: &idlcore.LiteralCollection{ + Literals: make([]*idlcore.Literal, 0, len(arrayNodeState.SubNodePhases.GetItems())), }, - } + }, } - list.Literals = append(list.Literals, literal) - outputLiterals[name] = existingVal + outputLiterals[name] = outputLiteral } + + collection := outputLiteral.GetCollection() + collection.Literals = append(collection.Literals, literal) } } - // TODO @hamersaw - collect outputs and write as List[] - fmt.Printf("HAMERSAW - final outputs %+v\n", idlcore.LiteralMap{Literals: outputLiterals}) outputLiteralMap := &idlcore.LiteralMap{ Literals: outputLiterals, } + //fmt.Printf("HAMERSAW - final outputs %+v\n", outputLiteralMap) outputFile := v1alpha1.GetOutputsFile(nCtx.NodeStatus().GetOutputDir()) if err := nCtx.DataStore().WriteProtobuf(ctx, outputFile, storage.Options{}, outputLiteralMap); err != nil { - // TODO @hamersaw return error - //return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, "WriteOutputsFailed", - // fmt.Sprintf("failed to write signal value to [%v] with error [%s]", outputFile, err.Error()), nil)), nil + return handler.UnknownTransition, err } return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil @@ -329,14 +284,53 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // Setup handles any initialization requirements for this handler func (a *arrayNodeHandler) Setup(_ context.Context, _ handler.SetupContext) error { - return nil // TODO @hamersaw - implement setup + return nil } // New initializes a new arrayNodeHandler -func New(nodeExecutor interfaces.Node, scope promutils.Scope) handler.Node { +func New(nodeExecutor interfaces.Node, scope promutils.Scope) (handler.Node, error) { + // create k8s PluginState byte mocks to reuse instead of creating for each subNode evaluation + pluginStateBytesNotStarted, err := bytesFromK8sPluginState(k8s.PluginState{Phase: k8s.PluginPhaseNotStarted}) + if err != nil { + return nil, err + } + + pluginStateBytesStarted, err := bytesFromK8sPluginState(k8s.PluginState{Phase: k8s.PluginPhaseStarted}) + if err != nil { + return nil, err + } + arrayScope := scope.NewSubScope("array") return &arrayNodeHandler{ - metrics: newMetrics(arrayScope), - nodeExecutor: nodeExecutor, + metrics: newMetrics(arrayScope), + nodeExecutor: nodeExecutor, + pluginStateBytesNotStarted: pluginStateBytesNotStarted, + pluginStateBytesStarted: pluginStateBytesStarted, + }, nil +} + +func bytesFromK8sPluginState(pluginState k8s.PluginState) ([]byte, error) { + buffer := make([]byte, 0, task.MaxPluginStateSizeBytes) + bufferWriter := bytes.NewBuffer(buffer) + + codec := codex.GobStateCodec{} + if err := codec.Encode(pluginState, bufferWriter); err != nil { + return nil, err } + + return bufferWriter.Bytes(), nil +} + +func constructOutputReferences(ctx context.Context, nCtx interfaces.NodeExecutionContext, postfix...string) (storage.DataReference, storage.DataReference, error) { + subDataDir, err := nCtx.DataStore().ConstructReference(ctx, nCtx.NodeStatus().GetDataDir(), postfix...) + if err != nil { + return "", "", err + } + + subOutputDir, err := nCtx.DataStore().ConstructReference(ctx, nCtx.NodeStatus().GetOutputDir(), postfix...) + if err != nil { + return "", "", err + } + + return subDataDir, subOutputDir, nil } diff --git a/pkg/controller/nodes/array/input_reader.go b/pkg/controller/nodes/array/input_reader.go index de6e6ef81..4059db95d 100644 --- a/pkg/controller/nodes/array/input_reader.go +++ b/pkg/controller/nodes/array/input_reader.go @@ -23,13 +23,10 @@ func newStaticInputReader(inputPaths io.InputFilePaths, input *idlcore.LiteralMa } } -func constructLiteralMap(ctx context.Context, inputReader io.InputReader, index int, inputs *idlcore.LiteralMap) (idlcore.LiteralMap, error) { - var err error - if inputs == nil { - inputs, err = inputReader.Get(ctx) - if err != nil { - return idlcore.LiteralMap{}, err - } +func constructLiteralMap(ctx context.Context, inputReader io.InputReader, index int) (*idlcore.LiteralMap, error) { + inputs, err := inputReader.Get(ctx) + if err != nil { + return nil, err } literals := make(map[string]*idlcore.Literal) @@ -39,20 +36,7 @@ func constructLiteralMap(ctx context.Context, inputReader io.InputReader, index } } - return idlcore.LiteralMap{ + return &idlcore.LiteralMap{ Literals: literals, }, nil } - -/*func ConstructStaticInputReaders(inputPaths io.InputFilePaths, inputs []*idlcore.Literal, inputName string) []io.InputReader { - inputReaders := make([]io.InputReader, 0, len(inputs)) - for i := 0; i < len(inputs); i++ { - inputReaders = append(inputReaders, NewStaticInputReader(inputPaths, &idlcore.LiteralMap{ - Literals: map[string]*idlcore.Literal{ - inputName: inputs[i], - }, - })) - } - - return inputReaders -}*/ diff --git a/pkg/controller/nodes/array/node_executor.go.bak b/pkg/controller/nodes/array/node_executor.go.bak deleted file mode 100644 index d9e5cf18b..000000000 --- a/pkg/controller/nodes/array/node_executor.go.bak +++ /dev/null @@ -1,47 +0,0 @@ -package array - -import ( - "context" - "fmt" - - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" - - "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" - "github.com/flyteorg/flytepropeller/pkg/controller/executors" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" -) - -type arrayNodeExecutor struct { - interfaces.Node - subNodeID v1alpha1.NodeID - subNodeIndex int - inputReader io.InputReader -} - -// TODO @hamersaw - docs -func (a arrayNodeExecutor) NewNodeExecutionContext(ctx context.Context, executionContext executors.ExecutionContext, - nl executors.NodeLookup, currentNodeID v1alpha1.NodeID) (interfaces.NodeExecutionContext, error) { - - // create base NodeExecutionContext - nCtx, err := a.Node.NewNodeExecutionContext(ctx, executionContext, nl, currentNodeID) - if err != nil { - return nil, err - } - - fmt.Println("HAMERSAW - currentNodeID %s subNodeID %s!\n", currentNodeID, a.subNodeID) - if currentNodeID == a.subNodeID { - // TODO @hamersaw - overwrite NodeExecutionContext for ArrayNode execution - nCtx = newArrayNodeExecutionContext(nCtx, a.inputReader, a.subNodeIndex) - } - - return nCtx, nil -} - -func newArrayNodeExecutor(nodeExecutor interfaces.Node, subNodeID v1alpha1.NodeID, subNodeIndex int, inputReader io.InputReader) arrayNodeExecutor { - return arrayNodeExecutor{ - Node: nodeExecutor, - subNodeID: subNodeID, - subNodeIndex: subNodeIndex, - inputReader: inputReader, - } -} diff --git a/pkg/controller/nodes/errors/codes.go b/pkg/controller/nodes/errors/codes.go index df2be215c..30ded68b7 100644 --- a/pkg/controller/nodes/errors/codes.go +++ b/pkg/controller/nodes/errors/codes.go @@ -25,4 +25,5 @@ const ( StorageError ErrorCode = "StorageError" EventRecordingFailed ErrorCode = "EventRecordingFailed" CatalogCallFailed ErrorCode = "CatalogCallFailed" + InvalidArrayLength ErrorCode = "InvalidArrayLength" ) diff --git a/pkg/controller/nodes/handler_factory.go b/pkg/controller/nodes/handler_factory.go index a275b59a7..28ca7e3d0 100644 --- a/pkg/controller/nodes/handler_factory.go +++ b/pkg/controller/nodes/handler_factory.go @@ -65,13 +65,18 @@ func NewHandlerFactory(ctx context.Context, executor interfaces.Node, workflowLa return nil, err } + arrayHandler, err := array.New(executor, scope) + if err != nil { + return nil, err + } + f := &handlerFactory{ handlers: map[v1alpha1.NodeKind]handler.Node{ v1alpha1.NodeKindBranch: branch.New(executor, eventConfig, scope), v1alpha1.NodeKindTask: dynamic.New(t, executor, launchPlanReader, eventConfig, scope), v1alpha1.NodeKindWorkflow: subworkflow.New(executor, workflowLauncher, recoveryClient, eventConfig, scope), v1alpha1.NodeKindGate: gate.New(eventConfig, signalClient, scope), - v1alpha1.NodeKindArray: array.New(executor, scope), + v1alpha1.NodeKindArray: arrayHandler, v1alpha1.NodeKindStart: start.New(), v1alpha1.NodeKindEnd: end.New(), }, diff --git a/pkg/controller/nodes/task/plugin_state_manager.go b/pkg/controller/nodes/task/plugin_state_manager.go index 496f5387f..f68d7b58a 100644 --- a/pkg/controller/nodes/task/plugin_state_manager.go +++ b/pkg/controller/nodes/task/plugin_state_manager.go @@ -19,7 +19,7 @@ const ( const currentCodec = GobCodecVersion // TODO Configurable? -const maxPluginStateSizeBytes = 256 +const MaxPluginStateSizeBytes = 256 type stateCodec interface { Encode(interface{}, io.Writer) error @@ -38,7 +38,7 @@ type pluginStateManager struct { func (p *pluginStateManager) Put(stateVersion uint8, v interface{}) error { p.newStateVersion = stateVersion if v != nil { - buf := make([]byte, 0, maxPluginStateSizeBytes) + buf := make([]byte, 0, MaxPluginStateSizeBytes) p.newState = bytes.NewBuffer(buf) return p.codec.Encode(v, p.newState) } From 0a5c58a318b18ad4d16d8bd99bbcf7467fe91f01 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Wed, 26 Apr 2023 13:55:42 -0500 Subject: [PATCH 12/62] subnode retries working Signed-off-by: Daniel Rammer --- .../cmd/testdata/array-node.yaml.golden | 5 +- pkg/apis/flyteworkflow/v1alpha1/iface.go | 6 ++ .../flyteworkflow/v1alpha1/node_status.go | 44 ++++++++++- .../nodes/array/execution_context.go | 45 +++++------ pkg/controller/nodes/array/handler.go | 75 ++++++++++++++----- pkg/controller/nodes/interfaces/state.go | 7 +- pkg/controller/nodes/node_state_manager.go | 3 + pkg/controller/nodes/transformers.go | 3 + 8 files changed, 136 insertions(+), 52 deletions(-) diff --git a/cmd/kubectl-flyte/cmd/testdata/array-node.yaml.golden b/cmd/kubectl-flyte/cmd/testdata/array-node.yaml.golden index 73eb67e81..b42fd58b2 100644 --- a/cmd/kubectl-flyte/cmd/testdata/array-node.yaml.golden +++ b/cmd/kubectl-flyte/cmd/testdata/array-node.yaml.golden @@ -3,7 +3,7 @@ tasks: args: - "pyflyte-fast-execute" - "--additional-distribution" - - "s3://my-s3-bucket/flytesnacks/development/TWUFWD6G3NTY7K2B6YJOCHXIKQ======/script_mode.tar.gz" + - "s3://my-s3-bucket/flytesnacks/development/SMJBJX7BQJ6MCOABLKQT5VZXVY======/script_mode.tar.gz" - "--dest-dir" - "/root" - "--" @@ -77,6 +77,9 @@ workflow: var: a arrayNode: node: + metadata: + retries: + retries: 3 taskNode: referenceId: name: task-1 diff --git a/pkg/apis/flyteworkflow/v1alpha1/iface.go b/pkg/apis/flyteworkflow/v1alpha1/iface.go index 14fd1ba52..40ae70e7f 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/iface.go +++ b/pkg/apis/flyteworkflow/v1alpha1/iface.go @@ -285,6 +285,9 @@ type MutableGateNodeStatus interface { type ExecutableArrayNodeStatus interface { GetArrayNodePhase() ArrayNodePhase GetSubNodePhases() bitarray.CompactArray + GetSubNodeTaskPhases() bitarray.CompactArray + GetSubNodeRetryAttempts() bitarray.CompactArray + GetSubNodeSystemFailures() bitarray.CompactArray } type MutableArrayNodeStatus interface { @@ -292,6 +295,9 @@ type MutableArrayNodeStatus interface { ExecutableArrayNodeStatus SetArrayNodePhase(phase ArrayNodePhase) SetSubNodePhases(subNodePhases bitarray.CompactArray) + SetSubNodeTaskPhases(subNodeTaskPhases bitarray.CompactArray) + SetSubNodeRetryAttempts(subNodeRetryAttempts bitarray.CompactArray) + SetSubNodeSystemFailures(subNodeSystemFailures bitarray.CompactArray) } type Mutable interface { diff --git a/pkg/apis/flyteworkflow/v1alpha1/node_status.go b/pkg/apis/flyteworkflow/v1alpha1/node_status.go index c70821b97..16d9801a1 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/node_status.go +++ b/pkg/apis/flyteworkflow/v1alpha1/node_status.go @@ -219,8 +219,11 @@ const ( type ArrayNodeStatus struct { MutableStruct - Phase ArrayNodePhase `json:"phase,omitempty"` - SubNodePhases bitarray.CompactArray `json:"subphase,omitempty"` + Phase ArrayNodePhase `json:"phase,omitempty"` + SubNodePhases bitarray.CompactArray `json:"subphase,omitempty"` + SubNodeTaskPhases bitarray.CompactArray `json:"subtphase,omitempty"` + SubNodeRetryAttempts bitarray.CompactArray `json:"subattempts,omitempty"` + SubNodeSystemFailures bitarray.CompactArray `json:"subsysfailures,omitempty"` } func (in *ArrayNodeStatus) GetArrayNodePhase() ArrayNodePhase { @@ -245,6 +248,39 @@ func (in *ArrayNodeStatus) SetSubNodePhases(subNodePhases bitarray.CompactArray) } } +func (in *ArrayNodeStatus) GetSubNodeTaskPhases() bitarray.CompactArray { + return in.SubNodeTaskPhases +} + +func (in *ArrayNodeStatus) SetSubNodeTaskPhases(subNodeTaskPhases bitarray.CompactArray) { + if in.SubNodeTaskPhases != subNodeTaskPhases { + in.SetDirty() + in.SubNodeTaskPhases = subNodeTaskPhases + } +} + +func (in *ArrayNodeStatus) GetSubNodeRetryAttempts() bitarray.CompactArray { + return in.SubNodeRetryAttempts +} + +func (in *ArrayNodeStatus) SetSubNodeRetryAttempts(subNodeRetryAttempts bitarray.CompactArray) { + if in.SubNodeRetryAttempts != subNodeRetryAttempts { + in.SetDirty() + in.SubNodeRetryAttempts = subNodeRetryAttempts + } +} + +func (in *ArrayNodeStatus) GetSubNodeSystemFailures() bitarray.CompactArray { + return in.SubNodeSystemFailures +} + +func (in *ArrayNodeStatus) SetSubNodeSystemFailures(subNodeSystemFailures bitarray.CompactArray) { + if in.SubNodeSystemFailures != subNodeSystemFailures { + in.SetDirty() + in.SubNodeSystemFailures = subNodeSystemFailures + } +} + type NodeStatus struct { MutableStruct Phase NodePhase `json:"phase,omitempty"` @@ -286,7 +322,9 @@ func (in *NodeStatus) IsDirty() bool { (in.TaskNodeStatus != nil && in.TaskNodeStatus.IsDirty()) || (in.DynamicNodeStatus != nil && in.DynamicNodeStatus.IsDirty()) || (in.WorkflowNodeStatus != nil && in.WorkflowNodeStatus.IsDirty()) || - (in.BranchStatus != nil && in.BranchStatus.IsDirty()) + (in.BranchStatus != nil && in.BranchStatus.IsDirty()) || + (in.GateNodeStatus != nil && in.GateNodeStatus.IsDirty()) || + (in.ArrayNodeStatus != nil && in.ArrayNodeStatus.IsDirty()) if isDirty { return true } diff --git a/pkg/controller/nodes/array/execution_context.go b/pkg/controller/nodes/array/execution_context.go index a701d5216..cd58b7321 100644 --- a/pkg/controller/nodes/array/execution_context.go +++ b/pkg/controller/nodes/array/execution_context.go @@ -43,6 +43,7 @@ type arrayNodeExecutionContext struct { interfaces.NodeExecutionContext inputReader io.InputReader executionContext arrayExecutionContext + nodeStatus *v1alpha1.NodeStatus } func (a arrayNodeExecutionContext) InputReader() io.InputReader { @@ -53,36 +54,27 @@ func (a arrayNodeExecutionContext) ExecutionContext() executors.ExecutionContext return a.executionContext } -// TODO @hamersaw - overwrite everything -/* -inputReader -taskRecorder -nodeRecorder - need to add to nodeExecutionContext so we can override?!?! -maxParallelism - looks like we need: - ExecutionConfig.GetMaxParallelism - ExecutionContext.IncrementMaxParallelism -storage locations - dataPrefix? - -add environment variables for maptask execution either: - (1) in arrayExecutionContext if we use separate for each - (2) in arrayNodeExectionContext if we choose to use single DAG -*/ - -func newArrayNodeExecutionContext(nodeExecutionContext interfaces.NodeExecutionContext, inputReader io.InputReader, subNodeIndex int) arrayNodeExecutionContext { +func (a arrayNodeExecutionContext) NodeStatus() v1alpha1.ExecutableNodeStatus { + return a.nodeStatus +} + +func newArrayNodeExecutionContext(nodeExecutionContext interfaces.NodeExecutionContext, inputReader io.InputReader, subNodeIndex int, nodeStatus *v1alpha1.NodeStatus) arrayNodeExecutionContext { arrayExecutionContext := newArrayExecutionContext(nodeExecutionContext.ExecutionContext(), subNodeIndex) return arrayNodeExecutionContext{ NodeExecutionContext: nodeExecutionContext, inputReader: inputReader, executionContext: arrayExecutionContext, + nodeStatus: nodeStatus, } } type arrayNodeExecutionContextBuilder struct { - nCtxBuilder interfaces.NodeExecutionContextBuilder - subNodeID v1alpha1.NodeID - subNodeIndex int - inputReader io.InputReader + nCtxBuilder interfaces.NodeExecutionContextBuilder + subNodeID v1alpha1.NodeID + subNodeIndex int + subNodeStatus *v1alpha1.NodeStatus + inputReader io.InputReader } func (a *arrayNodeExecutionContextBuilder) BuildNodeExecutionContext(ctx context.Context, executionContext executors.ExecutionContext, @@ -96,19 +88,20 @@ func (a *arrayNodeExecutionContextBuilder) BuildNodeExecutionContext(ctx context if currentNodeID == a.subNodeID { // overwrite NodeExecutionContext for ArrayNode execution - nCtx = newArrayNodeExecutionContext(nCtx, a.inputReader, a.subNodeIndex) + nCtx = newArrayNodeExecutionContext(nCtx, a.inputReader, a.subNodeIndex, a.subNodeStatus) } return nCtx, nil } func newArrayNodeExecutionContextBuilder(nCtxBuilder interfaces.NodeExecutionContextBuilder, subNodeID v1alpha1.NodeID, - subNodeIndex int, inputReader io.InputReader) interfaces.NodeExecutionContextBuilder { + subNodeIndex int, subNodeStatus *v1alpha1.NodeStatus, inputReader io.InputReader) interfaces.NodeExecutionContextBuilder { return &arrayNodeExecutionContextBuilder{ - nCtxBuilder: nCtxBuilder, - subNodeID: subNodeID, - subNodeIndex: subNodeIndex, - inputReader: inputReader, + nCtxBuilder: nCtxBuilder, + subNodeID: subNodeID, + subNodeIndex: subNodeIndex, + subNodeStatus: subNodeStatus, + inputReader: inputReader, } } diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 60718c0da..8712a8550 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -103,13 +103,24 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu } // initialize ArrayNode state - arrayNodeState.SubNodePhases, err = bitarray.NewCompactArray(uint(size), bitarray.Item(len(core.Phases)-1)) - if err != nil { - return handler.UnknownTransition, err + maxAttempts := task.DefaultMaxAttempts + subNodeSpec := *arrayNode.GetSubNodeSpec() + if subNodeSpec.GetRetryStrategy() != nil && subNodeSpec.GetRetryStrategy().MinAttempts != nil { + maxAttempts = *subNodeSpec.GetRetryStrategy().MinAttempts } - // TODO @hamersaw - init SystemFailures and RetryAttempts as well - // do we want to abstract this? ie. arrayNodeState.GetStats(subNodeIndex) (phase, systemFailures, ...) + for _, item := range []struct{arrayReference *bitarray.CompactArray; maxValue int}{ + {arrayReference: &arrayNodeState.SubNodePhases, maxValue: len(core.Phases)-1}, // TODO @hamersaw - maxValue is for task phases + {arrayReference: &arrayNodeState.SubNodeTaskPhases, maxValue: len(core.Phases)-1}, + {arrayReference: &arrayNodeState.SubNodeRetryAttempts, maxValue: maxAttempts}, + {arrayReference: &arrayNodeState.SubNodeSystemFailures, maxValue: maxAttempts}, + } { + + *item.arrayReference, err = bitarray.NewCompactArray(uint(size), bitarray.Item(item.maxValue)) + if err != nil { + return handler.UnknownTransition, err + } + } //fmt.Printf("HAMERSAW - created SubNodePhases with length '%d:%d'\n", size, len(arrayNodeState.SubNodePhases.GetItems())) arrayNodeState.Phase = v1alpha1.ArrayNodePhaseExecuting @@ -117,11 +128,14 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // process array node subnodes for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { nodePhase := v1alpha1.NodePhase(nodePhaseUint64) + taskPhase := int(arrayNodeState.SubNodeTaskPhases.GetItem(i)) + //fmt.Printf("HAMERSAW - evaluating node '%d' in phase '%d'\n", i, nodePhase) + fmt.Printf("HAMERSAW - evaluating node '%d' in node phase '%d' task phase '%d'\n", i, nodePhase, taskPhase) // TODO @hamersaw fix - do not process nodes in terminal state //if nodes.IsTerminalNodePhase(nodePhase) { - if nodePhase == v1alpha1.NodePhaseSucceeded || nodePhase == v1alpha1.NodePhaseFailed || nodePhase == v1alpha1.NodePhaseTimedOut || nodePhase == v1alpha1.NodePhaseSkipped { + if nodePhase == v1alpha1.NodePhaseSucceeded || nodePhase == v1alpha1.NodePhaseFailed || nodePhase == v1alpha1.NodePhaseTimedOut || nodePhase == v1alpha1.NodePhaseSkipped || nodePhase == v1alpha1.NodePhaseRecovered { continue } @@ -155,9 +169,12 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // currently just mocking based on node phase -> which works for all k8s plugins // we can not pre-allocated a bit array because max size is 256B and with 5k fanout node state = 1.28MB pluginStateBytes := a.pluginStateBytesStarted - if nodePhase == v1alpha1.NodePhaseQueued { + //if nodePhase == v1alpha1.NodePhaseQueued || nodePhase == v1alpha1.NodePhaseRetryableFailure { + if taskPhase == int(core.PhaseUndefined) || taskPhase == int(core.PhaseRetryableFailure) { pluginStateBytes = a.pluginStateBytesNotStarted } + // TODO @hamerssaw NEED TO FIGURE THIS ^^^ OUT when working with node retries + // Failed to find the Resource with name: flytesnacks-development/array-test-40-node-1-n1-1. Error: pods \"array-test-40-node-1-n1-1\" not found // we set subDataDir and subOutputDir to the node dirs because flytekit automatically appends subtask // index. however when we check completion status we need to manually append index - so in all cases @@ -176,46 +193,63 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu subNodeStatus := &v1alpha1.NodeStatus{ Phase: nodePhase, + DataDir: subDataDir, + OutputDir: subOutputDir, + Attempts: uint32(arrayNodeState.SubNodeRetryAttempts.GetItem(i)), + SystemFailures: uint32(arrayNodeState.SubNodeSystemFailures.GetItem(i)), TaskNodeStatus: &v1alpha1.TaskNodeStatus{ - // TODO @hamersaw - to get caching working we need to set to Queued to force cache lookup + // TODO @hamersaw - to get caching working we need to set to Undefined to force cache lookup // once fastcache is done we dont care about the TaskNodeStatus - Phase: int(core.Phases[core.PhaseRunning]), + Phase: taskPhase, PluginState: pluginStateBytes, }, - DataDir: subDataDir, - OutputDir: subOutputDir, - // TODO @hamersaw - fill out systemFailures, retryAttempt etc } arrayNodeLookup := newArrayNodeLookup(nCtx.ContextualNodeLookup(), subNodeID, &subNodeSpec, subNodeStatus) // execute subNode through RecursiveNodeHandler - arrayNodeExecutionContextBuilder := newArrayNodeExecutionContextBuilder(a.nodeExecutor.GetNodeExecutionContextBuilder(), subNodeID, i, inputReader) + arrayNodeExecutionContextBuilder := newArrayNodeExecutionContextBuilder(a.nodeExecutor.GetNodeExecutionContextBuilder(), subNodeID, i, subNodeStatus, inputReader) arrayNodeExecutor := a.nodeExecutor.WithNodeExecutionContextBuilder(arrayNodeExecutionContextBuilder) _, err = arrayNodeExecutor.RecursiveNodeHandler(ctx, nCtx.ExecutionContext(), &arrayNodeLookup, &arrayNodeLookup, &subNodeSpec) if err != nil { return handler.UnknownTransition, err } - //fmt.Printf("HAMERSAW - node phase transition %d -> %d\n", nodePhase, subNodeStatus.GetPhase()) + //fmt.Printf("HAMERSAW - '%d' transition node phase %d -> %d task phase '%d' -> '%d'\n", i, + // nodePhase, subNodeStatus.GetPhase(), taskPhase, subNodeStatus.GetTaskNodeStatus().GetPhase()) + arrayNodeState.SubNodePhases.SetItem(i, uint64(subNodeStatus.GetPhase())) + if subNodeStatus.GetTaskNodeStatus() == nil { + // TODO @hamersaw during retries we clear the GetTaskNodeStatus - so resetting task phase + arrayNodeState.SubNodeTaskPhases.SetItem(i, uint64(0)) + } else { + arrayNodeState.SubNodeTaskPhases.SetItem(i, uint64(subNodeStatus.GetTaskNodeStatus().GetPhase())) + } + arrayNodeState.SubNodeRetryAttempts.SetItem(i, uint64(subNodeStatus.GetAttempts())) + arrayNodeState.SubNodeSystemFailures.SetItem(i, uint64(subNodeStatus.GetSystemFailures())) } // TODO @hamersaw - determine summary phases - succeeded := true + successCount := 0 + failedCount := 0 for _, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { nodePhase := v1alpha1.NodePhase(nodePhaseUint64) - if nodePhase != v1alpha1.NodePhaseSucceeded { - succeeded = false - break + switch nodePhase { + case v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseRecovered: + successCount++ + case v1alpha1.NodePhaseFailed: + failedCount++ } } - if succeeded { + if failedCount > 0 { + arrayNodeState.Phase = v1alpha1.ArrayNodePhaseFailing + } else if successCount == len(arrayNodeState.SubNodePhases.GetItems()) { arrayNodeState.Phase = v1alpha1.ArrayNodePhaseSucceeding } case v1alpha1.ArrayNodePhaseFailing: // TODO @hamersaw - abort everything! + fmt.Printf("HAMERSAW TODO - abort ArrayNode!\n") case v1alpha1.ArrayNodePhaseSucceeding: outputLiterals := make(map[string]*idlcore.Literal) for i, _ := range arrayNodeState.SubNodePhases.GetItems() { @@ -235,7 +269,8 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu if err != nil { return handler.UnknownTransition, err } else if executionErr != nil { - return handler.UnknownTransition, executionErr + // TODO @hamersaw handle executionErr + //return handler.UnknownTransition, executionErr } // copy individual subNode output literals into a collection of output literals diff --git a/pkg/controller/nodes/interfaces/state.go b/pkg/controller/nodes/interfaces/state.go index e83bb8a65..4edde397b 100644 --- a/pkg/controller/nodes/interfaces/state.go +++ b/pkg/controller/nodes/interfaces/state.go @@ -50,8 +50,11 @@ type GateNodeState struct { } type ArrayNodeState struct { - Phase v1alpha1.ArrayNodePhase - SubNodePhases bitarray.CompactArray + Phase v1alpha1.ArrayNodePhase + SubNodePhases bitarray.CompactArray + SubNodeTaskPhases bitarray.CompactArray + SubNodeRetryAttempts bitarray.CompactArray + SubNodeSystemFailures bitarray.CompactArray } type NodeStateWriter interface { diff --git a/pkg/controller/nodes/node_state_manager.go b/pkg/controller/nodes/node_state_manager.go index 70ac7ac5f..9a9582723 100644 --- a/pkg/controller/nodes/node_state_manager.go +++ b/pkg/controller/nodes/node_state_manager.go @@ -160,6 +160,9 @@ func (n nodeStateManager) GetArrayNodeState() interfaces.ArrayNodeState { if an != nil { as.Phase = an.GetArrayNodePhase() as.SubNodePhases = an.GetSubNodePhases() + as.SubNodeTaskPhases = an.GetSubNodeTaskPhases() + as.SubNodeRetryAttempts = an.GetSubNodeRetryAttempts() + as.SubNodeSystemFailures = an.GetSubNodeSystemFailures() } return as } diff --git a/pkg/controller/nodes/transformers.go b/pkg/controller/nodes/transformers.go index 44e2dffbb..5819419ae 100644 --- a/pkg/controller/nodes/transformers.go +++ b/pkg/controller/nodes/transformers.go @@ -289,5 +289,8 @@ func UpdateNodeStatus(np v1alpha1.NodePhase, p handler.PhaseInfo, n interfaces.N t := s.GetOrCreateArrayNodeStatus() t.SetArrayNodePhase(na.Phase) t.SetSubNodePhases(na.SubNodePhases) + t.SetSubNodeTaskPhases(na.SubNodeTaskPhases) + t.SetSubNodeRetryAttempts(na.SubNodeRetryAttempts) + t.SetSubNodeSystemFailures(na.SubNodeSystemFailures) } } From 7fee1e61fc797ba54bbe4ebbd9a6b880d2aab6de Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Wed, 26 Apr 2023 21:51:04 -0500 Subject: [PATCH 13/62] parallelism working Signed-off-by: Daniel Rammer --- .../cmd/testdata/array-node.yaml.golden | 1 + pkg/apis/flyteworkflow/v1alpha1/array.go | 5 ++ pkg/apis/flyteworkflow/v1alpha1/iface.go | 1 + pkg/compiler/transformers/k8s/node.go | 1 + .../nodes/array/execution_context.go | 69 ++++++++++++------- pkg/controller/nodes/array/handler.go | 26 ++++--- pkg/controller/nodes/executor.go | 2 + pkg/controller/nodes/task/handler.go | 14 +++- 8 files changed, 81 insertions(+), 38 deletions(-) diff --git a/cmd/kubectl-flyte/cmd/testdata/array-node.yaml.golden b/cmd/kubectl-flyte/cmd/testdata/array-node.yaml.golden index b42fd58b2..0b2e67500 100644 --- a/cmd/kubectl-flyte/cmd/testdata/array-node.yaml.golden +++ b/cmd/kubectl-flyte/cmd/testdata/array-node.yaml.golden @@ -76,6 +76,7 @@ workflow: var: x var: a arrayNode: + parallelism: 1 node: metadata: retries: diff --git a/pkg/apis/flyteworkflow/v1alpha1/array.go b/pkg/apis/flyteworkflow/v1alpha1/array.go index 0809c1a64..cdb3a59d9 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/array.go +++ b/pkg/apis/flyteworkflow/v1alpha1/array.go @@ -5,9 +5,14 @@ import ( type ArrayNodeSpec struct { SubNodeSpec *NodeSpec + Parallelism uint32 // TODO @hamersaw - fill out ArrayNodeSpec } func (a *ArrayNodeSpec) GetSubNodeSpec() *NodeSpec { return a.SubNodeSpec } + +func (a *ArrayNodeSpec) GetParallelism() uint32 { + return a.Parallelism +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/iface.go b/pkg/apis/flyteworkflow/v1alpha1/iface.go index 40ae70e7f..2f5c8cba5 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/iface.go +++ b/pkg/apis/flyteworkflow/v1alpha1/iface.go @@ -257,6 +257,7 @@ type ExecutableGateNode interface { type ExecutableArrayNode interface { GetSubNodeSpec() *NodeSpec + GetParallelism() uint32 // TODO @hamersaw - complete ExecutableArrayNode } diff --git a/pkg/compiler/transformers/k8s/node.go b/pkg/compiler/transformers/k8s/node.go index d2dde8ef1..42a7f4562 100644 --- a/pkg/compiler/transformers/k8s/node.go +++ b/pkg/compiler/transformers/k8s/node.go @@ -167,6 +167,7 @@ func buildNodeSpec(n *core.Node, tasks []*core.CompiledTask, errs errors.Compile nodeSpec.Kind = v1alpha1.NodeKindArray nodeSpec.ArrayNode = &v1alpha1.ArrayNodeSpec{ SubNodeSpec: subNodeSpecs[0], + Parallelism: arrayNode.Parallelism, } default: if n.GetId() == v1alpha1.StartNodeID { diff --git a/pkg/controller/nodes/array/execution_context.go b/pkg/controller/nodes/array/execution_context.go index cd58b7321..f3091aab0 100644 --- a/pkg/controller/nodes/array/execution_context.go +++ b/pkg/controller/nodes/array/execution_context.go @@ -2,6 +2,7 @@ package array import ( "context" + "fmt" "strconv" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" @@ -18,49 +19,63 @@ const ( type arrayExecutionContext struct { executors.ExecutionContext - executionConfig v1alpha1.ExecutionConfig + executionConfig v1alpha1.ExecutionConfig + currentParallelism *uint32 } -func (a arrayExecutionContext) GetExecutionConfig() v1alpha1.ExecutionConfig { +func (a *arrayExecutionContext) GetExecutionConfig() v1alpha1.ExecutionConfig { return a.executionConfig } -func newArrayExecutionContext(executionContext executors.ExecutionContext, subNodeIndex int) arrayExecutionContext { +func (a *arrayExecutionContext) CurrentParallelism() uint32 { + return *a.currentParallelism +} + +func (a *arrayExecutionContext) IncrementParallelism() uint32 { + *a.currentParallelism = *a.currentParallelism+1 + fmt.Printf("HAMERSAW - increment parallelism %d\n", *a.currentParallelism) + return *a.currentParallelism +} + +func newArrayExecutionContext(executionContext executors.ExecutionContext, subNodeIndex int, currentParallelism *uint32, maxParallelism uint32) *arrayExecutionContext { executionConfig := executionContext.GetExecutionConfig() if executionConfig.EnvironmentVariables == nil { executionConfig.EnvironmentVariables = make(map[string]string) } executionConfig.EnvironmentVariables[JobIndexVarName] = FlyteK8sArrayIndexVarName executionConfig.EnvironmentVariables[FlyteK8sArrayIndexVarName] = strconv.Itoa(subNodeIndex) + + executionConfig.MaxParallelism = maxParallelism - return arrayExecutionContext{ - ExecutionContext: executionContext, - executionConfig: executionConfig, + return &arrayExecutionContext{ + ExecutionContext: executionContext, + executionConfig: executionConfig, + currentParallelism: currentParallelism, } } type arrayNodeExecutionContext struct { interfaces.NodeExecutionContext inputReader io.InputReader - executionContext arrayExecutionContext + executionContext *arrayExecutionContext nodeStatus *v1alpha1.NodeStatus } -func (a arrayNodeExecutionContext) InputReader() io.InputReader { +func (a *arrayNodeExecutionContext) InputReader() io.InputReader { return a.inputReader } -func (a arrayNodeExecutionContext) ExecutionContext() executors.ExecutionContext { +func (a *arrayNodeExecutionContext) ExecutionContext() executors.ExecutionContext { return a.executionContext } -func (a arrayNodeExecutionContext) NodeStatus() v1alpha1.ExecutableNodeStatus { +func (a *arrayNodeExecutionContext) NodeStatus() v1alpha1.ExecutableNodeStatus { return a.nodeStatus } -func newArrayNodeExecutionContext(nodeExecutionContext interfaces.NodeExecutionContext, inputReader io.InputReader, subNodeIndex int, nodeStatus *v1alpha1.NodeStatus) arrayNodeExecutionContext { - arrayExecutionContext := newArrayExecutionContext(nodeExecutionContext.ExecutionContext(), subNodeIndex) - return arrayNodeExecutionContext{ +func newArrayNodeExecutionContext(nodeExecutionContext interfaces.NodeExecutionContext, inputReader io.InputReader, subNodeIndex int, nodeStatus *v1alpha1.NodeStatus, currentParallelism *uint32, maxParallelism uint32) *arrayNodeExecutionContext { + arrayExecutionContext := newArrayExecutionContext(nodeExecutionContext.ExecutionContext(), subNodeIndex, currentParallelism, maxParallelism) + return &arrayNodeExecutionContext{ NodeExecutionContext: nodeExecutionContext, inputReader: inputReader, executionContext: arrayExecutionContext, @@ -70,11 +85,13 @@ func newArrayNodeExecutionContext(nodeExecutionContext interfaces.NodeExecutionC type arrayNodeExecutionContextBuilder struct { - nCtxBuilder interfaces.NodeExecutionContextBuilder - subNodeID v1alpha1.NodeID - subNodeIndex int - subNodeStatus *v1alpha1.NodeStatus - inputReader io.InputReader + nCtxBuilder interfaces.NodeExecutionContextBuilder + subNodeID v1alpha1.NodeID + subNodeIndex int + subNodeStatus *v1alpha1.NodeStatus + inputReader io.InputReader + currentParallelism *uint32 + maxParallelism uint32 } func (a *arrayNodeExecutionContextBuilder) BuildNodeExecutionContext(ctx context.Context, executionContext executors.ExecutionContext, @@ -88,20 +105,22 @@ func (a *arrayNodeExecutionContextBuilder) BuildNodeExecutionContext(ctx context if currentNodeID == a.subNodeID { // overwrite NodeExecutionContext for ArrayNode execution - nCtx = newArrayNodeExecutionContext(nCtx, a.inputReader, a.subNodeIndex, a.subNodeStatus) + nCtx = newArrayNodeExecutionContext(nCtx, a.inputReader, a.subNodeIndex, a.subNodeStatus, a.currentParallelism, a.maxParallelism) } return nCtx, nil } func newArrayNodeExecutionContextBuilder(nCtxBuilder interfaces.NodeExecutionContextBuilder, subNodeID v1alpha1.NodeID, - subNodeIndex int, subNodeStatus *v1alpha1.NodeStatus, inputReader io.InputReader) interfaces.NodeExecutionContextBuilder { + subNodeIndex int, subNodeStatus *v1alpha1.NodeStatus, inputReader io.InputReader, currentParallelism *uint32, maxParallelism uint32) interfaces.NodeExecutionContextBuilder { return &arrayNodeExecutionContextBuilder{ - nCtxBuilder: nCtxBuilder, - subNodeID: subNodeID, - subNodeIndex: subNodeIndex, - subNodeStatus: subNodeStatus, - inputReader: inputReader, + nCtxBuilder: nCtxBuilder, + subNodeID: subNodeID, + subNodeIndex: subNodeIndex, + subNodeStatus: subNodeStatus, + inputReader: inputReader, + currentParallelism: currentParallelism, + maxParallelism: maxParallelism, } } diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 8712a8550..73b674b6b 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -69,6 +69,7 @@ func (a *arrayNodeHandler) FinalizeRequired() bool { func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.Transition, error) { arrayNode := nCtx.Node().GetArrayNode() arrayNodeState := nCtx.NodeStateReader().GetArrayNodeState() + fmt.Printf("HAMERSAW - executing ArrayNode\n") switch arrayNodeState.Phase { case v1alpha1.ArrayNodePhaseNone: @@ -122,16 +123,17 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu } } - //fmt.Printf("HAMERSAW - created SubNodePhases with length '%d:%d'\n", size, len(arrayNodeState.SubNodePhases.GetItems())) + // transition ArrayNode to `ArrayNodePhaseExecuting` arrayNodeState.Phase = v1alpha1.ArrayNodePhaseExecuting case v1alpha1.ArrayNodePhaseExecuting: // process array node subnodes + currentParallelism := uint32(0) for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { + fmt.Printf("HAMERSAW - current parallelism %d '%d' max %d \n", i, currentParallelism, arrayNode.GetParallelism()) nodePhase := v1alpha1.NodePhase(nodePhaseUint64) taskPhase := int(arrayNodeState.SubNodeTaskPhases.GetItem(i)) - //fmt.Printf("HAMERSAW - evaluating node '%d' in phase '%d'\n", i, nodePhase) - fmt.Printf("HAMERSAW - evaluating node '%d' in node phase '%d' task phase '%d'\n", i, nodePhase, taskPhase) + //fmt.Printf("HAMERSAW - evaluating node '%d' in node phase '%d' task phase '%d'\n", i, nodePhase, taskPhase) // TODO @hamersaw fix - do not process nodes in terminal state //if nodes.IsTerminalNodePhase(nodePhase) { @@ -164,7 +166,6 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu subNodeSpec.ID = subNodeID subNodeSpec.Name = subNodeID - // TODO @hamersaw - store task phase and use to mock plugin state // TODO - if we want to support more plugin types we need to figure out the best way to store plugin state // currently just mocking based on node phase -> which works for all k8s plugins // we can not pre-allocated a bit array because max size is 256B and with 5k fanout node state = 1.28MB @@ -173,8 +174,6 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu if taskPhase == int(core.PhaseUndefined) || taskPhase == int(core.PhaseRetryableFailure) { pluginStateBytes = a.pluginStateBytesNotStarted } - // TODO @hamerssaw NEED TO FIGURE THIS ^^^ OUT when working with node retries - // Failed to find the Resource with name: flytesnacks-development/array-test-40-node-1-n1-1. Error: pods \"array-test-40-node-1-n1-1\" not found // we set subDataDir and subOutputDir to the node dirs because flytekit automatically appends subtask // index. however when we check completion status we need to manually append index - so in all cases @@ -198,8 +197,6 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu Attempts: uint32(arrayNodeState.SubNodeRetryAttempts.GetItem(i)), SystemFailures: uint32(arrayNodeState.SubNodeSystemFailures.GetItem(i)), TaskNodeStatus: &v1alpha1.TaskNodeStatus{ - // TODO @hamersaw - to get caching working we need to set to Undefined to force cache lookup - // once fastcache is done we dont care about the TaskNodeStatus Phase: taskPhase, PluginState: pluginStateBytes, }, @@ -208,9 +205,16 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu arrayNodeLookup := newArrayNodeLookup(nCtx.ContextualNodeLookup(), subNodeID, &subNodeSpec, subNodeStatus) // execute subNode through RecursiveNodeHandler - arrayNodeExecutionContextBuilder := newArrayNodeExecutionContextBuilder(a.nodeExecutor.GetNodeExecutionContextBuilder(), subNodeID, i, subNodeStatus, inputReader) + arrayNodeExecutionContextBuilder := newArrayNodeExecutionContextBuilder(a.nodeExecutor.GetNodeExecutionContextBuilder(), + subNodeID, i, subNodeStatus, inputReader, ¤tParallelism, arrayNode.GetParallelism()) + arrayExecutionContext := newArrayExecutionContext(nCtx.ExecutionContext(), i, ¤tParallelism, arrayNode.GetParallelism()) + /*arrayNodeExecutionContext, err := arrayNodeExecutionContextBuilder.BuildNodeExecutionContext(ctx, nCtx.ExecutionContext(), &arrayNodeLookup, subNodeID) + if err != nil { + return handler.UnknownTransition, err + }*/ + arrayNodeExecutor := a.nodeExecutor.WithNodeExecutionContextBuilder(arrayNodeExecutionContextBuilder) - _, err = arrayNodeExecutor.RecursiveNodeHandler(ctx, nCtx.ExecutionContext(), &arrayNodeLookup, &arrayNodeLookup, &subNodeSpec) + _, err = arrayNodeExecutor.RecursiveNodeHandler(ctx, arrayExecutionContext, &arrayNodeLookup, &arrayNodeLookup, &subNodeSpec) if err != nil { return handler.UnknownTransition, err } @@ -229,7 +233,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu arrayNodeState.SubNodeSystemFailures.SetItem(i, uint64(subNodeStatus.GetSystemFailures())) } - // TODO @hamersaw - determine summary phases + // process phases of subNodes to determine overall `ArrayNode` phase successCount := 0 failedCount := 0 for _, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { diff --git a/pkg/controller/nodes/executor.go b/pkg/controller/nodes/executor.go index b39787fc0..00f8a1b8c 100644 --- a/pkg/controller/nodes/executor.go +++ b/pkg/controller/nodes/executor.go @@ -139,6 +139,7 @@ func canHandleNode(phase v1alpha1.NodePhase) bool { func IsMaxParallelismAchieved(ctx context.Context, currentNode v1alpha1.ExecutableNode, currentPhase v1alpha1.NodePhase, execContext executors.ExecutionContext) bool { maxParallelism := execContext.GetExecutionConfig().MaxParallelism + //fmt.Printf("HAMERSAW - maxParallelism %d\n", maxParallelism) if maxParallelism == 0 { logger.Debugf(ctx, "Parallelism control disabled") return false @@ -193,6 +194,7 @@ func (c *recursiveNodeExecutor) RecursiveNodeHandler(ctx context.Context, execCo return interfaces.NodeStatusRunning, nil } + fmt.Printf("HAMERSAW executing %s %+v\n", currentNode.GetID(), IsMaxParallelismAchieved(ctx, currentNode, nodePhase, execContext)) if IsMaxParallelismAchieved(ctx, currentNode, nodePhase, execContext) { return interfaces.NodeStatusRunning, nil } diff --git a/pkg/controller/nodes/task/handler.go b/pkg/controller/nodes/task/handler.go index 572261d62..6fd7505fe 100644 --- a/pkg/controller/nodes/task/handler.go +++ b/pkg/controller/nodes/task/handler.go @@ -554,6 +554,16 @@ func (t Handler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContex ts := nCtx.NodeStateReader().GetTaskNodeState() pluginTrns := &pluginRequestedTransition{} + + // TODO @hamersaw - does this introduce issues in cache hits?!?! + // need to make sure the plugin transition does not block other workflows from progressing + defer func() { + if pluginTrns != nil && !pluginTrns.pInfo.Phase().IsTerminal() { + eCtx := nCtx.ExecutionContext() + logger.Infof(ctx, "Parallelism now set to [%d].", eCtx.IncrementParallelism()) + } + }() + // We will start with the assumption that catalog is disabled pluginTrns.PopulateCacheInfo(catalog.NewFailedCatalogEntry(catalog.NewStatus(core.CatalogCacheStatus_CACHE_DISABLED, nil))) @@ -784,10 +794,10 @@ func (t Handler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContex return handler.UnknownTransition, err } - if !pluginTrns.pInfo.Phase().IsTerminal() { + /*if !pluginTrns.pInfo.Phase().IsTerminal() { eCtx := nCtx.ExecutionContext() logger.Infof(ctx, "Parallelism now set to [%d].", eCtx.IncrementParallelism()) - } + }*/ return pluginTrns.FinalTransition(ctx) } From dc0c1644af74cf6738672da0d5390a9f947d8e0c Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Thu, 27 Apr 2023 10:11:32 -0500 Subject: [PATCH 14/62] cache and cache_serialize working - first new functionality in maptask Signed-off-by: Daniel Rammer --- .../array-node-cache-serialize.yaml.golden | 91 +++++++++++++++++++ .../cmd/testdata/array-node-cache.yaml.golden | 91 +++++++++++++++++++ .../testdata/array-node-inputs.yaml.golden | 3 + .../nodes/array/execution_context.go | 2 - pkg/controller/nodes/array/handler.go | 27 +++--- pkg/controller/nodes/executor.go | 2 - 6 files changed, 196 insertions(+), 20 deletions(-) create mode 100644 cmd/kubectl-flyte/cmd/testdata/array-node-cache-serialize.yaml.golden create mode 100644 cmd/kubectl-flyte/cmd/testdata/array-node-cache.yaml.golden diff --git a/cmd/kubectl-flyte/cmd/testdata/array-node-cache-serialize.yaml.golden b/cmd/kubectl-flyte/cmd/testdata/array-node-cache-serialize.yaml.golden new file mode 100644 index 000000000..9eff40124 --- /dev/null +++ b/cmd/kubectl-flyte/cmd/testdata/array-node-cache-serialize.yaml.golden @@ -0,0 +1,91 @@ +tasks: +- container: + args: + - "pyflyte-fast-execute" + - "--additional-distribution" + - "s3://my-s3-bucket/flytesnacks/development/SMJBJX7BQJ6MCOABLKQT5VZXVY======/script_mode.tar.gz" + - "--dest-dir" + - "/root" + - "--" + - "pyflyte-map-execute" + - "--inputs" + - "{{.input}}" + - "--output-prefix" + - "{{.outputPrefix}}" + - "--raw-output-data-prefix" + - "{{.rawOutputDataPrefix}}" + - "--checkpoint-path" + - "{{.checkpointOutputPrefix}}" + - "--prev-checkpoint" + - "{{.prevCheckpointPrefix}}" + - "--resolver" + - "MapTaskResolver" + - "--" + - "vars" + - "" + - "resolver" + - "flytekit.core.python_auto_container.default_task_resolver" + - "task-module" + - "map-task" + - "task-name" + - "a_mappable_task" + image: "array-node:ee1ba227aa95447d04bb1761691b4d97749642dc" + resources: + limits: + - name: 1 + value: "1" + - name: 3 + value: "500Mi" + requests: + - name: 1 + value: "1" + - name: 3 + value: "300Mi" + id: + name: task-1 + project: flytesnacks + domain: development + metadata: + discoverable: true + discovery_version: "1.0" + cache_serializable: true + interface: + inputs: + variables: + a: + type: + simple: INTEGER + outputs: + variables: + o0: + type: + simple: STRING +workflow: + id: + name: workflow-with-array-node + interface: + inputs: + variables: + x: + type: + collectionType: + simple: INTEGER + nodes: + - id: node-1 + inputs: + - binding: + promise: + node_id: start-node + var: x + var: a + arrayNode: + parallelism: 0 + node: + metadata: + retries: + retries: 3 + taskNode: + referenceId: + name: task-1 + project: flytesnacks + domain: development diff --git a/cmd/kubectl-flyte/cmd/testdata/array-node-cache.yaml.golden b/cmd/kubectl-flyte/cmd/testdata/array-node-cache.yaml.golden new file mode 100644 index 000000000..bb07a9dd5 --- /dev/null +++ b/cmd/kubectl-flyte/cmd/testdata/array-node-cache.yaml.golden @@ -0,0 +1,91 @@ +tasks: +- container: + args: + - "pyflyte-fast-execute" + - "--additional-distribution" + - "s3://my-s3-bucket/flytesnacks/development/SMJBJX7BQJ6MCOABLKQT5VZXVY======/script_mode.tar.gz" + - "--dest-dir" + - "/root" + - "--" + - "pyflyte-map-execute" + - "--inputs" + - "{{.input}}" + - "--output-prefix" + - "{{.outputPrefix}}" + - "--raw-output-data-prefix" + - "{{.rawOutputDataPrefix}}" + - "--checkpoint-path" + - "{{.checkpointOutputPrefix}}" + - "--prev-checkpoint" + - "{{.prevCheckpointPrefix}}" + - "--resolver" + - "MapTaskResolver" + - "--" + - "vars" + - "" + - "resolver" + - "flytekit.core.python_auto_container.default_task_resolver" + - "task-module" + - "map-task" + - "task-name" + - "a_mappable_task" + image: "array-node:ee1ba227aa95447d04bb1761691b4d97749642dc" + resources: + limits: + - name: 1 + value: "1" + - name: 3 + value: "500Mi" + requests: + - name: 1 + value: "1" + - name: 3 + value: "300Mi" + id: + name: task-1 + project: flytesnacks + domain: development + metadata: + discoverable: true + discovery_version: "1.0" + cache_serializable: false + interface: + inputs: + variables: + a: + type: + simple: INTEGER + outputs: + variables: + o0: + type: + simple: STRING +workflow: + id: + name: workflow-with-array-node + interface: + inputs: + variables: + x: + type: + collectionType: + simple: INTEGER + nodes: + - id: node-1 + inputs: + - binding: + promise: + node_id: start-node + var: x + var: a + arrayNode: + parallelism: 1 + node: + metadata: + retries: + retries: 3 + taskNode: + referenceId: + name: task-1 + project: flytesnacks + domain: development diff --git a/cmd/kubectl-flyte/cmd/testdata/array-node-inputs.yaml.golden b/cmd/kubectl-flyte/cmd/testdata/array-node-inputs.yaml.golden index 3f9d69172..42e176686 100755 --- a/cmd/kubectl-flyte/cmd/testdata/array-node-inputs.yaml.golden +++ b/cmd/kubectl-flyte/cmd/testdata/array-node-inputs.yaml.golden @@ -8,3 +8,6 @@ literals: - scalar: primitive: integer: "2" + - scalar: + primitive: + integer: "3" diff --git a/pkg/controller/nodes/array/execution_context.go b/pkg/controller/nodes/array/execution_context.go index f3091aab0..cc126796a 100644 --- a/pkg/controller/nodes/array/execution_context.go +++ b/pkg/controller/nodes/array/execution_context.go @@ -2,7 +2,6 @@ package array import ( "context" - "fmt" "strconv" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" @@ -33,7 +32,6 @@ func (a *arrayExecutionContext) CurrentParallelism() uint32 { func (a *arrayExecutionContext) IncrementParallelism() uint32 { *a.currentParallelism = *a.currentParallelism+1 - fmt.Printf("HAMERSAW - increment parallelism %d\n", *a.currentParallelism) return *a.currentParallelism } diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 73b674b6b..297644f5d 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -69,7 +69,6 @@ func (a *arrayNodeHandler) FinalizeRequired() bool { func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.Transition, error) { arrayNode := nCtx.Node().GetArrayNode() arrayNodeState := nCtx.NodeStateReader().GetArrayNodeState() - fmt.Printf("HAMERSAW - executing ArrayNode\n") switch arrayNodeState.Phase { case v1alpha1.ArrayNodePhaseNone: @@ -129,7 +128,6 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // process array node subnodes currentParallelism := uint32(0) for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { - fmt.Printf("HAMERSAW - current parallelism %d '%d' max %d \n", i, currentParallelism, arrayNode.GetParallelism()) nodePhase := v1alpha1.NodePhase(nodePhaseUint64) taskPhase := int(arrayNodeState.SubNodeTaskPhases.GetItem(i)) @@ -141,14 +139,11 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu continue } - // initialize input reader if NodePhaseNotyetStarted or NodePhaseSucceeding for cache lookup and population - var inputLiteralMap *idlcore.LiteralMap - var err error - if nodePhase == v1alpha1.NodePhaseNotYetStarted || nodePhase == v1alpha1.NodePhaseSucceeding { - inputLiteralMap, err = constructLiteralMap(ctx, nCtx.InputReader(), i) - if err != nil { - return handler.UnknownTransition, err - } + // need to initialize the inputReader everytime to ensure TaskHandler can access for cache lookups / population + // TODO @hamersaw - once fastcache is implemented this can be optimized to only initialize on NodePhaseUndefined and NodePhaseSucceeding + inputLiteralMap, err := constructLiteralMap(ctx, nCtx.InputReader(), i) + if err != nil { + return handler.UnknownTransition, err } inputReader := newStaticInputReader(nCtx.InputReader(), inputLiteralMap) @@ -180,11 +175,15 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // where the node phase is not Queued (ie. task handler will launch task and init flytekit params) we // append the subtask index. var subDataDir, subOutputDir storage.DataReference - if nodePhase == v1alpha1.NodePhaseQueued { + /*if nodePhase == v1alpha1.NodePhaseQueued { subDataDir, subOutputDir, err = constructOutputReferences(ctx, nCtx) } else { subDataDir, subOutputDir, err = constructOutputReferences(ctx, nCtx, strconv.Itoa(i)) - } + }*/ + // TODO @hamersaw - this is a problem because cache lookups happen in NodePhaseQueued + // so the cache hit items will be written to the wrong location + // can we just change flytekit appending the index onto the location?!?1 + subDataDir, subOutputDir, err = constructOutputReferences(ctx, nCtx, strconv.Itoa(i)) if err != nil { return handler.UnknownTransition, err @@ -208,10 +207,6 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu arrayNodeExecutionContextBuilder := newArrayNodeExecutionContextBuilder(a.nodeExecutor.GetNodeExecutionContextBuilder(), subNodeID, i, subNodeStatus, inputReader, ¤tParallelism, arrayNode.GetParallelism()) arrayExecutionContext := newArrayExecutionContext(nCtx.ExecutionContext(), i, ¤tParallelism, arrayNode.GetParallelism()) - /*arrayNodeExecutionContext, err := arrayNodeExecutionContextBuilder.BuildNodeExecutionContext(ctx, nCtx.ExecutionContext(), &arrayNodeLookup, subNodeID) - if err != nil { - return handler.UnknownTransition, err - }*/ arrayNodeExecutor := a.nodeExecutor.WithNodeExecutionContextBuilder(arrayNodeExecutionContextBuilder) _, err = arrayNodeExecutor.RecursiveNodeHandler(ctx, arrayExecutionContext, &arrayNodeLookup, &arrayNodeLookup, &subNodeSpec) diff --git a/pkg/controller/nodes/executor.go b/pkg/controller/nodes/executor.go index 00f8a1b8c..b39787fc0 100644 --- a/pkg/controller/nodes/executor.go +++ b/pkg/controller/nodes/executor.go @@ -139,7 +139,6 @@ func canHandleNode(phase v1alpha1.NodePhase) bool { func IsMaxParallelismAchieved(ctx context.Context, currentNode v1alpha1.ExecutableNode, currentPhase v1alpha1.NodePhase, execContext executors.ExecutionContext) bool { maxParallelism := execContext.GetExecutionConfig().MaxParallelism - //fmt.Printf("HAMERSAW - maxParallelism %d\n", maxParallelism) if maxParallelism == 0 { logger.Debugf(ctx, "Parallelism control disabled") return false @@ -194,7 +193,6 @@ func (c *recursiveNodeExecutor) RecursiveNodeHandler(ctx context.Context, execCo return interfaces.NodeStatusRunning, nil } - fmt.Printf("HAMERSAW executing %s %+v\n", currentNode.GetID(), IsMaxParallelismAchieved(ctx, currentNode, nodePhase, execContext)) if IsMaxParallelismAchieved(ctx, currentNode, nodePhase, execContext) { return interfaces.NodeStatusRunning, nil } From 0f880cd0e586452a0dd8cb1a24bfa7109ddd9614 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Fri, 28 Apr 2023 10:18:55 -0500 Subject: [PATCH 15/62] adding implementation notes Signed-off-by: Daniel Rammer --- cmd/kubectl-flyte/cmd/testdata/array-node.yaml.golden | 2 +- pkg/controller/nodes/array/handler.go | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/cmd/kubectl-flyte/cmd/testdata/array-node.yaml.golden b/cmd/kubectl-flyte/cmd/testdata/array-node.yaml.golden index 0b2e67500..0b370417a 100644 --- a/cmd/kubectl-flyte/cmd/testdata/array-node.yaml.golden +++ b/cmd/kubectl-flyte/cmd/testdata/array-node.yaml.golden @@ -29,7 +29,7 @@ tasks: - "map-task" - "task-name" - "a_mappable_task" - image: cr.flyte.org/flyteorg/flytekit:py3.10-latest + image: "array-node:ee1ba227aa95447d04bb1761691b4d97749642dc" resources: limits: - name: 1 diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 297644f5d..2b4cb7402 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -151,6 +151,8 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // if node has not yet started we automatically set to NodePhaseQueued to skip input resolution if nodePhase == v1alpha1.NodePhaseNotYetStarted { // TODO @hamersaw how does this work with fastcache? + // to supprt fastcache we'll need to override the bindings to BindingScalars for the input resolution on the nCtx + // that way we resolution is just reading a literal ... but does this still write a file then?!? nodePhase = v1alpha1.NodePhaseQueued } @@ -174,8 +176,8 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // index. however when we check completion status we need to manually append index - so in all cases // where the node phase is not Queued (ie. task handler will launch task and init flytekit params) we // append the subtask index. - var subDataDir, subOutputDir storage.DataReference - /*if nodePhase == v1alpha1.NodePhaseQueued { + /*var subDataDir, subOutputDir storage.DataReference + if nodePhase == v1alpha1.NodePhaseQueued { subDataDir, subOutputDir, err = constructOutputReferences(ctx, nCtx) } else { subDataDir, subOutputDir, err = constructOutputReferences(ctx, nCtx, strconv.Itoa(i)) @@ -183,8 +185,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // TODO @hamersaw - this is a problem because cache lookups happen in NodePhaseQueued // so the cache hit items will be written to the wrong location // can we just change flytekit appending the index onto the location?!?1 - subDataDir, subOutputDir, err = constructOutputReferences(ctx, nCtx, strconv.Itoa(i)) - + subDataDir, subOutputDir, err := constructOutputReferences(ctx, nCtx, strconv.Itoa(i)) if err != nil { return handler.UnknownTransition, err } From e7406276e27558497ca644c2d3ada221ae26e505 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Wed, 24 May 2023 07:56:27 -0500 Subject: [PATCH 16/62] removed eventing from subtasks Signed-off-by: Daniel Rammer --- pkg/compiler/transformers/k8s/node.go | 8 ++ .../nodes/array/execution_context.go | 28 +++++- pkg/controller/nodes/array/handler.go | 4 + pkg/controller/nodes/executor.go | 20 ++-- pkg/controller/nodes/interfaces/node.go | 1 - .../nodes/interfaces/node_exec_context.go | 8 +- pkg/controller/nodes/node_exec_context.go | 92 +++++++++++++++++-- ...recorder.go => task_event_recorder.go.bak} | 0 ...est.go => task_event_recorder_test.go.bak} | 0 9 files changed, 139 insertions(+), 22 deletions(-) rename pkg/controller/nodes/{task_event_recorder.go => task_event_recorder.go.bak} (100%) rename pkg/controller/nodes/{task_event_recorder_test.go => task_event_recorder_test.go.bak} (100%) diff --git a/pkg/compiler/transformers/k8s/node.go b/pkg/compiler/transformers/k8s/node.go index 53175c38d..78be27ec7 100644 --- a/pkg/compiler/transformers/k8s/node.go +++ b/pkg/compiler/transformers/k8s/node.go @@ -167,6 +167,14 @@ func buildNodeSpec(n *core.Node, tasks []*core.CompiledTask, errs errors.Compile SubNodeSpec: subNodeSpecs[0], Parallelism: arrayNode.Parallelism, } + + // TODO @hamersaw hack - should not be necessary, should be set in flytekit + for _, binding := range nodeSpec.InputBindings { + switch b := binding.Binding.Binding.Value.(type) { + case *core.BindingData_Promise: + b.Promise.NodeId = "start-node" + } + } default: if n.GetId() == v1alpha1.StartNodeID { nodeSpec.Kind = v1alpha1.NodeKindStart diff --git a/pkg/controller/nodes/array/execution_context.go b/pkg/controller/nodes/array/execution_context.go index cc126796a..80ff6d72c 100644 --- a/pkg/controller/nodes/array/execution_context.go +++ b/pkg/controller/nodes/array/execution_context.go @@ -4,9 +4,12 @@ import ( "context" "strconv" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/config" "github.com/flyteorg/flytepropeller/pkg/controller/executors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" ) @@ -52,21 +55,36 @@ func newArrayExecutionContext(executionContext executors.ExecutionContext, subNo } } +type arrayEventRecorder struct {} + +func (a *arrayEventRecorder) RecordNodeEvent(ctx context.Context, event *event.NodeExecutionEvent, eventConfig *config.EventConfig) error { + return nil +} + +func (a *arrayEventRecorder) RecordTaskEvent(ctx context.Context, event *event.TaskExecutionEvent, eventConfig *config.EventConfig) error { + return nil +} + type arrayNodeExecutionContext struct { interfaces.NodeExecutionContext - inputReader io.InputReader + eventRecorder *arrayEventRecorder executionContext *arrayExecutionContext + inputReader io.InputReader nodeStatus *v1alpha1.NodeStatus } -func (a *arrayNodeExecutionContext) InputReader() io.InputReader { - return a.inputReader +func (a *arrayNodeExecutionContext) EventsRecorder() interfaces.EventRecorder { + return a.eventRecorder } func (a *arrayNodeExecutionContext) ExecutionContext() executors.ExecutionContext { return a.executionContext } +func (a *arrayNodeExecutionContext) InputReader() io.InputReader { + return a.inputReader +} + func (a *arrayNodeExecutionContext) NodeStatus() v1alpha1.ExecutableNodeStatus { return a.nodeStatus } @@ -75,13 +93,13 @@ func newArrayNodeExecutionContext(nodeExecutionContext interfaces.NodeExecutionC arrayExecutionContext := newArrayExecutionContext(nodeExecutionContext.ExecutionContext(), subNodeIndex, currentParallelism, maxParallelism) return &arrayNodeExecutionContext{ NodeExecutionContext: nodeExecutionContext, - inputReader: inputReader, + eventRecorder: &arrayEventRecorder{}, executionContext: arrayExecutionContext, + inputReader: inputReader, nodeStatus: nodeStatus, } } - type arrayNodeExecutionContextBuilder struct { nCtxBuilder interfaces.NodeExecutionContextBuilder subNodeID v1alpha1.NodeID diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 2b4cb7402..2cce80c0b 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -109,6 +109,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu maxAttempts = *subNodeSpec.GetRetryStrategy().MinAttempts } + fmt.Printf("HAMERSAW - maxAttempts %d\n", maxAttempts) for _, item := range []struct{arrayReference *bitarray.CompactArray; maxValue int}{ {arrayReference: &arrayNodeState.SubNodePhases, maxValue: len(core.Phases)-1}, // TODO @hamersaw - maxValue is for task phases {arrayReference: &arrayNodeState.SubNodeTaskPhases, maxValue: len(core.Phases)-1}, @@ -205,6 +206,8 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu arrayNodeLookup := newArrayNodeLookup(nCtx.ContextualNodeLookup(), subNodeID, &subNodeSpec, subNodeStatus) // execute subNode through RecursiveNodeHandler + // TODO @hamersaw - if recursiveNodeHandler is exported then can we just create a new one without needing the + // new GetNodeExecutionContextBuilder and WithNodeExecutionContextBuilder functions? arrayNodeExecutionContextBuilder := newArrayNodeExecutionContextBuilder(a.nodeExecutor.GetNodeExecutionContextBuilder(), subNodeID, i, subNodeStatus, inputReader, ¤tParallelism, arrayNode.GetParallelism()) arrayExecutionContext := newArrayExecutionContext(nCtx.ExecutionContext(), i, ¤tParallelism, arrayNode.GetParallelism()) @@ -225,6 +228,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu } else { arrayNodeState.SubNodeTaskPhases.SetItem(i, uint64(subNodeStatus.GetTaskNodeStatus().GetPhase())) } + fmt.Printf("HAMERSAW - setting %d to %d\n", i, uint64(subNodeStatus.GetAttempts())) arrayNodeState.SubNodeRetryAttempts.SetItem(i, uint64(subNodeStatus.GetAttempts())) arrayNodeState.SubNodeSystemFailures.SetItem(i, uint64(subNodeStatus.GetSystemFailures())) } diff --git a/pkg/controller/nodes/executor.go b/pkg/controller/nodes/executor.go index bad2a9a96..72cf83326 100644 --- a/pkg/controller/nodes/executor.go +++ b/pkg/controller/nodes/executor.go @@ -502,7 +502,7 @@ func (c *nodeExecutor) RecordTransitionLatency(ctx context.Context, dag executor } } -func (c *nodeExecutor) IdempotentRecordEvent(ctx context.Context, nodeEvent *event.NodeExecutionEvent) error { +/*func (c *nodeExecutor) IdempotentRecordEvent(ctx context.Context, nodeEvent *event.NodeExecutionEvent) error { if nodeEvent == nil { return fmt.Errorf("event recording attempt of Nil Node execution event") } @@ -534,7 +534,7 @@ func (c *nodeExecutor) IdempotentRecordEvent(ctx context.Context, nodeEvent *eve } } return err -} +}*/ func (c *nodeExecutor) recoverInputs(ctx context.Context, nCtx interfaces.NodeExecutionContext, recovered *admin.NodeExecution, recoveredData *admin.NodeExecutionGetDataResponse) (*core.LiteralMap, error) { @@ -911,7 +911,8 @@ func (c *nodeExecutor) Abort(ctx context.Context, h handler.Node, nCtx interface nodeExecutionID.NodeId = currentNodeUniqueID } - err := c.IdempotentRecordEvent(ctx, &event.NodeExecutionEvent{ + //err := c.IdempotentRecordEvent(ctx, &event.NodeExecutionEvent{ + err := nCtx.EventsRecorder().RecordNodeEvent(ctx, &event.NodeExecutionEvent{ Id: nodeExecutionID, Phase: core.NodeExecution_ABORTED, OccurredAt: ptypes.TimestampNow(), @@ -923,7 +924,7 @@ func (c *nodeExecutor) Abort(ctx context.Context, h handler.Node, nCtx interface }, ProducerId: c.clusterID, ReportedAt: ptypes.TimestampNow(), - }) + }, c.eventConfig) if err != nil && !eventsErr.IsNotFound(err) && !eventsErr.IsEventIncompatibleClusterError(err) { if errors2.IsCausedBy(err, errors.IllegalStateError) { logger.Debugf(ctx, "Failed to record abort event due to illegal state transition. Ignoring the error. Error: %v", err) @@ -975,7 +976,8 @@ func (c *nodeExecutor) handleNotYetStartedNode(ctx context.Context, dag executor if err != nil { return interfaces.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "could not convert phase info to event") } - err = c.IdempotentRecordEvent(ctx, nev) + //err = c.IdempotentRecordEvent(ctx, nev) + err = nCtx.EventsRecorder().RecordNodeEvent(ctx, nev, c.eventConfig) if err != nil { logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) return interfaces.NodeStatusUndefined, errors.Wrapf(errors.EventRecordingFailed, nCtx.NodeID(), err, "failed to record node event") @@ -1092,7 +1094,8 @@ func (c *nodeExecutor) handleQueuedOrRunningNode(ctx context.Context, nCtx inter return interfaces.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "could not convert phase info to event") } - err = c.IdempotentRecordEvent(ctx, nev) + //err = c.IdempotentRecordEvent(ctx, nev) + err = nCtx.EventsRecorder().RecordNodeEvent(ctx, nev, c.eventConfig) if err != nil { if eventsErr.IsTooLarge(err) { // With large enough dynamic task fanouts the reported node event, which contains the compiled @@ -1101,7 +1104,8 @@ func (c *nodeExecutor) handleQueuedOrRunningNode(ctx context.Context, nCtx inter np = v1alpha1.NodePhaseFailing p = handler.PhaseInfoFailure(core.ExecutionError_USER, "NodeFailed", err.Error(), p.GetInfo()) - err = c.IdempotentRecordEvent(ctx, &event.NodeExecutionEvent{ + //err = c.IdempotentRecordEvent(ctx, &event.NodeExecutionEvent{ + err = nCtx.EventsRecorder().RecordNodeEvent(ctx, &event.NodeExecutionEvent{ Id: nCtx.NodeExecutionMetadata().GetNodeExecutionID(), Phase: core.NodeExecution_FAILED, OccurredAt: ptypes.TimestampNow(), @@ -1112,7 +1116,7 @@ func (c *nodeExecutor) handleQueuedOrRunningNode(ctx context.Context, nCtx inter }, }, ReportedAt: ptypes.TimestampNow(), - }) + }, c.eventConfig) if err != nil { return interfaces.NodeStatusUndefined, errors.Wrapf(errors.EventRecordingFailed, nCtx.NodeID(), err, "failed to record node event") diff --git a/pkg/controller/nodes/interfaces/node.go b/pkg/controller/nodes/interfaces/node.go index e279de1c7..0f1b56e22 100644 --- a/pkg/controller/nodes/interfaces/node.go +++ b/pkg/controller/nodes/interfaces/node.go @@ -96,7 +96,6 @@ type Node interface { // TODO @hamersaw - docs type NodeExecutionContextBuilder interface { - //BuildNodeExecutionContext(execContext executors.ExecutionContext) NodeExecutionContext BuildNodeExecutionContext(ctx context.Context, executionContext executors.ExecutionContext, nl executors.NodeLookup, currentNodeID v1alpha1.NodeID) (NodeExecutionContext, error) } diff --git a/pkg/controller/nodes/interfaces/node_exec_context.go b/pkg/controller/nodes/interfaces/node_exec_context.go index db33c303c..3b8afc384 100644 --- a/pkg/controller/nodes/interfaces/node_exec_context.go +++ b/pkg/controller/nodes/interfaces/node_exec_context.go @@ -24,6 +24,11 @@ type TaskReader interface { GetTaskID() *core.Identifier } +type EventRecorder interface { + events.TaskEventRecorder + events.NodeEventRecorder +} + type NodeExecutionMetadata interface { GetOwnerID() types.NamespacedName GetNodeExecutionID() *core.NodeExecutionIdentifier @@ -49,7 +54,8 @@ type NodeExecutionContext interface { DataStore() *storage.DataStore InputReader() io.InputReader - EventsRecorder() events.TaskEventRecorder + //EventsRecorder() events.TaskEventRecorder + EventsRecorder() EventRecorder NodeID() v1alpha1.NodeID Node() v1alpha1.ExecutableNode CurrentAttempt() uint32 diff --git a/pkg/controller/nodes/node_exec_context.go b/pkg/controller/nodes/node_exec_context.go index 681c32f48..6506ae527 100644 --- a/pkg/controller/nodes/node_exec_context.go +++ b/pkg/controller/nodes/node_exec_context.go @@ -6,23 +6,90 @@ import ( "strconv" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flytestdlib/storage" - "k8s.io/apimachinery/pkg/types" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" "github.com/flyteorg/flytepropeller/events" + eventsErr "github.com/flyteorg/flytepropeller/events/errors" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/config" "github.com/flyteorg/flytepropeller/pkg/controller/executors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + nodeerrors "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" "github.com/flyteorg/flytepropeller/pkg/utils" + + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/storage" + + "github.com/pkg/errors" + + "k8s.io/apimachinery/pkg/types" ) const NodeIDLabel = "node-id" const TaskNameLabel = "task-name" const NodeInterruptibleLabel = "interruptible" +type eventRecorder struct { + taskEventRecorder events.TaskEventRecorder + nodeEventRecorder events.NodeEventRecorder +} + +func (e eventRecorder) RecordTaskEvent(ctx context.Context, ev *event.TaskExecutionEvent, eventConfig *config.EventConfig) error { + if err := e.taskEventRecorder.RecordTaskEvent(ctx, ev, eventConfig); err != nil { + if eventsErr.IsAlreadyExists(err) { + logger.Warningf(ctx, "Failed to record taskEvent, error [%s]. Trying to record state: %s. Ignoring this error!", err.Error(), ev.Phase) + return nil + } else if eventsErr.IsEventAlreadyInTerminalStateError(err) { + if IsTerminalTaskPhase(ev.Phase) { + // Event is terminal and the stored value in flyteadmin is already terminal. This implies aborted case. So ignoring + logger.Warningf(ctx, "Failed to record taskEvent, error [%s]. Trying to record state: %s. Ignoring this error!", err.Error(), ev.Phase) + return nil + } + logger.Warningf(ctx, "Failed to record taskEvent in state: %s, error: %s", ev.Phase, err) + return errors.Wrapf(err, "failed to record task event, as it already exists in terminal state. Event state: %s", ev.Phase) + } + return err + } + return nil +} + +func (e eventRecorder) RecordNodeEvent(ctx context.Context, nodeEvent *event.NodeExecutionEvent, eventConfig *config.EventConfig) error { + if nodeEvent == nil { + return fmt.Errorf("event recording attempt of Nil Node execution event") + } + + if nodeEvent.Id == nil { + return fmt.Errorf("event recording attempt of with nil node Event ID") + } + + logger.Infof(ctx, "Recording NodeEvent [%s] phase[%s]", nodeEvent.GetId().String(), nodeEvent.Phase.String()) + err := e.nodeEventRecorder.RecordNodeEvent(ctx, nodeEvent, eventConfig) + if err != nil { + if nodeEvent.GetId().NodeId == v1alpha1.EndNodeID { + return nil + } + + if eventsErr.IsAlreadyExists(err) { + logger.Infof(ctx, "Node event phase: %s, nodeId %s already exist", + nodeEvent.Phase.String(), nodeEvent.GetId().NodeId) + return nil + } else if eventsErr.IsEventAlreadyInTerminalStateError(err) { + if IsTerminalNodePhase(nodeEvent.Phase) { + // Event was trying to record a different terminal phase for an already terminal event. ignoring. + logger.Infof(ctx, "Node event phase: %s, nodeId %s already in terminal phase. err: %s", + nodeEvent.Phase.String(), nodeEvent.GetId().NodeId, err.Error()) + return nil + } + logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) + return nodeerrors.Wrapf(nodeerrors.IllegalStateError, nodeEvent.Id.NodeId, err, "phase mis-match mismatch between propeller and control plane; Trying to record Node p: %s", nodeEvent.Phase) + } + } + return err +} + type nodeExecMetadata struct { v1alpha1.Meta nodeExecID *core.NodeExecutionIdentifier @@ -59,7 +126,8 @@ type nodeExecContext struct { store *storage.DataStore tr interfaces.TaskReader md interfaces.NodeExecutionMetadata - er events.TaskEventRecorder + eventRecorder interfaces.EventRecorder + //er events.TaskEventRecorder inputs io.InputReader node v1alpha1.ExecutableNode nodeStatus v1alpha1.ExecutableNodeStatus @@ -112,8 +180,12 @@ func (e nodeExecContext) InputReader() io.InputReader { return e.inputs } -func (e nodeExecContext) EventsRecorder() events.TaskEventRecorder { +/*func (e nodeExecContext) EventsRecorder() events.TaskEventRecorder { return e.er +}*/ + +func (e nodeExecContext) EventsRecorder() interfaces.EventRecorder { + return e.eventRecorder } func (e nodeExecContext) NodeID() v1alpha1.NodeID { @@ -142,7 +214,7 @@ func (e nodeExecContext) MaxDatasetSizeBytes() int64 { func newNodeExecContext(_ context.Context, store *storage.DataStore, execContext executors.ExecutionContext, nl executors.NodeLookup, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus, inputs io.InputReader, interruptible bool, interruptibleFailureThreshold uint32, - maxDatasetSize int64, er events.TaskEventRecorder, tr interfaces.TaskReader, nsm *nodeStateManager, + maxDatasetSize int64, taskEventRecorder events.TaskEventRecorder, nodeEventRecorder events.NodeEventRecorder, tr interfaces.TaskReader, nsm *nodeStateManager, enqueueOwner func() error, rawOutputPrefix storage.DataReference, outputShardSelector ioutils.ShardSelector) *nodeExecContext { md := nodeExecMetadata{ @@ -173,7 +245,11 @@ func newNodeExecContext(_ context.Context, store *storage.DataStore, execContext node: node, nodeStatus: nodeStatus, inputs: inputs, - er: er, + eventRecorder: &eventRecorder{ + taskEventRecorder: taskEventRecorder, + nodeEventRecorder: nodeEventRecorder, + }, + //er: er, maxDatasetSizeBytes: maxDatasetSize, tr: tr, nsm: nsm, @@ -243,7 +319,9 @@ func (c *nodeExecutor) BuildNodeExecutionContext(ctx context.Context, executionC interruptible, c.interruptibleFailureThreshold, c.maxDatasetSizeBytes, - &taskEventRecorder{TaskEventRecorder: c.taskRecorder}, + //&taskEventRecorder{TaskEventRecorder: c.taskRecorder}, + c.taskRecorder, + c.nodeRecorder, tr, newNodeStateManager(ctx, s), workflowEnqueuer, diff --git a/pkg/controller/nodes/task_event_recorder.go b/pkg/controller/nodes/task_event_recorder.go.bak similarity index 100% rename from pkg/controller/nodes/task_event_recorder.go rename to pkg/controller/nodes/task_event_recorder.go.bak diff --git a/pkg/controller/nodes/task_event_recorder_test.go b/pkg/controller/nodes/task_event_recorder_test.go.bak similarity index 100% rename from pkg/controller/nodes/task_event_recorder_test.go rename to pkg/controller/nodes/task_event_recorder_test.go.bak From b8271ef82a97f34c617a545691a9ece11ee7359d Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Thu, 25 May 2023 18:26:18 -0500 Subject: [PATCH 17/62] adding correct requirements Signed-off-by: Daniel Rammer --- pkg/compiler/requirements.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pkg/compiler/requirements.go b/pkg/compiler/requirements.go index ab1b11a05..b9f589ad7 100755 --- a/pkg/compiler/requirements.go +++ b/pkg/compiler/requirements.go @@ -86,5 +86,7 @@ func updateNodeRequirements(node *flyteNode, subWfs common.WorkflowIndex, taskId if elseNode := branchN.IfElse.GetElseNode(); elseNode != nil { updateNodeRequirements(elseNode, subWfs, taskIds, workflowIds, followSubworkflows, errs) } + } else if arrayNode := node.GetArrayNode(); arrayNode != nil { + updateNodeRequirements(arrayNode.Node, subWfs, taskIds, workflowIds, followSubworkflows, errs) } } From 45495919ebed373fff8c04de0f0aca890551b186 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Thu, 25 May 2023 23:01:17 -0500 Subject: [PATCH 18/62] working end-2-end with flytekit Signed-off-by: Daniel Rammer --- pkg/controller/nodes/array/handler.go | 13 +++++++------ pkg/controller/nodes/task/handler.go | 1 + 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 2cce80c0b..e3c1886a0 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -109,7 +109,6 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu maxAttempts = *subNodeSpec.GetRetryStrategy().MinAttempts } - fmt.Printf("HAMERSAW - maxAttempts %d\n", maxAttempts) for _, item := range []struct{arrayReference *bitarray.CompactArray; maxValue int}{ {arrayReference: &arrayNodeState.SubNodePhases, maxValue: len(core.Phases)-1}, // TODO @hamersaw - maxValue is for task phases {arrayReference: &arrayNodeState.SubNodeTaskPhases, maxValue: len(core.Phases)-1}, @@ -177,19 +176,21 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // index. however when we check completion status we need to manually append index - so in all cases // where the node phase is not Queued (ie. task handler will launch task and init flytekit params) we // append the subtask index. - /*var subDataDir, subOutputDir storage.DataReference + var subDataDir, subOutputDir storage.DataReference if nodePhase == v1alpha1.NodePhaseQueued { subDataDir, subOutputDir, err = constructOutputReferences(ctx, nCtx) } else { subDataDir, subOutputDir, err = constructOutputReferences(ctx, nCtx, strconv.Itoa(i)) - }*/ + } // TODO @hamersaw - this is a problem because cache lookups happen in NodePhaseQueued // so the cache hit items will be written to the wrong location // can we just change flytekit appending the index onto the location?!?1 - subDataDir, subOutputDir, err := constructOutputReferences(ctx, nCtx, strconv.Itoa(i)) + /*subDataDir, subOutputDir, err := constructOutputReferences(ctx, nCtx, strconv.Itoa(i)) if err != nil { return handler.UnknownTransition, err - } + }*/ + + //fmt.Printf("HAMERSAW - subDataDir '%s' subOutputDir '%s'\n", subDataDir, subOutputDir) subNodeStatus := &v1alpha1.NodeStatus{ Phase: nodePhase, @@ -228,7 +229,6 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu } else { arrayNodeState.SubNodeTaskPhases.SetItem(i, uint64(subNodeStatus.GetTaskNodeStatus().GetPhase())) } - fmt.Printf("HAMERSAW - setting %d to %d\n", i, uint64(subNodeStatus.GetAttempts())) arrayNodeState.SubNodeRetryAttempts.SetItem(i, uint64(subNodeStatus.GetAttempts())) arrayNodeState.SubNodeSystemFailures.SetItem(i, uint64(subNodeStatus.GetSystemFailures())) } @@ -238,6 +238,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu failedCount := 0 for _, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { nodePhase := v1alpha1.NodePhase(nodePhaseUint64) + //fmt.Printf("HAMERSAW - node %d phase %d\n", i, nodePhase) switch nodePhase { case v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseRecovered: successCount++ diff --git a/pkg/controller/nodes/task/handler.go b/pkg/controller/nodes/task/handler.go index 3b778291f..5f4493978 100644 --- a/pkg/controller/nodes/task/handler.go +++ b/pkg/controller/nodes/task/handler.go @@ -174,6 +174,7 @@ func (p *pluginRequestedTransition) FinalTransition(ctx context.Context) (handle return handler.DoTransition(p.ttype, handler.PhaseInfoSuccess(&p.execInfo)), nil case pluginCore.PhaseRetryableFailure: logger.Debugf(ctx, "Transitioning to RetryableFailure") + fmt.Printf("HAMERSAW - %+v\n", p.pInfo.Err()) return handler.DoTransition(p.ttype, handler.PhaseInfoRetryableFailureErr(p.pInfo.Err(), &p.execInfo)), nil case pluginCore.PhasePermanentFailure: logger.Debugf(ctx, "Transitioning to Failure") From b4c6f3eb051b88a9936cf3947396306d440ec323 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Fri, 26 May 2023 08:37:12 -0500 Subject: [PATCH 19/62] reporting output directory on success Signed-off-by: Daniel Rammer --- pkg/controller/nodes/array/handler.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index e3c1886a0..27a488363 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -308,7 +308,13 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu return handler.UnknownTransition, err } - return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess( + &handler.ExecutionInfo{ + OutputInfo: &handler.OutputInfo{ + OutputURI: outputFile, + }, + }, + )), nil default: // TODO @hamersaw - fail } From 134f215588cab02647c1895c13db61e4c7f32411 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Tue, 30 May 2023 07:29:12 -0500 Subject: [PATCH 20/62] fixed output directory append Signed-off-by: Daniel Rammer --- pkg/controller/nodes/array/handler.go | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 27a488363..18fe21ca9 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -176,19 +176,19 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // index. however when we check completion status we need to manually append index - so in all cases // where the node phase is not Queued (ie. task handler will launch task and init flytekit params) we // append the subtask index. - var subDataDir, subOutputDir storage.DataReference + /*var subDataDir, subOutputDir storage.DataReference if nodePhase == v1alpha1.NodePhaseQueued { subDataDir, subOutputDir, err = constructOutputReferences(ctx, nCtx) } else { subDataDir, subOutputDir, err = constructOutputReferences(ctx, nCtx, strconv.Itoa(i)) - } + }*/ // TODO @hamersaw - this is a problem because cache lookups happen in NodePhaseQueued // so the cache hit items will be written to the wrong location // can we just change flytekit appending the index onto the location?!?1 - /*subDataDir, subOutputDir, err := constructOutputReferences(ctx, nCtx, strconv.Itoa(i)) + subDataDir, subOutputDir, err := constructOutputReferences(ctx, nCtx, strconv.Itoa(i)) if err != nil { return handler.UnknownTransition, err - }*/ + } //fmt.Printf("HAMERSAW - subDataDir '%s' subOutputDir '%s'\n", subDataDir, subOutputDir) @@ -207,8 +207,6 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu arrayNodeLookup := newArrayNodeLookup(nCtx.ContextualNodeLookup(), subNodeID, &subNodeSpec, subNodeStatus) // execute subNode through RecursiveNodeHandler - // TODO @hamersaw - if recursiveNodeHandler is exported then can we just create a new one without needing the - // new GetNodeExecutionContextBuilder and WithNodeExecutionContextBuilder functions? arrayNodeExecutionContextBuilder := newArrayNodeExecutionContextBuilder(a.nodeExecutor.GetNodeExecutionContextBuilder(), subNodeID, i, subNodeStatus, inputReader, ¤tParallelism, arrayNode.GetParallelism()) arrayExecutionContext := newArrayExecutionContext(nCtx.ExecutionContext(), i, ¤tParallelism, arrayNode.GetParallelism()) @@ -255,6 +253,10 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu case v1alpha1.ArrayNodePhaseFailing: // TODO @hamersaw - abort everything! fmt.Printf("HAMERSAW TODO - abort ArrayNode!\n") + + /*return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure( + //TODO @hamersaw - fill out error message for all failed subtasks + )), nil*/ case v1alpha1.ArrayNodePhaseSucceeding: outputLiterals := make(map[string]*idlcore.Literal) for i, _ := range arrayNodeState.SubNodePhases.GetItems() { @@ -319,6 +321,9 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // TODO @hamersaw - fail } + // TODO @hamersaw - send task-level events + // this requires externalResources to emulate current maptasks + // update array node status if err := nCtx.NodeStateWriter().PutArrayNodeState(arrayNodeState); err != nil { logger.Errorf(ctx, "failed to store ArrayNode state with err [%s]", err.Error()) From d36cd44e08d45f90c1d399e84030e58756ccd44f Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Tue, 30 May 2023 09:28:22 -0500 Subject: [PATCH 21/62] mocking TaskTemplate interface to enable caching Signed-off-by: Daniel Rammer --- .../nodes/array/execution_context.go | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/pkg/controller/nodes/array/execution_context.go b/pkg/controller/nodes/array/execution_context.go index 80ff6d72c..2f06352c0 100644 --- a/pkg/controller/nodes/array/execution_context.go +++ b/pkg/controller/nodes/array/execution_context.go @@ -5,6 +5,7 @@ import ( "strconv" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" @@ -65,12 +66,44 @@ func (a *arrayEventRecorder) RecordTaskEvent(ctx context.Context, event *event.T return nil } +type arrayTaskReader struct { + interfaces.TaskReader +} + +func (a *arrayTaskReader) Read(ctx context.Context) (*core.TaskTemplate, error) { + taskTemplate, err := a.TaskReader.Read(ctx) + if err != nil { + return nil, err + } + + // convert output list variable to singular + outputVariables := make(map[string]*core.Variable) + for key, value := range taskTemplate.Interface.Outputs.Variables { + switch v := value.Type.Type.(type) { + case *core.LiteralType_CollectionType: + outputVariables[key] = &core.Variable{ + Type: v.CollectionType, + Description: value.Description, + } + default: + outputVariables[key] = value + } + } + + taskTemplate.Interface.Outputs = &core.VariableMap{ + Variables: outputVariables, + } + return taskTemplate, nil +} + + type arrayNodeExecutionContext struct { interfaces.NodeExecutionContext eventRecorder *arrayEventRecorder executionContext *arrayExecutionContext inputReader io.InputReader nodeStatus *v1alpha1.NodeStatus + taskReader interfaces.TaskReader } func (a *arrayNodeExecutionContext) EventsRecorder() interfaces.EventRecorder { @@ -89,6 +122,10 @@ func (a *arrayNodeExecutionContext) NodeStatus() v1alpha1.ExecutableNodeStatus { return a.nodeStatus } +func (a *arrayNodeExecutionContext) TaskReader() interfaces.TaskReader { + return a.taskReader +} + func newArrayNodeExecutionContext(nodeExecutionContext interfaces.NodeExecutionContext, inputReader io.InputReader, subNodeIndex int, nodeStatus *v1alpha1.NodeStatus, currentParallelism *uint32, maxParallelism uint32) *arrayNodeExecutionContext { arrayExecutionContext := newArrayExecutionContext(nodeExecutionContext.ExecutionContext(), subNodeIndex, currentParallelism, maxParallelism) return &arrayNodeExecutionContext{ @@ -97,6 +134,7 @@ func newArrayNodeExecutionContext(nodeExecutionContext interfaces.NodeExecutionC executionContext: arrayExecutionContext, inputReader: inputReader, nodeStatus: nodeStatus, + taskReader: &arrayTaskReader{nodeExecutionContext.TaskReader()}, } } From fa472d445fbd43a6e051fcd078f31ef42cff2cbe Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Tue, 30 May 2023 13:04:51 -0500 Subject: [PATCH 22/62] capture failure reasons Signed-off-by: Daniel Rammer --- pkg/apis/flyteworkflow/v1alpha1/iface.go | 2 + .../flyteworkflow/v1alpha1/node_status.go | 12 ++++ pkg/controller/nodes/array/handler.go | 57 ++++++++++++++++--- pkg/controller/nodes/interfaces/state.go | 1 + pkg/controller/nodes/node_state_manager.go | 1 + pkg/controller/nodes/transformers.go | 1 + 6 files changed, 65 insertions(+), 9 deletions(-) diff --git a/pkg/apis/flyteworkflow/v1alpha1/iface.go b/pkg/apis/flyteworkflow/v1alpha1/iface.go index c4b4ec128..77c43e558 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/iface.go +++ b/pkg/apis/flyteworkflow/v1alpha1/iface.go @@ -285,6 +285,7 @@ type MutableGateNodeStatus interface { type ExecutableArrayNodeStatus interface { GetArrayNodePhase() ArrayNodePhase + GetExecutionError() *core.ExecutionError GetSubNodePhases() bitarray.CompactArray GetSubNodeTaskPhases() bitarray.CompactArray GetSubNodeRetryAttempts() bitarray.CompactArray @@ -295,6 +296,7 @@ type MutableArrayNodeStatus interface { Mutable ExecutableArrayNodeStatus SetArrayNodePhase(phase ArrayNodePhase) + SetExecutionError(executionError *core.ExecutionError) SetSubNodePhases(subNodePhases bitarray.CompactArray) SetSubNodeTaskPhases(subNodeTaskPhases bitarray.CompactArray) SetSubNodeRetryAttempts(subNodeRetryAttempts bitarray.CompactArray) diff --git a/pkg/apis/flyteworkflow/v1alpha1/node_status.go b/pkg/apis/flyteworkflow/v1alpha1/node_status.go index 0215fbdca..1954151e0 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/node_status.go +++ b/pkg/apis/flyteworkflow/v1alpha1/node_status.go @@ -220,6 +220,7 @@ const ( type ArrayNodeStatus struct { MutableStruct Phase ArrayNodePhase `json:"phase,omitempty"` + ExecutionError *core.ExecutionError `json:"executionError,omitempty"` SubNodePhases bitarray.CompactArray `json:"subphase,omitempty"` SubNodeTaskPhases bitarray.CompactArray `json:"subtphase,omitempty"` SubNodeRetryAttempts bitarray.CompactArray `json:"subattempts,omitempty"` @@ -237,6 +238,17 @@ func (in *ArrayNodeStatus) SetArrayNodePhase(phase ArrayNodePhase) { } } +func (in *ArrayNodeStatus) GetExecutionError() *core.ExecutionError { + return in.ExecutionError +} + +func (in *ArrayNodeStatus) SetExecutionError(executionError *core.ExecutionError) { + if in.ExecutionError != executionError { + in.SetDirty() + in.ExecutionError = executionError + } +} + func (in *ArrayNodeStatus) GetSubNodePhases() bitarray.CompactArray { return in.SubNodePhases } diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 18fe21ca9..7fc70ad37 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -10,6 +10,7 @@ import ( "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" + "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/errorcollector" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" "github.com/flyteorg/flytepropeller/pkg/compiler/validators" @@ -50,12 +51,26 @@ func newMetrics(scope promutils.Scope) metrics { // Abort stops the array node defined in the NodeExecutionContext func (a *arrayNodeHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecutionContext, reason string) error { - return nil // TODO @hamersaw - implement abort + arrayNodeState := nCtx.NodeStateReader().GetArrayNodeState() + switch arrayNodeState.Phase { + case v1alpha1.ArrayNodePhaseExecuting: + fallthrough + case v1alpha1.ArrayNodePhaseFailing: + // TODO @hamersaw - implement abort + } + return nil } // Finalize completes the array node defined in the NodeExecutionContext -func (a *arrayNodeHandler) Finalize(ctx context.Context, _ interfaces.NodeExecutionContext) error { - return nil // TODO @hamersaw - implement finalize - clear node data?!?! +func (a *arrayNodeHandler) Finalize(ctx context.Context, nCtx interfaces.NodeExecutionContext) error { + arrayNodeState := nCtx.NodeStateReader().GetArrayNodeState() + switch arrayNodeState.Phase { + case v1alpha1.ArrayNodePhaseExecuting: + fallthrough + case v1alpha1.ArrayNodePhaseFailing: + // TODO @hamersaw - implement finalize + } + return nil } // FinalizeRequired defines whether or not this handler requires finalize to be called on @@ -127,6 +142,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu case v1alpha1.ArrayNodePhaseExecuting: // process array node subnodes currentParallelism := uint32(0) + messageCollector := errorcollector.NewErrorMessageCollector() for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { nodePhase := v1alpha1.NodePhase(nodePhaseUint64) taskPhase := int(arrayNodeState.SubNodeTaskPhases.GetItem(i)) @@ -140,7 +156,6 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu } // need to initialize the inputReader everytime to ensure TaskHandler can access for cache lookups / population - // TODO @hamersaw - once fastcache is implemented this can be optimized to only initialize on NodePhaseUndefined and NodePhaseSucceeding inputLiteralMap, err := constructLiteralMap(ctx, nCtx.InputReader(), i) if err != nil { return handler.UnknownTransition, err @@ -217,6 +232,11 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu return handler.UnknownTransition, err } + // capture subtask error if exists + if subNodeStatus.Error != nil { + messageCollector.Collect(i, subNodeStatus.Error.Message) + } + //fmt.Printf("HAMERSAW - '%d' transition node phase %d -> %d task phase '%d' -> '%d'\n", i, // nodePhase, subNodeStatus.GetPhase(), taskPhase, subNodeStatus.GetTaskNodeStatus().GetPhase()) @@ -234,29 +254,48 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // process phases of subNodes to determine overall `ArrayNode` phase successCount := 0 failedCount := 0 + failingCount := 0 for _, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { nodePhase := v1alpha1.NodePhase(nodePhaseUint64) //fmt.Printf("HAMERSAW - node %d phase %d\n", i, nodePhase) switch nodePhase { case v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseRecovered: successCount++ + case v1alpha1.NodePhaseFailing: + failingCount++ case v1alpha1.NodePhaseFailed: failedCount++ } } + // if there is a failing node set the error message + if failingCount > 0 && arrayNodeState.Error == nil { + arrayNodeState.Error = &idlcore.ExecutionError{ + Message: messageCollector.Summary(512), // TODO @hamersaw - make configurable + } + } + if failedCount > 0 { arrayNodeState.Phase = v1alpha1.ArrayNodePhaseFailing } else if successCount == len(arrayNodeState.SubNodePhases.GetItems()) { arrayNodeState.Phase = v1alpha1.ArrayNodePhaseSucceeding } case v1alpha1.ArrayNodePhaseFailing: - // TODO @hamersaw - abort everything! - fmt.Printf("HAMERSAW TODO - abort ArrayNode!\n") + if err := a.Abort(ctx, nCtx, "TODO @hamersaw"); err != nil { + return handler.UnknownTransition, err + } - /*return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure( - //TODO @hamersaw - fill out error message for all failed subtasks - )), nil*/ + // fail with error (if provided) + if arrayNodeState.Error != nil { + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailureErr(arrayNodeState.Error, nil)), nil + } + + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure( + idlcore.ExecutionError_UNKNOWN, + "ArrayNodeFailing", + "Unknown reason", + nil, + )), nil case v1alpha1.ArrayNodePhaseSucceeding: outputLiterals := make(map[string]*idlcore.Literal) for i, _ := range arrayNodeState.SubNodePhases.GetItems() { diff --git a/pkg/controller/nodes/interfaces/state.go b/pkg/controller/nodes/interfaces/state.go index 5af278147..00021f212 100644 --- a/pkg/controller/nodes/interfaces/state.go +++ b/pkg/controller/nodes/interfaces/state.go @@ -51,6 +51,7 @@ type GateNodeState struct { type ArrayNodeState struct { Phase v1alpha1.ArrayNodePhase + Error *core.ExecutionError SubNodePhases bitarray.CompactArray SubNodeTaskPhases bitarray.CompactArray SubNodeRetryAttempts bitarray.CompactArray diff --git a/pkg/controller/nodes/node_state_manager.go b/pkg/controller/nodes/node_state_manager.go index f67961e7a..aa9006348 100644 --- a/pkg/controller/nodes/node_state_manager.go +++ b/pkg/controller/nodes/node_state_manager.go @@ -159,6 +159,7 @@ func (n nodeStateManager) GetArrayNodeState() interfaces.ArrayNodeState { as := interfaces.ArrayNodeState{} if an != nil { as.Phase = an.GetArrayNodePhase() + as.Error = an.GetExecutionError() as.SubNodePhases = an.GetSubNodePhases() as.SubNodeTaskPhases = an.GetSubNodeTaskPhases() as.SubNodeRetryAttempts = an.GetSubNodeRetryAttempts() diff --git a/pkg/controller/nodes/transformers.go b/pkg/controller/nodes/transformers.go index 2c210c40c..d55402305 100644 --- a/pkg/controller/nodes/transformers.go +++ b/pkg/controller/nodes/transformers.go @@ -288,6 +288,7 @@ func UpdateNodeStatus(np v1alpha1.NodePhase, p handler.PhaseInfo, n interfaces.N na := n.GetArrayNodeState() t := s.GetOrCreateArrayNodeStatus() t.SetArrayNodePhase(na.Phase) + t.SetExecutionError(na.Error) t.SetSubNodePhases(na.SubNodePhases) t.SetSubNodeTaskPhases(na.SubNodeTaskPhases) t.SetSubNodeRetryAttempts(na.SubNodeRetryAttempts) From a6e20c7c4e64eb94244268de2b04708cdc7ce30d Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Wed, 31 May 2023 08:44:40 -0500 Subject: [PATCH 23/62] wrapped up abort and finalize functionality Signed-off-by: Daniel Rammer --- pkg/controller/nodes/array/handler.go | 240 ++++++++++++++++---------- 1 file changed, 151 insertions(+), 89 deletions(-) diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 7fc70ad37..e18db4216 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -14,6 +14,7 @@ import ( "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" "github.com/flyteorg/flytepropeller/pkg/compiler/validators" + "github.com/flyteorg/flytepropeller/pkg/controller/executors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" @@ -51,32 +52,83 @@ func newMetrics(scope promutils.Scope) metrics { // Abort stops the array node defined in the NodeExecutionContext func (a *arrayNodeHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecutionContext, reason string) error { + arrayNode := nCtx.Node().GetArrayNode() arrayNodeState := nCtx.NodeStateReader().GetArrayNodeState() + + messageCollector := errorcollector.NewErrorMessageCollector() switch arrayNodeState.Phase { - case v1alpha1.ArrayNodePhaseExecuting: - fallthrough - case v1alpha1.ArrayNodePhaseFailing: - // TODO @hamersaw - implement abort + case v1alpha1.ArrayNodePhaseExecuting, v1alpha1.ArrayNodePhaseFailing: + currentParallelism := uint32(0) + for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { + nodePhase := v1alpha1.NodePhase(nodePhaseUint64) + + // TODO @hamersaw fix - do not process nodes that haven't started or are in a terminal state + //if nodes.IsNotyetStarted(nodePhase) || nodes.IsTerminalNodePhase(nodePhase) { + if nodePhase == v1alpha1.NodePhaseNotYetStarted || nodePhase == v1alpha1.NodePhaseSucceeded || nodePhase == v1alpha1.NodePhaseFailed || nodePhase == v1alpha1.NodePhaseTimedOut || nodePhase == v1alpha1.NodePhaseSkipped || nodePhase == v1alpha1.NodePhaseRecovered { + continue + } + + // create array contexts + arrayNodeExecutor, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec, _, err := + a.buildArrayNodeContext(ctx, nCtx, &arrayNodeState, arrayNode, i, ¤tParallelism) + + // abort subNode + err = arrayNodeExecutor.AbortHandler(ctx, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec, reason) + if err != nil { + messageCollector.Collect(i, err.Error()) + } + } + } + + if messageCollector.Length() > 0 { + return fmt.Errorf(messageCollector.Summary(512)) // TODO @hamersaw - make configurable } + return nil } // Finalize completes the array node defined in the NodeExecutionContext func (a *arrayNodeHandler) Finalize(ctx context.Context, nCtx interfaces.NodeExecutionContext) error { + arrayNode := nCtx.Node().GetArrayNode() arrayNodeState := nCtx.NodeStateReader().GetArrayNodeState() + + messageCollector := errorcollector.NewErrorMessageCollector() switch arrayNodeState.Phase { - case v1alpha1.ArrayNodePhaseExecuting: - fallthrough - case v1alpha1.ArrayNodePhaseFailing: - // TODO @hamersaw - implement finalize + case v1alpha1.ArrayNodePhaseExecuting, v1alpha1.ArrayNodePhaseFailing, v1alpha1.ArrayNodePhaseSucceeding: + currentParallelism := uint32(0) + for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { + nodePhase := v1alpha1.NodePhase(nodePhaseUint64) + + // TODO @hamersaw fix - do not process nodes that haven't started + //if nodes.IsNotyetStarted(nodePhase) { + if nodePhase == v1alpha1.NodePhaseNotYetStarted { + continue + } + + // create array contexts + arrayNodeExecutor, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec, _, err := + a.buildArrayNodeContext(ctx, nCtx, &arrayNodeState, arrayNode, i, ¤tParallelism) + + // finalize subNode + err = arrayNodeExecutor.FinalizeHandler(ctx, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec) + if err != nil { + messageCollector.Collect(i, err.Error()) + } + } + } + + if messageCollector.Length() > 0 { + return fmt.Errorf(messageCollector.Summary(512)) // TODO @hamersaw - make configurable } + return nil } -// FinalizeRequired defines whether or not this handler requires finalize to be called on -// node completion +// FinalizeRequired defines whether or not this handler requires finalize to be called on node +// completion func (a *arrayNodeHandler) FinalizeRequired() bool { - return false + // must return true because we can't determine if finalize is required for the subNode + return true } // Handle is responsible for transitioning and reporting node state to complete the node defined @@ -145,9 +197,6 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu messageCollector := errorcollector.NewErrorMessageCollector() for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { nodePhase := v1alpha1.NodePhase(nodePhaseUint64) - taskPhase := int(arrayNodeState.SubNodeTaskPhases.GetItem(i)) - - //fmt.Printf("HAMERSAW - evaluating node '%d' in node phase '%d' task phase '%d'\n", i, nodePhase, taskPhase) // TODO @hamersaw fix - do not process nodes in terminal state //if nodes.IsTerminalNodePhase(nodePhase) { @@ -155,84 +204,17 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu continue } - // need to initialize the inputReader everytime to ensure TaskHandler can access for cache lookups / population - inputLiteralMap, err := constructLiteralMap(ctx, nCtx.InputReader(), i) - if err != nil { - return handler.UnknownTransition, err - } - - inputReader := newStaticInputReader(nCtx.InputReader(), inputLiteralMap) - - // if node has not yet started we automatically set to NodePhaseQueued to skip input resolution - if nodePhase == v1alpha1.NodePhaseNotYetStarted { - // TODO @hamersaw how does this work with fastcache? - // to supprt fastcache we'll need to override the bindings to BindingScalars for the input resolution on the nCtx - // that way we resolution is just reading a literal ... but does this still write a file then?!? - nodePhase = v1alpha1.NodePhaseQueued - } - - // wrap node lookup - subNodeSpec := *arrayNode.GetSubNodeSpec() - - subNodeID := fmt.Sprintf("%s-n%d", nCtx.NodeID(), i) - subNodeSpec.ID = subNodeID - subNodeSpec.Name = subNodeID + // create array contexts + arrayNodeExecutor, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec, subNodeStatus, err := + a.buildArrayNodeContext(ctx, nCtx, &arrayNodeState, arrayNode, i, ¤tParallelism) - // TODO - if we want to support more plugin types we need to figure out the best way to store plugin state - // currently just mocking based on node phase -> which works for all k8s plugins - // we can not pre-allocated a bit array because max size is 256B and with 5k fanout node state = 1.28MB - pluginStateBytes := a.pluginStateBytesStarted - //if nodePhase == v1alpha1.NodePhaseQueued || nodePhase == v1alpha1.NodePhaseRetryableFailure { - if taskPhase == int(core.PhaseUndefined) || taskPhase == int(core.PhaseRetryableFailure) { - pluginStateBytes = a.pluginStateBytesNotStarted - } - - // we set subDataDir and subOutputDir to the node dirs because flytekit automatically appends subtask - // index. however when we check completion status we need to manually append index - so in all cases - // where the node phase is not Queued (ie. task handler will launch task and init flytekit params) we - // append the subtask index. - /*var subDataDir, subOutputDir storage.DataReference - if nodePhase == v1alpha1.NodePhaseQueued { - subDataDir, subOutputDir, err = constructOutputReferences(ctx, nCtx) - } else { - subDataDir, subOutputDir, err = constructOutputReferences(ctx, nCtx, strconv.Itoa(i)) - }*/ - // TODO @hamersaw - this is a problem because cache lookups happen in NodePhaseQueued - // so the cache hit items will be written to the wrong location - // can we just change flytekit appending the index onto the location?!?1 - subDataDir, subOutputDir, err := constructOutputReferences(ctx, nCtx, strconv.Itoa(i)) + // execute subNode + _, err = arrayNodeExecutor.RecursiveNodeHandler(ctx, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec) if err != nil { return handler.UnknownTransition, err } - //fmt.Printf("HAMERSAW - subDataDir '%s' subOutputDir '%s'\n", subDataDir, subOutputDir) - - subNodeStatus := &v1alpha1.NodeStatus{ - Phase: nodePhase, - DataDir: subDataDir, - OutputDir: subOutputDir, - Attempts: uint32(arrayNodeState.SubNodeRetryAttempts.GetItem(i)), - SystemFailures: uint32(arrayNodeState.SubNodeSystemFailures.GetItem(i)), - TaskNodeStatus: &v1alpha1.TaskNodeStatus{ - Phase: taskPhase, - PluginState: pluginStateBytes, - }, - } - - arrayNodeLookup := newArrayNodeLookup(nCtx.ContextualNodeLookup(), subNodeID, &subNodeSpec, subNodeStatus) - - // execute subNode through RecursiveNodeHandler - arrayNodeExecutionContextBuilder := newArrayNodeExecutionContextBuilder(a.nodeExecutor.GetNodeExecutionContextBuilder(), - subNodeID, i, subNodeStatus, inputReader, ¤tParallelism, arrayNode.GetParallelism()) - arrayExecutionContext := newArrayExecutionContext(nCtx.ExecutionContext(), i, ¤tParallelism, arrayNode.GetParallelism()) - - arrayNodeExecutor := a.nodeExecutor.WithNodeExecutionContextBuilder(arrayNodeExecutionContextBuilder) - _, err = arrayNodeExecutor.RecursiveNodeHandler(ctx, arrayExecutionContext, &arrayNodeLookup, &arrayNodeLookup, &subNodeSpec) - if err != nil { - return handler.UnknownTransition, err - } - - // capture subtask error if exists + // capture subNode error if exists if subNodeStatus.Error != nil { messageCollector.Collect(i, subNodeStatus.Error.Message) } @@ -240,6 +222,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu //fmt.Printf("HAMERSAW - '%d' transition node phase %d -> %d task phase '%d' -> '%d'\n", i, // nodePhase, subNodeStatus.GetPhase(), taskPhase, subNodeStatus.GetTaskNodeStatus().GetPhase()) + // update subNode state arrayNodeState.SubNodePhases.SetItem(i, uint64(subNodeStatus.GetPhase())) if subNodeStatus.GetTaskNodeStatus() == nil { // TODO @hamersaw during retries we clear the GetTaskNodeStatus - so resetting task phase @@ -281,11 +264,11 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu arrayNodeState.Phase = v1alpha1.ArrayNodePhaseSucceeding } case v1alpha1.ArrayNodePhaseFailing: - if err := a.Abort(ctx, nCtx, "TODO @hamersaw"); err != nil { + if err := a.Abort(ctx, nCtx, "ArrayNodeFailing"); err != nil { return handler.UnknownTransition, err } - // fail with error (if provided) + // fail with reported error if one exists if arrayNodeState.Error != nil { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailureErr(arrayNodeState.Error, nil)), nil } @@ -399,6 +382,85 @@ func New(nodeExecutor interfaces.Node, scope promutils.Scope) (handler.Node, err }, nil } +func (a *arrayNodeHandler) buildArrayNodeContext(ctx context.Context, nCtx interfaces.NodeExecutionContext, arrayNodeState *interfaces.ArrayNodeState, arrayNode v1alpha1.ExecutableArrayNode, subNodeIndex int, currentParallelism *uint32) ( + interfaces.Node, executors.ExecutionContext, executors.DAGStructure, executors.NodeLookup, *v1alpha1.NodeSpec, *v1alpha1.NodeStatus, error) { + + nodePhase := v1alpha1.NodePhase(arrayNodeState.SubNodePhases.GetItem(subNodeIndex)) + taskPhase := int(arrayNodeState.SubNodeTaskPhases.GetItem(subNodeIndex)) + + // need to initialize the inputReader everytime to ensure TaskHandler can access for cache lookups / population + inputLiteralMap, err := constructLiteralMap(ctx, nCtx.InputReader(), subNodeIndex) + if err != nil { + return nil, nil, nil, nil, nil, nil, err + } + + inputReader := newStaticInputReader(nCtx.InputReader(), inputLiteralMap) + + // if node has not yet started we automatically set to NodePhaseQueued to skip input resolution + if nodePhase == v1alpha1.NodePhaseNotYetStarted { + // TODO @hamersaw how does this work with fastcache? + // to supprt fastcache we'll need to override the bindings to BindingScalars for the input resolution on the nCtx + // that way we resolution is just reading a literal ... but does this still write a file then?!? + nodePhase = v1alpha1.NodePhaseQueued + } + + // wrap node lookup + subNodeSpec := *arrayNode.GetSubNodeSpec() + + subNodeID := fmt.Sprintf("%s-n%d", nCtx.NodeID(), subNodeIndex) + subNodeSpec.ID = subNodeID + subNodeSpec.Name = subNodeID + + // TODO - if we want to support more plugin types we need to figure out the best way to store plugin state + // currently just mocking based on node phase -> which works for all k8s plugins + // we can not pre-allocated a bit array because max size is 256B and with 5k fanout node state = 1.28MB + pluginStateBytes := a.pluginStateBytesStarted + //if nodePhase == v1alpha1.NodePhaseQueued || nodePhase == v1alpha1.NodePhaseRetryableFailure { + if taskPhase == int(core.PhaseUndefined) || taskPhase == int(core.PhaseRetryableFailure) { + pluginStateBytes = a.pluginStateBytesNotStarted + } + + // we set subDataDir and subOutputDir to the node dirs because flytekit automatically appends subtask + // index. however when we check completion status we need to manually append index - so in all cases + // where the node phase is not Queued (ie. task handler will launch task and init flytekit params) we + // append the subtask index. + /*var subDataDir, subOutputDir storage.DataReference + if nodePhase == v1alpha1.NodePhaseQueued { + subDataDir, subOutputDir, err = constructOutputReferences(ctx, nCtx) + } else { + subDataDir, subOutputDir, err = constructOutputReferences(ctx, nCtx, strconv.Itoa(i)) + }*/ + // TODO @hamersaw - this is a problem because cache lookups happen in NodePhaseQueued + // so the cache hit items will be written to the wrong location + // can we just change flytekit appending the index onto the location?!?1 + subDataDir, subOutputDir, err := constructOutputReferences(ctx, nCtx, strconv.Itoa(subNodeIndex)) + if err != nil { + return nil, nil, nil, nil, nil, nil, err + } + + subNodeStatus := &v1alpha1.NodeStatus{ + Phase: nodePhase, + DataDir: subDataDir, + OutputDir: subOutputDir, + Attempts: uint32(arrayNodeState.SubNodeRetryAttempts.GetItem(subNodeIndex)), + SystemFailures: uint32(arrayNodeState.SubNodeSystemFailures.GetItem(subNodeIndex)), + TaskNodeStatus: &v1alpha1.TaskNodeStatus{ + Phase: taskPhase, + PluginState: pluginStateBytes, + }, + } + + arrayNodeLookup := newArrayNodeLookup(nCtx.ContextualNodeLookup(), subNodeID, &subNodeSpec, subNodeStatus) + + arrayExecutionContext := newArrayExecutionContext(nCtx.ExecutionContext(), subNodeIndex, currentParallelism, arrayNode.GetParallelism()) + + arrayNodeExecutionContextBuilder := newArrayNodeExecutionContextBuilder(a.nodeExecutor.GetNodeExecutionContextBuilder(), + subNodeID, subNodeIndex, subNodeStatus, inputReader, currentParallelism, arrayNode.GetParallelism()) + arrayNodeExecutor := a.nodeExecutor.WithNodeExecutionContextBuilder(arrayNodeExecutionContextBuilder) + + return arrayNodeExecutor, arrayExecutionContext, &arrayNodeLookup, &arrayNodeLookup, &subNodeSpec, subNodeStatus, nil +} + func bytesFromK8sPluginState(pluginState k8s.PluginState) ([]byte, error) { buffer := make([]byte, 0, task.MaxPluginStateSizeBytes) bufferWriter := bytes.NewBuffer(buffer) From f5a46b91547915f865552bd9ea389d6fed7e4ec4 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Wed, 31 May 2023 09:51:44 -0500 Subject: [PATCH 24/62] mocking initialization events Signed-off-by: Daniel Rammer --- pkg/controller/nodes/array/handler.go | 59 ++++++++++++++++++++++++++- 1 file changed, 57 insertions(+), 2 deletions(-) diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index e18db4216..6bf435d7d 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -5,8 +5,10 @@ import ( "context" "fmt" "strconv" + "time" idlcore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" @@ -14,6 +16,7 @@ import ( "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" "github.com/flyteorg/flytepropeller/pkg/compiler/validators" + "github.com/flyteorg/flytepropeller/pkg/controller/config" "github.com/flyteorg/flytepropeller/pkg/controller/executors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" @@ -26,6 +29,8 @@ import ( "github.com/flyteorg/flytestdlib/logger" "github.com/flyteorg/flytestdlib/promutils" "github.com/flyteorg/flytestdlib/storage" + + "github.com/golang/protobuf/ptypes" ) //go:generate mockery -all -case=underscore @@ -137,6 +142,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu arrayNode := nCtx.Node().GetArrayNode() arrayNodeState := nCtx.NodeStateReader().GetArrayNodeState() + var externalResources []*event.ExternalResourceInfo switch arrayNodeState.Phase { case v1alpha1.ArrayNodePhaseNone: // identify and validate array node input value lengths @@ -189,6 +195,19 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu } } + // initialize externalResources + externalResources = make([]*event.ExternalResourceInfo, 0, size) + for i := 0; i < size; i++ { + externalResources = append(externalResources, &event.ExternalResourceInfo{ + ExternalId: fmt.Sprintf("%s-%d", nCtx.NodeID, i), // TODO @hamersaw do better + //CacheStatus: cacheStatus, + Index: uint32(i), + Logs: nil, + RetryAttempt: 0, + Phase: idlcore.TaskExecution_QUEUED, + }) + } + // transition ArrayNode to `ArrayNodePhaseExecuting` arrayNodeState.Phase = v1alpha1.ArrayNodePhaseExecuting case v1alpha1.ArrayNodePhaseExecuting: @@ -343,8 +362,44 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // TODO @hamersaw - fail } - // TODO @hamersaw - send task-level events - // this requires externalResources to emulate current maptasks + // TODO @hamersaw - send task-level events - this requires externalResources to emulate current maptasks + if len(externalResources) > 0 { + occurredAt, err := ptypes.TimestampProto(time.Now()) + if err != nil { + return handler.UnknownTransition, err + } + + nodeExecutionId := nCtx.NodeExecutionMetadata().GetNodeExecutionID() + workflowExecutionId := nodeExecutionId.ExecutionId + taskExecutionEvent := &event.TaskExecutionEvent{ + TaskId: &idlcore.Identifier{ + ResourceType: idlcore.ResourceType_TASK, + Project: workflowExecutionId.Project, + Domain: workflowExecutionId.Domain, + Name: "foo", // TODO @hamersaw - do better + Version: "v1", // TODO @hamersaw - do better + }, + ParentNodeExecutionId: nCtx.NodeExecutionMetadata().GetNodeExecutionID(), + RetryAttempt: 0, // ArrayNode will never retry + Phase: 2, // TODO @hamersaw - determine node phase from ArrayNodePhase (ie. Queued, Running, Succeeded, Failed) + PhaseVersion: 0, // TODO @hamersaw - need to increment? + OccurredAt: occurredAt, + Metadata: &event.TaskExecutionMetadata{ + ExternalResources: externalResources, + }, + TaskType: "k8s-array", + EventVersion: 1, + } + + // TODO @hamersaw - pass eventConfig correctly + if err := nCtx.EventsRecorder().RecordTaskEvent(ctx, taskExecutionEvent, &config.EventConfig{}); err != nil { + logger.Errorf(ctx, "ArrayNode event recording failed: [%s]", err.Error()) + //logger.Errorf(ctx, "event recording failed for Plugin [%s], eventPhase [%s], error :%s", p.GetID(), evInfo.Phase.String(), err.Error()) + // Check for idempotency + // Check for terminate state error + return handler.UnknownTransition, err + } + } // update array node status if err := nCtx.NodeStateWriter().PutArrayNodeState(arrayNodeState); err != nil { From 6dcbe530382a1f8ab17d374083e5e1247704dc83 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Wed, 31 May 2023 14:00:30 -0500 Subject: [PATCH 25/62] sending all events Signed-off-by: Daniel Rammer --- pkg/apis/flyteworkflow/v1alpha1/iface.go | 2 + .../flyteworkflow/v1alpha1/node_status.go | 12 +++++ .../nodes/array/execution_context.go | 28 +++++++--- pkg/controller/nodes/array/handler.go | 51 +++++++++++++------ pkg/controller/nodes/interfaces/state.go | 1 + pkg/controller/nodes/node_state_manager.go | 1 + pkg/controller/nodes/transformers.go | 1 + 7 files changed, 75 insertions(+), 21 deletions(-) diff --git a/pkg/apis/flyteworkflow/v1alpha1/iface.go b/pkg/apis/flyteworkflow/v1alpha1/iface.go index 77c43e558..dfffbfd02 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/iface.go +++ b/pkg/apis/flyteworkflow/v1alpha1/iface.go @@ -290,6 +290,7 @@ type ExecutableArrayNodeStatus interface { GetSubNodeTaskPhases() bitarray.CompactArray GetSubNodeRetryAttempts() bitarray.CompactArray GetSubNodeSystemFailures() bitarray.CompactArray + GetTaskPhaseVersion() uint32 } type MutableArrayNodeStatus interface { @@ -301,6 +302,7 @@ type MutableArrayNodeStatus interface { SetSubNodeTaskPhases(subNodeTaskPhases bitarray.CompactArray) SetSubNodeRetryAttempts(subNodeRetryAttempts bitarray.CompactArray) SetSubNodeSystemFailures(subNodeSystemFailures bitarray.CompactArray) + SetTaskPhaseVersion(taskPhaseVersion uint32) } type Mutable interface { diff --git a/pkg/apis/flyteworkflow/v1alpha1/node_status.go b/pkg/apis/flyteworkflow/v1alpha1/node_status.go index 1954151e0..5a4d9fc8a 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/node_status.go +++ b/pkg/apis/flyteworkflow/v1alpha1/node_status.go @@ -225,6 +225,7 @@ type ArrayNodeStatus struct { SubNodeTaskPhases bitarray.CompactArray `json:"subtphase,omitempty"` SubNodeRetryAttempts bitarray.CompactArray `json:"subattempts,omitempty"` SubNodeSystemFailures bitarray.CompactArray `json:"subsysfailures,omitempty"` + TaskPhaseVersion uint32 `json:"taskPhaseVersion,omitempty"` } func (in *ArrayNodeStatus) GetArrayNodePhase() ArrayNodePhase { @@ -293,6 +294,17 @@ func (in *ArrayNodeStatus) SetSubNodeSystemFailures(subNodeSystemFailures bitarr } } +func (in *ArrayNodeStatus) GetTaskPhaseVersion() uint32 { + return in.TaskPhaseVersion +} + +func (in *ArrayNodeStatus) SetTaskPhaseVersion(taskPhaseVersion uint32) { + if in.TaskPhaseVersion != taskPhaseVersion { + in.SetDirty() + in.TaskPhaseVersion = taskPhaseVersion + } +} + type NodeStatus struct { MutableStruct Phase NodePhase `json:"phase,omitempty"` diff --git a/pkg/controller/nodes/array/execution_context.go b/pkg/controller/nodes/array/execution_context.go index 2f06352c0..c53e092e9 100644 --- a/pkg/controller/nodes/array/execution_context.go +++ b/pkg/controller/nodes/array/execution_context.go @@ -56,16 +56,29 @@ func newArrayExecutionContext(executionContext executors.ExecutionContext, subNo } } -type arrayEventRecorder struct {} +type arrayEventRecorder struct { + taskEvents []*event.TaskExecutionEvent +} func (a *arrayEventRecorder) RecordNodeEvent(ctx context.Context, event *event.NodeExecutionEvent, eventConfig *config.EventConfig) error { return nil } func (a *arrayEventRecorder) RecordTaskEvent(ctx context.Context, event *event.TaskExecutionEvent, eventConfig *config.EventConfig) error { + a.taskEvents = append(a.taskEvents, event) return nil } +func (a *arrayEventRecorder) Events() []*event.TaskExecutionEvent { + return a.taskEvents +} + +func newArrayEventRecorder() *arrayEventRecorder { + return &arrayEventRecorder{ + taskEvents: make([]*event.TaskExecutionEvent, 0), + } +} + type arrayTaskReader struct { interfaces.TaskReader } @@ -99,7 +112,7 @@ func (a *arrayTaskReader) Read(ctx context.Context) (*core.TaskTemplate, error) type arrayNodeExecutionContext struct { interfaces.NodeExecutionContext - eventRecorder *arrayEventRecorder + eventRecorder interfaces.EventRecorder executionContext *arrayExecutionContext inputReader io.InputReader nodeStatus *v1alpha1.NodeStatus @@ -126,11 +139,11 @@ func (a *arrayNodeExecutionContext) TaskReader() interfaces.TaskReader { return a.taskReader } -func newArrayNodeExecutionContext(nodeExecutionContext interfaces.NodeExecutionContext, inputReader io.InputReader, subNodeIndex int, nodeStatus *v1alpha1.NodeStatus, currentParallelism *uint32, maxParallelism uint32) *arrayNodeExecutionContext { +func newArrayNodeExecutionContext(nodeExecutionContext interfaces.NodeExecutionContext, inputReader io.InputReader, eventRecorder interfaces.EventRecorder, subNodeIndex int, nodeStatus *v1alpha1.NodeStatus, currentParallelism *uint32, maxParallelism uint32) *arrayNodeExecutionContext { arrayExecutionContext := newArrayExecutionContext(nodeExecutionContext.ExecutionContext(), subNodeIndex, currentParallelism, maxParallelism) return &arrayNodeExecutionContext{ NodeExecutionContext: nodeExecutionContext, - eventRecorder: &arrayEventRecorder{}, + eventRecorder: eventRecorder, executionContext: arrayExecutionContext, inputReader: inputReader, nodeStatus: nodeStatus, @@ -146,6 +159,7 @@ type arrayNodeExecutionContextBuilder struct { inputReader io.InputReader currentParallelism *uint32 maxParallelism uint32 + eventRecorder interfaces.EventRecorder } func (a *arrayNodeExecutionContextBuilder) BuildNodeExecutionContext(ctx context.Context, executionContext executors.ExecutionContext, @@ -159,14 +173,15 @@ func (a *arrayNodeExecutionContextBuilder) BuildNodeExecutionContext(ctx context if currentNodeID == a.subNodeID { // overwrite NodeExecutionContext for ArrayNode execution - nCtx = newArrayNodeExecutionContext(nCtx, a.inputReader, a.subNodeIndex, a.subNodeStatus, a.currentParallelism, a.maxParallelism) + nCtx = newArrayNodeExecutionContext(nCtx, a.inputReader, a.eventRecorder, a.subNodeIndex, a.subNodeStatus, a.currentParallelism, a.maxParallelism) } return nCtx, nil } func newArrayNodeExecutionContextBuilder(nCtxBuilder interfaces.NodeExecutionContextBuilder, subNodeID v1alpha1.NodeID, - subNodeIndex int, subNodeStatus *v1alpha1.NodeStatus, inputReader io.InputReader, currentParallelism *uint32, maxParallelism uint32) interfaces.NodeExecutionContextBuilder { + subNodeIndex int, subNodeStatus *v1alpha1.NodeStatus, inputReader io.InputReader, eventRecorder interfaces.EventRecorder, + currentParallelism *uint32, maxParallelism uint32) interfaces.NodeExecutionContextBuilder { return &arrayNodeExecutionContextBuilder{ nCtxBuilder: nCtxBuilder, @@ -176,5 +191,6 @@ func newArrayNodeExecutionContextBuilder(nCtxBuilder interfaces.NodeExecutionCon inputReader: inputReader, currentParallelism: currentParallelism, maxParallelism: maxParallelism, + eventRecorder: eventRecorder, } } diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 6bf435d7d..6b0f3e3a0 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -74,7 +74,7 @@ func (a *arrayNodeHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecut } // create array contexts - arrayNodeExecutor, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec, _, err := + arrayNodeExecutor, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec, _, _, err := a.buildArrayNodeContext(ctx, nCtx, &arrayNodeState, arrayNode, i, ¤tParallelism) // abort subNode @@ -111,7 +111,7 @@ func (a *arrayNodeHandler) Finalize(ctx context.Context, nCtx interfaces.NodeExe } // create array contexts - arrayNodeExecutor, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec, _, err := + arrayNodeExecutor, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec, _, _, err := a.buildArrayNodeContext(ctx, nCtx, &arrayNodeState, arrayNode, i, ¤tParallelism) // finalize subNode @@ -143,6 +143,8 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu arrayNodeState := nCtx.NodeStateReader().GetArrayNodeState() var externalResources []*event.ExternalResourceInfo + taskPhaseVersion := arrayNodeState.TaskPhaseVersion + switch arrayNodeState.Phase { case v1alpha1.ArrayNodePhaseNone: // identify and validate array node input value lengths @@ -200,7 +202,6 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu for i := 0; i < size; i++ { externalResources = append(externalResources, &event.ExternalResourceInfo{ ExternalId: fmt.Sprintf("%s-%d", nCtx.NodeID, i), // TODO @hamersaw do better - //CacheStatus: cacheStatus, Index: uint32(i), Logs: nil, RetryAttempt: 0, @@ -214,6 +215,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // process array node subnodes currentParallelism := uint32(0) messageCollector := errorcollector.NewErrorMessageCollector() + externalResources = make([]*event.ExternalResourceInfo, 0) for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { nodePhase := v1alpha1.NodePhase(nodePhaseUint64) @@ -224,7 +226,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu } // create array contexts - arrayNodeExecutor, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec, subNodeStatus, err := + arrayNodeExecutor, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec, subNodeStatus, arrayEventRecorder, err := a.buildArrayNodeContext(ctx, nCtx, &arrayNodeState, arrayNode, i, ¤tParallelism) // execute subNode @@ -238,6 +240,23 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu messageCollector.Collect(i, subNodeStatus.Error.Message) } + // process TaskExecutionEvents + for _, taskExecutionEvent := range arrayEventRecorder.Events() { + taskPhase := idlcore.TaskExecution_UNDEFINED + if taskNodeStatus := subNodeStatus.GetTaskNodeStatus(); taskNodeStatus != nil { + taskPhase = task.ToTaskEventPhase(core.Phase(taskNodeStatus.GetPhase())) + } + + fmt.Printf("HAMERSAW - processing event '%s' for node '%d' with phase '%d'\n", taskExecutionEvent.Phase.String(), i, taskPhase) + externalResources = append(externalResources, &event.ExternalResourceInfo{ + ExternalId: fmt.Sprintf("%s-%d", nCtx.NodeID, i), // TODO @hamersaw do better + Index: uint32(i), + Logs: taskExecutionEvent.Logs, + RetryAttempt: 0, + Phase: taskPhase, + }) + } + //fmt.Printf("HAMERSAW - '%d' transition node phase %d -> %d task phase '%d' -> '%d'\n", i, // nodePhase, subNodeStatus.GetPhase(), taskPhase, subNodeStatus.GetTaskNodeStatus().GetPhase()) @@ -376,13 +395,13 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu ResourceType: idlcore.ResourceType_TASK, Project: workflowExecutionId.Project, Domain: workflowExecutionId.Domain, - Name: "foo", // TODO @hamersaw - do better - Version: "v1", // TODO @hamersaw - do better + Name: "foo", // TODO @hamersaw - make it better + Version: "v1", // TODO @hamersaw - please }, ParentNodeExecutionId: nCtx.NodeExecutionMetadata().GetNodeExecutionID(), RetryAttempt: 0, // ArrayNode will never retry Phase: 2, // TODO @hamersaw - determine node phase from ArrayNodePhase (ie. Queued, Running, Succeeded, Failed) - PhaseVersion: 0, // TODO @hamersaw - need to increment? + PhaseVersion: taskPhaseVersion, OccurredAt: occurredAt, Metadata: &event.TaskExecutionMetadata{ ExternalResources: externalResources, @@ -394,11 +413,12 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // TODO @hamersaw - pass eventConfig correctly if err := nCtx.EventsRecorder().RecordTaskEvent(ctx, taskExecutionEvent, &config.EventConfig{}); err != nil { logger.Errorf(ctx, "ArrayNode event recording failed: [%s]", err.Error()) - //logger.Errorf(ctx, "event recording failed for Plugin [%s], eventPhase [%s], error :%s", p.GetID(), evInfo.Phase.String(), err.Error()) - // Check for idempotency - // Check for terminate state error return handler.UnknownTransition, err } + + // TODO @hamersaw - only need to increment if arrayNodeState.Phase does not change + // if it does we can reset to 0 + arrayNodeState.TaskPhaseVersion = taskPhaseVersion+1 } // update array node status @@ -438,7 +458,7 @@ func New(nodeExecutor interfaces.Node, scope promutils.Scope) (handler.Node, err } func (a *arrayNodeHandler) buildArrayNodeContext(ctx context.Context, nCtx interfaces.NodeExecutionContext, arrayNodeState *interfaces.ArrayNodeState, arrayNode v1alpha1.ExecutableArrayNode, subNodeIndex int, currentParallelism *uint32) ( - interfaces.Node, executors.ExecutionContext, executors.DAGStructure, executors.NodeLookup, *v1alpha1.NodeSpec, *v1alpha1.NodeStatus, error) { + interfaces.Node, executors.ExecutionContext, executors.DAGStructure, executors.NodeLookup, *v1alpha1.NodeSpec, *v1alpha1.NodeStatus, *arrayEventRecorder, error) { nodePhase := v1alpha1.NodePhase(arrayNodeState.SubNodePhases.GetItem(subNodeIndex)) taskPhase := int(arrayNodeState.SubNodeTaskPhases.GetItem(subNodeIndex)) @@ -446,7 +466,7 @@ func (a *arrayNodeHandler) buildArrayNodeContext(ctx context.Context, nCtx inter // need to initialize the inputReader everytime to ensure TaskHandler can access for cache lookups / population inputLiteralMap, err := constructLiteralMap(ctx, nCtx.InputReader(), subNodeIndex) if err != nil { - return nil, nil, nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, nil, nil, err } inputReader := newStaticInputReader(nCtx.InputReader(), inputLiteralMap) @@ -490,7 +510,7 @@ func (a *arrayNodeHandler) buildArrayNodeContext(ctx context.Context, nCtx inter // can we just change flytekit appending the index onto the location?!?1 subDataDir, subOutputDir, err := constructOutputReferences(ctx, nCtx, strconv.Itoa(subNodeIndex)) if err != nil { - return nil, nil, nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, nil, nil, err } subNodeStatus := &v1alpha1.NodeStatus{ @@ -509,11 +529,12 @@ func (a *arrayNodeHandler) buildArrayNodeContext(ctx context.Context, nCtx inter arrayExecutionContext := newArrayExecutionContext(nCtx.ExecutionContext(), subNodeIndex, currentParallelism, arrayNode.GetParallelism()) + arrayEventRecorder := newArrayEventRecorder() arrayNodeExecutionContextBuilder := newArrayNodeExecutionContextBuilder(a.nodeExecutor.GetNodeExecutionContextBuilder(), - subNodeID, subNodeIndex, subNodeStatus, inputReader, currentParallelism, arrayNode.GetParallelism()) + subNodeID, subNodeIndex, subNodeStatus, inputReader, arrayEventRecorder, currentParallelism, arrayNode.GetParallelism()) arrayNodeExecutor := a.nodeExecutor.WithNodeExecutionContextBuilder(arrayNodeExecutionContextBuilder) - return arrayNodeExecutor, arrayExecutionContext, &arrayNodeLookup, &arrayNodeLookup, &subNodeSpec, subNodeStatus, nil + return arrayNodeExecutor, arrayExecutionContext, &arrayNodeLookup, &arrayNodeLookup, &subNodeSpec, subNodeStatus, arrayEventRecorder, nil } func bytesFromK8sPluginState(pluginState k8s.PluginState) ([]byte, error) { diff --git a/pkg/controller/nodes/interfaces/state.go b/pkg/controller/nodes/interfaces/state.go index 00021f212..bf753a23a 100644 --- a/pkg/controller/nodes/interfaces/state.go +++ b/pkg/controller/nodes/interfaces/state.go @@ -51,6 +51,7 @@ type GateNodeState struct { type ArrayNodeState struct { Phase v1alpha1.ArrayNodePhase + TaskPhaseVersion uint32 Error *core.ExecutionError SubNodePhases bitarray.CompactArray SubNodeTaskPhases bitarray.CompactArray diff --git a/pkg/controller/nodes/node_state_manager.go b/pkg/controller/nodes/node_state_manager.go index aa9006348..17f4113a3 100644 --- a/pkg/controller/nodes/node_state_manager.go +++ b/pkg/controller/nodes/node_state_manager.go @@ -164,6 +164,7 @@ func (n nodeStateManager) GetArrayNodeState() interfaces.ArrayNodeState { as.SubNodeTaskPhases = an.GetSubNodeTaskPhases() as.SubNodeRetryAttempts = an.GetSubNodeRetryAttempts() as.SubNodeSystemFailures = an.GetSubNodeSystemFailures() + as.TaskPhaseVersion = an.GetTaskPhaseVersion() } return as } diff --git a/pkg/controller/nodes/transformers.go b/pkg/controller/nodes/transformers.go index d55402305..b2ec54807 100644 --- a/pkg/controller/nodes/transformers.go +++ b/pkg/controller/nodes/transformers.go @@ -293,5 +293,6 @@ func UpdateNodeStatus(np v1alpha1.NodePhase, p handler.PhaseInfo, n interfaces.N t.SetSubNodeTaskPhases(na.SubNodeTaskPhases) t.SetSubNodeRetryAttempts(na.SubNodeRetryAttempts) t.SetSubNodeSystemFailures(na.SubNodeSystemFailures) + t.SetTaskPhaseVersion(na.TaskPhaseVersion) } } From 59df8426e6753a911e346f140ad2a73fc0197df5 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Thu, 1 Jun 2023 09:01:45 -0500 Subject: [PATCH 26/62] minor refactoring of debug prints and formatting Signed-off-by: Daniel Rammer --- pkg/controller/nodes/array/handler.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 6b0f3e3a0..8258a3a80 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -247,13 +247,13 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu taskPhase = task.ToTaskEventPhase(core.Phase(taskNodeStatus.GetPhase())) } - fmt.Printf("HAMERSAW - processing event '%s' for node '%d' with phase '%d'\n", taskExecutionEvent.Phase.String(), i, taskPhase) externalResources = append(externalResources, &event.ExternalResourceInfo{ ExternalId: fmt.Sprintf("%s-%d", nCtx.NodeID, i), // TODO @hamersaw do better Index: uint32(i), Logs: taskExecutionEvent.Logs, RetryAttempt: 0, Phase: taskPhase, + //CacheStatus: taskExecutionEvent.Metadata, // TODO @hamersaw - figure out how to get CacheStatus back }) } @@ -406,7 +406,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu Metadata: &event.TaskExecutionMetadata{ ExternalResources: externalResources, }, - TaskType: "k8s-array", + TaskType: "k8s-array", EventVersion: 1, } From 9a9d0f0f8899f56b1d1d817beb8317ad1f8f9306 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Fri, 2 Jun 2023 14:14:55 -0500 Subject: [PATCH 27/62] intratask checkpointing working Signed-off-by: Daniel Rammer --- pkg/compiler/transformers/k8s/node.go | 3 +++ pkg/controller/nodes/array/handler.go | 9 ++++++--- pkg/controller/nodes/task/handler.go | 1 - 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/pkg/compiler/transformers/k8s/node.go b/pkg/compiler/transformers/k8s/node.go index 78be27ec7..3392ac6c2 100644 --- a/pkg/compiler/transformers/k8s/node.go +++ b/pkg/compiler/transformers/k8s/node.go @@ -155,6 +155,9 @@ func buildNodeSpec(n *core.Node, tasks []*core.CompiledTask, errs errors.Compile case *core.Node_ArrayNode: arrayNode := n.GetArrayNode() + // since we set retries=1 on the node it's not using the task-level retries + arrayNode.Node.Metadata.Retries = nil // TODO @hamersaw - should probably set node-level retires to task in flytekit + // build subNodeSpecs subNodeSpecs, ok := buildNodeSpec(arrayNode.Node, tasks, errs) if !ok { diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 8258a3a80..27d74d9a5 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -321,7 +321,8 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu outputLiterals := make(map[string]*idlcore.Literal) for i, _ := range arrayNodeState.SubNodePhases.GetItems() { // initialize subNode reader - subDataDir, subOutputDir, err := constructOutputReferences(ctx, nCtx, strconv.Itoa(i)) + currentAttempt := uint32(arrayNodeState.SubNodeRetryAttempts.GetItem(i)) + subDataDir, subOutputDir, err := constructOutputReferences(ctx, nCtx, strconv.Itoa(i), strconv.Itoa(int(currentAttempt))) if err != nil { return handler.UnknownTransition, err } @@ -508,7 +509,8 @@ func (a *arrayNodeHandler) buildArrayNodeContext(ctx context.Context, nCtx inter // TODO @hamersaw - this is a problem because cache lookups happen in NodePhaseQueued // so the cache hit items will be written to the wrong location // can we just change flytekit appending the index onto the location?!?1 - subDataDir, subOutputDir, err := constructOutputReferences(ctx, nCtx, strconv.Itoa(subNodeIndex)) + currentAttempt := uint32(arrayNodeState.SubNodeRetryAttempts.GetItem(subNodeIndex)) + subDataDir, subOutputDir, err := constructOutputReferences(ctx, nCtx, strconv.Itoa(subNodeIndex), strconv.Itoa(int(currentAttempt))) if err != nil { return nil, nil, nil, nil, nil, nil, nil, err } @@ -517,7 +519,7 @@ func (a *arrayNodeHandler) buildArrayNodeContext(ctx context.Context, nCtx inter Phase: nodePhase, DataDir: subDataDir, OutputDir: subOutputDir, - Attempts: uint32(arrayNodeState.SubNodeRetryAttempts.GetItem(subNodeIndex)), + Attempts: currentAttempt, SystemFailures: uint32(arrayNodeState.SubNodeSystemFailures.GetItem(subNodeIndex)), TaskNodeStatus: &v1alpha1.TaskNodeStatus{ Phase: taskPhase, @@ -525,6 +527,7 @@ func (a *arrayNodeHandler) buildArrayNodeContext(ctx context.Context, nCtx inter }, } + // initialize mocks arrayNodeLookup := newArrayNodeLookup(nCtx.ContextualNodeLookup(), subNodeID, &subNodeSpec, subNodeStatus) arrayExecutionContext := newArrayExecutionContext(nCtx.ExecutionContext(), subNodeIndex, currentParallelism, arrayNode.GetParallelism()) diff --git a/pkg/controller/nodes/task/handler.go b/pkg/controller/nodes/task/handler.go index 5f4493978..3b778291f 100644 --- a/pkg/controller/nodes/task/handler.go +++ b/pkg/controller/nodes/task/handler.go @@ -174,7 +174,6 @@ func (p *pluginRequestedTransition) FinalTransition(ctx context.Context) (handle return handler.DoTransition(p.ttype, handler.PhaseInfoSuccess(&p.execInfo)), nil case pluginCore.PhaseRetryableFailure: logger.Debugf(ctx, "Transitioning to RetryableFailure") - fmt.Printf("HAMERSAW - %+v\n", p.pInfo.Err()) return handler.DoTransition(p.ttype, handler.PhaseInfoRetryableFailureErr(p.pInfo.Err(), &p.execInfo)), nil case pluginCore.PhasePermanentFailure: logger.Debugf(ctx, "Transitioning to Failure") From 377497eb87c72bda5c288c12c2049b63855327da Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Tue, 6 Jun 2023 12:44:44 -0500 Subject: [PATCH 28/62] support for and Signed-off-by: Daniel Rammer --- pkg/apis/flyteworkflow/v1alpha1/array.go | 15 ++- pkg/apis/flyteworkflow/v1alpha1/iface.go | 3 +- pkg/compiler/transformers/k8s/node.go | 7 + pkg/controller/nodes/array/handler.go | 159 ++++++++++++++++++----- 4 files changed, 144 insertions(+), 40 deletions(-) diff --git a/pkg/apis/flyteworkflow/v1alpha1/array.go b/pkg/apis/flyteworkflow/v1alpha1/array.go index cdb3a59d9..8d47a8990 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/array.go +++ b/pkg/apis/flyteworkflow/v1alpha1/array.go @@ -4,9 +4,10 @@ import ( ) type ArrayNodeSpec struct { - SubNodeSpec *NodeSpec - Parallelism uint32 - // TODO @hamersaw - fill out ArrayNodeSpec + SubNodeSpec *NodeSpec + Parallelism uint32 + MinSuccesses *uint32 + MinSuccessRatio *float32 } func (a *ArrayNodeSpec) GetSubNodeSpec() *NodeSpec { @@ -16,3 +17,11 @@ func (a *ArrayNodeSpec) GetSubNodeSpec() *NodeSpec { func (a *ArrayNodeSpec) GetParallelism() uint32 { return a.Parallelism } + +func (a *ArrayNodeSpec) GetMinSuccesses() *uint32 { + return a.MinSuccesses +} + +func (a *ArrayNodeSpec) GetMinSuccessRatio() *float32 { + return a.MinSuccessRatio +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/iface.go b/pkg/apis/flyteworkflow/v1alpha1/iface.go index dfffbfd02..fcf1467b0 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/iface.go +++ b/pkg/apis/flyteworkflow/v1alpha1/iface.go @@ -258,7 +258,8 @@ type ExecutableGateNode interface { type ExecutableArrayNode interface { GetSubNodeSpec() *NodeSpec GetParallelism() uint32 - // TODO @hamersaw - complete ExecutableArrayNode + GetMinSuccesses() *uint32 + GetMinSuccessRatio() *float32 } type ExecutableWorkflowNodeStatus interface { diff --git a/pkg/compiler/transformers/k8s/node.go b/pkg/compiler/transformers/k8s/node.go index 3392ac6c2..fab0bfcd2 100644 --- a/pkg/compiler/transformers/k8s/node.go +++ b/pkg/compiler/transformers/k8s/node.go @@ -171,6 +171,13 @@ func buildNodeSpec(n *core.Node, tasks []*core.CompiledTask, errs errors.Compile Parallelism: arrayNode.Parallelism, } + switch successCriteria := arrayNode.SuccessCriteria.(type) { + case *core.ArrayNode_MinSuccesses: + nodeSpec.ArrayNode.MinSuccesses = &successCriteria.MinSuccesses; + case *core.ArrayNode_MinSuccessRatio: + nodeSpec.ArrayNode.MinSuccessRatio = &successCriteria.MinSuccessRatio; + } + // TODO @hamersaw hack - should not be necessary, should be set in flytekit for _, binding := range nodeSpec.InputBindings { switch b := binding.Binding.Binding.Value.(type) { diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 27d74d9a5..57e1ff92a 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "math" "strconv" "time" @@ -22,8 +23,8 @@ import ( "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/k8s" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/codex" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/k8s" "github.com/flyteorg/flytestdlib/bitarray" "github.com/flyteorg/flytestdlib/logger" @@ -33,6 +34,18 @@ import ( "github.com/golang/protobuf/ptypes" ) +var ( + nilLiteral = &idlcore.Literal{ + Value: &idlcore.Literal_Scalar{ + Scalar: &idlcore.Scalar{ + Value: &idlcore.Scalar_NoneType{ + NoneType: &idlcore.Void{}, + }, + }, + }, + } +) + //go:generate mockery -all -case=underscore // arrayNodeHandler is a handle implementation for processing array nodes @@ -276,19 +289,30 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu successCount := 0 failedCount := 0 failingCount := 0 + runningCount := 0 for _, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { nodePhase := v1alpha1.NodePhase(nodePhaseUint64) //fmt.Printf("HAMERSAW - node %d phase %d\n", i, nodePhase) switch nodePhase { - case v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseRecovered: + case v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseRecovered: // TODO @hamersaw NodePhaseSkipped? successCount++ case v1alpha1.NodePhaseFailing: failingCount++ - case v1alpha1.NodePhaseFailed: + case v1alpha1.NodePhaseFailed: // TODO @hamersaw NodePhaseTimedOut? failedCount++ + default: + runningCount++ } } + // calculate minimum number of successes to succeed the ArrayNode + minSuccesses := len(arrayNodeState.SubNodePhases.GetItems()) + if arrayNode.GetMinSuccesses() != nil { + minSuccesses = int(*arrayNode.GetMinSuccesses()) + } else if minSuccessRatio := arrayNode.GetMinSuccessRatio(); minSuccessRatio != nil { + minSuccesses = int(math.Ceil(float64(*minSuccessRatio) * float64(minSuccesses))) + } + // if there is a failing node set the error message if failingCount > 0 && arrayNodeState.Error == nil { arrayNodeState.Error = &idlcore.ExecutionError{ @@ -296,11 +320,18 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu } } - if failedCount > 0 { + if len(arrayNodeState.SubNodePhases.GetItems()) - failedCount < minSuccesses { + // no chance to reach the mininum number of successes arrayNodeState.Phase = v1alpha1.ArrayNodePhaseFailing - } else if successCount == len(arrayNodeState.SubNodePhases.GetItems()) { + } else if successCount >= minSuccesses && runningCount == 0 { + // wait until all tasks have completed before declaring success arrayNodeState.Phase = v1alpha1.ArrayNodePhaseSucceeding } + /*if failedCount > 0 { + arrayNodeState.Phase = v1alpha1.ArrayNodePhaseFailing + } else if successCount == len(arrayNodeState.SubNodePhases.GetItems()) { + arrayNodeState.Phase = v1alpha1.ArrayNodePhaseSucceeding + }*/ case v1alpha1.ArrayNodePhaseFailing: if err := a.Abort(ctx, nCtx, "ArrayNodeFailing"); err != nil { return handler.UnknownTransition, err @@ -319,45 +350,83 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu )), nil case v1alpha1.ArrayNodePhaseSucceeding: outputLiterals := make(map[string]*idlcore.Literal) - for i, _ := range arrayNodeState.SubNodePhases.GetItems() { - // initialize subNode reader - currentAttempt := uint32(arrayNodeState.SubNodeRetryAttempts.GetItem(i)) - subDataDir, subOutputDir, err := constructOutputReferences(ctx, nCtx, strconv.Itoa(i), strconv.Itoa(int(currentAttempt))) - if err != nil { - return handler.UnknownTransition, err - } + for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { + nodePhase := v1alpha1.NodePhase(nodePhaseUint64) - // checkpoint paths are not computed here because this function is only called when writing - // existing cached outputs. if this functionality changes this will need to be revisited. - outputPaths := ioutils.NewCheckpointRemoteFilePaths(ctx, nCtx.DataStore(), subOutputDir, ioutils.NewRawOutputPaths(ctx, subDataDir), "") - reader := ioutils.NewRemoteFileOutputReader(ctx, nCtx.DataStore(), outputPaths, int64(999999999)) + if nodePhase != v1alpha1.NodePhaseSucceeded { + // retrieve output variables from task template + var outputVariables map[string]*idlcore.Variable + task, err := nCtx.ExecutionContext().GetTask(*arrayNode.GetSubNodeSpec().TaskRef) + if err != nil { + // Should never happen + return handler.UnknownTransition, err + } - // read outputs - outputs, executionErr, err := reader.Read(ctx) - if err != nil { - return handler.UnknownTransition, err - } else if executionErr != nil { - // TODO @hamersaw handle executionErr - //return handler.UnknownTransition, executionErr - } + if task.CoreTask() != nil && task.CoreTask().Interface != nil && task.CoreTask().Interface.Outputs != nil { + outputVariables = task.CoreTask().Interface.Outputs.Variables + } - // copy individual subNode output literals into a collection of output literals - for name, literal := range outputs.GetLiterals() { - outputLiteral, exists := outputLiterals[name] - if !exists { - outputLiteral = &idlcore.Literal{ - Value: &idlcore.Literal_Collection{ - Collection: &idlcore.LiteralCollection{ - Literals: make([]*idlcore.Literal, 0, len(arrayNodeState.SubNodePhases.GetItems())), + // append nil literal for all ouput variables + for name, _ := range outputVariables { + appendLiteral(name, nilLiteral, outputLiterals, len(arrayNodeState.SubNodePhases.GetItems())) + /*// TODO @hamersaw - refactor because duplicated below + outputLiteral, exists := outputLiterals[name] + if !exists { + outputLiteral = &idlcore.Literal{ + Value: &idlcore.Literal_Collection{ + Collection: &idlcore.LiteralCollection{ + Literals: make([]*idlcore.Literal, 0, len(arrayNodeState.SubNodePhases.GetItems())), + }, }, - }, + } + + outputLiterals[name] = outputLiteral } - outputLiterals[name] = outputLiteral + collection := outputLiteral.GetCollection() + collection.Literals = append(collection.Literals, nilLiteral)*/ + } + } else { + // initialize subNode reader + currentAttempt := uint32(arrayNodeState.SubNodeRetryAttempts.GetItem(i)) + subDataDir, subOutputDir, err := constructOutputReferences(ctx, nCtx, strconv.Itoa(i), strconv.Itoa(int(currentAttempt))) + if err != nil { + return handler.UnknownTransition, err } - collection := outputLiteral.GetCollection() - collection.Literals = append(collection.Literals, literal) + // checkpoint paths are not computed here because this function is only called when writing + // existing cached outputs. if this functionality changes this will need to be revisited. + outputPaths := ioutils.NewCheckpointRemoteFilePaths(ctx, nCtx.DataStore(), subOutputDir, ioutils.NewRawOutputPaths(ctx, subDataDir), "") + reader := ioutils.NewRemoteFileOutputReader(ctx, nCtx.DataStore(), outputPaths, int64(999999999)) + + // read outputs + outputs, executionErr, err := reader.Read(ctx) + if err != nil { + return handler.UnknownTransition, err + } else if executionErr != nil { + // TODO @hamersaw handle executionErr + //return handler.UnknownTransition, executionErr + } + + // copy individual subNode output literals into a collection of output literals + for name, literal := range outputs.GetLiterals() { + appendLiteral(name, literal, outputLiterals, len(arrayNodeState.SubNodePhases.GetItems())) + /*outputLiteral, exists := outputLiterals[name] + if !exists { + outputLiteral = &idlcore.Literal{ + Value: &idlcore.Literal_Collection{ + Collection: &idlcore.LiteralCollection{ + Literals: make([]*idlcore.Literal, 0, len(arrayNodeState.SubNodePhases.GetItems())), + }, + }, + } + + outputLiterals[name] = outputLiteral + } + + collection := outputLiteral.GetCollection() + collection.Literals = append(collection.Literals, literal)*/ + } } } @@ -540,6 +609,24 @@ func (a *arrayNodeHandler) buildArrayNodeContext(ctx context.Context, nCtx inter return arrayNodeExecutor, arrayExecutionContext, &arrayNodeLookup, &arrayNodeLookup, &subNodeSpec, subNodeStatus, arrayEventRecorder, nil } +func appendLiteral(name string, literal *idlcore.Literal, outputLiterals map[string]*idlcore.Literal, length int) { + outputLiteral, exists := outputLiterals[name] + if !exists { + outputLiteral = &idlcore.Literal{ + Value: &idlcore.Literal_Collection{ + Collection: &idlcore.LiteralCollection{ + Literals: make([]*idlcore.Literal, 0, length), + }, + }, + } + + outputLiterals[name] = outputLiteral + } + + collection := outputLiteral.GetCollection() + collection.Literals = append(collection.Literals, literal) +} + func bytesFromK8sPluginState(pluginState k8s.PluginState) ([]byte, error) { buffer := make([]byte, 0, task.MaxPluginStateSizeBytes) bufferWriter := bytes.NewBuffer(buffer) From 5cf3259e03a490f8e4dd2b88e62169b6acfcae84 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Tue, 6 Jun 2023 17:01:06 -0500 Subject: [PATCH 29/62] setting node log ids correctly Signed-off-by: Daniel Rammer --- pkg/controller/nodes/array/handler.go | 43 ++++----------------------- 1 file changed, 6 insertions(+), 37 deletions(-) diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 57e1ff92a..83790561e 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -260,6 +260,11 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu taskPhase = task.ToTaskEventPhase(core.Phase(taskNodeStatus.GetPhase())) } + for _, log := range taskExecutionEvent.Logs { + // TODO @hamersaw - do we need to add retryattempt? + log.Name = fmt.Sprintf("-%d", log.Name, i) + } + externalResources = append(externalResources, &event.ExternalResourceInfo{ ExternalId: fmt.Sprintf("%s-%d", nCtx.NodeID, i), // TODO @hamersaw do better Index: uint32(i), @@ -327,11 +332,6 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // wait until all tasks have completed before declaring success arrayNodeState.Phase = v1alpha1.ArrayNodePhaseSucceeding } - /*if failedCount > 0 { - arrayNodeState.Phase = v1alpha1.ArrayNodePhaseFailing - } else if successCount == len(arrayNodeState.SubNodePhases.GetItems()) { - arrayNodeState.Phase = v1alpha1.ArrayNodePhaseSucceeding - }*/ case v1alpha1.ArrayNodePhaseFailing: if err := a.Abort(ctx, nCtx, "ArrayNodeFailing"); err != nil { return handler.UnknownTransition, err @@ -369,22 +369,6 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // append nil literal for all ouput variables for name, _ := range outputVariables { appendLiteral(name, nilLiteral, outputLiterals, len(arrayNodeState.SubNodePhases.GetItems())) - /*// TODO @hamersaw - refactor because duplicated below - outputLiteral, exists := outputLiterals[name] - if !exists { - outputLiteral = &idlcore.Literal{ - Value: &idlcore.Literal_Collection{ - Collection: &idlcore.LiteralCollection{ - Literals: make([]*idlcore.Literal, 0, len(arrayNodeState.SubNodePhases.GetItems())), - }, - }, - } - - outputLiterals[name] = outputLiteral - } - - collection := outputLiteral.GetCollection() - collection.Literals = append(collection.Literals, nilLiteral)*/ } } else { // initialize subNode reader @@ -411,21 +395,6 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // copy individual subNode output literals into a collection of output literals for name, literal := range outputs.GetLiterals() { appendLiteral(name, literal, outputLiterals, len(arrayNodeState.SubNodePhases.GetItems())) - /*outputLiteral, exists := outputLiterals[name] - if !exists { - outputLiteral = &idlcore.Literal{ - Value: &idlcore.Literal_Collection{ - Collection: &idlcore.LiteralCollection{ - Literals: make([]*idlcore.Literal, 0, len(arrayNodeState.SubNodePhases.GetItems())), - }, - }, - } - - outputLiterals[name] = outputLiteral - } - - collection := outputLiteral.GetCollection() - collection.Literals = append(collection.Literals, literal)*/ } } } @@ -465,7 +434,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu ResourceType: idlcore.ResourceType_TASK, Project: workflowExecutionId.Project, Domain: workflowExecutionId.Domain, - Name: "foo", // TODO @hamersaw - make it better + Name: nCtx.NodeID(), //"foo", // TODO @hamersaw - make it better Version: "v1", // TODO @hamersaw - please }, ParentNodeExecutionId: nCtx.NodeExecutionMetadata().GetNodeExecutionID(), From 437dc91888375677852739a29abaa6f1f095a140 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Tue, 6 Jun 2023 20:12:26 -0500 Subject: [PATCH 30/62] reporting cache status Signed-off-by: Daniel Rammer --- .../nodes/array/execution_context.go | 9 ++++++++- pkg/controller/nodes/array/handler.go | 18 ++++++++++++++---- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/pkg/controller/nodes/array/execution_context.go b/pkg/controller/nodes/array/execution_context.go index c53e092e9..fb8f54590 100644 --- a/pkg/controller/nodes/array/execution_context.go +++ b/pkg/controller/nodes/array/execution_context.go @@ -57,10 +57,12 @@ func newArrayExecutionContext(executionContext executors.ExecutionContext, subNo } type arrayEventRecorder struct { + nodeEvents []*event.NodeExecutionEvent taskEvents []*event.TaskExecutionEvent } func (a *arrayEventRecorder) RecordNodeEvent(ctx context.Context, event *event.NodeExecutionEvent, eventConfig *config.EventConfig) error { + a.nodeEvents = append(a.nodeEvents, event) return nil } @@ -69,12 +71,17 @@ func (a *arrayEventRecorder) RecordTaskEvent(ctx context.Context, event *event.T return nil } -func (a *arrayEventRecorder) Events() []*event.TaskExecutionEvent { +func (a *arrayEventRecorder) NodeEvents() []*event.NodeExecutionEvent { + return a.nodeEvents +} + +func (a *arrayEventRecorder) TaskEvents() []*event.TaskExecutionEvent { return a.taskEvents } func newArrayEventRecorder() *arrayEventRecorder { return &arrayEventRecorder{ + nodeEvents: make([]*event.NodeExecutionEvent, 0), taskEvents: make([]*event.TaskExecutionEvent, 0), } } diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 83790561e..e81709d96 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -253,8 +253,18 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu messageCollector.Collect(i, subNodeStatus.Error.Message) } - // process TaskExecutionEvents - for _, taskExecutionEvent := range arrayEventRecorder.Events() { + // process events + cacheStatus := idlcore.CatalogCacheStatus_CACHE_DISABLED + for _, nodeExecutionEvent := range arrayEventRecorder.NodeEvents() { + switch target := nodeExecutionEvent.TargetMetadata.(type) { + case *event.NodeExecutionEvent_TaskNodeMetadata: + if target.TaskNodeMetadata != nil { + cacheStatus = target.TaskNodeMetadata.CacheStatus + } + } + } + + for _, taskExecutionEvent := range arrayEventRecorder.TaskEvents() { taskPhase := idlcore.TaskExecution_UNDEFINED if taskNodeStatus := subNodeStatus.GetTaskNodeStatus(); taskNodeStatus != nil { taskPhase = task.ToTaskEventPhase(core.Phase(taskNodeStatus.GetPhase())) @@ -262,7 +272,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu for _, log := range taskExecutionEvent.Logs { // TODO @hamersaw - do we need to add retryattempt? - log.Name = fmt.Sprintf("-%d", log.Name, i) + log.Name = fmt.Sprintf("%s-%d", log.Name, i) } externalResources = append(externalResources, &event.ExternalResourceInfo{ @@ -271,7 +281,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu Logs: taskExecutionEvent.Logs, RetryAttempt: 0, Phase: taskPhase, - //CacheStatus: taskExecutionEvent.Metadata, // TODO @hamersaw - figure out how to get CacheStatus back + CacheStatus: cacheStatus, // TODO @hamersaw - figure out how to get CacheStatus back }) } From d2abcccd0215f14592932bd640c866214e2dac42 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Mon, 12 Jun 2023 10:49:33 -0500 Subject: [PATCH 31/62] correctly setting subnode abort phase Signed-off-by: Daniel Rammer --- pkg/controller/nodes/array/handler.go | 62 +++++++++++++++++++++++++-- 1 file changed, 59 insertions(+), 3 deletions(-) diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index e81709d96..0099f530e 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -73,6 +73,7 @@ func (a *arrayNodeHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecut arrayNode := nCtx.Node().GetArrayNode() arrayNodeState := nCtx.NodeStateReader().GetArrayNodeState() + externalResources := make([]*event.ExternalResourceInfo, 0, len(arrayNodeState.SubNodePhases.GetItems())) messageCollector := errorcollector.NewErrorMessageCollector() switch arrayNodeState.Phase { case v1alpha1.ArrayNodePhaseExecuting, v1alpha1.ArrayNodePhaseFailing: @@ -94,6 +95,14 @@ func (a *arrayNodeHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecut err = arrayNodeExecutor.AbortHandler(ctx, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec, reason) if err != nil { messageCollector.Collect(i, err.Error()) + } else { + externalResources = append(externalResources, &event.ExternalResourceInfo{ + ExternalId: fmt.Sprintf("%s-%d", nCtx.NodeID, i), // TODO @hamersaw do better + Index: uint32(i), + Logs: nil, + RetryAttempt: 0, + Phase: idlcore.TaskExecution_ABORTED, + }) } } } @@ -102,6 +111,18 @@ func (a *arrayNodeHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecut return fmt.Errorf(messageCollector.Summary(512)) // TODO @hamersaw - make configurable } + // TODO @hamersaw - update aborted state for subnodes + taskExecutionEvent, err := buildTaskExecutionEvent(ctx, nCtx, idlcore.TaskExecution_ABORTED, 0, externalResources) + if err != nil { + return err + } + + // TODO @hamersaw - pass eventConfig correctly + if err := nCtx.EventsRecorder().RecordTaskEvent(ctx, taskExecutionEvent, &config.EventConfig{}); err != nil { + logger.Errorf(ctx, "ArrayNode event recording failed: [%s]", err.Error()) + return err + } + return nil } @@ -281,7 +302,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu Logs: taskExecutionEvent.Logs, RetryAttempt: 0, Phase: taskPhase, - CacheStatus: cacheStatus, // TODO @hamersaw - figure out how to get CacheStatus back + CacheStatus: cacheStatus, }) } @@ -432,7 +453,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // TODO @hamersaw - send task-level events - this requires externalResources to emulate current maptasks if len(externalResources) > 0 { - occurredAt, err := ptypes.TimestampProto(time.Now()) + /*occurredAt, err := ptypes.TimestampProto(time.Now()) if err != nil { return handler.UnknownTransition, err } @@ -444,7 +465,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu ResourceType: idlcore.ResourceType_TASK, Project: workflowExecutionId.Project, Domain: workflowExecutionId.Domain, - Name: nCtx.NodeID(), //"foo", // TODO @hamersaw - make it better + Name: nCtx.NodeID(), Version: "v1", // TODO @hamersaw - please }, ParentNodeExecutionId: nCtx.NodeExecutionMetadata().GetNodeExecutionID(), @@ -457,6 +478,12 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu }, TaskType: "k8s-array", EventVersion: 1, + }*/ + + // TODO @hamersaw - determine node phase from ArrayNodePhase (ie. Queued, Running, Succeeded, Failed) + taskExecutionEvent, err := buildTaskExecutionEvent(ctx, nCtx, idlcore.TaskExecution_RUNNING, taskPhaseVersion, externalResources) + if err != nil { + return handler.UnknownTransition, err } // TODO @hamersaw - pass eventConfig correctly @@ -606,6 +633,35 @@ func appendLiteral(name string, literal *idlcore.Literal, outputLiterals map[str collection.Literals = append(collection.Literals, literal) } +func buildTaskExecutionEvent(ctx context.Context, nCtx interfaces.NodeExecutionContext, taskPhase idlcore.TaskExecution_Phase, taskPhaseVersion uint32, externalResources []*event.ExternalResourceInfo) (*event.TaskExecutionEvent, error) { + occurredAt, err := ptypes.TimestampProto(time.Now()) + if err != nil { + return nil, err + } + + nodeExecutionId := nCtx.NodeExecutionMetadata().GetNodeExecutionID() + workflowExecutionId := nodeExecutionId.ExecutionId + return &event.TaskExecutionEvent{ + TaskId: &idlcore.Identifier{ + ResourceType: idlcore.ResourceType_TASK, + Project: workflowExecutionId.Project, + Domain: workflowExecutionId.Domain, + Name: nCtx.NodeID(), + Version: "v1", // TODO @hamersaw - please + }, + ParentNodeExecutionId: nCtx.NodeExecutionMetadata().GetNodeExecutionID(), + RetryAttempt: 0, // ArrayNode will never retry + Phase: taskPhase, + PhaseVersion: taskPhaseVersion, + OccurredAt: occurredAt, + Metadata: &event.TaskExecutionMetadata{ + ExternalResources: externalResources, + }, + TaskType: "k8s-array", + EventVersion: 1, + }, nil +} + func bytesFromK8sPluginState(pluginState k8s.PluginState) ([]byte, error) { buffer := make([]byte, 0, task.MaxPluginStateSizeBytes) bufferWriter := bytes.NewBuffer(buffer) From 1de3e16741cd03dc6b33b66bbbdd881a8c2808b3 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Mon, 12 Jun 2023 11:07:33 -0500 Subject: [PATCH 32/62] removing dead code Signed-off-by: Daniel Rammer --- pkg/controller/nodes/array/handler.go | 51 ++++++++------------------- 1 file changed, 15 insertions(+), 36 deletions(-) diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 0099f530e..3f04e14e8 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -97,7 +97,7 @@ func (a *arrayNodeHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecut messageCollector.Collect(i, err.Error()) } else { externalResources = append(externalResources, &event.ExternalResourceInfo{ - ExternalId: fmt.Sprintf("%s-%d", nCtx.NodeID, i), // TODO @hamersaw do better + ExternalId: buildSubNodeID(nCtx, i, 0), Index: uint32(i), Logs: nil, RetryAttempt: 0, @@ -111,7 +111,7 @@ func (a *arrayNodeHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecut return fmt.Errorf(messageCollector.Summary(512)) // TODO @hamersaw - make configurable } - // TODO @hamersaw - update aborted state for subnodes + // update aborted state for subnodes taskExecutionEvent, err := buildTaskExecutionEvent(ctx, nCtx, idlcore.TaskExecution_ABORTED, 0, externalResources) if err != nil { return err @@ -235,7 +235,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu externalResources = make([]*event.ExternalResourceInfo, 0, size) for i := 0; i < size; i++ { externalResources = append(externalResources, &event.ExternalResourceInfo{ - ExternalId: fmt.Sprintf("%s-%d", nCtx.NodeID, i), // TODO @hamersaw do better + ExternalId: buildSubNodeID(nCtx, i, 0), Index: uint32(i), Logs: nil, RetryAttempt: 0, @@ -285,6 +285,8 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu } } + retryAttempt := subNodeStatus.GetAttempts() + for _, taskExecutionEvent := range arrayEventRecorder.TaskEvents() { taskPhase := idlcore.TaskExecution_UNDEFINED if taskNodeStatus := subNodeStatus.GetTaskNodeStatus(); taskNodeStatus != nil { @@ -297,18 +299,15 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu } externalResources = append(externalResources, &event.ExternalResourceInfo{ - ExternalId: fmt.Sprintf("%s-%d", nCtx.NodeID, i), // TODO @hamersaw do better + ExternalId: buildSubNodeID(nCtx, i, retryAttempt), Index: uint32(i), Logs: taskExecutionEvent.Logs, - RetryAttempt: 0, + RetryAttempt: retryAttempt, Phase: taskPhase, CacheStatus: cacheStatus, }) } - //fmt.Printf("HAMERSAW - '%d' transition node phase %d -> %d task phase '%d' -> '%d'\n", i, - // nodePhase, subNodeStatus.GetPhase(), taskPhase, subNodeStatus.GetTaskNodeStatus().GetPhase()) - // update subNode state arrayNodeState.SubNodePhases.SetItem(i, uint64(subNodeStatus.GetPhase())) if subNodeStatus.GetTaskNodeStatus() == nil { @@ -451,35 +450,9 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // TODO @hamersaw - fail } - // TODO @hamersaw - send task-level events - this requires externalResources to emulate current maptasks + // if there were changes to subnode status externalResources will be populated and must be + // reported to admin through a TaskExecutionEvent. if len(externalResources) > 0 { - /*occurredAt, err := ptypes.TimestampProto(time.Now()) - if err != nil { - return handler.UnknownTransition, err - } - - nodeExecutionId := nCtx.NodeExecutionMetadata().GetNodeExecutionID() - workflowExecutionId := nodeExecutionId.ExecutionId - taskExecutionEvent := &event.TaskExecutionEvent{ - TaskId: &idlcore.Identifier{ - ResourceType: idlcore.ResourceType_TASK, - Project: workflowExecutionId.Project, - Domain: workflowExecutionId.Domain, - Name: nCtx.NodeID(), - Version: "v1", // TODO @hamersaw - please - }, - ParentNodeExecutionId: nCtx.NodeExecutionMetadata().GetNodeExecutionID(), - RetryAttempt: 0, // ArrayNode will never retry - Phase: 2, // TODO @hamersaw - determine node phase from ArrayNodePhase (ie. Queued, Running, Succeeded, Failed) - PhaseVersion: taskPhaseVersion, - OccurredAt: occurredAt, - Metadata: &event.TaskExecutionMetadata{ - ExternalResources: externalResources, - }, - TaskType: "k8s-array", - EventVersion: 1, - }*/ - // TODO @hamersaw - determine node phase from ArrayNodePhase (ie. Queued, Running, Succeeded, Failed) taskExecutionEvent, err := buildTaskExecutionEvent(ctx, nCtx, idlcore.TaskExecution_RUNNING, taskPhaseVersion, externalResources) if err != nil { @@ -662,6 +635,12 @@ func buildTaskExecutionEvent(ctx context.Context, nCtx interfaces.NodeExecutionC }, nil } +// TODO @hamersaw - what do we want for a subnode ID? +func buildSubNodeID(nCtx interfaces.NodeExecutionContext, index int, retryAttempt uint32) string { + return fmt.Sprintf("%s-n%d-%d", nCtx.NodeID, index, retryAttempt) +} + + func bytesFromK8sPluginState(pluginState k8s.PluginState) ([]byte, error) { buffer := make([]byte, 0, task.MaxPluginStateSizeBytes) bufferWriter := bytes.NewBuffer(buffer) From a66bb602052654932413ae7f72dc6b1773c095c9 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Mon, 12 Jun 2023 14:00:49 -0500 Subject: [PATCH 33/62] cleaned up most random TODO items Signed-off-by: Daniel Rammer --- events/event_recorder.go | 8 +- pkg/controller/nodes/array/handler.go | 112 ++++++++++++++---------- pkg/controller/nodes/handler_factory.go | 2 +- 3 files changed, 69 insertions(+), 53 deletions(-) diff --git a/events/event_recorder.go b/events/event_recorder.go index b07cc412b..7366cd1cf 100644 --- a/events/event_recorder.go +++ b/events/event_recorder.go @@ -13,7 +13,7 @@ import ( "github.com/golang/protobuf/proto" ) -const maxErrorMessageLength = 104857600 //100KB +const MaxErrorMessageLength = 104857600 //100KB const truncationIndicator = "... ..." type recordingMetrics struct { @@ -60,7 +60,7 @@ func (r *eventRecorder) sinkEvent(ctx context.Context, event proto.Message) erro func (r *eventRecorder) RecordNodeEvent(ctx context.Context, e *event.NodeExecutionEvent) error { if err, ok := e.GetOutputResult().(*event.NodeExecutionEvent_Error); ok { - truncateErrorMessage(err.Error, maxErrorMessageLength) + truncateErrorMessage(err.Error, MaxErrorMessageLength) } return r.sinkEvent(ctx, e) @@ -68,7 +68,7 @@ func (r *eventRecorder) RecordNodeEvent(ctx context.Context, e *event.NodeExecut func (r *eventRecorder) RecordTaskEvent(ctx context.Context, e *event.TaskExecutionEvent) error { if err, ok := e.GetOutputResult().(*event.TaskExecutionEvent_Error); ok { - truncateErrorMessage(err.Error, maxErrorMessageLength) + truncateErrorMessage(err.Error, MaxErrorMessageLength) } return r.sinkEvent(ctx, e) @@ -76,7 +76,7 @@ func (r *eventRecorder) RecordTaskEvent(ctx context.Context, e *event.TaskExecut func (r *eventRecorder) RecordWorkflowEvent(ctx context.Context, e *event.WorkflowExecutionEvent) error { if err, ok := e.GetOutputResult().(*event.WorkflowExecutionEvent_Error); ok { - truncateErrorMessage(err.Error, maxErrorMessageLength) + truncateErrorMessage(err.Error, MaxErrorMessageLength) } return r.sinkEvent(ctx, e) diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 3f04e14e8..074c4bc7b 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -15,6 +15,7 @@ import ( "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/errorcollector" + "github.com/flyteorg/flytepropeller/events" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" "github.com/flyteorg/flytepropeller/pkg/compiler/validators" "github.com/flyteorg/flytepropeller/pkg/controller/config" @@ -50,6 +51,7 @@ var ( // arrayNodeHandler is a handle implementation for processing array nodes type arrayNodeHandler struct { + eventConfig *config.EventConfig metrics metrics nodeExecutor interfaces.Node pluginStateBytesNotStarted []byte @@ -81,9 +83,8 @@ func (a *arrayNodeHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecut for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { nodePhase := v1alpha1.NodePhase(nodePhaseUint64) - // TODO @hamersaw fix - do not process nodes that haven't started or are in a terminal state - //if nodes.IsNotyetStarted(nodePhase) || nodes.IsTerminalNodePhase(nodePhase) { - if nodePhase == v1alpha1.NodePhaseNotYetStarted || nodePhase == v1alpha1.NodePhaseSucceeded || nodePhase == v1alpha1.NodePhaseFailed || nodePhase == v1alpha1.NodePhaseTimedOut || nodePhase == v1alpha1.NodePhaseSkipped || nodePhase == v1alpha1.NodePhaseRecovered { + // do not process nodes that have not started or are in a terminal state + if nodePhase == v1alpha1.NodePhaseNotYetStarted || isTerminalNodePhase(nodePhase) { continue } @@ -108,7 +109,7 @@ func (a *arrayNodeHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecut } if messageCollector.Length() > 0 { - return fmt.Errorf(messageCollector.Summary(512)) // TODO @hamersaw - make configurable + return fmt.Errorf(messageCollector.Summary(events.MaxErrorMessageLength)) } // update aborted state for subnodes @@ -117,8 +118,7 @@ func (a *arrayNodeHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecut return err } - // TODO @hamersaw - pass eventConfig correctly - if err := nCtx.EventsRecorder().RecordTaskEvent(ctx, taskExecutionEvent, &config.EventConfig{}); err != nil { + if err := nCtx.EventsRecorder().RecordTaskEvent(ctx, taskExecutionEvent, a.eventConfig); err != nil { logger.Errorf(ctx, "ArrayNode event recording failed: [%s]", err.Error()) return err } @@ -138,8 +138,7 @@ func (a *arrayNodeHandler) Finalize(ctx context.Context, nCtx interfaces.NodeExe for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { nodePhase := v1alpha1.NodePhase(nodePhaseUint64) - // TODO @hamersaw fix - do not process nodes that haven't started - //if nodes.IsNotyetStarted(nodePhase) { + // do not process nodes that have not started if nodePhase == v1alpha1.NodePhaseNotYetStarted { continue } @@ -157,7 +156,7 @@ func (a *arrayNodeHandler) Finalize(ctx context.Context, nCtx interfaces.NodeExe } if messageCollector.Length() > 0 { - return fmt.Errorf(messageCollector.Summary(512)) // TODO @hamersaw - make configurable + return fmt.Errorf(messageCollector.Summary(events.MaxErrorMessageLength)) } return nil @@ -175,11 +174,12 @@ func (a *arrayNodeHandler) FinalizeRequired() bool { func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.Transition, error) { arrayNode := nCtx.Node().GetArrayNode() arrayNodeState := nCtx.NodeStateReader().GetArrayNodeState() + currentArrayNodePhase := arrayNodeState.Phase var externalResources []*event.ExternalResourceInfo taskPhaseVersion := arrayNodeState.TaskPhaseVersion - switch arrayNodeState.Phase { + switch currentArrayNodePhase { case v1alpha1.ArrayNodePhaseNone: // identify and validate array node input value lengths literalMap, err := nCtx.InputReader().Get(ctx) @@ -219,7 +219,9 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu } for _, item := range []struct{arrayReference *bitarray.CompactArray; maxValue int}{ - {arrayReference: &arrayNodeState.SubNodePhases, maxValue: len(core.Phases)-1}, // TODO @hamersaw - maxValue is for task phases + // we use NodePhaseRecovered for the `maxValue` of `SubNodePhases` because `Phase` is + // defined as an `iota` so it is impossible to programmatically get largest value + {arrayReference: &arrayNodeState.SubNodePhases, maxValue: int(v1alpha1.NodePhaseRecovered)}, {arrayReference: &arrayNodeState.SubNodeTaskPhases, maxValue: len(core.Phases)-1}, {arrayReference: &arrayNodeState.SubNodeRetryAttempts, maxValue: maxAttempts}, {arrayReference: &arrayNodeState.SubNodeSystemFailures, maxValue: maxAttempts}, @@ -253,9 +255,8 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { nodePhase := v1alpha1.NodePhase(nodePhaseUint64) - // TODO @hamersaw fix - do not process nodes in terminal state - //if nodes.IsTerminalNodePhase(nodePhase) { - if nodePhase == v1alpha1.NodePhaseSucceeded || nodePhase == v1alpha1.NodePhaseFailed || nodePhase == v1alpha1.NodePhaseTimedOut || nodePhase == v1alpha1.NodePhaseSkipped || nodePhase == v1alpha1.NodePhaseRecovered { + // do not process nodes in terminal state + if isTerminalNodePhase(nodePhase) { continue } @@ -294,7 +295,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu } for _, log := range taskExecutionEvent.Logs { - // TODO @hamersaw - do we need to add retryattempt? + // TODO @hamersaw - do we need to add retryAttempt to log name? log.Name = fmt.Sprintf("%s-%d", log.Name, i) } @@ -311,7 +312,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // update subNode state arrayNodeState.SubNodePhases.SetItem(i, uint64(subNodeStatus.GetPhase())) if subNodeStatus.GetTaskNodeStatus() == nil { - // TODO @hamersaw during retries we clear the GetTaskNodeStatus - so resetting task phase + // resetting task phase because during retries we clear the GetTaskNodeStatus arrayNodeState.SubNodeTaskPhases.SetItem(i, uint64(0)) } else { arrayNodeState.SubNodeTaskPhases.SetItem(i, uint64(subNodeStatus.GetTaskNodeStatus().GetPhase())) @@ -327,13 +328,12 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu runningCount := 0 for _, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { nodePhase := v1alpha1.NodePhase(nodePhaseUint64) - //fmt.Printf("HAMERSAW - node %d phase %d\n", i, nodePhase) switch nodePhase { - case v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseRecovered: // TODO @hamersaw NodePhaseSkipped? + case v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseRecovered, v1alpha1.NodePhaseSkipped: successCount++ case v1alpha1.NodePhaseFailing: failingCount++ - case v1alpha1.NodePhaseFailed: // TODO @hamersaw NodePhaseTimedOut? + case v1alpha1.NodePhaseFailed, v1alpha1.NodePhaseTimedOut: failedCount++ default: runningCount++ @@ -348,10 +348,10 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu minSuccesses = int(math.Ceil(float64(*minSuccessRatio) * float64(minSuccesses))) } - // if there is a failing node set the error message + // if there is a failing node set the error message if it has not been previous set if failingCount > 0 && arrayNodeState.Error == nil { arrayNodeState.Error = &idlcore.ExecutionError{ - Message: messageCollector.Summary(512), // TODO @hamersaw - make configurable + Message: messageCollector.Summary(events.MaxErrorMessageLength), } } @@ -418,8 +418,8 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu if err != nil { return handler.UnknownTransition, err } else if executionErr != nil { - // TODO @hamersaw handle executionErr - //return handler.UnknownTransition, executionErr + return handler.UnknownTransition, errors.Errorf(errors.IllegalStateError, nCtx.NodeID(), + "execution error ArrayNode output, bad state: %s", executionErr.String()) } // copy individual subNode output literals into a collection of output literals @@ -433,7 +433,6 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu Literals: outputLiterals, } - //fmt.Printf("HAMERSAW - final outputs %+v\n", outputLiteralMap) outputFile := v1alpha1.GetOutputsFile(nCtx.NodeStatus().GetOutputDir()) if err := nCtx.DataStore().WriteProtobuf(ctx, outputFile, storage.Options{}, outputLiteralMap); err != nil { return handler.UnknownTransition, err @@ -447,27 +446,44 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu }, )), nil default: - // TODO @hamersaw - fail + return handler.UnknownTransition, errors.Errorf(errors.IllegalStateError, nCtx.NodeID(), "invalid ArrayNode phase %+v", arrayNodeState.Phase) } // if there were changes to subnode status externalResources will be populated and must be // reported to admin through a TaskExecutionEvent. if len(externalResources) > 0 { - // TODO @hamersaw - determine node phase from ArrayNodePhase (ie. Queued, Running, Succeeded, Failed) - taskExecutionEvent, err := buildTaskExecutionEvent(ctx, nCtx, idlcore.TaskExecution_RUNNING, taskPhaseVersion, externalResources) + // determine task phase from ArrayNodePhase + taskPhase := idlcore.TaskExecution_UNDEFINED + switch currentArrayNodePhase { + case v1alpha1.ArrayNodePhaseNone: + taskPhase = idlcore.TaskExecution_QUEUED + case v1alpha1.ArrayNodePhaseExecuting: + taskPhase = idlcore.TaskExecution_RUNNING + case v1alpha1.ArrayNodePhaseSucceeding: + taskPhase = idlcore.TaskExecution_SUCCEEDED + case v1alpha1.ArrayNodePhaseFailing: + taskPhase = idlcore.TaskExecution_FAILED + } + + // need to increment taskPhaseVersion if arrayNodeState.Phase does not change, otherwise + // reset to 0. by incrementing this always we report an event and ensure processing + // everytime the ArrayNode is evaluated. if this overhead becomes too large, we will need + // to revisit and only increment when any subnode state changes. + if currentArrayNodePhase != arrayNodeState.Phase { + arrayNodeState.TaskPhaseVersion = 0 + } else { + arrayNodeState.TaskPhaseVersion = taskPhaseVersion+1 + } + + taskExecutionEvent, err := buildTaskExecutionEvent(ctx, nCtx, taskPhase, taskPhaseVersion, externalResources) if err != nil { return handler.UnknownTransition, err } - // TODO @hamersaw - pass eventConfig correctly - if err := nCtx.EventsRecorder().RecordTaskEvent(ctx, taskExecutionEvent, &config.EventConfig{}); err != nil { + if err := nCtx.EventsRecorder().RecordTaskEvent(ctx, taskExecutionEvent, a.eventConfig); err != nil { logger.Errorf(ctx, "ArrayNode event recording failed: [%s]", err.Error()) return handler.UnknownTransition, err } - - // TODO @hamersaw - only need to increment if arrayNodeState.Phase does not change - // if it does we can reset to 0 - arrayNodeState.TaskPhaseVersion = taskPhaseVersion+1 } // update array node status @@ -485,7 +501,7 @@ func (a *arrayNodeHandler) Setup(_ context.Context, _ handler.SetupContext) erro } // New initializes a new arrayNodeHandler -func New(nodeExecutor interfaces.Node, scope promutils.Scope) (handler.Node, error) { +func New(nodeExecutor interfaces.Node, eventConfig *config.EventConfig, scope promutils.Scope) (handler.Node, error) { // create k8s PluginState byte mocks to reuse instead of creating for each subNode evaluation pluginStateBytesNotStarted, err := bytesFromK8sPluginState(k8s.PluginState{Phase: k8s.PluginPhaseNotStarted}) if err != nil { @@ -499,6 +515,7 @@ func New(nodeExecutor interfaces.Node, scope promutils.Scope) (handler.Node, err arrayScope := scope.NewSubScope("array") return &arrayNodeHandler{ + eventConfig: eventConfig, metrics: newMetrics(arrayScope), nodeExecutor: nodeExecutor, pluginStateBytesNotStarted: pluginStateBytesNotStarted, @@ -522,9 +539,8 @@ func (a *arrayNodeHandler) buildArrayNodeContext(ctx context.Context, nCtx inter // if node has not yet started we automatically set to NodePhaseQueued to skip input resolution if nodePhase == v1alpha1.NodePhaseNotYetStarted { - // TODO @hamersaw how does this work with fastcache? - // to supprt fastcache we'll need to override the bindings to BindingScalars for the input resolution on the nCtx - // that way we resolution is just reading a literal ... but does this still write a file then?!? + // TODO - to supprt fastcache we'll need to override the bindings to BindingScalars for the input resolution on the nCtx + // that way resolution is just reading a literal ... but does this still write a file then?!? nodePhase = v1alpha1.NodePhaseQueued } @@ -536,10 +552,9 @@ func (a *arrayNodeHandler) buildArrayNodeContext(ctx context.Context, nCtx inter subNodeSpec.Name = subNodeID // TODO - if we want to support more plugin types we need to figure out the best way to store plugin state - // currently just mocking based on node phase -> which works for all k8s plugins + // currently just mocking based on node phase -> which works for all k8s plugins // we can not pre-allocated a bit array because max size is 256B and with 5k fanout node state = 1.28MB pluginStateBytes := a.pluginStateBytesStarted - //if nodePhase == v1alpha1.NodePhaseQueued || nodePhase == v1alpha1.NodePhaseRetryableFailure { if taskPhase == int(core.PhaseUndefined) || taskPhase == int(core.PhaseRetryableFailure) { pluginStateBytes = a.pluginStateBytesNotStarted } @@ -554,9 +569,6 @@ func (a *arrayNodeHandler) buildArrayNodeContext(ctx context.Context, nCtx inter } else { subDataDir, subOutputDir, err = constructOutputReferences(ctx, nCtx, strconv.Itoa(i)) }*/ - // TODO @hamersaw - this is a problem because cache lookups happen in NodePhaseQueued - // so the cache hit items will be written to the wrong location - // can we just change flytekit appending the index onto the location?!?1 currentAttempt := uint32(arrayNodeState.SubNodeRetryAttempts.GetItem(subNodeIndex)) subDataDir, subOutputDir, err := constructOutputReferences(ctx, nCtx, strconv.Itoa(subNodeIndex), strconv.Itoa(int(currentAttempt))) if err != nil { @@ -612,15 +624,15 @@ func buildTaskExecutionEvent(ctx context.Context, nCtx interfaces.NodeExecutionC return nil, err } - nodeExecutionId := nCtx.NodeExecutionMetadata().GetNodeExecutionID() - workflowExecutionId := nodeExecutionId.ExecutionId + nodeExecutionID := nCtx.NodeExecutionMetadata().GetNodeExecutionID() + workflowExecutionID := nodeExecutionID.ExecutionId return &event.TaskExecutionEvent{ TaskId: &idlcore.Identifier{ ResourceType: idlcore.ResourceType_TASK, - Project: workflowExecutionId.Project, - Domain: workflowExecutionId.Domain, + Project: workflowExecutionID.Project, + Domain: workflowExecutionID.Domain, Name: nCtx.NodeID(), - Version: "v1", // TODO @hamersaw - please + Version: "v1", // this value is irrelevant but necessary for the identifier to be valid }, ParentNodeExecutionId: nCtx.NodeExecutionMetadata().GetNodeExecutionID(), RetryAttempt: 0, // ArrayNode will never retry @@ -640,7 +652,6 @@ func buildSubNodeID(nCtx interfaces.NodeExecutionContext, index int, retryAttemp return fmt.Sprintf("%s-n%d-%d", nCtx.NodeID, index, retryAttempt) } - func bytesFromK8sPluginState(pluginState k8s.PluginState) ([]byte, error) { buffer := make([]byte, 0, task.MaxPluginStateSizeBytes) bufferWriter := bytes.NewBuffer(buffer) @@ -666,3 +677,8 @@ func constructOutputReferences(ctx context.Context, nCtx interfaces.NodeExecutio return subDataDir, subOutputDir, nil } + +func isTerminalNodePhase(nodePhase v1alpha1.NodePhase) bool { + return nodePhase == v1alpha1.NodePhaseSucceeded || nodePhase == v1alpha1.NodePhaseFailed || nodePhase == v1alpha1.NodePhaseTimedOut || + nodePhase == v1alpha1.NodePhaseSkipped || nodePhase == v1alpha1.NodePhaseRecovered +} diff --git a/pkg/controller/nodes/handler_factory.go b/pkg/controller/nodes/handler_factory.go index 28ca7e3d0..b3ed8b95e 100644 --- a/pkg/controller/nodes/handler_factory.go +++ b/pkg/controller/nodes/handler_factory.go @@ -65,7 +65,7 @@ func NewHandlerFactory(ctx context.Context, executor interfaces.Node, workflowLa return nil, err } - arrayHandler, err := array.New(executor, scope) + arrayHandler, err := array.New(executor, eventConfig, scope) if err != nil { return nil, err } From 3609dd7427173db8edf77ad96555aa89d546f763 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Mon, 12 Jun 2023 14:59:16 -0500 Subject: [PATCH 34/62] refactored into new files Signed-off-by: Daniel Rammer --- .../nodes/array/execution_context.go | 154 ------------------ pkg/controller/nodes/array/handler.go | 99 +---------- pkg/controller/nodes/array/input_reader.go | 42 ----- .../nodes/array/node_execution_context.go | 151 +++++++++++++++++ .../array/node_execution_context_builder.go | 55 +++++++ pkg/controller/nodes/array/utils.go | 104 ++++++++++++ 6 files changed, 315 insertions(+), 290 deletions(-) delete mode 100644 pkg/controller/nodes/array/input_reader.go create mode 100644 pkg/controller/nodes/array/node_execution_context.go create mode 100644 pkg/controller/nodes/array/node_execution_context_builder.go create mode 100644 pkg/controller/nodes/array/utils.go diff --git a/pkg/controller/nodes/array/execution_context.go b/pkg/controller/nodes/array/execution_context.go index fb8f54590..b5acd384f 100644 --- a/pkg/controller/nodes/array/execution_context.go +++ b/pkg/controller/nodes/array/execution_context.go @@ -1,18 +1,10 @@ package array import ( - "context" "strconv" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" - "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" - "github.com/flyteorg/flytepropeller/pkg/controller/config" "github.com/flyteorg/flytepropeller/pkg/controller/executors" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" ) const ( @@ -55,149 +47,3 @@ func newArrayExecutionContext(executionContext executors.ExecutionContext, subNo currentParallelism: currentParallelism, } } - -type arrayEventRecorder struct { - nodeEvents []*event.NodeExecutionEvent - taskEvents []*event.TaskExecutionEvent -} - -func (a *arrayEventRecorder) RecordNodeEvent(ctx context.Context, event *event.NodeExecutionEvent, eventConfig *config.EventConfig) error { - a.nodeEvents = append(a.nodeEvents, event) - return nil -} - -func (a *arrayEventRecorder) RecordTaskEvent(ctx context.Context, event *event.TaskExecutionEvent, eventConfig *config.EventConfig) error { - a.taskEvents = append(a.taskEvents, event) - return nil -} - -func (a *arrayEventRecorder) NodeEvents() []*event.NodeExecutionEvent { - return a.nodeEvents -} - -func (a *arrayEventRecorder) TaskEvents() []*event.TaskExecutionEvent { - return a.taskEvents -} - -func newArrayEventRecorder() *arrayEventRecorder { - return &arrayEventRecorder{ - nodeEvents: make([]*event.NodeExecutionEvent, 0), - taskEvents: make([]*event.TaskExecutionEvent, 0), - } -} - -type arrayTaskReader struct { - interfaces.TaskReader -} - -func (a *arrayTaskReader) Read(ctx context.Context) (*core.TaskTemplate, error) { - taskTemplate, err := a.TaskReader.Read(ctx) - if err != nil { - return nil, err - } - - // convert output list variable to singular - outputVariables := make(map[string]*core.Variable) - for key, value := range taskTemplate.Interface.Outputs.Variables { - switch v := value.Type.Type.(type) { - case *core.LiteralType_CollectionType: - outputVariables[key] = &core.Variable{ - Type: v.CollectionType, - Description: value.Description, - } - default: - outputVariables[key] = value - } - } - - taskTemplate.Interface.Outputs = &core.VariableMap{ - Variables: outputVariables, - } - return taskTemplate, nil -} - - -type arrayNodeExecutionContext struct { - interfaces.NodeExecutionContext - eventRecorder interfaces.EventRecorder - executionContext *arrayExecutionContext - inputReader io.InputReader - nodeStatus *v1alpha1.NodeStatus - taskReader interfaces.TaskReader -} - -func (a *arrayNodeExecutionContext) EventsRecorder() interfaces.EventRecorder { - return a.eventRecorder -} - -func (a *arrayNodeExecutionContext) ExecutionContext() executors.ExecutionContext { - return a.executionContext -} - -func (a *arrayNodeExecutionContext) InputReader() io.InputReader { - return a.inputReader -} - -func (a *arrayNodeExecutionContext) NodeStatus() v1alpha1.ExecutableNodeStatus { - return a.nodeStatus -} - -func (a *arrayNodeExecutionContext) TaskReader() interfaces.TaskReader { - return a.taskReader -} - -func newArrayNodeExecutionContext(nodeExecutionContext interfaces.NodeExecutionContext, inputReader io.InputReader, eventRecorder interfaces.EventRecorder, subNodeIndex int, nodeStatus *v1alpha1.NodeStatus, currentParallelism *uint32, maxParallelism uint32) *arrayNodeExecutionContext { - arrayExecutionContext := newArrayExecutionContext(nodeExecutionContext.ExecutionContext(), subNodeIndex, currentParallelism, maxParallelism) - return &arrayNodeExecutionContext{ - NodeExecutionContext: nodeExecutionContext, - eventRecorder: eventRecorder, - executionContext: arrayExecutionContext, - inputReader: inputReader, - nodeStatus: nodeStatus, - taskReader: &arrayTaskReader{nodeExecutionContext.TaskReader()}, - } -} - -type arrayNodeExecutionContextBuilder struct { - nCtxBuilder interfaces.NodeExecutionContextBuilder - subNodeID v1alpha1.NodeID - subNodeIndex int - subNodeStatus *v1alpha1.NodeStatus - inputReader io.InputReader - currentParallelism *uint32 - maxParallelism uint32 - eventRecorder interfaces.EventRecorder -} - -func (a *arrayNodeExecutionContextBuilder) BuildNodeExecutionContext(ctx context.Context, executionContext executors.ExecutionContext, - nl executors.NodeLookup, currentNodeID v1alpha1.NodeID) (interfaces.NodeExecutionContext, error) { - - // create base NodeExecutionContext - nCtx, err := a.nCtxBuilder.BuildNodeExecutionContext(ctx, executionContext, nl, currentNodeID) - if err != nil { - return nil, err - } - - if currentNodeID == a.subNodeID { - // overwrite NodeExecutionContext for ArrayNode execution - nCtx = newArrayNodeExecutionContext(nCtx, a.inputReader, a.eventRecorder, a.subNodeIndex, a.subNodeStatus, a.currentParallelism, a.maxParallelism) - } - - return nCtx, nil -} - -func newArrayNodeExecutionContextBuilder(nCtxBuilder interfaces.NodeExecutionContextBuilder, subNodeID v1alpha1.NodeID, - subNodeIndex int, subNodeStatus *v1alpha1.NodeStatus, inputReader io.InputReader, eventRecorder interfaces.EventRecorder, - currentParallelism *uint32, maxParallelism uint32) interfaces.NodeExecutionContextBuilder { - - return &arrayNodeExecutionContextBuilder{ - nCtxBuilder: nCtxBuilder, - subNodeID: subNodeID, - subNodeIndex: subNodeIndex, - subNodeStatus: subNodeStatus, - inputReader: inputReader, - currentParallelism: currentParallelism, - maxParallelism: maxParallelism, - eventRecorder: eventRecorder, - } -} diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 074c4bc7b..9c206e3de 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -1,12 +1,10 @@ package array import ( - "bytes" "context" "fmt" "math" "strconv" - "time" idlcore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" @@ -24,15 +22,12 @@ import ( "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/codex" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/k8s" "github.com/flyteorg/flytestdlib/bitarray" "github.com/flyteorg/flytestdlib/logger" "github.com/flyteorg/flytestdlib/promutils" "github.com/flyteorg/flytestdlib/storage" - - "github.com/golang/protobuf/ptypes" ) var ( @@ -112,7 +107,7 @@ func (a *arrayNodeHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecut return fmt.Errorf(messageCollector.Summary(events.MaxErrorMessageLength)) } - // update aborted state for subnodes + // update aborted state for subNodes taskExecutionEvent, err := buildTaskExecutionEvent(ctx, nCtx, idlcore.TaskExecution_ABORTED, 0, externalResources) if err != nil { return err @@ -248,7 +243,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // transition ArrayNode to `ArrayNodePhaseExecuting` arrayNodeState.Phase = v1alpha1.ArrayNodePhaseExecuting case v1alpha1.ArrayNodePhaseExecuting: - // process array node subnodes + // process array node subNodes currentParallelism := uint32(0) messageCollector := errorcollector.NewErrorMessageCollector() externalResources = make([]*event.ExternalResourceInfo, 0) @@ -295,8 +290,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu } for _, log := range taskExecutionEvent.Logs { - // TODO @hamersaw - do we need to add retryAttempt to log name? - log.Name = fmt.Sprintf("%s-%d", log.Name, i) + log.Name = fmt.Sprintf("%s-%d", log.Name, i) // TODO @hamersaw - do we need to add retryAttempt to log name? } externalResources = append(externalResources, &event.ExternalResourceInfo{ @@ -449,7 +443,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu return handler.UnknownTransition, errors.Errorf(errors.IllegalStateError, nCtx.NodeID(), "invalid ArrayNode phase %+v", arrayNodeState.Phase) } - // if there were changes to subnode status externalResources will be populated and must be + // if there were changes to subNode status externalResources will be populated and must be // reported to admin through a TaskExecutionEvent. if len(externalResources) > 0 { // determine task phase from ArrayNodePhase @@ -468,7 +462,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // need to increment taskPhaseVersion if arrayNodeState.Phase does not change, otherwise // reset to 0. by incrementing this always we report an event and ensure processing // everytime the ArrayNode is evaluated. if this overhead becomes too large, we will need - // to revisit and only increment when any subnode state changes. + // to revisit and only increment when any subNode state changes. if currentArrayNodePhase != arrayNodeState.Phase { arrayNodeState.TaskPhaseVersion = 0 } else { @@ -599,86 +593,3 @@ func (a *arrayNodeHandler) buildArrayNodeContext(ctx context.Context, nCtx inter return arrayNodeExecutor, arrayExecutionContext, &arrayNodeLookup, &arrayNodeLookup, &subNodeSpec, subNodeStatus, arrayEventRecorder, nil } - -func appendLiteral(name string, literal *idlcore.Literal, outputLiterals map[string]*idlcore.Literal, length int) { - outputLiteral, exists := outputLiterals[name] - if !exists { - outputLiteral = &idlcore.Literal{ - Value: &idlcore.Literal_Collection{ - Collection: &idlcore.LiteralCollection{ - Literals: make([]*idlcore.Literal, 0, length), - }, - }, - } - - outputLiterals[name] = outputLiteral - } - - collection := outputLiteral.GetCollection() - collection.Literals = append(collection.Literals, literal) -} - -func buildTaskExecutionEvent(ctx context.Context, nCtx interfaces.NodeExecutionContext, taskPhase idlcore.TaskExecution_Phase, taskPhaseVersion uint32, externalResources []*event.ExternalResourceInfo) (*event.TaskExecutionEvent, error) { - occurredAt, err := ptypes.TimestampProto(time.Now()) - if err != nil { - return nil, err - } - - nodeExecutionID := nCtx.NodeExecutionMetadata().GetNodeExecutionID() - workflowExecutionID := nodeExecutionID.ExecutionId - return &event.TaskExecutionEvent{ - TaskId: &idlcore.Identifier{ - ResourceType: idlcore.ResourceType_TASK, - Project: workflowExecutionID.Project, - Domain: workflowExecutionID.Domain, - Name: nCtx.NodeID(), - Version: "v1", // this value is irrelevant but necessary for the identifier to be valid - }, - ParentNodeExecutionId: nCtx.NodeExecutionMetadata().GetNodeExecutionID(), - RetryAttempt: 0, // ArrayNode will never retry - Phase: taskPhase, - PhaseVersion: taskPhaseVersion, - OccurredAt: occurredAt, - Metadata: &event.TaskExecutionMetadata{ - ExternalResources: externalResources, - }, - TaskType: "k8s-array", - EventVersion: 1, - }, nil -} - -// TODO @hamersaw - what do we want for a subnode ID? -func buildSubNodeID(nCtx interfaces.NodeExecutionContext, index int, retryAttempt uint32) string { - return fmt.Sprintf("%s-n%d-%d", nCtx.NodeID, index, retryAttempt) -} - -func bytesFromK8sPluginState(pluginState k8s.PluginState) ([]byte, error) { - buffer := make([]byte, 0, task.MaxPluginStateSizeBytes) - bufferWriter := bytes.NewBuffer(buffer) - - codec := codex.GobStateCodec{} - if err := codec.Encode(pluginState, bufferWriter); err != nil { - return nil, err - } - - return bufferWriter.Bytes(), nil -} - -func constructOutputReferences(ctx context.Context, nCtx interfaces.NodeExecutionContext, postfix...string) (storage.DataReference, storage.DataReference, error) { - subDataDir, err := nCtx.DataStore().ConstructReference(ctx, nCtx.NodeStatus().GetDataDir(), postfix...) - if err != nil { - return "", "", err - } - - subOutputDir, err := nCtx.DataStore().ConstructReference(ctx, nCtx.NodeStatus().GetOutputDir(), postfix...) - if err != nil { - return "", "", err - } - - return subDataDir, subOutputDir, nil -} - -func isTerminalNodePhase(nodePhase v1alpha1.NodePhase) bool { - return nodePhase == v1alpha1.NodePhaseSucceeded || nodePhase == v1alpha1.NodePhaseFailed || nodePhase == v1alpha1.NodePhaseTimedOut || - nodePhase == v1alpha1.NodePhaseSkipped || nodePhase == v1alpha1.NodePhaseRecovered -} diff --git a/pkg/controller/nodes/array/input_reader.go b/pkg/controller/nodes/array/input_reader.go deleted file mode 100644 index 4059db95d..000000000 --- a/pkg/controller/nodes/array/input_reader.go +++ /dev/null @@ -1,42 +0,0 @@ -package array - -import ( - "context" - - idlcore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" -) - -type staticInputReader struct { - io.InputFilePaths - input *idlcore.LiteralMap -} - -func (i staticInputReader) Get(_ context.Context) (*idlcore.LiteralMap, error) { - return i.input, nil -} - -func newStaticInputReader(inputPaths io.InputFilePaths, input *idlcore.LiteralMap) staticInputReader { - return staticInputReader{ - InputFilePaths: inputPaths, - input: input, - } -} - -func constructLiteralMap(ctx context.Context, inputReader io.InputReader, index int) (*idlcore.LiteralMap, error) { - inputs, err := inputReader.Get(ctx) - if err != nil { - return nil, err - } - - literals := make(map[string]*idlcore.Literal) - for name, literal := range inputs.Literals { - if literalCollection := literal.GetCollection(); literalCollection != nil { - literals[name] = literalCollection.Literals[index] - } - } - - return &idlcore.LiteralMap{ - Literals: literals, - }, nil -} diff --git a/pkg/controller/nodes/array/node_execution_context.go b/pkg/controller/nodes/array/node_execution_context.go new file mode 100644 index 000000000..2bd0005ea --- /dev/null +++ b/pkg/controller/nodes/array/node_execution_context.go @@ -0,0 +1,151 @@ +package array + +import ( + "context" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" + + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/config" + "github.com/flyteorg/flytepropeller/pkg/controller/executors" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" +) + +type arrayEventRecorder struct { + nodeEvents []*event.NodeExecutionEvent + taskEvents []*event.TaskExecutionEvent +} + +func (a *arrayEventRecorder) RecordNodeEvent(ctx context.Context, event *event.NodeExecutionEvent, eventConfig *config.EventConfig) error { + a.nodeEvents = append(a.nodeEvents, event) + return nil +} + +func (a *arrayEventRecorder) RecordTaskEvent(ctx context.Context, event *event.TaskExecutionEvent, eventConfig *config.EventConfig) error { + a.taskEvents = append(a.taskEvents, event) + return nil +} + +func (a *arrayEventRecorder) NodeEvents() []*event.NodeExecutionEvent { + return a.nodeEvents +} + +func (a *arrayEventRecorder) TaskEvents() []*event.TaskExecutionEvent { + return a.taskEvents +} + +func newArrayEventRecorder() *arrayEventRecorder { + return &arrayEventRecorder{ + nodeEvents: make([]*event.NodeExecutionEvent, 0), + taskEvents: make([]*event.TaskExecutionEvent, 0), + } +} + +type staticInputReader struct { + io.InputFilePaths + input *core.LiteralMap +} + +func (i staticInputReader) Get(_ context.Context) (*core.LiteralMap, error) { + return i.input, nil +} + +func newStaticInputReader(inputPaths io.InputFilePaths, input *core.LiteralMap) staticInputReader { + return staticInputReader{ + InputFilePaths: inputPaths, + input: input, + } +} + +func constructLiteralMap(ctx context.Context, inputReader io.InputReader, index int) (*core.LiteralMap, error) { + inputs, err := inputReader.Get(ctx) + if err != nil { + return nil, err + } + + literals := make(map[string]*core.Literal) + for name, literal := range inputs.Literals { + if literalCollection := literal.GetCollection(); literalCollection != nil { + literals[name] = literalCollection.Literals[index] + } + } + + return &core.LiteralMap{ + Literals: literals, + }, nil +} + +type arrayTaskReader struct { + interfaces.TaskReader +} + +func (a *arrayTaskReader) Read(ctx context.Context) (*core.TaskTemplate, error) { + taskTemplate, err := a.TaskReader.Read(ctx) + if err != nil { + return nil, err + } + + // convert output list variable to singular + outputVariables := make(map[string]*core.Variable) + for key, value := range taskTemplate.Interface.Outputs.Variables { + switch v := value.Type.Type.(type) { + case *core.LiteralType_CollectionType: + outputVariables[key] = &core.Variable{ + Type: v.CollectionType, + Description: value.Description, + } + default: + outputVariables[key] = value + } + } + + taskTemplate.Interface.Outputs = &core.VariableMap{ + Variables: outputVariables, + } + return taskTemplate, nil +} + + +type arrayNodeExecutionContext struct { + interfaces.NodeExecutionContext + eventRecorder interfaces.EventRecorder + executionContext executors.ExecutionContext + inputReader io.InputReader + nodeStatus *v1alpha1.NodeStatus + taskReader interfaces.TaskReader +} + +func (a *arrayNodeExecutionContext) EventsRecorder() interfaces.EventRecorder { + return a.eventRecorder +} + +func (a *arrayNodeExecutionContext) ExecutionContext() executors.ExecutionContext { + return a.executionContext +} + +func (a *arrayNodeExecutionContext) InputReader() io.InputReader { + return a.inputReader +} + +func (a *arrayNodeExecutionContext) NodeStatus() v1alpha1.ExecutableNodeStatus { + return a.nodeStatus +} + +func (a *arrayNodeExecutionContext) TaskReader() interfaces.TaskReader { + return a.taskReader +} + +func newArrayNodeExecutionContext(nodeExecutionContext interfaces.NodeExecutionContext, inputReader io.InputReader, eventRecorder interfaces.EventRecorder, subNodeIndex int, nodeStatus *v1alpha1.NodeStatus, currentParallelism *uint32, maxParallelism uint32) *arrayNodeExecutionContext { + arrayExecutionContext := newArrayExecutionContext(nodeExecutionContext.ExecutionContext(), subNodeIndex, currentParallelism, maxParallelism) + return &arrayNodeExecutionContext{ + NodeExecutionContext: nodeExecutionContext, + eventRecorder: eventRecorder, + executionContext: arrayExecutionContext, + inputReader: inputReader, + nodeStatus: nodeStatus, + taskReader: &arrayTaskReader{nodeExecutionContext.TaskReader()}, + } +} diff --git a/pkg/controller/nodes/array/node_execution_context_builder.go b/pkg/controller/nodes/array/node_execution_context_builder.go new file mode 100644 index 000000000..e8367e2eb --- /dev/null +++ b/pkg/controller/nodes/array/node_execution_context_builder.go @@ -0,0 +1,55 @@ +package array + +import ( + "context" + + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" + + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/executors" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" +) + +type arrayNodeExecutionContextBuilder struct { + nCtxBuilder interfaces.NodeExecutionContextBuilder + subNodeID v1alpha1.NodeID + subNodeIndex int + subNodeStatus *v1alpha1.NodeStatus + inputReader io.InputReader + currentParallelism *uint32 + maxParallelism uint32 + eventRecorder interfaces.EventRecorder +} + +func (a *arrayNodeExecutionContextBuilder) BuildNodeExecutionContext(ctx context.Context, executionContext executors.ExecutionContext, + nl executors.NodeLookup, currentNodeID v1alpha1.NodeID) (interfaces.NodeExecutionContext, error) { + + // create base NodeExecutionContext + nCtx, err := a.nCtxBuilder.BuildNodeExecutionContext(ctx, executionContext, nl, currentNodeID) + if err != nil { + return nil, err + } + + if currentNodeID == a.subNodeID { + // overwrite NodeExecutionContext for ArrayNode execution + nCtx = newArrayNodeExecutionContext(nCtx, a.inputReader, a.eventRecorder, a.subNodeIndex, a.subNodeStatus, a.currentParallelism, a.maxParallelism) + } + + return nCtx, nil +} + +func newArrayNodeExecutionContextBuilder(nCtxBuilder interfaces.NodeExecutionContextBuilder, subNodeID v1alpha1.NodeID, + subNodeIndex int, subNodeStatus *v1alpha1.NodeStatus, inputReader io.InputReader, eventRecorder interfaces.EventRecorder, + currentParallelism *uint32, maxParallelism uint32) interfaces.NodeExecutionContextBuilder { + + return &arrayNodeExecutionContextBuilder{ + nCtxBuilder: nCtxBuilder, + subNodeID: subNodeID, + subNodeIndex: subNodeIndex, + subNodeStatus: subNodeStatus, + inputReader: inputReader, + currentParallelism: currentParallelism, + maxParallelism: maxParallelism, + eventRecorder: eventRecorder, + } +} diff --git a/pkg/controller/nodes/array/utils.go b/pkg/controller/nodes/array/utils.go new file mode 100644 index 000000000..fc8633535 --- /dev/null +++ b/pkg/controller/nodes/array/utils.go @@ -0,0 +1,104 @@ +package array + +import ( + "bytes" + "context" + "fmt" + "time" + + idlcore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/codex" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/k8s" + + "github.com/flyteorg/flytestdlib/storage" + + "github.com/golang/protobuf/ptypes" +) + +func appendLiteral(name string, literal *idlcore.Literal, outputLiterals map[string]*idlcore.Literal, length int) { + outputLiteral, exists := outputLiterals[name] + if !exists { + outputLiteral = &idlcore.Literal{ + Value: &idlcore.Literal_Collection{ + Collection: &idlcore.LiteralCollection{ + Literals: make([]*idlcore.Literal, 0, length), + }, + }, + } + + outputLiterals[name] = outputLiteral + } + + collection := outputLiteral.GetCollection() + collection.Literals = append(collection.Literals, literal) +} + +func buildTaskExecutionEvent(ctx context.Context, nCtx interfaces.NodeExecutionContext, taskPhase idlcore.TaskExecution_Phase, taskPhaseVersion uint32, externalResources []*event.ExternalResourceInfo) (*event.TaskExecutionEvent, error) { + occurredAt, err := ptypes.TimestampProto(time.Now()) + if err != nil { + return nil, err + } + + nodeExecutionID := nCtx.NodeExecutionMetadata().GetNodeExecutionID() + workflowExecutionID := nodeExecutionID.ExecutionId + return &event.TaskExecutionEvent{ + TaskId: &idlcore.Identifier{ + ResourceType: idlcore.ResourceType_TASK, + Project: workflowExecutionID.Project, + Domain: workflowExecutionID.Domain, + Name: nCtx.NodeID(), + Version: "v1", // this value is irrelevant but necessary for the identifier to be valid + }, + ParentNodeExecutionId: nCtx.NodeExecutionMetadata().GetNodeExecutionID(), + RetryAttempt: 0, // ArrayNode will never retry + Phase: taskPhase, + PhaseVersion: taskPhaseVersion, + OccurredAt: occurredAt, + Metadata: &event.TaskExecutionMetadata{ + ExternalResources: externalResources, + }, + TaskType: "k8s-array", + EventVersion: 1, + }, nil +} + +// TODO @hamersaw - what do we want for a subNode ID? +func buildSubNodeID(nCtx interfaces.NodeExecutionContext, index int, retryAttempt uint32) string { + return fmt.Sprintf("%s-n%d-%d", nCtx.NodeID, index, retryAttempt) +} + +func bytesFromK8sPluginState(pluginState k8s.PluginState) ([]byte, error) { + buffer := make([]byte, 0, task.MaxPluginStateSizeBytes) + bufferWriter := bytes.NewBuffer(buffer) + + codec := codex.GobStateCodec{} + if err := codec.Encode(pluginState, bufferWriter); err != nil { + return nil, err + } + + return bufferWriter.Bytes(), nil +} + +func constructOutputReferences(ctx context.Context, nCtx interfaces.NodeExecutionContext, postfix...string) (storage.DataReference, storage.DataReference, error) { + subDataDir, err := nCtx.DataStore().ConstructReference(ctx, nCtx.NodeStatus().GetDataDir(), postfix...) + if err != nil { + return "", "", err + } + + subOutputDir, err := nCtx.DataStore().ConstructReference(ctx, nCtx.NodeStatus().GetOutputDir(), postfix...) + if err != nil { + return "", "", err + } + + return subDataDir, subOutputDir, nil +} + +func isTerminalNodePhase(nodePhase v1alpha1.NodePhase) bool { + return nodePhase == v1alpha1.NodePhaseSucceeded || nodePhase == v1alpha1.NodePhaseFailed || nodePhase == v1alpha1.NodePhaseTimedOut || + nodePhase == v1alpha1.NodePhaseSkipped || nodePhase == v1alpha1.NodePhaseRecovered +} From fdf8d6a2dfcaa3f9b53e4968ad0ac0fd2e4c56fa Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Tue, 13 Jun 2023 09:53:21 -0500 Subject: [PATCH 35/62] refactoring for ArrayNode unit tests Signed-off-by: Daniel Rammer --- .../v1alpha1/mocks/ExecutableArrayNode.go | 147 +++++++ .../mocks/ExecutableArrayNodeStatus.go | 243 +++++++++++ .../v1alpha1/mocks/ExecutableNode.go | 34 ++ .../v1alpha1/mocks/ExecutableNodeStatus.go | 73 ++++ .../v1alpha1/mocks/MutableArrayNodeStatus.go | 310 ++++++++++++++ .../v1alpha1/mocks/MutableNodeStatus.go | 73 ++++ pkg/compiler/common/mocks/node.go | 34 ++ pkg/compiler/common/mocks/node_builder.go | 34 ++ pkg/controller/nodes/array/handler_test.go | 70 +++ pkg/controller/nodes/array/utils.go | 4 +- .../nodes/dynamic/mocks/task_node_handler.go | 22 +- pkg/controller/nodes/handler/mocks/node.go | 22 +- .../nodes/handler/mocks/node_executor.go | 122 ++++++ .../nodes/handler/mocks/node_state_reader.go | 173 -------- .../nodes/interfaces/mocks/event_recorder.go | 82 ++++ .../interfaces}/mocks/node.go | 90 +++- .../mocks/node_execution_context.go | 53 ++- .../mocks/node_execution_context_builder.go | 58 +++ .../mocks/node_execution_metadata.go | 0 .../interfaces/mocks/node_state_reader.go | 397 ++++++++++++++++++ .../mocks/node_state_writer.go | 69 ++- .../mocks/task_reader.go | 0 22 files changed, 1862 insertions(+), 248 deletions(-) create mode 100644 pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableArrayNode.go create mode 100644 pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableArrayNodeStatus.go create mode 100644 pkg/apis/flyteworkflow/v1alpha1/mocks/MutableArrayNodeStatus.go create mode 100644 pkg/controller/nodes/array/handler_test.go create mode 100644 pkg/controller/nodes/handler/mocks/node_executor.go delete mode 100644 pkg/controller/nodes/handler/mocks/node_state_reader.go create mode 100644 pkg/controller/nodes/interfaces/mocks/event_recorder.go rename pkg/controller/{executors => nodes/interfaces}/mocks/node.go (68%) rename pkg/controller/nodes/{handler => interfaces}/mocks/node_execution_context.go (90%) create mode 100644 pkg/controller/nodes/interfaces/mocks/node_execution_context_builder.go rename pkg/controller/nodes/{handler => interfaces}/mocks/node_execution_metadata.go (100%) create mode 100644 pkg/controller/nodes/interfaces/mocks/node_state_reader.go rename pkg/controller/nodes/{handler => interfaces}/mocks/node_state_writer.go (60%) rename pkg/controller/nodes/{handler => interfaces}/mocks/task_reader.go (100%) diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableArrayNode.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableArrayNode.go new file mode 100644 index 000000000..fb200ff06 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableArrayNode.go @@ -0,0 +1,147 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + v1alpha1 "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + mock "github.com/stretchr/testify/mock" +) + +// ExecutableArrayNode is an autogenerated mock type for the ExecutableArrayNode type +type ExecutableArrayNode struct { + mock.Mock +} + +type ExecutableArrayNode_GetMinSuccessRatio struct { + *mock.Call +} + +func (_m ExecutableArrayNode_GetMinSuccessRatio) Return(_a0 *float32) *ExecutableArrayNode_GetMinSuccessRatio { + return &ExecutableArrayNode_GetMinSuccessRatio{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableArrayNode) OnGetMinSuccessRatio() *ExecutableArrayNode_GetMinSuccessRatio { + c_call := _m.On("GetMinSuccessRatio") + return &ExecutableArrayNode_GetMinSuccessRatio{Call: c_call} +} + +func (_m *ExecutableArrayNode) OnGetMinSuccessRatioMatch(matchers ...interface{}) *ExecutableArrayNode_GetMinSuccessRatio { + c_call := _m.On("GetMinSuccessRatio", matchers...) + return &ExecutableArrayNode_GetMinSuccessRatio{Call: c_call} +} + +// GetMinSuccessRatio provides a mock function with given fields: +func (_m *ExecutableArrayNode) GetMinSuccessRatio() *float32 { + ret := _m.Called() + + var r0 *float32 + if rf, ok := ret.Get(0).(func() *float32); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*float32) + } + } + + return r0 +} + +type ExecutableArrayNode_GetMinSuccesses struct { + *mock.Call +} + +func (_m ExecutableArrayNode_GetMinSuccesses) Return(_a0 *uint32) *ExecutableArrayNode_GetMinSuccesses { + return &ExecutableArrayNode_GetMinSuccesses{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableArrayNode) OnGetMinSuccesses() *ExecutableArrayNode_GetMinSuccesses { + c_call := _m.On("GetMinSuccesses") + return &ExecutableArrayNode_GetMinSuccesses{Call: c_call} +} + +func (_m *ExecutableArrayNode) OnGetMinSuccessesMatch(matchers ...interface{}) *ExecutableArrayNode_GetMinSuccesses { + c_call := _m.On("GetMinSuccesses", matchers...) + return &ExecutableArrayNode_GetMinSuccesses{Call: c_call} +} + +// GetMinSuccesses provides a mock function with given fields: +func (_m *ExecutableArrayNode) GetMinSuccesses() *uint32 { + ret := _m.Called() + + var r0 *uint32 + if rf, ok := ret.Get(0).(func() *uint32); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*uint32) + } + } + + return r0 +} + +type ExecutableArrayNode_GetParallelism struct { + *mock.Call +} + +func (_m ExecutableArrayNode_GetParallelism) Return(_a0 uint32) *ExecutableArrayNode_GetParallelism { + return &ExecutableArrayNode_GetParallelism{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableArrayNode) OnGetParallelism() *ExecutableArrayNode_GetParallelism { + c_call := _m.On("GetParallelism") + return &ExecutableArrayNode_GetParallelism{Call: c_call} +} + +func (_m *ExecutableArrayNode) OnGetParallelismMatch(matchers ...interface{}) *ExecutableArrayNode_GetParallelism { + c_call := _m.On("GetParallelism", matchers...) + return &ExecutableArrayNode_GetParallelism{Call: c_call} +} + +// GetParallelism provides a mock function with given fields: +func (_m *ExecutableArrayNode) GetParallelism() uint32 { + ret := _m.Called() + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + return r0 +} + +type ExecutableArrayNode_GetSubNodeSpec struct { + *mock.Call +} + +func (_m ExecutableArrayNode_GetSubNodeSpec) Return(_a0 *v1alpha1.NodeSpec) *ExecutableArrayNode_GetSubNodeSpec { + return &ExecutableArrayNode_GetSubNodeSpec{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableArrayNode) OnGetSubNodeSpec() *ExecutableArrayNode_GetSubNodeSpec { + c_call := _m.On("GetSubNodeSpec") + return &ExecutableArrayNode_GetSubNodeSpec{Call: c_call} +} + +func (_m *ExecutableArrayNode) OnGetSubNodeSpecMatch(matchers ...interface{}) *ExecutableArrayNode_GetSubNodeSpec { + c_call := _m.On("GetSubNodeSpec", matchers...) + return &ExecutableArrayNode_GetSubNodeSpec{Call: c_call} +} + +// GetSubNodeSpec provides a mock function with given fields: +func (_m *ExecutableArrayNode) GetSubNodeSpec() *v1alpha1.NodeSpec { + ret := _m.Called() + + var r0 *v1alpha1.NodeSpec + if rf, ok := ret.Get(0).(func() *v1alpha1.NodeSpec); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1alpha1.NodeSpec) + } + } + + return r0 +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableArrayNodeStatus.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableArrayNodeStatus.go new file mode 100644 index 000000000..08de9e29c --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableArrayNodeStatus.go @@ -0,0 +1,243 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + core "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + bitarray "github.com/flyteorg/flytestdlib/bitarray" + + mock "github.com/stretchr/testify/mock" + + v1alpha1 "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) + +// ExecutableArrayNodeStatus is an autogenerated mock type for the ExecutableArrayNodeStatus type +type ExecutableArrayNodeStatus struct { + mock.Mock +} + +type ExecutableArrayNodeStatus_GetArrayNodePhase struct { + *mock.Call +} + +func (_m ExecutableArrayNodeStatus_GetArrayNodePhase) Return(_a0 v1alpha1.ArrayNodePhase) *ExecutableArrayNodeStatus_GetArrayNodePhase { + return &ExecutableArrayNodeStatus_GetArrayNodePhase{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableArrayNodeStatus) OnGetArrayNodePhase() *ExecutableArrayNodeStatus_GetArrayNodePhase { + c_call := _m.On("GetArrayNodePhase") + return &ExecutableArrayNodeStatus_GetArrayNodePhase{Call: c_call} +} + +func (_m *ExecutableArrayNodeStatus) OnGetArrayNodePhaseMatch(matchers ...interface{}) *ExecutableArrayNodeStatus_GetArrayNodePhase { + c_call := _m.On("GetArrayNodePhase", matchers...) + return &ExecutableArrayNodeStatus_GetArrayNodePhase{Call: c_call} +} + +// GetArrayNodePhase provides a mock function with given fields: +func (_m *ExecutableArrayNodeStatus) GetArrayNodePhase() v1alpha1.ArrayNodePhase { + ret := _m.Called() + + var r0 v1alpha1.ArrayNodePhase + if rf, ok := ret.Get(0).(func() v1alpha1.ArrayNodePhase); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.ArrayNodePhase) + } + + return r0 +} + +type ExecutableArrayNodeStatus_GetExecutionError struct { + *mock.Call +} + +func (_m ExecutableArrayNodeStatus_GetExecutionError) Return(_a0 *core.ExecutionError) *ExecutableArrayNodeStatus_GetExecutionError { + return &ExecutableArrayNodeStatus_GetExecutionError{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableArrayNodeStatus) OnGetExecutionError() *ExecutableArrayNodeStatus_GetExecutionError { + c_call := _m.On("GetExecutionError") + return &ExecutableArrayNodeStatus_GetExecutionError{Call: c_call} +} + +func (_m *ExecutableArrayNodeStatus) OnGetExecutionErrorMatch(matchers ...interface{}) *ExecutableArrayNodeStatus_GetExecutionError { + c_call := _m.On("GetExecutionError", matchers...) + return &ExecutableArrayNodeStatus_GetExecutionError{Call: c_call} +} + +// GetExecutionError provides a mock function with given fields: +func (_m *ExecutableArrayNodeStatus) GetExecutionError() *core.ExecutionError { + ret := _m.Called() + + var r0 *core.ExecutionError + if rf, ok := ret.Get(0).(func() *core.ExecutionError); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.ExecutionError) + } + } + + return r0 +} + +type ExecutableArrayNodeStatus_GetSubNodePhases struct { + *mock.Call +} + +func (_m ExecutableArrayNodeStatus_GetSubNodePhases) Return(_a0 bitarray.CompactArray) *ExecutableArrayNodeStatus_GetSubNodePhases { + return &ExecutableArrayNodeStatus_GetSubNodePhases{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableArrayNodeStatus) OnGetSubNodePhases() *ExecutableArrayNodeStatus_GetSubNodePhases { + c_call := _m.On("GetSubNodePhases") + return &ExecutableArrayNodeStatus_GetSubNodePhases{Call: c_call} +} + +func (_m *ExecutableArrayNodeStatus) OnGetSubNodePhasesMatch(matchers ...interface{}) *ExecutableArrayNodeStatus_GetSubNodePhases { + c_call := _m.On("GetSubNodePhases", matchers...) + return &ExecutableArrayNodeStatus_GetSubNodePhases{Call: c_call} +} + +// GetSubNodePhases provides a mock function with given fields: +func (_m *ExecutableArrayNodeStatus) GetSubNodePhases() bitarray.CompactArray { + ret := _m.Called() + + var r0 bitarray.CompactArray + if rf, ok := ret.Get(0).(func() bitarray.CompactArray); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bitarray.CompactArray) + } + + return r0 +} + +type ExecutableArrayNodeStatus_GetSubNodeRetryAttempts struct { + *mock.Call +} + +func (_m ExecutableArrayNodeStatus_GetSubNodeRetryAttempts) Return(_a0 bitarray.CompactArray) *ExecutableArrayNodeStatus_GetSubNodeRetryAttempts { + return &ExecutableArrayNodeStatus_GetSubNodeRetryAttempts{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableArrayNodeStatus) OnGetSubNodeRetryAttempts() *ExecutableArrayNodeStatus_GetSubNodeRetryAttempts { + c_call := _m.On("GetSubNodeRetryAttempts") + return &ExecutableArrayNodeStatus_GetSubNodeRetryAttempts{Call: c_call} +} + +func (_m *ExecutableArrayNodeStatus) OnGetSubNodeRetryAttemptsMatch(matchers ...interface{}) *ExecutableArrayNodeStatus_GetSubNodeRetryAttempts { + c_call := _m.On("GetSubNodeRetryAttempts", matchers...) + return &ExecutableArrayNodeStatus_GetSubNodeRetryAttempts{Call: c_call} +} + +// GetSubNodeRetryAttempts provides a mock function with given fields: +func (_m *ExecutableArrayNodeStatus) GetSubNodeRetryAttempts() bitarray.CompactArray { + ret := _m.Called() + + var r0 bitarray.CompactArray + if rf, ok := ret.Get(0).(func() bitarray.CompactArray); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bitarray.CompactArray) + } + + return r0 +} + +type ExecutableArrayNodeStatus_GetSubNodeSystemFailures struct { + *mock.Call +} + +func (_m ExecutableArrayNodeStatus_GetSubNodeSystemFailures) Return(_a0 bitarray.CompactArray) *ExecutableArrayNodeStatus_GetSubNodeSystemFailures { + return &ExecutableArrayNodeStatus_GetSubNodeSystemFailures{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableArrayNodeStatus) OnGetSubNodeSystemFailures() *ExecutableArrayNodeStatus_GetSubNodeSystemFailures { + c_call := _m.On("GetSubNodeSystemFailures") + return &ExecutableArrayNodeStatus_GetSubNodeSystemFailures{Call: c_call} +} + +func (_m *ExecutableArrayNodeStatus) OnGetSubNodeSystemFailuresMatch(matchers ...interface{}) *ExecutableArrayNodeStatus_GetSubNodeSystemFailures { + c_call := _m.On("GetSubNodeSystemFailures", matchers...) + return &ExecutableArrayNodeStatus_GetSubNodeSystemFailures{Call: c_call} +} + +// GetSubNodeSystemFailures provides a mock function with given fields: +func (_m *ExecutableArrayNodeStatus) GetSubNodeSystemFailures() bitarray.CompactArray { + ret := _m.Called() + + var r0 bitarray.CompactArray + if rf, ok := ret.Get(0).(func() bitarray.CompactArray); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bitarray.CompactArray) + } + + return r0 +} + +type ExecutableArrayNodeStatus_GetSubNodeTaskPhases struct { + *mock.Call +} + +func (_m ExecutableArrayNodeStatus_GetSubNodeTaskPhases) Return(_a0 bitarray.CompactArray) *ExecutableArrayNodeStatus_GetSubNodeTaskPhases { + return &ExecutableArrayNodeStatus_GetSubNodeTaskPhases{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableArrayNodeStatus) OnGetSubNodeTaskPhases() *ExecutableArrayNodeStatus_GetSubNodeTaskPhases { + c_call := _m.On("GetSubNodeTaskPhases") + return &ExecutableArrayNodeStatus_GetSubNodeTaskPhases{Call: c_call} +} + +func (_m *ExecutableArrayNodeStatus) OnGetSubNodeTaskPhasesMatch(matchers ...interface{}) *ExecutableArrayNodeStatus_GetSubNodeTaskPhases { + c_call := _m.On("GetSubNodeTaskPhases", matchers...) + return &ExecutableArrayNodeStatus_GetSubNodeTaskPhases{Call: c_call} +} + +// GetSubNodeTaskPhases provides a mock function with given fields: +func (_m *ExecutableArrayNodeStatus) GetSubNodeTaskPhases() bitarray.CompactArray { + ret := _m.Called() + + var r0 bitarray.CompactArray + if rf, ok := ret.Get(0).(func() bitarray.CompactArray); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bitarray.CompactArray) + } + + return r0 +} + +type ExecutableArrayNodeStatus_GetTaskPhaseVersion struct { + *mock.Call +} + +func (_m ExecutableArrayNodeStatus_GetTaskPhaseVersion) Return(_a0 uint32) *ExecutableArrayNodeStatus_GetTaskPhaseVersion { + return &ExecutableArrayNodeStatus_GetTaskPhaseVersion{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableArrayNodeStatus) OnGetTaskPhaseVersion() *ExecutableArrayNodeStatus_GetTaskPhaseVersion { + c_call := _m.On("GetTaskPhaseVersion") + return &ExecutableArrayNodeStatus_GetTaskPhaseVersion{Call: c_call} +} + +func (_m *ExecutableArrayNodeStatus) OnGetTaskPhaseVersionMatch(matchers ...interface{}) *ExecutableArrayNodeStatus_GetTaskPhaseVersion { + c_call := _m.On("GetTaskPhaseVersion", matchers...) + return &ExecutableArrayNodeStatus_GetTaskPhaseVersion{Call: c_call} +} + +// GetTaskPhaseVersion provides a mock function with given fields: +func (_m *ExecutableArrayNodeStatus) GetTaskPhaseVersion() uint32 { + ret := _m.Called() + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + return r0 +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNode.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNode.go index 5fbd946fa..a6f143207 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNode.go +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNode.go @@ -51,6 +51,40 @@ func (_m *ExecutableNode) GetActiveDeadline() *time.Duration { return r0 } +type ExecutableNode_GetArrayNode struct { + *mock.Call +} + +func (_m ExecutableNode_GetArrayNode) Return(_a0 v1alpha1.ExecutableArrayNode) *ExecutableNode_GetArrayNode { + return &ExecutableNode_GetArrayNode{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableNode) OnGetArrayNode() *ExecutableNode_GetArrayNode { + c_call := _m.On("GetArrayNode") + return &ExecutableNode_GetArrayNode{Call: c_call} +} + +func (_m *ExecutableNode) OnGetArrayNodeMatch(matchers ...interface{}) *ExecutableNode_GetArrayNode { + c_call := _m.On("GetArrayNode", matchers...) + return &ExecutableNode_GetArrayNode{Call: c_call} +} + +// GetArrayNode provides a mock function with given fields: +func (_m *ExecutableNode) GetArrayNode() v1alpha1.ExecutableArrayNode { + ret := _m.Called() + + var r0 v1alpha1.ExecutableArrayNode + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableArrayNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableArrayNode) + } + } + + return r0 +} + type ExecutableNode_GetBranchNode struct { *mock.Call } diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNodeStatus.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNodeStatus.go index 346680cfa..886aa217c 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNodeStatus.go +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNodeStatus.go @@ -20,6 +20,11 @@ type ExecutableNodeStatus struct { mock.Mock } +// ClearArrayNodeStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) ClearArrayNodeStatus() { + _m.Called() +} + // ClearDynamicNodeStatus provides a mock function with given fields: func (_m *ExecutableNodeStatus) ClearDynamicNodeStatus() { _m.Called() @@ -50,6 +55,40 @@ func (_m *ExecutableNodeStatus) ClearWorkflowStatus() { _m.Called() } +type ExecutableNodeStatus_GetArrayNodeStatus struct { + *mock.Call +} + +func (_m ExecutableNodeStatus_GetArrayNodeStatus) Return(_a0 v1alpha1.MutableArrayNodeStatus) *ExecutableNodeStatus_GetArrayNodeStatus { + return &ExecutableNodeStatus_GetArrayNodeStatus{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableNodeStatus) OnGetArrayNodeStatus() *ExecutableNodeStatus_GetArrayNodeStatus { + c_call := _m.On("GetArrayNodeStatus") + return &ExecutableNodeStatus_GetArrayNodeStatus{Call: c_call} +} + +func (_m *ExecutableNodeStatus) OnGetArrayNodeStatusMatch(matchers ...interface{}) *ExecutableNodeStatus_GetArrayNodeStatus { + c_call := _m.On("GetArrayNodeStatus", matchers...) + return &ExecutableNodeStatus_GetArrayNodeStatus{Call: c_call} +} + +// GetArrayNodeStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetArrayNodeStatus() v1alpha1.MutableArrayNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableArrayNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableArrayNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableArrayNodeStatus) + } + } + + return r0 +} + type ExecutableNodeStatus_GetAttempts struct { *mock.Call } @@ -384,6 +423,40 @@ func (_m *ExecutableNodeStatus) GetNodeExecutionStatus(ctx context.Context, id s return r0 } +type ExecutableNodeStatus_GetOrCreateArrayNodeStatus struct { + *mock.Call +} + +func (_m ExecutableNodeStatus_GetOrCreateArrayNodeStatus) Return(_a0 v1alpha1.MutableArrayNodeStatus) *ExecutableNodeStatus_GetOrCreateArrayNodeStatus { + return &ExecutableNodeStatus_GetOrCreateArrayNodeStatus{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableNodeStatus) OnGetOrCreateArrayNodeStatus() *ExecutableNodeStatus_GetOrCreateArrayNodeStatus { + c_call := _m.On("GetOrCreateArrayNodeStatus") + return &ExecutableNodeStatus_GetOrCreateArrayNodeStatus{Call: c_call} +} + +func (_m *ExecutableNodeStatus) OnGetOrCreateArrayNodeStatusMatch(matchers ...interface{}) *ExecutableNodeStatus_GetOrCreateArrayNodeStatus { + c_call := _m.On("GetOrCreateArrayNodeStatus", matchers...) + return &ExecutableNodeStatus_GetOrCreateArrayNodeStatus{Call: c_call} +} + +// GetOrCreateArrayNodeStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetOrCreateArrayNodeStatus() v1alpha1.MutableArrayNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableArrayNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableArrayNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableArrayNodeStatus) + } + } + + return r0 +} + type ExecutableNodeStatus_GetOrCreateBranchStatus struct { *mock.Call } diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableArrayNodeStatus.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableArrayNodeStatus.go new file mode 100644 index 000000000..e052187cc --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableArrayNodeStatus.go @@ -0,0 +1,310 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + core "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + bitarray "github.com/flyteorg/flytestdlib/bitarray" + + mock "github.com/stretchr/testify/mock" + + v1alpha1 "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) + +// MutableArrayNodeStatus is an autogenerated mock type for the MutableArrayNodeStatus type +type MutableArrayNodeStatus struct { + mock.Mock +} + +type MutableArrayNodeStatus_GetArrayNodePhase struct { + *mock.Call +} + +func (_m MutableArrayNodeStatus_GetArrayNodePhase) Return(_a0 v1alpha1.ArrayNodePhase) *MutableArrayNodeStatus_GetArrayNodePhase { + return &MutableArrayNodeStatus_GetArrayNodePhase{Call: _m.Call.Return(_a0)} +} + +func (_m *MutableArrayNodeStatus) OnGetArrayNodePhase() *MutableArrayNodeStatus_GetArrayNodePhase { + c_call := _m.On("GetArrayNodePhase") + return &MutableArrayNodeStatus_GetArrayNodePhase{Call: c_call} +} + +func (_m *MutableArrayNodeStatus) OnGetArrayNodePhaseMatch(matchers ...interface{}) *MutableArrayNodeStatus_GetArrayNodePhase { + c_call := _m.On("GetArrayNodePhase", matchers...) + return &MutableArrayNodeStatus_GetArrayNodePhase{Call: c_call} +} + +// GetArrayNodePhase provides a mock function with given fields: +func (_m *MutableArrayNodeStatus) GetArrayNodePhase() v1alpha1.ArrayNodePhase { + ret := _m.Called() + + var r0 v1alpha1.ArrayNodePhase + if rf, ok := ret.Get(0).(func() v1alpha1.ArrayNodePhase); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.ArrayNodePhase) + } + + return r0 +} + +type MutableArrayNodeStatus_GetExecutionError struct { + *mock.Call +} + +func (_m MutableArrayNodeStatus_GetExecutionError) Return(_a0 *core.ExecutionError) *MutableArrayNodeStatus_GetExecutionError { + return &MutableArrayNodeStatus_GetExecutionError{Call: _m.Call.Return(_a0)} +} + +func (_m *MutableArrayNodeStatus) OnGetExecutionError() *MutableArrayNodeStatus_GetExecutionError { + c_call := _m.On("GetExecutionError") + return &MutableArrayNodeStatus_GetExecutionError{Call: c_call} +} + +func (_m *MutableArrayNodeStatus) OnGetExecutionErrorMatch(matchers ...interface{}) *MutableArrayNodeStatus_GetExecutionError { + c_call := _m.On("GetExecutionError", matchers...) + return &MutableArrayNodeStatus_GetExecutionError{Call: c_call} +} + +// GetExecutionError provides a mock function with given fields: +func (_m *MutableArrayNodeStatus) GetExecutionError() *core.ExecutionError { + ret := _m.Called() + + var r0 *core.ExecutionError + if rf, ok := ret.Get(0).(func() *core.ExecutionError); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.ExecutionError) + } + } + + return r0 +} + +type MutableArrayNodeStatus_GetSubNodePhases struct { + *mock.Call +} + +func (_m MutableArrayNodeStatus_GetSubNodePhases) Return(_a0 bitarray.CompactArray) *MutableArrayNodeStatus_GetSubNodePhases { + return &MutableArrayNodeStatus_GetSubNodePhases{Call: _m.Call.Return(_a0)} +} + +func (_m *MutableArrayNodeStatus) OnGetSubNodePhases() *MutableArrayNodeStatus_GetSubNodePhases { + c_call := _m.On("GetSubNodePhases") + return &MutableArrayNodeStatus_GetSubNodePhases{Call: c_call} +} + +func (_m *MutableArrayNodeStatus) OnGetSubNodePhasesMatch(matchers ...interface{}) *MutableArrayNodeStatus_GetSubNodePhases { + c_call := _m.On("GetSubNodePhases", matchers...) + return &MutableArrayNodeStatus_GetSubNodePhases{Call: c_call} +} + +// GetSubNodePhases provides a mock function with given fields: +func (_m *MutableArrayNodeStatus) GetSubNodePhases() bitarray.CompactArray { + ret := _m.Called() + + var r0 bitarray.CompactArray + if rf, ok := ret.Get(0).(func() bitarray.CompactArray); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bitarray.CompactArray) + } + + return r0 +} + +type MutableArrayNodeStatus_GetSubNodeRetryAttempts struct { + *mock.Call +} + +func (_m MutableArrayNodeStatus_GetSubNodeRetryAttempts) Return(_a0 bitarray.CompactArray) *MutableArrayNodeStatus_GetSubNodeRetryAttempts { + return &MutableArrayNodeStatus_GetSubNodeRetryAttempts{Call: _m.Call.Return(_a0)} +} + +func (_m *MutableArrayNodeStatus) OnGetSubNodeRetryAttempts() *MutableArrayNodeStatus_GetSubNodeRetryAttempts { + c_call := _m.On("GetSubNodeRetryAttempts") + return &MutableArrayNodeStatus_GetSubNodeRetryAttempts{Call: c_call} +} + +func (_m *MutableArrayNodeStatus) OnGetSubNodeRetryAttemptsMatch(matchers ...interface{}) *MutableArrayNodeStatus_GetSubNodeRetryAttempts { + c_call := _m.On("GetSubNodeRetryAttempts", matchers...) + return &MutableArrayNodeStatus_GetSubNodeRetryAttempts{Call: c_call} +} + +// GetSubNodeRetryAttempts provides a mock function with given fields: +func (_m *MutableArrayNodeStatus) GetSubNodeRetryAttempts() bitarray.CompactArray { + ret := _m.Called() + + var r0 bitarray.CompactArray + if rf, ok := ret.Get(0).(func() bitarray.CompactArray); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bitarray.CompactArray) + } + + return r0 +} + +type MutableArrayNodeStatus_GetSubNodeSystemFailures struct { + *mock.Call +} + +func (_m MutableArrayNodeStatus_GetSubNodeSystemFailures) Return(_a0 bitarray.CompactArray) *MutableArrayNodeStatus_GetSubNodeSystemFailures { + return &MutableArrayNodeStatus_GetSubNodeSystemFailures{Call: _m.Call.Return(_a0)} +} + +func (_m *MutableArrayNodeStatus) OnGetSubNodeSystemFailures() *MutableArrayNodeStatus_GetSubNodeSystemFailures { + c_call := _m.On("GetSubNodeSystemFailures") + return &MutableArrayNodeStatus_GetSubNodeSystemFailures{Call: c_call} +} + +func (_m *MutableArrayNodeStatus) OnGetSubNodeSystemFailuresMatch(matchers ...interface{}) *MutableArrayNodeStatus_GetSubNodeSystemFailures { + c_call := _m.On("GetSubNodeSystemFailures", matchers...) + return &MutableArrayNodeStatus_GetSubNodeSystemFailures{Call: c_call} +} + +// GetSubNodeSystemFailures provides a mock function with given fields: +func (_m *MutableArrayNodeStatus) GetSubNodeSystemFailures() bitarray.CompactArray { + ret := _m.Called() + + var r0 bitarray.CompactArray + if rf, ok := ret.Get(0).(func() bitarray.CompactArray); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bitarray.CompactArray) + } + + return r0 +} + +type MutableArrayNodeStatus_GetSubNodeTaskPhases struct { + *mock.Call +} + +func (_m MutableArrayNodeStatus_GetSubNodeTaskPhases) Return(_a0 bitarray.CompactArray) *MutableArrayNodeStatus_GetSubNodeTaskPhases { + return &MutableArrayNodeStatus_GetSubNodeTaskPhases{Call: _m.Call.Return(_a0)} +} + +func (_m *MutableArrayNodeStatus) OnGetSubNodeTaskPhases() *MutableArrayNodeStatus_GetSubNodeTaskPhases { + c_call := _m.On("GetSubNodeTaskPhases") + return &MutableArrayNodeStatus_GetSubNodeTaskPhases{Call: c_call} +} + +func (_m *MutableArrayNodeStatus) OnGetSubNodeTaskPhasesMatch(matchers ...interface{}) *MutableArrayNodeStatus_GetSubNodeTaskPhases { + c_call := _m.On("GetSubNodeTaskPhases", matchers...) + return &MutableArrayNodeStatus_GetSubNodeTaskPhases{Call: c_call} +} + +// GetSubNodeTaskPhases provides a mock function with given fields: +func (_m *MutableArrayNodeStatus) GetSubNodeTaskPhases() bitarray.CompactArray { + ret := _m.Called() + + var r0 bitarray.CompactArray + if rf, ok := ret.Get(0).(func() bitarray.CompactArray); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bitarray.CompactArray) + } + + return r0 +} + +type MutableArrayNodeStatus_GetTaskPhaseVersion struct { + *mock.Call +} + +func (_m MutableArrayNodeStatus_GetTaskPhaseVersion) Return(_a0 uint32) *MutableArrayNodeStatus_GetTaskPhaseVersion { + return &MutableArrayNodeStatus_GetTaskPhaseVersion{Call: _m.Call.Return(_a0)} +} + +func (_m *MutableArrayNodeStatus) OnGetTaskPhaseVersion() *MutableArrayNodeStatus_GetTaskPhaseVersion { + c_call := _m.On("GetTaskPhaseVersion") + return &MutableArrayNodeStatus_GetTaskPhaseVersion{Call: c_call} +} + +func (_m *MutableArrayNodeStatus) OnGetTaskPhaseVersionMatch(matchers ...interface{}) *MutableArrayNodeStatus_GetTaskPhaseVersion { + c_call := _m.On("GetTaskPhaseVersion", matchers...) + return &MutableArrayNodeStatus_GetTaskPhaseVersion{Call: c_call} +} + +// GetTaskPhaseVersion provides a mock function with given fields: +func (_m *MutableArrayNodeStatus) GetTaskPhaseVersion() uint32 { + ret := _m.Called() + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + return r0 +} + +type MutableArrayNodeStatus_IsDirty struct { + *mock.Call +} + +func (_m MutableArrayNodeStatus_IsDirty) Return(_a0 bool) *MutableArrayNodeStatus_IsDirty { + return &MutableArrayNodeStatus_IsDirty{Call: _m.Call.Return(_a0)} +} + +func (_m *MutableArrayNodeStatus) OnIsDirty() *MutableArrayNodeStatus_IsDirty { + c_call := _m.On("IsDirty") + return &MutableArrayNodeStatus_IsDirty{Call: c_call} +} + +func (_m *MutableArrayNodeStatus) OnIsDirtyMatch(matchers ...interface{}) *MutableArrayNodeStatus_IsDirty { + c_call := _m.On("IsDirty", matchers...) + return &MutableArrayNodeStatus_IsDirty{Call: c_call} +} + +// IsDirty provides a mock function with given fields: +func (_m *MutableArrayNodeStatus) IsDirty() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// SetArrayNodePhase provides a mock function with given fields: phase +func (_m *MutableArrayNodeStatus) SetArrayNodePhase(phase v1alpha1.ArrayNodePhase) { + _m.Called(phase) +} + +// SetExecutionError provides a mock function with given fields: executionError +func (_m *MutableArrayNodeStatus) SetExecutionError(executionError *core.ExecutionError) { + _m.Called(executionError) +} + +// SetSubNodePhases provides a mock function with given fields: subNodePhases +func (_m *MutableArrayNodeStatus) SetSubNodePhases(subNodePhases bitarray.CompactArray) { + _m.Called(subNodePhases) +} + +// SetSubNodeRetryAttempts provides a mock function with given fields: subNodeRetryAttempts +func (_m *MutableArrayNodeStatus) SetSubNodeRetryAttempts(subNodeRetryAttempts bitarray.CompactArray) { + _m.Called(subNodeRetryAttempts) +} + +// SetSubNodeSystemFailures provides a mock function with given fields: subNodeSystemFailures +func (_m *MutableArrayNodeStatus) SetSubNodeSystemFailures(subNodeSystemFailures bitarray.CompactArray) { + _m.Called(subNodeSystemFailures) +} + +// SetSubNodeTaskPhases provides a mock function with given fields: subNodeTaskPhases +func (_m *MutableArrayNodeStatus) SetSubNodeTaskPhases(subNodeTaskPhases bitarray.CompactArray) { + _m.Called(subNodeTaskPhases) +} + +// SetTaskPhaseVersion provides a mock function with given fields: taskPhaseVersion +func (_m *MutableArrayNodeStatus) SetTaskPhaseVersion(taskPhaseVersion uint32) { + _m.Called(taskPhaseVersion) +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableNodeStatus.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableNodeStatus.go index 9bb0f59b2..56feb9c1b 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableNodeStatus.go +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableNodeStatus.go @@ -18,6 +18,11 @@ type MutableNodeStatus struct { mock.Mock } +// ClearArrayNodeStatus provides a mock function with given fields: +func (_m *MutableNodeStatus) ClearArrayNodeStatus() { + _m.Called() +} + // ClearDynamicNodeStatus provides a mock function with given fields: func (_m *MutableNodeStatus) ClearDynamicNodeStatus() { _m.Called() @@ -48,6 +53,40 @@ func (_m *MutableNodeStatus) ClearWorkflowStatus() { _m.Called() } +type MutableNodeStatus_GetArrayNodeStatus struct { + *mock.Call +} + +func (_m MutableNodeStatus_GetArrayNodeStatus) Return(_a0 v1alpha1.MutableArrayNodeStatus) *MutableNodeStatus_GetArrayNodeStatus { + return &MutableNodeStatus_GetArrayNodeStatus{Call: _m.Call.Return(_a0)} +} + +func (_m *MutableNodeStatus) OnGetArrayNodeStatus() *MutableNodeStatus_GetArrayNodeStatus { + c_call := _m.On("GetArrayNodeStatus") + return &MutableNodeStatus_GetArrayNodeStatus{Call: c_call} +} + +func (_m *MutableNodeStatus) OnGetArrayNodeStatusMatch(matchers ...interface{}) *MutableNodeStatus_GetArrayNodeStatus { + c_call := _m.On("GetArrayNodeStatus", matchers...) + return &MutableNodeStatus_GetArrayNodeStatus{Call: c_call} +} + +// GetArrayNodeStatus provides a mock function with given fields: +func (_m *MutableNodeStatus) GetArrayNodeStatus() v1alpha1.MutableArrayNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableArrayNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableArrayNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableArrayNodeStatus) + } + } + + return r0 +} + type MutableNodeStatus_GetBranchStatus struct { *mock.Call } @@ -150,6 +189,40 @@ func (_m *MutableNodeStatus) GetGateNodeStatus() v1alpha1.MutableGateNodeStatus return r0 } +type MutableNodeStatus_GetOrCreateArrayNodeStatus struct { + *mock.Call +} + +func (_m MutableNodeStatus_GetOrCreateArrayNodeStatus) Return(_a0 v1alpha1.MutableArrayNodeStatus) *MutableNodeStatus_GetOrCreateArrayNodeStatus { + return &MutableNodeStatus_GetOrCreateArrayNodeStatus{Call: _m.Call.Return(_a0)} +} + +func (_m *MutableNodeStatus) OnGetOrCreateArrayNodeStatus() *MutableNodeStatus_GetOrCreateArrayNodeStatus { + c_call := _m.On("GetOrCreateArrayNodeStatus") + return &MutableNodeStatus_GetOrCreateArrayNodeStatus{Call: c_call} +} + +func (_m *MutableNodeStatus) OnGetOrCreateArrayNodeStatusMatch(matchers ...interface{}) *MutableNodeStatus_GetOrCreateArrayNodeStatus { + c_call := _m.On("GetOrCreateArrayNodeStatus", matchers...) + return &MutableNodeStatus_GetOrCreateArrayNodeStatus{Call: c_call} +} + +// GetOrCreateArrayNodeStatus provides a mock function with given fields: +func (_m *MutableNodeStatus) GetOrCreateArrayNodeStatus() v1alpha1.MutableArrayNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableArrayNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableArrayNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableArrayNodeStatus) + } + } + + return r0 +} + type MutableNodeStatus_GetOrCreateBranchStatus struct { *mock.Call } diff --git a/pkg/compiler/common/mocks/node.go b/pkg/compiler/common/mocks/node.go index 364a1921d..ea9a24df1 100644 --- a/pkg/compiler/common/mocks/node.go +++ b/pkg/compiler/common/mocks/node.go @@ -14,6 +14,40 @@ type Node struct { mock.Mock } +type Node_GetArrayNode struct { + *mock.Call +} + +func (_m Node_GetArrayNode) Return(_a0 *core.ArrayNode) *Node_GetArrayNode { + return &Node_GetArrayNode{Call: _m.Call.Return(_a0)} +} + +func (_m *Node) OnGetArrayNode() *Node_GetArrayNode { + c_call := _m.On("GetArrayNode") + return &Node_GetArrayNode{Call: c_call} +} + +func (_m *Node) OnGetArrayNodeMatch(matchers ...interface{}) *Node_GetArrayNode { + c_call := _m.On("GetArrayNode", matchers...) + return &Node_GetArrayNode{Call: c_call} +} + +// GetArrayNode provides a mock function with given fields: +func (_m *Node) GetArrayNode() *core.ArrayNode { + ret := _m.Called() + + var r0 *core.ArrayNode + if rf, ok := ret.Get(0).(func() *core.ArrayNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.ArrayNode) + } + } + + return r0 +} + type Node_GetBranchNode struct { *mock.Call } diff --git a/pkg/compiler/common/mocks/node_builder.go b/pkg/compiler/common/mocks/node_builder.go index 44b320dc9..9ab750130 100644 --- a/pkg/compiler/common/mocks/node_builder.go +++ b/pkg/compiler/common/mocks/node_builder.go @@ -14,6 +14,40 @@ type NodeBuilder struct { mock.Mock } +type NodeBuilder_GetArrayNode struct { + *mock.Call +} + +func (_m NodeBuilder_GetArrayNode) Return(_a0 *core.ArrayNode) *NodeBuilder_GetArrayNode { + return &NodeBuilder_GetArrayNode{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeBuilder) OnGetArrayNode() *NodeBuilder_GetArrayNode { + c_call := _m.On("GetArrayNode") + return &NodeBuilder_GetArrayNode{Call: c_call} +} + +func (_m *NodeBuilder) OnGetArrayNodeMatch(matchers ...interface{}) *NodeBuilder_GetArrayNode { + c_call := _m.On("GetArrayNode", matchers...) + return &NodeBuilder_GetArrayNode{Call: c_call} +} + +// GetArrayNode provides a mock function with given fields: +func (_m *NodeBuilder) GetArrayNode() *core.ArrayNode { + ret := _m.Called() + + var r0 *core.ArrayNode + if rf, ok := ret.Get(0).(func() *core.ArrayNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.ArrayNode) + } + } + + return r0 +} + type NodeBuilder_GetBranchNode struct { *mock.Call } diff --git a/pkg/controller/nodes/array/handler_test.go b/pkg/controller/nodes/array/handler_test.go new file mode 100644 index 000000000..a268a1420 --- /dev/null +++ b/pkg/controller/nodes/array/handler_test.go @@ -0,0 +1,70 @@ +package array + +import ( + "context" + "testing" + + "github.com/flyteorg/flytepropeller/events" + eventmocks "github.com/flyteorg/flytepropeller/events/mocks" + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + execmocks "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/config" + gatemocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes" + gatemocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/gate/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + recoverymocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/catalog" + + "github.com/flyteorg/flytestdlib/promutils" + "github.com/flyteorg/flytestdlib/storage" +) + +func createArrayNodeExecutor(t *testing.T, ctx context.Context, scope promutils.Scope) (handler.Node, error) { + // mock components + adminClient := launchplan.NewFailFastLaunchPlanExecutor() + dataStore, err := storage.NewDataStore(&storage.Config{ + Type: storage.TypeMemory, + }, scope) + assert.NoError(t, err) + enqueueWorkflowFunc := func(workflowID v1alpha1.WorkflowID) {} + eventConfig := &events.EventConfig{} + mockEventSink = eventmocks.NewMockEventSink() + mockKubeClient = execmocks.NewFakeKubeClient() + mockRecoveryClient = &recoverymocks.Client{} + mockSignalClient = &gatemocks.SignalServiceClient{} + noopCatalogClient = catalog.NOOPCatalog{} + scope := promutils.NewTestScope() + + // create node executor + nodeExecutor, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, dataStore, enqueueWorkflowFunc, mockEventSink, adminClient, + adminClient, 10, "s3://bucket/", mockKubeClient, noopCatalogClient, mockRecoveryClient, eventConfig, "clusterID", mockSignalClient, scope) + assert.NoError(t, err) + + return New(nodeExecutor, eventConfig, scope) +} + +func TestAbort(t *testing.T) { + // TODO @hamersaw - complete +} + +func TestFinalize(t *testing.T) { + // TODO @hamersaw - complete +} + +func TestHandleArrayNodePhaseNone(t *testing.T) { + ctx := context.Background() + // TODO @hamersaw - complete +} + +func TestHandleArrayNodePhaseExecuting(t *testing.T) { + // TODO @hamersaw - complete +} + +func TestHandleArrayNodePhaseSucceeding(t *testing.T) { + // TODO @hamersaw - complete +} + +func TestHandleArrayNodePhaseFailing(t *testing.T) { + // TODO @hamersaw - complete +} diff --git a/pkg/controller/nodes/array/utils.go b/pkg/controller/nodes/array/utils.go index fc8633535..ad4f80837 100644 --- a/pkg/controller/nodes/array/utils.go +++ b/pkg/controller/nodes/array/utils.go @@ -67,9 +67,9 @@ func buildTaskExecutionEvent(ctx context.Context, nCtx interfaces.NodeExecutionC }, nil } -// TODO @hamersaw - what do we want for a subNode ID? func buildSubNodeID(nCtx interfaces.NodeExecutionContext, index int, retryAttempt uint32) string { - return fmt.Sprintf("%s-n%d-%d", nCtx.NodeID, index, retryAttempt) + // TODO @hamersaw - what do we want for a subNode ID? + return fmt.Sprintf("%s-n%d-%d", nCtx.NodeID(), index, retryAttempt) } func bytesFromK8sPluginState(pluginState k8s.PluginState) ([]byte, error) { diff --git a/pkg/controller/nodes/dynamic/mocks/task_node_handler.go b/pkg/controller/nodes/dynamic/mocks/task_node_handler.go index 49936c11d..fc69132e5 100644 --- a/pkg/controller/nodes/dynamic/mocks/task_node_handler.go +++ b/pkg/controller/nodes/dynamic/mocks/task_node_handler.go @@ -9,6 +9,8 @@ import ( handler "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + interfaces "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + io "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" ioutils "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" @@ -31,7 +33,7 @@ func (_m TaskNodeHandler_Abort) Return(_a0 error) *TaskNodeHandler_Abort { return &TaskNodeHandler_Abort{Call: _m.Call.Return(_a0)} } -func (_m *TaskNodeHandler) OnAbort(ctx context.Context, executionContext handler.NodeExecutionContext, reason string) *TaskNodeHandler_Abort { +func (_m *TaskNodeHandler) OnAbort(ctx context.Context, executionContext interfaces.NodeExecutionContext, reason string) *TaskNodeHandler_Abort { c_call := _m.On("Abort", ctx, executionContext, reason) return &TaskNodeHandler_Abort{Call: c_call} } @@ -42,11 +44,11 @@ func (_m *TaskNodeHandler) OnAbortMatch(matchers ...interface{}) *TaskNodeHandle } // Abort provides a mock function with given fields: ctx, executionContext, reason -func (_m *TaskNodeHandler) Abort(ctx context.Context, executionContext handler.NodeExecutionContext, reason string) error { +func (_m *TaskNodeHandler) Abort(ctx context.Context, executionContext interfaces.NodeExecutionContext, reason string) error { ret := _m.Called(ctx, executionContext, reason) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, handler.NodeExecutionContext, string) error); ok { + if rf, ok := ret.Get(0).(func(context.Context, interfaces.NodeExecutionContext, string) error); ok { r0 = rf(ctx, executionContext, reason) } else { r0 = ret.Error(0) @@ -63,7 +65,7 @@ func (_m TaskNodeHandler_Finalize) Return(_a0 error) *TaskNodeHandler_Finalize { return &TaskNodeHandler_Finalize{Call: _m.Call.Return(_a0)} } -func (_m *TaskNodeHandler) OnFinalize(ctx context.Context, executionContext handler.NodeExecutionContext) *TaskNodeHandler_Finalize { +func (_m *TaskNodeHandler) OnFinalize(ctx context.Context, executionContext interfaces.NodeExecutionContext) *TaskNodeHandler_Finalize { c_call := _m.On("Finalize", ctx, executionContext) return &TaskNodeHandler_Finalize{Call: c_call} } @@ -74,11 +76,11 @@ func (_m *TaskNodeHandler) OnFinalizeMatch(matchers ...interface{}) *TaskNodeHan } // Finalize provides a mock function with given fields: ctx, executionContext -func (_m *TaskNodeHandler) Finalize(ctx context.Context, executionContext handler.NodeExecutionContext) error { +func (_m *TaskNodeHandler) Finalize(ctx context.Context, executionContext interfaces.NodeExecutionContext) error { ret := _m.Called(ctx, executionContext) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, handler.NodeExecutionContext) error); ok { + if rf, ok := ret.Get(0).(func(context.Context, interfaces.NodeExecutionContext) error); ok { r0 = rf(ctx, executionContext) } else { r0 = ret.Error(0) @@ -127,7 +129,7 @@ func (_m TaskNodeHandler_Handle) Return(_a0 handler.Transition, _a1 error) *Task return &TaskNodeHandler_Handle{Call: _m.Call.Return(_a0, _a1)} } -func (_m *TaskNodeHandler) OnHandle(ctx context.Context, executionContext handler.NodeExecutionContext) *TaskNodeHandler_Handle { +func (_m *TaskNodeHandler) OnHandle(ctx context.Context, executionContext interfaces.NodeExecutionContext) *TaskNodeHandler_Handle { c_call := _m.On("Handle", ctx, executionContext) return &TaskNodeHandler_Handle{Call: c_call} } @@ -138,18 +140,18 @@ func (_m *TaskNodeHandler) OnHandleMatch(matchers ...interface{}) *TaskNodeHandl } // Handle provides a mock function with given fields: ctx, executionContext -func (_m *TaskNodeHandler) Handle(ctx context.Context, executionContext handler.NodeExecutionContext) (handler.Transition, error) { +func (_m *TaskNodeHandler) Handle(ctx context.Context, executionContext interfaces.NodeExecutionContext) (handler.Transition, error) { ret := _m.Called(ctx, executionContext) var r0 handler.Transition - if rf, ok := ret.Get(0).(func(context.Context, handler.NodeExecutionContext) handler.Transition); ok { + if rf, ok := ret.Get(0).(func(context.Context, interfaces.NodeExecutionContext) handler.Transition); ok { r0 = rf(ctx, executionContext) } else { r0 = ret.Get(0).(handler.Transition) } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, handler.NodeExecutionContext) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, interfaces.NodeExecutionContext) error); ok { r1 = rf(ctx, executionContext) } else { r1 = ret.Error(1) diff --git a/pkg/controller/nodes/handler/mocks/node.go b/pkg/controller/nodes/handler/mocks/node.go index e7e376606..52eba43b7 100644 --- a/pkg/controller/nodes/handler/mocks/node.go +++ b/pkg/controller/nodes/handler/mocks/node.go @@ -6,6 +6,8 @@ import ( context "context" handler "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + interfaces "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + mock "github.com/stretchr/testify/mock" ) @@ -22,7 +24,7 @@ func (_m Node_Abort) Return(_a0 error) *Node_Abort { return &Node_Abort{Call: _m.Call.Return(_a0)} } -func (_m *Node) OnAbort(ctx context.Context, executionContext handler.NodeExecutionContext, reason string) *Node_Abort { +func (_m *Node) OnAbort(ctx context.Context, executionContext interfaces.NodeExecutionContext, reason string) *Node_Abort { c_call := _m.On("Abort", ctx, executionContext, reason) return &Node_Abort{Call: c_call} } @@ -33,11 +35,11 @@ func (_m *Node) OnAbortMatch(matchers ...interface{}) *Node_Abort { } // Abort provides a mock function with given fields: ctx, executionContext, reason -func (_m *Node) Abort(ctx context.Context, executionContext handler.NodeExecutionContext, reason string) error { +func (_m *Node) Abort(ctx context.Context, executionContext interfaces.NodeExecutionContext, reason string) error { ret := _m.Called(ctx, executionContext, reason) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, handler.NodeExecutionContext, string) error); ok { + if rf, ok := ret.Get(0).(func(context.Context, interfaces.NodeExecutionContext, string) error); ok { r0 = rf(ctx, executionContext, reason) } else { r0 = ret.Error(0) @@ -54,7 +56,7 @@ func (_m Node_Finalize) Return(_a0 error) *Node_Finalize { return &Node_Finalize{Call: _m.Call.Return(_a0)} } -func (_m *Node) OnFinalize(ctx context.Context, executionContext handler.NodeExecutionContext) *Node_Finalize { +func (_m *Node) OnFinalize(ctx context.Context, executionContext interfaces.NodeExecutionContext) *Node_Finalize { c_call := _m.On("Finalize", ctx, executionContext) return &Node_Finalize{Call: c_call} } @@ -65,11 +67,11 @@ func (_m *Node) OnFinalizeMatch(matchers ...interface{}) *Node_Finalize { } // Finalize provides a mock function with given fields: ctx, executionContext -func (_m *Node) Finalize(ctx context.Context, executionContext handler.NodeExecutionContext) error { +func (_m *Node) Finalize(ctx context.Context, executionContext interfaces.NodeExecutionContext) error { ret := _m.Called(ctx, executionContext) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, handler.NodeExecutionContext) error); ok { + if rf, ok := ret.Get(0).(func(context.Context, interfaces.NodeExecutionContext) error); ok { r0 = rf(ctx, executionContext) } else { r0 = ret.Error(0) @@ -118,7 +120,7 @@ func (_m Node_Handle) Return(_a0 handler.Transition, _a1 error) *Node_Handle { return &Node_Handle{Call: _m.Call.Return(_a0, _a1)} } -func (_m *Node) OnHandle(ctx context.Context, executionContext handler.NodeExecutionContext) *Node_Handle { +func (_m *Node) OnHandle(ctx context.Context, executionContext interfaces.NodeExecutionContext) *Node_Handle { c_call := _m.On("Handle", ctx, executionContext) return &Node_Handle{Call: c_call} } @@ -129,18 +131,18 @@ func (_m *Node) OnHandleMatch(matchers ...interface{}) *Node_Handle { } // Handle provides a mock function with given fields: ctx, executionContext -func (_m *Node) Handle(ctx context.Context, executionContext handler.NodeExecutionContext) (handler.Transition, error) { +func (_m *Node) Handle(ctx context.Context, executionContext interfaces.NodeExecutionContext) (handler.Transition, error) { ret := _m.Called(ctx, executionContext) var r0 handler.Transition - if rf, ok := ret.Get(0).(func(context.Context, handler.NodeExecutionContext) handler.Transition); ok { + if rf, ok := ret.Get(0).(func(context.Context, interfaces.NodeExecutionContext) handler.Transition); ok { r0 = rf(ctx, executionContext) } else { r0 = ret.Get(0).(handler.Transition) } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, handler.NodeExecutionContext) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, interfaces.NodeExecutionContext) error); ok { r1 = rf(ctx, executionContext) } else { r1 = ret.Error(1) diff --git a/pkg/controller/nodes/handler/mocks/node_executor.go b/pkg/controller/nodes/handler/mocks/node_executor.go new file mode 100644 index 000000000..9aeca0cd3 --- /dev/null +++ b/pkg/controller/nodes/handler/mocks/node_executor.go @@ -0,0 +1,122 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + executors "github.com/flyteorg/flytepropeller/pkg/controller/executors" + handler "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + + interfaces "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + + mock "github.com/stretchr/testify/mock" +) + +// NodeExecutor is an autogenerated mock type for the NodeExecutor type +type NodeExecutor struct { + mock.Mock +} + +type NodeExecutor_Abort struct { + *mock.Call +} + +func (_m NodeExecutor_Abort) Return(_a0 error) *NodeExecutor_Abort { + return &NodeExecutor_Abort{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeExecutor) OnAbort(ctx context.Context, h handler.Node, nCtx interfaces.NodeExecutionContext, reason string) *NodeExecutor_Abort { + c_call := _m.On("Abort", ctx, h, nCtx, reason) + return &NodeExecutor_Abort{Call: c_call} +} + +func (_m *NodeExecutor) OnAbortMatch(matchers ...interface{}) *NodeExecutor_Abort { + c_call := _m.On("Abort", matchers...) + return &NodeExecutor_Abort{Call: c_call} +} + +// Abort provides a mock function with given fields: ctx, h, nCtx, reason +func (_m *NodeExecutor) Abort(ctx context.Context, h handler.Node, nCtx interfaces.NodeExecutionContext, reason string) error { + ret := _m.Called(ctx, h, nCtx, reason) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, handler.Node, interfaces.NodeExecutionContext, string) error); ok { + r0 = rf(ctx, h, nCtx, reason) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type NodeExecutor_Finalize struct { + *mock.Call +} + +func (_m NodeExecutor_Finalize) Return(_a0 error) *NodeExecutor_Finalize { + return &NodeExecutor_Finalize{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeExecutor) OnFinalize(ctx context.Context, h handler.Node, nCtx interfaces.NodeExecutionContext) *NodeExecutor_Finalize { + c_call := _m.On("Finalize", ctx, h, nCtx) + return &NodeExecutor_Finalize{Call: c_call} +} + +func (_m *NodeExecutor) OnFinalizeMatch(matchers ...interface{}) *NodeExecutor_Finalize { + c_call := _m.On("Finalize", matchers...) + return &NodeExecutor_Finalize{Call: c_call} +} + +// Finalize provides a mock function with given fields: ctx, h, nCtx +func (_m *NodeExecutor) Finalize(ctx context.Context, h handler.Node, nCtx interfaces.NodeExecutionContext) error { + ret := _m.Called(ctx, h, nCtx) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, handler.Node, interfaces.NodeExecutionContext) error); ok { + r0 = rf(ctx, h, nCtx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type NodeExecutor_HandleNode struct { + *mock.Call +} + +func (_m NodeExecutor_HandleNode) Return(_a0 interfaces.NodeStatus, _a1 error) *NodeExecutor_HandleNode { + return &NodeExecutor_HandleNode{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *NodeExecutor) OnHandleNode(ctx context.Context, dag executors.DAGStructure, nCtx interfaces.NodeExecutionContext, h handler.Node) *NodeExecutor_HandleNode { + c_call := _m.On("HandleNode", ctx, dag, nCtx, h) + return &NodeExecutor_HandleNode{Call: c_call} +} + +func (_m *NodeExecutor) OnHandleNodeMatch(matchers ...interface{}) *NodeExecutor_HandleNode { + c_call := _m.On("HandleNode", matchers...) + return &NodeExecutor_HandleNode{Call: c_call} +} + +// HandleNode provides a mock function with given fields: ctx, dag, nCtx, h +func (_m *NodeExecutor) HandleNode(ctx context.Context, dag executors.DAGStructure, nCtx interfaces.NodeExecutionContext, h handler.Node) (interfaces.NodeStatus, error) { + ret := _m.Called(ctx, dag, nCtx, h) + + var r0 interfaces.NodeStatus + if rf, ok := ret.Get(0).(func(context.Context, executors.DAGStructure, interfaces.NodeExecutionContext, handler.Node) interfaces.NodeStatus); ok { + r0 = rf(ctx, dag, nCtx, h) + } else { + r0 = ret.Get(0).(interfaces.NodeStatus) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, executors.DAGStructure, interfaces.NodeExecutionContext, handler.Node) error); ok { + r1 = rf(ctx, dag, nCtx, h) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/pkg/controller/nodes/handler/mocks/node_state_reader.go b/pkg/controller/nodes/handler/mocks/node_state_reader.go deleted file mode 100644 index ef86e64cc..000000000 --- a/pkg/controller/nodes/handler/mocks/node_state_reader.go +++ /dev/null @@ -1,173 +0,0 @@ -// Code generated by mockery v1.0.1. DO NOT EDIT. - -package mocks - -import ( - handler "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" - mock "github.com/stretchr/testify/mock" -) - -// NodeStateReader is an autogenerated mock type for the NodeStateReader type -type NodeStateReader struct { - mock.Mock -} - -type NodeStateReader_GetBranchNode struct { - *mock.Call -} - -func (_m NodeStateReader_GetBranchNode) Return(_a0 handler.BranchNodeState) *NodeStateReader_GetBranchNode { - return &NodeStateReader_GetBranchNode{Call: _m.Call.Return(_a0)} -} - -func (_m *NodeStateReader) OnGetBranchNode() *NodeStateReader_GetBranchNode { - c_call := _m.On("GetBranchNode") - return &NodeStateReader_GetBranchNode{Call: c_call} -} - -func (_m *NodeStateReader) OnGetBranchNodeMatch(matchers ...interface{}) *NodeStateReader_GetBranchNode { - c_call := _m.On("GetBranchNode", matchers...) - return &NodeStateReader_GetBranchNode{Call: c_call} -} - -// GetBranchNode provides a mock function with given fields: -func (_m *NodeStateReader) GetBranchNode() handler.BranchNodeState { - ret := _m.Called() - - var r0 handler.BranchNodeState - if rf, ok := ret.Get(0).(func() handler.BranchNodeState); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(handler.BranchNodeState) - } - - return r0 -} - -type NodeStateReader_GetDynamicNodeState struct { - *mock.Call -} - -func (_m NodeStateReader_GetDynamicNodeState) Return(_a0 handler.DynamicNodeState) *NodeStateReader_GetDynamicNodeState { - return &NodeStateReader_GetDynamicNodeState{Call: _m.Call.Return(_a0)} -} - -func (_m *NodeStateReader) OnGetDynamicNodeState() *NodeStateReader_GetDynamicNodeState { - c_call := _m.On("GetDynamicNodeState") - return &NodeStateReader_GetDynamicNodeState{Call: c_call} -} - -func (_m *NodeStateReader) OnGetDynamicNodeStateMatch(matchers ...interface{}) *NodeStateReader_GetDynamicNodeState { - c_call := _m.On("GetDynamicNodeState", matchers...) - return &NodeStateReader_GetDynamicNodeState{Call: c_call} -} - -// GetDynamicNodeState provides a mock function with given fields: -func (_m *NodeStateReader) GetDynamicNodeState() handler.DynamicNodeState { - ret := _m.Called() - - var r0 handler.DynamicNodeState - if rf, ok := ret.Get(0).(func() handler.DynamicNodeState); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(handler.DynamicNodeState) - } - - return r0 -} - -type NodeStateReader_GetGateNodeState struct { - *mock.Call -} - -func (_m NodeStateReader_GetGateNodeState) Return(_a0 handler.GateNodeState) *NodeStateReader_GetGateNodeState { - return &NodeStateReader_GetGateNodeState{Call: _m.Call.Return(_a0)} -} - -func (_m *NodeStateReader) OnGetGateNodeState() *NodeStateReader_GetGateNodeState { - c_call := _m.On("GetGateNodeState") - return &NodeStateReader_GetGateNodeState{Call: c_call} -} - -func (_m *NodeStateReader) OnGetGateNodeStateMatch(matchers ...interface{}) *NodeStateReader_GetGateNodeState { - c_call := _m.On("GetGateNodeState", matchers...) - return &NodeStateReader_GetGateNodeState{Call: c_call} -} - -// GetGateNodeState provides a mock function with given fields: -func (_m *NodeStateReader) GetGateNodeState() handler.GateNodeState { - ret := _m.Called() - - var r0 handler.GateNodeState - if rf, ok := ret.Get(0).(func() handler.GateNodeState); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(handler.GateNodeState) - } - - return r0 -} - -type NodeStateReader_GetTaskNodeState struct { - *mock.Call -} - -func (_m NodeStateReader_GetTaskNodeState) Return(_a0 handler.TaskNodeState) *NodeStateReader_GetTaskNodeState { - return &NodeStateReader_GetTaskNodeState{Call: _m.Call.Return(_a0)} -} - -func (_m *NodeStateReader) OnGetTaskNodeState() *NodeStateReader_GetTaskNodeState { - c_call := _m.On("GetTaskNodeState") - return &NodeStateReader_GetTaskNodeState{Call: c_call} -} - -func (_m *NodeStateReader) OnGetTaskNodeStateMatch(matchers ...interface{}) *NodeStateReader_GetTaskNodeState { - c_call := _m.On("GetTaskNodeState", matchers...) - return &NodeStateReader_GetTaskNodeState{Call: c_call} -} - -// GetTaskNodeState provides a mock function with given fields: -func (_m *NodeStateReader) GetTaskNodeState() handler.TaskNodeState { - ret := _m.Called() - - var r0 handler.TaskNodeState - if rf, ok := ret.Get(0).(func() handler.TaskNodeState); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(handler.TaskNodeState) - } - - return r0 -} - -type NodeStateReader_GetWorkflowNodeState struct { - *mock.Call -} - -func (_m NodeStateReader_GetWorkflowNodeState) Return(_a0 handler.WorkflowNodeState) *NodeStateReader_GetWorkflowNodeState { - return &NodeStateReader_GetWorkflowNodeState{Call: _m.Call.Return(_a0)} -} - -func (_m *NodeStateReader) OnGetWorkflowNodeState() *NodeStateReader_GetWorkflowNodeState { - c_call := _m.On("GetWorkflowNodeState") - return &NodeStateReader_GetWorkflowNodeState{Call: c_call} -} - -func (_m *NodeStateReader) OnGetWorkflowNodeStateMatch(matchers ...interface{}) *NodeStateReader_GetWorkflowNodeState { - c_call := _m.On("GetWorkflowNodeState", matchers...) - return &NodeStateReader_GetWorkflowNodeState{Call: c_call} -} - -// GetWorkflowNodeState provides a mock function with given fields: -func (_m *NodeStateReader) GetWorkflowNodeState() handler.WorkflowNodeState { - ret := _m.Called() - - var r0 handler.WorkflowNodeState - if rf, ok := ret.Get(0).(func() handler.WorkflowNodeState); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(handler.WorkflowNodeState) - } - - return r0 -} diff --git a/pkg/controller/nodes/interfaces/mocks/event_recorder.go b/pkg/controller/nodes/interfaces/mocks/event_recorder.go new file mode 100644 index 000000000..684419825 --- /dev/null +++ b/pkg/controller/nodes/interfaces/mocks/event_recorder.go @@ -0,0 +1,82 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + config "github.com/flyteorg/flytepropeller/pkg/controller/config" + + event "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + + mock "github.com/stretchr/testify/mock" +) + +// EventRecorder is an autogenerated mock type for the EventRecorder type +type EventRecorder struct { + mock.Mock +} + +type EventRecorder_RecordNodeEvent struct { + *mock.Call +} + +func (_m EventRecorder_RecordNodeEvent) Return(_a0 error) *EventRecorder_RecordNodeEvent { + return &EventRecorder_RecordNodeEvent{Call: _m.Call.Return(_a0)} +} + +func (_m *EventRecorder) OnRecordNodeEvent(ctx context.Context, _a1 *event.NodeExecutionEvent, eventConfig *config.EventConfig) *EventRecorder_RecordNodeEvent { + c_call := _m.On("RecordNodeEvent", ctx, _a1, eventConfig) + return &EventRecorder_RecordNodeEvent{Call: c_call} +} + +func (_m *EventRecorder) OnRecordNodeEventMatch(matchers ...interface{}) *EventRecorder_RecordNodeEvent { + c_call := _m.On("RecordNodeEvent", matchers...) + return &EventRecorder_RecordNodeEvent{Call: c_call} +} + +// RecordNodeEvent provides a mock function with given fields: ctx, _a1, eventConfig +func (_m *EventRecorder) RecordNodeEvent(ctx context.Context, _a1 *event.NodeExecutionEvent, eventConfig *config.EventConfig) error { + ret := _m.Called(ctx, _a1, eventConfig) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *event.NodeExecutionEvent, *config.EventConfig) error); ok { + r0 = rf(ctx, _a1, eventConfig) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type EventRecorder_RecordTaskEvent struct { + *mock.Call +} + +func (_m EventRecorder_RecordTaskEvent) Return(_a0 error) *EventRecorder_RecordTaskEvent { + return &EventRecorder_RecordTaskEvent{Call: _m.Call.Return(_a0)} +} + +func (_m *EventRecorder) OnRecordTaskEvent(ctx context.Context, _a1 *event.TaskExecutionEvent, eventConfig *config.EventConfig) *EventRecorder_RecordTaskEvent { + c_call := _m.On("RecordTaskEvent", ctx, _a1, eventConfig) + return &EventRecorder_RecordTaskEvent{Call: c_call} +} + +func (_m *EventRecorder) OnRecordTaskEventMatch(matchers ...interface{}) *EventRecorder_RecordTaskEvent { + c_call := _m.On("RecordTaskEvent", matchers...) + return &EventRecorder_RecordTaskEvent{Call: c_call} +} + +// RecordTaskEvent provides a mock function with given fields: ctx, _a1, eventConfig +func (_m *EventRecorder) RecordTaskEvent(ctx context.Context, _a1 *event.TaskExecutionEvent, eventConfig *config.EventConfig) error { + ret := _m.Called(ctx, _a1, eventConfig) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *event.TaskExecutionEvent, *config.EventConfig) error); ok { + r0 = rf(ctx, _a1, eventConfig) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/pkg/controller/executors/mocks/node.go b/pkg/controller/nodes/interfaces/mocks/node.go similarity index 68% rename from pkg/controller/executors/mocks/node.go rename to pkg/controller/nodes/interfaces/mocks/node.go index 8f2d5cf0c..0413f7ebc 100644 --- a/pkg/controller/executors/mocks/node.go +++ b/pkg/controller/nodes/interfaces/mocks/node.go @@ -8,6 +8,8 @@ import ( core "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" executors "github.com/flyteorg/flytepropeller/pkg/controller/executors" + interfaces "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + mock "github.com/stretchr/testify/mock" v1alpha1 "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" @@ -82,6 +84,40 @@ func (_m *Node) FinalizeHandler(ctx context.Context, execContext executors.Execu return r0 } +type Node_GetNodeExecutionContextBuilder struct { + *mock.Call +} + +func (_m Node_GetNodeExecutionContextBuilder) Return(_a0 interfaces.NodeExecutionContextBuilder) *Node_GetNodeExecutionContextBuilder { + return &Node_GetNodeExecutionContextBuilder{Call: _m.Call.Return(_a0)} +} + +func (_m *Node) OnGetNodeExecutionContextBuilder() *Node_GetNodeExecutionContextBuilder { + c_call := _m.On("GetNodeExecutionContextBuilder") + return &Node_GetNodeExecutionContextBuilder{Call: c_call} +} + +func (_m *Node) OnGetNodeExecutionContextBuilderMatch(matchers ...interface{}) *Node_GetNodeExecutionContextBuilder { + c_call := _m.On("GetNodeExecutionContextBuilder", matchers...) + return &Node_GetNodeExecutionContextBuilder{Call: c_call} +} + +// GetNodeExecutionContextBuilder provides a mock function with given fields: +func (_m *Node) GetNodeExecutionContextBuilder() interfaces.NodeExecutionContextBuilder { + ret := _m.Called() + + var r0 interfaces.NodeExecutionContextBuilder + if rf, ok := ret.Get(0).(func() interfaces.NodeExecutionContextBuilder); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(interfaces.NodeExecutionContextBuilder) + } + } + + return r0 +} + type Node_Initialize struct { *mock.Call } @@ -118,7 +154,7 @@ type Node_RecursiveNodeHandler struct { *mock.Call } -func (_m Node_RecursiveNodeHandler) Return(_a0 executors.NodeStatus, _a1 error) *Node_RecursiveNodeHandler { +func (_m Node_RecursiveNodeHandler) Return(_a0 interfaces.NodeStatus, _a1 error) *Node_RecursiveNodeHandler { return &Node_RecursiveNodeHandler{Call: _m.Call.Return(_a0, _a1)} } @@ -133,14 +169,14 @@ func (_m *Node) OnRecursiveNodeHandlerMatch(matchers ...interface{}) *Node_Recur } // RecursiveNodeHandler provides a mock function with given fields: ctx, execContext, dag, nl, currentNode -func (_m *Node) RecursiveNodeHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) (executors.NodeStatus, error) { +func (_m *Node) RecursiveNodeHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) (interfaces.NodeStatus, error) { ret := _m.Called(ctx, execContext, dag, nl, currentNode) - var r0 executors.NodeStatus - if rf, ok := ret.Get(0).(func(context.Context, executors.ExecutionContext, executors.DAGStructure, executors.NodeLookup, v1alpha1.ExecutableNode) executors.NodeStatus); ok { + var r0 interfaces.NodeStatus + if rf, ok := ret.Get(0).(func(context.Context, executors.ExecutionContext, executors.DAGStructure, executors.NodeLookup, v1alpha1.ExecutableNode) interfaces.NodeStatus); ok { r0 = rf(ctx, execContext, dag, nl, currentNode) } else { - r0 = ret.Get(0).(executors.NodeStatus) + r0 = ret.Get(0).(interfaces.NodeStatus) } var r1 error @@ -157,7 +193,7 @@ type Node_SetInputsForStartNode struct { *mock.Call } -func (_m Node_SetInputsForStartNode) Return(_a0 executors.NodeStatus, _a1 error) *Node_SetInputsForStartNode { +func (_m Node_SetInputsForStartNode) Return(_a0 interfaces.NodeStatus, _a1 error) *Node_SetInputsForStartNode { return &Node_SetInputsForStartNode{Call: _m.Call.Return(_a0, _a1)} } @@ -172,14 +208,14 @@ func (_m *Node) OnSetInputsForStartNodeMatch(matchers ...interface{}) *Node_SetI } // SetInputsForStartNode provides a mock function with given fields: ctx, execContext, dag, nl, inputs -func (_m *Node) SetInputsForStartNode(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructureWithStartNode, nl executors.NodeLookup, inputs *core.LiteralMap) (executors.NodeStatus, error) { +func (_m *Node) SetInputsForStartNode(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructureWithStartNode, nl executors.NodeLookup, inputs *core.LiteralMap) (interfaces.NodeStatus, error) { ret := _m.Called(ctx, execContext, dag, nl, inputs) - var r0 executors.NodeStatus - if rf, ok := ret.Get(0).(func(context.Context, executors.ExecutionContext, executors.DAGStructureWithStartNode, executors.NodeLookup, *core.LiteralMap) executors.NodeStatus); ok { + var r0 interfaces.NodeStatus + if rf, ok := ret.Get(0).(func(context.Context, executors.ExecutionContext, executors.DAGStructureWithStartNode, executors.NodeLookup, *core.LiteralMap) interfaces.NodeStatus); ok { r0 = rf(ctx, execContext, dag, nl, inputs) } else { - r0 = ret.Get(0).(executors.NodeStatus) + r0 = ret.Get(0).(interfaces.NodeStatus) } var r1 error @@ -191,3 +227,37 @@ func (_m *Node) SetInputsForStartNode(ctx context.Context, execContext executors return r0, r1 } + +type Node_WithNodeExecutionContextBuilder struct { + *mock.Call +} + +func (_m Node_WithNodeExecutionContextBuilder) Return(_a0 interfaces.Node) *Node_WithNodeExecutionContextBuilder { + return &Node_WithNodeExecutionContextBuilder{Call: _m.Call.Return(_a0)} +} + +func (_m *Node) OnWithNodeExecutionContextBuilder(_a0 interfaces.NodeExecutionContextBuilder) *Node_WithNodeExecutionContextBuilder { + c_call := _m.On("WithNodeExecutionContextBuilder", _a0) + return &Node_WithNodeExecutionContextBuilder{Call: c_call} +} + +func (_m *Node) OnWithNodeExecutionContextBuilderMatch(matchers ...interface{}) *Node_WithNodeExecutionContextBuilder { + c_call := _m.On("WithNodeExecutionContextBuilder", matchers...) + return &Node_WithNodeExecutionContextBuilder{Call: c_call} +} + +// WithNodeExecutionContextBuilder provides a mock function with given fields: _a0 +func (_m *Node) WithNodeExecutionContextBuilder(_a0 interfaces.NodeExecutionContextBuilder) interfaces.Node { + ret := _m.Called(_a0) + + var r0 interfaces.Node + if rf, ok := ret.Get(0).(func(interfaces.NodeExecutionContextBuilder) interfaces.Node); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(interfaces.Node) + } + } + + return r0 +} diff --git a/pkg/controller/nodes/handler/mocks/node_execution_context.go b/pkg/controller/nodes/interfaces/mocks/node_execution_context.go similarity index 90% rename from pkg/controller/nodes/handler/mocks/node_execution_context.go rename to pkg/controller/nodes/interfaces/mocks/node_execution_context.go index 434f78caa..fcf130e30 100644 --- a/pkg/controller/nodes/handler/mocks/node_execution_context.go +++ b/pkg/controller/nodes/interfaces/mocks/node_execution_context.go @@ -3,9 +3,8 @@ package mocks import ( - events "github.com/flyteorg/flytepropeller/events" executors "github.com/flyteorg/flytepropeller/pkg/controller/executors" - handler "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + interfaces "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" io "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" @@ -161,7 +160,7 @@ type NodeExecutionContext_EventsRecorder struct { *mock.Call } -func (_m NodeExecutionContext_EventsRecorder) Return(_a0 events.TaskEventRecorder) *NodeExecutionContext_EventsRecorder { +func (_m NodeExecutionContext_EventsRecorder) Return(_a0 interfaces.EventRecorder) *NodeExecutionContext_EventsRecorder { return &NodeExecutionContext_EventsRecorder{Call: _m.Call.Return(_a0)} } @@ -176,15 +175,15 @@ func (_m *NodeExecutionContext) OnEventsRecorderMatch(matchers ...interface{}) * } // EventsRecorder provides a mock function with given fields: -func (_m *NodeExecutionContext) EventsRecorder() events.TaskEventRecorder { +func (_m *NodeExecutionContext) EventsRecorder() interfaces.EventRecorder { ret := _m.Called() - var r0 events.TaskEventRecorder - if rf, ok := ret.Get(0).(func() events.TaskEventRecorder); ok { + var r0 interfaces.EventRecorder + if rf, ok := ret.Get(0).(func() interfaces.EventRecorder); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(events.TaskEventRecorder) + r0 = ret.Get(0).(interfaces.EventRecorder) } } @@ -329,7 +328,7 @@ type NodeExecutionContext_NodeExecutionMetadata struct { *mock.Call } -func (_m NodeExecutionContext_NodeExecutionMetadata) Return(_a0 handler.NodeExecutionMetadata) *NodeExecutionContext_NodeExecutionMetadata { +func (_m NodeExecutionContext_NodeExecutionMetadata) Return(_a0 interfaces.NodeExecutionMetadata) *NodeExecutionContext_NodeExecutionMetadata { return &NodeExecutionContext_NodeExecutionMetadata{Call: _m.Call.Return(_a0)} } @@ -344,15 +343,15 @@ func (_m *NodeExecutionContext) OnNodeExecutionMetadataMatch(matchers ...interfa } // NodeExecutionMetadata provides a mock function with given fields: -func (_m *NodeExecutionContext) NodeExecutionMetadata() handler.NodeExecutionMetadata { +func (_m *NodeExecutionContext) NodeExecutionMetadata() interfaces.NodeExecutionMetadata { ret := _m.Called() - var r0 handler.NodeExecutionMetadata - if rf, ok := ret.Get(0).(func() handler.NodeExecutionMetadata); ok { + var r0 interfaces.NodeExecutionMetadata + if rf, ok := ret.Get(0).(func() interfaces.NodeExecutionMetadata); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(handler.NodeExecutionMetadata) + r0 = ret.Get(0).(interfaces.NodeExecutionMetadata) } } @@ -395,7 +394,7 @@ type NodeExecutionContext_NodeStateReader struct { *mock.Call } -func (_m NodeExecutionContext_NodeStateReader) Return(_a0 handler.NodeStateReader) *NodeExecutionContext_NodeStateReader { +func (_m NodeExecutionContext_NodeStateReader) Return(_a0 interfaces.NodeStateReader) *NodeExecutionContext_NodeStateReader { return &NodeExecutionContext_NodeStateReader{Call: _m.Call.Return(_a0)} } @@ -410,15 +409,15 @@ func (_m *NodeExecutionContext) OnNodeStateReaderMatch(matchers ...interface{}) } // NodeStateReader provides a mock function with given fields: -func (_m *NodeExecutionContext) NodeStateReader() handler.NodeStateReader { +func (_m *NodeExecutionContext) NodeStateReader() interfaces.NodeStateReader { ret := _m.Called() - var r0 handler.NodeStateReader - if rf, ok := ret.Get(0).(func() handler.NodeStateReader); ok { + var r0 interfaces.NodeStateReader + if rf, ok := ret.Get(0).(func() interfaces.NodeStateReader); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(handler.NodeStateReader) + r0 = ret.Get(0).(interfaces.NodeStateReader) } } @@ -429,7 +428,7 @@ type NodeExecutionContext_NodeStateWriter struct { *mock.Call } -func (_m NodeExecutionContext_NodeStateWriter) Return(_a0 handler.NodeStateWriter) *NodeExecutionContext_NodeStateWriter { +func (_m NodeExecutionContext_NodeStateWriter) Return(_a0 interfaces.NodeStateWriter) *NodeExecutionContext_NodeStateWriter { return &NodeExecutionContext_NodeStateWriter{Call: _m.Call.Return(_a0)} } @@ -444,15 +443,15 @@ func (_m *NodeExecutionContext) OnNodeStateWriterMatch(matchers ...interface{}) } // NodeStateWriter provides a mock function with given fields: -func (_m *NodeExecutionContext) NodeStateWriter() handler.NodeStateWriter { +func (_m *NodeExecutionContext) NodeStateWriter() interfaces.NodeStateWriter { ret := _m.Called() - var r0 handler.NodeStateWriter - if rf, ok := ret.Get(0).(func() handler.NodeStateWriter); ok { + var r0 interfaces.NodeStateWriter + if rf, ok := ret.Get(0).(func() interfaces.NodeStateWriter); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(handler.NodeStateWriter) + r0 = ret.Get(0).(interfaces.NodeStateWriter) } } @@ -563,7 +562,7 @@ type NodeExecutionContext_TaskReader struct { *mock.Call } -func (_m NodeExecutionContext_TaskReader) Return(_a0 handler.TaskReader) *NodeExecutionContext_TaskReader { +func (_m NodeExecutionContext_TaskReader) Return(_a0 interfaces.TaskReader) *NodeExecutionContext_TaskReader { return &NodeExecutionContext_TaskReader{Call: _m.Call.Return(_a0)} } @@ -578,15 +577,15 @@ func (_m *NodeExecutionContext) OnTaskReaderMatch(matchers ...interface{}) *Node } // TaskReader provides a mock function with given fields: -func (_m *NodeExecutionContext) TaskReader() handler.TaskReader { +func (_m *NodeExecutionContext) TaskReader() interfaces.TaskReader { ret := _m.Called() - var r0 handler.TaskReader - if rf, ok := ret.Get(0).(func() handler.TaskReader); ok { + var r0 interfaces.TaskReader + if rf, ok := ret.Get(0).(func() interfaces.TaskReader); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(handler.TaskReader) + r0 = ret.Get(0).(interfaces.TaskReader) } } diff --git a/pkg/controller/nodes/interfaces/mocks/node_execution_context_builder.go b/pkg/controller/nodes/interfaces/mocks/node_execution_context_builder.go new file mode 100644 index 000000000..d068f1902 --- /dev/null +++ b/pkg/controller/nodes/interfaces/mocks/node_execution_context_builder.go @@ -0,0 +1,58 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + executors "github.com/flyteorg/flytepropeller/pkg/controller/executors" + interfaces "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + + mock "github.com/stretchr/testify/mock" +) + +// NodeExecutionContextBuilder is an autogenerated mock type for the NodeExecutionContextBuilder type +type NodeExecutionContextBuilder struct { + mock.Mock +} + +type NodeExecutionContextBuilder_BuildNodeExecutionContext struct { + *mock.Call +} + +func (_m NodeExecutionContextBuilder_BuildNodeExecutionContext) Return(_a0 interfaces.NodeExecutionContext, _a1 error) *NodeExecutionContextBuilder_BuildNodeExecutionContext { + return &NodeExecutionContextBuilder_BuildNodeExecutionContext{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *NodeExecutionContextBuilder) OnBuildNodeExecutionContext(ctx context.Context, executionContext executors.ExecutionContext, nl executors.NodeLookup, currentNodeID string) *NodeExecutionContextBuilder_BuildNodeExecutionContext { + c_call := _m.On("BuildNodeExecutionContext", ctx, executionContext, nl, currentNodeID) + return &NodeExecutionContextBuilder_BuildNodeExecutionContext{Call: c_call} +} + +func (_m *NodeExecutionContextBuilder) OnBuildNodeExecutionContextMatch(matchers ...interface{}) *NodeExecutionContextBuilder_BuildNodeExecutionContext { + c_call := _m.On("BuildNodeExecutionContext", matchers...) + return &NodeExecutionContextBuilder_BuildNodeExecutionContext{Call: c_call} +} + +// BuildNodeExecutionContext provides a mock function with given fields: ctx, executionContext, nl, currentNodeID +func (_m *NodeExecutionContextBuilder) BuildNodeExecutionContext(ctx context.Context, executionContext executors.ExecutionContext, nl executors.NodeLookup, currentNodeID string) (interfaces.NodeExecutionContext, error) { + ret := _m.Called(ctx, executionContext, nl, currentNodeID) + + var r0 interfaces.NodeExecutionContext + if rf, ok := ret.Get(0).(func(context.Context, executors.ExecutionContext, executors.NodeLookup, string) interfaces.NodeExecutionContext); ok { + r0 = rf(ctx, executionContext, nl, currentNodeID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(interfaces.NodeExecutionContext) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, executors.ExecutionContext, executors.NodeLookup, string) error); ok { + r1 = rf(ctx, executionContext, nl, currentNodeID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/pkg/controller/nodes/handler/mocks/node_execution_metadata.go b/pkg/controller/nodes/interfaces/mocks/node_execution_metadata.go similarity index 100% rename from pkg/controller/nodes/handler/mocks/node_execution_metadata.go rename to pkg/controller/nodes/interfaces/mocks/node_execution_metadata.go diff --git a/pkg/controller/nodes/interfaces/mocks/node_state_reader.go b/pkg/controller/nodes/interfaces/mocks/node_state_reader.go new file mode 100644 index 000000000..2f8191d62 --- /dev/null +++ b/pkg/controller/nodes/interfaces/mocks/node_state_reader.go @@ -0,0 +1,397 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + interfaces "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + mock "github.com/stretchr/testify/mock" +) + +// NodeStateReader is an autogenerated mock type for the NodeStateReader type +type NodeStateReader struct { + mock.Mock +} + +type NodeStateReader_GetArrayNodeState struct { + *mock.Call +} + +func (_m NodeStateReader_GetArrayNodeState) Return(_a0 interfaces.ArrayNodeState) *NodeStateReader_GetArrayNodeState { + return &NodeStateReader_GetArrayNodeState{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeStateReader) OnGetArrayNodeState() *NodeStateReader_GetArrayNodeState { + c_call := _m.On("GetArrayNodeState") + return &NodeStateReader_GetArrayNodeState{Call: c_call} +} + +func (_m *NodeStateReader) OnGetArrayNodeStateMatch(matchers ...interface{}) *NodeStateReader_GetArrayNodeState { + c_call := _m.On("GetArrayNodeState", matchers...) + return &NodeStateReader_GetArrayNodeState{Call: c_call} +} + +// GetArrayNodeState provides a mock function with given fields: +func (_m *NodeStateReader) GetArrayNodeState() interfaces.ArrayNodeState { + ret := _m.Called() + + var r0 interfaces.ArrayNodeState + if rf, ok := ret.Get(0).(func() interfaces.ArrayNodeState); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(interfaces.ArrayNodeState) + } + + return r0 +} + +type NodeStateReader_GetBranchNodeState struct { + *mock.Call +} + +func (_m NodeStateReader_GetBranchNodeState) Return(_a0 interfaces.BranchNodeState) *NodeStateReader_GetBranchNodeState { + return &NodeStateReader_GetBranchNodeState{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeStateReader) OnGetBranchNodeState() *NodeStateReader_GetBranchNodeState { + c_call := _m.On("GetBranchNodeState") + return &NodeStateReader_GetBranchNodeState{Call: c_call} +} + +func (_m *NodeStateReader) OnGetBranchNodeStateMatch(matchers ...interface{}) *NodeStateReader_GetBranchNodeState { + c_call := _m.On("GetBranchNodeState", matchers...) + return &NodeStateReader_GetBranchNodeState{Call: c_call} +} + +// GetBranchNodeState provides a mock function with given fields: +func (_m *NodeStateReader) GetBranchNodeState() interfaces.BranchNodeState { + ret := _m.Called() + + var r0 interfaces.BranchNodeState + if rf, ok := ret.Get(0).(func() interfaces.BranchNodeState); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(interfaces.BranchNodeState) + } + + return r0 +} + +type NodeStateReader_GetDynamicNodeState struct { + *mock.Call +} + +func (_m NodeStateReader_GetDynamicNodeState) Return(_a0 interfaces.DynamicNodeState) *NodeStateReader_GetDynamicNodeState { + return &NodeStateReader_GetDynamicNodeState{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeStateReader) OnGetDynamicNodeState() *NodeStateReader_GetDynamicNodeState { + c_call := _m.On("GetDynamicNodeState") + return &NodeStateReader_GetDynamicNodeState{Call: c_call} +} + +func (_m *NodeStateReader) OnGetDynamicNodeStateMatch(matchers ...interface{}) *NodeStateReader_GetDynamicNodeState { + c_call := _m.On("GetDynamicNodeState", matchers...) + return &NodeStateReader_GetDynamicNodeState{Call: c_call} +} + +// GetDynamicNodeState provides a mock function with given fields: +func (_m *NodeStateReader) GetDynamicNodeState() interfaces.DynamicNodeState { + ret := _m.Called() + + var r0 interfaces.DynamicNodeState + if rf, ok := ret.Get(0).(func() interfaces.DynamicNodeState); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(interfaces.DynamicNodeState) + } + + return r0 +} + +type NodeStateReader_GetGateNodeState struct { + *mock.Call +} + +func (_m NodeStateReader_GetGateNodeState) Return(_a0 interfaces.GateNodeState) *NodeStateReader_GetGateNodeState { + return &NodeStateReader_GetGateNodeState{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeStateReader) OnGetGateNodeState() *NodeStateReader_GetGateNodeState { + c_call := _m.On("GetGateNodeState") + return &NodeStateReader_GetGateNodeState{Call: c_call} +} + +func (_m *NodeStateReader) OnGetGateNodeStateMatch(matchers ...interface{}) *NodeStateReader_GetGateNodeState { + c_call := _m.On("GetGateNodeState", matchers...) + return &NodeStateReader_GetGateNodeState{Call: c_call} +} + +// GetGateNodeState provides a mock function with given fields: +func (_m *NodeStateReader) GetGateNodeState() interfaces.GateNodeState { + ret := _m.Called() + + var r0 interfaces.GateNodeState + if rf, ok := ret.Get(0).(func() interfaces.GateNodeState); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(interfaces.GateNodeState) + } + + return r0 +} + +type NodeStateReader_GetTaskNodeState struct { + *mock.Call +} + +func (_m NodeStateReader_GetTaskNodeState) Return(_a0 interfaces.TaskNodeState) *NodeStateReader_GetTaskNodeState { + return &NodeStateReader_GetTaskNodeState{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeStateReader) OnGetTaskNodeState() *NodeStateReader_GetTaskNodeState { + c_call := _m.On("GetTaskNodeState") + return &NodeStateReader_GetTaskNodeState{Call: c_call} +} + +func (_m *NodeStateReader) OnGetTaskNodeStateMatch(matchers ...interface{}) *NodeStateReader_GetTaskNodeState { + c_call := _m.On("GetTaskNodeState", matchers...) + return &NodeStateReader_GetTaskNodeState{Call: c_call} +} + +// GetTaskNodeState provides a mock function with given fields: +func (_m *NodeStateReader) GetTaskNodeState() interfaces.TaskNodeState { + ret := _m.Called() + + var r0 interfaces.TaskNodeState + if rf, ok := ret.Get(0).(func() interfaces.TaskNodeState); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(interfaces.TaskNodeState) + } + + return r0 +} + +type NodeStateReader_GetWorkflowNodeState struct { + *mock.Call +} + +func (_m NodeStateReader_GetWorkflowNodeState) Return(_a0 interfaces.WorkflowNodeState) *NodeStateReader_GetWorkflowNodeState { + return &NodeStateReader_GetWorkflowNodeState{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeStateReader) OnGetWorkflowNodeState() *NodeStateReader_GetWorkflowNodeState { + c_call := _m.On("GetWorkflowNodeState") + return &NodeStateReader_GetWorkflowNodeState{Call: c_call} +} + +func (_m *NodeStateReader) OnGetWorkflowNodeStateMatch(matchers ...interface{}) *NodeStateReader_GetWorkflowNodeState { + c_call := _m.On("GetWorkflowNodeState", matchers...) + return &NodeStateReader_GetWorkflowNodeState{Call: c_call} +} + +// GetWorkflowNodeState provides a mock function with given fields: +func (_m *NodeStateReader) GetWorkflowNodeState() interfaces.WorkflowNodeState { + ret := _m.Called() + + var r0 interfaces.WorkflowNodeState + if rf, ok := ret.Get(0).(func() interfaces.WorkflowNodeState); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(interfaces.WorkflowNodeState) + } + + return r0 +} + +type NodeStateReader_HasArrayNodeState struct { + *mock.Call +} + +func (_m NodeStateReader_HasArrayNodeState) Return(_a0 bool) *NodeStateReader_HasArrayNodeState { + return &NodeStateReader_HasArrayNodeState{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeStateReader) OnHasArrayNodeState() *NodeStateReader_HasArrayNodeState { + c_call := _m.On("HasArrayNodeState") + return &NodeStateReader_HasArrayNodeState{Call: c_call} +} + +func (_m *NodeStateReader) OnHasArrayNodeStateMatch(matchers ...interface{}) *NodeStateReader_HasArrayNodeState { + c_call := _m.On("HasArrayNodeState", matchers...) + return &NodeStateReader_HasArrayNodeState{Call: c_call} +} + +// HasArrayNodeState provides a mock function with given fields: +func (_m *NodeStateReader) HasArrayNodeState() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +type NodeStateReader_HasBranchNodeState struct { + *mock.Call +} + +func (_m NodeStateReader_HasBranchNodeState) Return(_a0 bool) *NodeStateReader_HasBranchNodeState { + return &NodeStateReader_HasBranchNodeState{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeStateReader) OnHasBranchNodeState() *NodeStateReader_HasBranchNodeState { + c_call := _m.On("HasBranchNodeState") + return &NodeStateReader_HasBranchNodeState{Call: c_call} +} + +func (_m *NodeStateReader) OnHasBranchNodeStateMatch(matchers ...interface{}) *NodeStateReader_HasBranchNodeState { + c_call := _m.On("HasBranchNodeState", matchers...) + return &NodeStateReader_HasBranchNodeState{Call: c_call} +} + +// HasBranchNodeState provides a mock function with given fields: +func (_m *NodeStateReader) HasBranchNodeState() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +type NodeStateReader_HasDynamicNodeState struct { + *mock.Call +} + +func (_m NodeStateReader_HasDynamicNodeState) Return(_a0 bool) *NodeStateReader_HasDynamicNodeState { + return &NodeStateReader_HasDynamicNodeState{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeStateReader) OnHasDynamicNodeState() *NodeStateReader_HasDynamicNodeState { + c_call := _m.On("HasDynamicNodeState") + return &NodeStateReader_HasDynamicNodeState{Call: c_call} +} + +func (_m *NodeStateReader) OnHasDynamicNodeStateMatch(matchers ...interface{}) *NodeStateReader_HasDynamicNodeState { + c_call := _m.On("HasDynamicNodeState", matchers...) + return &NodeStateReader_HasDynamicNodeState{Call: c_call} +} + +// HasDynamicNodeState provides a mock function with given fields: +func (_m *NodeStateReader) HasDynamicNodeState() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +type NodeStateReader_HasGateNodeState struct { + *mock.Call +} + +func (_m NodeStateReader_HasGateNodeState) Return(_a0 bool) *NodeStateReader_HasGateNodeState { + return &NodeStateReader_HasGateNodeState{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeStateReader) OnHasGateNodeState() *NodeStateReader_HasGateNodeState { + c_call := _m.On("HasGateNodeState") + return &NodeStateReader_HasGateNodeState{Call: c_call} +} + +func (_m *NodeStateReader) OnHasGateNodeStateMatch(matchers ...interface{}) *NodeStateReader_HasGateNodeState { + c_call := _m.On("HasGateNodeState", matchers...) + return &NodeStateReader_HasGateNodeState{Call: c_call} +} + +// HasGateNodeState provides a mock function with given fields: +func (_m *NodeStateReader) HasGateNodeState() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +type NodeStateReader_HasTaskNodeState struct { + *mock.Call +} + +func (_m NodeStateReader_HasTaskNodeState) Return(_a0 bool) *NodeStateReader_HasTaskNodeState { + return &NodeStateReader_HasTaskNodeState{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeStateReader) OnHasTaskNodeState() *NodeStateReader_HasTaskNodeState { + c_call := _m.On("HasTaskNodeState") + return &NodeStateReader_HasTaskNodeState{Call: c_call} +} + +func (_m *NodeStateReader) OnHasTaskNodeStateMatch(matchers ...interface{}) *NodeStateReader_HasTaskNodeState { + c_call := _m.On("HasTaskNodeState", matchers...) + return &NodeStateReader_HasTaskNodeState{Call: c_call} +} + +// HasTaskNodeState provides a mock function with given fields: +func (_m *NodeStateReader) HasTaskNodeState() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +type NodeStateReader_HasWorkflowNodeState struct { + *mock.Call +} + +func (_m NodeStateReader_HasWorkflowNodeState) Return(_a0 bool) *NodeStateReader_HasWorkflowNodeState { + return &NodeStateReader_HasWorkflowNodeState{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeStateReader) OnHasWorkflowNodeState() *NodeStateReader_HasWorkflowNodeState { + c_call := _m.On("HasWorkflowNodeState") + return &NodeStateReader_HasWorkflowNodeState{Call: c_call} +} + +func (_m *NodeStateReader) OnHasWorkflowNodeStateMatch(matchers ...interface{}) *NodeStateReader_HasWorkflowNodeState { + c_call := _m.On("HasWorkflowNodeState", matchers...) + return &NodeStateReader_HasWorkflowNodeState{Call: c_call} +} + +// HasWorkflowNodeState provides a mock function with given fields: +func (_m *NodeStateReader) HasWorkflowNodeState() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} diff --git a/pkg/controller/nodes/handler/mocks/node_state_writer.go b/pkg/controller/nodes/interfaces/mocks/node_state_writer.go similarity index 60% rename from pkg/controller/nodes/handler/mocks/node_state_writer.go rename to pkg/controller/nodes/interfaces/mocks/node_state_writer.go index ec5359550..93334a42d 100644 --- a/pkg/controller/nodes/handler/mocks/node_state_writer.go +++ b/pkg/controller/nodes/interfaces/mocks/node_state_writer.go @@ -3,7 +3,7 @@ package mocks import ( - handler "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + interfaces "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" mock "github.com/stretchr/testify/mock" ) @@ -12,6 +12,43 @@ type NodeStateWriter struct { mock.Mock } +// ClearNodeStatus provides a mock function with given fields: +func (_m *NodeStateWriter) ClearNodeStatus() { + _m.Called() +} + +type NodeStateWriter_PutArrayNodeState struct { + *mock.Call +} + +func (_m NodeStateWriter_PutArrayNodeState) Return(_a0 error) *NodeStateWriter_PutArrayNodeState { + return &NodeStateWriter_PutArrayNodeState{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeStateWriter) OnPutArrayNodeState(s interfaces.ArrayNodeState) *NodeStateWriter_PutArrayNodeState { + c_call := _m.On("PutArrayNodeState", s) + return &NodeStateWriter_PutArrayNodeState{Call: c_call} +} + +func (_m *NodeStateWriter) OnPutArrayNodeStateMatch(matchers ...interface{}) *NodeStateWriter_PutArrayNodeState { + c_call := _m.On("PutArrayNodeState", matchers...) + return &NodeStateWriter_PutArrayNodeState{Call: c_call} +} + +// PutArrayNodeState provides a mock function with given fields: s +func (_m *NodeStateWriter) PutArrayNodeState(s interfaces.ArrayNodeState) error { + ret := _m.Called(s) + + var r0 error + if rf, ok := ret.Get(0).(func(interfaces.ArrayNodeState) error); ok { + r0 = rf(s) + } else { + r0 = ret.Error(0) + } + + return r0 +} + type NodeStateWriter_PutBranchNode struct { *mock.Call } @@ -20,7 +57,7 @@ func (_m NodeStateWriter_PutBranchNode) Return(_a0 error) *NodeStateWriter_PutBr return &NodeStateWriter_PutBranchNode{Call: _m.Call.Return(_a0)} } -func (_m *NodeStateWriter) OnPutBranchNode(s handler.BranchNodeState) *NodeStateWriter_PutBranchNode { +func (_m *NodeStateWriter) OnPutBranchNode(s interfaces.BranchNodeState) *NodeStateWriter_PutBranchNode { c_call := _m.On("PutBranchNode", s) return &NodeStateWriter_PutBranchNode{Call: c_call} } @@ -31,11 +68,11 @@ func (_m *NodeStateWriter) OnPutBranchNodeMatch(matchers ...interface{}) *NodeSt } // PutBranchNode provides a mock function with given fields: s -func (_m *NodeStateWriter) PutBranchNode(s handler.BranchNodeState) error { +func (_m *NodeStateWriter) PutBranchNode(s interfaces.BranchNodeState) error { ret := _m.Called(s) var r0 error - if rf, ok := ret.Get(0).(func(handler.BranchNodeState) error); ok { + if rf, ok := ret.Get(0).(func(interfaces.BranchNodeState) error); ok { r0 = rf(s) } else { r0 = ret.Error(0) @@ -52,7 +89,7 @@ func (_m NodeStateWriter_PutDynamicNodeState) Return(_a0 error) *NodeStateWriter return &NodeStateWriter_PutDynamicNodeState{Call: _m.Call.Return(_a0)} } -func (_m *NodeStateWriter) OnPutDynamicNodeState(s handler.DynamicNodeState) *NodeStateWriter_PutDynamicNodeState { +func (_m *NodeStateWriter) OnPutDynamicNodeState(s interfaces.DynamicNodeState) *NodeStateWriter_PutDynamicNodeState { c_call := _m.On("PutDynamicNodeState", s) return &NodeStateWriter_PutDynamicNodeState{Call: c_call} } @@ -63,11 +100,11 @@ func (_m *NodeStateWriter) OnPutDynamicNodeStateMatch(matchers ...interface{}) * } // PutDynamicNodeState provides a mock function with given fields: s -func (_m *NodeStateWriter) PutDynamicNodeState(s handler.DynamicNodeState) error { +func (_m *NodeStateWriter) PutDynamicNodeState(s interfaces.DynamicNodeState) error { ret := _m.Called(s) var r0 error - if rf, ok := ret.Get(0).(func(handler.DynamicNodeState) error); ok { + if rf, ok := ret.Get(0).(func(interfaces.DynamicNodeState) error); ok { r0 = rf(s) } else { r0 = ret.Error(0) @@ -84,7 +121,7 @@ func (_m NodeStateWriter_PutGateNodeState) Return(_a0 error) *NodeStateWriter_Pu return &NodeStateWriter_PutGateNodeState{Call: _m.Call.Return(_a0)} } -func (_m *NodeStateWriter) OnPutGateNodeState(s handler.GateNodeState) *NodeStateWriter_PutGateNodeState { +func (_m *NodeStateWriter) OnPutGateNodeState(s interfaces.GateNodeState) *NodeStateWriter_PutGateNodeState { c_call := _m.On("PutGateNodeState", s) return &NodeStateWriter_PutGateNodeState{Call: c_call} } @@ -95,11 +132,11 @@ func (_m *NodeStateWriter) OnPutGateNodeStateMatch(matchers ...interface{}) *Nod } // PutGateNodeState provides a mock function with given fields: s -func (_m *NodeStateWriter) PutGateNodeState(s handler.GateNodeState) error { +func (_m *NodeStateWriter) PutGateNodeState(s interfaces.GateNodeState) error { ret := _m.Called(s) var r0 error - if rf, ok := ret.Get(0).(func(handler.GateNodeState) error); ok { + if rf, ok := ret.Get(0).(func(interfaces.GateNodeState) error); ok { r0 = rf(s) } else { r0 = ret.Error(0) @@ -116,7 +153,7 @@ func (_m NodeStateWriter_PutTaskNodeState) Return(_a0 error) *NodeStateWriter_Pu return &NodeStateWriter_PutTaskNodeState{Call: _m.Call.Return(_a0)} } -func (_m *NodeStateWriter) OnPutTaskNodeState(s handler.TaskNodeState) *NodeStateWriter_PutTaskNodeState { +func (_m *NodeStateWriter) OnPutTaskNodeState(s interfaces.TaskNodeState) *NodeStateWriter_PutTaskNodeState { c_call := _m.On("PutTaskNodeState", s) return &NodeStateWriter_PutTaskNodeState{Call: c_call} } @@ -127,11 +164,11 @@ func (_m *NodeStateWriter) OnPutTaskNodeStateMatch(matchers ...interface{}) *Nod } // PutTaskNodeState provides a mock function with given fields: s -func (_m *NodeStateWriter) PutTaskNodeState(s handler.TaskNodeState) error { +func (_m *NodeStateWriter) PutTaskNodeState(s interfaces.TaskNodeState) error { ret := _m.Called(s) var r0 error - if rf, ok := ret.Get(0).(func(handler.TaskNodeState) error); ok { + if rf, ok := ret.Get(0).(func(interfaces.TaskNodeState) error); ok { r0 = rf(s) } else { r0 = ret.Error(0) @@ -148,7 +185,7 @@ func (_m NodeStateWriter_PutWorkflowNodeState) Return(_a0 error) *NodeStateWrite return &NodeStateWriter_PutWorkflowNodeState{Call: _m.Call.Return(_a0)} } -func (_m *NodeStateWriter) OnPutWorkflowNodeState(s handler.WorkflowNodeState) *NodeStateWriter_PutWorkflowNodeState { +func (_m *NodeStateWriter) OnPutWorkflowNodeState(s interfaces.WorkflowNodeState) *NodeStateWriter_PutWorkflowNodeState { c_call := _m.On("PutWorkflowNodeState", s) return &NodeStateWriter_PutWorkflowNodeState{Call: c_call} } @@ -159,11 +196,11 @@ func (_m *NodeStateWriter) OnPutWorkflowNodeStateMatch(matchers ...interface{}) } // PutWorkflowNodeState provides a mock function with given fields: s -func (_m *NodeStateWriter) PutWorkflowNodeState(s handler.WorkflowNodeState) error { +func (_m *NodeStateWriter) PutWorkflowNodeState(s interfaces.WorkflowNodeState) error { ret := _m.Called(s) var r0 error - if rf, ok := ret.Get(0).(func(handler.WorkflowNodeState) error); ok { + if rf, ok := ret.Get(0).(func(interfaces.WorkflowNodeState) error); ok { r0 = rf(s) } else { r0 = ret.Error(0) diff --git a/pkg/controller/nodes/handler/mocks/task_reader.go b/pkg/controller/nodes/interfaces/mocks/task_reader.go similarity index 100% rename from pkg/controller/nodes/handler/mocks/task_reader.go rename to pkg/controller/nodes/interfaces/mocks/task_reader.go From b150f1a6c104ebe70fa883675c2b828993d99e8e Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Tue, 13 Jun 2023 16:45:47 -0500 Subject: [PATCH 36/62] refactored for unit testing to allow creation of NodeExecutor in array package Signed-off-by: Daniel Rammer --- pkg/controller/controller.go | 12 +- pkg/controller/nodes/array/handler.go | 4 +- pkg/controller/nodes/array/handler_test.go | 34 ++-- pkg/controller/nodes/branch/handler.go | 4 +- pkg/controller/nodes/dynamic/handler.go | 4 +- .../nodes/dynamic/mocks/task_node_handler.go | 6 +- pkg/controller/nodes/end/handler.go | 4 +- pkg/controller/nodes/executor.go | 45 +++-- .../nodes/factory/handler_factory.go | 96 +++++++++ pkg/controller/nodes/gate/handler.go | 4 +- pkg/controller/nodes/handler/mocks/node.go | 184 ------------------ .../nodes/handler/transition_info_test.go | 3 +- .../nodes/handler/transition_test.go | 1 + pkg/controller/nodes/handler_factory.go | 86 -------- .../iface.go => interfaces/handler.go} | 18 +- .../nodes/interfaces/handler_factory.go | 14 ++ .../{ => interfaces}/mocks/handler_factory.go | 26 +-- .../mocks/node_executor.go | 22 +-- .../nodes/interfaces/mocks/node_handler.go | 184 ++++++++++++++++++ .../mocks/setup_context.go | 0 pkg/controller/nodes/setup_context.go | 4 +- pkg/controller/nodes/start/handler.go | 4 +- pkg/controller/nodes/subworkflow/handler.go | 4 +- pkg/controller/nodes/task/handler.go | 2 +- pkg/controller/nodes/task/setup_ctx.go | 10 +- 25 files changed, 409 insertions(+), 366 deletions(-) create mode 100644 pkg/controller/nodes/factory/handler_factory.go delete mode 100644 pkg/controller/nodes/handler/mocks/node.go delete mode 100644 pkg/controller/nodes/handler_factory.go rename pkg/controller/nodes/{handler/iface.go => interfaces/handler.go} (64%) create mode 100644 pkg/controller/nodes/interfaces/handler_factory.go rename pkg/controller/nodes/{ => interfaces}/mocks/handler_factory.go (61%) rename pkg/controller/nodes/{handler => interfaces}/mocks/node_executor.go (69%) create mode 100644 pkg/controller/nodes/interfaces/mocks/node_handler.go rename pkg/controller/nodes/{handler => interfaces}/mocks/setup_context.go (100%) diff --git a/pkg/controller/controller.go b/pkg/controller/controller.go index d7b8d9c28..57dcb9ba6 100644 --- a/pkg/controller/controller.go +++ b/pkg/controller/controller.go @@ -26,6 +26,7 @@ import ( "github.com/flyteorg/flytepropeller/pkg/controller/executors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes" errors3 "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/factory" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/catalog" @@ -437,9 +438,16 @@ func New(ctx context.Context, cfg *config.Config, kubeclientset kubernetes.Inter controller.levelMonitor = NewResourceLevelMonitor(scope.NewSubScope("collector"), flyteworkflowInformer.Lister()) + recoveryClient := recovery.NewClient(adminClient) + nodeHandlerFactory, err := factory.NewHandlerFactory(ctx, launchPlanActor, launchPlanActor, + kubeClient, catalogClient, recoveryClient, &cfg.EventConfig, cfg.ClusterID, signalClient, scope) + if err != nil { + return nil, errors.Wrapf(err, "failed to create node handler factory") + } + nodeExecutor, err := nodes.NewExecutor(ctx, cfg.NodeConfig, store, controller.enqueueWorkflowForNodeUpdates, eventSink, - launchPlanActor, launchPlanActor, cfg.MaxDatasetSizeBytes, - storage.DataReference(cfg.DefaultRawOutputPrefix), kubeClient, catalogClient, recovery.NewClient(adminClient), &cfg.EventConfig, cfg.ClusterID, signalClient, scope) + launchPlanActor, launchPlanActor, cfg.MaxDatasetSizeBytes, storage.DataReference(cfg.DefaultRawOutputPrefix), kubeClient, + catalogClient, recoveryClient, &cfg.EventConfig, cfg.ClusterID, signalClient, nodeHandlerFactory, scope) if err != nil { return nil, errors.Wrapf(err, "Failed to create Controller.") } diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 9c206e3de..2b8de07e9 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -490,12 +490,12 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu } // Setup handles any initialization requirements for this handler -func (a *arrayNodeHandler) Setup(_ context.Context, _ handler.SetupContext) error { +func (a *arrayNodeHandler) Setup(_ context.Context, _ interfaces.SetupContext) error { return nil } // New initializes a new arrayNodeHandler -func New(nodeExecutor interfaces.Node, eventConfig *config.EventConfig, scope promutils.Scope) (handler.Node, error) { +func New(nodeExecutor interfaces.Node, eventConfig *config.EventConfig, scope promutils.Scope) (interfaces.NodeHandler, error) { // create k8s PluginState byte mocks to reuse instead of creating for each subNode evaluation pluginStateBytesNotStarted, err := bytesFromK8sPluginState(k8s.PluginState{Phase: k8s.PluginPhaseNotStarted}) if err != nil { diff --git a/pkg/controller/nodes/array/handler_test.go b/pkg/controller/nodes/array/handler_test.go index a268a1420..bdfb7f5a1 100644 --- a/pkg/controller/nodes/array/handler_test.go +++ b/pkg/controller/nodes/array/handler_test.go @@ -4,23 +4,26 @@ import ( "context" "testing" - "github.com/flyteorg/flytepropeller/events" eventmocks "github.com/flyteorg/flytepropeller/events/mocks" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" execmocks "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/config" - gatemocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes" gatemocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/gate/mocks" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" recoverymocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/catalog" "github.com/flyteorg/flytestdlib/promutils" "github.com/flyteorg/flytestdlib/storage" + + "github.com/stretchr/testify/assert" + //"github.com/stretchr/testify/mock" ) -func createArrayNodeExecutor(t *testing.T, ctx context.Context, scope promutils.Scope) (handler.Node, error) { +func createArrayNodeHandler(t *testing.T, ctx context.Context, scope promutils.Scope) (interfaces.NodeHandler, error) { // mock components adminClient := launchplan.NewFailFastLaunchPlanExecutor() dataStore, err := storage.NewDataStore(&storage.Config{ @@ -28,19 +31,20 @@ func createArrayNodeExecutor(t *testing.T, ctx context.Context, scope promutils. }, scope) assert.NoError(t, err) enqueueWorkflowFunc := func(workflowID v1alpha1.WorkflowID) {} - eventConfig := &events.EventConfig{} - mockEventSink = eventmocks.NewMockEventSink() - mockKubeClient = execmocks.NewFakeKubeClient() - mockRecoveryClient = &recoverymocks.Client{} - mockSignalClient = &gatemocks.SignalServiceClient{} - noopCatalogClient = catalog.NOOPCatalog{} - scope := promutils.NewTestScope() + eventConfig := &config.EventConfig{} + mockEventSink := eventmocks.NewMockEventSink() + mockHandlerFactory := &mocks.HandlerFactory{} + mockKubeClient := execmocks.NewFakeKubeClient() + mockRecoveryClient := &recoverymocks.Client{} + mockSignalClient := &gatemocks.SignalServiceClient{} + noopCatalogClient := catalog.NOOPCatalog{} // create node executor nodeExecutor, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, dataStore, enqueueWorkflowFunc, mockEventSink, adminClient, - adminClient, 10, "s3://bucket/", mockKubeClient, noopCatalogClient, mockRecoveryClient, eventConfig, "clusterID", mockSignalClient, scope) + adminClient, 10, "s3://bucket/", mockKubeClient, noopCatalogClient, mockRecoveryClient, eventConfig, "clusterID", mockSignalClient, mockHandlerFactory, scope) assert.NoError(t, err) + // return ArrayNodeHandler return New(nodeExecutor, eventConfig, scope) } @@ -54,6 +58,12 @@ func TestFinalize(t *testing.T) { func TestHandleArrayNodePhaseNone(t *testing.T) { ctx := context.Background() + scope := promutils.NewTestScope() + + // initialize ArrayNodeHandler + arrayNodeHandler, err := createArrayNodeHandler(t, ctx, scope) + assert.NoError(t, err) + // TODO @hamersaw - complete } diff --git a/pkg/controller/nodes/branch/handler.go b/pkg/controller/nodes/branch/handler.go index 50e47dd68..3e4b2897e 100644 --- a/pkg/controller/nodes/branch/handler.go +++ b/pkg/controller/nodes/branch/handler.go @@ -33,7 +33,7 @@ func (b *branchHandler) FinalizeRequired() bool { return false } -func (b *branchHandler) Setup(ctx context.Context, _ handler.SetupContext) error { +func (b *branchHandler) Setup(ctx context.Context, _ interfaces.SetupContext) error { logger.Debugf(ctx, "BranchNode::Setup: nothing to do") return nil } @@ -258,7 +258,7 @@ func (b *branchHandler) Finalize(ctx context.Context, nCtx interfaces.NodeExecut return b.nodeExecutor.FinalizeHandler(ctx, execContext, dag, nCtx.ContextualNodeLookup(), branchTakenNode) } -func New(executor interfaces.Node, eventConfig *config.EventConfig, scope promutils.Scope) handler.Node { +func New(executor interfaces.Node, eventConfig *config.EventConfig, scope promutils.Scope) interfaces.NodeHandler { return &branchHandler{ nodeExecutor: executor, m: metrics{scope: scope}, diff --git a/pkg/controller/nodes/dynamic/handler.go b/pkg/controller/nodes/dynamic/handler.go index 2cfa745e1..22cca7c77 100644 --- a/pkg/controller/nodes/dynamic/handler.go +++ b/pkg/controller/nodes/dynamic/handler.go @@ -31,7 +31,7 @@ import ( const dynamicNodeID = "dynamic-node" type TaskNodeHandler interface { - handler.Node + interfaces.NodeHandler ValidateOutputAndCacheAdd(ctx context.Context, nodeID v1alpha1.NodeID, i io.InputReader, r io.OutputReader, outputCommitter io.OutputWriter, executionConfig v1alpha1.ExecutionConfig, tr ioutils.SimpleTaskReader, m catalog.Metadata) (catalog.Status, *io.ExecutionError, error) @@ -303,7 +303,7 @@ func (d dynamicNodeTaskNodeHandler) Finalize(ctx context.Context, nCtx interface return nil } -func New(underlying TaskNodeHandler, nodeExecutor interfaces.Node, launchPlanReader launchplan.Reader, eventConfig *config.EventConfig, scope promutils.Scope) handler.Node { +func New(underlying TaskNodeHandler, nodeExecutor interfaces.Node, launchPlanReader launchplan.Reader, eventConfig *config.EventConfig, scope promutils.Scope) interfaces.NodeHandler { return &dynamicNodeTaskNodeHandler{ TaskNodeHandler: underlying, diff --git a/pkg/controller/nodes/dynamic/mocks/task_node_handler.go b/pkg/controller/nodes/dynamic/mocks/task_node_handler.go index fc69132e5..e8d8cc6d7 100644 --- a/pkg/controller/nodes/dynamic/mocks/task_node_handler.go +++ b/pkg/controller/nodes/dynamic/mocks/task_node_handler.go @@ -168,7 +168,7 @@ func (_m TaskNodeHandler_Setup) Return(_a0 error) *TaskNodeHandler_Setup { return &TaskNodeHandler_Setup{Call: _m.Call.Return(_a0)} } -func (_m *TaskNodeHandler) OnSetup(ctx context.Context, setupContext handler.SetupContext) *TaskNodeHandler_Setup { +func (_m *TaskNodeHandler) OnSetup(ctx context.Context, setupContext interfaces.SetupContext) *TaskNodeHandler_Setup { c_call := _m.On("Setup", ctx, setupContext) return &TaskNodeHandler_Setup{Call: c_call} } @@ -179,11 +179,11 @@ func (_m *TaskNodeHandler) OnSetupMatch(matchers ...interface{}) *TaskNodeHandle } // Setup provides a mock function with given fields: ctx, setupContext -func (_m *TaskNodeHandler) Setup(ctx context.Context, setupContext handler.SetupContext) error { +func (_m *TaskNodeHandler) Setup(ctx context.Context, setupContext interfaces.SetupContext) error { ret := _m.Called(ctx, setupContext) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, handler.SetupContext) error); ok { + if rf, ok := ret.Get(0).(func(context.Context, interfaces.SetupContext) error); ok { r0 = rf(ctx, setupContext) } else { r0 = ret.Error(0) diff --git a/pkg/controller/nodes/end/handler.go b/pkg/controller/nodes/end/handler.go index 7bd1286ed..d77a7ab50 100644 --- a/pkg/controller/nodes/end/handler.go +++ b/pkg/controller/nodes/end/handler.go @@ -19,7 +19,7 @@ func (e endHandler) FinalizeRequired() bool { return false } -func (e endHandler) Setup(ctx context.Context, setupContext handler.SetupContext) error { +func (e endHandler) Setup(ctx context.Context, setupContext interfaces.SetupContext) error { return nil } @@ -50,6 +50,6 @@ func (e endHandler) Finalize(_ context.Context, _ interfaces.NodeExecutionContex return nil } -func New() handler.Node { +func New() interfaces.NodeHandler { return &endHandler{} } diff --git a/pkg/controller/nodes/executor.go b/pkg/controller/nodes/executor.go index 72cf83326..7781fb90d 100644 --- a/pkg/controller/nodes/executor.go +++ b/pkg/controller/nodes/executor.go @@ -87,11 +87,11 @@ type nodeMetrics struct { // Implements the executors.Node interface type recursiveNodeExecutor struct { - nodeExecutor handler.NodeExecutor + nodeExecutor interfaces.NodeExecutor nCtxBuilder interfaces.NodeExecutionContextBuilder enqueueWorkflow v1alpha1.EnqueueWorkflow - nodeHandlerFactory HandlerFactory + nodeHandlerFactory interfaces.HandlerFactory store *storage.DataStore metrics *nodeMetrics } @@ -446,7 +446,7 @@ func (c *recursiveNodeExecutor) AbortHandler(ctx context.Context, execContext ex func (c *recursiveNodeExecutor) Initialize(ctx context.Context) error { logger.Infof(ctx, "Initializing Core Node Executor") s := c.newSetupContext(ctx) - return c.nodeHandlerFactory.Setup(ctx, s) + return c.nodeHandlerFactory.Setup(ctx, c, s) } // TODO @hamersaw docs @@ -458,7 +458,6 @@ func (c *recursiveNodeExecutor) WithNodeExecutionContextBuilder(nCtxBuilder inte return &recursiveNodeExecutor{ nodeExecutor: c.nodeExecutor, nCtxBuilder: nCtxBuilder, - // TODO @hamersaw fill out enqueueWorkflow: c.enqueueWorkflow, nodeHandlerFactory: c.nodeHandlerFactory, store: c.store, @@ -834,7 +833,7 @@ func (c *nodeExecutor) isEligibleForRetry(nCtx interfaces.NodeExecutionContext, return } -func (c *nodeExecutor) execute(ctx context.Context, h handler.Node, nCtx interfaces.NodeExecutionContext, nodeStatus v1alpha1.ExecutableNodeStatus) (handler.PhaseInfo, error) { +func (c *nodeExecutor) execute(ctx context.Context, h interfaces.NodeHandler, nCtx interfaces.NodeExecutionContext, nodeStatus v1alpha1.ExecutableNodeStatus) (handler.PhaseInfo, error) { logger.Debugf(ctx, "Executing node") defer logger.Debugf(ctx, "Node execution round complete") @@ -885,7 +884,7 @@ func (c *nodeExecutor) execute(ctx context.Context, h handler.Node, nCtx interfa return phase, nil } -func (c *nodeExecutor) Abort(ctx context.Context, h handler.Node, nCtx interfaces.NodeExecutionContext, reason string) error { +func (c *nodeExecutor) Abort(ctx context.Context, h interfaces.NodeHandler, nCtx interfaces.NodeExecutionContext, reason string) error { logger.Debugf(ctx, "Calling aborting & finalize") if err := h.Abort(ctx, nCtx, reason); err != nil { finalizeErr := h.Finalize(ctx, nCtx) @@ -936,11 +935,11 @@ func (c *nodeExecutor) Abort(ctx context.Context, h handler.Node, nCtx interface return nil } -func (c *nodeExecutor) Finalize(ctx context.Context, h handler.Node, nCtx interfaces.NodeExecutionContext) error { +func (c *nodeExecutor) Finalize(ctx context.Context, h interfaces.NodeHandler, nCtx interfaces.NodeExecutionContext) error { return h.Finalize(ctx, nCtx) } -func (c *nodeExecutor) handleNotYetStartedNode(ctx context.Context, dag executors.DAGStructure, nCtx interfaces.NodeExecutionContext, _ handler.Node) (interfaces.NodeStatus, error) { +func (c *nodeExecutor) handleNotYetStartedNode(ctx context.Context, dag executors.DAGStructure, nCtx interfaces.NodeExecutionContext, _ interfaces.NodeHandler) (interfaces.NodeStatus, error) { logger.Debugf(ctx, "Node not yet started, running pre-execute") defer logger.Debugf(ctx, "Node pre-execute completed") occurredAt := time.Now() @@ -998,7 +997,7 @@ func (c *nodeExecutor) handleNotYetStartedNode(ctx context.Context, dag executor return interfaces.NodeStatusPending, nil } -func (c *nodeExecutor) handleQueuedOrRunningNode(ctx context.Context, nCtx interfaces.NodeExecutionContext, h handler.Node) (interfaces.NodeStatus, error) { +func (c *nodeExecutor) handleQueuedOrRunningNode(ctx context.Context, nCtx interfaces.NodeExecutionContext, h interfaces.NodeHandler) (interfaces.NodeStatus, error) { nodeStatus := nCtx.NodeStatus() currentPhase := nodeStatus.GetPhase() @@ -1139,7 +1138,7 @@ func (c *nodeExecutor) handleQueuedOrRunningNode(ctx context.Context, nCtx inter return finalStatus, nil } -func (c *nodeExecutor) handleRetryableFailure(ctx context.Context, nCtx interfaces.NodeExecutionContext, h handler.Node) (interfaces.NodeStatus, error) { +func (c *nodeExecutor) handleRetryableFailure(ctx context.Context, nCtx interfaces.NodeExecutionContext, h interfaces.NodeHandler) (interfaces.NodeStatus, error) { nodeStatus := nCtx.NodeStatus() logger.Debugf(ctx, "node failed with retryable failure, aborting and finalizing, message: %s", nodeStatus.GetMessage()) if err := c.Abort(ctx, h, nCtx, nodeStatus.GetMessage()); err != nil { @@ -1160,7 +1159,7 @@ func (c *nodeExecutor) handleRetryableFailure(ctx context.Context, nCtx interfac return interfaces.NodeStatusPending, nil } -func (c *nodeExecutor) HandleNode(ctx context.Context, dag executors.DAGStructure, nCtx interfaces.NodeExecutionContext, h handler.Node) (interfaces.NodeStatus, error) { +func (c *nodeExecutor) HandleNode(ctx context.Context, dag executors.DAGStructure, nCtx interfaces.NodeExecutionContext, h interfaces.NodeHandler) (interfaces.NodeStatus, error) { logger.Debugf(ctx, "Handling Node [%s]", nCtx.NodeID()) defer logger.Debugf(ctx, "Completed node [%s]", nCtx.NodeID()) @@ -1245,9 +1244,9 @@ func (c *nodeExecutor) HandleNode(ctx context.Context, dag executors.DAGStructur } func NewExecutor(ctx context.Context, nodeConfig config.NodeConfig, store *storage.DataStore, enQWorkflow v1alpha1.EnqueueWorkflow, eventSink events.EventSink, - workflowLauncher launchplan.Executor, launchPlanReader launchplan.Reader, maxDatasetSize int64, - defaultRawOutputPrefix storage.DataReference, kubeClient executors.Client, - catalogClient catalog.Client, recoveryClient recovery.Client, eventConfig *config.EventConfig, clusterID string, signalClient service.SignalServiceClient, scope promutils.Scope) (interfaces.Node, error) { + workflowLauncher launchplan.Executor, launchPlanReader launchplan.Reader, maxDatasetSize int64, defaultRawOutputPrefix storage.DataReference, kubeClient executors.Client, + catalogClient catalog.Client, recoveryClient recovery.Client, eventConfig *config.EventConfig, clusterID string, signalClient service.SignalServiceClient, + nodeHandlerFactory interfaces.HandlerFactory, scope promutils.Scope) (interfaces.Node, error) { // TODO we may want to make this configurable. shardSelector, err := ioutils.NewBase36PrefixShardSelector(ctx) @@ -1299,14 +1298,14 @@ func NewExecutor(ctx context.Context, nodeConfig config.NodeConfig, store *stora } exec := &recursiveNodeExecutor{ - nodeExecutor: nodeExecutor, - nCtxBuilder: nodeExecutor, - - enqueueWorkflow: enQWorkflow, - store: store, - metrics: metrics, - } - nodeHandlerFactory, err := NewHandlerFactory(ctx, exec, workflowLauncher, launchPlanReader, kubeClient, catalogClient, recoveryClient, eventConfig, clusterID, signalClient, nodeScope) - exec.nodeHandlerFactory = nodeHandlerFactory + nodeExecutor: nodeExecutor, + nCtxBuilder: nodeExecutor, + nodeHandlerFactory: nodeHandlerFactory, + enqueueWorkflow: enQWorkflow, + store: store, + metrics: metrics, + } + /*nodeHandlerFactory, err := NewHandlerFactory(ctx, exec, workflowLauncher, launchPlanReader, kubeClient, catalogClient, recoveryClient, eventConfig, clusterID, signalClient, nodeScope) + exec.nodeHandlerFactory = nodeHandlerFactory*/ return exec, err } diff --git a/pkg/controller/nodes/factory/handler_factory.go b/pkg/controller/nodes/factory/handler_factory.go new file mode 100644 index 000000000..8a3718fb3 --- /dev/null +++ b/pkg/controller/nodes/factory/handler_factory.go @@ -0,0 +1,96 @@ +package factory + +import ( + "context" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" + + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog" + + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/config" + "github.com/flyteorg/flytepropeller/pkg/controller/executors" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/array" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/branch" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/dynamic" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/end" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/gate" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/start" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task" + + "github.com/flyteorg/flytestdlib/promutils" + + "github.com/pkg/errors" +) + +type handlerFactory struct { + handlers map[v1alpha1.NodeKind]interfaces.NodeHandler + + workflowLauncher launchplan.Executor + launchPlanReader launchplan.Reader + kubeClient executors.Client + catalogClient catalog.Client + recoveryClient recovery.Client + eventConfig *config.EventConfig + clusterID string + signalClient service.SignalServiceClient + scope promutils.Scope +} + +func (f *handlerFactory) GetHandler(kind v1alpha1.NodeKind) (interfaces.NodeHandler, error) { + h, ok := f.handlers[kind] + if !ok { + return nil, errors.Errorf("Handler not registered for NodeKind [%v]", kind) + } + return h, nil +} + +func (f *handlerFactory) Setup(ctx context.Context, executor interfaces.Node, setup interfaces.SetupContext) error { + t, err := task.New(ctx, f.kubeClient, f.catalogClient, f.eventConfig, f.clusterID, f.scope) + if err != nil { + return err + } + + arrayHandler, err := array.New(executor, f.eventConfig, f.scope) + if err != nil { + return err + } + + f.handlers = map[v1alpha1.NodeKind]interfaces.NodeHandler{ + v1alpha1.NodeKindBranch: branch.New(executor, f.eventConfig, f.scope), + v1alpha1.NodeKindTask: dynamic.New(t, executor, f.launchPlanReader, f.eventConfig, f.scope), + v1alpha1.NodeKindWorkflow: subworkflow.New(executor, f.workflowLauncher, f.recoveryClient, f.eventConfig, f.scope), + v1alpha1.NodeKindGate: gate.New(f.eventConfig, f.signalClient, f.scope), + v1alpha1.NodeKindArray: arrayHandler, + v1alpha1.NodeKindStart: start.New(), + v1alpha1.NodeKindEnd: end.New(), + } + + for _, v := range f.handlers { + if err := v.Setup(ctx, setup); err != nil { + return err + } + } + return nil +} + +func NewHandlerFactory(ctx context.Context, workflowLauncher launchplan.Executor, launchPlanReader launchplan.Reader, + kubeClient executors.Client, catalogClient catalog.Client, recoveryClient recovery.Client, eventConfig *config.EventConfig, + clusterID string, signalClient service.SignalServiceClient, scope promutils.Scope) (interfaces.HandlerFactory, error) { + + return &handlerFactory{ + workflowLauncher: workflowLauncher, + launchPlanReader: launchPlanReader, + kubeClient: kubeClient, + catalogClient: catalogClient, + recoveryClient: recoveryClient, + eventConfig: eventConfig, + clusterID: clusterID, + signalClient: signalClient, + scope: scope, + }, nil +} diff --git a/pkg/controller/nodes/gate/handler.go b/pkg/controller/nodes/gate/handler.go index 340cbab9f..4b8d62720 100644 --- a/pkg/controller/nodes/gate/handler.go +++ b/pkg/controller/nodes/gate/handler.go @@ -206,12 +206,12 @@ func (g *gateNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecut } // Setup handles any initialization requirements for this handler -func (g *gateNodeHandler) Setup(_ context.Context, _ handler.SetupContext) error { +func (g *gateNodeHandler) Setup(_ context.Context, _ interfaces.SetupContext) error { return nil } // New initializes a new gateNodeHandler -func New(eventConfig *config.EventConfig, signalClient service.SignalServiceClient, scope promutils.Scope) handler.Node { +func New(eventConfig *config.EventConfig, signalClient service.SignalServiceClient, scope promutils.Scope) interfaces.NodeHandler { gateScope := scope.NewSubScope("gate") return &gateNodeHandler{ signalClient: signalClient, diff --git a/pkg/controller/nodes/handler/mocks/node.go b/pkg/controller/nodes/handler/mocks/node.go deleted file mode 100644 index 52eba43b7..000000000 --- a/pkg/controller/nodes/handler/mocks/node.go +++ /dev/null @@ -1,184 +0,0 @@ -// Code generated by mockery v1.0.1. DO NOT EDIT. - -package mocks - -import ( - context "context" - - handler "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" - interfaces "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" - - mock "github.com/stretchr/testify/mock" -) - -// Node is an autogenerated mock type for the Node type -type Node struct { - mock.Mock -} - -type Node_Abort struct { - *mock.Call -} - -func (_m Node_Abort) Return(_a0 error) *Node_Abort { - return &Node_Abort{Call: _m.Call.Return(_a0)} -} - -func (_m *Node) OnAbort(ctx context.Context, executionContext interfaces.NodeExecutionContext, reason string) *Node_Abort { - c_call := _m.On("Abort", ctx, executionContext, reason) - return &Node_Abort{Call: c_call} -} - -func (_m *Node) OnAbortMatch(matchers ...interface{}) *Node_Abort { - c_call := _m.On("Abort", matchers...) - return &Node_Abort{Call: c_call} -} - -// Abort provides a mock function with given fields: ctx, executionContext, reason -func (_m *Node) Abort(ctx context.Context, executionContext interfaces.NodeExecutionContext, reason string) error { - ret := _m.Called(ctx, executionContext, reason) - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, interfaces.NodeExecutionContext, string) error); ok { - r0 = rf(ctx, executionContext, reason) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -type Node_Finalize struct { - *mock.Call -} - -func (_m Node_Finalize) Return(_a0 error) *Node_Finalize { - return &Node_Finalize{Call: _m.Call.Return(_a0)} -} - -func (_m *Node) OnFinalize(ctx context.Context, executionContext interfaces.NodeExecutionContext) *Node_Finalize { - c_call := _m.On("Finalize", ctx, executionContext) - return &Node_Finalize{Call: c_call} -} - -func (_m *Node) OnFinalizeMatch(matchers ...interface{}) *Node_Finalize { - c_call := _m.On("Finalize", matchers...) - return &Node_Finalize{Call: c_call} -} - -// Finalize provides a mock function with given fields: ctx, executionContext -func (_m *Node) Finalize(ctx context.Context, executionContext interfaces.NodeExecutionContext) error { - ret := _m.Called(ctx, executionContext) - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, interfaces.NodeExecutionContext) error); ok { - r0 = rf(ctx, executionContext) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -type Node_FinalizeRequired struct { - *mock.Call -} - -func (_m Node_FinalizeRequired) Return(_a0 bool) *Node_FinalizeRequired { - return &Node_FinalizeRequired{Call: _m.Call.Return(_a0)} -} - -func (_m *Node) OnFinalizeRequired() *Node_FinalizeRequired { - c_call := _m.On("FinalizeRequired") - return &Node_FinalizeRequired{Call: c_call} -} - -func (_m *Node) OnFinalizeRequiredMatch(matchers ...interface{}) *Node_FinalizeRequired { - c_call := _m.On("FinalizeRequired", matchers...) - return &Node_FinalizeRequired{Call: c_call} -} - -// FinalizeRequired provides a mock function with given fields: -func (_m *Node) FinalizeRequired() bool { - ret := _m.Called() - - var r0 bool - if rf, ok := ret.Get(0).(func() bool); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(bool) - } - - return r0 -} - -type Node_Handle struct { - *mock.Call -} - -func (_m Node_Handle) Return(_a0 handler.Transition, _a1 error) *Node_Handle { - return &Node_Handle{Call: _m.Call.Return(_a0, _a1)} -} - -func (_m *Node) OnHandle(ctx context.Context, executionContext interfaces.NodeExecutionContext) *Node_Handle { - c_call := _m.On("Handle", ctx, executionContext) - return &Node_Handle{Call: c_call} -} - -func (_m *Node) OnHandleMatch(matchers ...interface{}) *Node_Handle { - c_call := _m.On("Handle", matchers...) - return &Node_Handle{Call: c_call} -} - -// Handle provides a mock function with given fields: ctx, executionContext -func (_m *Node) Handle(ctx context.Context, executionContext interfaces.NodeExecutionContext) (handler.Transition, error) { - ret := _m.Called(ctx, executionContext) - - var r0 handler.Transition - if rf, ok := ret.Get(0).(func(context.Context, interfaces.NodeExecutionContext) handler.Transition); ok { - r0 = rf(ctx, executionContext) - } else { - r0 = ret.Get(0).(handler.Transition) - } - - var r1 error - if rf, ok := ret.Get(1).(func(context.Context, interfaces.NodeExecutionContext) error); ok { - r1 = rf(ctx, executionContext) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -type Node_Setup struct { - *mock.Call -} - -func (_m Node_Setup) Return(_a0 error) *Node_Setup { - return &Node_Setup{Call: _m.Call.Return(_a0)} -} - -func (_m *Node) OnSetup(ctx context.Context, setupContext handler.SetupContext) *Node_Setup { - c_call := _m.On("Setup", ctx, setupContext) - return &Node_Setup{Call: c_call} -} - -func (_m *Node) OnSetupMatch(matchers ...interface{}) *Node_Setup { - c_call := _m.On("Setup", matchers...) - return &Node_Setup{Call: c_call} -} - -// Setup provides a mock function with given fields: ctx, setupContext -func (_m *Node) Setup(ctx context.Context, setupContext handler.SetupContext) error { - ret := _m.Called(ctx, setupContext) - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, handler.SetupContext) error); ok { - r0 = rf(ctx, setupContext) - } else { - r0 = ret.Error(0) - } - - return r0 -} diff --git a/pkg/controller/nodes/handler/transition_info_test.go b/pkg/controller/nodes/handler/transition_info_test.go index 579d16cb5..9f85f0628 100644 --- a/pkg/controller/nodes/handler/transition_info_test.go +++ b/pkg/controller/nodes/handler/transition_info_test.go @@ -4,9 +4,10 @@ import ( "testing" "github.com/flyteorg/flyteidl/clients/go/coreutils" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/golang/protobuf/proto" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/stretchr/testify/assert" ) diff --git a/pkg/controller/nodes/handler/transition_test.go b/pkg/controller/nodes/handler/transition_test.go index 32f79d2dc..61236531f 100644 --- a/pkg/controller/nodes/handler/transition_test.go +++ b/pkg/controller/nodes/handler/transition_test.go @@ -6,6 +6,7 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytestdlib/storage" + "github.com/stretchr/testify/assert" ) diff --git a/pkg/controller/nodes/handler_factory.go b/pkg/controller/nodes/handler_factory.go deleted file mode 100644 index b3ed8b95e..000000000 --- a/pkg/controller/nodes/handler_factory.go +++ /dev/null @@ -1,86 +0,0 @@ -package nodes - -import ( - "context" - - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" - - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog" - - "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" - "github.com/flyteorg/flytepropeller/pkg/controller/config" - "github.com/flyteorg/flytepropeller/pkg/controller/executors" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/array" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/branch" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/dynamic" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/end" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/gate" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/start" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task" - - "github.com/flyteorg/flytestdlib/promutils" - - "github.com/pkg/errors" -) - -//go:generate mockery -name HandlerFactory -case=underscore - -type HandlerFactory interface { - GetHandler(kind v1alpha1.NodeKind) (handler.Node, error) - Setup(ctx context.Context, setup handler.SetupContext) error -} - -type handlerFactory struct { - handlers map[v1alpha1.NodeKind]handler.Node -} - -func (f handlerFactory) GetHandler(kind v1alpha1.NodeKind) (handler.Node, error) { - h, ok := f.handlers[kind] - if !ok { - return nil, errors.Errorf("Handler not registered for NodeKind [%v]", kind) - } - return h, nil -} - -func (f handlerFactory) Setup(ctx context.Context, setup handler.SetupContext) error { - for _, v := range f.handlers { - if err := v.Setup(ctx, setup); err != nil { - return err - } - } - return nil -} - -func NewHandlerFactory(ctx context.Context, executor interfaces.Node, workflowLauncher launchplan.Executor, - launchPlanReader launchplan.Reader, kubeClient executors.Client, client catalog.Client, recoveryClient recovery.Client, - eventConfig *config.EventConfig, clusterID string, signalClient service.SignalServiceClient, scope promutils.Scope) (HandlerFactory, error) { - - t, err := task.New(ctx, kubeClient, client, eventConfig, clusterID, scope) - if err != nil { - return nil, err - } - - arrayHandler, err := array.New(executor, eventConfig, scope) - if err != nil { - return nil, err - } - - f := &handlerFactory{ - handlers: map[v1alpha1.NodeKind]handler.Node{ - v1alpha1.NodeKindBranch: branch.New(executor, eventConfig, scope), - v1alpha1.NodeKindTask: dynamic.New(t, executor, launchPlanReader, eventConfig, scope), - v1alpha1.NodeKindWorkflow: subworkflow.New(executor, workflowLauncher, recoveryClient, eventConfig, scope), - v1alpha1.NodeKindGate: gate.New(eventConfig, signalClient, scope), - v1alpha1.NodeKindArray: arrayHandler, - v1alpha1.NodeKindStart: start.New(), - v1alpha1.NodeKindEnd: end.New(), - }, - } - - return f, nil -} diff --git a/pkg/controller/nodes/handler/iface.go b/pkg/controller/nodes/interfaces/handler.go similarity index 64% rename from pkg/controller/nodes/handler/iface.go rename to pkg/controller/nodes/interfaces/handler.go index d85caef19..d2fd411cf 100644 --- a/pkg/controller/nodes/handler/iface.go +++ b/pkg/controller/nodes/interfaces/handler.go @@ -1,10 +1,10 @@ -package handler +package interfaces import ( "context" "github.com/flyteorg/flytepropeller/pkg/controller/executors" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" "github.com/flyteorg/flytestdlib/promutils" ) @@ -13,13 +13,13 @@ import ( // TODO @hamersaw - docs?!?1 type NodeExecutor interface { // TODO @hamersaw - BuildNodeExecutionContext should be here - removes need for another interface - HandleNode(ctx context.Context, dag executors.DAGStructure, nCtx interfaces.NodeExecutionContext, h Node) (interfaces.NodeStatus, error) - Abort(ctx context.Context, h Node, nCtx interfaces.NodeExecutionContext, reason string) error - Finalize(ctx context.Context, h Node, nCtx interfaces.NodeExecutionContext) error + HandleNode(ctx context.Context, dag executors.DAGStructure, nCtx NodeExecutionContext, h NodeHandler) (NodeStatus, error) + Abort(ctx context.Context, h NodeHandler, nCtx NodeExecutionContext, reason string) error + Finalize(ctx context.Context, h NodeHandler, nCtx NodeExecutionContext) error } // Interface that should be implemented for a node type. -type Node interface { +type NodeHandler interface { // Method to indicate that finalize is required for this handler FinalizeRequired() bool @@ -27,14 +27,14 @@ type Node interface { Setup(ctx context.Context, setupContext SetupContext) error // Core method that should handle this node - Handle(ctx context.Context, executionContext interfaces.NodeExecutionContext) (Transition, error) + Handle(ctx context.Context, executionContext NodeExecutionContext) (handler.Transition, error) // This method should be invoked to indicate the node needs to be aborted. - Abort(ctx context.Context, executionContext interfaces.NodeExecutionContext, reason string) error + Abort(ctx context.Context, executionContext NodeExecutionContext, reason string) error // This method is always called before completing the node, if FinalizeRequired returns true. // It is guaranteed that Handle -> (happens before) -> Finalize. Abort -> finalize may be repeated multiple times - Finalize(ctx context.Context, executionContext interfaces.NodeExecutionContext) error + Finalize(ctx context.Context, executionContext NodeExecutionContext) error } type SetupContext interface { diff --git a/pkg/controller/nodes/interfaces/handler_factory.go b/pkg/controller/nodes/interfaces/handler_factory.go new file mode 100644 index 000000000..a323d8bf8 --- /dev/null +++ b/pkg/controller/nodes/interfaces/handler_factory.go @@ -0,0 +1,14 @@ +package interfaces + +import ( + "context" + + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) + +//go:generate mockery -name HandlerFactory -case=underscore + +type HandlerFactory interface { + GetHandler(kind v1alpha1.NodeKind) (NodeHandler, error) + Setup(ctx context.Context, executor Node, setup SetupContext) error +} diff --git a/pkg/controller/nodes/mocks/handler_factory.go b/pkg/controller/nodes/interfaces/mocks/handler_factory.go similarity index 61% rename from pkg/controller/nodes/mocks/handler_factory.go rename to pkg/controller/nodes/interfaces/mocks/handler_factory.go index fffa3c818..ca851495f 100644 --- a/pkg/controller/nodes/mocks/handler_factory.go +++ b/pkg/controller/nodes/interfaces/mocks/handler_factory.go @@ -5,7 +5,7 @@ package mocks import ( context "context" - handler "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + interfaces "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" mock "github.com/stretchr/testify/mock" v1alpha1 "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" @@ -20,7 +20,7 @@ type HandlerFactory_GetHandler struct { *mock.Call } -func (_m HandlerFactory_GetHandler) Return(_a0 handler.Node, _a1 error) *HandlerFactory_GetHandler { +func (_m HandlerFactory_GetHandler) Return(_a0 interfaces.NodeHandler, _a1 error) *HandlerFactory_GetHandler { return &HandlerFactory_GetHandler{Call: _m.Call.Return(_a0, _a1)} } @@ -35,15 +35,15 @@ func (_m *HandlerFactory) OnGetHandlerMatch(matchers ...interface{}) *HandlerFac } // GetHandler provides a mock function with given fields: kind -func (_m *HandlerFactory) GetHandler(kind v1alpha1.NodeKind) (handler.Node, error) { +func (_m *HandlerFactory) GetHandler(kind v1alpha1.NodeKind) (interfaces.NodeHandler, error) { ret := _m.Called(kind) - var r0 handler.Node - if rf, ok := ret.Get(0).(func(v1alpha1.NodeKind) handler.Node); ok { + var r0 interfaces.NodeHandler + if rf, ok := ret.Get(0).(func(v1alpha1.NodeKind) interfaces.NodeHandler); ok { r0 = rf(kind) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(handler.Node) + r0 = ret.Get(0).(interfaces.NodeHandler) } } @@ -65,8 +65,8 @@ func (_m HandlerFactory_Setup) Return(_a0 error) *HandlerFactory_Setup { return &HandlerFactory_Setup{Call: _m.Call.Return(_a0)} } -func (_m *HandlerFactory) OnSetup(ctx context.Context, setup handler.SetupContext) *HandlerFactory_Setup { - c_call := _m.On("Setup", ctx, setup) +func (_m *HandlerFactory) OnSetup(ctx context.Context, executor interfaces.Node, setup interfaces.SetupContext) *HandlerFactory_Setup { + c_call := _m.On("Setup", ctx, executor, setup) return &HandlerFactory_Setup{Call: c_call} } @@ -75,13 +75,13 @@ func (_m *HandlerFactory) OnSetupMatch(matchers ...interface{}) *HandlerFactory_ return &HandlerFactory_Setup{Call: c_call} } -// Setup provides a mock function with given fields: ctx, setup -func (_m *HandlerFactory) Setup(ctx context.Context, setup handler.SetupContext) error { - ret := _m.Called(ctx, setup) +// Setup provides a mock function with given fields: ctx, executor, setup +func (_m *HandlerFactory) Setup(ctx context.Context, executor interfaces.Node, setup interfaces.SetupContext) error { + ret := _m.Called(ctx, executor, setup) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, handler.SetupContext) error); ok { - r0 = rf(ctx, setup) + if rf, ok := ret.Get(0).(func(context.Context, interfaces.Node, interfaces.SetupContext) error); ok { + r0 = rf(ctx, executor, setup) } else { r0 = ret.Error(0) } diff --git a/pkg/controller/nodes/handler/mocks/node_executor.go b/pkg/controller/nodes/interfaces/mocks/node_executor.go similarity index 69% rename from pkg/controller/nodes/handler/mocks/node_executor.go rename to pkg/controller/nodes/interfaces/mocks/node_executor.go index 9aeca0cd3..72e99c906 100644 --- a/pkg/controller/nodes/handler/mocks/node_executor.go +++ b/pkg/controller/nodes/interfaces/mocks/node_executor.go @@ -6,8 +6,6 @@ import ( context "context" executors "github.com/flyteorg/flytepropeller/pkg/controller/executors" - handler "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" - interfaces "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" mock "github.com/stretchr/testify/mock" @@ -26,7 +24,7 @@ func (_m NodeExecutor_Abort) Return(_a0 error) *NodeExecutor_Abort { return &NodeExecutor_Abort{Call: _m.Call.Return(_a0)} } -func (_m *NodeExecutor) OnAbort(ctx context.Context, h handler.Node, nCtx interfaces.NodeExecutionContext, reason string) *NodeExecutor_Abort { +func (_m *NodeExecutor) OnAbort(ctx context.Context, h interfaces.NodeHandler, nCtx interfaces.NodeExecutionContext, reason string) *NodeExecutor_Abort { c_call := _m.On("Abort", ctx, h, nCtx, reason) return &NodeExecutor_Abort{Call: c_call} } @@ -37,11 +35,11 @@ func (_m *NodeExecutor) OnAbortMatch(matchers ...interface{}) *NodeExecutor_Abor } // Abort provides a mock function with given fields: ctx, h, nCtx, reason -func (_m *NodeExecutor) Abort(ctx context.Context, h handler.Node, nCtx interfaces.NodeExecutionContext, reason string) error { +func (_m *NodeExecutor) Abort(ctx context.Context, h interfaces.NodeHandler, nCtx interfaces.NodeExecutionContext, reason string) error { ret := _m.Called(ctx, h, nCtx, reason) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, handler.Node, interfaces.NodeExecutionContext, string) error); ok { + if rf, ok := ret.Get(0).(func(context.Context, interfaces.NodeHandler, interfaces.NodeExecutionContext, string) error); ok { r0 = rf(ctx, h, nCtx, reason) } else { r0 = ret.Error(0) @@ -58,7 +56,7 @@ func (_m NodeExecutor_Finalize) Return(_a0 error) *NodeExecutor_Finalize { return &NodeExecutor_Finalize{Call: _m.Call.Return(_a0)} } -func (_m *NodeExecutor) OnFinalize(ctx context.Context, h handler.Node, nCtx interfaces.NodeExecutionContext) *NodeExecutor_Finalize { +func (_m *NodeExecutor) OnFinalize(ctx context.Context, h interfaces.NodeHandler, nCtx interfaces.NodeExecutionContext) *NodeExecutor_Finalize { c_call := _m.On("Finalize", ctx, h, nCtx) return &NodeExecutor_Finalize{Call: c_call} } @@ -69,11 +67,11 @@ func (_m *NodeExecutor) OnFinalizeMatch(matchers ...interface{}) *NodeExecutor_F } // Finalize provides a mock function with given fields: ctx, h, nCtx -func (_m *NodeExecutor) Finalize(ctx context.Context, h handler.Node, nCtx interfaces.NodeExecutionContext) error { +func (_m *NodeExecutor) Finalize(ctx context.Context, h interfaces.NodeHandler, nCtx interfaces.NodeExecutionContext) error { ret := _m.Called(ctx, h, nCtx) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, handler.Node, interfaces.NodeExecutionContext) error); ok { + if rf, ok := ret.Get(0).(func(context.Context, interfaces.NodeHandler, interfaces.NodeExecutionContext) error); ok { r0 = rf(ctx, h, nCtx) } else { r0 = ret.Error(0) @@ -90,7 +88,7 @@ func (_m NodeExecutor_HandleNode) Return(_a0 interfaces.NodeStatus, _a1 error) * return &NodeExecutor_HandleNode{Call: _m.Call.Return(_a0, _a1)} } -func (_m *NodeExecutor) OnHandleNode(ctx context.Context, dag executors.DAGStructure, nCtx interfaces.NodeExecutionContext, h handler.Node) *NodeExecutor_HandleNode { +func (_m *NodeExecutor) OnHandleNode(ctx context.Context, dag executors.DAGStructure, nCtx interfaces.NodeExecutionContext, h interfaces.NodeHandler) *NodeExecutor_HandleNode { c_call := _m.On("HandleNode", ctx, dag, nCtx, h) return &NodeExecutor_HandleNode{Call: c_call} } @@ -101,18 +99,18 @@ func (_m *NodeExecutor) OnHandleNodeMatch(matchers ...interface{}) *NodeExecutor } // HandleNode provides a mock function with given fields: ctx, dag, nCtx, h -func (_m *NodeExecutor) HandleNode(ctx context.Context, dag executors.DAGStructure, nCtx interfaces.NodeExecutionContext, h handler.Node) (interfaces.NodeStatus, error) { +func (_m *NodeExecutor) HandleNode(ctx context.Context, dag executors.DAGStructure, nCtx interfaces.NodeExecutionContext, h interfaces.NodeHandler) (interfaces.NodeStatus, error) { ret := _m.Called(ctx, dag, nCtx, h) var r0 interfaces.NodeStatus - if rf, ok := ret.Get(0).(func(context.Context, executors.DAGStructure, interfaces.NodeExecutionContext, handler.Node) interfaces.NodeStatus); ok { + if rf, ok := ret.Get(0).(func(context.Context, executors.DAGStructure, interfaces.NodeExecutionContext, interfaces.NodeHandler) interfaces.NodeStatus); ok { r0 = rf(ctx, dag, nCtx, h) } else { r0 = ret.Get(0).(interfaces.NodeStatus) } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, executors.DAGStructure, interfaces.NodeExecutionContext, handler.Node) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, executors.DAGStructure, interfaces.NodeExecutionContext, interfaces.NodeHandler) error); ok { r1 = rf(ctx, dag, nCtx, h) } else { r1 = ret.Error(1) diff --git a/pkg/controller/nodes/interfaces/mocks/node_handler.go b/pkg/controller/nodes/interfaces/mocks/node_handler.go new file mode 100644 index 000000000..66bd61d27 --- /dev/null +++ b/pkg/controller/nodes/interfaces/mocks/node_handler.go @@ -0,0 +1,184 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + handler "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + interfaces "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + + mock "github.com/stretchr/testify/mock" +) + +// NodeHandler is an autogenerated mock type for the NodeHandler type +type NodeHandler struct { + mock.Mock +} + +type NodeHandler_Abort struct { + *mock.Call +} + +func (_m NodeHandler_Abort) Return(_a0 error) *NodeHandler_Abort { + return &NodeHandler_Abort{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeHandler) OnAbort(ctx context.Context, executionContext interfaces.NodeExecutionContext, reason string) *NodeHandler_Abort { + c_call := _m.On("Abort", ctx, executionContext, reason) + return &NodeHandler_Abort{Call: c_call} +} + +func (_m *NodeHandler) OnAbortMatch(matchers ...interface{}) *NodeHandler_Abort { + c_call := _m.On("Abort", matchers...) + return &NodeHandler_Abort{Call: c_call} +} + +// Abort provides a mock function with given fields: ctx, executionContext, reason +func (_m *NodeHandler) Abort(ctx context.Context, executionContext interfaces.NodeExecutionContext, reason string) error { + ret := _m.Called(ctx, executionContext, reason) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, interfaces.NodeExecutionContext, string) error); ok { + r0 = rf(ctx, executionContext, reason) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type NodeHandler_Finalize struct { + *mock.Call +} + +func (_m NodeHandler_Finalize) Return(_a0 error) *NodeHandler_Finalize { + return &NodeHandler_Finalize{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeHandler) OnFinalize(ctx context.Context, executionContext interfaces.NodeExecutionContext) *NodeHandler_Finalize { + c_call := _m.On("Finalize", ctx, executionContext) + return &NodeHandler_Finalize{Call: c_call} +} + +func (_m *NodeHandler) OnFinalizeMatch(matchers ...interface{}) *NodeHandler_Finalize { + c_call := _m.On("Finalize", matchers...) + return &NodeHandler_Finalize{Call: c_call} +} + +// Finalize provides a mock function with given fields: ctx, executionContext +func (_m *NodeHandler) Finalize(ctx context.Context, executionContext interfaces.NodeExecutionContext) error { + ret := _m.Called(ctx, executionContext) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, interfaces.NodeExecutionContext) error); ok { + r0 = rf(ctx, executionContext) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type NodeHandler_FinalizeRequired struct { + *mock.Call +} + +func (_m NodeHandler_FinalizeRequired) Return(_a0 bool) *NodeHandler_FinalizeRequired { + return &NodeHandler_FinalizeRequired{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeHandler) OnFinalizeRequired() *NodeHandler_FinalizeRequired { + c_call := _m.On("FinalizeRequired") + return &NodeHandler_FinalizeRequired{Call: c_call} +} + +func (_m *NodeHandler) OnFinalizeRequiredMatch(matchers ...interface{}) *NodeHandler_FinalizeRequired { + c_call := _m.On("FinalizeRequired", matchers...) + return &NodeHandler_FinalizeRequired{Call: c_call} +} + +// FinalizeRequired provides a mock function with given fields: +func (_m *NodeHandler) FinalizeRequired() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +type NodeHandler_Handle struct { + *mock.Call +} + +func (_m NodeHandler_Handle) Return(_a0 handler.Transition, _a1 error) *NodeHandler_Handle { + return &NodeHandler_Handle{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *NodeHandler) OnHandle(ctx context.Context, executionContext interfaces.NodeExecutionContext) *NodeHandler_Handle { + c_call := _m.On("Handle", ctx, executionContext) + return &NodeHandler_Handle{Call: c_call} +} + +func (_m *NodeHandler) OnHandleMatch(matchers ...interface{}) *NodeHandler_Handle { + c_call := _m.On("Handle", matchers...) + return &NodeHandler_Handle{Call: c_call} +} + +// Handle provides a mock function with given fields: ctx, executionContext +func (_m *NodeHandler) Handle(ctx context.Context, executionContext interfaces.NodeExecutionContext) (handler.Transition, error) { + ret := _m.Called(ctx, executionContext) + + var r0 handler.Transition + if rf, ok := ret.Get(0).(func(context.Context, interfaces.NodeExecutionContext) handler.Transition); ok { + r0 = rf(ctx, executionContext) + } else { + r0 = ret.Get(0).(handler.Transition) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, interfaces.NodeExecutionContext) error); ok { + r1 = rf(ctx, executionContext) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type NodeHandler_Setup struct { + *mock.Call +} + +func (_m NodeHandler_Setup) Return(_a0 error) *NodeHandler_Setup { + return &NodeHandler_Setup{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeHandler) OnSetup(ctx context.Context, setupContext interfaces.SetupContext) *NodeHandler_Setup { + c_call := _m.On("Setup", ctx, setupContext) + return &NodeHandler_Setup{Call: c_call} +} + +func (_m *NodeHandler) OnSetupMatch(matchers ...interface{}) *NodeHandler_Setup { + c_call := _m.On("Setup", matchers...) + return &NodeHandler_Setup{Call: c_call} +} + +// Setup provides a mock function with given fields: ctx, setupContext +func (_m *NodeHandler) Setup(ctx context.Context, setupContext interfaces.SetupContext) error { + ret := _m.Called(ctx, setupContext) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, interfaces.SetupContext) error); ok { + r0 = rf(ctx, setupContext) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/pkg/controller/nodes/handler/mocks/setup_context.go b/pkg/controller/nodes/interfaces/mocks/setup_context.go similarity index 100% rename from pkg/controller/nodes/handler/mocks/setup_context.go rename to pkg/controller/nodes/interfaces/mocks/setup_context.go diff --git a/pkg/controller/nodes/setup_context.go b/pkg/controller/nodes/setup_context.go index c940447a7..d39863763 100644 --- a/pkg/controller/nodes/setup_context.go +++ b/pkg/controller/nodes/setup_context.go @@ -6,7 +6,7 @@ import ( "github.com/flyteorg/flytestdlib/promutils" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" ) type setupContext struct { @@ -26,7 +26,7 @@ func (s *setupContext) MetricsScope() promutils.Scope { return s.scope } -func (c *recursiveNodeExecutor) newSetupContext(_ context.Context) handler.SetupContext { +func (c *recursiveNodeExecutor) newSetupContext(_ context.Context) interfaces.SetupContext { return &setupContext{ enq: c.enqueueWorkflow, scope: c.metrics.Scope, diff --git a/pkg/controller/nodes/start/handler.go b/pkg/controller/nodes/start/handler.go index 1fecdba25..a8535b8fd 100644 --- a/pkg/controller/nodes/start/handler.go +++ b/pkg/controller/nodes/start/handler.go @@ -14,7 +14,7 @@ func (s startHandler) FinalizeRequired() bool { return false } -func (s startHandler) Setup(ctx context.Context, setupContext handler.SetupContext) error { +func (s startHandler) Setup(ctx context.Context, setupContext interfaces.SetupContext) error { return nil } @@ -30,6 +30,6 @@ func (s startHandler) Finalize(ctx context.Context, executionContext interfaces. return nil } -func New() handler.Node { +func New() interfaces.NodeHandler { return &startHandler{} } diff --git a/pkg/controller/nodes/subworkflow/handler.go b/pkg/controller/nodes/subworkflow/handler.go index 7ad2b3dc7..4787509e8 100644 --- a/pkg/controller/nodes/subworkflow/handler.go +++ b/pkg/controller/nodes/subworkflow/handler.go @@ -40,7 +40,7 @@ func (w *workflowNodeHandler) FinalizeRequired() bool { return false } -func (w *workflowNodeHandler) Setup(_ context.Context, _ handler.SetupContext) error { +func (w *workflowNodeHandler) Setup(_ context.Context, _ interfaces.SetupContext) error { return nil } @@ -129,7 +129,7 @@ func (w *workflowNodeHandler) Finalize(ctx context.Context, _ interfaces.NodeExe return nil } -func New(executor interfaces.Node, workflowLauncher launchplan.Executor, recoveryClient recovery.Client, eventConfig *config.EventConfig, scope promutils.Scope) handler.Node { +func New(executor interfaces.Node, workflowLauncher launchplan.Executor, recoveryClient recovery.Client, eventConfig *config.EventConfig, scope promutils.Scope) interfaces.NodeHandler { workflowScope := scope.NewSubScope("workflow") return &workflowNodeHandler{ subWfHandler: newSubworkflowHandler(executor, eventConfig), diff --git a/pkg/controller/nodes/task/handler.go b/pkg/controller/nodes/task/handler.go index 3b778291f..04fb00721 100644 --- a/pkg/controller/nodes/task/handler.go +++ b/pkg/controller/nodes/task/handler.go @@ -227,7 +227,7 @@ func (t *Handler) setDefault(ctx context.Context, p pluginCore.Plugin) error { return nil } -func (t *Handler) Setup(ctx context.Context, sCtx handler.SetupContext) error { +func (t *Handler) Setup(ctx context.Context, sCtx interfaces.SetupContext) error { tSCtx := t.newSetupContext(sCtx) // Create a new base resource negotiator diff --git a/pkg/controller/nodes/task/setup_ctx.go b/pkg/controller/nodes/task/setup_ctx.go index c788ffd4e..4277275a2 100644 --- a/pkg/controller/nodes/task/setup_ctx.go +++ b/pkg/controller/nodes/task/setup_ctx.go @@ -2,14 +2,16 @@ package task import ( pluginCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + "github.com/flyteorg/flytestdlib/promutils" - "k8s.io/apimachinery/pkg/types" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "k8s.io/apimachinery/pkg/types" ) type setupContext struct { - handler.SetupContext + interfaces.SetupContext kubeClient pluginCore.KubeClient secretManager pluginCore.SecretManager } @@ -29,7 +31,7 @@ func (s setupContext) EnqueueOwner() pluginCore.EnqueueOwner { } } -func (t *Handler) newSetupContext(sCtx handler.SetupContext) *setupContext { +func (t *Handler) newSetupContext(sCtx interfaces.SetupContext) *setupContext { return &setupContext{ SetupContext: sCtx, From b871b23a8a6a904d1754396c56e02f8dd3e74704 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Wed, 14 Jun 2023 09:38:07 -0500 Subject: [PATCH 37/62] first unit test for handling ArrayNodePhaseNone Signed-off-by: Daniel Rammer --- pkg/controller/nodes/array/handler_test.go | 146 ++++++++++++++++++++- 1 file changed, 143 insertions(+), 3 deletions(-) diff --git a/pkg/controller/nodes/array/handler_test.go b/pkg/controller/nodes/array/handler_test.go index bdfb7f5a1..49144afe3 100644 --- a/pkg/controller/nodes/array/handler_test.go +++ b/pkg/controller/nodes/array/handler_test.go @@ -4,23 +4,30 @@ import ( "context" "testing" + idlcore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + eventmocks "github.com/flyteorg/flytepropeller/events/mocks" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" - execmocks "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/config" + execmocks "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes" gatemocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/gate/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" recoverymocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/catalog" + pluginmocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" + + "github.com/flyteorg/flytestdlib/contextutils" "github.com/flyteorg/flytestdlib/promutils" + "github.com/flyteorg/flytestdlib/promutils/labeled" "github.com/flyteorg/flytestdlib/storage" "github.com/stretchr/testify/assert" - //"github.com/stretchr/testify/mock" + "github.com/stretchr/testify/mock" ) func createArrayNodeHandler(t *testing.T, ctx context.Context, scope promutils.Scope) (interfaces.NodeHandler, error) { @@ -48,6 +55,62 @@ func createArrayNodeHandler(t *testing.T, ctx context.Context, scope promutils.S return New(nodeExecutor, eventConfig, scope) } +func createNodeExecutionContext(t *testing.T, ctx context.Context, inputLiteralMap *idlcore.LiteralMap) (interfaces.NodeExecutionContext, error) { + nCtx := &mocks.NodeExecutionContext{} + + // EventsRecorder + eventRecorder := &mocks.EventRecorder{} + eventRecorder.OnRecordTaskEventMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) // TODO @hamersaw - should probably capture to validate + nCtx.OnEventsRecorder().Return(eventRecorder) + + // InputReader + nCtx.OnInputReader().Return( + newStaticInputReader( + &pluginmocks.InputFilePaths{}, + inputLiteralMap, + )) + + // Node + taskRef := "arrayNodeTaskID" + nCtx.OnNode().Return(&v1alpha1.NodeSpec{ + ID: "foo", + ArrayNode: &v1alpha1.ArrayNodeSpec{ + SubNodeSpec: &v1alpha1.NodeSpec{ + TaskRef: &taskRef, + }, + }, + }) + + // NodeExecutionMetadata + nodeExecutionMetadata := &mocks.NodeExecutionMetadata{} + nodeExecutionMetadata.OnGetNodeExecutionID().Return(&idlcore.NodeExecutionIdentifier{ + NodeId: "foo", + ExecutionId: &idlcore.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + }) + nCtx.OnNodeExecutionMetadata().Return(nodeExecutionMetadata) + + // NodeID + nCtx.OnNodeID().Return("foo") + + // NodeStateReader + nodeStateReader := &mocks.NodeStateReader{} + nodeStateReader.OnGetArrayNodeState().Return(interfaces.ArrayNodeState{ + Phase: v1alpha1.ArrayNodePhaseNone, + }) + nCtx.OnNodeStateReader().Return(nodeStateReader) + + // NodeStateWriter + nodeStateWriter := &mocks.NodeStateWriter{} + nodeStateWriter.OnPutArrayNodeStateMatch(mock.Anything, mock.Anything).Return(nil) // TODO @hamersaw - should probably capture to validate + nCtx.OnNodeStateWriter().Return(nodeStateWriter) + + return nCtx, nil +} + func TestAbort(t *testing.T) { // TODO @hamersaw - complete } @@ -64,7 +127,44 @@ func TestHandleArrayNodePhaseNone(t *testing.T) { arrayNodeHandler, err := createArrayNodeHandler(t, ctx, scope) assert.NoError(t, err) - // TODO @hamersaw - complete + tests := []struct { + name string + inputValues map[string][]int64 + expectedTransitionPhase handler.EPhase + }{ + { + name: "Success", + inputValues: map[string][]int64{ + "foo": []int64{1, 2}, + }, + expectedTransitionPhase: handler.EPhaseRunning, + }, + { + name: "FailureDifferentInputListLengths", + inputValues: map[string][]int64{ + "foo": []int64{1, 2}, + "bar": []int64{3}, + }, + expectedTransitionPhase: handler.EPhaseFailed, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // create NodeExecutionContext + literalMap := convertMapToArrayLiterals(test.inputValues) + nCtx, err := createNodeExecutionContext(t, ctx, literalMap) + assert.NoError(t, err) + + // evaluate node + transition, err := arrayNodeHandler.Handle(ctx, nCtx) + assert.NoError(t, err) + + // validate results + assert.Equal(t, test.expectedTransitionPhase, transition.Info().GetPhase()) + // TODO @hamersaw - validate TaskExecutionEvent and ArrayNodeState + }) + } } func TestHandleArrayNodePhaseExecuting(t *testing.T) { @@ -78,3 +178,43 @@ func TestHandleArrayNodePhaseSucceeding(t *testing.T) { func TestHandleArrayNodePhaseFailing(t *testing.T) { // TODO @hamersaw - complete } + +func init() { + labeled.SetMetricKeys(contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey) +} + +func convertMapToArrayLiterals(values map[string][]int64) *idlcore.LiteralMap { + literalMap := make(map[string]*idlcore.Literal) + for k, v := range values { + // create LiteralCollection + literalList := make([]*idlcore.Literal, len(v)) + for _, x := range v { + literalList = append(literalList, &idlcore.Literal{ + Value: &idlcore.Literal_Scalar{ + Scalar: &idlcore.Scalar{ + Value: &idlcore.Scalar_Primitive{ + Primitive: &idlcore.Primitive{ + Value: &idlcore.Primitive_Integer{ + Integer: x, + }, + }, + }, + }, + }, + }) + } + + // add LiteralCollection to map + literalMap[k] = &idlcore.Literal{ + Value: &idlcore.Literal_Collection{ + Collection: &idlcore.LiteralCollection{ + Literals: literalList, + }, + }, + } + } + + return &idlcore.LiteralMap{ + Literals: literalMap, + } +} From 49f4d32859228787ff2bdf1239a1ad92f5c7aee7 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Wed, 14 Jun 2023 16:03:05 -0500 Subject: [PATCH 38/62] most of executing unit tests completed Signed-off-by: Daniel Rammer --- pkg/controller/nodes/array/handler.go | 7 +- pkg/controller/nodes/array/handler_test.go | 279 +++++++++++++++++++-- 2 files changed, 259 insertions(+), 27 deletions(-) diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 2b8de07e9..24ebc04d9 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -284,11 +284,6 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu retryAttempt := subNodeStatus.GetAttempts() for _, taskExecutionEvent := range arrayEventRecorder.TaskEvents() { - taskPhase := idlcore.TaskExecution_UNDEFINED - if taskNodeStatus := subNodeStatus.GetTaskNodeStatus(); taskNodeStatus != nil { - taskPhase = task.ToTaskEventPhase(core.Phase(taskNodeStatus.GetPhase())) - } - for _, log := range taskExecutionEvent.Logs { log.Name = fmt.Sprintf("%s-%d", log.Name, i) // TODO @hamersaw - do we need to add retryAttempt to log name? } @@ -298,7 +293,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu Index: uint32(i), Logs: taskExecutionEvent.Logs, RetryAttempt: retryAttempt, - Phase: taskPhase, + Phase: taskExecutionEvent.Phase, CacheStatus: cacheStatus, }) } diff --git a/pkg/controller/nodes/array/handler_test.go b/pkg/controller/nodes/array/handler_test.go index 49144afe3..6492cef24 100644 --- a/pkg/controller/nodes/array/handler_test.go +++ b/pkg/controller/nodes/array/handler_test.go @@ -2,9 +2,11 @@ package array import ( "context" + "fmt" "testing" idlcore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" eventmocks "github.com/flyteorg/flytepropeller/events/mocks" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" @@ -19,8 +21,10 @@ import ( "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/catalog" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" pluginmocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" + "github.com/flyteorg/flytestdlib/bitarray" "github.com/flyteorg/flytestdlib/contextutils" "github.com/flyteorg/flytestdlib/promutils" "github.com/flyteorg/flytestdlib/promutils/labeled" @@ -30,17 +34,14 @@ import ( "github.com/stretchr/testify/mock" ) -func createArrayNodeHandler(t *testing.T, ctx context.Context, scope promutils.Scope) (interfaces.NodeHandler, error) { +func createArrayNodeHandler(t *testing.T, ctx context.Context, nodeHandler interfaces.NodeHandler, dataStore *storage.DataStore, scope promutils.Scope) (interfaces.NodeHandler, error) { // mock components adminClient := launchplan.NewFailFastLaunchPlanExecutor() - dataStore, err := storage.NewDataStore(&storage.Config{ - Type: storage.TypeMemory, - }, scope) - assert.NoError(t, err) enqueueWorkflowFunc := func(workflowID v1alpha1.WorkflowID) {} eventConfig := &config.EventConfig{} mockEventSink := eventmocks.NewMockEventSink() mockHandlerFactory := &mocks.HandlerFactory{} + mockHandlerFactory.OnGetHandlerMatch(mock.Anything).Return(nodeHandler, nil) mockKubeClient := execmocks.NewFakeKubeClient() mockRecoveryClient := &recoverymocks.Client{} mockSignalClient := &gatemocks.SignalServiceClient{} @@ -55,18 +56,45 @@ func createArrayNodeHandler(t *testing.T, ctx context.Context, scope promutils.S return New(nodeExecutor, eventConfig, scope) } -func createNodeExecutionContext(t *testing.T, ctx context.Context, inputLiteralMap *idlcore.LiteralMap) (interfaces.NodeExecutionContext, error) { +func createNodeExecutionContext(t *testing.T, ctx context.Context, dataStore *storage.DataStore, eventRecorder interfaces.EventRecorder, inputLiteralMap *idlcore.LiteralMap, arrayNodeState *interfaces.ArrayNodeState) (interfaces.NodeExecutionContext, error) { nCtx := &mocks.NodeExecutionContext{} + // ContextualNodeLookup + nodeLookup := &execmocks.NodeLookup{} + nCtx.OnContextualNodeLookup().Return(nodeLookup) + + // DataStore + nCtx.OnDataStore().Return(dataStore) + + // ExecutionContext + executionContext := &execmocks.ExecutionContext{} + executionContext.OnGetEventVersion().Return(1) + executionContext.OnGetExecutionConfig().Return( + v1alpha1.ExecutionConfig{ + }) + executionContext.OnGetExecutionID().Return( + v1alpha1.ExecutionID{ + &idlcore.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + }) + executionContext.OnGetLabels().Return(nil) + executionContext.OnGetRawOutputDataConfig().Return(v1alpha1.RawOutputDataConfig{}) + executionContext.OnIsInterruptible().Return(false) + executionContext.OnGetParentInfo().Return(nil) + nCtx.OnExecutionContext().Return(executionContext) + // EventsRecorder - eventRecorder := &mocks.EventRecorder{} - eventRecorder.OnRecordTaskEventMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) // TODO @hamersaw - should probably capture to validate nCtx.OnEventsRecorder().Return(eventRecorder) // InputReader + inputFilePaths := &pluginmocks.InputFilePaths{} + inputFilePaths.OnGetInputPath().Return(storage.DataReference("s3://bucket/input")) nCtx.OnInputReader().Return( newStaticInputReader( - &pluginmocks.InputFilePaths{}, + inputFilePaths, inputLiteralMap, )) @@ -98,16 +126,23 @@ func createNodeExecutionContext(t *testing.T, ctx context.Context, inputLiteralM // NodeStateReader nodeStateReader := &mocks.NodeStateReader{} - nodeStateReader.OnGetArrayNodeState().Return(interfaces.ArrayNodeState{ - Phase: v1alpha1.ArrayNodePhaseNone, - }) + nodeStateReader.OnGetArrayNodeState().Return(*arrayNodeState) nCtx.OnNodeStateReader().Return(nodeStateReader) // NodeStateWriter nodeStateWriter := &mocks.NodeStateWriter{} - nodeStateWriter.OnPutArrayNodeStateMatch(mock.Anything, mock.Anything).Return(nil) // TODO @hamersaw - should probably capture to validate + nodeStateWriter.OnPutArrayNodeStateMatch(mock.Anything, mock.Anything).Run( + func(args mock.Arguments) { + *arrayNodeState = args.Get(0).(interfaces.ArrayNodeState) + }, + ).Return(nil) nCtx.OnNodeStateWriter().Return(nodeStateWriter) + // NodeStatus + nCtx.OnNodeStatus().Return(&v1alpha1.NodeStatus{ + DataDir: storage.DataReference("s3://bucket/foo"), + }) + return nCtx, nil } @@ -122,22 +157,41 @@ func TestFinalize(t *testing.T) { func TestHandleArrayNodePhaseNone(t *testing.T) { ctx := context.Background() scope := promutils.NewTestScope() + dataStore, err := storage.NewDataStore(&storage.Config{ + Type: storage.TypeMemory, + }, scope) + assert.NoError(t, err) + nodeHandler := &mocks.NodeHandler{} // initialize ArrayNodeHandler - arrayNodeHandler, err := createArrayNodeHandler(t, ctx, scope) + arrayNodeHandler, err := createArrayNodeHandler(t, ctx, nodeHandler, dataStore, scope) assert.NoError(t, err) tests := []struct { - name string - inputValues map[string][]int64 - expectedTransitionPhase handler.EPhase + name string + inputValues map[string][]int64 + expectedArrayNodePhase v1alpha1.ArrayNodePhase + expectedTransitionPhase handler.EPhase + expectedExternalResourcePhases []idlcore.TaskExecution_Phase }{ { name: "Success", inputValues: map[string][]int64{ "foo": []int64{1, 2}, }, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, expectedTransitionPhase: handler.EPhaseRunning, + expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_QUEUED, idlcore.TaskExecution_QUEUED}, + }, + { + name: "SuccessMultipleInputs", + inputValues: map[string][]int64{ + "foo": []int64{1, 2, 3}, + "bar": []int64{4, 5, 6}, + }, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, + expectedTransitionPhase: handler.EPhaseRunning, + expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_QUEUED, idlcore.TaskExecution_QUEUED, idlcore.TaskExecution_QUEUED}, }, { name: "FailureDifferentInputListLengths", @@ -145,15 +199,21 @@ func TestHandleArrayNodePhaseNone(t *testing.T) { "foo": []int64{1, 2}, "bar": []int64{3}, }, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseNone, expectedTransitionPhase: handler.EPhaseFailed, + expectedExternalResourcePhases: nil, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { // create NodeExecutionContext + eventRecorder := newArrayEventRecorder() literalMap := convertMapToArrayLiterals(test.inputValues) - nCtx, err := createNodeExecutionContext(t, ctx, literalMap) + arrayNodeState := &interfaces.ArrayNodeState{ + Phase: v1alpha1.ArrayNodePhaseNone, + } + nCtx, err := createNodeExecutionContext(t, ctx, dataStore, eventRecorder, literalMap, arrayNodeState) assert.NoError(t, err) // evaluate node @@ -161,14 +221,191 @@ func TestHandleArrayNodePhaseNone(t *testing.T) { assert.NoError(t, err) // validate results + assert.Equal(t, test.expectedArrayNodePhase, arrayNodeState.Phase) assert.Equal(t, test.expectedTransitionPhase, transition.Info().GetPhase()) - // TODO @hamersaw - validate TaskExecutionEvent and ArrayNodeState + + if len(test.expectedExternalResourcePhases) > 0 { + assert.Equal(t, 1, len(eventRecorder.taskEvents)) + + externalResources := eventRecorder.taskEvents[0].Metadata.GetExternalResources() + assert.Equal(t, len(test.expectedExternalResourcePhases), len(externalResources)) + for i, expectedPhase := range test.expectedExternalResourcePhases { + assert.Equal(t, expectedPhase, externalResources[i].Phase) + } + } else { + assert.Equal(t, 0, len(eventRecorder.taskEvents)) + } }) } } func TestHandleArrayNodePhaseExecuting(t *testing.T) { - // TODO @hamersaw - complete + ctx := context.Background() + + // initailize universal variables + inputMap := map[string][]int64{ + "foo": []int64{0, 1}, + "bar": []int64{2, 3}, + } + literalMap := convertMapToArrayLiterals(inputMap) + + size := -1 + for _, v := range inputMap { + if size == -1 { + size = len(v) + } else if len(v) > size { // calculating size as largest input list + size = len(v) + } + } + + tests := []struct { + name string + subNodePhases []v1alpha1.NodePhase + subNodeTaskPhases []core.Phase + subNodeTransitions []handler.Transition + expectedArrayNodePhase v1alpha1.ArrayNodePhase + expectedTransitionPhase handler.EPhase + expectedExternalResourcePhases []idlcore.TaskExecution_Phase + }{ + { + name: "StartAllSubNodes", + subNodePhases: []v1alpha1.NodePhase{ + v1alpha1.NodePhaseQueued, + v1alpha1.NodePhaseQueued, + }, + subNodeTaskPhases: []core.Phase{ + core.PhaseUndefined, + core.PhaseUndefined, + }, + subNodeTransitions: []handler.Transition{ + handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), + handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), + }, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, + expectedTransitionPhase: handler.EPhaseRunning, + expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_RUNNING, idlcore.TaskExecution_RUNNING}, + }, + // TODO @hamersaw - concurrency -> only start one node + { + name: "AllSubNodeSuccedeed", + subNodePhases: []v1alpha1.NodePhase{ + v1alpha1.NodePhaseRunning, + v1alpha1.NodePhaseRunning, + }, + subNodeTaskPhases: []core.Phase{ + core.PhaseRunning, + core.PhaseRunning, + }, + subNodeTransitions: []handler.Transition{ + handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(&handler.ExecutionInfo{})), + handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(&handler.ExecutionInfo{})), + }, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseSucceeding, + expectedTransitionPhase: handler.EPhaseRunning, + expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_SUCCEEDED, idlcore.TaskExecution_SUCCEEDED}, + }, + // TODO @hamersaw - recording failure message + { + name: "OneSubNodeFailed", + subNodePhases: []v1alpha1.NodePhase{ + v1alpha1.NodePhaseRunning, + v1alpha1.NodePhaseRunning, + }, + subNodeTaskPhases: []core.Phase{ + core.PhaseRunning, + core.PhaseRunning, + }, + subNodeTransitions: []handler.Transition{ + handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(0, "", "", &handler.ExecutionInfo{})), + handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(&handler.ExecutionInfo{})), + }, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseFailing, + expectedTransitionPhase: handler.EPhaseRunning, + expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_FAILED, idlcore.TaskExecution_SUCCEEDED}, + }, + // TODO @hamersaw - min_successes / min_success_ratio + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + scope := promutils.NewTestScope() + dataStore, err := storage.NewDataStore(&storage.Config{ + Type: storage.TypeMemory, + }, scope) + assert.NoError(t, err) + + // configure SubNodePhases and SubNodeTaskPhases + arrayNodeState := &interfaces.ArrayNodeState{ + Phase: v1alpha1.ArrayNodePhaseExecuting, + } + for _, item := range []struct{arrayReference *bitarray.CompactArray; maxValue int}{ + {arrayReference: &arrayNodeState.SubNodePhases, maxValue: int(v1alpha1.NodePhaseRecovered)}, + {arrayReference: &arrayNodeState.SubNodeTaskPhases, maxValue: len(core.Phases)-1}, + {arrayReference: &arrayNodeState.SubNodeRetryAttempts, maxValue: 1}, + {arrayReference: &arrayNodeState.SubNodeSystemFailures, maxValue: 1}, + } { + + *item.arrayReference, err = bitarray.NewCompactArray(uint(size), bitarray.Item(item.maxValue)) + assert.NoError(t, err) + } + + for i, nodePhase := range test.subNodePhases { + arrayNodeState.SubNodePhases.SetItem(i, bitarray.Item(nodePhase)) + } + for i, taskPhase := range test.subNodeTaskPhases { + arrayNodeState.SubNodeTaskPhases.SetItem(i, bitarray.Item(taskPhase)) + } + + // create NodeExecutionContext + eventRecorder := newArrayEventRecorder() + nCtx, err := createNodeExecutionContext(t, ctx, dataStore, eventRecorder, literalMap, arrayNodeState) + assert.NoError(t, err) + + // initialize ArrayNodeHandler + nodeHandler := &mocks.NodeHandler{} + nodeHandler.OnFinalizeRequired().Return(false) + for i, transition := range test.subNodeTransitions { + nodeID := fmt.Sprintf("%s-n%d", nCtx.NodeID(), i) + transitionPhase := test.expectedExternalResourcePhases[i] + + nodeHandler.OnHandleMatch(mock.Anything, mock.MatchedBy(func(arrayNCtx interfaces.NodeExecutionContext) bool { + return arrayNCtx.NodeID() == nodeID // match on NodeID using index to ensure each subNode is handled independently + })).Run( + func(args mock.Arguments) { + // mock sending TaskExecutionEvent from handler to show task state transition + taskExecutionEvent := &event.TaskExecutionEvent{ + Phase: transitionPhase, + } + + args.Get(1).(interfaces.NodeExecutionContext).EventsRecorder().RecordTaskEvent(ctx, taskExecutionEvent, &config.EventConfig{}) + }, + ).Return(transition, nil) + } + + arrayNodeHandler, err := createArrayNodeHandler(t, ctx, nodeHandler, dataStore, scope) + assert.NoError(t, err) + + // evaluate node + transition, err := arrayNodeHandler.Handle(ctx, nCtx) + assert.NoError(t, err) + + // validate results + assert.Equal(t, test.expectedArrayNodePhase, arrayNodeState.Phase) + assert.Equal(t, test.expectedTransitionPhase, transition.Info().GetPhase()) + + if len(test.expectedExternalResourcePhases) > 0 { + assert.Equal(t, 1, len(eventRecorder.taskEvents)) + + externalResources := eventRecorder.taskEvents[0].Metadata.GetExternalResources() + assert.Equal(t, len(test.expectedExternalResourcePhases), len(externalResources)) + for i, expectedPhase := range test.expectedExternalResourcePhases { + assert.Equal(t, expectedPhase, externalResources[i].Phase) + } + } else { + assert.Equal(t, 0, len(eventRecorder.taskEvents)) + } + }) + } } func TestHandleArrayNodePhaseSucceeding(t *testing.T) { @@ -187,7 +424,7 @@ func convertMapToArrayLiterals(values map[string][]int64) *idlcore.LiteralMap { literalMap := make(map[string]*idlcore.Literal) for k, v := range values { // create LiteralCollection - literalList := make([]*idlcore.Literal, len(v)) + literalList := make([]*idlcore.Literal, 0, len(v)) for _, x := range v { literalList = append(literalList, &idlcore.Literal{ Value: &idlcore.Literal_Scalar{ From 7767668d004fa82f32cb546d436a91a941d3332c Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Wed, 14 Jun 2023 16:38:16 -0500 Subject: [PATCH 39/62] finished executing unit tests Signed-off-by: Daniel Rammer --- pkg/controller/nodes/array/handler_test.go | 82 +++++++++++++++++----- 1 file changed, 66 insertions(+), 16 deletions(-) diff --git a/pkg/controller/nodes/array/handler_test.go b/pkg/controller/nodes/array/handler_test.go index 6492cef24..fdf1daad0 100644 --- a/pkg/controller/nodes/array/handler_test.go +++ b/pkg/controller/nodes/array/handler_test.go @@ -34,6 +34,18 @@ import ( "github.com/stretchr/testify/mock" ) +var ( + taskRef = "taskRef" + arrayNodeSpec = v1alpha1.NodeSpec{ + ID: "foo", + ArrayNode: &v1alpha1.ArrayNodeSpec{ + SubNodeSpec: &v1alpha1.NodeSpec{ + TaskRef: &taskRef, + }, + }, + } +) + func createArrayNodeHandler(t *testing.T, ctx context.Context, nodeHandler interfaces.NodeHandler, dataStore *storage.DataStore, scope promutils.Scope) (interfaces.NodeHandler, error) { // mock components adminClient := launchplan.NewFailFastLaunchPlanExecutor() @@ -56,7 +68,9 @@ func createArrayNodeHandler(t *testing.T, ctx context.Context, nodeHandler inter return New(nodeExecutor, eventConfig, scope) } -func createNodeExecutionContext(t *testing.T, ctx context.Context, dataStore *storage.DataStore, eventRecorder interfaces.EventRecorder, inputLiteralMap *idlcore.LiteralMap, arrayNodeState *interfaces.ArrayNodeState) (interfaces.NodeExecutionContext, error) { +func createNodeExecutionContext(t *testing.T, ctx context.Context, dataStore *storage.DataStore, eventRecorder interfaces.EventRecorder, + inputLiteralMap *idlcore.LiteralMap, arrayNodeSpec *v1alpha1.NodeSpec, arrayNodeState *interfaces.ArrayNodeState) (interfaces.NodeExecutionContext, error) { + nCtx := &mocks.NodeExecutionContext{} // ContextualNodeLookup @@ -99,15 +113,7 @@ func createNodeExecutionContext(t *testing.T, ctx context.Context, dataStore *st )) // Node - taskRef := "arrayNodeTaskID" - nCtx.OnNode().Return(&v1alpha1.NodeSpec{ - ID: "foo", - ArrayNode: &v1alpha1.ArrayNodeSpec{ - SubNodeSpec: &v1alpha1.NodeSpec{ - TaskRef: &taskRef, - }, - }, - }) + nCtx.OnNode().Return(arrayNodeSpec) // NodeExecutionMetadata nodeExecutionMetadata := &mocks.NodeExecutionMetadata{} @@ -213,7 +219,7 @@ func TestHandleArrayNodePhaseNone(t *testing.T) { arrayNodeState := &interfaces.ArrayNodeState{ Phase: v1alpha1.ArrayNodePhaseNone, } - nCtx, err := createNodeExecutionContext(t, ctx, dataStore, eventRecorder, literalMap, arrayNodeState) + nCtx, err := createNodeExecutionContext(t, ctx, dataStore, eventRecorder, literalMap, &arrayNodeSpec, arrayNodeState) assert.NoError(t, err) // evaluate node @@ -241,6 +247,7 @@ func TestHandleArrayNodePhaseNone(t *testing.T) { func TestHandleArrayNodePhaseExecuting(t *testing.T) { ctx := context.Background() + minSuccessRatio := float32(0.5) // initailize universal variables inputMap := map[string][]int64{ @@ -260,6 +267,8 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { tests := []struct { name string + parallelism int + minSuccessRatio *float32 subNodePhases []v1alpha1.NodePhase subNodeTaskPhases []core.Phase subNodeTransitions []handler.Transition @@ -285,9 +294,27 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { expectedTransitionPhase: handler.EPhaseRunning, expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_RUNNING, idlcore.TaskExecution_RUNNING}, }, - // TODO @hamersaw - concurrency -> only start one node { - name: "AllSubNodeSuccedeed", + name: "StartOneSubNodeParallelism", + parallelism: 1, + subNodePhases: []v1alpha1.NodePhase{ + v1alpha1.NodePhaseQueued, + v1alpha1.NodePhaseQueued, + }, + subNodeTaskPhases: []core.Phase{ + core.PhaseUndefined, + core.PhaseUndefined, + }, + subNodeTransitions: []handler.Transition{ + handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), + handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), + }, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, + expectedTransitionPhase: handler.EPhaseRunning, + expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_RUNNING, idlcore.TaskExecution_QUEUED}, + }, + { + name: "AllSubNodesSuccedeed", subNodePhases: []v1alpha1.NodePhase{ v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseRunning, @@ -304,7 +331,26 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { expectedTransitionPhase: handler.EPhaseRunning, expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_SUCCEEDED, idlcore.TaskExecution_SUCCEEDED}, }, - // TODO @hamersaw - recording failure message + { + name: "OneSubNodeSuccedeedMinSuccessRatio", + minSuccessRatio: &minSuccessRatio, + subNodePhases: []v1alpha1.NodePhase{ + v1alpha1.NodePhaseRunning, + v1alpha1.NodePhaseRunning, + }, + subNodeTaskPhases: []core.Phase{ + core.PhaseRunning, + core.PhaseRunning, + }, + subNodeTransitions: []handler.Transition{ + handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(&handler.ExecutionInfo{})), + handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(0, "", "", &handler.ExecutionInfo{})), + }, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseSucceeding, + expectedTransitionPhase: handler.EPhaseRunning, + expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_SUCCEEDED, idlcore.TaskExecution_FAILED}, + }, + // TODO @hamersaw - recording failure message? { name: "OneSubNodeFailed", subNodePhases: []v1alpha1.NodePhase{ @@ -323,7 +369,6 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { expectedTransitionPhase: handler.EPhaseRunning, expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_FAILED, idlcore.TaskExecution_SUCCEEDED}, }, - // TODO @hamersaw - min_successes / min_success_ratio } for _, test := range tests { @@ -358,7 +403,12 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { // create NodeExecutionContext eventRecorder := newArrayEventRecorder() - nCtx, err := createNodeExecutionContext(t, ctx, dataStore, eventRecorder, literalMap, arrayNodeState) + + nodeSpec := arrayNodeSpec + nodeSpec.ArrayNode.Parallelism = uint32(test.parallelism) + nodeSpec.ArrayNode.MinSuccessRatio = test.minSuccessRatio + + nCtx, err := createNodeExecutionContext(t, ctx, dataStore, eventRecorder, literalMap, &arrayNodeSpec, arrayNodeState) assert.NoError(t, err) // initialize ArrayNodeHandler From 73a26f4943a3a245d3dc8243f8c72ad9c4a7111e Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Thu, 15 Jun 2023 08:26:23 -0500 Subject: [PATCH 40/62] finished succeeding unit tests Signed-off-by: Daniel Rammer --- pkg/controller/nodes/array/handler_test.go | 150 +++++++++++++++++++-- 1 file changed, 140 insertions(+), 10 deletions(-) diff --git a/pkg/controller/nodes/array/handler_test.go b/pkg/controller/nodes/array/handler_test.go index fdf1daad0..2c41e5334 100644 --- a/pkg/controller/nodes/array/handler_test.go +++ b/pkg/controller/nodes/array/handler_test.go @@ -69,7 +69,8 @@ func createArrayNodeHandler(t *testing.T, ctx context.Context, nodeHandler inter } func createNodeExecutionContext(t *testing.T, ctx context.Context, dataStore *storage.DataStore, eventRecorder interfaces.EventRecorder, - inputLiteralMap *idlcore.LiteralMap, arrayNodeSpec *v1alpha1.NodeSpec, arrayNodeState *interfaces.ArrayNodeState) (interfaces.NodeExecutionContext, error) { + outputVariables []string, inputLiteralMap *idlcore.LiteralMap, arrayNodeSpec *v1alpha1.NodeSpec, + arrayNodeState *interfaces.ArrayNodeState) (interfaces.NodeExecutionContext, error) { nCtx := &mocks.NodeExecutionContext{} @@ -83,9 +84,7 @@ func createNodeExecutionContext(t *testing.T, ctx context.Context, dataStore *st // ExecutionContext executionContext := &execmocks.ExecutionContext{} executionContext.OnGetEventVersion().Return(1) - executionContext.OnGetExecutionConfig().Return( - v1alpha1.ExecutionConfig{ - }) + executionContext.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{}) executionContext.OnGetExecutionID().Return( v1alpha1.ExecutionID{ &idlcore.WorkflowExecutionIdentifier{ @@ -98,6 +97,22 @@ func createNodeExecutionContext(t *testing.T, ctx context.Context, dataStore *st executionContext.OnGetRawOutputDataConfig().Return(v1alpha1.RawOutputDataConfig{}) executionContext.OnIsInterruptible().Return(false) executionContext.OnGetParentInfo().Return(nil) + outputVariableMap := make(map[string]*idlcore.Variable) + for _, outputVariable := range outputVariables { + outputVariableMap[outputVariable] = &idlcore.Variable{} + } + executionContext.OnGetTaskMatch(taskRef).Return( + &v1alpha1.TaskSpec{ + &idlcore.TaskTemplate{ + Interface: &idlcore.TypedInterface{ + Outputs: &idlcore.VariableMap{ + Variables: outputVariableMap, + }, + }, + }, + }, + nil, + ) nCtx.OnExecutionContext().Return(executionContext) // EventsRecorder @@ -146,7 +161,8 @@ func createNodeExecutionContext(t *testing.T, ctx context.Context, dataStore *st // NodeStatus nCtx.OnNodeStatus().Return(&v1alpha1.NodeStatus{ - DataDir: storage.DataReference("s3://bucket/foo"), + DataDir: storage.DataReference("s3://bucket/data"), + OutputDir: storage.DataReference("s3://bucket/output"), }) return nCtx, nil @@ -219,7 +235,7 @@ func TestHandleArrayNodePhaseNone(t *testing.T) { arrayNodeState := &interfaces.ArrayNodeState{ Phase: v1alpha1.ArrayNodePhaseNone, } - nCtx, err := createNodeExecutionContext(t, ctx, dataStore, eventRecorder, literalMap, &arrayNodeSpec, arrayNodeState) + nCtx, err := createNodeExecutionContext(t, ctx, dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState) assert.NoError(t, err) // evaluate node @@ -350,7 +366,6 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { expectedTransitionPhase: handler.EPhaseRunning, expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_SUCCEEDED, idlcore.TaskExecution_FAILED}, }, - // TODO @hamersaw - recording failure message? { name: "OneSubNodeFailed", subNodePhases: []v1alpha1.NodePhase{ @@ -379,7 +394,7 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { }, scope) assert.NoError(t, err) - // configure SubNodePhases and SubNodeTaskPhases + // initialize ArrayNodeState arrayNodeState := &interfaces.ArrayNodeState{ Phase: v1alpha1.ArrayNodePhaseExecuting, } @@ -408,7 +423,7 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { nodeSpec.ArrayNode.Parallelism = uint32(test.parallelism) nodeSpec.ArrayNode.MinSuccessRatio = test.minSuccessRatio - nCtx, err := createNodeExecutionContext(t, ctx, dataStore, eventRecorder, literalMap, &arrayNodeSpec, arrayNodeState) + nCtx, err := createNodeExecutionContext(t, ctx, dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState) assert.NoError(t, err) // initialize ArrayNodeHandler @@ -459,7 +474,122 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { } func TestHandleArrayNodePhaseSucceeding(t *testing.T) { - // TODO @hamersaw - complete + ctx := context.Background() + scope := promutils.NewTestScope() + dataStore, err := storage.NewDataStore(&storage.Config{ + Type: storage.TypeMemory, + }, scope) + assert.NoError(t, err) + nodeHandler := &mocks.NodeHandler{} + valueOne := 1 + + // initialize ArrayNodeHandler + arrayNodeHandler, err := createArrayNodeHandler(t, ctx, nodeHandler, dataStore, scope) + assert.NoError(t, err) + + tests := []struct { + name string + outputVariable string + outputValues []*int + subNodePhases []v1alpha1.NodePhase + expectedArrayNodePhase v1alpha1.ArrayNodePhase + expectedTransitionPhase handler.EPhase + }{ + { + name: "Success", + outputValues: []*int{&valueOne, nil}, + outputVariable: "foo", + subNodePhases: []v1alpha1.NodePhase{ + v1alpha1.NodePhaseSucceeded, + v1alpha1.NodePhaseFailed, + }, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseSucceeding, + expectedTransitionPhase: handler.EPhaseSuccess, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // initialize ArrayNodeState + subNodePhases, err := bitarray.NewCompactArray(uint(len(test.subNodePhases)), bitarray.Item(v1alpha1.NodePhaseRecovered)) + assert.NoError(t, err) + for i, nodePhase := range test.subNodePhases { + subNodePhases.SetItem(i, bitarray.Item(nodePhase)) + } + + retryAttempts, err := bitarray.NewCompactArray(uint(len(test.subNodePhases)), bitarray.Item(1)) + assert.NoError(t, err) + + arrayNodeState := &interfaces.ArrayNodeState{ + Phase: v1alpha1.ArrayNodePhaseSucceeding, + SubNodePhases: subNodePhases, + SubNodeRetryAttempts: retryAttempts, + } + + // create NodeExecutionContext + eventRecorder := newArrayEventRecorder() + literalMap := &idlcore.LiteralMap{} + nCtx, err := createNodeExecutionContext(t, ctx, dataStore, eventRecorder, []string{test.outputVariable}, literalMap, &arrayNodeSpec, arrayNodeState) + assert.NoError(t, err) + + // write mocked output files + for i, outputValue := range test.outputValues { + if outputValue == nil { + continue + } + + outputFile := storage.DataReference(fmt.Sprintf("s3://bucket/output/%d/0/outputs.pb", i)) + outputLiteralMap := &idlcore.LiteralMap{ + Literals: map[string]*idlcore.Literal{ + test.outputVariable: &idlcore.Literal{ + Value: &idlcore.Literal_Scalar{ + Scalar: &idlcore.Scalar{ + Value: &idlcore.Scalar_Primitive{ + Primitive: &idlcore.Primitive{ + Value: &idlcore.Primitive_Integer{ + Integer: int64(*outputValue), + }, + }, + }, + }, + }, + }, + }, + } + + err := nCtx.DataStore().WriteProtobuf(ctx, outputFile, storage.Options{}, outputLiteralMap) + assert.NoError(t, err) + } + + // evaluate node + transition, err := arrayNodeHandler.Handle(ctx, nCtx) + assert.NoError(t, err) + + // validate results + assert.Equal(t, test.expectedArrayNodePhase, arrayNodeState.Phase) + assert.Equal(t, test.expectedTransitionPhase, transition.Info().GetPhase()) + + // validate output file + var outputs idlcore.LiteralMap + outputFile := v1alpha1.GetOutputsFile(nCtx.NodeStatus().GetOutputDir()) + err = nCtx.DataStore().ReadProtobuf(ctx, outputFile, &outputs) + assert.NoError(t, err) + + assert.Len(t, outputs.GetLiterals(), 1) + + collection := outputs.GetLiterals()[test.outputVariable].GetCollection() + assert.NotNil(t, collection) + + assert.Len(t, collection.GetLiterals(), len(test.outputValues)) + for i, outputValue := range test.outputValues { + if outputValue == nil { + assert.NotNil(t, collection.GetLiterals()[i].GetScalar()) + } else { + assert.Equal(t, int64(*outputValue), collection.GetLiterals()[i].GetScalar().GetPrimitive().GetInteger()) + } + } + }) + } } func TestHandleArrayNodePhaseFailing(t *testing.T) { From 575deea941c52987acecf2a3e643bfed4c77f8ca Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Thu, 15 Jun 2023 09:10:46 -0500 Subject: [PATCH 41/62] wrote failing phase unit tests Signed-off-by: Daniel Rammer --- pkg/controller/nodes/array/handler_test.go | 74 +++++++++++++++++++++- 1 file changed, 73 insertions(+), 1 deletion(-) diff --git a/pkg/controller/nodes/array/handler_test.go b/pkg/controller/nodes/array/handler_test.go index 2c41e5334..e2858a9b3 100644 --- a/pkg/controller/nodes/array/handler_test.go +++ b/pkg/controller/nodes/array/handler_test.go @@ -593,7 +593,79 @@ func TestHandleArrayNodePhaseSucceeding(t *testing.T) { } func TestHandleArrayNodePhaseFailing(t *testing.T) { - // TODO @hamersaw - complete + ctx := context.Background() + scope := promutils.NewTestScope() + dataStore, err := storage.NewDataStore(&storage.Config{ + Type: storage.TypeMemory, + }, scope) + assert.NoError(t, err) + + nodeHandler := &mocks.NodeHandler{} + nodeHandler.OnAbortMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) + nodeHandler.OnFinalizeMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) + + // initialize ArrayNodeHandler + arrayNodeHandler, err := createArrayNodeHandler(t, ctx, nodeHandler, dataStore, scope) + assert.NoError(t, err) + + tests := []struct { + name string + subNodePhases []v1alpha1.NodePhase + expectedArrayNodePhase v1alpha1.ArrayNodePhase + expectedTransitionPhase handler.EPhase + expectedAbortCalls int + }{ + { + name: "Success", + subNodePhases: []v1alpha1.NodePhase{ + v1alpha1.NodePhaseRunning, + v1alpha1.NodePhaseSucceeded, + v1alpha1.NodePhaseFailed, + }, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseFailing, + expectedTransitionPhase: handler.EPhaseFailed, + expectedAbortCalls: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // initialize ArrayNodeState + arrayNodeState := &interfaces.ArrayNodeState{ + Phase: v1alpha1.ArrayNodePhaseFailing, + } + + for _, item := range []struct{arrayReference *bitarray.CompactArray; maxValue int}{ + {arrayReference: &arrayNodeState.SubNodePhases, maxValue: int(v1alpha1.NodePhaseRecovered)}, + {arrayReference: &arrayNodeState.SubNodeTaskPhases, maxValue: len(core.Phases)-1}, + {arrayReference: &arrayNodeState.SubNodeRetryAttempts, maxValue: 1}, + {arrayReference: &arrayNodeState.SubNodeSystemFailures, maxValue: 1}, + } { + + *item.arrayReference, err = bitarray.NewCompactArray(uint(len(test.subNodePhases)), bitarray.Item(item.maxValue)) + assert.NoError(t, err) + } + + for i, nodePhase := range test.subNodePhases { + arrayNodeState.SubNodePhases.SetItem(i, bitarray.Item(nodePhase)) + } + + // create NodeExecutionContext + eventRecorder := newArrayEventRecorder() + literalMap := &idlcore.LiteralMap{} + nCtx, err := createNodeExecutionContext(t, ctx, dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState) + assert.NoError(t, err) + + // evaluate node + transition, err := arrayNodeHandler.Handle(ctx, nCtx) + assert.NoError(t, err) + + // validate results + assert.Equal(t, test.expectedArrayNodePhase, arrayNodeState.Phase) + assert.Equal(t, test.expectedTransitionPhase, transition.Info().GetPhase()) + nodeHandler.AssertNumberOfCalls(t, "Abort", test.expectedAbortCalls) + }) + } } func init() { From 6dc6b882d8a04428af49027b62ff3b49400bbbd3 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Thu, 15 Jun 2023 16:19:07 -0500 Subject: [PATCH 42/62] moving towards complete unit_test success Signed-off-by: Daniel Rammer --- pkg/controller/nodes/branch/handler_test.go | 50 ++- .../nodes/dynamic/dynamic_workflow_test.go | 18 +- pkg/controller/nodes/dynamic/handler_test.go | 101 ++--- pkg/controller/nodes/dynamic/utils_test.go | 2 +- pkg/controller/nodes/end/handler_test.go | 2 +- pkg/controller/nodes/executor_test.go | 380 +++++++++--------- pkg/controller/nodes/gate/handler_test.go | 5 +- .../nodes/node_exec_context_test.go | 20 +- .../nodes/subworkflow/handler_test.go | 26 +- .../nodes/subworkflow/launchplan_test.go | 2 +- .../nodes/subworkflow/subworkflow_test.go | 8 +- pkg/controller/nodes/task/handler_test.go | 89 ++-- pkg/controller/nodes/task/setup_ctx_test.go | 4 +- .../nodes/task/taskexec_context_test.go | 10 +- pkg/controller/nodes/task/transformer_test.go | 8 +- pkg/controller/workflow/executor_test.go | 11 +- 16 files changed, 389 insertions(+), 347 deletions(-) diff --git a/pkg/controller/nodes/branch/handler_test.go b/pkg/controller/nodes/branch/handler_test.go index 5711de5d4..a2271bc6a 100644 --- a/pkg/controller/nodes/branch/handler_test.go +++ b/pkg/controller/nodes/branch/handler_test.go @@ -26,7 +26,8 @@ import ( "github.com/flyteorg/flytepropeller/pkg/controller/executors" execMocks "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" ) var eventConfig = &config.EventConfig{ @@ -34,27 +35,34 @@ var eventConfig = &config.EventConfig{ } type branchNodeStateHolder struct { - s handler.BranchNodeState + s interfaces.BranchNodeState } -func (t *branchNodeStateHolder) PutTaskNodeState(s handler.TaskNodeState) error { +func (t *branchNodeStateHolder) ClearNodeStatus() { +} + +func (t *branchNodeStateHolder) PutTaskNodeState(s interfaces.TaskNodeState) error { panic("not implemented") } -func (t *branchNodeStateHolder) PutBranchNode(s handler.BranchNodeState) error { +func (t *branchNodeStateHolder) PutBranchNode(s interfaces.BranchNodeState) error { t.s = s return nil } -func (t branchNodeStateHolder) PutWorkflowNodeState(s handler.WorkflowNodeState) error { +func (t branchNodeStateHolder) PutWorkflowNodeState(s interfaces.WorkflowNodeState) error { + panic("not implemented") +} + +func (t branchNodeStateHolder) PutDynamicNodeState(s interfaces.DynamicNodeState) error { panic("not implemented") } -func (t branchNodeStateHolder) PutDynamicNodeState(s handler.DynamicNodeState) error { +func (t branchNodeStateHolder) PutGateNodeState(s interfaces.GateNodeState) error { panic("not implemented") } -func (t branchNodeStateHolder) PutGateNodeState(s handler.GateNodeState) error { +func (t branchNodeStateHolder) PutArrayNodeState(s interfaces.ArrayNodeState) error { panic("not implemented") } @@ -71,7 +79,7 @@ func (parentInfo) CurrentAttempt() uint32 { func createNodeContext(phase v1alpha1.BranchNodePhase, childNodeID *v1alpha1.NodeID, n v1alpha1.ExecutableNode, inputs *core.LiteralMap, nl *execMocks.NodeLookup, eCtx executors.ExecutionContext) (*mocks.NodeExecutionContext, *branchNodeStateHolder) { - branchNodeState := handler.BranchNodeState{ + branchNodeState := interfaces.BranchNodeState{ FinalizedNodeID: childNodeID, Phase: phase, } @@ -119,7 +127,7 @@ func createNodeContext(phase v1alpha1.BranchNodePhase, childNodeID *v1alpha1.Nod nCtx.OnEnqueueOwnerFunc().Return(nil) nr := &mocks.NodeStateReader{} - nr.OnGetBranchNode().Return(handler.BranchNodeState{ + nr.OnGetBranchNodeState().Return(interfaces.BranchNodeState{ FinalizedNodeID: childNodeID, Phase: phase, }) @@ -151,7 +159,7 @@ func TestBranchHandler_RecurseDownstream(t *testing.T) { tests := []struct { name string - ns executors.NodeStatus + ns interfaces.NodeStatus err error nodeStatus *mocks2.ExecutableNodeStatus branchTakenNode v1alpha1.ExecutableNode @@ -160,17 +168,17 @@ func TestBranchHandler_RecurseDownstream(t *testing.T) { childPhase v1alpha1.NodePhase upstreamNodeID string }{ - {"upstreamNodeExists", executors.NodeStatusPending, nil, + {"upstreamNodeExists", interfaces.NodeStatusPending, nil, &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, "n2"}, - {"childNodeError", executors.NodeStatusUndefined, fmt.Errorf("err"), + {"childNodeError", interfaces.NodeStatusUndefined, fmt.Errorf("err"), &mocks2.ExecutableNodeStatus{}, bn, true, handler.EPhaseUndefined, v1alpha1.NodePhaseFailed, ""}, - {"childPending", executors.NodeStatusPending, nil, + {"childPending", interfaces.NodeStatusPending, nil, &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, ""}, - {"childStillRunning", executors.NodeStatusRunning, nil, + {"childStillRunning", interfaces.NodeStatusRunning, nil, &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseRunning, ""}, - {"childFailure", executors.NodeStatusFailed(expectedError), nil, + {"childFailure", interfaces.NodeStatusFailed(expectedError), nil, &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseFailed, v1alpha1.NodePhaseFailed, ""}, - {"childComplete", executors.NodeStatusComplete, nil, + {"childComplete", interfaces.NodeStatusComplete, nil, &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseSuccess, v1alpha1.NodePhaseSucceeded, ""}, } for _, test := range tests { @@ -188,7 +196,7 @@ func TestBranchHandler_RecurseDownstream(t *testing.T) { nCtx, _ := createNodeContext(v1alpha1.BranchNodeNotYetEvaluated, &childNodeID, n, nil, mockNodeLookup, eCtx) newParentInfo, _ := common.CreateParentInfo(parentInfo{}, nCtx.NodeID(), nCtx.CurrentAttempt()) expectedExecContext := executors.NewExecutionContextWithParentInfo(nCtx.ExecutionContext(), newParentInfo) - mockNodeExecutor := &execMocks.Node{} + mockNodeExecutor := &mocks.Node{} mockNodeExecutor.OnRecursiveNodeHandlerMatch( mock.Anything, // ctx mock.MatchedBy(func(e executors.ExecutionContext) bool { return assert.Equal(t, e, expectedExecContext) }), @@ -295,7 +303,7 @@ func TestBranchHandler_AbortNode(t *testing.T) { assert.NotNil(t, w) t.Run("NoBranchNode", func(t *testing.T) { - mockNodeExecutor := &execMocks.Node{} + mockNodeExecutor := &mocks.Node{} mockNodeExecutor.OnAbortHandlerMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(fmt.Errorf("err")) @@ -308,7 +316,7 @@ func TestBranchHandler_AbortNode(t *testing.T) { }) t.Run("BranchNodeSuccess", func(t *testing.T) { - mockNodeExecutor := &execMocks.Node{} + mockNodeExecutor := &mocks.Node{} mockNodeLookup := &execMocks.NodeLookup{} mockNodeLookup.OnToNodeMatch(mock.Anything).Return(nil, nil) eCtx := &execMocks.ExecutionContext{} @@ -329,7 +337,7 @@ func TestBranchHandler_AbortNode(t *testing.T) { func TestBranchHandler_Initialize(t *testing.T) { ctx := context.TODO() - mockNodeExecutor := &execMocks.Node{} + mockNodeExecutor := &mocks.Node{} branch := New(mockNodeExecutor, eventConfig, promutils.NewTestScope()) assert.NoError(t, branch.Setup(ctx, nil)) } @@ -337,7 +345,7 @@ func TestBranchHandler_Initialize(t *testing.T) { // TODO incomplete test suite, add more func TestBranchHandler_HandleNode(t *testing.T) { ctx := context.TODO() - mockNodeExecutor := &execMocks.Node{} + mockNodeExecutor := &mocks.Node{} branch := New(mockNodeExecutor, eventConfig, promutils.NewTestScope()) childNodeID := "child" childDatadir := v1alpha1.DataReference("test") diff --git a/pkg/controller/nodes/dynamic/dynamic_workflow_test.go b/pkg/controller/nodes/dynamic/dynamic_workflow_test.go index 2e539a00f..6750e79e2 100644 --- a/pkg/controller/nodes/dynamic/dynamic_workflow_test.go +++ b/pkg/controller/nodes/dynamic/dynamic_workflow_test.go @@ -24,8 +24,8 @@ import ( mocks2 "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" mocks4 "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" mocks6 "github.com/flyteorg/flytepropeller/pkg/controller/nodes/dynamic/mocks" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" mocks5 "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/mocks" ) @@ -135,7 +135,7 @@ func Test_dynamicNodeHandler_buildContextualDynamicWorkflow_withLaunchPlans(t *t w.OnGetExecutionStatus().Return(ws) r := &mocks.NodeStateReader{} - r.OnGetDynamicNodeState().Return(handler.DynamicNodeState{ + r.OnGetDynamicNodeState().Return(interfaces.DynamicNodeState{ Phase: v1alpha1.DynamicNodePhaseExecuting, }) nCtx.OnNodeStateReader().Return(r) @@ -183,7 +183,7 @@ func Test_dynamicNodeHandler_buildContextualDynamicWorkflow_withLaunchPlans(t *t }, }, nil) h := &mocks6.TaskNodeHandler{} - n := &mocks4.Node{} + n := &mocks.Node{} d := dynamicNodeTaskNodeHandler{ TaskNodeHandler: h, nodeExecutor: n, @@ -255,7 +255,7 @@ func Test_dynamicNodeHandler_buildContextualDynamicWorkflow_withLaunchPlans(t *t }, }, nil) h := &mocks6.TaskNodeHandler{} - n := &mocks4.Node{} + n := &mocks.Node{} d := dynamicNodeTaskNodeHandler{ TaskNodeHandler: h, nodeExecutor: n, @@ -324,7 +324,7 @@ func Test_dynamicNodeHandler_buildContextualDynamicWorkflow_withLaunchPlans(t *t }, }, nil) h := &mocks6.TaskNodeHandler{} - n := &mocks4.Node{} + n := &mocks.Node{} d := dynamicNodeTaskNodeHandler{ TaskNodeHandler: h, nodeExecutor: n, @@ -407,7 +407,7 @@ func Test_dynamicNodeHandler_buildContextualDynamicWorkflow_withLaunchPlans(t *t }, }, nil) h := &mocks6.TaskNodeHandler{} - n := &mocks4.Node{} + n := &mocks.Node{} d := dynamicNodeTaskNodeHandler{ TaskNodeHandler: h, nodeExecutor: n, @@ -461,7 +461,7 @@ func Test_dynamicNodeHandler_buildContextualDynamicWorkflow_withLaunchPlans(t *t mockLPLauncher := &mocks5.Reader{} h := &mocks6.TaskNodeHandler{} - n := &mocks4.Node{} + n := &mocks.Node{} d := dynamicNodeTaskNodeHandler{ TaskNodeHandler: h, nodeExecutor: n, @@ -550,7 +550,7 @@ func Test_dynamicNodeHandler_buildContextualDynamicWorkflow_withLaunchPlans(t *t }, nil) h := &mocks6.TaskNodeHandler{} - n := &mocks4.Node{} + n := &mocks.Node{} d := dynamicNodeTaskNodeHandler{ TaskNodeHandler: h, nodeExecutor: n, diff --git a/pkg/controller/nodes/dynamic/handler_test.go b/pkg/controller/nodes/dynamic/handler_test.go index 27bc2d935..7fdb14dcd 100644 --- a/pkg/controller/nodes/dynamic/handler_test.go +++ b/pkg/controller/nodes/dynamic/handler_test.go @@ -28,35 +28,42 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" flyteMocks "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" - "github.com/flyteorg/flytepropeller/pkg/controller/executors" executorMocks "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/dynamic/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" - nodeMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + nodeMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" ) type dynamicNodeStateHolder struct { - s handler.DynamicNodeState + s interfaces.DynamicNodeState } -func (t *dynamicNodeStateHolder) PutTaskNodeState(s handler.TaskNodeState) error { +func (t *dynamicNodeStateHolder) ClearNodeStatus() { +} + +func (t *dynamicNodeStateHolder) PutTaskNodeState(s interfaces.TaskNodeState) error { panic("not implemented") } -func (t dynamicNodeStateHolder) PutBranchNode(s handler.BranchNodeState) error { +func (t dynamicNodeStateHolder) PutBranchNode(s interfaces.BranchNodeState) error { panic("not implemented") } -func (t dynamicNodeStateHolder) PutWorkflowNodeState(s handler.WorkflowNodeState) error { +func (t dynamicNodeStateHolder) PutWorkflowNodeState(s interfaces.WorkflowNodeState) error { panic("not implemented") } -func (t *dynamicNodeStateHolder) PutDynamicNodeState(s handler.DynamicNodeState) error { +func (t *dynamicNodeStateHolder) PutDynamicNodeState(s interfaces.DynamicNodeState) error { t.s = s return nil } -func (t dynamicNodeStateHolder) PutGateNodeState(s handler.GateNodeState) error { +func (t dynamicNodeStateHolder) PutGateNodeState(s interfaces.GateNodeState) error { + panic("not implemented") +} + +func (t dynamicNodeStateHolder) PutArrayNodeState(s interfaces.ArrayNodeState) error { panic("not implemented") } @@ -141,7 +148,7 @@ func Test_dynamicNodeHandler_Handle_Parent(t *testing.T) { nCtx.OnDataStore().Return(dataStore) r := &nodeMocks.NodeStateReader{} - r.OnGetDynamicNodeState().Return(handler.DynamicNodeState{}) + r.OnGetDynamicNodeState().Return(interfaces.DynamicNodeState{}) nCtx.OnNodeStateReader().Return(r) return nCtx } @@ -186,7 +193,7 @@ func Test_dynamicNodeHandler_Handle_Parent(t *testing.T) { } h := &mocks.TaskNodeHandler{} mockLPLauncher := &lpMocks.Reader{} - n := &executorMocks.Node{} + n := &nodeMocks.Node{} if tt.args.isErr { h.OnHandleMatch(mock.Anything, mock.Anything).Return(handler.UnknownTransition, fmt.Errorf("error")) } else { @@ -282,7 +289,7 @@ func Test_dynamicNodeHandler_Handle_ParentFinalize(t *testing.T) { nCtx.OnDataStore().Return(dataStore) r := &nodeMocks.NodeStateReader{} - r.On("GetDynamicNodeState").Return(handler.DynamicNodeState{ + r.On("GetDynamicNodeState").Return(interfaces.DynamicNodeState{ Phase: v1alpha1.DynamicNodePhaseParentFinalizing, }) nCtx.OnNodeStateReader().Return(r) @@ -293,14 +300,14 @@ func Test_dynamicNodeHandler_Handle_ParentFinalize(t *testing.T) { t.Run("parent-finalize-success", func(t *testing.T) { nCtx := createNodeContext("test") s := &dynamicNodeStateHolder{ - s: handler.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseParentFinalizing}, + s: interfaces.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseParentFinalizing}, } nCtx.OnNodeStateWriter().Return(s) f, err := nCtx.DataStore().ConstructReference(context.TODO(), nCtx.NodeStatus().GetDataDir(), "futures.pb") assert.NoError(t, err) dj := &core.DynamicJobSpec{} mockLPLauncher := &lpMocks.Reader{} - n := &executorMocks.Node{} + n := &nodeMocks.Node{} assert.NoError(t, nCtx.DataStore().WriteProtobuf(context.TODO(), f, storage.Options{}, dj)) h := &mocks.TaskNodeHandler{} h.OnFinalizeMatch(mock.Anything, mock.Anything).Return(nil) @@ -313,14 +320,14 @@ func Test_dynamicNodeHandler_Handle_ParentFinalize(t *testing.T) { t.Run("parent-finalize-error", func(t *testing.T) { nCtx := createNodeContext("test") s := &dynamicNodeStateHolder{ - s: handler.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseParentFinalizing}, + s: interfaces.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseParentFinalizing}, } nCtx.OnNodeStateWriter().Return(s) f, err := nCtx.DataStore().ConstructReference(context.TODO(), nCtx.NodeStatus().GetDataDir(), "futures.pb") assert.NoError(t, err) dj := &core.DynamicJobSpec{} mockLPLauncher := &lpMocks.Reader{} - n := &executorMocks.Node{} + n := &nodeMocks.Node{} assert.NoError(t, nCtx.DataStore().WriteProtobuf(context.TODO(), f, storage.Options{}, dj)) h := &mocks.TaskNodeHandler{} h.OnFinalizeMatch(mock.Anything, mock.Anything).Return(fmt.Errorf("err")) @@ -506,7 +513,7 @@ func Test_dynamicNodeHandler_Handle_SubTaskV1(t *testing.T) { w.OnGetExecutionStatus().Return(ws) r := &nodeMocks.NodeStateReader{} - r.OnGetDynamicNodeState().Return(handler.DynamicNodeState{ + r.OnGetDynamicNodeState().Return(interfaces.DynamicNodeState{ Phase: v1alpha1.DynamicNodePhaseExecuting, }) nCtx.OnNodeStateReader().Return(r) @@ -546,7 +553,7 @@ func Test_dynamicNodeHandler_Handle_SubTaskV1(t *testing.T) { } type args struct { - s executors.NodeStatus + s interfaces.NodeStatus isErr bool dj *core.DynamicJobSpec validErr *io.ExecutionError @@ -565,15 +572,15 @@ func Test_dynamicNodeHandler_Handle_SubTaskV1(t *testing.T) { want want }{ {"error", args{isErr: true, dj: createDynamicJobSpec()}, want{isErr: true}}, - {"success", args{s: executors.NodeStatusSuccess, dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, - {"complete", args{s: executors.NodeStatusComplete, dj: createDynamicJobSpec(), generateOutputs: true, validCacheStatus: &validCachePopulatedStatus}, want{p: handler.EPhaseSuccess, phase: v1alpha1.DynamicNodePhaseExecuting, info: execInfoWithTaskNodeMeta}}, - {"complete-no-outputs", args{s: executors.NodeStatusComplete, dj: createDynamicJobSpec(), generateOutputs: false}, want{p: handler.EPhaseRetryableFailure, phase: v1alpha1.DynamicNodePhaseFailing}}, - {"complete-valid-error-retryable", args{s: executors.NodeStatusComplete, dj: createDynamicJobSpec(), validErr: &io.ExecutionError{IsRecoverable: true}, generateOutputs: true}, want{p: handler.EPhaseRetryableFailure, phase: v1alpha1.DynamicNodePhaseFailing, info: execInfoOutputOnly}}, - {"complete-valid-error", args{s: executors.NodeStatusComplete, dj: createDynamicJobSpec(), validErr: &io.ExecutionError{}, generateOutputs: true}, want{p: handler.EPhaseFailed, phase: v1alpha1.DynamicNodePhaseFailing, info: execInfoOutputOnly}}, - {"failed", args{s: executors.NodeStatusFailed(&core.ExecutionError{}), dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseFailing}}, - {"running", args{s: executors.NodeStatusRunning, dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, - {"running-valid-err", args{s: executors.NodeStatusRunning, dj: createDynamicJobSpec(), validErr: &io.ExecutionError{}}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, - {"queued", args{s: executors.NodeStatusQueued, dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, + {"success", args{s: interfaces.NodeStatusSuccess, dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, + {"complete", args{s: interfaces.NodeStatusComplete, dj: createDynamicJobSpec(), generateOutputs: true, validCacheStatus: &validCachePopulatedStatus}, want{p: handler.EPhaseSuccess, phase: v1alpha1.DynamicNodePhaseExecuting, info: execInfoWithTaskNodeMeta}}, + {"complete-no-outputs", args{s: interfaces.NodeStatusComplete, dj: createDynamicJobSpec(), generateOutputs: false}, want{p: handler.EPhaseRetryableFailure, phase: v1alpha1.DynamicNodePhaseFailing}}, + {"complete-valid-error-retryable", args{s: interfaces.NodeStatusComplete, dj: createDynamicJobSpec(), validErr: &io.ExecutionError{IsRecoverable: true}, generateOutputs: true}, want{p: handler.EPhaseRetryableFailure, phase: v1alpha1.DynamicNodePhaseFailing, info: execInfoOutputOnly}}, + {"complete-valid-error", args{s: interfaces.NodeStatusComplete, dj: createDynamicJobSpec(), validErr: &io.ExecutionError{}, generateOutputs: true}, want{p: handler.EPhaseFailed, phase: v1alpha1.DynamicNodePhaseFailing, info: execInfoOutputOnly}}, + {"failed", args{s: interfaces.NodeStatusFailed(&core.ExecutionError{}), dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseFailing}}, + {"running", args{s: interfaces.NodeStatusRunning, dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, + {"running-valid-err", args{s: interfaces.NodeStatusRunning, dj: createDynamicJobSpec(), validErr: &io.ExecutionError{}}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, + {"queued", args{s: interfaces.NodeStatusQueued, dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -601,9 +608,9 @@ func Test_dynamicNodeHandler_Handle_SubTaskV1(t *testing.T) { } h.OnValidateOutputAndCacheAddMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(validCacheStatus, nil, nil) } - n := &executorMocks.Node{} + n := &nodeMocks.Node{} if tt.args.isErr { - n.OnRecursiveNodeHandlerMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(executors.NodeStatusUndefined, fmt.Errorf("error")) + n.OnRecursiveNodeHandlerMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(interfaces.NodeStatusUndefined, fmt.Errorf("error")) } else { n.OnRecursiveNodeHandlerMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.s, nil) } @@ -739,7 +746,7 @@ func Test_dynamicNodeHandler_Handle_SubTask(t *testing.T) { w.OnGetExecutionStatus().Return(ws) r := &nodeMocks.NodeStateReader{} - r.OnGetDynamicNodeState().Return(handler.DynamicNodeState{ + r.OnGetDynamicNodeState().Return(interfaces.DynamicNodeState{ Phase: v1alpha1.DynamicNodePhaseExecuting, }) nCtx.OnNodeStateReader().Return(r) @@ -747,7 +754,7 @@ func Test_dynamicNodeHandler_Handle_SubTask(t *testing.T) { } type args struct { - s executors.NodeStatus + s interfaces.NodeStatus isErr bool dj *core.DynamicJobSpec validErr *io.ExecutionError @@ -764,15 +771,15 @@ func Test_dynamicNodeHandler_Handle_SubTask(t *testing.T) { want want }{ {"error", args{isErr: true, dj: createDynamicJobSpec()}, want{isErr: true}}, - {"success", args{s: executors.NodeStatusSuccess, dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, - {"complete", args{s: executors.NodeStatusComplete, dj: createDynamicJobSpec(), generateOutputs: true}, want{p: handler.EPhaseSuccess, phase: v1alpha1.DynamicNodePhaseExecuting}}, - {"complete-no-outputs", args{s: executors.NodeStatusComplete, dj: createDynamicJobSpec(), generateOutputs: false}, want{p: handler.EPhaseRetryableFailure, phase: v1alpha1.DynamicNodePhaseFailing}}, - {"complete-valid-error-retryable", args{s: executors.NodeStatusComplete, dj: createDynamicJobSpec(), validErr: &io.ExecutionError{IsRecoverable: true}, generateOutputs: true}, want{p: handler.EPhaseRetryableFailure, phase: v1alpha1.DynamicNodePhaseFailing}}, - {"complete-valid-error", args{s: executors.NodeStatusComplete, dj: createDynamicJobSpec(), validErr: &io.ExecutionError{}, generateOutputs: true}, want{p: handler.EPhaseFailed, phase: v1alpha1.DynamicNodePhaseFailing}}, - {"failed", args{s: executors.NodeStatusFailed(&core.ExecutionError{}), dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseFailing}}, - {"running", args{s: executors.NodeStatusRunning, dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, - {"running-valid-err", args{s: executors.NodeStatusRunning, dj: createDynamicJobSpec(), validErr: &io.ExecutionError{}}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, - {"queued", args{s: executors.NodeStatusQueued, dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, + {"success", args{s: interfaces.NodeStatusSuccess, dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, + {"complete", args{s: interfaces.NodeStatusComplete, dj: createDynamicJobSpec(), generateOutputs: true}, want{p: handler.EPhaseSuccess, phase: v1alpha1.DynamicNodePhaseExecuting}}, + {"complete-no-outputs", args{s: interfaces.NodeStatusComplete, dj: createDynamicJobSpec(), generateOutputs: false}, want{p: handler.EPhaseRetryableFailure, phase: v1alpha1.DynamicNodePhaseFailing}}, + {"complete-valid-error-retryable", args{s: interfaces.NodeStatusComplete, dj: createDynamicJobSpec(), validErr: &io.ExecutionError{IsRecoverable: true}, generateOutputs: true}, want{p: handler.EPhaseRetryableFailure, phase: v1alpha1.DynamicNodePhaseFailing}}, + {"complete-valid-error", args{s: interfaces.NodeStatusComplete, dj: createDynamicJobSpec(), validErr: &io.ExecutionError{}, generateOutputs: true}, want{p: handler.EPhaseFailed, phase: v1alpha1.DynamicNodePhaseFailing}}, + {"failed", args{s: interfaces.NodeStatusFailed(&core.ExecutionError{}), dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseFailing}}, + {"running", args{s: interfaces.NodeStatusRunning, dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, + {"running-valid-err", args{s: interfaces.NodeStatusRunning, dj: createDynamicJobSpec(), validErr: &io.ExecutionError{}}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, + {"queued", args{s: interfaces.NodeStatusQueued, dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -792,9 +799,9 @@ func Test_dynamicNodeHandler_Handle_SubTask(t *testing.T) { } else { h.OnValidateOutputAndCacheAddMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(catalog.NewStatus(core.CatalogCacheStatus_CACHE_HIT, &core.CatalogMetadata{ArtifactTag: &core.CatalogArtifactTag{Name: "name", ArtifactId: "id"}}), nil, nil) } - n := &executorMocks.Node{} + n := &nodeMocks.Node{} if tt.args.isErr { - n.OnRecursiveNodeHandlerMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(executors.NodeStatusUndefined, fmt.Errorf("error")) + n.OnRecursiveNodeHandlerMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(interfaces.NodeStatusUndefined, fmt.Errorf("error")) } else { n.OnRecursiveNodeHandlerMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.s, nil) } @@ -863,7 +870,7 @@ func TestDynamicNodeTaskNodeHandler_Finalize(t *testing.T) { ctx := context.TODO() t.Run("dynamicnodephase-none", func(t *testing.T) { - s := handler.DynamicNodeState{ + s := interfaces.DynamicNodeState{ Phase: v1alpha1.DynamicNodePhaseNone, Reason: "", } @@ -876,7 +883,7 @@ func TestDynamicNodeTaskNodeHandler_Finalize(t *testing.T) { mockLPLauncher := &lpMocks.Reader{} h := &mocks.TaskNodeHandler{} h.OnFinalize(ctx, nCtx).Return(nil) - n := &executorMocks.Node{} + n := &nodeMocks.Node{} d := New(h, n, mockLPLauncher, eventConfig, promutils.NewTestScope()) assert.NoError(t, d.Finalize(ctx, nCtx)) assert.NotZero(t, len(h.ExpectedCalls)) @@ -990,7 +997,7 @@ func TestDynamicNodeTaskNodeHandler_Finalize(t *testing.T) { w.OnGetExecutionStatus().Return(ws) r := &nodeMocks.NodeStateReader{} - r.OnGetDynamicNodeState().Return(handler.DynamicNodeState{ + r.OnGetDynamicNodeState().Return(interfaces.DynamicNodeState{ Phase: v1alpha1.DynamicNodePhaseExecuting, }) nCtx.OnNodeStateReader().Return(r) @@ -1007,7 +1014,7 @@ func TestDynamicNodeTaskNodeHandler_Finalize(t *testing.T) { mockLPLauncher := &lpMocks.Reader{} h := &mocks.TaskNodeHandler{} h.OnFinalize(ctx, nCtx).Return(nil) - n := &executorMocks.Node{} + n := &nodeMocks.Node{} n.OnFinalizeHandlerMatch(ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) d := New(h, n, mockLPLauncher, eventConfig, promutils.NewTestScope()) assert.NoError(t, d.Finalize(ctx, nCtx)) @@ -1028,7 +1035,7 @@ func TestDynamicNodeTaskNodeHandler_Finalize(t *testing.T) { mockLPLauncher := &lpMocks.Reader{} h := &mocks.TaskNodeHandler{} h.OnFinalize(ctx, nCtx).Return(fmt.Errorf("err")) - n := &executorMocks.Node{} + n := &nodeMocks.Node{} n.OnFinalizeHandlerMatch(ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) d := New(h, n, mockLPLauncher, eventConfig, promutils.NewTestScope()) assert.Error(t, d.Finalize(ctx, nCtx)) @@ -1049,7 +1056,7 @@ func TestDynamicNodeTaskNodeHandler_Finalize(t *testing.T) { mockLPLauncher := &lpMocks.Reader{} h := &mocks.TaskNodeHandler{} h.OnFinalize(ctx, nCtx).Return(nil) - n := &executorMocks.Node{} + n := &nodeMocks.Node{} n.OnFinalizeHandlerMatch(ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(fmt.Errorf("err")) d := New(h, n, mockLPLauncher, eventConfig, promutils.NewTestScope()) assert.Error(t, d.Finalize(ctx, nCtx)) diff --git a/pkg/controller/nodes/dynamic/utils_test.go b/pkg/controller/nodes/dynamic/utils_test.go index 6afc3cb80..291d175ac 100644 --- a/pkg/controller/nodes/dynamic/utils_test.go +++ b/pkg/controller/nodes/dynamic/utils_test.go @@ -4,7 +4,7 @@ import ( "context" "testing" - mocks2 "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler/mocks" + mocks2 "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" diff --git a/pkg/controller/nodes/end/handler_test.go b/pkg/controller/nodes/end/handler_test.go index fd18841e3..d1d500d14 100644 --- a/pkg/controller/nodes/end/handler_test.go +++ b/pkg/controller/nodes/end/handler_test.go @@ -21,7 +21,7 @@ import ( mocks3 "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" flyteassert "github.com/flyteorg/flytepropeller/pkg/utils/assert" ) diff --git a/pkg/controller/nodes/executor_test.go b/pkg/controller/nodes/executor_test.go index cccec87ef..9605bd73c 100644 --- a/pkg/controller/nodes/executor_test.go +++ b/pkg/controller/nodes/executor_test.go @@ -13,8 +13,6 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/flyteorg/flytestdlib/contextutils" - mocks3 "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" "github.com/flyteorg/flytepropeller/events" @@ -27,13 +25,14 @@ import ( mocks4 "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" gatemocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/gate/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" - nodeHandlerMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler/mocks" - mocks2 "github.com/flyteorg/flytepropeller/pkg/controller/nodes/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + nodemocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" recoveryMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/catalog" flyteassert "github.com/flyteorg/flytepropeller/pkg/utils/assert" + "github.com/flyteorg/flytestdlib/contextutils" "github.com/flyteorg/flytestdlib/promutils" "github.com/flyteorg/flytestdlib/promutils/labeled" "github.com/flyteorg/flytestdlib/storage" @@ -70,8 +69,10 @@ func TestSetInputsForStartNode(t *testing.T) { enQWf := func(workflowID v1alpha1.WorkflowID) {} adminClient := launchplan.NewFailFastLaunchPlanExecutor() + hf := &nodemocks.HandlerFactory{} + hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) exec, err := NewExecutor(ctx, config.GetConfig().NodeConfig, mockStorage, enQWf, eventMocks.NewMockEventSink(), adminClient, - adminClient, 10, "s3://bucket/", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + adminClient, 10, "s3://bucket/", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) inputs := &core.LiteralMap{ Literals: map[string]*core.Literal{ @@ -87,7 +88,7 @@ func TestSetInputsForStartNode(t *testing.T) { } s, err := exec.SetInputsForStartNode(ctx, w, w, w, nil) assert.NoError(t, err) - assert.Equal(t, executors.NodeStatusComplete, s) + assert.Equal(t, interfaces.NodeStatusComplete, s) }) t.Run("WithInputs", func(t *testing.T) { @@ -99,7 +100,7 @@ func TestSetInputsForStartNode(t *testing.T) { } s, err := exec.SetInputsForStartNode(ctx, w, w, w, inputs) assert.NoError(t, err) - assert.Equal(t, executors.NodeStatusComplete, s) + assert.Equal(t, interfaces.NodeStatusComplete, s) actual := &core.LiteralMap{} if assert.NoError(t, mockStorage.ReadProtobuf(ctx, "s3://test-bucket/exec/start-node/data/0/outputs.pb", actual)) { flyteassert.EqualLiteralMap(t, inputs, actual) @@ -113,12 +114,12 @@ func TestSetInputsForStartNode(t *testing.T) { } s, err := exec.SetInputsForStartNode(ctx, w, w, w, inputs) assert.Error(t, err) - assert.Equal(t, executors.NodeStatusUndefined, s) + assert.Equal(t, interfaces.NodeStatusUndefined, s) }) failStorage := createFailingDatastore(t, testScope.NewSubScope("failing")) execFail, err := NewExecutor(ctx, config.GetConfig().NodeConfig, failStorage, enQWf, eventMocks.NewMockEventSink(), adminClient, - adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) t.Run("StorageFailure", func(t *testing.T) { w := createDummyBaseWorkflow(mockStorage) @@ -128,7 +129,7 @@ func TestSetInputsForStartNode(t *testing.T) { } s, err := execFail.SetInputsForStartNode(ctx, w, w, w, inputs) assert.Error(t, err) - assert.Equal(t, executors.NodeStatusUndefined, s) + assert.Equal(t, interfaces.NodeStatusUndefined, s) }) } @@ -143,30 +144,26 @@ func TestNodeExecutor_Initialize(t *testing.T) { adminClient := launchplan.NewFailFastLaunchPlanExecutor() t.Run("happy", func(t *testing.T) { + hf := &nodemocks.HandlerFactory{} + hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) + execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, memStore, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) - hf := &mocks2.HandlerFactory{} - exec.nodeHandlerFactory = hf - - hf.On("Setup", mock.Anything, mock.Anything).Return(nil) - assert.NoError(t, exec.Initialize(ctx)) }) t.Run("error", func(t *testing.T) { + hf := &nodemocks.HandlerFactory{} + hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(fmt.Errorf("error")) + execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, memStore, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) - hf := &mocks2.HandlerFactory{} - exec.nodeHandlerFactory = hf - - hf.On("Setup", mock.Anything, mock.Anything).Return(fmt.Errorf("error")) - assert.Error(t, exec.Initialize(ctx)) }) } @@ -180,8 +177,10 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseStartNodes(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() + hf := &nodemocks.HandlerFactory{} + hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -230,30 +229,30 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseStartNodes(t *testing.T) { name string currentNodePhase v1alpha1.NodePhase expectedNodePhase v1alpha1.NodePhase - expectedPhase executors.NodePhase + expectedPhase interfaces.NodePhase handlerReturn func() (handler.Transition, error) expectedError bool }{ // Starting at Queued - {"nys->success", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseSucceeded, executors.NodePhaseSuccess, func() (handler.Transition, error) { + {"nys->success", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseSucceeded, interfaces.NodePhaseSuccess, func() (handler.Transition, error) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil }, false}, - {"queued->success", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseSucceeded, executors.NodePhaseSuccess, func() (handler.Transition, error) { + {"queued->success", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseSucceeded, interfaces.NodePhaseSuccess, func() (handler.Transition, error) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil }, false}, - {"nys->error", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseNotYetStarted, executors.NodePhaseUndefined, func() (handler.Transition, error) { + {"nys->error", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseNotYetStarted, interfaces.NodePhaseUndefined, func() (handler.Transition, error) { return handler.UnknownTransition, fmt.Errorf("err") }, true}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - hf := &mocks2.HandlerFactory{} + hf := &nodemocks.HandlerFactory{} exec.nodeHandlerFactory = hf - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.On("Handle", mock.MatchedBy(func(ctx context.Context) bool { return true }), - mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), + mock.MatchedBy(func(o interfaces.NodeExecutionContext) bool { return true }), ).Return(test.handlerReturn()) h.On("FinalizeRequired").Return(false) @@ -284,8 +283,10 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseEndNode(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() + hf := &nodemocks.HandlerFactory{} + hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -335,22 +336,22 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseEndNode(t *testing.T) { name string parentNodePhase v1alpha1.NodePhase expectedNodePhase v1alpha1.NodePhase - expectedPhase executors.NodePhase + expectedPhase interfaces.NodePhase expectedError bool }{ - {"notYetStarted", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, - {"running", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, - {"queued", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, - {"retryable", v1alpha1.NodePhaseRetryableFailure, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, - {"failing", v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, - {"skipped", v1alpha1.NodePhaseSkipped, v1alpha1.NodePhaseSkipped, executors.NodePhaseSuccess, false}, - {"success", v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseQueued, executors.NodePhaseQueued, false}, + {"notYetStarted", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseNotYetStarted, interfaces.NodePhasePending, false}, + {"running", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseNotYetStarted, interfaces.NodePhasePending, false}, + {"queued", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseNotYetStarted, interfaces.NodePhasePending, false}, + {"retryable", v1alpha1.NodePhaseRetryableFailure, v1alpha1.NodePhaseNotYetStarted, interfaces.NodePhasePending, false}, + {"failing", v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseNotYetStarted, interfaces.NodePhasePending, false}, + {"skipped", v1alpha1.NodePhaseSkipped, v1alpha1.NodePhaseSkipped, interfaces.NodePhaseSuccess, false}, + {"success", v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseQueued, interfaces.NodePhaseQueued, false}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - hf := &mocks2.HandlerFactory{} + hf := &nodemocks.HandlerFactory{} exec.nodeHandlerFactory = hf - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} hf.OnGetHandler(v1alpha1.NodeKindEnd).Return(h, nil) mockWf, mockNode, mockNodeStatus := createSingleNodeWf(test.parentNodePhase, 0) @@ -426,36 +427,36 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseEndNode(t *testing.T) { name string currentNodePhase v1alpha1.NodePhase expectedNodePhase v1alpha1.NodePhase - expectedPhase executors.NodePhase + expectedPhase interfaces.NodePhase handlerReturn func() (handler.Transition, error) expectedError bool }{ // Starting at Queued - {"queued->success", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseSucceeded, executors.NodePhaseSuccess, func() (handler.Transition, error) { + {"queued->success", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseSucceeded, interfaces.NodePhaseSuccess, func() (handler.Transition, error) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil }, false}, - {"queued->failed", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseFailed, executors.NodePhaseFailed, func() (handler.Transition, error) { + {"queued->failed", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseFailed, interfaces.NodePhaseFailed, func() (handler.Transition, error) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_USER, "code", "mesage", nil)), nil }, false}, - {"queued->running", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseRunning, executors.NodePhasePending, func() (handler.Transition, error) { + {"queued->running", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseRunning, interfaces.NodePhasePending, func() (handler.Transition, error) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(nil)), nil }, false}, - {"queued->error", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseQueued, executors.NodePhaseUndefined, func() (handler.Transition, error) { + {"queued->error", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseQueued, interfaces.NodePhaseUndefined, func() (handler.Transition, error) { return handler.UnknownTransition, fmt.Errorf("err") }, true}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - hf := &mocks2.HandlerFactory{} + hf := &nodemocks.HandlerFactory{} exec.nodeHandlerFactory = hf - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.On("Handle", mock.MatchedBy(func(ctx context.Context) bool { return true }), - mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), + mock.MatchedBy(func(o interfaces.NodeExecutionContext) bool { return true }), ).Return(test.handlerReturn()) h.OnFinalizeRequired().Return(false) @@ -471,7 +472,7 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseEndNode(t *testing.T) { } else { assert.NoError(t, err) } - if test.expectedPhase == executors.NodePhaseFailed { + if test.expectedPhase == interfaces.NodePhaseFailed { assert.NotNil(t, s.Err) } else { assert.Nil(t, s.Err) @@ -665,22 +666,23 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { currentNodePhase v1alpha1.NodePhase parentNodePhase v1alpha1.NodePhase expectedNodePhase v1alpha1.NodePhase - expectedPhase executors.NodePhase + expectedPhase interfaces.NodePhase expectedError bool updateCalled bool }{ - {"notYetStarted->skipped", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseFailed, v1alpha1.NodePhaseSkipped, executors.NodePhaseFailed, false, false}, - {"notYetStarted->skipped", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseSkipped, v1alpha1.NodePhaseSkipped, executors.NodePhaseSuccess, false, true}, - {"notYetStarted->queued", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseQueued, executors.NodePhasePending, false, true}, + {"notYetStarted->skipped", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseFailed, v1alpha1.NodePhaseSkipped, interfaces.NodePhaseFailed, false, false}, + {"notYetStarted->skipped", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseSkipped, v1alpha1.NodePhaseSkipped, interfaces.NodePhaseSuccess, false, true}, + {"notYetStarted->queued", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseQueued, interfaces.NodePhasePending, false, true}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - hf := &mocks2.HandlerFactory{} + hf := &nodemocks.HandlerFactory{} + hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.OnHandleMatch( mock.MatchedBy(func(ctx context.Context) bool { return true }), - mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), + mock.MatchedBy(func(o interfaces.NodeExecutionContext) bool { return true }), ).Return(handler.UnknownTransition, fmt.Errorf("should not be called")) h.OnFinalizeRequired().Return(false) hf.OnGetHandler(v1alpha1.NodeKindTask).Return(h, nil) @@ -691,10 +693,9 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) - exec.nodeHandlerFactory = hf execContext := executors.NewExecutionContext(mockWf, mockWf, mockWf, nil, executors.InitializeControlFlow()) s, err := exec.RecursiveNodeHandler(ctx, execContext, mockWf, mockWf, startNode) @@ -714,7 +715,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { name string currentNodePhase v1alpha1.NodePhase expectedNodePhase v1alpha1.NodePhase - expectedPhase executors.NodePhase + expectedPhase interfaces.NodePhase handlerReturn func() (handler.Transition, error) finalizeReturnErr bool expectedError bool @@ -722,54 +723,54 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { eventPhase core.NodeExecution_Phase }{ // Starting at Queued - {"queued->running", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseRunning, executors.NodePhasePending, func() (handler.Transition, error) { + {"queued->running", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseRunning, interfaces.NodePhasePending, func() (handler.Transition, error) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(nil)), nil }, true, false, true, core.NodeExecution_RUNNING}, - {"queued->queued", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseQueued, executors.NodePhasePending, func() (handler.Transition, error) { + {"queued->queued", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseQueued, interfaces.NodePhasePending, func() (handler.Transition, error) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoQueued("reason", &core.LiteralMap{})), nil }, true, false, false, core.NodeExecution_QUEUED}, - {"queued->failing", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseFailing, executors.NodePhasePending, func() (handler.Transition, error) { + {"queued->failing", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseFailing, interfaces.NodePhasePending, func() (handler.Transition, error) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_USER, "code", "reason", nil)), nil }, true, false, true, core.NodeExecution_FAILED}, - {"failing->failed", v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseFailed, executors.NodePhaseFailed, func() (handler.Transition, error) { + {"failing->failed", v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseFailed, interfaces.NodePhaseFailed, func() (handler.Transition, error) { return handler.UnknownTransition, fmt.Errorf("error") }, false, false, false, core.NodeExecution_FAILED}, - {"failing->failed(error)", v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseFailing, executors.NodePhaseUndefined, func() (handler.Transition, error) { + {"failing->failed(error)", v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseFailing, interfaces.NodePhaseUndefined, func() (handler.Transition, error) { return handler.UnknownTransition, fmt.Errorf("error") }, true, true, false, core.NodeExecution_FAILING}, - {"queued->succeeding", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseSucceeding, executors.NodePhasePending, func() (handler.Transition, error) { + {"queued->succeeding", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseSucceeding, interfaces.NodePhasePending, func() (handler.Transition, error) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil }, true, false, true, core.NodeExecution_SUCCEEDED}, - {"succeeding->success", v1alpha1.NodePhaseSucceeding, v1alpha1.NodePhaseSucceeded, executors.NodePhaseSuccess, func() (handler.Transition, error) { + {"succeeding->success", v1alpha1.NodePhaseSucceeding, v1alpha1.NodePhaseSucceeded, interfaces.NodePhaseSuccess, func() (handler.Transition, error) { return handler.UnknownTransition, fmt.Errorf("error") }, false, false, false, core.NodeExecution_SUCCEEDED}, - {"succeeding->success(error)", v1alpha1.NodePhaseSucceeding, v1alpha1.NodePhaseSucceeding, executors.NodePhaseUndefined, func() (handler.Transition, error) { + {"succeeding->success(error)", v1alpha1.NodePhaseSucceeding, v1alpha1.NodePhaseSucceeding, interfaces.NodePhaseUndefined, func() (handler.Transition, error) { return handler.UnknownTransition, fmt.Errorf("error") }, true, true, false, core.NodeExecution_SUCCEEDED}, - {"queued->error", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseQueued, executors.NodePhaseUndefined, func() (handler.Transition, error) { + {"queued->error", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseQueued, interfaces.NodePhaseUndefined, func() (handler.Transition, error) { return handler.UnknownTransition, fmt.Errorf("error") }, true, true, false, core.NodeExecution_RUNNING}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - hf := &mocks2.HandlerFactory{} + hf := &nodemocks.HandlerFactory{} + hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) - exec.nodeHandlerFactory = hf called := false evRecorder := &eventMocks.NodeEventRecorder{} @@ -780,12 +781,14 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { return true }), mock.Anything).Return(nil) - exec.nodeRecorder = evRecorder + nodeExec, ok := exec.nodeExecutor.(*nodeExecutor) + assert.True(t, ok) + nodeExec.nodeRecorder = evRecorder - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.On("Handle", mock.MatchedBy(func(ctx context.Context) bool { return true }), - mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), + mock.MatchedBy(func(o interfaces.NodeExecutionContext) bool { return true }), ).Return(test.handlerReturn()) h.On("FinalizeRequired").Return(true) @@ -831,57 +834,57 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { name string currentNodePhase v1alpha1.NodePhase expectedNodePhase v1alpha1.NodePhase - expectedPhase executors.NodePhase + expectedPhase interfaces.NodePhase handlerReturn func() (handler.Transition, error) expectedError bool eventRecorded bool eventPhase core.NodeExecution_Phase attempts int }{ - {"running->running", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseRunning, executors.NodePhasePending, func() (handler.Transition, error) { + {"running->running", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseRunning, interfaces.NodePhasePending, func() (handler.Transition, error) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(nil)), nil }, false, false, core.NodeExecution_RUNNING, 0}, - {"running->retryableFailure", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseFailing, executors.NodePhasePending, + {"running->retryableFailure", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseFailing, interfaces.NodePhasePending, func() (handler.Transition, error) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRetryableFailure(core.ExecutionError_USER, "x", "y", nil)), nil }, false, true, core.NodeExecution_FAILED, 0}, - {"retryablefailure->running", v1alpha1.NodePhaseRetryableFailure, v1alpha1.NodePhaseRunning, executors.NodePhasePending, func() (handler.Transition, error) { + {"retryablefailure->running", v1alpha1.NodePhaseRetryableFailure, v1alpha1.NodePhaseRunning, interfaces.NodePhasePending, func() (handler.Transition, error) { return handler.UnknownTransition, fmt.Errorf("should not be invoked") }, false, false, core.NodeExecution_RUNNING, 1}, - {"running->failing", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseFailing, executors.NodePhasePending, func() (handler.Transition, error) { + {"running->failing", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseFailing, interfaces.NodePhasePending, func() (handler.Transition, error) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_USER, "code", "reason", nil)), nil }, false, true, core.NodeExecution_FAILED, 0}, - {"running->succeeding", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseSucceeding, executors.NodePhasePending, func() (handler.Transition, error) { + {"running->succeeding", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseSucceeding, interfaces.NodePhasePending, func() (handler.Transition, error) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil }, false, true, core.NodeExecution_SUCCEEDED, 0}, - {"running->error", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseRunning, executors.NodePhaseUndefined, func() (handler.Transition, error) { + {"running->error", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseRunning, interfaces.NodePhaseUndefined, func() (handler.Transition, error) { return handler.UnknownTransition, fmt.Errorf("error") }, true, false, core.NodeExecution_RUNNING, 0}, - {"previously-failed", v1alpha1.NodePhaseFailed, v1alpha1.NodePhaseFailed, executors.NodePhaseFailed, func() (handler.Transition, error) { + {"previously-failed", v1alpha1.NodePhaseFailed, v1alpha1.NodePhaseFailed, interfaces.NodePhaseFailed, func() (handler.Transition, error) { return handler.UnknownTransition, fmt.Errorf("error") }, false, false, core.NodeExecution_RUNNING, 0}, - {"previously-success", v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseSucceeded, executors.NodePhaseComplete, func() (handler.Transition, error) { + {"previously-success", v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseSucceeded, interfaces.NodePhaseComplete, func() (handler.Transition, error) { return handler.UnknownTransition, fmt.Errorf("error") }, false, false, core.NodeExecution_RUNNING, 0}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - hf := &mocks2.HandlerFactory{} + hf := &nodemocks.HandlerFactory{} + hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) - exec.nodeHandlerFactory = hf called := false evRecorder := &eventMocks.NodeEventRecorder{} @@ -891,12 +894,15 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { called = true return true }), mock.Anything).Return(nil) - exec.nodeRecorder = evRecorder - h := &nodeHandlerMocks.Node{} + nodeExec, ok := exec.nodeExecutor.(*nodeExecutor) + assert.True(t, ok) + nodeExec.nodeRecorder = evRecorder + + h := &nodemocks.NodeHandler{} h.On("Handle", mock.MatchedBy(func(ctx context.Context) bool { return true }), - mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), + mock.MatchedBy(func(o interfaces.NodeExecutionContext) bool { return true }), ).Return(test.handlerReturn()) h.On("FinalizeRequired").Return(true) if test.currentNodePhase == v1alpha1.NodePhaseRetryableFailure { @@ -938,19 +944,19 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { // Extinguished retries t.Run("retries-exhausted", func(t *testing.T) { - hf := &mocks2.HandlerFactory{} + hf := &nodemocks.HandlerFactory{} + hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) - exec.nodeHandlerFactory = hf - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.On("Handle", mock.MatchedBy(func(ctx context.Context) bool { return true }), - mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), + mock.MatchedBy(func(o interfaces.NodeExecutionContext) bool { return true }), ).Return(handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRetryableFailure(core.ExecutionError_USER, "x", "y", nil)), nil) h.On("FinalizeRequired").Return(true) h.On("Finalize", mock.Anything, mock.Anything).Return(fmt.Errorf("error")) @@ -962,26 +968,26 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { s, err := exec.RecursiveNodeHandler(ctx, execContext, mockWf, mockWf, startNode) assert.NoError(t, err) - assert.Equal(t, executors.NodePhasePending.String(), s.NodePhase.String()) + assert.Equal(t, interfaces.NodePhasePending.String(), s.NodePhase.String()) assert.Equal(t, uint32(0), mockNodeStatus.GetAttempts()) assert.Equal(t, v1alpha1.NodePhaseFailing.String(), mockNodeStatus.GetPhase().String()) }) // Remaining retries t.Run("retries-remaining", func(t *testing.T) { - hf := &mocks2.HandlerFactory{} + hf := &nodemocks.HandlerFactory{} + hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) - exec.nodeHandlerFactory = hf - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.On("Handle", mock.MatchedBy(func(ctx context.Context) bool { return true }), - mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), + mock.MatchedBy(func(o interfaces.NodeExecutionContext) bool { return true }), ).Return(handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRetryableFailure(core.ExecutionError_USER, "x", "y", nil)), nil) h.On("FinalizeRequired").Return(true) h.On("Finalize", mock.Anything, mock.Anything).Return(fmt.Errorf("error")) @@ -992,7 +998,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { execContext := executors.NewExecutionContext(mockWf, mockWf, nil, nil, executors.InitializeControlFlow()) s, err := exec.RecursiveNodeHandler(ctx, execContext, mockWf, mockWf, startNode) assert.NoError(t, err) - assert.Equal(t, executors.NodePhasePending.String(), s.NodePhase.String()) + assert.Equal(t, interfaces.NodePhasePending.String(), s.NodePhase.String()) assert.Equal(t, uint32(0), mockNodeStatus.GetAttempts()) assert.Equal(t, v1alpha1.NodePhaseFailing.String(), mockNodeStatus.GetPhase().String()) }) @@ -1006,8 +1012,10 @@ func TestNodeExecutor_RecursiveNodeHandler_NoDownstream(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() + hf := &nodemocks.HandlerFactory{} + hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -1072,21 +1080,21 @@ func TestNodeExecutor_RecursiveNodeHandler_NoDownstream(t *testing.T) { name string currentNodePhase v1alpha1.NodePhase expectedNodePhase v1alpha1.NodePhase - expectedPhase executors.NodePhase + expectedPhase interfaces.NodePhase expectedError bool }{ - {"succeeded", v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseSucceeded, executors.NodePhaseComplete, false}, - {"failed", v1alpha1.NodePhaseFailed, v1alpha1.NodePhaseFailed, executors.NodePhaseFailed, false}, + {"succeeded", v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseSucceeded, interfaces.NodePhaseComplete, false}, + {"failed", v1alpha1.NodePhaseFailed, v1alpha1.NodePhaseFailed, interfaces.NodePhaseFailed, false}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - hf := &mocks2.HandlerFactory{} + hf := &nodemocks.HandlerFactory{} exec.nodeHandlerFactory = hf - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.On("Handle", mock.MatchedBy(func(ctx context.Context) bool { return true }), - mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), + mock.MatchedBy(func(o interfaces.NodeExecutionContext) bool { return true }), ).Return(handler.UnknownTransition, fmt.Errorf("should not be called")) h.On("FinalizeRequired").Return(true) h.On("Finalize", mock.Anything, mock.Anything).Return(fmt.Errorf("error")) @@ -1117,8 +1125,10 @@ func TestNodeExecutor_RecursiveNodeHandler_UpstreamNotReady(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() + hf := &nodemocks.HandlerFactory{} + hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -1183,25 +1193,25 @@ func TestNodeExecutor_RecursiveNodeHandler_UpstreamNotReady(t *testing.T) { name string parentNodePhase v1alpha1.NodePhase expectedNodePhase v1alpha1.NodePhase - expectedPhase executors.NodePhase + expectedPhase interfaces.NodePhase expectedError bool }{ - {"notYetStarted", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, - {"running", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, - {"queued", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, - {"retryable", v1alpha1.NodePhaseRetryableFailure, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, - {"failing", v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, - {"failing", v1alpha1.NodePhaseSucceeding, v1alpha1.NodePhaseNotYetStarted, executors.NodePhasePending, false}, - {"skipped", v1alpha1.NodePhaseSkipped, v1alpha1.NodePhaseSkipped, executors.NodePhaseSuccess, false}, + {"notYetStarted", v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseNotYetStarted, interfaces.NodePhasePending, false}, + {"running", v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseNotYetStarted, interfaces.NodePhasePending, false}, + {"queued", v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseNotYetStarted, interfaces.NodePhasePending, false}, + {"retryable", v1alpha1.NodePhaseRetryableFailure, v1alpha1.NodePhaseNotYetStarted, interfaces.NodePhasePending, false}, + {"failing", v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseNotYetStarted, interfaces.NodePhasePending, false}, + {"failing", v1alpha1.NodePhaseSucceeding, v1alpha1.NodePhaseNotYetStarted, interfaces.NodePhasePending, false}, + {"skipped", v1alpha1.NodePhaseSkipped, v1alpha1.NodePhaseSkipped, interfaces.NodePhaseSuccess, false}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - hf := &mocks2.HandlerFactory{} + hf := &nodemocks.HandlerFactory{} exec.nodeHandlerFactory = hf - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.On("Handle", mock.MatchedBy(func(ctx context.Context) bool { return true }), - mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), + mock.MatchedBy(func(o interfaces.NodeExecutionContext) bool { return true }), ).Return(handler.UnknownTransition, fmt.Errorf("should not be called")) h.On("FinalizeRequired").Return(true) h.On("Finalize", mock.Anything, mock.Anything).Return(fmt.Errorf("error")) @@ -1233,8 +1243,10 @@ func TestNodeExecutor_RecursiveNodeHandler_BranchNode(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() + hf := &nodemocks.HandlerFactory{} + hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) // Node not yet started @@ -1244,22 +1256,22 @@ func TestNodeExecutor_RecursiveNodeHandler_BranchNode(t *testing.T) { parentNodePhase v1alpha1.BranchNodePhase currentNodePhase v1alpha1.NodePhase phaseUpdateExpected bool - expectedPhase executors.NodePhase + expectedPhase interfaces.NodePhase expectedError bool }{ - {"branchSuccess", v1alpha1.BranchNodeSuccess, v1alpha1.NodePhaseNotYetStarted, true, executors.NodePhaseQueued, false}, - {"branchNotYetDone", v1alpha1.BranchNodeNotYetEvaluated, v1alpha1.NodePhaseNotYetStarted, false, executors.NodePhasePending, false}, - {"branchError", v1alpha1.BranchNodeError, v1alpha1.NodePhaseNotYetStarted, false, executors.NodePhasePending, false}, + {"branchSuccess", v1alpha1.BranchNodeSuccess, v1alpha1.NodePhaseNotYetStarted, true, interfaces.NodePhaseQueued, false}, + {"branchNotYetDone", v1alpha1.BranchNodeNotYetEvaluated, v1alpha1.NodePhaseNotYetStarted, false, interfaces.NodePhasePending, false}, + {"branchError", v1alpha1.BranchNodeError, v1alpha1.NodePhaseNotYetStarted, false, interfaces.NodePhasePending, false}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - hf := &mocks2.HandlerFactory{} + hf := &nodemocks.HandlerFactory{} exec.nodeHandlerFactory = hf - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.OnHandleMatch( mock.MatchedBy(func(ctx context.Context) bool { return true }), - mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), + mock.MatchedBy(func(o interfaces.NodeExecutionContext) bool { return true }), ).Return(handler.UnknownTransition, fmt.Errorf("should not be called")) h.OnFinalizeRequired().Return(true) h.OnFinalizeMatch(mock.Anything, mock.Anything).Return(fmt.Errorf("error")) @@ -1345,7 +1357,6 @@ func TestNodeExecutor_RecursiveNodeHandler_BranchNode(t *testing.T) { func Test_nodeExecutor_RecordTransitionLatency(t *testing.T) { testScope := promutils.NewTestScope() type fields struct { - nodeHandlerFactory HandlerFactory enqueueWorkflow v1alpha1.EnqueueWorkflow store *storage.DataStore nodeRecorder events.NodeEventRecorder @@ -1389,8 +1400,7 @@ func Test_nodeExecutor_RecordTransitionLatency(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := &recursiveNodeExecutor{ - nodeHandlerFactory: tt.fields.nodeHandlerFactory, + c := &nodeExecutor{ enqueueWorkflow: tt.fields.enqueueWorkflow, store: tt.fields.store, nodeRecorder: tt.fields.nodeRecorder, @@ -1489,22 +1499,18 @@ func Test_nodeExecutor_timeout(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := &recursiveNodeExecutor{defaultActiveDeadline: time.Second, defaultExecutionDeadline: time.Second} + c := &nodeExecutor{defaultActiveDeadline: time.Second, defaultExecutionDeadline: time.Second} handlerReturn := func() (handler.Transition, error) { return handler.DoTransition(handler.TransitionTypeEphemeral, tt.phaseInfo), tt.err } - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.On("Handle", mock.MatchedBy(func(ctx context.Context) bool { return true }), - mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), + mock.MatchedBy(func(o interfaces.NodeExecutionContext) bool { return true }), ).Return(handlerReturn()) h.On("FinalizeRequired").Return(true) h.On("Finalize", mock.Anything, mock.Anything).Return(nil) - hf := &mocks2.HandlerFactory{} - hf.On("GetHandler", v1alpha1.NodeKindStart).Return(h, nil) - c.nodeHandlerFactory = hf - mockNode := &mocks.ExecutableNode{} mockNode.On("GetID").Return("node") mockNode.On("GetBranchNode").Return(nil) @@ -1544,19 +1550,16 @@ func Test_nodeExecutor_system_error(t *testing.T) { ns.On("ClearLastAttemptStartedAt").Return() - c := &recursiveNodeExecutor{} - h := &nodeHandlerMocks.Node{} + c := &nodeExecutor{} + h := &nodemocks.NodeHandler{} h.On("Handle", mock.MatchedBy(func(ctx context.Context) bool { return true }), - mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), + mock.MatchedBy(func(o interfaces.NodeExecutionContext) bool { return true }), ).Return(handler.DoTransition(handler.TransitionTypeEphemeral, phaseInfo), nil) h.On("FinalizeRequired").Return(true) h.On("Finalize", mock.Anything, mock.Anything).Return(nil) - hf := &mocks2.HandlerFactory{} - hf.On("GetHandler", v1alpha1.NodeKindStart).Return(h, nil) - c.nodeHandlerFactory = hf c.maxNodeRetriesForSystemFailures = 2 mockNode := &mocks.ExecutableNode{} @@ -1575,11 +1578,11 @@ func Test_nodeExecutor_system_error(t *testing.T) { func Test_nodeExecutor_abort(t *testing.T) { ctx := context.Background() - exec := recursiveNodeExecutor{} + exec := nodeExecutor{} nCtx := &nodeExecContext{} t.Run("abort error calls finalize", func(t *testing.T) { - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.OnAbortMatch(mock.Anything, mock.Anything, mock.Anything).Return(errors.New("test error")) h.OnFinalizeRequired().Return(true) var called bool @@ -1587,13 +1590,13 @@ func Test_nodeExecutor_abort(t *testing.T) { called = true }).Return(nil) - err := exec.abort(ctx, h, nCtx, "testing") + err := exec.Abort(ctx, h, nCtx, "testing") assert.Equal(t, "test error", err.Error()) assert.True(t, called) }) t.Run("abort error calls finalize with error", func(t *testing.T) { - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.OnAbortMatch(mock.Anything, mock.Anything, mock.Anything).Return(errors.New("test error")) h.OnFinalizeRequired().Return(true) var called bool @@ -1601,13 +1604,13 @@ func Test_nodeExecutor_abort(t *testing.T) { called = true }).Return(errors.New("finalize error")) - err := exec.abort(ctx, h, nCtx, "testing") + err := exec.Abort(ctx, h, nCtx, "testing") assert.Equal(t, "0: test error\r\n1: finalize error\r\n", err.Error()) assert.True(t, called) }) t.Run("abort calls finalize when no errors", func(t *testing.T) { - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.OnAbortMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) h.OnFinalizeRequired().Return(true) var called bool @@ -1615,7 +1618,7 @@ func Test_nodeExecutor_abort(t *testing.T) { called = true }).Return(nil) - err := exec.abort(ctx, h, nCtx, "testing") + err := exec.Abort(ctx, h, nCtx, "testing") assert.NoError(t, err) assert.True(t, called) }) @@ -1651,15 +1654,17 @@ func TestNodeExecutor_AbortHandler(t *testing.T) { nl.OnGetNode(id).Return(n, true) incompatibleClusterErr := fakeNodeEventRecorder{&eventsErr.EventError{Code: eventsErr.AlreadyExists, Cause: fmt.Errorf("err")}} - hf := &mocks2.HandlerFactory{} + hf := &nodemocks.HandlerFactory{} exec.nodeHandlerFactory = hf - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.OnAbortMatch(mock.Anything, mock.Anything, "aborting").Return(nil) h.OnFinalizeMatch(mock.Anything, mock.Anything).Return(nil) hf.OnGetHandlerMatch(v1alpha1.NodeKindStart).Return(h, nil) nExec := recursiveNodeExecutor{ - nodeRecorder: incompatibleClusterErr, + nodeExecutor: &nodeExecutor{ + nodeRecorder: incompatibleClusterErr, + }, nodeHandlerFactory: hf, } @@ -1843,8 +1848,10 @@ func TestNodeExecutor_RecursiveNodeHandler_ParallelismLimit(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() + hf := &nodemocks.HandlerFactory{} + hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -1921,12 +1928,12 @@ func TestNodeExecutor_RecursiveNodeHandler_ParallelismLimit(t *testing.T) { cf := executors.InitializeControlFlow() eCtx := executors.NewExecutionContext(mockWf, mockWf, nil, nil, cf) - hf := &mocks2.HandlerFactory{} + hf := &nodemocks.HandlerFactory{} exec.nodeHandlerFactory = hf - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.OnHandleMatch( mock.MatchedBy(func(ctx context.Context) bool { return true }), - mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), + mock.MatchedBy(func(o interfaces.NodeExecutionContext) bool { return true }), ).Return(handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil) h.OnFinalizeRequired().Return(false) @@ -1934,7 +1941,7 @@ func TestNodeExecutor_RecursiveNodeHandler_ParallelismLimit(t *testing.T) { s, err := exec.RecursiveNodeHandler(ctx, eCtx, mockWf, mockWf, mockNode) assert.NoError(t, err) - assert.Equal(t, s.NodePhase.String(), executors.NodePhaseSuccess.String()) + assert.Equal(t, s.NodePhase.String(), interfaces.NodePhaseSuccess.String()) }) t.Run("parallelism-met", func(t *testing.T) { @@ -1945,7 +1952,7 @@ func TestNodeExecutor_RecursiveNodeHandler_ParallelismLimit(t *testing.T) { s, err := exec.RecursiveNodeHandler(ctx, eCtx, mockWf, mockWf, mockNode) assert.NoError(t, err) - assert.Equal(t, s.NodePhase.String(), executors.NodePhaseRunning.String()) + assert.Equal(t, s.NodePhase.String(), interfaces.NodePhaseRunning.String()) }) t.Run("parallelism-met-not-yet-started", func(t *testing.T) { @@ -1956,7 +1963,7 @@ func TestNodeExecutor_RecursiveNodeHandler_ParallelismLimit(t *testing.T) { s, err := exec.RecursiveNodeHandler(ctx, eCtx, mockWf, mockWf, mockNode) assert.NoError(t, err) - assert.Equal(t, s.NodePhase.String(), executors.NodePhaseRunning.String()) + assert.Equal(t, s.NodePhase.String(), interfaces.NodePhaseRunning.String()) }) t.Run("parallelism-disabled", func(t *testing.T) { @@ -1965,12 +1972,12 @@ func TestNodeExecutor_RecursiveNodeHandler_ParallelismLimit(t *testing.T) { cf.IncrementParallelism() eCtx := executors.NewExecutionContext(mockWf, mockWf, nil, nil, cf) - hf := &mocks2.HandlerFactory{} + hf := &nodemocks.HandlerFactory{} exec.nodeHandlerFactory = hf - h := &nodeHandlerMocks.Node{} + h := &nodemocks.NodeHandler{} h.OnHandleMatch( mock.MatchedBy(func(ctx context.Context) bool { return true }), - mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), + mock.MatchedBy(func(o interfaces.NodeExecutionContext) bool { return true }), ).Return(handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil) h.OnFinalizeRequired().Return(false) @@ -1978,7 +1985,7 @@ func TestNodeExecutor_RecursiveNodeHandler_ParallelismLimit(t *testing.T) { s, err := exec.RecursiveNodeHandler(ctx, eCtx, mockWf, mockWf, mockNode) assert.NoError(t, err) - assert.Equal(t, s.NodePhase.String(), executors.NodePhaseSuccess.String()) + assert.Equal(t, s.NodePhase.String(), interfaces.NodePhaseSuccess.String()) }) } @@ -1993,7 +2000,8 @@ func (f fakeNodeEventRecorder) RecordNodeEvent(ctx context.Context, event *event return nil } -func Test_nodeExecutor_IdempotentRecordEvent(t *testing.T) { +// TODO @hamersaw - fix IdempotentRecordEvent test -> move to NodeExecutionSomething +/*func Test_nodeExecutor_IdempotentRecordEvent(t *testing.T) { noErrRecorder := fakeNodeEventRecorder{} alreadyExistsError := fakeNodeEventRecorder{&eventsErr.EventError{Code: eventsErr.AlreadyExists, Cause: fmt.Errorf("err")}} inTerminalError := fakeNodeEventRecorder{&eventsErr.EventError{Code: eventsErr.EventAlreadyInTerminalStateError, Cause: fmt.Errorf("err")}} @@ -2029,7 +2037,7 @@ func Test_nodeExecutor_IdempotentRecordEvent(t *testing.T) { } }) } -} +}*/ func TestRecover(t *testing.T) { recoveryID := &core.WorkflowExecutionIdentifier{ @@ -2087,7 +2095,7 @@ func TestRecover(t *testing.T) { }) execContext.OnGetEventVersion().Return(v1alpha1.EventVersion0) - nm := &nodeHandlerMocks.NodeExecutionMetadata{} + nm := &nodemocks.NodeExecutionMetadata{} nm.OnGetNodeExecutionID().Return(&core.NodeExecutionIdentifier{ ExecutionId: wfExecID, NodeId: nodeID, @@ -2099,7 +2107,7 @@ func TestRecover(t *testing.T) { ns := &mocks.ExecutableNodeStatus{} ns.OnGetOutputDir().Return(storage.DataReference("out")) - nCtx := &nodeHandlerMocks.NodeExecutionContext{} + nCtx := &nodemocks.NodeExecutionContext{} nCtx.OnExecutionContext().Return(execContext) nCtx.OnNodeExecutionMetadata().Return(nm) nCtx.OnInputReader().Return(ir) @@ -2139,7 +2147,7 @@ func TestRecover(t *testing.T) { } nCtx.OnDataStore().Return(storageClient) - executor := recursiveNodeExecutor{ + executor := nodeExecutor{ recoveryClient: recoveryClient, store: storageClient, eventConfig: &config.EventConfig{ @@ -2156,7 +2164,7 @@ func TestRecover(t *testing.T) { dstDynamicJobSpecURI := "dst/foo/bar" // initialize node execution context - nCtx := &nodeHandlerMocks.NodeExecutionContext{} + nCtx := &nodemocks.NodeExecutionContext{} nCtx.OnExecutionContext().Return(execContext) nCtx.OnNodeExecutionMetadata().Return(nm) nCtx.OnInputReader().Return(ir) @@ -2182,13 +2190,13 @@ func TestRecover(t *testing.T) { nCtx.OnDataStore().Return(storageClient) - reader := &nodeHandlerMocks.NodeStateReader{} - reader.OnGetDynamicNodeState().Return(handler.DynamicNodeState{}) + reader := &nodemocks.NodeStateReader{} + reader.OnGetDynamicNodeState().Return(interfaces.DynamicNodeState{}) nCtx.OnNodeStateReader().Return(reader) - writer := &nodeHandlerMocks.NodeStateWriter{} + writer := &nodemocks.NodeStateWriter{} writer.OnPutDynamicNodeStateMatch(mock.Anything).Run(func(args mock.Arguments) { - state := args.Get(0).(handler.DynamicNodeState) + state := args.Get(0).(interfaces.DynamicNodeState) assert.Equal(t, v1alpha1.DynamicNodePhaseParentFinalized, state.Phase) }).Return(nil) nCtx.OnNodeStateWriter().Return(writer) @@ -2232,7 +2240,7 @@ func TestRecover(t *testing.T) { DynamicWorkflow: dynamicWorkflow, }, nil) - executor := recursiveNodeExecutor{ + executor := nodeExecutor{ recoveryClient: recoveryClient, store: storageClient, eventConfig: eventConfig, @@ -2305,7 +2313,7 @@ func TestRecover(t *testing.T) { nCtx.OnDataStore().Return(storageClient) - executor := recursiveNodeExecutor{ + executor := nodeExecutor{ recoveryClient: recoveryClient, store: storageClient, eventConfig: eventConfig, @@ -2369,7 +2377,7 @@ func TestRecover(t *testing.T) { } nCtx.OnDataStore().Return(storageClient) - executor := recursiveNodeExecutor{ + executor := nodeExecutor{ recoveryClient: recoveryClient, store: storageClient, eventConfig: eventConfig, @@ -2399,16 +2407,16 @@ func TestRecover(t *testing.T) { }, }, nil) - executor := recursiveNodeExecutor{ + executor := nodeExecutor{ recoveryClient: recoveryClient, } - reader := &nodeHandlerMocks.NodeStateReader{} - reader.OnGetTaskNodeState().Return(handler.TaskNodeState{}) + reader := &nodemocks.NodeStateReader{} + reader.OnGetTaskNodeState().Return(interfaces.TaskNodeState{}) nCtx.OnNodeStateReader().Return(reader) - writer := &nodeHandlerMocks.NodeStateWriter{} + writer := &nodemocks.NodeStateWriter{} writer.OnPutTaskNodeStateMatch(mock.Anything).Run(func(args mock.Arguments) { - state := args.Get(0).(handler.TaskNodeState) + state := args.Get(0).(interfaces.TaskNodeState) assert.Equal(t, state.PreviousNodeExecutionCheckpointURI.String(), "prev path") }).Return(nil) nCtx.OnNodeStateWriter().Return(writer) @@ -2454,7 +2462,7 @@ func TestRecover(t *testing.T) { } nCtx.OnDataStore().Return(storageClient) - executor := recursiveNodeExecutor{ + executor := nodeExecutor{ recoveryClient: recoveryClient, store: storageClient, eventConfig: eventConfig, @@ -2499,7 +2507,7 @@ func TestRecover(t *testing.T) { } nCtx.OnDataStore().Return(storageClient) - executor := recursiveNodeExecutor{ + executor := nodeExecutor{ recoveryClient: recoveryClient, store: storageClient, eventConfig: eventConfig, diff --git a/pkg/controller/nodes/gate/handler_test.go b/pkg/controller/nodes/gate/handler_test.go index ca7e8e9dd..0de925797 100644 --- a/pkg/controller/nodes/gate/handler_test.go +++ b/pkg/controller/nodes/gate/handler_test.go @@ -17,7 +17,8 @@ import ( executormocks "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/gate/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" - nodeMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + nodeMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" "github.com/flyteorg/flytestdlib/contextutils" "github.com/flyteorg/flytestdlib/promutils" @@ -122,7 +123,7 @@ func createNodeExecutionContext(gateNode *v1alpha1.GateNodeSpec) *nodeMocks.Node nCtx.OnInputReader().Return(inputReader) r := &nodeMocks.NodeStateReader{} - r.OnGetGateNodeState().Return(handler.GateNodeState{}) + r.OnGetGateNodeState().Return(interfaces.GateNodeState{}) nCtx.OnNodeStateReader().Return(r) w := &nodeMocks.NodeStateWriter{} diff --git a/pkg/controller/nodes/node_exec_context_test.go b/pkg/controller/nodes/node_exec_context_test.go index 8ad294287..8586e2e00 100644 --- a/pkg/controller/nodes/node_exec_context_test.go +++ b/pkg/controller/nodes/node_exec_context_test.go @@ -90,7 +90,7 @@ func Test_NodeContext(t *testing.T) { s, _ := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) p := parentInfo{} execContext := executors.NewExecutionContext(w1, nil, nil, p, nil) - nCtx := newNodeExecContext(context.TODO(), s, execContext, w1, getTestNodeSpec(nil), nil, nil, false, 0, 2, nil, TaskReader{}, nil, nil, "s3://bucket", ioutils.NewConstantShardSelector([]string{"x"})) + nCtx := newNodeExecContext(context.TODO(), s, execContext, w1, getTestNodeSpec(nil), nil, nil, false, 0, 2, nil, nil, TaskReader{}, nil, nil, "s3://bucket", ioutils.NewConstantShardSelector([]string{"x"})) assert.Equal(t, "id", nCtx.NodeExecutionMetadata().GetLabels()["node-id"]) assert.Equal(t, "false", nCtx.NodeExecutionMetadata().GetLabels()["interruptible"]) assert.Equal(t, "task-name", nCtx.NodeExecutionMetadata().GetLabels()["task-name"]) @@ -108,7 +108,7 @@ func Test_NodeContextDefault(t *testing.T) { SystemFailures: 0, }) - nodeExecutor := recursiveNodeExecutor{ + nodeExecutor := nodeExecutor{ interruptibleFailureThreshold: 0, maxDatasetSizeBytes: 0, defaultDataSandbox: "s3://bucket-a", @@ -118,14 +118,14 @@ func Test_NodeContextDefault(t *testing.T) { } p := parentInfo{} execContext := executors.NewExecutionContext(w1, w1, w1, p, nil) - nodeExecContext, err := nodeExecutor.newNodeExecContextDefault(context.Background(), "node-a", execContext, nodeLookup) + nodeExecContext, err := nodeExecutor.BuildNodeExecutionContext(context.Background(), execContext, nodeLookup, "node-a") assert.NoError(t, err) - assert.Equal(t, "s3://bucket-a", nodeExecContext.rawOutputPrefix.String()) + assert.Equal(t, "s3://bucket-a", nodeExecContext.RawOutputPrefix().String()) w1.RawOutputDataConfig.OutputLocationPrefix = "s3://bucket-b" - nodeExecContext, err = nodeExecutor.newNodeExecContextDefault(context.Background(), "node-a", execContext, nodeLookup) + nodeExecContext, err = nodeExecutor.BuildNodeExecutionContext(context.Background(), execContext, nodeLookup, "node-a") assert.NoError(t, err) - assert.Equal(t, "s3://bucket-b", nodeExecContext.rawOutputPrefix.String()) + assert.Equal(t, "s3://bucket-b", nodeExecContext.RawOutputPrefix().String()) } func Test_NodeContextDefaultInterruptible(t *testing.T) { @@ -133,7 +133,7 @@ func Test_NodeContextDefaultInterruptible(t *testing.T) { scope := promutils.NewTestScope() dataStore, _ := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, scope.NewSubScope("dataStore")) - nodeExecutor := recursiveNodeExecutor{ + nodeExecutor := nodeExecutor{ interruptibleFailureThreshold: 10, maxDatasetSizeBytes: 0, defaultDataSandbox: "s3://bucket-a", @@ -148,10 +148,10 @@ func Test_NodeContextDefaultInterruptible(t *testing.T) { } verifyNodeExecContext := func(t *testing.T, executionContext executors.ExecutionContext, nl executors.NodeLookup, shouldBeInterruptible bool) { - nodeExecContext, err := nodeExecutor.newNodeExecContextDefault(context.Background(), "node-a", executionContext, nl) + nodeExecContext, err := nodeExecutor.BuildNodeExecutionContext(context.Background(), executionContext, nl, "node-a") assert.NoError(t, err) - assert.Equal(t, shouldBeInterruptible, nodeExecContext.md.IsInterruptible()) - labels := nodeExecContext.md.GetLabels() + assert.Equal(t, shouldBeInterruptible, nodeExecContext.NodeExecutionMetadata().IsInterruptible()) + labels := nodeExecContext.NodeExecutionMetadata().GetLabels() assert.Contains(t, labels, NodeInterruptibleLabel) assert.Equal(t, strconv.FormatBool(shouldBeInterruptible), labels[NodeInterruptibleLabel]) } diff --git a/pkg/controller/nodes/subworkflow/handler_test.go b/pkg/controller/nodes/subworkflow/handler_test.go index 0599953f5..2a7bb7cfa 100644 --- a/pkg/controller/nodes/subworkflow/handler_test.go +++ b/pkg/controller/nodes/subworkflow/handler_test.go @@ -26,37 +26,45 @@ import ( mocks2 "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" execMocks "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" - mocks3 "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + mocks3 "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/mocks" ) type workflowNodeStateHolder struct { - s handler.WorkflowNodeState + s interfaces.WorkflowNodeState } var eventConfig = &config.EventConfig{ RawOutputPolicy: config.RawOutputPolicyReference, } -func (t *workflowNodeStateHolder) PutTaskNodeState(s handler.TaskNodeState) error { +func (t *workflowNodeStateHolder) ClearNodeStatus() { +} + +func (t *workflowNodeStateHolder) PutTaskNodeState(s interfaces.TaskNodeState) error { panic("not implemented") } -func (t workflowNodeStateHolder) PutBranchNode(s handler.BranchNodeState) error { +func (t workflowNodeStateHolder) PutBranchNode(s interfaces.BranchNodeState) error { panic("not implemented") } -func (t *workflowNodeStateHolder) PutWorkflowNodeState(s handler.WorkflowNodeState) error { +func (t *workflowNodeStateHolder) PutWorkflowNodeState(s interfaces.WorkflowNodeState) error { t.s = s return nil } -func (t workflowNodeStateHolder) PutDynamicNodeState(s handler.DynamicNodeState) error { +func (t workflowNodeStateHolder) PutDynamicNodeState(s interfaces.DynamicNodeState) error { + panic("not implemented") +} + +func (t workflowNodeStateHolder) PutGateNodeState(s interfaces.GateNodeState) error { panic("not implemented") } -func (t workflowNodeStateHolder) PutGateNodeState(s handler.GateNodeState) error { +func (t workflowNodeStateHolder) PutArrayNodeState(s interfaces.ArrayNodeState) error { panic("not implemented") } @@ -68,7 +76,7 @@ var wfExecID = &core.WorkflowExecutionIdentifier{ func createNodeContextWithVersion(phase v1alpha1.WorkflowNodePhase, n v1alpha1.ExecutableNode, s v1alpha1.ExecutableNodeStatus, version v1alpha1.EventVersion) *mocks3.NodeExecutionContext { - wfNodeState := handler.WorkflowNodeState{} + wfNodeState := interfaces.WorkflowNodeState{} state := &workflowNodeStateHolder{s: wfNodeState} nm := &mocks3.NodeExecutionMetadata{} @@ -101,7 +109,7 @@ func createNodeContextWithVersion(phase v1alpha1.WorkflowNodePhase, n v1alpha1.E nCtx.OnNodeStatus().Return(s) nr := &mocks3.NodeStateReader{} - nr.OnGetWorkflowNodeState().Return(handler.WorkflowNodeState{ + nr.OnGetWorkflowNodeState().Return(interfaces.WorkflowNodeState{ Phase: phase, }) nCtx.OnNodeStateReader().Return(nr) diff --git a/pkg/controller/nodes/subworkflow/launchplan_test.go b/pkg/controller/nodes/subworkflow/launchplan_test.go index 4c8f6806e..99d75e8a0 100644 --- a/pkg/controller/nodes/subworkflow/launchplan_test.go +++ b/pkg/controller/nodes/subworkflow/launchplan_test.go @@ -7,7 +7,7 @@ import ( "testing" mocks4 "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" - mocks3 "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler/mocks" + mocks3 "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" diff --git a/pkg/controller/nodes/subworkflow/subworkflow_test.go b/pkg/controller/nodes/subworkflow/subworkflow_test.go index 50840776f..def32c198 100644 --- a/pkg/controller/nodes/subworkflow/subworkflow_test.go +++ b/pkg/controller/nodes/subworkflow/subworkflow_test.go @@ -13,7 +13,7 @@ import ( coreMocks "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" execMocks "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" ) func TestGetSubWorkflow(t *testing.T) { @@ -87,7 +87,7 @@ func Test_subworkflowHandler_HandleAbort(t *testing.T) { nCtx.OnNodeStatus().Return(ns) nCtx.OnNodeID().Return("n1") - nodeExec := &execMocks.Node{} + nodeExec := &mocks.Node{} s := newSubworkflowHandler(nodeExec, eventConfig) n := &coreMocks.ExecutableNode{} swf.OnGetID().Return("swf") @@ -120,7 +120,7 @@ func Test_subworkflowHandler_HandleAbort(t *testing.T) { nCtx.OnNodeID().Return("n1") nCtx.OnCurrentAttempt().Return(uint32(1)) - nodeExec := &execMocks.Node{} + nodeExec := &mocks.Node{} s := newSubworkflowHandler(nodeExec, eventConfig) n := &coreMocks.ExecutableNode{} swf.OnGetID().Return("swf") @@ -154,7 +154,7 @@ func Test_subworkflowHandler_HandleAbort(t *testing.T) { nCtx.OnNodeID().Return("n1") nCtx.OnCurrentAttempt().Return(uint32(1)) - nodeExec := &execMocks.Node{} + nodeExec := &mocks.Node{} s := newSubworkflowHandler(nodeExec, eventConfig) n := &coreMocks.ExecutableNode{} swf.OnGetID().Return("swf") diff --git a/pkg/controller/nodes/task/handler_test.go b/pkg/controller/nodes/task/handler_test.go index 6da4762cb..f002d60ae 100644 --- a/pkg/controller/nodes/task/handler_test.go +++ b/pkg/controller/nodes/task/handler_test.go @@ -10,9 +10,6 @@ import ( "github.com/golang/protobuf/proto" eventsErr "github.com/flyteorg/flytepropeller/events/errors" - mocks2 "github.com/flyteorg/flytepropeller/events/mocks" - - "github.com/flyteorg/flytepropeller/events" pluginK8sMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s/mocks" @@ -49,7 +46,8 @@ import ( flyteMocks "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" - nodeMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + nodeMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/codex" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/config" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/fakeplugins" @@ -352,37 +350,48 @@ func Test_task_ResolvePlugin(t *testing.T) { } } -type fakeBufferedTaskEventRecorder struct { +type fakeBufferedEventRecorder struct { evs []*event.TaskExecutionEvent } -func (f *fakeBufferedTaskEventRecorder) RecordTaskEvent(ctx context.Context, ev *event.TaskExecutionEvent, eventConfig *controllerConfig.EventConfig) error { +func (f *fakeBufferedEventRecorder) RecordTaskEvent(ctx context.Context, ev *event.TaskExecutionEvent, eventConfig *controllerConfig.EventConfig) error { f.evs = append(f.evs, ev) return nil } +func (f *fakeBufferedEventRecorder) RecordNodeEvent(ctx context.Context, ev *event.NodeExecutionEvent, eventConfig *controllerConfig.EventConfig) error { + return nil +} + type taskNodeStateHolder struct { - s handler.TaskNodeState + s interfaces.TaskNodeState } -func (t *taskNodeStateHolder) PutTaskNodeState(s handler.TaskNodeState) error { +func (t *taskNodeStateHolder) ClearNodeStatus() { +} + +func (t *taskNodeStateHolder) PutTaskNodeState(s interfaces.TaskNodeState) error { t.s = s return nil } -func (t taskNodeStateHolder) PutBranchNode(s handler.BranchNodeState) error { +func (t taskNodeStateHolder) PutBranchNode(s interfaces.BranchNodeState) error { + panic("not implemented") +} + +func (t taskNodeStateHolder) PutWorkflowNodeState(s interfaces.WorkflowNodeState) error { panic("not implemented") } -func (t taskNodeStateHolder) PutWorkflowNodeState(s handler.WorkflowNodeState) error { +func (t taskNodeStateHolder) PutDynamicNodeState(s interfaces.DynamicNodeState) error { panic("not implemented") } -func (t taskNodeStateHolder) PutDynamicNodeState(s handler.DynamicNodeState) error { +func (t taskNodeStateHolder) PutGateNodeState(s interfaces.GateNodeState) error { panic("not implemented") } -func (t taskNodeStateHolder) PutGateNodeState(s handler.GateNodeState) error { +func (t taskNodeStateHolder) PutArrayNodeState(s interfaces.ArrayNodeState) error { panic("not implemented") } @@ -399,7 +408,7 @@ func Test_task_Handle_NoCatalog(t *testing.T) { "foo": coreutils.MustMakeLiteral("bar"), }, } - createNodeContext := func(pluginPhase pluginCore.Phase, pluginVer uint32, pluginResp fakeplugins.NextPhaseState, recorder events.TaskEventRecorder, ttype string, s *taskNodeStateHolder, allowIncrementParallelism bool) *nodeMocks.NodeExecutionContext { + createNodeContext := func(pluginPhase pluginCore.Phase, pluginVer uint32, pluginResp fakeplugins.NextPhaseState, recorder interfaces.EventRecorder, ttype string, s *taskNodeStateHolder, allowIncrementParallelism bool) *nodeMocks.NodeExecutionContext { wfExecID := &core.WorkflowExecutionIdentifier{ Project: "project", Domain: "domain", @@ -510,7 +519,7 @@ func Test_task_Handle_NoCatalog(t *testing.T) { cod := codex.GobStateCodec{} assert.NoError(t, cod.Encode(pluginResp, st)) nr := &nodeMocks.NodeStateReader{} - nr.OnGetTaskNodeState().Return(handler.TaskNodeState{ + nr.OnGetTaskNodeState().Return(interfaces.TaskNodeState{ PluginState: st.Bytes(), PluginPhase: pluginPhase, PluginPhaseVersion: pluginVer, @@ -690,7 +699,7 @@ func Test_task_Handle_NoCatalog(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { state := &taskNodeStateHolder{} - ev := &fakeBufferedTaskEventRecorder{} + ev := &fakeBufferedEventRecorder{} nCtx := createNodeContext(tt.args.startingPluginPhase, uint32(tt.args.startingPluginPhaseVersion), tt.args.expectedState, ev, "test", state, tt.want.incrParallel) c := &pluginCatalogMocks.Client{} tk := Handler{ @@ -758,7 +767,7 @@ func Test_task_Handle_NoCatalog(t *testing.T) { func Test_task_Handle_Catalog(t *testing.T) { - createNodeContext := func(recorder events.TaskEventRecorder, ttype string, s *taskNodeStateHolder, overwriteCache bool) *nodeMocks.NodeExecutionContext { + createNodeContext := func(recorder interfaces.EventRecorder, ttype string, s *taskNodeStateHolder, overwriteCache bool) *nodeMocks.NodeExecutionContext { wfExecID := &core.WorkflowExecutionIdentifier{ Project: "project", Domain: "domain", @@ -858,7 +867,7 @@ func Test_task_Handle_Catalog(t *testing.T) { OutputExists: true, }, st)) nr := &nodeMocks.NodeStateReader{} - nr.OnGetTaskNodeState().Return(handler.TaskNodeState{ + nr.OnGetTaskNodeState().Return(interfaces.TaskNodeState{ PluginState: st.Bytes(), }) nCtx.OnNodeStateReader().Return(nr) @@ -950,7 +959,7 @@ func Test_task_Handle_Catalog(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { state := &taskNodeStateHolder{} - ev := &fakeBufferedTaskEventRecorder{} + ev := &fakeBufferedEventRecorder{} nCtx := createNodeContext(ev, "test", state, tt.args.catalogSkip) c := &pluginCatalogMocks.Client{} if tt.args.catalogFetch { @@ -1018,7 +1027,7 @@ func Test_task_Handle_Catalog(t *testing.T) { func Test_task_Handle_Reservation(t *testing.T) { - createNodeContext := func(recorder events.TaskEventRecorder, ttype string, s *taskNodeStateHolder, overwriteCache bool) *nodeMocks.NodeExecutionContext { + createNodeContext := func(recorder interfaces.EventRecorder, ttype string, s *taskNodeStateHolder, overwriteCache bool) *nodeMocks.NodeExecutionContext { wfExecID := &core.WorkflowExecutionIdentifier{ Project: "project", Domain: "domain", @@ -1206,7 +1215,7 @@ func Test_task_Handle_Reservation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { state := &taskNodeStateHolder{} - ev := &fakeBufferedTaskEventRecorder{} + ev := &fakeBufferedEventRecorder{} nCtx := createNodeContext(ev, "test", state, tt.args.catalogSkip) c := &pluginCatalogMocks.Client{} nr := &nodeMocks.NodeStateReader{} @@ -1216,7 +1225,7 @@ func Test_task_Handle_Reservation(t *testing.T) { Phase: pluginCore.PhaseSuccess, OutputExists: true, }, st)) - nr.OnGetTaskNodeState().Return(handler.TaskNodeState{ + nr.OnGetTaskNodeState().Return(interfaces.TaskNodeState{ PluginPhase: tt.args.pluginPhase, PluginState: st.Bytes(), }) @@ -1269,7 +1278,7 @@ func Test_task_Handle_Reservation(t *testing.T) { } func Test_task_Abort(t *testing.T) { - createNodeCtx := func(ev events.TaskEventRecorder) *nodeMocks.NodeExecutionContext { + createNodeCtx := func(ev interfaces.EventRecorder) *nodeMocks.NodeExecutionContext { wfExecID := &core.WorkflowExecutionIdentifier{ Project: "project", Domain: "domain", @@ -1346,7 +1355,7 @@ func Test_task_Abort(t *testing.T) { cod := codex.GobStateCodec{} assert.NoError(t, cod.Encode(test{A: a}, st)) nr := &nodeMocks.NodeStateReader{} - nr.OnGetTaskNodeState().Return(handler.TaskNodeState{ + nr.OnGetTaskNodeState().Return(interfaces.TaskNodeState{ PluginState: st.Bytes(), }) nCtx.OnNodeStateReader().Return(nr) @@ -1355,7 +1364,7 @@ func Test_task_Abort(t *testing.T) { noopRm := CreateNoopResourceManager(context.TODO(), promutils.NewTestScope()) - incompatibleClusterEventsRecorder := mocks2.TaskEventRecorder{} + incompatibleClusterEventsRecorder := nodeMocks.EventRecorder{} incompatibleClusterEventsRecorder.OnRecordTaskEventMatch(mock.Anything, mock.Anything, mock.Anything).Return( &eventsErr.EventError{ Code: eventsErr.EventIncompatibleCusterError, @@ -1365,7 +1374,7 @@ func Test_task_Abort(t *testing.T) { defaultPluginCallback func() pluginCore.Plugin } type args struct { - ev events.TaskEventRecorder + ev interfaces.EventRecorder } tests := []struct { name string @@ -1391,7 +1400,7 @@ func Test_task_Abort(t *testing.T) { p.OnGetProperties().Return(pluginCore.PluginProperties{}) p.On("Abort", mock.Anything, mock.Anything).Return(nil) return p - }}, args{ev: &fakeBufferedTaskEventRecorder{}}, false, true}, + }}, args{ev: &fakeBufferedEventRecorder{}}, false, true}, {"abort-swallows-incompatible-cluster-err", fields{defaultPluginCallback: func() pluginCore.Plugin { p := &pluginCoreMocks.Plugin{} p.On("GetID").Return("id") @@ -1416,10 +1425,10 @@ func Test_task_Abort(t *testing.T) { c = 1 if !tt.wantErr { switch tt.args.ev.(type) { - case *fakeBufferedTaskEventRecorder: - assert.Len(t, tt.args.ev.(*fakeBufferedTaskEventRecorder).evs, 1) - case *mocks2.TaskEventRecorder: - assert.Len(t, tt.args.ev.(*mocks2.TaskEventRecorder).Calls, 1) + case *fakeBufferedEventRecorder: + assert.Len(t, tt.args.ev.(*fakeBufferedEventRecorder).evs, 1) + case *nodeMocks.EventRecorder: + assert.Len(t, tt.args.ev.(*nodeMocks.EventRecorder).Calls, 1) } } } @@ -1431,7 +1440,7 @@ func Test_task_Abort(t *testing.T) { } func Test_task_Abort_v1(t *testing.T) { - createNodeCtx := func(ev events.TaskEventRecorder) *nodeMocks.NodeExecutionContext { + createNodeCtx := func(ev interfaces.EventRecorder) *nodeMocks.NodeExecutionContext { wfExecID := &core.WorkflowExecutionIdentifier{ Project: "project", Domain: "domain", @@ -1508,7 +1517,7 @@ func Test_task_Abort_v1(t *testing.T) { cod := codex.GobStateCodec{} assert.NoError(t, cod.Encode(test{A: a}, st)) nr := &nodeMocks.NodeStateReader{} - nr.OnGetTaskNodeState().Return(handler.TaskNodeState{ + nr.OnGetTaskNodeState().Return(interfaces.TaskNodeState{ PluginState: st.Bytes(), }) nCtx.OnNodeStateReader().Return(nr) @@ -1517,7 +1526,7 @@ func Test_task_Abort_v1(t *testing.T) { noopRm := CreateNoopResourceManager(context.TODO(), promutils.NewTestScope()) - incompatibleClusterEventsRecorder := mocks2.TaskEventRecorder{} + incompatibleClusterEventsRecorder := nodeMocks.EventRecorder{} incompatibleClusterEventsRecorder.OnRecordTaskEventMatch(mock.Anything, mock.Anything, mock.Anything).Return( &eventsErr.EventError{ Code: eventsErr.EventIncompatibleCusterError, @@ -1527,7 +1536,7 @@ func Test_task_Abort_v1(t *testing.T) { defaultPluginCallback func() pluginCore.Plugin } type args struct { - ev events.TaskEventRecorder + ev interfaces.EventRecorder } tests := []struct { name string @@ -1553,7 +1562,7 @@ func Test_task_Abort_v1(t *testing.T) { p.OnGetProperties().Return(pluginCore.PluginProperties{}) p.On("Abort", mock.Anything, mock.Anything).Return(nil) return p - }}, args{ev: &fakeBufferedTaskEventRecorder{}}, false, true}, + }}, args{ev: &fakeBufferedEventRecorder{}}, false, true}, {"abort-swallows-incompatible-cluster-err", fields{defaultPluginCallback: func() pluginCore.Plugin { p := &pluginCoreMocks.Plugin{} p.On("GetID").Return("id") @@ -1578,10 +1587,10 @@ func Test_task_Abort_v1(t *testing.T) { c = 1 if !tt.wantErr { switch tt.args.ev.(type) { - case *fakeBufferedTaskEventRecorder: - assert.Len(t, tt.args.ev.(*fakeBufferedTaskEventRecorder).evs, 1) - case *mocks2.TaskEventRecorder: - assert.Len(t, tt.args.ev.(*mocks2.TaskEventRecorder).Calls, 1) + case *fakeBufferedEventRecorder: + assert.Len(t, tt.args.ev.(*fakeBufferedEventRecorder).evs, 1) + case *nodeMocks.EventRecorder: + assert.Len(t, tt.args.ev.(*nodeMocks.EventRecorder).Calls, 1) } } } @@ -1694,7 +1703,7 @@ func Test_task_Finalize(t *testing.T) { cod := codex.GobStateCodec{} assert.NoError(t, cod.Encode(test{A: a}, st)) nr := &nodeMocks.NodeStateReader{} - nr.On("GetTaskNodeState").Return(handler.TaskNodeState{ + nr.On("GetTaskNodeState").Return(interfaces.TaskNodeState{ PluginState: st.Bytes(), }) nCtx.On("NodeStateReader").Return(nr) diff --git a/pkg/controller/nodes/task/setup_ctx_test.go b/pkg/controller/nodes/task/setup_ctx_test.go index 6b8d0a438..e987bbb24 100644 --- a/pkg/controller/nodes/task/setup_ctx_test.go +++ b/pkg/controller/nodes/task/setup_ctx_test.go @@ -4,13 +4,13 @@ import ( "testing" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flytestdlib/promutils" "github.com/stretchr/testify/assert" ) type dummySetupCtx struct { - handler.SetupContext + interfaces.SetupContext testScopeName string } diff --git a/pkg/controller/nodes/task/taskexec_context_test.go b/pkg/controller/nodes/task/taskexec_context_test.go index bea082871..1bc368652 100644 --- a/pkg/controller/nodes/task/taskexec_context_test.go +++ b/pkg/controller/nodes/task/taskexec_context_test.go @@ -30,8 +30,8 @@ import ( "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" flyteMocks "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" - nodeMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + nodeMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/codex" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/secretmanager" ) @@ -113,7 +113,7 @@ func TestHandler_newTaskExecutionContext(t *testing.T) { codex := codex.GobStateCodec{} assert.NoError(t, codex.Encode(test{A: a}, st)) nr := &nodeMocks.NodeStateReader{} - nr.OnGetTaskNodeState().Return(handler.TaskNodeState{ + nr.OnGetTaskNodeState().Return(interfaces.TaskNodeState{ PluginState: st.Bytes(), }) nCtx.OnNodeStateReader().Return(nr) @@ -432,7 +432,7 @@ func TestComputePreviousCheckpointPath(t *testing.T) { nCtx.OnDataStore().Return(ds) nCtx.OnNodeExecutionMetadata().Return(nm) reader := &nodeMocks.NodeStateReader{} - reader.OnGetTaskNodeState().Return(handler.TaskNodeState{}) + reader.OnGetTaskNodeState().Return(interfaces.TaskNodeState{}) nCtx.OnNodeStateReader().Return(reader) t.Run("attempt-0-nCtx", func(t *testing.T) { @@ -464,7 +464,7 @@ func TestComputePreviousCheckpointPath_Recovery(t *testing.T) { nCtx.OnDataStore().Return(ds) nCtx.OnNodeExecutionMetadata().Return(nm) reader := &nodeMocks.NodeStateReader{} - reader.OnGetTaskNodeState().Return(handler.TaskNodeState{ + reader.OnGetTaskNodeState().Return(interfaces.TaskNodeState{ PreviousNodeExecutionCheckpointURI: storage.DataReference("s3://sandbox/x/prevname-n1-0/_flytecheckpoints"), }) nCtx.OnNodeStateReader().Return(reader) diff --git a/pkg/controller/nodes/task/transformer_test.go b/pkg/controller/nodes/task/transformer_test.go index f8a6f1289..d9400336e 100644 --- a/pkg/controller/nodes/task/transformer_test.go +++ b/pkg/controller/nodes/task/transformer_test.go @@ -24,7 +24,7 @@ import ( pluginMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" - handlerMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler/mocks" + nodemocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" ) const containerTaskType = "container" @@ -60,7 +60,7 @@ func TestToTaskExecutionEvent(t *testing.T) { const outputPath = "out" out.On("GetOutputPath").Return(storage.DataReference(outputPath)) - nodeExecutionMetadata := handlerMocks.NodeExecutionMetadata{} + nodeExecutionMetadata := nodemocks.NodeExecutionMetadata{} nodeExecutionMetadata.OnIsInterruptible().Return(true) mockExecContext := &mocks2.ExecutionContext{} @@ -158,7 +158,7 @@ func TestToTaskExecutionEvent(t *testing.T) { assert.EqualValues(t, resourcePoolInfo, tev.Metadata.ResourcePoolInfo) assert.Equal(t, testClusterID, tev.ProducerId) - defaultNodeExecutionMetadata := handlerMocks.NodeExecutionMetadata{} + defaultNodeExecutionMetadata := nodemocks.NodeExecutionMetadata{} defaultNodeExecutionMetadata.OnIsInterruptible().Return(false) tev, err = ToTaskExecutionEvent(ToTaskExecutionEventInputs{ TaskExecContext: tCtx, @@ -251,7 +251,7 @@ func TestToTaskExecutionEventWithParent(t *testing.T) { const outputPath = "out" out.On("GetOutputPath").Return(storage.DataReference(outputPath)) - nodeExecutionMetadata := handlerMocks.NodeExecutionMetadata{} + nodeExecutionMetadata := nodemocks.NodeExecutionMetadata{} nodeExecutionMetadata.OnIsInterruptible().Return(true) mockExecContext := &mocks2.ExecutionContext{} diff --git a/pkg/controller/workflow/executor_test.go b/pkg/controller/workflow/executor_test.go index 47cf83ebe..f4b9236b9 100644 --- a/pkg/controller/workflow/executor_test.go +++ b/pkg/controller/workflow/executor_test.go @@ -32,6 +32,7 @@ import ( eventsErr "github.com/flyteorg/flytepropeller/events/errors" eventMocks "github.com/flyteorg/flytepropeller/events/mocks" mocks2 "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" + nodemocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/catalog" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/fakeplugins" @@ -726,7 +727,7 @@ func TestWorkflowExecutor_HandleAbortedWorkflow(t *testing.T) { t.Run("user-initiated-fail", func(t *testing.T) { - nodeExec := &mocks2.Node{} + nodeExec := &nodemocks.Node{} wExec := &workflowExecutor{ nodeExecutor: nodeExec, metrics: newMetrics(promutils.NewTestScope()), @@ -756,7 +757,7 @@ func TestWorkflowExecutor_HandleAbortedWorkflow(t *testing.T) { t.Run("user-initiated-success", func(t *testing.T) { var evs []*event.WorkflowExecutionEvent - nodeExec := &mocks2.Node{} + nodeExec := &nodemocks.Node{} wfRecorder := &eventMocks.WorkflowEventRecorder{} wfRecorder.On("RecordWorkflowEvent", mock.Anything, mock.MatchedBy(func(ev *event.WorkflowExecutionEvent) bool { assert.Equal(t, testClusterID, ev.ProducerId) @@ -798,7 +799,7 @@ func TestWorkflowExecutor_HandleAbortedWorkflow(t *testing.T) { t.Run("user-initiated-attempts-exhausted", func(t *testing.T) { var evs []*event.WorkflowExecutionEvent - nodeExec := &mocks2.Node{} + nodeExec := &nodemocks.Node{} wfRecorder := &eventMocks.WorkflowEventRecorder{} wfRecorder.OnRecordWorkflowEventMatch(mock.Anything, mock.MatchedBy(func(ev *event.WorkflowExecutionEvent) bool { assert.Equal(t, testClusterID, ev.ProducerId) @@ -839,7 +840,7 @@ func TestWorkflowExecutor_HandleAbortedWorkflow(t *testing.T) { t.Run("failure-abort-success", func(t *testing.T) { var evs []*event.WorkflowExecutionEvent - nodeExec := &mocks2.Node{} + nodeExec := &nodemocks.Node{} wfRecorder := &eventMocks.WorkflowEventRecorder{} wfRecorder.OnRecordWorkflowEventMatch(mock.Anything, mock.MatchedBy(func(ev *event.WorkflowExecutionEvent) bool { assert.Equal(t, testClusterID, ev.ProducerId) @@ -877,7 +878,7 @@ func TestWorkflowExecutor_HandleAbortedWorkflow(t *testing.T) { t.Run("failure-abort-failed", func(t *testing.T) { - nodeExec := &mocks2.Node{} + nodeExec := &nodemocks.Node{} wExec := &workflowExecutor{ nodeExecutor: nodeExec, metrics: newMetrics(promutils.NewTestScope()), From a38ec17d0c0b34992acdadd781aee46892436864 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Thu, 15 Jun 2023 17:47:23 -0500 Subject: [PATCH 43/62] unit tests passing Signed-off-by: Daniel Rammer --- pkg/controller/nodes/executor.go | 79 ++++++++++++------------ pkg/controller/nodes/executor_test.go | 9 +-- pkg/controller/workflow/executor_test.go | 48 +++++++++++--- 3 files changed, 85 insertions(+), 51 deletions(-) diff --git a/pkg/controller/nodes/executor.go b/pkg/controller/nodes/executor.go index 7781fb90d..bfc5691c5 100644 --- a/pkg/controller/nodes/executor.go +++ b/pkg/controller/nodes/executor.go @@ -409,6 +409,45 @@ func (c *recursiveNodeExecutor) AbortHandler(ctx context.Context, execContext ex if err != nil { return err } + + // TODO @hamersaw - need to fix this shouldn't need to decompose nodeExecutor to send event + if nodeExec, ok := c.nodeExecutor.(*nodeExecutor); ok { + nodeExecutionID := &core.NodeExecutionIdentifier{ + ExecutionId: nCtx.NodeExecutionMetadata().GetNodeExecutionID().ExecutionId, + NodeId: nCtx.NodeExecutionMetadata().GetNodeExecutionID().NodeId, + } + if nCtx.ExecutionContext().GetEventVersion() != v1alpha1.EventVersion0 { + currentNodeUniqueID, err := common.GenerateUniqueID(nCtx.ExecutionContext().GetParentInfo(), nodeExecutionID.NodeId) + if err != nil { + return err + } + nodeExecutionID.NodeId = currentNodeUniqueID + } + + //err := c.IdempotentRecordEvent(ctx, &event.NodeExecutionEvent{ + err = nCtx.EventsRecorder().RecordNodeEvent(ctx, &event.NodeExecutionEvent{ + Id: nodeExecutionID, + Phase: core.NodeExecution_ABORTED, + OccurredAt: ptypes.TimestampNow(), + OutputResult: &event.NodeExecutionEvent_Error{ + Error: &core.ExecutionError{ + Code: "NodeAborted", + Message: reason, + }, + }, + ProducerId: nodeExec.clusterID, + ReportedAt: ptypes.TimestampNow(), + }, nodeExec.eventConfig) + if err != nil && !eventsErr.IsNotFound(err) && !eventsErr.IsEventIncompatibleClusterError(err) { + if errors2.IsCausedBy(err, errors.IllegalStateError) { + logger.Debugf(ctx, "Failed to record abort event due to illegal state transition. Ignoring the error. Error: %v", err) + } else { + logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) + return errors.Wrapf(errors.EventRecordingFailed, nCtx.NodeID(), err, "failed to record node event") + } + } + } + return nil } else if nodePhase == v1alpha1.NodePhaseSucceeded || nodePhase == v1alpha1.NodePhaseSkipped || nodePhase == v1alpha1.NodePhaseRecovered { // Abort downstream nodes downstreamNodes, err := dag.FromNode(currentNode.GetID()) @@ -894,45 +933,7 @@ func (c *nodeExecutor) Abort(ctx context.Context, h interfaces.NodeHandler, nCtx return err } - if err := h.Finalize(ctx, nCtx); err != nil { - return err - } - - nodeExecutionID := &core.NodeExecutionIdentifier{ - ExecutionId: nCtx.NodeExecutionMetadata().GetNodeExecutionID().ExecutionId, - NodeId: nCtx.NodeExecutionMetadata().GetNodeExecutionID().NodeId, - } - if nCtx.ExecutionContext().GetEventVersion() != v1alpha1.EventVersion0 { - currentNodeUniqueID, err := common.GenerateUniqueID(nCtx.ExecutionContext().GetParentInfo(), nodeExecutionID.NodeId) - if err != nil { - return err - } - nodeExecutionID.NodeId = currentNodeUniqueID - } - - //err := c.IdempotentRecordEvent(ctx, &event.NodeExecutionEvent{ - err := nCtx.EventsRecorder().RecordNodeEvent(ctx, &event.NodeExecutionEvent{ - Id: nodeExecutionID, - Phase: core.NodeExecution_ABORTED, - OccurredAt: ptypes.TimestampNow(), - OutputResult: &event.NodeExecutionEvent_Error{ - Error: &core.ExecutionError{ - Code: "NodeAborted", - Message: reason, - }, - }, - ProducerId: c.clusterID, - ReportedAt: ptypes.TimestampNow(), - }, c.eventConfig) - if err != nil && !eventsErr.IsNotFound(err) && !eventsErr.IsEventIncompatibleClusterError(err) { - if errors2.IsCausedBy(err, errors.IllegalStateError) { - logger.Debugf(ctx, "Failed to record abort event due to illegal state transition. Ignoring the error. Error: %v", err) - } else { - logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) - return errors.Wrapf(errors.EventRecordingFailed, nCtx.NodeID(), err, "failed to record node event") - } - } - return nil + return h.Finalize(ctx, nCtx) } func (c *nodeExecutor) Finalize(ctx context.Context, h interfaces.NodeHandler, nCtx interfaces.NodeExecutionContext) error { diff --git a/pkg/controller/nodes/executor_test.go b/pkg/controller/nodes/executor_test.go index 9605bd73c..f3f759881 100644 --- a/pkg/controller/nodes/executor_test.go +++ b/pkg/controller/nodes/executor_test.go @@ -1655,16 +1655,17 @@ func TestNodeExecutor_AbortHandler(t *testing.T) { incompatibleClusterErr := fakeNodeEventRecorder{&eventsErr.EventError{Code: eventsErr.AlreadyExists, Cause: fmt.Errorf("err")}} hf := &nodemocks.HandlerFactory{} - exec.nodeHandlerFactory = hf h := &nodemocks.NodeHandler{} h.OnAbortMatch(mock.Anything, mock.Anything, "aborting").Return(nil) h.OnFinalizeMatch(mock.Anything, mock.Anything).Return(nil) hf.OnGetHandlerMatch(v1alpha1.NodeKindStart).Return(h, nil) + nodeExecutor := &nodeExecutor{ + nodeRecorder: incompatibleClusterErr, + } nExec := recursiveNodeExecutor{ - nodeExecutor: &nodeExecutor{ - nodeRecorder: incompatibleClusterErr, - }, + nodeExecutor: nodeExecutor, + nCtxBuilder: nodeExecutor, nodeHandlerFactory: hf, } diff --git a/pkg/controller/workflow/executor_test.go b/pkg/controller/workflow/executor_test.go index f4b9236b9..e947918f4 100644 --- a/pkg/controller/workflow/executor_test.go +++ b/pkg/controller/workflow/executor_test.go @@ -32,6 +32,8 @@ import ( eventsErr "github.com/flyteorg/flytepropeller/events/errors" eventMocks "github.com/flyteorg/flytepropeller/events/mocks" mocks2 "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/factory" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" nodemocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/catalog" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/fakeplugins" @@ -244,10 +246,13 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Error(t *testing.T) { catalogClient, err := catalog.NewCatalogClient(ctx, nil) assert.NoError(t, err) recoveryClient := &recoveryMocks.Client{} - adminClient := launchplan.NewFailFastLaunchPlanExecutor() + + handlerFactory, err := factory.NewHandlerFactory(ctx, adminClient, adminClient, fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + assert.NoError(t, err) + nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, - maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) assert.NoError(t, err) executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "", nodeExec, eventConfig, testClusterID, promutils.NewTestScope()) assert.NoError(t, err) @@ -324,10 +329,13 @@ func TestWorkflowExecutor_HandleFlyteWorkflow(t *testing.T) { catalogClient, err := catalog.NewCatalogClient(ctx, nil) assert.NoError(t, err) recoveryClient := &recoveryMocks.Client{} - adminClient := launchplan.NewFailFastLaunchPlanExecutor() + + handlerFactory, err := factory.NewHandlerFactory(ctx, adminClient, adminClient, fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + assert.NoError(t, err) + nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, - maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) assert.NoError(t, err) executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "", nodeExec, eventConfig, testClusterID, promutils.NewTestScope()) @@ -389,8 +397,9 @@ func BenchmarkWorkflowExecutor(b *testing.B) { assert.NoError(b, err) recoveryClient := &recoveryMocks.Client{} adminClient := launchplan.NewFailFastLaunchPlanExecutor() + handlerFactory := &nodemocks.HandlerFactory{} nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, - maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, scope) + maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, scope) assert.NoError(b, err) executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "", nodeExec, eventConfig, testClusterID, promutils.NewTestScope()) @@ -490,8 +499,19 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Failing(t *testing.T) { assert.NoError(t, err) recoveryClient := &recoveryMocks.Client{} adminClient := launchplan.NewFailFastLaunchPlanExecutor() + + h := &nodemocks.NodeHandler{} + h.OnAbortMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) + h.OnHandleMatch(mock.Anything, mock.Anything).Return(handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil) + h.OnFinalizeMatch(mock.Anything, mock.Anything).Return(nil) + h.OnFinalizeRequired().Return(false) + + handlerFactory := &nodemocks.HandlerFactory{} + handlerFactory.OnSetupMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) + handlerFactory.OnGetHandlerMatch(mock.Anything).Return(h, nil) + nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, - maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) assert.NoError(t, err) executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "", nodeExec, eventConfig, testClusterID, promutils.NewTestScope()) assert.NoError(t, err) @@ -586,8 +606,12 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Events(t *testing.T) { assert.NoError(t, err) adminClient := launchplan.NewFailFastLaunchPlanExecutor() recoveryClient := &recoveryMocks.Client{} + + handlerFactory, err := factory.NewHandlerFactory(ctx, adminClient, adminClient, fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + assert.NoError(t, err) + nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, - maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) assert.NoError(t, err) executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "metadata", nodeExec, eventConfig, testClusterID, promutils.NewTestScope()) assert.NoError(t, err) @@ -644,8 +668,16 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_EventFailure(t *testing.T) { recoveryClient := &recoveryMocks.Client{} adminClient := launchplan.NewFailFastLaunchPlanExecutor() + h := &nodemocks.NodeHandler{} + h.OnAbortMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) + h.OnHandleMatch(mock.Anything, mock.Anything).Return(handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil) + h.OnFinalizeMatch(mock.Anything, mock.Anything).Return(nil) + h.OnFinalizeRequired().Return(false) + handlerFactory := &nodemocks.HandlerFactory{} + handlerFactory.OnSetupMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) + handlerFactory.OnGetHandlerMatch(mock.Anything).Return(h, nil) nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, nodeEventSink, adminClient, adminClient, - maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) assert.NoError(t, err) t.Run("EventAlreadyInTerminalStateError", func(t *testing.T) { From db1e3613a83494f124a24d63d6c8236195d069c1 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Thu, 15 Jun 2023 21:34:19 -0500 Subject: [PATCH 44/62] fixed lint issues Signed-off-by: Daniel Rammer --- pkg/apis/flyteworkflow/v1alpha1/array.go | 5 +- .../flyteworkflow/v1alpha1/node_status.go | 2 +- pkg/compiler/transformers/k8s/node.go | 6 +- .../nodes/array/execution_context.go | 3 +- pkg/controller/nodes/array/handler.go | 46 +++--- pkg/controller/nodes/array/handler_test.go | 134 +++++++++--------- .../nodes/array/node_execution_context.go | 3 +- .../array/node_execution_context_builder.go | 2 +- pkg/controller/nodes/array/node_lookup.go | 8 +- pkg/controller/nodes/array/utils.go | 6 +- pkg/controller/nodes/executor.go | 44 +++--- pkg/controller/nodes/executor_test.go | 16 +-- .../nodes/factory/handler_factory.go | 8 +- pkg/controller/nodes/node_exec_context.go | 13 +- 14 files changed, 151 insertions(+), 145 deletions(-) diff --git a/pkg/apis/flyteworkflow/v1alpha1/array.go b/pkg/apis/flyteworkflow/v1alpha1/array.go index 8d47a8990..6680e7410 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/array.go +++ b/pkg/apis/flyteworkflow/v1alpha1/array.go @@ -1,10 +1,7 @@ package v1alpha1 -import ( -) - type ArrayNodeSpec struct { - SubNodeSpec *NodeSpec + SubNodeSpec *NodeSpec Parallelism uint32 MinSuccesses *uint32 MinSuccessRatio *float32 diff --git a/pkg/apis/flyteworkflow/v1alpha1/node_status.go b/pkg/apis/flyteworkflow/v1alpha1/node_status.go index 5a4d9fc8a..ea95c3b56 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/node_status.go +++ b/pkg/apis/flyteworkflow/v1alpha1/node_status.go @@ -11,8 +11,8 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytestdlib/bitarray" - "github.com/flyteorg/flytestdlib/storage" "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/storage" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) diff --git a/pkg/compiler/transformers/k8s/node.go b/pkg/compiler/transformers/k8s/node.go index fab0bfcd2..4fcbd6e14 100644 --- a/pkg/compiler/transformers/k8s/node.go +++ b/pkg/compiler/transformers/k8s/node.go @@ -172,10 +172,10 @@ func buildNodeSpec(n *core.Node, tasks []*core.CompiledTask, errs errors.Compile } switch successCriteria := arrayNode.SuccessCriteria.(type) { - case *core.ArrayNode_MinSuccesses: - nodeSpec.ArrayNode.MinSuccesses = &successCriteria.MinSuccesses; + case *core.ArrayNode_MinSuccesses: + nodeSpec.ArrayNode.MinSuccesses = &successCriteria.MinSuccesses case *core.ArrayNode_MinSuccessRatio: - nodeSpec.ArrayNode.MinSuccessRatio = &successCriteria.MinSuccessRatio; + nodeSpec.ArrayNode.MinSuccessRatio = &successCriteria.MinSuccessRatio } // TODO @hamersaw hack - should not be necessary, should be set in flytekit diff --git a/pkg/controller/nodes/array/execution_context.go b/pkg/controller/nodes/array/execution_context.go index b5acd384f..6731ec55b 100644 --- a/pkg/controller/nodes/array/execution_context.go +++ b/pkg/controller/nodes/array/execution_context.go @@ -27,7 +27,7 @@ func (a *arrayExecutionContext) CurrentParallelism() uint32 { } func (a *arrayExecutionContext) IncrementParallelism() uint32 { - *a.currentParallelism = *a.currentParallelism+1 + *a.currentParallelism = *a.currentParallelism + 1 return *a.currentParallelism } @@ -38,7 +38,6 @@ func newArrayExecutionContext(executionContext executors.ExecutionContext, subNo } executionConfig.EnvironmentVariables[JobIndexVarName] = FlyteK8sArrayIndexVarName executionConfig.EnvironmentVariables[FlyteK8sArrayIndexVarName] = strconv.Itoa(subNodeIndex) - executionConfig.MaxParallelism = maxParallelism return &arrayExecutionContext{ diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 24ebc04d9..a1d7c0223 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -86,6 +86,9 @@ func (a *arrayNodeHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecut // create array contexts arrayNodeExecutor, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec, _, _, err := a.buildArrayNodeContext(ctx, nCtx, &arrayNodeState, arrayNode, i, ¤tParallelism) + if err != nil { + return err + } // abort subNode err = arrayNodeExecutor.AbortHandler(ctx, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec, reason) @@ -141,6 +144,9 @@ func (a *arrayNodeHandler) Finalize(ctx context.Context, nCtx interfaces.NodeExe // create array contexts arrayNodeExecutor, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec, _, _, err := a.buildArrayNodeContext(ctx, nCtx, &arrayNodeState, arrayNode, i, ¤tParallelism) + if err != nil { + return err + } // finalize subNode err = arrayNodeExecutor.FinalizeHandler(ctx, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec) @@ -160,7 +166,7 @@ func (a *arrayNodeHandler) Finalize(ctx context.Context, nCtx interfaces.NodeExe // FinalizeRequired defines whether or not this handler requires finalize to be called on node // completion func (a *arrayNodeHandler) FinalizeRequired() bool { - // must return true because we can't determine if finalize is required for the subNode + // must return true because we can't determine if finalize is required for the subNode return true } @@ -213,14 +219,17 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu maxAttempts = *subNodeSpec.GetRetryStrategy().MinAttempts } - for _, item := range []struct{arrayReference *bitarray.CompactArray; maxValue int}{ + for _, item := range []struct { + arrayReference *bitarray.CompactArray + maxValue int + }{ // we use NodePhaseRecovered for the `maxValue` of `SubNodePhases` because `Phase` is // defined as an `iota` so it is impossible to programmatically get largest value - {arrayReference: &arrayNodeState.SubNodePhases, maxValue: int(v1alpha1.NodePhaseRecovered)}, - {arrayReference: &arrayNodeState.SubNodeTaskPhases, maxValue: len(core.Phases)-1}, - {arrayReference: &arrayNodeState.SubNodeRetryAttempts, maxValue: maxAttempts}, - {arrayReference: &arrayNodeState.SubNodeSystemFailures, maxValue: maxAttempts}, - } { + {arrayReference: &arrayNodeState.SubNodePhases, maxValue: int(v1alpha1.NodePhaseRecovered)}, + {arrayReference: &arrayNodeState.SubNodeTaskPhases, maxValue: len(core.Phases) - 1}, + {arrayReference: &arrayNodeState.SubNodeRetryAttempts, maxValue: maxAttempts}, + {arrayReference: &arrayNodeState.SubNodeSystemFailures, maxValue: maxAttempts}, + } { *item.arrayReference, err = bitarray.NewCompactArray(uint(size), bitarray.Item(item.maxValue)) if err != nil { @@ -258,6 +267,9 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // create array contexts arrayNodeExecutor, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec, subNodeStatus, arrayEventRecorder, err := a.buildArrayNodeContext(ctx, nCtx, &arrayNodeState, arrayNode, i, ¤tParallelism) + if err != nil { + return handler.UnknownTransition, err + } // execute subNode _, err = arrayNodeExecutor.RecursiveNodeHandler(ctx, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec) @@ -344,7 +356,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu } } - if len(arrayNodeState.SubNodePhases.GetItems()) - failedCount < minSuccesses { + if len(arrayNodeState.SubNodePhases.GetItems())-failedCount < minSuccesses { // no chance to reach the mininum number of successes arrayNodeState.Phase = v1alpha1.ArrayNodePhaseFailing } else if successCount >= minSuccesses && runningCount == 0 { @@ -385,8 +397,8 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu outputVariables = task.CoreTask().Interface.Outputs.Variables } - // append nil literal for all ouput variables - for name, _ := range outputVariables { + // append nil literal for all output variables + for name := range outputVariables { appendLiteral(name, nilLiteral, outputLiterals, len(arrayNodeState.SubNodePhases.GetItems())) } } else { @@ -457,11 +469,11 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // need to increment taskPhaseVersion if arrayNodeState.Phase does not change, otherwise // reset to 0. by incrementing this always we report an event and ensure processing // everytime the ArrayNode is evaluated. if this overhead becomes too large, we will need - // to revisit and only increment when any subNode state changes. + // to revisit and only increment when any subNode state changes. if currentArrayNodePhase != arrayNodeState.Phase { arrayNodeState.TaskPhaseVersion = 0 } else { - arrayNodeState.TaskPhaseVersion = taskPhaseVersion+1 + arrayNodeState.TaskPhaseVersion = taskPhaseVersion + 1 } taskExecutionEvent, err := buildTaskExecutionEvent(ctx, nCtx, taskPhase, taskPhaseVersion, externalResources) @@ -565,13 +577,13 @@ func (a *arrayNodeHandler) buildArrayNodeContext(ctx context.Context, nCtx inter } subNodeStatus := &v1alpha1.NodeStatus{ - Phase: nodePhase, - DataDir: subDataDir, - OutputDir: subOutputDir, - Attempts: currentAttempt, + Phase: nodePhase, + DataDir: subDataDir, + OutputDir: subOutputDir, + Attempts: currentAttempt, SystemFailures: uint32(arrayNodeState.SubNodeSystemFailures.GetItem(subNodeIndex)), TaskNodeStatus: &v1alpha1.TaskNodeStatus{ - Phase: taskPhase, + Phase: taskPhase, PluginState: pluginStateBytes, }, } diff --git a/pkg/controller/nodes/array/handler_test.go b/pkg/controller/nodes/array/handler_test.go index e2858a9b3..e07cbaadb 100644 --- a/pkg/controller/nodes/array/handler_test.go +++ b/pkg/controller/nodes/array/handler_test.go @@ -35,7 +35,7 @@ import ( ) var ( - taskRef = "taskRef" + taskRef = "taskRef" arrayNodeSpec = v1alpha1.NodeSpec{ ID: "foo", ArrayNode: &v1alpha1.ArrayNodeSpec{ @@ -46,7 +46,7 @@ var ( } ) -func createArrayNodeHandler(t *testing.T, ctx context.Context, nodeHandler interfaces.NodeHandler, dataStore *storage.DataStore, scope promutils.Scope) (interfaces.NodeHandler, error) { +func createArrayNodeHandler(ctx context.Context, t *testing.T, nodeHandler interfaces.NodeHandler, dataStore *storage.DataStore, scope promutils.Scope) (interfaces.NodeHandler, error) { // mock components adminClient := launchplan.NewFailFastLaunchPlanExecutor() enqueueWorkflowFunc := func(workflowID v1alpha1.WorkflowID) {} @@ -68,9 +68,8 @@ func createArrayNodeHandler(t *testing.T, ctx context.Context, nodeHandler inter return New(nodeExecutor, eventConfig, scope) } -func createNodeExecutionContext(t *testing.T, ctx context.Context, dataStore *storage.DataStore, eventRecorder interfaces.EventRecorder, - outputVariables []string, inputLiteralMap *idlcore.LiteralMap, arrayNodeSpec *v1alpha1.NodeSpec, - arrayNodeState *interfaces.ArrayNodeState) (interfaces.NodeExecutionContext, error) { +func createNodeExecutionContext(dataStore *storage.DataStore, eventRecorder interfaces.EventRecorder, outputVariables []string, + inputLiteralMap *idlcore.LiteralMap, arrayNodeSpec *v1alpha1.NodeSpec, arrayNodeState *interfaces.ArrayNodeState) interfaces.NodeExecutionContext { nCtx := &mocks.NodeExecutionContext{} @@ -87,7 +86,7 @@ func createNodeExecutionContext(t *testing.T, ctx context.Context, dataStore *st executionContext.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{}) executionContext.OnGetExecutionID().Return( v1alpha1.ExecutionID{ - &idlcore.WorkflowExecutionIdentifier{ + WorkflowExecutionIdentifier: &idlcore.WorkflowExecutionIdentifier{ Project: "project", Domain: "domain", Name: "name", @@ -103,7 +102,7 @@ func createNodeExecutionContext(t *testing.T, ctx context.Context, dataStore *st } executionContext.OnGetTaskMatch(taskRef).Return( &v1alpha1.TaskSpec{ - &idlcore.TaskTemplate{ + TaskTemplate: &idlcore.TaskTemplate{ Interface: &idlcore.TypedInterface{ Outputs: &idlcore.VariableMap{ Variables: outputVariableMap, @@ -161,11 +160,11 @@ func createNodeExecutionContext(t *testing.T, ctx context.Context, dataStore *st // NodeStatus nCtx.OnNodeStatus().Return(&v1alpha1.NodeStatus{ - DataDir: storage.DataReference("s3://bucket/data"), + DataDir: storage.DataReference("s3://bucket/data"), OutputDir: storage.DataReference("s3://bucket/output"), }) - return nCtx, nil + return nCtx } func TestAbort(t *testing.T) { @@ -186,7 +185,7 @@ func TestHandleArrayNodePhaseNone(t *testing.T) { nodeHandler := &mocks.NodeHandler{} // initialize ArrayNodeHandler - arrayNodeHandler, err := createArrayNodeHandler(t, ctx, nodeHandler, dataStore, scope) + arrayNodeHandler, err := createArrayNodeHandler(ctx, t, nodeHandler, dataStore, scope) assert.NoError(t, err) tests := []struct { @@ -201,8 +200,8 @@ func TestHandleArrayNodePhaseNone(t *testing.T) { inputValues: map[string][]int64{ "foo": []int64{1, 2}, }, - expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, - expectedTransitionPhase: handler.EPhaseRunning, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, + expectedTransitionPhase: handler.EPhaseRunning, expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_QUEUED, idlcore.TaskExecution_QUEUED}, }, { @@ -211,8 +210,8 @@ func TestHandleArrayNodePhaseNone(t *testing.T) { "foo": []int64{1, 2, 3}, "bar": []int64{4, 5, 6}, }, - expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, - expectedTransitionPhase: handler.EPhaseRunning, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, + expectedTransitionPhase: handler.EPhaseRunning, expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_QUEUED, idlcore.TaskExecution_QUEUED, idlcore.TaskExecution_QUEUED}, }, { @@ -221,8 +220,8 @@ func TestHandleArrayNodePhaseNone(t *testing.T) { "foo": []int64{1, 2}, "bar": []int64{3}, }, - expectedArrayNodePhase: v1alpha1.ArrayNodePhaseNone, - expectedTransitionPhase: handler.EPhaseFailed, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseNone, + expectedTransitionPhase: handler.EPhaseFailed, expectedExternalResourcePhases: nil, }, } @@ -235,8 +234,7 @@ func TestHandleArrayNodePhaseNone(t *testing.T) { arrayNodeState := &interfaces.ArrayNodeState{ Phase: v1alpha1.ArrayNodePhaseNone, } - nCtx, err := createNodeExecutionContext(t, ctx, dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState) - assert.NoError(t, err) + nCtx := createNodeExecutionContext(dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState) // evaluate node transition, err := arrayNodeHandler.Handle(ctx, nCtx) @@ -306,12 +304,12 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), }, - expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, - expectedTransitionPhase: handler.EPhaseRunning, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, + expectedTransitionPhase: handler.EPhaseRunning, expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_RUNNING, idlcore.TaskExecution_RUNNING}, }, { - name: "StartOneSubNodeParallelism", + name: "StartOneSubNodeParallelism", parallelism: 1, subNodePhases: []v1alpha1.NodePhase{ v1alpha1.NodePhaseQueued, @@ -325,8 +323,8 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), }, - expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, - expectedTransitionPhase: handler.EPhaseRunning, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, + expectedTransitionPhase: handler.EPhaseRunning, expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_RUNNING, idlcore.TaskExecution_QUEUED}, }, { @@ -343,12 +341,12 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(&handler.ExecutionInfo{})), handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(&handler.ExecutionInfo{})), }, - expectedArrayNodePhase: v1alpha1.ArrayNodePhaseSucceeding, - expectedTransitionPhase: handler.EPhaseRunning, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseSucceeding, + expectedTransitionPhase: handler.EPhaseRunning, expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_SUCCEEDED, idlcore.TaskExecution_SUCCEEDED}, }, { - name: "OneSubNodeSuccedeedMinSuccessRatio", + name: "OneSubNodeSuccedeedMinSuccessRatio", minSuccessRatio: &minSuccessRatio, subNodePhases: []v1alpha1.NodePhase{ v1alpha1.NodePhaseRunning, @@ -362,8 +360,8 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(&handler.ExecutionInfo{})), handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(0, "", "", &handler.ExecutionInfo{})), }, - expectedArrayNodePhase: v1alpha1.ArrayNodePhaseSucceeding, - expectedTransitionPhase: handler.EPhaseRunning, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseSucceeding, + expectedTransitionPhase: handler.EPhaseRunning, expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_SUCCEEDED, idlcore.TaskExecution_FAILED}, }, { @@ -380,8 +378,8 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(0, "", "", &handler.ExecutionInfo{})), handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(&handler.ExecutionInfo{})), }, - expectedArrayNodePhase: v1alpha1.ArrayNodePhaseFailing, - expectedTransitionPhase: handler.EPhaseRunning, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseFailing, + expectedTransitionPhase: handler.EPhaseRunning, expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_FAILED, idlcore.TaskExecution_SUCCEEDED}, }, } @@ -398,12 +396,15 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { arrayNodeState := &interfaces.ArrayNodeState{ Phase: v1alpha1.ArrayNodePhaseExecuting, } - for _, item := range []struct{arrayReference *bitarray.CompactArray; maxValue int}{ - {arrayReference: &arrayNodeState.SubNodePhases, maxValue: int(v1alpha1.NodePhaseRecovered)}, - {arrayReference: &arrayNodeState.SubNodeTaskPhases, maxValue: len(core.Phases)-1}, - {arrayReference: &arrayNodeState.SubNodeRetryAttempts, maxValue: 1}, - {arrayReference: &arrayNodeState.SubNodeSystemFailures, maxValue: 1}, - } { + for _, item := range []struct { + arrayReference *bitarray.CompactArray + maxValue int + }{ + {arrayReference: &arrayNodeState.SubNodePhases, maxValue: int(v1alpha1.NodePhaseRecovered)}, + {arrayReference: &arrayNodeState.SubNodeTaskPhases, maxValue: len(core.Phases) - 1}, + {arrayReference: &arrayNodeState.SubNodeRetryAttempts, maxValue: 1}, + {arrayReference: &arrayNodeState.SubNodeSystemFailures, maxValue: 1}, + } { *item.arrayReference, err = bitarray.NewCompactArray(uint(size), bitarray.Item(item.maxValue)) assert.NoError(t, err) @@ -423,8 +424,7 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { nodeSpec.ArrayNode.Parallelism = uint32(test.parallelism) nodeSpec.ArrayNode.MinSuccessRatio = test.minSuccessRatio - nCtx, err := createNodeExecutionContext(t, ctx, dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState) - assert.NoError(t, err) + nCtx := createNodeExecutionContext(dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState) // initialize ArrayNodeHandler nodeHandler := &mocks.NodeHandler{} @@ -434,7 +434,7 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { transitionPhase := test.expectedExternalResourcePhases[i] nodeHandler.OnHandleMatch(mock.Anything, mock.MatchedBy(func(arrayNCtx interfaces.NodeExecutionContext) bool { - return arrayNCtx.NodeID() == nodeID // match on NodeID using index to ensure each subNode is handled independently + return arrayNCtx.NodeID() == nodeID // match on NodeID using index to ensure each subNode is handled independently })).Run( func(args mock.Arguments) { // mock sending TaskExecutionEvent from handler to show task state transition @@ -442,12 +442,13 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { Phase: transitionPhase, } - args.Get(1).(interfaces.NodeExecutionContext).EventsRecorder().RecordTaskEvent(ctx, taskExecutionEvent, &config.EventConfig{}) + err := args.Get(1).(interfaces.NodeExecutionContext).EventsRecorder().RecordTaskEvent(ctx, taskExecutionEvent, &config.EventConfig{}) + assert.NoError(t, err) }, ).Return(transition, nil) } - arrayNodeHandler, err := createArrayNodeHandler(t, ctx, nodeHandler, dataStore, scope) + arrayNodeHandler, err := createArrayNodeHandler(ctx, t, nodeHandler, dataStore, scope) assert.NoError(t, err) // evaluate node @@ -484,26 +485,26 @@ func TestHandleArrayNodePhaseSucceeding(t *testing.T) { valueOne := 1 // initialize ArrayNodeHandler - arrayNodeHandler, err := createArrayNodeHandler(t, ctx, nodeHandler, dataStore, scope) + arrayNodeHandler, err := createArrayNodeHandler(ctx, t, nodeHandler, dataStore, scope) assert.NoError(t, err) tests := []struct { - name string - outputVariable string - outputValues []*int - subNodePhases []v1alpha1.NodePhase - expectedArrayNodePhase v1alpha1.ArrayNodePhase - expectedTransitionPhase handler.EPhase + name string + outputVariable string + outputValues []*int + subNodePhases []v1alpha1.NodePhase + expectedArrayNodePhase v1alpha1.ArrayNodePhase + expectedTransitionPhase handler.EPhase }{ { - name: "Success", - outputValues: []*int{&valueOne, nil}, + name: "Success", + outputValues: []*int{&valueOne, nil}, outputVariable: "foo", subNodePhases: []v1alpha1.NodePhase{ v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseFailed, }, - expectedArrayNodePhase: v1alpha1.ArrayNodePhaseSucceeding, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseSucceeding, expectedTransitionPhase: handler.EPhaseSuccess, }, } @@ -521,16 +522,15 @@ func TestHandleArrayNodePhaseSucceeding(t *testing.T) { assert.NoError(t, err) arrayNodeState := &interfaces.ArrayNodeState{ - Phase: v1alpha1.ArrayNodePhaseSucceeding, - SubNodePhases: subNodePhases, + Phase: v1alpha1.ArrayNodePhaseSucceeding, + SubNodePhases: subNodePhases, SubNodeRetryAttempts: retryAttempts, } // create NodeExecutionContext eventRecorder := newArrayEventRecorder() literalMap := &idlcore.LiteralMap{} - nCtx, err := createNodeExecutionContext(t, ctx, dataStore, eventRecorder, []string{test.outputVariable}, literalMap, &arrayNodeSpec, arrayNodeState) - assert.NoError(t, err) + nCtx := createNodeExecutionContext(dataStore, eventRecorder, []string{test.outputVariable}, literalMap, &arrayNodeSpec, arrayNodeState) // write mocked output files for i, outputValue := range test.outputValues { @@ -605,7 +605,7 @@ func TestHandleArrayNodePhaseFailing(t *testing.T) { nodeHandler.OnFinalizeMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) // initialize ArrayNodeHandler - arrayNodeHandler, err := createArrayNodeHandler(t, ctx, nodeHandler, dataStore, scope) + arrayNodeHandler, err := createArrayNodeHandler(ctx, t, nodeHandler, dataStore, scope) assert.NoError(t, err) tests := []struct { @@ -622,9 +622,9 @@ func TestHandleArrayNodePhaseFailing(t *testing.T) { v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseFailed, }, - expectedArrayNodePhase: v1alpha1.ArrayNodePhaseFailing, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseFailing, expectedTransitionPhase: handler.EPhaseFailed, - expectedAbortCalls: 1, + expectedAbortCalls: 1, }, } @@ -635,12 +635,15 @@ func TestHandleArrayNodePhaseFailing(t *testing.T) { Phase: v1alpha1.ArrayNodePhaseFailing, } - for _, item := range []struct{arrayReference *bitarray.CompactArray; maxValue int}{ - {arrayReference: &arrayNodeState.SubNodePhases, maxValue: int(v1alpha1.NodePhaseRecovered)}, - {arrayReference: &arrayNodeState.SubNodeTaskPhases, maxValue: len(core.Phases)-1}, - {arrayReference: &arrayNodeState.SubNodeRetryAttempts, maxValue: 1}, - {arrayReference: &arrayNodeState.SubNodeSystemFailures, maxValue: 1}, - } { + for _, item := range []struct { + arrayReference *bitarray.CompactArray + maxValue int + }{ + {arrayReference: &arrayNodeState.SubNodePhases, maxValue: int(v1alpha1.NodePhaseRecovered)}, + {arrayReference: &arrayNodeState.SubNodeTaskPhases, maxValue: len(core.Phases) - 1}, + {arrayReference: &arrayNodeState.SubNodeRetryAttempts, maxValue: 1}, + {arrayReference: &arrayNodeState.SubNodeSystemFailures, maxValue: 1}, + } { *item.arrayReference, err = bitarray.NewCompactArray(uint(len(test.subNodePhases)), bitarray.Item(item.maxValue)) assert.NoError(t, err) @@ -653,8 +656,7 @@ func TestHandleArrayNodePhaseFailing(t *testing.T) { // create NodeExecutionContext eventRecorder := newArrayEventRecorder() literalMap := &idlcore.LiteralMap{} - nCtx, err := createNodeExecutionContext(t, ctx, dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState) - assert.NoError(t, err) + nCtx := createNodeExecutionContext(dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState) // evaluate node transition, err := arrayNodeHandler.Handle(ctx, nCtx) diff --git a/pkg/controller/nodes/array/node_execution_context.go b/pkg/controller/nodes/array/node_execution_context.go index 2bd0005ea..af3ea42f7 100644 --- a/pkg/controller/nodes/array/node_execution_context.go +++ b/pkg/controller/nodes/array/node_execution_context.go @@ -3,8 +3,8 @@ package array import ( "context" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" @@ -108,7 +108,6 @@ func (a *arrayTaskReader) Read(ctx context.Context) (*core.TaskTemplate, error) return taskTemplate, nil } - type arrayNodeExecutionContext struct { interfaces.NodeExecutionContext eventRecorder interfaces.EventRecorder diff --git a/pkg/controller/nodes/array/node_execution_context_builder.go b/pkg/controller/nodes/array/node_execution_context_builder.go index e8367e2eb..de145b95a 100644 --- a/pkg/controller/nodes/array/node_execution_context_builder.go +++ b/pkg/controller/nodes/array/node_execution_context_builder.go @@ -39,7 +39,7 @@ func (a *arrayNodeExecutionContextBuilder) BuildNodeExecutionContext(ctx context } func newArrayNodeExecutionContextBuilder(nCtxBuilder interfaces.NodeExecutionContextBuilder, subNodeID v1alpha1.NodeID, - subNodeIndex int, subNodeStatus *v1alpha1.NodeStatus, inputReader io.InputReader, eventRecorder interfaces.EventRecorder, + subNodeIndex int, subNodeStatus *v1alpha1.NodeStatus, inputReader io.InputReader, eventRecorder interfaces.EventRecorder, currentParallelism *uint32, maxParallelism uint32) interfaces.NodeExecutionContextBuilder { return &arrayNodeExecutionContextBuilder{ diff --git a/pkg/controller/nodes/array/node_lookup.go b/pkg/controller/nodes/array/node_lookup.go index d1ef8fe55..061b323af 100644 --- a/pkg/controller/nodes/array/node_lookup.go +++ b/pkg/controller/nodes/array/node_lookup.go @@ -32,9 +32,9 @@ func (a *arrayNodeLookup) GetNodeExecutionStatus(ctx context.Context, id v1alpha func newArrayNodeLookup(nodeLookup executors.NodeLookup, subNodeID v1alpha1.NodeID, subNodeSpec *v1alpha1.NodeSpec, subNodeStatus *v1alpha1.NodeStatus) arrayNodeLookup { return arrayNodeLookup{ - NodeLookup: nodeLookup, - subNodeID: subNodeID, - subNodeSpec: subNodeSpec, - subNodeStatus: subNodeStatus, + NodeLookup: nodeLookup, + subNodeID: subNodeID, + subNodeSpec: subNodeSpec, + subNodeStatus: subNodeStatus, } } diff --git a/pkg/controller/nodes/array/utils.go b/pkg/controller/nodes/array/utils.go index ad4f80837..7f330063d 100644 --- a/pkg/controller/nodes/array/utils.go +++ b/pkg/controller/nodes/array/utils.go @@ -38,7 +38,7 @@ func appendLiteral(name string, literal *idlcore.Literal, outputLiterals map[str collection.Literals = append(collection.Literals, literal) } -func buildTaskExecutionEvent(ctx context.Context, nCtx interfaces.NodeExecutionContext, taskPhase idlcore.TaskExecution_Phase, taskPhaseVersion uint32, externalResources []*event.ExternalResourceInfo) (*event.TaskExecutionEvent, error) { +func buildTaskExecutionEvent(_ context.Context, nCtx interfaces.NodeExecutionContext, taskPhase idlcore.TaskExecution_Phase, taskPhaseVersion uint32, externalResources []*event.ExternalResourceInfo) (*event.TaskExecutionEvent, error) { occurredAt, err := ptypes.TimestampProto(time.Now()) if err != nil { return nil, err @@ -55,7 +55,7 @@ func buildTaskExecutionEvent(ctx context.Context, nCtx interfaces.NodeExecutionC Version: "v1", // this value is irrelevant but necessary for the identifier to be valid }, ParentNodeExecutionId: nCtx.NodeExecutionMetadata().GetNodeExecutionID(), - RetryAttempt: 0, // ArrayNode will never retry + RetryAttempt: 0, // ArrayNode will never retry Phase: taskPhase, PhaseVersion: taskPhaseVersion, OccurredAt: occurredAt, @@ -84,7 +84,7 @@ func bytesFromK8sPluginState(pluginState k8s.PluginState) ([]byte, error) { return bufferWriter.Bytes(), nil } -func constructOutputReferences(ctx context.Context, nCtx interfaces.NodeExecutionContext, postfix...string) (storage.DataReference, storage.DataReference, error) { +func constructOutputReferences(ctx context.Context, nCtx interfaces.NodeExecutionContext, postfix ...string) (storage.DataReference, storage.DataReference, error) { subDataDir, err := nCtx.DataStore().ConstructReference(ctx, nCtx.NodeStatus().GetDataDir(), postfix...) if err != nil { return "", "", err diff --git a/pkg/controller/nodes/executor.go b/pkg/controller/nodes/executor.go index bfc5691c5..fe2d2be6c 100644 --- a/pkg/controller/nodes/executor.go +++ b/pkg/controller/nodes/executor.go @@ -87,13 +87,12 @@ type nodeMetrics struct { // Implements the executors.Node interface type recursiveNodeExecutor struct { - nodeExecutor interfaces.NodeExecutor - nCtxBuilder interfaces.NodeExecutionContextBuilder - - enqueueWorkflow v1alpha1.EnqueueWorkflow - nodeHandlerFactory interfaces.HandlerFactory - store *storage.DataStore - metrics *nodeMetrics + nodeExecutor interfaces.NodeExecutor + nCtxBuilder interfaces.NodeExecutionContextBuilder + enqueueWorkflow v1alpha1.EnqueueWorkflow + nodeHandlerFactory interfaces.HandlerFactory + store *storage.DataStore + metrics *nodeMetrics } func (c *recursiveNodeExecutor) SetInputsForStartNode(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructureWithStartNode, nl executors.NodeLookup, inputs *core.LiteralMap) (interfaces.NodeStatus, error) { @@ -162,7 +161,6 @@ func IsMaxParallelismAchieved(ctx context.Context, currentNode v1alpha1.Executab return false } - // RecursiveNodeHandler This is the entrypoint of executing a node in a workflow. A workflow consists of nodes, that are // nested within other nodes. The system follows an actor model, where the parent nodes control the execution of nested nodes // The recursive node-handler uses a modified depth-first type of algorithm to execute non-blocked nodes. @@ -495,12 +493,12 @@ func (c *recursiveNodeExecutor) GetNodeExecutionContextBuilder() interfaces.Node func (c *recursiveNodeExecutor) WithNodeExecutionContextBuilder(nCtxBuilder interfaces.NodeExecutionContextBuilder) interfaces.Node { return &recursiveNodeExecutor{ - nodeExecutor: c.nodeExecutor, - nCtxBuilder: nCtxBuilder, - enqueueWorkflow: c.enqueueWorkflow, - nodeHandlerFactory: c.nodeHandlerFactory, - store: c.store, - metrics: c.metrics, + nodeExecutor: c.nodeExecutor, + nCtxBuilder: nCtxBuilder, + enqueueWorkflow: c.enqueueWorkflow, + nodeHandlerFactory: c.nodeHandlerFactory, + store: c.store, + metrics: c.metrics, } } @@ -515,12 +513,12 @@ type nodeExecutor struct { interruptibleFailureThreshold uint32 maxDatasetSizeBytes int64 maxNodeRetriesForSystemFailures uint32 - metrics *nodeMetrics + metrics *nodeMetrics nodeRecorder events.NodeEventRecorder outputResolver OutputResolver - recoveryClient recovery.Client + recoveryClient recovery.Client shardSelector ioutils.ShardSelector - store *storage.DataStore + store *storage.DataStore taskRecorder events.TaskEventRecorder } @@ -1284,18 +1282,18 @@ func NewExecutor(ctx context.Context, nodeConfig config.NodeConfig, store *stora defaultActiveDeadline: nodeConfig.DefaultDeadlines.DefaultNodeActiveDeadline.Duration, defaultDataSandbox: defaultRawOutputPrefix, defaultExecutionDeadline: nodeConfig.DefaultDeadlines.DefaultNodeExecutionDeadline.Duration, - enqueueWorkflow: enQWorkflow, + enqueueWorkflow: enQWorkflow, eventConfig: eventConfig, interruptibleFailureThreshold: uint32(nodeConfig.InterruptibleFailureThreshold), - maxDatasetSizeBytes: maxDatasetSize, + maxDatasetSizeBytes: maxDatasetSize, maxNodeRetriesForSystemFailures: uint32(nodeConfig.MaxNodeRetriesOnSystemFailures), - metrics: metrics, - nodeRecorder: events.NewNodeEventRecorder(eventSink, nodeScope, store), + metrics: metrics, + nodeRecorder: events.NewNodeEventRecorder(eventSink, nodeScope, store), outputResolver: NewRemoteFileOutputResolver(store), recoveryClient: recoveryClient, shardSelector: shardSelector, - store: store, - taskRecorder: events.NewTaskEventRecorder(eventSink, scope.NewSubScope("task"), store), + store: store, + taskRecorder: events.NewTaskEventRecorder(eventSink, scope.NewSubScope("task"), store), } exec := &recursiveNodeExecutor{ diff --git a/pkg/controller/nodes/executor_test.go b/pkg/controller/nodes/executor_test.go index f3f759881..43cb7fe33 100644 --- a/pkg/controller/nodes/executor_test.go +++ b/pkg/controller/nodes/executor_test.go @@ -1357,10 +1357,10 @@ func TestNodeExecutor_RecursiveNodeHandler_BranchNode(t *testing.T) { func Test_nodeExecutor_RecordTransitionLatency(t *testing.T) { testScope := promutils.NewTestScope() type fields struct { - enqueueWorkflow v1alpha1.EnqueueWorkflow - store *storage.DataStore - nodeRecorder events.NodeEventRecorder - metrics *nodeMetrics + enqueueWorkflow v1alpha1.EnqueueWorkflow + store *storage.DataStore + nodeRecorder events.NodeEventRecorder + metrics *nodeMetrics } type args struct { w v1alpha1.ExecutableWorkflow @@ -1401,10 +1401,10 @@ func Test_nodeExecutor_RecordTransitionLatency(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &nodeExecutor{ - enqueueWorkflow: tt.fields.enqueueWorkflow, - store: tt.fields.store, - nodeRecorder: tt.fields.nodeRecorder, - metrics: tt.fields.metrics, + enqueueWorkflow: tt.fields.enqueueWorkflow, + store: tt.fields.store, + nodeRecorder: tt.fields.nodeRecorder, + metrics: tt.fields.metrics, } c.RecordTransitionLatency(context.TODO(), tt.args.w, tt.args.w, tt.args.node, tt.args.nodeStatus) diff --git a/pkg/controller/nodes/factory/handler_factory.go b/pkg/controller/nodes/factory/handler_factory.go index 8a3718fb3..9ec00da7a 100644 --- a/pkg/controller/nodes/factory/handler_factory.go +++ b/pkg/controller/nodes/factory/handler_factory.go @@ -85,12 +85,12 @@ func NewHandlerFactory(ctx context.Context, workflowLauncher launchplan.Executor return &handlerFactory{ workflowLauncher: workflowLauncher, launchPlanReader: launchPlanReader, - kubeClient: kubeClient, + kubeClient: kubeClient, catalogClient: catalogClient, recoveryClient: recoveryClient, - eventConfig: eventConfig, + eventConfig: eventConfig, clusterID: clusterID, - signalClient: signalClient, - scope: scope, + signalClient: signalClient, + scope: scope, }, nil } diff --git a/pkg/controller/nodes/node_exec_context.go b/pkg/controller/nodes/node_exec_context.go index 6506ae527..41cb9831b 100644 --- a/pkg/controller/nodes/node_exec_context.go +++ b/pkg/controller/nodes/node_exec_context.go @@ -16,8 +16,8 @@ import ( "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" "github.com/flyteorg/flytepropeller/pkg/controller/config" "github.com/flyteorg/flytepropeller/pkg/controller/executors" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" nodeerrors "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flytepropeller/pkg/utils" "github.com/flyteorg/flytestdlib/logger" @@ -127,7 +127,6 @@ type nodeExecContext struct { tr interfaces.TaskReader md interfaces.NodeExecutionMetadata eventRecorder interfaces.EventRecorder - //er events.TaskEventRecorder inputs io.InputReader node v1alpha1.ExecutableNode nodeStatus v1alpha1.ExecutableNodeStatus @@ -240,11 +239,11 @@ func newNodeExecContext(_ context.Context, store *storage.DataStore, execContext md.nodeLabels = nodeLabels return &nodeExecContext{ - md: md, - store: store, - node: node, - nodeStatus: nodeStatus, - inputs: inputs, + md: md, + store: store, + node: node, + nodeStatus: nodeStatus, + inputs: inputs, eventRecorder: &eventRecorder{ taskEventRecorder: taskEventRecorder, nodeEventRecorder: nodeEventRecorder, From e0f156d6d30673e62942842c3effa49acd8e9a15 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Thu, 15 Jun 2023 21:38:01 -0500 Subject: [PATCH 45/62] updated flyteidl dep Signed-off-by: Daniel Rammer --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index d95ee04f7..bd52dd8f8 100644 --- a/go.mod +++ b/go.mod @@ -148,4 +148,4 @@ require ( replace github.com/aws/amazon-sagemaker-operator-for-k8s => github.com/aws/amazon-sagemaker-operator-for-k8s v1.0.1-0.20210303003444-0fb33b1fd49d -replace github.com/flyteorg/flyteidl => ../flyteidl +replace github.com/flyteorg/flyteidl => github.com/flyteorg/flyteidl v1.5.11-0.20230614183933-d56d4d37bf34 diff --git a/go.sum b/go.sum index 901198892..af66d661e 100644 --- a/go.sum +++ b/go.sum @@ -260,6 +260,8 @@ github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5Kwzbycv github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/flyteorg/flyteidl v1.5.11-0.20230614183933-d56d4d37bf34 h1:Gj5UKqJU+ozeTeYAvDWHiF4HSVufHW1W1ecymFfbbis= +github.com/flyteorg/flyteidl v1.5.11-0.20230614183933-d56d4d37bf34/go.mod h1:EtE/muM2lHHgBabjYcxqe9TWeJSL0kXwbI0RgVwI4Og= github.com/flyteorg/flyteplugins v1.0.64 h1:t9S59C0s2nXsUEqKbQZJYHQzSfkhjJh1dcXlBlxhEUk= github.com/flyteorg/flyteplugins v1.0.64/go.mod h1:HHt4nKDKVwrZPKDsj99dNtDSIJL378xNotYMA3a/TFA= github.com/flyteorg/flytestdlib v1.0.15 h1:kv9jDQmytbE84caY+pkZN8trJU2ouSAmESzpTEhfTt0= From 211ff29c9745a5b78abb78965c616751a7a57abb Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Thu, 29 Jun 2023 16:07:42 -0500 Subject: [PATCH 46/62] added unit tests for Abort Signed-off-by: Daniel Rammer --- .../flyteworkflow/v1alpha1/node_status.go | 1 - pkg/controller/nodes/array/handler_test.go | 94 ++++++++++++++++++- 2 files changed, 93 insertions(+), 2 deletions(-) diff --git a/pkg/apis/flyteworkflow/v1alpha1/node_status.go b/pkg/apis/flyteworkflow/v1alpha1/node_status.go index ea95c3b56..93078a1a8 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/node_status.go +++ b/pkg/apis/flyteworkflow/v1alpha1/node_status.go @@ -214,7 +214,6 @@ const ( ArrayNodePhaseExecuting ArrayNodePhaseFailing ArrayNodePhaseSucceeding - // TODO @hamersaw - need more phases ) type ArrayNodeStatus struct { diff --git a/pkg/controller/nodes/array/handler_test.go b/pkg/controller/nodes/array/handler_test.go index e07cbaadb..57a7cc4b7 100644 --- a/pkg/controller/nodes/array/handler_test.go +++ b/pkg/controller/nodes/array/handler_test.go @@ -168,7 +168,99 @@ func createNodeExecutionContext(dataStore *storage.DataStore, eventRecorder inte } func TestAbort(t *testing.T) { - // TODO @hamersaw - complete + ctx := context.Background() + scope := promutils.NewTestScope() + dataStore, err := storage.NewDataStore(&storage.Config{ + Type: storage.TypeMemory, + }, scope) + assert.NoError(t, err) + + nodeHandler := &mocks.NodeHandler{} + nodeHandler.OnAbortMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) + nodeHandler.OnFinalizeMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) + + // initialize ArrayNodeHandler + arrayNodeHandler, err := createArrayNodeHandler(ctx, t, nodeHandler, dataStore, scope) + assert.NoError(t, err) + + tests := []struct { + name string + inputMap map[string][]int64 + subNodePhases []v1alpha1.NodePhase + subNodeTaskPhases []core.Phase + expectedExternalResourcePhases []idlcore.TaskExecution_Phase + }{ + { + name: "Success", + inputMap: map[string][]int64{ + "foo": []int64{0, 1, 2}, + }, + subNodePhases: []v1alpha1.NodePhase{v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseNotYetStarted}, + subNodeTaskPhases: []core.Phase{core.PhaseSuccess, core.PhaseRunning, core.PhaseUndefined}, + expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_ABORTED}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // initailize universal variables + literalMap := convertMapToArrayLiterals(test.inputMap) + + size := -1 + for _, v := range test.inputMap { + if size == -1 { + size = len(v) + } else if len(v) > size { // calculating size as largest input list + size = len(v) + } + } + + // initialize ArrayNodeState + arrayNodeState := &interfaces.ArrayNodeState{ + Phase: v1alpha1.ArrayNodePhaseFailing, + } + for _, item := range []struct { + arrayReference *bitarray.CompactArray + maxValue int + }{ + {arrayReference: &arrayNodeState.SubNodePhases, maxValue: int(v1alpha1.NodePhaseRecovered)}, + {arrayReference: &arrayNodeState.SubNodeTaskPhases, maxValue: len(core.Phases) - 1}, + {arrayReference: &arrayNodeState.SubNodeRetryAttempts, maxValue: 1}, + {arrayReference: &arrayNodeState.SubNodeSystemFailures, maxValue: 1}, + } { + + *item.arrayReference, err = bitarray.NewCompactArray(uint(size), bitarray.Item(item.maxValue)) + assert.NoError(t, err) + } + + for i, nodePhase := range test.subNodePhases { + arrayNodeState.SubNodePhases.SetItem(i, bitarray.Item(nodePhase)) + } + for i, taskPhase := range test.subNodeTaskPhases { + arrayNodeState.SubNodeTaskPhases.SetItem(i, bitarray.Item(taskPhase)) + } + + // create NodeExecutionContext + eventRecorder := newArrayEventRecorder() + nCtx := createNodeExecutionContext(dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState) + + // evaluate node + err := arrayNodeHandler.Abort(ctx, nCtx, "foo") + assert.NoError(t, err) + + if len(test.expectedExternalResourcePhases) > 0 { + assert.Equal(t, 1, len(eventRecorder.taskEvents)) + + externalResources := eventRecorder.taskEvents[0].Metadata.GetExternalResources() + assert.Equal(t, len(test.expectedExternalResourcePhases), len(externalResources)) + for i, expectedPhase := range test.expectedExternalResourcePhases { + assert.Equal(t, expectedPhase, externalResources[i].Phase) + } + } else { + assert.Equal(t, 0, len(eventRecorder.taskEvents)) + } + }) + } } func TestFinalize(t *testing.T) { From c6df3fdf4b38166f6b16b22c6470a8c58cdc2b16 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Thu, 29 Jun 2023 16:27:23 -0500 Subject: [PATCH 47/62] adding unit test for Finalize Signed-off-by: Daniel Rammer --- pkg/controller/nodes/array/handler.go | 1 + pkg/controller/nodes/array/handler_test.go | 86 +++++++++++++++++++++- 2 files changed, 86 insertions(+), 1 deletion(-) diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index a1d7c0223..8123d2d04 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -564,6 +564,7 @@ func (a *arrayNodeHandler) buildArrayNodeContext(ctx context.Context, nCtx inter // index. however when we check completion status we need to manually append index - so in all cases // where the node phase is not Queued (ie. task handler will launch task and init flytekit params) we // append the subtask index. + // TODO @hamersaw - verify this has been fixed in flytekit for arraynode implementation /*var subDataDir, subOutputDir storage.DataReference if nodePhase == v1alpha1.NodePhaseQueued { subDataDir, subOutputDir, err = constructOutputReferences(ctx, nCtx) diff --git a/pkg/controller/nodes/array/handler_test.go b/pkg/controller/nodes/array/handler_test.go index 57a7cc4b7..07c908831 100644 --- a/pkg/controller/nodes/array/handler_test.go +++ b/pkg/controller/nodes/array/handler_test.go @@ -75,6 +75,7 @@ func createNodeExecutionContext(dataStore *storage.DataStore, eventRecorder inte // ContextualNodeLookup nodeLookup := &execmocks.NodeLookup{} + nodeLookup.OnFromNodeMatch(mock.Anything).Return(nil, nil) nCtx.OnContextualNodeLookup().Return(nodeLookup) // DataStore @@ -248,6 +249,7 @@ func TestAbort(t *testing.T) { err := arrayNodeHandler.Abort(ctx, nCtx, "foo") assert.NoError(t, err) + nodeHandler.AssertNumberOfCalls(t, "Abort", len(test.expectedExternalResourcePhases)) if len(test.expectedExternalResourcePhases) > 0 { assert.Equal(t, 1, len(eventRecorder.taskEvents)) @@ -264,7 +266,89 @@ func TestAbort(t *testing.T) { } func TestFinalize(t *testing.T) { - // TODO @hamersaw - complete + ctx := context.Background() + scope := promutils.NewTestScope() + dataStore, err := storage.NewDataStore(&storage.Config{ + Type: storage.TypeMemory, + }, scope) + assert.NoError(t, err) + + nodeHandler := &mocks.NodeHandler{} + nodeHandler.OnFinalizeMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) + + // initialize ArrayNodeHandler + arrayNodeHandler, err := createArrayNodeHandler(ctx, t, nodeHandler, dataStore, scope) + assert.NoError(t, err) + + tests := []struct { + name string + inputMap map[string][]int64 + subNodePhases []v1alpha1.NodePhase + subNodeTaskPhases []core.Phase + expectedFinalizeCalls int + }{ + { + name: "Success", + inputMap: map[string][]int64{ + "foo": []int64{0, 1, 2}, + }, + subNodePhases: []v1alpha1.NodePhase{v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseNotYetStarted}, + subNodeTaskPhases: []core.Phase{core.PhaseSuccess, core.PhaseRunning, core.PhaseUndefined}, + expectedFinalizeCalls: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // initailize universal variables + literalMap := convertMapToArrayLiterals(test.inputMap) + + size := -1 + for _, v := range test.inputMap { + if size == -1 { + size = len(v) + } else if len(v) > size { // calculating size as largest input list + size = len(v) + } + } + + // initialize ArrayNodeState + arrayNodeState := &interfaces.ArrayNodeState{ + Phase: v1alpha1.ArrayNodePhaseFailing, + } + for _, item := range []struct { + arrayReference *bitarray.CompactArray + maxValue int + }{ + {arrayReference: &arrayNodeState.SubNodePhases, maxValue: int(v1alpha1.NodePhaseRecovered)}, + {arrayReference: &arrayNodeState.SubNodeTaskPhases, maxValue: len(core.Phases) - 1}, + {arrayReference: &arrayNodeState.SubNodeRetryAttempts, maxValue: 1}, + {arrayReference: &arrayNodeState.SubNodeSystemFailures, maxValue: 1}, + } { + + *item.arrayReference, err = bitarray.NewCompactArray(uint(size), bitarray.Item(item.maxValue)) + assert.NoError(t, err) + } + + for i, nodePhase := range test.subNodePhases { + arrayNodeState.SubNodePhases.SetItem(i, bitarray.Item(nodePhase)) + } + for i, taskPhase := range test.subNodeTaskPhases { + arrayNodeState.SubNodeTaskPhases.SetItem(i, bitarray.Item(taskPhase)) + } + + // create NodeExecutionContext + eventRecorder := newArrayEventRecorder() + nCtx := createNodeExecutionContext(dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState) + + // evaluate node + err := arrayNodeHandler.Finalize(ctx, nCtx) + assert.NoError(t, err) + + // validate + nodeHandler.AssertNumberOfCalls(t, "Finalize", test.expectedFinalizeCalls) + }) + } } func TestHandleArrayNodePhaseNone(t *testing.T) { From 7c1931d062bd658800c16c61c0102bb87b678c16 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Thu, 29 Jun 2023 20:05:51 -0500 Subject: [PATCH 48/62] added utils unit tests Signed-off-by: Daniel Rammer --- pkg/controller/nodes/array/utils_test.go | 36 ++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 pkg/controller/nodes/array/utils_test.go diff --git a/pkg/controller/nodes/array/utils_test.go b/pkg/controller/nodes/array/utils_test.go new file mode 100644 index 000000000..2e3eaf6e6 --- /dev/null +++ b/pkg/controller/nodes/array/utils_test.go @@ -0,0 +1,36 @@ +package array + +import ( + "testing" + + idlcore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + + "github.com/stretchr/testify/assert" +) + +func TestAppendLiteral(t *testing.T) { + outputLiterals := make(map[string]*idlcore.Literal) + literalMaps := []map[string]*idlcore.Literal{ + map[string]*idlcore.Literal{ + "foo": nilLiteral, + "bar": nilLiteral, + }, + map[string]*idlcore.Literal{ + "foo": nilLiteral, + "bar": nilLiteral, + }, + } + + for _, m := range literalMaps { + for k, v := range m { + appendLiteral(k, v, outputLiterals, len(literalMaps)) + } + } + + for _, v := range outputLiterals { + collection, ok := v.Value.(*idlcore.Literal_Collection) + assert.True(t, ok) + + assert.Equal(t, 2, len(collection.Collection.Literals)) + } +} From 0df9572bd9a5045f33eadb200de06ad520c9a272 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Thu, 29 Jun 2023 20:22:20 -0500 Subject: [PATCH 49/62] moved state structs to handler package Signed-off-by: Daniel Rammer --- pkg/controller/nodes/array/handler.go | 2 +- pkg/controller/nodes/array/handler_test.go | 16 ++-- pkg/controller/nodes/branch/handler.go | 2 +- pkg/controller/nodes/branch/handler_test.go | 18 ++--- .../nodes/dynamic/dynamic_workflow.go | 10 +-- .../nodes/dynamic/dynamic_workflow_test.go | 4 +- pkg/controller/nodes/dynamic/handler.go | 24 +++--- pkg/controller/nodes/dynamic/handler_test.go | 30 +++---- pkg/controller/nodes/executor_test.go | 8 +- pkg/controller/nodes/gate/handler_test.go | 3 +- pkg/controller/nodes/handler/state.go | 60 ++++++++++++++ .../{interfaces => handler}/state_test.go | 2 +- .../interfaces/mocks/node_state_reader.go | 63 ++++++++------- .../interfaces/mocks/node_state_writer.go | 39 ++++----- pkg/controller/nodes/interfaces/state.go | 81 +++---------------- pkg/controller/nodes/node_state_manager.go | 52 ++++++------ pkg/controller/nodes/subworkflow/handler.go | 2 +- .../nodes/subworkflow/handler_test.go | 19 +++-- .../nodes/subworkflow/subworkflow.go | 2 +- pkg/controller/nodes/task/handler.go | 4 +- pkg/controller/nodes/task/handler_test.go | 26 +++--- .../nodes/task/taskexec_context_test.go | 8 +- 22 files changed, 240 insertions(+), 235 deletions(-) create mode 100644 pkg/controller/nodes/handler/state.go rename pkg/controller/nodes/{interfaces => handler}/state_test.go (97%) diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 8123d2d04..3e566d7c3 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -524,7 +524,7 @@ func New(nodeExecutor interfaces.Node, eventConfig *config.EventConfig, scope pr }, nil } -func (a *arrayNodeHandler) buildArrayNodeContext(ctx context.Context, nCtx interfaces.NodeExecutionContext, arrayNodeState *interfaces.ArrayNodeState, arrayNode v1alpha1.ExecutableArrayNode, subNodeIndex int, currentParallelism *uint32) ( +func (a *arrayNodeHandler) buildArrayNodeContext(ctx context.Context, nCtx interfaces.NodeExecutionContext, arrayNodeState *handler.ArrayNodeState, arrayNode v1alpha1.ExecutableArrayNode, subNodeIndex int, currentParallelism *uint32) ( interfaces.Node, executors.ExecutionContext, executors.DAGStructure, executors.NodeLookup, *v1alpha1.NodeSpec, *v1alpha1.NodeStatus, *arrayEventRecorder, error) { nodePhase := v1alpha1.NodePhase(arrayNodeState.SubNodePhases.GetItem(subNodeIndex)) diff --git a/pkg/controller/nodes/array/handler_test.go b/pkg/controller/nodes/array/handler_test.go index 07c908831..49e975061 100644 --- a/pkg/controller/nodes/array/handler_test.go +++ b/pkg/controller/nodes/array/handler_test.go @@ -69,7 +69,7 @@ func createArrayNodeHandler(ctx context.Context, t *testing.T, nodeHandler inter } func createNodeExecutionContext(dataStore *storage.DataStore, eventRecorder interfaces.EventRecorder, outputVariables []string, - inputLiteralMap *idlcore.LiteralMap, arrayNodeSpec *v1alpha1.NodeSpec, arrayNodeState *interfaces.ArrayNodeState) interfaces.NodeExecutionContext { + inputLiteralMap *idlcore.LiteralMap, arrayNodeSpec *v1alpha1.NodeSpec, arrayNodeState *handler.ArrayNodeState) interfaces.NodeExecutionContext { nCtx := &mocks.NodeExecutionContext{} @@ -154,7 +154,7 @@ func createNodeExecutionContext(dataStore *storage.DataStore, eventRecorder inte nodeStateWriter := &mocks.NodeStateWriter{} nodeStateWriter.OnPutArrayNodeStateMatch(mock.Anything, mock.Anything).Run( func(args mock.Arguments) { - *arrayNodeState = args.Get(0).(interfaces.ArrayNodeState) + *arrayNodeState = args.Get(0).(handler.ArrayNodeState) }, ).Return(nil) nCtx.OnNodeStateWriter().Return(nodeStateWriter) @@ -217,7 +217,7 @@ func TestAbort(t *testing.T) { } // initialize ArrayNodeState - arrayNodeState := &interfaces.ArrayNodeState{ + arrayNodeState := &handler.ArrayNodeState{ Phase: v1alpha1.ArrayNodePhaseFailing, } for _, item := range []struct { @@ -313,7 +313,7 @@ func TestFinalize(t *testing.T) { } // initialize ArrayNodeState - arrayNodeState := &interfaces.ArrayNodeState{ + arrayNodeState := &handler.ArrayNodeState{ Phase: v1alpha1.ArrayNodePhaseFailing, } for _, item := range []struct { @@ -407,7 +407,7 @@ func TestHandleArrayNodePhaseNone(t *testing.T) { // create NodeExecutionContext eventRecorder := newArrayEventRecorder() literalMap := convertMapToArrayLiterals(test.inputValues) - arrayNodeState := &interfaces.ArrayNodeState{ + arrayNodeState := &handler.ArrayNodeState{ Phase: v1alpha1.ArrayNodePhaseNone, } nCtx := createNodeExecutionContext(dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState) @@ -569,7 +569,7 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { assert.NoError(t, err) // initialize ArrayNodeState - arrayNodeState := &interfaces.ArrayNodeState{ + arrayNodeState := &handler.ArrayNodeState{ Phase: v1alpha1.ArrayNodePhaseExecuting, } for _, item := range []struct { @@ -697,7 +697,7 @@ func TestHandleArrayNodePhaseSucceeding(t *testing.T) { retryAttempts, err := bitarray.NewCompactArray(uint(len(test.subNodePhases)), bitarray.Item(1)) assert.NoError(t, err) - arrayNodeState := &interfaces.ArrayNodeState{ + arrayNodeState := &handler.ArrayNodeState{ Phase: v1alpha1.ArrayNodePhaseSucceeding, SubNodePhases: subNodePhases, SubNodeRetryAttempts: retryAttempts, @@ -807,7 +807,7 @@ func TestHandleArrayNodePhaseFailing(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { // initialize ArrayNodeState - arrayNodeState := &interfaces.ArrayNodeState{ + arrayNodeState := &handler.ArrayNodeState{ Phase: v1alpha1.ArrayNodePhaseFailing, } diff --git a/pkg/controller/nodes/branch/handler.go b/pkg/controller/nodes/branch/handler.go index 3e4b2897e..ed7324552 100644 --- a/pkg/controller/nodes/branch/handler.go +++ b/pkg/controller/nodes/branch/handler.go @@ -57,7 +57,7 @@ func (b *branchHandler) HandleBranchNode(ctx context.Context, branchNode v1alpha return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, errors.IllegalStateError, errMsg, nil)), nil } - branchNodeState := interfaces.BranchNodeState{FinalizedNodeID: finalNodeID, Phase: v1alpha1.BranchNodeSuccess} + branchNodeState := handler.BranchNodeState{FinalizedNodeID: finalNodeID, Phase: v1alpha1.BranchNodeSuccess} err = nCtx.NodeStateWriter().PutBranchNode(branchNodeState) if err != nil { logger.Errorf(ctx, "Failed to store BranchNode state, err :%s", err.Error()) diff --git a/pkg/controller/nodes/branch/handler_test.go b/pkg/controller/nodes/branch/handler_test.go index a2271bc6a..774c4a459 100644 --- a/pkg/controller/nodes/branch/handler_test.go +++ b/pkg/controller/nodes/branch/handler_test.go @@ -35,34 +35,34 @@ var eventConfig = &config.EventConfig{ } type branchNodeStateHolder struct { - s interfaces.BranchNodeState + s handler.BranchNodeState } func (t *branchNodeStateHolder) ClearNodeStatus() { } -func (t *branchNodeStateHolder) PutTaskNodeState(s interfaces.TaskNodeState) error { +func (t *branchNodeStateHolder) PutTaskNodeState(s handler.TaskNodeState) error { panic("not implemented") } -func (t *branchNodeStateHolder) PutBranchNode(s interfaces.BranchNodeState) error { +func (t *branchNodeStateHolder) PutBranchNode(s handler.BranchNodeState) error { t.s = s return nil } -func (t branchNodeStateHolder) PutWorkflowNodeState(s interfaces.WorkflowNodeState) error { +func (t branchNodeStateHolder) PutWorkflowNodeState(s handler.WorkflowNodeState) error { panic("not implemented") } -func (t branchNodeStateHolder) PutDynamicNodeState(s interfaces.DynamicNodeState) error { +func (t branchNodeStateHolder) PutDynamicNodeState(s handler.DynamicNodeState) error { panic("not implemented") } -func (t branchNodeStateHolder) PutGateNodeState(s interfaces.GateNodeState) error { +func (t branchNodeStateHolder) PutGateNodeState(s handler.GateNodeState) error { panic("not implemented") } -func (t branchNodeStateHolder) PutArrayNodeState(s interfaces.ArrayNodeState) error { +func (t branchNodeStateHolder) PutArrayNodeState(s handler.ArrayNodeState) error { panic("not implemented") } @@ -79,7 +79,7 @@ func (parentInfo) CurrentAttempt() uint32 { func createNodeContext(phase v1alpha1.BranchNodePhase, childNodeID *v1alpha1.NodeID, n v1alpha1.ExecutableNode, inputs *core.LiteralMap, nl *execMocks.NodeLookup, eCtx executors.ExecutionContext) (*mocks.NodeExecutionContext, *branchNodeStateHolder) { - branchNodeState := interfaces.BranchNodeState{ + branchNodeState := handler.BranchNodeState{ FinalizedNodeID: childNodeID, Phase: phase, } @@ -127,7 +127,7 @@ func createNodeContext(phase v1alpha1.BranchNodePhase, childNodeID *v1alpha1.Nod nCtx.OnEnqueueOwnerFunc().Return(nil) nr := &mocks.NodeStateReader{} - nr.OnGetBranchNodeState().Return(interfaces.BranchNodeState{ + nr.OnGetBranchNodeState().Return(handler.BranchNodeState{ FinalizedNodeID: childNodeID, Phase: phase, }) diff --git a/pkg/controller/nodes/dynamic/dynamic_workflow.go b/pkg/controller/nodes/dynamic/dynamic_workflow.go index eb71c8520..89757c430 100644 --- a/pkg/controller/nodes/dynamic/dynamic_workflow.go +++ b/pkg/controller/nodes/dynamic/dynamic_workflow.go @@ -265,7 +265,7 @@ func (d dynamicNodeTaskNodeHandler) buildDynamicWorkflow(ctx context.Context, nC } func (d dynamicNodeTaskNodeHandler) progressDynamicWorkflow(ctx context.Context, execContext executors.ExecutionContext, dynamicWorkflow v1alpha1.ExecutableWorkflow, nl executors.NodeLookup, - nCtx interfaces.NodeExecutionContext, prevState interfaces.DynamicNodeState) (handler.Transition, interfaces.DynamicNodeState, error) { + nCtx interfaces.NodeExecutionContext, prevState handler.DynamicNodeState) (handler.Transition, handler.DynamicNodeState, error) { state, err := d.nodeExecutor.RecursiveNodeHandler(ctx, execContext, dynamicWorkflow, nl, dynamicWorkflow.StartNode()) if err != nil { @@ -281,7 +281,7 @@ func (d dynamicNodeTaskNodeHandler) progressDynamicWorkflow(ctx context.Context, // As we do not support Failure Node, we can just return failure in this case return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoDynamicRunning(nil)), - interfaces.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: "Dynamic workflow failed", Error: state.Err}, + handler.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: "Dynamic workflow failed", Error: state.Err}, nil } @@ -293,7 +293,7 @@ func (d dynamicNodeTaskNodeHandler) progressDynamicWorkflow(ctx context.Context, endNodeStatus := dynamicNodeStatus.GetNodeExecutionStatus(ctx, v1alpha1.EndNodeID) if endNodeStatus == nil { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, "MalformedDynamicWorkflow", "no end-node found in dynamic workflow", nil)), - interfaces.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: "no end-node found in dynamic workflow"}, + handler.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: "no end-node found in dynamic workflow"}, nil } @@ -301,7 +301,7 @@ func (d dynamicNodeTaskNodeHandler) progressDynamicWorkflow(ctx context.Context, if metadata, err := nCtx.DataStore().Head(ctx, sourcePath); err == nil { if !metadata.Exists() { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRetryableFailure(core.ExecutionError_SYSTEM, "DynamicWorkflowOutputsNotFound", fmt.Sprintf(" is expected to produce outputs but no outputs file was written to %v.", sourcePath), nil)), - interfaces.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: "DynamicWorkflow is expected to produce outputs but no outputs file was written"}, + handler.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: "DynamicWorkflow is expected to produce outputs but no outputs file was written"}, nil } } else { @@ -313,7 +313,7 @@ func (d dynamicNodeTaskNodeHandler) progressDynamicWorkflow(ctx context.Context, return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, "OutputsNotFound", fmt.Sprintf("Failed to copy subworkflow outputs from [%v] to [%v]. Error: %s", sourcePath, destinationPath, err.Error()), nil), - ), interfaces.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: "Failed to copy subworkflow outputs"}, + ), handler.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: "Failed to copy subworkflow outputs"}, nil } diff --git a/pkg/controller/nodes/dynamic/dynamic_workflow_test.go b/pkg/controller/nodes/dynamic/dynamic_workflow_test.go index 6750e79e2..88b317397 100644 --- a/pkg/controller/nodes/dynamic/dynamic_workflow_test.go +++ b/pkg/controller/nodes/dynamic/dynamic_workflow_test.go @@ -24,7 +24,7 @@ import ( mocks2 "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" mocks4 "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" mocks6 "github.com/flyteorg/flytepropeller/pkg/controller/nodes/dynamic/mocks" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" mocks5 "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/mocks" ) @@ -135,7 +135,7 @@ func Test_dynamicNodeHandler_buildContextualDynamicWorkflow_withLaunchPlans(t *t w.OnGetExecutionStatus().Return(ws) r := &mocks.NodeStateReader{} - r.OnGetDynamicNodeState().Return(interfaces.DynamicNodeState{ + r.OnGetDynamicNodeState().Return(handler.DynamicNodeState{ Phase: v1alpha1.DynamicNodePhaseExecuting, }) nCtx.OnNodeStateReader().Return(r) diff --git a/pkg/controller/nodes/dynamic/handler.go b/pkg/controller/nodes/dynamic/handler.go index 22cca7c77..2ea98f741 100644 --- a/pkg/controller/nodes/dynamic/handler.go +++ b/pkg/controller/nodes/dynamic/handler.go @@ -63,7 +63,7 @@ type dynamicNodeTaskNodeHandler struct { eventConfig *config.EventConfig } -func (d dynamicNodeTaskNodeHandler) handleParentNode(ctx context.Context, prevState interfaces.DynamicNodeState, nCtx interfaces.NodeExecutionContext) (handler.Transition, interfaces.DynamicNodeState, error) { +func (d dynamicNodeTaskNodeHandler) handleParentNode(ctx context.Context, prevState handler.DynamicNodeState, nCtx interfaces.NodeExecutionContext) (handler.Transition, handler.DynamicNodeState, error) { // It seems parent node is still running, lets call handle for parent node trns, err := d.TaskNodeHandler.Handle(ctx, nCtx) if err != nil { @@ -85,7 +85,7 @@ func (d dynamicNodeTaskNodeHandler) handleParentNode(ctx context.Context, prevSt // directly to record, and then progress the dynamically generated workflow. logger.Infof(ctx, "future file detected, assuming dynamic node") // There is a futures file, so we need to continue running the node with the modified state - return trns.WithInfo(handler.PhaseInfoRunning(trns.Info().GetInfo())), interfaces.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseParentFinalizing}, nil + return trns.WithInfo(handler.PhaseInfoRunning(trns.Info().GetInfo())), handler.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseParentFinalizing}, nil } } @@ -94,7 +94,7 @@ func (d dynamicNodeTaskNodeHandler) handleParentNode(ctx context.Context, prevSt } func (d dynamicNodeTaskNodeHandler) produceDynamicWorkflow(ctx context.Context, nCtx interfaces.NodeExecutionContext) ( - handler.Transition, interfaces.DynamicNodeState, error) { + handler.Transition, handler.DynamicNodeState, error) { // The first time this is called we go ahead and evaluate the dynamic node to build the workflow. We then cache // this workflow definition and send it to be persisted by flyteadmin so that users can observe the structure. dCtx, err := d.buildContextualDynamicWorkflow(ctx, nCtx) @@ -102,9 +102,9 @@ func (d dynamicNodeTaskNodeHandler) produceDynamicWorkflow(ctx context.Context, if stdErrors.IsCausedBy(err, utils.ErrorCodeUser) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_USER, "DynamicWorkflowBuildFailed", err.Error(), nil), - ), interfaces.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: err.Error()}, nil + ), handler.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: err.Error()}, nil } - return handler.Transition{}, interfaces.DynamicNodeState{}, err + return handler.Transition{}, handler.DynamicNodeState{}, err } taskNodeInfoMetadata := &event.TaskNodeMetadata{} if dCtx.subWorkflowClosure != nil && dCtx.subWorkflowClosure.Primary != nil && dCtx.subWorkflowClosure.Primary.Template != nil { @@ -115,7 +115,7 @@ func (d dynamicNodeTaskNodeHandler) produceDynamicWorkflow(ctx context.Context, } } - nextState := interfaces.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseExecuting} + nextState := handler.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseExecuting} return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoDynamicRunning(&handler.ExecutionInfo{ TaskNodeInfo: &handler.TaskNodeInfo{ TaskNodeMetadata: taskNodeInfoMetadata, @@ -123,16 +123,16 @@ func (d dynamicNodeTaskNodeHandler) produceDynamicWorkflow(ctx context.Context, })), nextState, nil } -func (d dynamicNodeTaskNodeHandler) handleDynamicSubNodes(ctx context.Context, nCtx interfaces.NodeExecutionContext, prevState interfaces.DynamicNodeState) (handler.Transition, interfaces.DynamicNodeState, error) { +func (d dynamicNodeTaskNodeHandler) handleDynamicSubNodes(ctx context.Context, nCtx interfaces.NodeExecutionContext, prevState handler.DynamicNodeState) (handler.Transition, handler.DynamicNodeState, error) { dCtx, err := d.buildContextualDynamicWorkflow(ctx, nCtx) if err != nil { if stdErrors.IsCausedBy(err, utils.ErrorCodeUser) { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_USER, "DynamicWorkflowBuildFailed", err.Error(), nil), - ), interfaces.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: err.Error()}, nil + ), handler.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: err.Error()}, nil } // Mostly a system error or unknown - return handler.Transition{}, interfaces.DynamicNodeState{}, err + return handler.Transition{}, handler.DynamicNodeState{}, err } trns, newState, err := d.progressDynamicWorkflow(ctx, dCtx.execContext, dCtx.subWorkflow, dCtx.nodeLookup, nCtx, prevState) @@ -158,10 +158,10 @@ func (d dynamicNodeTaskNodeHandler) handleDynamicSubNodes(ctx context.Context, n if ee != nil { if ee.IsRecoverable { - return trns.WithInfo(handler.PhaseInfoRetryableFailureErr(ee.ExecutionError, trns.Info().GetInfo())), interfaces.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: ee.ExecutionError.String()}, nil + return trns.WithInfo(handler.PhaseInfoRetryableFailureErr(ee.ExecutionError, trns.Info().GetInfo())), handler.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: ee.ExecutionError.String()}, nil } - return trns.WithInfo(handler.PhaseInfoFailureErr(ee.ExecutionError, trns.Info().GetInfo())), interfaces.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: ee.ExecutionError.String()}, nil + return trns.WithInfo(handler.PhaseInfoFailureErr(ee.ExecutionError, trns.Info().GetInfo())), handler.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: ee.ExecutionError.String()}, nil } taskNodeInfoMetadata := &event.TaskNodeMetadata{CacheStatus: status.GetCacheStatus(), CatalogKey: status.GetMetadata()} trns = trns.WithInfo(trns.Info().WithInfo(&handler.ExecutionInfo{ @@ -210,7 +210,7 @@ func (d dynamicNodeTaskNodeHandler) Handle(ctx context.Context, nCtx interfaces. if err := d.finalizeParentNode(ctx, nCtx); err != nil { return handler.UnknownTransition, err } - newState = interfaces.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseParentFinalized} + newState = handler.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseParentFinalized} trns = handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(trns.Info().GetInfo())) case v1alpha1.DynamicNodePhaseParentFinalized: trns, newState, err = d.produceDynamicWorkflow(ctx, nCtx) diff --git a/pkg/controller/nodes/dynamic/handler_test.go b/pkg/controller/nodes/dynamic/handler_test.go index 7fdb14dcd..ae0bb6912 100644 --- a/pkg/controller/nodes/dynamic/handler_test.go +++ b/pkg/controller/nodes/dynamic/handler_test.go @@ -36,34 +36,34 @@ import ( ) type dynamicNodeStateHolder struct { - s interfaces.DynamicNodeState + s handler.DynamicNodeState } func (t *dynamicNodeStateHolder) ClearNodeStatus() { } -func (t *dynamicNodeStateHolder) PutTaskNodeState(s interfaces.TaskNodeState) error { +func (t *dynamicNodeStateHolder) PutTaskNodeState(s handler.TaskNodeState) error { panic("not implemented") } -func (t dynamicNodeStateHolder) PutBranchNode(s interfaces.BranchNodeState) error { +func (t dynamicNodeStateHolder) PutBranchNode(s handler.BranchNodeState) error { panic("not implemented") } -func (t dynamicNodeStateHolder) PutWorkflowNodeState(s interfaces.WorkflowNodeState) error { +func (t dynamicNodeStateHolder) PutWorkflowNodeState(s handler.WorkflowNodeState) error { panic("not implemented") } -func (t *dynamicNodeStateHolder) PutDynamicNodeState(s interfaces.DynamicNodeState) error { +func (t *dynamicNodeStateHolder) PutDynamicNodeState(s handler.DynamicNodeState) error { t.s = s return nil } -func (t dynamicNodeStateHolder) PutGateNodeState(s interfaces.GateNodeState) error { +func (t dynamicNodeStateHolder) PutGateNodeState(s handler.GateNodeState) error { panic("not implemented") } -func (t dynamicNodeStateHolder) PutArrayNodeState(s interfaces.ArrayNodeState) error { +func (t dynamicNodeStateHolder) PutArrayNodeState(s handler.ArrayNodeState) error { panic("not implemented") } @@ -148,7 +148,7 @@ func Test_dynamicNodeHandler_Handle_Parent(t *testing.T) { nCtx.OnDataStore().Return(dataStore) r := &nodeMocks.NodeStateReader{} - r.OnGetDynamicNodeState().Return(interfaces.DynamicNodeState{}) + r.OnGetDynamicNodeState().Return(handler.DynamicNodeState{}) nCtx.OnNodeStateReader().Return(r) return nCtx } @@ -289,7 +289,7 @@ func Test_dynamicNodeHandler_Handle_ParentFinalize(t *testing.T) { nCtx.OnDataStore().Return(dataStore) r := &nodeMocks.NodeStateReader{} - r.On("GetDynamicNodeState").Return(interfaces.DynamicNodeState{ + r.On("GetDynamicNodeState").Return(handler.DynamicNodeState{ Phase: v1alpha1.DynamicNodePhaseParentFinalizing, }) nCtx.OnNodeStateReader().Return(r) @@ -300,7 +300,7 @@ func Test_dynamicNodeHandler_Handle_ParentFinalize(t *testing.T) { t.Run("parent-finalize-success", func(t *testing.T) { nCtx := createNodeContext("test") s := &dynamicNodeStateHolder{ - s: interfaces.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseParentFinalizing}, + s: handler.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseParentFinalizing}, } nCtx.OnNodeStateWriter().Return(s) f, err := nCtx.DataStore().ConstructReference(context.TODO(), nCtx.NodeStatus().GetDataDir(), "futures.pb") @@ -320,7 +320,7 @@ func Test_dynamicNodeHandler_Handle_ParentFinalize(t *testing.T) { t.Run("parent-finalize-error", func(t *testing.T) { nCtx := createNodeContext("test") s := &dynamicNodeStateHolder{ - s: interfaces.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseParentFinalizing}, + s: handler.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseParentFinalizing}, } nCtx.OnNodeStateWriter().Return(s) f, err := nCtx.DataStore().ConstructReference(context.TODO(), nCtx.NodeStatus().GetDataDir(), "futures.pb") @@ -513,7 +513,7 @@ func Test_dynamicNodeHandler_Handle_SubTaskV1(t *testing.T) { w.OnGetExecutionStatus().Return(ws) r := &nodeMocks.NodeStateReader{} - r.OnGetDynamicNodeState().Return(interfaces.DynamicNodeState{ + r.OnGetDynamicNodeState().Return(handler.DynamicNodeState{ Phase: v1alpha1.DynamicNodePhaseExecuting, }) nCtx.OnNodeStateReader().Return(r) @@ -746,7 +746,7 @@ func Test_dynamicNodeHandler_Handle_SubTask(t *testing.T) { w.OnGetExecutionStatus().Return(ws) r := &nodeMocks.NodeStateReader{} - r.OnGetDynamicNodeState().Return(interfaces.DynamicNodeState{ + r.OnGetDynamicNodeState().Return(handler.DynamicNodeState{ Phase: v1alpha1.DynamicNodePhaseExecuting, }) nCtx.OnNodeStateReader().Return(r) @@ -870,7 +870,7 @@ func TestDynamicNodeTaskNodeHandler_Finalize(t *testing.T) { ctx := context.TODO() t.Run("dynamicnodephase-none", func(t *testing.T) { - s := interfaces.DynamicNodeState{ + s := handler.DynamicNodeState{ Phase: v1alpha1.DynamicNodePhaseNone, Reason: "", } @@ -997,7 +997,7 @@ func TestDynamicNodeTaskNodeHandler_Finalize(t *testing.T) { w.OnGetExecutionStatus().Return(ws) r := &nodeMocks.NodeStateReader{} - r.OnGetDynamicNodeState().Return(interfaces.DynamicNodeState{ + r.OnGetDynamicNodeState().Return(handler.DynamicNodeState{ Phase: v1alpha1.DynamicNodePhaseExecuting, }) nCtx.OnNodeStateReader().Return(r) diff --git a/pkg/controller/nodes/executor_test.go b/pkg/controller/nodes/executor_test.go index 43cb7fe33..b7505ec92 100644 --- a/pkg/controller/nodes/executor_test.go +++ b/pkg/controller/nodes/executor_test.go @@ -2192,12 +2192,12 @@ func TestRecover(t *testing.T) { nCtx.OnDataStore().Return(storageClient) reader := &nodemocks.NodeStateReader{} - reader.OnGetDynamicNodeState().Return(interfaces.DynamicNodeState{}) + reader.OnGetDynamicNodeState().Return(handler.DynamicNodeState{}) nCtx.OnNodeStateReader().Return(reader) writer := &nodemocks.NodeStateWriter{} writer.OnPutDynamicNodeStateMatch(mock.Anything).Run(func(args mock.Arguments) { - state := args.Get(0).(interfaces.DynamicNodeState) + state := args.Get(0).(handler.DynamicNodeState) assert.Equal(t, v1alpha1.DynamicNodePhaseParentFinalized, state.Phase) }).Return(nil) nCtx.OnNodeStateWriter().Return(writer) @@ -2413,11 +2413,11 @@ func TestRecover(t *testing.T) { } reader := &nodemocks.NodeStateReader{} - reader.OnGetTaskNodeState().Return(interfaces.TaskNodeState{}) + reader.OnGetTaskNodeState().Return(handler.TaskNodeState{}) nCtx.OnNodeStateReader().Return(reader) writer := &nodemocks.NodeStateWriter{} writer.OnPutTaskNodeStateMatch(mock.Anything).Run(func(args mock.Arguments) { - state := args.Get(0).(interfaces.TaskNodeState) + state := args.Get(0).(handler.TaskNodeState) assert.Equal(t, state.PreviousNodeExecutionCheckpointURI.String(), "prev path") }).Return(nil) nCtx.OnNodeStateWriter().Return(writer) diff --git a/pkg/controller/nodes/gate/handler_test.go b/pkg/controller/nodes/gate/handler_test.go index 0de925797..b60b9db24 100644 --- a/pkg/controller/nodes/gate/handler_test.go +++ b/pkg/controller/nodes/gate/handler_test.go @@ -17,7 +17,6 @@ import ( executormocks "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/gate/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" nodeMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" "github.com/flyteorg/flytestdlib/contextutils" @@ -123,7 +122,7 @@ func createNodeExecutionContext(gateNode *v1alpha1.GateNodeSpec) *nodeMocks.Node nCtx.OnInputReader().Return(inputReader) r := &nodeMocks.NodeStateReader{} - r.OnGetGateNodeState().Return(interfaces.GateNodeState{}) + r.OnGetGateNodeState().Return(handler.GateNodeState{}) nCtx.OnNodeStateReader().Return(r) w := &nodeMocks.NodeStateWriter{} diff --git a/pkg/controller/nodes/handler/state.go b/pkg/controller/nodes/handler/state.go new file mode 100644 index 000000000..5290546b9 --- /dev/null +++ b/pkg/controller/nodes/handler/state.go @@ -0,0 +1,60 @@ +package handler + +import ( + "time" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + + pluginCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + + "github.com/flyteorg/flytestdlib/bitarray" + "github.com/flyteorg/flytestdlib/storage" +) + +// This is the legacy state structure that gets translated to node status +// TODO eventually we could just convert this to be binary node state encoded into the node status + +type TaskNodeState struct { + PluginPhase pluginCore.Phase + PluginPhaseVersion uint32 + PluginState []byte + PluginStateVersion uint32 + LastPhaseUpdatedAt time.Time + PreviousNodeExecutionCheckpointURI storage.DataReference + CleanupOnFailure bool +} + +type BranchNodeState struct { + FinalizedNodeID *v1alpha1.NodeID + Phase v1alpha1.BranchNodePhase +} + +type DynamicNodePhase uint8 + +type DynamicNodeState struct { + Phase v1alpha1.DynamicNodePhase + Reason string + Error *core.ExecutionError +} + +type WorkflowNodeState struct { + Phase v1alpha1.WorkflowNodePhase + Error *core.ExecutionError +} + +type GateNodeState struct { + Phase v1alpha1.GateNodePhase + StartedAt time.Time +} + +type ArrayNodeState struct { + Phase v1alpha1.ArrayNodePhase + TaskPhaseVersion uint32 + Error *core.ExecutionError + SubNodePhases bitarray.CompactArray + SubNodeTaskPhases bitarray.CompactArray + SubNodeRetryAttempts bitarray.CompactArray + SubNodeSystemFailures bitarray.CompactArray +} diff --git a/pkg/controller/nodes/interfaces/state_test.go b/pkg/controller/nodes/handler/state_test.go similarity index 97% rename from pkg/controller/nodes/interfaces/state_test.go rename to pkg/controller/nodes/handler/state_test.go index d7d9d62fd..7e914422e 100644 --- a/pkg/controller/nodes/interfaces/state_test.go +++ b/pkg/controller/nodes/handler/state_test.go @@ -1,4 +1,4 @@ -package interfaces +package handler import ( "bytes" diff --git a/pkg/controller/nodes/interfaces/mocks/node_state_reader.go b/pkg/controller/nodes/interfaces/mocks/node_state_reader.go index 2f8191d62..853eb5b67 100644 --- a/pkg/controller/nodes/interfaces/mocks/node_state_reader.go +++ b/pkg/controller/nodes/interfaces/mocks/node_state_reader.go @@ -3,7 +3,8 @@ package mocks import ( - interfaces "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + handler "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + mock "github.com/stretchr/testify/mock" ) @@ -16,7 +17,7 @@ type NodeStateReader_GetArrayNodeState struct { *mock.Call } -func (_m NodeStateReader_GetArrayNodeState) Return(_a0 interfaces.ArrayNodeState) *NodeStateReader_GetArrayNodeState { +func (_m NodeStateReader_GetArrayNodeState) Return(_a0 handler.ArrayNodeState) *NodeStateReader_GetArrayNodeState { return &NodeStateReader_GetArrayNodeState{Call: _m.Call.Return(_a0)} } @@ -31,14 +32,14 @@ func (_m *NodeStateReader) OnGetArrayNodeStateMatch(matchers ...interface{}) *No } // GetArrayNodeState provides a mock function with given fields: -func (_m *NodeStateReader) GetArrayNodeState() interfaces.ArrayNodeState { +func (_m *NodeStateReader) GetArrayNodeState() handler.ArrayNodeState { ret := _m.Called() - var r0 interfaces.ArrayNodeState - if rf, ok := ret.Get(0).(func() interfaces.ArrayNodeState); ok { + var r0 handler.ArrayNodeState + if rf, ok := ret.Get(0).(func() handler.ArrayNodeState); ok { r0 = rf() } else { - r0 = ret.Get(0).(interfaces.ArrayNodeState) + r0 = ret.Get(0).(handler.ArrayNodeState) } return r0 @@ -48,7 +49,7 @@ type NodeStateReader_GetBranchNodeState struct { *mock.Call } -func (_m NodeStateReader_GetBranchNodeState) Return(_a0 interfaces.BranchNodeState) *NodeStateReader_GetBranchNodeState { +func (_m NodeStateReader_GetBranchNodeState) Return(_a0 handler.BranchNodeState) *NodeStateReader_GetBranchNodeState { return &NodeStateReader_GetBranchNodeState{Call: _m.Call.Return(_a0)} } @@ -63,14 +64,14 @@ func (_m *NodeStateReader) OnGetBranchNodeStateMatch(matchers ...interface{}) *N } // GetBranchNodeState provides a mock function with given fields: -func (_m *NodeStateReader) GetBranchNodeState() interfaces.BranchNodeState { +func (_m *NodeStateReader) GetBranchNodeState() handler.BranchNodeState { ret := _m.Called() - var r0 interfaces.BranchNodeState - if rf, ok := ret.Get(0).(func() interfaces.BranchNodeState); ok { + var r0 handler.BranchNodeState + if rf, ok := ret.Get(0).(func() handler.BranchNodeState); ok { r0 = rf() } else { - r0 = ret.Get(0).(interfaces.BranchNodeState) + r0 = ret.Get(0).(handler.BranchNodeState) } return r0 @@ -80,7 +81,7 @@ type NodeStateReader_GetDynamicNodeState struct { *mock.Call } -func (_m NodeStateReader_GetDynamicNodeState) Return(_a0 interfaces.DynamicNodeState) *NodeStateReader_GetDynamicNodeState { +func (_m NodeStateReader_GetDynamicNodeState) Return(_a0 handler.DynamicNodeState) *NodeStateReader_GetDynamicNodeState { return &NodeStateReader_GetDynamicNodeState{Call: _m.Call.Return(_a0)} } @@ -95,14 +96,14 @@ func (_m *NodeStateReader) OnGetDynamicNodeStateMatch(matchers ...interface{}) * } // GetDynamicNodeState provides a mock function with given fields: -func (_m *NodeStateReader) GetDynamicNodeState() interfaces.DynamicNodeState { +func (_m *NodeStateReader) GetDynamicNodeState() handler.DynamicNodeState { ret := _m.Called() - var r0 interfaces.DynamicNodeState - if rf, ok := ret.Get(0).(func() interfaces.DynamicNodeState); ok { + var r0 handler.DynamicNodeState + if rf, ok := ret.Get(0).(func() handler.DynamicNodeState); ok { r0 = rf() } else { - r0 = ret.Get(0).(interfaces.DynamicNodeState) + r0 = ret.Get(0).(handler.DynamicNodeState) } return r0 @@ -112,7 +113,7 @@ type NodeStateReader_GetGateNodeState struct { *mock.Call } -func (_m NodeStateReader_GetGateNodeState) Return(_a0 interfaces.GateNodeState) *NodeStateReader_GetGateNodeState { +func (_m NodeStateReader_GetGateNodeState) Return(_a0 handler.GateNodeState) *NodeStateReader_GetGateNodeState { return &NodeStateReader_GetGateNodeState{Call: _m.Call.Return(_a0)} } @@ -127,14 +128,14 @@ func (_m *NodeStateReader) OnGetGateNodeStateMatch(matchers ...interface{}) *Nod } // GetGateNodeState provides a mock function with given fields: -func (_m *NodeStateReader) GetGateNodeState() interfaces.GateNodeState { +func (_m *NodeStateReader) GetGateNodeState() handler.GateNodeState { ret := _m.Called() - var r0 interfaces.GateNodeState - if rf, ok := ret.Get(0).(func() interfaces.GateNodeState); ok { + var r0 handler.GateNodeState + if rf, ok := ret.Get(0).(func() handler.GateNodeState); ok { r0 = rf() } else { - r0 = ret.Get(0).(interfaces.GateNodeState) + r0 = ret.Get(0).(handler.GateNodeState) } return r0 @@ -144,7 +145,7 @@ type NodeStateReader_GetTaskNodeState struct { *mock.Call } -func (_m NodeStateReader_GetTaskNodeState) Return(_a0 interfaces.TaskNodeState) *NodeStateReader_GetTaskNodeState { +func (_m NodeStateReader_GetTaskNodeState) Return(_a0 handler.TaskNodeState) *NodeStateReader_GetTaskNodeState { return &NodeStateReader_GetTaskNodeState{Call: _m.Call.Return(_a0)} } @@ -159,14 +160,14 @@ func (_m *NodeStateReader) OnGetTaskNodeStateMatch(matchers ...interface{}) *Nod } // GetTaskNodeState provides a mock function with given fields: -func (_m *NodeStateReader) GetTaskNodeState() interfaces.TaskNodeState { +func (_m *NodeStateReader) GetTaskNodeState() handler.TaskNodeState { ret := _m.Called() - var r0 interfaces.TaskNodeState - if rf, ok := ret.Get(0).(func() interfaces.TaskNodeState); ok { + var r0 handler.TaskNodeState + if rf, ok := ret.Get(0).(func() handler.TaskNodeState); ok { r0 = rf() } else { - r0 = ret.Get(0).(interfaces.TaskNodeState) + r0 = ret.Get(0).(handler.TaskNodeState) } return r0 @@ -176,7 +177,7 @@ type NodeStateReader_GetWorkflowNodeState struct { *mock.Call } -func (_m NodeStateReader_GetWorkflowNodeState) Return(_a0 interfaces.WorkflowNodeState) *NodeStateReader_GetWorkflowNodeState { +func (_m NodeStateReader_GetWorkflowNodeState) Return(_a0 handler.WorkflowNodeState) *NodeStateReader_GetWorkflowNodeState { return &NodeStateReader_GetWorkflowNodeState{Call: _m.Call.Return(_a0)} } @@ -191,14 +192,14 @@ func (_m *NodeStateReader) OnGetWorkflowNodeStateMatch(matchers ...interface{}) } // GetWorkflowNodeState provides a mock function with given fields: -func (_m *NodeStateReader) GetWorkflowNodeState() interfaces.WorkflowNodeState { +func (_m *NodeStateReader) GetWorkflowNodeState() handler.WorkflowNodeState { ret := _m.Called() - var r0 interfaces.WorkflowNodeState - if rf, ok := ret.Get(0).(func() interfaces.WorkflowNodeState); ok { + var r0 handler.WorkflowNodeState + if rf, ok := ret.Get(0).(func() handler.WorkflowNodeState); ok { r0 = rf() } else { - r0 = ret.Get(0).(interfaces.WorkflowNodeState) + r0 = ret.Get(0).(handler.WorkflowNodeState) } return r0 diff --git a/pkg/controller/nodes/interfaces/mocks/node_state_writer.go b/pkg/controller/nodes/interfaces/mocks/node_state_writer.go index 93334a42d..46c0e2a38 100644 --- a/pkg/controller/nodes/interfaces/mocks/node_state_writer.go +++ b/pkg/controller/nodes/interfaces/mocks/node_state_writer.go @@ -3,7 +3,8 @@ package mocks import ( - interfaces "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + handler "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + mock "github.com/stretchr/testify/mock" ) @@ -25,7 +26,7 @@ func (_m NodeStateWriter_PutArrayNodeState) Return(_a0 error) *NodeStateWriter_P return &NodeStateWriter_PutArrayNodeState{Call: _m.Call.Return(_a0)} } -func (_m *NodeStateWriter) OnPutArrayNodeState(s interfaces.ArrayNodeState) *NodeStateWriter_PutArrayNodeState { +func (_m *NodeStateWriter) OnPutArrayNodeState(s handler.ArrayNodeState) *NodeStateWriter_PutArrayNodeState { c_call := _m.On("PutArrayNodeState", s) return &NodeStateWriter_PutArrayNodeState{Call: c_call} } @@ -36,11 +37,11 @@ func (_m *NodeStateWriter) OnPutArrayNodeStateMatch(matchers ...interface{}) *No } // PutArrayNodeState provides a mock function with given fields: s -func (_m *NodeStateWriter) PutArrayNodeState(s interfaces.ArrayNodeState) error { +func (_m *NodeStateWriter) PutArrayNodeState(s handler.ArrayNodeState) error { ret := _m.Called(s) var r0 error - if rf, ok := ret.Get(0).(func(interfaces.ArrayNodeState) error); ok { + if rf, ok := ret.Get(0).(func(handler.ArrayNodeState) error); ok { r0 = rf(s) } else { r0 = ret.Error(0) @@ -57,7 +58,7 @@ func (_m NodeStateWriter_PutBranchNode) Return(_a0 error) *NodeStateWriter_PutBr return &NodeStateWriter_PutBranchNode{Call: _m.Call.Return(_a0)} } -func (_m *NodeStateWriter) OnPutBranchNode(s interfaces.BranchNodeState) *NodeStateWriter_PutBranchNode { +func (_m *NodeStateWriter) OnPutBranchNode(s handler.BranchNodeState) *NodeStateWriter_PutBranchNode { c_call := _m.On("PutBranchNode", s) return &NodeStateWriter_PutBranchNode{Call: c_call} } @@ -68,11 +69,11 @@ func (_m *NodeStateWriter) OnPutBranchNodeMatch(matchers ...interface{}) *NodeSt } // PutBranchNode provides a mock function with given fields: s -func (_m *NodeStateWriter) PutBranchNode(s interfaces.BranchNodeState) error { +func (_m *NodeStateWriter) PutBranchNode(s handler.BranchNodeState) error { ret := _m.Called(s) var r0 error - if rf, ok := ret.Get(0).(func(interfaces.BranchNodeState) error); ok { + if rf, ok := ret.Get(0).(func(handler.BranchNodeState) error); ok { r0 = rf(s) } else { r0 = ret.Error(0) @@ -89,7 +90,7 @@ func (_m NodeStateWriter_PutDynamicNodeState) Return(_a0 error) *NodeStateWriter return &NodeStateWriter_PutDynamicNodeState{Call: _m.Call.Return(_a0)} } -func (_m *NodeStateWriter) OnPutDynamicNodeState(s interfaces.DynamicNodeState) *NodeStateWriter_PutDynamicNodeState { +func (_m *NodeStateWriter) OnPutDynamicNodeState(s handler.DynamicNodeState) *NodeStateWriter_PutDynamicNodeState { c_call := _m.On("PutDynamicNodeState", s) return &NodeStateWriter_PutDynamicNodeState{Call: c_call} } @@ -100,11 +101,11 @@ func (_m *NodeStateWriter) OnPutDynamicNodeStateMatch(matchers ...interface{}) * } // PutDynamicNodeState provides a mock function with given fields: s -func (_m *NodeStateWriter) PutDynamicNodeState(s interfaces.DynamicNodeState) error { +func (_m *NodeStateWriter) PutDynamicNodeState(s handler.DynamicNodeState) error { ret := _m.Called(s) var r0 error - if rf, ok := ret.Get(0).(func(interfaces.DynamicNodeState) error); ok { + if rf, ok := ret.Get(0).(func(handler.DynamicNodeState) error); ok { r0 = rf(s) } else { r0 = ret.Error(0) @@ -121,7 +122,7 @@ func (_m NodeStateWriter_PutGateNodeState) Return(_a0 error) *NodeStateWriter_Pu return &NodeStateWriter_PutGateNodeState{Call: _m.Call.Return(_a0)} } -func (_m *NodeStateWriter) OnPutGateNodeState(s interfaces.GateNodeState) *NodeStateWriter_PutGateNodeState { +func (_m *NodeStateWriter) OnPutGateNodeState(s handler.GateNodeState) *NodeStateWriter_PutGateNodeState { c_call := _m.On("PutGateNodeState", s) return &NodeStateWriter_PutGateNodeState{Call: c_call} } @@ -132,11 +133,11 @@ func (_m *NodeStateWriter) OnPutGateNodeStateMatch(matchers ...interface{}) *Nod } // PutGateNodeState provides a mock function with given fields: s -func (_m *NodeStateWriter) PutGateNodeState(s interfaces.GateNodeState) error { +func (_m *NodeStateWriter) PutGateNodeState(s handler.GateNodeState) error { ret := _m.Called(s) var r0 error - if rf, ok := ret.Get(0).(func(interfaces.GateNodeState) error); ok { + if rf, ok := ret.Get(0).(func(handler.GateNodeState) error); ok { r0 = rf(s) } else { r0 = ret.Error(0) @@ -153,7 +154,7 @@ func (_m NodeStateWriter_PutTaskNodeState) Return(_a0 error) *NodeStateWriter_Pu return &NodeStateWriter_PutTaskNodeState{Call: _m.Call.Return(_a0)} } -func (_m *NodeStateWriter) OnPutTaskNodeState(s interfaces.TaskNodeState) *NodeStateWriter_PutTaskNodeState { +func (_m *NodeStateWriter) OnPutTaskNodeState(s handler.TaskNodeState) *NodeStateWriter_PutTaskNodeState { c_call := _m.On("PutTaskNodeState", s) return &NodeStateWriter_PutTaskNodeState{Call: c_call} } @@ -164,11 +165,11 @@ func (_m *NodeStateWriter) OnPutTaskNodeStateMatch(matchers ...interface{}) *Nod } // PutTaskNodeState provides a mock function with given fields: s -func (_m *NodeStateWriter) PutTaskNodeState(s interfaces.TaskNodeState) error { +func (_m *NodeStateWriter) PutTaskNodeState(s handler.TaskNodeState) error { ret := _m.Called(s) var r0 error - if rf, ok := ret.Get(0).(func(interfaces.TaskNodeState) error); ok { + if rf, ok := ret.Get(0).(func(handler.TaskNodeState) error); ok { r0 = rf(s) } else { r0 = ret.Error(0) @@ -185,7 +186,7 @@ func (_m NodeStateWriter_PutWorkflowNodeState) Return(_a0 error) *NodeStateWrite return &NodeStateWriter_PutWorkflowNodeState{Call: _m.Call.Return(_a0)} } -func (_m *NodeStateWriter) OnPutWorkflowNodeState(s interfaces.WorkflowNodeState) *NodeStateWriter_PutWorkflowNodeState { +func (_m *NodeStateWriter) OnPutWorkflowNodeState(s handler.WorkflowNodeState) *NodeStateWriter_PutWorkflowNodeState { c_call := _m.On("PutWorkflowNodeState", s) return &NodeStateWriter_PutWorkflowNodeState{Call: c_call} } @@ -196,11 +197,11 @@ func (_m *NodeStateWriter) OnPutWorkflowNodeStateMatch(matchers ...interface{}) } // PutWorkflowNodeState provides a mock function with given fields: s -func (_m *NodeStateWriter) PutWorkflowNodeState(s interfaces.WorkflowNodeState) error { +func (_m *NodeStateWriter) PutWorkflowNodeState(s handler.WorkflowNodeState) error { ret := _m.Called(s) var r0 error - if rf, ok := ret.Get(0).(func(interfaces.WorkflowNodeState) error); ok { + if rf, ok := ret.Get(0).(func(handler.WorkflowNodeState) error); ok { r0 = rf(s) } else { r0 = ret.Error(0) diff --git a/pkg/controller/nodes/interfaces/state.go b/pkg/controller/nodes/interfaces/state.go index bf753a23a..bdbcad2e1 100644 --- a/pkg/controller/nodes/interfaces/state.go +++ b/pkg/controller/nodes/interfaces/state.go @@ -1,85 +1,30 @@ package interfaces import ( - "time" - - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - - pluginCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - - "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" - - "github.com/flyteorg/flytestdlib/bitarray" - "github.com/flyteorg/flytestdlib/storage" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" ) -// This is the legacy state structure that gets translated to node status -// TODO eventually we could just convert this to be binary node state encoded into the node status - -type TaskNodeState struct { - PluginPhase pluginCore.Phase - PluginPhaseVersion uint32 - PluginState []byte - PluginStateVersion uint32 - LastPhaseUpdatedAt time.Time - PreviousNodeExecutionCheckpointURI storage.DataReference - CleanupOnFailure bool -} - -type BranchNodeState struct { - FinalizedNodeID *v1alpha1.NodeID - Phase v1alpha1.BranchNodePhase -} - -type DynamicNodePhase uint8 - -type DynamicNodeState struct { - Phase v1alpha1.DynamicNodePhase - Reason string - Error *core.ExecutionError -} - -type WorkflowNodeState struct { - Phase v1alpha1.WorkflowNodePhase - Error *core.ExecutionError -} - -type GateNodeState struct { - Phase v1alpha1.GateNodePhase - StartedAt time.Time -} - -type ArrayNodeState struct { - Phase v1alpha1.ArrayNodePhase - TaskPhaseVersion uint32 - Error *core.ExecutionError - SubNodePhases bitarray.CompactArray - SubNodeTaskPhases bitarray.CompactArray - SubNodeRetryAttempts bitarray.CompactArray - SubNodeSystemFailures bitarray.CompactArray -} - type NodeStateWriter interface { - PutTaskNodeState(s TaskNodeState) error - PutBranchNode(s BranchNodeState) error - PutDynamicNodeState(s DynamicNodeState) error - PutWorkflowNodeState(s WorkflowNodeState) error - PutGateNodeState(s GateNodeState) error - PutArrayNodeState(s ArrayNodeState) error + PutTaskNodeState(s handler.TaskNodeState) error + PutBranchNode(s handler.BranchNodeState) error + PutDynamicNodeState(s handler.DynamicNodeState) error + PutWorkflowNodeState(s handler.WorkflowNodeState) error + PutGateNodeState(s handler.GateNodeState) error + PutArrayNodeState(s handler.ArrayNodeState) error ClearNodeStatus() } type NodeStateReader interface { HasTaskNodeState() bool - GetTaskNodeState() TaskNodeState + GetTaskNodeState() handler.TaskNodeState HasBranchNodeState() bool - GetBranchNodeState() BranchNodeState + GetBranchNodeState() handler.BranchNodeState HasDynamicNodeState() bool - GetDynamicNodeState() DynamicNodeState + GetDynamicNodeState() handler.DynamicNodeState HasWorkflowNodeState() bool - GetWorkflowNodeState() WorkflowNodeState + GetWorkflowNodeState() handler.WorkflowNodeState HasGateNodeState() bool - GetGateNodeState() GateNodeState + GetGateNodeState() handler.GateNodeState HasArrayNodeState() bool - GetArrayNodeState() ArrayNodeState + GetArrayNodeState() handler.ArrayNodeState } diff --git a/pkg/controller/nodes/node_state_manager.go b/pkg/controller/nodes/node_state_manager.go index 17f4113a3..0d61e0ce3 100644 --- a/pkg/controller/nodes/node_state_manager.go +++ b/pkg/controller/nodes/node_state_manager.go @@ -6,45 +6,45 @@ import ( pluginCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" ) type nodeStateManager struct { nodeStatus v1alpha1.ExecutableNodeStatus - t *interfaces.TaskNodeState - b *interfaces.BranchNodeState - d *interfaces.DynamicNodeState - w *interfaces.WorkflowNodeState - g *interfaces.GateNodeState - a *interfaces.ArrayNodeState + t *handler.TaskNodeState + b *handler.BranchNodeState + d *handler.DynamicNodeState + w *handler.WorkflowNodeState + g *handler.GateNodeState + a *handler.ArrayNodeState } -func (n *nodeStateManager) PutTaskNodeState(s interfaces.TaskNodeState) error { +func (n *nodeStateManager) PutTaskNodeState(s handler.TaskNodeState) error { n.t = &s return nil } -func (n *nodeStateManager) PutBranchNode(s interfaces.BranchNodeState) error { +func (n *nodeStateManager) PutBranchNode(s handler.BranchNodeState) error { n.b = &s return nil } -func (n *nodeStateManager) PutDynamicNodeState(s interfaces.DynamicNodeState) error { +func (n *nodeStateManager) PutDynamicNodeState(s handler.DynamicNodeState) error { n.d = &s return nil } -func (n *nodeStateManager) PutWorkflowNodeState(s interfaces.WorkflowNodeState) error { +func (n *nodeStateManager) PutWorkflowNodeState(s handler.WorkflowNodeState) error { n.w = &s return nil } -func (n *nodeStateManager) PutGateNodeState(s interfaces.GateNodeState) error { +func (n *nodeStateManager) PutGateNodeState(s handler.GateNodeState) error { n.g = &s return nil } -func (n *nodeStateManager) PutArrayNodeState(s interfaces.ArrayNodeState) error { +func (n *nodeStateManager) PutArrayNodeState(s handler.ArrayNodeState) error { n.a = &s return nil } @@ -73,14 +73,14 @@ func (n *nodeStateManager) HasArrayNodeState() bool { return n.a != nil } -func (n nodeStateManager) GetTaskNodeState() interfaces.TaskNodeState { +func (n nodeStateManager) GetTaskNodeState() handler.TaskNodeState { if n.t != nil { return *n.t } tn := n.nodeStatus.GetTaskNodeStatus() if tn != nil { - return interfaces.TaskNodeState{ + return handler.TaskNodeState{ PluginPhase: pluginCore.Phase(tn.GetPhase()), PluginPhaseVersion: tn.GetPhaseVersion(), PluginStateVersion: tn.GetPluginStateVersion(), @@ -90,16 +90,16 @@ func (n nodeStateManager) GetTaskNodeState() interfaces.TaskNodeState { CleanupOnFailure: tn.GetCleanupOnFailure(), } } - return interfaces.TaskNodeState{} + return handler.TaskNodeState{} } -func (n nodeStateManager) GetBranchNodeState() interfaces.BranchNodeState { +func (n nodeStateManager) GetBranchNodeState() handler.BranchNodeState { if n.b != nil { return *n.b } bn := n.nodeStatus.GetBranchStatus() - bs := interfaces.BranchNodeState{} + bs := handler.BranchNodeState{} if bn != nil { bs.Phase = bn.GetPhase() bs.FinalizedNodeID = bn.GetFinalizedNode() @@ -107,13 +107,13 @@ func (n nodeStateManager) GetBranchNodeState() interfaces.BranchNodeState { return bs } -func (n nodeStateManager) GetDynamicNodeState() interfaces.DynamicNodeState { +func (n nodeStateManager) GetDynamicNodeState() handler.DynamicNodeState { if n.d != nil { return *n.d } dn := n.nodeStatus.GetDynamicNodeStatus() - ds := interfaces.DynamicNodeState{} + ds := handler.DynamicNodeState{} if dn != nil { ds.Phase = dn.GetDynamicNodePhase() ds.Reason = dn.GetDynamicNodeReason() @@ -123,13 +123,13 @@ func (n nodeStateManager) GetDynamicNodeState() interfaces.DynamicNodeState { return ds } -func (n nodeStateManager) GetWorkflowNodeState() interfaces.WorkflowNodeState { +func (n nodeStateManager) GetWorkflowNodeState() handler.WorkflowNodeState { if n.w != nil { return *n.w } wn := n.nodeStatus.GetWorkflowNodeStatus() - ws := interfaces.WorkflowNodeState{} + ws := handler.WorkflowNodeState{} if wn != nil { ws.Phase = wn.GetWorkflowNodePhase() ws.Error = wn.GetExecutionError() @@ -137,26 +137,26 @@ func (n nodeStateManager) GetWorkflowNodeState() interfaces.WorkflowNodeState { return ws } -func (n nodeStateManager) GetGateNodeState() interfaces.GateNodeState { +func (n nodeStateManager) GetGateNodeState() handler.GateNodeState { if n.g != nil { return *n.g } gn := n.nodeStatus.GetGateNodeStatus() - gs := interfaces.GateNodeState{} + gs := handler.GateNodeState{} if gn != nil { gs.Phase = gn.GetGateNodePhase() } return gs } -func (n nodeStateManager) GetArrayNodeState() interfaces.ArrayNodeState { +func (n nodeStateManager) GetArrayNodeState() handler.ArrayNodeState { if n.a != nil { return *n.a } an := n.nodeStatus.GetArrayNodeStatus() - as := interfaces.ArrayNodeState{} + as := handler.ArrayNodeState{} if an != nil { as.Phase = an.GetArrayNodePhase() as.Error = an.GetExecutionError() diff --git a/pkg/controller/nodes/subworkflow/handler.go b/pkg/controller/nodes/subworkflow/handler.go index 4787509e8..5f478e398 100644 --- a/pkg/controller/nodes/subworkflow/handler.go +++ b/pkg/controller/nodes/subworkflow/handler.go @@ -58,7 +58,7 @@ func (w *workflowNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeEx return transition, err } - workflowNodeState := interfaces.WorkflowNodeState{Phase: newPhase} + workflowNodeState := handler.WorkflowNodeState{Phase: newPhase} err = nCtx.NodeStateWriter().PutWorkflowNodeState(workflowNodeState) if err != nil { logger.Errorf(ctx, "Failed to store WorkflowNodeState, err :%s", err.Error()) diff --git a/pkg/controller/nodes/subworkflow/handler_test.go b/pkg/controller/nodes/subworkflow/handler_test.go index 4cc16215e..20e40fdc0 100644 --- a/pkg/controller/nodes/subworkflow/handler_test.go +++ b/pkg/controller/nodes/subworkflow/handler_test.go @@ -26,14 +26,13 @@ import ( mocks2 "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" execMocks "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" mocks3 "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/mocks" ) type workflowNodeStateHolder struct { - s interfaces.WorkflowNodeState + s handler.WorkflowNodeState } var eventConfig = &config.EventConfig{ @@ -43,28 +42,28 @@ var eventConfig = &config.EventConfig{ func (t *workflowNodeStateHolder) ClearNodeStatus() { } -func (t *workflowNodeStateHolder) PutTaskNodeState(s interfaces.TaskNodeState) error { +func (t *workflowNodeStateHolder) PutTaskNodeState(s handler.TaskNodeState) error { panic("not implemented") } -func (t workflowNodeStateHolder) PutBranchNode(s interfaces.BranchNodeState) error { +func (t workflowNodeStateHolder) PutBranchNode(s handler.BranchNodeState) error { panic("not implemented") } -func (t *workflowNodeStateHolder) PutWorkflowNodeState(s interfaces.WorkflowNodeState) error { +func (t *workflowNodeStateHolder) PutWorkflowNodeState(s handler.WorkflowNodeState) error { t.s = s return nil } -func (t workflowNodeStateHolder) PutDynamicNodeState(s interfaces.DynamicNodeState) error { +func (t workflowNodeStateHolder) PutDynamicNodeState(s handler.DynamicNodeState) error { panic("not implemented") } -func (t workflowNodeStateHolder) PutGateNodeState(s interfaces.GateNodeState) error { +func (t workflowNodeStateHolder) PutGateNodeState(s handler.GateNodeState) error { panic("not implemented") } -func (t workflowNodeStateHolder) PutArrayNodeState(s interfaces.ArrayNodeState) error { +func (t workflowNodeStateHolder) PutArrayNodeState(s handler.ArrayNodeState) error { panic("not implemented") } @@ -76,7 +75,7 @@ var wfExecID = &core.WorkflowExecutionIdentifier{ func createNodeContextWithVersion(phase v1alpha1.WorkflowNodePhase, n v1alpha1.ExecutableNode, s v1alpha1.ExecutableNodeStatus, version v1alpha1.EventVersion) *mocks3.NodeExecutionContext { - wfNodeState := interfaces.WorkflowNodeState{} + wfNodeState := handler.WorkflowNodeState{} state := &workflowNodeStateHolder{s: wfNodeState} nm := &mocks3.NodeExecutionMetadata{} @@ -109,7 +108,7 @@ func createNodeContextWithVersion(phase v1alpha1.WorkflowNodePhase, n v1alpha1.E nCtx.OnNodeStatus().Return(s) nr := &mocks3.NodeStateReader{} - nr.OnGetWorkflowNodeState().Return(interfaces.WorkflowNodeState{ + nr.OnGetWorkflowNodeState().Return(handler.WorkflowNodeState{ Phase: phase, }) nCtx.OnNodeStateReader().Return(nr) diff --git a/pkg/controller/nodes/subworkflow/subworkflow.go b/pkg/controller/nodes/subworkflow/subworkflow.go index 52626730c..d0dad95b9 100644 --- a/pkg/controller/nodes/subworkflow/subworkflow.go +++ b/pkg/controller/nodes/subworkflow/subworkflow.go @@ -77,7 +77,7 @@ func (s *subworkflowHandler) handleSubWorkflow(ctx context.Context, nCtx interfa } if state.HasFailed() { - workflowNodeState := interfaces.WorkflowNodeState{ + workflowNodeState := handler.WorkflowNodeState{ Phase: v1alpha1.WorkflowNodePhaseFailing, Error: state.Err, } diff --git a/pkg/controller/nodes/task/handler.go b/pkg/controller/nodes/task/handler.go index 04fb00721..b1890d3ba 100644 --- a/pkg/controller/nodes/task/handler.go +++ b/pkg/controller/nodes/task/handler.go @@ -379,7 +379,7 @@ func (t Handler) fetchPluginTaskMetrics(pluginID, taskType string) (*taskMetrics return t.taskMetricsMap[metricNameKey], nil } -func (t Handler) invokePlugin(ctx context.Context, p pluginCore.Plugin, tCtx *taskExecutionContext, ts interfaces.TaskNodeState) (*pluginRequestedTransition, error) { +func (t Handler) invokePlugin(ctx context.Context, p pluginCore.Plugin, tCtx *taskExecutionContext, ts handler.TaskNodeState) (*pluginRequestedTransition, error) { pluginTrns := &pluginRequestedTransition{} trns, err := func() (trns pluginCore.Transition, err error) { @@ -750,7 +750,7 @@ func (t Handler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContex } // STEP 6: Persist the plugin state - err = nCtx.NodeStateWriter().PutTaskNodeState(interfaces.TaskNodeState{ + err = nCtx.NodeStateWriter().PutTaskNodeState(handler.TaskNodeState{ PluginState: pluginTrns.pluginState, PluginStateVersion: pluginTrns.pluginStateVersion, PluginPhase: pluginTrns.pInfo.Phase(), diff --git a/pkg/controller/nodes/task/handler_test.go b/pkg/controller/nodes/task/handler_test.go index f002d60ae..0a009cb65 100644 --- a/pkg/controller/nodes/task/handler_test.go +++ b/pkg/controller/nodes/task/handler_test.go @@ -364,34 +364,34 @@ func (f *fakeBufferedEventRecorder) RecordNodeEvent(ctx context.Context, ev *eve } type taskNodeStateHolder struct { - s interfaces.TaskNodeState + s handler.TaskNodeState } func (t *taskNodeStateHolder) ClearNodeStatus() { } -func (t *taskNodeStateHolder) PutTaskNodeState(s interfaces.TaskNodeState) error { +func (t *taskNodeStateHolder) PutTaskNodeState(s handler.TaskNodeState) error { t.s = s return nil } -func (t taskNodeStateHolder) PutBranchNode(s interfaces.BranchNodeState) error { +func (t taskNodeStateHolder) PutBranchNode(s handler.BranchNodeState) error { panic("not implemented") } -func (t taskNodeStateHolder) PutWorkflowNodeState(s interfaces.WorkflowNodeState) error { +func (t taskNodeStateHolder) PutWorkflowNodeState(s handler.WorkflowNodeState) error { panic("not implemented") } -func (t taskNodeStateHolder) PutDynamicNodeState(s interfaces.DynamicNodeState) error { +func (t taskNodeStateHolder) PutDynamicNodeState(s handler.DynamicNodeState) error { panic("not implemented") } -func (t taskNodeStateHolder) PutGateNodeState(s interfaces.GateNodeState) error { +func (t taskNodeStateHolder) PutGateNodeState(s handler.GateNodeState) error { panic("not implemented") } -func (t taskNodeStateHolder) PutArrayNodeState(s interfaces.ArrayNodeState) error { +func (t taskNodeStateHolder) PutArrayNodeState(s handler.ArrayNodeState) error { panic("not implemented") } @@ -519,7 +519,7 @@ func Test_task_Handle_NoCatalog(t *testing.T) { cod := codex.GobStateCodec{} assert.NoError(t, cod.Encode(pluginResp, st)) nr := &nodeMocks.NodeStateReader{} - nr.OnGetTaskNodeState().Return(interfaces.TaskNodeState{ + nr.OnGetTaskNodeState().Return(handler.TaskNodeState{ PluginState: st.Bytes(), PluginPhase: pluginPhase, PluginPhaseVersion: pluginVer, @@ -867,7 +867,7 @@ func Test_task_Handle_Catalog(t *testing.T) { OutputExists: true, }, st)) nr := &nodeMocks.NodeStateReader{} - nr.OnGetTaskNodeState().Return(interfaces.TaskNodeState{ + nr.OnGetTaskNodeState().Return(handler.TaskNodeState{ PluginState: st.Bytes(), }) nCtx.OnNodeStateReader().Return(nr) @@ -1225,7 +1225,7 @@ func Test_task_Handle_Reservation(t *testing.T) { Phase: pluginCore.PhaseSuccess, OutputExists: true, }, st)) - nr.OnGetTaskNodeState().Return(interfaces.TaskNodeState{ + nr.OnGetTaskNodeState().Return(handler.TaskNodeState{ PluginPhase: tt.args.pluginPhase, PluginState: st.Bytes(), }) @@ -1355,7 +1355,7 @@ func Test_task_Abort(t *testing.T) { cod := codex.GobStateCodec{} assert.NoError(t, cod.Encode(test{A: a}, st)) nr := &nodeMocks.NodeStateReader{} - nr.OnGetTaskNodeState().Return(interfaces.TaskNodeState{ + nr.OnGetTaskNodeState().Return(handler.TaskNodeState{ PluginState: st.Bytes(), }) nCtx.OnNodeStateReader().Return(nr) @@ -1517,7 +1517,7 @@ func Test_task_Abort_v1(t *testing.T) { cod := codex.GobStateCodec{} assert.NoError(t, cod.Encode(test{A: a}, st)) nr := &nodeMocks.NodeStateReader{} - nr.OnGetTaskNodeState().Return(interfaces.TaskNodeState{ + nr.OnGetTaskNodeState().Return(handler.TaskNodeState{ PluginState: st.Bytes(), }) nCtx.OnNodeStateReader().Return(nr) @@ -1703,7 +1703,7 @@ func Test_task_Finalize(t *testing.T) { cod := codex.GobStateCodec{} assert.NoError(t, cod.Encode(test{A: a}, st)) nr := &nodeMocks.NodeStateReader{} - nr.On("GetTaskNodeState").Return(interfaces.TaskNodeState{ + nr.On("GetTaskNodeState").Return(handler.TaskNodeState{ PluginState: st.Bytes(), }) nCtx.On("NodeStateReader").Return(nr) diff --git a/pkg/controller/nodes/task/taskexec_context_test.go b/pkg/controller/nodes/task/taskexec_context_test.go index 1bc368652..cd30e86b5 100644 --- a/pkg/controller/nodes/task/taskexec_context_test.go +++ b/pkg/controller/nodes/task/taskexec_context_test.go @@ -30,7 +30,7 @@ import ( "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" flyteMocks "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" nodeMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/codex" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/secretmanager" @@ -113,7 +113,7 @@ func TestHandler_newTaskExecutionContext(t *testing.T) { codex := codex.GobStateCodec{} assert.NoError(t, codex.Encode(test{A: a}, st)) nr := &nodeMocks.NodeStateReader{} - nr.OnGetTaskNodeState().Return(interfaces.TaskNodeState{ + nr.OnGetTaskNodeState().Return(handler.TaskNodeState{ PluginState: st.Bytes(), }) nCtx.OnNodeStateReader().Return(nr) @@ -432,7 +432,7 @@ func TestComputePreviousCheckpointPath(t *testing.T) { nCtx.OnDataStore().Return(ds) nCtx.OnNodeExecutionMetadata().Return(nm) reader := &nodeMocks.NodeStateReader{} - reader.OnGetTaskNodeState().Return(interfaces.TaskNodeState{}) + reader.OnGetTaskNodeState().Return(handler.TaskNodeState{}) nCtx.OnNodeStateReader().Return(reader) t.Run("attempt-0-nCtx", func(t *testing.T) { @@ -464,7 +464,7 @@ func TestComputePreviousCheckpointPath_Recovery(t *testing.T) { nCtx.OnDataStore().Return(ds) nCtx.OnNodeExecutionMetadata().Return(nm) reader := &nodeMocks.NodeStateReader{} - reader.OnGetTaskNodeState().Return(interfaces.TaskNodeState{ + reader.OnGetTaskNodeState().Return(handler.TaskNodeState{ PreviousNodeExecutionCheckpointURI: storage.DataReference("s3://sandbox/x/prevname-n1-0/_flytecheckpoints"), }) nCtx.OnNodeStateReader().Return(reader) From 4496efb68514b6103488eda9bf168f1a19f6fc7a Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Wed, 5 Jul 2023 12:07:56 -0500 Subject: [PATCH 50/62] added docs Signed-off-by: Daniel Rammer --- cmd/kubectl-flyte/cmd/create.go | 6 --- pkg/controller/nodes/executor.go | 8 ++-- pkg/controller/nodes/interfaces/handler.go | 2 +- pkg/controller/nodes/interfaces/node.go | 6 ++- pkg/controller/nodes/node_exec_context.go | 1 - .../nodes/node_exec_context_test.go | 39 +++++++++++++++++++ 6 files changed, 49 insertions(+), 13 deletions(-) diff --git a/cmd/kubectl-flyte/cmd/create.go b/cmd/kubectl-flyte/cmd/create.go index 2114aa902..91ea38255 100644 --- a/cmd/kubectl-flyte/cmd/create.go +++ b/cmd/kubectl-flyte/cmd/create.go @@ -212,12 +212,6 @@ func (c *CreateOpts) createWorkflowFromProto() error { } } - // TODO @hamersaw temp - flyteWf.ExecutionID.Project = "flytesnacks" - flyteWf.ExecutionID.Domain = "development" - flyteWf.Labels["project"] = "flytesnacks" - flyteWf.Labels["domain"] = "development" - if c.dryRun { fmt.Printf("Dry Run mode enabled. Printing the compiled workflow.\n") j, err := json.Marshal(flyteWf) diff --git a/pkg/controller/nodes/executor.go b/pkg/controller/nodes/executor.go index fe2d2be6c..c53b1ccc9 100644 --- a/pkg/controller/nodes/executor.go +++ b/pkg/controller/nodes/executor.go @@ -85,7 +85,8 @@ type nodeMetrics struct { NodeInputGatherLatency labeled.StopWatch } -// Implements the executors.Node interface +// recursiveNodeExector implements the executors.Node interfaces and is the starting point for +// executing any node in the workflow. type recursiveNodeExecutor struct { nodeExecutor interfaces.NodeExecutor nCtxBuilder interfaces.NodeExecutionContextBuilder @@ -486,11 +487,12 @@ func (c *recursiveNodeExecutor) Initialize(ctx context.Context) error { return c.nodeHandlerFactory.Setup(ctx, c, s) } -// TODO @hamersaw docs +// GetNodeExecutionContextBuilder returns the current NodeExecutionContextBuilder func (c *recursiveNodeExecutor) GetNodeExecutionContextBuilder() interfaces.NodeExecutionContextBuilder { return c.nCtxBuilder } +// WithNodeExecutionContextBuilder returns a new Node with the given NodeExecutionContextBuilder func (c *recursiveNodeExecutor) WithNodeExecutionContextBuilder(nCtxBuilder interfaces.NodeExecutionContextBuilder) interfaces.Node { return &recursiveNodeExecutor{ nodeExecutor: c.nodeExecutor, @@ -502,7 +504,7 @@ func (c *recursiveNodeExecutor) WithNodeExecutionContextBuilder(nCtxBuilder inte } } -// TODO @hamersaw nodeExecutor goes here and write docs +// nodeExecutor implements the NodeExecutor interface and is responsible for executing a single node. type nodeExecutor struct { clusterID string defaultActiveDeadline time.Duration diff --git a/pkg/controller/nodes/interfaces/handler.go b/pkg/controller/nodes/interfaces/handler.go index d2fd411cf..a0b6ca2b8 100644 --- a/pkg/controller/nodes/interfaces/handler.go +++ b/pkg/controller/nodes/interfaces/handler.go @@ -10,7 +10,7 @@ import ( //go:generate mockery -all -case=underscore -// TODO @hamersaw - docs?!?1 +// NodeExecutor defines the interface for handling a single Flyte Node of any Node type. type NodeExecutor interface { // TODO @hamersaw - BuildNodeExecutionContext should be here - removes need for another interface HandleNode(ctx context.Context, dag executors.DAGStructure, nCtx NodeExecutionContext, h NodeHandler) (NodeStatus, error) diff --git a/pkg/controller/nodes/interfaces/node.go b/pkg/controller/nodes/interfaces/node.go index 0f1b56e22..719689741 100644 --- a/pkg/controller/nodes/interfaces/node.go +++ b/pkg/controller/nodes/interfaces/node.go @@ -89,12 +89,14 @@ type Node interface { // This method should be used to initialize Node executor Initialize(ctx context.Context) error - // TODO @hamersaw - docs + // GetNodeExecutionContextBuilder returns the current NodeExecutionContextBuilder GetNodeExecutionContextBuilder() NodeExecutionContextBuilder + + // WithNodeExecutionContextBuilder returns a new Node with the given NodeExecutionContextBuilder WithNodeExecutionContextBuilder(NodeExecutionContextBuilder) Node } -// TODO @hamersaw - docs +// NodeExecutionContextBuilder defines how a NodeExecutionContext is built type NodeExecutionContextBuilder interface { BuildNodeExecutionContext(ctx context.Context, executionContext executors.ExecutionContext, nl executors.NodeLookup, currentNodeID v1alpha1.NodeID) (NodeExecutionContext, error) diff --git a/pkg/controller/nodes/node_exec_context.go b/pkg/controller/nodes/node_exec_context.go index 41cb9831b..3822afd8c 100644 --- a/pkg/controller/nodes/node_exec_context.go +++ b/pkg/controller/nodes/node_exec_context.go @@ -318,7 +318,6 @@ func (c *nodeExecutor) BuildNodeExecutionContext(ctx context.Context, executionC interruptible, c.interruptibleFailureThreshold, c.maxDatasetSizeBytes, - //&taskEventRecorder{TaskEventRecorder: c.taskRecorder}, c.taskRecorder, c.nodeRecorder, tr, diff --git a/pkg/controller/nodes/node_exec_context_test.go b/pkg/controller/nodes/node_exec_context_test.go index 8586e2e00..7130883c6 100644 --- a/pkg/controller/nodes/node_exec_context_test.go +++ b/pkg/controller/nodes/node_exec_context_test.go @@ -284,3 +284,42 @@ func Test_NodeContextDefaultInterruptible(t *testing.T) { verifyNodeExecContext(t, execContext, nodeLookup, false) }) } + +// TODO @hamersaw - get working +/*func Test_NodeContext_IdempotentRecordEvent(t *testing.T) { + noErrRecorder := fakeNodeEventRecorder{} + alreadyExistsError := fakeNodeEventRecorder{&eventsErr.EventError{Code: eventsErr.AlreadyExists, Cause: fmt.Errorf("err")}} + inTerminalError := fakeNodeEventRecorder{&eventsErr.EventError{Code: eventsErr.EventAlreadyInTerminalStateError, Cause: fmt.Errorf("err")}} + otherError := fakeNodeEventRecorder{&eventsErr.EventError{Code: eventsErr.ResourceExhausted, Cause: fmt.Errorf("err")}} + + tests := []struct { + name string + rec events.NodeEventRecorder + p core.NodeExecution_Phase + wantErr bool + }{ + {"aborted-success", noErrRecorder, core.NodeExecution_ABORTED, false}, + {"aborted-failure", otherError, core.NodeExecution_ABORTED, true}, + {"aborted-already", alreadyExistsError, core.NodeExecution_ABORTED, false}, + {"aborted-terminal", inTerminalError, core.NodeExecution_ABORTED, false}, + {"running-terminal", inTerminalError, core.NodeExecution_RUNNING, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &recursiveNodeExecutor{ + nodeRecorder: tt.rec, + eventConfig: &config.EventConfig{ + RawOutputPolicy: config.RawOutputPolicyReference, + }, + } + ev := &event.NodeExecutionEvent{ + Id: &core.NodeExecutionIdentifier{}, + Phase: tt.p, + ProducerId: "propeller", + } + if err := c.IdempotentRecordEvent(context.TODO(), ev); (err != nil) != tt.wantErr { + t.Errorf("IdempotentRecordEvent() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +}*/ From b13ef67c870e83b723d9941e65efe8d10f3ba564 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Fri, 7 Jul 2023 10:36:47 -0500 Subject: [PATCH 51/62] cleaned up abort event reporting Signed-off-by: Daniel Rammer --- pkg/controller/nodes/executor.go | 90 +++++++++++----------- pkg/controller/nodes/interfaces/handler.go | 1 - pkg/controller/nodes/node_exec_context.go | 4 - 3 files changed, 44 insertions(+), 51 deletions(-) diff --git a/pkg/controller/nodes/executor.go b/pkg/controller/nodes/executor.go index c53b1ccc9..516f26d21 100644 --- a/pkg/controller/nodes/executor.go +++ b/pkg/controller/nodes/executor.go @@ -405,48 +405,7 @@ func (c *recursiveNodeExecutor) AbortHandler(ctx context.Context, execContext ex } // Abort this node err = c.nodeExecutor.Abort(ctx, h, nCtx, reason) - if err != nil { - return err - } - - // TODO @hamersaw - need to fix this shouldn't need to decompose nodeExecutor to send event - if nodeExec, ok := c.nodeExecutor.(*nodeExecutor); ok { - nodeExecutionID := &core.NodeExecutionIdentifier{ - ExecutionId: nCtx.NodeExecutionMetadata().GetNodeExecutionID().ExecutionId, - NodeId: nCtx.NodeExecutionMetadata().GetNodeExecutionID().NodeId, - } - if nCtx.ExecutionContext().GetEventVersion() != v1alpha1.EventVersion0 { - currentNodeUniqueID, err := common.GenerateUniqueID(nCtx.ExecutionContext().GetParentInfo(), nodeExecutionID.NodeId) - if err != nil { - return err - } - nodeExecutionID.NodeId = currentNodeUniqueID - } - - //err := c.IdempotentRecordEvent(ctx, &event.NodeExecutionEvent{ - err = nCtx.EventsRecorder().RecordNodeEvent(ctx, &event.NodeExecutionEvent{ - Id: nodeExecutionID, - Phase: core.NodeExecution_ABORTED, - OccurredAt: ptypes.TimestampNow(), - OutputResult: &event.NodeExecutionEvent_Error{ - Error: &core.ExecutionError{ - Code: "NodeAborted", - Message: reason, - }, - }, - ProducerId: nodeExec.clusterID, - ReportedAt: ptypes.TimestampNow(), - }, nodeExec.eventConfig) - if err != nil && !eventsErr.IsNotFound(err) && !eventsErr.IsEventIncompatibleClusterError(err) { - if errors2.IsCausedBy(err, errors.IllegalStateError) { - logger.Debugf(ctx, "Failed to record abort event due to illegal state transition. Ignoring the error. Error: %v", err) - } else { - logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) - return errors.Wrapf(errors.EventRecordingFailed, nCtx.NodeID(), err, "failed to record node event") - } - } - } - return nil + return err } else if nodePhase == v1alpha1.NodePhaseSucceeded || nodePhase == v1alpha1.NodePhaseSkipped || nodePhase == v1alpha1.NodePhaseRecovered { // Abort downstream nodes downstreamNodes, err := dag.FromNode(currentNode.GetID()) @@ -933,7 +892,49 @@ func (c *nodeExecutor) Abort(ctx context.Context, h interfaces.NodeHandler, nCtx return err } - return h.Finalize(ctx, nCtx) + if err := h.Finalize(ctx, nCtx); err != nil { + return err + } + + // only send event if node is in non-terminal phase + phase := nCtx.NodeStatus().GetPhase() + if phase != v1alpha1.NodePhaseNotYetStarted && canHandleNode(phase) { + nodeExecutionID := &core.NodeExecutionIdentifier{ + ExecutionId: nCtx.NodeExecutionMetadata().GetNodeExecutionID().ExecutionId, + NodeId: nCtx.NodeExecutionMetadata().GetNodeExecutionID().NodeId, + } + if nCtx.ExecutionContext().GetEventVersion() != v1alpha1.EventVersion0 { + currentNodeUniqueID, err := common.GenerateUniqueID(nCtx.ExecutionContext().GetParentInfo(), nodeExecutionID.NodeId) + if err != nil { + return err + } + nodeExecutionID.NodeId = currentNodeUniqueID + } + + err := nCtx.EventsRecorder().RecordNodeEvent(ctx, &event.NodeExecutionEvent{ + Id: nodeExecutionID, + Phase: core.NodeExecution_ABORTED, + OccurredAt: ptypes.TimestampNow(), + OutputResult: &event.NodeExecutionEvent_Error{ + Error: &core.ExecutionError{ + Code: "NodeAborted", + Message: reason, + }, + }, + ProducerId: c.clusterID, + ReportedAt: ptypes.TimestampNow(), + }, c.eventConfig) + if err != nil && !eventsErr.IsNotFound(err) && !eventsErr.IsEventIncompatibleClusterError(err) { + if errors2.IsCausedBy(err, errors.IllegalStateError) { + logger.Debugf(ctx, "Failed to record abort event due to illegal state transition. Ignoring the error. Error: %v", err) + } else { + logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) + return errors.Wrapf(errors.EventRecordingFailed, nCtx.NodeID(), err, "failed to record node event") + } + } + } + + return nil } func (c *nodeExecutor) Finalize(ctx context.Context, h interfaces.NodeHandler, nCtx interfaces.NodeExecutionContext) error { @@ -976,7 +977,6 @@ func (c *nodeExecutor) handleNotYetStartedNode(ctx context.Context, dag executor if err != nil { return interfaces.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "could not convert phase info to event") } - //err = c.IdempotentRecordEvent(ctx, nev) err = nCtx.EventsRecorder().RecordNodeEvent(ctx, nev, c.eventConfig) if err != nil { logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) @@ -1094,7 +1094,6 @@ func (c *nodeExecutor) handleQueuedOrRunningNode(ctx context.Context, nCtx inter return interfaces.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "could not convert phase info to event") } - //err = c.IdempotentRecordEvent(ctx, nev) err = nCtx.EventsRecorder().RecordNodeEvent(ctx, nev, c.eventConfig) if err != nil { if eventsErr.IsTooLarge(err) { @@ -1104,7 +1103,6 @@ func (c *nodeExecutor) handleQueuedOrRunningNode(ctx context.Context, nCtx inter np = v1alpha1.NodePhaseFailing p = handler.PhaseInfoFailure(core.ExecutionError_USER, "NodeFailed", err.Error(), p.GetInfo()) - //err = c.IdempotentRecordEvent(ctx, &event.NodeExecutionEvent{ err = nCtx.EventsRecorder().RecordNodeEvent(ctx, &event.NodeExecutionEvent{ Id: nCtx.NodeExecutionMetadata().GetNodeExecutionID(), Phase: core.NodeExecution_FAILED, diff --git a/pkg/controller/nodes/interfaces/handler.go b/pkg/controller/nodes/interfaces/handler.go index a0b6ca2b8..89eaac5c8 100644 --- a/pkg/controller/nodes/interfaces/handler.go +++ b/pkg/controller/nodes/interfaces/handler.go @@ -12,7 +12,6 @@ import ( // NodeExecutor defines the interface for handling a single Flyte Node of any Node type. type NodeExecutor interface { - // TODO @hamersaw - BuildNodeExecutionContext should be here - removes need for another interface HandleNode(ctx context.Context, dag executors.DAGStructure, nCtx NodeExecutionContext, h NodeHandler) (NodeStatus, error) Abort(ctx context.Context, h NodeHandler, nCtx NodeExecutionContext, reason string) error Finalize(ctx context.Context, h NodeHandler, nCtx NodeExecutionContext) error diff --git a/pkg/controller/nodes/node_exec_context.go b/pkg/controller/nodes/node_exec_context.go index 3822afd8c..1303f245f 100644 --- a/pkg/controller/nodes/node_exec_context.go +++ b/pkg/controller/nodes/node_exec_context.go @@ -179,10 +179,6 @@ func (e nodeExecContext) InputReader() io.InputReader { return e.inputs } -/*func (e nodeExecContext) EventsRecorder() events.TaskEventRecorder { - return e.er -}*/ - func (e nodeExecContext) EventsRecorder() interfaces.EventRecorder { return e.eventRecorder } From 1c446b7daa132ecc99e635d28d16c841f64b5e6c Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Fri, 7 Jul 2023 14:48:38 -0500 Subject: [PATCH 52/62] fixed RecordNodeEvent unit tests Signed-off-by: Daniel Rammer --- pkg/controller/nodes/executor_test.go | 50 ------------------- .../nodes/node_exec_context_test.go | 40 +++++++++------ pkg/controller/nodes/task/handler.go | 4 +- 3 files changed, 27 insertions(+), 67 deletions(-) diff --git a/pkg/controller/nodes/executor_test.go b/pkg/controller/nodes/executor_test.go index b7505ec92..090a7f811 100644 --- a/pkg/controller/nodes/executor_test.go +++ b/pkg/controller/nodes/executor_test.go @@ -1990,56 +1990,6 @@ func TestNodeExecutor_RecursiveNodeHandler_ParallelismLimit(t *testing.T) { }) } -type fakeNodeEventRecorder struct { - err error -} - -func (f fakeNodeEventRecorder) RecordNodeEvent(ctx context.Context, event *event.NodeExecutionEvent, eventConfig *config.EventConfig) error { - if f.err != nil { - return f.err - } - return nil -} - -// TODO @hamersaw - fix IdempotentRecordEvent test -> move to NodeExecutionSomething -/*func Test_nodeExecutor_IdempotentRecordEvent(t *testing.T) { - noErrRecorder := fakeNodeEventRecorder{} - alreadyExistsError := fakeNodeEventRecorder{&eventsErr.EventError{Code: eventsErr.AlreadyExists, Cause: fmt.Errorf("err")}} - inTerminalError := fakeNodeEventRecorder{&eventsErr.EventError{Code: eventsErr.EventAlreadyInTerminalStateError, Cause: fmt.Errorf("err")}} - otherError := fakeNodeEventRecorder{&eventsErr.EventError{Code: eventsErr.ResourceExhausted, Cause: fmt.Errorf("err")}} - - tests := []struct { - name string - rec events.NodeEventRecorder - p core.NodeExecution_Phase - wantErr bool - }{ - {"aborted-success", noErrRecorder, core.NodeExecution_ABORTED, false}, - {"aborted-failure", otherError, core.NodeExecution_ABORTED, true}, - {"aborted-already", alreadyExistsError, core.NodeExecution_ABORTED, false}, - {"aborted-terminal", inTerminalError, core.NodeExecution_ABORTED, false}, - {"running-terminal", inTerminalError, core.NodeExecution_RUNNING, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - c := &recursiveNodeExecutor{ - nodeRecorder: tt.rec, - eventConfig: &config.EventConfig{ - RawOutputPolicy: config.RawOutputPolicyReference, - }, - } - ev := &event.NodeExecutionEvent{ - Id: &core.NodeExecutionIdentifier{}, - Phase: tt.p, - ProducerId: "propeller", - } - if err := c.IdempotentRecordEvent(context.TODO(), ev); (err != nil) != tt.wantErr { - t.Errorf("IdempotentRecordEvent() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -}*/ - func TestRecover(t *testing.T) { recoveryID := &core.WorkflowExecutionIdentifier{ Project: "p", diff --git a/pkg/controller/nodes/node_exec_context_test.go b/pkg/controller/nodes/node_exec_context_test.go index 7130883c6..20b023e84 100644 --- a/pkg/controller/nodes/node_exec_context_test.go +++ b/pkg/controller/nodes/node_exec_context_test.go @@ -2,22 +2,29 @@ package nodes import ( "context" + "fmt" "strconv" "testing" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + + "github.com/flyteorg/flytepropeller/events" + eventsErr "github.com/flyteorg/flytepropeller/events/errors" + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/config" "github.com/flyteorg/flytepropeller/pkg/controller/executors" mocks2 "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" + "github.com/flyteorg/flytestdlib/promutils" "github.com/flyteorg/flytestdlib/promutils/labeled" "github.com/flyteorg/flytestdlib/storage" - "github.com/stretchr/testify/assert" - "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" - "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" + "github.com/stretchr/testify/assert" ) type TaskReader struct{} @@ -28,6 +35,14 @@ func (t TaskReader) GetTaskID() *core.Identifier { return &core.Identifier{Project: "p", Domain: "d", Name: "task-name"} } +type fakeNodeEventRecorder struct { + err error +} + +func (f fakeNodeEventRecorder) RecordNodeEvent(ctx context.Context, event *event.NodeExecutionEvent, eventConfig *config.EventConfig) error { + return f.err +} + type parentInfo struct { executors.ImmutableParentInfo } @@ -285,8 +300,7 @@ func Test_NodeContextDefaultInterruptible(t *testing.T) { }) } -// TODO @hamersaw - get working -/*func Test_NodeContext_IdempotentRecordEvent(t *testing.T) { +func Test_NodeContext_RecordNodeEvent(t *testing.T) { noErrRecorder := fakeNodeEventRecorder{} alreadyExistsError := fakeNodeEventRecorder{&eventsErr.EventError{Code: eventsErr.AlreadyExists, Cause: fmt.Errorf("err")}} inTerminalError := fakeNodeEventRecorder{&eventsErr.EventError{Code: eventsErr.EventAlreadyInTerminalStateError, Cause: fmt.Errorf("err")}} @@ -306,20 +320,18 @@ func Test_NodeContextDefaultInterruptible(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := &recursiveNodeExecutor{ - nodeRecorder: tt.rec, - eventConfig: &config.EventConfig{ - RawOutputPolicy: config.RawOutputPolicyReference, - }, + eventRecorder := &eventRecorder{ + nodeEventRecorder: tt.rec, } + ev := &event.NodeExecutionEvent{ Id: &core.NodeExecutionIdentifier{}, Phase: tt.p, ProducerId: "propeller", } - if err := c.IdempotentRecordEvent(context.TODO(), ev); (err != nil) != tt.wantErr { - t.Errorf("IdempotentRecordEvent() error = %v, wantErr %v", err, tt.wantErr) + if err := eventRecorder.RecordNodeEvent(context.TODO(), ev, &config.EventConfig{}); (err != nil) != tt.wantErr { + t.Errorf("RecordNodeEvent() error = %v, wantErr %v", err, tt.wantErr) } }) } -}*/ +} diff --git a/pkg/controller/nodes/task/handler.go b/pkg/controller/nodes/task/handler.go index b1890d3ba..cc502ed40 100644 --- a/pkg/controller/nodes/task/handler.go +++ b/pkg/controller/nodes/task/handler.go @@ -553,10 +553,8 @@ func (t Handler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContex ts := nCtx.NodeStateReader().GetTaskNodeState() pluginTrns := &pluginRequestedTransition{} - - // TODO @hamersaw - does this introduce issues in cache hits?!?! - // need to make sure the plugin transition does not block other workflows from progressing defer func() { + // increment parallelism if the final pluginTrns is not in a terminal state if pluginTrns != nil && !pluginTrns.pInfo.Phase().IsTerminal() { eCtx := nCtx.ExecutionContext() logger.Infof(ctx, "Parallelism now set to [%d].", eCtx.IncrementParallelism()) From 77824607c24464c281838e16298db9bcc309d73b Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Fri, 7 Jul 2023 14:57:06 -0500 Subject: [PATCH 53/62] removed taskEventRecorder from nodes package Signed-off-by: Daniel Rammer --- pkg/controller/nodes/executor_test.go | 2 +- .../nodes/node_exec_context_test.go | 56 ++++++++++++++--- .../nodes/task_event_recorder.go.bak | 36 ----------- .../nodes/task_event_recorder_test.go.bak | 60 ------------------- 4 files changed, 49 insertions(+), 105 deletions(-) delete mode 100644 pkg/controller/nodes/task_event_recorder.go.bak delete mode 100644 pkg/controller/nodes/task_event_recorder_test.go.bak diff --git a/pkg/controller/nodes/executor_test.go b/pkg/controller/nodes/executor_test.go index 090a7f811..1c0c0aa45 100644 --- a/pkg/controller/nodes/executor_test.go +++ b/pkg/controller/nodes/executor_test.go @@ -1652,7 +1652,7 @@ func TestNodeExecutor_AbortHandler(t *testing.T) { ns.OnGetDataDir().Return(storage.DataReference("s3:/foo")) nl.OnGetNodeExecutionStatusMatch(mock.Anything, id).Return(ns) nl.OnGetNode(id).Return(n, true) - incompatibleClusterErr := fakeNodeEventRecorder{&eventsErr.EventError{Code: eventsErr.AlreadyExists, Cause: fmt.Errorf("err")}} + incompatibleClusterErr := fakeEventRecorder{nodeErr: &eventsErr.EventError{Code: eventsErr.AlreadyExists, Cause: fmt.Errorf("err")}} hf := &nodemocks.HandlerFactory{} h := &nodemocks.NodeHandler{} diff --git a/pkg/controller/nodes/node_exec_context_test.go b/pkg/controller/nodes/node_exec_context_test.go index 20b023e84..1816a667f 100644 --- a/pkg/controller/nodes/node_exec_context_test.go +++ b/pkg/controller/nodes/node_exec_context_test.go @@ -35,12 +35,17 @@ func (t TaskReader) GetTaskID() *core.Identifier { return &core.Identifier{Project: "p", Domain: "d", Name: "task-name"} } -type fakeNodeEventRecorder struct { - err error +type fakeEventRecorder struct { + nodeErr error + taskErr error } -func (f fakeNodeEventRecorder) RecordNodeEvent(ctx context.Context, event *event.NodeExecutionEvent, eventConfig *config.EventConfig) error { - return f.err +func (f fakeEventRecorder) RecordNodeEvent(ctx context.Context, event *event.NodeExecutionEvent, eventConfig *config.EventConfig) error { + return f.nodeErr +} + +func (f fakeEventRecorder) RecordTaskEvent(ctx context.Context, event *event.TaskExecutionEvent, eventConfig *config.EventConfig) error { + return f.taskErr } type parentInfo struct { @@ -301,10 +306,10 @@ func Test_NodeContextDefaultInterruptible(t *testing.T) { } func Test_NodeContext_RecordNodeEvent(t *testing.T) { - noErrRecorder := fakeNodeEventRecorder{} - alreadyExistsError := fakeNodeEventRecorder{&eventsErr.EventError{Code: eventsErr.AlreadyExists, Cause: fmt.Errorf("err")}} - inTerminalError := fakeNodeEventRecorder{&eventsErr.EventError{Code: eventsErr.EventAlreadyInTerminalStateError, Cause: fmt.Errorf("err")}} - otherError := fakeNodeEventRecorder{&eventsErr.EventError{Code: eventsErr.ResourceExhausted, Cause: fmt.Errorf("err")}} + noErrRecorder := fakeEventRecorder{} + alreadyExistsError := fakeEventRecorder{nodeErr: &eventsErr.EventError{Code: eventsErr.AlreadyExists, Cause: fmt.Errorf("err")}} + inTerminalError := fakeEventRecorder{nodeErr: &eventsErr.EventError{Code: eventsErr.EventAlreadyInTerminalStateError, Cause: fmt.Errorf("err")}} + otherError := fakeEventRecorder{nodeErr: &eventsErr.EventError{Code: eventsErr.ResourceExhausted, Cause: fmt.Errorf("err")}} tests := []struct { name string @@ -335,3 +340,38 @@ func Test_NodeContext_RecordNodeEvent(t *testing.T) { }) } } + +func Test_NodeContext_RecordTaskEvent(t1 *testing.T) { + noErrRecorder := fakeEventRecorder{} + alreadyExistsError := fakeEventRecorder{taskErr: &eventsErr.EventError{Code: eventsErr.AlreadyExists, Cause: fmt.Errorf("err")}} + inTerminalError := fakeEventRecorder{taskErr: &eventsErr.EventError{Code: eventsErr.EventAlreadyInTerminalStateError, Cause: fmt.Errorf("err")}} + otherError := fakeEventRecorder{taskErr: &eventsErr.EventError{Code: eventsErr.ResourceExhausted, Cause: fmt.Errorf("err")}} + + tests := []struct { + name string + rec events.TaskEventRecorder + p core.TaskExecution_Phase + wantErr bool + }{ + {"aborted-success", noErrRecorder, core.TaskExecution_ABORTED, false}, + {"aborted-failure", otherError, core.TaskExecution_ABORTED, true}, + {"aborted-already", alreadyExistsError, core.TaskExecution_ABORTED, false}, + {"aborted-terminal", inTerminalError, core.TaskExecution_ABORTED, false}, + {"running-terminal", inTerminalError, core.TaskExecution_RUNNING, true}, + } + for _, tt := range tests { + t1.Run(tt.name, func(t1 *testing.T) { + t := &eventRecorder{ + taskEventRecorder: tt.rec, + } + ev := &event.TaskExecutionEvent{ + Phase: tt.p, + } + if err := t.RecordTaskEvent(context.TODO(), ev, &config.EventConfig{ + RawOutputPolicy: config.RawOutputPolicyReference, + }); (err != nil) != tt.wantErr { + t1.Errorf("RecordTaskEvent() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/pkg/controller/nodes/task_event_recorder.go.bak b/pkg/controller/nodes/task_event_recorder.go.bak deleted file mode 100644 index ef3ec1e93..000000000 --- a/pkg/controller/nodes/task_event_recorder.go.bak +++ /dev/null @@ -1,36 +0,0 @@ -package nodes - -import ( - "context" - - "github.com/flyteorg/flytepropeller/events" - eventsErr "github.com/flyteorg/flytepropeller/events/errors" - "github.com/flyteorg/flytepropeller/pkg/controller/config" - - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/flyteorg/flytestdlib/logger" - "github.com/pkg/errors" -) - -type taskEventRecorder struct { - events.TaskEventRecorder -} - -func (t taskEventRecorder) RecordTaskEvent(ctx context.Context, ev *event.TaskExecutionEvent, eventConfig *config.EventConfig) error { - if err := t.TaskEventRecorder.RecordTaskEvent(ctx, ev, eventConfig); err != nil { - if eventsErr.IsAlreadyExists(err) { - logger.Warningf(ctx, "Failed to record taskEvent, error [%s]. Trying to record state: %s. Ignoring this error!", err.Error(), ev.Phase) - return nil - } else if eventsErr.IsEventAlreadyInTerminalStateError(err) { - if IsTerminalTaskPhase(ev.Phase) { - // Event is terminal and the stored value in flyteadmin is already terminal. This implies aborted case. So ignoring - logger.Warningf(ctx, "Failed to record taskEvent, error [%s]. Trying to record state: %s. Ignoring this error!", err.Error(), ev.Phase) - return nil - } - logger.Warningf(ctx, "Failed to record taskEvent in state: %s, error: %s", ev.Phase, err) - return errors.Wrapf(err, "failed to record task event, as it already exists in terminal state. Event state: %s", ev.Phase) - } - return err - } - return nil -} diff --git a/pkg/controller/nodes/task_event_recorder_test.go.bak b/pkg/controller/nodes/task_event_recorder_test.go.bak deleted file mode 100644 index 0f4da2037..000000000 --- a/pkg/controller/nodes/task_event_recorder_test.go.bak +++ /dev/null @@ -1,60 +0,0 @@ -package nodes - -import ( - "context" - "fmt" - "testing" - - "github.com/flyteorg/flytepropeller/events" - eventsErr "github.com/flyteorg/flytepropeller/events/errors" - "github.com/flyteorg/flytepropeller/pkg/controller/config" - - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" -) - -type fakeTaskEventsRecorder struct { - err error -} - -func (f fakeTaskEventsRecorder) RecordTaskEvent(ctx context.Context, event *event.TaskExecutionEvent, eventConfig *config.EventConfig) error { - if f.err != nil { - return f.err - } - return nil -} - -func Test_taskEventRecorder_RecordTaskEvent(t1 *testing.T) { - noErrRecorder := fakeTaskEventsRecorder{} - alreadyExistsError := fakeTaskEventsRecorder{&eventsErr.EventError{Code: eventsErr.AlreadyExists, Cause: fmt.Errorf("err")}} - inTerminalError := fakeTaskEventsRecorder{&eventsErr.EventError{Code: eventsErr.EventAlreadyInTerminalStateError, Cause: fmt.Errorf("err")}} - otherError := fakeTaskEventsRecorder{&eventsErr.EventError{Code: eventsErr.ResourceExhausted, Cause: fmt.Errorf("err")}} - - tests := []struct { - name string - rec events.TaskEventRecorder - p core.TaskExecution_Phase - wantErr bool - }{ - {"aborted-success", noErrRecorder, core.TaskExecution_ABORTED, false}, - {"aborted-failure", otherError, core.TaskExecution_ABORTED, true}, - {"aborted-already", alreadyExistsError, core.TaskExecution_ABORTED, false}, - {"aborted-terminal", inTerminalError, core.TaskExecution_ABORTED, false}, - {"running-terminal", inTerminalError, core.TaskExecution_RUNNING, true}, - } - for _, tt := range tests { - t1.Run(tt.name, func(t1 *testing.T) { - t := taskEventRecorder{ - TaskEventRecorder: tt.rec, - } - ev := &event.TaskExecutionEvent{ - Phase: tt.p, - } - if err := t.RecordTaskEvent(context.TODO(), ev, &config.EventConfig{ - RawOutputPolicy: config.RawOutputPolicyReference, - }); (err != nil) != tt.wantErr { - t1.Errorf("RecordTaskEvent() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} From 837379dc1e6d4b0a632723706340fc968e0df42e Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Tue, 11 Jul 2023 13:46:41 -0500 Subject: [PATCH 54/62] adding interface checking for arraynode Signed-off-by: Daniel Rammer --- pkg/compiler/transformers/k8s/node.go | 2 +- pkg/compiler/validators/interface.go | 30 +++++- pkg/compiler/validators/interface_test.go | 107 ++++++++++++++++++++++ 3 files changed, 137 insertions(+), 2 deletions(-) diff --git a/pkg/compiler/transformers/k8s/node.go b/pkg/compiler/transformers/k8s/node.go index 4fcbd6e14..fb487efdd 100644 --- a/pkg/compiler/transformers/k8s/node.go +++ b/pkg/compiler/transformers/k8s/node.go @@ -156,7 +156,7 @@ func buildNodeSpec(n *core.Node, tasks []*core.CompiledTask, errs errors.Compile arrayNode := n.GetArrayNode() // since we set retries=1 on the node it's not using the task-level retries - arrayNode.Node.Metadata.Retries = nil // TODO @hamersaw - should probably set node-level retires to task in flytekit + //arrayNode.Node.Metadata.Retries = nil // TODO @hamersaw - should probably set node-level retires to task in flytekit // build subNodeSpecs subNodeSpecs, ok := buildNodeSpec(arrayNode.Node, tasks, errs) diff --git a/pkg/compiler/validators/interface.go b/pkg/compiler/validators/interface.go index f5a11345d..b11ff04cb 100644 --- a/pkg/compiler/validators/interface.go +++ b/pkg/compiler/validators/interface.go @@ -154,7 +154,35 @@ func ValidateUnderlyingInterface(w c.WorkflowBuilder, node c.NodeBuilder, errs e errs.Collect(errors.NewNoConditionFound(node.GetId())) } case *core.Node_ArrayNode: - // TODO @hamersaw complete + arrayNode := node.GetArrayNode() + underlyingNodeBuilder := w.GetOrCreateNodeBuilder(arrayNode.Node) + if underlyingIface, ok := ValidateUnderlyingInterface(w, underlyingNodeBuilder, errs.NewScope()); ok { + // wrap all input and output variables in a collection type + iface = &core.TypedInterface{ + Inputs: &core.VariableMap{Variables: map[string]*core.Variable{}}, + Outputs: &core.VariableMap{Variables: map[string]*core.Variable{}}, + } + + for name, binding := range underlyingIface.GetInputs().Variables { + iface.Inputs.Variables[name] = &core.Variable{ + Type: &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: binding.GetType(), + }, + }, + } + } + + for name, binding := range underlyingIface.GetOutputs().Variables { + iface.Outputs.Variables[name] = &core.Variable{ + Type: &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: binding.GetType(), + }, + }, + } + } + } default: errs.Collect(errors.NewValueRequiredErr(node.GetId(), "Target")) } diff --git a/pkg/compiler/validators/interface_test.go b/pkg/compiler/validators/interface_test.go index 5580bd5c6..862764b84 100644 --- a/pkg/compiler/validators/interface_test.go +++ b/pkg/compiler/validators/interface_test.go @@ -374,6 +374,113 @@ func TestValidateUnderlyingInterface(t *testing.T) { assertNonEmptyInterface(t, iface, ifaceOk, errs) }) }) + + t.Run("ArrayNode", func(t *testing.T) { + // mock underlying task node + iface := &core.TypedInterface{ + Inputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "foo": { + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_INTEGER, + }, + }, + }, + }, + }, + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "bar": { + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_FLOAT, + }, + }, + }, + }, + }, + } + + taskNode := &core.Node{ + Id: "node_1", + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{ + Name: "Task_1", + }, + }, + }, + }, + } + + task := mocks.Task{} + task.On("GetInterface").Return(iface) + + taskNodeBuilder := &mocks.NodeBuilder{} + taskNodeBuilder.On("GetCoreNode").Return(taskNode) + taskNodeBuilder.On("GetId").Return(taskNode.Id) + taskNodeBuilder.On("GetTaskNode").Return(taskNode.Target.(*core.Node_TaskNode).TaskNode) + taskNodeBuilder.On("GetInterface").Return(nil) + taskNodeBuilder.On("SetInterface", mock.AnythingOfType("*core.TypedInterface")).Return(nil) + + wfBuilder := mocks.WorkflowBuilder{} + wfBuilder.On("GetTask", mock.MatchedBy(func(id core.Identifier) bool { + return id.String() == (&core.Identifier{ + Name: "Task_1", + }).String() + })).Return(&task, true) + wfBuilder.On("GetOrCreateNodeBuilder", mock.MatchedBy(func(node *core.Node) bool { + return node.Id == "node_1" + })).Return(taskNodeBuilder) + + // mock array node + arrayNode := &core.Node{ + Id: "node_2", + Target: &core.Node_ArrayNode{ + ArrayNode: &core.ArrayNode{ + Node: taskNode, + }, + }, + } + + nodeBuilder := mocks.NodeBuilder{} + nodeBuilder.On("GetArrayNode").Return(arrayNode.Target.(*core.Node_ArrayNode).ArrayNode) + nodeBuilder.On("GetCoreNode").Return(arrayNode) + nodeBuilder.On("GetId").Return(arrayNode.Id) + nodeBuilder.On("GetInterface").Return(nil) + nodeBuilder.On("SetInterface", mock.Anything).Return() + + // compute arrayNode interface + errs := errors.NewCompileErrors() + arrayNodeIface, ifaceOk := ValidateUnderlyingInterface(&wfBuilder, &nodeBuilder, errs.NewScope()) + assertNonEmptyInterface(t, iface, ifaceOk, errs) + + // assert arrayNodeIFace is a collection wrapper of iface + assert.Len(t, arrayNodeIface.Inputs.Variables, len(iface.Inputs.Variables)) + for name, variable := range iface.Inputs.Variables { + arrayNodeVariable := arrayNodeIface.Inputs.Variables[name] + assert.NotNil(t, arrayNodeVariable) + + collectionType, ok := arrayNodeVariable.Type.GetType().(*core.LiteralType_CollectionType) + assert.True(t, ok) + + assert.Equal(t, variable.Type, collectionType.CollectionType) + } + + assert.Len(t, arrayNodeIface.Outputs.Variables, len(iface.Outputs.Variables)) + for name, variable := range iface.Outputs.Variables { + arrayNodeVariable := arrayNodeIface.Outputs.Variables[name] + assert.NotNil(t, arrayNodeVariable) + + collectionType, ok := arrayNodeVariable.Type.GetType().(*core.LiteralType_CollectionType) + assert.True(t, ok) + + assert.Equal(t, variable.Type, collectionType.CollectionType) + } + + }) } func matchIdentifier(id core.Identifier) interface{} { From 23d5312b63275f9e3174e34a2c973de154fd896f Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Tue, 11 Jul 2023 16:18:04 -0500 Subject: [PATCH 55/62] added transform unit test Signed-off-by: Daniel Rammer --- pkg/compiler/transformers/k8s/node.go | 2 +- pkg/compiler/transformers/k8s/node_test.go | 23 +++++++++++++++++++ .../nodes/node_exec_context_test.go | 2 +- 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/pkg/compiler/transformers/k8s/node.go b/pkg/compiler/transformers/k8s/node.go index fb487efdd..5307988c3 100644 --- a/pkg/compiler/transformers/k8s/node.go +++ b/pkg/compiler/transformers/k8s/node.go @@ -164,7 +164,7 @@ func buildNodeSpec(n *core.Node, tasks []*core.CompiledTask, errs errors.Compile return nil, ok } - // TODO @hamersaw - complete + // build ArrayNode nodeSpec.Kind = v1alpha1.NodeKindArray nodeSpec.ArrayNode = &v1alpha1.ArrayNodeSpec{ SubNodeSpec: subNodeSpecs[0], diff --git a/pkg/compiler/transformers/k8s/node_test.go b/pkg/compiler/transformers/k8s/node_test.go index a9732d9d7..e879f2cb3 100644 --- a/pkg/compiler/transformers/k8s/node_test.go +++ b/pkg/compiler/transformers/k8s/node_test.go @@ -243,6 +243,29 @@ func TestBuildNodeSpec(t *testing.T) { mustBuild(t, n, 1, errs.NewScope()) }) + + t.Run("ArrayNode", func(t *testing.T) { + n.Node.Target = &core.Node_ArrayNode{ + ArrayNode: &core.ArrayNode{ + Node: &core.Node{ + Id: "foo", + Target: &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "ref_1"}, + }, + }, + }, + }, + Parallelism: 10, + SuccessCriteria: &core.ArrayNode_MinSuccessRatio{ + MinSuccessRatio: 0.5, + }, + }, + } + + mustBuild(t, n, 1, errs.NewScope()) + }) } func TestBuildTasks(t *testing.T) { diff --git a/pkg/controller/nodes/node_exec_context_test.go b/pkg/controller/nodes/node_exec_context_test.go index 1816a667f..707ce33f4 100644 --- a/pkg/controller/nodes/node_exec_context_test.go +++ b/pkg/controller/nodes/node_exec_context_test.go @@ -328,7 +328,7 @@ func Test_NodeContext_RecordNodeEvent(t *testing.T) { eventRecorder := &eventRecorder{ nodeEventRecorder: tt.rec, } - + ev := &event.NodeExecutionEvent{ Id: &core.NodeExecutionIdentifier{}, Phase: tt.p, From 5dbc665ace89da87727435d54fa58e93893b8f7b Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Thu, 13 Jul 2023 12:20:38 -0500 Subject: [PATCH 56/62] fixed input bindings issue Signed-off-by: Daniel Rammer --- pkg/compiler/transformers/k8s/node.go | 11 ----------- pkg/compiler/validators/interface.go | 28 +++------------------------ 2 files changed, 3 insertions(+), 36 deletions(-) diff --git a/pkg/compiler/transformers/k8s/node.go b/pkg/compiler/transformers/k8s/node.go index 5307988c3..9bc0f608e 100644 --- a/pkg/compiler/transformers/k8s/node.go +++ b/pkg/compiler/transformers/k8s/node.go @@ -155,9 +155,6 @@ func buildNodeSpec(n *core.Node, tasks []*core.CompiledTask, errs errors.Compile case *core.Node_ArrayNode: arrayNode := n.GetArrayNode() - // since we set retries=1 on the node it's not using the task-level retries - //arrayNode.Node.Metadata.Retries = nil // TODO @hamersaw - should probably set node-level retires to task in flytekit - // build subNodeSpecs subNodeSpecs, ok := buildNodeSpec(arrayNode.Node, tasks, errs) if !ok { @@ -177,14 +174,6 @@ func buildNodeSpec(n *core.Node, tasks []*core.CompiledTask, errs errors.Compile case *core.ArrayNode_MinSuccessRatio: nodeSpec.ArrayNode.MinSuccessRatio = &successCriteria.MinSuccessRatio } - - // TODO @hamersaw hack - should not be necessary, should be set in flytekit - for _, binding := range nodeSpec.InputBindings { - switch b := binding.Binding.Binding.Value.(type) { - case *core.BindingData_Promise: - b.Promise.NodeId = "start-node" - } - } default: if n.GetId() == v1alpha1.StartNodeID { nodeSpec.Kind = v1alpha1.NodeKindStart diff --git a/pkg/compiler/validators/interface.go b/pkg/compiler/validators/interface.go index b11ff04cb..cdee66a45 100644 --- a/pkg/compiler/validators/interface.go +++ b/pkg/compiler/validators/interface.go @@ -157,31 +157,9 @@ func ValidateUnderlyingInterface(w c.WorkflowBuilder, node c.NodeBuilder, errs e arrayNode := node.GetArrayNode() underlyingNodeBuilder := w.GetOrCreateNodeBuilder(arrayNode.Node) if underlyingIface, ok := ValidateUnderlyingInterface(w, underlyingNodeBuilder, errs.NewScope()); ok { - // wrap all input and output variables in a collection type - iface = &core.TypedInterface{ - Inputs: &core.VariableMap{Variables: map[string]*core.Variable{}}, - Outputs: &core.VariableMap{Variables: map[string]*core.Variable{}}, - } - - for name, binding := range underlyingIface.GetInputs().Variables { - iface.Inputs.Variables[name] = &core.Variable{ - Type: &core.LiteralType{ - Type: &core.LiteralType_CollectionType{ - CollectionType: binding.GetType(), - }, - }, - } - } - - for name, binding := range underlyingIface.GetOutputs().Variables { - iface.Outputs.Variables[name] = &core.Variable{ - Type: &core.LiteralType{ - Type: &core.LiteralType_CollectionType{ - CollectionType: binding.GetType(), - }, - }, - } - } + // ArrayNode interface should be inferred from the underlying node interface. flytekit + // will correct wrap variables in collections as needed, leaving partials as is. + iface = underlyingIface } default: errs.Collect(errors.NewValueRequiredErr(node.GetId(), "Target")) From 7914641b2426497b5a2069342ee56715664e33fc Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Thu, 13 Jul 2023 12:24:37 -0500 Subject: [PATCH 57/62] fixed unit tests Signed-off-by: Daniel Rammer --- pkg/compiler/validators/interface_test.go | 28 +++-------------------- 1 file changed, 3 insertions(+), 25 deletions(-) diff --git a/pkg/compiler/validators/interface_test.go b/pkg/compiler/validators/interface_test.go index 862764b84..9a2183ebf 100644 --- a/pkg/compiler/validators/interface_test.go +++ b/pkg/compiler/validators/interface_test.go @@ -1,6 +1,7 @@ package validators import ( + "reflect" "testing" "time" @@ -455,31 +456,8 @@ func TestValidateUnderlyingInterface(t *testing.T) { // compute arrayNode interface errs := errors.NewCompileErrors() arrayNodeIface, ifaceOk := ValidateUnderlyingInterface(&wfBuilder, &nodeBuilder, errs.NewScope()) - assertNonEmptyInterface(t, iface, ifaceOk, errs) - - // assert arrayNodeIFace is a collection wrapper of iface - assert.Len(t, arrayNodeIface.Inputs.Variables, len(iface.Inputs.Variables)) - for name, variable := range iface.Inputs.Variables { - arrayNodeVariable := arrayNodeIface.Inputs.Variables[name] - assert.NotNil(t, arrayNodeVariable) - - collectionType, ok := arrayNodeVariable.Type.GetType().(*core.LiteralType_CollectionType) - assert.True(t, ok) - - assert.Equal(t, variable.Type, collectionType.CollectionType) - } - - assert.Len(t, arrayNodeIface.Outputs.Variables, len(iface.Outputs.Variables)) - for name, variable := range iface.Outputs.Variables { - arrayNodeVariable := arrayNodeIface.Outputs.Variables[name] - assert.NotNil(t, arrayNodeVariable) - - collectionType, ok := arrayNodeVariable.Type.GetType().(*core.LiteralType_CollectionType) - assert.True(t, ok) - - assert.Equal(t, variable.Type, collectionType.CollectionType) - } - + assertNonEmptyInterface(t, arrayNodeIface, ifaceOk, errs) + assert.True(t, reflect.DeepEqual(arrayNodeIface, iface)) }) } From d5eb484e4e2087be44798ace944a91a503168158 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Fri, 14 Jul 2023 10:41:38 -0500 Subject: [PATCH 58/62] fixed unit tests Signed-off-by: Daniel Rammer --- pkg/controller/nodes/array/handler.go | 2 +- pkg/controller/nodes/executor.go | 16 +++++++--------- pkg/controller/nodes/executor_test.go | 6 +++--- pkg/controller/nodes/interfaces/handler.go | 2 +- 4 files changed, 12 insertions(+), 14 deletions(-) diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 3e566d7c3..ff82ce55a 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -412,7 +412,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // checkpoint paths are not computed here because this function is only called when writing // existing cached outputs. if this functionality changes this will need to be revisited. outputPaths := ioutils.NewCheckpointRemoteFilePaths(ctx, nCtx.DataStore(), subOutputDir, ioutils.NewRawOutputPaths(ctx, subDataDir), "") - reader := ioutils.NewRemoteFileOutputReader(ctx, nCtx.DataStore(), outputPaths, int64(999999999)) + reader := ioutils.NewRemoteFileOutputReader(ctx, nCtx.DataStore(), outputPaths, int64(999999999)) // TODO @hamersaw - use max-output-size // read outputs outputs, executionErr, err := reader.Read(ctx) diff --git a/pkg/controller/nodes/executor.go b/pkg/controller/nodes/executor.go index 516f26d21..23fbafc46 100644 --- a/pkg/controller/nodes/executor.go +++ b/pkg/controller/nodes/executor.go @@ -404,8 +404,7 @@ func (c *recursiveNodeExecutor) AbortHandler(ctx context.Context, execContext ex return err } // Abort this node - err = c.nodeExecutor.Abort(ctx, h, nCtx, reason) - return err + return c.nodeExecutor.Abort(ctx, h, nCtx, reason, true) } else if nodePhase == v1alpha1.NodePhaseSucceeded || nodePhase == v1alpha1.NodePhaseSkipped || nodePhase == v1alpha1.NodePhaseRecovered { // Abort downstream nodes downstreamNodes, err := dag.FromNode(currentNode.GetID()) @@ -882,7 +881,7 @@ func (c *nodeExecutor) execute(ctx context.Context, h interfaces.NodeHandler, nC return phase, nil } -func (c *nodeExecutor) Abort(ctx context.Context, h interfaces.NodeHandler, nCtx interfaces.NodeExecutionContext, reason string) error { +func (c *nodeExecutor) Abort(ctx context.Context, h interfaces.NodeHandler, nCtx interfaces.NodeExecutionContext, reason string, finalTransition bool) error { logger.Debugf(ctx, "Calling aborting & finalize") if err := h.Abort(ctx, nCtx, reason); err != nil { finalizeErr := h.Finalize(ctx, nCtx) @@ -896,9 +895,8 @@ func (c *nodeExecutor) Abort(ctx context.Context, h interfaces.NodeHandler, nCtx return err } - // only send event if node is in non-terminal phase - phase := nCtx.NodeStatus().GetPhase() - if phase != v1alpha1.NodePhaseNotYetStarted && canHandleNode(phase) { + // only send event if this is the final transition for this node + if finalTransition { nodeExecutionID := &core.NodeExecutionIdentifier{ ExecutionId: nCtx.NodeExecutionMetadata().GetNodeExecutionID().ExecutionId, NodeId: nCtx.NodeExecutionMetadata().GetNodeExecutionID().NodeId, @@ -1140,7 +1138,7 @@ func (c *nodeExecutor) handleQueuedOrRunningNode(ctx context.Context, nCtx inter func (c *nodeExecutor) handleRetryableFailure(ctx context.Context, nCtx interfaces.NodeExecutionContext, h interfaces.NodeHandler) (interfaces.NodeStatus, error) { nodeStatus := nCtx.NodeStatus() logger.Debugf(ctx, "node failed with retryable failure, aborting and finalizing, message: %s", nodeStatus.GetMessage()) - if err := c.Abort(ctx, h, nCtx, nodeStatus.GetMessage()); err != nil { + if err := c.Abort(ctx, h, nCtx, nodeStatus.GetMessage(), false); err != nil { return interfaces.NodeStatusUndefined, err } @@ -1180,7 +1178,7 @@ func (c *nodeExecutor) HandleNode(ctx context.Context, dag executors.DAGStructur if currentPhase == v1alpha1.NodePhaseFailing { logger.Debugf(ctx, "node failing") - if err := c.Abort(ctx, h, nCtx, "node failing"); err != nil { + if err := c.Abort(ctx, h, nCtx, "node failing", false); err != nil { return interfaces.NodeStatusUndefined, err } nodeStatus.UpdatePhase(v1alpha1.NodePhaseFailed, metav1.Now(), nodeStatus.GetMessage(), nodeStatus.GetExecutionError()) @@ -1193,7 +1191,7 @@ func (c *nodeExecutor) HandleNode(ctx context.Context, dag executors.DAGStructur if currentPhase == v1alpha1.NodePhaseTimingOut { logger.Debugf(ctx, "node timing out") - if err := c.Abort(ctx, h, nCtx, "node timed out"); err != nil { + if err := c.Abort(ctx, h, nCtx, "node timed out", false); err != nil { return interfaces.NodeStatusUndefined, err } diff --git a/pkg/controller/nodes/executor_test.go b/pkg/controller/nodes/executor_test.go index 1c0c0aa45..0d97ed921 100644 --- a/pkg/controller/nodes/executor_test.go +++ b/pkg/controller/nodes/executor_test.go @@ -1590,7 +1590,7 @@ func Test_nodeExecutor_abort(t *testing.T) { called = true }).Return(nil) - err := exec.Abort(ctx, h, nCtx, "testing") + err := exec.Abort(ctx, h, nCtx, "testing", false) assert.Equal(t, "test error", err.Error()) assert.True(t, called) }) @@ -1604,7 +1604,7 @@ func Test_nodeExecutor_abort(t *testing.T) { called = true }).Return(errors.New("finalize error")) - err := exec.Abort(ctx, h, nCtx, "testing") + err := exec.Abort(ctx, h, nCtx, "testing", false) assert.Equal(t, "0: test error\r\n1: finalize error\r\n", err.Error()) assert.True(t, called) }) @@ -1618,7 +1618,7 @@ func Test_nodeExecutor_abort(t *testing.T) { called = true }).Return(nil) - err := exec.Abort(ctx, h, nCtx, "testing") + err := exec.Abort(ctx, h, nCtx, "testing", false) assert.NoError(t, err) assert.True(t, called) }) diff --git a/pkg/controller/nodes/interfaces/handler.go b/pkg/controller/nodes/interfaces/handler.go index 89eaac5c8..16ef73274 100644 --- a/pkg/controller/nodes/interfaces/handler.go +++ b/pkg/controller/nodes/interfaces/handler.go @@ -13,7 +13,7 @@ import ( // NodeExecutor defines the interface for handling a single Flyte Node of any Node type. type NodeExecutor interface { HandleNode(ctx context.Context, dag executors.DAGStructure, nCtx NodeExecutionContext, h NodeHandler) (NodeStatus, error) - Abort(ctx context.Context, h NodeHandler, nCtx NodeExecutionContext, reason string) error + Abort(ctx context.Context, h NodeHandler, nCtx NodeExecutionContext, reason string, finalTransition bool) error Finalize(ctx context.Context, h NodeHandler, nCtx NodeExecutionContext) error } From b044277250853965baab556c7bd5b208178426db Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Fri, 14 Jul 2023 10:49:20 -0500 Subject: [PATCH 59/62] go generate Signed-off-by: Daniel Rammer --- .../nodes/interfaces/mocks/node_executor.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pkg/controller/nodes/interfaces/mocks/node_executor.go b/pkg/controller/nodes/interfaces/mocks/node_executor.go index 72e99c906..e619c8fd7 100644 --- a/pkg/controller/nodes/interfaces/mocks/node_executor.go +++ b/pkg/controller/nodes/interfaces/mocks/node_executor.go @@ -24,8 +24,8 @@ func (_m NodeExecutor_Abort) Return(_a0 error) *NodeExecutor_Abort { return &NodeExecutor_Abort{Call: _m.Call.Return(_a0)} } -func (_m *NodeExecutor) OnAbort(ctx context.Context, h interfaces.NodeHandler, nCtx interfaces.NodeExecutionContext, reason string) *NodeExecutor_Abort { - c_call := _m.On("Abort", ctx, h, nCtx, reason) +func (_m *NodeExecutor) OnAbort(ctx context.Context, h interfaces.NodeHandler, nCtx interfaces.NodeExecutionContext, reason string, finalTransition bool) *NodeExecutor_Abort { + c_call := _m.On("Abort", ctx, h, nCtx, reason, finalTransition) return &NodeExecutor_Abort{Call: c_call} } @@ -34,13 +34,13 @@ func (_m *NodeExecutor) OnAbortMatch(matchers ...interface{}) *NodeExecutor_Abor return &NodeExecutor_Abort{Call: c_call} } -// Abort provides a mock function with given fields: ctx, h, nCtx, reason -func (_m *NodeExecutor) Abort(ctx context.Context, h interfaces.NodeHandler, nCtx interfaces.NodeExecutionContext, reason string) error { - ret := _m.Called(ctx, h, nCtx, reason) +// Abort provides a mock function with given fields: ctx, h, nCtx, reason, finalTransition +func (_m *NodeExecutor) Abort(ctx context.Context, h interfaces.NodeHandler, nCtx interfaces.NodeExecutionContext, reason string, finalTransition bool) error { + ret := _m.Called(ctx, h, nCtx, reason, finalTransition) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, interfaces.NodeHandler, interfaces.NodeExecutionContext, string) error); ok { - r0 = rf(ctx, h, nCtx, reason) + if rf, ok := ret.Get(0).(func(context.Context, interfaces.NodeHandler, interfaces.NodeExecutionContext, string, bool) error); ok { + r0 = rf(ctx, h, nCtx, reason, finalTransition) } else { r0 = ret.Error(0) } From 14fae22ec7e6d3084b7aadf9fdd9a0c9585769c9 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Mon, 17 Jul 2023 11:24:16 -0500 Subject: [PATCH 60/62] addressing random TODO Signed-off-by: Daniel Rammer --- pkg/controller/nodes/array/handler.go | 16 +++------------- pkg/controller/nodes/array/utils.go | 1 - pkg/controller/nodes/node_exec_context.go | 1 - 3 files changed, 3 insertions(+), 15 deletions(-) diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index ff82ce55a..115049e16 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -297,7 +297,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu for _, taskExecutionEvent := range arrayEventRecorder.TaskEvents() { for _, log := range taskExecutionEvent.Logs { - log.Name = fmt.Sprintf("%s-%d", log.Name, i) // TODO @hamersaw - do we need to add retryAttempt to log name? + log.Name = fmt.Sprintf("%s-%d", log.Name, i) } externalResources = append(externalResources, &event.ExternalResourceInfo{ @@ -412,7 +412,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // checkpoint paths are not computed here because this function is only called when writing // existing cached outputs. if this functionality changes this will need to be revisited. outputPaths := ioutils.NewCheckpointRemoteFilePaths(ctx, nCtx.DataStore(), subOutputDir, ioutils.NewRawOutputPaths(ctx, subDataDir), "") - reader := ioutils.NewRemoteFileOutputReader(ctx, nCtx.DataStore(), outputPaths, int64(999999999)) // TODO @hamersaw - use max-output-size + reader := ioutils.NewRemoteFileOutputReader(ctx, nCtx.DataStore(), outputPaths, int64(nCtx.MaxDatasetSizeBytes())) // read outputs outputs, executionErr, err := reader.Read(ctx) @@ -560,17 +560,7 @@ func (a *arrayNodeHandler) buildArrayNodeContext(ctx context.Context, nCtx inter pluginStateBytes = a.pluginStateBytesNotStarted } - // we set subDataDir and subOutputDir to the node dirs because flytekit automatically appends subtask - // index. however when we check completion status we need to manually append index - so in all cases - // where the node phase is not Queued (ie. task handler will launch task and init flytekit params) we - // append the subtask index. - // TODO @hamersaw - verify this has been fixed in flytekit for arraynode implementation - /*var subDataDir, subOutputDir storage.DataReference - if nodePhase == v1alpha1.NodePhaseQueued { - subDataDir, subOutputDir, err = constructOutputReferences(ctx, nCtx) - } else { - subDataDir, subOutputDir, err = constructOutputReferences(ctx, nCtx, strconv.Itoa(i)) - }*/ + // construct output references currentAttempt := uint32(arrayNodeState.SubNodeRetryAttempts.GetItem(subNodeIndex)) subDataDir, subOutputDir, err := constructOutputReferences(ctx, nCtx, strconv.Itoa(subNodeIndex), strconv.Itoa(int(currentAttempt))) if err != nil { diff --git a/pkg/controller/nodes/array/utils.go b/pkg/controller/nodes/array/utils.go index 7f330063d..a0700e573 100644 --- a/pkg/controller/nodes/array/utils.go +++ b/pkg/controller/nodes/array/utils.go @@ -68,7 +68,6 @@ func buildTaskExecutionEvent(_ context.Context, nCtx interfaces.NodeExecutionCon } func buildSubNodeID(nCtx interfaces.NodeExecutionContext, index int, retryAttempt uint32) string { - // TODO @hamersaw - what do we want for a subNode ID? return fmt.Sprintf("%s-n%d-%d", nCtx.NodeID(), index, retryAttempt) } diff --git a/pkg/controller/nodes/node_exec_context.go b/pkg/controller/nodes/node_exec_context.go index 1303f245f..ba43d1ba7 100644 --- a/pkg/controller/nodes/node_exec_context.go +++ b/pkg/controller/nodes/node_exec_context.go @@ -244,7 +244,6 @@ func newNodeExecContext(_ context.Context, store *storage.DataStore, execContext taskEventRecorder: taskEventRecorder, nodeEventRecorder: nodeEventRecorder, }, - //er: er, maxDatasetSizeBytes: maxDatasetSize, tr: tr, nsm: nsm, From bdbef61748bef96bc289ff439c135f03a891ccbb Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Wed, 19 Jul 2023 10:10:38 -0500 Subject: [PATCH 61/62] fixed unit tests Signed-off-by: Daniel Rammer --- pkg/controller/nodes/array/handler.go | 2 +- pkg/controller/nodes/array/handler_test.go | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 115049e16..2a7a22eb9 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -412,7 +412,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // checkpoint paths are not computed here because this function is only called when writing // existing cached outputs. if this functionality changes this will need to be revisited. outputPaths := ioutils.NewCheckpointRemoteFilePaths(ctx, nCtx.DataStore(), subOutputDir, ioutils.NewRawOutputPaths(ctx, subDataDir), "") - reader := ioutils.NewRemoteFileOutputReader(ctx, nCtx.DataStore(), outputPaths, int64(nCtx.MaxDatasetSizeBytes())) + reader := ioutils.NewRemoteFileOutputReader(ctx, nCtx.DataStore(), outputPaths, nCtx.MaxDatasetSizeBytes()) // read outputs outputs, executionErr, err := reader.Read(ctx) diff --git a/pkg/controller/nodes/array/handler_test.go b/pkg/controller/nodes/array/handler_test.go index 49e975061..f3e6f8bd1 100644 --- a/pkg/controller/nodes/array/handler_test.go +++ b/pkg/controller/nodes/array/handler_test.go @@ -72,6 +72,7 @@ func createNodeExecutionContext(dataStore *storage.DataStore, eventRecorder inte inputLiteralMap *idlcore.LiteralMap, arrayNodeSpec *v1alpha1.NodeSpec, arrayNodeState *handler.ArrayNodeState) interfaces.NodeExecutionContext { nCtx := &mocks.NodeExecutionContext{} + nCtx.OnMaxDatasetSizeBytes().Return(9999999) // ContextualNodeLookup nodeLookup := &execmocks.NodeLookup{} From 70eda6ab2739c1878c4e8e0dfc4be66b01eff473 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Fri, 28 Jul 2023 10:18:12 -0500 Subject: [PATCH 62/62] addressing pr comments Signed-off-by: Daniel Rammer --- pkg/controller/nodes/array/handler.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 2a7a22eb9..19641cb93 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -136,8 +136,8 @@ func (a *arrayNodeHandler) Finalize(ctx context.Context, nCtx interfaces.NodeExe for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { nodePhase := v1alpha1.NodePhase(nodePhaseUint64) - // do not process nodes that have not started - if nodePhase == v1alpha1.NodePhaseNotYetStarted { + // do not process nodes that have not started or are in a terminal state + if nodePhase == v1alpha1.NodePhaseNotYetStarted || isTerminalNodePhase(nodePhase) { continue } @@ -524,6 +524,11 @@ func New(nodeExecutor interfaces.Node, eventConfig *config.EventConfig, scope pr }, nil } +// buildArrayNodeContext creates a custom environment to execute the ArrayNode subnode. This is uniquely required for +// the arrayNodeHandler because we require the same node execution entrypoint (ie. recursiveNodeExecutor.RecursiveNodeHandler) +// but need many different execution details, for example setting input values as a singular item rather than a collection, +// injecting environment variables for flytekit maptask execution, aggregating eventing so that rather than tracking state for +// each subnode individually it sends a single event for the whole ArrayNode, and many more. func (a *arrayNodeHandler) buildArrayNodeContext(ctx context.Context, nCtx interfaces.NodeExecutionContext, arrayNodeState *handler.ArrayNodeState, arrayNode v1alpha1.ExecutableArrayNode, subNodeIndex int, currentParallelism *uint32) ( interfaces.Node, executors.ExecutionContext, executors.DAGStructure, executors.NodeLookup, *v1alpha1.NodeSpec, *v1alpha1.NodeStatus, *arrayEventRecorder, error) {