diff --git a/appveyor.yml b/appveyor.yml index 2c143550..471cf477 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -41,5 +41,5 @@ before_test: test_script: - - go test -race -coverprofile=coverage.txt -covermode=atomic + - go test -race -cpu 4 -coverprofile=coverage.txt -covermode=atomic - codecov -f coverage.txt diff --git a/bulkcopy.go b/bulkcopy.go index 235d81ef..14c6b442 100644 --- a/bulkcopy.go +++ b/bulkcopy.go @@ -13,6 +13,12 @@ import ( ) type Bulk struct { + // ctx is used only for AddRow and Done methods. + // This could be removed if AddRow and Done accepted + // a ctx field as well, which is available with the + // database/sql call. + ctx context.Context + cn *Conn metadata []columnStruct bulkColumns []columnStruct @@ -37,14 +43,20 @@ type BulkOptions struct { type DataValue interface{} func (cn *Conn) CreateBulk(table string, columns []string) (_ *Bulk) { - b := Bulk{cn: cn, tablename: table, headerSent: false, columnsName: columns} + b := Bulk{ctx: context.Background(), cn: cn, tablename: table, headerSent: false, columnsName: columns} b.Debug = false return &b } -func (b *Bulk) sendBulkCommand() (err error) { +func (cn *Conn) CreateBulkContext(ctx context.Context, table string, columns []string) (_ *Bulk) { + b := Bulk{ctx: ctx, cn: cn, tablename: table, headerSent: false, columnsName: columns} + b.Debug = false + return &b +} + +func (b *Bulk) sendBulkCommand(ctx context.Context) (err error) { //get table columns info - err = b.getMetadata() + err = b.getMetadata(ctx) if err != nil { return err } @@ -114,13 +126,13 @@ func (b *Bulk) sendBulkCommand() (err error) { query := fmt.Sprintf("INSERT BULK %s (%s) %s", b.tablename, col_defs.String(), with_part) - stmt, err := b.cn.Prepare(query) + stmt, err := b.cn.PrepareContext(ctx, query) if err != nil { return fmt.Errorf("Prepare failed: %s", err.Error()) } b.dlogf(query) - _, err = stmt.Exec(nil) + _, err = stmt.(*Stmt).ExecContext(ctx, nil) if err != nil { return err } @@ -130,7 +142,7 @@ func (b *Bulk) sendBulkCommand() (err error) { var buf = b.cn.sess.buf buf.BeginPacket(packBulkLoadBCP, false) - // send the columns metadata + // Send the columns metadata. columnMetadata := b.createColMetadata() _, err = buf.Write(columnMetadata) @@ -141,7 +153,7 @@ func (b *Bulk) sendBulkCommand() (err error) { // The arguments are the row values in the order they were specified. func (b *Bulk) AddRow(row []interface{}) (err error) { if !b.headerSent { - err = b.sendBulkCommand() + err = b.sendBulkCommand(b.ctx) if err != nil { return } @@ -216,7 +228,7 @@ func (b *Bulk) Done() (rowcount int64, err error) { buf.FinishPacket() tokchan := make(chan tokenStruct, 5) - go processResponse(context.Background(), b.cn.sess, tokchan, nil) + go processResponse(b.ctx, b.cn.sess, tokchan, nil) var rowCount int64 for token := range tokchan { @@ -267,28 +279,27 @@ func (b *Bulk) createColMetadata() []byte { return buf.Bytes() } -func (b *Bulk) getMetadata() (err error) { - stmt, err := b.cn.Prepare("SET FMTONLY ON") +func (b *Bulk) getMetadata(ctx context.Context) (err error) { + stmt, err := b.cn.prepareContext(ctx, "SET FMTONLY ON") if err != nil { return } - _, err = stmt.Exec(nil) + _, err = stmt.ExecContext(ctx, nil) if err != nil { return } - //get columns info - stmt, err = b.cn.Prepare(fmt.Sprintf("select * from %s SET FMTONLY OFF", b.tablename)) + // Get columns info. + stmt, err = b.cn.prepareContext(ctx, fmt.Sprintf("select * from %s SET FMTONLY OFF", b.tablename)) if err != nil { return } - stmt2 := stmt.(*Stmt) - cols, err := stmt2.QueryMeta() + rows, err := stmt.QueryContext(ctx, nil) if err != nil { - return fmt.Errorf("get columns info failed: %v", err.Error()) + return fmt.Errorf("get columns info failed: %v", err) } - b.metadata = cols + b.metadata = rows.(*Rows).cols if b.Debug { for _, col := range b.metadata { @@ -298,30 +309,7 @@ func (b *Bulk) getMetadata() (err error) { } } - return nil -} - -// QueryMeta is almost the same as mssql.Stmt.Query, but returns all the columns info. -func (s *Stmt) QueryMeta() (cols []columnStruct, err error) { - if err = s.sendQuery(nil); err != nil { - return - } - tokchan := make(chan tokenStruct, 5) - go processResponse(context.Background(), s.c.sess, tokchan, s.c.outs) - s.c.clearOuts() -loop: - for tok := range tokchan { - switch token := tok.(type) { - case doneStruct: - break loop - case []columnStruct: - cols = token - break loop - case error: - return nil, s.c.checkBadConn(token) - } - } - return cols, nil + return rows.Close() } func (b *Bulk) makeParam(val DataValue, col columnStruct) (res Param, err error) { diff --git a/bulkcopy_sql.go b/bulkcopy_sql.go index 0af51df8..4824df9a 100644 --- a/bulkcopy_sql.go +++ b/bulkcopy_sql.go @@ -23,7 +23,7 @@ func (d *Driver) OpenConnection(dsn string) (*Conn, error) { return d.open(context.Background(), dsn) } -func (c *Conn) prepareCopyIn(query string) (_ driver.Stmt, err error) { +func (c *Conn) prepareCopyIn(ctx context.Context, query string) (_ driver.Stmt, err error) { config_json := query[11:] bulkconfig := serializableBulkConfig{} @@ -32,7 +32,7 @@ func (c *Conn) prepareCopyIn(query string) (_ driver.Stmt, err error) { return } - bulkcopy := c.CreateBulk(bulkconfig.TableName, bulkconfig.ColumnsName) + bulkcopy := c.CreateBulkContext(ctx, bulkconfig.TableName, bulkconfig.ColumnsName) bulkcopy.Options = bulkconfig.Options ci := ©in{ @@ -61,7 +61,7 @@ func (ci *copyin) NumInput() int { } func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) { - return nil, errors.New("ErrNotSupported") + panic("should never be called") } func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) { diff --git a/bulkcopy_test.go b/bulkcopy_test.go index 4f173de2..38d908ba 100644 --- a/bulkcopy_test.go +++ b/bulkcopy_test.go @@ -95,7 +95,7 @@ func TestBulkcopy(t *testing.T) { return } - t.Log("Preparing copyin statement") + t.Log("Preparing copy in statement") stmt, err := conn.PrepareContext(ctx, CopyIn(tableName, BulkOptions{}, columns...)) diff --git a/mssql.go b/mssql.go index ed52625c..50c57e01 100644 --- a/mssql.go +++ b/mssql.go @@ -333,9 +333,8 @@ func (c *Conn) Prepare(query string) (driver.Stmt, error) { return nil, driver.ErrBadConn } if len(query) > 10 && strings.EqualFold(query[:10], "INSERTBULK") { - return c.prepareCopyIn(query) + return c.prepareCopyIn(context.Background(), query) } - return c.prepareContext(context.Background(), query) } diff --git a/mssql_go18.go b/mssql_go18.go index 74179c47..082a65ed 100644 --- a/mssql_go18.go +++ b/mssql_go18.go @@ -62,7 +62,7 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e return nil, driver.ErrBadConn } if len(query) > 10 && strings.EqualFold(query[:10], "INSERTBULK") { - return c.prepareCopyIn(query) + return c.prepareCopyIn(ctx, query) } return c.prepareContext(ctx, query)