Skip to content

Commit

Permalink
better support for running multiple separate brokers (#85)
Browse files Browse the repository at this point in the history
* better support for HA brokers

* slightly better genkey/pubkey impl
  • Loading branch information
tpetr authored Sep 17, 2024
1 parent 62bfa29 commit 66150c5
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 47 deletions.
2 changes: 1 addition & 1 deletion cmd/dump.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ var dumpCmd = &cobra.Command{
Use: "dump",
Short: "Dump current config",
Run: func(cmd *cobra.Command, args []string) {
config, err := pkg.LoadConfig(configFiles, deploymentId)
config, err := pkg.LoadConfig(configFiles, deploymentId, brokerIndex)
if err != nil {
log.Panic(err)
}
Expand Down
28 changes: 23 additions & 5 deletions cmd/genkey.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,43 @@ package cmd
import (
"encoding/base64"
"fmt"
"os"

log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)

var replicaCount int

const defaultReplicaCount = 3
const minReplicaCount = 1
const maxReplicaCount = 16

var genkeyCmd = &cobra.Command{
Use: "genkey",
Short: "Generates a random private key in base64 and prints it to stdout",
Short: "Generates a random Semgrep Network Broker private key and prints it to stdout.",
Run: func(cmd *cobra.Command, args []string) {
privateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
log.Panic(fmt.Errorf("failed to generate private key: %v", err))
if replicaCount < minReplicaCount || replicaCount > maxReplicaCount {
log.Panic(fmt.Errorf("replica count must be between %v and %v", minReplicaCount, maxReplicaCount))
}

fmt.Println(base64.StdEncoding.EncodeToString(privateKey[:]))
encoder := base64.NewEncoder(base64.StdEncoding, os.Stdout)
defer encoder.Close()

for i := 0; i < replicaCount; i++ {
privateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
log.Panic(fmt.Errorf("failed to generate private key %v: %v", i, err))
}
if _, err := encoder.Write(privateKey[:]); err != nil {
log.Panic(fmt.Errorf("failed to write private key %v: %v", i, err))
}
}
},
}

func init() {
genkeyCmd.PersistentFlags().IntVarP(&replicaCount, "replica-count", "r", defaultReplicaCount, "Number of broker replicas to support")
rootCmd.AddCommand(genkeyCmd)
}
47 changes: 26 additions & 21 deletions cmd/pubkey.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,40 @@ import (

log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)

var pubkeyCmd = &cobra.Command{
Use: "pubkey",
Short: "Reads a base64 private key from stdin, outputs the corresponding base64 public key",
Short: "Reads a Semgrep Network Broker private key from stdin and ptints the corresponding public key to stdout.",
Run: func(cmd *cobra.Command, args []string) {
keyBase64, err := io.ReadAll(os.Stdin)
if err != nil {
log.Panic(err)
}

keyBytes := make([]byte, 32)
n, err := base64.StdEncoding.Decode(keyBytes, keyBase64)
if err != nil {
log.Panic(err)
}
if n != 32 {
log.Panic("not enough bytes")
}

privateKey, err := wgtypes.NewKey(keyBytes)
if err != nil {
log.Panic(err)
decoder := base64.NewDecoder(base64.StdEncoding, os.Stdin)
encoder := base64.NewEncoder(base64.StdEncoding, os.Stdout)
defer encoder.Close()

privateKeyBytes := make([]byte, device.NoisePrivateKeySize)

for i := 0; ; i++ {
_, err := io.ReadFull(decoder, privateKeyBytes)
if err != nil {
if err == io.EOF {
break
} else {
log.Panic(fmt.Errorf("error reading private key %v: %v", i, err))
}
}
privateKey, err := wgtypes.NewKey(privateKeyBytes)
if err != nil {
log.Panic(fmt.Errorf("error creating private key %v: %v", i, err))
}

publicKey := privateKey.PublicKey()
if _, err := encoder.Write(publicKey[:]); err != nil {
log.Panic(fmt.Errorf("error writing public key %v: %v", i, err))
}
}

publicKey := privateKey.PublicKey()

fmt.Println(base64.StdEncoding.EncodeToString(publicKey[:]))
},
}

Expand Down
2 changes: 1 addition & 1 deletion cmd/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ var relayCmd = &cobra.Command{
}()

// load config(s)
config, err := pkg.LoadConfig(configFiles, 0)
config, err := pkg.LoadConfig(configFiles, 0, 0)
if err != nil {
log.Panic(err)
}
Expand Down
6 changes: 4 additions & 2 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
var configFiles []string
var jsonLog bool
var deploymentId int
var brokerIndex int

var rootCmd = &cobra.Command{
Use: "semgrep-network-broker",
Expand All @@ -39,7 +40,7 @@ var rootCmd = &cobra.Command{
}()

// load config(s)
config, err := pkg.LoadConfig(configFiles, deploymentId)
config, err := pkg.LoadConfig(configFiles, deploymentId, brokerIndex)
if err != nil {
log.Panic(err)
}
Expand Down Expand Up @@ -75,7 +76,7 @@ func StartNetworkBroker(config *pkg.Config) (func() error, error) {
return wireguardTeardown()
}

// start inbound proxy (r2c --> customer)
// start inbound proxy (semgrep --> customer)
if err := config.Inbound.Start(tnet); err != nil {
teardown()
return nil, fmt.Errorf("failed to start inbound proxy: %v", err)
Expand All @@ -95,4 +96,5 @@ func init() {
rootCmd.PersistentFlags().StringArrayVarP(&configFiles, "config", "c", nil, "config file(s)")
rootCmd.PersistentFlags().BoolVarP(&jsonLog, "json-log", "j", false, "JSON log output")
rootCmd.PersistentFlags().IntVarP(&deploymentId, "deployment-id", "d", 0, "Semgrep deployment ID")
rootCmd.PersistentFlags().IntVarP(&brokerIndex, "broker-index", "i", 0, "Semgrep network broker index")
}
21 changes: 13 additions & 8 deletions pkg/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net/http"
"net/netip"
"net/url"
"os"
"reflect"
Expand Down Expand Up @@ -72,13 +73,15 @@ type WireguardPeer struct {
}

type WireguardBase struct {
LocalAddress string `mapstructure:"localAddress" json:"localAddress" validate:"format=ip"`
Dns []string `mapstructure:"dns" json:"dns" validate:"empty=true > format=ip"`
Mtu int `mapstructure:"mtu" json:"mtu" validate:"gte=0" default:"1420"`
PrivateKey SensitiveBase64String `mapstructure:"privateKey" json:"privateKey" validate:"empty=false"`
ListenPort int `mapstructure:"listenPort" json:"listenPort" validate:"gte=0"`
Peers []WireguardPeer `mapstructure:"peers" json:"peers" validate:"empty=false"`
Verbose bool `mapstructure:"verbose" json:"verbose"`
resolvedLocalAddress netip.Addr
LocalAddress string `mapstructure:"localAddress" json:"localAddress" validate:"format=ip"`
Dns []string `mapstructure:"dns" json:"dns" validate:"empty=true > format=ip"`
Mtu int `mapstructure:"mtu" json:"mtu" validate:"gte=0" default:"1420"`
PrivateKey SensitiveBase64String `mapstructure:"privateKey" json:"privateKey" validate:"empty=false"`
ListenPort int `mapstructure:"listenPort" json:"listenPort" validate:"gte=0"`
Peers []WireguardPeer `mapstructure:"peers" json:"peers" validate:"empty=false"`
Verbose bool `mapstructure:"verbose" json:"verbose"`
BrokerIndex int `mapstructure:"brokerIndex" json:"brokerIndex" validate:"gte=0"`
}

type BitTester interface {
Expand Down Expand Up @@ -262,9 +265,11 @@ type Config struct {
Outbound OutboundProxyConfig `mapstructure:"outbound" json:"outbound"`
}

func LoadConfig(configFiles []string, deploymentId int) (*Config, error) {
func LoadConfig(configFiles []string, deploymentId int, brokerIndex int) (*Config, error) {
config := new(Config)

config.Inbound.Wireguard.BrokerIndex = brokerIndex

if deploymentId > 0 {
hostname := os.Getenv("SEMGREP_HOSTNAME")
if hostname == "" {
Expand Down
2 changes: 1 addition & 1 deletion pkg/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
)

func TestEmptyConfigs(t *testing.T) {
config, err := LoadConfig(nil, 0)
config, err := LoadConfig(nil, 0, 0)
if err != nil {
t.Error(err)
}
Expand Down
35 changes: 27 additions & 8 deletions pkg/wireguard.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pkg

import (
"encoding/hex"
"errors"
"fmt"
"io"
"math/rand"
Expand Down Expand Up @@ -34,10 +35,22 @@ func (peer WireguardPeer) WriteTo(sb io.StringWriter) {
}
}

func (base WireguardBase) Validate() error {
privateKeyCount := len(base.PrivateKey) / device.NoisePrivateKeySize

if base.BrokerIndex >= privateKeyCount {
return errors.New("broker index beyond private key count")
}

return nil
}

func (base WireguardBase) GenerateConfig() string {
sb := strings.Builder{}

sb.WriteString(fmt.Sprintf("private_key=%s\n", hex.EncodeToString(base.PrivateKey)))
indexedPrivateKey := base.PrivateKey[device.NoisePrivateKeySize*base.BrokerIndex : device.NoisePrivateKeySize*(base.BrokerIndex+1)]

sb.WriteString(fmt.Sprintf("private_key=%s\n", hex.EncodeToString(indexedPrivateKey)))
sb.WriteString(fmt.Sprintf("listen_port=%d\n", base.ListenPort))

for i := range base.Peers {
Expand All @@ -47,7 +60,16 @@ func (base WireguardBase) GenerateConfig() string {
return sb.String()
}

func (base *WireguardBase) ResolvePeerEndpoints() error {
func (base *WireguardBase) ResolveConfig() error {
resolvedLocalAddress, err := netip.ParseAddr(base.LocalAddress)
if err != nil {
return fmt.Errorf("LocalAddress parse failed: %v", err)
}
for i := 0; i < base.BrokerIndex; i++ {
resolvedLocalAddress = resolvedLocalAddress.Next()
}
base.resolvedLocalAddress = resolvedLocalAddress

for i := range base.Peers {
if base.Peers[i].Endpoint == "" {
continue
Expand Down Expand Up @@ -77,22 +99,19 @@ func (config *WireguardBase) Start() (*netstack.Net, func() error, error) {
return nil, nil, fmt.Errorf("invalid wireguard config: %v", err)
}

// resolve peer endpoints (if not IP address already)
if err := config.ResolvePeerEndpoints(); err != nil {
// resolve local address and peer endpoints (if not IP address already)
if err := config.ResolveConfig(); err != nil {
return nil, nil, fmt.Errorf("failed to resolve peer endpoint: %v", err)
}

// parse localAddres and DNS addresses -- MustParseAddr is fine here because we've already validated the config
localAddress := netip.MustParseAddr(config.LocalAddress)

var dnsAddresses = make([]netip.Addr, len(config.Dns))
for i := range config.Dns {
dnsAddresses[i] = netip.MustParseAddr(config.Dns[i])
}

// create the wireguard interface
tun, tnet, err := netstack.CreateNetTUN(
[]netip.Addr{localAddress},
[]netip.Addr{config.resolvedLocalAddress},
dnsAddresses,
config.Mtu,
)
Expand Down

0 comments on commit 66150c5

Please sign in to comment.