Skip to content

Commit

Permalink
Add a route to ask chat completion to the RAG
Browse files Browse the repository at this point in the history
  • Loading branch information
nono committed Sep 19, 2024
1 parent 8660cc6 commit 392e491
Show file tree
Hide file tree
Showing 6 changed files with 258 additions and 0 deletions.
48 changes: 48 additions & 0 deletions docs/ai.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,51 @@ When a user starts a chat, their prompts are sent to the RAG that can use the
vector database to find relevant documents (technically, only some parts of
the documents called chunks). Those documents are added to the prompt, so
that the LLM can use them as a context when answering.

### POST /ai/chat/completions/:id

This route can be used to ask AI for a chat completion. The id in the path
must be the identifier of a chat session. The client can generate a random
identifier for a new chat session.

The stack will respond after pushing a job for this task, but without the
response. The client must use the real-time websocket and subscribe to
`io.cozy.ai.chat.completions`.

#### Request

```http
POST /ai/chat/completions/e21dce8058b9013d800a18c04daba326 HTTP/1.1
Content-Type: application/json
```

```json
{
"q": "Why the sky is blue?"
}
```

#### Response

```http
HTTP/1.1 202 Accepted
Content-Type: application/vnd.api+json
```

```json
{
"data": {
"type": "io.cozy.ai.chat.completions"
"id": "e21dce8058b9013d800a18c04daba326",
"rev": "1-23456",
"attributes": {
"messages": [
{
"role": "user",
"content": "Why the sky is blue?"
}
]
}
}
}
```
152 changes: 152 additions & 0 deletions model/rag/chat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
package rag

import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"

"github.com/cozy/cozy-stack/model/instance"
"github.com/cozy/cozy-stack/model/job"
"github.com/cozy/cozy-stack/pkg/consts"
"github.com/cozy/cozy-stack/pkg/couchdb"
"github.com/cozy/cozy-stack/pkg/jsonapi"
"github.com/cozy/cozy-stack/pkg/logger"
"github.com/labstack/echo/v4"
)

type ChatPayload struct {
ChatCompletionID string
Query string `json:"q"`
}

type ChatCompletion struct {
DocID string `json:"_id"`
DocRev string `json:"_rev,omitempty"`
Messages []ChatMessage `json:"messages"`
}

type ChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}

const (
HumanRole = "human"
AIRole = "ai"
)

func (c *ChatCompletion) ID() string { return c.DocID }
func (c *ChatCompletion) Rev() string { return c.DocRev }
func (c *ChatCompletion) DocType() string { return consts.ChatCompletions }
func (c *ChatCompletion) SetID(id string) { c.DocID = id }
func (c *ChatCompletion) SetRev(rev string) { c.DocRev = rev }
func (c *ChatCompletion) Clone() couchdb.Doc {
cloned := *c
cloned.Messages = make([]ChatMessage, len(c.Messages))
copy(cloned.Messages, c.Messages)
return &cloned
}
func (c *ChatCompletion) Included() []jsonapi.Object { return nil }
func (c *ChatCompletion) Relationships() jsonapi.RelationshipMap { return nil }
func (c *ChatCompletion) Links() *jsonapi.LinksList { return nil }

var _ jsonapi.Object = (*ChatCompletion)(nil)

type QueryMessage struct {
Task string `json:"task"`
DocID string `json:"doc_id"`
}

func Chat(inst *instance.Instance, payload ChatPayload) (*ChatCompletion, error) {
var chat ChatCompletion
err := couchdb.GetDoc(inst, consts.ChatCompletions, payload.ChatCompletionID, &chat)
if couchdb.IsNotFoundError(err) {
chat.DocID = payload.ChatCompletionID
} else if err != nil {
return nil, err
}
msg := ChatMessage{Role: HumanRole, Content: payload.Query}
chat.Messages = append(chat.Messages, msg)
if chat.DocRev == "" {
err = couchdb.CreateNamedDocWithDB(inst, &chat)
} else {
err = couchdb.UpdateDoc(inst, &chat)
}
if err != nil {
return nil, err
}
query, err := job.NewMessage(&QueryMessage{
Task: "chat-completion",
DocID: chat.DocID,
})
if err != nil {
return nil, err
}
_, err = job.System().PushJob(inst, &job.JobRequest{
WorkerType: "rag-query",
Message: query,
})
if err != nil {
return nil, err
}
return &chat, nil
}

