From 55ce40d1519fff2766a075e5e5f099e61d58b7ae Mon Sep 17 00:00:00 2001 From: Yuki Wong Date: Tue, 9 Jul 2019 16:59:04 -0700 Subject: [PATCH 01/10] Update issue templates --- .github/ISSUE_TEMPLATE/custom.md | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/custom.md diff --git a/.github/ISSUE_TEMPLATE/custom.md b/.github/ISSUE_TEMPLATE/custom.md new file mode 100644 index 00000000..48d5f81f --- /dev/null +++ b/.github/ISSUE_TEMPLATE/custom.md @@ -0,0 +1,10 @@ +--- +name: Custom issue template +about: Describe this issue template's purpose here. +title: '' +labels: '' +assignees: '' + +--- + + From 9ea2f0be6f269a539898d3aab78c067a2864fa77 Mon Sep 17 00:00:00 2001 From: Yuki Wong Date: Tue, 9 Jul 2019 17:04:05 -0700 Subject: [PATCH 02/10] Delete custom.md --- .github/ISSUE_TEMPLATE/custom.md | 10 ---------- 1 file changed, 10 deletions(-) delete mode 100644 .github/ISSUE_TEMPLATE/custom.md diff --git a/.github/ISSUE_TEMPLATE/custom.md b/.github/ISSUE_TEMPLATE/custom.md deleted file mode 100644 index 48d5f81f..00000000 --- a/.github/ISSUE_TEMPLATE/custom.md +++ /dev/null @@ -1,10 +0,0 @@ ---- -name: Custom issue template -about: Describe this issue template's purpose here. -title: '' -labels: '' -assignees: '' - ---- - - From b57a23fd32ea652bbba42374ec01c7f322cf7b48 Mon Sep 17 00:00:00 2001 From: v-kaywon Date: Wed, 11 Sep 2019 14:29:29 -0700 Subject: [PATCH 03/10] sub package out mssql specific types --- buf.go | 32 ++-- bulkcopy.go | 6 +- datetimeoffset_example_test.go | 10 +- error.go => internal/mssqlerror/error.go | 17 +- internal/{decimal => mssqltypes}/decimal.go | 2 +- .../{decimal => mssqltypes}/decimal_test.go | 2 +- internal/mssqltypes/mssqltypes.go | 20 ++ .../mssqltypes/uniqueidentifier.go | 6 +- .../mssqltypes/uniqueidentifier_test.go | 2 +- mssql.go | 3 +- mssql_go19.go | 33 ++-- queries_go110_test.go | 15 +- queries_go19_test.go | 25 +-- queries_test.go | 13 +- token.go | 110 +++++------ tvp_go19_db_test.go | 174 +++++++++--------- types.go | 39 ++-- 17 files changed, 272 insertions(+), 237 deletions(-) rename error.go => internal/mssqlerror/error.go (64%) rename internal/{decimal => mssqltypes}/decimal.go (99%) rename internal/{decimal => mssqltypes}/decimal_test.go (99%) create mode 100644 internal/mssqltypes/mssqltypes.go rename uniqueidentifier.go => internal/mssqltypes/uniqueidentifier.go (85%) rename uniqueidentifier_test.go => internal/mssqltypes/uniqueidentifier_test.go (98%) diff --git a/buf.go b/buf.go index ba39b40f..d8252b84 100644 --- a/buf.go +++ b/buf.go @@ -4,6 +4,8 @@ import ( "encoding/binary" "errors" "io" + + "github.com/denisenkom/go-mssqldb/internal/mssqlerror" ) type packetType uint8 @@ -186,7 +188,7 @@ func (r *tdsBuffer) ReadByte() (res byte, err error) { func (r *tdsBuffer) byte() byte { b, err := r.ReadByte() if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } return b } @@ -194,7 +196,7 @@ func (r *tdsBuffer) byte() byte { func (r *tdsBuffer) ReadFull(buf []byte) { _, err := io.ReadFull(r, buf[:]) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } } @@ -221,27 +223,23 @@ func (r *tdsBuffer) uint16() uint16 { } func (r *tdsBuffer) BVarChar() string { - return readBVarCharOrPanic(r) + l := int(r.byte()) + return r.readUcs2(l) } -func readBVarCharOrPanic(r io.Reader) string { - s, err := readBVarChar(r) - if err != nil { - badStreamPanic(err) - } - return s +func (r *tdsBuffer) UsVarChar() string { + l := int(r.uint16()) + return r.readUcs2(l) } -func readUsVarCharOrPanic(r io.Reader) string { - s, err := readUsVarChar(r) +func (r *tdsBuffer) readUcs2(numchars int) string { + b := make([]byte, numchars*2) + r.ReadFull(b) + res, err := ucs22str(b) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } - return s -} - -func (r *tdsBuffer) UsVarChar() string { - return readUsVarCharOrPanic(r) + return res } func (r *tdsBuffer) Read(buf []byte) (copied int, err error) { diff --git a/bulkcopy.go b/bulkcopy.go index 3b7bbb44..c7e3e527 100644 --- a/bulkcopy.go +++ b/bulkcopy.go @@ -11,7 +11,7 @@ import ( "strings" "time" - "github.com/denisenkom/go-mssqldb/internal/decimal" + "github.com/denisenkom/go-mssqldb/internal/mssqltypes" ) type Bulk struct { @@ -483,8 +483,8 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error) perc := col.ti.Prec scale := col.ti.Scale - var dec decimal.Decimal - dec, err = decimal.Float64ToDecimalScale(value, scale) + var dec mssqltypes.Decimal + dec, err = mssqltypes.Float64ToDecimalScale(value, scale) if err != nil { return res, err } diff --git a/datetimeoffset_example_test.go b/datetimeoffset_example_test.go index 186f8196..95a0de03 100644 --- a/datetimeoffset_example_test.go +++ b/datetimeoffset_example_test.go @@ -9,8 +9,8 @@ import ( "log" "time" + "github.com/denisenkom/go-mssqldb/internal/mssqltypes" "github.com/golang-sql/civil" - "github.com/denisenkom/go-mssqldb" ) // This example shows how to insert and retrieve date and time types data @@ -56,9 +56,9 @@ func insertDateTime(db *sql.DB) { var timeCol civil.Time = civil.TimeOf(tin) var dateCol civil.Date = civil.DateOf(tin) var smalldatetimeCol string = "2006-01-02 22:04:00" - var datetimeCol mssql.DateTime1 = mssql.DateTime1(tin) + var datetimeCol mssqltypes.DateTime1 = mssqltypes.DateTime1(tin) var datetime2Col civil.DateTime = civil.DateTimeOf(tin) - var datetimeoffsetCol mssql.DateTimeOffset = mssql.DateTimeOffset(tin) + var datetimeoffsetCol mssqltypes.DateTimeOffset = mssqltypes.DateTimeOffset(tin) _, err = stmt.Exec(timeCol, dateCol, smalldatetimeCol, datetimeCol, datetime2Col, datetimeoffsetCol) if err != nil { log.Fatal(err) @@ -103,8 +103,8 @@ func retrieveDateTimeOutParam(db *sql.DB) { log.Fatal(err) } var ( - timeOutParam, datetime2OutParam, datetimeoffsetOutParam mssql.DateTimeOffset - dateOutParam, datetimeOutParam mssql.DateTime1 + timeOutParam, datetime2OutParam, datetimeoffsetOutParam mssqltypes.DateTimeOffset + dateOutParam, datetimeOutParam mssqltypes.DateTime1 smalldatetimeOutParam string ) _, err = db.Exec("OutDatetimeProc", diff --git a/error.go b/internal/mssqlerror/error.go similarity index 64% rename from error.go rename to internal/mssqlerror/error.go index 2e5bacee..fe4ce619 100644 --- a/error.go +++ b/internal/mssqlerror/error.go @@ -1,4 +1,4 @@ -package mssql +package mssqlerror import ( "fmt" @@ -19,6 +19,7 @@ type Error struct { LineNo int32 } +// Error returns the SQL Server error message. func (e Error) Error() string { return "mssql: " + e.Message } @@ -28,34 +29,42 @@ func (e Error) SQLErrorNumber() int32 { return e.Number } +// SQLErrorState returns the SQL Server error state. func (e Error) SQLErrorState() uint8 { return e.State } +// SQLErrorClass returns the SQL Server error class. func (e Error) SQLErrorClass() uint8 { return e.Class } +// SQLErrorMessage returns the SQL Server error message. func (e Error) SQLErrorMessage() string { return e.Message } +// SQLErrorServerName returns the SQL Server name. func (e Error) SQLErrorServerName() string { return e.ServerName } +// SQLErrorProcName returns the procedure name. func (e Error) SQLErrorProcName() string { return e.ProcName } +// SQLErrorLineNo returns the error line number. func (e Error) SQLErrorLineNo() int32 { return e.LineNo } +// StreamError represents TDS stream error. type StreamError struct { Message string } +// Error returns the TDS stream error message. func (e StreamError) Error() string { return e.Message } @@ -64,10 +73,12 @@ func streamErrorf(format string, v ...interface{}) StreamError { return StreamError{"Invalid TDS stream: " + fmt.Sprintf(format, v...)} } -func badStreamPanic(err error) { +// BadStreamPanic calls panic with err. +func BadStreamPanic(err error) { panic(err) } -func badStreamPanicf(format string, v ...interface{}) { +// BadStreamPanicf calls panic with a formatted error message as an invalid TDS stream error. +func BadStreamPanicf(format string, v ...interface{}) { panic(streamErrorf(format, v...)) } diff --git a/internal/decimal/decimal.go b/internal/mssqltypes/decimal.go similarity index 99% rename from internal/decimal/decimal.go rename to internal/mssqltypes/decimal.go index 68f790a0..3bafb65a 100644 --- a/internal/decimal/decimal.go +++ b/internal/mssqltypes/decimal.go @@ -1,4 +1,4 @@ -package decimal +package mssqltypes import ( "encoding/binary" diff --git a/internal/decimal/decimal_test.go b/internal/mssqltypes/decimal_test.go similarity index 99% rename from internal/decimal/decimal_test.go rename to internal/mssqltypes/decimal_test.go index f59b06c3..9d18e768 100644 --- a/internal/decimal/decimal_test.go +++ b/internal/mssqltypes/decimal_test.go @@ -1,4 +1,4 @@ -package decimal +package mssqltypes import ( "math" diff --git a/internal/mssqltypes/mssqltypes.go b/internal/mssqltypes/mssqltypes.go new file mode 100644 index 00000000..92b04b3c --- /dev/null +++ b/internal/mssqltypes/mssqltypes.go @@ -0,0 +1,20 @@ +// +build go1.9 + +package mssqltypes + +import "time" + +// VarChar parameter types. +type VarChar string + +// NVarCharMax encodes parameters to NVarChar(max) SQL type. +type NVarCharMax string + +// VarCharMax encodes parameter to VarChar(max) SQL type. +type VarCharMax string + +// DateTime1 encodes parameters to original DateTime SQL types. +type DateTime1 time.Time + +// DateTimeOffset encodes parameters to DateTimeOffset, preserving the UTC offset. +type DateTimeOffset time.Time diff --git a/uniqueidentifier.go b/internal/mssqltypes/uniqueidentifier.go similarity index 85% rename from uniqueidentifier.go rename to internal/mssqltypes/uniqueidentifier.go index c8ef3149..39a28ead 100644 --- a/uniqueidentifier.go +++ b/internal/mssqltypes/uniqueidentifier.go @@ -1,4 +1,4 @@ -package mssql +package mssqltypes import ( "database/sql/driver" @@ -7,8 +7,10 @@ import ( "fmt" ) +// UniqueIdentifier encodes parameters to Uniqueidentifier SQL type. type UniqueIdentifier [16]byte +// Scan converts v to UniqueIdentifier func (u *UniqueIdentifier) Scan(v interface{}) error { reverse := func(b []byte) { for i, j := 0, len(b)-1; i < j; i, j = i+1, j-1 { @@ -52,6 +54,7 @@ func (u *UniqueIdentifier) Scan(v interface{}) error { } } +// Value converts UniqueIdentifier to bytes func (u UniqueIdentifier) Value() (driver.Value, error) { reverse := func(b []byte) { for i, j := 0, len(b)-1; i < j; i, j = i+1, j-1 { @@ -69,6 +72,7 @@ func (u UniqueIdentifier) Value() (driver.Value, error) { return raw, nil } +// String converts UniqueIdentifier to string func (u UniqueIdentifier) String() string { return fmt.Sprintf("%X-%X-%X-%X-%X", u[0:4], u[4:6], u[6:8], u[8:10], u[10:]) } diff --git a/uniqueidentifier_test.go b/internal/mssqltypes/uniqueidentifier_test.go similarity index 98% rename from uniqueidentifier_test.go rename to internal/mssqltypes/uniqueidentifier_test.go index 2a29133a..22249614 100644 --- a/uniqueidentifier_test.go +++ b/internal/mssqltypes/uniqueidentifier_test.go @@ -1,4 +1,4 @@ -package mssql +package mssqltypes import ( "bytes" diff --git a/mssql.go b/mssql.go index e37109cd..2949163d 100644 --- a/mssql.go +++ b/mssql.go @@ -15,6 +15,7 @@ import ( "time" "unicode" + "github.com/denisenkom/go-mssqldb/internal/mssqlerror" "github.com/denisenkom/go-mssqldb/internal/querytext" ) @@ -188,7 +189,7 @@ func (c *Conn) checkBadConn(err error) error { case net.Error: c.connectionGood = false return err - case StreamError: + case mssqlerror.StreamError: c.connectionGood = false return err default: diff --git a/mssql_go19.go b/mssql_go19.go index a2bd1167..170a2a0d 100644 --- a/mssql_go19.go +++ b/mssql_go19.go @@ -11,6 +11,7 @@ import ( "time" // "github.com/cockroachdb/apd" + "github.com/denisenkom/go-mssqldb/internal/mssqltypes" "github.com/golang-sql/civil" ) @@ -26,29 +27,17 @@ type MssqlStmt = Stmt // Deprecated: users should transition to th var _ driver.NamedValueChecker = &Conn{} -// VarChar parameter types. -type VarChar string - -type NVarCharMax string -type VarCharMax string - -// DateTime1 encodes parameters to original DateTime SQL types. -type DateTime1 time.Time - -// DateTimeOffset encodes parameters to DateTimeOffset, preserving the UTC offset. -type DateTimeOffset time.Time - func convertInputParameter(val interface{}) (interface{}, error) { switch v := val.(type) { - case VarChar: + case mssqltypes.VarChar: return val, nil - case NVarCharMax: + case mssqltypes.NVarCharMax: return val, nil - case VarCharMax: + case mssqltypes.VarCharMax: return val, nil - case DateTime1: + case mssqltypes.DateTime1: return val, nil - case DateTimeOffset: + case mssqltypes.DateTimeOffset: return val, nil case civil.Date: return val, nil @@ -123,24 +112,24 @@ func (c *Conn) CheckNamedValue(nv *driver.NamedValue) error { func (s *Stmt) makeParamExtra(val driver.Value) (res param, err error) { switch val := val.(type) { - case VarChar: + case mssqltypes.VarChar: res.ti.TypeId = typeBigVarChar res.buffer = []byte(val) res.ti.Size = len(res.buffer) - case VarCharMax: + case mssqltypes.VarCharMax: res.ti.TypeId = typeBigVarChar res.buffer = []byte(val) res.ti.Size = 0 // currently zero forces varchar(max) - case NVarCharMax: + case mssqltypes.NVarCharMax: res.ti.TypeId = typeNVarChar res.buffer = str2ucs2(string(val)) res.ti.Size = 0 // currently zero forces nvarchar(max) - case DateTime1: + case mssqltypes.DateTime1: t := time.Time(val) res.ti.TypeId = typeDateTimeN res.buffer = encodeDateTime(t) res.ti.Size = len(res.buffer) - case DateTimeOffset: + case mssqltypes.DateTimeOffset: res.ti.TypeId = typeDateTimeOffsetN res.ti.Scale = 7 res.buffer = encodeDateTimeOffset(time.Time(val), int(res.ti.Scale)) diff --git a/queries_go110_test.go b/queries_go110_test.go index debb636f..f7e7b16e 100644 --- a/queries_go110_test.go +++ b/queries_go110_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/denisenkom/go-mssqldb/internal/mssqltypes" "github.com/golang-sql/civil" ) @@ -80,14 +81,14 @@ select ; `, sql.Named("nv", "base type nvarchar"), - sql.Named("v", VarChar("base type varchar")), - sql.Named("nvcm", NVarCharMax(strings.Repeat("x", 5000))), - sql.Named("vcm", VarCharMax(strings.Repeat("x", 5000))), - sql.Named("dt1", DateTime1(tin)), + sql.Named("v", mssqltypes.VarChar("base type varchar")), + sql.Named("nvcm", mssqltypes.NVarCharMax(strings.Repeat("x", 5000))), + sql.Named("vcm", mssqltypes.VarCharMax(strings.Repeat("x", 5000))), + sql.Named("dt1", mssqltypes.DateTime1(tin)), sql.Named("dt2", civil.DateTimeOf(tin)), sql.Named("d", civil.DateOf(tin)), sql.Named("tm", civil.TimeOf(tin)), - sql.Named("dto", DateTimeOffset(tin)), + sql.Named("dto", mssqltypes.DateTimeOffset(tin)), ) err = row.Scan(&nv, &v, &nvcm, &vcm, &dt1, &dt2, &d, &tm, &dto) if err != nil { @@ -153,11 +154,11 @@ select sql.Named("nv", sin), sql.Named("v", sin), sql.Named("tgo", tin), - sql.Named("dt1", DateTime1(tin)), + sql.Named("dt1", mssqltypes.DateTime1(tin)), sql.Named("dt2", civil.DateTimeOf(tin)), sql.Named("d", civil.DateOf(tin)), sql.Named("tm", civil.TimeOf(tin)), - sql.Named("dto", DateTimeOffset(tin)), + sql.Named("dto", mssqltypes.DateTimeOffset(tin)), ).Scan(&nv, &v, &tgo, &dt1, &dt2, &d, &tm, &dto) if err != nil { t.Fatal(err) diff --git a/queries_go19_test.go b/queries_go19_test.go index e3addd03..5e86818d 100644 --- a/queries_go19_test.go +++ b/queries_go19_test.go @@ -10,6 +10,9 @@ import ( "regexp" "testing" "time" + + "github.com/denisenkom/go-mssqldb/internal/mssqlerror" + "github.com/denisenkom/go-mssqldb/internal/mssqltypes" ) func TestOutputParam(t *testing.T) { @@ -174,8 +177,8 @@ END; if err != nil { t.Fatal(err) } - var datetime_param DateTime1 - datetime_param = DateTime1(tin) + var datetime_param mssqltypes.DateTime1 + datetime_param = mssqltypes.DateTime1(tin) _, err = db.ExecContext(ctx, sqltextrun, sql.Named("datetime", sql.Out{Dest: &datetime_param}), ) @@ -224,7 +227,7 @@ END; t.Run("should fail if destination has invalid type", func(t *testing.T) { // Error type should not be supported - var err_out Error + var err_out mssqlerror.Error _, err := db.ExecContext(ctx, sqltextrun, sql.Named("bid", sql.Out{Dest: &err_out}), ) @@ -244,7 +247,7 @@ END; t.Run("should fail if parameter has invalid type", func(t *testing.T) { // passing invalid parameter type - var err_val Error + var err_val mssqlerror.Error _, err = db.ExecContext(ctx, sqltextrun, err_val) if err == nil { t.Error("Expected to fail but it didn't") @@ -346,7 +349,7 @@ END; t.Run("original test", func(t *testing.T) { var bout int64 = 3 var cout string - var vout VarChar + var vout mssqltypes.VarChar _, err = db.ExecContext(ctx, sqltextrun, sql.Named("aid", 5), sql.Named("bid", sql.Out{Dest: &bout}), @@ -964,12 +967,12 @@ func TestDateTimeParam19(t *testing.T) { var emptydate time.Time mindate1 := time.Date(1753, 1, 1, 0, 0, 0, 0, time.UTC) maxdate1 := time.Date(9999, 12, 31, 23, 59, 59, 997000000, time.UTC) - testdates1 := []DateTime1{ - DateTime1(mindate1), - DateTime1(maxdate1), - DateTime1(time.Date(1752, 12, 31, 23, 59, 59, 997000000, time.UTC)), // just a little below minimum date - DateTime1(time.Date(10000, 1, 1, 0, 0, 0, 0, time.UTC)), // just a little over maximum date - DateTime1(emptydate), + testdates1 := []mssqltypes.DateTime1{ + mssqltypes.DateTime1(mindate1), + mssqltypes.DateTime1(maxdate1), + mssqltypes.DateTime1(time.Date(1752, 12, 31, 23, 59, 59, 997000000, time.UTC)), // just a little below minimum date + mssqltypes.DateTime1(time.Date(10000, 1, 1, 0, 0, 0, 0, time.UTC)), // just a little over maximum date + mssqltypes.DateTime1(emptydate), } for _, test := range testdates1 { diff --git a/queries_test.go b/queries_test.go index 2b476a84..1098899a 100644 --- a/queries_test.go +++ b/queries_test.go @@ -14,6 +14,9 @@ import ( "sync" "testing" "time" + + "github.com/denisenkom/go-mssqldb/internal/mssqlerror" + "github.com/denisenkom/go-mssqldb/internal/mssqltypes" ) func driverWithProcess(t *testing.T) *Driver { @@ -619,7 +622,7 @@ func TestError(t *testing.T) { t.Fatal("Query should fail") } - if sqlerr, ok := err.(Error); !ok { + if sqlerr, ok := err.(mssqlerror.Error); !ok { t.Fatalf("Should be sql error, actually %T, %v", err, err) } else { if sqlerr.Number != 2812 { // Could not find stored procedure 'bad' @@ -836,7 +839,7 @@ func TestUniqueIdentifierParam(t *testing.T) { uuid interface{} } - expected := UniqueIdentifier{0x01, 0x23, 0x45, 0x67, + expected := mssqltypes.UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, @@ -856,7 +859,7 @@ func TestUniqueIdentifierParam(t *testing.T) { for _, test := range values { t.Run(test.name, func(t *testing.T) { - var uuid2 UniqueIdentifier + var uuid2 mssqltypes.UniqueIdentifier err := conn.QueryRow("select ?", test.uuid).Scan(&uuid2) if err != nil { t.Fatal("select / scan failed", err.Error()) @@ -969,7 +972,7 @@ func TestErrorInfo(t *testing.T) { defer conn.Close() _, err := conn.Exec("select bad") - if sqlError, ok := err.(Error); ok { + if sqlError, ok := err.(mssqlerror.Error); ok { if sqlError.SQLErrorNumber() != 207 /*invalid column name*/ { t.Errorf("Query failed with unexpected error number %d %s", sqlError.SQLErrorNumber(), sqlError.SQLErrorMessage()) } @@ -980,7 +983,7 @@ func TestErrorInfo(t *testing.T) { t.Error("Failed to convert error to SQLErorr", err) } _, err = conn.Exec("RAISERROR('test message', 18, 111)") - if sqlError, ok := err.(Error); ok { + if sqlError, ok := err.(mssqlerror.Error); ok { if sqlError.SQLErrorNumber() != 50000 { t.Errorf("Query failed with unexpected error number %d %s", sqlError.SQLErrorNumber(), sqlError.SQLErrorMessage()) } diff --git a/token.go b/token.go index 1acac8a5..7b4bd4f0 100644 --- a/token.go +++ b/token.go @@ -9,6 +9,8 @@ import ( "net" "strconv" "strings" + + "github.com/denisenkom/go-mssqldb/internal/mssqlerror" ) //go:generate stringer -type token @@ -87,18 +89,18 @@ type doneStruct struct { Status uint16 CurCmd uint16 RowCount uint64 - errors []Error + errors []mssqlerror.Error } func (d doneStruct) isError() bool { return d.Status&doneError != 0 || len(d.errors) > 0 } -func (d doneStruct) getError() Error { +func (d doneStruct) getError() mssqlerror.Error { if len(d.errors) > 0 { return d.errors[len(d.errors)-1] } else { - return Error{Message: "Request failed but didn't provide reason"} + return mssqlerror.Error{Message: "Request failed but didn't provide reason"} } } @@ -137,127 +139,127 @@ func processEnvChg(sess *tdsSession) { return } if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } switch envtype { case envTypDatabase: sess.database, err = readBVarChar(r) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } _, err = readBVarChar(r) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } case envTypLanguage: // currently ignored // new value if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } // old value if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } case envTypCharset: // currently ignored // new value if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } // old value if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } case envTypPacketSize: packetsize, err := readBVarChar(r) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } _, err = readBVarChar(r) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } packetsizei, err := strconv.Atoi(packetsize) if err != nil { - badStreamPanicf("Invalid Packet size value returned from server (%s): %s", packetsize, err.Error()) + mssqlerror.BadStreamPanicf("Invalid Packet size value returned from server (%s): %s", packetsize, err.Error()) } sess.buf.ResizeBuffer(packetsizei) case envSortId: // currently ignored // new value if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } // old value, should be 0 if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } case envSortFlags: // currently ignored // new value if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } // old value, should be 0 if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } case envSqlCollation: // currently ignored var collationSize uint8 err = binary.Read(r, binary.LittleEndian, &collationSize) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } // SQL Collation data should contain 5 bytes in length if collationSize != 5 { - badStreamPanicf("Invalid SQL Collation size value returned from server: %d", collationSize) + mssqlerror.BadStreamPanicf("Invalid SQL Collation size value returned from server: %d", collationSize) } // 4 bytes, contains: LCID ColFlags Version var info uint32 err = binary.Read(r, binary.LittleEndian, &info) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } // 1 byte, contains: sortID var sortID uint8 err = binary.Read(r, binary.LittleEndian, &sortID) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } // old value, should be 0 if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } case envTypBeginTran: tranid, err := readBVarByte(r) if len(tranid) != 8 { - badStreamPanicf("invalid size of transaction identifier: %d", len(tranid)) + mssqlerror.BadStreamPanicf("invalid size of transaction identifier: %d", len(tranid)) } sess.tranid = binary.LittleEndian.Uint64(tranid) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } if sess.logFlags&logTransaction != 0 { sess.log.Printf("BEGIN TRANSACTION %x\n", sess.tranid) } _, err = readBVarByte(r) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } case envTypCommitTran, envTypRollbackTran: _, err = readBVarByte(r) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } _, err = readBVarByte(r) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } if sess.logFlags&logTransaction != 0 { if envtype == envTypCommitTran { @@ -271,81 +273,81 @@ func processEnvChg(sess *tdsSession) { // currently ignored // new value, should be 0 if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } // old value if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } case envDefectTran: // currently ignored // new value if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } // old value, should be 0 if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } case envDatabaseMirrorPartner: sess.partner, err = readBVarChar(r) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } _, err = readBVarChar(r) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } case envPromoteTran: // currently ignored // old value, should be 0 if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } // dtc token // spec says it should be L_VARBYTE, so this code might be wrong if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } case envTranMgrAddr: // currently ignored // old value, should be 0 if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } // XACT_MANAGER_ADDRESS = B_VARBYTE if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } case envTranEnded: // currently ignored // old value, B_VARBYTE if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } // should be 0 if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } case envResetConnAck: // currently ignored // old value, should be 0 if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } // should be 0 if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } case envStartedInstanceName: // currently ignored // old value, should be 0 if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } // instance name if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } case envRouting: // RoutingData message is: @@ -355,24 +357,24 @@ func processEnvChg(sess *tdsSession) { // AlternateServer US_VARCHAR _, err := readUshort(r) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } protocol, err := readByte(r) if err != nil || protocol != 0 { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } newPort, err := readUshort(r) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } newServer, err := readUsVarChar(r) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } // consume the OLDVALUE = %x00 %x00 _, err = readUshort(r) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } sess.routedServer = newServer sess.routedPort = newPort @@ -441,7 +443,7 @@ func parseLoginAck(r *tdsBuffer) loginAckStruct { prognamelen := buf[1+4] var err error if res.ProgName, err = ucs22str(buf[1+4+1 : 1+4+1+prognamelen*2]); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } res.ProgVer = binary.BigEndian.Uint32(buf[size-4:]) return res @@ -489,7 +491,7 @@ func parseNbcRow(r *tdsBuffer, columns []columnStruct, row []interface{}) { } // http://msdn.microsoft.com/en-us/library/dd304156.aspx -func parseError72(r *tdsBuffer) (res Error) { +func parseError72(r *tdsBuffer) (res mssqlerror.Error) { length := r.uint16() _ = length // ignore length res.Number = r.int32() @@ -503,7 +505,7 @@ func parseError72(r *tdsBuffer) (res Error) { } // http://msdn.microsoft.com/en-us/library/dd304156.aspx -func parseInfo(r *tdsBuffer) (res Error) { +func parseInfo(r *tdsBuffer) (res mssqlerror.Error) { length := r.uint16() _ = length // ignore length res.Number = r.int32() @@ -558,10 +560,10 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin return } if packet_type != packReply { - badStreamPanic(fmt.Errorf("unexpected packet type in reply: got %v, expected %v", packet_type, packReply)) + mssqlerror.BadStreamPanic(fmt.Errorf("unexpected packet type in reply: got %v, expected %v", packet_type, packReply)) } var columns []columnStruct - errs := make([]Error, 0, 5) + errs := make([]mssqlerror.Error, 0, 5) for { token := token(sess.buf.byte()) if sess.logFlags&logDebug != 0 { @@ -646,7 +648,7 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin } } default: - badStreamPanic(fmt.Errorf("unknown token type returned: %v", token)) + mssqlerror.BadStreamPanic(fmt.Errorf("unknown token type returned: %v", token)) } } } diff --git a/tvp_go19_db_test.go b/tvp_go19_db_test.go index 6cf42641..d5f33db0 100644 --- a/tvp_go19_db_test.go +++ b/tvp_go19_db_test.go @@ -9,6 +9,8 @@ import ( "reflect" "testing" "time" + + "github.com/denisenkom/go-mssqldb/internal/mssqltypes" ) const ( @@ -40,87 +42,87 @@ const ( ) type TvptableRow struct { - PBinary []byte `db:"p_binary"` - PVarchar string `db:"p_varchar"` - PVarcharNull *string `db:"p_varcharNull"` - PNvarchar string `db:"p_nvarchar"` - PNvarcharNull *string `db:"p_nvarcharNull"` - PID UniqueIdentifier `db:"p_id"` - PIDNull *UniqueIdentifier `db:"p_idNull"` - PVarbinary []byte `db:"p_varbinary"` - PTinyint int8 `db:"p_tinyint"` - PTinyintNull *int8 `db:"p_tinyintNull"` - PSmallint int16 `db:"p_smallint"` - PSmallintNull *int16 `db:"p_smallintNull"` - PInt int32 `db:"p_int"` - PIntNull *int32 `db:"p_intNull"` - PBigint int64 `db:"p_bigint"` - PBigintNull *int64 `db:"p_bigintNull"` - PBit bool `db:"p_bit"` - PBitNull *bool `db:"p_bitNull"` - PFloat32 float32 `db:"p_float32"` - PFloatNull32 *float32 `db:"p_floatNull32"` - PFloat64 float64 `db:"p_float64"` - PFloatNull64 *float64 `db:"p_floatNull64"` - DTime time.Time `db:"p_timeNull"` - DTimeNull *time.Time `db:"p_time"` - Pint int `db:"pInt"` - PintNull *int `db:"pIntNull"` + PBinary []byte `db:"p_binary"` + PVarchar string `db:"p_varchar"` + PVarcharNull *string `db:"p_varcharNull"` + PNvarchar string `db:"p_nvarchar"` + PNvarcharNull *string `db:"p_nvarcharNull"` + PID mssqltypes.UniqueIdentifier `db:"p_id"` + PIDNull *mssqltypes.UniqueIdentifier `db:"p_idNull"` + PVarbinary []byte `db:"p_varbinary"` + PTinyint int8 `db:"p_tinyint"` + PTinyintNull *int8 `db:"p_tinyintNull"` + PSmallint int16 `db:"p_smallint"` + PSmallintNull *int16 `db:"p_smallintNull"` + PInt int32 `db:"p_int"` + PIntNull *int32 `db:"p_intNull"` + PBigint int64 `db:"p_bigint"` + PBigintNull *int64 `db:"p_bigintNull"` + PBit bool `db:"p_bit"` + PBitNull *bool `db:"p_bitNull"` + PFloat32 float32 `db:"p_float32"` + PFloatNull32 *float32 `db:"p_floatNull32"` + PFloat64 float64 `db:"p_float64"` + PFloatNull64 *float64 `db:"p_floatNull64"` + DTime time.Time `db:"p_timeNull"` + DTimeNull *time.Time `db:"p_time"` + Pint int `db:"pInt"` + PintNull *int `db:"pIntNull"` } type TvptableRowWithSkipTag struct { - PBinary []byte `db:"p_binary"` - SkipPBinary []byte `json:"-"` - PVarchar string `db:"p_varchar"` - SkipPVarchar string `tvp:"-"` - PVarcharNull *string `db:"p_varcharNull"` - SkipPVarcharNull *string `json:"-" tvp:"-"` - PNvarchar string `db:"p_nvarchar"` - SkipPNvarchar string `json:"-"` - PNvarcharNull *string `db:"p_nvarcharNull"` - SkipPNvarcharNull *string `json:"-"` - PID UniqueIdentifier `db:"p_id"` - SkipPID UniqueIdentifier `json:"-"` - PIDNull *UniqueIdentifier `db:"p_idNull"` - SkipPIDNull *UniqueIdentifier `tvp:"-"` - PVarbinary []byte `db:"p_varbinary"` - SkipPVarbinary []byte `json:"-" tvp:"-"` - PTinyint int8 `db:"p_tinyint"` - SkipPTinyint int8 `tvp:"-"` - PTinyintNull *int8 `db:"p_tinyintNull"` - SkipPTinyintNull *int8 `tvp:"-" json:"any"` - PSmallint int16 `db:"p_smallint"` - SkipPSmallint int16 `json:"-"` - PSmallintNull *int16 `db:"p_smallintNull"` - SkipPSmallintNull *int16 `tvp:"-"` - PInt int32 `db:"p_int"` - SkipPInt int32 `json:"-"` - PIntNull *int32 `db:"p_intNull"` - SkipPIntNull *int32 `tvp:"-"` - PBigint int64 `db:"p_bigint"` - SkipPBigint int64 `tvp:"-"` - PBigintNull *int64 `db:"p_bigintNull"` - SkipPBigintNull *int64 `json:"any" tvp:"-"` - PBit bool `db:"p_bit"` - SkipPBit bool `json:"-"` - PBitNull *bool `db:"p_bitNull"` - SkipPBitNull *bool `json:"-"` - PFloat32 float32 `db:"p_float32"` - SkipPFloat32 float32 `tvp:"-"` - PFloatNull32 *float32 `db:"p_floatNull32"` - SkipPFloatNull32 *float32 `tvp:"-"` - PFloat64 float64 `db:"p_float64"` - SkipPFloat64 float64 `tvp:"-"` - PFloatNull64 *float64 `db:"p_floatNull64"` - SkipPFloatNull64 *float64 `tvp:"-"` - DTime time.Time `db:"p_timeNull"` - SkipDTime time.Time `tvp:"-"` - DTimeNull *time.Time `db:"p_time"` - SkipDTimeNull *time.Time `tvp:"-"` - Pint int `db:"p_int_null"` - SkipPint int `tvp:"-"` - PintNull *int `db:"p_int_"` - SkipPintNull *int `tvp:"-"` + PBinary []byte `db:"p_binary"` + SkipPBinary []byte `json:"-"` + PVarchar string `db:"p_varchar"` + SkipPVarchar string `tvp:"-"` + PVarcharNull *string `db:"p_varcharNull"` + SkipPVarcharNull *string `json:"-" tvp:"-"` + PNvarchar string `db:"p_nvarchar"` + SkipPNvarchar string `json:"-"` + PNvarcharNull *string `db:"p_nvarcharNull"` + SkipPNvarcharNull *string `json:"-"` + PID mssqltypes.UniqueIdentifier `db:"p_id"` + SkipPID mssqltypes.UniqueIdentifier `json:"-"` + PIDNull *mssqltypes.UniqueIdentifier `db:"p_idNull"` + SkipPIDNull *mssqltypes.UniqueIdentifier `tvp:"-"` + PVarbinary []byte `db:"p_varbinary"` + SkipPVarbinary []byte `json:"-" tvp:"-"` + PTinyint int8 `db:"p_tinyint"` + SkipPTinyint int8 `tvp:"-"` + PTinyintNull *int8 `db:"p_tinyintNull"` + SkipPTinyintNull *int8 `tvp:"-" json:"any"` + PSmallint int16 `db:"p_smallint"` + SkipPSmallint int16 `json:"-"` + PSmallintNull *int16 `db:"p_smallintNull"` + SkipPSmallintNull *int16 `tvp:"-"` + PInt int32 `db:"p_int"` + SkipPInt int32 `json:"-"` + PIntNull *int32 `db:"p_intNull"` + SkipPIntNull *int32 `tvp:"-"` + PBigint int64 `db:"p_bigint"` + SkipPBigint int64 `tvp:"-"` + PBigintNull *int64 `db:"p_bigintNull"` + SkipPBigintNull *int64 `json:"any" tvp:"-"` + PBit bool `db:"p_bit"` + SkipPBit bool `json:"-"` + PBitNull *bool `db:"p_bitNull"` + SkipPBitNull *bool `json:"-"` + PFloat32 float32 `db:"p_float32"` + SkipPFloat32 float32 `tvp:"-"` + PFloatNull32 *float32 `db:"p_floatNull32"` + SkipPFloatNull32 *float32 `tvp:"-"` + PFloat64 float64 `db:"p_float64"` + SkipPFloat64 float64 `tvp:"-"` + PFloatNull64 *float64 `db:"p_floatNull64"` + SkipPFloatNull64 *float64 `tvp:"-"` + DTime time.Time `db:"p_timeNull"` + SkipDTime time.Time `tvp:"-"` + DTimeNull *time.Time `db:"p_time"` + SkipDTimeNull *time.Time `tvp:"-"` + Pint int `db:"p_int_null"` + SkipPint int `tvp:"-"` + PintNull *int `db:"p_int_"` + SkipPintNull *int `tvp:"-"` } func TestTVP(t *testing.T) { @@ -215,7 +217,7 @@ func TestTVP(t *testing.T) { PBinary: []byte("ccc"), PVarchar: varcharNull, PNvarchar: nvarchar, - PID: UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, + PID: mssqltypes.UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, PVarbinary: bytesMock, PTinyint: i8, PSmallint: i16, @@ -231,7 +233,7 @@ func TestTVP(t *testing.T) { PBinary: []byte("www"), PVarchar: "eee", PNvarchar: "lll", - PID: UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, + PID: mssqltypes.UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, PVarbinary: []byte("zzz"), PTinyint: 5, PSmallint: 16000, @@ -247,7 +249,7 @@ func TestTVP(t *testing.T) { PBinary: nil, PVarcharNull: &varcharNull, PNvarcharNull: &nvarchar, - PIDNull: &UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, + PIDNull: &mssqltypes.UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, PTinyintNull: &i8, PSmallintNull: &i16, PIntNull: &i32, @@ -263,7 +265,7 @@ func TestTVP(t *testing.T) { PBinary: []byte("www"), PVarchar: "eee", PNvarchar: "lll", - PIDNull: &UniqueIdentifier{}, + PIDNull: &mssqltypes.UniqueIdentifier{}, PVarbinary: []byte("zzz"), PTinyint: 5, PSmallint: 16000, @@ -468,7 +470,7 @@ func TestTVP_WithTag(t *testing.T) { PBinary: []byte("ccc"), PVarchar: varcharNull, PNvarchar: nvarchar, - PID: UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, + PID: mssqltypes.UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, PVarbinary: bytesMock, PTinyint: i8, PSmallint: i16, @@ -485,7 +487,7 @@ func TestTVP_WithTag(t *testing.T) { PBinary: []byte("www"), PVarchar: "eee", PNvarchar: "lll", - PID: UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, + PID: mssqltypes.UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, PVarbinary: []byte("zzz"), PTinyint: 5, PSmallint: 16000, @@ -502,7 +504,7 @@ func TestTVP_WithTag(t *testing.T) { PBinary: nil, PVarcharNull: &varcharNull, PNvarcharNull: &nvarchar, - PIDNull: &UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, + PIDNull: &mssqltypes.UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, PTinyintNull: &i8, PSmallintNull: &i16, PIntNull: &i32, @@ -518,7 +520,7 @@ func TestTVP_WithTag(t *testing.T) { PBinary: []byte("www"), PVarchar: "eee", PNvarchar: "lll", - PIDNull: &UniqueIdentifier{}, + PIDNull: &mssqltypes.UniqueIdentifier{}, PVarbinary: []byte("zzz"), PTinyint: 5, PSmallint: 16000, diff --git a/types.go b/types.go index b6e7fb2b..6bf297f7 100644 --- a/types.go +++ b/types.go @@ -11,7 +11,8 @@ import ( "time" "github.com/denisenkom/go-mssqldb/internal/cp" - "github.com/denisenkom/go-mssqldb/internal/decimal" + "github.com/denisenkom/go-mssqldb/internal/mssqlerror" + "github.com/denisenkom/go-mssqldb/internal/mssqltypes" ) // fixed-length data types @@ -338,7 +339,7 @@ func readFixedType(ti *typeInfo, r *tdsBuffer) interface{} { case typeInt8: return int64(binary.LittleEndian.Uint64(buf)) default: - badStreamPanicf("Invalid typeid") + mssqlerror.BadStreamPanicf("Invalid typeid") } panic("shoulnd't get here") } @@ -353,7 +354,7 @@ func readByteLenType(ti *typeInfo, r *tdsBuffer) interface{} { switch ti.TypeId { case typeDateN: if len(buf) != 3 { - badStreamPanicf("Invalid size for DATENTYPE") + mssqlerror.BadStreamPanicf("Invalid size for DATENTYPE") } return decodeDate(buf) case typeTimeN: @@ -375,13 +376,13 @@ func readByteLenType(ti *typeInfo, r *tdsBuffer) interface{} { case 8: return int64(binary.LittleEndian.Uint64(buf)) default: - badStreamPanicf("Invalid size for INTNTYPE: %d", len(buf)) + mssqlerror.BadStreamPanicf("Invalid size for INTNTYPE: %d", len(buf)) } case typeDecimal, typeNumeric, typeDecimalN, typeNumericN: return decodeDecimal(ti.Prec, ti.Scale, buf) case typeBitN: if len(buf) != 1 { - badStreamPanicf("Invalid size for BITNTYPE") + mssqlerror.BadStreamPanicf("Invalid size for BITNTYPE") } return buf[0] != 0 case typeFltN: @@ -391,7 +392,7 @@ func readByteLenType(ti *typeInfo, r *tdsBuffer) interface{} { case 8: return math.Float64frombits(binary.LittleEndian.Uint64(buf)) default: - badStreamPanicf("Invalid size for FLTNTYPE") + mssqlerror.BadStreamPanicf("Invalid size for FLTNTYPE") } case typeMoneyN: switch len(buf) { @@ -400,7 +401,7 @@ func readByteLenType(ti *typeInfo, r *tdsBuffer) interface{} { case 8: return decodeMoney(buf) default: - badStreamPanicf("Invalid size for MONEYNTYPE") + mssqlerror.BadStreamPanicf("Invalid size for MONEYNTYPE") } case typeDateTim4: return decodeDateTim4(buf) @@ -413,7 +414,7 @@ func readByteLenType(ti *typeInfo, r *tdsBuffer) interface{} { case 8: return decodeDateTime(buf) default: - badStreamPanicf("Invalid size for DATETIMENTYPE") + mssqlerror.BadStreamPanicf("Invalid size for DATETIMENTYPE") } case typeChar, typeVarChar: return decodeChar(ti.Collation, buf) @@ -425,7 +426,7 @@ func readByteLenType(ti *typeInfo, r *tdsBuffer) interface{} { copy(cpy, buf) return cpy default: - badStreamPanicf("Invalid typeid") + mssqlerror.BadStreamPanicf("Invalid typeid") } panic("shoulnd't get here") } @@ -464,7 +465,7 @@ func readShortLenType(ti *typeInfo, r *tdsBuffer) interface{} { case typeUdt: return decodeUdt(*ti, buf) default: - badStreamPanicf("Invalid typeid") + mssqlerror.BadStreamPanicf("Invalid typeid") } panic("shoulnd't get here") } @@ -512,7 +513,7 @@ func readLongLenType(ti *typeInfo, r *tdsBuffer) interface{} { case typeNText: return decodeNChar(buf) default: - badStreamPanicf("Invalid typeid") + mssqlerror.BadStreamPanicf("Invalid typeid") } panic("shoulnd't get here") } @@ -645,7 +646,7 @@ func readVariantType(ti *typeInfo, r *tdsBuffer) interface{} { r.ReadFull(buf) return decodeNChar(buf) default: - badStreamPanicf("Invalid variant typeid") + mssqlerror.BadStreamPanicf("Invalid variant typeid") } panic("shoulnd't get here") } @@ -671,7 +672,7 @@ func readPLPType(ti *typeInfo, r *tdsBuffer) interface{} { break } if _, err := io.CopyN(buf, r, int64(chunksize)); err != nil { - badStreamPanicf("Reading PLP type failed: %s", err.Error()) + mssqlerror.BadStreamPanicf("Reading PLP type failed: %s", err.Error()) } } switch ti.TypeId { @@ -725,7 +726,7 @@ func readVarLen(ti *typeInfo, r *tdsBuffer) { case 5, 6, 7: ti.Size = 5 default: - badStreamPanicf("Invalid scale for TIME/DATETIME2/DATETIMEOFFSET type") + mssqlerror.BadStreamPanicf("Invalid scale for TIME/DATETIME2/DATETIMEOFFSET type") } switch ti.TypeId { case typeDateTime2N: @@ -805,7 +806,7 @@ func readVarLen(ti *typeInfo, r *tdsBuffer) { ti.Reader = readVariantType } default: - badStreamPanicf("Invalid type %d", ti.TypeId) + mssqlerror.BadStreamPanicf("Invalid type %d", ti.TypeId) } return } @@ -819,12 +820,12 @@ func decodeMoney(buf []byte) []byte { uint64(buf[1])<<40 | uint64(buf[2])<<48 | uint64(buf[3])<<56) - return decimal.ScaleBytes(strconv.FormatInt(money, 10), 4) + return mssqltypes.ScaleBytes(strconv.FormatInt(money, 10), 4) } func decodeMoney4(buf []byte) []byte { money := int32(binary.LittleEndian.Uint32(buf[0:4])) - return decimal.ScaleBytes(strconv.FormatInt(int64(money), 10), 4) + return mssqltypes.ScaleBytes(strconv.FormatInt(int64(money), 10), 4) } func decodeGuid(buf []byte) []byte { @@ -836,7 +837,7 @@ func decodeGuid(buf []byte) []byte { func decodeDecimal(prec uint8, scale uint8, buf []byte) []byte { var sign uint8 sign = buf[0] - var dec decimal.Decimal + var dec mssqltypes.Decimal dec.SetPositive(sign != 0) dec.SetPrec(prec) dec.SetScale(scale) @@ -994,7 +995,7 @@ func decodeChar(col cp.Collation, buf []byte) string { func decodeUcs2(buf []byte) string { res, err := ucs22str(buf) if err != nil { - badStreamPanicf("Invalid UCS2 encoding: %s", err.Error()) + mssqlerror.BadStreamPanicf("Invalid UCS2 encoding: %s", err.Error()) } return res } From 95d9e8fd5071cfe6a88c86a4fb61deb7fd10fb79 Mon Sep 17 00:00:00 2001 From: v-kaywon Date: Wed, 11 Sep 2019 14:31:31 -0700 Subject: [PATCH 04/10] fix buf_test.go --- buf_test.go | 32 -------------------------------- 1 file changed, 32 deletions(-) diff --git a/buf_test.go b/buf_test.go index 76efbd1d..9d681665 100644 --- a/buf_test.go +++ b/buf_test.go @@ -282,35 +282,3 @@ func TestWrite_BufferBounds(t *testing.T) { t.Fatal("FinishPacket failed:", err.Error()) } } - -func TestReadUsVarCharOrPanic(t *testing.T) { - memBuf := bytes.NewBuffer([]byte{3, 0, 0x31, 0, 0x32, 0, 0x33, 0}) - s := readUsVarCharOrPanic(memBuf) - if s != "123" { - t.Errorf("UsVarChar expected to return 123 but it returned %s", s) - } - - // test invalid usvarchar - defer func() { - recover() - }() - memBuf = bytes.NewBuffer([]byte{}) - s = readUsVarCharOrPanic(memBuf) - t.Fatal("UsVarChar() should panic, but it didn't") -} - -func TestReadBVarCharOrPanic(t *testing.T) { - memBuf := bytes.NewBuffer([]byte{3, 0x31, 0, 0x32, 0, 0x33, 0}) - s := readBVarCharOrPanic(memBuf) - if s != "123" { - t.Errorf("readBVarCharOrPanic expected to return 123 but it returned %s", s) - } - - // test invalid varchar - defer func() { - recover() - }() - memBuf = bytes.NewBuffer([]byte{}) - s = readBVarCharOrPanic(memBuf) - t.Fatal("readBVarCharOrPanic() should panic on empty buffer, but it didn't") -} From 372f3d2d1ba241dfcf5ae80ef19771aba498fc9b Mon Sep 17 00:00:00 2001 From: Yuki Wong Date: Tue, 9 Jul 2019 16:59:04 -0700 Subject: [PATCH 05/10] Update issue templates --- .github/ISSUE_TEMPLATE/custom.md | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/custom.md diff --git a/.github/ISSUE_TEMPLATE/custom.md b/.github/ISSUE_TEMPLATE/custom.md new file mode 100644 index 00000000..48d5f81f --- /dev/null +++ b/.github/ISSUE_TEMPLATE/custom.md @@ -0,0 +1,10 @@ +--- +name: Custom issue template +about: Describe this issue template's purpose here. +title: '' +labels: '' +assignees: '' + +--- + + From 83a323c964eeafa4d3318fef3d31265d4b18e7f3 Mon Sep 17 00:00:00 2001 From: Yuki Wong Date: Tue, 9 Jul 2019 17:04:05 -0700 Subject: [PATCH 06/10] Delete custom.md --- .github/ISSUE_TEMPLATE/custom.md | 10 ---------- 1 file changed, 10 deletions(-) delete mode 100644 .github/ISSUE_TEMPLATE/custom.md diff --git a/.github/ISSUE_TEMPLATE/custom.md b/.github/ISSUE_TEMPLATE/custom.md deleted file mode 100644 index 48d5f81f..00000000 --- a/.github/ISSUE_TEMPLATE/custom.md +++ /dev/null @@ -1,10 +0,0 @@ ---- -name: Custom issue template -about: Describe this issue template's purpose here. -title: '' -labels: '' -assignees: '' - ---- - - From 2b0a628ab0e1d4a185ae822a91b4303964b114cb Mon Sep 17 00:00:00 2001 From: v-kaywon Date: Wed, 11 Sep 2019 14:29:29 -0700 Subject: [PATCH 07/10] sub package out mssql specific types --- buf.go | 20 +- bulkcopy.go | 20 +- datetimeoffset_example_test.go | 10 +- error.go => internal/mssqlerror/error.go | 17 +- internal/{decimal => mssqltypes}/decimal.go | 2 +- .../{decimal => mssqltypes}/decimal_test.go | 2 +- internal/mssqltypes/mssqltypes.go | 20 ++ .../mssqltypes/uniqueidentifier.go | 6 +- .../mssqltypes/uniqueidentifier_test.go | 2 +- mssql.go | 3 +- mssql_go19.go | 33 ++-- queries_go110_test.go | 15 +- queries_go19_test.go | 25 +-- queries_test.go | 13 +- token.go | 110 +++++------ tvp_go19_db_test.go | 174 +++++++++--------- types.go | 39 ++-- 17 files changed, 280 insertions(+), 231 deletions(-) rename error.go => internal/mssqlerror/error.go (64%) rename internal/{decimal => mssqltypes}/decimal.go (99%) rename internal/{decimal => mssqltypes}/decimal_test.go (99%) create mode 100644 internal/mssqltypes/mssqltypes.go rename uniqueidentifier.go => internal/mssqltypes/uniqueidentifier.go (85%) rename uniqueidentifier_test.go => internal/mssqltypes/uniqueidentifier_test.go (98%) diff --git a/buf.go b/buf.go index ba39b40f..56e56dda 100644 --- a/buf.go +++ b/buf.go @@ -4,6 +4,8 @@ import ( "encoding/binary" "errors" "io" + + "github.com/denisenkom/go-mssqldb/internal/mssqlerror" ) type packetType uint8 @@ -186,7 +188,7 @@ func (r *tdsBuffer) ReadByte() (res byte, err error) { func (r *tdsBuffer) byte() byte { b, err := r.ReadByte() if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } return b } @@ -194,7 +196,7 @@ func (r *tdsBuffer) byte() byte { func (r *tdsBuffer) ReadFull(buf []byte) { _, err := io.ReadFull(r, buf[:]) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } } @@ -227,7 +229,7 @@ func (r *tdsBuffer) BVarChar() string { func readBVarCharOrPanic(r io.Reader) string { s, err := readBVarChar(r) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } return s } @@ -235,7 +237,7 @@ func readBVarCharOrPanic(r io.Reader) string { func readUsVarCharOrPanic(r io.Reader) string { s, err := readUsVarChar(r) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } return s } @@ -244,6 +246,16 @@ func (r *tdsBuffer) UsVarChar() string { return readUsVarCharOrPanic(r) } +func (r *tdsBuffer) readUcs2(numchars int) string { + b := make([]byte, numchars*2) + r.ReadFull(b) + res, err := ucs22str(b) + if err != nil { + mssqlerror.BadStreamPanic(err) + } + return res +} + func (r *tdsBuffer) Read(buf []byte) (copied int, err error) { copied = 0 err = nil diff --git a/bulkcopy.go b/bulkcopy.go index 1d5eacb3..1792d275 100644 --- a/bulkcopy.go +++ b/bulkcopy.go @@ -10,7 +10,7 @@ import ( "strings" "time" - "github.com/denisenkom/go-mssqldb/internal/decimal" + "github.com/denisenkom/go-mssqldb/internal/mssqltypes" ) type Bulk struct { @@ -490,24 +490,24 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error) case typeDecimal, typeDecimalN, typeNumeric, typeNumericN: prec := col.ti.Prec scale := col.ti.Scale - var dec decimal.Decimal + var dec mssqltypes.Decimal switch v := val.(type) { case int: - dec = decimal.Int64ToDecimalScale(int64(v), 0) + dec = mssqltypes.Int64ToDecimalScale(int64(v), 0) case int8: - dec = decimal.Int64ToDecimalScale(int64(v), 0) + dec = mssqltypes.Int64ToDecimalScale(int64(v), 0) case int16: - dec = decimal.Int64ToDecimalScale(int64(v), 0) + dec = mssqltypes.Int64ToDecimalScale(int64(v), 0) case int32: - dec = decimal.Int64ToDecimalScale(int64(v), 0) + dec = mssqltypes.Int64ToDecimalScale(int64(v), 0) case int64: - dec = decimal.Int64ToDecimalScale(int64(v), 0) + dec = mssqltypes.Int64ToDecimalScale(int64(v), 0) case float32: - dec, err = decimal.Float64ToDecimalScale(float64(v), scale) + dec, err = mssqltypes.Float64ToDecimalScale(float64(v), scale) case float64: - dec, err = decimal.Float64ToDecimalScale(float64(v), scale) + dec, err = mssqltypes.Float64ToDecimalScale(float64(v), scale) case string: - dec, err = decimal.StringToDecimalScale(v, scale) + dec, err = mssqltypes.StringToDecimalScale(v, scale) default: return res, fmt.Errorf("unknown value for decimal: %T %#v", v, v) } diff --git a/datetimeoffset_example_test.go b/datetimeoffset_example_test.go index 186f8196..95a0de03 100644 --- a/datetimeoffset_example_test.go +++ b/datetimeoffset_example_test.go @@ -9,8 +9,8 @@ import ( "log" "time" + "github.com/denisenkom/go-mssqldb/internal/mssqltypes" "github.com/golang-sql/civil" - "github.com/denisenkom/go-mssqldb" ) // This example shows how to insert and retrieve date and time types data @@ -56,9 +56,9 @@ func insertDateTime(db *sql.DB) { var timeCol civil.Time = civil.TimeOf(tin) var dateCol civil.Date = civil.DateOf(tin) var smalldatetimeCol string = "2006-01-02 22:04:00" - var datetimeCol mssql.DateTime1 = mssql.DateTime1(tin) + var datetimeCol mssqltypes.DateTime1 = mssqltypes.DateTime1(tin) var datetime2Col civil.DateTime = civil.DateTimeOf(tin) - var datetimeoffsetCol mssql.DateTimeOffset = mssql.DateTimeOffset(tin) + var datetimeoffsetCol mssqltypes.DateTimeOffset = mssqltypes.DateTimeOffset(tin) _, err = stmt.Exec(timeCol, dateCol, smalldatetimeCol, datetimeCol, datetime2Col, datetimeoffsetCol) if err != nil { log.Fatal(err) @@ -103,8 +103,8 @@ func retrieveDateTimeOutParam(db *sql.DB) { log.Fatal(err) } var ( - timeOutParam, datetime2OutParam, datetimeoffsetOutParam mssql.DateTimeOffset - dateOutParam, datetimeOutParam mssql.DateTime1 + timeOutParam, datetime2OutParam, datetimeoffsetOutParam mssqltypes.DateTimeOffset + dateOutParam, datetimeOutParam mssqltypes.DateTime1 smalldatetimeOutParam string ) _, err = db.Exec("OutDatetimeProc", diff --git a/error.go b/internal/mssqlerror/error.go similarity index 64% rename from error.go rename to internal/mssqlerror/error.go index 2e5bacee..fe4ce619 100644 --- a/error.go +++ b/internal/mssqlerror/error.go @@ -1,4 +1,4 @@ -package mssql +package mssqlerror import ( "fmt" @@ -19,6 +19,7 @@ type Error struct { LineNo int32 } +// Error returns the SQL Server error message. func (e Error) Error() string { return "mssql: " + e.Message } @@ -28,34 +29,42 @@ func (e Error) SQLErrorNumber() int32 { return e.Number } +// SQLErrorState returns the SQL Server error state. func (e Error) SQLErrorState() uint8 { return e.State } +// SQLErrorClass returns the SQL Server error class. func (e Error) SQLErrorClass() uint8 { return e.Class } +// SQLErrorMessage returns the SQL Server error message. func (e Error) SQLErrorMessage() string { return e.Message } +// SQLErrorServerName returns the SQL Server name. func (e Error) SQLErrorServerName() string { return e.ServerName } +// SQLErrorProcName returns the procedure name. func (e Error) SQLErrorProcName() string { return e.ProcName } +// SQLErrorLineNo returns the error line number. func (e Error) SQLErrorLineNo() int32 { return e.LineNo } +// StreamError represents TDS stream error. type StreamError struct { Message string } +// Error returns the TDS stream error message. func (e StreamError) Error() string { return e.Message } @@ -64,10 +73,12 @@ func streamErrorf(format string, v ...interface{}) StreamError { return StreamError{"Invalid TDS stream: " + fmt.Sprintf(format, v...)} } -func badStreamPanic(err error) { +// BadStreamPanic calls panic with err. +func BadStreamPanic(err error) { panic(err) } -func badStreamPanicf(format string, v ...interface{}) { +// BadStreamPanicf calls panic with a formatted error message as an invalid TDS stream error. +func BadStreamPanicf(format string, v ...interface{}) { panic(streamErrorf(format, v...)) } diff --git a/internal/decimal/decimal.go b/internal/mssqltypes/decimal.go similarity index 99% rename from internal/decimal/decimal.go rename to internal/mssqltypes/decimal.go index 7da0375d..6f4e7b2a 100644 --- a/internal/decimal/decimal.go +++ b/internal/mssqltypes/decimal.go @@ -1,4 +1,4 @@ -package decimal +package mssqltypes import ( "encoding/binary" diff --git a/internal/decimal/decimal_test.go b/internal/mssqltypes/decimal_test.go similarity index 99% rename from internal/decimal/decimal_test.go rename to internal/mssqltypes/decimal_test.go index 4086a964..f75cf79e 100644 --- a/internal/decimal/decimal_test.go +++ b/internal/mssqltypes/decimal_test.go @@ -1,4 +1,4 @@ -package decimal +package mssqltypes import ( "math" diff --git a/internal/mssqltypes/mssqltypes.go b/internal/mssqltypes/mssqltypes.go new file mode 100644 index 00000000..92b04b3c --- /dev/null +++ b/internal/mssqltypes/mssqltypes.go @@ -0,0 +1,20 @@ +// +build go1.9 + +package mssqltypes + +import "time" + +// VarChar parameter types. +type VarChar string + +// NVarCharMax encodes parameters to NVarChar(max) SQL type. +type NVarCharMax string + +// VarCharMax encodes parameter to VarChar(max) SQL type. +type VarCharMax string + +// DateTime1 encodes parameters to original DateTime SQL types. +type DateTime1 time.Time + +// DateTimeOffset encodes parameters to DateTimeOffset, preserving the UTC offset. +type DateTimeOffset time.Time diff --git a/uniqueidentifier.go b/internal/mssqltypes/uniqueidentifier.go similarity index 85% rename from uniqueidentifier.go rename to internal/mssqltypes/uniqueidentifier.go index c8ef3149..39a28ead 100644 --- a/uniqueidentifier.go +++ b/internal/mssqltypes/uniqueidentifier.go @@ -1,4 +1,4 @@ -package mssql +package mssqltypes import ( "database/sql/driver" @@ -7,8 +7,10 @@ import ( "fmt" ) +// UniqueIdentifier encodes parameters to Uniqueidentifier SQL type. type UniqueIdentifier [16]byte +// Scan converts v to UniqueIdentifier func (u *UniqueIdentifier) Scan(v interface{}) error { reverse := func(b []byte) { for i, j := 0, len(b)-1; i < j; i, j = i+1, j-1 { @@ -52,6 +54,7 @@ func (u *UniqueIdentifier) Scan(v interface{}) error { } } +// Value converts UniqueIdentifier to bytes func (u UniqueIdentifier) Value() (driver.Value, error) { reverse := func(b []byte) { for i, j := 0, len(b)-1; i < j; i, j = i+1, j-1 { @@ -69,6 +72,7 @@ func (u UniqueIdentifier) Value() (driver.Value, error) { return raw, nil } +// String converts UniqueIdentifier to string func (u UniqueIdentifier) String() string { return fmt.Sprintf("%X-%X-%X-%X-%X", u[0:4], u[4:6], u[6:8], u[8:10], u[10:]) } diff --git a/uniqueidentifier_test.go b/internal/mssqltypes/uniqueidentifier_test.go similarity index 98% rename from uniqueidentifier_test.go rename to internal/mssqltypes/uniqueidentifier_test.go index 2a29133a..22249614 100644 --- a/uniqueidentifier_test.go +++ b/internal/mssqltypes/uniqueidentifier_test.go @@ -1,4 +1,4 @@ -package mssql +package mssqltypes import ( "bytes" diff --git a/mssql.go b/mssql.go index e37109cd..2949163d 100644 --- a/mssql.go +++ b/mssql.go @@ -15,6 +15,7 @@ import ( "time" "unicode" + "github.com/denisenkom/go-mssqldb/internal/mssqlerror" "github.com/denisenkom/go-mssqldb/internal/querytext" ) @@ -188,7 +189,7 @@ func (c *Conn) checkBadConn(err error) error { case net.Error: c.connectionGood = false return err - case StreamError: + case mssqlerror.StreamError: c.connectionGood = false return err default: diff --git a/mssql_go19.go b/mssql_go19.go index a2bd1167..170a2a0d 100644 --- a/mssql_go19.go +++ b/mssql_go19.go @@ -11,6 +11,7 @@ import ( "time" // "github.com/cockroachdb/apd" + "github.com/denisenkom/go-mssqldb/internal/mssqltypes" "github.com/golang-sql/civil" ) @@ -26,29 +27,17 @@ type MssqlStmt = Stmt // Deprecated: users should transition to th var _ driver.NamedValueChecker = &Conn{} -// VarChar parameter types. -type VarChar string - -type NVarCharMax string -type VarCharMax string - -// DateTime1 encodes parameters to original DateTime SQL types. -type DateTime1 time.Time - -// DateTimeOffset encodes parameters to DateTimeOffset, preserving the UTC offset. -type DateTimeOffset time.Time - func convertInputParameter(val interface{}) (interface{}, error) { switch v := val.(type) { - case VarChar: + case mssqltypes.VarChar: return val, nil - case NVarCharMax: + case mssqltypes.NVarCharMax: return val, nil - case VarCharMax: + case mssqltypes.VarCharMax: return val, nil - case DateTime1: + case mssqltypes.DateTime1: return val, nil - case DateTimeOffset: + case mssqltypes.DateTimeOffset: return val, nil case civil.Date: return val, nil @@ -123,24 +112,24 @@ func (c *Conn) CheckNamedValue(nv *driver.NamedValue) error { func (s *Stmt) makeParamExtra(val driver.Value) (res param, err error) { switch val := val.(type) { - case VarChar: + case mssqltypes.VarChar: res.ti.TypeId = typeBigVarChar res.buffer = []byte(val) res.ti.Size = len(res.buffer) - case VarCharMax: + case mssqltypes.VarCharMax: res.ti.TypeId = typeBigVarChar res.buffer = []byte(val) res.ti.Size = 0 // currently zero forces varchar(max) - case NVarCharMax: + case mssqltypes.NVarCharMax: res.ti.TypeId = typeNVarChar res.buffer = str2ucs2(string(val)) res.ti.Size = 0 // currently zero forces nvarchar(max) - case DateTime1: + case mssqltypes.DateTime1: t := time.Time(val) res.ti.TypeId = typeDateTimeN res.buffer = encodeDateTime(t) res.ti.Size = len(res.buffer) - case DateTimeOffset: + case mssqltypes.DateTimeOffset: res.ti.TypeId = typeDateTimeOffsetN res.ti.Scale = 7 res.buffer = encodeDateTimeOffset(time.Time(val), int(res.ti.Scale)) diff --git a/queries_go110_test.go b/queries_go110_test.go index debb636f..f7e7b16e 100644 --- a/queries_go110_test.go +++ b/queries_go110_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/denisenkom/go-mssqldb/internal/mssqltypes" "github.com/golang-sql/civil" ) @@ -80,14 +81,14 @@ select ; `, sql.Named("nv", "base type nvarchar"), - sql.Named("v", VarChar("base type varchar")), - sql.Named("nvcm", NVarCharMax(strings.Repeat("x", 5000))), - sql.Named("vcm", VarCharMax(strings.Repeat("x", 5000))), - sql.Named("dt1", DateTime1(tin)), + sql.Named("v", mssqltypes.VarChar("base type varchar")), + sql.Named("nvcm", mssqltypes.NVarCharMax(strings.Repeat("x", 5000))), + sql.Named("vcm", mssqltypes.VarCharMax(strings.Repeat("x", 5000))), + sql.Named("dt1", mssqltypes.DateTime1(tin)), sql.Named("dt2", civil.DateTimeOf(tin)), sql.Named("d", civil.DateOf(tin)), sql.Named("tm", civil.TimeOf(tin)), - sql.Named("dto", DateTimeOffset(tin)), + sql.Named("dto", mssqltypes.DateTimeOffset(tin)), ) err = row.Scan(&nv, &v, &nvcm, &vcm, &dt1, &dt2, &d, &tm, &dto) if err != nil { @@ -153,11 +154,11 @@ select sql.Named("nv", sin), sql.Named("v", sin), sql.Named("tgo", tin), - sql.Named("dt1", DateTime1(tin)), + sql.Named("dt1", mssqltypes.DateTime1(tin)), sql.Named("dt2", civil.DateTimeOf(tin)), sql.Named("d", civil.DateOf(tin)), sql.Named("tm", civil.TimeOf(tin)), - sql.Named("dto", DateTimeOffset(tin)), + sql.Named("dto", mssqltypes.DateTimeOffset(tin)), ).Scan(&nv, &v, &tgo, &dt1, &dt2, &d, &tm, &dto) if err != nil { t.Fatal(err) diff --git a/queries_go19_test.go b/queries_go19_test.go index e3addd03..5e86818d 100644 --- a/queries_go19_test.go +++ b/queries_go19_test.go @@ -10,6 +10,9 @@ import ( "regexp" "testing" "time" + + "github.com/denisenkom/go-mssqldb/internal/mssqlerror" + "github.com/denisenkom/go-mssqldb/internal/mssqltypes" ) func TestOutputParam(t *testing.T) { @@ -174,8 +177,8 @@ END; if err != nil { t.Fatal(err) } - var datetime_param DateTime1 - datetime_param = DateTime1(tin) + var datetime_param mssqltypes.DateTime1 + datetime_param = mssqltypes.DateTime1(tin) _, err = db.ExecContext(ctx, sqltextrun, sql.Named("datetime", sql.Out{Dest: &datetime_param}), ) @@ -224,7 +227,7 @@ END; t.Run("should fail if destination has invalid type", func(t *testing.T) { // Error type should not be supported - var err_out Error + var err_out mssqlerror.Error _, err := db.ExecContext(ctx, sqltextrun, sql.Named("bid", sql.Out{Dest: &err_out}), ) @@ -244,7 +247,7 @@ END; t.Run("should fail if parameter has invalid type", func(t *testing.T) { // passing invalid parameter type - var err_val Error + var err_val mssqlerror.Error _, err = db.ExecContext(ctx, sqltextrun, err_val) if err == nil { t.Error("Expected to fail but it didn't") @@ -346,7 +349,7 @@ END; t.Run("original test", func(t *testing.T) { var bout int64 = 3 var cout string - var vout VarChar + var vout mssqltypes.VarChar _, err = db.ExecContext(ctx, sqltextrun, sql.Named("aid", 5), sql.Named("bid", sql.Out{Dest: &bout}), @@ -964,12 +967,12 @@ func TestDateTimeParam19(t *testing.T) { var emptydate time.Time mindate1 := time.Date(1753, 1, 1, 0, 0, 0, 0, time.UTC) maxdate1 := time.Date(9999, 12, 31, 23, 59, 59, 997000000, time.UTC) - testdates1 := []DateTime1{ - DateTime1(mindate1), - DateTime1(maxdate1), - DateTime1(time.Date(1752, 12, 31, 23, 59, 59, 997000000, time.UTC)), // just a little below minimum date - DateTime1(time.Date(10000, 1, 1, 0, 0, 0, 0, time.UTC)), // just a little over maximum date - DateTime1(emptydate), + testdates1 := []mssqltypes.DateTime1{ + mssqltypes.DateTime1(mindate1), + mssqltypes.DateTime1(maxdate1), + mssqltypes.DateTime1(time.Date(1752, 12, 31, 23, 59, 59, 997000000, time.UTC)), // just a little below minimum date + mssqltypes.DateTime1(time.Date(10000, 1, 1, 0, 0, 0, 0, time.UTC)), // just a little over maximum date + mssqltypes.DateTime1(emptydate), } for _, test := range testdates1 { diff --git a/queries_test.go b/queries_test.go index 2b476a84..1098899a 100644 --- a/queries_test.go +++ b/queries_test.go @@ -14,6 +14,9 @@ import ( "sync" "testing" "time" + + "github.com/denisenkom/go-mssqldb/internal/mssqlerror" + "github.com/denisenkom/go-mssqldb/internal/mssqltypes" ) func driverWithProcess(t *testing.T) *Driver { @@ -619,7 +622,7 @@ func TestError(t *testing.T) { t.Fatal("Query should fail") } - if sqlerr, ok := err.(Error); !ok { + if sqlerr, ok := err.(mssqlerror.Error); !ok { t.Fatalf("Should be sql error, actually %T, %v", err, err) } else { if sqlerr.Number != 2812 { // Could not find stored procedure 'bad' @@ -836,7 +839,7 @@ func TestUniqueIdentifierParam(t *testing.T) { uuid interface{} } - expected := UniqueIdentifier{0x01, 0x23, 0x45, 0x67, + expected := mssqltypes.UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, @@ -856,7 +859,7 @@ func TestUniqueIdentifierParam(t *testing.T) { for _, test := range values { t.Run(test.name, func(t *testing.T) { - var uuid2 UniqueIdentifier + var uuid2 mssqltypes.UniqueIdentifier err := conn.QueryRow("select ?", test.uuid).Scan(&uuid2) if err != nil { t.Fatal("select / scan failed", err.Error()) @@ -969,7 +972,7 @@ func TestErrorInfo(t *testing.T) { defer conn.Close() _, err := conn.Exec("select bad") - if sqlError, ok := err.(Error); ok { + if sqlError, ok := err.(mssqlerror.Error); ok { if sqlError.SQLErrorNumber() != 207 /*invalid column name*/ { t.Errorf("Query failed with unexpected error number %d %s", sqlError.SQLErrorNumber(), sqlError.SQLErrorMessage()) } @@ -980,7 +983,7 @@ func TestErrorInfo(t *testing.T) { t.Error("Failed to convert error to SQLErorr", err) } _, err = conn.Exec("RAISERROR('test message', 18, 111)") - if sqlError, ok := err.(Error); ok { + if sqlError, ok := err.(mssqlerror.Error); ok { if sqlError.SQLErrorNumber() != 50000 { t.Errorf("Query failed with unexpected error number %d %s", sqlError.SQLErrorNumber(), sqlError.SQLErrorMessage()) } diff --git a/token.go b/token.go index 1acac8a5..7b4bd4f0 100644 --- a/token.go +++ b/token.go @@ -9,6 +9,8 @@ import ( "net" "strconv" "strings" + + "github.com/denisenkom/go-mssqldb/internal/mssqlerror" ) //go:generate stringer -type token @@ -87,18 +89,18 @@ type doneStruct struct { Status uint16 CurCmd uint16 RowCount uint64 - errors []Error + errors []mssqlerror.Error } func (d doneStruct) isError() bool { return d.Status&doneError != 0 || len(d.errors) > 0 } -func (d doneStruct) getError() Error { +func (d doneStruct) getError() mssqlerror.Error { if len(d.errors) > 0 { return d.errors[len(d.errors)-1] } else { - return Error{Message: "Request failed but didn't provide reason"} + return mssqlerror.Error{Message: "Request failed but didn't provide reason"} } } @@ -137,127 +139,127 @@ func processEnvChg(sess *tdsSession) { return } if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } switch envtype { case envTypDatabase: sess.database, err = readBVarChar(r) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } _, err = readBVarChar(r) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } case envTypLanguage: // currently ignored // new value if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } // old value if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } case envTypCharset: // currently ignored // new value if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } // old value if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } case envTypPacketSize: packetsize, err := readBVarChar(r) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } _, err = readBVarChar(r) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } packetsizei, err := strconv.Atoi(packetsize) if err != nil { - badStreamPanicf("Invalid Packet size value returned from server (%s): %s", packetsize, err.Error()) + mssqlerror.BadStreamPanicf("Invalid Packet size value returned from server (%s): %s", packetsize, err.Error()) } sess.buf.ResizeBuffer(packetsizei) case envSortId: // currently ignored // new value if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } // old value, should be 0 if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } case envSortFlags: // currently ignored // new value if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } // old value, should be 0 if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } case envSqlCollation: // currently ignored var collationSize uint8 err = binary.Read(r, binary.LittleEndian, &collationSize) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } // SQL Collation data should contain 5 bytes in length if collationSize != 5 { - badStreamPanicf("Invalid SQL Collation size value returned from server: %d", collationSize) + mssqlerror.BadStreamPanicf("Invalid SQL Collation size value returned from server: %d", collationSize) } // 4 bytes, contains: LCID ColFlags Version var info uint32 err = binary.Read(r, binary.LittleEndian, &info) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } // 1 byte, contains: sortID var sortID uint8 err = binary.Read(r, binary.LittleEndian, &sortID) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } // old value, should be 0 if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } case envTypBeginTran: tranid, err := readBVarByte(r) if len(tranid) != 8 { - badStreamPanicf("invalid size of transaction identifier: %d", len(tranid)) + mssqlerror.BadStreamPanicf("invalid size of transaction identifier: %d", len(tranid)) } sess.tranid = binary.LittleEndian.Uint64(tranid) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } if sess.logFlags&logTransaction != 0 { sess.log.Printf("BEGIN TRANSACTION %x\n", sess.tranid) } _, err = readBVarByte(r) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } case envTypCommitTran, envTypRollbackTran: _, err = readBVarByte(r) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } _, err = readBVarByte(r) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } if sess.logFlags&logTransaction != 0 { if envtype == envTypCommitTran { @@ -271,81 +273,81 @@ func processEnvChg(sess *tdsSession) { // currently ignored // new value, should be 0 if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } // old value if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } case envDefectTran: // currently ignored // new value if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } // old value, should be 0 if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } case envDatabaseMirrorPartner: sess.partner, err = readBVarChar(r) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } _, err = readBVarChar(r) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } case envPromoteTran: // currently ignored // old value, should be 0 if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } // dtc token // spec says it should be L_VARBYTE, so this code might be wrong if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } case envTranMgrAddr: // currently ignored // old value, should be 0 if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } // XACT_MANAGER_ADDRESS = B_VARBYTE if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } case envTranEnded: // currently ignored // old value, B_VARBYTE if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } // should be 0 if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } case envResetConnAck: // currently ignored // old value, should be 0 if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } // should be 0 if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } case envStartedInstanceName: // currently ignored // old value, should be 0 if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } // instance name if _, err = readBVarChar(r); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } case envRouting: // RoutingData message is: @@ -355,24 +357,24 @@ func processEnvChg(sess *tdsSession) { // AlternateServer US_VARCHAR _, err := readUshort(r) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } protocol, err := readByte(r) if err != nil || protocol != 0 { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } newPort, err := readUshort(r) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } newServer, err := readUsVarChar(r) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } // consume the OLDVALUE = %x00 %x00 _, err = readUshort(r) if err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } sess.routedServer = newServer sess.routedPort = newPort @@ -441,7 +443,7 @@ func parseLoginAck(r *tdsBuffer) loginAckStruct { prognamelen := buf[1+4] var err error if res.ProgName, err = ucs22str(buf[1+4+1 : 1+4+1+prognamelen*2]); err != nil { - badStreamPanic(err) + mssqlerror.BadStreamPanic(err) } res.ProgVer = binary.BigEndian.Uint32(buf[size-4:]) return res @@ -489,7 +491,7 @@ func parseNbcRow(r *tdsBuffer, columns []columnStruct, row []interface{}) { } // http://msdn.microsoft.com/en-us/library/dd304156.aspx -func parseError72(r *tdsBuffer) (res Error) { +func parseError72(r *tdsBuffer) (res mssqlerror.Error) { length := r.uint16() _ = length // ignore length res.Number = r.int32() @@ -503,7 +505,7 @@ func parseError72(r *tdsBuffer) (res Error) { } // http://msdn.microsoft.com/en-us/library/dd304156.aspx -func parseInfo(r *tdsBuffer) (res Error) { +func parseInfo(r *tdsBuffer) (res mssqlerror.Error) { length := r.uint16() _ = length // ignore length res.Number = r.int32() @@ -558,10 +560,10 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin return } if packet_type != packReply { - badStreamPanic(fmt.Errorf("unexpected packet type in reply: got %v, expected %v", packet_type, packReply)) + mssqlerror.BadStreamPanic(fmt.Errorf("unexpected packet type in reply: got %v, expected %v", packet_type, packReply)) } var columns []columnStruct - errs := make([]Error, 0, 5) + errs := make([]mssqlerror.Error, 0, 5) for { token := token(sess.buf.byte()) if sess.logFlags&logDebug != 0 { @@ -646,7 +648,7 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin } } default: - badStreamPanic(fmt.Errorf("unknown token type returned: %v", token)) + mssqlerror.BadStreamPanic(fmt.Errorf("unknown token type returned: %v", token)) } } } diff --git a/tvp_go19_db_test.go b/tvp_go19_db_test.go index 6cf42641..d5f33db0 100644 --- a/tvp_go19_db_test.go +++ b/tvp_go19_db_test.go @@ -9,6 +9,8 @@ import ( "reflect" "testing" "time" + + "github.com/denisenkom/go-mssqldb/internal/mssqltypes" ) const ( @@ -40,87 +42,87 @@ const ( ) type TvptableRow struct { - PBinary []byte `db:"p_binary"` - PVarchar string `db:"p_varchar"` - PVarcharNull *string `db:"p_varcharNull"` - PNvarchar string `db:"p_nvarchar"` - PNvarcharNull *string `db:"p_nvarcharNull"` - PID UniqueIdentifier `db:"p_id"` - PIDNull *UniqueIdentifier `db:"p_idNull"` - PVarbinary []byte `db:"p_varbinary"` - PTinyint int8 `db:"p_tinyint"` - PTinyintNull *int8 `db:"p_tinyintNull"` - PSmallint int16 `db:"p_smallint"` - PSmallintNull *int16 `db:"p_smallintNull"` - PInt int32 `db:"p_int"` - PIntNull *int32 `db:"p_intNull"` - PBigint int64 `db:"p_bigint"` - PBigintNull *int64 `db:"p_bigintNull"` - PBit bool `db:"p_bit"` - PBitNull *bool `db:"p_bitNull"` - PFloat32 float32 `db:"p_float32"` - PFloatNull32 *float32 `db:"p_floatNull32"` - PFloat64 float64 `db:"p_float64"` - PFloatNull64 *float64 `db:"p_floatNull64"` - DTime time.Time `db:"p_timeNull"` - DTimeNull *time.Time `db:"p_time"` - Pint int `db:"pInt"` - PintNull *int `db:"pIntNull"` + PBinary []byte `db:"p_binary"` + PVarchar string `db:"p_varchar"` + PVarcharNull *string `db:"p_varcharNull"` + PNvarchar string `db:"p_nvarchar"` + PNvarcharNull *string `db:"p_nvarcharNull"` + PID mssqltypes.UniqueIdentifier `db:"p_id"` + PIDNull *mssqltypes.UniqueIdentifier `db:"p_idNull"` + PVarbinary []byte `db:"p_varbinary"` + PTinyint int8 `db:"p_tinyint"` + PTinyintNull *int8 `db:"p_tinyintNull"` + PSmallint int16 `db:"p_smallint"` + PSmallintNull *int16 `db:"p_smallintNull"` + PInt int32 `db:"p_int"` + PIntNull *int32 `db:"p_intNull"` + PBigint int64 `db:"p_bigint"` + PBigintNull *int64 `db:"p_bigintNull"` + PBit bool `db:"p_bit"` + PBitNull *bool `db:"p_bitNull"` + PFloat32 float32 `db:"p_float32"` + PFloatNull32 *float32 `db:"p_floatNull32"` + PFloat64 float64 `db:"p_float64"` + PFloatNull64 *float64 `db:"p_floatNull64"` + DTime time.Time `db:"p_timeNull"` + DTimeNull *time.Time `db:"p_time"` + Pint int `db:"pInt"` + PintNull *int `db:"pIntNull"` } type TvptableRowWithSkipTag struct { - PBinary []byte `db:"p_binary"` - SkipPBinary []byte `json:"-"` - PVarchar string `db:"p_varchar"` - SkipPVarchar string `tvp:"-"` - PVarcharNull *string `db:"p_varcharNull"` - SkipPVarcharNull *string `json:"-" tvp:"-"` - PNvarchar string `db:"p_nvarchar"` - SkipPNvarchar string `json:"-"` - PNvarcharNull *string `db:"p_nvarcharNull"` - SkipPNvarcharNull *string `json:"-"` - PID UniqueIdentifier `db:"p_id"` - SkipPID UniqueIdentifier `json:"-"` - PIDNull *UniqueIdentifier `db:"p_idNull"` - SkipPIDNull *UniqueIdentifier `tvp:"-"` - PVarbinary []byte `db:"p_varbinary"` - SkipPVarbinary []byte `json:"-" tvp:"-"` - PTinyint int8 `db:"p_tinyint"` - SkipPTinyint int8 `tvp:"-"` - PTinyintNull *int8 `db:"p_tinyintNull"` - SkipPTinyintNull *int8 `tvp:"-" json:"any"` - PSmallint int16 `db:"p_smallint"` - SkipPSmallint int16 `json:"-"` - PSmallintNull *int16 `db:"p_smallintNull"` - SkipPSmallintNull *int16 `tvp:"-"` - PInt int32 `db:"p_int"` - SkipPInt int32 `json:"-"` - PIntNull *int32 `db:"p_intNull"` - SkipPIntNull *int32 `tvp:"-"` - PBigint int64 `db:"p_bigint"` - SkipPBigint int64 `tvp:"-"` - PBigintNull *int64 `db:"p_bigintNull"` - SkipPBigintNull *int64 `json:"any" tvp:"-"` - PBit bool `db:"p_bit"` - SkipPBit bool `json:"-"` - PBitNull *bool `db:"p_bitNull"` - SkipPBitNull *bool `json:"-"` - PFloat32 float32 `db:"p_float32"` - SkipPFloat32 float32 `tvp:"-"` - PFloatNull32 *float32 `db:"p_floatNull32"` - SkipPFloatNull32 *float32 `tvp:"-"` - PFloat64 float64 `db:"p_float64"` - SkipPFloat64 float64 `tvp:"-"` - PFloatNull64 *float64 `db:"p_floatNull64"` - SkipPFloatNull64 *float64 `tvp:"-"` - DTime time.Time `db:"p_timeNull"` - SkipDTime time.Time `tvp:"-"` - DTimeNull *time.Time `db:"p_time"` - SkipDTimeNull *time.Time `tvp:"-"` - Pint int `db:"p_int_null"` - SkipPint int `tvp:"-"` - PintNull *int `db:"p_int_"` - SkipPintNull *int `tvp:"-"` + PBinary []byte `db:"p_binary"` + SkipPBinary []byte `json:"-"` + PVarchar string `db:"p_varchar"` + SkipPVarchar string `tvp:"-"` + PVarcharNull *string `db:"p_varcharNull"` + SkipPVarcharNull *string `json:"-" tvp:"-"` + PNvarchar string `db:"p_nvarchar"` + SkipPNvarchar string `json:"-"` + PNvarcharNull *string `db:"p_nvarcharNull"` + SkipPNvarcharNull *string `json:"-"` + PID mssqltypes.UniqueIdentifier `db:"p_id"` + SkipPID mssqltypes.UniqueIdentifier `json:"-"` + PIDNull *mssqltypes.UniqueIdentifier `db:"p_idNull"` + SkipPIDNull *mssqltypes.UniqueIdentifier `tvp:"-"` + PVarbinary []byte `db:"p_varbinary"` + SkipPVarbinary []byte `json:"-" tvp:"-"` + PTinyint int8 `db:"p_tinyint"` + SkipPTinyint int8 `tvp:"-"` + PTinyintNull *int8 `db:"p_tinyintNull"` + SkipPTinyintNull *int8 `tvp:"-" json:"any"` + PSmallint int16 `db:"p_smallint"` + SkipPSmallint int16 `json:"-"` + PSmallintNull *int16 `db:"p_smallintNull"` + SkipPSmallintNull *int16 `tvp:"-"` + PInt int32 `db:"p_int"` + SkipPInt int32 `json:"-"` + PIntNull *int32 `db:"p_intNull"` + SkipPIntNull *int32 `tvp:"-"` + PBigint int64 `db:"p_bigint"` + SkipPBigint int64 `tvp:"-"` + PBigintNull *int64 `db:"p_bigintNull"` + SkipPBigintNull *int64 `json:"any" tvp:"-"` + PBit bool `db:"p_bit"` + SkipPBit bool `json:"-"` + PBitNull *bool `db:"p_bitNull"` + SkipPBitNull *bool `json:"-"` + PFloat32 float32 `db:"p_float32"` + SkipPFloat32 float32 `tvp:"-"` + PFloatNull32 *float32 `db:"p_floatNull32"` + SkipPFloatNull32 *float32 `tvp:"-"` + PFloat64 float64 `db:"p_float64"` + SkipPFloat64 float64 `tvp:"-"` + PFloatNull64 *float64 `db:"p_floatNull64"` + SkipPFloatNull64 *float64 `tvp:"-"` + DTime time.Time `db:"p_timeNull"` + SkipDTime time.Time `tvp:"-"` + DTimeNull *time.Time `db:"p_time"` + SkipDTimeNull *time.Time `tvp:"-"` + Pint int `db:"p_int_null"` + SkipPint int `tvp:"-"` + PintNull *int `db:"p_int_"` + SkipPintNull *int `tvp:"-"` } func TestTVP(t *testing.T) { @@ -215,7 +217,7 @@ func TestTVP(t *testing.T) { PBinary: []byte("ccc"), PVarchar: varcharNull, PNvarchar: nvarchar, - PID: UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, + PID: mssqltypes.UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, PVarbinary: bytesMock, PTinyint: i8, PSmallint: i16, @@ -231,7 +233,7 @@ func TestTVP(t *testing.T) { PBinary: []byte("www"), PVarchar: "eee", PNvarchar: "lll", - PID: UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, + PID: mssqltypes.UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, PVarbinary: []byte("zzz"), PTinyint: 5, PSmallint: 16000, @@ -247,7 +249,7 @@ func TestTVP(t *testing.T) { PBinary: nil, PVarcharNull: &varcharNull, PNvarcharNull: &nvarchar, - PIDNull: &UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, + PIDNull: &mssqltypes.UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, PTinyintNull: &i8, PSmallintNull: &i16, PIntNull: &i32, @@ -263,7 +265,7 @@ func TestTVP(t *testing.T) { PBinary: []byte("www"), PVarchar: "eee", PNvarchar: "lll", - PIDNull: &UniqueIdentifier{}, + PIDNull: &mssqltypes.UniqueIdentifier{}, PVarbinary: []byte("zzz"), PTinyint: 5, PSmallint: 16000, @@ -468,7 +470,7 @@ func TestTVP_WithTag(t *testing.T) { PBinary: []byte("ccc"), PVarchar: varcharNull, PNvarchar: nvarchar, - PID: UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, + PID: mssqltypes.UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, PVarbinary: bytesMock, PTinyint: i8, PSmallint: i16, @@ -485,7 +487,7 @@ func TestTVP_WithTag(t *testing.T) { PBinary: []byte("www"), PVarchar: "eee", PNvarchar: "lll", - PID: UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, + PID: mssqltypes.UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, PVarbinary: []byte("zzz"), PTinyint: 5, PSmallint: 16000, @@ -502,7 +504,7 @@ func TestTVP_WithTag(t *testing.T) { PBinary: nil, PVarcharNull: &varcharNull, PNvarcharNull: &nvarchar, - PIDNull: &UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, + PIDNull: &mssqltypes.UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, PTinyintNull: &i8, PSmallintNull: &i16, PIntNull: &i32, @@ -518,7 +520,7 @@ func TestTVP_WithTag(t *testing.T) { PBinary: []byte("www"), PVarchar: "eee", PNvarchar: "lll", - PIDNull: &UniqueIdentifier{}, + PIDNull: &mssqltypes.UniqueIdentifier{}, PVarbinary: []byte("zzz"), PTinyint: 5, PSmallint: 16000, diff --git a/types.go b/types.go index b6e7fb2b..6bf297f7 100644 --- a/types.go +++ b/types.go @@ -11,7 +11,8 @@ import ( "time" "github.com/denisenkom/go-mssqldb/internal/cp" - "github.com/denisenkom/go-mssqldb/internal/decimal" + "github.com/denisenkom/go-mssqldb/internal/mssqlerror" + "github.com/denisenkom/go-mssqldb/internal/mssqltypes" ) // fixed-length data types @@ -338,7 +339,7 @@ func readFixedType(ti *typeInfo, r *tdsBuffer) interface{} { case typeInt8: return int64(binary.LittleEndian.Uint64(buf)) default: - badStreamPanicf("Invalid typeid") + mssqlerror.BadStreamPanicf("Invalid typeid") } panic("shoulnd't get here") } @@ -353,7 +354,7 @@ func readByteLenType(ti *typeInfo, r *tdsBuffer) interface{} { switch ti.TypeId { case typeDateN: if len(buf) != 3 { - badStreamPanicf("Invalid size for DATENTYPE") + mssqlerror.BadStreamPanicf("Invalid size for DATENTYPE") } return decodeDate(buf) case typeTimeN: @@ -375,13 +376,13 @@ func readByteLenType(ti *typeInfo, r *tdsBuffer) interface{} { case 8: return int64(binary.LittleEndian.Uint64(buf)) default: - badStreamPanicf("Invalid size for INTNTYPE: %d", len(buf)) + mssqlerror.BadStreamPanicf("Invalid size for INTNTYPE: %d", len(buf)) } case typeDecimal, typeNumeric, typeDecimalN, typeNumericN: return decodeDecimal(ti.Prec, ti.Scale, buf) case typeBitN: if len(buf) != 1 { - badStreamPanicf("Invalid size for BITNTYPE") + mssqlerror.BadStreamPanicf("Invalid size for BITNTYPE") } return buf[0] != 0 case typeFltN: @@ -391,7 +392,7 @@ func readByteLenType(ti *typeInfo, r *tdsBuffer) interface{} { case 8: return math.Float64frombits(binary.LittleEndian.Uint64(buf)) default: - badStreamPanicf("Invalid size for FLTNTYPE") + mssqlerror.BadStreamPanicf("Invalid size for FLTNTYPE") } case typeMoneyN: switch len(buf) { @@ -400,7 +401,7 @@ func readByteLenType(ti *typeInfo, r *tdsBuffer) interface{} { case 8: return decodeMoney(buf) default: - badStreamPanicf("Invalid size for MONEYNTYPE") + mssqlerror.BadStreamPanicf("Invalid size for MONEYNTYPE") } case typeDateTim4: return decodeDateTim4(buf) @@ -413,7 +414,7 @@ func readByteLenType(ti *typeInfo, r *tdsBuffer) interface{} { case 8: return decodeDateTime(buf) default: - badStreamPanicf("Invalid size for DATETIMENTYPE") + mssqlerror.BadStreamPanicf("Invalid size for DATETIMENTYPE") } case typeChar, typeVarChar: return decodeChar(ti.Collation, buf) @@ -425,7 +426,7 @@ func readByteLenType(ti *typeInfo, r *tdsBuffer) interface{} { copy(cpy, buf) return cpy default: - badStreamPanicf("Invalid typeid") + mssqlerror.BadStreamPanicf("Invalid typeid") } panic("shoulnd't get here") } @@ -464,7 +465,7 @@ func readShortLenType(ti *typeInfo, r *tdsBuffer) interface{} { case typeUdt: return decodeUdt(*ti, buf) default: - badStreamPanicf("Invalid typeid") + mssqlerror.BadStreamPanicf("Invalid typeid") } panic("shoulnd't get here") } @@ -512,7 +513,7 @@ func readLongLenType(ti *typeInfo, r *tdsBuffer) interface{} { case typeNText: return decodeNChar(buf) default: - badStreamPanicf("Invalid typeid") + mssqlerror.BadStreamPanicf("Invalid typeid") } panic("shoulnd't get here") } @@ -645,7 +646,7 @@ func readVariantType(ti *typeInfo, r *tdsBuffer) interface{} { r.ReadFull(buf) return decodeNChar(buf) default: - badStreamPanicf("Invalid variant typeid") + mssqlerror.BadStreamPanicf("Invalid variant typeid") } panic("shoulnd't get here") } @@ -671,7 +672,7 @@ func readPLPType(ti *typeInfo, r *tdsBuffer) interface{} { break } if _, err := io.CopyN(buf, r, int64(chunksize)); err != nil { - badStreamPanicf("Reading PLP type failed: %s", err.Error()) + mssqlerror.BadStreamPanicf("Reading PLP type failed: %s", err.Error()) } } switch ti.TypeId { @@ -725,7 +726,7 @@ func readVarLen(ti *typeInfo, r *tdsBuffer) { case 5, 6, 7: ti.Size = 5 default: - badStreamPanicf("Invalid scale for TIME/DATETIME2/DATETIMEOFFSET type") + mssqlerror.BadStreamPanicf("Invalid scale for TIME/DATETIME2/DATETIMEOFFSET type") } switch ti.TypeId { case typeDateTime2N: @@ -805,7 +806,7 @@ func readVarLen(ti *typeInfo, r *tdsBuffer) { ti.Reader = readVariantType } default: - badStreamPanicf("Invalid type %d", ti.TypeId) + mssqlerror.BadStreamPanicf("Invalid type %d", ti.TypeId) } return } @@ -819,12 +820,12 @@ func decodeMoney(buf []byte) []byte { uint64(buf[1])<<40 | uint64(buf[2])<<48 | uint64(buf[3])<<56) - return decimal.ScaleBytes(strconv.FormatInt(money, 10), 4) + return mssqltypes.ScaleBytes(strconv.FormatInt(money, 10), 4) } func decodeMoney4(buf []byte) []byte { money := int32(binary.LittleEndian.Uint32(buf[0:4])) - return decimal.ScaleBytes(strconv.FormatInt(int64(money), 10), 4) + return mssqltypes.ScaleBytes(strconv.FormatInt(int64(money), 10), 4) } func decodeGuid(buf []byte) []byte { @@ -836,7 +837,7 @@ func decodeGuid(buf []byte) []byte { func decodeDecimal(prec uint8, scale uint8, buf []byte) []byte { var sign uint8 sign = buf[0] - var dec decimal.Decimal + var dec mssqltypes.Decimal dec.SetPositive(sign != 0) dec.SetPrec(prec) dec.SetScale(scale) @@ -994,7 +995,7 @@ func decodeChar(col cp.Collation, buf []byte) string { func decodeUcs2(buf []byte) string { res, err := ucs22str(buf) if err != nil { - badStreamPanicf("Invalid UCS2 encoding: %s", err.Error()) + mssqlerror.BadStreamPanicf("Invalid UCS2 encoding: %s", err.Error()) } return res } From 3dde7868fd453d585a5f41aaa25dbf6a99dc3bba Mon Sep 17 00:00:00 2001 From: v-kaywon Date: Wed, 11 Sep 2019 14:31:31 -0700 Subject: [PATCH 08/10] fix buf_test.go --- buf_test.go | 32 -------------------------------- 1 file changed, 32 deletions(-) diff --git a/buf_test.go b/buf_test.go index 76efbd1d..9d681665 100644 --- a/buf_test.go +++ b/buf_test.go @@ -282,35 +282,3 @@ func TestWrite_BufferBounds(t *testing.T) { t.Fatal("FinishPacket failed:", err.Error()) } } - -func TestReadUsVarCharOrPanic(t *testing.T) { - memBuf := bytes.NewBuffer([]byte{3, 0, 0x31, 0, 0x32, 0, 0x33, 0}) - s := readUsVarCharOrPanic(memBuf) - if s != "123" { - t.Errorf("UsVarChar expected to return 123 but it returned %s", s) - } - - // test invalid usvarchar - defer func() { - recover() - }() - memBuf = bytes.NewBuffer([]byte{}) - s = readUsVarCharOrPanic(memBuf) - t.Fatal("UsVarChar() should panic, but it didn't") -} - -func TestReadBVarCharOrPanic(t *testing.T) { - memBuf := bytes.NewBuffer([]byte{3, 0x31, 0, 0x32, 0, 0x33, 0}) - s := readBVarCharOrPanic(memBuf) - if s != "123" { - t.Errorf("readBVarCharOrPanic expected to return 123 but it returned %s", s) - } - - // test invalid varchar - defer func() { - recover() - }() - memBuf = bytes.NewBuffer([]byte{}) - s = readBVarCharOrPanic(memBuf) - t.Fatal("readBVarCharOrPanic() should panic on empty buffer, but it didn't") -} From 1844e8344ecc3dbb03955a41892cb80c5e5bd796 Mon Sep 17 00:00:00 2001 From: v-kaywon Date: Thu, 12 Sep 2019 17:14:27 -0700 Subject: [PATCH 09/10] update buf_test.go --- buf_test.go | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/buf_test.go b/buf_test.go index 9d681665..76efbd1d 100644 --- a/buf_test.go +++ b/buf_test.go @@ -282,3 +282,35 @@ func TestWrite_BufferBounds(t *testing.T) { t.Fatal("FinishPacket failed:", err.Error()) } } + +func TestReadUsVarCharOrPanic(t *testing.T) { + memBuf := bytes.NewBuffer([]byte{3, 0, 0x31, 0, 0x32, 0, 0x33, 0}) + s := readUsVarCharOrPanic(memBuf) + if s != "123" { + t.Errorf("UsVarChar expected to return 123 but it returned %s", s) + } + + // test invalid usvarchar + defer func() { + recover() + }() + memBuf = bytes.NewBuffer([]byte{}) + s = readUsVarCharOrPanic(memBuf) + t.Fatal("UsVarChar() should panic, but it didn't") +} + +func TestReadBVarCharOrPanic(t *testing.T) { + memBuf := bytes.NewBuffer([]byte{3, 0x31, 0, 0x32, 0, 0x33, 0}) + s := readBVarCharOrPanic(memBuf) + if s != "123" { + t.Errorf("readBVarCharOrPanic expected to return 123 but it returned %s", s) + } + + // test invalid varchar + defer func() { + recover() + }() + memBuf = bytes.NewBuffer([]byte{}) + s = readBVarCharOrPanic(memBuf) + t.Fatal("readBVarCharOrPanic() should panic on empty buffer, but it didn't") +} From e6e91f5f31a94caf4b704d6720b0d8d580c418aa Mon Sep 17 00:00:00 2001 From: v-kaywon Date: Tue, 17 Sep 2019 16:26:52 -0700 Subject: [PATCH 10/10] pull error out of subpackage --- buf.go | 10 +-- internal/mssqlerror/error.go => error.go | 8 +- mssql.go | 3 +- queries_go19_test.go | 5 +- queries_test.go | 7 +- token.go | 110 +++++++++++------------ types.go | 31 ++++--- 7 files changed, 82 insertions(+), 92 deletions(-) rename internal/mssqlerror/error.go => error.go (87%) diff --git a/buf.go b/buf.go index e20d7aa1..ba39b40f 100644 --- a/buf.go +++ b/buf.go @@ -4,8 +4,6 @@ import ( "encoding/binary" "errors" "io" - - "github.com/denisenkom/go-mssqldb/internal/mssqlerror" ) type packetType uint8 @@ -188,7 +186,7 @@ func (r *tdsBuffer) ReadByte() (res byte, err error) { func (r *tdsBuffer) byte() byte { b, err := r.ReadByte() if err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } return b } @@ -196,7 +194,7 @@ func (r *tdsBuffer) byte() byte { func (r *tdsBuffer) ReadFull(buf []byte) { _, err := io.ReadFull(r, buf[:]) if err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } } @@ -229,7 +227,7 @@ func (r *tdsBuffer) BVarChar() string { func readBVarCharOrPanic(r io.Reader) string { s, err := readBVarChar(r) if err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } return s } @@ -237,7 +235,7 @@ func readBVarCharOrPanic(r io.Reader) string { func readUsVarCharOrPanic(r io.Reader) string { s, err := readUsVarChar(r) if err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } return s } diff --git a/internal/mssqlerror/error.go b/error.go similarity index 87% rename from internal/mssqlerror/error.go rename to error.go index fe4ce619..6a0175c0 100644 --- a/internal/mssqlerror/error.go +++ b/error.go @@ -1,4 +1,4 @@ -package mssqlerror +package mssql import ( "fmt" @@ -73,12 +73,10 @@ func streamErrorf(format string, v ...interface{}) StreamError { return StreamError{"Invalid TDS stream: " + fmt.Sprintf(format, v...)} } -// BadStreamPanic calls panic with err. -func BadStreamPanic(err error) { +func badStreamPanic(err error) { panic(err) } -// BadStreamPanicf calls panic with a formatted error message as an invalid TDS stream error. -func BadStreamPanicf(format string, v ...interface{}) { +func badStreamPanicf(format string, v ...interface{}) { panic(streamErrorf(format, v...)) } diff --git a/mssql.go b/mssql.go index 2949163d..e37109cd 100644 --- a/mssql.go +++ b/mssql.go @@ -15,7 +15,6 @@ import ( "time" "unicode" - "github.com/denisenkom/go-mssqldb/internal/mssqlerror" "github.com/denisenkom/go-mssqldb/internal/querytext" ) @@ -189,7 +188,7 @@ func (c *Conn) checkBadConn(err error) error { case net.Error: c.connectionGood = false return err - case mssqlerror.StreamError: + case StreamError: c.connectionGood = false return err default: diff --git a/queries_go19_test.go b/queries_go19_test.go index 5e86818d..222a034d 100644 --- a/queries_go19_test.go +++ b/queries_go19_test.go @@ -11,7 +11,6 @@ import ( "testing" "time" - "github.com/denisenkom/go-mssqldb/internal/mssqlerror" "github.com/denisenkom/go-mssqldb/internal/mssqltypes" ) @@ -227,7 +226,7 @@ END; t.Run("should fail if destination has invalid type", func(t *testing.T) { // Error type should not be supported - var err_out mssqlerror.Error + var err_out Error _, err := db.ExecContext(ctx, sqltextrun, sql.Named("bid", sql.Out{Dest: &err_out}), ) @@ -247,7 +246,7 @@ END; t.Run("should fail if parameter has invalid type", func(t *testing.T) { // passing invalid parameter type - var err_val mssqlerror.Error + var err_val Error _, err = db.ExecContext(ctx, sqltextrun, err_val) if err == nil { t.Error("Expected to fail but it didn't") diff --git a/queries_test.go b/queries_test.go index 1098899a..8f5c8d40 100644 --- a/queries_test.go +++ b/queries_test.go @@ -15,7 +15,6 @@ import ( "testing" "time" - "github.com/denisenkom/go-mssqldb/internal/mssqlerror" "github.com/denisenkom/go-mssqldb/internal/mssqltypes" ) @@ -622,7 +621,7 @@ func TestError(t *testing.T) { t.Fatal("Query should fail") } - if sqlerr, ok := err.(mssqlerror.Error); !ok { + if sqlerr, ok := err.(Error); !ok { t.Fatalf("Should be sql error, actually %T, %v", err, err) } else { if sqlerr.Number != 2812 { // Could not find stored procedure 'bad' @@ -972,7 +971,7 @@ func TestErrorInfo(t *testing.T) { defer conn.Close() _, err := conn.Exec("select bad") - if sqlError, ok := err.(mssqlerror.Error); ok { + if sqlError, ok := err.(Error); ok { if sqlError.SQLErrorNumber() != 207 /*invalid column name*/ { t.Errorf("Query failed with unexpected error number %d %s", sqlError.SQLErrorNumber(), sqlError.SQLErrorMessage()) } @@ -983,7 +982,7 @@ func TestErrorInfo(t *testing.T) { t.Error("Failed to convert error to SQLErorr", err) } _, err = conn.Exec("RAISERROR('test message', 18, 111)") - if sqlError, ok := err.(mssqlerror.Error); ok { + if sqlError, ok := err.(Error); ok { if sqlError.SQLErrorNumber() != 50000 { t.Errorf("Query failed with unexpected error number %d %s", sqlError.SQLErrorNumber(), sqlError.SQLErrorMessage()) } diff --git a/token.go b/token.go index 7b4bd4f0..1acac8a5 100644 --- a/token.go +++ b/token.go @@ -9,8 +9,6 @@ import ( "net" "strconv" "strings" - - "github.com/denisenkom/go-mssqldb/internal/mssqlerror" ) //go:generate stringer -type token @@ -89,18 +87,18 @@ type doneStruct struct { Status uint16 CurCmd uint16 RowCount uint64 - errors []mssqlerror.Error + errors []Error } func (d doneStruct) isError() bool { return d.Status&doneError != 0 || len(d.errors) > 0 } -func (d doneStruct) getError() mssqlerror.Error { +func (d doneStruct) getError() Error { if len(d.errors) > 0 { return d.errors[len(d.errors)-1] } else { - return mssqlerror.Error{Message: "Request failed but didn't provide reason"} + return Error{Message: "Request failed but didn't provide reason"} } } @@ -139,127 +137,127 @@ func processEnvChg(sess *tdsSession) { return } if err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } switch envtype { case envTypDatabase: sess.database, err = readBVarChar(r) if err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } _, err = readBVarChar(r) if err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } case envTypLanguage: // currently ignored // new value if _, err = readBVarChar(r); err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } // old value if _, err = readBVarChar(r); err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } case envTypCharset: // currently ignored // new value if _, err = readBVarChar(r); err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } // old value if _, err = readBVarChar(r); err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } case envTypPacketSize: packetsize, err := readBVarChar(r) if err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } _, err = readBVarChar(r) if err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } packetsizei, err := strconv.Atoi(packetsize) if err != nil { - mssqlerror.BadStreamPanicf("Invalid Packet size value returned from server (%s): %s", packetsize, err.Error()) + badStreamPanicf("Invalid Packet size value returned from server (%s): %s", packetsize, err.Error()) } sess.buf.ResizeBuffer(packetsizei) case envSortId: // currently ignored // new value if _, err = readBVarChar(r); err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } // old value, should be 0 if _, err = readBVarChar(r); err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } case envSortFlags: // currently ignored // new value if _, err = readBVarChar(r); err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } // old value, should be 0 if _, err = readBVarChar(r); err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } case envSqlCollation: // currently ignored var collationSize uint8 err = binary.Read(r, binary.LittleEndian, &collationSize) if err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } // SQL Collation data should contain 5 bytes in length if collationSize != 5 { - mssqlerror.BadStreamPanicf("Invalid SQL Collation size value returned from server: %d", collationSize) + badStreamPanicf("Invalid SQL Collation size value returned from server: %d", collationSize) } // 4 bytes, contains: LCID ColFlags Version var info uint32 err = binary.Read(r, binary.LittleEndian, &info) if err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } // 1 byte, contains: sortID var sortID uint8 err = binary.Read(r, binary.LittleEndian, &sortID) if err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } // old value, should be 0 if _, err = readBVarChar(r); err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } case envTypBeginTran: tranid, err := readBVarByte(r) if len(tranid) != 8 { - mssqlerror.BadStreamPanicf("invalid size of transaction identifier: %d", len(tranid)) + badStreamPanicf("invalid size of transaction identifier: %d", len(tranid)) } sess.tranid = binary.LittleEndian.Uint64(tranid) if err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } if sess.logFlags&logTransaction != 0 { sess.log.Printf("BEGIN TRANSACTION %x\n", sess.tranid) } _, err = readBVarByte(r) if err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } case envTypCommitTran, envTypRollbackTran: _, err = readBVarByte(r) if err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } _, err = readBVarByte(r) if err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } if sess.logFlags&logTransaction != 0 { if envtype == envTypCommitTran { @@ -273,81 +271,81 @@ func processEnvChg(sess *tdsSession) { // currently ignored // new value, should be 0 if _, err = readBVarChar(r); err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } // old value if _, err = readBVarChar(r); err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } case envDefectTran: // currently ignored // new value if _, err = readBVarChar(r); err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } // old value, should be 0 if _, err = readBVarChar(r); err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } case envDatabaseMirrorPartner: sess.partner, err = readBVarChar(r) if err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } _, err = readBVarChar(r) if err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } case envPromoteTran: // currently ignored // old value, should be 0 if _, err = readBVarChar(r); err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } // dtc token // spec says it should be L_VARBYTE, so this code might be wrong if _, err = readBVarChar(r); err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } case envTranMgrAddr: // currently ignored // old value, should be 0 if _, err = readBVarChar(r); err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } // XACT_MANAGER_ADDRESS = B_VARBYTE if _, err = readBVarChar(r); err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } case envTranEnded: // currently ignored // old value, B_VARBYTE if _, err = readBVarChar(r); err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } // should be 0 if _, err = readBVarChar(r); err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } case envResetConnAck: // currently ignored // old value, should be 0 if _, err = readBVarChar(r); err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } // should be 0 if _, err = readBVarChar(r); err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } case envStartedInstanceName: // currently ignored // old value, should be 0 if _, err = readBVarChar(r); err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } // instance name if _, err = readBVarChar(r); err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } case envRouting: // RoutingData message is: @@ -357,24 +355,24 @@ func processEnvChg(sess *tdsSession) { // AlternateServer US_VARCHAR _, err := readUshort(r) if err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } protocol, err := readByte(r) if err != nil || protocol != 0 { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } newPort, err := readUshort(r) if err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } newServer, err := readUsVarChar(r) if err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } // consume the OLDVALUE = %x00 %x00 _, err = readUshort(r) if err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } sess.routedServer = newServer sess.routedPort = newPort @@ -443,7 +441,7 @@ func parseLoginAck(r *tdsBuffer) loginAckStruct { prognamelen := buf[1+4] var err error if res.ProgName, err = ucs22str(buf[1+4+1 : 1+4+1+prognamelen*2]); err != nil { - mssqlerror.BadStreamPanic(err) + badStreamPanic(err) } res.ProgVer = binary.BigEndian.Uint32(buf[size-4:]) return res @@ -491,7 +489,7 @@ func parseNbcRow(r *tdsBuffer, columns []columnStruct, row []interface{}) { } // http://msdn.microsoft.com/en-us/library/dd304156.aspx -func parseError72(r *tdsBuffer) (res mssqlerror.Error) { +func parseError72(r *tdsBuffer) (res Error) { length := r.uint16() _ = length // ignore length res.Number = r.int32() @@ -505,7 +503,7 @@ func parseError72(r *tdsBuffer) (res mssqlerror.Error) { } // http://msdn.microsoft.com/en-us/library/dd304156.aspx -func parseInfo(r *tdsBuffer) (res mssqlerror.Error) { +func parseInfo(r *tdsBuffer) (res Error) { length := r.uint16() _ = length // ignore length res.Number = r.int32() @@ -560,10 +558,10 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin return } if packet_type != packReply { - mssqlerror.BadStreamPanic(fmt.Errorf("unexpected packet type in reply: got %v, expected %v", packet_type, packReply)) + badStreamPanic(fmt.Errorf("unexpected packet type in reply: got %v, expected %v", packet_type, packReply)) } var columns []columnStruct - errs := make([]mssqlerror.Error, 0, 5) + errs := make([]Error, 0, 5) for { token := token(sess.buf.byte()) if sess.logFlags&logDebug != 0 { @@ -648,7 +646,7 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin } } default: - mssqlerror.BadStreamPanic(fmt.Errorf("unknown token type returned: %v", token)) + badStreamPanic(fmt.Errorf("unknown token type returned: %v", token)) } } } diff --git a/types.go b/types.go index 6bf297f7..497d3abf 100644 --- a/types.go +++ b/types.go @@ -11,7 +11,6 @@ import ( "time" "github.com/denisenkom/go-mssqldb/internal/cp" - "github.com/denisenkom/go-mssqldb/internal/mssqlerror" "github.com/denisenkom/go-mssqldb/internal/mssqltypes" ) @@ -339,7 +338,7 @@ func readFixedType(ti *typeInfo, r *tdsBuffer) interface{} { case typeInt8: return int64(binary.LittleEndian.Uint64(buf)) default: - mssqlerror.BadStreamPanicf("Invalid typeid") + badStreamPanicf("Invalid typeid") } panic("shoulnd't get here") } @@ -354,7 +353,7 @@ func readByteLenType(ti *typeInfo, r *tdsBuffer) interface{} { switch ti.TypeId { case typeDateN: if len(buf) != 3 { - mssqlerror.BadStreamPanicf("Invalid size for DATENTYPE") + badStreamPanicf("Invalid size for DATENTYPE") } return decodeDate(buf) case typeTimeN: @@ -376,13 +375,13 @@ func readByteLenType(ti *typeInfo, r *tdsBuffer) interface{} { case 8: return int64(binary.LittleEndian.Uint64(buf)) default: - mssqlerror.BadStreamPanicf("Invalid size for INTNTYPE: %d", len(buf)) + badStreamPanicf("Invalid size for INTNTYPE: %d", len(buf)) } case typeDecimal, typeNumeric, typeDecimalN, typeNumericN: return decodeDecimal(ti.Prec, ti.Scale, buf) case typeBitN: if len(buf) != 1 { - mssqlerror.BadStreamPanicf("Invalid size for BITNTYPE") + badStreamPanicf("Invalid size for BITNTYPE") } return buf[0] != 0 case typeFltN: @@ -392,7 +391,7 @@ func readByteLenType(ti *typeInfo, r *tdsBuffer) interface{} { case 8: return math.Float64frombits(binary.LittleEndian.Uint64(buf)) default: - mssqlerror.BadStreamPanicf("Invalid size for FLTNTYPE") + badStreamPanicf("Invalid size for FLTNTYPE") } case typeMoneyN: switch len(buf) { @@ -401,7 +400,7 @@ func readByteLenType(ti *typeInfo, r *tdsBuffer) interface{} { case 8: return decodeMoney(buf) default: - mssqlerror.BadStreamPanicf("Invalid size for MONEYNTYPE") + badStreamPanicf("Invalid size for MONEYNTYPE") } case typeDateTim4: return decodeDateTim4(buf) @@ -414,7 +413,7 @@ func readByteLenType(ti *typeInfo, r *tdsBuffer) interface{} { case 8: return decodeDateTime(buf) default: - mssqlerror.BadStreamPanicf("Invalid size for DATETIMENTYPE") + badStreamPanicf("Invalid size for DATETIMENTYPE") } case typeChar, typeVarChar: return decodeChar(ti.Collation, buf) @@ -426,7 +425,7 @@ func readByteLenType(ti *typeInfo, r *tdsBuffer) interface{} { copy(cpy, buf) return cpy default: - mssqlerror.BadStreamPanicf("Invalid typeid") + badStreamPanicf("Invalid typeid") } panic("shoulnd't get here") } @@ -465,7 +464,7 @@ func readShortLenType(ti *typeInfo, r *tdsBuffer) interface{} { case typeUdt: return decodeUdt(*ti, buf) default: - mssqlerror.BadStreamPanicf("Invalid typeid") + badStreamPanicf("Invalid typeid") } panic("shoulnd't get here") } @@ -513,7 +512,7 @@ func readLongLenType(ti *typeInfo, r *tdsBuffer) interface{} { case typeNText: return decodeNChar(buf) default: - mssqlerror.BadStreamPanicf("Invalid typeid") + badStreamPanicf("Invalid typeid") } panic("shoulnd't get here") } @@ -646,7 +645,7 @@ func readVariantType(ti *typeInfo, r *tdsBuffer) interface{} { r.ReadFull(buf) return decodeNChar(buf) default: - mssqlerror.BadStreamPanicf("Invalid variant typeid") + badStreamPanicf("Invalid variant typeid") } panic("shoulnd't get here") } @@ -672,7 +671,7 @@ func readPLPType(ti *typeInfo, r *tdsBuffer) interface{} { break } if _, err := io.CopyN(buf, r, int64(chunksize)); err != nil { - mssqlerror.BadStreamPanicf("Reading PLP type failed: %s", err.Error()) + badStreamPanicf("Reading PLP type failed: %s", err.Error()) } } switch ti.TypeId { @@ -726,7 +725,7 @@ func readVarLen(ti *typeInfo, r *tdsBuffer) { case 5, 6, 7: ti.Size = 5 default: - mssqlerror.BadStreamPanicf("Invalid scale for TIME/DATETIME2/DATETIMEOFFSET type") + badStreamPanicf("Invalid scale for TIME/DATETIME2/DATETIMEOFFSET type") } switch ti.TypeId { case typeDateTime2N: @@ -806,7 +805,7 @@ func readVarLen(ti *typeInfo, r *tdsBuffer) { ti.Reader = readVariantType } default: - mssqlerror.BadStreamPanicf("Invalid type %d", ti.TypeId) + badStreamPanicf("Invalid type %d", ti.TypeId) } return } @@ -995,7 +994,7 @@ func decodeChar(col cp.Collation, buf []byte) string { func decodeUcs2(buf []byte) string { res, err := ucs22str(buf) if err != nil { - mssqlerror.BadStreamPanicf("Invalid UCS2 encoding: %s", err.Error()) + badStreamPanicf("Invalid UCS2 encoding: %s", err.Error()) } return res }