Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Callback ordering for ConnectionState #702

Merged
merged 2 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 35 additions & 55 deletions agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
"sync/atomic"
"time"

atomicx "github.com/pion/ice/v2/internal/atomic"

Check failure on line 19 in agent.go

View workflow job for this annotation

GitHub Actions / lint / Go

import 'github.com/pion/ice/v2/internal/atomic' is not allowed from list 'Main' (depguard)
stunx "github.com/pion/ice/v2/internal/stun"

Check failure on line 20 in agent.go

View workflow job for this annotation

GitHub Actions / lint / Go

import 'github.com/pion/ice/v2/internal/stun' is not allowed from list 'Main' (depguard)
"github.com/pion/logging"

Check failure on line 21 in agent.go

View workflow job for this annotation

GitHub Actions / lint / Go

import 'github.com/pion/logging' is not allowed from list 'Main' (depguard)
"github.com/pion/mdns"

Check failure on line 22 in agent.go

View workflow job for this annotation

GitHub Actions / lint / Go

import 'github.com/pion/mdns' is not allowed from list 'Main' (depguard)
"github.com/pion/stun"
"github.com/pion/transport/v2"
"github.com/pion/transport/v2/packetio"
Expand Down Expand Up @@ -133,9 +133,9 @@
gatherCandidateCancel func()
gatherCandidateDone chan struct{}

chanCandidate chan Candidate
chanCandidatePair chan *CandidatePair
chanState chan ConnectionState
connectionStateNotifier *handlerNotifier
candidateNotifier *handlerNotifier
selectedCandidatePairNotifier *handlerNotifier

loggerFactory logging.LoggerFactory
log logging.LeveledLogger
Expand Down Expand Up @@ -232,9 +232,6 @@

after()

close(a.chanState)
close(a.chanCandidate)
close(a.chanCandidatePair)
close(a.taskLoopDone)
}()

Expand Down Expand Up @@ -282,33 +279,30 @@
startedCtx, startedFn := context.WithCancel(context.Background())

a := &Agent{
chanTask: make(chan task),
chanState: make(chan ConnectionState),
chanCandidate: make(chan Candidate),
chanCandidatePair: make(chan *CandidatePair),
tieBreaker: globalMathRandomGenerator.Uint64(),
lite: config.Lite,
gatheringState: GatheringStateNew,
connectionState: ConnectionStateNew,
localCandidates: make(map[NetworkType][]Candidate),
remoteCandidates: make(map[NetworkType][]Candidate),
urls: config.Urls,
networkTypes: config.NetworkTypes,
onConnected: make(chan struct{}),
buf: packetio.NewBuffer(),
done: make(chan struct{}),
taskLoopDone: make(chan struct{}),
startedCh: startedCtx.Done(),
startedFn: startedFn,
portMin: config.PortMin,
portMax: config.PortMax,
loggerFactory: loggerFactory,
log: log,
net: config.Net,
proxyDialer: config.ProxyDialer,
tcpMux: config.TCPMux,
udpMux: config.UDPMux,
udpMuxSrflx: config.UDPMuxSrflx,
chanTask: make(chan task),
tieBreaker: globalMathRandomGenerator.Uint64(),
lite: config.Lite,
gatheringState: GatheringStateNew,
connectionState: ConnectionStateNew,
localCandidates: make(map[NetworkType][]Candidate),
remoteCandidates: make(map[NetworkType][]Candidate),
urls: config.Urls,
networkTypes: config.NetworkTypes,
onConnected: make(chan struct{}),
buf: packetio.NewBuffer(),
done: make(chan struct{}),
taskLoopDone: make(chan struct{}),
startedCh: startedCtx.Done(),
startedFn: startedFn,
portMin: config.PortMin,
portMax: config.PortMax,
loggerFactory: loggerFactory,
log: log,
net: config.Net,
proxyDialer: config.ProxyDialer,
tcpMux: config.TCPMux,
udpMux: config.UDPMux,
udpMuxSrflx: config.UDPMuxSrflx,

mDNSMode: mDNSMode,
mDNSName: mDNSName,
Expand All @@ -329,6 +323,9 @@

userBindingRequestHandler: config.BindingRequestHandler,
}
a.connectionStateNotifier = &handlerNotifier{connectionStateFunc: a.onConnectionStateChange}
a.candidateNotifier = &handlerNotifier{candidateFunc: a.onCandidate}
a.selectedCandidatePairNotifier = &handlerNotifier{candidatePairFunc: a.onSelectedCandidatePairChange}

if a.net == nil {
a.net, err = stdnet.NewNet()
Expand Down Expand Up @@ -372,13 +369,6 @@

go a.taskLoop()

// CandidatePair and ConnectionState are usually changed at once.
// Blocking one by the other one causes deadlock.
// Hence, we call handlers from independent Goroutines.
go a.candidatePairRoutine()
go a.connectionStateRoutine()
go a.candidateRoutine()

// Restart is also used to initialize the agent for the first time
if err := a.Restart(config.LocalUfrag, config.LocalPwd); err != nil {
a.closeMulticastConn()
Expand Down Expand Up @@ -516,12 +506,7 @@

a.log.Infof("Setting new connection state: %s", newState)
a.connectionState = newState

// Call handler after finishing current task since we may be holding the agent lock
// and the handler may also require it
a.afterRun(func(ctx context.Context) {
a.chanState <- newState
})
a.connectionStateNotifier.EnqueueConnectionState(newState)
}
}

Expand All @@ -540,12 +525,7 @@
a.updateConnectionState(ConnectionStateConnected)

// Notify when the selected pair changes
a.afterRun(func(ctx context.Context) {
select {
case a.chanCandidatePair <- p:
case <-ctx.Done():
}
})
a.selectedCandidatePairNotifier.EnqueueSelectedCandidatePair(p)

