diff --git a/agent.go b/agent.go index f4cb67ab..15f9f0ca 100644 --- a/agent.go +++ b/agent.go @@ -130,7 +130,7 @@ type Agent struct { chanCandidate chan Candidate chanCandidatePair chan *CandidatePair - chanState chan ConnectionState + stateNotifier *connectionStateNotifier loggerFactory logging.LoggerFactory log logging.LeveledLogger @@ -223,9 +223,10 @@ func (a *Agent) taskLoop() { } a.closeMulticastConn() + a.updateConnectionState(ConnectionStateClosed) + after() - close(a.chanState) close(a.chanCandidate) close(a.chanCandidatePair) close(a.taskLoopDone) @@ -276,7 +277,6 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit a := &Agent{ chanTask: make(chan task), - chanState: make(chan ConnectionState), chanCandidate: make(chan Candidate), chanCandidatePair: make(chan *CandidatePair), tieBreaker: globalMathRandomGenerator.Uint64(), @@ -320,6 +320,7 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit disableActiveTCP: config.DisableActiveTCP, } + a.stateNotifier = &connectionStateNotifier{NotificationFunc: a.onConnectionStateChange} if a.net == nil { a.net, err = stdnet.NewNet() @@ -367,7 +368,6 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit // Blocking one by the other one causes deadlock. // Hence, we call handlers from independent Goroutines. go a.candidatePairRoutine() - go a.connectionStateRoutine() go a.candidateRoutine() // Restart is also used to initialize the agent for the first time @@ -501,12 +501,7 @@ func (a *Agent) updateConnectionState(newState ConnectionState) { a.log.Infof("Setting new connection state: %s", newState) a.connectionState = newState - - // Call handler after finishing current task since we may be holding the agent lock - // and the handler may also require it - a.afterRun(func(ctx context.Context) { - a.chanState <- newState - }) + a.stateNotifier.Enqueue(newState) } } diff --git a/agent_handlers.go b/agent_handlers.go index 0e8af04b..4ec24843 100644 --- a/agent_handlers.go +++ b/agent_handlers.go @@ -3,6 +3,8 @@ package ice +import "sync" + // OnConnectionStateChange sets a handler that is fired when the connection state changes func (a *Agent) OnConnectionStateChange(f func(ConnectionState)) error { a.onConnectionStateChangeHdlr.Store(f) @@ -47,11 +49,36 @@ func (a *Agent) candidatePairRoutine() { } } -func (a *Agent) connectionStateRoutine() { - for s := range a.chanState { - a.onConnectionStateChange(s) +type connectionStateNotifier struct { + sync.Mutex + states []ConnectionState + running bool + NotificationFunc func(ConnectionState) +} + +func (c *connectionStateNotifier) Enqueue(s ConnectionState) { + c.Lock() + defer c.Unlock() + c.states = append(c.states, s) + if !c.running { + c.running = true + go c.notify() + } +} + +func (c *connectionStateNotifier) notify() { + for { + c.Lock() + if len(c.states) == 0 { + c.running = false + c.Unlock() + return + } + s := c.states[0] + c.states = c.states[1:] + c.Unlock() + c.NotificationFunc(s) } - a.onConnectionStateChange(ConnectionStateClosed) } func (a *Agent) candidateRoutine() { diff --git a/agent_handlers_test.go b/agent_handlers_test.go new file mode 100644 index 00000000..05ed5b8f --- /dev/null +++ b/agent_handlers_test.go @@ -0,0 +1,71 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "testing" + "time" + + "github.com/pion/transport/v3/test" +) + +func TestConnectionStateNotifier(t *testing.T) { + t.Run("TestManyUpdates", func(t *testing.T) { + report := test.CheckRoutines(t) + defer report() + updates := make(chan struct{}, 1) + c := &connectionStateNotifier{ + NotificationFunc: func(_ ConnectionState) { + updates <- struct{}{} + }, + } + // Enqueue all updates upfront to ensure that it + // doesn't block + for i := 0; i < 10000; i++ { + c.Enqueue(ConnectionStateNew) + } + done := make(chan struct{}) + go func() { + for i := 0; i < 10000; i++ { + <-updates + } + select { + case <-updates: + t.Errorf("received more updates than expected") + case <-time.After(1 * time.Second): + } + close(done) + }() + <-done + }) + t.Run("TestUpdateOrdering", func(t *testing.T) { + report := test.CheckRoutines(t) + defer report() + updates := make(chan ConnectionState) + c := &connectionStateNotifier{ + NotificationFunc: func(cs ConnectionState) { + updates <- cs + }, + } + done := make(chan struct{}) + go func() { + for i := 0; i < 10000; i++ { + x := <-updates + if x != ConnectionState(i) { + t.Errorf("expected %d got %d", x, i) + } + } + select { + case <-updates: + t.Errorf("received more updates than expected") + case <-time.After(1 * time.Second): + } + close(done) + }() + for i := 0; i < 10000; i++ { + c.Enqueue(ConnectionState(i)) + } + <-done + }) +} diff --git a/agent_test.go b/agent_test.go index 2d07aa17..291df4d7 100644 --- a/agent_test.go +++ b/agent_test.go @@ -1394,7 +1394,7 @@ func TestCloseInConnectionStateCallback(t *testing.T) { report := test.CheckRoutines(t) defer report() - lim := test.TimeOut(time.Second * 10) + lim := test.TimeOut(time.Second * 5) defer lim.Stop() disconnectedDuration := time.Second