Skip to content

Commit

Permalink
refactor(finetune): pass through owner and session information for lo…
Browse files Browse the repository at this point in the history
…gging
  • Loading branch information
philwinder committed Aug 28, 2024
1 parent 989afa2 commit 368ed59
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 13 deletions.
2 changes: 1 addition & 1 deletion api/cmd/helix/qapairs.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func newQapairCommand() *cobra.Command {
serverConfig.FineTuning.QAPairGenModel = qaPairGenModel
}

return qapairs.Run(client, serverConfig.FineTuning.QAPairGenModel, prompt, theText)
return qapairs.Run(client, "n/a", "n/a", serverConfig.FineTuning.QAPairGenModel, prompt, theText)
},
}

Expand Down
12 changes: 11 additions & 1 deletion api/pkg/controller/dataprep.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,12 +299,14 @@ func (c *Controller) getQAChunksToProcess(session *types.Session, dataprep text.
return nil, err
}

log.Debug().Int("beforechunks", len(splitter.Chunks)).Msg("PHIL")
// Some qapair generators expand each chunk into N chunks so they can be run
// by our outer concurrency manager
allChunks, err := dataprep.ExpandChunks(splitter.Chunks)
if err != nil {
return nil, err
}
log.Debug().Int("afterChunks", len(allChunks)).Msg("PHIL")

