Skip to content

Commit

Permalink
factors: Make spin-factor-outbound-mysql generic across clients and w…
Browse files Browse the repository at this point in the history
…rite tests

Signed-off-by: Caleb Schoepp <[email protected]>
Co-authored-by: Karthik Ganeshram <[email protected]>
  • Loading branch information
calebschoepp and karthik2804 committed Aug 19, 2024
1 parent 6fea860 commit e01162d
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 464 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions crates/factor-outbound-mysql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,9 @@ tokio = { version = "1", features = ["rt-multi-thread"] }
tracing = { version = "0.1", features = ["log"] }
url = "2.3.1"

[dev-dependencies]
spin-factor-variables = { path = "../factor-variables" }
spin-factors-test = { path = "../factors-test" }

[lints]
workspace = true
Original file line number Diff line number Diff line change
@@ -1,133 +1,72 @@
use anyhow::{Context, Result};
use mysql_async::{consts::ColumnType, from_value_opt, prelude::*, Opts, OptsBuilder, SslOpts};
use spin_app::DynamicHostComponent;
use spin_core::wasmtime::component::Resource;
use spin_core::{async_trait, HostComponent};
use spin_world::v1::mysql as v1;
use spin_world::v2::mysql::{self as v2, Connection};
use spin_world::v2::rdbms_types as v2_types;
use spin_world::v2::rdbms_types::{Column, DbDataType, DbValue, ParameterValue};
use std::sync::Arc;
use tracing::{instrument, Level};
use url::Url;

/// A simple implementation to support outbound mysql connection
pub struct OutboundMysqlComponent {
pub resolver: spin_expressions::SharedPreparedResolver,
}

#[derive(Default)]
pub struct OutboundMysql {
allowed_hosts: spin_outbound_networking::AllowedHostsConfig,
pub connections: table::Table<mysql_async::Conn>,
}
use anyhow::{anyhow, Result};
use mysql_async::consts::ColumnType;
use mysql_async::prelude::{FromValue, Queryable as _};
use mysql_async::{from_value_opt, Conn as MysqlClient, Opts, OptsBuilder, SslOpts};
use spin_core::async_trait;
use spin_world::v2::mysql::{self as v2};
use spin_world::v2::rdbms_types::{
self as v2_types, Column, DbDataType, DbValue, ParameterValue, RowSet,
};
use url::Url;

