Skip to content

Commit

Permalink
fix: type compatibility (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
gaukas authored Mar 15, 2024
1 parent 3806a87 commit 58cc372
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 27 deletions.
89 changes: 62 additions & 27 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,27 @@ import (

var ErrNoDialer = errors.New("no dialer available")

func unprotectedDial(network, address string) (net.Conn, error) {
return net.Dial(network, address)
}
type Dialer struct {
protectedDial func(network, address string) (net.Conn, error)
unprotectedDial func(network, address string) (net.Conn, error)

var protectedDial = func(network, address string) (net.Conn, error) {
return nil, ErrNoDialer
configJSON []byte
configPB []byte
}

type Protector interface {
Protect(fd int) bool
func NewDialer() *Dialer {
return &Dialer{
protectedDial: func(network, address string) (net.Conn, error) {
return nil, ErrNoDialer
},
unprotectedDial: net.Dial,
}
}

// SetProtector updates the protectedDial function to use the provided Protector
// to protect the file descriptor of the connection.
func SetProtector(p Protector) {
protectedDial = func(network, address string) (net.Conn, error) {
func (d *Dialer) SetProtector(p Protector) {
d.protectedDial = func(network, address string) (net.Conn, error) {
dialer := &net.Dialer{
Timeout: time.Second * 16,
LocalAddr: nil,
Expand All @@ -50,35 +55,65 @@ func SetProtector(p Protector) {
}
}

func ProtectedDialWATER(remoteAddr string, wasm []byte) (net.Conn, error) {
return dialWATER(remoteAddr, wasm, protectedDial)
func (d *Dialer) SetConfigJSON(configJSON []byte) {
d.configJSON = configJSON
d.configPB = nil
}

func (d *Dialer) SetConfigPB(configPB []byte) {
d.configPB = configPB
d.configJSON = nil
}

func (d *Dialer) DialWATERProtected(network, remoteAddr string, wasm []byte) (NetConn, error) {
return d.dialWATER(network, remoteAddr, wasm, d.protectedDial)
}

func (d *Dialer) DialWATERUnprotected(network, remoteAddr string, wasm []byte) (NetConn, error) {
return d.dialWATER(network, remoteAddr, wasm, d.unprotectedDial)
}

func (d *Dialer) DirectlyStartWorkerProtected(network, remoteAddr string, wasm []byte) error {
conn, err := d.DialWATERProtected(network, remoteAddr, wasm)
if err != nil {
panic(fmt.Sprintf("failed to dial: %v", err))
}

return startWorker(conn)
}

func UnprotectedDialWATER(remoteAddr string, wasm []byte) (net.Conn, error) {
return dialWATER(remoteAddr, wasm, unprotectedDial)
func (d *Dialer) DirectlyStartWorkerUnprotected(network, remoteAddr string, wasm []byte) error {
conn, err := d.DialWATERUnprotected(network, remoteAddr, wasm)
if err != nil {
panic(fmt.Sprintf("failed to dial: %v", err))
}

return startWorker(conn)
}

func dialWATER(remoteAddr string, wasm []byte, dialerFunc func(network string, address string) (net.Conn, error)) (net.Conn, error) {
func (d *Dialer) dialWATER(network, remoteAddr string,
wasm []byte,
dialerFunc func(network, address string) (net.Conn, error),
) (NetConn, error) {
config := &water.Config{
TransportModuleBin: wasm,
NetworkDialerFunc: dialerFunc,
}
// configuring the standard out of the WebAssembly instance to inherit
// from the parent process
config.ModuleConfig().InheritStdout()
config.ModuleConfig().InheritStderr()

if d.configJSON != nil {
config.UnmarshalJSON(d.configJSON)
} else if d.configPB != nil {
config.UnmarshalProto(d.configPB)
}

ctx := context.Background()
// // optional: enable wazero logging
// ctx = context.WithValue(ctx, experimental.FunctionListenerFactoryKey{},
// logging.NewHostLoggingListenerFactory(os.Stderr, logging.LogScopeFilesystem|logging.LogScopePoll|logging.LogScopeSock))

dialer, err := water.NewDialerWithContext(ctx, config)
if err != nil {
panic(fmt.Sprintf("failed to create dialer: %v", err))
}

conn, err := dialer.DialContext(ctx, "tcp", remoteAddr)
conn, err := dialer.DialContext(ctx, network, remoteAddr)
if err != nil {
panic(fmt.Sprintf("failed to dial: %v", err))
}
Expand All @@ -87,10 +122,10 @@ func dialWATER(remoteAddr string, wasm []byte, dialerFunc func(network string, a
// So effectively, W.A.T.E.R. API ends here and everything below
// this line is just how you treat a net.Conn.

return conn, nil
return &netConn{conn}, nil
}

func StartWorker(conn net.Conn) {
func startWorker(conn NetConn) error {
defer conn.Close()

log.Printf("Connected to %s", conn.RemoteAddr())
Expand Down Expand Up @@ -119,22 +154,22 @@ func StartWorker(conn net.Conn) {
select {
case msg := <-chanMsgRecv:
if msg == nil {
return // connection closed
return errors.New("connection closed")
}
log.Printf("peer: %x\n", msg)
case <-ticker.C:
n, err := rand.Read(sendBuf)
if err != nil {
log.Printf("rand.Read: error %v, tearing down connection...", err)
return
return err
}
// print the bytes sending as hex string
log.Printf("sending: %x\n", sendBuf[:n])

_, err = conn.Write(sendBuf[:n])
if err != nil {
log.Printf("write: error %v, tearing down connection...", err)
return
return err
}
}
}
Expand Down
5 changes: 5 additions & 0 deletions protector.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package watermob

type Protector interface {
Protect(fd int) bool
}
67 changes: 67 additions & 0 deletions types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package watermob

import (
"net"
"time"
)

// This file contains the explicitly defined types for the watermob package that
// maps to standard library types. This is done to allow for easy swapping of

// net.Addr
type NetAddr interface {
Network() string
String() string
}

// net.Conn
type NetConn interface {
Read(b []byte) (n int, err error)
Write(b []byte) (n int, err error)
Close() error
LocalAddr() NetAddr
RemoteAddr() NetAddr
SetDeadline(t time.Time) error
SetReadDeadline(t time.Time) error
SetWriteDeadline(t time.Time) error
}

type netConn struct {
embeddedConn net.Conn
}

func (c *netConn) Read(b []byte) (n int, err error) {
return c.embeddedConn.Read(b)
}

func (c *netConn) Write(b []byte) (n int, err error) {
return c.embeddedConn.Write(b)
}

func (c *netConn) Close() error {
return c.embeddedConn.Close()
}

func (c *netConn) LocalAddr() NetAddr {
return c.embeddedConn.LocalAddr()
}

func (c *netConn) RemoteAddr() NetAddr {
return c.embeddedConn.RemoteAddr()
}

func (c *netConn) SetDeadline(t time.Time) error {
return c.embeddedConn.SetDeadline(t)
}

func (c *netConn) SetReadDeadline(t time.Time) error {
return c.embeddedConn.SetReadDeadline(t)
}

func (c *netConn) SetWriteDeadline(t time.Time) error {
return c.embeddedConn.SetWriteDeadline(t)
}

func NewNetConn(c net.Conn) NetConn {
return &netConn{c}
}

0 comments on commit 58cc372

Please sign in to comment.