Skip to content

Commit

Permalink
feat: supporting batchquery
Browse files Browse the repository at this point in the history
  • Loading branch information
Farenheith committed Oct 15, 2024
1 parent 4d80e7b commit c78c5ab
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 55 deletions.
73 changes: 47 additions & 26 deletions src/Codibre.GrpcSqlProxy.Api/Utils/IAsyncEnumerableExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,21 @@
using Dapper;
using Google.Protobuf;
using Microsoft.Data.SqlClient;
using Microsoft.Extensions.ObjectPool;
using Newtonsoft.Json;

namespace Codibre.GrpcSqlProxy.Api.Utils
{
public static class IAsyncEnumerableExtensions
{
private static readonly (ByteString, bool, bool) _empty = (ByteString.Empty, true, false);
public static async IAsyncEnumerable<(ByteString, bool, bool)> StreamByteChunks(this IAsyncEnumerable<dynamic> result, string schema, int packetSize, bool compress)
public static async IAsyncEnumerable<QueryPacket> StreamByteChunks(
this IAsyncEnumerable<dynamic> result,
string schema,
int packetSize,
bool compress,
int index,
int maxSchema
)
{
var schemaResult = CachedSchema.GetSchema(schema);
var queue = new ChunkQueue(compress, schemaResult, packetSize);
Expand All @@ -23,56 +30,70 @@ public static class IAsyncEnumerableExtensions
if (dict.TryGetValue(field, out var value)) record.Add(field, value);
}
queue.Write(record);
if (queue.Count > 1) yield return (queue.Pop(), false, compress);
if (queue.Count > 1) yield return QueryPacket.GetMid(queue, compress, index);
}
if (queue.Empty) yield return _empty;
if (queue.Empty) yield return QueryPacket.Empty(index, LastKind(index, maxSchema));
else
{
queue.EnqueueRest();
while (queue.Count > 1) yield return (queue.Pop(), false, compress);
if (queue.Count > 0) yield return (queue.Pop(), true, compress);
while (queue.Count > 1) yield return QueryPacket.GetMid(queue, compress, index);
if (queue.Count > 0) yield return QueryPacket.GetLast(queue, compress, index);
}
}

public static async IAsyncEnumerable<(ByteString, bool, bool)> EmptyResult(this Task result)
private static LastEnum LastKind(int index, int maxSchema)
{
await result;
yield return _empty;
return index < maxSchema ? LastEnum.SetLast : LastEnum.Last;
}

public static async IAsyncEnumerable<(ByteString, bool, bool)> EmptyResult(this ValueTask result)
{
await result;
yield return _empty;
}
public static IAsyncEnumerable<QueryPacket> EmptyResult(this Task result, int index, LastEnum last)
=> EmptyResult(new ValueTask(result), index, last);

public static async IAsyncEnumerable<(ByteString, bool, bool)> EmptyResult<T>(this ValueTask<T> result)
public static async IAsyncEnumerable<QueryPacket> EmptyResult(this ValueTask result, int index, LastEnum last)
{
await result;
yield return _empty;
yield return QueryPacket.Empty(index, last);
}

internal static IAsyncEnumerable<(ByteString, bool, bool)> GetResult(
internal static IAsyncEnumerable<QueryPacket> GetResult(
this SqlConnection connection,
SqlRequest request,
ProxyContext context
)
{
var query = request.Query;
var options = JsonConvert.DeserializeObject<Dictionary<string, object>?>(request.Params);
if (request.Schema.Count > 1) return RunQuery(connection, request, context, query, options);
return query.ToUpperInvariant().Replace(";", "") switch
{
"BEGIN TRANSACTION" => StartTransaction(connection, context).EmptyResult(),
"COMMIT" => Commit(context).EmptyResult(),
"ROLLBACK" => Rollback(context).EmptyResult(),
_ => string.IsNullOrWhiteSpace(request.Schema)
? connection.ExecuteAsync(query, options, context.Transaction)
.EmptyResult()
"BEGIN TRANSACTION" => StartTransaction(connection, context).EmptyResult(0, LastEnum.Last),
"COMMIT" => Commit(context).EmptyResult(0, LastEnum.Last),
"ROLLBACK" => Rollback(context).EmptyResult(0, LastEnum.Last),
_ => RunQuery(connection, request, context, query, options)
};
}

private static IAsyncEnumerable<QueryPacket> RunQuery(
SqlConnection connection,
SqlRequest request,
ProxyContext context,
string query,
Dictionary<string, object>? options
)
{
var max = request.Schema.Count - 1;
return request.Schema
.Select((schema, pos) => new { schema, pos })
.ToAsyncEnumerable()
.SelectMany(x =>
string.IsNullOrWhiteSpace(x.schema)
? connection
.ExecuteAsync(query, options, context.Transaction)
.EmptyResult(x.pos, LastKind(x.pos, max))
: connection
.QueryUnbufferedAsync(request.Query, options, context.Transaction)
.StreamByteChunks(request.Schema, request.PacketSize, request.Compress),
};
;
.StreamByteChunks(x.schema, request.PacketSize, request.Compress, x.pos, max)
);
}

