Skip to content

Commit

Permalink
Merge pull request #11 from DavidZbarsky-at/master
Browse files Browse the repository at this point in the history
A slew of fixes to make hot-reloading truly work
  • Loading branch information
dzbarsky authored May 18, 2024
2 parents fbb0b85 + e65c813 commit 5b1a8ac
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 77 deletions.
160 changes: 94 additions & 66 deletions cmd/svcinit/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,13 @@ func main() {

isOneShot := !shouldHotReload && testLabel != ""

serviceSpecs, err := readVersionedServiceSpecs(serviceSpecsPath)
unversionedSpecs, err := readServiceSpecs(serviceSpecsPath)
must(err)

ports, err := assignPorts(unversionedSpecs)
must(err)

serviceSpecs, err := augmentServiceSpecs(unversionedSpecs, ports)
must(err)

/*if *allowSvcctl {
Expand Down Expand Up @@ -226,7 +232,10 @@ func main() {
log.Println(ibazelCmd)

// Restart any services as needed.
serviceSpecs, err := readVersionedServiceSpecs(serviceSpecsPath)
unversionedSpecs, err := readServiceSpecs(serviceSpecsPath)
must(err)

serviceSpecs, err := augmentServiceSpecs(unversionedSpecs, ports)
must(err)

criticalPath, err = r.UpdateSpecsAndRestart(serviceSpecs, []byte(ibazelCmd))
Expand All @@ -235,66 +244,30 @@ func main() {
}
}

func readVersionedServiceSpecs(
func readServiceSpecs(
path string,
) (
map[string]svclib.VersionedServiceSpec, error,
map[string]svclib.ServiceSpec, error,
) {
data, err := os.ReadFile(path)
must(err)

var serviceSpecs map[string]svclib.ServiceSpec
err = json.Unmarshal(data, &serviceSpecs)
must(err)

tmpDir := os.Getenv("TMPDIR")
socketDir := os.Getenv("SOCKET_DIR")
return serviceSpecs, err
}

func assignPorts(
serviceSpecs map[string]svclib.ServiceSpec,
) (
svclib.Ports, error,
) {
var toClose []net.Listener

ports := svclib.Ports{}
versionedServiceSpecs := make(map[string]svclib.VersionedServiceSpec, len(serviceSpecs))
for label, serviceSpec := range serviceSpecs {
s := svclib.VersionedServiceSpec{
ServiceSpec: serviceSpec,
}

if s.Type == "group" {
versionedServiceSpecs[label] = s
continue
}

exePath, err := runfiles.Rlocation(s.Exe)
if err != nil {
return nil, err
}
s.Exe = exePath

if s.HealthCheck != "" {
healthCheckPath, err := runfiles.Rlocation(serviceSpec.HealthCheck)
if err != nil {
return nil, err
}
s.HealthCheck = healthCheckPath
}

if serviceSpec.VersionFile != "" {
versionFilePath, err := runfiles.Rlocation(serviceSpec.VersionFile)
if err != nil {
return nil, err
}

version, err := os.ReadFile(versionFilePath)
if err != nil {
return nil, err
}
s.Version = string(version)
}

s.Color = logger.Colorize(s.Label)

namedPorts := slices.Clone(s.NamedPorts)
if s.AutoassignPort {
for label, spec := range serviceSpecs {
namedPorts := slices.Clone(spec.NamedPorts)
if spec.AutoassignPort {
namedPorts = append(namedPorts, "")
}

Expand Down Expand Up @@ -327,29 +300,15 @@ func readVersionedServiceSpecs(
return nil, err
}

qualifiedPortName := s.Label
qualifiedPortName := label
if portName != "" {
qualifiedPortName += ":" + portName
}

fmt.Printf("Assigning port %s to %s\n", port, qualifiedPortName)
ports.Set(qualifiedPortName, port)
toClose = append(toClose, listener)

if portName == "" {
for i := range s.ServiceSpec.Args {
s.Args[i] = strings.ReplaceAll(s.Args[i], "$${PORT}", port)
}
s.HttpHealthCheckAddress = strings.ReplaceAll(s.HttpHealthCheckAddress, "$${PORT}", port)
}
}

for i := range s.Args {
s.Args[i] = strings.ReplaceAll(s.Args[i], "$${TMPDIR}", tmpDir)
s.Args[i] = strings.ReplaceAll(s.Args[i], "$${SOCKET_DIR}", socketDir)
}

versionedServiceSpecs[label] = s
}

for _, listener := range toClose {
Expand All @@ -364,8 +323,77 @@ func readVersionedServiceSpecs(
time.Sleep(10 * time.Millisecond)

serializedPorts, err := ports.Marshal()
must(err)
if err != nil {
return nil, err
}
os.Setenv("ASSIGNED_PORTS", string(serializedPorts))
return ports, nil
}

func augmentServiceSpecs(
serviceSpecs map[string]svclib.ServiceSpec,
ports svclib.Ports,
) (
map[string]svclib.VersionedServiceSpec, error,
) {
tmpDir := os.Getenv("TMPDIR")
socketDir := os.Getenv("SOCKET_DIR")

versionedServiceSpecs := make(map[string]svclib.VersionedServiceSpec, len(serviceSpecs))
for label, serviceSpec := range serviceSpecs {
s := svclib.VersionedServiceSpec{
ServiceSpec: serviceSpec,
}

if s.Type == "group" {
versionedServiceSpecs[label] = s
continue
}

exePath, err := runfiles.Rlocation(s.Exe)
if err != nil {
return nil, err
}
s.Exe = exePath

if s.HealthCheck != "" {
healthCheckPath, err := runfiles.Rlocation(serviceSpec.HealthCheck)
if err != nil {
return nil, err
}
s.HealthCheck = healthCheckPath
}

if serviceSpec.VersionFile != "" {
versionFilePath, err := runfiles.Rlocation(serviceSpec.VersionFile)
if err != nil {
return nil, err
}

version, err := os.ReadFile(versionFilePath)
if err != nil {
return nil, err
}
s.Version = string(version)
}

s.Color = logger.Colorize(s.Label)

if s.AutoassignPort {
port := ports[s.Label]
for i := range s.ServiceSpec.Args {
s.Args[i] = strings.ReplaceAll(s.Args[i], "$${PORT}", port)
}
s.HttpHealthCheckAddress = strings.ReplaceAll(s.HttpHealthCheckAddress, "$${PORT}", port)
}

for i := range s.Args {
s.Args[i] = strings.ReplaceAll(s.Args[i], "$${TMPDIR}", tmpDir)
s.Args[i] = strings.ReplaceAll(s.Args[i], "$${SOCKET_DIR}", socketDir)
}

versionedServiceSpecs[label] = s
}

replacements := make([]Replacement, 0, len(ports))
for label, port := range ports {
Expand Down
11 changes: 5 additions & 6 deletions itest.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ load(
def itest_service(name, tags = [], **kwargs):
_itest_service(
name = name,
tags = tags,
tags = tags + ["ibazel_notify_changes"],
**kwargs
)
_hygiene_test(
Expand All @@ -22,7 +22,7 @@ def itest_service(name, tags = [], **kwargs):
def itest_service_group(name, tags = [], **kwargs):
_itest_service_group(
name = name,
tags = tags,
tags = tags + ["ibazel_notify_changes"],
**kwargs
)
_hygiene_test(
Expand All @@ -33,21 +33,20 @@ def itest_service_group(name, tags = [], **kwargs):
def itest_task(name, tags = [], **kwargs):
_itest_task(
name = name,
tags = tags,
tags = tags + ["ibazel_notify_changes"],
**kwargs
)
_hygiene_test(
name = name,
tags = tags,
)

def service_test(tags = [], **kwargs):
_service_test(tags = tags + ["ibazel_notify_changes"], **kwargs)

def _hygiene_test(name, **kwargs):
service_test(
name = name + "_hygiene_test",
services = [name],
test = "@rules_itest//:exit0",
**kwargs
)

service_test = _service_test
4 changes: 2 additions & 2 deletions private/itest.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def _itest_service_impl(ctx):
_itest_service_attrs = _itest_binary_attrs | {
# Note, autoassigning a port is a little racy. If you can stick to hardcoded ports and network namespace, you should prefer that.
"autoassign_port": attr.bool(
doc = """If true, the service manager will pick a free port and assign it to the service.
doc = """If true, the service manager will pick a free port and assign it to the service.
The port will be interpolated into `$${PORT}` in the service's `http_health_check_address` and `args`.
It will also be exported under the target's fully qualified label in the service-port mapping.
Expand Down Expand Up @@ -351,7 +351,7 @@ def _create_version_file(ctx, inputs):
mnemonic = "SvcVersionFile",
# disable remote cache and sandbox, since the output is not stable given the inputs
# additionally, running this action in the sandbox is way too expensive
execution_requirements = {"local": "1"},
execution_requirements = {"local": "1", "no-cache": "1"},
)

return output
2 changes: 2 additions & 0 deletions runner/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ go_library(
name = "runner",
srcs = [
"runner.go",
"runner_unix.go",
"runner_windows.go",
"service_instance.go",
"topo.go",
],
Expand Down
23 changes: 21 additions & 2 deletions runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,24 @@ import (
"os"
"os/exec"
"reflect"
"runtime"
"sync"
"syscall"
"time"

"rules_itest/logger"
"rules_itest/runner/topological"
"rules_itest/svclib"
)

// We need to use process groups to reliably tear down services and their descendants.
// This is especially important in hot-reload mode, where you need to restart the child
// and have it bind the same port.
// However, we don't want to do this in tests, because Bazel will already terminate the
// test process (svcinit) and all its children.
// If we were to start new process groups in tests, we could leak children (at least on Mac).
var shouldUseProcessGroups = runtime.GOOS != "windows" && os.Getenv("BAZEL_TEST") != "1"

type ServiceSpecs = map[string]svclib.VersionedServiceSpec

type runner struct {
Expand Down Expand Up @@ -108,7 +118,7 @@ func computeUpdateActions(currentServices, newServices ServiceSpecs) updateActio
// We technically don't need a restart if the change is the list of deps.
// But that should not be a common use case, so it's not worth the complexity.
if !reflect.DeepEqual(service, newService) {
fmt.Println(label + " definition or code has changed, restarting...")
log.Printf(colorize(service) + " definition or code has changed, restarting...")
if service.HotReloadable && reflect.DeepEqual(service.ServiceSpec, newService.ServiceSpec) {
// The only difference is the Version. Trust the service that
// it prefers to receive the ibazel reload command.
Expand Down Expand Up @@ -194,6 +204,10 @@ func prepareServiceInstance(ctx context.Context, s svclib.VersionedServiceSpec)
cmd.Stdout = logger.New(s.Label+"> ", s.Color, os.Stdout)
cmd.Stderr = logger.New(s.Label+"> ", s.Color, os.Stderr)

if shouldUseProcessGroups {
setPgid(cmd)
}

// Even if a child process exits, Wait will block until the I/O pipes are closed.
// They may have been forwarded to an orphaned child, so we disable that behavior to unblock exit.
if s.Type == "service" {
Expand All @@ -205,6 +219,7 @@ func prepareServiceInstance(ctx context.Context, s svclib.VersionedServiceSpec)
Cmd: cmd,

startErrFn: sync.OnceValue(cmd.Start),
waitErrFn: sync.OnceValue(cmd.Wait),
}

if s.HotReloadable {
Expand All @@ -218,7 +233,11 @@ func prepareServiceInstance(ctx context.Context, s svclib.VersionedServiceSpec)
}

func stopInstance(serviceInstance *ServiceInstance) {
serviceInstance.Cmd.Process.Kill()
pid := serviceInstance.Cmd.Process.Pid
if shouldUseProcessGroups {
pid = -pid
}
syscall.Kill(pid, syscall.SIGKILL)
serviceInstance.Cmd.Wait()

for serviceInstance.Cmd.ProcessState == nil {
Expand Down
12 changes: 12 additions & 0 deletions runner/runner_unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
//go:build unix

package runner

import (
"os/exec"
"syscall"
)

func setPgid(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
}
10 changes: 10 additions & 0 deletions runner/runner_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
//go:build windows

package runner

import "os/exec"

func setPgid(cmd *exec.Cmd) {
panic("Pgid not implemented on windows!")
}

3 changes: 2 additions & 1 deletion runner/service_instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type ServiceInstance struct {
startDuration time.Duration

startErrFn func() error
waitErrFn func() error

mu sync.Mutex
runErr error
Expand All @@ -51,7 +52,7 @@ func (s *ServiceInstance) WaitUntilHealthy(ctx context.Context) error {

coloredLabel := colorize(s.VersionedServiceSpec)
if s.Type == "task" {
err := s.Wait()
err := s.waitErrFn()
log.Printf("%s completed.\n", coloredLabel)
return err
}
Expand Down

0 comments on commit 5b1a8ac

Please sign in to comment.