Skip to content

Commit

Permalink
feat: add support for nested transaction rollbacks via savepoints in sql
Browse files Browse the repository at this point in the history
This is my first OSS contribution for a Rust project, so I'm sure I've
made some stupid mistakes, but I think it should mostly work :)

This change adds a mutable depth counter, that can track how many levels
deep a transaction is, and uses savepoints to implement correct rollback
behaviour. Previously, once a nested transaction was complete, it would
be saved with `COMMIT`, meaning that even if the outer transaction was
rolled back, the operations in the inner transaction would persist. With
this change, if the outer transaction gets rolled back, then all inner
transactions will also be rolled back.

Different flavours of SQL servers have different syntax for handling
savepoints, so I've had to add new methods to the `Queryable` trait for
getting the commit and rollback statements. These are both parameterized
by the current depth.

I've additionally had to modify the `begin_statement` method to accept a depth
parameter, as it will need to conditionally create a savepoint.

Signed-off-by: Lucian Buzzo <[email protected]>
  • Loading branch information
LucianBuzzo committed Oct 17, 2023
1 parent 72963d8 commit b984ae4
Show file tree
Hide file tree
Showing 11 changed files with 303 additions and 25 deletions.
59 changes: 55 additions & 4 deletions quaint/src/connector/mssql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use std::{
use tiberius::*;
use tokio::net::TcpStream;
use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
use std::sync::Arc;

/// The underlying SQL Server driver. Only available with the `expose-drivers` Cargo feature.
#[cfg(feature = "expose-drivers")]
Expand Down Expand Up @@ -106,10 +107,25 @@ impl TransactionCapable for Mssql {
.or(self.url.query_params.transaction_isolation_level)
.or(Some(SQL_SERVER_DEFAULT_ISOLATION));

let opts = TransactionOptions::new(isolation, self.requires_isolation_first());
let mut transaction_depth = self.transaction_depth.lock().await;
*transaction_depth += 1;
let st_depth = *transaction_depth + 0;

let begin_statement = self.begin_statement(st_depth).await;
let commit_stmt = self.commit_statement(st_depth).await;
let rollback_stmt = self.rollback_statement(st_depth).await;

let opts = TransactionOptions::new(
isolation,
self.requires_isolation_first(),
self.transaction_depth.clone(),
commit_stmt,
rollback_stmt,
);


Ok(Box::new(
DefaultTransaction::new(self, self.begin_statement(), opts).await?,
DefaultTransaction::new(self, &begin_statement, opts).await?,
))
}
}
Expand Down Expand Up @@ -273,6 +289,7 @@ pub struct Mssql {
url: MssqlUrl,
socket_timeout: Option<Duration>,
is_healthy: AtomicBool,
transaction_depth: Arc<Mutex<i32>>,
}

impl Mssql {
Expand Down Expand Up @@ -304,6 +321,7 @@ impl Mssql {
url,
socket_timeout,
is_healthy: AtomicBool::new(true),
transaction_depth: Arc::new(Mutex::new(0)),
};

if let Some(isolation) = this.url.transaction_isolation_level() {
Expand Down Expand Up @@ -443,8 +461,41 @@ impl Queryable for Mssql {
Ok(())
}

fn begin_statement(&self) -> &'static str {
"BEGIN TRAN"
/// Statement to begin a transaction
async fn begin_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("SAVE TRANSACTION savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"BEGIN TRAN".to_string()
};

return ret
}

/// Statement to commit a transaction
async fn commit_statement(&self, depth: i32) -> String {
// MSSQL doesn't have a "RELEASE SAVEPOINT" equivalent, so in a nested
// transaction we just continue onwards
let ret = if depth > 1 {
" ".to_string()
} else {
"COMMIT".to_string()
};

return ret
}

/// Statement to rollback a transaction
async fn rollback_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("ROLLBACK TRANSACTION savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"ROLLBACK".to_string()
};

return ret
}

fn requires_isolation_first(&self) -> bool {
Expand Down
39 changes: 39 additions & 0 deletions quaint/src/connector/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use std::{
};
use tokio::sync::Mutex;
use url::{Host, Url};
use std::sync::Arc;

/// The underlying MySQL driver. Only available with the `expose-drivers`
/// Cargo feature.
Expand All @@ -39,6 +40,7 @@ pub struct Mysql {
socket_timeout: Option<Duration>,
is_healthy: AtomicBool,
statement_cache: Mutex<LruCache<String, my::Statement>>,
transaction_depth: Arc<futures::lock::Mutex<i32>>,
}

/// Wraps a connection url and exposes the parsing logic used by quaint, including default values.
Expand Down Expand Up @@ -374,6 +376,7 @@ impl Mysql {
statement_cache: Mutex::new(url.cache()),
url,
is_healthy: AtomicBool::new(true),
transaction_depth: Arc::new(futures::lock::Mutex::new(0)),
})
}

Expand Down Expand Up @@ -581,6 +584,42 @@ impl Queryable for Mysql {
fn requires_isolation_first(&self) -> bool {
true
}

/// Statement to begin a transaction
async fn begin_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"BEGIN".to_string()
};

return ret
}

/// Statement to commit a transaction
async fn commit_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"COMMIT".to_string()
};

