Skip to content

Commit

Permalink
Fix job logs following to work with auth
Browse files Browse the repository at this point in the history
Auth broke logs following because not only was there an incorrect cast
the setting of headers was not done correctly on the websocket request

To bring logs reading into the fold, we've had to refactor a little
bit and make Dial a first class method on the client. Unfortunately
there is a collision between the interface-typed Client and the
generic-typed Dialler, meaning that Dial now has to return just byte
slices and a utility is needed to parse these into strongly-typed
objects.
  • Loading branch information
simonwo committed Mar 14, 2024
1 parent f1f430b commit 720aae4
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 41 deletions.
3 changes: 1 addition & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ endif
export GO111MODULE = on
export CGO_ENABLED = 0
export PRECOMMIT = poetry run pre-commit
export EARTHLY ?= $(shell which earthly)
export EARTHLY ?= $(shell command -v earthly --push 2> /dev/null)

BUILD_DIR = bacalhau
BINARY_NAME = bacalhau
Expand All @@ -57,7 +57,6 @@ TEST_PARALLEL_PACKAGES ?= 1
PRIVATE_KEY_FILE := /tmp/private.pem
PUBLIC_KEY_FILE := /tmp/public.pem

export EARTHLY := $(shell command -v earthly --push 2> /dev/null)
export MAKE := $(shell command -v make 2> /dev/null)

define BUILD_FLAGS
Expand Down
2 changes: 1 addition & 1 deletion pkg/publicapi/client/v2/api_jobs.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,5 @@ func (j *Jobs) Stop(ctx context.Context, r *apimodels.StopJobRequest) (*apimodel

// Logs returns a stream of logs for a given job/execution.
func (j *Jobs) Logs(ctx context.Context, r *apimodels.GetLogsRequest) (<-chan *concurrency.AsyncResult[models.ExecutionLog], error) {
return webSocketDialer[models.ExecutionLog](ctx, j.client.(*httpClient), jobsPath+"/"+r.JobID+"/logs", r)
return DialAsyncResult[*apimodels.GetLogsRequest, models.ExecutionLog](ctx, j.client, jobsPath+"/"+r.JobID+"/logs", r)
}
76 changes: 76 additions & 0 deletions pkg/publicapi/client/v2/client.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
package client

import (
"bytes"
"context"
"io"
"net/http"
"net/url"
"time"

"github.com/bacalhau-project/bacalhau/pkg/lib/concurrency"
"github.com/bacalhau-project/bacalhau/pkg/publicapi/apimodels"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
"go.uber.org/multierr"
)
Expand All @@ -19,6 +23,7 @@ type Client interface {
Put(context.Context, string, apimodels.PutRequest, apimodels.PutResponse) error
Post(context.Context, string, apimodels.PutRequest, apimodels.PutResponse) error
Delete(context.Context, string, apimodels.PutRequest, apimodels.Response) error
Dial(context.Context, string, apimodels.Request) (<-chan *concurrency.AsyncResult[[]byte], error)
}

// New creates a new transport.
Expand Down Expand Up @@ -112,6 +117,64 @@ func (c *httpClient) Delete(ctx context.Context, endpoint string, in apimodels.P
return c.write(ctx, http.MethodDelete, endpoint, in, out)
}

// Dial is used to upgrade to a Websocket connection and subscribe to an
// endpoint. The method returns on error or if the endpoint has been
// successfully dialed, from which point on the returned channel will contain
// every received message.
func (c *httpClient) Dial(ctx context.Context, endpoint string, in apimodels.Request) (<-chan *concurrency.AsyncResult[[]byte], error) {
r := in.ToHTTPRequest()
httpR, err := c.toHTTP(ctx, http.MethodGet, endpoint, r)
if err != nil {
return nil, err
}
httpR.URL.Scheme = "ws"

// Connect to the server
conn, resp, err := websocket.DefaultDialer.Dial(httpR.URL.String(), httpR.Header)
if err != nil {
return nil, err
}
defer resp.Body.Close()

// Read messages from the server, and send them until the conn is closed or
// the context is cancelled. We have to read them here because the reader
// will be discarded upon the next call to NextReader.
output := make(chan *concurrency.AsyncResult[[]byte], c.config.WebsocketChannelBuffer)
go func() {
defer func() {
_ = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
conn.Close()
close(output)
}()

for {
select {
case <-ctx.Done():
return
default:
_, reader, err := conn.NextReader()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) {
output <- &concurrency.AsyncResult[[]byte]{Err: err}
}
return
}

if reader != nil {
var buf bytes.Buffer
if _, err := io.Copy(&buf, reader); err != nil {
output <- &concurrency.AsyncResult[[]byte]{Err: err}
return
}
output <- &concurrency.AsyncResult[[]byte]{Value: buf.Bytes()}
}
}
}
}()

