diff --git a/agent.go b/agent.go index 5befb4a4..344200c2 100644 --- a/agent.go +++ b/agent.go @@ -133,9 +133,9 @@ type Agent struct { gatherCandidateCancel func() gatherCandidateDone chan struct{} - chanCandidate chan Candidate - chanCandidatePair chan *CandidatePair - chanState chan ConnectionState + connectionStateNotifier *handlerNotifier + candidateNotifier *handlerNotifier + selectedCandidatePairNotifier *handlerNotifier loggerFactory logging.LoggerFactory log logging.LeveledLogger @@ -232,9 +232,6 @@ func (a *Agent) taskLoop() { after() - close(a.chanState) - close(a.chanCandidate) - close(a.chanCandidatePair) close(a.taskLoopDone) }() @@ -282,33 +279,30 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit startedCtx, startedFn := context.WithCancel(context.Background()) a := &Agent{ - chanTask: make(chan task), - chanState: make(chan ConnectionState), - chanCandidate: make(chan Candidate), - chanCandidatePair: make(chan *CandidatePair), - tieBreaker: globalMathRandomGenerator.Uint64(), - lite: config.Lite, - gatheringState: GatheringStateNew, - connectionState: ConnectionStateNew, - localCandidates: make(map[NetworkType][]Candidate), - remoteCandidates: make(map[NetworkType][]Candidate), - urls: config.Urls, - networkTypes: config.NetworkTypes, - onConnected: make(chan struct{}), - buf: packetio.NewBuffer(), - done: make(chan struct{}), - taskLoopDone: make(chan struct{}), - startedCh: startedCtx.Done(), - startedFn: startedFn, - portMin: config.PortMin, - portMax: config.PortMax, - loggerFactory: loggerFactory, - log: log, - net: config.Net, - proxyDialer: config.ProxyDialer, - tcpMux: config.TCPMux, - udpMux: config.UDPMux, - udpMuxSrflx: config.UDPMuxSrflx, + chanTask: make(chan task), + tieBreaker: globalMathRandomGenerator.Uint64(), + lite: config.Lite, + gatheringState: GatheringStateNew, + connectionState: ConnectionStateNew, + localCandidates: make(map[NetworkType][]Candidate), + remoteCandidates: make(map[NetworkType][]Candidate), + urls: config.Urls, + networkTypes: config.NetworkTypes, + onConnected: make(chan struct{}), + buf: packetio.NewBuffer(), + done: make(chan struct{}), + taskLoopDone: make(chan struct{}), + startedCh: startedCtx.Done(), + startedFn: startedFn, + portMin: config.PortMin, + portMax: config.PortMax, + loggerFactory: loggerFactory, + log: log, + net: config.Net, + proxyDialer: config.ProxyDialer, + tcpMux: config.TCPMux, + udpMux: config.UDPMux, + udpMuxSrflx: config.UDPMuxSrflx, mDNSMode: mDNSMode, mDNSName: mDNSName, @@ -329,6 +323,9 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit userBindingRequestHandler: config.BindingRequestHandler, } + a.connectionStateNotifier = &handlerNotifier{connectionStateFunc: a.onConnectionStateChange} + a.candidateNotifier = &handlerNotifier{candidateFunc: a.onCandidate} + a.selectedCandidatePairNotifier = &handlerNotifier{candidatePairFunc: a.onSelectedCandidatePairChange} if a.net == nil { a.net, err = stdnet.NewNet() @@ -372,13 +369,6 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit go a.taskLoop() - // CandidatePair and ConnectionState are usually changed at once. - // Blocking one by the other one causes deadlock. - // Hence, we call handlers from independent Goroutines. - go a.candidatePairRoutine() - go a.connectionStateRoutine() - go a.candidateRoutine() - // Restart is also used to initialize the agent for the first time if err := a.Restart(config.LocalUfrag, config.LocalPwd); err != nil { a.closeMulticastConn() @@ -516,12 +506,7 @@ func (a *Agent) updateConnectionState(newState ConnectionState) { a.log.Infof("Setting new connection state: %s", newState) a.connectionState = newState - - // Call handler after finishing current task since we may be holding the agent lock - // and the handler may also require it - a.afterRun(func(ctx context.Context) { - a.chanState <- newState - }) + a.connectionStateNotifier.EnqueueConnectionState(newState) } } @@ -540,12 +525,7 @@ func (a *Agent) setSelectedPair(p *CandidatePair) { a.updateConnectionState(ConnectionStateConnected) // Notify when the selected pair changes - a.afterRun(func(ctx context.Context) { - select { - case a.chanCandidatePair <- p: - case <-ctx.Done(): - } - }) + a.selectedCandidatePairNotifier.EnqueueSelectedCandidatePair(p) // Signal connected a.onConnectedOnce.Do(func() { close(a.onConnected) }) @@ -781,7 +761,7 @@ func (a *Agent) addRemotePassiveTCPCandidate(remoteCandidate Candidate) { localCandidate.start(a, conn, a.startedCh) a.localCandidates[localCandidate.NetworkType()] = append(a.localCandidates[localCandidate.NetworkType()], localCandidate) - a.chanCandidate <- localCandidate + a.candidateNotifier.EnqueueCandidate(localCandidate) a.addPair(localCandidate, remoteCandidate) } @@ -851,7 +831,7 @@ func (a *Agent) addCandidate(ctx context.Context, c Candidate, candidateConn net a.requestConnectivityCheck() - a.chanCandidate <- c + a.candidateNotifier.EnqueueCandidate(c) }) } @@ -1287,7 +1267,7 @@ func (a *Agent) setGatheringState(newState GatheringState) error { done := make(chan struct{}) if err := a.run(a.context(), func(ctx context.Context, agent *Agent) { if a.gatheringState != newState && newState == GatheringStateComplete { - a.chanCandidate <- nil + a.candidateNotifier.EnqueueCandidate(nil) } a.gatheringState = newState diff --git a/agent_handlers.go b/agent_handlers.go index c5a5ec03..bb0c8d30 100644 --- a/agent_handlers.go +++ b/agent_handlers.go @@ -3,6 +3,8 @@ package ice +import "sync" + // OnConnectionStateChange sets a handler that is fired when the connection state changes func (a *Agent) OnConnectionStateChange(f func(ConnectionState)) error { a.onConnectionStateChangeHdlr.Store(f) @@ -41,20 +43,94 @@ func (a *Agent) onConnectionStateChange(s ConnectionState) { } } -func (a *Agent) candidatePairRoutine() { - for p := range a.chanCandidatePair { - a.onSelectedCandidatePairChange(p) +type handlerNotifier struct { + sync.Mutex + running bool + + connectionStates []ConnectionState + connectionStateFunc func(ConnectionState) + + candidates []Candidate + candidateFunc func(Candidate) + + selectedCandidatePairs []*CandidatePair + candidatePairFunc func(*CandidatePair) +} + +func (h *handlerNotifier) EnqueueConnectionState(s ConnectionState) { + h.Lock() + defer h.Unlock() + + notify := func() { + for { + h.Lock() + if len(h.connectionStates) == 0 { + h.running = false + h.Unlock() + return + } + notification := h.connectionStates[0] + h.connectionStates = h.connectionStates[1:] + h.Unlock() + h.connectionStateFunc(notification) + } + } + + h.connectionStates = append(h.connectionStates, s) + if !h.running { + h.running = true + go notify() } } -func (a *Agent) connectionStateRoutine() { - for s := range a.chanState { - go a.onConnectionStateChange(s) +func (h *handlerNotifier) EnqueueCandidate(c Candidate) { + h.Lock() + defer h.Unlock() + + notify := func() { + for { + h.Lock() + if len(h.candidates) == 0 { + h.running = false + h.Unlock() + return + } + notification := h.candidates[0] + h.candidates = h.candidates[1:] + h.Unlock() + h.candidateFunc(notification) + } + } + + h.candidates = append(h.candidates, c) + if !h.running { + h.running = true + go notify() } } -func (a *Agent) candidateRoutine() { - for c := range a.chanCandidate { - a.onCandidate(c) +func (h *handlerNotifier) EnqueueSelectedCandidatePair(p *CandidatePair) { + h.Lock() + defer h.Unlock() + + notify := func() { + for { + h.Lock() + if len(h.selectedCandidatePairs) == 0 { + h.running = false + h.Unlock() + return + } + notification := h.selectedCandidatePairs[0] + h.selectedCandidatePairs = h.selectedCandidatePairs[1:] + h.Unlock() + h.candidatePairFunc(notification) + } + } + + h.selectedCandidatePairs = append(h.selectedCandidatePairs, p) + if !h.running { + h.running = true + go notify() } } diff --git a/agent_handlers_test.go b/agent_handlers_test.go new file mode 100644 index 00000000..66ff8048 --- /dev/null +++ b/agent_handlers_test.go @@ -0,0 +1,71 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "testing" + "time" + + "github.com/pion/transport/v2/test" +) + +func TestConnectionStateNotifier(t *testing.T) { + t.Run("TestManyUpdates", func(t *testing.T) { + report := test.CheckRoutines(t) + defer report() + updates := make(chan struct{}, 1) + c := &handlerNotifier{ + connectionStateFunc: func(_ ConnectionState) { + updates <- struct{}{} + }, + } + // Enqueue all updates upfront to ensure that it + // doesn't block + for i := 0; i < 10000; i++ { + c.EnqueueConnectionState(ConnectionStateNew) + } + done := make(chan struct{}) + go func() { + for i := 0; i < 10000; i++ { + <-updates + } + select { + case <-updates: + t.Errorf("received more updates than expected") + case <-time.After(1 * time.Second): + } + close(done) + }() + <-done + }) + t.Run("TestUpdateOrdering", func(t *testing.T) { + report := test.CheckRoutines(t) + defer report() + updates := make(chan ConnectionState) + c := &handlerNotifier{ + connectionStateFunc: func(cs ConnectionState) { + updates <- cs + }, + } + done := make(chan struct{}) + go func() { + for i := 0; i < 10000; i++ { + x := <-updates + if x != ConnectionState(i) { + t.Errorf("expected %d got %d", x, i) + } + } + select { + case <-updates: + t.Errorf("received more updates than expected") + case <-time.After(1 * time.Second): + } + close(done) + }() + for i := 0; i < 10000; i++ { + c.EnqueueConnectionState(ConnectionState(i)) + } + <-done + }) +} diff --git a/selection_test.go b/selection_test.go index 260ea4e9..5caea144 100644 --- a/selection_test.go +++ b/selection_test.go @@ -17,7 +17,7 @@ import ( "time" "github.com/pion/stun" - "github.com/pion/transport/v3/test" + "github.com/pion/transport/v2/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" )