Skip to content

Commit

Permalink
make all embedders rerankers too
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianliechti committed Sep 18, 2024
1 parent 9ed2e2f commit caac226
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 0 deletions.
3 changes: 3 additions & 0 deletions config/config_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package config
import (
"errors"

"github.com/adrianliechti/llama/pkg/provider"

"golang.org/x/time/rate"
"gopkg.in/yaml.v3"
)
Expand Down Expand Up @@ -55,6 +57,7 @@ func (cfg *Config) registerProviders(f *configFile) error {
}

cfg.RegisterEmbedder(p.Type, id, embedder)
cfg.RegisterReranker(p.Type, id, provider.FromEmbedder(embedder))

case ModelTypeReranker:
reranker, err := createReranker(p, context)
Expand Down
82 changes: 82 additions & 0 deletions pkg/provider/reranker.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package provider

import (
"context"
"math"
"sort"
)

type Reranker interface {
Expand All @@ -11,3 +13,83 @@ type Reranker interface {
type RerankOptions struct {
Limit *int
}

type embedderWrapper struct {
embedder Embedder
}

func FromEmbedder(embedder Embedder) Reranker {
return embedderWrapper{
embedder: embedder,
}
}

func (e embedderWrapper) Rerank(ctx context.Context, query string, inputs []string, options *RerankOptions) ([]Result, error) {
result, err := e.embedder.Embed(ctx, query)

if err != nil {
return nil, err
}

var results []Result

for _, input := range inputs {
embedding, err := e.embedder.Embed(ctx, input)

if err != nil {
return nil, err
}

score := cosineSimilarity(result.Data, embedding.Data)

result := Result{
Content: input,
Score: float64(score),
}

results = append(results, result)
}

sort.Slice(results, func(i, j int) bool {
return results[i].Score > results[j].Score
})

if options.Limit != nil {
limit := *options.Limit

if limit > len(results) {
limit = len(results)
}

results = results[:limit]
}

return results, nil
}

func cosineSimilarity(a []float32, b []float32) float32 {
if len(a) != len(b) {
return 0.0
}

dotproduct := 0.0

magnitudeA := 0.0
magnitudeB := 0.0

for k := 0; k < len(a); k++ {
valA := float64(a[k])
valB := float64(b[k])

dotproduct += valA * valB

magnitudeA += math.Pow(valA, 2)
magnitudeB += math.Pow(valB, 2)
}

if magnitudeA == 0 || magnitudeB == 0 {
return 0.0
}

return float32(dotproduct / (math.Sqrt(magnitudeA) * math.Sqrt(magnitudeB)))
}

0 comments on commit caac226

Please sign in to comment.