Skip to content

Commit

Permalink
Revert "better support for running multiple separate brokers (#85)" (#87
Browse files Browse the repository at this point in the history
)

This reverts commit 66150c5.
  • Loading branch information
tpetr authored Oct 7, 2024
1 parent 66150c5 commit 63532f1
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 96 deletions.
2 changes: 1 addition & 1 deletion cmd/dump.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ var dumpCmd = &cobra.Command{
Use: "dump",
Short: "Dump current config",
Run: func(cmd *cobra.Command, args []string) {
config, err := pkg.LoadConfig(configFiles, deploymentId, brokerIndex)
config, err := pkg.LoadConfig(configFiles, deploymentId)
if err != nil {
log.Panic(err)
}
Expand Down
28 changes: 5 additions & 23 deletions cmd/genkey.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,25 @@ package cmd
import (
"encoding/base64"
"fmt"
"os"

log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)

var replicaCount int

const defaultReplicaCount = 3
const minReplicaCount = 1
const maxReplicaCount = 16

var genkeyCmd = &cobra.Command{
Use: "genkey",
Short: "Generates a random Semgrep Network Broker private key and prints it to stdout.",
Short: "Generates a random private key in base64 and prints it to stdout",
Run: func(cmd *cobra.Command, args []string) {
if replicaCount < minReplicaCount || replicaCount > maxReplicaCount {
log.Panic(fmt.Errorf("replica count must be between %v and %v", minReplicaCount, maxReplicaCount))
privateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
log.Panic(fmt.Errorf("failed to generate private key: %v", err))
}

encoder := base64.NewEncoder(base64.StdEncoding, os.Stdout)
defer encoder.Close()

for i := 0; i < replicaCount; i++ {
privateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
log.Panic(fmt.Errorf("failed to generate private key %v: %v", i, err))
}
if _, err := encoder.Write(privateKey[:]); err != nil {
log.Panic(fmt.Errorf("failed to write private key %v: %v", i, err))
}
}
fmt.Println(base64.StdEncoding.EncodeToString(privateKey[:]))
},
}

func init() {
genkeyCmd.PersistentFlags().IntVarP(&replicaCount, "replica-count", "r", defaultReplicaCount, "Number of broker replicas to support")
rootCmd.AddCommand(genkeyCmd)
}
47 changes: 21 additions & 26 deletions cmd/pubkey.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,40 +8,35 @@ import (

log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)

var pubkeyCmd = &cobra.Command{
Use: "pubkey",
Short: "Reads a Semgrep Network Broker private key from stdin and ptints the corresponding public key to stdout.",
Short: "Reads a base64 private key from stdin, outputs the corresponding base64 public key",
Run: func(cmd *cobra.Command, args []string) {
keyBase64, err := io.ReadAll(os.Stdin)
if err != nil {
log.Panic(err)
}

keyBytes := make([]byte, 32)
n, err := base64.StdEncoding.Decode(keyBytes, keyBase64)
if err != nil {
log.Panic(err)
}
if n != 32 {
log.Panic("not enough bytes")
}

decoder := base64.NewDecoder(base64.StdEncoding, os.Stdin)
encoder := base64.NewEncoder(base64.StdEncoding, os.Stdout)
defer encoder.Close()

privateKeyBytes := make([]byte, device.NoisePrivateKeySize)

for i := 0; ; i++ {
_, err := io.ReadFull(decoder, privateKeyBytes)
if err != nil {
if err == io.EOF {
break
} else {
log.Panic(fmt.Errorf("error reading private key %v: %v", i, err))
}
}
privateKey, err := wgtypes.NewKey(privateKeyBytes)
if err != nil {
log.Panic(fmt.Errorf("error creating private key %v: %v", i, err))
}

publicKey := privateKey.PublicKey()
if _, err := encoder.Write(publicKey[:]); err != nil {
log.Panic(fmt.Errorf("error writing public key %v: %v", i, err))
}
privateKey, err := wgtypes.NewKey(keyBytes)
if err != nil {
log.Panic(err)
}

publicKey := privateKey.PublicKey()

fmt.Println(base64.StdEncoding.EncodeToString(publicKey[:]))
},
}

Expand Down
2 changes: 1 addition & 1 deletion cmd/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ var relayCmd = &cobra.Command{
}()

// load config(s)
config, err := pkg.LoadConfig(configFiles, 0, 0)
config, err := pkg.LoadConfig(configFiles, 0)
if err != nil {
log.Panic(err)
}
Expand Down
6 changes: 2 additions & 4 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
var configFiles []string
var jsonLog bool
var deploymentId int
var brokerIndex int

var rootCmd = &cobra.Command{
Use: "semgrep-network-broker",
Expand All @@ -40,7 +39,7 @@ var rootCmd = &cobra.Command{
}()

// load config(s)
config, err := pkg.LoadConfig(configFiles, deploymentId, brokerIndex)
config, err := pkg.LoadConfig(configFiles, deploymentId)
if err != nil {
log.Panic(err)
}
Expand Down Expand Up @@ -76,7 +75,7 @@ func StartNetworkBroker(config *pkg.Config) (func() error, error) {
return wireguardTeardown()
}

// start inbound proxy (semgrep --> customer)
// start inbound proxy (r2c --> customer)
if err := config.Inbound.Start(tnet); err != nil {
teardown()
return nil, fmt.Errorf("failed to start inbound proxy: %v", err)
Expand All @@ -96,5 +95,4 @@ func init() {
rootCmd.PersistentFlags().StringArrayVarP(&configFiles, "config", "c", nil, "config file(s)")
rootCmd.PersistentFlags().BoolVarP(&jsonLog, "json-log", "j", false, "JSON log output")
rootCmd.PersistentFlags().IntVarP(&deploymentId, "deployment-id", "d", 0, "Semgrep deployment ID")
rootCmd.PersistentFlags().IntVarP(&brokerIndex, "broker-index", "i", 0, "Semgrep network broker index")
}
21 changes: 8 additions & 13 deletions pkg/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"io"
"net/http"
"net/netip"
"net/url"
"os"
"reflect"
Expand Down Expand Up @@ -73,15 +72,13 @@ type WireguardPeer struct {
}

type WireguardBase struct {
resolvedLocalAddress netip.Addr
LocalAddress string `mapstructure:"localAddress" json:"localAddress" validate:"format=ip"`
Dns []string `mapstructure:"dns" json:"dns" validate:"empty=true > format=ip"`
Mtu int `mapstructure:"mtu" json:"mtu" validate:"gte=0" default:"1420"`
PrivateKey SensitiveBase64String `mapstructure:"privateKey" json:"privateKey" validate:"empty=false"`
ListenPort int `mapstructure:"listenPort" json:"listenPort" validate:"gte=0"`
Peers []WireguardPeer `mapstructure:"peers" json:"peers" validate:"empty=false"`
Verbose bool `mapstructure:"verbose" json:"verbose"`
BrokerIndex int `mapstructure:"brokerIndex" json:"brokerIndex" validate:"gte=0"`
LocalAddress string `mapstructure:"localAddress" json:"localAddress" validate:"format=ip"`
Dns []string `mapstructure:"dns" json:"dns" validate:"empty=true > format=ip"`
Mtu int `mapstructure:"mtu" json:"mtu" validate:"gte=0" default:"1420"`
PrivateKey SensitiveBase64String `mapstructure:"privateKey" json:"privateKey" validate:"empty=false"`
ListenPort int `mapstructure:"listenPort" json:"listenPort" validate:"gte=0"`
Peers []WireguardPeer `mapstructure:"peers" json:"peers" validate:"empty=false"`
Verbose bool `mapstructure:"verbose" json:"verbose"`
}

type BitTester interface {
Expand Down Expand Up @@ -265,11 +262,9 @@ type Config struct {
Outbound OutboundProxyConfig `mapstructure:"outbound" json:"outbound"`
}

func LoadConfig(configFiles []string, deploymentId int, brokerIndex int) (*Config, error) {
func LoadConfig(configFiles []string, deploymentId int) (*Config, error) {
config := new(Config)

config.Inbound.Wireguard.BrokerIndex = brokerIndex

if deploymentId > 0 {
hostname := os.Getenv("SEMGREP_HOSTNAME")
if hostname == "" {
Expand Down
2 changes: 1 addition & 1 deletion pkg/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
)

func TestEmptyConfigs(t *testing.T) {
config, err := LoadConfig(nil, 0, 0)
config, err := LoadConfig(nil, 0)
if err != nil {
t.Error(err)
}
Expand Down
35 changes: 8 additions & 27 deletions pkg/wireguard.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package pkg

import (
"encoding/hex"
"errors"
"fmt"
"io"
"math/rand"
Expand Down Expand Up @@ -35,22 +34,10 @@ func (peer WireguardPeer) WriteTo(sb io.StringWriter) {
}
}

func (base WireguardBase) Validate() error {
privateKeyCount := len(base.PrivateKey) / device.NoisePrivateKeySize

if base.BrokerIndex >= privateKeyCount {
return errors.New("broker index beyond private key count")
}

return nil
}

func (base WireguardBase) GenerateConfig() string {
sb := strings.Builder{}

indexedPrivateKey := base.PrivateKey[device.NoisePrivateKeySize*base.BrokerIndex : device.NoisePrivateKeySize*(base.BrokerIndex+1)]

sb.WriteString(fmt.Sprintf("private_key=%s\n", hex.EncodeToString(indexedPrivateKey)))
sb.WriteString(fmt.Sprintf("private_key=%s\n", hex.EncodeToString(base.PrivateKey)))
sb.WriteString(fmt.Sprintf("listen_port=%d\n", base.ListenPort))

for i := range base.Peers {
Expand All @@ -60,16 +47,7 @@ func (base WireguardBase) GenerateConfig() string {
return sb.String()
}

func (base *WireguardBase) ResolveConfig() error {
resolvedLocalAddress, err := netip.ParseAddr(base.LocalAddress)
if err != nil {
return fmt.Errorf("LocalAddress parse failed: %v", err)
}
for i := 0; i < base.BrokerIndex; i++ {
resolvedLocalAddress = resolvedLocalAddress.Next()
}
base.resolvedLocalAddress = resolvedLocalAddress

func (base *WireguardBase) ResolvePeerEndpoints() error {
for i := range base.Peers {
if base.Peers[i].Endpoint == "" {
continue
Expand Down Expand Up @@ -99,19 +77,22 @@ func (config *WireguardBase) Start() (*netstack.Net, func() error, error) {
return nil, nil, fmt.Errorf("invalid wireguard config: %v", err)
}

// resolve local address and peer endpoints (if not IP address already)
if err := config.ResolveConfig(); err != nil {
// resolve peer endpoints (if not IP address already)
if err := config.ResolvePeerEndpoints(); err != nil {
return nil, nil, fmt.Errorf("failed to resolve peer endpoint: %v", err)
}

// parse localAddres and DNS addresses -- MustParseAddr is fine here because we've already validated the config
localAddress := netip.MustParseAddr(config.LocalAddress)

var dnsAddresses = make([]netip.Addr, len(config.Dns))
for i := range config.Dns {
dnsAddresses[i] = netip.MustParseAddr(config.Dns[i])
}

// create the wireguard interface
tun, tnet, err := netstack.CreateNetTUN(
[]netip.Addr{config.resolvedLocalAddress},
[]netip.Addr{localAddress},
dnsAddresses,
config.Mtu,
)
Expand Down

0 comments on commit 63532f1

Please sign in to comment.