return ret
}

/// Statement to rollback a transaction
async fn rollback_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("ROLLBACK TO savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"ROLLBACK".to_string()
};

return ret
}
}

#[cfg(test)]
Expand Down
39 changes: 39 additions & 0 deletions quaint/src/connector/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use tokio_postgres::{
Client, Config, Statement,
};
use url::{Host, Url};
use std::sync::Arc;

pub(crate) const DEFAULT_SCHEMA: &str = "public";

Expand Down Expand Up @@ -61,6 +62,7 @@ pub struct PostgreSql {
socket_timeout: Option<Duration>,
statement_cache: Mutex<LruCache<String, Statement>>,
is_healthy: AtomicBool,
transaction_depth: Arc<Mutex<i32>>,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
Expand Down Expand Up @@ -650,6 +652,7 @@ impl PostgreSql {
pg_bouncer: url.query_params.pg_bouncer,
statement_cache: Mutex::new(url.cache()),
is_healthy: AtomicBool::new(true),
transaction_depth: Arc::new(Mutex::new(0)),
})
}

Expand Down Expand Up @@ -930,6 +933,42 @@ impl Queryable for PostgreSql {
fn requires_isolation_first(&self) -> bool {
false
}

/// Statement to begin a transaction
async fn begin_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"BEGIN".to_string()
};

return ret
}

/// Statement to commit a transaction
async fn commit_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"COMMIT".to_string()
};

return ret
}

/// Statement to rollback a transaction
async fn rollback_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("ROLLBACK TO SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"ROLLBACK".to_string()
};

return ret
}
}

/// Sorted list of CockroachDB's reserved keywords.
Expand Down
36 changes: 32 additions & 4 deletions quaint/src/connector/queryable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,18 @@ pub trait Queryable: Send + Sync {
}

/// Statement to begin a transaction
fn begin_statement(&self) -> &'static str {
"BEGIN"
async fn begin_statement(&self, _depth: i32) -> String {
"BEGIN".to_string()
}

/// Statement to commit a transaction
async fn commit_statement(&self, _depth: i32) -> String {
"COMMIT".to_string()
}

/// Statement to rollback a transaction
async fn rollback_statement(&self, _depth: i32) -> String {
"ROLLBACK".to_string()
}

/// Sets the transaction isolation level to given value.
Expand Down Expand Up @@ -117,10 +127,28 @@ macro_rules! impl_default_TransactionCapable {
&'a self,
isolation: Option<IsolationLevel>,
) -> crate::Result<Box<dyn crate::connector::Transaction + 'a>> {
let opts = crate::connector::TransactionOptions::new(isolation, self.requires_isolation_first());
let depth = self.transaction_depth.clone();
let mut depth_guard = self.transaction_depth.lock().await;
*depth_guard += 1;

let st_depth = *depth_guard;

let begin_statement = self.begin_statement(st_depth).await;
let commit_stmt = self.commit_statement(st_depth).await;
let rollback_stmt = self.rollback_statement(st_depth).await;



let opts = crate::connector::TransactionOptions::new(
isolation,
self.requires_isolation_first(),
depth,
commit_stmt,
rollback_stmt,
);

Ok(Box::new(
crate::connector::DefaultTransaction::new(self, self.begin_statement(), opts).await?,
crate::connector::DefaultTransaction::new(self, &begin_statement, opts).await?,
))
}
}
Expand Down
44 changes: 43 additions & 1 deletion quaint/src/connector/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::{
use async_trait::async_trait;
use std::{convert::TryFrom, path::Path, time::Duration};
use tokio::sync::Mutex;
use std::sync::Arc;

pub(crate) const DEFAULT_SQLITE_SCHEMA_NAME: &str = "main";

Expand All @@ -23,6 +24,7 @@ pub use rusqlite;
/// A connector interface for the SQLite database
pub struct Sqlite {
pub(crate) client: Mutex<rusqlite::Connection>,
transaction_depth: Arc<futures::lock::Mutex<i32>>,
}

/// Wraps a connection url and exposes the parsing logic used by Quaint,
Expand Down Expand Up @@ -139,7 +141,10 @@ impl TryFrom<&str> for Sqlite {

let client = Mutex::new(conn);

Ok(Sqlite { client })
Ok(Sqlite {
client,
transaction_depth: Arc::new(futures::lock::Mutex::new(0)),
})
}
}

Expand All @@ -154,6 +159,7 @@ impl Sqlite {

Ok(Sqlite {
client: Mutex::new(client),
transaction_depth: Arc::new(futures::lock::Mutex::new(0)),
})
}

Expand Down Expand Up @@ -250,6 +256,42 @@ impl Queryable for Sqlite {
fn requires_isolation_first(&self) -> bool {
false
}

/// Statement to begin a transaction
async fn begin_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"BEGIN".to_string()
};

return ret
}

/// Statement to commit a transaction
async fn commit_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"COMMIT".to_string()
};

return ret
}

/// Statement to rollback a transaction
async fn rollback_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("ROLLBACK TO savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"ROLLBACK".to_string()
};

return ret
}
}

#[cfg(test)]
Expand Down
Loading

0 comments on commit b984ae4

Please sign in to comment.