Skip to content

Commit

Permalink
Fix and test goroutine leak (#269)
Browse files Browse the repository at this point in the history
Fix goroutine leaks unaddressed by #265 and add more tests for goroutine leak.
  • Loading branch information
mingyech authored Feb 13, 2024
1 parent c8df961 commit 6e9ba3c
Show file tree
Hide file tree
Showing 8 changed files with 382 additions and 57 deletions.
2 changes: 1 addition & 1 deletion pkg/dtls/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func ClientWithContext(ctx context.Context, conn net.Conn, config *Config) (net.
dtlsConn, err := dtls.ClientWithContext(ctx, conn, dtlsConf)

if err != nil {
return nil, fmt.Errorf("error creating dtls connection: %v", err)
return nil, fmt.Errorf("error creating dtls connection: %w", err)
}

wrappedConn, err := wrapSCTP(dtlsConn, config)
Expand Down
28 changes: 28 additions & 0 deletions pkg/dtls/goroutine_leak_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package dtls

import (
"runtime"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func passGoroutineLeak(testFunc func(*testing.T), t *testing.T) bool {
initialGoroutines := runtime.NumGoroutine()

testFunc(t)

time.Sleep(2 * time.Second)

return runtime.NumGoroutine() <= initialGoroutines
}

func TestGoroutineLeak(t *testing.T) {
testFuncs := []func(*testing.T){TestSend, TestServerFail, TestClientFail, TestListenSuccess, TestListenFail}

for _, test := range testFuncs {
require.True(t, passGoroutineLeak(test, t))
}

}
48 changes: 30 additions & 18 deletions pkg/dtls/heartbeat.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package dtls
import (
"bytes"
"errors"
"io"
"net"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -76,6 +75,12 @@ func (c *hbConn) recvLoop() {
for {
buffer := make([]byte, c.maxMessageSize)

err := c.stream.SetReadDeadline(time.Now().Add(c.timeout))
if err != nil {
c.Close()
return
}

n, err := c.stream.Read(buffer)

if bytes.Equal(c.hb, buffer[:n]) {
Expand All @@ -84,16 +89,19 @@ func (c *hbConn) recvLoop() {
}

if err != nil {
c.recvCh <- errBytes{nil, err}
switch {
case errors.Is(err, net.ErrClosed):
case errors.Is(err, io.EOF):
c.Close()
return
}
c.Close()
return
}

c.recvCh <- errBytes{buffer[:n], err}
timer := time.NewTimer(c.timeout)
select {
case c.recvCh <- errBytes{buffer[:n], err}:
timer.Stop()
continue
case <-timer.C:
c.Close()
return
}
}

}
Expand All @@ -108,18 +116,22 @@ func (c *hbConn) Write(b []byte) (n int, err error) {
}

func (c *hbConn) Read(b []byte) (int, error) {
readBytes := <-c.recvCh
if readBytes.err != nil {
return 0, readBytes.err
}
select {
case <-c.closed:
return 0, net.ErrClosed
case readBytes := <-c.recvCh:
if readBytes.err != nil {
return 0, readBytes.err
}

if len(b) < len(readBytes.b) {
return 0, ErrInsufficientBuffer
}
if len(b) < len(readBytes.b) {
return 0, ErrInsufficientBuffer
}

n := copy(b, readBytes.b)
n := copy(b, readBytes.b)

return n, nil
return n, nil
}
}

func (c *hbConn) BufferedAmount() uint64 {
Expand Down
93 changes: 62 additions & 31 deletions pkg/dtls/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@ import (
"fmt"
"net"
"sync"
"time"

"github.com/pion/dtls/v2"
"github.com/pion/dtls/v2/pkg/protocol/handshake"
"github.com/pion/transport/v2/udp"
)

const defaultAcceptTimeout = 5 * time.Second

// Listen creates a listener and starts listening
func Listen(network string, laddr *net.UDPAddr, config *Config) (*Listener, error) {
lc := udp.ListenConfig{}
Expand All @@ -39,40 +42,49 @@ func (l *Listener) acceptLoop() {
}

for {
c, err := l.parent.Accept()
if err != nil {
continue
}

go func() {
newDTLSConn, err := dtls.Server(c, config)
select {
case <-l.closed:
return
default:
c, err := l.parent.Accept()
if err != nil {
switch addr := c.RemoteAddr().(type) {
case *net.UDPAddr:
l.logIP(err, &addr.IP)
case *net.TCPAddr:
l.logIP(err, &addr.IP)
case *net.IPAddr:
l.logIP(err, &addr.IP)
}

return
continue
}

connState := newDTLSConn.ConnectionState()
connID := connState.RemoteRandomBytes()
go func() {
ctx, cancel := context.WithTimeout(context.Background(), defaultAcceptTimeout)
defer cancel()
newDTLSConn, err := dtls.ServerWithContext(ctx, c, config)
if err != nil {
switch addr := c.RemoteAddr().(type) {
case *net.UDPAddr:
l.logIP(err, &addr.IP)
case *net.TCPAddr:
l.logIP(err, &addr.IP)
case *net.IPAddr:
l.logIP(err, &addr.IP)
}

return
}

l.connMapMutex.RLock()
defer l.connMapMutex.RUnlock()
connState := newDTLSConn.ConnectionState()
connID := connState.RemoteRandomBytes()

acceptCh, ok := l.connMap[connID]
acceptCh, err := l.chFromID(connID)

if !ok {
return
}
if err != nil {
return
}

acceptCh <- newDTLSConn
}()
select {
case acceptCh <- newDTLSConn:
return
case <-ctx.Done():
return
}
}()
}
}
}

Expand All @@ -97,6 +109,7 @@ func NewListener(inner net.Listener, config *Config) (*Listener, error) {
connMap: map[[handshake.RandomBytesLength]byte](chan net.Conn){},
connToCert: map[[handshake.RandomBytesLength]byte]*certPair{},
defaultCert: defaultCert,
closed: make(chan struct{}),
logAuthFail: config.LogAuthFail,
logOther: config.LogOther,
}
Expand All @@ -110,18 +123,22 @@ func NewListener(inner net.Listener, config *Config) (*Listener, error) {
type Listener struct {
parent net.Listener
connMap map[[handshake.RandomBytesLength]byte](chan net.Conn)
connMapMutex sync.RWMutex
connMapMutex sync.Mutex
connToCert map[[handshake.RandomBytesLength]byte]*certPair
connToCertMutex sync.RWMutex
connToCertMutex sync.Mutex
defaultCert *tls.Certificate
logAuthFail func(*net.IP)
logOther func(*net.IP)

closeOnce sync.Once
closed chan struct{}
}

// Close closes the listener.
// Any blocked Accept operations will be unblocked and return errors.
// Already Accepted connections are not closed.
func (l *Listener) Close() error {
l.closeOnce.Do(func() { close(l.closed) })
return l.parent.Close()
}

Expand Down Expand Up @@ -223,6 +240,20 @@ func (l *Listener) registerChannel(connID [handshake.RandomBytesLength]byte) (<-
return connChan, nil
}

func (l *Listener) chFromID(id [handshake.RandomBytesLength]byte) (chan<- net.Conn, error) {

l.connMapMutex.Lock()
defer l.connMapMutex.Unlock()

acceptCh, ok := l.connMap[id]

if !ok {
return nil, fmt.Errorf("id not registered")
}

return acceptCh, nil
}

func (l *Listener) removeChannel(connID [handshake.RandomBytesLength]byte) {
l.connMapMutex.Lock()
defer l.connMapMutex.Unlock()
Expand All @@ -237,8 +268,8 @@ func (l *Listener) getCertificateFromClientHello(clientHello *dtls.ClientHelloIn
return l.defaultCert, nil
}

l.connToCertMutex.RLock()
defer l.connToCertMutex.RUnlock()
l.connToCertMutex.Lock()
defer l.connToCertMutex.Unlock()

certs, ok := l.connToCert[clientHello.RandomBytes]

Expand Down
Loading

0 comments on commit 6e9ba3c

Please sign in to comment.