From 768b5e3bcc7414431d14a730dea7c3279ac9dbc5 Mon Sep 17 00:00:00 2001 From: "Rose, William" Date: Wed, 16 Dec 2020 11:06:53 -0800 Subject: [PATCH] Add connection logging to help with debugging --- conn_str.go | 7 +++ conn_str_test.go | 5 +- log_conn.go | 80 +++++++++++++++++++++++++++++++ log_conn_test.go | 121 +++++++++++++++++++++++++++++++++++++++++++++++ tds.go | 20 +++++++- 5 files changed, 230 insertions(+), 3 deletions(-) create mode 100644 log_conn.go create mode 100644 log_conn_test.go diff --git a/conn_str.go b/conn_str.go index d7d9e06a..decb7632 100644 --- a/conn_str.go +++ b/conn_str.go @@ -1,6 +1,7 @@ package mssql import ( + "errors" "fmt" "net" "net/url" @@ -39,6 +40,7 @@ type connectParams struct { packetSize uint16 fedAuthLibrary int fedAuthADALWorkflow byte + tlsKeyLogFile string } // default packet size for TDS buffer @@ -235,6 +237,11 @@ func parseConnectParams(dsn string) (connectParams, error) { } } + p.tlsKeyLogFile, ok = params["tls key log file"] + if ok && p.tlsKeyLogFile != "" && p.disableEncryption { + return p, errors.New("Cannot set tlsKeyLogFile when encryption is disabled") + } + return p, nil } diff --git a/conn_str_test.go b/conn_str_test.go index bb6e2682..a7f953a3 100644 --- a/conn_str_test.go +++ b/conn_str_test.go @@ -67,6 +67,7 @@ func TestValidConnectionString(t *testing.T) { {"trustservercertificate=false", func(p connectParams) bool { return !p.trustServerCertificate }}, {"certificate=abc", func(p connectParams) bool { return p.certificate == "abc" }}, {"hostnameincertificate=abc", func(p connectParams) bool { return p.hostInCertificate == "abc" }}, + {"tls key log file=tls.log", func(p connectParams) bool { return p.tlsKeyLogFile == "tls.log" }}, {"connection timeout=3;dial timeout=4;keepalive=5", func(p connectParams) bool { return p.conn_timeout == 3*time.Second && p.dial_timeout == 4*time.Second && p.keepAlive == 5*time.Second }}, @@ -186,10 +187,10 @@ func testConnParams(t testing.TB) connectParams { } if len(os.Getenv("HOST")) > 0 && len(os.Getenv("DATABASE")) > 0 { return connectParams{ - host: os.Getenv("HOST"), + host: os.Getenv("HOST"), instance: os.Getenv("INSTANCE"), database: os.Getenv("DATABASE"), - user: os.Getenv("SQLUSER"), + user: os.Getenv("SQLUSER"), password: os.Getenv("SQLPASSWORD"), logFlags: logFlags, } diff --git a/log_conn.go b/log_conn.go new file mode 100644 index 00000000..4777e4c1 --- /dev/null +++ b/log_conn.go @@ -0,0 +1,80 @@ +package mssql + +import ( + "encoding/hex" + "net" + "strings" + "time" +) + +type connLogger struct { + conn net.Conn + readKind, writeKind string + readCount, writeCount int + logger Logger +} + +var _ net.Conn = &connLogger{} + +func newConnLogger(conn net.Conn, kind string, logger Logger) net.Conn { + if len(kind) > 0 && !strings.HasPrefix(kind, " ") { + kind = " " + kind + } + + cl := &connLogger{ + conn: conn, + readKind: "R" + kind, + writeKind: "W" + kind, + logger: logger, + } + + return cl +} + +func (cl *connLogger) Read(p []byte) (n int, err error) { + n, err = cl.conn.Read(p) + + if n > 0 { + dump := hex.Dump(p) + cl.logger.Printf("%s %d\n%s", cl.readKind, cl.readCount, dump) + cl.readCount += n + } + + return +} + +func (cl *connLogger) Write(p []byte) (n int, err error) { + n, err = cl.conn.Write(p) + + if n > 0 { + dump := hex.Dump(p) + cl.logger.Printf("%s %d\n%s", cl.writeKind, cl.writeCount, dump) + cl.writeCount += n + } + + return +} + +func (cl *connLogger) Close() (err error) { + return cl.conn.Close() +} + +func (cl *connLogger) LocalAddr() net.Addr { + return cl.conn.LocalAddr() +} + +func (cl *connLogger) RemoteAddr() net.Addr { + return cl.conn.RemoteAddr() +} + +func (cl *connLogger) SetDeadline(t time.Time) error { + return cl.conn.SetDeadline(t) +} + +func (cl *connLogger) SetReadDeadline(t time.Time) error { + return cl.conn.SetReadDeadline(t) +} + +func (cl *connLogger) SetWriteDeadline(t time.Time) error { + return cl.conn.SetWriteDeadline(t) +} diff --git a/log_conn_test.go b/log_conn_test.go new file mode 100644 index 00000000..2e4b91d5 --- /dev/null +++ b/log_conn_test.go @@ -0,0 +1,121 @@ +package mssql + +import ( + "net" + "sync/atomic" + "testing" + "time" +) + +func TestConnLoggerOperations(t *testing.T) { + clt := &connLoggerTest{} + cl := newConnLogger(clt, "test", nullLogger{}) + packet := append(make([]byte, 0, 10), 1, 2, 3, 4, 5) + n, err := cl.Read(packet) + if n != 10 || err != nil { + t.Error("Unexpected return value from call to Read()") + } + + n, err = cl.Write(packet) + if n != 5 || err != nil { + t.Error("Unexpected return value from call to Write()") + } + + if cl.Close() != nil { + t.Error("Unexpected return value from call to Close()") + } + + if cl.LocalAddr() == nil { + t.Error("Unexpected return value from call to LocalAddr()") + } + + if cl.RemoteAddr() == nil { + t.Error("Unexpected return value from call to RemoteAddr()") + } + + if cl.SetDeadline(time.Now()) != nil { + t.Error("Unexpected return value from call to SetDeadline()") + } + + if cl.SetReadDeadline(time.Now()) != nil { + t.Error("Unexpected return value from call to SetReadDeadline()") + } + + if cl.SetWriteDeadline(time.Now()) != nil { + t.Error("Unexpected return value from call to SetWriteDeadline()") + } + + if atomic.LoadInt32(&clt.calls) != 8 { + t.Error("Unexpected number of calls recorded") + } +} + +type connLoggerTest struct { + calls int32 +} + +var _ net.Conn = &connLoggerTest{} + +type addressTest struct { +} + +var _ net.Addr = &addressTest{} + +type nullLogger struct { +} + +var _ Logger = nullLogger{} + +func (n nullLogger) Printf(format string, v ...interface{}) { +} + +func (n nullLogger) Println(v ...interface{}) { +} + +func (a *addressTest) Network() string { + return "test" +} + +func (a *addressTest) String() string { + return "test" +} + +func (cl *connLoggerTest) Read(p []byte) (int, error) { + atomic.AddInt32(&cl.calls, 1) + return cap(p), nil +} + +func (cl *connLoggerTest) Write(p []byte) (int, error) { + atomic.AddInt32(&cl.calls, 1) + return len(p), nil +} + +func (cl *connLoggerTest) Close() error { + atomic.AddInt32(&cl.calls, 1) + return nil +} + +func (cl *connLoggerTest) LocalAddr() net.Addr { + atomic.AddInt32(&cl.calls, 1) + return &addressTest{} +} + +func (cl *connLoggerTest) RemoteAddr() net.Addr { + atomic.AddInt32(&cl.calls, 1) + return &addressTest{} +} + +func (cl *connLoggerTest) SetDeadline(t time.Time) error { + atomic.AddInt32(&cl.calls, 1) + return nil +} + +func (cl *connLoggerTest) SetReadDeadline(t time.Time) error { + atomic.AddInt32(&cl.calls, 1) + return nil +} + +func (cl *connLoggerTest) SetWriteDeadline(t time.Time) error { + atomic.AddInt32(&cl.calls, 1) + return nil +} diff --git a/tds.go b/tds.go index e1b63300..a41a9ddf 100644 --- a/tds.go +++ b/tds.go @@ -10,6 +10,7 @@ import ( "io" "io/ioutil" "net" + "os" "sort" "strconv" "strings" @@ -152,6 +153,7 @@ const ( logParams = 16 logTransaction = 32 logDebug = 64 + logTraffic = 128 ) type columnStruct struct { @@ -1059,6 +1061,10 @@ initiate_connection: return nil, err } + if p.logFlags&logTraffic != 0 { + conn = newConnLogger(conn, "TCP", log) + } + toconn := newTimeoutConn(conn, p.conn_timeout) outbuf := newTdsBuffer(p.packetSize, toconn) @@ -1104,6 +1110,14 @@ initiate_connection: if p.trustServerCertificate { config.InsecureSkipVerify = true } + if p.tlsKeyLogFile != "" { + if w, err := os.OpenFile(p.tlsKeyLogFile, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0600); err == nil { + defer w.Close() + config.KeyLogWriter = w + } else { + return nil, fmt.Errorf("Cannot open TLS key log file %s: %v", p.tlsKeyLogFile, err) + } + } config.ServerName = p.hostInCertificate // fix for https://github.com/denisenkom/go-mssqldb/issues/166 // Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments, @@ -1116,7 +1130,11 @@ initiate_connection: tlsConn := tls.Client(&passthrough, &config) err = tlsConn.Handshake() passthrough.c = toconn - outbuf.transport = tlsConn + if sess.logFlags&logTraffic != 0 { + outbuf.transport = newConnLogger(tlsConn, "TLS", log) + } else { + outbuf.transport = tlsConn + } if err != nil { return nil, fmt.Errorf("TLS Handshake failed: %v", err) }