From b0786d48a58a5f95fe22db12932a3a9dffb4101f Mon Sep 17 00:00:00 2001 From: qmmk <47608571+qmmk@users.noreply.github.com> Date: Tue, 3 Sep 2024 15:30:06 +0200 Subject: [PATCH] GH-43907: [C#][FlightRPC] Add Grpc Call Options support on Flight Client (#43910) ### Rationale for this change This implementation add default grpc call options on the csharp implementation FlightClient ### What changes are included in this PR? - FlightClient.cs with updated signature for all the methods accepting grpc call options - FlightTest.cs update test to verify the raise of the right exception ### Are these changes tested? Yes, tests are added in FlightTest.cs I've tested locally with the C++ implementation. ### Are there any user-facing changes? No is transparent for the user, following the already present documentation should be sufficient. ### References * GitHub Issue: #43907 Authored-by: Marco Malagoli Signed-off-by: Curt Hagenlocher --- .../Client/FlightClient.cs | 69 ++++++++++--- .../Apache.Arrow.Flight.Tests/FlightTests.cs | 97 ++++++++++++++++++- 2 files changed, 150 insertions(+), 16 deletions(-) diff --git a/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs b/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs index efb22b1948a01..b89ce9da79d14 100644 --- a/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs +++ b/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs @@ -13,6 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Flight.Internal; using Apache.Arrow.Flight.Protocol; @@ -34,12 +35,17 @@ public FlightClient(ChannelBase grpcChannel) public AsyncServerStreamingCall ListFlights(FlightCriteria criteria = null, Metadata headers = null) { - if(criteria == null) + return ListFlights(criteria, headers, null, CancellationToken.None); + } + + public AsyncServerStreamingCall ListFlights(FlightCriteria criteria, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default) + { + if (criteria == null) { criteria = FlightCriteria.Empty; } - - var response = _client.ListFlights(criteria.ToProtocol(), headers); + + var response = _client.ListFlights(criteria.ToProtocol(), headers, deadline, cancellationToken); var convertStream = new StreamReader(response.ResponseStream, inFlight => new FlightInfo(inFlight)); return new AsyncServerStreamingCall(convertStream, response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose); @@ -47,7 +53,12 @@ public AsyncServerStreamingCall ListFlights(FlightCriteria criteria public AsyncServerStreamingCall ListActions(Metadata headers = null) { - var response = _client.ListActions(EmptyInstance, headers); + return ListActions(headers, null, CancellationToken.None); + } + + public AsyncServerStreamingCall ListActions(Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default) + { + var response = _client.ListActions(EmptyInstance, headers, deadline, cancellationToken); var convertStream = new StreamReader(response.ResponseStream, actionType => new FlightActionType(actionType)); return new AsyncServerStreamingCall(convertStream, response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose); @@ -55,14 +66,24 @@ public AsyncServerStreamingCall ListActions(Metadata headers = public FlightRecordBatchStreamingCall GetStream(FlightTicket ticket, Metadata headers = null) { - var stream = _client.DoGet(ticket.ToProtocol(), headers); + return GetStream(ticket, headers, null, CancellationToken.None); + } + + public FlightRecordBatchStreamingCall GetStream(FlightTicket ticket, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default) + { + var stream = _client.DoGet(ticket.ToProtocol(), headers, deadline, cancellationToken); var responseStream = new FlightClientRecordBatchStreamReader(stream.ResponseStream); return new FlightRecordBatchStreamingCall(responseStream, stream.ResponseHeadersAsync, stream.GetStatus, stream.GetTrailers, stream.Dispose); } public AsyncUnaryCall GetInfo(FlightDescriptor flightDescriptor, Metadata headers = null) { - var flightInfoResult = _client.GetFlightInfoAsync(flightDescriptor.ToProtocol(), headers); + return GetInfo(flightDescriptor, headers, null, CancellationToken.None); + } + + public AsyncUnaryCall GetInfo(FlightDescriptor flightDescriptor, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default) + { + var flightInfoResult = _client.GetFlightInfoAsync(flightDescriptor.ToProtocol(), headers, deadline, cancellationToken); var flightInfo = flightInfoResult .ResponseAsync @@ -79,7 +100,12 @@ public AsyncUnaryCall GetInfo(FlightDescriptor flightDescriptor, Met public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor flightDescriptor, Metadata headers = null) { - var channels = _client.DoPut(headers); + return StartPut(flightDescriptor, headers, null, CancellationToken.None); + } + + public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor flightDescriptor, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default) + { + var channels = _client.DoPut(headers, deadline, cancellationToken); var requestStream = new FlightClientRecordBatchStreamWriter(channels.RequestStream, flightDescriptor); var readStream = new StreamReader(channels.ResponseStream, putResult => new FlightPutResult(putResult)); return new FlightRecordBatchDuplexStreamingCall( @@ -93,7 +119,13 @@ public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor flightDesc public AsyncDuplexStreamingCall Handshake(Metadata headers = null) { - var channel = _client.Handshake(headers); + return Handshake(headers, null, CancellationToken.None); + + } + + public AsyncDuplexStreamingCall Handshake(Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default) + { + var channel = _client.Handshake(headers, deadline, cancellationToken); var readStream = new StreamReader(channel.ResponseStream, response => new FlightHandshakeResponse(response)); var writeStream = new FlightHandshakeStreamWriterAdapter(channel.RequestStream); var call = new AsyncDuplexStreamingCall( @@ -109,7 +141,12 @@ public AsyncDuplexStreamingCall public FlightRecordBatchExchangeCall DoExchange(FlightDescriptor flightDescriptor, Metadata headers = null) { - var channel = _client.DoExchange(headers); + return DoExchange(flightDescriptor, headers, null, CancellationToken.None); + } + + public FlightRecordBatchExchangeCall DoExchange(FlightDescriptor flightDescriptor, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default) + { + var channel = _client.DoExchange(headers, deadline, cancellationToken); var requestStream = new FlightClientRecordBatchStreamWriter(channel.RequestStream, flightDescriptor); var responseStream = new FlightClientRecordBatchStreamReader(channel.ResponseStream); var call = new FlightRecordBatchExchangeCall( @@ -125,14 +162,24 @@ public FlightRecordBatchExchangeCall DoExchange(FlightDescriptor flightDescripto public AsyncServerStreamingCall DoAction(FlightAction action, Metadata headers = null) { - var stream = _client.DoAction(action.ToProtocol(), headers); + return DoAction(action, headers, null, CancellationToken.None); + } + + public AsyncServerStreamingCall DoAction(FlightAction action, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default) + { + var stream = _client.DoAction(action.ToProtocol(), headers, deadline, cancellationToken); var streamReader = new StreamReader(stream.ResponseStream, result => new FlightResult(result)); return new AsyncServerStreamingCall(streamReader, stream.ResponseHeadersAsync, stream.GetStatus, stream.GetTrailers, stream.Dispose); } public AsyncUnaryCall GetSchema(FlightDescriptor flightDescriptor, Metadata headers = null) { - var schemaResult = _client.GetSchemaAsync(flightDescriptor.ToProtocol(), headers); + return GetSchema(flightDescriptor, headers, null, CancellationToken.None); + } + + public AsyncUnaryCall GetSchema(FlightDescriptor flightDescriptor, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default) + { + var schemaResult = _client.GetSchemaAsync(flightDescriptor.ToProtocol(), headers, deadline, cancellationToken); var schema = schemaResult .ResponseAsync diff --git a/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs b/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs index aac4e4209240a..8bf6e1120c6d3 100644 --- a/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs @@ -16,12 +16,15 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Flight.Client; using Apache.Arrow.Flight.TestWeb; using Apache.Arrow.Tests; using Google.Protobuf; +using Grpc.Core; using Grpc.Core.Utils; +using Python.Runtime; using Xunit; namespace Apache.Arrow.Flight.Tests @@ -70,7 +73,7 @@ private FlightInfo GivenStoreBatches(FlightDescriptor flightDescriptor, params R var flightHolder = new FlightHolder(flightDescriptor, initialBatch.RecordBatch.Schema, _testWebFactory.GetAddress()); - foreach(var batch in batches) + foreach (var batch in batches) { flightHolder.AddBatch(batch); } @@ -187,8 +190,8 @@ public async Task TestGetFlightMetadata() var getStream = _flightClient.GetStream(endpoint.Ticket); - List actualMetadata = new List(); - while(await getStream.ResponseStream.MoveNext(default)) + List actualMetadata = new List(); + while (await getStream.ResponseStream.MoveNext(default)) { actualMetadata.AddRange(getStream.ResponseStream.ApplicationMetadata); } @@ -277,7 +280,7 @@ public async Task TestListFlights() var actualFlights = await listFlightStream.ResponseStream.ToListAsync(); - for(int i = 0; i < expectedFlightInfo.Count; i++) + for (int i = 0; i < expectedFlightInfo.Count; i++) { FlightInfoComparer.Compare(expectedFlightInfo[i], actualFlights[i]); } @@ -386,7 +389,7 @@ public async Task TestGetBatchesWithAsyncEnumerable() List resultList = new List(); - await foreach(var recordBatch in getStream.ResponseStream) + await foreach (var recordBatch in getStream.ResponseStream) { resultList.Add(recordBatch); } @@ -415,5 +418,89 @@ public async Task EnsureTheSerializedBatchContainsTheProperTotalRecordsAndTotalB Assert.Equal(expectedBatch.Length, result.TotalRecords); Assert.Equal(expectedTotalBytes, result.TotalBytes); } + + [Fact] + public async Task EnsureCallRaisesDeadlineExceeded() + { + var flightDescriptor = FlightDescriptor.CreatePathDescriptor("raise_deadline"); + var deadline = DateTime.UtcNow; + var batch = CreateTestBatch(0, 100); + + RpcException exception = null; + + var asyncServerStreamingCallFlights = _flightClient.ListFlights(null, null, deadline); + Assert.Equal(StatusCode.DeadlineExceeded, asyncServerStreamingCallFlights.GetStatus().StatusCode); + + var asyncServerStreamingCallActions = _flightClient.ListActions(null, deadline); + Assert.Equal(StatusCode.DeadlineExceeded, asyncServerStreamingCallFlights.GetStatus().StatusCode); + + GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(batch)); + exception = await Assert.ThrowsAsync(async () => await _flightClient.GetInfo(flightDescriptor, null, deadline)); + Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode); + + var flightInfo = await _flightClient.GetInfo(flightDescriptor); + var endpoint = flightInfo.Endpoints.FirstOrDefault(); + var getStream = _flightClient.GetStream(endpoint.Ticket, null, deadline); + Assert.Equal(StatusCode.DeadlineExceeded, getStream.GetStatus().StatusCode); + + var duplexStreamingCall = _flightClient.DoExchange(flightDescriptor, null, deadline); + exception = await Assert.ThrowsAsync(async () => await duplexStreamingCall.RequestStream.WriteAsync(batch)); + Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode); + + var putStream = _flightClient.StartPut(flightDescriptor, null, deadline); + exception = await Assert.ThrowsAsync(async () => await putStream.RequestStream.WriteAsync(batch)); + Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode); + + exception = await Assert.ThrowsAsync(async () => await _flightClient.GetSchema(flightDescriptor, null, deadline)); + Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode); + + var handshakeStreamingCall = _flightClient.Handshake(null, deadline); + exception = await Assert.ThrowsAsync(async () => await handshakeStreamingCall.RequestStream.WriteAsync(new FlightHandshakeRequest(ByteString.Empty))); + Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode); + } + + [Fact] + public async Task EnsureCallRaisesRequestCancelled() + { + var cts = new CancellationTokenSource(); + cts.CancelAfter(1); + + var batch = CreateTestBatch(0, 100); + var metadata = new Metadata(); + var flightDescriptor = FlightDescriptor.CreatePathDescriptor("raise_cancelled"); + await Task.Delay(5); + RpcException exception = null; + + var asyncServerStreamingCallFlights = _flightClient.ListFlights(null, null, null, cts.Token); + Assert.Equal(StatusCode.Cancelled, asyncServerStreamingCallFlights.GetStatus().StatusCode); + + var asyncServerStreamingCallActions = _flightClient.ListActions(null, null, cts.Token); + Assert.Equal(StatusCode.Cancelled, asyncServerStreamingCallFlights.GetStatus().StatusCode); + + GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(batch)); + exception = await Assert.ThrowsAsync(async () => await _flightClient.GetInfo(flightDescriptor, null, null, cts.Token)); + Assert.Equal(StatusCode.Cancelled, exception.StatusCode); + + var flightInfo = await _flightClient.GetInfo(flightDescriptor); + var endpoint = flightInfo.Endpoints.FirstOrDefault(); + var getStream = _flightClient.GetStream(endpoint.Ticket, null, null, cts.Token); + Assert.Equal(StatusCode.Cancelled, getStream.GetStatus().StatusCode); + + var duplexStreamingCall = _flightClient.DoExchange(flightDescriptor, null, null, cts.Token); + exception = await Assert.ThrowsAsync(async () => await duplexStreamingCall.RequestStream.WriteAsync(batch)); + Assert.Equal(StatusCode.Cancelled, exception.StatusCode); + + var putStream = _flightClient.StartPut(flightDescriptor, null, null, cts.Token); + exception = await Assert.ThrowsAsync(async () => await putStream.RequestStream.WriteAsync(batch)); + Assert.Equal(StatusCode.Cancelled, exception.StatusCode); + + exception = await Assert.ThrowsAsync(async () => await _flightClient.GetSchema(flightDescriptor, null, null, cts.Token)); + Assert.Equal(StatusCode.Cancelled, exception.StatusCode); + + var handshakeStreamingCall = _flightClient.Handshake(null, null, cts.Token); + exception = await Assert.ThrowsAsync(async () => await handshakeStreamingCall.RequestStream.WriteAsync(new FlightHandshakeRequest(ByteString.Empty))); + Assert.Equal(StatusCode.Cancelled, exception.StatusCode); + + } } }