Skip to content

Commit

Permalink
libsql: WAL pull support
Browse files Browse the repository at this point in the history
  • Loading branch information
penberg committed Dec 5, 2024
1 parent 7843e86 commit a2a6fc5
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 3 deletions.
55 changes: 55 additions & 0 deletions libsql/examples/offline_writes_pull.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Example of using a offline writes with libSQL.

use libsql::Builder;

#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();

// The local database path where the data will be stored.
let db_path = std::env::var("LIBSQL_DB_PATH")
.map_err(|_| {
eprintln!(
"Please set the LIBSQL_DB_PATH environment variable to set to local database path."
)
})
.unwrap();

// The remote sync URL to use.
let sync_url = std::env::var("LIBSQL_SYNC_URL")
.map_err(|_| {
eprintln!(
"Please set the LIBSQL_SYNC_URL environment variable to set to remote sync URL."
)
})
.unwrap();

// The authentication token to use.
let auth_token = std::env::var("LIBSQL_AUTH_TOKEN").unwrap_or("".to_string());

let db_builder = Builder::new_synced_database(db_path, sync_url, auth_token);

let db = match db_builder.build().await {
Ok(db) => db,
Err(error) => {
eprintln!("Error connecting to remote sync server: {}", error);
return;
}
};

println!("Syncing database from remote...");
db.sync().await.unwrap();

let conn = db.connect().unwrap();
let mut results = conn
.query("SELECT * FROM guest_book_entries", ())
.await
.unwrap();
println!("Guest book entries:");
while let Some(row) = results.next().await.unwrap() {
let text: String = row.get(0).unwrap();
println!(" {}", text);
}

println!("Done!");
}
2 changes: 1 addition & 1 deletion libsql/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ cfg_replication! {
#[cfg(feature = "replication")]
DbType::Sync { db, encryption_config: _ } => db.sync().await,
#[cfg(feature = "sync")]
DbType::Offline { db } => db.push().await,
DbType::Offline { db } => db.sync_offline().await,
_ => Err(Error::SyncNotSupported(format!("{:?}", self.db_type))),
}
}
Expand Down
40 changes: 40 additions & 0 deletions libsql/src/local/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,46 @@ impl Connection {

Ok(buf)
}

pub(crate) fn wal_insert_begin(&self) -> Result<()> {
let rc = unsafe { libsql_sys::ffi::libsql_wal_insert_begin(self.handle()) };
if rc != 0 {
return Err(crate::errors::Error::SqliteFailure(
rc as std::ffi::c_int,
format!("wal_insert_begin failed"),
));
}
Ok(())
}

pub(crate) fn wal_insert_end(&self) -> Result<()> {
let rc = unsafe { libsql_sys::ffi::libsql_wal_insert_end(self.handle()) };
if rc != 0 {
return Err(crate::errors::Error::SqliteFailure(
rc as std::ffi::c_int,
format!("wal_insert_end failed"),
));
}
Ok(())
}

pub(crate) fn wal_insert_frame(&self, frame: &[u8]) -> Result<()> {
let rc = unsafe {
libsql_sys::ffi::libsql_wal_insert_frame(
self.handle(),
frame.len() as u32,
frame.as_ptr() as *mut std::ffi::c_void,
0
)
};
if rc != 0 {
return Err(crate::errors::Error::SqliteFailure(
rc as std::ffi::c_int,
format!("wal_insert_frame failed"),
));
}
Ok(())
}
}

