Skip to content

Commit

Permalink
Gemini Function Calling
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianliechti committed Oct 6, 2024
1 parent bb713a3 commit 31341ba
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 29 deletions.
182 changes: 154 additions & 28 deletions pkg/provider/google/completer.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (
"unicode"

"github.com/adrianliechti/llama/pkg/provider"
"github.com/adrianliechti/llama/pkg/to"

"github.com/google/uuid"
)

Expand Down Expand Up @@ -80,16 +82,15 @@ func (c *Completer) Complete(ctx context.Context, messages []provider.Message, o

candidate := response.Candidates[0]

content := candidate.Content.Parts[0].Text
content = strings.TrimRight(content, "\n ")

return &provider.Completion{
ID: uuid.New().String(),
Reason: toCompletionResult(candidate.FinishReason),

Message: provider.Message{
Role: provider.MessageRoleAssistant,
Content: content,
Content: toContent(candidate.Content),

ToolCalls: toToolCalls(candidate.Content),
},
}, nil
} else {
Expand Down Expand Up @@ -129,6 +130,8 @@ func (c *Completer) Complete(ctx context.Context, messages []provider.Message, o
//Usage: &provider.Usage{},
}

resultToolCalls := map[string]provider.ToolCall{}

for i := 0; ; i++ {
data, err := reader.ReadBytes('\n')

Expand All @@ -142,9 +145,7 @@ func (c *Completer) Complete(ctx context.Context, messages []provider.Message, o

data = bytes.TrimSpace(data)

println(string(data))

if bytes.HasPrefix(data, []byte("event:")) {
if !bytes.HasPrefix(data, []byte("data:")) {
continue
}

Expand All @@ -163,27 +164,35 @@ func (c *Completer) Complete(ctx context.Context, messages []provider.Message, o

candidate := event.Candidates[0]

content := candidate.Content.Parts[0].Text
content := toContent(candidate.Content)

if i == 0 {
content = strings.TrimLeftFunc(content, unicode.IsSpace)
}

result.Message.Content += content

options.Stream <- provider.Completion{
ID: result.ID,
//Reason: result.Reason,
if len(content) > 0 {
options.Stream <- provider.Completion{
ID: result.ID,

Message: provider.Message{
Role: result.Message.Role,
Message: provider.Message{
Role: provider.MessageRoleAssistant,
Content: content,
},
}
}

Content: content,
},
for _, c := range toToolCalls(candidate.Content) {
resultToolCalls[c.Name] = c
}
}

result.Message.Content = strings.TrimRight(result.Message.Content, "\n ")
result.Message.Content = strings.TrimRightFunc(result.Message.Content, unicode.IsSpace)

if len(resultToolCalls) > 0 {
result.Message.ToolCalls = to.Values(resultToolCalls)
}

return result, nil
}
Expand All @@ -200,36 +209,85 @@ func convertGenerateRequest(messages []provider.Message, options *provider.Compl
switch m.Role {

case provider.MessageRoleUser:
content := Content{
Role: ContentRoleUser,
}

content.Parts = []ContentPart{
parts := []ContentPart{
{
Text: m.Content,
},
}

req.Contents = append(req.Contents, content)
req.Contents = append(req.Contents, Content{
Role: ContentRoleUser,
Parts: parts,
})

case provider.MessageRoleAssistant:
content := Content{
Role: ContentRoleUser,
var parts []ContentPart

if m.Content != "" {
parts = append(parts, ContentPart{
Text: m.Content,
})
}

for _, c := range m.ToolCalls {
parts = append(parts, ContentPart{
FunctionCall: &FunctionCall{
Name: c.Name,
Args: json.RawMessage([]byte(c.Arguments)),
},
})
}

content.Parts = []ContentPart{
req.Contents = append(req.Contents, Content{
Role: ContentRoleModel,
Parts: parts,
})

case provider.MessageRoleTool:
parts := []ContentPart{
{
Text: m.Content,
FunctionResponse: &FunctionResponse{
Name: m.Tool,

Response: Response{
Name: m.Tool,
Content: json.RawMessage([]byte(m.Content)),
},
},
},
}

req.Contents = append(req.Contents, content)
req.Contents = append(req.Contents, Content{
Role: ContentRoleUser,
Parts: parts,
})

default:
return nil, errors.New("unsupported message role")
}
}

var functions []FunctionDeclaration

for _, t := range options.Tools {
function := FunctionDeclaration{
Name: t.Name,
Description: t.Description,

Parameters: t.Parameters,
}

functions = append(functions, function)
}

if len(functions) > 0 {
req.Tools = []Tool{
{
FunctionDeclarations: functions,
},
}
}

return req, nil
}

Expand All @@ -243,6 +301,8 @@ var (
// https://ai.google.dev/gemini-api/docs/text-generation?lang=rest#chat
type GenerateRequest struct {
Contents []Content `json:"contents"`

Tools []Tool `json:"tools,omitempty"`
}

type Content struct {
Expand All @@ -252,7 +312,38 @@ type Content struct {
}

type ContentPart struct {
Text string `json:"text"`
Text string `json:"text,omitempty"`

FunctionCall *FunctionCall `json:"functionCall,omitempty"`
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
}

type FunctionCall struct {
Name string `json:"name"`
Args json.RawMessage `json:"args"`
}

type FunctionResponse struct {
Name string `json:"name"`

Response Response `json:"response,omitempty"`
}

type Response struct {
Name string `json:"name"`

Content json.RawMessage `json:"content"`
}

type Tool struct {
FunctionDeclarations []FunctionDeclaration `json:"function_declarations,omitempty"`
}

type FunctionDeclaration struct {
Name string `json:"name"`
Description string `json:"description"`

Parameters map[string]any `json:"parameters"`
}

type GenerateResponse struct {
Expand All @@ -279,6 +370,41 @@ type UsageMetadata struct {
TotalTokenCount int `json:"totalTokenCount"`
}

func toContent(content Content) string {
for _, p := range content.Parts {
if p.Text == "" {
continue
}

return p.Text
}

return ""
}

func toToolCalls(content Content) []provider.ToolCall {
var result []provider.ToolCall

for _, p := range content.Parts {
if p.FunctionCall == nil {
continue
}

arguments, _ := p.FunctionCall.Args.MarshalJSON()

call := provider.ToolCall{
ID: uuid.NewString(),

Name: p.FunctionCall.Name,
Arguments: string(arguments),
}

result = append(result, call)
}

return result
}

func toCompletionResult(val FinishReason) provider.CompletionReason {
switch val {
case FinishReasonStop:
Expand Down
2 changes: 1 addition & 1 deletion pkg/tool/translate/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func (*Tool) Parameters() map[string]any {
},
},

"required": []string{"query", "lang"},
"required": []string{"text", "lang"},
}
}

Expand Down

0 comments on commit 31341ba

Please sign in to comment.