Skip to content

Commit

Permalink
Merge pull request #2 from codibre/preparing-mars
Browse files Browse the repository at this point in the history
fix: preparing transaction for MARS
  • Loading branch information
Farenheith authored Oct 14, 2024
2 parents e218af1 + 624f35e commit b518ef1
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 26 deletions.
17 changes: 15 additions & 2 deletions src/Codibre.GrpcSqlProxy.Api/Services/SqlProxyService.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Codibre.GrpcSqlProxy.Api.Utils;
using Codibre.GrpcSqlProxy.Common;
using Dapper;
using Grpc.Core;
using static Codibre.GrpcSqlProxy.Api.SqlProxy;

Expand Down Expand Up @@ -27,7 +28,7 @@ await responseStream.Catch(request.Id, async () =>
if (connection is not null)
{
if (request.PacketSize <= 0) request.PacketSize = 1000;
responseStream.PipeResponse(connection, request);
responseStream.PipeResponse(connection, request, proxyContext);
}
});
}
Expand All @@ -38,7 +39,19 @@ await responseStream.Catch(request.Id, async () =>
}
finally
{
if (proxyContext.Connection is not null) await proxyContext.Connection.CloseAsync();
if (proxyContext.Transaction is not null)
{
try
{
await proxyContext.Transaction.RollbackAsync();
}
catch
{
// Ignore if error occurs as no transaction were there
}
}
if (proxyContext.Connection is not null)
await proxyContext.Connection.CloseAsync();
}
}
}
Expand Down
53 changes: 45 additions & 8 deletions src/Codibre.GrpcSqlProxy.Api/Utils/IAsyncEnumerableExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,56 @@ public static class IAsyncEnumerableExtensions
yield return _empty;
}

public static IAsyncEnumerable<(ByteString, bool, bool)> GetResult(
public static async IAsyncEnumerable<(ByteString, bool, bool)> EmptyResult(this ValueTask result)
{
await result;
yield return _empty;
}

public static async IAsyncEnumerable<(ByteString, bool, bool)> EmptyResult<T>(this ValueTask<T> result)
{
await result;
yield return _empty;
}

internal static IAsyncEnumerable<(ByteString, bool, bool)> GetResult(
this SqlConnection connection,
SqlRequest request
SqlRequest request,
ProxyContext context
)
{
var query = request.Query;
var options = JsonConvert.DeserializeObject<Dictionary<string, object>?>(request.Params);
return string.IsNullOrWhiteSpace(request.Schema)
? connection.ExecuteAsync(query, options)
.EmptyResult()
: connection
.QueryUnbufferedAsync(request.Query, options)
.StreamByteChunks(request.Schema, request.PacketSize, request.Compress);
return query.ToUpperInvariant().Replace(";", "") switch
{
"BEGIN TRANSACTION" => StartTransaction(connection, context).EmptyResult(),
"COMMIT" => Commit(context).EmptyResult(),
"ROLLBACK" => Rollback(context).EmptyResult(),
_ => string.IsNullOrWhiteSpace(request.Schema)
? connection.ExecuteAsync(query, options, context.Transaction)
.EmptyResult()
: connection
.QueryUnbufferedAsync(request.Query, options, context.Transaction)
.StreamByteChunks(request.Schema, request.PacketSize, request.Compress),
};
;
}

private static async ValueTask Rollback(ProxyContext context)
{
await context.Transaction!.RollbackAsync();
context.Transaction = null;
}

private static async ValueTask Commit(ProxyContext context)
{
await context.Transaction!.CommitAsync();
context.Transaction = null;
}

private static async ValueTask StartTransaction(SqlConnection connection, ProxyContext context)
{
context.Transaction = await connection.BeginTransactionAsync();
}
}
}
5 changes: 4 additions & 1 deletion src/Codibre.GrpcSqlProxy.Api/Utils/ProxyContext.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
using Microsoft.Data.SqlClient;
using System.Data.Common;
using System.Transactions;
using Microsoft.Data.SqlClient;

namespace Codibre.GrpcSqlProxy.Api.Utils
{
internal class ProxyContext
{
public string? ConnectionString { get; set; } = null;
public SqlConnection? Connection { get; set; } = null;
public DbTransaction? Transaction { get; set; } = null;
}
}
7 changes: 4 additions & 3 deletions src/Codibre.GrpcSqlProxy.Api/Utils/ResponsStreamExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ private static Task WriteSuccess(
);
}

