Skip to content

Commit

Permalink
Merge pull request #139 from helixml/system-prompt
Browse files Browse the repository at this point in the history
add system prompt for text finetuning, and retries
  • Loading branch information
lukemarsden authored Jan 29, 2024
2 parents 0cdad03 + 6b1f2c0 commit 84e58fe
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 7 deletions.
20 changes: 20 additions & 0 deletions api/pkg/controller/dataprep.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/helixml/helix/api/pkg/dataprep/text"
"github.com/helixml/helix/api/pkg/system"
"github.com/helixml/helix/api/pkg/types"
"github.com/puzpuzpuz/xsync/v3"
"github.com/rs/zerolog/log"
)

Expand Down Expand Up @@ -312,6 +313,8 @@ func (c *Controller) convertChunksToQuestions(session *types.Session) (*types.Se
var writeUpdatesMutex sync.Mutex

runningFileList := copyFileList(userInteraction.Files)
docIds := xsync.NewMapOf[string, bool]()
docGroupId := ""

outerError = system.ForEachConcurrently[*text.DataPrepTextSplitterChunk](
chunksToProcess,
Expand All @@ -320,6 +323,8 @@ func (c *Controller) convertChunksToQuestions(session *types.Session) (*types.Se
log.Info().Msgf("🔵 question conversion start %d of %d", i+1, len(chunksToProcess))
questions, convertError := dataprep.ConvertChunk(chunk.Text, chunk.Index, chunk.DocumentID, chunk.DocumentGroupID, chunk.PromptName)

docIds.Store(chunk.DocumentID, true)
docGroupId = chunk.DocumentGroupID
// if this is set then we have a non GPT error and should just stop what we are doing
if outerError != nil {
return nil
Expand Down Expand Up @@ -398,7 +403,22 @@ func (c *Controller) convertChunksToQuestions(session *types.Session) (*types.Se
systemInteraction.DataPrepStage = types.TextDataPrepStageEditQuestions
systemInteraction.Progress = 0
systemInteraction.State = types.InteractionStateEditing
// This is going to appear as if the chatbot said it, which might confuse
// the chatbot. Should we wrap it in [INST][/INST]?? But that is model
// specific. Really, we should add explicit support for system prompts to
// the system.
docIdsList := []string{}
docIds.Range(func(key string, value bool) bool {
docIdsList = append(docIdsList, key)
return true
})
session = c.WriteInteraction(session, systemInteraction)
systemPrompt := fmt.Sprintf(
"You are an intelligent chatbot that has been fine-tuned on document(s) %s in document group %s. The document group contains %d document(s). The user will ask you questions about these documents: you must ONLY answer with context from the documents listed. Do NOT refer to background knowledge.",
strings.Join(docIdsList, " "), docGroupId, len(docIdsList),
)
session.Metadata.SystemPrompt = systemPrompt
c.WriteSession(session)

return session, len(chunksToProcess), nil
}
Expand Down
18 changes: 11 additions & 7 deletions api/pkg/dataprep/qapairs/qapairs.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func FindPrompt(name string) (Prompt, error) {
}
}

return Prompt{}, fmt.Errorf("Could not find prompt with name %s", name)
return Prompt{}, fmt.Errorf("could not find prompt with name %s", name)
}

func FindTarget(name string) (Target, error) {
Expand All @@ -133,7 +133,7 @@ func FindTarget(name string) (Target, error) {
}

log.Fatalf("Could not find target with name %s", name)
return Target{}, fmt.Errorf("Could not find target with name %s", name)
return Target{}, fmt.Errorf("could not find target with name %s", name)
}

func Run(targetFilter, promptFilter, textFilter []string) {
Expand Down Expand Up @@ -273,7 +273,11 @@ func Query(target Target, prompt Prompt, text Text, documentID, documentGroupID
debug := fmt.Sprintf("prompt %s", prompt.Name)
resp, err := chatWithModel(target.ApiUrl, os.Getenv(target.TokenFromEnv), target.Model, systemPrompt, userPrompt, debug)
if err != nil {
return nil, err
log.Printf("ChatCompletion error, trying again (%s): %v\n", debug, err)
resp, err = chatWithModel(target.ApiUrl, os.Getenv(target.TokenFromEnv), target.Model, systemPrompt, userPrompt, debug)
if err != nil {
return nil, err
}
}
latency := time.Since(startTime).Milliseconds()

Expand Down Expand Up @@ -346,7 +350,7 @@ func chatWithModel(apiUrl, token, model, system, user, debug string) ([]types.Da
},
)
if err != nil {
fmt.Printf("ChatCompletion error: %v\n", err)
fmt.Printf("ChatCompletion error (%s): %v\n", debug, err)
return nil, err
}

Expand All @@ -364,7 +368,7 @@ func chatWithModel(apiUrl, token, model, system, user, debug string) ([]types.Da
// backslashes for now...
answer = strings.Replace(answer, "\\", "", -1)

return TryVariousJSONFormats(answer)
return TryVariousJSONFormats(answer, fmt.Sprintf("%s respID=%s", debug, resp.ID))

}

Expand All @@ -381,7 +385,7 @@ type QuestionSet struct {
Questions []types.DataPrepTextQuestionRaw `json:"questions"`
}

func TryVariousJSONFormats(jsonString string) ([]types.DataPrepTextQuestionRaw, error) {
func TryVariousJSONFormats(jsonString, debug string) ([]types.DataPrepTextQuestionRaw, error) {
var res []types.DataPrepTextQuestionRaw
var err error

Expand Down Expand Up @@ -409,5 +413,5 @@ func TryVariousJSONFormats(jsonString string) ([]types.DataPrepTextQuestionRaw,
return topLevel.Questions, nil
}

return nil, fmt.Errorf("error parsing JSON:\n\n%s", jsonString)
return nil, fmt.Errorf("error parsing JSON (%s):\n\n%s", debug, jsonString)
}
9 changes: 9 additions & 0 deletions api/pkg/model/mistral7b.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ func (l *Mistral7bInstruct01) GetTask(session *types.Session, fileManager ModelS
task.DatasetDir = fileManager.GetFolder()

var messages []string

// XXX Should there be spaces after the [INST]?
// XXX Should we be including a </s>?
// https://docs.mistral.ai/models/

if session.Metadata.SystemPrompt != "" {
messages = append(messages, fmt.Sprintf("[INST]%s[/INST]", session.Metadata.SystemPrompt))
}

for _, interaction := range session.Interactions {
// Chat API mode
// if len(interaction.Messages) > 0 {
Expand Down
1 change: 1 addition & 0 deletions api/pkg/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ type SessionMetadata struct {
DocumentIDs map[string]string `json:"document_ids"`
DocumentGroupID string `json:"document_group_id"`
ManuallyReviewQuestions bool `json:"manually_review_questions"`
SystemPrompt string `json:"system_prompt"`
}

// the packet we put a list of sessions into so pagination is supported and we know the total amount
Expand Down

0 comments on commit 84e58fe

Please sign in to comment.