diff --git a/cmd/ingest/main.go b/cmd/ingest/main.go new file mode 100644 index 0000000..4d56eff --- /dev/null +++ b/cmd/ingest/main.go @@ -0,0 +1,261 @@ +package main + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "flag" + "fmt" + "io" + "io/fs" + "mime/multipart" + "net/http" + "net/url" + "os" + "path/filepath" + "slices" + + "github.com/google/uuid" +) + +func main() { + urlFlag := flag.String("url", "http://localhost:8080", "server url") + tokenFlag := flag.String("token", "", "server token") + indexFlag := flag.String("index", "docs", "index name") + dirFlag := flag.String("dir", ".", "directory") + + flag.Parse() + + url, err := url.Parse(*urlFlag) + + if err != nil { + panic(err) + } + + c := client{ + url: url, + client: http.DefaultClient, + } + + ctx := context.Background() + + _ = tokenFlag + + supported := []string{ + ".txt", ".md", + } + + filepath.WalkDir(*dirFlag, func(path string, e fs.DirEntry, err error) error { + if e.IsDir() { + return nil + } + + if !slices.Contains(supported, filepath.Ext(path)) { + return nil + } + + name := filepath.Base(path) + + f, err := os.Open(path) + + if err != nil { + return err + } + + content, err := c.Extract(ctx, name, f, nil) + + if err != nil { + return err + } + + segments, err := c.Segment(ctx, content, nil) + + if err != nil { + return err + } + + var documents []Document + + for i, segment := range segments { + part := i + 1 + + id := uuid.NewMD5(uuid.NameSpaceOID, []byte(fmt.Sprintf("%s#%d", name, part))).String() + + document := Document{ + ID: id, + + Content: segment.Text, + + Metadata: map[string]string{ + "name": name, + "path": path, + "part": fmt.Sprintf("%d", part), + }, + } + + documents = append(documents, document) + } + + if err := c.Ingest(ctx, *indexFlag, documents, nil); err != nil { + return err + } + + return nil + }) +} + +type client struct { + url *url.URL + client *http.Client +} + +func (c *client) Extract(ctx context.Context, name string, reader io.Reader, options *ExtractOptions) (string, error) { + if options == nil { + options = new(ExtractOptions) + } + + var body bytes.Buffer + w := multipart.NewWriter(&body) + + //w.WriteField("model", string(options.Model)) + //w.WriteField("format", string(options.Format)) + + file, err := w.CreateFormFile("file", name) + + if err != nil { + return "", err + } + + if _, err := io.Copy(file, reader); err != nil { + return "", err + } + + w.Close() + + req, _ := http.NewRequestWithContext(ctx, "POST", c.url.JoinPath("/v1/extract").String(), &body) + req.Header.Set("Content-Type", w.FormDataContentType()) + + resp, err := http.DefaultClient.Do(req) + + if err != nil { + return "", err + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", errors.New(resp.Status) + } + + data, err := io.ReadAll(resp.Body) + + if err != nil { + return "", err + } + + return string(data), nil +} + +type ExtractOptions struct { +} + +func (c *client) Segment(ctx context.Context, content string, options *SegmentOptions) ([]Segment, error) { + if options == nil { + options = new(SegmentOptions) + } + + request := SegmentRequest{ + Content: content, + + SegmentLength: options.SegmentLength, + SegmentOverlap: options.SegmentOverlap, + } + + var body bytes.Buffer + + if err := json.NewEncoder(&body).Encode(request); err != nil { + return nil, err + } + + req, _ := http.NewRequestWithContext(ctx, "POST", c.url.JoinPath("/v1/segment").String(), &body) + req.Header.Set("Content-Type", "application/json") + + resp, err := c.client.Do(req) + + if err != nil { + return nil, err + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, errors.New(resp.Status) + } + + var result struct { + Segments []Segment `json:"segments,omitempty"` + } + + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, err + } + + return result.Segments, nil +} + +type Segment struct { + Text string `json:"text"` +} + +type SegmentOptions struct { + SegmentLength *int + SegmentOverlap *int +} + +type SegmentRequest struct { + Content string `json:"content"` + + SegmentLength *int `json:"segment_length"` + SegmentOverlap *int `json:"segment_overlap"` +} + +func (c *client) Ingest(ctx context.Context, index string, documents []Document, options *IngestOptions) error { + if options == nil { + options = new(IngestOptions) + } + + var body bytes.Buffer + + if err := json.NewEncoder(&body).Encode(documents); err != nil { + return err + } + + req, _ := http.NewRequestWithContext(ctx, "POST", c.url.JoinPath("/v1/index/"+index).String(), &body) + req.Header.Set("Content-Type", "application/json") + + resp, err := c.client.Do(req) + + if err != nil { + return err + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusNoContent { + return errors.New(resp.Status) + } + + return nil +} + +type Document struct { + ID string `json:"id,omitempty"` + + Content string `json:"content,omitempty"` + + Metadata map[string]string `json:"metadata,omitempty"` +} + +type IngestOptions struct { +}