-
Notifications
You must be signed in to change notification settings - Fork 138
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a route to ask chat completion to the RAG
- Loading branch information
Showing
6 changed files
with
258 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
package rag | ||
|
||
import ( | ||
"bytes" | ||
"encoding/json" | ||
"errors" | ||
"fmt" | ||
"io" | ||
"net/http" | ||
"net/url" | ||
|
||
"github.com/cozy/cozy-stack/model/instance" | ||
"github.com/cozy/cozy-stack/model/job" | ||
"github.com/cozy/cozy-stack/pkg/consts" | ||
"github.com/cozy/cozy-stack/pkg/couchdb" | ||
"github.com/cozy/cozy-stack/pkg/jsonapi" | ||
"github.com/cozy/cozy-stack/pkg/logger" | ||
"github.com/labstack/echo/v4" | ||
) | ||
|
||
type ChatPayload struct { | ||
ChatCompletionID string | ||
Query string `json:"q"` | ||
} | ||
|
||
type ChatCompletion struct { | ||
DocID string `json:"_id"` | ||
DocRev string `json:"_rev,omitempty"` | ||
Messages []ChatMessage `json:"messages"` | ||
} | ||
|
||
type ChatMessage struct { | ||
Role string `json:"role"` | ||
Content string `json:"content"` | ||
} | ||
|
||
const ( | ||
HumanRole = "human" | ||
AIRole = "ai" | ||
) | ||
|
||
func (c *ChatCompletion) ID() string { return c.DocID } | ||
func (c *ChatCompletion) Rev() string { return c.DocRev } | ||
func (c *ChatCompletion) DocType() string { return consts.ChatCompletions } | ||
func (c *ChatCompletion) SetID(id string) { c.DocID = id } | ||
func (c *ChatCompletion) SetRev(rev string) { c.DocRev = rev } | ||
func (c *ChatCompletion) Clone() couchdb.Doc { | ||
cloned := *c | ||
cloned.Messages = make([]ChatMessage, len(c.Messages)) | ||
copy(cloned.Messages, c.Messages) | ||
return &cloned | ||
} | ||
func (c *ChatCompletion) Included() []jsonapi.Object { return nil } | ||
func (c *ChatCompletion) Relationships() jsonapi.RelationshipMap { return nil } | ||
func (c *ChatCompletion) Links() *jsonapi.LinksList { return nil } | ||
|
||
var _ jsonapi.Object = (*ChatCompletion)(nil) | ||
|
||
type QueryMessage struct { | ||
Task string `json:"task"` | ||
DocID string `json:"doc_id"` | ||
} | ||
|
||
func Chat(inst *instance.Instance, payload ChatPayload) (*ChatCompletion, error) { | ||
var chat ChatCompletion | ||
err := couchdb.GetDoc(inst, consts.ChatCompletions, payload.ChatCompletionID, &chat) | ||
if couchdb.IsNotFoundError(err) { | ||
chat.DocID = payload.ChatCompletionID | ||
} else if err != nil { | ||
return nil, err | ||
} | ||
msg := ChatMessage{Role: HumanRole, Content: payload.Query} | ||
chat.Messages = append(chat.Messages, msg) | ||
if chat.DocRev == "" { | ||
err = couchdb.CreateNamedDocWithDB(inst, &chat) | ||
} else { | ||
err = couchdb.UpdateDoc(inst, &chat) | ||
} | ||
if err != nil { | ||
return nil, err | ||
} | ||
query, err := job.NewMessage(&QueryMessage{ | ||
Task: "chat-completion", | ||
DocID: chat.DocID, | ||
}) | ||
if err != nil { | ||
return nil, err | ||
} | ||
_, err = job.System().PushJob(inst, &job.JobRequest{ | ||
WorkerType: "rag-query", | ||
Message: query, | ||
}) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return &chat, nil | ||
} | ||
|
||
func Query(inst *instance.Instance, logger logger.Logger, query QueryMessage) error { | ||
var chat ChatCompletion | ||
err := couchdb.GetDoc(inst, consts.ChatCompletions, query.DocID, &chat) | ||
if err != nil { | ||
return err | ||
} | ||
msg := chat.Messages[len(chat.Messages)-1] | ||
payload := map[string]interface{}{ | ||
"q": msg.Content, | ||
} | ||
|
||
res, err := callRAGQuery(inst, payload) | ||
if err != nil { | ||
return err | ||
} | ||
defer res.Body.Close() | ||
if res.StatusCode != 200 { | ||
return fmt.Errorf("POST status code: %d", res.StatusCode) | ||
} | ||
|
||
// TODO streaming | ||
completion, err := io.ReadAll(res.Body) | ||
if err != nil { | ||
return err | ||
} | ||
answer := ChatMessage{ | ||
Role: AIRole, | ||
Content: string(completion), | ||
} | ||
chat.Messages = append(chat.Messages, answer) | ||
return couchdb.UpdateDoc(inst, &chat) | ||
} | ||
|
||
func callRAGQuery(inst *instance.Instance, payload map[string]interface{}) (*http.Response, error) { | ||
ragServer := inst.RAGServer() | ||
if ragServer.URL == "" { | ||
return nil, errors.New("no RAG server configured") | ||
} | ||
u, err := url.Parse(ragServer.URL) | ||
if err != nil { | ||
return nil, err | ||
} | ||
u.Path = fmt.Sprintf("/query/%s", inst.Domain) | ||
body, err := json.Marshal(payload) | ||
if err != nil { | ||
return nil, err | ||
} | ||
req, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body)) | ||
if err != nil { | ||
return nil, err | ||
} | ||
req.Header.Add("Content-Type", echo.MIMEApplicationJSON) | ||
return http.DefaultClient.Do(req) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
package ai | ||
|
||
import ( | ||
"net/http" | ||
|
||
"github.com/cozy/cozy-stack/model/permission" | ||
"github.com/cozy/cozy-stack/model/rag" | ||
"github.com/cozy/cozy-stack/pkg/consts" | ||
"github.com/cozy/cozy-stack/pkg/jsonapi" | ||
"github.com/cozy/cozy-stack/web/middlewares" | ||
"github.com/labstack/echo/v4" | ||
) | ||
|
||
// Chat is the route for asking a chat completion to AI. | ||
func Chat(c echo.Context) error { | ||
if err := middlewares.AllowWholeType(c, permission.POST, consts.ChatCompletions); err != nil { | ||
return middlewares.ErrForbidden | ||
} | ||
var payload rag.ChatPayload | ||
if err := c.Bind(&payload); err != nil { | ||
return err | ||
} | ||
payload.ChatCompletionID = c.Param("id") | ||
inst := middlewares.GetInstance(c) | ||
chat, err := rag.Chat(inst, payload) | ||
if err != nil { | ||
return jsonapi.InternalServerError(err) | ||
} | ||
return jsonapi.Data(c, http.StatusAccepted, chat, nil) | ||
} | ||
|
||
// Routes sets the routing for the AI tasks. | ||
func Routes(router *echo.Group) { | ||
router.POST("/chat/completions/:id", Chat) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters