Skip to content

Commit

Permalink
Merge pull request #1900 from fermyon/mysql-v2
Browse files Browse the repository at this point in the history
Move mysql interface to resources
  • Loading branch information
rylev authored Oct 18, 2023
2 parents 7954a73 + 8d30055 commit d1d8184
Show file tree
Hide file tree
Showing 10 changed files with 175 additions and 87 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 6 additions & 3 deletions crates/outbound-mysql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@ doctest = false
anyhow = "1.0"
flate2 = "1.0.17"
# Removing default features for mysql_async to remove flate2/zlib feature
mysql_async = { version = "0.32.2", default-features = false, features = ["native-tls-tls"] }
mysql_async = { version = "0.32.2", default-features = false, features = [
"native-tls-tls",
] }
# Removing default features for mysql_common to remove flate2/zlib feature
mysql_common = { version = "0.30.6", default-features = false }
spin-core = { path = "../core" }
spin-world = { path = "../world" }
tokio = { version = "1", features = [ "rt-multi-thread" ] }
tracing = { version = "0.1", features = [ "log" ] }
table = { path = "../table" }
tokio = { version = "1", features = ["rt-multi-thread"] }
tracing = { version = "0.1", features = ["log"] }
url = "2.3.1"
190 changes: 121 additions & 69 deletions crates/outbound-mysql/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
use anyhow::Result;
pub use mysql::add_to_linker;
use mysql_async::{consts::ColumnType, from_value_opt, prelude::*, Opts, OptsBuilder, SslOpts};
use spin_core::wasmtime::component::Resource;
use spin_core::{async_trait, HostComponent};
use spin_world::v1::{
mysql::{self, MysqlError},
rdbms_types::{Column, DbDataType, DbValue, ParameterValue, RowSet},
};
use std::collections::HashMap;
use spin_world::v1::mysql as v1;
use spin_world::v1::rdbms_types::{Column, DbDataType, DbValue, ParameterValue, RowSet};
use spin_world::v2::mysql::{self as v2, Connection};
use std::sync::Arc;
use url::Url;

/// A simple implementation to support outbound mysql connection
#[derive(Default)]
pub struct OutboundMysql {
pub connections: HashMap<String, mysql_async::Conn>,
pub connections: table::Table<mysql_async::Conn>,
}

impl OutboundMysql {
async fn get_conn(
&mut self,
connection: Resource<Connection>,
) -> Result<&mut mysql_async::Conn, v2::Error> {
self.connections
.get_mut(connection.rep())
.ok_or_else(|| v2::Error::ConnectionFailed("no connection found".into()))
}
}

impl HostComponent for OutboundMysql {
Expand All @@ -23,37 +32,48 @@ impl HostComponent for OutboundMysql {
linker: &mut spin_core::Linker<T>,
get: impl Fn(&mut spin_core::Data<T>) -> &mut Self::Data + Send + Sync + Copy + 'static,
) -> anyhow::Result<()> {
mysql::add_to_linker(linker, get)
v2::add_to_linker(linker, get)?;
v1::add_to_linker(linker, get)
}

fn build_data(&self) -> Self::Data {
Default::default()
}
}

impl v2::Host for OutboundMysql {}

#[async_trait]
impl mysql::Host for OutboundMysql {
impl v2::HostConnection for OutboundMysql {
async fn open(&mut self, address: String) -> Result<Result<Resource<Connection>, v2::Error>> {
Ok(async {
self.connections
.push(
build_conn(&address)
.await
.map_err(|e| v2::Error::ConnectionFailed(format!("{e:?}")))?,
)
.map_err(|_| v2::Error::ConnectionFailed("too many connections".into()))
.map(Resource::new_own)
}
.await)
}

async fn execute(
&mut self,
address: String,
connection: Resource<Connection>,
statement: String,
params: Vec<ParameterValue>,
) -> Result<Result<(), MysqlError>> {
) -> Result<Result<(), v2::Error>> {
Ok(async {
let db_params = params
.iter()
.map(to_sql_parameter)
.collect::<anyhow::Result<Vec<_>>>()
.map_err(|e| MysqlError::QueryFailed(format!("{:?}", e)))?;

let db_params = params.into_iter().map(to_sql_parameter).collect::<Vec<_>>();
let parameters = mysql_async::Params::Positional(db_params);

self.get_conn(&address)
.await
.map_err(|e| MysqlError::ConnectionFailed(format!("{:?}", e)))?
self.get_conn(connection)
.await?
.exec_batch(&statement, &[parameters])
.await
.map_err(|e| MysqlError::QueryFailed(format!("{:?}", e)))?;
.map_err(|e| v2::Error::QueryFailed(format!("{:?}", e)))?;

Ok(())
}
Expand All @@ -62,63 +82,105 @@ impl mysql::Host for OutboundMysql {

async fn query(
&mut self,
address: String,
connection: Resource<Connection>,
statement: String,
params: Vec<ParameterValue>,
) -> Result<Result<RowSet, MysqlError>> {
) -> Result<Result<RowSet, v2::Error>> {
Ok(async {
let db_params = params
.iter()
.map(to_sql_parameter)
.collect::<anyhow::Result<Vec<_>>>()
.map_err(|e| MysqlError::QueryFailed(format!("{:?}", e)))?;

let db_params = params.into_iter().map(to_sql_parameter).collect::<Vec<_>>();
let parameters = mysql_async::Params::Positional(db_params);

let mut query_result = self
.get_conn(&address)
.await
.map_err(|e| MysqlError::ConnectionFailed(format!("{:?}", e)))?
.get_conn(connection)
.await?
.exec_iter(&statement, parameters)
.await
.map_err(|e| MysqlError::QueryFailed(format!("{:?}", e)))?;
.map_err(|e| v2::Error::QueryFailed(format!("{:?}", e)))?;

// We have to get these before collect() destroys them
let columns = convert_columns(query_result.columns());

match query_result.collect::<mysql_async::Row>().await {
Err(e) => Err(MysqlError::OtherError(format!("{:?}", e))),
Err(e) => Err(v2::Error::Other(e.to_string())),
Ok(result_set) => {
let rows = result_set
.into_iter()
.map(|row| convert_row(row, &columns))
.collect::<Result<Vec<_>, _>>()
.map_err(|e| MysqlError::QueryFailed(format!("{:?}", e)))?;
.collect::<Result<Vec<_>, _>>()?;

Ok(RowSet { columns, rows })
}
}
}
.await)
}

fn drop(&mut self, connection: Resource<Connection>) -> Result<()> {
self.connections.remove(connection.rep());
Ok(())
}
}

/// Delegate a function call to the v2::HostConnection implementation
macro_rules! delegate {
($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{
let connection = match <Self as v2::HostConnection>::open($self, $address).await? {
Ok(c) => c,
Err(e) => return Ok(Err(to_legacy_error(e))),
};
Ok(<Self as v2::HostConnection>::$name($self, connection, $($arg),*)
.await?
.map_err(|e| to_legacy_error(e)))
}};
}

#[async_trait]
impl v1::Host for OutboundMysql {
async fn execute(
&mut self,
address: String,
statement: String,
params: Vec<ParameterValue>,
) -> Result<Result<(), v1::MysqlError>> {
delegate!(self.execute(address, statement, params))
}

async fn query(
&mut self,
address: String,
statement: String,
params: Vec<ParameterValue>,
) -> Result<Result<RowSet, v1::MysqlError>> {
delegate!(self.query(address, statement, params))
}
}

fn to_legacy_error(error: v2::Error) -> v1::MysqlError {
match error {
v2::Error::ConnectionFailed(e) => v1::MysqlError::ConnectionFailed(e),
v2::Error::BadParameter(e) => v1::MysqlError::BadParameter(e),
v2::Error::QueryFailed(e) => v1::MysqlError::QueryFailed(e),
v2::Error::ValueConversionFailed(e) => v1::MysqlError::ValueConversionFailed(e),
v2::Error::Other(e) => v1::MysqlError::OtherError(e),
}
}

fn to_sql_parameter(value: &ParameterValue) -> anyhow::Result<mysql_async::Value> {
fn to_sql_parameter(value: ParameterValue) -> mysql_async::Value {
match value {
ParameterValue::Boolean(v) => Ok(mysql_async::Value::from(v)),
ParameterValue::Int32(v) => Ok(mysql_async::Value::from(v)),
ParameterValue::Int64(v) => Ok(mysql_async::Value::from(v)),
ParameterValue::Int8(v) => Ok(mysql_async::Value::from(v)),
ParameterValue::Int16(v) => Ok(mysql_async::Value::from(v)),
ParameterValue::Floating32(v) => Ok(mysql_async::Value::from(v)),
ParameterValue::Floating64(v) => Ok(mysql_async::Value::from(v)),
ParameterValue::Uint8(v) => Ok(mysql_async::Value::from(v)),
ParameterValue::Uint16(v) => Ok(mysql_async::Value::from(v)),
ParameterValue::Uint32(v) => Ok(mysql_async::Value::from(v)),
ParameterValue::Uint64(v) => Ok(mysql_async::Value::from(v)),
ParameterValue::Str(v) => Ok(mysql_async::Value::from(v)),
ParameterValue::Binary(v) => Ok(mysql_async::Value::from(v)),
ParameterValue::DbNull => Ok(mysql_async::Value::NULL),
ParameterValue::Boolean(v) => mysql_async::Value::from(v),
ParameterValue::Int32(v) => mysql_async::Value::from(v),
ParameterValue::Int64(v) => mysql_async::Value::from(v),
ParameterValue::Int8(v) => mysql_async::Value::from(v),
ParameterValue::Int16(v) => mysql_async::Value::from(v),
ParameterValue::Floating32(v) => mysql_async::Value::from(v),
ParameterValue::Floating64(v) => mysql_async::Value::from(v),
ParameterValue::Uint8(v) => mysql_async::Value::from(v),
ParameterValue::Uint16(v) => mysql_async::Value::from(v),
ParameterValue::Uint32(v) => mysql_async::Value::from(v),
ParameterValue::Uint64(v) => mysql_async::Value::from(v),
ParameterValue::Str(v) => mysql_async::Value::from(v),
ParameterValue::Binary(v) => mysql_async::Value::from(v),
ParameterValue::DbNull => mysql_async::Value::NULL,
}
}

Expand All @@ -130,7 +192,7 @@ fn convert_columns(columns: Option<Arc<[mysql_async::Column]>>) -> Vec<Column> {
}

fn convert_column(column: &mysql_async::Column) -> Column {
let name = column.name_str().to_string();
let name = column.name_str().into_owned();
let data_type = convert_data_type(column);

Column { name, data_type }
Expand Down Expand Up @@ -192,7 +254,7 @@ fn is_binary(column: &mysql_async::Column) -> bool {
.contains(mysql_async::consts::ColumnFlags::BINARY_FLAG)
}

fn convert_row(mut row: mysql_async::Row, columns: &[Column]) -> Result<Vec<DbValue>, MysqlError> {
fn convert_row(mut row: mysql_async::Row, columns: &[Column]) -> Result<Vec<DbValue>, v2::Error> {
let mut result = Vec::with_capacity(row.len());
for index in 0..row.len() {
result.push(convert_entry(&mut row, index, columns)?);
Expand All @@ -204,10 +266,10 @@ fn convert_entry(
row: &mut mysql_async::Row,
index: usize,
columns: &[Column],
) -> Result<DbValue, MysqlError> {
) -> Result<DbValue, v2::Error> {
match (row.take(index), columns.get(index)) {
(None, _) => Ok(DbValue::DbNull), // TODO: is this right or is this an "index out of range" thing
(_, None) => Err(MysqlError::OtherError(format!(
(_, None) => Err(v2::Error::Other(format!(
"Can't get column at index {}",
index
))),
Expand All @@ -216,7 +278,7 @@ fn convert_entry(
}
}

fn convert_value(value: mysql_async::Value, column: &Column) -> Result<DbValue, MysqlError> {
fn convert_value(value: mysql_async::Value, column: &Column) -> Result<DbValue, v2::Error> {
match column.data_type {
DbDataType::Binary => convert_value_to::<Vec<u8>>(value).map(DbValue::Binary),
DbDataType::Boolean => convert_value_to::<bool>(value).map(DbValue::Boolean),
Expand All @@ -231,23 +293,13 @@ fn convert_value(value: mysql_async::Value, column: &Column) -> Result<DbValue,
DbDataType::Uint16 => convert_value_to::<u16>(value).map(DbValue::Uint16),
DbDataType::Uint32 => convert_value_to::<u32>(value).map(DbValue::Uint32),
DbDataType::Uint64 => convert_value_to::<u64>(value).map(DbValue::Uint64),
DbDataType::Other => Err(MysqlError::ValueConversionFailed(format!(
DbDataType::Other => Err(v2::Error::ValueConversionFailed(format!(
"Cannot convert value {:?} in column {} data type {:?}",
value, column.name, column.data_type
))),
}
}

impl OutboundMysql {
async fn get_conn(&mut self, address: &str) -> anyhow::Result<&mut mysql_async::Conn> {
let client = match self.connections.entry(address.to_owned()) {
std::collections::hash_map::Entry::Occupied(o) => o.into_mut(),
std::collections::hash_map::Entry::Vacant(v) => v.insert(build_conn(address).await?),
};
Ok(client)
}
}

async fn build_conn(address: &str) -> Result<mysql_async::Conn, mysql_async::Error> {
tracing::log::debug!("Build new connection: {}", address);

Expand Down Expand Up @@ -295,8 +347,8 @@ fn build_opts(address: &str) -> Result<Opts, mysql_async::Error> {
.into())
}

fn convert_value_to<T: FromValue>(value: mysql_async::Value) -> Result<T, MysqlError> {
from_value_opt::<T>(value).map_err(|e| MysqlError::ValueConversionFailed(format!("{}", e)))
fn convert_value_to<T: FromValue>(value: mysql_async::Value) -> Result<T, v2::Error> {
from_value_opt::<T>(value).map_err(|e| v2::Error::ValueConversionFailed(format!("{}", e)))
}

#[cfg(test)]
Expand Down
2 changes: 1 addition & 1 deletion crates/outbound-pg/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ impl v2::HostConnection for OutboundPg {
.await
.map_err(|e| v2::Error::ConnectionFailed(format!("{e:?}")))?,
)
.map_err(|_| v2::Error::Other("too many connections".into()))
.map_err(|_| v2::Error::ConnectionFailed("too many connections".into()))
.map(Resource::new_own)
}
.await)
Expand Down
Loading

0 comments on commit d1d8184

Please sign in to comment.