Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix nodes not getting set to offline when disconnects #2131

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test-integration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ jobs:
- TestEnableDisableAutoApprovedRoute
- TestAutoApprovedSubRoute2068
- TestSubnetRouteACL
- TestHASubnetRouterFailoverWhenNodeDisconnects2129
- TestHeadscale
- TestCreateTailscale
- TestTailscaleNodesJoiningHeadcale
Expand Down
37 changes: 8 additions & 29 deletions hscontrol/poll.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ func (m *mapSession) serve() {
//
//nolint:gocyclo
func (m *mapSession) serveLongPoll() {
start := time.Now()
m.beforeServeLongPoll()

// Clean up the session when the client disconnects
Expand Down Expand Up @@ -235,16 +236,6 @@ func (m *mapSession) serveLongPoll() {

m.pollFailoverRoutes("node connected", m.node)

// Upgrade the writer to a ResponseController
rc := http.NewResponseController(m.w)

// Longpolling will break if there is a write timeout,
// so it needs to be disabled.
rc.SetWriteDeadline(time.Time{})

ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname))
defer cancel()

m.keepAliveTicker = time.NewTicker(m.keepAlive)

m.h.nodeNotifier.AddNode(m.node.ID, m.ch)
Expand All @@ -258,12 +249,12 @@ func (m *mapSession) serveLongPoll() {
// consume channels with update, keep alives or "batch" blocking signals
select {
case <-m.cancelCh:
m.tracef("poll cancelled received")
m.tracef("poll cancelled received (%s)", time.Since(start).String())
mapResponseEnded.WithLabelValues("cancelled").Inc()
return

case <-ctx.Done():
m.tracef("poll context done")
case <-m.ctx.Done():
m.tracef("poll context done (%s): %s", time.Since(start).String(), m.ctx.Err().Error())
mapResponseEnded.WithLabelValues("done").Inc()
return

Expand Down Expand Up @@ -354,14 +345,7 @@ func (m *mapSession) serveLongPoll() {
m.errf(err, "could not write the map response(%s), for mapSession: %p", update.Type.String(), m)
return
}

err = rc.Flush()
if err != nil {
mapResponseSent.WithLabelValues("error", updateType).Inc()
m.errf(err, "flushing the map response to client, for mapSession: %p", m)
return
}

m.w.(http.Flusher).Flush()
log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node")

if debugHighCardinalityMetrics {
Expand All @@ -375,22 +359,17 @@ func (m *mapSession) serveLongPoll() {
case <-m.keepAliveTicker.C:
data, err := m.mapper.KeepAliveResponse(m.req, m.node)
if err != nil {
m.errf(err, "Error generating the keep alive msg")
m.errf(err, "Error generating the keepalive msg")
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
return
}
_, err = m.w.Write(data)
if err != nil {
m.errf(err, "Cannot write keep alive message")
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
return
}
err = rc.Flush()
if err != nil {
m.errf(err, "flushing keep alive to client, for mapSession: %p", m)
m.errf(err, "Cannot write keepalive message")
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
return
}
m.w.(http.Flusher).Flush()

if debugHighCardinalityMetrics {
mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix()))
Expand Down
250 changes: 250 additions & 0 deletions integration/route_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/google/go-cmp/cmp/cmpopts"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic"
Expand Down Expand Up @@ -1314,3 +1315,252 @@ func TestSubnetRouteACL(t *testing.T) {
t.Errorf("Subnet (%s) filter, unexpected result (-want +got):\n%s", subRouter1.Hostname(), diff)
}
}

func TestHASubnetRouterFailoverWhenNodeDisconnects2129(t *testing.T) {
IntegrationSkip(t)
t.Parallel()

user := "enable-routing"

scenario, err := NewScenario(dockertestMaxWait())
assertNoErrf(t, "failed to create scenario: %s", err)
// defer scenario.ShutdownAssertNoPanics(t)

spec := map[string]int{
user: 3,
}

err = scenario.CreateHeadscaleEnv(spec,
[]tsic.Option{},
hsic.WithTestName("clientdisc"),
hsic.WithEmbeddedDERPServerOnly(),
hsic.WithTLS(),
hsic.WithHostnameAsServerURL(),
hsic.WithIPAllocationStrategy(types.IPAllocationStrategyRandom),
)
assertNoErrHeadscaleEnv(t, err)

allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)

err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)

headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)

expectedRoutes := map[string]string{
"1": "10.0.0.0/24",
"2": "10.0.0.0/24",
}

// Sort nodes by ID
sort.SliceStable(allClients, func(i, j int) bool {
statusI, err := allClients[i].Status()
if err != nil {
return false
}

statusJ, err := allClients[j].Status()
if err != nil {
return false
}

return statusI.Self.ID < statusJ.Self.ID
})

subRouter1 := allClients[0]
subRouter2 := allClients[1]

t.Logf("Advertise route from r1 (%s) and r2 (%s), making it HA, n1 is primary", subRouter1.Hostname(), subRouter2.Hostname())
// advertise HA route on node 1 and 2
// ID 1 will be primary
// ID 2 will be secondary
for _, client := range allClients[:2] {
status, err := client.Status()
assertNoErr(t, err)

if route, ok := expectedRoutes[string(status.Self.ID)]; ok {
command := []string{
"tailscale",
"set",
"--advertise-routes=" + route,
}
_, _, err = client.Execute(command)
assertNoErrf(t, "failed to advertise route: %s", err)
} else {
t.Fatalf("failed to find route for Node %s (id: %s)", status.Self.HostName, status.Self.ID)
}
}

err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)

