Skip to content

Commit

Permalink
Add support for fd passing
Browse files Browse the repository at this point in the history
This adds a new message type for passing file descriptors.
How this works is:

1. Client sends a message with a header for messageTypeFileDescriptor
   along with the list of descriptors to be sent
2. Client sends 2nd message to actually pass along the descriptors
   (needed for unix sockets).
3. Server sees the message type and waits to receive the fd's.
4. Once fd's are seen the server responds with the real fd numbers that
   are used which an application can use in future calls.

To accomplish this reliably (on unix sockets) I had to drop the usage of
the bufio.Reader because we need to ensure exact message boundaries.

Within ttrpc this only support unix sockets and `net.Conn` implementations
that implement `SendFds`/`ReceiveFds` (this interface is totally
invented here).

Something to consider, I have not attempted to do fd passing on Windows
which will need other mechanisms entirely (and the conn's provided by
winio are not sufficient for fd passing).
I'm not sure if this new messaging will actually work on a Windows
implementation.

Perhaps the message tpye should be specifically for unix sockets? I'm
not sure how this would be enforced at the moment except by checking if
the `net.Conn` is a `*net.UnixConn`.

Signed-off-by: Brian Goff <[email protected]>
  • Loading branch information
cpuguy83 committed May 7, 2021
1 parent 25f5476 commit a19e82d
Show file tree
Hide file tree
Showing 7 changed files with 472 additions and 37 deletions.
94 changes: 90 additions & 4 deletions channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ package ttrpc
import (
"bufio"
"encoding/binary"
"fmt"
"io"
"net"
"sync"

"github.com/pkg/errors"
"golang.org/x/sys/unix"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
Expand All @@ -36,8 +38,9 @@ const (
type messageType uint8

const (
messageTypeRequest messageType = 0x1
messageTypeResponse messageType = 0x2
messageTypeRequest messageType = 0x1
messageTypeResponse messageType = 0x2
messageTypeFileDescriptor messageType = 0x3
)

// messageHeader represents the fixed-length message header of 10 bytes sent
Expand Down Expand Up @@ -98,7 +101,7 @@ func newChannel(conn net.Conn) *channel {
// the correct consumer. The bytes on the underlying channel
// will be discarded.
func (ch *channel) recv() (messageHeader, []byte, error) {
mh, err := readMessageHeader(ch.hrbuf[:], ch.br)
mh, err := readMessageHeader(ch.hrbuf[:messageHeaderLength], ch.conn)
if err != nil {
return messageHeader{}, nil, err
}
Expand All @@ -112,13 +115,78 @@ func (ch *channel) recv() (messageHeader, []byte, error) {
}

p := ch.getmbuf(int(mh.Length))
if _, err := io.ReadFull(ch.br, p); err != nil {
if _, err := io.ReadFull(ch.conn, p[:int(mh.Length)]); err != nil {
return messageHeader{}, nil, errors.Wrapf(err, "failed reading message")
}

return mh, p, nil
}

func (ch *channel) recvFD(files *FileList) error {
var (
fds []int
err error
)

switch t := ch.conn.(type) {
case FdReceiver:
fds, err = t.Recvfd()
if err != nil {
return err
}
case unixReader:
oob := ch.getmbuf(unix.CmsgSpace(len(files.List) * 4))
defer ch.putmbuf(oob)

_, oobn, _, _, err := t.ReadMsgUnix(make([]byte, 1), oob)
if err != nil {
return err
}

ls, err := unix.ParseSocketControlMessage(oob[:oobn])
if err != nil {
return fmt.Errorf("error parsing socket controll message: %w", err)
}

for _, m := range ls {
fdsTemp, err := unix.ParseUnixRights(&m)
if err != nil {
return fmt.Errorf("error parsing unix rights message: %w", err)
}
fds = append(fds, fdsTemp...)
}
default:
return fmt.Errorf("receiving file descriptors is not supported on the transport")
}

if len(files.List) != len(fds) {
return fmt.Errorf("received %d file descriptors, expected %d", len(fds), len(files.List))
}
for i, fd := range fds {
files.List[i].Fileno = int64(fd)
}
return nil
}

func (ch *channel) sendFd(streamID uint32, mt messageType, files *FileList) error {
fds := make([]int, len(files.List))

for i, f := range files.List {
fds[i] = int(f.Fileno)
}

switch t := ch.conn.(type) {
case unixWriter:
// Must send at least a single byte over unix sockets for the ancillary data to be accepted.
_, _, err := t.WriteMsgUnix(make([]byte, 1), unix.UnixRights(fds...), nil)
return err
case FdSender:
return t.SendFd(fds)
default:
return fmt.Errorf("sending file descriptors is not supported on the transport")
}
}

func (ch *channel) send(streamID uint32, t messageType, p []byte) error {
if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t}); err != nil {
return err
Expand Down Expand Up @@ -151,3 +219,21 @@ func (ch *channel) getmbuf(size int) []byte {
func (ch *channel) putmbuf(p []byte) {
buffers.Put(&p)
}

// FdReceiver is an interface used that the transport may implement to receive file descriptors from the client
type FdReceiver interface {
Recvfd() ([]int, error)
}

// FdSender is an interface used that the transport may implement to send file descriptors to the server.
type FdSender interface {
SendFd([]int) error
}

type unixReader interface {
ReadMsgUnix(p, oob []byte) (n, oobn, flags int, addr *net.UnixAddr, err error)
}

type unixWriter interface {
WriteMsgUnix(b, oob []byte, addr *net.UnixAddr) (n, oobn int, err error)
}
79 changes: 68 additions & 11 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,42 @@ func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
}

type callRequest struct {
ctx context.Context
req *Request
resp *Response // response will be written back here
errs chan error // error written here on completion
ctx context.Context
req *Request
resp *Response // response will be written back here
errs chan error // error written here on completion
files *FileList
}

func (c *Client) Sendfd(ctx context.Context, files []*os.File) ([]int64, error) {
ls := make([]*File, len(files))
for i, f := range files {
ls[i] = &File{
Name: f.Name(),
Fileno: int64(f.Fd()),
}
}

resp := &Response{}
fl := &FileList{List: ls}
if err := c.dispatch(ctx, nil, resp, fl); err != nil {
return nil, err
}
if resp.Status != nil && resp.Status.Code != int32(codes.OK) {
return nil, status.ErrorProto(resp.Status)
}

fl.Reset()

if err := c.codec.Unmarshal(resp.Payload, fl); err != nil {
return nil, err
}

fds := make([]int64, len(fl.List))
for i, f := range fl.List {
fds[i] = f.Fileno
}
return fds, nil
}

func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error {
Expand Down Expand Up @@ -129,7 +161,9 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int
info := &UnaryClientInfo{
FullMethod: fullPath(service, method),
}
if err := c.interceptor(ctx, creq, cresp, info, c.dispatch); err != nil {
if err := c.interceptor(ctx, creq, cresp, info, func(ctx context.Context, req *Request, resp *Response) error {
return c.dispatch(ctx, req, resp, nil)
}); err != nil {
return err
}

Expand All @@ -143,13 +177,14 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int
return nil
}

func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error {
func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response, files *FileList) error {
errs := make(chan error, 1)
call := &callRequest{
ctx: ctx,
req: req,
resp: resp,
errs: errs,
ctx: ctx,
req: req,
resp: resp,
errs: errs,
files: files,
}

select {
Expand Down Expand Up @@ -270,13 +305,35 @@ func (c *Client) run() {
for {
select {
case call := <-calls:
if err := c.send(streamID, messageTypeRequest, call.req); err != nil {
var (
data interface{}
mt messageType
)

switch {
case call.files != nil:
data = call.files
mt = messageTypeFileDescriptor
case call.req != nil:
data = call.req
mt = messageTypeRequest
}

if err := c.send(streamID, mt, data); err != nil {
call.errs <- err
continue
}

waiters[streamID] = call
streamID += 2 // enforce odd client initiated request ids

if call.files != nil {
if err := c.channel.sendFd(streamID, mt, call.files); err != nil {
call.errs <- err
continue
}
}

case msg := <-incoming:
call, ok := waiters[msg.StreamID]
if !ok {
Expand Down
146 changes: 146 additions & 0 deletions fd_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
Copyright The containerd Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package ttrpc

import (
"bytes"
"context"
"fmt"
"io"
"net"
"os"
"strconv"
"testing"
"time"
)

func TestSendRecvFd(t *testing.T) {
var (
ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(1*time.Minute))
addr, listener = newTestListener(t)
)

defer cancel()

// Spin up an out of process ttrpc server
if err := listenerCmd(ctx, t.Name(), listener); err != nil {
t.Fatal(err)
}

var (
client, cleanup = newTestClient(t, addr)

tclient = testFdClient{client}
)
defer cleanup()

r, w, err := os.Pipe()
if err != nil {
t.Fatal(err, "error creating test pipe")
}
defer r.Close()

type readResp struct {
buf []byte
err error
}

expect := []byte("hello")

chResp := make(chan readResp, 1)
go func() {
buf := make([]byte, len(expect))
_, err := io.ReadFull(r, buf)
chResp <- readResp{buf, err}
}()

if err := tclient.Test(ctx, w); err != nil {
t.Fatal(err)
}

select {
case <-ctx.Done():
t.Fatal(ctx.Err())
case resp := <-chResp:
if resp.err != nil {
t.Error(err)
}
if !bytes.Equal(resp.buf, expect) {
t.Fatalf("got unexpected respone data, exepcted %q, got %q", string(expect), string(resp.buf))
}
}
}

type testFdPayload struct {
Fds []int64 `protobuf:"varint,1,opt,name=fds,proto3"`
}

func (r *testFdPayload) Reset() { *r = testFdPayload{} }
func (r *testFdPayload) String() string { return fmt.Sprintf("%+#v", r) }
func (r *testFdPayload) ProtoMessage() {}

type testingServerFd struct {
respData []byte
}

func (s *testingServerFd) Test(ctx context.Context, req *testFdPayload) error {
for i, fd := range req.Fds {
f := os.NewFile(uintptr(fd), "TEST_FILE_"+strconv.Itoa(i))
go func() {
f.Write(s.respData)
f.Close()
}()
}

return nil
}

type testFdClient struct {
client *Client
}

func (c *testFdClient) Test(ctx context.Context, files ...*os.File) error {
fds, err := c.client.Sendfd(ctx, files)
if err != nil {
return fmt.Errorf("error sending fds: %w", err)
}

tp := testFdPayload{}
return c.client.Call(ctx, "Test", "Test", &testFdPayload{Fds: fds}, &tp)
}

func handleTestSendRecvFd(l net.Listener) error {
s, err := NewServer()
if err != nil {
return err
}
testImpl := &testingServerFd{respData: []byte("hello")}

s.Register("Test", map[string]Method{
"Test": func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) {
req := &testFdPayload{}

if err := unmarshal(req); err != nil {
return nil, err
}

return &testFdPayload{}, testImpl.Test(ctx, req)
},
})

return s.Serve(context.TODO(), l)
}
Loading

0 comments on commit a19e82d

Please sign in to comment.