Skip to content

Commit

Permalink
Go bindings: separate query and exec
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Jastrzebski <[email protected]>
  • Loading branch information
haaawk committed Feb 19, 2024
1 parent 27222a5 commit 54ed280
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 14 deletions.
2 changes: 1 addition & 1 deletion bindings/c/example.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ int main(int argc, char *argv[])
goto quit;
}

retval = libsql_execute(conn, "SELECT 1", &rows, &err);
retval = libsql_query(conn, "SELECT 1", &rows, &err);
if (retval != 0) {
fprintf(stderr, "%s\n", err);
goto quit;
Expand Down
8 changes: 6 additions & 2 deletions bindings/c/include/libsql.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,15 @@ int libsql_bind_string(libsql_stmt_t stmt, int idx, const char *value, const cha

int libsql_bind_blob(libsql_stmt_t stmt, int idx, const unsigned char *value, int value_len, const char **out_err_msg);

int libsql_execute_stmt(libsql_stmt_t stmt, libsql_rows_t *out_rows, const char **out_err_msg);
int libsql_query_stmt(libsql_stmt_t stmt, libsql_rows_t *out_rows, const char **out_err_msg);

int libsql_execute_stmt(libsql_stmt_t stmt, const char **out_err_msg);

void libsql_free_stmt(libsql_stmt_t stmt);

int libsql_execute(libsql_connection_t conn, const char *sql, libsql_rows_t *out_rows, const char **out_err_msg);
int libsql_query(libsql_connection_t conn, const char *sql, libsql_rows_t *out_rows, const char **out_err_msg);

int libsql_execute(libsql_connection_t conn, const char *sql, const char **out_err_msg);

void libsql_free_rows(libsql_rows_t res);

Expand Down
47 changes: 45 additions & 2 deletions bindings/c/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ pub unsafe extern "C" fn libsql_bind_blob(
}

#[no_mangle]
pub unsafe extern "C" fn libsql_execute_stmt(
pub unsafe extern "C" fn libsql_query_stmt(
stmt: libsql_stmt_t,
out_rows: *mut libsql_rows_t,
out_err_msg: *mut *const std::ffi::c_char,
Expand All @@ -400,6 +400,25 @@ pub unsafe extern "C" fn libsql_execute_stmt(
0
}

#[no_mangle]
pub unsafe extern "C" fn libsql_execute_stmt(
stmt: libsql_stmt_t,
out_err_msg: *mut *const std::ffi::c_char,
) -> std::ffi::c_int {
if stmt.is_null() {
set_err_msg("Null statement".to_string(), out_err_msg);
return 1;
}
let stmt = stmt.get_ref_mut();
match RT.block_on(stmt.stmt.execute(stmt.params.clone())) {
Ok(_) => 0,
Err(e) => {
set_err_msg(format!("Error executing statement: {}", e), out_err_msg);
2
}
}
}

#[no_mangle]
pub unsafe extern "C" fn libsql_free_stmt(stmt: libsql_stmt_t) {
if stmt.is_null() {
Expand All @@ -409,7 +428,7 @@ pub unsafe extern "C" fn libsql_free_stmt(stmt: libsql_stmt_t) {
}

#[no_mangle]
pub unsafe extern "C" fn libsql_execute(
pub unsafe extern "C" fn libsql_query(
conn: libsql_connection_t,
sql: *const std::ffi::c_char,
out_rows: *mut libsql_rows_t,
Expand Down Expand Up @@ -437,6 +456,30 @@ pub unsafe extern "C" fn libsql_execute(
0
}

#[no_mangle]
pub unsafe extern "C" fn libsql_execute(
conn: libsql_connection_t,
sql: *const std::ffi::c_char,
out_err_msg: *mut *const std::ffi::c_char,
) -> std::ffi::c_int {
let sql = unsafe { std::ffi::CStr::from_ptr(sql) };
let sql = match sql.to_str() {
Ok(sql) => sql,
Err(e) => {
set_err_msg(format!("Wrong SQL: {}", e), out_err_msg);
return 1;
}
};
let conn = conn.get_ref();
match RT.block_on(conn.execute(sql, ())) {
Ok(_) => 0,
Err(e) => {
set_err_msg(format!("Error executing statement: {}", e), out_err_msg);
2
}
}
}

#[no_mangle]
pub unsafe extern "C" fn libsql_free_rows(res: libsql_rows_t) {
if res.is_null() {
Expand Down
23 changes: 16 additions & 7 deletions bindings/go/libsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,22 +368,27 @@ func (c *conn) BeginTx(ctx context.Context, opts sqldriver.TxOptions) (sqldriver
return &tx{c}, nil
}

func (c *conn) executeNoArgs(query string) (C.libsql_rows_t, error) {
func (c *conn) executeNoArgs(query string, exec bool) (C.libsql_rows_t, error) {
queryCString := C.CString(query)
defer C.free(unsafe.Pointer(queryCString))

var rows C.libsql_rows_t
var errMsg *C.char
statusCode := C.libsql_execute(c.nativePtr, queryCString, &rows, &errMsg)
var statusCode C.int
if exec {
statusCode = C.libsql_execute(c.nativePtr, queryCString, &errMsg)
} else {
statusCode = C.libsql_query(c.nativePtr, queryCString, &rows, &errMsg)
}
if statusCode != 0 {
return nil, libsqlError(fmt.Sprint("failed to execute query ", query), statusCode, errMsg)
}
return rows, nil
}

func (c *conn) execute(query string, args []sqldriver.NamedValue) (C.libsql_rows_t, error) {
func (c *conn) execute(query string, args []sqldriver.NamedValue, exec bool) (C.libsql_rows_t, error) {
if len(args) == 0 {
return c.executeNoArgs(query)
return c.executeNoArgs(query, exec)
}
queryCString := C.CString(query)
defer C.free(unsafe.Pointer(queryCString))
Expand Down Expand Up @@ -425,7 +430,11 @@ func (c *conn) execute(query string, args []sqldriver.NamedValue) (C.libsql_rows
}

var rows C.libsql_rows_t
statusCode = C.libsql_execute_stmt(stmt, &rows, &errMsg)
if exec {
statusCode = C.libsql_execute_stmt(stmt, &errMsg)
} else {
statusCode = C.libsql_query_stmt(stmt, &rows, &errMsg)
}
if statusCode != 0 {
return nil, libsqlError(fmt.Sprint("failed to execute query ", query), statusCode, errMsg)
}
Expand All @@ -446,7 +455,7 @@ func (r execResult) RowsAffected() (int64, error) {
}

func (c *conn) ExecContext(ctx context.Context, query string, args []sqldriver.NamedValue) (sqldriver.Result, error) {
rows, err := c.execute(query, args)
rows, err := c.execute(query, args, true)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -644,7 +653,7 @@ func (r *rows) Next(dest []sqldriver.Value) error {
}

func (c *conn) QueryContext(ctx context.Context, query string, args []sqldriver.NamedValue) (sqldriver.Rows, error) {
rowsNativePtr, err := c.execute(query, args)
rowsNativePtr, err := c.execute(query, args, false)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions bindings/go/libsql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1075,7 +1075,7 @@ func TestExec(t *testing.T) {

func TestExecWithQuery(t *testing.T) {
runMemoryAndFileTests(t, func(t *testing.T, db *sql.DB) {
if _, err := db.ExecContext(context.Background(), "SELECT 1"); err != nil {
if _, err := db.QueryContext(context.Background(), "SELECT 1"); err != nil {
t.Fatal(err)
}
})
Expand All @@ -1087,7 +1087,7 @@ func TestErrorExec(t *testing.T) {
if err == nil {
t.Fatal("expected error")
}
if err.Error() != "failed to execute query CREATE TABLES test (id INTEGER, name TEXT)\nerror code = 1: Error executing statement: SQLite failure: `near \"TABLES\": syntax error`" {
if err.Error() != "failed to execute query CREATE TABLES test (id INTEGER, name TEXT)\nerror code = 2: Error executing statement: SQLite failure: `near \"TABLES\": syntax error`" {
t.Fatal("unexpected error:", err)
}
})
Expand Down

0 comments on commit 54ed280

Please sign in to comment.