diff --git a/mssql.go b/mssql.go index 4db17e72..d9a52112 100644 --- a/mssql.go +++ b/mssql.go @@ -1250,9 +1250,7 @@ scan: return tokdata } } - if rc.requestDone { - return io.EOF - } + return nil } diff --git a/queries_go19_test.go b/queries_go19_test.go index c1640f98..c8f431e8 100644 --- a/queries_go19_test.go +++ b/queries_go19_test.go @@ -1133,12 +1133,13 @@ func TestMessageQueue(t *testing.T) { sqlexp.MsgNextResultSet{}, sqlexp.MsgNotice{Message: "msg2"}, sqlexp.MsgNextResultSet{}, + sqlexp.MsgNextResultSet{}, } i := 0 for active { msg := retmsg.Message(ctx) if i >= len(msgs) { - t.Fatalf("Got extra message:%+v", msg) + t.Fatalf("Got extra message:%+v", reflect.TypeOf(msg)) } t.Log(reflect.TypeOf(msg)) if reflect.TypeOf(msgs[i]) != reflect.TypeOf(msg) { @@ -1406,7 +1407,7 @@ func TestSprocWithCursorNoResult(t *testing.T) { } defer conn.Exec(DropSprocWithCursor) latency, _ := getLatency(t) - ctx, cancel := context.WithTimeout(context.Background(), latency+5000*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), latency+500*time.Millisecond) defer cancel() retmsg := &sqlexp.ReturnMessage{} // Use a sproc instead of the cursor loop directly to cover the different code path in token.go @@ -1420,7 +1421,7 @@ func TestSprocWithCursorNoResult(t *testing.T) { msgCount := 0 for active { msg := retmsg.Message(ctx) - t.Logf("Got a message: %s", reflect.TypeOf(msg)) + t.Logf("Got a message: %v", reflect.TypeOf(msg)) switch m := msg.(type) { case sqlexp.MsgNext: t.Fatalf("Got a MsgNext from a query with no rows") @@ -1437,10 +1438,64 @@ func TestSprocWithCursorNoResult(t *testing.T) { if r.Err() != nil { t.Fatalf("Got an error: %v", r.Err()) } - if rsCount != 12 { + if rsCount != 13 { t.Fatalf("Unexpected record set count: %v", rsCount) } if msgCount != 2 { t.Fatalf("Unexpected message count: %v", msgCount) } } + +func TestErrorAsLastResult(t *testing.T) { + conn, logger := open(t) + defer conn.Close() + defer logger.StopLogging() + latency, _ := getLatency(t) + ctx, cancel := context.WithTimeout(context.Background(), latency+5000*time.Millisecond) + defer cancel() + retmsg := &sqlexp.ReturnMessage{} + // Use a sproc instead of the cursor loop directly to cover the different code path in token.go + r, err := conn.QueryContext(ctx, + ` + Print N'message' + select 1 + raiserror(N'Error!', 16, 1)`, + retmsg) + if err != nil { + t.Fatal(err.Error()) + } + defer r.Close() + active := true + d := 0 + err = nil + for active { + msg := retmsg.Message(ctx) + t.Logf("Got a message: %s", reflect.TypeOf(msg)) + switch m := msg.(type) { + case sqlexp.MsgNext: + if !r.Next() { + t.Fatalf("Next returned false") + } + r.Scan(&d) + if r.Next() { + t.Fatal("Second Next returned true") + } + case sqlexp.MsgError: + err = m.Error + case sqlexp.MsgNextResultSet: + active = r.NextResultSet() + } + } + if err == nil { + t.Fatal("Should have gotten an error message") + } else { + switch e := err.(type) { + case Error: + if e.Message != "Error!" || e.Class != 16 { + t.Fatalf("Got the wrong mssql error %v", e) + } + default: + t.Fatalf("Got an unexpected error %v", e) + } + } +} diff --git a/token.go b/token.go index f71f061b..9692129e 100644 --- a/token.go +++ b/token.go @@ -710,6 +710,11 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{}) } if done.Status&doneMore == 0 { + // Rows marks the request as done when seeing this done token. We queue another result set message + // so the app calls NextResultSet again which will return false. + if outs.msgq != nil { + _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{}) + } return } case tokenDone, tokenDoneProc: @@ -741,6 +746,11 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{}) } if done.Status&doneMore == 0 { + // Rows marks the request as done when seeing this done token. We queue another result set message + // so the app calls NextResultSet again which will return false. + if outs.msgq != nil { + _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{}) + } return } case tokenColMetadata: