diff --git a/go/adbc/driver/flightsql/cmd/testserver/main.go b/go/adbc/driver/flightsql/cmd/testserver/main.go index 9951df2357..1b9bf7a95e 100644 --- a/go/adbc/driver/flightsql/cmd/testserver/main.go +++ b/go/adbc/driver/flightsql/cmd/testserver/main.go @@ -36,21 +36,36 @@ import ( "github.com/apache/arrow/go/v17/arrow/array" "github.com/apache/arrow/go/v17/arrow/flight" "github.com/apache/arrow/go/v17/arrow/flight/flightsql" + "github.com/apache/arrow/go/v17/arrow/flight/flightsql/schema_ref" "github.com/apache/arrow/go/v17/arrow/memory" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/wrapperspb" ) +type RecordedHeader struct { + method string + header string + value string +} + type ExampleServer struct { flightsql.BaseServer mu sync.Mutex pollingStatus map[string]int + headers []RecordedHeader } +var recordedHeadersSchema = arrow.NewSchema([]arrow.Field{ + {Name: "method", Type: arrow.BinaryTypes.String, Nullable: false}, + {Name: "header", Type: arrow.BinaryTypes.String, Nullable: false}, + {Name: "value", Type: arrow.BinaryTypes.String, Nullable: false}, +}, nil) + func StatusWithDetail(code codes.Code, message string, details ...proto.Message) error { p := status.New(code, message).Proto() // Have to do this by hand because gRPC uses deprecated proto import @@ -64,11 +79,41 @@ func StatusWithDetail(code codes.Code, message string, details ...proto.Message) return status.FromProto(p).Err() } +func (srv *ExampleServer) recordHeaders(ctx context.Context, method string) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + panic("Misuse of recordHeaders") + } + + srv.mu.Lock() + defer srv.mu.Unlock() + for k, vv := range md { + for _, v := range vv { + log.Printf("Header: %s: %s = %s\n", method, k, v) + srv.headers = append(srv.headers, RecordedHeader{ + method: method, header: k, value: v, + }) + } + } +} + +func (srv *ExampleServer) BeginTransaction(ctx context.Context, req flightsql.ActionBeginTransactionRequest) ([]byte, error) { + srv.recordHeaders(ctx, "BeginTransaction") + return []byte("foo"), nil +} + +func (srv *ExampleServer) EndTransaction(ctx context.Context, req flightsql.ActionEndTransactionRequest) error { + srv.recordHeaders(ctx, "EndTransaction") + return nil +} + func (srv *ExampleServer) ClosePreparedStatement(ctx context.Context, request flightsql.ActionClosePreparedStatementRequest) error { + srv.recordHeaders(ctx, "ClosePreparedStatement") return nil } func (srv *ExampleServer) CreatePreparedStatement(ctx context.Context, req flightsql.ActionCreatePreparedStatementRequest) (result flightsql.ActionCreatePreparedStatementResult, err error) { + srv.recordHeaders(ctx, "CreatePreparedStatement") switch req.GetQuery() { case "error_create_prepared_statement": err = status.Error(codes.InvalidArgument, "expected error (DoAction)") @@ -83,7 +128,8 @@ func (srv *ExampleServer) CreatePreparedStatement(ctx context.Context, req fligh return } -func (srv *ExampleServer) GetFlightInfoPreparedStatement(_ context.Context, cmd flightsql.PreparedStatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { +func (srv *ExampleServer) GetFlightInfoPreparedStatement(ctx context.Context, cmd flightsql.PreparedStatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + srv.recordHeaders(ctx, "GetFlightInfoPreparedStatement") switch string(cmd.GetPreparedStatementHandle()) { case "error_do_get", "error_do_get_stream", "error_do_get_detail", "error_do_get_stream_detail", "forever": schema := arrow.NewSchema([]arrow.Field{{Name: "ints", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil) @@ -111,6 +157,7 @@ func (srv *ExampleServer) GetFlightInfoPreparedStatement(_ context.Context, cmd } func (srv *ExampleServer) GetFlightInfoStatement(ctx context.Context, cmd flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + srv.recordHeaders(ctx, "GetFlightInfoStatement") ticket, err := flightsql.CreateStatementQueryTicket(desc.Cmd) if err != nil { return nil, err @@ -239,6 +286,7 @@ func (srv *ExampleServer) PollFlightInfoPreparedStatement(ctx context.Context, q } func (srv *ExampleServer) DoGetPreparedStatement(ctx context.Context, cmd flightsql.PreparedStatementQuery) (schema *arrow.Schema, out <-chan flight.StreamChunk, err error) { + srv.recordHeaders(ctx, "DoGetPreparedStatement") log.Printf("DoGetPreparedStatement: %v", cmd.GetPreparedStatementHandle()) switch string(cmd.GetPreparedStatementHandle()) { case "error_do_get": @@ -271,6 +319,45 @@ func (srv *ExampleServer) DoGetPreparedStatement(ctx context.Context, cmd flight case "stateless_prepared_statement": err = status.Error(codes.InvalidArgument, "client didn't use the updated handle") return + case "recorded_headers": + schema = recordedHeadersSchema + ch := make(chan flight.StreamChunk) + + methods := array.NewStringBuilder(srv.Alloc) + headers := array.NewStringBuilder(srv.Alloc) + values := array.NewStringBuilder(srv.Alloc) + defer methods.Release() + defer headers.Release() + defer values.Release() + + srv.mu.Lock() + defer srv.mu.Unlock() + + count := int64(0) + for _, recorded := range srv.headers { + count++ + methods.AppendString(recorded.method) + headers.AppendString(recorded.header) + values.AppendString(recorded.value) + } + srv.headers = make([]RecordedHeader, 0) + + rec := array.NewRecord(recordedHeadersSchema, []arrow.Array{ + methods.NewArray(), + headers.NewArray(), + values.NewArray(), + }, count) + + go func() { + defer close(ch) + ch <- flight.StreamChunk{ + Data: rec, + Desc: nil, + Err: nil, + } + }() + out = ch + return } schema = arrow.NewSchema([]arrow.Field{{Name: "ints", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil) @@ -323,6 +410,7 @@ func (srv *ExampleServer) DoGetStatement(ctx context.Context, cmd flightsql.Stat } func (srv *ExampleServer) DoPutPreparedStatementQuery(ctx context.Context, cmd flightsql.PreparedStatementQuery, reader flight.MessageReader, writer flight.MetadataWriter) ([]byte, error) { + srv.recordHeaders(ctx, "DoPutPreparedStatementQuery") switch string(cmd.GetPreparedStatementHandle()) { case "error_do_put": return nil, status.Error(codes.Unknown, "expected error (DoPut)") @@ -341,6 +429,115 @@ func (srv *ExampleServer) DoPutPreparedStatementUpdate(context.Context, flightsq return 0, status.Error(codes.Unimplemented, "DoPutPreparedStatementUpdate not implemented") } +func (srv *ExampleServer) GetFlightInfoCatalogs(ctx context.Context, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + srv.recordHeaders(ctx, "GetFlightInfoCatalogs") + return &flight.FlightInfo{ + Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: desc.Cmd}}}, + FlightDescriptor: desc, + Schema: flight.SerializeSchema(schema_ref.Catalogs, srv.Alloc), + TotalRecords: -1, + TotalBytes: -1, + }, nil +} + +func (srv *ExampleServer) DoGetCatalogs(ctx context.Context) (*arrow.Schema, <-chan flight.StreamChunk, error) { + srv.recordHeaders(ctx, "DoGetCatalogs") + + // Just return some dummy data + schema := schema_ref.Catalogs + ch := make(chan flight.StreamChunk, 1) + catalogs, _, err := array.FromJSON(srv.Alloc, arrow.BinaryTypes.String, strings.NewReader(`["catalog"]`)) + if err != nil { + return nil, nil, err + } + defer catalogs.Release() + + batch := array.NewRecord(schema, []arrow.Array{catalogs}, 1) + ch <- flight.StreamChunk{Data: batch} + close(ch) + return schema, ch, nil +} + +func (srv *ExampleServer) GetFlightInfoSchemas(ctx context.Context, req flightsql.GetDBSchemas, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + srv.recordHeaders(ctx, "GetFlightInfoDBSchemas") + return &flight.FlightInfo{ + Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: desc.Cmd}}}, + FlightDescriptor: desc, + Schema: flight.SerializeSchema(schema_ref.DBSchemas, srv.Alloc), + TotalRecords: -1, + TotalBytes: -1, + }, nil +} + +func (srv *ExampleServer) DoGetDBSchemas(ctx context.Context, req flightsql.GetDBSchemas) (*arrow.Schema, <-chan flight.StreamChunk, error) { + srv.recordHeaders(ctx, "DoGetDBSchemas") + + // Just return some dummy data + schema := schema_ref.DBSchemas + ch := make(chan flight.StreamChunk, 1) + // Not really a proper match, but good enough + if req.GetDBSchemaFilterPattern() == nil || *req.GetDBSchemaFilterPattern() == "" || *req.GetDBSchemaFilterPattern() == "main" { + catalogs, _, err := array.FromJSON(srv.Alloc, arrow.BinaryTypes.String, strings.NewReader(`["main"]`)) + if err != nil { + return nil, nil, err + } + defer catalogs.Release() + + dbSchemas, _, err := array.FromJSON(srv.Alloc, arrow.BinaryTypes.String, strings.NewReader(`[""]`)) + if err != nil { + return nil, nil, err + } + defer dbSchemas.Release() + + batch := array.NewRecord(schema, []arrow.Array{catalogs, dbSchemas}, 1) + ch <- flight.StreamChunk{Data: batch} + } + close(ch) + return schema, ch, nil +} + +func (srv *ExampleServer) GetFlightInfoTables(ctx context.Context, req flightsql.GetTables, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + srv.recordHeaders(ctx, "GetFlightInfoTables") + schema := schema_ref.Tables + if req.GetIncludeSchema() { + schema = schema_ref.TablesWithIncludedSchema + } + return &flight.FlightInfo{ + Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: desc.Cmd}}}, + FlightDescriptor: desc, + Schema: flight.SerializeSchema(schema, srv.Alloc), + TotalRecords: -1, + TotalBytes: -1, + }, nil +} + +func (srv *ExampleServer) DoGetTables(ctx context.Context, req flightsql.GetTables) (*arrow.Schema, <-chan flight.StreamChunk, error) { + srv.recordHeaders(ctx, "DoGetTables") + // Just return some dummy data + schema := schema_ref.Tables + if req.GetIncludeSchema() { + schema = schema_ref.TablesWithIncludedSchema + } + ch := make(chan flight.StreamChunk, 1) + close(ch) + return schema, ch, nil +} + +func (srv *ExampleServer) SetSessionOptions(ctx context.Context, req *flight.SetSessionOptionsRequest) (*flight.SetSessionOptionsResult, error) { + srv.recordHeaders(ctx, "SetSessionOptions") + return &flight.SetSessionOptionsResult{}, nil +} + +func (srv *ExampleServer) GetSessionOptions(ctx context.Context, req *flight.GetSessionOptionsRequest) (*flight.GetSessionOptionsResult, error) { + srv.recordHeaders(ctx, "GetSessionOptions") + return &flight.GetSessionOptionsResult{}, nil +} + +func (srv *ExampleServer) CloseSession(ctx context.Context, req *flight.CloseSessionRequest) (*flight.CloseSessionResult, error) { + srv.recordHeaders(ctx, "CloseSession") + return &flight.CloseSessionResult{}, nil +} + func main() { var ( host = flag.String("host", "localhost", "hostname to bind to") @@ -351,6 +548,9 @@ func main() { srv := &ExampleServer{pollingStatus: make(map[string]int)} srv.Alloc = memory.DefaultAllocator + if err := srv.RegisterSqlInfo(flightsql.SqlInfoFlightSqlServerTransaction, int32(flightsql.SqlTransactionTransaction)); err != nil { + log.Fatal(err) + } server := flight.NewServerWithMiddleware(nil) server.RegisterFlightService(flightsql.NewFlightServer(srv)) diff --git a/go/adbc/driver/flightsql/flightsql_connection.go b/go/adbc/driver/flightsql/flightsql_connection.go index 5c8269fb85..3a8fb8ef9f 100644 --- a/go/adbc/driver/flightsql/flightsql_connection.go +++ b/go/adbc/driver/flightsql/flightsql_connection.go @@ -225,6 +225,13 @@ func (c *connectionImpl) getSessionOptions(ctx context.Context) (map[string]inte func (c *connectionImpl) setSessionOptions(ctx context.Context, key string, val interface{}) error { req := flight.SetSessionOptionsRequest{} + hdrs := make([]string, 0) + for k, vv := range c.hdrs { + for _, v := range vv { + hdrs = append(hdrs, k, v) + } + } + ctx = metadata.AppendToOutgoingContext(ctx, hdrs...) var err error req.SessionOptions, err = flight.NewSessionOptionValues(map[string]any{key: val}) @@ -238,7 +245,7 @@ func (c *connectionImpl) setSessionOptions(ctx context.Context, key string, val var header, trailer metadata.MD errors, err := c.cl.SetSessionOptions(ctx, &req, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) if err != nil { - return adbcFromFlightStatusWithDetails(err, header, trailer, "GetSessionOptions") + return adbcFromFlightStatusWithDetails(err, header, trailer, "SetSessionOptions") } if len(errors.Errors) > 0 { msg := strings.Builder{} @@ -635,6 +642,7 @@ func (c *connectionImpl) GetObjectsCatalogs(ctx context.Context, catalog *string header, trailer metadata.MD numCatalogs int64 ) + ctx = metadata.NewOutgoingContext(ctx, c.hdrs) // To avoid an N+1 query problem, we assume result sets here will fit in memory and build up a single response. info, err := c.cl.GetCatalogs(ctx, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) if err != nil { @@ -675,6 +683,7 @@ func (c *connectionImpl) GetObjectsDbSchemas(ctx context.Context, depth adbc.Obj if depth == adbc.ObjectDepthCatalogs { return } + ctx = metadata.NewOutgoingContext(ctx, c.hdrs) result = make(map[string][]string) var header, trailer metadata.MD // Pre-populate the map of which schemas are in which catalogs @@ -716,6 +725,7 @@ func (c *connectionImpl) GetObjectsTables(ctx context.Context, depth adbc.Object if depth == adbc.ObjectDepthCatalogs || depth == adbc.ObjectDepthDBSchemas { return } + ctx = metadata.NewOutgoingContext(ctx, c.hdrs) result = make(map[internal.CatalogAndSchema][]internal.TableInfo) // Pre-populate the map of which schemas are in which catalogs diff --git a/python/adbc_driver_flightsql/tests/conftest.py b/python/adbc_driver_flightsql/tests/conftest.py index b4eb181105..96105a6f27 100644 --- a/python/adbc_driver_flightsql/tests/conftest.py +++ b/python/adbc_driver_flightsql/tests/conftest.py @@ -69,6 +69,7 @@ def dremio_dbapi(dremio_uri, dremio_user, dremio_pass): adbc_driver_manager.DatabaseOptions.USERNAME.value: dremio_user, adbc_driver_manager.DatabaseOptions.PASSWORD.value: dremio_pass, }, + autocommit=True, ) as conn: yield conn @@ -79,5 +80,8 @@ def test_dbapi(): if not uri: pytest.skip("Set ADBC_TEST_FLIGHTSQL_URI to run tests") - with adbc_driver_flightsql.dbapi.connect(uri) as conn: + with adbc_driver_flightsql.dbapi.connect( + uri, + autocommit=True, + ) as conn: yield conn diff --git a/python/adbc_driver_flightsql/tests/test_errors.py b/python/adbc_driver_flightsql/tests/test_errors.py index 688369a96e..d572589aee 100644 --- a/python/adbc_driver_flightsql/tests/test_errors.py +++ b/python/adbc_driver_flightsql/tests/test_errors.py @@ -15,7 +15,9 @@ # specific language governing permissions and limitations # under the License. +import contextlib import re +import secrets import threading import time @@ -190,3 +192,88 @@ def test_stateless_prepared_statement(test_dbapi) -> None: with test_dbapi.cursor() as cur: cur.adbc_prepare("stateless_prepared_statement") cur.execute("stateless_prepared_statement", parameters=[(1,)]) + + +def test_header_propagation(test_dbapi) -> None: + header = "x-trace" + option = f"adbc.flight.sql.rpc.call_header.{header}" + + @contextlib.contextmanager + def _trace(value): + test_dbapi.adbc_connection.set_options(**{option: value}) + yield + test_dbapi.adbc_connection.set_options(**{option: ""}) + + getobjects = secrets.token_hex(16) + with _trace(getobjects): + with test_dbapi.adbc_get_objects(): + pass + + stmt = secrets.token_hex(16) + with _trace(stmt): + with test_dbapi.cursor() as cur: + cur.execute("foo") + cur.fetchall() + + prepared = secrets.token_hex(16) + with _trace(prepared): + with test_dbapi.cursor() as cur: + cur.adbc_prepare("stateless_prepared_statement") + cur.execute("stateless_prepared_statement", parameters=[(1,)]) + + txn = secrets.token_hex(16) + with _trace(txn): + test_dbapi.adbc_connection.set_autocommit(False) + test_dbapi.adbc_connection.set_autocommit(True) + + sess = secrets.token_hex(16) + with _trace(sess): + test_dbapi.adbc_connection.get_option("adbc.flight.sql.session.options") + test_dbapi.adbc_connection.set_options( + **{ + "adbc.flight.sql.session.option.foo": 2, + } + ) + + with test_dbapi.cursor() as cur: + cur.execute("recorded_headers") + headers = [x for x in cur.fetchall() if x[1] == header] + + for method in [ + "GetFlightInfoCatalogs", + "DoGetCatalogs", + "GetFlightInfoDBSchemas", + "DoGetDBSchemas", + "GetFlightInfoTables", + "DoGetTables", + ]: + assert (method, header, getobjects) in headers + + for method in [ + "CreatePreparedStatement", + "ClosePreparedStatement", + "GetFlightInfoPreparedStatement", + "DoGetPreparedStatement", + ]: + assert (method, header, stmt) in headers + + for method in [ + "CreatePreparedStatement", + "ClosePreparedStatement", + "GetFlightInfoPreparedStatement", + "DoGetPreparedStatement", + "DoPutPreparedStatementQuery", + ]: + assert (method, header, prepared) in headers + + for method in [ + "BeginTransaction", + "EndTransaction", + ]: + assert (method, header, txn) in headers + + for method in [ + "GetSessionOptions", + "SetSessionOptions", + ]: + assert (method, header, sess) in headers