diff --git a/api/pkg/controller/dataprep.go b/api/pkg/controller/dataprep.go index 34770ace0..dc36b3475 100644 --- a/api/pkg/controller/dataprep.go +++ b/api/pkg/controller/dataprep.go @@ -13,6 +13,7 @@ import ( "github.com/helixml/helix/api/pkg/dataprep/text" "github.com/helixml/helix/api/pkg/system" "github.com/helixml/helix/api/pkg/types" + "github.com/puzpuzpuz/xsync/v3" "github.com/rs/zerolog/log" ) @@ -312,6 +313,8 @@ func (c *Controller) convertChunksToQuestions(session *types.Session) (*types.Se var writeUpdatesMutex sync.Mutex runningFileList := copyFileList(userInteraction.Files) + docIds := xsync.NewMapOf[string, bool]() + docGroupId := "" outerError = system.ForEachConcurrently[*text.DataPrepTextSplitterChunk]( chunksToProcess, @@ -320,6 +323,8 @@ func (c *Controller) convertChunksToQuestions(session *types.Session) (*types.Se log.Info().Msgf("🔵 question conversion start %d of %d", i+1, len(chunksToProcess)) questions, convertError := dataprep.ConvertChunk(chunk.Text, chunk.Index, chunk.DocumentID, chunk.DocumentGroupID, chunk.PromptName) + docIds.Store(chunk.DocumentID, true) + docGroupId = chunk.DocumentGroupID // if this is set then we have a non GPT error and should just stop what we are doing if outerError != nil { return nil @@ -398,7 +403,22 @@ func (c *Controller) convertChunksToQuestions(session *types.Session) (*types.Se systemInteraction.DataPrepStage = types.TextDataPrepStageEditQuestions systemInteraction.Progress = 0 systemInteraction.State = types.InteractionStateEditing + // This is going to appear as if the chatbot said it, which might confuse + // the chatbot. Should we wrap it in [INST][/INST]?? But that is model + // specific. Really, we should add explicit support for system prompts to + // the system. + docIdsList := []string{} + docIds.Range(func(key string, value bool) bool { + docIdsList = append(docIdsList, key) + return true + }) session = c.WriteInteraction(session, systemInteraction) + systemPrompt := fmt.Sprintf( + "You are an intelligent chatbot that has been fine-tuned on document(s) %s in document group %s. The document group contains %d document(s). The user will ask you questions about these documents: you must ONLY answer with context from the documents listed. Do NOT refer to background knowledge.", + strings.Join(docIdsList, " "), docGroupId, len(docIdsList), + ) + session.Metadata.SystemPrompt = systemPrompt + c.WriteSession(session) return session, len(chunksToProcess), nil } diff --git a/api/pkg/dataprep/qapairs/qapairs.go b/api/pkg/dataprep/qapairs/qapairs.go index 85f9098c6..b64b03ce5 100644 --- a/api/pkg/dataprep/qapairs/qapairs.go +++ b/api/pkg/dataprep/qapairs/qapairs.go @@ -116,7 +116,7 @@ func FindPrompt(name string) (Prompt, error) { } } - return Prompt{}, fmt.Errorf("Could not find prompt with name %s", name) + return Prompt{}, fmt.Errorf("could not find prompt with name %s", name) } func FindTarget(name string) (Target, error) { @@ -133,7 +133,7 @@ func FindTarget(name string) (Target, error) { } log.Fatalf("Could not find target with name %s", name) - return Target{}, fmt.Errorf("Could not find target with name %s", name) + return Target{}, fmt.Errorf("could not find target with name %s", name) } func Run(targetFilter, promptFilter, textFilter []string) { @@ -273,7 +273,11 @@ func Query(target Target, prompt Prompt, text Text, documentID, documentGroupID debug := fmt.Sprintf("prompt %s", prompt.Name) resp, err := chatWithModel(target.ApiUrl, os.Getenv(target.TokenFromEnv), target.Model, systemPrompt, userPrompt, debug) if err != nil { - return nil, err + log.Printf("ChatCompletion error, trying again (%s): %v\n", debug, err) + resp, err = chatWithModel(target.ApiUrl, os.Getenv(target.TokenFromEnv), target.Model, systemPrompt, userPrompt, debug) + if err != nil { + return nil, err + } } latency := time.Since(startTime).Milliseconds() @@ -346,7 +350,7 @@ func chatWithModel(apiUrl, token, model, system, user, debug string) ([]types.Da }, ) if err != nil { - fmt.Printf("ChatCompletion error: %v\n", err) + fmt.Printf("ChatCompletion error (%s): %v\n", debug, err) return nil, err } @@ -364,7 +368,7 @@ func chatWithModel(apiUrl, token, model, system, user, debug string) ([]types.Da // backslashes for now... answer = strings.Replace(answer, "\\", "", -1) - return TryVariousJSONFormats(answer) + return TryVariousJSONFormats(answer, fmt.Sprintf("%s respID=%s", debug, resp.ID)) } @@ -381,7 +385,7 @@ type QuestionSet struct { Questions []types.DataPrepTextQuestionRaw `json:"questions"` } -func TryVariousJSONFormats(jsonString string) ([]types.DataPrepTextQuestionRaw, error) { +func TryVariousJSONFormats(jsonString, debug string) ([]types.DataPrepTextQuestionRaw, error) { var res []types.DataPrepTextQuestionRaw var err error @@ -409,5 +413,5 @@ func TryVariousJSONFormats(jsonString string) ([]types.DataPrepTextQuestionRaw, return topLevel.Questions, nil } - return nil, fmt.Errorf("error parsing JSON:\n\n%s", jsonString) + return nil, fmt.Errorf("error parsing JSON (%s):\n\n%s", debug, jsonString) } diff --git a/api/pkg/model/mistral7b.go b/api/pkg/model/mistral7b.go index 982458734..72397175d 100644 --- a/api/pkg/model/mistral7b.go +++ b/api/pkg/model/mistral7b.go @@ -40,6 +40,15 @@ func (l *Mistral7bInstruct01) GetTask(session *types.Session, fileManager ModelS task.DatasetDir = fileManager.GetFolder() var messages []string + + // XXX Should there be spaces after the [INST]? + // XXX Should we be including a ? + // https://docs.mistral.ai/models/ + + if session.Metadata.SystemPrompt != "" { + messages = append(messages, fmt.Sprintf("[INST]%s[/INST]", session.Metadata.SystemPrompt)) + } + for _, interaction := range session.Interactions { // Chat API mode // if len(interaction.Messages) > 0 { diff --git a/api/pkg/types/types.go b/api/pkg/types/types.go index 608de74f4..d07a50f34 100644 --- a/api/pkg/types/types.go +++ b/api/pkg/types/types.go @@ -64,6 +64,7 @@ type SessionMetadata struct { DocumentIDs map[string]string `json:"document_ids"` DocumentGroupID string `json:"document_group_id"` ManuallyReviewQuestions bool `json:"manually_review_questions"` + SystemPrompt string `json:"system_prompt"` } // the packet we put a list of sessions into so pagination is supported and we know the total amount