diff --git a/cmd/dump.go b/cmd/dump.go index 2911aa5..3d6703f 100644 --- a/cmd/dump.go +++ b/cmd/dump.go @@ -13,7 +13,7 @@ var dumpCmd = &cobra.Command{ Use: "dump", Short: "Dump current config", Run: func(cmd *cobra.Command, args []string) { - config, err := pkg.LoadConfig(configFiles, deploymentId, brokerIndex) + config, err := pkg.LoadConfig(configFiles, deploymentId) if err != nil { log.Panic(err) } diff --git a/cmd/genkey.go b/cmd/genkey.go index b1ff9db..2bb277e 100644 --- a/cmd/genkey.go +++ b/cmd/genkey.go @@ -3,43 +3,25 @@ package cmd import ( "encoding/base64" "fmt" - "os" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -var replicaCount int - -const defaultReplicaCount = 3 -const minReplicaCount = 1 -const maxReplicaCount = 16 - var genkeyCmd = &cobra.Command{ Use: "genkey", - Short: "Generates a random Semgrep Network Broker private key and prints it to stdout.", + Short: "Generates a random private key in base64 and prints it to stdout", Run: func(cmd *cobra.Command, args []string) { - if replicaCount < minReplicaCount || replicaCount > maxReplicaCount { - log.Panic(fmt.Errorf("replica count must be between %v and %v", minReplicaCount, maxReplicaCount)) + privateKey, err := wgtypes.GeneratePrivateKey() + if err != nil { + log.Panic(fmt.Errorf("failed to generate private key: %v", err)) } - encoder := base64.NewEncoder(base64.StdEncoding, os.Stdout) - defer encoder.Close() - - for i := 0; i < replicaCount; i++ { - privateKey, err := wgtypes.GeneratePrivateKey() - if err != nil { - log.Panic(fmt.Errorf("failed to generate private key %v: %v", i, err)) - } - if _, err := encoder.Write(privateKey[:]); err != nil { - log.Panic(fmt.Errorf("failed to write private key %v: %v", i, err)) - } - } + fmt.Println(base64.StdEncoding.EncodeToString(privateKey[:])) }, } func init() { - genkeyCmd.PersistentFlags().IntVarP(&replicaCount, "replica-count", "r", defaultReplicaCount, "Number of broker replicas to support") rootCmd.AddCommand(genkeyCmd) } diff --git a/cmd/pubkey.go b/cmd/pubkey.go index 32c1d95..cd3b341 100644 --- a/cmd/pubkey.go +++ b/cmd/pubkey.go @@ -8,40 +8,35 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) var pubkeyCmd = &cobra.Command{ Use: "pubkey", - Short: "Reads a Semgrep Network Broker private key from stdin and ptints the corresponding public key to stdout.", + Short: "Reads a base64 private key from stdin, outputs the corresponding base64 public key", Run: func(cmd *cobra.Command, args []string) { + keyBase64, err := io.ReadAll(os.Stdin) + if err != nil { + log.Panic(err) + } + + keyBytes := make([]byte, 32) + n, err := base64.StdEncoding.Decode(keyBytes, keyBase64) + if err != nil { + log.Panic(err) + } + if n != 32 { + log.Panic("not enough bytes") + } - decoder := base64.NewDecoder(base64.StdEncoding, os.Stdin) - encoder := base64.NewEncoder(base64.StdEncoding, os.Stdout) - defer encoder.Close() - - privateKeyBytes := make([]byte, device.NoisePrivateKeySize) - - for i := 0; ; i++ { - _, err := io.ReadFull(decoder, privateKeyBytes) - if err != nil { - if err == io.EOF { - break - } else { - log.Panic(fmt.Errorf("error reading private key %v: %v", i, err)) - } - } - privateKey, err := wgtypes.NewKey(privateKeyBytes) - if err != nil { - log.Panic(fmt.Errorf("error creating private key %v: %v", i, err)) - } - - publicKey := privateKey.PublicKey() - if _, err := encoder.Write(publicKey[:]); err != nil { - log.Panic(fmt.Errorf("error writing public key %v: %v", i, err)) - } + privateKey, err := wgtypes.NewKey(keyBytes) + if err != nil { + log.Panic(err) } + + publicKey := privateKey.PublicKey() + + fmt.Println(base64.StdEncoding.EncodeToString(publicKey[:])) }, } diff --git a/cmd/relay.go b/cmd/relay.go index c05d6bd..0820cd9 100644 --- a/cmd/relay.go +++ b/cmd/relay.go @@ -30,7 +30,7 @@ var relayCmd = &cobra.Command{ }() // load config(s) - config, err := pkg.LoadConfig(configFiles, 0, 0) + config, err := pkg.LoadConfig(configFiles, 0) if err != nil { log.Panic(err) } diff --git a/cmd/root.go b/cmd/root.go index ef176b6..f49c8d7 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -18,7 +18,6 @@ import ( var configFiles []string var jsonLog bool var deploymentId int -var brokerIndex int var rootCmd = &cobra.Command{ Use: "semgrep-network-broker", @@ -40,7 +39,7 @@ var rootCmd = &cobra.Command{ }() // load config(s) - config, err := pkg.LoadConfig(configFiles, deploymentId, brokerIndex) + config, err := pkg.LoadConfig(configFiles, deploymentId) if err != nil { log.Panic(err) } @@ -76,7 +75,7 @@ func StartNetworkBroker(config *pkg.Config) (func() error, error) { return wireguardTeardown() } - // start inbound proxy (semgrep --> customer) + // start inbound proxy (r2c --> customer) if err := config.Inbound.Start(tnet); err != nil { teardown() return nil, fmt.Errorf("failed to start inbound proxy: %v", err) @@ -96,5 +95,4 @@ func init() { rootCmd.PersistentFlags().StringArrayVarP(&configFiles, "config", "c", nil, "config file(s)") rootCmd.PersistentFlags().BoolVarP(&jsonLog, "json-log", "j", false, "JSON log output") rootCmd.PersistentFlags().IntVarP(&deploymentId, "deployment-id", "d", 0, "Semgrep deployment ID") - rootCmd.PersistentFlags().IntVarP(&brokerIndex, "broker-index", "i", 0, "Semgrep network broker index") } diff --git a/pkg/config.go b/pkg/config.go index 588e22a..84f7310 100644 --- a/pkg/config.go +++ b/pkg/config.go @@ -6,7 +6,6 @@ import ( "fmt" "io" "net/http" - "net/netip" "net/url" "os" "reflect" @@ -73,15 +72,13 @@ type WireguardPeer struct { } type WireguardBase struct { - resolvedLocalAddress netip.Addr - LocalAddress string `mapstructure:"localAddress" json:"localAddress" validate:"format=ip"` - Dns []string `mapstructure:"dns" json:"dns" validate:"empty=true > format=ip"` - Mtu int `mapstructure:"mtu" json:"mtu" validate:"gte=0" default:"1420"` - PrivateKey SensitiveBase64String `mapstructure:"privateKey" json:"privateKey" validate:"empty=false"` - ListenPort int `mapstructure:"listenPort" json:"listenPort" validate:"gte=0"` - Peers []WireguardPeer `mapstructure:"peers" json:"peers" validate:"empty=false"` - Verbose bool `mapstructure:"verbose" json:"verbose"` - BrokerIndex int `mapstructure:"brokerIndex" json:"brokerIndex" validate:"gte=0"` + LocalAddress string `mapstructure:"localAddress" json:"localAddress" validate:"format=ip"` + Dns []string `mapstructure:"dns" json:"dns" validate:"empty=true > format=ip"` + Mtu int `mapstructure:"mtu" json:"mtu" validate:"gte=0" default:"1420"` + PrivateKey SensitiveBase64String `mapstructure:"privateKey" json:"privateKey" validate:"empty=false"` + ListenPort int `mapstructure:"listenPort" json:"listenPort" validate:"gte=0"` + Peers []WireguardPeer `mapstructure:"peers" json:"peers" validate:"empty=false"` + Verbose bool `mapstructure:"verbose" json:"verbose"` } type BitTester interface { @@ -265,11 +262,9 @@ type Config struct { Outbound OutboundProxyConfig `mapstructure:"outbound" json:"outbound"` } -func LoadConfig(configFiles []string, deploymentId int, brokerIndex int) (*Config, error) { +func LoadConfig(configFiles []string, deploymentId int) (*Config, error) { config := new(Config) - config.Inbound.Wireguard.BrokerIndex = brokerIndex - if deploymentId > 0 { hostname := os.Getenv("SEMGREP_HOSTNAME") if hostname == "" { diff --git a/pkg/config_test.go b/pkg/config_test.go index 36786b3..b7fdded 100644 --- a/pkg/config_test.go +++ b/pkg/config_test.go @@ -11,7 +11,7 @@ import ( ) func TestEmptyConfigs(t *testing.T) { - config, err := LoadConfig(nil, 0, 0) + config, err := LoadConfig(nil, 0) if err != nil { t.Error(err) } diff --git a/pkg/wireguard.go b/pkg/wireguard.go index e70f60c..dba3be8 100644 --- a/pkg/wireguard.go +++ b/pkg/wireguard.go @@ -2,7 +2,6 @@ package pkg import ( "encoding/hex" - "errors" "fmt" "io" "math/rand" @@ -35,22 +34,10 @@ func (peer WireguardPeer) WriteTo(sb io.StringWriter) { } } -func (base WireguardBase) Validate() error { - privateKeyCount := len(base.PrivateKey) / device.NoisePrivateKeySize - - if base.BrokerIndex >= privateKeyCount { - return errors.New("broker index beyond private key count") - } - - return nil -} - func (base WireguardBase) GenerateConfig() string { sb := strings.Builder{} - indexedPrivateKey := base.PrivateKey[device.NoisePrivateKeySize*base.BrokerIndex : device.NoisePrivateKeySize*(base.BrokerIndex+1)] - - sb.WriteString(fmt.Sprintf("private_key=%s\n", hex.EncodeToString(indexedPrivateKey))) + sb.WriteString(fmt.Sprintf("private_key=%s\n", hex.EncodeToString(base.PrivateKey))) sb.WriteString(fmt.Sprintf("listen_port=%d\n", base.ListenPort)) for i := range base.Peers { @@ -60,16 +47,7 @@ func (base WireguardBase) GenerateConfig() string { return sb.String() } -func (base *WireguardBase) ResolveConfig() error { - resolvedLocalAddress, err := netip.ParseAddr(base.LocalAddress) - if err != nil { - return fmt.Errorf("LocalAddress parse failed: %v", err) - } - for i := 0; i < base.BrokerIndex; i++ { - resolvedLocalAddress = resolvedLocalAddress.Next() - } - base.resolvedLocalAddress = resolvedLocalAddress - +func (base *WireguardBase) ResolvePeerEndpoints() error { for i := range base.Peers { if base.Peers[i].Endpoint == "" { continue @@ -99,11 +77,14 @@ func (config *WireguardBase) Start() (*netstack.Net, func() error, error) { return nil, nil, fmt.Errorf("invalid wireguard config: %v", err) } - // resolve local address and peer endpoints (if not IP address already) - if err := config.ResolveConfig(); err != nil { + // resolve peer endpoints (if not IP address already) + if err := config.ResolvePeerEndpoints(); err != nil { return nil, nil, fmt.Errorf("failed to resolve peer endpoint: %v", err) } + // parse localAddres and DNS addresses -- MustParseAddr is fine here because we've already validated the config + localAddress := netip.MustParseAddr(config.LocalAddress) + var dnsAddresses = make([]netip.Addr, len(config.Dns)) for i := range config.Dns { dnsAddresses[i] = netip.MustParseAddr(config.Dns[i]) @@ -111,7 +92,7 @@ func (config *WireguardBase) Start() (*netstack.Net, func() error, error) { // create the wireguard interface tun, tnet, err := netstack.CreateNetTUN( - []netip.Addr{config.resolvedLocalAddress}, + []netip.Addr{localAddress}, dnsAddresses, config.Mtu, )