From e53873151a46af22b77fa010fc1485c31cbb38b1 Mon Sep 17 00:00:00 2001 From: David Shiflet Date: Thu, 20 Jan 2022 10:12:18 -0500 Subject: [PATCH] Implement sqlexp Messages for Query/QueryContext (#690) * Ask for usedb and set language messages * add functional test for message queue * first basic version of msg queue * complete messages implementation * fix lint and build issues * fix test to avoid 0 rows in a query * add sqlexp to appveyor yml * fix build * fix the merge * move messageq tests to go1.9 * fix example function and support env vars * implement data type discovery * fix pipeline variables for tests --- .pipelines/TestSql2017.yml | 4 +- appveyor.yml | 1 + bulkimport_example_test.go | 14 +- datetimeoffset_example_test.go | 13 +- go.mod | 1 + go.sum | 2 + lastinsertid_example_test.go | 14 +- messages_benchmark_test.go | 63 ++++++++ messages_example_test.go | 76 ++++++++++ mssql.go | 224 +++++++++++++++++++++++++++- mssql_go19.go | 6 + newconnector_example_test.go | 29 ++-- queries_go19_test.go | 263 +++++++++++++++++++++++++++++++++ queries_test.go | 2 +- tds.go | 8 + tds_go110_test.go | 3 +- tds_login_test.go | 20 ++- tds_test.go | 36 +++-- token.go | 54 ++++++- tvp_example_test.go | 14 +- 20 files changed, 762 insertions(+), 85 deletions(-) create mode 100644 messages_benchmark_test.go create mode 100644 messages_example_test.go diff --git a/.pipelines/TestSql2017.yml b/.pipelines/TestSql2017.yml index 2a052e96..046e3e98 100644 --- a/.pipelines/TestSql2017.yml +++ b/.pipelines/TestSql2017.yml @@ -58,9 +58,9 @@ steps: workingDirectory: '$(Build.SourcesDirectory)' displayName: 'run tests' env: - SQLSERVER_DSN: 'server=.;user id=sa;password=$(TESTPASSWORD)' + SQLPASSWORD: $(SQLPASSWORD) AZURESERVER_DSN: $(AZURESERVER_DSN) - + SQLSERVER_DSN: $(SQLSERVER_DSN) continueOnError: true - task: PublishTestResults@2 displayName: "Publish junit-style results" diff --git a/appveyor.yml b/appveyor.yml index ecb893a3..c03f375c 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -49,6 +49,7 @@ install: - go version - go env - go get -u github.com/golang-sql/civil + - go get -u github.com/golang-sql/sqlexp build_script: - go build diff --git a/bulkimport_example_test.go b/bulkimport_example_test.go index 54d3dc2d..4fafbc99 100644 --- a/bulkimport_example_test.go +++ b/bulkimport_example_test.go @@ -1,11 +1,10 @@ +//go:build go1.10 // +build go1.10 package mssql_test import ( "database/sql" - "flag" - "fmt" "log" "strings" "unicode/utf8" @@ -32,19 +31,8 @@ const ( // This example shows how to perform bulk imports func ExampleCopyIn() { - flag.Parse() - - if *debug { - fmt.Printf(" password:%s\n", *password) - fmt.Printf(" port:%d\n", *port) - fmt.Printf(" server:%s\n", *server) - fmt.Printf(" user:%s\n", *user) - } connString := makeConnURL().String() - if *debug { - fmt.Printf(" connString:%s\n", connString) - } db, err := sql.Open("sqlserver", connString) if err != nil { diff --git a/datetimeoffset_example_test.go b/datetimeoffset_example_test.go index fa3dffb3..ad419c41 100644 --- a/datetimeoffset_example_test.go +++ b/datetimeoffset_example_test.go @@ -1,10 +1,10 @@ +//go:build go1.10 // +build go1.10 package mssql_test import ( "database/sql" - "flag" "fmt" "log" "time" @@ -15,19 +15,8 @@ import ( // This example shows how to insert and retrieve date and time types data func ExampleDateTimeOffset() { - flag.Parse() - - if *debug { - fmt.Printf(" password:%s\n", *password) - fmt.Printf(" port:%d\n", *port) - fmt.Printf(" server:%s\n", *server) - fmt.Printf(" user:%s\n", *user) - } connString := makeConnURL().String() - if *debug { - fmt.Printf(" connString:%s\n", connString) - } db, err := sql.Open("sqlserver", connString) if err != nil { diff --git a/go.mod b/go.mod index d0d623e8..792e0382 100644 --- a/go.mod +++ b/go.mod @@ -6,5 +6,6 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/azcore v0.19.0 github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.11.0 github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe + github.com/golang-sql/sqlexp v0.0.0-20170517235910-f1bb20e5a188 golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897 ) diff --git a/go.sum b/go.sum index 67630618..f8a505f6 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= +github.com/golang-sql/sqlexp v0.0.0-20170517235910-f1bb20e5a188 h1:+eHOFJl1BaXrQxKX+T06f78590z4qA2ZzBTqahsKSE4= +github.com/golang-sql/sqlexp v0.0.0-20170517235910-f1bb20e5a188/go.mod h1:vXjM/+wXQnTPR4KqTKDgJukSZ6amVRtWMPEjE6sQoK8= github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8= github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4 h1:49lOXmGaUpV9Fz3gd7TFZY106KVlPVa5jcYD1gaQf98= github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4/go.mod h1:4OwLy04Bl9Ef3GJJCoec+30X3LQs/0/m4HFRt/2LUSA= diff --git a/lastinsertid_example_test.go b/lastinsertid_example_test.go index 9a82284f..260b44ec 100644 --- a/lastinsertid_example_test.go +++ b/lastinsertid_example_test.go @@ -1,29 +1,17 @@ +//go:build go1.10 // +build go1.10 package mssql_test import ( "database/sql" - "flag" - "fmt" "log" ) // This example shows the usage of Connector type func ExampleLastInsertId() { - flag.Parse() - - if *debug { - fmt.Printf(" password:%s\n", *password) - fmt.Printf(" port:%d\n", *port) - fmt.Printf(" server:%s\n", *server) - fmt.Printf(" user:%s\n", *user) - } connString := makeConnURL().String() - if *debug { - fmt.Printf(" connString:%s\n", connString) - } db, err := sql.Open("sqlserver", connString) if err != nil { diff --git a/messages_benchmark_test.go b/messages_benchmark_test.go new file mode 100644 index 00000000..b7425429 --- /dev/null +++ b/messages_benchmark_test.go @@ -0,0 +1,63 @@ +// +build go1.14 + +package mssql + +import ( + "testing" +) + +func BenchmarkMessageQueue(b *testing.B) { + conn, logger := open(b) + defer conn.Close() + defer logger.StopLogging() + + b.Run("BlockingQuery", func(b *testing.B) { + var errs, results float64 + for i := 0; i < b.N; i++ { + r, err := conn.Query(mixedQuery) + if err != nil { + b.Fatal(err.Error()) + } + defer r.Close() + active := true + first := true + for active { + active = r.Next() + if active && first { + results++ + } + first = false + if !active { + if r.Err() != nil { + b.Logf("r.Err:%v", r.Err()) + errs++ + } + active = r.NextResultSet() + if active { + first = true + } + } + } + } + b.ReportMetric(float64(0), "msgs/op") + b.ReportMetric(errs/float64(b.N), "errors/op") + b.ReportMetric(results/float64(b.N), "results/op") + }) + b.Run("NonblockingQuery", func(b *testing.B) { + var msgs, errs, results, rowcounts float64 + for i := 0; i < b.N; i++ { + m, e, r, rc := testMixedQuery(conn, b) + msgs += float64(m) + errs += float64(e) + results += float64(r) + rowcounts += float64(rc) + if r != 4 { + b.Fatalf("Got wrong results count: %d, expected 4", r) + } + } + b.ReportMetric(msgs/float64(b.N), "msgs/op") + b.ReportMetric(errs/float64(b.N), "errors/op") + b.ReportMetric(results/float64(b.N), "results/op") + b.ReportMetric(rowcounts/float64(b.N), "rowcounts/op") + }) +} diff --git a/messages_example_test.go b/messages_example_test.go new file mode 100644 index 00000000..37dc0e8f --- /dev/null +++ b/messages_example_test.go @@ -0,0 +1,76 @@ +//go:build go1.10 +// +build go1.10 + +package mssql_test + +import ( + "context" + "database/sql" + "fmt" + "log" + + mssql "github.com/denisenkom/go-mssqldb" + "github.com/golang-sql/sqlexp" +) + +const ( + msgQuery = `select 'name' as Name +PRINT N'This is a message' +select 199 +RAISERROR (N'Testing!' , 11, 1) +select 300 +` +) + +// This example shows the usage of sqlexp/Messages +func ExampleRows_usingmessages() { + + connString := makeConnURL().String() + + // Create a new connector object by calling NewConnector + connector, err := mssql.NewConnector(connString) + if err != nil { + log.Println(err) + return + } + + // Pass connector to sql.OpenDB to get a sql.DB object + db := sql.OpenDB(connector) + defer db.Close() + retmsg := &sqlexp.ReturnMessage{} + ctx := context.Background() + rows, err := db.QueryContext(ctx, msgQuery, retmsg) + if err != nil { + log.Fatalf("QueryContext failed: %v", err) + } + active := true + for active { + msg := retmsg.Message(ctx) + switch m := msg.(type) { + case sqlexp.MsgNotice: + fmt.Println(m.Message) + case sqlexp.MsgNext: + inresult := true + for inresult { + inresult = rows.Next() + if inresult { + cols, err := rows.Columns() + if err != nil { + log.Fatalf("Columns failed: %v", err) + } + fmt.Println(cols) + var d interface{} + if err = rows.Scan(&d); err == nil { + fmt.Println(d) + } + } + } + case sqlexp.MsgNextResultSet: + active = rows.NextResultSet() + case sqlexp.MsgError: + fmt.Println("Error:", m.Error) + case sqlexp.MsgRowsAffected: + fmt.Println("Rows affected:", m.Count) + } + } +} diff --git a/mssql.go b/mssql.go index c34f2a03..fbd44d1f 100644 --- a/mssql.go +++ b/mssql.go @@ -17,6 +17,7 @@ import ( "github.com/denisenkom/go-mssqldb/internal/querytext" "github.com/denisenkom/go-mssqldb/msdsn" + "github.com/golang-sql/sqlexp" ) // ReturnStatus may be used to return the return value from a proc. @@ -206,6 +207,7 @@ type Conn struct { type outputs struct { params map[string]interface{} returnStatus *ReturnStatus + msgq *sqlexp.ReturnMessage } // IsValid satisfies the driver.Validator interface. @@ -667,6 +669,11 @@ func (s *Stmt) processQueryResponse(ctx context.Context) (res driver.Rows, err e ctx, cancel := context.WithCancel(ctx) reader := startReading(s.c.sess, ctx, s.c.outs) s.c.clearOuts() + // For apps using a message queue, return right away and let Rowsq do all the work + if reader.outs.msgq != nil { + res = &Rowsq{stmt: s, reader: reader, cols: nil, cancel: cancel} + return res, nil + } // process metadata var cols []columnStruct loop: @@ -738,13 +745,13 @@ func (s *Stmt) processExec(ctx context.Context) (res driver.Result, err error) { return &Result{s.c, reader.rowCount}, nil } +// Rows represents the non-experimental data/sql model for Query and QueryContext type Rows struct { stmt *Stmt cols []columnStruct reader *tokenProcessor nextCols []columnStruct - - cancel func() + cancel func() } func (rc *Rows) Close() error { @@ -772,6 +779,7 @@ func (rc *Rows) Close() error { } func (rc *Rows) Columns() (res []string) { + res = make([]string, len(rc.cols)) for i, col := range rc.cols { res[i] = col.ColName @@ -793,6 +801,7 @@ func (rc *Rows) Next(dest []driver.Value) error { return io.EOF } else { switch tokdata := tok.(type) { + // processQueryResponse may have delegated all the token reading to us case []columnStruct: rc.nextCols = tokdata return io.EOF @@ -1058,3 +1067,214 @@ func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive } return s.exec(ctx, list) } + +// Rowsq implements the sqlexp messages model for Query and QueryContext +// Theory: We could also implement the non-experimental model this way +type Rowsq struct { + stmt *Stmt + cols []columnStruct + reader *tokenProcessor + nextCols []columnStruct + cancel func() + requestDone bool + inResultSet bool +} + +func (rc *Rowsq) Close() error { + rc.cancel() + + for { + tok, err := rc.reader.nextToken() + if err == nil { + if tok == nil { + return nil + } else { + // continue consuming tokens + continue + } + } else { + if err == rc.reader.ctx.Err() { + return nil + } else { + return err + } + } + } +} + +// data/sql calls Columns during the app's call to Next +func (rc *Rowsq) Columns() (res []string) { + if rc.cols == nil { + scan: + for { + tok, err := rc.reader.nextToken() + if err == nil { + if rc.reader.sess.logFlags&logDebug != 0 { + rc.reader.sess.logger.Log(rc.reader.ctx, msdsn.LogDebug, fmt.Sprintf("Columns() token type:%v", reflect.TypeOf(tok))) + } + if tok == nil { + return []string{} + } else { + switch tokdata := tok.(type) { + case []columnStruct: + rc.cols = tokdata + rc.inResultSet = true + break scan + } + } + } + } + } + res = make([]string, len(rc.cols)) + for i, col := range rc.cols { + res[i] = col.ColName + } + return +} + +func (rc *Rowsq) Next(dest []driver.Value) error { + if !rc.stmt.c.connectionGood { + return driver.ErrBadConn + } + for { + tok, err := rc.reader.nextToken() + if rc.reader.sess.logFlags&logDebug != 0 { + rc.reader.sess.logger.Log(rc.reader.ctx, msdsn.LogDebug, fmt.Sprintf("Next() token type:%v", reflect.TypeOf(tok))) + } + if err == nil { + if tok == nil { + return io.EOF + } else { + switch tokdata := tok.(type) { + case []interface{}: + for i := range dest { + dest[i] = tokdata[i] + } + return nil + case doneStruct: + if tokdata.Status&doneMore == 0 { + rc.requestDone = true + } + if tokdata.isError() { + e := rc.stmt.c.checkBadConn(rc.reader.ctx, tokdata.getError(), false) + switch e.(type) { + case Error: + // Ignore non-fatal server errors. Fatal errors are of type ServerError + default: + return e + } + } + if rc.inResultSet { + rc.inResultSet = false + return io.EOF + } + case ReturnStatus: + if rc.reader.outs.returnStatus != nil { + *rc.reader.outs.returnStatus = tokdata + } + } + } + + } else { + return rc.stmt.c.checkBadConn(rc.reader.ctx, err, false) + } + } +} + +// In Message Queue mode, we always claim another resultset could be on the way +// to avoid Rows being closed prematurely +func (rc *Rowsq) HasNextResultSet() bool { + return !rc.requestDone +} + +// Scans to the next set of columns in the stream +// Note that the caller may not have read all the rows in the prior set +func (rc *Rowsq) NextResultSet() error { + if rc.requestDone { + return io.EOF + } +scan: + for { + // we should have a columns token in the channel if we aren't at the end + tok, err := rc.reader.nextToken() + if rc.reader.sess.logFlags&logDebug != 0 { + rc.reader.sess.logger.Log(rc.reader.ctx, msdsn.LogDebug, fmt.Sprintf("NextResultSet() token type:%v", reflect.TypeOf(tok))) + } + + if err != nil { + return err + } + if tok == nil { + return io.EOF + } + switch tokdata := tok.(type) { + case []columnStruct: + rc.nextCols = tokdata + rc.inResultSet = true + break scan + case doneStruct: + if tokdata.Status&doneMore == 0 { + rc.nextCols = nil + rc.requestDone = true + break scan + } + } + } + rc.cols = rc.nextCols + rc.nextCols = nil + if rc.cols == nil { + return io.EOF + } + return nil +} + +// It should return +// the value type that can be used to scan types into. For example, the database +// column type "bigint" this should return "reflect.TypeOf(int64(0))". +func (r *Rowsq) ColumnTypeScanType(index int) reflect.Type { + return makeGoLangScanType(r.cols[index].ti) +} + +// RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the +// database system type name without the length. Type names should be uppercase. +// Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT", +// "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML", +// "TIMESTAMP". +func (r *Rowsq) ColumnTypeDatabaseTypeName(index int) string { + return makeGoLangTypeName(r.cols[index].ti) +} + +// RowsColumnTypeLength may be implemented by Rows. It should return the length +// of the column type if the column is a variable length type. If the column is +// not a variable length type ok should return false. +// If length is not limited other than system limits, it should return math.MaxInt64. +// The following are examples of returned values for various types: +// TEXT (math.MaxInt64, true) +// varchar(10) (10, true) +// nvarchar(10) (10, true) +// decimal (0, false) +// int (0, false) +// bytea(30) (30, true) +func (r *Rowsq) ColumnTypeLength(index int) (int64, bool) { + return makeGoLangTypeLength(r.cols[index].ti) +} + +// It should return +// the precision and scale for decimal types. If not applicable, ok should be false. +// The following are examples of returned values for various types: +// decimal(38, 4) (38, 4, true) +// int (0, 0, false) +// decimal (math.MaxInt64, math.MaxInt64, true) +func (r *Rowsq) ColumnTypePrecisionScale(index int) (int64, int64, bool) { + return makeGoLangTypePrecisionScale(r.cols[index].ti) +} + +// The nullable value should +// be true if it is known the column may be null, or false if the column is known +// to be not nullable. +// If the column nullability is unknown, ok should be false. +func (r *Rowsq) ColumnTypeNullable(index int) (nullable, ok bool) { + nullable = r.cols[index].Flags&colFlagNullable != 0 + ok = true + return +} diff --git a/mssql_go19.go b/mssql_go19.go index e77eebba..508b03a0 100644 --- a/mssql_go19.go +++ b/mssql_go19.go @@ -10,6 +10,8 @@ import ( "reflect" "time" + "github.com/golang-sql/sqlexp" + // "github.com/cockroachdb/apd" "github.com/golang-sql/civil" ) @@ -114,6 +116,10 @@ func (c *Conn) CheckNamedValue(nv *driver.NamedValue) error { return driver.ErrRemoveArgument case TVP: return nil + case *sqlexp.ReturnMessage: + sqlexp.ReturnMessageInit(v) + c.outs.msgq = v + return driver.ErrRemoveArgument default: var err error nv.Value, err = convertInputParameter(nv.Value) diff --git a/newconnector_example_test.go b/newconnector_example_test.go index 613866bb..8dc74baa 100644 --- a/newconnector_example_test.go +++ b/newconnector_example_test.go @@ -1,3 +1,4 @@ +//go:build go1.10 // +build go1.10 package mssql_test @@ -18,7 +19,7 @@ var ( debug = flag.Bool("debug", false, "enable debugging") password = flag.String("password", "", "the database password") port *int = flag.Int("port", 1433, "the database port") - server = flag.String("server", "", "the database server") + server = flag.String("server", ".", "the database server") user = flag.String("user", "", "the database user") ) @@ -32,23 +33,31 @@ const ( ) func makeConnURL() *url.URL { + flag.Parse() + if *debug { + fmt.Printf(" password:%s\n", *password) + fmt.Printf(" port:%d\n", *port) + fmt.Printf(" server:%s\n", *server) + fmt.Printf(" user:%s\n", *user) + } + + params, err := mssql.GetConnParams() + if err == nil && params != nil { + return params.URL() + } + var userInfo *url.Userinfo + if *user != "" { + userInfo = url.UserPassword(*user, *password) + } return &url.URL{ Scheme: "sqlserver", Host: *server + ":" + strconv.Itoa(*port), - User: url.UserPassword(*user, *password), + User: userInfo, } } // This example shows the usage of Connector type func ExampleConnector() { - flag.Parse() - - if *debug { - fmt.Printf(" password:%s\n", *password) - fmt.Printf(" port:%d\n", *port) - fmt.Printf(" server:%s\n", *server) - fmt.Printf(" user:%s\n", *user) - } connString := makeConnURL().String() if *debug { diff --git a/queries_go19_test.go b/queries_go19_test.go index bbb75d74..12371094 100644 --- a/queries_go19_test.go +++ b/queries_go19_test.go @@ -7,9 +7,12 @@ import ( "context" "database/sql" "fmt" + "reflect" "regexp" "testing" "time" + + "github.com/golang-sql/sqlexp" ) func TestOutputParam(t *testing.T) { @@ -1105,3 +1108,263 @@ func TestClearReturnStatus(t *testing.T) { t.Errorf("expected status=42, got %d", rs) } } + +func TestMessageQueue(t *testing.T) { + conn, logger := open(t) + defer conn.Close() + defer logger.StopLogging() + retmsg := &sqlexp.ReturnMessage{} + latency, _ := getLatency(t) + ctx, cancel := context.WithTimeout(context.Background(), latency+200000*time.Millisecond) + defer cancel() + rows, err := conn.QueryContext(ctx, "PRINT 'msg1'; select 100 as c; PRINT 'msg2'", retmsg) + if err != nil { + t.Fatal(err.Error()) + } + defer rows.Close() + active := true + + msgs := []interface{}{ + sqlexp.MsgNotice{Message: "msg1"}, + sqlexp.MsgNext{}, + sqlexp.MsgRowsAffected{Count: 1}, + sqlexp.MsgNotice{Message: "msg2"}, + sqlexp.MsgNextResultSet{}, + } + i := 0 + rsCount := 0 + for active { + msg := retmsg.Message(ctx) + if i >= len(msgs) { + t.Fatalf("Got extra message:%+v", msg) + } + t.Log(reflect.TypeOf(msg)) + if reflect.TypeOf(msgs[i]) != reflect.TypeOf(msg) { + t.Fatalf("Out of order or incorrect message at %d. Actual: %+v. Expected: %+v", i, reflect.TypeOf(msg), reflect.TypeOf(msgs[i])) + } + switch m := msg.(type) { + case sqlexp.MsgNotice: + t.Log(m.Message) + case sqlexp.MsgNextResultSet: + active = rows.NextResultSet() + if active { + t.Fatal("NextResultSet returned true") + } + rsCount++ + case sqlexp.MsgNext: + if !rows.Next() { + t.Fatal("rows.Next() returned false") + } + var c int + err = rows.Scan(&c) + if err != nil { + t.Fatalf("rows.Scan() failed: %s", err.Error()) + } + if c != 100 { + t.Fatalf("query returned wrong value: %d", c) + } + } + i++ + } +} + +func TestAdvanceResultSetAfterPartialRead(t *testing.T) { + conn, logger := open(t) + defer conn.Close() + defer logger.StopLogging() + + ctx := context.Background() + retmsg := &sqlexp.ReturnMessage{} + + rows, err := conn.QueryContext(ctx, "select top 2 object_id from sys.all_objects; print 'this is a message'; select 100 as Count; ", retmsg) + if err != nil { + t.Fatal(err.Error()) + } + defer rows.Close() + + rows.Next() + var g interface{} + err = rows.Scan(&g) + if err != nil { + t.Fatalf("Scan failed: %v", err) + } + next := rows.NextResultSet() + if !next { + t.Fatalf("NextResultSet returned false") + } + next = rows.Next() + if !next { + t.Fatalf("Next on the second result set returned false") + } + cols, err := rows.Columns() + if err != nil { + t.Fatalf("Columns() error: %s", err) + } + if cols[0] != "Count" { + t.Fatalf("Wrong column in second result:%s, expected Count", cols[0]) + } + var c int + err = rows.Scan(&c) + if err != nil { + t.Fatalf("Scan errored out on second result: %s", err) + } + if c != 100 { + t.Fatalf("Scan returned incorrect value on second result set: %d, expected 100", c) + } +} +func TestMessageQueueWithErrors(t *testing.T) { + conn, logger := open(t) + defer conn.Close() + defer logger.StopLogging() + + msgs, errs, results, rowcounts := testMixedQuery(conn, t) + if msgs != 1 { + t.Fatalf("Got %d messages, expected 1", msgs) + } + if errs != 1 { + t.Fatalf("Got %d errors, expected 1", errs) + } + if results != 4 { + t.Fatalf("Got %d results, expected 4", results) + } + if rowcounts != 4 { + t.Fatalf("Got %d row counts, expected 4", rowcounts) + } +} + +const mixedQuery = `select top 5 name from sys.system_columns +select getdate() +PRINT N'This is a message' +select 199 +RAISERROR (N'Testing!' , 11, 1) +select 300 +` + +func testMixedQuery(conn *sql.DB, b testing.TB) (msgs, errs, results, rowcounts int) { + ctx := context.Background() + retmsg := &sqlexp.ReturnMessage{} + r, err := conn.QueryContext(ctx, mixedQuery, retmsg) + if err != nil { + b.Fatal(err.Error()) + } + defer r.Close() + active := true + first := true + for active { + msg := retmsg.Message(ctx) + switch m := msg.(type) { + case sqlexp.MsgNotice: + b.Logf("MsgNotice:%s", m.Message) + msgs++ + case sqlexp.MsgNext: + b.Logf("MsgNext") + inresult := true + for inresult { + inresult = r.Next() + if first { + if !inresult { + b.Fatalf("First Next call returned false") + } + results++ + } + if inresult { + var d interface{} + err = r.Scan(&d) + if err != nil { + b.Fatalf("Scan failed:%v", err) + } + b.Logf("Row data:%v", d) + } + first = false + } + case sqlexp.MsgNextResultSet: + b.Log("MsgNextResultSet") + active = r.NextResultSet() + first = true + case sqlexp.MsgError: + b.Logf("MsgError:%v", m.Error) + errs++ + case sqlexp.MsgRowsAffected: + b.Logf("MsgRowsAffected:%d", m.Count) + rowcounts++ + } + } + return msgs, errs, results, rowcounts +} + +func TestTimeoutWithNoResults(t *testing.T) { + conn, logger := open(t) + defer conn.Close() + defer logger.StopLogging() + + latency, _ := getLatency(t) + ctx, cancel := context.WithTimeout(context.Background(), latency+5000*time.Millisecond) + defer cancel() + retmsg := &sqlexp.ReturnMessage{} + r, err := conn.QueryContext(ctx, `waitfor delay '00:00:15'; select 100`, retmsg) + if err != nil { + t.Fatal(err.Error()) + } + defer r.Close() + active := true + for active { + msg := retmsg.Message(ctx) + t.Logf("Got a message: %s", reflect.TypeOf(msg)) + switch m := msg.(type) { + case sqlexp.MsgNextResultSet: + active = r.NextResultSet() + if active { + t.Fatal("NextResultSet returned true") + } + case sqlexp.MsgNext: + if r.Next() { + t.Fatal("Got a successful Next even though the query should have timed out") + } + case sqlexp.MsgRowsAffected: + t.Fatalf("Got a MsgRowsAffected %d", m.Count) + } + } + if r.Err() != context.DeadlineExceeded { + t.Fatalf("Unexpected error: %v", r.Err()) + } + +} + +func TestCancelWithNoResults(t *testing.T) { + conn, logger := open(t) + defer conn.Close() + defer logger.StopLogging() + + latency, _ := getLatency(t) + ctx, cancel := context.WithTimeout(context.Background(), latency+5000*time.Millisecond) + retmsg := &sqlexp.ReturnMessage{} + r, err := conn.QueryContext(ctx, `waitfor delay '00:00:15'; select 100`, retmsg) + if err != nil { + cancel() + t.Fatal(err.Error()) + } + defer r.Close() + time.Sleep(latency + 100*time.Millisecond) + cancel() + active := true + for active { + msg := retmsg.Message(ctx) + t.Logf("Got a message: %s", reflect.TypeOf(msg)) + switch m := msg.(type) { + case sqlexp.MsgNextResultSet: + active = r.NextResultSet() + if active { + t.Fatal("NextResultSet returned true") + } + case sqlexp.MsgNext: + if r.Next() { + t.Fatal("Got a successful Next even though the query should been cancelled") + } + case sqlexp.MsgRowsAffected: + t.Fatalf("Got a MsgRowsAffected %d", m.Count) + } + } + if r.Err() != context.Canceled { + t.Fatalf("Unexpected error: %v", r.Err()) + } +} diff --git a/queries_test.go b/queries_test.go index 438db695..2b1125df 100644 --- a/queries_test.go +++ b/queries_test.go @@ -2077,7 +2077,7 @@ func TestLoginTimeout(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), latency+(2*increment)) defer cancel() _, err := conn.ExecContext(ctx, "waitfor delay '00:00:03'") - t.Log("Got error ", err) + t.Logf("Got error type %v: %s ", reflect.TypeOf(err), err.Error()) if oe, ok := err.(*net.OpError); ok { if !oe.Timeout() { t.Fatalf("Got non-timeout error %s", oe.Error()) diff --git a/tds.go b/tds.go index fbc0d149..dbe95272 100644 --- a/tds.go +++ b/tds.go @@ -245,6 +245,13 @@ func readPrelogin(r *tdsBuffer) (map[uint8][]byte, error) { return results, nil } +// OptionFlags1 +// http://msdn.microsoft.com/en-us/library/dd304019.aspx +const ( + fUseDB = 0x20 + fSetLang = 0x80 +) + // OptionFlags2 // http://msdn.microsoft.com/en-us/library/dd304019.aspx const ( @@ -981,6 +988,7 @@ func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger Cont PacketSize: packetSize, Database: p.Database, OptionFlags2: fODBC, // to get unlimited TEXTSIZE + OptionFlags1: fUseDB | fSetLang, HostName: p.Workstation, ServerName: serverName, AppName: p.AppName, diff --git a/tds_go110_test.go b/tds_go110_test.go index 8f2b0d50..c36f29a9 100644 --- a/tds_go110_test.go +++ b/tds_go110_test.go @@ -7,10 +7,9 @@ import ( "testing" ) -func open(t *testing.T) (*sql.DB, *testLogger) { +func open(t testing.TB) (*sql.DB, *testLogger) { tl := testLogger{t: t} SetLogger(&tl) - checkConnStr(t) connector, err := NewConnector(makeConnStr(t).String()) if err != nil { t.Error("Open connection failed:", err.Error()) diff --git a/tds_login_test.go b/tds_login_test.go index 08f3e3e5..f7c4cb79 100644 --- a/tds_login_test.go +++ b/tds_login_test.go @@ -75,8 +75,16 @@ func testLoginSequenceServer(result chan error, conn net.Conn, expectedPackets, for bi := 0; bi < n; bi++ { if expectedBytes[bi+b] != packet[bi] { - err = fmt.Errorf("Client sent unexpected byte %02X != %02X at offset %d of packet %d", - packet[bi], expectedBytes[bi+b], bi+b, i) + suffix := "" + if bi > 0 { + suffix = fmt.Sprintf("Previous byte: %02X", packet[bi-1]) + } + if bi < n { + suffix = fmt.Sprintf("%s Next byte:%02X", suffix, packet[bi+1]) + } + err = fmt.Errorf("Client sent unexpected byte %02X != %02X at offset %d of packet %d. %s", + packet[bi], expectedBytes[bi+b], bi+b, i, suffix) + result <- err return } @@ -126,7 +134,7 @@ func TestLoginWithSQLServerAuth(t *testing.T) { "01 ff 00 00 00 00 00 00 00 00 00 00 00 00 00\n", " 10 01 00 b2 00 00 01 00 aa 00 00 00 04 00 00 74\n" + "00 10 00 00 00 00 00 00 00 00 00 00 00 00 00 00\n" + - "00 02 00 00 00 00 00 00 00 00 00 00 5e 00 09 00\n" + + "A0 02 00 00 00 00 00 00 00 00 00 00 5e 00 09 00\n" + "70 00 04 00 78 00 06 00 84 00 0a 00 98 00 09 00\n" + "00 00 00 00 aa 00 00 00 aa 00 00 00 aa 00 00 00\n" + "00 00 00 00 00 00 aa 00 00 00 aa 00 00 00 aa 00\n" + @@ -187,7 +195,7 @@ func TestLoginWithSecurityTokenAuth(t *testing.T) { "00 00 00 00 01\n", " 10 01 00 BB 00 00 01 00 B3 00 00 00 04 00 00 74\n" + "00 10 00 00 00 00 00 00 00 00 00 00 00 00 00 00\n" + - "00 02 00 10 00 00 00 00 00 00 00 00 5E 00 09 00\n" + + "A0 02 00 10 00 00 00 00 00 00 00 00 5E 00 09 00\n" + "70 00 00 00 70 00 00 00 70 00 0A 00 84 00 09 00\n" + "96 00 04 00 96 00 00 00 96 00 00 00 96 00 00 00\n" + "00 00 00 00 00 00 96 00 00 00 96 00 00 00 96 00\n" + @@ -250,7 +258,7 @@ func TestLoginWithADALUsernamePasswordAuth(t *testing.T) { "00 00 00 00 01\n", " 10 01 00 aa 00 00 01 00 a2 00 00 00 04 00 00 74\n" + "00 10 00 00 00 00 00 00 00 00 00 00 00 00 00 00\n" + - "00 02 00 10 00 00 00 00 00 00 00 00 5e 00 09 00\n" + + "A0 02 00 10 00 00 00 00 00 00 00 00 5e 00 09 00\n" + "70 00 00 00 70 00 00 00 70 00 0a 00 84 00 09 00\n" + "96 00 04 00 96 00 00 00 96 00 00 00 96 00 00 00\n" + "00 00 00 00 00 00 96 00 00 00 96 00 00 00 96 00\n" + @@ -324,7 +332,7 @@ func TestLoginWithADALManagedIdentityAuth(t *testing.T) { "00 00 00 00 01\n", " 10 01 00 aa 00 00 01 00 a2 00 00 00 04 00 00 74\n" + "00 10 00 00 00 00 00 00 00 00 00 00 00 00 00 00\n" + - "00 02 00 10 00 00 00 00 00 00 00 00 5e 00 09 00\n" + + "A0 02 00 10 00 00 00 00 00 00 00 00 5e 00 09 00\n" + "70 00 00 00 70 00 00 00 70 00 0a 00 84 00 09 00\n" + "96 00 04 00 96 00 00 00 96 00 00 00 96 00 00 00\n" + "00 00 00 00 00 00 96 00 00 00 96 00 00 00 96 00\n" + diff --git a/tds_test.go b/tds_test.go index bfce86c5..7d9af0de 100644 --- a/tds_test.go +++ b/tds_test.go @@ -202,25 +202,43 @@ func TestSendSqlBatch(t *testing.T) { // returns parsed connection parameters derived from // environment variables func testConnParams(t testing.TB) msdsn.Config { + params, err := GetConnParams() + if err != nil { + t.Fatal("unable to parse SQLSERVER_DSN or read .connstr", err) + } + if params == nil { + t.Skip("no database connection string") + return msdsn.Config{} + } + return *params +} + +// TestConnParams returns a connection configuration based on environment variables or the contents of a text file +// Set environment variable SQLSERVER_DSN to provide an entire connection string +// Set environment variables HOST and DATABASE from which a minimal config will be created. +// If HOST and DATABASE are set, you can optionally set INSTANCE, SQLUSER, and SQLPASSWORD as well +// If environment variables are not set, it will look in the working directory for a file named .connstr +// If the file exists it will use the first line of the file as the file as the DSN +func GetConnParams() (*msdsn.Config, error) { dsn := os.Getenv("SQLSERVER_DSN") const logFlags = 127 if len(dsn) > 0 { params, _, err := msdsn.Parse(dsn) if err != nil { - t.Fatal("unable to parse SQLSERVER_DSN", err) + return nil, err } params.LogFlags = logFlags - return params + return ¶ms, nil } if len(os.Getenv("HOST")) > 0 && len(os.Getenv("DATABASE")) > 0 { - return msdsn.Config{ + return &msdsn.Config{ Host: os.Getenv("HOST"), Instance: os.Getenv("INSTANCE"), Database: os.Getenv("DATABASE"), User: os.Getenv("SQLUSER"), Password: os.Getenv("SQLPASSWORD"), LogFlags: logFlags, - } + }, nil } // try loading connection string from file f, err := os.Open(".connstr") @@ -228,17 +246,17 @@ func testConnParams(t testing.TB) msdsn.Config { rdr := bufio.NewReader(f) dsn, err := rdr.ReadString('\n') if err != io.EOF && err != nil { - t.Fatal(err) + return nil, err } params, _, err := msdsn.Parse(dsn) if err != nil { - t.Fatal("unable to parse connection string loaded from file", err) + return nil, err } params.LogFlags = logFlags - return params + return ¶ms, nil } - t.Skip("no database connection string") - return msdsn.Config{} + + return nil, nil } func checkConnStr(t testing.TB) { diff --git a/token.go b/token.go index 643a78ac..43039d3d 100644 --- a/token.go +++ b/token.go @@ -10,6 +10,7 @@ import ( "strconv" "github.com/denisenkom/go-mssqldb/msdsn" + "github.com/golang-sql/sqlexp" ) //go:generate go run golang.org/x/tools/cmd/stringer -type token @@ -108,6 +109,7 @@ func (d doneStruct) getError() Error { return Error{Message: "Request failed but didn't provide reason"} } err := d.errors[n-1] + // should this return the most severe error? err.All = make([]Error, n) copy(err.All, d.errors) return err @@ -643,6 +645,7 @@ func parseReturnValue(r *tdsBuffer) (nv namedValue) { } func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenStruct, outs outputs) { + firstResult := true defer func() { if err := recover(); err != nil { if sess.logFlags&logErrors != 0 { @@ -692,30 +695,67 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS ch <- order case tokenDoneInProc: done := parseDoneInProc(sess.buf) + + ch <- done if sess.logFlags&logRows != 0 && done.Status&doneCount != 0 { - sess.logger.Log(ctx, msdsn.LogRows, fmt.Sprintf("(%d row(s) affected)", done.RowCount)) + sess.logger.Log(ctx, msdsn.LogRows, fmt.Sprintf("(%d rows affected)", done.RowCount)) + + if outs.msgq != nil { + _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgRowsAffected{Count: int64(done.RowCount)}) + } + } + if done.Status&doneMore == 0 { + if outs.msgq != nil { + // For now we ignore ctx->Done errors that ReturnMessageEnqueue might return + // It's not clear how to handle them correctly here, and data/sql seems + // to set Rows.Err correctly when ctx expires already + _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{}) + } + return } - ch <- done case tokenDone, tokenDoneProc: done := parseDone(sess.buf) done.errors = errs + if outs.msgq != nil { + errs = make([]Error, 0, 5) + } if sess.logFlags&logDebug != 0 { sess.logger.Log(ctx, msdsn.LogDebug, fmt.Sprintf("got DONE or DONEPROC status=%d", done.Status)) } if done.Status&doneSrvError != 0 { ch <- ServerError{done.getError()} + if outs.msgq != nil { + _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{}) + } return } if sess.logFlags&logRows != 0 && done.Status&doneCount != 0 { sess.logger.Log(ctx, msdsn.LogRows, fmt.Sprintf("(%d row(s) affected)", done.RowCount)) } ch <- done + if done.Status&doneCount != 0 { + if outs.msgq != nil { + _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgRowsAffected{Count: int64(done.RowCount)}) + } + } if done.Status&doneMore == 0 { + if outs.msgq != nil { + _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{}) + } return } case tokenColMetadata: columns = parseColMetadata72(sess.buf) ch <- columns + + if outs.msgq != nil { + if !firstResult { + _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{}) + } + _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNext{}) + } + firstResult = false + case tokenRow: row := make([]interface{}, len(columns)) parseRow(sess.buf, columns, row) @@ -735,6 +775,9 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS if sess.logFlags&logErrors != 0 { sess.logger.Log(ctx, msdsn.LogErrors, err.Message) } + if outs.msgq != nil { + _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgError{Error: err}) + } case tokenInfo: info := parseInfo(sess.buf) if sess.logFlags&logDebug != 0 { @@ -743,6 +786,9 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS if sess.logFlags&logMessages != 0 { sess.logger.Log(ctx, msdsn.LogMessages, info.Message) } + if outs.msgq != nil { + _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNotice{Message: info.Message}) + } case tokenReturnValue: nv := parseReturnValue(sess.buf) if len(nv.Name) > 0 { @@ -854,6 +900,10 @@ func (t tokenProcessor) nextToken() (tokenStruct, error) { return nil, nil } case <-t.ctx.Done(): + // It seems the Message function on t.outs.msgq doesn't get the Done if it comes here instead + if t.outs.msgq != nil { + _ = sqlexp.ReturnMessageEnqueue(t.ctx, t.outs.msgq, sqlexp.MsgNextResultSet{}) + } if t.noAttn { return nil, t.ctx.Err() } diff --git a/tvp_example_test.go b/tvp_example_test.go index 99582155..c27f98e5 100644 --- a/tvp_example_test.go +++ b/tvp_example_test.go @@ -1,10 +1,10 @@ +//go:build go1.10 // +build go1.10 package mssql_test import ( "database/sql" - "flag" "fmt" "log" @@ -46,19 +46,7 @@ func ExampleTVP() { Currency string `json:"-"` } - flag.Parse() - - if *debug { - fmt.Printf(" password:%s\n", *password) - fmt.Printf(" port:%d\n", *port) - fmt.Printf(" server:%s\n", *server) - fmt.Printf(" user:%s\n", *user) - } - connString := makeConnURL().String() - if *debug { - fmt.Printf(" connString:%s\n", connString) - } db, err := sql.Open("sqlserver", connString) if err != nil {