func Query(inst *instance.Instance, logger logger.Logger, query QueryMessage) error {
var chat ChatCompletion
err := couchdb.GetDoc(inst, consts.ChatCompletions, query.DocID, &chat)
if err != nil {
return err
}
msg := chat.Messages[len(chat.Messages)-1]
payload := map[string]interface{}{
"q": msg.Content,
}

res, err := callRAGQuery(inst, payload)
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode != 200 {
return fmt.Errorf("POST status code: %d", res.StatusCode)
}

// TODO streaming
completion, err := io.ReadAll(res.Body)
if err != nil {
return err
}
answer := ChatMessage{
Role: AIRole,
Content: string(completion),
}
chat.Messages = append(chat.Messages, answer)
return couchdb.UpdateDoc(inst, &chat)
}

func callRAGQuery(inst *instance.Instance, payload map[string]interface{}) (*http.Response, error) {
ragServer := inst.RAGServer()
if ragServer.URL == "" {
return nil, errors.New("no RAG server configured")
}
u, err := url.Parse(ragServer.URL)
if err != nil {
return nil, err
}
u.Path = fmt.Sprintf("/query/%s", inst.Domain)
body, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Add("Content-Type", echo.MIMEApplicationJSON)
return http.DefaultClient.Do(req)
}
2 changes: 2 additions & 0 deletions pkg/consts/doctype.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,6 @@ const (
// NextCloudFiles doc type is used when listing files from a NextCloud via
// WebDAV.
NextCloudFiles = "io.cozy.remote.nextcloud.files"
// ChatCompletions doc type is used for a chat between the user and a chatbot.
ChatCompletions = "io.cozy.ai.chat.completions"
)
35 changes: 35 additions & 0 deletions web/ai/ai.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package ai

import (
"net/http"

"github.com/cozy/cozy-stack/model/permission"
"github.com/cozy/cozy-stack/model/rag"
"github.com/cozy/cozy-stack/pkg/consts"
"github.com/cozy/cozy-stack/pkg/jsonapi"
"github.com/cozy/cozy-stack/web/middlewares"
"github.com/labstack/echo/v4"
)

// Chat is the route for asking a chat completion to AI.
func Chat(c echo.Context) error {
if err := middlewares.AllowWholeType(c, permission.POST, consts.ChatCompletions); err != nil {
return middlewares.ErrForbidden
}
var payload rag.ChatPayload
if err := c.Bind(&payload); err != nil {
return err
}
payload.ChatCompletionID = c.Param("id")
inst := middlewares.GetInstance(c)
chat, err := rag.Chat(inst, payload)
if err != nil {
return jsonapi.InternalServerError(err)
}
return jsonapi.Data(c, http.StatusAccepted, chat, nil)
}

// Routes sets the routing for the AI tasks.
func Routes(router *echo.Group) {
router.POST("/chat/completions/:id", Chat)
}
2 changes: 2 additions & 0 deletions web/routing.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/cozy/cozy-stack/pkg/jsonapi"
"github.com/cozy/cozy-stack/pkg/metrics"
"github.com/cozy/cozy-stack/web/accounts"
"github.com/cozy/cozy-stack/web/ai"
"github.com/cozy/cozy-stack/web/apps"
"github.com/cozy/cozy-stack/web/auth"
"github.com/cozy/cozy-stack/web/bitwarden"
Expand Down Expand Up @@ -235,6 +236,7 @@ func SetupRoutes(router *echo.Echo, services *stack.Services) error {
sharings.Routes(router.Group("/sharings", mws...))
bitwarden.Routes(router.Group("/bitwarden", mws...))
shortcuts.Routes(router.Group("/shortcuts", mws...))
ai.Routes(router.Group("/ai", mws...))

// The settings routes needs not to be blocked
apps.WebappsRoutes(router.Group("/apps", mwsNotBlocked...))
Expand Down
19 changes: 19 additions & 0 deletions worker/rag/rag.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ func init() {
Timeout: 15 * time.Minute,
WorkerFunc: WorkerIndex,
})

job.AddWorker(&job.WorkerConfig{
WorkerType: "rag-query",
Concurrency: runtime.NumCPU(),
MaxExecCount: 1,
Reserved: true,
Timeout: 15 * time.Minute,
WorkerFunc: WorkerQuery,
})
}

func WorkerIndex(ctx *job.TaskContext) error {
Expand All @@ -28,3 +37,13 @@ func WorkerIndex(ctx *job.TaskContext) error {
logger.Debugf("RAG: index %s", msg.Doctype)
return rag.Index(ctx.Instance, logger, msg)
}

func WorkerQuery(ctx *job.TaskContext) error {
logger := ctx.Logger()
var msg rag.QueryMessage
if err := ctx.UnmarshalMessage(&msg); err != nil {
return err
}
logger.Debugf("RAG: query %v", msg)
return rag.Query(ctx.Instance, logger, msg)
}

0 comments on commit 392e491

Please sign in to comment.