Skip to content

Commit

Permalink
Merge pull request #9 from jovakaloom/add-context-function
Browse files Browse the repository at this point in the history
Add Context Functions and a go.mod -file.
  • Loading branch information
tilinna authored Nov 17, 2022
2 parents 3cb12ee + 02eec62 commit fd48376
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 7 deletions.
17 changes: 16 additions & 1 deletion cnuodb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,11 @@ static int fetchExecuteResult(struct nuodb *db, Statement *stmt,
}

int nuodb_execute(struct nuodb *db, const char *sql,
int64_t *rows_affected, int64_t *last_insert_id) {
int64_t *rows_affected, int64_t *last_insert_id, int64_t timeout_micro_seconds) {
Statement *stmt = 0;
try {
stmt = db->conn->createStatement();
stmt->setQueryTimeoutMicros(timeout_micro_seconds);
stmt->executeUpdate(sql, RETURN_GENERATED_KEYS);
int rc = fetchExecuteResult(db, stmt, rows_affected, last_insert_id);
stmt->close();
Expand Down Expand Up @@ -296,6 +297,20 @@ int nuodb_statement_close(struct nuodb *db, struct nuodb_statement **st) {
}
}

int nuodb_statement_set_query_micros(struct nuodb *db, struct nuodb_statement *st,
int64_t timeout_micro_seconds) {
try {
if (st) {
PreparedStatement *stmt = reinterpret_cast<PreparedStatement *>(st);
// Set the timeout in micro seconds; zero means there is no limit.
stmt->setQueryTimeoutMicros(timeout_micro_seconds);
}
return 0;
} catch (SQLException &e) {
return setError(db, e);
}
}

int nuodb_resultset_column_names(struct nuodb *db, struct nuodb_resultset *rs,
struct nuodb_value names[]) {
ResultSet *resultSet = reinterpret_cast<ResultSet *>(rs);
Expand Down
3 changes: 2 additions & 1 deletion cnuodb.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,14 @@ int nuodb_autocommit(struct nuodb *db, int *state);
int nuodb_autocommit_set(struct nuodb *db, int state);
int nuodb_commit(struct nuodb *db);
int nuodb_rollback(struct nuodb *db);
int nuodb_execute(struct nuodb *db, const char *sql, int64_t *rows_affected, int64_t *last_insert_id);
int nuodb_execute(struct nuodb *db, const char *sql, int64_t *rows_affected, int64_t *last_insert_id, int64_t timeout_micro_seconds);

int nuodb_statement_prepare(struct nuodb *db, const char *sql, struct nuodb_statement **st, int *parameter_count);
int nuodb_statement_bind(struct nuodb *db, struct nuodb_statement *st, struct nuodb_value parameters[]);
int nuodb_statement_execute(struct nuodb *db, struct nuodb_statement *st, int64_t *rows_affected, int64_t *last_insert_id);
int nuodb_statement_query(struct nuodb *db, struct nuodb_statement *st, struct nuodb_resultset **rs, int *column_count);
int nuodb_statement_close(struct nuodb *db, struct nuodb_statement **st);
int nuodb_statement_set_query_micros(struct nuodb *db, struct nuodb_statement *st, int64_t timeout_micro_seconds);

int nuodb_resultset_column_names(struct nuodb *db, struct nuodb_resultset *rs, struct nuodb_value names[]);
int nuodb_resultset_next(struct nuodb *db, struct nuodb_resultset *rs, int *has_values, struct nuodb_value values[]);
Expand Down
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
module github.com/tilinna/go-nuodb

go 1.13
92 changes: 90 additions & 2 deletions nuodb.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

package nuodb

// #cgo CPPFLAGS: -I/opt/nuodb/include
// #cgo CPPFLAGS: -I/opt/nuodb/include
// #cgo LDFLAGS: -L. -lcnuodb -L/opt/nuodb/lib64/ -lNuoRemote
// #include "cnuodb.h"
// #include <stdlib.h>
import "C"
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
Expand All @@ -34,6 +35,12 @@ type Stmt struct {
ddlStatement bool
}

var _ interface {
driver.Stmt
driver.StmtQueryContext
// driver.StmtExecContext
} = (*Stmt)(nil)

type Result struct {
rowsAffected C.int64_t
lastInsertId C.int64_t
Expand Down Expand Up @@ -168,13 +175,26 @@ func (c *Conn) Begin() (driver.Tx, error) {
}

func (c *Conn) Exec(sql string, args []driver.Value) (driver.Result, error) {
if len(args) > 0 {
return nil, driver.ErrSkip
}
return c.ExecContext(context.Background(), sql, nil)
}

func (c *Conn) ExecContext(ctx context.Context, sql string, args []driver.NamedValue) (driver.Result, error) {
if len(args) > 0 {
return nil, driver.ErrSkip
}
csql := C.CString(sql)
defer C.free(unsafe.Pointer(csql))
result := &Result{}
if rc := C.nuodb_execute(c.db, csql, &result.rowsAffected, &result.lastInsertId); rc != 0 {

uSec, err := getMicrosecondsUntilDeadline(ctx)
if err != nil {
return nil, err
}

if rc := C.nuodb_execute(c.db, csql, &result.rowsAffected, &result.lastInsertId, uSec); rc != 0 {
return nil, c.lastError(rc)
}
if result.rowsAffected == 0 && ddlStatement(sql) {
Expand Down Expand Up @@ -254,6 +274,19 @@ func (stmt *Stmt) bind(args []driver.Value) error {
}

func (stmt *Stmt) Exec(args []driver.Value) (driver.Result, error) {
return stmt.execQuery(context.Background(), args)
}

func (stmt *Stmt) ExecQuery(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
values, err := namedValuesToValues(args)
if err != nil {
return nil, err
}

return stmt.execQuery(ctx, values)
}

func (stmt *Stmt) execQuery(ctx context.Context, args []driver.Value) (driver.Result, error) {
var err error
c := stmt.c
if c.db == nil {
Expand All @@ -262,6 +295,9 @@ func (stmt *Stmt) Exec(args []driver.Value) (driver.Result, error) {
if err = stmt.bind(args); err != nil {
return nil, fmt.Errorf("bind: %s", err)
}
if err = stmt.addTimeoutFromContext(ctx); err != nil {
return nil, err
}
result := &Result{}
if rc := C.nuodb_statement_execute(c.db, stmt.st, &result.rowsAffected, &result.lastInsertId); rc != 0 {
return nil, c.lastError(rc)
Expand All @@ -273,6 +309,18 @@ func (stmt *Stmt) Exec(args []driver.Value) (driver.Result, error) {
}

func (stmt *Stmt) Query(args []driver.Value) (driver.Rows, error) {
return stmt.queryContext(context.Background(), args)
}

func (stmt *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
values, err := namedValuesToValues(args)
if err != nil {
return nil, err
}
return stmt.queryContext(ctx, values)
}

func (stmt *Stmt) queryContext(ctx context.Context, args []driver.Value) (driver.Rows, error) {
var err error
c := stmt.c
if c.db == nil {
Expand All @@ -281,6 +329,9 @@ func (stmt *Stmt) Query(args []driver.Value) (driver.Rows, error) {
if err = stmt.bind(args); err != nil {
return nil, fmt.Errorf("bind: %s", err)
}
if err = stmt.addTimeoutFromContext(ctx); err != nil {
return nil, err
}
rows := &Rows{c: c}
var columnCount C.int
if rc := C.nuodb_statement_query(c.db, stmt.st, &rows.rs, &columnCount); rc != 0 {
Expand All @@ -304,6 +355,43 @@ func (stmt *Stmt) Query(args []driver.Value) (driver.Rows, error) {
return rows, nil
}

func (stmt *Stmt) addTimeoutFromContext(ctx context.Context) error {
uSec, err := getMicrosecondsUntilDeadline(ctx)
if err != nil {
return err
}

C.nuodb_statement_set_query_micros(stmt.c.db, stmt.st, uSec)

return nil
}

// getMicrosecondsUntilDeadline returns the number of micro seconds until the context's deadline is reached.
// Returns an error if the context is already done.
// N.B. A value of zero means no limit.
func getMicrosecondsUntilDeadline(ctx context.Context) (uSec C.int64_t, err error) {
if deadline, ok := ctx.Deadline(); ok {
uSec = C.int64_t(time.Until(deadline).Microseconds())
}

if err = ctx.Err(); err != nil {
return 0, err
}

return uSec, nil
}

func namedValuesToValues(namedValues []driver.NamedValue) ([]driver.Value, error) {
values := make([]driver.Value, 0, len(namedValues))
for _, namedValue := range namedValues {
if len(namedValue.Name) != 0 {
return nil, fmt.Errorf("sql driver doesn't support named values")
}
values = append(values, namedValue.Value)
}
return values, nil
}

func (stmt *Stmt) Close() error {
if stmt != nil && stmt.c.db != nil {
if rc := C.nuodb_statement_close(stmt.c.db, &stmt.st); rc != 0 {
Expand Down
106 changes: 103 additions & 3 deletions nuodb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package nuodb

import (
"context"
"database/sql"
"log"
"math"
Expand All @@ -18,9 +19,9 @@ const default_dsn = base_dsn + "?timezone=America/Los_Angeles"

const (
syntaxError = -1
compileError = -4
conversionError = -8
connectionError = -10
ddlError = -11
noSuchTableError = -25
)

Expand All @@ -35,6 +36,17 @@ func exec(t *testing.T, db *sql.DB, sql string, args ...interface{}) (li, ra int
return
}

func execContext(t *testing.T, db *sql.DB, ctx context.Context, sql string, args ...interface{}) (li, ra int64) {
result, err := db.ExecContext(ctx, sql, args...)
if err != nil {
_, _, line, _ := runtime.Caller(1)
t.Fatalf("line:%d sql: %s err: %s", line, sql, err)
}
li, _ = result.LastInsertId()
ra, _ = result.RowsAffected()
return
}

func query(t *testing.T, db *sql.DB, sql string, args ...interface{}) *sql.Rows {
rows, err := db.Query(sql, args...)
if err != nil {
Expand All @@ -43,6 +55,14 @@ func query(t *testing.T, db *sql.DB, sql string, args ...interface{}) *sql.Rows
return rows
}

func queryContext(t *testing.T, db *sql.DB, ctx context.Context, sql string, args ...interface{}) *sql.Rows {
rows, err := db.QueryContext(ctx, sql, args...)
if err != nil {
t.Fatal(sql, "=>", err)
}
return rows
}

func testConn(t *testing.T) *sql.DB {
db, err := sql.Open("nuodb", default_dsn)
if err != nil {
Expand Down Expand Up @@ -73,7 +93,7 @@ func expectErrorCode(t *testing.T, err error, code int) {

func TestConnectionError(t *testing.T) {
// Use an invalid IP address to force a connection error
db, err := sql.Open("nuodb", "nuodb://robinh:[email protected].1:48004/tests")
db, err := sql.Open("nuodb", "nuodb://robinh:[email protected]:48004/tests")
if err != nil {
t.Fatal("sql.Open:", err)
}
Expand Down Expand Up @@ -193,12 +213,92 @@ func TestExecAndQuery(t *testing.T) {
}
}

func TestExecAndQueryContext(t *testing.T) {
db := testConn(t)
defer db.Close()

id, ra := exec(t, db, "CREATE TABLE FooBar (id BIGINT GENERATED ALWAYS AS IDENTITY NOT NULL, ir INTEGER)")
if id|ra != 0 {
t.Fatal(id, ra)
}

t.Run("context canceled before call", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()

_, err := db.ExecContext(ctx, "INSERT INTO FooBar (ir) VALUES (1)")
if !strings.Contains(err.Error(), "context canceled") {
t.Fatal(err)
}

_, err = db.QueryContext(ctx, "SELECT 1 FROM Dual")
if !strings.Contains(err.Error(), "context canceled") {
t.Fatal(err)
}
})

t.Run("call done with context without deadline", func(t *testing.T) {
ctx := context.Background()

id, ra := execContext(t, db, ctx, "INSERT INTO FooBar (ir) VALUES (1)")
if id|ra == 0 {
t.Fatal(id, ra)
}

rows := queryContext(t, db, ctx, "SELECT 1 FROM Dual")
defer rows.Close()
if !rows.Next() {
t.Fatal("Expected rows")
}
})

t.Run("call done before context deadline", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Hour)
defer cancel()

id, ra := execContext(t, db, ctx, "INSERT INTO FooBar (ir) VALUES (2)")
if id|ra == 0 {
t.Fatal(id, ra)
}

rows := queryContext(t, db, ctx, "SELECT 1 FROM Dual")
defer rows.Close()
if !rows.Next() {
t.Fatal("Expected rows")
}
})

t.Run("call is interrupted by context", func(t *testing.T) {
ctxExec, cancelExec := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancelExec()

// SQL query that will 'spin' artificially for 5s.
longQuery := `
VAR until TIMESTAMP=(SELECT DATE_ADD(NOW(), INTERVAL 5 SECOND) FROM DUAL);
WHILE ((SELECT NOW() FROM DUAL) < until )
END_WHILE`

_, err := db.ExecContext(ctxExec, longQuery)
if !strings.Contains(err.Error(), "exceeded") {
t.Fatal(err)
}

ctxQuery, cancelQuery := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancelQuery()

_, err = db.QueryContext(ctxQuery, longQuery)
if !strings.Contains(err.Error(), "exceeded") {
t.Fatal(err)
}
})
}

func TestExecAndQueryError(t *testing.T) {
db := testConn(t)
defer db.Close()

_, err := db.Exec("CALL NotARealFunction()")
expectErrorCode(t, err, compileError)
expectErrorCode(t, err, ddlError)

_, err = db.Query("SELECT * FROM tests.NotARealTable")
expectErrorCode(t, err, noSuchTableError)
Expand Down

0 comments on commit fd48376

Please sign in to comment.