chunksToProcess := []*text.DataPrepTextSplitterChunk{}
for _, chunk := range allChunks {
Expand All @@ -317,6 +319,7 @@ func (c *Controller) getQAChunksToProcess(session *types.Session, dataprep text.
}

func (c *Controller) getRagChunksToProcess(session *types.Session) ([]*text.DataPrepTextSplitterChunk, error) {
log.Debug().Msg("PHIL getRagChunksToProcess")
filesToConvert, err := c.getTextFilesToConvert(session)
if err != nil {
return nil, err
Expand All @@ -337,6 +340,9 @@ func (c *Controller) getRagChunksToProcess(session *types.Session) ([]*text.Data
newMeta.DocumentIDs = map[string]string{}
}

log.Debug().
Interface("filesToConvert", filesToConvert).
Msg("PHIL Files to convert for RAG processing")
for _, file := range filesToConvert {
fileContent, err := getFileContent(c.Ctx, c.Options.Filestore, file)
if err != nil {
Expand Down Expand Up @@ -488,6 +494,10 @@ func (c *Controller) convertChunksToQuestions(session *types.Session) (*types.Se
if err != nil {
return nil, 0, err
}
log.Debug().
Int("chunksToProcess", len(chunksToProcess)).
Str("sessionID", session.ID).
Msg("PHIL Retrieved chunks to process for QA conversion")

if len(chunksToProcess) == 0 {
return session, 0, nil
Expand Down Expand Up @@ -610,7 +620,7 @@ func (c *Controller) convertChunksToQuestions(session *types.Session) (*types.Se
dataprep.GetConcurrency(),
func(chunk *text.DataPrepTextSplitterChunk, i int) error {
log.Info().Msgf("🔵 question conversion start %d of %d", i+1, len(chunksToProcess))
questions, convertError := dataprep.ConvertChunk(chunk.Text, chunk.Index, chunk.DocumentID, chunk.DocumentGroupID, chunk.PromptName)
questions, convertError := dataprep.ConvertChunk(session.Owner, session.ID, chunk.Text, chunk.Index, chunk.DocumentID, chunk.DocumentGroupID, chunk.PromptName)

// if this is set then we have a non GPT error and should just stop what we are doing
if outerError != nil {
Expand Down
16 changes: 8 additions & 8 deletions api/pkg/dataprep/qapairs/qapairs.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func FindPrompt(name string) (Prompt, error) {
return Prompt{}, fmt.Errorf("could not find prompt with name %s", name)
}

func Run(client openai.Client, model string, promptFilter, textFilter []string) error {
func Run(client openai.Client, ownerID, sessionID, model string, promptFilter, textFilter []string) error {
var config Config
err := yaml.Unmarshal([]byte(qapairConfig), &config)
if err != nil {
Expand Down Expand Up @@ -154,7 +154,7 @@ func Run(client openai.Client, model string, promptFilter, textFilter []string)
for _, prompt := range filteredPrompts {
for _, text := range filteredTexts {
fmt.Printf("Running helix qapairs --target=\"%s\" --prompt=\"%s\" --text=\"%s\"\n", model, prompt.Name, text.Name)
resp, err := Query(client, model, prompt, text, "", "", 0)
resp, err := Query(client, ownerID, sessionID, model, prompt, text, "", "", 0)
if err != nil {
return fmt.Errorf("error querying model: %v", err)
}
Expand All @@ -176,7 +176,7 @@ type TemplateData struct {
DocumentChunk string
}

func Query(client openai.Client, model string, prompt Prompt, text Text, documentID, documentGroupID string, numQuestions int) ([]types.DataPrepTextQuestionRaw, error) {
func Query(client openai.Client, ownerID, sessionID, model string, prompt Prompt, text Text, documentID, documentGroupID string, numQuestions int) ([]types.DataPrepTextQuestionRaw, error) {
// Perform the query for the given target and prompt
var (
contents string
Expand Down Expand Up @@ -233,10 +233,10 @@ func Query(client openai.Client, model string, prompt Prompt, text Text, documen
startTime := time.Now()
debug := fmt.Sprintf("prompt %s", prompt.Name)
// try not enforcing json schema initially, only retry if we fail to parse
resp, err := chatWithModel(client, model, systemPrompt, userPrompt, debug, nil)
resp, err := chatWithModel(client, ownerID, sessionID, model, systemPrompt, userPrompt, debug, nil)
if err != nil {
log.Warn().Msgf("ChatCompletion error non-JSON mode, trying again (%s): %v\n", debug, err)
resp, err = chatWithModel(client, model, systemPrompt, userPrompt, debug, prompt.JsonSchema)
resp, err = chatWithModel(client, ownerID, sessionID, model, systemPrompt, userPrompt, debug, prompt.JsonSchema)
if err != nil {
log.Warn().Msgf("ChatCompletion error JSON mode, giving up, but not propagating the error further for now. (%s): %v\n", debug, err)
latency := time.Since(startTime).Milliseconds()
Expand Down Expand Up @@ -295,7 +295,7 @@ func loadFile(filePath string) (string, error) {
return string(content), nil
}

func chatWithModel(client openai.Client, model, system, user, debug string, jsonSchema map[string]interface{}) ([]types.DataPrepTextQuestionRaw, error) {
func chatWithModel(client openai.Client, ownerID, sessionID, model, system, user, debug string, jsonSchema map[string]interface{}) ([]types.DataPrepTextQuestionRaw, error) {
req := ext_openai.ChatCompletionRequest{
Model: model,
Messages: []ext_openai.ChatCompletionMessage{
Expand All @@ -318,8 +318,8 @@ func chatWithModel(client openai.Client, model, system, user, debug string, json
}

ctx := openai.SetContextValues(context.Background(), &openai.ContextValues{
OwnerID: "n/a",
SessionID: "n/a",
OwnerID: ownerID,
SessionID: sessionID,
InteractionID: "n/a",
})

Expand Down
4 changes: 2 additions & 2 deletions api/pkg/dataprep/text/dynamic.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (d *DynamicDataPrep) ExpandChunks(chunks []*DataPrepTextSplitterChunk) (
}

func (d *DynamicDataPrep) ConvertChunk(
chunk string, index int, documentID, documentGroupID, promptName string,
ownerID, sessionID, chunk string, index int, documentID, documentGroupID, promptName string,
) ([]types.DataPrepTextQuestion, error) {
prompt, err := qapairs.FindPrompt(promptName)
if err != nil {
Expand All @@ -59,7 +59,7 @@ func (d *DynamicDataPrep) ConvertChunk(
Name: "user-provided",
Contents: chunk,
}
resRaw, err := qapairs.Query(d.client, d.model, prompt, text, documentID, documentGroupID, 0)
resRaw, err := qapairs.Query(d.client, ownerID, sessionID, d.model, prompt, text, documentID, documentGroupID, 0)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion api/pkg/dataprep/text/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (

type DataPrepTextQuestionGenerator interface {
ExpandChunks(chunks []*DataPrepTextSplitterChunk) ([]*DataPrepTextSplitterChunk, error)
ConvertChunk(chunk string, index int, documentID, documentGroupID, promptName string) ([]types.DataPrepTextQuestion, error)
ConvertChunk(ownerID, sessionID, chunk string, index int, documentID, documentGroupID, promptName string) ([]types.DataPrepTextQuestion, error)
GetConcurrency() int
GetChunkSize() int
}

0 comments on commit 368ed59

Please sign in to comment.