diff --git a/buf.go b/buf.go index 365acd48..bf8a7c7c 100644 --- a/buf.go +++ b/buf.go @@ -115,15 +115,23 @@ func (w *tdsBuffer) WriteByte(b byte) error { return nil } -func (w *tdsBuffer) BeginPacket(packetType packetType) { - w.wbuf[1] = 0 // Packet is incomplete. This byte is set again in FinishPacket. +func (w *tdsBuffer) BeginPacket(packetType packetType, resetSession bool) { + status := byte(0) + if resetSession { + switch packetType { + // Reset session can only be set on the following packet types. + case packSQLBatch, packRPCRequest, packTransMgrReq: + status = 0x8 + } + } + w.wbuf[1] = status // Packet is incomplete. This byte is set again in FinishPacket. w.wpos = 8 w.wPacketSeq = 1 w.wPacketType = packetType } func (w *tdsBuffer) FinishPacket() error { - w.wbuf[1] = 1 // Mark this as the last packet in the message. + w.wbuf[1] |= 1 // Mark this as the last packet in the message. return w.flush() } diff --git a/buf_test.go b/buf_test.go index dd4c8fe2..4149d7b5 100644 --- a/buf_test.go +++ b/buf_test.go @@ -151,7 +151,7 @@ func TestReadFailsOnSecondPacket(t *testing.T) { func TestWrite(t *testing.T) { memBuf := bytes.NewBuffer([]byte{}) buf := newTdsBuffer(11, closableBuffer{memBuf}) - buf.BeginPacket(1) + buf.BeginPacket(1, false) err := buf.WriteByte(2) if err != nil { t.Fatal("WriteByte failed:", err.Error()) @@ -172,7 +172,7 @@ func TestWrite(t *testing.T) { t.Fatalf("Written buffer has invalid content: %v", memBuf.Bytes()) } - buf.BeginPacket(2) + buf.BeginPacket(2, false) wrote, err = buf.Write([]byte{3, 4, 5, 6}) if err != nil { t.Fatal("Write failed:", err.Error()) diff --git a/bulkcopy.go b/bulkcopy.go index 8c0a4e0a..235d81ef 100644 --- a/bulkcopy.go +++ b/bulkcopy.go @@ -128,7 +128,7 @@ func (b *Bulk) sendBulkCommand() (err error) { b.headerSent = true var buf = b.cn.sess.buf - buf.BeginPacket(packBulkLoadBCP) + buf.BeginPacket(packBulkLoadBCP, false) // send the columns metadata columnMetadata := b.createColMetadata() diff --git a/bulkcopy_test.go b/bulkcopy_test.go index 5f6cd549..171f2faa 100644 --- a/bulkcopy_test.go +++ b/bulkcopy_test.go @@ -1,6 +1,9 @@ +// +build go1.9 + package mssql import ( + "context" "database/sql" "encoding/hex" "log" @@ -20,6 +23,7 @@ func TestBulkcopy(t *testing.T) { colname string val interface{} } + tableName := "#table_test" geom, _ := hex.DecodeString("E6100000010C00000000000034400000000000004440") testValues := []testValue{ @@ -71,18 +75,30 @@ func TestBulkcopy(t *testing.T) { values[i] = val.val } - conn := open(t) + pool := open(t) + defer pool.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Now that session resetting is supported, the use of the per session + // temp table requires the use of a dedicated connection from the connection + // pool. + conn, err := pool.Conn(ctx) + if err != nil { + t.Error("failed to pull connection from pool", err) + } defer conn.Close() - err := setupTable(conn, tableName) + err = setupTable(ctx, conn, tableName) if err != nil { - t.Error("Setup table failed: ", err.Error()) + t.Error("Setup table failed: ", err) return } log.Println("Preparing copyin statement") - stmt, err := conn.Prepare(CopyIn(tableName, BulkOptions{}, columns...)) + stmt, err := conn.PrepareContext(ctx, CopyIn(tableName, BulkOptions{}, columns...)) for i := 0; i < 10; i++ { log.Printf("Executing copy in statement %d time with %d values", i+1, len(values)) @@ -105,14 +121,14 @@ func TestBulkcopy(t *testing.T) { //check that all rows are present var rowCount int - err = conn.QueryRow("select count(*) c from " + tableName).Scan(&rowCount) + err = conn.QueryRowContext(ctx, "select count(*) c from "+tableName).Scan(&rowCount) if rowCount != 10 { t.Errorf("unexpected row count %d", rowCount) } //data verification - rows, err := conn.Query("select " + strings.Join(columns, ",") + " from " + tableName) + rows, err := conn.QueryContext(ctx, "select "+strings.Join(columns, ",")+" from "+tableName) if err != nil { log.Fatal(err) } @@ -158,7 +174,7 @@ func compareValue(a interface{}, expected interface{}) bool { } } -func setupTable(conn *sql.DB, tableName string) (err error) { +func setupTable(ctx context.Context, conn *sql.Conn, tableName string) (err error) { tablesql := `CREATE TABLE ` + tableName + ` ( [id] [int] IDENTITY(1,1) NOT NULL, [test_nvarchar] [nvarchar](50) NULL, @@ -203,7 +219,7 @@ func setupTable(conn *sql.DB, tableName string) (err error) { [id] ASC )WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY] ) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY];` - _, err = conn.Exec(tablesql) + _, err = conn.ExecContext(ctx, tablesql) if err != nil { log.Fatal("tablesql failed:", err) } diff --git a/mssql.go b/mssql.go index 8f5ff2d0..238c1272 100644 --- a/mssql.go +++ b/mssql.go @@ -95,6 +95,7 @@ func (d *Driver) SetLogger(logger Logger) { type Conn struct { sess *tdsSession transactionCtx context.Context + resetSession bool processQueryText bool connectionGood bool @@ -102,6 +103,15 @@ type Conn struct { outs map[string]interface{} } +func (c *Conn) ResetSession(ctx context.Context) error { + if !c.connectionGood { + return driver.ErrBadConn + } + c.resetSession = true + + return nil +} + func (c *Conn) checkBadConn(err error) error { // this is a hack to address Issue #275 // we set connectionGood flag to false if @@ -117,6 +127,7 @@ func (c *Conn) checkBadConn(err error) error { case nil: return nil case io.EOF: + c.connectionGood = false return driver.ErrBadConn case driver.ErrBadConn: // It is an internal programming error if driver.ErrBadConn @@ -174,7 +185,9 @@ func (c *Conn) sendCommitRequest() error { {hdrtype: dataStmHdrTransDescr, data: transDescrHdr{c.sess.tranid, 1}.pack()}, } - if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, ""); err != nil { + reset := c.resetSession + c.resetSession = false + if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, "", reset); err != nil { if c.sess.logFlags&logErrors != 0 { c.sess.log.Printf("Failed to send CommitXact with %v", err) } @@ -199,7 +212,9 @@ func (c *Conn) sendRollbackRequest() error { {hdrtype: dataStmHdrTransDescr, data: transDescrHdr{c.sess.tranid, 1}.pack()}, } - if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, ""); err != nil { + reset := c.resetSession + c.resetSession = false + if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, "", reset); err != nil { if c.sess.logFlags&logErrors != 0 { c.sess.log.Printf("Failed to send RollbackXact with %v", err) } @@ -234,7 +249,9 @@ func (c *Conn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) erro {hdrtype: dataStmHdrTransDescr, data: transDescrHdr{0, 1}.pack()}, } - if err := sendBeginXact(c.sess.buf, headers, tdsIsolation, ""); err != nil { + reset := c.resetSession + c.resetSession = false + if err := sendBeginXact(c.sess.buf, headers, tdsIsolation, "", reset); err != nil { if c.sess.logFlags&logErrors != 0 { c.sess.log.Printf("Failed to send BeginXact with %v", err) } @@ -362,11 +379,13 @@ func (s *Stmt) sendQuery(args []namedValue) (err error) { }) } + conn := s.c + // no need to check number of parameters here, it is checked by database/sql - if s.c.sess.logFlags&logSQL != 0 { - s.c.sess.log.Println(s.query) + if conn.sess.logFlags&logSQL != 0 { + conn.sess.log.Println(s.query) } - if s.c.sess.logFlags&logParams != 0 && len(args) > 0 { + if conn.sess.logFlags&logParams != 0 && len(args) > 0 { for i := 0; i < len(args); i++ { if len(args[i].Name) > 0 { s.c.sess.log.Printf("\t@%s\t%v\n", args[i].Name, args[i].Value) @@ -374,14 +393,16 @@ func (s *Stmt) sendQuery(args []namedValue) (err error) { s.c.sess.log.Printf("\t@p%d\t%v\n", i+1, args[i].Value) } } - } + + reset := conn.resetSession + conn.resetSession = false if len(args) == 0 { - if err = sendSqlBatch72(s.c.sess.buf, s.query, headers); err != nil { - if s.c.sess.logFlags&logErrors != 0 { - s.c.sess.log.Printf("Failed to send SqlBatch with %v", err) + if err = sendSqlBatch72(conn.sess.buf, s.query, headers, reset); err != nil { + if conn.sess.logFlags&logErrors != 0 { + conn.sess.log.Printf("Failed to send SqlBatch with %v", err) } - s.c.connectionGood = false + conn.connectionGood = false return fmt.Errorf("failed to send SQL Batch: %v", err) } } else { @@ -399,11 +420,11 @@ func (s *Stmt) sendQuery(args []namedValue) (err error) { params[0] = makeStrParam(s.query) params[1] = makeStrParam(strings.Join(decls, ",")) } - if err = sendRpc(s.c.sess.buf, headers, proc, 0, params); err != nil { - if s.c.sess.logFlags&logErrors != 0 { - s.c.sess.log.Printf("Failed to send Rpc with %v", err) + if err = sendRpc(conn.sess.buf, headers, proc, 0, params, reset); err != nil { + if conn.sess.logFlags&logErrors != 0 { + conn.sess.log.Printf("Failed to send Rpc with %v", err) } - s.c.connectionGood = false + conn.connectionGood = false return fmt.Errorf("Failed to send RPC: %v", err) } } diff --git a/mssql_go110.go b/mssql_go110.go new file mode 100644 index 00000000..adab4999 --- /dev/null +++ b/mssql_go110.go @@ -0,0 +1,10 @@ +// +build go1.10 + +package mssql + +import ( + "database/sql/driver" +) + +var _ driver.Connector = &Connector{} +var _ driver.SessionResetter = &Conn{} diff --git a/net.go b/net.go index 8c3c8ef8..73d5ac93 100644 --- a/net.go +++ b/net.go @@ -58,7 +58,7 @@ func (c *timeoutConn) Read(b []byte) (n int, err error) { func (c *timeoutConn) Write(b []byte) (n int, err error) { if c.buf != nil { if !c.packetPending { - c.buf.BeginPacket(packPrelogin) + c.buf.BeginPacket(packPrelogin, false) c.packetPending = true } n, err = c.buf.Write(b) diff --git a/rpc.go b/rpc.go index 00b9b1e2..873474d4 100644 --- a/rpc.go +++ b/rpc.go @@ -57,8 +57,8 @@ var ( ) // http://msdn.microsoft.com/en-us/library/dd357576.aspx -func sendRpc(buf *tdsBuffer, headers []headerStruct, proc ProcId, flags uint16, params []Param) (err error) { - buf.BeginPacket(packRPCRequest) +func sendRpc(buf *tdsBuffer, headers []headerStruct, proc ProcId, flags uint16, params []Param, resetSession bool) (err error) { + buf.BeginPacket(packRPCRequest, resetSession) writeAllHeaders(buf, headers) if len(proc.name) == 0 { var idswitch uint16 = 0xffff diff --git a/tds.go b/tds.go index 54ac6dba..b17c65e6 100644 --- a/tds.go +++ b/tds.go @@ -162,7 +162,7 @@ func (p KeySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } func writePrelogin(w *tdsBuffer, fields map[uint8][]byte) error { var err error - w.BeginPacket(packPrelogin) + w.BeginPacket(packPrelogin, false) offset := uint16(5*len(fields) + 1) keys := make(KeySlice, 0, len(fields)) for k, _ := range fields { @@ -352,7 +352,7 @@ func manglePassword(password string) []byte { // http://msdn.microsoft.com/en-us/library/dd304019.aspx func sendLogin(w *tdsBuffer, login login) error { - w.BeginPacket(packLogin7) + w.BeginPacket(packLogin7, false) hostname := str2ucs2(login.HostName) username := str2ucs2(login.UserName) password := manglePassword(login.Password) @@ -633,8 +633,8 @@ func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) { return nil } -func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct) (err error) { - buf.BeginPacket(packSQLBatch) +func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct, resetSession bool) (err error) { + buf.BeginPacket(packSQLBatch, resetSession) if err = writeAllHeaders(buf, headers); err != nil { return @@ -650,7 +650,7 @@ func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct) (err // 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx // 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx func sendAttention(buf *tdsBuffer) error { - buf.BeginPacket(packAttention) + buf.BeginPacket(packAttention, false) return buf.FinishPacket() } @@ -1337,7 +1337,7 @@ continue_login: } } if sspi_msg != nil { - outbuf.BeginPacket(packSSPIMessage) + outbuf.BeginPacket(packSSPIMessage, false) _, err = outbuf.Write(sspi_msg) if err != nil { return nil, err diff --git a/tds_test.go b/tds_test.go index 40e2f92a..683ed7aa 100644 --- a/tds_test.go +++ b/tds_test.go @@ -89,7 +89,7 @@ func TestSendSqlBatch(t *testing.T) { {hdrtype: dataStmHdrTransDescr, data: transDescrHdr{0, 1}.pack()}, } - err = sendSqlBatch72(conn.buf, "select 1", headers) + err = sendSqlBatch72(conn.buf, "select 1", headers, true) if err != nil { t.Error("Sending sql batch failed", err.Error()) return diff --git a/tran.go b/tran.go index 75e7a2ae..cb643681 100644 --- a/tran.go +++ b/tran.go @@ -28,9 +28,8 @@ const ( isolationSnapshot = 5 ) -func sendBeginXact(buf *tdsBuffer, headers []headerStruct, isolation isoLevel, - name string) (err error) { - buf.BeginPacket(packTransMgrReq) +func sendBeginXact(buf *tdsBuffer, headers []headerStruct, isolation isoLevel, name string, resetSession bool) (err error) { + buf.BeginPacket(packTransMgrReq, resetSession) writeAllHeaders(buf, headers) var rqtype uint16 = tmBeginXact err = binary.Write(buf, binary.LittleEndian, &rqtype) @@ -52,8 +51,8 @@ const ( fBeginXact = 1 ) -func sendCommitXact(buf *tdsBuffer, headers []headerStruct, name string, flags uint8, isolation uint8, newname string) error { - buf.BeginPacket(packTransMgrReq) +func sendCommitXact(buf *tdsBuffer, headers []headerStruct, name string, flags uint8, isolation uint8, newname string, resetSession bool) error { + buf.BeginPacket(packTransMgrReq, resetSession) writeAllHeaders(buf, headers) var rqtype uint16 = tmCommitXact err := binary.Write(buf, binary.LittleEndian, &rqtype) @@ -81,8 +80,8 @@ func sendCommitXact(buf *tdsBuffer, headers []headerStruct, name string, flags u return buf.FinishPacket() } -func sendRollbackXact(buf *tdsBuffer, headers []headerStruct, name string, flags uint8, isolation uint8, newname string) error { - buf.BeginPacket(packTransMgrReq) +func sendRollbackXact(buf *tdsBuffer, headers []headerStruct, name string, flags uint8, isolation uint8, newname string, resetSession bool) error { + buf.BeginPacket(packTransMgrReq, resetSession) writeAllHeaders(buf, headers) var rqtype uint16 = tmRollbackXact err := binary.Write(buf, binary.LittleEndian, &rqtype)