Skip to content

Commit

Permalink
Avoid creating needless slices of IPs (fixes realclientip#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
jub0bs committed Sep 4, 2024
1 parent 0b402ba commit 2d07147
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 55 deletions.
174 changes: 120 additions & 54 deletions realclientip.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package realclientip

import (
"fmt"
"iter"
"net"
"net/http"
"strings"
Expand Down Expand Up @@ -42,7 +43,8 @@ func Must(strat Strategy, err error) Strategy {
// strategies are exhausted.
// A common use for this is if a server is both directly connected to the internet and
// expecting a header to check. It might be called like:
// NewChainStrategy(Must(LeftmostNonPrivateStrategy("X-Forwarded-For")), RemoteAddrStrategy)
//
// NewChainStrategy(Must(LeftmostNonPrivateStrategy("X-Forwarded-For")), RemoteAddrStrategy)
type ChainStrategy struct {
strategies []Strategy
}
Expand Down Expand Up @@ -188,8 +190,7 @@ func NewLeftmostNonPrivateStrategy(headerName string) (LeftmostNonPrivateStrateg
// The returned IP may contain a zone identifier.
// If no valid IP can be derived, empty string will be returned.
func (strat LeftmostNonPrivateStrategy) ClientIP(headers http.Header, _ string) string {
ipAddrs := getIPAddrList(headers, strat.headerName)
for _, ip := range ipAddrs {
for ip := range allIPAddrFromFirstToLast(headers, strat.headerName) {
if ip != nil && !isPrivateOrLocal(ip.IP) {
// This is the leftmost valid, non-private IP
return ip.String()
Expand Down Expand Up @@ -231,12 +232,11 @@ func NewRightmostNonPrivateStrategy(headerName string) (RightmostNonPrivateStrat
// The returned IP may contain a zone identifier.
// If no valid IP can be derived, empty string will be returned.
func (strat RightmostNonPrivateStrategy) ClientIP(headers http.Header, _ string) string {
ipAddrs := getIPAddrList(headers, strat.headerName)
// Look backwards through the list of IP addresses
for i := len(ipAddrs) - 1; i >= 0; i-- {
if ipAddrs[i] != nil && !isPrivateOrLocal(ipAddrs[i].IP) {
for ip := range allIPAddrFromLastToFirst(headers, strat.headerName) {
if ip != nil && !isPrivateOrLocal(ip.IP) {
// This is the rightmost non-private IP
return ipAddrs[i].String()
return ip.String()
}
}

Expand Down Expand Up @@ -283,27 +283,33 @@ func NewRightmostTrustedCountStrategy(headerName string, trustedCount int) (Righ
// The returned IP may contain a zone identifier.
// If no valid IP can be derived, empty string will be returned.
func (strat RightmostTrustedCountStrategy) ClientIP(headers http.Header, _ string) string {
ipAddrs := getIPAddrList(headers, strat.headerName)

// We want the (N-1)th from the rightmost. For example, if there's only one
// trusted proxy, we want the last.
rightmostIndex := len(ipAddrs) - 1
targetIndex := rightmostIndex - (strat.trustedCount - 1)
targetCount := strat.trustedCount - 1
var (
ip *net.IPAddr
count int
)
for ip = range allIPAddrFromLastToFirst(headers, strat.headerName) {
if count < targetCount {
count++
continue
}
break
}

if targetIndex < 0 {
if count < targetCount {
// This is a misconfiguration error. There were fewer IPs than we expected.
return ""
}

resultIP := ipAddrs[targetIndex]

if resultIP == nil {
if ip == nil {
// This is a misconfiguration error. Our first trusted proxy didn't add a
// valid IP address to the header.
return ""
}

return resultIP.String()
return ip.String()
}

// AddressesAndRangesToIPNets converts a slice of strings with IPv4 and IPv6 addresses and
Expand Down Expand Up @@ -393,21 +399,20 @@ func NewRightmostTrustedRangeStrategy(headerName string, trustedRanges []net.IPN
// The returned IP may contain a zone identifier.
// If no valid IP can be derived, empty string will be returned.
func (strat RightmostTrustedRangeStrategy) ClientIP(headers http.Header, _ string) string {
ipAddrs := getIPAddrList(headers, strat.headerName)
// Look backwards through the list of IP addresses
for i := len(ipAddrs) - 1; i >= 0; i-- {
if ipAddrs[i] != nil && isIPContainedInRanges(ipAddrs[i].IP, strat.trustedRanges) {
for ip := range allIPAddrFromLastToFirst(headers, strat.headerName) {
if ip != nil && isIPContainedInRanges(ip.IP, strat.trustedRanges) {
// This IP is trusted
continue
}

// At this point we have found the first-from-the-rightmost untrusted IP

if ipAddrs[i] == nil {
if ip == nil {
return ""
}

return ipAddrs[i].String()
return ip.String()
}

// Either there are no addresses or they are all in our trusted ranges
Expand Down Expand Up @@ -445,44 +450,105 @@ func lastHeader(headers http.Header, headerName string) string {
return matches[len(matches)-1]
}

// getIPAddrList creates a single list of all of the X-Forwarded-For or Forwarded header
// values, in order. Any invalid IPs will result in nil elements. headerName must already
// be canonicalized.
func getIPAddrList(headers http.Header, headerName string) []*net.IPAddr {
var result []*net.IPAddr

// There may be multiple XFF headers present. We need to iterate through them all,
// in order, and collect all of the IPs.
// Note that we're not joining all of the headers into a single string and then
// splitting. Doing it that way would use more memory.
// Note that Go's Header map uses canonicalized keys.
for _, h := range headers[headerName] {
// We now have a string with comma-separated list items
for _, rawListItem := range strings.Split(h, ",") {
// The IPs are often comma-space separated, so we'll need to trim the string
rawListItem = strings.TrimSpace(rawListItem)

var ipAddr *net.IPAddr
// If this is the XFF header, rawListItem is just an IP;
// if it's the Forwarded header, then there's more parsing to do.
if headerName == forwardedHdr {
ipAddr = parseForwardedListItem(rawListItem)
} else { // == XFF
ipAddr = goodIPAddr(rawListItem)
// allIPAddrFromLastToFirst returns an iterator over all of the X-Forwarded-For
// or Forwarded header values, from first to last. Any invalid IPs will result
// in nil elements. headerName must already be canonicalized.
func allIPAddrFromFirstToLast(headers http.Header, headerName string) iter.Seq[*net.IPAddr] {
return func(yield func(*net.IPAddr) bool) {
// There may be multiple XFF headers present. We need to iterate through them all,
// in order, and collect all of the IPs.
// Note that we're not joining all of the headers into a single string and then
// splitting. Doing it that way would use more memory.
// Note that Go's Header map uses canonicalized keys.
for _, h := range headers[headerName] {
var (
rawListItem string
commaFound bool
)
rawListItem, h, commaFound = strings.Cut(h, ",")
// We now have a string with comma-separated list items
for {
// The IPs are often comma-space separated, so we'll need to trim the string
rawListItem = strings.TrimSpace(rawListItem)

var ipAddr *net.IPAddr
// If this is the XFF header, rawListItem is just an IP;
// if it's the Forwarded header, then there's more parsing to do.
if headerName == forwardedHdr {
ipAddr = parseForwardedListItem(rawListItem)
} else { // == XFF
ipAddr = goodIPAddr(rawListItem)
}

// ipAddr is nil if not valid
if !yield(ipAddr) {
return
}
if !commaFound {
break
}
rawListItem, h, commaFound = strings.Cut(h, ",")
}

// ipAddr is nil if not valid
result = append(result, ipAddr)
}
}
}

// Possible performance improvements:
// Here we are parsing _all_ of the IPs in the XFF headers, but we don't need all of
// them. Instead, we could start from the left or the right (depending on strategy),
// parse as we go, and stop when we've come to the one we want. But that would make
// the various strategies somewhat more complex.
// allIPAddrFromLastToFirst returns an iterator over all of the X-Forwarded-For
// or Forwarded header values, from last to first. Any invalid IPs will result
// in nil elements. headerName must already be canonicalized.
func allIPAddrFromLastToFirst(headers http.Header, headerName string) iter.Seq[*net.IPAddr] {
return func(yield func(*net.IPAddr) bool) {
// There may be multiple XFF headers present. We need to iterate through them all,
// in order, and collect all of the IPs.
// Note that we're not joining all of the headers into a single string and then
// splitting. Doing it that way would use more memory.
// Note that Go's Header map uses canonicalized keys.
hs := headers[headerName]
last := len(hs) - 1
for i := last; 0 <= i; i-- {
h := hs[i]
var (
rawListItem string
commaFound bool
)
h, rawListItem, commaFound = cutLast(h, ",")
// We now have a string with comma-separated list items
for {
// The IPs are often comma-space separated, so we'll need to trim the string
rawListItem = strings.TrimSpace(rawListItem)

var ipAddr *net.IPAddr
// If this is the XFF header, rawListItem is just an IP;
// if it's the Forwarded header, then there's more parsing to do.
if headerName == forwardedHdr {
ipAddr = parseForwardedListItem(rawListItem)
} else { // == XFF
ipAddr = goodIPAddr(rawListItem)
}

// ipAddr is nil if not valid
if !yield(ipAddr) {
return
}
if !commaFound {
break
}
h, rawListItem, commaFound = cutLast(h, ",")
}
}
}
}

return result
// cutLast slices s around the last instance of sep,
// returning the text before and after sep.
// The found result reports whether sep appears in s.
// If sep does not appear in s, cut returns "", s, false.
// (Adapted from strings.Cut.)
func cutLast(s, sep string) (before, after string, found bool) {
if i := strings.LastIndex(s, sep); i >= 0 {
return s[:i], s[i+len(sep):], true
}
return "", s, false
}

// parseForwardedListItem parses a Forwarded header list item, and returns the "for" IP
Expand Down
6 changes: 5 additions & 1 deletion realclientip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2034,7 +2034,11 @@ func Test_forwardedHeaderRFCDeviations(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getIPAddrList(tt.args.headers, tt.args.headerName); !reflect.DeepEqual(got, tt.want) {
var got []*net.IPAddr
for ip := range allIPAddrFromFirstToLast(tt.args.headers, tt.args.headerName) {
got = append(got, ip)
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("getIPAddrList() = %v, want %v", got, tt.want)
}
})
Expand Down

0 comments on commit 2d07147

Please sign in to comment.