diff --git a/crowdsec/crowdsec.go b/crowdsec/crowdsec.go index 3c0fab7f..f01a2a89 100644 --- a/crowdsec/crowdsec.go +++ b/crowdsec/crowdsec.go @@ -18,7 +18,7 @@ import ( "context" "errors" "fmt" - "net" + "net/netip" "reflect" "runtime/debug" "slices" @@ -248,7 +248,7 @@ func (c *CrowdSec) Stop() error { // IsAllowed is used by the CrowdSec HTTP handler to check if // an IP is allowed to perform a request -func (c *CrowdSec) IsAllowed(ip net.IP) (bool, *models.Decision, error) { +func (c *CrowdSec) IsAllowed(ip netip.Addr) (bool, *models.Decision, error) { // TODO: check if running? fully loaded, etc? return c.bouncer.IsAllowed(ip) } diff --git a/crowdsec/crowdsec_test.go b/crowdsec/crowdsec_test.go index 67b72395..35c6582e 100644 --- a/crowdsec/crowdsec_test.go +++ b/crowdsec/crowdsec_test.go @@ -18,9 +18,9 @@ import ( "context" "encoding/json" "fmt" - "net" "net/http" "net/http/httptest" + "net/netip" "sync" "testing" "time" @@ -200,7 +200,7 @@ func TestCrowdSec_streamingBouncerRuntime(t *testing.T) { time.Sleep(100 * time.Millisecond) // simulate a lookup - allowed, decision, err := c.IsAllowed(net.ParseIP("127.0.0.1")) + allowed, decision, err := c.IsAllowed(netip.MustParseAddr("127.0.0.1")) assert.NoError(t, err) assert.Nil(t, decision) assert.True(t, allowed) @@ -253,7 +253,7 @@ func TestCrowdSec_liveBouncerRuntime(t *testing.T) { require.NoError(t, err) // simulate a lookup - allowed, decision, err := c.IsAllowed(net.ParseIP("127.0.0.1")) + allowed, decision, err := c.IsAllowed(netip.MustParseAddr("127.0.0.1")) assert.NoError(t, err) assert.Nil(t, decision) assert.True(t, allowed) diff --git a/go.mod b/go.mod index 6bfd00bb..87f4f217 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/crowdsecurity/go-cs-bouncer v0.0.14 github.com/crowdsecurity/go-cs-lib v0.0.15 github.com/google/go-cmp v0.6.0 - github.com/hslatman/ipstore v0.0.0-20210131120430-64b55d649887 + github.com/hslatman/ipstore v0.2.0 github.com/jarcoal/httpmock v1.3.1 github.com/mholt/caddy-l4 v0.0.0-20231016112149-a362a1fbf652 github.com/sirupsen/logrus v1.9.3 diff --git a/go.sum b/go.sum index 290290ee..ee23c5da 100644 --- a/go.sum +++ b/go.sum @@ -199,8 +199,8 @@ github.com/groob/finalizer v0.0.0-20170707115354-4c2ed49aabda/go.mod h1:MyndkAZd github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/hslatman/cidranger v1.0.3-0.20210102151717-b2292da972c3 h1:Sv/aRgGM6Qpidn4IaCeW1M184rkdXCuKHCMGW3slpnY= github.com/hslatman/cidranger v1.0.3-0.20210102151717-b2292da972c3/go.mod h1:gcrMfr0dObt7Xdm3JrZqrshMoaCFs9Plkc+ID9ygSdY= -github.com/hslatman/ipstore v0.0.0-20210131120430-64b55d649887 h1:/is/XCIDQs5vfEfRk7dV5rpnuindiJD6FMIGwMej4Go= -github.com/hslatman/ipstore v0.0.0-20210131120430-64b55d649887/go.mod h1:/EV+ke1dTQNL6ZF1xuyemZcDw4vpQa9xHaSVTI6S2lI= +github.com/hslatman/ipstore v0.2.0 h1:q320dnrCF78ruZta0zNuterclga4tTFzxXosHfbEbfU= +github.com/hslatman/ipstore v0.2.0/go.mod h1:O5HTtag+448N/IuPezCz/3B+p/Ev7DMrqW2q0VZedRg= github.com/huandu/xstrings v1.3.3/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/huandu/xstrings v1.4.0 h1:D17IlohoQq4UcpqD7fDk80P7l+lwAmlFaBHgOipl2FU= github.com/huandu/xstrings v1.4.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= diff --git a/http/http.go b/http/http.go index 11681f04..6b784748 100644 --- a/http/http.go +++ b/http/http.go @@ -17,8 +17,8 @@ package http import ( "errors" "fmt" - "net" "net/http" + "net/netip" "time" "github.com/caddyserver/caddy/v2" @@ -74,7 +74,6 @@ func (h *Handler) Validate() error { // ServeHTTP is the Caddy handler for serving HTTP requests func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { - ipToCheck, err := determineIPFromRequest(r) if err != nil { return err // TODO: return error here? Or just log it and continue serving @@ -149,25 +148,26 @@ func writeThrottleResponse(w http.ResponseWriter, duration string) error { // Caddy extracts from the original request and stores in the request context. // Support for setting the real client IP in case a proxy sits in front of // Caddy was added, so the client IP reported here is the actual client IP. -func determineIPFromRequest(r *http.Request) (net.IP, error) { +func determineIPFromRequest(r *http.Request) (netip.Addr, error) { + zero := netip.Addr{} clientIPVar := caddyhttp.GetVar(r.Context(), caddyhttp.ClientIPVarKey) if clientIPVar == nil { - return nil, errors.New("failed getting client IP from context") + return zero, errors.New("failed getting client IP from context") } var clientIP string var ok bool if clientIP, ok = clientIPVar.(string); !ok { - return nil, fmt.Errorf("client IP from request context is invalid type %T", clientIPVar) + return zero, fmt.Errorf("client IP from request context is invalid type %T", clientIPVar) } if clientIP == "" { - return nil, errors.New("client IP from request context is empty") + return zero, errors.New("client IP from request context is empty") } - ip := net.ParseIP(clientIP) - if ip == nil { - return nil, fmt.Errorf("could not parse %q into net.IP", clientIP) + ip, err := netip.ParseAddr(clientIP) + if err != nil { + return zero, fmt.Errorf("could not parse %q into netip.Addr", clientIP) } return ip, nil diff --git a/http/http_test.go b/http/http_test.go index ddd45d2e..2776331c 100644 --- a/http/http_test.go +++ b/http/http_test.go @@ -16,9 +16,9 @@ package http import ( "context" - "net" "net/http" "net/http/httptest" + "net/netip" "reflect" "testing" @@ -47,14 +47,14 @@ func Test_determineIPFromRequest(t *testing.T) { tests := []struct { name string args args - want net.IP + want netip.Addr wantErr bool }{ - {"ok", args{r.WithContext(okCtx)}, net.ParseIP("127.0.0.1"), false}, - {"no-ip", args{r.WithContext(noIPCtx)}, nil, true}, - {"wrong-type", args{r.WithContext(wrongTypeCtx)}, nil, true}, - {"empty-ip", args{r.WithContext(emptyIPCtx)}, nil, true}, - {"invalid-ip", args{r.WithContext(invalidIPCtx)}, nil, true}, + {"ok", args{r.WithContext(okCtx)}, netip.MustParseAddr("127.0.0.1"), false}, + {"no-ip", args{r.WithContext(noIPCtx)}, netip.Addr{}, true}, + {"wrong-type", args{r.WithContext(wrongTypeCtx)}, netip.Addr{}, true}, + {"empty-ip", args{r.WithContext(emptyIPCtx)}, netip.Addr{}, true}, + {"invalid-ip", args{r.WithContext(invalidIPCtx)}, netip.Addr{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/internal/bouncer/bouncer.go b/internal/bouncer/bouncer.go index 67f94c45..c92832e5 100644 --- a/internal/bouncer/bouncer.go +++ b/internal/bouncer/bouncer.go @@ -19,7 +19,7 @@ import ( "encoding/hex" "fmt" "math/rand" - "net" + "net/netip" "sync" "time" @@ -36,15 +36,15 @@ const ( maxNumberOfDecisionsToLog = 10 ) -// Bouncer is a wrapper for a CrowdSec bouncer. It supports both the the +// Bouncer is a wrapper for a CrowdSec bouncer. It supports both the // streaming and live bouncer implementations. The streaming bouncer is // backed by an immutable radix tree storing known bad IPs and IP ranges. -// The live bouncer will reach out to the CrowdSec agent on every check. +// The live bouncer will reach out to the CrowdSec LAPI on every check. type Bouncer struct { streamingBouncer *csbouncer.StreamBouncer liveBouncer *csbouncer.LiveBouncer metricsProvider *csbouncer.MetricsProvider - store *crowdSecStore + store *store logger *zap.Logger useStreamingBouncer bool shouldFailHard bool @@ -201,8 +201,7 @@ func (b *Bouncer) Shutdown() error { } // IsAllowed checks if an IP is allowed or not -func (b *Bouncer) IsAllowed(ip net.IP) (bool, *models.Decision, error) { - +func (b *Bouncer) IsAllowed(ip netip.Addr) (bool, *models.Decision, error) { // TODO: perform lookup in explicit allowlist as a kind of quick lookup in front of the CrowdSec lookup list? isAllowed := false decision, err := b.retrieveDecision(ip) diff --git a/internal/bouncer/bouncer_test.go b/internal/bouncer/bouncer_test.go index 3822fcb1..8c16217b 100644 --- a/internal/bouncer/bouncer_test.go +++ b/internal/bouncer/bouncer_test.go @@ -2,7 +2,7 @@ package bouncer import ( "context" - "net" + "net/netip" "net/url" "regexp" "testing" @@ -58,7 +58,6 @@ func newBouncer(t *testing.T) (*Bouncer, error) { } func decisions() *models.DecisionsStreamResponse { - duration := "120s" source := "cscli" scenario := "manual ban ..." @@ -145,7 +144,7 @@ func TestStreamingBouncer(t *testing.T) { time.Sleep(1 * time.Second) type args struct { - ip net.IP + ip netip.Addr } tests := []struct { name string @@ -156,7 +155,7 @@ func TestStreamingBouncer(t *testing.T) { { name: "127.0.0.1 not allowed", args: args{ - ip: net.ParseIP("127.0.0.1"), + ip: netip.MustParseAddr("127.0.0.1"), }, want: false, wantErr: false, @@ -164,7 +163,7 @@ func TestStreamingBouncer(t *testing.T) { { name: "127.0.0.2 not allowed", args: args{ - ip: net.ParseIP("127.0.0.2"), + ip: netip.MustParseAddr("127.0.0.2"), }, want: false, wantErr: false, @@ -172,7 +171,7 @@ func TestStreamingBouncer(t *testing.T) { { name: "127.0.0.3 allowed", args: args{ - ip: net.ParseIP("127.0.0.3"), + ip: netip.MustParseAddr("127.0.0.3"), }, want: true, wantErr: false, @@ -180,7 +179,7 @@ func TestStreamingBouncer(t *testing.T) { { name: "10.0.0.1/24 (10.0.0.1) not allowed", args: args{ - ip: net.ParseIP("10.0.0.1"), + ip: netip.MustParseAddr("10.0.0.1"), }, want: false, wantErr: false, @@ -188,7 +187,7 @@ func TestStreamingBouncer(t *testing.T) { { name: "10.0.1.0 allowed", args: args{ - ip: net.ParseIP("10.0.1.0"), + ip: netip.MustParseAddr("10.0.1.0"), }, want: true, wantErr: false, @@ -196,7 +195,7 @@ func TestStreamingBouncer(t *testing.T) { { name: "128.0.0.1 not allowed", args: args{ - ip: net.ParseIP("128.0.0.1"), + ip: netip.MustParseAddr("128.0.0.1"), }, want: false, wantErr: false, @@ -204,7 +203,7 @@ func TestStreamingBouncer(t *testing.T) { { name: "129.0.0.1 allowed", args: args{ - ip: net.ParseIP("129.0.0.1"), + ip: netip.MustParseAddr("129.0.0.1"), }, want: true, wantErr: false, diff --git a/internal/bouncer/decisions.go b/internal/bouncer/decisions.go index 567e0d0b..129e8660 100644 --- a/internal/bouncer/decisions.go +++ b/internal/bouncer/decisions.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "net/netip" "github.com/crowdsecurity/crowdsec/pkg/models" "go.uber.org/zap" @@ -95,7 +96,9 @@ func (b *Bouncer) delete(decision *models.Decision) error { return b.store.delete(decision) } -func (b *Bouncer) retrieveDecision(ip net.IP) (*models.Decision, error) { +func (b *Bouncer) retrieveDecision(ipAddr netip.Addr) (*models.Decision, error) { + ip := net.IP(ipAddr.AsSlice()) // TODO: feed through netip.Addr fully + if b.useStreamingBouncer { return b.store.get(ip) } diff --git a/internal/bouncer/store.go b/internal/bouncer/store.go index 90a45166..9ecda232 100644 --- a/internal/bouncer/store.go +++ b/internal/bouncer/store.go @@ -19,20 +19,20 @@ import ( "net" "github.com/crowdsecurity/crowdsec/pkg/models" - "github.com/hslatman/ipstore/pkg/ipstore" + "github.com/hslatman/ipstore" ) -type crowdSecStore struct { +type store struct { store *ipstore.Store } -func newStore() *crowdSecStore { - return &crowdSecStore{ +func newStore() *store { + return &store{ store: ipstore.New(), } } -func (s *crowdSecStore) add(decision *models.Decision) error { +func (s *store) add(decision *models.Decision) error { if isInvalid(decision) { return nil } @@ -58,8 +58,7 @@ func (s *crowdSecStore) add(decision *models.Decision) error { } } -func (s *crowdSecStore) delete(decision *models.Decision) error { - +func (s *store) delete(decision *models.Decision) error { if isInvalid(decision) { return nil } @@ -87,8 +86,7 @@ func (s *crowdSecStore) delete(decision *models.Decision) error { } } -func (s *crowdSecStore) get(key net.IP) (*models.Decision, error) { - +func (s *store) get(key net.IP) (*models.Decision, error) { r, err := s.store.Get(key) if err != nil { return nil, err @@ -137,7 +135,6 @@ func parseIP(value string) (net.IP, error) { // valid, meaning that it's not pointing to nil and has a // Scope, Value and Type set, the minimum required to operate func isInvalid(d *models.Decision) bool { - if d == nil { return true } diff --git a/internal/bouncer/store_test.go b/internal/bouncer/store_test.go new file mode 100644 index 00000000..4c45b28b --- /dev/null +++ b/internal/bouncer/store_test.go @@ -0,0 +1,115 @@ +// Copyright 2021 Herman Slatman +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bouncer + +import ( + "net" + "testing" + + "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/stretchr/testify/require" +) + +func TestStore(t *testing.T) { + duration := "120s" + source := "cscli" + scenario := "manual ban ..." + scopeIP := "Ip" + scopeRange := "Range" + typ := "ban" + value1 := "127.0.0.1" + value2 := "127.0.0.2" + value3 := "10.0.0.1/24" + value4 := "128.0.0.1/32" + value5 := "129.0.0.1/24" + + d1 := &models.Decision{ + Duration: &duration, + ID: 1, + Origin: &source, + Scenario: &scenario, + Scope: &scopeIP, + Type: &typ, + Value: &value1, + } + + d2 := &models.Decision{ + Duration: &duration, + ID: 2, + Origin: &source, + Scenario: &scenario, + Scope: &scopeIP, + Type: &typ, + Value: &value2, + } + + d3 := &models.Decision{ + Duration: &duration, + ID: 3, + Origin: &source, + Scenario: &scenario, + Scope: &scopeRange, + Type: &typ, + Value: &value3, + } + + d4 := &models.Decision{ + Duration: &duration, + ID: 4, + Origin: &source, + Scenario: &scenario, + Scope: &scopeIP, + Type: &typ, + Value: &value4, + } + + d5 := &models.Decision{ + Duration: &duration, + ID: 5, + Origin: &source, + Scenario: &scenario, + Scope: &scopeIP, // IP scope + Type: &typ, + Value: &value5, // range + } + + s := newStore() + err := s.add(d1) + require.NoError(t, err) + err = s.add(d2) + require.NoError(t, err) + err = s.add(d3) + require.NoError(t, err) + err = s.add(d4) + require.NoError(t, err) + err = s.add(d5) + require.Error(t, err) + + ip1 := net.ParseIP(value1) + r1, err := s.get(ip1) + require.NoError(t, err) + require.NotNil(t, r1) + require.Equal(t, value1, *r1.Value) + + err = s.delete(d1) + require.NoError(t, err) + + err = s.delete(d3) + require.NoError(t, err) + + r1, err = s.get(ip1) + require.NoError(t, err) + require.Nil(t, r1) +} diff --git a/layer4/l4.go b/layer4/l4.go index 74002821..66c847df 100644 --- a/layer4/l4.go +++ b/layer4/l4.go @@ -17,6 +17,7 @@ package layer4 import ( "fmt" "net" + "net/netip" "github.com/caddyserver/caddy/v2" "github.com/hslatman/caddy-crowdsec-bouncer/crowdsec" @@ -89,18 +90,16 @@ func (m Matcher) Match(cx *l4.Connection) (bool, error) { // getClientIP determines the IP of the client connecting // Implementation taken from github.com/mholt/caddy-l4/layer4/matchers.go -func (m Matcher) getClientIP(cx *l4.Connection) (net.IP, error) { - +func (m Matcher) getClientIP(cx *l4.Connection) (netip.Addr, error) { remote := cx.Conn.RemoteAddr().String() - ipStr, _, err := net.SplitHostPort(remote) if err != nil { ipStr = remote } - ip := net.ParseIP(ipStr) - if ip == nil { - return nil, fmt.Errorf("invalid client IP address: %s", ipStr) + ip, err := netip.ParseAddr(ipStr) + if err != nil { + return netip.Addr{}, fmt.Errorf("invalid client IP address: %s", ipStr) } return ip, nil