Skip to content

Commit

Permalink
fix(errors): report errors on streaming responses
Browse files Browse the repository at this point in the history
  • Loading branch information
philwinder committed Oct 15, 2024
1 parent 9a1380b commit 940b96d
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 40 deletions.
28 changes: 26 additions & 2 deletions api/pkg/openai/helix_openai_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ func (c *InternalHelixServer) CreateChatCompletionStream(ctx context.Context, re
}

doneCh := make(chan struct{})
readyCh := make(chan struct{})
firstRun := true
var respError error

pr, pw := io.Pipe()

Expand All @@ -168,6 +171,18 @@ func (c *InternalHelixServer) CreateChatCompletionStream(ctx context.Context, re
return fmt.Errorf("error unmarshalling runner response: %w", err)
}

if runnerResp.Error != "" {
respError = fmt.Errorf("runner error: %s", runnerResp.Error)
}

// First chunk received, ready to return the stream or the error
// This MUST be done before the writeChunk call, otherwise it will block waiting for the
// reader to start
if firstRun {
close(readyCh)
firstRun = false
}

if runnerResp.StreamResponse != nil {
bts, err := json.Marshal(runnerResp.StreamResponse)
if err != nil {
Expand All @@ -180,7 +195,7 @@ func (c *InternalHelixServer) CreateChatCompletionStream(ctx context.Context, re
}
}

if runnerResp.Done {
if runnerResp.Done || runnerResp.Error != "" {
close(doneCh)

// Ensure the buffer gets EOF so it stops reading
Expand Down Expand Up @@ -212,7 +227,16 @@ func (c *InternalHelixServer) CreateChatCompletionStream(ctx context.Context, re
}()

// Initiate through our client
return client.CreateChatCompletionStream(ctx, request)
stream, err := client.CreateChatCompletionStream(ctx, request)

// Wait for the ready signal
<-readyCh

if respError != nil {
return nil, respError
}

return stream, err
}

// NewOpenAIStreamingAdapter returns a new OpenAI streaming adapter which allows
Expand Down
9 changes: 7 additions & 2 deletions api/pkg/openai/helix_openai_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,17 @@ func (suite *HelixClientTestSuite) Test_CreateChatCompletion_ValidateQueue() {
// Request should be in the queue
time.Sleep(50 * time.Millisecond)

// The work has been scheduled immediately, so the queue should be empty
suite.srv.queueMu.Lock()
defer suite.srv.queueMu.Unlock()

suite.Len(suite.srv.queue, 1)
suite.Len(suite.srv.queue, 0)

req := suite.srv.queue[0]
// The request should now be given to the worker when it next asks for work
work, err := suite.srv.scheduler.WorkForRunner(runnerID, scheduler.WorkloadTypeLLMInferenceRequest, false, "")
suite.NoError(err)

req := work.LLMInferenceRequest()

suite.Equal(ownerID, req.OwnerID)
suite.Equal(sessionID, req.SessionID)
Expand Down
75 changes: 39 additions & 36 deletions api/pkg/openai/helix_openai_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,36 @@ func (c *InternalHelixServer) GetNextLLMInferenceRequest(ctx context.Context, fi
c.queueMu.Lock()
defer c.queueMu.Unlock()

// Doing all the scheduling work here to avoid making too many changes at once. Schedule any
// requests that are currently in the queue.
// Default to requesting warm work
newWorkOnly := false

// Only get new work if the filter has a memory requirement (see runner/controller.go)
if filter.Memory != 0 {
newWorkOnly = true
}

// Now for this runner, get work
req, err := c.scheduler.WorkForRunner(runnerID, scheduler.WorkloadTypeLLMInferenceRequest, newWorkOnly, filter.ModelName)
if err != nil {
return nil, fmt.Errorf("error getting work for runner: %w", err)
}

if req != nil {
c.addSchedulingDecision(filter, runnerID, runnerID, req.LLMInferenceRequest().SessionID, req.LLMInferenceRequest().InteractionID)
log.Info().Str("runnerID", runnerID).Interface("filter", filter).Interface("req", req).Int("len(queue)", len(c.queue)).Msgf("🟠 helix_openai_server GetNextLLMInferenceRequest END")
return req.LLMInferenceRequest(), nil
}
return nil, nil

}

func (c *InternalHelixServer) enqueueRequest(req *types.RunnerLLMInferenceRequest) {
c.queueMu.Lock()
defer c.queueMu.Unlock()

c.queue = append(c.queue, req)

// Schedule any requests that are currently in the queue.
taken := 0
for _, req := range c.queue {
work, err := scheduler.NewLLMWorkload(req)
Expand All @@ -82,18 +110,22 @@ func (c *InternalHelixServer) GetNextLLMInferenceRequest(ctx context.Context, fi

// If we can't retry, write an error to the request and continue so it takes it off
// the queue
log.Warn().Err(err).Str("id", work.ID()).Msg("error scheduling work, removing from queue")

resp := &types.RunnerLLMInferenceResponse{
RequestID: req.RequestID,
OwnerID: req.OwnerID,
Error: err.Error(),
Done: true,
RequestID: req.RequestID,
OwnerID: req.OwnerID,
SessionID: req.SessionID,
InteractionID: req.InteractionID,
Error: err.Error(),
Done: true,
}
bts, err := json.Marshal(resp)
if err != nil {
log.Error().Err(err).Str("id", work.ID()).Msg("error marshalling runner response")
}

err = c.pubsub.Publish(ctx, pubsub.GetRunnerResponsesQueue(resp.OwnerID, resp.RequestID), bts)
err = c.pubsub.Publish(context.Background(), pubsub.GetRunnerResponsesQueue(req.OwnerID, req.RequestID), bts)
if err != nil {
log.Error().Err(err).Str("id", work.ID()).Msg("error publishing runner response")
}
Expand All @@ -102,35 +134,6 @@ func (c *InternalHelixServer) GetNextLLMInferenceRequest(ctx context.Context, fi
}
// Clear processed queue
c.queue = c.queue[taken:]

// Default to requesting warm work
newWorkOnly := false

// Only get new work if the filter has a memory requirement (see runner/controller.go)
if filter.Memory != 0 {
newWorkOnly = true
}

// Now for this runner, get work
req, err := c.scheduler.WorkForRunner(runnerID, scheduler.WorkloadTypeLLMInferenceRequest, newWorkOnly, filter.ModelName)
if err != nil {
return nil, fmt.Errorf("error getting work for runner: %w", err)
}

if req != nil {
c.addSchedulingDecision(filter, runnerID, runnerID, req.LLMInferenceRequest().SessionID, req.LLMInferenceRequest().InteractionID)
log.Info().Str("runnerID", runnerID).Interface("filter", filter).Interface("req", req).Int("len(queue)", len(c.queue)).Msgf("🟠 helix_openai_server GetNextLLMInferenceRequest END")
return req.LLMInferenceRequest(), nil
}
return nil, nil

}

func (c *InternalHelixServer) enqueueRequest(req *types.RunnerLLMInferenceRequest) {
c.queueMu.Lock()
defer c.queueMu.Unlock()

c.queue = append(c.queue, req)
}

// ProcessRunnerResponse is called on both partial streaming and full responses coming from the runner
Expand Down
7 changes: 7 additions & 0 deletions api/pkg/server/session_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,13 @@ func (s *HelixAPIServer) handleStreamingSession(ctx context.Context, user *types
// Call the LLM
stream, _, err := s.Controller.ChatCompletionStream(ctx, user, chatCompletionRequest, options)
if err != nil {
// Update last interaction
session.Interactions[len(session.Interactions)-1].Error = err.Error()
session.Interactions[len(session.Interactions)-1].Completed = time.Now()
session.Interactions[len(session.Interactions)-1].State = types.InteractionStateError
session.Interactions[len(session.Interactions)-1].Finished = true
s.Controller.WriteSession(session)

http.Error(rw, err.Error(), http.StatusInternalServerError)
return nil
}
Expand Down

0 comments on commit 940b96d

Please sign in to comment.