diff --git a/internal/internal_workflow_client.go b/internal/internal_workflow_client.go index 31d98ad7b..4352ed6e1 100644 --- a/internal/internal_workflow_client.go +++ b/internal/internal_workflow_client.go @@ -1715,75 +1715,99 @@ func (w *workflowClientInterceptor) executeWorkflowWithOperation( withStartOp, }, } - multiResp, err := w.client.workflowService.ExecuteMultiOperation(ctx, &multiRequest) - var multiErr *serviceerror.MultiOperationExecution - if errors.As(err, &multiErr) { - if len(multiErr.OperationErrors()) != len(multiRequest.Operations) { - return nil, fmt.Errorf("%w: %v instead of %v operation errors", - errInvalidServerResponse, len(multiErr.OperationErrors()), len(multiRequest.Operations)) + var startResp *workflowservice.StartWorkflowExecutionResponse + var updateResp *workflowservice.UpdateWorkflowExecutionResponse + for { + multiResp, err := func() (*workflowservice.ExecuteMultiOperationResponse, error) { + grpcCtx, cancel := newGRPCContext(ctx, grpcTimeout(pollUpdateTimeout), grpcLongPoll(true), defaultGrpcRetryParameters(ctx)) + defer cancel() + + multiResp, err := w.client.workflowService.ExecuteMultiOperation(grpcCtx, &multiRequest) + if err != nil { + if ctx.Err() != nil { + return nil, NewWorkflowUpdateServiceTimeoutOrCanceledError(err) + } + if status := serviceerror.ToStatus(err); status.Code() == codes.Canceled || status.Code() == codes.DeadlineExceeded { + return nil, NewWorkflowUpdateServiceTimeoutOrCanceledError(err) + } + return nil, err + } + + return multiResp, err + }() + + var multiErr *serviceerror.MultiOperationExecution + if errors.As(err, &multiErr) { + if len(multiErr.OperationErrors()) != len(multiRequest.Operations) { + return nil, fmt.Errorf("%w: %v instead of %v operation errors", + errInvalidServerResponse, len(multiErr.OperationErrors()), len(multiRequest.Operations)) + } + + var abortedErr *serviceerror.MultiOperationAborted + startErr := errors.New("failed to start workflow") + for i, opReq := range multiRequest.Operations { + // if an operation error is of type MultiOperationAborted, it means it was only aborted because + // of another operation's error and is therefore not interesting or helpful + opErr := multiErr.OperationErrors()[i] + + switch t := opReq.Operation.(type) { + case *workflowservice.ExecuteMultiOperationRequest_Operation_StartWorkflow: + if !errors.As(opErr, &abortedErr) { + startErr = opErr + } + case *workflowservice.ExecuteMultiOperationRequest_Operation_UpdateWorkflow: + if !errors.As(opErr, &abortedErr) { + startErr = fmt.Errorf("%w: %w", errInvalidWorkflowOperation, opErr) + } + default: + // this would only happen if a case statement for a newly added operation is missing above + return nil, fmt.Errorf("%w: %T", errUnsupportedOperation, t) + } + } + return nil, startErr + } else if err != nil { + return nil, err + } + + if len(multiResp.Responses) != len(multiRequest.Operations) { + return nil, fmt.Errorf("%w: %v instead of %v operation results", + errInvalidServerResponse, len(multiResp.Responses), len(multiRequest.Operations)) } - var startErr error - var abortedErr *serviceerror.MultiOperationAborted for i, opReq := range multiRequest.Operations { - // if an operation error is of type MultiOperationAborted, it means it was only aborted because - // of another operation's error and is therefore not interesting or helpful - opErr := multiErr.OperationErrors()[i] + resp := multiResp.Responses[i].Response switch t := opReq.Operation.(type) { case *workflowservice.ExecuteMultiOperationRequest_Operation_StartWorkflow: - if !errors.As(opErr, &abortedErr) { - startErr = opErr + if opResp, ok := resp.(*workflowservice.ExecuteMultiOperationResponse_Response_StartWorkflow); ok { + startResp = opResp.StartWorkflow + } else { + return nil, fmt.Errorf("%w: StartWorkflow response has the wrong type %T", errInvalidServerResponse, resp) } case *workflowservice.ExecuteMultiOperationRequest_Operation_UpdateWorkflow: - if !errors.As(opErr, &abortedErr) { - startErr = fmt.Errorf("%w: %w", errInvalidWorkflowOperation, opErr) + if opResp, ok := resp.(*workflowservice.ExecuteMultiOperationResponse_Response_UpdateWorkflow); ok { + updateResp = opResp.UpdateWorkflow + } else { + return nil, fmt.Errorf("%w: UpdateWorkflow response has the wrong type %T", errInvalidServerResponse, resp) } default: // this would only happen if a case statement for a newly added operation is missing above return nil, fmt.Errorf("%w: %T", errUnsupportedOperation, t) } } - return nil, startErr - } else if err != nil { - return nil, err - } - if len(multiResp.Responses) != len(multiRequest.Operations) { - return nil, fmt.Errorf("%w: %v instead of %v operation results", - errInvalidServerResponse, len(multiResp.Responses), len(multiRequest.Operations)) + if w.updateIsDurable(updateResp) { + break + } } - var startResp *workflowservice.StartWorkflowExecutionResponse - for i, opReq := range multiRequest.Operations { - resp := multiResp.Responses[i].Response - - switch t := opReq.Operation.(type) { - case *workflowservice.ExecuteMultiOperationRequest_Operation_StartWorkflow: - if opResp, ok := resp.(*workflowservice.ExecuteMultiOperationResponse_Response_StartWorkflow); ok { - startResp = opResp.StartWorkflow - } else { - return nil, fmt.Errorf("%w: StartWorkflow response has the wrong type %T", errInvalidServerResponse, resp) - } - case *workflowservice.ExecuteMultiOperationRequest_Operation_UpdateWorkflow: - if opResp, ok := resp.(*workflowservice.ExecuteMultiOperationResponse_Response_UpdateWorkflow); ok { - handle, err := w.updateHandleFromResponse( - ctx, - enumspb.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_UNSPECIFIED, - opResp.UpdateWorkflow) - operation.(*UpdateWithStartWorkflowOperation).set(handle, err) - if err != nil { - return nil, fmt.Errorf("%w: %w", errInvalidWorkflowOperation, err) - } - } else { - return nil, fmt.Errorf("%w: UpdateWorkflow response has the wrong type %T", errInvalidServerResponse, resp) - } - default: - // this would only happen if a case statement for a newly added operation is missing above - return nil, fmt.Errorf("%w: %T", errUnsupportedOperation, t) - } + handle, err := w.updateHandleFromResponse(ctx, enumspb.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_UNSPECIFIED, updateResp) + operation.(*UpdateWithStartWorkflowOperation).set(handle, err) + if err != nil { + return nil, fmt.Errorf("%w: %w", errInvalidWorkflowOperation, err) } + return startResp, nil } @@ -2028,11 +2052,7 @@ func (w *workflowClientInterceptor) UpdateWorkflow( } return nil, err } - // Once the update is past admitted we know it is durable - // Note: old server version may return UNSPECIFIED if the update request - // did not reach the desired lifecycle stage. - if resp.GetStage() != enumspb.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_ADMITTED && - resp.GetStage() != enumspb.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_UNSPECIFIED { + if w.updateIsDurable(resp) { break } } @@ -2042,6 +2062,14 @@ func (w *workflowClientInterceptor) UpdateWorkflow( return w.updateHandleFromResponse(ctx, desiredLifecycleStage, resp) } +func (w *workflowClientInterceptor) updateIsDurable(resp *workflowservice.UpdateWorkflowExecutionResponse) bool { + // Once the update is past admitted we know it is durable + // Note: old server version may return UNSPECIFIED if the update request + // did not reach the desired lifecycle stage. + return resp.GetStage() != enumspb.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_ADMITTED && + resp.GetStage() != enumspb.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_UNSPECIFIED +} + func createUpdateWorkflowInput( options UpdateWorkflowOptions, ) (*ClientUpdateWorkflowInput, error) { diff --git a/internal/internal_workflow_client_test.go b/internal/internal_workflow_client_test.go index d242edda1..df3b644fc 100644 --- a/internal/internal_workflow_client_test.go +++ b/internal/internal_workflow_client_test.go @@ -976,6 +976,118 @@ func (s *workflowRunSuite) TestGetWorkflowNoExtantWorkflowAndNoRunId() { s.Equal("", workflowRunNoRunID.GetRunID()) } +func (s *workflowRunSuite) TestExecuteWorkflowWithUpdate_Retry() { + s.workflowServiceClient.EXPECT(). + ExecuteMultiOperation(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&workflowservice.ExecuteMultiOperationResponse{ + Responses: []*workflowservice.ExecuteMultiOperationResponse_Response{ + { + Response: &workflowservice.ExecuteMultiOperationResponse_Response_StartWorkflow{}, + }, + { + // 1st response: empty response, Update is not durable yet, client retries + Response: &workflowservice.ExecuteMultiOperationResponse_Response_UpdateWorkflow{}, + }, + }, + }, nil). + Return(&workflowservice.ExecuteMultiOperationResponse{ + Responses: []*workflowservice.ExecuteMultiOperationResponse_Response{ + { + Response: &workflowservice.ExecuteMultiOperationResponse_Response_StartWorkflow{ + StartWorkflow: &workflowservice.StartWorkflowExecutionResponse{ + RunId: "RUN_ID", + }, + }, + }, + { + // 2nd response: non-empty response, Update is durable + Response: &workflowservice.ExecuteMultiOperationResponse_Response_UpdateWorkflow{ + UpdateWorkflow: &workflowservice.UpdateWorkflowExecutionResponse{ + Stage: enumspb.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_COMPLETED, + }, + }, + }, + }, + }, nil) + + updOp := NewUpdateWithStartWorkflowOperation( + UpdateWorkflowOptions{ + UpdateName: "update", + WaitForStage: WorkflowUpdateStageCompleted, + }) + + _, err := s.workflowClient.ExecuteWorkflow( + context.Background(), + StartWorkflowOptions{ + ID: workflowID, + TaskQueue: taskqueue, + WithStartOperation: updOp, + }, workflowType, + ) + s.NoError(err) +} + +func (s *workflowRunSuite) TestExecuteWorkflowWithUpdate_Abort() { + tests := []struct { + name string + expectedErr string + respFunc func(ctx context.Context, in *workflowservice.ExecuteMultiOperationRequest, opts ...grpc.CallOption) (*workflowservice.ExecuteMultiOperationResponse, error) + }{ + { + name: "Timeout", + expectedErr: "context deadline exceeded", + respFunc: func(ctx context.Context, in *workflowservice.ExecuteMultiOperationRequest, opts ...grpc.CallOption) (*workflowservice.ExecuteMultiOperationResponse, error) { + <-ctx.Done() + return nil, ctx.Err() + }, + }, + { + name: "Cancelled", + expectedErr: "was_cancelled", + respFunc: func(ctx context.Context, in *workflowservice.ExecuteMultiOperationRequest, opts ...grpc.CallOption) (*workflowservice.ExecuteMultiOperationResponse, error) { + return nil, serviceerror.NewCanceled("was_cancelled") + }, + }, + { + name: "DeadlineExceeded", + expectedErr: "deadline_exceeded", + respFunc: func(ctx context.Context, in *workflowservice.ExecuteMultiOperationRequest, opts ...grpc.CallOption) (*workflowservice.ExecuteMultiOperationResponse, error) { + return nil, serviceerror.NewDeadlineExceeded("deadline_exceeded") + }, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + s.workflowServiceClient.EXPECT(). + ExecuteMultiOperation(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(tt.respFunc) + + updOp := NewUpdateWithStartWorkflowOperation( + UpdateWorkflowOptions{ + UpdateName: "update", + WaitForStage: WorkflowUpdateStageCompleted, + }) + + ctxWithTimeout, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + _, err := s.workflowClient.ExecuteWorkflow( + ctxWithTimeout, + StartWorkflowOptions{ + ID: workflowID, + TaskQueue: taskqueue, + WithStartOperation: updOp, + }, workflowType, + ) + + var expectedErr *WorkflowUpdateServiceTimeoutOrCanceledError + require.ErrorAs(s.T(), err, &expectedErr) + require.ErrorContains(s.T(), err, tt.expectedErr) + }) + } +} + func (s *workflowRunSuite) TestExecuteWorkflowWithUpdate_NonMultiOperationError() { s.workflowServiceClient.EXPECT(). ExecuteMultiOperation(gomock.Any(), gomock.Any(), gomock.Any()).