Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "better support for running multiple separate brokers" #87

Merged
merged 1 commit into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/dump.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ var dumpCmd = &cobra.Command{
Use: "dump",
Short: "Dump current config",
Run: func(cmd *cobra.Command, args []string) {
config, err := pkg.LoadConfig(configFiles, deploymentId, brokerIndex)
config, err := pkg.LoadConfig(configFiles, deploymentId)
if err != nil {
log.Panic(err)
}
Expand Down
28 changes: 5 additions & 23 deletions cmd/genkey.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,25 @@ package cmd
import (
"encoding/base64"
"fmt"
"os"

log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)

var replicaCount int

const defaultReplicaCount = 3
const minReplicaCount = 1
const maxReplicaCount = 16

var genkeyCmd = &cobra.Command{
Use: "genkey",
Short: "Generates a random Semgrep Network Broker private key and prints it to stdout.",
Short: "Generates a random private key in base64 and prints it to stdout",
Run: func(cmd *cobra.Command, args []string) {
if replicaCount < minReplicaCount || replicaCount > maxReplicaCount {
log.Panic(fmt.Errorf("replica count must be between %v and %v", minReplicaCount, maxReplicaCount))
privateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
log.Panic(fmt.Errorf("failed to generate private key: %v", err))
}

encoder := base64.NewEncoder(base64.StdEncoding, os.Stdout)
defer encoder.Close()

for i := 0; i < replicaCount; i++ {
privateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
log.Panic(fmt.Errorf("failed to generate private key %v: %v", i, err))
}
if _, err := encoder.Write(privateKey[:]); err != nil {
log.Panic(fmt.Errorf("failed to write private key %v: %v", i, err))
}
}
fmt.Println(base64.StdEncoding.EncodeToString(privateKey[:]))
},
}

func init() {
genkeyCmd.PersistentFlags().IntVarP(&replicaCount, "replica-count", "r", defaultReplicaCount, "Number of broker replicas to support")
rootCmd.AddCommand(genkeyCmd)
}
47 changes: 21 additions & 26 deletions cmd/pubkey.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,40 +8,35 @@ import (

log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)

var pubkeyCmd = &cobra.Command{
Use: "pubkey",
Short: "Reads a Semgrep Network Broker private key from stdin and ptints the corresponding public key to stdout.",
Short: "Reads a base64 private key from stdin, outputs the corresponding base64 public key",
Run: func(cmd *cobra.Command, args []string) {
keyBase64, err := io.ReadAll(os.Stdin)
if err != nil {
log.Panic(err)
}

keyBytes := make([]byte, 32)
n, err := base64.StdEncoding.Decode(keyBytes, keyBase64)
if err != nil {
log.Panic(err)
}
if n != 32 {
log.Panic("not enough bytes")
}

decoder := base64.NewDecoder(base64.StdEncoding, os.Stdin)
encoder := base64.NewEncoder(base64.StdEncoding, os.Stdout)
defer encoder.Close()

privateKeyBytes := make([]byte, device.NoisePrivateKeySize)

for i := 0; ; i++ {
_, err := io.ReadFull(decoder, privateKeyBytes)
if err != nil {
if err == io.EOF {
break
} else {
log.Panic(fmt.Errorf("error reading private key %v: %v", i, err))
}
}
privateKey, err := wgtypes.NewKey(privateKeyBytes)
if err != nil {
log.Panic(fmt.Errorf("error creating private key %v: %v", i, err))
}

publicKey := privateKey.PublicKey()
if _, err := encoder.Write(publicKey[:]); err != nil {
log.Panic(fmt.Errorf("error writing public key %v: %v", i, err))
}
privateKey, err := wgtypes.NewKey(keyBytes)
if err != nil {
log.Panic(err)
}

publicKey := privateKey.PublicKey()

fmt.Println(base64.StdEncoding.EncodeToString(publicKey[:]))
},
}

Expand Down
2 changes: 1 addition & 1 deletion cmd/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ var relayCmd = &cobra.Command{
}()

// load config(s)
config, err := pkg.LoadConfig(configFiles, 0, 0)
config, err := pkg.LoadConfig(configFiles, 0)
if err != nil {
log.Panic(err)
}
Expand Down
6 changes: 2 additions & 4 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
var configFiles []string
var jsonLog bool
var deploymentId int
var brokerIndex int

var rootCmd = &cobra.Command{
Use: "semgrep-network-broker",
Expand All @@ -40,7 +39,7 @@ var rootCmd = &cobra.Command{
}()

// load config(s)
config, err := pkg.LoadConfig(configFiles, deploymentId, brokerIndex)
config, err := pkg.LoadConfig(configFiles, deploymentId)
if err != nil {
log.Panic(err)
}
Expand Down Expand Up @@ -76,7 +75,7 @@ func StartNetworkBroker(config *pkg.Config) (func() error, error) {
return wireguardTeardown()
}

// start inbound proxy (semgrep --> customer)
// start inbound proxy (r2c --> customer)
if err := config.Inbound.Start(tnet); err != nil {
teardown()
return nil, fmt.Errorf("failed to start inbound proxy: %v", err)
Expand All @@ -96,5 +95,4 @@ func init() {
rootCmd.PersistentFlags().StringArrayVarP(&configFiles, "config", "c", nil, "config file(s)")
rootCmd.PersistentFlags().BoolVarP(&jsonLog, "json-log", "j", false, "JSON log output")
rootCmd.PersistentFlags().IntVarP(&deploymentId, "deployment-id", "d", 0, "Semgrep deployment ID")
rootCmd.PersistentFlags().IntVarP(&brokerIndex, "broker-index", "i", 0, "Semgrep network broker index")
}
21 changes: 8 additions & 13 deletions pkg/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"io"
"net/http"
"net/netip"
"net/url"
"os"
"reflect"
Expand Down Expand Up @@ -73,15 +72,13 @@ type WireguardPeer struct {
}

type WireguardBase struct {
resolvedLocalAddress netip.Addr
LocalAddress string `mapstructure:"localAddress" json:"localAddress" validate:"format=ip"`
Dns []string `mapstructure:"dns" json:"dns" validate:"empty=true > format=ip"`
Mtu int `mapstructure:"mtu" json:"mtu" validate:"gte=0" default:"1420"`
PrivateKey SensitiveBase64String `mapstructure:"privateKey" json:"privateKey" validate:"empty=false"`
ListenPort int `mapstructure:"listenPort" json:"listenPort" validate:"gte=0"`
Peers []WireguardPeer `mapstructure:"peers" json:"peers" validate:"empty=false"`
Verbose bool `mapstructure:"verbose" json:"verbose"`
BrokerIndex int `mapstructure:"brokerIndex" json:"brokerIndex" validate:"gte=0"`
LocalAddress string `mapstructure:"localAddress" json:"localAddress" validate:"format=ip"`
Dns []string `mapstructure:"dns" json:"dns" validate:"empty=true > format=ip"`
Mtu int `mapstructure:"mtu" json:"mtu" validate:"gte=0" default:"1420"`
PrivateKey SensitiveBase64String `mapstructure:"privateKey" json:"privateKey" validate:"empty=false"`
ListenPort int `mapstructure:"listenPort" json:"listenPort" validate:"gte=0"`
Peers []WireguardPeer `mapstructure:"peers" json:"peers" validate:"empty=false"`
Verbose bool `mapstructure:"verbose" json:"verbose"`
}

type BitTester interface {
Expand Down Expand Up @@ -265,11 +262,9 @@ type Config struct {
Outbound OutboundProxyConfig `mapstructure:"outbound" json:"outbound"`
}

func LoadConfig(configFiles []string, deploymentId int, brokerIndex int) (*Config, error) {
func LoadConfig(configFiles []string, deploymentId int) (*Config, error) {
config := new(Config)

config.Inbound.Wireguard.BrokerIndex = brokerIndex

if deploymentId > 0 {
hostname := os.Getenv("SEMGREP_HOSTNAME")
if hostname == "" {
Expand Down
2 changes: 1 addition & 1 deletion pkg/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
)

func TestEmptyConfigs(t *testing.T) {
config, err := LoadConfig(nil, 0, 0)
config, err := LoadConfig(nil, 0)
if err != nil {
t.Error(err)
}
Expand Down
35 changes: 8 additions & 27 deletions pkg/wireguard.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package pkg

import (
"encoding/hex"
"errors"
"fmt"
"io"
"math/rand"
Expand Down Expand Up @@ -35,22 +34,10 @@ func (peer WireguardPeer) WriteTo(sb io.StringWriter) {
}
}

func (base WireguardBase) Validate() error {
privateKeyCount := len(base.PrivateKey) / device.NoisePrivateKeySize

if base.BrokerIndex >= privateKeyCount {
return errors.New("broker index beyond private key count")
}

return nil
}

func (base WireguardBase) GenerateConfig() string {
sb := strings.Builder{}

indexedPrivateKey := base.PrivateKey[device.NoisePrivateKeySize*base.BrokerIndex : device.NoisePrivateKeySize*(base.BrokerIndex+1)]

sb.WriteString(fmt.Sprintf("private_key=%s\n", hex.EncodeToString(indexedPrivateKey)))
sb.WriteString(fmt.Sprintf("private_key=%s\n", hex.EncodeToString(base.PrivateKey)))
sb.WriteString(fmt.Sprintf("listen_port=%d\n", base.ListenPort))

for i := range base.Peers {
Expand All @@ -60,16 +47,7 @@ func (base WireguardBase) GenerateConfig() string {
return sb.String()
}

func (base *WireguardBase) ResolveConfig() error {
resolvedLocalAddress, err := netip.ParseAddr(base.LocalAddress)
if err != nil {
return fmt.Errorf("LocalAddress parse failed: %v", err)
}
for i := 0; i < base.BrokerIndex; i++ {
resolvedLocalAddress = resolvedLocalAddress.Next()
}
base.resolvedLocalAddress = resolvedLocalAddress

func (base *WireguardBase) ResolvePeerEndpoints() error {
for i := range base.Peers {
if base.Peers[i].Endpoint == "" {
continue
Expand Down Expand Up @@ -99,19 +77,22 @@ func (config *WireguardBase) Start() (*netstack.Net, func() error, error) {
return nil, nil, fmt.Errorf("invalid wireguard config: %v", err)
}

// resolve local address and peer endpoints (if not IP address already)
if err := config.ResolveConfig(); err != nil {
// resolve peer endpoints (if not IP address already)
if err := config.ResolvePeerEndpoints(); err != nil {
return nil, nil, fmt.Errorf("failed to resolve peer endpoint: %v", err)
}

// parse localAddres and DNS addresses -- MustParseAddr is fine here because we've already validated the config
localAddress := netip.MustParseAddr(config.LocalAddress)

var dnsAddresses = make([]netip.Addr, len(config.Dns))
for i := range config.Dns {
dnsAddresses[i] = netip.MustParseAddr(config.Dns[i])
}

// create the wireguard interface
tun, tnet, err := netstack.CreateNetTUN(
[]netip.Addr{config.resolvedLocalAddress},
[]netip.Addr{localAddress},
dnsAddresses,
config.Mtu,
)
Expand Down
Loading