diff --git a/frontend/main.go b/frontend/main.go index c30393b..6d0d3e0 100644 --- a/frontend/main.go +++ b/frontend/main.go @@ -26,6 +26,7 @@ type settingType struct { nameFilter string timeOut int connectionTimeOut int + trustProxyHeaders bool } var setting settingType diff --git a/frontend/settings.go b/frontend/settings.go index 4c4d1fd..6171c7c 100644 --- a/frontend/settings.go +++ b/frontend/settings.go @@ -27,6 +27,7 @@ type viperSettingType struct { NameFilter string `mapstructure:"name_filter"` TimeOut int `mapstructure:"timeout"` ConnectionTimeOut int `mapstructure:"connection_timeout"` + TrustProxyHeaders bool `mapstructure:"trust_proxy_headers"` } // Parse settings with viper, and convert to legacy setting format @@ -94,6 +95,9 @@ func parseSettings() { pflag.Int("connection-time-out", 5, "time before backend TCP connection times out, in seconds; defaults to 5 if not set") viper.BindPFlag("connection_timeout", pflag.Lookup("connection-time-out")) + pflag.Bool("trust-proxy-headers", false, "Trust X-Forwared-For, X-Real-IP, X-Forwarded-Proto, X-Forwarded-Scheme and X-Forwarded-Host sent by the client") + viper.BindPFlag("trust_proxy_headers", pflag.Lookup("trust-proxy-headers")) + pflag.Parse() if err := viper.ReadInConfig(); err != nil { @@ -144,6 +148,7 @@ func parseSettings() { setting.nameFilter = viperSettings.NameFilter setting.timeOut = viperSettings.TimeOut setting.connectionTimeOut = viperSettings.ConnectionTimeOut + setting.trustProxyHeaders = viperSettings.TrustProxyHeaders fmt.Printf("%#v\n", setting) } diff --git a/frontend/webserver.go b/frontend/webserver.go index 0442c61..84a19a9 100644 --- a/frontend/webserver.go +++ b/frontend/webserver.go @@ -75,7 +75,6 @@ func webHandlerWhois(w http.ResponseWriter, r *http.Request) { // serve up results from bird func webBackendCommunicator(endpoint string, command string) func(w http.ResponseWriter, r *http.Request) { - backendCommandPrimitive, commandPresent := primitiveMap[command] if !commandPresent { panic("invalid command: " + command) @@ -195,7 +194,6 @@ func webHandlerBGPMap(endpoint string, command string) func(w http.ResponseWrite // set up routing paths and start webserver func webServerStart(l net.Listener) { - // redirect main page to all server summary http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/summary/"+url.PathEscape(strings.Join(setting.servers, "+")), 302) @@ -239,5 +237,11 @@ func webServerStart(l net.Listener) { http.HandleFunc("/telegram/", webHandlerTelegramBot) // Start HTTP server - http.Serve(l, handlers.LoggingHandler(os.Stdout, http.DefaultServeMux)) + var handler http.Handler + handler = http.DefaultServeMux + if setting.trustProxyHeaders { + handler = handlers.ProxyHeaders(handler) + } + handler = handlers.LoggingHandler(os.Stdout, handler) + http.Serve(l, handler) }