var routes []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&routes,
)

assertNoErr(t, err)
assert.Len(t, routes, 2)

t.Logf("initial routes %#v", routes)

for _, route := range routes {
assert.Equal(t, true, route.GetAdvertised())
assert.Equal(t, false, route.GetEnabled())
assert.Equal(t, false, route.GetIsPrimary())
}

// Verify that no routes has been sent to the client,
// they are not yet enabled.
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)

for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey]

assert.Nil(t, peerStatus.PrimaryRoutes)
}
}

// Enable all routes
for _, route := range routes {
_, err = headscale.Execute(
[]string{
"headscale",
"routes",
"enable",
"--route",
strconv.Itoa(int(route.GetId())),
})
assertNoErr(t, err)

time.Sleep(time.Second)
}

var enablingRoutes []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&enablingRoutes,
)
assertNoErr(t, err)
assert.Len(t, enablingRoutes, 2)

// Node 1 is primary
assert.Equal(t, true, enablingRoutes[0].GetAdvertised())
assert.Equal(t, true, enablingRoutes[0].GetEnabled())
assert.Equal(t, true, enablingRoutes[0].GetIsPrimary(), "both subnet routers are up, expected r1 to be primary")

// Node 2 is not primary
assert.Equal(t, true, enablingRoutes[1].GetAdvertised())
assert.Equal(t, true, enablingRoutes[1].GetEnabled())
assert.Equal(t, false, enablingRoutes[1].GetIsPrimary(), "both subnet routers are up, expected r2 to be non-primary")

var nodeList []v1.Node
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"nodes",
"list",
"--output",
"json",
},
&nodeList,
)
assert.Nil(t, err)
assert.Len(t, nodeList, 3)
assert.True(t, nodeList[0].Online)
assert.True(t, nodeList[1].Online)
assert.True(t, nodeList[2].Online)

// Kill off one of the docker containers to simulate a disconnect
err = scenario.DisconnectContainersFromScenario(subRouter1.Hostname())
assertNoErr(t, err)

time.Sleep(5 * time.Second)

var nodeListAfterDisconnect []v1.Node
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"nodes",
"list",
"--output",
"json",
},
&nodeListAfterDisconnect,
)
assert.Nil(t, err)
assert.Len(t, nodeListAfterDisconnect, 3)
assert.False(t, nodeListAfterDisconnect[0].Online)
assert.True(t, nodeListAfterDisconnect[1].Online)
assert.True(t, nodeListAfterDisconnect[2].Online)

