From 66150c5731150288516236c43644fee508a53b48 Mon Sep 17 00:00:00 2001 From: Tom Petr Date: Tue, 17 Sep 2024 13:11:48 -0400 Subject: [PATCH] better support for running multiple separate brokers (#85) * better support for HA brokers * slightly better genkey/pubkey impl --- cmd/dump.go | 2 +- cmd/genkey.go | 28 ++++++++++++++++++++++----- cmd/pubkey.go | 47 +++++++++++++++++++++++++--------------------- cmd/relay.go | 2 +- cmd/root.go | 6 ++++-- pkg/config.go | 21 +++++++++++++-------- pkg/config_test.go | 2 +- pkg/wireguard.go | 35 ++++++++++++++++++++++++++-------- 8 files changed, 96 insertions(+), 47 deletions(-) diff --git a/cmd/dump.go b/cmd/dump.go index 3d6703f..2911aa5 100644 --- a/cmd/dump.go +++ b/cmd/dump.go @@ -13,7 +13,7 @@ var dumpCmd = &cobra.Command{ Use: "dump", Short: "Dump current config", Run: func(cmd *cobra.Command, args []string) { - config, err := pkg.LoadConfig(configFiles, deploymentId) + config, err := pkg.LoadConfig(configFiles, deploymentId, brokerIndex) if err != nil { log.Panic(err) } diff --git a/cmd/genkey.go b/cmd/genkey.go index 2bb277e..b1ff9db 100644 --- a/cmd/genkey.go +++ b/cmd/genkey.go @@ -3,25 +3,43 @@ package cmd import ( "encoding/base64" "fmt" + "os" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) +var replicaCount int + +const defaultReplicaCount = 3 +const minReplicaCount = 1 +const maxReplicaCount = 16 + var genkeyCmd = &cobra.Command{ Use: "genkey", - Short: "Generates a random private key in base64 and prints it to stdout", + Short: "Generates a random Semgrep Network Broker private key and prints it to stdout.", Run: func(cmd *cobra.Command, args []string) { - privateKey, err := wgtypes.GeneratePrivateKey() - if err != nil { - log.Panic(fmt.Errorf("failed to generate private key: %v", err)) + if replicaCount < minReplicaCount || replicaCount > maxReplicaCount { + log.Panic(fmt.Errorf("replica count must be between %v and %v", minReplicaCount, maxReplicaCount)) } - fmt.Println(base64.StdEncoding.EncodeToString(privateKey[:])) + encoder := base64.NewEncoder(base64.StdEncoding, os.Stdout) + defer encoder.Close() + + for i := 0; i < replicaCount; i++ { + privateKey, err := wgtypes.GeneratePrivateKey() + if err != nil { + log.Panic(fmt.Errorf("failed to generate private key %v: %v", i, err)) + } + if _, err := encoder.Write(privateKey[:]); err != nil { + log.Panic(fmt.Errorf("failed to write private key %v: %v", i, err)) + } + } }, } func init() { + genkeyCmd.PersistentFlags().IntVarP(&replicaCount, "replica-count", "r", defaultReplicaCount, "Number of broker replicas to support") rootCmd.AddCommand(genkeyCmd) } diff --git a/cmd/pubkey.go b/cmd/pubkey.go index cd3b341..32c1d95 100644 --- a/cmd/pubkey.go +++ b/cmd/pubkey.go @@ -8,35 +8,40 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) var pubkeyCmd = &cobra.Command{ Use: "pubkey", - Short: "Reads a base64 private key from stdin, outputs the corresponding base64 public key", + Short: "Reads a Semgrep Network Broker private key from stdin and ptints the corresponding public key to stdout.", Run: func(cmd *cobra.Command, args []string) { - keyBase64, err := io.ReadAll(os.Stdin) - if err != nil { - log.Panic(err) - } - - keyBytes := make([]byte, 32) - n, err := base64.StdEncoding.Decode(keyBytes, keyBase64) - if err != nil { - log.Panic(err) - } - if n != 32 { - log.Panic("not enough bytes") - } - privateKey, err := wgtypes.NewKey(keyBytes) - if err != nil { - log.Panic(err) + decoder := base64.NewDecoder(base64.StdEncoding, os.Stdin) + encoder := base64.NewEncoder(base64.StdEncoding, os.Stdout) + defer encoder.Close() + + privateKeyBytes := make([]byte, device.NoisePrivateKeySize) + + for i := 0; ; i++ { + _, err := io.ReadFull(decoder, privateKeyBytes) + if err != nil { + if err == io.EOF { + break + } else { + log.Panic(fmt.Errorf("error reading private key %v: %v", i, err)) + } + } + privateKey, err := wgtypes.NewKey(privateKeyBytes) + if err != nil { + log.Panic(fmt.Errorf("error creating private key %v: %v", i, err)) + } + + publicKey := privateKey.PublicKey() + if _, err := encoder.Write(publicKey[:]); err != nil { + log.Panic(fmt.Errorf("error writing public key %v: %v", i, err)) + } } - - publicKey := privateKey.PublicKey() - - fmt.Println(base64.StdEncoding.EncodeToString(publicKey[:])) }, } diff --git a/cmd/relay.go b/cmd/relay.go index 0820cd9..c05d6bd 100644 --- a/cmd/relay.go +++ b/cmd/relay.go @@ -30,7 +30,7 @@ var relayCmd = &cobra.Command{ }() // load config(s) - config, err := pkg.LoadConfig(configFiles, 0) + config, err := pkg.LoadConfig(configFiles, 0, 0) if err != nil { log.Panic(err) } diff --git a/cmd/root.go b/cmd/root.go index f49c8d7..ef176b6 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -18,6 +18,7 @@ import ( var configFiles []string var jsonLog bool var deploymentId int +var brokerIndex int var rootCmd = &cobra.Command{ Use: "semgrep-network-broker", @@ -39,7 +40,7 @@ var rootCmd = &cobra.Command{ }() // load config(s) - config, err := pkg.LoadConfig(configFiles, deploymentId) + config, err := pkg.LoadConfig(configFiles, deploymentId, brokerIndex) if err != nil { log.Panic(err) } @@ -75,7 +76,7 @@ func StartNetworkBroker(config *pkg.Config) (func() error, error) { return wireguardTeardown() } - // start inbound proxy (r2c --> customer) + // start inbound proxy (semgrep --> customer) if err := config.Inbound.Start(tnet); err != nil { teardown() return nil, fmt.Errorf("failed to start inbound proxy: %v", err) @@ -95,4 +96,5 @@ func init() { rootCmd.PersistentFlags().StringArrayVarP(&configFiles, "config", "c", nil, "config file(s)") rootCmd.PersistentFlags().BoolVarP(&jsonLog, "json-log", "j", false, "JSON log output") rootCmd.PersistentFlags().IntVarP(&deploymentId, "deployment-id", "d", 0, "Semgrep deployment ID") + rootCmd.PersistentFlags().IntVarP(&brokerIndex, "broker-index", "i", 0, "Semgrep network broker index") } diff --git a/pkg/config.go b/pkg/config.go index 84f7310..588e22a 100644 --- a/pkg/config.go +++ b/pkg/config.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "net/netip" "net/url" "os" "reflect" @@ -72,13 +73,15 @@ type WireguardPeer struct { } type WireguardBase struct { - LocalAddress string `mapstructure:"localAddress" json:"localAddress" validate:"format=ip"` - Dns []string `mapstructure:"dns" json:"dns" validate:"empty=true > format=ip"` - Mtu int `mapstructure:"mtu" json:"mtu" validate:"gte=0" default:"1420"` - PrivateKey SensitiveBase64String `mapstructure:"privateKey" json:"privateKey" validate:"empty=false"` - ListenPort int `mapstructure:"listenPort" json:"listenPort" validate:"gte=0"` - Peers []WireguardPeer `mapstructure:"peers" json:"peers" validate:"empty=false"` - Verbose bool `mapstructure:"verbose" json:"verbose"` + resolvedLocalAddress netip.Addr + LocalAddress string `mapstructure:"localAddress" json:"localAddress" validate:"format=ip"` + Dns []string `mapstructure:"dns" json:"dns" validate:"empty=true > format=ip"` + Mtu int `mapstructure:"mtu" json:"mtu" validate:"gte=0" default:"1420"` + PrivateKey SensitiveBase64String `mapstructure:"privateKey" json:"privateKey" validate:"empty=false"` + ListenPort int `mapstructure:"listenPort" json:"listenPort" validate:"gte=0"` + Peers []WireguardPeer `mapstructure:"peers" json:"peers" validate:"empty=false"` + Verbose bool `mapstructure:"verbose" json:"verbose"` + BrokerIndex int `mapstructure:"brokerIndex" json:"brokerIndex" validate:"gte=0"` } type BitTester interface { @@ -262,9 +265,11 @@ type Config struct { Outbound OutboundProxyConfig `mapstructure:"outbound" json:"outbound"` } -func LoadConfig(configFiles []string, deploymentId int) (*Config, error) { +func LoadConfig(configFiles []string, deploymentId int, brokerIndex int) (*Config, error) { config := new(Config) + config.Inbound.Wireguard.BrokerIndex = brokerIndex + if deploymentId > 0 { hostname := os.Getenv("SEMGREP_HOSTNAME") if hostname == "" { diff --git a/pkg/config_test.go b/pkg/config_test.go index b7fdded..36786b3 100644 --- a/pkg/config_test.go +++ b/pkg/config_test.go @@ -11,7 +11,7 @@ import ( ) func TestEmptyConfigs(t *testing.T) { - config, err := LoadConfig(nil, 0) + config, err := LoadConfig(nil, 0, 0) if err != nil { t.Error(err) } diff --git a/pkg/wireguard.go b/pkg/wireguard.go index dba3be8..e70f60c 100644 --- a/pkg/wireguard.go +++ b/pkg/wireguard.go @@ -2,6 +2,7 @@ package pkg import ( "encoding/hex" + "errors" "fmt" "io" "math/rand" @@ -34,10 +35,22 @@ func (peer WireguardPeer) WriteTo(sb io.StringWriter) { } } +func (base WireguardBase) Validate() error { + privateKeyCount := len(base.PrivateKey) / device.NoisePrivateKeySize + + if base.BrokerIndex >= privateKeyCount { + return errors.New("broker index beyond private key count") + } + + return nil +} + func (base WireguardBase) GenerateConfig() string { sb := strings.Builder{} - sb.WriteString(fmt.Sprintf("private_key=%s\n", hex.EncodeToString(base.PrivateKey))) + indexedPrivateKey := base.PrivateKey[device.NoisePrivateKeySize*base.BrokerIndex : device.NoisePrivateKeySize*(base.BrokerIndex+1)] + + sb.WriteString(fmt.Sprintf("private_key=%s\n", hex.EncodeToString(indexedPrivateKey))) sb.WriteString(fmt.Sprintf("listen_port=%d\n", base.ListenPort)) for i := range base.Peers { @@ -47,7 +60,16 @@ func (base WireguardBase) GenerateConfig() string { return sb.String() } -func (base *WireguardBase) ResolvePeerEndpoints() error { +func (base *WireguardBase) ResolveConfig() error { + resolvedLocalAddress, err := netip.ParseAddr(base.LocalAddress) + if err != nil { + return fmt.Errorf("LocalAddress parse failed: %v", err) + } + for i := 0; i < base.BrokerIndex; i++ { + resolvedLocalAddress = resolvedLocalAddress.Next() + } + base.resolvedLocalAddress = resolvedLocalAddress + for i := range base.Peers { if base.Peers[i].Endpoint == "" { continue @@ -77,14 +99,11 @@ func (config *WireguardBase) Start() (*netstack.Net, func() error, error) { return nil, nil, fmt.Errorf("invalid wireguard config: %v", err) } - // resolve peer endpoints (if not IP address already) - if err := config.ResolvePeerEndpoints(); err != nil { + // resolve local address and peer endpoints (if not IP address already) + if err := config.ResolveConfig(); err != nil { return nil, nil, fmt.Errorf("failed to resolve peer endpoint: %v", err) } - // parse localAddres and DNS addresses -- MustParseAddr is fine here because we've already validated the config - localAddress := netip.MustParseAddr(config.LocalAddress) - var dnsAddresses = make([]netip.Addr, len(config.Dns)) for i := range config.Dns { dnsAddresses[i] = netip.MustParseAddr(config.Dns[i]) @@ -92,7 +111,7 @@ func (config *WireguardBase) Start() (*netstack.Net, func() error, error) { // create the wireguard interface tun, tnet, err := netstack.CreateNetTUN( - []netip.Addr{localAddress}, + []netip.Addr{config.resolvedLocalAddress}, dnsAddresses, config.Mtu, )