impl OutboundMysql {
async fn open_connection(&mut self, address: &str) -> Result<Resource<Connection>, v2::Error> {
self.connections
.push(
build_conn(address)
.await
.map_err(|e| v2::Error::ConnectionFailed(format!("{e:?}")))?,
)
.map_err(|_| v2::Error::ConnectionFailed("too many connections".into()))
.map(Resource::new_own)
}
#[async_trait]
pub trait Client {
async fn build_client(address: &str) -> Result<Self>
where
Self: Sized;

async fn get_conn(
async fn execute(
&mut self,
connection: Resource<Connection>,
) -> Result<&mut mysql_async::Conn, v2::Error> {
self.connections
.get_mut(connection.rep())
.ok_or_else(|| v2::Error::ConnectionFailed("no connection found".into()))
}
statement: String,
params: Vec<ParameterValue>,
) -> Result<(), v2::Error>;

fn is_address_allowed(&self, address: &str) -> bool {
spin_outbound_networking::check_url(address, "mysql", &self.allowed_hosts)
}
async fn query(
&mut self,
statement: String,
params: Vec<ParameterValue>,
) -> Result<RowSet, v2::Error>;
}

impl HostComponent for OutboundMysqlComponent {
type Data = OutboundMysql;

fn add_to_linker<T: Send>(
linker: &mut spin_core::Linker<T>,
get: impl Fn(&mut spin_core::Data<T>) -> &mut Self::Data + Send + Sync + Copy + 'static,
) -> anyhow::Result<()> {
v2::add_to_linker(linker, get)?;
v1::add_to_linker(linker, get)
}

fn build_data(&self) -> Self::Data {
Default::default()
}
}
#[async_trait]
impl Client for MysqlClient {
async fn build_client(address: &str) -> Result<Self>
where
Self: Sized,
{
tracing::debug!("Build new connection: {}", address);

impl DynamicHostComponent for OutboundMysqlComponent {
fn update_data(
&self,
data: &mut Self::Data,
component: &spin_app::AppComponent,
) -> anyhow::Result<()> {
let hosts = component
.get_metadata(spin_outbound_networking::ALLOWED_HOSTS_KEY)?
.unwrap_or_default();
data.allowed_hosts = spin_outbound_networking::AllowedHostsConfig::parse(
&hosts,
self.resolver.get().unwrap(),
)
.context("`allowed_outbound_hosts` contained an invalid url")?;
Ok(())
}
}
let opts = build_opts(address)?;

impl v2::Host for OutboundMysql {}
let connection_pool = mysql_async::Pool::new(opts);

#[async_trait]
impl v2::HostConnection for OutboundMysql {
#[instrument(name = "spin_outbound_mysql.open_connection", skip(self), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql"))]
async fn open(&mut self, address: String) -> Result<Resource<Connection>, v2::Error> {
if !self.is_address_allowed(&address) {
return Err(v2::Error::ConnectionFailed(format!(
"address {address} is not permitted"
)));
}
self.open_connection(&address).await
connection_pool.get_conn().await.map_err(|e| anyhow!(e))
}

#[instrument(name = "spin_outbound_mysql.execute", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", otel.name = statement))]
async fn execute(
&mut self,
connection: Resource<Connection>,
statement: String,
params: Vec<ParameterValue>,
) -> Result<(), v2::Error> {
let db_params = params.into_iter().map(to_sql_parameter).collect::<Vec<_>>();
let parameters = mysql_async::Params::Positional(db_params);

self.get_conn(connection)
.await?
.exec_batch(&statement, &[parameters])
self.exec_batch(&statement, &[parameters])
.await
.map_err(|e| v2::Error::QueryFailed(format!("{:?}", e)))?;

Ok(())
.map_err(|e| v2::Error::QueryFailed(format!("{:?}", e)))
}

#[instrument(name = "spin_outbound_mysql.query", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", otel.name = statement))]
async fn query(
&mut self,
connection: Resource<Connection>,
statement: String,
params: Vec<ParameterValue>,
) -> Result<v2_types::RowSet, v2::Error> {
) -> Result<RowSet, v2::Error> {
let db_params = params.into_iter().map(to_sql_parameter).collect::<Vec<_>>();
let parameters = mysql_async::Params::Positional(db_params);

let mut query_result = self
.get_conn(connection)
.await?
.exec_iter(&statement, parameters)
.await
.map_err(|e| v2::Error::QueryFailed(format!("{:?}", e)))?;
Expand All @@ -147,69 +86,6 @@ impl v2::HostConnection for OutboundMysql {
}
}
}

fn drop(&mut self, connection: Resource<Connection>) -> Result<()> {
self.connections.remove(connection.rep());
Ok(())
}
}

impl v2_types::Host for OutboundMysql {
fn convert_error(&mut self, error: v2::Error) -> Result<v2::Error> {
Ok(error)
}
}

/// Delegate a function call to the v2::HostConnection implementation
macro_rules! delegate {
($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{
if !$self.is_address_allowed(&$address) {
return Err(v1::MysqlError::ConnectionFailed(format!(
"address {} is not permitted", $address
)));
}
let connection = match $self.open_connection(&$address).await {
Ok(c) => c,
Err(e) => return Err(e.into()),
};
<Self as v2::HostConnection>::$name($self, connection, $($arg),*)
.await
.map_err(Into::into)
}};
}

#[async_trait]
impl v1::Host for OutboundMysql {
async fn execute(
&mut self,
address: String,
statement: String,
params: Vec<v1::ParameterValue>,
) -> Result<(), v1::MysqlError> {
delegate!(self.execute(
address,
statement,
params.into_iter().map(Into::into).collect()
))
}

async fn query(
&mut self,
address: String,
statement: String,
params: Vec<v1::ParameterValue>,
) -> Result<v1::RowSet, v1::MysqlError> {
delegate!(self.query(
address,
statement,
params.into_iter().map(Into::into).collect()
))
.map(Into::into)
}

fn convert_mysql_error(&mut self, error: v1::MysqlError) -> Result<v1::MysqlError> {
Ok(error)
}
}

fn to_sql_parameter(value: ParameterValue) -> mysql_async::Value {
Expand Down Expand Up @@ -347,16 +223,6 @@ fn convert_value(value: mysql_async::Value, column: &Column) -> Result<DbValue,
}
}

async fn build_conn(address: &str) -> Result<mysql_async::Conn, mysql_async::Error> {
tracing::debug!("Build new connection: {}", address);

let opts = build_opts(address)?;

let connection_pool = mysql_async::Pool::new(opts);

connection_pool.get_conn().await
}

fn is_ssl_param(s: &str) -> bool {
["ssl-mode", "sslmode"].contains(&s.to_lowercase().as_str())
}
Expand Down
Loading

0 comments on commit e01162d

Please sign in to comment.