Skip to content

Commit

Permalink
Merge pull request #430 from helixml/fix/api-tools-improvements
Browse files Browse the repository at this point in the history
API tools improvements (August edition)
  • Loading branch information
lukemarsden authored Aug 28, 2024
2 parents a7eaaed + 138e741 commit 1415e21
Show file tree
Hide file tree
Showing 16 changed files with 201 additions and 134 deletions.
11 changes: 3 additions & 8 deletions api/pkg/controller/inference.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,9 @@ func (c *Controller) evaluateToolUsage(ctx context.Context, user *types.User, re
return nil, false, nil
}

lastMessage := getLastMessage(req)
history := types.HistoryFromChatCompletionRequest(req)

resp, err := c.ToolsPlanner.RunAction(ctx, vals.SessionID, vals.InteractionID, selectedTool, history, lastMessage, isActionable.Api)
resp, err := c.ToolsPlanner.RunAction(ctx, vals.SessionID, vals.InteractionID, selectedTool, history, isActionable.Api)
if err != nil {
return nil, false, fmt.Errorf("failed to perform action: %w", err)
}
Expand Down Expand Up @@ -178,10 +177,9 @@ func (c *Controller) evaluateToolUsageStream(ctx context.Context, user *types.Us
return nil, false, nil
}

lastMessage := getLastMessage(req)
history := types.HistoryFromChatCompletionRequest(req)

stream, err := c.ToolsPlanner.RunActionStream(ctx, vals.SessionID, vals.InteractionID, selectedTool, history, lastMessage, isActionable.Api)
stream, err := c.ToolsPlanner.RunActionStream(ctx, vals.SessionID, vals.InteractionID, selectedTool, history, isActionable.Api)
if err != nil {
return nil, false, fmt.Errorf("failed to perform action: %w", err)
}
Expand All @@ -205,9 +203,6 @@ func (c *Controller) selectAndConfigureTool(ctx context.Context, user *types.Use
return nil, nil, false, nil
}

// Get last message from the chat completion messages
lastMessage := getLastMessage(req)

var options []tools.Option

// If assistant has configured an actionable template, use it
Expand All @@ -222,7 +217,7 @@ func (c *Controller) selectAndConfigureTool(ctx context.Context, user *types.Use
vals = &oai.ContextValues{}
}

