Skip to content

Commit

Permalink
mssql: support custom SQL to init connection setup
Browse files Browse the repository at this point in the history
  • Loading branch information
kardianos committed Mar 15, 2018
1 parent 9340cdc commit 73efb3e
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 34 deletions.
13 changes: 6 additions & 7 deletions bulkcopy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"context"
"database/sql"
"encoding/hex"
"log"
"math"
"reflect"
"strings"
Expand Down Expand Up @@ -90,18 +89,18 @@ func TestBulkcopy(t *testing.T) {
}
defer conn.Close()

err = setupTable(ctx, conn, tableName)
err = setupTable(ctx, t, conn, tableName)
if err != nil {
t.Error("Setup table failed: ", err)
return
}

log.Println("Preparing copyin statement")
t.Log("Preparing copyin statement")

stmt, err := conn.PrepareContext(ctx, CopyIn(tableName, BulkOptions{}, columns...))

for i := 0; i < 10; i++ {
log.Printf("Executing copy in statement %d time with %d values", i+1, len(values))
t.Logf("Executing copy in statement %d time with %d values", i+1, len(values))
_, err = stmt.Exec(values...)
if err != nil {
t.Error("AddRow failed: ", err.Error())
Expand Down Expand Up @@ -130,7 +129,7 @@ func TestBulkcopy(t *testing.T) {
//data verification
rows, err := conn.QueryContext(ctx, "select "+strings.Join(columns, ",")+" from "+tableName)
if err != nil {
log.Fatal(err)
t.Fatal(err)
}
defer rows.Close()
for rows.Next() {
Expand Down Expand Up @@ -174,7 +173,7 @@ func compareValue(a interface{}, expected interface{}) bool {
}
}

func setupTable(ctx context.Context, conn *sql.Conn, tableName string) (err error) {
func setupTable(ctx context.Context, t *testing.T, conn *sql.Conn, tableName string) (err error) {
tablesql := `CREATE TABLE ` + tableName + ` (
[id] [int] IDENTITY(1,1) NOT NULL,
[test_nvarchar] [nvarchar](50) NULL,
Expand Down Expand Up @@ -221,7 +220,7 @@ func setupTable(ctx context.Context, conn *sql.Conn, tableName string) (err erro
) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY];`
_, err = conn.ExecContext(ctx, tablesql)
if err != nil {
log.Fatal("tablesql failed:", err)
t.Fatal("tablesql failed:", err)
}
return
}
56 changes: 29 additions & 27 deletions mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ func (d *Driver) Open(dsn string) (driver.Conn, error) {
return d.open(context.Background(), dsn)
}

func SetLogger(logger Logger) {
driverInstance.SetLogger(logger)
driverInstanceNoProcess.SetLogger(logger)
}

func (d *Driver) SetLogger(logger Logger) {
d.log = optionalLogger{logger}
}

// Connector holds the parsed DSN and is ready to make a new connection
// at any time.
//
Expand All @@ -71,28 +80,29 @@ func (d *Driver) Open(dsn string) (driver.Conn, error) {
type Connector struct {
params connectParams
driver *Driver
}

// Connect to the server and return a TDS connection.
func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) {
return c.driver.connect(ctx, c.params)
}

// Driver underlying the Connector.
func (c *Connector) Driver() driver.Driver {
return c.driver
}

func SetLogger(logger Logger) {
driverInstance.SetLogger(logger)
driverInstanceNoProcess.SetLogger(logger)
}

func (d *Driver) SetLogger(logger Logger) {
d.log = optionalLogger{logger}
// ResetSQL is executed after marking a given connection to be reset.
// When not present, the next query will be reset to the database
// defaults.
// When present the connection will immediately mark the connection to
// be reset, then execute the ResetSQL text to setup the session
// that may be different from the base database defaults.
//
// For Example, the application relies on the following defaults
// but is not allowed to set them at the database system level.
//
// SET XACT_ABORT ON;
// SET TEXTSIZE -1;
// SET ANSI_NULLS ON;
// SET LOCK_TIMEOUT 10000;
//
// ResetSQL should not attempt to manually call sp_reset_connection.
// This will happen at the TDS layer.
ResetSQL string
}

type Conn struct {
connector *Connector
sess *tdsSession
transactionCtx context.Context
resetSession bool
Expand All @@ -103,15 +113,6 @@ type Conn struct {
outs map[string]interface{}
}

func (c *Conn) ResetSession(ctx context.Context) error {
if !c.connectionGood {
return driver.ErrBadConn
}
c.resetSession = true

return nil
}

func (c *Conn) checkBadConn(err error) error {
// this is a hack to address Issue #275
// we set connectionGood flag to false if
Expand Down Expand Up @@ -306,6 +307,7 @@ func (d *Driver) connect(ctx context.Context, params connectParams) (*Conn, erro
connectionGood: true,
}
conn.sess.log = d.log

return conn, nil
}

Expand Down
40 changes: 40 additions & 0 deletions mssql_go110.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,48 @@
package mssql

import (
"context"
"database/sql/driver"
)

var _ driver.Connector = &Connector{}
var _ driver.SessionResetter = &Conn{}

func (c *Conn) ResetSession(ctx context.Context) error {
if !c.connectionGood {
return driver.ErrBadConn
}
c.resetSession = true

if c.connector == nil || len(c.connector.ResetSQL) == 0 {
return nil
}

s, err := c.prepareContext(ctx, c.connector.ResetSQL)
if err != nil {
return driver.ErrBadConn
}
_, err = s.exec(ctx, nil)
if err != nil {
return driver.ErrBadConn
}

return nil
}

// Connect to the server and return a TDS connection.
func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) {
conn, err := c.driver.connect(ctx, c.params)
if conn != nil {
conn.connector = c
}
if err == nil {
err = conn.ResetSession(ctx)
}
return conn, err
}

// Driver underlying the Connector.
func (c *Connector) Driver() driver.Driver {
return c.driver
}
43 changes: 43 additions & 0 deletions queries_go110_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// +build go1.10

package mssql

import (
"context"
"database/sql"
"testing"
)

func TestResetSQL(t *testing.T) {
checkConnStr(t)
SetLogger(testLogger{t})

d := &Driver{}
connector, err := d.OpenConnector(makeConnStr(t).String())
if err != nil {
t.Fatal("unable to open connector", err)
}
connector.ResetSQL = `
SET XACT_ABORT ON; -- 16384
SET ANSI_NULLS ON; -- 32
SET ARITHIGNORE ON; -- 128
`

pool := sql.OpenDB(connector)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

var opt int32
err = pool.QueryRowContext(ctx, `
select Options = @@OPTIONS;
`).Scan(&opt)
if err != nil {
t.Fatal("failed to run query", err)
}
mask := int32(16384 | 128 | 32)

if opt&mask != mask {
t.Fatal("incorrect session settings", opt)
}
}

0 comments on commit 73efb3e

Please sign in to comment.