// Signal connected
a.onConnectedOnce.Do(func() { close(a.onConnected) })
Expand Down Expand Up @@ -781,7 +761,7 @@

localCandidate.start(a, conn, a.startedCh)
a.localCandidates[localCandidate.NetworkType()] = append(a.localCandidates[localCandidate.NetworkType()], localCandidate)
a.chanCandidate <- localCandidate
a.candidateNotifier.EnqueueCandidate(localCandidate)

a.addPair(localCandidate, remoteCandidate)
}
Expand Down Expand Up @@ -851,7 +831,7 @@

a.requestConnectivityCheck()

a.chanCandidate <- c
a.candidateNotifier.EnqueueCandidate(c)
})
}

Expand Down Expand Up @@ -1287,7 +1267,7 @@
done := make(chan struct{})
if err := a.run(a.context(), func(ctx context.Context, agent *Agent) {
if a.gatheringState != newState && newState == GatheringStateComplete {
a.chanCandidate <- nil
a.candidateNotifier.EnqueueCandidate(nil)
}

a.gatheringState = newState
Expand Down
94 changes: 85 additions & 9 deletions agent_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

package ice

import "sync"

// OnConnectionStateChange sets a handler that is fired when the connection state changes
func (a *Agent) OnConnectionStateChange(f func(ConnectionState)) error {
a.onConnectionStateChangeHdlr.Store(f)
Expand Down Expand Up @@ -41,20 +43,94 @@ func (a *Agent) onConnectionStateChange(s ConnectionState) {
}
}

func (a *Agent) candidatePairRoutine() {
for p := range a.chanCandidatePair {
a.onSelectedCandidatePairChange(p)
type handlerNotifier struct {
sync.Mutex
running bool

connectionStates []ConnectionState
connectionStateFunc func(ConnectionState)

candidates []Candidate
candidateFunc func(Candidate)

selectedCandidatePairs []*CandidatePair
candidatePairFunc func(*CandidatePair)
}

func (h *handlerNotifier) EnqueueConnectionState(s ConnectionState) {
h.Lock()
defer h.Unlock()

notify := func() {
for {
h.Lock()
if len(h.connectionStates) == 0 {
h.running = false
h.Unlock()
return
}
notification := h.connectionStates[0]
h.connectionStates = h.connectionStates[1:]
h.Unlock()
h.connectionStateFunc(notification)
}
}

h.connectionStates = append(h.connectionStates, s)
if !h.running {
h.running = true
go notify()
}
}

func (a *Agent) connectionStateRoutine() {
for s := range a.chanState {
go a.onConnectionStateChange(s)
func (h *handlerNotifier) EnqueueCandidate(c Candidate) {
h.Lock()
defer h.Unlock()

notify := func() {
for {
h.Lock()
if len(h.candidates) == 0 {
h.running = false
h.Unlock()
return
}
notification := h.candidates[0]
h.candidates = h.candidates[1:]
h.Unlock()
h.candidateFunc(notification)
}
}

h.candidates = append(h.candidates, c)
if !h.running {
h.running = true
go notify()
}
}

func (a *Agent) candidateRoutine() {
for c := range a.chanCandidate {
a.onCandidate(c)
func (h *handlerNotifier) EnqueueSelectedCandidatePair(p *CandidatePair) {
h.Lock()
defer h.Unlock()

notify := func() {
for {
h.Lock()
if len(h.selectedCandidatePairs) == 0 {
h.running = false
h.Unlock()
return
}
notification := h.selectedCandidatePairs[0]
h.selectedCandidatePairs = h.selectedCandidatePairs[1:]
h.Unlock()
h.candidatePairFunc(notification)
}
}

h.selectedCandidatePairs = append(h.selectedCandidatePairs, p)
if !h.running {
h.running = true
go notify()
}
}
71 changes: 71 additions & 0 deletions agent_handlers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

package ice

import (
"testing"
"time"

"github.com/pion/transport/v2/test"
)

func TestConnectionStateNotifier(t *testing.T) {
t.Run("TestManyUpdates", func(t *testing.T) {
report := test.CheckRoutines(t)
defer report()
updates := make(chan struct{}, 1)
c := &handlerNotifier{
connectionStateFunc: func(_ ConnectionState) {
updates <- struct{}{}
},
}
// Enqueue all updates upfront to ensure that it
// doesn't block
for i := 0; i < 10000; i++ {
c.EnqueueConnectionState(ConnectionStateNew)
}
done := make(chan struct{})
go func() {
for i := 0; i < 10000; i++ {
<-updates
}
select {
case <-updates:
t.Errorf("received more updates than expected")
case <-time.After(1 * time.Second):
}
close(done)
}()
<-done
})
t.Run("TestUpdateOrdering", func(t *testing.T) {
report := test.CheckRoutines(t)
defer report()
updates := make(chan ConnectionState)
c := &handlerNotifier{
connectionStateFunc: func(cs ConnectionState) {
updates <- cs
},
}
done := make(chan struct{})
go func() {
for i := 0; i < 10000; i++ {
x := <-updates
if x != ConnectionState(i) {
t.Errorf("expected %d got %d", x, i)
}
}
select {
case <-updates:
t.Errorf("received more updates than expected")
case <-time.After(1 * time.Second):
}
close(done)
}()
for i := 0; i < 10000; i++ {
c.EnqueueConnectionState(ConnectionState(i))
}
<-done
})
}
2 changes: 1 addition & 1 deletion selection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
"time"

"github.com/pion/stun"
"github.com/pion/transport/v3/test"
"github.com/pion/transport/v2/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down
Loading