From 8f64921843d4cb27ad0e88b86a789869c369975a Mon Sep 17 00:00:00 2001 From: Jeb Bearer Date: Mon, 16 Sep 2024 17:53:24 -0400 Subject: [PATCH 1/5] Switch postgres client to sqlx Tests passing --- Cargo.lock | 33 +- Cargo.toml | 23 +- flake.nix | 12 + src/data_source.rs | 59 +- src/data_source/extension.rs | 12 +- src/data_source/fetching.rs | 115 ++-- src/data_source/fetching/block.rs | 12 +- src/data_source/fetching/header.rs | 2 +- src/data_source/fetching/leaf.rs | 6 +- src/data_source/fetching/notify_storage.rs | 100 +-- src/data_source/fetching/transaction.rs | 4 +- src/data_source/fetching/vid.rs | 6 +- src/data_source/sql.rs | 7 +- src/data_source/storage.rs | 186 ++---- src/data_source/storage/fs.rs | 43 +- src/data_source/storage/no_storage.rs | 117 ++-- src/data_source/storage/pruning.rs | 15 +- src/data_source/storage/sql.rs | 395 +++++------- src/data_source/storage/sql/db.rs | 25 + src/data_source/storage/sql/migrate.rs | 71 +++ src/data_source/storage/sql/queries.rs | 318 +++++++++ .../sql/{query => queries}/availability.rs | 175 +++-- .../sql/{query => queries}/explorer.rs | 525 +++++++-------- .../storage/sql/{query => queries}/node.rs | 137 ++-- .../storage/sql/{query => queries}/state.rs | 603 +++++++----------- src/data_source/storage/sql/query.rs | 268 -------- src/data_source/storage/sql/transaction.rs | 566 +++++++--------- src/data_source/update.rs | 18 - src/explorer.rs | 5 +- src/explorer/data_source.rs | 92 +++ src/fetching/provider/query_service.rs | 6 +- src/merklized_state.rs | 2 +- src/merklized_state/data_source.rs | 29 +- src/node.rs | 2 +- src/node/data_source.rs | 40 +- 35 files changed, 1911 insertions(+), 2118 deletions(-) create mode 100644 src/data_source/storage/sql/db.rs create mode 100644 src/data_source/storage/sql/migrate.rs create mode 100644 src/data_source/storage/sql/queries.rs rename src/data_source/storage/sql/{query => queries}/availability.rs (50%) rename src/data_source/storage/sql/{query => queries}/explorer.rs (52%) rename src/data_source/storage/sql/{query => queries}/node.rs (77%) rename src/data_source/storage/sql/{query => queries}/state.rs (74%) delete mode 100644 src/data_source/storage/sql/query.rs create mode 100644 src/explorer/data_source.rs diff --git a/Cargo.lock b/Cargo.lock index 88f0c2705..13cc2222b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1109,9 +1109,6 @@ name = "bit-vec" version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" -dependencies = [ - "serde", -] [[package]] name = "bitflags" @@ -3236,7 +3233,6 @@ dependencies = [ "atomic_store", "backtrace-on-stack-overflow", "bincode", - "bit-vec", "chrono", "clap", "committable", @@ -3255,24 +3251,22 @@ dependencies = [ "itertools 0.12.1", "jf-merkle-tree", "jf-vid", - "native-tls", "portpicker", - "postgres-native-tls", "prometheus", "rand 0.8.5", "refinery", + "refinery-core", "reqwest", "serde", "serde_json", "snafu 0.8.4", "spin_sleep", + "sqlx", "surf-disco", "tagged-base64", "tempfile", "tide-disco", "time 0.3.36", - "tokio", - "tokio-postgres", "toml", "tracing", "trait-variant", @@ -5608,19 +5602,6 @@ dependencies = [ "rand 0.8.5", ] -[[package]] -name = "postgres-native-tls" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d442770e2b1e244bb5eb03b31c79b65bb2568f413b899eaba850fa945a65954" -dependencies = [ - "futures", - "native-tls", - "tokio", - "tokio-native-tls", - "tokio-postgres", -] - [[package]] name = "postgres-protocol" version = "0.6.7" @@ -5645,13 +5626,9 @@ version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "02048d9e032fb3cc3413bbf7b83a15d84a5d419778e2628751896d856498eee9" dependencies = [ - "bit-vec", "bytes", "fallible-iterator", "postgres-protocol", - "serde", - "serde_json", - "time 0.3.36", ] [[package]] @@ -7085,7 +7062,10 @@ version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d4d8060b456358185f7d50c55d9b5066ad956956fddec42ee2e8567134a8936e" dependencies = [ + "async-io 1.13.0", + "async-std", "atoi", + "bit-vec", "byteorder", "bytes", "crc", @@ -7103,6 +7083,7 @@ dependencies = [ "indexmap 2.5.0", "log", "memchr", + "native-tls", "once_cell", "paste", "percent-encoding", @@ -7138,6 +7119,7 @@ version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1804e8a7c7865599c9c79be146dc8a9fd8cc86935fa641d3ea58e5f0688abaa5" dependencies = [ + "async-std", "dotenvy", "either", "heck 0.5.0", @@ -7209,6 +7191,7 @@ checksum = "6fa91a732d854c5d7726349bb4bb879bb9478993ceb764247660aee25f67c2f8" dependencies = [ "atoi", "base64 0.22.1", + "bit-vec", "bitflags 2.6.0", "byteorder", "crc", diff --git a/Cargo.toml b/Cargo.toml index 356929363..cbedc441f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,11 +36,9 @@ no-storage = [] # Enable the availability data source backed by a Postgres database. sql-data-source = [ "include_dir", - "native-tls", - "postgres-native-tls", "refinery", - "tokio", - "tokio-postgres", + "refinery-core", + "sqlx", ] # Enable extra features useful for writing tests with a query service. @@ -67,7 +65,6 @@ async-lock = "3.3.0" async-std = { version = "1.9.0", features = ["unstable", "attributes"] } async-trait = "0.1" bincode = "1.3" -bit-vec = { version = "0.6.3", features = ["serde_std"] } chrono = "0.4" committable = "0.2" custom_debug = "0.6" @@ -107,15 +104,15 @@ atomic_store = { git = "https://github.com/EspressoSystems/atomicstore.git", tag # Dependencies enabled by feature "sql-data-source". include_dir = { version = "0.7", optional = true } -native-tls = { version = "0.2", optional = true } -postgres-native-tls = { version = "0.5", optional = true } refinery = { version = "0.8", features = ["tokio-postgres"], optional = true } -tokio = { version = "1.37", optional = true } -tokio-postgres = { version = "0.7", optional = true, default-features = false, features = [ # disabling the default features removes dependence on the tokio runtime - "with-serde_json-1", - "with-time-0_3", - "with-bit-vec-0_6", -] } +refinery-core = { version = "0.8", optional = true } +sqlx = { version = "0.8", features = [ + "bit-vec", + "postgres", + "runtime-async-std", + "sqlite", + "tls-native-tls", +], optional = true } # Dependencies enabled by feature "testing". espresso-macros = { git = "https://github.com/EspressoSystems/espresso-macros.git", tag = "0.1.0", optional = true } diff --git a/flake.nix b/flake.nix index 822491363..7d9580cb0 100644 --- a/flake.nix +++ b/flake.nix @@ -42,6 +42,9 @@ rustToolchain = pkgs.rust-bin.stable.latest.minimal.override { extensions = [ "rustfmt" "clippy" "llvm-tools-preview" "rust-src" ]; }; + nightlyToolchain = pkgs.rust-bin.nightly.latest.minimal.override { + extensions = [ "rustfmt" "clippy" "llvm-tools-preview" "rust-src" ]; + }; rustDeps = with pkgs; [ pkg-config @@ -184,6 +187,15 @@ inherit RUST_SRC_PATH RUST_BACKTRACE RUST_LOG RUSTFLAGS CARGO_TARGET_DIR; }; devShells = { + nightlyShell = pkgs.mkShell { + shellHook = shellHook; + buildInputs = with pkgs; + [ + nixWithFlakes + git + nightlyToolchain + ] ++ myPython ++ rustDeps; + }; perfShell = pkgs.mkShell { shellHook = shellHook; buildInputs = with pkgs; diff --git a/src/data_source.rs b/src/data_source.rs index 3a4bfeebe..02595ef61 100644 --- a/src/data_source.rs +++ b/src/data_source.rs @@ -43,7 +43,7 @@ pub use fs::FileSystemDataSource; pub use metrics::MetricsDataSource; #[cfg(feature = "sql-data-source")] pub use sql::SqlDataSource; -pub use update::{ReadOnly, Transaction, UpdateDataSource, VersionedDataSource}; +pub use update::{Transaction, UpdateDataSource, VersionedDataSource}; #[cfg(any(test, feature = "testing"))] mod test_helpers { @@ -57,6 +57,7 @@ mod test_helpers { stream::{BoxStream, StreamExt}, }; use std::ops::{Bound, RangeBounds}; + /// Apply an upper bound to a range based on the currently available block height. async fn bound_range(ds: &D, range: R) -> impl RangeBounds where @@ -98,9 +99,12 @@ mod test_helpers { .boxed() } - pub async fn get_non_empty_blocks( - ds: &impl TestableDataSource, - ) -> Vec<(LeafQueryData, BlockQueryData)> { + pub async fn get_non_empty_blocks( + ds: &D, + ) -> Vec<(LeafQueryData, BlockQueryData)> + where + D: TestableDataSource, + { // Ignore the genesis block (start from height 1). leaf_range(ds, 1..) .await @@ -118,6 +122,7 @@ pub mod availability_tests { use super::test_helpers::*; use crate::{ availability::{payload_size, BlockId}, + data_source::storage::NodeStorage, node::NodeDataSource, testing::{ consensus::{MockNetwork, TestableDataSource}, @@ -274,7 +279,7 @@ pub mod availability_tests { #[async_std::test] pub async fn test_update() where - for<'a> D::ReadOnly<'a>: NodeDataSource, + for<'a> D::ReadOnly<'a>: NodeStorage, { setup_test(); @@ -348,7 +353,7 @@ pub mod availability_tests { #[async_std::test] pub async fn test_range() where - for<'a> D::ReadOnly<'a>: NodeDataSource, + for<'a> D::ReadOnly<'a>: NodeStorage, { setup_test(); @@ -358,7 +363,7 @@ pub mod availability_tests { // Wait for there to be at least 3 blocks. let block_height = loop { - let tx = ds.read().await.unwrap(); + let mut tx = ds.read().await.unwrap(); let block_height = tx.block_height().await.unwrap(); if block_height >= 3 { break block_height as u64; @@ -450,7 +455,10 @@ pub mod availability_tests { pub mod persistence_tests { use crate::{ availability::{BlockQueryData, LeafQueryData, UpdateAvailabilityData}, - data_source::{storage::AvailabilityStorage, Transaction, UpdateDataSource}, + data_source::{ + storage::{AvailabilityStorage, NodeStorage}, + Transaction, UpdateDataSource, + }, node::NodeDataSource, testing::{ consensus::TestableDataSource, @@ -467,9 +475,8 @@ pub mod persistence_tests { #[async_std::test] pub async fn test_revert() where - for<'a> D::Transaction<'a>: UpdateDataSource - + NodeDataSource - + AvailabilityStorage, + for<'a> D::Transaction<'a>: + UpdateDataSource + AvailabilityStorage + NodeStorage, { use hotshot_example_types::node_types::TestVersions; @@ -502,12 +509,7 @@ pub mod persistence_tests { tx.insert_leaf(leaf.clone()).await.unwrap(); tx.insert_block(block.clone()).await.unwrap(); - assert_eq!( - NodeDataSource::::block_height(&tx) - .await - .unwrap(), - 2 - ); + assert_eq!(tx.block_height().await.unwrap(), 2); assert_eq!(leaf, tx.get_leaf(1.into()).await.unwrap()); assert_eq!(block, tx.get_block(1.into()).await.unwrap()); @@ -603,10 +605,9 @@ pub mod persistence_tests { #[async_std::test] pub async fn test_drop_tx() where - for<'a> D::Transaction<'a>: UpdateDataSource - + NodeDataSource - + AvailabilityStorage, - for<'a> D::ReadOnly<'a>: NodeDataSource, + for<'a> D::Transaction<'a>: + UpdateDataSource + AvailabilityStorage + NodeStorage, + for<'a> D::ReadOnly<'a>: NodeStorage, { use hotshot_example_types::node_types::TestVersions; @@ -649,7 +650,7 @@ pub mod persistence_tests { // Open a new transaction and check that the changes are reverted. tracing::info!("read"); - let tx = ds.read().await.unwrap(); + let mut tx = ds.read().await.unwrap(); assert_eq!(tx.block_height().await.unwrap(), 0); drop(tx); @@ -691,8 +692,8 @@ pub mod node_tests { BlockQueryData, LeafQueryData, QueryableHeader, UpdateAvailabilityData, VidCommonQueryData, }, - data_source::{update::Transaction, UpdateDataSource}, - node::{BlockId, NodeDataSource, SyncStatus, TimeWindowQueryData, WindowStart}, + data_source::{storage::NodeStorage, update::Transaction, UpdateDataSource}, + node::{BlockId, SyncStatus, TimeWindowQueryData, WindowStart}, testing::{ consensus::{MockNetwork, TestableDataSource}, mocks::{mock_transaction, MockPayload, MockTypes}, @@ -948,7 +949,7 @@ pub mod node_tests { #[async_std::test] pub async fn test_vid_shares() where - for<'a> D::ReadOnly<'a>: NodeDataSource, + for<'a> D::ReadOnly<'a>: NodeStorage, { setup_test(); @@ -961,7 +962,7 @@ pub mod node_tests { let mut leaves = ds.subscribe_leaves(0).await.take(3); while let Some(leaf) = leaves.next().await { tracing::info!("got leaf {}", leaf.height()); - let tx = ds.read().await.unwrap(); + let mut tx = ds.read().await.unwrap(); let share = tx.vid_share(leaf.height() as usize).await.unwrap(); assert_eq!(share, tx.vid_share(leaf.block_hash()).await.unwrap()); assert_eq!( @@ -977,7 +978,7 @@ pub mod node_tests { pub async fn test_vid_monotonicity() where for<'a> D::Transaction<'a>: UpdateDataSource, - for<'a> D::ReadOnly<'a>: NodeDataSource, + for<'a> D::ReadOnly<'a>: NodeStorage, { use hotshot_example_types::node_types::TestVersions; @@ -1027,7 +1028,7 @@ pub mod node_tests { #[async_std::test] pub async fn test_vid_recovery() where - for<'a> D::ReadOnly<'a>: NodeDataSource, + for<'a> D::ReadOnly<'a>: NodeStorage, { setup_test(); @@ -1099,7 +1100,7 @@ pub mod node_tests { #[async_std::test] pub async fn test_timestamp_window() where - for<'a> D::ReadOnly<'a>: NodeDataSource, + for<'a> D::ReadOnly<'a>: NodeStorage, { setup_test(); diff --git a/src/data_source/extension.rs b/src/data_source/extension.rs index bfeca8092..7b948f1dd 100644 --- a/src/data_source/extension.rs +++ b/src/data_source/extension.rs @@ -10,14 +10,14 @@ // You should have received a copy of the GNU General Public License along with this program. If not, // see . -use super::{storage::ExplorerStorage, VersionedDataSource}; +use super::VersionedDataSource; use crate::{ availability::{ AvailabilityDataSource, BlockId, BlockQueryData, Fetch, LeafId, LeafQueryData, PayloadQueryData, QueryableHeader, QueryablePayload, TransactionHash, TransactionQueryData, UpdateAvailabilityData, VidCommonQueryData, }, - explorer::{self, ExplorerHeader, ExplorerTransaction}, + explorer::{self, ExplorerDataSource, ExplorerHeader, ExplorerTransaction}, merklized_state::{ MerklizedState, MerklizedStateDataSource, MerklizedStateHeightPersistence, Snapshot, UpdateStateData, @@ -289,7 +289,7 @@ where impl MerklizedStateDataSource for ExtensibleDataSource where - D: MerklizedStateDataSource + Send + Sync, + D: MerklizedStateDataSource + Sync, U: Send + Sync, Types: NodeType, State: MerklizedState, @@ -306,7 +306,7 @@ where #[async_trait] impl MerklizedStateHeightPersistence for ExtensibleDataSource where - D: MerklizedStateHeightPersistence + Send + Sync, + D: MerklizedStateHeightPersistence + Sync, U: Send + Sync, { async fn get_last_state_height(&self) -> QueryResult { @@ -340,9 +340,9 @@ where } #[async_trait] -impl ExplorerStorage for ExtensibleDataSource +impl ExplorerDataSource for ExtensibleDataSource where - D: ExplorerStorage + Send + Sync, + D: ExplorerDataSource + Sync, U: Send + Sync, Types: NodeType, Payload: QueryablePayload, diff --git a/src/data_source/fetching.rs b/src/data_source/fetching.rs index da341d875..c9ddf9461 100644 --- a/src/data_source/fetching.rs +++ b/src/data_source/fetching.rs @@ -76,7 +76,8 @@ use super::{ storage::{ pruning::{PruneStorage, PrunedHeightStorage}, - AvailabilityStorage, ExplorerStorage, + AvailabilityStorage, ExplorerStorage, MerklizedStateHeightStorage, MerklizedStateStorage, + NodeStorage, }, VersionedDataSource, }; @@ -86,7 +87,7 @@ use crate::{ PayloadQueryData, QueryableHeader, QueryablePayload, TransactionHash, TransactionQueryData, UpdateAvailabilityData, VidCommonQueryData, }, - explorer, + explorer::{self, ExplorerDataSource}, fetching::{self, request, Provider}, merklized_state::{ MerklizedState, MerklizedStateDataSource, MerklizedStateHeightPersistence, Snapshot, @@ -104,7 +105,7 @@ use async_std::{sync::Arc, task::sleep}; use async_trait::async_trait; use derivative::Derivative; use futures::{ - future::{self, join_all, BoxFuture, FutureExt}, + future::{join_all, BoxFuture, FutureExt}, stream::{self, BoxStream, Stream, StreamExt}, }; use hotshot_types::traits::node_implementation::NodeType; @@ -292,8 +293,7 @@ where Payload: QueryablePayload, Header: QueryableHeader, S: PruneStorage + VersionedDataSource + 'static, - for<'a> S::ReadOnly<'a>: - PrunedHeightStorage + NodeDataSource + AvailabilityStorage, + for<'a> S::ReadOnly<'a>: AvailabilityStorage + PrunedHeightStorage + NodeStorage, for<'a> S::Transaction<'a>: UpdateAvailabilityData, P: AvailabilityProvider, { @@ -387,8 +387,7 @@ where Header: QueryableHeader, S: VersionedDataSource + PruneStorage + 'static, for<'a> S::Transaction<'a>: UpdateAvailabilityData, - for<'a> S::ReadOnly<'a>: - PrunedHeightStorage + NodeDataSource + AvailabilityStorage, + for<'a> S::ReadOnly<'a>: AvailabilityStorage + NodeStorage + PrunedHeightStorage, P: AvailabilityProvider, { /// Build a [`FetchingDataSource`] with the given `storage` and `provider`. @@ -446,11 +445,11 @@ impl StatusDataSource for FetchingDataSource where Types: NodeType, S: VersionedDataSource + Send + Sync + 'static, - for<'a> S::ReadOnly<'a>: NodeDataSource, + for<'a> S::ReadOnly<'a>: NodeStorage, P: Send + Sync, { async fn block_height(&self) -> QueryResult { - let tx = self.read().await.map_err(|err| QueryError::Error { + let mut tx = self.read().await.map_err(|err| QueryError::Error { message: err.to_string(), })?; tx.block_height().await @@ -468,8 +467,7 @@ where Payload: QueryablePayload, S: VersionedDataSource + 'static, for<'a> S::Transaction<'a>: UpdateAvailabilityData, - for<'a> S::ReadOnly<'a>: - PrunedHeightStorage + AvailabilityStorage + NodeDataSource, + for<'a> S::ReadOnly<'a>: AvailabilityStorage + NodeStorage + PrunedHeightStorage, P: AvailabilityProvider, { type LeafRange = BoxStream<'static, Fetch>> @@ -615,7 +613,7 @@ impl Fetcher where Types: NodeType, S: VersionedDataSource + Sync, - for<'a> S::ReadOnly<'a>: NodeDataSource + PrunedHeightStorage, + for<'a> S::ReadOnly<'a>: PrunedHeightStorage + NodeStorage, { async fn new(builder: Builder) -> anyhow::Result { let mut payload_fetcher = fetching::Fetcher::default(); @@ -654,8 +652,7 @@ where Payload: QueryablePayload, S: VersionedDataSource + 'static, for<'a> S::Transaction<'a>: UpdateAvailabilityData, - for<'a> S::ReadOnly<'a>: - AvailabilityStorage + PrunedHeightStorage + NodeDataSource, + for<'a> S::ReadOnly<'a>: AvailabilityStorage + NodeStorage + PrunedHeightStorage, P: AvailabilityProvider, { async fn get(self: &Arc, req: R) -> Fetch @@ -668,7 +665,7 @@ where // notifications sent in between checking local storage and triggering a fetch if necessary. let passive = T::passive_fetch(self.storage.notifiers(), req).await; - let tx = match self.read().await { + let mut tx = match self.read().await { Ok(tx) => tx, Err(err) => { tracing::warn!( @@ -679,8 +676,8 @@ where } }; - self.ok_or_fetch(Some(&tx), passive, req, T::load(&tx, req).await) - .await + let res = T::load(&mut tx, req).await; + self.ok_or_fetch(Some(&mut tx), passive, req, res).await } /// Get a range of objects from local storage or a provider. @@ -763,9 +760,9 @@ where ) .await; - let (tx, ts) = match self.read().await { - Ok(tx) => { - let ts = T::load_range(&tx, chunk.clone()) + let (mut tx, ts) = match self.read().await { + Ok(mut tx) => { + let ts = T::load_range(&mut tx, chunk.clone()) .await .context(format!("when fetching items in range {chunk:?}")) .ok_or_trace() @@ -798,37 +795,29 @@ where fetches.len() ); fetches.push( - self.fetch(tx.as_ref(), passive.remove(0), chunk.start + fetches.len()) - .boxed(), + self.fetch(tx.as_mut(), passive.remove(0), chunk.start + fetches.len()) + .await, ); } // `t` itself is already available, we don't have to trigger a fetch for it. Remove (and // drop without awaiting) the passive fetch we preemptively started. drop(passive.remove(0)); - fetches.push(future::ready(Fetch::Ready(t)).boxed()); + fetches.push(Fetch::Ready(t)); } // Fetch missing objects from the end of the range. while fetches.len() < chunk.len() { fetches.push( - self.fetch(tx.as_ref(), passive.remove(0), chunk.start + fetches.len()) - .boxed(), + self.fetch(tx.as_mut(), passive.remove(0), chunk.start + fetches.len()) + .await, ); } - // We `join_all` here because we want this iterator to be evaluated eagerly for two reasons: - // 1. It borrows from `self`, which is local to this future. This avoids having to clone - // `self` for every entry, instead we clone it for every chunk. - // 2. We evaluate all the `some_or_fetch` calls eagerly, so the fetches are triggered as - // soon as we evaluate the chunk. This ensures we don't miss any notifications, since we - // load from storage and subscribe to notifications for missing objects all while we have - // a read lock on `self.storage`. No notifications can be sent during this time since - // sending a notification requires a write lock. - stream::iter(join_all(fetches).await) + stream::iter(fetches) } async fn ok_or_fetch( self: &Arc, - tx: Option<&::ReadOnly<'_>>, + tx: Option<&mut ::ReadOnly<'_>>, passive: PassiveFetch, req: R, res: QueryResult, @@ -849,7 +838,7 @@ where async fn some_or_fetch( self: &Arc, - tx: Option<&::ReadOnly<'_>>, + tx: Option<&mut ::ReadOnly<'_>>, passive: PassiveFetch, req: R, res: Option, @@ -866,7 +855,7 @@ where async fn fetch( self: &Arc, - tx: Option<&::ReadOnly<'_>>, + tx: Option<&mut ::ReadOnly<'_>>, passive: PassiveFetch, req: R, ) -> Fetch @@ -944,7 +933,7 @@ where // We can't start the scan until we know the current block height and pruned height, // so we know which blocks to scan. Thus we retry until this succeeds. let heights = loop { - let tx = match self.read().await { + let mut tx = match self.read().await { Ok(tx) => tx, Err(err) => { tracing::error!( @@ -1028,8 +1017,8 @@ impl MerklizedStateDataSource where Types: NodeType, - S: VersionedDataSource, - for<'a> S::ReadOnly<'a>: MerklizedStateDataSource + Send + Sync, + S: VersionedDataSource + 'static, + for<'a> S::ReadOnly<'a>: MerklizedStateStorage, P: Send + Sync, State: MerklizedState + 'static, ::Commitment: Send, @@ -1039,7 +1028,7 @@ where snapshot: Snapshot, key: State::Key, ) -> QueryResult> { - let tx = self.read().await.map_err(|err| QueryError::Error { + let mut tx = self.read().await.map_err(|err| QueryError::Error { message: err.to_string(), })?; tx.get_path(snapshot, key).await @@ -1051,12 +1040,12 @@ impl MerklizedStateHeightPersistence for FetchingDataSource: QueryablePayload, - S: VersionedDataSource, - for<'a> S::ReadOnly<'a>: MerklizedStateHeightPersistence + Send + Sync, + S: VersionedDataSource + 'static, + for<'a> S::ReadOnly<'a>: MerklizedStateHeightStorage, P: Send + Sync, { async fn get_last_state_height(&self) -> QueryResult { - let tx = self.read().await.map_err(|err| QueryError::Error { + let mut tx = self.read().await.map_err(|err| QueryError::Error { message: err.to_string(), })?; tx.get_last_state_height().await @@ -1068,25 +1057,25 @@ impl NodeDataSource for FetchingDataSource where Types: NodeType, S: VersionedDataSource + 'static, - for<'a> S::ReadOnly<'a>: NodeDataSource + Sync, + for<'a> S::ReadOnly<'a>: NodeStorage, P: Send + Sync, { async fn block_height(&self) -> QueryResult { - let tx = self.read().await.map_err(|err| QueryError::Error { + let mut tx = self.read().await.map_err(|err| QueryError::Error { message: err.to_string(), })?; tx.block_height().await } async fn count_transactions(&self) -> QueryResult { - let tx = self.read().await.map_err(|err| QueryError::Error { + let mut tx = self.read().await.map_err(|err| QueryError::Error { message: err.to_string(), })?; tx.count_transactions().await } async fn payload_size(&self) -> QueryResult { - let tx = self.read().await.map_err(|err| QueryError::Error { + let mut tx = self.read().await.map_err(|err| QueryError::Error { message: err.to_string(), })?; tx.payload_size().await @@ -1096,14 +1085,14 @@ where where ID: Into> + Send + Sync, { - let tx = self.read().await.map_err(|err| QueryError::Error { + let mut tx = self.read().await.map_err(|err| QueryError::Error { message: err.to_string(), })?; tx.vid_share(id).await } async fn sync_status(&self) -> QueryResult { - let tx = self.read().await.map_err(|err| QueryError::Error { + let mut tx = self.read().await.map_err(|err| QueryError::Error { message: err.to_string(), })?; tx.sync_status().await @@ -1114,7 +1103,7 @@ where start: impl Into> + Send + Sync, end: u64, ) -> QueryResult>> { - let tx = self.read().await.map_err(|err| QueryError::Error { + let mut tx = self.read().await.map_err(|err| QueryError::Error { message: err.to_string(), })?; tx.get_header_window(start, end).await @@ -1122,14 +1111,14 @@ where } #[async_trait] -impl ExplorerStorage for FetchingDataSource +impl ExplorerDataSource for FetchingDataSource where Types: NodeType, Payload: QueryablePayload, Header: QueryableHeader + explorer::traits::ExplorerHeader, crate::Transaction: explorer::traits::ExplorerTransaction, - S: VersionedDataSource, - for<'a> S::ReadOnly<'a>: ExplorerStorage + Send + Sync, + S: VersionedDataSource + 'static, + for<'a> S::ReadOnly<'a>: ExplorerStorage, P: Send + Sync, { async fn get_block_summaries( @@ -1139,7 +1128,7 @@ where Vec>, explorer::query_data::GetBlockSummariesError, > { - let tx = self.read().await.map_err(|err| QueryError::Error { + let mut tx = self.read().await.map_err(|err| QueryError::Error { message: err.to_string(), })?; tx.get_block_summaries(request).await @@ -1150,7 +1139,7 @@ where request: explorer::query_data::BlockIdentifier, ) -> Result, explorer::query_data::GetBlockDetailError> { - let tx = self.read().await.map_err(|err| QueryError::Error { + let mut tx = self.read().await.map_err(|err| QueryError::Error { message: err.to_string(), })?; tx.get_block_detail(request).await @@ -1163,7 +1152,7 @@ where Vec>, explorer::query_data::GetTransactionSummariesError, > { - let tx = self.read().await.map_err(|err| QueryError::Error { + let mut tx = self.read().await.map_err(|err| QueryError::Error { message: err.to_string(), })?; tx.get_transaction_summaries(request).await @@ -1176,7 +1165,7 @@ where explorer::query_data::TransactionDetailResponse, explorer::query_data::GetTransactionDetailError, > { - let tx = self.read().await.map_err(|err| QueryError::Error { + let mut tx = self.read().await.map_err(|err| QueryError::Error { message: err.to_string(), })?; tx.get_transaction_detail(request).await @@ -1188,7 +1177,7 @@ where explorer::query_data::ExplorerSummary, explorer::query_data::GetExplorerSummaryError, > { - let tx = self.read().await.map_err(|err| QueryError::Error { + let mut tx = self.read().await.map_err(|err| QueryError::Error { message: err.to_string(), })?; tx.get_explorer_summary().await @@ -1201,7 +1190,7 @@ where explorer::query_data::SearchResult, explorer::query_data::GetSearchResultsError, > { - let tx = self.read().await.map_err(|err| QueryError::Error { + let mut tx = self.read().await.map_err(|err| QueryError::Error { message: err.to_string(), })?; tx.get_search_results(query).await @@ -1277,7 +1266,7 @@ where /// receive it passively, since we will eventually receive all blocks and leaves that are ever /// produced. Active fetching merely helps us receive certain objects sooner. async fn active_fetch( - tx: &impl AvailabilityStorage, + tx: &mut impl AvailabilityStorage, fetcher: Arc>, req: Self::Request, ) where @@ -1292,7 +1281,7 @@ where /// /// This function assumes `req.might_exist()` has already been checked before calling it, and so /// may do unnecessary work if the caller does not ensure this. - async fn load(storage: &S, req: Self::Request) -> QueryResult + async fn load(storage: &mut S, req: Self::Request) -> QueryResult where S: AvailabilityStorage; } @@ -1308,7 +1297,7 @@ where type RangedRequest: FetchRequest + From + Send; /// Load a range of these objects from local storage. - async fn load_range(storage: &S, range: R) -> QueryResult>> + async fn load_range(storage: &mut S, range: R) -> QueryResult>> where S: AvailabilityStorage, R: RangeBounds + Send + 'static; diff --git a/src/data_source/fetching/block.rs b/src/data_source/fetching/block.rs index 252b7b13b..a494908b4 100644 --- a/src/data_source/fetching/block.rs +++ b/src/data_source/fetching/block.rs @@ -82,7 +82,7 @@ where } async fn active_fetch( - tx: &impl AvailabilityStorage, + tx: &mut impl AvailabilityStorage, fetcher: Arc>, req: Self::Request, ) where @@ -100,7 +100,7 @@ where .await } - async fn load(storage: &S, req: Self::Request) -> QueryResult + async fn load(storage: &mut S, req: Self::Request) -> QueryResult where S: AvailabilityStorage, { @@ -116,7 +116,7 @@ where { type RangedRequest = BlockId; - async fn load_range(storage: &S, range: R) -> QueryResult>> + async fn load_range(storage: &mut S, range: R) -> QueryResult>> where S: AvailabilityStorage, R: RangeBounds + Send + 'static, @@ -196,7 +196,7 @@ where } async fn active_fetch( - tx: &impl AvailabilityStorage, + tx: &mut impl AvailabilityStorage, fetcher: Arc>, req: Self::Request, ) where @@ -210,7 +210,7 @@ where BlockQueryData::active_fetch(tx, fetcher, req).await } - async fn load(storage: &S, req: Self::Request) -> QueryResult + async fn load(storage: &mut S, req: Self::Request) -> QueryResult where S: AvailabilityStorage, { @@ -226,7 +226,7 @@ where { type RangedRequest = BlockId; - async fn load_range(storage: &S, range: R) -> QueryResult>> + async fn load_range(storage: &mut S, range: R) -> QueryResult>> where S: AvailabilityStorage, R: RangeBounds + Send + 'static, diff --git a/src/data_source/fetching/header.rs b/src/data_source/fetching/header.rs index 9270989df..1dd5e9f52 100644 --- a/src/data_source/fetching/header.rs +++ b/src/data_source/fetching/header.rs @@ -106,7 +106,7 @@ where } pub(super) async fn fetch_header_and_then( - tx: &impl AvailabilityStorage, + tx: &mut impl AvailabilityStorage, req: BlockId, callback: HeaderCallback, ) where diff --git a/src/data_source/fetching/leaf.rs b/src/data_source/fetching/leaf.rs index 6618c395a..5941ab25c 100644 --- a/src/data_source/fetching/leaf.rs +++ b/src/data_source/fetching/leaf.rs @@ -75,7 +75,7 @@ where } async fn active_fetch( - _tx: &impl AvailabilityStorage, + _tx: &mut impl AvailabilityStorage, fetcher: Arc>, req: Self::Request, ) where @@ -86,7 +86,7 @@ where fetch_leaf_with_callbacks(fetcher, req, None) } - async fn load(storage: &S, req: Self::Request) -> QueryResult + async fn load(storage: &mut S, req: Self::Request) -> QueryResult where S: AvailabilityStorage, { @@ -149,7 +149,7 @@ where { type RangedRequest = LeafId; - async fn load_range(storage: &S, range: R) -> QueryResult>> + async fn load_range(storage: &mut S, range: R) -> QueryResult>> where S: AvailabilityStorage, R: RangeBounds + Send + 'static, diff --git a/src/data_source/fetching/notify_storage.rs b/src/data_source/fetching/notify_storage.rs index d6b26e377..157f711d1 100644 --- a/src/data_source/fetching/notify_storage.rs +++ b/src/data_source/fetching/notify_storage.rs @@ -20,16 +20,14 @@ use crate::{ notifier::Notifier, storage::{ pruning::{PruneStorage, PrunedHeightStorage, PrunerCfg}, - AvailabilityStorage, ExplorerStorage, + AvailabilityStorage, ExplorerStorage, MerklizedStateHeightStorage, + MerklizedStateStorage, NodeStorage, }, update::{self, VersionedDataSource}, }, explorer, - merklized_state::{ - MerklizedState, MerklizedStateDataSource, MerklizedStateHeightPersistence, Snapshot, - UpdateStateData, - }, - node::{NodeDataSource, SyncStatus, TimeWindowQueryData, WindowStart}, + merklized_state::{MerklizedState, Snapshot, UpdateStateData}, + node::{SyncStatus, TimeWindowQueryData, WindowStart}, Header, Payload, QueryResult, VidShare, }; use anyhow::Context; @@ -218,9 +216,9 @@ where impl<'a, Types, T> PrunedHeightStorage for Transaction<'a, Types, T> where Types: NodeType, - T: PrunedHeightStorage + Sync, + T: PrunedHeightStorage + Send, { - async fn load_pruned_height(&self) -> anyhow::Result> { + async fn load_pruned_height(&mut self) -> anyhow::Result> { self.inner.load_pruned_height().await } } @@ -228,9 +226,9 @@ where impl<'a, Types, T> Transaction<'a, Types, T> where Types: NodeType, - T: PrunedHeightStorage + NodeDataSource + Sync, + T: PrunedHeightStorage + NodeStorage + Send + Sync, { - pub(super) async fn heights(&self) -> anyhow::Result { + pub(super) async fn heights(&mut self) -> anyhow::Result { let height = self.block_height().await.context("loading block height")? as u64; let pruned_height = self .load_pruned_height() @@ -244,32 +242,31 @@ where } #[async_trait] -impl<'a, Types, T, State, const ARITY: usize> MerklizedStateDataSource +impl<'a, Types, T, State, const ARITY: usize> MerklizedStateStorage for Transaction<'a, Types, T> where Types: NodeType, - T: MerklizedStateDataSource + Send + Sync, + T: MerklizedStateStorage + Send, State: MerklizedState + 'static, ::Commitment: Send, { async fn get_path( - &self, + &mut self, snapshot: Snapshot, key: State::Key, ) -> QueryResult> { - self.as_ref().get_path(snapshot, key).await + self.as_mut().get_path(snapshot, key).await } } #[async_trait] -impl<'a, Types, T> MerklizedStateHeightPersistence for Transaction<'a, Types, T> +impl<'a, Types, T> MerklizedStateHeightStorage for Transaction<'a, Types, T> where Types: NodeType, - Payload: QueryablePayload, - T: MerklizedStateHeightPersistence + Send + Sync, + T: MerklizedStateHeightStorage + Send, { - async fn get_last_state_height(&self) -> QueryResult { - self.as_ref().get_last_state_height().await + async fn get_last_state_height(&mut self) -> QueryResult { + self.as_mut().get_last_state_height().await } } @@ -349,28 +346,31 @@ where Payload: QueryablePayload, T: AvailabilityStorage, { - async fn get_leaf(&self, id: LeafId) -> QueryResult> { + async fn get_leaf(&mut self, id: LeafId) -> QueryResult> { self.inner.get_leaf(id).await } - async fn get_block(&self, id: BlockId) -> QueryResult> { + async fn get_block(&mut self, id: BlockId) -> QueryResult> { self.inner.get_block(id).await } - async fn get_header(&self, id: BlockId) -> QueryResult> { + async fn get_header(&mut self, id: BlockId) -> QueryResult> { self.inner.get_header(id).await } - async fn get_payload(&self, id: BlockId) -> QueryResult> { + async fn get_payload(&mut self, id: BlockId) -> QueryResult> { self.inner.get_payload(id).await } - async fn get_vid_common(&self, id: BlockId) -> QueryResult> { + async fn get_vid_common( + &mut self, + id: BlockId, + ) -> QueryResult> { self.inner.get_vid_common(id).await } async fn get_leaf_range( - &self, + &mut self, range: R, ) -> QueryResult>>> where @@ -380,7 +380,7 @@ where } async fn get_block_range( - &self, + &mut self, range: R, ) -> QueryResult>>> where @@ -390,7 +390,7 @@ where } async fn get_payload_range( - &self, + &mut self, range: R, ) -> QueryResult>>> where @@ -400,7 +400,7 @@ where } async fn get_vid_common_range( - &self, + &mut self, range: R, ) -> QueryResult>>> where @@ -410,7 +410,7 @@ where } async fn get_transaction( - &self, + &mut self, hash: TransactionHash, ) -> QueryResult> { self.inner.get_transaction(hash).await @@ -418,36 +418,36 @@ where } #[async_trait] -impl<'a, Types, T> NodeDataSource for Transaction<'a, Types, T> +impl<'a, Types, T> NodeStorage for Transaction<'a, Types, T> where Types: NodeType, - T: NodeDataSource + Sync, + T: NodeStorage + Send, { - async fn block_height(&self) -> QueryResult { + async fn block_height(&mut self) -> QueryResult { self.inner.block_height().await } - async fn count_transactions(&self) -> QueryResult { + async fn count_transactions(&mut self) -> QueryResult { self.inner.count_transactions().await } - async fn payload_size(&self) -> QueryResult { + async fn payload_size(&mut self) -> QueryResult { self.inner.payload_size().await } - async fn vid_share(&self, id: ID) -> QueryResult + async fn vid_share(&mut self, id: ID) -> QueryResult where ID: Into> + Send + Sync, { self.inner.vid_share(id).await } - async fn sync_status(&self) -> QueryResult { + async fn sync_status(&mut self) -> QueryResult { self.inner.sync_status().await } async fn get_header_window( - &self, + &mut self, start: impl Into> + Send + Sync, end: u64, ) -> QueryResult>> { @@ -462,63 +462,63 @@ where Payload: QueryablePayload, Header: QueryableHeader + explorer::traits::ExplorerHeader, crate::Transaction: explorer::traits::ExplorerTransaction, - T: ExplorerStorage + Send + Sync, + T: ExplorerStorage + Send, { async fn get_block_summaries( - &self, + &mut self, request: explorer::query_data::GetBlockSummariesRequest, ) -> Result< Vec>, explorer::query_data::GetBlockSummariesError, > { - self.as_ref().get_block_summaries(request).await + self.as_mut().get_block_summaries(request).await } async fn get_block_detail( - &self, + &mut self, request: explorer::query_data::BlockIdentifier, ) -> Result, explorer::query_data::GetBlockDetailError> { - self.as_ref().get_block_detail(request).await + self.as_mut().get_block_detail(request).await } async fn get_transaction_summaries( - &self, + &mut self, request: explorer::query_data::GetTransactionSummariesRequest, ) -> Result< Vec>, explorer::query_data::GetTransactionSummariesError, > { - self.as_ref().get_transaction_summaries(request).await + self.as_mut().get_transaction_summaries(request).await } async fn get_transaction_detail( - &self, + &mut self, request: explorer::query_data::TransactionIdentifier, ) -> Result< explorer::query_data::TransactionDetailResponse, explorer::query_data::GetTransactionDetailError, > { - self.as_ref().get_transaction_detail(request).await + self.as_mut().get_transaction_detail(request).await } async fn get_explorer_summary( - &self, + &mut self, ) -> Result< explorer::query_data::ExplorerSummary, explorer::query_data::GetExplorerSummaryError, > { - self.as_ref().get_explorer_summary().await + self.as_mut().get_explorer_summary().await } async fn get_search_results( - &self, + &mut self, query: String, ) -> Result< explorer::query_data::SearchResult, explorer::query_data::GetSearchResultsError, > { - self.as_ref().get_search_results(query).await + self.as_mut().get_search_results(query).await } } diff --git a/src/data_source/fetching/transaction.rs b/src/data_source/fetching/transaction.rs index 5a5b5d20d..b7ab5df0e 100644 --- a/src/data_source/fetching/transaction.rs +++ b/src/data_source/fetching/transaction.rs @@ -62,7 +62,7 @@ where } async fn active_fetch( - _tx: &impl AvailabilityStorage, + _tx: &mut impl AvailabilityStorage, _fetcher: Arc>, req: Self::Request, ) where @@ -76,7 +76,7 @@ where tracing::debug!("not fetching unknown transaction {req:?}"); } - async fn load(storage: &S, req: Self::Request) -> QueryResult + async fn load(storage: &mut S, req: Self::Request) -> QueryResult where S: AvailabilityStorage, { diff --git a/src/data_source/fetching/vid.rs b/src/data_source/fetching/vid.rs index b9b39b711..64d6a5a80 100644 --- a/src/data_source/fetching/vid.rs +++ b/src/data_source/fetching/vid.rs @@ -82,7 +82,7 @@ where } async fn active_fetch( - tx: &impl AvailabilityStorage, + tx: &mut impl AvailabilityStorage, fetcher: Arc>, req: Self::Request, ) where @@ -100,7 +100,7 @@ where .await } - async fn load(storage: &S, req: Self::Request) -> QueryResult + async fn load(storage: &mut S, req: Self::Request) -> QueryResult where S: AvailabilityStorage, { @@ -116,7 +116,7 @@ where { type RangedRequest = VidCommonRequest; - async fn load_range(storage: &S, range: R) -> QueryResult>> + async fn load_range(storage: &mut S, range: R) -> QueryResult>> where S: AvailabilityStorage, R: RangeBounds + Send + 'static, diff --git a/src/data_source/sql.rs b/src/data_source/sql.rs index 435578ddf..f9c9dcb3f 100644 --- a/src/data_source/sql.rs +++ b/src/data_source/sql.rs @@ -25,11 +25,11 @@ use crate::{ pub use anyhow::Error; use hotshot_types::traits::node_implementation::NodeType; pub use refinery::Migration; -pub use tokio_postgres as postgres; -pub use sql::{Config, Transaction}; +pub use sql::{Postgres, Transaction}; pub type Builder = fetching::Builder; +pub type Config = sql::Config; impl Config { /// Connect to the database with this config. @@ -374,9 +374,8 @@ mod test { availability::{ AvailabilityDataSource, LeafQueryData, UpdateAvailabilityData, VidCommonQueryData, }, - data_source::{Transaction, VersionedDataSource}, + data_source::{storage::NodeStorage, Transaction, VersionedDataSource}, fetching::provider::NoFetching, - node::NodeDataSource, testing::{consensus::DataSourceLifeCycle, mocks::MockTypes, setup_test}, }; use hotshot_example_types::state_types::{TestInstanceState, TestValidatedState}; diff --git a/src/data_source/storage.rs b/src/data_source/storage.rs index b556ca7ba..7029adb71 100644 --- a/src/data_source/storage.rs +++ b/src/data_source/storage.rs @@ -29,7 +29,6 @@ use crate::{ BlockId, BlockQueryData, LeafId, LeafQueryData, PayloadQueryData, QueryableHeader, QueryablePayload, TransactionHash, TransactionQueryData, VidCommonQueryData, }, - data_source::ReadOnly, explorer::{ query_data::{ BlockDetail, BlockIdentifier, BlockSummary, ExplorerSummary, GetBlockDetailError, @@ -40,10 +39,13 @@ use crate::{ }, traits::{ExplorerHeader, ExplorerTransaction}, }, - Header, Payload, QueryResult, Transaction, + merklized_state::{MerklizedState, Snapshot}, + node::{SyncStatus, TimeWindowQueryData, WindowStart}, + Header, Payload, QueryResult, Transaction, VidShare, }; use async_trait::async_trait; use hotshot_types::traits::node_implementation::NodeType; +use jf_merkle_tree::prelude::MerkleProof; use std::ops::RangeBounds; pub mod fs; @@ -79,116 +81,62 @@ where Types: NodeType, Payload: QueryablePayload, { - async fn get_leaf(&self, id: LeafId) -> QueryResult>; - async fn get_block(&self, id: BlockId) -> QueryResult>; - async fn get_header(&self, id: BlockId) -> QueryResult>; - async fn get_payload(&self, id: BlockId) -> QueryResult>; - async fn get_vid_common(&self, id: BlockId) -> QueryResult>; + async fn get_leaf(&mut self, id: LeafId) -> QueryResult>; + async fn get_block(&mut self, id: BlockId) -> QueryResult>; + async fn get_header(&mut self, id: BlockId) -> QueryResult>; + async fn get_payload(&mut self, id: BlockId) -> QueryResult>; + async fn get_vid_common( + &mut self, + id: BlockId, + ) -> QueryResult>; async fn get_leaf_range( - &self, + &mut self, range: R, ) -> QueryResult>>> where R: RangeBounds + Send + 'static; async fn get_block_range( - &self, + &mut self, range: R, ) -> QueryResult>>> where R: RangeBounds + Send + 'static; async fn get_payload_range( - &self, + &mut self, range: R, ) -> QueryResult>>> where R: RangeBounds + Send + 'static; async fn get_vid_common_range( - &self, + &mut self, range: R, ) -> QueryResult>>> where R: RangeBounds + Send + 'static; async fn get_transaction( - &self, + &mut self, hash: TransactionHash, ) -> QueryResult>; } #[async_trait] -impl AvailabilityStorage for ReadOnly -where - Types: NodeType, - Payload: QueryablePayload, - T: AvailabilityStorage, -{ - async fn get_leaf(&self, id: LeafId) -> QueryResult> { - (**self).get_leaf(id).await - } - - async fn get_block(&self, id: BlockId) -> QueryResult> { - (**self).get_block(id).await - } - - async fn get_header(&self, id: BlockId) -> QueryResult> { - (**self).get_header(id).await - } - - async fn get_payload(&self, id: BlockId) -> QueryResult> { - (**self).get_payload(id).await - } - - async fn get_vid_common(&self, id: BlockId) -> QueryResult> { - (**self).get_vid_common(id).await - } - - async fn get_leaf_range( - &self, - range: R, - ) -> QueryResult>>> - where - R: RangeBounds + Send + 'static, - { - (**self).get_leaf_range(range).await - } - - async fn get_block_range( - &self, - range: R, - ) -> QueryResult>>> +pub trait NodeStorage { + async fn block_height(&mut self) -> QueryResult; + async fn count_transactions(&mut self) -> QueryResult; + async fn payload_size(&mut self) -> QueryResult; + async fn vid_share(&mut self, id: ID) -> QueryResult where - R: RangeBounds + Send + 'static, - { - (**self).get_block_range(range).await - } - - async fn get_payload_range( - &self, - range: R, - ) -> QueryResult>>> - where - R: RangeBounds + Send + 'static, - { - (**self).get_payload_range(range).await - } - - async fn get_vid_common_range( - &self, - range: R, - ) -> QueryResult>>> - where - R: RangeBounds + Send + 'static, - { - (**self).get_vid_common_range(range).await - } - - async fn get_transaction( - &self, - hash: TransactionHash, - ) -> QueryResult> { - (**self).get_transaction(hash).await - } + ID: Into> + Send + Sync; + async fn get_header_window( + &mut self, + start: impl Into> + Send + Sync, + end: u64, + ) -> QueryResult>>; + + /// Search the database for missing objects and generate a report. + async fn sync_status(&mut self) -> QueryResult; } /// An interface for querying Data and Statistics from the HotShot Blockchain. @@ -212,7 +160,7 @@ where /// block from the blockchain. The block is identified by the given /// [BlockIdentifier]. async fn get_block_detail( - &self, + &mut self, request: BlockIdentifier, ) -> Result, GetBlockDetailError>; @@ -220,7 +168,7 @@ where /// summaries from the blockchain. The list is generated from the given /// [GetBlockSummariesRequest]. async fn get_block_summaries( - &self, + &mut self, request: GetBlockSummariesRequest, ) -> Result>, GetBlockSummariesError>; @@ -228,7 +176,7 @@ where /// specific transaction from the blockchain. The transaction is identified /// by the given [TransactionIdentifier]. async fn get_transaction_detail( - &self, + &mut self, request: TransactionIdentifier, ) -> Result, GetTransactionDetailError>; @@ -236,72 +184,42 @@ where /// transaction summaries from the blockchain. The list is generated from /// the given [GetTransactionSummariesRequest]. async fn get_transaction_summaries( - &self, + &mut self, request: GetTransactionSummariesRequest, ) -> Result>, GetTransactionSummariesError>; /// `get_explorer_summary` is a method that retrieves a summary overview of /// the blockchain. This is useful for displaying information that /// indicates the overall status of the block chain. - async fn get_explorer_summary(&self) - -> Result, GetExplorerSummaryError>; + async fn get_explorer_summary( + &mut self, + ) -> Result, GetExplorerSummaryError>; /// `get_search_results` is a method that retrieves the results of a search /// query against the blockchain. The results are generated from the given /// query string. async fn get_search_results( - &self, + &mut self, query: String, ) -> Result, GetSearchResultsError>; } +/// This trait defines methods that a data source should implement +/// It enables retrieval of the membership path for a leaf node, which can be used to reconstruct the Merkle tree state. #[async_trait] -impl ExplorerStorage for ReadOnly +pub trait MerklizedStateStorage where Types: NodeType, - Header: ExplorerHeader + QueryableHeader, - Transaction: ExplorerTransaction, - Payload: QueryablePayload, - T: ExplorerStorage + Sync, + State: MerklizedState, { - async fn get_block_detail( - &self, - request: BlockIdentifier, - ) -> Result, GetBlockDetailError> { - (**self).get_block_detail(request).await - } - - async fn get_block_summaries( - &self, - request: GetBlockSummariesRequest, - ) -> Result>, GetBlockSummariesError> { - (**self).get_block_summaries(request).await - } - - async fn get_transaction_detail( - &self, - request: TransactionIdentifier, - ) -> Result, GetTransactionDetailError> { - (**self).get_transaction_detail(request).await - } - - async fn get_transaction_summaries( - &self, - request: GetTransactionSummariesRequest, - ) -> Result>, GetTransactionSummariesError> { - (**self).get_transaction_summaries(request).await - } - - async fn get_explorer_summary( - &self, - ) -> Result, GetExplorerSummaryError> { - (**self).get_explorer_summary().await - } + async fn get_path( + &mut self, + snapshot: Snapshot, + key: State::Key, + ) -> QueryResult>; +} - async fn get_search_results( - &self, - query: String, - ) -> Result, GetSearchResultsError> { - (**self).get_search_results(query).await - } +#[async_trait] +pub trait MerklizedStateHeightStorage { + async fn get_last_state_height(&mut self) -> QueryResult; } diff --git a/src/data_source/storage/fs.rs b/src/data_source/storage/fs.rs index 2c2a65bf9..f07cbb4be 100644 --- a/src/data_source/storage/fs.rs +++ b/src/data_source/storage/fs.rs @@ -15,7 +15,7 @@ use super::{ ledger_log::{Iter, LedgerLog}, pruning::{PruneStorage, PrunedHeightStorage, PrunerConfig}, - AvailabilityStorage, + AvailabilityStorage, NodeStorage, }; use crate::{ @@ -27,7 +27,7 @@ use crate::{ }, }, data_source::{update, VersionedDataSource}, - node::{NodeDataSource, SyncStatus, TimeWindowQueryData, WindowStart}, + node::{SyncStatus, TimeWindowQueryData, WindowStart}, types::HeightIndexed, ErrorSnafu, Header, MissingSnafu, NotFoundSnafu, Payload, QueryResult, VidCommitment, VidShare, }; @@ -419,7 +419,7 @@ where Header: QueryableHeader, T: Revert + Deref> + Send + Sync, { - async fn get_leaf(&self, id: LeafId) -> QueryResult> { + async fn get_leaf(&mut self, id: LeafId) -> QueryResult> { let n = match id { LeafId::Number(n) => n, LeafId::Hash(h) => *self @@ -436,19 +436,22 @@ where .context(MissingSnafu) } - async fn get_block(&self, id: BlockId) -> QueryResult> { + async fn get_block(&mut self, id: BlockId) -> QueryResult> { self.inner.get_block(id) } - async fn get_header(&self, id: BlockId) -> QueryResult> { + async fn get_header(&mut self, id: BlockId) -> QueryResult> { self.inner.get_header(id) } - async fn get_payload(&self, id: BlockId) -> QueryResult> { + async fn get_payload(&mut self, id: BlockId) -> QueryResult> { self.get_block(id).await.map(PayloadQueryData::from) } - async fn get_vid_common(&self, id: BlockId) -> QueryResult> { + async fn get_vid_common( + &mut self, + id: BlockId, + ) -> QueryResult> { Ok(self .inner .vid_storage @@ -460,7 +463,7 @@ where } async fn get_leaf_range( - &self, + &mut self, range: R, ) -> QueryResult>>> where @@ -470,7 +473,7 @@ where } async fn get_block_range( - &self, + &mut self, range: R, ) -> QueryResult>>> where @@ -480,7 +483,7 @@ where } async fn get_payload_range( - &self, + &mut self, range: R, ) -> QueryResult>>> where @@ -492,7 +495,7 @@ where } async fn get_vid_common_range( - &self, + &mut self, range: R, ) -> QueryResult>>> where @@ -504,7 +507,7 @@ where } async fn get_transaction( - &self, + &mut self, hash: TransactionHash, ) -> QueryResult> { let height = self @@ -605,26 +608,26 @@ fn update_index_by_hash(index: &mut HashMap, hash: H } #[async_trait] -impl NodeDataSource for Transaction +impl NodeStorage for Transaction where Types: NodeType, Payload: QueryablePayload, Header: QueryableHeader, - T: Revert + Deref> + Sync, + T: Revert + Deref> + Send, { - async fn block_height(&self) -> QueryResult { + async fn block_height(&mut self) -> QueryResult { Ok(self.inner.leaf_storage.iter().len()) } - async fn count_transactions(&self) -> QueryResult { + async fn count_transactions(&mut self) -> QueryResult { Ok(self.inner.num_transactions) } - async fn payload_size(&self) -> QueryResult { + async fn payload_size(&mut self) -> QueryResult { Ok(self.inner.payload_size) } - async fn vid_share(&self, id: ID) -> QueryResult + async fn vid_share(&mut self, id: ID) -> QueryResult where ID: Into> + Send + Sync, { @@ -638,7 +641,7 @@ where .context(MissingSnafu) } - async fn sync_status(&self) -> QueryResult { + async fn sync_status(&mut self) -> QueryResult { let height = self.inner.leaf_storage.iter().len(); // The number of missing VID common is just the number of completely missing VID @@ -662,7 +665,7 @@ where } async fn get_header_window( - &self, + &mut self, start: impl Into> + Send + Sync, end: u64, ) -> QueryResult>> { diff --git a/src/data_source/storage/no_storage.rs b/src/data_source/storage/no_storage.rs index 6ee6f74bf..e1dc7d5fd 100644 --- a/src/data_source/storage/no_storage.rs +++ b/src/data_source/storage/no_storage.rs @@ -12,17 +12,17 @@ #![cfg(feature = "no-storage")] -use super::AvailabilityStorage; +use super::{ + pruning::{PruneStorage, PrunedHeightStorage, PrunerConfig}, + AvailabilityStorage, NodeStorage, +}; use crate::{ availability::{ BlockId, BlockQueryData, LeafId, LeafQueryData, PayloadQueryData, QueryablePayload, TransactionHash, TransactionQueryData, UpdateAvailabilityData, VidCommonQueryData, }, - data_source::{ - storage::pruning::{PruneStorage, PrunedHeightStorage, PrunerConfig}, - update, VersionedDataSource, - }, - node::{NodeDataSource, SyncStatus, TimeWindowQueryData, WindowStart}, + data_source::{update, VersionedDataSource}, + node::{SyncStatus, TimeWindowQueryData, WindowStart}, types::HeightIndexed, Header, Payload, QueryError, QueryResult, VidShare, }; @@ -98,28 +98,31 @@ impl<'a, Types: NodeType> AvailabilityStorage for Transaction<'a> where Payload: QueryablePayload, { - async fn get_leaf(&self, _id: LeafId) -> QueryResult> { + async fn get_leaf(&mut self, _id: LeafId) -> QueryResult> { Err(QueryError::Missing) } - async fn get_block(&self, _id: BlockId) -> QueryResult> { + async fn get_block(&mut self, _id: BlockId) -> QueryResult> { Err(QueryError::Missing) } - async fn get_header(&self, _id: BlockId) -> QueryResult> { + async fn get_header(&mut self, _id: BlockId) -> QueryResult> { Err(QueryError::Missing) } - async fn get_payload(&self, _id: BlockId) -> QueryResult> { + async fn get_payload(&mut self, _id: BlockId) -> QueryResult> { Err(QueryError::Missing) } - async fn get_vid_common(&self, _id: BlockId) -> QueryResult> { + async fn get_vid_common( + &mut self, + _id: BlockId, + ) -> QueryResult> { Err(QueryError::Missing) } async fn get_leaf_range( - &self, + &mut self, _range: R, ) -> QueryResult>>> where @@ -129,7 +132,7 @@ where } async fn get_block_range( - &self, + &mut self, _range: R, ) -> QueryResult>>> where @@ -139,7 +142,7 @@ where } async fn get_payload_range( - &self, + &mut self, _range: R, ) -> QueryResult>>> where @@ -149,7 +152,7 @@ where } async fn get_vid_common_range( - &self, + &mut self, _range: R, ) -> QueryResult>>> where @@ -159,7 +162,7 @@ where } async fn get_transaction( - &self, + &mut self, _hash: TransactionHash, ) -> QueryResult> { Err(QueryError::Missing) @@ -192,35 +195,35 @@ where } #[async_trait] -impl<'a, Types: NodeType> NodeDataSource for Transaction<'a> +impl<'a, Types: NodeType> NodeStorage for Transaction<'a> where Payload: QueryablePayload, { - async fn block_height(&self) -> QueryResult { + async fn block_height(&mut self) -> QueryResult { Ok(self.height as usize) } - async fn count_transactions(&self) -> QueryResult { + async fn count_transactions(&mut self) -> QueryResult { Err(QueryError::Missing) } - async fn payload_size(&self) -> QueryResult { + async fn payload_size(&mut self) -> QueryResult { Err(QueryError::Missing) } - async fn vid_share(&self, _id: ID) -> QueryResult + async fn vid_share(&mut self, _id: ID) -> QueryResult where ID: Into> + Send + Sync, { Err(QueryError::Missing) } - async fn sync_status(&self) -> QueryResult { + async fn sync_status(&mut self) -> QueryResult { Err(QueryError::Missing) } async fn get_header_window( - &self, + &mut self, _start: impl Into> + Send + Sync, _end: u64, ) -> QueryResult>> { @@ -577,56 +580,56 @@ pub mod testing { } #[async_trait] - impl<'a, T> NodeDataSource for Transaction<'a, T> + impl<'a, T> NodeStorage for Transaction<'a, T> where - T: NodeDataSource + Sync, + T: NodeStorage + Send, { - async fn block_height(&self) -> QueryResult { + async fn block_height(&mut self) -> QueryResult { match self { - Self::Sql(tx) => NodeDataSource::block_height(tx).await, - Self::NoStorage(tx) => NodeDataSource::block_height(tx).await, + Transaction::Sql(tx) => tx.block_height().await, + Transaction::NoStorage(tx) => tx.block_height().await, } } - async fn count_transactions(&self) -> QueryResult { + async fn count_transactions(&mut self) -> QueryResult { match self { - Self::Sql(tx) => tx.count_transactions().await, - Self::NoStorage(tx) => tx.count_transactions().await, + Transaction::Sql(tx) => tx.count_transactions().await, + Transaction::NoStorage(tx) => tx.count_transactions().await, } } - async fn payload_size(&self) -> QueryResult { + async fn payload_size(&mut self) -> QueryResult { match self { - Self::Sql(tx) => tx.payload_size().await, - Self::NoStorage(tx) => tx.payload_size().await, + Transaction::Sql(tx) => tx.payload_size().await, + Transaction::NoStorage(tx) => tx.payload_size().await, } } - async fn vid_share(&self, id: ID) -> QueryResult + async fn vid_share(&mut self, id: ID) -> QueryResult where ID: Into> + Send + Sync, { match self { - Self::Sql(tx) => tx.vid_share(id).await, - Self::NoStorage(tx) => tx.vid_share(id).await, + Transaction::Sql(tx) => tx.vid_share(id).await, + Transaction::NoStorage(tx) => tx.vid_share(id).await, } } - async fn sync_status(&self) -> QueryResult { + async fn sync_status(&mut self) -> QueryResult { match self { - Self::Sql(tx) => tx.sync_status().await, - Self::NoStorage(tx) => tx.sync_status().await, + Transaction::Sql(tx) => tx.sync_status().await, + Transaction::NoStorage(tx) => tx.sync_status().await, } } async fn get_header_window( - &self, + &mut self, start: impl Into> + Send + Sync, end: u64, ) -> QueryResult>> { match self { - Self::Sql(tx) => tx.get_header_window(start, end).await, - Self::NoStorage(tx) => tx.get_header_window(start, end).await, + Transaction::Sql(tx) => tx.get_header_window(start, end).await, + Transaction::NoStorage(tx) => tx.get_header_window(start, end).await, } } } @@ -652,22 +655,24 @@ pub mod testing { impl NodeDataSource for DataSource { async fn block_height(&self) -> QueryResult { match self { - Self::Sql(data_source) => NodeDataSource::block_height(data_source).await, - Self::NoStorage(data_source) => NodeDataSource::block_height(data_source).await, + DataSource::Sql(data_source) => NodeDataSource::block_height(data_source).await, + DataSource::NoStorage(data_source) => { + NodeDataSource::block_height(data_source).await + } } } async fn count_transactions(&self) -> QueryResult { match self { - Self::Sql(data_source) => data_source.count_transactions().await, - Self::NoStorage(data_source) => data_source.count_transactions().await, + DataSource::Sql(data_source) => data_source.count_transactions().await, + DataSource::NoStorage(data_source) => data_source.count_transactions().await, } } async fn payload_size(&self) -> QueryResult { match self { - Self::Sql(data_source) => data_source.payload_size().await, - Self::NoStorage(data_source) => data_source.payload_size().await, + DataSource::Sql(data_source) => data_source.payload_size().await, + DataSource::NoStorage(data_source) => data_source.payload_size().await, } } @@ -676,15 +681,15 @@ pub mod testing { ID: Into> + Send + Sync, { match self { - Self::Sql(data_source) => data_source.vid_share(id).await, - Self::NoStorage(data_source) => data_source.vid_share(id).await, + DataSource::Sql(data_source) => data_source.vid_share(id).await, + DataSource::NoStorage(data_source) => data_source.vid_share(id).await, } } async fn sync_status(&self) -> QueryResult { match self { - Self::Sql(data_source) => data_source.sync_status().await, - Self::NoStorage(data_source) => data_source.sync_status().await, + DataSource::Sql(data_source) => data_source.sync_status().await, + DataSource::NoStorage(data_source) => data_source.sync_status().await, } } @@ -694,8 +699,10 @@ pub mod testing { end: u64, ) -> QueryResult>> { match self { - Self::Sql(data_source) => data_source.get_header_window(start, end).await, - Self::NoStorage(data_source) => data_source.get_header_window(start, end).await, + DataSource::Sql(data_source) => data_source.get_header_window(start, end).await, + DataSource::NoStorage(data_source) => { + data_source.get_header_window(start, end).await + } } } } diff --git a/src/data_source/storage/pruning.rs b/src/data_source/storage/pruning.rs index 8d9ce72e4..21fb970dc 100644 --- a/src/data_source/storage/pruning.rs +++ b/src/data_source/storage/pruning.rs @@ -10,7 +10,6 @@ // You should have received a copy of the GNU General Public License along with this program. If not, // see . -use crate::data_source::ReadOnly; use anyhow::bail; use async_trait::async_trait; use std::{fmt::Debug, time::Duration}; @@ -39,22 +38,12 @@ pub trait PruneStorage: PrunerConfig { } #[async_trait] -pub trait PrunedHeightStorage { - async fn load_pruned_height(&self) -> anyhow::Result> { +pub trait PrunedHeightStorage: Sized { + async fn load_pruned_height(&mut self) -> anyhow::Result> { Ok(None) } } -#[async_trait] -impl PrunedHeightStorage for ReadOnly -where - T: PrunedHeightStorage + Sync, -{ - async fn load_pruned_height(&self) -> anyhow::Result> { - (**self).load_pruned_height().await - } -} - pub trait PrunerConfig { fn set_pruning_config(&mut self, _cfg: PrunerCfg) {} fn get_pruning_config(&self) -> Option { diff --git a/src/data_source/storage/sql.rs b/src/data_source/storage/sql.rs index eab89384f..15b87c7a7 100644 --- a/src/data_source/storage/sql.rs +++ b/src/data_source/storage/sql.rs @@ -15,38 +15,44 @@ use crate::{ data_source::{ storage::pruning::{PruneStorage, PrunerCfg, PrunerConfig}, - update::{ReadOnly, Transaction as _}, + update::Transaction as _, VersionedDataSource, }, - BackgroundTask, QueryError, QueryResult, -}; -use async_std::{ - net::ToSocketAddrs, - sync::Mutex, - task::{Context, Poll}, + QueryError, QueryResult, }; use async_trait::async_trait; use chrono::Utc; -use futures::{AsyncRead, AsyncWrite}; +use futures::future::FutureExt; use itertools::Itertools; -use postgres_native_tls::TlsConnector; -use std::{cmp::min, fmt::Debug, pin::Pin, str::FromStr}; -use tokio_postgres::{config::Host, tls::TlsConnect, Client, NoTls}; +use sqlx::{ + pool::{Pool, PoolOptions}, + postgres::{PgConnectOptions, PgSslMode}, + ConnectOptions, Connection, Row, +}; +use std::{cmp::min, fmt::Debug, str::FromStr}; -mod query; -mod transaction; +pub extern crate sqlx; +pub use sqlx::{Database, Postgres, Sqlite}; -use self::transaction::Connection; +mod db; +mod migrate; +mod queries; +mod transaction; pub use anyhow::Error; // This needs to be reexported so that we can reference it by absolute path relative to this crate // in the expansion of `include_migrations`, even when `include_migrations` is invoked from another // crate which doesn't have `include_dir` as a dependency. pub use crate::include_migrations; +pub use db::Db; pub use include_dir::include_dir; pub use refinery::Migration; -pub use tokio_postgres as postgres; -pub use transaction::Transaction; +pub use transaction::{query, query_as, Executor, Query, QueryAs, Transaction}; + +use self::{ + migrate::Migrator, + transaction::{Read, Write}, +}; /// Embed migrations from the given directory into the current binary. /// @@ -167,68 +173,58 @@ fn add_custom_migrations( /// Postgres client config. #[derive(Clone, Debug)] -pub struct Config { - pgcfg: postgres::Config, - host: String, - port: u16, +pub struct Config +where + DB: Database, +{ + db_opt: ::Options, + pool_opt: PoolOptions, schema: String, reset: bool, migrations: Vec, no_migrations: bool, - tls: bool, pruner_cfg: Option, archive: bool, } -impl Default for Config { +impl Default for Config { fn default() -> Self { + PgConnectOptions::default() + .host("localhost") + .port(5432) + .into() + } +} + +impl From for Config { + fn from(db_opt: PgConnectOptions) -> Self { Self { - pgcfg: Default::default(), - host: "localhost".into(), - port: 5432, + db_opt, + pool_opt: PoolOptions::default(), schema: "hotshot".into(), reset: false, migrations: vec![], no_migrations: false, - tls: false, pruner_cfg: None, archive: false, } } } -impl From for Config { - fn from(pgcfg: postgres::Config) -> Self { - // We connect via TCP manually, without using the host and port from pgcfg. So we need to - // pull those out of pgcfg if they have been specified, to override the defaults. - let host = match pgcfg.get_hosts().first() { - Some(Host::Tcp(host)) => host.to_string(), - _ => "localhost".into(), - }; - let port = *pgcfg.get_ports().first().unwrap_or(&5432); - Self { - pgcfg, - host, - port, - ..Default::default() - } - } -} - -impl FromStr for Config { - type Err = ::Err; +impl FromStr for Config { + type Err = ::Err; fn from_str(s: &str) -> Result { - Ok(postgres::Config::from_str(s)?.into()) + Ok(PgConnectOptions::from_str(s)?.into()) } } -impl Config { +impl Config { /// Set the hostname of the database server. /// /// The default is `localhost`. pub fn host(mut self, host: impl Into) -> Self { - self.host = host.into(); + self.db_opt = self.db_opt.host(&host.into()); self } @@ -236,28 +232,40 @@ impl Config { /// /// The default is 5432, the default Postgres port. pub fn port(mut self, port: u16) -> Self { - self.port = port; + self.db_opt = self.db_opt.port(port); self } /// Set the DB user to connect as. pub fn user(mut self, user: &str) -> Self { - self.pgcfg.user(user); + self.db_opt = self.db_opt.username(user); self } /// Set a password for connecting to the database. pub fn password(mut self, password: &str) -> Self { - self.pgcfg.password(password); + self.db_opt = self.db_opt.password(password); self } /// Set the name of the database to connect to. pub fn database(mut self, database: &str) -> Self { - self.pgcfg.dbname(database); + self.db_opt = self.db_opt.database(database); + self + } + + /// Use TLS for an encrypted connection to the database. + /// + /// Note that an encrypted connection may be established even if this option is not set, as long + /// as both the client and server support it. This option merely causes connection to fail if an + /// encrypted stream cannot be established. + pub fn tls(mut self) -> Self { + self.db_opt = self.db_opt.ssl_mode(PgSslMode::Require); self } +} +impl Config { /// Set the name of the schema to use for queries. /// /// The default schema is named `hotshot` and is created via the default migrations. @@ -293,12 +301,6 @@ impl Config { self } - /// Use TLS for an encrypted connection to the database. - pub fn tls(mut self) -> Self { - self.tls = true; - self - } - /// Enable pruning with a given configuration. /// /// If [`archive`](Self::archive) was previously specified, this will override it. @@ -327,14 +329,7 @@ impl Config { /// Storage for the APIs provided in this crate, backed by a remote PostgreSQL database. #[derive(Debug)] pub struct SqlStorage { - client: Mutex, - _client_task: BackgroundTask, - // We use a separate client for mutable access (database transactions). This allows runtime - // serialization of mutable operations while immutable operations can proceed in parallel, which - // mimics the lose concurrency semantics provided by Postgres. Eventually we will have a - // transaction pool which allows even multiple transactions to happen in parallel. - tx_client: Mutex, - _tx_client_task: BackgroundTask, + pool: Pool, pruner_cfg: Option, } @@ -347,29 +342,32 @@ pub struct Pruner { impl SqlStorage { /// Connect to a remote database. - pub async fn connect(mut config: Config) -> Result { - // Establish a TCP connection to the server. - let tcp = TcpStream::connect((config.host.as_str(), config.port)).await?; - - // Convert the TCP connection into a postgres connection. - let (mut client, client_task) = if config.tls { - let tls = TlsConnector::new(native_tls::TlsConnector::new()?, config.host.as_str()); - connect(config.pgcfg.clone(), tcp, tls).await? - } else { - connect(config.pgcfg.clone(), tcp, NoTls).await? - }; + pub async fn connect(mut config: Config) -> Result { + let schema = config.schema.clone(); + let pool = config + .pool_opt + .after_connect(move |conn, _| { + let schema = schema.clone(); + async move { + query(&format!("SET search_path TO {schema}")) + .execute(conn) + .await?; + Ok(()) + } + .boxed() + }) + .connect(config.db_opt.to_url_lossy().as_ref()) + .await?; // Create or connect to the schema for this query service. + let mut conn = pool.acquire().await?; if config.reset { - client - .batch_execute(&format!("DROP SCHEMA IF EXISTS {} CASCADE", config.schema)) + query(&format!("DROP SCHEMA IF EXISTS {} CASCADE", config.schema)) + .execute(conn.as_mut()) .await?; } - client - .batch_execute(&format!("CREATE SCHEMA IF NOT EXISTS {}", config.schema)) - .await?; - client - .batch_execute(&format!("SET search_path TO {}", config.schema)) + query(&format!("CREATE SCHEMA IF NOT EXISTS {}", config.schema)) + .execute(conn.as_mut()) .await?; // Get migrations and interleave with custom migrations, sorting by version number. @@ -384,7 +382,9 @@ impl SqlStorage { if config.no_migrations { // We've been asked not to run any migrations. Abort if the DB is not already up to // date. - let last_applied = runner.get_last_applied_migration_async(&mut client).await?; + let last_applied = runner + .get_last_applied_migration_async(&mut Migrator::from(&mut conn)) + .await?; let last_expected = migrations.last(); if last_applied.as_ref() != last_expected { return Err(Error::msg(format!( @@ -393,7 +393,7 @@ impl SqlStorage { } } else { // Run migrations using `refinery`. - match runner.run_async(&mut client).await { + match runner.run_async(&mut Migrator::from(&mut conn)).await { Ok(report) => { tracing::info!("ran DB migrations: {report:?}"); } @@ -407,28 +407,13 @@ impl SqlStorage { if config.archive { // If running in archive mode, ensure the pruned height is set to 0, so the fetcher will // reconstruct previously pruned data. - client - .batch_execute("DELETE FROM pruned_height WHERE id = 1") + query("DELETE FROM pruned_height WHERE id = 1") + .execute(conn.as_mut()) .await?; } - // Open a second connection for mutable transactions. - let tcp = TcpStream::connect((config.host.as_str(), config.port)).await?; - let (tx_client, tx_client_task) = if config.tls { - let tls = TlsConnector::new(native_tls::TlsConnector::new()?, config.host.as_str()); - connect(config.pgcfg, tcp, tls).await? - } else { - connect(config.pgcfg, tcp, NoTls).await? - }; - tx_client - .batch_execute(&format!("SET search_path TO {}", config.schema)) - .await?; - Ok(Self { - client: Mutex::new(client.into()), - _client_task: client_task, - tx_client: Mutex::new(tx_client.into()), - _tx_client_task: tx_client_task, + pool, pruner_cfg: config.pruner_cfg, }) } @@ -446,20 +431,21 @@ impl PrunerConfig for SqlStorage { impl SqlStorage { async fn get_minimum_height(&self) -> QueryResult> { - let tx = self.read().await.map_err(|err| QueryError::Error { + let mut tx = self.read().await.map_err(|err| QueryError::Error { message: err.to_string(), })?; - let row = tx - .query_one_static("SELECT MIN(height) as height FROM header") - .await?; - - let height = row.get::<_, Option>(0).map(|h| h as u64); - - Ok(height) + let (Some(height),) = + query_as::<(Option,)>("SELECT MIN(height) as height FROM header") + .fetch_one(tx.as_mut()) + .await? + else { + return Ok(None); + }; + Ok(Some(height as u64)) } async fn get_height_by_timestamp(&self, timestamp: i64) -> QueryResult> { - let tx = self.read().await.map_err(|err| QueryError::Error { + let mut tx = self.read().await.map_err(|err| QueryError::Error { message: err.to_string(), })?; @@ -469,19 +455,19 @@ impl SqlStorage { // based on the timestamp index. The remaining sort on height, which guarantees a unique // block if multiple blocks have the same timestamp, is very efficient, because there are // never more than a handful of blocks with the same timestamp. - let row = tx - .query_opt( - "SELECT height FROM header - WHERE timestamp <= $1 - ORDER BY timestamp DESC, height DESC - LIMIT 1", - [×tamp], - ) - .await?; - - let height = row.map(|row| row.get::<_, i64>(0) as u64); - - Ok(height) + let Some((height,)) = query_as::<(i64,)>( + "SELECT height FROM header + WHERE timestamp <= $1 + ORDER BY timestamp DESC, height DESC + LIMIT 1", + ) + .bind(timestamp) + .fetch_optional(tx.as_mut()) + .await? + else { + return Ok(None); + }; + Ok(Some(height as u64)) } } @@ -490,9 +476,9 @@ impl PruneStorage for SqlStorage { type Pruner = Pruner; async fn get_disk_usage(&self) -> anyhow::Result { - let tx = self.read().await?; + let mut tx = self.read().await?; let row = tx - .query_one_static("SELECT pg_database_size(current_database())") + .fetch_one("SELECT pg_database_size(current_database())") .await?; let size: i64 = row.get(0); Ok(size as u64) @@ -599,108 +585,19 @@ impl PruneStorage for SqlStorage { } impl VersionedDataSource for SqlStorage { - type Transaction<'a> = Transaction<'a> + type Transaction<'a> = Transaction where Self: 'a; - type ReadOnly<'a> = ReadOnly> + type ReadOnly<'a> = Transaction where Self: 'a; - async fn write(&self) -> anyhow::Result> { - let tx = self.tx_client.lock().await; - Transaction::write(tx).await + async fn write(&self) -> anyhow::Result> { + Transaction::new(&self.pool).await } - async fn read(&self) -> anyhow::Result>> { - let tx = self.client.lock().await; - Transaction::read(tx).await - } -} - -/// Connect to a Postgres database with a TLS implementation. -/// -/// Spawns a background task to run the connection. Returns a client and a handle to the spawned -/// task. -async fn connect( - pgcfg: postgres::Config, - tcp: TcpStream, - tls: T, -) -> anyhow::Result<(Client, BackgroundTask)> -where - T: TlsConnect, - T::Stream: Send + 'static, -{ - let (client, connection) = pgcfg.connect_raw(tcp, tls).await?; - Ok(( - client, - BackgroundTask::spawn("postgres connection", connection), - )) -} - -// tokio-postgres is written in terms of the tokio AsyncRead/AsyncWrite traits. However, these -// traits do not require any specifics of the tokio runtime. Thus we can implement them using the -// async_std TcpStream type, and have a stream which is compatible with tokio-postgres but will run -// on the async_std executor. -// -// To avoid orphan impls, we wrap this tream in a new type. -struct TcpStream(async_std::net::TcpStream); - -impl TcpStream { - async fn connect(addrs: A) -> Result { - Ok(Self(async_std::net::TcpStream::connect(addrs).await?)) - } -} - -impl tokio::io::AsyncRead for TcpStream { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> Poll> { - // tokio uses this hyper-optimized `ReadBuf` construct, where there is a filled portion, an - // unfilled portion where we append new data, and the unfilled portion of the buffer need - // not even be initialized. However the async_std implementation we're delegating to just - // expects a normal `&mut [u8]` buffer which is entirely unfilled. To simplify the - // conversion, we will abandon the uninitialized buffer optimization and force - // initialization of the entire buffer, resulting in a plain old `&mut [u8]` representing - // the unfilled portion. But first, we need to grab the length of the filled region so we - // can increment it after we read new data from async_std. - let filled = buf.filled().len(); - - // Initialize the buffer and get a slice of the unfilled region. This operation is free - // after the first time it is called, so we don't need to worry about maintaining state - // between subsequent calls to `poll_read`. - let unfilled = buf.initialize_unfilled(); - - // Read data into the unfilled portion of the buffer. - match Pin::new(&mut self.0).poll_read(cx, unfilled) { - Poll::Ready(Ok(bytes_read)) => { - // After the read completes, the first `bytes_read` of `unfilled` have now been - // filled. Increment the `filled` cursor within the `ReadBuf` to account for this. - buf.set_filled(filled + bytes_read); - Poll::Ready(Ok(())) - } - Poll::Ready(Err(err)) => Poll::Ready(Err(err)), - Poll::Pending => Poll::Pending, - } - } -} - -impl tokio::io::AsyncWrite for TcpStream { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.0).poll_write(cx, buf) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_flush(cx) - } - - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_close(cx) + async fn read(&self) -> anyhow::Result> { + Transaction::new(&self.pool).await } } @@ -719,7 +616,7 @@ pub mod testing { use portpicker::pick_unused_port; use refinery::Migration; - use super::Config; + use super::{Config, Postgres}; use crate::testing::sleep; #[derive(Debug)] @@ -786,13 +683,12 @@ pub mod testing { self.port } - pub fn config(&self) -> Config { + pub fn config(&self) -> Config { Config::default() .user("postgres") .password("password") .host(self.host()) .port(self.port()) - .tls() .migrations(vec![Migration::unapplied( "V11__create_test_merkle_tree_table.sql", &TestMerkleTreeMigration::create("test_tree"), @@ -1019,21 +915,20 @@ mod test { #[test] fn test_config_from_str() { let cfg = Config::from_str("postgresql://user:password@host:8080").unwrap(); - assert_eq!(cfg.pgcfg.get_user(), Some("user")); - assert_eq!(cfg.pgcfg.get_password(), Some("password".as_bytes())); - assert_eq!(cfg.host, "host"); - assert_eq!(cfg.port, 8080); + assert_eq!(cfg.db_opt.get_username(), "user"); + assert_eq!(cfg.db_opt.get_host(), "host"); + assert_eq!(cfg.db_opt.get_port(), 8080); } - #[test] - fn test_config_from_pgcfg() { - let mut pgcfg = postgres::Config::default(); - pgcfg.dbname("db"); - let cfg = Config::from(pgcfg.clone()); - assert_eq!(cfg.pgcfg, pgcfg); - // Default values. - assert_eq!(cfg.host, "localhost"); - assert_eq!(cfg.port, 5432); + async fn vacuum(storage: &SqlStorage) { + storage + .pool + .acquire() + .await + .unwrap() + .execute("VACUUM") + .await + .unwrap(); } #[async_std::test] @@ -1075,7 +970,7 @@ mod test { // Vacuum the database to reclaim space. // This is necessary to ensure the test passes. // Note: We don't perform a vacuum after each pruner run in production because the auto vacuum job handles it automatically. - storage.client.lock().await.vacuum().await.unwrap(); + vacuum(&storage).await; // Pruned height should be none assert!(pruned_height.is_none()); @@ -1096,7 +991,7 @@ mod test { // Vacuum the database to reclaim space. // This is necessary to ensure the test passes. // Note: We don't perform a vacuum after each pruner run in production because the auto vacuum job handles it automatically. - storage.client.lock().await.vacuum().await.unwrap(); + vacuum(&storage).await; // Pruned height should be some assert!(pruned_height.is_some()); @@ -1107,10 +1002,10 @@ mod test { .read() .await .unwrap() - .query_one_static("select count(*) as count from header") + .fetch_one("select count(*) as count from header") .await .unwrap() - .get::<_, i64>("count"); + .get::("count"); // the table should be empty assert_eq!(header_rows, 0); @@ -1121,10 +1016,10 @@ mod test { .read() .await .unwrap() - .query_one_static("select count(*) as count from leaf") + .fetch_one("select count(*) as count from leaf") .await .unwrap() - .get::<_, i64>("count"); + .get::("count"); // the table should be empty assert_eq!(leaf_rows, 0); @@ -1169,7 +1064,7 @@ mod test { // Vacuum the database to reclaim space. // This is necessary to ensure the test passes. // Note: We don't perform a vacuum after each pruner run in production because the auto vacuum job handles it automatically. - storage.client.lock().await.vacuum().await.unwrap(); + vacuum(&storage).await; // Pruned height should be none assert!(pruned_height.is_none()); @@ -1193,7 +1088,7 @@ mod test { // Vacuum the database to reclaim space. // This is necessary to ensure the test passes. // Note: We don't perform a vacuum after each pruner run in production because the auto vacuum job handles it automatically. - storage.client.lock().await.vacuum().await.unwrap(); + vacuum(&storage).await; // Pruned height should be some assert!(pruned_height.is_some()); @@ -1203,10 +1098,10 @@ mod test { .read() .await .unwrap() - .query_one_static("select count(*) as count from header") + .fetch_one("select count(*) as count from header") .await .unwrap() - .get::<_, i64>("count"); + .get::("count"); // the table should be empty assert_eq!(header_rows, 0); } diff --git a/src/data_source/storage/sql/db.rs b/src/data_source/storage/sql/db.rs new file mode 100644 index 000000000..cdd3a90ea --- /dev/null +++ b/src/data_source/storage/sql/db.rs @@ -0,0 +1,25 @@ +/// The concrete database backing a SQL data source. +/// +/// Currently only Postgres is supported. In the future we can support SQLite as well by making this +/// an enum with variants for each (we'll then need to create enums and trait implementations for +/// all the associated types as well; it will be messy). +/// +/// The reason for taking this approach over sqlx's `Any` database is that we can support SQL types +/// which are implemented for the two backends we care about (Postgres and SQLite) but not for _any_ +/// SQL database, such as MySQL. Crucially, JSON types fall in this category. +/// +/// The reason for taking this approach rather than writing all of our code to be generic over the +/// `Database` implementation is that `sqlx` does not have the necessary trait bounds on all of the +/// associated types (e.g. `Database::Connection` does not implement `Executor` for all possible +/// databases, the `Executor` impl lives on each concrete connection type) and Rust does not provide +/// a good way of encapsulating a collection of trait bounds on associated types. Thus, our function +/// signatures become untenably messy with bounds like +/// +/// ``` +/// where +/// for<'a> &'a mut DB::Connection: Executor<'a>, +/// for<'q> DB::Arguments<'q>: IntoArguments<'q, DB>, +/// for<'a> i64: Type + Encode<'a, DB>, +/// ``` +/// etc. +pub type Db = sqlx::Postgres; diff --git a/src/data_source/storage/sql/migrate.rs b/src/data_source/storage/sql/migrate.rs new file mode 100644 index 000000000..694a68608 --- /dev/null +++ b/src/data_source/storage/sql/migrate.rs @@ -0,0 +1,71 @@ +use super::{queries::DecodeError, Db}; +use async_trait::async_trait; +use derive_more::From; +use futures::stream::StreamExt; +use refinery_core::{ + traits::r#async::{AsyncMigrate, AsyncQuery, AsyncTransaction}, + Migration, +}; +use sqlx::{pool::PoolConnection, Acquire, Executor, Row}; +use time::{format_description::well_known::Rfc3339, OffsetDateTime}; + +/// Run migrations using a sqlx connection. +/// +/// While SQLx has its own built-in migration functionality, we use Refinery, and alas we must +/// support existing deployed databases which are already using Refinery to handle migrations. +/// Rather than implement a tricky "migration of the migrations table", or supporting separate +/// migrations interfaces for databases deployed before and after the switch to SQLx, we continue +/// using Refinery. This wrapper implements the Refinery traits for SQLx types. +#[derive(Debug, From)] +pub(super) struct Migrator<'a> { + conn: &'a mut PoolConnection, +} + +#[async_trait] +impl<'a> AsyncTransaction for Migrator<'a> { + type Error = sqlx::Error; + + async fn execute(&mut self, queries: &[&str]) -> sqlx::Result { + let mut tx = self.conn.begin().await?; + let mut count = 0; + for query in queries { + let res = tx.execute(*query).await?; + count += res.rows_affected(); + } + tx.commit().await?; + Ok(count as usize) + } +} + +#[async_trait] +impl<'a> AsyncQuery> for Migrator<'a> { + async fn query(&mut self, query: &str) -> sqlx::Result> { + let mut tx = self.conn.begin().await?; + + let mut applied = Vec::new(); + let mut rows = tx.fetch(query); + while let Some(row) = rows.next().await { + let row = row?; + let version = row.try_get(0)?; + let applied_on: String = row.try_get(2)?; + let applied_on = OffsetDateTime::parse(&applied_on, &Rfc3339) + .decode_error("malformed migration timestamp")?; + let checksum: String = row.get(3); + + applied.push(Migration::applied( + version, + row.try_get(1)?, + applied_on, + checksum + .parse::() + .decode_error("malformed migration checksum")?, + )); + } + + drop(rows); + tx.commit().await?; + Ok(applied) + } +} + +impl<'a> AsyncMigrate for Migrator<'a> {} diff --git a/src/data_source/storage/sql/queries.rs b/src/data_source/storage/sql/queries.rs new file mode 100644 index 000000000..13cd9cf5b --- /dev/null +++ b/src/data_source/storage/sql/queries.rs @@ -0,0 +1,318 @@ +// Copyright (c) 2022 Espresso Systems (espressosys.com) +// This file is part of the HotShot Query Service library. +// +// This program is free software: you can redistribute it and/or modify it under the terms of the GNU +// General Public License as published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without +// even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// General Public License for more details. +// You should have received a copy of the GNU General Public License along with this program. If not, +// see . + +//! Immutable query functionality of a SQL database. + +use super::{Database, Db, Query, QueryAs, Transaction}; +use crate::{ + availability::{ + BlockId, BlockQueryData, LeafQueryData, PayloadQueryData, QueryablePayload, + VidCommonQueryData, + }, + Header, Leaf, Payload, QueryError, QueryResult, +}; +use anyhow::Context; +use derivative::Derivative; +use hotshot_types::{ + simple_certificate::QuorumCertificate, + traits::{ + block_contents::{BlockHeader, BlockPayload}, + node_implementation::NodeType, + }, +}; +use sqlx::{Arguments, FromRow, Row}; +use std::{ + fmt::Display, + ops::{Bound, RangeBounds}, +}; + +pub(super) mod availability; +pub(super) mod explorer; +pub(super) mod node; +pub(super) mod state; + +/// Helper type for programatically constructing queries. +/// +/// This type can be used to bind arguments of various types, similar to [`Query`] or [`QueryAs`]. +/// With [`QueryBuilder`], though, the arguments are bound *first* and the SQL statement is given +/// last. Each time an argument is bound, a SQL fragment is returned as a string which can be used +/// to represent that argument in the statement (e.g. `$1` for the first argument bound). This makes +/// it easier to programatically construct queries where the statement is not a compile time +/// constant. +/// +/// # Example +/// +/// ``` +/// fn search_and_maybe_filter( +/// tx: &mut Transaction, +/// id: Option, +/// ) -> QueryResult> { +/// let mut query = QueryBuilder::default(); +/// let mut sql = "SELECT * FROM table".into(); +/// if let Some(id) = id { +/// sql = format!("{sql} WHERE id = {}", query.bind(id)); +/// } +/// let results = query +/// .query_as(&sql) +/// .fetch_all(tx)?; +/// Ok(results) +/// } +/// ``` +#[derive(Derivative, Default)] +#[derivative(Debug)] +pub struct QueryBuilder<'q> { + #[derivative(Debug = "ignore")] + arguments: ::Arguments<'q>, +} + +impl<'q> QueryBuilder<'q> { + /// Add an argument and return its name as a formal parameter in a SQL prepared statement. + pub fn bind(&mut self, arg: T) -> QueryResult + where + T: 'q + sqlx::Encode<'q, Db> + sqlx::Type, + { + self.arguments.add(arg).map_err(|err| QueryError::Error { + message: format!("{err:#}"), + })?; + Ok(format!("${}", self.arguments.len())) + } + + /// Finalize the query with a constructed SQL statement. + pub fn query(self, sql: &'q str) -> Query<'q> { + sqlx::query_with(sql, self.arguments) + } + + /// Finalize the query with a constructed SQL statement and a specified output type. + pub fn query_as(self, sql: &'q str) -> QueryAs<'q, T> + where + T: for<'r> FromRow<'r, ::Row>, + { + sqlx::query_as_with(sql, self.arguments) + } +} + +impl<'q> QueryBuilder<'q> { + /// Construct a SQL `WHERE` clause which filters for a header exactly matching `id`. + pub fn header_where_clause( + &mut self, + id: BlockId, + ) -> QueryResult { + let clause = match id { + BlockId::Number(n) => format!("h.height = {}", self.bind(n as i64)?), + BlockId::Hash(h) => format!("h.hash = {}", self.bind(h.to_string())?), + BlockId::PayloadHash(h) => format!("h.payload_hash = {}", self.bind(h.to_string())?), + }; + Ok(clause) + } + + /// Convert range bounds to a SQL `WHERE` clause constraining a given column. + pub fn bounds_to_where_clause(&mut self, range: R, column: &str) -> QueryResult + where + R: RangeBounds, + { + let mut bounds = vec![]; + + match range.start_bound() { + Bound::Included(n) => { + bounds.push(format!("{column} >= {}", self.bind(*n as i64)?)); + } + Bound::Excluded(n) => { + bounds.push(format!("{column} > {}", self.bind(*n as i64)?)); + } + Bound::Unbounded => {} + } + match range.end_bound() { + Bound::Included(n) => { + bounds.push(format!("{column} <= {}", self.bind(*n as i64)?)); + } + Bound::Excluded(n) => { + bounds.push(format!("{column} < {}", self.bind(*n as i64)?)); + } + Bound::Unbounded => {} + } + + let mut where_clause = bounds.join(" AND "); + if !where_clause.is_empty() { + where_clause = format!(" WHERE {where_clause}"); + } + + Ok(where_clause) + } +} + +const LEAF_COLUMNS: &str = "leaf, qc"; + +impl<'r, Types> FromRow<'r, ::Row> for LeafQueryData +where + Types: NodeType, +{ + fn from_row(row: &'r ::Row) -> sqlx::Result { + let leaf = row.try_get("leaf")?; + let leaf: Leaf = serde_json::from_value(leaf).decode_error("malformed leaf")?; + + let qc = row.try_get("qc")?; + let qc: QuorumCertificate = + serde_json::from_value(qc).decode_error("malformed QC")?; + + Ok(Self { leaf, qc }) + } +} + +const BLOCK_COLUMNS: &str = + "h.hash AS hash, h.data AS header_data, p.size AS payload_size, p.data AS payload_data"; + +impl<'r, Types> FromRow<'r, ::Row> for BlockQueryData +where + Types: NodeType, + Payload: QueryablePayload, +{ + fn from_row(row: &'r ::Row) -> sqlx::Result { + // First, check if we have the payload for this block yet. + let size: Option = row.try_get("payload_size")?; + let payload_data: Option> = row.try_get("payload_data")?; + let (size, payload_data) = size.zip(payload_data).ok_or(sqlx::Error::RowNotFound)?; + let size = size as u64; + + // Reconstruct the full header. + let header_data = row.try_get("header_data")?; + let header: Header = + serde_json::from_value(header_data).decode_error("malformed header")?; + + // Reconstruct the full block payload. + let payload = Payload::::from_bytes(&payload_data, header.metadata()); + + // Reconstruct the query data by adding metadata. + let hash: String = row.try_get("hash")?; + let hash = hash.parse().decode_error("malformed block hash")?; + + Ok(Self { + num_transactions: payload.len(header.metadata()) as u64, + header, + payload, + size, + hash, + }) + } +} + +const PAYLOAD_COLUMNS: &str = BLOCK_COLUMNS; + +impl<'r, Types> FromRow<'r, ::Row> for PayloadQueryData +where + Types: NodeType, + Payload: QueryablePayload, +{ + fn from_row(row: &'r ::Row) -> sqlx::Result { + as FromRow<::Row>>::from_row(row).map(Self::from) + } +} + +const VID_COMMON_COLUMNS: &str = "h.height AS height, h.hash AS block_hash, h.payload_hash AS payload_hash, v.common AS common_data"; + +impl<'r, Types> FromRow<'r, ::Row> for VidCommonQueryData +where + Types: NodeType, + Payload: QueryablePayload, +{ + fn from_row(row: &'r ::Row) -> sqlx::Result { + let height = row.try_get::("height")? as u64; + let block_hash: String = row.try_get("block_hash")?; + let block_hash = block_hash.parse().decode_error("malformed block hash")?; + let payload_hash: String = row.try_get("payload_hash")?; + let payload_hash = payload_hash + .parse() + .decode_error("malformed payload hash")?; + let common_data: Vec = row.try_get("common_data")?; + let common = + bincode::deserialize(&common_data).decode_error("malformed VID common data")?; + Ok(Self { + height, + block_hash, + payload_hash, + common, + }) + } +} + +const HEADER_COLUMNS: &str = "h.data AS data"; + +// We can't implement `FromRow` for `Header` since `Header` is not actually a type +// defined in this crate; it's just an alias for `Types::BlockHeader`. So this standalone function +// will have to do. +fn parse_header(row: ::Row) -> sqlx::Result> +where + Types: NodeType, +{ + // Reconstruct the full header. + let data = row.try_get("data")?; + serde_json::from_value(data).decode_error("malformed header") +} + +impl From for QueryError { + fn from(err: sqlx::Error) -> Self { + if matches!(err, sqlx::Error::RowNotFound) { + Self::NotFound + } else { + Self::Error { + message: err.to_string(), + } + } + } +} + +impl Transaction { + /// Load a header from storage. + /// + /// This function is similar to `AvailabilityStorage::get_header`, but + /// * does not require the `QueryablePayload` bound that that trait impl does + /// * makes it easier to specify types since the type parameter is on the function and not on a + /// trait impl + /// * allows type conversions for the `id` parameter + /// + /// This more ergonomic interface is useful as loading headers is important for many SQL storage + /// functions, not just the `AvailabilityStorage` interface. + pub async fn load_header( + &mut self, + id: impl Into> + Send, + ) -> QueryResult> { + let mut query = QueryBuilder::default(); + let where_clause = query.header_where_clause(id.into())?; + // ORDER BY h.height ASC ensures that if there are duplicate blocks (this can happen when + // selecting by payload ID, as payloads are not unique), we return the first one. + let sql = format!( + "SELECT {HEADER_COLUMNS} + FROM header AS h + WHERE {where_clause} + ORDER BY h.height ASC + LIMIT 1" + ); + let row = query.query(&sql).fetch_one(self.as_mut()).await?; + let header = parse_header::(row)?; + Ok(header) + } +} + +pub(super) trait DecodeError { + type Ok; + fn decode_error(self, msg: impl Display) -> sqlx::Result; +} + +impl DecodeError for Result +where + E: std::error::Error + Send + Sync + 'static, +{ + type Ok = T; + fn decode_error(self, msg: impl Display) -> sqlx::Result<::Ok> { + self.context(msg.to_string()) + .map_err(|err| sqlx::Error::Decode(err.into())) + } +} diff --git a/src/data_source/storage/sql/query/availability.rs b/src/data_source/storage/sql/queries/availability.rs similarity index 50% rename from src/data_source/storage/sql/query/availability.rs rename to src/data_source/storage/sql/queries/availability.rs index 5c1c97e28..53fd59622 100644 --- a/src/data_source/storage/sql/query/availability.rs +++ b/src/data_source/storage/sql/queries/availability.rs @@ -13,9 +13,8 @@ //! Availability storage implementation for a database query engine. use super::{ - bounds_to_where_clause, header_where_clause, parse_block, parse_leaf, parse_payload, - parse_vid_common, postgres::types::ToSql, Transaction, BLOCK_COLUMNS, PAYLOAD_COLUMNS, - VID_COMMON_COLUMNS, + super::transaction::{Transaction, TransactionMode}, + QueryBuilder, BLOCK_COLUMNS, LEAF_COLUMNS, PAYLOAD_COLUMNS, VID_COMMON_COLUMNS, }; use crate::{ availability::{ @@ -24,36 +23,45 @@ use crate::{ }, data_source::storage::AvailabilityStorage, types::HeightIndexed, - ErrorSnafu, Header, Payload, QueryResult, + ErrorSnafu, Header, Payload, QueryError, QueryResult, }; use async_trait::async_trait; -use futures::stream::StreamExt; +use futures::stream::{StreamExt, TryStreamExt}; use hotshot_types::traits::node_implementation::NodeType; use snafu::OptionExt; +use sqlx::FromRow; use std::ops::RangeBounds; #[async_trait] -impl<'a, Types> AvailabilityStorage for Transaction<'a> +impl AvailabilityStorage for Transaction where Types: NodeType, + Mode: TransactionMode, Payload: QueryablePayload, Header: QueryableHeader, { - async fn get_leaf(&self, id: LeafId) -> QueryResult> { - let (where_clause, param): (&str, Box) = match id { - LeafId::Number(n) => ("height = $1", Box::new(n as i64)), - LeafId::Hash(h) => ("hash = $1", Box::new(h.to_string())), + async fn get_leaf(&mut self, id: LeafId) -> QueryResult> { + let mut query = QueryBuilder::default(); + let where_clause = match id { + LeafId::Number(n) => format!("height = {}", query.bind(n as i64)?), + LeafId::Hash(h) => format!("hash = {}", query.bind(h.to_string())?), }; - let query = format!("SELECT leaf, qc FROM leaf WHERE {where_clause}"); - let row = self.query_one(&query, [param]).await?; - parse_leaf(row) + let row = query + .query(&format!( + "SELECT {LEAF_COLUMNS} FROM leaf WHERE {where_clause}" + )) + .fetch_one(self.as_mut()) + .await?; + let leaf = LeafQueryData::from_row(&row)?; + Ok(leaf) } - async fn get_block(&self, id: BlockId) -> QueryResult> { - let (where_clause, param) = header_where_clause(id); + async fn get_block(&mut self, id: BlockId) -> QueryResult> { + let mut query = QueryBuilder::default(); + let where_clause = query.header_where_clause(id)?; // ORDER BY h.height ASC ensures that if there are duplicate blocks (this can happen when // selecting by payload ID, as payloads are not unique), we return the first one. - let query = format!( + let sql = format!( "SELECT {BLOCK_COLUMNS} FROM header AS h JOIN payload AS p ON h.height = p.height @@ -61,139 +69,170 @@ where ORDER BY h.height ASC LIMIT 1" ); - let row = self.query_one(&query, [param]).await?; - parse_block(row) + let row = query.query(&sql).fetch_one(self.as_mut()).await?; + let block = BlockQueryData::from_row(&row)?; + Ok(block) } - async fn get_header(&self, id: BlockId) -> QueryResult> { + async fn get_header(&mut self, id: BlockId) -> QueryResult> { self.load_header(id).await } - async fn get_payload(&self, id: BlockId) -> QueryResult> { - let (where_clause, param) = header_where_clause(id); + async fn get_payload(&mut self, id: BlockId) -> QueryResult> { + let mut query = QueryBuilder::default(); + let where_clause = query.header_where_clause(id)?; // ORDER BY h.height ASC ensures that if there are duplicate blocks (this can happen when // selecting by payload ID, as payloads are not unique), we return the first one. - let query = format!( + let sql = format!( "SELECT {PAYLOAD_COLUMNS} - FROM header AS h - JOIN payload AS p ON h.height = p.height - WHERE {where_clause} - ORDER BY h.height ASC - LIMIT 1" + FROM header AS h + JOIN payload AS p ON h.height = p.height + WHERE {where_clause} + ORDER BY h.height ASC + LIMIT 1" ); - let row = self.query_one(&query, [param]).await?; - parse_payload(row) + let row = query.query(&sql).fetch_one(self.as_mut()).await?; + let payload = PayloadQueryData::from_row(&row)?; + Ok(payload) } - async fn get_vid_common(&self, id: BlockId) -> QueryResult> { - let (where_clause, param) = header_where_clause(id); + async fn get_vid_common( + &mut self, + id: BlockId, + ) -> QueryResult> { + let mut query = QueryBuilder::default(); + let where_clause = query.header_where_clause(id)?; // ORDER BY h.height ASC ensures that if there are duplicate blocks (this can happen when // selecting by payload ID, as payloads are not unique), we return the first one. - let query = format!( + let sql = format!( "SELECT {VID_COMMON_COLUMNS} - FROM header AS h - JOIN vid AS v ON h.height = v.height - WHERE {where_clause} - ORDER BY h.height ASC - LIMIT 1" + FROM header AS h + JOIN vid AS v ON h.height = v.height + WHERE {where_clause} + ORDER BY h.height ASC + LIMIT 1" ); - let row = self.query_one(&query, [param]).await?; - parse_vid_common(row) + let row = query.query(&sql).fetch_one(self.as_mut()).await?; + let common = VidCommonQueryData::from_row(&row)?; + Ok(common) } async fn get_leaf_range( - &self, + &mut self, range: R, ) -> QueryResult>>> where R: RangeBounds + Send, { - let (where_clause, params) = bounds_to_where_clause(range, "height"); - let query = format!("SELECT leaf, qc FROM leaf {where_clause} ORDER BY height ASC"); - let rows = self.query(&query, params).await?; - - Ok(rows.map(|res| parse_leaf(res?)).collect().await) + let mut query = QueryBuilder::default(); + let where_clause = query.bounds_to_where_clause(range, "height")?; + let sql = format!("SELECT {LEAF_COLUMNS} FROM leaf {where_clause} ORDER BY height ASC"); + Ok(query + .query(&sql) + .fetch(self.as_mut()) + .map(|res| LeafQueryData::from_row(&res?)) + .map_err(QueryError::from) + .collect() + .await) } async fn get_block_range( - &self, + &mut self, range: R, ) -> QueryResult>>> where R: RangeBounds + Send, { - let (where_clause, params) = bounds_to_where_clause(range, "h.height"); - let query = format!( + let mut query = QueryBuilder::default(); + let where_clause = query.bounds_to_where_clause(range, "h.height")?; + let sql = format!( "SELECT {BLOCK_COLUMNS} FROM header AS h JOIN payload AS p ON h.height = p.height {where_clause} ORDER BY h.height ASC" ); - let rows = self.query(&query, params).await?; - - Ok(rows.map(|res| parse_block(res?)).collect().await) + Ok(query + .query(&sql) + .fetch(self.as_mut()) + .map(|res| BlockQueryData::from_row(&res?)) + .map_err(QueryError::from) + .collect() + .await) } async fn get_payload_range( - &self, + &mut self, range: R, ) -> QueryResult>>> where R: RangeBounds + Send, { - let (where_clause, params) = bounds_to_where_clause(range, "h.height"); - let query = format!( + let mut query = QueryBuilder::default(); + let where_clause = query.bounds_to_where_clause(range, "h.height")?; + let sql = format!( "SELECT {PAYLOAD_COLUMNS} FROM header AS h JOIN payload AS p ON h.height = p.height {where_clause} ORDER BY h.height ASC" ); - let rows = self.query(&query, params).await?; - - Ok(rows.map(|res| parse_payload(res?)).collect().await) + Ok(query + .query(&sql) + .fetch(self.as_mut()) + .map(|res| PayloadQueryData::from_row(&res?)) + .map_err(QueryError::from) + .collect() + .await) } async fn get_vid_common_range( - &self, + &mut self, range: R, ) -> QueryResult>>> where R: RangeBounds + Send, { - let (where_clause, params) = bounds_to_where_clause(range, "h.height"); - let query = format!( + let mut query = QueryBuilder::default(); + let where_clause = query.bounds_to_where_clause(range, "h.height")?; + let sql = format!( "SELECT {VID_COMMON_COLUMNS} FROM header AS h JOIN vid AS v ON h.height = v.height {where_clause} ORDER BY h.height ASC" ); - let rows = self.query(&query, params).await?; - - Ok(rows.map(|res| parse_vid_common(res?)).collect().await) + Ok(query + .query(&sql) + .fetch(self.as_mut()) + .map(|res| VidCommonQueryData::from_row(&res?)) + .map_err(QueryError::from) + .collect() + .await) } async fn get_transaction( - &self, + &mut self, hash: TransactionHash, ) -> QueryResult> { + let mut query = QueryBuilder::default(); + let hash_param = query.bind(hash.to_string())?; + // ORDER BY ASC ensures that if there are duplicate transactions, we return the first // one. - let query = format!( + let sql = format!( "SELECT {BLOCK_COLUMNS}, t.index AS tx_index FROM header AS h JOIN payload AS p ON h.height = p.height JOIN transaction AS t ON t.block_height = h.height - WHERE t.hash = $1 + WHERE t.hash = {hash_param} ORDER BY (t.block_height, t.index) ASC LIMIT 1" ); - let row = self.query_one(&query, &[&hash.to_string()]).await?; + let row = query.query(&sql).fetch_one(self.as_mut()).await?; // Extract the block. - let block = parse_block(row)?; + let block = BlockQueryData::from_row(&row)?; TransactionQueryData::with_hash(&block, hash).context(ErrorSnafu { message: format!( diff --git a/src/data_source/storage/sql/query/explorer.rs b/src/data_source/storage/sql/queries/explorer.rs similarity index 52% rename from src/data_source/storage/sql/query/explorer.rs rename to src/data_source/storage/sql/queries/explorer.rs index d58a37d05..2d90ad28b 100644 --- a/src/data_source/storage/sql/query/explorer.rs +++ b/src/data_source/storage/sql/queries/explorer.rs @@ -13,20 +13,20 @@ //! Explorer storage implementation for a database query engine. use super::{ - parse_block, - postgres::types::{Json, ToSql}, - Transaction, BLOCK_COLUMNS, + super::transaction::{query, Transaction, TransactionMode}, + Database, Db, DecodeError, QueryBuilder, BLOCK_COLUMNS, }; use crate::{ - availability::{QueryableHeader, QueryablePayload, TransactionIndex}, + availability::{BlockQueryData, QueryableHeader, QueryablePayload, TransactionIndex}, data_source::storage::ExplorerStorage, explorer::{ - self, errors::NotFound, query_data::TransactionDetailResponse, BalanceAmount, BlockDetail, - BlockIdentifier, BlockRange, BlockSummary, ExplorerHistograms, ExplorerSummary, - GenesisOverview, GetBlockDetailError, GetBlockSummariesError, GetBlockSummariesRequest, - GetExplorerSummaryError, GetSearchResultsError, GetTransactionDetailError, - GetTransactionSummariesError, GetTransactionSummariesRequest, SearchResult, - TransactionIdentifier, TransactionRange, TransactionSummary, TransactionSummaryFilter, + self, errors::NotFound, query_data::TransactionDetailResponse, traits::ExplorerHeader, + BalanceAmount, BlockDetail, BlockIdentifier, BlockRange, BlockSummary, ExplorerHistograms, + ExplorerSummary, GenesisOverview, GetBlockDetailError, GetBlockSummariesError, + GetBlockSummariesRequest, GetExplorerSummaryError, GetSearchResultsError, + GetTransactionDetailError, GetTransactionSummariesError, GetTransactionSummariesRequest, + MonetaryValue, SearchResult, TransactionIdentifier, TransactionRange, TransactionSummary, + TransactionSummaryFilter, }, Header, Payload, QueryError, QueryResult, }; @@ -35,127 +35,174 @@ use committable::Committable; use futures::stream::{self, StreamExt, TryStreamExt}; use hotshot_types::traits::node_implementation::NodeType; use itertools::Itertools; +use sqlx::{types::Json, FromRow, Row}; use std::num::NonZeroUsize; +impl From for GetExplorerSummaryError { + fn from(err: sqlx::Error) -> Self { + Self::from(QueryError::from(err)) + } +} + +impl From for GetTransactionDetailError { + fn from(err: sqlx::Error) -> Self { + Self::from(QueryError::from(err)) + } +} + +impl From for GetTransactionSummariesError { + fn from(err: sqlx::Error) -> Self { + Self::from(QueryError::from(err)) + } +} + +impl From for GetBlockDetailError { + fn from(err: sqlx::Error) -> Self { + Self::from(QueryError::from(err)) + } +} + +impl From for GetBlockSummariesError { + fn from(err: sqlx::Error) -> Self { + Self::from(QueryError::from(err)) + } +} + +impl From for GetSearchResultsError { + fn from(err: sqlx::Error) -> Self { + Self::from(QueryError::from(err)) + } +} + +impl<'r, Types> FromRow<'r, ::Row> for BlockSummary +where + Types: NodeType, + Header: QueryableHeader + ExplorerHeader, + Payload: QueryablePayload, +{ + fn from_row(row: &'r ::Row) -> sqlx::Result { + BlockQueryData::::from_row(row)? + .try_into() + .decode_error("malformed block summary") + } +} + +impl<'r, Types> FromRow<'r, ::Row> for BlockDetail +where + Types: NodeType, + Header: QueryableHeader + ExplorerHeader, + Payload: QueryablePayload, + BalanceAmount: Into, +{ + fn from_row(row: &'r ::Row) -> sqlx::Result { + BlockQueryData::::from_row(row)? + .try_into() + .decode_error("malformed block detail") + } +} + #[async_trait] -impl<'a, Types: NodeType> ExplorerStorage for Transaction<'a> +impl ExplorerStorage for Transaction where + Mode: TransactionMode, Types: NodeType, Payload: QueryablePayload, - Header: QueryableHeader + explorer::traits::ExplorerHeader, + Header: QueryableHeader + ExplorerHeader, crate::Transaction: explorer::traits::ExplorerTransaction, BalanceAmount: Into, { async fn get_block_summaries( - &self, + &mut self, request: GetBlockSummariesRequest, ) -> Result>, GetBlockSummariesError> { let request = &request.0; - let (query, params): (String, Vec>) = match request.target { - BlockIdentifier::Latest => ( - format!( - "SELECT {BLOCK_COLUMNS} - FROM header AS h - JOIN payload AS p ON h.height = p.height - ORDER BY h.height DESC - LIMIT $1" - ), - vec![Box::new(request.num_blocks.get() as i64)], + let mut query = QueryBuilder::default(); + let sql = match request.target { + BlockIdentifier::Latest => format!( + "SELECT {BLOCK_COLUMNS} + FROM header AS h + JOIN payload AS p ON h.height = p.height + ORDER BY h.height DESC + LIMIT {}", + query.bind(request.num_blocks.get() as i64)?, ), - BlockIdentifier::Height(height) => ( - format!( - "SELECT {BLOCK_COLUMNS} - FROM header AS h - JOIN payload AS p ON h.height = p.height - WHERE h.height <= $1 - ORDER BY h.height DESC - LIMIT $2" - ), - vec![ - Box::new(height as i64), - Box::new(request.num_blocks.get() as i64), - ], + BlockIdentifier::Height(height) => format!( + "SELECT {BLOCK_COLUMNS} + FROM header AS h + JOIN payload AS p ON h.height = p.height + WHERE h.height <= {} + ORDER BY h.height DESC + LIMIT {}", + query.bind(height as i64)?, + query.bind(request.num_blocks.get() as i64)?, ), - BlockIdentifier::Hash(hash) => ( - // We want to match the blocks starting with the given hash, and working backwards until we - // have returned up to the number of requested blocks. The hash for a block should be unique, - // so we should just need to start with identifying the block height with the given hash, and - // return all blocks with a height less than or equal to that height, up to the number of - // requested blocks. + BlockIdentifier::Hash(hash) => { + // We want to match the blocks starting with the given hash, and working backwards + // until we have returned up to the number of requested blocks. The hash for a + // block should be unique, so we should just need to start with identifying the + // block height with the given hash, and return all blocks with a height less than + // or equal to that height, up to the number of requested blocks. format!( "SELECT {BLOCK_COLUMNS} FROM header AS h JOIN payload AS p ON h.height = p.height - WHERE h.height <= (SELECT h1.height FROM header AS h1 WHERE h1.hash = $1) + WHERE h.height <= (SELECT h1.height FROM header AS h1 WHERE h1.hash = {}) ORDER BY h.height DESC - LIMIT $2" - ), - vec![ - Box::new(hash.to_string()), - Box::new(request.num_blocks.get() as i64), - ], - ), + LIMIT {}", + query.bind(hash.to_string())?, + query.bind(request.num_blocks.get() as i64)?, + ) + } }; - let row_stream = self.query(&query, params).await?; - let result = row_stream.map(|row| -> QueryResult> { - let block = parse_block::(row?)?; - Ok(BlockSummary::try_from(block)?) - }); + let row_stream = query.query(&sql).fetch(self.as_mut()); + let result = row_stream.map(|row| BlockSummary::from_row(&row?)); Ok(result.try_collect().await?) } async fn get_block_detail( - &self, + &mut self, request: BlockIdentifier, ) -> Result, GetBlockDetailError> { - let (query, params): (String, Vec>) = match request { - BlockIdentifier::Latest => ( - format!( - "SELECT {BLOCK_COLUMNS} - FROM header AS h - JOIN payload AS p ON h.height = p.height - ORDER BY h.height DESC - LIMIT 1" - ), - vec![], + let mut query = QueryBuilder::default(); + let sql = match request { + BlockIdentifier::Latest => format!( + "SELECT {BLOCK_COLUMNS} + FROM header AS h + JOIN payload AS p ON h.height = p.height + ORDER BY h.height DESC + LIMIT 1" ), - BlockIdentifier::Height(height) => ( - format!( - "SELECT {BLOCK_COLUMNS} - FROM header AS h - JOIN payload AS p ON h.height = p.height - WHERE h.height = $1 - ORDER BY h.height DESC - LIMIT 1" - ), - vec![Box::new(height as i64)], + BlockIdentifier::Height(height) => format!( + "SELECT {BLOCK_COLUMNS} + FROM header AS h + JOIN payload AS p ON h.height = p.height + WHERE h.height = {} + ORDER BY h.height DESC + LIMIT 1", + query.bind(height as i64)?, ), - BlockIdentifier::Hash(hash) => ( - format!( - "SELECT {BLOCK_COLUMNS} - FROM header AS h - JOIN payload AS p ON h.height = p.height - WHERE h.hash = $1 - ORDER BY h.height DESC - LIMIT 1" - ), - vec![Box::new(hash.to_string())], + BlockIdentifier::Hash(hash) => format!( + "SELECT {BLOCK_COLUMNS} + FROM header AS h + JOIN payload AS p ON h.height = p.height + WHERE h.hash = {} + ORDER BY h.height DESC + LIMIT 1", + query.bind(hash.to_string())?, ), }; - let query_result = self.query_one(&query, params).await?; - let block = parse_block(query_result)?; + let query_result = query.query(&sql).fetch_one(self.as_mut()).await?; + let block = BlockDetail::from_row(&query_result)?; - Ok(BlockDetail::try_from(block).map_err(|e| QueryError::Error { - message: e.to_string(), - })?) + Ok(block) } async fn get_transaction_summaries( - &self, + &mut self, request: GetTransactionSummariesRequest, ) -> Result>, GetTransactionSummariesError> { let range = &request.range; @@ -164,31 +211,30 @@ where // We need to figure out the transaction target we are going to start // returned results based on. - let transaction_target = match target { - TransactionIdentifier::Latest => self.query_one::>>( + let transaction_target_query = match target { + TransactionIdentifier::Latest => query( "SELECT t.block_height AS height, t.index AS index FROM transaction AS t ORDER BY (t.block_height, t.index) DESC LIMIT 1", - vec![], - ).await, - TransactionIdentifier::HeightAndOffset(height, _) => self.query_one( + ), + TransactionIdentifier::HeightAndOffset(height, _) => query( "SELECT t.block_height AS height, t.index AS index FROM transaction AS t WHERE t.block_height = $1 ORDER BY (t.block_height, t.index) DESC LIMIT 1", - [*height as i64], - ).await, - TransactionIdentifier::Hash(hash) => self.query_one( + ) + .bind(*height as i64), + TransactionIdentifier::Hash(hash) => query( "SELECT t.block_height AS height, t.index AS index FROM transaction AS t WHERE t.hash = $1 ORDER BY (t.block_height, t.index) DESC LIMIT 1", - [hash.to_string()], - ).await, + ) + .bind(hash.to_string()), + }; + let Some(transaction_target) = transaction_target_query + .fetch_optional(self.as_mut()) + .await? + else { + // If nothing is found, then we want to return an empty summary list as it means there + // is either no transaction, or the targeting criteria fails to identify any transaction + return Ok(vec![]); }; - let transaction_target = match transaction_target { - // If nothing is found, then we want to return an empty summary list - // as it means there is either no transaction, or the targeting - // criteria fails to identify any transaction - Err(QueryError::NotFound) => return Ok(vec![]), - _ => transaction_target, - }?; - - let block_height = transaction_target.get::<_, i64>("height") as usize; - let transaction_index = transaction_target.get::<_, Json>>("index"); + let block_height = transaction_target.get::("height") as usize; + let transaction_index = transaction_target.get::>, _>("index"); let offset = if let TransactionIdentifier::HeightAndOffset(_, offset) = target { *offset } else { @@ -202,51 +248,40 @@ where // transactions from that point. We then grab only the blocks for those // identified transactions, as only those blocks are needed to pull all // of the relevant transactions. - let block_stream = match filter { + let mut query = QueryBuilder::default(); + let sql = match filter { TransactionSummaryFilter::RollUp(_) => return Ok(vec![]), - TransactionSummaryFilter::None => { - self.query::>>( - &format!( - "SELECT {BLOCK_COLUMNS} + TransactionSummaryFilter::None => format!( + "SELECT {BLOCK_COLUMNS} FROM header AS h JOIN payload AS p ON h.height = p.height WHERE h.height IN ( SELECT t.block_height FROM transaction AS t - WHERE (t.block_height, t.index) <= ($1, $2) + WHERE (t.block_height, t.index) <= ({}, {}) ORDER BY (t.block_height, t.index) DESC - LIMIT $3 + LIMIT {} ) - ORDER BY h.height DESC" - ), - vec![ - Box::new(block_height as i64), - Box::new(&transaction_index), - Box::new((range.num_transactions.get() + offset) as i64), - ], - ) - .await - } + ORDER BY h.height DESC", + query.bind(block_height as i64)?, + query.bind(transaction_index)?, + query.bind((range.num_transactions.get() + offset) as i64)?, + ), - TransactionSummaryFilter::Block(block) => { - self.query::>>( - &format!( - "SELECT {BLOCK_COLUMNS} + TransactionSummaryFilter::Block(block) => format!( + "SELECT {BLOCK_COLUMNS} FROM header AS h JOIN payload AS p ON h.height = p.height - WHERE h.height = $1 - ORDER BY h.height DESC" - ), - vec![Box::new(*block as i64)], - ) - .await - } - }? - .map(|result| match result { - Ok(row) => parse_block::(row), - Err(err) => Err(err), - }); + WHERE h.height = {} + ORDER BY h.height DESC", + query.bind(*block as i64)?, + ), + }; + let block_stream = query + .query(&sql) + .fetch(self.as_mut()) + .map(|row| BlockQueryData::from_row(&row?)); let transaction_summary_stream = block_stream.flat_map(|row| match row { Ok(block) => stream::iter( @@ -265,7 +300,7 @@ where .rev() .collect::>>>(), ), - Err(err) => stream::iter(vec![Err(err)]), + Err(err) => stream::iter(vec![Err(err.into())]), }); let transaction_summary_vec = transaction_summary_stream @@ -286,62 +321,57 @@ where } async fn get_transaction_detail( - &self, + &mut self, request: TransactionIdentifier, ) -> Result, GetTransactionDetailError> { let target = request; - let (query, params): (String, Vec>) = match target { - TransactionIdentifier::Latest => ( - format!( - "SELECT {BLOCK_COLUMNS} - FROM header AS h - JOIN payload AS p ON h.height = p.height - WHERE h.height = ( - SELECT MAX(t1.block_height) - FROM transaction AS t1 - ) - ORDER BY h.height DESC" - ), - vec![], + let mut query = QueryBuilder::default(); + let sql = match target { + TransactionIdentifier::Latest => format!( + "SELECT {BLOCK_COLUMNS} + FROM header AS h + JOIN payload AS p ON h.height = p.height + WHERE h.height = ( + SELECT MAX(t1.block_height) + FROM transaction AS t1 + ) + ORDER BY h.height DESC" ), - TransactionIdentifier::HeightAndOffset(height, offset) => ( - format!( - "SELECT {BLOCK_COLUMNS} - FROM header AS h - JOIN payload AS p ON h.height = p.height - WHERE h.height = ( - SELECT t1.block_height - FROM transaction AS t1 - WHERE t1.block_height = $1 - ORDER BY (t1.block_height, t1.index) DESC - OFFSET $2 - LIMIT 1 - ) - ORDER BY h.height DESC" - ), - vec![Box::new(height as i64), Box::new(offset as i64)], + TransactionIdentifier::HeightAndOffset(height, offset) => format!( + "SELECT {BLOCK_COLUMNS} + FROM header AS h + JOIN payload AS p ON h.height = p.height + WHERE h.height = ( + SELECT t1.block_height + FROM transaction AS t1 + WHERE t1.block_height = {} + ORDER BY (t1.block_height, t1.index) DESC + OFFSET {} + LIMIT 1 + ) + ORDER BY h.height DESC", + query.bind(height as i64)?, + query.bind(offset as i64)?, ), - TransactionIdentifier::Hash(hash) => ( - format!( - "SELECT {BLOCK_COLUMNS} - FROM header AS h - JOIN payload AS p ON h.height = p.height - WHERE h.height = ( - SELECT t1.block_height - FROM transaction AS t1 - WHERE t1.hash = $1 - ORDER BY (t1.block_height, t1.index) DESC - LIMIT 1 - ) - ORDER BY h.height DESC" - ), - vec![Box::new(hash.to_string())], + TransactionIdentifier::Hash(hash) => format!( + "SELECT {BLOCK_COLUMNS} + FROM header AS h + JOIN payload AS p ON h.height = p.height + WHERE h.height = ( + SELECT t1.block_height + FROM transaction AS t1 + WHERE t1.hash = {} + ORDER BY (t1.block_height, t1.index) DESC + LIMIT 1 + ) + ORDER BY h.height DESC", + query.bind(hash.to_string())?, ), }; - let query_row = self.query_one(&query, params).await?; - let block = parse_block::(query_row)?; + let query_row = query.query(&sql).fetch_one(self.as_mut()).await?; + let block = BlockQueryData::::from_row(&query_row)?; let txns = block.enumerate().map(|(_, txn)| txn).collect::>(); @@ -371,10 +401,10 @@ where } async fn get_explorer_summary( - &self, + &mut self, ) -> Result, GetExplorerSummaryError> { - let histograms= async { - let historgram_query_result = self.query( + let histograms = { + let historgram_query_result = query( "SELECT h.height AS height, h.timestamp AS timestamp, @@ -388,29 +418,16 @@ where h.height IN (SELECT height FROM header ORDER BY height DESC LIMIT 50) ORDER BY h.height ASC ", - Vec::>::new(), - ).await?; + ).fetch(self.as_mut()); - let histograms:Result = historgram_query_result + let histograms: Result = historgram_query_result .map(|row_stream| { row_stream.map(|row| { - let height: i64 = row.try_get("height").map_err(|e| QueryError::Error { - message: format!("failed to get column height {e}"), - })?; - let timestamp: i64 = - row.try_get("timestamp").map_err(|e| QueryError::Error { - message: format!("failed to get column timestamp {e}"), - })?; - let time: Option = row.try_get("time").map_err(|e| QueryError::Error { - message: format!("failed to get column time {e}"), - })?; - let size: Option = row.try_get("size").map_err(|e| QueryError::Error { - message: format!("failed to get column size {e}"), - })?; - let num_transactions: i64 = - row.try_get("transactions").map_err(|e| QueryError::Error { - message: format!("failed to get column transactions {e}"), - })?; + let height: i64 = row.try_get("height")?; + let timestamp: i64 = row.try_get("timestamp")?; + let time: Option = row.try_get("time")?; + let size: Option = row.try_get("size")?; + let num_transactions: i64 = row.try_get("transactions")?; Ok((height, timestamp, time, size, num_transactions)) }) @@ -423,7 +440,7 @@ where block_heights: Vec::with_capacity(50), }, |mut histograms: ExplorerHistograms, - row: Result<(i64, i64, Option, Option, i64), QueryError>| async { + row: sqlx::Result<(i64, i64, Option, Option, i64)>| async { let (height, _timestamp, time, size, num_transactions) = row?; histograms.block_time.push(time.map(|i| i as u64)); histograms.block_size.push(size.map(|i| i as u64)); @@ -431,42 +448,37 @@ where histograms.block_heights.push(height as u64); Ok(histograms) }, - ).await; + ) + .await; - histograms - }.await?; + histograms? + }; - let genesis_overview = async { - let row = self - .query_one( - "SELECT + let genesis_overview = { + let row = query( + "SELECT (SELECT MAX(height) + 1 FROM header) AS blocks, (SELECT COUNT(*) FROM transaction) AS transactions", - Vec::>::new(), - ) - .await?; - - let blocks: i64 = row.try_get("blocks").map_err(|e| QueryError::Error { - message: format!("failed to get column blocks {e}"), - })?; - let transactions: i64 = row.try_get("transactions").map_err(|e| QueryError::Error { - message: format!("failed to get column transactions {e}"), - })?; - - let blocks: u64 = blocks.try_into().map_err(|e| QueryError::Error { - message: format!("failed to convert blocks to u64 {e}"), - })?; - let transactions: u64 = transactions.try_into().map_err(|e| QueryError::Error { - message: format!("failed to convert transactions to u64 {e}"), - })?; - - Ok::<_, QueryError>(GenesisOverview { + ) + .fetch_one(self.as_mut()) + .await?; + + let blocks: i64 = row.try_get("blocks")?; + let transactions: i64 = row.try_get("transactions")?; + + let blocks: u64 = blocks + .try_into() + .decode_error("failed to convert blocks to u64 {e}")?; + let transactions: u64 = transactions + .try_into() + .decode_error("failed to convert transactions to u64 {e}")?; + + GenesisOverview { rollups: 0, transactions, blocks, - }) - } - .await?; + } + }; let latest_block: BlockDetail = self.get_block_detail(BlockIdentifier::Latest).await?; @@ -496,8 +508,8 @@ where } async fn get_search_results( - &self, - query: String, + &mut self, + search_query: String, ) -> Result, GetSearchResultsError> { let block_query = format!( "SELECT {BLOCK_COLUMNS} @@ -507,12 +519,11 @@ where ORDER BY h.height DESC LIMIT 5" ); - let block_query_rows = self.query(&block_query, vec![&query]).await?; + let block_query_rows = query(block_query.as_str()) + .bind(&search_query) + .fetch(self.as_mut()); let block_query_result: Vec> = block_query_rows - .map(|row| -> Result, QueryError> { - let block = parse_block::(row?)?; - Ok(BlockSummary::try_from(block)?) - }) + .map(|row| BlockSummary::from_row(&row?)) .try_collect() .await?; @@ -525,14 +536,16 @@ where ORDER BY h.height DESC LIMIT 5" ); - let transactions_query_rows = self.query(&transactions_query, vec![&query]).await?; + let transactions_query_rows = query(transactions_query.as_str()) + .bind(&search_query) + .fetch(self.as_mut()); let transactions_query_result: Vec> = transactions_query_rows .map(|row| -> Result>, QueryError>{ - let block = parse_block::(row?)?; + let block = BlockQueryData::::from_row(&row?)?; let transactions = block .enumerate() .enumerate() - .filter(|(_, (_, txn))| txn.commit().to_string().starts_with(&query)) + .filter(|(_, (_, txn))| txn.commit().to_string().starts_with(&search_query)) .map(|(offset, (_, txn))| { Ok(TransactionSummary::try_from(( &block, offset, txn, diff --git a/src/data_source/storage/sql/query/node.rs b/src/data_source/storage/sql/queries/node.rs similarity index 77% rename from src/data_source/storage/sql/query/node.rs rename to src/data_source/storage/sql/queries/node.rs index e669beff5..286f02e80 100644 --- a/src/data_source/storage/sql/query/node.rs +++ b/src/data_source/storage/sql/queries/node.rs @@ -12,80 +12,83 @@ //! Node storage implementation for a database query engine. -use super::{header_where_clause, parse_header, Transaction, HEADER_COLUMNS}; +use super::{ + super::transaction::{query, query_as, Transaction, TransactionMode}, + parse_header, DecodeError, QueryBuilder, HEADER_COLUMNS, +}; use crate::{ - node::{BlockId, NodeDataSource, SyncStatus, TimeWindowQueryData, WindowStart}, + data_source::storage::NodeStorage, + node::{BlockId, SyncStatus, TimeWindowQueryData, WindowStart}, Header, MissingSnafu, NotFoundSnafu, QueryError, QueryResult, VidShare, }; use async_trait::async_trait; use futures::stream::{StreamExt, TryStreamExt}; use hotshot_types::traits::{block_contents::BlockHeader, node_implementation::NodeType}; use snafu::OptionExt; +use sqlx::Row; #[async_trait] -impl<'a, Types> NodeDataSource for Transaction<'a> +impl NodeStorage for Transaction where + Mode: TransactionMode, Types: NodeType, { - async fn block_height(&self) -> QueryResult { - let query = "SELECT max(height) FROM header"; - let row = self.query_one_static(query).await?; - let height: Option = row.get(0); - match height { - Some(height) => { + async fn block_height(&mut self) -> QueryResult { + match query_as::<(Option,)>("SELECT max(height) FROM header") + .fetch_one(self.as_mut()) + .await? + { + (Some(height),) => { // The height of the block is the number of blocks below it, so the total number of // blocks is one more than the height of the highest block. Ok(height as usize + 1) } - None => { + (None,) => { // If there are no blocks yet, the height is 0. Ok(0) } } } - async fn count_transactions(&self) -> QueryResult { - let row = self - .query_one_static("SELECT count(*) FROM transaction") + async fn count_transactions(&mut self) -> QueryResult { + let (count,) = query_as::<(i64,)>("SELECT count(*) FROM transaction") + .fetch_one(self.as_mut()) .await?; - let count: i64 = row.get(0); Ok(count as usize) } - async fn payload_size(&self) -> QueryResult { - let row = self - .query_one_static("SELECT sum(size) FROM payload") + async fn payload_size(&mut self) -> QueryResult { + let (sum,) = query_as::<(Option,)>("SELECT sum(size) FROM payload") + .fetch_one(self.as_mut()) .await?; - let sum: Option = row.get(0); Ok(sum.unwrap_or(0) as usize) } - async fn vid_share(&self, id: ID) -> QueryResult + async fn vid_share(&mut self, id: ID) -> QueryResult where ID: Into> + Send + Sync, { - let (where_clause, param) = header_where_clause(id.into()); + let mut query = QueryBuilder::default(); + let where_clause = query.header_where_clause(id.into())?; // ORDER BY h.height ASC ensures that if there are duplicate blocks (this can happen when // selecting by payload ID, as payloads are not unique), we return the first one. - let query = format!( + let sql = format!( "SELECT v.share AS share FROM vid AS v JOIN header AS h ON v.height = h.height WHERE {where_clause} ORDER BY h.height ASC LIMIT 1" ); - let row = self.query_one(&query, [param]).await?; - let share_data: Option> = - row.try_get("share").map_err(|err| QueryError::Error { - message: format!("error extracting share data from query results: {err}"), - })?; + let (share_data,) = query + .query_as::<(Option>,)>(&sql) + .fetch_one(self.as_mut()) + .await?; let share_data = share_data.context(MissingSnafu)?; - bincode::deserialize(&share_data).map_err(|err| QueryError::Error { - message: format!("malformed VID share: {err}"), - }) + let share = bincode::deserialize(&share_data).decode_error("malformed VID share")?; + Ok(share) } - async fn sync_status(&self) -> QueryResult { + async fn sync_status(&mut self) -> QueryResult { // A leaf can only be missing if there is no row for it in the database (all its columns are // non-nullable). A block can be missing if its corresponding leaf is missing or if the // block's `data` field is `NULL`. We can find the number of missing leaves and blocks by @@ -109,16 +112,19 @@ where // missing in that case _or_ if the row is present but share data is NULL. Thus, we also // need to select the total number of VID rows and the number of present VID rows with a // NULL share. - let query = "SELECT l.max_height, l.total_leaves, p.null_payloads, v.total_vid, vn.null_vid, pruned_height FROM + let sql = "SELECT l.max_height, l.total_leaves, p.null_payloads, v.total_vid, vn.null_vid, pruned_height FROM (SELECT max(leaf.height) AS max_height, count(*) AS total_leaves FROM leaf) AS l, (SELECT count(*) AS null_payloads FROM payload WHERE data IS NULL) AS p, (SELECT count(*) AS total_vid FROM vid) AS v, (SELECT count(*) AS null_vid FROM vid WHERE share IS NULL) AS vn, coalesce((SELECT last_height FROM pruned_height ORDER BY id DESC LIMIT 1)) as pruned_height "; - let row = self.query_opt_static(query).await?.context(NotFoundSnafu)?; + let row = query(sql) + .fetch_optional(self.as_mut()) + .await? + .context(NotFoundSnafu)?; - let block_height = match row.get::<_, Option>("max_height") { + let block_height = match row.get::, _>("max_height") { Some(height) => { // The height of the block is the number of blocks below it, so the total number of // blocks is one more than the height of the highest block. @@ -129,12 +135,12 @@ where 0 } }; - let total_leaves = row.get::<_, i64>("total_leaves") as usize; - let null_payloads = row.get::<_, i64>("null_payloads") as usize; - let total_vid = row.get::<_, i64>("total_vid") as usize; - let null_vid = row.get::<_, i64>("null_vid") as usize; + let total_leaves = row.get::("total_leaves") as usize; + let null_payloads = row.get::("null_payloads") as usize; + let total_vid = row.get::("total_vid") as usize; + let null_vid = row.get::("null_vid") as usize; let pruned_height = row - .get::<_, Option>("pruned_height") + .get::, _>("pruned_height") .map(|h| h as usize); let missing_leaves = block_height.saturating_sub(total_leaves); @@ -152,7 +158,7 @@ where } async fn get_header_window( - &self, + &mut self, start: impl Into> + Send + Sync, end: u64, ) -> QueryResult>> { @@ -172,21 +178,18 @@ where // Find all blocks starting from `first_block` with timestamps less than `end`. Block // timestamps are monotonically increasing, so this query is guaranteed to return a // contiguous range of blocks ordered by increasing height. - let query = format!( + let sql = format!( "SELECT {HEADER_COLUMNS} FROM header AS h WHERE h.height >= $1 AND h.timestamp < $2 ORDER BY h.height" ); - let rows = self - .query(&query, [&(first_block as i64), &(end as i64)]) - .await?; + let rows = query(&sql) + .bind(first_block as i64) + .bind(end as i64) + .fetch(self.as_mut()); let window = rows - .map(|row| { - parse_header::(row.map_err(|err| QueryError::Error { - message: err.to_string(), - })?) - }) + .map(|row| parse_header::(row?)) .try_collect() .await?; @@ -204,15 +207,16 @@ where // unique, deterministic result (the first block with a given timestamp). This sort may not // be able to use an index, but it shouldn't be too expensive, since there will never be // more than a handful of blocks with the same timestamp. - let query = format!( + let sql = format!( "SELECT {HEADER_COLUMNS} FROM header AS h WHERE h.timestamp >= $1 ORDER BY h.timestamp, h.height LIMIT 1" ); - let next = self - .query_opt(&query, [&(end as i64)]) + let next = query(&sql) + .bind(end as i64) + .fetch_optional(self.as_mut()) .await? .map(parse_header::) .transpose()?; @@ -221,9 +225,9 @@ where } } -impl<'a> Transaction<'a> { +impl Transaction { async fn time_window( - &self, + &mut self, start: u64, end: u64, ) -> QueryResult>> { @@ -238,32 +242,32 @@ impl<'a> Transaction<'a> { // with a given timestamp). This sort may not be able to use an index, but it shouldn't be // too expensive, since there will never be more than a handful of blocks with the same // timestamp. - let query = format!( + let sql = format!( "SELECT {HEADER_COLUMNS} FROM header AS h WHERE h.timestamp >= $1 AND h.timestamp < $2 ORDER BY h.timestamp, h.height" ); - let rows = self.query(&query, [&(start as i64), &(end as i64)]).await?; + let rows = query(&sql) + .bind(start as i64) + .bind(end as i64) + .fetch(self.as_mut()); let window: Vec<_> = rows - .map(|row| { - parse_header::(row.map_err(|err| QueryError::Error { - message: err.to_string(), - })?) - }) + .map(|row| parse_header::(row?)) .try_collect() .await?; // Find the block just after the window. - let query = format!( + let sql = format!( "SELECT {HEADER_COLUMNS} FROM header AS h WHERE h.timestamp >= $1 ORDER BY h.timestamp, h.height LIMIT 1" ); - let next = self - .query_opt(&query, [&(end as i64)]) + let next = query(&sql) + .bind(end as i64) + .fetch_optional(self.as_mut()) .await? .map(parse_header::) .transpose()?; @@ -280,15 +284,16 @@ impl<'a> Transaction<'a> { } // Find the block just before the window. - let query = format!( + let sql = format!( "SELECT {HEADER_COLUMNS} FROM header AS h WHERE h.timestamp < $1 ORDER BY h.timestamp DESC, h.height DESC LIMIT 1" ); - let prev = self - .query_opt(&query, [&(start as i64)]) + let prev = query(&sql) + .bind(start as i64) + .fetch_optional(self.as_mut()) .await? .map(parse_header::) .transpose()?; diff --git a/src/data_source/storage/sql/query/state.rs b/src/data_source/storage/sql/queries/state.rs similarity index 74% rename from src/data_source/storage/sql/query/state.rs rename to src/data_source/storage/sql/queries/state.rs index dc39ff1fc..e1d35db67 100644 --- a/src/data_source/storage/sql/query/state.rs +++ b/src/data_source/storage/sql/queries/state.rs @@ -13,41 +13,37 @@ //! Merklized state storage implementation for a database query engine. use super::{ - postgres::{types::ToSql, Row}, - sql_param, Transaction, + super::transaction::{query_as, Transaction, TransactionMode, Write}, + DecodeError, QueryBuilder, }; use crate::{ - merklized_state::{ - MerklizedState, MerklizedStateDataSource, MerklizedStateHeightPersistence, Snapshot, - }, + data_source::storage::{MerklizedStateHeightStorage, MerklizedStateStorage}, + merklized_state::{MerklizedState, Snapshot}, QueryError, QueryResult, }; -use ark_serialize::{CanonicalDeserialize, SerializationError}; +use ark_serialize::CanonicalDeserialize; use async_std::sync::Arc; use async_trait::async_trait; -use bit_vec::BitVec; -use futures::stream::{StreamExt, TryStreamExt}; +use futures::stream::TryStreamExt; use hotshot_types::traits::node_implementation::NodeType; -use itertools::Itertools; use jf_merkle_tree::{ prelude::{MerkleNode, MerkleProof}, DigestAlgorithm, MerkleCommitment, ToTraversalPath, }; -use std::{ - collections::{HashMap, HashSet, VecDeque}, - fmt::{self, Display, Formatter}, -}; +use sqlx::{types::BitVec, FromRow}; +use std::collections::{HashMap, HashSet, VecDeque}; #[async_trait] -impl<'a, Types, State, const ARITY: usize> MerklizedStateDataSource - for Transaction<'a> +impl MerklizedStateStorage + for Transaction where + Mode: TransactionMode, Types: NodeType, State: MerklizedState + 'static, { /// Retreives a Merkle path from the database async fn get_path( - &self, + &mut self, snapshot: Snapshot, key: State::Key, ) -> QueryResult> { @@ -61,11 +57,11 @@ where // Get all the nodes in the path to the index. // Order by pos DESC is to return nodes from the leaf to the root - let (params, stmt) = build_get_path_query(state_type, traversal_path.clone(), created); - - let nodes = self.query(&stmt, params).await?; - - let nodes: Vec<_> = nodes.map(|res| Node::try_from(res?)).try_collect().await?; + let (query, sql) = build_get_path_query(state_type, traversal_path.clone(), created)?; + let nodes = query + .query_as::(&sql) + .fetch_all(self.as_mut()) + .await?; // insert all the hash ids to a hashset which is used to query later // HashSet is used to avoid duplicates @@ -79,16 +75,12 @@ where // Find all the hash values and create a hashmap // Hashmap will be used to get the hash value of the nodes children and the node itself. - let hashes_query = self - .query( - "SELECT * FROM hash WHERE id = ANY( $1)", - [sql_param(&hash_ids.into_iter().collect::>())], - ) - .await?; - let hashes: HashMap<_, _> = hashes_query - .map(|row| HashTableRow::try_from(row?).map(|h| (h.id, h.value))) - .try_collect() - .await?; + let hashes: HashMap> = + query_as("SELECT id, value FROM hash WHERE id = ANY( $1)") + .bind(hash_ids.into_iter().collect::>()) + .fetch(self.as_mut()) + .try_collect() + .await?; let mut proof_path = VecDeque::with_capacity(State::tree_height()); for Node { @@ -125,7 +117,7 @@ where })?; Ok(Arc::new(MerkleNode::ForgettenSubtree { value: State::T::deserialize_compressed(value.as_slice()) - .map_err(ParseError::Deserialize)?, + .decode_error("malformed merkle node value")?, })) } else { Ok(Arc::new(MerkleNode::Empty)) @@ -135,7 +127,7 @@ where // Use the Children merkle nodes to reconstruct the branch node proof_path.push_back(MerkleNode::Branch { value: State::T::deserialize_compressed(value.as_slice()) - .map_err(ParseError::Deserialize)?, + .decode_error("malformed merkle node value")?, children: child_nodes, }); } @@ -143,11 +135,11 @@ where (None, None, Some(index), Some(entry)) => { proof_path.push_back(MerkleNode::Leaf { value: State::T::deserialize_compressed(value.as_slice()) - .map_err(ParseError::Deserialize)?, + .decode_error("malformed merkle node value")?, pos: serde_json::from_value(index.clone()) - .map_err(ParseError::Serde)?, + .decode_error("malformed merkle node index")?, elem: serde_json::from_value(entry.clone()) - .map_err(ParseError::Serde)?, + .decode_error("malformed merkle element")?, }); } // Otherwise, it's empty. @@ -236,26 +228,26 @@ where } #[async_trait] -impl<'a> MerklizedStateHeightPersistence for Transaction<'a> { - async fn get_last_state_height(&self) -> QueryResult { - let row = self - .query_opt_static("SELECT * from last_merklized_state_height") - .await?; - - let height = row.map(|r| r.get::<_, i64>("height") as usize); - - Ok(height.unwrap_or(0)) +impl MerklizedStateHeightStorage for Transaction { + async fn get_last_state_height(&mut self) -> QueryResult { + let Some((height,)) = query_as::<(i64,)>("SELECT height from last_merklized_state_height") + .fetch_optional(self.as_mut()) + .await? + else { + return Ok(0); + }; + Ok(height as usize) } } -impl<'a> Transaction<'a> { +impl Transaction { /// Get information identifying a [`Snapshot`]. /// /// If the given snapshot is known to the database, this function returns /// * The block height at which the snapshot was created /// * A digest of the Merkle commitment to the snapshotted state async fn snapshot_info( - &self, + &mut self, snapshot: Snapshot, ) -> QueryResult<(i64, State::Commit)> where @@ -271,34 +263,30 @@ impl<'a> Transaction<'a> { // height we get, since any query against equivalent states will yield equivalent // results, regardless of which block the state is from. Thus, we can make this // query fast with `LIMIT 1` and no `ORDER BY`. - let query = self - .query_one( - &format!( - "SELECT height - FROM header - WHERE {header_state_commitment_field} = $1 - LIMIT 1" - ), - &[&commit.to_string()], - ) - .await?; - - (query.get(0), commit) + let (height,) = query_as(&format!( + "SELECT height + FROM header + WHERE {header_state_commitment_field} = $1 + LIMIT 1" + )) + .bind(commit.to_string()) + .fetch_one(self.as_mut()) + .await?; + + (height, commit) } Snapshot::Index(created) => { let created = created as i64; - let row = self - .query_one( - &format!( - "SELECT {header_state_commitment_field} AS root_commmitment - FROM header - WHERE height = $1" - ), - [sql_param(&created)], - ) - .await?; - let commit: String = row.get(0); - let commit = serde_json::from_value(commit.into()).map_err(ParseError::Serde)?; + let (commit,) = query_as::<(String,)>(&format!( + "SELECT {header_state_commitment_field} AS root_commmitment + FROM header + WHERE height = $1" + )) + .bind(created) + .fetch_one(self.as_mut()) + .await?; + let commit = serde_json::from_value(commit.into()) + .decode_error("malformed state commitment")?; (created, commit) } }; @@ -313,83 +301,24 @@ impl<'a> Transaction<'a> { } } -/// Represents a Hash table row -pub(crate) struct HashTableRow { - /// Hash id to be used by the state table to save space - id: i32, - /// hash value - value: Vec, -} - -impl HashTableRow { - // TODO: create a generic upsert function with retries that returns the column - pub(crate) fn build_batch_insert(hashes: &[Vec]) -> (Vec<&(dyn ToSql + Sync)>, String) { - let len = hashes.len(); - let params: Vec<_> = hashes - .iter() - .flat_map(|c| [c as &(dyn ToSql + Sync)]) - .collect(); - let stmt = format!( - "INSERT INTO hash(value) values {} ON CONFLICT (value) DO UPDATE SET value = EXCLUDED.value returning *", - (1..len+1) - .format_with(", ", |v, f| { f(&format_args!("(${v})")) }), +// TODO: create a generic upsert function with retries that returns the column +pub(crate) fn build_hash_batch_insert( + hashes: &[Vec], +) -> QueryResult<(QueryBuilder<'_>, String)> { + let mut query = QueryBuilder::default(); + let params = hashes + .iter() + .map(|hash| Ok(format!("({})", query.bind(hash)?))) + .collect::>>()?; + let sql = format!( + "INSERT INTO hash(value) values {} ON CONFLICT (value) DO UPDATE SET value = EXCLUDED.value returning value, id", + params.join(",") ); - - (params, stmt) - } -} - -// Parse a row to a HashTableRow -impl TryFrom for HashTableRow { - type Error = QueryError; - fn try_from(row: Row) -> QueryResult { - Ok(Self { - id: row.try_get(0).map_err(|e| QueryError::Error { - message: format!("failed to get column id {e}"), - })?, - value: row.try_get(1).map_err(|e| QueryError::Error { - message: format!("failed to get column value {e}"), - })?, - }) - } -} - -// parsing errors -#[derive(Debug)] -pub(crate) enum ParseError { - Serde(serde_json::Error), - Deserialize(SerializationError), - Serialize(SerializationError), -} - -impl Display for ParseError { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - match self { - Self::Serde(err) => { - write!(f, "failed to parse value {err:?}") - } - Self::Deserialize(err) => { - write!(f, "failed to deserialize {err:?}") - } - Self::Serialize(err) => { - write!(f, "failed to serialize {err:?}") - } - } - } -} - -impl std::error::Error for ParseError {} - -impl From for QueryError { - fn from(value: ParseError) -> Self { - Self::Error { - message: value.to_string(), - } - } + Ok((query, sql)) } // Represents a row in a state table -#[derive(Debug, Default, Clone)] +#[derive(Debug, Default, Clone, FromRow)] pub(crate) struct Node { pub(crate) path: Vec, pub(crate) created: i64, @@ -401,103 +330,69 @@ pub(crate) struct Node { } impl Node { - pub(crate) fn build_batch_insert<'a>( - name: &'a str, - nodes: &'a [Self], - ) -> (Vec<&'a (dyn ToSql + Sync)>, String) { - let params: Vec<&(dyn ToSql + Sync)> = nodes - .iter() - .flat_map(|n| { - [ - &n.path as &(dyn ToSql + Sync), - &n.created, - &n.hash_id, - &n.children, - &n.children_bitvec, - &n.index, - &n.entry, - ] - }) - .collect(); - - let stmt = format!( - "INSERT INTO {name} (path, created, hash_id, children, children_bitvec, index, entry) values {} ON CONFLICT (path, created) - DO UPDATE SET hash_id = EXCLUDED.hash_id, children = EXCLUDED.children, children_bitvec = EXCLUDED.children_bitvec, - index = EXCLUDED.index, entry = EXCLUDED.entry RETURNING path", - (1..params.len()+1) - .tuples() - .format_with(", ", |(path, created, id, children, bitmap, i, e), f| - { f(&format_args!("(${path}, ${created}, ${id}, ${children}, ${bitmap}, ${i}, ${e})")) }), - ); - - (params, stmt) - } -} - -// Parse a Row to a Node -impl TryFrom for Node { - type Error = QueryError; - fn try_from(row: Row) -> Result { - Ok(Self { - path: row.try_get(0).map_err(|e| QueryError::Error { - message: format!("failed to get column path: {e}"), - })?, - created: row.try_get(1).map_err(|e| QueryError::Error { - message: format!("failed to get column created: {e}"), - })?, - hash_id: row.try_get(2).map_err(|e| QueryError::Error { - message: format!("failed to get column hash_id: {e}"), - })?, - children: row.try_get(3).map_err(|e| QueryError::Error { - message: format!("failed to get column children: {e}"), - })?, - children_bitvec: row.try_get(4).map_err(|e| QueryError::Error { - message: format!("failed to get column children bitmap: {e}"), - })?, - index: row.try_get(5).map_err(|e| QueryError::Error { - message: format!("failed to get column index: {e}"), - })?, - entry: row.try_get(6).map_err(|e| QueryError::Error { - message: format!("failed to get column entry: {e}"), - })?, - }) + pub(crate) async fn upsert( + name: &str, + nodes: impl IntoIterator, + tx: &mut Transaction, + ) -> anyhow::Result<()> { + tx.upsert( + name, + [ + "path", + "created", + "hash_id", + "children", + "children_bitvec", + "index", + "entry", + ], + ["path", "created"], + nodes.into_iter().map(|n| { + ( + n.path.clone(), + n.created, + n.hash_id, + n.children.clone(), + n.children_bitvec.clone(), + n.index.clone(), + n.entry.clone(), + ) + }), + ) + .await } } -fn build_get_path_query( +fn build_get_path_query<'q>( table: &'static str, traversal_path: Vec, created: i64, -) -> (Vec>, String) { +) -> QueryResult<(QueryBuilder<'q>, String)> { + let mut query = QueryBuilder::default(); let mut traversal_path = traversal_path.into_iter().map(|x| x as i32); + let created = query.bind(created)?; - // Since the 'created' parameter is common to all queries, - // we place it at the first position in the 'params' vector. // We iterate through the path vector skipping the first element after each iteration - let mut params: Vec> = vec![Box::new(created)]; let len = traversal_path.len(); - let mut queries = Vec::new(); - - for i in 0..=len { - let node_path = traversal_path.clone().rev().collect::>(); - - let query = format!( - "(SELECT * FROM {table} WHERE path = ${} AND created <= $1 ORDER BY created DESC LIMIT 1)", - i + 2 + let mut sub_queries = Vec::new(); + for _ in 0..=len { + let node_path = query.bind(traversal_path.clone().rev().collect::>())?; + let sub_query = format!( + "(SELECT * FROM {table} WHERE path = {node_path} AND created <= {created} ORDER BY created DESC LIMIT 1)", ); - queries.push(query); - params.push(Box::new(node_path)); + sub_queries.push(sub_query); traversal_path.next(); } - let mut final_query: String = queries.join(" UNION "); - final_query.push_str("ORDER BY path DESC"); - (params, final_query) + let mut sql: String = sub_queries.join(" UNION "); + sql.push_str("ORDER BY path DESC"); + Ok((query, sql)) } #[cfg(test)] mod test { + use futures::stream::StreamExt; use jf_merkle_tree::{ universal_merkle_tree::UniversalMerkleTree, LookupResult, MerkleTreeScheme, UniversalMerkleTreeScheme, @@ -507,7 +402,7 @@ mod test { use super::*; use crate::{ data_source::{ - storage::sql::{query::sql_param, testing::TmpDb, *}, + storage::sql::{testing::TmpDb, *}, VersionedDataSource, }, merklized_state::UpdateStateData, @@ -539,17 +434,20 @@ mod test { // data field of the header let test_data = serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()}); - tx - .query_opt( - "INSERT INTO HEADER(height, hash, payload_hash, timestamp, data) VALUES ($1, $2, 't', 0, $3) ON CONFLICT(height) DO UPDATE set data = excluded.data", - [ - sql_param(&(block_height as i64)), - sql_param(&format!("randomHash{i}")), - sql_param(&test_data), - ], - ) - .await - .unwrap(); + tx.upsert( + "header", + ["height", "hash", "payload_hash", "timestamp", "data"], + ["height"], + [( + block_height as i64, + format!("randomHash{i}"), + "t", + 0, + test_data, + )], + ) + .await + .unwrap(); // proof for the index from the tree let (_, proof) = test_tree.lookup(i).expect_ok().unwrap(); // traversal path for the index. @@ -574,7 +472,7 @@ mod test { //Get the path and check if it matches the lookup for i in 0..27 { // Query the path for the index - let tx = storage.read().await.unwrap(); + let mut tx = storage.read().await.unwrap(); let merkle_path = tx .get_path( Snapshot::<_, MockMerkleTree, 8>::Index(block_height as u64), @@ -602,17 +500,14 @@ mod test { // data field of the header let mut tx = storage.write().await.unwrap(); let test_data = serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()}); - tx - .query_opt( - "INSERT INTO HEADER(height, hash, payload_hash, timestamp, data) VALUES ($1, $2, 't', 0, $3) ON CONFLICT(height) DO UPDATE set data = excluded.data", - [ - sql_param(&(2_i64)), - sql_param(&"randomstring"), - sql_param(&test_data), - ], - ) - .await - .unwrap(); + tx.upsert( + "header", + ["height", "hash", "payload_hash", "timestamp", "data"], + ["height"], + [(2i64, "randomstring", "t", 0, test_data)], + ) + .await + .unwrap(); let (_, proof_bh_2) = test_tree.lookup(0).expect_ok().unwrap(); // traversal path for the index. let traversal_path = @@ -640,19 +535,13 @@ mod test { .collect::>(); // Find all the nodes of Index 0 in table - let rows = storage - .read() - .await - .unwrap() - .query( - "SELECT * from test_tree where path = $1 ORDER BY created", - [sql_param(&node_path)], - ) - .await - .unwrap(); + let mut tx = storage.read().await.unwrap(); + let rows = query("SELECT * from test_tree where path = $1 ORDER BY created") + .bind(node_path) + .fetch(tx.as_mut()); let nodes: Vec<_> = rows - .map(|res| Node::try_from(res.unwrap())) + .map(|res| Node::from_row(&res.unwrap())) .try_collect() .await .unwrap(); @@ -706,17 +595,14 @@ mod test { let test_data = serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(commitment).unwrap()}); // insert the header with merkle commitment let mut tx = storage.write().await.unwrap(); - tx - .query_opt( - "INSERT INTO HEADER(height, hash, payload_hash, timestamp, data) VALUES ($1, $2, 't', 0, $3) ON CONFLICT(height) DO UPDATE set data = excluded.data", - [ - sql_param(&(block_height as i64)), - sql_param(&"randomString"), - sql_param(&test_data), - ], - ) - .await - .unwrap(); + tx.upsert( + "header", + ["height", "hash", "payload_hash", "timestamp", "data"], + ["height"], + [(block_height as i64, "randomString", "t", 0, test_data)], + ) + .await + .unwrap(); // proof for the index from the tree let (_, proof_before_remove) = test_tree.lookup(0).expect_ok().unwrap(); // traversal path for the index. @@ -772,17 +658,20 @@ mod test { .await .expect("failed to insert nodes"); // Insert the new header - tx - .query_opt( - "INSERT INTO HEADER(height, hash, payload_hash, timestamp, data) VALUES ($1, $2, 't', 0, $3) ON CONFLICT(height) DO UPDATE set data = excluded.data", - [ - sql_param(&2_i64), - sql_param(&"randomString2"), - sql_param(&serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()})), - ], - ) - .await - .unwrap(); + tx.upsert( + "header", + ["height", "hash", "payload_hash", "timestamp", "data"], + ["height"], + [( + 2i64, + "randomString2", + "t", + 0, + serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()}), + )], + ) + .await + .unwrap(); // update saved state height UpdateStateData::<_, MockMerkleTree, 8>::set_last_state_height(&mut tx, 2) .await @@ -833,17 +722,20 @@ mod test { let mut tx = storage.write().await.unwrap(); // Insert a dummy header - tx - .query_opt( - "INSERT INTO HEADER(height, hash, payload_hash, timestamp, data) VALUES ($1, $2, 't', 0, $3)", - [ - sql_param(&(i as i64)), - sql_param(&format!("hash{i}")), - sql_param(&serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()})), - ], - ) - .await - .unwrap(); + tx.upsert( + "header", + ["height", "hash", "payload_hash", "timestamp", "data"], + ["height"], + [( + i as i64, + format!("hash{i}"), + "t", + 0, + serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()}) + )], + ) + .await + .unwrap(); // update saved state height UpdateStateData::<_, MockMerkleTree, 8>::set_last_state_height(&mut tx, i) .await @@ -899,17 +791,14 @@ mod test { let test_data = serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(commitment).unwrap()}); // insert the header with merkle commitment let mut tx = storage.write().await.unwrap(); - tx - .query_opt( - "INSERT INTO HEADER(height, hash, payload_hash, timestamp, data) VALUES ($1, $2, 't', 0, $3) ON CONFLICT(height) DO UPDATE set data = excluded.data", - [ - sql_param(&(block_height as i64)), - sql_param(&"randomString"), - sql_param(&test_data), - ], - ) - .await - .unwrap(); + tx.upsert( + "header", + ["height", "hash", "payload_hash", "timestamp", "data"], + ["height"], + [(block_height as i64, "randomString", "t", 0, test_data)], + ) + .await + .unwrap(); // proof for the index from the tree let (_, proof) = test_tree.lookup(0).expect_ok().unwrap(); // traversal path for the index. @@ -965,17 +854,20 @@ mod test { test_tree.update(i, i).unwrap(); let test_data = serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()}); // insert the header with merkle commitment - tx - .query_opt( - "INSERT INTO HEADER(height, hash, payload_hash, timestamp, data) VALUES ($1, $2, 't', 0, $3) ON CONFLICT(height) DO UPDATE set data = excluded.data", - [ - sql_param(&(block_height as i64)), - sql_param(&format!("randomString{i}")), - sql_param(&test_data), - ], - ) - .await - .unwrap(); + tx.upsert( + "header", + ["height", "hash", "payload_hash", "timestamp", "data"], + ["height"], + [( + block_height as i64, + format!("rarndomString{i}"), + "t", + 0, + test_data, + )], + ) + .await + .unwrap(); // proof for the index from the tree let (_, proof) = test_tree.lookup(i).expect_ok().unwrap(); // traversal path for the index. @@ -1026,18 +918,15 @@ mod test { let test_data = serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()}); // insert the header with merkle commitment - let tx = storage.write().await.unwrap(); - tx - .query_opt( - "INSERT INTO HEADER(height, hash, payload_hash, timestamp, data) VALUES ($1, $2, 't', 0, $3) ON CONFLICT(height) DO UPDATE set data = excluded.data", - [ - sql_param(&(block_height as i64)), - sql_param(&"randomStringgg"), - sql_param(&test_data), - ], - ) - .await - .unwrap(); + let mut tx = storage.write().await.unwrap(); + tx.upsert( + "header", + ["height", "hash", "payload_hash", "timestamp", "data"], + ["height"], + [(block_height as i64, "randomStringgg", "t", 0, test_data)], + ) + .await + .unwrap(); tx.commit().await.unwrap(); // Querying the path again let merkle_proof = storage @@ -1061,17 +950,14 @@ mod test { // insert the header with merkle commitment let mut tx = storage.write().await.unwrap(); - tx - .query_opt( - "INSERT INTO HEADER(height, hash, payload_hash, timestamp, data) VALUES ($1, $2, 't', 0, $3) ON CONFLICT(height) DO UPDATE set data = excluded.data", - [ - sql_param(&(2_i64)), - sql_param(&"randomHashString"), - sql_param(&test_data), - ], - ) - .await - .unwrap(); + tx.upsert( + "header", + ["height", "hash", "payload_hash", "timestamp", "data"], + ["height"], + [(2i64, "randomHashString", "t", 0, test_data)], + ) + .await + .unwrap(); UpdateStateData::<_, MockMerkleTree, 8>::insert_merkle_nodes( &mut tx, proof.clone(), @@ -1089,11 +975,11 @@ mod test { .map(|n| *n as i32) .collect::>(); tx.execute_one( - &format!( + query(&format!( "DELETE FROM {} WHERE created = 2 and path = $1", MockMerkleTree::state_type() - ), - [sql_param(&node_path)], + )) + .bind(node_path), ) .await .expect("failed to delete internal node"); @@ -1186,13 +1072,13 @@ mod test { // insert the header with merkle commitment tx .upsert("header", ["height", "hash", "payload_hash", "timestamp", "data"], ["height"], - [[ - sql_param(&(block_height as i64)), - sql_param(&format!("hash{block_height}")), - sql_param(&"hash"), - sql_param(&0i64), - sql_param(&serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(tree.commitment()).unwrap()})), - ]], + [( + block_height as i64, + format!("hash{block_height}"), + "hash", + 0i64, + serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(tree.commitment()).unwrap()}), + )], ) .await .unwrap(); @@ -1306,15 +1192,20 @@ mod test { let mut tx = storage.write().await.unwrap(); // Insert a header with the tree commitment. - tx - .query_opt( - "INSERT INTO HEADER(height, hash, payload_hash, timestamp, data) VALUES (0, 'hash', 'hash', 0, $1)", - [ - sql_param(&serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()})), - ], - ) - .await - .unwrap(); + tx.upsert( + "header", + ["height", "hash", "payload_hash", "timestamp", "data"], + ["height"], + [( + 0i64, + "hash", + "hash", + 0, + serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()}), + )], + ) + .await + .unwrap(); // Insert Merkle nodes. for i in 0..tree_size { @@ -1360,7 +1251,7 @@ mod test { "DELETE FROM {} WHERE index = $1", MockMerkleTree::state_type() ), - [index], + (index,), ) .await .unwrap(); diff --git a/src/data_source/storage/sql/query.rs b/src/data_source/storage/sql/query.rs deleted file mode 100644 index 2142d97ea..000000000 --- a/src/data_source/storage/sql/query.rs +++ /dev/null @@ -1,268 +0,0 @@ -// Copyright (c) 2022 Espresso Systems (espressosys.com) -// This file is part of the HotShot Query Service library. -// -// This program is free software: you can redistribute it and/or modify it under the terms of the GNU -// General Public License as published by the Free Software Foundation, either version 3 of the -// License, or (at your option) any later version. -// This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without -// even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -// General Public License for more details. -// You should have received a copy of the GNU General Public License along with this program. If not, -// see . - -//! Immutable query functionality of a SQL database. - -use super::{ - postgres::{self, types::ToSql, Row}, - Transaction, -}; -use crate::{ - availability::{ - BlockId, BlockQueryData, LeafQueryData, PayloadQueryData, QueryablePayload, - VidCommonQueryData, - }, - Header, Leaf, Payload, - QueryError::{self, Missing}, - QueryResult, -}; -use hotshot_types::{ - simple_certificate::QuorumCertificate, - traits::{ - block_contents::{BlockHeader, BlockPayload}, - node_implementation::NodeType, - }, -}; -use std::ops::{Bound, RangeBounds}; - -pub(super) mod availability; -pub(super) mod explorer; -pub(super) mod node; -pub(super) mod state; - -impl<'a> Transaction<'a> { - /// Load a header from storage. - /// - /// This function is similar to `AvailabilityStorage::get_header`, but - /// * does not require the `QueryablePayload` bound that that trait impl does - /// * makes it easier to specify types since the type parameter is on the function and not on a - /// trait impl - /// * allows type conversions for the `id` parameter - /// - /// This more ergonomic interface is useful as loading headers is important for many SQL storage - /// functions, not just the `AvailabilityStorage` interface. - async fn load_header( - &self, - id: impl Into> + Send, - ) -> QueryResult> { - let (where_clause, param) = header_where_clause(id.into()); - // ORDER BY h.height ASC ensures that if there are duplicate blocks (this can happen when - // selecting by payload ID, as payloads are not unique), we return the first one. - let query = format!( - "SELECT {HEADER_COLUMNS} - FROM header AS h - WHERE {where_clause} - ORDER BY h.height ASC - LIMIT 1" - ); - let row = self.query_one(&query, [param]).await?; - parse_header::(row) - } -} - -fn parse_leaf(row: Row) -> QueryResult> -where - Types: NodeType, -{ - let leaf = row.try_get("leaf").map_err(|err| QueryError::Error { - message: format!("error extracting leaf from query results: {err}"), - })?; - let leaf: Leaf = serde_json::from_value(leaf).map_err(|err| QueryError::Error { - message: format!("malformed leaf: {err}"), - })?; - - let qc = row.try_get("qc").map_err(|err| QueryError::Error { - message: format!("error extracting QC from query results: {err}"), - })?; - let qc: QuorumCertificate = - serde_json::from_value(qc).map_err(|err| QueryError::Error { - message: format!("malformed QC: {err}"), - })?; - - Ok(LeafQueryData { leaf, qc }) -} - -fn header_where_clause( - id: BlockId, -) -> (&'static str, Box) { - match id { - BlockId::Number(n) => ("h.height = $1", Box::new(n as i64)), - BlockId::Hash(h) => ("h.hash = $1", Box::new(h.to_string())), - BlockId::PayloadHash(h) => ("h.payload_hash = $1", Box::new(h.to_string())), - } -} - -const BLOCK_COLUMNS: &str = - "h.hash AS hash, h.data AS header_data, p.size AS payload_size, p.data AS payload_data"; - -fn parse_block(row: Row) -> QueryResult> -where - Types: NodeType, - Payload: QueryablePayload, -{ - // First, check if we have the payload for this block yet. - let size: Option = row - .try_get("payload_size") - .map_err(|err| QueryError::Error { - message: format!("error extracting payload size from query results: {err}"), - })?; - let payload_data: Option> = - row.try_get("payload_data") - .map_err(|err| QueryError::Error { - message: format!("error extracting payload data from query results: {err}"), - })?; - let (size, payload_data) = size.zip(payload_data).ok_or(Missing)?; - let size = size as u64; - - // Reconstruct the full header. - let header_data = row - .try_get("header_data") - .map_err(|err| QueryError::Error { - message: format!("error extracting header data from query results: {err}"), - })?; - let header: Header = - serde_json::from_value(header_data).map_err(|err| QueryError::Error { - message: format!("malformed header: {err}"), - })?; - - // Reconstruct the full block payload. - let payload = Payload::::from_bytes(&payload_data, header.metadata()); - - // Reconstruct the query data by adding metadata. - let hash: String = row.try_get("hash").map_err(|err| QueryError::Error { - message: format!("error extracting block hash from query results: {err}"), - })?; - let hash = hash.parse().map_err(|err| QueryError::Error { - message: format!("malformed block hash: {err}"), - })?; - - Ok(BlockQueryData { - num_transactions: payload.len(header.metadata()) as u64, - header, - payload, - size, - hash, - }) -} - -const PAYLOAD_COLUMNS: &str = BLOCK_COLUMNS; - -fn parse_payload(row: Row) -> QueryResult> -where - Types: NodeType, - Payload: QueryablePayload, -{ - parse_block(row).map(PayloadQueryData::from) -} - -const VID_COMMON_COLUMNS: &str = "h.height AS height, h.hash AS block_hash, h.payload_hash AS payload_hash, v.common AS common_data"; - -fn parse_vid_common(row: Row) -> QueryResult> -where - Types: NodeType, - Payload: QueryablePayload, -{ - let height = row - .try_get::<_, i64>("height") - .map_err(|err| QueryError::Error { - message: format!("error extracting height from query results: {err}"), - })? as u64; - let block_hash: String = row.try_get("block_hash").map_err(|err| QueryError::Error { - message: format!("error extracting block_hash from query results: {err}"), - })?; - let block_hash = block_hash.parse().map_err(|err| QueryError::Error { - message: format!("malformed block hash: {err}"), - })?; - let payload_hash: String = row - .try_get("payload_hash") - .map_err(|err| QueryError::Error { - message: format!("error extracting payload_hash from query results: {err}"), - })?; - let payload_hash = payload_hash.parse().map_err(|err| QueryError::Error { - message: format!("malformed payload hash: {err}"), - })?; - let common_data: Vec = row - .try_get("common_data") - .map_err(|err| QueryError::Error { - message: format!("error extracting common_data from query results: {err}"), - })?; - let common = bincode::deserialize(&common_data).map_err(|err| QueryError::Error { - message: format!("malformed VID common data: {err}"), - })?; - Ok(VidCommonQueryData { - height, - block_hash, - payload_hash, - common, - }) -} - -const HEADER_COLUMNS: &str = "h.data AS data"; - -fn parse_header(row: Row) -> QueryResult> -where - Types: NodeType, -{ - // Reconstruct the full header. - let data = row.try_get("data").map_err(|err| QueryError::Error { - message: format!("error extracting header data from query results: {err}"), - })?; - serde_json::from_value(data).map_err(|err| QueryError::Error { - message: format!("malformed header: {err}"), - }) -} - -/// Convert range bounds to a SQL where clause constraining a given column. -/// -/// Returns the where clause as a string and a list of query parameters. We assume that there are no -/// other parameters in the query; that is, parameters in the where clause will start from $1. -fn bounds_to_where_clause(range: R, column: &str) -> (String, Vec) -where - R: RangeBounds, -{ - let mut bounds = vec![]; - let mut params = vec![]; - - match range.start_bound() { - Bound::Included(n) => { - params.push(*n as i64); - bounds.push(format!("{column} >= ${}", params.len())); - } - Bound::Excluded(n) => { - params.push(*n as i64); - bounds.push(format!("{column} > ${}", params.len())); - } - Bound::Unbounded => {} - } - match range.end_bound() { - Bound::Included(n) => { - params.push(*n as i64); - bounds.push(format!("{column} <= ${}", params.len())); - } - Bound::Excluded(n) => { - params.push(*n as i64); - bounds.push(format!("{column} < ${}", params.len())); - } - Bound::Unbounded => {} - } - - let mut where_clause = bounds.join(" AND "); - if !where_clause.is_empty() { - where_clause = format!(" WHERE {where_clause}"); - } - - (where_clause, params) -} - -pub(super) fn sql_param(param: &T) -> &(dyn ToSql + Sync) { - param -} diff --git a/src/data_source/storage/sql/transaction.rs b/src/data_source/storage/sql/transaction.rs index 8a92f93b9..db749cc4e 100644 --- a/src/data_source/storage/sql/transaction.rs +++ b/src/data_source/storage/sql/transaction.rs @@ -19,284 +19,228 @@ //! transaction. use super::{ - postgres::{types::BorrowToSql, Client, Row, ToStatement}, - query::{ - sql_param, - state::{HashTableRow, Node, ParseError}, + queries::{ + state::{build_hash_batch_insert, Node}, + DecodeError, }, + Database, Db, }; use crate::{ availability::{ BlockQueryData, LeafQueryData, QueryableHeader, QueryablePayload, UpdateAvailabilityData, VidCommonQueryData, }, - data_source::{ - storage::pruning::PrunedHeightStorage, - update::{self, ReadOnly}, - }, + data_source::{storage::pruning::PrunedHeightStorage, update}, merklized_state::{MerklizedState, UpdateStateData}, - task::Task, types::HeightIndexed, - Header, Payload, - QueryError::{self, NotFound}, - QueryResult, VidShare, + Header, Payload, QueryError, VidShare, }; use anyhow::{bail, ensure, Context}; use ark_serialize::CanonicalSerialize; -use async_std::{ - sync::{Arc, MutexGuard}, - task::sleep, -}; +use async_std::task::sleep; use async_trait::async_trait; -use bit_vec::BitVec; use committable::Committable; -use derivative::Derivative; -use derive_more::{Deref, DerefMut, From}; -use futures::{ - future::Future, - stream::{BoxStream, StreamExt, TryStreamExt}, -}; +use derive_more::{Deref, DerefMut}; +use futures::{future::Future, stream::TryStreamExt}; use hotshot_types::traits::{ block_contents::BlockHeader, node_implementation::NodeType, EncodeBytes, }; -use itertools::{izip, Itertools}; +use itertools::Itertools; use jf_merkle_tree::prelude::{MerkleNode, MerkleProof}; +use sqlx::{pool::Pool, types::BitVec, Encode, Execute, FromRow, Type}; use std::{ collections::{HashMap, HashSet}, - fmt::Display, + marker::PhantomData, time::Duration, }; -#[derive(Debug, Deref, DerefMut)] -pub(super) struct Connection { - #[deref] - #[deref_mut] - client: Arc, +pub use sqlx::Executor; - // When the connection is dropped in a synchronous context, we spawn a task to revert the - // in-progress transaction, so the revert can run asynchronously without blocking the dropping - // thread. If such a revert is in progress, a handle to the task will be stored here. The handle - // _must_ be awaited to allow the revert to finish before the connection can be used again. - revert: Option>>, +pub type Query<'q> = sqlx::query::Query<'q, Db, ::Arguments<'q>>; +pub type QueryAs<'q, T> = sqlx::query::QueryAs<'q, Db, T, ::Arguments<'q>>; + +pub fn query(sql: &str) -> Query<'_> { + sqlx::query(sql) } -impl From for Connection { - fn from(client: Client) -> Self { - Self { - client: Arc::new(client), - revert: None, - } - } +pub fn query_as<'q, T>(sql: &'q str) -> QueryAs<'q, T> +where + T: for<'r> FromRow<'r, ::Row>, +{ + sqlx::query_as(sql) } -impl Connection { - /// Prepare the connection for use. - /// - /// This will wait for any async operations to finish which were spawned when the connection was - /// dropped in a synchronous context. This _must_ be called the first time the connection is - /// invoked in an async context after being dropped in a sync context. - async fn acquire(&mut self) -> anyhow::Result<()> { - if let Some(revert) = self.revert.take() { - // If a revert was started when this connection was released in a synchronous context, - // we must wait for it to finish before reusing the connection. - revert.join().await?; - } - Ok(()) - } +/// Marker type indicating a transaction with read-write access to the database. +#[derive(Clone, Copy, Debug, Default)] +pub struct Write; - /// Spawn a revert to run asynchronously. - fn spawn_revert(&mut self) { - // Consistency check. - assert!( - self.revert.is_none(), - "attempting to queue revert while a queued revert is in progress; this should not be possible", - ); +/// Marker type indicating a transaction with read-only access to the database. +#[derive(Clone, Copy, Debug, Default)] +pub struct Read; - // Get a client that is not connected to the lifetime of this reference, so we can run the - // revert command in the background. - let client = self.client.clone(); - self.revert = Some(Task::spawn("revert postgres transaction", async move { - client.batch_execute("ROLLBACK").await?; - Ok(()) - })); +/// Trait for marker types indicating what type of access a transaction has to the database. +pub trait TransactionMode: Send + Sync { + fn begin( + conn: &mut ::Connection, + ) -> impl Future> + Send; +} + +impl TransactionMode for Write { + async fn begin(conn: &mut ::Connection) -> anyhow::Result<()> { + conn.execute("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE") + .await?; + Ok(()) } +} - #[cfg(test)] - pub(super) async fn vacuum(&mut self) -> anyhow::Result<()> { - self.acquire().await?; - self.batch_execute("VACUUM").await?; +impl TransactionMode for Read { + async fn begin(conn: &mut ::Connection) -> anyhow::Result<()> { + conn.execute("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE, READ ONLY, DEFERRABLE") + .await?; Ok(()) } } /// An atomic SQL transaction. -#[derive(Derivative, From)] -#[derivative(Debug)] -pub struct Transaction<'a> { - inner: MutexGuard<'a, Connection>, - finalized: bool, +#[derive(Debug, Deref, DerefMut)] +pub struct Transaction { + #[deref] + #[deref_mut] + inner: sqlx::Transaction<'static, Db>, + _mode: PhantomData, } -impl<'a> Transaction<'a> { - pub(super) async fn read( - mut inner: MutexGuard<'a, Connection>, - ) -> anyhow::Result> { - inner.acquire().await?; - inner - .batch_execute("BEGIN ISOLATION LEVEL SERIALIZABLE READ ONLY DEFERRABLE") - .await?; +impl Transaction { + pub(super) async fn new(pool: &Pool) -> anyhow::Result { + let mut tx = pool.begin().await?; + Mode::begin(tx.as_mut()).await?; Ok(Self { - inner, - finalized: false, - } - .into()) - } - - pub(super) async fn write(mut inner: MutexGuard<'a, Connection>) -> anyhow::Result { - inner.acquire().await?; - inner - .batch_execute("BEGIN ISOLATION LEVEL SERIALIZABLE") - .await?; - Ok(Self { - inner, - finalized: false, + inner: tx, + _mode: Default::default(), }) } } -impl<'a> update::Transaction for Transaction<'a> { - async fn commit(mut self) -> anyhow::Result<()> { - self.inner.batch_execute("COMMIT").await?; - self.finalized = true; +impl update::Transaction for Transaction { + async fn commit(self) -> anyhow::Result<()> { + self.inner.commit().await?; Ok(()) } - fn revert(mut self) -> impl Future + Send { + fn revert(self) -> impl Future + Send { async move { - self.inner.batch_execute("ROLLBACK").await.unwrap(); - self.finalized = true; + self.inner.rollback().await.unwrap(); } } } -impl<'a> Drop for Transaction<'a> { - fn drop(&mut self) { - if !self.finalized { - // Since `drop` is synchronous, we can't execute the asynchronous revert process here, - // at least not without blocking the current thread (which may be an async executor - // thread, blocking other unrelated futures and causing deadlocks). Instead, we will - // revert the transaction asynchronously. - self.inner.spawn_revert(); - } - } -} - -/// Low-level, general database queries and mutation. -impl<'a> Transaction<'a> { - pub async fn query( - &self, - query: &T, - params: P, - ) -> QueryResult>> - where - T: ?Sized + ToStatement + Sync, - P: IntoIterator + Send, - P::IntoIter: ExactSizeIterator, - P::Item: BorrowToSql, - { - Ok(self - .inner - .query_raw(query, params) - .await - .map_err(postgres_err)? - .map_err(postgres_err) - .boxed()) - } - - /// Query the underlying SQL database with no parameters. - pub async fn query_static( - &self, - query: &T, - ) -> QueryResult>> - where - T: ?Sized + ToStatement + Sync, - { - self.query::(query, []).await - } - - /// Query the underlying SQL database, returning exactly one result or failing. - pub async fn query_one(&self, query: &T, params: P) -> QueryResult +/// A collection of parameters which can be bound to a SQL query. +/// +/// This trait allows us to carry around hetergenous lists of parameters (e.g. tuples) and bind them +/// to a query at the last moment before executing. This means we can manipulate the parameters +/// independently of the query before executing it. For example, by requiring a trait bound of +/// `Params<'p> + Clone`, we get a list (or tuple) of parameters which can be cloned and then bound +/// to a query, which allows us to keep a copy of the parameters around in order to retry the query +/// if it fails. +/// +/// # Lifetimes +/// +/// A SQL [`Query`] with lifetime `'q` borrows from both it's SQL statement (`&'q str`) and its +/// parameters (bound via `bind<'q>`). Sometimes, though, it is necessary for the statement and its +/// parameters to have different (but overlapping) lifetimes. For example, the parameters might be +/// passed in and owned by the caller, while the query string is constructed in the callee and its +/// lifetime is limited to the callee scope. (See for example the [`upsert`](Transaction::upsert) +/// function which does exactly this.) +/// +/// We could rectify this situation with a trait bound like `P: for<'q> Params<'q>`, meaning `P` +/// must be bindable to a query with a lifetime chosen by the callee. However, when `P` is an +/// associated type, such as an element of an iterator, as in +/// `::Item: for<'q> Params<'q>`, [a current limitation](https://blog.rust-lang.org/2022/10/28/gats-stabilization.html#implied-static-requirement-from-higher-ranked-trait-bounds.) +/// in the Rust compiler then requires `P: 'static`, which we don't necessarily want: the caller +/// should be able to pass in a reference to avoid expensive cloning. +/// +/// So, instead, we work around this by making it explicit in the [`Params`] trait that the lifetime +/// of the query we're binding to (`'q`) may be different than the lifetime of the parameters (`'p`) +/// as long as the parameters outlive the duration of the query (the `'p: 'q`) bound on the +/// [`bind`](Self::bind) function. +pub trait Params<'p> { + fn bind<'q>(self, q: Query<'q>) -> Query<'q> where - T: ?Sized + ToStatement + Sync, - P: IntoIterator + Send, - P::IntoIter: ExactSizeIterator, - P::Item: BorrowToSql, - { - self.query_opt(query, params).await?.ok_or(NotFound) - } + 'p: 'q; +} - /// Query the underlying SQL database with no parameters, returning exactly one result or - /// failing. - pub async fn query_one_static(&self, query: &T) -> QueryResult - where - T: ?Sized + ToStatement + Sync, - { - self.query_one::(query, []).await - } +/// A collection of parameters with a statically known length. +/// +/// This is a simple trick for enforcing at compile time that a list of parameters has a certain +/// length, such as matching the length of a list of column names. This can prevent easy mistakes +/// like leaving out a parameter. It is implemented for tuples up to length 8. +pub trait FixedLengthParams<'p, const N: usize>: Params<'p> {} + +macro_rules! impl_tuple_params { + ($n:literal, ($($t:ident,)+)) => { + impl<'p, $($t),+> Params<'p> for ($($t,)+) + where $( + $t: 'p + for<'q> Encode<'q, Db> + Type + ),+ { + fn bind<'q>(self, q: Query<'q>) -> Query<'q> + where + 'p: 'q + { + #[allow(non_snake_case)] + let ($($t,)+) = self; + q $( + .bind($t) + )+ + } + } - /// Query the underlying SQL database, returning zero or one results. - pub async fn query_opt(&self, query: &T, params: P) -> QueryResult> - where - T: ?Sized + ToStatement + Sync, - P: IntoIterator + Send, - P::IntoIter: ExactSizeIterator, - P::Item: BorrowToSql, - { - self.query(query, params).await?.try_next().await - } + impl<'p, $($t),+> FixedLengthParams<'p, $n> for ($($t,)+) + where $( + $t: 'p + for<'q> Encode<'q, Db> + Type + ),+ { + } + }; +} - /// Query the underlying SQL database with no parameters, returning zero or one results. - pub async fn query_opt_static(&self, query: &T) -> QueryResult> - where - T: ?Sized + ToStatement + Sync, - { - self.query_opt::(query, []).await - } +impl_tuple_params!(1, (T,)); +impl_tuple_params!(2, (T1, T2,)); +impl_tuple_params!(3, (T1, T2, T3,)); +impl_tuple_params!(4, (T1, T2, T3, T4,)); +impl_tuple_params!(5, (T1, T2, T3, T4, T5,)); +impl_tuple_params!(6, (T1, T2, T3, T4, T5, T6,)); +impl_tuple_params!(7, (T1, T2, T3, T4, T5, T6, T7,)); +impl_tuple_params!(8, (T1, T2, T3, T4, T5, T6, T7, T8,)); - /// Execute a statement against the underlying database. - /// - /// The results of the statement will be reflected immediately in future statements made within - /// this transaction, but will not be reflected in the underlying database until the transaction - /// is committed with [`commit`](update::Transaction::commit). - pub async fn execute(&mut self, statement: &T, params: P) -> anyhow::Result +impl<'p, T> Params<'p> for Vec +where + T: Params<'p>, +{ + fn bind<'q>(self, mut q: Query<'q>) -> Query<'q> where - T: ?Sized + ToStatement, - P: IntoIterator, - P::IntoIter: ExactSizeIterator, - P::Item: BorrowToSql, + 'p: 'q, { - Ok(self.inner.execute_raw(statement, params).await?) + for params in self { + q = params.bind(q); + } + q } +} +/// Low-level, general database queries and mutation. +impl Transaction { /// Execute a statement that is expected to modify exactly one row. /// /// Returns an error if the database is not modified. - pub async fn execute_one(&mut self, statement: &T, params: P) -> anyhow::Result<()> + pub async fn execute_one<'q, E>(&mut self, statement: E) -> anyhow::Result<()> where - T: ?Sized + ToStatement + Display, - P: IntoIterator, - P::IntoIter: ExactSizeIterator, - P::Item: BorrowToSql, + E: 'q + Execute<'q, Db>, { - let nrows = self.execute_many(statement, params).await?; + let nrows = self.execute_many(statement).await?; if nrows > 1 { // If more than one row is affected, we don't return an error, because clearly // _something_ happened and modified the database. So we don't necessarily want the // caller to retry. But we do log an error, because it seems the query did something // different than the caller intended. - tracing::error!( - %statement, - "statement modified more rows ({nrows}) than expected (1)" - ); + tracing::error!("statement modified more rows ({nrows}) than expected (1)"); } Ok(()) } @@ -304,21 +248,18 @@ impl<'a> Transaction<'a> { /// Execute a statement that is expected to modify exactly one row. /// /// Returns an error if the database is not modified. Retries several times before failing. - pub async fn execute_one_with_retries( + pub async fn execute_one_with_retries<'q>( &mut self, - statement: &T, - params: P, - ) -> anyhow::Result<()> - where - T: ?Sized + ToStatement + Display, - P: IntoIterator + Clone, - P::IntoIter: ExactSizeIterator, - P::Item: BorrowToSql, - { + statement: &'q str, + params: impl Params<'q> + Clone, + ) -> anyhow::Result<()> { let interval = Duration::from_secs(1); let mut retries = 5; - while let Err(err) = self.execute_one(statement, params.clone()).await { + while let Err(err) = self + .execute_one(params.clone().bind(query(statement))) + .await + { tracing::error!( %statement, "error in statement execution ({retries} tries remaining): {err}" @@ -336,40 +277,34 @@ impl<'a> Transaction<'a> { /// Execute a statement that is expected to modify at least one row. /// /// Returns an error if the database is not modified. - pub async fn execute_many(&mut self, statement: &T, params: P) -> anyhow::Result + pub async fn execute_many<'q, E>(&mut self, statement: E) -> anyhow::Result where - T: ?Sized + ToStatement + Display, - P: IntoIterator, - P::IntoIter: ExactSizeIterator, - P::Item: BorrowToSql, + E: 'q + Execute<'q, Db>, { - let nrows = self.execute(statement, params).await?; - ensure!( - nrows > 0, - "statement failed: 0 rows affected. Statement: {statement}" - ); + let nrows = self.execute(statement).await?.rows_affected(); + ensure!(nrows > 0, "statement failed: 0 rows affected"); Ok(nrows) } /// Execute a statement that is expected to modify at least one row. /// /// Returns an error if the database is not modified. Retries several times before failing. - pub async fn execute_many_with_retries( + pub async fn execute_many_with_retries<'q, 'p>( &mut self, - statement: &T, - params: P, + statement: &'q str, + params: impl Params<'p> + Clone, ) -> anyhow::Result where - T: ?Sized + ToStatement + Display, - P: IntoIterator + Clone, - P::IntoIter: ExactSizeIterator, - P::Item: BorrowToSql, + 'p: 'q, { let interval = Duration::from_secs(1); let mut retries = 5; loop { - match self.execute_many(statement, params.clone()).await { + match self + .execute_many(params.clone().bind(query(statement))) + .await + { Ok(nrows) => return Ok(nrows), Err(err) => { tracing::error!( @@ -386,15 +321,16 @@ impl<'a> Transaction<'a> { } } - pub async fn upsert( + pub async fn upsert<'p, const N: usize, R>( &mut self, table: &str, columns: [&str; N], pk: impl IntoIterator, - rows: impl IntoIterator, + rows: R, ) -> anyhow::Result<()> where - P: BorrowToSql + Clone, + R: IntoIterator, + R::Item: 'p + FixedLengthParams<'p, N> + Clone, { let set_columns = columns .iter() @@ -412,7 +348,7 @@ impl<'a> Transaction<'a> { let row_params = (start..end).map(|i| format!("${}", i + 1)).join(","); values.push(format!("({row_params})")); - params.extend(entries); + params.push(entries); num_rows += 1; } @@ -440,10 +376,10 @@ impl<'a> Transaction<'a> { } /// Query service specific mutations. -impl<'a> Transaction<'a> { +impl Transaction { /// Delete a batch of data for pruning. pub(super) async fn delete_batch(&mut self, height: u64) -> anyhow::Result<()> { - self.execute("DELETE FROM header WHERE height <= $1", &[&(height as i64)]) + self.execute(query("DELETE FROM header WHERE height <= $1").bind(height as i64)) .await?; self.save_pruned_height(height).await?; Ok(()) @@ -457,14 +393,14 @@ impl<'a> Transaction<'a> { "pruned_height", ["id", "last_height"], ["id"], - [[sql_param(&(1_i32)), sql_param(&(height as i64))]], + [(1i32, height as i64)], ) .await } } #[async_trait] -impl<'a, Types> UpdateAvailabilityData for Transaction<'a> +impl UpdateAvailabilityData for Transaction where Types: NodeType, Payload: QueryablePayload, @@ -479,20 +415,20 @@ where "header", ["height", "hash", "payload_hash", "data", "timestamp"], ["height"], - [[ - sql_param(&(leaf.height() as i64)), - sql_param(&leaf.block_hash().to_string()), - sql_param(&leaf.leaf().block_header().payload_commitment().to_string()), - sql_param(&header_json), - sql_param(&(leaf.leaf().block_header().timestamp() as i64)), - ]], + [( + leaf.height() as i64, + leaf.block_hash().to_string(), + leaf.leaf().block_header().payload_commitment().to_string(), + header_json, + leaf.leaf().block_header().timestamp() as i64, + )], ) .await?; // Similarly, we can initialize the payload table with a null payload, which can help us // distinguish between blocks that haven't been produced yet and blocks we haven't received // yet when answering queries. - self.upsert("payload", ["height"], ["height"], [[leaf.height() as i64]]) + self.upsert("payload", ["height"], ["height"], [(leaf.height() as i64,)]) .await?; // Finally, we insert the leaf itself, which references the header row we created. @@ -503,13 +439,13 @@ where "leaf", ["height", "hash", "block_hash", "leaf", "qc"], ["height"], - [[ - sql_param(&(leaf.height() as i64)), - sql_param(&leaf.hash().to_string()), - sql_param(&leaf.block_hash().to_string()), - sql_param(&leaf_json), - sql_param(&qc_json), - ]], + [( + leaf.height() as i64, + leaf.hash().to_string(), + leaf.block_hash().to_string(), + leaf_json, + qc_json, + )], ) .await?; @@ -524,43 +460,23 @@ where "payload", ["height", "data", "size"], ["height"], - [[ - sql_param(&(block.height() as i64)), - sql_param(&payload.as_ref()), - sql_param(&(block.size() as i32)), - ]], + [(block.height() as i64, payload.as_ref(), block.size() as i32)], ) .await?; - // Index the transactions in the block. For each transaction, collect, separately, its hash, - // height, and index. These items all have different types, so we collect them into - // different vecs. - let mut tx_hashes = vec![]; - let mut tx_block_heights = vec![]; - let mut tx_indexes = vec![]; + // Index the transactions in the block. + let mut rows = vec![]; for (txn_ix, txn) in block.enumerate() { let txn_ix = serde_json::to_value(&txn_ix).context("failed to serialize transaction index")?; - tx_hashes.push(txn.commit().to_string()); - tx_block_heights.push(block.height() as i64); - tx_indexes.push(txn_ix); + rows.push((txn.commit().to_string(), block.height() as i64, txn_ix)); } - if !tx_hashes.is_empty() { + if !rows.is_empty() { self.upsert( "transaction", ["hash", "block_height", "index"], ["block_height", "index"], - // Now that we have the transaction hashes, block heights, and indexes collected in - // memory, we can combine them all into a single vec using type erasure: all the - // values get converted to `&dyn ToSql`. The references all borrow from one of - // `tx_hashes`, `tx_block_heights`, or `tx_indexes`, which all outlive this function - // call, so the lifetimes work out. - izip!( - tx_hashes.iter().map(sql_param), - tx_block_heights.iter().map(sql_param), - tx_indexes.iter().map(sql_param), - ) - .map(|(hash, height, index)| [hash, height, index]), + rows, ) .await?; } @@ -581,11 +497,7 @@ where "vid", ["height", "common", "share"], ["height"], - [[ - sql_param(&(common.height() as i64)), - sql_param(&common_data), - sql_param(&share_data), - ]], + [(common.height() as i64, common_data, share_data)], ) .await } else { @@ -596,10 +508,7 @@ where "vid", ["height", "common"], ["height"], - [[ - sql_param(&(common.height() as i64)), - sql_param(&common_data), - ]], + [(common.height() as i64, common_data)], ) .await } @@ -607,15 +516,15 @@ where } #[async_trait] -impl<'a, Types: NodeType, State: MerklizedState, const ARITY: usize> - UpdateStateData for Transaction<'a> +impl, const ARITY: usize> + UpdateStateData for Transaction { async fn set_last_state_height(&mut self, height: usize) -> anyhow::Result<()> { self.upsert( "last_merklized_state_height", ["id", "height"], ["id"], - [[sql_param(&(1_i32)), sql_param(&(height as i64))]], + [(1i32, height as i64)], ) .await?; @@ -644,7 +553,8 @@ impl<'a, Types: NodeType, State: MerklizedState, const ARITY: usiz for node in path.iter() { match node { MerkleNode::Empty => { - let index = serde_json::to_value(pos.clone()).map_err(ParseError::Serde)?; + let index = serde_json::to_value(pos.clone()) + .decode_error("malformed merkle position")?; // The node path represents the sequence of nodes from the root down to a specific node. // Therefore, the traversal path needs to be reversed // The root node path is an empty array. @@ -668,12 +578,14 @@ impl<'a, Types: NodeType, State: MerklizedState, const ARITY: usiz // Serialize the leaf node hash value into a vector value .serialize_compressed(&mut leaf_commit) - .map_err(ParseError::Serialize)?; + .decode_error("malformed merkle leaf commitment")?; let path = traversal_path.clone().rev().collect(); - let index = serde_json::to_value(pos.clone()).map_err(ParseError::Serde)?; - let entry = serde_json::to_value(elem).map_err(ParseError::Serde)?; + let index = serde_json::to_value(pos.clone()) + .decode_error("malformed merkle position")?; + let entry = + serde_json::to_value(elem).decode_error("malformed merkle element")?; nodes.push(( Node { @@ -693,7 +605,7 @@ impl<'a, Types: NodeType, State: MerklizedState, const ARITY: usiz let mut branch_hash = Vec::new(); value .serialize_compressed(&mut branch_hash) - .map_err(ParseError::Serialize)?; + .decode_error("malformed merkle branch hash")?; // We only insert the non-empty children in the children field of the table // BitVec is used to separate out Empty children positions @@ -711,7 +623,7 @@ impl<'a, Types: NodeType, State: MerklizedState, const ARITY: usiz let mut hash = Vec::new(); value .serialize_compressed(&mut hash) - .map_err(ParseError::Serialize)?; + .decode_error("malformed merkle node hash")?; children_values.push(hash); // Mark the entry as 1 in bitvec to indiciate a non-empty child @@ -747,14 +659,10 @@ impl<'a, Types: NodeType, State: MerklizedState, const ARITY: usiz // insert all the hashes into database // It returns all the ids inserted in the order they were inserted // We use the hash ids to insert all the nodes - let (params, batch_hash_insert_stmt) = HashTableRow::build_batch_insert(&hashes); - - // Batch insert all the hashes - let nodes_hash_ids: HashMap, i32> = self - .inner - .query_raw(&batch_hash_insert_stmt, params) - .await? - .map_ok(|r| (r.get(1), r.get(0))) + let (query, sql) = build_hash_batch_insert(&hashes)?; + let nodes_hash_ids: HashMap, i32> = query + .query_as(&sql) + .fetch(self.as_mut()) .try_collect() .await?; @@ -777,33 +685,21 @@ impl<'a, Types: NodeType, State: MerklizedState, const ARITY: usiz node.children = Some(children_hashes); } } - let nodes = nodes.into_iter().map(|(n, _, _)| n).collect::>(); - let (params, batch_stmt) = Node::build_batch_insert(name, &nodes); - - // Batch insert all the child hashes - let rows_inserted = self.inner.query_raw(&batch_stmt, params).await?; - - if rows_inserted.count().await != path.len() { - bail!("failed to insert all merkle nodes"); - } - + Node::upsert(name, nodes.into_iter().map(|(n, _, _)| n), self).await?; Ok(()) } } #[async_trait] -impl<'a> PrunedHeightStorage for Transaction<'a> { - async fn load_pruned_height(&self) -> anyhow::Result> { - let row = self - .query_opt_static("SELECT last_height FROM pruned_height ORDER BY id DESC LIMIT 1") - .await?; - let height = row.map(|row| row.get::<_, i64>(0) as u64); - Ok(height) - } -} - -fn postgres_err(err: tokio_postgres::Error) -> QueryError { - QueryError::Error { - message: format!("postgres error: {err:#}"), +impl PrunedHeightStorage for Transaction { + async fn load_pruned_height(&mut self) -> anyhow::Result> { + let Some((height,)) = + query_as::<(i64,)>("SELECT last_height FROM pruned_height ORDER BY id DESC LIMIT 1") + .fetch_optional(self.as_mut()) + .await? + else { + return Ok(None); + }; + Ok(Some(height as u64)) } } diff --git a/src/data_source/update.rs b/src/data_source/update.rs index 5051f136e..c4ca862ca 100644 --- a/src/data_source/update.rs +++ b/src/data_source/update.rs @@ -18,7 +18,6 @@ use crate::{ Leaf, Payload, }; use async_trait::async_trait; -use derive_more::{Deref, From}; use futures::future::Future; use hotshot::types::{Event, EventType}; use hotshot_types::event::LeafInfo; @@ -210,20 +209,3 @@ pub trait Transaction: Send + Sync { fn commit(self) -> impl Future> + Send; fn revert(self) -> impl Future + Send; } - -/// A wrapper around a [`Transaction`] that permits immutable operations only. -#[derive(Debug, Deref, From)] -pub struct ReadOnly(T); - -impl Transaction for ReadOnly -where - T: Transaction, -{ - async fn commit(self) -> anyhow::Result<()> { - self.0.commit().await - } - - fn revert(self) -> impl Future + Send { - self.0.revert() - } -} diff --git a/src/explorer.rs b/src/explorer.rs index 2cdfd5963..ed3ca5a6f 100644 --- a/src/explorer.rs +++ b/src/explorer.rs @@ -11,6 +11,7 @@ // see . pub(crate) mod currency; +pub(crate) mod data_source; pub(crate) mod errors; pub(crate) mod monetary_value; pub(crate) mod query_data; @@ -18,10 +19,10 @@ pub(crate) mod traits; use self::errors::InvalidLimit; use crate::availability::{QueryableHeader, QueryablePayload}; -use crate::data_source::storage::ExplorerStorage; use crate::{api::load_api, Header, Payload, Transaction}; pub use currency::*; +pub use data_source::*; use futures::FutureExt; use hotshot_types::traits::node_implementation::NodeType; pub use monetary_value::*; @@ -242,7 +243,7 @@ where Header: ExplorerHeader + QueryableHeader, Transaction: ExplorerTransaction, Payload: QueryablePayload, - ::State: Send + Sync + ExplorerStorage, + ::State: ExplorerDataSource + Send + Sync, { let mut api = load_api::( Option::>::None, diff --git a/src/explorer/data_source.rs b/src/explorer/data_source.rs new file mode 100644 index 000000000..c236f56ba --- /dev/null +++ b/src/explorer/data_source.rs @@ -0,0 +1,92 @@ +// Copyright (c) 2022 Espresso Systems (espressosys.com) +// This file is part of the HotShot Query Service library. +// +// This program is free software: you can redistribute it and/or modify it under the terms of the GNU +// General Public License as published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without +// even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// General Public License for more details. +// You should have received a copy of the GNU General Public License along with this program. If not, +// see . + +use super::{ + query_data::{ + BlockDetail, BlockIdentifier, BlockSummary, ExplorerSummary, GetBlockDetailError, + GetBlockSummariesError, GetBlockSummariesRequest, GetExplorerSummaryError, + GetSearchResultsError, GetTransactionDetailError, GetTransactionSummariesError, + GetTransactionSummariesRequest, SearchResult, TransactionDetailResponse, + TransactionIdentifier, TransactionSummary, + }, + traits::{ExplorerHeader, ExplorerTransaction}, +}; +use crate::{ + availability::{QueryableHeader, QueryablePayload}, + Header, Payload, Transaction, +}; +use async_trait::async_trait; +use hotshot_types::traits::node_implementation::NodeType; + +/// An interface for querying Data and Statistics from the HotShot Blockchain. +/// +/// This interface provides methods that allows the enabling of querying data +/// concerning the blockchain from the stored data for use with a +/// block explorer. It does not provide the same guarantees as the +/// Availability data source with data fetching. It is not concerned with +/// being up-to-date or having all of the data required, but rather it is +/// concerned with providing the requested data as quickly as possible, and in +/// a way that can be easily cached. +#[async_trait] +pub trait ExplorerDataSource +where + Types: NodeType, + Header: ExplorerHeader + QueryableHeader, + Transaction: ExplorerTransaction, + Payload: QueryablePayload, +{ + /// `get_block_detail` is a method that retrieves the details of a specific + /// block from the blockchain. The block is identified by the given + /// [BlockIdentifier]. + async fn get_block_detail( + &self, + request: BlockIdentifier, + ) -> Result, GetBlockDetailError>; + + /// `get_block_summaries` is a method that retrieves a list of block + /// summaries from the blockchain. The list is generated from the given + /// [GetBlockSummariesRequest]. + async fn get_block_summaries( + &self, + request: GetBlockSummariesRequest, + ) -> Result>, GetBlockSummariesError>; + + /// `get_transaction_detail` is a method that retrieves the details of a + /// specific transaction from the blockchain. The transaction is identified + /// by the given [TransactionIdentifier]. + async fn get_transaction_detail( + &self, + request: TransactionIdentifier, + ) -> Result, GetTransactionDetailError>; + + /// `get_transaction_summaries` is a method that retrieves a list of + /// transaction summaries from the blockchain. The list is generated from + /// the given [GetTransactionSummariesRequest]. + async fn get_transaction_summaries( + &self, + request: GetTransactionSummariesRequest, + ) -> Result>, GetTransactionSummariesError>; + + /// `get_explorer_summary` is a method that retrieves a summary overview of + /// the blockchain. This is useful for displaying information that + /// indicates the overall status of the block chain. + async fn get_explorer_summary(&self) + -> Result, GetExplorerSummaryError>; + + /// `get_search_results` is a method that retrieves the results of a search + /// query against the blockchain. The results are generated from the given + /// query string. + async fn get_search_results( + &self, + query: String, + ) -> Result, GetSearchResultsError>; +} diff --git a/src/fetching/provider/query_service.rs b/src/fetching/provider/query_service.rs index 51c2db7f6..ebd717f9c 100644 --- a/src/fetching/provider/query_service.rs +++ b/src/fetching/provider/query_service.rs @@ -953,7 +953,7 @@ mod test { .read() .await .unwrap() - .as_ref() + .as_mut() .load_pruned_height() .await .unwrap(); @@ -982,7 +982,7 @@ mod test { .read() .await .unwrap() - .as_ref() + .as_mut() .load_pruned_height() .await .unwrap(); @@ -1015,7 +1015,7 @@ mod test { .read() .await .unwrap() - .as_ref() + .as_mut() .load_pruned_height() .await .unwrap(); diff --git a/src/merklized_state.rs b/src/merklized_state.rs index 7bf2152de..453a34151 100644 --- a/src/merklized_state.rs +++ b/src/merklized_state.rs @@ -80,7 +80,7 @@ pub fn define_api< where State: 'static + Send + Sync + ReadState, ::State: - Send + Sync + MerklizedStateDataSource + MerklizedStateHeightPersistence, + MerklizedStateDataSource + MerklizedStateHeightPersistence + Send + Sync, for<'a> >::Error: Display, { let mut api = load_api::( diff --git a/src/merklized_state/data_source.rs b/src/merklized_state/data_source.rs index d6f5f967f..029682ebc 100644 --- a/src/merklized_state/data_source.rs +++ b/src/merklized_state/data_source.rs @@ -33,7 +33,7 @@ use tagged_base64::TaggedBase64; use std::cmp::Ordering; -use crate::{data_source::ReadOnly, QueryResult}; +use crate::QueryResult; /// This trait defines methods that a data source should implement /// It enables retrieval of the membership path for a leaf node, which can be used to reconstruct the Merkle tree state. @@ -50,23 +50,6 @@ where ) -> QueryResult>; } -#[async_trait] -impl MerklizedStateDataSource - for ReadOnly -where - Types: NodeType, - State: MerklizedState, - T: MerklizedStateDataSource + Sync, -{ - async fn get_path( - &self, - snapshot: Snapshot, - key: State::Key, - ) -> QueryResult> { - (**self).get_path(snapshot, key).await - } -} - /// This trait defines methods for updating the storage with the merkle tree state. #[async_trait] pub trait UpdateStateData, const ARITY: usize>: @@ -86,16 +69,6 @@ pub trait MerklizedStateHeightPersistence { async fn get_last_state_height(&self) -> QueryResult; } -#[async_trait] -impl MerklizedStateHeightPersistence for ReadOnly -where - T: MerklizedStateHeightPersistence + Sync, -{ - async fn get_last_state_height(&self) -> QueryResult { - (**self).get_last_state_height().await - } -} - type StateCommitment = >::Commit; /// Snapshot can be queried by block height (index) or merkle tree commitment diff --git a/src/node.rs b/src/node.rs index 24f51d69e..e8fcf281f 100644 --- a/src/node.rs +++ b/src/node.rs @@ -103,7 +103,7 @@ pub fn define_api( ) -> Result, ApiError> where State: 'static + Send + Sync + ReadState, - ::State: Send + Sync + NodeDataSource, + ::State: NodeDataSource + Send + Sync, { let mut api = load_api::( options.api_path.as_ref(), diff --git a/src/node/data_source.rs b/src/node/data_source.rs index caa9e839f..821e3176f 100644 --- a/src/node/data_source.rs +++ b/src/node/data_source.rs @@ -25,7 +25,7 @@ //! trait](crate::availability::UpdateAvailabilityData). use super::query_data::{BlockHash, BlockId, SyncStatus, TimeWindowQueryData}; -use crate::{data_source::ReadOnly, Header, QueryResult, VidShare}; +use crate::{Header, QueryResult, VidShare}; use async_trait::async_trait; use derivative::Derivative; use derive_more::From; @@ -64,41 +64,3 @@ pub trait NodeDataSource { /// Search the database for missing objects and generate a report. async fn sync_status(&self) -> QueryResult; } - -#[async_trait] -impl NodeDataSource for ReadOnly -where - Types: NodeType, - T: NodeDataSource + Sync, -{ - async fn block_height(&self) -> QueryResult { - (**self).block_height().await - } - - async fn count_transactions(&self) -> QueryResult { - (**self).count_transactions().await - } - - async fn payload_size(&self) -> QueryResult { - (**self).payload_size().await - } - - async fn vid_share(&self, id: ID) -> QueryResult - where - ID: Into> + Send + Sync, - { - (**self).vid_share(id).await - } - - async fn sync_status(&self) -> QueryResult { - (**self).sync_status().await - } - - async fn get_header_window( - &self, - start: impl Into> + Send + Sync, - end: u64, - ) -> QueryResult>> { - (**self).get_header_window(start, end).await - } -} From 9bd1f2cb946f19c2aa26f9dd972e55d8ac48240e Mon Sep 17 00:00:00 2001 From: Jeb Bearer Date: Thu, 26 Sep 2024 13:39:02 -0400 Subject: [PATCH 2/5] Document storage traits --- src/data_source/storage.rs | 33 +++++++++++++++++++++++++++++++++ src/task.rs | 9 ++++----- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/src/data_source/storage.rs b/src/data_source/storage.rs index 7029adb71..4133ce26c 100644 --- a/src/data_source/storage.rs +++ b/src/data_source/storage.rs @@ -23,6 +23,39 @@ //! * [`FileSystemStorage`] //! * [`NoStorage`] //! +//! # Storage Traits vs Data Source Traits +//! +//! Many of the traits defined in this module (e.g. [`NodeStorage`], [`ExplorerStorage`], and +//! others) are nearly identical to the corresponding data source traits (e.g. +//! [`NodeDataSource`](crate::node::NodeDataSource), +//! [`ExplorerDataSource`](crate::explorer::ExplorerDataSource), etc). They typically differ in +//! mutability: the storage traits are intended to be implemented on storage +//! [transactions](super::Transaction), and because even reading may update the internal +//! state of a transaction, such as a buffer or database cursor, these traits typically take `&mut +//! self`. This is not a barrier for concurrency since there may be many transactions open +//! simultaneously from a single data source. The data source traits, meanwhile, are implemented on +//! the data source itself. Internally, they usually open a fresh transaction and do all their work +//! on the transaction, not modifying the data source itself, so they take `&self`. +//! +//! For traits that differ _only_ in the mutability of the `self` parameter, it is almost possible +//! to combine them into a single trait whose methods take `self` by value, and implementing said +//! traits for the reference types `&SomeDataSource` and `&mut SomeDataSourceTransaction`. There are +//! two problems with this approach, which lead us to prefer the slight redundance of having +//! separate versions of the traits with mutable and immutable methods: +//! * The trait bounds quickly get out of hand, since we now have trait bounds not only on the type +//! itself, but also on references to that type, and the reference also requires the introduction +//! of an additional lifetime parameter. +//! * We run into a longstanding [`rustc` bug](https://github.com/rust-lang/rust/issues/85063) in +//! which type inference diverges when given trait bounds on reference types, even when +//! theoretically the types are uniquely inferrable. This issue can be worked around by [explicitly +//! specifying type paramters at every call site](https://users.rust-lang.org/t/type-recursion-when-trait-bound-is-added-on-reference-type/74525/2), +//! but this further exacerbates the ergonomic issues with this approach, past the point of +//! viability. +//! +//! Occasionally, there may be further differences between the data source traits and corresponding +//! storage traits. For example, [`AvailabilityStorage`] also differs from +//! [`AvailabilityDataSource`](crate::availability::AvailabilityDataSource) in fallibility. +//! use crate::{ availability::{ diff --git a/src/task.rs b/src/task.rs index 29c733b2b..6fb0193ee 100644 --- a/src/task.rs +++ b/src/task.rs @@ -39,9 +39,8 @@ impl BackgroundTask { /// The caller should ensure that `future` yields back to the executor fairly frequently, to /// ensure timely cancellation in case the task is dropped. If an operation in `future` may run /// for a long time without blocking or yielding, consider using - /// [`yield_now`](async_std::task::yield_now) periodically, or using - /// [`spawn`](async_std::task::spawn) or [`spawn_blocking`](async_std::task::spawn_blocking) to - /// run long operations in a sub-task. + /// [`yield_now`](async_std::task::yield_now) periodically, or using [`spawn`] or + /// [`spawn_blocking`](async_std::task::spawn_blocking) to run long operations in a sub-task. pub fn spawn(name: impl Display, future: F) -> Self where F: Future + Send + 'static, @@ -80,8 +79,8 @@ impl Task { /// ensure timely cancellation in case the task is dropped. If an operation in `future` may run /// for a long time without blocking or yielding, consider using /// [`yield_now`](async_std::task::yield_now) periodically, or using - /// [`spawn`](async_std::task::spawn) or [`spawn_blocking`](async_std::task::spawn_blocking) to - /// run long operations in a sub-task. + /// [`spawn`] or [`spawn_blocking`](async_std::task::spawn_blocking) to run long operations in a + /// sub-task. pub fn spawn(name: impl Display, future: F) -> Self where F: Future + Send + 'static, From eca2f363e5ef5046d35a9106c95b88c565bc5e4d Mon Sep 17 00:00:00 2001 From: Jeb Bearer Date: Thu, 26 Sep 2024 13:55:46 -0400 Subject: [PATCH 3/5] Fix doc tests --- src/data_source/storage/sql.rs | 1 + src/data_source/storage/sql/db.rs | 15 +++++++++++++++ src/data_source/storage/sql/queries.rs | 19 +++++++++++++++---- 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/src/data_source/storage/sql.rs b/src/data_source/storage/sql.rs index 15b87c7a7..aef208f7e 100644 --- a/src/data_source/storage/sql.rs +++ b/src/data_source/storage/sql.rs @@ -46,6 +46,7 @@ pub use anyhow::Error; pub use crate::include_migrations; pub use db::Db; pub use include_dir::include_dir; +pub use queries::QueryBuilder; pub use refinery::Migration; pub use transaction::{query, query_as, Executor, Query, QueryAs, Transaction}; diff --git a/src/data_source/storage/sql/db.rs b/src/data_source/storage/sql/db.rs index cdd3a90ea..94a868d63 100644 --- a/src/data_source/storage/sql/db.rs +++ b/src/data_source/storage/sql/db.rs @@ -1,3 +1,15 @@ +// Copyright (c) 2022 Espresso Systems (espressosys.com) +// This file is part of the HotShot Query Service library. +// +// This program is free software: you can redistribute it and/or modify it under the terms of the GNU +// General Public License as published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without +// even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// General Public License for more details. +// You should have received a copy of the GNU General Public License along with this program. If not, +// see . + /// The concrete database backing a SQL data source. /// /// Currently only Postgres is supported. In the future we can support SQLite as well by making this @@ -16,10 +28,13 @@ /// signatures become untenably messy with bounds like /// /// ``` +/// # use sqlx::{Database, Encode, Executor, IntoArguments, Type}; +/// fn foo() /// where /// for<'a> &'a mut DB::Connection: Executor<'a>, /// for<'q> DB::Arguments<'q>: IntoArguments<'q, DB>, /// for<'a> i64: Type + Encode<'a, DB>, +/// {} /// ``` /// etc. pub type Db = sqlx::Postgres; diff --git a/src/data_source/storage/sql/queries.rs b/src/data_source/storage/sql/queries.rs index 13cd9cf5b..5b5561f9d 100644 --- a/src/data_source/storage/sql/queries.rs +++ b/src/data_source/storage/sql/queries.rs @@ -52,18 +52,29 @@ pub(super) mod state; /// # Example /// /// ``` -/// fn search_and_maybe_filter( +/// # use hotshot_query_service::{ +/// # data_source::storage::sql::{ +/// # Database, Db, QueryBuilder, Transaction, +/// # }, +/// # QueryResult, +/// # }; +/// # use sqlx::FromRow; +/// async fn search_and_maybe_filter( /// tx: &mut Transaction, /// id: Option, -/// ) -> QueryResult> { +/// ) -> QueryResult> +/// where +/// for<'r> T: FromRow<'r, ::Row> + Send + Unpin, +/// { /// let mut query = QueryBuilder::default(); /// let mut sql = "SELECT * FROM table".into(); /// if let Some(id) = id { -/// sql = format!("{sql} WHERE id = {}", query.bind(id)); +/// sql = format!("{sql} WHERE id = {}", query.bind(id)?); /// } /// let results = query /// .query_as(&sql) -/// .fetch_all(tx)?; +/// .fetch_all(tx.as_mut()) +/// .await?; /// Ok(results) /// } /// ``` From 10e356ec91ed8df0fa227f6f4d2104cf35b302f1 Mon Sep 17 00:00:00 2001 From: Jeb Bearer Date: Thu, 26 Sep 2024 19:12:46 -0400 Subject: [PATCH 4/5] Add connection pool configurations --- src/data_source/storage/sql.rs | 41 +++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/src/data_source/storage/sql.rs b/src/data_source/storage/sql.rs index aef208f7e..d0054767c 100644 --- a/src/data_source/storage/sql.rs +++ b/src/data_source/storage/sql.rs @@ -29,7 +29,7 @@ use sqlx::{ postgres::{PgConnectOptions, PgSslMode}, ConnectOptions, Connection, Row, }; -use std::{cmp::min, fmt::Debug, str::FromStr}; +use std::{cmp::min, fmt::Debug, str::FromStr, time::Duration}; pub extern crate sqlx; pub use sqlx::{Database, Postgres, Sqlite}; @@ -325,6 +325,45 @@ impl Config { self.archive = true; self } + + /// Set the maximum idle time of a connection. + /// + /// Any connection which has been open and unused longer than this duration will be + /// automatically closed to reduce load on the server. + pub fn idle_connection_timeout(mut self, timeout: Duration) -> Self { + self.pool_opt = self.pool_opt.idle_timeout(Some(timeout)); + self + } + + /// Set the maximum lifetime of a connection. + /// + /// Any connection which has been open longer than this duration will be automatically closed + /// (and, if needed, replaced), even if it is otherwise healthy. It is good practice to refresh + /// even healthy connections once in a while (e.g. daily) in case of resource leaks in the + /// server implementation. + pub fn connection_timeout(mut self, timeout: Duration) -> Self { + self.pool_opt = self.pool_opt.max_lifetime(Some(timeout)); + self + } + + /// Set the minimum number of connections to maintain at any time. + /// + /// The data source will, to the best of its ability, maintain at least `min` open connections + /// at all times. This can be used to reduce the latency hit of opening new connections when at + /// least this many simultaneous connections are frequently needed. + pub fn min_connections(mut self, min: u32) -> Self { + self.pool_opt = self.pool_opt.min_connections(min); + self + } + + /// Set the maximum number of connections to maintain at any time. + /// + /// Once `max` connections are in use simultaneously, further attempts to acquire a connection + /// (or begin a transaction) will block until one of the existing connections is released. + pub fn max_connections(mut self, max: u32) -> Self { + self.pool_opt = self.pool_opt.max_connections(max); + self + } } /// Storage for the APIs provided in this crate, backed by a remote PostgreSQL database. From 4a1c84b65308b043bbd0a9e757aa2e282f75c656 Mon Sep 17 00:00:00 2001 From: Jeb Bearer Date: Mon, 30 Sep 2024 09:48:34 -0400 Subject: [PATCH 5/5] Remove disabled test for reading during a write transaction This is never going to be supported with READ DEFERRABLE, which blocks until in-progress write transactions complete.Remove disabled test for reading during a write transaction This is never going to be supported with READ DEFERRABLE, which blocks until in-progress write transactions complete.Remove disabled test for reading during a write transaction This is never going to be supported with READ DEFERRABLE, which blocks until in-progress write transactions complete.Remove disabled test for reading during a write transaction This is never going to be supported with READ DEFERRABLE, which blocks until in-progress write transactions complete.Remove disabled test for reading during a write transaction This is never going to be supported with READ DEFERRABLE, which blocks until in-progress write transactions complete.Remove disabled test for reading during a write transaction This is never going to be supported with READ DEFERRABLE, which blocks until in-progress write transactions complete.Remove disabled test for reading during a write transaction This is never going to be supported with READ DEFERRABLE, which blocks until in-progress write transactions complete.Remove disabled test for reading during a write transaction This is never going to be supported with READ DEFERRABLE, which blocks until in-progress write transactions complete.Remove disabled test for reading during a write transaction This is never going to be supported with READ DEFERRABLE, which blocks until in-progress write transactions complete. --- src/data_source.rs | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/src/data_source.rs b/src/data_source.rs index 02595ef61..dbbb3cb3d 100644 --- a/src/data_source.rs +++ b/src/data_source.rs @@ -513,23 +513,6 @@ pub mod persistence_tests { assert_eq!(leaf, tx.get_leaf(1.into()).await.unwrap()); assert_eq!(block, tx.get_block(1.into()).await.unwrap()); - // TODO currently the following check causes a deadlock, because it tries to open a new - // transaction (implicitly via the NodeDataSource and AvailabilityDataSource traits) while - // the current one is still open, which is not yet supported. Once we have proper support - // for multiple concurrent connections - // (https://github.com/EspressoSystems/hotshot-query-service/issues/567), we should reenable - // this. - // // The inserted data is _not_ returned when reading through the data source itself (as - // // opposed to the transaction) since it is not yet committed. - // assert_eq!( - // NodeDataSource::::block_height(&ds) - // .await - // .unwrap(), - // 0 - // ); - // ds.get_leaf(1).await.try_resolve().unwrap_err(); - // ds.get_block(1).await.try_resolve().unwrap_err(); - // Revert the changes. tx.revert().await; assert_eq!(