private static async ValueTask Rollback(ProxyContext context)
Expand Down
21 changes: 21 additions & 0 deletions src/Codibre.GrpcSqlProxy.Api/Utils/QueryPacket.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using Codibre.GrpcSqlProxy.Api;
using Codibre.GrpcSqlProxy.Api.Utils;
using Google.Protobuf;

public record QueryPacket(
ByteString Result,
bool Compressed,
LastEnum Last,
int Index
)
{
public static QueryPacket Empty(int index, LastEnum last) => new(ByteString.Empty, false, last, index);
public static QueryPacket GetMid(ChunkQueue queue, bool compressed, int index)
=> new(queue.Pop(), compressed, LastEnum.Mid, index);

public static QueryPacket GetSetLast(ChunkQueue queue, bool compressed, int index)
=> new(queue.Pop(), compressed, LastEnum.SetLast, index);

public static QueryPacket GetLast(ChunkQueue queue, bool compressed, int index)
=> new(queue.Pop(), compressed, LastEnum.Last, index);
}
6 changes: 2 additions & 4 deletions src/Codibre.GrpcSqlProxy.Api/Utils/ResponsStreamExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,12 @@ private static void WriteError(this IServerStreamWriter<SqlResponse> responseStr
private static Task WriteSuccess(
this IServerStreamWriter<SqlResponse> responseStream,
SqlRequest request,
(ByteString, bool, bool) x)
QueryPacket packet)
{
return responseStream.WriteSqlResponse(
SqlResponseEx.Create(
request.Id,
x.Item1,
x.Item2,
x.Item3
packet
)
);
}
Expand Down
10 changes: 5 additions & 5 deletions src/Codibre.GrpcSqlProxy.Api/Utils/SqlResponseEx.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,21 @@ namespace Codibre.GrpcSqlProxy.Api.Utils
{
public static class SqlResponseEx
{
public static SqlResponse Create(string id, ByteString result, bool last, bool compressed) => new()
public static SqlResponse Create(string id, QueryPacket packet) => new()
{
Id = id,
Result = result,
Result = packet.Result,
Error = "",
Last = last,
Compressed = compressed
Last = packet.Last,
Compressed = packet.Compressed
};

public static SqlResponse CreateError(string id, string error) => new()
{
Id = id,
Result = ByteString.Empty,
Error = error,
Last = true,
Last = LastEnum.Last,
Compressed = false
};
}
Expand Down
54 changes: 36 additions & 18 deletions src/Codibre.GrpcSqlProxy.Client/Impl/SqlProxyClientTunnel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using Codibre.GrpcSqlProxy.Api;
using Codibre.GrpcSqlProxy.Client.Impl.Utils;
using Codibre.GrpcSqlProxy.Common;
using Google.Protobuf.Collections;
using Grpc.Core;

namespace Codibre.GrpcSqlProxy.Client.Impl
Expand Down Expand Up @@ -67,7 +68,7 @@ public async IAsyncEnumerable<T> Query<T>(string sql, SqlProxyQueryOptions? opti
var type = typeof(T);
var schema = type.GetCachedSchema();

var results = InternalRun(sql, schema.Item2, options);
var results = InternalRun(sql, [schema.Item2], options);
await foreach (var result in results)
{
using var memStream = result.Compressed ? result.Result.DecompressData() : result.Result.ToMemoryStream();
Expand All @@ -92,37 +93,54 @@ public async ValueTask Execute(string sql, SqlProxyQueryOptions? options = null)
await InternalRun(sql, null, options).LastAsync();
}

private async IAsyncEnumerable<SqlResponse> InternalRun(string sql, string? schema, SqlProxyQueryOptions? options)
private async IAsyncEnumerable<SqlResponse> InternalRun(string sql, string[]? schemas, SqlProxyQueryOptions? options)
{
var id = GuidEx.NewBase64Guid();
await _stream.RequestStream.WriteAsync(new()
{
Id = id,
ConnString = _started ? "" : _connString,
Query = sql,
Schema = schema ?? "",
Compress = options?.Compress ?? clientOptions.Compress,
PacketSize = options?.PacketSize ?? clientOptions.PacketSize,
Params = JsonSerializer.Serialize(options?.Params)
});
var count = (schemas?.Length ?? 1);

Check failure on line 99 in src/Codibre.GrpcSqlProxy.Client/Impl/SqlProxyClientTunnel.cs

View workflow job for this annotation

GitHub Actions / lint

Parentheses can be removed

Check failure on line 99 in src/Codibre.GrpcSqlProxy.Client/Impl/SqlProxyClientTunnel.cs

View workflow job for this annotation

GitHub Actions / build

Parentheses can be removed (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/style-rules/ide0047)

Check failure on line 99 in src/Codibre.GrpcSqlProxy.Client/Impl/SqlProxyClientTunnel.cs

View workflow job for this annotation

GitHub Actions / build

Parentheses can be removed (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/style-rules/ide0047)

Check failure on line 99 in src/Codibre.GrpcSqlProxy.Client/Impl/SqlProxyClientTunnel.cs

View workflow job for this annotation

GitHub Actions / test

Parentheses can be removed (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/style-rules/ide0047)
var message = GetRequest(clientOptions, sql, schemas, options, id);
await _stream.RequestStream.WriteAsync(message);
MonitorResponse();
var channel = Channel.CreateUnbounded<SqlResponse>();
_responseHooks.TryAdd(id, channel.Writer);
var reader = channel.Reader;

while (await reader.WaitToReadAsync(_cancellationTokenSource.Token))
for (var i = 0; i < count; i++)
{
reader.TryRead(out var item);
if (item is not null)
while (await reader.WaitToReadAsync(_cancellationTokenSource.Token))
{
if (!string.IsNullOrEmpty(item.Error)) throw new SqlProxyException(item.Error);
yield return item;
reader.TryRead(out var item);
if (item is not null)
{
if (!string.IsNullOrEmpty(item.Error)) throw new SqlProxyException(item.Error);
yield return item;
if (item.Last == LastEnum.SetLast) break;
if (item.Last == LastEnum.Last) yield break;
}
}
if (item?.Last == true) break;
}
_responseHooks.TryRemove(id, out _);
}

private SqlRequest GetRequest(SqlProxyClientOptions clientOptions, string sql, string[]? schemas, SqlProxyQueryOptions? options, string id)
{
SqlRequest message = new()
{
Id = id,
ConnString = _started ? "" : _connString,
Query = sql,
Schema = { },
Compress = options?.Compress ?? clientOptions.Compress,
PacketSize = options?.PacketSize ?? clientOptions.PacketSize,
Params = JsonSerializer.Serialize(options?.Params)
};
foreach (var schema in schemas ?? [])
{
message.Schema.Add(schema);
}

return message;
}

public ValueTask<T?> QueryFirstOrDefault<T>(string sql, SqlProxyQueryOptions? options = null) where T : class, new()
=> Query<T>(sql, options).FirstOrDefaultAsync();

Expand Down
11 changes: 9 additions & 2 deletions src/Protos/sql-proxy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,25 @@ service SqlProxy {

message SqlRequest {
string id = 1;
optional string schema = 2;
repeated string schema = 2;
string connString = 3;
string query = 4;
int32 packetSize = 5;
bool compress = 6;
optional string params = 7;
}

enum LastEnum {
Mid = 0;
SetLast = 1;
Last = 2;
}

message SqlResponse {
string id = 1;
optional bytes result = 2;
optional string error = 3;
bool last = 4;
LastEnum last = 4;
bool compressed = 5;
int32 index = 6;
}

0 comments on commit c78c5ab

Please sign in to comment.