diff --git a/mssql.go b/mssql.go index fbd44d1f..d9a52112 100644 --- a/mssql.go +++ b/mssql.go @@ -1074,7 +1074,6 @@ type Rowsq struct { stmt *Stmt cols []columnStruct reader *tokenProcessor - nextCols []columnStruct cancel func() requestDone bool inResultSet bool @@ -1102,8 +1101,11 @@ func (rc *Rowsq) Close() error { } } -// data/sql calls Columns during the app's call to Next +// ProcessSingleResponse queues MsgNext for every columns token. +// data/sql calls Columns during the app's call to Next. func (rc *Rowsq) Columns() (res []string) { + // r.cols is nil if the first query in a batch is a SELECT or similar query that returns a rowset. + // if will be non-nil for subsequent queries where NextResultSet() has populated it if rc.cols == nil { scan: for { @@ -1145,6 +1147,10 @@ func (rc *Rowsq) Next(dest []driver.Value) error { if tok == nil { return io.EOF } else { + switch tokdata := tok.(type) { + case doneInProcStruct: + tok = (doneStruct)(tokdata) + } switch tokdata := tok.(type) { case []interface{}: for i := range dest { @@ -1172,9 +1178,11 @@ func (rc *Rowsq) Next(dest []driver.Value) error { if rc.reader.outs.returnStatus != nil { *rc.reader.outs.returnStatus = tokdata } + case ServerError: + rc.requestDone = true + return tokdata } } - } else { return rc.stmt.c.checkBadConn(rc.reader.ctx, err, false) } @@ -1187,7 +1195,7 @@ func (rc *Rowsq) HasNextResultSet() bool { return !rc.requestDone } -// Scans to the next set of columns in the stream +// Scans to the end of the current statement being processed // Note that the caller may not have read all the rows in the prior set func (rc *Rowsq) NextResultSet() error { if rc.requestDone { @@ -1195,7 +1203,6 @@ func (rc *Rowsq) NextResultSet() error { } scan: for { - // we should have a columns token in the channel if we aren't at the end tok, err := rc.reader.nextToken() if rc.reader.sess.logFlags&logDebug != 0 { rc.reader.sess.logger.Log(rc.reader.ctx, msdsn.LogDebug, fmt.Sprintf("NextResultSet() token type:%v", reflect.TypeOf(tok))) @@ -1208,23 +1215,42 @@ scan: return io.EOF } switch tokdata := tok.(type) { + case doneInProcStruct: + tok = (doneStruct)(tokdata) + } + // ProcessSingleResponse queues a MsgNextResult for every "done" and "server error" token + // The only tokens to consume after a "done" should be "done", "server error", or "columns" + switch tokdata := tok.(type) { case []columnStruct: - rc.nextCols = tokdata + rc.cols = tokdata rc.inResultSet = true break scan case doneStruct: if tokdata.Status&doneMore == 0 { - rc.nextCols = nil rc.requestDone = true - break scan } + if tokdata.isError() { + e := rc.stmt.c.checkBadConn(rc.reader.ctx, tokdata.getError(), false) + switch e.(type) { + case Error: + // Ignore non-fatal server errors. Fatal errors are of type ServerError + default: + return e + } + } + rc.inResultSet = false + rc.cols = nil + break scan + case ReturnStatus: + if rc.reader.outs.returnStatus != nil { + *rc.reader.outs.returnStatus = tokdata + } + case ServerError: + rc.requestDone = true + return tokdata } } - rc.cols = rc.nextCols - rc.nextCols = nil - if rc.cols == nil { - return io.EOF - } + return nil } diff --git a/queries_go19_test.go b/queries_go19_test.go index 12371094..c8f431e8 100644 --- a/queries_go19_test.go +++ b/queries_go19_test.go @@ -1,3 +1,4 @@ +//go:build go1.9 // +build go1.9 package mssql @@ -1126,17 +1127,19 @@ func TestMessageQueue(t *testing.T) { msgs := []interface{}{ sqlexp.MsgNotice{Message: "msg1"}, + sqlexp.MsgNextResultSet{}, sqlexp.MsgNext{}, sqlexp.MsgRowsAffected{Count: 1}, + sqlexp.MsgNextResultSet{}, sqlexp.MsgNotice{Message: "msg2"}, sqlexp.MsgNextResultSet{}, + sqlexp.MsgNextResultSet{}, } i := 0 - rsCount := 0 for active { msg := retmsg.Message(ctx) if i >= len(msgs) { - t.Fatalf("Got extra message:%+v", msg) + t.Fatalf("Got extra message:%+v", reflect.TypeOf(msg)) } t.Log(reflect.TypeOf(msg)) if reflect.TypeOf(msgs[i]) != reflect.TypeOf(msg) { @@ -1147,10 +1150,6 @@ func TestMessageQueue(t *testing.T) { t.Log(m.Message) case sqlexp.MsgNextResultSet: active = rows.NextResultSet() - if active { - t.Fatal("NextResultSet returned true") - } - rsCount++ case sqlexp.MsgNext: if !rows.Next() { t.Fatal("rows.Next() returned false") @@ -1368,3 +1367,135 @@ func TestCancelWithNoResults(t *testing.T) { t.Fatalf("Unexpected error: %v", r.Err()) } } + +const DropSprocWithCursor = `IF EXISTS (SELECT * FROM sys.objects WHERE object_id = OBJECT_ID(N'[dbo].[TestSqlCmd]') AND type in (N'P', N'PC')) +DROP PROCEDURE [dbo].[TestSqlCmd] +` + +// This query generates half a dozen tokenDoneInProc tokens which fill the channel if the app isn't scanning Rowsq +const CreateSprocWithCursor = ` +CREATE PROCEDURE [dbo].[TestSqlCmd] +AS +BEGIN + DECLARE @tmp int; + DECLARE Server_Cursor CURSOR FOR + SELECT 1 UNION SELECT 2 + OPEN Server_Cursor; + FETCH NEXT FROM Server_Cursor INTO @tmp; + WHILE @@FETCH_STATUS = 0 + BEGIN + PRINT @tmp + FETCH NEXT FROM Server_Cursor INTO @tmp; + END; + CLOSE Server_Cursor; + DEALLOCATE Server_Cursor; +END +` + +func TestSprocWithCursorNoResult(t *testing.T) { + conn, logger := open(t) + defer conn.Close() + defer logger.StopLogging() + + _, e := conn.Exec(DropSprocWithCursor) + if e != nil { + t.Fatalf("Unable to drop test sproc: %v", e) + } + _, e = conn.Exec(CreateSprocWithCursor) + if e != nil { + t.Fatalf("Unable to create test sproc: %v", e) + } + defer conn.Exec(DropSprocWithCursor) + latency, _ := getLatency(t) + ctx, cancel := context.WithTimeout(context.Background(), latency+500*time.Millisecond) + defer cancel() + retmsg := &sqlexp.ReturnMessage{} + // Use a sproc instead of the cursor loop directly to cover the different code path in token.go + r, err := conn.QueryContext(ctx, `exec [dbo].[TestSqlCmd]`, retmsg) + if err != nil { + t.Fatal(err.Error()) + } + defer r.Close() + active := true + rsCount := 0 + msgCount := 0 + for active { + msg := retmsg.Message(ctx) + t.Logf("Got a message: %v", reflect.TypeOf(msg)) + switch m := msg.(type) { + case sqlexp.MsgNext: + t.Fatalf("Got a MsgNext from a query with no rows") + case sqlexp.MsgError: + t.Fatalf("Got an error: %s", m.Error.Error()) + case sqlexp.MsgNotice: + msgCount++ + case sqlexp.MsgNextResultSet: + if active = r.NextResultSet(); active { + rsCount++ + } + } + } + if r.Err() != nil { + t.Fatalf("Got an error: %v", r.Err()) + } + if rsCount != 13 { + t.Fatalf("Unexpected record set count: %v", rsCount) + } + if msgCount != 2 { + t.Fatalf("Unexpected message count: %v", msgCount) + } +} + +func TestErrorAsLastResult(t *testing.T) { + conn, logger := open(t) + defer conn.Close() + defer logger.StopLogging() + latency, _ := getLatency(t) + ctx, cancel := context.WithTimeout(context.Background(), latency+5000*time.Millisecond) + defer cancel() + retmsg := &sqlexp.ReturnMessage{} + // Use a sproc instead of the cursor loop directly to cover the different code path in token.go + r, err := conn.QueryContext(ctx, + ` + Print N'message' + select 1 + raiserror(N'Error!', 16, 1)`, + retmsg) + if err != nil { + t.Fatal(err.Error()) + } + defer r.Close() + active := true + d := 0 + err = nil + for active { + msg := retmsg.Message(ctx) + t.Logf("Got a message: %s", reflect.TypeOf(msg)) + switch m := msg.(type) { + case sqlexp.MsgNext: + if !r.Next() { + t.Fatalf("Next returned false") + } + r.Scan(&d) + if r.Next() { + t.Fatal("Second Next returned true") + } + case sqlexp.MsgError: + err = m.Error + case sqlexp.MsgNextResultSet: + active = r.NextResultSet() + } + } + if err == nil { + t.Fatal("Should have gotten an error message") + } else { + switch e := err.(type) { + case Error: + if e.Message != "Error!" || e.Class != 16 { + t.Fatalf("Got the wrong mssql error %v", e) + } + default: + t.Fatalf("Got an unexpected error %v", e) + } + } +} diff --git a/token.go b/token.go index 43039d3d..9692129e 100644 --- a/token.go +++ b/token.go @@ -645,7 +645,6 @@ func parseReturnValue(r *tdsBuffer) (nv namedValue) { } func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenStruct, outs outputs) { - firstResult := true defer func() { if err := recover(); err != nil { if sess.logFlags&logErrors != 0 { @@ -704,11 +703,16 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgRowsAffected{Count: int64(done.RowCount)}) } } + if outs.msgq != nil { + // For now we ignore ctx->Done errors that ReturnMessageEnqueue might return + // It's not clear how to handle them correctly here, and data/sql seems + // to set Rows.Err correctly when ctx expires already + _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{}) + } if done.Status&doneMore == 0 { + // Rows marks the request as done when seeing this done token. We queue another result set message + // so the app calls NextResultSet again which will return false. if outs.msgq != nil { - // For now we ignore ctx->Done errors that ReturnMessageEnqueue might return - // It's not clear how to handle them correctly here, and data/sql seems - // to set Rows.Err correctly when ctx expires already _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{}) } return @@ -738,7 +742,12 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgRowsAffected{Count: int64(done.RowCount)}) } } + if outs.msgq != nil { + _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{}) + } if done.Status&doneMore == 0 { + // Rows marks the request as done when seeing this done token. We queue another result set message + // so the app calls NextResultSet again which will return false. if outs.msgq != nil { _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{}) } @@ -749,12 +758,8 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS ch <- columns if outs.msgq != nil { - if !firstResult { - _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{}) - } _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNext{}) } - firstResult = false case tokenRow: row := make([]interface{}, len(columns))