From e65c8135094100d7a3272bf286651ab285710fcd Mon Sep 17 00:00:00 2001 From: David Zbarsky Date: Fri, 17 May 2024 18:10:33 -0400 Subject: [PATCH] A slew of fixes to make hot-reloading truly work --- cmd/svcinit/main.go | 160 ++++++++++++++++++++++--------------- itest.bzl | 11 ++- private/itest.bzl | 4 +- runner/BUILD.bazel | 2 + runner/runner.go | 23 +++++- runner/runner_unix.go | 12 +++ runner/runner_windows.go | 10 +++ runner/service_instance.go | 3 +- 8 files changed, 148 insertions(+), 77 deletions(-) create mode 100644 runner/runner_unix.go create mode 100644 runner/runner_windows.go diff --git a/cmd/svcinit/main.go b/cmd/svcinit/main.go index 62b121d..127145b 100644 --- a/cmd/svcinit/main.go +++ b/cmd/svcinit/main.go @@ -84,7 +84,13 @@ func main() { isOneShot := !shouldHotReload && testLabel != "" - serviceSpecs, err := readVersionedServiceSpecs(serviceSpecsPath) + unversionedSpecs, err := readServiceSpecs(serviceSpecsPath) + must(err) + + ports, err := assignPorts(unversionedSpecs) + must(err) + + serviceSpecs, err := augmentServiceSpecs(unversionedSpecs, ports) must(err) /*if *allowSvcctl { @@ -226,7 +232,10 @@ func main() { log.Println(ibazelCmd) // Restart any services as needed. - serviceSpecs, err := readVersionedServiceSpecs(serviceSpecsPath) + unversionedSpecs, err := readServiceSpecs(serviceSpecsPath) + must(err) + + serviceSpecs, err := augmentServiceSpecs(unversionedSpecs, ports) must(err) criticalPath, err = r.UpdateSpecsAndRestart(serviceSpecs, []byte(ibazelCmd)) @@ -235,66 +244,30 @@ func main() { } } -func readVersionedServiceSpecs( +func readServiceSpecs( path string, ) ( - map[string]svclib.VersionedServiceSpec, error, + map[string]svclib.ServiceSpec, error, ) { data, err := os.ReadFile(path) must(err) var serviceSpecs map[string]svclib.ServiceSpec err = json.Unmarshal(data, &serviceSpecs) - must(err) - - tmpDir := os.Getenv("TMPDIR") - socketDir := os.Getenv("SOCKET_DIR") + return serviceSpecs, err +} +func assignPorts( + serviceSpecs map[string]svclib.ServiceSpec, +) ( + svclib.Ports, error, +) { var toClose []net.Listener - ports := svclib.Ports{} - versionedServiceSpecs := make(map[string]svclib.VersionedServiceSpec, len(serviceSpecs)) - for label, serviceSpec := range serviceSpecs { - s := svclib.VersionedServiceSpec{ - ServiceSpec: serviceSpec, - } - - if s.Type == "group" { - versionedServiceSpecs[label] = s - continue - } - - exePath, err := runfiles.Rlocation(s.Exe) - if err != nil { - return nil, err - } - s.Exe = exePath - if s.HealthCheck != "" { - healthCheckPath, err := runfiles.Rlocation(serviceSpec.HealthCheck) - if err != nil { - return nil, err - } - s.HealthCheck = healthCheckPath - } - - if serviceSpec.VersionFile != "" { - versionFilePath, err := runfiles.Rlocation(serviceSpec.VersionFile) - if err != nil { - return nil, err - } - - version, err := os.ReadFile(versionFilePath) - if err != nil { - return nil, err - } - s.Version = string(version) - } - - s.Color = logger.Colorize(s.Label) - - namedPorts := slices.Clone(s.NamedPorts) - if s.AutoassignPort { + for label, spec := range serviceSpecs { + namedPorts := slices.Clone(spec.NamedPorts) + if spec.AutoassignPort { namedPorts = append(namedPorts, "") } @@ -327,7 +300,7 @@ func readVersionedServiceSpecs( return nil, err } - qualifiedPortName := s.Label + qualifiedPortName := label if portName != "" { qualifiedPortName += ":" + portName } @@ -335,21 +308,7 @@ func readVersionedServiceSpecs( fmt.Printf("Assigning port %s to %s\n", port, qualifiedPortName) ports.Set(qualifiedPortName, port) toClose = append(toClose, listener) - - if portName == "" { - for i := range s.ServiceSpec.Args { - s.Args[i] = strings.ReplaceAll(s.Args[i], "$${PORT}", port) - } - s.HttpHealthCheckAddress = strings.ReplaceAll(s.HttpHealthCheckAddress, "$${PORT}", port) - } } - - for i := range s.Args { - s.Args[i] = strings.ReplaceAll(s.Args[i], "$${TMPDIR}", tmpDir) - s.Args[i] = strings.ReplaceAll(s.Args[i], "$${SOCKET_DIR}", socketDir) - } - - versionedServiceSpecs[label] = s } for _, listener := range toClose { @@ -364,8 +323,77 @@ func readVersionedServiceSpecs( time.Sleep(10 * time.Millisecond) serializedPorts, err := ports.Marshal() - must(err) + if err != nil { + return nil, err + } os.Setenv("ASSIGNED_PORTS", string(serializedPorts)) + return ports, nil +} + +func augmentServiceSpecs( + serviceSpecs map[string]svclib.ServiceSpec, + ports svclib.Ports, +) ( + map[string]svclib.VersionedServiceSpec, error, +) { + tmpDir := os.Getenv("TMPDIR") + socketDir := os.Getenv("SOCKET_DIR") + + versionedServiceSpecs := make(map[string]svclib.VersionedServiceSpec, len(serviceSpecs)) + for label, serviceSpec := range serviceSpecs { + s := svclib.VersionedServiceSpec{ + ServiceSpec: serviceSpec, + } + + if s.Type == "group" { + versionedServiceSpecs[label] = s + continue + } + + exePath, err := runfiles.Rlocation(s.Exe) + if err != nil { + return nil, err + } + s.Exe = exePath + + if s.HealthCheck != "" { + healthCheckPath, err := runfiles.Rlocation(serviceSpec.HealthCheck) + if err != nil { + return nil, err + } + s.HealthCheck = healthCheckPath + } + + if serviceSpec.VersionFile != "" { + versionFilePath, err := runfiles.Rlocation(serviceSpec.VersionFile) + if err != nil { + return nil, err + } + + version, err := os.ReadFile(versionFilePath) + if err != nil { + return nil, err + } + s.Version = string(version) + } + + s.Color = logger.Colorize(s.Label) + + if s.AutoassignPort { + port := ports[s.Label] + for i := range s.ServiceSpec.Args { + s.Args[i] = strings.ReplaceAll(s.Args[i], "$${PORT}", port) + } + s.HttpHealthCheckAddress = strings.ReplaceAll(s.HttpHealthCheckAddress, "$${PORT}", port) + } + + for i := range s.Args { + s.Args[i] = strings.ReplaceAll(s.Args[i], "$${TMPDIR}", tmpDir) + s.Args[i] = strings.ReplaceAll(s.Args[i], "$${SOCKET_DIR}", socketDir) + } + + versionedServiceSpecs[label] = s + } replacements := make([]Replacement, 0, len(ports)) for label, port := range ports { diff --git a/itest.bzl b/itest.bzl index 11ee562..e4dd71a 100644 --- a/itest.bzl +++ b/itest.bzl @@ -11,7 +11,7 @@ load( def itest_service(name, tags = [], **kwargs): _itest_service( name = name, - tags = tags, + tags = tags + ["ibazel_notify_changes"], **kwargs ) _hygiene_test( @@ -22,7 +22,7 @@ def itest_service(name, tags = [], **kwargs): def itest_service_group(name, tags = [], **kwargs): _itest_service_group( name = name, - tags = tags, + tags = tags + ["ibazel_notify_changes"], **kwargs ) _hygiene_test( @@ -33,7 +33,7 @@ def itest_service_group(name, tags = [], **kwargs): def itest_task(name, tags = [], **kwargs): _itest_task( name = name, - tags = tags, + tags = tags + ["ibazel_notify_changes"], **kwargs ) _hygiene_test( @@ -41,9 +41,6 @@ def itest_task(name, tags = [], **kwargs): tags = tags, ) -def service_test(tags = [], **kwargs): - _service_test(tags = tags + ["ibazel_notify_changes"], **kwargs) - def _hygiene_test(name, **kwargs): service_test( name = name + "_hygiene_test", @@ -51,3 +48,5 @@ def _hygiene_test(name, **kwargs): test = "@rules_itest//:exit0", **kwargs ) + +service_test = _service_test diff --git a/private/itest.bzl b/private/itest.bzl index 167d1eb..4484323 100644 --- a/private/itest.bzl +++ b/private/itest.bzl @@ -154,7 +154,7 @@ def _itest_service_impl(ctx): _itest_service_attrs = _itest_binary_attrs | { # Note, autoassigning a port is a little racy. If you can stick to hardcoded ports and network namespace, you should prefer that. "autoassign_port": attr.bool( - doc = """If true, the service manager will pick a free port and assign it to the service. + doc = """If true, the service manager will pick a free port and assign it to the service. The port will be interpolated into `$${PORT}` in the service's `http_health_check_address` and `args`. It will also be exported under the target's fully qualified label in the service-port mapping. @@ -351,7 +351,7 @@ def _create_version_file(ctx, inputs): mnemonic = "SvcVersionFile", # disable remote cache and sandbox, since the output is not stable given the inputs # additionally, running this action in the sandbox is way too expensive - execution_requirements = {"local": "1"}, + execution_requirements = {"local": "1", "no-cache": "1"}, ) return output diff --git a/runner/BUILD.bazel b/runner/BUILD.bazel index 452003e..3b13df4 100644 --- a/runner/BUILD.bazel +++ b/runner/BUILD.bazel @@ -4,6 +4,8 @@ go_library( name = "runner", srcs = [ "runner.go", + "runner_unix.go", + "runner_windows.go", "service_instance.go", "topo.go", ], diff --git a/runner/runner.go b/runner/runner.go index 57b8176..2dcd979 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -7,7 +7,9 @@ import ( "os" "os/exec" "reflect" + "runtime" "sync" + "syscall" "time" "rules_itest/logger" @@ -15,6 +17,14 @@ import ( "rules_itest/svclib" ) +// We need to use process groups to reliably tear down services and their descendants. +// This is especially important in hot-reload mode, where you need to restart the child +// and have it bind the same port. +// However, we don't want to do this in tests, because Bazel will already terminate the +// test process (svcinit) and all its children. +// If we were to start new process groups in tests, we could leak children (at least on Mac). +var shouldUseProcessGroups = runtime.GOOS != "windows" && os.Getenv("BAZEL_TEST") != "1" + type ServiceSpecs = map[string]svclib.VersionedServiceSpec type runner struct { @@ -108,7 +118,7 @@ func computeUpdateActions(currentServices, newServices ServiceSpecs) updateActio // We technically don't need a restart if the change is the list of deps. // But that should not be a common use case, so it's not worth the complexity. if !reflect.DeepEqual(service, newService) { - fmt.Println(label + " definition or code has changed, restarting...") + log.Printf(colorize(service) + " definition or code has changed, restarting...") if service.HotReloadable && reflect.DeepEqual(service.ServiceSpec, newService.ServiceSpec) { // The only difference is the Version. Trust the service that // it prefers to receive the ibazel reload command. @@ -194,6 +204,10 @@ func prepareServiceInstance(ctx context.Context, s svclib.VersionedServiceSpec) cmd.Stdout = logger.New(s.Label+"> ", s.Color, os.Stdout) cmd.Stderr = logger.New(s.Label+"> ", s.Color, os.Stderr) + if shouldUseProcessGroups { + setPgid(cmd) + } + // Even if a child process exits, Wait will block until the I/O pipes are closed. // They may have been forwarded to an orphaned child, so we disable that behavior to unblock exit. if s.Type == "service" { @@ -205,6 +219,7 @@ func prepareServiceInstance(ctx context.Context, s svclib.VersionedServiceSpec) Cmd: cmd, startErrFn: sync.OnceValue(cmd.Start), + waitErrFn: sync.OnceValue(cmd.Wait), } if s.HotReloadable { @@ -218,7 +233,11 @@ func prepareServiceInstance(ctx context.Context, s svclib.VersionedServiceSpec) } func stopInstance(serviceInstance *ServiceInstance) { - serviceInstance.Cmd.Process.Kill() + pid := serviceInstance.Cmd.Process.Pid + if shouldUseProcessGroups { + pid = -pid + } + syscall.Kill(pid, syscall.SIGKILL) serviceInstance.Cmd.Wait() for serviceInstance.Cmd.ProcessState == nil { diff --git a/runner/runner_unix.go b/runner/runner_unix.go new file mode 100644 index 0000000..189d544 --- /dev/null +++ b/runner/runner_unix.go @@ -0,0 +1,12 @@ +//go:build unix + +package runner + +import ( + "os/exec" + "syscall" +) + +func setPgid(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} +} diff --git a/runner/runner_windows.go b/runner/runner_windows.go new file mode 100644 index 0000000..8388e8c --- /dev/null +++ b/runner/runner_windows.go @@ -0,0 +1,10 @@ +//go:build windows + +package runner + +import "os/exec" + +func setPgid(cmd *exec.Cmd) { + panic("Pgid not implemented on windows!") +} + diff --git a/runner/service_instance.go b/runner/service_instance.go index 561c0c0..4223d77 100644 --- a/runner/service_instance.go +++ b/runner/service_instance.go @@ -28,6 +28,7 @@ type ServiceInstance struct { startDuration time.Duration startErrFn func() error + waitErrFn func() error mu sync.Mutex runErr error @@ -51,7 +52,7 @@ func (s *ServiceInstance) WaitUntilHealthy(ctx context.Context) error { coloredLabel := colorize(s.VersionedServiceSpec) if s.Type == "task" { - err := s.Wait() + err := s.waitErrFn() log.Printf("%s completed.\n", coloredLabel) return err }