Skip to content

Commit

Permalink
Include query params in proxied request (#58)
Browse files Browse the repository at this point in the history
* include query parameters in proxied request

* replace httptest with an actual Gin server (makes testing easier)

* add test for query params
  • Loading branch information
tpetr authored Jan 17, 2024
1 parent 320b5ee commit 50ecc7b
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 14 deletions.
95 changes: 81 additions & 14 deletions it/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"testing"
"time"

"github.com/gin-gonic/gin"
"github.com/mcuadros/go-defaults"
"github.com/semgrep/semgrep-network-broker/cmd"
"github.com/semgrep/semgrep-network-broker/pkg"
Expand Down Expand Up @@ -90,6 +91,35 @@ func (tc *testClient) AssertStatusCode(t *testing.T, method string, rawUrl strin
}
}

func (tc *testClient) AssertStatusAndContent(t *testing.T, method string, rawUrl string, expectedStatusCode int, expectedContent string) {
url, err := url.Parse(rawUrl)
if err != nil {
t.Errorf("error while making %v %v: %v", method, rawUrl, err)
}

req := &http.Request{
Method: method,
URL: url,
}

if method != "GET" {
req.Body = io.NopCloser(strings.NewReader("{\"foo\": 2}"))
}

statusCode, content, err := tc.Request(req)
if err != nil {
t.Errorf("error while making %v %v: %v", method, rawUrl, err)
}

if statusCode != expectedStatusCode {
t.Errorf("%v %v returned HTTP %v, expected HTTP %v", method, rawUrl, statusCode, expectedStatusCode)
}

if content != expectedContent {
t.Errorf("%v %v returned '%v', expected '%v'", method, rawUrl, content, expectedContent)
}
}

func TestWireguardInboundProxy(t *testing.T) {
gatewayWireguardPort := mustGetFreePort()
gatewayWireguardAddress := mustGetRandomPrivateAddress()
Expand Down Expand Up @@ -125,13 +155,39 @@ func TestWireguardInboundProxy(t *testing.T) {
defer remoteWireguardTeardown()
log.Info("Remote wireguard peer is up")

// set up internal service
internalServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "Hello")
}))
defer internalServer.Close()
// set up internal service (the thing that the broker proxies to)
internalServer := gin.Default()

// we want this proxy to be transparent, so don't un-escape characters in the URL
internalServer.UseRawPath = true
internalServer.UnescapePathValues = false

internalServer.Any("/allowed-get", func(ctx *gin.Context) {
ctx.String(200, "Hello")
})
internalServer.Any("/unallowed-get", func(ctx *gin.Context) {
ctx.String(200, "Hello")
})
internalServer.Any("/allowed-post", func(ctx *gin.Context) {
ctx.String(200, "Hello")
})
internalServer.Any("/allowed-path/:path", func(ctx *gin.Context) {
ctx.String(200, "Hello %v", ctx.GetString("path"))
})
internalServer.Any("/introspect/query-params", func(ctx *gin.Context) {
ctx.String(200, ctx.Request.URL.RawQuery)
})

internalListener, err := net.Listen("tcp", "127.0.0.1:")
if err != nil {
t.Errorf("Failed to start internal listener: %v", err)
}
defer internalListener.Close()
go internalServer.RunListener(internalListener)
log.Info("Internal server is up")

internalServerBaseUrl := fmt.Sprintf("http://%v", internalListener.Addr().String())

