Skip to content

Commit

Permalink
Merge pull request #72 from tursodatabase/lucio/fix-rt
Browse files Browse the repository at this point in the history
fix rt drop before hrana close
  • Loading branch information
LucioFranco authored Oct 8, 2024
2 parents 5cdf4b3 + 9537eba commit ff13876
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 49 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "libsql-python"
version = "0.0.40"
version = "0.0.41"
edition = "2021"

[lib]
Expand All @@ -16,4 +16,4 @@ tracing-subscriber = "0.3"
[build-dependencies]
version_check = "0.9.5"
# used where logic has to be version/distribution specific, e.g. pypy
pyo3-build-config = { version = "0.19.0" }
pyo3-build-config = { version = "0.19.0" }
10 changes: 5 additions & 5 deletions shell.nix
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
(pkgs.buildFHSUserEnv {
name = "pipzone";
targetPkgs = pkgs: (with pkgs; [
python39
python39Packages.pip
python39Packages.virtualenv
python39Packages.pytest
python39Packages.pyperf
python312
python312Packages.pip
python312Packages.virtualenv
python312Packages.pytest
python312Packages.pyperf
maturin
]);
runScript = "bash";
Expand Down
75 changes: 34 additions & 41 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,26 @@ use pyo3::create_exception;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::{PyList, PyTuple};
use std::cell::RefCell;
use std::sync::Arc;
use std::cell::{OnceCell, RefCell};
use std::sync::{Arc, OnceLock};
use tokio::runtime::{Handle, Runtime};

const LEGACY_TRANSACTION_CONTROL: i32 = -1;

fn rt() -> Handle {
static RT: OnceLock<Runtime> = OnceLock::new();

RT.get_or_init(|| {
tokio::runtime::Builder::new_multi_thread()
.worker_threads(1)
.enable_all()
.build()
.unwrap()
})
.handle()
.clone()
}

fn to_py_err(error: libsql_core::errors::Error) -> PyErr {
let msg = match error {
libsql::Error::SqliteFailure(_, err) => err,
Expand Down Expand Up @@ -99,7 +114,7 @@ fn _connect_core(
) -> PyResult<Connection> {
let ver = env!("CARGO_PKG_VERSION");
let ver = format!("libsql-python-rpc-{ver}");
let rt = tokio::runtime::Runtime::new().unwrap();
let rt = rt();
let encryption_config = match encryption_key {
Some(key) => {
let cipher = libsql::Cipher::default();
Expand Down Expand Up @@ -147,9 +162,8 @@ fn _connect_core(
db,
conn: Arc::new(ConnectionGuard {
conn: Some(conn),
handle: rt.handle().clone(),
handle: rt.clone(),
}),
rt,
isolation_level,
autocommit,
})
Expand Down Expand Up @@ -186,7 +200,6 @@ impl Drop for ConnectionGuard {
pub struct Connection {
db: libsql_core::Database,
conn: Arc<ConnectionGuard>,
rt: tokio::runtime::Runtime,
isolation_level: Option<String>,
autocommit: i32,
}
Expand All @@ -199,7 +212,6 @@ impl Connection {
fn cursor(&self) -> PyResult<Cursor> {
Ok(Cursor {
arraysize: 1,
rt: self.rt.handle().clone(),
conn: self.conn.clone(),
stmt: RefCell::new(None),
rows: RefCell::new(None),
Expand All @@ -212,24 +224,19 @@ impl Connection {

fn sync(self_: PyRef<'_, Self>, py: Python<'_>) -> PyResult<()> {
let fut = {
let _enter = self_.rt.enter();
let _enter = rt().enter();
self_.db.sync()
};
tokio::pin!(fut);

self_
.rt
.block_on(check_signals(py, fut))
.map_err(to_py_err)?;
rt().block_on(check_signals(py, fut)).map_err(to_py_err)?;
Ok(())
}

fn commit(self_: PyRef<'_, Self>) -> PyResult<()> {
// TODO: Switch to libSQL transaction API
if !self_.conn.is_autocommit() {
self_
.rt
.block_on(async { self_.conn.execute("COMMIT", ()).await })
rt().block_on(async { self_.conn.execute("COMMIT", ()).await })
.map_err(to_py_err)?;
}
Ok(())
Expand All @@ -238,9 +245,7 @@ impl Connection {
fn rollback(self_: PyRef<'_, Self>) -> PyResult<()> {
// TODO: Switch to libSQL transaction API
if !self_.conn.is_autocommit() {
self_
.rt
.block_on(async { self_.conn.execute("ROLLBACK", ()).await })
rt().block_on(async { self_.conn.execute("ROLLBACK", ()).await })
.map_err(to_py_err)?;
}
Ok(())
Expand All @@ -252,8 +257,7 @@ impl Connection {
parameters: Option<&PyTuple>,
) -> PyResult<Cursor> {
let cursor = Connection::cursor(&self_)?;
let rt = self_.rt.handle();
rt.block_on(async { execute(&cursor, sql, parameters).await })?;
rt().block_on(async { execute(&cursor, sql, parameters).await })?;
Ok(cursor)
}

Expand All @@ -265,17 +269,15 @@ impl Connection {
let cursor = Connection::cursor(&self_)?;
for parameters in parameters.unwrap().iter() {
let parameters = parameters.extract::<&PyTuple>()?;
self_
.rt
.block_on(async { execute(&cursor, sql.clone(), Some(parameters)).await })?;
rt().block_on(async { execute(&cursor, sql.clone(), Some(parameters)).await })?;
}
Ok(cursor)
}

fn executescript(self_: PyRef<'_, Self>, script: String) -> PyResult<()> {
let _ = self_.rt.block_on(async {
self_.conn.execute_batch(&script).await
}).map_err(to_py_err);
let _ = rt()
.block_on(async { self_.conn.execute_batch(&script).await })
.map_err(to_py_err);
Ok(())
}

Expand Down Expand Up @@ -316,7 +318,6 @@ impl Connection {
pub struct Cursor {
#[pyo3(get, set)]
arraysize: usize,
rt: tokio::runtime::Handle,
conn: Arc<ConnectionGuard>,
stmt: RefCell<Option<libsql_core::Statement>>,
rows: RefCell<Option<libsql_core::Rows>>,
Expand All @@ -336,9 +337,7 @@ impl Cursor {
sql: String,
parameters: Option<&PyTuple>,
) -> PyResult<pyo3::PyRef<'a, Self>> {
self_
.rt
.block_on(async { execute(&self_, sql, parameters).await })?;
rt().block_on(async { execute(&self_, sql, parameters).await })?;
Ok(self_)
}

Expand All @@ -349,9 +348,7 @@ impl Cursor {
) -> PyResult<pyo3::PyRef<'a, Cursor>> {
for parameters in parameters.unwrap().iter() {
let parameters = parameters.extract::<&PyTuple>()?;
self_
.rt
.block_on(async { execute(&self_, sql.clone(), Some(parameters)).await })?;
rt().block_on(async { execute(&self_, sql.clone(), Some(parameters)).await })?;
}
Ok(self_)
}
Expand All @@ -360,9 +357,7 @@ impl Cursor {
self_: PyRef<'a, Self>,
script: String,
) -> PyResult<pyo3::PyRef<'a, Self>> {
self_
.rt
.block_on(async { self_.conn.execute_batch(&script).await })
rt().block_on(async { self_.conn.execute_batch(&script).await })
.map_err(to_py_err)?;
Ok(self_)
}
Expand Down Expand Up @@ -398,7 +393,7 @@ impl Cursor {
let mut rows = self_.rows.borrow_mut();
match rows.as_mut() {
Some(rows) => {
let row = self_.rt.block_on(rows.next()).map_err(to_py_err)?;
let row = rt().block_on(rows.next()).map_err(to_py_err)?;
match row {
Some(row) => {
let row = convert_row(self_.py(), row, rows.column_count())?;
Expand All @@ -422,8 +417,7 @@ impl Cursor {
// done before iterating.
if !*self_.done.borrow() {
for _ in 0..size {
let row = self_
.rt
let row = rt()
.block_on(async { rows.next().await })
.map_err(to_py_err)?;
match row {
Expand All @@ -450,8 +444,7 @@ impl Cursor {
Some(rows) => {
let mut elements: Vec<Py<PyAny>> = vec![];
loop {
let row = self_
.rt
let row = rt()
.block_on(async { rows.next().await })
.map_err(to_py_err)?;
match row {
Expand Down

0 comments on commit ff13876

Please sign in to comment.