Skip to content

Commit

Permalink
add rag chain (again)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianliechti committed Sep 30, 2024
1 parent bfbe808 commit 1510451
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 2 deletions.
3 changes: 3 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ RUN CGO_ENABLED=0 go build -o server
WORKDIR /src/cmd/client
RUN CGO_ENABLED=0 go build -o client

WORKDIR /src/cmd/ingest
RUN CGO_ENABLED=0 go build -o ingest

FROM alpine

Expand All @@ -23,6 +25,7 @@ RUN apk add --no-cache tini ca-certificates mailcap
WORKDIR /
COPY --from=build /src/cmd/server/server .
COPY --from=build /src/cmd/client/client .
COPY --from=build /src/cmd/ingest/ingest .

EXPOSE 8080

Expand Down
45 changes: 43 additions & 2 deletions config/config_chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ import (
"github.com/adrianliechti/llama/pkg/limiter"
"github.com/adrianliechti/llama/pkg/otel"
"github.com/adrianliechti/llama/pkg/provider"
"github.com/adrianliechti/llama/pkg/template"
"golang.org/x/time/rate"

"github.com/adrianliechti/llama/pkg/chain"
"github.com/adrianliechti/llama/pkg/chain/agent"
"github.com/adrianliechti/llama/pkg/chain/assistant"
"github.com/adrianliechti/llama/pkg/chain/rag"
"github.com/adrianliechti/llama/pkg/chain/reasoning"

"github.com/adrianliechti/llama/pkg/to"
Expand All @@ -36,9 +38,11 @@ type chainConfig struct {

Model string `yaml:"model"`

Tools []string `yaml:"tools"`
Template string `yaml:"template"`
Messages []message `yaml:"messages"`

Tools []string `yaml:"tools"`

Limit *int `yaml:"limit"`
Temperature *float32 `yaml:"temperature"`
}
Expand All @@ -49,9 +53,11 @@ type chainContext struct {
Embedder provider.Embedder
Completer provider.Completer

Tools map[string]tool.Tool
Template *template.Template
Messages []provider.Message

Tools map[string]tool.Tool

Limiter *rate.Limiter
}

Expand Down Expand Up @@ -100,6 +106,12 @@ func (cfg *Config) registerChains(f *configFile) error {
context.Tools[t] = tool
}

if c.Template != "" {
if context.Template, err = parseTemplate(c.Template); err != nil {
return err
}
}

if c.Messages != nil {
if context.Messages, err = parseMessages(c.Messages); err != nil {
return err
Expand Down Expand Up @@ -134,6 +146,9 @@ func createChain(cfg chainConfig, context chainContext) (chain.Provider, error)
case "assistant":
return assistantChain(cfg, context)

case "rag":
return ragChain(cfg, context)

case "reasoning":
return reasoningChain(cfg, context)

Expand Down Expand Up @@ -182,6 +197,32 @@ func assistantChain(cfg chainConfig, context chainContext) (chain.Provider, erro
return assistant.New(options...)
}

func ragChain(cfg chainConfig, context chainContext) (chain.Provider, error) {
var options []rag.Option

if context.Completer != nil {
options = append(options, rag.WithCompleter(context.Completer))
}

if context.Template != nil {
options = append(options, rag.WithTemplate(context.Template))
}

if context.Messages != nil {
options = append(options, rag.WithMessages(context.Messages...))
}

if context.Index != nil {
options = append(options, rag.WithIndex(context.Index))
}

if cfg.Temperature != nil {
options = append(options, rag.WithTemperature(*cfg.Temperature))
}

return rag.New(options...)
}

func reasoningChain(cfg chainConfig, context chainContext) (chain.Provider, error) {
var options []reasoning.Option

Expand Down
20 changes: 20 additions & 0 deletions config/config_template.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package config

import (
"errors"
"os"

"github.com/adrianliechti/llama/pkg/template"
)

func parseTemplate(val string) (*template.Template, error) {
if val == "" {
return nil, errors.New("empty template")
}

if data, err := os.ReadFile(val); err == nil {
return template.NewTemplate(string(data))
}

return template.NewTemplate(val)
}
140 changes: 140 additions & 0 deletions pkg/chain/rag/chain.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package rag

import (
"context"
"errors"
"strings"

"github.com/adrianliechti/llama/pkg/chain"
"github.com/adrianliechti/llama/pkg/index"
"github.com/adrianliechti/llama/pkg/provider"
"github.com/adrianliechti/llama/pkg/template"
"github.com/adrianliechti/llama/pkg/text"
)

var _ chain.Provider = &Chain{}

type Chain struct {
completer provider.Completer

template *template.Template
messages []provider.Message

index index.Provider

limit *int
temperature *float32
}

type Option func(*Chain)

func New(options ...Option) (*Chain, error) {
c := &Chain{
template: template.MustTemplate(promptTemplate),
}

for _, option := range options {
option(c)
}

if c.completer == nil {
return nil, errors.New("missing completer provider")
}

if c.index == nil {
return nil, errors.New("missing index provider")
}

return c, nil
}

func WithCompleter(completer provider.Completer) Option {
return func(c *Chain) {
c.completer = completer
}
}

func WithTemplate(template *template.Template) Option {
return func(c *Chain) {
c.template = template
}
}

func WithMessages(messages ...provider.Message) Option {
return func(c *Chain) {
c.messages = messages
}
}

func WithIndex(index index.Provider) Option {
return func(c *Chain) {
c.index = index
}
}

func WithTemperature(temperature float32) Option {
return func(c *Chain) {
c.temperature = &temperature
}
}

func (c *Chain) Complete(ctx context.Context, messages []provider.Message, options *provider.CompleteOptions) (*provider.Completion, error) {
if options == nil {
options = new(provider.CompleteOptions)
}

if options.Temperature == nil {
options.Temperature = c.temperature
}

message := messages[len(messages)-1]

if message.Role != provider.MessageRoleUser {
return nil, errors.New("last message must be from user")
}

query := strings.TrimSpace(message.Content)

results, err := c.index.Query(ctx, query, &index.QueryOptions{
Limit: c.limit,
})

if err != nil {
return nil, err
}

data := promptData{
Input: query,
}

for _, r := range results {
data.Results = append(data.Results, promptResult{
Title: r.Title,
Content: text.Normalize(r.Content),
Location: r.Location,

Metadata: r.Metadata,
})
}

prompt, err := c.template.Execute(data)

if err != nil {
return nil, err
}

message = provider.Message{
Role: provider.MessageRoleUser,
Content: prompt,
}

messages[len(messages)-1] = message

result, err := c.completer.Complete(ctx, messages, options)

if err != nil {
return nil, err
}

return result, nil
}
23 changes: 23 additions & 0 deletions pkg/chain/rag/prompt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package rag

import (
_ "embed"
)

var (
//go:embed prompt.tmpl
promptTemplate string
)

type promptData struct {
Input string
Results []promptResult
}

type promptResult struct {
Title string
Content string
Location string

Metadata map[string]string
}
16 changes: 16 additions & 0 deletions pkg/chain/rag/prompt.tmpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{{- if .Results -}}
Use the provided documents to answer questions:
{{ range .Results }}
---
{{- if .Title }}
Title: {{ .Title }}
{{- end }}
{{- if .Location }}
Location: {{ .Location }}
{{- end }}
{{ .Content }}
{{ end }}
---
{{- end -}}

Question: {{ .Input }}

0 comments on commit 1510451

Please sign in to comment.