Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ordering of message queue messages #723

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 39 additions & 11 deletions mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -1074,7 +1074,6 @@ type Rowsq struct {
stmt *Stmt
cols []columnStruct
reader *tokenProcessor
nextCols []columnStruct
cancel func()
requestDone bool
inResultSet bool
Expand Down Expand Up @@ -1102,8 +1101,11 @@ func (rc *Rowsq) Close() error {
}
}

// data/sql calls Columns during the app's call to Next
// ProcessSingleResponse queues MsgNext for every columns token.
// data/sql calls Columns during the app's call to Next.
func (rc *Rowsq) Columns() (res []string) {
// r.cols is nil if the first query in a batch is a SELECT or similar query that returns a rowset.
// if will be non-nil for subsequent queries where NextResultSet() has populated it
if rc.cols == nil {
scan:
for {
Expand Down Expand Up @@ -1145,6 +1147,10 @@ func (rc *Rowsq) Next(dest []driver.Value) error {
if tok == nil {
return io.EOF
} else {
switch tokdata := tok.(type) {
case doneInProcStruct:
tok = (doneStruct)(tokdata)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a more idiomatic way to combine the handling of doneInProcStruct and doneStruct?

}
switch tokdata := tok.(type) {
case []interface{}:
for i := range dest {
Expand Down Expand Up @@ -1172,9 +1178,11 @@ func (rc *Rowsq) Next(dest []driver.Value) error {
if rc.reader.outs.returnStatus != nil {
*rc.reader.outs.returnStatus = tokdata
}
case ServerError:
rc.requestDone = true
return tokdata
}
}

} else {
return rc.stmt.c.checkBadConn(rc.reader.ctx, err, false)
}
Expand All @@ -1187,15 +1195,14 @@ func (rc *Rowsq) HasNextResultSet() bool {
return !rc.requestDone
}

// Scans to the next set of columns in the stream
// Scans to the end of the current statement being processed
// Note that the caller may not have read all the rows in the prior set
func (rc *Rowsq) NextResultSet() error {
if rc.requestDone {
return io.EOF
}
scan:
for {
// we should have a columns token in the channel if we aren't at the end
tok, err := rc.reader.nextToken()
if rc.reader.sess.logFlags&logDebug != 0 {
rc.reader.sess.logger.Log(rc.reader.ctx, msdsn.LogDebug, fmt.Sprintf("NextResultSet() token type:%v", reflect.TypeOf(tok)))
Expand All @@ -1208,21 +1215,42 @@ scan:
return io.EOF
}
switch tokdata := tok.(type) {
case doneInProcStruct:
tok = (doneStruct)(tokdata)
}
// ProcessSingleResponse queues a MsgNextResult for every "done" and "server error" token
// The only tokens to consume after a "done" should be "done", "server error", or "columns"
switch tokdata := tok.(type) {
case []columnStruct:
rc.nextCols = tokdata
rc.cols = tokdata
rc.inResultSet = true
break scan
case doneStruct:
if tokdata.Status&doneMore == 0 {
rc.nextCols = nil
rc.requestDone = true
break scan
}
if tokdata.isError() {
e := rc.stmt.c.checkBadConn(rc.reader.ctx, tokdata.getError(), false)
switch e.(type) {
case Error:
// Ignore non-fatal server errors. Fatal errors are of type ServerError
default:
return e
}
}
rc.inResultSet = false
rc.cols = nil
break scan
case ReturnStatus:
if rc.reader.outs.returnStatus != nil {
*rc.reader.outs.returnStatus = tokdata
}
case ServerError:
rc.requestDone = true
return tokdata
}
}
rc.cols = rc.nextCols
rc.nextCols = nil
if rc.cols == nil {
if rc.requestDone {
return io.EOF
}
return nil
Expand Down
80 changes: 75 additions & 5 deletions queries_go19_test.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//go:build go1.9
// +build go1.9

package mssql
Expand Down Expand Up @@ -1126,13 +1127,14 @@ func TestMessageQueue(t *testing.T) {

msgs := []interface{}{
sqlexp.MsgNotice{Message: "msg1"},
sqlexp.MsgNextResultSet{},
sqlexp.MsgNext{},
sqlexp.MsgRowsAffected{Count: 1},
sqlexp.MsgNextResultSet{},
sqlexp.MsgNotice{Message: "msg2"},
sqlexp.MsgNextResultSet{},
}
i := 0
rsCount := 0
for active {
msg := retmsg.Message(ctx)
if i >= len(msgs) {
Expand All @@ -1147,10 +1149,6 @@ func TestMessageQueue(t *testing.T) {
t.Log(m.Message)
case sqlexp.MsgNextResultSet:
active = rows.NextResultSet()
if active {
t.Fatal("NextResultSet returned true")
}
rsCount++
case sqlexp.MsgNext:
if !rows.Next() {
t.Fatal("rows.Next() returned false")
Expand Down Expand Up @@ -1368,3 +1366,75 @@ func TestCancelWithNoResults(t *testing.T) {
t.Fatalf("Unexpected error: %v", r.Err())
}
}

const DropSprocWithCursor = `IF EXISTS (SELECT * FROM sys.objects WHERE object_id = OBJECT_ID(N'[dbo].[TestSqlCmd]') AND type in (N'P', N'PC'))
DROP PROCEDURE [dbo].[TestSqlCmd]
`

// This query generates half a dozen tokenDoneInProc tokens which fill the channel if the app isn't scanning Rowsq
const CreateSprocWithCursor = `
CREATE PROCEDURE [dbo].[TestSqlCmd]
AS
BEGIN
DECLARE @tmp int;
DECLARE Server_Cursor CURSOR FOR
SELECT 1 UNION SELECT 2
OPEN Server_Cursor;
FETCH NEXT FROM Server_Cursor INTO @tmp;
WHILE @@FETCH_STATUS = 0
BEGIN
PRINT @tmp
FETCH NEXT FROM Server_Cursor INTO @tmp;
END;
CLOSE Server_Cursor;
DEALLOCATE Server_Cursor;
END
`

func TestSprocWithCursorNoResult(t *testing.T) {
conn, logger := open(t)
defer conn.Close()
defer logger.StopLogging()

conn.Exec(DropSprocWithCursor)
conn.Exec(CreateSprocWithCursor)
defer conn.Exec(DropSprocWithCursor)
latency, _ := getLatency(t)
ctx, cancel := context.WithTimeout(context.Background(), latency+5000*time.Millisecond)
defer cancel()
retmsg := &sqlexp.ReturnMessage{}
// Use a sproc instead of the cursor loop directly to cover the different code path in token.go
r, err := conn.QueryContext(ctx, `exec [dbo].[TestSqlCmd]`, retmsg)
if err != nil {
t.Fatal(err.Error())
}
defer r.Close()
active := true
rsCount := 0
msgCount := 0
for active {
msg := retmsg.Message(ctx)
t.Logf("Got a message: %s", reflect.TypeOf(msg))
switch m := msg.(type) {
case sqlexp.MsgNext:
t.Fatalf("Got a MsgNext from a query with no rows")
case sqlexp.MsgError:
t.Fatalf("Got an error: %s", m.Error.Error())
case sqlexp.MsgNotice:
msgCount++
case sqlexp.MsgNextResultSet:
if active = r.NextResultSet(); active {
rsCount++
}
}
}
if r.Err() != nil {
t.Fatalf("Got an error: %v", r.Err())
}
if rsCount != 12 {
t.Fatalf("Unexpected record set count: %v", rsCount)
}
if msgCount != 2 {
t.Fatalf("Unexpected message count: %v", msgCount)
}
}
23 changes: 9 additions & 14 deletions token.go
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,6 @@ func parseReturnValue(r *tdsBuffer) (nv namedValue) {
}

func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenStruct, outs outputs) {
firstResult := true
defer func() {
if err := recover(); err != nil {
if sess.logFlags&logErrors != 0 {
Expand Down Expand Up @@ -704,13 +703,13 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS
_ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgRowsAffected{Count: int64(done.RowCount)})
}
}
if outs.msgq != nil {
// For now we ignore ctx->Done errors that ReturnMessageEnqueue might return
// It's not clear how to handle them correctly here, and data/sql seems
// to set Rows.Err correctly when ctx expires already
_ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{})
}
if done.Status&doneMore == 0 {
if outs.msgq != nil {
// For now we ignore ctx->Done errors that ReturnMessageEnqueue might return
// It's not clear how to handle them correctly here, and data/sql seems
// to set Rows.Err correctly when ctx expires already
_ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{})
}
return
}
case tokenDone, tokenDoneProc:
Expand Down Expand Up @@ -738,23 +737,19 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS
_ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgRowsAffected{Count: int64(done.RowCount)})
}
}
if outs.msgq != nil {
_ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{})
}
if done.Status&doneMore == 0 {
if outs.msgq != nil {
_ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{})
}
return
}
case tokenColMetadata:
columns = parseColMetadata72(sess.buf)
ch <- columns

if outs.msgq != nil {
if !firstResult {
_ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{})
}
_ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNext{})
}
firstResult = false

case tokenRow:
row := make([]interface{}, len(columns))
Expand Down