var routesAfterDisconnect []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&routesAfterDisconnect,
)
assertNoErr(t, err)
assert.Len(t, routesAfterDisconnect, 2)

// Node 1 is primary
assert.Equal(t, true, routesAfterDisconnect[0].GetAdvertised())
assert.Equal(t, true, routesAfterDisconnect[0].GetEnabled())
assert.Equal(t, false, routesAfterDisconnect[0].GetIsPrimary(), "both subnet routers are up, expected r1 to be non-primary")

// Node 2 is not primary
assert.Equal(t, true, routesAfterDisconnect[1].GetAdvertised())
assert.Equal(t, true, routesAfterDisconnect[1].GetEnabled())
assert.Equal(t, true, routesAfterDisconnect[1].GetIsPrimary(), "both subnet routers are up, expected r2 to be primary")

// // Ensure the node can reconncet as expected
// err = scenario.ConnectContainersToScenario(subRouter1.Hostname())
// assertNoErr(t, err)

// time.Sleep(5 * time.Second)

// var nodeListAfterReconnect []v1.Node
// err = executeAndUnmarshal(
// headscale,
// []string{
// "headscale",
// "nodes",
// "list",
// "--output",
// "json",
// },
// &nodeListAfterReconnect,
// )
// assert.Nil(t, err)
// assert.Len(t, nodeListAfterReconnect, 3)
// assert.True(t, nodeListAfterReconnect[0].Online)
// assert.True(t, nodeListAfterReconnect[1].Online)
// assert.True(t, nodeListAfterReconnect[2].Online)
}
28 changes: 28 additions & 0 deletions integration/scenario.go
Original file line number Diff line number Diff line change
Expand Up @@ -649,3 +649,31 @@ func (s *Scenario) WaitForTailscaleLogout() error {

return nil
}

// DisconnectContainersFromScenario disconnects a list of containers from the network.
func (s *Scenario) DisconnectContainersFromScenario(containers ...string) error {
for _, container := range containers {
if ctr, ok := s.pool.ContainerByName(container); ok {
err := ctr.DisconnectFromNetwork(s.network)
if err != nil {
return err
}
}
}

return nil
}

// ConnectContainersToScenario disconnects a list of containers from the network.
func (s *Scenario) ConnectContainersToScenario(containers ...string) error {
for _, container := range containers {
if ctr, ok := s.pool.ContainerByName(container); ok {
err := ctr.ConnectToNetwork(s.network)
if err != nil {
return err
}
}
}

return nil
}
Comment on lines +667 to +679
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider extracting the common logic into a separate function to make the code more DRY.

The ConnectContainersToScenario and DisconnectContainersFromScenario functions share the same structure and logic, with the only difference being the specific action performed on the container (connect or disconnect). To make the code more DRY (Don't Repeat Yourself), consider extracting the common logic into a separate function that takes a function parameter for the specific action.

Here's an example of how the refactored code could look:

func (s *Scenario) performActionOnContainers(containers []string, action func(*dockertest.Container) error) error {
    for _, container := range containers {
        if ctr, ok := s.pool.ContainerByName(container); ok {
            err := action(ctr)
            if err != nil {
                return err
            }
        } else {
            log.Printf("Warning: Container %s not found in the pool", container)
        }
    }
    return nil
}

func (s *Scenario) DisconnectContainersFromScenario(containers ...string) error {
    return s.performActionOnContainers(containers, func(ctr *dockertest.Container) error {
        return ctr.DisconnectFromNetwork(s.network)
    })
}

func (s *Scenario) ConnectContainersToScenario(containers ...string) error {
    return s.performActionOnContainers(containers, func(ctr *dockertest.Container) error {
        return ctr.ConnectToNetwork(s.network)
    })
}

Loading