// start network broker
brokerConfig := &pkg.Config{
Inbound: pkg.InboundProxyConfig{
Expand All @@ -148,17 +204,21 @@ func TestWireguardInboundProxy(t *testing.T) {
},
Allowlist: []pkg.AllowlistItem{
{
URL: internalServer.URL + "/allowed-get",
URL: internalServerBaseUrl + "/allowed-get",
Methods: pkg.ParseHttpMethods([]string{"GET"}),
},
{
URL: internalServer.URL + "/allowed-post",
URL: internalServerBaseUrl + "/allowed-post",
Methods: pkg.ParseHttpMethods([]string{"POST"}),
},
{
URL: internalServer.URL + "/allowed-path/:path",
URL: internalServerBaseUrl + "/allowed-path/:path",
Methods: pkg.ParseHttpMethods([]string{"POST"}),
},
{
URL: internalServerBaseUrl + "/introspect/*",
Methods: pkg.ParseHttpMethods([]string{"GET", "POST"}),
},
},
Heartbeat: pkg.HeartbeatConfig{
URL: fmt.Sprintf("http://[%v]/ping", gatewayWireguardAddress),
Expand Down Expand Up @@ -190,17 +250,24 @@ func TestWireguardInboundProxy(t *testing.T) {
}

// it should proxy requests that match the allowlist
remoteHttpClient.AssertStatusCode(t, "GET", fmt.Sprintf("http://[%v]/proxy/%v/allowed-get", clientWireguardAddress, internalServer.URL), 200)
remoteHttpClient.AssertStatusCode(t, "POST", fmt.Sprintf("http://[%v]/proxy/%v/allowed-post", clientWireguardAddress, internalServer.URL), 200)
remoteHttpClient.AssertStatusCode(t, "POST", fmt.Sprintf("http://[%v]/proxy/%v/allowed-path/foobar", clientWireguardAddress, internalServer.URL), 200)
remoteHttpClient.AssertStatusCode(t, "GET", fmt.Sprintf("http://[%v]/proxy/%v/allowed-get", clientWireguardAddress, internalServerBaseUrl), 200)
remoteHttpClient.AssertStatusCode(t, "POST", fmt.Sprintf("http://[%v]/proxy/%v/allowed-post", clientWireguardAddress, internalServerBaseUrl), 200)
remoteHttpClient.AssertStatusCode(t, "POST", fmt.Sprintf("http://[%v]/proxy/%v/allowed-path/foobar", clientWireguardAddress, internalServerBaseUrl), 200)

// it should pass along all query params
remoteHttpClient.AssertStatusCode(t, "GET", fmt.Sprintf("http://[%v]/proxy/%v/allowed-get?foo=bar", clientWireguardAddress, internalServerBaseUrl), 200)

// it shouldnt decode urlencoded characters
remoteHttpClient.AssertStatusCode(t, "POST", fmt.Sprintf("http://[%v]/proxy/%v/allowed-path/%s", clientWireguardAddress, internalServer.URL, "foobar%2Fbla"), 200)
remoteHttpClient.AssertStatusCode(t, "POST", fmt.Sprintf("http://[%v]/proxy/%v/allowed-path/%s", clientWireguardAddress, internalServerBaseUrl, "foobar%2Fbla"), 200)

// it should reject requests that don't match the allowlist
remoteHttpClient.AssertStatusCode(t, "POST", fmt.Sprintf("http://[%v]/proxy/%v/allowed-get", clientWireguardAddress, internalServer.URL), 403)
remoteHttpClient.AssertStatusCode(t, "GET", fmt.Sprintf("http://[%v]/proxy/%v/allowed-post", clientWireguardAddress, internalServer.URL), 403)
remoteHttpClient.AssertStatusCode(t, "POST", fmt.Sprintf("http://[%v]/proxy/%v/unallowed-get", clientWireguardAddress, internalServerBaseUrl), 403)
remoteHttpClient.AssertStatusCode(t, "POST", fmt.Sprintf("http://[%v]/proxy/%v/allowed-get", clientWireguardAddress, internalServerBaseUrl), 403)
remoteHttpClient.AssertStatusCode(t, "GET", fmt.Sprintf("http://[%v]/proxy/%v/allowed-post", clientWireguardAddress, internalServerBaseUrl), 403)
remoteHttpClient.AssertStatusCode(t, "GET", fmt.Sprintf("http://[%v]/proxy/https://google.com", clientWireguardAddress), 403)

// it should include query params in the proxied request
remoteHttpClient.AssertStatusAndContent(t, "GET", fmt.Sprintf("http://[%v]/proxy/%v/introspect/query-params?foo=bar", clientWireguardAddress, internalServerBaseUrl), 200, "foo=bar")
}

func TestRelay(t *testing.T) {
Expand Down
4 changes: 4 additions & 0 deletions pkg/inbound_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ func (config *InboundProxyConfig) Start(tnet *netstack.Net) error {
r.Any(proxyPath, func(c *gin.Context) {
logger := log.WithFields(GetRequestFields(c))
destinationUrl, err := url.Parse(c.Param(destinationUrlParam)[1:])

// we have to explicitly copy over the query params
destinationUrl.RawQuery = c.Request.URL.RawQuery

logger = logger.WithField("destinationUrl", destinationUrl)

if err != nil {
Expand Down

0 comments on commit 50ecc7b

Please sign in to comment.