Skip to content

Commit

Permalink
Replace net.IP with netip.Addr (where possible)
Browse files Browse the repository at this point in the history
  • Loading branch information
hslatman committed Oct 2, 2024
1 parent a875f18 commit c332c11
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 40 deletions.
4 changes: 2 additions & 2 deletions crowdsec/crowdsec.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"reflect"
"runtime/debug"
"slices"
Expand Down Expand Up @@ -248,7 +248,7 @@ func (c *CrowdSec) Stop() error {

// IsAllowed is used by the CrowdSec HTTP handler to check if
// an IP is allowed to perform a request
func (c *CrowdSec) IsAllowed(ip net.IP) (bool, *models.Decision, error) {
func (c *CrowdSec) IsAllowed(ip netip.Addr) (bool, *models.Decision, error) {
// TODO: check if running? fully loaded, etc?
return c.bouncer.IsAllowed(ip)
}
Expand Down
6 changes: 3 additions & 3 deletions crowdsec/crowdsec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/netip"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -200,7 +200,7 @@ func TestCrowdSec_streamingBouncerRuntime(t *testing.T) {
time.Sleep(100 * time.Millisecond)

// simulate a lookup
allowed, decision, err := c.IsAllowed(net.ParseIP("127.0.0.1"))
allowed, decision, err := c.IsAllowed(netip.MustParseAddr("127.0.0.1"))
assert.NoError(t, err)
assert.Nil(t, decision)
assert.True(t, allowed)
Expand Down Expand Up @@ -253,7 +253,7 @@ func TestCrowdSec_liveBouncerRuntime(t *testing.T) {
require.NoError(t, err)

// simulate a lookup
allowed, decision, err := c.IsAllowed(net.ParseIP("127.0.0.1"))
allowed, decision, err := c.IsAllowed(netip.MustParseAddr("127.0.0.1"))
assert.NoError(t, err)
assert.Nil(t, decision)
assert.True(t, allowed)
Expand Down
18 changes: 9 additions & 9 deletions http/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ package http
import (
"errors"
"fmt"
"net"
"net/http"
"net/netip"
"time"

"github.com/caddyserver/caddy/v2"
Expand Down Expand Up @@ -74,7 +74,6 @@ func (h *Handler) Validate() error {

// ServeHTTP is the Caddy handler for serving HTTP requests
func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {

ipToCheck, err := determineIPFromRequest(r)
if err != nil {
return err // TODO: return error here? Or just log it and continue serving
Expand Down Expand Up @@ -149,25 +148,26 @@ func writeThrottleResponse(w http.ResponseWriter, duration string) error {
// Caddy extracts from the original request and stores in the request context.
// Support for setting the real client IP in case a proxy sits in front of
// Caddy was added, so the client IP reported here is the actual client IP.
func determineIPFromRequest(r *http.Request) (net.IP, error) {
func determineIPFromRequest(r *http.Request) (netip.Addr, error) {
zero := netip.Addr{}
clientIPVar := caddyhttp.GetVar(r.Context(), caddyhttp.ClientIPVarKey)
if clientIPVar == nil {
return nil, errors.New("failed getting client IP from context")
return zero, errors.New("failed getting client IP from context")
}

var clientIP string
var ok bool
if clientIP, ok = clientIPVar.(string); !ok {
return nil, fmt.Errorf("client IP from request context is invalid type %T", clientIPVar)
return zero, fmt.Errorf("client IP from request context is invalid type %T", clientIPVar)
}

if clientIP == "" {
return nil, errors.New("client IP from request context is empty")
return zero, errors.New("client IP from request context is empty")
}

ip := net.ParseIP(clientIP)
if ip == nil {
return nil, fmt.Errorf("could not parse %q into net.IP", clientIP)
ip, err := netip.ParseAddr(clientIP)
if err != nil {
return zero, fmt.Errorf("could not parse %q into netip.Addr", clientIP)
}

return ip, nil
Expand Down
14 changes: 7 additions & 7 deletions http/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ package http

import (
"context"
"net"
"net/http"
"net/http/httptest"
"net/netip"
"reflect"
"testing"

Expand Down Expand Up @@ -47,14 +47,14 @@ func Test_determineIPFromRequest(t *testing.T) {
tests := []struct {
name string
args args
want net.IP
want netip.Addr
wantErr bool
}{
{"ok", args{r.WithContext(okCtx)}, net.ParseIP("127.0.0.1"), false},
{"no-ip", args{r.WithContext(noIPCtx)}, nil, true},
{"wrong-type", args{r.WithContext(wrongTypeCtx)}, nil, true},
{"empty-ip", args{r.WithContext(emptyIPCtx)}, nil, true},
{"invalid-ip", args{r.WithContext(invalidIPCtx)}, nil, true},
{"ok", args{r.WithContext(okCtx)}, netip.MustParseAddr("127.0.0.1"), false},
{"no-ip", args{r.WithContext(noIPCtx)}, netip.Addr{}, true},
{"wrong-type", args{r.WithContext(wrongTypeCtx)}, netip.Addr{}, true},
{"empty-ip", args{r.WithContext(emptyIPCtx)}, netip.Addr{}, true},
{"invalid-ip", args{r.WithContext(invalidIPCtx)}, netip.Addr{}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down
5 changes: 2 additions & 3 deletions internal/bouncer/bouncer.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
"encoding/hex"
"fmt"
"math/rand"
"net"
"net/netip"
"sync"
"time"

Expand Down Expand Up @@ -201,8 +201,7 @@ func (b *Bouncer) Shutdown() error {
}

// IsAllowed checks if an IP is allowed or not
func (b *Bouncer) IsAllowed(ip net.IP) (bool, *models.Decision, error) {

func (b *Bouncer) IsAllowed(ip netip.Addr) (bool, *models.Decision, error) {
// TODO: perform lookup in explicit allowlist as a kind of quick lookup in front of the CrowdSec lookup list?
isAllowed := false
decision, err := b.retrieveDecision(ip)
Expand Down
18 changes: 9 additions & 9 deletions internal/bouncer/bouncer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package bouncer

import (
"context"
"net"
"net/netip"
"net/url"
"regexp"
"testing"
Expand Down Expand Up @@ -144,7 +144,7 @@ func TestStreamingBouncer(t *testing.T) {
time.Sleep(1 * time.Second)

type args struct {
ip net.IP
ip netip.Addr
}
tests := []struct {
name string
Expand All @@ -155,55 +155,55 @@ func TestStreamingBouncer(t *testing.T) {
{
name: "127.0.0.1 not allowed",
args: args{
ip: net.ParseIP("127.0.0.1"),
ip: netip.MustParseAddr("127.0.0.1"),
},
want: false,
wantErr: false,
},
{
name: "127.0.0.2 not allowed",
args: args{
ip: net.ParseIP("127.0.0.2"),
ip: netip.MustParseAddr("127.0.0.2"),
},
want: false,
wantErr: false,
},
{
name: "127.0.0.3 allowed",
args: args{
ip: net.ParseIP("127.0.0.3"),
ip: netip.MustParseAddr("127.0.0.3"),
},
want: true,
wantErr: false,
},
{
name: "10.0.0.1/24 (10.0.0.1) not allowed",
args: args{
ip: net.ParseIP("10.0.0.1"),
ip: netip.MustParseAddr("10.0.0.1"),
},
want: false,
wantErr: false,
},
{
name: "10.0.1.0 allowed",
args: args{
ip: net.ParseIP("10.0.1.0"),
ip: netip.MustParseAddr("10.0.1.0"),
},
want: true,
wantErr: false,
},
{
name: "128.0.0.1 not allowed",
args: args{
ip: net.ParseIP("128.0.0.1"),
ip: netip.MustParseAddr("128.0.0.1"),
},
want: false,
wantErr: false,
},
{
name: "129.0.0.1 allowed",
args: args{
ip: net.ParseIP("129.0.0.1"),
ip: netip.MustParseAddr("129.0.0.1"),
},
want: true,
wantErr: false,
Expand Down
5 changes: 4 additions & 1 deletion internal/bouncer/decisions.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net"
"net/netip"

"github.com/crowdsecurity/crowdsec/pkg/models"
"go.uber.org/zap"
Expand Down Expand Up @@ -95,7 +96,9 @@ func (b *Bouncer) delete(decision *models.Decision) error {
return b.store.delete(decision)
}

func (b *Bouncer) retrieveDecision(ip net.IP) (*models.Decision, error) {
func (b *Bouncer) retrieveDecision(ipAddr netip.Addr) (*models.Decision, error) {
ip := net.IP(ipAddr.AsSlice()) // TODO: feed through netip.Addr fully

if b.useStreamingBouncer {
return b.store.get(ip)
}
Expand Down
11 changes: 5 additions & 6 deletions layer4/l4.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package layer4
import (
"fmt"
"net"
"net/netip"

"github.com/caddyserver/caddy/v2"
"github.com/hslatman/caddy-crowdsec-bouncer/crowdsec"
Expand Down Expand Up @@ -89,18 +90,16 @@ func (m Matcher) Match(cx *l4.Connection) (bool, error) {

// getClientIP determines the IP of the client connecting
// Implementation taken from github.com/mholt/caddy-l4/layer4/matchers.go
func (m Matcher) getClientIP(cx *l4.Connection) (net.IP, error) {

func (m Matcher) getClientIP(cx *l4.Connection) (netip.Addr, error) {
remote := cx.Conn.RemoteAddr().String()

ipStr, _, err := net.SplitHostPort(remote)
if err != nil {
ipStr = remote
}

ip := net.ParseIP(ipStr)
if ip == nil {
return nil, fmt.Errorf("invalid client IP address: %s", ipStr)
ip, err := netip.ParseAddr(ipStr)
if err != nil {
return netip.Addr{}, fmt.Errorf("invalid client IP address: %s", ipStr)
}

return ip, nil
Expand Down

0 comments on commit c332c11

Please sign in to comment.