Skip to content

Commit

Permalink
Merge pull request #3 from shueybubbles/shueybubbles/fixmessageq
Browse files Browse the repository at this point in the history
Fix message ordering in Rowsq
  • Loading branch information
shueybubbles authored Apr 4, 2022
2 parents c47e89f + fafb9d9 commit 983a80e
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 27 deletions.
52 changes: 39 additions & 13 deletions mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -1074,7 +1074,6 @@ type Rowsq struct {
stmt *Stmt
cols []columnStruct
reader *tokenProcessor
nextCols []columnStruct
cancel func()
requestDone bool
inResultSet bool
Expand Down Expand Up @@ -1102,8 +1101,11 @@ func (rc *Rowsq) Close() error {
}
}

// data/sql calls Columns during the app's call to Next
// ProcessSingleResponse queues MsgNext for every columns token.
// data/sql calls Columns during the app's call to Next.
func (rc *Rowsq) Columns() (res []string) {
// r.cols is nil if the first query in a batch is a SELECT or similar query that returns a rowset.
// if will be non-nil for subsequent queries where NextResultSet() has populated it
if rc.cols == nil {
scan:
for {
Expand Down Expand Up @@ -1145,6 +1147,10 @@ func (rc *Rowsq) Next(dest []driver.Value) error {
if tok == nil {
return io.EOF
} else {
switch tokdata := tok.(type) {
case doneInProcStruct:
tok = (doneStruct)(tokdata)
}
switch tokdata := tok.(type) {
case []interface{}:
for i := range dest {
Expand Down Expand Up @@ -1172,9 +1178,11 @@ func (rc *Rowsq) Next(dest []driver.Value) error {
if rc.reader.outs.returnStatus != nil {
*rc.reader.outs.returnStatus = tokdata
}
case ServerError:
rc.requestDone = true
return tokdata
}
}

} else {
return rc.stmt.c.checkBadConn(rc.reader.ctx, err, false)
}
Expand All @@ -1187,15 +1195,14 @@ func (rc *Rowsq) HasNextResultSet() bool {
return !rc.requestDone
}

// Scans to the next set of columns in the stream
// Scans to the end of the current statement being processed
// Note that the caller may not have read all the rows in the prior set
func (rc *Rowsq) NextResultSet() error {
if rc.requestDone {
return io.EOF
}
scan:
for {
// we should have a columns token in the channel if we aren't at the end
tok, err := rc.reader.nextToken()
if rc.reader.sess.logFlags&logDebug != 0 {
rc.reader.sess.logger.Log(rc.reader.ctx, msdsn.LogDebug, fmt.Sprintf("NextResultSet() token type:%v", reflect.TypeOf(tok)))
Expand All @@ -1208,23 +1215,42 @@ scan:
return io.EOF
}
switch tokdata := tok.(type) {
case doneInProcStruct:
tok = (doneStruct)(tokdata)
}
// ProcessSingleResponse queues a MsgNextResult for every "done" and "server error" token
// The only tokens to consume after a "done" should be "done", "server error", or "columns"
switch tokdata := tok.(type) {
case []columnStruct:
rc.nextCols = tokdata
rc.cols = tokdata
rc.inResultSet = true
break scan
case doneStruct:
if tokdata.Status&doneMore == 0 {
rc.nextCols = nil
rc.requestDone = true
break scan
}
if tokdata.isError() {
e := rc.stmt.c.checkBadConn(rc.reader.ctx, tokdata.getError(), false)
switch e.(type) {
case Error:
// Ignore non-fatal server errors. Fatal errors are of type ServerError
default:
return e
}
}
rc.inResultSet = false
rc.cols = nil
break scan
case ReturnStatus:
if rc.reader.outs.returnStatus != nil {
*rc.reader.outs.returnStatus = tokdata
}
case ServerError:
rc.requestDone = true
return tokdata
}
}
rc.cols = rc.nextCols
rc.nextCols = nil
if rc.cols == nil {
return io.EOF
}

return nil
}

Expand Down
143 changes: 137 additions & 6 deletions queries_go19_test.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//go:build go1.9
// +build go1.9

package mssql
Expand Down Expand Up @@ -1126,17 +1127,19 @@ func TestMessageQueue(t *testing.T) {

msgs := []interface{}{
sqlexp.MsgNotice{Message: "msg1"},
sqlexp.MsgNextResultSet{},
sqlexp.MsgNext{},
sqlexp.MsgRowsAffected{Count: 1},
sqlexp.MsgNextResultSet{},
sqlexp.MsgNotice{Message: "msg2"},
sqlexp.MsgNextResultSet{},
sqlexp.MsgNextResultSet{},
}
i := 0
rsCount := 0
for active {
msg := retmsg.Message(ctx)
if i >= len(msgs) {
t.Fatalf("Got extra message:%+v", msg)
t.Fatalf("Got extra message:%+v", reflect.TypeOf(msg))
}
t.Log(reflect.TypeOf(msg))
if reflect.TypeOf(msgs[i]) != reflect.TypeOf(msg) {
Expand All @@ -1147,10 +1150,6 @@ func TestMessageQueue(t *testing.T) {
t.Log(m.Message)
case sqlexp.MsgNextResultSet:
active = rows.NextResultSet()
if active {
t.Fatal("NextResultSet returned true")
}
rsCount++
case sqlexp.MsgNext:
if !rows.Next() {
t.Fatal("rows.Next() returned false")
Expand Down Expand Up @@ -1368,3 +1367,135 @@ func TestCancelWithNoResults(t *testing.T) {
t.Fatalf("Unexpected error: %v", r.Err())
}
}

const DropSprocWithCursor = `IF EXISTS (SELECT * FROM sys.objects WHERE object_id = OBJECT_ID(N'[dbo].[TestSqlCmd]') AND type in (N'P', N'PC'))
DROP PROCEDURE [dbo].[TestSqlCmd]
`

// This query generates half a dozen tokenDoneInProc tokens which fill the channel if the app isn't scanning Rowsq
const CreateSprocWithCursor = `
CREATE PROCEDURE [dbo].[TestSqlCmd]
AS
BEGIN
DECLARE @tmp int;
DECLARE Server_Cursor CURSOR FOR
SELECT 1 UNION SELECT 2
OPEN Server_Cursor;
FETCH NEXT FROM Server_Cursor INTO @tmp;
WHILE @@FETCH_STATUS = 0
BEGIN
PRINT @tmp
FETCH NEXT FROM Server_Cursor INTO @tmp;
END;
CLOSE Server_Cursor;
DEALLOCATE Server_Cursor;
END
`

func TestSprocWithCursorNoResult(t *testing.T) {
conn, logger := open(t)
defer conn.Close()
defer logger.StopLogging()

_, e := conn.Exec(DropSprocWithCursor)
if e != nil {
t.Fatalf("Unable to drop test sproc: %v", e)
}
_, e = conn.Exec(CreateSprocWithCursor)
if e != nil {
t.Fatalf("Unable to create test sproc: %v", e)
}
defer conn.Exec(DropSprocWithCursor)
latency, _ := getLatency(t)
ctx, cancel := context.WithTimeout(context.Background(), latency+500*time.Millisecond)
defer cancel()
retmsg := &sqlexp.ReturnMessage{}
// Use a sproc instead of the cursor loop directly to cover the different code path in token.go
r, err := conn.QueryContext(ctx, `exec [dbo].[TestSqlCmd]`, retmsg)
if err != nil {
t.Fatal(err.Error())
}
defer r.Close()
active := true
rsCount := 0
msgCount := 0
for active {
msg := retmsg.Message(ctx)
t.Logf("Got a message: %v", reflect.TypeOf(msg))
switch m := msg.(type) {
case sqlexp.MsgNext:
t.Fatalf("Got a MsgNext from a query with no rows")
case sqlexp.MsgError:
t.Fatalf("Got an error: %s", m.Error.Error())
case sqlexp.MsgNotice:
msgCount++
case sqlexp.MsgNextResultSet:
if active = r.NextResultSet(); active {
rsCount++
}
}
}
if r.Err() != nil {
t.Fatalf("Got an error: %v", r.Err())
}
if rsCount != 13 {
t.Fatalf("Unexpected record set count: %v", rsCount)
}
if msgCount != 2 {
t.Fatalf("Unexpected message count: %v", msgCount)
}
}

func TestErrorAsLastResult(t *testing.T) {
conn, logger := open(t)
defer conn.Close()
defer logger.StopLogging()
latency, _ := getLatency(t)
ctx, cancel := context.WithTimeout(context.Background(), latency+5000*time.Millisecond)
defer cancel()
retmsg := &sqlexp.ReturnMessage{}
// Use a sproc instead of the cursor loop directly to cover the different code path in token.go
r, err := conn.QueryContext(ctx,
`
Print N'message'
select 1
raiserror(N'Error!', 16, 1)`,
retmsg)
if err != nil {
t.Fatal(err.Error())
}
defer r.Close()
active := true
d := 0
err = nil
for active {
msg := retmsg.Message(ctx)
t.Logf("Got a message: %s", reflect.TypeOf(msg))
switch m := msg.(type) {
case sqlexp.MsgNext:
if !r.Next() {
t.Fatalf("Next returned false")
}
r.Scan(&d)
if r.Next() {
t.Fatal("Second Next returned true")
}
case sqlexp.MsgError:
err = m.Error
case sqlexp.MsgNextResultSet:
active = r.NextResultSet()
}
}
if err == nil {
t.Fatal("Should have gotten an error message")
} else {
switch e := err.(type) {
case Error:
if e.Message != "Error!" || e.Class != 16 {
t.Fatalf("Got the wrong mssql error %v", e)
}
default:
t.Fatalf("Got an unexpected error %v", e)
}
}
}
21 changes: 13 additions & 8 deletions token.go
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,6 @@ func parseReturnValue(r *tdsBuffer) (nv namedValue) {
}

func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenStruct, outs outputs) {
firstResult := true
defer func() {
if err := recover(); err != nil {
if sess.logFlags&logErrors != 0 {
Expand Down Expand Up @@ -704,11 +703,16 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS
_ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgRowsAffected{Count: int64(done.RowCount)})
}
}
if outs.msgq != nil {
// For now we ignore ctx->Done errors that ReturnMessageEnqueue might return
// It's not clear how to handle them correctly here, and data/sql seems
// to set Rows.Err correctly when ctx expires already
_ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{})
}
if done.Status&doneMore == 0 {
// Rows marks the request as done when seeing this done token. We queue another result set message
// so the app calls NextResultSet again which will return false.
if outs.msgq != nil {
// For now we ignore ctx->Done errors that ReturnMessageEnqueue might return
// It's not clear how to handle them correctly here, and data/sql seems
// to set Rows.Err correctly when ctx expires already
_ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{})
}
return
Expand Down Expand Up @@ -738,7 +742,12 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS
_ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgRowsAffected{Count: int64(done.RowCount)})
}
}
if outs.msgq != nil {
_ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{})
}
if done.Status&doneMore == 0 {
// Rows marks the request as done when seeing this done token. We queue another result set message
// so the app calls NextResultSet again which will return false.
if outs.msgq != nil {
_ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{})
}
Expand All @@ -749,12 +758,8 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS
ch <- columns

if outs.msgq != nil {
if !firstResult {
_ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{})
}
_ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNext{})
}
firstResult = false

case tokenRow:
row := make([]interface{}, len(columns))
Expand Down

0 comments on commit 983a80e

Please sign in to comment.