From a2a6fc557491ad1eeced0fdf383c65891f546665 Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Sat, 30 Nov 2024 12:07:42 +0200 Subject: [PATCH] libsql: WAL pull support --- libsql/examples/offline_writes_pull.rs | 55 ++++++++++++++++++++++ libsql/src/database.rs | 2 +- libsql/src/local/connection.rs | 40 ++++++++++++++++ libsql/src/local/database.rs | 39 +++++++++++++++- libsql/src/sync.rs | 64 ++++++++++++++++++++++++++ 5 files changed, 197 insertions(+), 3 deletions(-) create mode 100644 libsql/examples/offline_writes_pull.rs diff --git a/libsql/examples/offline_writes_pull.rs b/libsql/examples/offline_writes_pull.rs new file mode 100644 index 0000000000..3505df75c2 --- /dev/null +++ b/libsql/examples/offline_writes_pull.rs @@ -0,0 +1,55 @@ +// Example of using a offline writes with libSQL. + +use libsql::Builder; + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt::init(); + + // The local database path where the data will be stored. + let db_path = std::env::var("LIBSQL_DB_PATH") + .map_err(|_| { + eprintln!( + "Please set the LIBSQL_DB_PATH environment variable to set to local database path." + ) + }) + .unwrap(); + + // The remote sync URL to use. + let sync_url = std::env::var("LIBSQL_SYNC_URL") + .map_err(|_| { + eprintln!( + "Please set the LIBSQL_SYNC_URL environment variable to set to remote sync URL." + ) + }) + .unwrap(); + + // The authentication token to use. + let auth_token = std::env::var("LIBSQL_AUTH_TOKEN").unwrap_or("".to_string()); + + let db_builder = Builder::new_synced_database(db_path, sync_url, auth_token); + + let db = match db_builder.build().await { + Ok(db) => db, + Err(error) => { + eprintln!("Error connecting to remote sync server: {}", error); + return; + } + }; + + println!("Syncing database from remote..."); + db.sync().await.unwrap(); + + let conn = db.connect().unwrap(); + let mut results = conn + .query("SELECT * FROM guest_book_entries", ()) + .await + .unwrap(); + println!("Guest book entries:"); + while let Some(row) = results.next().await.unwrap() { + let text: String = row.get(0).unwrap(); + println!(" {}", text); + } + + println!("Done!"); +} diff --git a/libsql/src/database.rs b/libsql/src/database.rs index b93ea66e98..ba9b3ad7cc 100644 --- a/libsql/src/database.rs +++ b/libsql/src/database.rs @@ -373,7 +373,7 @@ cfg_replication! { #[cfg(feature = "replication")] DbType::Sync { db, encryption_config: _ } => db.sync().await, #[cfg(feature = "sync")] - DbType::Offline { db } => db.push().await, + DbType::Offline { db } => db.sync_offline().await, _ => Err(Error::SyncNotSupported(format!("{:?}", self.db_type))), } } diff --git a/libsql/src/local/connection.rs b/libsql/src/local/connection.rs index bb1c7b7ab0..8e73e7efb1 100644 --- a/libsql/src/local/connection.rs +++ b/libsql/src/local/connection.rs @@ -482,6 +482,46 @@ impl Connection { Ok(buf) } + + pub(crate) fn wal_insert_begin(&self) -> Result<()> { + let rc = unsafe { libsql_sys::ffi::libsql_wal_insert_begin(self.handle()) }; + if rc != 0 { + return Err(crate::errors::Error::SqliteFailure( + rc as std::ffi::c_int, + format!("wal_insert_begin failed"), + )); + } + Ok(()) + } + + pub(crate) fn wal_insert_end(&self) -> Result<()> { + let rc = unsafe { libsql_sys::ffi::libsql_wal_insert_end(self.handle()) }; + if rc != 0 { + return Err(crate::errors::Error::SqliteFailure( + rc as std::ffi::c_int, + format!("wal_insert_end failed"), + )); + } + Ok(()) + } + + pub(crate) fn wal_insert_frame(&self, frame: &[u8]) -> Result<()> { + let rc = unsafe { + libsql_sys::ffi::libsql_wal_insert_frame( + self.handle(), + frame.len() as u32, + frame.as_ptr() as *mut std::ffi::c_void, + 0 + ) + }; + if rc != 0 { + return Err(crate::errors::Error::SqliteFailure( + rc as std::ffi::c_int, + format!("wal_insert_frame failed"), + )); + } + Ok(()) + } } impl fmt::Debug for Connection { diff --git a/libsql/src/local/database.rs b/libsql/src/local/database.rs index 64f09fcc12..c69045728b 100644 --- a/libsql/src/local/database.rs +++ b/libsql/src/local/database.rs @@ -386,8 +386,8 @@ impl Database { } #[cfg(feature = "sync")] - /// Push WAL frames to remote. - pub async fn push(&self) -> Result { + /// Sync WAL frames to remote. + pub async fn sync_offline(&self) -> Result { use crate::sync::SyncError; use crate::Error; @@ -425,6 +425,10 @@ impl Database { let max_frame_no = conn.wal_frame_count(); + if max_frame_no == 0 { + return self.try_pull(&mut sync_ctx).await; + } + let generation = sync_ctx.generation(); // TODO: Probe from WAL. let start_frame_no = sync_ctx.durable_frame_num() + 1; let end_frame_no = max_frame_no; @@ -448,6 +452,10 @@ impl Database { sync_ctx.write_metadata().await?; + if start_frame_no > end_frame_no { + return self.try_pull(&mut sync_ctx).await; + } + // TODO(lucio): this can underflow if the server previously returned a higher max_frame_no // than what we have stored here. let frame_count = end_frame_no - start_frame_no + 1; @@ -457,6 +465,33 @@ impl Database { }) } + #[cfg(feature = "sync")] + async fn try_pull(&self, sync_ctx: &mut SyncContext) -> Result { + let generation = sync_ctx.generation(); + let mut frame_no = sync_ctx.durable_frame_num() + 1; + let conn = self.connect()?; + conn.wal_insert_begin()?; + loop { + match sync_ctx.pull_one_frame(generation, frame_no).await { + Ok(frame) => { + conn.wal_insert_frame(&frame)?; + frame_no += 1; + } + Err(e) => { + println!("pull_one_frame error: {:?}", e); + break; + } + } + + } + conn.wal_insert_end()?; + sync_ctx.write_metadata().await?; + Ok(crate::database::Replicated { + frame_no: None, + frames_synced: 1, + }) + } + pub(crate) fn path(&self) -> &str { &self.db_path } diff --git a/libsql/src/sync.rs b/libsql/src/sync.rs index 24f7a5318f..f8c5d7f7a3 100644 --- a/libsql/src/sync.rs +++ b/libsql/src/sync.rs @@ -47,6 +47,8 @@ pub enum SyncError { InvalidPushFrameNoLow(u32, u32), #[error("server returned a higher frame_no: sent={0}, got={1}")] InvalidPushFrameNoHigh(u32, u32), + #[error("failed to pull frame: status={0}, error={1}")] + PullFrame(StatusCode, String), } impl SyncError { @@ -104,6 +106,21 @@ impl SyncContext { Ok(me) } + #[tracing::instrument(skip(self))] + pub(crate) async fn pull_one_frame(&mut self, generation: u32, frame_no: u32) -> Result { + let uri = format!( + "{}/sync/{}/{}/{}", + self.sync_url, + generation, + frame_no, + frame_no + 1 + ); + tracing::debug!("pulling frame"); + let frame = self.pull_with_retry(uri, self.max_retries).await?; + self.durable_frame_num = frame_no; + Ok(frame) + } + #[tracing::instrument(skip(self, frame))] pub(crate) async fn push_one_frame( &mut self, @@ -215,6 +232,53 @@ impl SyncContext { } } + async fn pull_with_retry(&self, uri: String, max_retries: usize) -> Result { + let mut nr_retries = 0; + loop { + let mut req = http::Request::builder().method("GET").uri(uri.clone()); + + match &self.auth_token { + Some(auth_token) => { + req = req.header("Authorization", auth_token); + } + None => {} + } + + let req = req.body(Body::empty()) + .expect("valid request"); + + let res = self + .client + .request(req) + .await + .map_err(SyncError::HttpDispatch)?; + + if res.status().is_success() { + let frame = hyper::body::to_bytes(res.into_body()) + .await + .map_err(SyncError::HttpBody)?; + return Ok(frame); + } + // If we've retried too many times or the error is not a server error, + // return the error. + if nr_retries > max_retries || !res.status().is_server_error() { + let status = res.status(); + + let res_body = hyper::body::to_bytes(res.into_body()) + .await + .map_err(SyncError::HttpBody)?; + + let msg = String::from_utf8_lossy(&res_body[..]); + + return Err(SyncError::PullFrame(status, msg.to_string()).into()); + } + + let delay = std::time::Duration::from_millis(100 * (1 << nr_retries)); + tokio::time::sleep(delay).await; + nr_retries += 1; + } + } + pub(crate) fn durable_frame_num(&self) -> u32 { self.durable_frame_num }