From f01a111c9467a814efc939b5d383a2969c3e832d Mon Sep 17 00:00:00 2001 From: Joel Jr Vanier Date: Mon, 14 Nov 2022 09:10:39 -0500 Subject: [PATCH 1/2] Add Context Functions --- cnuodb.cpp | 17 +++++++- cnuodb.h | 3 +- nuodb.go | 83 ++++++++++++++++++++++++++++++++++++++- nuodb_test.go | 106 ++++++++++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 202 insertions(+), 7 deletions(-) diff --git a/cnuodb.cpp b/cnuodb.cpp index 214fb06..5fce737 100644 --- a/cnuodb.cpp +++ b/cnuodb.cpp @@ -164,10 +164,11 @@ static int fetchExecuteResult(struct nuodb *db, Statement *stmt, } int nuodb_execute(struct nuodb *db, const char *sql, - int64_t *rows_affected, int64_t *last_insert_id) { + int64_t *rows_affected, int64_t *last_insert_id, int64_t timeout_micro_seconds) { Statement *stmt = 0; try { stmt = db->conn->createStatement(); + stmt->setQueryTimeoutMicros(timeout_micro_seconds); stmt->executeUpdate(sql, RETURN_GENERATED_KEYS); int rc = fetchExecuteResult(db, stmt, rows_affected, last_insert_id); stmt->close(); @@ -296,6 +297,20 @@ int nuodb_statement_close(struct nuodb *db, struct nuodb_statement **st) { } } +int nuodb_statement_set_query_micros(struct nuodb *db, struct nuodb_statement *st, + int64_t timeout_micro_seconds) { + try { + if (st) { + PreparedStatement *stmt = reinterpret_cast(st); + stmt->setQueryTimeoutMicros(timeout_micro_seconds); + st = 0; + } + return 0; + } catch (SQLException &e) { + return setError(db, e); + } +} + int nuodb_resultset_column_names(struct nuodb *db, struct nuodb_resultset *rs, struct nuodb_value names[]) { ResultSet *resultSet = reinterpret_cast(rs); diff --git a/cnuodb.h b/cnuodb.h index fbf7bfa..e9e5a94 100644 --- a/cnuodb.h +++ b/cnuodb.h @@ -49,13 +49,14 @@ int nuodb_autocommit(struct nuodb *db, int *state); int nuodb_autocommit_set(struct nuodb *db, int state); int nuodb_commit(struct nuodb *db); int nuodb_rollback(struct nuodb *db); -int nuodb_execute(struct nuodb *db, const char *sql, int64_t *rows_affected, int64_t *last_insert_id); +int nuodb_execute(struct nuodb *db, const char *sql, int64_t *rows_affected, int64_t *last_insert_id, int64_t timeout_micro_seconds); int nuodb_statement_prepare(struct nuodb *db, const char *sql, struct nuodb_statement **st, int *parameter_count); int nuodb_statement_bind(struct nuodb *db, struct nuodb_statement *st, struct nuodb_value parameters[]); int nuodb_statement_execute(struct nuodb *db, struct nuodb_statement *st, int64_t *rows_affected, int64_t *last_insert_id); int nuodb_statement_query(struct nuodb *db, struct nuodb_statement *st, struct nuodb_resultset **rs, int *column_count); int nuodb_statement_close(struct nuodb *db, struct nuodb_statement **st); +int nuodb_statement_set_query_micros(struct nuodb *db, struct nuodb_statement *st, int64_t timeout_micro_seconds); int nuodb_resultset_column_names(struct nuodb *db, struct nuodb_resultset *rs, struct nuodb_value names[]); int nuodb_resultset_next(struct nuodb *db, struct nuodb_resultset *rs, int *has_values, struct nuodb_value values[]); diff --git a/nuodb.go b/nuodb.go index a4efb38..d5fcbbb 100644 --- a/nuodb.go +++ b/nuodb.go @@ -2,12 +2,13 @@ package nuodb -// #cgo CPPFLAGS: -I/opt/nuodb/include +// #cgo CPPFLAGS: -I/opt/nuodb/include // #cgo LDFLAGS: -L. -lcnuodb -L/opt/nuodb/lib64/ -lNuoRemote // #include "cnuodb.h" // #include import "C" import ( + "context" "database/sql" "database/sql/driver" "errors" @@ -168,13 +169,26 @@ func (c *Conn) Begin() (driver.Tx, error) { } func (c *Conn) Exec(sql string, args []driver.Value) (driver.Result, error) { + if len(args) > 0 { + return nil, driver.ErrSkip + } + return c.ExecContext(context.Background(), sql, nil) +} + +func (c *Conn) ExecContext(ctx context.Context, sql string, args []driver.NamedValue) (driver.Result, error) { if len(args) > 0 { return nil, driver.ErrSkip } csql := C.CString(sql) defer C.free(unsafe.Pointer(csql)) result := &Result{} - if rc := C.nuodb_execute(c.db, csql, &result.rowsAffected, &result.lastInsertId); rc != 0 { + + uSec, err := getMicrosecondsUntilDeadline(ctx) + if err != nil { + return nil, err + } + + if rc := C.nuodb_execute(c.db, csql, &result.rowsAffected, &result.lastInsertId, uSec); rc != 0 { return nil, c.lastError(rc) } if result.rowsAffected == 0 && ddlStatement(sql) { @@ -254,6 +268,19 @@ func (stmt *Stmt) bind(args []driver.Value) error { } func (stmt *Stmt) Exec(args []driver.Value) (driver.Result, error) { + return stmt.execQuery(context.Background(), args) +} + +func (stmt *Stmt) ExecQuery(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + values, err := namedValuesToValues(args) + if err != nil { + return nil, err + } + + return stmt.execQuery(ctx, values) +} + +func (stmt *Stmt) execQuery(ctx context.Context, args []driver.Value) (driver.Result, error) { var err error c := stmt.c if c.db == nil { @@ -262,6 +289,9 @@ func (stmt *Stmt) Exec(args []driver.Value) (driver.Result, error) { if err = stmt.bind(args); err != nil { return nil, fmt.Errorf("bind: %s", err) } + if err = stmt.addTimeoutFromContext(ctx); err != nil { + return nil, err + } result := &Result{} if rc := C.nuodb_statement_execute(c.db, stmt.st, &result.rowsAffected, &result.lastInsertId); rc != 0 { return nil, c.lastError(rc) @@ -273,6 +303,18 @@ func (stmt *Stmt) Exec(args []driver.Value) (driver.Result, error) { } func (stmt *Stmt) Query(args []driver.Value) (driver.Rows, error) { + return stmt.queryContext(context.Background(), args) +} + +func (stmt *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + values, err := namedValuesToValues(args) + if err != nil { + return nil, err + } + return stmt.queryContext(ctx, values) +} + +func (stmt *Stmt) queryContext(ctx context.Context, args []driver.Value) (driver.Rows, error) { var err error c := stmt.c if c.db == nil { @@ -281,6 +323,9 @@ func (stmt *Stmt) Query(args []driver.Value) (driver.Rows, error) { if err = stmt.bind(args); err != nil { return nil, fmt.Errorf("bind: %s", err) } + if err = stmt.addTimeoutFromContext(ctx); err != nil { + return nil, err + } rows := &Rows{c: c} var columnCount C.int if rc := C.nuodb_statement_query(c.db, stmt.st, &rows.rs, &columnCount); rc != 0 { @@ -304,6 +349,40 @@ func (stmt *Stmt) Query(args []driver.Value) (driver.Rows, error) { return rows, nil } +func (stmt *Stmt) addTimeoutFromContext(ctx context.Context) error { + uSec, err := getMicrosecondsUntilDeadline(ctx) + if err != nil { + return err + } + + C.nuodb_statement_set_query_micros(stmt.c.db, stmt.st, uSec) + + return nil +} + +func getMicrosecondsUntilDeadline(ctx context.Context) (uSec C.int64_t, err error) { + if deadline, ok := ctx.Deadline(); ok { + uSec = C.int64_t(time.Until(deadline).Microseconds()) + } + + if err = ctx.Err(); err != nil { + return 0, err + } + + return uSec, nil +} + +func namedValuesToValues(namedValues []driver.NamedValue) ([]driver.Value, error) { + values := make([]driver.Value, 0, len(namedValues)) + for _, namedValue := range namedValues { + if len(namedValue.Name) != 0 { + return nil, fmt.Errorf("sql driver doesn't support named values") + } + values = append(values, namedValue.Value) + } + return values, nil +} + func (stmt *Stmt) Close() error { if stmt != nil && stmt.c.db != nil { if rc := C.nuodb_statement_close(stmt.c.db, &stmt.st); rc != 0 { diff --git a/nuodb_test.go b/nuodb_test.go index 515dc03..6aa7c40 100644 --- a/nuodb_test.go +++ b/nuodb_test.go @@ -3,6 +3,7 @@ package nuodb import ( + "context" "database/sql" "log" "math" @@ -18,9 +19,9 @@ const default_dsn = base_dsn + "?timezone=America/Los_Angeles" const ( syntaxError = -1 - compileError = -4 conversionError = -8 connectionError = -10 + ddlError = -11 noSuchTableError = -25 ) @@ -35,6 +36,17 @@ func exec(t *testing.T, db *sql.DB, sql string, args ...interface{}) (li, ra int return } +func execContext(t *testing.T, db *sql.DB, ctx context.Context, sql string, args ...interface{}) (li, ra int64) { + result, err := db.ExecContext(ctx, sql, args...) + if err != nil { + _, _, line, _ := runtime.Caller(1) + t.Fatalf("line:%d sql: %s err: %s", line, sql, err) + } + li, _ = result.LastInsertId() + ra, _ = result.RowsAffected() + return +} + func query(t *testing.T, db *sql.DB, sql string, args ...interface{}) *sql.Rows { rows, err := db.Query(sql, args...) if err != nil { @@ -43,6 +55,14 @@ func query(t *testing.T, db *sql.DB, sql string, args ...interface{}) *sql.Rows return rows } +func queryContext(t *testing.T, db *sql.DB, ctx context.Context, sql string, args ...interface{}) *sql.Rows { + rows, err := db.QueryContext(ctx, sql, args...) + if err != nil { + t.Fatal(sql, "=>", err) + } + return rows +} + func testConn(t *testing.T) *sql.DB { db, err := sql.Open("nuodb", default_dsn) if err != nil { @@ -73,7 +93,7 @@ func expectErrorCode(t *testing.T, err error, code int) { func TestConnectionError(t *testing.T) { // Use an invalid IP address to force a connection error - db, err := sql.Open("nuodb", "nuodb://robinh:crossbow@0.0.0.1:48004/tests") + db, err := sql.Open("nuodb", "nuodb://robinh:crossbow@0.0.0:48004/tests") if err != nil { t.Fatal("sql.Open:", err) } @@ -193,12 +213,92 @@ func TestExecAndQuery(t *testing.T) { } } +func TestExecAndQueryContext(t *testing.T) { + db := testConn(t) + defer db.Close() + + id, ra := exec(t, db, "CREATE TABLE FooBar (id BIGINT GENERATED ALWAYS AS IDENTITY NOT NULL, ir INTEGER)") + if id|ra != 0 { + t.Fatal(id, ra) + } + + t.Run("context canceled before call", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := db.ExecContext(ctx, "INSERT INTO FooBar (ir) VALUES (1)") + if !strings.Contains(err.Error(), "context canceled") { + t.Fatal(err) + } + + _, err = db.QueryContext(ctx, "SELECT 1 FROM Dual") + if !strings.Contains(err.Error(), "context canceled") { + t.Fatal(err) + } + }) + + t.Run("call done with context without deadline", func(t *testing.T) { + ctx := context.Background() + + id, ra := execContext(t, db, ctx, "INSERT INTO FooBar (ir) VALUES (1)") + if id|ra == 0 { + t.Fatal(id, ra) + } + + rows := queryContext(t, db, ctx, "SELECT 1 FROM Dual") + defer rows.Close() + if !rows.Next() { + t.Fatal("Expected rows") + } + }) + + t.Run("call done before context deadline", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Hour) + defer cancel() + + id, ra := execContext(t, db, ctx, "INSERT INTO FooBar (ir) VALUES (2)") + if id|ra == 0 { + t.Fatal(id, ra) + } + + rows := queryContext(t, db, ctx, "SELECT 1 FROM Dual") + defer rows.Close() + if !rows.Next() { + t.Fatal("Expected rows") + } + }) + + t.Run("call is interrupted by context", func(t *testing.T) { + ctxExec, cancelExec := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancelExec() + + // SQL query that will 'spin' artificially for 5s. + longQuery := ` +VAR until TIMESTAMP=(SELECT DATE_ADD(NOW(), INTERVAL 5 SECOND) FROM DUAL); +WHILE ((SELECT NOW() FROM DUAL) < until ) +END_WHILE` + + _, err := db.ExecContext(ctxExec, longQuery) + if !strings.Contains(err.Error(), "exceeded") { + t.Fatal(err) + } + + ctxQuery, cancelQuery := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancelQuery() + + _, err = db.QueryContext(ctxQuery, longQuery) + if !strings.Contains(err.Error(), "exceeded") { + t.Fatal(err) + } + }) +} + func TestExecAndQueryError(t *testing.T) { db := testConn(t) defer db.Close() _, err := db.Exec("CALL NotARealFunction()") - expectErrorCode(t, err, compileError) + expectErrorCode(t, err, ddlError) _, err = db.Query("SELECT * FROM tests.NotARealTable") expectErrorCode(t, err, noSuchTableError) From 02eec624a3e7bc23b931a8652550d40acc21181b Mon Sep 17 00:00:00 2001 From: Joel Jr Vanier Date: Wed, 16 Nov 2022 11:53:27 -0500 Subject: [PATCH 2/2] review comments --- cnuodb.cpp | 2 +- go.mod | 3 +++ nuodb.go | 9 +++++++++ 3 files changed, 13 insertions(+), 1 deletion(-) create mode 100644 go.mod diff --git a/cnuodb.cpp b/cnuodb.cpp index 5fce737..cafc127 100644 --- a/cnuodb.cpp +++ b/cnuodb.cpp @@ -302,8 +302,8 @@ int nuodb_statement_set_query_micros(struct nuodb *db, struct nuodb_statement *s try { if (st) { PreparedStatement *stmt = reinterpret_cast(st); + // Set the timeout in micro seconds; zero means there is no limit. stmt->setQueryTimeoutMicros(timeout_micro_seconds); - st = 0; } return 0; } catch (SQLException &e) { diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..792d2eb --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/tilinna/go-nuodb + +go 1.13 diff --git a/nuodb.go b/nuodb.go index d5fcbbb..e6bc1d0 100644 --- a/nuodb.go +++ b/nuodb.go @@ -35,6 +35,12 @@ type Stmt struct { ddlStatement bool } +var _ interface { + driver.Stmt + driver.StmtQueryContext + // driver.StmtExecContext +} = (*Stmt)(nil) + type Result struct { rowsAffected C.int64_t lastInsertId C.int64_t @@ -360,6 +366,9 @@ func (stmt *Stmt) addTimeoutFromContext(ctx context.Context) error { return nil } +// getMicrosecondsUntilDeadline returns the number of micro seconds until the context's deadline is reached. +// Returns an error if the context is already done. +// N.B. A value of zero means no limit. func getMicrosecondsUntilDeadline(ctx context.Context) (uSec C.int64_t, err error) { if deadline, ok := ctx.Deadline(); ok { uSec = C.int64_t(time.Until(deadline).Microseconds())