Skip to content

Commit

Permalink
better agent chain (wip)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianliechti committed Oct 13, 2024
1 parent 4a648ee commit bccbb95
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# syntax=docker/dockerfile:1

FROM golang:1.22-alpine AS build
FROM golang:1-alpine AS build

WORKDIR /src

Expand Down
50 changes: 40 additions & 10 deletions pkg/chain/agent/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,6 @@ func (c *Chain) Complete(ctx context.Context, messages []provider.Message, optio
var result *provider.Completion

inputOptions := &provider.CompleteOptions{
Stream: options.Stream,

Tools: to.Values(inputTools),

MaxTokens: options.MaxTokens,
Expand All @@ -105,8 +103,46 @@ func (c *Chain) Complete(ctx context.Context, messages []provider.Message, optio
Format: options.Format,
}

if len(options.Tools) > 0 {
inputOptions.Stream = nil
var lastToolCallID string
var lastToolCallName string

streamToolCalls := map[string]provider.ToolCall{}

stream := func(ctx context.Context, completion provider.Completion) error {
for _, t := range completion.Message.ToolCalls {
if t.ID != "" {
lastToolCallID = t.ID
}

if t.Name != "" {
lastToolCallName = t.Name
}

if lastToolCallName == "" {
continue
}

if _, found := agentTools[lastToolCallName]; !found {
call := streamToolCalls[lastToolCallID]
call.ID = lastToolCallID
call.Name = lastToolCallName
call.Arguments += t.Arguments

streamToolCalls[lastToolCallID] = call
}
}

if completion.Message.Content != "" || completion.Reason != "" {
completion.Message.ToolCalls = to.Values(streamToolCalls)

return options.Stream(ctx, completion)
}

return nil
}

if options.Stream != nil {
inputOptions.Stream = stream
}

for {
Expand Down Expand Up @@ -165,11 +201,5 @@ func (c *Chain) Complete(ctx context.Context, messages []provider.Message, optio
return nil, errors.New("unable to handle request")
}

if inputOptions.Stream == nil && options.Stream != nil {
if err := options.Stream(ctx, *result); err != nil {
return nil, err
}
}

return result, nil
}

0 comments on commit bccbb95

Please sign in to comment.