Skip to content

Commit

Permalink
ollama tool streaming hack
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianliechti committed Oct 8, 2024
1 parent 45217e9 commit 2c0e018
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 30 deletions.
4 changes: 0 additions & 4 deletions config/config_chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,6 @@ func agentChain(cfg chainConfig, context chainContext) (chain.Provider, error) {
options = append(options, agent.WithMessages(context.Messages...))
}

if cfg.Temperature != nil {
options = append(options, agent.WithTemperature(*cfg.Temperature))
}

return agent.New(options...)
}

Expand Down
20 changes: 6 additions & 14 deletions pkg/chain/agent/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ type Chain struct {

tools []tool.Tool
messages []provider.Message

temperature *float32
}

type Option func(*Chain)
Expand Down Expand Up @@ -58,12 +56,6 @@ func WithTools(tool ...tool.Tool) Option {
}
}

func WithTemperature(temperature float32) Option {
return func(c *Chain) {
c.temperature = &temperature
}
}

func (c *Chain) Complete(ctx context.Context, messages []provider.Message, options *provider.CompleteOptions) (*provider.Completion, error) {
if options == nil {
options = new(provider.CompleteOptions)
Expand Down Expand Up @@ -104,14 +96,14 @@ func (c *Chain) Complete(ctx context.Context, messages []provider.Message, optio

for {
inputOptions := &provider.CompleteOptions{
Temperature: options.Temperature,
Tools: to.Values(inputTools),

Stream: options.Stream,
}

if inputOptions.Temperature == nil {
inputOptions.Temperature = c.temperature
Tools: to.Values(inputTools),

MaxTokens: options.MaxTokens,
Temperature: options.Temperature,

Format: options.Format,
}

completion, err := c.completer.Complete(ctx, input, inputOptions)
Expand Down
1 change: 1 addition & 0 deletions pkg/chain/reasoning/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ func (c *Chain) Complete(ctx context.Context, messages []provider.Message, optio
})

inputOptions := &provider.CompleteOptions{
MaxTokens: options.MaxTokens,
Temperature: options.Temperature,

Format: provider.CompletionFormatJSON,
Expand Down
2 changes: 1 addition & 1 deletion pkg/provider/custom/completer.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func NewCompleter(url string, options ...Option) (*Completer, error) {

func (c *Completer) Complete(ctx context.Context, messages []provider.Message, options *provider.CompleteOptions) (*provider.Completion, error) {
if options == nil {
options = &provider.CompleteOptions{}
options = new(provider.CompleteOptions)
}

req := &CompletionRequest{
Expand Down
64 changes: 60 additions & 4 deletions pkg/provider/ollama/completer.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
package ollama

import (
"context"
"strings"

"github.com/adrianliechti/llama/pkg/provider"
"github.com/adrianliechti/llama/pkg/provider/openai"
)

type Completer = openai.Completer
var _ provider.Completer = (*Completer)(nil)

type Completer struct {
completer *openai.Completer
}

func NewCompleter(url, model string, options ...Option) (*Completer, error) {
if url == "" {
Expand All @@ -16,11 +22,61 @@ func NewCompleter(url, model string, options ...Option) (*Completer, error) {
url = strings.TrimRight(url, "/")
url = strings.TrimSuffix(url, "/v1")

c := &Config{}
cfg := &Config{}

for _, option := range options {
option(c)
option(cfg)
}

opts := []openai.Option{}

if cfg.client != nil {
opts = append(opts, openai.WithClient(cfg.client))
}

completer, err := openai.NewCompleter(url+"/v1", model, opts...)

if err != nil {
return nil, err
}

return &Completer{
completer: completer,
}, nil
}

func (c *Completer) Complete(ctx context.Context, messages []provider.Message, options *provider.CompleteOptions) (*provider.Completion, error) {
if options == nil {
options = new(provider.CompleteOptions)
}

inputOptions := &provider.CompleteOptions{
Stream: options.Stream,

Stop: options.Stop,
Tools: options.Tools,

MaxTokens: options.MaxTokens,
Temperature: options.Temperature,

Format: options.Format,
}

if len(options.Tools) > 0 {
inputOptions.Stream = nil
}

result, err := c.completer.Complete(ctx, messages, inputOptions)

if err != nil {
return nil, err
}

if inputOptions.Stream == nil && options.Stream != nil {
if err = options.Stream(ctx, *result); err != nil {
return nil, err
}
}

return openai.NewCompleter(url+"/v1", model, c.options...)
return result, nil
}
6 changes: 2 additions & 4 deletions pkg/provider/ollama/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,16 @@ package ollama

import (
"net/http"

"github.com/adrianliechti/llama/pkg/provider/openai"
)

type Config struct {
options []openai.Option
client *http.Client
}

type Option func(*Config)

func WithClient(client *http.Client) Option {
return func(c *Config) {
c.options = append(c.options, openai.WithClient(client))
c.client = client
}
}
12 changes: 9 additions & 3 deletions pkg/provider/ollama/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,17 @@ func NewEmbedder(url, model string, options ...Option) (*Embedder, error) {
url = strings.TrimRight(url, "/")
url = strings.TrimSuffix(url, "/v1")

c := &Config{}
cfg := &Config{}

for _, option := range options {
option(c)
option(cfg)
}

return openai.NewEmbedder(url+"/v1", model, c.options...)
opts := []openai.Option{}

if cfg.client != nil {
opts = append(opts, openai.WithClient(cfg.client))
}

return openai.NewEmbedder(url+"/v1", model, opts...)
}

0 comments on commit 2c0e018

Please sign in to comment.