diff --git a/realclientip.go b/realclientip.go index 193c0ca..4aa5d96 100644 --- a/realclientip.go +++ b/realclientip.go @@ -5,6 +5,7 @@ package realclientip import ( "fmt" + "iter" "net" "net/http" "strings" @@ -42,7 +43,8 @@ func Must(strat Strategy, err error) Strategy { // strategies are exhausted. // A common use for this is if a server is both directly connected to the internet and // expecting a header to check. It might be called like: -// NewChainStrategy(Must(LeftmostNonPrivateStrategy("X-Forwarded-For")), RemoteAddrStrategy) +// +// NewChainStrategy(Must(LeftmostNonPrivateStrategy("X-Forwarded-For")), RemoteAddrStrategy) type ChainStrategy struct { strategies []Strategy } @@ -188,8 +190,7 @@ func NewLeftmostNonPrivateStrategy(headerName string) (LeftmostNonPrivateStrateg // The returned IP may contain a zone identifier. // If no valid IP can be derived, empty string will be returned. func (strat LeftmostNonPrivateStrategy) ClientIP(headers http.Header, _ string) string { - ipAddrs := getIPAddrList(headers, strat.headerName) - for _, ip := range ipAddrs { + for ip := range allIPAddrFromFirstToLast(headers, strat.headerName) { if ip != nil && !isPrivateOrLocal(ip.IP) { // This is the leftmost valid, non-private IP return ip.String() @@ -231,12 +232,11 @@ func NewRightmostNonPrivateStrategy(headerName string) (RightmostNonPrivateStrat // The returned IP may contain a zone identifier. // If no valid IP can be derived, empty string will be returned. func (strat RightmostNonPrivateStrategy) ClientIP(headers http.Header, _ string) string { - ipAddrs := getIPAddrList(headers, strat.headerName) // Look backwards through the list of IP addresses - for i := len(ipAddrs) - 1; i >= 0; i-- { - if ipAddrs[i] != nil && !isPrivateOrLocal(ipAddrs[i].IP) { + for ip := range allIPAddrFromLastToFirst(headers, strat.headerName) { + if ip != nil && !isPrivateOrLocal(ip.IP) { // This is the rightmost non-private IP - return ipAddrs[i].String() + return ip.String() } } @@ -283,27 +283,33 @@ func NewRightmostTrustedCountStrategy(headerName string, trustedCount int) (Righ // The returned IP may contain a zone identifier. // If no valid IP can be derived, empty string will be returned. func (strat RightmostTrustedCountStrategy) ClientIP(headers http.Header, _ string) string { - ipAddrs := getIPAddrList(headers, strat.headerName) - // We want the (N-1)th from the rightmost. For example, if there's only one // trusted proxy, we want the last. - rightmostIndex := len(ipAddrs) - 1 - targetIndex := rightmostIndex - (strat.trustedCount - 1) + targetCount := strat.trustedCount - 1 + var ( + ip *net.IPAddr + count int + ) + for ip = range allIPAddrFromLastToFirst(headers, strat.headerName) { + if count < targetCount { + count++ + continue + } + break + } - if targetIndex < 0 { + if count < targetCount { // This is a misconfiguration error. There were fewer IPs than we expected. return "" } - resultIP := ipAddrs[targetIndex] - - if resultIP == nil { + if ip == nil { // This is a misconfiguration error. Our first trusted proxy didn't add a // valid IP address to the header. return "" } - return resultIP.String() + return ip.String() } // AddressesAndRangesToIPNets converts a slice of strings with IPv4 and IPv6 addresses and @@ -393,21 +399,20 @@ func NewRightmostTrustedRangeStrategy(headerName string, trustedRanges []net.IPN // The returned IP may contain a zone identifier. // If no valid IP can be derived, empty string will be returned. func (strat RightmostTrustedRangeStrategy) ClientIP(headers http.Header, _ string) string { - ipAddrs := getIPAddrList(headers, strat.headerName) // Look backwards through the list of IP addresses - for i := len(ipAddrs) - 1; i >= 0; i-- { - if ipAddrs[i] != nil && isIPContainedInRanges(ipAddrs[i].IP, strat.trustedRanges) { + for ip := range allIPAddrFromLastToFirst(headers, strat.headerName) { + if ip != nil && isIPContainedInRanges(ip.IP, strat.trustedRanges) { // This IP is trusted continue } // At this point we have found the first-from-the-rightmost untrusted IP - if ipAddrs[i] == nil { + if ip == nil { return "" } - return ipAddrs[i].String() + return ip.String() } // Either there are no addresses or they are all in our trusted ranges @@ -445,44 +450,105 @@ func lastHeader(headers http.Header, headerName string) string { return matches[len(matches)-1] } -// getIPAddrList creates a single list of all of the X-Forwarded-For or Forwarded header -// values, in order. Any invalid IPs will result in nil elements. headerName must already -// be canonicalized. -func getIPAddrList(headers http.Header, headerName string) []*net.IPAddr { - var result []*net.IPAddr - - // There may be multiple XFF headers present. We need to iterate through them all, - // in order, and collect all of the IPs. - // Note that we're not joining all of the headers into a single string and then - // splitting. Doing it that way would use more memory. - // Note that Go's Header map uses canonicalized keys. - for _, h := range headers[headerName] { - // We now have a string with comma-separated list items - for _, rawListItem := range strings.Split(h, ",") { - // The IPs are often comma-space separated, so we'll need to trim the string - rawListItem = strings.TrimSpace(rawListItem) - - var ipAddr *net.IPAddr - // If this is the XFF header, rawListItem is just an IP; - // if it's the Forwarded header, then there's more parsing to do. - if headerName == forwardedHdr { - ipAddr = parseForwardedListItem(rawListItem) - } else { // == XFF - ipAddr = goodIPAddr(rawListItem) +// allIPAddrFromLastToFirst returns an iterator over all of the X-Forwarded-For +// or Forwarded header values, from first to last. Any invalid IPs will result +// in nil elements. headerName must already be canonicalized. +func allIPAddrFromFirstToLast(headers http.Header, headerName string) iter.Seq[*net.IPAddr] { + return func(yield func(*net.IPAddr) bool) { + // There may be multiple XFF headers present. We need to iterate through them all, + // in order, and collect all of the IPs. + // Note that we're not joining all of the headers into a single string and then + // splitting. Doing it that way would use more memory. + // Note that Go's Header map uses canonicalized keys. + for _, h := range headers[headerName] { + var ( + rawListItem string + commaFound bool + ) + rawListItem, h, commaFound = strings.Cut(h, ",") + // We now have a string with comma-separated list items + for { + // The IPs are often comma-space separated, so we'll need to trim the string + rawListItem = strings.TrimSpace(rawListItem) + + var ipAddr *net.IPAddr + // If this is the XFF header, rawListItem is just an IP; + // if it's the Forwarded header, then there's more parsing to do. + if headerName == forwardedHdr { + ipAddr = parseForwardedListItem(rawListItem) + } else { // == XFF + ipAddr = goodIPAddr(rawListItem) + } + + // ipAddr is nil if not valid + if !yield(ipAddr) { + return + } + if !commaFound { + break + } + rawListItem, h, commaFound = strings.Cut(h, ",") } - - // ipAddr is nil if not valid - result = append(result, ipAddr) } } +} - // Possible performance improvements: - // Here we are parsing _all_ of the IPs in the XFF headers, but we don't need all of - // them. Instead, we could start from the left or the right (depending on strategy), - // parse as we go, and stop when we've come to the one we want. But that would make - // the various strategies somewhat more complex. +// allIPAddrFromLastToFirst returns an iterator over all of the X-Forwarded-For +// or Forwarded header values, from last to first. Any invalid IPs will result +// in nil elements. headerName must already be canonicalized. +func allIPAddrFromLastToFirst(headers http.Header, headerName string) iter.Seq[*net.IPAddr] { + return func(yield func(*net.IPAddr) bool) { + // There may be multiple XFF headers present. We need to iterate through them all, + // in order, and collect all of the IPs. + // Note that we're not joining all of the headers into a single string and then + // splitting. Doing it that way would use more memory. + // Note that Go's Header map uses canonicalized keys. + hs := headers[headerName] + last := len(hs) - 1 + for i := last; 0 <= i; i-- { + h := hs[i] + var ( + rawListItem string + commaFound bool + ) + h, rawListItem, commaFound = cutLast(h, ",") + // We now have a string with comma-separated list items + for { + // The IPs are often comma-space separated, so we'll need to trim the string + rawListItem = strings.TrimSpace(rawListItem) + + var ipAddr *net.IPAddr + // If this is the XFF header, rawListItem is just an IP; + // if it's the Forwarded header, then there's more parsing to do. + if headerName == forwardedHdr { + ipAddr = parseForwardedListItem(rawListItem) + } else { // == XFF + ipAddr = goodIPAddr(rawListItem) + } + + // ipAddr is nil if not valid + if !yield(ipAddr) { + return + } + if !commaFound { + break + } + h, rawListItem, commaFound = cutLast(h, ",") + } + } + } +} - return result +// cutLast slices s around the last instance of sep, +// returning the text before and after sep. +// The found result reports whether sep appears in s. +// If sep does not appear in s, cut returns "", s, false. +// (Adapted from strings.Cut.) +func cutLast(s, sep string) (before, after string, found bool) { + if i := strings.LastIndex(s, sep); i >= 0 { + return s[:i], s[i+len(sep):], true + } + return "", s, false } // parseForwardedListItem parses a Forwarded header list item, and returns the "for" IP diff --git a/realclientip_test.go b/realclientip_test.go index fb11485..c4d0398 100644 --- a/realclientip_test.go +++ b/realclientip_test.go @@ -2034,7 +2034,11 @@ func Test_forwardedHeaderRFCDeviations(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := getIPAddrList(tt.args.headers, tt.args.headerName); !reflect.DeepEqual(got, tt.want) { + var got []*net.IPAddr + for ip := range allIPAddrFromFirstToLast(tt.args.headers, tt.args.headerName) { + got = append(got, ip) + } + if !reflect.DeepEqual(got, tt.want) { t.Errorf("getIPAddrList() = %v, want %v", got, tt.want) } })