From 6a30f4e59a440ad6f213e0b3d45ecd3923d211b7 Mon Sep 17 00:00:00 2001 From: Daniel Theophanes Date: Thu, 15 Mar 2018 11:00:31 -0700 Subject: [PATCH] mssql: rename structs to prevent stutter. Use alias for compatibility for now. --- bulkcopy.go | 34 ++++++------ bulkcopy_sql.go | 12 ++--- bulkcopy_test.go | 2 +- examples/bulk/bulk.go | 2 +- mssql.go | 118 +++++++++++++++++++++--------------------- mssql_go18.go | 16 +++--- mssql_go19.go | 17 ++++-- mssql_go19pre.go | 2 +- queries_go18_test.go | 2 +- queries_test.go | 12 ++--- 10 files changed, 114 insertions(+), 103 deletions(-) diff --git a/bulkcopy.go b/bulkcopy.go index 3a2f7dff..8c0a4e0a 100644 --- a/bulkcopy.go +++ b/bulkcopy.go @@ -12,8 +12,8 @@ import ( "time" ) -type MssqlBulk struct { - cn *MssqlConn +type Bulk struct { + cn *Conn metadata []columnStruct bulkColumns []columnStruct columnsName []string @@ -21,10 +21,10 @@ type MssqlBulk struct { numRows int headerSent bool - Options MssqlBulkOptions + Options BulkOptions Debug bool } -type MssqlBulkOptions struct { +type BulkOptions struct { CheckConstraints bool FireTriggers bool KeepNulls bool @@ -36,13 +36,13 @@ type MssqlBulkOptions struct { type DataValue interface{} -func (cn *MssqlConn) CreateBulk(table string, columns []string) (_ *MssqlBulk) { - b := MssqlBulk{cn: cn, tablename: table, headerSent: false, columnsName: columns} +func (cn *Conn) CreateBulk(table string, columns []string) (_ *Bulk) { + b := Bulk{cn: cn, tablename: table, headerSent: false, columnsName: columns} b.Debug = false return &b } -func (b *MssqlBulk) sendBulkCommand() (err error) { +func (b *Bulk) sendBulkCommand() (err error) { //get table columns info err = b.getMetadata() if err != nil { @@ -139,7 +139,7 @@ func (b *MssqlBulk) sendBulkCommand() (err error) { // AddRow immediately writes the row to the destination table. // The arguments are the row values in the order they were specified. -func (b *MssqlBulk) AddRow(row []interface{}) (err error) { +func (b *Bulk) AddRow(row []interface{}) (err error) { if !b.headerSent { err = b.sendBulkCommand() if err != nil { @@ -166,7 +166,7 @@ func (b *MssqlBulk) AddRow(row []interface{}) (err error) { return } -func (b *MssqlBulk) makeRowData(row []interface{}) ([]byte, error) { +func (b *Bulk) makeRowData(row []interface{}) ([]byte, error) { buf := new(bytes.Buffer) buf.WriteByte(byte(tokenRow)) @@ -196,7 +196,7 @@ func (b *MssqlBulk) makeRowData(row []interface{}) ([]byte, error) { return buf.Bytes(), nil } -func (b *MssqlBulk) Done() (rowcount int64, err error) { +func (b *Bulk) Done() (rowcount int64, err error) { if b.headerSent == false { //no rows had been sent return 0, nil @@ -235,7 +235,7 @@ func (b *MssqlBulk) Done() (rowcount int64, err error) { return rowCount, nil } -func (b *MssqlBulk) createColMetadata() []byte { +func (b *Bulk) createColMetadata() []byte { buf := new(bytes.Buffer) buf.WriteByte(byte(tokenColMetadata)) // token binary.Write(buf, binary.LittleEndian, uint16(len(b.bulkColumns))) // column count @@ -267,7 +267,7 @@ func (b *MssqlBulk) createColMetadata() []byte { return buf.Bytes() } -func (b *MssqlBulk) getMetadata() (err error) { +func (b *Bulk) getMetadata() (err error) { stmt, err := b.cn.Prepare("SET FMTONLY ON") if err != nil { return @@ -283,7 +283,7 @@ func (b *MssqlBulk) getMetadata() (err error) { if err != nil { return } - stmt2 := stmt.(*MssqlStmt) + stmt2 := stmt.(*Stmt) cols, err := stmt2.QueryMeta() if err != nil { return fmt.Errorf("get columns info failed: %v", err.Error()) @@ -301,8 +301,8 @@ func (b *MssqlBulk) getMetadata() (err error) { return nil } -// QueryMeta is almost the same as MssqlStmt.Query, but returns all the columns info. -func (s *MssqlStmt) QueryMeta() (cols []columnStruct, err error) { +// QueryMeta is almost the same as mssql.Stmt.Query, but returns all the columns info. +func (s *Stmt) QueryMeta() (cols []columnStruct, err error) { if err = s.sendQuery(nil); err != nil { return } @@ -324,7 +324,7 @@ loop: return cols, nil } -func (b *MssqlBulk) makeParam(val DataValue, col columnStruct) (res Param, err error) { +func (b *Bulk) makeParam(val DataValue, col columnStruct) (res Param, err error) { res.ti.Size = col.ti.Size res.ti.TypeId = col.ti.TypeId @@ -609,7 +609,7 @@ func (b *MssqlBulk) makeParam(val DataValue, col columnStruct) (res Param, err e } -func (b *MssqlBulk) dlogf(format string, v ...interface{}) { +func (b *Bulk) dlogf(format string, v ...interface{}) { if b.Debug { b.cn.sess.log.Printf(format, v...) } diff --git a/bulkcopy_sql.go b/bulkcopy_sql.go index ea697978..0af51df8 100644 --- a/bulkcopy_sql.go +++ b/bulkcopy_sql.go @@ -8,22 +8,22 @@ import ( ) type copyin struct { - cn *MssqlConn - bulkcopy *MssqlBulk + cn *Conn + bulkcopy *Bulk closed bool } type serializableBulkConfig struct { TableName string ColumnsName []string - Options MssqlBulkOptions + Options BulkOptions } -func (d *MssqlDriver) OpenConnection(dsn string) (*MssqlConn, error) { +func (d *Driver) OpenConnection(dsn string) (*Conn, error) { return d.open(context.Background(), dsn) } -func (c *MssqlConn) prepareCopyIn(query string) (_ driver.Stmt, err error) { +func (c *Conn) prepareCopyIn(query string) (_ driver.Stmt, err error) { config_json := query[11:] bulkconfig := serializableBulkConfig{} @@ -43,7 +43,7 @@ func (c *MssqlConn) prepareCopyIn(query string) (_ driver.Stmt, err error) { return ci, nil } -func CopyIn(table string, options MssqlBulkOptions, columns ...string) string { +func CopyIn(table string, options BulkOptions, columns ...string) string { bulkconfig := &serializableBulkConfig{TableName: table, Options: options, ColumnsName: columns} config_json, err := json.Marshal(bulkconfig) diff --git a/bulkcopy_test.go b/bulkcopy_test.go index ad9ba34e..5f6cd549 100644 --- a/bulkcopy_test.go +++ b/bulkcopy_test.go @@ -82,7 +82,7 @@ func TestBulkcopy(t *testing.T) { log.Println("Preparing copyin statement") - stmt, err := conn.Prepare(CopyIn(tableName, MssqlBulkOptions{}, columns...)) + stmt, err := conn.Prepare(CopyIn(tableName, BulkOptions{}, columns...)) for i := 0; i < 10; i++ { log.Printf("Executing copy in statement %d time with %d values", i+1, len(values)) diff --git a/examples/bulk/bulk.go b/examples/bulk/bulk.go index 33851c35..8ce12388 100644 --- a/examples/bulk/bulk.go +++ b/examples/bulk/bulk.go @@ -60,7 +60,7 @@ func main() { log.Fatal(err) } - stmt, err := txn.Prepare(mssql.CopyIn("test_table", mssql.MssqlBulkOptions{}, "test_varchar", "test_nvarchar", "test_float", "test_bigint")) + stmt, err := txn.Prepare(mssql.CopyIn("test_table", mssql.BulkOptions{}, "test_varchar", "test_nvarchar", "test_float", "test_bigint")) if err != nil { log.Fatal(err.Error()) } diff --git a/mssql.go b/mssql.go index 74dbe68d..8f5ff2d0 100644 --- a/mssql.go +++ b/mssql.go @@ -15,8 +15,8 @@ import ( "time" ) -var driverInstance = &MssqlDriver{processQueryText: true} -var driverInstanceNoProcess = &MssqlDriver{processQueryText: false} +var driverInstance = &Driver{processQueryText: true} +var driverInstanceNoProcess = &Driver{processQueryText: false} func init() { sql.Register("mssql", driverInstance) @@ -41,14 +41,14 @@ func (d tcpDialer) Dial(ctx context.Context, addr string) (net.Conn, error) { return d.nd.DialContext(ctx, "tcp", addr) } -type MssqlDriver struct { +type Driver struct { log optionalLogger processQueryText bool } // OpenConnector opens a new connector. Useful to dial with a context. -func (d *MssqlDriver) OpenConnector(dsn string) (*Connector, error) { +func (d *Driver) OpenConnector(dsn string) (*Connector, error) { params, err := parseConnectParams(dsn) if err != nil { return nil, err @@ -59,7 +59,7 @@ func (d *MssqlDriver) OpenConnector(dsn string) (*Connector, error) { }, nil } -func (d *MssqlDriver) Open(dsn string) (driver.Conn, error) { +func (d *Driver) Open(dsn string) (driver.Conn, error) { return d.open(context.Background(), dsn) } @@ -70,7 +70,7 @@ func (d *MssqlDriver) Open(dsn string) (driver.Conn, error) { // may be set directly on the connector. type Connector struct { params connectParams - driver *MssqlDriver + driver *Driver } // Connect to the server and return a TDS connection. @@ -88,11 +88,11 @@ func SetLogger(logger Logger) { driverInstanceNoProcess.SetLogger(logger) } -func (d *MssqlDriver) SetLogger(logger Logger) { +func (d *Driver) SetLogger(logger Logger) { d.log = optionalLogger{logger} } -type MssqlConn struct { +type Conn struct { sess *tdsSession transactionCtx context.Context @@ -102,7 +102,7 @@ type MssqlConn struct { outs map[string]interface{} } -func (c *MssqlConn) checkBadConn(err error) error { +func (c *Conn) checkBadConn(err error) error { // this is a hack to address Issue #275 // we set connectionGood flag to false if // error indicates that connection is not usable @@ -121,7 +121,7 @@ func (c *MssqlConn) checkBadConn(err error) error { case driver.ErrBadConn: // It is an internal programming error if driver.ErrBadConn // is ever passed to this function. driver.ErrBadConn should - // only ever be returned in response to a *MssqlConn.connectionGood == false + // only ever be returned in response to a *mssql.Conn.connectionGood == false // check in the external facing API. panic("driver.ErrBadConn in checkBadConn. This should not happen.") } @@ -138,11 +138,11 @@ func (c *MssqlConn) checkBadConn(err error) error { } } -func (c *MssqlConn) clearOuts() { +func (c *Conn) clearOuts() { c.outs = nil } -func (c *MssqlConn) simpleProcessResp(ctx context.Context) error { +func (c *Conn) simpleProcessResp(ctx context.Context) error { tokchan := make(chan tokenStruct, 5) go processResponse(ctx, c.sess, tokchan, c.outs) c.clearOuts() @@ -159,7 +159,7 @@ func (c *MssqlConn) simpleProcessResp(ctx context.Context) error { return nil } -func (c *MssqlConn) Commit() error { +func (c *Conn) Commit() error { if !c.connectionGood { return driver.ErrBadConn } @@ -169,7 +169,7 @@ func (c *MssqlConn) Commit() error { return c.simpleProcessResp(c.transactionCtx) } -func (c *MssqlConn) sendCommitRequest() error { +func (c *Conn) sendCommitRequest() error { headers := []headerStruct{ {hdrtype: dataStmHdrTransDescr, data: transDescrHdr{c.sess.tranid, 1}.pack()}, @@ -184,7 +184,7 @@ func (c *MssqlConn) sendCommitRequest() error { return nil } -func (c *MssqlConn) Rollback() error { +func (c *Conn) Rollback() error { if !c.connectionGood { return driver.ErrBadConn } @@ -194,7 +194,7 @@ func (c *MssqlConn) Rollback() error { return c.simpleProcessResp(c.transactionCtx) } -func (c *MssqlConn) sendRollbackRequest() error { +func (c *Conn) sendRollbackRequest() error { headers := []headerStruct{ {hdrtype: dataStmHdrTransDescr, data: transDescrHdr{c.sess.tranid, 1}.pack()}, @@ -209,11 +209,11 @@ func (c *MssqlConn) sendRollbackRequest() error { return nil } -func (c *MssqlConn) Begin() (driver.Tx, error) { +func (c *Conn) Begin() (driver.Tx, error) { return c.begin(context.Background(), isolationUseCurrent) } -func (c *MssqlConn) begin(ctx context.Context, tdsIsolation isoLevel) (tx driver.Tx, err error) { +func (c *Conn) begin(ctx context.Context, tdsIsolation isoLevel) (tx driver.Tx, err error) { if !c.connectionGood { return nil, driver.ErrBadConn } @@ -228,7 +228,7 @@ func (c *MssqlConn) begin(ctx context.Context, tdsIsolation isoLevel) (tx driver return } -func (c *MssqlConn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) error { +func (c *Conn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) error { c.transactionCtx = ctx headers := []headerStruct{ {hdrtype: dataStmHdrTransDescr, @@ -244,7 +244,7 @@ func (c *MssqlConn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) return nil } -func (c *MssqlConn) processBeginResponse(ctx context.Context) (driver.Tx, error) { +func (c *Conn) processBeginResponse(ctx context.Context) (driver.Tx, error) { if err := c.simpleProcessResp(ctx); err != nil { return nil, err } @@ -253,7 +253,7 @@ func (c *MssqlConn) processBeginResponse(ctx context.Context) (driver.Tx, error) return c, nil } -func (d *MssqlDriver) open(ctx context.Context, dsn string) (*MssqlConn, error) { +func (d *Driver) open(ctx context.Context, dsn string) (*Conn, error) { params, err := parseConnectParams(dsn) if err != nil { return nil, err @@ -262,7 +262,7 @@ func (d *MssqlDriver) open(ctx context.Context, dsn string) (*MssqlConn, error) } // connect to the server, using the provided context for dialing only. -func (d *MssqlDriver) connect(ctx context.Context, params connectParams) (*MssqlConn, error) { +func (d *Driver) connect(ctx context.Context, params connectParams) (*Conn, error) { sess, err := connect(ctx, d.log, params) if err != nil { // main server failed, try fail-over partner @@ -282,7 +282,7 @@ func (d *MssqlDriver) connect(ctx context.Context, params connectParams) (*Mssql } } - conn := &MssqlConn{ + conn := &Conn{ sess: sess, transactionCtx: context.Background(), processQueryText: d.processQueryText, @@ -292,12 +292,12 @@ func (d *MssqlDriver) connect(ctx context.Context, params connectParams) (*Mssql return conn, nil } -func (c *MssqlConn) Close() error { +func (c *Conn) Close() error { return c.sess.buf.transport.Close() } -type MssqlStmt struct { - c *MssqlConn +type Stmt struct { + c *Conn query string paramCount int notifSub *queryNotifSub @@ -309,7 +309,7 @@ type queryNotifSub struct { timeout uint32 } -func (c *MssqlConn) Prepare(query string) (driver.Stmt, error) { +func (c *Conn) Prepare(query string) (driver.Stmt, error) { if !c.connectionGood { return nil, driver.ErrBadConn } @@ -320,19 +320,19 @@ func (c *MssqlConn) Prepare(query string) (driver.Stmt, error) { return c.prepareContext(context.Background(), query) } -func (c *MssqlConn) prepareContext(ctx context.Context, query string) (*MssqlStmt, error) { +func (c *Conn) prepareContext(ctx context.Context, query string) (*Stmt, error) { paramCount := -1 if c.processQueryText { query, paramCount = parseParams(query) } - return &MssqlStmt{c, query, paramCount, nil}, nil + return &Stmt{c, query, paramCount, nil}, nil } -func (s *MssqlStmt) Close() error { +func (s *Stmt) Close() error { return nil } -func (s *MssqlStmt) SetQueryNotification(id, options string, timeout time.Duration) { +func (s *Stmt) SetQueryNotification(id, options string, timeout time.Duration) { to := uint32(timeout / time.Second) if to < 1 { to = 1 @@ -340,11 +340,11 @@ func (s *MssqlStmt) SetQueryNotification(id, options string, timeout time.Durati s.notifSub = &queryNotifSub{id, options, to} } -func (s *MssqlStmt) NumInput() int { +func (s *Stmt) NumInput() int { return s.paramCount } -func (s *MssqlStmt) sendQuery(args []namedValue) (err error) { +func (s *Stmt) sendQuery(args []namedValue) (err error) { headers := []headerStruct{ {hdrtype: dataStmHdrTransDescr, data: transDescrHdr{s.c.sess.tranid, 1}.pack()}, @@ -422,7 +422,7 @@ func isProc(s string) bool { return !strings.ContainsAny(s, " \t\n\r;") } -func (s *MssqlStmt) makeRPCParams(args []namedValue, offset int) ([]Param, []string, error) { +func (s *Stmt) makeRPCParams(args []namedValue, offset int) ([]Param, []string, error) { var err error params := make([]Param, len(args)+offset) decls := make([]string, len(args)) @@ -460,11 +460,11 @@ func convertOldArgs(args []driver.Value) []namedValue { return list } -func (s *MssqlStmt) Query(args []driver.Value) (driver.Rows, error) { +func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) { return s.queryContext(context.Background(), convertOldArgs(args)) } -func (s *MssqlStmt) queryContext(ctx context.Context, args []namedValue) (rows driver.Rows, err error) { +func (s *Stmt) queryContext(ctx context.Context, args []namedValue) (rows driver.Rows, err error) { if !s.c.connectionGood { return nil, driver.ErrBadConn } @@ -474,7 +474,7 @@ func (s *MssqlStmt) queryContext(ctx context.Context, args []namedValue) (rows d return s.processQueryResponse(ctx) } -func (s *MssqlStmt) processQueryResponse(ctx context.Context) (res driver.Rows, err error) { +func (s *Stmt) processQueryResponse(ctx context.Context) (res driver.Rows, err error) { tokchan := make(chan tokenStruct, 5) ctx, cancel := context.WithCancel(ctx) go processResponse(ctx, s.c.sess, tokchan, s.c.outs) @@ -502,15 +502,15 @@ loop: return nil, s.c.checkBadConn(token) } } - res = &MssqlRows{stmt: s, tokchan: tokchan, cols: cols, cancel: cancel} + res = &Rows{stmt: s, tokchan: tokchan, cols: cols, cancel: cancel} return } -func (s *MssqlStmt) Exec(args []driver.Value) (driver.Result, error) { +func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) { return s.exec(context.Background(), convertOldArgs(args)) } -func (s *MssqlStmt) exec(ctx context.Context, args []namedValue) (res driver.Result, err error) { +func (s *Stmt) exec(ctx context.Context, args []namedValue) (res driver.Result, err error) { if !s.c.connectionGood { return nil, driver.ErrBadConn } @@ -523,7 +523,7 @@ func (s *MssqlStmt) exec(ctx context.Context, args []namedValue) (res driver.Res return } -func (s *MssqlStmt) processExec(ctx context.Context) (res driver.Result, err error) { +func (s *Stmt) processExec(ctx context.Context) (res driver.Result, err error) { tokchan := make(chan tokenStruct, 5) go processResponse(ctx, s.c.sess, tokchan, s.c.outs) s.c.clearOuts() @@ -545,11 +545,11 @@ func (s *MssqlStmt) processExec(ctx context.Context) (res driver.Result, err err return nil, token } } - return &MssqlResult{s.c, rowCount}, nil + return &Result{s.c, rowCount}, nil } -type MssqlRows struct { - stmt *MssqlStmt +type Rows struct { + stmt *Stmt cols []columnStruct tokchan chan tokenStruct @@ -558,7 +558,7 @@ type MssqlRows struct { cancel func() } -func (rc *MssqlRows) Close() error { +func (rc *Rows) Close() error { rc.cancel() for _ = range rc.tokchan { } @@ -566,7 +566,7 @@ func (rc *MssqlRows) Close() error { return nil } -func (rc *MssqlRows) Columns() (res []string) { +func (rc *Rows) Columns() (res []string) { res = make([]string, len(rc.cols)) for i, col := range rc.cols { res[i] = col.ColName @@ -574,7 +574,7 @@ func (rc *MssqlRows) Columns() (res []string) { return } -func (rc *MssqlRows) Next(dest []driver.Value) error { +func (rc *Rows) Next(dest []driver.Value) error { if !rc.stmt.c.connectionGood { return driver.ErrBadConn } @@ -602,11 +602,11 @@ func (rc *MssqlRows) Next(dest []driver.Value) error { return io.EOF } -func (rc *MssqlRows) HasNextResultSet() bool { +func (rc *Rows) HasNextResultSet() bool { return rc.nextCols != nil } -func (rc *MssqlRows) NextResultSet() error { +func (rc *Rows) NextResultSet() error { rc.cols = rc.nextCols rc.nextCols = nil if rc.cols == nil { @@ -618,7 +618,7 @@ func (rc *MssqlRows) NextResultSet() error { // It should return // the value type that can be used to scan types into. For example, the database // column type "bigint" this should return "reflect.TypeOf(int64(0))". -func (r *MssqlRows) ColumnTypeScanType(index int) reflect.Type { +func (r *Rows) ColumnTypeScanType(index int) reflect.Type { return makeGoLangScanType(r.cols[index].ti) } @@ -627,7 +627,7 @@ func (r *MssqlRows) ColumnTypeScanType(index int) reflect.Type { // Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT", // "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML", // "TIMESTAMP". -func (r *MssqlRows) ColumnTypeDatabaseTypeName(index int) string { +func (r *Rows) ColumnTypeDatabaseTypeName(index int) string { return makeGoLangTypeName(r.cols[index].ti) } @@ -642,7 +642,7 @@ func (r *MssqlRows) ColumnTypeDatabaseTypeName(index int) string { // decimal (0, false) // int (0, false) // bytea(30) (30, true) -func (r *MssqlRows) ColumnTypeLength(index int) (int64, bool) { +func (r *Rows) ColumnTypeLength(index int) (int64, bool) { return makeGoLangTypeLength(r.cols[index].ti) } @@ -652,7 +652,7 @@ func (r *MssqlRows) ColumnTypeLength(index int) (int64, bool) { // decimal(38, 4) (38, 4, true) // int (0, 0, false) // decimal (math.MaxInt64, math.MaxInt64, true) -func (r *MssqlRows) ColumnTypePrecisionScale(index int) (int64, int64, bool) { +func (r *Rows) ColumnTypePrecisionScale(index int) (int64, int64, bool) { return makeGoLangTypePrecisionScale(r.cols[index].ti) } @@ -660,7 +660,7 @@ func (r *MssqlRows) ColumnTypePrecisionScale(index int) (int64, int64, bool) { // be true if it is known the column may be null, or false if the column is known // to be not nullable. // If the column nullability is unknown, ok should be false. -func (r *MssqlRows) ColumnTypeNullable(index int) (nullable, ok bool) { +func (r *Rows) ColumnTypeNullable(index int) (nullable, ok bool) { nullable = r.cols[index].Flags&colFlagNullable != 0 ok = true return @@ -673,7 +673,7 @@ func makeStrParam(val string) (res Param) { return } -func (s *MssqlStmt) makeParam(val driver.Value) (res Param, err error) { +func (s *Stmt) makeParam(val driver.Value) (res Param, err error) { if val == nil { res.ti.TypeId = typeNull res.buffer = nil @@ -742,16 +742,16 @@ func (s *MssqlStmt) makeParam(val driver.Value) (res Param, err error) { return } -type MssqlResult struct { - c *MssqlConn +type Result struct { + c *Conn rowsAffected int64 } -func (r *MssqlResult) RowsAffected() (int64, error) { +func (r *Result) RowsAffected() (int64, error) { return r.rowsAffected, nil } -func (r *MssqlResult) LastInsertId() (int64, error) { +func (r *Result) LastInsertId() (int64, error) { s, err := r.c.Prepare("select cast(@@identity as bigint)") if err != nil { return 0, err diff --git a/mssql_go18.go b/mssql_go18.go index 9eaeb167..74179c47 100644 --- a/mssql_go18.go +++ b/mssql_go18.go @@ -10,22 +10,22 @@ import ( "strings" ) -var _ driver.Pinger = &MssqlConn{} +var _ driver.Pinger = &Conn{} // Ping is used to check if the remote server is available and satisfies the Pinger interface. -func (c *MssqlConn) Ping(ctx context.Context) error { +func (c *Conn) Ping(ctx context.Context) error { if !c.connectionGood { return driver.ErrBadConn } - stmt := &MssqlStmt{c, `select 1;`, 0, nil} + stmt := &Stmt{c, `select 1;`, 0, nil} _, err := stmt.ExecContext(ctx, nil) return err } -var _ driver.ConnBeginTx = &MssqlConn{} +var _ driver.ConnBeginTx = &Conn{} // BeginTx satisfies ConnBeginTx. -func (c *MssqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { +func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { if !c.connectionGood { return nil, driver.ErrBadConn } @@ -57,7 +57,7 @@ func (c *MssqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver. return c.begin(ctx, tdsIsolation) } -func (c *MssqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { +func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { if !c.connectionGood { return nil, driver.ErrBadConn } @@ -68,7 +68,7 @@ func (c *MssqlConn) PrepareContext(ctx context.Context, query string) (driver.St return c.prepareContext(ctx, query) } -func (s *MssqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { +func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { if !s.c.connectionGood { return nil, driver.ErrBadConn } @@ -79,7 +79,7 @@ func (s *MssqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) return s.queryContext(ctx, list) } -func (s *MssqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { +func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { if !s.c.connectionGood { return nil, driver.ErrBadConn } diff --git a/mssql_go19.go b/mssql_go19.go index 5e8432b4..250151ab 100644 --- a/mssql_go19.go +++ b/mssql_go19.go @@ -9,9 +9,20 @@ import ( // "github.com/cockroachdb/apd" ) -var _ driver.NamedValueChecker = &MssqlConn{} +// Type alias provided for compibility. +// +// Deprecated: users should transition to the new names when possible. +type MssqlDriver = Driver +type MssqlBulk = Bulk +type MssqlBulkOptions = BulkOptions +type MssqlConn = Conn +type MssqlResult = Result +type MssqlRows = Rows +type MssqlStmt = Stmt -func (c *MssqlConn) CheckNamedValue(nv *driver.NamedValue) error { +var _ driver.NamedValueChecker = &Conn{} + +func (c *Conn) CheckNamedValue(nv *driver.NamedValue) error { switch v := nv.Value.(type) { case sql.Out: if c.outs == nil { @@ -41,7 +52,7 @@ func (c *MssqlConn) CheckNamedValue(nv *driver.NamedValue) error { } } -func (s *MssqlStmt) makeParamExtra(val driver.Value) (res Param, err error) { +func (s *Stmt) makeParamExtra(val driver.Value) (res Param, err error) { switch val := val.(type) { case sql.Out: res, err = s.makeParam(val.Dest) diff --git a/mssql_go19pre.go b/mssql_go19pre.go index 27cce0bd..3bad7fb0 100644 --- a/mssql_go19pre.go +++ b/mssql_go19pre.go @@ -7,6 +7,6 @@ import ( "fmt" ) -func (s *MssqlStmt) makeParamExtra(val driver.Value) (Param, error) { +func (s *Stmt) makeParamExtra(val driver.Value) (Param, error) { return Param{}, fmt.Errorf("mssql: unknown type for %T", val) } diff --git a/queries_go18_test.go b/queries_go18_test.go index ef17e8e6..9010775d 100644 --- a/queries_go18_test.go +++ b/queries_go18_test.go @@ -288,7 +288,7 @@ func TestBeginTxtReadOnlyNotSupported(t *testing.T) { } } -func TestMssqlConn_BeginTx(t *testing.T) { +func TestConn_BeginTx(t *testing.T) { conn := open(t) defer conn.Close() _, err := conn.Exec("create table test (f int)") diff --git a/queries_test.go b/queries_test.go index 5d38a89d..941ec176 100644 --- a/queries_test.go +++ b/queries_test.go @@ -13,14 +13,14 @@ import ( "time" ) -func driverWithProcess(t *testing.T) *MssqlDriver { - return &MssqlDriver{ +func driverWithProcess(t *testing.T) *Driver { + return &Driver{ log: optionalLogger{testLogger{t}}, processQueryText: true, } } -func driverNoProcess(t *testing.T) *MssqlDriver { - return &MssqlDriver{ +func driverNoProcess(t *testing.T) *Driver { + return &Driver{ log: optionalLogger{testLogger{t}}, processQueryText: false, } @@ -809,7 +809,7 @@ func TestIgnoreEmptyResults(t *testing.T) { } } -func TestMssqlStmt_SetQueryNotification(t *testing.T) { +func TestStmt_SetQueryNotification(t *testing.T) { checkConnStr(t) mssqldriver := driverWithProcess(t) cn, err := mssqldriver.Open(makeConnStr(t).String()) @@ -821,7 +821,7 @@ func TestMssqlStmt_SetQueryNotification(t *testing.T) { t.Error("Connection failed", err) } - sqlstmt := stmt.(*MssqlStmt) + sqlstmt := stmt.(*Stmt) sqlstmt.SetQueryNotification("ABC", "service=WebCacheNotifications", time.Hour) rows, err := sqlstmt.Query(nil)