From b984ae44b13386e4b874d5893c4878d24d1cdc1f Mon Sep 17 00:00:00 2001 From: Lucian Buzzo Date: Sun, 15 Oct 2023 13:58:28 +0100 Subject: [PATCH] feat: add support for nested transaction rollbacks via savepoints in sql This is my first OSS contribution for a Rust project, so I'm sure I've made some stupid mistakes, but I think it should mostly work :) This change adds a mutable depth counter, that can track how many levels deep a transaction is, and uses savepoints to implement correct rollback behaviour. Previously, once a nested transaction was complete, it would be saved with `COMMIT`, meaning that even if the outer transaction was rolled back, the operations in the inner transaction would persist. With this change, if the outer transaction gets rolled back, then all inner transactions will also be rolled back. Different flavours of SQL servers have different syntax for handling savepoints, so I've had to add new methods to the `Queryable` trait for getting the commit and rollback statements. These are both parameterized by the current depth. I've additionally had to modify the `begin_statement` method to accept a depth parameter, as it will need to conditionally create a savepoint. Signed-off-by: Lucian Buzzo --- quaint/src/connector/mssql.rs | 59 +++++++++++++++++++++++++++-- quaint/src/connector/mysql.rs | 39 +++++++++++++++++++ quaint/src/connector/postgres.rs | 39 +++++++++++++++++++ quaint/src/connector/queryable.rs | 36 ++++++++++++++++-- quaint/src/connector/sqlite.rs | 44 ++++++++++++++++++++- quaint/src/connector/transaction.rs | 56 +++++++++++++++++++++++---- quaint/src/pooled.rs | 5 ++- quaint/src/pooled/manager.rs | 15 +++++++- quaint/src/single.rs | 17 +++++++-- quaint/src/tests/query.rs | 16 +++++++- quaint/src/tests/query/error.rs | 2 +- 11 files changed, 303 insertions(+), 25 deletions(-) diff --git a/quaint/src/connector/mssql.rs b/quaint/src/connector/mssql.rs index cef092edb9d7..7e1ee59f9be7 100644 --- a/quaint/src/connector/mssql.rs +++ b/quaint/src/connector/mssql.rs @@ -22,6 +22,7 @@ use std::{ use tiberius::*; use tokio::net::TcpStream; use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt}; +use std::sync::Arc; /// The underlying SQL Server driver. Only available with the `expose-drivers` Cargo feature. #[cfg(feature = "expose-drivers")] @@ -106,10 +107,25 @@ impl TransactionCapable for Mssql { .or(self.url.query_params.transaction_isolation_level) .or(Some(SQL_SERVER_DEFAULT_ISOLATION)); - let opts = TransactionOptions::new(isolation, self.requires_isolation_first()); + let mut transaction_depth = self.transaction_depth.lock().await; + *transaction_depth += 1; + let st_depth = *transaction_depth + 0; + + let begin_statement = self.begin_statement(st_depth).await; + let commit_stmt = self.commit_statement(st_depth).await; + let rollback_stmt = self.rollback_statement(st_depth).await; + + let opts = TransactionOptions::new( + isolation, + self.requires_isolation_first(), + self.transaction_depth.clone(), + commit_stmt, + rollback_stmt, + ); + Ok(Box::new( - DefaultTransaction::new(self, self.begin_statement(), opts).await?, + DefaultTransaction::new(self, &begin_statement, opts).await?, )) } } @@ -273,6 +289,7 @@ pub struct Mssql { url: MssqlUrl, socket_timeout: Option, is_healthy: AtomicBool, + transaction_depth: Arc>, } impl Mssql { @@ -304,6 +321,7 @@ impl Mssql { url, socket_timeout, is_healthy: AtomicBool::new(true), + transaction_depth: Arc::new(Mutex::new(0)), }; if let Some(isolation) = this.url.transaction_isolation_level() { @@ -443,8 +461,41 @@ impl Queryable for Mssql { Ok(()) } - fn begin_statement(&self) -> &'static str { - "BEGIN TRAN" + /// Statement to begin a transaction + async fn begin_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("SAVE TRANSACTION savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "BEGIN TRAN".to_string() + }; + + return ret + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + // MSSQL doesn't have a "RELEASE SAVEPOINT" equivalent, so in a nested + // transaction we just continue onwards + let ret = if depth > 1 { + " ".to_string() + } else { + "COMMIT".to_string() + }; + + return ret + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TRANSACTION savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret } fn requires_isolation_first(&self) -> bool { diff --git a/quaint/src/connector/mysql.rs b/quaint/src/connector/mysql.rs index e5a1b794ab5b..68d9cdb95e65 100644 --- a/quaint/src/connector/mysql.rs +++ b/quaint/src/connector/mysql.rs @@ -23,6 +23,7 @@ use std::{ }; use tokio::sync::Mutex; use url::{Host, Url}; +use std::sync::Arc; /// The underlying MySQL driver. Only available with the `expose-drivers` /// Cargo feature. @@ -39,6 +40,7 @@ pub struct Mysql { socket_timeout: Option, is_healthy: AtomicBool, statement_cache: Mutex>, + transaction_depth: Arc>, } /// Wraps a connection url and exposes the parsing logic used by quaint, including default values. @@ -374,6 +376,7 @@ impl Mysql { statement_cache: Mutex::new(url.cache()), url, is_healthy: AtomicBool::new(true), + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), }) } @@ -581,6 +584,42 @@ impl Queryable for Mysql { fn requires_isolation_first(&self) -> bool { true } + + /// Statement to begin a transaction + async fn begin_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "BEGIN".to_string() + }; + + return ret + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "COMMIT".to_string() + }; + + return ret + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TO savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret + } } #[cfg(test)] diff --git a/quaint/src/connector/postgres.rs b/quaint/src/connector/postgres.rs index 2c81144c812b..e1ab15bb9a8e 100644 --- a/quaint/src/connector/postgres.rs +++ b/quaint/src/connector/postgres.rs @@ -26,6 +26,7 @@ use tokio_postgres::{ Client, Config, Statement, }; use url::{Host, Url}; +use std::sync::Arc; pub(crate) const DEFAULT_SCHEMA: &str = "public"; @@ -61,6 +62,7 @@ pub struct PostgreSql { socket_timeout: Option, statement_cache: Mutex>, is_healthy: AtomicBool, + transaction_depth: Arc>, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -650,6 +652,7 @@ impl PostgreSql { pg_bouncer: url.query_params.pg_bouncer, statement_cache: Mutex::new(url.cache()), is_healthy: AtomicBool::new(true), + transaction_depth: Arc::new(Mutex::new(0)), }) } @@ -930,6 +933,42 @@ impl Queryable for PostgreSql { fn requires_isolation_first(&self) -> bool { false } + + /// Statement to begin a transaction + async fn begin_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "BEGIN".to_string() + }; + + return ret + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "COMMIT".to_string() + }; + + return ret + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TO SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret + } } /// Sorted list of CockroachDB's reserved keywords. diff --git a/quaint/src/connector/queryable.rs b/quaint/src/connector/queryable.rs index 09dbc7abba4c..f991e5506d9e 100644 --- a/quaint/src/connector/queryable.rs +++ b/quaint/src/connector/queryable.rs @@ -87,8 +87,18 @@ pub trait Queryable: Send + Sync { } /// Statement to begin a transaction - fn begin_statement(&self) -> &'static str { - "BEGIN" + async fn begin_statement(&self, _depth: i32) -> String { + "BEGIN".to_string() + } + + /// Statement to commit a transaction + async fn commit_statement(&self, _depth: i32) -> String { + "COMMIT".to_string() + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, _depth: i32) -> String { + "ROLLBACK".to_string() } /// Sets the transaction isolation level to given value. @@ -117,10 +127,28 @@ macro_rules! impl_default_TransactionCapable { &'a self, isolation: Option, ) -> crate::Result> { - let opts = crate::connector::TransactionOptions::new(isolation, self.requires_isolation_first()); + let depth = self.transaction_depth.clone(); + let mut depth_guard = self.transaction_depth.lock().await; + *depth_guard += 1; + + let st_depth = *depth_guard; + + let begin_statement = self.begin_statement(st_depth).await; + let commit_stmt = self.commit_statement(st_depth).await; + let rollback_stmt = self.rollback_statement(st_depth).await; + + + + let opts = crate::connector::TransactionOptions::new( + isolation, + self.requires_isolation_first(), + depth, + commit_stmt, + rollback_stmt, + ); Ok(Box::new( - crate::connector::DefaultTransaction::new(self, self.begin_statement(), opts).await?, + crate::connector::DefaultTransaction::new(self, &begin_statement, opts).await?, )) } } diff --git a/quaint/src/connector/sqlite.rs b/quaint/src/connector/sqlite.rs index 6db49523c80a..83a3c3d0274a 100644 --- a/quaint/src/connector/sqlite.rs +++ b/quaint/src/connector/sqlite.rs @@ -13,6 +13,7 @@ use crate::{ use async_trait::async_trait; use std::{convert::TryFrom, path::Path, time::Duration}; use tokio::sync::Mutex; +use std::sync::Arc; pub(crate) const DEFAULT_SQLITE_SCHEMA_NAME: &str = "main"; @@ -23,6 +24,7 @@ pub use rusqlite; /// A connector interface for the SQLite database pub struct Sqlite { pub(crate) client: Mutex, + transaction_depth: Arc>, } /// Wraps a connection url and exposes the parsing logic used by Quaint, @@ -139,7 +141,10 @@ impl TryFrom<&str> for Sqlite { let client = Mutex::new(conn); - Ok(Sqlite { client }) + Ok(Sqlite { + client, + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), + }) } } @@ -154,6 +159,7 @@ impl Sqlite { Ok(Sqlite { client: Mutex::new(client), + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), }) } @@ -250,6 +256,42 @@ impl Queryable for Sqlite { fn requires_isolation_first(&self) -> bool { false } + + /// Statement to begin a transaction + async fn begin_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "BEGIN".to_string() + }; + + return ret + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "COMMIT".to_string() + }; + + return ret + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TO savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret + } } #[cfg(test)] diff --git a/quaint/src/connector/transaction.rs b/quaint/src/connector/transaction.rs index b7e91e97f6a8..e6caa01b0e2a 100644 --- a/quaint/src/connector/transaction.rs +++ b/quaint/src/connector/transaction.rs @@ -6,16 +6,18 @@ use crate::{ use async_trait::async_trait; use metrics::{decrement_gauge, increment_gauge}; use std::{fmt, str::FromStr}; +use futures::lock::Mutex; +use std::sync::Arc; extern crate metrics as metrics; #[async_trait] pub trait Transaction: Queryable { /// Commit the changes to the database and consume the transaction. - async fn commit(&self) -> crate::Result<()>; + async fn commit(&mut self) -> crate::Result<()>; /// Rolls back the changes to the database. - async fn rollback(&self) -> crate::Result<()>; + async fn rollback(&mut self) -> crate::Result<()>; /// workaround for lack of upcasting between traits https://github.com/rust-lang/rust/issues/65991 fn as_queryable(&self) -> &dyn Queryable; @@ -27,6 +29,15 @@ pub(crate) struct TransactionOptions { /// Whether or not to put the isolation level `SET` before or after the `BEGIN`. pub(crate) isolation_first: bool, + + /// The depth of the transaction, used to determine the nested transaction statements. + pub depth: Arc>, + + /// The statement to use to commit the transaction. + pub commit_stmt: String, + + /// The statement to use to rollback the transaction. + pub rollback_stmt: String, } /// A default representation of an SQL database transaction. If not commited, a @@ -36,6 +47,9 @@ pub(crate) struct TransactionOptions { /// transaction object will panic. pub struct DefaultTransaction<'a> { pub inner: &'a dyn Queryable, + pub depth: Arc>, + pub commit_stmt: String, + pub rollback_stmt: String, } impl<'a> DefaultTransaction<'a> { @@ -44,7 +58,12 @@ impl<'a> DefaultTransaction<'a> { begin_stmt: &str, tx_opts: TransactionOptions, ) -> crate::Result> { - let this = Self { inner }; + let this = Self { + inner, + depth: tx_opts.depth, + commit_stmt: tx_opts.commit_stmt, + rollback_stmt: tx_opts.rollback_stmt, + }; if tx_opts.isolation_first { if let Some(isolation) = tx_opts.isolation_level { @@ -70,17 +89,29 @@ impl<'a> DefaultTransaction<'a> { #[async_trait] impl<'a> Transaction for DefaultTransaction<'a> { /// Commit the changes to the database and consume the transaction. - async fn commit(&self) -> crate::Result<()> { + async fn commit(&mut self) -> crate::Result<()> { decrement_gauge!("prisma_client_queries_active", 1.0); - self.inner.raw_cmd("COMMIT").await?; + + let mut depth_guard = self.depth.lock().await; + + self.inner.raw_cmd(&self.commit_stmt).await?; + + // Modify the depth value through the MutexGuard + *depth_guard -= 1; Ok(()) } /// Rolls back the changes to the database. - async fn rollback(&self) -> crate::Result<()> { + async fn rollback(&mut self) -> crate::Result<()> { decrement_gauge!("prisma_client_queries_active", 1.0); - self.inner.raw_cmd("ROLLBACK").await?; + + let mut depth_guard = self.depth.lock().await; + + self.inner.raw_cmd(&self.rollback_stmt).await?; + + // Modify the depth value through the MutexGuard + *depth_guard -= 1; Ok(()) } @@ -190,10 +221,19 @@ impl FromStr for IsolationLevel { } } impl TransactionOptions { - pub fn new(isolation_level: Option, isolation_first: bool) -> Self { + pub fn new( + isolation_level: Option, + isolation_first: bool, + depth: Arc>, + commit_stmt: String, + rollback_stmt: String, + ) -> Self { Self { isolation_level, isolation_first, + depth, + commit_stmt, + rollback_stmt, } } } diff --git a/quaint/src/pooled.rs b/quaint/src/pooled.rs index 4c4152923377..aec229b744dc 100644 --- a/quaint/src/pooled.rs +++ b/quaint/src/pooled.rs @@ -500,7 +500,10 @@ impl Quaint { } }; - Ok(PooledConnection { inner }) + Ok(PooledConnection { + inner, + transaction_depth: Arc::new(futures::lock::Mutex::new(0)) + }) } /// Info about the connection and underlying database. diff --git a/quaint/src/pooled/manager.rs b/quaint/src/pooled/manager.rs index c0aa8c93b75d..27367961cbe5 100644 --- a/quaint/src/pooled/manager.rs +++ b/quaint/src/pooled/manager.rs @@ -11,11 +11,14 @@ use crate::{ }; use async_trait::async_trait; use mobc::{Connection as MobcPooled, Manager}; +use futures::lock::Mutex; +use std::sync::Arc; /// A connection from the pool. Implements /// [Queryable](connector/trait.Queryable.html). pub struct PooledConnection { pub(crate) inner: MobcPooled, + pub transaction_depth: Arc>, } impl_default_TransactionCapable!(PooledConnection); @@ -62,8 +65,16 @@ impl Queryable for PooledConnection { self.inner.server_reset_query(tx).await } - fn begin_statement(&self) -> &'static str { - self.inner.begin_statement() + async fn begin_statement(&self, depth: i32) -> String { + self.inner.begin_statement(depth).await + } + + async fn commit_statement(&self, depth: i32) -> String { + self.inner.commit_statement(depth).await + } + + async fn rollback_statement(&self, depth: i32) -> String { + self.inner.rollback_statement(depth).await } async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { diff --git a/quaint/src/single.rs b/quaint/src/single.rs index 82042f58010b..da173321ff51 100644 --- a/quaint/src/single.rs +++ b/quaint/src/single.rs @@ -8,6 +8,7 @@ use crate::{ }; use async_trait::async_trait; use std::{fmt, sync::Arc}; +use futures::lock::Mutex; #[cfg(feature = "sqlite")] use std::convert::TryFrom; @@ -17,6 +18,7 @@ use std::convert::TryFrom; pub struct Quaint { inner: Arc, connection_info: Arc, + transaction_depth: Arc>, } impl fmt::Debug for Quaint { @@ -163,7 +165,7 @@ impl Quaint { let connection_info = Arc::new(ConnectionInfo::from_url(url_str)?); Self::log_start(&connection_info); - Ok(Self { inner, connection_info }) + Ok(Self { inner, connection_info, transaction_depth: Arc::new(Mutex::new(0)) }) } #[cfg(feature = "sqlite")] @@ -174,6 +176,7 @@ impl Quaint { connection_info: Arc::new(ConnectionInfo::InMemorySqlite { db_name: DEFAULT_SQLITE_SCHEMA_NAME.to_owned(), }), + transaction_depth: Arc::new(Mutex::new(0)), }) } @@ -228,8 +231,16 @@ impl Queryable for Quaint { self.inner.is_healthy() } - fn begin_statement(&self) -> &'static str { - self.inner.begin_statement() + async fn begin_statement(&self, depth: i32) -> String { + self.inner.begin_statement(depth).await + } + + async fn commit_statement(&self, depth: i32) -> String { + self.inner.commit_statement(depth).await + } + + async fn rollback_statement(&self, depth: i32) -> String { + self.inner.rollback_statement(depth).await } async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { diff --git a/quaint/src/tests/query.rs b/quaint/src/tests/query.rs index 06bebe1a9601..cf471fbf7330 100644 --- a/quaint/src/tests/query.rs +++ b/quaint/src/tests/query.rs @@ -64,7 +64,7 @@ async fn select_star_from(api: &mut dyn TestApi) -> crate::Result<()> { async fn transactions(api: &mut dyn TestApi) -> crate::Result<()> { let table = api.create_temp_table("value int").await?; - let tx = api.conn().start_transaction(None).await?; + let mut tx = api.conn().start_transaction(None).await?; let insert = Insert::single_into(&table).value("value", 10); let rows_affected = tx.execute(insert.into()).await?; @@ -75,6 +75,20 @@ async fn transactions(api: &mut dyn TestApi) -> crate::Result<()> { assert_eq!(Value::int32(10), res[0]); + // Check that nested transactions are also rolled back, even at multiple levels deep + let mut tx_inner = api.conn().start_transaction(None).await?; + let inner_insert1 = Insert::single_into(&table).value("value", 20); + let inner_rows_affected1 = tx.execute(inner_insert1.into()).await?; + assert_eq!(1, inner_rows_affected1); + + let mut tx_inner2 = api.conn().start_transaction(None).await?; + let inner_insert2 = Insert::single_into(&table).value("value", 20); + let inner_rows_affected2 = tx.execute(inner_insert2.into()).await?; + assert_eq!(1, inner_rows_affected2); + tx_inner2.commit().await?; + + tx_inner.commit().await?; + tx.rollback().await?; let select = Select::from_table(&table).column("value"); diff --git a/quaint/src/tests/query/error.rs b/quaint/src/tests/query/error.rs index 69c57332b6d3..67334858576e 100644 --- a/quaint/src/tests/query/error.rs +++ b/quaint/src/tests/query/error.rs @@ -456,7 +456,7 @@ async fn concurrent_transaction_conflict(api: &mut dyn TestApi) -> crate::Result let conn1 = api.create_additional_connection().await?; let conn2 = api.create_additional_connection().await?; - let tx1 = conn1.start_transaction(Some(IsolationLevel::Serializable)).await?; + let mut tx1 = conn1.start_transaction(Some(IsolationLevel::Serializable)).await?; let tx2 = conn2.start_transaction(Some(IsolationLevel::Serializable)).await?; tx1.query(Select::from_table(&table).into()).await?;