Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(go/adbc/driver/flightsql): add context to gRPC errors #921

Merged
merged 1 commit into from
Jul 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 24 additions & 24 deletions go/adbc/driver/flightsql/flightsql_adbc.go
Original file line number Diff line number Diff line change
Expand Up @@ -892,10 +892,10 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re
ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
info, err := c.cl.GetSqlInfo(ctx, translated, c.timeouts)
if err == nil {
for _, endpoint := range info.Endpoint {
for i, endpoint := range info.Endpoint {
rdr, err := doGet(ctx, c.cl, endpoint, c.clientCache, c.timeouts)
if err != nil {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "GetInfo(DoGet): endpoint %d: %s", i, endpoint.Location)
}

for rdr.Next() {
Expand All @@ -922,11 +922,11 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re
}

if rdr.Err() != nil {
return nil, adbcFromFlightStatus(rdr.Err())
return nil, adbcFromFlightStatus(rdr.Err(), "GetInfo(DoGet): endpoint %d: %s", i, endpoint.Location)
}
}
} else if grpcstatus.Code(err) != grpccodes.Unimplemented {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "GetInfo(GetSqlInfo)")
}

final := bldr.NewRecord()
Expand Down Expand Up @@ -1032,12 +1032,12 @@ func (c *cnxn) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog *
// To avoid an N+1 query problem, we assume result sets here will fit in memory and build up a single response.
info, err := c.cl.GetCatalogs(ctx)
if err != nil {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "GetObjects(GetCatalogs)")
}

rdr, err := c.readInfo(ctx, schema_ref.Catalogs, info)
if err != nil {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "GetObjects(GetCatalogs)")
}
defer rdr.Release()

Expand All @@ -1058,7 +1058,7 @@ func (c *cnxn) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog *
}

if err = rdr.Err(); err != nil {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "GetObjects(GetCatalogs)")
}

return g.Finish()
Expand All @@ -1069,7 +1069,7 @@ func (c *cnxn) readInfo(ctx context.Context, expectedSchema *arrow.Schema, info
// use a default queueSize for the reader
rdr, err := newRecordReader(ctx, c.db.alloc, c.cl, info, c.clientCache, 5)
if err != nil {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "DoGet")
}

if !rdr.Schema().Equal(expectedSchema) {
Expand All @@ -1091,12 +1091,12 @@ func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth,
// Pre-populate the map of which schemas are in which catalogs
info, err := c.cl.GetDBSchemas(ctx, &flightsql.GetDBSchemasOpts{DbSchemaFilterPattern: dbSchema})
if err != nil {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "GetObjects(GetDBSchemas)")
}

rdr, err := c.readInfo(ctx, schema_ref.DBSchemas, info)
if err != nil {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "GetObjects(GetDBSchemas)")
}
defer rdr.Release()

Expand All @@ -1117,7 +1117,7 @@ func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth,

if rdr.Err() != nil {
result = nil
err = adbcFromFlightStatus(rdr.Err())
err = adbcFromFlightStatus(rdr.Err(), "GetObjects(GetDBSchemas)")
}
return
}
Expand All @@ -1137,7 +1137,7 @@ func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, cat
IncludeSchema: includeSchema,
})
if err != nil {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "GetObjects(GetTables)")
}

expectedSchema := schema_ref.Tables
Expand All @@ -1146,7 +1146,7 @@ func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, cat
}
rdr, err := c.readInfo(ctx, expectedSchema, info)
if err != nil {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "GetObjects(GetTables)")
}
defer rdr.Release()

Expand Down Expand Up @@ -1195,7 +1195,7 @@ func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, cat

if rdr.Err() != nil {
result = nil
err = adbcFromFlightStatus(rdr.Err())
err = adbcFromFlightStatus(rdr.Err(), "GetObjects(GetTables)")
}
return
}
Expand All @@ -1211,12 +1211,12 @@ func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *st
ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
info, err := c.cl.GetTables(ctx, opts, c.timeouts)
if err != nil {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "GetTableSchema(GetTables)")
}

rdr, err := doGet(ctx, c.cl, info.Endpoint[0], c.clientCache, c.timeouts)
if err != nil {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "GetTableSchema(DoGet)")
}
defer rdr.Release()

Expand All @@ -1228,7 +1228,7 @@ func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *st
Code: adbc.StatusNotFound,
}
}
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "GetTableSchema(DoGet)")
}

if rec.NumRows() == 0 {
Expand All @@ -1246,7 +1246,7 @@ func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *st
schemaBytes := rec.Column(4).(*array.Binary).Value(0)
s, err := flight.DeserializeSchema(schemaBytes, c.db.alloc)
if err != nil {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "GetTableSchema")
}
return s, nil
}
Expand All @@ -1262,7 +1262,7 @@ func (c *cnxn) GetTableTypes(ctx context.Context) (array.RecordReader, error) {
ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
info, err := c.cl.GetTableTypes(ctx, c.timeouts)
if err != nil {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "GetTableTypes")
}

