This repository has been archived by the owner on Mar 21, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.go
71 lines (59 loc) · 2.01 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
package main
import (
"context"
"log"
"time"
)
type Helix struct{}
const HELIX_IMAGE = "europe-docker.pkg.dev/helixml/helix/runner:v0.0.10"
// TODO: need to make client download file from runner via the API
func (m *Helix) Service(ctx context.Context) *Service {
return dag.Container().
From(HELIX_IMAGE).
ExperimentalWithAllGPUs().
WithEntrypoint([]string{"/app/helix/helix", "runner", "--timeout-seconds", "600", "--memory", "24GB"}).
WithExposedPort(8080).
AsService()
}
func (m *Helix) Client(ctx context.Context) *Container {
return dag.Container().
From(HELIX_IMAGE).
WithEntrypoint([]string{"/app/helix/helix", "run"})
}
// You don't want to have to load the model weights every time you use the AI
// model - you want to reuse the GPU memory so we run it as a service that
// persists across the lifetime of many dagger calls within a dag
func (m *Helix) Generate(ctx context.Context, prompt string) (*Container, error) {
// create HTTP service container with exposed port 8080
helixRunner := m.Service(ctx)
container := m.Client(ctx).
WithServiceBinding("helix-runner", helixRunner).
WithEnvVariable("CACHE_BUSTER", time.Now().Format(time.RFC3339Nano)).
WithExec([]string{"--api-host", "http://helix-runner:8080", "--type", "image", "--prompt", prompt})
return container, nil
}
func (m *Helix) GenerateFile(ctx context.Context, prompt string) (*File, error) {
container, err := m.Generate(ctx, prompt)
if err != nil {
return nil, err
}
stdout, err := container.Stdout(ctx)
if err != nil {
return nil, err
}
stderr, err := container.Stderr(ctx)
if err != nil {
return nil, err
}
log.Printf("Got stdout from generate: %s", stdout)
log.Printf("Got stderr from generate: %s", stderr)
return container.File("/app/helix/output.png"), nil
}
// example usage: "dagger call nvidia-smi"
func (m *Helix) NvidiaSmi(ctx context.Context) (string, error) {
return dag.Container().
From("nvidia/cuda:12.2.2-base-ubuntu22.04").
ExperimentalWithAllGPUs().
WithExec([]string{"nvidia-smi"}).
Stdout(ctx)
}