impl fmt::Debug for Connection {
Expand Down
39 changes: 37 additions & 2 deletions libsql/src/local/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,8 @@ impl Database {
}

#[cfg(feature = "sync")]
/// Push WAL frames to remote.
pub async fn push(&self) -> Result<crate::database::Replicated> {
/// Sync WAL frames to remote.
pub async fn sync_offline(&self) -> Result<crate::database::Replicated> {
use crate::sync::SyncError;
use crate::Error;

Expand Down Expand Up @@ -425,6 +425,10 @@ impl Database {

let max_frame_no = conn.wal_frame_count();

if max_frame_no == 0 {
return self.try_pull(&mut sync_ctx).await;
}

let generation = sync_ctx.generation(); // TODO: Probe from WAL.
let start_frame_no = sync_ctx.durable_frame_num() + 1;
let end_frame_no = max_frame_no;
Expand All @@ -448,6 +452,10 @@ impl Database {

sync_ctx.write_metadata().await?;

if start_frame_no > end_frame_no {
return self.try_pull(&mut sync_ctx).await;
}

// TODO(lucio): this can underflow if the server previously returned a higher max_frame_no
// than what we have stored here.
let frame_count = end_frame_no - start_frame_no + 1;
Expand All @@ -457,6 +465,33 @@ impl Database {
})
}

#[cfg(feature = "sync")]
async fn try_pull(&self, sync_ctx: &mut SyncContext) -> Result<crate::database::Replicated> {
let generation = sync_ctx.generation();
let mut frame_no = sync_ctx.durable_frame_num() + 1;
let conn = self.connect()?;
conn.wal_insert_begin()?;
loop {
match sync_ctx.pull_one_frame(generation, frame_no).await {
Ok(frame) => {
conn.wal_insert_frame(&frame)?;
frame_no += 1;
}
Err(e) => {
println!("pull_one_frame error: {:?}", e);
break;
}
}

}
conn.wal_insert_end()?;
sync_ctx.write_metadata().await?;
Ok(crate::database::Replicated {
frame_no: None,
frames_synced: 1,
})
}

pub(crate) fn path(&self) -> &str {
&self.db_path
}
Expand Down
64 changes: 64 additions & 0 deletions libsql/src/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ pub enum SyncError {
InvalidPushFrameNoLow(u32, u32),
#[error("server returned a higher frame_no: sent={0}, got={1}")]
InvalidPushFrameNoHigh(u32, u32),
#[error("failed to pull frame: status={0}, error={1}")]
PullFrame(StatusCode, String),
}

impl SyncError {
Expand Down Expand Up @@ -104,6 +106,21 @@ impl SyncContext {
Ok(me)
}

#[tracing::instrument(skip(self))]
pub(crate) async fn pull_one_frame(&mut self, generation: u32, frame_no: u32) -> Result<Bytes> {
let uri = format!(
"{}/sync/{}/{}/{}",
self.sync_url,
generation,
frame_no,
frame_no + 1
);
tracing::debug!("pulling frame");
let frame = self.pull_with_retry(uri, self.max_retries).await?;
self.durable_frame_num = frame_no;
Ok(frame)
}

#[tracing::instrument(skip(self, frame))]
pub(crate) async fn push_one_frame(
&mut self,
Expand Down Expand Up @@ -215,6 +232,53 @@ impl SyncContext {
}
}

async fn pull_with_retry(&self, uri: String, max_retries: usize) -> Result<Bytes> {
let mut nr_retries = 0;
loop {
let mut req = http::Request::builder().method("GET").uri(uri.clone());

match &self.auth_token {
Some(auth_token) => {
req = req.header("Authorization", auth_token);
}
None => {}
}

let req = req.body(Body::empty())
.expect("valid request");

let res = self
.client
.request(req)
.await
.map_err(SyncError::HttpDispatch)?;

if res.status().is_success() {
let frame = hyper::body::to_bytes(res.into_body())
.await
.map_err(SyncError::HttpBody)?;
return Ok(frame);
}
// If we've retried too many times or the error is not a server error,
// return the error.
if nr_retries > max_retries || !res.status().is_server_error() {
let status = res.status();

let res_body = hyper::body::to_bytes(res.into_body())
.await
.map_err(SyncError::HttpBody)?;

let msg = String::from_utf8_lossy(&res_body[..]);

return Err(SyncError::PullFrame(status, msg.to_string()).into());
}

let delay = std::time::Duration::from_millis(100 * (1 << nr_retries));
tokio::time::sleep(delay).await;
nr_retries += 1;
}
}

pub(crate) fn durable_frame_num(&self) -> u32 {
self.durable_frame_num
}
Expand Down

0 comments on commit a2a6fc5

Please sign in to comment.