Skip to content

Commit

Permalink
Merge pull request #917 from tursodatabase/golibsql
Browse files Browse the repository at this point in the history
Go bindings support remote only dbs
  • Loading branch information
haaawk authored Jan 22, 2024
2 parents b01a442 + ee964f6 commit 2f25725
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 68 deletions.
7 changes: 1 addition & 6 deletions bindings/c/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,7 @@ pub unsafe extern "C" fn libsql_open_remote(
return 2;
}
};
match RT.block_on(libsql::Database::open_with_remote_sync(
url.to_string(),
url,
auth_token,
None,
)) {
match libsql::Database::open_remote(url, auth_token) {
Ok(db) => {
let db = Box::leak(Box::new(libsql_database { db }));
*out_db = libsql_database_t::from(db);
Expand Down
164 changes: 107 additions & 57 deletions bindings/go/libsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
sqldriver "database/sql/driver"
"fmt"
"io"
"net/url"
"strings"
"time"
"unsafe"
)
Expand All @@ -30,25 +32,44 @@ func init() {
}

func NewEmbeddedReplicaConnector(dbPath, primaryUrl, authToken string) (*Connector, error) {
return openConnector(dbPath, primaryUrl, authToken, 0)
return openEmbeddedReplicaConnector(dbPath, primaryUrl, authToken, 0)
}

func NewEmbeddedReplicaConnectorWithAutoSync(dbPath, primaryUrl, authToken string, syncInterval time.Duration) (*Connector, error) {
return openConnector(dbPath, primaryUrl, authToken, syncInterval)
return openEmbeddedReplicaConnector(dbPath, primaryUrl, authToken, syncInterval)
}

type driver struct{}

func (d driver) Open(dbPath string) (sqldriver.Conn, error) {
connector, err := d.OpenConnector(dbPath)
func (d driver) Open(dbAddress string) (sqldriver.Conn, error) {
connector, err := d.OpenConnector(dbAddress)
if err != nil {
return nil, err
}
return connector.Connect(context.Background())
}

func (d driver) OpenConnector(dbPath string) (sqldriver.Connector, error) {
return openConnector(dbPath, "", "", 0)
func (d driver) OpenConnector(dbAddress string) (sqldriver.Connector, error) {
if strings.HasPrefix(dbAddress, ":memory:") {
return openLocalConnector(dbAddress)
}
u, err := url.Parse(dbAddress)
if err != nil {
return nil, err
}
switch u.Scheme {
case "file":
return openLocalConnector(dbAddress)
case "http":
fallthrough
case "https":
fallthrough
case "libsql":
authToken := u.Query().Get("authToken")
u.RawQuery = ""
return openRemoteConnector(u.String(), authToken)
}
return nil, fmt.Errorf("unsupported URL scheme: %s\nThis driver supports only URLs that start with libsql://, file://, https:// or http://", u.Scheme)
}

func libsqlSync(nativeDbPtr C.libsql_database_t) error {
Expand All @@ -60,44 +81,54 @@ func libsqlSync(nativeDbPtr C.libsql_database_t) error {
return nil
}

func openConnector(dbPath, primaryUrl, authToken string, syncInterval time.Duration) (*Connector, error) {
var nativeDbPtr C.libsql_database_t
var err error
func openLocalConnector(dbPath string) (*Connector, error) {
nativeDbPtr, err := libsqlOpenLocal(dbPath)
if err != nil {
return nil, err
}
return &Connector{nativeDbPtr: nativeDbPtr}, nil
}

func openRemoteConnector(primaryUrl, authToken string) (*Connector, error) {
nativeDbPtr, err := libsqlOpenRemote(primaryUrl, authToken)
if err != nil {
return nil, err
}
return &Connector{nativeDbPtr: nativeDbPtr}, nil
}

func openEmbeddedReplicaConnector(dbPath, primaryUrl, authToken string, syncInterval time.Duration) (*Connector, error) {
var closeCh chan struct{}
var closeAckCh chan struct{}
if primaryUrl != "" {
nativeDbPtr, err = libsqlOpenWithSync(dbPath, primaryUrl, authToken)
if err != nil {
return nil, err
}
if err := libsqlSync(nativeDbPtr); err != nil {
C.libsql_close(nativeDbPtr)
return nil, err
}
if syncInterval != 0 {
closeCh = make(chan struct{}, 1)
closeAckCh = make(chan struct{}, 1)
go func() {
for {
timerCh := make(chan struct{}, 1)
go func() {
time.Sleep(syncInterval)
timerCh <- struct{}{}
}()
select {
case <-closeCh:
closeAckCh <- struct{}{}
return
case <-timerCh:
if err := libsqlSync(nativeDbPtr); err != nil {
fmt.Println(err)
}
nativeDbPtr, err := libsqlOpenWithSync(dbPath, primaryUrl, authToken)
if err != nil {
return nil, err
}
if err := libsqlSync(nativeDbPtr); err != nil {
C.libsql_close(nativeDbPtr)
return nil, err
}
if syncInterval != 0 {
closeCh = make(chan struct{}, 1)
closeAckCh = make(chan struct{}, 1)
go func() {
for {
timerCh := make(chan struct{}, 1)
go func() {
time.Sleep(syncInterval)
timerCh <- struct{}{}
}()
select {
case <-closeCh:
closeAckCh <- struct{}{}
return
case <-timerCh:
if err := libsqlSync(nativeDbPtr); err != nil {
fmt.Println(err)
}
}
}()
}
} else {
nativeDbPtr, err = libsqlOpen(dbPath)
}
}()
}
if err != nil {
return nil, err
Expand Down Expand Up @@ -147,15 +178,30 @@ func libsqlError(message string, statusCode C.int, errMsg *C.char) error {
}
}

func libsqlOpen(dataSourceName string) (C.libsql_database_t, error) {
func libsqlOpenLocal(dataSourceName string) (C.libsql_database_t, error) {
connectionString := C.CString(dataSourceName)
defer C.free(unsafe.Pointer(connectionString))

var db C.libsql_database_t
var errMsg *C.char
statusCode := C.libsql_open_ext(connectionString, &db, &errMsg)
statusCode := C.libsql_open_file(connectionString, &db, &errMsg)
if statusCode != 0 {
return nil, libsqlError(fmt.Sprint("failed to open local database ", dataSourceName), statusCode, errMsg)
}
return db, nil
}

func libsqlOpenRemote(url, authToken string) (C.libsql_database_t, error) {
connectionString := C.CString(url)
defer C.free(unsafe.Pointer(connectionString))
authTokenNativeString := C.CString(authToken)
defer C.free(unsafe.Pointer(authTokenNativeString))

var db C.libsql_database_t
var errMsg *C.char
statusCode := C.libsql_open_remote(connectionString, authTokenNativeString, &db, &errMsg)
if statusCode != 0 {
return nil, libsqlError(fmt.Sprint("failed to open database ", dataSourceName), statusCode, errMsg)
return nil, libsqlError(fmt.Sprint("failed to open remote database ", url), statusCode, errMsg)
}
return db, nil
}
Expand Down Expand Up @@ -249,18 +295,8 @@ func newRows(nativePtr C.libsql_rows_t) (*rows, error) {
return &rows{nil, nil, nil}, nil
}
columnCount := int(C.libsql_column_count(nativePtr))
columnTypes := make([]int, columnCount)
columns := make([]string, columnCount)
for i := 0; i < columnCount; i++ {
var columnType C.int
var errMsg *C.char
statusCode := C.libsql_column_type(nativePtr, C.int(i), &columnType, &errMsg)
if statusCode != 0 {
return nil, libsqlError(fmt.Sprint("failed to get column type for index ", i), statusCode, errMsg)
}
columnTypes[i] = int(columnType)
}
columns := make([]string, len(columnTypes))
for i := 0; i < len(columnTypes); i++ {
var ptr *C.char
var errMsg *C.char
statusCode := C.libsql_column_name(nativePtr, C.int(i), &ptr, &errMsg)
Expand All @@ -270,7 +306,7 @@ func newRows(nativePtr C.libsql_rows_t) (*rows, error) {
columns[i] = C.GoString(ptr)
C.libsql_free_string(ptr)
}
return &rows{nativePtr, columnTypes, columns}, nil
return &rows{nativePtr, nil, columns}, nil
}

type rows struct {
Expand Down Expand Up @@ -306,9 +342,23 @@ func (r *rows) Next(dest []sqldriver.Value) error {
return io.EOF
}
defer C.libsql_free_row(row)
if len(r.columnTypes) == 0 {
columnCount := len(r.columnNames)
columnTypes := make([]int, columnCount)
for i := 0; i < columnCount; i++ {
var columnType C.int
var errMsg *C.char
statusCode := C.libsql_column_type(r.nativePtr, C.int(i), &columnType, &errMsg)
if statusCode != 0 {
return libsqlError(fmt.Sprint("failed to get column type for index ", i), statusCode, errMsg)
}
columnTypes[i] = int(columnType)
}
r.columnTypes = columnTypes
}
count := len(dest)
if count > len(r.columnTypes) {
count = len(r.columnTypes)
if count > len(r.columnNames) {
count = len(r.columnNames)
}
for i := 0; i < count; i++ {
switch r.columnTypes[i] {
Expand Down
70 changes: 65 additions & 5 deletions bindings/go/libsql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,74 @@ func TestSync(t *testing.T) {
})
}

func TestRemote(t *testing.T) {
primaryUrl := os.Getenv("LIBSQL_PRIMARY_URL")
if primaryUrl == "" {
t.Skip("LIBSQL_PRIMARY_URL is not set")
return
}
authToken := os.Getenv("LIBSQL_AUTH_TOKEN")
db, err := sql.Open("libsql", primaryUrl+"?authToken="+authToken)
if err != nil {
t.Fatal(err)
}
tableName := fmt.Sprintf("test_%d", time.Now().UnixNano())
_, err = db.Exec(fmt.Sprintf("CREATE TABLE %s (id INTEGER, name TEXT, gpa REAL, cv BLOB);", tableName))
if err != nil {
t.Fatal(err)
}
_, err = db.Exec(fmt.Sprintf("INSERT INTO %s (id, name, gpa, cv) VALUES (%d, '%d', %d.5, randomblob(10));", tableName, 0, 0, 0))
if err != nil {
t.Fatal(err)
}
rows, err := db.QueryContext(context.Background(), "SELECT NULL, id, name, gpa, cv FROM "+tableName)
if err != nil {
t.Fatal(err)
}
defer rows.Close()
idx := 0
for rows.Next() {
if idx > 0 {
t.Fatal("idx should be <= ", 0)
}
var null any
var id int
var name string
var gpa float64
var cv []byte
if err := rows.Scan(&null, &id, &name, &gpa, &cv); err != nil {
t.Fatal(err)
}
if null != nil {
t.Fatal("null should be nil")
}
if id != int(idx) {
t.Fatal("id should be ", idx, " got ", id)
}
if name != fmt.Sprint(idx) {
t.Fatal("name should be", idx)
}
if gpa != float64(idx)+0.5 {
t.Fatal("gpa should be", float64(idx)+0.5)
}
if len(cv) != 10 {
t.Fatal("cv should be 10 bytes")
}
idx++
}
if idx != 1 {
t.Fatal("idx should be 1 got ", idx)
}
}

func runFileTest(t *testing.T, test func(*testing.T, *sql.DB)) {
t.Parallel()
dir, err := os.MkdirTemp("", "libsql-*")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
db, err := sql.Open("libsql", dir+"/test.db")
db, err := sql.Open("libsql", "file:"+dir+"/test.db")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -266,7 +326,7 @@ func runMemoryAndFileTests(t *testing.T, test func(*testing.T, *sql.DB)) {

func TestErrorNonUtf8URL(t *testing.T) {
t.Parallel()
db, err := sql.Open("libsql", "a\xc5z")
db, err := sql.Open("libsql", "file:a\xc5z")
if err == nil {
defer func() {
if err := db.Close(); err != nil {
Expand All @@ -275,7 +335,7 @@ func TestErrorNonUtf8URL(t *testing.T) {
}()
t.Fatal("expected error")
}
if err.Error() != "failed to open database a\xc5z\nerror code = 1: Wrong URL: invalid utf-8 sequence of 1 bytes from index 1" {
if err.Error() != "failed to open local database file:a\xc5z\nerror code = 1: Wrong URL: invalid utf-8 sequence of 1 bytes from index 6" {
t.Fatal("unexpected error:", err)
}
}
Expand All @@ -299,7 +359,7 @@ func TestErrorWrongURL(t *testing.T) {

func TestErrorCanNotConnect(t *testing.T) {
t.Parallel()
db, err := sql.Open("libsql", "/root/test.db")
db, err := sql.Open("libsql", "file:/root/test.db")
if err != nil {
t.Fatal(err)
}
Expand All @@ -317,7 +377,7 @@ func TestErrorCanNotConnect(t *testing.T) {
}()
t.Fatal("expected error")
}
if err.Error() != "failed to connect to database\nerror code = 1: Unable to connect: Failed to connect to database: `/root/test.db`" {
if err.Error() != "failed to connect to database\nerror code = 1: Unable to connect: Failed to connect to database: `file:/root/test.db`" {
t.Fatal("unexpected error:", err)
}
}
Expand Down

0 comments on commit 2f25725

Please sign in to comment.