Skip to content

Commit

Permalink
Reduce allocation in udp muxed conn addr decode (#686)
Browse files Browse the repository at this point in the history
  • Loading branch information
paulwe authored Apr 21, 2024
1 parent 7263f68 commit c1b4386
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 36 deletions.
22 changes: 18 additions & 4 deletions udp_mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,25 @@ func BenchmarkAddressEncoding(b *testing.B) {
}
buf := make([]byte, 64)

for i := 0; i < b.N; i++ {
if _, err := encodeUDPAddr(addr, buf); err != nil {
require.NoError(b, err)
b.Run("encode", func(b *testing.B) {
for i := 0; i < b.N; i++ {
if _, err := encodeUDPAddr(addr, buf); err != nil {
require.NoError(b, err)
}
}
}
})

b.Run("decode", func(b *testing.B) {
n, _ := encodeUDPAddr(addr, buf)
var addr *net.UDPAddr
var err error
for i := 0; i < b.N; i++ {
if addr, err = decodeUDPAddr(buf[:n]); err != nil {
require.NoError(b, err)
}
}
_ = addr
})
}

func testMuxConnection(t *testing.T, udpMux *UDPMuxDefault, ufrag string, network string) {
Expand Down
56 changes: 24 additions & 32 deletions udp_muxed_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@ import (
"encoding/binary"
"io"
"net"
"net/netip"
"reflect"
"sync"
"time"
"unsafe"

"github.com/pion/logging"
"github.com/pion/transport/v3/packetio"
Expand Down Expand Up @@ -214,51 +211,46 @@ func (c *udpMuxedConn) writePacket(data []byte, addr *net.UDPAddr) error {
}

func encodeUDPAddr(addr *net.UDPAddr, buf []byte) (int, error) {
if len(addr.IP) != 0 && len(addr.IP) != net.IPv4len && len(addr.IP) != net.IPv6len {
return 0, errInvalidAddress
total := 1 + len(addr.IP) + 2 + len(addr.Zone)
if len(buf) < total {
return 0, io.ErrShortBuffer
}

var n int
if ip4 := addr.IP.To4(); len(ip4) == net.IPv4len {
d := (*reflect.SliceHeader)(unsafe.Pointer(&ip4)) // nolint:gosec
n = len(netip.AddrFrom4(*(*[4]byte)(unsafe.Pointer(d.Data))).AppendTo(buf[2:2])) // nolint:gosec
} else if len(addr.IP) != 0 {
d := (*reflect.SliceHeader)(unsafe.Pointer(&addr.IP)) // nolint:gosec
n = len(netip.AddrFrom16(*(*[16]byte)(unsafe.Pointer(d.Data))).AppendTo(buf[2:2])) // nolint:gosec
}
buf[0] = uint8(len(addr.IP))
offset := 1

total := 2 + n + 2 + len(addr.Zone)
if total > len(buf) {
return 0, io.ErrShortBuffer
}
copy(buf[offset:], addr.IP)
offset += len(addr.IP)

binary.LittleEndian.PutUint16(buf, uint16(n))
offset := 2 + n
binary.LittleEndian.PutUint16(buf[offset:], uint16(addr.Port))
offset += 2

copy(buf[offset:], addr.Zone)
return total, nil
}

func decodeUDPAddr(buf []byte) (*net.UDPAddr, error) {
addr := net.UDPAddr{}
addr := &net.UDPAddr{}

offset := 0
ipLen := int(binary.LittleEndian.Uint16(buf[:2]))
offset += 2
// Basic bounds checking
if ipLen+offset > len(buf) {
if len(buf) == 0 || len(buf) < int(buf[0])+3 {
return nil, io.ErrShortBuffer
}
if err := addr.IP.UnmarshalText(buf[offset : offset+ipLen]); err != nil {
return nil, err

ipLen := int(buf[0])
offset := 1

if ipLen == 0 {
addr.IP = nil
} else {
addr.IP = append(addr.IP[:0], buf[offset:offset+ipLen]...)
offset += ipLen
}
offset += ipLen
addr.Port = int(binary.LittleEndian.Uint16(buf[offset : offset+2]))

addr.Port = int(binary.LittleEndian.Uint16(buf[offset:]))
offset += 2
zone := make([]byte, len(buf[offset:]))
copy(zone, buf[offset:])
addr.Zone = string(zone)

return &addr, nil
addr.Zone = string(buf[offset:])

return addr, nil
}

0 comments on commit c1b4386

Please sign in to comment.