From eb76cc894989949ef732f811dbb4ea8a2d9f06dd Mon Sep 17 00:00:00 2001 From: Lucian Buzzo Date: Tue, 17 Sep 2024 12:13:37 +0100 Subject: [PATCH] encapsulate transaction depth inside transaction instannce The depth tracking can be encapsulated entirely inside a transaction instance, simplifying the code significantly. --- quaint/src/connector/mssql/native/mod.rs | 13 +--- quaint/src/connector/mysql/native/mod.rs | 7 +- quaint/src/connector/postgres/native/mod.rs | 7 +- quaint/src/connector/queryable.rs | 6 +- quaint/src/connector/sqlite/native/mod.rs | 9 +-- quaint/src/connector/transaction.rs | 71 +++++++++++-------- quaint/src/pooled.rs | 5 +- quaint/src/pooled/manager.rs | 3 - quaint/src/single.rs | 9 +-- query-engine/driver-adapters/src/queryable.rs | 42 +++++------ .../driver-adapters/src/transaction.rs | 32 ++++----- 11 files changed, 80 insertions(+), 124 deletions(-) diff --git a/quaint/src/connector/mssql/native/mod.rs b/quaint/src/connector/mssql/native/mod.rs index 3579114a364..8eda2704674 100644 --- a/quaint/src/connector/mssql/native/mod.rs +++ b/quaint/src/connector/mssql/native/mod.rs @@ -18,10 +18,7 @@ use futures::lock::Mutex; use std::{ convert::TryFrom, future::Future, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, + sync::atomic::{AtomicBool, Ordering}, time::Duration, }; use tiberius::*; @@ -48,11 +45,7 @@ impl TransactionCapable for Mssql { .or(self.url.query_params.transaction_isolation_level) .or(Some(SQL_SERVER_DEFAULT_ISOLATION)); - let opts = TransactionOptions::new( - isolation, - self.requires_isolation_first(), - self.transaction_depth.clone(), - ); + let opts = TransactionOptions::new(isolation, self.requires_isolation_first()); Ok(Box::new(DefaultTransaction::new(self, opts).await?)) } @@ -65,7 +58,6 @@ pub struct Mssql { url: MssqlUrl, socket_timeout: Option, is_healthy: AtomicBool, - transaction_depth: Arc>, } impl Mssql { @@ -97,7 +89,6 @@ impl Mssql { url, socket_timeout, is_healthy: AtomicBool::new(true), - transaction_depth: Arc::new(Mutex::new(0)), }; if let Some(isolation) = this.url.transaction_isolation_level() { diff --git a/quaint/src/connector/mysql/native/mod.rs b/quaint/src/connector/mysql/native/mod.rs index 6465db6684e..fc0b9667e19 100644 --- a/quaint/src/connector/mysql/native/mod.rs +++ b/quaint/src/connector/mysql/native/mod.rs @@ -23,10 +23,7 @@ use mysql_async::{ }; use std::{ future::Future, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, + sync::atomic::{AtomicBool, Ordering}, time::Duration, }; use tokio::sync::Mutex; @@ -79,7 +76,6 @@ pub struct Mysql { socket_timeout: Option, is_healthy: AtomicBool, statement_cache: Mutex>, - transaction_depth: Arc>, } impl Mysql { @@ -93,7 +89,6 @@ impl Mysql { statement_cache: Mutex::new(url.cache()), url, is_healthy: AtomicBool::new(true), - transaction_depth: Arc::new(futures::lock::Mutex::new(0)), }) } diff --git a/quaint/src/connector/postgres/native/mod.rs b/quaint/src/connector/postgres/native/mod.rs index 2a5fee9450f..825cf2cabf9 100644 --- a/quaint/src/connector/postgres/native/mod.rs +++ b/quaint/src/connector/postgres/native/mod.rs @@ -33,10 +33,7 @@ use std::{ fmt::{Debug, Display}, fs, future::Future, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, + sync::atomic::{AtomicBool, Ordering}, time::Duration, }; use tokio_postgres::{config::ChannelBinding, Client, Config, Statement}; @@ -64,7 +61,6 @@ pub struct PostgreSql { is_healthy: AtomicBool, is_cockroachdb: bool, is_materialize: bool, - transaction_depth: Arc>, } /// Key uniquely representing an SQL statement in the prepared statements cache. @@ -293,7 +289,6 @@ impl PostgreSql { is_healthy: AtomicBool::new(true), is_cockroachdb, is_materialize, - transaction_depth: Arc::new(Mutex::new(0)), }) } diff --git a/quaint/src/connector/queryable.rs b/quaint/src/connector/queryable.rs index 8aed583a7a0..1894176b0df 100644 --- a/quaint/src/connector/queryable.rs +++ b/quaint/src/connector/queryable.rs @@ -148,11 +148,7 @@ macro_rules! impl_default_TransactionCapable { &'a self, isolation: Option, ) -> crate::Result> { - let opts = crate::connector::TransactionOptions::new( - isolation, - self.requires_isolation_first(), - self.transaction_depth.clone(), - ); + let opts = crate::connector::TransactionOptions::new(isolation, self.requires_isolation_first()); Ok(Box::new( crate::connector::DefaultTransaction::new(self, opts).await?, diff --git a/quaint/src/connector/sqlite/native/mod.rs b/quaint/src/connector/sqlite/native/mod.rs index 902c5fb91cb..5b3d2abbe0f 100644 --- a/quaint/src/connector/sqlite/native/mod.rs +++ b/quaint/src/connector/sqlite/native/mod.rs @@ -17,7 +17,7 @@ use crate::{ visitor::{self, Visitor}, }; use async_trait::async_trait; -use std::{convert::TryFrom, sync::Arc}; +use std::convert::TryFrom; use tokio::sync::Mutex; /// The underlying sqlite driver. Only available with the `expose-drivers` Cargo feature. @@ -27,7 +27,6 @@ pub use rusqlite; /// A connector interface for the SQLite database pub struct Sqlite { pub(crate) client: Mutex, - transaction_depth: Arc>, } impl TryFrom<&str> for Sqlite { @@ -65,10 +64,7 @@ impl TryFrom<&str> for Sqlite { let client = Mutex::new(conn); - Ok(Sqlite { - client, - transaction_depth: Arc::new(futures::lock::Mutex::new(0)), - }) + Ok(Sqlite { client }) } } @@ -83,7 +79,6 @@ impl Sqlite { Ok(Sqlite { client: Mutex::new(client), - transaction_depth: Arc::new(futures::lock::Mutex::new(0)), }) } diff --git a/quaint/src/connector/transaction.rs b/quaint/src/connector/transaction.rs index 20e1ee6b029..42ce8120961 100644 --- a/quaint/src/connector/transaction.rs +++ b/quaint/src/connector/transaction.rs @@ -4,9 +4,12 @@ use crate::{ error::{Error, ErrorKind}, }; use async_trait::async_trait; -use futures::lock::Mutex; use metrics::{decrement_gauge, increment_gauge}; -use std::{fmt, str::FromStr, sync::Arc}; +use std::{ + fmt, + str::FromStr, + sync::{Arc, Mutex}, +}; extern crate metrics as metrics; @@ -31,9 +34,6 @@ pub(crate) struct TransactionOptions { /// Whether or not to put the isolation level `SET` before or after the `BEGIN`. pub(crate) isolation_first: bool, - - /// The depth of the transaction, used to determine the nested transaction statements. - pub depth: Arc>, } /// A default representation of an SQL database transaction. If not commited, a @@ -53,7 +53,7 @@ impl<'a> DefaultTransaction<'a> { ) -> crate::Result> { let mut this = Self { inner, - depth: tx_opts.depth, + depth: Arc::new(Mutex::new(0)), }; if tx_opts.isolation_first { @@ -81,14 +81,13 @@ impl<'a> Transaction for DefaultTransaction<'a> { async fn begin(&mut self) -> crate::Result<()> { increment_gauge!("prisma_client_queries_active", 1.0); - let mut depth_guard = self.depth.lock().await; - - // Modify the depth value through the MutexGuard - *depth_guard += 1; - - let st_depth = *depth_guard; + let current_depth = { + let mut depth = self.depth.lock().unwrap(); + *depth += 1; + *depth + }; - let begin_statement = self.inner.begin_statement(st_depth).await; + let begin_statement = self.inner.begin_statement(current_depth).await; self.inner.raw_cmd(&begin_statement).await?; @@ -99,36 +98,49 @@ impl<'a> Transaction for DefaultTransaction<'a> { async fn commit(&mut self) -> crate::Result { decrement_gauge!("prisma_client_queries_active", 1.0); - let mut depth_guard = self.depth.lock().await; - - let st_depth = *depth_guard; - - let commit_statement = self.inner.commit_statement(st_depth).await; + // Lock the mutex and get the depth value + let depth_val = { + let depth = self.depth.lock().unwrap(); + *depth + }; + // Perform the asynchronous operation without holding the lock + let commit_statement = self.inner.commit_statement(depth_val).await; self.inner.raw_cmd(&commit_statement).await?; - // Modify the depth value through the MutexGuard - *depth_guard -= 1; + // Lock the mutex again to modify the depth + let new_depth = { + let mut depth = self.depth.lock().unwrap(); + *depth -= 1; + *depth + }; - Ok(*depth_guard) + Ok(new_depth) } /// Rolls back the changes to the database. async fn rollback(&mut self) -> crate::Result { decrement_gauge!("prisma_client_queries_active", 1.0); - let mut depth_guard = self.depth.lock().await; - - let st_depth = *depth_guard; + // Lock the mutex and get the depth value + let depth_val = { + let depth = self.depth.lock().unwrap(); + *depth + }; - let rollback_statement = self.inner.rollback_statement(st_depth).await; + // Perform the asynchronous operation without holding the lock + let rollback_statement = self.inner.rollback_statement(depth_val).await; self.inner.raw_cmd(&rollback_statement).await?; - // Modify the depth value through the MutexGuard - *depth_guard -= 1; + // Lock the mutex again to modify the depth + let new_depth = { + let mut depth = self.depth.lock().unwrap(); + *depth -= 1; + *depth + }; - Ok(*depth_guard) + Ok(new_depth) } fn as_queryable(&self) -> &dyn Queryable { @@ -240,11 +252,10 @@ impl FromStr for IsolationLevel { } } impl TransactionOptions { - pub fn new(isolation_level: Option, isolation_first: bool, depth: Arc>) -> Self { + pub fn new(isolation_level: Option, isolation_first: bool) -> Self { Self { isolation_level, isolation_first, - depth, } } } diff --git a/quaint/src/pooled.rs b/quaint/src/pooled.rs index 9bacf46d421..381f0c82414 100644 --- a/quaint/src/pooled.rs +++ b/quaint/src/pooled.rs @@ -507,10 +507,7 @@ impl Quaint { } }; - Ok(PooledConnection { - inner, - transaction_depth: Arc::new(futures::lock::Mutex::new(0)), - }) + Ok(PooledConnection { inner }) } /// Info about the connection and underlying database. diff --git a/quaint/src/pooled/manager.rs b/quaint/src/pooled/manager.rs index e20bf7a341a..5e96c3c51bd 100644 --- a/quaint/src/pooled/manager.rs +++ b/quaint/src/pooled/manager.rs @@ -10,15 +10,12 @@ use crate::{ error::Error, }; use async_trait::async_trait; -use futures::lock::Mutex; use mobc::{Connection as MobcPooled, Manager}; -use std::sync::Arc; /// A connection from the pool. Implements /// [Queryable](connector/trait.Queryable.html). pub struct PooledConnection { pub(crate) inner: MobcPooled, - pub transaction_depth: Arc>, } impl_default_TransactionCapable!(PooledConnection); diff --git a/quaint/src/single.rs b/quaint/src/single.rs index 96806653867..128e63c5a6c 100644 --- a/quaint/src/single.rs +++ b/quaint/src/single.rs @@ -5,7 +5,6 @@ use crate::{ connector::{self, impl_default_TransactionCapable, ConnectionInfo, IsolationLevel, Queryable, TransactionCapable}, }; use async_trait::async_trait; -use futures::lock::Mutex; use std::{fmt, sync::Arc}; #[cfg(feature = "sqlite-native")] @@ -19,7 +18,6 @@ use crate::connector::NativeConnectionInfo; pub struct Quaint { inner: Arc, connection_info: Arc, - transaction_depth: Arc>, } impl fmt::Debug for Quaint { @@ -167,11 +165,7 @@ impl Quaint { let connection_info = Arc::new(ConnectionInfo::from_url(url_str)?); Self::log_start(&connection_info); - Ok(Self { - inner, - connection_info, - transaction_depth: Arc::new(Mutex::new(0)), - }) + Ok(Self { inner, connection_info }) } #[cfg(feature = "sqlite-native")] @@ -184,7 +178,6 @@ impl Quaint { connection_info: Arc::new(ConnectionInfo::Native(NativeConnectionInfo::InMemorySqlite { db_name: DEFAULT_SQLITE_DATABASE.to_owned(), })), - transaction_depth: Arc::new(Mutex::new(0)), }) } diff --git a/query-engine/driver-adapters/src/queryable.rs b/query-engine/driver-adapters/src/queryable.rs index 06ad4b499ab..db5f9141577 100644 --- a/query-engine/driver-adapters/src/queryable.rs +++ b/query-engine/driver-adapters/src/queryable.rs @@ -5,7 +5,7 @@ use crate::JsObject; use super::conversion; use crate::send_future::UnsafeFuture; use async_trait::async_trait; -use futures::{lock::Mutex, Future}; +use futures::Future; use quaint::connector::{DescribedQuery, ExternalConnectionInfo, ExternalConnector}; use quaint::{ connector::{metrics, IsolationLevel, Transaction}, @@ -13,7 +13,6 @@ use quaint::{ prelude::{Query as QuaintQuery, Queryable as QuaintQueryable, ResultSet, TransactionCapable}, visitor::{self, Visitor}, }; -use std::sync::Arc; use tracing::{info_span, Instrument}; /// A JsQueryable adapts a Proxy to implement quaint's Queryable interface. It has the @@ -228,7 +227,6 @@ impl JsBaseQueryable { pub struct JsQueryable { inner: JsBaseQueryable, driver_proxy: DriverProxy, - pub transaction_depth: Arc>, } impl std::fmt::Display for JsQueryable { @@ -324,34 +322,29 @@ impl JsQueryable { } // 3. Spawn a transaction from the context. - let tx = tx_ctx.start_transaction().await?; + let mut tx = tx_ctx.start_transaction().await?; - { - let mut depth_guard = tx.depth.lock().await; - *depth_guard += 1; + tx.depth += 1; - let st_depth = *depth_guard; + let begin_stmt = tx.begin_statement(tx.depth).await; + let tx_opts = tx.options(); - let begin_stmt = tx.begin_statement(st_depth).await; - let tx_opts = tx.options(); - - if tx_opts.use_phantom_query { - let begin_stmt = JsBaseQueryable::phantom_query_message(&begin_stmt); - tx.raw_phantom_cmd(begin_stmt.as_str()).await?; - } else { - tx.raw_cmd(&begin_stmt).await?; - } + if tx_opts.use_phantom_query { + let begin_stmt = JsBaseQueryable::phantom_query_message(&begin_stmt); + tx.raw_phantom_cmd(begin_stmt.as_str()).await?; + } else { + tx.raw_cmd(&begin_stmt).await?; + } - // 4. Set the isolation level (if specified) if we didn't do it before. - if !requires_isolation_first { - if let Some(isolation) = isolation { - tx.set_tx_isolation_level(isolation).await?; - } + // 4. Set the isolation level (if specified) if we didn't do it before. + if !requires_isolation_first { + if let Some(isolation) = isolation { + tx.set_tx_isolation_level(isolation).await?; } - - self.server_reset_query(tx.as_ref()).await?; } + self.server_reset_query(tx.as_ref()).await?; + Ok(tx) } } @@ -373,6 +366,5 @@ pub fn from_js(driver: JsObject) -> JsQueryable { JsQueryable { inner: JsBaseQueryable::new(common), driver_proxy, - transaction_depth: Arc::new(futures::lock::Mutex::new(0)), } } diff --git a/query-engine/driver-adapters/src/transaction.rs b/query-engine/driver-adapters/src/transaction.rs index 4d0ddb389a8..d6b2fe88be9 100644 --- a/query-engine/driver-adapters/src/transaction.rs +++ b/query-engine/driver-adapters/src/transaction.rs @@ -1,14 +1,12 @@ use std::future::Future; use async_trait::async_trait; -use futures::lock::Mutex; use metrics::decrement_gauge; use quaint::{ connector::{DescribedQuery, IsolationLevel, Transaction as QuaintTransaction}, prelude::{Query as QuaintQuery, Queryable, ResultSet}, Value, }; -use std::sync::Arc; use crate::proxy::{TransactionContextProxy, TransactionOptions, TransactionProxy}; use crate::{proxy::CommonProxy, queryable::JsBaseQueryable, send_future::UnsafeFuture}; @@ -88,7 +86,7 @@ impl Queryable for JsTransactionContext { pub(crate) struct JsTransaction { tx_proxy: TransactionProxy, inner: JsBaseQueryable, - pub depth: Arc>, + pub depth: i32, } impl JsTransaction { @@ -96,7 +94,7 @@ impl JsTransaction { Self { inner, tx_proxy, - depth: Arc::new(futures::lock::Mutex::new(0)), + depth: 0, } } @@ -116,11 +114,9 @@ impl QuaintTransaction for JsTransaction { // increment of this gauge is done in DriverProxy::startTransaction decrement_gauge!("prisma_client_queries_active", 1.0); - let mut depth_guard = self.depth.lock().await; - // Modify the depth value through the MutexGuard - *depth_guard += 1; + self.depth += 1; - let begin_stmt = self.begin_statement(*depth_guard).await; + let begin_stmt = self.begin_statement(self.depth).await; if self.options().use_phantom_query { let commit_stmt = JsBaseQueryable::phantom_query_message(&begin_stmt); @@ -129,7 +125,7 @@ impl QuaintTransaction for JsTransaction { self.inner.raw_cmd(&begin_stmt).await?; } - println!("JsTransaction begin: incrementing depth_guard to: {}", *depth_guard); + println!("JsTransaction begin: incrementing depth_guard to: {}", self.depth); UnsafeFuture(self.tx_proxy.begin()).await } @@ -137,8 +133,7 @@ impl QuaintTransaction for JsTransaction { // increment of this gauge is done in DriverProxy::startTransaction decrement_gauge!("prisma_client_queries_active", 1.0); - let mut depth_guard = self.depth.lock().await; - let commit_stmt = self.commit_statement(*depth_guard).await; + let commit_stmt = self.commit_statement(self.depth).await; if self.options().use_phantom_query { let commit_stmt = JsBaseQueryable::phantom_query_message(&commit_stmt); @@ -147,20 +142,19 @@ impl QuaintTransaction for JsTransaction { self.inner.raw_cmd(&commit_stmt).await?; } - // Modify the depth value through the MutexGuard - *depth_guard -= 1; + // Modify the depth value + self.depth -= 1; let _ = UnsafeFuture(self.tx_proxy.commit()).await; - Ok(*depth_guard) + Ok(self.depth) } async fn rollback(&mut self) -> quaint::Result { // increment of this gauge is done in DriverProxy::startTransaction decrement_gauge!("prisma_client_queries_active", 1.0); - let mut depth_guard = self.depth.lock().await; - let rollback_stmt = self.rollback_statement(*depth_guard).await; + let rollback_stmt = self.rollback_statement(self.depth).await; if self.options().use_phantom_query { let rollback_stmt = JsBaseQueryable::phantom_query_message(&rollback_stmt); @@ -169,12 +163,12 @@ impl QuaintTransaction for JsTransaction { self.inner.raw_cmd(&rollback_stmt).await?; } - // Modify the depth value through the MutexGuard - *depth_guard -= 1; + // Modify the depth value + self.depth -= 1; let _ = UnsafeFuture(self.tx_proxy.rollback()).await; - Ok(*depth_guard) + Ok(self.depth) } fn as_queryable(&self) -> &dyn Queryable {