Skip to content

Commit

Permalink
UDPMux support multi ports
Browse files Browse the repository at this point in the history
  • Loading branch information
cnderrauber committed Aug 16, 2023
1 parent ce9ec53 commit d8026b9
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 53 deletions.
80 changes: 60 additions & 20 deletions udp_mux_multi.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package ice
import (
"fmt"
"net"
"sync/atomic"
"time"

"github.com/pion/logging"
Expand All @@ -19,20 +20,50 @@ import (
type MultiUDPMuxDefault struct {
muxes []UDPMux
localAddrToMux map[string]UDPMux

// Manage port balance for mux that listen on multiple ports for same IP,
// for each IP, only return one addr (one port) for each GetListenAddresses call to
// avoid duplicate ip candidates be gathered for a single ice agent.
multiPortsAddresses []*multiPortsAddress
}

type multiPortsAddress struct {
addresses []net.Addr
nextPos atomic.Int32
}

func (addr *multiPortsAddress) next() net.Addr {
return addr.addresses[addr.nextPos.Add(1)%int32(len(addr.addresses))]
}

// NewMultiUDPMuxDefault creates an instance of MultiUDPMuxDefault that
// uses the provided UDPMux instances.
func NewMultiUDPMuxDefault(muxes ...UDPMux) *MultiUDPMuxDefault {
addrToMux := make(map[string]UDPMux)
ipToAddrs := make(map[string]*multiPortsAddress)
for _, mux := range muxes {
for _, addr := range mux.GetListenAddresses() {
addrToMux[addr.String()] = mux

ip := addr.(*net.UDPAddr).IP.String()
if mpa, ok := ipToAddrs[ip]; ok {
mpa.addresses = append(mpa.addresses, addr)
} else {
ipToAddrs[ip] = &multiPortsAddress{
addresses: []net.Addr{addr},
}
}
}
}

multiPortsAddresses := make([]*multiPortsAddress, 0, len(ipToAddrs))
for _, mpa := range ipToAddrs {
multiPortsAddresses = append(multiPortsAddresses, mpa)
}
return &MultiUDPMuxDefault{
muxes: muxes,
localAddrToMux: addrToMux,
muxes: muxes,
localAddrToMux: addrToMux,
multiPortsAddresses: multiPortsAddresses,
}
}

Expand Down Expand Up @@ -67,16 +98,20 @@ func (m *MultiUDPMuxDefault) Close() error {

// GetListenAddresses returns the list of addresses that this mux is listening on
func (m *MultiUDPMuxDefault) GetListenAddresses() []net.Addr {
addrs := make([]net.Addr, 0, len(m.localAddrToMux))
for _, mux := range m.muxes {
addrs = append(addrs, mux.GetListenAddresses()...)
addrs := make([]net.Addr, 0, len(m.multiPortsAddresses))
for _, mpa := range m.multiPortsAddresses {
addrs = append(addrs, mpa.next())
}
return addrs
}

// NewMultiUDPMuxFromPort creates an instance of MultiUDPMuxDefault that
// listen all interfaces on the provided port.
func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMuxDefault, error) {
return NewMultiUDPMuxFromPorts([]int{port}, opts...)
}

func NewMultiUDPMuxFromPorts(ports []int, opts ...UDPMuxFromPortOption) (*MultiUDPMuxDefault, error) {
params := multiUDPMuxFromPortParam{
networks: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
}
Expand All @@ -96,23 +131,28 @@ func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMu
return nil, err
}

conns := make([]net.PacketConn, 0, len(ips))
conns := make([]net.PacketConn, 0, len(ports)*len(ips))
for _, ip := range ips {
conn, listenErr := params.net.ListenUDP("udp", &net.UDPAddr{IP: ip, Port: port})
if listenErr != nil {
err = listenErr
break
for _, port := range ports {
conn, listenErr := params.net.ListenUDP("udp", &net.UDPAddr{IP: ip, Port: port})
if listenErr != nil {
err = listenErr
break
}
if params.readBufferSize > 0 {
_ = conn.SetReadBuffer(params.readBufferSize)
}
if params.writeBufferSize > 0 {
_ = conn.SetWriteBuffer(params.writeBufferSize)
}
if params.batchWriteSize > 0 {
conns = append(conns, NewBatchConn(conn, params.batchWriteSize, params.batchWriteInterval, params.logger))
} else {
conns = append(conns, conn)
}
}
if params.readBufferSize > 0 {
_ = conn.SetReadBuffer(params.readBufferSize)
}
if params.writeBufferSize > 0 {
_ = conn.SetWriteBuffer(params.writeBufferSize)
}
if params.batchWriteSize > 0 {
conns = append(conns, NewBatchConn(conn, params.batchWriteSize, params.batchWriteInterval, params.logger))
} else {
conns = append(conns, conn)
if err != nil {
break
}
}

Expand Down
95 changes: 62 additions & 33 deletions udp_mux_multi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package ice

import (
"fmt"
"net"
"strings"
"sync"
Expand Down Expand Up @@ -117,39 +118,67 @@ func TestUnspecifiedUDPMux(t *testing.T) {
lim := test.TimeOut(time.Second * 30)
defer lim.Stop()

muxPort := 7778
udpMuxMulti, err := NewMultiUDPMuxFromPort(muxPort, UDPMuxFromPortWithInterfaceFilter(func(s string) bool {
defaultDockerBridgeNetwork := strings.Contains(s, "docker")
customDockerBridgeNetwork := strings.Contains(s, "br-")
return !defaultDockerBridgeNetwork && !customDockerBridgeNetwork
}))
require.NoError(t, err)

require.GreaterOrEqual(t, len(udpMuxMulti.muxes), 1, "at least have 1 muxes")
defer func() {
_ = udpMuxMulti.Close()
}()

wg := sync.WaitGroup{}

wg.Add(1)
go func() {
defer wg.Done()
testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag1", udp)
}()
wg.Add(1)
go func() {
defer wg.Done()
testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag2", udp4)
}()

// Skip IPv6 test on i386
const ptrSize = 32 << (^uintptr(0) >> 63)
if ptrSize != 32 {
testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag3", udp6)
cases := map[string][]int{
"single port": {7778},
"multi ports": {7779, 7780, 7781},
}

wg.Wait()

require.NoError(t, udpMuxMulti.Close())
for name, ports := range cases {
cname, cports := name, ports
t.Run(cname, func(t *testing.T) {
udpMuxMulti, err := NewMultiUDPMuxFromPorts(cports, UDPMuxFromPortWithInterfaceFilter(func(s string) bool {
defaultDockerBridgeNetwork := strings.Contains(s, "docker")
customDockerBridgeNetwork := strings.Contains(s, "br-")
return !defaultDockerBridgeNetwork && !customDockerBridgeNetwork
}))
require.NoError(t, err)

require.GreaterOrEqual(t, len(udpMuxMulti.muxes), 1, "at least have 1 muxes")
defer func() {
_ = udpMuxMulti.Close()
}()

wg := sync.WaitGroup{}

wg.Add(1)
go func() {
defer wg.Done()
testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag1", udp)
}()
wg.Add(1)
go func() {
defer wg.Done()
testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag2", udp4)
}()

// Skip IPv6 test on i386
const ptrSize = 32 << (^uintptr(0) >> 63)
if ptrSize != 32 {
testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag3", udp6)
}

wg.Wait()

// check port allocation is balanced
if len(cports) > 1 {
expectPorts := make(map[int]bool)
for i := range cports {
addr := udpMuxMulti.GetListenAddresses()[0]
ufrag := fmt.Sprintf("ufragetest%d", i)
conn, err := udpMuxMulti.GetConn(ufrag, addr)
require.NoError(t, err)
require.NotNil(t, conn)
require.False(t, expectPorts[conn.LocalAddr().(*net.UDPAddr).Port], fmt.Sprint("port ", conn.LocalAddr().(*net.UDPAddr).Port, " is already used", expectPorts))
expectPorts[conn.LocalAddr().(*net.UDPAddr).Port] = true

conn2, err := udpMuxMulti.GetConn(ufrag, addr)
require.NoError(t, err)
require.Equal(t, conn, conn2)
}
require.Equal(t, len(cports), len(expectPorts))
}

require.NoError(t, udpMuxMulti.Close())
})
}
}

0 comments on commit d8026b9

Please sign in to comment.