Skip to content

Commit

Permalink
Merge pull request #1886 from fermyon/postgres-resources
Browse files Browse the repository at this point in the history
Use resources for Postgres API
  • Loading branch information
rylev authored Oct 16, 2023
2 parents 88373dd + 4573dce commit e2c844a
Show file tree
Hide file tree
Showing 12 changed files with 189 additions and 86 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion crates/outbound-pg/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ native-tls = "0.2.11"
postgres-native-tls = "0.5.0"
spin-core = { path = "../core" }
spin-world = { path = "../world" }
tokio = { version = "1", features = [ "rt-multi-thread" ] }
table = { path = "../table" }
tokio = { version = "1", features = ["rt-multi-thread"] }
tokio-postgres = { version = "0.7.7" }
tracing = { workspace = true }
125 changes: 94 additions & 31 deletions crates/outbound-pg/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use anyhow::{anyhow, Result};
use native_tls::TlsConnector;
use postgres_native_tls::MakeTlsConnector;
use spin_core::{async_trait, HostComponent};
use spin_core::{async_trait, wasmtime::component::Resource, HostComponent};
use spin_world::v1::{
postgres::{self, PgError},
postgres as v1,
rdbms_types::{Column, DbDataType, DbValue, ParameterValue, RowSet},
};
use std::collections::HashMap;
use spin_world::v2::postgres::{self as v2, Connection};
use tokio_postgres::{
config::SslMode,
types::{ToSql, Type},
Expand All @@ -16,7 +16,15 @@ use tokio_postgres::{
/// A simple implementation to support outbound pg connection
#[derive(Default)]
pub struct OutboundPg {
pub connections: HashMap<String, Client>,
pub connections: table::Table<Client>,
}

impl OutboundPg {
async fn get_client(&mut self, connection: Resource<Connection>) -> Result<&Client, v2::Error> {
self.connections
.get(connection.rep())
.ok_or_else(|| v2::Error::ConnectionFailed("no connection found".into()))
}
}

impl HostComponent for OutboundPg {
Expand All @@ -26,7 +34,8 @@ impl HostComponent for OutboundPg {
linker: &mut spin_core::Linker<T>,
get: impl Fn(&mut spin_core::Data<T>) -> &mut Self::Data + Send + Sync + Copy + 'static,
) -> anyhow::Result<()> {
postgres::add_to_linker(linker, get)
v1::add_to_linker(linker, get)?;
v2::add_to_linker(linker, get)
}

fn build_data(&self) -> Self::Data {
Expand All @@ -35,27 +44,43 @@ impl HostComponent for OutboundPg {
}

#[async_trait]
impl postgres::Host for OutboundPg {
impl v2::Host for OutboundPg {}

#[async_trait]
impl v2::HostConnection for OutboundPg {
async fn open(&mut self, address: String) -> Result<Result<Resource<Connection>, v2::Error>> {
Ok(async {
self.connections
.push(
build_client(&address)
.await
.map_err(|e| v2::Error::ConnectionFailed(format!("{e:?}")))?,
)
.map_err(|_| v2::Error::Other("too many connections".into()))
.map(Resource::new_own)
}
.await)
}

async fn execute(
&mut self,
address: String,
connection: Resource<Connection>,
statement: String,
params: Vec<ParameterValue>,
) -> Result<Result<u64, PgError>> {
) -> Result<Result<u64, v2::Error>> {
Ok(async {
let params: Vec<&(dyn ToSql + Sync)> = params
.iter()
.map(to_sql_parameter)
.collect::<anyhow::Result<Vec<_>>>()
.map_err(|e| PgError::ValueConversionFailed(format!("{:?}", e)))?;
.map_err(|e| v2::Error::ValueConversionFailed(format!("{:?}", e)))?;

let nrow = self
.get_client(&address)
.await
.map_err(|e| PgError::ConnectionFailed(format!("{:?}", e)))?
.get_client(connection)
.await?
.execute(&statement, params.as_slice())
.await
.map_err(|e| PgError::QueryFailed(format!("{:?}", e)))?;
.map_err(|e| v2::Error::QueryFailed(format!("{:?}", e)))?;

Ok(nrow)
}
Expand All @@ -64,24 +89,23 @@ impl postgres::Host for OutboundPg {

async fn query(
&mut self,
address: String,
connection: Resource<Connection>,
statement: String,
params: Vec<ParameterValue>,
) -> Result<Result<RowSet, PgError>> {
) -> Result<Result<RowSet, v2::Error>> {
Ok(async {
let params: Vec<&(dyn ToSql + Sync)> = params
.iter()
.map(to_sql_parameter)
.collect::<anyhow::Result<Vec<_>>>()
.map_err(|e| PgError::BadParameter(format!("{:?}", e)))?;
.map_err(|e| v2::Error::BadParameter(format!("{:?}", e)))?;

let results = self
.get_client(&address)
.await
.map_err(|e| PgError::ConnectionFailed(format!("{:?}", e)))?
.get_client(connection)
.await?
.query(&statement, params.as_slice())
.await
.map_err(|e| PgError::QueryFailed(format!("{:?}", e)))?;
.map_err(|e| v2::Error::QueryFailed(format!("{:?}", e)))?;

if results.is_empty() {
return Ok(RowSet {
Expand All @@ -95,12 +119,17 @@ impl postgres::Host for OutboundPg {
.iter()
.map(convert_row)
.collect::<Result<Vec<_>, _>>()
.map_err(|e| PgError::QueryFailed(format!("{:?}", e)))?;
.map_err(|e| v2::Error::QueryFailed(format!("{:?}", e)))?;

Ok(RowSet { columns, rows })
}
.await)
}

fn drop(&mut self, connection: Resource<Connection>) -> anyhow::Result<()> {
self.connections.remove(connection.rep());
Ok(())
}
}

fn to_sql_parameter(value: &ParameterValue) -> anyhow::Result<&(dyn ToSql + Sync)> {
Expand Down Expand Up @@ -233,16 +262,6 @@ fn convert_entry(row: &Row, index: usize) -> Result<DbValue, tokio_postgres::Err
Ok(value)
}

impl OutboundPg {
async fn get_client(&mut self, address: &str) -> anyhow::Result<&Client> {
let client = match self.connections.entry(address.to_owned()) {
std::collections::hash_map::Entry::Occupied(o) => o.into_mut(),
std::collections::hash_map::Entry::Vacant(v) => v.insert(build_client(address).await?),
};
Ok(client)
}
}

async fn build_client(address: &str) -> anyhow::Result<Client> {
let config = address.parse::<tokio_postgres::Config>()?;

Expand Down Expand Up @@ -325,3 +344,47 @@ impl std::fmt::Debug for PgNull {
f.debug_struct("NULL").finish()
}
}

/// Delegate a function call to the v2::HostConnection implementation
macro_rules! delegate {
($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{
let connection = match <Self as v2::HostConnection>::open($self, $address).await? {
Ok(c) => c,
Err(e) => return Ok(Err(to_legacy_error(e))),
};
Ok(<Self as v2::HostConnection>::$name($self, connection, $($arg),*)
.await?
.map_err(|e| to_legacy_error(e)))
}};
}

#[async_trait]
impl v1::Host for OutboundPg {
async fn execute(
&mut self,
address: String,
statement: String,
params: Vec<ParameterValue>,
) -> Result<Result<u64, v1::PgError>> {
delegate!(self.execute(address, statement, params))
}

async fn query(
&mut self,
address: String,
statement: String,
params: Vec<ParameterValue>,
) -> Result<Result<RowSet, v1::PgError>> {
delegate!(self.query(address, statement, params))
}
}

fn to_legacy_error(error: v2::Error) -> v1::PgError {
match error {
v2::Error::ConnectionFailed(e) => v1::PgError::ConnectionFailed(e),
v2::Error::BadParameter(e) => v1::PgError::BadParameter(e),
v2::Error::QueryFailed(e) => v1::PgError::QueryFailed(e),
v2::Error::ValueConversionFailed(e) => v1::PgError::ValueConversionFailed(e),
v2::Error::Other(e) => v1::PgError::OtherError(e),
}
}
2 changes: 0 additions & 2 deletions crates/table/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,4 @@ version.workspace = true
authors.workspace = true
edition.workspace = true

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
6 changes: 6 additions & 0 deletions crates/table/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ pub struct Table<V> {
tuples: HashMap<u32, V>,
}

impl<V> Default for Table<V> {
fn default() -> Self {
Self::new(1024)
}
}

impl<V> Table<V> {
/// Create a new, empty table with the specified capacity.
pub fn new(capacity: u32) -> Self {
Expand Down
11 changes: 7 additions & 4 deletions examples/rust-outbound-pg/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,10 @@ fn process(req: Request) -> Result<Response> {

fn read(_req: Request) -> Result<Response> {
let address = std::env::var(DB_URL_ENV)?;
let conn = pg::Connection::open(&address)?;

let sql = "SELECT id, title, content, authorname, coauthor FROM articletest";
let rowset = pg::query(&address, sql, &[])?;
let rowset = conn.query(sql, &[])?;

let column_summary = rowset
.columns
Expand Down Expand Up @@ -89,14 +90,15 @@ fn read(_req: Request) -> Result<Response> {

fn write(_req: Request) -> Result<Response> {
let address = std::env::var(DB_URL_ENV)?;
let conn = pg::Connection::open(&address)?;

let sql = "INSERT INTO articletest (title, content, authorname) VALUES ('aaa', 'bbb', 'ccc')";
let nrow_executed = pg::execute(&address, sql, &[])?;
let nrow_executed = conn.execute(sql, &[])?;

println!("nrow_executed: {}", nrow_executed);

let sql = "SELECT COUNT(id) FROM articletest";
let rowset = pg::query(&address, sql, &[])?;
let rowset = conn.query(sql, &[])?;
let row = &rowset.rows[0];
let count = i64::decode(&row[0])?;
let response = format!("Count: {}\n", count);
Expand All @@ -108,10 +110,11 @@ fn write(_req: Request) -> Result<Response> {

fn pg_backend_pid(_req: Request) -> Result<Response> {
let address = std::env::var(DB_URL_ENV)?;
let conn = pg::Connection::open(&address)?;
let sql = "SELECT pg_backend_pid()";

let get_pid = || {
let rowset = pg::query(&address, sql, &[])?;
let rowset = conn.query(sql, &[])?;
let row = &rowset.rows[0];

i32::decode(&row[0])
Expand Down
3 changes: 2 additions & 1 deletion examples/spin-timer/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions sdk/rust/src/pg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
//! | `String` | str(string) | VARCHAR, CHAR(N), TEXT |
//! | `Vec<u8>` | binary(list\<u8\>) | BYTEA |
pub use super::wit::v1::postgres::{execute, query, PgError};
#[doc(inline)]
pub use super::wit::v1::rdbms_types::*;
#[doc(inline)]
pub use super::wit::v2::postgres::{Connection, Error as PgError};

/// A pg error
#[derive(Debug, thiserror::Error)]
Expand All @@ -23,7 +25,7 @@ pub enum Error {
#[error("error value decoding: {0}")]
Decode(String),
/// Pg query failed with an error
#[error("{0}")]
#[error(transparent)]
PgError(#[from] PgError),
}

Expand Down
Loading

0 comments on commit e2c844a

Please sign in to comment.