diff --git a/Dockerfile b/Dockerfile index 5dc1c12..6c12140 100644 --- a/Dockerfile +++ b/Dockerfile @@ -15,6 +15,8 @@ RUN CGO_ENABLED=0 go build -o server WORKDIR /src/cmd/client RUN CGO_ENABLED=0 go build -o client +WORKDIR /src/cmd/ingest +RUN CGO_ENABLED=0 go build -o ingest FROM alpine @@ -23,6 +25,7 @@ RUN apk add --no-cache tini ca-certificates mailcap WORKDIR / COPY --from=build /src/cmd/server/server . COPY --from=build /src/cmd/client/client . +COPY --from=build /src/cmd/ingest/ingest . EXPOSE 8080 diff --git a/config/config_chain.go b/config/config_chain.go index 451120c..cca1c5d 100644 --- a/config/config_chain.go +++ b/config/config_chain.go @@ -8,11 +8,13 @@ import ( "github.com/adrianliechti/llama/pkg/limiter" "github.com/adrianliechti/llama/pkg/otel" "github.com/adrianliechti/llama/pkg/provider" + "github.com/adrianliechti/llama/pkg/template" "golang.org/x/time/rate" "github.com/adrianliechti/llama/pkg/chain" "github.com/adrianliechti/llama/pkg/chain/agent" "github.com/adrianliechti/llama/pkg/chain/assistant" + "github.com/adrianliechti/llama/pkg/chain/rag" "github.com/adrianliechti/llama/pkg/chain/reasoning" "github.com/adrianliechti/llama/pkg/to" @@ -36,9 +38,11 @@ type chainConfig struct { Model string `yaml:"model"` - Tools []string `yaml:"tools"` + Template string `yaml:"template"` Messages []message `yaml:"messages"` + Tools []string `yaml:"tools"` + Limit *int `yaml:"limit"` Temperature *float32 `yaml:"temperature"` } @@ -49,9 +53,11 @@ type chainContext struct { Embedder provider.Embedder Completer provider.Completer - Tools map[string]tool.Tool + Template *template.Template Messages []provider.Message + Tools map[string]tool.Tool + Limiter *rate.Limiter } @@ -100,6 +106,12 @@ func (cfg *Config) registerChains(f *configFile) error { context.Tools[t] = tool } + if c.Template != "" { + if context.Template, err = parseTemplate(c.Template); err != nil { + return err + } + } + if c.Messages != nil { if context.Messages, err = parseMessages(c.Messages); err != nil { return err @@ -134,6 +146,9 @@ func createChain(cfg chainConfig, context chainContext) (chain.Provider, error) case "assistant": return assistantChain(cfg, context) + case "rag": + return ragChain(cfg, context) + case "reasoning": return reasoningChain(cfg, context) @@ -182,6 +197,32 @@ func assistantChain(cfg chainConfig, context chainContext) (chain.Provider, erro return assistant.New(options...) } +func ragChain(cfg chainConfig, context chainContext) (chain.Provider, error) { + var options []rag.Option + + if context.Completer != nil { + options = append(options, rag.WithCompleter(context.Completer)) + } + + if context.Template != nil { + options = append(options, rag.WithTemplate(context.Template)) + } + + if context.Messages != nil { + options = append(options, rag.WithMessages(context.Messages...)) + } + + if context.Index != nil { + options = append(options, rag.WithIndex(context.Index)) + } + + if cfg.Temperature != nil { + options = append(options, rag.WithTemperature(*cfg.Temperature)) + } + + return rag.New(options...) +} + func reasoningChain(cfg chainConfig, context chainContext) (chain.Provider, error) { var options []reasoning.Option diff --git a/config/config_template.go b/config/config_template.go new file mode 100644 index 0000000..33dfd0c --- /dev/null +++ b/config/config_template.go @@ -0,0 +1,20 @@ +package config + +import ( + "errors" + "os" + + "github.com/adrianliechti/llama/pkg/template" +) + +func parseTemplate(val string) (*template.Template, error) { + if val == "" { + return nil, errors.New("empty template") + } + + if data, err := os.ReadFile(val); err == nil { + return template.NewTemplate(string(data)) + } + + return template.NewTemplate(val) +} diff --git a/pkg/chain/rag/chain.go b/pkg/chain/rag/chain.go new file mode 100644 index 0000000..ce93563 --- /dev/null +++ b/pkg/chain/rag/chain.go @@ -0,0 +1,140 @@ +package rag + +import ( + "context" + "errors" + "strings" + + "github.com/adrianliechti/llama/pkg/chain" + "github.com/adrianliechti/llama/pkg/index" + "github.com/adrianliechti/llama/pkg/provider" + "github.com/adrianliechti/llama/pkg/template" + "github.com/adrianliechti/llama/pkg/text" +) + +var _ chain.Provider = &Chain{} + +type Chain struct { + completer provider.Completer + + template *template.Template + messages []provider.Message + + index index.Provider + + limit *int + temperature *float32 +} + +type Option func(*Chain) + +func New(options ...Option) (*Chain, error) { + c := &Chain{ + template: template.MustTemplate(promptTemplate), + } + + for _, option := range options { + option(c) + } + + if c.completer == nil { + return nil, errors.New("missing completer provider") + } + + if c.index == nil { + return nil, errors.New("missing index provider") + } + + return c, nil +} + +func WithCompleter(completer provider.Completer) Option { + return func(c *Chain) { + c.completer = completer + } +} + +func WithTemplate(template *template.Template) Option { + return func(c *Chain) { + c.template = template + } +} + +func WithMessages(messages ...provider.Message) Option { + return func(c *Chain) { + c.messages = messages + } +} + +func WithIndex(index index.Provider) Option { + return func(c *Chain) { + c.index = index + } +} + +func WithTemperature(temperature float32) Option { + return func(c *Chain) { + c.temperature = &temperature + } +} + +func (c *Chain) Complete(ctx context.Context, messages []provider.Message, options *provider.CompleteOptions) (*provider.Completion, error) { + if options == nil { + options = new(provider.CompleteOptions) + } + + if options.Temperature == nil { + options.Temperature = c.temperature + } + + message := messages[len(messages)-1] + + if message.Role != provider.MessageRoleUser { + return nil, errors.New("last message must be from user") + } + + query := strings.TrimSpace(message.Content) + + results, err := c.index.Query(ctx, query, &index.QueryOptions{ + Limit: c.limit, + }) + + if err != nil { + return nil, err + } + + data := promptData{ + Input: query, + } + + for _, r := range results { + data.Results = append(data.Results, promptResult{ + Title: r.Title, + Content: text.Normalize(r.Content), + Location: r.Location, + + Metadata: r.Metadata, + }) + } + + prompt, err := c.template.Execute(data) + + if err != nil { + return nil, err + } + + message = provider.Message{ + Role: provider.MessageRoleUser, + Content: prompt, + } + + messages[len(messages)-1] = message + + result, err := c.completer.Complete(ctx, messages, options) + + if err != nil { + return nil, err + } + + return result, nil +} diff --git a/pkg/chain/rag/prompt.go b/pkg/chain/rag/prompt.go new file mode 100644 index 0000000..1fab895 --- /dev/null +++ b/pkg/chain/rag/prompt.go @@ -0,0 +1,23 @@ +package rag + +import ( + _ "embed" +) + +var ( + //go:embed prompt.tmpl + promptTemplate string +) + +type promptData struct { + Input string + Results []promptResult +} + +type promptResult struct { + Title string + Content string + Location string + + Metadata map[string]string +} diff --git a/pkg/chain/rag/prompt.tmpl b/pkg/chain/rag/prompt.tmpl new file mode 100644 index 0000000..377041f --- /dev/null +++ b/pkg/chain/rag/prompt.tmpl @@ -0,0 +1,16 @@ +{{- if .Results -}} +Use the provided documents to answer questions: +{{ range .Results }} +--- +{{- if .Title }} +Title: {{ .Title }} +{{- end }} +{{- if .Location }} +Location: {{ .Location }} +{{- end }} +{{ .Content }} +{{ end }} +--- +{{- end -}} + +Question: {{ .Input }} \ No newline at end of file