diff --git a/udp_mux_test.go b/udp_mux_test.go index 5fdb67a4..77c575f4 100644 --- a/udp_mux_test.go +++ b/udp_mux_test.go @@ -173,11 +173,25 @@ func BenchmarkAddressEncoding(b *testing.B) { } buf := make([]byte, 64) - for i := 0; i < b.N; i++ { - if _, err := encodeUDPAddr(addr, buf); err != nil { - require.NoError(b, err) + b.Run("encode", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if _, err := encodeUDPAddr(addr, buf); err != nil { + require.NoError(b, err) + } } - } + }) + + b.Run("decode", func(b *testing.B) { + n, _ := encodeUDPAddr(addr, buf) + var addr *net.UDPAddr + var err error + for i := 0; i < b.N; i++ { + if addr, err = decodeUDPAddr(buf[:n]); err != nil { + require.NoError(b, err) + } + } + _ = addr + }) } func testMuxConnection(t *testing.T, udpMux *UDPMuxDefault, ufrag string, network string) { diff --git a/udp_muxed_conn.go b/udp_muxed_conn.go index 88c1a53c..9c279b89 100644 --- a/udp_muxed_conn.go +++ b/udp_muxed_conn.go @@ -7,11 +7,8 @@ import ( "encoding/binary" "io" "net" - "net/netip" - "reflect" "sync" "time" - "unsafe" "github.com/pion/logging" "github.com/pion/transport/v3/packetio" @@ -214,51 +211,46 @@ func (c *udpMuxedConn) writePacket(data []byte, addr *net.UDPAddr) error { } func encodeUDPAddr(addr *net.UDPAddr, buf []byte) (int, error) { - if len(addr.IP) != 0 && len(addr.IP) != net.IPv4len && len(addr.IP) != net.IPv6len { - return 0, errInvalidAddress + total := 1 + len(addr.IP) + 2 + len(addr.Zone) + if len(buf) < total { + return 0, io.ErrShortBuffer } - var n int - if ip4 := addr.IP.To4(); len(ip4) == net.IPv4len { - d := (*reflect.SliceHeader)(unsafe.Pointer(&ip4)) // nolint:gosec - n = len(netip.AddrFrom4(*(*[4]byte)(unsafe.Pointer(d.Data))).AppendTo(buf[2:2])) // nolint:gosec - } else if len(addr.IP) != 0 { - d := (*reflect.SliceHeader)(unsafe.Pointer(&addr.IP)) // nolint:gosec - n = len(netip.AddrFrom16(*(*[16]byte)(unsafe.Pointer(d.Data))).AppendTo(buf[2:2])) // nolint:gosec - } + buf[0] = uint8(len(addr.IP)) + offset := 1 - total := 2 + n + 2 + len(addr.Zone) - if total > len(buf) { - return 0, io.ErrShortBuffer - } + copy(buf[offset:], addr.IP) + offset += len(addr.IP) - binary.LittleEndian.PutUint16(buf, uint16(n)) - offset := 2 + n binary.LittleEndian.PutUint16(buf[offset:], uint16(addr.Port)) offset += 2 + copy(buf[offset:], addr.Zone) return total, nil } func decodeUDPAddr(buf []byte) (*net.UDPAddr, error) { - addr := net.UDPAddr{} + addr := &net.UDPAddr{} - offset := 0 - ipLen := int(binary.LittleEndian.Uint16(buf[:2])) - offset += 2 // Basic bounds checking - if ipLen+offset > len(buf) { + if len(buf) == 0 || len(buf) < int(buf[0])+3 { return nil, io.ErrShortBuffer } - if err := addr.IP.UnmarshalText(buf[offset : offset+ipLen]); err != nil { - return nil, err + + ipLen := int(buf[0]) + offset := 1 + + if ipLen == 0 { + addr.IP = nil + } else { + addr.IP = append(addr.IP[:0], buf[offset:offset+ipLen]...) + offset += ipLen } - offset += ipLen - addr.Port = int(binary.LittleEndian.Uint16(buf[offset : offset+2])) + + addr.Port = int(binary.LittleEndian.Uint16(buf[offset:])) offset += 2 - zone := make([]byte, len(buf[offset:])) - copy(zone, buf[offset:]) - addr.Zone = string(zone) - return &addr, nil + addr.Zone = string(buf[offset:]) + + return addr, nil }