Skip to content

Commit

Permalink
GH-36952: [C++][FlightRPC][Python] Add methods to send headers (#36956)
Browse files Browse the repository at this point in the history
### Rationale for this change

Sending headers/trailers is required for services, but you couldn't do this before.

### What changes are included in this PR?

Add new methods to directly send headers/trailers.

### Are these changes tested?

Yes

### Are there any user-facing changes?

Yes (new APIs)

* Closes: #36952

Authored-by: David Li <[email protected]>
Signed-off-by: Sutou Kouhei <[email protected]>
  • Loading branch information
lidavidm authored Jul 31, 2023
1 parent a06b261 commit 37cb592
Show file tree
Hide file tree
Showing 11 changed files with 174 additions and 22 deletions.
5 changes: 5 additions & 0 deletions cpp/src/arrow/flight/client_middleware.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ class ARROW_FLIGHT_EXPORT ClientMiddleware {
virtual void SendingHeaders(AddCallHeaders* outgoing_headers) = 0;

/// \brief A callback when headers are received from the server.
///
/// This may be called more than once, since servers send both
/// headers and trailers. Some implementations (e.g. gRPC-Java, and
/// hence Arrow Flight in Java) may consolidate headers into
/// trailers if the RPC errored.
virtual void ReceivedHeaders(const CallHeaders& incoming_headers) = 0;

/// \brief A callback after the call has completed.
Expand Down
9 changes: 9 additions & 0 deletions cpp/src/arrow/flight/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,15 @@ class ARROW_FLIGHT_EXPORT ServerCallContext {
virtual const std::string& peer_identity() const = 0;
/// \brief The peer address (not validated)
virtual const std::string& peer() const = 0;
/// \brief Add a response header. This is only valid before the server
/// starts sending the response; generally this isn't an issue unless you
/// are implementing FlightDataStream, ResultStream, or similar interfaces
/// yourself, or during a DoExchange or DoPut.
virtual void AddHeader(const std::string& key, const std::string& value) const = 0;
/// \brief Add a response trailer. This is only valid before the server
/// sends the final status; generally this isn't an issue unless your RPC
/// handler launches a thread or similar.
virtual void AddTrailer(const std::string& key, const std::string& value) const = 0;
/// \brief Look up a middleware by key. Do not maintain a reference
/// to the object beyond the request body.
/// \return The middleware, or nullptr if not found.
Expand Down
87 changes: 81 additions & 6 deletions cpp/src/arrow/flight/test_definitions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,22 @@
#include "arrow/flight/test_definitions.h"

#include <chrono>
#include <memory>
#include <mutex>

#include "arrow/array/array_base.h"
#include "arrow/array/array_dict.h"
#include "arrow/array/util.h"
#include "arrow/flight/api.h"
#include "arrow/flight/client_middleware.h"
#include "arrow/flight/test_util.h"
#include "arrow/table.h"
#include "arrow/testing/generator.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/config.h"
#include "arrow/util/logging.h"
#include "gmock/gmock.h"

#if defined(ARROW_CUDA)
#include "arrow/gpu/cuda_api.h"
Expand Down Expand Up @@ -1438,20 +1443,26 @@ class ErrorHandlingTestServer : public FlightServerBase {
public:
Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request,
std::unique_ptr<FlightInfo>* info) override {
if (request.path.size() >= 2) {
if (request.path.size() == 1 && request.path[0] == "metadata") {
context.AddHeader("x-header", "header-value");
context.AddHeader("x-header-bin", "header\x01value");
context.AddTrailer("x-trailer", "trailer-value");
context.AddTrailer("x-trailer-bin", "trailer\x01value");
return Status::Invalid("Expected");
} else if (request.path.size() >= 2) {
const int raw_code = std::atoi(request.path[0].c_str());
ARROW_ASSIGN_OR_RAISE(StatusCode code, TryConvertStatusCode(raw_code));

if (request.path.size() == 2) {
return Status(code, request.path[1]);
return {code, request.path[1]};
} else if (request.path.size() == 3) {
return Status(code, request.path[1], std::make_shared<TestStatusDetail>());
return {code, request.path[1], std::make_shared<TestStatusDetail>()};
} else {
const int raw_code = std::atoi(request.path[2].c_str());
ARROW_ASSIGN_OR_RAISE(FlightStatusCode flight_code,
TryConvertFlightStatusCode(raw_code));
return Status(code, request.path[1],
std::make_shared<FlightStatusDetail>(flight_code, request.path[3]));
return {code, request.path[1],
std::make_shared<FlightStatusDetail>(flight_code, request.path[3])};
}
}
return Status::NotImplemented("NYI");
Expand All @@ -1469,20 +1480,70 @@ class ErrorHandlingTestServer : public FlightServerBase {
return MakeFlightError(FlightStatusCode::Unauthorized, "Unauthorized", "extra info");
}
};

class MetadataRecordingClientMiddleware : public ClientMiddleware {
public:
explicit MetadataRecordingClientMiddleware(
std::mutex& mutex, std::vector<std::pair<std::string, std::string>>& headers)
: mutex_(mutex), headers_(headers) {}
void SendingHeaders(AddCallHeaders*) override {}
void ReceivedHeaders(const CallHeaders& incoming_headers) override {
std::lock_guard<std::mutex> guard(mutex_);
for (const auto& [key, value] : incoming_headers) {
headers_.emplace_back(key, value);
}
}
void CallCompleted(const Status&) override {}

private:
std::mutex& mutex_;
std::vector<std::pair<std::string, std::string>>& headers_;
};

class MetadataRecordingClientMiddlewareFactory : public ClientMiddlewareFactory {
public:
void StartCall(const CallInfo&,
std::unique_ptr<ClientMiddleware>* middleware) override {
*middleware = std::make_unique<MetadataRecordingClientMiddleware>(mutex_, headers_);
}

std::vector<std::pair<std::string, std::string>> GetHeaders() const {
std::lock_guard<std::mutex> guard(mutex_);
// Take copy
return headers_;
}

private:
mutable std::mutex mutex_;
std::vector<std::pair<std::string, std::string>> headers_;
};
} // namespace

struct ErrorHandlingTest::Impl {
std::shared_ptr<MetadataRecordingClientMiddlewareFactory> metadata =
std::make_shared<MetadataRecordingClientMiddlewareFactory>();
};

void ErrorHandlingTest::SetUpTest() {
impl_ = std::make_shared<Impl>();
ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
ASSERT_OK(MakeServer<ErrorHandlingTestServer>(
location, &server_, &client_,
[](FlightServerOptions* options) { return Status::OK(); },
[](FlightClientOptions* options) { return Status::OK(); }));
[&](FlightClientOptions* options) {
options->middleware.emplace_back(impl_->metadata);
return Status::OK();
}));
}
void ErrorHandlingTest::TearDownTest() {
ASSERT_OK(client_->Close());
ASSERT_OK(server_->Shutdown());
}

std::vector<std::pair<std::string, std::string>> ErrorHandlingTest::GetHeaders() {
return impl_->metadata->GetHeaders();
}

void ErrorHandlingTest::TestGetFlightInfo() {
std::unique_ptr<FlightInfo> info;
for (const auto code : kStatusCodes) {
Expand Down Expand Up @@ -1518,6 +1579,20 @@ void ErrorHandlingTest::TestGetFlightInfo() {
}
}

void ErrorHandlingTest::TestGetFlightInfoMetadata() {
auto descr = FlightDescriptor::Path({"metadata"});
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("Expected"),
client_->GetFlightInfo(descr));
// This is janky because we don't/can't expose grpc::CallContext.
// See https://github.com/apache/arrow/issues/34607
ASSERT_THAT(GetHeaders(), ::testing::IsSupersetOf({
std::make_pair("x-header", "header-value"),
std::make_pair("x-header-bin", "header\x01value"),
std::make_pair("x-trailer", "trailer-value"),
std::make_pair("x-trailer-bin", "trailer\x01value"),
}));
}

void CheckErrorDetail(const Status& status) {
auto detail = FlightStatusDetail::UnwrapStatus(status);
ASSERT_NE(detail, nullptr) << status.ToString();
Expand Down
9 changes: 8 additions & 1 deletion cpp/src/arrow/flight/test_definitions.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,16 @@ class ARROW_FLIGHT_EXPORT ErrorHandlingTest : public FlightTest {

// Test methods
void TestGetFlightInfo();
void TestGetFlightInfoMetadata();
void TestDoPut();
void TestDoExchange();

private:
protected:
struct Impl;

std::vector<std::pair<std::string, std::string>> GetHeaders();

std::shared_ptr<Impl> impl_;
std::unique_ptr<FlightClient> client_;
std::unique_ptr<FlightServerBase> server_;
};
Expand All @@ -277,6 +283,7 @@ class ARROW_FLIGHT_EXPORT ErrorHandlingTest : public FlightTest {
static_assert(std::is_base_of<ErrorHandlingTest, FIXTURE>::value, \
ARROW_STRINGIFY(FIXTURE) " must inherit from ErrorHandlingTest"); \
TEST_F(FIXTURE, TestGetFlightInfo) { TestGetFlightInfo(); } \
TEST_F(FIXTURE, TestGetFlightInfoMetadata) { TestGetFlightInfoMetadata(); } \
TEST_F(FIXTURE, TestDoPut) { TestDoPut(); } \
TEST_F(FIXTURE, TestDoExchange) { TestDoExchange(); }

Expand Down
18 changes: 4 additions & 14 deletions cpp/src/arrow/flight/transport/grpc/grpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ class GrpcClientInterceptorAdapter : public ::grpc::experimental::Interceptor {
public:
explicit GrpcClientInterceptorAdapter(
std::vector<std::unique_ptr<ClientMiddleware>> middleware)
: middleware_(std::move(middleware)), received_headers_(false) {}
: middleware_(std::move(middleware)) {}

void Intercept(::grpc::experimental::InterceptorBatchMethods* methods) {
void Intercept(::grpc::experimental::InterceptorBatchMethods* methods) override {
using InterceptionHookPoints = ::grpc::experimental::InterceptionHookPoints;
if (methods->QueryInterceptionHookPoint(
InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
Expand Down Expand Up @@ -142,10 +142,6 @@ class GrpcClientInterceptorAdapter : public ::grpc::experimental::Interceptor {
private:
void ReceivedHeaders(
const std::multimap<::grpc::string_ref, ::grpc::string_ref>& metadata) {
if (received_headers_) {
return;
}
received_headers_ = true;
CallHeaders headers;
for (const auto& entry : metadata) {
headers.insert({std::string_view(entry.first.data(), entry.first.length()),
Expand All @@ -157,20 +153,14 @@ class GrpcClientInterceptorAdapter : public ::grpc::experimental::Interceptor {
}

std::vector<std::unique_ptr<ClientMiddleware>> middleware_;
// When communicating with a gRPC-Java server, the server may not
// send back headers if the call fails right away. Instead, the
// headers will be consolidated into the trailers. We don't want to
// call the client middleware callback twice, so instead track
// whether we saw headers - if not, then we need to check trailers.
bool received_headers_;
};

class GrpcClientInterceptorAdapterFactory
: public ::grpc::experimental::ClientInterceptorFactoryInterface {
public:
GrpcClientInterceptorAdapterFactory(
explicit GrpcClientInterceptorAdapterFactory(
std::vector<std::shared_ptr<ClientMiddlewareFactory>> middleware)
: middleware_(middleware) {}
: middleware_(std::move(middleware)) {}

::grpc::experimental::Interceptor* CreateClientInterceptor(
::grpc::experimental::ClientRpcInfo* info) override {
Expand Down
9 changes: 9 additions & 0 deletions cpp/src/arrow/flight/transport/grpc/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class GrpcServerAuthSender : public ServerAuthSender {
};

class GrpcServerCallContext : public ServerCallContext {
public:
explicit GrpcServerCallContext(::grpc::ServerContext* context)
: context_(context), peer_(context_->peer()) {
for (const auto& entry : context->client_metadata()) {
Expand Down Expand Up @@ -143,6 +144,14 @@ class GrpcServerCallContext : public ServerCallContext {
return ToGrpcStatus(status, context_);
}

void AddHeader(const std::string& key, const std::string& value) const override {
context_->AddInitialMetadata(key, value);
}

void AddTrailer(const std::string& key, const std::string& value) const override {
context_->AddTrailingMetadata(key, value);
}

ServerMiddleware* GetMiddleware(const std::string& key) const override {
const auto& instance = middleware_map_.find(key);
if (instance == middleware_map_.end()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ class UcxErrorHandlingTest : public ErrorHandlingTest, public ::testing::Test {
std::string transport() const override { return "ucx"; }
void SetUp() override { SetUpTest(); }
void TearDown() override { TearDownTest(); }

void TestGetFlightInfoMetadata() { GTEST_SKIP() << "Middleware not implemented"; }
};
ARROW_FLIGHT_TEST_ERROR_HANDLING(UcxErrorHandlingTest);

Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/flight/transport/ucx/ucx_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ class UcxServerCallContext : public flight::ServerCallContext {
public:
const std::string& peer_identity() const override { return peer_; }
const std::string& peer() const override { return peer_; }
// Not supported
void AddHeader(const std::string& key, const std::string& value) const override {}
void AddTrailer(const std::string& key, const std::string& value) const override {}
ServerMiddleware* GetMiddleware(const std::string& key) const override {
return nullptr;
}
Expand Down
8 changes: 8 additions & 0 deletions python/pyarrow/_flight.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1756,6 +1756,14 @@ cdef class ServerCallContext(_Weakrefable):
"""Check if the current RPC call has been canceled by the client."""
return self.context.is_cancelled()

def add_header(self, key, value):
"""Add a response header."""
self.context.AddHeader(tobytes(key), tobytes(value))

def add_trailer(self, key, value):
"""Add a response trailer."""
self.context.AddTrailer(tobytes(key), tobytes(value))

def get_middleware(self, key):
"""
Get a middleware instance by key.
Expand Down
2 changes: 2 additions & 0 deletions python/pyarrow/includes/libarrow_flight.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
c_string& peer_identity()
c_string& peer()
c_bool is_cancelled()
void AddHeader(const c_string& key, const c_string& value)
void AddTrailer(const c_string& key, const c_string& value)
CServerMiddleware* GetMiddleware(const c_string& key)

cdef cppclass CTimeoutDuration" arrow::flight::TimeoutDuration":
Expand Down
44 changes: 43 additions & 1 deletion python/pyarrow/tests/test_flight.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,7 @@ def sending_headers(self):
def received_headers(self, headers):
# Let the test code know what the last set of headers we
# received were.
self.factory.last_headers = headers
self.factory.last_headers.update(headers)


class MultiHeaderServerMiddlewareFactory(ServerMiddlewareFactory):
Expand Down Expand Up @@ -2323,3 +2323,45 @@ def test_do_put_does_not_crash_when_schema_is_none():
with pytest.raises(TypeError, match=msg):
client.do_put(flight.FlightDescriptor.for_command('foo'),
schema=None)


def test_headers_trailers():
"""Ensure that server-sent headers/trailers make it through."""

class HeadersTrailersFlightServer(FlightServerBase):
def get_flight_info(self, context, descriptor):
context.add_header("x-header", "header-value")
context.add_header("x-header-bin", "header\x01value")
context.add_trailer("x-trailer", "trailer-value")
context.add_trailer("x-trailer-bin", "trailer\x01value")
return flight.FlightInfo(
pa.schema([]),
descriptor,
[],
-1, -1
)

class HeadersTrailersMiddlewareFactory(ClientMiddlewareFactory):
def __init__(self):
self.headers = []

def start_call(self, info):
return HeadersTrailersMiddleware(self)

class HeadersTrailersMiddleware(ClientMiddleware):
def __init__(self, factory):
self.factory = factory

def received_headers(self, headers):
for key, values in headers.items():
for value in values:
self.factory.headers.append((key, value))

factory = HeadersTrailersMiddlewareFactory()
with HeadersTrailersFlightServer() as server, \
FlightClient(("localhost", server.port), middleware=[factory]) as client:
client.get_flight_info(flight.FlightDescriptor.for_path(""))
assert ("x-header", "header-value") in factory.headers
assert ("x-header-bin", b"header\x01value") in factory.headers
assert ("x-trailer", "trailer-value") in factory.headers
assert ("x-trailer-bin", b"trailer\x01value") in factory.headers

0 comments on commit 37cb592

Please sign in to comment.