isActionable, err := c.ToolsPlanner.IsActionable(ctx, vals.SessionID, vals.InteractionID, assistant.Tools, history, lastMessage, options...)
isActionable, err := c.ToolsPlanner.IsActionable(ctx, vals.SessionID, vals.InteractionID, assistant.Tools, history, options...)
if err != nil {
log.Error().Err(err).Msg("failed to evaluate if the message is actionable, skipping to general knowledge")
return nil, nil, false, fmt.Errorf("failed to evaluate if the message is actionable: %w", err)
Expand Down
18 changes: 8 additions & 10 deletions api/pkg/controller/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -600,16 +600,14 @@ func (c *Controller) checkForActions(session *types.Session) (*types.Session, er
return session, nil
}

userInteraction, err := data.GetLastUserInteraction(session.Interactions)
if err != nil {
return nil, fmt.Errorf("failed to get last user interaction: %w", err)
}

history := data.GetLastInteractions(session, actionContextHistorySize)

// If history has more than 2 interactions, remove the last 2 as it's the current user and assistant interaction
if len(history) > 2 {
history = history[:len(history)-2]
for i, interaction := range history {
log.Info().
Int("index", i).
Str("creator", string(interaction.Creator)).
Str("message", interaction.Message).
Msg("History item")
}

messageHistory := types.HistoryFromInteractions(history)
Expand All @@ -627,7 +625,7 @@ func (c *Controller) checkForActions(session *types.Session) (*types.Session, er
options = append(options, tools.WithIsActionableTemplate(assistant.IsActionableTemplate))
}

isActionable, err := c.ToolsPlanner.IsActionable(ctx, session.ID, lastInteraction.ID, activeTools, messageHistory, userInteraction.Message, options...)
isActionable, err := c.ToolsPlanner.IsActionable(ctx, session.ID, lastInteraction.ID, activeTools, messageHistory, options...)
if err != nil {
log.Error().Err(err).Msg("failed to evaluate if the message is actionable, skipping to general knowledge")
return session, nil
Expand All @@ -637,7 +635,7 @@ func (c *Controller) checkForActions(session *types.Session) (*types.Session, er
Str("api", isActionable.Api).
Str("actionable", isActionable.NeedsTool).
Str("justification", isActionable.Justification).
Str("message", userInteraction.Message).
Str("history", fmt.Sprintf("%+v", messageHistory)).
Msg("checked for actionable")

if !isActionable.Actionable() {
Expand Down
19 changes: 2 additions & 17 deletions api/pkg/controller/tool_actions.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ func (c *Controller) runActionInteraction(ctx context.Context, session *types.Se
return nil, fmt.Errorf("tool ID not found in interaction metadata")
}

systemPrompt := ""

if session.ParentApp != "" {
app, err := c.Options.Store.GetApp(ctx, session.ParentApp)
if err != nil {
Expand All @@ -45,8 +43,6 @@ func (c *Controller) runActionInteraction(ctx context.Context, session *types.Se
return nil, fmt.Errorf("we could not find the assistant with the id: %s", assistantID)
}

systemPrompt = assistant.SystemPrompt

for _, appTool := range assistant.Tools {
if appTool.ID == toolID {
tool = appTool
Expand All @@ -70,25 +66,14 @@ func (c *Controller) runActionInteraction(ctx context.Context, session *types.Se
}
}

userInteraction, err := data.GetLastUserInteraction(session.Interactions)
if err != nil {
return nil, fmt.Errorf("failed to get last user interaction: %w", err)
}

var updated *types.Session

history := data.GetLastInteractions(session, actionContextHistorySize)

// If history has more than 2 interactions, remove the last 2 as it's the current user and assistant interaction
if len(history) > 2 {
history = history[:len(history)-2]
}

messageHistory := types.HistoryFromInteractions(history)

message := fmt.Sprintf("%s %s", systemPrompt, userInteraction.Message)
log.Info().Str("tool", tool.Name).Str("action", action).Str("message", message).Msg("Running tool action")
resp, err := c.ToolsPlanner.RunAction(ctx, session.ID, assistantInteraction.ID, tool, messageHistory, message, action)
log.Info().Str("tool", tool.Name).Str("action", action).Str("history", fmt.Sprintf("%+v", messageHistory)).Msg("Running tool action")
resp, err := c.ToolsPlanner.RunAction(ctx, session.ID, assistantInteraction.ID, tool, messageHistory, action)
if err != nil {
return nil, fmt.Errorf("failed to perform action: %w", err)
}
Expand Down
5 changes: 5 additions & 0 deletions api/pkg/controller/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,11 @@ func sessionToChatCompletion(session *types.Session) (*openai.ChatCompletionRequ
Content: interaction.Message,
})
case types.CreatorTypeSystem:
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleSystem,
Content: interaction.Message,
})
case types.CreatorTypeAssistant:
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant,
Content: interaction.Message,
Expand Down
37 changes: 25 additions & 12 deletions api/pkg/tools/informative_or_actionable.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ func (i *IsActionableResponse) Actionable() bool {
return i.NeedsTool == "yes"
}

func (c *ChainStrategy) IsActionable(ctx context.Context, sessionID, interactionID string, tools []*types.Tool, history []*types.ToolHistoryMessage, currentMessage string, options ...Option) (*IsActionableResponse, error) {
func (c *ChainStrategy) IsActionable(ctx context.Context, sessionID, interactionID string, tools []*types.Tool, history []*types.ToolHistoryMessage, options ...Option) (*IsActionableResponse, error) {
return retry.DoWithData(
func() (*IsActionableResponse, error) {
return c.isActionable(ctx, sessionID, interactionID, tools, history, currentMessage, options...)
return c.isActionable(ctx, sessionID, interactionID, tools, history, options...)
},
retry.Attempts(apiActionRetries),
retry.Delay(delayBetweenApiRetries),
Expand All @@ -37,7 +37,7 @@ func (c *ChainStrategy) IsActionable(ctx context.Context, sessionID, interaction
log.Warn().
Err(err).
Str("session_id", sessionID).
Str("user_input", currentMessage).
Str("history", fmt.Sprintf("%+v", history)).
Uint("retry_number", n).
Msg("retrying isActionable")
}),
Expand All @@ -50,7 +50,7 @@ func (c *ChainStrategy) getDefaultOptions() Options {
}
}

func (c *ChainStrategy) isActionable(ctx context.Context, sessionID, interactionID string, tools []*types.Tool, history []*types.ToolHistoryMessage, currentMessage string, options ...Option) (*IsActionableResponse, error) {
func (c *ChainStrategy) isActionable(ctx context.Context, sessionID, interactionID string, tools []*types.Tool, history []*types.ToolHistoryMessage, options ...Option) (*IsActionableResponse, error) {
opts := c.getDefaultOptions()

for _, opt := range options {
Expand Down Expand Up @@ -82,9 +82,26 @@ func (c *ChainStrategy) isActionable(ctx context.Context, sessionID, interaction
return nil, fmt.Errorf("failed to prepare system prompt: %w", err)
}

var messages []openai.ChatCompletionMessage
messages := []openai.ChatCompletionMessage{systemPrompt}

messages = append(messages, systemPrompt)
// Log history and current message in a readable way
log.Info().
Str("session_id", sessionID).
Str("interaction_id", interactionID).
Msg("Processing isActionable request")

if len(history) > 0 {
log.Info().Msg("Message history:")
for i, msg := range history {
log.Info().
Int("message_number", i+1).
Str("role", string(msg.Role)).
Str("content", msg.Content).
Msg("Historical message")
}
} else {
log.Info().Msg("No message history")
}

for _, msg := range history {
messages = append(messages, openai.ChatCompletionMessage{
Expand All @@ -95,10 +112,6 @@ func (c *ChainStrategy) isActionable(ctx context.Context, sessionID, interaction

// Adding current message
messages = append(messages,
openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: fmt.Sprintf("<user_message>\n\n%s\n\n</user_message>", currentMessage),
},
openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: "Return the corresponding json for the last user input",
Expand Down Expand Up @@ -140,7 +153,7 @@ func (c *ChainStrategy) isActionable(ctx context.Context, sessionID, interaction
}

log.Info().
Str("user_input", currentMessage).
Str("history", fmt.Sprintf("%+v", history)).
Str("justification", actionableResponse.Justification).
Str("needs_tool", actionableResponse.NeedsTool).
Dur("time_taken", time.Since(started)).
Expand Down Expand Up @@ -279,7 +292,7 @@ Examples:
}
` + "```" + `
**Response Format:** Always respond with JSON without any commentary, wrapped in markdown json tags, for example:
**Response Format:** Always respond with JSON without any commentary, wrapped in markdown json tags (` + "```" + `json at the start and ` + "```" + `at the end), for example:
` + "```" + `json
{
Expand Down
34 changes: 22 additions & 12 deletions api/pkg/tools/informative_or_actionable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/helixml/helix/api/pkg/types"

"github.com/kelseyhightower/envconfig"
oai "github.com/lukemarsden/go-openai2"
openai_ext "github.com/lukemarsden/go-openai2"
"github.com/stretchr/testify/suite"
"go.uber.org/mock/gomock"
Expand Down Expand Up @@ -91,11 +92,14 @@ func (suite *ActionTestSuite) TestIsActionable_Yes() {
},
}

history := []*types.ToolHistoryMessage{}

currentMessage := "What is the weather like in San Francisco?"
history := []*types.ToolHistoryMessage{
{
Role: oai.ChatMessageRoleUser,
Content: "What is the weather like in San Francisco?",
},
}

resp, err := suite.strategy.IsActionable(suite.ctx, "session-123", "i-123", tools, history, currentMessage)
resp, err := suite.strategy.IsActionable(suite.ctx, "session-123", "i-123", tools, history)
suite.Require().NoError(err)

suite.strategy.wg.Wait()
Expand Down Expand Up @@ -161,11 +165,14 @@ func (suite *ActionTestSuite) TestIsActionable_Retryable() {
},
}

history := []*types.ToolHistoryMessage{}

currentMessage := "What is the weather like in San Francisco?"
history := []*types.ToolHistoryMessage{
{
Role: oai.ChatMessageRoleUser,
Content: "What is the weather like in San Francisco?",
},
}

resp, err := suite.strategy.IsActionable(suite.ctx, "session-123", "i-123", tools, history, currentMessage)
resp, err := suite.strategy.IsActionable(suite.ctx, "session-123", "i-123", tools, history)
suite.Require().NoError(err)

suite.strategy.wg.Wait()
Expand Down Expand Up @@ -206,11 +213,14 @@ func (suite *ActionTestSuite) TestIsActionable_NotActionable() {
},
}

history := []*types.ToolHistoryMessage{}

currentMessage := "What's the reason why oceans have less fish??"
history := []*types.ToolHistoryMessage{
{
Role: oai.ChatMessageRoleUser,
Content: "What's the reason why oceans have less fish??",
},
}

resp, err := suite.strategy.IsActionable(suite.ctx, "session-123", "i-123", tools, history, currentMessage)
resp, err := suite.strategy.IsActionable(suite.ctx, "session-123", "i-123", tools, history)
suite.NoError(err)

suite.strategy.wg.Wait()
Expand Down
6 changes: 3 additions & 3 deletions api/pkg/tools/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ import (

// TODO: probably move planner into a separate package so we can decide when we want to call APIs, when to go with RAG, etc.
type Planner interface {
IsActionable(ctx context.Context, sessionID, interactionID string, tools []*types.Tool, history []*types.ToolHistoryMessage, currentMessage string, options ...Option) (*IsActionableResponse, error)
IsActionable(ctx context.Context, sessionID, interactionID string, tools []*types.Tool, history []*types.ToolHistoryMessage, options ...Option) (*IsActionableResponse, error)
// TODO: RAG lookup
RunAction(ctx context.Context, sessionID, interactionID string, tool *types.Tool, history []*types.ToolHistoryMessage, currentMessage, action string) (*RunActionResponse, error)
RunActionStream(ctx context.Context, sessionID, interactionID string, tool *types.Tool, history []*types.ToolHistoryMessage, currentMessage, action string) (*oai.ChatCompletionStream, error)
RunAction(ctx context.Context, sessionID, interactionID string, tool *types.Tool, history []*types.ToolHistoryMessage, action string) (*RunActionResponse, error)
RunActionStream(ctx context.Context, sessionID, interactionID string, tool *types.Tool, history []*types.ToolHistoryMessage, action string) (*oai.ChatCompletionStream, error)
// Validation and defaulting
ValidateAndDefault(ctx context.Context, tool *types.Tool) (*types.Tool, error)
}
Expand Down
11 changes: 8 additions & 3 deletions api/pkg/tools/tools_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,13 @@ func (c *ChainStrategy) prepareRequest(ctx context.Context, tool *types.Tool, ac
return req, nil
}

func (c *ChainStrategy) getAPIRequestParameters(ctx context.Context, sessionID, interactionID string, tool *types.Tool, history []*types.ToolHistoryMessage, currentMessage, action string) (map[string]string, error) {
func (c *ChainStrategy) getAPIRequestParameters(ctx context.Context, sessionID, interactionID string, tool *types.Tool, history []*types.ToolHistoryMessage, action string) (map[string]string, error) {
systemPrompt, err := c.getApiSystemPrompt(tool)
if err != nil {
return nil, fmt.Errorf("failed to prepare system prompt: %w", err)
}

userPrompt, err := c.getApiUserPrompt(tool, history, currentMessage, action)
userPrompt, err := c.getApiUserPrompt(tool, history, action)
if err != nil {
return nil, fmt.Errorf("failed to prepare user prompt: %w", err)
}
Expand Down Expand Up @@ -178,7 +178,7 @@ func (c *ChainStrategy) getApiSystemPrompt(_ *types.Tool) (openai.ChatCompletion
}, nil
}

func (c *ChainStrategy) getApiUserPrompt(tool *types.Tool, history []*types.ToolHistoryMessage, currentMessage, action string) (openai.ChatCompletionMessage, error) {
func (c *ChainStrategy) getApiUserPrompt(tool *types.Tool, history []*types.ToolHistoryMessage, action string) (openai.ChatCompletionMessage, error) {
// Render template
apiUserPromptTemplate := apiUserPrompt

Expand All @@ -196,6 +196,11 @@ func (c *ChainStrategy) getApiUserPrompt(tool *types.Tool, history []*types.Tool
return openai.ChatCompletionMessage{}, err
}

// for preparing the API request, we ONLY use the last message for now (but
// we might want to revisit this, because it could make sense to fill in api
// params from previous messages)
currentMessage := history[len(history)-1].Content

// Render template
var sb strings.Builder
err = tmpl.Execute(&sb, struct {
Expand Down
Loading

0 comments on commit 1415e21

Please sign in to comment.