return newRecordReader(ctx, c.db.alloc, c.cl, info, c.clientCache, 5)
Expand All @@ -1289,12 +1289,12 @@ func (c *cnxn) Commit(ctx context.Context) error {
ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
err := c.txn.Commit(ctx, c.timeouts)
if err != nil {
return adbcFromFlightStatus(err)
return adbcFromFlightStatus(err, "Commit")
}

c.txn, err = c.cl.BeginTransaction(ctx, c.timeouts)
if err != nil {
return adbcFromFlightStatus(err)
return adbcFromFlightStatus(err, "BeginTransaction")
}
return nil
}
Expand All @@ -1320,12 +1320,12 @@ func (c *cnxn) Rollback(ctx context.Context) error {
ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
err := c.txn.Rollback(ctx, c.timeouts)
if err != nil {
return adbcFromFlightStatus(err)
return adbcFromFlightStatus(err, "Rollback")
}

c.txn, err = c.cl.BeginTransaction(ctx, c.timeouts)
if err != nil {
return adbcFromFlightStatus(err)
return adbcFromFlightStatus(err, "BeginTransaction")
}
return nil
}
Expand Down Expand Up @@ -1428,7 +1428,7 @@ func (c *cnxn) ReadPartition(ctx context.Context, serializedPartition []byte) (r
ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
rdr, err = doGet(ctx, c.cl, info.Endpoint[0], c.clientCache, c.timeouts)
if err != nil {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "ReadPartition(DoGet)")
}
return rdr, nil
}
Expand Down
2 changes: 1 addition & 1 deletion go/adbc/driver/flightsql/flightsql_adbc_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ func (ts *TimeoutTests) TestDoActionTimeout() {
ts.ErrorAs(stmt.Prepare(context.Background()), &adbcErr)
ts.Equal(adbc.StatusTimeout, adbcErr.Code, adbcErr.Error())
// Exact match - we don't want extra fluff in the message
ts.Equal("context deadline exceeded", adbcErr.Msg)
ts.Equal("[FlightSQL] context deadline exceeded (DeadlineExceeded; Prepare)", adbcErr.Msg)
}

func (ts *TimeoutTests) TestDoGetTimeout() {
Expand Down
10 changes: 5 additions & 5 deletions go/adbc/driver/flightsql/flightsql_statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ func (s *statement) ExecuteQuery(ctx context.Context) (rdr array.RecordReader, n
}

if err != nil {
return nil, -1, adbcFromFlightStatus(err)
return nil, -1, adbcFromFlightStatus(err, "ExecuteQuery")
}

nrec = info.TotalRecords
Expand All @@ -259,7 +259,7 @@ func (s *statement) ExecuteUpdate(ctx context.Context) (n int64, err error) {
}

if err != nil {
err = adbcFromFlightStatus(err)
err = adbcFromFlightStatus(err, "ExecuteUpdate")
}

return
Expand All @@ -271,7 +271,7 @@ func (s *statement) Prepare(ctx context.Context) error {
ctx = metadata.NewOutgoingContext(ctx, s.hdrs)
prep, err := s.query.prepare(ctx, s.cnxn, s.timeouts)
if err != nil {
return adbcFromFlightStatus(err)
return adbcFromFlightStatus(err, "Prepare")
}
s.prepared = prep
return nil
Expand Down Expand Up @@ -394,13 +394,13 @@ func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc.
}

if err != nil {
return nil, out, -1, adbcFromFlightStatus(err)
return nil, out, -1, adbcFromFlightStatus(err, "ExecutePartitions")
}

if len(info.Schema) > 0 {
sc, err = flight.DeserializeSchema(info.Schema, s.alloc)
if err != nil {
return nil, out, -1, adbcFromFlightStatus(err)
return nil, out, -1, adbcFromFlightStatus(err, "ExecutePartitions: could not deserialize FlightInfo schema:")
}
}

Expand Down
4 changes: 2 additions & 2 deletions go/adbc/driver/flightsql/record_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, cl *flightsql.
} else {
rdr, err := doGet(ctx, cl, endpoints[0], clCache, opts...)
if err != nil {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "DoGet: endpoint 0: remote: %s", endpoints[0].Location)
}
schema = rdr.Schema()
group.Go(func() error {
Expand Down Expand Up @@ -135,7 +135,7 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, cl *flightsql.

rdr, err := doGet(ctx, cl, endpoint, clCache, opts...)
if err != nil {
return err
return adbcFromFlightStatus(err, "DoGet: endpoint %d: %s", endpointIndex, endpoint.Location)
}
defer rdr.Release()

Expand Down
7 changes: 5 additions & 2 deletions go/adbc/driver/flightsql/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
package flightsql

import (
"fmt"

"github.com/apache/arrow-adbc/go/adbc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

func adbcFromFlightStatus(err error) error {
func adbcFromFlightStatus(err error, context string, args ...any) error {
if _, ok := err.(adbc.Error); ok {
return err
}
Expand Down Expand Up @@ -70,8 +72,9 @@ func adbcFromFlightStatus(err error) error {
adbcCode = adbc.StatusUnknown
}

// People don't read error messages, so backload the context and frontload the server error
return adbc.Error{
Msg: grpcStatus.Message(),
Msg: fmt.Sprintf("[FlightSQL] %s (%s; %s)", grpcStatus.Message(), grpcStatus.Code(), fmt.Sprintf(context, args...)),
Code: adbcCode,
}
}
Loading