public static void PipeResponse(
internal static void PipeResponse(
this IServerStreamWriter<SqlResponse> responseStream,
SqlConnection connection,
SqlRequest request
SqlRequest request,
ProxyContext context
) => _ = responseStream.Catch(request.Id, () =>
connection.GetResult(request)
connection.GetResult(request, context)
.ForEachAwaitAsync((x) => responseStream.WriteSuccess(request, x))
);
}
Expand Down
3 changes: 3 additions & 0 deletions src/Codibre.GrpcSqlProxy.Client/ISqlProxyClientTunnel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ namespace Codibre.GrpcSqlProxy.Client
public interface ISqlProxyClientTunnel : IDisposable
{
event ErrorHandlerEvent? ErrorHandler;
ValueTask BeginTransaction();
ValueTask Commit();
ValueTask Rollback();
ValueTask Execute(string sql, SqlProxyQueryOptions? options = null);
IAsyncEnumerable<T> Query<T>(string sql, SqlProxyQueryOptions? options = null) where T : class, new();
ValueTask<T?> QueryFirstOrDefault<T>(string sql, SqlProxyQueryOptions? options = null) where T : class, new();
Expand Down
6 changes: 6 additions & 0 deletions src/Codibre.GrpcSqlProxy.Client/Impl/SqlProxyClientTunnel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -134,5 +134,11 @@ public void Dispose()
_running = false;
_cancellationTokenSource.Cancel();
}

public ValueTask BeginTransaction() => Execute("BEGIN TRANSACTION");

public ValueTask Commit() => Execute("COMMIT");

public ValueTask Rollback() => Execute("ROLLBACK");
}
}
24 changes: 12 additions & 12 deletions test/Codibre.GrpcSqlProxy.Test/GrpcSqlProxyClientTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ public async Task Should_Keep_Transaction_Opened()
// Act
using var channel = client.CreateChannel();
await channel.Execute("DELETE FROM TB_PEDIDO");
await channel.Execute("BEGIN TRANSACTION");
await channel.Execute("INSERT INTO TB_PEDIDO VALUES (1)");
await channel.BeginTransaction();
await channel.Execute("INSERT INTO TB_PEDIDO (CD_PEDIDO) VALUES (1)");
var result1 = await channel.QueryFirstOrDefault<TB_PEDIDO>("SELECT * FROM TB_PEDIDO");
await channel.Execute("ROLLBACK");
await channel.Rollback();
var result2 = await channel.Query<TB_PEDIDO>("SELECT * FROM TB_PEDIDO").ToArrayAsync();

// Assert
Expand Down Expand Up @@ -93,10 +93,10 @@ public async Task Should_Use_Compression()
// Act
using var channel = client.CreateChannel();
await channel.Execute("DELETE FROM TB_PRODUTO");
await channel.Execute("BEGIN TRANSACTION");
await channel.Execute("INSERT INTO TB_PRODUTO VALUES (1)");
await channel.BeginTransaction();
await channel.Execute("INSERT INTO TB_PRODUTO (CD_PRODUTO) VALUES (1)");
var result1 = await channel.QueryFirstOrDefault<TB_PRODUTO>("SELECT * FROM TB_PRODUTO");
await channel.Execute("ROLLBACK");
await channel.Rollback();
var result2 = await channel.Query<TB_PRODUTO>("SELECT * FROM TB_PRODUTO").ToArrayAsync();

// Assert
Expand Down Expand Up @@ -156,10 +156,10 @@ public async Task Should_Keep_Parallel_Transaction_Opened()
using var channel1 = client.CreateChannel();
using var channel2 = client.CreateChannel();
await channel1.Execute("DELETE FROM TB_PESSOA");
await channel1.Execute("INSERT INTO TB_PESSOA VALUES (1)");
await channel1.Execute("INSERT INTO TB_PESSOA VALUES (2)");
await channel1.Execute("BEGIN TRANSACTION");
await channel2.Execute("BEGIN TRANSACTION");
await channel1.Execute("INSERT INTO TB_PESSOA (CD_PESSOA) VALUES (1)");
await channel1.Execute("INSERT INTO TB_PESSOA (CD_PESSOA) VALUES (2)");
await channel1.BeginTransaction();
await channel2.BeginTransaction();
await channel1.Execute("UPDATE TB_PESSOA SET CD_PESSOA = 3 WHERE CD_PESSOA = @Id", new()
{
Params = new
Expand All @@ -174,10 +174,10 @@ public async Task Should_Keep_Parallel_Transaction_Opened()
Id = 3
}
});
await channel1.Execute("ROLLBACK");
await channel1.Rollback();
await channel2.Execute("UPDATE TB_PESSOA SET CD_PESSOA = 5 WHERE CD_PESSOA = 2");
var result2 = await channel2.Query<TB_PESSOA>("SELECT * FROM TB_PESSOA").ToArrayAsync();
await channel2.Execute("ROLLBACK");
await channel2.Rollback();
var result3 = await channel1.Query<TB_PESSOA>("SELECT * FROM TB_PESSOA", new()
{
PacketSize = 1
Expand Down

0 comments on commit b518ef1

Please sign in to comment.