From c78c5ab47f585228cd8c06861e5d8d2eb742494e Mon Sep 17 00:00:00 2001 From: Thiago Oliveira Santos Date: Mon, 14 Oct 2024 23:13:27 -0300 Subject: [PATCH] feat: supporting batchquery --- .../Utils/IAsyncEnumerableExtensions.cs | 73 ++++++++++++------- .../Utils/QueryPacket.cs | 21 ++++++ .../Utils/ResponsStreamExtensions.cs | 6 +- .../Utils/SqlResponseEx.cs | 10 +-- .../Impl/SqlProxyClientTunnel.cs | 54 +++++++++----- src/Protos/sql-proxy.proto | 11 ++- 6 files changed, 120 insertions(+), 55 deletions(-) create mode 100644 src/Codibre.GrpcSqlProxy.Api/Utils/QueryPacket.cs diff --git a/src/Codibre.GrpcSqlProxy.Api/Utils/IAsyncEnumerableExtensions.cs b/src/Codibre.GrpcSqlProxy.Api/Utils/IAsyncEnumerableExtensions.cs index 3702e45..4f29d4b 100644 --- a/src/Codibre.GrpcSqlProxy.Api/Utils/IAsyncEnumerableExtensions.cs +++ b/src/Codibre.GrpcSqlProxy.Api/Utils/IAsyncEnumerableExtensions.cs @@ -3,14 +3,21 @@ using Dapper; using Google.Protobuf; using Microsoft.Data.SqlClient; +using Microsoft.Extensions.ObjectPool; using Newtonsoft.Json; namespace Codibre.GrpcSqlProxy.Api.Utils { public static class IAsyncEnumerableExtensions { - private static readonly (ByteString, bool, bool) _empty = (ByteString.Empty, true, false); - public static async IAsyncEnumerable<(ByteString, bool, bool)> StreamByteChunks(this IAsyncEnumerable result, string schema, int packetSize, bool compress) + public static async IAsyncEnumerable StreamByteChunks( + this IAsyncEnumerable result, + string schema, + int packetSize, + bool compress, + int index, + int maxSchema + ) { var schemaResult = CachedSchema.GetSchema(schema); var queue = new ChunkQueue(compress, schemaResult, packetSize); @@ -23,36 +30,32 @@ public static class IAsyncEnumerableExtensions if (dict.TryGetValue(field, out var value)) record.Add(field, value); } queue.Write(record); - if (queue.Count > 1) yield return (queue.Pop(), false, compress); + if (queue.Count > 1) yield return QueryPacket.GetMid(queue, compress, index); } - if (queue.Empty) yield return _empty; + if (queue.Empty) yield return QueryPacket.Empty(index, LastKind(index, maxSchema)); else { queue.EnqueueRest(); - while (queue.Count > 1) yield return (queue.Pop(), false, compress); - if (queue.Count > 0) yield return (queue.Pop(), true, compress); + while (queue.Count > 1) yield return QueryPacket.GetMid(queue, compress, index); + if (queue.Count > 0) yield return QueryPacket.GetLast(queue, compress, index); } } - public static async IAsyncEnumerable<(ByteString, bool, bool)> EmptyResult(this Task result) + private static LastEnum LastKind(int index, int maxSchema) { - await result; - yield return _empty; + return index < maxSchema ? LastEnum.SetLast : LastEnum.Last; } - public static async IAsyncEnumerable<(ByteString, bool, bool)> EmptyResult(this ValueTask result) - { - await result; - yield return _empty; - } + public static IAsyncEnumerable EmptyResult(this Task result, int index, LastEnum last) + => EmptyResult(new ValueTask(result), index, last); - public static async IAsyncEnumerable<(ByteString, bool, bool)> EmptyResult(this ValueTask result) + public static async IAsyncEnumerable EmptyResult(this ValueTask result, int index, LastEnum last) { await result; - yield return _empty; + yield return QueryPacket.Empty(index, last); } - internal static IAsyncEnumerable<(ByteString, bool, bool)> GetResult( + internal static IAsyncEnumerable GetResult( this SqlConnection connection, SqlRequest request, ProxyContext context @@ -60,19 +63,37 @@ ProxyContext context { var query = request.Query; var options = JsonConvert.DeserializeObject?>(request.Params); + if (request.Schema.Count > 1) return RunQuery(connection, request, context, query, options); return query.ToUpperInvariant().Replace(";", "") switch { - "BEGIN TRANSACTION" => StartTransaction(connection, context).EmptyResult(), - "COMMIT" => Commit(context).EmptyResult(), - "ROLLBACK" => Rollback(context).EmptyResult(), - _ => string.IsNullOrWhiteSpace(request.Schema) - ? connection.ExecuteAsync(query, options, context.Transaction) -.EmptyResult() + "BEGIN TRANSACTION" => StartTransaction(connection, context).EmptyResult(0, LastEnum.Last), + "COMMIT" => Commit(context).EmptyResult(0, LastEnum.Last), + "ROLLBACK" => Rollback(context).EmptyResult(0, LastEnum.Last), + _ => RunQuery(connection, request, context, query, options) + }; + } + + private static IAsyncEnumerable RunQuery( + SqlConnection connection, + SqlRequest request, + ProxyContext context, + string query, + Dictionary? options + ) + { + var max = request.Schema.Count - 1; + return request.Schema + .Select((schema, pos) => new { schema, pos }) + .ToAsyncEnumerable() + .SelectMany(x => + string.IsNullOrWhiteSpace(x.schema) + ? connection + .ExecuteAsync(query, options, context.Transaction) + .EmptyResult(x.pos, LastKind(x.pos, max)) : connection .QueryUnbufferedAsync(request.Query, options, context.Transaction) -.StreamByteChunks(request.Schema, request.PacketSize, request.Compress), - }; - ; + .StreamByteChunks(x.schema, request.PacketSize, request.Compress, x.pos, max) + ); } private static async ValueTask Rollback(ProxyContext context) diff --git a/src/Codibre.GrpcSqlProxy.Api/Utils/QueryPacket.cs b/src/Codibre.GrpcSqlProxy.Api/Utils/QueryPacket.cs new file mode 100644 index 0000000..5bf0a0f --- /dev/null +++ b/src/Codibre.GrpcSqlProxy.Api/Utils/QueryPacket.cs @@ -0,0 +1,21 @@ +using Codibre.GrpcSqlProxy.Api; +using Codibre.GrpcSqlProxy.Api.Utils; +using Google.Protobuf; + +public record QueryPacket( + ByteString Result, + bool Compressed, + LastEnum Last, + int Index +) +{ + public static QueryPacket Empty(int index, LastEnum last) => new(ByteString.Empty, false, last, index); + public static QueryPacket GetMid(ChunkQueue queue, bool compressed, int index) + => new(queue.Pop(), compressed, LastEnum.Mid, index); + + public static QueryPacket GetSetLast(ChunkQueue queue, bool compressed, int index) + => new(queue.Pop(), compressed, LastEnum.SetLast, index); + + public static QueryPacket GetLast(ChunkQueue queue, bool compressed, int index) + => new(queue.Pop(), compressed, LastEnum.Last, index); +} \ No newline at end of file diff --git a/src/Codibre.GrpcSqlProxy.Api/Utils/ResponsStreamExtensions.cs b/src/Codibre.GrpcSqlProxy.Api/Utils/ResponsStreamExtensions.cs index e3779d9..8a42ba3 100644 --- a/src/Codibre.GrpcSqlProxy.Api/Utils/ResponsStreamExtensions.cs +++ b/src/Codibre.GrpcSqlProxy.Api/Utils/ResponsStreamExtensions.cs @@ -65,14 +65,12 @@ private static void WriteError(this IServerStreamWriter responseStr private static Task WriteSuccess( this IServerStreamWriter responseStream, SqlRequest request, - (ByteString, bool, bool) x) + QueryPacket packet) { return responseStream.WriteSqlResponse( SqlResponseEx.Create( request.Id, - x.Item1, - x.Item2, - x.Item3 + packet ) ); } diff --git a/src/Codibre.GrpcSqlProxy.Api/Utils/SqlResponseEx.cs b/src/Codibre.GrpcSqlProxy.Api/Utils/SqlResponseEx.cs index bd13c99..30b6c4b 100644 --- a/src/Codibre.GrpcSqlProxy.Api/Utils/SqlResponseEx.cs +++ b/src/Codibre.GrpcSqlProxy.Api/Utils/SqlResponseEx.cs @@ -4,13 +4,13 @@ namespace Codibre.GrpcSqlProxy.Api.Utils { public static class SqlResponseEx { - public static SqlResponse Create(string id, ByteString result, bool last, bool compressed) => new() + public static SqlResponse Create(string id, QueryPacket packet) => new() { Id = id, - Result = result, + Result = packet.Result, Error = "", - Last = last, - Compressed = compressed + Last = packet.Last, + Compressed = packet.Compressed }; public static SqlResponse CreateError(string id, string error) => new() @@ -18,7 +18,7 @@ public static class SqlResponseEx Id = id, Result = ByteString.Empty, Error = error, - Last = true, + Last = LastEnum.Last, Compressed = false }; } diff --git a/src/Codibre.GrpcSqlProxy.Client/Impl/SqlProxyClientTunnel.cs b/src/Codibre.GrpcSqlProxy.Client/Impl/SqlProxyClientTunnel.cs index f980fac..b432126 100644 --- a/src/Codibre.GrpcSqlProxy.Client/Impl/SqlProxyClientTunnel.cs +++ b/src/Codibre.GrpcSqlProxy.Client/Impl/SqlProxyClientTunnel.cs @@ -6,6 +6,7 @@ using Codibre.GrpcSqlProxy.Api; using Codibre.GrpcSqlProxy.Client.Impl.Utils; using Codibre.GrpcSqlProxy.Common; +using Google.Protobuf.Collections; using Grpc.Core; namespace Codibre.GrpcSqlProxy.Client.Impl @@ -67,7 +68,7 @@ public async IAsyncEnumerable Query(string sql, SqlProxyQueryOptions? opti var type = typeof(T); var schema = type.GetCachedSchema(); - var results = InternalRun(sql, schema.Item2, options); + var results = InternalRun(sql, [schema.Item2], options); await foreach (var result in results) { using var memStream = result.Compressed ? result.Result.DecompressData() : result.Result.ToMemoryStream(); @@ -92,37 +93,54 @@ public async ValueTask Execute(string sql, SqlProxyQueryOptions? options = null) await InternalRun(sql, null, options).LastAsync(); } - private async IAsyncEnumerable InternalRun(string sql, string? schema, SqlProxyQueryOptions? options) + private async IAsyncEnumerable InternalRun(string sql, string[]? schemas, SqlProxyQueryOptions? options) { var id = GuidEx.NewBase64Guid(); - await _stream.RequestStream.WriteAsync(new() - { - Id = id, - ConnString = _started ? "" : _connString, - Query = sql, - Schema = schema ?? "", - Compress = options?.Compress ?? clientOptions.Compress, - PacketSize = options?.PacketSize ?? clientOptions.PacketSize, - Params = JsonSerializer.Serialize(options?.Params) - }); + var count = (schemas?.Length ?? 1); + var message = GetRequest(clientOptions, sql, schemas, options, id); + await _stream.RequestStream.WriteAsync(message); MonitorResponse(); var channel = Channel.CreateUnbounded(); _responseHooks.TryAdd(id, channel.Writer); var reader = channel.Reader; - while (await reader.WaitToReadAsync(_cancellationTokenSource.Token)) + for (var i = 0; i < count; i++) { - reader.TryRead(out var item); - if (item is not null) + while (await reader.WaitToReadAsync(_cancellationTokenSource.Token)) { - if (!string.IsNullOrEmpty(item.Error)) throw new SqlProxyException(item.Error); - yield return item; + reader.TryRead(out var item); + if (item is not null) + { + if (!string.IsNullOrEmpty(item.Error)) throw new SqlProxyException(item.Error); + yield return item; + if (item.Last == LastEnum.SetLast) break; + if (item.Last == LastEnum.Last) yield break; + } } - if (item?.Last == true) break; } _responseHooks.TryRemove(id, out _); } + private SqlRequest GetRequest(SqlProxyClientOptions clientOptions, string sql, string[]? schemas, SqlProxyQueryOptions? options, string id) + { + SqlRequest message = new() + { + Id = id, + ConnString = _started ? "" : _connString, + Query = sql, + Schema = { }, + Compress = options?.Compress ?? clientOptions.Compress, + PacketSize = options?.PacketSize ?? clientOptions.PacketSize, + Params = JsonSerializer.Serialize(options?.Params) + }; + foreach (var schema in schemas ?? []) + { + message.Schema.Add(schema); + } + + return message; + } + public ValueTask QueryFirstOrDefault(string sql, SqlProxyQueryOptions? options = null) where T : class, new() => Query(sql, options).FirstOrDefaultAsync(); diff --git a/src/Protos/sql-proxy.proto b/src/Protos/sql-proxy.proto index 3793eb3..5e8d8e6 100644 --- a/src/Protos/sql-proxy.proto +++ b/src/Protos/sql-proxy.proto @@ -10,7 +10,7 @@ service SqlProxy { message SqlRequest { string id = 1; - optional string schema = 2; + repeated string schema = 2; string connString = 3; string query = 4; int32 packetSize = 5; @@ -18,10 +18,17 @@ message SqlRequest { optional string params = 7; } +enum LastEnum { + Mid = 0; + SetLast = 1; + Last = 2; +} + message SqlResponse { string id = 1; optional bytes result = 2; optional string error = 3; - bool last = 4; + LastEnum last = 4; bool compressed = 5; + int32 index = 6; }