return output, nil
}

// doRequest runs a request with our client
func (c *httpClient) doRequest(
ctx context.Context,
Expand Down Expand Up @@ -271,6 +334,19 @@ func (t *AuthenticatingClient) Delete(ctx context.Context, path string, in apimo
})
}

func (t *AuthenticatingClient) Dial(
ctx context.Context,
path string,
in apimodels.Request,
) (<-chan *concurrency.AsyncResult[[]byte], error) {
var output <-chan *concurrency.AsyncResult[[]byte]
err := doRequest(t, in, func(req apimodels.Request) (err error) {
output, err = t.Client.Dial(ctx, path, req)
return
})
return output, err
}

func doRequest[R apimodels.Request](t *AuthenticatingClient, request R, runRequest func(R) error) (err error) {
if t.Credential != nil {
request.SetCredential(t.Credential)
Expand Down
60 changes: 23 additions & 37 deletions pkg/publicapi/client/v2/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (

"github.com/bacalhau-project/bacalhau/pkg/lib/concurrency"
"github.com/bacalhau-project/bacalhau/pkg/publicapi/apimodels"
"github.com/gorilla/websocket"
"go.uber.org/multierr"
)

// encodeBody prepares the reader to serve as the request body.
Expand Down Expand Up @@ -101,48 +101,34 @@ func autoUnzip(resp *http.Response) error {
return nil
}

func webSocketDialer[T any](ctx context.Context, c *httpClient, endpoint string, in apimodels.GetRequest) (
<-chan *concurrency.AsyncResult[T], error) {
r := in.ToHTTPRequest()
httpR, err := c.toHTTP(ctx, http.MethodGet, endpoint, r)
// DialAsyncResult makes a Dial call to the passed client and interprets any
// received messages as AsyncResult objects, decoding them and posting them on
// the returned channel.
func DialAsyncResult[In apimodels.Request, Out any](
ctx context.Context,
client Client,
endpoint string,
r In,
) (<-chan *concurrency.AsyncResult[Out], error) {
output := make(chan *concurrency.AsyncResult[Out])

input, err := client.Dial(ctx, endpoint, r)
if err != nil {
return nil, err
}
httpR.URL.Scheme = "ws"

// Connect to the server
conn, resp, err := websocket.DefaultDialer.Dial(httpR.URL.String(), httpR.Header)
if err != nil {
return nil, err
}
defer resp.Body.Close()

// Read messages from the server, and send them to the conn is closed or the context is cancelled
ch := make(chan *concurrency.AsyncResult[T], c.config.WebsocketChannelBuffer)
go func() {
defer func() {
_ = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
conn.Close()
close(ch)
}()
for {
select {
case <-ctx.Done():
return
default:
result := new(concurrency.AsyncResult[T])
err := conn.ReadJSON(result)
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) {
result.Err = err
ch <- result
}
return
}
ch <- result
for result := range input {
outResult := new(concurrency.AsyncResult[Out])
if result.Value != nil {
decodeErr := json.NewDecoder(bytes.NewReader(result.Value)).Decode(outResult)
outResult.Err = multierr.Combine(outResult.Err, result.Err, decodeErr)
} else {
outResult.Err = result.Err
}
output <- outResult
}
close(output)
}()

return ch, nil
return output, nil
}
4 changes: 3 additions & 1 deletion pkg/publicapi/endpoint/orchestrator/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/bacalhau-project/bacalhau/pkg/lib/concurrency"
"github.com/gorilla/websocket"
"github.com/labstack/echo/v4"
"github.com/rs/zerolog/log"
"github.com/samber/lo"
"golang.org/x/exp/slices"

Expand Down Expand Up @@ -458,11 +459,12 @@ func (e *Endpoint) logs(c echo.Context) error {

err = e.logsWS(c, ws)
if err != nil {
log.Ctx(c.Request().Context()).Error().Err(err).Msg("websocket failure")
err = ws.WriteJSON(concurrency.AsyncResult[models.ExecutionLog]{
Err: err,
})
if err != nil {
c.Logger().Errorf("failed to write error to websocket: %s", err)
log.Ctx(c.Request().Context()).Error().Err(err).Msg("failed to write error to websocket")
}
}
_ = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
Expand Down
12 changes: 12 additions & 0 deletions test/logs.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!bin/bashtub

source bin/bacalhau.sh

testcase_can_follow_job_logs() {
create_node requester,compute

subject bacalhau job run --follow $ROOT/testdata/jobs/wasm.yaml
assert_equal 0 $status
assert_match 'Hello, world!' $(echo $stdout | tail -n1)
assert_equal '' $stderr
}

0 comments on commit 720aae4

Please sign in to comment.